mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
feat: Langfuse session grouping for multi-turn chat traces (#15679)
## Summary This PR passes `session_id` into Langfuse trace observations so multi-turn chat messages can be grouped under the same session in Langfuse. Changes include: - Propagate `session_id` from chat/session APIs into `dialog_service.async_chat`. - Pass `session_id` into Langfuse `start_observation(...)`. - Share Langfuse `trace_context` with chat, embedding, rerank, and TTS model bundles where applicable. - Add unit coverage to verify Langfuse observations receive `session_id`. - Update affected test stubs for the new optional Langfuse context arguments. ## Related Issue Closes: #15636 ## Change Type - [x] Feature - [x] Bug fix - [x] Test - [ ] Refactor - [ ] Documentation - [ ] Breaking change ## Real Behavior Proof Before this change: - Langfuse observations were created without `session_id`. - Multi-turn chat traces could not be grouped by session in Langfuse. After this change: - Chat/session flows pass `session_id` into `async_chat`. - Langfuse observations include `session_id`. - Related model bundles receive shared trace context and session metadata. Validation result: ```bash uv run python -m py_compile \ api/db/services/tenant_llm_service.py \ api/db/services/llm_service.py \ api/db/services/dialog_service.py \ api/db/services/conversation_service.py \ api/apps/restful_apis/chat_api.py \ test/unit_test/api/db/services/test_dialog_service_final_answer.py \ test/unit_test/api/db/services/test_dialog_service_use_sql_source_columns.py ``` Passed. ```bash uv run pytest \ test/unit_test/api/db/services/test_dialog_service_final_answer.py \ test/unit_test/api/db/services/test_dialog_service_use_sql_source_columns.py -q ``` Result: ```text 11 passed in 16.89s ``` ```bash git diff --check ``` Passed. ## Checklist - [x] Analyzed the issue requirement. - [x] Checked existing Langfuse trace integration. - [x] Implemented only the requested session grouping behavior. - [x] Added/updated unit tests. - [x] Ran focused tests successfully. - [x] Ran Python compile validation. - [x] Ran whitespace diff validation.
This commit is contained in:
@@ -1242,7 +1242,7 @@ async def session_completion(chat_id_in_arg=""):
|
||||
"""Yield SSE-formatted chunks from the async chat generator."""
|
||||
nonlocal dia, msg, req, conv
|
||||
try:
|
||||
async for ans in async_chat(dia, msg, True, **req):
|
||||
async for ans in async_chat(dia, msg, True, session_id=session_id, **req):
|
||||
ans = _format_answer(ans)
|
||||
payload = _sanitize_json_floats({"code": 0, "message": "", "data": ans})
|
||||
yield "data:" + json.dumps(payload, ensure_ascii=False) + "\n\n"
|
||||
@@ -1262,7 +1262,7 @@ async def session_completion(chat_id_in_arg=""):
|
||||
return resp
|
||||
|
||||
answer = None
|
||||
async for ans in async_chat(dia, msg, False, **req):
|
||||
async for ans in async_chat(dia, msg, False, session_id=session_id, **req):
|
||||
answer = _format_answer(ans)
|
||||
if conv is not None:
|
||||
await thread_pool_exec(ConversationService.update_by_id, conv.id, conv.to_dict())
|
||||
|
||||
@@ -187,7 +187,7 @@ async def async_completion(tenant_id, chat_id, question, name="New session", ses
|
||||
|
||||
if stream:
|
||||
try:
|
||||
async for ans in async_chat(dia, msg, True, **kwargs):
|
||||
async for ans in async_chat(dia, msg, True, session_id=session_id, **kwargs):
|
||||
ans = structure_answer(conv, ans, message_id, session_id)
|
||||
yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n"
|
||||
ConversationService.update_by_id(conv.id, conv.to_dict())
|
||||
@@ -199,7 +199,7 @@ async def async_completion(tenant_id, chat_id, question, name="New session", ses
|
||||
|
||||
else:
|
||||
answer = None
|
||||
async for ans in async_chat(dia, msg, False, **kwargs):
|
||||
async for ans in async_chat(dia, msg, False, session_id=session_id, **kwargs):
|
||||
answer = structure_answer(conv, ans, message_id, session_id)
|
||||
ConversationService.update_by_id(conv.id, conv.to_dict())
|
||||
break
|
||||
@@ -275,7 +275,7 @@ async def async_iframe_completion(dialog_id, question, session_id=None, stream=T
|
||||
|
||||
if stream:
|
||||
try:
|
||||
async for ans in async_chat(dia, msg, True, **kwargs):
|
||||
async for ans in async_chat(dia, msg, True, session_id=session_id, **kwargs):
|
||||
ans = structure_answer(conv, ans, message_id, session_id)
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans},
|
||||
ensure_ascii=False) + "\n\n"
|
||||
@@ -288,7 +288,7 @@ async def async_iframe_completion(dialog_id, question, session_id=None, stream=T
|
||||
|
||||
else:
|
||||
answer = None
|
||||
async for ans in async_chat(dia, msg, False, **kwargs):
|
||||
async for ans in async_chat(dia, msg, False, session_id=session_id, **kwargs):
|
||||
answer = structure_answer(conv, ans, message_id, session_id)
|
||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||
break
|
||||
|
||||
@@ -286,7 +286,7 @@ class DialogService(CommonService):
|
||||
return list(objs)
|
||||
|
||||
|
||||
async def async_chat_solo(dialog, messages, stream=True):
|
||||
async def async_chat_solo(dialog, messages, stream=True, session_id=None):
|
||||
llm_types = get_model_type_by_name(dialog.tenant_id, dialog.llm_id)
|
||||
attachments = ""
|
||||
image_attachments = []
|
||||
@@ -303,14 +303,14 @@ async def async_chat_solo(dialog, messages, stream=True):
|
||||
else:
|
||||
model_config = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.CHAT)
|
||||
|
||||
chat_mdl = LLMBundle(dialog.tenant_id, model_config)
|
||||
chat_mdl = LLMBundle(dialog.tenant_id, model_config, langfuse_session_id=session_id)
|
||||
factory = model_config.get("llm_factory", "") if model_config else ""
|
||||
|
||||
prompt_config = dialog.prompt_config
|
||||
tts_mdl = None
|
||||
if prompt_config.get("tts"):
|
||||
default_tts_model = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.TTS)
|
||||
tts_mdl = LLMBundle(dialog.tenant_id, default_tts_model)
|
||||
tts_mdl = LLMBundle(dialog.tenant_id, default_tts_model, trace_context=chat_mdl.trace_context, langfuse_session_id=session_id)
|
||||
msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"]
|
||||
if attachments and msg:
|
||||
msg[-1]["content"] += attachments
|
||||
@@ -337,7 +337,7 @@ async def async_chat_solo(dialog, messages, stream=True):
|
||||
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()}
|
||||
|
||||
|
||||
def get_models(dialog):
|
||||
def get_models(dialog, trace_context=None, langfuse_session_id=None):
|
||||
embd_mdl, chat_mdl, rerank_mdl, tts_mdl = None, None, None, None
|
||||
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
|
||||
embedding_list = list(set([kb.embd_id for kb in kbs]))
|
||||
@@ -347,7 +347,7 @@ def get_models(dialog):
|
||||
if embedding_list:
|
||||
embd_owner_tenant_id = kbs[0].tenant_id
|
||||
embd_model_config = get_model_config_from_provider_instance(embd_owner_tenant_id, LLMType.EMBEDDING, embedding_list[0])
|
||||
embd_mdl = LLMBundle(embd_owner_tenant_id, embd_model_config)
|
||||
embd_mdl = LLMBundle(embd_owner_tenant_id, embd_model_config, trace_context=trace_context, langfuse_session_id=langfuse_session_id)
|
||||
if not embd_mdl:
|
||||
raise LookupError("Embedding model(%s) not found" % embedding_list[0])
|
||||
|
||||
@@ -356,15 +356,15 @@ def get_models(dialog):
|
||||
else:
|
||||
chat_model_config = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.CHAT)
|
||||
|
||||
chat_mdl = LLMBundle(dialog.tenant_id, chat_model_config)
|
||||
chat_mdl = LLMBundle(dialog.tenant_id, chat_model_config, trace_context=trace_context, langfuse_session_id=langfuse_session_id)
|
||||
|
||||
if dialog.rerank_id:
|
||||
rerank_model_config = get_model_config_from_provider_instance(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
|
||||
rerank_mdl = LLMBundle(dialog.tenant_id, rerank_model_config)
|
||||
rerank_mdl = LLMBundle(dialog.tenant_id, rerank_model_config, trace_context=trace_context, langfuse_session_id=langfuse_session_id)
|
||||
|
||||
if dialog.prompt_config.get("tts"):
|
||||
default_tts_model_config = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.TTS)
|
||||
tts_mdl = LLMBundle(dialog.tenant_id, default_tts_model_config)
|
||||
tts_mdl = LLMBundle(dialog.tenant_id, default_tts_model_config, trace_context=trace_context, langfuse_session_id=langfuse_session_id)
|
||||
return kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl
|
||||
|
||||
|
||||
@@ -541,10 +541,11 @@ def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set):
|
||||
async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
logging.debug("Begin async_chat")
|
||||
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
||||
session_id = kwargs.get("session_id")
|
||||
use_web_search = _should_use_web_search(dialog.prompt_config, kwargs.get("internet"))
|
||||
logging.debug("web_search kb=%s tavily=%s internet=%r enabled=%s", bool(dialog.kb_ids), bool(dialog.prompt_config.get("tavily_api_key")), kwargs.get("internet"), use_web_search)
|
||||
if not dialog.kb_ids and not use_web_search:
|
||||
async for ans in async_chat_solo(dialog, messages, stream):
|
||||
async for ans in async_chat_solo(dialog, messages, stream, session_id=session_id):
|
||||
yield ans
|
||||
return
|
||||
|
||||
@@ -579,7 +580,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
pass
|
||||
|
||||
check_langfuse_tracer_ts = timer()
|
||||
kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl = get_models(dialog)
|
||||
kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl = get_models(dialog, trace_context=trace_context, langfuse_session_id=session_id)
|
||||
toolcall_session, tools = kwargs.get("toolcall_session"), kwargs.get("tools")
|
||||
if toolcall_session and tools:
|
||||
chat_mdl.bind_tools(toolcall_session, tools)
|
||||
@@ -737,7 +738,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
|
||||
if prompt_config.get("use_kg"):
|
||||
default_chat_model = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.CHAT)
|
||||
ck = await settings.kg_retriever.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl, LLMBundle(dialog.tenant_id, default_chat_model))
|
||||
ck = await settings.kg_retriever.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl, LLMBundle(dialog.tenant_id, default_chat_model, trace_context=trace_context, langfuse_session_id=session_id))
|
||||
if ck["content_with_weight"]:
|
||||
kbinfos["chunks"].insert(0, ck)
|
||||
|
||||
@@ -866,13 +867,16 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
|
||||
if langfuse_tracer:
|
||||
try:
|
||||
langfuse_generation = langfuse_tracer.start_observation(
|
||||
as_type="generation",
|
||||
trace_context=trace_context,
|
||||
name="chat",
|
||||
model=llm_model_config["llm_name"],
|
||||
input={"prompt": prompt, "prompt4citation": prompt4citation, "messages": msg},
|
||||
)
|
||||
observation_kwargs = {
|
||||
"as_type": "generation",
|
||||
"trace_context": trace_context,
|
||||
"name": "chat",
|
||||
"model": llm_model_config["llm_name"],
|
||||
"input": {"prompt": prompt, "prompt4citation": prompt4citation, "messages": msg},
|
||||
}
|
||||
if session_id:
|
||||
observation_kwargs["session_id"] = session_id
|
||||
langfuse_generation = langfuse_tracer.start_observation(**observation_kwargs)
|
||||
except Exception as e: # noqa: BLE001 - tracing must not break chat flow
|
||||
logger.warning("Langfuse start_observation failed; continuing without tracing: %s", e)
|
||||
langfuse_tracer = None
|
||||
|
||||
@@ -86,6 +86,11 @@ class LLMBundle(LLM4Tenant):
|
||||
def __init__(self, tenant_id: str, model_config: dict, lang="Chinese", **kwargs):
|
||||
super().__init__(tenant_id, model_config, lang, **kwargs)
|
||||
|
||||
def _start_langfuse_observation(self, **kwargs):
|
||||
if self.langfuse_session_id:
|
||||
kwargs["session_id"] = self.langfuse_session_id
|
||||
return self.langfuse.start_observation(**kwargs)
|
||||
|
||||
def close(self):
|
||||
"""Release resources held by this LLMBundle instance."""
|
||||
super().close()
|
||||
@@ -107,7 +112,7 @@ class LLMBundle(LLM4Tenant):
|
||||
|
||||
def encode(self, texts: list):
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="encode", model=self.model_config["llm_name"], input={"texts": texts})
|
||||
generation = self._start_langfuse_observation(trace_context=self.trace_context, as_type="generation", name="encode", model=self.model_config["llm_name"], input={"texts": texts})
|
||||
|
||||
safe_texts = []
|
||||
for idx, text in enumerate(texts):
|
||||
@@ -149,7 +154,7 @@ class LLMBundle(LLM4Tenant):
|
||||
|
||||
def encode_queries(self, query: str):
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="encode_queries", model=self.model_config["llm_name"], input={"query": query})
|
||||
generation = self._start_langfuse_observation(trace_context=self.trace_context, as_type="generation", name="encode_queries", model=self.model_config["llm_name"], input={"query": query})
|
||||
|
||||
if query is None or not str(query).strip():
|
||||
marker = "None" if query is None else "whitespace-only"
|
||||
@@ -173,7 +178,7 @@ class LLMBundle(LLM4Tenant):
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="similarity", model=self.model_config["llm_name"], input={"query": query, "texts": texts})
|
||||
generation = self._start_langfuse_observation(trace_context=self.trace_context, as_type="generation", name="similarity", model=self.model_config["llm_name"], input={"query": query, "texts": texts})
|
||||
|
||||
sim, used_tokens = self.mdl.similarity(query, texts)
|
||||
logging.info("LLMBundle.similarity used_tokens: {}, llm_name: {}".format(used_tokens, self.model_config["llm_name"]))
|
||||
@@ -186,7 +191,7 @@ class LLMBundle(LLM4Tenant):
|
||||
|
||||
def describe(self, image, max_tokens=300):
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="describe", metadata={"model": self.model_config["llm_name"]})
|
||||
generation = self._start_langfuse_observation(trace_context=self.trace_context, as_type="generation", name="describe", metadata={"model": self.model_config["llm_name"]})
|
||||
|
||||
txt, used_tokens = self.mdl.describe(image)
|
||||
logging.info("LLMBundle.describe used_tokens: {}, llm_name: {}".format(used_tokens, self.model_config["llm_name"]))
|
||||
@@ -199,7 +204,7 @@ class LLMBundle(LLM4Tenant):
|
||||
|
||||
def describe_with_prompt(self, image, prompt):
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="describe_with_prompt", metadata={"model": self.model_config["llm_name"], "prompt": prompt})
|
||||
generation = self._start_langfuse_observation(trace_context=self.trace_context, as_type="generation", name="describe_with_prompt", metadata={"model": self.model_config["llm_name"], "prompt": prompt})
|
||||
|
||||
txt, used_tokens = self.mdl.describe_with_prompt(image, prompt)
|
||||
logging.info("LLMBundle.describe_with_prompt used_tokens: {}, llm_name: {}".format(used_tokens, self.model_config["llm_name"]))
|
||||
@@ -212,7 +217,7 @@ class LLMBundle(LLM4Tenant):
|
||||
|
||||
def transcription(self, audio):
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="transcription", metadata={"model": self.model_config["llm_name"]})
|
||||
generation = self._start_langfuse_observation(trace_context=self.trace_context, as_type="generation", name="transcription", metadata={"model": self.model_config["llm_name"]})
|
||||
|
||||
txt, used_tokens = self.mdl.transcription(audio)
|
||||
logging.info("LLMBundle.transcription used_tokens: {}, llm_name: {}".format(used_tokens, self.model_config["llm_name"]))
|
||||
@@ -228,7 +233,7 @@ class LLMBundle(LLM4Tenant):
|
||||
supports_stream = hasattr(mdl, "stream_transcription") and callable(getattr(mdl, "stream_transcription"))
|
||||
if supports_stream:
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_observation(as_type="generation",
|
||||
generation = self._start_langfuse_observation(as_type="generation",
|
||||
trace_context=self.trace_context,
|
||||
name="stream_transcription",
|
||||
metadata={"model": self.model_config["llm_name"]},
|
||||
@@ -262,7 +267,7 @@ class LLMBundle(LLM4Tenant):
|
||||
return
|
||||
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_observation(as_type="generation",
|
||||
generation = self._start_langfuse_observation(as_type="generation",
|
||||
trace_context=self.trace_context,
|
||||
name="stream_transcription",
|
||||
metadata={"model": self.model_config["llm_name"]},
|
||||
@@ -286,7 +291,7 @@ class LLMBundle(LLM4Tenant):
|
||||
|
||||
def tts(self, text: str) -> Generator[bytes, None, None]:
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="tts", input={"text": text})
|
||||
generation = self._start_langfuse_observation(trace_context=self.trace_context, as_type="generation", name="tts", input={"text": text})
|
||||
|
||||
for chunk in self.mdl.tts(text):
|
||||
if isinstance(chunk, int):
|
||||
@@ -408,7 +413,7 @@ class LLMBundle(LLM4Tenant):
|
||||
|
||||
generation = None
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="chat", model=self.model_config["llm_name"], input={"system": system, "history": history})
|
||||
generation = self._start_langfuse_observation(trace_context=self.trace_context, as_type="generation", name="chat", model=self.model_config["llm_name"], input={"system": system, "history": history})
|
||||
|
||||
chat_partial = partial(base_fn, system, history, gen_conf)
|
||||
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
||||
@@ -449,7 +454,7 @@ class LLMBundle(LLM4Tenant):
|
||||
|
||||
generation = None
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="chat_streamly", model=self.model_config["llm_name"], input={"system": system, "history": history})
|
||||
generation = self._start_langfuse_observation(trace_context=self.trace_context, as_type="generation", name="chat_streamly", model=self.model_config["llm_name"], input={"system": system, "history": history})
|
||||
|
||||
if stream_fn:
|
||||
chat_partial = partial(stream_fn, system, history, gen_conf)
|
||||
@@ -492,7 +497,7 @@ class LLMBundle(LLM4Tenant):
|
||||
|
||||
generation = None
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="chat_streamly", model=self.model_config["llm_name"], input={"system": system, "history": history})
|
||||
generation = self._start_langfuse_observation(trace_context=self.trace_context, as_type="generation", name="chat_streamly", model=self.model_config["llm_name"], input={"system": system, "history": history})
|
||||
|
||||
if stream_fn:
|
||||
chat_partial = partial(stream_fn, system, history, gen_conf)
|
||||
|
||||
@@ -505,6 +505,8 @@ class TenantLLMService(CommonService):
|
||||
|
||||
class LLM4Tenant:
|
||||
def __init__(self, tenant_id: str, model_config: dict, lang="Chinese", **kwargs):
|
||||
self.trace_context = kwargs.pop("trace_context", None) or {}
|
||||
self.langfuse_session_id = kwargs.pop("langfuse_session_id", None)
|
||||
self.tenant_id = tenant_id
|
||||
self.llm_name = model_config["llm_name"]
|
||||
self.model_config = model_config
|
||||
@@ -522,8 +524,9 @@ class LLM4Tenant:
|
||||
try:
|
||||
if langfuse.auth_check():
|
||||
self.langfuse = langfuse
|
||||
trace_id = self.langfuse.create_trace_id()
|
||||
self.trace_context = {"trace_id": trace_id}
|
||||
if not self.trace_context:
|
||||
trace_id = self.langfuse.create_trace_id()
|
||||
self.trace_context = {"trace_id": trace_id}
|
||||
except Exception:
|
||||
# Skip langfuse tracing if connection fails
|
||||
pass
|
||||
|
||||
@@ -406,7 +406,7 @@ def test_async_chat_final_event_carries_decorated_answer(monkeypatch):
|
||||
# get_models returns (kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl)
|
||||
monkeypatch.setattr(
|
||||
dialog_service, "get_models",
|
||||
lambda _dialog: ([_KB], chat_mdl, None, chat_mdl, None),
|
||||
lambda _dialog, **_kwargs: ([_KB], chat_mdl, None, chat_mdl, None),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
dialog_service.KnowledgebaseService, "get_field_map", lambda _kb_ids: {}
|
||||
@@ -473,7 +473,7 @@ def test_async_chat_langfuse_uses_start_observation(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
dialog_service,
|
||||
"get_models",
|
||||
lambda _dialog: ([_KB], chat_mdl, None, chat_mdl, None),
|
||||
lambda _dialog, **_kwargs: ([_KB], chat_mdl, None, chat_mdl, None),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
dialog_service.KnowledgebaseService, "get_field_map", lambda _kb_ids: {}
|
||||
@@ -511,6 +511,105 @@ def test_async_chat_langfuse_uses_start_observation(monkeypatch):
|
||||
assert langfuse.observation.ended is True
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_async_chat_langfuse_observation_includes_session_id(monkeypatch):
|
||||
_FakeLangfuseClient.instances = []
|
||||
monkeypatch.setattr(_FakeLangfuseClient, "fail_start_observation", False)
|
||||
chat_mdl = _StreamingChatModel("Session traces should be grouped.")
|
||||
retriever = _StubRetriever()
|
||||
|
||||
monkeypatch.setattr(
|
||||
dialog_service, "get_model_type_by_name",
|
||||
lambda _tid, _llm_id: ["chat"]
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
dialog_service,
|
||||
"get_model_config_from_provider_instance",
|
||||
lambda _tid, _type, _llm_id: _LLM_CONFIG,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
dialog_service.TenantLangfuseService, "filter_by_tenant",
|
||||
lambda tenant_id: SimpleNamespace(
|
||||
public_key="public",
|
||||
secret_key="secret",
|
||||
host="http://langfuse.local",
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(dialog_service, "Langfuse", _FakeLangfuseClient)
|
||||
monkeypatch.setattr(
|
||||
dialog_service,
|
||||
"get_models",
|
||||
lambda _dialog, **_kwargs: ([_KB], chat_mdl, None, chat_mdl, None),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
dialog_service.KnowledgebaseService, "get_field_map", lambda _kb_ids: {}
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
dialog_service.KnowledgebaseService, "get_by_ids", lambda _ids: [_KB]
|
||||
)
|
||||
monkeypatch.setattr(dialog_service.settings, "retriever", retriever, raising=False)
|
||||
monkeypatch.setattr(dialog_service, "label_question", lambda _q, _kbs: "")
|
||||
monkeypatch.setattr(
|
||||
dialog_service,
|
||||
"kb_prompt",
|
||||
lambda _kbinfos, _max_tokens, **_kw: ["RAGFlow is a RAG engine."],
|
||||
)
|
||||
|
||||
dialog = _make_dialog(chat_mdl)
|
||||
messages = [{"role": "user", "content": "What is RAGFlow?"}]
|
||||
|
||||
events = _collect(dialog_service.async_chat(dialog, messages, stream=True, quote=True, session_id="session-1"))
|
||||
|
||||
assert any(e.get("final") is True for e in events)
|
||||
langfuse = _FakeLangfuseClient.instances[0]
|
||||
assert langfuse.observation_kwargs["trace_context"] == {"trace_id": "trace-id"}
|
||||
assert langfuse.observation_kwargs["session_id"] == "session-1"
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_get_models_passes_langfuse_trace_context_to_llm_bundles(monkeypatch):
|
||||
captured = []
|
||||
|
||||
class _FakeBundle:
|
||||
def __init__(self, tenant_id, model_config, **kwargs):
|
||||
self.tenant_id = tenant_id
|
||||
self.model_config = model_config
|
||||
self.trace_context = kwargs.get("trace_context")
|
||||
self.langfuse_session_id = kwargs.get("langfuse_session_id")
|
||||
captured.append((tenant_id, model_config["model_type"], kwargs))
|
||||
|
||||
monkeypatch.setattr(dialog_service.KnowledgebaseService, "get_by_ids", lambda _ids: [_KB])
|
||||
monkeypatch.setattr(
|
||||
dialog_service,
|
||||
"get_model_config_from_provider_instance",
|
||||
lambda _tenant_id, model_type, _model_id: {**_LLM_CONFIG, "model_type": model_type},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
dialog_service,
|
||||
"get_tenant_default_model_by_type",
|
||||
lambda _tenant_id, model_type: {**_LLM_CONFIG, "model_type": model_type},
|
||||
)
|
||||
monkeypatch.setattr(dialog_service, "LLMBundle", _FakeBundle)
|
||||
|
||||
dialog = _make_dialog(None)
|
||||
dialog.rerank_id = "rerank-1"
|
||||
dialog.prompt_config["tts"] = True
|
||||
trace_context = {"trace_id": "trace-id"}
|
||||
|
||||
dialog_service.get_models(dialog, trace_context=trace_context, langfuse_session_id="session-1")
|
||||
|
||||
assert len(captured) == 4
|
||||
assert {model_type for _, model_type, _ in captured} == {
|
||||
dialog_service.LLMType.EMBEDDING,
|
||||
dialog_service.LLMType.CHAT,
|
||||
dialog_service.LLMType.RERANK,
|
||||
dialog_service.LLMType.TTS,
|
||||
}
|
||||
for _, _, kwargs in captured:
|
||||
assert kwargs["trace_context"] is trace_context
|
||||
assert kwargs["langfuse_session_id"] == "session-1"
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_async_chat_continues_when_langfuse_observation_start_fails(monkeypatch):
|
||||
"""
|
||||
@@ -543,7 +642,7 @@ def test_async_chat_continues_when_langfuse_observation_start_fails(monkeypatch)
|
||||
monkeypatch.setattr(
|
||||
dialog_service,
|
||||
"get_models",
|
||||
lambda _dialog: ([_KB], chat_mdl, None, chat_mdl, None),
|
||||
lambda _dialog, **_kwargs: ([_KB], chat_mdl, None, chat_mdl, None),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
dialog_service.KnowledgebaseService, "get_field_map", lambda _kb_ids: {}
|
||||
|
||||
@@ -295,7 +295,7 @@ def test_async_chat_uses_all_docs_when_no_doc_ids_selected(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
dialog_service,
|
||||
"get_models",
|
||||
lambda _dialog: ([SimpleNamespace(tenant_id="tenant-id")], object(), None, chat_model, None),
|
||||
lambda _dialog, **_kwargs: ([SimpleNamespace(tenant_id="tenant-id")], object(), None, chat_model, None),
|
||||
)
|
||||
monkeypatch.setattr(dialog_service.KnowledgebaseService, "get_field_map", lambda _kb_ids: {})
|
||||
monkeypatch.setattr(dialog_service, "label_question", lambda _question, _kbs: None)
|
||||
|
||||
Reference in New Issue
Block a user