mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
Bug fix: Enhance embeding model to give better error message (#15346)
To resolve https://github.com/infiniflow/ragflow/issues/15343 enhance the model embedding message to give extact failure message to customer. # QWen ## Retrieval <img width="3321" height="1033" alt="image" src="https://github.com/user-attachments/assets/6b82921a-a3a7-4a33-a383-1cf316398ee2" /> ## Chat <img width="2241" height="311" alt="image" src="https://github.com/user-attachments/assets/ec311365-62d5-407a-8915-5c8d72be9716" /> # SiliconFlow ## Retrieval <img width="3321" height="1033" alt="image" src="https://github.com/user-attachments/assets/ee2cd191-a27d-4729-b53d-2fbdb4e352cd" /> ## Chat <img width="1562" height="210" alt="image" src="https://github.com/user-attachments/assets/10376a8e-a3f4-422f-bc2e-96f2a8a96448" /> # Baichuan ## Retrieval <img width="3321" height="1107" alt="image" src="https://github.com/user-attachments/assets/dcb5409d-f7fc-4804-b186-5e1ee11e09c4" /> ## Chat <img width="2241" height="311" alt="image" src="https://github.com/user-attachments/assets/ec311365-62d5-407a-8915-5c8d72be9716" /> # Zhipu zhipu is good.
This commit is contained in:
12
.github/workflows/tests.yml
vendored
12
.github/workflows/tests.yml
vendored
@@ -280,7 +280,7 @@ jobs:
|
||||
svc_ready=1
|
||||
break
|
||||
fi
|
||||
echo "Waiting for service to be available... ($i/120)"
|
||||
echo "Waiting for service to be available... ($i/60)"
|
||||
sleep 5
|
||||
done
|
||||
if [ "$svc_ready" -ne 1 ]; then
|
||||
@@ -300,7 +300,7 @@ jobs:
|
||||
svc_ready=1
|
||||
break
|
||||
fi
|
||||
echo "Waiting for service to be available... ($i/120)"
|
||||
echo "Waiting for service to be available... ($i/60)"
|
||||
sleep 5
|
||||
done
|
||||
if [ "$svc_ready" -ne 1 ]; then
|
||||
@@ -376,7 +376,7 @@ jobs:
|
||||
svc_ready=1
|
||||
break
|
||||
fi
|
||||
echo "Waiting for service to be available... ($i/120)"
|
||||
echo "Waiting for service to be available... ($i/60)"
|
||||
sleep 5
|
||||
done
|
||||
if [ "$svc_ready" -ne 1 ]; then
|
||||
@@ -496,7 +496,7 @@ jobs:
|
||||
svc_ready=1
|
||||
break
|
||||
fi
|
||||
echo "Waiting for service to be available... ($i/120)"
|
||||
echo "Waiting for service to be available... ($i/60)"
|
||||
sleep 5
|
||||
done
|
||||
if [ "$svc_ready" -ne 1 ]; then
|
||||
@@ -516,7 +516,7 @@ jobs:
|
||||
svc_ready=1
|
||||
break
|
||||
fi
|
||||
echo "Waiting for service to be available... ($i/120)"
|
||||
echo "Waiting for service to be available... ($i/60)"
|
||||
sleep 5
|
||||
done
|
||||
if [ "$svc_ready" -ne 1 ]; then
|
||||
@@ -592,7 +592,7 @@ jobs:
|
||||
svc_ready=1
|
||||
break
|
||||
fi
|
||||
echo "Waiting for service to be available... ($i/120)"
|
||||
echo "Waiting for service to be available... ($i/60)"
|
||||
sleep 5
|
||||
done
|
||||
if [ "$svc_ready" -ne 1 ]; then
|
||||
|
||||
@@ -34,6 +34,7 @@ from quart_schema import QuartSchema
|
||||
from common import settings
|
||||
from api.utils.api_utils import server_error_response, get_json_result
|
||||
from api.constants import API_VERSION
|
||||
from common.exceptions import ModelException
|
||||
from common.misc_utils import get_uuid
|
||||
|
||||
settings.init_settings()
|
||||
@@ -361,6 +362,12 @@ async def unauthorized_werkzeug(error):
|
||||
return get_json_result(code=error.code, message=error.description), RetCode.UNAUTHORIZED
|
||||
|
||||
|
||||
@app.errorhandler(ModelException)
|
||||
async def handle_model_exception(error):
|
||||
logging.warning("Forbidden request")
|
||||
return get_json_result(code=RetCode.BAD_REQUEST, message=repr(error)), 200
|
||||
|
||||
|
||||
@app.teardown_request
|
||||
def _db_close(exception):
|
||||
if exception:
|
||||
|
||||
@@ -497,17 +497,11 @@ async def search_datasets(tenant_id):
|
||||
req, err = await validate_and_parse_json_request(request, SearchDatasetsReq)
|
||||
if err is not None:
|
||||
return get_error_argument_result(err)
|
||||
try:
|
||||
success, result = await dataset_api_service.search_datasets(tenant_id, req)
|
||||
if success:
|
||||
return get_result(data=result)
|
||||
else:
|
||||
return get_error_data_result(message=result)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
if "not_found" in str(e):
|
||||
return get_error_data_result(message="No chunk found! Check the chunk status please!")
|
||||
return get_error_data_result(message="Internal server error")
|
||||
success, result = await dataset_api_service.search_datasets(tenant_id, req)
|
||||
if success:
|
||||
return get_result(data=result)
|
||||
else:
|
||||
return get_error_data_result(message=result)
|
||||
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/search", methods=["POST"]) # noqa: F821
|
||||
|
||||
@@ -147,6 +147,9 @@ def server_error_response(e):
|
||||
if repr(e).find("index_not_found_exception") >= 0:
|
||||
return get_json_result(code=RetCode.EXCEPTION_ERROR, message="No chunk found, please upload file and parse it.")
|
||||
|
||||
if "not_found" in str(e):
|
||||
return get_error_data_result(message="No chunk found! Check the chunk status please!")
|
||||
|
||||
return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
|
||||
|
||||
|
||||
|
||||
@@ -26,3 +26,9 @@ class ArgumentException(Exception):
|
||||
class NotFoundException(Exception):
|
||||
def __init__(self, msg):
|
||||
self.msg = msg
|
||||
|
||||
class ModelException(Exception):
|
||||
def __init__(self, msg, retryable=False):
|
||||
super().__init__(msg)
|
||||
self.msg = msg
|
||||
self.retryable = retryable
|
||||
98
docker/launch_admin_service.sh
Executable file
98
docker/launch_admin_service.sh
Executable file
@@ -0,0 +1,98 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Exit immediately if a command exits with a non-zero status
|
||||
set -e
|
||||
|
||||
# Function to load environment variables from .env file
|
||||
load_env_file() {
|
||||
# Get the directory of the current script
|
||||
local script_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
local env_file="$script_dir/.env"
|
||||
|
||||
# Check if .env file exists
|
||||
if [ -f "$env_file" ]; then
|
||||
echo "Loading environment variables from: $env_file"
|
||||
# Source the .env file
|
||||
set -a
|
||||
source "$env_file"
|
||||
set +a
|
||||
else
|
||||
echo "Warning: .env file not found at: $env_file"
|
||||
fi
|
||||
}
|
||||
|
||||
# Load environment variables
|
||||
load_env_file
|
||||
|
||||
# Unset HTTP proxies that might be set by Docker daemon
|
||||
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
|
||||
export PYTHONPATH=$(pwd)
|
||||
|
||||
export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu/
|
||||
JEMALLOC_PATH=$(pkg-config --variable=libdir jemalloc)/libjemalloc.so
|
||||
|
||||
PY=python3
|
||||
|
||||
# Set default number of workers if WS is not set or less than 1
|
||||
if [[ -z "$WS" || $WS -lt 1 ]]; then
|
||||
WS=1
|
||||
fi
|
||||
|
||||
# Maximum number of retries for each task executor and server
|
||||
MAX_RETRIES=5
|
||||
|
||||
# Flag to control termination
|
||||
STOP=false
|
||||
|
||||
# Array to keep track of child PIDs
|
||||
PIDS=()
|
||||
|
||||
# Set the path to the NLTK data directory
|
||||
export NLTK_DATA="./nltk_data"
|
||||
|
||||
# Function to handle termination signals
|
||||
cleanup() {
|
||||
echo "Termination signal received. Shutting down..."
|
||||
STOP=true
|
||||
# Terminate all child processes
|
||||
for pid in "${PIDS[@]}"; do
|
||||
if kill -0 "$pid" 2>/dev/null; then
|
||||
echo "Killing process $pid"
|
||||
kill "$pid"
|
||||
fi
|
||||
done
|
||||
exit 0
|
||||
}
|
||||
|
||||
# Trap SIGINT and SIGTERM to invoke cleanup
|
||||
trap cleanup SIGINT SIGTERM
|
||||
|
||||
# Function to execute admin_server with retry logic
|
||||
run_server(){
|
||||
local retry_count=0
|
||||
while ! $STOP && [ $retry_count -lt $MAX_RETRIES ]; do
|
||||
echo "Starting admin_server.py (Attempt $((retry_count+1)))"
|
||||
$PY admin/server/admin_server.py
|
||||
EXIT_CODE=$?
|
||||
if [ $EXIT_CODE -eq 0 ]; then
|
||||
echo "admin_server.py exited successfully."
|
||||
break
|
||||
else
|
||||
echo "admin_server.py failed with exit code $EXIT_CODE. Retrying..." >&2
|
||||
retry_count=$((retry_count + 1))
|
||||
sleep 2
|
||||
fi
|
||||
done
|
||||
|
||||
if [ $retry_count -ge $MAX_RETRIES ]; then
|
||||
echo "admin_server.py failed after $MAX_RETRIES attempts. Exiting..." >&2
|
||||
cleanup
|
||||
fi
|
||||
}
|
||||
|
||||
# Start the main server
|
||||
run_server &
|
||||
PIDS+=($!)
|
||||
|
||||
# Wait for all background processes to finish
|
||||
wait
|
||||
@@ -27,6 +27,7 @@ from ollama import Client
|
||||
from openai import OpenAI
|
||||
from zhipuai import ZhipuAI
|
||||
|
||||
from common.exceptions import ModelException
|
||||
from common.log_utils import log_exception
|
||||
from common.token_utils import num_tokens_from_string, truncate, total_token_count_from_response
|
||||
from common import settings
|
||||
@@ -36,6 +37,14 @@ import base64
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _raise_model_exception_if_failed(resp):
|
||||
status_code = resp.status_code
|
||||
if status_code >= 400:
|
||||
if status_code < 500 and status_code not in [408, 429]:
|
||||
raise ModelException(f"status: {resp.status_code}, response: {resp.text}", retryable=False)
|
||||
raise ModelException(f"status: {resp.status_code}, response: {resp.text}", retryable=True)
|
||||
|
||||
|
||||
def _dashscope_base_url_for_log(base_url: str) -> str:
|
||||
"""Log host/path only (no query string) so secrets in URLs are not printed."""
|
||||
return base_url.split("?", 1)[0].strip()[:256]
|
||||
@@ -172,7 +181,10 @@ class OpenAIEmbed(Base):
|
||||
ress = []
|
||||
total_tokens = 0
|
||||
for i in range(0, len(texts), batch_size):
|
||||
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name, encoding_format="float", extra_body={"drop_params": True})
|
||||
try:
|
||||
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name, encoding_format="float", extra_body={"drop_params": True})
|
||||
except Exception as _e:
|
||||
raise ModelException(f"Error: {_e}")
|
||||
try:
|
||||
ress.extend([d.embedding for d in res.data])
|
||||
total_tokens += total_token_count_from_response(res)
|
||||
@@ -182,7 +194,10 @@ class OpenAIEmbed(Base):
|
||||
return np.array(ress), total_tokens
|
||||
|
||||
def encode_queries(self, text):
|
||||
res = self.client.embeddings.create(input=[truncate(text, 8191)], model=self.model_name, encoding_format="float", extra_body={"drop_params": True})
|
||||
try:
|
||||
res = self.client.embeddings.create(input=[truncate(text, 8191)], model=self.model_name, encoding_format="float", extra_body={"drop_params": True})
|
||||
except Exception as _e:
|
||||
raise ModelException(f"Error: {_e}")
|
||||
try:
|
||||
return np.array(res.data[0].embedding), total_token_count_from_response(res)
|
||||
except Exception as _e:
|
||||
@@ -294,20 +309,23 @@ class QWenEmbed(Base):
|
||||
token_count = 0
|
||||
texts = [truncate(t, 2048) for t in texts]
|
||||
for i in range(0, len(texts), batch_size):
|
||||
retry_max = 5
|
||||
with _dashscope_native_api_url_scope(self._dashscope_http_api_url):
|
||||
resp = dashscope.TextEmbedding.call(model=self.model_name, input=texts[i : i + batch_size], api_key=self.key, text_type="document")
|
||||
while (resp["output"] is None or resp["output"].get("embeddings") is None) and retry_max > 0:
|
||||
time.sleep(10)
|
||||
|
||||
retry_max, retry_wait_secs = 5, 10
|
||||
for retry in range(retry_max):
|
||||
with _dashscope_native_api_url_scope(self._dashscope_http_api_url):
|
||||
resp = dashscope.TextEmbedding.call(model=self.model_name, input=texts[i : i + batch_size], api_key=self.key, text_type="document")
|
||||
retry_max -= 1
|
||||
if retry_max == 0 and (resp["output"] is None or resp["output"].get("embeddings") is None):
|
||||
if resp.get("message"):
|
||||
log_exception(ValueError(f"Retry_max reached, calling embedding model failed: {resp['message']}"))
|
||||
status_code = resp.status_code
|
||||
if status_code >= 400 and status_code < 500 and status_code not in [408, 429]:
|
||||
raise ModelException(f"Error, status: {status_code}, response: {resp}")
|
||||
# No need to retry for 4XX error
|
||||
if status_code == 200:
|
||||
break
|
||||
if retry < retry_max - 1:
|
||||
logging.warning(f"Got error response from DashScope API (status: {status_code}, response: {resp}). Wait {retry_wait_secs} seconds. Retrying...")
|
||||
time.sleep(retry_wait_secs)
|
||||
else:
|
||||
log_exception(ValueError("Retry_max reached, calling embedding model failed"))
|
||||
raise
|
||||
raise ModelException(f"Error after {retry_max} retries., status: {status_code}, response: {resp}")
|
||||
|
||||
try:
|
||||
embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
|
||||
for e in resp["output"]["embeddings"]:
|
||||
@@ -316,17 +334,21 @@ class QWenEmbed(Base):
|
||||
token_count += total_token_count_from_response(resp)
|
||||
except Exception as _e:
|
||||
log_exception(_e, resp)
|
||||
raise
|
||||
raise ModelException(f"Error: {status_code}: {resp}")
|
||||
return np.array(res), token_count
|
||||
|
||||
def encode_queries(self, text):
|
||||
with _dashscope_native_api_url_scope(self._dashscope_http_api_url):
|
||||
resp = dashscope.TextEmbedding.call(model=self.model_name, input=text[:2048], api_key=self.key, text_type="query")
|
||||
status_code = resp.status_code
|
||||
if status_code != 200:
|
||||
raise ModelException(f"Error: status: {status_code}: code: {resp.get('code')}, message: {resp.get('message')}")
|
||||
# No need to retry for 4XX error
|
||||
try:
|
||||
return np.array(resp["output"]["embeddings"][0]["embedding"]), total_token_count_from_response(resp)
|
||||
except Exception as _e:
|
||||
log_exception(_e, resp)
|
||||
raise Exception(f"Error: {resp}")
|
||||
raise ModelException(f"Error: {status_code}: {resp}")
|
||||
|
||||
|
||||
class ZhipuEmbed(Base):
|
||||
@@ -494,6 +516,7 @@ class JinaMultiVecEmbed(Base):
|
||||
data["truncate"] = True
|
||||
|
||||
response = requests.post(self.base_url, headers=self.headers, json=data, timeout=30)
|
||||
_raise_model_exception_if_failed(response)
|
||||
try:
|
||||
res = response.json()
|
||||
for d in res["data"]:
|
||||
@@ -772,6 +795,7 @@ class NvidiaEmbed(Base):
|
||||
"truncate": "END",
|
||||
}
|
||||
response = requests.post(self.base_url, headers=self.headers, json=payload, timeout=30)
|
||||
_raise_model_exception_if_failed(response)
|
||||
try:
|
||||
res = response.json()
|
||||
ress.extend([d["embedding"] for d in res["data"]])
|
||||
@@ -912,6 +936,7 @@ class SILICONFLOWEmbed(Base):
|
||||
"encoding_format": "float",
|
||||
}
|
||||
response = requests.post(self.base_url, json=payload, headers=self.headers, timeout=30)
|
||||
_raise_model_exception_if_failed(response)
|
||||
try:
|
||||
res = response.json()
|
||||
ress.extend([d["embedding"] for d in res["data"]])
|
||||
@@ -929,6 +954,7 @@ class SILICONFLOWEmbed(Base):
|
||||
"encoding_format": "float",
|
||||
}
|
||||
response = requests.post(self.base_url, json=payload, headers=self.headers, timeout=30)
|
||||
_raise_model_exception_if_failed(response)
|
||||
try:
|
||||
res = response.json()
|
||||
return np.array(res["data"][0]["embedding"]), total_token_count_from_response(res)
|
||||
@@ -1039,19 +1065,15 @@ class HuggingFaceEmbed(Base):
|
||||
|
||||
def encode(self, texts: list):
|
||||
response = requests.post(f"{self.base_url}/embed", json={"inputs": texts}, headers={"Content-Type": "application/json"}, timeout=30)
|
||||
if response.status_code == 200:
|
||||
embeddings = response.json()
|
||||
else:
|
||||
raise Exception(f"Error: {response.status_code} - {response.text}")
|
||||
_raise_model_exception_if_failed(response)
|
||||
embeddings = response.json()
|
||||
return np.array(embeddings), sum([num_tokens_from_string(text) for text in texts])
|
||||
|
||||
def encode_queries(self, text: str):
|
||||
response = requests.post(f"{self.base_url}/embed", json={"inputs": text}, headers={"Content-Type": "application/json"}, timeout=30)
|
||||
if response.status_code == 200:
|
||||
embedding = response.json()[0]
|
||||
return np.array(embedding), num_tokens_from_string(text)
|
||||
else:
|
||||
raise Exception(f"Error: {response.status_code} - {response.text}")
|
||||
_raise_model_exception_if_failed(response)
|
||||
embedding = response.json()[0]
|
||||
return np.array(embedding), num_tokens_from_string(text)
|
||||
|
||||
|
||||
class VolcEngineEmbed(Base):
|
||||
@@ -1248,6 +1270,7 @@ class PerplexityEmbed(Base):
|
||||
"encoding_format": "base64_int8",
|
||||
}
|
||||
response = requests.post(url, headers=self.headers, json=payload, timeout=30)
|
||||
_raise_model_exception_if_failed(response)
|
||||
try:
|
||||
res = response.json()
|
||||
for doc in res["data"]:
|
||||
@@ -1267,6 +1290,7 @@ class PerplexityEmbed(Base):
|
||||
"encoding_format": "base64_int8",
|
||||
}
|
||||
response = requests.post(url, headers=self.headers, json=payload, timeout=30)
|
||||
_raise_model_exception_if_failed(response)
|
||||
try:
|
||||
res = response.json()
|
||||
for d in res["data"]:
|
||||
|
||||
@@ -57,6 +57,16 @@ def _mock_contextualized_response(docs_embeddings_b64, total_tokens=20):
|
||||
}
|
||||
|
||||
|
||||
def _mock_http_response(json_response=None, status_code=200, text=""):
|
||||
"""Build a minimal requests.Response-like mock."""
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = status_code
|
||||
mock_resp.text = text
|
||||
if json_response is not None:
|
||||
mock_resp.json.return_value = json_response
|
||||
return mock_resp
|
||||
|
||||
|
||||
class TestPerplexityEmbedInit:
|
||||
def test_default_base_url(self):
|
||||
embed = PerplexityEmbed("test-key", "pplx-embed-v1-0.6b")
|
||||
@@ -125,8 +135,7 @@ class TestPerplexityEmbedStandardEncode:
|
||||
@patch("rag.llm.embedding_model.requests.post")
|
||||
def test_encode_single_text(self, mock_post):
|
||||
emb_b64 = _make_b64_int8([10, 20, 30])
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = _mock_standard_response([emb_b64], total_tokens=5)
|
||||
mock_resp = _mock_http_response(_mock_standard_response([emb_b64], total_tokens=5))
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
embed = PerplexityEmbed("key", "pplx-embed-v1-0.6b")
|
||||
@@ -144,8 +153,7 @@ class TestPerplexityEmbedStandardEncode:
|
||||
emb1 = _make_b64_int8([1, 2])
|
||||
emb2 = _make_b64_int8([3, 4])
|
||||
emb3 = _make_b64_int8([5, 6])
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = _mock_standard_response([emb1, emb2, emb3], total_tokens=15)
|
||||
mock_resp = _mock_http_response(_mock_standard_response([emb1, emb2, emb3], total_tokens=15))
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
embed = PerplexityEmbed("key", "pplx-embed-v1-0.6b")
|
||||
@@ -156,8 +164,7 @@ class TestPerplexityEmbedStandardEncode:
|
||||
|
||||
@patch("rag.llm.embedding_model.requests.post")
|
||||
def test_encode_sends_correct_payload(self, mock_post):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = _mock_standard_response([_make_b64_int8([1])], total_tokens=1)
|
||||
mock_resp = _mock_http_response(_mock_standard_response([_make_b64_int8([1])], total_tokens=1))
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
embed = PerplexityEmbed("key", "pplx-embed-v1-4b")
|
||||
@@ -171,9 +178,8 @@ class TestPerplexityEmbedStandardEncode:
|
||||
|
||||
@patch("rag.llm.embedding_model.requests.post")
|
||||
def test_encode_api_error_raises(self, mock_post):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp = _mock_http_response(text="Internal Server Error")
|
||||
mock_resp.json.side_effect = Exception("Invalid JSON")
|
||||
mock_resp.text = "Internal Server Error"
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
embed = PerplexityEmbed("key", "pplx-embed-v1-0.6b")
|
||||
@@ -186,8 +192,7 @@ class TestPerplexityEmbedContextualizedEncode:
|
||||
def test_contextualized_encode(self, mock_post):
|
||||
emb1 = _make_b64_int8([10, 20])
|
||||
emb2 = _make_b64_int8([30, 40])
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = _mock_contextualized_response([[emb1], [emb2]], total_tokens=12)
|
||||
mock_resp = _mock_http_response(_mock_contextualized_response([[emb1], [emb2]], total_tokens=12))
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
embed = PerplexityEmbed("key", "pplx-embed-context-v1-0.6b")
|
||||
@@ -200,8 +205,7 @@ class TestPerplexityEmbedContextualizedEncode:
|
||||
|
||||
@patch("rag.llm.embedding_model.requests.post")
|
||||
def test_contextualized_uses_correct_endpoint(self, mock_post):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = _mock_contextualized_response([[_make_b64_int8([1])]], total_tokens=1)
|
||||
mock_resp = _mock_http_response(_mock_contextualized_response([[_make_b64_int8([1])]], total_tokens=1))
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
embed = PerplexityEmbed("key", "pplx-embed-context-v1-4b")
|
||||
@@ -212,8 +216,7 @@ class TestPerplexityEmbedContextualizedEncode:
|
||||
|
||||
@patch("rag.llm.embedding_model.requests.post")
|
||||
def test_contextualized_sends_nested_input(self, mock_post):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = _mock_contextualized_response([[_make_b64_int8([1])]], total_tokens=1)
|
||||
mock_resp = _mock_http_response(_mock_contextualized_response([[_make_b64_int8([1])]], total_tokens=1))
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
embed = PerplexityEmbed("key", "pplx-embed-context-v1-0.6b")
|
||||
@@ -228,8 +231,7 @@ class TestPerplexityEmbedEncodeQueries:
|
||||
@patch("rag.llm.embedding_model.requests.post")
|
||||
def test_encode_queries_returns_single_vector(self, mock_post):
|
||||
emb = _make_b64_int8([5, 10, 15, 20])
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = _mock_standard_response([emb], total_tokens=3)
|
||||
mock_resp = _mock_http_response(_mock_standard_response([emb], total_tokens=3))
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
embed = PerplexityEmbed("key", "pplx-embed-v1-0.6b")
|
||||
|
||||
Reference in New Issue
Block a user