From 563d8557802df668deb28a3db6c6eede7e94f76f Mon Sep 17 00:00:00 2001 From: qinling0210 <88864212+qinling0210@users.noreply.github.com> Date: Thu, 18 Jun 2026 18:07:27 +0800 Subject: [PATCH] Implement OpenAI chat completions in GO (#16177) ### What problem does this PR solve? Implement OpenAI chat completions in GO POST /api/v1/openai//chat/completions OpenAI chat cli: internal/development.md ### Type of change - [x] Refactoring --- api/db/services/dialog_service.py | 25 +- cmd/server_main.go | 4 +- go.mod | 6 +- go.sum | 8 + internal/cli/cli.go | 60 +- internal/cli/cli_http.go | 5 + internal/cli/http_client.go | 2 +- internal/cli/lexer.go | 2 + internal/cli/parser.go | 2 + internal/cli/response.go | 190 + internal/cli/types.go | 1 + internal/cli/user_command.go | 328 ++ internal/cli/user_parser.go | 261 ++ internal/common/metadata_utils.go | 8 +- internal/common/multimodal.go | 353 ++ internal/common/timer.go | 169 + internal/common/timer_test.go | 224 ++ internal/development.md | 86 +- internal/engine/elasticsearch/sql.go | 206 + internal/engine/elasticsearch/sql_test.go | 493 +++ internal/engine/engine.go | 4 + internal/engine/global.go | 5 + internal/engine/infinity/chunk.go | 1 + internal/engine/infinity/client.go | 17 +- internal/engine/infinity/metadata.go | 121 +- internal/engine/infinity/sql.go | 330 ++ internal/engine/infinity/sql_test.go | 394 ++ internal/entity/models/chat_tools.go | 351 ++ internal/entity/models/openai.go | 161 +- internal/entity/models/siliconflow.go | 46 +- internal/entity/models/types.go | 75 +- internal/handler/openai_chat.go | 125 + internal/handler/openai_chat_test.go | 254 ++ internal/router/router.go | 9 + internal/service/ask_service.go | 4 +- internal/service/chat.go | 133 +- internal/service/chat_pipeline.go | 4275 +++++++++++++++++++++ internal/service/chat_pipeline_test.go | 1464 +++++++ internal/service/chat_session.go | 1099 +----- internal/service/chat_session_test.go | 1392 +++---- internal/service/citation.go | 101 + internal/service/deep_researcher.go | 854 ++++ internal/service/file.go | 60 + internal/service/generator.go | 173 +- internal/service/kb_prompt.go | 46 +- internal/service/kb_prompt_test.go | 30 +- internal/service/kg/scoring.go | 4 +- internal/service/langfuse.go | 258 ++ internal/service/metadata.go | 28 +- internal/service/metadata_filter.go | 76 +- internal/service/model_service.go | 36 + internal/service/nlp/retrieval.go | 86 +- internal/service/openai_chat.go | 846 ++++ internal/service/openai_chat_test.go | 842 ++++ internal/service/tag.go | 11 +- internal/service/tenant.go | 53 +- internal/service/toc_enhancer.go | 605 +++ internal/service/toc_enhancer_test.go | 314 ++ internal/tokenizer/tokenizer.go | 49 +- internal/tokenizer/tokenizer_test.go | 265 ++ rag/prompts/generator.py | 2 +- 61 files changed, 15327 insertions(+), 2105 deletions(-) create mode 100644 internal/common/multimodal.go create mode 100644 internal/common/timer.go create mode 100644 internal/common/timer_test.go create mode 100644 internal/engine/elasticsearch/sql.go create mode 100644 internal/engine/elasticsearch/sql_test.go create mode 100644 internal/engine/infinity/sql.go create mode 100644 internal/engine/infinity/sql_test.go create mode 100644 internal/entity/models/chat_tools.go create mode 100644 internal/handler/openai_chat.go create mode 100644 internal/handler/openai_chat_test.go create mode 100644 internal/service/chat_pipeline.go create mode 100644 internal/service/chat_pipeline_test.go create mode 100644 internal/service/deep_researcher.go create mode 100644 internal/service/langfuse.go create mode 100644 internal/service/openai_chat.go create mode 100644 internal/service/openai_chat_test.go create mode 100644 internal/service/toc_enhancer.go create mode 100644 internal/service/toc_enhancer_test.go create mode 100644 internal/tokenizer/tokenizer_test.go diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index b2268db6b5..7edc94b16c 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -759,10 +759,19 @@ async def async_chat(dialog, messages, stream=True, **kwargs): yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions), "audio_binary": tts(tts_mdl, empty_res), "final": True} return - kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges) + # Only overwrite kwargs["knowledge"] when retrieval produced something; + # otherwise preserve any caller-supplied value. + knowledge_text = "\n\n------\n\n".join(knowledges) + if knowledge_text: + kwargs["knowledge"] = "\n------\n" + knowledge_text gen_conf = dialog.llm_setting - msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs) + attachments_}] + system_content = prompt_config["system"].format(**kwargs) + attachments_ + # If knowledge was retrieved but the template has no {knowledge} + # placeholder, auto-append it so the LLM still sees the context. + if knowledges and "{knowledge}" not in prompt_config.get("system", ""): + system_content += kwargs["knowledge"] + msg = [{"role": "system", "content": system_content}] prompt4citation = "" if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)): prompt4citation = citation_prompt() @@ -1056,7 +1065,9 @@ RULES: - Question mentions "not null" or "excluding null" - Add NULL check for count specific column - DO NOT add NULL check for COUNT(*) queries (COUNT(*) counts all rows including nulls) -7. Output ONLY the SQL, no explanations""" +7. json_extract_string() returns JSON-quoted strings ("value"), so WHERE comparisons MUST wrap values in double-quotes inside single-quotes (no spaces between quotes): '"value"' (e.g. WHERE json_extract_string(chunk_data, '$.name') = '"Alice"') +8. For partial text search, use LIKE with wildcards: '"%value%"' (e.g. WHERE json_extract_string(chunk_data, '$.name') LIKE '"%Alice%"') +9. Output ONLY the SQL, no explanations""" user_prompt = """Table: {} Fields (EXACT case): {} {} @@ -1128,9 +1139,13 @@ Write SQL using exact field names above. Include doc_id, docnm_kwd for data quer logging.debug(f"use_sql: Executing SQL retrieval (attempt {tried_times})") tbl = settings.retriever.sql_retrieval(sql, format="json") if tbl is None: - logging.debug("use_sql: SQL retrieval returned None") + logging.debug("use_sql: SQL retrieval failed (returned None)") return None, sql - logging.debug(f"use_sql: SQL retrieval completed, got {len(tbl.get('rows', []))} rows") + row_count = len(tbl.get("rows", [])) + if row_count == 0: + logging.debug("use_sql: SQL execution succeeded but returned 0 rows") + else: + logging.debug(f"use_sql: SQL retrieval completed, got {row_count} rows") return tbl, sql async def repair_table_for_missing_source_columns(previous_sql): diff --git a/cmd/server_main.go b/cmd/server_main.go index a37ef19b33..d6c6f1f30c 100644 --- a/cmd/server_main.go +++ b/cmd/server_main.go @@ -212,6 +212,7 @@ func startServer(config *server.Config) { tenantService := service.NewTenantService() chatService := service.NewChatService() chatSessionService := service.NewChatSessionService() + openaiChatService := service.NewOpenAIChatService() systemService := service.NewSystemService() connectorService := service.NewConnectorService() searchService := service.NewSearchService() @@ -235,6 +236,7 @@ func startServer(config *server.Config) { llmHandler := handler.NewLLMHandler(llmService, userService) chatHandler := handler.NewChatHandler(chatService, userService) chatSessionHandler := handler.NewChatSessionHandler(chatSessionService, userService) + openaiChatHandler := handler.NewOpenAIChatHandler(openaiChatService) connectorHandler := handler.NewConnectorHandler(connectorService, userService) searchHandler := handler.NewSearchHandler(searchService, userService) fileHandler := handler.NewFileHandler(fileService, userService) @@ -307,7 +309,7 @@ func startServer(config *server.Config) { adminRuntimeHandler := handler.NewAdminRuntimeHandler(adminRuntimeSelector) // Initialize router - r := router.NewRouter(authHandler, userHandler, tenantHandler, documentHandler, datasetsHandler, systemHandler, knowledgebaseHandler, chunkHandler, llmHandler, chatHandler, chatSessionHandler, connectorHandler, searchHandler, fileHandler, memoryHandler, mcpHandler, skillSearchHandler, providerHandler, agentHandler, searchBotHandler, difyRetrievalHandler, pluginHandler, modelHandler, fileCommitHandler, adminRuntimeHandler) + r := router.NewRouter(authHandler, userHandler, tenantHandler, documentHandler, datasetsHandler, systemHandler, knowledgebaseHandler, chunkHandler, llmHandler, chatHandler, chatSessionHandler, connectorHandler, searchHandler, fileHandler, memoryHandler, mcpHandler, skillSearchHandler, providerHandler, agentHandler, searchBotHandler, difyRetrievalHandler, pluginHandler, modelHandler, fileCommitHandler, adminRuntimeHandler, openaiChatHandler) // Create Gin engine ginEngine := gin.New() diff --git a/go.mod b/go.mod index a1086e508e..0f2b0b8be6 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module ragflow -go 1.26.2 +go 1.26.4 require ( github.com/DATA-DOG/go-sqlmock v1.5.2 @@ -28,11 +28,13 @@ require ( github.com/infiniflow/infinity-go-sdk v0.0.0-00010101000000-000000000000 github.com/iromli/go-itsdangerous v0.0.0-20220223194502-9c8bef8dac6a github.com/json-iterator/go v1.1.12 + github.com/kaptinlin/jsonrepair v0.4.8 github.com/lib/pq v1.10.9 github.com/minio/minio-go/v7 v7.0.99 github.com/nats-io/nats.go v1.52.0 github.com/nikolalohinski/gonja v1.5.3 github.com/peterh/liner v1.2.2 + github.com/pkoukk/tiktoken-go v0.1.8 github.com/prometheus/client_golang v1.23.2 github.com/prometheus/client_model v0.6.2 github.com/redis/go-redis/v9 v9.18.0 @@ -93,6 +95,7 @@ require ( github.com/clbanning/mxj/v2 v2.7.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/dlclark/regexp2 v1.10.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/eino-contrib/jsonschema v1.0.3 // indirect github.com/elastic/elastic-transport-go/v8 v8.8.0 // indirect @@ -102,6 +105,7 @@ require ( github.com/gin-contrib/sse v0.1.0 // indirect github.com/glebarez/go-sqlite v1.21.2 // indirect github.com/go-ini/ini v1.67.0 // indirect + github.com/go-json-experiment/json v0.0.0-20260601182631-00ed12fed2a6 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-playground/locales v0.14.1 // indirect diff --git a/go.sum b/go.sum index c6d835d3bc..7ee3534f3c 100644 --- a/go.sum +++ b/go.sum @@ -148,6 +148,8 @@ github.com/denisenkom/go-mssqldb v0.12.3 h1:pBSGx9Tq67pBOTLmxNuirNTeB8Vjmf886Kx+ github.com/denisenkom/go-mssqldb v0.12.3/go.mod h1:k0mtMFOnU+AihqFxPMiF05rtiDrorD1Vrm1KEz5hxDo= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0= +github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= @@ -185,6 +187,8 @@ github.com/go-check/check v0.0.0-20180628173108-788fd7840127 h1:0gkP6mzaMqkmpcJY github.com/go-check/check v0.0.0-20180628173108-788fd7840127/go.mod h1:9ES+weclKsC9YodN5RgxqK/VD9HM9JsCSh7rNhMZE98= github.com/go-ini/ini v1.67.0 h1:z6ZrTEZqSWOTyH2FlglNbNgARyHG8oLW9gMELqKr06A= github.com/go-ini/ini v1.67.0/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3Ies8= +github.com/go-json-experiment/json v0.0.0-20260601182631-00ed12fed2a6 h1:nxP4pPoyqOAgX8lYDFCfl3DyKeXErCvSvhcyzwGV9CE= +github.com/go-json-experiment/json v0.0.0-20260601182631-00ed12fed2a6/go.mod h1:tphK2c80bpPhMOI4v6bIc2xWywPfbqi1Z06+RcrMkDg= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= @@ -274,6 +278,8 @@ github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHm github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/juju/gnuflag v0.0.0-20171113085948-2ce1bb71843d/go.mod h1:2PavIy+JPciBPrBUjwbNvtwB6RQlve+hkpll6QSNmOE= +github.com/kaptinlin/jsonrepair v0.4.8 h1:9oaoEe/vaKgm8ko4TLjBLUEog6tBW6WUzZXLPL2yTCk= +github.com/kaptinlin/jsonrepair v0.4.8/go.mod h1:eWRC42KDUT0MHkMplUN6necu59FQFqKOKe+86akpY3g= github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0/go.mod h1:1NbS8ALrpOvjt0rHPNLyCIeMtbizbir8U//inJ+zuB8= github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE= github.com/klauspost/compress v1.18.5 h1:/h1gH5Ce+VWNLSWqPzOVn6XBO+vJbCNGvjoaGBFW2IE= @@ -354,6 +360,8 @@ github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkoukk/tiktoken-go v0.1.8 h1:85ENo+3FpWgAACBaEUVp+lctuTcYUO7BtmfhlN/QTRo= +github.com/pkoukk/tiktoken-go v0.1.8/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= diff --git a/internal/cli/cli.go b/internal/cli/cli.go index 6841187202..f0852727b4 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -754,10 +754,10 @@ Commands (User Mode): LIST TOKENS; - List API tokens LIST PROVIDERS; - List available LLM providers CREATE TOKEN; - Create new API token - ADD PROVIDER 'name'; - Create a provider without API key - ADD PROVIDER 'name' 'api_key'; - Create a provider with API key + ADD PROVIDER 'name'; - Create a provider without API key + ADD PROVIDER 'name' 'api_key'; - Create a provider with API key DROP TOKEN 'token_value'; - Delete an API token - DELETE PROVIDER 'name'; - Delete a provider + DELETE PROVIDER 'name'; - Delete a provider SET TOKEN 'token_value'; - Set and validate API token SHOW TOKEN; - Show current API token SHOW PROVIDER 'name'; - Show provider details @@ -767,6 +767,8 @@ Commands (User Mode): USE MODEL 'provider/instance/model'; - Set current model for chat CHAT 'message'; - Chat using current model CHAT 'provider/instance/model' 'message'; - Chat with specified model + OPENAI_CHAT 'chat_id' 'message' [options] ; - OpenAI-compatible chat + (run openai_chat -h for detailed options) Filesystem Commands (no quotes): ls [path] - List resources @@ -919,3 +921,55 @@ Datasets syntax (full filter set): ` fmt.Println(help) } + +// printOpenaiChatHelp prints help for the OPENAI_CHAT command. +func printOpenaiChatHelp() { + help := `OPENAI_CHAT — hit POST /api/v1/openai//chat/completions + +Syntax: + OPENAI_CHAT 'chat_id' 'message' + [system "..."] + [history "user:...;assistant:...;user:..."] + [history_delimiter ""] + [model ] + [temperature ] [max_tokens ] [stream ] + [top_p ] [frequency_penalty ] [presence_penalty ] + [extra_body ] ; + +Required positional: + 'chat_id' the dialog id (becomes the URL path segment) + 'message' the user message content + +Named options (any order; all optional with defaults): + system '...' override the system prompt + history '...' prior turns: user:...;assistant:...;user:... + history_delimiter '...' turn separator for history (default ';') + model '...' 'model' (sentinel) or composite (default 'model') + temperature 0..2 (default 0) + max_tokens (default 0 = server/model default) + stream true|false (default false) + top_p 0..1 + frequency_penalty -2..2 + presence_penalty -2..2 + extra_body '{"reference":true,...}' + +Defaults: + model 'model' — server resolves to the dialog's configured LLM + stream false + temperature 0 + history_delimiter ';' — commas in content survive unchanged + +extra_body allowlist: + reference bool + reference_metadata { include?: bool, fields?: string[] } + metadata_condition { logic?: "and"|"or", conditions?: [{key, operator, value}] } + +Examples: + OPENAI_CHAT 'cid' 'Hello, how are you?'; + OPENAI_CHAT 'cid' 'Hello' model 'Qwen/Qwen3-8B@ling@SILICONFLOW' temperature 0.7 max_tokens 512; + OPENAI_CHAT 'cid' 'Hello' stream true; + OPENAI_CHAT 'cid' 'next' system 'You are concise.' history 'user:q1;assistant:a1'; + OPENAI_CHAT 'cid' 'Hello' extra_body '{"reference":true,"metadata_condition":{"logic":"and","conditions":[{"key":"doc_type","operator":"is","value":"faq"}]}}'; +` + fmt.Println(help) +} diff --git a/internal/cli/cli_http.go b/internal/cli/cli_http.go index 9c07579117..13587abe8f 100644 --- a/internal/cli/cli_http.go +++ b/internal/cli/cli_http.go @@ -282,6 +282,11 @@ func (c *CLI) ExecuteUserCommand(cmd *Command) (ResponseIf, error) { return c.ChatToModel(cmd) case "think_chat_to_model": return c.ChatToModel(cmd) + case "openai_chat": + return c.OpenaiChat(cmd) + case "openai_chat_help": + printOpenaiChatHelp() + return nil, nil case "embed_user_text": return c.EmbedUserText(cmd) case "rarank_user_document": diff --git a/internal/cli/http_client.go b/internal/cli/http_client.go index 9f270ef176..919a43194d 100644 --- a/internal/cli/http_client.go +++ b/internal/cli/http_client.go @@ -158,12 +158,12 @@ func (c *HTTPClient) Request(method, path string, authKind string, headers map[s return nil, err } defer resp.Body.Close() - duration := time.Since(startTime).Seconds() respBody, err := io.ReadAll(resp.Body) if err != nil { return nil, err } + duration := time.Since(startTime).Seconds() return &Response{ StatusCode: resp.StatusCode, diff --git a/internal/cli/lexer.go b/internal/cli/lexer.go index 0c4fe24202..4308342a30 100644 --- a/internal/cli/lexer.go +++ b/internal/cli/lexer.go @@ -305,6 +305,8 @@ func (l *Lexer) lookupIdent(ident string) Token { return Token{Type: TokenChats, Value: ident} case "CHAT": return Token{Type: TokenChat, Value: ident} + case "OPENAI_CHAT": + return Token{Type: TokenOpenaiChat, Value: ident} case "MESSAGE": return Token{Type: TokenMessage, Value: ident} case "IMAGE": diff --git a/internal/cli/parser.go b/internal/cli/parser.go index 4af45dceae..73c12bce3b 100644 --- a/internal/cli/parser.go +++ b/internal/cli/parser.go @@ -206,6 +206,8 @@ func (p *Parser) parseUserCommand() (*Command, error) { return p.parseStreamCommand() case TokenChat: return p.parseChatCommand() + case TokenOpenaiChat: + return p.parseOpenaiChatCommand() case TokenThink: return p.parseThinkCommand() case TokenEmbed: diff --git a/internal/cli/response.go b/internal/cli/response.go index c2b3efbfa2..fcbea45844 100644 --- a/internal/cli/response.go +++ b/internal/cli/response.go @@ -17,6 +17,7 @@ package cli import ( + "encoding/json" "fmt" "strings" ) @@ -599,3 +600,192 @@ func (r *FileSystemResponse) SetOutputFormat(format OutputFormat) { r.OutputForm func (r *FileSystemResponse) PrintOut() { fmt.Print(r.Output) } + +type OpenAIChatResponse struct { + Code int `json:"code,omitempty"` + Data *openAIChatData `json:"data,omitempty"` + Message string `json:"message,omitempty"` + Duration float64 `json:"-"` + OutputFormat OutputFormat `json:"-"` + // Reasoning from the model's chain-of-thought. + Reasoning string `json:"-"` + // streamed skips the "Answer:" line in PrintOut to avoid duplication. + streamed bool + // raw HTTP body for the "raw" output format. + raw []byte +} + +type openAIChatData struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []openAIChatChoice `json:"choices"` + Usage *openAIChatUsage `json:"usage"` + ReferencePayload json.RawMessage `json:"reference,omitempty"` +} + +type openAIChatChoice struct { + Index int `json:"index"` + FinishReason string `json:"finish_reason"` + Logprobs interface{} `json:"logprobs"` + Message openAIChatMessage `json:"message"` +} + +type openAIChatMessage struct { + Role string `json:"role"` + Content string `json:"content"` + Reference json.RawMessage `json:"reference,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` +} + +type openAIChatUsage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + CompletionTokensDetails *struct { + ReasoningTokens int `json:"reasoning_tokens"` + AcceptedPredictionTokens int `json:"accepted_prediction_tokens"` + RejectedPredictionTokens int `json:"rejected_prediction_tokens"` + } `json:"completion_tokens_details"` +} + +func (r *OpenAIChatResponse) Type() string { return "openai_chat" } +func (r *OpenAIChatResponse) TimeCost() float64 { return r.Duration } +func (r *OpenAIChatResponse) SetOutputFormat(f OutputFormat) { r.OutputFormat = f } +func (r *OpenAIChatResponse) Raw() []byte { return r.raw } + +func (r *OpenAIChatResponse) SetRaw(b []byte) { r.raw = b } + +func (r *OpenAIChatResponse) Content() string { + if r.Data == nil || len(r.Data.Choices) == 0 { + return "" + } + return r.Data.Choices[0].Message.Content +} + +func (r *OpenAIChatResponse) Model() string { + if r.Data == nil { + return "" + } + return r.Data.Model +} + +func (r *OpenAIChatResponse) Usage() *openAIChatUsage { + if r.Data == nil { + return nil + } + return r.Data.Usage +} + +func (r *OpenAIChatResponse) PrintOut() { + if r.OutputFormat == "raw" && r.raw != nil { + fmt.Println(string(r.raw)) + return + } + if r.Code != 0 { + fmt.Println("ERROR") + fmt.Printf("%d, %s\n", r.Code, r.Message) + return + } + if r.Data == nil { + fmt.Println("(no data)") + return + } + if !r.streamed { + if r.Reasoning != "" { + fmt.Printf("Thinking: %s\n", r.Reasoning) + } + if content := r.Content(); content != "" { + fmt.Printf("Answer: %s\n", content) + } + } + + // Print reference chunks and their document_metadata when available. + // Reference can be on the data-level or on the message-level. + refRaw := r.Data.ReferencePayload + if len(refRaw) == 0 && len(r.Data.Choices) > 0 { + refRaw = r.Data.Choices[0].Message.Reference + } + if len(refRaw) > 0 { + printReferenceChunks(refRaw) + } + + fmt.Printf("Time: %f\n", r.Duration) +} + +// printReferenceChunks parses a reference JSON blob and prints each chunk +// together with its document_metadata (if any). +func printReferenceChunks(raw json.RawMessage) { + var chunks []map[string]interface{} + + // direct array: [...] + if err := json.Unmarshal(raw, &chunks); err != nil { + // object with "chunks" key: {"chunks": [...], "doc_aggs": [...]} + var ref struct { + Chunks []map[string]interface{} `json:"chunks"` + } + if err2 := json.Unmarshal(raw, &ref); err2 != nil || len(ref.Chunks) == 0 { + return + } + chunks = ref.Chunks + } + if len(chunks) == 0 { + return + } + + fmt.Println("Reference:") + for i, chunk := range chunks { + id := chunkID(chunk) + content := chunkContent(chunk) + docName := chunkDocName(chunk) + fmt.Printf(" [ID:%d] id=%s content=%q", i, id, truncateStr(content, 120)) + if docName != "" { + fmt.Printf(" doc=%s", docName) + } + fmt.Println() + + // Print document_metadata if present. + if meta, ok := chunk["document_metadata"].(map[string]interface{}); ok && len(meta) > 0 { + for k, v := range meta { + fmt.Printf(" metadata.%s = %v\n", k, v) + } + } + } +} + +func chunkID(c map[string]interface{}) string { + for _, key := range []string{"chunk_id", "id"} { + if v, ok := c[key]; ok { + return fmt.Sprint(v) + } + } + return "-" +} + +func chunkContent(c map[string]interface{}) string { + if v, ok := c["content"]; ok { + s := fmt.Sprint(v) + return strings.TrimSpace(s) + } + return "" +} + +func chunkDocName(c map[string]interface{}) string { + if v, ok := c["document_name"]; ok { + return fmt.Sprint(v) + } + if v, ok := c["doc_name"]; ok { + return fmt.Sprint(v) + } + return "" +} + +func truncateStr(s string, maxLen int) string { + s = strings.TrimSpace(s) + runes := []rune(s) + if len(runes) <= maxLen { + return s + } + return string(runes[:maxLen]) + "..." +} diff --git a/internal/cli/types.go b/internal/cli/types.go index 1979b6c0a8..c67f385424 100644 --- a/internal/cli/types.go +++ b/internal/cli/types.go @@ -189,6 +189,7 @@ const ( TokenPurge TokenPlan TokenPreview + TokenOpenaiChat TokenLog TokenLevel TokenDebug diff --git a/internal/cli/user_command.go b/internal/cli/user_command.go index 1f6cc1ac9a..639db12e1e 100644 --- a/internal/cli/user_command.go +++ b/internal/cli/user_command.go @@ -32,6 +32,7 @@ import ( "ragflow/internal/ingestion" "ragflow/internal/ingestion/parser" "ragflow/internal/utility" + "regexp" "strings" "time" ) @@ -3644,3 +3645,330 @@ func (c *CLI) ChunkCommand(cmd *Command) (ResponseIf, error) { result.Message = fmt.Sprintf("Success to chunk %s", filename) return &result, nil } + +// OpenaiChat dispatches the parsed OPENAI_CHAT command to either a +// non-streaming oneshot call or a streaming SSE call, depending on the +// `stream` option. +func (c *CLI) OpenaiChat(cmd *Command) (ResponseIf, error) { + if c.Config.CLIMode != APIMode { + return nil, fmt.Errorf("OPENAI_CHAT is only allowed in USER mode") + } + httpClient := c.APIServerClientMap[c.Config.APIClientConfig.CurrentAPIServer] + if httpClient.APIToken == nil && httpClient.LoginToken == nil { + return nil, fmt.Errorf("API token not set. Please login first") + } + + body, err := buildOpenaiChatRequestBody(cmd) + if err != nil { + return nil, err + } + + chatID, _ := cmd.Params["chat_id"].(string) + url := fmt.Sprintf("/openai/%s/chat/completions", chatID) + + stream, _ := cmd.Params["stream"].(bool) + if stream { + return c.streamOpenaiChat(url, body) + } + return c.oneshotOpenaiChat(url, body) +} + +// allowedExtraBodyKeys enumerates every top-level key the server +// accepts under `extra_body`. Anything else is rejected at CLI +// build time so the user gets a clear error before the request +// goes over the wire. +var allowedExtraBodyKeys = map[string]struct{}{ + "reference": {}, + "reference_metadata": {}, + "metadata_condition": {}, +} + +// validateExtraBody checks the shape of an extra_body payload +// supplied by the user. It rejects: +// +// - Unknown top-level keys (typos and unsupported fields). +// - reference_metadata that's not an object, or whose +// sub-fields have the wrong type. +// - metadata_condition that's not an object, or whose +// conditions are missing required fields. +// +// The error message names the offending path so the user can +// fix the JSON literal in their command without having to read +// the server source. +func validateExtraBody(eb map[string]interface{}) error { + for k := range eb { + if _, ok := allowedExtraBodyKeys[k]; !ok { + return fmt.Errorf("OPENAI_CHAT extra_body: unknown field %q (valid: reference, reference_metadata, metadata_condition)", k) + } + } + + // reference_metadata: { include?: bool, fields?: string[] } + if v, present := eb["reference_metadata"]; present { + rm, ok := v.(map[string]interface{}) + if !ok { + return fmt.Errorf("OPENAI_CHAT extra_body.reference_metadata must be an object, got %T", v) + } + if inc, ok := rm["include"]; ok { + if _, ok := inc.(bool); !ok { + return fmt.Errorf("OPENAI_CHAT extra_body.reference_metadata.include must be a boolean, got %T", inc) + } + } + if fields, ok := rm["fields"]; ok { + arr, ok := fields.([]interface{}) + if !ok { + return fmt.Errorf("OPENAI_CHAT extra_body.reference_metadata.fields must be an array, got %T", fields) + } + for i, item := range arr { + if _, ok := item.(string); !ok { + return fmt.Errorf("OPENAI_CHAT extra_body.reference_metadata.fields[%d] must be a string, got %T", i, item) + } + } + } + } + + // metadata_condition: { logic?: "and"|"or", conditions?: [{key, operator, value}, ...] } + if v, present := eb["metadata_condition"]; present { + mc, ok := v.(map[string]interface{}) + if !ok { + return fmt.Errorf("OPENAI_CHAT extra_body.metadata_condition must be an object, got %T", v) + } + if logic, ok := mc["logic"]; ok { + s, ok := logic.(string) + if !ok { + return fmt.Errorf("OPENAI_CHAT extra_body.metadata_condition.logic must be a string, got %T", logic) + } + if s != "and" && s != "or" { + return fmt.Errorf("OPENAI_CHAT extra_body.metadata_condition.logic must be \"and\" or \"or\", got %q", s) + } + } + if conds, ok := mc["conditions"]; ok { + arr, ok := conds.([]interface{}) + if !ok { + return fmt.Errorf("OPENAI_CHAT extra_body.metadata_condition.conditions must be an array, got %T", conds) + } + for i, item := range arr { + cond, ok := item.(map[string]interface{}) + if !ok { + return fmt.Errorf("OPENAI_CHAT extra_body.metadata_condition.conditions[%d] must be an object, got %T", i, item) + } + if _, ok := cond["key"]; !ok { + return fmt.Errorf("OPENAI_CHAT extra_body.metadata_condition.conditions[%d] missing required field 'key'", i) + } + } + } + } + + return nil +} + +// buildOpenaiChatRequestBody assembles the JSON payload that +// /api/v1/openai//chat/completions expects +// +// RAGFlow-specific knobs (e.g. `reference`, `reference_metadata`, +// `metadata_condition`) flow in via the user-supplied `extra_body` +// JSON literal, which is validated against the `allowedExtraBodyKeys` +// allowlist above before the request goes out. `stop` and `user` are +// not first-class CLI options — the Python server does not inspect +// them, and the Go server has dropped them from its request struct; +// the parser rejects them as "unknown option" so there is exactly +// one place to set them. +// +// The `messages` array is built from three optional sources, in +// this order: +// 1. `system` — single system message (if supplied) +// 2. `history` — prior turns encoded as +// "user:...,assistant:..." (if supplied) +// 3. positional — always the trailing user turn +func buildOpenaiChatRequestBody(cmd *Command) (map[string]interface{}, error) { + msg, _ := cmd.Params["message"].(string) + model, _ := cmd.Params["model"].(string) + temp, _ := cmd.Params["temperature"].(float64) + maxTokens, _ := cmd.Params["max_tokens"].(int) + stream, _ := cmd.Params["stream"].(bool) + + messages := make([]map[string]interface{}, 0, 4) + if v, ok := cmd.Params["system"].(string); ok && v != "" { + messages = append(messages, map[string]interface{}{"role": "system", "content": v}) + } + if v, ok := cmd.Params["history_raw"].(string); ok && v != "" { + delimiter, _ := cmd.Params["history_delimiter"].(string) + turns, err := parseHistory(v, delimiter) + if err != nil { + return nil, fmt.Errorf("OPENAI_CHAT history: %w", err) + } + for _, t := range turns { + messages = append(messages, map[string]interface{}{ + "role": t["role"], + "content": t["content"], + }) + } + } + messages = append(messages, map[string]interface{}{"role": "user", "content": msg}) + + body := map[string]interface{}{ + "model": model, + "messages": messages, + "stream": stream, + } + // Only emit generation params when the user actually set them + // (zero is the parser-default for "unset" and matches Python's + // behavior of dropping the field). + if temp != 0.0 { + body["temperature"] = temp + } + if maxTokens != 0 { + body["max_tokens"] = maxTokens + } + if v, ok := cmd.Params["top_p"].(float64); ok && v != 0.0 { + body["top_p"] = v + } + if v, ok := cmd.Params["frequency_penalty"].(float64); ok && v != 0.0 { + body["frequency_penalty"] = v + } + if v, ok := cmd.Params["presence_penalty"].(float64); ok && v != 0.0 { + body["presence_penalty"] = v + } + + var extraBody map[string]interface{} + if v, ok := cmd.Params["extra_body"].(string); ok && v != "" { + if err := json.Unmarshal([]byte(v), &extraBody); err != nil { + return nil, fmt.Errorf("OPENAI_CHAT extra_body: invalid JSON: %w", err) + } + } + // Validate the user's extra_body against the server's accepted + // schema before the request goes over the wire. + if err := validateExtraBody(extraBody); err != nil { + return nil, err + } + if len(extraBody) > 0 { + body["extra_body"] = extraBody + } + + return body, nil +} + +// oneshotOpenaiChat performs a non-streaming POST and returns an +// OpenAIChatResponse parsed from the JSON envelope. It calls the +// same HTTPClient.Request used by every other CLI command. +func (c *CLI) oneshotOpenaiChat(url string, body map[string]interface{}) (ResponseIf, error) { + httpClient := c.APIServerClientMap[c.Config.APIClientConfig.CurrentAPIServer] + resp, err := httpClient.Request("POST", url, "web", nil, body) + if err != nil { + return nil, fmt.Errorf("openai_chat request: %w", err) + } + if resp.StatusCode != 200 { + // Python wraps errors as `{"code":..., "message":...}`. Surface + // the body verbatim so the user can read the upstream error. + return &OpenAIChatResponse{ + Code: resp.StatusCode, + Message: string(resp.Body), + raw: resp.Body, + }, nil + } + out := &OpenAIChatResponse{ + Duration: resp.Duration, + raw: resp.Body, + } + var wrapped struct { + Code int `json:"code"` + Message string `json:"message"` + Data *openAIChatData `json:"data"` + } + if err := json.Unmarshal(resp.Body, &wrapped); err == nil && wrapped.Data != nil { + out.Code = wrapped.Code + out.Message = wrapped.Message + out.Data = wrapped.Data + if len(wrapped.Data.Choices) > 0 { + out.Reasoning = wrapped.Data.Choices[0].Message.ReasoningContent + } + return out, nil + } + // Unwrapped (Go handler) shape. + if err := json.Unmarshal(resp.Body, &out.Data); err != nil { + return nil, fmt.Errorf("openai_chat: invalid response JSON: %w", err) + } + if out.Data != nil && len(out.Data.Choices) > 0 { + out.Reasoning = out.Data.Choices[0].Message.ReasoningContent + } + return out, nil +} + +// streamOpenaiChat performs a streaming POST and prints SSE chunks to +// stdout as they arrive +func (c *CLI) streamOpenaiChat(url string, body map[string]interface{}) (ResponseIf, error) { + httpClient := c.APIServerClientMap[c.Config.APIClientConfig.CurrentAPIServer] + resp, err := httpClient.Request("POST", url, "web", nil, body) + if err != nil { + return nil, fmt.Errorf("openai_chat stream: %w", err) + } + if resp.StatusCode != 200 { + return &OpenAIChatResponse{ + Code: resp.StatusCode, + Message: string(resp.Body), + Duration: resp.Duration, + raw: resp.Body, + }, nil + } + full := string(resp.Body) + var ( + fullContent string + fullReason string + resolvedMod string + ) + for _, line := range strings.Split(full, "\n") { + line = strings.TrimSpace(line) + if !strings.HasPrefix(line, "data:") { + continue + } + payload := strings.TrimSpace(strings.TrimPrefix(line, "data:")) + if payload == "" || payload == "[DONE]" { + continue + } + var chunk struct { + Model string `json:"model"` + Choices []struct { + Delta struct { + Content string `json:"content"` + ReasoningContent string `json:"reasoning_content"` + } `json:"delta"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` + Usage *openAIChatUsage `json:"usage"` + } + if err := json.Unmarshal([]byte(payload), &chunk); err != nil { + continue + } + if chunk.Model != "" { + resolvedMod = chunk.Model + } + if len(chunk.Choices) > 0 { + if d := chunk.Choices[0].Delta.Content; d != "" { + fullContent += d + } + if r := chunk.Choices[0].Delta.ReasoningContent; r != "" { + fullReason += r + } + } + } + + fullContent = strings.TrimLeft(fullContent, "\n\r") + fullReason = strings.TrimLeft(fullReason, "\n\r") + fullContent = stripThinkTags(fullContent) + fullReason = stripThinkTags(fullReason) + return &OpenAIChatResponse{ + Duration: resp.Duration, + Reasoning: fullReason, + Data: &openAIChatData{ + Model: resolvedMod, + Choices: []openAIChatChoice{{Message: openAIChatMessage{Content: fullContent, ReasoningContent: fullReason}}}, + }, + streamed: true, + raw: resp.Body, + }, nil +} + +// stripThinkTags removes wrappers from a streamed answer +func stripThinkTags(s string) string { + var thinkTagRE = regexp.MustCompile(`(?s).*?`) + return thinkTagRE.ReplaceAllString(s, "") +} diff --git a/internal/cli/user_parser.go b/internal/cli/user_parser.go index ffa5e7f27d..915857b541 100644 --- a/internal/cli/user_parser.go +++ b/internal/cli/user_parser.go @@ -1,8 +1,10 @@ package cli import ( + "encoding/json" "fmt" "ragflow/internal/common" + "regexp" "strconv" "strings" ) @@ -2799,6 +2801,16 @@ func (p *Parser) parseRetrieveCommand() (*Command, error) { p.nextToken() } continue + } else if p.curToken.Type == TokenIllegal && p.curToken.Value == "." { + if cmd.Params["path"] == nil { + cmd.Params["path"] = "." + } else { + cmd.Params["path"] = fmt.Sprintf("%s.", cmd.Params["path"]) + } + p.nextToken() + continue + } else { + return nil, fmt.Errorf("unexpected token %q in search path", p.curToken.Value) } } return cmd, nil @@ -4750,3 +4762,252 @@ func (p *Parser) parseChunkCommand(explain bool) (*Command, error) { return cmd, nil } + +// parseOpenaiChatCommand parses: +// +// OPENAI_CHAT +// [model ] [system ] +// [history ] [history_delimiter ] +// [temperature ] [max_tokens ] [stream ] +// [top_p ] [frequency_penalty ] [presence_penalty ] +// [extra_body ] +// ; +// +// Named options can appear in any order. The chat_id and message are +// required positional args; everything else is optional with a default. +// +// `history` is captured as a single string in cmd.Params["history_raw"] +// and is split into turns by cmd.Params["history_delimiter"] (default +// ";") later in buildOpenaiChatRequestBody — this two-step split lets +// `history_delimiter` and `history` appear in either order on the +// command line. The chosen delimiter must not appear inside any +// message body. +// +// `extra_body` is well-formed JSON. The accepted keys are: +// +// reference bool +// reference_metadata { include?: bool, fields?: string[] } +// metadata_condition { logic?: "and"|"or", conditions?: [{key, operator, value}] } +// (See user_command.go:allowedExtraBodyKeys for the authoritative set) +func (p *Parser) parseOpenaiChatCommand() (*Command, error) { + p.nextToken() // consume OPENAI_CHAT + + if p.curToken.Type == TokenDash { + dashCount := 0 + for p.curToken.Type == TokenDash { + dashCount++ + p.nextToken() + } + if dashCount > 0 && p.curToken.Type == TokenIdentifier { + switch strings.ToLower(p.curToken.Value) { + case "h", "help": + return NewCommand("openai_chat_help"), nil + } + } + return nil, fmt.Errorf("OPENAI_CHAT: only -h/--help takes no args; otherwise expected chat_id and message") + } + + cmd := NewCommand("openai_chat") + + // Defaults — match the OpenAI spec / RAGFlow server behavior. + cmd.Params["model"] = "model" // placeholder; server resolves to dialog.llm_id + cmd.Params["temperature"] = 0.0 + cmd.Params["max_tokens"] = 0 + cmd.Params["stream"] = false + + // Required positional: + chatID, err := p.parseQuotedString() + if err != nil { + return nil, fmt.Errorf("OPENAI_CHAT: expected chat_id as first argument: %w", err) + } + cmd.Params["chat_id"] = chatID + p.nextToken() + + message, err := p.parseQuotedString() + if err != nil { + return nil, fmt.Errorf("OPENAI_CHAT: expected message as second argument: %w", err) + } + cmd.Params["message"] = message + p.nextToken() + + // Optional + handleOption := func(name string) error { + switch name { + case "model", "system": + v, err := p.parseQuotedString() + if err != nil { + return fmt.Errorf("OPENAI_CHAT %s: expected quoted string, got %s", name, p.curToken.Value) + } + cmd.Params[name] = v + p.nextToken() + case "temperature", "top_p", "frequency_penalty", "presence_penalty": + v, err := p.parseFloat() + if err != nil { + return fmt.Errorf("OPENAI_CHAT %s: expected number, got %s", name, p.curToken.Value) + } + cmd.Params[name] = v + p.nextToken() + case "max_tokens": + v, err := p.parseNumber() + if err != nil { + return fmt.Errorf("OPENAI_CHAT max_tokens: expected integer, got %s", p.curToken.Value) + } + cmd.Params["max_tokens"] = v + p.nextToken() + case "stream": + v, err := p.parseBool() + if err != nil { + return fmt.Errorf("OPENAI_CHAT %s: expected true|false, got %s", name, p.curToken.Value) + } + cmd.Params[name] = v + // parseBool already advances the cursor. + case "extra_body": + raw, err := p.parseJSONLiteral() + if err != nil { + return fmt.Errorf("OPENAI_CHAT %s: %w", name, err) + } + cmd.Params[name] = raw + p.nextToken() + case "history": + raw, err := p.parseQuotedString() + if err != nil { + return fmt.Errorf("OPENAI_CHAT history: expected quoted string, got %s", p.curToken.Value) + } + cmd.Params["history_raw"] = raw + p.nextToken() + case "history_delimiter": + v, err := p.parseQuotedString() + if err != nil { + return fmt.Errorf("OPENAI_CHAT history_delimiter: expected quoted string, got %s", p.curToken.Value) + } + cmd.Params["history_delimiter"] = v + p.nextToken() + default: + return fmt.Errorf("OPENAI_CHAT: unknown option %q (valid: model, system, history, history_delimiter, temperature, max_tokens, stream, top_p, frequency_penalty, presence_penalty, extra_body)", name) + } + return nil + } + + // Named options, any order, until ';'. +optionsLoop: + for { + switch p.curToken.Type { + case TokenSemicolon: + p.nextToken() + break optionsLoop + case TokenEOF: + break optionsLoop + + case TokenIdentifier, TokenQuotedString: + name := p.curToken.Value + if p.curToken.Type == TokenQuotedString { + name = strings.Trim(name, "'\"") + } + p.nextToken() + if err := handleOption(name); err != nil { + return nil, err + } + + default: + if !isKeyword(p.curToken.Type) { + return nil, fmt.Errorf("OPENAI_CHAT: unexpected token %q in option list (valid options: model, system, history, history_delimiter, temperature, max_tokens, stream, top_p, frequency_penalty, presence_penalty, extra_body)", p.curToken.Value) + } + name := p.curToken.Value + p.nextToken() + if err := handleOption(name); err != nil { + return nil, err + } + } + } + + return cmd, nil +} + +// parseJSONLiteral consumes a TokenQuotedString whose payload is a JSON +// value (object, array, string, number, or boolean) and returns it as +// the original raw string (NOT decoded — the caller decides whether to +// embed it into a larger JSON object or pass it through as-is). +func (p *Parser) parseJSONLiteral() (string, error) { + if p.curToken.Type != TokenQuotedString { + return "", fmt.Errorf("expected JSON literal in single/double quotes, got %s", p.curToken.Value) + } + raw := p.curToken.Value + // Validate it actually parses as JSON so we fail fast on + // typos like `'{}' extra comma'` or `'not json'`. + var probe interface{} + if err := json.Unmarshal([]byte(raw), &probe); err != nil { + return "", fmt.Errorf("invalid JSON literal %q: %w", raw, err) + } + return raw, nil +} + +// parseBool accepts a TokenIdentifier "true"/"false" +func (p *Parser) parseBool() (bool, error) { + switch strings.ToLower(p.curToken.Value) { + case "true": + p.nextToken() + return true, nil + case "false": + p.nextToken() + return false, nil + } + return false, fmt.Errorf("expected true or false, got %q", p.curToken.Value) +} + +// historyRoleRegex matches the role prefix on a turn. The captured +// alternation is the role; the colon is required so we don't +// accidentally split on a word like "user:foo" appearing inside +// other content. +var historyRoleRegex = regexp.MustCompile(`(?i)^(user|assistant):`) + +// defaultHistoryDelimiter is the turn separator used when the +// caller does not pass the `history_delimiter` option. +const defaultHistoryDelimiter = ";" + +// parseHistory splits the history literal into a slice of +// {"role": ..., "content": ...} maps. Format: +// +// "user:question one;assistant:answer one;user:question two" +// +// Turns are separated by `history_delimiter` (default `;`). Each +// segment must start with the role prefix `user:` or `assistant:` +// (case-insensitive). +func parseHistory(literal, delimiter string) ([]map[string]string, error) { + if delimiter == "" { + delimiter = defaultHistoryDelimiter + } + + // Trim a single pair of surrounding quotes if present. + s := strings.TrimSpace(literal) + if len(s) >= 2 { + first, last := s[0], s[len(s)-1] + if (first == '"' && last == '"') || (first == '\'' && last == '\'') { + s = s[1 : len(s)-1] + } + } + + raw := strings.Split(s, delimiter) + turns := make([]map[string]string, 0, len(raw)) + for _, segment := range raw { + segment = strings.TrimSpace(segment) + if segment == "" { + continue + } + m := historyRoleRegex.FindStringSubmatch(segment) + if m == nil { + return nil, fmt.Errorf("history segment %q must start with 'user:' or 'assistant:'", segment) + } + role := strings.ToLower(m[1]) + // Drop the ":" prefix (m[0] is the whole match, e.g. + // "user:"; we want the content AFTER the colon). + content := strings.TrimPrefix(segment, m[0]) + turns = append(turns, map[string]string{ + "role": role, + "content": content, + }) + } + if len(turns) == 0 { + return nil, fmt.Errorf("history is empty or unparseable: %q", literal) + } + return turns, nil +} diff --git a/internal/common/metadata_utils.go b/internal/common/metadata_utils.go index 24cad1bfc1..1786e506f2 100644 --- a/internal/common/metadata_utils.go +++ b/internal/common/metadata_utils.go @@ -49,7 +49,7 @@ var operatorMapping = map[string]string{ ">=": "≥", "<=": "≤", "!=": "≠", - "==": "=", + "==": "=", } // ParseAndConvert converts raw API conditions into MetaFilterInput. @@ -76,10 +76,16 @@ func ParseAndConvert(metadataCondition map[string]interface{}) *MetaFilterInput continue } name, _ := cond["name"].(string) + if name == "" { + name, _ = cond["key"].(string) // OpenAI API metadata_condition uses "key" + } if name == "" { continue } op, _ := cond["comparison_operator"].(string) + if op == "" { + op, _ = cond["operator"].(string) // OpenAI API uses "operator" + } op = convertOperator(op) conditions = append(conditions, MetaCondition{ Operator: op, diff --git a/internal/common/multimodal.go b/internal/common/multimodal.go new file mode 100644 index 0000000000..74d3efa7db --- /dev/null +++ b/internal/common/multimodal.go @@ -0,0 +1,353 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package common + +import ( + "encoding/base64" + "fmt" + "regexp" + "strings" +) + +// ContentPart is the internal representation of a multimodal content +// fragment, decoupled from any provider's wire format. Drivers consume +// the result of RenderContentPartsForFactory to produce their per- +// provider JSON. +type ContentPart struct { + // Type is one of: "text", "image_url", "image", "inline_data". + Type string + // Text is set when Type == "text". + Text string + // ImageURL is set when Type == "image_url" (OpenAI shape). + ImageURL *ImageURL + // Source is set when Type == "image" (Anthropic) or + // Type == "inline_data" (Gemini). + Source *ContentSource +} + +// ImageURL is the OpenAI-shaped image reference. +type ImageURL struct { + URL string `json:"url"` +} + +// ContentSource is the Anthropic / Gemini source payload. +type ContentSource struct { + Type string `json:"type"` // "base64" or "url" + MediaType string `json:"media_type"` // e.g. "image/png" + Data string `json:"data,omitempty"` // base64 payload + URL string `json:"url,omitempty"` +} + +// dataURIRE detects a "data:;base64," string. +var dataURIRE = regexp.MustCompile(`^data:([^;,]+)(?:;base64)?,(.*)$`) + +// parseDataURIOrB64 accepts a string and classifies it as a data URI, +// a plain https URL, or a raw base64 payload +func parseDataURIOrB64(s string) (ContentSource, error) { + if s == "" { + return ContentSource{}, fmt.Errorf("empty image source") + } + if m := dataURIRE.FindStringSubmatch(s); m != nil { + mediaType := strings.TrimSpace(m[1]) + if mediaType == "" { + mediaType = "image/png" + } + return ContentSource{ + Type: "base64", + MediaType: mediaType, + Data: m[2], + }, nil + } + if strings.HasPrefix(s, "http://") || strings.HasPrefix(s, "https://") { + return ContentSource{Type: "url", URL: s}, nil + } + // Assume raw base64 (no data URI, no http scheme). The provider + // uses the file extension or a content-type hint from the call site + // to pick the right media type; we default to image/png. + if _, err := base64.StdEncoding.DecodeString(s); err != nil { + return ContentSource{}, fmt.Errorf("not a valid data URI, URL, or base64: %w", err) + } + return ContentSource{ + Type: "base64", + MediaType: "image/png", + Data: s, + }, nil +} + +// normalizeTextFromContent extracts a single text string from a content +// value that may be a string, []map[string]interface{}, or []interface{}. +func normalizeTextFromContent(content interface{}) string { + switch v := content.(type) { + case string: + return v + case []map[string]interface{}: + var parts []string + for _, p := range v { + if t, ok := p["type"].(string); ok && (t == "text" || t == "input_text") { + if txt, ok := p["text"].(string); ok { + parts = append(parts, txt) + } + } else if txt, ok := p["text"]; ok { + // Fallback: "text" key present even though type didn't match. + switch tv := txt.(type) { + case string: + parts = append(parts, tv) + case float64: + parts = append(parts, fmt.Sprintf("%v", tv)) + case int: + parts = append(parts, fmt.Sprintf("%v", tv)) + } + } + } + return strings.Join(parts, "\n") + case []interface{}: + var parts []string + for _, item := range v { + switch p := item.(type) { + case map[string]interface{}: + if t, ok := p["type"].(string); ok && (t == "text" || t == "input_text") { + if txt, ok := p["text"].(string); ok { + parts = append(parts, txt) + } + } else if txt, ok := p["text"]; ok { + // Fallback: "text" key present even though type didn't match. + switch tv := txt.(type) { + case string: + parts = append(parts, tv) + case float64: + parts = append(parts, fmt.Sprintf("%v", tv)) + case int: + parts = append(parts, fmt.Sprintf("%v", tv)) + } + } + case string: + parts = append(parts, p) + } + } + return strings.Join(parts, "\n") + } + return "" +} + +// extractImageURLs pulls image_url values out of a content value. Used +// by ConvertLastUserMsgToMultimodal to assemble the ContentPart slice. +func extractImageURLs(content interface{}) []string { + var urls []string + process := func(p map[string]interface{}) { + t, _ := p["type"].(string) + if t == "image_url" { + if u, ok := p["image_url"].(string); ok && u != "" { + urls = append(urls, u) + } else if obj, ok := p["image_url"].(map[string]interface{}); ok { + if u, ok := obj["url"].(string); ok && u != "" { + urls = append(urls, u) + } + } + } + } + switch v := content.(type) { + case []map[string]interface{}: + for _, p := range v { + process(p) + } + case []interface{}: + for _, item := range v { + if p, ok := item.(map[string]interface{}); ok { + process(p) + } + } + } + return urls +} + +// ConvertLastUserMsgToMultimodal converts a user message whose content +// is a multimodal parts array into a message whose content is a +// driver-ready content-parts value, dispatched by `factory` (provider +// name). +// +// `imageAttachments` is an additional list of image URLs from the +// `messages[-1]["files"]` array. +// When non-empty, each URL is added to the content as an image +// part regardless of the original message content. +// +// factory values supported: +// - "gemini" → {"text": ...} / {"inline_data": {...}} +// - "anthropic" → {"type": "text", ...} / {"type": "image", "source": {...}} +// - default → {"type": "text", ...} / {"type": "image_url", "image_url": {...}} +// +// If the message is already a string, it is returned unchanged. +// If the message has no image parts and `imageAttachments` is empty, +// the text is returned as a string for compatibility with providers +// that don't accept content arrays. +func ConvertLastUserMsgToMultimodal(msg map[string]interface{}, imageAttachments []string, factory string) (map[string]interface{}, error) { + if msg == nil { + return nil, fmt.Errorf("nil message") + } + originalContent, ok := msg["content"] + if !ok { + return msg, nil + } + // If the content is already a plain string and there are no + // imageAttachments to add, leave it alone. + if _, isString := originalContent.(string); isString && len(imageAttachments) == 0 { + return msg, nil + } + + // Combine images from the content array and from imageAttachments + // (the `files` array on the last user message). + // Order: content-array images first, then files-array images. + textPart := normalizeTextFromContent(originalContent) + imageURLs := extractImageURLs(originalContent) + allImageURLs := append(imageURLs, imageAttachments...) + if len(allImageURLs) == 0 { + // No images — collapse to a string for compatibility. + out := make(map[string]interface{}, len(msg)) + for k, v := range msg { + out[k] = v + } + out["content"] = textPart + return out, nil + } + + // Build ContentPart slice. + parts := make([]ContentPart, 0, 1+len(allImageURLs)) + if textPart != "" { + parts = append(parts, ContentPart{Type: "text", Text: textPart}) + } + for _, u := range allImageURLs { + src, err := parseDataURIOrB64(u) + if err != nil { + return nil, fmt.Errorf("image_url %q: %w", u, err) + } + // OpenAI / default: pass the raw URL through (provider accepts + // both data: and http(s):). Anthropic / Gemini need a Source. + if factory == "anthropic" || factory == "gemini" { + parts = append(parts, ContentPart{ + Type: pickImageType(factory), + Source: &src, + }) + } else { + parts = append(parts, ContentPart{ + Type: "image_url", + ImageURL: &ImageURL{URL: u}, + }) + } + } + + // Render to the driver's wire format. + rendered, err := RenderContentPartsForFactory(parts, factory) + if err != nil { + return nil, err + } + out := make(map[string]interface{}, len(msg)) + for k, v := range msg { + out[k] = v + } + out["content"] = rendered + return out, nil +} + +func pickImageType(factory string) string { + if factory == "gemini" { + return "inline_data" + } + return "image" +} + +// RenderContentPartsForFactory converts internal ContentPart values +// into the per-provider JSON wire format: +// +// - gemini: [{"text": ...}, {"inline_data": {"mime_type": ..., "data": ...}}] +// - anthropic: [{"type": "text", "text": ...}, {"type": "image", "source": {...}}] +// - default: [{"type": "text", "text": ...}, {"type": "image_url", "image_url": {"url": ...}}] +// +// The return value is suitable for direct assignment to a Message's +// `Content` field (`interface{}`). +func RenderContentPartsForFactory(parts []ContentPart, factory string) (interface{}, error) { + factory = strings.ToLower(factory) + switch factory { + case "gemini": + out := make([]map[string]interface{}, 0, len(parts)) + for _, p := range parts { + switch p.Type { + case "text": + out = append(out, map[string]interface{}{"text": p.Text}) + case "image", "inline_data": + if p.Source == nil { + return nil, fmt.Errorf("gemini image part missing source") + } + if p.Source.Type == "url" { + out = append(out, map[string]interface{}{ + "file_data": map[string]interface{}{ + "file_uri": p.Source.URL, + "mime_type": p.Source.MediaType, + }, + }) + } else { + out = append(out, map[string]interface{}{ + "inline_data": map[string]interface{}{ + "mime_type": p.Source.MediaType, + "data": p.Source.Data, + }, + }) + } + } + } + return out, nil + case "anthropic": + out := make([]map[string]interface{}, 0, len(parts)) + for _, p := range parts { + switch p.Type { + case "text": + out = append(out, map[string]interface{}{ + "type": "text", + "text": p.Text, + }) + case "image": + if p.Source == nil { + return nil, fmt.Errorf("anthropic image part missing source") + } + out = append(out, map[string]interface{}{ + "type": "image", + "source": p.Source, + }) + } + } + return out, nil + default: + // OpenAI-compatible. + out := make([]map[string]interface{}, 0, len(parts)) + for _, p := range parts { + switch p.Type { + case "text": + out = append(out, map[string]interface{}{ + "type": "text", + "text": p.Text, + }) + case "image_url": + if p.ImageURL == nil { + return nil, fmt.Errorf("openai image_url part missing URL") + } + out = append(out, map[string]interface{}{ + "type": "image_url", + "image_url": p.ImageURL, + }) + } + } + return out, nil + } +} diff --git a/internal/common/timer.go b/internal/common/timer.go new file mode 100644 index 0000000000..ae8ae14f62 --- /dev/null +++ b/internal/common/timer.go @@ -0,0 +1,169 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package common + +import ( + "encoding/json" + "fmt" + "strings" + "sync" + "time" +) + +// Phase is a named timing bucket in the RAG pipeline +type Phase string + +const ( + PhaseCheckLLM Phase = "check_llm" + PhaseCheckLangfuse Phase = "check_langfuse" + PhaseBindModels Phase = "bind_models" + PhaseQueryRefinement Phase = "query_refinement" + PhaseRetrieval Phase = "retrieval" + PhaseGenerateAnswer Phase = "generate_answer" +) + +// allPhases ordered for Markdown() display. +var allPhases = []Phase{ + PhaseCheckLLM, + PhaseCheckLangfuse, + PhaseBindModels, + PhaseQueryRefinement, + PhaseRetrieval, + PhaseGenerateAnswer, +} + +// Timer tracks elapsed wall-clock time per named Phase. +// Supports reentrant Enter/Exit on the same phase (inner span's duration +// adds to the outer span's accumulated total). +type Timer struct { + mu sync.Mutex + start time.Time + phases map[Phase]time.Duration + entries map[Phase][]time.Time +} + +// NewTimer constructs a Timer. +func NewTimer() *Timer { + return &Timer{ + phases: make(map[Phase]time.Duration, len(allPhases)), + entries: make(map[Phase][]time.Time, len(allPhases)), + } +} + +// Start anchors the timer. Calling Start() twice resets all state. +func (t *Timer) Start() { + t.mu.Lock() + defer t.mu.Unlock() + t.start = time.Now() + t.phases = make(map[Phase]time.Duration, len(allPhases)) + t.entries = make(map[Phase][]time.Time, len(allPhases)) +} + +// Enter marks the start of phase p. Reentrant calls push a new anchor. +func (t *Timer) Enter(p Phase) { + t.mu.Lock() + defer t.mu.Unlock() + t.entries[p] = append(t.entries[p], time.Now()) +} + +// Exit records the duration since the most recent Enter(p). No-op if no Enter. +func (t *Timer) Exit(p Phase) { + t.mu.Lock() + defer t.mu.Unlock() + stack := t.entries[p] + if len(stack) == 0 { + return + } + open := stack[len(stack)-1] + t.entries[p] = stack[:len(stack)-1] + t.phases[p] += time.Since(open) +} + +// Phase returns the accumulated duration for phase p. +func (t *Timer) Phase(p Phase) time.Duration { + t.mu.Lock() + defer t.mu.Unlock() + return t.phases[p] +} + +// Total returns the elapsed time since Start(). +func (t *Timer) Total() time.Duration { + t.mu.Lock() + defer t.mu.Unlock() + if t.start.IsZero() { + return 0 + } + return time.Since(t.start) +} + +// PhaseReport is the JSON-serializable view of a Timer's state. +type PhaseReport struct { + PhasesMs map[string]float64 `json:"phases_ms"` + TotalMs float64 `json:"total_ms"` +} + +// Report returns a JSON-marshalable snapshot with microsecond precision. +func (t *Timer) Report() *PhaseReport { + t.mu.Lock() + defer t.mu.Unlock() + phases := make(map[string]float64, len(allPhases)) + for _, p := range allPhases { + phases[string(p)] = float64(t.phases[p].Microseconds()) / 1000.0 + } + var totalMs float64 + if !t.start.IsZero() { + totalMs = float64(time.Since(t.start).Microseconds()) / 1000.0 + } + return &PhaseReport{PhasesMs: phases, TotalMs: totalMs} +} + +func (t *Timer) MarshalJSON() ([]byte, error) { + return json.Marshal(t.Report()) +} + +// Markdown renders the Timer as a "## Time elapsed:" block matching +func (t *Timer) Markdown() string { + r := t.Report() + var b strings.Builder + b.WriteString("\n## Time elapsed:\n") + b.WriteString(fmt.Sprintf(" - Total: %.1fms\n", r.TotalMs)) + for _, p := range allPhases { + ms := r.PhasesMs[string(p)] + b.WriteString(fmt.Sprintf(" - %s: %.1fms\n", displayName(p), ms)) + } + b.WriteString("\n") + return b.String() +} + +func displayName(p Phase) string { + switch p { + case PhaseCheckLLM: + return "Check LLM" + case PhaseCheckLangfuse: + return "Check Langfuse tracer" + case PhaseBindModels: + return "Bind models" + case PhaseQueryRefinement: + return "Query refinement(LLM)" + case PhaseRetrieval: + return "Retrieval" + case PhaseGenerateAnswer: + return "Generate answer" + default: + return string(p) + } +} diff --git a/internal/common/timer_test.go b/internal/common/timer_test.go new file mode 100644 index 0000000000..0424122325 --- /dev/null +++ b/internal/common/timer_test.go @@ -0,0 +1,224 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package common + +import ( + "encoding/json" + "regexp" + "strings" + "sync" + "testing" + "time" +) + +func TestTimer_BasicSequentialPhases(t *testing.T) { + tm := NewTimer() + tm.Start() + + tm.Enter(PhaseCheckLLM) + time.Sleep(5 * time.Millisecond) + tm.Exit(PhaseCheckLLM) + + tm.Enter(PhaseBindModels) + time.Sleep(3 * time.Millisecond) + tm.Exit(PhaseBindModels) + + got := tm.Phase(PhaseCheckLLM) + if got < 4*time.Millisecond || got > 50*time.Millisecond { + t.Errorf("PhaseCheckLLM = %v, want ~5ms", got) + } + got = tm.Phase(PhaseBindModels) + if got < 2*time.Millisecond || got > 50*time.Millisecond { + t.Errorf("PhaseBindModels = %v, want ~3ms", got) + } + + // Untouched phase should be 0. + if d := tm.Phase(PhaseRetrieval); d != 0 { + t.Errorf("PhaseRetrieval = %v, want 0", d) + } + + total := tm.Total() + if total < 7*time.Millisecond { + t.Errorf("Total = %v, want >= 7ms", total) + } +} + +func TestTimer_NestedPhasesAddUp(t *testing.T) { + tm := NewTimer() + tm.Start() + + tm.Enter(PhaseQueryRefinement) // outer + time.Sleep(2 * time.Millisecond) + tm.Enter(PhaseGenerateAnswer) // inner (LLM call inside pre-retrieval) + time.Sleep(3 * time.Millisecond) + tm.Exit(PhaseGenerateAnswer) + time.Sleep(1 * time.Millisecond) + tm.Exit(PhaseQueryRefinement) + + // Generate answer records the inner 3ms. + got := tm.Phase(PhaseGenerateAnswer) + if got < 2*time.Millisecond || got > 50*time.Millisecond { + t.Errorf("PhaseGenerateAnswer = %v, want ~3ms", got) + } + // Pre-retrieval processing records the WHOLE outer span (2 + 3 + 1 ≈ 6ms). + got = tm.Phase(PhaseQueryRefinement) + if got < 5*time.Millisecond || got > 50*time.Millisecond { + t.Errorf("PhaseQueryRefinement = %v, want ~6ms (outer span)", got) + } +} + +func TestTimer_ExitWithoutEnterIsNoop(t *testing.T) { + tm := NewTimer() + tm.Start() + // Should not panic, should not record anything. + tm.Exit(PhaseRetrieval) + if d := tm.Phase(PhaseRetrieval); d != 0 { + t.Errorf("PhaseRetrieval = %v, want 0", d) + } +} + +func TestTimer_StartResetsState(t *testing.T) { + tm := NewTimer() + tm.Start() + tm.Enter(PhaseCheckLLM) + time.Sleep(2 * time.Millisecond) + tm.Exit(PhaseCheckLLM) + if tm.Phase(PhaseCheckLLM) == 0 { + t.Fatal("precondition: phase must be non-zero before reset") + } + tm.Start() + if d := tm.Phase(PhaseCheckLLM); d != 0 { + t.Errorf("after Start, PhaseCheckLLM = %v, want 0", d) + } + if total := tm.Total(); total > 50*time.Millisecond { + t.Errorf("after Start, Total = %v, want tiny", total) + } +} + +func TestTimer_ConcurrentAccess(t *testing.T) { + tm := NewTimer() + tm.Start() + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + tm.Enter(PhaseRetrieval) + time.Sleep(time.Millisecond) + tm.Exit(PhaseRetrieval) + }() + } + wg.Wait() + got := tm.Phase(PhaseRetrieval) + if got < 9*time.Millisecond { + t.Errorf("PhaseRetrieval = %v, want ~10ms (10 parallel spans)", got) + } +} + +func TestTimer_Report(t *testing.T) { + tm := NewTimer() + tm.Start() + tm.Enter(PhaseCheckLLM) + time.Sleep(2 * time.Millisecond) + tm.Exit(PhaseCheckLLM) + tm.Enter(PhaseBindModels) + time.Sleep(1 * time.Millisecond) + tm.Exit(PhaseBindModels) + + r := tm.Report() + // Required fields + if _, ok := r.PhasesMs[string(PhaseCheckLLM)]; !ok { + t.Errorf("Report missing PhaseCheckLLM: %+v", r.PhasesMs) + } + if _, ok := r.PhasesMs[string(PhaseBindModels)]; !ok { + t.Errorf("Report missing PhaseBindModels: %+v", r.PhasesMs) + } + if _, ok := r.PhasesMs[string(PhaseGenerateAnswer)]; !ok { + t.Errorf("Report missing PhaseGenerateAnswer: %+v", r.PhasesMs) + } + if r.PhasesMs[string(PhaseCheckLLM)] < 1.0 { + t.Errorf("Report PhaseCheckLLM_ms = %v, want >= 1.0", r.PhasesMs[string(PhaseCheckLLM)]) + } + if r.TotalMs < 2.0 { + t.Errorf("Report TotalMs = %v, want >= 2.0", r.TotalMs) + } + + // JSON round-trip + b, err := json.Marshal(r) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + if !strings.Contains(string(b), `"phases_ms"`) || !strings.Contains(string(b), `"total_ms"`) { + t.Errorf("JSON missing expected keys: %s", b) + } + + // Direct Marshal of the Timer + b2, err := json.Marshal(tm) + if err != nil { + t.Fatalf("Marshal(Timer) failed: %v", err) + } + if !strings.Contains(string(b2), `"phases_ms"`) { + t.Errorf("Timer JSON missing phases_ms: %s", b2) + } +} + +func TestTimer_Markdown(t *testing.T) { + tm := NewTimer() + tm.Start() + tm.Enter(PhaseCheckLLM) + time.Sleep(2 * time.Millisecond) + tm.Exit(PhaseCheckLLM) + tm.Enter(PhaseRetrieval) + time.Sleep(5 * time.Millisecond) + tm.Exit(PhaseRetrieval) + tm.Enter(PhaseGenerateAnswer) + time.Sleep(50 * time.Millisecond) + tm.Exit(PhaseGenerateAnswer) + + md := tm.Markdown() + + // Should start with newline + "## Time elapsed:" header + if !strings.HasPrefix(md, "\n## Time elapsed:") { + t.Errorf("Markdown missing header: %q", md) + } + // Should contain all 6 phase labels + for _, label := range []string{"Check LLM", "Check Langfuse tracer", "Bind models", "Query refinement(LLM)", "Retrieval", "Generate answer", "Total"} { + if !strings.Contains(md, label+":") { + t.Errorf("Markdown missing label %q: %q", label, md) + } + } + // Phase durations should be numeric with "ms" suffix. + mdRE := regexp.MustCompile(`(?m)^\s*-\s+([A-Za-z ()\.]+):\s+([0-9.]+)ms$`) + matches := mdRE.FindAllStringSubmatch(md, -1) + if len(matches) < 7 { + t.Errorf("expected 7 phase lines, found %d in:\n%s", len(matches), md) + } + // Total should be the sum-ish of the three measured phases. + totalRE := regexp.MustCompile(`Total:\s+([0-9.]+)ms`) + totalMatch := totalRE.FindStringSubmatch(md) + if len(totalMatch) < 2 { + t.Fatalf("Markdown missing Total line: %q", md) + } +} + +func TestTimer_TotalBeforeStart(t *testing.T) { + tm := NewTimer() + // No Start() called. + if total := tm.Total(); total != 0 { + t.Errorf("Total before Start = %v, want 0", total) + } +} diff --git a/internal/development.md b/internal/development.md index c5f7bcf642..3a0d1c0b33 100644 --- a/internal/development.md +++ b/internal/development.md @@ -159,16 +159,54 @@ Time: 76.582520 ``` Note: Both image and video understanding support streaming and thinking modes as well. -### 6.8. Generate Embeddings +### 6.8. Chat with OpenAI compatible API +``` +RAGFlow(api/default)> openai_chat '' 'Hello, how are you?'; +Answer: Hello! I'm just a virtual assistant, so I don't have feelings, but I'm here and ready to help you with anything you need. How can I assist you today? 😊 +Time: 8.487349 +``` + +``` +RAGFlow(api/default)> openai_chat '' 'Great, now what about x^3?' \ + system 'You are a math tutor. Always explain step by step.' \ + history 'user:What is the derivative of x^2?;assistant:The derivative of x^2 is 2x.'; +``` + +``` +RAGFlow(api/default)> openai_chat '' 'Hello, how are you?' temperature 0.7 max_tokens 100; +``` + +``` +RAGFlow(api/default)> openai_chat '' "what's in the doc?" stream true \ + extra_body '{"reference":true,"reference_metadata":{"include":true,"fields":["author","title"]}}'; +``` + +``` +RAGFlow(api/default)> openai_chat '7b1d58f263ca11f18121ab54cc8673a7' 'Hello' \ + extra_body '{"metadata_condition":{"logic":"and","conditions":[{"key":"doc_type","operator":"is","value":"faq"}]}}'; +``` + +``` +RAGFlow(api/default)> openai_chat '' 'Hello, how are you?' temp 100; +CLI error: OPENAI_CHAT: unknown option "temp" (valid: model, system, history, delimiter, temperature, max_tokens, stream, top_p, frequency_penalty, presence_penalty, extra_body) +``` + +``` +RAGFlow(api/default)> openai_chat '' 'Hello, how are you?' extra_body '{"ref":true}'; +CLI error: OPENAI_CHAT extra_body: unknown field "ref" (valid: reference, reference_metadata, metadata_condition) +``` + +### 6.9. Generate Embeddings ``` RAGFlow(api/default)> embed text 'what is rag' 'who are you' with 'embedding-3@test@zhipu-ai' dimension 16; ``` -### 6.9. Document Reranking + +### 6.10. Document Reranking ``` RAGFlow(api/default)> rerank query 'what is rag' document 'rag is retrieval augment generation' 'rag need llm' 'famous rag project includes ragflow' with 'rerank@test@zhipu-ai' top 2; ``` -### 6.10. Get supported models from provider API +### 6.11. Get supported models from provider API ``` RAGFlow(api/default)> list supported models from 'gitee' 'test'; @@ -190,7 +228,7 @@ RAGFlow(api/default)> list supported models from 'gitee' 'test'; +-----------+---------------------------+---------------+------------+-----------------------------------------------------------------+----------------------------------------------------------+---------------------------------------------+ ``` -### 6.11. Get preset models of a provider +### 6.12. Get preset models of a provider ``` RAGFlow(api/default)> list models from 'minimax'; @@ -208,7 +246,7 @@ RAGFlow(api/default)> list models from 'minimax'; +------------+-------------+------------------------+ ``` -### 6.12. List instances of a provider +### 6.13. List instances of a provider ``` RAGFlow(api/default)> list instances from 'zhipu-ai'; @@ -219,7 +257,7 @@ RAGFlow(api/default)> list instances from 'zhipu-ai'; +---------+----------------------+----------------------------------+--------------+----------------------------------+--------+ ``` -### 6.13. Show instance of a provider +### 6.14. Show instance of a provider ``` RAGFlow(api/default)> show instance 'test' from 'zhipu-ai'; +----------------------------------+--------------+----------------------------------+---------+--------+ @@ -229,7 +267,7 @@ RAGFlow(api/default)> show instance 'test' from 'zhipu-ai'; +----------------------------------+--------------+----------------------------------+---------+--------+ ``` -### 6.14. List models of a specific instance +### 6.15. List models of a specific instance ``` RAGFlow(api/default)> list models from 'minimax' 'test'; @@ -247,7 +285,7 @@ RAGFlow(api/default)> list models from 'minimax' 'test'; +------------+-------------+------------------------+--------+ ``` -### 6.15. List added providers +### 6.16. List added providers ``` RAGFlow(api/default)> list providers; +--------------------------------------------------------------------------+-------------+--------------+ @@ -259,7 +297,7 @@ RAGFlow(api/default)> list providers; +--------------------------------------------------------------------------+-------------+--------------+ ``` -### 6.16. Deactivate / activate a model +### 6.17. Deactivate / activate a model ``` RAGFlow(api/default)> disable model 'deepseek-v4-pro' from 'deepseek' 'test'; @@ -275,7 +313,7 @@ RAGFlow(api/default)> enable model 'deepseek-v4-pro' from 'deepseek' 'test'; SUCCESS ``` -### 6.17. Set current model +### 6.18. Set current model ``` RAGFlow(api/default)> use model 'glm-4.5-flash@test@zhipu-ai'; SUCCESS @@ -284,7 +322,7 @@ Answer: Large language models are advanced AI systems. They process text to unde Time: 1.680416 ``` -### 6.18. Set, reset, and list default models +### 6.19. Set, reset, and list default models ``` RAGFlow(api/default)> set default chat model 'zhipu-ai/test/glm-4.5-flash'; SUCCESS @@ -314,7 +352,7 @@ RAGFlow(api/default)> list default models; +--------+----------------+---------------+----------------+------------+ RAGFlow(api/default)> reset default embedding model; SUCCESS -RAGFlow(api/default)> reset default chat model +RAGFlow(api/default)> reset default chat model; SUCCESS RAGFlow(api/default)> list default models; +--------+----------------+--------------+----------------+------------+ @@ -328,7 +366,7 @@ RAGFlow(api/default)> list default models; +--------+----------------+--------------+----------------+------------+ ``` -### 6.19. Show current balance of a provider instance +### 6.20. Show current balance of a provider instance ``` RAGFlow(api/default)> show balance from 'gitee' 'test'; +-------------+----------+ @@ -338,13 +376,13 @@ RAGFlow(api/default)> show balance from 'gitee' 'test'; +-------------+----------+ ``` -### 6.20. Check provider instance availability +### 6.21. Check provider instance availability ``` RAGFlow(api/default)> check instance 'test' from 'zhipu-ai'; SUCCESS ``` -### 6.21. Add local model to RAGFlow, only for local deployed inference server, such as ollama +### 6.22. Add local model to RAGFlow, only for local deployed inference server, such as ollama ``` RAGFlow(api/default)> add model 'Qwen/Qwen2.5-0.5B' to provider 'vllm' instance 'test' with tokens 131072 chat; SUCCESS @@ -358,7 +396,7 @@ RAGFlow(api/default)> drop model 'Qwen/Qwen2.5-0.5B' from 'vllm' 'test'; SUCCESS ``` -### 6.22. List datasets +### 6.23. List datasets ``` RAGFlow(api/default)> list datasets; +-------------+--------------+----------------+----------------------+----------------------------------+----------+------+----------+------------+----------------------------------+-----------+---------------+ @@ -369,14 +407,14 @@ RAGFlow(api/default)> list datasets; +-------------+--------------+----------------+----------------------+----------------------------------+----------+------+----------+------------+----------------------------------+-----------+---------------+ ``` -### 6.23 Text to Speech +### 6.24. Text to Speech ``` RAGFlow(api/default)> tts with 'speech-2.8-hd@test@minimax' text 'He who desires but acts not, breeds pestilence.' play format 'wav' save './internal' param '{"voice_setting": {"voice_id": "English_radiant_girl", "speed": 1, "vol": 1, "pitch": 0}, "audio_setting": {"sample_rate": 32000, "bitrate": 128000, "format": "wav", "channel": 1}, "output_format": "hex"}' Saved to directory: /home/infiniflow/Documents/development/ragflow/internal/speech-2.8-hd_output.wav SUCCESS ``` -### 6.24 Audio to Speech +### 6.25. Audio to Speech ``` RAGFlow(api/default)> asr with 'FunAudioLLM/SenseVoiceSmall@test@siliconflow' audio './internal/test.wav' param '' +----------------------------------------------------------------------------------------------------------------------+ @@ -386,7 +424,7 @@ RAGFlow(api/default)> asr with 'FunAudioLLM/SenseVoiceSmall@test@siliconflow' au +----------------------------------------------------------------------------------------------------------------------+ ``` -### 6.25 Optical Character Recognition\ +### 6.26. Optical Character Recognition ``` RAGFlow(api/default)> ocr with 'paddleocr-vl-0.9b@test@baidu' file './internal/text.jpg' +------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ @@ -396,7 +434,7 @@ RAGFlow(api/default)> ocr with 'paddleocr-vl-0.9b@test@baidu' file './internal/t +------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ ``` -### 6.26 Chunk Management Commands +### 6.27. Chunk Management Commands - Create a chunk store with vector size ``` @@ -443,7 +481,7 @@ RAGFlow(api/default)> SEARCH 'AI' ON DATASETS 'test' RAGFlow(api/default)> GET CHUNK '29cc4f6d7a5c6e7c' OF DATASET 'test' DOCUMENT 'bbe55942535e11f1bc5184ba59049aa3' IN DATASET 'test' ``` -### 6.27 Metadata Management Commands +### 6.28. Metadata Management Commands - Create metadata store ``` @@ -461,12 +499,12 @@ RAGFlow(api/default)> SET METADATA OF DOCUMENT 'bbe55942535e11f1bc5184ba59049aa3 - Delete metadata of a document ``` -DELETE METADATA OF DOCUMENT 'bbe55942535e11f1bc5184ba59049aa3' +RAGFlow(api/default)> DELETE METADATA OF DOCUMENT 'bbe55942535e11f1bc5184ba59049aa3' ``` - Delete metadata keys of a document ``` -DELETE METADATA OF DOCUMENT 'bbe55942535e11f1bc5184ba59049aa3' KEYS '["key1", "key2"]' +RAGFlow(api/default)> DELETE METADATA OF DOCUMENT 'bbe55942535e11f1bc5184ba59049aa3' KEYS '["key1", "key2"]' ``` - Drop metadata store @@ -479,7 +517,7 @@ RAGFlow(api/default)> DROP METADATA STORE RAGFlow(api/default)> GET METADATA OF DATASET 'test' 'test2' ``` -### 6.28 Search datasets +### 6.29. Search datasets - Search datasets ``` diff --git a/internal/engine/elasticsearch/sql.go b/internal/engine/elasticsearch/sql.go new file mode 100644 index 0000000000..2af379a86f --- /dev/null +++ b/internal/engine/elasticsearch/sql.go @@ -0,0 +1,206 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package elasticsearch + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "regexp" + "strings" + "time" + + "ragflow/internal/common" + "ragflow/internal/tokenizer" + + "github.com/elastic/go-elasticsearch/v8/esapi" + "go.uber.org/zap" +) + +const ( + esSQLRequestTimeout = 2 * time.Second + esSQLFetchSize = 128 +) + +const esSQLRetryAttempts = 2 +const esSQLRetryDelay = 3 * time.Second + +var whitespaceRe = regexp.MustCompile("[ `]+") +var lktksMatchRe = regexp.MustCompile(` ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'`) + +// Preprocess normalizes SQL for ES: collapses whitespace/backticks, +// strips '%', and rewrites `_l?tks like/= 'value'` into a +// tokenized MATCH() call. +func Preprocess(sql string) string { + sql = whitespaceRe.ReplaceAllString(sql, " ") + sql = strings.ReplaceAll(sql, "%", "") + + // Collect replacements so we don't re-scan tokens we've already rewritten + type replacement struct { + old, new string + } + var replaces []replacement + for _, m := range lktksMatchRe.FindAllStringSubmatchIndex(sql, -1) { + match := sql[m[0]:m[1]] + fld := sql[m[2]:m[3]] + val := sql[m[6]:m[7]] + tokenized, err := tokenizer.Tokenize(val) + if err != nil { + continue + } + fine, err := tokenizer.FineGrainedTokenize(tokenized) + if err != nil { + continue + } + replaces = append(replaces, replacement{ + old: match, + new: fmt.Sprintf(" MATCH(%s, '%s', 'operator=OR;minimum_should_match=30%%') ", fld, fine), + }) + } + for _, r := range replaces { + sql = strings.Replace(sql, r.old, r.new, 1) + } + return sql +} + +// RunSQL posts SQL to `/_sql`, translates the response into chunk-shaped maps. +// Returns (nil, nil) on empty rows; (nil, error) when retries exhausted. +func (e *elasticsearchEngine) RunSQL(ctx context.Context, tableName string, sqlText string, kbIDs []string, format string) ([]map[string]interface{}, error) { + if e == nil || e.client == nil { + return nil, fmt.Errorf("Elasticsearch RunSQL: client not initialized") + } + if sqlText == "" { + return nil, fmt.Errorf("Elasticsearch RunSQL: empty SQL") + } + + common.Debug("ESConnection.sql get sql", zap.String("sql", sqlText)) + sqlText = Preprocess(sqlText) + common.Debug("ESConnection.sql to es", zap.String("sql", sqlText)) + + var lastErr error + for attempt := 0; attempt < esSQLRetryAttempts; attempt++ { + rows, err := e.runSQLOnce(ctx, sqlText, format) + if err == nil { + return rows, nil + } + lastErr = err + if !isTimeoutError(err) { + common.Warn("ESConnection.sql got exception", + zap.String("sql", sqlText), + zap.Error(err)) + return nil, fmt.Errorf("SQL error: %w\n\nSQL: %s", err, sqlText) + } + common.Warn("ES request timeout", + zap.String("sql", sqlText), + zap.Int("attempt", attempt+1), + zap.Error(err)) + if attempt < esSQLRetryAttempts-1 { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(esSQLRetryDelay): + } + } + } + common.Error(fmt.Sprintf("ESConnection.sql timeout after %d attempts. SQL: %s", esSQLRetryAttempts, sqlText), lastErr) + return nil, fmt.Errorf("Elasticsearch RunSQL: timeout after %d attempts: %w", esSQLRetryAttempts, lastErr) +} + +func (e *elasticsearchEngine) runSQLOnce(ctx context.Context, sqlText string, format string) ([]map[string]interface{}, error) { + ctx, cancel := context.WithTimeout(ctx, esSQLRequestTimeout) + defer cancel() + + body := map[string]interface{}{ + "query": sqlText, + "fetch_size": esSQLFetchSize, + } + buf, err := json.Marshal(body) + if err != nil { + return nil, fmt.Errorf("marshal body: %w", err) + } + + req := esapi.SQLQueryRequest{ + Body: bytes.NewReader(buf), + Format: format, + } + res, err := req.Do(ctx, e.client) + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + defer res.Body.Close() + + if res.IsError() { + errBody, _ := io.ReadAll(res.Body) + return nil, fmt.Errorf("status=%d body=%s", res.StatusCode, string(errBody)) + } + + // Parse the SQL response. + var resp struct { + Columns []struct { + Name string `json:"name"` + Type string `json:"type"` + } `json:"columns"` + Rows [][]interface{} `json:"rows"` + } + if err := json.NewDecoder(res.Body).Decode(&resp); err != nil { + return nil, fmt.Errorf("decode response: %w", err) + } + + if len(resp.Rows) == 0 { + return nil, nil + } + + // Convert to chunk-shaped maps. Column names map 1:1 to JSON keys. + out := make([]map[string]interface{}, 0, len(resp.Rows)) + for _, row := range resp.Rows { + cm := make(map[string]interface{}, len(resp.Columns)) + for i, col := range resp.Columns { + if i < len(row) { + cm[col.Name] = row[i] + } + } + out = append(out, cm) + } + return out, nil +} + +// isTimeoutError detects connection-level and per-attempt timeouts +// via context.DeadlineExceeded, net.Error.Timeout(), and substring +// matches (for SDKs that wrap without typed errors). +func isTimeoutError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, context.DeadlineExceeded) { + return true + } + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + return true + } + msg := err.Error() + for _, sub := range []string{"i/o timeout", "deadline exceeded", "connection timeout", "context deadline"} { + if strings.Contains(msg, sub) { + return true + } + } + return false +} diff --git a/internal/engine/elasticsearch/sql_test.go b/internal/engine/elasticsearch/sql_test.go new file mode 100644 index 0000000000..16ecdbda92 --- /dev/null +++ b/internal/engine/elasticsearch/sql_test.go @@ -0,0 +1,493 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package elasticsearch + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/http/httptest" + "regexp" + "strings" + "sync" + "testing" + "time" + + "ragflow/internal/tokenizer" + + "github.com/elastic/go-elasticsearch/v8" +) + +// capturedRequest holds the request body the test server saw, for +// assertions. +type capturedRequest struct { + mu sync.Mutex + path string + body string + method string +} + +// newCapturingServer returns an httptest.Server that captures each +// incoming request and replies with the given body / status. +func newCapturingServer(t *testing.T, replyStatus int, replyBody string) (*httptest.Server, *capturedRequest) { + t.Helper() + cap := &capturedRequest{} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + cap.mu.Lock() + cap.method = r.Method + cap.path = r.URL.Path + cap.body = string(body) + cap.mu.Unlock() + w.Header().Set("X-Elastic-Product", "Elasticsearch") + w.WriteHeader(replyStatus) + _, _ = w.Write([]byte(replyBody)) + })) + t.Cleanup(srv.Close) + return srv, cap +} + +// newTestEngine constructs an elasticsearchEngine pointing at the given +// test server. Bypasses NewEngine (which calls ES Info to verify +// connectivity) — the test server is a stub, not a real ES cluster. +func newTestEngine(t *testing.T, srvURL string) *elasticsearchEngine { + t.Helper() + client, err := elasticsearch.NewClient(elasticsearch.Config{ + Addresses: []string{srvURL}, + }) + if err != nil { + t.Fatalf("elasticsearch.NewClient: %v", err) + } + return &elasticsearchEngine{client: client} +} + +const sampleESResponse = `{ + "columns": [ + {"name": "doc_id", "type": "text"}, + {"name": "docnm", "type": "text"}, + {"name": "count", "type": "long"} + ], + "rows": [ + ["d1", "report.pdf", 5], + ["d2", "spec.pdf", 3] + ] +}` + +// TestRunSQL_NoFilterAdded verifies the request body is exactly +// {"query": } — the redundant `filter` field that the previous +// implementation added is gone. (service.addKBFilter is the source of +// truth for kb_id scoping upstream of RunSQL.) +func TestRunSQL_NoFilterAdded(t *testing.T) { + srv, cap := newCapturingServer(t, http.StatusOK, sampleESResponse) + e := newTestEngine(t, srv.URL) + + rows, err := e.RunSQL(context.Background(), "ragflow_t1", "SELECT doc_id FROM ragflow_t1", nil, "json") + if err != nil { + t.Fatalf("RunSQL: %v", err) + } + if len(rows) != 2 { + t.Fatalf("rows: got %d, want 2", len(rows)) + } + cap.mu.Lock() + got := cap.body + cap.mu.Unlock() + + var body map[string]interface{} + if err := json.Unmarshal([]byte(got), &body); err != nil { + t.Fatalf("body is not JSON: %v\nbody=%q", err, got) + } + if _, has := body["filter"]; has { + t.Errorf("RunSQL request must NOT include top-level filter (addKBFilter is the source of truth upstream). body=%v", body) + } + if _, has := body["query"]; !has { + t.Errorf("RunSQL request must include query. body=%v", body) + } +} + +// TestRunSQL_WhitespaceNormalizedAndPercentStripped verifies the Python +// preprocessing step `re.sub(r"[ `]+", " ", sql)` + `sql.replace("%", "")` +// is applied. Without these, the LLM-generated SQL with stray backticks +// or `%` characters (e.g. from JSON decoding glitches) would fail to +// parse in ES. +func TestRunSQL_WhitespaceNormalizedAndPercentStripped(t *testing.T) { + srv, cap := newCapturingServer(t, http.StatusOK, sampleESResponse) + e := newTestEngine(t, srv.URL) + + // Input SQL has multiple backticks/spaces and trailing % characters. + in := "SELECT doc_id FROM `ragflow_t1` WHERE count > 0 %" + _, err := e.RunSQL(context.Background(), "ragflow_t1", in, nil, "json") + if err != nil { + t.Fatalf("RunSQL: %v", err) + } + cap.mu.Lock() + got := cap.body + cap.mu.Unlock() + + var body map[string]interface{} + if err := json.Unmarshal([]byte(got), &body); err != nil { + t.Fatalf("body is not JSON: %v\nbody=%q", err, got) + } + q, _ := body["query"].(string) + if strings.Contains(q, " ") { + t.Errorf("query still has multiple spaces (whitespace not normalized): %q", q) + } + if strings.Contains(q, "`") { + t.Errorf("query still has backticks (whitespace+backtick regex not applied): %q", q) + } + if strings.Contains(q, "%") { + t.Errorf("query still has %% (percent strip not applied): %q", q) + } +} + +// TestRunSQL_PerAttemptTimeout verifies the derived context has a 2s +// deadline. We send a hanging response from the test server and assert +// the call returns well before 30s (the Go ES client's default +// transport-level timeout). With the retry loop in place, the total +// time is 2s (first attempt) + 3s (sleep) + 2s (second attempt) = ~7s. +func TestRunSQL_PerAttemptTimeout(t *testing.T) { + hang := make(chan struct{}) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + <-hang + })) + t.Cleanup(func() { + close(hang) + srv.Close() + }) + e := newTestEngine(t, srv.URL) + + start := time.Now() + _, err := e.RunSQL(context.Background(), "ragflow_t1", "SELECT 1", nil, "json") + elapsed := time.Since(start) + if err == nil { + t.Fatalf("RunSQL: got nil error, want timeout error") + } + // 2s + 3s + 2s = 7s; allow generous upper bound for the test runner. + if elapsed < 6*time.Second { + t.Errorf("RunSQL returned in %s; expected ~7s (2 attempts + 3s sleep)", elapsed) + } + if elapsed > 10*time.Second { + t.Errorf("RunSQL took %s; expected ~7s, looks like the retry didn't fire", elapsed) + } + // The final error should mention timeout + 2 attempts. + if !strings.Contains(err.Error(), "timeout after 2 attempts") { + t.Errorf("err: got %q, want substring %q", err.Error(), "timeout after 2 attempts") + } +} + +// TestRunSQL_RetryOnTimeoutThenSucceed simulates Python's +// ConnectionTimeout-retry pattern: the first attempt times out, the +// second attempt returns valid rows. The loop should silently retry +// and return the rows. +func TestRunSQL_RetryOnTimeoutThenSucceed(t *testing.T) { + var ( + mu sync.Mutex + calls int + ) + release := make(chan struct{}) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + calls++ + attempt := calls + mu.Unlock() + if attempt == 1 { + // First attempt: hang so the 2s context fires. + select { + case <-release: + case <-r.Context().Done(): + } + return + } + w.Header().Set("X-Elastic-Product", "Elasticsearch") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(sampleESResponse)) + })) + t.Cleanup(func() { + close(release) + srv.Close() + }) + + e := newTestEngine(t, srv.URL) + rows, err := e.RunSQL(context.Background(), "ragflow_t1", "SELECT 1", nil, "json") + if err != nil { + t.Fatalf("RunSQL: %v", err) + } + if len(rows) != 2 { + t.Errorf("rows: got %d, want 2 (second attempt should succeed)", len(rows)) + } + mu.Lock() + defer mu.Unlock() + if calls != 2 { + t.Errorf("server calls: got %d, want 2 (initial + one retry)", calls) + } +} + +// TestRunSQL_NonTimeoutErrorSurfacesImmediately verifies the non-retry +// path: a 4xx ES response should NOT trigger a retry. The error must +// be wrapped as `SQL error: \n\nSQL: `, matching Python's +// es_conn_base.py:400. +func TestRunSQL_NonTimeoutErrorSurfacesImmediately(t *testing.T) { + var ( + mu sync.Mutex + calls int + ) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + calls++ + mu.Unlock() + w.Header().Set("X-Elastic-Product", "Elasticsearch") + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"error": "syntax error"}`)) + })) + t.Cleanup(srv.Close) + + e := newTestEngine(t, srv.URL) + _, err := e.RunSQL(context.Background(), "ragflow_t1", "SELECT bad", nil, "json") + if err == nil { + t.Fatalf("RunSQL: got nil error, want error") + } + mu.Lock() + defer mu.Unlock() + if calls != 1 { + t.Errorf("server calls: got %d, want 1 (non-timeout error must NOT retry)", calls) + } + // Python wraps as `f"SQL error: {e}\n\nSQL: {sql}"`. + if !strings.Contains(err.Error(), "SQL error:") { + t.Errorf("err: got %q, want substring 'SQL error:'", err.Error()) + } + if !strings.Contains(err.Error(), "SQL: SELECT bad") { + t.Errorf("err: got %q, want substring 'SQL: SELECT bad'", err.Error()) + } +} + +// TestRunSQL_RequestBodyHasFetchSizeAndFormat verifies the request body +// includes fetch_size=128 and the SQLQueryRequest is built with +// format="json", matching the Python defaults at rag/nlp/search.py:773. +func TestRunSQL_RequestBodyHasFetchSizeAndFormat(t *testing.T) { + srv, cap := newCapturingServer(t, http.StatusOK, sampleESResponse) + e := newTestEngine(t, srv.URL) + + if _, err := e.RunSQL(context.Background(), "ragflow_t1", "SELECT 1", nil, "json"); err != nil { + t.Fatalf("RunSQL: %v", err) + } + cap.mu.Lock() + got := cap.body + cap.mu.Unlock() + + var body map[string]interface{} + if err := json.Unmarshal([]byte(got), &body); err != nil { + t.Fatalf("body is not JSON: %v\nbody=%q", err, got) + } + fs, ok := body["fetch_size"] + if !ok { + t.Errorf("body has no fetch_size; got %v", body) + } + if fmt.Sprint(fs) != "128" { + t.Errorf("fetch_size: got %v, want 128", fs) + } +} + +// TestRunSQL_EmptyRowsReturnsNilNil verifies the (nil, nil) sentinel +// for empty results — callers treat this as "fall through to vector +// retrieval". +func TestRunSQL_EmptyRowsReturnsNilNil(t *testing.T) { + empty := `{"columns": [{"name": "doc_id", "type": "text"}], "rows": []}` + srv, _ := newCapturingServer(t, http.StatusOK, empty) + e := newTestEngine(t, srv.URL) + + rows, err := e.RunSQL(context.Background(), "ragflow_t1", "SELECT doc_id FROM ragflow_t1", nil, "json") + if err != nil { + t.Fatalf("RunSQL: %v", err) + } + if rows != nil { + t.Errorf("rows: got %v, want nil (empty-rows sentinel)", rows) + } +} + +// TestRunSQL_PostsToSQLPath verifies the request goes to the /_sql +// endpoint (the modern ES SQL API; the older /_xpack/sql path is +// deprecated as of ES 7.x). The Go SDK's esapi.SQLQueryRequest hits +// /_sql; the Python ES client is also pinned to the modern endpoint +// at runtime even though the legacy /_xpack/sql name appears in the +// SDK's method (`es.sql.query(...)`). +func TestRunSQL_PostsToSQLPath(t *testing.T) { + srv, cap := newCapturingServer(t, http.StatusOK, sampleESResponse) + e := newTestEngine(t, srv.URL) + + if _, err := e.RunSQL(context.Background(), "ragflow_t1", "SELECT 1", nil, "json"); err != nil { + t.Fatalf("RunSQL: %v", err) + } + cap.mu.Lock() + got := cap.path + cap.mu.Unlock() + if got != "/_sql" { + t.Errorf("path: got %q, want /_sql", got) + } +} + +// TestMain registers the engine as "infinity" so tokenizer.Tokenize and +// tokenizer.FineGrainedTokenize short-circuit and return the input +// as-is. This lets the rewrite tests assert on the SHAPE of the MATCH() +// substitution without depending on a real tokenizer pool. +func TestMain(m *testing.M) { + tokenizer.RegisterEngineType(func() string { return "infinity" }) + m.Run() +} + +func TestPreprocess_WhitespaceAndBackticks(t *testing.T) { + cases := []struct { + in, want string + }{ + {"a b", "a b"}, + {"a b c", "a b c"}, + {"a`b`c", "a b c"}, + {"a `` b", "a b"}, + {" leading and trailing ", " leading and trailing "}, + } + for _, c := range cases { + if got := Preprocess(c.in); got != c.want { + t.Errorf("Preprocess(%q) = %q, want %q", c.in, got, c.want) + } + } +} + +func TestPreprocess_StripsPercent(t *testing.T) { + cases := []struct { + in, want string + }{ + {"count > 0 %", "count > 0 "}, + {"100% match", "100 match"}, + {"%%%", ""}, + } + for _, c := range cases { + if got := Preprocess(c.in); got != c.want { + t.Errorf("Preprocess(%q) = %q, want %q", c.in, got, c.want) + } + } +} + +func TestPreprocess_LktksRewrite(t *testing.T) { + cases := []struct { + name string + in string + field string + expect string + }{ + { + "like with single-token value (ltks suffix)", + "select content_ltks like 'weather'", + "content_ltks", + "MATCH(content_ltks,", + }, + { + "= with multi-word value (ltks suffix)", + "select content_ltks = 'final report'", + "content_ltks", + "MATCH(content_ltks,", + }, + { + "tks (no l) suffix", + "select title_tks = 'hello'", + "title_tks", + "MATCH(title_tks,", + }, + { + "leading-space anchor: no leading space means no match (mirrors Python regex)", + "content_ltks like 'weather'", + "content_ltks", + "content_ltks like 'weather'", + }, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + got := Preprocess(c.in) + isAnchorTest := c.expect == c.in + if isAnchorTest { + if got != c.in { + t.Errorf("Preprocess(%q) = %q, want unchanged (leading-space anchor should prevent match)", c.in, got) + } + return + } + if strings.Contains(got, c.field+" ") { + pattern := regexp.MustCompile(c.field + `( like | ?= ?)`) + if pattern.MatchString(got) { + t.Errorf("Preprocess(%q) = %q, still contains the original ` like/=` pattern", c.in, got) + } + } + if !strings.Contains(got, c.expect) { + t.Errorf("Preprocess(%q) = %q, want substring %q", c.in, got, c.expect) + } + if !strings.Contains(got, "minimum_should_match=30") { + t.Errorf("Preprocess(%q) = %q, want substring minimum_should_match=30", c.in, got) + } + }) + } +} + +func TestPreprocess_NoMatchLeavesSQLAlone(t *testing.T) { + in := "SELECT doc_id FROM ragflow_t1" + got := Preprocess(in) + if got != in { + t.Errorf("Preprocess(%q) = %q, want unchanged", in, got) + } +} + +// fakeNetTimeoutErr implements net.Error with Timeout()==true. +type fakeNetTimeoutErr struct{} + +func (fakeNetTimeoutErr) Error() string { return "i/o timeout" } +func (fakeNetTimeoutErr) Timeout() bool { return true } +func (fakeNetTimeoutErr) Temporary() bool { return true } + +func TestIsTimeoutError(t *testing.T) { + cases := []struct { + name string + err error + want bool + }{ + {"nil error", nil, false}, + {"context.DeadlineExceeded", context.DeadlineExceeded, true}, + {"wrapped context.DeadlineExceeded", fmt.Errorf("wrap: %w", context.DeadlineExceeded), true}, + {"net.Error.Timeout()==true", fakeNetTimeoutErr{}, true}, + {"wrapped net.Error.Timeout", fmt.Errorf("wrap: %w", fakeNetTimeoutErr{}), true}, + {"plain string 'i/o timeout'", errors.New("read tcp: i/o timeout"), true}, + {"plain string 'deadline exceeded'", errors.New("context deadline exceeded"), true}, + {"plain string 'connection timeout'", errors.New("connection timeout while reading"), true}, + {"unrelated error", errors.New("parse: invalid character"), false}, + {"EOF is not a timeout", errors.New("EOF"), false}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + if got := isTimeoutError(c.err); got != c.want { + t.Errorf("isTimeoutError(%v) = %v, want %v", c.err, got, c.want) + } + }) + } +} + +func TestIsTimeoutError_NonTimeoutNetError(t *testing.T) { + e := &net.OpError{ + Op: "dial", + Err: errors.New("connection refused"), + } + if isTimeoutError(e) { + t.Errorf("isTimeoutError(connection-refused) = true, want false") + } +} diff --git a/internal/engine/engine.go b/internal/engine/engine.go index 789fc7c29b..9a08ce8398 100644 --- a/internal/engine/engine.go +++ b/internal/engine/engine.go @@ -61,6 +61,10 @@ type DocEngine interface { GetFields(chunks []map[string]interface{}, fields []string) map[string]map[string]interface{} GetAggregation(chunks []map[string]interface{}, fieldName string) []map[string]interface{} GetHighlight(chunks []map[string]interface{}, keywords []string, fieldName string) map[string]string + + // Run SQL + RunSQL(ctx context.Context, tableName string, sqlText string, kbIDs []string, format string) ([]map[string]interface{}, error) + GetChunkIDs(chunks []map[string]interface{}) []string KNNScores(ctx context.Context, chunks []map[string]interface{}, queryVector []float64, topK int) (map[string]interface{}, error) GetScores(searchResult map[string]interface{}) map[string]float64 diff --git a/internal/engine/global.go b/internal/engine/global.go index 151de6fef4..3a443f7cef 100644 --- a/internal/engine/global.go +++ b/internal/engine/global.go @@ -27,6 +27,7 @@ import ( "ragflow/internal/engine/infinity" "go.uber.org/zap" + "ragflow/internal/tokenizer" ) var ( @@ -40,6 +41,10 @@ var ( func Init(cfg *server.DocEngineConfig) error { var initErr error once.Do(func() { + tokenizer.RegisterEngineType(func() string { + return string(GetEngineType()) + }) + engineType = EngineType(cfg.Type) var err error switch engineType { diff --git a/internal/engine/infinity/chunk.go b/internal/engine/infinity/chunk.go index ad37d51a0e..b3205abc24 100644 --- a/internal/engine/infinity/chunk.go +++ b/internal/engine/infinity/chunk.go @@ -1916,6 +1916,7 @@ func convertMatchingField(fieldWeightStr string) string { "authors_tks": "authors@ft_authors_rag_coarse", "authors_sm_tks": "authors@ft_authors_rag_fine", "tag_kwd": "tag_kwd@ft_tag_kwd_whitespace__", + "toc_kwd": "toc_kwd@ft_toc_kwd_whitespace__", // Skill index fields "name": "name@ft_name_rag_coarse", "tags": "tags@ft_tags_rag_coarse", diff --git a/internal/engine/infinity/client.go b/internal/engine/infinity/client.go index aca5d68074..2eca460f79 100644 --- a/internal/engine/infinity/client.go +++ b/internal/engine/infinity/client.go @@ -30,10 +30,18 @@ import ( infinity "github.com/infiniflow/infinity-go-sdk" ) -// infinityClient Infinity SDK client wrapper type infinityClient struct { conn *infinity.InfinityConnection dbName string + + // Original URI from config, used by RunSQL to extract the host. + hostURI string + + // Port for psql wire-protocol listener (default 5432). + postgresPort int + + // JSON file (under conf/) with the field-name alias map. + mappingFileName string } // NewInfinityClient creates a new Infinity client using the SDK @@ -70,8 +78,11 @@ func NewInfinityClient(cfg *server.InfinityConfig) (*infinityClient, error) { } client := &infinityClient{ - conn: conn, - dbName: cfg.DBName, + conn: conn, + dbName: cfg.DBName, + hostURI: cfg.URI, + postgresPort: cfg.PostgresPort, + mappingFileName: cfg.MappingFileName, } return client, nil diff --git a/internal/engine/infinity/metadata.go b/internal/engine/infinity/metadata.go index 67d43a5dc6..438023dfe2 100644 --- a/internal/engine/infinity/metadata.go +++ b/internal/engine/infinity/metadata.go @@ -24,12 +24,13 @@ import ( "path/filepath" "strings" - infinity "github.com/infiniflow/infinity-go-sdk" "ragflow/internal/common" "ragflow/internal/dao" "ragflow/internal/engine/types" "ragflow/internal/utility" + infinity "github.com/infiniflow/infinity-go-sdk" + "go.uber.org/zap" ) @@ -568,25 +569,113 @@ func (e *infinityEngine) SearchMetadata(ctx context.Context, req *types.SearchMe }, nil } - // Build search request for metadata - simpler than chunk search, no match expressions - searchReq := &types.SearchRequest{ - IndexNames: []string{tableName}, - Offset: req.Offset, - Limit: req.Limit, - SelectFields: req.SelectFields, - Filter: req.Filter, - MatchExprs: nil, // No match expressions for metadata - OrderBy: req.OrderBy, - RankFeature: nil, + // Build output columns: use caller-specified fields, or "*" for all columns + var outputColumns []string + if len(req.SelectFields) > 0 { + outputColumns = req.SelectFields + } else { + outputColumns = []string{"*"} } - result, err := e.Search(ctx, searchReq) - if err != nil { - return nil, err + // Pagination defaults + pageSize := req.Limit + if pageSize <= 0 { + pageSize = 30 } + offset := req.Offset + if offset < 0 { + offset = 0 + } + + // Build filter from req.Filter + var filterStr string + if req.Filter != nil { + filterStr = equivalentConditionToStr(req.Filter) + } + + // Get database and table + db, err := e.client.conn.GetDatabase(e.client.dbName) + if err != nil { + return nil, fmt.Errorf("failed to get database: %w", err) + } + + tbl, err := db.GetTable(tableName) + if err != nil { + return nil, fmt.Errorf("failed to get metadata table %s: %w", tableName, err) + } + + // Build Infinity query (chainable API) + table := tbl.Output(outputColumns) + if filterStr != "" { + table = table.Filter(filterStr) + } + + // Add order_by if provided + if req.OrderBy != nil && len(req.OrderBy.Fields) > 0 { + var sortFields [][2]interface{} + for _, orderField := range req.OrderBy.Fields { + sortType := infinity.SortTypeAsc + if orderField.Type == types.SortDesc { + sortType = infinity.SortTypeDesc + } + sortFields = append(sortFields, [2]interface{}{orderField.Field, sortType}) + } + table = table.Sort(sortFields) + } + + table = table.Limit(pageSize) + if offset > 0 { + table = table.Offset(offset) + } + table = table.Option(map[string]interface{}{"total_hits_count": true}) + + // Execute query + df, err := table.ToDataFrame() + if err != nil { + common.Warn("Infinity SearchMetadata query failed", + zap.String("tableName", tableName), + zap.Error(err)) + return nil, fmt.Errorf("metadata query failed: %w", err) + } + + // Convert column-oriented DataFrame to row-oriented records + records := make([]map[string]interface{}, 0) + for colName, colData := range df.ColumnData { + for i, val := range colData { + for len(records) <= i { + records = append(records, make(map[string]interface{})) + } + records[i][colName] = val + } + } + + // Handle ROW_ID -> row_id() mapping (Infinity internal column) + for _, rec := range records { + if val, ok := rec["ROW_ID"]; ok { + rec["row_id()"] = val + delete(rec, "ROW_ID") + } + } + + // Realign meta_fields column for multi-row queries (Infinity may + // concatenate values into one blob with 4-byte length prefix) + realignMetaFieldsColumn(records) + + // Parse total_hits_count from ExtraInfo + var totalHits int64 + if df.ExtraInfo != "" { + if t, ok := totalHitsFromInfinityExtraInfo(df.ExtraInfo); ok { + totalHits = t + } + } + + common.Debug("SearchMetadata in Infinity completed", + zap.Int("rows", len(records)), + zap.Int64("total", totalHits)) + return &types.SearchMetadataResult{ - MetadataRecords: result.Chunks, - Total: result.Total, + MetadataRecords: records, + Total: totalHits, }, nil } diff --git a/internal/engine/infinity/sql.go b/internal/engine/infinity/sql.go new file mode 100644 index 0000000000..fedff93354 --- /dev/null +++ b/internal/engine/infinity/sql.go @@ -0,0 +1,330 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package infinity + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "os" + "os/exec" + "path/filepath" + "regexp" + "strconv" + "strings" + "time" + + "ragflow/internal/common" + + "go.uber.org/zap" +) + +const ( + psqlTimeout = 10 * time.Second + defaultPsqlPath = "/usr/bin/psql" + defaultPsqlHost = "infinity" + defaultPsqlPort = "5432" +) + +var whitespaceRe = regexp.MustCompile("[ `]+") + +var rowCountFooterRe = regexp.MustCompile(`^\(\d+ rows?`) + +// fieldMappingEntry is one entry in infinity_mapping.json. +type fieldMappingEntry struct { + Type string `json:"type"` + Comment string `json:"comment"` +} + +// loadFieldMapping reads infinity_mapping.json and returns alias→actual +// and actual→firstAlias maps. Silently returns empty maps on missing file. +func loadFieldMapping(mappingFileName string) (aliasToActual map[string]string, actualToFirstAlias map[string]string, err error) { + if mappingFileName == "" { + mappingFileName = "infinity_mapping.json" + } + confPath := filepath.Join(projectBaseDir(), "conf", mappingFileName) + data, err := os.ReadFile(confPath) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return map[string]string{}, map[string]string{}, nil + } + return nil, nil, fmt.Errorf("load field mapping %q: %w", confPath, err) + } + + fields := map[string]fieldMappingEntry{} + if err := json.Unmarshal(data, &fields); err != nil { + return nil, nil, fmt.Errorf("parse field mapping %q: %w", confPath, err) + } + + aliasToActual = make(map[string]string, len(fields)*2) + actualToFirstAlias = make(map[string]string, len(fields)) + for actual, info := range fields { + if info.Comment == "" { + continue + } + var firstAlias string + for _, raw := range strings.Split(info.Comment, ",") { + alias := strings.TrimSpace(raw) + if alias == "" { + continue + } + aliasToActual[alias] = actual + if firstAlias == "" { + firstAlias = alias + } + } + if firstAlias != "" { + actualToFirstAlias[actual] = firstAlias + } + } + return aliasToActual, actualToFirstAlias, nil +} + +// projectBaseDir returns the project root. Honors RAG_PROJECT_BASE and +// RAG_DEPLOY_BASE env vars; falls back to working directory. +func projectBaseDir() string { + if v := os.Getenv("RAG_PROJECT_BASE"); v != "" { + return v + } + if v := os.Getenv("RAG_DEPLOY_BASE"); v != "" { + return v + } + // Fall back to the repository root. The Go engine package lives at + // internal/engine/infinity/; the repo root is three levels up. + wd, err := os.Getwd() + if err != nil { + return "." + } + return wd +} + +// preprocessSQL collapses spaces/backticks and strips '%'. +func preprocessSQL(sql string) string { + sql = whitespaceRe.ReplaceAllString(sql, " ") + sql = strings.ReplaceAll(sql, "%", "") + return sql +} + +// rewriteFieldAliases rewrites alias field names to actual stored names +// in SELECT, WHERE, ORDER BY, GROUP BY, and HAVING clauses. +func rewriteFieldAliases(sql string, aliasToActual map[string]string) string { + if len(aliasToActual) == 0 { + return sql + } + selectRe := regexp.MustCompile(`(?si)(select\s+)(.+?)(\s+from\b)`) + sql = selectRe.ReplaceAllStringFunc(sql, func(m string) string { + parts := selectRe.FindStringSubmatch(m) + prefix, cols, suffix := parts[1], parts[2], parts[3] + for alias, actual := range aliasToActual { + pat := regexp.MustCompile(`(^|[,\s])` + regexp.QuoteMeta(alias) + `($|[,\s])`) + cols = pat.ReplaceAllString(cols, "${1}"+actual+"${2}") + } + return prefix + cols + suffix + }) + + clauseAliases := func(sql, keyword string) string { + return rewriteFirstAliasAfterKeyword(sql, keyword, aliasToActual) + } + sql = clauseAliases(sql, "where") + sql = clauseAliases(sql, "order by") + sql = clauseAliases(sql, "group by") + sql = clauseAliases(sql, "having") + return sql +} + +func rewriteFirstAliasAfterKeyword(sql, keyword string, aliasToActual map[string]string) string { + for alias, actual := range aliasToActual { + aliasPat := regexp.MustCompile(`\b` + regexp.QuoteMeta(alias) + `\b`) + kwIdx := regexp.MustCompile(`(?i)\b` + regexp.QuoteMeta(keyword) + `\b`).FindStringIndex(sql) + if kwIdx == nil { + continue + } + tail := sql[kwIdx[1]:] + aliasIdx := aliasPat.FindStringIndex(tail) + if aliasIdx == nil { + continue + } + absStart := kwIdx[1] + aliasIdx[0] + absEnd := kwIdx[1] + aliasIdx[1] + sql = sql[:absStart] + actual + sql[absEnd:] + } + return sql +} + +// psqlResult is the structured parse of a psql table-format output. +type psqlResult struct { + Columns []string + Rows [][]string +} + +// runPsql shells out to psql and parses the table-format output. +func runPsql(ctx context.Context, host, port, sql string) (*psqlResult, error) { + psqlPath, err := findPsqlBinary() + if err != nil { + return nil, err + } + + ctx, cancel := context.WithTimeout(ctx, psqlTimeout) + defer cancel() + + cmd := exec.CommandContext(ctx, psqlPath, "-h", host, "-p", port, "-c", sql) + common.Debug("executing psql", + zap.String("path", psqlPath), + zap.String("host", host), + zap.String("port", port), + ) + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + if err := cmd.Run(); err != nil { + if ctx.Err() != nil { + return nil, fmt.Errorf("SQL timeout\n\nSQL: %s", sql) + } + return nil, fmt.Errorf("psql command failed: %s\nSQL: %s", strings.TrimSpace(stderr.String()), sql) + } + return parsePsqlTable(stdout.String()), nil +} + +// findPsqlBinary checks PATH first, then falls back to defaultPsqlPath. +func findPsqlBinary() (string, error) { + if path, err := exec.LookPath("psql"); err == nil { + return path, nil + } + if _, err := os.Stat(defaultPsqlPath); err == nil { + return defaultPsqlPath, nil + } + return "", fmt.Errorf("psql not found on PATH and not at %q", defaultPsqlPath) +} + +// parsePsqlTable parses psql's pipe-delimited output: +// +// col1 | col2 +// -----+----- +// val1 | val2 +func parsePsqlTable(output string) *psqlResult { + res := &psqlResult{} + out := strings.TrimSpace(output) + if out == "" { + return res + } + lines := strings.Split(out, "\n") + if len(lines) == 0 { + return res + } + + for _, raw := range strings.Split(lines[0], "|") { + if col := strings.TrimSpace(raw); col != "" { + res.Columns = append(res.Columns, col) + } + } + + dataStart := 1 + if len(lines) >= 2 && strings.Contains(lines[1], "-") { + dataStart = 2 + } + for i := dataStart; i < len(lines); i++ { + line := strings.TrimSpace(lines[i]) + if line == "" || rowCountFooterRe.MatchString(line) { + continue + } + cells := strings.Split(line, "|") + for j := range cells { + cells[j] = strings.TrimSpace(cells[j]) + } + switch { + case len(cells) == len(res.Columns): + res.Rows = append(res.Rows, cells) + case len(cells) > len(res.Columns): + res.Rows = append(res.Rows, cells[:len(res.Columns)]) + default: + padded := make([]string, len(res.Columns)) + copy(padded, cells) + for k := len(cells); k < len(res.Columns); k++ { + padded[k] = "" + } + res.Rows = append(res.Rows, padded) + } + } + return res +} + +// toRowMaps converts psqlResult to a slice of column-keyed maps. +func toRowMaps(res *psqlResult) []map[string]interface{} { + if res == nil || len(res.Rows) == 0 { + return nil + } + out := make([]map[string]interface{}, 0, len(res.Rows)) + for _, row := range res.Rows { + m := make(map[string]interface{}, len(res.Columns)) + for j, col := range res.Columns { + if j < len(row) { + m[col] = row[j] + } + } + out = append(out, m) + } + return out +} + +func resolvePsqlHostPort(hostURI string, postgresPort int) (host, port string) { + host = defaultPsqlHost + port = defaultPsqlPort + if postgresPort > 0 { + port = strconv.Itoa(postgresPort) + } + if hostURI != "" { + if h, _, ok := strings.Cut(hostURI, ":"); ok && h != "" { + host = h + } + } + return host, port +} + +// RunSQL implements the SQL retrieval path: preprocess, rewrite aliases, +// run psql subprocess, parse output. +func (e *infinityEngine) RunSQL(ctx context.Context, tableName string, sqlText string, kbIDs []string, _ string) ([]map[string]interface{}, error) { + if e == nil || e.client == nil { + return nil, fmt.Errorf("infinity RunSQL: client not initialized") + } + sqlText = strings.TrimSpace(sqlText) + if sqlText == "" { + return nil, fmt.Errorf("infinity RunSQL: empty SQL") + } + + common.Debug("InfinityConnection.sql get sql", zap.String("sql", sqlText)) + + sqlText = preprocessSQL(sqlText) + + aliasMap, _, err := loadFieldMapping(e.client.mappingFileName) + if err != nil { + return nil, fmt.Errorf("infinity RunSQL: %w", err) + } + sqlText = rewriteFieldAliases(sqlText, aliasMap) + + common.Debug("InfinityConnection.sql to execute", zap.String("sql", sqlText)) + + host, port := resolvePsqlHostPort(e.client.hostURI, e.client.postgresPort) + res, err := runPsql(ctx, host, port, sqlText) + if err != nil { + return nil, err + } + + return toRowMaps(res), nil +} diff --git a/internal/engine/infinity/sql_test.go b/internal/engine/infinity/sql_test.go new file mode 100644 index 0000000000..4c7e7b98ac --- /dev/null +++ b/internal/engine/infinity/sql_test.go @@ -0,0 +1,394 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package infinity + +import ( + "os" + "path/filepath" + "reflect" + "testing" +) + +// ----------------------------------------------------------------------------- +// preprocessSQL — mirrors infinity_conn_base.py:788-789. +// ----------------------------------------------------------------------------- + +func TestPreprocessSQL_WhitespaceAndBackticks(t *testing.T) { + cases := []struct { + in, want string + }{ + {"a b", "a b"}, + {"a b c", "a b c"}, + {"a`b`c", "a b c"}, + {"a `` b", "a b"}, + // The regex collapses ALL runs of spaces/backticks — including + // leading and trailing whitespace. Trimming is a separate step + // in RunSQL (strings.TrimSpace before the preprocessing pass). + {" leading and trailing ", " leading and trailing "}, + } + for _, c := range cases { + if got := preprocessSQL(c.in); got != c.want { + t.Errorf("preprocessSQL(%q) = %q, want %q", c.in, got, c.want) + } + } +} + +func TestPreprocessSQL_StripsPercent(t *testing.T) { + cases := []struct { + in, want string + }{ + {"count > 0 %", "count > 0 "}, + {"100% match", "100 match"}, + {"%%%", ""}, + } + for _, c := range cases { + if got := preprocessSQL(c.in); got != c.want { + t.Errorf("preprocessSQL(%q) = %q, want %q", c.in, got, c.want) + } + } +} + +func TestPreprocessSQL_Combined(t *testing.T) { + in := "SELECT docnm_kwd FROM `ragflow_t1` WHERE count > 0 %" + got := preprocessSQL(in) + want := "SELECT docnm_kwd FROM ragflow_t1 WHERE count > 0 " + if got != want { + t.Errorf("preprocessSQL(%q) = %q, want %q", in, got, want) + } +} + +// ----------------------------------------------------------------------------- +// rewriteFieldAliases — mirrors infinity_conn_base.py:809-830. +// ----------------------------------------------------------------------------- + +func TestRewriteFieldAliases_SelectClause(t *testing.T) { + aliases := map[string]string{ + "docnm_kwd": "docnm", + "title_tks": "docnm", + "title_sm_tks": "docnm", + "content_ltks": "content", + } + in := "select docnm_kwd, title_tks, content_ltks from ragflow_t1" + got := rewriteFieldAliases(in, aliases) + want := "select docnm, docnm, content from ragflow_t1" + if got != want { + t.Errorf("rewriteFieldAliases(%q) = %q, want %q", in, got, want) + } +} + +func TestRewriteFieldAliases_WhereClause(t *testing.T) { + aliases := map[string]string{ + "docnm_kwd": "docnm", + } + in := "select doc_id from ragflow_t1 where docnm_kwd = 'foo'" + got := rewriteFieldAliases(in, aliases) + want := "select doc_id from ragflow_t1 where docnm = 'foo'" + if got != want { + t.Errorf("rewriteFieldAliases(%q) = %q, want %q", in, got, want) + } +} + +func TestRewriteFieldAliases_OrderGroupHaving(t *testing.T) { + aliases := map[string]string{ + "docnm_kwd": "docnm", + "important_kwd": "important_keywords", + } + in := "select doc_id from ragflow_t1 order by docnm_kwd group by important_kwd having important_kwd > 0" + got := rewriteFieldAliases(in, aliases) + want := "select doc_id from ragflow_t1 order by docnm group by important_keywords having important_keywords > 0" + if got != want { + t.Errorf("rewriteFieldAliases(%q) = %q, want %q", in, got, want) + } +} + +func TestRewriteFieldAliases_EmptyMapIsNoop(t *testing.T) { + in := "select docnm_kwd from ragflow_t1" + if got := rewriteFieldAliases(in, map[string]string{}); got != in { + t.Errorf("empty alias map should not modify SQL; got %q", got) + } +} + +func TestRewriteFieldAliases_WordBoundaryProtected(t *testing.T) { + // "title" is an alias; "title_sm_tks" should NOT match because + // word boundary is enforced. + aliases := map[string]string{ + "title": "docnm", + } + in := "select title_sm_tks from ragflow_t1" + got := rewriteFieldAliases(in, aliases) + // "title" inside "title_sm_tks" should NOT be rewritten. + want := "select title_sm_tks from ragflow_t1" + if got != want { + t.Errorf("rewriteFieldAliases(%q) = %q, want %q (title_sm_tks must NOT be touched)", in, got, want) + } +} + +func TestRewriteFieldAliases_NoAliasMatchLeavesSQLAlone(t *testing.T) { + aliases := map[string]string{ + "docnm_kwd": "docnm", + } + in := "select content_with_weight from ragflow_t1" + got := rewriteFieldAliases(in, aliases) + if got != in { + t.Errorf("unrelated SQL should be unchanged; got %q", got) + } +} + +// ----------------------------------------------------------------------------- +// parsePsqlTable — mirrors infinity_conn_base.py:894-934. +// ----------------------------------------------------------------------------- + +func TestParsePsqlTable_StandardOutput(t *testing.T) { + // Sample psql table output for `select 1 as a, 2 as b;` + out := ` a | b +---+--- + 1 | 2 +(1 row)` + + res := parsePsqlTable(out) + wantCols := []string{"a", "b"} + if !reflect.DeepEqual(res.Columns, wantCols) { + t.Errorf("columns: got %v, want %v", res.Columns, wantCols) + } + wantRows := [][]string{{"1", "2"}} + if !reflect.DeepEqual(res.Rows, wantRows) { + t.Errorf("rows: got %v, want %v", res.Rows, wantRows) + } +} + +func TestParsePsqlTable_EmptyOutput(t *testing.T) { + res := parsePsqlTable("") + if len(res.Columns) != 0 || len(res.Rows) != 0 { + t.Errorf("empty output should yield (0 cols, 0 rows); got %+v", res) + } +} + +func TestParsePsqlTable_NoSeparatorLine(t *testing.T) { + // Some psql configurations skip the separator line; the parser + // should still recover (data starts at line 1 in that case). + out := "a | b\n1 | 2" + res := parsePsqlTable(out) + if len(res.Rows) != 1 { + t.Errorf("rows: got %d, want 1", len(res.Rows)) + } +} + +func TestParsePsqlTable_MultipleRowsAndRowCountFooter(t *testing.T) { + out := ` id | name +----+------ + 1 | foo + 2 | bar +(2 rows)` + res := parsePsqlTable(out) + wantCols := []string{"id", "name"} + if !reflect.DeepEqual(res.Columns, wantCols) { + t.Errorf("columns: got %v, want %v", res.Columns, wantCols) + } + if len(res.Rows) != 2 { + t.Errorf("rows: got %d, want 2", len(res.Rows)) + } + if res.Rows[0][0] != "1" || res.Rows[0][1] != "foo" { + t.Errorf("row[0]: got %v, want [1 foo]", res.Rows[0]) + } + if res.Rows[1][0] != "2" || res.Rows[1][1] != "bar" { + t.Errorf("row[1]: got %v, want [2 bar]", res.Rows[1]) + } +} + +func TestParsePsqlTable_PadsAndTruncatesRows(t *testing.T) { + // Row with fewer cells → pad with empty strings. + // Row with more cells → truncate. + out := ` a | b | c +---+---+--- + 1 | 2 + 1 | 2 | 3 | 4 +(2 rows)` + res := parsePsqlTable(out) + if len(res.Rows) != 2 { + t.Fatalf("rows: got %d, want 2", len(res.Rows)) + } + // First row: ["1", "2", ""] (padded) + if !reflect.DeepEqual(res.Rows[0], []string{"1", "2", ""}) { + t.Errorf("padded row: got %v, want [1 2 ]", res.Rows[0]) + } + // Second row: ["1", "2", "3"] (truncated) + if !reflect.DeepEqual(res.Rows[1], []string{"1", "2", "3"}) { + t.Errorf("truncated row: got %v, want [1 2 3]", res.Rows[1]) + } +} + +func TestParsePsqlTable_SkipsRowCountFooter(t *testing.T) { + out := " a \n---\n 1 \n(1 row)" + res := parsePsqlTable(out) + if len(res.Rows) != 1 { + t.Errorf("row count footer should be skipped; got %d rows", len(res.Rows)) + } +} + +// ----------------------------------------------------------------------------- +// toRowMaps — chunk-shape conversion. +// ----------------------------------------------------------------------------- + +func TestToRowMaps_EmptyResultsReturnsNil(t *testing.T) { + if rows := toRowMaps(nil); rows != nil { + t.Errorf("nil result: got %v, want nil", rows) + } + if rows := toRowMaps(&psqlResult{}); rows != nil { + t.Errorf("empty result: got %v, want nil", rows) + } +} + +func TestToRowMaps_ConvertsToRowMaps(t *testing.T) { + res := &psqlResult{ + Columns: []string{"id", "name"}, + Rows: [][]string{ + {"1", "foo"}, + {"2", "bar"}, + }, + } + got := toRowMaps(res) + want := []map[string]interface{}{ + {"id": "1", "name": "foo"}, + {"id": "2", "name": "bar"}, + } + if !reflect.DeepEqual(got, want) { + t.Errorf("toRowMaps: got %v, want %v", got, want) + } +} + +// ----------------------------------------------------------------------------- +// resolvePsqlHostPort — mirrors infinity_conn_base.py:838-858. +// ----------------------------------------------------------------------------- + +func TestResolvePsqlHostPort_DefaultsWhenConfigEmpty(t *testing.T) { + host, port := resolvePsqlHostPort("", 0) + if host != defaultPsqlHost { + t.Errorf("host: got %q, want %q", host, defaultPsqlHost) + } + if port != defaultPsqlPort { + t.Errorf("port: got %q, want %q", port, defaultPsqlPort) + } +} + +func TestResolvePsqlHostPort_OverridesFromConfig(t *testing.T) { + host, port := resolvePsqlHostPort("10.0.0.1:23817", 5433) + if host != "10.0.0.1" { + t.Errorf("host: got %q, want 10.0.0.1", host) + } + if port != "5433" { + t.Errorf("port: got %q, want 5433", port) + } +} + +func TestResolvePsqlHostPort_EmptyHostInURIFallsBackToDefault(t *testing.T) { + // ":23817" parses via strings.Cut to ("", "23817") — the empty + // host doesn't override the default, matching Python's + // `re.search(r"host=(\S+)", ...)` which only matches a non-empty + // value. + host, port := resolvePsqlHostPort(":23817", 5432) + if host != defaultPsqlHost { + t.Errorf("host: got %q, want default %q (empty host in URI should not override)", host, defaultPsqlHost) + } + if port != "5432" { + t.Errorf("port: got %q, want 5432", port) + } +} + +// ----------------------------------------------------------------------------- +// loadFieldMapping — mirrors infinity_conn_base.py:793-807. +// ----------------------------------------------------------------------------- + +func TestLoadFieldMapping_MissingFileReturnsEmpty(t *testing.T) { + // Use a name that doesn't exist; the function should silently + // return empty maps (matching Python's `os.path.exists` guard). + a2a, r2a, err := loadFieldMapping("nonexistent_mapping_xyz.json") + if err != nil { + t.Fatalf("missing file should be a no-op, got error: %v", err) + } + if len(a2a) != 0 || len(r2a) != 0 { + t.Errorf("missing file should yield empty maps; got a2a=%v r2a=%v", a2a, r2a) + } +} + +func TestLoadFieldMapping_ParsesAliases(t *testing.T) { + // Write a temporary mapping file. + dir := t.TempDir() + mappingPath := filepath.Join(dir, "test_mapping.json") + contents := `{ + "docnm": {"type": "varchar", "comment": "docnm_kwd, title_tks, title_sm_tks"}, + "content": {"type": "varchar", "comment": "content_with_weight, content_ltks"}, + "plain": {"type": "varchar"} + }` + if err := os.WriteFile(mappingPath, []byte(contents), 0o644); err != nil { + t.Fatalf("write mapping: %v", err) + } + + // Set RAG_PROJECT_BASE to the temp dir's parent so loadFieldMapping + // finds the file at /conf/. + os.Setenv("RAG_PROJECT_BASE", dir) + defer os.Unsetenv("RAG_PROJECT_BASE") + + // Need to create conf/ subdir. + if err := os.MkdirAll(filepath.Join(dir, "conf"), 0o755); err != nil { + t.Fatalf("mkdir conf: %v", err) + } + if err := os.WriteFile(filepath.Join(dir, "conf", "test_mapping.json"), []byte(contents), 0o644); err != nil { + t.Fatalf("write conf/mapping: %v", err) + } + + a2a, r2a, err := loadFieldMapping("test_mapping.json") + if err != nil { + t.Fatalf("loadFieldMapping: %v", err) + } + + // alias → actual + expectedAliases := map[string]string{ + "docnm_kwd": "docnm", + "title_tks": "docnm", + "title_sm_tks": "docnm", + "content_with_weight": "content", + "content_ltks": "content", + } + if !reflect.DeepEqual(a2a, expectedAliases) { + t.Errorf("aliasToActual: got %v, want %v", a2a, expectedAliases) + } + + // actual → first alias (mirrors Python at line 807) + if r2a["docnm"] != "docnm_kwd" { + t.Errorf("actualToFirstAlias[docnm]: got %q, want docnm_kwd", r2a["docnm"]) + } + if r2a["content"] != "content_with_weight" { + t.Errorf("actualToFirstAlias[content]: got %q, want content_with_weight", r2a["content"]) + } + // "plain" has no comment, so it shouldn't appear in the reverse map. + if _, ok := r2a["plain"]; ok { + t.Errorf("actualToFirstAlias should not include fields without comments") + } +} + +func TestLoadFieldMapping_EmptyNameDefaultsToInfinityMappingJSON(t *testing.T) { + // Empty name → defaults to "infinity_mapping.json" (line 145). + // We just verify the function doesn't panic and the file-not-found + // path is taken silently. + a2a, r2a, err := loadFieldMapping("") + if err != nil { + t.Fatalf("empty name: %v", err) + } + if len(a2a) != 0 || len(r2a) != 0 { + t.Errorf("empty name + no file should yield empty maps; got a2a=%v r2a=%v", a2a, r2a) + } +} diff --git a/internal/entity/models/chat_tools.go b/internal/entity/models/chat_tools.go new file mode 100644 index 0000000000..09308c71ee --- /dev/null +++ b/internal/entity/models/chat_tools.go @@ -0,0 +1,351 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package models + +import ( + "context" + "encoding/json" + "fmt" + "sync" + + "ragflow/internal/tokenizer" +) + +const ( + defaultMaxRetries = 3 + defaultMaxRounds = 5 +) + +// ChatWithTools runs the non-streaming tool-calling loop. +func (cm *ChatModel) ChatWithTools(ctx context.Context, system string, history []Message, chatCfg *ChatConfig) (string, int, error) { + tc := cm.ToolConfig + if tc == nil { + return "", 0, fmt.Errorf("ChatWithTools called without bound tools") + } + + var toolsList interface{} + if err := json.Unmarshal([]byte(tc.Tools), &toolsList); err != nil { + return "", 0, fmt.Errorf("failed to parse tools JSON: %w", err) + } + + maxRounds := tc.MaxRounds + if maxRounds <= 0 { + maxRounds = defaultMaxRounds + } + maxRetries := tc.MaxRetries + if maxRetries <= 0 { + maxRetries = defaultMaxRetries + } + + if system != "" && len(history) > 0 && history[0].Role != "system" { + history = append([]Message{{Role: "system", Content: system}}, history...) + } + + baseHistory := make([]Message, len(history)) + copy(baseHistory, history) + + for attempt := 0; attempt < maxRetries; attempt++ { + select { + case <-ctx.Done(): + return "", 0, ctx.Err() + default: + } + + h := make([]Message, len(baseHistory)) + copy(h, baseHistory) + + answer, tokens, err := runToolLoop(ctx, cm, h, toolsList, chatCfg, maxRounds) + if err == nil { + return answer, tokens, nil + } + } + return "", 0, fmt.Errorf("ChatWithTools failed after %d retries", maxRetries) +} + +func runToolLoop(ctx context.Context, cm *ChatModel, history []Message, toolsList interface{}, chatCfg *ChatConfig, maxRounds int) (string, int, error) { + var totalTokens int + + for round := 0; round <= maxRounds; round++ { + select { + case <-ctx.Done(): + return "", totalTokens, ctx.Err() + default: + } + cfg := *chatCfg + cfg.Tools = toolsList + tcChoice := "auto" + cfg.ToolChoice = &tcChoice + + resp, err := cm.ModelDriver.ChatWithMessages(*cm.ModelName, history, cm.APIConfig, &cfg) + if err != nil { + return "", totalTokens, fmt.Errorf("round %d: %w", round, err) + } + if resp == nil { + return "", totalTokens, fmt.Errorf("round %d: nil response", round) + } + + if len(resp.ToolCalls) == 0 { + answer := "" + if resp.Answer != nil { + answer = *resp.Answer + } + if resp.ReasonContent != nil && *resp.ReasonContent != "" { + answer = "" + *resp.ReasonContent + "" + answer + } + totalTokens += tokenizer.NumTokensFromString(answer) + return answer, totalTokens, nil + } + + history = appendToolResults(history, resp.ToolCalls, cm.ToolConfig.ToolCallSession) + } + + // Exceeded max rounds + history = append(history, Message{ + Role: "user", + Content: fmt.Sprintf("Exceed max rounds: %d", maxRounds), + }) + cfg := *chatCfg + resp, err := cm.ModelDriver.ChatWithMessages(*cm.ModelName, history, cm.APIConfig, &cfg) + if err != nil { + return "", totalTokens, fmt.Errorf("final call: %w", err) + } + if resp == nil || resp.Answer == nil { + return "", totalTokens, fmt.Errorf("final call: no answer") + } + totalTokens += tokenizer.NumTokensFromString(*resp.Answer) + return *resp.Answer, totalTokens, nil +} + +// ChatStreamlyWithTools runs the streaming tool-calling loop. +func (cm *ChatModel) ChatStreamlyWithTools(ctx context.Context, system string, history []Message, chatCfg *ChatConfig, sender func(*string, *string) error) (int, error) { + tc := cm.ToolConfig + if tc == nil { + return 0, fmt.Errorf("ChatStreamlyWithTools called without bound tools") + } + + var toolsList interface{} + if err := json.Unmarshal([]byte(tc.Tools), &toolsList); err != nil { + return 0, fmt.Errorf("failed to parse tools JSON: %w", err) + } + + maxRounds := tc.MaxRounds + if maxRounds <= 0 { + maxRounds = defaultMaxRounds + } + maxRetries := tc.MaxRetries + if maxRetries <= 0 { + maxRetries = defaultMaxRetries + } + + if system != "" && len(history) > 0 && history[0].Role != "system" { + history = append([]Message{{Role: "system", Content: system}}, history...) + } + + baseHistory := make([]Message, len(history)) + copy(baseHistory, history) + + for attempt := 0; attempt < maxRetries; attempt++ { + select { + case <-ctx.Done(): + return 0, ctx.Err() + default: + } + + h := make([]Message, len(baseHistory)) + copy(h, baseHistory) + + totalTokens, err := runStreamToolLoop(ctx, cm, h, toolsList, chatCfg, maxRounds, sender) + if err == nil { + return totalTokens, nil + } + } + return 0, fmt.Errorf("ChatStreamlyWithTools failed after %d retries", maxRetries) +} + +func runStreamToolLoop(ctx context.Context, cm *ChatModel, history []Message, toolsList interface{}, chatCfg *ChatConfig, maxRounds int, sender func(*string, *string) error) (int, error) { + var totalTokens int + + for round := 0; round <= maxRounds; round++ { + select { + case <-ctx.Done(): + return totalTokens, ctx.Err() + default: + } + cfg := *chatCfg + cfg.Tools = toolsList + tcChoice := "auto" + cfg.ToolChoice = &tcChoice + cfg.Stream = boolPtr(true) + var tcs []map[string]interface{} + cfg.ToolCallsResult = &tcs + + reasoningStarted := false + var answer string + var pendingThinkClose bool + + err := cm.ModelDriver.ChatStreamlyWithSender(*cm.ModelName, history, cm.APIConfig, &cfg, func(delta *string, reason *string) error { + if reason != nil && *reason != "" { + if !reasoningStarted { + reasoningStarted = true + thinkOpen := "" + if e := sender(&thinkOpen, nil); e != nil { + return e + } + } + pendingThinkClose = true + return sender(reason, nil) + } + // Reasoning ended, close the think block if open + if pendingThinkClose { + pendingThinkClose = false + thinkClose := "" + if e := sender(&thinkClose, nil); e != nil { + return e + } + } + if delta != nil && *delta != "" { + if *delta == "[DONE]" { + return nil + } + totalTokens += tokenizer.NumTokensFromString(*delta) + answer += *delta + if e := sender(delta, nil); e != nil { + return e + } + } + return nil + }) + // Close any unclosed think block after stream completes + if pendingThinkClose { + pendingThinkClose = false + thinkClose := "" + if e := sender(&thinkClose, nil); e != nil { + return totalTokens, e + } + } + if err != nil { + return totalTokens, fmt.Errorf("round %d: %w", round, err) + } + + var toolCalls []map[string]interface{} + if cfg.ToolCallsResult != nil { + toolCalls = *cfg.ToolCallsResult + } + + if answer != "" && len(toolCalls) == 0 { + return totalTokens, nil + } + if len(toolCalls) == 0 { + return totalTokens, fmt.Errorf("round %d: no content and no tool_calls", round) + } + + history = appendToolResults(history, toolCalls, cm.ToolConfig.ToolCallSession) + } + + // Exceeded max rounds + history = append(history, Message{ + Role: "user", + Content: fmt.Sprintf("Exceed max rounds: %d", maxRounds), + }) + cfg := *chatCfg + cfg.Stream = boolPtr(true) + return totalTokens, cm.ModelDriver.ChatStreamlyWithSender(*cm.ModelName, history, cm.APIConfig, &cfg, sender) +} + +// appendToolResults executes tool calls concurrently, appends the assistant +// message with tool_calls and individual tool result messages to history. +func appendToolResults(history []Message, toolCalls []map[string]interface{}, session ToolCallSession) []Message { + if session == nil { + history = append(history, Message{ + Role: "assistant", + Content: nil, + ToolCalls: toolCalls, + }) + for _, tc := range toolCalls { + tcID, _ := tc["id"].(string) + history = append(history, Message{ + Role: "tool", + Content: "Error: no tool session configured", + ToolCallID: tcID, + }) + } + return history + } + var mu sync.Mutex + var wg sync.WaitGroup + type toolResult struct { + index int + tcID string + content string + } + results := make([]toolResult, len(toolCalls)) + + for i, tc := range toolCalls { + wg.Add(1) + go func(idx int, tcMap map[string]interface{}) { + defer wg.Done() + var result toolResult + result.index = idx + fn, ok := tcMap["function"].(map[string]interface{}) + if !ok { + mu.Lock() + results[idx] = result + mu.Unlock() + return + } + name, _ := fn["name"].(string) + argsStr, _ := fn["arguments"].(string) + result.tcID, _ = tcMap["id"].(string) + + var args map[string]interface{} + if err := json.Unmarshal([]byte(argsStr), &args); err != nil { + args = map[string]interface{}{"raw_arguments": argsStr} + } + + res, err := session.ToolCall(name, args) + if err != nil { + result.content = fmt.Sprintf("Error: %s", err.Error()) + } else { + result.content = res + } + mu.Lock() + results[idx] = result + mu.Unlock() + }(i, tc) + } + wg.Wait() + + history = append(history, Message{ + Role: "assistant", + Content: nil, + ToolCalls: toolCalls, + }) + + for _, r := range results { + history = append(history, Message{ + Role: "tool", + Content: r.content, + ToolCallID: r.tcID, + }) + } + + return history +} + +func boolPtr(b bool) *bool { + return &b +} diff --git a/internal/entity/models/openai.go b/internal/entity/models/openai.go index f5010da859..6c97b1877c 100644 --- a/internal/entity/models/openai.go +++ b/internal/entity/models/openai.go @@ -73,21 +73,27 @@ func (o *OpenAIModel) ChatWithMessages(modelName string, messages []Message, api baseURL = strings.TrimSuffix(baseURL, "/") url := fmt.Sprintf("%s/%s", baseURL, o.baseModel.URLSuffix.Chat) - // Convert messages to the format expected by the API + // Convert messages to API format (supports multimodal content) apiMessages := make([]map[string]interface{}, len(messages)) for i, msg := range messages { - apiMessages[i] = map[string]interface{}{ + apiMsg := map[string]interface{}{ "role": msg.Role, "content": msg.Content, } + if msg.ToolCallID != "" { + apiMsg["tool_call_id"] = msg.ToolCallID + } + if len(msg.ToolCalls) > 0 { + apiMsg["tool_calls"] = msg.ToolCalls + } + apiMessages[i] = apiMsg } // Build request body reqBody := map[string]interface{}{ - "model": modelName, - "messages": apiMessages, - "stream": false, - "temperature": 1, + "model": modelName, + "messages": apiMessages, + "stream": false, } if chatModelConfig != nil { @@ -106,6 +112,21 @@ func (o *OpenAIModel) ChatWithMessages(modelName string, messages []Message, api if chatModelConfig.Stop != nil { reqBody["stop"] = *chatModelConfig.Stop } + + if chatModelConfig.Tools != nil { + reqBody["tools"] = chatModelConfig.Tools + tc := "auto" + if chatModelConfig.ToolChoice != nil { + tc = *chatModelConfig.ToolChoice + } + reqBody["tool_choice"] = tc + } + } + + // Qwen3 family: disable thinking by default (matches Python's + // _apply_model_family_policies in rag/llm/chat_model.py:119-121). + if strings.Contains(strings.ToLower(modelName), "qwen3") && (chatModelConfig == nil || chatModelConfig.Thinking == nil) { + reqBody["enable_thinking"] = false } jsonData, err := json.Marshal(reqBody) @@ -160,9 +181,9 @@ func (o *OpenAIModel) ChatWithMessages(modelName string, messages []Message, api return nil, fmt.Errorf("invalid message format") } - content, ok := messageMap["content"].(string) - if !ok { - return nil, fmt.Errorf("invalid content format") + var content string + if c, ok := messageMap["content"].(string); ok { + content = c } // OpenAI reasoning models (o-series and similar) return reasoning text in @@ -175,9 +196,19 @@ func (o *OpenAIModel) ChatWithMessages(modelName string, messages []Message, api } } + var toolCalls []map[string]interface{} + if tcs, ok := messageMap["tool_calls"].([]interface{}); ok { + for _, tc := range tcs { + if tcMap, ok := tc.(map[string]interface{}); ok { + toolCalls = append(toolCalls, tcMap) + } + } + } + chatResponse := &ChatResponse{ Answer: &content, ReasonContent: &reasonContent, + ToolCalls: toolCalls, } return chatResponse, nil @@ -200,13 +231,20 @@ func (o *OpenAIModel) ChatStreamlyWithSender(modelName string, messages []Messag baseURL = strings.TrimSuffix(baseURL, "/") url := fmt.Sprintf("%s/%s", baseURL, o.baseModel.URLSuffix.Chat) - // Convert messages to API format (supports multimodal content) + // Convert messages to API format (supports multimodal content and tool messages) apiMessages := make([]map[string]interface{}, len(messages)) for i, msg := range messages { - apiMessages[i] = map[string]interface{}{ + apiMsg := map[string]interface{}{ "role": msg.Role, "content": msg.Content, } + if msg.ToolCallID != "" { + apiMsg["tool_call_id"] = msg.ToolCallID + } + if len(msg.ToolCalls) > 0 { + apiMsg["tool_calls"] = msg.ToolCalls + } + apiMessages[i] = apiMsg } // Build request body with streaming on by default @@ -236,6 +274,20 @@ func (o *OpenAIModel) ChatStreamlyWithSender(modelName string, messages []Messag if chatModelConfig.Stop != nil { reqBody["stop"] = *chatModelConfig.Stop } + + if chatModelConfig.Tools != nil { + reqBody["tools"] = chatModelConfig.Tools + tc := "auto" + if chatModelConfig.ToolChoice != nil { + tc = *chatModelConfig.ToolChoice + } + reqBody["tool_choice"] = tc + } + } + + // Qwen3 family: disable thinking by default. + if strings.Contains(strings.ToLower(modelName), "qwen3") && (chatModelConfig == nil || chatModelConfig.Thinking == nil) { + reqBody["enable_thinking"] = false } jsonData, err := json.Marshal(reqBody) @@ -263,20 +315,75 @@ func (o *OpenAIModel) ChatStreamlyWithSender(modelName string, messages []Messag } sawTerminal := false - done, err := ParseSSEStream[map[string]interface{}](resp.Body, func(event map[string]interface{}) error { + accumulatedToolCalls := make(map[int]map[string]interface{}) + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + line := scanner.Text() + + // SSE data line starts with "data:" + if !strings.HasPrefix(line, "data:") { + continue + } + + // Extract JSON after "data:" + data := strings.TrimSpace(line[5:]) + + // [DONE] marks the end of the stream + if data == "[DONE]" { + sawTerminal = true + break + } + + // Parse the JSON event + var event map[string]interface{} + if err = json.Unmarshal([]byte(data), &event); err != nil { + continue + } + choices, ok := event["choices"].([]interface{}) if !ok || len(choices) == 0 { - return nil + continue } firstChoice, ok := choices[0].(map[string]interface{}) if !ok { - return nil + continue } delta, ok := firstChoice["delta"].(map[string]interface{}) if !ok { - return nil + continue + } + + // Accumulate streaming tool_call deltas (mirrors Python's + // async_chat_streamly_with_tools in rag/llm/chat_model.py:500-509). + if tcs, ok := delta["tool_calls"].([]interface{}); ok { + for _, tc := range tcs { + if tcMap, ok := tc.(map[string]interface{}); ok { + idxF, ok := tcMap["index"].(float64) + if !ok { + continue + } + idx := int(idxF) + existing, hasExisting := accumulatedToolCalls[idx] + if hasExisting { + if fn, ok := tcMap["function"].(map[string]interface{}); ok { + if args, ok := fn["arguments"].(string); ok { + if ef, ok := existing["function"].(map[string]interface{}); ok { + if ea, ok := ef["arguments"].(string); ok { + ef["arguments"] = ea + args + } else { + ef["arguments"] = args + } + } + } + } + } else { + accumulatedToolCalls[idx] = cloneMap(tcMap) + } + } + } + continue // tool_call deltas don't carry content } reasoningContent, ok := delta["reasoning_content"].(string) @@ -297,15 +404,23 @@ func (o *OpenAIModel) ChatStreamlyWithSender(modelName string, messages []Messag if ok && finishReason != "" { sawTerminal = true } - return nil - }) - if err != nil { + } + if err := scanner.Err(); err != nil { return fmt.Errorf("failed to scan response body: %w", err) } - if !done && !sawTerminal { + if !sawTerminal { return fmt.Errorf("openai: stream ended before [DONE] or finish_reason") } + // Populate ToolCallsResult with accumulated streaming tool_calls. + if len(accumulatedToolCalls) > 0 && chatModelConfig != nil { + tcs := make([]map[string]interface{}, 0, len(accumulatedToolCalls)) + for _, tc := range accumulatedToolCalls { + tcs = append(tcs, tc) + } + chatModelConfig.ToolCallsResult = &tcs + } + // Send the [DONE] marker for OpenAI compatibility endOfStream := "[DONE]" if err := sender(&endOfStream, nil); err != nil { @@ -907,3 +1022,11 @@ func (o *OpenAIModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) func (o *OpenAIModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { return nil, fmt.Errorf("%s, no such method", o.Name()) } + +func cloneMap(m map[string]interface{}) map[string]interface{} { + cp := make(map[string]interface{}, len(m)) + for k, v := range m { + cp[k] = v + } + return cp +} diff --git a/internal/entity/models/siliconflow.go b/internal/entity/models/siliconflow.go index 238834b841..79bca41406 100644 --- a/internal/entity/models/siliconflow.go +++ b/internal/entity/models/siliconflow.go @@ -93,10 +93,9 @@ func (s *SiliconflowModel) ChatWithMessages(modelName string, messages []Message // Build request body reqBody := map[string]interface{}{ - "model": modelName, - "messages": apiMessages, - "stream": false, - "temperature": 1, + "model": modelName, + "messages": apiMessages, + "stream": false, } if chatModelConfig != nil { @@ -119,18 +118,12 @@ func (s *SiliconflowModel) ChatWithMessages(modelName string, messages []Message if chatModelConfig.Stop != nil { reqBody["stop"] = *chatModelConfig.Stop } + } - if chatModelConfig.Thinking != nil { - if *chatModelConfig.Thinking { - reqBody["thinking"] = map[string]interface{}{ - "type": "enabled", - } - } else { - reqBody["thinking"] = map[string]interface{}{ - "type": "disabled", - } - } - } + // Qwen3 family: disable thinking by default (matches Python's + // _apply_model_family_policies in rag/llm/chat_model.py:119-121). + if strings.Contains(strings.ToLower(modelName), "qwen3") && (chatModelConfig == nil || chatModelConfig.Thinking == nil) { + reqBody["enable_thinking"] = false } jsonData, err := json.Marshal(reqBody) @@ -243,10 +236,9 @@ func (s *SiliconflowModel) ChatStreamlyWithSender(modelName string, messages []M // Build request body with streaming enabled reqBody := map[string]interface{}{ - "model": modelName, - "messages": apiMessages, - "stream": true, - "temperature": 1, + "model": modelName, + "messages": apiMessages, + "stream": true, } if chatModelConfig != nil { @@ -273,18 +265,12 @@ func (s *SiliconflowModel) ChatStreamlyWithSender(modelName string, messages []M if chatModelConfig.Stop != nil { reqBody["stop"] = *chatModelConfig.Stop } + } - if chatModelConfig.Thinking != nil { - if *chatModelConfig.Thinking { - reqBody["thinking"] = map[string]interface{}{ - "type": "enabled", - } - } else { - reqBody["thinking"] = map[string]interface{}{ - "type": "disabled", - } - } - } + // Qwen3 family: disable thinking by default (matches Python's + // _apply_model_family_policies in rag/llm/chat_model.py:119-121). + if strings.Contains(strings.ToLower(modelName), "qwen3") && (chatModelConfig == nil || chatModelConfig.Thinking == nil) { + reqBody["enable_thinking"] = false } jsonData, err := json.Marshal(reqBody) diff --git a/internal/entity/models/types.go b/internal/entity/models/types.go index 863a33a0d0..36d2ae3610 100644 --- a/internal/entity/models/types.go +++ b/internal/entity/models/types.go @@ -1,5 +1,7 @@ package models +import "encoding/json" + // Message represents a chat message with role and content // // Content is interface{} to support different formats: @@ -7,8 +9,15 @@ package models // - []interface{}: multimodal content array where each element is map[string]interface{} // (e.g., [{"type": "text", "text": "..."}, {"type": "image_url", "image_url": {"url": "..."}}]) type Message struct { - Role string `json:"role"` - Content interface{} `json:"content"` + Role string `json:"role"` + Content interface{} `json:"content"` + ToolCallID string `json:"tool_call_id,omitempty"` + ToolCalls []map[string]interface{} `json:"tool_calls,omitempty"` +} + +// ToolCallSession mirrors Python's common.mcp_tool_call_conn.ToolCallSession protocol. +type ToolCallSession interface { + ToolCall(name string, arguments map[string]interface{}) (string, error) } // EmbeddingModel interface for embedding models @@ -48,8 +57,9 @@ type ModelDriver interface { } type ChatResponse struct { - Answer *string `json:"answer"` - ReasonContent *string `json:"reason_content"` + Answer *string `json:"answer"` + ReasonContent *string `json:"reason_content"` + ToolCalls []map[string]interface{} `json:"tool_calls,omitempty"` } type EmbeddingData struct { @@ -130,17 +140,20 @@ type URLSuffix struct { } type ChatConfig struct { - Stream *bool - Vision *bool - Thinking *bool - MaxTokens *int - Temperature *float64 - TopP *float64 - DoSample *bool - Stop *[]string - ModelClass *string - Effort *string - Verbosity *string + Stream *bool + Vision *bool + Thinking *bool + MaxTokens *int + Temperature *float64 + TopP *float64 + DoSample *bool + Stop *[]string + ModelClass *string + Effort *string + Verbosity *string + Tools interface{} `json:"tools,omitempty"` + ToolChoice *string `json:"tool_choice,omitempty"` + ToolCallsResult *[]map[string]interface{} `json:"-"` } type APIConfig struct { @@ -211,11 +224,20 @@ func (r *RerankModel) Rerank(query string, texts []string, apiConfig *APIConfig, return r.ModelDriver.Rerank(r.ModelName, query, texts, apiConfig, rerankConfig) } +// ToolConfig bundles tool-calling configuration for a ChatModel. +type ToolConfig struct { + Tools string // JSON-encoded tools list + MaxRounds int // max tool-calling rounds (default: 5) + MaxRetries int // max retries on failure (default: 3) + ToolCallSession ToolCallSession // session that executes tool calls +} + // ChatModel wraps a ModelDriver with chat-specific configuration type ChatModel struct { ModelDriver ModelDriver ModelName *string APIConfig *APIConfig + ToolConfig *ToolConfig } // NewChatModel creates a new ChatModel @@ -226,3 +248,26 @@ func NewChatModel(driver ModelDriver, modelName *string, apiConfig *APIConfig) * APIConfig: apiConfig, } } + +// BindTools registers tools for the ChatModel to call. +// Mirrors Python's Base.bind_tools() in rag/llm/chat_model.py. +func (cm *ChatModel) BindTools(session ToolCallSession, tools interface{}) { + // Serialize tools to JSON if it's a list/map. + toolsJSON := "" + switch v := tools.(type) { + case string: + toolsJSON = v + case []byte: + toolsJSON = string(v) + default: + if b, err := json.Marshal(tools); err == nil { + toolsJSON = string(b) + } + } + cm.ToolConfig = &ToolConfig{ + Tools: toolsJSON, + MaxRounds: defaultMaxRounds, + MaxRetries: defaultMaxRetries, + ToolCallSession: session, + } +} diff --git a/internal/handler/openai_chat.go b/internal/handler/openai_chat.go new file mode 100644 index 0000000000..c2ef962756 --- /dev/null +++ b/internal/handler/openai_chat.go @@ -0,0 +1,125 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package handler + +import ( + "encoding/json" + "io" + "ragflow/internal/common" + "ragflow/internal/service" + + "github.com/gin-gonic/gin" +) + +type OpenAIChatHandler struct { + svc *service.OpenAIChatService +} + +func NewOpenAIChatHandler(svc *service.OpenAIChatService) *OpenAIChatHandler { + return &OpenAIChatHandler{svc: svc} +} + +// OpenAIChatCompletions handles the OpenAI-compatible chat completions route. +// @Summary OpenAI Chat Completions +// @Description OpenAI-compatible chat completions endpoint +// @Tags openai +// @Accept json +// @Produce json +// @Param chat_id path string true "dialog id" +// @Param request body service.OpenAIChatRequest true "chat completion request" +// @Success 200 {object} map[string]interface{} +// @Router /api/v1/openai/{chat_id}/chat/completions [post] +func (h *OpenAIChatHandler) OpenAIChatCompletions(c *gin.Context) { + chatID := c.Param("chat_id") + if chatID == "" { + jsonError(c, common.CodeDataError, "You don't own the chat "+chatID) + return + } + + user, code, msg := GetUser(c) + if code != common.CodeSuccess { + jsonError(c, code, msg) + return + } + + bodyBytes, err := io.ReadAll(c.Request.Body) + if err != nil { + jsonError(c, common.CodeArgumentError, err.Error()) + return + } + + // Parse body into the typed request + var req service.OpenAIChatRequest + if err := json.Unmarshal(bodyBytes, &req); err != nil { + jsonError(c, common.CodeArgumentError, err.Error()) + return + } + + // Messages presence + if len(req.Messages) == 0 { + jsonError(c, common.CodeDataError, "You have to provide messages.") + return + } + + // extra_body shape validation + extraBody, extraBodyOK := req.ExtraBody.(map[string]interface{}) + if req.ExtraBody != nil && !extraBodyOK { + jsonError(c, common.CodeArgumentError, "extra_body must be an object.") + return + } + + // reference_metadata shape validation + if extraBody != nil { + if rm, ok := extraBody["reference_metadata"].(map[string]interface{}); ok { + if rawFields, has := rm["fields"]; has { + if rawArr, ok := rawFields.([]interface{}); !ok { + jsonError(c, common.CodeArgumentError, "reference_metadata.fields must be an array.") + return + } else { + for _, item := range rawArr { + if _, ok := item.(string); !ok { + jsonError(c, common.CodeArgumentError, "reference_metadata.fields must be an array.") + return + } + } + } + } + } + } + + // metadata_condition shape validation + if extraBody != nil { + if mc, ok := extraBody["metadata_condition"]; ok && mc != nil { + if _, ok := mc.(map[string]interface{}); !ok { + jsonError(c, common.CodeArgumentError, "metadata_condition must be an object.") + return + } + } + } + + // Last message must be from the user + if last := req.Messages[len(req.Messages)-1]; last != nil { + if role, _ := last["role"].(string); role != "user" { + jsonError(c, common.CodeDataError, "The last content of this conversation is not from user.") + return + } + } + + // All early-rejection checks passed. Delegate to the service for the + // actual LLM call. + h.svc.OpenAIChatCompletions(c, user.ID, chatID, bodyBytes) +} diff --git a/internal/handler/openai_chat_test.go b/internal/handler/openai_chat_test.go new file mode 100644 index 0000000000..2e4e6177b7 --- /dev/null +++ b/internal/handler/openai_chat_test.go @@ -0,0 +1,254 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package handler + +import ( + "bytes" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gin-gonic/gin" + + "ragflow/internal/entity" + "ragflow/internal/service" +) + +// TestNormalizeMessageContent and friends moved to +// internal/service/openai_chat_test.go as TestService_NormalizeMessageContent_* +// (the helpers themselves moved to the service package). TestWriteSSE +// also moved to the service package as TestService_WriteSSE_FormatAndFlush. +// Handler tests here focus on the HTTP boundary: rejection at parse / +// presence / forbidden-key checks. + +// fakeOpenAIUser injects a real *entity.User into the context so GetUser +// succeeds. Without this, the handler short-circuits with +// "User not found" before any validation runs. +func fakeOpenAIUser(c *gin.Context) { + c.Set("user", &entity.User{ID: "u1", Email: "u@x"}) +} + +// newOpenAITestContext builds a test context with a real user, the +// chat_id path param, and a POST request carrying the given JSON body. +func newOpenAITestContext(t *testing.T, chatID, body string) (*gin.Context, *httptest.ResponseRecorder) { + t.Helper() + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request, _ = http.NewRequest(http.MethodPost, + "/api/v1/openai/"+chatID+"/chat/completions", + bytes.NewBufferString(body)) + c.Request.Header.Set("Content-Type", "application/json") + c.Params = gin.Params{{Key: "chat_id", Value: chatID}} + fakeOpenAIUser(c) + return c, w +} + +// TestChatCompletions_RejectsMissingMessages pins down the validation +// rule "You have to provide messages." (openai_api.py:255-256). The +// peek-and-discard parse in the handler (see OpenAIChatCompletions) +// rejects this BEFORE the service is called, so the test doesn't +// need a DB. +func TestChatCompletions_RejectsMissingMessages(t *testing.T) { + h := NewOpenAIChatHandler(service.NewOpenAIChatService()) + c, w := newOpenAITestContext(t, "c1", `{"model":"model"}`) + + h.OpenAIChatCompletions(c) + + if w.Code != http.StatusOK { + t.Fatalf("expected HTTP 200 (Python convention), got %d", w.Code) + } + respBody := w.Body.String() + if !strings.Contains(respBody, "You have to provide messages.") { + t.Fatalf("expected 'You have to provide messages.' in body, got %s", respBody) + } +} + +// TestChatCompletions_DefaultsMissingModelToModel pins the +// Go-specific behavior: `model` is OPTIONAL on the openai_chat +// endpoint. If absent or empty, the handler injects the OpenAI +// compat sentinel "model" (which the service resolves to the +// dialog's default LLM). Python enforces the +// OpenAI spec strictly via @validate_request("model", "messages") +// at openai_api.py:237; the Go side intentionally relaxes that +// so callers can use the dialog's default without typing +// `"model": "model"` explicitly. +// +// The handler's check is the "model" is defaulted, not the +// service's success. We recover from the service's expected DB +// panic and only assert on what the handler wrote. +func TestChatCompletions_DefaultsMissingModelToModel(t *testing.T) { + h := NewOpenAIChatHandler(service.NewOpenAIChatService()) + c, w := newOpenAITestContext(t, "c1", + `{"messages":[{"role":"user","content":"hi"}]}`) + + // The handler should accept the request and call the service. + // The service will panic on DB access (no DB in tests); we + // recover so the test only asserts the handler's behavior. + func() { + defer func() { + _ = recover() // expected: service panicked on DB call + }() + h.OpenAIChatCompletions(c) + }() + + // The handler must NOT have written a rejection response + // before the service panicked. If the response body has + // "You have to provide messages.", the messages-presence + // check fired by mistake (it shouldn't, since messages IS + // present). + respBody := w.Body.String() + if strings.Contains(respBody, "You have to provide messages.") { + t.Fatalf("missing model should be defaulted, not rejected; got: %s", respBody) + } +} + +// TestChatCompletions_RejectsBadExtraBody pins down the validation +// rule "extra_body must be an object." (openai_api.py:243). +func TestChatCompletions_RejectsBadExtraBody(t *testing.T) { + h := NewOpenAIChatHandler(service.NewOpenAIChatService()) + c, w := newOpenAITestContext(t, "c1", `{ + "model": "model", + "messages": [{"role": "user", "content": "hi"}], + "extra_body": "not an object" + }`) + + h.OpenAIChatCompletions(c) + + respBody := w.Body.String() + if !strings.Contains(respBody, "extra_body must be an object.") { + t.Fatalf("expected 'extra_body must be an object.' in body, got %s", respBody) + } +} + +// TestChatCompletions_RejectsBadMetadataCondition pins down the validation +// rule "metadata_condition must be an object." (openai_api.py:287). +func TestChatCompletions_RejectsBadMetadataCondition(t *testing.T) { + h := NewOpenAIChatHandler(service.NewOpenAIChatService()) + c, w := newOpenAITestContext(t, "c1", `{ + "model": "model", + "messages": [{"role": "user", "content": "hi"}], + "extra_body": {"metadata_condition": "bad"} + }`) + + h.OpenAIChatCompletions(c) + + respBody := w.Body.String() + if !strings.Contains(respBody, "metadata_condition must be an object.") { + t.Fatalf("expected 'metadata_condition must be an object.' in body, got %s", respBody) + } +} + +// TestChatCompletions_RejectsBadReferenceMetadataFields pins down the +// validation rule "reference_metadata.fields must be an array." (openai_api.py:251). +func TestChatCompletions_RejectsBadReferenceMetadataFields(t *testing.T) { + h := NewOpenAIChatHandler(service.NewOpenAIChatService()) + c, w := newOpenAITestContext(t, "c1", `{ + "model": "model", + "messages": [{"role": "user", "content": "hi"}], + "extra_body": {"reference_metadata": {"fields": "author"}} + }`) + + h.OpenAIChatCompletions(c) + + respBody := w.Body.String() + if !strings.Contains(respBody, "reference_metadata.fields must be an array.") { + t.Fatalf("expected 'reference_metadata.fields must be an array.' in body, got %s", respBody) + } +} + +// TestChatCompletions_RejectsLastMessageNotUser pins down the validation +// rule "The last content of this conversation is not from user." (openai_api.py:261). +func TestChatCompletions_RejectsLastMessageNotUser(t *testing.T) { + h := NewOpenAIChatHandler(service.NewOpenAIChatService()) + c, w := newOpenAITestContext(t, "c1", `{ + "model": "model", + "messages": [{"role": "user", "content": "hi"}, {"role": "assistant", "content": "world"}] + }`) + + h.OpenAIChatCompletions(c) + + respBody := w.Body.String() + if !strings.Contains(respBody, "The last content of this conversation is not from user.") { + t.Fatalf("expected 'The last content of this conversation is not from user.' in body, got %s", respBody) + } +} + +// TestChatCompletions_RejectsInvalidJSON pins down the JSON-parse failure +// path. We expect a 4xx-ish error code (Gin's binding error returns +// 400-equivalent message; we accept any non-empty error). +func TestChatCompletions_RejectsInvalidJSON(t *testing.T) { + h := NewOpenAIChatHandler(service.NewOpenAIChatService()) + c, w := newOpenAITestContext(t, "c1", `{ not json`) + + h.OpenAIChatCompletions(c) + + if w.Body.Len() == 0 { + t.Fatalf("expected non-empty error body for invalid JSON") + } +} + +// TestChatCompletions_SilentlyDropsTopLevelStop verifies that top-level +// `stop` is silently dropped rather than rejected. The field is not declared +// on OpenAIChatRequest, so Go's json.Unmarshal discards it — matching the +// OpenAI server convention of ignoring unknown request fields. The CLI parser +// rejects `stop` at parse time for CLI callers. +// +// The payload ends with an assistant turn so validation trips the early +// "last content not from user" rejector before reaching the DB. The +// rejection message proves we got past the stop check; the absence of +// "not supported" proves the field was silently dropped. +func TestChatCompletions_SilentlyDropsTopLevelStop(t *testing.T) { + h := NewOpenAIChatHandler(service.NewOpenAIChatService()) + c, w := newOpenAITestContext(t, "c1", `{ + "model": "model", + "messages": [{"role": "user", "content": "hi"}, {"role": "assistant", "content": "world"}], + "stop": ["END"] + }`) + + h.OpenAIChatCompletions(c) + + respBody := w.Body.String() + if strings.Contains(respBody, "not supported") { + t.Fatalf("did not expect 'stop' rejection, got %s", respBody) + } + if !strings.Contains(respBody, "The last content of this conversation is not from user.") { + t.Fatalf("expected request to flow past stop check to last-message validator, got %s", respBody) + } +} + +// TestChatCompletions_SilentlyDropsTopLevelUser verifies that top-level +// `user` is silently dropped (same rationale and structure as `stop` above). +func TestChatCompletions_SilentlyDropsTopLevelUser(t *testing.T) { + h := NewOpenAIChatHandler(service.NewOpenAIChatService()) + c, w := newOpenAITestContext(t, "c1", `{ + "model": "model", + "messages": [{"role": "user", "content": "hi"}, {"role": "assistant", "content": "world"}], + "user": "session-abc" + }`) + + h.OpenAIChatCompletions(c) + + respBody := w.Body.String() + if strings.Contains(respBody, "not supported") { + t.Fatalf("did not expect 'user' rejection, got %s", respBody) + } + if !strings.Contains(respBody, "The last content of this conversation is not from user.") { + t.Fatalf("expected request to flow past user check to last-message validator, got %s", respBody) + } +} diff --git a/internal/router/router.go b/internal/router/router.go index f1e9f6ec5a..f095364f0e 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -33,6 +33,7 @@ type Router struct { chunkHandler *handler.ChunkHandler llmHandler *handler.LLMHandler chatHandler *handler.ChatHandler + openaiChatHandler *handler.OpenAIChatHandler chatSessionHandler *handler.ChatSessionHandler connectorHandler *handler.ConnectorHandler searchHandler *handler.SearchHandler @@ -77,6 +78,7 @@ func NewRouter( modelHandler *handler.ModelHandler, fileCommitHandler *handler.FileCommitHandler, adminRuntimeHandler *handler.AdminRuntimeHandler, + openaiChatHandler *handler.OpenAIChatHandler, ) *Router { return &Router{ authHandler: authHandler, @@ -89,6 +91,7 @@ func NewRouter( chunkHandler: chunkHandler, llmHandler: llmHandler, chatHandler: chatHandler, + openaiChatHandler: openaiChatHandler, chatSessionHandler: chatSessionHandler, connectorHandler: connectorHandler, searchHandler: searchHandler, @@ -242,6 +245,12 @@ func (r *Router) Setup(engine *gin.Engine) { chats.GET("/:chat_id/sessions", r.chatSessionHandler.ListChatSessions) } + // OpenAI-compatible chat completions route + openai := v1.Group("/openai") + { + openai.POST("/:chat_id/chat/completions", r.openaiChatHandler.OpenAIChatCompletions) + } + // Searchbot routes v1.POST("/searchbots/related_questions", r.searchBotHandler.Handle) v1.POST("/searchbots/retrieval_test", r.searchBotHandler.RetrievalTest) diff --git a/internal/service/ask_service.go b/internal/service/ask_service.go index bd1cab5c57..87e037c9fa 100644 --- a/internal/service/ask_service.go +++ b/internal/service/ask_service.go @@ -216,6 +216,8 @@ func toFloat64Slice(v interface{}) []float64 { for i, x := range val { if f, ok := x.(float64); ok { out[i] = f + } else { + return nil } } return out @@ -224,5 +226,5 @@ func toFloat64Slice(v interface{}) []float64 { } } -func ptrInt(v int) *int { return &v } +func ptrInt(v int) *int { return &v } func ptrFloat64(v float64) *float64 { return &v } diff --git a/internal/service/chat.go b/internal/service/chat.go index 122b0c227b..4e8915a3cc 100644 --- a/internal/service/chat.go +++ b/internal/service/chat.go @@ -209,18 +209,65 @@ type ParameterConfig struct { // PromptConfig prompt configuration type PromptConfig struct { - System string `json:"system"` - Prologue string `json:"prologue"` - Parameters []ParameterConfig `json:"parameters"` - EmptyResponse string `json:"empty_response"` - TavilyAPIKey string `json:"tavily_api_key,omitempty"` - Keyword bool `json:"keyword,omitempty"` - Quote bool `json:"quote,omitempty"` - Reasoning bool `json:"reasoning,omitempty"` - RefineMultiturn bool `json:"refine_multiturn,omitempty"` - TocEnhance bool `json:"toc_enhance,omitempty"` - TTS bool `json:"tts,omitempty"` - UseKG bool `json:"use_kg,omitempty"` + System *string `json:"system"` + Prologue *string `json:"prologue"` + Parameters []ParameterConfig `json:"parameters"` + EmptyResponse *string `json:"empty_response"` + TavilyAPIKey string `json:"tavily_api_key,omitempty"` + Keyword *bool `json:"keyword,omitempty"` + Quote *bool `json:"quote,omitempty"` + Reasoning *bool `json:"reasoning,omitempty"` + RefineMultiturn *bool `json:"refine_multiturn,omitempty"` + TocEnhance *bool `json:"toc_enhance,omitempty"` + TTS *bool `json:"tts,omitempty"` + UseKG *bool `json:"use_kg,omitempty"` + CrossLanguages []string `json:"cross_languages,omitempty"` + ReferenceMetadata map[string]interface{} `json:"reference_metadata,omitempty"` +} + +const ( + pyDefaultSystemPrompt = "You are an intelligent assistant. Please summarize the content of the dataset to answer the question. " + + "Please list the data in the dataset and answer in detail. " + + "When all dataset content is irrelevant to the question, your answer must include the sentence " + + `"The answer you are looking for is not found in the dataset!" ` + + "Answers need to consider chat history.\n" + + " Here is the knowledge base:\n" + + " {knowledge}\n" + + " The above is the knowledge base." + + pyDefaultPrologue = "Hi! I'm your assistant. What can I do for you?" + pyDefaultEmptyResponse = "Sorry! No relevant content was found in the knowledge base!" +) + +// applyPromptDefaults replaces missing keys with default values +func applyPromptDefaults(p *PromptConfig) { + if p.System == nil || *p.System == "" { + s := pyDefaultSystemPrompt + p.System = &s + } + if p.Prologue == nil { + s := pyDefaultPrologue + p.Prologue = &s + } + if p.Parameters == nil { + p.Parameters = []ParameterConfig{{Key: "knowledge", Optional: false}} + } + if p.EmptyResponse == nil { + s := pyDefaultEmptyResponse + p.EmptyResponse = &s + } + if p.Quote == nil { + t := true + p.Quote = &t + } + if p.RefineMultiturn == nil { + t := true + p.RefineMultiturn = &t + } + if p.TTS == nil { + f := false + p.TTS = &f + } } // SetDialogRequest set chat request @@ -347,12 +394,16 @@ func (s *ChatService) SetDialog(userID string, req *SetDialogRequest) (*SetDialo kbIDs = []string{} } + // Apply default prompt config on create only + if isCreate { + applyPromptDefaults(promptConfig) + } + // Set default parameters for datasets with knowledge retrieval // Check if parameters is missing or empty and kb_ids is provided if len(kbIDs) > 0 && (promptConfig.Parameters == nil || len(promptConfig.Parameters) == 0) { // Check if system prompt uses {knowledge} placeholder - if strings.Contains(promptConfig.System, "{knowledge}") { - // Set default parameters for any dataset with knowledge placeholder + if promptConfig.System != nil && strings.Contains(*promptConfig.System, "{knowledge}") { promptConfig.Parameters = []ParameterConfig{ {Key: "knowledge", Optional: false}, } @@ -361,7 +412,8 @@ func (s *ChatService) SetDialog(userID string, req *SetDialogRequest) (*SetDialo // For update: validate that {knowledge} is not used when no KBs or Tavily if !isCreate { - if len(kbIDs) == 0 && promptConfig.TavilyAPIKey == "" && strings.Contains(promptConfig.System, "{knowledge}") { + if len(kbIDs) == 0 && promptConfig.TavilyAPIKey == "" && + promptConfig.System != nil && strings.Contains(*promptConfig.System, "{knowledge}") { return nil, errors.New("Please remove `{knowledge}` in system prompt since no dataset / Tavily used here") } } @@ -372,7 +424,7 @@ func (s *ChatService) SetDialog(userID string, req *SetDialogRequest) (*SetDialo continue } placeholder := fmt.Sprintf("{%s}", p.Key) - if !strings.Contains(promptConfig.System, placeholder) { + if promptConfig.System == nil || !strings.Contains(*promptConfig.System, placeholder) { return nil, fmt.Errorf("Parameter '%s' is not used", p.Key) } } @@ -410,22 +462,47 @@ func (s *ChatService) SetDialog(userID string, req *SetDialogRequest) (*SetDialo llmID = tenant.LLMID } - // Convert prompt config to JSONMap with all fields - promptConfigMap := entity.JSONMap{ - "system": promptConfig.System, - "prologue": promptConfig.Prologue, - "empty_response": promptConfig.EmptyResponse, - "keyword": promptConfig.Keyword, - "quote": promptConfig.Quote, - "reasoning": promptConfig.Reasoning, - "refine_multiturn": promptConfig.RefineMultiturn, - "toc_enhance": promptConfig.TocEnhance, - "tts": promptConfig.TTS, - "use_kg": promptConfig.UseKG, + // Convert prompt config to JSONMap + promptConfigMap := entity.JSONMap{} + if promptConfig.System != nil && *promptConfig.System != "" { + promptConfigMap["system"] = *promptConfig.System + } + if promptConfig.Prologue != nil { + promptConfigMap["prologue"] = *promptConfig.Prologue + } + if promptConfig.EmptyResponse != nil { + promptConfigMap["empty_response"] = *promptConfig.EmptyResponse + } + if promptConfig.Quote != nil { + promptConfigMap["quote"] = *promptConfig.Quote + } + if promptConfig.RefineMultiturn != nil { + promptConfigMap["refine_multiturn"] = *promptConfig.RefineMultiturn + } + if promptConfig.TTS != nil { + promptConfigMap["tts"] = *promptConfig.TTS + } + if promptConfig.Keyword != nil { + promptConfigMap["keyword"] = *promptConfig.Keyword + } + if promptConfig.Reasoning != nil { + promptConfigMap["reasoning"] = *promptConfig.Reasoning + } + if promptConfig.TocEnhance != nil { + promptConfigMap["toc_enhance"] = *promptConfig.TocEnhance + } + if promptConfig.UseKG != nil { + promptConfigMap["use_kg"] = *promptConfig.UseKG } if promptConfig.TavilyAPIKey != "" { promptConfigMap["tavily_api_key"] = promptConfig.TavilyAPIKey } + if len(promptConfig.CrossLanguages) > 0 { + promptConfigMap["cross_languages"] = promptConfig.CrossLanguages + } + if len(promptConfig.ReferenceMetadata) > 0 { + promptConfigMap["reference_metadata"] = promptConfig.ReferenceMetadata + } if len(promptConfig.Parameters) > 0 { params := make([]map[string]interface{}, len(promptConfig.Parameters)) for i, p := range promptConfig.Parameters { diff --git a/internal/service/chat_pipeline.go b/internal/service/chat_pipeline.go new file mode 100644 index 0000000000..c649fca3c2 --- /dev/null +++ b/internal/service/chat_pipeline.go @@ -0,0 +1,4275 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package service + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "ragflow/internal/common" + "ragflow/internal/engine" + "ragflow/internal/entity" + modelModule "ragflow/internal/entity/models" + "ragflow/internal/service/kg" + "ragflow/internal/service/nlp" + "regexp" + "sort" + "strings" + "time" + + "ragflow/internal/dao" + + "go.uber.org/zap" +) + +// ChatPipelineService is the shared RAG chat pipeline engine used by both +// the OpenAI-compatible endpoint (/api/v1/openai//chat/completions) +// and the regular chat completion endpoint (/api/v1/chat/completions). +// +// It owns the core retrieval → generation pipeline (AsyncChat, AsyncChatSolo) +// and all their supporting helpers. Callers (OpenAIChatService, ChatSessionService) +// compose it to avoid code duplication. +type ChatPipelineService struct { + ModelProviderSvc *ModelProviderService + MetadataSvc *MetadataService + KbService *KnowledgebaseService +} + +// NewChatPipelineService creates a new ChatPipelineService with all required dependencies. +func NewChatPipelineService() *ChatPipelineService { + return &ChatPipelineService{ + ModelProviderSvc: NewModelProviderService(), + MetadataSvc: NewMetadataService(), + KbService: NewKnowledgebaseService(), + } +} + +// --------------------------------------------------------------------------- +// AsyncChatResult mirrors the dicts yielded by Python's async_chat / +// async_chat_solo. The handler translates these into OpenAIStreamEvent or +// builds a non-streaming OpenAICompletionResponse. +// --------------------------------------------------------------------------- + +// AsyncChatResult is a single yield from the chat pipeline. +// +// Reasoning carries chain-of-thought text routed by the driver to a +// separate `reason` channel (e.g. OpenAI's `delta.reasoning_content` +// from Qwen / SiliconFlow). It is kept distinct from Answer so the +// SSE handler can map it to `delta.reasoning_content` rather than +// `delta.content`. Mirrors Python's _async_chat_streamly, which wraps +// reasoning_content in markers (rag/llm/chat_model.py:226-232) +// so _stream_with_think_delta can route it via in_think state. In Go +// the driver already separates the two streams, so we surface them as +// separate fields directly instead of merging-then-re-splitting. +type AsyncChatResult struct { + Answer string `json:"answer"` + Reasoning string `json:"reasoning,omitempty"` + Reference map[string]interface{} `json:"reference"` + AudioBinary interface{} `json:"audio_binary"` + Prompt string `json:"prompt"` + CreatedAt float64 `json:"created_at"` + Final bool `json:"final"` + StartToThink bool `json:"start_to_think,omitempty"` + EndToThink bool `json:"end_to_think,omitempty"` + // Internal-only: accumulated answer for building the decorated final result. + accumulatedAnswer string +} + +// AsyncChat is the Go equivalent of Python's async_chat() in +// api/db/services/dialog_service.py:541. +// +// Full pipeline: +// +// ┌───────────────────────────────────────────────────────┐ +// │ 1. Entry validation │ +// │ messages non-empty, last role = "user" │ +// ├───────────────────────────────────────────────────────┤ +// │ No KBs & no web search → AsyncChatSolo (LLM-only) │ +// ├───────────────────────────────────────────────────────┤ +// │ 2. Resolve LLM model config + max_tokens │ +// │ 3. Langfuse trace setup │ +// ├───────────────────────────────────────────────────────┤ +// │ 4. Bind Models: getModels() → embd, rerank, chat, tts │ +// │ + BindTools (toolcall session) │ +// ├───────────────────────────────────────────────────────┤ +// │ 5. Extract questions, attachments, image_files │ +// ├───────────────────────────────────────────────────────┤ +// │ 6. SQL Retrieval (field_map + chat_model) │ +// │ HIT → return structured SQL result directly │ +// │ MISS → fall through to vector retrieval │ +// ├───────────────────────────────────────────────────────┤ +// │ 7. Prompt parameters: resolve param_keys, auto-fix │ +// │ {knowledge} placeholder, validate kwargs │ +// │ 8. Query refinement(LLM): │ +// │ refine_multiturn → cross_languages → │ +// │ meta_data_filter → keyword extraction │ +// ├───────────────────────────────────────────────────────┤ +// │ 9. Retrieval (if hasKnowledgeParam): │ +// │ │ +// │ reasoning=true? │ +// │ YES → DeepResearcher (recursive, maxDepth=3) │ +// │ each layer: KB → Web(Tavily) → KG(use_kg) │ +// │ → sufficiencyCheck → multiQueriesGen → recurse│ +// │ NO → Standard vector retrieval │ +// │ vector/hybrid search → rerank → │ +// │ TOC enhance → child chunk retrieval → │ +// │ Tavily web search → KG retrieval (prepend) │ +// │ │ +// │ enrichChunksWithMetadata (doc metadata) │ +// │ kbPrompt (build knowledge blocks) │ +// ├───────────────────────────────────────────────────────┤ +// │ 10. Build LLM request: │ +// │ empty_response check → formatPrompt → │ +// │ citationPrompt(quote) → messageFitIn(95% budget) │ +// │ → multimodal conversion → adjust max_tokens │ +// ├───────────────────────────────────────────────────────┤ +// │ 11. Drive LLM (stream / non-stream) │ +// │ + answer decoration (citations, references, │ +// │ timing stats, Langfuse trace, TTS synthesis) │ +// └───────────────────────────────────────────────────────┘ +// +// Parameters: +// - chat: the chat/chat entity with KBs, prompt_config, etc. +// - messages: pre-filtered user/assistant messages (system already stripped). +// - stream: if true, yields content deltas as they arrive. +// - kwargs: extra parameters (doc_ids, knowledge, quote, etc.). +func (s *ChatPipelineService) AsyncChat( + ctx context.Context, + chat *entity.Chat, + messages []map[string]interface{}, + stream bool, + kwargs map[string]interface{}, +) (<-chan AsyncChatResult, error) { + + common.Info("AsyncChat started", zap.String("chat_id", chat.ID)) + + // === Phase 1: Entry Validation === + // Guard: messages must be non-empty and the last role must be "user". + common.Info("phase 1: Entry Validation") + if len(messages) == 0 { + return nil, fmt.Errorf("AsyncChat: messages is empty") + } + lastMsg := messages[len(messages)-1] + if role, _ := lastMsg["role"].(string); role != "user" { + return nil, fmt.Errorf("The last content of this conversation is not from user.") + } + + // No KBs & no web search → fast-path to LLM-only chat. + useWebSearch := s.shouldUseWebSearch(chat, kwargs["internet"]) + if useWebSearch { + common.Debug("web_search", + zap.Bool("kb", len(chat.KBIDs) > 0), + zap.Bool("tavily", chat.PromptConfig != nil && chat.PromptConfig["tavily_api_key"] != "" && chat.PromptConfig["tavily_api_key"] != nil), + zap.Any("internet", kwargs["internet"]), + zap.Bool("enabled", useWebSearch)) + } + + if len(chat.KBIDs) == 0 && !useWebSearch { + return s.AsyncChatSolo(ctx, chat, messages, stream) + } + + // Spawn goroutine for the async pipeline. All remaining phases run inside. + out := make(chan AsyncChatResult, 16) + + go func() { + defer close(out) + + timer := common.NewTimer() + timer.Start() + + // === Phase 2: Resolve LLM Model Config + max_tokens === + common.Info("Phase 2: Resolve LLM Model Config + max_tokens") + timer.Enter(common.PhaseCheckLLM) + llmModelConfig, _, _, _, err := s.getLLMModelConfig(chat) + if err != nil { + out <- AsyncChatResult{ + Answer: fmt.Sprintf("**ERROR**: %s", err.Error()), + Final: true, + } + return + } + modelMaxTokens := 8192 + if llmModelConfig != nil { + if mt, ok := llmModelConfig["max_tokens"].(float64); ok { + modelMaxTokens = int(mt) + } + } + timer.Exit(common.PhaseCheckLLM) + + // === Phase 3: Langfuse Trace Setup === + common.Info("Phase 3: Setup Langfuse Trace") + timer.Enter(common.PhaseCheckLangfuse) + var langfuseTraceID string + if lfClient, ok := ctx.Value(langfuseCtxKey).(*LangfuseClient); ok && lfClient != nil { + langfuseTraceID = fmt.Sprintf("trace-%d", time.Now().UnixNano()) + _ = lfClient.PostTrace(ctx, LangfuseTrace{ + ID: langfuseTraceID, + Name: "openai_chat", + UserID: chat.TenantID, + SessionID: chat.ID, + Metadata: map[string]interface{}{ + "stream": stream, + "kb_count": len(chat.KBIDs), + }, + Timestamp: time.Now().UTC().Format(time.RFC3339Nano), + }) + } + timer.Exit(common.PhaseCheckLangfuse) + + // === Phase 4: Bind Models (embedding, rerank, chat, TTS) + ToolCall === + common.Info("Phase 4: Bind Models (embedding, rerank, chat, TTS)") + timer.Enter(common.PhaseBindModels) + kbs, embModel, rerankModel, chatModel, ttsModel := s.getModels(ctx, chat) + + // Toolcall binding + if toolcallSession, hasSession := kwargs["toolcall_session"]; hasSession && toolcallSession != nil { + if tools, hasTools := kwargs["tools"]; hasTools && tools != nil { + if chatModel != nil { + if ts, ok := toolcallSession.(modelModule.ToolCallSession); ok { + common.Info("Bind ToolCall") + chatModel.BindTools(ts, tools) + } + } + } + } + timer.Exit(common.PhaseBindModels) + + // === Phase 5: Extract Questions, doc_ids, Attachments === + common.Info("Phase 5: Extract questions, doc_ids, attachments") + // Retrieve the last 3 user questions. + var questions []string + for _, m := range messages { + if role, _ := m["role"].(string); role == "user" { + if content, ok := m["content"].(string); ok { + questions = append(questions, content) + } + } + } + if len(questions) > 3 { + questions = questions[len(questions)-3:] + } + + common.Debug("Extracted questions", zap.Strings("questions", questions)) + + // Resolve doc_ids from kwargs or the last message. + // Kwargs["doc_ids"] is a comma-separated string. + // messages[-1]["doc_ids"] ALWAYS overrides the kwargs value. + var docIDs []string + if docIDsStr, ok := kwargs["doc_ids"].(string); ok { + for _, p := range strings.Split(docIDsStr, ",") { + p = strings.TrimSpace(p) + if p != "" { + docIDs = append(docIDs, p) + } + } + } + if docIDsRaw, ok := lastMsg["doc_ids"]; ok { + docIDs = nil + if v, ok := docIDsRaw.([]string); ok { + for _, id := range v { + if id != "" { + docIDs = append(docIDs, id) + } + } + } else { + common.Warn("doc_ids in message is not []string, ignoring", + zap.Any("type", fmt.Sprintf("%T", docIDsRaw))) + } + } + if docIDs != nil { + common.Debug("Resolved doc_ids", zap.Strings("doc_ids", docIDs)) + } + + // Parse file attachments from the last message. + // Split text-file URLs (joined with "\n\n") and image URLs. + // Chat model: images → imageAttachments (multimodal conversion). + // Image2text model: images → imageFiles (raw URLs). + var textAttachmentsList []string + var imageAttachments []string + var imageFiles []string + // Joined text attachments (appended to system prompt). + var attachments string + // When files are file dicts, splitFileAttachments fetches blobs + // from storage. When plain strings, falls back to string splitting. + if files, hasFiles := lastMsg["files"]; hasFiles { + modelType := "chat" + if llmModelConfig != nil { + if mt, ok := llmModelConfig["model_type"].(string); ok { + modelType = mt + } + } + if modelType == "chat" { + textAttachmentsList, imageAttachments = splitFileAttachments(files, false) + } else { + textAttachmentsList, imageFiles = splitFileAttachments(files, true) + } + attachments = strings.Join(textAttachmentsList, "\n\n") + common.Debug("Resolved attachments", + zap.Strings("text_attachments_list", textAttachmentsList), + zap.Strings("image_attachments", imageAttachments), + zap.Strings("image_files", imageFiles), + zap.String("attachments", attachments)) + } + + // === Phase 6: SQL Retrieval === + // Retrieve field_map for SQL retrieval (preferred over vector search) + promptConfig := chat.PromptConfig + fieldMap, fmErr := s.KbService.GetFieldMap(kbIDStrings(kbs)) + if fmErr != nil { + common.Warn("get_field_map failed; proceeding without field_map", zap.Error(fmErr)) + fieldMap = nil + } + // Try structured SQL retrieval before vector search. + // Only runs on the last question + // HIT → return structured result directly. + // MISS → fall through to vector search. + if len(fieldMap) > 0 && chatModel != nil && len(kbs) > 0 { + common.Info("Phase 6: Use SQL to retrieval") + common.Debug("field_map retrieved", zap.Any("field_map", fieldMap)) + quote := true + if v, ok := promptConfig["quote"].(bool); ok { + quote = v + } + + ans, sqlErr := s.useSQL( + ctx, chat, kbs, questions[len(questions)-1], chatModel, fieldMap, quote, + ) + if sqlErr != nil { + common.Warn("SQL retrieval error; falling through", zap.Error(sqlErr)) + } + + // For aggregate queries (COUNT, SUM, etc.), chunks may be empty + // but answer is still valid. + chunks := []map[string]interface{}{} + ansStr := "" + if ans != nil { + if refs, ok := ans["reference"].(map[string]interface{}); ok { + if c, ok := refs["chunks"].([]map[string]interface{}); ok { + chunks = c + } + } + ansStr, _ = ans["answer"].(string) + } + if ans != nil && (ansStr != "" || len(chunks) > 0) { + common.Info("SQL retrieval succeeded, skipping vector retrieval") + + // Enrich chunks with document metadata + if includeRefMeta, metadataFields := s.resolveReferenceMetadata(promptConfig, kwargs); includeRefMeta && len(chunks) > 0 { + if len(kbs) != 1 { + hasMissingKBID := false + for _, cm := range chunks { + if _, hasKBID := cm["kb_id"]; !hasKBID { + hasMissingKBID = true + break + } + } + if hasMissingKBID { + common.Warn("Skipping some _enrich_chunks_with_document_metadata results because chat.kb_ids has multiple entries and use_sql returned chunks without kb_id", + zap.Int("kb_count", len(kbs))) + } + } + kbinfos := map[string]interface{}{"chunks": chunks} + s.enrichChunksWithMetadata(kbinfos, chat.TenantID, metadataFields) + } + + out <- AsyncChatResult{ + Answer: ansStr, + Reference: ans["reference"].(map[string]interface{}), + Final: true, + } + return + } + common.Info("SQL retrieval: no valid result, falling back to vector search") + } + + // === Phase 7: Prompt Parameters === + common.Info("Phase 7: Building Prompt Parameters") + // Build param_keys from prompt_config["parameters"]. + // prompt_config["parameters"] is a JSON array of + // {key: string, optional: bool} + // objects declaring which placeholder variables the system prompt + // template expects to be substituted. + // + // hasKnowledgeParam gates the entire RAG retrieval phase below. + // When true: vector / DeepResearcher / TOC / Tavily / KG retrieval + // populates kbinfos (from which knowledges is derived afterward). + // When false: skip retrieval and rely on caller-supplied + // kwargs["knowledge"] or LLM-only. + var parameters []interface{} + if paramsRaw, ok := promptConfig["parameters"]; ok { + if p, ok := paramsRaw.([]interface{}); ok { + parameters = p + } + } + var paramKeys []string + hasKnowledgeParam := false + for _, p := range parameters { + if pMap, ok := p.(map[string]interface{}); ok { + if key, _ := pMap["key"].(string); key != "" { + paramKeys = append(paramKeys, key) + if key == "knowledge" { + hasKnowledgeParam = true + } + } + } + } + + // Auto-fix: ensure "knowledge" is in param_keys when the chat has + // KBs and the system prompt references {knowledge}. + if len(kbs) > 0 && !hasKnowledgeParam { + systemPrompt, _ := promptConfig["system"].(string) + if strings.Contains(systemPrompt, "{knowledge}") { + common.Warn("prompt_config['parameters'] is missing 'knowledge' entry despite kb_ids being set; auto-fixing.") + parameters = append(parameters, map[string]interface{}{ + "key": "knowledge", + "optional": false, + }) + promptConfig["parameters"] = parameters + paramKeys = append(paramKeys, "knowledge") + hasKnowledgeParam = true + } + } + + // Validate prompt template parameters against caller-supplied kwargs. + // - "knowledge" is always skipped (system-injected, not caller-supplied). + // - Missing non-optional param => return error immediately. + // - Missing optional param => replace "{key}" placeholder with space. + systemPrompt, _ := promptConfig["system"].(string) + for _, p := range parameters { + pMap, ok := p.(map[string]interface{}) + if !ok { + continue + } + key, _ := pMap["key"].(string) + if key == "knowledge" { + continue // system-injected, skip caller validation + } + if _, inKwargs := kwargs[key]; !inKwargs { + optional, _ := pMap["optional"].(bool) + if !optional { + // Required parameter missing => fail fast + out <- AsyncChatResult{ + Answer: fmt.Sprintf("**ERROR**: Miss parameter: %s", key), + Final: true, + } + return + } + // Optional parameter missing => erase placeholder from system prompt + systemPrompt = strings.ReplaceAll(systemPrompt, "{"+key+"}", " ") + } + } + promptConfig["system"] = systemPrompt + + common.Debug("Prompt parameters", + zap.Strings("doc_ids", docIDs), + zap.Strings("param_keys", paramKeys), + zap.Bool("has_embd_mdl", embModel != nil), + zap.Any("prompt_config.parameters", promptConfig["parameters"]), + zap.String("prompt_config.system", systemPrompt)) + + // === Phase 8: Query refinement(LLM) === + // Sub-steps: refine_multiturn → cross_languages → meta_data_filter → keyword. + common.Info("Phase 8: Query refinement(LLM)") + timer.Enter(common.PhaseQueryRefinement) + + // refine_multiturn — condense multi-turn conversation into a single + // refined question via LLM. When disabled, simply keep the last question. + if refine, _ := chat.PromptConfig["refine_multiturn"].(bool); refine && len(questions) > 1 && chatModel != nil { + if refined, err := FullQuestion(ctx, chatModel, messages, ""); err == nil && refined != "" { + questions = []string{refined} // replace with refined question + common.Debug("refine_multiturn applied", + zap.String("refined", truncateForLog(refined, 60))) + } else if err != nil { + common.Warn("refine_multiturn failed; using original question", zap.Error(err)) + } + } else { + // Keep only the last question. + questions = questions[len(questions)-1:] + } + + // cross_languages — translate the question into configured target + // languages via LLM, replacing the original. Useful for cross-lingual retrieval. + if crossLangs, ok := chat.PromptConfig["cross_languages"].([]interface{}); ok && len(crossLangs) > 0 && chatModel != nil && len(questions) > 0 { + langs := make([]string, 0, len(crossLangs)) + for _, x := range crossLangs { + if s, ok := x.(string); ok && s != "" { + langs = append(langs, s) + } + } + if len(langs) > 0 { + if translated, err := CrossLanguages(ctx, chat.TenantID, chat.LLMID, questions[0], langs); err == nil && translated != "" { + original := questions[0] + questions = []string{translated} // replace with translated question + common.Debug("cross_languages applied", + zap.String("original_question", original), + zap.String("translated_question", translated)) + } else if err != nil { + common.Warn("cross_languages failed", zap.Error(err)) + } + } + } + + // meta_data_filter — use LLM to map the question to metadata + // criteria, then filter docIDs to matching + // documents only. + if chat.MetaDataFilter != nil && len(*chat.MetaDataFilter) > 0 && len(kbs) > 0 { + kbIDs := kbIDStrings(kbs) + if metaQ := questions[len(questions)-1]; metaQ != "" { + var flattedMeta common.MetaData + var mErr error + if s.MetadataSvc != nil { + flattedMeta, mErr = s.MetadataSvc.GetFlattedMetaByKBs(kbIDs) + } + if mErr == nil { + if filtered, ok := ApplyMetaDataFilter( + ctx, + *chat.MetaDataFilter, + flattedMeta, + metaQ, + chatModel, + docIDs, + kbIDs, + ); ok { + common.Debug("meta_data_filter applied", + zap.Int("filtered_count", len(filtered)), + zap.Int("pre_filter_count", len(docIDs))) + docIDs = filtered + } + } else { + common.Warn("loadMetaData failed; skipping meta_data_filter", zap.Error(mErr)) + } + } + } + + // keyword — extract top-N keywords from the question via LLM and + // append them to the question text to boost lexical retrieval recall. + if useKW, _ := chat.PromptConfig["keyword"].(bool); useKW && chatModel != nil && len(questions) > 0 { + if kw, err := KeywordExtraction(ctx, chatModel, questions[len(questions)-1], 3); err == nil && kw != "" { + original := questions[len(questions)-1] + questions[len(questions)-1] = questions[len(questions)-1] + "," + kw + common.Debug("keyword extraction applied", + zap.String("original_question", original), + zap.String("augmented_question", questions[len(questions)-1])) + } else if err != nil { + common.Warn("keyword extraction failed", zap.Error(err)) + } + } + timer.Exit(common.PhaseQueryRefinement) + + // === Phase 9: Retrieval === + promptReasoning, _ := chat.PromptConfig["reasoning"].(bool) + kwargReasoning, _ := kwargs["reasoning"].(bool) + useReasoning := promptReasoning || kwargReasoning + common.Info("Phase 9: Retrieval", + zap.Bool("has_knowledge_param", hasKnowledgeParam), + zap.Bool("reasoning", useReasoning)) + + timer.Enter(common.PhaseRetrieval) + var kbinfos map[string]interface{} + kbinfos = map[string]interface{}{ + "total": 0, + "chunks": []map[string]interface{}{}, + "doc_aggs": []interface{}{}, + } + var knowledges []string + + // When hasKnowledgeParam is true, runs (mutually exclusive): + // a) If reasoning is enabled: DeepResearcher replaces vector retrieval. + // b) Otherwise: standard retrieval, then: + // - TOC enhancement (if toc_enhance is enabled). + // - Child chunk retrieval. + // - Tavily web search (if internet is enabled). + // - Knowledge graph retrieval (if use_kg is enabled). + // Populates kbinfos (chunks + doc_aggs) and knowledges. + // When false, the entire block is skipped. + if hasKnowledgeParam { + if useReasoning && chatModel != nil && len(kbs) > 0 { + // DeepResearcher — replaces vector retrieval. + // Yields / markers + intermediate messages. + docEngine := engine.Get() + if docEngine != nil { + retSvc := nlp.NewRetrievalService(docEngine, dao.NewDocumentDAO()) + tenantIDs := kbTenantIDStrings(kbs) + kbIDs := kbIDStrings(kbs) + + // KB retrieval callback for the deep researcher + kbRetrieve := func(ctx context.Context, q string) (*nlp.RetrievalResult, error) { + return retSvc.Retrieval(ctx, &nlp.RetrievalRequest{ + Question: q, + TenantIDs: tenantIDs, + KbIDs: kbIDs, + DocIDs: docIDs, + Page: 1, + PageSize: int(chat.TopN), + EmbeddingModel: embModel, + }) + } + + dr := NewDeepResearcher( + chatModel, + map[string]interface{}(chat.PromptConfig), + kbRetrieve, + useWebSearch, + docEngine, + kbIDs, + tenantIDs, + embModel, + ) + question := strings.Join(questions, " ") + + drErr := dr.Research(ctx, kbinfos, question, question, func(msg string) { + switch { + case strings.HasPrefix(msg, ""): + out <- AsyncChatResult{ + Answer: "", + Reference: map[string]interface{}{}, + AudioBinary: nil, + Final: false, + } + case strings.HasPrefix(msg, ""): + out <- AsyncChatResult{ + Answer: "", + Reference: map[string]interface{}{}, + AudioBinary: nil, + Final: false, + } + default: + out <- AsyncChatResult{ + Answer: msg, + Reference: map[string]interface{}{}, + AudioBinary: nil, + Final: false, + } + } + }) + if drErr != nil { + common.Warn("DeepResearcher failed", zap.Error(drErr)) + } else { + // kbinfos now contains real chunks with proper + // chunk_ids from the recursive tree search. + common.Debug("DeepResearcher completed", + zap.Int("chunks", len(kbinfos["chunks"].([]map[string]interface{})))) + } + } + } else { + searchQuestion := strings.Join(questions, " ") + if embModel != nil { + // Retrieval + rankFeature := s.MetadataSvc.LabelQuestion(searchQuestion, kbs) + { + tenantIDs := make([]string, 0) + kbIDs := make([]string, 0) + for _, kb := range kbs { + tenantIDs = append(tenantIDs, kb.TenantID) + kbIDs = append(kbIDs, kb.ID) + } + + docEngine := engine.Get() + documentDAO := dao.NewDocumentDAO() + retrievalSvc := nlp.NewRetrievalService(docEngine, documentDAO) + + top := int(chat.TopK) + threshold := chat.SimilarityThreshold + vsw := chat.VectorSimilarityWeight + topN := int(chat.TopN) + + req := &nlp.RetrievalRequest{ + Question: searchQuestion, + TenantIDs: tenantIDs, + KbIDs: kbIDs, + DocIDs: docIDs, + Page: 1, + PageSize: topN, + Top: &top, + SimilarityThreshold: &threshold, + VectorSimilarityWeight: &vsw, + RankFeature: &rankFeature, + RerankModel: rerankModel, + EmbeddingModel: embModel, + Aggs: func() *bool { v := true; return &v }(), + } + + result, retErr := retrievalSvc.Retrieval(ctx, req) + if retErr != nil { + kbinfos = map[string]interface{}{ + "total": 0, + "chunks": []map[string]interface{}{}, + "doc_aggs": []interface{}{}, + } + err = retErr + } else { + docAggs := make([]interface{}, len(result.DocAggs)) + for i, da := range result.DocAggs { + docAggs[i] = da + } + + kbinfos = map[string]interface{}{ + "total": len(result.Chunks), + "chunks": result.Chunks, + "doc_aggs": docAggs, + } + } + } + if err != nil { + common.Warn("Retrieval failed", zap.Error(err)) + // Continue with empty kbinfos. + } + + // TOC enhancement + if useTOC, _ := chat.PromptConfig["toc_enhance"].(bool); useTOC && chatModel != nil && len(kbs) > 0 { + enhancer := NewTOCEnhancer( + engine.Get(), + chatModel, + kbTenantIDStrings(kbs), + kbIDStrings(kbs), + searchQuestion, + int(chat.TopN), + ) + if added, err := enhancer.Enhance(ctx, kbinfos); err != nil { + common.Warn("TOC enhance failed", zap.Error(err)) + } else if added > 0 { + common.Debug("TOC enhance added chunks", zap.Int("added", added)) + } + } + } + + // Child chunk retrieval + if existingChunks, ok := kbinfos["chunks"].([]map[string]interface{}); ok && len(existingChunks) > 0 { + kbinfos["chunks"] = nlp.RetrievalByChildren(existingChunks, kbTenantIDStrings(kbs), engine.Get(), ctx) + } + + // Web search via Tavily + if s.shouldUseWebSearch(chat, kwargs["internet"]) { + tavilyKey, _ := chat.PromptConfig["tavily_api_key"].(string) + tavResult, tavErr := s.tavilyRetrieve(ctx, tavilyKey, searchQuestion) + if tavErr != nil { + common.Warn("Tavily web search failed", zap.Error(tavErr)) + } else { + // Extend chunks and doc_aggs with web search results. + if existingChunks, ok := kbinfos["chunks"].([]map[string]interface{}); ok { + if newChunks, ok := tavResult["chunks"].([]map[string]interface{}); ok { + kbinfos["chunks"] = append(existingChunks, newChunks...) + } + } + if existingAggs, ok := kbinfos["doc_aggs"].([]interface{}); ok { + if newAggs, ok := tavResult["doc_aggs"].([]interface{}); ok { + kbinfos["doc_aggs"] = append(existingAggs, newAggs...) + } + } + } + } + + // Knowledge Graph retrieval + if useKG, _ := chat.PromptConfig["use_kg"].(bool); useKG && chatModel != nil && len(kbs) > 0 { + kgIDs := kbIDStrings(kbs) + if len(kgIDs) > 0 { + kgPipeline := kg.NewPipeline(engine.Get(), kgIDs, kbTenantIDStrings(kbs), searchQuestion) + kgPipeline.SetChatModel(chatModel) + if embModel != nil { + kgPipeline.SetEmbModel(embModel) + } + kgChunk, kgErr := kgPipeline.Retrieval(ctx) + if kgErr != nil { + common.Warn("KG retrieval failed; falling through to vector-only", + zap.Error(kgErr)) + } else if kgChunk != nil { + if _, hasContent := kgChunk["content_with_weight"]; hasContent { + if existingChunks, ok := kbinfos["chunks"].([]map[string]interface{}); ok { + newChunks := make([]map[string]interface{}, 0, len(existingChunks)+1) + newChunks = append(newChunks, kgChunk) + newChunks = append(newChunks, existingChunks...) + kbinfos["chunks"] = newChunks + common.Debug("KG chunk prepended", + zap.Int("total_chunks", len(newChunks))) + } + } + } + } + } + } + } + + // Enrich chunks with document metadata AFTER all retrieval adds. + // Request values (kwargs) take precedence over config values. + if includeRefMeta, metadataFields := s.resolveReferenceMetadata(promptConfig, kwargs); includeRefMeta { + s.enrichChunksWithMetadata(kbinfos, chat.TenantID, metadataFields) + } + timer.Exit(common.PhaseRetrieval) + + // === Phase 10: Build LLM Request === + // Sub-steps: empty_response check → formatPrompt → citationPrompt → + // messageFitIn (95% token budget) → multimodal conversion → adjust max_tokens. + // If no knowledges and empty_response is configured, yield it and return. + knowledges = s.kbPrompt(kbinfos, modelMaxTokens) + common.Info("Phase 10: Build LLM Request") + common.Debug("Knowledge prompt", + zap.String("question", strings.Join(questions, " ")), + zap.Strings("knowledges", knowledges)) + + // empty_response check + // When no knowledge chunks were retrieved, skip the LLM entirely and + // return the user-configured fallback message (if set). + // If empty_response is not configured, fall through to the LLM call + // with an empty knowledge context. + if len(knowledges) == 0 { + if emptyResp, ok := promptConfig["empty_response"].(string); ok && emptyResp != "" { + out <- AsyncChatResult{ + Answer: emptyResp, + Reference: kbinfos, + AudioBinary: s.synthesizeTTS(ttsModel, emptyResp), + Prompt: fmt.Sprintf("\n\n### Query:\n%s", strings.Join(questions, " ")), + Final: true, + } + return + } + } + + // Format the system prompt with knowledge. + // Only overwrite kwargs["knowledge"] when retrieval produced something; + // otherwise preserve any caller-supplied value. + knowledge := strings.Join(knowledges, "\n\n------\n\n") + if knowledge != "" { + kwargs["knowledge"] = "\n------\n" + knowledge + } + systemPrompt = "" + if sp, ok := promptConfig["system"].(string); ok { + systemPrompt = s.formatPrompt(sp, kwargs) + attachments + // If knowledge was retrieved but the template has no {knowledge} + // placeholder, auto-append it so the LLM still sees the context. + if len(knowledges) > 0 && !strings.Contains(sp, "{knowledge}") { + if kw, ok := kwargs["knowledge"].(string); ok { + systemPrompt += kw + } + } + } + if systemPrompt != "" { + common.Info("System prompt built", + zap.Int("length", len(systemPrompt))) + } + + // Build citation prompt if quoting is enabled. + prompt4citation := "" + quote := true + if v, ok := kwargs["quote"].(bool); ok { + quote = v + } + if promptConfigQuote, ok := promptConfig["quote"].(bool); ok { + quote = quote && promptConfigQuote + } + if len(knowledges) > 0 && quote { + prompt4citation = citationPrompt() + } + + if prompt4citation != "" { + common.Info("Citation prompt built", + zap.Bool("quote", quote), + zap.Int("length", len(prompt4citation))) + } + + // Build the message list: system + cleaned user/assistant messages. + var llmMessages []map[string]interface{} + llmMessages = append(llmMessages, map[string]interface{}{ + "role": "system", + "content": systemPrompt, + }) + factoryName := "" + if llmModelConfig != nil { + if f, ok := llmModelConfig["llm_factory"].(string); ok && f != "" { + factoryName = strings.ToLower(f) + } + } + if factoryName == "" { + factoryName = factoryFromLLMID(chat.LLMID) + } + for _, m := range messages { + role, _ := m["role"].(string) + if role == "system" { + continue + } + content := m["content"] + if contentStr, ok := content.(string); ok { + content = cleanCitationMarkers(contentStr) + } + llmMessages = append(llmMessages, map[string]interface{}{ + "role": role, + "content": content, + }) + } + + // Fit messages within token budget. + usedTokenCount, llmMessages := s.messageFitIn(llmMessages, int(float64(modelMaxTokens)*0.95)) + common.Debug("Messages fitted in token budget", + zap.Int("model max_tokens", modelMaxTokens), + zap.Int("used_token_count", usedTokenCount), + zap.Int("msg_count", len(llmMessages))) + + // Multimodal conversion + allImages := make([]string, 0, len(imageAttachments)+len(imageFiles)) + allImages = append(allImages, imageAttachments...) + allImages = append(allImages, imageFiles...) + if len(llmMessages) >= 2 && len(allImages) > 0 { + lastIdx := len(llmMessages) - 1 + if role, _ := llmMessages[lastIdx]["role"].(string); role == "user" { + if converted, err := common.ConvertLastUserMsgToMultimodal( + llmMessages[lastIdx], + allImages, + factoryName, + ); err == nil { + llmMessages[lastIdx] = converted + } + } + } + + prompt := systemPrompt + if len(llmMessages) > 0 { + if c, ok := llmMessages[0]["content"].(string); ok { + prompt = c + } + } + + if len(llmMessages) < 2 { + out <- AsyncChatResult{ + Answer: "**ERROR**: message_fit_in has bug", + Final: true, + } + return + } + + // Adjust max_tokens so the LLM has room within the total budget. + if chat.LLMSetting != nil { + if mt, ok := chat.LLMSetting["max_tokens"].(float64); ok { + original := int(mt) + adjusted := original + if adjusted > modelMaxTokens-usedTokenCount { + adjusted = modelMaxTokens - usedTokenCount + } + chat.LLMSetting["max_tokens"] = float64(adjusted) + common.Debug("Adjusted max_tokens", zap.Int("max_tokens in chat", adjusted)) + } + } + + // === Phase 11: Drive LLM + Decorate Answer === + // Stream path: accumulate deltas → per-delta TTS → decorate final. + // Non-stream path: one-shot chat → decorate (includes TTS). + // Answer decoration: citation markers, references, timing stats, Langfuse. + common.Info("Phase 11: Drive LLM + Decorate Answer", + zap.Bool("stream", stream), + zap.Int("llm_messages_count", len(llmMessages))) + timer.Enter(common.PhaseGenerateAnswer) + chatDriver := s.buildChatDriver(chat, chatModel) + if chatDriver == nil { + out <- AsyncChatResult{ + Answer: "**ERROR**: No chat model available for this chat.", + Final: true, + } + return + } + chatMessages := s.buildChatMessages(prompt+prompt4citation, llmMessages[1:]) + + // Langfuse generation start observation. + var langfuseGenerationID string + if langfuseTraceID != "" { + if lfClient, ok := ctx.Value(langfuseCtxKey).(*LangfuseClient); ok && lfClient != nil { + langfuseGenerationID = fmt.Sprintf("gen-%s", langfuseTraceID) + modelName := "" + if llmModelConfig != nil { + if mn, ok := llmModelConfig["llm_name"].(string); ok { + modelName = mn + } + } + // PostGeneration creates a start-observation span. + // Error is non-fatal; end-observation fires regardless. + genInput := map[string]interface{}{ + "prompt": prompt, + "prompt4citation": prompt4citation, + "messages": chatMessages, + } + if err := lfClient.PostGeneration(ctx, LangfuseGeneration{ + ID: langfuseGenerationID, + TraceID: langfuseTraceID, + Name: "chat", + Model: modelName, + StartTime: time.Now().UTC().Format(time.RFC3339Nano), + Input: genInput, + }); err != nil { + common.Warn("Langfuse start observation (PostGeneration) failed; continuing without start-side tracing", + zap.String("langfuse_trace_id", langfuseTraceID), + zap.Error(err)) + // Keep langfuseGenerationID set so the end + // Keep langfuseGenerationID set so end-observation fires. + } + } + } + + // Stream path: per-delta callbacks, accumulate answer. + // Non-stream path: one-shot synchronous answer. + if stream { + // Streaming path: accumulate answer, emit deltas. + var fullAnswer string + var fullReasoning string + thinkState := &thinkStreamState{} + + chatCfg := BuildChatConfig(chat, nil) + + // Tool routing: use tool-loop method when tools are bound. + var driverErr error + if chatDriver.ToolConfig != nil { + // Tool streaming path: + // Wraps reasoning in markers. + // inThink tracks local state to route reasoning vs answer. + var inThink bool + _, driverErr = chatDriver.ChatStreamlyWithTools(ctx, prompt+prompt4citation, chatMessages, chatCfg, + func(answerDelta *string, reason *string) error { + if answerDelta == nil || *answerDelta == "" { + return nil + } + text := *answerDelta + fullAnswer += text + + if text == "" { + inThink = true + out <- AsyncChatResult{ + Answer: "", + Reference: map[string]interface{}{}, + AudioBinary: nil, + CreatedAt: float64(time.Now().Unix()), + Final: false, + StartToThink: true, + } + return nil + } + if text == "" { + inThink = false + out <- AsyncChatResult{ + Answer: "", + Reference: map[string]interface{}{}, + AudioBinary: nil, + CreatedAt: float64(time.Now().Unix()), + Final: false, + EndToThink: true, + } + return nil + } + if inThink { + // Reasoning text — route to Reasoning field so + // the SSE handler maps it to + // `delta.reasoning_content`. + out <- AsyncChatResult{ + Reasoning: text, + Reference: map[string]interface{}{}, + AudioBinary: nil, + CreatedAt: float64(time.Now().Unix()), + Final: false, + } + } else { + // Regular answer content + out <- AsyncChatResult{ + Answer: text, + Reference: map[string]interface{}{}, + AudioBinary: s.synthesizeTTS(ttsModel, text), + CreatedAt: float64(time.Now().Unix()), + Final: false, + } + } + return nil + }) + } else { + driverErr = chatDriver.ModelDriver.ChatStreamlyWithSender( + *chatDriver.ModelName, chatMessages, chatDriver.APIConfig, chatCfg, + func(answer *string, reason *string) error { + if reason != nil && *reason != "" { + fullReasoning += *reason + kind, output := processThinkDelta(thinkState, *reason, 16) + if kind == "marker" && output == "" { + // marker — emit StartToThink + out <- AsyncChatResult{ + Answer: "", + Reference: map[string]interface{}{}, + AudioBinary: nil, + CreatedAt: float64(time.Now().Unix()), + Final: false, + StartToThink: true, + } + } else if kind == "marker" && output == "" { + // marker — emit EndToThink + out <- AsyncChatResult{ + Answer: "", + Reference: map[string]interface{}{}, + AudioBinary: nil, + CreatedAt: float64(time.Now().Unix()), + Final: false, + EndToThink: true, + } + } else if kind == "text" && output != "" { + // Route reasoning text to Reasoning field. + // TTS is nil — chain-of-thought is not narrated. + out <- AsyncChatResult{ + Reasoning: output, + Reference: map[string]interface{}{}, + // TTS only narrates user-visible answer text. + AudioBinary: nil, + CreatedAt: float64(time.Now().Unix()), + Final: false, + } + } + } + if isContentDelta(answer) { + fullAnswer += *answer + out <- AsyncChatResult{ + Answer: *answer, + Reference: map[string]interface{}{}, + // Per-delta TTS for incremental audio playback. + AudioBinary: s.synthesizeTTS(ttsModel, *answer), + CreatedAt: float64(time.Now().Unix()), + Final: false, + } + } + return nil + }, + ) + } + if driverErr != nil { + out <- AsyncChatResult{ + Answer: fmt.Sprintf("**ERROR**: %s", driverErr.Error()), + Final: true, + } + return + } + + // Flush remaining think stream buffer. + // If the LLM ended mid-think, emit remaining reasoning + + // implicit close marker. + remainingText, remainingMarker := flushThinkStream(thinkState) + if remainingText != "" { + // Flushed text belongs in Reasoning, not Answer. + out <- AsyncChatResult{ + Reasoning: remainingText, + Reference: map[string]interface{}{}, + AudioBinary: nil, + CreatedAt: float64(time.Now().Unix()), + Final: false, + } + } + if remainingMarker == "" { + out <- AsyncChatResult{ + Answer: "", + Reference: map[string]interface{}{}, + AudioBinary: nil, + CreatedAt: float64(time.Now().Unix()), + Final: false, + EndToThink: true, + } + } + + // Decorate and yield the final answer. + visibleAnswer := s.extractVisibleAnswer(fullReasoning + fullAnswer) + + // Pass nil for ttsModel — audio was already produced per-delta. + final := s.decorateAnswer(ctx, visibleAnswer, kbinfos, prompt, questions, usedTokenCount, timer, embModel, chat.VectorSimilarityWeight, quote, nil, langfuseTraceID, llmModelConfig, chat.TenantID, kbTenantIDStrings(kbs), len(knowledges) > 0) + final.Final = true + final.AudioBinary = nil + timer.Exit(common.PhaseGenerateAnswer) + out <- final + } else { + // Non-streaming: get the answer synchronously. + var answer string + var err error + chatCfg := BuildChatConfig(chat, nil) + + // Tool routing: use tool-loop when tools are bound. + if chatDriver.ToolConfig != nil { + answer, _, err = chatDriver.ChatWithTools(ctx, prompt+prompt4citation, chatMessages, chatCfg) + } else { + resp, respErr := chatDriver.ModelDriver.ChatWithMessages( + *chatDriver.ModelName, chatMessages, chatDriver.APIConfig, chatCfg, + ) + if respErr != nil { + err = respErr + } else if resp != nil && resp.Answer != nil { + answer = *resp.Answer + } + } + + if err != nil { + out <- AsyncChatResult{ + Answer: fmt.Sprintf("**ERROR**: %s", err.Error()), + Final: true, + } + return + } + + // Last user message's content for the debug log. + userContent := "[content not available]" + if len(llmMessages) > 1 { + if c, ok := llmMessages[len(llmMessages)-1]["content"].(string); ok { + userContent = c + } + } + common.Debug("User: " + userContent + "|Assistant: " + answer) + + // Synthesize TTS for the full answer (non-stream, one-shot). + final := s.decorateAnswer(ctx, answer, kbinfos, prompt, questions, usedTokenCount, timer, embModel, chat.VectorSimilarityWeight, quote, ttsModel, langfuseTraceID, llmModelConfig, chat.TenantID, kbTenantIDStrings(kbs), len(knowledges) > 0) + final.Final = true + timer.Exit(common.PhaseGenerateAnswer) + out <- final + } + common.Info("AsyncChat completed", zap.String("chat_id", chat.ID)) + }() + + return out, nil +} + +// AsyncChatSolo is the LLM-only chat path (no KBs, no web search). +// Equivalent to Python's async_chat_solo() in dialog_service.py:289-337. +func (s *ChatPipelineService) AsyncChatSolo( + ctx context.Context, + chat *entity.Chat, + messages []map[string]interface{}, + stream bool, +) (<-chan AsyncChatResult, error) { + + out := make(chan AsyncChatResult, 16) + + go func() { + defer close(out) + + // Timer brackets the LLM call; other phases are N/A in solo mode. + timer := common.NewTimer() + timer.Start() + + // 1. Resolve system prompt. + promptConfig := chat.PromptConfig + systemPrompt := "" + if sp, ok := promptConfig["system"].(string); ok { + systemPrompt = sp + } + + // 1b. Resolve LLM model config (needed early for model_type dispatch). + llmModelConfig, _, _, _, err := s.getLLMModelConfig(chat) + factoryName := "" + if err == nil && llmModelConfig != nil { + factoryName, _ = llmModelConfig["llm_factory"].(string) + } + if factoryName == "" { + factoryName = factoryFromLLMID(chat.LLMID) + } + + // 2. Process file attachments (chat → data URIs, image2text → raw URLs). + attachmentsStr := "" + var imageFiles []string + modelType := "chat" + if llmModelConfig != nil { + if mt, ok := llmModelConfig["model_type"].(string); ok && mt != "" { + modelType = mt + } + } + isImage2Text := modelType == "image2text" + if len(messages) > 0 { + if files, hasFiles := messages[len(messages)-1]["files"]; hasFiles { + attachmentsStr = s.processFileAttachments(files) + if isImage2Text { + imageFiles = s.extractRawImageURLs(files) + } else { + imageFiles = s.extractImageFiles(files) + } + } + } + + // 3. Strip citation markers and drop system messages from history. + var msg []map[string]interface{} + for _, m := range messages { + role, _ := m["role"].(string) + if role == "system" { + continue + } + content := m["content"] + if contentStr, ok := content.(string); ok { + content = cleanCitationMarkers(contentStr) + } + msg = append(msg, map[string]interface{}{ + "role": role, + "content": content, + }) + } + // Append text attachments to the last user message (no separator). + if attachmentsStr != "" && len(msg) > 0 { + if lastContent, ok := msg[len(msg)-1]["content"].(string); ok { + msg[len(msg)-1]["content"] = lastContent + attachmentsStr + } + } + + // 4. Build the chat model wrapper. + driver, modelName, apiConfig, _, err := s.ModelProviderSvc.GetChatModelConfig(chat.TenantID, chat.LLMID) + if err != nil { + out <- AsyncChatResult{ + Answer: fmt.Sprintf("**ERROR**: %s", err.Error()), + Final: true, + } + return + } + chatModel := modelModule.NewChatModel(driver, &modelName, apiConfig) + + // 5. Resolve TTS model. Best-effort: warn and proceed without TTS on lookup failure. + var ttsModel *modelModule.ChatModel + if promptConfig != nil { + if useTTS, _ := promptConfig["tts"].(bool); useTTS { + ttsDriver, ttsName, ttsConfig, _, ttsErr := s.ModelProviderSvc.GetTenantDefaultModelByType( + chat.TenantID, entity.ModelTypeTTS, + ) + if ttsErr != nil { + common.Warn("AsyncChatSolo: TTS lookup failed; proceeding without TTS", + zap.String("tenant_id", chat.TenantID), + zap.Error(ttsErr)) + } else { + ttsModel = modelModule.NewChatModel(ttsDriver, &ttsName, ttsConfig) + } + } + } + + // 6. Build messages for driver. Convert last user msg to multimodal if images present. + var chatMessages []modelModule.Message + if systemPrompt != "" { + chatMessages = append(chatMessages, modelModule.Message{ + Role: "system", + Content: systemPrompt, + }) + } + for i, m := range msg { + role, _ := m["role"].(string) + content := m["content"] + // Multimodal conversion for the last user message. + if i == len(msg)-1 && role == "user" && len(imageFiles) > 0 { + if converted, err := common.ConvertLastUserMsgToMultimodal( + map[string]interface{}{"role": role, "content": content}, + imageFiles, + strings.ToLower(factoryName), + ); err == nil { + content = converted["content"] + } + } + chatMessages = append(chatMessages, modelModule.Message{ + Role: role, + Content: content, + }) + } + + // 7. Drive the LLM: stream (per-delta with think markers) or non-stream (one-shot). + if stream { + var fullAnswer string + var fullReasoning string + thinkState := &thinkStreamState{} + chatCfg := BuildChatConfig(chat, nil) + timer.Enter(common.PhaseGenerateAnswer) + driverErr := chatModel.ModelDriver.ChatStreamlyWithSender( + *chatModel.ModelName, chatMessages, chatModel.APIConfig, chatCfg, + func(answer *string, reason *string) error { + if reason != nil && *reason != "" { + fullReasoning += *reason + kind, output := processThinkDelta(thinkState, *reason, 16) + if kind == "marker" && output == "" { + // Start thinking. + out <- AsyncChatResult{ + Answer: "", + Reference: map[string]interface{}{}, + AudioBinary: nil, + CreatedAt: float64(time.Now().Unix()), + Final: false, + StartToThink: true, + } + } else if kind == "marker" && output == "" { + // End thinking. + out <- AsyncChatResult{ + Answer: "", + Reference: map[string]interface{}{}, + AudioBinary: nil, + CreatedAt: float64(time.Now().Unix()), + Final: false, + EndToThink: true, + } + } else if kind == "text" && output != "" { + // Reasoning text with per-delta TTS. + out <- AsyncChatResult{ + Reasoning: output, + Reference: map[string]interface{}{}, + AudioBinary: s.synthesizeTTS(ttsModel, output), + CreatedAt: float64(time.Now().Unix()), + Final: false, + } + } + } + if isContentDelta(answer) { + fullAnswer += *answer + out <- AsyncChatResult{ + Answer: *answer, + Reference: map[string]interface{}{}, + AudioBinary: s.synthesizeTTS(ttsModel, *answer), + CreatedAt: float64(time.Now().Unix()), + Final: false, + } + } + return nil + }, + ) + if driverErr != nil { + out <- AsyncChatResult{ + Answer: fmt.Sprintf("**ERROR**: %s", driverErr.Error()), + Final: true, + } + return + } + timer.Exit(common.PhaseGenerateAnswer) + // Flush any remaining think buffer. + remainingText, remainingMarker := flushThinkStream(thinkState) + if remainingText != "" { + out <- AsyncChatResult{ + Reasoning: remainingText, + Reference: map[string]interface{}{}, + AudioBinary: s.synthesizeTTS(ttsModel, remainingText), + CreatedAt: float64(time.Now().Unix()), + Final: false, + } + } + if remainingMarker == "" { + out <- AsyncChatResult{ + Answer: "", + Reference: map[string]interface{}{}, + AudioBinary: nil, + CreatedAt: float64(time.Now().Unix()), + Final: false, + EndToThink: true, + } + } + // Final aggregate: re-attach reasoning wrapper for non-streaming consumers. + finalAnswer := fullAnswer + if fullReasoning != "" { + finalAnswer = "" + fullReasoning + "" + fullAnswer + } + // Raw answer, no decorate_answer. AudioBinary=nil (per-delta TTS already emitted). + out <- AsyncChatResult{ + Answer: finalAnswer, + Reference: map[string]interface{}{}, + AudioBinary: nil, + CreatedAt: float64(time.Now().Unix()), + Final: true, + } + } else { + // Non-streaming: one-shot call. + chatCfg := BuildChatConfig(chat, nil) + timer.Enter(common.PhaseGenerateAnswer) + resp, err := chatModel.ModelDriver.ChatWithMessages( + *chatModel.ModelName, chatMessages, chatModel.APIConfig, chatCfg, + ) + timer.Exit(common.PhaseGenerateAnswer) + if err != nil { + out <- AsyncChatResult{ + Answer: fmt.Sprintf("**ERROR**: %s", err.Error()), + Final: true, + } + return + } + answer := "" + if resp.Answer != nil { + answer = *resp.Answer + } + // Debug log matching Python's dialog_service.py:335-336. + userContent := "[content not available]" + if len(msg) > 0 { + if c, ok := msg[len(msg)-1]["content"].(string); ok { + userContent = c + } + } + common.Debug("User: " + userContent + "|Assistant: " + answer) + + // Raw answer with full TTS, no decorate_answer. Caller handles decoration. + out <- AsyncChatResult{ + Answer: answer, + Reference: map[string]interface{}{}, + AudioBinary: s.synthesizeTTS(ttsModel, answer), + CreatedAt: float64(time.Now().Unix()), + Final: true, + } + } + }() + + return out, nil +} + +// extractImageFiles extracts data-URI image attachments from the files list. +// Mirrors Python split_file_attachments raw mode. +func (s *ChatPipelineService) extractImageFiles(files interface{}) []string { + // ── File-dict mode ── + if fileDicts, ok := parseFileDicts(files); ok { + fileSvc := NewFileService() + // Use raw=false to get base64 data URIs for images. + _, images, err := fileSvc.GetFileContents(fileDicts, false) + if err != nil { + common.Warn("GetFileContents failed in extractImageFiles", + zap.Error(err)) + return nil + } + return images + } + + // ── String fallback ── + var images []string + switch v := files.(type) { + case []string: + for _, f := range v { + if strings.HasPrefix(f, "data:") { + images = append(images, f) + } + } + case []interface{}: + for _, f := range v { + if s, ok := f.(string); ok && strings.HasPrefix(s, "data:") { + images = append(images, s) + } + } + } + return images +} + +// extractRawImageURLs extracts image references as raw URLs/data-URIs from +// the string-mode files list, WITHOUT fetching blobs and WITHOUT filtering +// to data: prefixes. Used for image2text models that expect URLs in the +// multimodal content (matches Python's `image_files` from +// `split_file_attachments(files, raw=True)` at +// dialog_service.py:371-392). +// +// The downstream ConvertLastUserMsgToMultimodal calls parseDataURIOrB64 +// (multimodal.go:63-92) which correctly handles all three forms: +// - data: URI → base64 source +// - http:// or https:// URL → URL source +// - raw base64 → base64 source (default media type) +// +// File-dict mode is a known limitation: returns empty for now. A future +// FileService.GetFileURLsForChat (mirror of GetFileContents with +// raw=true) would be needed to fully cover the file-dict + image2text +// combination. The Python equivalent has the same limitation +// (split_file_attachments calls FileService.get_files which doesn't +// fetch blobs in raw mode). +func (s *ChatPipelineService) extractRawImageURLs(files interface{}) []string { + if fileDicts, ok := parseFileDicts(files); ok { + _ = fileDicts // see file-dict limitation comment above + common.Debug("AsyncChatSolo: file-dict + image2text not yet supported; image refs dropped", + zap.Int("file_dict_count", len(fileDicts))) + return nil + } + + // String-mode: return all entries as-is. The downstream + // ConvertLastUserMsgToMultimodal + parseDataURIOrB64 will + // dispatch on prefix (data: → base64, http(s): → url, else → + // raw base64). + var urls []string + switch v := files.(type) { + case []string: + for _, f := range v { + if f != "" { + urls = append(urls, f) + } + } + case []interface{}: + for _, f := range v { + if s, ok := f.(string); ok && s != "" { + urls = append(urls, s) + } + } + } + return urls +} + +// --------------------------------------------------------------------------- +// Helper methods +// --------------------------------------------------------------------------- + +// internetTruthyStrings / internetFalsyStrings mirror the case-insensitive, +// whitespace-trimmed alias sets at dialog_service.py:115-117 of +// _normalize_internet_flag. Kept in one place so a future addition (e.g. +// Python accepting "y"/"n") is a one-line change here. +var internetTruthyStrings = map[string]bool{"true": true, "1": true, "yes": true, "on": true} +var internetFalsyStrings = map[string]bool{"false": true, "0": true, "no": true, "off": true, "": true} + +// normalizeInternetFlag is the Go port of Python's +// _normalize_internet_flag (dialog_service.py:108-119). Three-state +// return matches Python: *true → explicit truthy, *false → explicit +// falsy, nil → couldn't interpret (Python's `return None`). The caller +// decides what to do with nil — _should_use_web_search treats it as +// "not enabled," so shouldUseWebSearch below only returns true when +// the normalized result is explicitly true. +// +// Accepted inputs (mirroring Python): +// - bool: returned as-is +// - int / int64 / float64 with value 0 or 1: coerced to bool +// - string (case-insensitive, trimmed): "true"/"1"/"yes"/"on" → true; +// "false"/"0"/"no"/"off"/"" → false +// - everything else (nil, slices, maps, other numeric values, +// unrecognized strings, complex, etc.) → nil +func normalizeInternetFlag(v interface{}) *bool { + switch x := v.(type) { + case bool: + return &x + case string: + s := strings.ToLower(strings.TrimSpace(x)) + if internetTruthyStrings[s] { + t := true + return &t + } + if internetFalsyStrings[s] { + f := false + return &f + } + case int: + if x == 0 || x == 1 { + b := x == 1 + return &b + } + case int64: + if x == 0 || x == 1 { + b := x == 1 + return &b + } + case float64: + if x == 0 || x == 1 { + b := x == 1 + return &b + } + } + return nil +} + +// shouldUseWebSearch returns true if web search should be enabled. +// Mirrors Python's _should_use_web_search (dialog_service.py:122-126): +// Tavily key must be present on chat.PromptConfig AND the internet +// flag must normalize to explicit true. +// +// The second parameter takes the raw internet value (typically +// kwargs["internet"] at the call site) — same shape as Python's +// `_should_use_web_search(chat.prompt_config, kwargs.get("internet"))`. +func (s *ChatPipelineService) shouldUseWebSearch(chat *entity.Chat, internet interface{}) bool { + if chat.PromptConfig == nil { + return false + } + tavilyKey, _ := chat.PromptConfig["tavily_api_key"].(string) + if tavilyKey == "" { + return false + } + normalized := normalizeInternetFlag(internet) + return normalized != nil && *normalized +} + +// tavilyRetrieve calls the Tavily API and returns results in the same chunk +// format used by performRetrieval. Mirrors Python's Tavily.retrieve_chunks() +// in rag/utils/tavily_conn.py. +func (s *ChatPipelineService) tavilyRetrieve(ctx context.Context, apiKey, question string) (map[string]interface{}, error) { + const tavilyURL = "https://api.tavily.com/search" + + body := map[string]interface{}{ + "query": question, + "search_depth": "advanced", + "max_results": 6, + } + bodyBytes, err := json.Marshal(body) + if err != nil { + return nil, fmt.Errorf("tavily: marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, tavilyURL, bytes.NewReader(bodyBytes)) + if err != nil { + return nil, fmt.Errorf("tavily: new request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("tavily: do request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("tavily: status %d", resp.StatusCode) + } + + var tavilyResp struct { + Results []struct { + URL string `json:"url"` + Title string `json:"title"` + Content string `json:"content"` + Score float64 `json:"score"` + } `json:"results"` + } + if err := json.NewDecoder(resp.Body).Decode(&tavilyResp); err != nil { + return nil, fmt.Errorf("tavily: decode response: %w", err) + } + + chunks := make([]map[string]interface{}, 0, len(tavilyResp.Results)) + docAggs := make([]interface{}, 0, len(tavilyResp.Results)) + for _, r := range tavilyResp.Results { + id := fmt.Sprintf("tavily-%s", r.URL) + chunk := map[string]interface{}{ + "chunk_id": id, + "content_ltks": tokenizeText(r.Content), // tokenized content + "content_with_weight": r.Content, + "doc_id": id, + "docnm_kwd": r.Title, + "kb_id": []interface{}{}, + "important_kwd": []interface{}{}, + "image_id": "", + "similarity": r.Score, + "vector_similarity": 1.0, + "term_similarity": 0.0, + "vector": []float64{}, // empty; no embedding for web results + "positions": []interface{}{}, + "url": r.URL, + } + chunks = append(chunks, chunk) + docAggs = append(docAggs, map[string]interface{}{ + "doc_name": r.Title, + "doc_id": id, + "count": 1, + "url": r.URL, + }) + } + + common.Info("[Tavily] question: "+question, zap.Int("results", len(chunks))) + return map[string]interface{}{ + "chunks": chunks, + "doc_aggs": docAggs, + }, nil +} + +// tokenizeText is a lightweight tokenizer for Tavily content. +// It lowercases and splits on whitespace, similar to rag_tokenizer.tokenize. +func tokenizeText(text string) string { + // Collapse multiple whitespaces and split. + ws := regexp.MustCompile(`\s+`) + text = ws.ReplaceAllString(text, " ") + // Convert to lowercase for tokenization. + return strings.ToLower(text) +} + +// getLLMModelConfig resolves the LLM model configuration for the chat. +// Mirrors Python's three-branch resolver at dialog_service.py:552-561: +// +// if chat.llm_id: +// if "image2text" in get_model_type_by_name(...): → IMAGE2TEXT +// else: → CHAT +// else: → tenant default CHAT +// +// The returned `cfg` map's "model_type" field carries the chosen type +// so downstream code (e.g. the multimodal-conversion guard in AsyncChat +// at async_chat.go:632) can skip chat-only logic for image2text dialogs. +func (s *ChatPipelineService) getLLMModelConfig(chat *entity.Chat) (map[string]interface{}, string, string, string, error) { + if chat.LLMID == "" { + // Branch 3: no explicit LLM → tenant default chat model. + return s.buildLLMModelConfig( + s.ModelProviderSvc.GetTenantDefaultModelByType(chat.TenantID, entity.ModelTypeChat), + ) + } + + // Branches 1/2: explicit LLM. Probe model types and pick IMAGE2TEXT + // when the LLM is registered as such, otherwise CHAT. + modelType := entity.ModelTypeChat + modelTypeStr := "chat" + if modelTypes, mtErr := s.ModelProviderSvc.GetModelTypeByName(chat.TenantID, chat.LLMID); mtErr == nil { + for _, mt := range modelTypes { + if mt == entity.ModelTypeImage2Text { + modelType = entity.ModelTypeImage2Text + modelTypeStr = "image2text" + break + } + } + } + cfg, modelName, factoryName, baseURL, err := s.buildLLMModelConfig( + s.ModelProviderSvc.GetModelConfigFromProviderInstance(chat.TenantID, modelType, chat.LLMID), + ) + if err != nil { + return nil, "", "", "", err + } + cfg["model_type"] = modelTypeStr + return cfg, modelName, factoryName, baseURL, nil +} + +// buildLLMModelConfig collapses the (driver, modelName, apiConfig, +// _, err) tuple from a model-provider lookup into the dict-shaped +// config the rest of async_chat.go consumes. Default "model_type" is +// "chat"; callers that resolved a different type overwrite the key +// before returning. +func (s *ChatPipelineService) buildLLMModelConfig( + driver modelModule.ModelDriver, + modelName string, + apiConfig *modelModule.APIConfig, + maxTokens int, + err error, +) (map[string]interface{}, string, string, string, error) { + if err != nil { + return nil, "", "", "", err + } + // Match Python: llm.max_tokens if llm.max_tokens else 8192. + if maxTokens == 0 { + maxTokens = 8192 + } + cfg := map[string]interface{}{ + "model_type": "chat", + "llm_name": modelName, + "max_tokens": maxTokens, + "llm_factory": driver.Name(), + } + baseURL := "" + if apiConfig != nil && apiConfig.BaseURL != nil { + baseURL = *apiConfig.BaseURL + } + return cfg, modelName, driver.Name(), baseURL, nil +} + +// getModels resolves all models needed for the RAG pipeline. +// Mirrors Python's get_models() in dialog_service.py:340. +func (s *ChatPipelineService) getModels(ctx context.Context, chat *entity.Chat) ( + []*entity.Knowledgebase, + *modelModule.EmbeddingModel, + *modelModule.RerankModel, + *modelModule.ChatModel, + *modelModule.ChatModel, // TTS model +) { + kbDAO := dao.NewKnowledgebaseDAO() + + // Extract KB ID strings. + kbIDs := make([]string, 0, len(chat.KBIDs)) + for _, raw := range chat.KBIDs { + if id, ok := raw.(string); ok && id != "" { + kbIDs = append(kbIDs, id) + } + } + + var kbs []*entity.Knowledgebase + if len(kbIDs) > 0 { + var err error + kbs, err = kbDAO.GetByIDs(kbIDs) + if err != nil { + common.Warn("Failed to get KBs by IDs; retrieval may be incomplete", + zap.Strings("kbIDs", kbIDs), zap.Error(err)) + } + } + + // Embedding model. + var embModel *modelModule.EmbeddingModel + if len(kbs) > 0 { + // All KBs must share the same embedding model. + embdIDs := make(map[string]bool) + for _, kb := range kbs { + if kb.EmbdID != "" { + embdIDs[kb.EmbdID] = true + } + } + if len(embdIDs) > 1 { + // Multiple embedding models across KBs — error. + common.Warn("Knowledge bases use different embedding models") + } + if len(embdIDs) == 1 { + for embdID := range embdIDs { + embdTenantID := kbs[0].TenantID + driver, modelName, apiConfig, maxTokens, err := s.ModelProviderSvc.GetModelConfigFromProviderInstance( + embdTenantID, entity.ModelTypeEmbedding, embdID, + ) + if err == nil { + embModel = modelModule.NewEmbeddingModel(driver, &modelName, apiConfig, maxTokens) + } + } + } + } + + // Chat model. + driver, modelName, apiConfig, _, err := s.ModelProviderSvc.GetChatModelConfig(chat.TenantID, chat.LLMID) + var chatModel *modelModule.ChatModel + if err == nil { + chatModel = modelModule.NewChatModel(driver, &modelName, apiConfig) + } + + // Rerank model. + var rerankModel *modelModule.RerankModel + if chat.RerankID != "" { + rerankDriver, rerankName, rerankConfig, _, err := s.ModelProviderSvc.GetModelConfigFromProviderInstance( + chat.TenantID, entity.ModelTypeRerank, chat.RerankID, + ) + if err == nil { + rerankModel = modelModule.NewRerankModel(rerankDriver, &rerankName, rerankConfig) + } + } + + // TTS model. + var ttsModel *modelModule.ChatModel + if chat.PromptConfig != nil { + if useTTS, _ := chat.PromptConfig["tts"].(bool); useTTS { + ttsDriver, ttsName, ttsConfig, _, err := s.ModelProviderSvc.GetTenantDefaultModelByType( + chat.TenantID, entity.ModelTypeTTS, + ) + if err == nil { + ttsModel = modelModule.NewChatModel(ttsDriver, &ttsName, ttsConfig) + } + } + } + + return kbs, embModel, rerankModel, chatModel, ttsModel +} + +// lastUserQuestion returns the content of the most recent user message in +// `messages`, or "" if there is no user message. Used by the P2 +// meta_data_filter wiring (Python's `questions[-1]` in +// dialog_service.py:655). + +// factoryFromLLMID extracts the provider name from a composite LLM ID +// like "Qwen3-8B@ling@SILICONFLOW" → "SILICONFLOW". When the LLM ID has +// no "@provider" segment, returns "openai" as a default. The lowercase +// return value is what ConvertLastUserMsgToMultimodal / +// RenderContentPartsForFactory dispatch on. +func factoryFromLLMID(llmID string) string { + if llmID == "" { + return "openai" + } + parts := strings.Split(llmID, "@") + if len(parts) < 3 { + return "openai" + } + provider := strings.ToLower(parts[len(parts)-1]) + if provider == "" { + return "openai" + } + return provider +} + +// The handler in openai_chat.go has already rejected requests +// whose last message is not from the user, so this should always succeed. +func lastUserQuestion(messages []map[string]interface{}) string { + for i := len(messages) - 1; i >= 0; i-- { + role, _ := messages[i]["role"].(string) + if role == "user" { + if c, ok := messages[i]["content"].(string); ok { + return c + } + return "" + } + } + return "" +} + +// processFileAttachments extracts text content from file attachments. +// Mirrors Python's split_file_attachments (dialog_service.py:371-392) +// in raw=false mode: returns text attachments joined by "\n\n", +// filtering out data-URI image attachments. +// +// When files are file dicts (Python-compatible format), calls +// FileService.GetFileContents to fetch actual blobs from storage. +func (s *ChatPipelineService) processFileAttachments(files interface{}) string { + // ── File-dict mode ── + if fileDicts, ok := parseFileDicts(files); ok { + fileSvc := NewFileService() + texts, _, err := fileSvc.GetFileContents(fileDicts, false) + if err != nil { + common.Warn("GetFileContents failed in processFileAttachments", + zap.Error(err)) + return "" + } + if len(texts) == 0 { + return "" + } + return strings.Join(texts, "\n\n") + } + + // ── String fallback ── + var texts []string + switch v := files.(type) { + case []string: + for _, f := range v { + if s := strings.TrimSpace(f); s != "" && !strings.HasPrefix(s, "data:") { + texts = append(texts, s) + } + } + case []interface{}: + for _, f := range v { + if s, ok := f.(string); ok && strings.TrimSpace(s) != "" && !strings.HasPrefix(s, "data:") { + texts = append(texts, s) + } + } + } + if len(texts) == 0 { + return "" + } + return strings.Join(texts, "\n\n") +} + +// splitFileAttachments mirrors Python's `split_file_attachments` at +// dialog_service.py:371-392. It separates `messages[-1]["files"]` +// into text-file content and image attachments. +// +// Two modes of operation: +// +// 1. File-dict mode: When `files` is `[]map[string]interface{}` (each dict +// with keys "id", "created_by", "mime_type", "name"), the method calls +// FileService.GetFileContents to fetch actual file blobs from +// storage, mirroring Python's FileService.get_files(). +// +// 2. String-fallback mode: When `files` is `[]string` or `[]interface{}` of +// strings (pre-resolved content), the method does simple string splitting: +// - raw=false: split by "data:" prefix. Text → textAttachments; data: +// URIs → image files. +// - raw=true: all items go to textAttachments (Python's FileService.get_files +// with raw=True pre-separates images, so non-image content arrives here). +func splitFileAttachments(files interface{}, raw bool) (textAttachments []string, imageAttachments []string) { + // ── Mode 1: file dicts (Python-compatible) ── + if fileDicts, ok := parseFileDicts(files); ok { + fileSvc := NewFileService() + texts, images, err := fileSvc.GetFileContents(fileDicts, raw) + if err != nil { + common.Warn("GetFileContents failed, falling back to string splitting", + zap.Error(err)) + } else { + return texts, images + } + } + + // ── Mode 2: string content fallback (backward compat) ── + var texts []string + var images []string + + if raw { + // Mirrors Python raw=True: FileService.get_files already + // separated images; only non-image content arrives here. + switch v := files.(type) { + case []string: + for _, f := range v { + f = strings.TrimSpace(f) + if f != "" { + texts = append(texts, f) + } + } + case []interface{}: + for _, f := range v { + if s, ok := f.(string); ok { + s = strings.TrimSpace(s) + if s != "" { + texts = append(texts, s) + } + } + } + } + return texts, images + } + + // raw=false: split by "data:" prefix. + process := func(f string) { + f = strings.TrimSpace(f) + if f == "" { + return + } + if strings.HasPrefix(f, "data:") { + images = append(images, f) + } else { + texts = append(texts, f) + } + } + switch v := files.(type) { + case []string: + for _, f := range v { + process(f) + } + case []interface{}: + for _, f := range v { + if s, ok := f.(string); ok { + process(s) + } + } + } + return texts, images +} + +// parseFileDicts attempts to parse files as a list of file-dict maps +// (the Python-compatible format from messages[-1]["files"]). +// Returns the parsed slice and true on success. +func parseFileDicts(files interface{}) ([]map[string]interface{}, bool) { + switch v := files.(type) { + case []map[string]interface{}: + if len(v) == 0 { + return nil, false + } + // Verify the first element has a recognizable file-dict key. + if _, ok := v[0]["id"]; ok { + return v, true + } + return nil, false + case []interface{}: + if len(v) == 0 { + return nil, false + } + // Check if the first element is a map with file-dict keys. + if m, ok := v[0].(map[string]interface{}); ok { + if _, hasID := m["id"]; hasID { + result := make([]map[string]interface{}, len(v)) + for i, item := range v { + if m2, mok := item.(map[string]interface{}); mok { + result[i] = m2 + } else { + return nil, false + } + } + return result, true + } + } + } + return nil, false +} + +// cleanTTSText sanitizes text for TTS synthesis. +// Mirrors dialog_service.py:1404-1423. +func cleanTTSText(text string) string { + if text == "" { + return "" + } + // Strip control chars. + controlRe := regexp.MustCompile(`[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]`) + text = controlRe.ReplaceAllString(text, "") + // Strip emojis. + emojiRe := regexp.MustCompile("[\U0001f600-\U0001f64f\U0001f300-\U0001f5ff\U0001f680-\U0001f6ff\U0001f1e0-\U0001f1ff\U00002700-\U000027bf\U0001f900-\U0001f9ff\U0001fa70-\U0001faff\U0001fad0-\U0001faff]+") + text = emojiRe.ReplaceAllString(text, "") + // Collapse whitespace. + wsRe := regexp.MustCompile(`\s+`) + text = wsRe.ReplaceAllString(text, " ") + text = strings.TrimSpace(text) + if len(text) > 500 { + text = text[:500] + } + return text +} + +// synthesizeTTS calls the TTS model to convert text to audio. +// Mirrors dialog_service.py:1426-1432. +func (s *ChatPipelineService) synthesizeTTS(ttsModel *modelModule.ChatModel, text string) interface{} { + if ttsModel == nil || text == "" { + return nil + } + text = cleanTTSText(text) + if text == "" { + return nil + } + ttsResp, err := ttsModel.ModelDriver.AudioSpeech( + ttsModel.ModelName, &text, ttsModel.APIConfig, &modelModule.TTSConfig{Format: "mp3"}, + ) + if err != nil { + common.Warn("TTS synthesis failed", zap.Error(err)) + return nil + } + if ttsResp == nil || len(ttsResp.Audio) == 0 { + return nil + } + return ttsResp.Audio +} + +// truncateForLog returns at most n characters of s, appending an +// ellipsis when truncated. Used to keep zap log lines bounded. +func truncateForLog(s string, n int) string { + if n <= 0 || len(s) <= n { + return s + } + return s[:n] + "..." +} + +// resolveReferenceMetadata mirrors Python's +// `resolve_reference_metadata_preferences` in +// api/utils/reference_metadata_utils.py:22-62. Returns (include, +// fields). The Python algorithm: +// +// resolved = {**config["reference_metadata"], **request["reference_metadata"]} +// if "include_metadata" in request: resolved["include"] = ... +// if "metadata_fields" in request: resolved["fields"] = ... +// include = bool(resolved.get("include", False)) +// fields = resolved.get("fields") → list of strings +// +// Config is `promptConfig["reference_metadata"]`, request is `kwargs`. +// `kwargs` takes precedence for both `include_metadata` (legacy) and +// `metadata_fields` (legacy), and for the entire `reference_metadata` +// sub-dict (preferred). +func (s *ChatPipelineService) resolveReferenceMetadata(promptConfig map[string]interface{}, kwargs map[string]interface{}) (bool, []string) { + resolved := map[string]interface{}{} + + // Layer 1: prompt_config["reference_metadata"] (config). + if promptConfig != nil { + if cfgRef, ok := promptConfig["reference_metadata"].(map[string]interface{}); ok { + for k, v := range cfgRef { + resolved[k] = v + } + } + } + // Layer 2: kwargs["reference_metadata"] (request, takes precedence). + if kwargs != nil { + if reqRef, ok := kwargs["reference_metadata"].(map[string]interface{}); ok { + for k, v := range reqRef { + resolved[k] = v + } + } + // Layer 3: legacy request keys (kwargs). + if v, ok := kwargs["include_metadata"]; ok { + if b, ok := v.(bool); ok { + resolved["include"] = b + } + } + if v, ok := kwargs["metadata_fields"]; ok { + resolved["fields"] = v + } + } + + include, _ := resolved["include"].(bool) + if !include { + return false, nil + } + rawFields, ok := resolved["fields"] + if !ok || rawFields == nil { + return true, nil + } + switch v := rawFields.(type) { + case []string: + return true, v + case []interface{}: + out := make([]string, 0, len(v)) + for _, item := range v { + if s, ok := item.(string); ok { + out = append(out, s) + } + } + return true, out + } + return true, nil +} + +// enrichChunksWithMetadata enriches chunk records in kbinfos with document-level +// metadata. Mirrors Python's enrich_chunks_with_document_metadata() in +// api/utils/reference_metadata_utils.py. +func (s *ChatPipelineService) enrichChunksWithMetadata(kbinfos map[string]interface{}, tenantID string, fields []string) { + chunksRaw, ok := kbinfos["chunks"].([]map[string]interface{}) + if !ok || len(chunksRaw) == 0 { + return + } + + chunks := make([]map[string]interface{}, 0, len(chunksRaw)) + chunks = append(chunks, chunksRaw...) + if len(chunks) == 0 { + return + } + + s.MetadataSvc.EnrichChunksWithDocMetadata(chunks, tenantID, fields) +} + +// kbPrompt builds knowledge prompt blocks from retrieved chunks. +// Mirrors Python's kb_prompt() in rag/prompts/generator.py. +func (s *ChatPipelineService) kbPrompt(kbinfos map[string]interface{}, maxTokens int) []string { + chunksRaw, ok := kbinfos["chunks"].([]map[string]interface{}) + if !ok || len(chunksRaw) == 0 { + return nil + } + + // Pass 1: count content tokens to determine how many chunks fit. + type chunkContent struct { + content string + } + contents := make([]chunkContent, 0, len(chunksRaw)) + for _, ck := range chunksRaw { + c := getMapString(ck, "content", "content_with_weight") + if c == "" { + continue + } + contents = append(contents, chunkContent{content: c}) + } + + usedTokenCount := 0 + chunksNum := 0 + for _, cc := range contents { + usedTokenCount += kg.NumTokensFromString(cc.content) + chunksNum++ + if float64(maxTokens)*0.97 < float64(usedTokenCount) { + common.Warn("Not all the retrieval into prompt", + zap.Int("kept", chunksNum), + zap.Int("total", len(contents))) + break + } + } + + // Pass 2: format chunks with tree structure, capped at chunksNum. + if chunksNum > len(chunksRaw) { + chunksNum = len(chunksRaw) + } + var result []string + for i := 0; i < chunksNum; i++ { + ck := chunksRaw[i] + c := getMapString(ck, "content", "content_with_weight") + if c == "" { + continue + } + + cnt := fmt.Sprintf("\nID: %d", i) + cnt += drawNode("Title", getMapString(ck, "docnm_kwd", "document_name")) + cnt += drawNode("URL", getMapString(ck, "url")) + if meta, ok := ck["document_metadata"].(map[string]interface{}); ok { + for k, v := range meta { + cnt += drawNode(k, v) + } + } + cnt += "\n└── Content:\n" + cnt += c + result = append(result, cnt) + } + + return result +} + +// formatPrompt substitutes {key} placeholders in a prompt string. +func (s *ChatPipelineService) formatPrompt(template string, kwargs map[string]interface{}) string { + result := template + for key, value := range kwargs { + placeholder := "{" + key + "}" + if strings.Contains(result, placeholder) { + strVal := fmt.Sprintf("%v", value) + result = strings.ReplaceAll(result, placeholder, strVal) + } + } + // Replace any remaining {unknown} placeholders with empty string. + for _, key := range []string{"knowledge", "quote"} { + placeholder := "{" + key + "}" + if strings.Contains(result, placeholder) { + result = strings.ReplaceAll(result, placeholder, " ") + } + } + return result +} + +// messageFitIn trims messages to fit within a token budget. +// Mirrors Python's message_fit_in() in rag/prompts/generator.py. +func (s *ChatPipelineService) messageFitIn(messages []map[string]interface{}, maxTokens int) (int, []map[string]interface{}) { + totalTokens := 0 + var result []map[string]interface{} + + // Always keep the system message first. + if len(messages) > 0 { + if role, _ := messages[0]["role"].(string); role == "system" { + if content, ok := messages[0]["content"].(string); ok { + sysTokens := kg.NumTokensFromString(content) + if sysTokens <= maxTokens { + totalTokens = sysTokens + result = append(result, messages[0]) + } + } + } + } + + // Add user/assistant messages from the end, working backwards. + rest := messages[1:] + for i := len(rest) - 1; i >= 0; i-- { + m := rest[i] + content, _ := m["content"].(string) + tokens := kg.NumTokensFromString(content) + if totalTokens+tokens > maxTokens { + break + } + totalTokens += tokens + // Prepend to maintain order. + result = append([]map[string]interface{}{m}, result...) + } + + return totalTokens, result +} + +// buildChatMessages converts the internal message representation to +// modelModule.Message for the driver. +func (s *ChatPipelineService) buildChatMessages(systemContent string, messages []map[string]interface{}) []modelModule.Message { + var result []modelModule.Message + if systemContent != "" { + result = append(result, modelModule.Message{Role: "system", Content: systemContent}) + } + for _, m := range messages { + role, _ := m["role"].(string) + content := m["content"] + if role == "" || content == nil { + continue + } + result = append(result, modelModule.Message{Role: role, Content: content}) + } + return result +} + +// buildChatDriver creates a ChatModel wrapper from the chat. +func (s *ChatPipelineService) buildChatDriver(chat *entity.Chat, chatModel *modelModule.ChatModel) *modelModule.ChatModel { + if chatModel != nil { + return chatModel + } + driver, modelName, apiConfig, _, err := s.ModelProviderSvc.GetChatModelConfig(chat.TenantID, chat.LLMID) + if err != nil { + return nil + } + return modelModule.NewChatModel(driver, &modelName, apiConfig) +} + +// HydrateChunkVectors fills the `vector` field on each chunk in `kbinfos` +// that lacks one, by issuing a single batched fetch via +// RetrievalService.FetchChunkVectors. Mirrors Python's +// `async_chat._hydrate_chunk_vectors` at +// api/db/services/dialog_service.py:62-106. +// +// The vector dimension is auto-detected from chunks that already carry a +// vector. If no chunk has a vector yet, no fetch is attempted (returns 0). +// +// Returns the number of chunks that gained a vector. +// +// Skips: +// - chunks that already have a non-empty `vector` +// - chunks without a `chunk_id` +// +// Errors are non-fatal: caller logs and proceeds with whatever vectors +// are available. InsertCitations tolerates missing vectors by falling +// back to token-only similarity (when the chat's +// vector_similarity_weight allows). +// +// Parameters: +// - tenantIDs: tenant ID(s) to derive index/table names (ragflow_). +// If empty, no fetch is attempted. +func HydrateChunkVectors(ctx context.Context, kbinfos map[string]interface{}, tenantIDs []string, kbIDs []string, docEngine engine.DocEngine) (int, error) { + if kbinfos == nil { + return 0, nil + } + chunksRaw, ok := kbinfos["chunks"].([]map[string]interface{}) + if !ok || len(chunksRaw) == 0 { + return 0, nil + } + if docEngine == nil { + docEngine = engine.Get() + } + if docEngine == nil { + return 0, nil + } + + // Auto-detect vector dimension from chunks that already carry a + // vector. If none do, there is nothing to hydrate against. + var dim int + var missing []string + for _, cm := range chunksRaw { + if cv, ok := cm["vector"].([]float64); ok && len(cv) > 0 { + if dim == 0 { + dim = len(cv) + } + continue + } + if cid, ok := cm["chunk_id"].(string); ok && cid != "" { + missing = append(missing, cid) + } + } + if len(missing) == 0 || dim == 0 || len(tenantIDs) == 0 { + return 0, nil + } + + // Use RetrievalService which mirrors Python's Dealer.fetch_chunk_vectors. + retrievalSvc := nlp.NewRetrievalService(docEngine, dao.NewDocumentDAO()) + vectors, err := retrievalSvc.FetchChunkVectors(ctx, missing, tenantIDs, kbIDs, dim) + if err != nil { + common.Warn("HydrateChunkVectors: FetchChunkVectors failed", zap.Error(err)) + return 0, err + } + + // Stitch the vectors back onto the chunks. + hits := 0 + for _, cm := range chunksRaw { + if cv, ok := cm["vector"].([]float64); ok && len(cv) > 0 { + continue + } + cid, _ := cm["chunk_id"].(string) + if cid == "" { + continue + } + vec, ok := vectors[cid] + if !ok || len(vec) == 0 { + continue + } + cm["vector"] = vec + hits++ + } + common.Debug("HydrateChunkVectors complete", + zap.Int("hits", hits), zap.Int("requested", len(missing))) + return hits, nil +} + +// embeddingModelEmbedder adapts an EmbeddingModel to the Embedder interface. +type embeddingModelEmbedder struct { + embModel *modelModule.EmbeddingModel +} + +func (e *embeddingModelEmbedder) Encode(texts []string) ([][]float64, error) { + config := &modelModule.EmbeddingConfig{Dimension: 0} + embeds, err := e.embModel.ModelDriver.Embed(e.embModel.ModelName, texts, e.embModel.APIConfig, config) + if err != nil { + return nil, err + } + vecs := make([][]float64, len(embeds)) + for i, v := range embeds { + vecs[i] = v.Embedding + } + return vecs, nil +} + +// decorateAnswer applies citation insertion, reference construction, +// timing stats, token accounting, TTS, and Langfuse generation end to +// the final answer. +// +// P1: the `timer` parameter carries the per-phase durations emitted in the +// `## Time elapsed:` block of the prompt. Caller must have called +// timer.Exit() for PhaseGenerateAnswer before invoking this function. +func (s *ChatPipelineService) decorateAnswer( + ctx context.Context, + answer string, + kbinfos map[string]interface{}, + prompt string, + questions []string, + usedTokenCount int, + timer *common.Timer, + embModel *modelModule.EmbeddingModel, + vectorSimilarityWeight float64, + quote bool, + ttsModel *modelModule.ChatModel, + langfuseTraceID string, + llmModelConfig map[string]interface{}, + tenantID string, + tenantIDs []string, + hasKnowledges bool, +) AsyncChatResult { + + // Handle think markers: split on . + think := "" + ans := answer + if strings.Contains(answer, "") { + parts := strings.Split(answer, "") + if len(parts) == 2 { + think = parts[0] + "" + ans = strings.TrimSpace(parts[1]) + } + } + + var citationIdx map[int]struct{} + var refs map[string]interface{} + // Citation insertion: encode answer sentences, score against chunks, + // and insert [ID:N] markers. Mirrors Python's insert_citations(). + // + // P0.11 (CITATION_MARKER_PATTERN pre-check): if the LLM already emitted + // citation markers in canonical or near-canonical form, skip + // insertCitations to avoid double-tagging. Mirrors + // dialog_service.py:790-802. + if hasKnowledges && quote { + chunksRaw, ok := kbinfos["chunks"].([]map[string]interface{}) + if ok && len(chunksRaw) > 0 { + // P7 — _hydrate_chunk_vectors. Mirrors + // dialog_service.py:794. If any chunk lacks a `vector` + // field (true for the ES path; Infinity ships vectors + // inline), fetch them in one batched engine call. We only + // need this when we'll actually call insertCitations + // (i.e., the LLM didn't already emit markers). + if embModel != nil && !HasCitationMarkers(ans) { + if _, err := HydrateChunkVectors(ctx, kbinfos, tenantIDs, nil, engine.Get()); err != nil { + common.Warn("hydrate chunk vectors failed", zap.Error(err)) + } + } + if embModel != nil && !HasCitationMarkers(ans) { + // Build chunkVectors aligned with chunksRaw. + chunkVectors := make([][]float64, len(chunksRaw)) + allVec := len(chunksRaw) > 0 + for i, cm := range chunksRaw { + cv, _ := cm["vector"].([]float64) + chunkVectors[i] = cv + if len(cv) == 0 { + allVec = false + } + } + if allVec { + embedder := &embeddingModelEmbedder{embModel: embModel} + if decorated, cited := InsertCitations(ans, NewSourcedChunks(chunksRaw), embedder, chunkVectors); len(cited) > 0 { + ans = decorated + citationIdx = make(map[int]struct{}) + for _, ci := range cited { + citationIdx[ci] = struct{}{} + } + } + } + } else { + // P0.11 pre-check matched: collect indices from existing + // markers instead of calling insertCitations. + for _, ci := range ExtractCitationMarkers(ans, len(chunksRaw)) { + if citationIdx == nil { + citationIdx = make(map[int]struct{}) + } + citationIdx[ci] = struct{}{} + } + } + } + + // repair_bad_citation_formats — runs even when chunks are empty. + // Mirrors dialog_service.py:818. + if ok { + ans = RepairBadCitationFormats(ans) + for _, ci := range ExtractCitationMarkers(ans, len(chunksRaw)) { + if citationIdx == nil { + citationIdx = make(map[int]struct{}) + } + citationIdx[ci] = struct{}{} + } + } + + // Map cited chunk indices to doc_ids and filter doc_aggs. + // Mirrors dialog_service.py:820-824. + if len(citationIdx) > 0 { + citedDocIDs := make(map[string]struct{}) + if chunksRaw, ok := kbinfos["chunks"].([]map[string]interface{}); ok { + for ci := range citationIdx { + if ci >= 0 && ci < len(chunksRaw) { + cm := chunksRaw[ci] + if docID, ok := cm["doc_id"].(string); ok && docID != "" { + citedDocIDs[docID] = struct{}{} + } + } + } + } + if len(citedDocIDs) > 0 { + if docAggsRaw, ok := kbinfos["doc_aggs"].([]interface{}); ok && len(docAggsRaw) > 0 { + var filtered []interface{} + for _, da := range docAggsRaw { + if dam, ok := da.(map[string]interface{}); ok { + if docID, ok := dam["doc_id"].(string); ok { + if _, cited := citedDocIDs[docID]; cited { + filtered = append(filtered, da) + } + } + } + } + if len(filtered) > 0 { + kbinfos["doc_aggs"] = filtered + } + } + } + } + } + + // Build refs: deepcopy kbinfos and strip vectors — done whenever + // hasKnowledges is true, regardless of quote flag. + // Mirrors dialog_service.py:826-829. + if hasKnowledges { + refs = make(map[string]interface{}) + for k, v := range kbinfos { + refs[k] = v + } + if chunksRaw, ok := refs["chunks"].([]map[string]interface{}); ok { + newChunks := make([]map[string]interface{}, 0, len(chunksRaw)) + for _, cm := range chunksRaw { + newChunk := make(map[string]interface{}) + for ck, cv := range cm { + if ck == "vector" { + continue + } + newChunk[ck] = cv + } + newChunks = append(newChunks, newChunk) + } + refs["chunks"] = newChunks + } + } + + // Check for invalid API key errors (outside knowledges guard). + // Mirrors dialog_service.py:831-832. + if strings.Contains(strings.ToLower(ans), "invalid key") || + strings.Contains(strings.ToLower(ans), "invalid api") { + ans += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'" + } + + finishChatTs := time.Now() + + // Build timing stats. + // P1: emit Timer.Markdown() (6 phase lines + Total) and then the + // token-count / token-speed lines that the existing OpenAI endpoint + // already exposes. Total wall-clock is rounded to ms. + totalMs := timer.Total().Seconds() * 1000 + tkNum := kg.NumTokensFromString(think + ans) + + prompt += fmt.Sprintf("\n\n### Query:\n%s", strings.Join(questions, " ")) + + timeStats := prompt + timer.Markdown() + "\n" + timeStats += fmt.Sprintf(" - Generated tokens(approximately): %d\n", tkNum) + if totalMs > 0 { + timeStats += fmt.Sprintf(" - Token speed: %d/s", int(float64(tkNum)/(totalMs/1000.0))) + } + + // TTS synthesis for the final answer. + audioBinary := s.synthesizeTTS(ttsModel, think+ans) + + // Langfuse generation end observation. + if langfuseTraceID != "" { + if lfClient, ok := ctx.Value(langfuseCtxKey).(*LangfuseClient); ok && lfClient != nil { + // Mirrors dialog_service.py:853-854. Python extracts + // everything from `### Query:` onwards (the time-elapsed + // + token-usage block) and replaces \n with " \n" for + // markdown line breaks. + langfuseOutput := langfuseExtractTimeElapsed(timeStats) + usage := &LangfuseUsage{ + PromptTokens: usedTokenCount, + CompletionTokens: tkNum, + TotalTokens: usedTokenCount + tkNum, + } + modelName := "" + if llmModelConfig != nil { + if mn, ok := llmModelConfig["llm_name"].(string); ok { + modelName = mn + } + } + _ = lfClient.PostGeneration(ctx, LangfuseGeneration{ + ID: fmt.Sprintf("gen-%s", langfuseTraceID), + TraceID: langfuseTraceID, + Name: "chat", + Model: modelName, + StartTime: time.Now().UTC().Format(time.RFC3339Nano), + EndTime: time.Now().UTC().Format(time.RFC3339Nano), + Output: langfuseOutput, + Usage: usage, + }) + } + } + + return AsyncChatResult{ + Answer: think + ans, + Reference: refs, + AudioBinary: audioBinary, + // Fix 7: Apply the markdown line-break substitution + // re.sub(r"\n", " \n", prompt) at the very end, matching + // dialog_service.py:865. This converts single \n to " \n" + // so multi-line prompt text renders as a single markdown + // paragraph instead of being broken into separate lines. + Prompt: strings.ReplaceAll(timeStats, "\n", " \n"), + CreatedAt: float64(finishChatTs.Unix()), + Final: false, // caller sets Final = true + } +} + +// langfuseExtractTimeElapsed extracts the time-elapsed + token-usage +// block from the prompt and applies the \n → " \n" substitution. +// Mirrors dialog_service.py:853-854: +// +// langfuse_output = "\n" + re.sub(r"^.*?(### Query:.*)", r"\1", prompt, flags=re.DOTALL) +// langfuse_output = {"time_elapsed:": re.sub(r"\n", " \n", langfuse_output), ...} +func langfuseExtractTimeElapsed(prompt string) string { + const marker = "### Query:" + idx := strings.Index(prompt, marker) + if idx < 0 { + // Fallback: return the whole prompt with \n substitution. + return strings.ReplaceAll(prompt, "\n", " \n") + } + return strings.ReplaceAll(prompt[idx:], "\n", " \n") +} + +// extractVisibleAnswer mirrors Python's _extract_visible_answer. +// It preserves wrappers and strips stray think tags. +func (s *ChatPipelineService) extractVisibleAnswer(text string) string { + if !strings.Contains(text, "") { + text = strings.ReplaceAll(text, "", "") + text = strings.ReplaceAll(text, "", "") + return text + } + idx := strings.LastIndex(text, "") + thought := text[:idx] + answer := text[idx+len(""):] + thought = strings.ReplaceAll(thought, "", "") + thought = strings.ReplaceAll(thought, "", "") + thought = strings.TrimSpace(thought) + answer = strings.ReplaceAll(answer, "", "") + answer = strings.ReplaceAll(answer, "", "") + if thought == "" { + return answer + } + return "" + thought + "" + answer +} + +// citationPrompt returns the citation instruction prompt. +// Mirrors Python's citation_prompt() in rag/prompts/generator.py. +func citationPrompt() string { + return "\n\n### Citation\nWhen answering, please cite sources using the format [ID:N] " + + "(where N is the chunk number) after each sentence where the information from that chunk is used." +} + +// --------------------------------------------------------------------------- +// Think-marker streaming — mirrors Python's _stream_with_think_delta. +// --------------------------------------------------------------------------- + +// thinkStreamState tracks accumulated reasoning text and emits deltas. +type thinkStreamState struct { + fullText string + lastIdx int + endsWithThink bool + inThink bool + buffer string + postThinkText string +} + +// nextThinkDelta computes the next delta to emit from the accumulated text. +// Mirrors _next_think_delta in dialog_service.py:1460-1487. +func nextThinkDelta(state *thinkStreamState) string { + full := state.fullText + if full == "" || len(full) <= state.lastIdx { + return "" + } + delta := full[state.lastIdx:] + + if strings.HasPrefix(delta, "") { + state.lastIdx += len("") + return "" + } + if idx := strings.Index(delta, ""); idx > 0 { + state.lastIdx += idx + return delta[:idx] + } + if strings.HasSuffix(delta, "") { + state.endsWithThink = true + } else if state.endsWithThink { + state.endsWithThink = false + remainder := delta + if idx := strings.Index(delta, ""); idx >= 0 { + remainder = delta[idx+len(""):] + } + if remainder != "" { + state.postThinkText = remainder + } + state.lastIdx = len(full) + return "" + } + + state.lastIdx = len(full) + if strings.HasSuffix(full, "") { + state.lastIdx -= len("") + } + return strings.ReplaceAll(strings.ReplaceAll(delta, "", ""), "", "") +} + +// processThinkDelta updates the state with a new delta and returns what to emit. +// Returns the kind of emission: "marker" for think tags, "text" for content, "" for nothing. +func processThinkDelta(state *thinkStreamState, delta string, minTokens int) (kind string, output string) { + if delta == "" { + return "", "" + } + state.fullText += delta + d := nextThinkDelta(state) + if d == "" { + return "", "" + } + if d == "" { + if state.inThink { + return "", "" + } + if state.buffer != "" { + kind, out := "text", state.buffer + state.buffer = "" + state.inThink = true + return kind, out + } + state.inThink = true + return "marker", "" + } + if d == "" { + if !state.inThink { + return "", "" + } + state.inThink = false + if state.postThinkText != "" { + state.buffer += state.postThinkText + state.postThinkText = "" + } + return "marker", "" + } + state.buffer += d + if kg.NumTokensFromString(state.buffer) < minTokens { + return "", "" + } + out := state.buffer + state.buffer = "" + return "text", out +} + +// flushThinkStream flushes any remaining buffered text from the think stream. +func flushThinkStream(state *thinkStreamState) (text string, marker string) { + if state.buffer != "" { + text = state.buffer + state.buffer = "" + } + if state.postThinkText != "" { + if text != "" { + text += state.postThinkText + } else { + text = state.postThinkText + } + state.postThinkText = "" + } + if state.endsWithThink { + marker = "" + state.endsWithThink = false + } + return text, marker +} + +// ----------------------------------------------------------------------- +// Moved from sql_fallback.go (2026-06-12). SQL retrieval system, repair +// helpers, and Python parity helpers. Kept in async_chat.go because the +// orchestrator entry point is s.useSQL at async_chat.go:319. +// ----------------------------------------------------------------------- + +// SQL retrieval system + user prompts, dispatched by engine type. +// Mirrors dialog_service.py:1031-1105. The Go port previously used a +// single engine-agnostic prompt, which made Infinity/OceanBase queries +// fail because the LLM didn't know to use json_extract_string. These +// constants restore parity with Python's three-way engine dispatch. + +// infinitySQLSysPrompt is for Infinity's JSON 'chunk_data' column. +// References docnm (no _kwd suffix) per the Python prompt at +// dialog_service.py:1035-1052. +const infinitySQLSysPrompt = `You are a Database Administrator. Write SQL for a table with JSON 'chunk_data' column. + +JSON Extraction: json_extract_string(chunk_data, '$.FieldName') +Numeric Cast: CAST(json_extract_string(chunk_data, '$.FieldName') AS INTEGER/FLOAT) +NULL Check: json_extract_isnull(chunk_data, '$.FieldName') == false + +RULES: +1. Use EXACT field names (case-sensitive) from the list below +2. For SELECT: include doc_id, docnm, and json_extract_string() for requested fields +3. For COUNT: use COUNT(*) or COUNT(DISTINCT json_extract_string(...)) +4. Add AS alias for extracted field names +5. DO NOT select 'content' field +6. Only add NULL check (json_extract_isnull() == false) in WHERE clause when: + - Question asks to "show me" or "display" specific columns + - Question mentions "not null" or "excluding null" + - Add NULL check for count specific column + - DO NOT add NULL check for COUNT(*) queries (COUNT(*) counts all rows including nulls) +7. json_extract_string() returns JSON-quoted strings ("value"), so WHERE comparisons MUST wrap values in double-quotes inside single-quotes (no spaces between quotes): '"value"' (e.g. WHERE json_extract_string(chunk_data, '$.name') = '"Alice"') +8. For partial text search, use LIKE with wildcards: '"%value%"' (e.g. WHERE json_extract_string(chunk_data, '$.name') LIKE '"%Alice%"') +9. Output ONLY the SQL, no explanations` + +// infinitySQLUserPromptTemplate has 4 %s placeholders: +// table_name, comma-joined field names, bullet list of field names, +// question. Mirrors dialog_service.py:1053-1059. +const infinitySQLUserPromptTemplate = `Table: %s +Fields (EXACT case): %s +%s +Question: %s +Write SQL using json_extract_string() with exact field names. Include doc_id, docnm for data queries. Only SQL.` + +// oceanbaseSQLSysPrompt is identical to Infinity's but uses docnm_kwd +// (the _kwd suffix is the OceanBase convention). Mirrors +// dialog_service.py:1064-1081. +const oceanbaseSQLSysPrompt = `You are a Database Administrator. Write SQL for a table with JSON 'chunk_data' column. + +JSON Extraction: json_extract_string(chunk_data, '$.FieldName') +Numeric Cast: CAST(json_extract_string(chunk_data, '$.FieldName') AS INTEGER/FLOAT) +NULL Check: json_extract_isnull(chunk_data, '$.FieldName') == false + +RULES: +1. Use EXACT field names (case-sensitive) from the list below +2. For SELECT: include doc_id, docnm_kwd, and json_extract_string() for requested fields +3. For COUNT: use COUNT(*) or COUNT(DISTINCT json_extract_string(...)) +4. Add AS alias for extracted field names +5. DO NOT select 'content' field +6. Only add NULL check (json_extract_isnull() == false) in WHERE clause when: + - Question asks to "show me" or "display" specific columns + - Question mentions "not null" or "excluding null" + - Add NULL check for count specific column + - DO NOT add NULL check for COUNT(*) queries (COUNT(*) counts all rows including nulls) +7. Output ONLY the SQL, no explanations` + +// oceanbaseSQLUserPromptTemplate — same shape as Infinity, docnm_kwd in +// the trailing sentence. Mirrors dialog_service.py:1082-1088. +const oceanbaseSQLUserPromptTemplate = `Table: %s +Fields (EXACT case): %s +%s +Question: %s +Write SQL using json_extract_string() with exact field names. Include doc_id, docnm_kwd for data queries. Only SQL.` + +// esSQLSysPrompt is for Elasticsearch / OpenSearch / default engines +// where fields are direct columns (no JSON extraction). Mirrors +// dialog_service.py:1092-1100. +const esSQLSysPrompt = `You are a Database Administrator. Write SQL queries. + +RULES: +1. Use EXACT field names from the schema below (e.g., product_tks, not product) +2. Quote field names starting with digit: "123_field" +3. Add IS NOT NULL in WHERE clause when: + - Question asks to "show me" or "display" specific columns +4. Include doc_id/docnm in non-aggregate statement +5. Output ONLY the SQL, no explanations` + +// esSQLUserPromptTemplate — 3 %s placeholders: table_name, bullet +// list with types, question. Mirrors dialog_service.py:1101-1105. +const esSQLUserPromptTemplate = `Table: %s +Available fields: +%s +Question: %s +Write SQL using exact field names above. Include doc_id, docnm_kwd for data queries. Only SQL.` + +// SQL retrieval repair prompts, split into TWO flows × TWO engine +// families, mirroring dialog_service.py:repair_table_for_missing_source_columns +// (lines 1129-1156) and the execution-error retry at lines 1164-1205. +// +// The previous single-template repair was generic and could not tell +// the LLM to keep using json_extract_string on Infinity, which led +// to fragile repairs. Per-engine prompts make the syntax intent +// explicit on the repair path too. +// +// Flow A (missing-source-columns): the SQL executed successfully but +// the result set is missing doc_id / docnm* columns. We call the LLM +// to rewrite the SQL with those columns added. +// Flow B (execution-error): the SQL failed to execute at all +// (syntax error, unknown column, etc.). We call the LLM with the +// error message and ask for a corrected SQL. +// +// Engine family A (Infinity / OceanBase): data lives in a JSON +// 'chunk_data' column, so JSON-extraction syntax must be preserved. +// Engine family B (Elasticsearch / OpenSearch / default): fields +// are direct columns. + +// infinityMissingColumnsRepairPromptTemplate — 5 %s args: +// table_name, JSON field bullets, question, previous_sql, +// expected_doc_name_column. Mirrors dialog_service.py:1132-1143. +// OceanBase shares this template (line 1130 dispatch) with +// expected_doc_name_column="docnm_kwd" instead of "docnm". +const infinityMissingColumnsRepairPromptTemplate = `Table name: %s; +JSON fields available in 'chunk_data' column (use exact names): +%s + +Question: %s +Previous SQL: +%s + +The previous SQL result is missing required source columns for citations. +Rewrite SQL to keep the same query intent and include doc_id and %s in the SELECT list. +For extracted JSON fields, use json_extract_string(chunk_data, '$.field_name'). +Return ONLY SQL.` + +// esMissingColumnsRepairPromptTemplate — 4 %s args: table_name, +// ES field bullets (with types), question, previous_sql. Mirrors +// dialog_service.py:1145-1155. +const esMissingColumnsRepairPromptTemplate = `Table name: %s +Available fields: +%s + +Question: %s +Previous SQL: +%s + +The previous SQL result is missing required source columns for citations. +Rewrite SQL to keep the same query intent and include doc_id and docnm_kwd in the SELECT list. +Return ONLY SQL.` + +// infinityExecutionErrorRepairPromptTemplate — 4 %s args: +// table_name, JSON field bullets, question, error. Mirrors +// dialog_service.py:1168-1181. Used for both Infinity and OceanBase +// (line 1165 dispatch). +const infinityExecutionErrorRepairPromptTemplate = ` +Table name: %s; +JSON fields available in 'chunk_data' column (use these exact names in json_extract_string): +%s + +Question: %s +Please write the SQL using json_extract_string(chunk_data, '$.field_name') with the field names from the list above. Only SQL, no explanations. + + +The SQL error you provided last time is as follows: +%s + +Please correct the error and write SQL again using json_extract_string(chunk_data, '$.field_name') syntax with the correct field names. Only SQL, no explanations.` + +// esExecutionErrorRepairPromptTemplate — 4 %s args: table_name, +// ES field bullets (with types), question, error. Mirrors +// dialog_service.py:1184-1198. +const esExecutionErrorRepairPromptTemplate = ` +Table name: %s; +Table of database fields are as follows (use the field names directly in SQL): +%s + +Question are as follows: +%s +Please write the SQL using the exact field names above, only SQL, without any other explanations or text. + + +The SQL error you provided last time is as follows: +%s + +Please correct the error and write SQL again using the exact field names above, only SQL, without any other explanations or text.` + +// useSQL is the Go port of dialog_service.use_sql +// (api/db/services/dialog_service.py:914-1226). It branches on the +// active document engine, asks the chat model to produce SQL, +// optionally repairs it once, and executes the query. +// +// The caller is responsible for resolving fieldMap (typically via +// s.KbService.GetFieldMap in AsyncChat before invoking this) so the +// structured-schema lookup happens once per request and is observable +// in logs at the AsyncChat call site. Pass nil/empty to short-circuit. +// +// Returns: +// +// - ans: a map mirroring the Python use_sql return shape: +// +// {"answer": , "reference": {"chunks": [], "doc_aggs": [], "total": }} +// +// or nil when SQL retrieval doesn't apply / produced no usable +// result. The caller checks `ans != nil && (ans["answer"] != "" +// or non-empty chunks)` to decide whether to short-circuit. +// +// - err: non-nil when something went wrong; caller should log and fall +// through. +func (s *ChatPipelineService) useSQL( + ctx context.Context, + chat *entity.Chat, + kbs []*entity.Knowledgebase, + question string, + chatModel *modelModule.ChatModel, + fieldMap map[string]interface{}, + quote bool, +) (ans map[string]interface{}, err error) { + if chat == nil || chatModel == nil || len(kbs) == 0 { + return nil, nil + } + + if fieldMap == nil || len(fieldMap) == 0 { + // No structured schema → SQL retrieval doesn't apply. + return nil, nil + } + + docEngine := engine.Get() + if docEngine == nil { + return nil, nil + } + + // Entry log. Mirrors `logging.debug(f"use_sql: Question: {question}")` + // at dialog_service.py:934. + common.Debug("SQL retrieval: question", zap.String("question", question)) + + // Build the table name. Infinity: ragflow_{tenant}_{kb_id} (one per + // KB). ES: ragflow_{tenant} (kb_id in WHERE). + tableName := ragflowTableName(chat.TenantID, kbs, docEngine) + + // Build engine-specific prompts. Mirrors the three-way dispatch + // at dialog_service.py:1031-1105. + engineName := docEngine.GetType() + sysPrompt, userPrompt, overrideSQL := buildSQLPrompts(engineName, tableName, question, fieldMap) + + // Step 1: generate SQL. If the question is a "how many rows in the + // dataset" row-count question, buildSQLPrompts returns a hard-coded + // override and we skip the LLM call entirely (matches Python + // row_count_override at dialog_service.py:1034/1063). + var sqlText string + if overrideSQL != "" { + sqlText = normalizeSQL(overrideSQL) + common.Debug("SQL retrieval: using row-count override", + zap.String("sql", sqlText)) + } else { + var sqlErr error + sqlText, sqlErr = generateSQL(ctx, chatModel, sysPrompt, userPrompt) + if sqlErr != nil { + common.Warn("SQL retrieval: LLM generation failed", zap.Error(sqlErr)) + return nil, nil + } + } + + // Step 1.5: inject the kb_id WHERE filter for ES / OS / OceanBase. + // No-op for Infinity (the table name already encodes the KB scope). + // Mirrors add_kb_filter at dialog_service.py:992-1021, called from + // get_table right after normalize_sql. + if filtered, ok := addKBFilter(sqlText, engineName, kbs); ok { + sqlText = filtered + } else { + common.Warn("SQL retrieval: invalid kb_id UUID; SQL will run unfiltered") + } + + // Step 2: try to execute. On failure, repair once with the + // engine-specific execution-error prompt so the LLM regenerates + // correctly (Flow B at dialog_service.py:1164-1205). + rows, execErr := docEngine.RunSQL(ctx, tableName, sqlText, kbIDStrings(kbs), "json") + if execErr != nil { + common.Debug("SQL retrieval: initial execution failed, attempting repair", + zap.String("sql", sqlText), zap.Error(execErr)) + repaired, repairErr := repairSQLForExecutionError( + ctx, chatModel, sysPrompt, tableName, question, execErr.Error(), engineName, fieldMap, + ) + if repairErr != nil { + common.Warn("SQL retrieval: repair failed", zap.Error(repairErr)) + return nil, nil + } + // Re-apply the kb filter after the LLM-driven repair. + if filtered, ok := addKBFilter(repaired, engineName, kbs); ok { + repaired = filtered + } + rows, execErr = docEngine.RunSQL(ctx, tableName, repaired, kbIDStrings(kbs), "json") + if execErr != nil { + common.Warn("SQL retrieval: repaired SQL also failed", zap.Error(execErr)) + return nil, nil + } + } + if len(rows) == 0 { + common.Debug("SQL retrieval: execution succeeded but returned 0 rows") + // Empty result set; let vector retrieval try. + return nil, nil + } + + // Step 3 (Python parity): for non-aggregate SQL, check that the + // result has source-citation columns (Flow A at + // dialog_service.py:1211-1221). If missing, call the LLM to + // rewrite the SQL with the right columns and retry once. If the + // repair doesn't yield source columns, fall through to the + // best-effort answer (matches Python's `returning best-effort + // answer` log at line 1221). + if !isAggregateSQL(sqlText) && !hasSourceColumns(rows) { + common.Debug("SQL retrieval: result missing source columns; attempting repair", + zap.String("sql", sqlText)) + expectedCol := expectedDocNameColumn(engineName) + repaired, repairErr := repairSQLForMissingColumns( + ctx, chatModel, sysPrompt, tableName, question, sqlText, expectedCol, engineName, fieldMap, + ) + if repairErr == nil && repaired != "" { + // Re-apply the kb filter after the LLM-driven repair. + if filtered, ok := addKBFilter(repaired, engineName, kbs); ok { + repaired = filtered + } + repairedRows, repairedErr := docEngine.RunSQL(ctx, tableName, repaired, kbIDStrings(kbs), "json") + if repairedErr == nil && len(repairedRows) > 0 && hasSourceColumns(repairedRows) { + common.Debug("SQL retrieval: missing-columns repair succeeded", + zap.String("sql", repaired)) + rows = repairedRows + sqlText = repaired + } else { + common.Warn("SQL retrieval: missing-columns repair did not yield source columns; using best-effort answer", + zap.String("sql", repaired)) + } + } else if repairErr != nil { + common.Warn("SQL retrieval: missing-columns repair failed; using best-effort answer", + zap.Error(repairErr)) + } + } + + // 4. Build the answer and reference from the rows. Mirrors Python's + // `return {"answer": ..., "reference": {"chunks": ..., + // "doc_aggs": ...}, "prompt": sys_prompt}` at dialog_service.py:1361 + // and 1377-1401. buildSQLReference handles all three branches: + // primary (rows have source columns), aggregate secondary fetch, + // and best-effort empty refs. + answerStr, ref := s.buildSQLReference( + ctx, docEngine, tableName, sqlText, rows, + sysPrompt, engineName, kbs, fieldMap, + ) + return map[string]interface{}{ + "answer": answerStr, + "reference": ref, + "prompt": sysPrompt, + }, nil +} + +// ragflowTableName returns the engine-specific SQL target name. +// Mirrors dialog_service.py:954-963. For Infinity with a single KB, +// validates the kb_id is a canonical UUID before interpolating +// (SQL injection guard matching _assert_valid_uuid at +// dialog_service.py:944-949). +func ragflowTableName(tenantID string, kbs []*entity.Knowledgebase, docEngine engine.DocEngine) string { + if docEngine == nil { + return "ragflow_" + tenantID + } + engineName := docEngine.GetType() + if engineName == "infinity" && len(kbs) == 1 { + if !isValidUUID(kbs[0].ID) { + common.Warn("ragflowTableName: invalid kb_id; falling back to base index", + zap.String("kb_id", kbs[0].ID)) + return "ragflow_" + tenantID + } + return fmt.Sprintf("ragflow_%s_%s", tenantID, kbs[0].ID) + } + // Elasticsearch / OpenSearch / default: single index, kb_id in WHERE. + return "ragflow_" + tenantID +} + +// isValidUUID returns true if s is a canonical UUID string (8-4-4-4-12 +// hex format). Used to validate kb_id before SQL interpolation, matching +// Python's _assert_valid_uuid (dialog_service.py:944-949). +var uuidRe = regexp.MustCompile(`^[0-9a-fA-F]{8}(-?[0-9a-fA-F]{4}){3}-?[0-9a-fA-F]{12}$`) + +func isValidUUID(s string) bool { + if s == "" { + return false + } + return uuidRe.MatchString(s) +} + +// addKBFilter injects a validated kb_id WHERE filter into sqlText for +// ES / OS / OceanBase engines. Infinity is a no-op because the table +// name already encodes the KB scope. Mirrors dialog_service.py:992-1021. +// +// Returns the (possibly modified) SQL and a boolean indicating whether +// all kb_ids passed UUID validation. When validation fails, the SQL is +// returned unchanged — the engine will likely reject the un-filtered +// query, triggering the repair path (Python's `_assert_valid_uuid` raises +// ValueError, which `get_table`'s try/except catches and routes to the +// repair flow). +// +// If the SQL already has a WHERE clause with `kb_id =`, the filter is +// not duplicated. Otherwise a fresh WHERE is appended, or `kb_id = '...' +// AND` is prepended to an existing WHERE. +func addKBFilter(sqlText, engineName string, kbs []*entity.Knowledgebase) (string, bool) { + if engineName == "infinity" || len(kbs) == 0 { + return sqlText, true + } + + // Validate all kb_ids as UUIDs. + for _, kb := range kbs { + if kb == nil || !isValidUUID(kb.ID) { + return sqlText, false + } + } + + kbIDs := kbIDStrings(kbs) + var kbFilter string + if len(kbIDs) == 1 { + kbFilter = fmt.Sprintf("kb_id = '%s'", kbIDs[0]) + } else { + parts := make([]string, len(kbIDs)) + for i, kid := range kbIDs { + parts[i] = fmt.Sprintf("kb_id = '%s'", kid) + } + kbFilter = "(" + strings.Join(parts, " OR ") + ")" + } + + lower := strings.ToLower(sqlText) + if !strings.Contains(lower, "where ") { + // No WHERE clause: append one. Honor ORDER BY if present. + if oIdx := strings.Index(lower, "order by"); oIdx >= 0 { + sqlText = sqlText[:oIdx] + " WHERE " + kbFilter + " order by " + sqlText[oIdx+len("order by"):] + } else { + sqlText += " WHERE " + kbFilter + } + } else if !strings.Contains(lower, "kb_id =") && !strings.Contains(lower, "kb_id=") { + // Has WHERE but no kb_id: insert "kb_id = '...' AND" after WHERE. + whereRe := regexp.MustCompile(`(?i)\bwhere\b `) + sqlText = whereRe.ReplaceAllString(sqlText, "where "+kbFilter+" and ") + } + return sqlText, true +} + +// generateSQL calls the chat model to produce a SQL SELECT. +// sysPrompt and userPrompt are pre-built by buildSQLPrompts and already +// carry engine-specific instructions (json_extract_string for Infinity/ +// OceanBase, direct column access for ES/OS). Thin wrapper over +// chatForSQL. +func generateSQL( + ctx context.Context, + chatModel *modelModule.ChatModel, + sysPrompt, userPrompt string, +) (string, error) { + return chatForSQL(ctx, chatModel, sysPrompt, userPrompt, "sql generation") +} + +// buildSQLPrompts returns the (system, user) prompt pair for the +// active document engine, plus an optional row-count override SQL. +// The override is non-empty only for "how many rows in the dataset/ +// table/spreadsheet/excel" questions, matching Python's +// row_count_override at dialog_service.py:1034 and 1063. +// +// engineName comes from docEngine.GetType() and is one of: +// "infinity", "oceanbase", "elasticsearch", "opensearch", or any +// other value (treated as the ES/OS default). +// +// Field names are sorted alphabetically for stable test output and +// to match the order-independent iteration of Python's dict. +func buildSQLPrompts(engineName, tableName, question string, fieldMap map[string]interface{}) (sysPrompt, userPrompt, overrideSQL string) { + names := make([]string, 0, len(fieldMap)) + for k := range fieldMap { + names = append(names, k) + } + sort.Strings(names) + + switch engineName { + case "infinity": + sysPrompt = infinitySQLSysPrompt + bullets := strings.Builder{} + for _, n := range names { + bullets.WriteString(" - " + n + "\n") + } + userPrompt = fmt.Sprintf( + infinitySQLUserPromptTemplate, + tableName, + strings.Join(names, ", "), + strings.TrimRight(bullets.String(), "\n"), + question, + ) + if isRowCountQuestion(question) { + overrideSQL = fmt.Sprintf("SELECT COUNT(*) AS rows FROM %s", tableName) + } + case "oceanbase": + sysPrompt = oceanbaseSQLSysPrompt + bullets := strings.Builder{} + for _, n := range names { + bullets.WriteString(" - " + n + "\n") + } + userPrompt = fmt.Sprintf( + oceanbaseSQLUserPromptTemplate, + tableName, + strings.Join(names, ", "), + strings.TrimRight(bullets.String(), "\n"), + question, + ) + if isRowCountQuestion(question) { + overrideSQL = fmt.Sprintf("SELECT COUNT(*) AS rows FROM %s", tableName) + } + default: + // Elasticsearch / OpenSearch / unknown — direct column access. + sysPrompt = esSQLSysPrompt + bullets := strings.Builder{} + for _, n := range names { + bullets.WriteString(fmt.Sprintf(" - %s (%v)\n", n, fieldMap[n])) + } + userPrompt = fmt.Sprintf( + esSQLUserPromptTemplate, + tableName, + strings.TrimRight(bullets.String(), "\n"), + question, + ) + } + return +} + +// isRowCountQuestion returns true when the question is asking for a +// total row count of a dataset/table. Mirrors Python's +// is_row_count_question at dialog_service.py:1023-1028. Uses word- +// boundary regex (not Contains) to match the Python implementation. +var rowCountPhraseRe = regexp.MustCompile(`(?i)\b(how many rows|number of rows|row count)\b`) +var rowCountSubjectRe = regexp.MustCompile(`(?i)\b(dataset|table|spreadsheet|excel)\b`) + +func isRowCountQuestion(q string) bool { + q = strings.TrimSpace(q) + if q == "" { + return false + } + return rowCountPhraseRe.MatchString(q) && rowCountSubjectRe.MatchString(q) +} + +// ----------------------------------------------------------------------- +// Repair helpers (Python parity: dialog_service.py:1129-1205) +// ----------------------------------------------------------------------- + +// expectedDocNameColumn returns the column name the engine uses for +// the document name in source-citation joins. "docnm" for Infinity +// (no _kwd suffix), "docnm_kwd" for OceanBase / ES / OS / default. +// Mirrors dialog_service.py:965. +func expectedDocNameColumn(engineName string) string { + if engineName == "infinity" { + return "docnm" + } + return "docnm_kwd" +} + +// hasSourceColumns reports whether the SQL result has the columns +// needed to build source citations: doc_id and (docnm OR docnm_kwd). +// Mirrors dialog_service.py:967-970. Returns false for empty rows +// (no schema to inspect). +func hasSourceColumns(rows []map[string]interface{}) bool { + if len(rows) == 0 { + return false + } + names := map[string]bool{} + for k := range rows[0] { + names[strings.ToLower(k)] = true + } + if !names["doc_id"] { + return false + } + return names["docnm_kwd"] || names["docnm"] +} + +// isAggregateSQL reports whether the SQL contains an aggregate +// function call (count, sum, avg, max, min, distinct). Mirrors +// dialog_service.py:972-974. +var aggregateFnRe = regexp.MustCompile(`(?i)\b(count|sum|avg|max|min|distinct)\s*\(`) + +func isAggregateSQL(sqlText string) bool { + return aggregateFnRe.MatchString(sqlText) +} + +// sortedFieldNames returns the field_map keys in alphabetical order. +// Used to format prompt bullets deterministically (matches Python's +// dict-iteration order on small maps, and gives stable test output). +func sortedFieldNames(fieldMap map[string]interface{}) []string { + names := make([]string, 0, len(fieldMap)) + for k := range fieldMap { + names = append(names, k) + } + sort.Strings(names) + return names +} + +// buildMissingColumnsRepairPrompt returns the engine-specific user +// prompt for the missing-source-columns repair flow. The Infinity +// and OceanBase branches share the JSON-column template; ES and +// OpenSearch share the direct-column template. expectedCol is +// "docnm" for Infinity or "docnm_kwd" for everything else. +func buildMissingColumnsRepairPrompt(engineName, tableName, question, prevSQL, expectedCol string, fieldMap map[string]interface{}) string { + isJSONEngine := engineName == "infinity" || engineName == "oceanbase" + names := sortedFieldNames(fieldMap) + bullets := strings.Builder{} + if isJSONEngine { + for _, n := range names { + bullets.WriteString(" - " + n + "\n") + } + return fmt.Sprintf( + infinityMissingColumnsRepairPromptTemplate, + tableName, + strings.TrimRight(bullets.String(), "\n"), + question, prevSQL, expectedCol, + ) + } + // ES / OS: include types in bullets + for _, n := range names { + bullets.WriteString(fmt.Sprintf(" - %s (%v)\n", n, fieldMap[n])) + } + return fmt.Sprintf( + esMissingColumnsRepairPromptTemplate, + tableName, + strings.TrimRight(bullets.String(), "\n"), + question, prevSQL, + ) +} + +// buildExecutionErrorRepairPrompt returns the engine-specific user +// prompt for the execution-error repair flow. +func buildExecutionErrorRepairPrompt(engineName, tableName, question, errMsg string, fieldMap map[string]interface{}) string { + isJSONEngine := engineName == "infinity" || engineName == "oceanbase" + names := sortedFieldNames(fieldMap) + bullets := strings.Builder{} + if isJSONEngine { + for _, n := range names { + bullets.WriteString(" - " + n + "\n") + } + return fmt.Sprintf( + infinityExecutionErrorRepairPromptTemplate, + tableName, + strings.TrimRight(bullets.String(), "\n"), + question, errMsg, + ) + } + for _, n := range names { + bullets.WriteString(fmt.Sprintf(" - %s (%v)\n", n, fieldMap[n])) + } + return fmt.Sprintf( + esExecutionErrorRepairPromptTemplate, + tableName, + strings.TrimRight(bullets.String(), "\n"), + question, errMsg, + ) +} + +// chatForSQL is the shared chat-model invocation for SQL generation +// and both repair flows. Returns the cleaned (normalized) SQL or an +// error. errPrefix is included in error messages to disambiguate +// which flow failed ("sql generation", "sql repair", etc.). +func chatForSQL( + ctx context.Context, + chatModel *modelModule.ChatModel, + sysPrompt, userPrompt, errPrefix string, +) (string, error) { + if chatModel == nil || chatModel.ModelDriver == nil { + return "", fmt.Errorf("nil chat model") + } + // Python uses 0.06 (dialog_service.py:1115) for all SQL LLM calls. + // Match it for parity — 0.0 made the LLM deterministic but produced + // SQL that diverged from the Python reference on some prompts. + tempLow := 0.06 + cfg := &modelModule.ChatConfig{ + Temperature: &tempLow, + } + modelName := "" + if chatModel.ModelName != nil { + modelName = *chatModel.ModelName + } + msgs := []modelModule.Message{ + modelModule.Message{Role: "system", Content: sysPrompt}, + modelModule.Message{Role: "user", Content: userPrompt}, + } + resp, err := chatModel.ModelDriver.ChatWithMessages( + modelName, msgs, chatModel.APIConfig, cfg, + ) + if err != nil { + return "", err + } + if resp == nil || resp.Answer == nil { + return "", fmt.Errorf("%s: empty response", errPrefix) + } + cleaned := normalizeSQL(*resp.Answer) + if cleaned == "" { + return "", fmt.Errorf("%s: empty after normalize", errPrefix) + } + return cleaned, nil +} + +// repairSQLForExecutionError calls the LLM to fix SQL that the engine +// refused to execute (syntax error, unknown column, etc.). Engine- +// specific user prompt keeps the right syntax (json_extract_string +// on Infinity, direct field access on ES). +func repairSQLForExecutionError( + ctx context.Context, + chatModel *modelModule.ChatModel, + sysPrompt, tableName, question, errMsg, engineName string, + fieldMap map[string]interface{}, +) (string, error) { + userPrompt := buildExecutionErrorRepairPrompt(engineName, tableName, question, errMsg, fieldMap) + return chatForSQL(ctx, chatModel, sysPrompt, userPrompt, "sql repair") +} + +// repairSQLForMissingColumns calls the LLM to fix SQL whose result +// set is missing the source-citation columns (doc_id, expectedCol). +// expectedCol is "docnm" for Infinity or "docnm_kwd" for everything +// else — see expectedDocNameColumn. +func repairSQLForMissingColumns( + ctx context.Context, + chatModel *modelModule.ChatModel, + sysPrompt, tableName, question, prevSQL, expectedCol, engineName string, + fieldMap map[string]interface{}, +) (string, error) { + userPrompt := buildMissingColumnsRepairPrompt(engineName, tableName, question, prevSQL, expectedCol, fieldMap) + return chatForSQL(ctx, chatModel, sysPrompt, userPrompt, "sql missing-columns repair") +} + +// normalizeSQL strips LLM artifacts from a SQL response. Mirrors the +// helper at dialog_service.py:976-990. +func normalizeSQL(s string) string { + if s == "" { + return "" + } + // Remove ... blocks. + thinkRe := regexp.MustCompile(`(?s).*?`) + s = thinkRe.ReplaceAllString(s, "") + // Also strip Chinese reasoning markers (思考...) — some models + // (notably Qwen) emit these instead of . Mirrors + // dialog_service.py:985: `re.sub(r"思考\n.*?\n", "", ...)`. + chineseThinkRe := regexp.MustCompile(`(?s)思考\n.*?\n`) + s = chineseThinkRe.ReplaceAllString(s, "") + // Strip Markdown code fences. + fenceRe := regexp.MustCompile("(?i)```(?:sql)?\\s*") + s = fenceRe.ReplaceAllString(s, "") + fenceEnd := regexp.MustCompile("```\\s*$") + s = fenceEnd.ReplaceAllString(s, "") + // Trim trailing semicolons (ES SQL parser doesn't like them) and + // outer whitespace. + s = strings.TrimSpace(s) + s = strings.TrimRight(s, ";") + return strings.TrimSpace(s) +} + +// ----------------------------------------------------------------------- +// Python parity helpers (dialog_service.py:56-59, 1238-1309, 1321-1365) +// ----------------------------------------------------------------------- + +// Redundant-space cleanup regexes. Mirrors +// common.string_utils.remove_redundant_spaces (string_utils.py:20-46). +// Pass 1: drop spaces after a "left boundary" character (parens, <, >). +// Pass 2: drop spaces before a "right boundary" character (parens, !). +var ( + redundantSpacePass1Re = regexp.MustCompile(`([^a-z0-9.,)>\x{ff08}]) +([^ ])`) // left boundary + space + non-space + redundantSpacePass2Re = regexp.MustCompile(`([^ ]) +([^a-z0-9.,(<])`) // non-space + space + right boundary +) + +// removeRedundantSpaces ports common.string_utils.remove_redundant_spaces. +// Two-pass regex cleanup; both passes use case-insensitive matching. +func removeRedundantSpaces(s string) string { + s = redundantSpacePass1Re.ReplaceAllString(s, "$1$2") + s = redundantSpacePass2Re.ReplaceAllString(s, "$1$2") + return s +} + +// ISO timestamp stripping regex. Mirrors the cleanup at +// dialog_service.py:1309. Matches `T13:24:55|` or `T13:24:55.123Z|`. +var isoTimestampCellRe = regexp.MustCompile(`T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|`) + +// stripISOTimestamps removes ISO-8601 timestamps that end a markdown +// table cell. Operates on the full joined rows string (not per-cell). +func stripISOTimestamps(rows string) string { + return isoTimestampCellRe.ReplaceAllString(rows, "|") +} + +// asAliasRe extracts the `AS alias` portion of a SQL column expression. +var asAliasRe = regexp.MustCompile(`(?i)\s+AS\s+([^\s,)]+)`) + +// parenSuffixRe strips `/...` and Chinese-parenthesized suffixes from +// display names (matches the regex in dialog_service.py:1251, 1255, 1263, +// 1269, 1279). The CJK variant `(...)` is intentionally non-greedy +// and stops at the first nested `(` or `)`. +var parenSuffixRe = regexp.MustCompile(`(/.*|([^()]+))`) + +// cleanDisplay applies the Python suffix-cleanup regex to a display name. +func cleanDisplay(s string) string { + return parenSuffixRe.ReplaceAllString(s, "") +} + +// mapColumnName translates a raw SQL column name to a human-readable +// display name using the field_map. Mirrors +// dialog_service.py:1238-1280 exactly. Algorithm: +// 1. Special case: literal "count(star)" → "COUNT(*)". +// 2. Try to extract `AS alias`; if alias is in fieldMap, return its +// cleaned display value (exact, then case-insensitive, then alias +// unchanged). +// 3. No AS: try fieldMap[colName] exact, then case-insensitive. +// 4. Still no match: bulk-replace each fieldMap key with its display +// value in the raw column name (handles bare json_extract_string +// expressions without AS). +func mapColumnName(colName string, fieldMap map[string]interface{}) string { + if strings.EqualFold(colName, "count(star)") { + return "COUNT(*)" + } + if m := asAliasRe.FindStringSubmatch(colName); len(m) >= 2 { + alias := strings.Trim(m[1], `"'`) + if disp, ok := fieldMap[alias]; ok { + return cleanDisplay(fmt.Sprintf("%v", disp)) + } + for k, v := range fieldMap { + if strings.EqualFold(k, alias) { + return cleanDisplay(fmt.Sprintf("%v", v)) + } + } + return alias + } + if disp, ok := fieldMap[colName]; ok { + return cleanDisplay(fmt.Sprintf("%v", disp)) + } + colLower := strings.ToLower(colName) + for k, v := range fieldMap { + if strings.ToLower(k) == colLower { + return cleanDisplay(fmt.Sprintf("%v", v)) + } + } + result := colName + for k, v := range fieldMap { + result = strings.ReplaceAll(result, k, fmt.Sprintf("%v", v)) + } + return cleanDisplay(result) +} + +// chunkKBIDForDoc resolves the kb_id for a citation chunk. Mirrors +// dialog_service.py:56-59. Single-kb queries use the chat's known +// kb_id; multi-kb queries read it from the row. +func chunkKBIDForDoc(rowDict map[string]interface{}, kbIDs []string, docID interface{}) string { + if len(kbIDs) == 1 { + return kbIDs[0] + } + if v, ok := rowDict["kb_id"]; ok && v != nil && v != "" { + return fmt.Sprintf("%v", v) + } + if v, ok := rowDict["kb_id_kwd"]; ok && v != nil && v != "" { + return fmt.Sprintf("%v", v) + } + return "" +} + +// cleanCellValue renders one cell value: replaces "None" with a space +// then runs the redundant-space cleanup. Mirrors dialog_service.py:1298. +func cleanCellValue(v interface{}) string { + s := fmt.Sprintf("%v", v) + s = strings.ReplaceAll(s, "None", " ") + return removeRedundantSpaces(s) +} + +// extractSourceColumnIndexes returns, for a set of SQL result rows, +// parallel slices of column indices that match `doc_id`, +// `docnm_kwd`/`docnm`, and `kb_id`/`kb_id_kwd` (case-insensitive). Also +// returns the full sorted column-name list. The Go RunSQL result is +// already keyed by column name; this helper derives a positional view +// (sorted alphabetically for stable iteration) that mirrors Python's +// `enumerate(tbl["columns"])`. +func extractSourceColumnIndexes(rows []map[string]interface{}) (docIDIdx, docNameIdx, kbIDIdx []int, columns []string) { + if len(rows) == 0 { + return + } + for k := range rows[0] { + columns = append(columns, k) + } + sort.Strings(columns) + for i, c := range columns { + switch strings.ToLower(c) { + case "doc_id": + docIDIdx = append(docIDIdx, i) + case "docnm_kwd", "docnm": + docNameIdx = append(docNameIdx, i) + case "kb_id", "kb_id_kwd": + kbIDIdx = append(kbIDIdx, i) + } + } + return +} + +// WHERE-clause extraction for the aggregate secondary fetch. +// Mirrors dialog_service.py:1321. +var whereClauseRe = regexp.MustCompile(`(?i)\bwhere\b(.+?)(?:\bgroup by\b|\border by\b|\blimit\b|$)`) + +// limitClauseRe detects whether a SQL already has a LIMIT clause. +var limitClauseRe = regexp.MustCompile(`(?i)\blimit\b`) + +// buildChunkFetchSQL extracts the WHERE clause from the original SQL +// and constructs a secondary SQL to fetch source chunks. Mirrors +// dialog_service.py:1327-1331. Returns ("", false) when no WHERE is +// present. The `multiKB` flag controls whether `kb_id` is included +// in the SELECT list (single-kb queries don't need it because the +// caller already knows the kb_id). +func buildChunkFetchSQL(originalSQL, tableName, expectedCol string, multiKB bool) (string, bool) { + m := whereClauseRe.FindStringSubmatch(originalSQL) + if len(m) < 2 { + return "", false + } + where := strings.TrimSpace(m[1]) + kbCol := "" + if multiKB { + kbCol = ", kb_id" + } + sql := fmt.Sprintf("select doc_id, %s%s from %s where %s", + expectedCol, kbCol, tableName, where) + if !limitClauseRe.MatchString(sql) { + sql += " limit 20" + } + return sql, true +} + +// toIfaceSlice converts a []map[string]interface{} to []interface{} for +// the call-site contract at async_chat.go:334, which type-asserts +// `reference["chunks"].([]map[string]interface{})`. +func toIfaceSlice(maps []map[string]interface{}) []interface{} { + out := make([]interface{}, len(maps)) + for i, m := range maps { + out[i] = m + } + return out +} + +// ----------------------------------------------------------------------- +// Aggregate secondary fetch (dialog_service.py:1311-1367) +// ----------------------------------------------------------------------- + +// fetchAggregateChunks runs the secondary "select doc_id, docnm[, kb_id] +// from where [limit 20]" query and uses the +// result to build chunks and doc_aggs. Mirrors the aggregate path in +// dialog_service.py:1311-1365. +// +// Returns (nil, nil) when the secondary fetch should be skipped or +// fails. Skips on Infinity multi-KB (RunSQL rejects), on missing WHERE +// clause, and on engine errors — all matching Python's try/except +// semantics at dialog_service.py:1333-1364. +func (s *ChatPipelineService) fetchAggregateChunks( + ctx context.Context, + docEngine engine.DocEngine, + tableName, originalSQL, expectedCol string, + kbIDs []string, +) (chunks []map[string]interface{}, docAggs []map[string]interface{}) { + multiKB := len(kbIDs) > 1 + + // Infinity's RunSQL rejects multi-KB (see infinity/sql.go:63-65). + // Python's add_kb_filter is a no-op for Infinity, so this branch is + // never exercised in Python either. Skip explicitly to avoid a + // hard error. + if multiKB && docEngine != nil && docEngine.GetType() == "infinity" { + common.Debug("SQL retrieval: skipping aggregate secondary fetch on Infinity multi-KB", + zap.Strings("kb_ids", kbIDs)) + return nil, nil + } + + chunksSQL, ok := buildChunkFetchSQL(originalSQL, tableName, expectedCol, multiKB) + if !ok { + common.Debug("SQL retrieval: aggregate secondary fetch skipped (no WHERE clause)", + zap.String("sql", originalSQL)) + return nil, nil + } + + rows, err := docEngine.RunSQL(ctx, tableName, chunksSQL, kbIDs, "json") + if err != nil { + common.Warn("SQL retrieval: aggregate secondary fetch failed", + zap.String("sql", chunksSQL), zap.Error(err)) + return nil, nil + } + if len(rows) == 0 { + return nil, nil + } + + docIDIdx, docNameIdx, kbIDIdx, columns := extractSourceColumnIndexes(rows) + if len(docIDIdx) == 0 || len(docNameIdx) == 0 { + common.Warn("SQL retrieval: aggregate secondary fetch missing source columns", + zap.Any("columns", columns)) + return nil, nil + } + + chunks = make([]map[string]interface{}, 0, len(rows)) + docAggMap := map[string]map[string]interface{}{} + for _, r := range rows { + docID := r[columns[docIDIdx[0]]] + docName := r[columns[docNameIdx[0]]] + chunk := map[string]interface{}{"doc_id": docID, "docnm_kwd": docName} + kid := chunkKBIDForDoc(r, kbIDs, docID) + if kid == "" && len(kbIDIdx) > 0 { + if v := r[columns[kbIDIdx[0]]]; v != nil && v != "" { + kid = fmt.Sprintf("%v", v) + } + } + if kid != "" { + chunk["kb_id"] = kid + } + chunks = append(chunks, chunk) + + // doc_aggs aggregation: group by doc_id, count occurrences, + // first-seen doc_name wins. + if entry, ok := docAggMap[fmt.Sprintf("%v", docID)]; ok { + entry["count"] = entry["count"].(int) + 1 + } else { + docAggMap[fmt.Sprintf("%v", docID)] = map[string]interface{}{ + "doc_name": docName, + "count": 1, + } + } + } + + docAggs = make([]map[string]interface{}, 0, len(docAggMap)) + for did, d := range docAggMap { + docAggs = append(docAggs, map[string]interface{}{ + "doc_id": did, + "doc_name": d["doc_name"], + "count": d["count"], + }) + } + common.Debug("SQL retrieval: aggregate secondary fetch produced chunks", + zap.Int("chunks", len(chunks)), + zap.Int("doc_aggs", len(docAggs))) + return chunks, docAggs +} + +// ----------------------------------------------------------------------- +// Answer + reference assembly (replaces renderSQLAnswer) +// ----------------------------------------------------------------------- + +// buildSQLReference renders the Markdown table answer and assembles +// the reference (chunks + doc_aggs) for a SQL retrieval result. Mirrors +// dialog_service.py:1282-1401. +// +// Three branches match Python: +// 1. hasSrc: rows themselves carry doc_id + docnm*. Build chunks/doc_aggs +// from the rows directly. (Python L1369-1401.) +// 2. isAggregateSQL: source columns missing. Run a secondary fetch to +// build chunks/doc_aggs; preserve the rendered table as the answer. +// (Python L1311-1367.) +// 3. Non-aggregate missing source: best-effort answer with empty refs. +// (Python L1367.) +// +// Scalar shortcut: when the result is a single-cell (1 row, 1 column), +// return the value directly without a table — matches the previous +// renderSQLAnswer behavior and the Python non-aggregate path's +// one-cell edge case. +func (s *ChatPipelineService) buildSQLReference( + ctx context.Context, + docEngine engine.DocEngine, + tableName, originalSQL string, + rows []map[string]interface{}, + sysPrompt, engineName string, + kbs []*entity.Knowledgebase, + fieldMap map[string]interface{}, +) (string, map[string]interface{}) { + if len(rows) == 0 { + return "No results.", map[string]interface{}{ + "chunks": []map[string]interface{}{}, + "doc_aggs": []interface{}{}, + "total": 0, + } + } + + // Scalar shortcut — matches the previous renderSQLAnswer behavior. + if len(rows) == 1 && len(rows[0]) == 1 { + for _, v := range rows[0] { + return cleanCellValue(v), map[string]interface{}{ + "chunks": []map[string]interface{}{}, + "doc_aggs": []interface{}{}, + "total": 1, + } + } + } + + kbIDs := kbIDStrings(kbs) + docIDIdx, docNameIdx, kbIDIdx, columns := extractSourceColumnIndexes(rows) + expectedCol := expectedDocNameColumn(engineName) + hasSrc := len(docIDIdx) > 0 && len(docNameIdx) > 0 + + // Build the set of "display column" indices (everything except + // doc_id, docnm*, kb_id*). Python uses set subtraction at + // dialog_service.py:1232. + exclude := map[int]bool{} + for _, i := range docIDIdx { + exclude[i] = true + } + for _, i := range docNameIdx { + exclude[i] = true + } + for _, i := range kbIDIdx { + exclude[i] = true + } + displayCols := make([]int, 0, len(columns)) + for i := range columns { + if !exclude[i] { + displayCols = append(displayCols, i) + } + } + + // --- Header --- + var header strings.Builder + header.WriteString("|") + for _, i := range displayCols { + header.WriteString(mapColumnName(columns[i], fieldMap)) + header.WriteString("|") + } + if hasSrc { + header.WriteString("Source|") + } + + // --- Separator (Python L1285) --- + sep := strings.Repeat("|------", len(displayCols)) + "|" + if hasSrc { + sep += "------|" + } + + // --- Body rows + ##N$$ citation markers --- + bodyRows := make([]string, 0, len(rows)) + for rowIdx, r := range rows { + var cells strings.Builder + cells.WriteString("|") + for _, i := range displayCols { + cells.WriteString(cleanCellValue(r[columns[i]])) + cells.WriteString("|") + } + if hasSrc { + cells.WriteString(fmt.Sprintf(" ##%d$$|", rowIdx)) + } + // Skip rows that are entirely empty/whitespace (Python's + // `if re.sub(r"[ |]+", "", row_str)` filter at L1303). + rowStr := cells.String() + if strings.TrimSpace(strings.ReplaceAll(strings.ReplaceAll(rowStr, "|", ""), " ", "")) != "" { + bodyRows = append(bodyRows, rowStr) + } + } + rowsJoined := stripISOTimestamps(strings.Join(bodyRows, "\n")) + + answer := strings.Join([]string{header.String(), sep, rowsJoined}, "\n") + + // --- Reference: chunks + doc_aggs --- + ref := map[string]interface{}{ + "chunks": []map[string]interface{}{}, + "doc_aggs": []interface{}{}, + "total": len(rows), + } + + if hasSrc { + // Primary path — build chunks and doc_aggs from rows. + chunks := make([]map[string]interface{}, 0, len(rows)) + docAggMap := map[string]map[string]interface{}{} + for _, r := range rows { + did := r[columns[docIDIdx[0]]] + dn := r[columns[docNameIdx[0]]] + entry := map[string]interface{}{"doc_id": did, "docnm_kwd": dn} + if kid := chunkKBIDForDoc(r, kbIDs, did); kid != "" { + entry["kb_id"] = kid + } else if len(kbIDIdx) > 0 { + if v := r[columns[kbIDIdx[0]]]; v != nil && v != "" { + entry["kb_id"] = fmt.Sprintf("%v", v) + } + } + chunks = append(chunks, entry) + + docIDKey := fmt.Sprintf("%v", did) + if e, ok := docAggMap[docIDKey]; ok { + e["count"] = e["count"].(int) + 1 + } else { + docAggMap[docIDKey] = map[string]interface{}{ + "doc_name": dn, + "count": 1, + } + } + } + docAggs := make([]map[string]interface{}, 0, len(docAggMap)) + for did, d := range docAggMap { + docAggs = append(docAggs, map[string]interface{}{ + "doc_id": did, + "doc_name": d["doc_name"], + "count": d["count"], + }) + } + ref["chunks"] = chunks + ref["doc_aggs"] = docAggs + return answer, ref + } + + // Source columns missing — try the aggregate secondary fetch. + if isAggregateSQL(originalSQL) { + chunks, docAggs := s.fetchAggregateChunks(ctx, docEngine, tableName, originalSQL, expectedCol, kbIDs) + if len(chunks) > 0 { + ref["chunks"] = chunks + ref["doc_aggs"] = docAggs + } + return answer, ref + } + + // Non-aggregate, no source columns: best-effort empty refs. + common.Debug("SQL retrieval: non-aggregate SQL missing source columns; returning best-effort answer", + zap.String("sql", originalSQL)) + return answer, ref +} + +// jsonMarshal is a small wrapper around encoding/json to keep this +// file's imports tidy. +func jsonMarshal(v interface{}) ([]byte, error) { + return json.Marshal(v) +} + +// kbIDStrings extracts the string KB IDs from a slice of Knowledgebase. +// Returns nil if no KB has a non-empty ID. Mirrors Python's `kb_ids` +// iteration in dialog_service.py:651-660. +func kbTenantIDStrings(kbs []*entity.Knowledgebase) []string { + if len(kbs) == 0 { + return nil + } + seen := make(map[string]struct{}) + out := make([]string, 0, len(kbs)) + for _, kb := range kbs { + if kb == nil { + continue + } + if kb.TenantID != "" { + if _, ok := seen[kb.TenantID]; !ok { + seen[kb.TenantID] = struct{}{} + out = append(out, kb.TenantID) + } + } + } + if len(out) == 0 { + return nil + } + return out +} + +// BuildChatConfig converts the dialog's LLM setting (with optional +// per-request overrides) into a typed ChatConfig for the LLM driver. +// Dialog values are read first; request config values win when present. +func BuildChatConfig(dialog *entity.Chat, config map[string]interface{}) *modelModule.ChatConfig { + cfg := &modelModule.ChatConfig{} + + if dialog.LLMSetting != nil { + if v, ok := dialog.LLMSetting["stream"].(bool); ok { + cfg.Stream = &v + } + if v, ok := dialog.LLMSetting["thinking"].(bool); ok { + cfg.Thinking = &v + } + if v, ok := dialog.LLMSetting["max_tokens"].(float64); ok { + i := int(v) + cfg.MaxTokens = &i + } + if v, ok := dialog.LLMSetting["temperature"].(float64); ok { + cfg.Temperature = &v + } + if v, ok := dialog.LLMSetting["top_p"].(float64); ok { + cfg.TopP = &v + } + if v, ok := dialog.LLMSetting["do_sample"].(bool); ok { + cfg.DoSample = &v + } + if v, ok := dialog.LLMSetting["stop"].([]interface{}); ok { + stops := make([]string, 0, len(v)) + for _, s := range v { + if str, ok := s.(string); ok { + stops = append(stops, str) + } + } + cfg.Stop = &stops + } + if v, ok := dialog.LLMSetting["model_class"].(string); ok { + cfg.ModelClass = &v + } + if v, ok := dialog.LLMSetting["effort"].(string); ok { + cfg.Effort = &v + } + if v, ok := dialog.LLMSetting["verbosity"].(string); ok { + cfg.Verbosity = &v + } + } + + if config != nil { + if v, ok := config["stream"].(bool); ok { + cfg.Stream = &v + } + if v, ok := config["thinking"].(bool); ok { + cfg.Thinking = &v + } + if v, ok := config["max_tokens"].(float64); ok { + i := int(v) + cfg.MaxTokens = &i + } + if v, ok := config["temperature"].(float64); ok { + cfg.Temperature = &v + } + if v, ok := config["top_p"].(float64); ok { + cfg.TopP = &v + } + if v, ok := config["do_sample"].(bool); ok { + cfg.DoSample = &v + } + if v, ok := config["stop"].([]interface{}); ok { + stops := make([]string, 0, len(v)) + for _, s := range v { + if str, ok := s.(string); ok { + stops = append(stops, str) + } + } + cfg.Stop = &stops + } + if v, ok := config["model_class"].(string); ok { + cfg.ModelClass = &v + } + if v, ok := config["effort"].(string); ok { + cfg.Effort = &v + } + if v, ok := config["verbosity"].(string); ok { + cfg.Verbosity = &v + } + } + + return cfg +} + +func kbIDStrings(kbs []*entity.Knowledgebase) []string { + if len(kbs) == 0 { + return nil + } + out := make([]string, 0, len(kbs)) + for _, kb := range kbs { + if kb == nil { + continue + } + if kb.ID != "" { + out = append(out, kb.ID) + } + } + if len(out) == 0 { + return nil + } + return out +} diff --git a/internal/service/chat_pipeline_test.go b/internal/service/chat_pipeline_test.go new file mode 100644 index 0000000000..f4f3d67526 --- /dev/null +++ b/internal/service/chat_pipeline_test.go @@ -0,0 +1,1464 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package service + +import ( + "context" + "fmt" + "reflect" + "strings" + "testing" + + "ragflow/internal/common" + "ragflow/internal/engine" + "ragflow/internal/entity" +) + +// dialForTest builds a minimal *entity.Chat suitable for the +// guard-clause tests. KBs are empty so AsyncChat goes through +// AsyncChatSolo. +func dialForTest(llmid string) *entity.Chat { + return &entity.Chat{ + ID: "chat-1", + TenantID: "tenant-1", + LLMID: llmid, + PromptConfig: map[string]interface{}{ + "system": "you are a test assistant.", + "quote": true, + "refine_multiturn": false, + "keyword": false, + "use_kg": false, + "toc_enhance": false, + }, + KBIDs: []interface{}{}, + VectorSimilarityWeight: 0.3, + } +} + +// newTimerAndPrompt builds a fresh Timer with all 6 phases recorded +// (with ~0 durations), so decorateAnswer emits the full Markdown +// block. +func newTimerAndPrompt() (*common.Timer, string) { + t := common.NewTimer() + t.Start() + for _, p := range []common.Phase{ + common.PhaseCheckLLM, + common.PhaseBindModels, + common.PhaseRetrieval, + common.PhaseGenerateAnswer, + } { + t.Enter(p) + t.Exit(p) + } + return t, "Test prompt" +} + +// --- P9 / P5 guard-clause tests on AsyncChat (P0 indirectly: input +// validation runs before any RAG pipeline) --- + +// TestAsyncChat_RejectsNonUserLastMessage covers the assertion at +// chat_pipeline.go:167. The OpenAI handler is supposed to enforce +// this, but a defense-in-depth check inside AsyncChat guards +// against misbehaving callers. +func TestAsyncChat_RejectsNonUserLastMessage(t *testing.T) { + s := &ChatPipelineService{} + messages := []map[string]interface{}{ + {"role": "user", "content": "first"}, + {"role": "assistant", "content": "last message must not be assistant"}, + } + _, err := s.AsyncChat(context.Background(), dialForTest(""), messages, false, nil) + if err == nil { + t.Fatal("expected error for non-user last message, got nil") + } + if !strings.Contains(err.Error(), "not from user") { + t.Errorf("unexpected error: %v", err) + } +} + +// TestAsyncChat_EmptyMessages covers the empty-messages case. The +// service should return an error before spawning the goroutine. +func TestAsyncChat_EmptyMessages(t *testing.T) { + s := &ChatPipelineService{} + _, err := s.AsyncChat(context.Background(), dialForTest(""), nil, false, nil) + if err == nil { + t.Fatal("expected error for empty messages, got nil") + } +} + +// --- P1 Timer + decorateAnswer tests (P0/P1/P7 surface) --- + +// TestDecorateAnswer_TimerFormatAlwaysEmitted pins the Markdown +// layout of Timer, ensuring all six phase lines plus Total appear. +func TestDecorateAnswer_TimerFormatAlwaysEmitted(t *testing.T) { + s := &ChatPipelineService{} + timer, _ := newTimerAndPrompt() + result := s.decorateAnswer( + context.Background(), + "hello world", + map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}, + "system prompt", + []string{"question"}, + 0, + timer, + nil, 0.0, false, + nil, + "", + nil, + "", + nil, + false, + ) + md := result.Prompt + for _, must := range []string{ + "## Time elapsed:", + " - Check LLM:", + " - Bind models:", + " - Retrieval:", + " - Generate answer:", + " - Total:", + "Generated tokens(approximately):", + } { + if !strings.Contains(md, must) { + t.Errorf("decorateAnswer prompt missing %q in:\n%s", must, md) + } + } +} + +// TestDecorateAnswer_ThinkMarkersPreserved covers the split +// at decorateAnswer: when the LLM emits a think block, decorateAnswer +// moves the think block to the front of the final answer. +func TestDecorateAnswer_ThinkMarkersPreserved(t *testing.T) { + s := &ChatPipelineService{} + timer, _ := newTimerAndPrompt() + result := s.decorateAnswer( + context.Background(), + "reasoningvisible answer", + map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}, + "system prompt", + []string{"q"}, + 0, + timer, + nil, 0.0, false, + nil, + "", + nil, + "", + nil, + false, + ) + if !strings.HasPrefix(result.Answer, "reasoning") { + t.Errorf("expected think block at start, got %q", result.Answer) + } + if !strings.Contains(result.Answer, "visible answer") { + t.Errorf("expected visible answer in result, got %q", result.Answer) + } +} + +// TestDecorateAnswer_InvalidKeySuffix ensures the "Invalid API key" +// append path runs. This is an LLM error-marker check; the message +// survives cleanup. +func TestDecorateAnswer_InvalidKeySuffix(t *testing.T) { + s := &ChatPipelineService{} + timer, _ := newTimerAndPrompt() + result := s.decorateAnswer( + context.Background(), + "oops: invalid api key", + map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}, + "system prompt", + []string{"q"}, + 0, + timer, + nil, 0.0, false, + nil, + "", + nil, + "", + nil, + false, + ) + if !strings.Contains(result.Answer, "Please set LLM API-Key") { + t.Errorf("expected API-key hint, got %q", result.Answer) + } +} + +// TestDecorateAnswer_LeavesCanonicalMarkers covers P0: the decorator +// passes canonical [ID:N] markers through unchanged when there are +// no chunks to cite (so insertCitations is skipped). +func TestDecorateAnswer_LeavesCanonicalMarkers(t *testing.T) { + s := &ChatPipelineService{} + timer, _ := newTimerAndPrompt() + result := s.decorateAnswer( + context.Background(), + "see [ID:12] for details", + map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}, + "system prompt", + []string{"q"}, + 0, + timer, + nil, 0.0, false, + nil, + "", + nil, + "", + nil, + false, + ) + if !strings.Contains(result.Answer, "[ID:12]") { + t.Errorf("canonical marker must survive decorateAnswer, got %q", result.Answer) + } +} + +// TestDecorateAnswer_RepairNotRunWhenNoQuote covers P0.10: when +// quote=false, the citation-repair branch is gated off and the +// answer is preserved verbatim. +func TestDecorateAnswer_RepairNotRunWhenNoQuote(t *testing.T) { + s := &ChatPipelineService{} + timer, _ := newTimerAndPrompt() + result := s.decorateAnswer( + context.Background(), + "see (ID: 12) for details", + map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}, + "system prompt", + []string{"q"}, + 0, + timer, + nil, 0.0, false, + nil, + "", + nil, + "", + nil, + false, + ) + if result.Answer != "see (ID: 12) for details" { + t.Errorf("quote=false must not repair, got %q", result.Answer) + } +} + +// TestDecorateAnswer_RepairRunsWhenQuote covers P0.10: when quote=true +// and the answer has bad citation shapes, RepairBadCitationFormats +// runs and produces canonical [ID:N] form. +func TestDecorateAnswer_RepairRunsWhenQuote(t *testing.T) { + s := &ChatPipelineService{} + timer, _ := newTimerAndPrompt() + // Repair requires at least one chunk (mirrors Python's + // `if knowledges and ...` guard). We provide one stub chunk so + // the repair block runs. + kb := map[string]interface{}{ + "chunks": []map[string]interface{}{ + map[string]interface{}{ + "chunk_id": "c1", + "content_with_weight": "hello world", + "doc_id": "d1", + }, + }, + "doc_aggs": []interface{}{}, + } + result := s.decorateAnswer( + context.Background(), + "see (ID: 12) for details", + kb, + "system prompt", + []string{"q"}, + 0, + timer, + nil, 0.0, true, // quote=true + nil, + "", + nil, + "", + nil, + true, + ) + if !strings.Contains(result.Answer, "[ID:12]") { + t.Errorf("quote=true must repair to [ID:12], got %q", result.Answer) + } +} + +// TestDecorateAnswer_PreCheckSkipsInsertCitations covers P0.11: when +// the LLM already emitted canonical [ID:N] markers, insertCitations +// is skipped (so we don't double-tag). We verify by checking that +// the final answer keeps the same marker count we sent in. +func TestDecorateAnswer_PreCheckSkipsInsertCitations(t *testing.T) { + s := &ChatPipelineService{} + timer, _ := newTimerAndPrompt() + in := "answer has [ID:3] already in it" + result := s.decorateAnswer( + context.Background(), + in, + map[string]interface{}{ + // No chunks → insertCitations path is gated off anyway, + // but the pre-check still works on the answer. + "chunks": []map[string]interface{}{}, + "doc_aggs": []interface{}{}, + }, + "system prompt", + []string{"q"}, + 0, + timer, + nil, 0.0, true, + nil, + "", + nil, + "", + nil, + false, + ) + // Marker must be preserved (idempotent re-formatting only). + if strings.Count(result.Answer, "[ID:3]") < 1 { + t.Errorf("expected [ID:3] preserved, got %q", result.Answer) + } +} + +// --- P2 helpers --- + +// TestKBIDStrings_ExtractsAndFilters pins the contract of the +// KB-id-string helper used by SQL retrieval, KG retrieval, and +// DeepResearcher. +func TestKBIDStrings_ExtractsAndFilters(t *testing.T) { + cases := []struct { + name string + in []*entity.Knowledgebase + want []string + }{ + {"nil", nil, nil}, + {"empty", []*entity.Knowledgebase{}, nil}, + {"all empty IDs", []*entity.Knowledgebase{{ID: ""}, {ID: ""}}, nil}, + {"mixed", []*entity.Knowledgebase{{ID: "kb-1"}, nil, {ID: "kb-2"}}, []string{"kb-1", "kb-2"}}, + {"all set", []*entity.Knowledgebase{{ID: "a"}, {ID: "b"}}, []string{"a", "b"}}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + got := kbIDStrings(c.in) + if len(got) != len(c.want) { + t.Fatalf("kbIDStrings(%v) = %v, want %v", c.in, got, c.want) + } + for i := range got { + if got[i] != c.want[i] { + t.Errorf("kbIDStrings[%d] = %q, want %q", i, got[i], c.want[i]) + } + } + }) + } +} + +// TestLastUserQuestion covers the helper that mirrors Python's +// `questions[-1]` access for meta_data_filter. +func TestLastUserQuestion(t *testing.T) { + cases := []struct { + name string + in []map[string]interface{} + want string + }{ + {"empty", nil, ""}, + {"no user", []map[string]interface{}{{"role": "system", "content": "x"}}, ""}, + {"single user", []map[string]interface{}{{"role": "user", "content": "hello"}}, "hello"}, + {"multi-turn picks last user", []map[string]interface{}{ + {"role": "user", "content": "first"}, + {"role": "assistant", "content": "ok"}, + {"role": "user", "content": "second"}, + }, "second"}, + {"non-string content", []map[string]interface{}{ + {"role": "user", "content": map[string]interface{}{"x": 1}}, + }, ""}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + if got := lastUserQuestion(c.in); got != c.want { + t.Errorf("lastUserQuestion(%v) = %q, want %q", c.in, got, c.want) + } + }) + } +} + +// --- P8 factory extraction test --- + +// TestFactoryFromLLMID covers the helper that pulls the provider +// segment out of a composite LLMID for P8 multimodal dispatch. +func TestFactoryFromLLMID(t *testing.T) { + cases := []struct { + in string + want string + }{ + {"", "openai"}, + {"plain-model", "openai"}, + {"qwen@local", "openai"}, // only one @ — fall back + {"Qwen3-8B@ling@SILICONFLOW", "siliconflow"}, + {"GPT-4@openai", "openai"}, + {"claude@user@anthropic", "anthropic"}, + {"gemini-1.5@vertex@GEMINI", "gemini"}, + } + for _, c := range cases { + if got := factoryFromLLMID(c.in); got != c.want { + t.Errorf("factoryFromLLMID(%q) = %q, want %q", c.in, got, c.want) + } + } +} + +// --- P6 FullQuestion helper test --- + +// TestFallbackToLatestUser pins the contract of the helper used by +// FullQuestion: when the LLM fails, we fall back to the latest user +// message content. +func TestFallbackToLatestUser(t *testing.T) { + cases := []struct { + name string + in []map[string]interface{} + want string + }{ + {"empty", nil, ""}, + {"no user", []map[string]interface{}{{"role": "system", "content": "x"}}, ""}, + {"single user", []map[string]interface{}{{"role": "user", "content": "hello"}}, "hello"}, + {"multi-turn picks last user", []map[string]interface{}{ + {"role": "user", "content": "first"}, + {"role": "assistant", "content": "ok"}, + {"role": "user", "content": "second"}, + }, "second"}, + {"non-string content", []map[string]interface{}{ + {"role": "user", "content": map[string]interface{}{"x": 1}}, + }, ""}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + if got := fallbackToLatestUser(c.in); got != c.want { + t.Errorf("fallbackToLatestUser(%v) = %q, want %q", c.in, got, c.want) + } + }) + } +} + +// --- P0/Hydration tests --- + +// TestHydrateChunkVectors_NoChunksNoop pins the no-op behavior of +// the hydration helper on empty input. +func TestHydrateChunkVectors_NoChunksNoop(t *testing.T) { + hits, err := HydrateChunkVectors(context.Background(), + map[string]interface{}{"chunks": []interface{}{}}, + nil, nil, nil, + ) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if hits != 0 { + t.Errorf("expected 0 hits, got %d", hits) + } +} + +// TestHydrateChunkVectors_NilKbinfosNoop pins the no-op behavior of +// the hydration helper on nil kbinfos. +func TestHydrateChunkVectors_NilKbinfosNoop(t *testing.T) { + hits, err := HydrateChunkVectors(context.Background(), nil, nil, nil, nil) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if hits != 0 { + t.Errorf("expected 0 hits, got %d", hits) + } +} + +// --- AsyncChatResult zero-value test --- + +// TestAsyncChatResult_FinalFlagDefaultsFalse pins the zero-value +// behavior of AsyncChatResult. A non-final delta must not have +// Final=true; only the terminal result does. +func TestAsyncChatResult_FinalFlagDefaultsFalse(t *testing.T) { + var r AsyncChatResult + if r.Final { + t.Errorf("zero-value AsyncChatResult should not be Final") + } + if r.Answer != "" { + t.Errorf("zero-value Answer = %q, want empty", r.Answer) + } + if r.Prompt != "" { + t.Errorf("zero-value Prompt = %q, want empty", r.Prompt) + } + if r.Reference != nil { + t.Errorf("zero-value Reference = %v, want nil", r.Reference) + } +} + +// --- P5 SQL retrieval normalization --- + +// TestNormalizeSQL_StripsThinkBlocks covers the cleanup that runs on +// the LLM-generated SQL before it's handed to the engine. +func TestNormalizeSQL_StripsThinkBlocks(t *testing.T) { + cases := []struct { + name string + in string + want string + }{ + {"empty", "", ""}, + {"plain", "SELECT 1", "SELECT 1"}, + {"think block", "xSELECT 1", "SELECT 1"}, + {"code fence", "```sql\nSELECT 1\n```", "SELECT 1"}, + {"trailing semicolon", "SELECT 1;", "SELECT 1"}, + {"all of the above", "x```sql\nSELECT 1;\n```", "SELECT 1"}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + if got := normalizeSQL(c.in); got != c.want { + t.Errorf("normalizeSQL(%q) = %q, want %q", c.in, got, c.want) + } + }) + } +} + +// TestBuildSQLReference_Scalar covers the single-row scalar shortcut in +// buildSQLReference. The path mirrors the previous +// `TestRenderSQLAnswer_Scalar` but goes through the new entry point, +// which requires constructing a minimal OpenAIChatService. +func TestBuildSQLReference_Scalar(t *testing.T) { + s := &ChatPipelineService{} + ans, ref := s.buildSQLReference( + context.Background(), nil, "", "", + []map[string]interface{}{{"count": 42.0}}, + "", "", nil, nil, + ) + if ans != "42" { + t.Errorf("buildSQLReference scalar answer = %q, want %q", ans, "42") + } + // Scalar branch returns empty chunks/doc_aggs and total=1. + if chunks, _ := ref["chunks"].([]map[string]interface{}); len(chunks) != 0 { + t.Errorf("scalar branch chunks = %v, want empty", chunks) + } + if total, _ := ref["total"].(int); total != 1 { + t.Errorf("scalar branch total = %d, want 1", total) + } +} + +// TestBuildSQLReference_MultiRowTable covers the markdown-table branch +// (multi-row, multi-column) and verifies that display columns and rows +// render correctly. Mirrors the previous +// `TestRenderSQLAnswer_MultiRowTable`. +func TestBuildSQLReference_MultiRowTable(t *testing.T) { + rows := []map[string]interface{}{ + {"id": 1.0, "name": "alice"}, + {"id": 2.0, "name": "bob"}, + } + s := &ChatPipelineService{} + ans, ref := s.buildSQLReference( + context.Background(), nil, "", "select id, name from t", + rows, + "sys", "elasticsearch", nil, nil, + ) + // No source columns → empty chunks/doc_aggs. + if chunks, _ := ref["chunks"].([]map[string]interface{}); len(chunks) != 0 { + t.Errorf("non-source path chunks = %v, want empty", chunks) + } + if !strings.Contains(ans, "|id|") || !strings.Contains(ans, "|name|") { + t.Errorf("expected header row, got:\n%s", ans) + } + if !strings.Contains(ans, "|alice|") || !strings.Contains(ans, "|bob|") { + t.Errorf("expected data rows, got:\n%s", ans) + } + if !strings.Contains(ans, "|------") { + t.Errorf("expected separator row, got:\n%s", ans) + } +} + +// --- P4 _resolve_reference_metadata --- + +// TestResolveReferenceMetadata covers the prompt_config + kwargs +// resolution (matches Python's +// `resolve_reference_metadata_preferences` at +// api/utils/reference_metadata_utils.py:22-62). +func TestResolveReferenceMetadata(t *testing.T) { + s := &ChatPipelineService{} + cases := []struct { + name string + promptCfg map[string]interface{} + kwargs map[string]interface{} + wantInc bool + wantFields []string + }{ + {"all nil", nil, nil, false, nil}, + {"prompt_config only, include=false", map[string]interface{}{ + "reference_metadata": map[string]interface{}{"include": false}, + }, nil, false, nil}, + {"prompt_config only, include=true no fields", map[string]interface{}{ + "reference_metadata": map[string]interface{}{"include": true}, + }, nil, true, nil}, + {"kwargs override prompt_config", map[string]interface{}{ + "reference_metadata": map[string]interface{}{"include": true, "fields": []string{"a"}}, + }, map[string]interface{}{ + "include_metadata": false, + }, false, nil}, + {"kwargs include_metadata true", nil, map[string]interface{}{ + "include_metadata": true, + }, true, nil}, + {"kwargs metadata_fields only", nil, map[string]interface{}{ + "include_metadata": true, + "metadata_fields": []string{"author", "title"}, + }, true, []string{"author", "title"}}, + {"kwargs reference_metadata sub-dict wins", map[string]interface{}{ + "reference_metadata": map[string]interface{}{"include": true, "fields": []string{"from_config"}}, + }, map[string]interface{}{ + "reference_metadata": map[string]interface{}{"include": true, "fields": []string{"from_request"}}, + }, true, []string{"from_request"}}, + {"fields as []interface{} coerced to []string", map[string]interface{}{ + "reference_metadata": map[string]interface{}{"include": true, "fields": []interface{}{"a", "b", "c"}}, + }, nil, true, []string{"a", "b", "c"}}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + inc, fields := s.resolveReferenceMetadata(c.promptCfg, c.kwargs) + if inc != c.wantInc { + t.Errorf("include = %v, want %v", inc, c.wantInc) + } + if !reflect.DeepEqual(fields, c.wantFields) { + t.Errorf("fields = %v, want %v", fields, c.wantFields) + } + }) + } +} + +// TestDecorateAnswer_VectorStrippedFromReference covers the +// reference-construction step: chunks in the Reference map have +// their `vector` field stripped (so they don't bloat the response). +func TestDecorateAnswer_VectorStrippedFromReference(t *testing.T) { + s := &ChatPipelineService{} + timer, _ := newTimerAndPrompt() + kb := map[string]interface{}{ + "chunks": []map[string]interface{}{ + map[string]interface{}{ + "chunk_id": "c1", + "content_with_weight": "hello world", + "vector": []float64{0.1, 0.2, 0.3}, + "doc_id": "d1", + }, + }, + "doc_aggs": []interface{}{}, + } + result := s.decorateAnswer( + context.Background(), + "x", + kb, + "system prompt", + []string{"q"}, + 0, + timer, + nil, 0.0, false, + nil, + "", + nil, + "", + nil, + true, + ) + chunks, ok := result.Reference["chunks"].([]map[string]interface{}) + if !ok || len(chunks) == 0 { + t.Fatalf("Reference.chunks missing: %+v", result.Reference) + } + chunk := chunks[0] + if _, has := chunk["vector"]; has { + t.Errorf("vector field should be stripped from reference chunks, got %+v", chunk) + } +} + +// --- normalizeInternetFlag / shouldUseWebSearch parity with Python --- + +// TestNormalizeInternetFlag_PythonParity pins the three-state return of +// the Go port against every input shape _normalize_internet_flag accepts +// in dialog_service.py:108-119. The key user-visible additions vs the +// previous Go implementation are the truthy aliases "yes" / "on" / "1" +// and the explicit falsy aliases "no" / "off" / "0" / "". +func TestNormalizeInternetFlag_PythonParity(t *testing.T) { + tRue, fAlse := true, false + cases := []struct { + name string + in interface{} + want *bool // nil means "couldn't interpret" + }{ + // bool — straight through + {"bool true", true, &tRue}, + {"bool false", false, &fAlse}, + + // strings — case-insensitive, whitespace-trimmed, alias set + {"string true", "true", &tRue}, + {"string TRUE", "TRUE", &tRue}, + {"string padded true", " True ", &tRue}, + {"string yes", "yes", &tRue}, + {"string on", "on", &tRue}, + {"string 1", "1", &tRue}, + {"string false", "false", &fAlse}, + {"string FALSE", "FALSE", &fAlse}, + {"string no", "no", &fAlse}, + {"string off", "off", &fAlse}, + {"string 0", "0", &fAlse}, + {"string empty", "", &fAlse}, + {"string unknown", "maybe", nil}, + + // numerics — only 0 and 1 are valid (Python: `value in (0, 1)`) + {"int 0", 0, &fAlse}, + {"int 1", 1, &tRue}, + {"int 2", 2, nil}, + {"int64 0", int64(0), &fAlse}, + {"int64 1", int64(1), &tRue}, + {"float64 0", 0.0, &fAlse}, + {"float64 1", 1.0, &tRue}, + {"float64 1.5", 1.5, nil}, + + // other types → nil (couldn't interpret) + {"nil", nil, nil}, + {"slice", []string{"true"}, nil}, + {"map", map[string]string{"a": "b"}, nil}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := normalizeInternetFlag(tc.in) + switch { + case tc.want == nil && got == nil: + return + case tc.want == nil && got != nil: + t.Fatalf("input=%#v: want nil, got *%v", tc.in, *got) + case tc.want != nil && got == nil: + t.Fatalf("input=%#v: want *%v, got nil", tc.in, *tc.want) + case *tc.want != *got: + t.Fatalf("input=%#v: want *%v, got *%v", tc.in, *tc.want, *got) + } + }) + } +} + +// TestShouldUseWebSearch_RequiresTavilyAndTruthyInternet pins the two +// conjuncts of Python's _should_use_web_search (dialog_service.py:122-126): +// tavily_api_key must be set on prompt_config AND the internet flag must +// normalize to explicit true. +func TestShouldUseWebSearch_RequiresTavilyAndTruthyInternet(t *testing.T) { + svc := &ChatPipelineService{} + withTavily := &entity.Chat{ + PromptConfig: entity.JSONMap{"tavily_api_key": "tvly-xxx"}, + } + withoutTavily := &entity.Chat{ + PromptConfig: entity.JSONMap{}, + } + nilPromptConfig := &entity.Chat{} + + cases := []struct { + name string + dialog *entity.Chat + flag interface{} + want bool + }{ + // disqualifying gates + {"nil prompt_config", nilPromptConfig, true, false}, + {"empty tavily key", withoutTavily, true, false}, + {"tavily key + nil flag", withTavily, nil, false}, + {"tavily key + false bool", withTavily, false, false}, + {"tavily key + 'false' string", withTavily, "false", false}, + {"tavily key + unrecognized string", withTavily, "maybe", false}, + + // enabling combinations — all of these were broken before + // the normalizer fix and now work. + {"tavily key + true bool", withTavily, true, true}, + {"tavily key + 'true' string", withTavily, "true", true}, + {"tavily key + 'yes' string (was broken)", withTavily, "yes", true}, + {"tavily key + 'on' string (was broken)", withTavily, "on", true}, + {"tavily key + '1' string (was broken)", withTavily, "1", true}, + {"tavily key + 1 int", withTavily, 1, true}, + {"tavily key + 1.0 float", withTavily, 1.0, true}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if got := svc.shouldUseWebSearch(tc.dialog, tc.flag); got != tc.want { + t.Fatalf("dialog=%+v flag=%#v: want %v, got %v", + tc.dialog.PromptConfig, tc.flag, tc.want, got) + } + }) + } +} + +// --- P5 SQL retrieval parity helpers (Python use_sql alignment) --- + +// TestRemoveRedundantSpaces mirrors common.string_utils.remove_redundant_spaces. +func TestRemoveRedundantSpaces(t *testing.T) { + cases := []struct { + name string + in string + want string + }{ + // Both passes run sequentially — pass 1 strips space after `(`, + // pass 2 strips space before `)`, so both go. + {"pass1+pass2 on ( world )", "hello ( world )", "hello (world)"}, + // Pass 2 strips space before `!`. + {"pass2: space before !", "world !", "world!"}, + // Comma is not a boundary in pass 2 (it's in the negated set + // along with `<` and `(`), so no change. + {"comma not a boundary", "a , b", "a , b"}, + {"no match", "foo bar", "foo bar"}, + {"empty", "", ""}, + {"digit not a boundary", "abc 123", "abc 123"}, + {"left paren kept (no following space)", "(abc)", "(abc)"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if got := removeRedundantSpaces(tc.in); got != tc.want { + t.Errorf("removeRedundantSpaces(%q) = %q, want %q", tc.in, got, tc.want) + } + }) + } +} + +// TestStripISOTimestamps verifies the dialog_service.py:1309 cleanup. +// The pattern `T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|` strips the +// timestamp + trailing pipe; the leading pipe/space is preserved (the +// function is meant to operate on the cell boundary). Python's +// `re.sub` has identical behavior. +func TestStripISOTimestamps(t *testing.T) { + cases := []struct { + name string + in string + want string + }{ + {"basic T13:24:55|", "abc T13:24:55|def", "abc |def"}, + {"with ms T13:24:55.123Z|", "abc T13:24:55.123Z|def", "abc |def"}, + {"no match", "abc|def", "abc|def"}, + {"multiple", "x T01:02:03|y T04:05:06|z", "x |y |z"}, + {"no space before T", "abcT13:24:55|def", "abc|def"}, + {"empty", "", ""}, + // Realistic markdown cell: |2024-01-15T13:24:55| → |2024-01-15| + {"realistic cell", "|2024-01-15T13:24:55|", "|2024-01-15|"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if got := stripISOTimestamps(tc.in); got != tc.want { + t.Errorf("stripISOTimestamps(%q) = %q, want %q", tc.in, got, tc.want) + } + }) + } +} + +// TestMapColumnName exercises the map_column_name algorithm at +// dialog_service.py:1238-1280. +func TestMapColumnName(t *testing.T) { + fieldMap := map[string]interface{}{ + "title": "Title", + "issue_date": "Issue Date (/Day/Month/Year)", + "docnm": "Document Name", + "docnm_kwd": "Document Name", + } + cases := []struct { + name string + col string + fm map[string]interface{} + want string + }{ + {"count(star) special case", "count(star)", nil, "COUNT(*)"}, + {"count(star) case-insensitive", "COUNT(STAR)", nil, "COUNT(*)"}, + {"AS alias in field_map", "json_extract_string(c, '$.title') AS title", fieldMap, "Title"}, + {"AS alias not in field_map, case-insensitive", "fn() AS TITLE", fieldMap, "Title"}, + {"AS alias unknown, return as-is", "fn() AS unknown_alias", fieldMap, "unknown_alias"}, + {"direct match", "title", fieldMap, "Title"}, + {"direct case-insensitive", "TITLE", fieldMap, "Title"}, + {"no match, bulk replace", "json_extract_string(c, '$.title')", fieldMap, "json_extract_string(c, '$.Title')"}, + // `(/.*|...)` matches "/Day/Month/Year)" and replaces with "". + // The leading `(` is left intact — this matches Python's + // `re.sub` behavior exactly. + {"paren suffix stripped", "issue_date", fieldMap, "Issue Date ("}, + {"empty field map returns alias", "fn() AS foo", map[string]interface{}{}, "foo"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + fm := tc.fm + if fm == nil && tc.name == "count(star) special case" || tc.name == "count(star) case-insensitive" { + fm = map[string]interface{}{} + } + if got := mapColumnName(tc.col, fm); got != tc.want { + t.Errorf("mapColumnName(%q) = %q, want %q", tc.col, got, tc.want) + } + }) + } +} + +// TestChunkKBIDForDoc mirrors _chunk_kb_id_for_doc at dialog_service.py:56-59. +func TestChunkKBIDForDoc(t *testing.T) { + cases := []struct { + name string + rowDict map[string]interface{} + kbIDs []string + docID interface{} + want string + }{ + { + name: "single kb returns kbIDs[0]", + rowDict: map[string]interface{}{}, + kbIDs: []string{"kb_a"}, + docID: "doc1", + want: "kb_a", + }, + { + name: "multi kb with kb_id in row", + rowDict: map[string]interface{}{"kb_id": "kb_b"}, + kbIDs: []string{"kb_a", "kb_b"}, + docID: "doc1", + want: "kb_b", + }, + { + name: "multi kb with kb_id_kwd in row (no kb_id)", + rowDict: map[string]interface{}{"kb_id_kwd": "kb_c"}, + kbIDs: []string{"kb_a", "kb_b"}, + docID: "doc1", + want: "kb_c", + }, + { + name: "multi kb with neither returns empty", + rowDict: map[string]interface{}{}, + kbIDs: []string{"kb_a", "kb_b"}, + docID: "doc1", + want: "", + }, + { + name: "multi kb with empty kb_id falls through to kb_id_kwd", + rowDict: map[string]interface{}{"kb_id": "", "kb_id_kwd": "kb_d"}, + kbIDs: []string{"kb_a", "kb_b"}, + docID: "doc1", + want: "kb_d", + }, + { + name: "no kbIDs falls through to row lookup", + rowDict: map[string]interface{}{"kb_id": "kb_a"}, + kbIDs: nil, + docID: "doc1", + want: "kb_a", + }, + { + name: "no kbIDs and no row kb_id returns empty", + rowDict: map[string]interface{}{}, + kbIDs: nil, + docID: "doc1", + want: "", + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if got := chunkKBIDForDoc(tc.rowDict, tc.kbIDs, tc.docID); got != tc.want { + t.Errorf("chunkKBIDForDoc = %q, want %q", got, tc.want) + } + }) + } +} + +// TestCleanCellValue verifies the per-cell rendering at +// dialog_service.py:1298 (remove_redundant_spaces + replace None with space). +func TestCleanCellValue(t *testing.T) { + cases := []struct { + name string + in interface{} + want string + }{ + {"string", "hello", "hello"}, + {"float", 42.0, "42"}, + {"int", 42, "42"}, + {"None string literal", "None", " "}, + {"string with redundant space after (", "( world", "(world"}, + {"empty", "", ""}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if got := cleanCellValue(tc.in); got != tc.want { + t.Errorf("cleanCellValue(%v) = %q, want %q", tc.in, got, tc.want) + } + }) + } +} + +// TestExtractSourceColumnIndexes verifies stable sorted column ordering +// and case-insensitive source-column detection. +func TestExtractSourceColumnIndexes(t *testing.T) { + rows := []map[string]interface{}{ + {"DOC_ID": "d1", "docnm_kwd": "Doc1", "title": "T1", "kb_id": "k1"}, + } + docIDIdx, docNameIdx, kbIDIdx, columns := extractSourceColumnIndexes(rows) + if len(docIDIdx) != 1 { + t.Errorf("expected 1 doc_id index, got %d", len(docIDIdx)) + } + if len(docNameIdx) != 1 { + t.Errorf("expected 1 doc_name index, got %d", len(docNameIdx)) + } + if len(kbIDIdx) != 1 { + t.Errorf("expected 1 kb_id index, got %d", len(kbIDIdx)) + } + // Columns must be sorted alphabetically: DOC_ID, docnm_kwd, kb_id, title + wantCols := []string{"DOC_ID", "docnm_kwd", "kb_id", "title"} + if !reflect.DeepEqual(columns, wantCols) { + t.Errorf("columns = %v, want %v", columns, wantCols) + } + // Empty rows returns empty slices. + emptyDocID, _, _, _ := extractSourceColumnIndexes(nil) + if len(emptyDocID) != 0 { + t.Errorf("empty rows docIDIdx = %v, want empty", emptyDocID) + } +} + +// TestBuildChunkFetchSQL verifies the WHERE-clause extraction and SQL +// construction at dialog_service.py:1321-1331. +func TestBuildChunkFetchSQL(t *testing.T) { + cases := []struct { + name string + sql string + multiKB bool + wantSQL string + wantFound bool + }{ + { + name: "WHERE + GROUP BY (extracts up to GROUP BY)", + sql: "select count(*) from t where x = 1 group by y", + multiKB: false, + wantSQL: "select doc_id, docnm_kwd from t where x = 1 limit 20", + wantFound: true, + }, + { + name: "WHERE only, single KB, no limit", + sql: "select * from t where x = 1", + multiKB: false, + wantSQL: "select doc_id, docnm_kwd from t where x = 1 limit 20", + wantFound: true, + }, + { + name: "WHERE only, multi KB adds kb_id column", + sql: "select * from t where x = 1", + multiKB: true, + wantSQL: "select doc_id, docnm_kwd, kb_id from t where x = 1 limit 20", + wantFound: true, + }, + { + // Python's regex is non-greedy, so WHERE-clause extraction + // stops at the first occurrence of ORDER BY / LIMIT / GROUP BY. + // Python's subsequent SQL string is then + // "select doc_id, ... from t where {where}", which DROPS + // the order by / limit suffixes. Go matches this behavior. + name: "WHERE + ORDER BY + LIMIT 5 (suffixes dropped, no extra limit)", + sql: "select * from t where x = 1 order by y limit 5", + multiKB: false, + wantSQL: "select doc_id, docnm_kwd from t where x = 1 limit 20", + wantFound: true, + }, + { + name: "no WHERE returns not-found", + sql: "select * from t", + multiKB: false, + wantSQL: "", + wantFound: false, + }, + { + // Python's f-string emits a literal lowercase "where"; + // the original case from the input is NOT preserved. + name: "case-insensitive where (output uses lowercase where)", + sql: "select * from t WHERE x = 1", + multiKB: false, + wantSQL: "select doc_id, docnm_kwd from t where x = 1 limit 20", + wantFound: true, + }, + { + name: "Infinity expectedCol is docnm (not _kwd)", + sql: "select * from t where x = 1", + multiKB: false, + wantSQL: "select doc_id, docnm from t where x = 1 limit 20", + wantFound: true, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + expectedCol := "docnm_kwd" + if tc.name == "Infinity expectedCol is docnm (not _kwd)" { + expectedCol = "docnm" + } + gotSQL, gotFound := buildChunkFetchSQL(tc.sql, "t", expectedCol, tc.multiKB) + if gotFound != tc.wantFound { + t.Errorf("found = %v, want %v", gotFound, tc.wantFound) + } + if gotSQL != tc.wantSQL { + t.Errorf("sql = %q, want %q", gotSQL, tc.wantSQL) + } + }) + } +} + +// TestToIfaceSlice verifies the slice type conversion for the call-site +// contract at chat_pipeline.go:3846. +func TestToIfaceSlice(t *testing.T) { + in := []map[string]interface{}{ + {"a": 1}, + {"b": 2}, + } + out := toIfaceSlice(in) + if len(out) != 2 { + t.Fatalf("len = %d, want 2", len(out)) + } + if _, ok := out[0].(map[string]interface{}); !ok { + t.Errorf("element 0 type = %T, want map[string]interface{}", out[0]) + } +} + +// TestExpectedDocNameColumn verifies the engine→column name mapping. +func TestExpectedDocNameColumn(t *testing.T) { + if got := expectedDocNameColumn("infinity"); got != "docnm" { + t.Errorf("infinity = %q, want docnm", got) + } + if got := expectedDocNameColumn("oceanbase"); got != "docnm_kwd" { + t.Errorf("oceanbase = %q, want docnm_kwd", got) + } + if got := expectedDocNameColumn("elasticsearch"); got != "docnm_kwd" { + t.Errorf("elasticsearch = %q, want docnm_kwd", got) + } + if got := expectedDocNameColumn("opensearch"); got != "docnm_kwd" { + t.Errorf("opensearch = %q, want docnm_kwd", got) + } + if got := expectedDocNameColumn("unknown"); got != "docnm_kwd" { + t.Errorf("unknown = %q, want docnm_kwd", got) + } +} + +// TestIsAggregateSQL matches the regex from dialog_service.py:974. +func TestIsAggregateSQL(t *testing.T) { + cases := []struct { + sql string + want bool + }{ + {"select count(*) from t", true}, + {"select sum(x) from t", true}, + {"select avg(x) from t", true}, + {"select max(x), min(y) from t", true}, + {"select count(distinct x) from t", true}, + {"select * from t where x = 1", false}, + {"select distinct x from t", false}, // bare DISTINCT without ( ) doesn't match + {"", false}, + } + for _, tc := range cases { + t.Run(tc.sql, func(t *testing.T) { + if got := isAggregateSQL(tc.sql); got != tc.want { + t.Errorf("isAggregateSQL(%q) = %v, want %v", tc.sql, got, tc.want) + } + }) + } +} + +// sqlFakeEngine is a minimal in-memory engine.DocEngine stub for +// testing fetchAggregateChunks / buildSQLReference without a real +// engine. It embeds engine.DocEngine to satisfy the interface (the +// embedded methods will panic if accidentally called, which is the +// intended loud-failure mode). +type sqlFakeEngine struct { + engine.DocEngine + engineType string + sqlCalls *[]string + rowsBySQL map[string][]map[string]interface{} + errBySQL map[string]error + runSQL func(ctx context.Context, table, sqlText string, kbIDs []string) ([]map[string]interface{}, error) +} + +func (f *sqlFakeEngine) GetType() string { return f.engineType } +func (f *sqlFakeEngine) RunSQL(ctx context.Context, table, sqlText string, kbIDs []string, format string) ([]map[string]interface{}, error) { + if f.runSQL != nil { + return f.runSQL(ctx, table, sqlText, kbIDs) + } + if f.sqlCalls != nil { + *f.sqlCalls = append(*f.sqlCalls, sqlText) + } + if f.errBySQL != nil { + if err, ok := f.errBySQL[sqlText]; ok { + return nil, err + } + } + if f.rowsBySQL != nil { + if rows, ok := f.rowsBySQL[sqlText]; ok { + return rows, nil + } + } + return nil, nil +} + +// TestFetchAggregateChunks_SkipsInfinityMultiKB verifies the +// Infinity multi-KB short-circuit (mirrors Python's add_kb_filter +// no-op for Infinity). +func TestFetchAggregateChunks_SkipsInfinityMultiKB(t *testing.T) { + engine := &sqlFakeEngine{engineType: "infinity"} + s := &ChatPipelineService{} + chunks, docAggs := s.fetchAggregateChunks( + context.Background(), engine, "t", + "select count(*) from t where x = 1", + "docnm", []string{"kb_a", "kb_b"}, + ) + if chunks != nil || docAggs != nil { + t.Errorf("expected nil chunks/docAggs on Infinity multi-KB, got %v / %v", chunks, docAggs) + } +} + +// TestFetchAggregateChunks_SingleKBSuccess verifies the secondary fetch +// path populates chunks and doc_aggs correctly. +func TestFetchAggregateChunks_SingleKBSuccess(t *testing.T) { + chunksSQL := "select doc_id, docnm_kwd from t where x = 1 limit 20" + engine := &sqlFakeEngine{ + engineType: "elasticsearch", + rowsBySQL: map[string][]map[string]interface{}{ + chunksSQL: { + {"doc_id": "d1", "docnm_kwd": "Doc1"}, + {"doc_id": "d2", "docnm_kwd": "Doc2"}, + {"doc_id": "d1", "docnm_kwd": "Doc1"}, + }, + }, + } + s := &ChatPipelineService{} + chunks, docAggs := s.fetchAggregateChunks( + context.Background(), engine, "t", + "select count(*) from t where x = 1", + "docnm_kwd", []string{"kb_a"}, + ) + if len(chunks) != 3 { + t.Fatalf("chunks len = %d, want 3", len(chunks)) + } + if len(docAggs) != 2 { + t.Fatalf("docAggs len = %d, want 2", len(docAggs)) + } + // d1 appears twice → count=2; d2 once → count=1. + counts := map[string]int{} + for _, agg := range docAggs { + counts[agg["doc_id"].(string)] = agg["count"].(int) + } + if counts["d1"] != 2 || counts["d2"] != 1 { + t.Errorf("counts = %v, want d1=2, d2=1", counts) + } + // Single-kb: each chunk gets kb_id from the dialog's kb list. + for i, c := range chunks { + if c["kb_id"] != "kb_a" { + t.Errorf("chunks[%d].kb_id = %v, want kb_a", i, c["kb_id"]) + } + } +} + +// TestFetchAggregateChunks_NoWhereClause verifies the no-WHERE early +// return (matches Python's aggregate fallback at L1365). +func TestFetchAggregateChunks_NoWhereClause(t *testing.T) { + engine := &sqlFakeEngine{engineType: "elasticsearch"} + s := &ChatPipelineService{} + chunks, docAggs := s.fetchAggregateChunks( + context.Background(), engine, "t", + "select count(*) from t", + "docnm_kwd", []string{"kb_a"}, + ) + if chunks != nil || docAggs != nil { + t.Errorf("expected nil on no-WHERE, got %v / %v", chunks, docAggs) + } +} + +// TestFetchAggregateChunks_RunSQLError verifies graceful failure. +func TestFetchAggregateChunks_RunSQLError(t *testing.T) { + engine := &sqlFakeEngine{ + engineType: "elasticsearch", + runSQL: func(ctx context.Context, table, sqlText string, kbIDs []string) ([]map[string]interface{}, error) { + return nil, fmt.Errorf("engine boom") + }, + } + s := &ChatPipelineService{} + chunks, docAggs := s.fetchAggregateChunks( + context.Background(), engine, "t", + "select count(*) from t where x = 1", + "docnm_kwd", []string{"kb_a"}, + ) + if chunks != nil || docAggs != nil { + t.Errorf("expected nil on RunSQL error, got %v / %v", chunks, docAggs) + } +} + +// TestBuildSQLReference_EmptyRows verifies the empty-rows path. +func TestBuildSQLReference_EmptyRows(t *testing.T) { + s := &ChatPipelineService{} + ans, ref := s.buildSQLReference( + context.Background(), nil, "", "", nil, + "", "", nil, nil, + ) + if ans != "No results." { + t.Errorf("ans = %q, want %q", ans, "No results.") + } + if total, _ := ref["total"].(int); total != 0 { + t.Errorf("total = %d, want 0", total) + } +} + +// TestBuildSQLReference_NonAggregateWithSourceColumns verifies that +// chunks and doc_aggs are populated from rows when source columns +// are present. +func TestBuildSQLReference_NonAggregateWithSourceColumns(t *testing.T) { + rows := []map[string]interface{}{ + {"doc_id": "d1", "docnm_kwd": "Doc1", "title": "T1"}, + {"doc_id": "d2", "docnm_kwd": "Doc2", "title": "T2"}, + } + kbs := []*entity.Knowledgebase{{ID: "kb_a"}} + s := &ChatPipelineService{} + ans, ref := s.buildSQLReference( + context.Background(), nil, "t", "select doc_id, docnm_kwd, title from t", + rows, "", "elasticsearch", kbs, nil, + ) + if !strings.Contains(ans, "Source|") { + t.Errorf("expected Source column in answer, got:\n%s", ans) + } + if !strings.Contains(ans, "##0$$") || !strings.Contains(ans, "##1$$") { + t.Errorf("expected ##N$$ citation markers, got:\n%s", ans) + } + chunks, _ := ref["chunks"].([]map[string]interface{}) + if len(chunks) != 2 { + t.Fatalf("chunks len = %d, want 2", len(chunks)) + } + docAggs, _ := ref["doc_aggs"].([]map[string]interface{}) + if len(docAggs) != 2 { + t.Fatalf("docAggs len = %d, want 2", len(docAggs)) + } + // Each chunk must carry kb_id from the single-KB dialog. + for i, cm := range chunks { + if cm["kb_id"] != "kb_a" { + t.Errorf("chunks[%d].kb_id = %v, want kb_a", i, cm["kb_id"]) + } + } +} + +// TestBuildSQLReference_AggregateMissingSourceColumnsSecondaryFetch +// verifies that an aggregate SQL with no source columns triggers the +// secondary fetch and uses its result for chunks/doc_aggs. +// +// The test uses a multi-cell aggregate (1 row, 2 columns) to avoid the +// scalar shortcut at the top of buildSQLReference. +func TestBuildSQLReference_AggregateMissingSourceColumnsSecondaryFetch(t *testing.T) { + rows := []map[string]interface{}{ + {"count": 42.0, "label": "total"}, + } + chunksSQL := "select doc_id, docnm_kwd from t where x = 1 limit 20" + engine := &sqlFakeEngine{ + engineType: "elasticsearch", + rowsBySQL: map[string][]map[string]interface{}{ + chunksSQL: { + {"doc_id": "d1", "docnm_kwd": "Doc1"}, + }, + }, + } + kbs := []*entity.Knowledgebase{{ID: "kb_a"}} + s := &ChatPipelineService{} + ans, ref := s.buildSQLReference( + context.Background(), engine, "t", + "select count(*) from t where x = 1", + rows, "", "elasticsearch", kbs, nil, + ) + // Multi-cell aggregate → renders as a table, not a scalar. + if !strings.Contains(ans, "|42|") { + t.Errorf("ans = %q, want to contain |42|", ans) + } + chunks, _ := ref["chunks"].([]map[string]interface{}) + if len(chunks) != 1 { + t.Errorf("chunks len = %d, want 1 (from secondary fetch)", len(chunks)) + } +} + +// TestBuildSQLReference_NonAggregateMissingSourceEmptyRefs verifies +// that non-aggregate SQL without source columns returns the table but +// empty chunks/doc_aggs (Python's best-effort path at L1367). +func TestBuildSQLReference_NonAggregateMissingSourceEmptyRefs(t *testing.T) { + rows := []map[string]interface{}{ + {"title": "T1"}, + {"title": "T2"}, + } + s := &ChatPipelineService{} + ans, ref := s.buildSQLReference( + context.Background(), nil, "t", "select title from t", + rows, "", "elasticsearch", nil, nil, + ) + if !strings.Contains(ans, "T1") || !strings.Contains(ans, "T2") { + t.Errorf("expected table data in answer, got:\n%s", ans) + } + if strings.Contains(ans, "Source|") { + t.Errorf("expected no Source column, got:\n%s", ans) + } + chunks, _ := ref["chunks"].([]map[string]interface{}) + if len(chunks) != 0 { + t.Errorf("chunks = %v, want empty", chunks) + } + docAggs, _ := ref["doc_aggs"].([]interface{}) + if len(docAggs) != 0 { + t.Errorf("docAggs = %v, want empty", docAggs) + } +} + +// TestBuildSQLReference_DisplayNameTranslation verifies that column +// names are translated via the field_map. +func TestBuildSQLReference_DisplayNameTranslation(t *testing.T) { + rows := []map[string]interface{}{ + {"doc_id": "d1", "docnm_kwd": "Doc1", "title": "Hello"}, + } + fieldMap := map[string]interface{}{"title": "My Title"} + s := &ChatPipelineService{} + ans, _ := s.buildSQLReference( + context.Background(), nil, "t", "select doc_id, docnm_kwd, title from t", + rows, "", "elasticsearch", nil, fieldMap, + ) + if !strings.Contains(ans, "|My Title|") { + t.Errorf("expected translated column name, got:\n%s", ans) + } + if strings.Contains(ans, "|title|") { + t.Errorf("raw column name should not appear, got:\n%s", ans) + } +} + +// TestBuildSQLReference_ISOTimestampStripped verifies that ISO +// timestamps in cell values are stripped from the rendered table. +func TestBuildSQLReference_ISOTimestampStripped(t *testing.T) { + rows := []map[string]interface{}{ + {"doc_id": "d1", "docnm_kwd": "Doc1", "created_at": "2024-01-15T13:24:55"}, + } + s := &ChatPipelineService{} + ans, _ := s.buildSQLReference( + context.Background(), nil, "t", "select doc_id, docnm_kwd, created_at from t", + rows, "", "elasticsearch", nil, nil, + ) + if strings.Contains(ans, "T13:24:55") { + t.Errorf("expected ISO timestamp stripped, got:\n%s", ans) + } + if !strings.Contains(ans, "2024-01-15") { + t.Errorf("expected date portion preserved, got:\n%s", ans) + } +} + +// --- BuildChatConfig unit tests (moved from openai_chat_test.go) --- + +// TestBuildChatConfig_RequestOverrides pins down the merge order: +// dialog.LLMSetting is the base; request fields override. +func TestBuildChatConfig_RequestOverrides(t *testing.T) { + temp := 0.1 + dialog := &entity.Chat{ + LLMSetting: entity.JSONMap{ + "temperature": 0.5, + "top_p": 0.9, + }, + } + req := map[string]interface{}{"temperature": temp} + cfg := BuildChatConfig(dialog, req) + if cfg.Temperature == nil || *cfg.Temperature != temp { + t.Fatalf("expected request temperature %v, got %v", temp, cfg.Temperature) + } + if cfg.TopP == nil || *cfg.TopP != 0.9 { + t.Fatalf("expected dialog top_p 0.9 to be preserved, got %v", cfg.TopP) + } +} + +// TestBuildChatConfig_FromEmptyDialog verifies the merger works even when +// dialog.LLMSetting is nil. +func TestBuildChatConfig_FromEmptyDialog(t *testing.T) { + temp := 0.3 + dialog := &entity.Chat{} + req := map[string]interface{}{"temperature": temp} + cfg := BuildChatConfig(dialog, req) + if cfg.Temperature == nil || *cfg.Temperature != temp { + t.Fatalf("expected temperature %v, got %v", temp, cfg.Temperature) + } +} diff --git a/internal/service/chat_session.go b/internal/service/chat_session.go index 1c1aa36d54..046b21eeb9 100644 --- a/internal/service/chat_session.go +++ b/internal/service/chat_session.go @@ -22,73 +22,47 @@ import ( "errors" "fmt" "ragflow/internal/common" - "ragflow/internal/engine" - "ragflow/internal/service/nlp" "strings" "time" - "go.uber.org/zap" - "ragflow/internal/dao" "ragflow/internal/entity" - modelModule "ragflow/internal/entity/models" ) -type chatKnowledgebaseStore interface { - Accessible(kbID, userID string) bool - GetByIDs(ids []string) ([]*entity.Knowledgebase, error) +// Interfaces for testability — satisfied by the concrete DAO/pipeline types. + +type chatSessionStore interface { + GetByID(id string) (*entity.ChatSession, error) + Create(conv *entity.ChatSession) error + UpdateByID(id string, updates map[string]interface{}) error + DeleteByID(id string) error + ListByChatID(chatID string) ([]*entity.ChatSession, error) + GetDialogByID(chatID string) (*entity.Chat, error) + CheckDialogExists(tenantID, chatID string) (bool, error) } -type chatModelProvider interface { - GetChatModel(tenantID, compositeModelName string) (*modelModule.ChatModel, error) - GetEmbeddingModel(tenantID, compositeModelName string) (*modelModule.EmbeddingModel, error) - GetRerankModel(tenantID, compositeModelName string) (*modelModule.RerankModel, error) - GetModelConfigFromProviderInstance(tenantID string, modelType entity.ModelType, modelName string) (modelModule.ModelDriver, string, *modelModule.APIConfig, int, error) - GetTenantDefaultModelByType(tenantID string, modelType entity.ModelType) (modelModule.ModelDriver, string, *modelModule.APIConfig, int, error) +type userTenantStore interface { + GetTenantIDsByUserID(userID string) ([]string, error) } -type chatMetadataService interface { - LabelQuestion(question string, kbs []*entity.Knowledgebase) map[string]float64 - GetFlattedMetaByKBs(kbIDs []string) (common.MetaData, error) +type chatPipelineRunner interface { + AsyncChat(ctx context.Context, chat *entity.Chat, messages []map[string]interface{}, stream bool, kwargs map[string]interface{}) (<-chan AsyncChatResult, error) } -type chatRetrievalService interface { - Retrieval(ctx context.Context, req *nlp.RetrievalRequest) (*nlp.RetrievalResult, error) -} - -// ChatSessionService chat session (conversation) service +// ChatSessionService chat session (conversation) service. +// The RAG pipeline is delegated to ChatPipelineService. type ChatSessionService struct { - chatSessionDAO *dao.ChatSessionDAO - chatDAO *dao.ChatDAO - userTenantDAO *dao.UserTenantDAO - kbDAO chatKnowledgebaseStore - docEngine engine.DocEngine - modelProviderSvc chatModelProvider - metadataSvc chatMetadataService - retrievalSvc chatRetrievalService + chatSessionDAO chatSessionStore + userTenantDAO userTenantStore + pipeline chatPipelineRunner } // NewChatSessionService create chat session service func NewChatSessionService() *ChatSessionService { - docEngine := engine.Get() - return newChatSessionServiceWithRetrieval(docEngine, nlp.NewRetrievalService(docEngine, dao.NewDocumentDAO())) -} - -// NewChatSessionServiceWithRetrieval creates a chat session service with a retrieval service. -func NewChatSessionServiceWithRetrieval(retrievalSvc chatRetrievalService) *ChatSessionService { - return newChatSessionServiceWithRetrieval(engine.Get(), retrievalSvc) -} - -func newChatSessionServiceWithRetrieval(docEngine engine.DocEngine, retrievalSvc chatRetrievalService) *ChatSessionService { return &ChatSessionService{ - chatSessionDAO: dao.NewChatSessionDAO(), - chatDAO: dao.NewChatDAO(), - userTenantDAO: dao.NewUserTenantDAO(), - kbDAO: dao.NewKnowledgebaseDAO(), - docEngine: docEngine, - modelProviderSvc: NewModelProviderService(), - metadataSvc: NewMetadataService(), - retrievalSvc: retrievalSvc, + chatSessionDAO: dao.NewChatSessionDAO(), + userTenantDAO: dao.NewUserTenantDAO(), + pipeline: NewChatPipelineService(), } } @@ -185,11 +159,6 @@ func (s *ChatSessionService) SetChatSession(userID string, req *SetChatSessionRe return &SetChatSessionResponse{ChatSession: session}, nil } -// RemoveChatSessionRequest remove chat sessions request -type RemoveChatSessionRequest struct { - ChatSessions []string `json:"conversation_ids" binding:"required"` -} - // RemoveChatSessions removes chat sessions (hard delete) func (s *ChatSessionService) RemoveChatSessions(userID string, chatSessions []string) error { // Get user's tenants @@ -294,7 +263,7 @@ func (s *ChatSessionService) ListChatSessions(userID string, chatID string) (*Li return &ListChatSessionsResponse{Sessions: sessions}, nil } -// Completion performs chat completion with full RAG support +// Completion performs chat completion with full RAG support via ChatPipelineService. func (s *ChatSessionService) Completion(userID string, conversationID string, messages []map[string]interface{}, llmID string, chatModelConfig map[string]interface{}, messageID string) (map[string]interface{}, error) { // Validate the last message is from user if len(messages) == 0 { @@ -336,12 +305,37 @@ func (s *ChatSessionService) Completion(userID string, conversationID string, me } } - // Perform chat completion with RAG - result, err := s.asyncChat(userID, dialog, session, messages, chatModelConfig, messageID, reference, false) + // Perform chat completion via shared RAG pipeline + ctx := context.Background() + kwargs := chatModelConfig + if kwargs == nil { + kwargs = map[string]interface{}{} + } + resultChan, err := s.pipeline.AsyncChat(ctx, dialog, messages, false, kwargs) if err != nil { return nil, err } + // Collect results from the pipeline + var answer strings.Builder + var finalRef map[string]interface{} + for result := range resultChan { + if result.Answer != "" { + answer.WriteString(result.Answer) + } + if result.Reference != nil { + finalRef = result.Reference + } + } + + // Structure the answer + ans := map[string]interface{}{ + "answer": answer.String(), + "reference": finalRef, + "final": true, + } + result := s.structureAnswerWithConv(session, ans, messageID, session.ID, reference) + // Update conversation if not embedded if !isEmbedded { s.updateSessionMessages(session, sessionMessages, reference) @@ -350,7 +344,7 @@ func (s *ChatSessionService) Completion(userID string, conversationID string, me return result, nil } -// CompletionStream performs streaming chat completion with full RAG support +// CompletionStream performs streaming chat completion with full RAG support via ChatPipelineService. func (s *ChatSessionService) CompletionStream(ctx context.Context, userID string, conversationID string, messages []map[string]interface{}, llmID string, chatModelConfig map[string]interface{}, messageID string, streamChan chan<- string) error { if ctx == nil { ctx = context.Background() @@ -402,19 +396,36 @@ func (s *ChatSessionService) CompletionStream(ctx context.Context, userID string } } - // Perform streaming chat completion with RAG - resultChan, err := s.asyncChatStream(ctx, userID, dialog, session, messages, chatModelConfig, messageID, reference) + // Perform streaming chat via shared RAG pipeline + kwargs := chatModelConfig + if kwargs == nil { + kwargs = map[string]interface{}{} + } + resultChan, err := s.pipeline.AsyncChat(ctx, dialog, messages, true, kwargs) if err != nil { streamChan <- fmt.Sprintf("data: %s\n\n", fmt.Sprintf(`{"code": 500, "message": "%s", "data": {"answer": "**ERROR**: %s", "reference": []}}`, err.Error(), err.Error())) return err } - // Stream results + // Stream results, accumulating the answer + var fullAnswer strings.Builder for result := range resultChan { + if result.Reference != nil && len(reference) > 0 { + reference[len(reference)-1] = result.Reference + } + if result.Final { + if result.Answer != "" { + fullAnswer.Reset() + fullAnswer.WriteString(result.Answer) + } + } else if result.Answer != "" { + fullAnswer.WriteString(result.Answer) + } + ans := s.structureAnswer(session, fullAnswer.String(), messageID, session.ID, reference) data, _ := json.Marshal(map[string]interface{}{ "code": 0, "message": "", - "data": result, + "data": ans, }) streamChan <- fmt.Sprintf("data: %s\n\n", string(data)) } @@ -462,7 +473,7 @@ func (s *ChatSessionService) initializeReference(session *entity.ChatSession) [] } } filtered = append(filtered, map[string]interface{}{ - "chunks": []interface{}{}, + "chunks": []map[string]interface{}{}, "doc_aggs": []interface{}{}, }) return filtered @@ -496,863 +507,13 @@ func (s *ChatSessionService) updateSessionMessages(session *entity.ChatSession, s.chatSessionDAO.UpdateByID(session.ID, updates) } -// asyncChat performs chat with RAG support (non-streaming) -func (s *ChatSessionService) asyncChat(userID string, dialog *entity.Chat, session *entity.ChatSession, messages []map[string]interface{}, config map[string]interface{}, messageID string, reference []interface{}, stream bool) (map[string]interface{}, error) { - // Check if we need RAG (knowledge base or tavily) - hasKB := len(dialog.KBIDs) > 0 - hasTavily := false - if dialog.PromptConfig != nil { - if tavilyKey, ok := dialog.PromptConfig["tavily_api_key"].(string); ok && tavilyKey != "" { - hasTavily = true - } - } - - if !hasKB && !hasTavily { - // Simple chat without RAG - return s.asyncChatSolo(dialog, session, messages, config, messageID, reference, stream) - } - - if hasKB { - return s.asyncChatWithRetrieval(context.Background(), userID, dialog, session, messages, config, messageID, reference, stream) - } - - common.Warn("Tavily-backed chat retrieval is not implemented in Go; falling back to solo chat", - zap.String("dialog_id", dialog.ID)) - return s.asyncChatSolo(dialog, session, messages, config, messageID, reference, stream) -} - -// asyncChatStream performs streaming chat with RAG support -func (s *ChatSessionService) asyncChatStream(ctx context.Context, userID string, dialog *entity.Chat, session *entity.ChatSession, messages []map[string]interface{}, config map[string]interface{}, messageID string, reference []interface{}) (<-chan map[string]interface{}, error) { - if ctx == nil { - ctx = context.Background() - } - resultChan := make(chan map[string]interface{}) - - go func() { - defer close(resultChan) - - // Check if we need RAG - hasKB := len(dialog.KBIDs) > 0 - hasTavily := false - if dialog.PromptConfig != nil { - if tavilyKey, ok := dialog.PromptConfig["tavily_api_key"].(string); ok && tavilyKey != "" { - hasTavily = true - } - } - - if !hasKB && !hasTavily { - // Simple chat without RAG - s.asyncChatSoloStream(dialog, session, messages, config, messageID, reference, resultChan) - return - } - - if hasKB { - ragMessages, ragDialog, emptyResponse, err := s.messagesWithRetrievedKnowledge(ctx, userID, dialog, messages, reference) - if err != nil { - resultChan <- s.structureAnswer(session, "**ERROR**: "+err.Error(), messageID, session.ID, reference) - return - } - if emptyResponse != nil { - resultChan <- s.structureAnswer(session, *emptyResponse, messageID, session.ID, reference) - return - } - s.asyncChatSoloStream(ragDialog, session, ragMessages, config, messageID, reference, resultChan) - return - } - - common.Warn("Tavily-backed streaming chat retrieval is not implemented in Go; falling back to solo chat", - zap.String("dialog_id", dialog.ID)) - s.asyncChatSoloStream(dialog, session, messages, config, messageID, reference, resultChan) - }() - - return resultChan, nil -} - -func (s *ChatSessionService) asyncChatWithRetrieval(ctx context.Context, userID string, dialog *entity.Chat, session *entity.ChatSession, messages []map[string]interface{}, config map[string]interface{}, messageID string, reference []interface{}, stream bool) (map[string]interface{}, error) { - ragMessages, ragDialog, emptyResponse, err := s.messagesWithRetrievedKnowledge(ctx, userID, dialog, messages, reference) - if err != nil { - return nil, err - } - if emptyResponse != nil { - var lastRef interface{} - if len(reference) > 0 { - lastRef = reference[len(reference)-1] - } - ans := map[string]interface{}{ - "answer": *emptyResponse, - "reference": lastRef, - "final": true, - } - return s.structureAnswerWithConv(session, ans, messageID, session.ID, reference), nil - } - return s.asyncChatSolo(ragDialog, session, ragMessages, config, messageID, reference, stream) -} - -func (s *ChatSessionService) messagesWithRetrievedKnowledge(ctx context.Context, userID string, dialog *entity.Chat, messages []map[string]interface{}, reference []interface{}) ([]map[string]interface{}, *entity.Chat, *string, error) { - kbIDs := stringSliceFromJSON(dialog.KBIDs) - if len(kbIDs) == 0 { - return messages, dialog, nil, nil - } - if s.retrievalSvc == nil { - return nil, nil, nil, errors.New("retrieval service is not configured") - } - - question := latestUserQuestion(messages) - if question == "" { - return messages, dialog, nil, nil - } - - kbs, err := s.kbDAO.GetByIDs(kbIDs) - if err != nil { - return nil, nil, nil, fmt.Errorf("failed to load knowledge bases: %w", err) - } - kbs, err = s.knowledgebasesForDialog(userID, dialog, kbIDs, kbs) - if err != nil { - return nil, nil, nil, err - } - embeddingTenantID, embeddingModelName, err := validateKnowledgebaseEmbeddingModels(kbs, dialog.TenantID, resolveEmbeddingModelName) - if err != nil { - return nil, nil, nil, err - } - - embeddingModel, err := s.modelProviderSvc.GetEmbeddingModel(embeddingTenantID, embeddingModelName) - if err != nil { - return nil, nil, nil, fmt.Errorf("failed to get embedding model: %w", err) - } - rerankModel, err := s.rerankModelForDialog(dialog) - if err != nil { - return nil, nil, nil, err - } - - top := int(dialog.TopK) - pageSize := int(dialog.TopN) - if pageSize <= 0 { - pageSize = 6 - } - similarityThreshold := dialog.SimilarityThreshold - vectorSimilarityWeight := dialog.VectorSimilarityWeight - var rankFeature map[string]float64 - if s.metadataSvc != nil { - rankFeature = s.metadataSvc.LabelQuestion(question, kbs) - } - baseDocIDs := docIDsFromMessages(messages) - docIDs, err := s.filteredDocIDsForDialog(ctx, dialog, kbIDs, question, baseDocIDs) - if err != nil { - return nil, nil, nil, err - } - tenantIDs := tenantIDsFromKnowledgebases(kbs, dialog.TenantID) - - retrievalResult, err := s.retrievalSvc.Retrieval(ctx, &nlp.RetrievalRequest{ - Question: question, - TenantIDs: tenantIDs, - KbIDs: kbIDs, - DocIDs: docIDs, - Page: 1, - PageSize: pageSize, - Top: &top, - SimilarityThreshold: &similarityThreshold, - VectorSimilarityWeight: &vectorSimilarityWeight, - RankFeature: &rankFeature, - EmbeddingModel: embeddingModel, - RerankModel: rerankModel, - }) - if err != nil { - return nil, nil, nil, fmt.Errorf("retrieval search failed: %w", err) - } - if retrievalResult == nil { - retrievalResult = &nlp.RetrievalResult{} - } - - chunks := retrievalResult.Chunks - if s.docEngine != nil { - chunks = nlp.RetrievalByChildren(chunks, tenantIDs, s.docEngine, ctx) - } - setLatestReference(reference, chunks, retrievalResult.DocAggs) - knowledge := buildKnowledgeBlock(chunks) - if knowledge == "" { - return messages, dialog, emptyResponseForDialog(dialog), nil - } - if ragDialog, ok := dialogWithInjectedKnowledgePrompt(dialog, knowledge); ok { - return copyMessages(messages), ragDialog, nil, nil - } - - return injectKnowledge(messages, knowledge), dialog, nil, nil -} - -type embeddingModelNameResolver func(tenantID string, kb *entity.Knowledgebase) (string, error) - -func validateKnowledgebaseEmbeddingModels(kbs []*entity.Knowledgebase, fallbackTenantID string, resolve embeddingModelNameResolver) (string, string, error) { - if len(kbs) == 0 { - return fallbackTenantID, "", nil - } - - expected := "" - expectedKBID := "" - expectedTenantID := fallbackTenantID - for _, kb := range kbs { - if kb == nil { - return "", "", errors.New("knowledge base is nil") - } - tenantID := kb.TenantID - if tenantID == "" { - tenantID = fallbackTenantID - } - modelName, err := resolve(tenantID, kb) - if err != nil { - return "", "", err - } - modelName = strings.TrimSpace(modelName) - if modelName == "" { - return "", "", fmt.Errorf("knowledge base %s has no embedding model", kb.ID) - } - if expected == "" { - expected = modelName - expectedKBID = kb.ID - expectedTenantID = tenantID - continue - } - if modelName != expected { - return "", "", fmt.Errorf("knowledge bases must use the same embedding model: %s resolves to %q, expected %q from %s", kb.ID, modelName, expected, expectedKBID) - } - } - return expectedTenantID, expected, nil -} - -func (s *ChatSessionService) rerankModelForDialog(dialog *entity.Chat) (*modelModule.RerankModel, error) { - compositeName, err := resolveRerankModelName(dialog) - if err != nil { - return nil, err - } - if compositeName == "" { - return nil, nil - } - rerankModel, err := s.modelProviderSvc.GetRerankModel(dialog.TenantID, compositeName) - if err != nil { - return nil, fmt.Errorf("failed to get rerank model: %w", err) - } - return rerankModel, nil -} - -func (s *ChatSessionService) filteredDocIDsForDialog(ctx context.Context, dialog *entity.Chat, kbIDs []string, question string, baseDocIDs []string) ([]string, error) { - if dialog.MetaDataFilter == nil || len(*dialog.MetaDataFilter) == 0 { - return baseDocIDs, nil - } - if s.metadataSvc == nil { - return nil, errors.New("metadata service is not configured") - } - - filter := make(map[string]interface{}, len(*dialog.MetaDataFilter)) - for key, value := range *dialog.MetaDataFilter { - filter[key] = value - } - - metaData, err := s.metadataSvc.GetFlattedMetaByKBs(kbIDs) - if err != nil { - return nil, fmt.Errorf("failed to get flattened metadata for chat retrieval: %w", err) - } - - var filterChatModel *modelModule.ChatModel - method, _ := filter["method"].(string) - if method == "auto" || method == "semi_auto" { - filterChatModel, err = s.modelProviderSvc.GetChatModel(dialog.TenantID, dialog.LLMID) - if err != nil { - common.Warn("Failed to get chat model for chat metadata filter", zap.Error(err)) - } - } - - docIDs, empty := ApplyMetaDataFilter(ctx, filter, metaData, question, filterChatModel, baseDocIDs, kbIDs) - if empty { - return []string{NoMatchDocIDSentinel}, nil - } - return docIDs, nil -} - -func resolveEmbeddingModelName(tenantID string, kb *entity.Knowledgebase) (string, error) { - if kb.TenantEmbdID != nil && *kb.TenantEmbdID > 0 { - _, compositeName, err := dao.LookupTenantLLMByID(dao.NewTenantLLMDAO(), *kb.TenantEmbdID) - if err != nil { - return "", fmt.Errorf("failed to get embedding model by tenant_embd_id: %w", err) - } - return compositeName, nil - } - if kb.EmbdID != "" { - if strings.Contains(kb.EmbdID, "@") { - return kb.EmbdID, nil - } - _, compositeName, err := dao.LookupTenantLLMByName(dao.NewTenantLLMDAO(), tenantID, kb.EmbdID, entity.ModelTypeEmbedding) - if err != nil { - return "", fmt.Errorf("failed to get embedding model by embd_id: %w", err) - } - return compositeName, nil - } - - tenantLLM, err := dao.NewTenantLLMDAO().GetByTenantAndType(tenantID, entity.ModelTypeEmbedding) - if err != nil { - return "", fmt.Errorf("failed to get tenant default embedding model: %w", err) - } - if tenantLLM == nil || tenantLLM.LLMName == nil || *tenantLLM.LLMName == "" { - return "", fmt.Errorf("no default embedding model found for tenant %s", tenantID) - } - return fmt.Sprintf("%s@%s", *tenantLLM.LLMName, tenantLLM.LLMFactory), nil -} - -func resolveRerankModelName(dialog *entity.Chat) (string, error) { - if dialog.TenantRerankID != nil && *dialog.TenantRerankID > 0 { - _, compositeName, err := dao.LookupTenantLLMByID(dao.NewTenantLLMDAO(), *dialog.TenantRerankID) - if err != nil { - return "", fmt.Errorf("failed to get rerank model by tenant_rerank_id: %w", err) - } - return compositeName, nil - } - if dialog.RerankID == "" { - return "", nil - } - if strings.Contains(dialog.RerankID, "@") { - return dialog.RerankID, nil - } - _, compositeName, err := dao.LookupTenantLLMByName(dao.NewTenantLLMDAO(), dialog.TenantID, dialog.RerankID, entity.ModelTypeRerank) - if err != nil { - return "", fmt.Errorf("failed to get rerank model by rerank_id: %w", err) - } - return compositeName, nil -} - -func stringSliceFromJSON(values entity.JSONSlice) []string { - result := make([]string, 0, len(values)) - seen := make(map[string]struct{}, len(values)) - for _, value := range values { - str, ok := value.(string) - if !ok || str == "" { - continue - } - if _, exists := seen[str]; exists { - continue - } - seen[str] = struct{}{} - result = append(result, str) - } - return result -} - -func tenantIDsFromKnowledgebases(kbs []*entity.Knowledgebase, fallback string) []string { - seen := make(map[string]struct{}, len(kbs)+1) - var tenantIDs []string - for _, kb := range kbs { - if kb == nil || kb.TenantID == "" { - continue - } - if _, exists := seen[kb.TenantID]; exists { - continue - } - seen[kb.TenantID] = struct{}{} - tenantIDs = append(tenantIDs, kb.TenantID) - } - if len(tenantIDs) == 0 && fallback != "" { - tenantIDs = append(tenantIDs, fallback) - } - return tenantIDs -} - -func (s *ChatSessionService) knowledgebasesForDialog(userID string, dialog *entity.Chat, kbIDs []string, loaded []*entity.Knowledgebase) ([]*entity.Knowledgebase, error) { - byID := make(map[string]*entity.Knowledgebase, len(loaded)) - for _, kb := range loaded { - if kb != nil { - byID[kb.ID] = kb - } - } - - kbs := make([]*entity.Knowledgebase, 0, len(kbIDs)) - for _, kbID := range kbIDs { - kb := byID[kbID] - if kb == nil { - return nil, fmt.Errorf("knowledge base %s not found", kbID) - } - if userID != "" && !s.kbDAO.Accessible(kbID, userID) { - return nil, fmt.Errorf("knowledge base %s is not authorized for user", kbID) - } - if userID == "" && kb.TenantID != dialog.TenantID { - return nil, fmt.Errorf("knowledge base %s is not authorized for dialog tenant", kbID) - } - kbs = append(kbs, kb) - } - if len(kbs) == 0 { - return nil, errors.New("no valid knowledge bases found") - } - return kbs, nil -} - -func docIDsFromMessages(messages []map[string]interface{}) []string { - for i := len(messages) - 1; i >= 0; i-- { - if role, _ := messages[i]["role"].(string); role != "user" { - continue - } - return stringSliceFromValue(messages[i]["doc_ids"]) - } - return nil -} - -func latestUserQuestion(messages []map[string]interface{}) string { - for i := len(messages) - 1; i >= 0; i-- { - if role, _ := messages[i]["role"].(string); role != "user" { - continue - } - return textFromMessageContent(messages[i]["content"]) - } - return "" -} - -func stringSliceFromValue(value interface{}) []string { - switch typed := value.(type) { - case nil: - return nil - case []string: - return uniqueNonEmptyStrings(typed) - case []interface{}: - values := make([]string, 0, len(typed)) - for _, item := range typed { - if str, ok := item.(string); ok { - values = append(values, str) - } - } - return uniqueNonEmptyStrings(values) - default: - return nil - } -} - -func uniqueNonEmptyStrings(values []string) []string { - result := make([]string, 0, len(values)) - seen := make(map[string]struct{}, len(values)) - for _, value := range values { - value = strings.TrimSpace(value) - if value == "" { - continue - } - if _, exists := seen[value]; exists { - continue - } - seen[value] = struct{}{} - result = append(result, value) - } - if len(result) == 0 { - return nil - } - return result -} - -func emptyResponseForDialog(dialog *entity.Chat) *string { - if dialog.PromptConfig == nil { - return nil - } - emptyResponse, ok := dialog.PromptConfig["empty_response"].(string) - if !ok || emptyResponse == "" { - return nil - } - return &emptyResponse -} - -func buildKnowledgeBlock(chunks []map[string]interface{}) string { - var builder strings.Builder - for i, chunk := range chunks { - content := chunkText(chunk) - if content == "" { - continue - } - if builder.Len() > 0 { - builder.WriteString("\n\n") - } - builder.WriteString(fmt.Sprintf("[%d]", i+1)) - if docName, ok := chunk["docnm_kwd"].(string); ok && docName != "" { - builder.WriteString(" ") - builder.WriteString(docName) - } - builder.WriteString("\n") - builder.WriteString(content) - } - return builder.String() -} - -func chunkText(chunk map[string]interface{}) string { - for _, key := range []string{"content_with_weight", "content_ltks", "content"} { - if value, ok := chunk[key].(string); ok && strings.TrimSpace(value) != "" { - return strings.TrimSpace(value) - } - } - return "" -} - -func injectKnowledge(messages []map[string]interface{}, knowledge string) []map[string]interface{} { - copied := copyMessages(messages) - if len(copied) == 0 { - return copied - } - - knowledgePrompt := fmt.Sprintf("Use the following knowledge snippets to answer the user's question. If the snippets do not contain the answer, say that the knowledge base does not provide enough information.\n\n%s", knowledge) - for i := len(copied) - 1; i >= 0; i-- { - if role, _ := copied[i]["role"].(string); role != "user" { - continue - } - copied[i]["content"] = injectKnowledgeIntoContent(copied[i]["content"], knowledgePrompt) - return copied - } - - copied = append(copied, map[string]interface{}{ - "role": "system", - "content": knowledgePrompt, - }) - return copied -} - -func injectKnowledgeIntoContent(content interface{}, knowledgePrompt string) interface{} { - switch typed := content.(type) { - case []interface{}: - injected := make([]interface{}, 0, len(typed)+1) - injected = append(injected, knowledgeTextBlock(knowledgePrompt)) - injected = append(injected, typed...) - return injected - case []map[string]interface{}: - injected := make([]interface{}, 0, len(typed)+1) - injected = append(injected, knowledgeTextBlock(knowledgePrompt)) - for _, block := range typed { - injected = append(injected, block) - } - return injected - default: - contentText := "" - if content != nil { - contentText = fmt.Sprint(content) - } - return strings.TrimSpace(knowledgePrompt + "\n\nQuestion:\n" + contentText) - } -} - -func knowledgeTextBlock(knowledgePrompt string) map[string]interface{} { - return map[string]interface{}{ - "type": "text", - "text": knowledgePrompt + "\n\nQuestion:", - } -} - -func textFromMessageContent(content interface{}) string { - switch typed := content.(type) { - case string: - return strings.TrimSpace(typed) - case []interface{}: - return strings.TrimSpace(strings.Join(textsFromContentBlocks(typed), "\n")) - case []map[string]interface{}: - blocks := make([]interface{}, 0, len(typed)) - for _, block := range typed { - blocks = append(blocks, block) - } - return strings.TrimSpace(strings.Join(textsFromContentBlocks(blocks), "\n")) - default: - if content == nil { - return "" - } - return strings.TrimSpace(fmt.Sprint(content)) - } -} - -func textsFromContentBlocks(blocks []interface{}) []string { - texts := make([]string, 0, len(blocks)) - for _, block := range blocks { - switch typed := block.(type) { - case string: - if text := strings.TrimSpace(typed); text != "" { - texts = append(texts, text) - } - case map[string]interface{}: - if text, ok := typed["text"].(string); ok && strings.TrimSpace(text) != "" { - texts = append(texts, strings.TrimSpace(text)) - } - } - } - return texts -} - -func dialogWithInjectedKnowledgePrompt(dialog *entity.Chat, knowledge string) (*entity.Chat, bool) { - if dialog.PromptConfig == nil { - return dialog, false - } - systemPrompt, ok := dialog.PromptConfig["system"].(string) - if !ok || !strings.Contains(systemPrompt, "{knowledge}") { - return dialog, false - } - - copied := cloneJSONMap(dialog.PromptConfig) - copied["system"] = strings.ReplaceAll(systemPrompt, "{knowledge}", knowledge) - dialogCopy := *dialog - dialogCopy.PromptConfig = copied - return &dialogCopy, true -} - -func cloneJSONMap(values entity.JSONMap) entity.JSONMap { - copied := make(entity.JSONMap, len(values)) - for key, value := range values { - copied[key] = value - } - return copied -} - -func copyMessages(messages []map[string]interface{}) []map[string]interface{} { - copied := make([]map[string]interface{}, len(messages)) - for i, msg := range messages { - copied[i] = make(map[string]interface{}, len(msg)) - for key, value := range msg { - copied[i][key] = value - } - } - return copied -} - -func setLatestReference(reference []interface{}, chunks []map[string]interface{}, docAggs []map[string]interface{}) { - ref := map[string]interface{}{ - "chunks": chunksForReference(chunks), - "doc_aggs": mapsForReference(docAggs), - } - if len(reference) == 0 { - return - } - reference[len(reference)-1] = ref -} - -func chunksForReference(chunks []map[string]interface{}) []interface{} { - result := make([]interface{}, 0, len(chunks)) - for _, chunk := range chunks { - copied := make(map[string]interface{}, len(chunk)) - for key, value := range chunk { - if key == "vector" { - continue - } - copied[key] = value - } - result = append(result, copied) - } - return result -} - -func mapsForReference(values []map[string]interface{}) []interface{} { - result := make([]interface{}, 0, len(values)) - for _, value := range values { - result = append(result, value) - } - return result -} - -// asyncChatSolo performs simple chat without RAG (non-streaming) -func (s *ChatSessionService) asyncChatSolo(dialog *entity.Chat, session *entity.ChatSession, messages []map[string]interface{}, config map[string]interface{}, messageID string, reference []interface{}, stream bool) (map[string]interface{}, error) { - common.Info("asyncChatSolo started", - zap.String("tenant_id", dialog.TenantID), - zap.String("llm_id", dialog.LLMID), - zap.String("dialog_id", dialog.ID), - zap.Int("message_count", len(messages))) - - // Get system prompt - systemPrompt := s.buildSystemPrompt(dialog) - - // Process messages - handle attachments and image files - processedMessages := s.processMessages(messages, dialog) - - var ( - driver modelModule.ModelDriver - modelName string - apiConfig *modelModule.APIConfig - err error - ) - if dialog.LLMID != "" { - driver, modelName, apiConfig, _, err = s.modelProviderSvc.GetModelConfigFromProviderInstance( - dialog.TenantID, entity.ModelTypeChat, dialog.LLMID, - ) - } else { - driver, modelName, apiConfig, _, err = s.modelProviderSvc.GetTenantDefaultModelByType( - dialog.TenantID, entity.ModelTypeChat, - ) - } - if err != nil { - common.Error("asyncChatSolo failed to get chat model", err) - return nil, err - } - chatModel := modelModule.NewChatModel(driver, &modelName, apiConfig) - - // Convert messages to Message format - var msgs []modelModule.Message - if systemPrompt != "" { - msgs = append(msgs, modelModule.Message{Role: "system", Content: systemPrompt}) - } - for _, msg := range processedMessages { - role, _ := msg["role"].(string) - if role == "" || role == "system" { - continue - } - - if msg["content"] != nil { - msgs = append(msgs, modelModule.Message{Role: role, Content: msg["content"]}) - } - } - - // Get ChatConfig directly from dialog and config - chatConfig := s.buildChatConfig(dialog, config) - - // Perform chat - response, err := chatModel.ModelDriver.ChatWithMessages(*chatModel.ModelName, msgs, chatModel.APIConfig, chatConfig) - if err != nil { - common.Error("asyncChatSolo chat failed", err) - return nil, err - } - - common.Info("asyncChatSolo completed", - zap.String("tenant_id", dialog.TenantID), - zap.String("llm_id", dialog.LLMID), - zap.Int("response_length", len(*response.Answer))) - - // Structure the answer - ans := map[string]interface{}{ - "answer": *response.Answer, - "reference": reference[len(reference)-1], - "final": true, - } - - return s.structureAnswerWithConv(session, ans, messageID, session.ID, reference), nil -} - -// asyncChatSoloStream performs simple streaming chat without RAG -func (s *ChatSessionService) asyncChatSoloStream(dialog *entity.Chat, session *entity.ChatSession, messages []map[string]interface{}, config map[string]interface{}, messageID string, reference []interface{}, resultChan chan<- map[string]interface{}) { - common.Info("asyncChatSoloStream started", - zap.String("tenant_id", dialog.TenantID), - zap.String("llm_id", dialog.LLMID), - zap.String("dialog_id", dialog.ID), - zap.Int("message_count", len(messages))) - - // Get system prompt - systemPrompt := s.buildSystemPrompt(dialog) - - // Process messages - processedMessages := s.processMessages(messages, dialog) - - var ( - driver modelModule.ModelDriver - modelName string - apiConfig *modelModule.APIConfig - err error - ) - if dialog.LLMID != "" { - driver, modelName, apiConfig, _, err = s.modelProviderSvc.GetModelConfigFromProviderInstance( - dialog.TenantID, entity.ModelTypeChat, dialog.LLMID, - ) - } else { - driver, modelName, apiConfig, _, err = s.modelProviderSvc.GetTenantDefaultModelByType( - dialog.TenantID, entity.ModelTypeChat, - ) - } - if err != nil { - common.Error("asyncChatSoloStream failed to get chat model", err) - resultChan <- s.structureAnswer(session, "**ERROR**: "+err.Error(), messageID, session.ID, reference) - return - } - chatModel := modelModule.NewChatModel(driver, &modelName, apiConfig) - - // Convert messages to []modelModule.Message for ChatStreamlyWithSender - var chatMessages []modelModule.Message - if systemPrompt != "" { - chatMessages = append(chatMessages, modelModule.Message{ - Role: "system", - Content: systemPrompt, - }) - } - for _, msg := range processedMessages { - role, _ := msg["role"].(string) - content := msg["content"] - if role != "" && content != nil && role != "system" { - chatMessages = append(chatMessages, modelModule.Message{ - Role: role, - Content: content, - }) - } - } - - // Get ChatConfig directly from dialog and config - chatConfig := s.buildChatConfig(dialog, config) - - // Perform streaming chat using ChatStreamlyWithSender - fullAnswer := "" - err = chatModel.ModelDriver.ChatStreamlyWithSender(*chatModel.ModelName, chatMessages, chatModel.APIConfig, chatConfig, func(answer *string, reason *string) error { - if reason != nil && *reason != "" { - fullAnswer += *reason - ans := s.structureAnswer(session, fullAnswer, messageID, session.ID, reference) - resultChan <- ans - } - if answer != nil && *answer != "" { - fullAnswer += *answer - fullAnswer = s.removeReasoningContent(fullAnswer) - ans := s.structureAnswer(session, fullAnswer, messageID, session.ID, reference) - resultChan <- ans - } - return nil - }) - if err != nil { - resultChan <- s.structureAnswer(session, "**ERROR**: "+err.Error(), messageID, session.ID, reference) - return - } - - common.Info("asyncChatSoloStream completed", - zap.String("tenant_id", dialog.TenantID), - zap.String("llm_id", dialog.LLMID), - zap.Int("response_length", len(fullAnswer))) -} - -// buildSystemPrompt builds the system prompt from dialog configuration -func (s *ChatSessionService) buildSystemPrompt(dialog *entity.Chat) string { - if dialog.PromptConfig == nil { - return "" - } - - system, _ := dialog.PromptConfig["system"].(string) - return system -} - -// processMessages processes messages and handles attachments -func (s *ChatSessionService) processMessages(messages []map[string]interface{}, dialog *entity.Chat) []map[string]interface{} { - // Process each message - processed := make([]map[string]interface{}, len(messages)) - for i, msg := range messages { - processed[i] = make(map[string]interface{}) - for k, v := range msg { - processed[i][k] = v - } - - // Clean content - remove file markers - if content, ok := msg["content"].(string); ok { - content = s.cleanContent(content) - processed[i]["content"] = content - } - } - - return processed -} - -// cleanContent removes file markers from content -func (s *ChatSessionService) cleanContent(content string) string { - // Remove ##N$$ markers - // This is a simplified version - full implementation would use regex - return content -} - -// removeReasoningContent removes reasoning/thinking content from answer -func (s *ChatSessionService) removeReasoningContent(answer string) string { - // Remove tags - if strings.HasSuffix(answer, "") { - answer = answer[:len(answer)-len("")] - } - return answer -} - // structureAnswerWithConv structures the answer with conversation update (like Python's structure_answer) func (s *ChatSessionService) structureAnswerWithConv(session *entity.ChatSession, ans map[string]interface{}, messageID, conversationID string, reference []interface{}) map[string]interface{} { // Extract reference from answer ref, _ := ans["reference"].(map[string]interface{}) if ref == nil { ref = map[string]interface{}{ - "chunks": []interface{}{}, + "chunks": []map[string]interface{}{}, "doc_aggs": []interface{}{}, } ans["reference"] = ref @@ -1427,105 +588,21 @@ func (s *ChatSessionService) getLastRole(messages []interface{}) string { } // chunksFormat formats chunks for reference (simplified version) -func (s *ChatSessionService) chunksFormat(reference map[string]interface{}) []interface{} { - chunks, _ := reference["chunks"].([]interface{}) - if chunks == nil { - return []interface{}{} - } - - // Format each chunk - formatted := make([]interface{}, len(chunks)) - for i, chunk := range chunks { - formatted[i] = chunk - } - return formatted -} - -// buildChatConfig builds ChatConfig directly from dialog.LLMSetting and config -func (s *ChatSessionService) buildChatConfig(dialog *entity.Chat, config map[string]interface{}) *modelModule.ChatConfig { - cfg := &modelModule.ChatConfig{} - - // Start with dialog's LLM setting - if dialog.LLMSetting != nil { - if v, ok := dialog.LLMSetting["stream"].(bool); ok { - cfg.Stream = &v - } - if v, ok := dialog.LLMSetting["thinking"].(bool); ok { - cfg.Thinking = &v - } - if v, ok := dialog.LLMSetting["max_tokens"].(float64); ok { - intVal := int(v) - cfg.MaxTokens = &intVal - } - if v, ok := dialog.LLMSetting["temperature"].(float64); ok { - cfg.Temperature = &v - } - if v, ok := dialog.LLMSetting["top_p"].(float64); ok { - cfg.TopP = &v - } - if v, ok := dialog.LLMSetting["do_sample"].(bool); ok { - cfg.DoSample = &v - } - if v, ok := dialog.LLMSetting["stop"].([]interface{}); ok { - stopStrs := make([]string, 0, len(v)) - for _, s := range v { - if str, ok := s.(string); ok { - stopStrs = append(stopStrs, str) - } +func (s *ChatSessionService) chunksFormat(reference map[string]interface{}) []map[string]interface{} { + switch c := reference["chunks"].(type) { + case []map[string]interface{}: + formatted := make([]map[string]interface{}, len(c)) + copy(formatted, c) + return formatted + case []interface{}: + formatted := make([]map[string]interface{}, 0, len(c)) + for _, item := range c { + if m, ok := item.(map[string]interface{}); ok { + formatted = append(formatted, m) } - cfg.Stop = &stopStrs - } - if v, ok := dialog.LLMSetting["model_class"].(string); ok { - cfg.ModelClass = &v - } - if v, ok := dialog.LLMSetting["effort"].(string); ok { - cfg.Effort = &v - } - if v, ok := dialog.LLMSetting["verbosity"].(string); ok { - cfg.Verbosity = &v } + return formatted + default: + return []map[string]interface{}{} } - - // Override with request config - if config != nil { - if v, ok := config["stream"].(bool); ok { - cfg.Stream = &v - } - if v, ok := config["thinking"].(bool); ok { - cfg.Thinking = &v - } - if v, ok := config["max_tokens"].(float64); ok { - intVal := int(v) - cfg.MaxTokens = &intVal - } - if v, ok := config["temperature"].(float64); ok { - cfg.Temperature = &v - } - if v, ok := config["top_p"].(float64); ok { - cfg.TopP = &v - } - if v, ok := config["do_sample"].(bool); ok { - cfg.DoSample = &v - } - if v, ok := config["stop"].([]interface{}); ok { - stopStrs := make([]string, 0, len(v)) - for _, s := range v { - if str, ok := s.(string); ok { - stopStrs = append(stopStrs, str) - } - } - cfg.Stop = &stopStrs - } - if v, ok := config["model_class"].(string); ok { - cfg.ModelClass = &v - } - if v, ok := config["effort"].(string); ok { - cfg.Effort = &v - } - if v, ok := config["verbosity"].(string); ok { - cfg.Verbosity = &v - } - } - - return cfg } diff --git a/internal/service/chat_session_test.go b/internal/service/chat_session_test.go index 98d0c25c24..c6e6d16cea 100644 --- a/internal/service/chat_session_test.go +++ b/internal/service/chat_session_test.go @@ -5,945 +5,707 @@ import ( "encoding/json" "errors" "strings" + "sync" "testing" - "ragflow/internal/common" - "ragflow/internal/engine/types" "ragflow/internal/entity" - modelModule "ragflow/internal/entity/models" - "ragflow/internal/service/nlp" ) -type fakeChatKBStore struct { - kbs []*entity.Knowledgebase - accessible map[string]bool -} +// --------------------------------------------------------------------------- +// Fake implementations +// --------------------------------------------------------------------------- -func (f fakeChatKBStore) Accessible(kbID, userID string) bool { - if f.accessible == nil { - return true +type fakeSessionStore struct { + mu sync.Mutex + sessions map[string]*entity.ChatSession + dialogs map[string]*entity.Chat + dialogExists map[string]bool // key: tenantID|chatID + getByIDErr error + createErr error + updateByIDErr error + deleteByIDErr error + getDialogErr error + // record calls + createCalled []*entity.ChatSession + updateCalled []struct { + id string + updates map[string]interface{} } - return f.accessible[kbID] + deleteByIDIDs []string } -func (f fakeChatKBStore) GetByIDs(ids []string) ([]*entity.Knowledgebase, error) { - return f.kbs, nil -} - -type fakeChatMetadataService struct{} - -func (fakeChatMetadataService) LabelQuestion(question string, kbs []*entity.Knowledgebase) map[string]float64 { - return map[string]float64{"pagerank_fea": 10} -} - -func (fakeChatMetadataService) GetFlattedMetaByKBs(kbIDs []string) (common.MetaData, error) { - return common.MetaData{ - "category": common.MetaValueDocs{ - "policy": []string{"doc-policy"}, - }, - }, nil -} - -type failingChatMetadataService struct{} - -func (failingChatMetadataService) LabelQuestion(question string, kbs []*entity.Knowledgebase) map[string]float64 { - return nil -} - -func (failingChatMetadataService) GetFlattedMetaByKBs(kbIDs []string) (common.MetaData, error) { - return nil, errors.New("metadata unavailable") -} - -type fakeChatDocEngine struct { - chunk map[string]interface{} -} - -func (f fakeChatDocEngine) CreateChunkStore(ctx context.Context, baseName, datasetID string, vectorSize int, parserID string) error { - return nil -} - -func (f fakeChatDocEngine) InsertChunks(ctx context.Context, chunks []map[string]interface{}, baseName string, datasetID string) ([]string, error) { - return nil, nil -} - -func (f fakeChatDocEngine) UpdateChunks(ctx context.Context, condition map[string]interface{}, newValue map[string]interface{}, baseName string, datasetID string) error { - return nil -} - -func (f fakeChatDocEngine) DeleteChunks(ctx context.Context, condition map[string]interface{}, baseName string, datasetID string) (int64, error) { - return 0, nil -} - -func (f fakeChatDocEngine) Search(ctx context.Context, req *types.SearchRequest) (*types.SearchResult, error) { - return nil, nil -} - -func (f fakeChatDocEngine) GetChunk(ctx context.Context, baseName, chunkID string, datasetIDs []string) (interface{}, error) { - return f.chunk, nil -} - -func (f fakeChatDocEngine) DropChunkStore(ctx context.Context, baseName, datasetID string) error { - return nil -} - -func (f fakeChatDocEngine) ChunkStoreExists(ctx context.Context, baseName, datasetID string) (bool, error) { - return true, nil -} - -func (f fakeChatDocEngine) CreateMetadataStore(ctx context.Context, tenantID string) error { - return nil -} - -func (f fakeChatDocEngine) InsertMetadata(ctx context.Context, metadata []map[string]interface{}, tenantID string) ([]string, error) { - return nil, nil -} - -func (f fakeChatDocEngine) UpdateMetadata(ctx context.Context, docID string, datasetID string, metaFields map[string]interface{}, tenantID string) error { - return nil -} - -func (f fakeChatDocEngine) DeleteMetadata(ctx context.Context, condition map[string]interface{}, tenantID string) (int64, error) { - return 0, nil -} - -func (f fakeChatDocEngine) DeleteMetadataKeys(ctx context.Context, docID string, datasetID string, keys []string, tenantID string) error { - return nil -} - -func (f fakeChatDocEngine) DropMetadataStore(ctx context.Context, tenantID string) error { - return nil -} - -func (f fakeChatDocEngine) MetadataStoreExists(ctx context.Context, tenantID string) (bool, error) { - return true, nil -} - -func (f fakeChatDocEngine) SearchMetadata(ctx context.Context, req *types.SearchMetadataRequest) (*types.SearchMetadataResult, error) { - return nil, nil -} - -func (f fakeChatDocEngine) IndexDocument(ctx context.Context, indexName, docID string, doc interface{}) error { - return nil -} - -func (f fakeChatDocEngine) DeleteDocument(ctx context.Context, indexName, docID string) error { - return nil -} - -func (f fakeChatDocEngine) BulkIndex(ctx context.Context, indexName string, docs []interface{}) (interface{}, error) { - return nil, nil -} - -func (f fakeChatDocEngine) GetFields(chunks []map[string]interface{}, fields []string) map[string]map[string]interface{} { - return nil -} - -func (f fakeChatDocEngine) GetAggregation(chunks []map[string]interface{}, fieldName string) []map[string]interface{} { - return nil -} - -func (f fakeChatDocEngine) GetHighlight(chunks []map[string]interface{}, keywords []string, fieldName string) map[string]string { - return nil -} - -func (f fakeChatDocEngine) GetChunkIDs(chunks []map[string]interface{}) []string { - return nil -} - -func (f fakeChatDocEngine) KNNScores(ctx context.Context, chunks []map[string]interface{}, queryVector []float64, topK int) (map[string]interface{}, error) { - return nil, nil -} - -func (f fakeChatDocEngine) GetScores(searchResult map[string]interface{}) map[string]float64 { - return nil -} - -func (f fakeChatDocEngine) FilterDocIdsByMetaPushdown(ctx context.Context, kbIDs []string, conditions []map[string]interface{}, logic string) []string { - return nil -} - -func (f fakeChatDocEngine) Ping(ctx context.Context) error { - return nil -} - -func (f fakeChatDocEngine) Close() error { - return nil -} - -func (f fakeChatDocEngine) GetType() string { - return "fake" -} - -type fakeChatRetrievalService struct { - req *nlp.RetrievalRequest - result *nlp.RetrievalResult -} - -func (f *fakeChatRetrievalService) Retrieval(ctx context.Context, req *nlp.RetrievalRequest) (*nlp.RetrievalResult, error) { - f.req = req - return f.result, nil -} - -type fakeChatModelProvider struct { - driver *fakeChatModelDriver -} - -func (f fakeChatModelProvider) GetChatModel(tenantID, compositeModelName string) (*modelModule.ChatModel, error) { - modelName := compositeModelName - return modelModule.NewChatModel(f.driver, &modelName, &modelModule.APIConfig{}), nil -} - -func (f fakeChatModelProvider) GetEmbeddingModel(tenantID, compositeModelName string) (*modelModule.EmbeddingModel, error) { - modelName := compositeModelName - return modelModule.NewEmbeddingModel(f.driver, &modelName, &modelModule.APIConfig{}, 512), nil -} - -func (f fakeChatModelProvider) GetRerankModel(tenantID, compositeModelName string) (*modelModule.RerankModel, error) { - modelName := compositeModelName - return modelModule.NewRerankModel(f.driver, &modelName, &modelModule.APIConfig{}), nil -} - -func (f fakeChatModelProvider) GetModelConfigFromProviderInstance(tenantID string, modelType entity.ModelType, modelName string) (modelModule.ModelDriver, string, *modelModule.APIConfig, int, error) { - return f.driver, modelName, &modelModule.APIConfig{}, 0, nil -} - -func (f fakeChatModelProvider) GetTenantDefaultModelByType(tenantID string, modelType entity.ModelType) (modelModule.ModelDriver, string, *modelModule.APIConfig, int, error) { - modelName := "default@factory" - return f.driver, modelName, &modelModule.APIConfig{}, 0, nil -} - -type fakeChatModelDriver struct { - messages []modelModule.Message -} - -func (f *fakeChatModelDriver) NewInstance(baseURL map[string]string) modelModule.ModelDriver { - return f -} - -func (f *fakeChatModelDriver) Name() string { - return "fake" -} - -func (f *fakeChatModelDriver) ChatWithMessages(modelName string, messages []modelModule.Message, apiConfig *modelModule.APIConfig, chatModelConfig *modelModule.ChatConfig) (*modelModule.ChatResponse, error) { - f.messages = messages - answer := "answer from knowledge" - return &modelModule.ChatResponse{Answer: &answer}, nil -} - -func (f *fakeChatModelDriver) ChatStreamlyWithSender(modelName string, messages []modelModule.Message, apiConfig *modelModule.APIConfig, modelConfig *modelModule.ChatConfig, sender func(*string, *string) error) error { - f.messages = messages - answer := "stream answer from knowledge" - return sender(&answer, nil) -} - -func (f *fakeChatModelDriver) Embed(modelName *string, texts []string, apiConfig *modelModule.APIConfig, embeddingConfig *modelModule.EmbeddingConfig) ([]modelModule.EmbeddingData, error) { - return nil, nil -} - -func (f *fakeChatModelDriver) Rerank(modelName *string, query string, documents []string, apiConfig *modelModule.APIConfig, rerankConfig *modelModule.RerankConfig) (*modelModule.RerankResponse, error) { - return nil, nil -} - -func (f *fakeChatModelDriver) TranscribeAudio(modelName *string, file *string, apiConfig *modelModule.APIConfig, asrConfig *modelModule.ASRConfig) (*modelModule.ASRResponse, error) { - return nil, nil -} - -func (f *fakeChatModelDriver) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *modelModule.APIConfig, asrConfig *modelModule.ASRConfig, sender func(*string, *string) error) error { - return nil -} - -func (f *fakeChatModelDriver) AudioSpeech(modelName *string, audioContent *string, apiConfig *modelModule.APIConfig, ttsConfig *modelModule.TTSConfig) (*modelModule.TTSResponse, error) { - return nil, nil -} - -func (f *fakeChatModelDriver) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *modelModule.APIConfig, ttsConfig *modelModule.TTSConfig, sender func(*string, *string) error) error { - return nil -} - -func (f *fakeChatModelDriver) OCRFile(modelName *string, content []byte, url *string, apiConfig *modelModule.APIConfig, ocrConfig *modelModule.OCRConfig) (*modelModule.OCRFileResponse, error) { - return nil, nil -} - -func (f *fakeChatModelDriver) ParseFile(modelName *string, content []byte, url *string, apiConfig *modelModule.APIConfig, parseFileConfig *modelModule.ParseFileConfig) (*modelModule.ParseFileResponse, error) { - return nil, nil -} - -func (f *fakeChatModelDriver) ListModels(apiConfig *modelModule.APIConfig) ([]modelModule.ListModelResponse, error) { - return nil, nil -} - -func (f *fakeChatModelDriver) Balance(apiConfig *modelModule.APIConfig) (map[string]interface{}, error) { - return nil, nil -} - -func (f *fakeChatModelDriver) CheckConnection(apiConfig *modelModule.APIConfig) error { - return nil -} - -func (f *fakeChatModelDriver) ListTasks(apiConfig *modelModule.APIConfig) ([]modelModule.ListTaskStatus, error) { - return nil, nil -} - -func (f *fakeChatModelDriver) ShowTask(taskID string, apiConfig *modelModule.APIConfig) (*modelModule.TaskResponse, error) { - return nil, nil -} - -func TestAsyncChatUsesRetrievedKnowledgeForKBDialog(t *testing.T) { - driver := &fakeChatModelDriver{} - retrieval := &fakeChatRetrievalService{ - result: &nlp.RetrievalResult{ - Chunks: []map[string]interface{}{ - { - "chunk_id": "chunk-1", - "content_with_weight": "RAGFlow stores conversation references alongside the session.", - "doc_id": "doc-1", - "docnm_kwd": "manual.md", - "vector": []float64{0.1, 0.2}, - }, - }, - DocAggs: []map[string]interface{}{ - {"doc_id": "doc-1", "doc_name": "manual.md", "count": 1}, - }, - }, - } - svc := &ChatSessionService{ - kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{ - {ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"}, - }}, - modelProviderSvc: fakeChatModelProvider{driver: driver}, - metadataSvc: fakeChatMetadataService{}, - retrievalSvc: retrieval, +func newFakeSessionStore() *fakeSessionStore { + return &fakeSessionStore{ + sessions: make(map[string]*entity.ChatSession), + dialogs: make(map[string]*entity.Chat), + dialogExists: make(map[string]bool), } +} - reference := []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}} - sessionMessage, err := json.Marshal(map[string]interface{}{"messages": []interface{}{}}) - if err != nil { - t.Fatalf("failed to marshal session message: %v", err) +func (f *fakeSessionStore) GetByID(id string) (*entity.ChatSession, error) { + if f.getByIDErr != nil { + return nil, f.getByIDErr } - session := &entity.ChatSession{ID: "session-1", Message: sessionMessage} - dialog := &entity.Chat{ - ID: "dialog-1", - TenantID: "tenant-1", - LLMID: "chat@factory", - PromptConfig: entity.JSONMap{"system": "You are helpful."}, - LLMSetting: entity.JSONMap{}, - KBIDs: entity.JSONSlice{"kb-1"}, - TopN: 3, - TopK: 32, - SimilarityThreshold: 0.2, - VectorSimilarityWeight: 0.3, - } - - result, err := svc.asyncChat("user-1", dialog, session, []map[string]interface{}{ - {"role": "user", "content": "Where are references stored?"}, - }, nil, "message-1", reference, false) - if err != nil { - t.Fatalf("asyncChat returned error: %v", err) - } - - if retrieval.req == nil { - t.Fatal("expected retrieval service to be called") - } - if retrieval.req.Question != "Where are references stored?" { - t.Fatalf("unexpected retrieval question: %q", retrieval.req.Question) - } - if retrieval.req.PageSize != 3 || retrieval.req.Top == nil || *retrieval.req.Top != 32 { - t.Fatalf("unexpected retrieval paging: page_size=%d top=%v", retrieval.req.PageSize, retrieval.req.Top) - } - if len(driver.messages) == 0 { - t.Fatal("expected chat model to receive messages") - } - last := driver.messages[len(driver.messages)-1] - content, ok := last.Content.(string) + s, ok := f.sessions[id] if !ok { - t.Fatalf("expected string content, got %T", last.Content) - } - if !strings.Contains(content, "RAGFlow stores conversation references") { - t.Fatalf("expected retrieved content in prompt, got %q", content) - } - - ref, ok := result["reference"].(map[string]interface{}) - if !ok { - t.Fatalf("expected reference map, got %T", result["reference"]) - } - chunks, ok := ref["chunks"].([]interface{}) - if !ok || len(chunks) != 1 { - t.Fatalf("expected one reference chunk, got %#v", ref["chunks"]) - } - chunk, ok := chunks[0].(map[string]interface{}) - if !ok { - t.Fatalf("expected chunk map, got %T", chunks[0]) - } - if _, exists := chunk["vector"]; exists { - t.Fatal("reference chunk should not expose vector") - } - if result["answer"] != "answer from knowledge" { - t.Fatalf("unexpected answer: %#v", result["answer"]) + return nil, errors.New("record not found") } + return s, nil } -func TestAsyncChatPropagatesRetrievalErrors(t *testing.T) { - retrievalErr := errors.New("search unavailable") - retrieval := &failingChatRetrievalService{err: retrievalErr} - svc := &ChatSessionService{ - kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{ - {ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"}, - }}, - modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}}, - metadataSvc: fakeChatMetadataService{}, - retrievalSvc: retrieval, - } - - _, err := svc.asyncChat("user-1", &entity.Chat{ - ID: "dialog-1", - TenantID: "tenant-1", - LLMID: "chat@factory", - PromptConfig: entity.JSONMap{}, - LLMSetting: entity.JSONMap{}, - KBIDs: entity.JSONSlice{"kb-1"}, - TopN: 3, - TopK: 32, - SimilarityThreshold: 0.2, - VectorSimilarityWeight: 0.3, - }, &entity.ChatSession{ID: "session-1"}, []map[string]interface{}{ - {"role": "user", "content": "question"}, - }, nil, "message-1", []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}}, false) - if err == nil || !strings.Contains(err.Error(), "retrieval search failed") { - t.Fatalf("expected retrieval error, got %v", err) +func (f *fakeSessionStore) Create(conv *entity.ChatSession) error { + f.mu.Lock() + defer f.mu.Unlock() + if f.createErr != nil { + return f.createErr } + f.sessions[conv.ID] = conv + f.createCalled = append(f.createCalled, conv) + return nil } -func TestMessagesWithRetrievedKnowledgeFillsSystemPlaceholder(t *testing.T) { - retrieval := &fakeChatRetrievalService{ - result: &nlp.RetrievalResult{ - Chunks: []map[string]interface{}{ - {"content_with_weight": "Knowledge inserted into the system prompt."}, - }, - }, +func (f *fakeSessionStore) UpdateByID(id string, updates map[string]interface{}) error { + f.mu.Lock() + defer f.mu.Unlock() + if f.updateByIDErr != nil { + return f.updateByIDErr } + f.updateCalled = append(f.updateCalled, struct { + id string + updates map[string]interface{} + }{id, updates}) + return nil +} + +func (f *fakeSessionStore) DeleteByID(id string) error { + f.mu.Lock() + defer f.mu.Unlock() + if f.deleteByIDErr != nil { + return f.deleteByIDErr + } + f.deleteByIDIDs = append(f.deleteByIDIDs, id) + delete(f.sessions, id) + return nil +} + +func (f *fakeSessionStore) ListByChatID(chatID string) ([]*entity.ChatSession, error) { + var result []*entity.ChatSession + for _, s := range f.sessions { + if s.DialogID == chatID { + result = append(result, s) + } + } + return result, nil +} + +func (f *fakeSessionStore) GetDialogByID(chatID string) (*entity.Chat, error) { + if f.getDialogErr != nil { + return nil, f.getDialogErr + } + d, ok := f.dialogs[chatID] + if !ok { + return nil, errors.New("dialog not found") + } + return d, nil +} + +func (f *fakeSessionStore) CheckDialogExists(tenantID, chatID string) (bool, error) { + key := tenantID + "|" + chatID + return f.dialogExists[key], nil +} + +// --------------------------------------------------------------------------- + +type fakeTenantStore struct { + tenantIDs []string + err error +} + +func (f *fakeTenantStore) GetTenantIDsByUserID(userID string) ([]string, error) { + return f.tenantIDs, f.err +} + +// --------------------------------------------------------------------------- + +type fakePipeline struct { + resultChan <-chan AsyncChatResult + err error +} + +func (f *fakePipeline) AsyncChat(ctx context.Context, chat *entity.Chat, messages []map[string]interface{}, stream bool, kwargs map[string]interface{}) (<-chan AsyncChatResult, error) { + return f.resultChan, f.err +} + +func makeResultChan(results ...AsyncChatResult) <-chan AsyncChatResult { + ch := make(chan AsyncChatResult, len(results)) + for _, r := range results { + ch <- r + } + close(ch) + return ch +} + +// =================================================================== +// SetChatSession tests +// =================================================================== + +func TestSetChatSession_CreateNew(t *testing.T) { + store := newFakeSessionStore() + dialog := &entity.Chat{ID: "dialog-1", PromptConfig: entity.JSONMap{"prologue": "Welcome!"}} + store.dialogs["dialog-1"] = dialog + svc := &ChatSessionService{ - kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{ - {ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"}, - }}, - modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}}, - metadataSvc: fakeChatMetadataService{}, - retrievalSvc: retrieval, - } - dialog := &entity.Chat{ - ID: "dialog-1", - TenantID: "tenant-1", - PromptConfig: entity.JSONMap{"system": "Answer from this context: {knowledge}"}, - KBIDs: entity.JSONSlice{"kb-1"}, - TopN: 3, - TopK: 32, - } - messages := []map[string]interface{}{ - {"role": "user", "content": "What context is available?"}, + chatSessionDAO: store, + userTenantDAO: &fakeTenantStore{}, + pipeline: &fakePipeline{}, } - got, ragDialog, emptyResponse, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", dialog, messages, []interface{}{ - map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}, + resp, err := svc.SetChatSession("user-1", &SetChatSessionRequest{ + DialogID: "dialog-1", + IsNew: true, }) if err != nil { - t.Fatalf("messagesWithRetrievedKnowledge returned error: %v", err) + t.Fatalf("unexpected error: %v", err) } - if emptyResponse != nil { - t.Fatalf("expected no empty response, got %q", *emptyResponse) + if resp.ID == "" { + t.Fatal("expected session ID to be generated") } - if got[0]["content"] != "What context is available?" { - t.Fatalf("expected user content to stay unchanged, got %q", got[0]["content"]) + if resp.DialogID != "dialog-1" { + t.Fatalf("expected dialog_id=dialog-1, got %s", resp.DialogID) } - originalPrompt, _ := dialog.PromptConfig["system"].(string) - if !strings.Contains(originalPrompt, "{knowledge}") { - t.Fatalf("expected original dialog prompt to remain unchanged, got %q", originalPrompt) + if len(store.createCalled) != 1 { + t.Fatalf("expected 1 Create call, got %d", len(store.createCalled)) } - systemPrompt, _ := ragDialog.PromptConfig["system"].(string) - if strings.Contains(systemPrompt, "{knowledge}") { - t.Fatalf("expected knowledge placeholder to be replaced, got %q", systemPrompt) + + // Verify prologue is in the message + var msgObj map[string]interface{} + if err := json.Unmarshal(store.createCalled[0].Message, &msgObj); err != nil { + t.Fatalf("failed to unmarshal message: %v", err) } - if !strings.Contains(systemPrompt, "Knowledge inserted into the system prompt.") { - t.Fatalf("expected retrieved knowledge in system prompt, got %q", systemPrompt) + msgs, _ := msgObj["messages"].([]interface{}) + if len(msgs) != 1 { + t.Fatalf("expected 1 initial message, got %d", len(msgs)) + } + firstMsg, _ := msgs[0].(map[string]interface{}) + if firstMsg["role"] != "assistant" || firstMsg["content"] != "Welcome!" { + t.Fatalf("unexpected prologue message: %#v", firstMsg) } } -func TestAsyncChatReturnsEmptyResponseWhenRetrievalHasNoKnowledge(t *testing.T) { - driver := &fakeChatModelDriver{} +func TestSetChatSession_CreateNewDefaultPrologue(t *testing.T) { + store := newFakeSessionStore() + store.dialogs["dialog-1"] = &entity.Chat{ID: "dialog-1"} + svc := &ChatSessionService{ - kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{ - {ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"}, - }}, - modelProviderSvc: fakeChatModelProvider{driver: driver}, - metadataSvc: fakeChatMetadataService{}, - retrievalSvc: &fakeChatRetrievalService{result: &nlp.RetrievalResult{}}, + chatSessionDAO: store, + userTenantDAO: &fakeTenantStore{}, + pipeline: &fakePipeline{}, } - reference := []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}} - sessionMessage, err := json.Marshal(map[string]interface{}{"messages": []interface{}{}}) + + resp, err := svc.SetChatSession("user-1", &SetChatSessionRequest{ + DialogID: "dialog-1", + IsNew: true, + }) if err != nil { - t.Fatalf("failed to marshal session message: %v", err) + t.Fatalf("unexpected error: %v", err) } - result, err := svc.asyncChat("user-1", &entity.Chat{ - ID: "dialog-1", - TenantID: "tenant-1", - LLMID: "chat@factory", - PromptConfig: entity.JSONMap{"empty_response": "No relevant content."}, - LLMSetting: entity.JSONMap{}, - KBIDs: entity.JSONSlice{"kb-1"}, - TopN: 3, - TopK: 32, - }, &entity.ChatSession{ID: "session-1", Message: sessionMessage}, []map[string]interface{}{ - {"role": "user", "content": "question"}, - }, nil, "message-1", reference, false) + if resp.ID == "" { + t.Fatal("expected session ID") + } + // Default prologue + var msgObj map[string]interface{} + json.Unmarshal(store.createCalled[0].Message, &msgObj) + msgs, _ := msgObj["messages"].([]interface{}) + firstMsg, _ := msgs[0].(map[string]interface{}) + if !strings.Contains(firstMsg["content"].(string), "Hi! I'm your assistant") { + t.Fatalf("expected default prologue, got %q", firstMsg["content"]) + } +} + +func TestSetChatSession_CreateNewDialogNotFound(t *testing.T) { + store := newFakeSessionStore() + + svc := &ChatSessionService{ + chatSessionDAO: store, + userTenantDAO: &fakeTenantStore{}, + pipeline: &fakePipeline{}, + } + + _, err := svc.SetChatSession("user-1", &SetChatSessionRequest{ + DialogID: "nonexistent", + IsNew: true, + }) + if err == nil || err.Error() != "Dialog not found" { + t.Fatalf("expected 'Dialog not found' error, got %v", err) + } +} + +func TestSetChatSession_UpdateExisting(t *testing.T) { + store := newFakeSessionStore() + store.sessions["session-1"] = &entity.ChatSession{ + ID: "session-1", DialogID: "dialog-1", Name: strPtr("old name"), + } + + svc := &ChatSessionService{ + chatSessionDAO: store, + userTenantDAO: &fakeTenantStore{}, + pipeline: &fakePipeline{}, + } + + resp, err := svc.SetChatSession("user-1", &SetChatSessionRequest{ + SessionID: "session-1", + Name: "new name", + IsNew: false, + }) if err != nil { - t.Fatalf("asyncChat returned error: %v", err) + t.Fatalf("unexpected error: %v", err) } - if result["answer"] != "No relevant content." { - t.Fatalf("unexpected empty response answer: %#v", result["answer"]) + if resp.ID != "session-1" { + t.Fatalf("expected session-1, got %s", resp.ID) } - if len(driver.messages) != 0 { - t.Fatal("chat model should not be called when empty_response is returned") + if len(store.updateCalled) != 1 { + t.Fatalf("expected UpdateByID call, got %d", len(store.updateCalled)) } } -func TestMessagesWithRetrievedKnowledgeAppliesMetadataFilter(t *testing.T) { - retrieval := &fakeChatRetrievalService{result: &nlp.RetrievalResult{}} +func TestSetChatSession_UpdateNotFound(t *testing.T) { + store := newFakeSessionStore() + store.updateByIDErr = errors.New("Chat session not found") + svc := &ChatSessionService{ - kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{ - {ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"}, - }}, - modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}}, - metadataSvc: fakeChatMetadataService{}, - retrievalSvc: retrieval, + chatSessionDAO: store, + userTenantDAO: &fakeTenantStore{}, + pipeline: &fakePipeline{}, } - filter := entity.JSONMap{ - "method": "manual", - "manual": []interface{}{ - map[string]interface{}{"key": "category", "op": "=", "value": "policy"}, - }, - "logic": "and", + + _, err := svc.SetChatSession("user-1", &SetChatSessionRequest{ + SessionID: "missing", + IsNew: false, + }) + if err == nil || err.Error() != "Chat session not found" { + t.Fatalf("expected 'Chat session not found' error, got %v", err) } - _, _, _, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", &entity.Chat{ - ID: "dialog-1", - TenantID: "tenant-1", - PromptConfig: entity.JSONMap{}, - MetaDataFilter: &filter, - KBIDs: entity.JSONSlice{"kb-1"}, - TopN: 3, - TopK: 32, - }, []map[string]interface{}{ - {"role": "user", "content": "question"}, - }, []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}}) +} + +func TestSetChatSession_NameTruncation(t *testing.T) { + store := newFakeSessionStore() + store.dialogs["dialog-1"] = &entity.Chat{ID: "dialog-1"} + + svc := &ChatSessionService{ + chatSessionDAO: store, + userTenantDAO: &fakeTenantStore{}, + pipeline: &fakePipeline{}, + } + + longName := strings.Repeat("x", 300) + resp, err := svc.SetChatSession("user-1", &SetChatSessionRequest{ + DialogID: "dialog-1", + Name: longName, + IsNew: true, + }) if err != nil { - t.Fatalf("messagesWithRetrievedKnowledge returned error: %v", err) + t.Fatalf("unexpected error: %v", err) } - if retrieval.req == nil { - t.Fatal("expected retrieval to be called") - } - if len(retrieval.req.DocIDs) != 1 || retrieval.req.DocIDs[0] != "doc-policy" { - t.Fatalf("expected metadata-filtered doc id, got %#v", retrieval.req.DocIDs) + if resp.Name == nil || len(*resp.Name) > 255 { + t.Fatalf("expected name truncated to <=255, got len=%d", len(*resp.Name)) } } -func TestMessagesWithRetrievedKnowledgeIntersectsDocIDsWithMetadataFilter(t *testing.T) { - retrieval := &fakeChatRetrievalService{result: &nlp.RetrievalResult{}} +// =================================================================== +// RemoveChatSessions tests +// =================================================================== + +func TestRemoveChatSessions_Success(t *testing.T) { + store := newFakeSessionStore() + store.sessions["conv-1"] = &entity.ChatSession{ID: "conv-1", DialogID: "dialog-1"} + store.sessions["conv-2"] = &entity.ChatSession{ID: "conv-2", DialogID: "dialog-1"} + store.dialogExists["user-1|dialog-1"] = true + svc := &ChatSessionService{ - kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{ - {ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"}, - }}, - modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}}, - metadataSvc: fakeChatMetadataService{}, - retrievalSvc: retrieval, - } - filter := entity.JSONMap{ - "method": "manual", - "manual": []interface{}{ - map[string]interface{}{"key": "category", "op": "=", "value": "policy"}, - }, - "logic": "and", + chatSessionDAO: store, + userTenantDAO: &fakeTenantStore{tenantIDs: []string{"tenant-1"}}, + pipeline: &fakePipeline{}, } - _, _, _, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", &entity.Chat{ - ID: "dialog-1", - TenantID: "tenant-1", - PromptConfig: entity.JSONMap{}, - MetaDataFilter: &filter, - KBIDs: entity.JSONSlice{"kb-1"}, - TopN: 3, - TopK: 32, - }, []map[string]interface{}{ - {"role": "user", "content": "question", "doc_ids": []interface{}{"doc-explicit", "doc-policy"}}, - }, []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}}) + err := svc.RemoveChatSessions("user-1", []string{"conv-1", "conv-2"}) if err != nil { - t.Fatalf("messagesWithRetrievedKnowledge returned error: %v", err) + t.Fatalf("unexpected error: %v", err) } - if len(retrieval.req.DocIDs) != 1 || retrieval.req.DocIDs[0] != "doc-policy" { - t.Fatalf("expected metadata and message doc_ids intersection, got %#v", retrieval.req.DocIDs) + if len(store.deleteByIDIDs) != 2 { + t.Fatalf("expected 2 deletes, got %d", len(store.deleteByIDIDs)) } } -func TestMessagesWithRetrievedKnowledgeNoMetadataIntersectionUsesSentinel(t *testing.T) { - retrieval := &fakeChatRetrievalService{result: &nlp.RetrievalResult{}} +func TestRemoveChatSessions_SessionNotFound(t *testing.T) { + store := newFakeSessionStore() svc := &ChatSessionService{ - kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{ - {ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"}, - }}, - modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}}, - metadataSvc: fakeChatMetadataService{}, - retrievalSvc: retrieval, - } - filter := entity.JSONMap{ - "method": "manual", - "manual": []interface{}{ - map[string]interface{}{"key": "category", "op": "=", "value": "policy"}, - }, - "logic": "and", + chatSessionDAO: store, + userTenantDAO: &fakeTenantStore{tenantIDs: []string{"tenant-1"}}, + pipeline: &fakePipeline{}, } - _, _, _, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", &entity.Chat{ - ID: "dialog-1", - TenantID: "tenant-1", - PromptConfig: entity.JSONMap{}, - MetaDataFilter: &filter, - KBIDs: entity.JSONSlice{"kb-1"}, - TopN: 3, - TopK: 32, - }, []map[string]interface{}{ - {"role": "user", "content": "question", "doc_ids": []interface{}{"doc-explicit"}}, - }, []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}}) + err := svc.RemoveChatSessions("user-1", []string{"missing"}) + if err == nil || !strings.Contains(err.Error(), "not found") { + t.Fatalf("expected 'not found' error, got %v", err) + } +} + +func TestRemoveChatSessions_NotOwner(t *testing.T) { + store := newFakeSessionStore() + store.sessions["conv-1"] = &entity.ChatSession{ID: "conv-1", DialogID: "dialog-1"} + // No tenant matches — dialogExists stays false for all combinations + + svc := &ChatSessionService{ + chatSessionDAO: store, + userTenantDAO: &fakeTenantStore{tenantIDs: []string{"tenant-other"}}, + pipeline: &fakePipeline{}, + } + + err := svc.RemoveChatSessions("user-1", []string{"conv-1"}) + if err == nil || !strings.Contains(err.Error(), "Only owner") { + t.Fatalf("expected 'Only owner' error, got %v", err) + } +} + +// =================================================================== +// ListChatSessions tests +// =================================================================== + +func TestListChatSessions_Success(t *testing.T) { + store := newFakeSessionStore() + store.sessions["s1"] = &entity.ChatSession{ID: "s1", DialogID: "chat-1"} + store.sessions["s2"] = &entity.ChatSession{ID: "s2", DialogID: "chat-1"} + store.dialogExists["tenant-1|chat-1"] = true + + svc := &ChatSessionService{ + chatSessionDAO: store, + userTenantDAO: &fakeTenantStore{tenantIDs: []string{"tenant-1"}}, + pipeline: &fakePipeline{}, + } + + resp, err := svc.ListChatSessions("user-1", "chat-1") if err != nil { - t.Fatalf("messagesWithRetrievedKnowledge returned error: %v", err) + t.Fatalf("unexpected error: %v", err) } - if len(retrieval.req.DocIDs) != 1 || retrieval.req.DocIDs[0] != NoMatchDocIDSentinel { - t.Fatalf("expected empty metadata/doc_ids intersection sentinel, got %#v", retrieval.req.DocIDs) + if len(resp.Sessions) != 2 { + t.Fatalf("expected 2 sessions, got %d", len(resp.Sessions)) } } -func TestMessagesWithRetrievedKnowledgePreservesEmptyMetadataFilterMatches(t *testing.T) { - retrieval := &fakeChatRetrievalService{result: &nlp.RetrievalResult{}} - svc := &ChatSessionService{ - kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{ - {ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"}, - }}, - modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}}, - metadataSvc: fakeChatMetadataService{}, - retrievalSvc: retrieval, - } - filter := entity.JSONMap{"method": "auto"} +func TestListChatSessions_NotOwner(t *testing.T) { + store := newFakeSessionStore() - _, _, _, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", &entity.Chat{ - ID: "dialog-1", - TenantID: "tenant-1", - LLMID: "chat@factory", - PromptConfig: entity.JSONMap{}, - MetaDataFilter: &filter, - KBIDs: entity.JSONSlice{"kb-1"}, - TopN: 3, - TopK: 32, - }, []map[string]interface{}{ - {"role": "user", "content": "question"}, - }, []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}}) + svc := &ChatSessionService{ + chatSessionDAO: store, + userTenantDAO: &fakeTenantStore{tenantIDs: []string{"tenant-other"}}, + pipeline: &fakePipeline{}, + } + + _, err := svc.ListChatSessions("user-1", "chat-1") + if err == nil || !strings.Contains(err.Error(), "only owner") { + t.Fatalf("expected 'only owner' error, got %v", err) + } +} + +// =================================================================== +// Completion tests +// =================================================================== + +func TestCompletion_Success(t *testing.T) { + store := newFakeSessionStore() + session := &entity.ChatSession{ + ID: "session-1", DialogID: "dialog-1", + Message: json.RawMessage(`{"messages":[]}`), + Reference: json.RawMessage(`[]`), + } + store.sessions["session-1"] = session + store.dialogs["dialog-1"] = &entity.Chat{ + ID: "dialog-1", TenantID: "tenant-1", LLMID: "chat@factory", + LLMSetting: entity.JSONMap{}, + } + + pipeline := &fakePipeline{ + resultChan: makeResultChan( + AsyncChatResult{Answer: "Hello", Reference: map[string]interface{}{"chunks": []interface{}{}}}, + AsyncChatResult{Answer: " world", Final: true, Reference: map[string]interface{}{"chunks": []interface{}{}}}, + ), + } + + svc := &ChatSessionService{ + chatSessionDAO: store, + userTenantDAO: &fakeTenantStore{}, + pipeline: pipeline, + } + + result, err := svc.Completion("user-1", "session-1", []map[string]interface{}{ + {"role": "user", "content": "hi"}, + }, "", nil, "msg-1") if err != nil { - t.Fatalf("messagesWithRetrievedKnowledge returned error: %v", err) + t.Fatalf("unexpected error: %v", err) } - if len(retrieval.req.DocIDs) != 1 || retrieval.req.DocIDs[0] != NoMatchDocIDSentinel { - t.Fatalf("expected empty metadata filter sentinel, got %#v", retrieval.req.DocIDs) + ans, _ := result["answer"].(string) + if ans != "Hello world" { + t.Fatalf("expected answer 'Hello world', got %q", ans) } } -func TestMessagesWithRetrievedKnowledgeFailsClosedWhenMetadataUnavailable(t *testing.T) { - retrieval := &fakeChatRetrievalService{result: &nlp.RetrievalResult{}} +func TestCompletion_EmptyMessages(t *testing.T) { svc := &ChatSessionService{ - kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{ - {ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"}, - }}, - modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}}, - metadataSvc: failingChatMetadataService{}, - retrievalSvc: retrieval, - } - filter := entity.JSONMap{ - "method": "manual", - "manual": []interface{}{ - map[string]interface{}{"key": "category", "op": "=", "value": "policy"}, - }, - "logic": "and", + chatSessionDAO: &fakeSessionStore{}, + userTenantDAO: &fakeTenantStore{}, + pipeline: &fakePipeline{}, } - _, _, _, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", &entity.Chat{ - ID: "dialog-1", - TenantID: "tenant-1", - PromptConfig: entity.JSONMap{}, - MetaDataFilter: &filter, - KBIDs: entity.JSONSlice{"kb-1"}, - TopN: 3, - TopK: 32, - }, []map[string]interface{}{ - {"role": "user", "content": "question", "doc_ids": []interface{}{"doc-explicit"}}, - }, []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}}) - if err == nil || !strings.Contains(err.Error(), "flattened metadata") { - t.Fatalf("expected metadata filter error, got %v", err) - } - if retrieval.req != nil { - t.Fatal("retrieval should not run when metadata filtering cannot be evaluated") + _, err := svc.Completion("user-1", "session-1", nil, "", nil, "msg-1") + if err == nil || err.Error() != "messages cannot be empty" { + t.Fatalf("expected 'messages cannot be empty', got %v", err) } } -func TestMessagesWithRetrievedKnowledgeExpandsChildChunks(t *testing.T) { - retrieval := &fakeChatRetrievalService{result: &nlp.RetrievalResult{ - Chunks: []map[string]interface{}{ - { - "chunk_id": "child-1", - "mom_id": "parent-1", - "kb_id": "kb-1", - "doc_id": "doc-1", - "docnm_kwd": "doc.md", - "content_ltks": "child tokens", - "content_with_weight": "child-only passage", - "similarity": 0.8, - }, - }, - }} +func TestCompletion_LastMessageNotFromUser(t *testing.T) { svc := &ChatSessionService{ - kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{ - {ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"}, - }}, - docEngine: fakeChatDocEngine{chunk: map[string]interface{}{ - "doc_id": "doc-1", - "docnm_kwd": "doc.md", - "kb_id": "kb-1", - "content_with_weight": "parent passage with surrounding context", - "position_int": []interface{}{1}, - }}, - modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}}, - metadataSvc: fakeChatMetadataService{}, - retrievalSvc: retrieval, + chatSessionDAO: &fakeSessionStore{}, + userTenantDAO: &fakeTenantStore{}, + pipeline: &fakePipeline{}, } - ragMessages, _, _, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", &entity.Chat{ - ID: "dialog-1", - TenantID: "tenant-1", - PromptConfig: entity.JSONMap{}, - KBIDs: entity.JSONSlice{"kb-1"}, - TopN: 3, - TopK: 32, - }, []map[string]interface{}{ - {"role": "user", "content": "question"}, - }, []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}}) - if err != nil { - t.Fatalf("messagesWithRetrievedKnowledge returned error: %v", err) - } - content, _ := ragMessages[0]["content"].(string) - if !strings.Contains(content, "parent passage with surrounding context") { - t.Fatalf("expected expanded parent content in prompt, got %q", content) - } - if strings.Contains(content, "child-only passage") { - t.Fatalf("expected child content to be replaced by expanded parent content, got %q", content) + _, err := svc.Completion("user-1", "session-1", []map[string]interface{}{ + {"role": "assistant", "content": "hello"}, + }, "", nil, "msg-1") + if err == nil || !strings.Contains(err.Error(), "not from user") { + t.Fatalf("expected 'not from user' error, got %v", err) } } -func TestMessagesWithRetrievedKnowledgeRejectsCrossTenantKnowledgebase(t *testing.T) { - retrieval := &fakeChatRetrievalService{result: &nlp.RetrievalResult{}} +func TestCompletion_ConversationNotFound(t *testing.T) { + store := newFakeSessionStore() + svc := &ChatSessionService{ - kbDAO: fakeChatKBStore{ - kbs: []*entity.Knowledgebase{ - {ID: "kb-1", TenantID: "tenant-2", Name: "Manual", EmbdID: "embed@factory"}, - }, - accessible: map[string]bool{"kb-1": false}, - }, - modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}}, - metadataSvc: fakeChatMetadataService{}, - retrievalSvc: retrieval, + chatSessionDAO: store, + userTenantDAO: &fakeTenantStore{}, + pipeline: &fakePipeline{}, } - _, _, _, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", &entity.Chat{ - ID: "dialog-1", - TenantID: "tenant-1", - PromptConfig: entity.JSONMap{}, - KBIDs: entity.JSONSlice{"kb-1"}, - TopN: 3, - TopK: 32, - }, []map[string]interface{}{ - {"role": "user", "content": "question"}, - }, []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}}) - if err == nil || !strings.Contains(err.Error(), "not authorized") { - t.Fatalf("expected cross-tenant authorization error, got %v", err) - } - if retrieval.req != nil { - t.Fatal("retrieval should not be called for an unauthorized knowledge base") + _, err := svc.Completion("user-1", "missing", []map[string]interface{}{ + {"role": "user", "content": "hi"}, + }, "", nil, "msg-1") + if err == nil || err.Error() != "Conversation not found" { + t.Fatalf("expected 'Conversation not found', got %v", err) } } -func TestMessagesWithRetrievedKnowledgeAllowsAccessibleSharedKnowledgebase(t *testing.T) { - retrieval := &fakeChatRetrievalService{result: &nlp.RetrievalResult{}} +func TestCompletion_DialogNotFound(t *testing.T) { + store := newFakeSessionStore() + store.sessions["session-1"] = &entity.ChatSession{ + ID: "session-1", DialogID: "dialog-1", + Message: json.RawMessage(`{"messages":[]}`), + Reference: json.RawMessage(`[]`), + } + svc := &ChatSessionService{ - kbDAO: fakeChatKBStore{ - kbs: []*entity.Knowledgebase{ - {ID: "kb-1", TenantID: "tenant-2", Name: "Shared Manual", EmbdID: "embed@factory"}, - }, - accessible: map[string]bool{"kb-1": true}, - }, - modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}}, - metadataSvc: fakeChatMetadataService{}, - retrievalSvc: retrieval, + chatSessionDAO: store, + userTenantDAO: &fakeTenantStore{}, + pipeline: &fakePipeline{}, } - _, _, _, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", &entity.Chat{ - ID: "dialog-1", - TenantID: "tenant-1", - PromptConfig: entity.JSONMap{}, - KBIDs: entity.JSONSlice{"kb-1"}, - TopN: 3, - TopK: 32, - }, []map[string]interface{}{ - {"role": "user", "content": "question"}, - }, []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}}) - if err != nil { - t.Fatalf("messagesWithRetrievedKnowledge returned error: %v", err) - } - if retrieval.req == nil || len(retrieval.req.TenantIDs) != 1 || retrieval.req.TenantIDs[0] != "tenant-2" { - t.Fatalf("expected retrieval to use shared KB tenant, got %#v", retrieval.req) + _, err := svc.Completion("user-1", "session-1", []map[string]interface{}{ + {"role": "user", "content": "hi"}, + }, "", nil, "msg-1") + if err == nil || err.Error() != "Dialog not found" { + t.Fatalf("expected 'Dialog not found', got %v", err) } } -func TestMessagesWithRetrievedKnowledgeRejectsMixedEmbeddingModels(t *testing.T) { - retrieval := &fakeChatRetrievalService{result: &nlp.RetrievalResult{}} +func TestCompletion_PipelineError(t *testing.T) { + store := newFakeSessionStore() + store.sessions["session-1"] = &entity.ChatSession{ + ID: "session-1", DialogID: "dialog-1", + Message: json.RawMessage(`{"messages":[]}`), + Reference: json.RawMessage(`[]`), + } + store.dialogs["dialog-1"] = &entity.Chat{ + ID: "dialog-1", TenantID: "tenant-1", LLMID: "chat@factory", + LLMSetting: entity.JSONMap{}, + } + svc := &ChatSessionService{ - kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{ - {ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed-a@factory"}, - {ID: "kb-2", TenantID: "tenant-1", Name: "FAQ", EmbdID: "embed-b@factory"}, - }}, - modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}}, - metadataSvc: fakeChatMetadataService{}, - retrievalSvc: retrieval, + chatSessionDAO: store, + userTenantDAO: &fakeTenantStore{}, + pipeline: &fakePipeline{err: errors.New("model unavailable")}, } - _, _, _, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", &entity.Chat{ - ID: "dialog-1", - TenantID: "tenant-1", - PromptConfig: entity.JSONMap{}, - KBIDs: entity.JSONSlice{"kb-1", "kb-2"}, - TopN: 3, - TopK: 32, - }, []map[string]interface{}{ - {"role": "user", "content": "question"}, - }, []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}}) - if err == nil || !strings.Contains(err.Error(), "same embedding model") { - t.Fatalf("expected mixed embedding model error, got %v", err) - } - if retrieval.req != nil { - t.Fatal("retrieval should not run when knowledge bases use different embedding models") + _, err := svc.Completion("user-1", "session-1", []map[string]interface{}{ + {"role": "user", "content": "hi"}, + }, "", nil, "msg-1") + if err == nil || err.Error() != "model unavailable" { + t.Fatalf("expected 'model unavailable' error, got %v", err) } } -func TestValidateKnowledgebaseEmbeddingModelsComparesResolvedNames(t *testing.T) { - firstTenantEmbdID := int64(1) - secondTenantEmbdID := int64(2) - kbs := []*entity.Knowledgebase{ - {ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "same-legacy-name", TenantEmbdID: &firstTenantEmbdID}, - {ID: "kb-2", TenantID: "tenant-1", Name: "FAQ", EmbdID: "same-legacy-name", TenantEmbdID: &secondTenantEmbdID}, - } - resolver := func(tenantID string, kb *entity.Knowledgebase) (string, error) { - if kb.TenantEmbdID != nil && *kb.TenantEmbdID == firstTenantEmbdID { - return "embed-a@factory", nil +// =================================================================== +// CompletionStream tests +// =================================================================== + +func readStreamChan(ch <-chan string, n int) []string { + var msgs []string + for i := 0; i < n; i++ { + select { + case msg, ok := <-ch: + if !ok { + return msgs + } + msgs = append(msgs, msg) + default: + return msgs } - return "embed-b@factory", nil - } - - _, _, err := validateKnowledgebaseEmbeddingModels(kbs, "tenant-1", resolver) - if err == nil || !strings.Contains(err.Error(), "same embedding model") { - t.Fatalf("expected resolved mixed embedding model error, got %v", err) } + return msgs } -func TestMessagesWithRetrievedKnowledgePreservesMultimodalContent(t *testing.T) { - retrieval := &fakeChatRetrievalService{ - result: &nlp.RetrievalResult{ - Chunks: []map[string]interface{}{ - {"content_with_weight": "Knowledge for an image question."}, - }, - }, +func TestCompletionStream_Success(t *testing.T) { + store := newFakeSessionStore() + store.sessions["session-1"] = &entity.ChatSession{ + ID: "session-1", DialogID: "dialog-1", + Message: json.RawMessage(`{"messages":[]}`), + Reference: json.RawMessage(`[]`), } + store.dialogs["dialog-1"] = &entity.Chat{ + ID: "dialog-1", TenantID: "tenant-1", LLMID: "chat@factory", + LLMSetting: entity.JSONMap{}, + } + + pipeline := &fakePipeline{ + resultChan: makeResultChan( + AsyncChatResult{Answer: "stream", Reference: map[string]interface{}{"chunks": []interface{}{}}}, + AsyncChatResult{Answer: " answer", Reference: map[string]interface{}{"chunks": []interface{}{}}}, + ), + } + svc := &ChatSessionService{ - kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{ - {ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"}, - }}, - modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}}, - metadataSvc: fakeChatMetadataService{}, - retrievalSvc: retrieval, - } - imageBlock := map[string]interface{}{"type": "image_url", "image_url": map[string]interface{}{"url": "https://example.com/cat.png"}} - messages := []map[string]interface{}{ - {"role": "user", "content": []interface{}{ - map[string]interface{}{"type": "text", "text": "What is in this image?"}, - imageBlock, - }}, + chatSessionDAO: store, + userTenantDAO: &fakeTenantStore{}, + pipeline: pipeline, } - got, _, emptyResponse, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", &entity.Chat{ - ID: "dialog-1", - TenantID: "tenant-1", - PromptConfig: entity.JSONMap{}, - KBIDs: entity.JSONSlice{"kb-1"}, - TopN: 3, - TopK: 32, - }, messages, []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}}) + streamChan := make(chan string, 10) + err := svc.CompletionStream(context.Background(), "user-1", "session-1", []map[string]interface{}{ + {"role": "user", "content": "hi"}, + }, "", nil, "msg-1", streamChan) if err != nil { - t.Fatalf("messagesWithRetrievedKnowledge returned error: %v", err) + t.Fatalf("unexpected error: %v", err) } - if emptyResponse != nil { - t.Fatalf("expected no empty response, got %q", *emptyResponse) + + // Should receive data events and final signal + msgs := readStreamChan(streamChan, 5) + if len(msgs) < 3 { + t.Fatalf("expected at least 3 stream messages, got %d: %v", len(msgs), msgs) } - content, ok := got[0]["content"].([]interface{}) - if !ok { - t.Fatalf("expected multimodal content to stay as blocks, got %T", got[0]["content"]) + // Check final signal + finalFound := false + for _, m := range msgs { + if strings.Contains(m, `"data":true`) { + finalFound = true + break + } } - if len(content) != 3 { - t.Fatalf("expected injected text plus original blocks, got %#v", content) - } - injected, ok := content[0].(map[string]interface{}) - if !ok || injected["type"] != "text" || !strings.Contains(injected["text"].(string), "Knowledge for an image question.") { - t.Fatalf("expected injected knowledge text block, got %#v", content[0]) - } - preservedImage, ok := content[2].(map[string]interface{}) - if !ok || preservedImage["type"] != "image_url" { - t.Fatalf("expected original image block to be preserved, got %#v", content[2]) - } - if retrieval.req == nil || retrieval.req.Question != "What is in this image?" { - t.Fatalf("expected retrieval question from text block, got %#v", retrieval.req) + if !finalFound { + t.Fatal("expected final=true signal in stream") } } -func TestMessagesWithRetrievedKnowledgePassesMessageDocIDs(t *testing.T) { - retrieval := &fakeChatRetrievalService{result: &nlp.RetrievalResult{}} +func TestCompletionStream_EmptyMessages(t *testing.T) { svc := &ChatSessionService{ - kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{ - {ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"}, - }}, - modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}}, - metadataSvc: fakeChatMetadataService{}, - retrievalSvc: retrieval, + chatSessionDAO: &fakeSessionStore{}, + userTenantDAO: &fakeTenantStore{}, + pipeline: &fakePipeline{}, } - _, _, _, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", &entity.Chat{ - ID: "dialog-1", - TenantID: "tenant-1", - PromptConfig: entity.JSONMap{}, - KBIDs: entity.JSONSlice{"kb-1"}, - TopN: 3, - TopK: 32, - }, []map[string]interface{}{ - {"role": "user", "content": "question", "doc_ids": []interface{}{"doc-1", "doc-2", "doc-1"}}, - }, []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}}) - if err != nil { - t.Fatalf("messagesWithRetrievedKnowledge returned error: %v", err) - } - if len(retrieval.req.DocIDs) != 2 || retrieval.req.DocIDs[0] != "doc-1" || retrieval.req.DocIDs[1] != "doc-2" { - t.Fatalf("expected scoped doc ids, got %#v", retrieval.req.DocIDs) + streamChan := make(chan string, 10) + err := svc.CompletionStream(context.Background(), "user-1", "session-1", nil, "", nil, "msg-1", streamChan) + if err == nil || err.Error() != "messages cannot be empty" { + t.Fatalf("expected 'messages cannot be empty', got %v", err) } } -type failingChatRetrievalService struct { - err error +func TestCompletionStream_LastMessageNotFromUser(t *testing.T) { + svc := &ChatSessionService{ + chatSessionDAO: &fakeSessionStore{}, + userTenantDAO: &fakeTenantStore{}, + pipeline: &fakePipeline{}, + } + + streamChan := make(chan string, 10) + err := svc.CompletionStream(context.Background(), "user-1", "session-1", []map[string]interface{}{ + {"role": "assistant", "content": "hello"}, + }, "", nil, "msg-1", streamChan) + if err == nil || !strings.Contains(err.Error(), "not from user") { + t.Fatalf("expected 'not from user' error, got %v", err) + } } -func (f *failingChatRetrievalService) Retrieval(ctx context.Context, req *nlp.RetrievalRequest) (*nlp.RetrievalResult, error) { - return nil, f.err +func TestCompletionStream_ConversationNotFound(t *testing.T) { + store := newFakeSessionStore() + svc := &ChatSessionService{ + chatSessionDAO: store, + userTenantDAO: &fakeTenantStore{}, + pipeline: &fakePipeline{}, + } + + streamChan := make(chan string, 10) + err := svc.CompletionStream(context.Background(), "user-1", "missing", []map[string]interface{}{ + {"role": "user", "content": "hi"}, + }, "", nil, "msg-1", streamChan) + if err == nil || err.Error() != "Conversation not found" { + t.Fatalf("expected 'Conversation not found', got %v", err) + } +} + +func TestCompletionStream_DialogNotFound(t *testing.T) { + store := newFakeSessionStore() + store.sessions["session-1"] = &entity.ChatSession{ + ID: "session-1", DialogID: "dialog-1", + Message: json.RawMessage(`{"messages":[]}`), + Reference: json.RawMessage(`[]`), + } + + svc := &ChatSessionService{ + chatSessionDAO: store, + userTenantDAO: &fakeTenantStore{}, + pipeline: &fakePipeline{}, + } + + streamChan := make(chan string, 10) + err := svc.CompletionStream(context.Background(), "user-1", "session-1", []map[string]interface{}{ + {"role": "user", "content": "hi"}, + }, "", nil, "msg-1", streamChan) + if err == nil || err.Error() != "Dialog not found" { + t.Fatalf("expected 'Dialog not found', got %v", err) + } +} + +func TestCompletionStream_PipelineError(t *testing.T) { + store := newFakeSessionStore() + store.sessions["session-1"] = &entity.ChatSession{ + ID: "session-1", DialogID: "dialog-1", + Message: json.RawMessage(`{"messages":[]}`), + Reference: json.RawMessage(`[]`), + } + store.dialogs["dialog-1"] = &entity.Chat{ + ID: "dialog-1", TenantID: "tenant-1", LLMID: "chat@factory", + LLMSetting: entity.JSONMap{}, + } + + svc := &ChatSessionService{ + chatSessionDAO: store, + userTenantDAO: &fakeTenantStore{}, + pipeline: &fakePipeline{err: errors.New("model unavailable")}, + } + + streamChan := make(chan string, 10) + err := svc.CompletionStream(context.Background(), "user-1", "session-1", []map[string]interface{}{ + {"role": "user", "content": "hi"}, + }, "", nil, "msg-1", streamChan) + if err == nil || err.Error() != "model unavailable" { + t.Fatalf("expected 'model unavailable' error, got %v", err) + } } diff --git a/internal/service/citation.go b/internal/service/citation.go index 439306fb59..c7c536cfb4 100644 --- a/internal/service/citation.go +++ b/internal/service/citation.go @@ -33,6 +33,18 @@ type Embedder interface { Encode(texts []string) ([][]float64, error) } +// CitationMarkerPattern matches "[ID:N]" or bare "[N]" with Arabic digit support, +// allowing optional whitespace after "ID:" (e.g. "[ID: 12]"). +var CitationMarkerPattern = regexp.MustCompile(`\[(?:ID:\s*)?([0-9\x{0660}-\x{0669}\x{06F0}-\x{06F9}]+)\]`) + +// badCitationPatterns match malformed citation shapes that LLMs sometimes emit +var badCitationPatterns = []*regexp.Regexp{ + regexp.MustCompile(`\(\s*ID\s*[:: ]*\s*([0-9\x{0660}-\x{0669}\x{06F0}-\x{06F9}]+)\s*\)`), // (ID: 12) + regexp.MustCompile(`\[\s*ID\s*[:: ]*\s*([0-9\x{0660}-\x{0669}\x{06F0}-\x{06F9}]+)\s*\]`), // [ID: 12] + regexp.MustCompile(`【\s*ID\s*[:: ]*\s*([0-9\x{0660}-\x{0669}\x{06F0}-\x{06F9}]+)\s*】`), // 【ID: 12】 + regexp.MustCompile(`(?i)\bref\s*([0-9\x{0660}-\x{0669}\x{06F0}-\x{06F9}]+)\b`), // ref12 +} + // InsertCitations decorates answer with [ID:n] citation markers. // // Algorithm mirrors Python Dealer.insert_citations: @@ -260,3 +272,92 @@ func maxRow(row []float64) float64 { } return mx } + +// normalizeArabicDigits converts Arabic-Indic (U+0660-0669) and +// Eastern Arabic-Indic (U+06F0-06F9) digits to ASCII. +func normalizeArabicDigits(s string) string { + if s == "" { + return s + } + var b strings.Builder + b.Grow(len(s)) + for _, r := range s { + switch { + case r >= 0x0660 && r <= 0x0669: + b.WriteRune(r - 0x0660 + '0') + case r >= 0x06F0 && r <= 0x06F9: + b.WriteRune(r - 0x06F0 + '0') + default: + b.WriteRune(r) + } + } + return b.String() +} + +// HasCitationMarkers reports whether answer already contains canonical citation markers. +func HasCitationMarkers(answer string) bool { + if answer == "" { + return false + } + return CitationMarkerPattern.MatchString(normalizeArabicDigits(answer)) +} + +// ExtractCitationMarkers returns chunk indices from citation markers within [0, maxIndex). +// Preserves first-seen order, no duplicates. +func ExtractCitationMarkers(answer string, maxIndex int) []int { + if answer == "" || maxIndex <= 0 { + return nil + } + seen := make(map[int]struct{}) + var out []int + for _, m := range CitationMarkerPattern.FindAllStringSubmatch(normalizeArabicDigits(answer), -1) { + if len(m) < 2 { + continue + } + var n int + for _, r := range m[1] { + if r < '0' || r > '9' { + n = 0 + break + } + n = n*10 + int(r-'0') + } + if n < 0 || n >= maxIndex { + continue + } + if _, ok := seen[n]; ok { + continue + } + seen[n] = struct{}{} + out = append(out, n) + } + return out +} + +// RepairBadCitationFormats rewrites bad citation shapes into canonical "[ID:N]" form +func RepairBadCitationFormats(answer string) string { + if answer == "" { + return answer + } + working := answer + for _, pat := range badCitationPatterns { + matches := pat.FindAllStringSubmatchIndex(working, -1) + if len(matches) == 0 { + continue + } + var b strings.Builder + b.Grow(len(working)) + last := 0 + for _, m := range matches { + b.WriteString(working[last:m[0]]) + digits := normalizeArabicDigits(working[m[2]:m[3]]) + b.WriteString("[ID:") + b.WriteString(digits) + b.WriteString("]") + last = m[1] + } + b.WriteString(working[last:]) + working = b.String() + } + return working +} diff --git a/internal/service/deep_researcher.go b/internal/service/deep_researcher.go new file mode 100644 index 0000000000..a632a2dfb8 --- /dev/null +++ b/internal/service/deep_researcher.go @@ -0,0 +1,854 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package service + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "regexp" + "strings" + "sync" + "time" + + "ragflow/internal/common" + "ragflow/internal/engine" + modelModule "ragflow/internal/entity/models" + "ragflow/internal/service/kg" + "ragflow/internal/service/nlp" + "ragflow/internal/tokenizer" + + "github.com/google/uuid" + "github.com/kaptinlin/jsonrepair" + "go.uber.org/zap" +) + +// Prompt templates + +const sufficiencyCheckTemplate = `You are a information retrieval evaluation expert. Please assess whether the currently retrieved content is sufficient to answer the user's question. + +User question: +%s + +Retrieved content: +%s + +Please determine whether these content are sufficient to answer the user's question. + +Output format (JSON): +{ + "is_sufficient": true/false, + "reasoning": "Your reasoning for the judgment", + "missing_information": ["Missing information 1", "Missing information 2"] +} + +Requirements: +1. If the retrieved content contains key information needed to answer the query, judge as sufficient (true). +2. If key information is missing, judge as insufficient (false), and list the missing information. +3. The reasoning should be concise and clear. +4. The missing_information should only be filled when insufficient, otherwise empty array. +` + +const multiQueriesGenTemplate = `You are a query optimization expert. +The user's original query failed to retrieve sufficient information; +please generate multiple complementary improved questions and corresponding queries. + +Original query: +%s + +Original question: +%s + +Currently, retrieved content: +%s + +Missing information: +%s + +Please generate 2-3 complementary queries to help find the missing information. These queries should: +1. Focus on different missing information points. +2. Use different expressions. +3. Avoid being identical to the original query. +4. Remain concise and clear. + +Output format (JSON): +{ + "reasoning": "Explanation of query generation strategy", + "questions": [ + {"question": "Improved question 1", "query": "Improved query 1"}, + {"question": "Improved question 2", "query": "Improved query 2"}, + {"question": "Improved question 3", "query": "Improved query 3"} + ] +} + +Requirements: +1. Questions array contains 1-3 questions and corresponding queries. +2. Each question length is between 5-200 characters. +3. Each query length is between 1-5 keywords. +4. Each query MUST be in the same language as the retrieved content in. +5. DO NOT generate question and query that is similar to the original query. +6. Reasoning explains the generation strategy. +` + +// Types + +// KBRetrieveFunc is the signature for knowledge base retrieval. +type KBRetrieveFunc func(ctx context.Context, question string) (*nlp.RetrievalResult, error) + +// sufficiencyResult is the sufficiency_check JSON output. +type sufficiencyResult struct { + IsSufficient bool `json:"is_sufficient"` + Reasoning string `json:"reasoning"` + MissingInformation []string `json:"missing_information"` +} + +// queryPair is a {question, query} entry. +type queryPair struct { + Question string `json:"question"` + Query string `json:"query"` +} + +// multiQueriesResult is the multi_queries_gen JSON output. +type multiQueriesResult struct { + Reasoning string `json:"reasoning"` + Questions []queryPair `json:"questions"` +} + +// DeepResearcher implements recursive query-decomposition retrieval. +// Each level: retrieve → sufficiency check → if insufficient, generate +// sub-queries → recurse. Accumulates chunks into a shared chunkInfo map. +type DeepResearcher struct { + ChatModel *modelModule.ChatModel + PromptConfig map[string]interface{} + KBRetrieve KBRetrieveFunc + InternetEnabled bool + TavilyAPIKey string + + // Fields needed for KG retrieval (mirrors async_chat.go usage). + DocEngine engine.DocEngine + KbIDs []string + TenantIDs []string + EmbModel *modelModule.EmbeddingModel + + maxDepth int // default 3 + mu sync.Mutex + tavilyURL string +} + +// NewDeepResearcher constructs a DeepResearcher. +func NewDeepResearcher( + chatModel *modelModule.ChatModel, + promptConfig map[string]interface{}, + kbRetrieve KBRetrieveFunc, + internetEnabled bool, + docEngine engine.DocEngine, + kbIDs []string, + tenantIDs []string, + embModel *modelModule.EmbeddingModel, +) *DeepResearcher { + return &DeepResearcher{ + ChatModel: chatModel, + PromptConfig: promptConfig, + KBRetrieve: kbRetrieve, + InternetEnabled: internetEnabled, + TavilyAPIKey: mapStringValue(promptConfig, "tavily_api_key"), + DocEngine: docEngine, + KbIDs: kbIDs, + TenantIDs: tenantIDs, + EmbModel: embModel, + maxDepth: 3, + tavilyURL: "https://api.tavily.com/search", + } +} + +// Research runs the recursive tree search, accumulating chunks into chunkInfo. +func (dr *DeepResearcher) Research( + ctx context.Context, + chunkInfo map[string]interface{}, + question string, + query string, + callback func(string), +) error { + if dr == nil || dr.ChatModel == nil { + return fmt.Errorf("DeepResearcher: missing chat model") + } + + if callback != nil { + callback("") + } + + // Initialize chunkInfo if empty + if _, ok := chunkInfo["chunks"]; !ok { + chunkInfo["chunks"] = []interface{}{} + chunkInfo["doc_aggs"] = []interface{}{} + chunkInfo["total"] = 0 + } + + _, err := dr._research(ctx, chunkInfo, question, query, dr.maxDepth, callback) + if err != nil { + common.Warn("DeepResearcher: research failed", zap.Error(err)) + } + + if callback != nil { + callback("") + } + + return err +} + +// _research is the recursive depth-first worker. +func (dr *DeepResearcher) _research( + ctx context.Context, + chunkInfo map[string]interface{}, + question string, + query string, + depth int, + callback func(string), +) (string, error) { + if depth == 0 { + return "", nil + } + + if callback != nil { + callback(fmt.Sprintf("Searching by `%s`...", query)) + } + + // 1. Retrieve information (KB + optional web) + st := time.Now() + kbinfos, err := dr._retrieve_information(ctx, query) + if err != nil { + return "", err + } + + if callback != nil { + n := len(chunksFromKBInfos(kbinfos)) + elapsed := time.Since(st).Milliseconds() + callback(fmt.Sprintf("Retrieval %d results in %.1fms", n, float64(elapsed))) + } + + // 2. Merge into chunkInfo (dedup by chunk_id) + dr.mergeChunkInfo(chunkInfo, kbinfos) + + // 3. Trim content + maxTokens := dr.ChatModelMaxTokens() / 2 + knowledges := kbPrompt(kbinfos, maxTokens, false) + retContent := strings.Join(knowledges, "\n\n") + + // 4. Sufficiency check + if callback != nil { + callback("Checking the sufficiency for retrieved information.") + } + + suff, err := dr.sufficiencyCheck(ctx, question, retContent) + if err != nil { + common.Warn("DeepResearcher: sufficiency check failed", + zap.Error(err), zap.Int("depth", depth)) + // On error, treat as insufficient + suff = &sufficiencyResult{IsSufficient: false} + } + + if suff.IsSufficient { + if callback != nil { + callback(fmt.Sprintf("Yes, the retrieved information is sufficient for '%s'.", question)) + } + return retContent, nil + } + + // 5. Generate sub-queries + missingStr := strings.Join(suff.MissingInformation, "\n - ") + mg, err := dr.multiQueriesGen(ctx, question, query, missingStr, retContent) + if err != nil { + common.Warn("DeepResearcher: multi_queries_gen failed", + zap.Error(err), zap.Int("depth", depth)) + return retContent, nil + } + + if len(mg.Questions) == 0 { + return retContent, nil + } + + if callback != nil { + var questionStrs []string + for _, q := range mg.Questions { + questionStrs = append(questionStrs, q.Question) + } + callback("Next step is to search for the following questions:
- " + strings.Join(questionStrs, "
- ")) + } + + // 6. Recurse in parallel + var wg sync.WaitGroup + results := make([]string, len(mg.Questions)) + mu := &sync.Mutex{} + + for i, qp := range mg.Questions { + wg.Add(1) + go func(idx int, q queryPair) { + defer wg.Done() + r, err := dr._research(ctx, chunkInfo, q.Question, q.Query, depth-1, callback) + mu.Lock() + defer mu.Unlock() + if err != nil { + // Exceptions become string results (gather with return_exceptions) + results[idx] = err.Error() + common.Warn("DeepResearcher: sub-research failed", + zap.Error(err), zap.String("question", q.Question)) + return + } + results[idx] = r + }(i, qp) + } + + // Wait with context cancellation support. + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + case <-ctx.Done(): + return retContent, ctx.Err() + } + + // 7. Join results + return strings.Join(results, "\n"), nil +} + +// ────────────────────────────────────────────────────────────────────── +// Retrieval (KB + optional Web) +// ────────────────────────────────────────────────────────────────────── + +// _retrieve_information does KB + optional web retrieval. +func (dr *DeepResearcher) _retrieve_information(ctx context.Context, query string) (map[string]interface{}, error) { + kbinfos := map[string]interface{}{ + "total": int64(0), + "chunks": []map[string]interface{}{}, + "doc_aggs": []interface{}{}, + } + + // 1. KB retrieval + if dr.KBRetrieve != nil { + resp, err := dr.KBRetrieve(ctx, query) + if err != nil { + common.Warn("DeepResearcher: KB retrieval error", zap.Error(err)) + } + if resp != nil { + chunks := make([]map[string]interface{}, len(resp.Chunks)) + copy(chunks, resp.Chunks) + docAggs := make([]interface{}, 0, len(resp.DocAggs)) + for _, d := range resp.DocAggs { + docAggs = append(docAggs, d) + } + kbinfos["chunks"] = chunks + kbinfos["doc_aggs"] = docAggs + kbinfos["total"] = resp.Total + } + } + + // 2. Web retrieval (Tavily) + if dr.InternetEnabled && dr.TavilyAPIKey != "" { + tavRes, err := dr.tavilyRetrieve(ctx, query) + if err != nil { + common.Warn("DeepResearcher: web retrieval error", zap.Error(err)) + } else if tavRes != nil { + if chunks, ok := tavRes["chunks"].([]map[string]interface{}); ok { + existing, _ := kbinfos["chunks"].([]map[string]interface{}) + kbinfos["chunks"] = append(existing, chunks...) + } + if aggs, ok := tavRes["doc_aggs"].([]interface{}); ok { + existing, _ := kbinfos["doc_aggs"].([]interface{}) + kbinfos["doc_aggs"] = append(existing, aggs...) + } + } + } + + // 3. Knowledge graph retrieval + if useKG, _ := dr.PromptConfig["use_kg"].(bool); useKG && dr.ChatModel != nil && len(dr.KbIDs) > 0 { + kgPipeline := kg.NewPipeline(dr.DocEngine, dr.KbIDs, dr.TenantIDs, query) + kgPipeline.SetChatModel(dr.ChatModel) + if dr.EmbModel != nil { + kgPipeline.SetEmbModel(dr.EmbModel) + } + kgChunk, kgErr := kgPipeline.Retrieval(ctx) + if kgErr != nil { + common.Warn("DeepResearcher: KG retrieval failed", zap.Error(kgErr)) + } else if kgChunk != nil { + if _, hasContent := kgChunk["content_with_weight"]; hasContent { + if existingChunks, ok := kbinfos["chunks"].([]map[string]interface{}); ok { + newChunks := make([]map[string]interface{}, 0, len(existingChunks)+1) + newChunks = append(newChunks, kgChunk) + newChunks = append(newChunks, existingChunks...) + kbinfos["chunks"] = newChunks + } + } + } + } + + return kbinfos, nil +} + +// tavilyRetrieve calls the Tavily Search API. +func (dr *DeepResearcher) tavilyRetrieve(ctx context.Context, query string) (map[string]interface{}, error) { + reqBody := map[string]interface{}{ + "query": query, + "api_key": dr.TavilyAPIKey, + "search_depth": "advanced", + "max_results": 6, + } + jsonBody, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("tavily marshal: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", dr.tavilyURL, bytes.NewReader(jsonBody)) + if err != nil { + return nil, fmt.Errorf("tavily request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: 15 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("tavily call: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("tavily read: %w", err) + } + + var apiResp struct { + Results []struct { + URL string `json:"url"` + Title string `json:"title"` + Content string `json:"content"` + Score float64 `json:"score"` + } `json:"results"` + } + if err := json.Unmarshal(body, &apiResp); err != nil { + return nil, fmt.Errorf("tavily parse: %w", err) + } + + if len(apiResp.Results) == 0 { + return nil, nil + } + + chunks := make([]map[string]interface{}, 0, len(apiResp.Results)) + aggs := make([]interface{}, 0, len(apiResp.Results)) + for _, r := range apiResp.Results { + id := strings.ReplaceAll(uuid.New().String(), "-", "") + chunks = append(chunks, map[string]interface{}{ + "chunk_id": id, + "content_ltks": tokenizeText(r.Content), + "content_with_weight": r.Content, + "doc_id": id, + "docnm_kwd": r.Title, + "kb_id": []interface{}{}, + "important_kwd": []interface{}{}, + "image_id": "", + "similarity": r.Score, + "vector_similarity": 1.0, + "term_similarity": 0, + "vector": []interface{}{}, + "positions": []interface{}{}, + "url": r.URL, + }) + aggs = append(aggs, map[string]interface{}{ + "doc_name": r.Title, + "doc_id": id, + "count": 1, + "url": r.URL, + }) + } + + return map[string]interface{}{ + "chunks": chunks, + "doc_aggs": aggs, + }, nil +} + +// mergeChunkInfo merges kbinfos into chunkInfo, deduplicating by chunk_id / doc_id. +func (dr *DeepResearcher) mergeChunkInfo( + chunkInfo map[string]interface{}, + kbinfos map[string]interface{}, +) { + dr.mu.Lock() + defer dr.mu.Unlock() + + existingChunks, _ := chunkInfo["chunks"].([]map[string]interface{}) + + // First retrieval — copy all keys from kbinfos + if len(existingChunks) == 0 { + for k := range chunkInfo { + chunkInfo[k] = kbinfos[k] + } + return + } + + newChunks, _ := kbinfos["chunks"].([]map[string]interface{}) + if len(newChunks) == 0 { + return + } + + // Build set of existing chunk IDs. + seenChunkIDs := make(map[string]bool) + for _, m := range existingChunks { + if id, ok := m["chunk_id"].(string); ok { + seenChunkIDs[id] = true + } + } + + // Append only new chunks. + for _, m := range newChunks { + id, _ := m["chunk_id"].(string) + if id == "" { + continue + } + if seenChunkIDs[id] { + continue + } + seenChunkIDs[id] = true + existingChunks = append(existingChunks, m) + } + + chunkInfo["chunks"] = existingChunks + + // Merge doc_aggs (dedup by doc_id). + newAggs, _ := kbinfos["doc_aggs"].([]interface{}) + if len(newAggs) == 0 { + return + } + + aggExisting, _ := chunkInfo["doc_aggs"].([]interface{}) + seenDocIDs := make(map[string]bool) + for _, d := range aggExisting { + if m, ok := d.(map[string]interface{}); ok { + if id, ok := m["doc_id"].(string); ok { + seenDocIDs[id] = true + } + } + } + for _, d := range newAggs { + m, ok := d.(map[string]interface{}) + if !ok { + continue + } + id, _ := m["doc_id"].(string) + if id == "" || seenDocIDs[id] { + continue + } + seenDocIDs[id] = true + aggExisting = append(aggExisting, d) + } + + chunkInfo["doc_aggs"] = aggExisting + + // Accumulate total. + existingTotal, _ := chunkInfo["total"].(int64) + newTotal, _ := kbinfos["total"].(int64) + chunkInfo["total"] = existingTotal + newTotal +} + +// genJSON calls the LLM with a system prompt and retries on parse failure. +func (dr *DeepResearcher) genJSON( + ctx context.Context, + systemPrompt string, + cfg *modelModule.ChatConfig, + result interface{}, +) error { + maxRetry := 2 + var lastAns, lastErr string + + for attempt := 0; attempt < maxRetry; attempt++ { + userPrompt := "Output:\n" + if attempt > 0 && lastAns != "" && lastErr != "" { + // Append correction to user message on retry + userPrompt += fmt.Sprintf( + "\nGenerated JSON is as following:\n%s\nBut exception while loading:\n%s\nPlease reconsider and correct it.", + lastAns, lastErr, + ) + } + + resp, err := dr.chatOnce(ctx, systemPrompt, userPrompt, cfg) + if err != nil { + return err + } + + resp = cleanLLMResponse(resp) + lastAns = resp + + repaired, rerr := jsonrepair.Repair(resp) + if rerr != nil { + repaired = resp + } + if err := json.Unmarshal([]byte(repaired), result); err != nil { + lastErr = err.Error() + common.Warn("genJSON: JSON parse failed, retrying", + zap.Error(err), zap.Int("attempt", attempt)) + continue + } + return nil + } + return fmt.Errorf("genJSON: failed after %d attempts: %s", maxRetry, lastErr) +} + +// sufficiencyCheck asks the LLM whether retrieved content is sufficient. +func (dr *DeepResearcher) sufficiencyCheck( + ctx context.Context, + question string, + retContent string, +) (*sufficiencyResult, error) { + systemPrompt := fmt.Sprintf(sufficiencyCheckTemplate, question, retContent) + tempLow := 0.0 + cfg := &modelModule.ChatConfig{Temperature: &tempLow} + + var result sufficiencyResult + if err := dr.genJSON(ctx, systemPrompt, cfg, &result); err != nil { + return nil, err + } + return &result, nil +} + +// multiQueriesGen asks the LLM to generate sub-queries from missing info. +func (dr *DeepResearcher) multiQueriesGen( + ctx context.Context, + originalQuestion string, + originalQuery string, + missingInfo string, + retContent string, +) (*multiQueriesResult, error) { + systemPrompt := fmt.Sprintf(multiQueriesGenTemplate, + originalQuery, originalQuestion, retContent, missingInfo) + tempLow := 0.4 + cfg := &modelModule.ChatConfig{Temperature: &tempLow} + + var result multiQueriesResult + if err := dr.genJSON(ctx, systemPrompt, cfg, &result); err != nil { + return nil, err + } + return &result, nil +} + +// chatOnce is a single-turn LLM call. Returns the answer text. +func (dr *DeepResearcher) chatOnce( + ctx context.Context, + systemPrompt string, + userPrompt string, + cfg *modelModule.ChatConfig, +) (string, error) { + if dr.ChatModel == nil || dr.ChatModel.ModelDriver == nil { + return "", fmt.Errorf("DeepResearcher: no chat model configured") + } + modelName := "" + if dr.ChatModel.ModelName != nil { + modelName = *dr.ChatModel.ModelName + } + msgs := []modelModule.Message{ + modelModule.Message{Role: "system", Content: systemPrompt}, + modelModule.Message{Role: "user", Content: userPrompt}, + } + resp, err := dr.ChatModel.ModelDriver.ChatWithMessages( + modelName, msgs, dr.ChatModel.APIConfig, cfg, + ) + if err != nil { + return "", err + } + if resp == nil || resp.Answer == nil { + return "", fmt.Errorf("empty response from chat model") + } + return *resp.Answer, nil +} + +// cleanLLMResponse strips think tags, markdown fences, and trailing backticks. +var thinkTagRe = regexp.MustCompile(`(?s)^.*?`) +var trailingCommaRe = regexp.MustCompile(`,\s*([}\]])`) +var cleanResponseRe = regexp.MustCompile(`(?s)(^.*?|` + "```json\\n" + `|` + "```\\n*$" + `)`) +var trailingBacktickRe = regexp.MustCompile("```\\n*$") + +func cleanLLMResponse(raw string) string { + // Strip think tags, markdown fences + raw = cleanResponseRe.ReplaceAllString(raw, "") + + // Also handle trailing ```` in case any remain after the regex pass + raw = trailingBacktickRe.ReplaceAllString(raw, "") + + return strings.TrimSpace(raw) +} + +// repairJSON: see metadata_filter.go:737 (canonical implementation). + +// ChatModelMaxTokens returns the token budget for kb_prompt sizing (default 6000). +func (dr *DeepResearcher) ChatModelMaxTokens() int { + return 6000 +} + +// kbPrompt formats retrieval results into knowledge blocks, truncating at 97% of maxTokens. +func kbPrompt(kbinfos map[string]interface{}, maxTokens int, hashID bool) []string { + chunksRaw, _ := kbinfos["chunks"].([]map[string]interface{}) + if len(chunksRaw) == 0 { + return nil + } + + // Extract content strings. + var knowledges []string + for _, m := range chunksRaw { + text := getMapString(m, "content") + if text == "" { + text = getMapString(m, "content_with_weight") + } + if text == "" { + continue + } + knowledges = append(knowledges, text) + } + + if len(knowledges) == 0 { + return nil + } + + // Truncate at 97% token budget. + usedTokens := 0 + chunksNum := 0 + for i, c := range knowledges { + usedTokens += tokenizer.NumTokensFromString(c) + chunksNum++ + if usedTokens > int(float64(maxTokens)*0.97) { + knowledges = knowledges[:i] + common.Warn("kb_prompt: truncating chunks", + zap.Int("kept", len(knowledges)), + zap.Int("total", len(chunksRaw))) + break + } + } + + // Format each chunk. + knowledges = nil // reuse + for i, m := range chunksRaw[:chunksNum] { + + id := i + if hashID { + // Hash chunk ID for stable int ID + if rawID := getMapString(m, "id", "chunk_id"); rawID != "" { + id = hashStrToInt(rawID, 500) + } + } + + cnt := fmt.Sprintf("\nID: %d", id) + cnt += drawNode("Title", getMapString(m, "docnm_kwd", "document_name")) + cnt += drawNode("URL", getMapString(m, "url")) + + if meta, ok := m["document_metadata"].(map[string]interface{}); ok { + for k, v := range meta { + cnt += drawNode(k, v) + } + } + + cnt += "\n└── Content:\n" + text := getMapString(m, "content") + if text == "" { + text = getMapString(m, "content_with_weight") + } + cnt += text + knowledges = append(knowledges, cnt) + } + + return knowledges +} + +// drawNode formats a key-value line with tree-drawing prefix. +func drawNode(k string, v interface{}) string { + var line string + switch val := v.(type) { + case string: + line = val + case fmt.Stringer: + line = val.String() + default: + line = fmt.Sprintf("%v", val) + } + if line == "" { + return "" + } + // Collapse consecutive newlines into a single space + var nb strings.Builder + nb.Grow(len(line)) + inNewlines := false + for _, r := range line { + if r == '\n' { + if !inNewlines { + nb.WriteByte(' ') + inNewlines = true + } + } else { + nb.WriteRune(r) + inNewlines = false + } + } + line = nb.String() + return fmt.Sprintf("\n├── %s: %s", k, line) +} + +// hashStrToInt is an FNV-1a hash modulo mod. +func hashStrToInt(s string, mod int) int { + if s == "" || mod <= 0 { + return 0 + } + var h uint64 = 14695981039346656037 // FNV offset basis + for i := 0; i < len(s); i++ { + h ^= uint64(s[i]) + h *= 1099511628211 // FNV prime + } + return int(h % uint64(mod)) +} + +// getMapString gets a string from a map, trying multiple keys. +func getMapString(m map[string]interface{}, keys ...string) string { + for _, k := range keys { + if v, ok := m[k]; ok { + if s, ok := v.(string); ok { + return s + } + } + } + return "" +} + +// mapStringValue extracts a string value from a map by key. +func mapStringValue(m map[string]interface{}, key string) string { + if v, ok := m[key]; ok { + if s, ok := v.(string); ok { + return s + } + } + return "" +} + +// chunksFromKBInfos extracts chunks list from kbinfos for counting. +func chunksFromKBInfos(kbinfos map[string]interface{}) []map[string]interface{} { + if ch, ok := kbinfos["chunks"].([]map[string]interface{}); ok { + return ch + } + return nil +} + +// Ensure time is referenced (avoids unused import in some build configurations). +var _ = time.Now diff --git a/internal/service/file.go b/internal/service/file.go index 7783343788..f46e3ecb14 100644 --- a/internal/service/file.go +++ b/internal/service/file.go @@ -18,6 +18,7 @@ package service import ( "context" + "encoding/base64" "fmt" "io" "mime/multipart" @@ -27,6 +28,7 @@ import ( "ragflow/internal/dao" "ragflow/internal/engine" "ragflow/internal/entity" + "ragflow/internal/ingestion/parser" "ragflow/internal/storage" "ragflow/internal/utility" "strings" @@ -1008,3 +1010,61 @@ func (s *FileService) DownloadAgentFile(tenantID, location string) ([]byte, erro return blob, nil } + +// GetFileContents fetches file contents (text + image) from storage +// for the given file dicts. +// - raw=false: images returned as base64 data URIs in images; non-images parsed and returned as text. +// - raw=true: images returned as raw bytes in images; non-images parsed and returned as text. +func (s *FileService) GetFileContents(fileDicts []map[string]interface{}, raw bool) (texts []string, images []string, err error) { + storageImpl := storage.GetStorageFactory().GetStorage() + if storageImpl == nil { + return nil, nil, fmt.Errorf("storage not initialized") + } + + for _, fd := range fileDicts { + id, _ := fd["id"].(string) + if id == "" { + continue + } + file, ferr := s.fileDAO.GetByID(id) + if ferr != nil || file == nil || file.Location == nil || *file.Location == "" { + continue + } + data, derr := storageImpl.Get(file.ParentID, *file.Location) + if derr != nil || len(data) == 0 { + continue + } + ft := utility.FilenameType(file.Name) + if ft == utility.FileTypeVISUAL { + if raw { + images = append(images, string(data)) + } else { + ext := utility.GetFileExtension(file.Name) + mime := utility.GetContentType(ext, string(ft)) + images = append(images, "data:"+mime+";base64,"+base64.StdEncoding.EncodeToString(data)) + } + } else { + texts = append(texts, parseFileContent(file.Name, data)) + } + } + return texts, images, nil +} + +// parseFileContent tries to parse a file's contents using the appropriate parser. +// Falls back to returning raw text if no parser is available. +func parseFileContent(filename string, data []byte) string { + fileType := utility.GetFileType(filename) + if fileType == utility.FileTypeOTHER { + return string(data) + } + // Parser config — office_oxide for MS Office formats; other parsers ignore it. + parserCfg := map[string]string{"lib_type": "office_oxide"} + fp, err := parser.GetParser(fileType, parserCfg) + if err != nil { + return string(data) + } + if err := fp.Parse(filename, data); err != nil { + return string(data) + } + return fp.String() +} diff --git a/internal/service/generator.go b/internal/service/generator.go index ede56dee39..f97a6914c1 100644 --- a/internal/service/generator.go +++ b/internal/service/generator.go @@ -17,20 +17,22 @@ package service import ( + "bytes" "context" "fmt" - "ragflow/internal/common" - "ragflow/internal/entity" "regexp" "strings" + "text/template" + "time" + + "ragflow/internal/common" + "ragflow/internal/entity" + modelModule "ragflow/internal/entity/models" "go.uber.org/zap" - - modelModule "ragflow/internal/entity/models" ) // KeywordExtraction extracts keywords from content using LLM. -// Corresponds to rag/prompts/generator.py:keyword_extraction(). // // Uses ChatModel to call the LLM with a keyword extraction prompt. // Returns comma-separated top N important keywords/phrases from the content. @@ -211,3 +213,164 @@ func CrossLanguages(ctx context.Context, tenantID string, llmID string, query st return query, nil } + +// fullQuestionTmpl mirrors the Python Jinja2 template +// rag/prompts/full_question_prompt.md. The rendered output is used as the +// system message; the user message is just "Output: ". +var fullQuestionTmpl = template.Must(template.New("full_question").Parse(`## Role +A helpful assistant. + +## Task & Steps +1. Generate a full user question that would follow the conversation. +2. If the user's question involves relative dates, convert them into absolute dates based on today ({{.Today}}). + - "yesterday" = {{.Yesterday}}, "tomorrow" = {{.Tomorrow}} + +## Requirements & Restrictions +- If the user's latest question is already complete, don't do anything — just return the original question. +- DON'T generate anything except a refined question. +{{- if .Language }} +- Text generated MUST be in {{.Language}}. +{{- else }} +- Text generated MUST be in the same language as the original user's question. +{{- end }} + +--- + +## Examples + +### Example 1 +**Conversation:** + +USER: What is the name of Donald Trump's father? +ASSISTANT: Fred Trump. +USER: And his mother? + +**Output:** What's the name of Donald Trump's mother? + +--- + +### Example 2 +**Conversation:** + +USER: What is the name of Donald Trump's father? +ASSISTANT: Fred Trump. +USER: And his mother? +ASSISTANT: Mary Trump. +USER: What's her full name? + +**Output:** What's the full name of Donald Trump's mother Mary Trump? + +--- + +### Example 3 +**Conversation:** + +USER: What's the weather today in London? +ASSISTANT: Cloudy. +USER: What's about tomorrow in Rochester? + +**Output:** What's the weather in Rochester on {{.Tomorrow}}? + +--- + +## Real Data + +**Conversation:** + +{{.Conversation}} +`)) + +var errorMarkerRE = regexp.MustCompile(`\*\*ERROR\*\*`) + +// FullQuestion rewrites the latest user question in light of prior +// conversation context (pronouns, dates, follow-ups). Falls back to the +// latest user message on LLM error. +// When language is empty, the original language is preserved (matching Python). +// +// The prompt structure mirrors Python's full_question(): +// - System: fullQuestionTmpl (instructions, examples, conversation) +// - User: "Output: " +// +// This matches rag/prompts/full_question_prompt.md rendered via Jinja2. +func FullQuestion( + ctx context.Context, + chatModel *modelModule.ChatModel, + messages []map[string]interface{}, + language string, +) (string, error) { + if chatModel == nil || chatModel.ModelDriver == nil { + return "", fmt.Errorf("FullQuestion: nil chat model") + } + if len(messages) == 0 { + return "", fmt.Errorf("FullQuestion: empty messages") + } + + var convLines []string + for _, m := range messages { + role, _ := m["role"].(string) + if role != "user" && role != "assistant" { + continue + } + content, _ := m["content"].(string) + convLines = append(convLines, fmt.Sprintf("%s: %s", strings.ToUpper(role), content)) + } + conv := strings.Join(convLines, "\n") + + today := time.Now().Format("2006-01-02") + tomorrow := time.Now().Add(24 * time.Hour).Format("2006-01-02") + yesterday := time.Now().Add(-24 * time.Hour).Format("2006-01-02") + + var buf bytes.Buffer + if err := fullQuestionTmpl.Execute(&buf, map[string]string{ + "Today": today, + "Yesterday": yesterday, + "Tomorrow": tomorrow, + "Conversation": conv, + "Language": language, + }); err != nil { + return fallbackToLatestUser(messages), fmt.Errorf("FullQuestion: render template: %w", err) + } + system := buf.String() + + modelName := "" + if chatModel.ModelName != nil { + modelName = *chatModel.ModelName + } + msgs := []modelModule.Message{ + {Role: "system", Content: system}, + {Role: "user", Content: "Output: "}, + } + resp, err := chatModel.ModelDriver.ChatWithMessages( + modelName, msgs, chatModel.APIConfig, nil, + ) + if err != nil { + return fallbackToLatestUser(messages), err + } + if resp == nil || resp.Answer == nil { + return fallbackToLatestUser(messages), fmt.Errorf("FullQuestion: empty response") + } + cleaned := strings.TrimSpace(*resp.Answer) + cleaned = thinkBlockRE.ReplaceAllString(cleaned, "") + cleaned = strings.TrimSpace(cleaned) + if errorMarkerRE.MatchString(cleaned) { + return fallbackToLatestUser(messages), nil + } + if cleaned == "" { + return fallbackToLatestUser(messages), nil + } + return cleaned, nil +} + +// fallbackToLatestUser returns the last user message, or "" if none. +func fallbackToLatestUser(messages []map[string]interface{}) string { + for i := len(messages) - 1; i >= 0; i-- { + role, _ := messages[i]["role"].(string) + if role == "user" { + if c, ok := messages[i]["content"].(string); ok { + return c + } + return "" + } + } + return "" +} diff --git a/internal/service/kb_prompt.go b/internal/service/kb_prompt.go index 737f8e0a41..d6ffaa20cb 100644 --- a/internal/service/kb_prompt.go +++ b/internal/service/kb_prompt.go @@ -19,7 +19,6 @@ package service import ( "fmt" "strings" - "unicode/utf8" "ragflow/internal/tokenizer" ) @@ -34,20 +33,20 @@ func ChunksFormat(chunks []SourcedChunk) []map[string]interface{} { out := make([]map[string]interface{}, len(chunks)) for i, ck := range chunks { out[i] = map[string]interface{}{ - "id": ck.ID, - "content": ck.Content, - "document_id": ck.DocID, - "document_name": ck.DocName, - "dataset_id": ck.DatasetID, - "image_id": ck.ImageID, - "positions": ck.Positions, - "url": ck.URL, - "similarity": ck.Similarity, - "vector_similarity": ck.VectorSimilarity, - "term_similarity": ck.TermSimilarity, - "row_id": ck.ID, // row_id == ID for consistency with Python - "doc_type": ck.DocType, - "document_metadata": ck.DocumentMetadata, + "id": ck.ID, + "content": ck.Content, + "document_id": ck.DocID, + "document_name": ck.DocName, + "dataset_id": ck.DatasetID, + "image_id": ck.ImageID, + "positions": ck.Positions, + "url": ck.URL, + "similarity": ck.Similarity, + "vector_similarity": ck.VectorSimilarity, + "term_similarity": ck.TermSimilarity, + "row_id": ck.ID, // row_id == ID for consistency with Python + "doc_type": ck.DocType, + "document_metadata": ck.DocumentMetadata, } } return out @@ -70,7 +69,7 @@ func KbPrompt(chunks []SourcedChunk, maxTokens int) string { used := 0 for _, ck := range chunks { entry := formatChunkEntry(ck) - tokens := NumTokensFromString(entry) + tokens := tokenizer.NumTokensFromString(entry) if used+tokens > limit { break } @@ -80,21 +79,6 @@ func KbPrompt(chunks []SourcedChunk, maxTokens int) string { return b.String() } -// NumTokensFromString returns the number of tokens in s using the C++ tokenizer. -// Falls back to a rune-based estimate (~2 chars per token) when the tokenizer -// is not available (e.g. CI, development without Infinity dictionaries). -func NumTokensFromString(s string) int { - if s == "" { - return 0 - } - result, err := tokenizer.Tokenize(s) - if err != nil { - // Fallback: ~2 chars per token for mixed language text. - return utf8.RuneCountInString(s) / 2 - } - return len(strings.Fields(result)) -} - // formatChunkEntry renders a single chunk as a tree-structured entry for the // LLM prompt. Format matches Python kb_prompt() in rag/prompts/generator.py: // diff --git a/internal/service/kb_prompt_test.go b/internal/service/kb_prompt_test.go index 3a64d04957..1eb8870f71 100644 --- a/internal/service/kb_prompt_test.go +++ b/internal/service/kb_prompt_test.go @@ -2,6 +2,8 @@ package service import ( "testing" + + "ragflow/internal/tokenizer" ) func TestKbPrompt_Empty(t *testing.T) { @@ -59,7 +61,7 @@ func TestKbPrompt_TokenLimit(t *testing.T) { } // Compute limit dynamically so the test works with both the C++ // tokenizer and the rune-based fallback. - entryTokens := NumTokensFromString(formatChunkEntry(chunks[0])) + entryTokens := tokenizer.NumTokensFromString(formatChunkEntry(chunks[0])) maxToks := int(float64(entryTokens+1) / 0.97) // just enough for first result := KbPrompt(chunks, maxToks) if !contains(result, "ID: 1") { @@ -102,8 +104,6 @@ func TestKbPrompt_NoDocNameOrURL(t *testing.T) { } } - - func contains(s, substr string) bool { for i := 0; i <= len(s)-len(substr); i++ { if s[i:i+len(substr)] == substr { @@ -113,33 +113,15 @@ func contains(s, substr string) bool { return false } - - -func TestNumTokensFromString_Empty(t *testing.T) { - if got := NumTokensFromString(""); got != 0 { - t.Errorf("expected 0 for empty string, got %d", got) - } -} - -func TestNumTokensFromString_Positive(t *testing.T) { - // Either the C++ tokenizer or the fallback must return > 0 for - // non-empty text. The exact count depends on the environment. - for _, s := range []string{"hello world", "你好世界"} { - if got := NumTokensFromString(s); got <= 0 { - t.Errorf("NumTokensFromString(%q) = %d, want >0", s, got) - } - } -} - func TestKbPrompt_TokenLimitAccurate(t *testing.T) { - // Verify truncation uses NumTokensFromString by computing the limit + // Verify truncation uses tokenizer.NumTokensFromString by computing the limit // dynamically from the actual token count (works in both fallback // and C++ tokenizer environments). chunks := []SourcedChunk{ {ID: "1", Content: "hello"}, {ID: "2", Content: "world"}, } - entryTokens := NumTokensFromString(formatChunkEntry(chunks[0])) + entryTokens := tokenizer.NumTokensFromString(formatChunkEntry(chunks[0])) maxToks := int(float64(entryTokens+1) / 0.97) // just enough for first entry result := KbPrompt(chunks, maxToks) if !contains(result, "ID: 1") { @@ -160,5 +142,3 @@ func TestKbPrompt_AllFit(t *testing.T) { t.Error("both chunks should fit under generous limit") } } - - diff --git a/internal/service/kg/scoring.go b/internal/service/kg/scoring.go index 6ee38c5559..61896c9942 100644 --- a/internal/service/kg/scoring.go +++ b/internal/service/kg/scoring.go @@ -24,7 +24,7 @@ import ( "sort" "strings" - "ragflow/internal/service" + "ragflow/internal/tokenizer" ) // AnalyzeNHopPaths decomposes N-hop paths into edges with distance-decayed scores. @@ -150,7 +150,7 @@ func SortAndTrimRelations(relsFromText map[Edge]*KGRelation, topN int) []ScoredR // NumTokensFromString estimates the number of tokens in a string. // Delegates to the shared implementation in the parent service package. func NumTokensFromString(s string) int { - return service.NumTokensFromString(s) + return tokenizer.NumTokensFromString(s) } // formatCSVLine formats fields as a single CSV record with trailing newline. diff --git a/internal/service/langfuse.go b/internal/service/langfuse.go new file mode 100644 index 0000000000..0193562189 --- /dev/null +++ b/internal/service/langfuse.go @@ -0,0 +1,258 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package service + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "sync" + "time" + + "ragflow/internal/dao" + "ragflow/internal/entity" + + "gorm.io/gorm" +) + +type langfuseCtxKeyType struct{} + +var langfuseCtxKey = langfuseCtxKeyType{} + +// LangfuseClientFromTenant returns a tracing client for the given tenant, +// or nil if Langfuse is not configured. Failures to look up credentials +// are non-fatal; Langfuse is observability, not a chat path requirement. +func LangfuseClientFromTenant(ctx context.Context, tenantID, userID, chatID, modelName string) *LangfuseClient { + if tenantID == "" { + return nil + } + creds, err := getTenantLangfuse(tenantID) + if err != nil || creds == nil { + return nil + } + if creds.Host == "" || creds.PublicKey == "" || creds.SecretKey == "" { + return nil + } + return NewLangfuseClient(creds.Host, creds.PublicKey, creds.SecretKey) +} + +// getTenantLangfuse returns the Langfuse credentials for a tenant, or +// (nil, nil) when no row exists. +func getTenantLangfuse(tenantID string) (*entity.TenantLangfuse, error) { + if tenantID == "" { + return nil, gorm.ErrInvalidDB + } + var row entity.TenantLangfuse + err := dao.DB.Where("tenant_id = ?", tenantID).First(&row).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, nil + } + return nil, err + } + return &row, nil +} + +// LangfuseClient posts trace and observation events to a Langfuse ingestion +// endpoint. All writes are async (background worker drains a buffered +// channel); reads (none in this minimal version) are direct. +type LangfuseClient struct { + Host string + PublicKey string + SecretKey string + HTTP *http.Client + + events chan []byte + stop chan struct{} + stopped chan struct{} + once sync.Once +} + +// NewLangfuseClient constructs a LangfuseClient with a 2-second HTTP timeout +// and starts a background worker. Call Shutdown to drain pending events. +func NewLangfuseClient(host, publicKey, secretKey string) *LangfuseClient { + c := &LangfuseClient{ + Host: host, + PublicKey: publicKey, + SecretKey: secretKey, + HTTP: &http.Client{Timeout: 2 * time.Second}, + events: make(chan []byte, 1024), + stop: make(chan struct{}), + stopped: make(chan struct{}), + } + go c.worker() + return c +} + +// LangfuseTrace is a single Langfuse trace (one per request). +type LangfuseTrace struct { + ID string `json:"id"` + Name string `json:"name"` + UserID string `json:"userId,omitempty"` + SessionID string `json:"sessionId,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` + Timestamp string `json:"timestamp"` +} + +// LangfuseSpan is a unit of work within a trace (e.g. "Pre-retrieval processing"). +type LangfuseSpan struct { + ID string `json:"id"` + TraceID string `json:"traceId"` + ParentObservationID string `json:"parentObservationId,omitempty"` + Name string `json:"name"` + StartTime string `json:"startTime"` + EndTime string `json:"endTime,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` + Input interface{} `json:"input,omitempty"` + Output interface{} `json:"output,omitempty"` +} + +// LangfuseGeneration is a span with model, usage, and LLM-specific fields. +type LangfuseGeneration struct { + ID string `json:"id"` + TraceID string `json:"traceId"` + ParentObservationID string `json:"parentObservationId,omitempty"` + Name string `json:"name"` + Model string `json:"model,omitempty"` + StartTime string `json:"startTime"` + EndTime string `json:"endTime,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` + Input interface{} `json:"input,omitempty"` + Output interface{} `json:"output,omitempty"` + Usage *LangfuseUsage `json:"usage,omitempty"` +} + +// LangfuseUsage records prompt/completion/total token counts. +type LangfuseUsage struct { + PromptTokens int `json:"promptTokens"` + CompletionTokens int `json:"completionTokens"` + TotalTokens int `json:"totalTokens"` +} + +func (c *LangfuseClient) PostTrace(ctx context.Context, t LangfuseTrace) error { + body, err := json.Marshal(t) + if err != nil { + return err + } + return c.enqueue("traces", body) +} + +func (c *LangfuseClient) PostSpan(ctx context.Context, s LangfuseSpan) error { + body, err := json.Marshal(s) + if err != nil { + return err + } + return c.enqueue("observations", body) +} + +func (c *LangfuseClient) PostGeneration(ctx context.Context, g LangfuseGeneration) error { + body, err := json.Marshal(g) + if err != nil { + return err + } + return c.enqueue("observations", body) +} + +func (c *LangfuseClient) enqueue(kind string, body []byte) error { + if c == nil { + return fmt.Errorf("nil langfuse client") + } + envelope := struct { + Kind string `json:"kind"` + Body []byte `json:"body"` + }{Kind: kind, Body: body} + env, err := json.Marshal(envelope) + if err != nil { + return err + } + select { + case c.events <- env: + return nil + default: + return nil + } +} + +func (c *LangfuseClient) worker() { + defer close(c.stopped) + for { + select { + case <-c.stop: + drainCtx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + for { + select { + case ev := <-c.events: + c.post(drainCtx, ev) + case <-drainCtx.Done(): + cancel() + return + default: + cancel() + return + } + } + case ev := <-c.events: + c.post(context.Background(), ev) + } + } +} + +func (c *LangfuseClient) post(ctx context.Context, envelope []byte) { + var env struct { + Kind string `json:"kind"` + Body json.RawMessage `json:"body"` + } + if err := json.Unmarshal(envelope, &env); err != nil { + return + } + url := c.Host + "/api/public/" + env.Kind + auth := basicAuth(c.PublicKey, c.SecretKey) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(env.Body)) + if err != nil { + return + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", auth) + res, err := c.HTTP.Do(req) + if err != nil { + return + } + defer res.Body.Close() + io.Copy(io.Discard, res.Body) +} + +func (c *LangfuseClient) Shutdown(ctx context.Context) error { + if c == nil { + return nil + } + c.once.Do(func() { close(c.stop) }) + select { + case <-c.stopped: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +func basicAuth(public, secret string) string { + return "Basic " + base64.StdEncoding.EncodeToString([]byte(public+":"+secret)) +} diff --git a/internal/service/metadata.go b/internal/service/metadata.go index 0f8c1c1e20..7d311d3b3f 100644 --- a/internal/service/metadata.go +++ b/internal/service/metadata.go @@ -70,8 +70,8 @@ func (s *MetadataService) GetTenantIDByKBIDs(kbIDs []string) (string, error) { // SearchMetadataResponse holds the result of a metadata search type SearchMetadataResponse struct { - IndexName string - MetadataRecords []map[string]interface{} + IndexName string + MetadataRecords []map[string]interface{} } // SearchMetadata searches the metadata index with the given parameters @@ -92,8 +92,8 @@ func (s *MetadataService) SearchMetadata(kbID, tenantID string, docIDs []string, } return &SearchMetadataResponse{ - IndexName: BuildMetadataIndexName(tenantID), - MetadataRecords: searchResult.MetadataRecords, + IndexName: BuildMetadataIndexName(tenantID), + MetadataRecords: searchResult.MetadataRecords, }, nil } @@ -123,8 +123,8 @@ func (s *MetadataService) SearchMetadataByKBs(kbIDs []string, size int) (*Search } return &SearchMetadataResponse{ - IndexName: BuildMetadataIndexName(tenantID), - MetadataRecords: searchResult.MetadataRecords, + IndexName: BuildMetadataIndexName(tenantID), + MetadataRecords: searchResult.MetadataRecords, }, nil } @@ -242,7 +242,7 @@ func CollectDocIDsByKB(chunks []map[string]interface{}) KBDocIDsMap { seen := make(map[string]struct{}) result := make(KBDocIDsMap) for _, chunk := range chunks { - kbID, _ := chunk["kb_id"].(string) + kbID := extractKBID(chunk) docID := extractDocID(chunk) if kbID == "" || docID == "" { continue @@ -298,6 +298,9 @@ func AttachDocMetaToChunks(chunks []map[string]interface{}, metaByDoc DocMetaMap } for _, chunk := range chunks { docID := extractDocID(chunk) + if docID == "" { + continue + } meta, ok := metaByDoc[docID] if !ok { continue @@ -335,6 +338,17 @@ func (s *MetadataService) EnrichChunksWithDocMetadata(chunks []map[string]interf AttachDocMetaToChunks(chunks, metaByDoc, metadataFields) } +// extractKBID extracts the KB ID from a chunk, checking common field names. +func extractKBID(chunk map[string]interface{}) string { + if id, ok := chunk["kb_id"].(string); ok && id != "" { + return id + } + if id, ok := chunk["dataset_id"].(string); ok && id != "" { + return id + } + return "" +} + // extractDocID extracts the document ID from a chunk, checking both id and doc_id. func extractDocID(chunk map[string]interface{}) string { if id, ok := chunk["id"].(string); ok { diff --git a/internal/service/metadata_filter.go b/internal/service/metadata_filter.go index c5c27bd381..4ed676d665 100644 --- a/internal/service/metadata_filter.go +++ b/internal/service/metadata_filter.go @@ -24,6 +24,8 @@ import ( "ragflow/internal/common" "ragflow/internal/engine" "regexp" + + "github.com/kaptinlin/jsonrepair" "strconv" "strings" "time" @@ -225,7 +227,10 @@ func GenMetaFilter(ctx context.Context, chatModel *modelModule.ChatModel, metaDa var result MetaFilterResult if err := json.Unmarshal([]byte(responseStr), &result); err != nil { // Attempt JSON repair for common LLM output issues - repaired := repairJSON(responseStr) + repaired, rerr := jsonrepair.Repair(responseStr) + if rerr != nil { + repaired = responseStr + } if err2 := json.Unmarshal([]byte(repaired), &result); err2 != nil { common.Warn("Failed to parse meta filter response after repair", zap.String("raw", responseStr[:min(len(responseStr), 200)]), @@ -570,6 +575,29 @@ func metaFilterValues(value interface{}) []string { } } +// MetadataConditionToDocIDs applies metadata_condition against pre-loaded +// metadata and returns a comma-separated doc ID string. +// Returns "-999" when conditions are non-empty but match nothing. +func MetadataConditionToDocIDs(metaData common.MetaData, metadataCondition map[string]interface{}) string { + if metadataCondition == nil { + return "" + } + input := common.ParseAndConvert(metadataCondition) + if input == nil { + return "" + } + filtered := common.MetaFilter(metaData, input) + + rawConditions, _ := metadataCondition["conditions"].([]interface{}) + if len(rawConditions) > 0 && len(filtered) == 0 { + return "-999" + } + if len(filtered) == 0 { + return "" + } + return strings.Join(filtered, ",") +} + // ApplyMetaDataFilter applies metadata filtering rules and returns filtered doc_ids // Supports three modes: // - auto: generate filter conditions via LLM @@ -765,49 +793,3 @@ func constrainDocIDs(baseDocIDs, filteredDocIDs []string) []string { } return result } - -// repairJSON attempts to fix common JSON formatting issues in LLM output. -// This mirrors Python's json_repair.loads() behavior for the most common issues: -// - Trailing commas in arrays/objects -// - Unquoted or single-quoted keys -// - Extra content after closing brace -func repairJSON(s string) string { - s = strings.TrimSpace(s) - - // Find the outermost JSON object { ... } - start := strings.Index(s, "{") - end := strings.LastIndex(s, "}") - if start == -1 || end == -1 || end <= start { - return s - } - s = s[start : end+1] - - // Remove trailing commas before ] or } - s = removeTrailingCommas(s) - - // Fix single-quoted keys and values: 'key' -> "key" - // This is a simplification — only handles the outermost level - s = fixQuotes(s) - - return s -} - -// removeTrailingCommas removes commas that appear immediately before ] or } -func removeTrailingCommas(s string) string { - // Remove , followed by optional whitespace and then ] or } - re := regexp.MustCompile(`,(\s*[}\]])`) - return re.ReplaceAllString(s, "$1") -} - -// fixQuotes converts single quotes to double quotes for JSON keys. -// Only handles simple cases: 'word' -> "word" -func fixQuotes(s string) string { - // Replace single-quoted keys: 'key': -> "key": - re := regexp.MustCompile(`'(\w+)'(\s*):`) - s = re.ReplaceAllString(s, `"$1"$2:`) - // Replace single-quoted string values: : 'value' -> : "value" - // (only when preceded by colon and optional whitespace) - re2 := regexp.MustCompile(`:\s*'([^']*)'(\s*[,}\]])`) - s = re2.ReplaceAllString(s, `: "$1"$2`) - return s -} diff --git a/internal/service/model_service.go b/internal/service/model_service.go index 93d7926365..0ae2b5f9e6 100644 --- a/internal/service/model_service.go +++ b/internal/service/model_service.go @@ -2579,3 +2579,39 @@ func (m *ModelProviderService) ListAllModels(pageIndex, pageSize int) ([]map[str func (m *ModelProviderService) ShowModel(modelName string) (*modelModule.Model, error) { return dao.GetModelProviderManager().GetModelByNameOrAlias(modelName), nil } + +// isImage2TextLLM returns true when the named LLM is registered as an +// image2text model for the tenant. +// Returns false on lookup error or empty LLM ID so callers fall back to +// chat — matches Python's branch order where only an EXPLICIT image2text +// registration switches the model type away from chat. +func (m *ModelProviderService) isImage2TextLLM(tenantID, llmID string) bool { + if m == nil || llmID == "" { + return false + } + modelTypes, err := m.GetModelTypeByName(tenantID, llmID) + if err != nil { + return false + } + for _, mt := range modelTypes { + if mt == entity.ModelTypeImage2Text { + return true + } + } + return false +} + +// GetChatModelConfig resolves the model configuration for a chat dialog. +// If llmID is empty, falls back to the tenant's default chat model. +// When the named LLM is registered as an image2text model, returns the +// IMAGE2TEXT driver/config instead of CHAT. +func (m *ModelProviderService) GetChatModelConfig(tenantID string, llmID string) (modelModule.ModelDriver, string, *modelModule.APIConfig, int, error) { + if llmID == "" { + return m.GetTenantDefaultModelByType(tenantID, entity.ModelTypeChat) + } + modelType := entity.ModelTypeChat + if m.isImage2TextLLM(tenantID, llmID) { + modelType = entity.ModelTypeImage2Text + } + return m.GetModelConfigFromProviderInstance(tenantID, modelType, llmID) +} diff --git a/internal/service/nlp/retrieval.go b/internal/service/nlp/retrieval.go index 0050b0eaa5..c746a8d4f9 100644 --- a/internal/service/nlp/retrieval.go +++ b/internal/service/nlp/retrieval.go @@ -26,6 +26,7 @@ import ( "ragflow/internal/engine/types" "ragflow/internal/entity/models" "sort" + "strconv" "strings" "ragflow/internal/tokenizer" @@ -1069,11 +1070,88 @@ func (s *RetrievalService) PruneDeletedChunks(result *RetrievalSearchResult) (*R }, nil } -// buildIndexNames creates index names for the given tenant IDs +// buildIndexNames creates index names for the given tenant IDs. +// Each tenantID may be a comma-separated list. func buildIndexNames(tenantIDs []string) []string { - indexNames := make([]string, len(tenantIDs)) - for i, tenantID := range tenantIDs { - indexNames[i] = fmt.Sprintf("ragflow_%s", tenantID) + var indexNames []string + for _, tid := range tenantIDs { + for _, part := range strings.Split(tid, ",") { + part = strings.TrimSpace(part) + if part != "" { + indexNames = append(indexNames, fmt.Sprintf("ragflow_%s", part)) + } + } } return indexNames } + +// FetchChunkVectors returns q_{dim}_vec for the given chunk IDs. +// Missing or wrong-dimension chunks get a zero vector. +func (s *RetrievalService) FetchChunkVectors(ctx context.Context, chunkIDs []string, tenantIDs []string, kbIDs []string, dim int) (map[string][]float64, error) { + if dim <= 0 { + return nil, fmt.Errorf("FetchChunkVectors: dim must be > 0, got %d", dim) + } + if len(chunkIDs) == 0 { + return map[string][]float64{}, nil + } + + vecField := fmt.Sprintf("q_%d_vec", dim) + idxNames := buildIndexNames(tenantIDs) + + req := &types.SearchRequest{ + IndexNames: idxNames, + KbIDs: kbIDs, + Limit: len(chunkIDs), + Offset: 0, + SelectFields: []string{"id", vecField}, + Filter: map[string]interface{}{"id": chunkIDs}, + MatchExprs: []interface{}{}, + } + + result, err := s.docEngine.Search(ctx, req) + if err != nil { + return nil, fmt.Errorf("FetchChunkVectors: engine search failed: %w", err) + } + + out := make(map[string][]float64, len(chunkIDs)) + for _, cid := range chunkIDs { + out[cid] = make([]float64, dim) + } + + for _, chunk := range result.Chunks { + cid, _ := chunk["id"].(string) + if cid == "" { + continue + } + var vec []float64 + switch v := chunk[vecField].(type) { + case []float64: + vec = v + case []interface{}: + vec = make([]float64, len(v)) + for i, val := range v { + if f, ok := val.(float64); ok { + vec[i] = f + } else if f32, ok := val.(float32); ok { + vec[i] = float64(f32) + } + } + case string: + // Tab-separated floats (mirrors Python's split("\t") in + // search.py:435-437 when Infinity returns vectors as a string). + parts := strings.Split(v, "\t") + vec = make([]float64, 0, len(parts)) + for _, p := range parts { + if f, err := strconv.ParseFloat(strings.TrimSpace(p), 64); err == nil { + vec = append(vec, f) + } + } + } + if len(vec) != dim { + vec = make([]float64, dim) + } + out[cid] = vec + } + + return out, nil +} diff --git a/internal/service/openai_chat.go b/internal/service/openai_chat.go new file mode 100644 index 0000000000..c076e770a0 --- /dev/null +++ b/internal/service/openai_chat.go @@ -0,0 +1,846 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package service + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "ragflow/internal/entity" + "regexp" + "strings" + "time" + + "ragflow/internal/common" + "ragflow/internal/tokenizer" + + "github.com/gin-gonic/gin" + + "go.uber.org/zap" +) + +type OpenAIRequest struct { + ChatID string + Model string + // Chat is the loaded chat entity, mutated in place by MergeGenerationConfig. + Chat *entity.Chat + // Messages are pre-normalized: system messages removed, leading assistant + // removed, content coerced to string (vision parts dropped). + Messages []map[string]interface{} + Stream bool + NeedReference bool + IncludeRefMetadata bool + MetadataFields []string + MetadataCondition map[string]interface{} + // Internet not plumbed — matches Python's openai_api.py behavior. + GenerationConfig map[string]interface{} +} + +// FormattedChunk is a normalized chunk matching Python's chunks_format output. +type FormattedChunk struct { + ID string `json:"id"` + Content string `json:"content"` + DocumentID string `json:"document_id"` + DocumentName string `json:"document_name"` + DatasetID string `json:"dataset_id"` + ImageID string `json:"image_id"` + Positions interface{} `json:"positions"` + URL interface{} `json:"url"` + Similarity interface{} `json:"similarity"` + VectorSimilarity interface{} `json:"vector_similarity"` + TermSimilarity interface{} `json:"term_similarity"` + RowID interface{} `json:"row_id"` + DocType interface{} `json:"doc_type"` + DocumentMetadata interface{} `json:"document_metadata"` +} + +// OpenAICompletionResponse is the non-streaming response payload. +// The reasoning_tokens quirk (openai_api.py:348-352) lives in the c.JSON call. +type OpenAICompletionResponse struct { + Model string + Content string + Reference []FormattedChunk + PromptTokens int + CompletionTokens int + TotalTokens int + Created int64 +} + +// OpenAIStreamEventKind discriminates stream events. +type OpenAIStreamEventKind int + +const ( + OpenAIEventContent OpenAIStreamEventKind = iota // delta.content + OpenAIEventReasoning // delta.reasoning_content + OpenAIEventFinal // trailing chunk + OpenAIEventError // in-band error +) + +// OpenAIStreamEvent is yielded by the event-translator inside OpenAIChatCompletions. +type OpenAIStreamEvent struct { + Kind OpenAIStreamEventKind + Delta string // for Content / Reasoning + FinalAnswer string // for Final + FinalReference []FormattedChunk + Error string // for Error + PromptTokens int + CompletionTokens int + TotalTokens int +} + +// OpenAIChatService implements the /api/v1/openai//chat/completions route. +// It composes ChatPipelineService for the shared RAG pipeline (AsyncChat) while +// keeping handler-level concerns (message filtering, generation config merge, +// reference metadata enrichment) on the service itself. +type OpenAIChatService struct { + chatSvc *ChatService + tenantLLMSvc *TenantLLMService + pipeline *ChatPipelineService +} + +func NewOpenAIChatService() *OpenAIChatService { + return &OpenAIChatService{ + chatSvc: NewChatService(), + tenantLLMSvc: NewTenantLLMService(), + pipeline: NewChatPipelineService(), + } +} + +// OpenAIChatRequest mirrors the OpenAI Chat Completions request body. +// `stop` and `user` are omitted intentionally — JSON unmarshal silently drops them. +type OpenAIChatRequest struct { + Model string `json:"model"` + Messages []map[string]interface{} `json:"messages"` + Stream *bool `json:"stream,omitempty"` + ExtraBody interface{} `json:"extra_body,omitempty"` + + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + MaxTokens *int `json:"max_tokens,omitempty"` +} + +func (s *OpenAIChatService) OpenAIChatCompletions(c *gin.Context, userID, chatID string, bodyBytes []byte) { + var req OpenAIChatRequest + if err := json.Unmarshal(bodyBytes, &req); err != nil { + s.writeArgError(c, err.Error()) + return + } + common.Info("OpenAIChatCompletions started", zap.String("chat_id", chatID)) + + normalizedMessages, err := normalizeOpenAIMessages(req.Messages) + if err != nil { + s.writeDataError(c, err.Error()) + return + } + if len(normalizedMessages) == 0 { + s.writeDataError(c, "You have to provide messages.") + return + } + + lastRole, _ := normalizedMessages[len(normalizedMessages)-1]["role"].(string) + if lastRole != "user" { + s.writeDataError(c, "The last content of this conversation is not from user.") + return + } + + if req.ExtraBody != nil { + if _, ok := req.ExtraBody.(map[string]interface{}); !ok { + s.writeDataError(c, "extra_body must be an object.") + return + } + } + + var needReference = false + var includeRefMetadata = false + var metadataFields []string + var metadataCondition map[string]interface{} + if eb, ok := req.ExtraBody.(map[string]interface{}); ok { + if v, hasRef := eb["reference"].(bool); hasRef { + needReference = v + } + rawRM, hasRM := eb["reference_metadata"] + if hasRM && rawRM != nil { + rm, ok := rawRM.(map[string]interface{}) + if !ok { + s.writeDataError(c, "reference_metadata must be an object.") + return + } + if inc, hasInc := rm["include"].(bool); hasInc { + includeRefMetadata = inc + } + if rawFields, hasFields := rm["fields"]; hasFields && rawFields != nil { + rawArr, rawOK := rawFields.([]interface{}) + if !rawOK { + s.writeDataError(c, "reference_metadata.fields must be an array.") + return + } + if len(rawArr) == 0 { + metadataFields = []string{} + } else { + for _, f := range rawArr { + str, ok := f.(string) + if !ok { + s.writeDataError(c, "reference_metadata.fields must be an array.") + return + } + metadataFields = append(metadataFields, str) + } + } + } + } + if mc, hasMC := eb["metadata_condition"]; hasMC && mc != nil { + mcMap, isObj := mc.(map[string]interface{}) + if !isObj { + s.writeDataError(c, "metadata_condition must be an object.") + return + } + if len(mcMap) > 0 { + metadataCondition = mcMap + } + } + } + + dialogResp, err := s.chatSvc.GetChat(userID, chatID) + if err != nil { + s.writeDataError(c, err.Error()) + return + } + dialog := dialogResp.Chat + resolvedModel := req.Model + if req.Model == "model" { + resolvedModel = dialog.LLMID + if resolvedModel == "" { + resolvedModel = "model" + } + } + if req.Model != "model" { + if _, _, _, _, mErr := s.pipeline.ModelProviderSvc.GetChatModelConfig(dialog.TenantID, resolvedModel); mErr != nil { + s.writeArgError(c, fmt.Sprintf("`llm_id` %s doesn't exist", req.Model)) + return + } + apiKey, apiErr := s.tenantLLMSvc.GetAPIKeyFromInstance(dialog.TenantID, req.Model) + if apiErr != nil || apiKey == "" { + s.writeDataError(c, fmt.Sprintf("Cannot use specified model %s.", req.Model)) + return + } + dialog.LLMID = resolvedModel + } + + genCfg := extractGenerationConfig(&req) + + s.MergeGenerationConfig(dialog, genCfg) + + stream := req.Stream != nil && *req.Stream + openaiReq := &OpenAIRequest{ + ChatID: chatID, + Model: resolvedModel, + Chat: dialog, + Messages: normalizedMessages, + Stream: stream, + NeedReference: needReference, + IncludeRefMetadata: includeRefMetadata, + MetadataFields: metadataFields, + MetadataCondition: metadataCondition, + GenerationConfig: genCfg, + } + + completionID := fmt.Sprintf("chatcmpl-%s", openaiReq.ChatID) + + ctx := c.Request.Context() + lfClient := LangfuseClientFromTenant(ctx, dialog.TenantID, userID, openaiReq.ChatID, openaiReq.Model) + if lfClient != nil { + ctx = context.WithValue(ctx, langfuseCtxKey, lfClient) + defer func() { + shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + _ = lfClient.Shutdown(shutdownCtx) + }() + } + + filteredMessages := s.filterMessages(openaiReq.Messages) + + var docIDsStr string + if openaiReq.MetadataCondition != nil { + common.Debug("metadata_condition filter started", + zap.Any("condition", openaiReq.MetadataCondition)) + kbIDs := make([]string, 0, len(dialog.KBIDs)) + for _, raw := range dialog.KBIDs { + if id, ok := raw.(string); ok && id != "" { + kbIDs = append(kbIDs, id) + } + } + metas, mdErr := s.pipeline.MetadataSvc.GetFlattedMetaByKBs(kbIDs) + if mdErr != nil { + s.writeDataError(c, fmt.Errorf("metadata_condition: load metadata: %w", mdErr).Error()) + return + } + docIDsStr = MetadataConditionToDocIDs(metas, openaiReq.MetadataCondition) + common.Debug("metadata_condition filter ended", zap.String("doc_ids", docIDsStr)) + } + + common.Debug("OpenAI chat config resolved", + zap.String("tenant_id", dialog.TenantID), + zap.String("dialog_id", dialog.ID), + zap.String("llm_id", dialog.LLMID), + zap.Any("llm_setting", dialog.LLMSetting), + zap.Any("request_generation_config", openaiReq.GenerationConfig), + zap.String("doc_ids", docIDsStr)) + + promptTokens := 0 + if lastMsg := filteredMessages[len(filteredMessages)-1]; lastMsg != nil { + if content, ok := lastMsg["content"].(string); ok { + promptTokens = tokenizer.NumTokensFromString(content) + } + } + + chatKwargs := map[string]interface{}{ + "toolcall_session": nil, // no tool calls on OpenAI-compat path + "tools": nil, + "quote": needReference, + } + if docIDsStr != "" { + chatKwargs["doc_ids"] = docIDsStr + } + + asyncResults, asyncErr := s.pipeline.AsyncChat(ctx, dialog, filteredMessages, openaiReq.Stream, chatKwargs) + if asyncErr != nil { + s.writeDataError(c, asyncErr.Error()) + return + } + + if stream { + events := make(chan OpenAIStreamEvent, 16) + go func() { + defer close(events) + defer func() { + if r := recover(); r != nil { + common.Warn("OpenAI streaming goroutine panic", zap.Any("recover", r)) + events <- OpenAIStreamEvent{Kind: OpenAIEventError, Error: fmt.Sprintf("internal error: %v", r)} + } + }() + + var ( + fullContent string + completionTok int + deltaCount int + finalReference []FormattedChunk + lastResult AsyncChatResult + ) + + for result := range asyncResults { + lastResult = result + + if result.StartToThink || result.EndToThink { + // Think markers only toggle routing state; no SSE event + // emitted. Matches Python's _stream_chat_completion_sse + // which ignores start_to_think/end_to_think flags and + // never emits "" or "" as content. + continue + } + + if result.Final { + finalContent := strings.TrimSpace(result.Answer) + fullContent = finalContent + if ref, ok := result.Reference["chunks"]; ok { + if chunks, ok := ref.([]map[string]interface{}); ok { + finalReference = formatChunks(chunks) + } + } + s.enrichChunksWithDocumentMetadata(finalReference, dialog.TenantID, openaiReq.IncludeRefMetadata, openaiReq.MetadataFields) + completionTok = tokenizer.NumTokensFromString(result.Answer) + events <- OpenAIStreamEvent{ + Kind: OpenAIEventFinal, + FinalAnswer: finalContent, + FinalReference: finalReference, + PromptTokens: promptTokens, + CompletionTokens: completionTok, + TotalTokens: promptTokens + completionTok, + } + return + } + + if result.Reasoning != "" { + completionTok += tokenizer.NumTokensFromString(result.Reasoning) + events <- OpenAIStreamEvent{Kind: OpenAIEventReasoning, Delta: result.Reasoning} + } + + if result.Answer != "" { + delta := result.Answer + fullContent += delta + completionTok += tokenizer.NumTokensFromString(delta) + events <- OpenAIStreamEvent{Kind: OpenAIEventContent, Delta: delta} + if deltaCount < 3 { + common.Debug("OpenAI first content delta", + zap.Int("delta_index", deltaCount), + zap.String("delta", result.Answer), + zap.Int("delta_len", len(result.Answer))) + deltaCount++ + } + } + } + + if finalReference == nil && openaiReq.NeedReference { + if ref, ok := lastResult.Reference["chunks"]; ok { + if chunks, ok := ref.([]map[string]interface{}); ok { + finalReference = formatChunks(chunks) + } + } + } + s.enrichChunksWithDocumentMetadata(finalReference, dialog.TenantID, openaiReq.IncludeRefMetadata, openaiReq.MetadataFields) + events <- OpenAIStreamEvent{ + Kind: OpenAIEventFinal, + FinalAnswer: strings.TrimSpace(fullContent), + FinalReference: finalReference, + PromptTokens: promptTokens, + CompletionTokens: completionTok, + TotalTokens: promptTokens + completionTok, + } + }() + if err := streamChatCompletionSSE(c, events, completionID, resolvedModel, openaiReq.NeedReference); err != nil { + s.writeDataError(c, err.Error()) + } + } else { + var finalResult AsyncChatResult + found := false + for result := range asyncResults { + if result.Final { + finalResult = result + found = true + break + } + } + if !found { + s.writeDataError(c, "AsyncChat returned no final result") + return + } + + content := strings.TrimSpace(finalResult.Answer) + completionTokens := tokenizer.NumTokensFromString(content) + resp := &OpenAICompletionResponse{ + Created: time.Now().Unix(), + Model: openaiReq.Model, + Content: content, + PromptTokens: promptTokens, + CompletionTokens: completionTokens, + TotalTokens: promptTokens + completionTokens, + } + if openaiReq.NeedReference { + if ref, ok := finalResult.Reference["chunks"]; ok { + if chunks, ok := ref.([]map[string]interface{}); ok { + resp.Reference = formatChunks(chunks) + } + } + s.enrichChunksWithDocumentMetadata(resp.Reference, dialog.TenantID, openaiReq.IncludeRefMetadata, openaiReq.MetadataFields) + } + + contextUsed := 0 + for _, m := range openaiReq.Messages { + if c, ok := m["content"].(string); ok { + contextUsed += tokenizer.NumTokensFromString(c) + } + } + + choices := []gin.H{{ + "index": 0, + "finish_reason": "stop", + "logprobs": nil, + "message": gin.H{ + "role": "assistant", + "content": resp.Content, + }, + }} + if openaiReq.NeedReference { + choices[0]["message"].(gin.H)["reference"] = resp.Reference + } + + c.JSON(http.StatusOK, gin.H{ + "id": completionID, + "object": "chat.completion", + "created": resp.Created, + "model": resp.Model, + "usage": gin.H{ + "prompt_tokens": resp.PromptTokens, + "completion_tokens": resp.CompletionTokens, + "total_tokens": resp.PromptTokens + resp.CompletionTokens, + "completion_tokens_details": gin.H{ + "reasoning_tokens": contextUsed, + "accepted_prediction_tokens": resp.CompletionTokens, + "rejected_prediction_tokens": 0, + }, + }, + "choices": choices, + }) + } + common.Info("OpenAIChatCompletions completed", zap.String("chat_id", chatID)) +} + +// MergeGenerationConfig merges request config into dialog.LLMSetting (mutating). +func (s *OpenAIChatService) MergeGenerationConfig(dialog *entity.Chat, config map[string]interface{}) { + if config == nil { + return + } + if dialog.LLMSetting == nil { + dialog.LLMSetting = map[string]interface{}{} + } + for k, v := range config { + dialog.LLMSetting[k] = v + } +} + +// filterMessages drops system messages and leading assistant messages. +func (s *OpenAIChatService) filterMessages(messages []map[string]interface{}) []map[string]interface{} { + var out []map[string]interface{} + for _, m := range messages { + role, _ := m["role"].(string) + if role == "system" { + continue + } + if role == "assistant" && len(out) == 0 { + continue + } + out = append(out, m) + } + return out +} + +// cleanCitationMarkers strips "##N$$" markers from the answer. +func cleanCitationMarkers(s string) string { + var citationMarkerRegex = regexp.MustCompile(`##\d+\$\$`) + return citationMarkerRegex.ReplaceAllString(s, "") +} + +// isContentDelta filters out "[DONE]" leaked by some drivers. +func isContentDelta(answer *string) bool { + if answer == nil { + return false + } + if *answer == "" { + return false + } + if *answer == "[DONE]" { + return false + } + return true +} + +// extractGenerationConfig mirrors Python's extract_generation_config. +func extractGenerationConfig(req *OpenAIChatRequest) map[string]interface{} { + cfg := make(map[string]interface{}) + if req.Temperature != nil { + cfg["temperature"] = *req.Temperature + } + if req.TopP != nil { + cfg["top_p"] = *req.TopP + } + if req.MaxTokens != nil { + cfg["max_tokens"] = float64(*req.MaxTokens) + } + if req.FrequencyPenalty != nil { + cfg["frequency_penalty"] = *req.FrequencyPenalty + } + if req.PresencePenalty != nil { + cfg["presence_penalty"] = *req.PresencePenalty + } + return cfg +} + +// normalizeMessageContent coerces content to string (drops non-text parts). +func normalizeMessageContent(content interface{}) (string, error) { + if content == nil { + return "", nil + } + if s, ok := content.(string); ok { + return s, nil + } + if arr, ok := content.([]interface{}); ok { + parts := make([]string, 0, len(arr)) + for _, p := range arr { + pm, ok := p.(map[string]interface{}) + if !ok { + continue + } + if pm["type"] != "text" { + continue + } + t, _ := pm["text"].(string) + parts = append(parts, t) + } + return joinNonEmpty(parts, "\n"), nil + } + return "", fmt.Errorf("messages[].content must be a string or an array of content parts.") +} + +// normalizeOpenAIMessages normalizes message content for all messages. +func normalizeOpenAIMessages(messages []map[string]interface{}) ([]map[string]interface{}, error) { + out := make([]map[string]interface{}, 0, len(messages)) + for _, m := range messages { + normalized := make(map[string]interface{}, len(m)) + for k, v := range m { + normalized[k] = v + } + c, err := normalizeMessageContent(m["content"]) + if err != nil { + return nil, err + } + normalized["content"] = c + out = append(out, normalized) + } + return out, nil +} + +// joinNonEmpty joins strings with sep, skipping empties. +func joinNonEmpty(parts []string, sep string) string { + nonEmpty := make([]string, 0, len(parts)) + for _, p := range parts { + if p != "" { + nonEmpty = append(nonEmpty, p) + } + } + out := "" + for i, p := range nonEmpty { + if i > 0 { + out += sep + } + out += p + } + return out +} + +// getValue reads chunk[m1] falling back to chunk[m2]. +func getValue(chunk map[string]interface{}, k1, k2 string) interface{} { + if v, ok := chunk[k1]; ok { + return v + } + return chunk[k2] +} + +func strVal(v interface{}) string { + if s, ok := v.(string); ok { + return s + } + return "" +} + +// formatChunks normalizes chunk fields to a canonical schema, matching Python's chunks_format. +func formatChunks(chunks []map[string]interface{}) []FormattedChunk { + out := make([]FormattedChunk, 0, len(chunks)) + for _, chunk := range chunks { + out = append(out, FormattedChunk{ + ID: strVal(getValue(chunk, "chunk_id", "id")), + Content: strVal(getValue(chunk, "content", "content_with_weight")), + DocumentID: strVal(getValue(chunk, "doc_id", "document_id")), + DocumentName: strVal(getValue(chunk, "docnm_kwd", "document_name")), + DatasetID: strVal(getValue(chunk, "kb_id", "dataset_id")), + ImageID: strVal(getValue(chunk, "image_id", "img_id")), + Positions: getValue(chunk, "positions", "position_int"), + URL: chunk["url"], + Similarity: chunk["similarity"], + VectorSimilarity: chunk["vector_similarity"], + TermSimilarity: chunk["term_similarity"], + RowID: chunk["row_id"], + DocType: getValue(chunk, "doc_type_kwd", "doc_type"), + DocumentMetadata: chunk["document_metadata"], + }) + } + return out +} + +// enrichChunksWithDocumentMetadata enriches chunks with document metadata. +// Mirrors Python's enrich_chunks_with_document_metadata() in +// api/utils/reference_metadata_utils.py. +// When fields is a non-nil empty slice (explicitly provided as []), enrichment +// is skipped — matching Python's behavior for {"fields": []}. +func (s *OpenAIChatService) enrichChunksWithDocumentMetadata(chunks []FormattedChunk, tenantID string, include bool, fields []string) { + if !include || len(chunks) == 0 || s == nil || s.pipeline.MetadataSvc == nil { + return + } + if fields != nil && len(fields) == 0 { + return + } + maps := make([]map[string]interface{}, len(chunks)) + for i, ch := range chunks { + maps[i] = map[string]interface{}{ + "kb_id": ch.DatasetID, + "doc_id": ch.DocumentID, + "document_metadata": ch.DocumentMetadata, + } + } + s.pipeline.MetadataSvc.EnrichChunksWithDocMetadata(maps, tenantID, fields) + for i, m := range maps { + if md, ok := m["document_metadata"]; ok { + chunks[i].DocumentMetadata = md + } + } +} + +// streamChatCompletionSSE drains events and writes SSE chunks. +func streamChatCompletionSSE( + c *gin.Context, + events <-chan OpenAIStreamEvent, + completionID string, + requestedModel string, + needReference bool, +) error { + c.Header("Cache-control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + c.Header("Content-Type", "text/event-stream; charset=utf-8") + + flusher, ok := c.Writer.(http.Flusher) + if !ok { + return fmt.Errorf("streaming unsupported") + } + + writeSSE := func(payload gin.H) { + body, _ := json.Marshal(payload) + _, _ = c.Writer.Write([]byte("data:")) + _, _ = c.Writer.Write(body) + _, _ = c.Writer.Write([]byte("\n\n")) + flusher.Flush() + } + + for ev := range events { + switch ev.Kind { + case OpenAIEventContent: + chunk := gin.H{ + "id": completionID, + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": requestedModel, + "system_fingerprint": "", + "usage": nil, + "choices": []gin.H{{ + "index": 0, + "delta": gin.H{ + "role": "assistant", + "content": ev.Delta, + "reasoning_content": nil, + "function_call": nil, + "tool_calls": nil, + }, + "finish_reason": nil, + "logprobs": nil, + }}, + } + writeSSE(chunk) + + case OpenAIEventReasoning: + chunk := gin.H{ + "id": completionID, + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": requestedModel, + "system_fingerprint": "", + "usage": nil, + "choices": []gin.H{{ + "index": 0, + "delta": gin.H{ + "role": "assistant", + "content": nil, + "reasoning_content": ev.Delta, + "function_call": nil, + "tool_calls": nil, + }, + "finish_reason": nil, + "logprobs": nil, + }}, + } + writeSSE(chunk) + + case OpenAIEventError: + chunk := gin.H{ + "id": completionID, + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": requestedModel, + "system_fingerprint": "", + "usage": nil, + "choices": []gin.H{{ + "index": 0, + "delta": gin.H{ + "role": "assistant", + "content": "**ERROR**: " + ev.Error, + "reasoning_content": nil, + "function_call": nil, + "tool_calls": nil, + }, + "finish_reason": nil, + "logprobs": nil, + }}, + } + writeSSE(chunk) + + case OpenAIEventFinal: + delta := gin.H{ + "role": "assistant", + "content": nil, + "reasoning_content": nil, + "function_call": nil, + "tool_calls": nil, + } + if needReference { + delta["reference"] = ev.FinalReference + delta["final_content"] = ev.FinalAnswer + } + chunk := gin.H{ + "id": completionID, + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": requestedModel, + "system_fingerprint": "", + "usage": gin.H{ + "prompt_tokens": ev.PromptTokens, + "completion_tokens": ev.CompletionTokens, + "total_tokens": ev.TotalTokens, + }, + "choices": []gin.H{{ + "index": 0, + "delta": delta, + "finish_reason": "stop", + "logprobs": nil, + }}, + } + writeSSE(chunk) + } + } + + // Always terminate with data: [DONE]\n\n. + _, _ = c.Writer.Write([]byte("data: [DONE]\n\n")) + flusher.Flush() + return nil +} + +// writeArgError writes a 101 JSON error envelope (malformed request). +func (s *OpenAIChatService) writeArgError(c *gin.Context, msg string) { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeArgumentError, + "data": nil, + "message": msg, + }) +} + +// writeDataError writes a 102 JSON error envelope (service failure). +func (s *OpenAIChatService) writeDataError(c *gin.Context, msg string) { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeDataError, + "data": nil, + "message": msg, + }) +} diff --git a/internal/service/openai_chat_test.go b/internal/service/openai_chat_test.go new file mode 100644 index 0000000000..96a115bef4 --- /dev/null +++ b/internal/service/openai_chat_test.go @@ -0,0 +1,842 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package service + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "reflect" + "strings" + "testing" + + "github.com/gin-gonic/gin" + + "ragflow/internal/entity" + "ragflow/internal/tokenizer" +) + +// TestOpenAI_FilterMessagesDropsSystemAndLeadingAssistant pins down the +// Python openai_api.py:301-307 behavior: drop all system messages, then +// drop leading assistant messages until the first user/system message. +func TestOpenAI_FilterMessagesDropsSystemAndLeadingAssistant(t *testing.T) { + svc := &OpenAIChatService{} + in := []map[string]interface{}{ + {"role": "system", "content": "you are a helper"}, + {"role": "assistant", "content": "leading assistant, dropped"}, + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "world"}, + {"role": "system", "content": "another system, dropped"}, + {"role": "user", "content": "second question"}, + } + got := svc.filterMessages(in) + want := []map[string]interface{}{ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "world"}, + {"role": "user", "content": "second question"}, + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("filterMessages: got %#v want %#v", got, want) + } +} + +// TestOpenAI_MergeGenerationConfig_RequestWins pins the merge order: +// dialog.LLMSetting is the base; request fields override. Mirrors +// api/apps/restful_apis/_generation_params.py:merge_generation_config, +// which the Python handler calls at openai_api.py:283. +func TestOpenAI_MergeGenerationConfig_RequestWins(t *testing.T) { + svc := &OpenAIChatService{} + dialog := &entity.Chat{ + LLMSetting: entity.JSONMap{ + "temperature": 0.5, + "top_p": 0.9, + }, + } + req := map[string]interface{}{"temperature": 0.1} + svc.MergeGenerationConfig(dialog, req) + + if got := dialog.LLMSetting["temperature"]; got != 0.1 { + t.Fatalf("temperature: request should win, got %v", got) + } + if got := dialog.LLMSetting["top_p"]; got != 0.9 { + t.Fatalf("top_p: dialog value should be preserved, got %v", got) + } +} + +// TestOpenAI_MergeGenerationConfig_NilConfigIsNoOp verifies the Python +// `if not generation_config: return` early-exit: a nil config does not +// touch the dialog at all. +func TestOpenAI_MergeGenerationConfig_NilConfigIsNoOp(t *testing.T) { + svc := &OpenAIChatService{} + dialog := &entity.Chat{ + LLMSetting: entity.JSONMap{"temperature": 0.5}, + } + svc.MergeGenerationConfig(dialog, nil) + if got := dialog.LLMSetting["temperature"]; got != 0.5 { + t.Fatalf("nil config should be a no-op, got temperature=%v", got) + } +} + +// TestOpenAI_MergeGenerationConfig_NilLLMSettingIsInitialized pins +// the Python `getattr(dialog, "llm_setting", None) or {}` pattern: +// a dialog with no LLMSetting (e.g. freshly created) gets one +// initialized to an empty map before the merge writes into it. +func TestOpenAI_MergeGenerationConfig_NilLLMSettingIsInitialized(t *testing.T) { + svc := &OpenAIChatService{} + dialog := &entity.Chat{} // LLMSetting is nil + req := map[string]interface{}{"temperature": 0.7} + svc.MergeGenerationConfig(dialog, req) + if dialog.LLMSetting == nil { + t.Fatal("expected LLMSetting to be initialized after merge") + } + if got := dialog.LLMSetting["temperature"]; got != 0.7 { + t.Fatalf("expected temperature 0.7, got %v", got) + } +} + +// TestOpenAI_MergeGenerationConfig_AddsNewKeys pins that the merge +// ADDS keys that the dialog didn't have before, matching Python's +// `dict.update`. +func TestOpenAI_MergeGenerationConfig_AddsNewKeys(t *testing.T) { + svc := &OpenAIChatService{} + dialog := &entity.Chat{ + LLMSetting: entity.JSONMap{"temperature": 0.5}, + } + req := map[string]interface{}{"top_p": 0.9, "max_tokens": 256} + svc.MergeGenerationConfig(dialog, req) + if got := dialog.LLMSetting["top_p"]; got != 0.9 { + t.Fatalf("top_p: expected 0.9, got %v", got) + } + if got := dialog.LLMSetting["max_tokens"]; got != 256 { + t.Fatalf("max_tokens: expected 256, got %v", got) + } + if got := dialog.LLMSetting["temperature"]; got != 0.5 { + t.Fatalf("temperature: existing dialog value should be preserved, got %v", got) + } +} + +// TestOpenAI_MergeThenBuild_AllGenerationParamsReachChatConfig is the +// end-to-end contract for the OpenAI path: the handler builds a +// genCfg via extractGenerationConfig, the handler calls +// MergeGenerationConfig(dialog, genCfg), and the RAG pipeline later +// calls BuildChatConfig(dialog, nil) which reads the merged values. +// Verifies that all 5 fields the Python server honors +// (temperature, top_p, max_tokens, frequency_penalty, presence_penalty) +// survive the merge. For 3 of them (temperature, top_p, max_tokens) the +// ChatConfig fields exist and the test asserts the values. For the +// other 2 (frequency_penalty, presence_penalty) the ChatConfig struct +// doesn't have fields yet, so we just assert the dialog's LLMSetting +// preserves them — the structural gap is documented in +// openai_chat_completions.go::extractGenerationConfig. +// +// The handler-side float64 coercion of max_tokens is verified by +// TestExtractGenerationConfig_OnlyKnownFields in the handler package. +// This test uses a float64 here (matching what the handler produces +// after the fix) so the BuildChatConfig type assertion succeeds. +func TestOpenAI_MergeThenBuild_AllGenerationParamsReachChatConfig(t *testing.T) { + svc := &OpenAIChatService{} + dialog := &entity.Chat{} // no LLMSetting, no defaults + req := map[string]interface{}{ + "temperature": 0.7, + "top_p": 0.9, + "max_tokens": float64(256), + "frequency_penalty": 0.1, + "presence_penalty": 0.2, + } + svc.MergeGenerationConfig(dialog, req) + + // The merge itself is type-agnostic — verify the dialog kept the + // raw values so any downstream code (Python-style dict consumer + // or future Go consumer) can read them. + if got := dialog.LLMSetting["temperature"]; got != 0.7 { + t.Fatalf("temperature: expected 0.7, got %v", got) + } + if got := dialog.LLMSetting["top_p"]; got != 0.9 { + t.Fatalf("top_p: expected 0.9, got %v", got) + } + if got, ok := dialog.LLMSetting["max_tokens"].(float64); !ok || got != 256 { + t.Fatalf("max_tokens: expected float64 256, got %v (%T)", dialog.LLMSetting["max_tokens"], dialog.LLMSetting["max_tokens"]) + } + if got := dialog.LLMSetting["frequency_penalty"]; got != 0.1 { + t.Fatalf("frequency_penalty: expected 0.1, got %v", got) + } + if got := dialog.LLMSetting["presence_penalty"]; got != 0.2 { + t.Fatalf("presence_penalty: expected 0.2, got %v", got) + } + + // Now run the RAG-pipeline call: BuildChatConfig(dialog, nil). + // For the 3 fields ChatConfig supports, the values must surface + // on the returned struct. + cfg := BuildChatConfig(dialog, nil) + if cfg.Temperature == nil || *cfg.Temperature != 0.7 { + t.Fatalf("ChatConfig.Temperature: expected 0.7, got %v", cfg.Temperature) + } + if cfg.TopP == nil || *cfg.TopP != 0.9 { + t.Fatalf("ChatConfig.TopP: expected 0.9, got %v", cfg.TopP) + } + if cfg.MaxTokens == nil || *cfg.MaxTokens != 256 { + t.Fatalf("ChatConfig.MaxTokens: expected 256, got %v", cfg.MaxTokens) + } +} + +// bindToolsFires pins the condition logic at chat_pipeline.go:241 — the +// BindTools block fires only when BOTH toolcall_session AND tools are +// present AND non-nil. Mirrors Python's `if toolcall_session and +// tools:` (dialog_service.py:584-585), which short-circuits to false +// on None. Without the explicit nil-check in the Go code, a present- +// with-nil value would flip hasSession/hasTools to true and call +// BindTools(nil, nil) — a no-op in the generic wrapper but a behavior +// change vs. Python's truthy short-circuit. +func bindToolsFires(kwargs map[string]interface{}) bool { + tc, hasTC := kwargs["toolcall_session"] + t, hasT := kwargs["tools"] + return hasTC && tc != nil && hasT && t != nil +} + +func TestOpenAI_BindToolsCondition_NilValuesDoNotFire(t *testing.T) { + cases := []struct { + name string + kw map[string]interface{} + fires bool + }{ + // The OPENAI_CHAT call site at openai_chat.go (this is the + // shape the new asyncKwargs produces): both keys present + // with nil values. Block must NOT fire. + { + "both present with nil (current OPENAI_CHAT call site)", + map[string]interface{}{"toolcall_session": nil, "tools": nil}, + false, + }, + // Both keys absent (legacy / pre-refactor OPENAI_CHAT call + // site): also must not fire. + { + "both absent (legacy / pre-refactor)", + map[string]interface{}{}, + false, + }, + // Both keys present with non-nil values (future OPENAI_CHAT + // with tool support): block MUST fire. + { + "both present with non-nil (future tool support)", + map[string]interface{}{"toolcall_session": "session-1", "tools": []interface{}{}}, + true, + }, + // Mixed: one nil, one present. The Python `if X and Y:` + // short-circuits on the first falsy, so block must NOT fire. + { + "toolcall_session nil, tools present", + map[string]interface{}{"toolcall_session": nil, "tools": []interface{}{}}, + false, + }, + { + "toolcall_session present, tools nil", + map[string]interface{}{"toolcall_session": "session-1", "tools": nil}, + false, + }, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + if got := bindToolsFires(c.kw); got != c.fires { + t.Fatalf("bindToolsFires(%v) = %v, want %v", c.kw, got, c.fires) + } + }) + } +} + +// TestOpenAI_CleanCitationMarkers pins down the citation-marker stripping +// that matches Python's re.sub(r"##\d+\$\$", "", content). +func TestOpenAI_CleanCitationMarkers(t *testing.T) { + cases := []struct { + name string + in string + want string + }{ + {"empty", "", ""}, + {"plain", "hello world", "hello world"}, + {"single", "hello ##1$$ world", "hello world"}, + {"multi", "##12$$foo##34$$bar", "foobar"}, + {"non-numeric", "##abc$$ stays", "##abc$$ stays"}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + if got := cleanCitationMarkers(c.in); got != c.want { + t.Fatalf("cleanCitationMarkers(%q) = %q, want %q", c.in, got, c.want) + } + }) + } +} + +// TestOpenAI_SystemPrompt reads dialog.PromptConfig["system"]. +func TestOpenAI_SystemPrompt(t *testing.T) { + systemPrompt := func(dialog *entity.Chat) string { + if dialog.PromptConfig == nil { + return "" + } + s, _ := dialog.PromptConfig["system"].(string) + return s + } + if got := systemPrompt(&entity.Chat{}); got != "" { + t.Fatalf("expected empty system prompt for nil PromptConfig, got %q", got) + } + if got := systemPrompt(&entity.Chat{PromptConfig: entity.JSONMap{"system": "be helpful"}}); got != "be helpful" { + t.Fatalf("expected system prompt 'be helpful', got %q", got) + } +} + +// TestOpenAI_MetadataConditionToDocIDs_NoCondition verifies the no-op path +// when metadata_condition is absent. Mirrors Python's `if +// metadata_condition:` guard at openai_api.py:290 — nil → no filter. +func TestOpenAI_MetadataConditionToDocIDs_NoCondition(t *testing.T) { + got := MetadataConditionToDocIDs(nil, nil) + if got != "" { + t.Fatalf("expected empty string for nil condition, got %q", got) + } +} + +// TestOpenAI_MetadataConditionToDocIDs_NoKBs verifies that empty metadata +// with non-empty conditions yields the "-999" sentinel — matching Python: +// get_flatted_meta_by_kbs([]) returns {}, meta_filter() returns [], and the +// `if conditions and not filtered_doc_ids` branch substitutes ["-999"]. +func TestOpenAI_MetadataConditionToDocIDs_NoKBs(t *testing.T) { + cond := map[string]interface{}{ + "logic": "and", + "conditions": []interface{}{map[string]interface{}{"name": "author", "comparison_operator": "is", "value": "x"}}, + } + got := MetadataConditionToDocIDs(nil, cond) + if got != "-999" { + t.Fatalf("expected sentinel \"-999\" for empty metadata with conditions, got %q", got) + } +} + +// TestOpenAI_ContextTokenUsed sums NumTokensFromString across messages. +func TestOpenAI_ContextTokenUsed(t *testing.T) { + contextTokenUsed := func(messages []map[string]interface{}) int { + total := 0 + for _, m := range messages { + if c, ok := m["content"].(string); ok { + total += tokenizer.NumTokensFromString(c) + } + } + return total + } + msgs := []map[string]interface{}{ + {"role": "user", "content": "0123456789"}, + {"role": "user", "content": "01234567890123"}, + } + got := contextTokenUsed(msgs) + // NumTokensFromString uses cl100k_base BPE encoding, not len(s)/4. + // "0123456789" (10 chars) + "01234567890123" (14 chars) = 9 BPE tokens. + if got != 9 { + t.Fatalf("expected 9 tokens (cl100k_base BPE), got %d", got) + } +} + +// TestOpenAI_DedupePrefix checks the SSE delta helper that strips the +// previous-cumulative prefix from a new cumulative string. +func TestOpenAI_DedupePrefix(t *testing.T) { + dedupePrefix := func(old, new string) string { + if strings.HasPrefix(new, old) { + return new[len(old):] + } + return new + } + cases := []struct { + old, new, want string + }{ + {"", "hello", "hello"}, + {"hello", "hello", ""}, + {"hello", "hello world", " world"}, + {"hello", "world", "world"}, + {"", "", ""}, + } + for _, c := range cases { + if got := dedupePrefix(c.old, c.new); got != c.want { + t.Fatalf("dedupePrefix(%q, %q) = %q, want %q", c.old, c.new, got, c.want) + } + } +} + +// TestOpenAI_IsContentDelta_FiltersSSETerminator guards the central filter +// that strips the OpenAI SSE terminator "[DONE]" out of the content stream. +// ~49 model drivers still call sender(&"[DONE]", nil) "for OpenAI +// compatibility", which used to leak the marker into the assistant +// message. isContentDelta is the single point of truth for whether a +// candidate delta should be appended to fullContent. +func TestOpenAI_IsContentDelta_FiltersSSETerminator(t *testing.T) { + cases := []struct { + name string + in *string + want bool + }{ + {"nil pointer", nil, false}, + {"empty string", strPtr(""), false}, + {"normal content", strPtr("Hello"), true}, + {"SSE terminator alone", strPtr("[DONE]"), false}, + {"content that contains the substring DONE", strPtr("DONE!"), true}, + {"multiline content", strPtr("line1\nline2"), true}, + {"single newline (leading-newline diagnostic)", strPtr("\n\n"), true}, + } + for _, c := range cases { + if got := isContentDelta(c.in); got != c.want { + t.Errorf("%s: isContentDelta(%v) = %v, want %v", c.name, derefStr(c.in), got, c.want) + } + } +} + +func derefStr(s *string) string { + if s == nil { + return "" + } + return *s +} + +// TestOpenAI_PerDeltaTrimSpace_NotApplied pins the contract that the +// sender callback does NOT strip per-delta whitespace, matching +// Python's api/db/services/llm_service.py::async_chat_streamly (which +// yields raw delta content and concatenates without stripping). +// +// Stripping per-delta would destroy inter-delta word boundaries: a model +// that sends "Hello" + " I'm" + " just" + " a" + " chatbot" (with +// leading spaces on most deltas) would concatenate to +// "Hello! I'mjustachatbot" if each piece was trimmed. +// +// The final accumulated answer IS trimmed, mirroring Python's +// response.choices[0].message.content.strip() at +// llm_service.py:1457 — that single trim is the only one we apply. +func TestOpenAI_PerDeltaTrimSpace_NotApplied(t *testing.T) { + // Per-delta: each piece is appended verbatim. Simulate the model + // streaming "Hello" + " I'm" + " just" + " a" + " chatbot" the way + // Qwen3-8B on SiliconFlow actually does. + deltas := []string{"Hello", "! I'm", " just", " a", " chatbot"} + accumulated := "" + for _, d := range deltas { + accumulated += d // sender appends raw delta, no TrimSpace + } + wantConcatenated := "Hello! I'm just a chatbot" + if accumulated != wantConcatenated { + t.Fatalf("per-delta concatenation = %q, want %q (this is the bug that per-delta TrimSpace would cause)", + accumulated, wantConcatenated) + } + + // Final-answer trim: only the final TrimSpace, applied to the full + // accumulated answer, matches Python's behavior. + leadingNewlines := "\n\nHello, world!\n\n" + trimmed := strings.TrimSpace(leadingNewlines) + if trimmed != "Hello, world!" { + t.Fatalf("final-answer TrimSpace(%q) = %q, want %q", + leadingNewlines, trimmed, "Hello, world!") + } +} + +// ============================================================================= +// Request preparation helpers (moved from internal/handler/openai_chat_completions_test.go) +// ============================================================================= + +// TestService_NormalizeMessageContent_String passes a plain string through. +func TestService_NormalizeMessageContent_String(t *testing.T) { + got, err := normalizeMessageContent("hello") + if err != nil || got != "hello" { + t.Fatalf("expected (%q,nil), got (%q,%v)", "hello", got, err) + } +} + +// TestService_NormalizeMessageContent_Nil returns "" with no error. +func TestService_NormalizeMessageContent_Nil(t *testing.T) { + got, err := normalizeMessageContent(nil) + if err != nil || got != "" { + t.Fatalf("expected (\"\",nil), got (%q,%v)", got, err) + } +} + +// TestService_NormalizeMessageContent_ArrayOfTextParts joins text parts +// with "\n" and drops non-text parts (e.g. image_url). Mirrors +// _normalize_message_content in openai_api.py:198-216. +func TestService_NormalizeMessageContent_ArrayOfTextParts(t *testing.T) { + in := []interface{}{ + map[string]interface{}{"type": "text", "text": "first"}, + map[string]interface{}{"type": "image_url", "image_url": map[string]string{"url": "http://x"}}, + map[string]interface{}{"type": "text", "text": "second"}, + } + got, err := normalizeMessageContent(in) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != "first\nsecond" { + t.Fatalf("expected %q, got %q", "first\nsecond", got) + } +} + +// TestService_NormalizeMessageContent_InvalidType returns the Python error string. +func TestService_NormalizeMessageContent_InvalidType(t *testing.T) { + _, err := normalizeMessageContent(42) + if err == nil || !strings.Contains(err.Error(), "must be a string or an array") { + t.Fatalf("expected content-type error, got %v", err) + } +} + +// TestService_NormalizeOpenAIMessages_RejectsInvalidContent ensures a +// message with bad content (e.g. a number) is rejected (Python: +// "messages[].content must be a string or an array of content parts."). +func TestService_NormalizeOpenAIMessages_RejectsInvalidContent(t *testing.T) { + in := []map[string]interface{}{{"role": "user", "content": 42}} + _, err := normalizeOpenAIMessages(in) + if err == nil { + t.Fatalf("expected error for non-string content") + } +} + +// TestService_ExtractGenerationConfig_OnlyKnownFields verifies the +// extraction mirrors extract_generation_config in +// _generation_params.py. The float64 coercion of max_tokens is what +// allows it to satisfy the type assertion in BuildChatConfig. +func TestService_ExtractGenerationConfig_OnlyKnownFields(t *testing.T) { + temp := 0.4 + topP := 0.8 + maxTokens := 256 + freq := 0.1 + pres := 0.2 + req := &OpenAIChatRequest{ + Temperature: &temp, + TopP: &topP, + MaxTokens: &maxTokens, + FrequencyPenalty: &freq, + PresencePenalty: &pres, + } + cfg := extractGenerationConfig(req) + if cfg["temperature"] != temp { + t.Fatalf("temperature: got %v want %v", cfg["temperature"], temp) + } + if cfg["top_p"] != topP { + t.Fatalf("top_p: got %v want %v", cfg["top_p"], topP) + } + if cfg["max_tokens"] != float64(maxTokens) { + t.Fatalf("max_tokens: got %v (%T) want %v (float64)", cfg["max_tokens"], cfg["max_tokens"], float64(maxTokens)) + } + if cfg["frequency_penalty"] != freq { + t.Fatalf("frequency_penalty: got %v want %v", cfg["frequency_penalty"], freq) + } + if cfg["presence_penalty"] != pres { + t.Fatalf("presence_penalty: got %v want %v", cfg["presence_penalty"], pres) + } + for _, k := range []string{"stop", "user", "internet", "tools"} { + if _, has := cfg[k]; has { + t.Fatalf("%s should not be in generation config, got %v", k, cfg[k]) + } + } +} + +// TestService_JoinNonEmpty_JoinsWithSeparator matches strings.Join +// semantics for non-empty inputs and skips empties. +func TestService_JoinNonEmpty_JoinsWithSeparator(t *testing.T) { + if got := joinNonEmpty([]string{"a", "b", "c"}, ","); got != "a,b,c" { + t.Fatalf("expected %q, got %q", "a,b,c", got) + } + if got := joinNonEmpty([]string{"a", "", "b"}, ","); got != "a,b" { + t.Fatalf("expected %q, got %q", "a,b", got) + } + if got := joinNonEmpty([]string{}, ","); got != "" { + t.Fatalf("expected empty, got %q", got) + } +} + +// TestMergeGenerationConfig_RequestOverridesDefault is the +// end-to-end contract: a dialog that arrives from the DB with the +// create-time default (api/db/db_models.py:987, applied inline in +// service/chat.go's SetDialog) still lets a request's per-call +// value win on top, mirroring Python's merge_generation_config at +// openai_api.py:283. +func TestMergeGenerationConfig_RequestOverridesDefault(t *testing.T) { + svc := &OpenAIChatService{} + dialog := &entity.Chat{ + LLMSetting: entity.JSONMap{ + "temperature": 0.1, + "top_p": 0.3, + "frequency_penalty": 0.7, + "presence_penalty": 0.4, + "max_tokens": 512, + }, + } + if dialog.LLMSetting["temperature"] != 0.1 { + t.Fatalf("precondition: default temperature should be 0.1, got %v", dialog.LLMSetting["temperature"]) + } + + svc.MergeGenerationConfig(dialog, map[string]interface{}{"temperature": 0.7}) + if got := dialog.LLMSetting["temperature"]; got != 0.7 { + t.Fatalf("request temperature should win, got %v", got) + } + // Other defaults (not in the request) survive intact. + if got := dialog.LLMSetting["top_p"]; got != 0.3 { + t.Fatalf("top_p: default should be intact after merge, got %v", got) + } + if got := dialog.LLMSetting["max_tokens"]; got != 512 { + t.Fatalf("max_tokens: default should be intact after merge, got %v", got) + } +} + +// flushableRecorder wraps httptest.ResponseRecorder with a no-op Flush so +// the gin.Context's c.Writer.(http.Flusher) type assertion succeeds. The +// recorder itself doesn't implement Flusher; without this wrapper, +// streamChatCompletionSSE would return "streaming unsupported" before +// emitting a single byte. +type flushableRecorder struct { + *httptest.ResponseRecorder + flushed int +} + +func (f *flushableRecorder) Flush() { f.flushed++ } + +// TestStreamChatCompletionSSE_HappyPath pins the SSE wire format +// produced by streamChatCompletionSSE: a `data: \n\n` line per +// event, the `chat.completion.chunk` object, role/content/reasoning +// fields, FinalAnswer surfaced only via final_content (not +// delta.content) — the #15286 fix, the [DONE] terminator, and the +// model field coming from the requestedModel arg (matching Python's +// _stream_chat_completion_sse, which uses requested_model in every +// yielded chunk). +func TestStreamChatCompletionSSE_HappyPath(t *testing.T) { + events := make(chan OpenAIStreamEvent, 8) + events <- OpenAIStreamEvent{Kind: OpenAIEventContent, Delta: "Hello"} + events <- OpenAIStreamEvent{Kind: OpenAIEventContent, Delta: " world"} + events <- OpenAIStreamEvent{Kind: OpenAIEventReasoning, Delta: "thinking..."} + events <- OpenAIStreamEvent{ + Kind: OpenAIEventFinal, + FinalAnswer: "Hello world", + FinalReference: []FormattedChunk{{ID: "chunk-1"}, {ID: "chunk-2"}}, + PromptTokens: 5, + CompletionTokens: 2, + TotalTokens: 7, + } + close(events) + + rec := &flushableRecorder{ResponseRecorder: httptest.NewRecorder()} + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat", nil) + + if err := streamChatCompletionSSE(c, events, "chatcmpl-test-1", "test-model", true); err != nil { + t.Fatalf("streamChatCompletionSSE returned error: %v", err) + } + + body := rec.Body.String() + if rec.Header().Get("Content-Type") != "text/event-stream; charset=utf-8" { + t.Errorf("Content-Type: got %q, want text/event-stream", rec.Header().Get("Content-Type")) + } + if rec.flushed < 4 { + t.Errorf("expected at least 4 flushes (3 chunks + [DONE]), got %d", rec.flushed) + } + // Per-chunk shape: data: \n\n, object=chat.completion.chunk. + if got := strings.Count(body, `"object":"chat.completion.chunk"`); got != 4 { + t.Errorf("expected 4 chat.completion.chunk objects, got %d in body:\n%s", got, body) + } + if got := strings.Count(body, "\n\ndata: [DONE]"); got != 1 { + t.Errorf("expected exactly 1 [DONE] terminator, got %d in body:\n%s", got, body) + } + // Model field uses requestedModel in every chunk (matches Python). + if got := strings.Count(body, `"model":"test-model"`); got != 4 { + t.Errorf("expected 4 chunks to carry model=test-model, got %d in body:\n%s", got, body) + } + // Content deltas surfaced as delta.content (not delta.final_content). + if !strings.Contains(body, `"content":"Hello"`) { + t.Errorf("expected content delta for 'Hello', body:\n%s", body) + } + if !strings.Contains(body, `"content":" world"`) { + t.Errorf("expected content delta for ' world', body:\n%s", body) + } + // Reasoning delta surfaced as delta.reasoning_content with content=null. + if !strings.Contains(body, `"reasoning_content":"thinking..."`) { + t.Errorf("expected reasoning_content for 'thinking...', body:\n%s", body) + } + // #15286 fix: FinalAnswer is in delta.final_content, NOT in delta.content. + // The final chunk has content:null, reasoning_content:null, final_content:"Hello world". + if !strings.Contains(body, `"final_content":"Hello world"`) { + t.Errorf("expected final_content='Hello world' in final chunk, body:\n%s", body) + } + if strings.Contains(body, `"content":"Hello world"`) { + t.Errorf("final answer leaked into delta.content — #15286 regression, body:\n%s", body) + } + // Reference is included because NeedReference=true. + if !strings.Contains(body, `"reference"`) { + t.Errorf("expected reference field (NeedReference=true), body:\n%s", body) + } + // Usage block on the final chunk. + if !strings.Contains(body, `"prompt_tokens":5`) { + t.Errorf("expected prompt_tokens=5, body:\n%s", body) + } + if !strings.Contains(body, `"completion_tokens":2`) { + t.Errorf("expected completion_tokens=2, body:\n%s", body) + } + if !strings.Contains(body, `"total_tokens":7`) { + t.Errorf("expected total_tokens=7, body:\n%s", body) + } + // finish_reason:stop on the final chunk. + if !strings.Contains(body, `"finish_reason":"stop"`) { + t.Errorf("expected finish_reason=stop on final chunk, body:\n%s", body) + } +} + +// TestStreamChatCompletionSSE_NoReference pins the +// NeedReference=false path: the final chunk's delta must NOT carry +// final_content or reference (Python omits those when need_reference +// is false, per openai_api.py:187-194). +func TestStreamChatCompletionSSE_NoReference(t *testing.T) { + events := make(chan OpenAIStreamEvent, 2) + events <- OpenAIStreamEvent{ + Kind: OpenAIEventFinal, + FinalAnswer: "answer", + FinalReference: []FormattedChunk{{ID: "chunk-1"}}, + } + close(events) + + rec := &flushableRecorder{ResponseRecorder: httptest.NewRecorder()} + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat", nil) + + if err := streamChatCompletionSSE(c, events, "chatcmpl-test-2", "test-model", false); err != nil { + t.Fatalf("streamChatCompletionSSE returned error: %v", err) + } + + body := rec.Body.String() + if strings.Contains(body, `"final_content"`) { + t.Errorf("NeedReference=false: final_content should be omitted, body:\n%s", body) + } + if strings.Contains(body, `"reference"`) { + t.Errorf("NeedReference=false: reference should be omitted, body:\n%s", body) + } +} + +// TestStreamChatCompletionSSE_ErrorEvent pins the in-band error path: +// an OpenAIEventError becomes a single chunk with delta.content = +// "**ERROR**: ", then the [DONE] terminator (mirrors +// openai_api.py:174-176). +func TestStreamChatCompletionSSE_ErrorEvent(t *testing.T) { + events := make(chan OpenAIStreamEvent, 1) + events <- OpenAIStreamEvent{Kind: OpenAIEventError, Error: "boom"} + close(events) + + rec := &flushableRecorder{ResponseRecorder: httptest.NewRecorder()} + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat", nil) + + if err := streamChatCompletionSSE(c, events, "chatcmpl-test-3", "test-model", false); err != nil { + t.Fatalf("streamChatCompletionSSE returned error: %v", err) + } + + body := rec.Body.String() + if !strings.Contains(body, `"content":"**ERROR**: boom"`) { + t.Errorf("expected error chunk with content='**ERROR**: boom', body:\n%s", body) + } + if !strings.Contains(body, "data: [DONE]") { + t.Errorf("expected [DONE] after error chunk, body:\n%s", body) + } +} + +// TestStreamChatCompletionSSE_FlusherUnsupported is intentionally not +// written: gin's responseWriter wrapper implements http.Flusher +// regardless of the underlying writer (it calls WriteHeaderNow on +// Flush), so the "streaming unsupported" branch is unreachable +// through gin.CreateTestContext. The branch stays in the function as +// a defensive guard against callers who build a gin.Context +// themselves with a custom non-flushable writer. + +// TestStreamChatCompletionSSE_EmptyChannel pins the empty-input +// edge case: a channel closed with no events at all should still +// emit the [DONE] terminator and not crash. +func TestStreamChatCompletionSSE_EmptyChannel(t *testing.T) { + events := make(chan OpenAIStreamEvent) + close(events) + + rec := &flushableRecorder{ResponseRecorder: httptest.NewRecorder()} + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat", nil) + + if err := streamChatCompletionSSE(c, events, "chatcmpl-empty", "test-model", false); err != nil { + t.Fatalf("streamChatCompletionSSE returned error: %v", err) + } + + body := rec.Body.String() + if !strings.Contains(body, "data: [DONE]") { + t.Errorf("expected [DONE] terminator for empty channel, body:\n%s", body) + } + if got := strings.Count(body, `"object":"chat.completion.chunk"`); got != 0 { + t.Errorf("expected 0 chunks for empty channel, got %d", got) + } +} + +// TestStreamChatCompletionSSE_ChunkJSONShape pins the structural +// shape of one chunk by parsing the JSON payload and checking the +// keys exist. Catches accidental field renames or type changes in +// the gin.H literals. +func TestStreamChatCompletionSSE_ChunkJSONShape(t *testing.T) { + events := make(chan OpenAIStreamEvent, 2) + events <- OpenAIStreamEvent{Kind: OpenAIEventContent, Delta: "x"} + events <- OpenAIStreamEvent{ + Kind: OpenAIEventFinal, + FinalAnswer: "x", + PromptTokens: 1, + CompletionTokens: 1, + TotalTokens: 2, + } + close(events) + + rec := &flushableRecorder{ResponseRecorder: httptest.NewRecorder()} + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat", nil) + + if err := streamChatCompletionSSE(c, events, "chatcmpl-shape", "shape-model", false); err != nil { + t.Fatalf("streamChatCompletionSSE returned error: %v", err) + } + + // Pull the first data: line and parse it. + lines := strings.Split(rec.Body.String(), "\n\n") + if len(lines) < 2 { + t.Fatalf("expected at least 2 lines (one chunk + [DONE]), got %d", len(lines)) + } + firstLine := strings.TrimPrefix(lines[0], "data:") + var chunk map[string]interface{} + if err := json.Unmarshal([]byte(firstLine), &chunk); err != nil { + t.Fatalf("first chunk is not valid JSON: %v\nline: %s", err, firstLine) + } + + // Required top-level fields on a chat.completion.chunk. + for _, key := range []string{"id", "object", "created", "model", "choices"} { + if _, ok := chunk[key]; !ok { + t.Errorf("chunk missing top-level %q", key) + } + } + if chunk["object"] != "chat.completion.chunk" { + t.Errorf("chunk.object: got %v, want chat.completion.chunk", chunk["object"]) + } + // Decode the choices[0] shape. + choices, ok := chunk["choices"].([]interface{}) + if !ok || len(choices) != 1 { + t.Fatalf("choices: got %T %v, want []interface{} of length 1", chunk["choices"], chunk["choices"]) + } + choice, ok := choices[0].(map[string]interface{}) + if !ok { + t.Fatalf("choices[0]: got %T, want object", choices[0]) + } + delta, ok := choice["delta"].(map[string]interface{}) + if !ok { + t.Fatalf("choices[0].delta: got %T, want object", choice["delta"]) + } + // Content deltas carry role + content. + if delta["role"] != "assistant" { + t.Errorf("delta.role: got %v, want assistant", delta["role"]) + } + if delta["content"] != "x" { + t.Errorf("delta.content: got %v, want x", delta["content"]) + } +} diff --git a/internal/service/tag.go b/internal/service/tag.go index d93e4e1471..e5be8a2a29 100644 --- a/internal/service/tag.go +++ b/internal/service/tag.go @@ -114,7 +114,8 @@ func SetTagsToCache(kbIDs []string, tags map[string]float64) error { // Knowledgebase type alias for entity.Knowledgebase type Knowledgebase = entity.Knowledgebase -// GetAllTagsInPortion returns the tag distribution for given KBs +// GetAllTagsInPortion returns all tag_kwd values and their occurrence counts +// for documents belonging to the given kbIDs. func (s *MetadataService) GetAllTagsInPortion(tenantID string, kbIDs []string) (map[string]float64, error) { if len(kbIDs) == 0 { return make(map[string]float64), nil @@ -122,12 +123,14 @@ func (s *MetadataService) GetAllTagsInPortion(tenantID string, kbIDs []string) ( indexName := fmt.Sprintf("ragflow_%s", tenantID) - // Search with large limit to get all tag_kwd values searchReq := &types.SearchRequest{ IndexNames: []string{indexName}, KbIDs: kbIDs, Offset: 0, - Limit: 10000, // Large limit to get all docs + // Python passes limit=0 ("unlimited") which Go SearchRequest treats + // as engine default (Infinity/ES: 30), so use an explicit large cap. + Limit: 100000, + SelectFields: []string{"tag_kwd"}, } searchResp, err := s.docEngine.Search(context.Background(), searchReq) @@ -259,7 +262,7 @@ func (s *MetadataService) TagQuery(question string, tenantIDs []string, kbIDs [] // 4. Get tag KBs by IDs // 5. Call TagQuery to get weighted tag features for the question func (s *MetadataService) LabelQuestion(question string, kbs []*Knowledgebase) map[string]float64 { - if len(kbs) == 0 { + if len(kbs) == 0 || question == "" { return nil } diff --git a/internal/service/tenant.go b/internal/service/tenant.go index 095eed7548..ff3ccb5120 100644 --- a/internal/service/tenant.go +++ b/internal/service/tenant.go @@ -109,13 +109,19 @@ type TenantListItem struct { // TenantLLMService tenant LLM service // This service handles operations related to tenant-specific LLM configurations type TenantLLMService struct { - tenantLLMDAO *dao.TenantLLMDAO + tenantLLMDAO *dao.TenantLLMDAO + modelProviderDAO *dao.TenantModelProviderDAO + modelInstanceDAO *dao.TenantModelInstanceDAO + modelDAO *dao.TenantModelDAO } // NewTenantLLMService creates a new TenantLLMService instance func NewTenantLLMService() *TenantLLMService { return &TenantLLMService{ - tenantLLMDAO: dao.NewTenantLLMDAO(), + tenantLLMDAO: dao.NewTenantLLMDAO(), + modelProviderDAO: dao.NewTenantModelProviderDAO(), + modelInstanceDAO: dao.NewTenantModelInstanceDAO(), + modelDAO: dao.NewTenantModelDAO(), } } @@ -176,6 +182,49 @@ func (s *TenantLLMService) SplitModelNameAndFactory(modelName string) (string, s return arr[0], arr[1] } +// GetAPIKeyFromInstance returns the API key for the given composite model name +// by looking it up in the tenant_model_instance table. compositeModelName is in +// "model@instance@provider" or "model@provider" format. +func (s *TenantLLMService) GetAPIKeyFromInstance(tenantID, compositeModelName string) (string, error) { + parts := strings.Split(compositeModelName, "@") + if len(parts) < 2 { + return "", fmt.Errorf("invalid model name format: %s", compositeModelName) + } + + var providerName, instanceName string + switch len(parts) { + case 2: + instanceName = "default" + providerName = parts[1] + case 3: + instanceName = parts[1] + providerName = parts[2] + default: + return "", fmt.Errorf("invalid model name format: %s", compositeModelName) + } + + provider, err := s.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, providerName) + if err != nil { + return "", fmt.Errorf("provider %q not found: %w", providerName, err) + } + if provider == nil { + return "", fmt.Errorf("provider %q not found", providerName) + } + + instance, err := s.modelInstanceDAO.GetByProviderIDAndInstanceName(provider.ID, instanceName) + if err != nil { + return "", fmt.Errorf("instance %q not found: %w", instanceName, err) + } + if instance == nil { + return "", fmt.Errorf("instance %q not found", instanceName) + } + + if instance.APIKey == "" { + return "", fmt.Errorf("no API key configured for model %s", compositeModelName) + } + return instance.APIKey, nil +} + // EnsureTenantModelIDForParams ensures tenant model IDs are populated for LLM-related parameters /** * This method iterates through a predefined list of LLM-related parameter keys (llm_id, embd_id, diff --git a/internal/service/toc_enhancer.go b/internal/service/toc_enhancer.go new file mode 100644 index 0000000000..84c7c2856a --- /dev/null +++ b/internal/service/toc_enhancer.go @@ -0,0 +1,605 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package service + +import ( + "context" + "encoding/json" + "fmt" + "sort" + "strconv" + "strings" + + "ragflow/internal/common" + "ragflow/internal/engine" + "ragflow/internal/engine/types" + modelModule "ragflow/internal/entity/models" + + "github.com/kaptinlin/jsonrepair" + "go.uber.org/zap" +) + +// flexInt is an int that can unmarshal from either a JSON number or a JSON string. +// This handles the mismatch between DB-stored TOC entries (level as string "1") +// and LLM-emitted scores (level as number 1). +type flexInt int + +func (f *flexInt) UnmarshalJSON(data []byte) error { + var i int + if err := json.Unmarshal(data, &i); err == nil { + *f = flexInt(i) + return nil + } + var s string + if err := json.Unmarshal(data, &s); err == nil { + i, err := strconv.Atoi(s) + if err != nil { + return fmt.Errorf("flexInt: invalid string %q: %w", s, err) + } + *f = flexInt(i) + return nil + } + return fmt.Errorf("flexInt: cannot unmarshal %s", string(data)) +} + +func (f flexInt) MarshalJSON() ([]byte, error) { + return json.Marshal(int(f)) +} + +// tocEntry holds a single entry from a document's TOC chunk. +// Note: level is stored as a string in JSON (e.g. "1"), so we use flexInt. +type tocEntry struct { + Level flexInt `json:"level"` + Title string `json:"title"` + IDs []string `json:"ids,omitempty"` +} + +// tocRelevanceScore is the LLM-emitted score for a single TOC entry. +type tocRelevanceScore struct { + Level int `json:"level"` + Title string `json:"title"` + Score float64 `json:"score"` +} + +const tocRelevanceSystemPrompt = `You are an expert logical reasoning assistant specializing in hierarchical Table of Contents (TOC) relevance evaluation. + +## GOAL +You will receive: +1. A JSON list of TOC items, each with fields: + ` + "```" + `json + { + "level": , // e.g., 1, 2, 3 + "title": // section title +} + +func asMap(v interface{}) map[string]interface{} { + if m, ok := v.(map[string]interface{}); ok { + return m + } + return nil +} + ` + "```" + ` +2. A user query (natural language question). + +You must assign a **relevance score** (integer) to every TOC entry, based on how related its ` + "`" + `title` + "`" + ` is to the ` + "`" + `query` + "`" + `. + +--- + +## RULES + +### Scoring System +- 5 → highly relevant (directly answers or matches the query intent) +- 3 → somewhat related (same topic or partially overlaps) +- 1 → weakly related (vague or tangential) +- 0 → no clear relation +- -1 → explicitly irrelevant or contradictory + +### Hierarchy Traversal +- The TOC is hierarchical: smaller ` + "`" + `level` + "`" + ` = higher layer (e.g., level 1 is top-level, level 2 is a subsection). +- You must traverse in **hierarchical order** — interpret the structure based on levels (1 > 2 > 3). +- If a high-level item (level 1) is strongly related (score 5), its child items (level 2, 3) are likely relevant too. +- If a high-level item is unrelated (-1 or 0), its deeper children are usually less relevant unless the titles clearly match the query. +- Lower (deeper) levels provide more specific content; prefer assigning higher scores if they directly match the query. + +### Output Format +Return a **JSON array**, preserving the input order but adding a new key ` + "`" + `"score"` + "`" + `: + +` + "```" + `json +[ + {"level": 1, "title": "Introduction", "score": 0}, + {"level": 2, "title": "Definition of Sustainability", "score": 5} +] +` + "```" + ` + +### Constraints +- Output **only the JSON array** — no explanations or reasoning text. + +### EXAMPLES + +#### Example 1 +Input TOC: +[ + {"level": 1, "title": "Machine Learning Overview"}, + {"level": 2, "title": "Supervised Learning"}, + {"level": 2, "title": "Unsupervised Learning"}, + {"level": 3, "title": "Applications of Deep Learning"} +] + +Query: +"How is deep learning used in image classification?" + +Output: +[ + {"level": 1, "title": "Machine Learning Overview", "score": 3}, + {"level": 2, "title": "Supervised Learning", "score": 3}, + {"level": 2, "title": "Unsupervised Learning", "score": 0}, + {"level": 3, "title": "Applications of Deep Learning", "score": 5} +] + +--- + +#### Example 2 +Input TOC: +[ + {"level": 1, "title": "Marketing Basics"}, + {"level": 2, "title": "Consumer Behavior"}, + {"level": 2, "title": "Digital Marketing"}, + {"level": 3, "title": "Social Media Campaigns"}, + {"level": 3, "title": "SEO Optimization"} +] + +Query: +"What are the best online marketing methods?" + +Output: +[ + {"level": 1, "title": "Marketing Basics", "score": 3}, + {"level": 2, "title": "Consumer Behavior", "score": 1}, + {"level": 2, "title": "Digital Marketing", "score": 5}, + {"level": 3, "title": "Social Media Campaigns", "score": 5}, + {"level": 3, "title": "SEO Optimization", "score": 5} +] + +--- + +#### Example 3 +Input TOC: +[ + {"level": 1, "title": "Physics Overview"}, + {"level": 2, "title": "Classical Mechanics"}, + {"level": 3, "title": "Newton's Laws"}, + {"level": 2, "title": "Thermodynamics"}, + {"level": 3, "title": "Entropy and Heat Transfer"} +] + +Query: +"What is entropy?" + +Output: +[ + {"level": 1, "title": "Physics Overview", "score": 3}, + {"level": 2, "title": "Classical Mechanics", "score": 0}, + {"level": 3, "title": "Newton's Laws", "score": -1}, + {"level": 2, "title": "Thermodynamics", "score": 5}, + {"level": 3, "title": "Entropy and Heat Transfer", "score": 5} +] +` + +const tocRelevanceUserTemplate = `You will now receive: +1. A JSON list of TOC items (each with ` + "`" + `level` + "`" + ` and ` + "`" + `title` + "`" + `) +2. A user query string. + +Traverse the TOC hierarchically based on level numbers and assign scores (5,3,1,0,-1) according to the rules in the system prompt. +Output **only** the JSON array with the added ` + "`" + `"score"` + "`" + ` field. + +--- + +**Input TOC:** +%s + +**Query:** +%s +` + +// TOCEnhancer picks the top document, fetches its TOC, scores entries via LLM, +// then merges matching chunks into kbinfos["chunks"]. +type TOCEnhancer struct { + docEngine engine.DocEngine + chatModel *modelModule.ChatModel + tenantIDs []string + kbIDs []string + question string + topN int +} + +// NewTOCEnhancer constructs a TOCEnhancer. +func NewTOCEnhancer( + docEngine engine.DocEngine, + chatModel *modelModule.ChatModel, + tenantIDs []string, + kbIDs []string, + question string, + topN int, +) *TOCEnhancer { + return &TOCEnhancer{ + docEngine: docEngine, + chatModel: chatModel, + tenantIDs: tenantIDs, + kbIDs: kbIDs, + question: question, + topN: topN, + } +} + +// Enhance mutates kbinfos["chunks"] by appending/boosting TOC-relevant chunks. +func (e *TOCEnhancer) Enhance(ctx context.Context, kbinfos map[string]interface{}) (int, error) { + if e == nil || e.chatModel == nil { + return 0, nil + } + if kbinfos == nil { + return 0, nil + } + if e.docEngine == nil { + e.docEngine = engine.Get() + } + if e.docEngine == nil { + return 0, nil + } + chunksRaw, ok := kbinfos["chunks"].([]map[string]interface{}) + if !ok || len(chunksRaw) == 0 { + return 0, nil + } + + common.Debug("TOC enhancer: started", + zap.Int("chunk_count", len(chunksRaw)), + zap.String("question", e.question)) + + topDocID, docID2KBID := topDocFromChunks(chunksRaw) + if topDocID == "" { + return 0, nil + } + + filter := map[string]interface{}{ + "doc_id": []string{topDocID}, + "toc_kwd": "toc", + } + indexNames := make([]string, 0, len(e.tenantIDs)) + for _, tid := range e.tenantIDs { + indexNames = append(indexNames, indexName(tid)) + } + tocResp, err := e.docEngine.Search(ctx, &types.SearchRequest{ + IndexNames: indexNames, + KbIDs: e.kbIDs, + Filter: filter, + SelectFields: []string{"content_with_weight"}, + Offset: 0, + Limit: 128, + }) + if err != nil || tocResp == nil || len(tocResp.Chunks) == 0 { + common.Debug("TOC enhancer: no TOC chunks found for top doc", + zap.String("doc_id", topDocID)) + return 0, nil + } + + entries := parseTOCEntries(tocResp.Chunks) + if len(entries) == 0 { + common.Debug("TOC enhancer: TOC content did not parse to entries", + zap.String("doc_id", topDocID)) + return 0, nil + } + + scores, err := e.scoreEntries(ctx, entries, e.topN*2) + if err != nil { + common.Warn("TOC enhancer: LLM scoring failed", + zap.Error(err), zap.String("doc_id", topDocID)) + return 0, nil + } + if len(scores) == 0 { + return 0, nil + } + + id2idx := map[string]int{} + for i, cm := range chunksRaw { + if cid, ok := cm["chunk_id"].(string); ok && cid != "" { + id2idx[cid] = i + } + } + added := 0 + kbID := docID2KBID[topDocID] + for _, sc := range scores { + cid := sc.Title + if idx, exists := id2idx[cid]; exists { + boostSimilarity(chunksRaw[idx], sc.Score) + } else { + fresh, fetchErr := e.fetchChunk(ctx, cid, topDocID, kbID) + if fetchErr != nil || fresh == nil { + continue + } + d := map[string]interface{}{ + "chunk_id": cid, + "content_ltks": getString(fresh, "content_ltks"), + "content_with_weight": getString(fresh, "content_with_weight"), + "doc_id": topDocID, + "docnm_kwd": getStringDef(fresh, "docnm_kwd", ""), + "kb_id": getStringDef(fresh, "kb_id", kbID), + "important_kwd": getSlice(fresh, "important_kwd"), + "image_id": getStringDef(fresh, "img_id", getStringDef(fresh, "image_id", "")), + "similarity": sc.Score, + "vector_similarity": sc.Score, + "term_similarity": sc.Score, + "vector": []float64{}, + "positions": getSlice(fresh, "position_int"), + "doc_type_kwd": getStringDef(fresh, "doc_type_kwd", ""), + } + for k, v := range fresh { + if len(k) >= 4 && k[len(k)-4:] == "_vec" { + if vec := toFloat64Slice(v); vec != nil { + d["vector"] = vec + break + } + } + } + chunksRaw = append(chunksRaw, d) + id2idx[cid] = len(chunksRaw) - 1 + added++ + } + } + + kbinfos["chunks"] = sortAndTrimChunks(chunksRaw, e.topN) + common.Debug("TOC enhancer: finished", + zap.Int("added_chunks", added), + zap.Int("total_chunks", len(chunksRaw)), + zap.String("doc_id", topDocID)) + return added, nil +} + +// topDocFromChunks picks the doc_id with the highest accumulated similarity. +func topDocFromChunks(chunks []map[string]interface{}) (string, map[string]string) { + ranks := map[string]float64{} + docID2KBID := map[string]string{} + for _, cm := range chunks { + docID, _ := cm["doc_id"].(string) + kbID, _ := cm["kb_id"].(string) + sim, _ := cm["similarity"].(float64) + if docID == "" { + continue + } + ranks[docID] += sim + if _, seen := docID2KBID[docID]; !seen && kbID != "" { + docID2KBID[docID] = kbID + } + } + if len(ranks) == 0 { + return "", nil + } + type kv struct { + k string + v float64 + } + pairs := make([]kv, 0, len(ranks)) + for k, v := range ranks { + pairs = append(pairs, kv{k, v}) + } + sort.Slice(pairs, func(i, j int) bool { return pairs[i].v > pairs[j].v }) + return pairs[0].k, docID2KBID +} + +// parseTOCEntries flattens TOC entries across all TOC chunks. +func parseTOCEntries(chunks []map[string]interface{}) []tocEntry { + common.Debug("TOC enhancer: parsing TOC entries", + zap.Int("chunk_count", len(chunks))) + var out []tocEntry + for _, ck := range chunks { + cww, _ := ck["content_with_weight"].(string) + if cww == "" { + continue + } + var arr []tocEntry + if err := json.Unmarshal([]byte(cww), &arr); err == nil { + out = append(out, arr...) + continue + } + var single tocEntry + if err := json.Unmarshal([]byte(cww), &single); err == nil && single.Title != "" { + out = append(out, single) + continue + } + // Debug: log raw content that failed to parse + preview := cww + if len(preview) > 200 { + preview = preview[:200] + "..." + } + chunkID, _ := ck["id"].(string) + docID, _ := ck["doc_id"].(string) + common.Debug("TOC enhancer: chunk content not valid TOC JSON", + zap.String("chunk_id", chunkID), + zap.String("doc_id", docID), + zap.String("content_preview", preview)) + } + return out +} + +// scoreEntries calls the LLM to score TOC entries and returns (chunkID, normalizedScore) pairs. +func (e *TOCEnhancer) scoreEntries(ctx context.Context, entries []tocEntry, limit int) ([]tocRelevanceScore, error) { + if e.chatModel == nil || e.chatModel.ModelDriver == nil || len(entries) == 0 { + return nil, nil + } + + type tocLLMInput struct { + Level int `json:"level"` + Title string `json:"title"` + } + lines := make([]string, len(entries)) + for i, ent := range entries { + b, _ := json.Marshal(tocLLMInput{Level: int(ent.Level), Title: ent.Title}) + lines[i] = string(b) + } + tocStr := fmt.Sprintf("[\n%s\n]\n", strings.Join(lines, "\n")) + + userPrompt := fmt.Sprintf(tocRelevanceUserTemplate, tocStr, e.question) + + tempZero := 0.0 + topP := 0.9 + cfg := &modelModule.ChatConfig{ + Temperature: &tempZero, + TopP: &topP, + } + + var scores []tocRelevanceScore + maxRetry := 2 + var lastAns, lastErr string + for attempt := 0; attempt < maxRetry; attempt++ { + currentUser := userPrompt + if attempt > 0 && lastAns != "" && lastErr != "" { + currentUser += fmt.Sprintf( + "\nGenerated JSON is as following:\n%s\nBut exception while loading:\n%s\nPlease reconsider and correct it.", + lastAns, lastErr, + ) + } + msgs := []modelModule.Message{ + {Role: "system", Content: tocRelevanceSystemPrompt}, + {Role: "user", Content: currentUser}, + } + modelName := "" + if e.chatModel.ModelName != nil { + modelName = *e.chatModel.ModelName + } + resp, err := e.chatModel.ModelDriver.ChatWithMessages( + modelName, msgs, e.chatModel.APIConfig, cfg, + ) + if err != nil { + return nil, err + } + if resp == nil || resp.Answer == nil { + return nil, fmt.Errorf("toc scoring: empty response") + } + + raw := cleanLLMResponse(*resp.Answer) + lastAns = raw + + repaired, rerr := jsonrepair.Repair(raw) + if rerr != nil { + repaired = raw + } + if err := json.Unmarshal([]byte(repaired), &scores); err != nil { + lastErr = err.Error() + common.Warn("TOC enhancer: JSON parse failed, retrying", + zap.Error(err), zap.Int("attempt", attempt)) + continue + } + break + } + if len(scores) == 0 && lastErr != "" { + return nil, fmt.Errorf("toc scoring: parse failed after retries: %s", lastErr) + } + + id2score := make(map[string][]float64) + for i := 0; i < len(scores) && i < len(entries); i++ { + sc := scores[i] + if sc.Score < 1 { + continue + } + norm := sc.Score / 5.0 + for _, cid := range entries[i].IDs { + id2score[cid] = append(id2score[cid], norm) + } + } + + result := make([]tocRelevanceScore, 0, len(id2score)) + for cid, vals := range id2score { + sum := 0.0 + for _, v := range vals { + sum += v + } + avg := sum / float64(len(vals)) + if avg >= 0.3 { + result = append(result, tocRelevanceScore{ + Title: cid, + Score: avg, + }) + } + } + if limit > 0 && len(result) > limit { + result = result[:limit] + } + return result, nil +} + +// fetchChunk loads a single chunk by chunk_id from the engine. +func (e *TOCEnhancer) fetchChunk(ctx context.Context, chunkID, docID, kbID string) (map[string]interface{}, error) { + filter := map[string]interface{}{ + "doc_id": []string{docID}, + "chunk_id": []string{chunkID}, + } + indexNames := make([]string, 0, len(e.tenantIDs)) + for _, tid := range e.tenantIDs { + indexNames = append(indexNames, indexName(tid)) + } + resp, err := e.docEngine.Search(ctx, &types.SearchRequest{ + IndexNames: indexNames, + KbIDs: []string{kbID}, + Filter: filter, + SelectFields: []string{"content_with_weight", "content_ltks", "doc_id", "docnm_kwd", "kb_id", "important_kwd", "image_id", "positions", "doc_type_kwd", "vector", "q_1024_vec"}, + Offset: 0, + Limit: 1, + }) + if err != nil || resp == nil || len(resp.Chunks) == 0 { + return nil, fmt.Errorf("toc enhancer: fetch chunk %s: not found", chunkID) + } + return resp.Chunks[0], nil +} + +// indexName returns the search index name for a tenant. +func indexName(tenantID string) string { + return "ragflow_" + tenantID +} + +func boostSimilarity(cm map[string]interface{}, delta float64) { + cm["similarity"] = getFloat(cm, "similarity") + delta +} + +func getStringDef(m map[string]interface{}, key, def string) string { + if v, ok := m[key].(string); ok { + return v + } + return def +} + +func getSlice(m map[string]interface{}, key string) []interface{} { + if v, ok := m[key].([]interface{}); ok { + return v + } + return nil +} + +// sortAndTrimChunks sorts chunks by similarity descending and trims to top-N. +func sortAndTrimChunks(chunks []map[string]interface{}, topN int) []map[string]interface{} { + sort.SliceStable(chunks, func(i, j int) bool { + return getFloat(chunks[i], "similarity") > getFloat(chunks[j], "similarity") + }) + if topN > 0 && topN < len(chunks) { + chunks = chunks[:topN] + } + return chunks +} + +func asMap(v interface{}) map[string]interface{} { + if m, ok := v.(map[string]interface{}); ok { + return m + } + return nil +} diff --git a/internal/service/toc_enhancer_test.go b/internal/service/toc_enhancer_test.go new file mode 100644 index 0000000000..ff69123713 --- /dev/null +++ b/internal/service/toc_enhancer_test.go @@ -0,0 +1,314 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package service + +import ( + "testing" +) + +func TestGetFloat(t *testing.T) { + tests := []struct { + name string + m map[string]interface{} + key string + want float64 + }{ + {"present", map[string]interface{}{"score": 3.5}, "score", 3.5}, + {"missing", map[string]interface{}{}, "score", 0}, + {"wrong type", map[string]interface{}{"score": "3.5"}, "score", 0}, + {"nil value", map[string]interface{}{"score": nil}, "score", 0}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := getFloat(tt.m, tt.key); got != tt.want { + t.Errorf("getFloat() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestGetStringDef(t *testing.T) { + tests := []struct { + name string + m map[string]interface{} + key string + def string + want string + }{ + {"present", map[string]interface{}{"name": "foo"}, "name", "default", "foo"}, + {"missing", map[string]interface{}{}, "name", "default", "default"}, + {"wrong type", map[string]interface{}{"name": 42}, "name", "default", "default"}, + {"empty string", map[string]interface{}{"name": ""}, "name", "default", ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := getStringDef(tt.m, tt.key, tt.def); got != tt.want { + t.Errorf("getStringDef() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestGetSlice(t *testing.T) { + tests := []struct { + name string + m map[string]interface{} + key string + want []interface{} + }{ + {"present", map[string]interface{}{"items": []interface{}{1, 2, 3}}, "items", []interface{}{1, 2, 3}}, + {"missing", map[string]interface{}{}, "items", nil}, + {"wrong type", map[string]interface{}{"items": "not a slice"}, "items", nil}, + {"nil value", map[string]interface{}{"items": nil}, "items", nil}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := getSlice(tt.m, tt.key) + if len(got) != len(tt.want) { + t.Errorf("getSlice() len = %d, want %d", len(got), len(tt.want)) + return + } + for i := range got { + if got[i] != tt.want[i] { + t.Errorf("getSlice()[%d] = %v, want %v", i, got[i], tt.want[i]) + } + } + }) + } +} + +func TestToFloat64Slice(t *testing.T) { + tests := []struct { + name string + v interface{} + want []float64 + wantOk bool + }{ + {"valid", []interface{}{1.0, 2.5, 3.0}, []float64{1.0, 2.5, 3.0}, true}, + {"empty", []interface{}{}, []float64{}, true}, + {"wrong type", "not a slice", nil, false}, + {"non-float element", []interface{}{1.0, "x"}, nil, false}, + {"nil", nil, nil, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := toFloat64Slice(tt.v) + if (got != nil) != tt.wantOk { + t.Errorf("toFloat64Slice() ok = %v, want %v", got != nil, tt.wantOk) + return + } + if len(got) != len(tt.want) { + t.Errorf("toFloat64Slice() len = %d, want %d", len(got), len(tt.want)) + return + } + for i := range got { + if got[i] != tt.want[i] { + t.Errorf("toFloat64Slice()[%d] = %v, want %v", i, got[i], tt.want[i]) + } + } + }) + } +} + +func TestAsMap(t *testing.T) { + tests := []struct { + name string + v interface{} + want map[string]interface{} + }{ + {"valid map", map[string]interface{}{"a": 1}, map[string]interface{}{"a": 1}}, + {"nil", nil, nil}, + {"wrong type", "not a map", nil}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := asMap(tt.v) + if len(got) != len(tt.want) { + t.Errorf("asMap() len = %d, want %d", len(got), len(tt.want)) + return + } + for k, v := range tt.want { + if got[k] != v { + t.Errorf("asMap()[%q] = %v, want %v", k, got[k], v) + } + } + }) + } +} + +func TestBoostSimilarity(t *testing.T) { + cm := map[string]interface{}{"similarity": 1.0} + boostSimilarity(cm, 0.5) + if cm["similarity"] != 1.5 { + t.Errorf("expected 1.5, got %v", cm["similarity"]) + } + + boostSimilarity(cm, 0.0) + if cm["similarity"] != 1.5 { + t.Errorf("expected 1.5 (unchanged), got %v", cm["similarity"]) + } +} + +func TestTopDocFromChunks_Empty(t *testing.T) { + docID, docMap := topDocFromChunks(nil) + if docID != "" || docMap != nil { + t.Errorf("expected empty, got docID=%q, map=%v", docID, docMap) + } +} + +func TestTopDocFromChunks_SingleDoc(t *testing.T) { + chunks := []map[string]interface{}{ + map[string]interface{}{"doc_id": "doc1", "kb_id": "kb1", "similarity": 0.8}, + map[string]interface{}{"doc_id": "doc1", "kb_id": "kb1", "similarity": 0.6}, + } + docID, docMap := topDocFromChunks(chunks) + if docID != "doc1" { + t.Errorf("expected doc1, got %q", docID) + } + if docMap["doc1"] != "kb1" { + t.Errorf("expected kb1, got %q", docMap["doc1"]) + } +} + +func TestTopDocFromChunks_MultiDoc(t *testing.T) { + chunks := []map[string]interface{}{ + map[string]interface{}{"doc_id": "doc1", "kb_id": "kb1", "similarity": 0.3}, + map[string]interface{}{"doc_id": "doc2", "kb_id": "kb2", "similarity": 0.9}, + map[string]interface{}{"doc_id": "doc1", "kb_id": "kb1", "similarity": 0.4}, + } + docID, docMap := topDocFromChunks(chunks) + if docID != "doc2" { + t.Errorf("expected doc2 (accumulated 0.9 > doc1's 0.7), got %q", docID) + } + if docMap["doc1"] != "kb1" { + t.Errorf("expected kb1, got %q", docMap["doc1"]) + } + if docMap["doc2"] != "kb2" { + t.Errorf("expected kb2, got %q", docMap["doc2"]) + } +} + +func TestTopDocFromChunks_NoDocID(t *testing.T) { + chunks := []map[string]interface{}{ + map[string]interface{}{"similarity": 0.8}, + } + docID, docMap := topDocFromChunks(chunks) + if docID != "" || docMap != nil { + t.Errorf("expected empty, got docID=%q, map=%v", docID, docMap) + } +} + +func TestParseTOCEntries_Empty(t *testing.T) { + got := parseTOCEntries(nil) + if len(got) != 0 { + t.Errorf("expected 0 entries, got %d", len(got)) + } +} + +func TestParseTOCEntries_SingleEntry(t *testing.T) { + chunks := []map[string]interface{}{ + {"content_with_weight": `{"level": 1, "title": "Intro", "ids": ["c1"]}`}, + } + got := parseTOCEntries(chunks) + if len(got) != 1 { + t.Fatalf("expected 1 entry, got %d", len(got)) + } + if got[0].Level != 1 || got[0].Title != "Intro" || len(got[0].IDs) != 1 || got[0].IDs[0] != "c1" { + t.Errorf("unexpected entry: %+v", got[0]) + } +} + +func TestParseTOCEntries_MultiEntry(t *testing.T) { + chunks := []map[string]interface{}{ + {"content_with_weight": `[ + {"level": 1, "title": "Intro", "ids": ["c1"]}, + {"level": 2, "title": "Details", "ids": ["c2", "c3"]} + ]`}, + } + got := parseTOCEntries(chunks) + if len(got) != 2 { + t.Fatalf("expected 2 entries, got %d", len(got)) + } + if got[0].Title != "Intro" || got[1].Title != "Details" { + t.Errorf("unexpected titles: %q, %q", got[0].Title, got[1].Title) + } + if len(got[1].IDs) != 2 { + t.Errorf("expected 2 IDs for Details, got %v", got[1].IDs) + } +} + +func TestParseTOCEntries_InvalidJSON(t *testing.T) { + chunks := []map[string]interface{}{ + {"content_with_weight": `not json`}, + } + got := parseTOCEntries(chunks) + if len(got) != 0 { + t.Errorf("expected 0 entries for invalid JSON, got %d", len(got)) + } +} + +func TestParseTOCEntries_MissingField(t *testing.T) { + chunks := []map[string]interface{}{ + {"other_field": "value"}, + } + got := parseTOCEntries(chunks) + if len(got) != 0 { + t.Errorf("expected 0 entries for missing content_with_weight, got %d", len(got)) + } +} + +func TestSortAndTrimChunks_Nil(t *testing.T) { + got := sortAndTrimChunks(nil, 3) + if got != nil { + t.Errorf("expected nil, got %v", got) + } +} + +func TestSortAndTrimChunks_Trim(t *testing.T) { + chunks := []map[string]interface{}{ + map[string]interface{}{"similarity": 0.3}, + map[string]interface{}{"similarity": 0.9}, + map[string]interface{}{"similarity": 0.6}, + } + got := sortAndTrimChunks(chunks, 2) + if len(got) != 2 { + t.Fatalf("expected 2 chunks, got %d", len(got)) + } + if got[0]["similarity"] != 0.9 { + t.Errorf("expected first similarity 0.9, got %v", got[0]) + } + if got[1]["similarity"] != 0.6 { + t.Errorf("expected second similarity 0.6, got %v", got[1]) + } +} + +func TestSortAndTrimChunks_AllKept(t *testing.T) { + chunks := []map[string]interface{}{ + map[string]interface{}{"similarity": 0.3}, + map[string]interface{}{"similarity": 0.9}, + } + got := sortAndTrimChunks(chunks, 5) + if len(got) != 2 { + t.Errorf("expected all 2 chunks kept, got %d", len(got)) + } +} + +func TestIndexName(t *testing.T) { + if got := indexName("tenant1"); got != "ragflow_tenant1" { + t.Errorf("expected 'ragflow_tenant1', got %q", got) + } +} diff --git a/internal/tokenizer/tokenizer.go b/internal/tokenizer/tokenizer.go index 83c085160c..f9bde3b7e8 100644 --- a/internal/tokenizer/tokenizer.go +++ b/internal/tokenizer/tokenizer.go @@ -21,17 +21,31 @@ import ( "fmt" "os" "ragflow/internal/common" - "ragflow/internal/engine" "runtime" "sync" "sync/atomic" "time" + "github.com/pkoukk/tiktoken-go" "go.uber.org/zap" rag "ragflow/internal/binding" ) +// engineTypeProvider is injected at startup by engine.RegisterEngineType +// to break the tokenizer → engine import cycle. +var engineTypeProvider = func() string { return "" } + +// RegisterEngineType wires the engine package's GetEngineType into the +// tokenizer, breaking the circular import (engine/elasticsearch → tokenizer → engine). +func RegisterEngineType(get func() string) { + if get == nil { + engineTypeProvider = func() string { return "" } + return + } + engineTypeProvider = get +} + // PoolConfig configures the elastic analyzer pool type PoolConfig struct { DictPath string // Path to dictionary files @@ -417,7 +431,7 @@ func withAnalyzerResult[T any](fn func(*rag.Analyzer) (T, error)) (T, error) { // // NOTE: For Infinity engine, returns input unchanged to match python's behavior func Tokenize(text string) (string, error) { - if engine.GetEngineType() == "infinity" { + if engineTypeProvider() == "infinity" { return text, nil } return withAnalyzerResult(func(a *rag.Analyzer) (string, error) { @@ -454,7 +468,7 @@ func SetFineGrained(fineGrained bool) { // // NOTE: For Infinity engine, returns input unchanged to match python's behavior func FineGrainedTokenize(tokens string) (string, error) { - if engine.GetEngineType() == "infinity" { + if engineTypeProvider() == "infinity" { return tokens, nil } return withAnalyzerResult(func(a *rag.Analyzer) (string, error) { @@ -490,3 +504,32 @@ func GetTermTag(term string) string { }) return result } + +var cl100kEncoder struct { + sync.Once + enc *tiktoken.Tiktoken + err error +} + +func getCL100KEncoder() (*tiktoken.Tiktoken, error) { + cl100kEncoder.Do(func() { + cl100kEncoder.enc, cl100kEncoder.err = tiktoken.GetEncoding("cl100k_base") + }) + return cl100kEncoder.enc, cl100kEncoder.err +} + +// NumTokensFromString returns the number of tokens in s using the cl100k_base +// BPE encoding +func NumTokensFromString(s string) int { + if s == "" { + return 0 + } + enc, err := getCL100KEncoder() + if err != nil { + // Fail closed: avoid dangerous undercounting when encoder is unavailable. + // A conservative byte-length estimate errs on the side of over-counting, + // which is safer for budget enforcement than returning zero. + return len([]byte(s)) + } + return len(enc.Encode(s, nil, nil)) +} diff --git a/internal/tokenizer/tokenizer_test.go b/internal/tokenizer/tokenizer_test.go new file mode 100644 index 0000000000..48bbaddef2 --- /dev/null +++ b/internal/tokenizer/tokenizer_test.go @@ -0,0 +1,265 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package tokenizer + +import ( + "strings" + "testing" +) + +// saveEngineType saves the current engineTypeProvider and returns a function +// to restore it. Use this when a test modifies the engine type to avoid +// leaking global state between tests. +func saveEngineType() func() { + original := engineTypeProvider + return func() { engineTypeProvider = original } +} + +// --------------------------------------------------------------------------- +// NumTokensFromString tests +// --------------------------------------------------------------------------- + +func TestNumTokensFromString_Empty(t *testing.T) { + if got := NumTokensFromString(""); got != 0 { + t.Errorf("expected 0 for empty string, got %d", got) + } +} + +func TestNumTokensFromString_Positive(t *testing.T) { + for _, s := range []string{"hello world", "你好世界"} { + if got := NumTokensFromString(s); got <= 0 { + t.Errorf("NumTokensFromString(%q) = %d, want >0", s, got) + } + } +} + +func TestNumTokensFromString_VariedInputs(t *testing.T) { + tests := []struct { + name string + input string + }{ + {"ascii letters", "hello world"}, + {"chinese characters", "你好世界"}, + {"japanese characters", "こんにちは世界"}, + {"korean characters", "안녕하세요세계"}, + {"emoji", "👋 hello 🌍"}, + {"numbers only", "1234567890"}, + {"special chars", "a+b=c; d!=e"}, + {"newlines and tabs", "line1\nline2\tindented"}, + {"mixed content", "RAGFlow 是一款 开源的 RAG (Retrieval-Augmented Generation) 引擎"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := NumTokensFromString(tt.input) + if got <= 0 { + t.Errorf("NumTokensFromString(%q) = %d, want >0", tt.input, got) + } + }) + } +} + +func TestNumTokensFromString_Consistency(t *testing.T) { + inputs := []string{"hello world", "你好世界", "a+b=c; d!=e"} + for _, s := range inputs { + first := NumTokensFromString(s) + second := NumTokensFromString(s) + if first != second { + t.Errorf("NumTokensFromString(%q) is not consistent: %d vs %d", s, first, second) + } + } +} + +func TestNumTokensFromString_LongString(t *testing.T) { + long := strings.Repeat("the quick brown fox jumps over the lazy dog. ", 200) + got := NumTokensFromString(long) + if got <= 0 { + t.Errorf("NumTokensFromString(long_string) = %d, want >0", got) + } +} + +func TestNumTokensFromString_WhitespaceOnly(t *testing.T) { + for _, s := range []string{" ", "\t", "\n", " "} { + got := NumTokensFromString(s) + // Whitespace strings should still produce tokens in BPE encoding + if got == 0 { + t.Logf("NumTokensFromString(%q) = %d", s, got) + } + } +} + +// --------------------------------------------------------------------------- +// RegisterEngineType tests +// --------------------------------------------------------------------------- + +func TestRegisterEngineType_Basic(t *testing.T) { + restore := saveEngineType() + defer restore() + + RegisterEngineType(func() string { return "infinity" }) + if got := engineTypeProvider(); got != "infinity" { + t.Errorf("expected 'infinity', got %q", got) + } +} + +func TestRegisterEngineType_Overwrite(t *testing.T) { + restore := saveEngineType() + defer restore() + + RegisterEngineType(func() string { return "first" }) + RegisterEngineType(func() string { return "second" }) + if got := engineTypeProvider(); got != "second" { + t.Errorf("expected 'second', got %q", got) + } +} + +// --------------------------------------------------------------------------- +// Tokenize tests +// --------------------------------------------------------------------------- + +func TestTokenize_InfinityEngine(t *testing.T) { + restore := saveEngineType() + defer restore() + RegisterEngineType(func() string { return "infinity" }) + + inputs := []string{"hello world", "你好 世界", "", "a single word"} + for _, input := range inputs { + got, err := Tokenize(input) + if err != nil { + t.Errorf("Tokenize(%q) unexpected error: %v", input, err) + } + if got != input { + t.Errorf("Tokenize(%q) = %q, want %q", input, got, input) + } + } +} + +func TestTokenize_PoolNotInitialized(t *testing.T) { + restore := saveEngineType() + defer restore() + // Ensure engine type is not "infinity" so we hit the pool path + RegisterEngineType(func() string { return "" }) + + _, err := Tokenize("hello world") + if err == nil { + t.Error("expected error when pool is not initialized, got nil") + } +} + +// --------------------------------------------------------------------------- +// FineGrainedTokenize tests +// --------------------------------------------------------------------------- + +func TestFineGrainedTokenize_InfinityEngine(t *testing.T) { + restore := saveEngineType() + defer restore() + RegisterEngineType(func() string { return "infinity" }) + + inputs := []string{"hello world", "测试 分词", ""} + for _, input := range inputs { + got, err := FineGrainedTokenize(input) + if err != nil { + t.Errorf("FineGrainedTokenize(%q) unexpected error: %v", input, err) + } + if got != input { + t.Errorf("FineGrainedTokenize(%q) = %q, want %q", input, got, input) + } + } +} + +func TestFineGrainedTokenize_PoolNotInitialized(t *testing.T) { + restore := saveEngineType() + defer restore() + RegisterEngineType(func() string { return "" }) + + _, err := FineGrainedTokenize("hello world") + if err == nil { + t.Error("expected error when pool is not initialized, got nil") + } +} + +// --------------------------------------------------------------------------- +// Error-path tests for functions that require the pool +// --------------------------------------------------------------------------- + +func TestTokenizeWithPosition_PoolNotInitialized(t *testing.T) { + _, err := TokenizeWithPosition("hello world") + if err == nil { + t.Error("expected error when pool is not initialized, got nil") + } +} + +func TestAnalyze_PoolNotInitialized(t *testing.T) { + _, err := Analyze("hello world") + if err == nil { + t.Error("expected error when pool is not initialized, got nil") + } +} + +func TestGetTermFreq_PoolNotInitialized(t *testing.T) { + got := GetTermFreq("hello") + if got != 0 { + t.Errorf("expected 0 when pool is not initialized, got %d", got) + } +} + +func TestGetTermTag_PoolNotInitialized(t *testing.T) { + got := GetTermTag("hello") + if got != "" { + t.Errorf("expected empty string when pool is not initialized, got %q", got) + } +} + +// --------------------------------------------------------------------------- +// Global state tests +// --------------------------------------------------------------------------- + +func TestGetPoolStats_Nil(t *testing.T) { + // Note: globalPool is nil by default in unit tests (pool not initialized) + stats := GetPoolStats() + if stats == nil { + t.Fatal("GetPoolStats returned nil") + } + init, ok := stats["initialized"] + if !ok { + t.Fatal("missing 'initialized' key") + } + if init.(bool) { + t.Error("expected initialized=false when pool is nil") + } +} + +func TestIsInitialized_Default(t *testing.T) { + if IsInitialized() { + t.Error("expected IsInitialized() = false when pool is not initialized") + } +} + +func TestClose_Nil(t *testing.T) { + // Close should be safe to call with nil globalPool + Close() // no panic = pass +} + +func TestClose_NilGlobalPool(t *testing.T) { + // Call Close directly after ensuring globalPool is nil + // (concurrent test may have initialized it, so handle gracefully) + defer func() { + if r := recover(); r != nil { + t.Errorf("Close() panicked: %v", r) + } + }() + Close() +} diff --git a/rag/prompts/generator.py b/rag/prompts/generator.py index fd58ddff7a..66ce08a42a 100644 --- a/rag/prompts/generator.py +++ b/rag/prompts/generator.py @@ -204,7 +204,7 @@ PROMPT_JINJA_ENV = SandboxedEnvironment( def citation_prompt(user_defined_prompts: dict = {}) -> str: template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("citation_guidelines", CITATION_PROMPT_TEMPLATE)) - return template.render() + return template.render() + "\n\nIMPORTANT: The example IDs above (45, 46, 78, etc.) are illustrative only. Use the actual chunk IDs from the provided knowledge blocks." def citation_plus(sources: str) -> str: