diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index ecc227a225..b2268db6b5 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -24,7 +24,7 @@ logger = logging.getLogger(__name__) from datetime import datetime from functools import partial from timeit import default_timer as timer -from langfuse import Langfuse +from langfuse import Langfuse, propagate_attributes from peewee import fn from api.db.services.file_service import FileService from common.constants import LLMType, ParserType, StatusEnum @@ -875,8 +875,10 @@ async def async_chat(dialog, messages, stream=True, **kwargs): "input": {"prompt": prompt, "prompt4citation": prompt4citation, "messages": msg}, } if session_id: - observation_kwargs["session_id"] = session_id - langfuse_generation = langfuse_tracer.start_observation(**observation_kwargs) + with propagate_attributes(session_id=session_id): + langfuse_generation = langfuse_tracer.start_observation(**observation_kwargs) + else: + langfuse_generation = langfuse_tracer.start_observation(**observation_kwargs) except Exception as e: # noqa: BLE001 - tracing must not break chat flow logger.warning("Langfuse start_observation failed; continuing without tracing: %s", e) langfuse_tracer = None diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index aaba3f202e..02852fa137 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -22,6 +22,8 @@ import threading from functools import partial from typing import Generator +from langfuse import propagate_attributes + from api.db.db_models import LLM from api.db.services.common_service import CommonService from api.db.services.tenant_llm_service import LLM4Tenant @@ -88,7 +90,8 @@ class LLMBundle(LLM4Tenant): def _start_langfuse_observation(self, **kwargs): if self.langfuse_session_id: - kwargs["session_id"] = self.langfuse_session_id + with propagate_attributes(session_id=self.langfuse_session_id): + return self.langfuse.start_observation(**kwargs) return self.langfuse.start_observation(**kwargs) def close(self): diff --git a/api/db/services/tenant_model_service.py b/api/db/services/tenant_model_service.py index ebef70137c..2c32d45295 100644 --- a/api/db/services/tenant_model_service.py +++ b/api/db/services/tenant_model_service.py @@ -83,7 +83,7 @@ class TenantModelService(CommonService): extra="{}" ) return len(operation.get("add", [])) + len(operation.get("delete", [])) - model_record_example = [model_record for model_record in model_type_records if model_record.status != ActiveStatusEnum.UNSUPPORTED] + model_record_example = [model_record for model_record in model_type_records if model_record.status != ActiveStatusEnum.UNSUPPORTED.value] extra_fields = model_record_example[0].extra if model_record_example else "{}" model_status = model_record_example[0].status if model_record_example else ActiveStatusEnum.ACTIVE.value type_record_map = {record.model_type: record for record in model_type_records} diff --git a/test/unit_test/api/db/services/test_dialog_service_final_answer.py b/test/unit_test/api/db/services/test_dialog_service_final_answer.py index 56fb74a83e..c3fd15ea92 100644 --- a/test/unit_test/api/db/services/test_dialog_service_final_answer.py +++ b/test/unit_test/api/db/services/test_dialog_service_final_answer.py @@ -140,6 +140,22 @@ class _StubRetriever: return answer, set() +class _FakePropagateAttributesContext: + """No-op context manager for fake propagate_attributes.""" + + def __enter__(self): + return self + + def __exit__(self, *args): + pass + + +def _fake_propagate_attributes(**kwargs): + """Fake propagate_attributes (Langfuse v4) that records kwargs and returns a no-op context manager.""" + _propagate_attributes_calls.append(kwargs) + return _FakePropagateAttributesContext() + + class _FakeLangfuseObservation: def __init__(self): self.updates = [] @@ -152,6 +168,9 @@ class _FakeLangfuseObservation: self.ended = True +_propagate_attributes_calls = [] + + class _FakeLangfuseClient: instances = [] fail_start_observation = False @@ -460,6 +479,8 @@ def test_async_chat_langfuse_uses_start_observation(monkeypatch): ), ) monkeypatch.setattr(dialog_service, "Langfuse", _FakeLangfuseClient) + _propagate_attributes_calls.clear() + monkeypatch.setattr(dialog_service, "propagate_attributes", _fake_propagate_attributes) monkeypatch.setattr( dialog_service, "get_models", @@ -504,6 +525,7 @@ def test_async_chat_langfuse_uses_start_observation(monkeypatch): @pytest.mark.p2 def test_async_chat_langfuse_observation_includes_session_id(monkeypatch): _FakeLangfuseClient.instances = [] + _propagate_attributes_calls.clear() monkeypatch.setattr(_FakeLangfuseClient, "fail_start_observation", False) chat_mdl = _StreamingChatModel("Session traces should be grouped.") retriever = _StubRetriever() @@ -526,6 +548,7 @@ def test_async_chat_langfuse_observation_includes_session_id(monkeypatch): ), ) monkeypatch.setattr(dialog_service, "Langfuse", _FakeLangfuseClient) + monkeypatch.setattr(dialog_service, "propagate_attributes", _fake_propagate_attributes) monkeypatch.setattr( dialog_service, "get_models", @@ -553,7 +576,7 @@ def test_async_chat_langfuse_observation_includes_session_id(monkeypatch): assert any(e.get("final") is True for e in events) langfuse = _FakeLangfuseClient.instances[0] assert langfuse.observation_kwargs["trace_context"] == {"trace_id": "trace-id"} - assert langfuse.observation_kwargs["session_id"] == "session-1" + assert _propagate_attributes_calls[0]["session_id"] == "session-1" @pytest.mark.p2 @@ -629,6 +652,8 @@ def test_async_chat_continues_when_langfuse_observation_start_fails(monkeypatch) ), ) monkeypatch.setattr(dialog_service, "Langfuse", _FakeLangfuseClient) + _propagate_attributes_calls.clear() + monkeypatch.setattr(dialog_service, "propagate_attributes", _fake_propagate_attributes) monkeypatch.setattr( dialog_service, "get_models",