mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
### Summary Closes #15423 `rag/llm/embedding_model.py` hosts about 40 embedding providers that shared several defects affecting indexing reliability, cost accounting, and error visibility. This PR fixes four concrete bugs. **Masked, inconsistent errors (27 sites).** Nearly every provider ran `log_exception(_e, res)` followed by `raise Exception(f"Error: {res}")`. Because `log_exception` always raises, the second line was dead code, and the surfaced exception varied with whether the SDK response exposed a `.text` attribute. Every failure path now raises a single `EmbeddingError` that includes the underlying response detail, so the cause of a failed embedding is consistent and visible. **Fabricated token counts.** `LocalAIEmbed` returned a hardcoded `1024` and `OllamaEmbed` added `128` per text. These values feed `used_tokens` and therefore billing and usage tracking. Both now report the real count from the API (Ollama `prompt_eval_count`, LocalAI `usage`) and fall back to a local token count only when the server omits it. **Truncation overshoot.** The `8196` limit used by Mistral and Bedrock exceeded the standard `8192` ceiling and could push boundary sized inputs past the model limit. Limits are corrected to `8192` and made intentional per provider, and providers that rely on server side truncation now request it explicitly (Ollama `truncate=True`, Cohere `truncate="END"`). **Missing batching on Zhipu and Ollama.** Both issued one request per text. They now batch like the other OpenAI compatible providers, turning N round trips into `ceil(N / batch_size)`. Batched results are realigned by response `index` so a chunk always keeps its own vector. A shared `Base._batched_encode` helper owns the batch loop, optional truncation, result accumulation, and the single error path. It is the mechanism that lets these fixes live in one place instead of across 27 duplicated sites. The public `encode()` and `encode_queries()` contract stays the same, so existing callers are unaffected. Tests covering all four fixes are added under `test/unit_test/rag/llm/test_embedding_model.py`. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
51 lines
2.1 KiB
Python
51 lines
2.1 KiB
Python
#
|
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
"""Shared setup for RAGFlow unit tests.
|
|
|
|
Several parsers and the chunking pipeline tokenize text with NLTK, which needs
|
|
the ``punkt_tab`` and ``wordnet`` data sets. Production provisions these via
|
|
``download_deps.py`` (into ``nltk_data``, exported as ``NLTK_DATA`` by
|
|
``docker/launch_backend_service.sh``) and ``api.validation`` at startup, but the
|
|
unit-test runner has neither. Without the data, tokenizer-backed tests such as
|
|
``test_epub_parser`` and ``test_dataflow_service`` fail with
|
|
``LookupError: Resource 'punkt_tab' not found``. Make sure the data is reachable
|
|
before any test imports a tokenizer: reuse a provisioned ``nltk_data`` directory
|
|
when present, and download only what is still missing.
|
|
"""
|
|
|
|
import os
|
|
|
|
import nltk
|
|
|
|
# Reuse data already fetched by download_deps.py (the directory the app exports
|
|
# as NLTK_DATA) so provisioned environments do not download it again.
|
|
_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir))
|
|
_LOCAL_NLTK_DATA = os.path.join(_REPO_ROOT, "nltk_data")
|
|
if os.path.isdir(_LOCAL_NLTK_DATA) and _LOCAL_NLTK_DATA not in nltk.data.path:
|
|
nltk.data.path.insert(0, _LOCAL_NLTK_DATA)
|
|
|
|
# (download name, resource path used by nltk.data.find)
|
|
_REQUIRED_NLTK_DATA = (
|
|
("punkt_tab", "tokenizers/punkt_tab"),
|
|
("wordnet", "corpora/wordnet"),
|
|
)
|
|
for _name, _find_path in _REQUIRED_NLTK_DATA:
|
|
try:
|
|
nltk.data.find(_find_path)
|
|
except LookupError:
|
|
nltk.download(_name, quiet=True)
|