feat: Langfuse session grouping for multi-turn chat traces (#15679)

## Summary

This PR passes `session_id` into Langfuse trace observations so
multi-turn chat messages can be grouped under the same session in
Langfuse.

Changes include:
- Propagate `session_id` from chat/session APIs into
`dialog_service.async_chat`.
- Pass `session_id` into Langfuse `start_observation(...)`.
- Share Langfuse `trace_context` with chat, embedding, rerank, and TTS
model bundles where applicable.
- Add unit coverage to verify Langfuse observations receive
`session_id`.
- Update affected test stubs for the new optional Langfuse context
arguments.

## Related Issue
Closes: #15636 

## Change Type
- [x] Feature
- [x] Bug fix
- [x] Test
- [ ] Refactor
- [ ] Documentation
- [ ] Breaking change

## Real Behavior Proof

Before this change:

- Langfuse observations were created without `session_id`.
- Multi-turn chat traces could not be grouped by session in Langfuse.

After this change:

- Chat/session flows pass `session_id` into `async_chat`.
- Langfuse observations include `session_id`.
- Related model bundles receive shared trace context and session
metadata.

Validation result:

```bash
uv run python -m py_compile \
  api/db/services/tenant_llm_service.py \
  api/db/services/llm_service.py \
  api/db/services/dialog_service.py \
  api/db/services/conversation_service.py \
  api/apps/restful_apis/chat_api.py \
  test/unit_test/api/db/services/test_dialog_service_final_answer.py \
  test/unit_test/api/db/services/test_dialog_service_use_sql_source_columns.py
```
Passed.

```bash
uv run pytest \
  test/unit_test/api/db/services/test_dialog_service_final_answer.py \
  test/unit_test/api/db/services/test_dialog_service_use_sql_source_columns.py -q
```
Result:

```text
11 passed in 16.89s
```

```bash
git diff --check
```
Passed.
## Checklist

- [x] Analyzed the issue requirement.
- [x] Checked existing Langfuse trace integration.
- [x] Implemented only the requested session grouping behavior.
- [x] Added/updated unit tests.
- [x] Ran focused tests successfully.
- [x] Ran Python compile validation.
- [x] Ran whitespace diff validation.
This commit is contained in:
Jonathan Chang
2026-06-12 09:18:06 +07:00
committed by GitHub
parent 0d836afd34
commit de06c9a60b
7 changed files with 153 additions and 42 deletions

View File

@@ -1242,7 +1242,7 @@ async def session_completion(chat_id_in_arg=""):
"""Yield SSE-formatted chunks from the async chat generator."""
nonlocal dia, msg, req, conv
try:
async for ans in async_chat(dia, msg, True, **req):
async for ans in async_chat(dia, msg, True, session_id=session_id, **req):
ans = _format_answer(ans)
payload = _sanitize_json_floats({"code": 0, "message": "", "data": ans})
yield "data:" + json.dumps(payload, ensure_ascii=False) + "\n\n"
@@ -1262,7 +1262,7 @@ async def session_completion(chat_id_in_arg=""):
return resp
answer = None
async for ans in async_chat(dia, msg, False, **req):
async for ans in async_chat(dia, msg, False, session_id=session_id, **req):
answer = _format_answer(ans)
if conv is not None:
await thread_pool_exec(ConversationService.update_by_id, conv.id, conv.to_dict())

View File

@@ -187,7 +187,7 @@ async def async_completion(tenant_id, chat_id, question, name="New session", ses
if stream:
try:
async for ans in async_chat(dia, msg, True, **kwargs):
async for ans in async_chat(dia, msg, True, session_id=session_id, **kwargs):
ans = structure_answer(conv, ans, message_id, session_id)
yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n"
ConversationService.update_by_id(conv.id, conv.to_dict())
@@ -199,7 +199,7 @@ async def async_completion(tenant_id, chat_id, question, name="New session", ses
else:
answer = None
async for ans in async_chat(dia, msg, False, **kwargs):
async for ans in async_chat(dia, msg, False, session_id=session_id, **kwargs):
answer = structure_answer(conv, ans, message_id, session_id)
ConversationService.update_by_id(conv.id, conv.to_dict())
break
@@ -275,7 +275,7 @@ async def async_iframe_completion(dialog_id, question, session_id=None, stream=T
if stream:
try:
async for ans in async_chat(dia, msg, True, **kwargs):
async for ans in async_chat(dia, msg, True, session_id=session_id, **kwargs):
ans = structure_answer(conv, ans, message_id, session_id)
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans},
ensure_ascii=False) + "\n\n"
@@ -288,7 +288,7 @@ async def async_iframe_completion(dialog_id, question, session_id=None, stream=T
else:
answer = None
async for ans in async_chat(dia, msg, False, **kwargs):
async for ans in async_chat(dia, msg, False, session_id=session_id, **kwargs):
answer = structure_answer(conv, ans, message_id, session_id)
API4ConversationService.append_message(conv.id, conv.to_dict())
break

View File

@@ -286,7 +286,7 @@ class DialogService(CommonService):
return list(objs)
async def async_chat_solo(dialog, messages, stream=True):
async def async_chat_solo(dialog, messages, stream=True, session_id=None):
llm_types = get_model_type_by_name(dialog.tenant_id, dialog.llm_id)
attachments = ""
image_attachments = []
@@ -303,14 +303,14 @@ async def async_chat_solo(dialog, messages, stream=True):
else:
model_config = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.CHAT)
chat_mdl = LLMBundle(dialog.tenant_id, model_config)
chat_mdl = LLMBundle(dialog.tenant_id, model_config, langfuse_session_id=session_id)
factory = model_config.get("llm_factory", "") if model_config else ""
prompt_config = dialog.prompt_config
tts_mdl = None
if prompt_config.get("tts"):
default_tts_model = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.TTS)
tts_mdl = LLMBundle(dialog.tenant_id, default_tts_model)
tts_mdl = LLMBundle(dialog.tenant_id, default_tts_model, trace_context=chat_mdl.trace_context, langfuse_session_id=session_id)
msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"]
if attachments and msg:
msg[-1]["content"] += attachments
@@ -337,7 +337,7 @@ async def async_chat_solo(dialog, messages, stream=True):
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()}
def get_models(dialog):
def get_models(dialog, trace_context=None, langfuse_session_id=None):
embd_mdl, chat_mdl, rerank_mdl, tts_mdl = None, None, None, None
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
embedding_list = list(set([kb.embd_id for kb in kbs]))
@@ -347,7 +347,7 @@ def get_models(dialog):
if embedding_list:
embd_owner_tenant_id = kbs[0].tenant_id
embd_model_config = get_model_config_from_provider_instance(embd_owner_tenant_id, LLMType.EMBEDDING, embedding_list[0])
embd_mdl = LLMBundle(embd_owner_tenant_id, embd_model_config)
embd_mdl = LLMBundle(embd_owner_tenant_id, embd_model_config, trace_context=trace_context, langfuse_session_id=langfuse_session_id)
if not embd_mdl:
raise LookupError("Embedding model(%s) not found" % embedding_list[0])
@@ -356,15 +356,15 @@ def get_models(dialog):
else:
chat_model_config = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.CHAT)
chat_mdl = LLMBundle(dialog.tenant_id, chat_model_config)
chat_mdl = LLMBundle(dialog.tenant_id, chat_model_config, trace_context=trace_context, langfuse_session_id=langfuse_session_id)
if dialog.rerank_id:
rerank_model_config = get_model_config_from_provider_instance(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
rerank_mdl = LLMBundle(dialog.tenant_id, rerank_model_config)
rerank_mdl = LLMBundle(dialog.tenant_id, rerank_model_config, trace_context=trace_context, langfuse_session_id=langfuse_session_id)
if dialog.prompt_config.get("tts"):
default_tts_model_config = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.TTS)
tts_mdl = LLMBundle(dialog.tenant_id, default_tts_model_config)
tts_mdl = LLMBundle(dialog.tenant_id, default_tts_model_config, trace_context=trace_context, langfuse_session_id=langfuse_session_id)
return kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl
@@ -541,10 +541,11 @@ def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set):
async def async_chat(dialog, messages, stream=True, **kwargs):
logging.debug("Begin async_chat")
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
session_id = kwargs.get("session_id")
use_web_search = _should_use_web_search(dialog.prompt_config, kwargs.get("internet"))
logging.debug("web_search kb=%s tavily=%s internet=%r enabled=%s", bool(dialog.kb_ids), bool(dialog.prompt_config.get("tavily_api_key")), kwargs.get("internet"), use_web_search)
if not dialog.kb_ids and not use_web_search:
async for ans in async_chat_solo(dialog, messages, stream):
async for ans in async_chat_solo(dialog, messages, stream, session_id=session_id):
yield ans
return
@@ -579,7 +580,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
pass
check_langfuse_tracer_ts = timer()
kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl = get_models(dialog)
kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl = get_models(dialog, trace_context=trace_context, langfuse_session_id=session_id)
toolcall_session, tools = kwargs.get("toolcall_session"), kwargs.get("tools")
if toolcall_session and tools:
chat_mdl.bind_tools(toolcall_session, tools)
@@ -737,7 +738,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
if prompt_config.get("use_kg"):
default_chat_model = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.CHAT)
ck = await settings.kg_retriever.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl, LLMBundle(dialog.tenant_id, default_chat_model))
ck = await settings.kg_retriever.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl, LLMBundle(dialog.tenant_id, default_chat_model, trace_context=trace_context, langfuse_session_id=session_id))
if ck["content_with_weight"]:
kbinfos["chunks"].insert(0, ck)
@@ -866,13 +867,16 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
if langfuse_tracer:
try:
langfuse_generation = langfuse_tracer.start_observation(
as_type="generation",
trace_context=trace_context,
name="chat",
model=llm_model_config["llm_name"],
input={"prompt": prompt, "prompt4citation": prompt4citation, "messages": msg},
)
observation_kwargs = {
"as_type": "generation",
"trace_context": trace_context,
"name": "chat",
"model": llm_model_config["llm_name"],
"input": {"prompt": prompt, "prompt4citation": prompt4citation, "messages": msg},
}
if session_id:
observation_kwargs["session_id"] = session_id
langfuse_generation = langfuse_tracer.start_observation(**observation_kwargs)
except Exception as e: # noqa: BLE001 - tracing must not break chat flow
logger.warning("Langfuse start_observation failed; continuing without tracing: %s", e)
langfuse_tracer = None

