Bug fix: Enhance embeding model to give better error message (#15346)

To resolve https://github.com/infiniflow/ragflow/issues/15343 enhance
the model embedding message to give extact failure message to customer.


# QWen

## Retrieval
<img width="3321" height="1033" alt="image"
src="https://github.com/user-attachments/assets/6b82921a-a3a7-4a33-a383-1cf316398ee2"
/>

## Chat
<img width="2241" height="311" alt="image"
src="https://github.com/user-attachments/assets/ec311365-62d5-407a-8915-5c8d72be9716"
/>


# SiliconFlow
## Retrieval
<img width="3321" height="1033" alt="image"
src="https://github.com/user-attachments/assets/ee2cd191-a27d-4729-b53d-2fbdb4e352cd"
/>

## Chat
<img width="1562" height="210" alt="image"
src="https://github.com/user-attachments/assets/10376a8e-a3f4-422f-bc2e-96f2a8a96448"
/>

# Baichuan
## Retrieval
<img width="3321" height="1107" alt="image"
src="https://github.com/user-attachments/assets/dcb5409d-f7fc-4804-b186-5e1ee11e09c4"
/>

## Chat
<img width="2241" height="311" alt="image"
src="https://github.com/user-attachments/assets/ec311365-62d5-407a-8915-5c8d72be9716"
/>


# Zhipu
zhipu is good.
This commit is contained in:
Wang Qi
2026-06-01 19:18:16 +08:00
committed by GitHub
parent 252cc19f93
commit 1a6df01b53
8 changed files with 191 additions and 57 deletions

View File

@@ -280,7 +280,7 @@ jobs:
svc_ready=1
break
fi
echo "Waiting for service to be available... ($i/120)"
echo "Waiting for service to be available... ($i/60)"
sleep 5
done
if [ "$svc_ready" -ne 1 ]; then
@@ -300,7 +300,7 @@ jobs:
svc_ready=1
break
fi
echo "Waiting for service to be available... ($i/120)"
echo "Waiting for service to be available... ($i/60)"
sleep 5
done
if [ "$svc_ready" -ne 1 ]; then
@@ -376,7 +376,7 @@ jobs:
svc_ready=1
break
fi
echo "Waiting for service to be available... ($i/120)"
echo "Waiting for service to be available... ($i/60)"
sleep 5
done
if [ "$svc_ready" -ne 1 ]; then
@@ -496,7 +496,7 @@ jobs:
svc_ready=1
break
fi
echo "Waiting for service to be available... ($i/120)"
echo "Waiting for service to be available... ($i/60)"
sleep 5
done
if [ "$svc_ready" -ne 1 ]; then
@@ -516,7 +516,7 @@ jobs:
svc_ready=1
break
fi
echo "Waiting for service to be available... ($i/120)"
echo "Waiting for service to be available... ($i/60)"
sleep 5
done
if [ "$svc_ready" -ne 1 ]; then
@@ -592,7 +592,7 @@ jobs:
svc_ready=1
break
fi
echo "Waiting for service to be available... ($i/120)"
echo "Waiting for service to be available... ($i/60)"
sleep 5
done
if [ "$svc_ready" -ne 1 ]; then

View File

@@ -34,6 +34,7 @@ from quart_schema import QuartSchema
from common import settings
from api.utils.api_utils import server_error_response, get_json_result
from api.constants import API_VERSION
from common.exceptions import ModelException
from common.misc_utils import get_uuid
settings.init_settings()
@@ -361,6 +362,12 @@ async def unauthorized_werkzeug(error):
return get_json_result(code=error.code, message=error.description), RetCode.UNAUTHORIZED
@app.errorhandler(ModelException)
async def handle_model_exception(error):
logging.warning("Forbidden request")
return get_json_result(code=RetCode.BAD_REQUEST, message=repr(error)), 200
@app.teardown_request
def _db_close(exception):
if exception:

View File

@@ -497,17 +497,11 @@ async def search_datasets(tenant_id):
req, err = await validate_and_parse_json_request(request, SearchDatasetsReq)
if err is not None:
return get_error_argument_result(err)
try:
success, result = await dataset_api_service.search_datasets(tenant_id, req)
if success:
return get_result(data=result)
else:
return get_error_data_result(message=result)
except Exception as e:
logging.exception(e)
if "not_found" in str(e):
return get_error_data_result(message="No chunk found! Check the chunk status please!")
return get_error_data_result(message="Internal server error")
success, result = await dataset_api_service.search_datasets(tenant_id, req)
if success:
return get_result(data=result)
else:
return get_error_data_result(message=result)
@manager.route("/datasets/<dataset_id>/search", methods=["POST"]) # noqa: F821

View File

@@ -147,6 +147,9 @@ def server_error_response(e):
if repr(e).find("index_not_found_exception") >= 0:
return get_json_result(code=RetCode.EXCEPTION_ERROR, message="No chunk found, please upload file and parse it.")
if "not_found" in str(e):
return get_error_data_result(message="No chunk found! Check the chunk status please!")
return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))

