mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
### Summary Closes #15423 `rag/llm/embedding_model.py` hosts about 40 embedding providers that shared several defects affecting indexing reliability, cost accounting, and error visibility. This PR fixes four concrete bugs. **Masked, inconsistent errors (27 sites).** Nearly every provider ran `log_exception(_e, res)` followed by `raise Exception(f"Error: {res}")`. Because `log_exception` always raises, the second line was dead code, and the surfaced exception varied with whether the SDK response exposed a `.text` attribute. Every failure path now raises a single `EmbeddingError` that includes the underlying response detail, so the cause of a failed embedding is consistent and visible. **Fabricated token counts.** `LocalAIEmbed` returned a hardcoded `1024` and `OllamaEmbed` added `128` per text. These values feed `used_tokens` and therefore billing and usage tracking. Both now report the real count from the API (Ollama `prompt_eval_count`, LocalAI `usage`) and fall back to a local token count only when the server omits it. **Truncation overshoot.** The `8196` limit used by Mistral and Bedrock exceeded the standard `8192` ceiling and could push boundary sized inputs past the model limit. Limits are corrected to `8192` and made intentional per provider, and providers that rely on server side truncation now request it explicitly (Ollama `truncate=True`, Cohere `truncate="END"`). **Missing batching on Zhipu and Ollama.** Both issued one request per text. They now batch like the other OpenAI compatible providers, turning N round trips into `ceil(N / batch_size)`. Batched results are realigned by response `index` so a chunk always keeps its own vector. A shared `Base._batched_encode` helper owns the batch loop, optional truncation, result accumulation, and the single error path. It is the mechanism that lets these fixes live in one place instead of across 27 duplicated sites. The public `encode()` and `encode_queries()` contract stays the same, so existing callers are unaffected. Tests covering all four fixes are added under `test/unit_test/rag/llm/test_embedding_model.py`. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
88 lines
4.1 KiB
Python
88 lines
4.1 KiB
Python
#!/usr/bin/env python3
|
|
|
|
# PEP 723 metadata
|
|
# /// script
|
|
# requires-python = ">=3.10"
|
|
# dependencies = [
|
|
# "nltk",
|
|
# "huggingface-hub"
|
|
# ]
|
|
# ///
|
|
|
|
import argparse
|
|
import os
|
|
import urllib.request
|
|
from typing import Union
|
|
|
|
import nltk
|
|
from huggingface_hub import snapshot_download
|
|
|
|
|
|
def get_urls(use_china_mirrors=False) -> list[Union[str, list[str]]]:
|
|
if use_china_mirrors:
|
|
return [
|
|
"http://mirrors.tuna.tsinghua.edu.cn/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb",
|
|
"http://mirrors.tuna.tsinghua.edu.cn/ubuntu-ports/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_arm64.deb",
|
|
"https://repo.huaweicloud.com/repository/maven/org/apache/tika/tika-server-standard/3.3.0/tika-server-standard-3.3.0.jar",
|
|
"https://repo.huaweicloud.com/repository/maven/org/apache/tika/tika-server-standard/3.3.0/tika-server-standard-3.3.0.jar.md5",
|
|
"https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken",
|
|
["https://registry.npmmirror.com/-/binary/chrome-for-testing/121.0.6167.85/linux64/chrome-linux64.zip", "chrome-linux64-121-0-6167-85"],
|
|
["https://registry.npmmirror.com/-/binary/chrome-for-testing/121.0.6167.85/linux64/chromedriver-linux64.zip", "chromedriver-linux64-121-0-6167-85"],
|
|
"https://github.com/astral-sh/uv/releases/download/0.9.16/uv-x86_64-unknown-linux-gnu.tar.gz",
|
|
"https://github.com/astral-sh/uv/releases/download/0.9.16/uv-aarch64-unknown-linux-gnu.tar.gz",
|
|
]
|
|
else:
|
|
return [
|
|
"http://archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb",
|
|
"http://ports.ubuntu.com/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_arm64.deb",
|
|
"https://repo1.maven.org/maven2/org/apache/tika/tika-server-standard/3.3.0/tika-server-standard-3.3.0.jar",
|
|
"https://repo1.maven.org/maven2/org/apache/tika/tika-server-standard/3.3.0/tika-server-standard-3.3.0.jar.md5",
|
|
"https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken",
|
|
["https://storage.googleapis.com/chrome-for-testing-public/121.0.6167.85/linux64/chrome-linux64.zip", "chrome-linux64-121-0-6167-85"],
|
|
["https://storage.googleapis.com/chrome-for-testing-public/121.0.6167.85/linux64/chromedriver-linux64.zip", "chromedriver-linux64-121-0-6167-85"],
|
|
"https://github.com/astral-sh/uv/releases/download/0.9.16/uv-x86_64-unknown-linux-gnu.tar.gz",
|
|
"https://github.com/astral-sh/uv/releases/download/0.9.16/uv-aarch64-unknown-linux-gnu.tar.gz",
|
|
]
|
|
|
|
|
|
repos = [
|
|
"InfiniFlow/text_concat_xgb_v1.0",
|
|
"InfiniFlow/deepdoc",
|
|
]
|
|
|
|
|
|
def download_model(repository_id):
|
|
local_directory = os.path.abspath(os.path.join("huggingface.co", repository_id))
|
|
os.makedirs(local_directory, exist_ok=True)
|
|
snapshot_download(repo_id=repository_id, local_dir=local_directory)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Download dependencies with optional China mirror support")
|
|
parser.add_argument("--china-mirrors", action="store_true", help="Use China-accessible mirrors for downloads")
|
|
args = parser.parse_args()
|
|
|
|
urls = get_urls(args.china_mirrors)
|
|
|
|
# Some mirrors (e.g. archive.ubuntu.com) reject the default urllib
|
|
# User-Agent with HTTP 403, so install an opener with a browser-like UA.
|
|
opener = urllib.request.build_opener()
|
|
opener.addheaders = [("User-Agent", "Mozilla/5.0")]
|
|
urllib.request.install_opener(opener)
|
|
|
|
for url in urls:
|
|
download_url = url[0] if isinstance(url, list) else url
|
|
filename = url[1] if isinstance(url, list) else url.split("/")[-1]
|
|
print(f"Downloading {filename} from {download_url}...")
|
|
if not os.path.exists(filename):
|
|
urllib.request.urlretrieve(download_url, filename)
|
|
|
|
local_dir = os.path.abspath("nltk_data")
|
|
for data in ["wordnet", "punkt", "punkt_tab"]:
|
|
print(f"Downloading nltk {data}...")
|
|
nltk.download(data, download_dir=local_dir)
|
|
|
|
for repo_id in repos:
|
|
print(f"Downloading huggingface repo {repo_id}...")
|
|
download_model(repo_id)
|