View File

@@ -86,6 +86,11 @@ class LLMBundle(LLM4Tenant):
def __init__(self, tenant_id: str, model_config: dict, lang="Chinese", **kwargs):
super().__init__(tenant_id, model_config, lang, **kwargs)
def _start_langfuse_observation(self, **kwargs):
if self.langfuse_session_id:
kwargs["session_id"] = self.langfuse_session_id
return self.langfuse.start_observation(**kwargs)
def close(self):
"""Release resources held by this LLMBundle instance."""
super().close()
@@ -107,7 +112,7 @@ class LLMBundle(LLM4Tenant):
def encode(self, texts: list):
if self.langfuse:
generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="encode", model=self.model_config["llm_name"], input={"texts": texts})
generation = self._start_langfuse_observation(trace_context=self.trace_context, as_type="generation", name="encode", model=self.model_config["llm_name"], input={"texts": texts})
safe_texts = []
for idx, text in enumerate(texts):
@@ -149,7 +154,7 @@ class LLMBundle(LLM4Tenant):
def encode_queries(self, query: str):
if self.langfuse:
generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="encode_queries", model=self.model_config["llm_name"], input={"query": query})
generation = self._start_langfuse_observation(trace_context=self.trace_context, as_type="generation", name="encode_queries", model=self.model_config["llm_name"], input={"query": query})
if query is None or not str(query).strip():
marker = "None" if query is None else "whitespace-only"
@@ -173,7 +178,7 @@ class LLMBundle(LLM4Tenant):
def similarity(self, query: str, texts: list):
if self.langfuse:
generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="similarity", model=self.model_config["llm_name"], input={"query": query, "texts": texts})
generation = self._start_langfuse_observation(trace_context=self.trace_context, as_type="generation", name="similarity", model=self.model_config["llm_name"], input={"query": query, "texts": texts})
sim, used_tokens = self.mdl.similarity(query, texts)
logging.info("LLMBundle.similarity used_tokens: {}, llm_name: {}".format(used_tokens, self.model_config["llm_name"]))
@@ -186,7 +191,7 @@ class LLMBundle(LLM4Tenant):
def describe(self, image, max_tokens=300):
if self.langfuse:
generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="describe", metadata={"model": self.model_config["llm_name"]})
generation = self._start_langfuse_observation(trace_context=self.trace_context, as_type="generation", name="describe", metadata={"model": self.model_config["llm_name"]})
txt, used_tokens = self.mdl.describe(image)
logging.info("LLMBundle.describe used_tokens: {}, llm_name: {}".format(used_tokens, self.model_config["llm_name"]))
@@ -199,7 +204,7 @@ class LLMBundle(LLM4Tenant):
def describe_with_prompt(self, image, prompt):
if self.langfuse:
generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="describe_with_prompt", metadata={"model": self.model_config["llm_name"], "prompt": prompt})
generation = self._start_langfuse_observation(trace_context=self.trace_context, as_type="generation", name="describe_with_prompt", metadata={"model": self.model_config["llm_name"], "prompt": prompt})
txt, used_tokens = self.mdl.describe_with_prompt(image, prompt)
logging.info("LLMBundle.describe_with_prompt used_tokens: {}, llm_name: {}".format(used_tokens, self.model_config["llm_name"]))
@@ -212,7 +217,7 @@ class LLMBundle(LLM4Tenant):
def transcription(self, audio):
if self.langfuse:
generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="transcription", metadata={"model": self.model_config["llm_name"]})
generation = self._start_langfuse_observation(trace_context=self.trace_context, as_type="generation", name="transcription", metadata={"model": self.model_config["llm_name"]})
txt, used_tokens = self.mdl.transcription(audio)
logging.info("LLMBundle.transcription used_tokens: {}, llm_name: {}".format(used_tokens, self.model_config["llm_name"]))
@@ -228,7 +233,7 @@ class LLMBundle(LLM4Tenant):
supports_stream = hasattr(mdl, "stream_transcription") and callable(getattr(mdl, "stream_transcription"))
if supports_stream:
if self.langfuse:
generation = self.langfuse.start_observation(as_type="generation",
generation = self._start_langfuse_observation(as_type="generation",
trace_context=self.trace_context,
name="stream_transcription",
metadata={"model": self.model_config["llm_name"]},
@@ -262,7 +267,7 @@ class LLMBundle(LLM4Tenant):
return
if self.langfuse:
generation = self.langfuse.start_observation(as_type="generation",
generation = self._start_langfuse_observation(as_type="generation",
trace_context=self.trace_context,
name="stream_transcription",
metadata={"model": self.model_config["llm_name"]},
@@ -286,7 +291,7 @@ class LLMBundle(LLM4Tenant):
def tts(self, text: str) -> Generator[bytes, None, None]:
if self.langfuse:
generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="tts", input={"text": text})
generation = self._start_langfuse_observation(trace_context=self.trace_context, as_type="generation", name="tts", input={"text": text})
for chunk in self.mdl.tts(text):
if isinstance(chunk, int):
@@ -408,7 +413,7 @@ class LLMBundle(LLM4Tenant):
generation = None
if self.langfuse:
generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="chat", model=self.model_config["llm_name"], input={"system": system, "history": history})
generation = self._start_langfuse_observation(trace_context=self.trace_context, as_type="generation", name="chat", model=self.model_config["llm_name"], input={"system": system, "history": history})
chat_partial = partial(base_fn, system, history, gen_conf)
use_kwargs = self._clean_param(chat_partial, **kwargs)
@@ -449,7 +454,7 @@ class LLMBundle(LLM4Tenant):
generation = None
if self.langfuse:
generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="chat_streamly", model=self.model_config["llm_name"], input={"system": system, "history": history})
generation = self._start_langfuse_observation(trace_context=self.trace_context, as_type="generation", name="chat_streamly", model=self.model_config["llm_name"], input={"system": system, "history": history})
if stream_fn:
chat_partial = partial(stream_fn, system, history, gen_conf)
@@ -492,7 +497,7 @@ class LLMBundle(LLM4Tenant):
generation = None
if self.langfuse:
generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="chat_streamly", model=self.model_config["llm_name"], input={"system": system, "history": history})
generation = self._start_langfuse_observation(trace_context=self.trace_context, as_type="generation", name="chat_streamly", model=self.model_config["llm_name"], input={"system": system, "history": history})
if stream_fn:
chat_partial = partial(stream_fn, system, history, gen_conf)