View File

@@ -26,3 +26,9 @@ class ArgumentException(Exception):
class NotFoundException(Exception):
def __init__(self, msg):
self.msg = msg
class ModelException(Exception):
def __init__(self, msg, retryable=False):
super().__init__(msg)
self.msg = msg
self.retryable = retryable

98
docker/launch_admin_service.sh Executable file
View File

@@ -0,0 +1,98 @@
#!/bin/bash
# Exit immediately if a command exits with a non-zero status
set -e
# Function to load environment variables from .env file
load_env_file() {
# Get the directory of the current script
local script_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
local env_file="$script_dir/.env"
# Check if .env file exists
if [ -f "$env_file" ]; then
echo "Loading environment variables from: $env_file"
# Source the .env file
set -a
source "$env_file"
set +a
else
echo "Warning: .env file not found at: $env_file"
fi
}
# Load environment variables
load_env_file
# Unset HTTP proxies that might be set by Docker daemon
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
export PYTHONPATH=$(pwd)
export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu/
JEMALLOC_PATH=$(pkg-config --variable=libdir jemalloc)/libjemalloc.so
PY=python3
# Set default number of workers if WS is not set or less than 1
if [[ -z "$WS" || $WS -lt 1 ]]; then
WS=1
fi
# Maximum number of retries for each task executor and server
MAX_RETRIES=5
# Flag to control termination
STOP=false
# Array to keep track of child PIDs
PIDS=()
# Set the path to the NLTK data directory
export NLTK_DATA="./nltk_data"
# Function to handle termination signals
cleanup() {
echo "Termination signal received. Shutting down..."
STOP=true
# Terminate all child processes
for pid in "${PIDS[@]}"; do
if kill -0 "$pid" 2>/dev/null; then
echo "Killing process $pid"
kill "$pid"
fi
done
exit 0
}
# Trap SIGINT and SIGTERM to invoke cleanup
trap cleanup SIGINT SIGTERM
# Function to execute admin_server with retry logic
run_server(){
local retry_count=0
while ! $STOP && [ $retry_count -lt $MAX_RETRIES ]; do
echo "Starting admin_server.py (Attempt $((retry_count+1)))"
$PY admin/server/admin_server.py
EXIT_CODE=$?
if [ $EXIT_CODE -eq 0 ]; then
echo "admin_server.py exited successfully."
break
else
echo "admin_server.py failed with exit code $EXIT_CODE. Retrying..." >&2
retry_count=$((retry_count + 1))
sleep 2
fi
done
if [ $retry_count -ge $MAX_RETRIES ]; then
echo "admin_server.py failed after $MAX_RETRIES attempts. Exiting..." >&2
cleanup
fi
}
# Start the main server
run_server &
PIDS+=($!)
# Wait for all background processes to finish
wait

View File

