diff --git a/common/data_source/blob_connector.py b/common/data_source/blob_connector.py index 1ab39189d7..627aa8fba7 100644 --- a/common/data_source/blob_connector.py +++ b/common/data_source/blob_connector.py @@ -10,6 +10,7 @@ from common.data_source.utils import ( download_object, extract_size_bytes, get_file_ext, + is_accepted_file_ext, ) from common.data_source.config import BlobType, DocumentSource, BLOB_STORAGE_SIZE_THRESHOLD, INDEX_BATCH_SIZE from common.data_source.exceptions import ( @@ -18,7 +19,7 @@ from common.data_source.exceptions import ( CredentialExpiredError, InsufficientPermissionsError ) -from common.data_source.interfaces import LoadConnector, PollConnector +from common.data_source.interfaces import LoadConnector, OnyxExtensionType, PollConnector from common.data_source.models import Document, SecondsSinceUnixEpoch, GenerateDocumentsOutput @@ -130,15 +131,23 @@ class BlobStorageConnector(LoadConnector, PollConnector): # Collect all objects first to count filename occurrences all_objects = [] + extension_type = OnyxExtensionType.Plain | OnyxExtensionType.Document + if bool(self._allow_images): + extension_type |= OnyxExtensionType.Multimedia for page in pages: if "Contents" not in page: continue for obj in page["Contents"]: - if obj["Key"].endswith("/"): + key = obj["Key"] + if key.endswith("/"): continue last_modified = obj["LastModified"].replace(tzinfo=timezone.utc) - if start < last_modified <= end: - all_objects.append(obj) + if not (start < last_modified <= end): + continue + file_name = os.path.basename(key) + if not is_accepted_file_ext(get_file_ext(file_name), extension_type): + continue + all_objects.append(obj) # Count filename occurrences to determine which need full paths filename_counts: dict[str, int] = {} diff --git a/rag/svr/sync_data_source.py b/rag/svr/sync_data_source.py index e24a8719bb..ac70a6843a 100644 --- a/rag/svr/sync_data_source.py +++ b/rag/svr/sync_data_source.py @@ -267,6 +267,7 @@ class _BlobLikeBase(SyncBase): bucket_name=self.conf["bucket_name"], prefix=self.conf.get("prefix", ""), ) + self.connector.set_allow_images(self.conf.get("allow_images", False)) self.connector.load_credentials(self.conf["credentials"]) document_batch_generator = (