View File

@@ -505,6 +505,8 @@ class TenantLLMService(CommonService):
class LLM4Tenant:
def __init__(self, tenant_id: str, model_config: dict, lang="Chinese", **kwargs):
self.trace_context = kwargs.pop("trace_context", None) or {}
self.langfuse_session_id = kwargs.pop("langfuse_session_id", None)
self.tenant_id = tenant_id
self.llm_name = model_config["llm_name"]
self.model_config = model_config
@@ -522,8 +524,9 @@ class LLM4Tenant:
try:
if langfuse.auth_check():
self.langfuse = langfuse
trace_id = self.langfuse.create_trace_id()
self.trace_context = {"trace_id": trace_id}
if not self.trace_context:
trace_id = self.langfuse.create_trace_id()
self.trace_context = {"trace_id": trace_id}
except Exception:
# Skip langfuse tracing if connection fails
pass

View File

@@ -406,7 +406,7 @@ def test_async_chat_final_event_carries_decorated_answer(monkeypatch):
# get_models returns (kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl)
monkeypatch.setattr(
dialog_service, "get_models",
lambda _dialog: ([_KB], chat_mdl, None, chat_mdl, None),
lambda _dialog, **_kwargs: ([_KB], chat_mdl, None, chat_mdl, None),
)
monkeypatch.setattr(
dialog_service.KnowledgebaseService, "get_field_map", lambda _kb_ids: {}
@@ -473,7 +473,7 @@ def test_async_chat_langfuse_uses_start_observation(monkeypatch):
monkeypatch.setattr(
dialog_service,
"get_models",
lambda _dialog: ([_KB], chat_mdl, None, chat_mdl, None),
lambda _dialog, **_kwargs: ([_KB], chat_mdl, None, chat_mdl, None),
)
monkeypatch.setattr(
dialog_service.KnowledgebaseService, "get_field_map", lambda _kb_ids: {}
@@ -511,6 +511,105 @@ def test_async_chat_langfuse_uses_start_observation(monkeypatch):
assert langfuse.observation.ended is True
@pytest.mark.p2
def test_async_chat_langfuse_observation_includes_session_id(monkeypatch):
_FakeLangfuseClient.instances = []
monkeypatch.setattr(_FakeLangfuseClient, "fail_start_observation", False)
chat_mdl = _StreamingChatModel("Session traces should be grouped.")
retriever = _StubRetriever()
monkeypatch.setattr(
dialog_service, "get_model_type_by_name",
lambda _tid, _llm_id: ["chat"]
)
monkeypatch.setattr(
dialog_service,
"get_model_config_from_provider_instance",
lambda _tid, _type, _llm_id: _LLM_CONFIG,
)
monkeypatch.setattr(
dialog_service.TenantLangfuseService, "filter_by_tenant",
lambda tenant_id: SimpleNamespace(
public_key="public",
secret_key="secret",
host="http://langfuse.local",
),
)
monkeypatch.setattr(dialog_service, "Langfuse", _FakeLangfuseClient)
monkeypatch.setattr(
dialog_service,
"get_models",
lambda _dialog, **_kwargs: ([_KB], chat_mdl, None, chat_mdl, None),
)
monkeypatch.setattr(
dialog_service.KnowledgebaseService, "get_field_map", lambda _kb_ids: {}
)
monkeypatch.setattr(
dialog_service.KnowledgebaseService, "get_by_ids", lambda _ids: [_KB]
)
monkeypatch.setattr(dialog_service.settings, "retriever", retriever, raising=False)
monkeypatch.setattr(dialog_service, "label_question", lambda _q, _kbs: "")
monkeypatch.setattr(
dialog_service,
"kb_prompt",
lambda _kbinfos, _max_tokens, **_kw: ["RAGFlow is a RAG engine."],
)
dialog = _make_dialog(chat_mdl)
messages = [{"role": "user", "content": "What is RAGFlow?"}]
events = _collect(dialog_service.async_chat(dialog, messages, stream=True, quote=True, session_id="session-1"))
assert any(e.get("final") is True for e in events)
langfuse = _FakeLangfuseClient.instances[0]
assert langfuse.observation_kwargs["trace_context"] == {"trace_id": "trace-id"}
assert langfuse.observation_kwargs["session_id"] == "session-1"
@pytest.mark.p2
def test_get_models_passes_langfuse_trace_context_to_llm_bundles(monkeypatch):
captured = []
class _FakeBundle:
def __init__(self, tenant_id, model_config, **kwargs):
self.tenant_id = tenant_id
self.model_config = model_config
self.trace_context = kwargs.get("trace_context")
self.langfuse_session_id = kwargs.get("langfuse_session_id")
captured.append((tenant_id, model_config["model_type"], kwargs))
monkeypatch.setattr(dialog_service.KnowledgebaseService, "get_by_ids", lambda _ids: [_KB])
monkeypatch.setattr(
dialog_service,
"get_model_config_from_provider_instance",
lambda _tenant_id, model_type, _model_id: {**_LLM_CONFIG, "model_type": model_type},
)
monkeypatch.setattr(
dialog_service,
"get_tenant_default_model_by_type",
lambda _tenant_id, model_type: {**_LLM_CONFIG, "model_type": model_type},
)
monkeypatch.setattr(dialog_service, "LLMBundle", _FakeBundle)
dialog = _make_dialog(None)
dialog.rerank_id = "rerank-1"
dialog.prompt_config["tts"] = True
trace_context = {"trace_id": "trace-id"}
dialog_service.get_models(dialog, trace_context=trace_context, langfuse_session_id="session-1")
assert len(captured) == 4
assert {model_type for _, model_type, _ in captured} == {
dialog_service.LLMType.EMBEDDING,
dialog_service.LLMType.CHAT,
dialog_service.LLMType.RERANK,
dialog_service.LLMType.TTS,
}
for _, _, kwargs in captured:
assert kwargs["trace_context"] is trace_context
assert kwargs["langfuse_session_id"] == "session-1"
@pytest.mark.p2
def test_async_chat_continues_when_langfuse_observation_start_fails(monkeypatch):
"""
@@ -543,7 +642,7 @@ def test_async_chat_continues_when_langfuse_observation_start_fails(monkeypatch)
monkeypatch.setattr(
dialog_service,
"get_models",
lambda _dialog: ([_KB], chat_mdl, None, chat_mdl, None),
lambda _dialog, **_kwargs: ([_KB], chat_mdl, None, chat_mdl, None),
)
monkeypatch.setattr(
dialog_service.KnowledgebaseService, "get_field_map", lambda _kb_ids: {}

View File

@@ -295,7 +295,7 @@ def test_async_chat_uses_all_docs_when_no_doc_ids_selected(monkeypatch):
monkeypatch.setattr(
dialog_service,
"get_models",
lambda _dialog: ([SimpleNamespace(tenant_id="tenant-id")], object(), None, chat_model, None),
lambda _dialog, **_kwargs: ([SimpleNamespace(tenant_id="tenant-id")], object(), None, chat_model, None),
)
monkeypatch.setattr(dialog_service.KnowledgebaseService, "get_field_map", lambda _kb_ids: {})
monkeypatch.setattr(dialog_service, "label_question", lambda _question, _kbs: None)