@@ -27,6 +27,7 @@ from ollama import Client
from openai import OpenAI
from zhipuai import ZhipuAI
from common.exceptions import ModelException
from common.log_utils import log_exception
from common.token_utils import num_tokens_from_string, truncate, total_token_count_from_response
from common import settings
@@ -36,6 +37,14 @@ import base64
logger = logging.getLogger(__name__)
def _raise_model_exception_if_failed(resp):
status_code = resp.status_code
if status_code >= 400:
if status_code < 500 and status_code not in [408, 429]:
raise ModelException(f"status: {resp.status_code}, response: {resp.text}", retryable=False)
raise ModelException(f"status: {resp.status_code}, response: {resp.text}", retryable=True)
def _dashscope_base_url_for_log(base_url: str) -> str:
"""Log host/path only (no query string) so secrets in URLs are not printed."""
return base_url.split("?", 1)[0].strip()[:256]
@@ -172,7 +181,10 @@ class OpenAIEmbed(Base):
ress = []
total_tokens = 0
for i in range(0, len(texts), batch_size):
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name, encoding_format="float", extra_body={"drop_params": True})
try:
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name, encoding_format="float", extra_body={"drop_params": True})
except Exception as _e:
raise ModelException(f"Error: {_e}")
try:
ress.extend([d.embedding for d in res.data])
total_tokens += total_token_count_from_response(res)
@@ -182,7 +194,10 @@ class OpenAIEmbed(Base):
return np.array(ress), total_tokens
def encode_queries(self, text):
res = self.client.embeddings.create(input=[truncate(text, 8191)], model=self.model_name, encoding_format="float", extra_body={"drop_params": True})
try:
res = self.client.embeddings.create(input=[truncate(text, 8191)], model=self.model_name, encoding_format="float", extra_body={"drop_params": True})
except Exception as _e:
raise ModelException(f"Error: {_e}")
try:
return np.array(res.data[0].embedding), total_token_count_from_response(res)
except Exception as _e:
@@ -294,20 +309,23 @@ class QWenEmbed(Base):
token_count = 0
texts = [truncate(t, 2048) for t in texts]
for i in range(0, len(texts), batch_size):
retry_max = 5
with _dashscope_native_api_url_scope(self._dashscope_http_api_url):
resp = dashscope.TextEmbedding.call(model=self.model_name, input=texts[i : i + batch_size], api_key=self.key, text_type="document")
while (resp["output"] is None or resp["output"].get("embeddings") is None) and retry_max > 0:
time.sleep(10)
retry_max, retry_wait_secs = 5, 10
for retry in range(retry_max):
with _dashscope_native_api_url_scope(self._dashscope_http_api_url):
resp = dashscope.TextEmbedding.call(model=self.model_name, input=texts[i : i + batch_size], api_key=self.key, text_type="document")
retry_max -= 1
if retry_max == 0 and (resp["output"] is None or resp["output"].get("embeddings") is None):
if resp.get("message"):
log_exception(ValueError(f"Retry_max reached, calling embedding model failed: {resp['message']}"))
status_code = resp.status_code
if status_code >= 400 and status_code < 500 and status_code not in [408, 429]:
raise ModelException(f"Error, status: {status_code}, response: {resp}")
# No need to retry for 4XX error
if status_code == 200:
break
if retry < retry_max - 1:
logging.warning(f"Got error response from DashScope API (status: {status_code}, response: {resp}). Wait {retry_wait_secs} seconds. Retrying...")
time.sleep(retry_wait_secs)
else:
log_exception(ValueError("Retry_max reached, calling embedding model failed"))
raise
raise ModelException(f"Error after {retry_max} retries., status: {status_code}, response: {resp}")
try:
embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
for e in resp["output"]["embeddings"]:
@@ -316,17 +334,21 @@ class QWenEmbed(Base):
token_count += total_token_count_from_response(resp)
except Exception as _e:
log_exception(_e, resp)
raise
raise ModelException(f"Error: {status_code}: {resp}")
return np.array(res), token_count
def encode_queries(self, text):
with _dashscope_native_api_url_scope(self._dashscope_http_api_url):
resp = dashscope.TextEmbedding.call(model=self.model_name, input=text[:2048], api_key=self.key, text_type="query")
status_code = resp.status_code
if status_code != 200:
raise ModelException(f"Error: status: {status_code}: code: {resp.get('code')}, message: {resp.get('message')}")
# No need to retry for 4XX error
try:
return np.array(resp["output"]["embeddings"][0]["embedding"]), total_token_count_from_response(resp)
except Exception as _e:
log_exception(_e, resp)
raise Exception(f"Error: {resp}")
raise ModelException(f"Error: {status_code}: {resp}")
class ZhipuEmbed(Base):
@@ -494,6 +516,7 @@ class JinaMultiVecEmbed(Base):
data["truncate"] = True
response = requests.post(self.base_url, headers=self.headers, json=data, timeout=30)
_raise_model_exception_if_failed(response)
try:
res = response.json()
for d in res["data"]:
@@ -772,6 +795,7 @@ class NvidiaEmbed(Base):
"truncate": "END",
}
response = requests.post(self.base_url, headers=self.headers, json=payload, timeout=30)
_raise_model_exception_if_failed(response)
try:
res = response.json()
ress.extend([d["embedding"] for d in res["data"]])
@@ -912,6 +936,7 @@ class SILICONFLOWEmbed(Base):
"encoding_format": "float",
}
response = requests.post(self.base_url, json=payload, headers=self.headers, timeout=30)
_raise_model_exception_if_failed(response)
try:
res = response.json()
ress.extend([d["embedding"] for d in res["data"]])
@@ -929,6 +954,7 @@ class SILICONFLOWEmbed(Base):
"encoding_format": "float",
}
response = requests.post(self.base_url, json=payload, headers=self.headers, timeout=30)
_raise_model_exception_if_failed(response)
try:
res = response.json()
return np.array(res["data"][0]["embedding"]), total_token_count_from_response(res)
@@ -1039,19 +1065,15 @@ class HuggingFaceEmbed(Base):
def encode(self, texts: list):
response = requests.post(f"{self.base_url}/embed", json={"inputs": texts}, headers={"Content-Type": "application/json"}, timeout=30)
if response.status_code == 200:
embeddings = response.json()
else:
raise Exception(f"Error: {response.status_code} - {response.text}")
_raise_model_exception_if_failed(response)
embeddings = response.json()
return np.array(embeddings), sum([num_tokens_from_string(text) for text in texts])
def encode_queries(self, text: str):
response = requests.post(f"{self.base_url}/embed", json={"inputs": text}, headers={"Content-Type": "application/json"}, timeout=30)
if response.status_code == 200:
embedding = response.json()[0]
return np.array(embedding), num_tokens_from_string(text)
else:
raise Exception(f"Error: {response.status_code} - {response.text}")
_raise_model_exception_if_failed(response)
embedding = response.json()[0]
return np.array(embedding), num_tokens_from_string(text)
class VolcEngineEmbed(Base):
@@ -1248,6 +1270,7 @@ class PerplexityEmbed(Base):
"encoding_format": "base64_int8",
}
response = requests.post(url, headers=self.headers, json=payload, timeout=30)
_raise_model_exception_if_failed(response)
try:
res = response.json()
for doc in res["data"]:
@@ -1267,6 +1290,7 @@ class PerplexityEmbed(Base):
"encoding_format": "base64_int8",
}
response = requests.post(url, headers=self.headers, json=payload, timeout=30)
_raise_model_exception_if_failed(response)
try:
res = response.json()
for d in res["data"]:

