mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
### What problem does this PR solve? S3-family connector syncs currently re-download every in-window object just so we can compute `xxhash128(blob)` and compare against `Document.content_hash`. Anything that bumps `LastModified` without changing bytes (`aws s3 cp` touches, bucket re-encryption, etc.) pays full bandwidth and re-parses files that didn't actually change. #14628 covers the broader incremental-ingestion redesign; this PR is the first slice. The fix is a pre-listing short-circuit. `BlobStorageConnector` (S3 / R2 / GCS / OCI / S3-compat) now implements a new `FingerprintConnector` interface: `list_keys()` paginates `list_objects_v2` and yields `KeyRecord(key, fingerprint)` where `fingerprint = xxhash128(ETag)`. The orchestrator joins those against the connector's existing `{doc_id: content_hash}` map and only calls `get_value(key)` when the fingerprint differs. Unchanged keys are skipped entirely — no `GetObject`, no re-parse. No DDL. xxhash128(ETag) is 32 hex chars and reuses the existing `Document.content_hash` column per @yingfeng's suggestion; the connector decides at listing time whether to populate it. Local uploads and connectors that don't opt in fall through to the existing post-download `xxhash128(blob)` path with no behavior change. This is PR-1 of a 4-PR series — full design lives on #14628. Subsequent PRs extend tier 1 to local FS / WebDAV / Dropbox / Seafile / RDBMS (PR-2), wire up tier 2 cursor connectors with `SyncLogs.next_checkpoint` (PR-3), and unify deletion via `KeyRecord(deleted=True)` reconciliation (PR-4). Holding those back keeps this PR additive and reviewable on its own. #### Files touched - `common/data_source/models.py` — new `KeyRecord`; optional `fingerprint` on `Document` - `common/data_source/interfaces.py` — `IncrementalCapability` enum, `FingerprintConnector` ABC - `common/data_source/blob_connector.py` — `BlobStorageConnector` implements `FingerprintConnector`; per-object download factored into `_build_document_from_obj()` so `_yield_blob_objects`, `list_keys`, `get_value` all share it - `rag/svr/sync_data_source.py` — `_BlobLikeBase._fingerprint_filtered_generator` does the bypass loop; `_run_task_logic` plumbs `doc.fingerprint` into the upload dict - `api/db/services/document_service.py` — `list_id_content_hash_map_by_kb_and_source_type()` helper - `api/db/services/connector_service.py` + `file_service.py` — fingerprint flows through `duplicate_and_parse → upload_document` and lands in `content_hash` - `test/unit_test/common/test_blob_connector_fingerprint.py` — 14 tests covering ETag normalization (single-part, multipart, quoted, empty), `list_keys()` not calling `GetObject`, `get_value()` materializing with fingerprint, deterministic/stable fingerprints, and the bypass loop asserting `GetObject` is *not* called on a match #### Worth flagging for review Old `_BlobLikeBase._generate` called `poll_source(start, now)` with a `LastModified` window when `poll_range_start` was set. New code uses `_fingerprint_filtered_generator` (full bucket listing + fingerprint compare) outside of explicit `reindex=1`. Strictly better for unchanged-bucket cases since it skips `GetObject`, but it does mean every sync now does a full `list_objects_v2` paginate. Should still be cheap for most buckets — flagging in case anyone has a very large bucket where the time-window filter was meaningful. On migration: existing rows have `content_hash = xxhash128(blob)` from the old code. The first sync after this lands sees ETag-derived fingerprints that don't match, re-fetches every object once, and writes the new fingerprint. From the second sync onward the bypass works as expected. "Slow day one, fast every day after." A `fingerprint_backfill: trust` opt-out is sketched in the design doc but not in this PR. #### Test plan - [x] `uv run ruff check` — clean on all 8 touched files - [x] `uv run pytest test/unit_test/common/test_blob_connector_fingerprint.py -v` — 14 passed - [x] Broader unit-test suite — no regressions in anything I touched - [ ] Manual smoke against a real S3 bucket — configure a connector, run sync twice, expect the second sync to log `bypassed=N, fetched=0` and no `GetObject` calls in CloudTrail / bucket access logs - [ ] Manual smoke with `reindex=1` — confirm the full re-download path still works ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Yingfeng <yingfeng.zhang@gmail.com>
This commit is contained in:
@@ -16,7 +16,7 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
import os
|
||||
from typing import Tuple, List
|
||||
from typing import Optional, Tuple, List
|
||||
|
||||
from anthropic import BaseModel
|
||||
from peewee import SQL, fn
|
||||
@@ -276,12 +276,13 @@ class SyncLogsService(CommonService):
|
||||
id: str
|
||||
filename: str
|
||||
blob: bytes
|
||||
fingerprint: Optional[str] = None
|
||||
|
||||
def read(self) -> bytes:
|
||||
return self.blob
|
||||
|
||||
errs = []
|
||||
files = [FileObj(id=d["id"], filename=d["semantic_identifier"]+(f"{d['extension']}" if d["semantic_identifier"][::-1].find(d['extension'][::-1])<0 else ""), blob=d["blob"]) for d in docs]
|
||||
files = [FileObj(id=d["id"], filename=d["semantic_identifier"]+(f"{d['extension']}" if d["semantic_identifier"][::-1].find(d['extension'][::-1])<0 else ""), blob=d["blob"], fingerprint=d.get("fingerprint")) for d in docs]
|
||||
doc_ids = []
|
||||
err, doc_blob_pairs = FileService.upload_document(kb, files, tenant_id, src)
|
||||
errs.extend(err)
|
||||
|
||||
@@ -388,6 +388,35 @@ class DocumentService(CommonService):
|
||||
offset += page_size
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def list_id_content_hash_map_by_kb_and_source_type(cls, kb_id, source_type, page_size=500):
|
||||
"""Return {doc_id: content_hash} for the connector's existing docs.
|
||||
|
||||
Used by the fingerprint-bypass path to decide which keys can skip a
|
||||
re-fetch -- if the connector's listing fingerprint equals content_hash,
|
||||
the body hasn't changed since the last sync.
|
||||
|
||||
Ordered by create_time so LIMIT/OFFSET pagination is stable under
|
||||
concurrent writes; without this, page boundaries can drop or duplicate
|
||||
rows and the resulting map would silently miss entries.
|
||||
"""
|
||||
fields = [cls.model.id, cls.model.content_hash]
|
||||
docs = cls.model.select(*fields).where(
|
||||
cls.model.kb_id == kb_id,
|
||||
cls.model.source_type == source_type,
|
||||
).order_by(cls.model.create_time.asc())
|
||||
offset = 0
|
||||
result: dict[str, str] = {}
|
||||
while True:
|
||||
batch = list(docs.offset(offset).limit(page_size).dicts())
|
||||
if not batch:
|
||||
break
|
||||
for row in batch:
|
||||
result[row["id"]] = row.get("content_hash") or ""
|
||||
offset += page_size
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_docs_by_creator_id(cls, creator_id):
|
||||
|
||||
@@ -482,7 +482,12 @@ class FileService(CommonService):
|
||||
err.append(file.filename + ": " + user_msg)
|
||||
continue
|
||||
blob = file.read()
|
||||
new_hash = xxhash.xxh128(blob).hexdigest()
|
||||
# Connector-supplied fingerprint (e.g. xxhash128(S3 ETag))
|
||||
# takes precedence: for connector-sourced docs the bypass
|
||||
# path uses the fingerprint as content_hash, so reverting
|
||||
# to xxhash128(blob) here would defeat it.
|
||||
incoming_fp = getattr(file, "fingerprint", None)
|
||||
new_hash = incoming_fp or xxhash.xxh128(blob).hexdigest()
|
||||
old_hash = doc.content_hash or ""
|
||||
settings.STORAGE_IMPL.put(kb.id, doc.location, blob, kb.tenant_id)
|
||||
doc.size = len(blob)
|
||||
@@ -518,6 +523,7 @@ class FileService(CommonService):
|
||||
thumbnail_location = f"thumbnail_{doc_id}.png"
|
||||
settings.STORAGE_IMPL.put(kb.id, thumbnail_location, img)
|
||||
|
||||
incoming_fp = getattr(file, "fingerprint", None)
|
||||
doc = {
|
||||
"id": doc_id,
|
||||
"kb_id": kb.id,
|
||||
@@ -532,7 +538,7 @@ class FileService(CommonService):
|
||||
"location": location,
|
||||
"size": len(blob),
|
||||
"thumbnail": thumbnail_location,
|
||||
"content_hash": xxhash.xxh128(blob).hexdigest(),
|
||||
"content_hash": incoming_fp or xxhash.xxh128(blob).hexdigest(),
|
||||
}
|
||||
DocumentService.insert(doc)
|
||||
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
"""Blob storage connector"""
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
|
||||
import xxhash
|
||||
|
||||
from common.data_source.utils import (
|
||||
create_s3_client,
|
||||
detect_bucket_region,
|
||||
@@ -18,9 +21,14 @@ from common.data_source.exceptions import (
|
||||
CredentialExpiredError,
|
||||
InsufficientPermissionsError
|
||||
)
|
||||
from common.data_source.interfaces import LoadConnector, PollConnector
|
||||
from common.data_source.interfaces import (
|
||||
FingerprintConnector,
|
||||
LoadConnector,
|
||||
PollConnector,
|
||||
)
|
||||
from common.data_source.models import (
|
||||
Document,
|
||||
KeyRecord,
|
||||
SecondsSinceUnixEpoch,
|
||||
GenerateDocumentsOutput,
|
||||
GenerateSlimDocumentOutput,
|
||||
@@ -28,7 +36,20 @@ from common.data_source.models import (
|
||||
)
|
||||
|
||||
|
||||
class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
def _normalize_etag(raw_etag: Optional[str]) -> Optional[str]:
|
||||
"""Return a 32-char hex fingerprint derived from an S3 ETag.
|
||||
|
||||
S3 ETags are MD5 (32 hex chars) for single-part uploads and "<md5>-<n>"
|
||||
(34+ chars) for multipart. We always hash so the column format is uniform
|
||||
regardless of upload type or provider quirks; equality of the hashed value
|
||||
is sufficient for change detection.
|
||||
"""
|
||||
if not raw_etag:
|
||||
return None
|
||||
return xxhash.xxh128(raw_etag.strip('"').encode()).hexdigest()
|
||||
|
||||
|
||||
class BlobStorageConnector(LoadConnector, PollConnector, FingerprintConnector):
|
||||
"""Blob storage connector"""
|
||||
|
||||
def __init__(
|
||||
@@ -48,6 +69,11 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
self.size_threshold: int | None = BLOB_STORAGE_SIZE_THRESHOLD
|
||||
self.bucket_region: Optional[str] = None
|
||||
self.european_residency: bool = european_residency
|
||||
# Populated by list_keys() so a subsequent get_value(key) can find the
|
||||
# raw S3 object metadata (LastModified, ETag, Key, Size) without a second
|
||||
# head_object call. Lifetime is one list_keys() pass.
|
||||
self._listing_cache: dict[str, dict[str, Any]] = {}
|
||||
self._filename_counts: dict[str, int] = {}
|
||||
|
||||
def set_allow_images(self, allow_images: bool) -> None:
|
||||
"""Set whether to process images"""
|
||||
@@ -122,6 +148,44 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
|
||||
return None
|
||||
|
||||
def _build_document_from_obj(
|
||||
self,
|
||||
obj: dict[str, Any],
|
||||
filename_counts: dict[str, int],
|
||||
) -> Optional[Document]:
|
||||
"""Materialize a Document for one S3 object, downloading its body."""
|
||||
key = obj["Key"]
|
||||
file_name = os.path.basename(key)
|
||||
last_modified = obj["LastModified"].replace(tzinfo=timezone.utc)
|
||||
|
||||
size_bytes = extract_size_bytes(obj)
|
||||
if (
|
||||
self.size_threshold is not None
|
||||
and isinstance(size_bytes, int)
|
||||
and size_bytes > self.size_threshold
|
||||
):
|
||||
logging.warning(
|
||||
f"{file_name} exceeds size threshold of {self.size_threshold}. Skipping."
|
||||
)
|
||||
return None
|
||||
|
||||
blob = download_object(
|
||||
self.s3_client, self.bucket_name, key, self.size_threshold
|
||||
)
|
||||
if blob is None:
|
||||
return None
|
||||
|
||||
return Document(
|
||||
id=f"{self.bucket_type}:{self.bucket_name}:{key}",
|
||||
blob=blob,
|
||||
source=DocumentSource(self.bucket_type.value),
|
||||
semantic_identifier=self._get_semantic_id(key, file_name, filename_counts),
|
||||
extension=get_file_ext(file_name),
|
||||
doc_updated_at=last_modified,
|
||||
size_bytes=size_bytes if size_bytes else 0,
|
||||
fingerprint=_normalize_etag(obj.get("ETag")),
|
||||
)
|
||||
|
||||
def _yield_blob_objects(
|
||||
self,
|
||||
start: datetime,
|
||||
@@ -132,51 +196,64 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
|
||||
batch: list[Document] = []
|
||||
for obj in all_objects:
|
||||
last_modified = obj["LastModified"].replace(tzinfo=timezone.utc)
|
||||
file_name = os.path.basename(obj["Key"])
|
||||
key = obj["Key"]
|
||||
|
||||
size_bytes = extract_size_bytes(obj)
|
||||
if (
|
||||
self.size_threshold is not None
|
||||
and isinstance(size_bytes, int)
|
||||
and size_bytes > self.size_threshold
|
||||
):
|
||||
logging.warning(
|
||||
f"{file_name} exceeds size threshold of {self.size_threshold}. Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
blob = download_object(
|
||||
self.s3_client, self.bucket_name, key, self.size_threshold
|
||||
)
|
||||
if blob is None:
|
||||
doc = self._build_document_from_obj(obj, filename_counts)
|
||||
if doc is None:
|
||||
continue
|
||||
|
||||
semantic_id = self._get_semantic_id(key, file_name, filename_counts)
|
||||
|
||||
batch.append(
|
||||
Document(
|
||||
id=f"{self.bucket_type}:{self.bucket_name}:{key}",
|
||||
blob=blob,
|
||||
source=DocumentSource(self.bucket_type.value),
|
||||
semantic_identifier=semantic_id,
|
||||
extension=get_file_ext(file_name),
|
||||
doc_updated_at=last_modified,
|
||||
size_bytes=size_bytes if size_bytes else 0,
|
||||
)
|
||||
)
|
||||
batch.append(doc)
|
||||
if len(batch) == self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
|
||||
except Exception:
|
||||
logging.exception(f"Error decoding object {key}")
|
||||
logging.exception(f"Error decoding object {obj.get('Key')}")
|
||||
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
def list_keys(self) -> Iterator[KeyRecord]:
|
||||
"""Enumerate the full bucket keyspace with per-object fingerprints.
|
||||
|
||||
Cheap path: relies on list_objects_v2 which returns ETag in the listing,
|
||||
so no GetObject call is needed. Caches each object's metadata so a
|
||||
subsequent get_value(key) call can rebuild the Document without a second
|
||||
round-trip to S3.
|
||||
"""
|
||||
if self.s3_client is None:
|
||||
raise ConnectorMissingCredentialError("Blob storage")
|
||||
|
||||
all_objects, filename_counts = self._collect_blob_objects(
|
||||
start=datetime(1970, 1, 1, tzinfo=timezone.utc),
|
||||
end=datetime.now(timezone.utc),
|
||||
)
|
||||
self._filename_counts = filename_counts
|
||||
self._listing_cache = {}
|
||||
|
||||
for obj in all_objects:
|
||||
doc_id = f"{self.bucket_type}:{self.bucket_name}:{obj['Key']}"
|
||||
self._listing_cache[doc_id] = obj
|
||||
yield KeyRecord(
|
||||
key=doc_id,
|
||||
fingerprint=_normalize_etag(obj.get("ETag")),
|
||||
)
|
||||
|
||||
def get_value(self, key: str) -> Document:
|
||||
"""Materialize the Document for a key previously yielded by list_keys().
|
||||
|
||||
Must be called within the same list_keys() pass that produced the key,
|
||||
since the metadata cache lives on the connector instance and is reset
|
||||
each list_keys() call.
|
||||
"""
|
||||
obj = self._listing_cache.get(key)
|
||||
if obj is None:
|
||||
raise KeyError(
|
||||
f"get_value({key!r}) called before list_keys() yielded the key, "
|
||||
"or after a subsequent list_keys() reset the cache"
|
||||
)
|
||||
doc = self._build_document_from_obj(obj, self._filename_counts)
|
||||
if doc is None:
|
||||
raise RuntimeError(f"Failed to materialize Document for key {key!r}")
|
||||
return doc
|
||||
|
||||
def _collect_blob_objects(
|
||||
self,
|
||||
start: datetime,
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import abc
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import IntFlag, auto
|
||||
from enum import IntEnum, IntFlag, auto
|
||||
from types import TracebackType
|
||||
from typing import Any, Dict, Generator, TypeVar, Generic, Callable, TypeAlias
|
||||
from collections.abc import Iterator
|
||||
@@ -10,12 +10,26 @@ from anthropic import BaseModel
|
||||
|
||||
from common.data_source.models import (
|
||||
Document,
|
||||
KeyRecord,
|
||||
SlimDocument,
|
||||
ConnectorCheckpoint,
|
||||
ConnectorFailure,
|
||||
SecondsSinceUnixEpoch, GenerateSlimDocumentOutput
|
||||
)
|
||||
|
||||
|
||||
class IncrementalCapability(IntEnum):
|
||||
"""How a connector handles incremental sync.
|
||||
|
||||
FULL_RESYNC -- every sync re-pulls; no per-key state.
|
||||
CURSOR -- "give me everything since cursor X"; opaque cursor persisted across syncs.
|
||||
FINGERPRINT -- list_keys() returns (key, fingerprint) cheaply; bodies fetched lazily.
|
||||
"""
|
||||
FULL_RESYNC = 0
|
||||
CURSOR = 1
|
||||
FINGERPRINT = 2
|
||||
|
||||
|
||||
GenerateDocumentsOutput = Iterator[list[Document]]
|
||||
|
||||
class LoadConnector(ABC):
|
||||
@@ -415,3 +429,39 @@ class IndexingHeartbeatInterface(ABC):
|
||||
just to act as a keep-alive.
|
||||
"""
|
||||
|
||||
|
||||
class FingerprintConnector(ABC):
|
||||
"""Tier 1 connector: cheap full listing with per-key fingerprint.
|
||||
|
||||
Sources that can enumerate their entire keyspace via a metadata-only call
|
||||
(e.g. S3 list_objects_v2 returning ETag + LastModified) implement this to
|
||||
let the orchestrator skip GetObject for keys whose fingerprint hasn't
|
||||
changed since the last sync.
|
||||
|
||||
The fingerprint is an opaque equality token: two equal fingerprints mean
|
||||
the content is unchanged from the orchestrator's point of view. Format is
|
||||
a 32-char hex string so it fits the existing Document.content_hash column;
|
||||
connectors are responsible for normalizing whatever the source exposes
|
||||
(typically by hashing it with xxhash128).
|
||||
"""
|
||||
|
||||
INCREMENTAL_CAPABILITY: IncrementalCapability = IncrementalCapability.FINGERPRINT
|
||||
|
||||
@abstractmethod
|
||||
def list_keys(self) -> Iterator[KeyRecord]:
|
||||
"""Yield one KeyRecord per object currently in the source.
|
||||
|
||||
Must enumerate the full current keyspace -- the orchestrator diffs the
|
||||
result against persisted state to detect adds, updates, and deletes.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_value(self, key: str) -> Document:
|
||||
"""Fetch the body for a single key, returning a fully populated Document.
|
||||
|
||||
Called only when list_keys()'s fingerprint differs from the persisted
|
||||
content_hash for that key (or when no persisted fingerprint exists).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -99,6 +99,25 @@ class Document(BaseModel):
|
||||
primary_owners: Optional[list] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
doc_metadata: Optional[dict[str, Any]] = None
|
||||
# Opaque, connector-supplied fingerprint stored in Document.content_hash for
|
||||
# change-detection. 32-char hex string; format is per-source (xxhash128 of
|
||||
# bytes for local uploads, xxhash128(ETag) for blob storage, etc.). When set
|
||||
# on a yielded Document, the orchestrator persists it as content_hash and
|
||||
# skips the post-download xxhash128(blob) recomputation.
|
||||
fingerprint: Optional[str] = None
|
||||
|
||||
|
||||
class KeyRecord(BaseModel):
|
||||
"""One entry returned by a FingerprintConnector.list_keys() call.
|
||||
|
||||
A KeyRecord is the cheap-listing primitive: connector enumerates all keys
|
||||
it has, attaches a fingerprint when the source exposes one, and the
|
||||
orchestrator only fetches content when the fingerprint differs from what's
|
||||
persisted.
|
||||
"""
|
||||
key: str
|
||||
fingerprint: Optional[str] = None
|
||||
deleted: bool = False
|
||||
|
||||
|
||||
class BasicExpertInfo(BaseModel):
|
||||
|
||||
@@ -213,6 +213,8 @@ class SyncBase:
|
||||
}
|
||||
if doc.metadata:
|
||||
d["metadata"] = doc.metadata
|
||||
if getattr(doc, "fingerprint", None):
|
||||
d["fingerprint"] = doc.fingerprint
|
||||
docs.append(d)
|
||||
|
||||
try:
|
||||
@@ -301,6 +303,81 @@ class SyncBase:
|
||||
class _BlobLikeBase(SyncBase):
|
||||
DEFAULT_BUCKET_TYPE: str = "s3"
|
||||
|
||||
def _fingerprint_filtered_generator(self, task: dict):
|
||||
"""Generator that uses list_keys() + get_value() to skip unchanged objects.
|
||||
|
||||
Pre-loads {doc_id: content_hash} for the connector's existing docs in
|
||||
this KB, iterates the bucket via list_keys(), and only materializes a
|
||||
Document (one GetObject call) when the listing fingerprint differs from
|
||||
the persisted content_hash. Unchanged objects are skipped entirely --
|
||||
no download, no re-parse.
|
||||
|
||||
Per-key fetch failures are counted and surfaced via SyncLogsService so
|
||||
a partially failing sync (e.g. throttling, IAM regression mid-run)
|
||||
doesn't silently report DONE while half the bucket is unreachable.
|
||||
Connectors yielding KeyRecord(deleted=True) are skipped here -- actual
|
||||
deletion reconciliation lives in the unified delete pass (PR-4).
|
||||
"""
|
||||
source_type = f"{self.SOURCE_NAME}/{task['connector_id']}"
|
||||
existing_fingerprints = DocumentService.list_id_content_hash_map_by_kb_and_source_type(
|
||||
task["kb_id"], source_type,
|
||||
)
|
||||
|
||||
bypass_count = 0
|
||||
fetch_count = 0
|
||||
fail_count = 0
|
||||
batch = []
|
||||
for key_record in self.connector.list_keys():
|
||||
if key_record.deleted:
|
||||
continue
|
||||
|
||||
doc_id = hash128(key_record.key)
|
||||
stored = existing_fingerprints.get(doc_id, "")
|
||||
if key_record.fingerprint and stored and key_record.fingerprint == stored:
|
||||
bypass_count += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
doc = self.connector.get_value(key_record.key)
|
||||
except Exception as ex:
|
||||
fail_count += 1
|
||||
logging.exception(
|
||||
"Failed to fetch %s from %s: %s",
|
||||
key_record.key,
|
||||
self.SOURCE_NAME,
|
||||
ex,
|
||||
)
|
||||
continue
|
||||
|
||||
fetch_count += 1
|
||||
batch.append(doc)
|
||||
if len(batch) >= self.connector.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
log_msg = (
|
||||
"[%s] fingerprint sync: %d bypassed, %d fetched, %d failed "
|
||||
"(connector_id=%s, kb_id=%s)"
|
||||
)
|
||||
log_args = (
|
||||
self.SOURCE_NAME,
|
||||
bypass_count,
|
||||
fetch_count,
|
||||
fail_count,
|
||||
task["connector_id"],
|
||||
task["kb_id"],
|
||||
)
|
||||
# Use WARNING when any fetch failed so partial-bucket regressions
|
||||
# (auth, throttling, IAM drift) surface without diving into the
|
||||
# per-exception traces above.
|
||||
if fail_count:
|
||||
logging.warning(log_msg, *log_args)
|
||||
else:
|
||||
logging.info(log_msg, *log_args)
|
||||
|
||||
async def _generate(self, task: dict):
|
||||
bucket_type = self.conf.get("bucket_type", self.DEFAULT_BUCKET_TYPE)
|
||||
|
||||
@@ -313,14 +390,13 @@ class _BlobLikeBase(SyncBase):
|
||||
self.connector.load_credentials(self.conf["credentials"])
|
||||
|
||||
file_list = None
|
||||
document_batch_generator = (
|
||||
self.connector.load_from_state()
|
||||
if task["reindex"] == "1" or not task["poll_range_start"]
|
||||
else self.connector.poll_source(
|
||||
task["poll_range_start"].timestamp(),
|
||||
datetime.now(timezone.utc).timestamp(),
|
||||
)
|
||||
)
|
||||
# Fingerprint-bypass path: skip GetObject for unchanged ETags. Disabled
|
||||
# on full reindex (we want to re-fetch everything in that case).
|
||||
use_fingerprint_path = task["reindex"] != "1"
|
||||
if use_fingerprint_path:
|
||||
document_batch_generator = self._fingerprint_filtered_generator(task)
|
||||
else:
|
||||
document_batch_generator = self.connector.load_from_state()
|
||||
|
||||
if (
|
||||
task["reindex"] != "1"
|
||||
@@ -332,9 +408,9 @@ class _BlobLikeBase(SyncBase):
|
||||
file_list.extend(slim_batch)
|
||||
|
||||
_begin_info = (
|
||||
"totally"
|
||||
if task["reindex"] == "1" or not task["poll_range_start"]
|
||||
else "from {}".format(task["poll_range_start"])
|
||||
"fingerprint-bypass"
|
||||
if use_fingerprint_path
|
||||
else "full reindex"
|
||||
)
|
||||
|
||||
logging.info(
|
||||
|
||||
347
test/unit_test/common/test_blob_connector_fingerprint.py
Normal file
347
test/unit_test/common/test_blob_connector_fingerprint.py
Normal file
@@ -0,0 +1,347 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""Tests for the FingerprintConnector bypass path in BlobStorageConnector."""
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
|
||||
import pytest
|
||||
import xxhash
|
||||
|
||||
|
||||
def _load_blob_connector_module():
|
||||
repo_root = Path(__file__).resolve().parents[3]
|
||||
package_name = "common.data_source"
|
||||
saved_modules = {name: module for name, module in sys.modules.items() if name == package_name or name.startswith(f"{package_name}.")}
|
||||
package_stub = ModuleType(package_name)
|
||||
package_stub.__path__ = [str(repo_root / "common" / "data_source")]
|
||||
sys.modules[package_name] = package_stub
|
||||
|
||||
try:
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"_blob_connector_under_test",
|
||||
repo_root / "common" / "data_source" / "blob_connector.py",
|
||||
)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
finally:
|
||||
for name in list(sys.modules):
|
||||
if name == package_name or name.startswith(f"{package_name}."):
|
||||
if name in saved_modules:
|
||||
sys.modules[name] = saved_modules[name]
|
||||
else:
|
||||
sys.modules.pop(name, None)
|
||||
|
||||
|
||||
blob_connector = _load_blob_connector_module()
|
||||
BlobStorageConnector = blob_connector.BlobStorageConnector
|
||||
_normalize_etag = blob_connector._normalize_etag
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fake S3 client wired through a paginator-style interface.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakePaginator:
|
||||
def __init__(self, pages: list[dict]) -> None:
|
||||
self._pages = pages
|
||||
|
||||
def paginate(self, **_kwargs):
|
||||
for page in self._pages:
|
||||
yield page
|
||||
|
||||
|
||||
class _FakeS3Client:
|
||||
"""Captures every call on the connector's S3 client.
|
||||
|
||||
Tests assert against `get_object_calls` to verify that the fingerprint
|
||||
bypass actually skips downloads when ETags haven't changed.
|
||||
"""
|
||||
|
||||
def __init__(self, objects: list[dict]) -> None:
|
||||
self._objects = objects
|
||||
self.get_object_calls: list[tuple[str, str]] = []
|
||||
# Hand objects to the paginator unmodified so the connector exercises
|
||||
# its own directory-placeholder filtering logic.
|
||||
self._paginator = _FakePaginator([{"Contents": list(objects)}])
|
||||
|
||||
def get_paginator(self, name: str):
|
||||
assert name == "list_objects_v2"
|
||||
return self._paginator
|
||||
|
||||
def list_objects_v2(self, **_kwargs):
|
||||
return {"Contents": self._objects, "KeyCount": len(self._objects)}
|
||||
|
||||
def get_object(self, Bucket: str, Key: str): # noqa: N803 (boto3 API)
|
||||
self.get_object_calls.append((Bucket, Key))
|
||||
body_text = f"body-of-{Key}".encode()
|
||||
return {
|
||||
"Body": _FakeBody(body_text),
|
||||
"ContentLength": len(body_text),
|
||||
}
|
||||
|
||||
|
||||
class _FakeBody:
|
||||
"""Minimal stand-in for botocore's StreamingBody.
|
||||
|
||||
The real downloader (common.data_source.utils.download_object) consumes
|
||||
the body via iter_chunks() and then calls close(); fake out both.
|
||||
"""
|
||||
|
||||
def __init__(self, payload: bytes) -> None:
|
||||
self._payload = payload
|
||||
|
||||
def read(self) -> bytes:
|
||||
return self._payload
|
||||
|
||||
def iter_chunks(self, chunk_size: int = 65536):
|
||||
for i in range(0, len(self._payload), chunk_size):
|
||||
yield self._payload[i : i + chunk_size]
|
||||
|
||||
def close(self) -> None:
|
||||
return None
|
||||
|
||||
|
||||
def _make_connector(s3_client) -> BlobStorageConnector:
|
||||
connector = BlobStorageConnector(bucket_type="s3", bucket_name="test-bucket")
|
||||
connector.s3_client = s3_client
|
||||
return connector
|
||||
|
||||
|
||||
def _s3_object(key: str, etag: str, size: int = 12) -> dict:
|
||||
return {
|
||||
"Key": key,
|
||||
"ETag": f'"{etag}"',
|
||||
"LastModified": datetime(2026, 1, 1, 12, tzinfo=timezone.utc),
|
||||
"Size": size,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_normalize_etag_returns_32_char_hex_for_singlepart_etag():
|
||||
fp = _normalize_etag('"d41d8cd98f00b204e9800998ecf8427e"')
|
||||
assert fp is not None
|
||||
assert len(fp) == 32
|
||||
assert all(c in "0123456789abcdef" for c in fp)
|
||||
|
||||
|
||||
def test_normalize_etag_returns_32_char_hex_for_multipart_etag():
|
||||
"""Multipart ETags are 34+ chars; hashing normalizes them to 32."""
|
||||
fp = _normalize_etag('"d41d8cd98f00b204e9800998ecf8427e-7"')
|
||||
assert fp is not None
|
||||
assert len(fp) == 32
|
||||
|
||||
|
||||
def test_normalize_etag_is_deterministic():
|
||||
raw = '"abc123def456abc123def456abc123de"'
|
||||
assert _normalize_etag(raw) == _normalize_etag(raw)
|
||||
|
||||
|
||||
def test_normalize_etag_strips_quotes_so_quoted_and_unquoted_match():
|
||||
quoted = '"d41d8cd98f00b204e9800998ecf8427e"'
|
||||
unquoted = "d41d8cd98f00b204e9800998ecf8427e"
|
||||
assert _normalize_etag(quoted) == _normalize_etag(unquoted)
|
||||
|
||||
|
||||
def test_normalize_etag_returns_none_for_empty_input():
|
||||
assert _normalize_etag("") is None
|
||||
assert _normalize_etag(None) is None
|
||||
|
||||
|
||||
def test_list_keys_yields_one_keyrecord_per_object_with_fingerprint():
|
||||
s3 = _FakeS3Client(
|
||||
[
|
||||
_s3_object("foo.txt", "etag-foo"),
|
||||
_s3_object("bar/baz.txt", "etag-baz"),
|
||||
]
|
||||
)
|
||||
connector = _make_connector(s3)
|
||||
|
||||
records = list(connector.list_keys())
|
||||
|
||||
assert len(records) == 2
|
||||
assert {r.key for r in records} == {
|
||||
"BlobType.S3:test-bucket:foo.txt",
|
||||
"BlobType.S3:test-bucket:bar/baz.txt",
|
||||
}
|
||||
for record in records:
|
||||
assert record.fingerprint is not None
|
||||
assert len(record.fingerprint) == 32
|
||||
assert record.deleted is False
|
||||
|
||||
|
||||
def test_list_keys_does_not_call_get_object():
|
||||
"""list_keys() must be cheap -- no body downloads during enumeration."""
|
||||
s3 = _FakeS3Client([_s3_object("foo.txt", "etag-foo")])
|
||||
connector = _make_connector(s3)
|
||||
|
||||
list(connector.list_keys())
|
||||
|
||||
assert s3.get_object_calls == []
|
||||
|
||||
|
||||
def test_list_keys_skips_directory_placeholder_keys():
|
||||
"""S3 'folders' are zero-byte keys ending in '/'; they shouldn't yield records."""
|
||||
s3 = _FakeS3Client(
|
||||
[
|
||||
_s3_object("real-file.txt", "etag-real"),
|
||||
_s3_object("folder/", "etag-folder"),
|
||||
]
|
||||
)
|
||||
connector = _make_connector(s3)
|
||||
|
||||
keys = [r.key for r in connector.list_keys()]
|
||||
|
||||
assert keys == ["BlobType.S3:test-bucket:real-file.txt"]
|
||||
|
||||
|
||||
def test_get_value_returns_document_with_fingerprint_set():
|
||||
s3 = _FakeS3Client([_s3_object("foo.txt", "etag-foo")])
|
||||
connector = _make_connector(s3)
|
||||
[record] = list(connector.list_keys())
|
||||
|
||||
doc = connector.get_value(record.key)
|
||||
|
||||
assert doc.id == "BlobType.S3:test-bucket:foo.txt"
|
||||
assert doc.fingerprint == record.fingerprint
|
||||
assert doc.fingerprint == xxhash.xxh128(b"etag-foo").hexdigest()
|
||||
|
||||
|
||||
def test_get_value_calls_get_object_exactly_once_per_key():
|
||||
s3 = _FakeS3Client([_s3_object("foo.txt", "etag-foo")])
|
||||
connector = _make_connector(s3)
|
||||
[record] = list(connector.list_keys())
|
||||
|
||||
connector.get_value(record.key)
|
||||
|
||||
assert s3.get_object_calls == [("test-bucket", "foo.txt")]
|
||||
|
||||
|
||||
def test_get_value_raises_keyerror_when_called_before_list_keys():
|
||||
s3 = _FakeS3Client([_s3_object("foo.txt", "etag-foo")])
|
||||
connector = _make_connector(s3)
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
connector.get_value("BlobType.S3:test-bucket:foo.txt")
|
||||
|
||||
|
||||
def test_singlepart_and_multipart_etags_yield_different_fingerprints():
|
||||
"""Sanity: distinct ETags must produce distinct fingerprints."""
|
||||
s3 = _FakeS3Client(
|
||||
[
|
||||
_s3_object("a.bin", "d41d8cd98f00b204e9800998ecf8427e"),
|
||||
_s3_object("b.bin", "d41d8cd98f00b204e9800998ecf8427e-3"),
|
||||
]
|
||||
)
|
||||
connector = _make_connector(s3)
|
||||
|
||||
records = list(connector.list_keys())
|
||||
|
||||
assert records[0].fingerprint != records[1].fingerprint
|
||||
|
||||
|
||||
def test_fingerprint_stable_across_repeated_listings():
|
||||
"""Same ETag in two list_keys() calls yields the same fingerprint."""
|
||||
s3 = _FakeS3Client([_s3_object("foo.txt", "etag-stable")])
|
||||
connector = _make_connector(s3)
|
||||
|
||||
fp_first = next(connector.list_keys()).fingerprint
|
||||
fp_second = next(connector.list_keys()).fingerprint
|
||||
|
||||
assert fp_first == fp_second
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bypass-logic test: simulates what the orchestrator does in
|
||||
# _BlobLikeBase._fingerprint_filtered_generator. Verifies that a key whose
|
||||
# fingerprint matches the persisted content_hash is NOT fetched.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_orchestrator_pattern_skips_get_object_when_fingerprint_matches():
|
||||
# Use distinct base names: "unchanged.txt".endswith("changed.txt") is True,
|
||||
# which would silently break endswith-based lookups in the test setup.
|
||||
s3 = _FakeS3Client(
|
||||
[
|
||||
_s3_object("static.txt", "etag-static"),
|
||||
_s3_object("modified.txt", "etag-modified"),
|
||||
]
|
||||
)
|
||||
connector = _make_connector(s3)
|
||||
|
||||
# Pre-compute the fingerprints the connector would emit, then pretend the
|
||||
# DB already stores the one for static.txt but a stale value for
|
||||
# modified.txt. This is the steady-state bypass scenario.
|
||||
listed = list(connector.list_keys())
|
||||
static_record = next(r for r in listed if r.key.endswith(":static.txt"))
|
||||
modified_record = next(r for r in listed if r.key.endswith(":modified.txt"))
|
||||
persisted = {
|
||||
static_record.key: static_record.fingerprint,
|
||||
modified_record.key: "stale-fingerprint",
|
||||
}
|
||||
|
||||
# Reset the call log so we only count get_object during the bypass loop.
|
||||
s3.get_object_calls = []
|
||||
|
||||
fetched = []
|
||||
for record in connector.list_keys():
|
||||
if record.fingerprint and persisted.get(record.key) == record.fingerprint:
|
||||
continue
|
||||
fetched.append(connector.get_value(record.key))
|
||||
|
||||
assert [doc.id for doc in fetched] == ["BlobType.S3:test-bucket:modified.txt"]
|
||||
assert s3.get_object_calls == [("test-bucket", "modified.txt")]
|
||||
|
||||
|
||||
def test_orchestrator_pattern_skips_deleted_records_without_calling_get_value():
|
||||
"""KeyRecord(deleted=True) must short-circuit before get_value().
|
||||
|
||||
Reach KeyRecord through the already-loaded blob_connector module to avoid
|
||||
triggering common.data_source.__init__'s circular imports.
|
||||
"""
|
||||
KeyRecord = blob_connector.KeyRecord
|
||||
|
||||
s3 = _FakeS3Client([_s3_object("foo.txt", "etag-foo")])
|
||||
connector = _make_connector(s3)
|
||||
|
||||
# Manually feed a deleted KeyRecord through the bypass logic to assert the
|
||||
# short-circuit holds even when a connector emits one. (BlobStorageConnector
|
||||
# itself doesn't yield deleted records yet -- that's PR-4 -- but the
|
||||
# orchestrator must already be defensive.)
|
||||
deleted_record = KeyRecord(
|
||||
key="BlobType.S3:test-bucket:gone.txt",
|
||||
fingerprint=None,
|
||||
deleted=True,
|
||||
)
|
||||
|
||||
# Mirror the orchestrator's loop body verbatim.
|
||||
fetched = []
|
||||
for record in [deleted_record]:
|
||||
if record.deleted:
|
||||
continue
|
||||
fetched.append(connector.get_value(record.key))
|
||||
|
||||
assert fetched == []
|
||||
assert s3.get_object_calls == []
|
||||
Reference in New Issue
Block a user