From de06c9a60bd5718eaf367793ff53165542c6c0cb Mon Sep 17 00:00:00 2001 From: Jonathan Chang <55106972+jonathanchang31@users.noreply.github.com> Date: Fri, 12 Jun 2026 09:18:06 +0700 Subject: [PATCH] feat: Langfuse session grouping for multi-turn chat traces (#15679) ## Summary This PR passes `session_id` into Langfuse trace observations so multi-turn chat messages can be grouped under the same session in Langfuse. Changes include: - Propagate `session_id` from chat/session APIs into `dialog_service.async_chat`. - Pass `session_id` into Langfuse `start_observation(...)`. - Share Langfuse `trace_context` with chat, embedding, rerank, and TTS model bundles where applicable. - Add unit coverage to verify Langfuse observations receive `session_id`. - Update affected test stubs for the new optional Langfuse context arguments. ## Related Issue Closes: #15636 ## Change Type - [x] Feature - [x] Bug fix - [x] Test - [ ] Refactor - [ ] Documentation - [ ] Breaking change ## Real Behavior Proof Before this change: - Langfuse observations were created without `session_id`. - Multi-turn chat traces could not be grouped by session in Langfuse. After this change: - Chat/session flows pass `session_id` into `async_chat`. - Langfuse observations include `session_id`. - Related model bundles receive shared trace context and session metadata. Validation result: ```bash uv run python -m py_compile \ api/db/services/tenant_llm_service.py \ api/db/services/llm_service.py \ api/db/services/dialog_service.py \ api/db/services/conversation_service.py \ api/apps/restful_apis/chat_api.py \ test/unit_test/api/db/services/test_dialog_service_final_answer.py \ test/unit_test/api/db/services/test_dialog_service_use_sql_source_columns.py ``` Passed. ```bash uv run pytest \ test/unit_test/api/db/services/test_dialog_service_final_answer.py \ test/unit_test/api/db/services/test_dialog_service_use_sql_source_columns.py -q ``` Result: ```text 11 passed in 16.89s ``` ```bash git diff --check ``` Passed. ## Checklist - [x] Analyzed the issue requirement. - [x] Checked existing Langfuse trace integration. - [x] Implemented only the requested session grouping behavior. - [x] Added/updated unit tests. - [x] Ran focused tests successfully. - [x] Ran Python compile validation. - [x] Ran whitespace diff validation. --- api/apps/restful_apis/chat_api.py | 4 +- api/db/services/conversation_service.py | 8 +- api/db/services/dialog_service.py | 40 ++++--- api/db/services/llm_service.py | 29 +++-- api/db/services/tenant_llm_service.py | 7 +- .../test_dialog_service_final_answer.py | 105 +++++++++++++++++- ...t_dialog_service_use_sql_source_columns.py | 2 +- 7 files changed, 153 insertions(+), 42 deletions(-) diff --git a/api/apps/restful_apis/chat_api.py b/api/apps/restful_apis/chat_api.py index 52d36282ba..d6750f0907 100644 --- a/api/apps/restful_apis/chat_api.py +++ b/api/apps/restful_apis/chat_api.py @@ -1242,7 +1242,7 @@ async def session_completion(chat_id_in_arg=""): """Yield SSE-formatted chunks from the async chat generator.""" nonlocal dia, msg, req, conv try: - async for ans in async_chat(dia, msg, True, **req): + async for ans in async_chat(dia, msg, True, session_id=session_id, **req): ans = _format_answer(ans) payload = _sanitize_json_floats({"code": 0, "message": "", "data": ans}) yield "data:" + json.dumps(payload, ensure_ascii=False) + "\n\n" @@ -1262,7 +1262,7 @@ async def session_completion(chat_id_in_arg=""): return resp answer = None - async for ans in async_chat(dia, msg, False, **req): + async for ans in async_chat(dia, msg, False, session_id=session_id, **req): answer = _format_answer(ans) if conv is not None: await thread_pool_exec(ConversationService.update_by_id, conv.id, conv.to_dict()) diff --git a/api/db/services/conversation_service.py b/api/db/services/conversation_service.py index 2603676e98..e23d37aa55 100644 --- a/api/db/services/conversation_service.py +++ b/api/db/services/conversation_service.py @@ -187,7 +187,7 @@ async def async_completion(tenant_id, chat_id, question, name="New session", ses if stream: try: - async for ans in async_chat(dia, msg, True, **kwargs): + async for ans in async_chat(dia, msg, True, session_id=session_id, **kwargs): ans = structure_answer(conv, ans, message_id, session_id) yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n" ConversationService.update_by_id(conv.id, conv.to_dict()) @@ -199,7 +199,7 @@ async def async_completion(tenant_id, chat_id, question, name="New session", ses else: answer = None - async for ans in async_chat(dia, msg, False, **kwargs): + async for ans in async_chat(dia, msg, False, session_id=session_id, **kwargs): answer = structure_answer(conv, ans, message_id, session_id) ConversationService.update_by_id(conv.id, conv.to_dict()) break @@ -275,7 +275,7 @@ async def async_iframe_completion(dialog_id, question, session_id=None, stream=T if stream: try: - async for ans in async_chat(dia, msg, True, **kwargs): + async for ans in async_chat(dia, msg, True, session_id=session_id, **kwargs): ans = structure_answer(conv, ans, message_id, session_id) yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" @@ -288,7 +288,7 @@ async def async_iframe_completion(dialog_id, question, session_id=None, stream=T else: answer = None - async for ans in async_chat(dia, msg, False, **kwargs): + async for ans in async_chat(dia, msg, False, session_id=session_id, **kwargs): answer = structure_answer(conv, ans, message_id, session_id) API4ConversationService.append_message(conv.id, conv.to_dict()) break diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 72e8c1cf1d..f0f5a37f1f 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -286,7 +286,7 @@ class DialogService(CommonService): return list(objs) -async def async_chat_solo(dialog, messages, stream=True): +async def async_chat_solo(dialog, messages, stream=True, session_id=None): llm_types = get_model_type_by_name(dialog.tenant_id, dialog.llm_id) attachments = "" image_attachments = [] @@ -303,14 +303,14 @@ async def async_chat_solo(dialog, messages, stream=True): else: model_config = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.CHAT) - chat_mdl = LLMBundle(dialog.tenant_id, model_config) + chat_mdl = LLMBundle(dialog.tenant_id, model_config, langfuse_session_id=session_id) factory = model_config.get("llm_factory", "") if model_config else "" prompt_config = dialog.prompt_config tts_mdl = None if prompt_config.get("tts"): default_tts_model = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.TTS) - tts_mdl = LLMBundle(dialog.tenant_id, default_tts_model) + tts_mdl = LLMBundle(dialog.tenant_id, default_tts_model, trace_context=chat_mdl.trace_context, langfuse_session_id=session_id) msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"] if attachments and msg: msg[-1]["content"] += attachments @@ -337,7 +337,7 @@ async def async_chat_solo(dialog, messages, stream=True): yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()} -def get_models(dialog): +def get_models(dialog, trace_context=None, langfuse_session_id=None): embd_mdl, chat_mdl, rerank_mdl, tts_mdl = None, None, None, None kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids) embedding_list = list(set([kb.embd_id for kb in kbs])) @@ -347,7 +347,7 @@ def get_models(dialog): if embedding_list: embd_owner_tenant_id = kbs[0].tenant_id embd_model_config = get_model_config_from_provider_instance(embd_owner_tenant_id, LLMType.EMBEDDING, embedding_list[0]) - embd_mdl = LLMBundle(embd_owner_tenant_id, embd_model_config) + embd_mdl = LLMBundle(embd_owner_tenant_id, embd_model_config, trace_context=trace_context, langfuse_session_id=langfuse_session_id) if not embd_mdl: raise LookupError("Embedding model(%s) not found" % embedding_list[0]) @@ -356,15 +356,15 @@ def get_models(dialog): else: chat_model_config = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.CHAT) - chat_mdl = LLMBundle(dialog.tenant_id, chat_model_config) + chat_mdl = LLMBundle(dialog.tenant_id, chat_model_config, trace_context=trace_context, langfuse_session_id=langfuse_session_id) if dialog.rerank_id: rerank_model_config = get_model_config_from_provider_instance(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id) - rerank_mdl = LLMBundle(dialog.tenant_id, rerank_model_config) + rerank_mdl = LLMBundle(dialog.tenant_id, rerank_model_config, trace_context=trace_context, langfuse_session_id=langfuse_session_id) if dialog.prompt_config.get("tts"): default_tts_model_config = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.TTS) - tts_mdl = LLMBundle(dialog.tenant_id, default_tts_model_config) + tts_mdl = LLMBundle(dialog.tenant_id, default_tts_model_config, trace_context=trace_context, langfuse_session_id=langfuse_session_id) return kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl @@ -541,10 +541,11 @@ def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set): async def async_chat(dialog, messages, stream=True, **kwargs): logging.debug("Begin async_chat") assert messages[-1]["role"] == "user", "The last content of this conversation is not from user." + session_id = kwargs.get("session_id") use_web_search = _should_use_web_search(dialog.prompt_config, kwargs.get("internet")) logging.debug("web_search kb=%s tavily=%s internet=%r enabled=%s", bool(dialog.kb_ids), bool(dialog.prompt_config.get("tavily_api_key")), kwargs.get("internet"), use_web_search) if not dialog.kb_ids and not use_web_search: - async for ans in async_chat_solo(dialog, messages, stream): + async for ans in async_chat_solo(dialog, messages, stream, session_id=session_id): yield ans return @@ -579,7 +580,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs): pass check_langfuse_tracer_ts = timer() - kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl = get_models(dialog) + kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl = get_models(dialog, trace_context=trace_context, langfuse_session_id=session_id) toolcall_session, tools = kwargs.get("toolcall_session"), kwargs.get("tools") if toolcall_session and tools: chat_mdl.bind_tools(toolcall_session, tools) @@ -737,7 +738,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs): kbinfos["doc_aggs"].extend(tav_res["doc_aggs"]) if prompt_config.get("use_kg"): default_chat_model = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.CHAT) - ck = await settings.kg_retriever.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl, LLMBundle(dialog.tenant_id, default_chat_model)) + ck = await settings.kg_retriever.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl, LLMBundle(dialog.tenant_id, default_chat_model, trace_context=trace_context, langfuse_session_id=session_id)) if ck["content_with_weight"]: kbinfos["chunks"].insert(0, ck) @@ -866,13 +867,16 @@ async def async_chat(dialog, messages, stream=True, **kwargs): if langfuse_tracer: try: - langfuse_generation = langfuse_tracer.start_observation( - as_type="generation", - trace_context=trace_context, - name="chat", - model=llm_model_config["llm_name"], - input={"prompt": prompt, "prompt4citation": prompt4citation, "messages": msg}, - ) + observation_kwargs = { + "as_type": "generation", + "trace_context": trace_context, + "name": "chat", + "model": llm_model_config["llm_name"], + "input": {"prompt": prompt, "prompt4citation": prompt4citation, "messages": msg}, + } + if session_id: + observation_kwargs["session_id"] = session_id + langfuse_generation = langfuse_tracer.start_observation(**observation_kwargs) except Exception as e: # noqa: BLE001 - tracing must not break chat flow logger.warning("Langfuse start_observation failed; continuing without tracing: %s", e) langfuse_tracer = None diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index c808d4add1..aaba3f202e 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -86,6 +86,11 @@ class LLMBundle(LLM4Tenant): def __init__(self, tenant_id: str, model_config: dict, lang="Chinese", **kwargs): super().__init__(tenant_id, model_config, lang, **kwargs) + def _start_langfuse_observation(self, **kwargs): + if self.langfuse_session_id: + kwargs["session_id"] = self.langfuse_session_id + return self.langfuse.start_observation(**kwargs) + def close(self): """Release resources held by this LLMBundle instance.""" super().close() @@ -107,7 +112,7 @@ class LLMBundle(LLM4Tenant): def encode(self, texts: list): if self.langfuse: - generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="encode", model=self.model_config["llm_name"], input={"texts": texts}) + generation = self._start_langfuse_observation(trace_context=self.trace_context, as_type="generation", name="encode", model=self.model_config["llm_name"], input={"texts": texts}) safe_texts = [] for idx, text in enumerate(texts): @@ -149,7 +154,7 @@ class LLMBundle(LLM4Tenant): def encode_queries(self, query: str): if self.langfuse: - generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="encode_queries", model=self.model_config["llm_name"], input={"query": query}) + generation = self._start_langfuse_observation(trace_context=self.trace_context, as_type="generation", name="encode_queries", model=self.model_config["llm_name"], input={"query": query}) if query is None or not str(query).strip(): marker = "None" if query is None else "whitespace-only" @@ -173,7 +178,7 @@ class LLMBundle(LLM4Tenant): def similarity(self, query: str, texts: list): if self.langfuse: - generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="similarity", model=self.model_config["llm_name"], input={"query": query, "texts": texts}) + generation = self._start_langfuse_observation(trace_context=self.trace_context, as_type="generation", name="similarity", model=self.model_config["llm_name"], input={"query": query, "texts": texts}) sim, used_tokens = self.mdl.similarity(query, texts) logging.info("LLMBundle.similarity used_tokens: {}, llm_name: {}".format(used_tokens, self.model_config["llm_name"])) @@ -186,7 +191,7 @@ class LLMBundle(LLM4Tenant): def describe(self, image, max_tokens=300): if self.langfuse: - generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="describe", metadata={"model": self.model_config["llm_name"]}) + generation = self._start_langfuse_observation(trace_context=self.trace_context, as_type="generation", name="describe", metadata={"model": self.model_config["llm_name"]}) txt, used_tokens = self.mdl.describe(image) logging.info("LLMBundle.describe used_tokens: {}, llm_name: {}".format(used_tokens, self.model_config["llm_name"])) @@ -199,7 +204,7 @@ class LLMBundle(LLM4Tenant): def describe_with_prompt(self, image, prompt): if self.langfuse: - generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="describe_with_prompt", metadata={"model": self.model_config["llm_name"], "prompt": prompt}) + generation = self._start_langfuse_observation(trace_context=self.trace_context, as_type="generation", name="describe_with_prompt", metadata={"model": self.model_config["llm_name"], "prompt": prompt}) txt, used_tokens = self.mdl.describe_with_prompt(image, prompt) logging.info("LLMBundle.describe_with_prompt used_tokens: {}, llm_name: {}".format(used_tokens, self.model_config["llm_name"])) @@ -212,7 +217,7 @@ class LLMBundle(LLM4Tenant): def transcription(self, audio): if self.langfuse: - generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="transcription", metadata={"model": self.model_config["llm_name"]}) + generation = self._start_langfuse_observation(trace_context=self.trace_context, as_type="generation", name="transcription", metadata={"model": self.model_config["llm_name"]}) txt, used_tokens = self.mdl.transcription(audio) logging.info("LLMBundle.transcription used_tokens: {}, llm_name: {}".format(used_tokens, self.model_config["llm_name"])) @@ -228,7 +233,7 @@ class LLMBundle(LLM4Tenant): supports_stream = hasattr(mdl, "stream_transcription") and callable(getattr(mdl, "stream_transcription")) if supports_stream: if self.langfuse: - generation = self.langfuse.start_observation(as_type="generation", + generation = self._start_langfuse_observation(as_type="generation", trace_context=self.trace_context, name="stream_transcription", metadata={"model": self.model_config["llm_name"]}, @@ -262,7 +267,7 @@ class LLMBundle(LLM4Tenant): return if self.langfuse: - generation = self.langfuse.start_observation(as_type="generation", + generation = self._start_langfuse_observation(as_type="generation", trace_context=self.trace_context, name="stream_transcription", metadata={"model": self.model_config["llm_name"]}, @@ -286,7 +291,7 @@ class LLMBundle(LLM4Tenant): def tts(self, text: str) -> Generator[bytes, None, None]: if self.langfuse: - generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="tts", input={"text": text}) + generation = self._start_langfuse_observation(trace_context=self.trace_context, as_type="generation", name="tts", input={"text": text}) for chunk in self.mdl.tts(text): if isinstance(chunk, int): @@ -408,7 +413,7 @@ class LLMBundle(LLM4Tenant): generation = None if self.langfuse: - generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="chat", model=self.model_config["llm_name"], input={"system": system, "history": history}) + generation = self._start_langfuse_observation(trace_context=self.trace_context, as_type="generation", name="chat", model=self.model_config["llm_name"], input={"system": system, "history": history}) chat_partial = partial(base_fn, system, history, gen_conf) use_kwargs = self._clean_param(chat_partial, **kwargs) @@ -449,7 +454,7 @@ class LLMBundle(LLM4Tenant): generation = None if self.langfuse: - generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="chat_streamly", model=self.model_config["llm_name"], input={"system": system, "history": history}) + generation = self._start_langfuse_observation(trace_context=self.trace_context, as_type="generation", name="chat_streamly", model=self.model_config["llm_name"], input={"system": system, "history": history}) if stream_fn: chat_partial = partial(stream_fn, system, history, gen_conf) @@ -492,7 +497,7 @@ class LLMBundle(LLM4Tenant): generation = None if self.langfuse: - generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="chat_streamly", model=self.model_config["llm_name"], input={"system": system, "history": history}) + generation = self._start_langfuse_observation(trace_context=self.trace_context, as_type="generation", name="chat_streamly", model=self.model_config["llm_name"], input={"system": system, "history": history}) if stream_fn: chat_partial = partial(stream_fn, system, history, gen_conf) diff --git a/api/db/services/tenant_llm_service.py b/api/db/services/tenant_llm_service.py index c28dd72efb..5012e1d1df 100644 --- a/api/db/services/tenant_llm_service.py +++ b/api/db/services/tenant_llm_service.py @@ -505,6 +505,8 @@ class TenantLLMService(CommonService): class LLM4Tenant: def __init__(self, tenant_id: str, model_config: dict, lang="Chinese", **kwargs): + self.trace_context = kwargs.pop("trace_context", None) or {} + self.langfuse_session_id = kwargs.pop("langfuse_session_id", None) self.tenant_id = tenant_id self.llm_name = model_config["llm_name"] self.model_config = model_config @@ -522,8 +524,9 @@ class LLM4Tenant: try: if langfuse.auth_check(): self.langfuse = langfuse - trace_id = self.langfuse.create_trace_id() - self.trace_context = {"trace_id": trace_id} + if not self.trace_context: + trace_id = self.langfuse.create_trace_id() + self.trace_context = {"trace_id": trace_id} except Exception: # Skip langfuse tracing if connection fails pass diff --git a/test/unit_test/api/db/services/test_dialog_service_final_answer.py b/test/unit_test/api/db/services/test_dialog_service_final_answer.py index 7fefb323a9..40335e440c 100644 --- a/test/unit_test/api/db/services/test_dialog_service_final_answer.py +++ b/test/unit_test/api/db/services/test_dialog_service_final_answer.py @@ -406,7 +406,7 @@ def test_async_chat_final_event_carries_decorated_answer(monkeypatch): # get_models returns (kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl) monkeypatch.setattr( dialog_service, "get_models", - lambda _dialog: ([_KB], chat_mdl, None, chat_mdl, None), + lambda _dialog, **_kwargs: ([_KB], chat_mdl, None, chat_mdl, None), ) monkeypatch.setattr( dialog_service.KnowledgebaseService, "get_field_map", lambda _kb_ids: {} @@ -473,7 +473,7 @@ def test_async_chat_langfuse_uses_start_observation(monkeypatch): monkeypatch.setattr( dialog_service, "get_models", - lambda _dialog: ([_KB], chat_mdl, None, chat_mdl, None), + lambda _dialog, **_kwargs: ([_KB], chat_mdl, None, chat_mdl, None), ) monkeypatch.setattr( dialog_service.KnowledgebaseService, "get_field_map", lambda _kb_ids: {} @@ -511,6 +511,105 @@ def test_async_chat_langfuse_uses_start_observation(monkeypatch): assert langfuse.observation.ended is True +@pytest.mark.p2 +def test_async_chat_langfuse_observation_includes_session_id(monkeypatch): + _FakeLangfuseClient.instances = [] + monkeypatch.setattr(_FakeLangfuseClient, "fail_start_observation", False) + chat_mdl = _StreamingChatModel("Session traces should be grouped.") + retriever = _StubRetriever() + + monkeypatch.setattr( + dialog_service, "get_model_type_by_name", + lambda _tid, _llm_id: ["chat"] + ) + monkeypatch.setattr( + dialog_service, + "get_model_config_from_provider_instance", + lambda _tid, _type, _llm_id: _LLM_CONFIG, + ) + monkeypatch.setattr( + dialog_service.TenantLangfuseService, "filter_by_tenant", + lambda tenant_id: SimpleNamespace( + public_key="public", + secret_key="secret", + host="http://langfuse.local", + ), + ) + monkeypatch.setattr(dialog_service, "Langfuse", _FakeLangfuseClient) + monkeypatch.setattr( + dialog_service, + "get_models", + lambda _dialog, **_kwargs: ([_KB], chat_mdl, None, chat_mdl, None), + ) + monkeypatch.setattr( + dialog_service.KnowledgebaseService, "get_field_map", lambda _kb_ids: {} + ) + monkeypatch.setattr( + dialog_service.KnowledgebaseService, "get_by_ids", lambda _ids: [_KB] + ) + monkeypatch.setattr(dialog_service.settings, "retriever", retriever, raising=False) + monkeypatch.setattr(dialog_service, "label_question", lambda _q, _kbs: "") + monkeypatch.setattr( + dialog_service, + "kb_prompt", + lambda _kbinfos, _max_tokens, **_kw: ["RAGFlow is a RAG engine."], + ) + + dialog = _make_dialog(chat_mdl) + messages = [{"role": "user", "content": "What is RAGFlow?"}] + + events = _collect(dialog_service.async_chat(dialog, messages, stream=True, quote=True, session_id="session-1")) + + assert any(e.get("final") is True for e in events) + langfuse = _FakeLangfuseClient.instances[0] + assert langfuse.observation_kwargs["trace_context"] == {"trace_id": "trace-id"} + assert langfuse.observation_kwargs["session_id"] == "session-1" + + +@pytest.mark.p2 +def test_get_models_passes_langfuse_trace_context_to_llm_bundles(monkeypatch): + captured = [] + + class _FakeBundle: + def __init__(self, tenant_id, model_config, **kwargs): + self.tenant_id = tenant_id + self.model_config = model_config + self.trace_context = kwargs.get("trace_context") + self.langfuse_session_id = kwargs.get("langfuse_session_id") + captured.append((tenant_id, model_config["model_type"], kwargs)) + + monkeypatch.setattr(dialog_service.KnowledgebaseService, "get_by_ids", lambda _ids: [_KB]) + monkeypatch.setattr( + dialog_service, + "get_model_config_from_provider_instance", + lambda _tenant_id, model_type, _model_id: {**_LLM_CONFIG, "model_type": model_type}, + ) + monkeypatch.setattr( + dialog_service, + "get_tenant_default_model_by_type", + lambda _tenant_id, model_type: {**_LLM_CONFIG, "model_type": model_type}, + ) + monkeypatch.setattr(dialog_service, "LLMBundle", _FakeBundle) + + dialog = _make_dialog(None) + dialog.rerank_id = "rerank-1" + dialog.prompt_config["tts"] = True + trace_context = {"trace_id": "trace-id"} + + dialog_service.get_models(dialog, trace_context=trace_context, langfuse_session_id="session-1") + + assert len(captured) == 4 + assert {model_type for _, model_type, _ in captured} == { + dialog_service.LLMType.EMBEDDING, + dialog_service.LLMType.CHAT, + dialog_service.LLMType.RERANK, + dialog_service.LLMType.TTS, + } + for _, _, kwargs in captured: + assert kwargs["trace_context"] is trace_context + assert kwargs["langfuse_session_id"] == "session-1" + + @pytest.mark.p2 def test_async_chat_continues_when_langfuse_observation_start_fails(monkeypatch): """ @@ -543,7 +642,7 @@ def test_async_chat_continues_when_langfuse_observation_start_fails(monkeypatch) monkeypatch.setattr( dialog_service, "get_models", - lambda _dialog: ([_KB], chat_mdl, None, chat_mdl, None), + lambda _dialog, **_kwargs: ([_KB], chat_mdl, None, chat_mdl, None), ) monkeypatch.setattr( dialog_service.KnowledgebaseService, "get_field_map", lambda _kb_ids: {} diff --git a/test/unit_test/api/db/services/test_dialog_service_use_sql_source_columns.py b/test/unit_test/api/db/services/test_dialog_service_use_sql_source_columns.py index 6109010b33..297cf3f4f1 100644 --- a/test/unit_test/api/db/services/test_dialog_service_use_sql_source_columns.py +++ b/test/unit_test/api/db/services/test_dialog_service_use_sql_source_columns.py @@ -295,7 +295,7 @@ def test_async_chat_uses_all_docs_when_no_doc_ids_selected(monkeypatch): monkeypatch.setattr( dialog_service, "get_models", - lambda _dialog: ([SimpleNamespace(tenant_id="tenant-id")], object(), None, chat_model, None), + lambda _dialog, **_kwargs: ([SimpleNamespace(tenant_id="tenant-id")], object(), None, chat_model, None), ) monkeypatch.setattr(dialog_service.KnowledgebaseService, "get_field_map", lambda _kb_ids: {}) monkeypatch.setattr(dialog_service, "label_question", lambda _question, _kbs: None)