View File

@@ -57,6 +57,16 @@ def _mock_contextualized_response(docs_embeddings_b64, total_tokens=20):
}
def _mock_http_response(json_response=None, status_code=200, text=""):
"""Build a minimal requests.Response-like mock."""
mock_resp = MagicMock()
mock_resp.status_code = status_code
mock_resp.text = text
if json_response is not None:
mock_resp.json.return_value = json_response
return mock_resp
class TestPerplexityEmbedInit:
def test_default_base_url(self):
embed = PerplexityEmbed("test-key", "pplx-embed-v1-0.6b")
@@ -125,8 +135,7 @@ class TestPerplexityEmbedStandardEncode:
@patch("rag.llm.embedding_model.requests.post")
def test_encode_single_text(self, mock_post):
emb_b64 = _make_b64_int8([10, 20, 30])
mock_resp = MagicMock()
mock_resp.json.return_value = _mock_standard_response([emb_b64], total_tokens=5)
mock_resp = _mock_http_response(_mock_standard_response([emb_b64], total_tokens=5))
mock_post.return_value = mock_resp
embed = PerplexityEmbed("key", "pplx-embed-v1-0.6b")
@@ -144,8 +153,7 @@ class TestPerplexityEmbedStandardEncode:
emb1 = _make_b64_int8([1, 2])
emb2 = _make_b64_int8([3, 4])
emb3 = _make_b64_int8([5, 6])
mock_resp = MagicMock()
mock_resp.json.return_value = _mock_standard_response([emb1, emb2, emb3], total_tokens=15)
mock_resp = _mock_http_response(_mock_standard_response([emb1, emb2, emb3], total_tokens=15))
mock_post.return_value = mock_resp
embed = PerplexityEmbed("key", "pplx-embed-v1-0.6b")
@@ -156,8 +164,7 @@ class TestPerplexityEmbedStandardEncode:
@patch("rag.llm.embedding_model.requests.post")
def test_encode_sends_correct_payload(self, mock_post):
mock_resp = MagicMock()
mock_resp.json.return_value = _mock_standard_response([_make_b64_int8([1])], total_tokens=1)
mock_resp = _mock_http_response(_mock_standard_response([_make_b64_int8([1])], total_tokens=1))
mock_post.return_value = mock_resp
embed = PerplexityEmbed("key", "pplx-embed-v1-4b")
@@ -171,9 +178,8 @@ class TestPerplexityEmbedStandardEncode:
@patch("rag.llm.embedding_model.requests.post")
def test_encode_api_error_raises(self, mock_post):
mock_resp = MagicMock()
mock_resp = _mock_http_response(text="Internal Server Error")
mock_resp.json.side_effect = Exception("Invalid JSON")
mock_resp.text = "Internal Server Error"
mock_post.return_value = mock_resp
embed = PerplexityEmbed("key", "pplx-embed-v1-0.6b")
@@ -186,8 +192,7 @@ class TestPerplexityEmbedContextualizedEncode:
def test_contextualized_encode(self, mock_post):
emb1 = _make_b64_int8([10, 20])
emb2 = _make_b64_int8([30, 40])
mock_resp = MagicMock()
mock_resp.json.return_value = _mock_contextualized_response([[emb1], [emb2]], total_tokens=12)
mock_resp = _mock_http_response(_mock_contextualized_response([[emb1], [emb2]], total_tokens=12))
mock_post.return_value = mock_resp
embed = PerplexityEmbed("key", "pplx-embed-context-v1-0.6b")
@@ -200,8 +205,7 @@ class TestPerplexityEmbedContextualizedEncode:
@patch("rag.llm.embedding_model.requests.post")
def test_contextualized_uses_correct_endpoint(self, mock_post):
mock_resp = MagicMock()
mock_resp.json.return_value = _mock_contextualized_response([[_make_b64_int8([1])]], total_tokens=1)
mock_resp = _mock_http_response(_mock_contextualized_response([[_make_b64_int8([1])]], total_tokens=1))
mock_post.return_value = mock_resp
embed = PerplexityEmbed("key", "pplx-embed-context-v1-4b")
@@ -212,8 +216,7 @@ class TestPerplexityEmbedContextualizedEncode:
@patch("rag.llm.embedding_model.requests.post")
def test_contextualized_sends_nested_input(self, mock_post):
mock_resp = MagicMock()
mock_resp.json.return_value = _mock_contextualized_response([[_make_b64_int8([1])]], total_tokens=1)
mock_resp = _mock_http_response(_mock_contextualized_response([[_make_b64_int8([1])]], total_tokens=1))
mock_post.return_value = mock_resp
embed = PerplexityEmbed("key", "pplx-embed-context-v1-0.6b")
@@ -228,8 +231,7 @@ class TestPerplexityEmbedEncodeQueries:
@patch("rag.llm.embedding_model.requests.post")
def test_encode_queries_returns_single_vector(self, mock_post):
emb = _make_b64_int8([5, 10, 15, 20])
mock_resp = MagicMock()
mock_resp.json.return_value = _mock_standard_response([emb], total_tokens=3)
mock_resp = _mock_http_response(_mock_standard_response([emb], total_tokens=3))
mock_post.return_value = mock_resp
embed = PerplexityEmbed("key", "pplx-embed-v1-0.6b")