fix(llm): correct error handling, token accounting, and truncation in embedding providers (#15424)

### Summary

Closes #15423

`rag/llm/embedding_model.py` hosts about 40 embedding providers that
shared several defects affecting indexing reliability, cost accounting,
and error visibility. This PR fixes four concrete bugs.

**Masked, inconsistent errors (27 sites).** Nearly every provider ran
`log_exception(_e, res)` followed by `raise Exception(f"Error: {res}")`.
Because `log_exception` always raises, the second line was dead code,
and the surfaced exception varied with whether the SDK response exposed
a `.text` attribute. Every failure path now raises a single
`EmbeddingError` that includes the underlying response detail, so the
cause of a failed embedding is consistent and visible.

**Fabricated token counts.** `LocalAIEmbed` returned a hardcoded `1024`
and `OllamaEmbed` added `128` per text. These values feed `used_tokens`
and therefore billing and usage tracking. Both now report the real count
from the API (Ollama `prompt_eval_count`, LocalAI `usage`) and fall back
to a local token count only when the server omits it.

**Truncation overshoot.** The `8196` limit used by Mistral and Bedrock
exceeded the standard `8192` ceiling and could push boundary sized
inputs past the model limit. Limits are corrected to `8192` and made
intentional per provider, and providers that rely on server side
truncation now request it explicitly (Ollama `truncate=True`, Cohere
`truncate="END"`).

**Missing batching on Zhipu and Ollama.** Both issued one request per
text. They now batch like the other OpenAI compatible providers, turning
N round trips into `ceil(N / batch_size)`. Batched results are realigned
by response `index` so a chunk always keeps its own vector.

A shared `Base._batched_encode` helper owns the batch loop, optional
truncation, result accumulation, and the single error path. It is the
mechanism that lets these fixes live in one place instead of across 27
duplicated sites. The public `encode()` and `encode_queries()` contract
stays the same, so existing callers are unaffected.

Tests covering all four fixes are added under
`test/unit_test/rag/llm/test_embedding_model.py`.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
Dexterity
2026-06-11 07:29:46 -04:00
committed by GitHub
parent ec89fc036d
commit bde2b1fc6d
4 changed files with 733 additions and 316 deletions

View File

@@ -0,0 +1,50 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Shared setup for RAGFlow unit tests.
Several parsers and the chunking pipeline tokenize text with NLTK, which needs
the ``punkt_tab`` and ``wordnet`` data sets. Production provisions these via
``download_deps.py`` (into ``nltk_data``, exported as ``NLTK_DATA`` by
``docker/launch_backend_service.sh``) and ``api.validation`` at startup, but the
unit-test runner has neither. Without the data, tokenizer-backed tests such as
``test_epub_parser`` and ``test_dataflow_service`` fail with
``LookupError: Resource 'punkt_tab' not found``. Make sure the data is reachable
before any test imports a tokenizer: reuse a provisioned ``nltk_data`` directory
when present, and download only what is still missing.
"""
import os
import nltk
# Reuse data already fetched by download_deps.py (the directory the app exports
# as NLTK_DATA) so provisioned environments do not download it again.
_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir))
_LOCAL_NLTK_DATA = os.path.join(_REPO_ROOT, "nltk_data")
if os.path.isdir(_LOCAL_NLTK_DATA) and _LOCAL_NLTK_DATA not in nltk.data.path:
nltk.data.path.insert(0, _LOCAL_NLTK_DATA)
# (download name, resource path used by nltk.data.find)
_REQUIRED_NLTK_DATA = (
("punkt_tab", "tokenizers/punkt_tab"),
("wordnet", "corpora/wordnet"),
)
for _name, _find_path in _REQUIRED_NLTK_DATA:
try:
nltk.data.find(_find_path)
except LookupError:
nltk.download(_name, quiet=True)

View File

@@ -0,0 +1,383 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Tests for the embedding-provider fixes in ``rag.llm.embedding_model``:
* a failing embedding call raises a single deterministic, informative
``EmbeddingError`` (and the previous unreachable ``raise Exception(f"Error: {res}")``
can no longer mask it, regardless of whether the SDK response exposes ``.text``);
* token counts reflect real usage, or an honest local fallback — never the old
fabricated ``1024`` / ``+= 128`` constants;
* inputs at the truncation boundary are not pushed past the model token limit
(the old ``8196`` overshoot is gone);
* ``ZhipuEmbed`` / ``OllamaEmbed`` now batch — ``ceil(n / batch_size)`` requests
with input order and output shape preserved.
"""
import json
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from rag.llm.embedding_model import (
DEFAULT_MAX_TOKENS,
BedrockEmbed,
EmbeddingError,
LocalAIEmbed,
MistralEmbed,
NvidiaEmbed,
OllamaEmbed,
OpenAIEmbed,
ZhipuEmbed,
)
from common.exceptions import ModelException
from common.token_utils import num_tokens_from_string
# --------------------------------------------------------------------------- #
# Fakes
# --------------------------------------------------------------------------- #
class _OpenAIResp:
"""Minimal stand-in for an OpenAI embeddings response.
Unlike ``MagicMock`` it does NOT auto-create a ``usage`` attribute, so
``total_token_count_from_response`` correctly returns 0 when ``total_tokens``
is not supplied (exercising the local-count fallback paths).
"""
def __init__(self, vectors, total_tokens=None):
self.data = [SimpleNamespace(embedding=list(v)) for v in vectors]
if total_tokens is not None:
self.usage = SimpleNamespace(total_tokens=total_tokens)
def _openai_create(total_tokens=None, dim=3):
"""Build a side_effect that returns one vector per input text."""
def _create(input, model, **kwargs):
return _OpenAIResp([[float(i)] * dim for i in range(len(input))], total_tokens=total_tokens)
return _create
def _make_openai(cls=OpenAIEmbed, total_tokens=None):
embed = cls("key", "text-embedding-3-small", base_url="https://example.invalid/v1")
embed.client = MagicMock()
embed.client.embeddings.create = MagicMock(side_effect=_openai_create(total_tokens=total_tokens))
return embed
# --------------------------------------------------------------------------- #
# 1. Deterministic, informative error handling (the masked-error bug)
# --------------------------------------------------------------------------- #
class _BadRespWithText:
"""Parsing this raises; it also exposes ``.text`` — which the old
``log_exception(_e, res)`` path would have re-raised as a bare
``Exception(text)``, masking the intended error non-deterministically."""
text = "Internal Server Error"
@property
def data(self):
raise ValueError("malformed response payload")
class _BadRespNoText:
@property
def data(self):
raise ValueError("malformed response payload")
@pytest.mark.p1
class TestDeterministicErrors:
def test_api_error_raises_embedding_error(self):
embed = _make_openai()
embed.client.embeddings.create = MagicMock(side_effect=RuntimeError("503 upstream down"))
with pytest.raises(EmbeddingError) as exc:
embed.encode(["hello"])
# Informative: surfaces the underlying detail and contains "Error".
assert "503 upstream down" in str(exc.value)
assert "Error" in str(exc.value)
assert "OpenAIEmbed" in str(exc.value)
def test_same_exception_type_with_and_without_text_attr(self):
"""The surfaced exception must NOT depend on whether the response object
exposes ``.text`` (the old non-determinism). Both variants -> EmbeddingError."""
with_text = _make_openai()
with_text.client.embeddings.create = MagicMock(return_value=_BadRespWithText())
without_text = _make_openai()
without_text.client.embeddings.create = MagicMock(return_value=_BadRespNoText())
with pytest.raises(EmbeddingError) as e1:
with_text.encode(["x"])
with pytest.raises(EmbeddingError) as e2:
without_text.encode(["x"])
# Deterministic: same type, and the response's ``.text`` did not hijack it.
assert type(e1.value) is type(e2.value) is EmbeddingError
assert "Internal Server Error" not in str(e1.value)
assert "malformed response payload" in str(e1.value)
def test_query_path_also_deterministic(self):
embed = _make_openai()
embed.client.embeddings.create = MagicMock(side_effect=RuntimeError("nope"))
with pytest.raises(EmbeddingError):
embed.encode_queries("hi")
def test_http_bad_status_raises_model_exception_with_body(self):
"""A bad HTTP status surfaces the response body via a retryable-aware
ModelException, which the API error handler understands."""
embed = NvidiaEmbed("key", "nvidia/nv-embed-v1")
bad = MagicMock()
bad.status_code = 400
bad.text = '{"error": "bad request: empty input"}'
with patch("rag.llm.embedding_model.requests.post", return_value=bad):
with pytest.raises(ModelException) as exc:
embed.encode(["hello"])
assert "bad request: empty input" in str(exc.value)
def test_http_malformed_ok_response_raises_embedding_error(self):
"""A 200 response with an unexpected body still yields a deterministic
EmbeddingError carrying the payload detail."""
embed = NvidiaEmbed("key", "nvidia/nv-embed-v1")
bad = MagicMock()
bad.status_code = 200
bad.json.return_value = {"unexpected": "shape"}
with patch("rag.llm.embedding_model.requests.post", return_value=bad):
with pytest.raises(EmbeddingError) as exc:
embed.encode(["hello"])
assert "unexpected" in str(exc.value)
# --------------------------------------------------------------------------- #
# 2. Token accounting (no fabricated 1024 / += 128)
# --------------------------------------------------------------------------- #
@pytest.mark.p1
class TestTokenAccounting:
def test_openai_uses_reported_usage(self):
embed = _make_openai(total_tokens=42)
_, tokens = embed.encode(["a", "b"])
assert tokens == 42
def test_localai_falls_back_to_local_count_not_1024(self):
embed = _make_openai(cls=LocalAIEmbed) # no usage in response
texts = ["hello world", "second chunk of text"]
_, tokens = embed.encode(texts)
expected = sum(num_tokens_from_string(t) for t in texts)
assert tokens == expected
assert tokens != 1024 # the old fabricated constant
def test_ollama_uses_prompt_eval_count_not_128(self):
embed = OllamaEmbed("x", "nomic-embed-text", base_url="http://localhost:11434")
embed.client = MagicMock()
embed.client.embed = MagicMock(return_value={"embeddings": [[0.1, 0.2], [0.3, 0.4]], "prompt_eval_count": 33})
_, tokens = embed.encode(["aaa", "bbb"])
assert tokens == 33
assert tokens != 128 * 2 # the old fabricated per-text constant
def test_ollama_token_fallback_when_server_omits_count(self):
embed = OllamaEmbed("x", "nomic-embed-text", base_url="http://localhost:11434")
embed.client = MagicMock()
# No prompt_eval_count reported -> honest local count, not a fixed number.
embed.client.embed = MagicMock(return_value={"embeddings": [[0.1, 0.2]]})
texts = ["some text to embed"]
_, tokens = embed.encode(texts)
assert tokens == sum(num_tokens_from_string(t) for t in texts)
# --------------------------------------------------------------------------- #
# 3. Truncation boundary (no 8196 overshoot)
# --------------------------------------------------------------------------- #
@pytest.mark.p2
class TestTruncationBoundary:
def test_default_limit_is_8192(self):
assert DEFAULT_MAX_TOKENS == 8192
def test_openai_input_truncated_below_model_limit(self):
embed = _make_openai(total_tokens=1)
# An input far above the 8K ceiling.
huge = "word " * 12000
embed.encode([huge])
sent = embed.client.embeddings.create.call_args.kwargs["input"][0]
# Truncated to the documented 8191 ceiling, never above the 8192 model limit.
assert num_tokens_from_string(sent) <= 8191
assert num_tokens_from_string(sent) <= DEFAULT_MAX_TOKENS
def test_mistral_truncates_to_8192_not_8196(self):
embed = MistralEmbed.__new__(MistralEmbed)
embed.model_name = "mistral-embed"
captured = {}
def _embeddings(input, model):
captured["input"] = input
return _OpenAIResp([[0.0, 0.0]], total_tokens=1)
embed.client = MagicMock()
embed.client.embeddings = MagicMock(side_effect=_embeddings)
huge = "word " * 12000
embed.encode([huge])
assert num_tokens_from_string(captured["input"][0]) <= DEFAULT_MAX_TOKENS
# --------------------------------------------------------------------------- #
# 4. Batching for Zhipu and Ollama (ceil(n / batch_size) requests)
# --------------------------------------------------------------------------- #
@pytest.mark.p1
class TestBatching:
def test_zhipu_batches_instead_of_per_text(self):
embed = ZhipuEmbed("key", "embedding-3")
embed.client = MagicMock()
embed.client.embeddings.create = MagicMock(side_effect=_openai_create(total_tokens=5))
texts = [f"t{i}" for i in range(3)]
vectors, _ = embed.encode(texts)
# One request for 3 texts (batch_size 16) — NOT three per-text requests.
assert embed.client.embeddings.create.call_count == 1
assert vectors.shape[0] == 3
def test_zhipu_issues_ceil_n_over_batch_calls(self):
embed = ZhipuEmbed("key", "embedding-3")
embed.client = MagicMock()
embed.client.embeddings.create = MagicMock(side_effect=_openai_create(total_tokens=5))
texts = [f"t{i}" for i in range(20)] # batch_size 16 -> ceil(20/16) == 2
vectors, _ = embed.encode(texts)
assert embed.client.embeddings.create.call_count == 2
assert vectors.shape[0] == 20
def test_ollama_batches_and_preserves_order(self):
embed = OllamaEmbed("x", "nomic-embed-text", base_url="http://localhost:11434")
embed.client = MagicMock()
def _embed(model, input, **kwargs):
# Echo a recognisable vector per input so order can be checked.
return {"embeddings": [[float(len(t))] for t in input], "prompt_eval_count": 1}
embed.client.embed = MagicMock(side_effect=_embed)
texts = ["a", "bb", "ccc"]
vectors, _ = embed.encode(texts)
# One batched request, not one per text.
assert embed.client.embed.call_count == 1
assert vectors.shape == (3, 1)
# Order preserved: vector value equals input length.
np.testing.assert_array_equal(vectors[:, 0], np.array([1.0, 2.0, 3.0]))
def test_zhipu_realigns_out_of_order_response(self):
"""If the provider returns embeddings out of order, the per-item `index`
must realign them with the input — otherwise chunks get wrong vectors."""
embed = ZhipuEmbed("key", "embedding-3")
embed.client = MagicMock()
def _create(input, model, **kwargs):
data = [SimpleNamespace(embedding=[float(i)], index=i) for i in range(len(input))]
return SimpleNamespace(data=list(reversed(data)), usage=SimpleNamespace(total_tokens=1))
embed.client.embeddings.create = MagicMock(side_effect=_create)
vectors, _ = embed.encode(["t0", "t1", "t2"])
np.testing.assert_array_equal(vectors[:, 0], np.array([0.0, 1.0, 2.0]))
def test_nvidia_http_realigns_out_of_order_response(self):
embed = NvidiaEmbed("key", "nvidia/nv-embed-v1")
resp = MagicMock()
resp.status_code = 200
resp.json.return_value = {
"data": [
{"index": 2, "embedding": [2.0]},
{"index": 0, "embedding": [0.0]},
{"index": 1, "embedding": [1.0]},
],
"usage": {"total_tokens": 3},
}
with patch("rag.llm.embedding_model.requests.post", return_value=resp):
vectors, _ = embed.encode(["a", "b", "c"])
np.testing.assert_array_equal(vectors[:, 0], np.array([0.0, 1.0, 2.0]))
def test_ollama_issues_ceil_n_over_batch_calls(self):
embed = OllamaEmbed("x", "nomic-embed-text", base_url="http://localhost:11434")
embed.client = MagicMock()
embed.client.embed = MagicMock(side_effect=lambda model, input, **kw: {"embeddings": [[0.0] for _ in input], "prompt_eval_count": 1})
texts = [f"t{i}" for i in range(20)] # batch_size 16 -> 2 calls
vectors, _ = embed.encode(texts)
assert embed.client.embed.call_count == 2
assert vectors.shape[0] == 20
# --------------------------------------------------------------------------- #
# 5. Provider-specific request/response shapes
# --------------------------------------------------------------------------- #
@pytest.mark.p2
class TestNvidiaInputType:
"""NVIDIA NIM expects input_type=passage for documents and =query for queries;
using "query" for documents degrades retrieval (asymmetric embeddings)."""
def _mock_resp(self):
resp = MagicMock()
resp.status_code = 200
resp.json.return_value = {"data": [{"index": 0, "embedding": [1.0]}], "usage": {"total_tokens": 1}}
return resp
def test_documents_use_passage(self):
embed = NvidiaEmbed("key", "nvidia/nv-embed-v1")
with patch("rag.llm.embedding_model.requests.post", return_value=self._mock_resp()) as post:
embed.encode(["a document"])
assert post.call_args.kwargs["json"]["input_type"] == "passage"
def test_queries_use_query(self):
embed = NvidiaEmbed("key", "nvidia/nv-embed-v1")
with patch("rag.llm.embedding_model.requests.post", return_value=self._mock_resp()) as post:
embed.encode_queries("a query")
assert post.call_args.kwargs["json"]["input_type"] == "query"
@pytest.mark.p2
class TestBedrockResponseParsing:
"""Bedrock Titan returns {"embedding": [...]}; Cohere returns
{"embeddings": [[...]]}. Both must parse without KeyError."""
@staticmethod
def _make(model_prefix):
embed = BedrockEmbed.__new__(BedrockEmbed)
embed.model_name = f"{model_prefix}.embed-model"
embed.is_amazon = model_prefix == "amazon"
embed.is_cohere = model_prefix == "cohere"
embed.client = MagicMock()
return embed
@staticmethod
def _body(payload):
body = MagicMock()
body.read.return_value = json.dumps(payload).encode()
return {"body": body}
def test_cohere_reads_embeddings_plural(self):
embed = self._make("cohere")
embed.client.invoke_model.return_value = self._body({"embeddings": [[1.0, 2.0]]})
vectors, _ = embed.encode(["hello"])
assert vectors.shape == (1, 2)
np.testing.assert_array_equal(vectors[0], np.array([1.0, 2.0]))
def test_amazon_reads_embedding_singular(self):
embed = self._make("amazon")
embed.client.invoke_model.return_value = self._body({"embedding": [3.0, 4.0]})
vectors, _ = embed.encode(["hello"])
np.testing.assert_array_equal(vectors[0], np.array([3.0, 4.0]))
def test_cohere_query_reads_embeddings_plural(self):
embed = self._make("cohere")
embed.client.invoke_model.return_value = self._body({"embeddings": [[5.0, 6.0]]})
vector, _ = embed.encode_queries("q")
np.testing.assert_array_equal(vector, np.array([5.0, 6.0]))