mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-06 03:18:36 +08:00
411 lines
17 KiB
Python
411 lines
17 KiB
Python
"""Azure Blob Storage data-source connector.
|
|
|
|
Ingests blobs from a user's Azure container into a RAGFlow knowledge
|
|
base. This is distinct from RAGFlow's own Azure storage *backend*
|
|
(``rag/utils/azure_sas_conn.py``, ``rag/utils/azure_spn_conn.py``),
|
|
which stores RAGFlow's own files.
|
|
|
|
Auth supports three mutually exclusive modes, selected explicitly by the
|
|
caller-supplied ``auth_mode`` (the UI hides the other modes' fields but
|
|
does not clear them, so we must not guess from whichever field happens to
|
|
be populated). When ``auth_mode`` is absent (older configs / direct API
|
|
callers) we fall back to field precedence:
|
|
|
|
1. **Connection string** — ``connection_string`` credential; one line,
|
|
everything embedded. Good for dev / testing.
|
|
2. **Account key** — ``account_name`` + ``account_key``; maps to the
|
|
same underlying SAS-less AccountKey credential.
|
|
3. **SAS token** — ``container_url`` + ``sas_token``; the shape that
|
|
``RAGFlowAzureSasBlob`` already uses.
|
|
|
|
Incremental runs are scoped by the poll time window
|
|
(``since_epoch`` < last-modified <= ``until_epoch``).
|
|
Each blob's ETag is also emitted as the document fingerprint, which the
|
|
indexing pipeline persists as ``content_hash`` so unchanged blobs are not
|
|
re-embedded. The connector itself keeps no cross-run ETag state.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from datetime import datetime, timezone
|
|
from typing import Any, Generator
|
|
|
|
from common.data_source.config import INDEX_BATCH_SIZE
|
|
from common.data_source.exceptions import (
|
|
ConnectorMissingCredentialError,
|
|
ConnectorValidationError,
|
|
InsufficientPermissionsError,
|
|
UnexpectedValidationError,
|
|
)
|
|
from common.data_source.interfaces import (
|
|
CheckpointedConnectorWithPermSync,
|
|
SecondsSinceUnixEpoch,
|
|
SlimConnectorWithPermSync,
|
|
)
|
|
from common.data_source.models import ConnectorCheckpoint, SlimDocument
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Extensions we ingest; mirrors the same set used by the OneDrive
|
|
# connector so behaviour is consistent across all file-based sources.
|
|
_SUPPORTED_EXTENSIONS = {
|
|
".pdf",
|
|
".docx",
|
|
".doc",
|
|
".xlsx",
|
|
".xls",
|
|
".pptx",
|
|
".ppt",
|
|
".txt",
|
|
".md",
|
|
".csv",
|
|
".html",
|
|
".htm",
|
|
".json",
|
|
".xml",
|
|
}
|
|
|
|
_AZURE_ENDPOINT_SUFFIX = "blob.core.windows.net"
|
|
|
|
|
|
class AzureBlobCheckpoint(ConnectorCheckpoint):
|
|
"""Checkpoint marker for the Azure Blob connector.
|
|
|
|
The connector keeps no cross-run state of its own: a single
|
|
``load_from_checkpoint`` pass lists the container once and sets
|
|
``has_more=False``. Incremental scoping comes from the poll time
|
|
window, and per-blob change detection from the document fingerprint
|
|
(ETag) the pipeline persists as ``content_hash``.
|
|
"""
|
|
|
|
|
|
class AzureBlobConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermSync):
|
|
"""Azure Blob Storage data-source connector.
|
|
|
|
Authenticates with one of three credential modes (connection string,
|
|
account key, or SAS token), chosen by ``auth_mode``, and enumerates
|
|
blobs in the configured container under an optional prefix. Each blob's
|
|
ETag is surfaced as the document fingerprint so the pipeline can skip
|
|
re-embedding unchanged blobs across runs.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
batch_size: int = INDEX_BATCH_SIZE,
|
|
prefix: str | None = None,
|
|
allow_images: bool = False,
|
|
auth_mode: str | None = None,
|
|
) -> None:
|
|
self.batch_size = batch_size
|
|
self.prefix = (prefix or "").lstrip("/")
|
|
self.allow_images = allow_images
|
|
# Explicitly selected credential mode: "connection_string",
|
|
# "account_key", or "sas_token". Empty falls back to precedence.
|
|
self.auth_mode = (auth_mode or "").strip().lower()
|
|
self._container_client = None
|
|
|
|
# ------------------------------------------------------------------
|
|
# Auth
|
|
# ------------------------------------------------------------------
|
|
|
|
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
|
from azure.storage.blob import BlobServiceClient, ContainerClient
|
|
|
|
conn_str = credentials.get("connection_string")
|
|
account_name = credentials.get("account_name")
|
|
account_key = credentials.get("account_key")
|
|
container_url = (credentials.get("container_url") or "").rstrip("/")
|
|
sas_token = credentials.get("sas_token")
|
|
container_name = credentials.get("container_name") or ""
|
|
|
|
# Honor the explicitly selected auth mode. The UI hides inactive
|
|
# credential fields but does not clear them, so a user who fills one
|
|
# mode and then switches can leave stale values behind; selecting by
|
|
# field precedence would then authenticate with the wrong mode.
|
|
# Fall back to precedence only when no auth_mode was supplied.
|
|
mode = self.auth_mode
|
|
if not mode:
|
|
if conn_str:
|
|
mode = "connection_string"
|
|
elif account_name and account_key:
|
|
mode = "account_key"
|
|
elif container_url and sas_token:
|
|
mode = "sas_token"
|
|
|
|
try:
|
|
if mode == "connection_string":
|
|
if not conn_str:
|
|
raise ConnectorMissingCredentialError("Azure Blob: connection_string is required for the connection_string auth mode")
|
|
if not container_name:
|
|
raise ConnectorMissingCredentialError("Azure Blob: container_name is required together with connection_string")
|
|
svc = BlobServiceClient.from_connection_string(conn_str)
|
|
self._container_client = svc.get_container_client(container_name)
|
|
elif mode == "account_key":
|
|
if not (account_name and account_key):
|
|
raise ConnectorMissingCredentialError("Azure Blob: account_name and account_key are required for the account_key auth mode")
|
|
if not container_name:
|
|
raise ConnectorMissingCredentialError("Azure Blob: container_name is required together with account_name + account_key")
|
|
account_url = f"https://{account_name}.{_AZURE_ENDPOINT_SUFFIX}"
|
|
svc = BlobServiceClient(
|
|
account_url=account_url,
|
|
credential=account_key,
|
|
)
|
|
self._container_client = svc.get_container_client(container_name)
|
|
elif mode == "sas_token":
|
|
if not (container_url and sas_token):
|
|
raise ConnectorMissingCredentialError("Azure Blob: container_url and sas_token are required for the sas_token auth mode")
|
|
# mirrors RAGFlowAzureSasBlob; strip a leading "?" so we
|
|
# never produce a double-"?" that breaks SAS auth.
|
|
normalized_sas = str(sas_token).lstrip("?")
|
|
full_url = f"{container_url}?{normalized_sas}"
|
|
self._container_client = ContainerClient.from_container_url(full_url)
|
|
else:
|
|
raise ConnectorMissingCredentialError(
|
|
"Azure Blob credentials are incomplete. Provide one of: (a) connection_string + container_name, (b) account_name + account_key + container_name, (c) container_url + sas_token."
|
|
)
|
|
except ConnectorMissingCredentialError:
|
|
raise
|
|
except Exception as exc:
|
|
raise ConnectorMissingCredentialError(f"Failed to initialise Azure Blob client: {exc}") from exc
|
|
|
|
return None
|
|
|
|
# ------------------------------------------------------------------
|
|
# Validation
|
|
# ------------------------------------------------------------------
|
|
|
|
def validate_connector_settings(self) -> None:
|
|
if self._container_client is None:
|
|
raise ConnectorMissingCredentialError("Azure Blob")
|
|
|
|
try:
|
|
# get_container_properties() costs one API call; it returns
|
|
# the ETag and last-modified of the container, proving both
|
|
# the credential and the container name are valid.
|
|
self._container_client.get_container_properties()
|
|
except Exception as exc:
|
|
msg = str(exc)
|
|
code = getattr(getattr(exc, "error_code", None), "value", None) or getattr(exc, "error_code", "")
|
|
if "AuthenticationFailed" in msg or "InvalidAuthenticationInfo" in msg:
|
|
raise ConnectorMissingCredentialError(f"Azure Blob credential rejected: {msg[:300]}") from exc
|
|
if "AuthorizationPermissionMismatch" in msg or "403" in msg:
|
|
raise InsufficientPermissionsError(f"Azure Blob: insufficient permissions on container: {msg[:300]}") from exc
|
|
if "ContainerNotFound" in msg or "404" in msg:
|
|
raise ConnectorValidationError(f"Azure Blob: container not found: {msg[:300]}") from exc
|
|
raise UnexpectedValidationError(f"Azure Blob validation failed ({code}): {msg[:300]}") from exc
|
|
|
|
# ------------------------------------------------------------------
|
|
# Checkpoint helpers
|
|
# ------------------------------------------------------------------
|
|
|
|
def build_dummy_checkpoint(self) -> AzureBlobCheckpoint:
|
|
return AzureBlobCheckpoint(has_more=True)
|
|
|
|
def validate_checkpoint_json(self, checkpoint_json: str) -> AzureBlobCheckpoint:
|
|
try:
|
|
return AzureBlobCheckpoint.model_validate_json(checkpoint_json)
|
|
except Exception:
|
|
return self.build_dummy_checkpoint()
|
|
|
|
# ------------------------------------------------------------------
|
|
# Core data loading
|
|
# ------------------------------------------------------------------
|
|
|
|
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Any:
|
|
return self._iter_documents(since_epoch=start, until_epoch=end)
|
|
|
|
def load_from_checkpoint(
|
|
self,
|
|
start: SecondsSinceUnixEpoch,
|
|
end: SecondsSinceUnixEpoch,
|
|
checkpoint: ConnectorCheckpoint,
|
|
) -> Any:
|
|
if not isinstance(checkpoint, AzureBlobCheckpoint):
|
|
checkpoint = self.build_dummy_checkpoint()
|
|
since = start if start else None
|
|
until = end if end else None
|
|
return self._iter_documents(checkpoint=checkpoint, since_epoch=since, until_epoch=until)
|
|
|
|
def load_from_checkpoint_with_perm_sync(
|
|
self,
|
|
start: SecondsSinceUnixEpoch,
|
|
end: SecondsSinceUnixEpoch,
|
|
checkpoint: ConnectorCheckpoint,
|
|
) -> Any:
|
|
return self.load_from_checkpoint(start, end, checkpoint)
|
|
|
|
def retrieve_all_slim_docs_perm_sync(
|
|
self,
|
|
callback: Any = None,
|
|
) -> Generator[list[SlimDocument], None, None]:
|
|
"""Yield batches of slim documents for prune / permission sync."""
|
|
if self._container_client is None:
|
|
raise ConnectorMissingCredentialError("Azure Blob")
|
|
|
|
batch: list[SlimDocument] = []
|
|
try:
|
|
for blob_props in self._container_client.list_blobs(name_starts_with=self.prefix or None):
|
|
name = blob_props.name
|
|
if not _has_supported_extension(name, self.allow_images):
|
|
continue
|
|
if callback:
|
|
callback(name, name)
|
|
batch.append(SlimDocument(id=name))
|
|
if len(batch) >= self.batch_size:
|
|
yield batch
|
|
batch = []
|
|
except Exception as exc:
|
|
raise UnexpectedValidationError(f"Azure Blob prune listing failed: {exc}") from exc
|
|
|
|
if batch:
|
|
yield batch
|
|
|
|
# ------------------------------------------------------------------
|
|
# Internal document iteration
|
|
# ------------------------------------------------------------------
|
|
|
|
def _iter_documents(
|
|
self,
|
|
checkpoint: AzureBlobCheckpoint | None = None,
|
|
since_epoch: float | None = None,
|
|
until_epoch: float | None = None,
|
|
):
|
|
from common.data_source.models import Document
|
|
|
|
if self._container_client is None:
|
|
raise ConnectorMissingCredentialError("Azure Blob")
|
|
|
|
batch: list[Document] = []
|
|
|
|
try:
|
|
for blob_props in self._container_client.list_blobs(name_starts_with=self.prefix or None):
|
|
name: str = blob_props.name
|
|
|
|
if not _has_supported_extension(name, self.allow_images):
|
|
continue
|
|
|
|
# Raw ETag (always present); Azure updates it on every
|
|
# write. Emitted below as the document fingerprint so the
|
|
# pipeline persists it as content_hash and skips re-embedding
|
|
# unchanged blobs across runs.
|
|
current_etag = (blob_props.etag or "").strip('"')
|
|
|
|
# Time-window filter: strict lower bound, inclusive upper
|
|
# bound (``since_epoch`` < last-modified <= ``until_epoch``).
|
|
# Excluding last-modified == since_epoch (the prior run's
|
|
# watermark, which that run already yielded) avoids stable
|
|
# duplicate re-fetches on the boundary — matching the
|
|
# Salesforce connector's ``> since``. Enforcing the upper
|
|
# bound keeps blobs modified mid-run from leaking into this
|
|
# window; they're picked up by the next run (whose lower bound
|
|
# is this run's upper bound), so an update can never fall into
|
|
# a gap between windows.
|
|
last_modified: datetime | None = blob_props.last_modified
|
|
if last_modified:
|
|
ts = last_modified.timestamp()
|
|
if since_epoch and ts <= since_epoch:
|
|
continue
|
|
if until_epoch and ts > until_epoch:
|
|
continue
|
|
|
|
# Download blob content. A blob that was deleted between the
|
|
# listing and this fetch is genuinely gone — skip it. Any
|
|
# other failure (throttling, transient 5xx, network) must
|
|
# abort the run: the sync framework advances its watermark
|
|
# from successfully yielded docs, so silently skipping a
|
|
# transiently-failed blob while newer blobs succeed would
|
|
# move the watermark past it and drop it permanently.
|
|
try:
|
|
blob_client = self._container_client.get_blob_client(name)
|
|
data = blob_client.download_blob().readall()
|
|
except Exception as exc:
|
|
if _is_blob_gone(exc):
|
|
logger.warning(
|
|
"Azure Blob: %s vanished between listing and fetch; skipping",
|
|
name,
|
|
)
|
|
continue
|
|
raise UnexpectedValidationError(f"Azure Blob: failed to download {name}: {exc}") from exc
|
|
|
|
doc_updated_at = last_modified.astimezone(timezone.utc) if last_modified else datetime.now(timezone.utc)
|
|
|
|
ext = _extension(name)
|
|
doc = Document(
|
|
id=name,
|
|
source="azure_blob",
|
|
semantic_identifier=name,
|
|
extension=ext,
|
|
blob=data,
|
|
doc_updated_at=doc_updated_at,
|
|
size_bytes=len(data),
|
|
fingerprint=current_etag or None,
|
|
metadata={
|
|
"container": _container_name(self._container_client),
|
|
"etag": current_etag,
|
|
"prefix": self.prefix,
|
|
},
|
|
)
|
|
batch.append(doc)
|
|
|
|
if len(batch) >= self.batch_size:
|
|
yield batch
|
|
batch = []
|
|
except UnexpectedValidationError:
|
|
raise
|
|
except Exception as exc:
|
|
raise UnexpectedValidationError(f"Azure Blob listing failed: {exc}") from exc
|
|
|
|
if batch:
|
|
yield batch
|
|
|
|
if checkpoint is not None:
|
|
checkpoint.has_more = False
|
|
|
|
|
|
# ----------------------------------------------------------------------
|
|
# Module-level helpers
|
|
# ----------------------------------------------------------------------
|
|
|
|
|
|
def _extension(name: str) -> str:
|
|
if "." not in name:
|
|
return ""
|
|
return "." + name.rsplit(".", 1)[-1].lower()
|
|
|
|
|
|
def _has_supported_extension(name: str, allow_images: bool) -> bool:
|
|
ext = _extension(name)
|
|
if ext in _SUPPORTED_EXTENSIONS:
|
|
return True
|
|
if allow_images and ext in {".png", ".jpg", ".jpeg", ".webp", ".gif", ".bmp", ".tiff"}:
|
|
return True
|
|
return False
|
|
|
|
|
|
def _is_blob_gone(exc: Exception) -> bool:
|
|
"""True when a download failed because the blob no longer exists.
|
|
|
|
Azure raises ``ResourceNotFoundError`` (status 404, error code
|
|
``BlobNotFound``) when a blob listed moments earlier has since been
|
|
deleted. That is not data loss — the blob is gone — so it is safe to
|
|
skip. Detected by attribute and string so we need not import the Azure
|
|
exception type at module load.
|
|
"""
|
|
if getattr(exc, "status_code", None) == 404:
|
|
return True
|
|
code = getattr(exc, "error_code", "") or ""
|
|
if "BlobNotFound" in str(code):
|
|
return True
|
|
msg = str(exc)
|
|
return "BlobNotFound" in msg or "ResourceNotFound" in msg
|
|
|
|
|
|
def _container_name(client: Any) -> str:
|
|
"""Extract the container name from a ContainerClient without
|
|
importing the Azure SDK at module level."""
|
|
try:
|
|
return client.container_name
|
|
except AttributeError:
|
|
return ""
|