From fabbfcab909ecc952ffc0eeca0c271404e25b9e3 Mon Sep 17 00:00:00 2001 From: 6ba3i <112825897+6ba3i@users.noreply.github.com> Date: Mon, 9 Feb 2026 14:56:10 +0800 Subject: [PATCH] Fix: failing p3 test for SDK/HTTP APIs (#13062) ### What problem does this PR solve? Adjust highlight parsing, add row-count SQL override, tweak retrieval thresholding, and update tests with engine-aware skips/utilities. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- api/apps/sdk/doc.py | 12 ++++- api/db/services/dialog_service.py | 24 +++++++++- rag/nlp/search.py | 4 +- .../test_retrieval_chunks.py | 5 +- .../test_list_chunks.py | 9 ++++ .../test_retrieval_chunks.py | 10 ++-- .../test_update_dataset.py | 7 +++ .../test_create_memory.py | 1 + .../test_list_message.py | 3 ++ test/testcases/utils/engine_utils.py | 47 +++++++++++++++++++ 10 files changed, 110 insertions(+), 12 deletions(-) create mode 100644 test/testcases/utils/engine_utils.py diff --git a/api/apps/sdk/doc.py b/api/apps/sdk/doc.py index 44cb077359..d8b81dce7e 100644 --- a/api/apps/sdk/doc.py +++ b/api/apps/sdk/doc.py @@ -1549,10 +1549,18 @@ async def retrieval_test(tenant_id): similarity_threshold = float(req.get("similarity_threshold", 0.2)) vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3)) top = int(req.get("top_k", 1024)) - if req.get("highlight") == "False" or req.get("highlight") == "false": + highlight_val = req.get("highlight", None) + if highlight_val is None: highlight = False + elif isinstance(highlight_val, bool): + highlight = highlight_val + elif isinstance(highlight_val, str): + if highlight_val.lower() in ["true", "false"]: + highlight = highlight_val.lower() == "true" + else: + return get_error_data_result("`highlight` should be a boolean") else: - highlight = True + return get_error_data_result("`highlight` should be a boolean") try: tenant_ids = list(set([kb.tenant_id for kb in kbs])) e, kb = KnowledgebaseService.get_by_id(kb_ids[0]) diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 3940a8a2fa..66025d13ef 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -606,10 +606,21 @@ async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=N table_name = base_table logging.debug(f"use_sql: Using ES/OS table name: {table_name}") + def is_row_count_question(q: str) -> bool: + q = (q or "").lower() + if not re.search(r"\bhow many rows\b|\bnumber of rows\b|\brow count\b", q): + return False + return bool(re.search(r"\bdataset\b|\btable\b|\bspreadsheet\b|\bexcel\b", q)) + # Generate engine-specific SQL prompts if doc_engine == "infinity": # Build Infinity prompts with JSON extraction context json_field_names = list(field_map.keys()) + row_count_override = ( + f"SELECT COUNT(*) AS rows FROM {table_name}" + if is_row_count_question(question) + else None + ) sys_prompt = """You are a Database Administrator. Write SQL for a table with JSON 'chunk_data' column. JSON Extraction: json_extract_string(chunk_data, '$.FieldName') @@ -641,6 +652,11 @@ Write SQL using json_extract_string() with exact field names. Include doc_id, do elif doc_engine == "oceanbase": # Build OceanBase prompts with JSON extraction context json_field_names = list(field_map.keys()) + row_count_override = ( + f"SELECT COUNT(*) AS rows FROM {table_name}" + if is_row_count_question(question) + else None + ) sys_prompt = """You are a Database Administrator. Write SQL for a table with JSON 'chunk_data' column. JSON Extraction: json_extract_string(chunk_data, '$.FieldName') @@ -671,6 +687,7 @@ Write SQL using json_extract_string() with exact field names. Include doc_id, do ) else: # Build ES/OS prompts with direct field access + row_count_override = None sys_prompt = """You are a Database Administrator. Write SQL queries. RULES: @@ -693,8 +710,11 @@ Write SQL using exact field names above. Include doc_id, docnm_kwd for data quer tried_times = 0 async def get_table(): - nonlocal sys_prompt, user_prompt, question, tried_times - sql = await chat_mdl.async_chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06}) + nonlocal sys_prompt, user_prompt, question, tried_times, row_count_override + if row_count_override: + sql = row_count_override + else: + sql = await chat_mdl.async_chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06}) logging.debug(f"use_sql: Raw SQL from LLM: {repr(sql[:500])}") # Remove think blocks if present (format: ...) sql = re.sub(r"\n.*?\n\s*", "", sql, flags=re.DOTALL) diff --git a/rag/nlp/search.py b/rag/nlp/search.py index d6cd6de510..a36a8d967a 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -434,7 +434,9 @@ class Dealer: sorted_idx = np.argsort(sim_np * -1) - valid_idx = [int(i) for i in sorted_idx if sim_np[i] >= similarity_threshold] + # When vector_similarity_weight is 0, similarity_threshold is not meaningful for term-only scores. + post_threshold = 0.0 if vector_similarity_weight <= 0 else similarity_threshold + valid_idx = [int(i) for i in sorted_idx if sim_np[i] >= post_threshold] filtered_count = len(valid_idx) ranks["total"] = int(filtered_count) diff --git a/test/testcases/test_http_api/test_chunk_management_within_dataset/test_retrieval_chunks.py b/test/testcases/test_http_api/test_chunk_management_within_dataset/test_retrieval_chunks.py index 4a05d29bac..2c94f2d30e 100644 --- a/test/testcases/test_http_api/test_chunk_management_within_dataset/test_retrieval_chunks.py +++ b/test/testcases/test_http_api/test_chunk_management_within_dataset/test_retrieval_chunks.py @@ -272,7 +272,7 @@ class TestChunksRetrieval: [ ({"highlight": True}, 0, True, ""), ({"highlight": "True"}, 0, True, ""), - pytest.param({"highlight": False}, 0, False, "", marks=pytest.mark.skip(reason="issues/6648")), + ({"highlight": False}, 0, False, ""), ({"highlight": "False"}, 0, False, ""), pytest.param({"highlight": None}, 0, False, "", marks=pytest.mark.skip(reason="issues/6648")), ], @@ -282,8 +282,7 @@ class TestChunksRetrieval: payload.update({"question": "chunk", "dataset_ids": [dataset_id]}) res = retrieval_chunks(HttpApiAuth, payload) assert res["code"] == expected_code - doc_engine = os.environ.get("DOC_ENGINE", "elasticsearch").lower() - if expected_highlight and doc_engine != "infinity": + if expected_highlight: for chunk in res["data"]["chunks"]: assert "highlight" in chunk else: diff --git a/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_list_chunks.py b/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_list_chunks.py index e29378528f..4174d3fb14 100644 --- a/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_list_chunks.py +++ b/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_list_chunks.py @@ -18,6 +18,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed import pytest from common import batch_add_chunks +from utils.engine_utils import get_doc_engine class TestChunksList: @@ -84,6 +85,12 @@ class TestChunksList: ) def test_keywords(self, add_chunks, params, expected_page_size): _, document, _ = add_chunks + if params.get("keywords") == "ragflow": + doc_engine = get_doc_engine(document.rag) + if doc_engine == "infinity" and expected_page_size == 1: + pytest.skip("issues/6509") + if doc_engine != "infinity" and expected_page_size == 5: + pytest.skip("issues/6509") chunks = document.list_chunks(**params) assert len(chunks) == expected_page_size, str(chunks) @@ -99,6 +106,8 @@ class TestChunksList: ) def test_id(self, add_chunks, chunk_id, expected_page_size, expected_message): _, document, chunks = add_chunks + if callable(chunk_id) and get_doc_engine(document.rag) == "infinity": + pytest.skip("issues/6499") chunk_ids = [chunk.id for chunk in chunks] if callable(chunk_id): params = {"id": chunk_id(chunk_ids)} diff --git a/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_retrieval_chunks.py b/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_retrieval_chunks.py index 2834cfba91..9e62b30918 100644 --- a/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_retrieval_chunks.py +++ b/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_retrieval_chunks.py @@ -18,6 +18,8 @@ from concurrent.futures import ThreadPoolExecutor, as_completed import pytest +DOC_ENGINE = (os.getenv("DOC_ENGINE") or "").lower() + class TestChunksRetrieval: @pytest.mark.p1 @@ -159,25 +161,25 @@ class TestChunksRetrieval: {"top_k": 1}, 4, "", - marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in ["infinity", "opensearch"], reason="Infinity"), + marks=pytest.mark.skipif(DOC_ENGINE in ["infinity", "opensearch"], reason="Infinity"), ), pytest.param( {"top_k": 1}, 1, "", - marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "opensearch", "elasticsearch"], reason="elasticsearch"), + marks=pytest.mark.skipif(DOC_ENGINE in ["", "opensearch", "elasticsearch"], reason="elasticsearch"), ), pytest.param( {"top_k": -1}, 4, "must be greater than 0", - marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in ["infinity", "opensearch"], reason="Infinity"), + marks=pytest.mark.skipif(DOC_ENGINE in ["infinity", "opensearch"], reason="Infinity"), ), pytest.param( {"top_k": -1}, 4, "3014", - marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "opensearch", "elasticsearch"], reason="elasticsearch"), + marks=pytest.mark.skipif(DOC_ENGINE in ["", "opensearch", "elasticsearch"], reason="elasticsearch"), ), pytest.param( {"top_k": "a"}, diff --git a/test/testcases/test_sdk_api/test_dataset_mangement/test_update_dataset.py b/test/testcases/test_sdk_api/test_dataset_mangement/test_update_dataset.py index e39b42374d..942e3b5fff 100644 --- a/test/testcases/test_sdk_api/test_dataset_mangement/test_update_dataset.py +++ b/test/testcases/test_sdk_api/test_dataset_mangement/test_update_dataset.py @@ -25,6 +25,7 @@ from utils import encode_avatar from utils.file_utils import create_image_file from utils.hypothesis_utils import valid_names from configs import DEFAULT_PARSER_CONFIG +from utils.engine_utils import get_doc_engine class TestRquest: @pytest.mark.p2 @@ -332,6 +333,8 @@ class TestDatasetUpdate: @pytest.mark.p2 @pytest.mark.parametrize("pagerank", [0, 50, 100], ids=["min", "mid", "max"]) def test_pagerank(self, client, add_dataset_func, pagerank): + if get_doc_engine(client) == "infinity": + pytest.skip("#8208") dataset = add_dataset_func dataset.update({"pagerank": pagerank}) assert dataset.pagerank == pagerank, str(dataset) @@ -342,6 +345,8 @@ class TestDatasetUpdate: @pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="#8208") @pytest.mark.p2 def test_pagerank_set_to_0(self, client, add_dataset_func): + if get_doc_engine(client) == "infinity": + pytest.skip("#8208") dataset = add_dataset_func dataset.update({"pagerank": 50}) assert dataset.pagerank == 50, str(dataset) @@ -358,6 +363,8 @@ class TestDatasetUpdate: @pytest.mark.skipif(os.getenv("DOC_ENGINE") != "infinity", reason="#8208") @pytest.mark.p2 def test_pagerank_infinity(self, client, add_dataset_func): + if get_doc_engine(client) != "infinity": + pytest.skip("#8208") dataset = add_dataset_func with pytest.raises(Exception) as exception_info: dataset.update({"pagerank": 50}) diff --git a/test/testcases/test_sdk_api/test_memory_management/test_create_memory.py b/test/testcases/test_sdk_api/test_memory_management/test_create_memory.py index c1852b119d..2c9a3e7c7d 100644 --- a/test/testcases/test_sdk_api/test_memory_management/test_create_memory.py +++ b/test/testcases/test_sdk_api/test_memory_management/test_create_memory.py @@ -81,6 +81,7 @@ class TestMemoryCreate: @pytest.mark.p2 @given(name=valid_names()) + @settings(deadline=None) def test_type_invalid(self, client, name): payload = { "name": name, diff --git a/test/testcases/test_sdk_api/test_message_management/test_list_message.py b/test/testcases/test_sdk_api/test_message_management/test_list_message.py index d7cdb7ed3b..fc7578353d 100644 --- a/test/testcases/test_sdk_api/test_message_management/test_list_message.py +++ b/test/testcases/test_sdk_api/test_message_management/test_list_message.py @@ -19,6 +19,7 @@ import random import pytest from ragflow_sdk import RAGFlow, Memory from configs import INVALID_API_TOKEN, HOST_ADDRESS +from utils.engine_utils import get_doc_engine class TestAuthorization: @@ -88,6 +89,8 @@ class TestMessageList: @pytest.mark.p2 @pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="Not support.") def test_search_keyword(self, client): + if get_doc_engine(client) == "infinity": + pytest.skip("Not support.") memory_id = self.memory_id session_ids = self.session_ids session_id = random.choice(session_ids) diff --git a/test/testcases/utils/engine_utils.py b/test/testcases/utils/engine_utils.py new file mode 100644 index 0000000000..8a54bed212 --- /dev/null +++ b/test/testcases/utils/engine_utils.py @@ -0,0 +1,47 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import requests + +_DOC_ENGINE_CACHE = None + + +def get_doc_engine(rag=None) -> str: + """Return lower-cased doc_engine from env, or from /system/status if env is unset.""" + global _DOC_ENGINE_CACHE + env = (os.getenv("DOC_ENGINE") or "").strip().lower() + if env: + _DOC_ENGINE_CACHE = env + return env + if _DOC_ENGINE_CACHE: + return _DOC_ENGINE_CACHE + if rag is None: + return "" + try: + api_url = getattr(rag, "api_url", "") + if "/api/" in api_url: + base_url, version = api_url.rsplit("/api/", 1) + status_url = f"{base_url}/{version}/system/status" + else: + status_url = f"{api_url}/system/status" + headers = getattr(rag, "authorization_header", {}) + res = requests.get(status_url, headers=headers).json() + engine = str(res.get("data", {}).get("doc_engine", {}).get("type", "")).lower() + if engine: + _DOC_ENGINE_CACHE = engine + return engine + except Exception: + return ""