diff --git a/download_deps.py b/download_deps.py index b707e03622..df29eaac91 100644 --- a/download_deps.py +++ b/download_deps.py @@ -64,6 +64,12 @@ if __name__ == "__main__": urls = get_urls(args.china_mirrors) + # Some mirrors (e.g. archive.ubuntu.com) reject the default urllib + # User-Agent with HTTP 403, so install an opener with a browser-like UA. + opener = urllib.request.build_opener() + opener.addheaders = [("User-Agent", "Mozilla/5.0")] + urllib.request.install_opener(opener) + for url in urls: download_url = url[0] if isinstance(url, list) else url filename = url[1] if isinstance(url, list) else url.split("/")[-1] diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 616382cf72..c8ed17b34a 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -30,7 +30,6 @@ from zhipuai import ZhipuAI from common import settings from common.exceptions import ModelException -from common.log_utils import log_exception from common.token_utils import num_tokens_from_string, truncate, total_token_count_from_response from rag.llm.key_utils import _normalize_replicate_key import logging @@ -38,6 +37,28 @@ import base64 logger = logging.getLogger(__name__) +# Standard token ceiling for the common 8K-context embedding models (OpenAI +# text-embedding-*, Mistral, Bedrock Titan, ...). Inputs are truncated to this +# many tokens so boundary-sized chunks are not rejected by the provider. +DEFAULT_MAX_TOKENS = 8192 + + +class EmbeddingError(ModelException): + """Raised when an embedding provider fails to return usable embeddings. + + A single, deterministic exception type for every provider failure path so + callers see consistent behaviour regardless of which SDK raised underneath. + Subclasses ``ModelException`` so the API error handler (and its retry + semantics) treats embedding failures like any other model failure. + """ + + +def _sorted_by_index(items): + """Order OpenAI-style SDK embedding items by their ``.index`` so batched + results stay aligned with input order even if the provider returns them out + of order. Stable no-op when items carry no ``index`` attribute.""" + return sorted(items, key=lambda d: getattr(d, "index", 0)) + def _raise_model_exception_if_failed(resp): status_code = resp.status_code @@ -91,8 +112,7 @@ def _dashscope_native_http_api_url(base_url: str | None) -> str | None: ) return resolved logger.warning( - "DashScope Tongyi-Qianwen embedding: base_url is set but not recognized as a DashScope host; " - "using SDK default endpoint (%s)", + "DashScope Tongyi-Qianwen embedding: base_url is set but not recognized as a DashScope host; using SDK default endpoint (%s)", safe, ) return None @@ -130,6 +150,66 @@ class Base(ABC): def encode_queries(self, text: str): raise NotImplementedError("Please implement encode method!") + def _batched_encode(self, texts: list, call_fn, *, batch_size: int, truncate_to: int | None = None): + """Drive an embedding provider over ``texts`` in batches. + + This is the shared template behind the OpenAI-style providers. It owns: + + * optional per-text truncation to ``truncate_to`` tokens (skipped when + ``None``) so oversized inputs do not get rejected by the provider; + * the batch loop, issuing ``ceil(len(texts) / batch_size)`` calls; + * accumulation of the per-text vectors into a single ``np.ndarray``; + * summation of the per-batch token counts; + * one deterministic, informative error path. + + ``call_fn`` is a provider-supplied closure ``call_fn(batch) -> + (embeddings, token_count)``. It performs the SDK/HTTP request *and* + parses the response (so a malformed/error response surfaces here), and + must not assume any particular response shape — the helper never touches + the raw response object. ``embeddings`` is a sequence of per-text + vectors; ``token_count`` is the real token usage for that batch. + + Any exception raised by ``call_fn`` is wrapped in a single + :class:`EmbeddingError` that includes the underlying detail. We log and + raise here directly instead of relying on ``log_exception``'s implicit + raise (whose surfaced exception varies by SDK response shape). + """ + if truncate_to is not None: + texts = [truncate(t, truncate_to) for t in texts] + vectors = [] + token_count = 0 + for i in range(0, len(texts), batch_size): + batch = texts[i : i + batch_size] + try: + embeddings, tokens = call_fn(batch) + except ModelException: + # Already a structured (and possibly retryable) model error; keep it. + raise + except Exception as e: + logger.exception("%s embedding request failed", type(self).__name__) + raise EmbeddingError(f"Embedding request failed for {type(self).__name__}. Error: {e}") from e + vectors.extend(embeddings) + token_count += tokens + return np.array(vectors), token_count + + @staticmethod + def _openai_http_embeddings(response): + """Parse an OpenAI-compatible HTTP embeddings ``requests`` response. + + Returns ``(embeddings, token_count)``. Raises a retryable-aware + :class:`ModelException` on a bad HTTP status, or surfaces the response + body (via :class:`EmbeddingError`) when the payload is not a successful + ``{"data": [...]}`` response. + """ + _raise_model_exception_if_failed(response) + res = response.json() + if not isinstance(res, dict) or "data" not in res: + raise ValueError(f"unexpected embeddings response (status {getattr(response, 'status_code', '?')}): {res}") + # Keep results aligned with input order: OpenAI-compatible responses carry + # a per-item `index`; sorting by it is a no-op (stable) when it is absent. + data = sorted(res["data"], key=lambda d: d.get("index", 0)) + return [d["embedding"] for d in data], total_token_count_from_response(res) + class BuiltinEmbed(Base): _FACTORY_NAME = "Builtin" @@ -176,35 +256,17 @@ class OpenAIEmbed(Base): self.client = OpenAI(api_key=key, base_url=base_url) self.model_name = model_name + def _call(self, batch): + res = self.client.embeddings.create(input=batch, model=self.model_name, encoding_format="float", extra_body={"drop_params": True}) + return [d.embedding for d in _sorted_by_index(res.data)], total_token_count_from_response(res) + def encode(self, texts: list): - # OpenAI requires batch size <=16 - batch_size = 16 - texts = [truncate(t, 8191) for t in texts] - ress = [] - total_tokens = 0 - for i in range(0, len(texts), batch_size): - try: - res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name, encoding_format="float", extra_body={"drop_params": True}) - except Exception as _e: - raise ModelException(f"Error: {_e}") - try: - ress.extend([d.embedding for d in res.data]) - total_tokens += total_token_count_from_response(res) - except Exception as _e: - log_exception(_e, res) - raise Exception(f"Error: {res}") - return np.array(ress), total_tokens + # OpenAI requires batch size <=16; 8191 is the documented per-input token ceiling. + return self._batched_encode(texts, self._call, batch_size=16, truncate_to=8191) def encode_queries(self, text): - try: - res = self.client.embeddings.create(input=[truncate(text, 8191)], model=self.model_name, encoding_format="float", extra_body={"drop_params": True}) - except Exception as _e: - raise ModelException(f"Error: {_e}") - try: - return np.array(res.data[0].embedding), total_token_count_from_response(res) - except Exception as _e: - log_exception(_e, res) - raise Exception(f"Error: {res}") + vectors, token_count = self._batched_encode([text], self._call, batch_size=16, truncate_to=8191) + return vectors[0], token_count class LocalAIEmbed(Base): @@ -217,22 +279,21 @@ class LocalAIEmbed(Base): self.client = OpenAI(api_key="empty", base_url=base_url) self.model_name = model_name.split("___")[0] + def _call(self, batch): + res = self.client.embeddings.create(input=batch, model=self.model_name) + # Local servers (LocalAI / LM Studio) usually omit usage data; fall back + # to a local tiktoken count rather than fabricating a fixed number. + tokens = total_token_count_from_response(res) + if not tokens: + tokens = sum(num_tokens_from_string(t) for t in batch) + return [d.embedding for d in _sorted_by_index(res.data)], tokens + def encode(self, texts: list): - batch_size = 16 - ress = [] - for i in range(0, len(texts), batch_size): - res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name) - try: - ress.extend([d.embedding for d in res.data]) - except Exception as _e: - log_exception(_e, res) - raise Exception(f"Error: {res}") - # local embedding for LmStudio donot count tokens - return np.array(ress), 1024 + return self._batched_encode(texts, self._call, batch_size=16) def encode_queries(self, text): - embds, cnt = self.encode([text]) - return np.array(embds[0]), cnt + vectors, token_count = self._batched_encode([text], self._call, batch_size=16) + return vectors[0], token_count def _resolve_azure_credentials(key): @@ -323,23 +384,21 @@ class QWenEmbed(Base): token_count = 0 texts = [truncate(t, 2048) for t in texts] for i in range(0, len(texts), batch_size): - retry_max, retry_wait_secs = 5, 10 for retry in range(retry_max): with _dashscope_native_api_url_scope(self._dashscope_http_api_url): resp = dashscope.TextEmbedding.call(model=self.model_name, input=texts[i : i + batch_size], api_key=self.key, text_type="document") status_code = resp.status_code if status_code >= 400 and status_code < 500 and status_code not in [408, 429]: - raise ModelException(f"Error, status: {status_code}, response: {resp}") # No need to retry for 4XX error + raise ModelException(f"Error, status: {status_code}, response: {resp}") if status_code == 200: break if retry < retry_max - 1: logging.warning(f"Got error response from DashScope API (status: {status_code}, response: {resp}). Wait {retry_wait_secs} seconds. Retrying...") time.sleep(retry_wait_secs) else: - raise ModelException(f"Error after {retry_max} retries., status: {status_code}, response: {resp}") - + raise ModelException(f"Error after {retry_max} retries, status: {status_code}, response: {resp}") try: embds = [[] for _ in range(len(resp["output"]["embeddings"]))] for e in resp["output"]["embeddings"]: @@ -347,8 +406,8 @@ class QWenEmbed(Base): res.extend(embds) token_count += total_token_count_from_response(resp) except Exception as _e: - log_exception(_e, resp) - raise ModelException(f"Error: {status_code}: {resp}") + logger.exception("QWenEmbed: failed to parse embedding response") + raise EmbeddingError(f"Embedding request failed for QWenEmbed. Error: {_e}; response={resp}") from _e return np.array(res), token_count def encode_queries(self, text): @@ -361,8 +420,8 @@ class QWenEmbed(Base): try: return np.array(resp["output"]["embeddings"][0]["embedding"]), total_token_count_from_response(resp) except Exception as _e: - log_exception(_e, resp) - raise ModelException(f"Error: {status_code}: {resp}") + logger.exception("QWenEmbed: failed to parse query embedding response") + raise EmbeddingError(f"Embedding request failed for QWenEmbed. Error: {_e}; response={resp}") from _e class ZhipuEmbed(Base): @@ -372,34 +431,28 @@ class ZhipuEmbed(Base): self.client = ZhipuAI(api_key=key) self.model_name = model_name - def encode(self, texts: list): - arr = [] - tks_num = 0 - MAX_LEN = -1 + def _max_len(self): + # Per-model input ceilings; fall back to the standard 8K limit for any + # other model rather than leaving oversized inputs untruncated. if self.model_name.lower() == "embedding-2": - MAX_LEN = 512 + return 512 if self.model_name.lower() == "embedding-3": - MAX_LEN = 3072 - if MAX_LEN > 0: - texts = [truncate(t, MAX_LEN) for t in texts] + return 3072 + return DEFAULT_MAX_TOKENS - for txt in texts: - res = self.client.embeddings.create(input=txt, model=self.model_name) - try: - arr.append(res.data[0].embedding) - tks_num += total_token_count_from_response(res) - except Exception as _e: - log_exception(_e, res) - raise Exception(f"Error: {res}") - return np.array(arr), tks_num + def _call(self, batch): + # Batch like the other OpenAI-style providers: one request per batch + # instead of one request per text. Sort by index so the batched results + # stay aligned with input order. + res = self.client.embeddings.create(input=batch, model=self.model_name) + return [d.embedding for d in _sorted_by_index(res.data)], total_token_count_from_response(res) + + def encode(self, texts: list): + return self._batched_encode(texts, self._call, batch_size=16, truncate_to=self._max_len()) def encode_queries(self, text): - res = self.client.embeddings.create(input=text, model=self.model_name) - try: - return np.array(res.data[0].embedding), total_token_count_from_response(res) - except Exception as _e: - log_exception(_e, res) - raise Exception(f"Error: {res}") + vectors, token_count = self._batched_encode([text], self._call, batch_size=16, truncate_to=self._max_len()) + return vectors[0], token_count class OllamaEmbed(Base): @@ -412,32 +465,33 @@ class OllamaEmbed(Base): self.model_name = model_name self.keep_alive = kwargs.get("ollama_keep_alive", int(os.environ.get("OLLAMA_KEEP_ALIVE", -1))) + @classmethod + def _strip_special(cls, text: str) -> str: + for token in cls._special_tokens: + text = text.replace(token, "") + return text + + def _call(self, batch): + # Batch via client.embed (accepts a list `input`) instead of one + # client.embeddings request per text. `truncate=True` lets Ollama clip + # oversized inputs to the model's real context length server-side, which + # is more accurate than a client-side cl100k estimate. + cleaned = [self._strip_special(t) for t in batch] + res = self.client.embed(model=self.model_name, input=cleaned, truncate=True, options={"use_mmap": True}, keep_alive=self.keep_alive) + # Ollama reports real prompt token usage in `prompt_eval_count`; fall + # back to a local count only if the server omits it (never a fixed 128). + tokens = res.get("prompt_eval_count") or 0 + if not tokens: + tokens = sum(num_tokens_from_string(t) for t in cleaned) + return res["embeddings"], tokens + def encode(self, texts: list): - arr = [] - tks_num = 0 - for txt in texts: - # remove special tokens if they exist base on regex in one request - for token in OllamaEmbed._special_tokens: - txt = txt.replace(token, "") - res = self.client.embeddings(prompt=txt, model=self.model_name, options={"use_mmap": True}, keep_alive=self.keep_alive) - try: - arr.append(res["embedding"]) - except Exception as _e: - log_exception(_e, res) - raise Exception(f"Error: {res}") - tks_num += 128 - return np.array(arr), tks_num + # No client-side truncation: Ollama truncates to the model context above. + return self._batched_encode(texts, self._call, batch_size=16) def encode_queries(self, text): - # remove special tokens if they exist - for token in OllamaEmbed._special_tokens: - text = text.replace(token, "") - res = self.client.embeddings(prompt=text, model=self.model_name, options={"use_mmap": True}, keep_alive=self.keep_alive) - try: - return np.array(res["embedding"]), 128 - except Exception as _e: - log_exception(_e, res) - raise Exception(f"Error: {res}") + vectors, token_count = self._batched_encode([text], self._call, batch_size=16) + return vectors[0], token_count class XinferenceEmbed(Base): @@ -448,29 +502,16 @@ class XinferenceEmbed(Base): self.client = OpenAI(api_key=key, base_url=base_url) self.model_name = model_name + def _call(self, batch): + res = self.client.embeddings.create(input=batch, model=self.model_name) + return [d.embedding for d in _sorted_by_index(res.data)], total_token_count_from_response(res) + def encode(self, texts: list): - batch_size = 16 - ress = [] - total_tokens = 0 - for i in range(0, len(texts), batch_size): - res = None - try: - res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name) - ress.extend([d.embedding for d in res.data]) - total_tokens += total_token_count_from_response(res) - except Exception as _e: - log_exception(_e, res) - raise Exception(f"Error: {res}") - return np.array(ress), total_tokens + return self._batched_encode(texts, self._call, batch_size=16) def encode_queries(self, text): - res = None - try: - res = self.client.embeddings.create(input=[text], model=self.model_name) - return np.array(res.data[0].embedding), total_token_count_from_response(res) - except Exception as _e: - log_exception(_e, res) - raise Exception(f"Error: {res}") + vectors, token_count = self._batched_encode([text], self._call, batch_size=16) + return vectors[0], token_count class YoudaoEmbed(Base): @@ -504,55 +545,42 @@ class JinaMultiVecEmbed(Base): self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"} self.model_name = model_name + @staticmethod + def _as_input_item(text): + if isinstance(text, str): + return {"text": text} + # bytes -> base64 encoded image + try: + base64.b64decode(text, validate=True) + return {"image": text.decode("utf8")} + except Exception: + return {"image": base64.b64encode(text).decode("utf8")} + def encode(self, texts: list[str | bytes], task="retrieval.passage"): - batch_size = 16 - ress = [] - token_count = 0 - input = [] - for text in texts: - if isinstance(text, str): - input.append({"text": text}) - elif isinstance(text, bytes): - img_b64s = None - try: - base64.b64decode(text, validate=True) - img_b64s = text.decode("utf8") - except Exception: - img_b64s = base64.b64encode(text).decode("utf8") - input.append({"image": img_b64s}) # base64 encoded image - for i in range(0, len(texts), batch_size): - data = {"model": self.model_name, "input": input[i : i + batch_size]} + def _call(batch): + data = {"model": self.model_name, "input": [self._as_input_item(t) for t in batch]} if "v4" in self.model_name: data["return_multivector"] = True - if "v3" in self.model_name or "v4" in self.model_name: data["task"] = task - data["truncate"] = True - + data["truncate"] = True # let Jina truncate oversized inputs server-side response = requests.post(self.base_url, headers=self.headers, json=data, timeout=30) _raise_model_exception_if_failed(response) - try: - res = response.json() - for d in res["data"]: - if data.get("return_multivector", False): # v4 - token_embs = np.asarray(d["embeddings"], dtype=np.float32) - chunk_emb = token_embs.mean(axis=0) + res = response.json() + embs = [] + for d in res["data"]: + if data.get("return_multivector", False): # v4 + embs.append(np.asarray(d["embeddings"], dtype=np.float32).mean(axis=0)) + else: # v2/v3 + embs.append(np.asarray(d["embedding"], dtype=np.float32)) + return embs, total_token_count_from_response(res) - else: - # v2/v3 - chunk_emb = np.asarray(d["embedding"], dtype=np.float32) - - ress.append(chunk_emb) - - token_count += total_token_count_from_response(res) - except Exception as _e: - log_exception(_e, response) - raise Exception(f"Error: {response}") - return np.array(ress), token_count + # Inputs may be image bytes, so token truncation is left to the server. + return self._batched_encode(texts, _call, batch_size=16) def encode_queries(self, text): - embds, cnt = self.encode([text], task="retrieval.query") - return np.array(embds[0]), cnt + vectors, token_count = self.encode([text], task="retrieval.query") + return vectors[0], token_count class MistralEmbed(Base): @@ -568,7 +596,7 @@ class MistralEmbed(Base): import time import random - texts = [truncate(t, 8196) for t in texts] + texts = [truncate(t, DEFAULT_MAX_TOKENS) for t in texts] batch_size = 16 ress = [] token_count = 0 @@ -582,7 +610,8 @@ class MistralEmbed(Base): break except Exception as _e: if retry_max == 1: - log_exception(_e) + logger.exception("MistralEmbed: embedding request failed after retries") + raise EmbeddingError(f"Embedding request failed for MistralEmbed. Error: {_e}") from _e delay = random.uniform(20, 60) time.sleep(delay) retry_max -= 1 @@ -595,11 +624,12 @@ class MistralEmbed(Base): retry_max = 5 while retry_max > 0: try: - res = self.client.embeddings(input=[truncate(text, 8196)], model=self.model_name) + res = self.client.embeddings(input=[truncate(text, DEFAULT_MAX_TOKENS)], model=self.model_name) return np.array(res.data[0].embedding), total_token_count_from_response(res) except Exception as _e: if retry_max == 1: - log_exception(_e) + logger.exception("MistralEmbed: query embedding request failed after retries") + raise EmbeddingError(f"Embedding request failed for MistralEmbed. Error: {_e}") from _e delay = random.randint(20, 60) time.sleep(delay) retry_max -= 1 @@ -649,42 +679,41 @@ class BedrockEmbed(Base): else: # assume_role self.client = boto3.client("bedrock-runtime", region_name=self.bedrock_region) + def _extract_vector(self, model_response): + # Titan returns {"embedding": [...]}; Cohere returns {"embeddings": [[...]]}. + if self.is_cohere: + return model_response["embeddings"][0] + return model_response["embedding"] + def encode(self, texts: list): - texts = [truncate(t, 8196) for t in texts] - embeddings = [] - token_count = 0 - for text in texts: + def _call(batch): + # Titan accepts a single input per call, so batch_size is 1. + text = batch[0] if self.is_amazon: body = {"inputText": text} elif self.is_cohere: body = {"texts": [text], "input_type": "search_document"} - response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body)) - try: - model_response = json.loads(response["body"].read()) - embeddings.extend([model_response["embedding"]]) - token_count += num_tokens_from_string(text) - except Exception as _e: - log_exception(_e, response) + model_response = json.loads(response["body"].read()) + # Bedrock does not report token usage; count locally. + return [self._extract_vector(model_response)], num_tokens_from_string(text) - return np.array(embeddings), token_count + return self._batched_encode(texts, _call, batch_size=1, truncate_to=DEFAULT_MAX_TOKENS) def encode_queries(self, text): - embeddings = [] + text = truncate(text, DEFAULT_MAX_TOKENS) token_count = num_tokens_from_string(text) if self.is_amazon: - body = {"inputText": truncate(text, 8196)} + body = {"inputText": text} elif self.is_cohere: - body = {"texts": [truncate(text, 8196)], "input_type": "search_query"} - - response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body)) + body = {"texts": [text], "input_type": "search_query"} try: + response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body)) model_response = json.loads(response["body"].read()) - embeddings.extend(model_response["embedding"]) + return np.array(self._extract_vector(model_response)), token_count except Exception as _e: - log_exception(_e, response) - - return np.array(embeddings), token_count + logger.exception("BedrockEmbed: query embedding request failed") + raise EmbeddingError(f"Embedding request failed for BedrockEmbed. Error: {_e}") from _e class GeminiEmbed(Base): @@ -741,28 +770,17 @@ class GeminiEmbed(Base): return self.types.EmbedContentConfig(task_type=task_type) def encode(self, texts: list): - texts = [truncate(t, 2048) for t in texts] - token_count = sum(num_tokens_from_string(text) for text in texts) config = self._build_embedding_config() - batch_size = 16 - ress = [] - for i in range(0, len(texts), batch_size): - result = None - try: - result = self.client.models.embed_content( - model=self.model_name, - contents=texts[i : i + batch_size], - config=config, - ) - ress.extend(self._parse_embedding_response(result)) - except Exception as _e: - log_exception(_e, result) - raise Exception(f"Error: {result}") - return np.array(ress), token_count + + def _call(batch): + result = self.client.models.embed_content(model=self.model_name, contents=batch, config=config) + # Gemini embeddings do not report token usage; count locally. + return self._parse_embedding_response(result), sum(num_tokens_from_string(t) for t in batch) + + return self._batched_encode(texts, _call, batch_size=16, truncate_to=2048) def encode_queries(self, text): config = self._build_embedding_config() - result = None token_count = num_tokens_from_string(text) try: result = self.client.models.embed_content( @@ -772,8 +790,8 @@ class GeminiEmbed(Base): ) return np.array(self._parse_embedding_response(result)[0]), token_count except Exception as _e: - log_exception(_e, result) - raise Exception(f"Error: {result}") + logger.exception("GeminiEmbed: query embedding request failed") + raise EmbeddingError(f"Embedding request failed for GeminiEmbed. Error: {_e}") from _e class NvidiaEmbed(Base): @@ -796,32 +814,24 @@ class NvidiaEmbed(Base): if model_name == "snowflake/arctic-embed-l": self.base_url = "https://ai.api.nvidia.com/v1/retrieval/snowflake/arctic-embed-l/embeddings" + def _call(self, batch, input_type="query"): + payload = { + "input": batch, + "input_type": input_type, + "model": self.model_name, + "encoding_format": "float", + "truncate": "END", # NVIDIA truncates oversized inputs server-side. + } + response = requests.post(self.base_url, headers=self.headers, json=payload, timeout=30) + return self._openai_http_embeddings(response) + def encode(self, texts: list): - batch_size = 16 - ress = [] - token_count = 0 - for i in range(0, len(texts), batch_size): - payload = { - "input": texts[i : i + batch_size], - "input_type": "query", - "model": self.model_name, - "encoding_format": "float", - "truncate": "END", - } - response = requests.post(self.base_url, headers=self.headers, json=payload, timeout=30) - _raise_model_exception_if_failed(response) - try: - res = response.json() - ress.extend([d["embedding"] for d in res["data"]]) - token_count += total_token_count_from_response(res) - except Exception as _e: - log_exception(_e, response) - raise Exception(f"Error: {response}") - return np.array(ress), token_count + # NVIDIA NIM expects "passage" for documents (indexing) and "query" for retrieval. + return self._batched_encode(texts, lambda b: self._call(b, "passage"), batch_size=16) def encode_queries(self, text): - embds, cnt = self.encode([text]) - return np.array(embds[0]), cnt + vectors, token_count = self._batched_encode([text], lambda b: self._call(b, "query"), batch_size=16) + return vectors[0], token_count class LmStudioEmbed(LocalAIEmbed): @@ -855,37 +865,32 @@ class CoHereEmbed(Base): self.client = Client(api_key=key) self.model_name = model_name + def _call(self, batch): + res = self.client.embed( + texts=batch, + model=self.model_name, + input_type="search_document", + embedding_types=["float"], + truncate="END", # let Cohere clip oversized inputs server-side instead of hard-failing + ) + return list(res.embeddings.float), total_token_count_from_response(res) + def encode(self, texts: list): - batch_size = 16 - ress = [] - token_count = 0 - for i in range(0, len(texts), batch_size): - res = self.client.embed( - texts=texts[i : i + batch_size], - model=self.model_name, - input_type="search_document", - embedding_types=["float"], - ) - try: - ress.extend([d for d in res.embeddings.float]) - token_count += total_token_count_from_response(res) - except Exception as _e: - log_exception(_e, res) - raise Exception(f"Error: {res}") - return np.array(ress), token_count + return self._batched_encode(texts, self._call, batch_size=16) def encode_queries(self, text): - res = self.client.embed( - texts=[text], - model=self.model_name, - input_type="search_query", - embedding_types=["float"], - ) try: + res = self.client.embed( + texts=[text], + model=self.model_name, + input_type="search_query", + embedding_types=["float"], + truncate="END", + ) return np.array(res.embeddings.float[0]), int(total_token_count_from_response(res)) except Exception as _e: - log_exception(_e, res) - raise Exception(f"Error: {res}") + logger.exception("CoHereEmbed: query embedding request failed") + raise EmbeddingError(f"Embedding request failed for CoHereEmbed. Error: {_e}") from _e class TogetherAIEmbed(OpenAIEmbed): @@ -932,49 +937,27 @@ class SILICONFLOWEmbed(Base): self.base_url = normalized_base_url self.model_name = model_name - def encode(self, texts: list): - batch_size = 16 - ress = [] - token_count = 0 - for i in range(0, len(texts), batch_size): - texts_batch = texts[i : i + batch_size] - if self.model_name in ["BAAI/bge-large-zh-v1.5", "BAAI/bge-large-en-v1.5"]: - # limit 512, 340 is almost safe - texts_batch = [" " if not text.strip() else truncate(text, 256) for text in texts_batch] - else: - texts_batch = [" " if not text.strip() else text for text in texts_batch] + def _clean_batch(self, batch): + if self.model_name in ["BAAI/bge-large-zh-v1.5", "BAAI/bge-large-en-v1.5"]: + # limit 512, 340 is almost safe + return [" " if not text.strip() else truncate(text, 256) for text in batch] + return [" " if not text.strip() else text for text in batch] - payload = { - "model": self.model_name, - "input": texts_batch, - "encoding_format": "float", - } - response = requests.post(self.base_url, json=payload, headers=self.headers, timeout=30) - _raise_model_exception_if_failed(response) - try: - res = response.json() - ress.extend([d["embedding"] for d in res["data"]]) - token_count += total_token_count_from_response(res) - except Exception as _e: - log_exception(_e, response) - raise Exception(f"Error: {response}") - - return np.array(ress), token_count - - def encode_queries(self, text): + def _call(self, batch): payload = { "model": self.model_name, - "input": text, + "input": self._clean_batch(batch), "encoding_format": "float", } response = requests.post(self.base_url, json=payload, headers=self.headers, timeout=30) - _raise_model_exception_if_failed(response) - try: - res = response.json() - return np.array(res["data"][0]["embedding"]), total_token_count_from_response(res) - except Exception as _e: - log_exception(_e, response) - raise Exception(f"Error: {response}") + return self._openai_http_embeddings(response) + + def encode(self, texts: list): + return self._batched_encode(texts, self._call, batch_size=16) + + def encode_queries(self, text): + vectors, token_count = self._batched_encode([text], self._call, batch_size=16) + return vectors[0], token_count class ReplicateEmbed(Base): @@ -1013,26 +996,26 @@ class BaiduYiyanEmbed(Base): self.model_name = model_name def encode(self, texts: list, batch_size=16): - res = self.client.do(model=self.model_name, texts=texts).body try: + res = self.client.do(model=self.model_name, texts=texts).body return ( np.array([r["embedding"] for r in res["data"]]), total_token_count_from_response(res), ) except Exception as _e: - log_exception(_e, res) - raise Exception(f"Error: {res}") + logger.exception("BaiduYiyanEmbed: embedding request failed") + raise EmbeddingError(f"Embedding request failed for BaiduYiyanEmbed. Error: {_e}") from _e def encode_queries(self, text): - res = self.client.do(model=self.model_name, texts=[text]).body try: + res = self.client.do(model=self.model_name, texts=[text]).body return ( np.array([r["embedding"] for r in res["data"]]), total_token_count_from_response(res), ) except Exception as _e: - log_exception(_e, res) - raise Exception(f"Error: {res}") + logger.exception("BaiduYiyanEmbed: query embedding request failed") + raise EmbeddingError(f"Embedding request failed for BaiduYiyanEmbed. Error: {_e}") from _e class VoyageEmbed(Base): @@ -1044,27 +1027,22 @@ class VoyageEmbed(Base): self.client = voyageai.Client(api_key=key) self.model_name = model_name + def _call(self, batch): + res = self.client.embed(texts=batch, model=self.model_name, input_type="document") + # `_batched_encode` accumulates these per-batch vectors and returns a + # single np.ndarray, so encode() keeps the np.ndarray contract. + return res.embeddings, res.total_tokens + def encode(self, texts: list): - batch_size = 16 - ress = [] - token_count = 0 - for i in range(0, len(texts), batch_size): - res = self.client.embed(texts=texts[i : i + batch_size], model=self.model_name, input_type="document") - try: - ress.extend(res.embeddings) - token_count += res.total_tokens - except Exception as _e: - log_exception(_e, res) - raise Exception(f"Error: {res}") - return np.array(ress), token_count + return self._batched_encode(texts, self._call, batch_size=16) def encode_queries(self, text): - res = self.client.embed(texts=text, model=self.model_name, input_type="query") try: + res = self.client.embed(texts=text, model=self.model_name, input_type="query") return np.array(res.embeddings)[0], res.total_tokens except Exception as _e: - log_exception(_e, res) - raise Exception(f"Error: {res}") + logger.exception("VoyageEmbed: query embedding request failed") + raise EmbeddingError(f"Embedding request failed for VoyageEmbed. Error: {_e}") from _e class HuggingFaceEmbed(Base): @@ -1080,14 +1058,13 @@ class HuggingFaceEmbed(Base): def encode(self, texts: list): response = requests.post(f"{self.base_url}/embed", json={"inputs": texts}, headers={"Content-Type": "application/json"}, timeout=30) _raise_model_exception_if_failed(response) - embeddings = response.json() - return np.array(embeddings), sum([num_tokens_from_string(text) for text in texts]) + # TEI auto-truncates oversized inputs, so no client-side truncation is needed. + return np.array(response.json()), sum([num_tokens_from_string(text) for text in texts]) def encode_queries(self, text: str): response = requests.post(f"{self.base_url}/embed", json={"inputs": text}, headers={"Content-Type": "application/json"}, timeout=30) _raise_model_exception_if_failed(response) - embedding = response.json()[0] - return np.array(embedding), num_tokens_from_string(text) + return np.array(response.json()[0]), num_tokens_from_string(text) class VolcEngineEmbed(Base): @@ -1141,13 +1118,14 @@ class VolcEngineEmbed(Base): request_body = {"model": self.model_name, "input": [{"type": "text", "text": text}]} response = sync_request(method="POST", url=url, headers=headers, json=request_body, timeout=60) if response.status_code != 200: - raise Exception(f"Error: {response.status_code} - {response.text}") + raise EmbeddingError(f"Embedding request failed for VolcEngineEmbed. Error: {response.status_code} - {response.text}") result = response.json() try: ress.append(self._extract_embedding(result)) total_tokens += total_token_count_from_response(result) except Exception as _e: - log_exception(_e) + logger.exception("VolcEngineEmbed: failed to parse embedding response") + raise EmbeddingError(f"Embedding request failed for VolcEngineEmbed. Error: {_e}; response={result}") from _e return np.array(ress), total_tokens @@ -1295,8 +1273,8 @@ class PerplexityEmbed(Base): ress.append(self._decode_base64_int8(chunk_emb["embedding"])) token_count += res.get("usage", {}).get("total_tokens", 0) except Exception as _e: - log_exception(_e, response) - raise Exception(f"Error: {response.text}") + logger.exception("PerplexityEmbed: failed to parse contextualized embedding response") + raise EmbeddingError(f"Embedding request failed for PerplexityEmbed. Error: {response.text}") from _e else: url = f"{self.base_url}/v1/embeddings" for i in range(0, len(texts), batch_size): @@ -1314,8 +1292,8 @@ class PerplexityEmbed(Base): ress.append(self._decode_base64_int8(d["embedding"])) token_count += res.get("usage", {}).get("total_tokens", 0) except Exception as _e: - log_exception(_e, response) - raise Exception(f"Error: {response.text}") + logger.exception("PerplexityEmbed: failed to parse embedding response") + raise EmbeddingError(f"Embedding request failed for PerplexityEmbed. Error: {response.text}") from _e return np.array(ress), token_count diff --git a/test/unit_test/conftest.py b/test/unit_test/conftest.py new file mode 100644 index 0000000000..bd5e7afe6c --- /dev/null +++ b/test/unit_test/conftest.py @@ -0,0 +1,50 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Shared setup for RAGFlow unit tests. + +Several parsers and the chunking pipeline tokenize text with NLTK, which needs +the ``punkt_tab`` and ``wordnet`` data sets. Production provisions these via +``download_deps.py`` (into ``nltk_data``, exported as ``NLTK_DATA`` by +``docker/launch_backend_service.sh``) and ``api.validation`` at startup, but the +unit-test runner has neither. Without the data, tokenizer-backed tests such as +``test_epub_parser`` and ``test_dataflow_service`` fail with +``LookupError: Resource 'punkt_tab' not found``. Make sure the data is reachable +before any test imports a tokenizer: reuse a provisioned ``nltk_data`` directory +when present, and download only what is still missing. +""" + +import os + +import nltk + +# Reuse data already fetched by download_deps.py (the directory the app exports +# as NLTK_DATA) so provisioned environments do not download it again. +_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir)) +_LOCAL_NLTK_DATA = os.path.join(_REPO_ROOT, "nltk_data") +if os.path.isdir(_LOCAL_NLTK_DATA) and _LOCAL_NLTK_DATA not in nltk.data.path: + nltk.data.path.insert(0, _LOCAL_NLTK_DATA) + +# (download name, resource path used by nltk.data.find) +_REQUIRED_NLTK_DATA = ( + ("punkt_tab", "tokenizers/punkt_tab"), + ("wordnet", "corpora/wordnet"), +) +for _name, _find_path in _REQUIRED_NLTK_DATA: + try: + nltk.data.find(_find_path) + except LookupError: + nltk.download(_name, quiet=True) diff --git a/test/unit_test/rag/llm/test_embedding_model.py b/test/unit_test/rag/llm/test_embedding_model.py new file mode 100644 index 0000000000..cbd0d0d7ee --- /dev/null +++ b/test/unit_test/rag/llm/test_embedding_model.py @@ -0,0 +1,383 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Tests for the embedding-provider fixes in ``rag.llm.embedding_model``: + +* a failing embedding call raises a single deterministic, informative + ``EmbeddingError`` (and the previous unreachable ``raise Exception(f"Error: {res}")`` + can no longer mask it, regardless of whether the SDK response exposes ``.text``); +* token counts reflect real usage, or an honest local fallback — never the old + fabricated ``1024`` / ``+= 128`` constants; +* inputs at the truncation boundary are not pushed past the model token limit + (the old ``8196`` overshoot is gone); +* ``ZhipuEmbed`` / ``OllamaEmbed`` now batch — ``ceil(n / batch_size)`` requests + with input order and output shape preserved. +""" + +import json +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest + +from rag.llm.embedding_model import ( + DEFAULT_MAX_TOKENS, + BedrockEmbed, + EmbeddingError, + LocalAIEmbed, + MistralEmbed, + NvidiaEmbed, + OllamaEmbed, + OpenAIEmbed, + ZhipuEmbed, +) +from common.exceptions import ModelException +from common.token_utils import num_tokens_from_string + + +# --------------------------------------------------------------------------- # +# Fakes +# --------------------------------------------------------------------------- # +class _OpenAIResp: + """Minimal stand-in for an OpenAI embeddings response. + + Unlike ``MagicMock`` it does NOT auto-create a ``usage`` attribute, so + ``total_token_count_from_response`` correctly returns 0 when ``total_tokens`` + is not supplied (exercising the local-count fallback paths). + """ + + def __init__(self, vectors, total_tokens=None): + self.data = [SimpleNamespace(embedding=list(v)) for v in vectors] + if total_tokens is not None: + self.usage = SimpleNamespace(total_tokens=total_tokens) + + +def _openai_create(total_tokens=None, dim=3): + """Build a side_effect that returns one vector per input text.""" + + def _create(input, model, **kwargs): + return _OpenAIResp([[float(i)] * dim for i in range(len(input))], total_tokens=total_tokens) + + return _create + + +def _make_openai(cls=OpenAIEmbed, total_tokens=None): + embed = cls("key", "text-embedding-3-small", base_url="https://example.invalid/v1") + embed.client = MagicMock() + embed.client.embeddings.create = MagicMock(side_effect=_openai_create(total_tokens=total_tokens)) + return embed + + +# --------------------------------------------------------------------------- # +# 1. Deterministic, informative error handling (the masked-error bug) +# --------------------------------------------------------------------------- # +class _BadRespWithText: + """Parsing this raises; it also exposes ``.text`` — which the old + ``log_exception(_e, res)`` path would have re-raised as a bare + ``Exception(text)``, masking the intended error non-deterministically.""" + + text = "Internal Server Error" + + @property + def data(self): + raise ValueError("malformed response payload") + + +class _BadRespNoText: + @property + def data(self): + raise ValueError("malformed response payload") + + +@pytest.mark.p1 +class TestDeterministicErrors: + def test_api_error_raises_embedding_error(self): + embed = _make_openai() + embed.client.embeddings.create = MagicMock(side_effect=RuntimeError("503 upstream down")) + with pytest.raises(EmbeddingError) as exc: + embed.encode(["hello"]) + # Informative: surfaces the underlying detail and contains "Error". + assert "503 upstream down" in str(exc.value) + assert "Error" in str(exc.value) + assert "OpenAIEmbed" in str(exc.value) + + def test_same_exception_type_with_and_without_text_attr(self): + """The surfaced exception must NOT depend on whether the response object + exposes ``.text`` (the old non-determinism). Both variants -> EmbeddingError.""" + with_text = _make_openai() + with_text.client.embeddings.create = MagicMock(return_value=_BadRespWithText()) + without_text = _make_openai() + without_text.client.embeddings.create = MagicMock(return_value=_BadRespNoText()) + + with pytest.raises(EmbeddingError) as e1: + with_text.encode(["x"]) + with pytest.raises(EmbeddingError) as e2: + without_text.encode(["x"]) + + # Deterministic: same type, and the response's ``.text`` did not hijack it. + assert type(e1.value) is type(e2.value) is EmbeddingError + assert "Internal Server Error" not in str(e1.value) + assert "malformed response payload" in str(e1.value) + + def test_query_path_also_deterministic(self): + embed = _make_openai() + embed.client.embeddings.create = MagicMock(side_effect=RuntimeError("nope")) + with pytest.raises(EmbeddingError): + embed.encode_queries("hi") + + def test_http_bad_status_raises_model_exception_with_body(self): + """A bad HTTP status surfaces the response body via a retryable-aware + ModelException, which the API error handler understands.""" + embed = NvidiaEmbed("key", "nvidia/nv-embed-v1") + bad = MagicMock() + bad.status_code = 400 + bad.text = '{"error": "bad request: empty input"}' + with patch("rag.llm.embedding_model.requests.post", return_value=bad): + with pytest.raises(ModelException) as exc: + embed.encode(["hello"]) + assert "bad request: empty input" in str(exc.value) + + def test_http_malformed_ok_response_raises_embedding_error(self): + """A 200 response with an unexpected body still yields a deterministic + EmbeddingError carrying the payload detail.""" + embed = NvidiaEmbed("key", "nvidia/nv-embed-v1") + bad = MagicMock() + bad.status_code = 200 + bad.json.return_value = {"unexpected": "shape"} + with patch("rag.llm.embedding_model.requests.post", return_value=bad): + with pytest.raises(EmbeddingError) as exc: + embed.encode(["hello"]) + assert "unexpected" in str(exc.value) + + +# --------------------------------------------------------------------------- # +# 2. Token accounting (no fabricated 1024 / += 128) +# --------------------------------------------------------------------------- # +@pytest.mark.p1 +class TestTokenAccounting: + def test_openai_uses_reported_usage(self): + embed = _make_openai(total_tokens=42) + _, tokens = embed.encode(["a", "b"]) + assert tokens == 42 + + def test_localai_falls_back_to_local_count_not_1024(self): + embed = _make_openai(cls=LocalAIEmbed) # no usage in response + texts = ["hello world", "second chunk of text"] + _, tokens = embed.encode(texts) + expected = sum(num_tokens_from_string(t) for t in texts) + assert tokens == expected + assert tokens != 1024 # the old fabricated constant + + def test_ollama_uses_prompt_eval_count_not_128(self): + embed = OllamaEmbed("x", "nomic-embed-text", base_url="http://localhost:11434") + embed.client = MagicMock() + embed.client.embed = MagicMock(return_value={"embeddings": [[0.1, 0.2], [0.3, 0.4]], "prompt_eval_count": 33}) + _, tokens = embed.encode(["aaa", "bbb"]) + assert tokens == 33 + assert tokens != 128 * 2 # the old fabricated per-text constant + + def test_ollama_token_fallback_when_server_omits_count(self): + embed = OllamaEmbed("x", "nomic-embed-text", base_url="http://localhost:11434") + embed.client = MagicMock() + # No prompt_eval_count reported -> honest local count, not a fixed number. + embed.client.embed = MagicMock(return_value={"embeddings": [[0.1, 0.2]]}) + texts = ["some text to embed"] + _, tokens = embed.encode(texts) + assert tokens == sum(num_tokens_from_string(t) for t in texts) + + +# --------------------------------------------------------------------------- # +# 3. Truncation boundary (no 8196 overshoot) +# --------------------------------------------------------------------------- # +@pytest.mark.p2 +class TestTruncationBoundary: + def test_default_limit_is_8192(self): + assert DEFAULT_MAX_TOKENS == 8192 + + def test_openai_input_truncated_below_model_limit(self): + embed = _make_openai(total_tokens=1) + # An input far above the 8K ceiling. + huge = "word " * 12000 + embed.encode([huge]) + sent = embed.client.embeddings.create.call_args.kwargs["input"][0] + # Truncated to the documented 8191 ceiling, never above the 8192 model limit. + assert num_tokens_from_string(sent) <= 8191 + assert num_tokens_from_string(sent) <= DEFAULT_MAX_TOKENS + + def test_mistral_truncates_to_8192_not_8196(self): + embed = MistralEmbed.__new__(MistralEmbed) + embed.model_name = "mistral-embed" + captured = {} + + def _embeddings(input, model): + captured["input"] = input + return _OpenAIResp([[0.0, 0.0]], total_tokens=1) + + embed.client = MagicMock() + embed.client.embeddings = MagicMock(side_effect=_embeddings) + huge = "word " * 12000 + embed.encode([huge]) + assert num_tokens_from_string(captured["input"][0]) <= DEFAULT_MAX_TOKENS + + +# --------------------------------------------------------------------------- # +# 4. Batching for Zhipu and Ollama (ceil(n / batch_size) requests) +# --------------------------------------------------------------------------- # +@pytest.mark.p1 +class TestBatching: + def test_zhipu_batches_instead_of_per_text(self): + embed = ZhipuEmbed("key", "embedding-3") + embed.client = MagicMock() + embed.client.embeddings.create = MagicMock(side_effect=_openai_create(total_tokens=5)) + texts = [f"t{i}" for i in range(3)] + vectors, _ = embed.encode(texts) + # One request for 3 texts (batch_size 16) — NOT three per-text requests. + assert embed.client.embeddings.create.call_count == 1 + assert vectors.shape[0] == 3 + + def test_zhipu_issues_ceil_n_over_batch_calls(self): + embed = ZhipuEmbed("key", "embedding-3") + embed.client = MagicMock() + embed.client.embeddings.create = MagicMock(side_effect=_openai_create(total_tokens=5)) + texts = [f"t{i}" for i in range(20)] # batch_size 16 -> ceil(20/16) == 2 + vectors, _ = embed.encode(texts) + assert embed.client.embeddings.create.call_count == 2 + assert vectors.shape[0] == 20 + + def test_ollama_batches_and_preserves_order(self): + embed = OllamaEmbed("x", "nomic-embed-text", base_url="http://localhost:11434") + embed.client = MagicMock() + + def _embed(model, input, **kwargs): + # Echo a recognisable vector per input so order can be checked. + return {"embeddings": [[float(len(t))] for t in input], "prompt_eval_count": 1} + + embed.client.embed = MagicMock(side_effect=_embed) + texts = ["a", "bb", "ccc"] + vectors, _ = embed.encode(texts) + + # One batched request, not one per text. + assert embed.client.embed.call_count == 1 + assert vectors.shape == (3, 1) + # Order preserved: vector value equals input length. + np.testing.assert_array_equal(vectors[:, 0], np.array([1.0, 2.0, 3.0])) + + def test_zhipu_realigns_out_of_order_response(self): + """If the provider returns embeddings out of order, the per-item `index` + must realign them with the input — otherwise chunks get wrong vectors.""" + embed = ZhipuEmbed("key", "embedding-3") + embed.client = MagicMock() + + def _create(input, model, **kwargs): + data = [SimpleNamespace(embedding=[float(i)], index=i) for i in range(len(input))] + return SimpleNamespace(data=list(reversed(data)), usage=SimpleNamespace(total_tokens=1)) + + embed.client.embeddings.create = MagicMock(side_effect=_create) + vectors, _ = embed.encode(["t0", "t1", "t2"]) + np.testing.assert_array_equal(vectors[:, 0], np.array([0.0, 1.0, 2.0])) + + def test_nvidia_http_realigns_out_of_order_response(self): + embed = NvidiaEmbed("key", "nvidia/nv-embed-v1") + resp = MagicMock() + resp.status_code = 200 + resp.json.return_value = { + "data": [ + {"index": 2, "embedding": [2.0]}, + {"index": 0, "embedding": [0.0]}, + {"index": 1, "embedding": [1.0]}, + ], + "usage": {"total_tokens": 3}, + } + with patch("rag.llm.embedding_model.requests.post", return_value=resp): + vectors, _ = embed.encode(["a", "b", "c"]) + np.testing.assert_array_equal(vectors[:, 0], np.array([0.0, 1.0, 2.0])) + + def test_ollama_issues_ceil_n_over_batch_calls(self): + embed = OllamaEmbed("x", "nomic-embed-text", base_url="http://localhost:11434") + embed.client = MagicMock() + embed.client.embed = MagicMock(side_effect=lambda model, input, **kw: {"embeddings": [[0.0] for _ in input], "prompt_eval_count": 1}) + texts = [f"t{i}" for i in range(20)] # batch_size 16 -> 2 calls + vectors, _ = embed.encode(texts) + assert embed.client.embed.call_count == 2 + assert vectors.shape[0] == 20 + + +# --------------------------------------------------------------------------- # +# 5. Provider-specific request/response shapes +# --------------------------------------------------------------------------- # +@pytest.mark.p2 +class TestNvidiaInputType: + """NVIDIA NIM expects input_type=passage for documents and =query for queries; + using "query" for documents degrades retrieval (asymmetric embeddings).""" + + def _mock_resp(self): + resp = MagicMock() + resp.status_code = 200 + resp.json.return_value = {"data": [{"index": 0, "embedding": [1.0]}], "usage": {"total_tokens": 1}} + return resp + + def test_documents_use_passage(self): + embed = NvidiaEmbed("key", "nvidia/nv-embed-v1") + with patch("rag.llm.embedding_model.requests.post", return_value=self._mock_resp()) as post: + embed.encode(["a document"]) + assert post.call_args.kwargs["json"]["input_type"] == "passage" + + def test_queries_use_query(self): + embed = NvidiaEmbed("key", "nvidia/nv-embed-v1") + with patch("rag.llm.embedding_model.requests.post", return_value=self._mock_resp()) as post: + embed.encode_queries("a query") + assert post.call_args.kwargs["json"]["input_type"] == "query" + + +@pytest.mark.p2 +class TestBedrockResponseParsing: + """Bedrock Titan returns {"embedding": [...]}; Cohere returns + {"embeddings": [[...]]}. Both must parse without KeyError.""" + + @staticmethod + def _make(model_prefix): + embed = BedrockEmbed.__new__(BedrockEmbed) + embed.model_name = f"{model_prefix}.embed-model" + embed.is_amazon = model_prefix == "amazon" + embed.is_cohere = model_prefix == "cohere" + embed.client = MagicMock() + return embed + + @staticmethod + def _body(payload): + body = MagicMock() + body.read.return_value = json.dumps(payload).encode() + return {"body": body} + + def test_cohere_reads_embeddings_plural(self): + embed = self._make("cohere") + embed.client.invoke_model.return_value = self._body({"embeddings": [[1.0, 2.0]]}) + vectors, _ = embed.encode(["hello"]) + assert vectors.shape == (1, 2) + np.testing.assert_array_equal(vectors[0], np.array([1.0, 2.0])) + + def test_amazon_reads_embedding_singular(self): + embed = self._make("amazon") + embed.client.invoke_model.return_value = self._body({"embedding": [3.0, 4.0]}) + vectors, _ = embed.encode(["hello"]) + np.testing.assert_array_equal(vectors[0], np.array([3.0, 4.0])) + + def test_cohere_query_reads_embeddings_plural(self): + embed = self._make("cohere") + embed.client.invoke_model.return_value = self._body({"embeddings": [[5.0, 6.0]]}) + vector, _ = embed.encode_queries("q") + np.testing.assert_array_equal(vector, np.array([5.0, 6.0]))