From 782084780ecf879fce23b8e75331d9e5eca3926d Mon Sep 17 00:00:00 2001 From: Hunnyboy1217 <110440428+hunnyboy1217@users.noreply.github.com> Date: Sat, 9 May 2026 05:03:56 -0700 Subject: [PATCH] feat(connectors): ETag-based bypass for incremental S3 ingestion (#14628) (#14677) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What problem does this PR solve? S3-family connector syncs currently re-download every in-window object just so we can compute `xxhash128(blob)` and compare against `Document.content_hash`. Anything that bumps `LastModified` without changing bytes (`aws s3 cp` touches, bucket re-encryption, etc.) pays full bandwidth and re-parses files that didn't actually change. #14628 covers the broader incremental-ingestion redesign; this PR is the first slice. The fix is a pre-listing short-circuit. `BlobStorageConnector` (S3 / R2 / GCS / OCI / S3-compat) now implements a new `FingerprintConnector` interface: `list_keys()` paginates `list_objects_v2` and yields `KeyRecord(key, fingerprint)` where `fingerprint = xxhash128(ETag)`. The orchestrator joins those against the connector's existing `{doc_id: content_hash}` map and only calls `get_value(key)` when the fingerprint differs. Unchanged keys are skipped entirely — no `GetObject`, no re-parse. No DDL. xxhash128(ETag) is 32 hex chars and reuses the existing `Document.content_hash` column per @yingfeng's suggestion; the connector decides at listing time whether to populate it. Local uploads and connectors that don't opt in fall through to the existing post-download `xxhash128(blob)` path with no behavior change. This is PR-1 of a 4-PR series — full design lives on #14628. Subsequent PRs extend tier 1 to local FS / WebDAV / Dropbox / Seafile / RDBMS (PR-2), wire up tier 2 cursor connectors with `SyncLogs.next_checkpoint` (PR-3), and unify deletion via `KeyRecord(deleted=True)` reconciliation (PR-4). Holding those back keeps this PR additive and reviewable on its own. #### Files touched - `common/data_source/models.py` — new `KeyRecord`; optional `fingerprint` on `Document` - `common/data_source/interfaces.py` — `IncrementalCapability` enum, `FingerprintConnector` ABC - `common/data_source/blob_connector.py` — `BlobStorageConnector` implements `FingerprintConnector`; per-object download factored into `_build_document_from_obj()` so `_yield_blob_objects`, `list_keys`, `get_value` all share it - `rag/svr/sync_data_source.py` — `_BlobLikeBase._fingerprint_filtered_generator` does the bypass loop; `_run_task_logic` plumbs `doc.fingerprint` into the upload dict - `api/db/services/document_service.py` — `list_id_content_hash_map_by_kb_and_source_type()` helper - `api/db/services/connector_service.py` + `file_service.py` — fingerprint flows through `duplicate_and_parse → upload_document` and lands in `content_hash` - `test/unit_test/common/test_blob_connector_fingerprint.py` — 14 tests covering ETag normalization (single-part, multipart, quoted, empty), `list_keys()` not calling `GetObject`, `get_value()` materializing with fingerprint, deterministic/stable fingerprints, and the bypass loop asserting `GetObject` is *not* called on a match #### Worth flagging for review Old `_BlobLikeBase._generate` called `poll_source(start, now)` with a `LastModified` window when `poll_range_start` was set. New code uses `_fingerprint_filtered_generator` (full bucket listing + fingerprint compare) outside of explicit `reindex=1`. Strictly better for unchanged-bucket cases since it skips `GetObject`, but it does mean every sync now does a full `list_objects_v2` paginate. Should still be cheap for most buckets — flagging in case anyone has a very large bucket where the time-window filter was meaningful. On migration: existing rows have `content_hash = xxhash128(blob)` from the old code. The first sync after this lands sees ETag-derived fingerprints that don't match, re-fetches every object once, and writes the new fingerprint. From the second sync onward the bypass works as expected. "Slow day one, fast every day after." A `fingerprint_backfill: trust` opt-out is sketched in the design doc but not in this PR. #### Test plan - [x] `uv run ruff check` — clean on all 8 touched files - [x] `uv run pytest test/unit_test/common/test_blob_connector_fingerprint.py -v` — 14 passed - [x] Broader unit-test suite — no regressions in anything I touched - [ ] Manual smoke against a real S3 bucket — configure a connector, run sync twice, expect the second sync to log `bypassed=N, fetched=0` and no `GetObject` calls in CloudTrail / bucket access logs - [ ] Manual smoke with `reindex=1` — confirm the full re-download path still works ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Yingfeng --- api/db/services/connector_service.py | 5 +- api/db/services/document_service.py | 29 ++ api/db/services/file_service.py | 10 +- common/data_source/blob_connector.py | 151 ++++++-- common/data_source/interfaces.py | 52 ++- common/data_source/models.py | 19 + rag/svr/sync_data_source.py | 98 ++++- .../common/test_blob_connector_fingerprint.py | 347 ++++++++++++++++++ 8 files changed, 658 insertions(+), 53 deletions(-) create mode 100644 test/unit_test/common/test_blob_connector_fingerprint.py diff --git a/api/db/services/connector_service.py b/api/db/services/connector_service.py index 9f7b0e6ded..ab754101e1 100644 --- a/api/db/services/connector_service.py +++ b/api/db/services/connector_service.py @@ -16,7 +16,7 @@ import logging from datetime import datetime import os -from typing import Tuple, List +from typing import Optional, Tuple, List from anthropic import BaseModel from peewee import SQL, fn @@ -276,12 +276,13 @@ class SyncLogsService(CommonService): id: str filename: str blob: bytes + fingerprint: Optional[str] = None def read(self) -> bytes: return self.blob errs = [] - files = [FileObj(id=d["id"], filename=d["semantic_identifier"]+(f"{d['extension']}" if d["semantic_identifier"][::-1].find(d['extension'][::-1])<0 else ""), blob=d["blob"]) for d in docs] + files = [FileObj(id=d["id"], filename=d["semantic_identifier"]+(f"{d['extension']}" if d["semantic_identifier"][::-1].find(d['extension'][::-1])<0 else ""), blob=d["blob"], fingerprint=d.get("fingerprint")) for d in docs] doc_ids = [] err, doc_blob_pairs = FileService.upload_document(kb, files, tenant_id, src) errs.extend(err) diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index 7992cdb610..bf6ebacbba 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -388,6 +388,35 @@ class DocumentService(CommonService): offset += page_size return res + @classmethod + @DB.connection_context() + def list_id_content_hash_map_by_kb_and_source_type(cls, kb_id, source_type, page_size=500): + """Return {doc_id: content_hash} for the connector's existing docs. + + Used by the fingerprint-bypass path to decide which keys can skip a + re-fetch -- if the connector's listing fingerprint equals content_hash, + the body hasn't changed since the last sync. + + Ordered by create_time so LIMIT/OFFSET pagination is stable under + concurrent writes; without this, page boundaries can drop or duplicate + rows and the resulting map would silently miss entries. + """ + fields = [cls.model.id, cls.model.content_hash] + docs = cls.model.select(*fields).where( + cls.model.kb_id == kb_id, + cls.model.source_type == source_type, + ).order_by(cls.model.create_time.asc()) + offset = 0 + result: dict[str, str] = {} + while True: + batch = list(docs.offset(offset).limit(page_size).dicts()) + if not batch: + break + for row in batch: + result[row["id"]] = row.get("content_hash") or "" + offset += page_size + return result + @classmethod @DB.connection_context() def get_all_docs_by_creator_id(cls, creator_id): diff --git a/api/db/services/file_service.py b/api/db/services/file_service.py index db8ae4b72f..e8b71a6afd 100644 --- a/api/db/services/file_service.py +++ b/api/db/services/file_service.py @@ -482,7 +482,12 @@ class FileService(CommonService): err.append(file.filename + ": " + user_msg) continue blob = file.read() - new_hash = xxhash.xxh128(blob).hexdigest() + # Connector-supplied fingerprint (e.g. xxhash128(S3 ETag)) + # takes precedence: for connector-sourced docs the bypass + # path uses the fingerprint as content_hash, so reverting + # to xxhash128(blob) here would defeat it. + incoming_fp = getattr(file, "fingerprint", None) + new_hash = incoming_fp or xxhash.xxh128(blob).hexdigest() old_hash = doc.content_hash or "" settings.STORAGE_IMPL.put(kb.id, doc.location, blob, kb.tenant_id) doc.size = len(blob) @@ -518,6 +523,7 @@ class FileService(CommonService): thumbnail_location = f"thumbnail_{doc_id}.png" settings.STORAGE_IMPL.put(kb.id, thumbnail_location, img) + incoming_fp = getattr(file, "fingerprint", None) doc = { "id": doc_id, "kb_id": kb.id, @@ -532,7 +538,7 @@ class FileService(CommonService): "location": location, "size": len(blob), "thumbnail": thumbnail_location, - "content_hash": xxhash.xxh128(blob).hexdigest(), + "content_hash": incoming_fp or xxhash.xxh128(blob).hexdigest(), } DocumentService.insert(doc) diff --git a/common/data_source/blob_connector.py b/common/data_source/blob_connector.py index 7505b878ba..e183eb63aa 100644 --- a/common/data_source/blob_connector.py +++ b/common/data_source/blob_connector.py @@ -1,9 +1,12 @@ """Blob storage connector""" import logging import os +from collections.abc import Iterator from datetime import datetime, timezone from typing import Any, Optional +import xxhash + from common.data_source.utils import ( create_s3_client, detect_bucket_region, @@ -18,9 +21,14 @@ from common.data_source.exceptions import ( CredentialExpiredError, InsufficientPermissionsError ) -from common.data_source.interfaces import LoadConnector, PollConnector +from common.data_source.interfaces import ( + FingerprintConnector, + LoadConnector, + PollConnector, +) from common.data_source.models import ( Document, + KeyRecord, SecondsSinceUnixEpoch, GenerateDocumentsOutput, GenerateSlimDocumentOutput, @@ -28,7 +36,20 @@ from common.data_source.models import ( ) -class BlobStorageConnector(LoadConnector, PollConnector): +def _normalize_etag(raw_etag: Optional[str]) -> Optional[str]: + """Return a 32-char hex fingerprint derived from an S3 ETag. + + S3 ETags are MD5 (32 hex chars) for single-part uploads and "-" + (34+ chars) for multipart. We always hash so the column format is uniform + regardless of upload type or provider quirks; equality of the hashed value + is sufficient for change detection. + """ + if not raw_etag: + return None + return xxhash.xxh128(raw_etag.strip('"').encode()).hexdigest() + + +class BlobStorageConnector(LoadConnector, PollConnector, FingerprintConnector): """Blob storage connector""" def __init__( @@ -48,6 +69,11 @@ class BlobStorageConnector(LoadConnector, PollConnector): self.size_threshold: int | None = BLOB_STORAGE_SIZE_THRESHOLD self.bucket_region: Optional[str] = None self.european_residency: bool = european_residency + # Populated by list_keys() so a subsequent get_value(key) can find the + # raw S3 object metadata (LastModified, ETag, Key, Size) without a second + # head_object call. Lifetime is one list_keys() pass. + self._listing_cache: dict[str, dict[str, Any]] = {} + self._filename_counts: dict[str, int] = {} def set_allow_images(self, allow_images: bool) -> None: """Set whether to process images""" @@ -122,6 +148,44 @@ class BlobStorageConnector(LoadConnector, PollConnector): return None + def _build_document_from_obj( + self, + obj: dict[str, Any], + filename_counts: dict[str, int], + ) -> Optional[Document]: + """Materialize a Document for one S3 object, downloading its body.""" + key = obj["Key"] + file_name = os.path.basename(key) + last_modified = obj["LastModified"].replace(tzinfo=timezone.utc) + + size_bytes = extract_size_bytes(obj) + if ( + self.size_threshold is not None + and isinstance(size_bytes, int) + and size_bytes > self.size_threshold + ): + logging.warning( + f"{file_name} exceeds size threshold of {self.size_threshold}. Skipping." + ) + return None + + blob = download_object( + self.s3_client, self.bucket_name, key, self.size_threshold + ) + if blob is None: + return None + + return Document( + id=f"{self.bucket_type}:{self.bucket_name}:{key}", + blob=blob, + source=DocumentSource(self.bucket_type.value), + semantic_identifier=self._get_semantic_id(key, file_name, filename_counts), + extension=get_file_ext(file_name), + doc_updated_at=last_modified, + size_bytes=size_bytes if size_bytes else 0, + fingerprint=_normalize_etag(obj.get("ETag")), + ) + def _yield_blob_objects( self, start: datetime, @@ -132,51 +196,64 @@ class BlobStorageConnector(LoadConnector, PollConnector): batch: list[Document] = [] for obj in all_objects: - last_modified = obj["LastModified"].replace(tzinfo=timezone.utc) - file_name = os.path.basename(obj["Key"]) - key = obj["Key"] - - size_bytes = extract_size_bytes(obj) - if ( - self.size_threshold is not None - and isinstance(size_bytes, int) - and size_bytes > self.size_threshold - ): - logging.warning( - f"{file_name} exceeds size threshold of {self.size_threshold}. Skipping." - ) - continue - try: - blob = download_object( - self.s3_client, self.bucket_name, key, self.size_threshold - ) - if blob is None: + doc = self._build_document_from_obj(obj, filename_counts) + if doc is None: continue - - semantic_id = self._get_semantic_id(key, file_name, filename_counts) - - batch.append( - Document( - id=f"{self.bucket_type}:{self.bucket_name}:{key}", - blob=blob, - source=DocumentSource(self.bucket_type.value), - semantic_identifier=semantic_id, - extension=get_file_ext(file_name), - doc_updated_at=last_modified, - size_bytes=size_bytes if size_bytes else 0, - ) - ) + batch.append(doc) if len(batch) == self.batch_size: yield batch batch = [] - except Exception: - logging.exception(f"Error decoding object {key}") + logging.exception(f"Error decoding object {obj.get('Key')}") if batch: yield batch + def list_keys(self) -> Iterator[KeyRecord]: + """Enumerate the full bucket keyspace with per-object fingerprints. + + Cheap path: relies on list_objects_v2 which returns ETag in the listing, + so no GetObject call is needed. Caches each object's metadata so a + subsequent get_value(key) call can rebuild the Document without a second + round-trip to S3. + """ + if self.s3_client is None: + raise ConnectorMissingCredentialError("Blob storage") + + all_objects, filename_counts = self._collect_blob_objects( + start=datetime(1970, 1, 1, tzinfo=timezone.utc), + end=datetime.now(timezone.utc), + ) + self._filename_counts = filename_counts + self._listing_cache = {} + + for obj in all_objects: + doc_id = f"{self.bucket_type}:{self.bucket_name}:{obj['Key']}" + self._listing_cache[doc_id] = obj + yield KeyRecord( + key=doc_id, + fingerprint=_normalize_etag(obj.get("ETag")), + ) + + def get_value(self, key: str) -> Document: + """Materialize the Document for a key previously yielded by list_keys(). + + Must be called within the same list_keys() pass that produced the key, + since the metadata cache lives on the connector instance and is reset + each list_keys() call. + """ + obj = self._listing_cache.get(key) + if obj is None: + raise KeyError( + f"get_value({key!r}) called before list_keys() yielded the key, " + "or after a subsequent list_keys() reset the cache" + ) + doc = self._build_document_from_obj(obj, self._filename_counts) + if doc is None: + raise RuntimeError(f"Failed to materialize Document for key {key!r}") + return doc + def _collect_blob_objects( self, start: datetime, diff --git a/common/data_source/interfaces.py b/common/data_source/interfaces.py index 324293baab..fb547d7d92 100644 --- a/common/data_source/interfaces.py +++ b/common/data_source/interfaces.py @@ -2,7 +2,7 @@ import abc import uuid from abc import ABC, abstractmethod -from enum import IntFlag, auto +from enum import IntEnum, IntFlag, auto from types import TracebackType from typing import Any, Dict, Generator, TypeVar, Generic, Callable, TypeAlias from collections.abc import Iterator @@ -10,12 +10,26 @@ from anthropic import BaseModel from common.data_source.models import ( Document, + KeyRecord, SlimDocument, ConnectorCheckpoint, ConnectorFailure, SecondsSinceUnixEpoch, GenerateSlimDocumentOutput ) + +class IncrementalCapability(IntEnum): + """How a connector handles incremental sync. + + FULL_RESYNC -- every sync re-pulls; no per-key state. + CURSOR -- "give me everything since cursor X"; opaque cursor persisted across syncs. + FINGERPRINT -- list_keys() returns (key, fingerprint) cheaply; bodies fetched lazily. + """ + FULL_RESYNC = 0 + CURSOR = 1 + FINGERPRINT = 2 + + GenerateDocumentsOutput = Iterator[list[Document]] class LoadConnector(ABC): @@ -415,3 +429,39 @@ class IndexingHeartbeatInterface(ABC): just to act as a keep-alive. """ + +class FingerprintConnector(ABC): + """Tier 1 connector: cheap full listing with per-key fingerprint. + + Sources that can enumerate their entire keyspace via a metadata-only call + (e.g. S3 list_objects_v2 returning ETag + LastModified) implement this to + let the orchestrator skip GetObject for keys whose fingerprint hasn't + changed since the last sync. + + The fingerprint is an opaque equality token: two equal fingerprints mean + the content is unchanged from the orchestrator's point of view. Format is + a 32-char hex string so it fits the existing Document.content_hash column; + connectors are responsible for normalizing whatever the source exposes + (typically by hashing it with xxhash128). + """ + + INCREMENTAL_CAPABILITY: IncrementalCapability = IncrementalCapability.FINGERPRINT + + @abstractmethod + def list_keys(self) -> Iterator[KeyRecord]: + """Yield one KeyRecord per object currently in the source. + + Must enumerate the full current keyspace -- the orchestrator diffs the + result against persisted state to detect adds, updates, and deletes. + """ + raise NotImplementedError + + @abstractmethod + def get_value(self, key: str) -> Document: + """Fetch the body for a single key, returning a fully populated Document. + + Called only when list_keys()'s fingerprint differs from the persisted + content_hash for that key (or when no persisted fingerprint exists). + """ + raise NotImplementedError + diff --git a/common/data_source/models.py b/common/data_source/models.py index 71f8c27242..29cb6bc251 100644 --- a/common/data_source/models.py +++ b/common/data_source/models.py @@ -99,6 +99,25 @@ class Document(BaseModel): primary_owners: Optional[list] = None metadata: Optional[dict[str, Any]] = None doc_metadata: Optional[dict[str, Any]] = None + # Opaque, connector-supplied fingerprint stored in Document.content_hash for + # change-detection. 32-char hex string; format is per-source (xxhash128 of + # bytes for local uploads, xxhash128(ETag) for blob storage, etc.). When set + # on a yielded Document, the orchestrator persists it as content_hash and + # skips the post-download xxhash128(blob) recomputation. + fingerprint: Optional[str] = None + + +class KeyRecord(BaseModel): + """One entry returned by a FingerprintConnector.list_keys() call. + + A KeyRecord is the cheap-listing primitive: connector enumerates all keys + it has, attaches a fingerprint when the source exposes one, and the + orchestrator only fetches content when the fingerprint differs from what's + persisted. + """ + key: str + fingerprint: Optional[str] = None + deleted: bool = False class BasicExpertInfo(BaseModel): diff --git a/rag/svr/sync_data_source.py b/rag/svr/sync_data_source.py index 9a60701e79..92ab86b023 100644 --- a/rag/svr/sync_data_source.py +++ b/rag/svr/sync_data_source.py @@ -213,6 +213,8 @@ class SyncBase: } if doc.metadata: d["metadata"] = doc.metadata + if getattr(doc, "fingerprint", None): + d["fingerprint"] = doc.fingerprint docs.append(d) try: @@ -301,6 +303,81 @@ class SyncBase: class _BlobLikeBase(SyncBase): DEFAULT_BUCKET_TYPE: str = "s3" + def _fingerprint_filtered_generator(self, task: dict): + """Generator that uses list_keys() + get_value() to skip unchanged objects. + + Pre-loads {doc_id: content_hash} for the connector's existing docs in + this KB, iterates the bucket via list_keys(), and only materializes a + Document (one GetObject call) when the listing fingerprint differs from + the persisted content_hash. Unchanged objects are skipped entirely -- + no download, no re-parse. + + Per-key fetch failures are counted and surfaced via SyncLogsService so + a partially failing sync (e.g. throttling, IAM regression mid-run) + doesn't silently report DONE while half the bucket is unreachable. + Connectors yielding KeyRecord(deleted=True) are skipped here -- actual + deletion reconciliation lives in the unified delete pass (PR-4). + """ + source_type = f"{self.SOURCE_NAME}/{task['connector_id']}" + existing_fingerprints = DocumentService.list_id_content_hash_map_by_kb_and_source_type( + task["kb_id"], source_type, + ) + + bypass_count = 0 + fetch_count = 0 + fail_count = 0 + batch = [] + for key_record in self.connector.list_keys(): + if key_record.deleted: + continue + + doc_id = hash128(key_record.key) + stored = existing_fingerprints.get(doc_id, "") + if key_record.fingerprint and stored and key_record.fingerprint == stored: + bypass_count += 1 + continue + + try: + doc = self.connector.get_value(key_record.key) + except Exception as ex: + fail_count += 1 + logging.exception( + "Failed to fetch %s from %s: %s", + key_record.key, + self.SOURCE_NAME, + ex, + ) + continue + + fetch_count += 1 + batch.append(doc) + if len(batch) >= self.connector.batch_size: + yield batch + batch = [] + + if batch: + yield batch + + log_msg = ( + "[%s] fingerprint sync: %d bypassed, %d fetched, %d failed " + "(connector_id=%s, kb_id=%s)" + ) + log_args = ( + self.SOURCE_NAME, + bypass_count, + fetch_count, + fail_count, + task["connector_id"], + task["kb_id"], + ) + # Use WARNING when any fetch failed so partial-bucket regressions + # (auth, throttling, IAM drift) surface without diving into the + # per-exception traces above. + if fail_count: + logging.warning(log_msg, *log_args) + else: + logging.info(log_msg, *log_args) + async def _generate(self, task: dict): bucket_type = self.conf.get("bucket_type", self.DEFAULT_BUCKET_TYPE) @@ -313,14 +390,13 @@ class _BlobLikeBase(SyncBase): self.connector.load_credentials(self.conf["credentials"]) file_list = None - document_batch_generator = ( - self.connector.load_from_state() - if task["reindex"] == "1" or not task["poll_range_start"] - else self.connector.poll_source( - task["poll_range_start"].timestamp(), - datetime.now(timezone.utc).timestamp(), - ) - ) + # Fingerprint-bypass path: skip GetObject for unchanged ETags. Disabled + # on full reindex (we want to re-fetch everything in that case). + use_fingerprint_path = task["reindex"] != "1" + if use_fingerprint_path: + document_batch_generator = self._fingerprint_filtered_generator(task) + else: + document_batch_generator = self.connector.load_from_state() if ( task["reindex"] != "1" @@ -332,9 +408,9 @@ class _BlobLikeBase(SyncBase): file_list.extend(slim_batch) _begin_info = ( - "totally" - if task["reindex"] == "1" or not task["poll_range_start"] - else "from {}".format(task["poll_range_start"]) + "fingerprint-bypass" + if use_fingerprint_path + else "full reindex" ) logging.info( diff --git a/test/unit_test/common/test_blob_connector_fingerprint.py b/test/unit_test/common/test_blob_connector_fingerprint.py new file mode 100644 index 0000000000..ec133fd697 --- /dev/null +++ b/test/unit_test/common/test_blob_connector_fingerprint.py @@ -0,0 +1,347 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Tests for the FingerprintConnector bypass path in BlobStorageConnector.""" + +import importlib.util +import sys +from datetime import datetime, timezone +from pathlib import Path +from types import ModuleType + +import pytest +import xxhash + + +def _load_blob_connector_module(): + repo_root = Path(__file__).resolve().parents[3] + package_name = "common.data_source" + saved_modules = {name: module for name, module in sys.modules.items() if name == package_name or name.startswith(f"{package_name}.")} + package_stub = ModuleType(package_name) + package_stub.__path__ = [str(repo_root / "common" / "data_source")] + sys.modules[package_name] = package_stub + + try: + spec = importlib.util.spec_from_file_location( + "_blob_connector_under_test", + repo_root / "common" / "data_source" / "blob_connector.py", + ) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + finally: + for name in list(sys.modules): + if name == package_name or name.startswith(f"{package_name}."): + if name in saved_modules: + sys.modules[name] = saved_modules[name] + else: + sys.modules.pop(name, None) + + +blob_connector = _load_blob_connector_module() +BlobStorageConnector = blob_connector.BlobStorageConnector +_normalize_etag = blob_connector._normalize_etag + + +# --------------------------------------------------------------------------- +# Fake S3 client wired through a paginator-style interface. +# --------------------------------------------------------------------------- + + +class _FakePaginator: + def __init__(self, pages: list[dict]) -> None: + self._pages = pages + + def paginate(self, **_kwargs): + for page in self._pages: + yield page + + +class _FakeS3Client: + """Captures every call on the connector's S3 client. + + Tests assert against `get_object_calls` to verify that the fingerprint + bypass actually skips downloads when ETags haven't changed. + """ + + def __init__(self, objects: list[dict]) -> None: + self._objects = objects + self.get_object_calls: list[tuple[str, str]] = [] + # Hand objects to the paginator unmodified so the connector exercises + # its own directory-placeholder filtering logic. + self._paginator = _FakePaginator([{"Contents": list(objects)}]) + + def get_paginator(self, name: str): + assert name == "list_objects_v2" + return self._paginator + + def list_objects_v2(self, **_kwargs): + return {"Contents": self._objects, "KeyCount": len(self._objects)} + + def get_object(self, Bucket: str, Key: str): # noqa: N803 (boto3 API) + self.get_object_calls.append((Bucket, Key)) + body_text = f"body-of-{Key}".encode() + return { + "Body": _FakeBody(body_text), + "ContentLength": len(body_text), + } + + +class _FakeBody: + """Minimal stand-in for botocore's StreamingBody. + + The real downloader (common.data_source.utils.download_object) consumes + the body via iter_chunks() and then calls close(); fake out both. + """ + + def __init__(self, payload: bytes) -> None: + self._payload = payload + + def read(self) -> bytes: + return self._payload + + def iter_chunks(self, chunk_size: int = 65536): + for i in range(0, len(self._payload), chunk_size): + yield self._payload[i : i + chunk_size] + + def close(self) -> None: + return None + + +def _make_connector(s3_client) -> BlobStorageConnector: + connector = BlobStorageConnector(bucket_type="s3", bucket_name="test-bucket") + connector.s3_client = s3_client + return connector + + +def _s3_object(key: str, etag: str, size: int = 12) -> dict: + return { + "Key": key, + "ETag": f'"{etag}"', + "LastModified": datetime(2026, 1, 1, 12, tzinfo=timezone.utc), + "Size": size, + } + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_normalize_etag_returns_32_char_hex_for_singlepart_etag(): + fp = _normalize_etag('"d41d8cd98f00b204e9800998ecf8427e"') + assert fp is not None + assert len(fp) == 32 + assert all(c in "0123456789abcdef" for c in fp) + + +def test_normalize_etag_returns_32_char_hex_for_multipart_etag(): + """Multipart ETags are 34+ chars; hashing normalizes them to 32.""" + fp = _normalize_etag('"d41d8cd98f00b204e9800998ecf8427e-7"') + assert fp is not None + assert len(fp) == 32 + + +def test_normalize_etag_is_deterministic(): + raw = '"abc123def456abc123def456abc123de"' + assert _normalize_etag(raw) == _normalize_etag(raw) + + +def test_normalize_etag_strips_quotes_so_quoted_and_unquoted_match(): + quoted = '"d41d8cd98f00b204e9800998ecf8427e"' + unquoted = "d41d8cd98f00b204e9800998ecf8427e" + assert _normalize_etag(quoted) == _normalize_etag(unquoted) + + +def test_normalize_etag_returns_none_for_empty_input(): + assert _normalize_etag("") is None + assert _normalize_etag(None) is None + + +def test_list_keys_yields_one_keyrecord_per_object_with_fingerprint(): + s3 = _FakeS3Client( + [ + _s3_object("foo.txt", "etag-foo"), + _s3_object("bar/baz.txt", "etag-baz"), + ] + ) + connector = _make_connector(s3) + + records = list(connector.list_keys()) + + assert len(records) == 2 + assert {r.key for r in records} == { + "BlobType.S3:test-bucket:foo.txt", + "BlobType.S3:test-bucket:bar/baz.txt", + } + for record in records: + assert record.fingerprint is not None + assert len(record.fingerprint) == 32 + assert record.deleted is False + + +def test_list_keys_does_not_call_get_object(): + """list_keys() must be cheap -- no body downloads during enumeration.""" + s3 = _FakeS3Client([_s3_object("foo.txt", "etag-foo")]) + connector = _make_connector(s3) + + list(connector.list_keys()) + + assert s3.get_object_calls == [] + + +def test_list_keys_skips_directory_placeholder_keys(): + """S3 'folders' are zero-byte keys ending in '/'; they shouldn't yield records.""" + s3 = _FakeS3Client( + [ + _s3_object("real-file.txt", "etag-real"), + _s3_object("folder/", "etag-folder"), + ] + ) + connector = _make_connector(s3) + + keys = [r.key for r in connector.list_keys()] + + assert keys == ["BlobType.S3:test-bucket:real-file.txt"] + + +def test_get_value_returns_document_with_fingerprint_set(): + s3 = _FakeS3Client([_s3_object("foo.txt", "etag-foo")]) + connector = _make_connector(s3) + [record] = list(connector.list_keys()) + + doc = connector.get_value(record.key) + + assert doc.id == "BlobType.S3:test-bucket:foo.txt" + assert doc.fingerprint == record.fingerprint + assert doc.fingerprint == xxhash.xxh128(b"etag-foo").hexdigest() + + +def test_get_value_calls_get_object_exactly_once_per_key(): + s3 = _FakeS3Client([_s3_object("foo.txt", "etag-foo")]) + connector = _make_connector(s3) + [record] = list(connector.list_keys()) + + connector.get_value(record.key) + + assert s3.get_object_calls == [("test-bucket", "foo.txt")] + + +def test_get_value_raises_keyerror_when_called_before_list_keys(): + s3 = _FakeS3Client([_s3_object("foo.txt", "etag-foo")]) + connector = _make_connector(s3) + + with pytest.raises(KeyError): + connector.get_value("BlobType.S3:test-bucket:foo.txt") + + +def test_singlepart_and_multipart_etags_yield_different_fingerprints(): + """Sanity: distinct ETags must produce distinct fingerprints.""" + s3 = _FakeS3Client( + [ + _s3_object("a.bin", "d41d8cd98f00b204e9800998ecf8427e"), + _s3_object("b.bin", "d41d8cd98f00b204e9800998ecf8427e-3"), + ] + ) + connector = _make_connector(s3) + + records = list(connector.list_keys()) + + assert records[0].fingerprint != records[1].fingerprint + + +def test_fingerprint_stable_across_repeated_listings(): + """Same ETag in two list_keys() calls yields the same fingerprint.""" + s3 = _FakeS3Client([_s3_object("foo.txt", "etag-stable")]) + connector = _make_connector(s3) + + fp_first = next(connector.list_keys()).fingerprint + fp_second = next(connector.list_keys()).fingerprint + + assert fp_first == fp_second + + +# --------------------------------------------------------------------------- +# Bypass-logic test: simulates what the orchestrator does in +# _BlobLikeBase._fingerprint_filtered_generator. Verifies that a key whose +# fingerprint matches the persisted content_hash is NOT fetched. +# --------------------------------------------------------------------------- + + +def test_orchestrator_pattern_skips_get_object_when_fingerprint_matches(): + # Use distinct base names: "unchanged.txt".endswith("changed.txt") is True, + # which would silently break endswith-based lookups in the test setup. + s3 = _FakeS3Client( + [ + _s3_object("static.txt", "etag-static"), + _s3_object("modified.txt", "etag-modified"), + ] + ) + connector = _make_connector(s3) + + # Pre-compute the fingerprints the connector would emit, then pretend the + # DB already stores the one for static.txt but a stale value for + # modified.txt. This is the steady-state bypass scenario. + listed = list(connector.list_keys()) + static_record = next(r for r in listed if r.key.endswith(":static.txt")) + modified_record = next(r for r in listed if r.key.endswith(":modified.txt")) + persisted = { + static_record.key: static_record.fingerprint, + modified_record.key: "stale-fingerprint", + } + + # Reset the call log so we only count get_object during the bypass loop. + s3.get_object_calls = [] + + fetched = [] + for record in connector.list_keys(): + if record.fingerprint and persisted.get(record.key) == record.fingerprint: + continue + fetched.append(connector.get_value(record.key)) + + assert [doc.id for doc in fetched] == ["BlobType.S3:test-bucket:modified.txt"] + assert s3.get_object_calls == [("test-bucket", "modified.txt")] + + +def test_orchestrator_pattern_skips_deleted_records_without_calling_get_value(): + """KeyRecord(deleted=True) must short-circuit before get_value(). + + Reach KeyRecord through the already-loaded blob_connector module to avoid + triggering common.data_source.__init__'s circular imports. + """ + KeyRecord = blob_connector.KeyRecord + + s3 = _FakeS3Client([_s3_object("foo.txt", "etag-foo")]) + connector = _make_connector(s3) + + # Manually feed a deleted KeyRecord through the bypass logic to assert the + # short-circuit holds even when a connector emits one. (BlobStorageConnector + # itself doesn't yield deleted records yet -- that's PR-4 -- but the + # orchestrator must already be defensive.) + deleted_record = KeyRecord( + key="BlobType.S3:test-bucket:gone.txt", + fingerprint=None, + deleted=True, + ) + + # Mirror the orchestrator's loop body verbatim. + fetched = [] + for record in [deleted_record]: + if record.deleted: + continue + fetched.append(connector.get_value(record.key)) + + assert fetched == [] + assert s3.get_object_calls == []