From 1a6df01b53c708b6c2423840d062d211cbc8e2b5 Mon Sep 17 00:00:00 2001 From: Wang Qi Date: Mon, 1 Jun 2026 19:18:16 +0800 Subject: [PATCH] Bug fix: Enhance embeding model to give better error message (#15346) To resolve https://github.com/infiniflow/ragflow/issues/15343 enhance the model embedding message to give extact failure message to customer. # QWen ## Retrieval image ## Chat image # SiliconFlow ## Retrieval image ## Chat image # Baichuan ## Retrieval image ## Chat image # Zhipu zhipu is good. --- .github/workflows/tests.yml | 12 +-- api/apps/__init__.py | 7 ++ api/apps/restful_apis/dataset_api.py | 16 +-- api/utils/api_utils.py | 3 + common/exceptions.py | 6 ++ docker/launch_admin_service.sh | 98 +++++++++++++++++++ rag/llm/embedding_model.py | 72 +++++++++----- .../rag/llm/test_perplexity_embed.py | 34 ++++--- 8 files changed, 191 insertions(+), 57 deletions(-) create mode 100755 docker/launch_admin_service.sh diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 408e737c19..6e2a9d8998 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -280,7 +280,7 @@ jobs: svc_ready=1 break fi - echo "Waiting for service to be available... ($i/120)" + echo "Waiting for service to be available... ($i/60)" sleep 5 done if [ "$svc_ready" -ne 1 ]; then @@ -300,7 +300,7 @@ jobs: svc_ready=1 break fi - echo "Waiting for service to be available... ($i/120)" + echo "Waiting for service to be available... ($i/60)" sleep 5 done if [ "$svc_ready" -ne 1 ]; then @@ -376,7 +376,7 @@ jobs: svc_ready=1 break fi - echo "Waiting for service to be available... ($i/120)" + echo "Waiting for service to be available... ($i/60)" sleep 5 done if [ "$svc_ready" -ne 1 ]; then @@ -496,7 +496,7 @@ jobs: svc_ready=1 break fi - echo "Waiting for service to be available... ($i/120)" + echo "Waiting for service to be available... ($i/60)" sleep 5 done if [ "$svc_ready" -ne 1 ]; then @@ -516,7 +516,7 @@ jobs: svc_ready=1 break fi - echo "Waiting for service to be available... ($i/120)" + echo "Waiting for service to be available... ($i/60)" sleep 5 done if [ "$svc_ready" -ne 1 ]; then @@ -592,7 +592,7 @@ jobs: svc_ready=1 break fi - echo "Waiting for service to be available... ($i/120)" + echo "Waiting for service to be available... ($i/60)" sleep 5 done if [ "$svc_ready" -ne 1 ]; then diff --git a/api/apps/__init__.py b/api/apps/__init__.py index 041d06ecc2..b8da01423c 100644 --- a/api/apps/__init__.py +++ b/api/apps/__init__.py @@ -34,6 +34,7 @@ from quart_schema import QuartSchema from common import settings from api.utils.api_utils import server_error_response, get_json_result from api.constants import API_VERSION +from common.exceptions import ModelException from common.misc_utils import get_uuid settings.init_settings() @@ -361,6 +362,12 @@ async def unauthorized_werkzeug(error): return get_json_result(code=error.code, message=error.description), RetCode.UNAUTHORIZED +@app.errorhandler(ModelException) +async def handle_model_exception(error): + logging.warning("Forbidden request") + return get_json_result(code=RetCode.BAD_REQUEST, message=repr(error)), 200 + + @app.teardown_request def _db_close(exception): if exception: diff --git a/api/apps/restful_apis/dataset_api.py b/api/apps/restful_apis/dataset_api.py index 70fbcc0777..480b949abf 100644 --- a/api/apps/restful_apis/dataset_api.py +++ b/api/apps/restful_apis/dataset_api.py @@ -497,17 +497,11 @@ async def search_datasets(tenant_id): req, err = await validate_and_parse_json_request(request, SearchDatasetsReq) if err is not None: return get_error_argument_result(err) - try: - success, result = await dataset_api_service.search_datasets(tenant_id, req) - if success: - return get_result(data=result) - else: - return get_error_data_result(message=result) - except Exception as e: - logging.exception(e) - if "not_found" in str(e): - return get_error_data_result(message="No chunk found! Check the chunk status please!") - return get_error_data_result(message="Internal server error") + success, result = await dataset_api_service.search_datasets(tenant_id, req) + if success: + return get_result(data=result) + else: + return get_error_data_result(message=result) @manager.route("/datasets//search", methods=["POST"]) # noqa: F821 diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index 4712d9504f..74d95a514f 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -147,6 +147,9 @@ def server_error_response(e): if repr(e).find("index_not_found_exception") >= 0: return get_json_result(code=RetCode.EXCEPTION_ERROR, message="No chunk found, please upload file and parse it.") + if "not_found" in str(e): + return get_error_data_result(message="No chunk found! Check the chunk status please!") + return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e)) diff --git a/common/exceptions.py b/common/exceptions.py index 9511304720..bfbf245228 100644 --- a/common/exceptions.py +++ b/common/exceptions.py @@ -26,3 +26,9 @@ class ArgumentException(Exception): class NotFoundException(Exception): def __init__(self, msg): self.msg = msg + +class ModelException(Exception): + def __init__(self, msg, retryable=False): + super().__init__(msg) + self.msg = msg + self.retryable = retryable \ No newline at end of file diff --git a/docker/launch_admin_service.sh b/docker/launch_admin_service.sh new file mode 100755 index 0000000000..0afba88b2d --- /dev/null +++ b/docker/launch_admin_service.sh @@ -0,0 +1,98 @@ +#!/bin/bash + +# Exit immediately if a command exits with a non-zero status +set -e + +# Function to load environment variables from .env file +load_env_file() { + # Get the directory of the current script + local script_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + local env_file="$script_dir/.env" + + # Check if .env file exists + if [ -f "$env_file" ]; then + echo "Loading environment variables from: $env_file" + # Source the .env file + set -a + source "$env_file" + set +a + else + echo "Warning: .env file not found at: $env_file" + fi +} + +# Load environment variables +load_env_file + +# Unset HTTP proxies that might be set by Docker daemon +export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY="" +export PYTHONPATH=$(pwd) + +export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu/ +JEMALLOC_PATH=$(pkg-config --variable=libdir jemalloc)/libjemalloc.so + +PY=python3 + +# Set default number of workers if WS is not set or less than 1 +if [[ -z "$WS" || $WS -lt 1 ]]; then + WS=1 +fi + +# Maximum number of retries for each task executor and server +MAX_RETRIES=5 + +# Flag to control termination +STOP=false + +# Array to keep track of child PIDs +PIDS=() + +# Set the path to the NLTK data directory +export NLTK_DATA="./nltk_data" + +# Function to handle termination signals +cleanup() { + echo "Termination signal received. Shutting down..." + STOP=true + # Terminate all child processes + for pid in "${PIDS[@]}"; do + if kill -0 "$pid" 2>/dev/null; then + echo "Killing process $pid" + kill "$pid" + fi + done + exit 0 +} + +# Trap SIGINT and SIGTERM to invoke cleanup +trap cleanup SIGINT SIGTERM + +# Function to execute admin_server with retry logic +run_server(){ + local retry_count=0 + while ! $STOP && [ $retry_count -lt $MAX_RETRIES ]; do + echo "Starting admin_server.py (Attempt $((retry_count+1)))" + $PY admin/server/admin_server.py + EXIT_CODE=$? + if [ $EXIT_CODE -eq 0 ]; then + echo "admin_server.py exited successfully." + break + else + echo "admin_server.py failed with exit code $EXIT_CODE. Retrying..." >&2 + retry_count=$((retry_count + 1)) + sleep 2 + fi + done + + if [ $retry_count -ge $MAX_RETRIES ]; then + echo "admin_server.py failed after $MAX_RETRIES attempts. Exiting..." >&2 + cleanup + fi +} + +# Start the main server +run_server & +PIDS+=($!) + +# Wait for all background processes to finish +wait diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index ccaa833901..79fa69eef0 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -27,6 +27,7 @@ from ollama import Client from openai import OpenAI from zhipuai import ZhipuAI +from common.exceptions import ModelException from common.log_utils import log_exception from common.token_utils import num_tokens_from_string, truncate, total_token_count_from_response from common import settings @@ -36,6 +37,14 @@ import base64 logger = logging.getLogger(__name__) +def _raise_model_exception_if_failed(resp): + status_code = resp.status_code + if status_code >= 400: + if status_code < 500 and status_code not in [408, 429]: + raise ModelException(f"status: {resp.status_code}, response: {resp.text}", retryable=False) + raise ModelException(f"status: {resp.status_code}, response: {resp.text}", retryable=True) + + def _dashscope_base_url_for_log(base_url: str) -> str: """Log host/path only (no query string) so secrets in URLs are not printed.""" return base_url.split("?", 1)[0].strip()[:256] @@ -172,7 +181,10 @@ class OpenAIEmbed(Base): ress = [] total_tokens = 0 for i in range(0, len(texts), batch_size): - res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name, encoding_format="float", extra_body={"drop_params": True}) + try: + res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name, encoding_format="float", extra_body={"drop_params": True}) + except Exception as _e: + raise ModelException(f"Error: {_e}") try: ress.extend([d.embedding for d in res.data]) total_tokens += total_token_count_from_response(res) @@ -182,7 +194,10 @@ class OpenAIEmbed(Base): return np.array(ress), total_tokens def encode_queries(self, text): - res = self.client.embeddings.create(input=[truncate(text, 8191)], model=self.model_name, encoding_format="float", extra_body={"drop_params": True}) + try: + res = self.client.embeddings.create(input=[truncate(text, 8191)], model=self.model_name, encoding_format="float", extra_body={"drop_params": True}) + except Exception as _e: + raise ModelException(f"Error: {_e}") try: return np.array(res.data[0].embedding), total_token_count_from_response(res) except Exception as _e: @@ -294,20 +309,23 @@ class QWenEmbed(Base): token_count = 0 texts = [truncate(t, 2048) for t in texts] for i in range(0, len(texts), batch_size): - retry_max = 5 - with _dashscope_native_api_url_scope(self._dashscope_http_api_url): - resp = dashscope.TextEmbedding.call(model=self.model_name, input=texts[i : i + batch_size], api_key=self.key, text_type="document") - while (resp["output"] is None or resp["output"].get("embeddings") is None) and retry_max > 0: - time.sleep(10) + + retry_max, retry_wait_secs = 5, 10 + for retry in range(retry_max): with _dashscope_native_api_url_scope(self._dashscope_http_api_url): resp = dashscope.TextEmbedding.call(model=self.model_name, input=texts[i : i + batch_size], api_key=self.key, text_type="document") - retry_max -= 1 - if retry_max == 0 and (resp["output"] is None or resp["output"].get("embeddings") is None): - if resp.get("message"): - log_exception(ValueError(f"Retry_max reached, calling embedding model failed: {resp['message']}")) + status_code = resp.status_code + if status_code >= 400 and status_code < 500 and status_code not in [408, 429]: + raise ModelException(f"Error, status: {status_code}, response: {resp}") + # No need to retry for 4XX error + if status_code == 200: + break + if retry < retry_max - 1: + logging.warning(f"Got error response from DashScope API (status: {status_code}, response: {resp}). Wait {retry_wait_secs} seconds. Retrying...") + time.sleep(retry_wait_secs) else: - log_exception(ValueError("Retry_max reached, calling embedding model failed")) - raise + raise ModelException(f"Error after {retry_max} retries., status: {status_code}, response: {resp}") + try: embds = [[] for _ in range(len(resp["output"]["embeddings"]))] for e in resp["output"]["embeddings"]: @@ -316,17 +334,21 @@ class QWenEmbed(Base): token_count += total_token_count_from_response(resp) except Exception as _e: log_exception(_e, resp) - raise + raise ModelException(f"Error: {status_code}: {resp}") return np.array(res), token_count def encode_queries(self, text): with _dashscope_native_api_url_scope(self._dashscope_http_api_url): resp = dashscope.TextEmbedding.call(model=self.model_name, input=text[:2048], api_key=self.key, text_type="query") + status_code = resp.status_code + if status_code != 200: + raise ModelException(f"Error: status: {status_code}: code: {resp.get('code')}, message: {resp.get('message')}") + # No need to retry for 4XX error try: return np.array(resp["output"]["embeddings"][0]["embedding"]), total_token_count_from_response(resp) except Exception as _e: log_exception(_e, resp) - raise Exception(f"Error: {resp}") + raise ModelException(f"Error: {status_code}: {resp}") class ZhipuEmbed(Base): @@ -494,6 +516,7 @@ class JinaMultiVecEmbed(Base): data["truncate"] = True response = requests.post(self.base_url, headers=self.headers, json=data, timeout=30) + _raise_model_exception_if_failed(response) try: res = response.json() for d in res["data"]: @@ -772,6 +795,7 @@ class NvidiaEmbed(Base): "truncate": "END", } response = requests.post(self.base_url, headers=self.headers, json=payload, timeout=30) + _raise_model_exception_if_failed(response) try: res = response.json() ress.extend([d["embedding"] for d in res["data"]]) @@ -912,6 +936,7 @@ class SILICONFLOWEmbed(Base): "encoding_format": "float", } response = requests.post(self.base_url, json=payload, headers=self.headers, timeout=30) + _raise_model_exception_if_failed(response) try: res = response.json() ress.extend([d["embedding"] for d in res["data"]]) @@ -929,6 +954,7 @@ class SILICONFLOWEmbed(Base): "encoding_format": "float", } response = requests.post(self.base_url, json=payload, headers=self.headers, timeout=30) + _raise_model_exception_if_failed(response) try: res = response.json() return np.array(res["data"][0]["embedding"]), total_token_count_from_response(res) @@ -1039,19 +1065,15 @@ class HuggingFaceEmbed(Base): def encode(self, texts: list): response = requests.post(f"{self.base_url}/embed", json={"inputs": texts}, headers={"Content-Type": "application/json"}, timeout=30) - if response.status_code == 200: - embeddings = response.json() - else: - raise Exception(f"Error: {response.status_code} - {response.text}") + _raise_model_exception_if_failed(response) + embeddings = response.json() return np.array(embeddings), sum([num_tokens_from_string(text) for text in texts]) def encode_queries(self, text: str): response = requests.post(f"{self.base_url}/embed", json={"inputs": text}, headers={"Content-Type": "application/json"}, timeout=30) - if response.status_code == 200: - embedding = response.json()[0] - return np.array(embedding), num_tokens_from_string(text) - else: - raise Exception(f"Error: {response.status_code} - {response.text}") + _raise_model_exception_if_failed(response) + embedding = response.json()[0] + return np.array(embedding), num_tokens_from_string(text) class VolcEngineEmbed(Base): @@ -1248,6 +1270,7 @@ class PerplexityEmbed(Base): "encoding_format": "base64_int8", } response = requests.post(url, headers=self.headers, json=payload, timeout=30) + _raise_model_exception_if_failed(response) try: res = response.json() for doc in res["data"]: @@ -1267,6 +1290,7 @@ class PerplexityEmbed(Base): "encoding_format": "base64_int8", } response = requests.post(url, headers=self.headers, json=payload, timeout=30) + _raise_model_exception_if_failed(response) try: res = response.json() for d in res["data"]: diff --git a/test/unit_test/rag/llm/test_perplexity_embed.py b/test/unit_test/rag/llm/test_perplexity_embed.py index 9edef6736c..2e4a15cd99 100644 --- a/test/unit_test/rag/llm/test_perplexity_embed.py +++ b/test/unit_test/rag/llm/test_perplexity_embed.py @@ -57,6 +57,16 @@ def _mock_contextualized_response(docs_embeddings_b64, total_tokens=20): } +def _mock_http_response(json_response=None, status_code=200, text=""): + """Build a minimal requests.Response-like mock.""" + mock_resp = MagicMock() + mock_resp.status_code = status_code + mock_resp.text = text + if json_response is not None: + mock_resp.json.return_value = json_response + return mock_resp + + class TestPerplexityEmbedInit: def test_default_base_url(self): embed = PerplexityEmbed("test-key", "pplx-embed-v1-0.6b") @@ -125,8 +135,7 @@ class TestPerplexityEmbedStandardEncode: @patch("rag.llm.embedding_model.requests.post") def test_encode_single_text(self, mock_post): emb_b64 = _make_b64_int8([10, 20, 30]) - mock_resp = MagicMock() - mock_resp.json.return_value = _mock_standard_response([emb_b64], total_tokens=5) + mock_resp = _mock_http_response(_mock_standard_response([emb_b64], total_tokens=5)) mock_post.return_value = mock_resp embed = PerplexityEmbed("key", "pplx-embed-v1-0.6b") @@ -144,8 +153,7 @@ class TestPerplexityEmbedStandardEncode: emb1 = _make_b64_int8([1, 2]) emb2 = _make_b64_int8([3, 4]) emb3 = _make_b64_int8([5, 6]) - mock_resp = MagicMock() - mock_resp.json.return_value = _mock_standard_response([emb1, emb2, emb3], total_tokens=15) + mock_resp = _mock_http_response(_mock_standard_response([emb1, emb2, emb3], total_tokens=15)) mock_post.return_value = mock_resp embed = PerplexityEmbed("key", "pplx-embed-v1-0.6b") @@ -156,8 +164,7 @@ class TestPerplexityEmbedStandardEncode: @patch("rag.llm.embedding_model.requests.post") def test_encode_sends_correct_payload(self, mock_post): - mock_resp = MagicMock() - mock_resp.json.return_value = _mock_standard_response([_make_b64_int8([1])], total_tokens=1) + mock_resp = _mock_http_response(_mock_standard_response([_make_b64_int8([1])], total_tokens=1)) mock_post.return_value = mock_resp embed = PerplexityEmbed("key", "pplx-embed-v1-4b") @@ -171,9 +178,8 @@ class TestPerplexityEmbedStandardEncode: @patch("rag.llm.embedding_model.requests.post") def test_encode_api_error_raises(self, mock_post): - mock_resp = MagicMock() + mock_resp = _mock_http_response(text="Internal Server Error") mock_resp.json.side_effect = Exception("Invalid JSON") - mock_resp.text = "Internal Server Error" mock_post.return_value = mock_resp embed = PerplexityEmbed("key", "pplx-embed-v1-0.6b") @@ -186,8 +192,7 @@ class TestPerplexityEmbedContextualizedEncode: def test_contextualized_encode(self, mock_post): emb1 = _make_b64_int8([10, 20]) emb2 = _make_b64_int8([30, 40]) - mock_resp = MagicMock() - mock_resp.json.return_value = _mock_contextualized_response([[emb1], [emb2]], total_tokens=12) + mock_resp = _mock_http_response(_mock_contextualized_response([[emb1], [emb2]], total_tokens=12)) mock_post.return_value = mock_resp embed = PerplexityEmbed("key", "pplx-embed-context-v1-0.6b") @@ -200,8 +205,7 @@ class TestPerplexityEmbedContextualizedEncode: @patch("rag.llm.embedding_model.requests.post") def test_contextualized_uses_correct_endpoint(self, mock_post): - mock_resp = MagicMock() - mock_resp.json.return_value = _mock_contextualized_response([[_make_b64_int8([1])]], total_tokens=1) + mock_resp = _mock_http_response(_mock_contextualized_response([[_make_b64_int8([1])]], total_tokens=1)) mock_post.return_value = mock_resp embed = PerplexityEmbed("key", "pplx-embed-context-v1-4b") @@ -212,8 +216,7 @@ class TestPerplexityEmbedContextualizedEncode: @patch("rag.llm.embedding_model.requests.post") def test_contextualized_sends_nested_input(self, mock_post): - mock_resp = MagicMock() - mock_resp.json.return_value = _mock_contextualized_response([[_make_b64_int8([1])]], total_tokens=1) + mock_resp = _mock_http_response(_mock_contextualized_response([[_make_b64_int8([1])]], total_tokens=1)) mock_post.return_value = mock_resp embed = PerplexityEmbed("key", "pplx-embed-context-v1-0.6b") @@ -228,8 +231,7 @@ class TestPerplexityEmbedEncodeQueries: @patch("rag.llm.embedding_model.requests.post") def test_encode_queries_returns_single_vector(self, mock_post): emb = _make_b64_int8([5, 10, 15, 20]) - mock_resp = MagicMock() - mock_resp.json.return_value = _mock_standard_response([emb], total_tokens=3) + mock_resp = _mock_http_response(_mock_standard_response([emb], total_tokens=3)) mock_post.return_value = mock_resp embed = PerplexityEmbed("key", "pplx-embed-v1-0.6b")