From bde2b1fc6d4d8515adda8e76e32c4aef5f380e27 Mon Sep 17 00:00:00 2001 From: Dexterity <173429049+Dexterity104@users.noreply.github.com> Date: Thu, 11 Jun 2026 07:29:46 -0400 Subject: [PATCH] fix(llm): correct error handling, token accounting, and truncation in embedding providers (#15424) ### Summary Closes #15423 `rag/llm/embedding_model.py` hosts about 40 embedding providers that shared several defects affecting indexing reliability, cost accounting, and error visibility. This PR fixes four concrete bugs. **Masked, inconsistent errors (27 sites).** Nearly every provider ran `log_exception(_e, res)` followed by `raise Exception(f"Error: {res}")`. Because `log_exception` always raises, the second line was dead code, and the surfaced exception varied with whether the SDK response exposed a `.text` attribute. Every failure path now raises a single `EmbeddingError` that includes the underlying response detail, so the cause of a failed embedding is consistent and visible. **Fabricated token counts.** `LocalAIEmbed` returned a hardcoded `1024` and `OllamaEmbed` added `128` per text. These values feed `used_tokens` and therefore billing and usage tracking. Both now report the real count from the API (Ollama `prompt_eval_count`, LocalAI `usage`) and fall back to a local token count only when the server omits it. **Truncation overshoot.** The `8196` limit used by Mistral and Bedrock exceeded the standard `8192` ceiling and could push boundary sized inputs past the model limit. Limits are corrected to `8192` and made intentional per provider, and providers that rely on server side truncation now request it explicitly (Ollama `truncate=True`, Cohere `truncate="END"`). **Missing batching on Zhipu and Ollama.** Both issued one request per text. They now batch like the other OpenAI compatible providers, turning N round trips into `ceil(N / batch_size)`. Batched results are realigned by response `index` so a chunk always keeps its own vector. A shared `Base._batched_encode` helper owns the batch loop, optional truncation, result accumulation, and the single error path. It is the mechanism that lets these fixes live in one place instead of across 27 duplicated sites. The public `encode()` and `encode_queries()` contract stays the same, so existing callers are unaffected. Tests covering all four fixes are added under `test/unit_test/rag/llm/test_embedding_model.py`. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- download_deps.py | 6 + rag/llm/embedding_model.py | 610 +++++++++--------- test/unit_test/conftest.py | 50 ++ .../unit_test/rag/llm/test_embedding_model.py | 383 +++++++++++ 4 files changed, 733 insertions(+), 316 deletions(-) create mode 100644 test/unit_test/conftest.py create mode 100644 test/unit_test/rag/llm/test_embedding_model.py diff --git a/download_deps.py b/download_deps.py index b707e03622..df29eaac91 100644 --- a/download_deps.py +++ b/download_deps.py @@ -64,6 +64,12 @@ if __name__ == "__main__": urls = get_urls(args.china_mirrors) + # Some mirrors (e.g. archive.ubuntu.com) reject the default urllib + # User-Agent with HTTP 403, so install an opener with a browser-like UA. + opener = urllib.request.build_opener() + opener.addheaders = [("User-Agent", "Mozilla/5.0")] + urllib.request.install_opener(opener) + for url in urls: download_url = url[0] if isinstance(url, list) else url filename = url[1] if isinstance(url, list) else url.split("/")[-1] diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 616382cf72..c8ed17b34a 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -30,7 +30,6 @@ from zhipuai import ZhipuAI from common import settings from common.exceptions import ModelException -from common.log_utils import log_exception from common.token_utils import num_tokens_from_string, truncate, total_token_count_from_response from rag.llm.key_utils import _normalize_replicate_key import logging @@ -38,6 +37,28 @@ import base64 logger = logging.getLogger(__name__) +# Standard token ceiling for the common 8K-context embedding models (OpenAI +# text-embedding-*, Mistral, Bedrock Titan, ...). Inputs are truncated to this +# many tokens so boundary-sized chunks are not rejected by the provider. +DEFAULT_MAX_TOKENS = 8192 + + +class EmbeddingError(ModelException): + """Raised when an embedding provider fails to return usable embeddings. + + A single, deterministic exception type for every provider failure path so + callers see consistent behaviour regardless of which SDK raised underneath. + Subclasses ``ModelException`` so the API error handler (and its retry + semantics) treats embedding failures like any other model failure. + """ + + +def _sorted_by_index(items): + """Order OpenAI-style SDK embedding items by their ``.index`` so batched + results stay aligned with input order even if the provider returns them out + of order. Stable no-op when items carry no ``index`` attribute.""" + return sorted(items, key=lambda d: getattr(d, "index", 0)) + def _raise_model_exception_if_failed(resp): status_code = resp.status_code @@ -91,8 +112,7 @@ def _dashscope_native_http_api_url(base_url: str | None) -> str | None: ) return resolved logger.warning( - "DashScope Tongyi-Qianwen embedding: base_url is set but not recognized as a DashScope host; " - "using SDK default endpoint (%s)", + "DashScope Tongyi-Qianwen embedding: base_url is set but not recognized as a DashScope host; using SDK default endpoint (%s)", safe, ) return None @@ -130,6 +150,66 @@ class Base(ABC): def encode_queries(self, text: str): raise NotImplementedError("Please implement encode method!") + def _batched_encode(self, texts: list, call_fn, *, batch_size: int, truncate_to: int | None = None): + """Drive an embedding provider over ``texts`` in batches. + + This is the shared template behind the OpenAI-style providers. It owns: + + * optional per-text truncation to ``truncate_to`` tokens (skipped when + ``None``) so oversized inputs do not get rejected by the provider; + * the batch loop, issuing ``ceil(len(texts) / batch_size)`` calls; + * accumulation of the per-text vectors into a single ``np.ndarray``; + * summation of the per-batch token counts; + * one deterministic, informative error path. + + ``call_fn`` is a provider-supplied closure ``call_fn(batch) -> + (embeddings, token_count)``. It performs the SDK/HTTP request *and* + parses the response (so a malformed/error response surfaces here), and + must not assume any particular response shape — the helper never touches + the raw response object. ``embeddings`` is a sequence of per-text + vectors; ``token_count`` is the real token usage for that batch. + + Any exception raised by ``call_fn`` is wrapped in a single + :class:`EmbeddingError` that includes the underlying detail. We log and + raise here directly instead of relying on ``log_exception``'s implicit + raise (whose surfaced exception varies by SDK response shape). + """ + if truncate_to is not None: + texts = [truncate(t, truncate_to) for t in texts] + vectors = [] + token_count = 0 + for i in range(0, len(texts), batch_size): + batch = texts[i : i + batch_size] + try: + embeddings, tokens = call_fn(batch) + except ModelException: + # Already a structured (and possibly retryable) model error; keep it. + raise + except Exception as e: + logger.exception("%s embedding request failed", type(self).__name__) + raise EmbeddingError(f"Embedding request failed for {type(self).__name__}. Error: {e}") from e + vectors.extend(embeddings) + token_count += tokens + return np.array(vectors), token_count + + @staticmethod + def _openai_http_embeddings(response): + """Parse an OpenAI-compatible HTTP embeddings ``requests`` response. + + Returns ``(embeddings, token_count)``. Raises a retryable-aware + :class:`ModelException` on a bad HTTP status, or surfaces the response + body (via :class:`EmbeddingError`) when the payload is not a successful + ``{"data": [...]}`` response. + """ + _raise_model_exception_if_failed(response) + res = response.json() + if not isinstance(res, dict) or "data" not in res: + raise ValueError(f"unexpected embeddings response (status {getattr(response, 'status_code', '?')}): {res}") + # Keep results aligned with input order: OpenAI-compatible responses carry + # a per-item `index`; sorting by it is a no-op (stable) when it is absent. + data = sorted(res["data"], key=lambda d: d.get("index", 0)) + return [d["embedding"] for d in data], total_token_count_from_response(res) + class BuiltinEmbed(Base): _FACTORY_NAME = "Builtin" @@ -176,35 +256,17 @@ class OpenAIEmbed(Base): self.client = OpenAI(api_key=key, base_url=base_url) self.model_name = model_name + def _call(self, batch): + res = self.client.embeddings.create(input=batch, model=self.model_name, encoding_format="float", extra_body={"drop_params": True}) + return [d.embedding for d in _sorted_by_index(res.data)], total_token_count_from_response(res) + def encode(self, texts: list): - # OpenAI requires batch size <=16 - batch_size = 16 - texts = [truncate(t, 8191) for t in texts] - ress = [] - total_tokens = 0 - for i in range(0, len(texts), batch_size): - try: - res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name, encoding_format="float", extra_body={"drop_params": True}) - except Exception as _e: - raise ModelException(f"Error: {_e}") - try: - ress.extend([d.embedding for d in res.data]) - total_tokens += total_token_count_from_response(res) - except Exception as _e: - log_exception(_e, res) - raise Exception(f"Error: {res}") - return np.array(ress), total_tokens + # OpenAI requires batch size <=16; 8191 is the documented per-input token ceiling. + return self._batched_encode(texts, self._call, batch_size=16, truncate_to=8191) def encode_queries(self, text): - try: - res = self.client.embeddings.create(input=[truncate(text, 8191)], model=self.model_name, encoding_format="float", extra_body={"drop_params": True}) - except Exception as _e: - raise ModelException(f"Error: {_e}") - try: - return np.array(res.data[0].embedding), total_token_count_from_response(res) - except Exception as _e: - log_exception(_e, res) - raise Exception(f"Error: {res}") + vectors, token_count = self._batched_encode([text], self._call, batch_size=16, truncate_to=8191) + return vectors[0], token_count class LocalAIEmbed(Base): @@ -217,22 +279,21 @@ class LocalAIEmbed(Base): self.client = OpenAI(api_key="empty", base_url=base_url) self.model_name = model_name.split("___")[0] + def _call(self, batch): + res = self.client.embeddings.create(input=batch, model=self.model_name) + # Local servers (LocalAI / LM Studio) usually omit usage data; fall back + # to a local tiktoken count rather than fabricating a fixed number. + tokens = total_token_count_from_response(res) + if not tokens: + tokens = sum(num_tokens_from_string(t) for t in batch) + return [d.embedding for d in _sorted_by_index(res.data)], tokens + def encode(self, texts: list): - batch_size = 16 - ress = [] - for i in range(0, len(texts), batch_size): - res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name) - try: - ress.extend([d.embedding for d in res.data]) - except Exception as _e: - log_exception(_e, res) - raise Exception(f"Error: {res}") - # local embedding for LmStudio donot count tokens - return np.array(ress), 1024 + return self._batched_encode(texts, self._call, batch_size=16) def encode_queries(self, text): - embds, cnt = self.encode([text]) - return np.array(embds[0]), cnt + vectors, token_count = self._batched_encode([text], self._call, batch_size=16) + return vectors[0], token_count def _resolve_azure_credentials(key): @@ -323,23 +384,21 @@ class QWenEmbed(Base): token_count = 0 texts = [truncate(t, 2048) for t in texts] for i in range(0, len(texts), batch_size): - retry_max, retry_wait_secs = 5, 10 for retry in range(retry_max): with _dashscope_native_api_url_scope(self._dashscope_http_api_url): resp = dashscope.TextEmbedding.call(model=self.model_name, input=texts[i : i + batch_size], api_key=self.key, text_type="document") status_code = resp.status_code if status_code >= 400 and status_code < 500 and status_code not in [408, 429]: - raise ModelException(f"Error, status: {status_code}, response: {resp}") # No need to retry for 4XX error + raise ModelException(f"Error, status: {status_code}, response: {resp}") if status_code == 200: break if retry < retry_max - 1: logging.warning(f"Got error response from DashScope API (status: {status_code}, response: {resp}). Wait {retry_wait_secs} seconds. Retrying...") time.sleep(retry_wait_secs) else: - raise ModelException(f"Error after {retry_max} retries., status: {status_code}, response: {resp}") - + raise ModelException(f"Error after {retry_max} retries, status: {status_code}, response: {resp}") try: embds = [[] for _ in range(len(resp["output"]["embeddings"]))] for e in resp["output"]["embeddings"]: @@ -347,8 +406,8 @@ class QWenEmbed(Base): res.extend(embds) token_count += total_token_count_from_response(resp) except Exception as _e: - log_exception(_e, resp) - raise ModelException(f"Error: {status_code}: {resp}") + logger.exception("QWenEmbed: failed to parse embedding response") + raise EmbeddingError(f"Embedding request failed for QWenEmbed. Error: {_e}; response={resp}") from _e return np.array(res), token_count def encode_queries(self, text): @@ -361,8 +420,8 @@ class QWenEmbed(Base): try: return np.array(resp["output"]["embeddings"][0]["embedding"]), total_token_count_from_response(resp) except Exception as _e: - log_exception(_e, resp) - raise ModelException(f"Error: {status_code}: {resp}") + logger.exception("QWenEmbed: failed to parse query embedding response") + raise EmbeddingError(f"Embedding request failed for QWenEmbed. Error: {_e}; response={resp}") from _e class ZhipuEmbed(Base): @@ -372,34 +431,28 @@ class ZhipuEmbed(Base): self.client = ZhipuAI(api_key=key) self.model_name = model_name - def encode(self, texts: list): - arr = [] - tks_num = 0 - MAX_LEN = -1 + def _max_len(self): + # Per-model input ceilings; fall back to the standard 8K limit for any + # other model rather than leaving oversized inputs untruncated. if self.model_name.lower() == "embedding-2": - MAX_LEN = 512 + return 512 if self.model_name.lower() == "embedding-3": - MAX_LEN = 3072 - if MAX_LEN > 0: - texts = [truncate(t, MAX_LEN) for t in texts] + return 3072 + return DEFAULT_MAX_TOKENS - for txt in texts: - res = self.client.embeddings.create(input=txt, model=self.model_name) - try: - arr.append(res.data[0].embedding) - tks_num += total_token_count_from_response(res) - except Exception as _e: - log_exception(_e, res) - raise Exception(f"Error: {res}") - return np.array(arr), tks_num + def _call(self, batch): + # Batch like the other OpenAI-style providers: one request per batch + # instead of one request per text. Sort by index so the batched results + # stay aligned with input order. + res = self.client.embeddings.create(input=batch, model=self.model_name) + return [d.embedding for d in _sorted_by_index(res.data)], total_token_count_from_response(res) + + def encode(self, texts: list): + return self._batched_encode(texts, self._call, batch_size=16, truncate_to=self._max_len()) def encode_queries(self, text): - res = self.client.embeddings.create(input=text, model=self.model_name) - try: - return np.array(res.data[0].embedding), total_token_count_from_response(res) - except Exception as _e: - log_exception(_e, res) - raise Exception(f"Error: {res}") + vectors, token_count = self._batched_encode([text], self._call, batch_size=16, truncate_to=self._max_len()) + return vectors[0], token_count class OllamaEmbed(Base): @@ -412,32 +465,33 @@ class OllamaEmbed(Base): self.model_name = model_name self.keep_alive = kwargs.get("ollama_keep_alive", int(os.environ.get("OLLAMA_KEEP_ALIVE", -1))) + @classmethod + def _strip_special(cls, text: str) -> str: + for token in cls._special_tokens: + text = text.replace(token, "") + return text + + def _call(self, batch): + # Batch via client.embed (accepts a list `input`) instead of one + # client.embeddings request per text. `truncate=True` lets Ollama clip + # oversized inputs to the model's real context length server-side, which + # is more accurate than a client-side cl100k estimate. + cleaned = [self._strip_special(t) for t in batch] + res = self.client.embed(model=self.model_name, input=cleaned, truncate=True, options={"use_mmap": True}, keep_alive=self.keep_alive) + # Ollama reports real prompt token usage in `prompt_eval_count`; fall + # back to a local count only if the server omits it (never a fixed 128). + tokens = res.get("prompt_eval_count") or 0 + if not tokens: + tokens = sum(num_tokens_from_string(t) for t in cleaned) + return res["embeddings"], tokens + def encode(self, texts: list): - arr = [] - tks_num = 0 - for txt in texts: - # remove special tokens if they exist base on regex in one request - for token in OllamaEmbed._special_tokens: - txt = txt.replace(token, "") - res = self.client.embeddings(prompt=txt, model=self.model_name, options={"use_mmap": True}, keep_alive=self.keep_alive) - try: - arr.append(res["embedding"]) - except Exception as _e: - log_exception(_e, res) - raise Exception(f"Error: {res}") - tks_num += 128 - return np.array(arr), tks_num + # No client-side truncation: Ollama truncates to the model context above. + return self._batched_encode(texts, self._call, batch_size=16) def encode_queries(self, text): - # remove special tokens if they exist - for token in OllamaEmbed._special_tokens: - text = text.replace(token, "") - res = self.client.embeddings(prompt=text, model=self.model_name, options={"use_mmap": True}, keep_alive=self.keep_alive) - try: - return np.array(res["embedding"]), 128 - except Exception as _e: - log_exception(_e, res) - raise Exception(f"Error: {res}") + vectors, token_count = self._batched_encode([text], self._call, batch_size=16) + return vectors[0], token_count class XinferenceEmbed(Base): @@ -448,29 +502,16 @@ class XinferenceEmbed(Base): self.client = OpenAI(api_key=key, base_url=base_url) self.model_name = model_name + def _call(self, batch): + res = self.client.embeddings.create(input=batch, model=self.model_name) + return [d.embedding for d in _sorted_by_index(res.data)], total_token_count_from_response(res) + def encode(self, texts: list): - batch_size = 16 - ress = [] - total_tokens = 0 - for i in range(0, len(texts), batch_size): - res = None - try: - res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name) - ress.extend([d.embedding for d in res.data]) - total_tokens += total_token_count_from_response(res) - except Exception as _e: - log_exception(_e, res) - raise Exception(f"Error: {res}") - return np.array(ress), total_tokens + return self._batched_encode(texts, self._call, batch_size=16) def encode_queries(self, text): - res = None - try: - res = self.client.embeddings.create(input=[text], model=self.model_name) - return np.array(res.data[0].embedding), total_token_count_from_response(res) - except Exception as _e: - log_exception(_e, res) - raise Exception(f"Error: {res}") + vectors, token_count = self._batched_encode([text], self._call, batch_size=16) + return vectors[0], token_count class YoudaoEmbed(Base): @@ -504,55 +545,42 @@ class JinaMultiVecEmbed(Base): self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"} self.model_name = model_name + @staticmethod + def _as_input_item(text): + if isinstance(text, str): + return {"text": text} + # bytes -> base64 encoded image + try: + base64.b64decode(text, validate=True) + return {"image": text.decode("utf8")} + except Exception: + return {"image": base64.b64encode(text).decode("utf8")} + def encode(self, texts: list[str | bytes], task="retrieval.passage"): - batch_size = 16 - ress = [] - token_count = 0 - input = [] - for text in texts: - if isinstance(text, str): - input.append({"text": text}) - elif isinstance(text, bytes): - img_b64s = None - try: - base64.b64decode(text, validate=True) - img_b64s = text.decode("utf8") - except Exception: - img_b64s = base64.b64encode(text).decode("utf8") - input.append({"image": img_b64s}) # base64 encoded image - for i in range(0, len(texts), batch_size): - data = {"model": self.model_name, "input": input[i : i + batch_size]} + def _call(batch): + data = {"model": self.model_name, "input": [self._as_input_item(t) for t in batch]} if "v4" in self.model_name: data["return_multivector"] = True - if "v3" in self.model_name or "v4" in self.model_name: data["task"] = task - data["truncate"] = True - + data["truncate"] = True # let Jina truncate oversized inputs server-side response = requests.post(self.base_url, headers=self.headers, json=data, timeout=30) _raise_model_exception_if_failed(response) - try: - res = response.json() - for d in res["data"]: - if data.get("return_multivector", False): # v4 - token_embs = np.asarray(d["embeddings"], dtype=np.float32) - chunk_emb = token_embs.mean(axis=0) + res = response.json() + embs = [] + for d in res["data"]: + if data.get("return_multivector", False): # v4 + embs.append(np.asarray(d["embeddings"], dtype=np.float32).mean(axis=0)) + else: # v2/v3 + embs.append(np.asarray(d["embedding"], dtype=np.float32)) + return embs, total_token_count_from_response(res) - else: - # v2/v3 - chunk_emb = np.asarray(d["embedding"], dtype=np.float32) - - ress.append(chunk_emb) - - token_count += total_token_count_from_response(res) - except Exception as _e: - log_exception(_e, response) - raise Exception(f"Error: {response}") - return np.array(ress), token_count + # Inputs may be image bytes, so token truncation is left to the server. + return self._batched_encode(texts, _call, batch_size=16) def encode_queries(self, text): - embds, cnt = self.encode([text], task="retrieval.query") - return np.array(embds[0]), cnt + vectors, token_count = self.encode([text], task="retrieval.query") + return vectors[0], token_count class MistralEmbed(Base): @@ -568,7 +596,7 @@ class MistralEmbed(Base): import time import random - texts = [truncate(t, 8196) for t in texts] + texts = [truncate(t, DEFAULT_MAX_TOKENS) for t in texts] batch_size = 16 ress = [] token_count = 0 @@ -582,7 +610,8 @@ class MistralEmbed(Base): break except Exception as _e: if retry_max == 1: - log_exception(_e) + logger.exception("MistralEmbed: embedding request failed after retries") + raise EmbeddingError(f"Embedding request failed for MistralEmbed. Error: {_e}") from _e delay = random.uniform(20, 60) time.sleep(delay) retry_max -= 1 @@ -595,11 +624,12 @@ class MistralEmbed(Base): retry_max = 5 while retry_max > 0: try: - res = self.client.embeddings(input=[truncate(text, 8196)], model=self.model_name) + res = self.client.embeddings(input=[truncate(text, DEFAULT_MAX_TOKENS)], model=self.model_name) return np.array(res.data[0].embedding), total_token_count_from_response(res) except Exception as _e: if retry_max == 1: - log_exception(_e) + logger.exception("MistralEmbed: query embedding request failed after retries") + raise EmbeddingError(f"Embedding request failed for MistralEmbed. Error: {_e}") from _e delay = random.randint(20, 60) time.sleep(delay) retry_max -= 1 @@ -649,42 +679,41 @@ class BedrockEmbed(Base): else: # assume_role self.client = boto3.client("bedrock-runtime", region_name=self.bedrock_region) + def _extract_vector(self, model_response): + # Titan returns {"embedding": [...]}; Cohere returns {"embeddings": [[...]]}. + if self.is_cohere: + return model_response["embeddings"][0] + return model_response["embedding"] + def encode(self, texts: list): - texts = [truncate(t, 8196) for t in texts] - embeddings = [] - token_count = 0 - for text in texts: + def _call(batch): + # Titan accepts a single input per call, so batch_size is 1. + text = batch[0] if self.is_amazon: body = {"inputText": text} elif self.is_cohere: body = {"texts": [text], "input_type": "search_document"} - response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body)) - try: - model_response = json.loads(response["body"].read()) - embeddings.extend([model_response["embedding"]]) - token_count += num_tokens_from_string(text) - except Exception as _e: - log_exception(_e, response) + model_response = json.loads(response["body"].read()) + # Bedrock does not report token usage; count locally. + return [self._extract_vector(model_response)], num_tokens_from_string(text) - return np.array(embeddings), token_count + return self._batched_encode(texts, _call, batch_size=1, truncate_to=DEFAULT_MAX_TOKENS) def encode_queries(self, text): - embeddings = [] + text = truncate(text, DEFAULT_MAX_TOKENS) token_count = num_tokens_from_string(text) if self.is_amazon: - body = {"inputText": truncate(text, 8196)} + body = {"inputText": text} elif self.is_cohere: - body = {"texts": [truncate(text, 8196)], "input_type": "search_query"} - - response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body)) + body = {"texts": [text], "input_type": "search_query"} try: + response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body)) model_response = json.loads(response["body"].read()) - embeddings.extend(model_response["embedding"]) + return np.array(self._extract_vector(model_response)), token_count except Exception as _e: - log_exception(_e, response) - - return np.array(embeddings), token_count + logger.exception("BedrockEmbed: query embedding request failed") + raise EmbeddingError(f"Embedding request failed for BedrockEmbed. Error: {_e}") from _e class GeminiEmbed(Base): @@ -741,28 +770,17 @@ class GeminiEmbed(Base): return self.types.EmbedContentConfig(task_type=task_type) def encode(self, texts: list): - texts = [truncate(t, 2048) for t in texts] - token_count = sum(num_tokens_from_string(text) for text in texts) config = self._build_embedding_config() - batch_size = 16 - ress = [] - for i in range(0, len(texts), batch_size): - result = None - try: - result = self.client.models.embed_content( - model=self.model_name, - contents=texts[i : i + batch_size], - config=config, - ) - ress.extend(self._parse_embedding_response(result)) - except Exception as _e: - log_exception(_e, result) - raise Exception(f"Error: {result}") - return np.array(ress), token_count + + def _call(batch): + result = self.client.models.embed_content(model=self.model_name, contents=batch, config=config) + # Gemini embeddings do not report token usage; count locally. + return self._parse_embedding_response(result), sum(num_tokens_from_string(t) for t in batch) + + return self._batched_encode(texts, _call, batch_size=16, truncate_to=2048) def encode_queries(self, text): config = self._build_embedding_config() - result = None token_count = num_tokens_from_string(text) try: result = self.client.models.embed_content( @@ -772,8 +790,8 @@ class GeminiEmbed(Base): ) return np.array(self._parse_embedding_response(result)[0]), token_count except Exception as _e: - log_exception(_e, result) - raise Exception(f"Error: {result}") + logger.exception("GeminiEmbed: query embedding request failed") + raise EmbeddingError(f"Embedding request failed for GeminiEmbed. Error: {_e}") from _e class NvidiaEmbed(Base): @@ -796,32 +814,24 @@ class NvidiaEmbed(Base): if model_name == "snowflake/arctic-embed-l": self.base_url = "https://ai.api.nvidia.com/v1/retrieval/snowflake/arctic-embed-l/embeddings" + def _call(self, batch, input_type="query"): + payload = { + "input": batch, + "input_type": input_type, + "model": self.model_name, + "encoding_format": "float", + "truncate": "END", # NVIDIA truncates oversized inputs server-side. + } + response = requests.post(self.base_url, headers=self.headers, json=payload, timeout=30) + return self._openai_http_embeddings(response) + def encode(self, texts: list): - batch_size = 16 - ress = [] - token_count = 0 - for i in range(0, len(texts), batch_size): - payload = { - "input": texts[i : i + batch_size], - "input_type": "query", - "model": self.model_name, - "encoding_format": "float", - "truncate": "END", - } - response = requests.post(self.base_url, headers=self.headers, json=payload, timeout=30) - _raise_model_exception_if_failed(response) - try: - res = response.json() - ress.extend([d["embedding"] for d in res["data"]]) - token_count += total_token_count_from_response(res) - except Exception as _e: - log_exception(_e, response) - raise Exception(f"Error: {response}") - return np.array(ress), token_count + # NVIDIA NIM expects "passage" for documents (indexing) and "query" for retrieval. + return self._batched_encode(texts, lambda b: self._call(b, "passage"), batch_size=16) def encode_queries(self, text): - embds, cnt = self.encode([text]) - return np.array(embds[0]), cnt + vectors, token_count = self._batched_encode([text], lambda b: self._call(b, "query"), batch_size=16) + return vectors[0], token_count class LmStudioEmbed(LocalAIEmbed): @@ -855,37 +865,32 @@ class CoHereEmbed(Base): self.client = Client(api_key=key) self.model_name = model_name + def _call(self, batch): + res = self.client.embed( + texts=batch, + model=self.model_name, + input_type="search_document", + embedding_types=["float"], + truncate="END", # let Cohere clip oversized inputs server-side instead of hard-failing + ) + return list(res.embeddings.float), total_token_count_from_response(res) + def encode(self, texts: list): - batch_size = 16 - ress = [] - token_count = 0 - for i in range(0, len(texts), batch_size): - res = self.client.embed( - texts=texts[i : i + batch_size], - model=self.model_name, - input_type="search_document", - embedding_types=["float"], - ) - try: - ress.extend([d for d in res.embeddings.float]) - token_count += total_token_count_from_response(res) - except Exception as _e: - log_exception(_e, res) - raise Exception(f"Error: {res}") - return np.array(ress), token_count + return self._batched_encode(texts, self._call, batch_size=16) def encode_queries(self, text): - res = self.client.embed( - texts=[text], - model=self.model_name, - input_type="search_query", - embedding_types=["float"], - ) try: + res = self.client.embed( + texts=[text], + model=self.model_name, + input_type="search_query", + embedding_types=["float"], + truncate="END", + ) return np.array(res.embeddings.float[0]), int(total_token_count_from_response(res)) except Exception as _e: - log_exception(_e, res) - raise Exception(f"Error: {res}") + logger.exception("CoHereEmbed: query embedding request failed") + raise EmbeddingError(f"Embedding request failed for CoHereEmbed. Error: {_e}") from _e class TogetherAIEmbed(OpenAIEmbed): @@ -932,49 +937,27 @@ class SILICONFLOWEmbed(Base): self.base_url = normalized_base_url self.model_name = model_name - def encode(self, texts: list): - batch_size = 16 - ress = [] - token_count = 0 - for i in range(0, len(texts), batch_size): - texts_batch = texts[i : i + batch_size] - if self.model_name in ["BAAI/bge-large-zh-v1.5", "BAAI/bge-large-en-v1.5"]: - # limit 512, 340 is almost safe - texts_batch = [" " if not text.strip() else truncate(text, 256) for text in texts_batch] - else: - texts_batch = [" " if not text.strip() else text for text in texts_batch] + def _clean_batch(self, batch): + if self.model_name in ["BAAI/bge-large-zh-v1.5", "BAAI/bge-large-en-v1.5"]: + # limit 512, 340 is almost safe + return [" " if not text.strip() else truncate(text, 256) for text in batch] + return [" " if not text.strip() else text for text in batch] - payload = { - "model": self.model_name, - "input": texts_batch, - "encoding_format": "float", - } - response = requests.post(self.base_url, json=payload, headers=self.headers, timeout=30) - _raise_model_exception_if_failed(response) - try: - res = response.json() - ress.extend([d["embedding"] for d in res["data"]]) - token_count += total_token_count_from_response(res) - except Exception as _e: - log_exception(_e, response) - raise Exception(f"Error: {response}") - - return np.array(ress), token_count - - def encode_queries(self, text): + def _call(self, batch): payload = { "model": self.model_name, - "input": text, + "input": self._clean_batch(batch), "encoding_format": "float", } response = requests.post(self.base_url, json=payload, headers=self.headers, timeout=30) - _raise_model_exception_if_failed(response) - try: - res = response.json() - return np.array(res["data"][0]["embedding"]), total_token_count_from_response(res) - except Exception as _e: - log_exception(_e, response) - raise Exception(f"Error: {response}") + return self._openai_http_embeddings(response) + + def encode(self, texts: list): + return self._batched_encode(texts, self._call, batch_size=16) + + def encode_queries(self, text): + vectors, token_count = self._batched_encode([text], self._call, batch_size=16) + return vectors[0], token_count class ReplicateEmbed(Base): @@ -1013,26 +996,26 @@ class BaiduYiyanEmbed(Base): self.model_name = model_name def encode(self, texts: list, batch_size=16): - res = self.client.do(model=self.model_name, texts=texts).body try: + res = self.client.do(model=self.model_name, texts=texts).body return ( np.array([r["embedding"] for r in res["data"]]), total_token_count_from_response(res), ) except Exception as _e: - log_exception(_e, res) - raise Exception(f"Error: {res}") + logger.exception("BaiduYiyanEmbed: embedding request failed") + raise EmbeddingError(f"Embedding request failed for BaiduYiyanEmbed. Error: {_e}") from _e def encode_queries(self, text): - res = self.client.do(model=self.model_name, texts=[text]).body try: + res = self.client.do(model=self.model_name, texts=[text]).body return ( np.array([r["embedding"] for r in res["data"]]), total_token_count_from_response(res), ) except Exception as _e: - log_exception(_e, res) - raise Exception(f"Error: {res}") + logger.exception("BaiduYiyanEmbed: query embedding request failed") + raise EmbeddingError(f"Embedding request failed for BaiduYiyanEmbed. Error: {_e}") from _e class VoyageEmbed(Base): @@ -1044,27 +1027,22 @@ class VoyageEmbed(Base): self.client = voyageai.Client(api_key=key) self.model_name = model_name + def _call(self, batch): + res = self.client.embed(texts=batch, model=self.model_name, input_type="document") + # `_batched_encode` accumulates these per-batch vectors and returns a + # single np.ndarray, so encode() keeps the np.ndarray contract. + return res.embeddings, res.total_tokens + def encode(self, texts: list): - batch_size = 16 - ress = [] - token_count = 0 - for i in range(0, len(texts), batch_size): - res = self.client.embed(texts=texts[i : i + batch_size], model=self.model_name, input_type="document") - try: - ress.extend(res.embeddings) - token_count += res.total_tokens - except Exception as _e: - log_exception(_e, res) - raise Exception(f"Error: {res}") - return np.array(ress), token_count + return self._batched_encode(texts, self._call, batch_size=16) def encode_queries(self, text): - res = self.client.embed(texts=text, model=self.model_name, input_type="query") try: + res = self.client.embed(texts=text, model=self.model_name, input_type="query") return np.array(res.embeddings)[0], res.total_tokens except Exception as _e: - log_exception(_e, res) - raise Exception(f"Error: {res}") + logger.exception("VoyageEmbed: query embedding request failed") + raise EmbeddingError(f"Embedding request failed for VoyageEmbed. Error: {_e}") from _e class HuggingFaceEmbed(Base): @@ -1080,14 +1058,13 @@ class HuggingFaceEmbed(Base): def encode(self, texts: list): response = requests.post(f"{self.base_url}/embed", json={"inputs": texts}, headers={"Content-Type": "application/json"}, timeout=30) _raise_model_exception_if_failed(response) - embeddings = response.json() - return np.array(embeddings), sum([num_tokens_from_string(text) for text in texts]) + # TEI auto-truncates oversized inputs, so no client-side truncation is needed. + return np.array(response.json()), sum([num_tokens_from_string(text) for text in texts]) def encode_queries(self, text: str): response = requests.post(f"{self.base_url}/embed", json={"inputs": text}, headers={"Content-Type": "application/json"}, timeout=30) _raise_model_exception_if_failed(response) - embedding = response.json()[0] - return np.array(embedding), num_tokens_from_string(text) + return np.array(response.json()[0]), num_tokens_from_string(text) class VolcEngineEmbed(Base): @@ -1141,13 +1118,14 @@ class VolcEngineEmbed(Base): request_body = {"model": self.model_name, "input": [{"type": "text", "text": text}]} response = sync_request(method="POST", url=url, headers=headers, json=request_body, timeout=60) if response.status_code != 200: - raise Exception(f"Error: {response.status_code} - {response.text}") + raise EmbeddingError(f"Embedding request failed for VolcEngineEmbed. Error: {response.status_code} - {response.text}") result = response.json() try: ress.append(self._extract_embedding(result)) total_tokens += total_token_count_from_response(result) except Exception as _e: - log_exception(_e) + logger.exception("VolcEngineEmbed: failed to parse embedding response") + raise EmbeddingError(f"Embedding request failed for VolcEngineEmbed. Error: {_e}; response={result}") from _e return np.array(ress), total_tokens @@ -1295,8 +1273,8 @@ class PerplexityEmbed(Base): ress.append(self._decode_base64_int8(chunk_emb["embedding"])) token_count += res.get("usage", {}).get("total_tokens", 0) except Exception as _e: - log_exception(_e, response) - raise Exception(f"Error: {response.text}") + logger.exception("PerplexityEmbed: failed to parse contextualized embedding response") + raise EmbeddingError(f"Embedding request failed for PerplexityEmbed. Error: {response.text}") from _e else: url = f"{self.base_url}/v1/embeddings" for i in range(0, len(texts), batch_size): @@ -1314,8 +1292,8 @@ class PerplexityEmbed(Base): ress.append(self._decode_base64_int8(d["embedding"])) token_count += res.get("usage", {}).get("total_tokens", 0) except Exception as _e: - log_exception(_e, response) - raise Exception(f"Error: {response.text}") + logger.exception("PerplexityEmbed: failed to parse embedding response") + raise EmbeddingError(f"Embedding request failed for PerplexityEmbed. Error: {response.text}") from _e return np.array(ress), token_count diff --git a/test/unit_test/conftest.py b/test/unit_test/conftest.py new file mode 100644 index 0000000000..bd5e7afe6c --- /dev/null +++ b/test/unit_test/conftest.py @@ -0,0 +1,50 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Shared setup for RAGFlow unit tests. + +Several parsers and the chunking pipeline tokenize text with NLTK, which needs +the ``punkt_tab`` and ``wordnet`` data sets. Production provisions these via +``download_deps.py`` (into ``nltk_data``, exported as ``NLTK_DATA`` by +``docker/launch_backend_service.sh``) and ``api.validation`` at startup, but the +unit-test runner has neither. Without the data, tokenizer-backed tests such as +``test_epub_parser`` and ``test_dataflow_service`` fail with +``LookupError: Resource 'punkt_tab' not found``. Make sure the data is reachable +before any test imports a tokenizer: reuse a provisioned ``nltk_data`` directory +when present, and download only what is still missing. +""" + +import os + +import nltk + +# Reuse data already fetched by download_deps.py (the directory the app exports +# as NLTK_DATA) so provisioned environments do not download it again. +_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir)) +_LOCAL_NLTK_DATA = os.path.join(_REPO_ROOT, "nltk_data") +if os.path.isdir(_LOCAL_NLTK_DATA) and _LOCAL_NLTK_DATA not in nltk.data.path: + nltk.data.path.insert(0, _LOCAL_NLTK_DATA) + +# (download name, resource path used by nltk.data.find) +_REQUIRED_NLTK_DATA = ( + ("punkt_tab", "tokenizers/punkt_tab"), + ("wordnet", "corpora/wordnet"), +) +for _name, _find_path in _REQUIRED_NLTK_DATA: + try: + nltk.data.find(_find_path) + except LookupError: + nltk.download(_name, quiet=True) diff --git a/test/unit_test/rag/llm/test_embedding_model.py b/test/unit_test/rag/llm/test_embedding_model.py new file mode 100644 index 0000000000..cbd0d0d7ee --- /dev/null +++ b/test/unit_test/rag/llm/test_embedding_model.py @@ -0,0 +1,383 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Tests for the embedding-provider fixes in ``rag.llm.embedding_model``: + +* a failing embedding call raises a single deterministic, informative + ``EmbeddingError`` (and the previous unreachable ``raise Exception(f"Error: {res}")`` + can no longer mask it, regardless of whether the SDK response exposes ``.text``); +* token counts reflect real usage, or an honest local fallback — never the old + fabricated ``1024`` / ``+= 128`` constants; +* inputs at the truncation boundary are not pushed past the model token limit + (the old ``8196`` overshoot is gone); +* ``ZhipuEmbed`` / ``OllamaEmbed`` now batch — ``ceil(n / batch_size)`` requests + with input order and output shape preserved. +""" + +import json +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest + +from rag.llm.embedding_model import ( + DEFAULT_MAX_TOKENS, + BedrockEmbed, + EmbeddingError, + LocalAIEmbed, + MistralEmbed, + NvidiaEmbed, + OllamaEmbed, + OpenAIEmbed, + ZhipuEmbed, +) +from common.exceptions import ModelException +from common.token_utils import num_tokens_from_string + + +# --------------------------------------------------------------------------- # +# Fakes +# --------------------------------------------------------------------------- # +class _OpenAIResp: + """Minimal stand-in for an OpenAI embeddings response. + + Unlike ``MagicMock`` it does NOT auto-create a ``usage`` attribute, so + ``total_token_count_from_response`` correctly returns 0 when ``total_tokens`` + is not supplied (exercising the local-count fallback paths). + """ + + def __init__(self, vectors, total_tokens=None): + self.data = [SimpleNamespace(embedding=list(v)) for v in vectors] + if total_tokens is not None: + self.usage = SimpleNamespace(total_tokens=total_tokens) + + +def _openai_create(total_tokens=None, dim=3): + """Build a side_effect that returns one vector per input text.""" + + def _create(input, model, **kwargs): + return _OpenAIResp([[float(i)] * dim for i in range(len(input))], total_tokens=total_tokens) + + return _create + + +def _make_openai(cls=OpenAIEmbed, total_tokens=None): + embed = cls("key", "text-embedding-3-small", base_url="https://example.invalid/v1") + embed.client = MagicMock() + embed.client.embeddings.create = MagicMock(side_effect=_openai_create(total_tokens=total_tokens)) + return embed + + +# --------------------------------------------------------------------------- # +# 1. Deterministic, informative error handling (the masked-error bug) +# --------------------------------------------------------------------------- # +class _BadRespWithText: + """Parsing this raises; it also exposes ``.text`` — which the old + ``log_exception(_e, res)`` path would have re-raised as a bare + ``Exception(text)``, masking the intended error non-deterministically.""" + + text = "Internal Server Error" + + @property + def data(self): + raise ValueError("malformed response payload") + + +class _BadRespNoText: + @property + def data(self): + raise ValueError("malformed response payload") + + +@pytest.mark.p1 +class TestDeterministicErrors: + def test_api_error_raises_embedding_error(self): + embed = _make_openai() + embed.client.embeddings.create = MagicMock(side_effect=RuntimeError("503 upstream down")) + with pytest.raises(EmbeddingError) as exc: + embed.encode(["hello"]) + # Informative: surfaces the underlying detail and contains "Error". + assert "503 upstream down" in str(exc.value) + assert "Error" in str(exc.value) + assert "OpenAIEmbed" in str(exc.value) + + def test_same_exception_type_with_and_without_text_attr(self): + """The surfaced exception must NOT depend on whether the response object + exposes ``.text`` (the old non-determinism). Both variants -> EmbeddingError.""" + with_text = _make_openai() + with_text.client.embeddings.create = MagicMock(return_value=_BadRespWithText()) + without_text = _make_openai() + without_text.client.embeddings.create = MagicMock(return_value=_BadRespNoText()) + + with pytest.raises(EmbeddingError) as e1: + with_text.encode(["x"]) + with pytest.raises(EmbeddingError) as e2: + without_text.encode(["x"]) + + # Deterministic: same type, and the response's ``.text`` did not hijack it. + assert type(e1.value) is type(e2.value) is EmbeddingError + assert "Internal Server Error" not in str(e1.value) + assert "malformed response payload" in str(e1.value) + + def test_query_path_also_deterministic(self): + embed = _make_openai() + embed.client.embeddings.create = MagicMock(side_effect=RuntimeError("nope")) + with pytest.raises(EmbeddingError): + embed.encode_queries("hi") + + def test_http_bad_status_raises_model_exception_with_body(self): + """A bad HTTP status surfaces the response body via a retryable-aware + ModelException, which the API error handler understands.""" + embed = NvidiaEmbed("key", "nvidia/nv-embed-v1") + bad = MagicMock() + bad.status_code = 400 + bad.text = '{"error": "bad request: empty input"}' + with patch("rag.llm.embedding_model.requests.post", return_value=bad): + with pytest.raises(ModelException) as exc: + embed.encode(["hello"]) + assert "bad request: empty input" in str(exc.value) + + def test_http_malformed_ok_response_raises_embedding_error(self): + """A 200 response with an unexpected body still yields a deterministic + EmbeddingError carrying the payload detail.""" + embed = NvidiaEmbed("key", "nvidia/nv-embed-v1") + bad = MagicMock() + bad.status_code = 200 + bad.json.return_value = {"unexpected": "shape"} + with patch("rag.llm.embedding_model.requests.post", return_value=bad): + with pytest.raises(EmbeddingError) as exc: + embed.encode(["hello"]) + assert "unexpected" in str(exc.value) + + +# --------------------------------------------------------------------------- # +# 2. Token accounting (no fabricated 1024 / += 128) +# --------------------------------------------------------------------------- # +@pytest.mark.p1 +class TestTokenAccounting: + def test_openai_uses_reported_usage(self): + embed = _make_openai(total_tokens=42) + _, tokens = embed.encode(["a", "b"]) + assert tokens == 42 + + def test_localai_falls_back_to_local_count_not_1024(self): + embed = _make_openai(cls=LocalAIEmbed) # no usage in response + texts = ["hello world", "second chunk of text"] + _, tokens = embed.encode(texts) + expected = sum(num_tokens_from_string(t) for t in texts) + assert tokens == expected + assert tokens != 1024 # the old fabricated constant + + def test_ollama_uses_prompt_eval_count_not_128(self): + embed = OllamaEmbed("x", "nomic-embed-text", base_url="http://localhost:11434") + embed.client = MagicMock() + embed.client.embed = MagicMock(return_value={"embeddings": [[0.1, 0.2], [0.3, 0.4]], "prompt_eval_count": 33}) + _, tokens = embed.encode(["aaa", "bbb"]) + assert tokens == 33 + assert tokens != 128 * 2 # the old fabricated per-text constant + + def test_ollama_token_fallback_when_server_omits_count(self): + embed = OllamaEmbed("x", "nomic-embed-text", base_url="http://localhost:11434") + embed.client = MagicMock() + # No prompt_eval_count reported -> honest local count, not a fixed number. + embed.client.embed = MagicMock(return_value={"embeddings": [[0.1, 0.2]]}) + texts = ["some text to embed"] + _, tokens = embed.encode(texts) + assert tokens == sum(num_tokens_from_string(t) for t in texts) + + +# --------------------------------------------------------------------------- # +# 3. Truncation boundary (no 8196 overshoot) +# --------------------------------------------------------------------------- # +@pytest.mark.p2 +class TestTruncationBoundary: + def test_default_limit_is_8192(self): + assert DEFAULT_MAX_TOKENS == 8192 + + def test_openai_input_truncated_below_model_limit(self): + embed = _make_openai(total_tokens=1) + # An input far above the 8K ceiling. + huge = "word " * 12000 + embed.encode([huge]) + sent = embed.client.embeddings.create.call_args.kwargs["input"][0] + # Truncated to the documented 8191 ceiling, never above the 8192 model limit. + assert num_tokens_from_string(sent) <= 8191 + assert num_tokens_from_string(sent) <= DEFAULT_MAX_TOKENS + + def test_mistral_truncates_to_8192_not_8196(self): + embed = MistralEmbed.__new__(MistralEmbed) + embed.model_name = "mistral-embed" + captured = {} + + def _embeddings(input, model): + captured["input"] = input + return _OpenAIResp([[0.0, 0.0]], total_tokens=1) + + embed.client = MagicMock() + embed.client.embeddings = MagicMock(side_effect=_embeddings) + huge = "word " * 12000 + embed.encode([huge]) + assert num_tokens_from_string(captured["input"][0]) <= DEFAULT_MAX_TOKENS + + +# --------------------------------------------------------------------------- # +# 4. Batching for Zhipu and Ollama (ceil(n / batch_size) requests) +# --------------------------------------------------------------------------- # +@pytest.mark.p1 +class TestBatching: + def test_zhipu_batches_instead_of_per_text(self): + embed = ZhipuEmbed("key", "embedding-3") + embed.client = MagicMock() + embed.client.embeddings.create = MagicMock(side_effect=_openai_create(total_tokens=5)) + texts = [f"t{i}" for i in range(3)] + vectors, _ = embed.encode(texts) + # One request for 3 texts (batch_size 16) — NOT three per-text requests. + assert embed.client.embeddings.create.call_count == 1 + assert vectors.shape[0] == 3 + + def test_zhipu_issues_ceil_n_over_batch_calls(self): + embed = ZhipuEmbed("key", "embedding-3") + embed.client = MagicMock() + embed.client.embeddings.create = MagicMock(side_effect=_openai_create(total_tokens=5)) + texts = [f"t{i}" for i in range(20)] # batch_size 16 -> ceil(20/16) == 2 + vectors, _ = embed.encode(texts) + assert embed.client.embeddings.create.call_count == 2 + assert vectors.shape[0] == 20 + + def test_ollama_batches_and_preserves_order(self): + embed = OllamaEmbed("x", "nomic-embed-text", base_url="http://localhost:11434") + embed.client = MagicMock() + + def _embed(model, input, **kwargs): + # Echo a recognisable vector per input so order can be checked. + return {"embeddings": [[float(len(t))] for t in input], "prompt_eval_count": 1} + + embed.client.embed = MagicMock(side_effect=_embed) + texts = ["a", "bb", "ccc"] + vectors, _ = embed.encode(texts) + + # One batched request, not one per text. + assert embed.client.embed.call_count == 1 + assert vectors.shape == (3, 1) + # Order preserved: vector value equals input length. + np.testing.assert_array_equal(vectors[:, 0], np.array([1.0, 2.0, 3.0])) + + def test_zhipu_realigns_out_of_order_response(self): + """If the provider returns embeddings out of order, the per-item `index` + must realign them with the input — otherwise chunks get wrong vectors.""" + embed = ZhipuEmbed("key", "embedding-3") + embed.client = MagicMock() + + def _create(input, model, **kwargs): + data = [SimpleNamespace(embedding=[float(i)], index=i) for i in range(len(input))] + return SimpleNamespace(data=list(reversed(data)), usage=SimpleNamespace(total_tokens=1)) + + embed.client.embeddings.create = MagicMock(side_effect=_create) + vectors, _ = embed.encode(["t0", "t1", "t2"]) + np.testing.assert_array_equal(vectors[:, 0], np.array([0.0, 1.0, 2.0])) + + def test_nvidia_http_realigns_out_of_order_response(self): + embed = NvidiaEmbed("key", "nvidia/nv-embed-v1") + resp = MagicMock() + resp.status_code = 200 + resp.json.return_value = { + "data": [ + {"index": 2, "embedding": [2.0]}, + {"index": 0, "embedding": [0.0]}, + {"index": 1, "embedding": [1.0]}, + ], + "usage": {"total_tokens": 3}, + } + with patch("rag.llm.embedding_model.requests.post", return_value=resp): + vectors, _ = embed.encode(["a", "b", "c"]) + np.testing.assert_array_equal(vectors[:, 0], np.array([0.0, 1.0, 2.0])) + + def test_ollama_issues_ceil_n_over_batch_calls(self): + embed = OllamaEmbed("x", "nomic-embed-text", base_url="http://localhost:11434") + embed.client = MagicMock() + embed.client.embed = MagicMock(side_effect=lambda model, input, **kw: {"embeddings": [[0.0] for _ in input], "prompt_eval_count": 1}) + texts = [f"t{i}" for i in range(20)] # batch_size 16 -> 2 calls + vectors, _ = embed.encode(texts) + assert embed.client.embed.call_count == 2 + assert vectors.shape[0] == 20 + + +# --------------------------------------------------------------------------- # +# 5. Provider-specific request/response shapes +# --------------------------------------------------------------------------- # +@pytest.mark.p2 +class TestNvidiaInputType: + """NVIDIA NIM expects input_type=passage for documents and =query for queries; + using "query" for documents degrades retrieval (asymmetric embeddings).""" + + def _mock_resp(self): + resp = MagicMock() + resp.status_code = 200 + resp.json.return_value = {"data": [{"index": 0, "embedding": [1.0]}], "usage": {"total_tokens": 1}} + return resp + + def test_documents_use_passage(self): + embed = NvidiaEmbed("key", "nvidia/nv-embed-v1") + with patch("rag.llm.embedding_model.requests.post", return_value=self._mock_resp()) as post: + embed.encode(["a document"]) + assert post.call_args.kwargs["json"]["input_type"] == "passage" + + def test_queries_use_query(self): + embed = NvidiaEmbed("key", "nvidia/nv-embed-v1") + with patch("rag.llm.embedding_model.requests.post", return_value=self._mock_resp()) as post: + embed.encode_queries("a query") + assert post.call_args.kwargs["json"]["input_type"] == "query" + + +@pytest.mark.p2 +class TestBedrockResponseParsing: + """Bedrock Titan returns {"embedding": [...]}; Cohere returns + {"embeddings": [[...]]}. Both must parse without KeyError.""" + + @staticmethod + def _make(model_prefix): + embed = BedrockEmbed.__new__(BedrockEmbed) + embed.model_name = f"{model_prefix}.embed-model" + embed.is_amazon = model_prefix == "amazon" + embed.is_cohere = model_prefix == "cohere" + embed.client = MagicMock() + return embed + + @staticmethod + def _body(payload): + body = MagicMock() + body.read.return_value = json.dumps(payload).encode() + return {"body": body} + + def test_cohere_reads_embeddings_plural(self): + embed = self._make("cohere") + embed.client.invoke_model.return_value = self._body({"embeddings": [[1.0, 2.0]]}) + vectors, _ = embed.encode(["hello"]) + assert vectors.shape == (1, 2) + np.testing.assert_array_equal(vectors[0], np.array([1.0, 2.0])) + + def test_amazon_reads_embedding_singular(self): + embed = self._make("amazon") + embed.client.invoke_model.return_value = self._body({"embedding": [3.0, 4.0]}) + vectors, _ = embed.encode(["hello"]) + np.testing.assert_array_equal(vectors[0], np.array([3.0, 4.0])) + + def test_cohere_query_reads_embeddings_plural(self): + embed = self._make("cohere") + embed.client.invoke_model.return_value = self._body({"embeddings": [[5.0, 6.0]]}) + vector, _ = embed.encode_queries("q") + np.testing.assert_array_equal(vector, np.array([5.0, 6.0]))