mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
fix(llm): correct error handling, token accounting, and truncation in embedding providers (#15424)
### Summary Closes #15423 `rag/llm/embedding_model.py` hosts about 40 embedding providers that shared several defects affecting indexing reliability, cost accounting, and error visibility. This PR fixes four concrete bugs. **Masked, inconsistent errors (27 sites).** Nearly every provider ran `log_exception(_e, res)` followed by `raise Exception(f"Error: {res}")`. Because `log_exception` always raises, the second line was dead code, and the surfaced exception varied with whether the SDK response exposed a `.text` attribute. Every failure path now raises a single `EmbeddingError` that includes the underlying response detail, so the cause of a failed embedding is consistent and visible. **Fabricated token counts.** `LocalAIEmbed` returned a hardcoded `1024` and `OllamaEmbed` added `128` per text. These values feed `used_tokens` and therefore billing and usage tracking. Both now report the real count from the API (Ollama `prompt_eval_count`, LocalAI `usage`) and fall back to a local token count only when the server omits it. **Truncation overshoot.** The `8196` limit used by Mistral and Bedrock exceeded the standard `8192` ceiling and could push boundary sized inputs past the model limit. Limits are corrected to `8192` and made intentional per provider, and providers that rely on server side truncation now request it explicitly (Ollama `truncate=True`, Cohere `truncate="END"`). **Missing batching on Zhipu and Ollama.** Both issued one request per text. They now batch like the other OpenAI compatible providers, turning N round trips into `ceil(N / batch_size)`. Batched results are realigned by response `index` so a chunk always keeps its own vector. A shared `Base._batched_encode` helper owns the batch loop, optional truncation, result accumulation, and the single error path. It is the mechanism that lets these fixes live in one place instead of across 27 duplicated sites. The public `encode()` and `encode_queries()` contract stays the same, so existing callers are unaffected. Tests covering all four fixes are added under `test/unit_test/rag/llm/test_embedding_model.py`. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@@ -30,7 +30,6 @@ from zhipuai import ZhipuAI
|
||||
|
||||
from common import settings
|
||||
from common.exceptions import ModelException
|
||||
from common.log_utils import log_exception
|
||||
from common.token_utils import num_tokens_from_string, truncate, total_token_count_from_response
|
||||
from rag.llm.key_utils import _normalize_replicate_key
|
||||
import logging
|
||||
@@ -38,6 +37,28 @@ import base64
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Standard token ceiling for the common 8K-context embedding models (OpenAI
|
||||
# text-embedding-*, Mistral, Bedrock Titan, ...). Inputs are truncated to this
|
||||
# many tokens so boundary-sized chunks are not rejected by the provider.
|
||||
DEFAULT_MAX_TOKENS = 8192
|
||||
|
||||
|
||||
class EmbeddingError(ModelException):
|
||||
"""Raised when an embedding provider fails to return usable embeddings.
|
||||
|
||||
A single, deterministic exception type for every provider failure path so
|
||||
callers see consistent behaviour regardless of which SDK raised underneath.
|
||||
Subclasses ``ModelException`` so the API error handler (and its retry
|
||||
semantics) treats embedding failures like any other model failure.
|
||||
"""
|
||||
|
||||
|
||||
def _sorted_by_index(items):
|
||||
"""Order OpenAI-style SDK embedding items by their ``.index`` so batched
|
||||
results stay aligned with input order even if the provider returns them out
|
||||
of order. Stable no-op when items carry no ``index`` attribute."""
|
||||
return sorted(items, key=lambda d: getattr(d, "index", 0))
|
||||
|
||||
|
||||
def _raise_model_exception_if_failed(resp):
|
||||
status_code = resp.status_code
|
||||
@@ -91,8 +112,7 @@ def _dashscope_native_http_api_url(base_url: str | None) -> str | None:
|
||||
)
|
||||
return resolved
|
||||
logger.warning(
|
||||
"DashScope Tongyi-Qianwen embedding: base_url is set but not recognized as a DashScope host; "
|
||||
"using SDK default endpoint (%s)",
|
||||
"DashScope Tongyi-Qianwen embedding: base_url is set but not recognized as a DashScope host; using SDK default endpoint (%s)",
|
||||
safe,
|
||||
)
|
||||
return None
|
||||
@@ -130,6 +150,66 @@ class Base(ABC):
|
||||
def encode_queries(self, text: str):
|
||||
raise NotImplementedError("Please implement encode method!")
|
||||
|
||||
def _batched_encode(self, texts: list, call_fn, *, batch_size: int, truncate_to: int | None = None):
|
||||
"""Drive an embedding provider over ``texts`` in batches.
|
||||
|
||||
This is the shared template behind the OpenAI-style providers. It owns:
|
||||
|
||||
* optional per-text truncation to ``truncate_to`` tokens (skipped when
|
||||
``None``) so oversized inputs do not get rejected by the provider;
|
||||
* the batch loop, issuing ``ceil(len(texts) / batch_size)`` calls;
|
||||
* accumulation of the per-text vectors into a single ``np.ndarray``;
|
||||
* summation of the per-batch token counts;
|
||||
* one deterministic, informative error path.
|
||||
|
||||
``call_fn`` is a provider-supplied closure ``call_fn(batch) ->
|
||||
(embeddings, token_count)``. It performs the SDK/HTTP request *and*
|
||||
parses the response (so a malformed/error response surfaces here), and
|
||||
must not assume any particular response shape — the helper never touches
|
||||
the raw response object. ``embeddings`` is a sequence of per-text
|
||||
vectors; ``token_count`` is the real token usage for that batch.
|
||||
|
||||
Any exception raised by ``call_fn`` is wrapped in a single
|
||||
:class:`EmbeddingError` that includes the underlying detail. We log and
|
||||
raise here directly instead of relying on ``log_exception``'s implicit
|
||||
raise (whose surfaced exception varies by SDK response shape).
|
||||
"""
|
||||
if truncate_to is not None:
|
||||
texts = [truncate(t, truncate_to) for t in texts]
|
||||
vectors = []
|
||||
token_count = 0
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i : i + batch_size]
|
||||
try:
|
||||
embeddings, tokens = call_fn(batch)
|
||||
except ModelException:
|
||||
# Already a structured (and possibly retryable) model error; keep it.
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("%s embedding request failed", type(self).__name__)
|
||||
raise EmbeddingError(f"Embedding request failed for {type(self).__name__}. Error: {e}") from e
|
||||
vectors.extend(embeddings)
|
||||
token_count += tokens
|
||||
return np.array(vectors), token_count
|
||||
|
||||
@staticmethod
|
||||
def _openai_http_embeddings(response):
|
||||
"""Parse an OpenAI-compatible HTTP embeddings ``requests`` response.
|
||||
|
||||
Returns ``(embeddings, token_count)``. Raises a retryable-aware
|
||||
:class:`ModelException` on a bad HTTP status, or surfaces the response
|
||||
body (via :class:`EmbeddingError`) when the payload is not a successful
|
||||
``{"data": [...]}`` response.
|
||||
"""
|
||||
_raise_model_exception_if_failed(response)
|
||||
res = response.json()
|
||||
if not isinstance(res, dict) or "data" not in res:
|
||||
raise ValueError(f"unexpected embeddings response (status {getattr(response, 'status_code', '?')}): {res}")
|
||||
# Keep results aligned with input order: OpenAI-compatible responses carry
|
||||
# a per-item `index`; sorting by it is a no-op (stable) when it is absent.
|
||||
data = sorted(res["data"], key=lambda d: d.get("index", 0))
|
||||
return [d["embedding"] for d in data], total_token_count_from_response(res)
|
||||
|
||||
|
||||
class BuiltinEmbed(Base):
|
||||
_FACTORY_NAME = "Builtin"
|
||||
@@ -176,35 +256,17 @@ class OpenAIEmbed(Base):
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
|
||||
def _call(self, batch):
|
||||
res = self.client.embeddings.create(input=batch, model=self.model_name, encoding_format="float", extra_body={"drop_params": True})
|
||||
return [d.embedding for d in _sorted_by_index(res.data)], total_token_count_from_response(res)
|
||||
|
||||
def encode(self, texts: list):
|
||||
# OpenAI requires batch size <=16
|
||||
batch_size = 16
|
||||
texts = [truncate(t, 8191) for t in texts]
|
||||
ress = []
|
||||
total_tokens = 0
|
||||
for i in range(0, len(texts), batch_size):
|
||||
try:
|
||||
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name, encoding_format="float", extra_body={"drop_params": True})
|
||||
except Exception as _e:
|
||||
raise ModelException(f"Error: {_e}")
|
||||
try:
|
||||
ress.extend([d.embedding for d in res.data])
|
||||
total_tokens += total_token_count_from_response(res)
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
raise Exception(f"Error: {res}")
|
||||
return np.array(ress), total_tokens
|
||||
# OpenAI requires batch size <=16; 8191 is the documented per-input token ceiling.
|
||||
return self._batched_encode(texts, self._call, batch_size=16, truncate_to=8191)
|
||||
|
||||
def encode_queries(self, text):
|
||||
try:
|
||||
res = self.client.embeddings.create(input=[truncate(text, 8191)], model=self.model_name, encoding_format="float", extra_body={"drop_params": True})
|
||||
except Exception as _e:
|
||||
raise ModelException(f"Error: {_e}")
|
||||
try:
|
||||
return np.array(res.data[0].embedding), total_token_count_from_response(res)
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
raise Exception(f"Error: {res}")
|
||||
vectors, token_count = self._batched_encode([text], self._call, batch_size=16, truncate_to=8191)
|
||||
return vectors[0], token_count
|
||||
|
||||
|
||||
class LocalAIEmbed(Base):
|
||||
@@ -217,22 +279,21 @@ class LocalAIEmbed(Base):
|
||||
self.client = OpenAI(api_key="empty", base_url=base_url)
|
||||
self.model_name = model_name.split("___")[0]
|
||||
|
||||
def _call(self, batch):
|
||||
res = self.client.embeddings.create(input=batch, model=self.model_name)
|
||||
# Local servers (LocalAI / LM Studio) usually omit usage data; fall back
|
||||
# to a local tiktoken count rather than fabricating a fixed number.
|
||||
tokens = total_token_count_from_response(res)
|
||||
if not tokens:
|
||||
tokens = sum(num_tokens_from_string(t) for t in batch)
|
||||
return [d.embedding for d in _sorted_by_index(res.data)], tokens
|
||||
|
||||
def encode(self, texts: list):
|
||||
batch_size = 16
|
||||
ress = []
|
||||
for i in range(0, len(texts), batch_size):
|
||||
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
|
||||
try:
|
||||
ress.extend([d.embedding for d in res.data])
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
raise Exception(f"Error: {res}")
|
||||
# local embedding for LmStudio donot count tokens
|
||||
return np.array(ress), 1024
|
||||
return self._batched_encode(texts, self._call, batch_size=16)
|
||||
|
||||
def encode_queries(self, text):
|
||||
embds, cnt = self.encode([text])
|
||||
return np.array(embds[0]), cnt
|
||||
vectors, token_count = self._batched_encode([text], self._call, batch_size=16)
|
||||
return vectors[0], token_count
|
||||
|
||||
|
||||
def _resolve_azure_credentials(key):
|
||||
@@ -323,23 +384,21 @@ class QWenEmbed(Base):
|
||||
token_count = 0
|
||||
texts = [truncate(t, 2048) for t in texts]
|
||||
for i in range(0, len(texts), batch_size):
|
||||
|
||||
retry_max, retry_wait_secs = 5, 10
|
||||
for retry in range(retry_max):
|
||||
with _dashscope_native_api_url_scope(self._dashscope_http_api_url):
|
||||
resp = dashscope.TextEmbedding.call(model=self.model_name, input=texts[i : i + batch_size], api_key=self.key, text_type="document")
|
||||
status_code = resp.status_code
|
||||
if status_code >= 400 and status_code < 500 and status_code not in [408, 429]:
|
||||
raise ModelException(f"Error, status: {status_code}, response: {resp}")
|
||||
# No need to retry for 4XX error
|
||||
raise ModelException(f"Error, status: {status_code}, response: {resp}")
|
||||
if status_code == 200:
|
||||
break
|
||||
if retry < retry_max - 1:
|
||||
logging.warning(f"Got error response from DashScope API (status: {status_code}, response: {resp}). Wait {retry_wait_secs} seconds. Retrying...")
|
||||
time.sleep(retry_wait_secs)
|
||||
else:
|
||||
raise ModelException(f"Error after {retry_max} retries., status: {status_code}, response: {resp}")
|
||||
|
||||
raise ModelException(f"Error after {retry_max} retries, status: {status_code}, response: {resp}")
|
||||
try:
|
||||
embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
|
||||
for e in resp["output"]["embeddings"]:
|
||||
@@ -347,8 +406,8 @@ class QWenEmbed(Base):
|
||||
res.extend(embds)
|
||||
token_count += total_token_count_from_response(resp)
|
||||
except Exception as _e:
|
||||
log_exception(_e, resp)
|
||||
raise ModelException(f"Error: {status_code}: {resp}")
|
||||
logger.exception("QWenEmbed: failed to parse embedding response")
|
||||
raise EmbeddingError(f"Embedding request failed for QWenEmbed. Error: {_e}; response={resp}") from _e
|
||||
return np.array(res), token_count
|
||||
|
||||
def encode_queries(self, text):
|
||||
@@ -361,8 +420,8 @@ class QWenEmbed(Base):
|
||||
try:
|
||||
return np.array(resp["output"]["embeddings"][0]["embedding"]), total_token_count_from_response(resp)
|
||||
except Exception as _e:
|
||||
log_exception(_e, resp)
|
||||
raise ModelException(f"Error: {status_code}: {resp}")
|
||||
logger.exception("QWenEmbed: failed to parse query embedding response")
|
||||
raise EmbeddingError(f"Embedding request failed for QWenEmbed. Error: {_e}; response={resp}") from _e
|
||||
|
||||
|
||||
class ZhipuEmbed(Base):
|
||||
@@ -372,34 +431,28 @@ class ZhipuEmbed(Base):
|
||||
self.client = ZhipuAI(api_key=key)
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list):
|
||||
arr = []
|
||||
tks_num = 0
|
||||
MAX_LEN = -1
|
||||
def _max_len(self):
|
||||
# Per-model input ceilings; fall back to the standard 8K limit for any
|
||||
# other model rather than leaving oversized inputs untruncated.
|
||||
if self.model_name.lower() == "embedding-2":
|
||||
MAX_LEN = 512
|
||||
return 512
|
||||
if self.model_name.lower() == "embedding-3":
|
||||
MAX_LEN = 3072
|
||||
if MAX_LEN > 0:
|
||||
texts = [truncate(t, MAX_LEN) for t in texts]
|
||||
return 3072
|
||||
return DEFAULT_MAX_TOKENS
|
||||
|
||||
for txt in texts:
|
||||
res = self.client.embeddings.create(input=txt, model=self.model_name)
|
||||
try:
|
||||
arr.append(res.data[0].embedding)
|
||||
tks_num += total_token_count_from_response(res)
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
raise Exception(f"Error: {res}")
|
||||
return np.array(arr), tks_num
|
||||
def _call(self, batch):
|
||||
# Batch like the other OpenAI-style providers: one request per batch
|
||||
# instead of one request per text. Sort by index so the batched results
|
||||
# stay aligned with input order.
|
||||
res = self.client.embeddings.create(input=batch, model=self.model_name)
|
||||
return [d.embedding for d in _sorted_by_index(res.data)], total_token_count_from_response(res)
|
||||
|
||||
def encode(self, texts: list):
|
||||
return self._batched_encode(texts, self._call, batch_size=16, truncate_to=self._max_len())
|
||||
|
||||
def encode_queries(self, text):
|
||||
res = self.client.embeddings.create(input=text, model=self.model_name)
|
||||
try:
|
||||
return np.array(res.data[0].embedding), total_token_count_from_response(res)
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
raise Exception(f"Error: {res}")
|
||||
vectors, token_count = self._batched_encode([text], self._call, batch_size=16, truncate_to=self._max_len())
|
||||
return vectors[0], token_count
|
||||
|
||||
|
||||
class OllamaEmbed(Base):
|
||||
@@ -412,32 +465,33 @@ class OllamaEmbed(Base):
|
||||
self.model_name = model_name
|
||||
self.keep_alive = kwargs.get("ollama_keep_alive", int(os.environ.get("OLLAMA_KEEP_ALIVE", -1)))
|
||||
|
||||
@classmethod
|
||||
def _strip_special(cls, text: str) -> str:
|
||||
for token in cls._special_tokens:
|
||||
text = text.replace(token, "")
|
||||
return text
|
||||
|
||||
def _call(self, batch):
|
||||
# Batch via client.embed (accepts a list `input`) instead of one
|
||||
# client.embeddings request per text. `truncate=True` lets Ollama clip
|
||||
# oversized inputs to the model's real context length server-side, which
|
||||
# is more accurate than a client-side cl100k estimate.
|
||||
cleaned = [self._strip_special(t) for t in batch]
|
||||
res = self.client.embed(model=self.model_name, input=cleaned, truncate=True, options={"use_mmap": True}, keep_alive=self.keep_alive)
|
||||
# Ollama reports real prompt token usage in `prompt_eval_count`; fall
|
||||
# back to a local count only if the server omits it (never a fixed 128).
|
||||
tokens = res.get("prompt_eval_count") or 0
|
||||
if not tokens:
|
||||
tokens = sum(num_tokens_from_string(t) for t in cleaned)
|
||||
return res["embeddings"], tokens
|
||||
|
||||
def encode(self, texts: list):
|
||||
arr = []
|
||||
tks_num = 0
|
||||
for txt in texts:
|
||||
# remove special tokens if they exist base on regex in one request
|
||||
for token in OllamaEmbed._special_tokens:
|
||||
txt = txt.replace(token, "")
|
||||
res = self.client.embeddings(prompt=txt, model=self.model_name, options={"use_mmap": True}, keep_alive=self.keep_alive)
|
||||
try:
|
||||
arr.append(res["embedding"])
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
raise Exception(f"Error: {res}")
|
||||
tks_num += 128
|
||||
return np.array(arr), tks_num
|
||||
# No client-side truncation: Ollama truncates to the model context above.
|
||||
return self._batched_encode(texts, self._call, batch_size=16)
|
||||
|
||||
def encode_queries(self, text):
|
||||
# remove special tokens if they exist
|
||||
for token in OllamaEmbed._special_tokens:
|
||||
text = text.replace(token, "")
|
||||
res = self.client.embeddings(prompt=text, model=self.model_name, options={"use_mmap": True}, keep_alive=self.keep_alive)
|
||||
try:
|
||||
return np.array(res["embedding"]), 128
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
raise Exception(f"Error: {res}")
|
||||
vectors, token_count = self._batched_encode([text], self._call, batch_size=16)
|
||||
return vectors[0], token_count
|
||||
|
||||
|
||||
class XinferenceEmbed(Base):
|
||||
@@ -448,29 +502,16 @@ class XinferenceEmbed(Base):
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
|
||||
def _call(self, batch):
|
||||
res = self.client.embeddings.create(input=batch, model=self.model_name)
|
||||
return [d.embedding for d in _sorted_by_index(res.data)], total_token_count_from_response(res)
|
||||
|
||||
def encode(self, texts: list):
|
||||
batch_size = 16
|
||||
ress = []
|
||||
total_tokens = 0
|
||||
for i in range(0, len(texts), batch_size):
|
||||
res = None
|
||||
try:
|
||||
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
|
||||
ress.extend([d.embedding for d in res.data])
|
||||
total_tokens += total_token_count_from_response(res)
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
raise Exception(f"Error: {res}")
|
||||
return np.array(ress), total_tokens
|
||||
return self._batched_encode(texts, self._call, batch_size=16)
|
||||
|
||||
def encode_queries(self, text):
|
||||
res = None
|
||||
try:
|
||||
res = self.client.embeddings.create(input=[text], model=self.model_name)
|
||||
return np.array(res.data[0].embedding), total_token_count_from_response(res)
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
raise Exception(f"Error: {res}")
|
||||
vectors, token_count = self._batched_encode([text], self._call, batch_size=16)
|
||||
return vectors[0], token_count
|
||||
|
||||
|
||||
class YoudaoEmbed(Base):
|
||||
@@ -504,55 +545,42 @@ class JinaMultiVecEmbed(Base):
|
||||
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
|
||||
self.model_name = model_name
|
||||
|
||||
@staticmethod
|
||||
def _as_input_item(text):
|
||||
if isinstance(text, str):
|
||||
return {"text": text}
|
||||
# bytes -> base64 encoded image
|
||||
try:
|
||||
base64.b64decode(text, validate=True)
|
||||
return {"image": text.decode("utf8")}
|
||||
except Exception:
|
||||
return {"image": base64.b64encode(text).decode("utf8")}
|
||||
|
||||
def encode(self, texts: list[str | bytes], task="retrieval.passage"):
|
||||
batch_size = 16
|
||||
ress = []
|
||||
token_count = 0
|
||||
input = []
|
||||
for text in texts:
|
||||
if isinstance(text, str):
|
||||
input.append({"text": text})
|
||||
elif isinstance(text, bytes):
|
||||
img_b64s = None
|
||||
try:
|
||||
base64.b64decode(text, validate=True)
|
||||
img_b64s = text.decode("utf8")
|
||||
except Exception:
|
||||
img_b64s = base64.b64encode(text).decode("utf8")
|
||||
input.append({"image": img_b64s}) # base64 encoded image
|
||||
for i in range(0, len(texts), batch_size):
|
||||
data = {"model": self.model_name, "input": input[i : i + batch_size]}
|
||||
def _call(batch):
|
||||
data = {"model": self.model_name, "input": [self._as_input_item(t) for t in batch]}
|
||||
if "v4" in self.model_name:
|
||||
data["return_multivector"] = True
|
||||
|
||||
if "v3" in self.model_name or "v4" in self.model_name:
|
||||
data["task"] = task
|
||||
data["truncate"] = True
|
||||
|
||||
data["truncate"] = True # let Jina truncate oversized inputs server-side
|
||||
response = requests.post(self.base_url, headers=self.headers, json=data, timeout=30)
|
||||
_raise_model_exception_if_failed(response)
|
||||
try:
|
||||
res = response.json()
|
||||
for d in res["data"]:
|
||||
if data.get("return_multivector", False): # v4
|
||||
token_embs = np.asarray(d["embeddings"], dtype=np.float32)
|
||||
chunk_emb = token_embs.mean(axis=0)
|
||||
res = response.json()
|
||||
embs = []
|
||||
for d in res["data"]:
|
||||
if data.get("return_multivector", False): # v4
|
||||
embs.append(np.asarray(d["embeddings"], dtype=np.float32).mean(axis=0))
|
||||
else: # v2/v3
|
||||
embs.append(np.asarray(d["embedding"], dtype=np.float32))
|
||||
return embs, total_token_count_from_response(res)
|
||||
|
||||
else:
|
||||
# v2/v3
|
||||
chunk_emb = np.asarray(d["embedding"], dtype=np.float32)
|
||||
|
||||
ress.append(chunk_emb)
|
||||
|
||||
token_count += total_token_count_from_response(res)
|
||||
except Exception as _e:
|
||||
log_exception(_e, response)
|
||||
raise Exception(f"Error: {response}")
|
||||
return np.array(ress), token_count
|
||||
# Inputs may be image bytes, so token truncation is left to the server.
|
||||
return self._batched_encode(texts, _call, batch_size=16)
|
||||
|
||||
def encode_queries(self, text):
|
||||
embds, cnt = self.encode([text], task="retrieval.query")
|
||||
return np.array(embds[0]), cnt
|
||||
vectors, token_count = self.encode([text], task="retrieval.query")
|
||||
return vectors[0], token_count
|
||||
|
||||
|
||||
class MistralEmbed(Base):
|
||||
@@ -568,7 +596,7 @@ class MistralEmbed(Base):
|
||||
import time
|
||||
import random
|
||||
|
||||
texts = [truncate(t, 8196) for t in texts]
|
||||
texts = [truncate(t, DEFAULT_MAX_TOKENS) for t in texts]
|
||||
batch_size = 16
|
||||
ress = []
|
||||
token_count = 0
|
||||
@@ -582,7 +610,8 @@ class MistralEmbed(Base):
|
||||
break
|
||||
except Exception as _e:
|
||||
if retry_max == 1:
|
||||
log_exception(_e)
|
||||
logger.exception("MistralEmbed: embedding request failed after retries")
|
||||
raise EmbeddingError(f"Embedding request failed for MistralEmbed. Error: {_e}") from _e
|
||||
delay = random.uniform(20, 60)
|
||||
time.sleep(delay)
|
||||
retry_max -= 1
|
||||
@@ -595,11 +624,12 @@ class MistralEmbed(Base):
|
||||
retry_max = 5
|
||||
while retry_max > 0:
|
||||
try:
|
||||
res = self.client.embeddings(input=[truncate(text, 8196)], model=self.model_name)
|
||||
res = self.client.embeddings(input=[truncate(text, DEFAULT_MAX_TOKENS)], model=self.model_name)
|
||||
return np.array(res.data[0].embedding), total_token_count_from_response(res)
|
||||
except Exception as _e:
|
||||
if retry_max == 1:
|
||||
log_exception(_e)
|
||||
logger.exception("MistralEmbed: query embedding request failed after retries")
|
||||
raise EmbeddingError(f"Embedding request failed for MistralEmbed. Error: {_e}") from _e
|
||||
delay = random.randint(20, 60)
|
||||
time.sleep(delay)
|
||||
retry_max -= 1
|
||||
@@ -649,42 +679,41 @@ class BedrockEmbed(Base):
|
||||
else: # assume_role
|
||||
self.client = boto3.client("bedrock-runtime", region_name=self.bedrock_region)
|
||||
|
||||
def _extract_vector(self, model_response):
|
||||
# Titan returns {"embedding": [...]}; Cohere returns {"embeddings": [[...]]}.
|
||||
if self.is_cohere:
|
||||
return model_response["embeddings"][0]
|
||||
return model_response["embedding"]
|
||||
|
||||
def encode(self, texts: list):
|
||||
texts = [truncate(t, 8196) for t in texts]
|
||||
embeddings = []
|
||||
token_count = 0
|
||||
for text in texts:
|
||||
def _call(batch):
|
||||
# Titan accepts a single input per call, so batch_size is 1.
|
||||
text = batch[0]
|
||||
if self.is_amazon:
|
||||
body = {"inputText": text}
|
||||
elif self.is_cohere:
|
||||
body = {"texts": [text], "input_type": "search_document"}
|
||||
|
||||
response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
|
||||
try:
|
||||
model_response = json.loads(response["body"].read())
|
||||
embeddings.extend([model_response["embedding"]])
|
||||
token_count += num_tokens_from_string(text)
|
||||
except Exception as _e:
|
||||
log_exception(_e, response)
|
||||
model_response = json.loads(response["body"].read())
|
||||
# Bedrock does not report token usage; count locally.
|
||||
return [self._extract_vector(model_response)], num_tokens_from_string(text)
|
||||
|
||||
return np.array(embeddings), token_count
|
||||
return self._batched_encode(texts, _call, batch_size=1, truncate_to=DEFAULT_MAX_TOKENS)
|
||||
|
||||
def encode_queries(self, text):
|
||||
embeddings = []
|
||||
text = truncate(text, DEFAULT_MAX_TOKENS)
|
||||
token_count = num_tokens_from_string(text)
|
||||
if self.is_amazon:
|
||||
body = {"inputText": truncate(text, 8196)}
|
||||
body = {"inputText": text}
|
||||
elif self.is_cohere:
|
||||
body = {"texts": [truncate(text, 8196)], "input_type": "search_query"}
|
||||
|
||||
response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
|
||||
body = {"texts": [text], "input_type": "search_query"}
|
||||
try:
|
||||
response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
|
||||
model_response = json.loads(response["body"].read())
|
||||
embeddings.extend(model_response["embedding"])
|
||||
return np.array(self._extract_vector(model_response)), token_count
|
||||
except Exception as _e:
|
||||
log_exception(_e, response)
|
||||
|
||||
return np.array(embeddings), token_count
|
||||
logger.exception("BedrockEmbed: query embedding request failed")
|
||||
raise EmbeddingError(f"Embedding request failed for BedrockEmbed. Error: {_e}") from _e
|
||||
|
||||
|
||||
class GeminiEmbed(Base):
|
||||
@@ -741,28 +770,17 @@ class GeminiEmbed(Base):
|
||||
return self.types.EmbedContentConfig(task_type=task_type)
|
||||
|
||||
def encode(self, texts: list):
|
||||
texts = [truncate(t, 2048) for t in texts]
|
||||
token_count = sum(num_tokens_from_string(text) for text in texts)
|
||||
config = self._build_embedding_config()
|
||||
batch_size = 16
|
||||
ress = []
|
||||
for i in range(0, len(texts), batch_size):
|
||||
result = None
|
||||
try:
|
||||
result = self.client.models.embed_content(
|
||||
model=self.model_name,
|
||||
contents=texts[i : i + batch_size],
|
||||
config=config,
|
||||
)
|
||||
ress.extend(self._parse_embedding_response(result))
|
||||
except Exception as _e:
|
||||
log_exception(_e, result)
|
||||
raise Exception(f"Error: {result}")
|
||||
return np.array(ress), token_count
|
||||
|
||||
def _call(batch):
|
||||
result = self.client.models.embed_content(model=self.model_name, contents=batch, config=config)
|
||||
# Gemini embeddings do not report token usage; count locally.
|
||||
return self._parse_embedding_response(result), sum(num_tokens_from_string(t) for t in batch)
|
||||
|
||||
return self._batched_encode(texts, _call, batch_size=16, truncate_to=2048)
|
||||
|
||||
def encode_queries(self, text):
|
||||
config = self._build_embedding_config()
|
||||
result = None
|
||||
token_count = num_tokens_from_string(text)
|
||||
try:
|
||||
result = self.client.models.embed_content(
|
||||
@@ -772,8 +790,8 @@ class GeminiEmbed(Base):
|
||||
)
|
||||
return np.array(self._parse_embedding_response(result)[0]), token_count
|
||||
except Exception as _e:
|
||||
log_exception(_e, result)
|
||||
raise Exception(f"Error: {result}")
|
||||
logger.exception("GeminiEmbed: query embedding request failed")
|
||||
raise EmbeddingError(f"Embedding request failed for GeminiEmbed. Error: {_e}") from _e
|
||||
|
||||
|
||||
class NvidiaEmbed(Base):
|
||||
@@ -796,32 +814,24 @@ class NvidiaEmbed(Base):
|
||||
if model_name == "snowflake/arctic-embed-l":
|
||||
self.base_url = "https://ai.api.nvidia.com/v1/retrieval/snowflake/arctic-embed-l/embeddings"
|
||||
|
||||
def _call(self, batch, input_type="query"):
|
||||
payload = {
|
||||
"input": batch,
|
||||
"input_type": input_type,
|
||||
"model": self.model_name,
|
||||
"encoding_format": "float",
|
||||
"truncate": "END", # NVIDIA truncates oversized inputs server-side.
|
||||
}
|
||||
response = requests.post(self.base_url, headers=self.headers, json=payload, timeout=30)
|
||||
return self._openai_http_embeddings(response)
|
||||
|
||||
def encode(self, texts: list):
|
||||
batch_size = 16
|
||||
ress = []
|
||||
token_count = 0
|
||||
for i in range(0, len(texts), batch_size):
|
||||
payload = {
|
||||
"input": texts[i : i + batch_size],
|
||||
"input_type": "query",
|
||||
"model": self.model_name,
|
||||
"encoding_format": "float",
|
||||
"truncate": "END",
|
||||
}
|
||||
response = requests.post(self.base_url, headers=self.headers, json=payload, timeout=30)
|
||||
_raise_model_exception_if_failed(response)
|
||||
try:
|
||||
res = response.json()
|
||||
ress.extend([d["embedding"] for d in res["data"]])
|
||||
token_count += total_token_count_from_response(res)
|
||||
except Exception as _e:
|
||||
log_exception(_e, response)
|
||||
raise Exception(f"Error: {response}")
|
||||
return np.array(ress), token_count
|
||||
# NVIDIA NIM expects "passage" for documents (indexing) and "query" for retrieval.
|
||||
return self._batched_encode(texts, lambda b: self._call(b, "passage"), batch_size=16)
|
||||
|
||||
def encode_queries(self, text):
|
||||
embds, cnt = self.encode([text])
|
||||
return np.array(embds[0]), cnt
|
||||
vectors, token_count = self._batched_encode([text], lambda b: self._call(b, "query"), batch_size=16)
|
||||
return vectors[0], token_count
|
||||
|
||||
|
||||
class LmStudioEmbed(LocalAIEmbed):
|
||||
@@ -855,37 +865,32 @@ class CoHereEmbed(Base):
|
||||
self.client = Client(api_key=key)
|
||||
self.model_name = model_name
|
||||
|
||||
def _call(self, batch):
|
||||
res = self.client.embed(
|
||||
texts=batch,
|
||||
model=self.model_name,
|
||||
input_type="search_document",
|
||||
embedding_types=["float"],
|
||||
truncate="END", # let Cohere clip oversized inputs server-side instead of hard-failing
|
||||
)
|
||||
return list(res.embeddings.float), total_token_count_from_response(res)
|
||||
|
||||
def encode(self, texts: list):
|
||||
batch_size = 16
|
||||
ress = []
|
||||
token_count = 0
|
||||
for i in range(0, len(texts), batch_size):
|
||||
res = self.client.embed(
|
||||
texts=texts[i : i + batch_size],
|
||||
model=self.model_name,
|
||||
input_type="search_document",
|
||||
embedding_types=["float"],
|
||||
)
|
||||
try:
|
||||
ress.extend([d for d in res.embeddings.float])
|
||||
token_count += total_token_count_from_response(res)
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
raise Exception(f"Error: {res}")
|
||||
return np.array(ress), token_count
|
||||
return self._batched_encode(texts, self._call, batch_size=16)
|
||||
|
||||
def encode_queries(self, text):
|
||||
res = self.client.embed(
|
||||
texts=[text],
|
||||
model=self.model_name,
|
||||
input_type="search_query",
|
||||
embedding_types=["float"],
|
||||
)
|
||||
try:
|
||||
res = self.client.embed(
|
||||
texts=[text],
|
||||
model=self.model_name,
|
||||
input_type="search_query",
|
||||
embedding_types=["float"],
|
||||
truncate="END",
|
||||
)
|
||||
return np.array(res.embeddings.float[0]), int(total_token_count_from_response(res))
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
raise Exception(f"Error: {res}")
|
||||
logger.exception("CoHereEmbed: query embedding request failed")
|
||||
raise EmbeddingError(f"Embedding request failed for CoHereEmbed. Error: {_e}") from _e
|
||||
|
||||
|
||||
class TogetherAIEmbed(OpenAIEmbed):
|
||||
@@ -932,49 +937,27 @@ class SILICONFLOWEmbed(Base):
|
||||
self.base_url = normalized_base_url
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list):
|
||||
batch_size = 16
|
||||
ress = []
|
||||
token_count = 0
|
||||
for i in range(0, len(texts), batch_size):
|
||||
texts_batch = texts[i : i + batch_size]
|
||||
if self.model_name in ["BAAI/bge-large-zh-v1.5", "BAAI/bge-large-en-v1.5"]:
|
||||
# limit 512, 340 is almost safe
|
||||
texts_batch = [" " if not text.strip() else truncate(text, 256) for text in texts_batch]
|
||||
else:
|
||||
texts_batch = [" " if not text.strip() else text for text in texts_batch]
|
||||
def _clean_batch(self, batch):
|
||||
if self.model_name in ["BAAI/bge-large-zh-v1.5", "BAAI/bge-large-en-v1.5"]:
|
||||
# limit 512, 340 is almost safe
|
||||
return [" " if not text.strip() else truncate(text, 256) for text in batch]
|
||||
return [" " if not text.strip() else text for text in batch]
|
||||
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"input": texts_batch,
|
||||
"encoding_format": "float",
|
||||
}
|
||||
response = requests.post(self.base_url, json=payload, headers=self.headers, timeout=30)
|
||||
_raise_model_exception_if_failed(response)
|
||||
try:
|
||||
res = response.json()
|
||||
ress.extend([d["embedding"] for d in res["data"]])
|
||||
token_count += total_token_count_from_response(res)
|
||||
except Exception as _e:
|
||||
log_exception(_e, response)
|
||||
raise Exception(f"Error: {response}")
|
||||
|
||||
return np.array(ress), token_count
|
||||
|
||||
def encode_queries(self, text):
|
||||
def _call(self, batch):
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"input": text,
|
||||
"input": self._clean_batch(batch),
|
||||
"encoding_format": "float",
|
||||
}
|
||||
response = requests.post(self.base_url, json=payload, headers=self.headers, timeout=30)
|
||||
_raise_model_exception_if_failed(response)
|
||||
try:
|
||||
res = response.json()
|
||||
return np.array(res["data"][0]["embedding"]), total_token_count_from_response(res)
|
||||
except Exception as _e:
|
||||
log_exception(_e, response)
|
||||
raise Exception(f"Error: {response}")
|
||||
return self._openai_http_embeddings(response)
|
||||
|
||||
def encode(self, texts: list):
|
||||
return self._batched_encode(texts, self._call, batch_size=16)
|
||||
|
||||
def encode_queries(self, text):
|
||||
vectors, token_count = self._batched_encode([text], self._call, batch_size=16)
|
||||
return vectors[0], token_count
|
||||
|
||||
|
||||
class ReplicateEmbed(Base):
|
||||
@@ -1013,26 +996,26 @@ class BaiduYiyanEmbed(Base):
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list, batch_size=16):
|
||||
res = self.client.do(model=self.model_name, texts=texts).body
|
||||
try:
|
||||
res = self.client.do(model=self.model_name, texts=texts).body
|
||||
return (
|
||||
np.array([r["embedding"] for r in res["data"]]),
|
||||
total_token_count_from_response(res),
|
||||
)
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
raise Exception(f"Error: {res}")
|
||||
logger.exception("BaiduYiyanEmbed: embedding request failed")
|
||||
raise EmbeddingError(f"Embedding request failed for BaiduYiyanEmbed. Error: {_e}") from _e
|
||||
|
||||
def encode_queries(self, text):
|
||||
res = self.client.do(model=self.model_name, texts=[text]).body
|
||||
try:
|
||||
res = self.client.do(model=self.model_name, texts=[text]).body
|
||||
return (
|
||||
np.array([r["embedding"] for r in res["data"]]),
|
||||
total_token_count_from_response(res),
|
||||
)
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
raise Exception(f"Error: {res}")
|
||||
logger.exception("BaiduYiyanEmbed: query embedding request failed")
|
||||
raise EmbeddingError(f"Embedding request failed for BaiduYiyanEmbed. Error: {_e}") from _e
|
||||
|
||||
|
||||
class VoyageEmbed(Base):
|
||||
@@ -1044,27 +1027,22 @@ class VoyageEmbed(Base):
|
||||
self.client = voyageai.Client(api_key=key)
|
||||
self.model_name = model_name
|
||||
|
||||
def _call(self, batch):
|
||||
res = self.client.embed(texts=batch, model=self.model_name, input_type="document")
|
||||
# `_batched_encode` accumulates these per-batch vectors and returns a
|
||||
# single np.ndarray, so encode() keeps the np.ndarray contract.
|
||||
return res.embeddings, res.total_tokens
|
||||
|
||||
def encode(self, texts: list):
|
||||
batch_size = 16
|
||||
ress = []
|
||||
token_count = 0
|
||||
for i in range(0, len(texts), batch_size):
|
||||
res = self.client.embed(texts=texts[i : i + batch_size], model=self.model_name, input_type="document")
|
||||
try:
|
||||
ress.extend(res.embeddings)
|
||||
token_count += res.total_tokens
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
raise Exception(f"Error: {res}")
|
||||
return np.array(ress), token_count
|
||||
return self._batched_encode(texts, self._call, batch_size=16)
|
||||
|
||||
def encode_queries(self, text):
|
||||
res = self.client.embed(texts=text, model=self.model_name, input_type="query")
|
||||
try:
|
||||
res = self.client.embed(texts=text, model=self.model_name, input_type="query")
|
||||
return np.array(res.embeddings)[0], res.total_tokens
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
raise Exception(f"Error: {res}")
|
||||
logger.exception("VoyageEmbed: query embedding request failed")
|
||||
raise EmbeddingError(f"Embedding request failed for VoyageEmbed. Error: {_e}") from _e
|
||||
|
||||
|
||||
class HuggingFaceEmbed(Base):
|
||||
@@ -1080,14 +1058,13 @@ class HuggingFaceEmbed(Base):
|
||||
def encode(self, texts: list):
|
||||
response = requests.post(f"{self.base_url}/embed", json={"inputs": texts}, headers={"Content-Type": "application/json"}, timeout=30)
|
||||
_raise_model_exception_if_failed(response)
|
||||
embeddings = response.json()
|
||||
return np.array(embeddings), sum([num_tokens_from_string(text) for text in texts])
|
||||
# TEI auto-truncates oversized inputs, so no client-side truncation is needed.
|
||||
return np.array(response.json()), sum([num_tokens_from_string(text) for text in texts])
|
||||
|
||||
def encode_queries(self, text: str):
|
||||
response = requests.post(f"{self.base_url}/embed", json={"inputs": text}, headers={"Content-Type": "application/json"}, timeout=30)
|
||||
_raise_model_exception_if_failed(response)
|
||||
embedding = response.json()[0]
|
||||
return np.array(embedding), num_tokens_from_string(text)
|
||||
return np.array(response.json()[0]), num_tokens_from_string(text)
|
||||
|
||||
|
||||
class VolcEngineEmbed(Base):
|
||||
@@ -1141,13 +1118,14 @@ class VolcEngineEmbed(Base):
|
||||
request_body = {"model": self.model_name, "input": [{"type": "text", "text": text}]}
|
||||
response = sync_request(method="POST", url=url, headers=headers, json=request_body, timeout=60)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Error: {response.status_code} - {response.text}")
|
||||
raise EmbeddingError(f"Embedding request failed for VolcEngineEmbed. Error: {response.status_code} - {response.text}")
|
||||
result = response.json()
|
||||
try:
|
||||
ress.append(self._extract_embedding(result))
|
||||
total_tokens += total_token_count_from_response(result)
|
||||
except Exception as _e:
|
||||
log_exception(_e)
|
||||
logger.exception("VolcEngineEmbed: failed to parse embedding response")
|
||||
raise EmbeddingError(f"Embedding request failed for VolcEngineEmbed. Error: {_e}; response={result}") from _e
|
||||
|
||||
return np.array(ress), total_tokens
|
||||
|
||||
@@ -1295,8 +1273,8 @@ class PerplexityEmbed(Base):
|
||||
ress.append(self._decode_base64_int8(chunk_emb["embedding"]))
|
||||
token_count += res.get("usage", {}).get("total_tokens", 0)
|
||||
except Exception as _e:
|
||||
log_exception(_e, response)
|
||||
raise Exception(f"Error: {response.text}")
|
||||
logger.exception("PerplexityEmbed: failed to parse contextualized embedding response")
|
||||
raise EmbeddingError(f"Embedding request failed for PerplexityEmbed. Error: {response.text}") from _e
|
||||
else:
|
||||
url = f"{self.base_url}/v1/embeddings"
|
||||
for i in range(0, len(texts), batch_size):
|
||||
@@ -1314,8 +1292,8 @@ class PerplexityEmbed(Base):
|
||||
ress.append(self._decode_base64_int8(d["embedding"]))
|
||||
token_count += res.get("usage", {}).get("total_tokens", 0)
|
||||
except Exception as _e:
|
||||
log_exception(_e, response)
|
||||
raise Exception(f"Error: {response.text}")
|
||||
logger.exception("PerplexityEmbed: failed to parse embedding response")
|
||||
raise EmbeddingError(f"Embedding request failed for PerplexityEmbed. Error: {response.text}") from _e
|
||||
|
||||
return np.array(ress), token_count
|
||||
|
||||
|
||||
Reference in New Issue
Block a user