diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 7a549b69d0..4289f507b5 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -494,12 +494,14 @@ async def async_chat(dialog, messages, stream=True, **kwargs): retriever = settings.retriever questions = [m["content"] for m in messages if m["role"] == "user"][-3:] - attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else [] + attachments = None + if "doc_ids" in kwargs: + attachments = [doc_id for doc_id in kwargs["doc_ids"].split(",") if doc_id] attachments_= "" image_attachments = [] image_files = [] if "doc_ids" in messages[-1]: - attachments = messages[-1]["doc_ids"] + attachments = [doc_id for doc_id in messages[-1]["doc_ids"] if doc_id] if "files" in messages[-1]: if llm_type == "chat": text_attachments, image_attachments = split_file_attachments(messages[-1]["files"]) @@ -559,7 +561,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs): kbinfos = {"total": 0, "chunks": [], "doc_aggs": []} knowledges = [] - if attachments is not None and "knowledge" in param_keys: + if "knowledge" in param_keys: logging.debug("Proceeding with retrieval") tenant_ids = list(set([kb.tenant_id for kb in kbs])) knowledges = [] diff --git a/test/unit_test/api/db/services/test_dialog_service_use_sql_source_columns.py b/test/unit_test/api/db/services/test_dialog_service_use_sql_source_columns.py index a79d935817..71941e3874 100644 --- a/test/unit_test/api/db/services/test_dialog_service_use_sql_source_columns.py +++ b/test/unit_test/api/db/services/test_dialog_service_use_sql_source_columns.py @@ -17,6 +17,7 @@ import asyncio import sys import types import warnings +from types import SimpleNamespace import pytest @@ -101,6 +102,19 @@ class _StubRetriever: return self._results[idx] +class _StubAsyncRetriever: + def __init__(self, result): + self.result = result + self.calls = [] + + async def retrieval(self, *args, **kwargs): + self.calls.append({"args": args, "kwargs": kwargs}) + return self.result + + def retrieval_by_children(self, chunks, tenant_ids): + return chunks + + @pytest.fixture def force_es_engine(monkeypatch): monkeypatch.setattr(dialog_service.settings, "DOC_ENGINE_INFINITY", False) @@ -219,3 +233,84 @@ def test_use_sql_source_repair_is_bounded_to_single_retry(monkeypatch, force_es_ assert "Source" not in result["answer"] assert len(chat_model.calls) == 2 assert len(retriever.sql_calls) == 2 + + +@pytest.mark.p2 +def test_async_chat_uses_all_docs_when_no_doc_ids_selected(monkeypatch): + retriever = _StubAsyncRetriever( + { + "total": 1, + "chunks": [ + { + "chunk_id": "chunk-1", + "content_ltks": "chunk text", + "content_with_weight": "Chunk text from dataset.", + "doc_id": "doc-1", + "docnm_kwd": "doc.txt", + "kb_id": "kb-1", + "important_kwd": [], + "positions": [], + "vector": [0.1, 0.2], + } + ], + "doc_aggs": [], + } + ) + chat_model = _StubChatModel(["stub answer"]) + dialog = SimpleNamespace( + kb_ids=["kb-1"], + llm_id="chat-model", + tenant_id="tenant-id", + llm_setting={}, + similarity_threshold=0.1, + vector_similarity_weight=0.2, + top_n=8, + top_k=32, + meta_data_filter=None, + prompt_config={ + "quote": False, + "keyword": False, + "tts": False, + "empty_response": "", + "system": "Use only this knowledge: {knowledge}", + "parameters": [{"key": "knowledge", "optional": False}], + "reasoning": False, + "toc_enhance": False, + "use_kg": False, + }, + ) + + monkeypatch.setattr(dialog_service.settings, "retriever", retriever, raising=False) + monkeypatch.setattr(dialog_service.TenantLLMService, "llm_id2llm_type", lambda _llm_id: "chat") + monkeypatch.setattr( + dialog_service.TenantLLMService, + "get_model_config", + lambda *_args, **_kwargs: {"llm_factory": "unit", "max_tokens": 4096}, + ) + monkeypatch.setattr(dialog_service.TenantLangfuseService, "filter_by_tenant", lambda **_kwargs: None) + monkeypatch.setattr( + dialog_service, + "get_models", + lambda _dialog: ([SimpleNamespace(tenant_id="tenant-id")], object(), None, chat_model, None), + ) + monkeypatch.setattr(dialog_service.KnowledgebaseService, "get_field_map", lambda _kb_ids: {}) + monkeypatch.setattr(dialog_service, "label_question", lambda _question, _kbs: None) + monkeypatch.setattr( + dialog_service, + "kb_prompt", + lambda kbinfos, _max_tokens: ["Chunk text from dataset."] if kbinfos["chunks"] else [], + ) + monkeypatch.setattr(dialog_service, "message_fit_in", lambda msg, _max_tokens: (0, msg)) + + async def _collect(): + items = [] + async for item in dialog_service.async_chat(dialog, [{"role": "user", "content": "What does the dataset say?"}], stream=False): + items.append(item) + return items + + result = asyncio.run(_collect()) + + assert len(retriever.calls) == 1 + assert retriever.calls[0]["kwargs"]["doc_ids"] is None + assert "Chunk text from dataset." in chat_model.calls[0]["system_prompt"] + assert result[0]["answer"] == "stub answer"