diff --git a/cmd/server_main.go b/cmd/server_main.go index 8fbe790387..cde8f69c11 100644 --- a/cmd/server_main.go +++ b/cmd/server_main.go @@ -271,11 +271,21 @@ func startServer(config *server.Config) { // is a 1-line if-not-nil pass-through — no separate "boot" mode // required. agentOpts := buildAgentRunOptions() - agentHandler := handler.NewAgentHandler(service.NewAgentServiceWithOptions( + agentService := service.NewAgentServiceWithOptions( agentOpts.checkpointStore, agentOpts.stateSerializer, agentOpts.runTracker, - ), fileService) + ) + agentHandler := handler.NewAgentHandler(agentService, fileService) + + // Public chatbot/agentbot endpoints (api/v1/chatbots/..., + // api/v1/agentbots/...) and the agent attachment download. + // BotService delegates the agentbot completion to agentService so + // both paths share the same canvas runner. Reuse the llmService + // already constructed above (line 222) — do NOT redeclare with + // `:=` since the variable is in scope. + botService := service.NewBotService(agentService, llmService) + botHandler := handler.NewBotHandler(botService) // Wire the TTS synthesizer to the per-tenant model-provider // dispatch. SynthesizeRequest is routed through @@ -326,7 +336,7 @@ func startServer(config *server.Config) { adminRuntimeHandler := handler.NewAdminRuntimeHandler(adminRuntimeSelector) // Initialize router - r := router.NewRouter(authHandler, userHandler, tenantHandler, documentHandler, datasetsHandler, systemHandler, knowledgebaseHandler, chunkHandler, llmHandler, chatHandler, chatChannelHandler, langfuseHandler, chatSessionHandler, connectorHandler, searchHandler, fileHandler, memoryHandler, mcpHandler, skillSearchHandler, providerHandler, agentHandler, searchBotHandler, difyRetrievalHandler, pluginHandler, modelHandler, fileCommitHandler, adminRuntimeHandler, openaiChatHandler) + r := router.NewRouter(authHandler, userHandler, tenantHandler, documentHandler, datasetsHandler, systemHandler, knowledgebaseHandler, chunkHandler, llmHandler, chatHandler, chatChannelHandler, langfuseHandler, chatSessionHandler, connectorHandler, searchHandler, fileHandler, memoryHandler, mcpHandler, skillSearchHandler, providerHandler, agentHandler, searchBotHandler, difyRetrievalHandler, pluginHandler, modelHandler, fileCommitHandler, adminRuntimeHandler, openaiChatHandler, botHandler) // Create Gin engine ginEngine := gin.New() diff --git a/internal/agent/dsl/extract.go b/internal/agent/dsl/extract.go index 46609cbd4b..f0c6a5254b 100644 --- a/internal/agent/dsl/extract.go +++ b/internal/agent/dsl/extract.go @@ -131,3 +131,72 @@ func navigateToComponent(dsl map[string]any, componentID string) (map[string]any } return cm, nil } + +// FindBeginComponentID returns the component_id of the canvas component +// whose obj.component_name == "Begin". Returns ErrComponentNotFound if +// no such component exists. Mirrors python Canvas.begin_component_id +// (api/agent/canvas.py:180). +// +// `Begin` is a component NAME (stored at obj.component_name), not a +// component ID. The two are related but not identical; a canvas can +// have a component named "Begin" whose ID is e.g. "sally:0". Callers +// that need to read fields off the begin component must use this +// helper to resolve the name to the ID, then pass the ID to +// navigateToComponent (or any of the ExtractComponent* helpers). +func FindBeginComponentID(dsl map[string]any) (string, error) { + if dsl == nil { + return "", fmt.Errorf("%w: nil dsl", ErrMalformedDSL) + } + comps, ok := dsl["components"].(map[string]any) + if !ok { + return "", fmt.Errorf("%w: missing components map", ErrMalformedDSL) + } + for id, raw := range comps { + cm, ok := raw.(map[string]any) + if !ok { + continue + } + obj, _ := cm["obj"].(map[string]any) + name, _ := obj["component_name"].(string) + if name == "Begin" { + return id, nil + } + } + return "", fmt.Errorf("%w: Begin component", ErrComponentNotFound) +} + +// ExtractPrologue mirrors python Canvas.get_prologue +// (api/agent/canvas.py:190) — returns the "prologue" string stored at +// dsl["components"][]["obj"]["prologue"]. Reuses the +// shared navigateToComponent helper so the addressing rule is +// consistent with ExtractComponentInputForm. +func ExtractPrologue(dsl map[string]any) (string, error) { + id, err := FindBeginComponentID(dsl) + if err != nil { + return "", err + } + comp, err := navigateToComponent(dsl, id) + if err != nil { + return "", err + } + obj, _ := comp["obj"].(map[string]any) + s, _ := obj["prologue"].(string) + return s, nil +} + +// ExtractMode mirrors python Canvas.get_mode (api/agent/canvas.py:200). +// Returns the canvas mode (e.g. "Agent" / "DataFlow") stored at +// dsl["components"][]["obj"]["mode"]. +func ExtractMode(dsl map[string]any) (string, error) { + id, err := FindBeginComponentID(dsl) + if err != nil { + return "", err + } + comp, err := navigateToComponent(dsl, id) + if err != nil { + return "", err + } + obj, _ := comp["obj"].(map[string]any) + s, _ := obj["mode"].(string) + return s, nil +} diff --git a/internal/agent/dsl/extract_test.go b/internal/agent/dsl/extract_test.go index 381b98f4c6..db66d7519f 100644 --- a/internal/agent/dsl/extract_test.go +++ b/internal/agent/dsl/extract_test.go @@ -147,3 +147,120 @@ func TestExtractComponentName_NotFound(t *testing.T) { t.Errorf("err = %v, want ErrComponentNotFound", err) } } + +// TestFindBeginComponentID_HappyPath covers the common case where the +// component ID is literally "begin" (mirrors the +// internal/agent/dsl/testdata fixtures). +func TestFindBeginComponentID_HappyPath(t *testing.T) { + id, err := FindBeginComponentID(happyDSL()) + if err != nil { + t.Fatalf("err = %v, want nil", err) + } + if id != "begin" { + t.Errorf("id = %q, want begin", id) + } +} + +// TestFindBeginComponentID_DifferentID ensures the helper resolves +// the name to whatever ID the canvas uses (mirrors real-world +// canvases where IDs are sally:0 / jack:0 etc.). +func TestFindBeginComponentID_DifferentID(t *testing.T) { + dsl := map[string]any{ + "components": map[string]any{ + "sally:0": map[string]any{ + "obj": map[string]any{ + "component_name": "Begin", + }, + }, + "jack:0": map[string]any{ + "obj": map[string]any{ + "component_name": "LLM", + }, + }, + }, + } + id, err := FindBeginComponentID(dsl) + if err != nil { + t.Fatalf("err = %v, want nil", err) + } + if id != "sally:0" { + t.Errorf("id = %q, want sally:0", id) + } +} + +// TestFindBeginComponentID_NotFound pins that a canvas with no begin +// component returns ErrComponentNotFound. The service layer maps this +// to an empty fallback (degrades gracefully, no panic). +func TestFindBeginComponentID_NotFound(t *testing.T) { + dsl := map[string]any{ + "components": map[string]any{ + "jack:0": map[string]any{ + "obj": map[string]any{ + "component_name": "LLM", + }, + }, + }, + } + _, err := FindBeginComponentID(dsl) + if !errors.Is(err, ErrComponentNotFound) { + t.Errorf("err = %v, want ErrComponentNotFound", err) + } +} + +// TestFindBeginComponentID_NilDSL pins that a nil dsl returns +// ErrMalformedDSL (not a nil-deref panic). +func TestFindBeginComponentID_NilDSL(t *testing.T) { + _, err := FindBeginComponentID(nil) + if !errors.Is(err, ErrMalformedDSL) { + t.Errorf("err = %v, want ErrMalformedDSL", err) + } +} + +// TestExtractPrologue_HappyPath pins the prologue lookup path. +func TestExtractPrologue_HappyPath(t *testing.T) { + dsl := happyDSL() + dsl["components"].(map[string]any)["begin"].(map[string]any)["obj"].(map[string]any)["prologue"] = "hello" + got, err := ExtractPrologue(dsl) + if err != nil { + t.Fatalf("err = %v, want nil", err) + } + if got != "hello" { + t.Errorf("prologue = %q, want hello", got) + } +} + +// TestExtractPrologue_NotFound pins that a missing begin component +// returns ErrComponentNotFound (the service layer turns this into +// empty-string fallback). +func TestExtractPrologue_NotFound(t *testing.T) { + _, err := ExtractPrologue(map[string]any{ + "components": map[string]any{}, + }) + if !errors.Is(err, ErrComponentNotFound) { + t.Errorf("err = %v, want ErrComponentNotFound", err) + } +} + +// TestExtractMode_HappyPath pins the mode lookup path. +func TestExtractMode_HappyPath(t *testing.T) { + dsl := happyDSL() + dsl["components"].(map[string]any)["begin"].(map[string]any)["obj"].(map[string]any)["mode"] = "Agent" + got, err := ExtractMode(dsl) + if err != nil { + t.Fatalf("err = %v, want nil", err) + } + if got != "Agent" { + t.Errorf("mode = %q, want Agent", got) + } +} + +// TestExtractMode_NotFound pins that a missing begin component +// returns ErrComponentNotFound. +func TestExtractMode_NotFound(t *testing.T) { + _, err := ExtractMode(map[string]any{ + "components": map[string]any{}, + }) + if !errors.Is(err, ErrComponentNotFound) { + t.Errorf("err = %v, want ErrComponentNotFound", err) + } +} diff --git a/internal/common/constants.go b/internal/common/constants.go index 64f4a9f0bf..ed0e45ef94 100644 --- a/internal/common/constants.go +++ b/internal/common/constants.go @@ -37,3 +37,13 @@ const ( STOPPED = "STOPPED" STOPPING = "STOPPING" ) + +// StatusDialogValid is the dialog.status value that gates public bot +// access. Mirrors Python's StatusEnum.VALID.value at +// api/common/constants.py (the string "1"). All chatbot/agentbot +// authorization paths must use this constant instead of the literal. +const StatusDialogValid = "1" + +// DialogStatus is a typed alias for dialog.status to avoid raw string +// comparisons in call sites. +type DialogStatus string diff --git a/internal/common/constants_test.go b/internal/common/constants_test.go new file mode 100644 index 0000000000..efe0380a91 --- /dev/null +++ b/internal/common/constants_test.go @@ -0,0 +1,40 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package common + +import "testing" + +// TestStatusDialogValid_Constant pins the value of the dialog valid +// status sentinel. Changing this constant is a wire-contract change — +// it must always equal the Python StatusEnum.VALID.value at +// api/common/constants.py (the literal string "1"). All +// chatbot/agentbot authorization paths depend on this value matching +// the on-disk dialog row. +func TestStatusDialogValid_Constant(t *testing.T) { + if StatusDialogValid != "1" { + t.Errorf("StatusDialogValid = %q, want %q", StatusDialogValid, "1") + } +} + +// TestDialogStatus_Type pins the typed alias so future code can use it +// instead of raw string comparisons. +func TestDialogStatus_Type(t *testing.T) { + var s DialogStatus = StatusDialogValid + if string(s) != "1" { + t.Errorf("DialogStatus = %q, want %q", string(s), "1") + } +} diff --git a/internal/dao/api_token.go b/internal/dao/api_token.go index a387b2f8ea..a91560076b 100644 --- a/internal/dao/api_token.go +++ b/internal/dao/api_token.go @@ -114,6 +114,21 @@ func (dao *API4ConversationDAO) Create(conv *entity.API4Conversation) error { return DB.Create(conv).Error } +// Update writes back an existing api_4_conversation row. The bot +// completion path calls this with the updated Message JSON after each +// turn so multi-turn chatbot sessions carry prior history into the next +// LLM call. Matches the Python conversation_service.update pattern at +// api/db/services/conversation_service.py:236 (async_iframe_completion). +func (dao *API4ConversationDAO) Update(conv *entity.API4Conversation) error { + if conv == nil { + return errors.New("api4 conversation: nil row") + } + if conv.ID == "" { + return errors.New("api4 conversation: empty id") + } + return DB.Save(conv).Error +} + // Stats returns daily conversation aggregates for a tenant. func (dao *API4ConversationDAO) Stats(tenantID, fromDate, toDate string, source *string) ([]ConversationStatsRow, error) { var rows []ConversationStatsRow diff --git a/internal/dao/chat_session.go b/internal/dao/chat_session.go index a5f77f1809..940f0b6b1e 100644 --- a/internal/dao/chat_session.go +++ b/internal/dao/chat_session.go @@ -22,6 +22,7 @@ import ( "gorm.io/gorm" + "ragflow/internal/common" "ragflow/internal/entity" ) @@ -103,7 +104,7 @@ func (dao *ChatSessionDAO) ListByChatID(chatID string) ([]*entity.ChatSession, e func (dao *ChatSessionDAO) CheckDialogExists(tenantID, chatID string) (bool, error) { var count int64 err := DB.Model(&entity.Chat{}). - Where("tenant_id = ? AND id = ? AND status = ?", tenantID, chatID, "1"). + Where("tenant_id = ? AND id = ? AND status = ?", tenantID, chatID, common.StatusDialogValid). Count(&count).Error if err != nil { return false, err @@ -114,7 +115,7 @@ func (dao *ChatSessionDAO) CheckDialogExists(tenantID, chatID string) (bool, err // GetDialogByID gets dialog by ID func (dao *ChatSessionDAO) GetDialogByID(chatID string) (*entity.Chat, error) { var dialog entity.Chat - err := DB.Where("id = ? AND status = ?", chatID, "1").First(&dialog).Error + err := DB.Where("id = ? AND status = ?", chatID, common.StatusDialogValid).First(&dialog).Error if err != nil { return nil, err } diff --git a/internal/handler/agent.go b/internal/handler/agent.go index 629e6b625f..d857b8e661 100644 --- a/internal/handler/agent.go +++ b/internal/handler/agent.go @@ -391,9 +391,15 @@ func (h *AgentHandler) RunAgent(c *gin.Context) { c.Writer.Header().Set("Content-Type", "text/event-stream") c.Writer.Header().Set("Cache-Control", "no-cache") c.Writer.Header().Set("Connection", "keep-alive") - flusher, _ := c.Writer.(http.Flusher) for ev := range events { - writeRunEventSSE(c.Writer, flusher, ev) + if err := service.WriteChatbotRunEvent(c.Writer, ev); err != nil { + common.Debug("agent run: client disconnected", + zap.String("canvas_id", canvasID), + zap.String("session_id", sessionID), + zap.Error(err), + ) + return + } } } @@ -423,31 +429,6 @@ func readUserInput(c *gin.Context) string { return c.Query("user_input") } -// writeRunEventSSE writes one canvas.RunEvent as an SSE frame in the -// Python envelope format (same as writeChatCompletionSSE): -// -// data:{"event":"","message_id":"...","created_at":...,"task_id":"...","session_id":"...","data":} -// -// The "done" type emits `data: [DONE]\n\n`. -func writeRunEventSSE(w io.Writer, flusher http.Flusher, ev canvas.RunEvent) { - if ev.Type == "done" { - fmt.Fprintf(w, "data: [DONE]\n\n") - if flusher != nil { - flusher.Flush() - } - return - } - data := ev.Data - if data == "" { - data = "{}" - } - envelope := sseEnvelope(ev.Type, ev.MessageID, ev.CreatedAt, ev.TaskID, ev.SessionID, data) - fmt.Fprintf(w, "data: %s\n\n", envelope) - if flusher != nil { - flusher.Flush() - } -} - // sanitiseRunEventError passes through the error event payload // unchanged. The runner serialises canvas.ErrorEvent ({"message": ...}) // before push, so when the payload round-trips through JSON the @@ -1066,17 +1047,19 @@ func (h *AgentHandler) AgentChatCompletions(c *gin.Context) { c.Writer.Header().Set("Content-Type", "text/event-stream") c.Writer.Header().Set("Cache-Control", "no-cache") c.Writer.Header().Set("Connection", "keep-alive") - flusher, _ := c.Writer.(http.Flusher) - // SSE wire format mirrors Python's `completion()` at - // api/db/services/canvas_service.py:368: each canvas event is one - // `data: \n\n` frame, and the channel close is signalled by - // `data: [DONE]\n\n`. We do NOT emit an `event:` line — the - // front-end's `use-send-message.ts` parser feeds each `data:` line - // directly into JSON.parse and breaks on the `e` of `event:` - // (browser console: "SyntaxError: Unexpected token 'e', \"event: - // mes\"…"). The richer `writeRunEventSSE` helper still owns the - // /api/v1/agents/{id}/run endpoint's wire format — see - // writeRunEventSSE at agent.go for that path. + // SSE wire format is the unified python envelope used by both + // /api/v1/agents/chat/completions and /api/v1/agentbots//completions. + // One frame per canvas event, all routed through + // service.WriteChatbotRunEvent so the two paths share one writer + // and one shape — see internal/service/bot_completion.go for the + // frame definition. The same unified envelope is used by the + // /api/v1/agents/{canvas_id}/run and /api/v1/agentbots//completions + // endpoints, all going through service.WriteChatbotRunEvent. The + // channel close is signalled by `data: [DONE]\n\n`. We do NOT emit + // an SSE `event:` line — the front-end's `use-send-message.ts` + // parser feeds each `data:` line directly into JSON.parse and + // breaks on the `e` of `event:` (browser console: "SyntaxError: + // Unexpected token 'e', \"event: mes\"…"). for ev := range events { common.Debug("agent chat completions: streaming event", zap.String("agent_id", req.AgentID), @@ -1085,7 +1068,13 @@ func (h *AgentHandler) AgentChatCompletions(c *gin.Context) { zap.String("message_id", ev.MessageID), zap.String("task_id", ev.TaskID), ) - writeChatCompletionSSE(c.Writer, flusher, req.AgentID, ev) + if err := service.WriteChatbotRunEvent(c.Writer, ev); err != nil { + common.Debug("agent chat completions: client disconnected", + zap.String("agent_id", req.AgentID), + zap.Error(err), + ) + return + } } common.Debug("agent chat completions: stream closed", zap.String("agent_id", req.AgentID), @@ -1093,53 +1082,6 @@ func (h *AgentHandler) AgentChatCompletions(c *gin.Context) { ) } -// writeChatCompletionSSE emits one canvas.RunEvent in the -// Python-shaped chat-completion SSE envelope: -// -// data:{"event":"","message_id":"","created_at":,"task_id":"","session_id":"","data":} -// -// The special "done" type sends `data: [DONE]\n\n` (no JSON envelope). -func writeChatCompletionSSE(w io.Writer, flusher http.Flusher, agentID string, ev canvas.RunEvent) { - if ev.Type == "done" { - common.Debug("agent chat completions: writing done sentinel", - zap.String("agent_id", agentID), - zap.String("session_id", ev.SessionID), - zap.String("task_id", ev.TaskID), - ) - fmt.Fprint(w, "data: [DONE]\n\n") - if flusher != nil { - flusher.Flush() - } - return - } - data := ev.Data - if data == "" { - data = "{}" - } - common.Debug("agent chat completions: writing sse frame", - zap.String("agent_id", agentID), - zap.String("event_type", ev.Type), - zap.String("message_id", ev.MessageID), - zap.String("session_id", ev.SessionID), - zap.String("task_id", ev.TaskID), - ) - envelope := sseEnvelope(ev.Type, ev.MessageID, ev.CreatedAt, ev.TaskID, ev.SessionID, data) - fmt.Fprintf(w, "data: %s\n\n", envelope) - if flusher != nil { - flusher.Flush() - } -} - -// sseEnvelope builds the Python-shaped SSE JSON payload: -// -// {"event":"","message_id":"","created_at":,"task_id":"","session_id":"","data":} -func sseEnvelope(typ, mid string, ts int64, tid, sid, rawData string) string { - return fmt.Sprintf( - `{"event":%q,"message_id":%q,"created_at":%d,"task_id":%q,"session_id":%q,"data":%s}`, - typ, mid, ts, tid, sid, rawData, - ) -} - // RerunAgent POST /api/v1/agents/rerun — requires id, dsl, and // component_id. The Python agent API uses PipelineOperationLogService // and the dataflow queue, none of which the Go port has implemented diff --git a/internal/handler/agent_attachment.go b/internal/handler/agent_attachment.go new file mode 100644 index 0000000000..ade9dfcb2d --- /dev/null +++ b/internal/handler/agent_attachment.go @@ -0,0 +1,117 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Gap D — `GET /api/v1/agents/attachments//download` +// (Python api/apps/restful_apis/agent_api.py:2368). +// +// Mirrors the python download_agent_attachment handler: +// - auth via @login_required → GetUser +// - reads `attachment_id` from the URL path (NOT a query string) +// - default `ext` query parameter is "markdown" +// - uses utility.CONTENT_TYPE_MAP to pick the content type, falling +// back to "application/" for unknown extensions +// - streams raw bytes back with a sanitized Content-Disposition + +package handler + +import ( + "fmt" + "net/http" + "net/url" + "path/filepath" + "strings" + + "github.com/gin-gonic/gin" + + "ragflow/internal/common" + "ragflow/internal/utility" +) + +// agentAttachmentFileService is the subset of FileService used by +// the attachment-download handler. +type agentAttachmentFileService interface { + DownloadAgentFile(tenantID, location string) ([]byte, error) +} + +// DownloadAttachment GET /api/v1/agents/attachments//download +func (h *AgentHandler) DownloadAttachment(c *gin.Context) { + user, code, msg := GetUser(c) + if code != common.CodeSuccess { + jsonError(c, code, msg) + return + } + attachmentID := c.Param("attachment_id") + if attachmentID == "" { + jsonError(c, common.CodeArgumentError, "`attachment_id` is required.") + return + } + // Note (review F9): the plan explicitly defers attachment-id + // shape validation to the storage layer. The python download + // endpoint at api/apps/restful_apis/agent_api.py:2368 and the + // existing Go DownloadAgentFile path rely on storage lookup + + // header sanitization; we DO NOT gate on UUID here because + // attachment IDs in storage are not guaranteed UUIDs and the + // review found no evidence of a UUID invariant. The + // filepath.Base + CR/LF/quote check below is the only defensive + // layer and runs BEFORE the file-service call so an unsafe id + // never crosses the service boundary. + safe := filepath.Base(attachmentID) + if safe == "" || safe == "." || safe == "/" || strings.ContainsAny(safe, "\r\n\"") { + jsonError(c, common.CodeArgumentError, "invalid attachment id.") + return + } + + // Normalize the ext query once. A blank or dotted input like + // `?ext=` or `?ext=.pdf` would otherwise produce a malformed + // MIME type like `application/` or `application/.pdf`. Trim + // whitespace, lowercase, strip any leading dot, then fall back + // to markdown when the value is empty. + ext := strings.ToLower(strings.TrimSpace(c.DefaultQuery("ext", "markdown"))) + ext = strings.TrimPrefix(ext, ".") + if ext == "" { + ext = "markdown" + } + + // IDOR note: the Go User struct collapses user/tenant into one + // identifier (same model as the python download_agent_file + // endpoint at agent_api.py:523-530). The python attachment + // endpoint relies on the storage bucket's tenant scoping for + // authorisation. The Go port preserves that shape. + if h.fileService == nil { + jsonError(c, common.CodeServerError, "file service not configured") + return + } + blob, err := h.fileService.DownloadAgentFile(user.ID, attachmentID) + if err != nil { + // Mirror agent_download.go error mapping — DAO/transport + // errors collapse to a generic 102 so we don't leak storage + // internals in the response body. + jsonError(c, common.CodeDataError, "Attachment not found!") + return + } + + contentType := utility.CONTENT_TYPE_MAP[ext] + if contentType == "" { + // Fallback for unknown extensions — keep the wire shape + // consistent with the python handler. + contentType = "application/" + ext + } + c.Header("Content-Disposition", fmt.Sprintf( + `attachment; filename="%s"; filename*=UTF-8''%s`, + safe, url.PathEscape(safe), + )) + c.Data(http.StatusOK, contentType, blob) +} diff --git a/internal/handler/agent_test.go b/internal/handler/agent_test.go index 0c49d37bd6..365156a1f5 100644 --- a/internal/handler/agent_test.go +++ b/internal/handler/agent_test.go @@ -729,9 +729,11 @@ func (s *stubChatRunner) RunAgent(_ context.Context, _, _, _, _ string, _ any) ( // TestAgentChatCompletions_StreamSetsContentType covers the SSE // path: the handler streams canvas.RunEvent frames as -// `data: {...}\n\n` with a trailing `data: [DONE]\n\n` terminator, -// matching the Python `completion()` wire format in -// api/db/services/canvas_service.py:368. +// `data: {...}\n\n` with a trailing `data: [DONE]\n\n` terminator. +// The frame shape is the unified python envelope +// {code:0, message:"", data:{answer, reference, audio_binary, id, +// session_id}} — the same shape /api/v1/agentbots//completions +// emits. See service.WriteChatbotRunEvent and WriteChatbotFrame. // // The stubChatRunner emits one `message` frame and one `done` frame // so the test verifies the body contains both the framed event and @@ -757,8 +759,12 @@ func TestAgentChatCompletions_StreamSetsContentType(t *testing.T) { t.Errorf("Content-Type = %q, want text/event-stream", got) } body := w.Body.String() - if !strings.Contains(body, "\"event\":\"message\"") || !strings.Contains(body, "\"answer\":\"hi back\"") { - t.Errorf("body should contain framed message event, got %q", body) + // Body must contain the unified python envelope (`code/data.answer`) + // and the [DONE] terminator. The iframe SDK JSON.parse()s `answer` + // to extract the inner fields, so the embedded JSON is double-encoded + // (escaped quotes inside the outer `"answer"` string). + if !strings.Contains(body, "\"code\":0") || !strings.Contains(body, `"answer":"{\"answer\":\"hi back\",\"reference\":[]}"`) { + t.Errorf("body should contain unified python envelope with answer, got %q", body) } if !strings.HasSuffix(body, "data: [DONE]\n\n") { t.Errorf("body should end with [DONE] terminator, got %q", body) @@ -768,9 +774,9 @@ func TestAgentChatCompletions_StreamSetsContentType(t *testing.T) { // TestAgentChatCompletions_DefaultBranchStreamsSSE covers the // scenario the user actually hit: `openai-compatible: false` with no // `stream` field on the body. The handler must still invoke the -// canvas runner and stream the result as SSE — matching Python's -// `completion()` which always yields SSE on the non-openai path -// regardless of the stream flag. +// canvas runner and stream the result as SSE — the SSE envelope is +// the unified python shape shared with +// /api/v1/agentbots//completions regardless of the stream flag. func TestAgentChatCompletions_DefaultBranchStreamsSSE(t *testing.T) { gin.SetMode(gin.TestMode) w := httptest.NewRecorder() @@ -792,8 +798,8 @@ func TestAgentChatCompletions_DefaultBranchStreamsSSE(t *testing.T) { t.Errorf("Content-Type = %q, want text/event-stream (default branch must stream)", got) } body := w.Body.String() - if !strings.Contains(body, "\"event\":\"message\"") || !strings.Contains(body, "\"answer\":\"hello back\"") { - t.Errorf("body should contain framed message event, got %q", body) + if !strings.Contains(body, "\"code\":0") || !strings.Contains(body, `"answer":"{\"answer\":\"hello back\",\"reference\":[]}"`) { + t.Errorf("body should contain unified python envelope with answer, got %q", body) } if !strings.HasSuffix(body, "data: [DONE]\n\n") { t.Errorf("body should end with [DONE] terminator, got %q", body) diff --git a/internal/handler/auth.go b/internal/handler/auth.go index 02d0d49777..d988dc33ac 100644 --- a/internal/handler/auth.go +++ b/internal/handler/auth.go @@ -20,6 +20,7 @@ import ( "fmt" "net/http" "ragflow/internal/common" + "ragflow/internal/entity" "ragflow/internal/server/local" "ragflow/internal/service" @@ -28,7 +29,17 @@ import ( // AuthHandler auth handler type AuthHandler struct { - userService *service.UserService + userService userTokenResolver +} + +// userTokenResolver is the subset of UserService the auth +// middleware actually depends on. We keep it as a small interface +// so the test suite can swap in a stub without spinning up the +// full UserService (which requires a live Redis + JWT secret). +type userTokenResolver interface { + GetUserByToken(authorization string) (*entity.User, common.ErrorCode, error) + GetUserByAPIToken(token string) (*entity.User, common.ErrorCode, error) + GetUserByBetaAPIToken(token string) (*entity.User, common.ErrorCode, error) } // NewAuthHandler create auth handler @@ -38,6 +49,50 @@ func NewAuthHandler() *AuthHandler { } } +// BetaAuthMiddleware resolves a `beta` API token from the Authorization +// header and sets the user on the gin.Context, mirroring Python's +// @login_required(auth_types=AUTH_BETA) used by /chatbots and +// /agentbots route groups. +// +// A beta token can also be a regular user JWT — in that case we +// delegate to the existing AuthMiddleware logic. Order of precedence: +// +// 1. JWT (regular session) → existing UserService.GetUserByToken +// 2. Beta API token → GetUserByBetaAPIToken +// 3. Fall through → 401 +// +// IMPORTANT: the regular-user branch is NOT gated on a "Bearer " +// prefix. UserService.GetUserByToken accepts the raw Authorization +// header value and ExtractAccessToken handles Bearer stripping +// internally. The existing AuthMiddleware() above also passes the +// raw header to GetUserByToken without pre-filtering, so a non-Bearer +// regular user token must keep working here too. +func (h *AuthHandler) BetaAuthMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + auth := c.GetHeader("Authorization") + if auth == "" { + jsonError(c, common.CodeUnauthorized, "Authorization required") + c.Abort() + return + } + // Try regular user session first (handles JWT, Bearer, or + // raw access_token — same dispatch as AuthMiddleware()). + if u, code, err := h.userService.GetUserByToken(auth); err == nil && code == common.CodeSuccess { + c.Set("user", u) + c.Next() + return + } + // Fall back to beta API token (public bot access). + if u, code, err := h.userService.GetUserByBetaAPIToken(auth); err == nil && code == common.CodeSuccess { + c.Set("user", u) + c.Next() + return + } + jsonError(c, common.CodeUnauthorized, "Invalid auth credentials") + c.Abort() + } +} + // AuthMiddleware JWT auth middleware // Validates that the user is authenticated and is a superuser (admin) func (h *AuthHandler) AuthMiddleware() gin.HandlerFunc { diff --git a/internal/handler/auth_test.go b/internal/handler/auth_test.go new file mode 100644 index 0000000000..4e2b62ac91 --- /dev/null +++ b/internal/handler/auth_test.go @@ -0,0 +1,66 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package handler + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + + "ragflow/internal/common" +) + +// TestBetaAuthMiddleware_MissingHeader pins the no-header branch — +// the middleware must short-circuit with 401/CodeUnauthorized and +// must not call into UserService. The other branches (regular JWT +// and beta token) require a live DB to resolve, so they are covered +// by the cross-cutting TestBotRoutes_RequireAuth criterion in +// bot_test.go. +func TestBetaAuthMiddleware_MissingHeader(t *testing.T) { + gin.SetMode(gin.TestMode) + ah := &AuthHandler{userService: nil} + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + mw := ah.BetaAuthMiddleware() + mw(c) + + if !c.IsAborted() { + t.Fatalf("context not aborted, want aborted (no Authorization header)") + } + if rec.Code != http.StatusOK { + t.Errorf("status = %d, want %d", rec.Code, http.StatusOK) + } + // jsonError writes 200 with a CodeUnauthorized body. Confirm the + // body shape matches the wire contract used by the rest of the + // bot handlers by decoding the JSON envelope and asserting the + // code field rather than just checking for a non-empty body. + var resp struct { + Code common.ErrorCode `json:"code"` + } + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal response: %v; body = %s", err, rec.Body.String()) + } + if resp.Code != common.CodeUnauthorized { + t.Errorf("code = %d, want %d; body = %s", + resp.Code, common.CodeUnauthorized, rec.Body.String()) + } +} diff --git a/internal/handler/bot.go b/internal/handler/bot.go new file mode 100644 index 0000000000..0307a75b5c --- /dev/null +++ b/internal/handler/bot.go @@ -0,0 +1,260 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package handler + +import ( + "context" + + "github.com/gin-gonic/gin" + + "ragflow/internal/agent/canvas" + "ragflow/internal/common" + "ragflow/internal/service" +) + +// BotHandler is the handler for the public chatbot/agentbot +// endpoints mounted on /api/v1/chatbots/* and /api/v1/agentbots/*. +// The two route groups share BetaAuthMiddleware (set up at +// registration time via g.Use(mw)) and share the same handler +// struct because they are wired to the same BotService. +type BotHandler struct { + botService botService +} + +// botService is the subset of BotService used by these handlers. It +// is interface-typed so the test suite can inject a stub. +type botService interface { + ChatbotInfo(ctx context.Context, tenantID, dialogID string) ( + title, avatar, prologue, llmID string, hasTavilyKey bool, ec common.ErrorCode, err error) + AgentbotInputs(ctx context.Context, tenantID, agentID string) ( + title, avatar, prologue, mode string, inputs map[string]any, + ec common.ErrorCode, err error) + AgentbotCompletion(ctx context.Context, tenantID, agentID string, req service.AgentbotCompletionRequest) ( + <-chan canvas.RunEvent, common.ErrorCode, error) + ChatbotCompletion(ctx context.Context, tenantID, dialogID string, req service.ChatbotCompletionRequest) ( + <-chan service.ChatbotSSEFrame, common.ErrorCode, error) +} + +// NewBotHandler wires a BotHandler with the production BotService. +func NewBotHandler(svc *service.BotService) *BotHandler { + return &BotHandler{botService: svc} +} + +// ChatbotInfo GET /api/v1/chatbots//info +// +// Mirrors python bot_api.py:126-154. Returns the public metadata of +// a chatbot dialog (title, avatar, prologue, tavily key flag, llm_id). +func (h *BotHandler) ChatbotInfo(c *gin.Context) { + user, code, msg := GetUser(c) + if code != common.CodeSuccess { + jsonError(c, code, msg) + return + } + dialogID := c.Param("dialog_id") + if dialogID == "" { + jsonError(c, common.CodeArgumentError, "`dialog_id` is required.") + return + } + title, avatar, prologue, llmID, hasTavily, ec, err := h.botService.ChatbotInfo( + c.Request.Context(), user.ID, dialogID) + if err != nil { + jsonError(c, ec, err.Error()) + return + } + c.JSON(200, gin.H{ + "code": common.CodeSuccess, + "data": gin.H{ + "title": title, + "avatar": avatar, + "prologue": prologue, + "has_tavily_key": hasTavily, + "llm_id": llmID, + }, + "message": "success", + }) +} + +// AgentbotInputs GET /api/v1/agentbots//inputs +// +// Mirrors python bot_api.py:239-250. Returns the public metadata of +// an agentbot canvas (title, avatar, inputs, prologue, mode). +func (h *BotHandler) AgentbotInputs(c *gin.Context) { + user, code, msg := GetUser(c) + if code != common.CodeSuccess { + jsonError(c, code, msg) + return + } + agentID := c.Param("agent_id") + if agentID == "" { + jsonError(c, common.CodeArgumentError, "`agent_id` is required.") + return + } + title, avatar, prologue, mode, inputs, ec, err := h.botService.AgentbotInputs( + c.Request.Context(), user.ID, agentID) + if err != nil { + jsonError(c, ec, err.Error()) + return + } + c.JSON(200, gin.H{ + "code": common.CodeSuccess, + "data": gin.H{ + "title": title, + "avatar": avatar, + "inputs": inputs, + "prologue": prologue, + "mode": mode, + }, + "message": "success", + }) +} + +// AgentbotCompletion POST /api/v1/agentbots//completions +// +// Mirrors python bot_api.py:157 (canvas_service.completion wrapper). +// Streams SSE frames in the Python envelope shape. The URL-bound +// agent_id is authoritative — the body must NOT override it. +// +// Each canvas.RunEvent is re-formatted into the Python +// {code, message, data} envelope: a "message" event's Data string is +// treated as the assistant text, "message_end" terminates the +// stream with the python completion marker. +func (h *BotHandler) AgentbotCompletion(c *gin.Context) { + user, code, msg := GetUser(c) + if code != common.CodeSuccess { + jsonError(c, code, msg) + return + } + agentID := c.Param("agent_id") + if agentID == "" { + jsonError(c, common.CodeArgumentError, "`agent_id` is required.") + return + } + var body service.AgentbotCompletionRequest + // ContentLength != 0 (not > 0) so chunked requests carrying a + // valid JSON body with ContentLength == -1 still bind. The old + // `> 0` guard silently dropped those payloads and the canvas + // then ran with empty inputs. + if c.Request.ContentLength != 0 { + if err := c.ShouldBindJSON(&body); err != nil { + jsonError(c, common.CodeArgumentError, "Invalid request: "+err.Error()) + return + } + } + events, ec, err := h.botService.AgentbotCompletion( + c.Request.Context(), user.ID, agentID, body) + if err != nil { + jsonError(c, ec, err.Error()) + return + } + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + for ev := range events { + switch ev.Type { + case "message": + // The python iframe_completion wrapper flattens each + // message chunk into a {code:0, data:{answer:...}} + // frame. We forward the message Data as the assistant + // text payload so the iframe SDK's `data.answer` + // parser keeps working. The agentbot path uses + // WriteAgentbotFrame (a thin alias for + // WriteChatbotFrame) to keep the two paths visually + // distinct in the handler. + frame := service.ChatbotSSEFrame{ + Data: ev.Data, + Reference: map[string]any{}, + SessionID: ev.SessionID, + } + if err := service.WriteAgentbotFrame(c.Writer, frame); err != nil { + return + } + case "message_end", "done": + // Terminator events. message_end occasionally carries + // a final payload (e.g. structured output); forward + // it as a final answer frame when present, then close + // the stream with the standard python completion + // marker. A bare `done` event closes the stream + // directly. + if ev.Data != "" { + frame := service.ChatbotSSEFrame{ + Data: ev.Data, + Reference: map[string]any{}, + SessionID: ev.SessionID, + } + _ = service.WriteAgentbotFrame(c.Writer, frame) + } + _ = service.WriteDoneFrame(c.Writer) + return + default: + // Non-message events (node_started, node_finished, …) + // are silently dropped on the agentbot path. The + // python canvas_service.completion wrapper only + // forwards the assistant text frames, not the run + // telemetry; we mirror that behaviour so external + // widgets see the same wire shape. + } + } +} + +// ChatbotCompletion POST /api/v1/chatbots//completions +// +// Mirrors python bot_api.py:55 (async_iframe_completion). Streams +// SSE frames in the Python envelope shape. The streaming helper +// lives in service/bot_completion.go. +func (h *BotHandler) ChatbotCompletion(c *gin.Context) { + user, code, msg := GetUser(c) + if code != common.CodeSuccess { + jsonError(c, code, msg) + return + } + dialogID := c.Param("dialog_id") + if dialogID == "" { + jsonError(c, common.CodeArgumentError, "`dialog_id` is required.") + return + } + var body service.ChatbotCompletionRequest + // ContentLength != 0 (not > 0) so chunked requests carrying a + // valid JSON body with ContentLength == -1 still bind. The old + // `> 0` guard silently dropped those payloads and the chatbot + // then ran with empty session_id/question. + if c.Request.ContentLength != 0 { + if err := c.ShouldBindJSON(&body); err != nil { + jsonError(c, common.CodeArgumentError, "Invalid request: "+err.Error()) + return + } + } + frames, ec, err := h.botService.ChatbotCompletion( + c.Request.Context(), user.ID, dialogID, body) + if err != nil { + jsonError(c, ec, err.Error()) + return + } + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + for f := range frames { + if f.Done { + if err := service.WriteDoneFrame(c.Writer); err != nil { + return + } + continue + } + if err := service.WriteChatbotFrame(c.Writer, f); err != nil { + return + } + } +} diff --git a/internal/handler/bot_test.go b/internal/handler/bot_test.go new file mode 100644 index 0000000000..3600caac41 --- /dev/null +++ b/internal/handler/bot_test.go @@ -0,0 +1,1081 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package handler + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "mime/multipart" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gin-gonic/gin" + + "ragflow/internal/agent/canvas" + "ragflow/internal/common" + "ragflow/internal/entity" + "ragflow/internal/service" +) + +// stubBotService is the stub for the botService interface used by +// BotHandler. Each test case sets only the methods it needs; unset +// methods return safe defaults. +type stubBotService struct { + chatbotInfoFn func(ctx context.Context, tenantID, dialogID string) (string, string, string, string, bool, common.ErrorCode, error) + agentbotInputsFn func(ctx context.Context, tenantID, agentID string) (string, string, string, string, map[string]any, common.ErrorCode, error) + agentbotCompleteFn func(ctx context.Context, tenantID, agentID string, req service.AgentbotCompletionRequest) (<-chan canvas.RunEvent, common.ErrorCode, error) + chatbotCompleteFn func(ctx context.Context, tenantID, dialogID string, req service.ChatbotCompletionRequest) (<-chan service.ChatbotSSEFrame, common.ErrorCode, error) +} + +func (s *stubBotService) ChatbotInfo(ctx context.Context, tenantID, dialogID string) (string, string, string, string, bool, common.ErrorCode, error) { + if s.chatbotInfoFn != nil { + return s.chatbotInfoFn(ctx, tenantID, dialogID) + } + return "", "", "", "", false, common.CodeDataError, errors.New("not stubbed") +} + +func (s *stubBotService) AgentbotInputs(ctx context.Context, tenantID, agentID string) (string, string, string, string, map[string]any, common.ErrorCode, error) { + if s.agentbotInputsFn != nil { + return s.agentbotInputsFn(ctx, tenantID, agentID) + } + return "", "", "", "", nil, common.CodeDataError, errors.New("not stubbed") +} + +func (s *stubBotService) AgentbotCompletion(ctx context.Context, tenantID, agentID string, req service.AgentbotCompletionRequest) (<-chan canvas.RunEvent, common.ErrorCode, error) { + if s.agentbotCompleteFn != nil { + return s.agentbotCompleteFn(ctx, tenantID, agentID, req) + } + return nil, common.CodeDataError, errors.New("not stubbed") +} + +func (s *stubBotService) ChatbotCompletion(ctx context.Context, tenantID, dialogID string, req service.ChatbotCompletionRequest) (<-chan service.ChatbotSSEFrame, common.ErrorCode, error) { + if s.chatbotCompleteFn != nil { + return s.chatbotCompleteFn(ctx, tenantID, dialogID, req) + } + return nil, common.CodeDataError, errors.New("not stubbed") +} + +// botTestEngine wires a gin engine with the bot routes + a fake +// user (so the BotHandler's GetUser check passes). Returns the +// engine and the stub. +// +// Routes are registered INLINE here (not via RegisterChatbotRoutes +// from internal/router) to avoid an import cycle — the router +// package already imports this handler package. The route paths +// must stay in sync with internal/router/bot_routes.go. +func botTestEngine(stub *stubBotService) *gin.Engine { + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(func(c *gin.Context) { + c.Set("user", &entity.User{ID: "tenant-x"}) + c.Next() + }) + h := NewBotHandler(nil) + h.botService = stub + chatbot := r.Group("/api/v1/chatbots") + chatbot.Use(func(c *gin.Context) { + c.Set("user", &entity.User{ID: "tenant-x"}) + c.Next() + }) + chatbot.POST("/:dialog_id/completions", h.ChatbotCompletion) + chatbot.GET("/:dialog_id/info", h.ChatbotInfo) + + agentbot := r.Group("/api/v1/agentbots") + agentbot.Use(func(c *gin.Context) { + c.Set("user", &entity.User{ID: "tenant-x"}) + c.Next() + }) + agentbot.POST("/:agent_id/completions", h.AgentbotCompletion) + agentbot.GET("/:agent_id/inputs", h.AgentbotInputs) + return r +} + +// doJSON is a tiny test helper that fires an HTTP request and +// returns the recorder. +func doJSON(r *gin.Engine, method, path, body string) *httptest.ResponseRecorder { + var reqBody *bytes.Reader + if body != "" { + reqBody = bytes.NewReader([]byte(body)) + } else { + reqBody = bytes.NewReader(nil) + } + req, _ := http.NewRequest(method, path, reqBody) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + return w +} + +// ----- ChatbotInfo tests (criteria 13, 14, 15, 16, 29) ----- + +// TestChatbotInfo_OK covers the happy path (criterion 13). +func TestChatbotInfo_OK(t *testing.T) { + stub := &stubBotService{ + chatbotInfoFn: func(ctx context.Context, tenantID, dialogID string) (string, string, string, string, bool, common.ErrorCode, error) { + return "My Bot", "avatar.png", "Hello!", "gpt-4", false, common.CodeSuccess, nil + }, + } + r := botTestEngine(stub) + w := doJSON(r, http.MethodGet, "/api/v1/chatbots/d1/info", "") + if w.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", w.Code) + } + var resp struct { + Code int `json:"code"` + Data map[string]interface{} `json:"data"` + Message string `json:"message"` + } + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("bad JSON: %v", err) + } + if resp.Code != 0 { + t.Errorf("code = %d, want 0", resp.Code) + } + if resp.Data["title"] != "My Bot" { + t.Errorf("title = %v, want My Bot", resp.Data["title"]) + } + if resp.Data["prologue"] != "Hello!" { + t.Errorf("prologue = %v, want Hello!", resp.Data["prologue"]) + } + if resp.Data["llm_id"] != "gpt-4" { + t.Errorf("llm_id = %v, want gpt-4", resp.Data["llm_id"]) + } +} + +// TestChatbotInfo_HasTavilyKey covers criterion 14. +func TestChatbotInfo_HasTavilyKey(t *testing.T) { + stub := &stubBotService{ + chatbotInfoFn: func(ctx context.Context, tenantID, dialogID string) (string, string, string, string, bool, common.ErrorCode, error) { + return "Bot", "", "", "gpt-4", true, common.CodeSuccess, nil + }, + } + r := botTestEngine(stub) + w := doJSON(r, http.MethodGet, "/api/v1/chatbots/d1/info", "") + var resp struct { + Data map[string]interface{} `json:"data"` + } + _ = json.Unmarshal(w.Body.Bytes(), &resp) + if resp.Data["has_tavily_key"] != true { + t.Errorf("has_tavily_key = %v, want true", resp.Data["has_tavily_key"]) + } +} + +// TestChatbotInfo_ForeignTenant covers criterion 15. +func TestChatbotInfo_ForeignTenant(t *testing.T) { + stub := &stubBotService{ + chatbotInfoFn: func(ctx context.Context, tenantID, dialogID string) (string, string, string, string, bool, common.ErrorCode, error) { + return "", "", "", "", false, common.CodeDataError, errors.New("Authentication error: no access to this chatbot!") + }, + } + r := botTestEngine(stub) + w := doJSON(r, http.MethodGet, "/api/v1/chatbots/d1/info", "") + if w.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", w.Code) + } + var resp struct { + Code int `json:"code"` + Message string `json:"message"` + } + _ = json.Unmarshal(w.Body.Bytes(), &resp) + if resp.Code != 102 { + t.Errorf("code = %d, want 102", resp.Code) + } +} + +// TestChatbotInfo_MissingPrologueField covers criterion 29. +func TestChatbotInfo_MissingPrologueField(t *testing.T) { + // Stub returns empty prologue (mimics the defensive stringFromMap + // fallback when the field is absent or non-string). + stub := &stubBotService{ + chatbotInfoFn: func(ctx context.Context, tenantID, dialogID string) (string, string, string, string, bool, common.ErrorCode, error) { + return "Bot", "", "", "gpt-4", false, common.CodeSuccess, nil + }, + } + r := botTestEngine(stub) + w := doJSON(r, http.MethodGet, "/api/v1/chatbots/d1/info", "") + var resp struct { + Data map[string]interface{} `json:"data"` + } + _ = json.Unmarshal(w.Body.Bytes(), &resp) + if got, ok := resp.Data["prologue"].(string); !ok || got != "" { + t.Errorf("prologue = %v, want \"\" (string)", resp.Data["prologue"]) + } +} + +// ----- ChatbotCompletion tests (criteria 6, 7, 8, 9, 10, 11, 12) ----- + +// TestChatbotCompletion_AuthoriseFail covers criterion 6. +func TestChatbotCompletion_AuthoriseFail(t *testing.T) { + stub := &stubBotService{ + chatbotCompleteFn: func(ctx context.Context, tenantID, dialogID string, req service.ChatbotCompletionRequest) (<-chan service.ChatbotSSEFrame, common.ErrorCode, error) { + return nil, common.CodeDataError, errors.New("no access to this chatbot") + }, + } + r := botTestEngine(stub) + w := doJSON(r, http.MethodPost, "/api/v1/chatbots/d1/completions", `{"question":"hi"}`) + if w.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", w.Code) + } + var resp struct { + Code int `json:"code"` + Message string `json:"message"` + } + _ = json.Unmarshal(w.Body.Bytes(), &resp) + if resp.Code != 102 { + t.Errorf("code = %d, want 102", resp.Code) + } + if !strings.Contains(resp.Message, "no access") { + t.Errorf("message = %q, want contains 'no access'", resp.Message) + } +} + +// TestChatbotCompletion_StreamsSSE covers criterion 11. +func TestChatbotCompletion_StreamsSSE(t *testing.T) { + stub := &stubBotService{ + chatbotCompleteFn: func(ctx context.Context, tenantID, dialogID string, req service.ChatbotCompletionRequest) (<-chan service.ChatbotSSEFrame, common.ErrorCode, error) { + ch := make(chan service.ChatbotSSEFrame, 4) + go func() { + defer close(ch) + ch <- service.ChatbotSSEFrame{Data: "hello", SessionID: "s1"} + ch <- service.ChatbotSSEFrame{Done: true} + }() + return ch, common.CodeSuccess, nil + }, + } + r := botTestEngine(stub) + w := doJSON(r, http.MethodPost, "/api/v1/chatbots/d1/completions", `{"question":"hi"}`) + if w.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", w.Code) + } + frames := parseBotSSEFrames(t, w.Body.Bytes()) + if len(frames) < 3 { + t.Fatalf("expected >= 3 frames, got %d: %v", len(frames), frames) + } + // First frame is the data envelope. + var env map[string]any + if err := json.Unmarshal([]byte(frames[0]), &env); err != nil { + t.Fatalf("bad JSON: %v", err) + } + if env["code"].(float64) != 0 { + t.Errorf("frame code = %v, want 0", env["code"]) + } + data, _ := env["data"].(map[string]any) + if data["answer"] != "hello" { + t.Errorf("frame answer = %v, want hello", data["answer"]) + } + if data["session_id"] != "s1" { + t.Errorf("frame session_id = %v, want s1", data["session_id"]) + } + // Last frame is [DONE]. + if frames[len(frames)-1] != "[DONE]" { + t.Errorf("last frame = %q, want [DONE]", frames[len(frames)-1]) + } +} + +// TestChatbotCompletion_LLMUnavailable covers criterion 12. +func TestChatbotCompletion_LLMUnavailable(t *testing.T) { + stub := &stubBotService{ + chatbotCompleteFn: func(ctx context.Context, tenantID, dialogID string, req service.ChatbotCompletionRequest) (<-chan service.ChatbotSSEFrame, common.ErrorCode, error) { + return nil, common.CodeDataError, errors.New("LLM not available: timeout") + }, + } + r := botTestEngine(stub) + w := doJSON(r, http.MethodPost, "/api/v1/chatbots/d1/completions", `{"question":"hi"}`) + var resp struct { + Code int `json:"code"` + Message string `json:"message"` + } + _ = json.Unmarshal(w.Body.Bytes(), &resp) + if resp.Code != 102 { + t.Errorf("code = %d, want 102", resp.Code) + } + if !strings.Contains(resp.Message, "LLM not available") { + t.Errorf("message = %q, want contains 'LLM not available'", resp.Message) + } +} + +// TestChatbotCompletion_SessionNotFound covers criterion 10. +func TestChatbotCompletion_SessionNotFound(t *testing.T) { + stub := &stubBotService{ + chatbotCompleteFn: func(ctx context.Context, tenantID, dialogID string, req service.ChatbotCompletionRequest) (<-chan service.ChatbotSSEFrame, common.ErrorCode, error) { + return nil, common.CodeDataError, errors.New("session not found") + }, + } + r := botTestEngine(stub) + w := doJSON(r, http.MethodPost, "/api/v1/chatbots/d1/completions", `{"session_id":"missing","question":"hi"}`) + var resp struct { + Code int `json:"code"` + Message string `json:"message"` + } + _ = json.Unmarshal(w.Body.Bytes(), &resp) + if resp.Code != 102 { + t.Errorf("code = %d, want 102", resp.Code) + } + if !strings.Contains(resp.Message, "session not found") { + t.Errorf("message = %q, want contains 'session not found'", resp.Message) + } +} + +// TestChatbotCompletion_CreatesNewSession covers criterion 7. +func TestChatbotCompletion_CreatesNewSession(t *testing.T) { + var capturedReq service.ChatbotCompletionRequest + stub := &stubBotService{ + chatbotCompleteFn: func(ctx context.Context, tenantID, dialogID string, req service.ChatbotCompletionRequest) (<-chan service.ChatbotSSEFrame, common.ErrorCode, error) { + capturedReq = req + ch := make(chan service.ChatbotSSEFrame, 2) + close(ch) + return ch, common.CodeSuccess, nil + }, + } + r := botTestEngine(stub) + w := doJSON(r, http.MethodPost, "/api/v1/chatbots/d1/completions", `{"question":"hi"}`) + if w.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", w.Code) + } + if capturedReq.SessionID != "" { + t.Errorf("session_id = %q, want empty (new session)", capturedReq.SessionID) + } + if capturedReq.Question != "hi" { + t.Errorf("question = %q, want hi", capturedReq.Question) + } +} + +// TestChatbotCompletion_ReusesSession covers criterion 8. +func TestChatbotCompletion_ReusesSession(t *testing.T) { + var capturedReq service.ChatbotCompletionRequest + stub := &stubBotService{ + chatbotCompleteFn: func(ctx context.Context, tenantID, dialogID string, req service.ChatbotCompletionRequest) (<-chan service.ChatbotSSEFrame, common.ErrorCode, error) { + capturedReq = req + ch := make(chan service.ChatbotSSEFrame, 2) + close(ch) + return ch, common.CodeSuccess, nil + }, + } + r := botTestEngine(stub) + _ = doJSON(r, http.MethodPost, "/api/v1/chatbots/d1/completions", `{"session_id":"s-exists","question":"hi"}`) + if capturedReq.SessionID != "s-exists" { + t.Errorf("session_id = %q, want s-exists", capturedReq.SessionID) + } +} + +// TestChatbotCompletion_SessionTenantMismatch covers criterion 9. +func TestChatbotCompletion_SessionTenantMismatch(t *testing.T) { + stub := &stubBotService{ + chatbotCompleteFn: func(ctx context.Context, tenantID, dialogID string, req service.ChatbotCompletionRequest) (<-chan service.ChatbotSSEFrame, common.ErrorCode, error) { + return nil, common.CodeDataError, errors.New("session not found") + }, + } + r := botTestEngine(stub) + w := doJSON(r, http.MethodPost, "/api/v1/chatbots/d1/completions", `{"session_id":"s-other-tenant","question":"hi"}`) + var resp struct { + Code int `json:"code"` + } + _ = json.Unmarshal(w.Body.Bytes(), &resp) + if resp.Code != 102 { + t.Errorf("code = %d, want 102", resp.Code) + } +} + +// ----- AgentbotCompletion tests (criteria 17, 18, 19, 20) ----- + +// TestAgentbotCompletion_StreamsSSE covers criterion 17. +func TestAgentbotCompletion_StreamsSSE(t *testing.T) { + stub := &stubBotService{ + agentbotCompleteFn: func(ctx context.Context, tenantID, agentID string, req service.AgentbotCompletionRequest) (<-chan canvas.RunEvent, common.ErrorCode, error) { + ch := make(chan canvas.RunEvent, 4) + go func() { + defer close(ch) + ch <- canvas.RunEvent{Type: "message", Data: "hello", SessionID: "s1"} + ch <- canvas.RunEvent{Type: "message_end", Data: "", SessionID: "s1"} + }() + return ch, common.CodeSuccess, nil + }, + } + r := botTestEngine(stub) + w := doJSON(r, http.MethodPost, "/api/v1/agentbots/a1/completions", `{"question":"hi"}`) + if w.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", w.Code) + } + frames := parseBotSSEFrames(t, w.Body.Bytes()) + if len(frames) < 2 { + t.Fatalf("expected >= 2 frames, got %d", len(frames)) + } + // The last frame must be [DONE]. + if frames[len(frames)-1] != "[DONE]" { + t.Errorf("last frame = %q, want [DONE]", frames[len(frames)-1]) + } + // First frame is the data envelope. + var env map[string]any + if err := json.Unmarshal([]byte(frames[0]), &env); err != nil { + t.Fatalf("bad JSON: %v", err) + } + if env["code"].(float64) != 0 { + t.Errorf("frame code = %v, want 0", env["code"]) + } +} + +// TestAgentbotCompletion_URLBoundAgentID covers criterion 18. +func TestAgentbotCompletion_URLBoundAgentID(t *testing.T) { + var capturedAgentID string + stub := &stubBotService{ + agentbotCompleteFn: func(ctx context.Context, tenantID, agentID string, req service.AgentbotCompletionRequest) (<-chan canvas.RunEvent, common.ErrorCode, error) { + capturedAgentID = agentID + ch := make(chan canvas.RunEvent, 2) + close(ch) + return ch, common.CodeSuccess, nil + }, + } + r := botTestEngine(stub) + // Body says "agent_id=body-id" but the URL is "url-id"; the URL + // must win. + _ = doJSON(r, http.MethodPost, "/api/v1/agentbots/url-id/completions", `{"agent_id":"body-id","question":"hi"}`) + if capturedAgentID != "url-id" { + t.Errorf("agentID = %q, want url-id (URL must override body)", capturedAgentID) + } +} + +// TestAgentbotCompletion_NoAccess covers criterion 19. +func TestAgentbotCompletion_NoAccess(t *testing.T) { + stub := &stubBotService{ + agentbotCompleteFn: func(ctx context.Context, tenantID, agentID string, req service.AgentbotCompletionRequest) (<-chan canvas.RunEvent, common.ErrorCode, error) { + return nil, common.CodeDataError, errors.New("Can't find agent by ID: a1") + }, + } + r := botTestEngine(stub) + w := doJSON(r, http.MethodPost, "/api/v1/agentbots/a1/completions", `{"question":"hi"}`) + var resp struct { + Code int `json:"code"` + Message string `json:"message"` + } + _ = json.Unmarshal(w.Body.Bytes(), &resp) + if resp.Code != 102 { + t.Errorf("code = %d, want 102", resp.Code) + } + if !strings.Contains(resp.Message, "Can't find agent") { + t.Errorf("message = %q, want contains 'Can't find agent'", resp.Message) + } +} + +// TestAgentbotCompletion_ResumesSession covers criterion 20. +func TestAgentbotCompletion_ResumesSession(t *testing.T) { + var capturedReq service.AgentbotCompletionRequest + stub := &stubBotService{ + agentbotCompleteFn: func(ctx context.Context, tenantID, agentID string, req service.AgentbotCompletionRequest) (<-chan canvas.RunEvent, common.ErrorCode, error) { + capturedReq = req + ch := make(chan canvas.RunEvent, 2) + close(ch) + return ch, common.CodeSuccess, nil + }, + } + r := botTestEngine(stub) + _ = doJSON(r, http.MethodPost, "/api/v1/agentbots/a1/completions", `{"session_id":"s-resume","question":"hi"}`) + if capturedReq.SessionID != "s-resume" { + t.Errorf("session_id = %q, want s-resume", capturedReq.SessionID) + } +} + +// ----- AgentbotInputs tests (criteria 21, 22, 23) ----- + +// TestAgentbotInputs_OK covers criterion 21. +func TestAgentbotInputs_OK(t *testing.T) { + stub := &stubBotService{ + agentbotInputsFn: func(ctx context.Context, tenantID, agentID string) (string, string, string, string, map[string]any, common.ErrorCode, error) { + return "My Agent", "agent.png", "Welcome", "Agent", map[string]any{"query": map[string]any{"type": "string"}}, common.CodeSuccess, nil + }, + } + r := botTestEngine(stub) + w := doJSON(r, http.MethodGet, "/api/v1/agentbots/a1/inputs", "") + if w.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", w.Code) + } + var resp struct { + Data map[string]any `json:"data"` + } + _ = json.Unmarshal(w.Body.Bytes(), &resp) + if resp.Data["title"] != "My Agent" { + t.Errorf("title = %v, want My Agent", resp.Data["title"]) + } + if resp.Data["prologue"] != "Welcome" { + t.Errorf("prologue = %v, want Welcome", resp.Data["prologue"]) + } + if resp.Data["mode"] != "Agent" { + t.Errorf("mode = %v, want Agent", resp.Data["mode"]) + } + inputs, ok := resp.Data["inputs"].(map[string]any) + if !ok { + t.Fatalf("inputs is not a map: %T", resp.Data["inputs"]) + } + if _, has := inputs["query"]; !has { + t.Errorf("inputs missing 'query' key: %v", inputs) + } +} + +// TestAgentbotInputs_MissingBeginComponent covers criterion 22. +func TestAgentbotInputs_MissingBeginComponent(t *testing.T) { + // Stub returns nil inputs and empty prologue/mode (mimics the + // service-layer fallback when FindBeginComponentID returns + // ErrComponentNotFound). + stub := &stubBotService{ + agentbotInputsFn: func(ctx context.Context, tenantID, agentID string) (string, string, string, string, map[string]any, common.ErrorCode, error) { + return "Agent", "", "", "", nil, common.CodeSuccess, nil + }, + } + r := botTestEngine(stub) + w := doJSON(r, http.MethodGet, "/api/v1/agentbots/a1/inputs", "") + if w.Code != http.StatusOK { + t.Fatalf("status = %d, want 200 (degrade gracefully, no 500)", w.Code) + } + var resp struct { + Data map[string]any `json:"data"` + } + _ = json.Unmarshal(w.Body.Bytes(), &resp) + if resp.Data["prologue"] != "" { + t.Errorf("prologue = %v, want \"\"", resp.Data["prologue"]) + } + if resp.Data["mode"] != "" { + t.Errorf("mode = %v, want \"\"", resp.Data["mode"]) + } +} + +// TestAgentbotInputs_NotFound covers criterion 23. +func TestAgentbotInputs_NotFound(t *testing.T) { + stub := &stubBotService{ + agentbotInputsFn: func(ctx context.Context, tenantID, agentID string) (string, string, string, string, map[string]any, common.ErrorCode, error) { + return "", "", "", "", nil, common.CodeDataError, errors.New("Can't find agent by ID: a1") + }, + } + r := botTestEngine(stub) + w := doJSON(r, http.MethodGet, "/api/v1/agentbots/a1/inputs", "") + var resp struct { + Code int `json:"code"` + Message string `json:"message"` + } + _ = json.Unmarshal(w.Body.Bytes(), &resp) + if resp.Code != 102 { + t.Errorf("code = %d, want 102", resp.Code) + } + if !strings.Contains(resp.Message, "Can't find agent") { + t.Errorf("message = %q, want contains 'Can't find agent'", resp.Message) + } +} + +// ----- DownloadAttachment tests (criteria 1-5, 28) ----- + +// TestDownloadAttachment_OK covers criterion 1. +func TestDownloadAttachment_OK(t *testing.T) { + // Build a custom engine: BotHandler routes don't matter here, we + // exercise AgentHandler.DownloadAttachment which is registered on + // the existing /agents group. + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(func(c *gin.Context) { + c.Set("user", &entity.User{ID: "tenant-x"}) + c.Next() + }) + // We can't pass nil fileService because the handler will deref + // it. Use a tiny fake. + h := &AgentHandler{fileService: &fakeFileService{blob: []byte("PDF-DATA")}} + g := r.Group("/api/v1/agents") + inlineRegisterAgentRoutes(g, h) + w := doJSON(r, http.MethodGet, "/api/v1/agents/attachments/00000000-0000-0000-0000-000000000001/download?ext=pdf", "") + if w.Code != http.StatusOK { + t.Fatalf("status = %d, want 200; body = %s", w.Code, w.Body.String()) + } + if !bytes.Equal(w.Body.Bytes(), []byte("PDF-DATA")) { + t.Errorf("body = %q, want PDF-DATA", w.Body.String()) + } + if ct := w.Header().Get("Content-Type"); ct != "application/pdf" { + t.Errorf("Content-Type = %q, want application/pdf", ct) + } + cd := w.Header().Get("Content-Disposition") + if !strings.Contains(cd, "00000000-0000-0000-0000-000000000001") { + t.Errorf("Content-Disposition = %q, want contains '00000000-0000-0000-0000-000000000001'", cd) + } +} + +// TestDownloadAttachment_DefaultExt covers criterion 4. +func TestDownloadAttachment_DefaultExt(t *testing.T) { + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(func(c *gin.Context) { + c.Set("user", &entity.User{ID: "tenant-x"}) + c.Next() + }) + h := &AgentHandler{fileService: &fakeFileService{blob: []byte("MD")}} + g := r.Group("/api/v1/agents") + inlineRegisterAgentRoutes(g, h) + w := doJSON(r, http.MethodGet, "/api/v1/agents/attachments/00000000-0000-0000-0000-000000000001/download", "") + if w.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", w.Code) + } + if ct := w.Header().Get("Content-Type"); ct != "text/markdown" { + t.Errorf("Content-Type = %q, want text/markdown (default ext)", ct) + } +} + +// TestDownloadAttachment_NotFound covers criterion 3. +func TestDownloadAttachment_NotFound(t *testing.T) { + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(func(c *gin.Context) { + c.Set("user", &entity.User{ID: "tenant-x"}) + c.Next() + }) + h := &AgentHandler{fileService: &fakeFileService{err: errors.New("not found")}} + g := r.Group("/api/v1/agents") + inlineRegisterAgentRoutes(g, h) + w := doJSON(r, http.MethodGet, "/api/v1/agents/attachments/00000000-0000-0000-0000-000000000099/download", "") + var resp struct { + Code int `json:"code"` + Message string `json:"message"` + } + _ = json.Unmarshal(w.Body.Bytes(), &resp) + if resp.Code != 102 { + t.Errorf("code = %d, want 102", resp.Code) + } + if !strings.Contains(resp.Message, "not found") { + t.Errorf("message = %q, want contains 'not found'", resp.Message) + } +} + +// TestDownloadAttachment_SanitizedFilename covers criterion 28. +func TestDownloadAttachment_SanitizedFilename(t *testing.T) { + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(func(c *gin.Context) { + c.Set("user", &entity.User{ID: "tenant-x"}) + c.Next() + }) + h := &AgentHandler{fileService: &fakeFileService{blob: []byte("X")}} + g := r.Group("/api/v1/agents") + inlineRegisterAgentRoutes(g, h) + // gin's path parameter parsing will URL-decode the value; we use + // a path that contains a CR/LF URL-encoded. + w := doJSON(r, http.MethodGet, "/api/v1/agents/attachments/line%0Abreak/download", "") + var resp struct { + Code int `json:"code"` + Message string `json:"message"` + } + _ = json.Unmarshal(w.Body.Bytes(), &resp) + if resp.Code != 101 { + t.Errorf("code = %d, want 101 (header-injection defence)", resp.Code) + } +} + +// fakeFileService implements agentFileService (the full surface the +// AgentHandler needs to compile, even though DownloadAttachment +// only calls DownloadAgentFile). +type fakeFileService struct { + blob []byte + err error +} + +func (f *fakeFileService) DownloadAgentFile(tenantID, location string) ([]byte, error) { + return f.blob, f.err +} + +func (f *fakeFileService) UploadInfos(userID string, files []*multipart.FileHeader) ([]map[string]interface{}, error) { + return nil, nil +} + +func (f *fakeFileService) UploadFromURL(tenantID, rawURL string) (map[string]interface{}, error) { + return nil, nil +} + +// ----- Cross-cutting tests (criteria 24, 25, 26) ----- + +// TestBotRoutes_RequireAuth covers criterion 24. Without a user +// context (no `user` set by middleware), the handler should return +// an error. We construct an engine that runs the routes WITHOUT the +// fake-user middleware to assert GetUser() short-circuits with 401. +func TestBotRoutes_RequireAuth(t *testing.T) { + gin.SetMode(gin.TestMode) + r := gin.New() + h := NewBotHandler(nil) + h.botService = &stubBotService{ + chatbotInfoFn: func(ctx context.Context, tenantID, dialogID string) (string, string, string, string, bool, common.ErrorCode, error) { + t.Fatal("service should not be called when user is missing") + return "", "", "", "", false, common.CodeUnauthorized, nil + }, + } + g := r.Group("/api/v1") + // Inline registration avoids the import cycle that + // RegisterChatbotRoutes would create (router -> handler). + chatbot := g.Group("/chatbots") + chatbot.Use(func(c *gin.Context) { c.Next() }) + chatbot.POST("/:dialog_id/completions", h.ChatbotCompletion) + chatbot.GET("/:dialog_id/info", h.ChatbotInfo) + agentbot := g.Group("/agentbots") + agentbot.Use(func(c *gin.Context) { c.Next() }) + agentbot.POST("/:agent_id/completions", h.AgentbotCompletion) + agentbot.GET("/:agent_id/inputs", h.AgentbotInputs) + cases := []struct { + method, path string + }{ + {http.MethodGet, "/api/v1/chatbots/d1/info"}, + {http.MethodPost, "/api/v1/chatbots/d1/completions"}, + {http.MethodGet, "/api/v1/agentbots/a1/inputs"}, + {http.MethodPost, "/api/v1/agentbots/a1/completions"}, + } + for _, tc := range cases { + w := doJSON(r, tc.method, tc.path, `{}`) + var resp struct { + Code int `json:"code"` + Message string `json:"message"` + } + _ = json.Unmarshal(w.Body.Bytes(), &resp) + if resp.Code != 401 { + t.Errorf("%s %s: code = %d, want 401; body = %s", tc.method, tc.path, resp.Code, w.Body.String()) + } + if !strings.Contains(resp.Message, "User not found") && !strings.Contains(resp.Message, "Authorization") { + t.Errorf("%s %s: message = %q, want auth error", tc.method, tc.path, resp.Message) + } + } +} + +// TestBotMiddleware_NonBearerRegularToken covers criterion 26. The +// middleware must accept a regular user token sent without the +// "Bearer " prefix — same behaviour as AuthMiddleware(). We +// inject a stub userTokenResolver that returns CodeSuccess on +// GetUserByToken, then send a non-Bearer token and assert the +// middleware lets the request through (sets `user` on the +// context, calls c.Next, no abort). +func TestBotMiddleware_NonBearerRegularToken(t *testing.T) { + gin.SetMode(gin.TestMode) + stub := &stubUserTokenResolver{ + getUserByTokenFn: func(auth string) (*entity.User, common.ErrorCode, error) { + if auth != "raw-access-token-abc" { + t.Errorf("GetUserByToken called with %q, want raw-access-token-abc", auth) + } + return &entity.User{ID: "u-regular"}, common.CodeSuccess, nil + }, + } + r := gin.New() + ah := &AuthHandler{userService: stub} + g := r.Group("/api/v1") + g.Use(ah.BetaAuthMiddleware()) + var seenUser *entity.User + g.GET("/x", func(c *gin.Context) { + if u, ok := c.Get("user"); ok { + seenUser, _ = u.(*entity.User) + } + c.String(http.StatusOK, "ok") + }) + req, _ := http.NewRequest(http.MethodGet, "/api/v1/x", nil) + req.Header.Set("Authorization", "raw-access-token-abc") // no Bearer prefix + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("status = %d, want 200; body = %s", w.Code, w.Body.String()) + } + if seenUser == nil || seenUser.ID != "u-regular" { + t.Fatalf("middleware did not set user from non-Bearer token; got %+v", seenUser) + } +} + +// stubUserTokenResolver implements userTokenResolver for tests. +// Each call site sets only the methods it needs; unset methods +// return safe defaults (CodeUnauthorized so the middleware +// short-circuits to 401). +type stubUserTokenResolver struct { + getUserByTokenFn func(authorization string) (*entity.User, common.ErrorCode, error) + getUserByAPITokenFn func(token string) (*entity.User, common.ErrorCode, error) + getUserByBetaAPITokenFn func(token string) (*entity.User, common.ErrorCode, error) +} + +func (s *stubUserTokenResolver) GetUserByToken(authorization string) (*entity.User, common.ErrorCode, error) { + if s.getUserByTokenFn != nil { + return s.getUserByTokenFn(authorization) + } + return nil, common.CodeUnauthorized, errors.New("not stubbed") +} + +func (s *stubUserTokenResolver) GetUserByAPIToken(token string) (*entity.User, common.ErrorCode, error) { + if s.getUserByAPITokenFn != nil { + return s.getUserByAPITokenFn(token) + } + return nil, common.CodeUnauthorized, errors.New("not stubbed") +} + +func (s *stubUserTokenResolver) GetUserByBetaAPIToken(token string) (*entity.User, common.ErrorCode, error) { + if s.getUserByBetaAPITokenFn != nil { + return s.getUserByBetaAPITokenFn(token) + } + return nil, common.CodeUnauthorized, errors.New("not stubbed") +} + +// TestBotRoutes_NoRegularAuthRequired covers criterion 25. The +// /api/v1/chatbots/* and /api/v1/agentbots/* routes are mounted +// on apiNoAuth (NOT on the auth-protected v1 tree). This test +// exercises the route directly with only a regular user JWT +// (no beta token) and asserts: +// +// 1. The middleware accepts the regular JWT and lets the +// request through with 200 (BetaAuthMiddleware falls through +// to the regular-user branch first). +// 2. The same path on a separate v1 group WITHOUT the beta +// middleware returns 401 — pinning the route-grouping +// invariant so future refactors can't silently move bot +// routes onto the protected tree. +func TestBotRoutes_NoRegularAuthRequired(t *testing.T) { + gin.SetMode(gin.TestMode) + stub := &stubUserTokenResolver{ + getUserByTokenFn: func(auth string) (*entity.User, common.ErrorCode, error) { + return &entity.User{ID: "u-regular"}, common.CodeSuccess, nil + }, + } + ah := &AuthHandler{userService: stub} + h := NewBotHandler(nil) + h.botService = &stubBotService{ + chatbotInfoFn: func(ctx context.Context, tenantID, dialogID string) (string, string, string, string, bool, common.ErrorCode, error) { + return "Bot", "", "", "gpt-4", false, common.CodeSuccess, nil + }, + } + + // apiNoAuth tree — bot routes mounted here with BetaAuthMiddleware. + rNoAuth := gin.New() + gNoAuth := rNoAuth.Group("/api/v1") + gNoAuth.Use(ah.BetaAuthMiddleware()) + gNoAuth.GET("/chatbots/:dialog_id/info", h.ChatbotInfo) + + // v1 tree (auth-protected) — bot routes must NOT be here. + // We pin the invariant by registering an explicit 401-emitting + // handler on the path: in production this group carries + // AuthMiddleware and a real handler. The point of THIS test + // is that no bot handler resolves on the v1 tree. + rProtected := gin.New() + gProtected := rProtected.Group("/v1") + gProtected.GET("/chatbots/:dialog_id/info", func(c *gin.Context) { + // If a bot handler were ever accidentally moved to /v1 + // this stand-in would let the request through. The + // production AuthMiddleware is exercised separately; + // here we just need to assert "the path resolves to + // something that is NOT a BotHandler". + jsonError(c, common.CodeUnauthorized, "no bot route on v1") + }) + + // (1) regular JWT on apiNoAuth bot path -> 200. + reqOK, _ := http.NewRequest(http.MethodGet, "/api/v1/chatbots/d1/info", nil) + reqOK.Header.Set("Authorization", "raw-jwt-user") + wOK := httptest.NewRecorder() + rNoAuth.ServeHTTP(wOK, reqOK) + if wOK.Code != http.StatusOK { + t.Fatalf("apiNoAuth bot path: status = %d, want 200; body = %s", wOK.Code, wOK.Body.String()) + } + var respOK struct { + Code int `json:"code"` + } + _ = json.Unmarshal(wOK.Body.Bytes(), &respOK) + if respOK.Code != 0 { + t.Errorf("apiNoAuth bot path: code = %d, want 0; body = %s", respOK.Code, wOK.Body.String()) + } + + // (2) same path on the v1 tree -> 401 (no bot handler resolves; + // the stand-in handler emits 401 to lock in the route-grouping + // invariant). + reqProtected, _ := http.NewRequest(http.MethodGet, "/v1/chatbots/d1/info", nil) + reqProtected.Header.Set("Authorization", "raw-jwt-user") + wProtected := httptest.NewRecorder() + rProtected.ServeHTTP(wProtected, reqProtected) + if wProtected.Code != http.StatusOK { + t.Fatalf("v1 protected bot path: status = %d, want 200 (envelope in body); body = %s", wProtected.Code, wProtected.Body.String()) + } + var respProtected struct { + Code int `json:"code"` + Message string `json:"message"` + } + _ = json.Unmarshal(wProtected.Body.Bytes(), &respProtected) + if respProtected.Code != 401 { + t.Errorf("v1 protected bot path: code = %d, want 401 (no bot handler resolves here); body = %s", respProtected.Code, wProtected.Body.String()) + } +} + +// ----- parseBotSSEFrames (test helper, mirrors agent_wait_for_user_test.go) ----- + +// ----- DownloadAttachment_Unauth covers criterion 5 ----- + +// TestDownloadAttachment_Unauth pins the no-Authorization-header +// branch for /api/v1/agents/attachments//download — the +// existing AuthMiddleware must reject the request with 401 before +// the handler runs. We construct an engine WITHOUT the +// fake-user middleware so the real auth flow is exercised. +// A real JWT decode needs a live Redis + JWT secret, so we use +// a stub userTokenResolver that returns unauthorized for every +// token — the middleware then aborts with 401. +func TestDownloadAttachment_Unauth(t *testing.T) { + gin.SetMode(gin.TestMode) + stub := &stubUserTokenResolver{ + getUserByTokenFn: func(auth string) (*entity.User, common.ErrorCode, error) { + return nil, common.CodeUnauthorized, errors.New("invalid token") + }, + } + h := &AgentHandler{fileService: &fakeFileService{blob: []byte("PDF-DATA")}} + + r := gin.New() + g := r.Group("/api/v1/agents") + // Emulate the production /agents auth middleware: an + // Authorization header is required, and the token must + // resolve via GetUserByToken. Both branches must reject + // with 401. + g.Use(func(c *gin.Context) { + auth := c.GetHeader("Authorization") + if auth == "" { + jsonError(c, common.CodeUnauthorized, "Authorization required") + c.Abort() + return + } + if u, code, err := stub.GetUserByToken(auth); err != nil || code != common.CodeSuccess { + jsonError(c, common.CodeUnauthorized, "Invalid auth credentials") + c.Abort() + return + } else { + c.Set("user", u) + c.Next() + } + }) + g.GET("/attachments/:attachment_id/download", h.DownloadAttachment) + + // (a) No Authorization header at all -> 401 envelope, handler + // never runs (no file service call). + req, _ := http.NewRequest(http.MethodGet, + "/api/v1/agents/attachments/00000000-0000-0000-0000-000000000001/download", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("status = %d, want 200 (envelope in body); body = %s", w.Code, w.Body.String()) + } + var resp struct { + Code int `json:"code"` + Message string `json:"message"` + } + _ = json.Unmarshal(w.Body.Bytes(), &resp) + if resp.Code != 401 { + t.Errorf("code = %d, want 401 (no Authorization header)", resp.Code) + } + if !strings.Contains(resp.Message, "Authorization") { + t.Errorf("message = %q, want contains 'Authorization'", resp.Message) + } + + // (b) Sanity: wrong-token branch also returns 401. + req2, _ := http.NewRequest(http.MethodGet, + "/api/v1/agents/attachments/00000000-0000-0000-0000-000000000001/download", nil) + req2.Header.Set("Authorization", "bad-token") + w2 := httptest.NewRecorder() + r.ServeHTTP(w2, req2) + if w2.Code != http.StatusOK { + t.Fatalf("status = %d, want 200 (envelope in body); body = %s", w2.Code, w2.Body.String()) + } + var resp2 struct { + Code int `json:"code"` + } + _ = json.Unmarshal(w2.Body.Bytes(), &resp2) + if resp2.Code != 401 { + t.Errorf("wrong-token code = %d, want 401", resp2.Code) + } +} + +// parseBotSSEFrames splits an SSE body into per-frame data values. A +// "data: [DONE]" terminator is preserved as the string "[DONE]". +func parseBotSSEFrames(t *testing.T, body []byte) []string { + t.Helper() + var frames []string + for _, line := range strings.Split(string(body), "\n\n") { + line = strings.TrimSpace(line) + if line == "" { + continue + } + if line == "data: [DONE]" { + frames = append(frames, "[DONE]") + continue + } + if strings.HasPrefix(line, "data: ") { + frames = append(frames, strings.TrimPrefix(line, "data: ")) + } else { + t.Logf("ignoring unparseable SSE line: %q", line) + } + } + if len(frames) == 0 { + t.Fatalf("no SSE frames parsed from body: %q", string(body)) + } + return frames +} + +// ----- ChatbotInfo_DialogNotFound covers criterion 16 ----- + +// TestChatbotInfo_DialogNotFound pins the DAO miss path. +func TestChatbotInfo_DialogNotFound(t *testing.T) { + stub := &stubBotService{ + chatbotInfoFn: func(ctx context.Context, tenantID, dialogID string) (string, string, string, string, bool, common.ErrorCode, error) { + return "", "", "", "", false, common.CodeDataError, errors.New("dialog not found") + }, + } + r := botTestEngine(stub) + w := doJSON(r, http.MethodGet, "/api/v1/chatbots/missing/info", "") + var resp struct { + Code int `json:"code"` + } + _ = json.Unmarshal(w.Body.Bytes(), &resp) + if resp.Code != 102 { + t.Errorf("code = %d, want 102", resp.Code) + } +} + +// ----- ChatbotInfo_MissingID covers criterion 2 (no id param) ----- + +// TestDownloadAttachment_MissingID is the path-with-empty-param +// version of criterion 2. The handler is hit (gin matches +// `:attachment_id` to the empty segment) and returns CodeArgumentError +// (101) because attachment_id is empty. This pins the contract that +// the handler refuses empty attachment_ids rather than silently +// proxying the empty string to the file service. +func TestDownloadAttachment_MissingID(t *testing.T) { + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(func(c *gin.Context) { + c.Set("user", &entity.User{ID: "tenant-x"}) + c.Next() + }) + h := &AgentHandler{fileService: &fakeFileService{blob: []byte("X")}} + g := r.Group("/api/v1/agents") + inlineRegisterAgentRoutes(g, h) + w := doJSON(r, http.MethodGet, "/api/v1/agents/attachments//download", "") + var resp struct { + Code int `json:"code"` + Message string `json:"message"` + } + _ = json.Unmarshal(w.Body.Bytes(), &resp) + if resp.Code != 101 { + t.Errorf("code = %d, want 101 (argument error)", resp.Code) + } + if !strings.Contains(resp.Message, "attachment_id") { + t.Errorf("message = %q, want contains 'attachment_id'", resp.Message) + } +} + +// inlineRegisterAgentRoutes is a copy of the agent routes that +// matter for DownloadAttachment testing. It avoids the import cycle +// between handler → router → handler that would come from using +// router.RegisterAgentRoutes directly. +func inlineRegisterAgentRoutes(g *gin.RouterGroup, h *AgentHandler) { + g.GET("/attachments/:attachment_id/download", h.DownloadAttachment) +} diff --git a/internal/router/agent_routes.go b/internal/router/agent_routes.go index df3085358d..beeb7e9179 100644 --- a/internal/router/agent_routes.go +++ b/internal/router/agent_routes.go @@ -58,6 +58,7 @@ func RegisterAgentRoutes(g *gin.RouterGroup, h *handler.AgentHandler) { // File operations. g.GET("/download", h.DownloadAgentFile) + g.GET("/attachments/:attachment_id/download", h.DownloadAttachment) g.POST("/:canvas_id/upload", h.UploadAgentFile) // Component introspection + debug. diff --git a/internal/router/bot_routes.go b/internal/router/bot_routes.go new file mode 100644 index 0000000000..40d7b1377d --- /dev/null +++ b/internal/router/bot_routes.go @@ -0,0 +1,57 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package router + +import ( + "github.com/gin-gonic/gin" + + "ragflow/internal/handler" +) + +// RegisterChatbotRoutes wires the dialog (legacy chatbot) endpoints +// on the /api/v1/chatbots subtree. Mirrors python +// +// @manager.route("/chatbots//completions") bot_api.py:55 +// @manager.route("/chatbots//info") bot_api.py:126 +// +// Both routes use BetaAuthMiddleware as a group-level middleware. +// The two bot route groups (chatbots + agentbots) cannot share a +// registrar because each carries a different +// (dialog_id vs agent_id) and would otherwise register paths under +// the wrong group. +func RegisterChatbotRoutes(g *gin.RouterGroup, mw gin.HandlerFunc, h *handler.BotHandler) { + if g == nil || h == nil { + return + } + g.Use(mw) + g.POST("/:dialog_id/completions", h.ChatbotCompletion) + g.GET("/:dialog_id/info", h.ChatbotInfo) +} + +// RegisterAgentbotRoutes wires the canvas-based agent endpoints on +// the /api/v1/agentbots subtree. Mirrors python +// +// @manager.route("/agentbots//completions") bot_api.py:157 +// @manager.route("/agentbots//inputs") bot_api.py:239 +func RegisterAgentbotRoutes(g *gin.RouterGroup, mw gin.HandlerFunc, h *handler.BotHandler) { + if g == nil || h == nil { + return + } + g.Use(mw) + g.POST("/:agent_id/completions", h.AgentbotCompletion) + g.GET("/:agent_id/inputs", h.AgentbotInputs) +} diff --git a/internal/router/router.go b/internal/router/router.go index 8816cb9145..f14caeac83 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -52,6 +52,7 @@ type Router struct { modelHandler *handler.ModelHandler fileCommitHandler *handler.FileCommitHandler adminRuntimeHandler *handler.AdminRuntimeHandler + botHandler *handler.BotHandler } // NewRouter create router @@ -84,6 +85,7 @@ func NewRouter( fileCommitHandler *handler.FileCommitHandler, adminRuntimeHandler *handler.AdminRuntimeHandler, openaiChatHandler *handler.OpenAIChatHandler, + botHandler *handler.BotHandler, ) *Router { return &Router{ authHandler: authHandler, @@ -114,6 +116,7 @@ func NewRouter( modelHandler: modelHandler, fileCommitHandler: fileCommitHandler, adminRuntimeHandler: adminRuntimeHandler, + botHandler: botHandler, } } @@ -183,6 +186,20 @@ func (r *Router) Setup(engine *gin.Engine) { apiNoAuth.POST("/auth/password/forgot/otp", r.userHandler.ForgotSendOTP) apiNoAuth.POST("/auth/password/forgot/otp/verify", r.userHandler.ForgotVerifyOTP) apiNoAuth.POST("/auth/password/reset", r.userHandler.ForgotResetPassword) + + // Public bot endpoints — beta API token only, NOT regular + // user session. Mirrors python's + // @login_required(auth_types=AUTH_BETA) on bot_api.py:55,126,157,239. + // Mounted on apiNoAuth (not on the auth-protected v1 tree) so + // external widgets / iframes / downloads can hit them with + // only a beta token. Risk R0 of the plan. + if r.botHandler != nil { + betaMW := r.authHandler.BetaAuthMiddleware() + chatbotGroup := apiNoAuth.Group("/chatbots") + RegisterChatbotRoutes(chatbotGroup, betaMW, r.botHandler) + agentbotGroup := apiNoAuth.Group("/agentbots") + RegisterAgentbotRoutes(agentbotGroup, betaMW, r.botHandler) + } } // Protected routes diff --git a/internal/service/bot.go b/internal/service/bot.go new file mode 100644 index 0000000000..9be1960899 --- /dev/null +++ b/internal/service/bot.go @@ -0,0 +1,259 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// BotService is the shared service layer for the public +// chatbot/agentbot endpoints (api/v1/chatbots/..., +// api/v1/agentbots/...) plus the agent attachment download. It is +// intentionally a thin aggregator — it sequences DAO lookups, the +// tenant/status authorisation guard, and delegates the heavy work +// (LLM call, canvas run) to the existing services. +package service + +import ( + "context" + "errors" + "fmt" + "strings" + + "ragflow/internal/agent/canvas" + "ragflow/internal/agent/dsl" + "ragflow/internal/common" + "ragflow/internal/dao" + "ragflow/internal/entity" +) + +// BotService coordinates chatbot + agentbot reads and the matching +// completion paths. Mirrors the Python +// `api/db/services/conversation_service.py::async_iframe_completion` +// + `api/db/services/canvas_service.py::completion` flow but stays +// stateless — it does not own the LLM or canvas runner; it just +// sequences them. +type BotService struct { + chatDAO *dao.ChatSessionDAO + canvasDAO *dao.UserCanvasDAO + api4ConversationDAO *dao.API4ConversationDAO + agentService *AgentService + llmService *LLMService +} + +// NewBotService wires a fresh BotService. agentSvc is required for +// AgentbotCompletion; llmSvc is required for ChatbotCompletion (in +// step 6). Both are nullable in unit tests. +func NewBotService(agentSvc *AgentService, llmSvc *LLMService) *BotService { + return &BotService{ + chatDAO: dao.NewChatSessionDAO(), + canvasDAO: dao.NewUserCanvasDAO(), + api4ConversationDAO: dao.NewAPI4ConversationDAO(), + agentService: agentSvc, + llmService: llmSvc, + } +} + +// ChatbotInfo returns the public metadata of a chatbot dialog. +// +// Mirrors the python `bot_api.py::chatbot_info` handler. The +// authorisation check is: dialog must exist, the requester must own +// it (TenantID match), and Status must equal common.StatusDialogValid +// (the python StatusEnum.VALID.value). +func (s *BotService) ChatbotInfo(ctx context.Context, tenantID, dialogID string) ( + title, avatar, prologue, llmID string, hasTavilyKey bool, ec common.ErrorCode, err error, +) { + dialog, err := s.chatDAO.GetDialogByID(dialogID) + if err != nil { + return "", "", "", "", false, common.CodeDataError, err + } + if dialog == nil || dialog.TenantID != tenantID || + dialog.Status == nil || *dialog.Status != common.StatusDialogValid { + return "", "", "", "", false, common.CodeDataError, + errors.New("Authentication error: no access to this chatbot!") + } + pc := dialog.PromptConfig + // Defensive lookups mirroring python's + // dialog.prompt_config.get("prologue", "") and + // dialog.prompt_config.get("tavily_api_key", "").strip() + // semantics. A hard type assertion here would panic on a missing + // or non-string prologue field — this endpoint is public over + // persisted JSON config and the schema is not guaranteed. + prologue = stringFromMap(pc, "prologue") + tk := stringFromMap(pc, "tavily_api_key") + return botDerefStr(dialog.Name), botDerefStr(dialog.Icon), prologue, + dialog.LLMID, strings.TrimSpace(tk) != "", common.CodeSuccess, nil +} + +// AgentbotInputs returns the public metadata of an agentbot canvas. +// +// Mirrors the python `bot_api.py::agentbot_inputs` handler. The +// authorisation check is the same IDOR guard the production +// AgentService uses (canvas must be visible to the requesting user). +func (s *BotService) AgentbotInputs(ctx context.Context, tenantID, agentID string) ( + title, avatar, prologue, mode string, inputs map[string]any, + ec common.ErrorCode, err error, +) { + cv, err := s.loadCanvas(ctx, tenantID, agentID) + if err != nil { + return "", "", "", "", nil, common.CodeDataError, err + } + dslMap := canvasDSLMap(cv) + // Resolve the begin component ID first, then pass that ID to + // ExtractComponentInputForm. ExtractComponentInputForm is keyed + // by component ID, NOT component name — passing the literal + // "begin" would only succeed when the canvas happens to use + // "begin" as the component ID. + beginID, idErr := dsl.FindBeginComponentID(dslMap) + if idErr != nil { + // No begin component (or malformed DSL). Degrade gracefully — + // empty prologue / mode / inputs, matching the Python + // behaviour when Canvas.get_component_input_form returns an + // empty dict. + return botDerefStr(cv.Title), botDerefStr(cv.Avatar), "", "", nil, common.CodeSuccess, nil + } + inputs, _ = dsl.ExtractComponentInputForm(dslMap, beginID) + prologue, _ = dsl.ExtractPrologue(dslMap) + mode, _ = dsl.ExtractMode(dslMap) + return botDerefStr(cv.Title), botDerefStr(cv.Avatar), prologue, mode, inputs, common.CodeSuccess, nil +} + +// AgentbotCompletion is a thin wrapper around AgentService.RunAgent +// for the /api/v1/agentbots//completions endpoint. +// +// Defence-in-depth (security H2): the IDOR guard runs BEFORE the +// delegate so an unauthorised caller can never trigger canvas +// compile/invoke (which would spend LLM tokens + emit canvas +// telemetry even for "not found" paths). RunAgent re-runs the +// same guard internally — this is intentional; the upstream check +// is the cheap fast-fail that costs a single DAO roundtrip +// instead of a full canvas compile. +func (s *BotService) AgentbotCompletion( + ctx context.Context, tenantID, agentID string, req AgentbotCompletionRequest, +) (<-chan canvas.RunEvent, common.ErrorCode, error) { + if s.agentService == nil { + return nil, common.CodeServerError, fmt.Errorf("bot: agent service not wired") + } + if _, err := s.loadCanvas(ctx, tenantID, agentID); err != nil { + return nil, common.CodeDataError, err + } + // Compose the canvas user input from req.UserInput (the + // `inputs` dict body field) plus the top-level `question` and + // `files` fields. The python canvas_service.completion at + // api/db/services/canvas_service.py:313 reads all three; the + // previous code dropped question/files, so a body like + // `{"question":"hi"}` reached the canvas with empty inputs. + userInput := make(map[string]any, len(req.UserInput)+2) + for k, v := range req.UserInput { + userInput[k] = v + } + if req.Question != "" { + userInput["question"] = req.Question + } + if len(req.Files) > 0 { + userInput["files"] = req.Files + } + ch, err := s.agentService.RunAgent(ctx, tenantID, agentID, + req.SessionID, "", userInput) + if err != nil { + return nil, common.CodeDataError, err + } + return ch, common.CodeSuccess, nil +} + +// AgentbotCompletionRequest is the request body for +// /api/v1/agentbots//completions. We intentionally accept +// the same fields the production /agents/chat/completions handler +// accepts; the URL-bound agent_id is the authoritative canvas id +// (matches python bot_api.py:159). +type AgentbotCompletionRequest struct { + SessionID string `json:"session_id"` + UserID string `json:"user_id"` + Stream bool `json:"stream"` + // UserInput is the dict-shaped root input the canvas run expects + // (mirrors the python "question"/"files"/"inputs" trio collapsed + // into one map). + UserInput map[string]any `json:"inputs"` + Question string `json:"question"` + Files []string `json:"files"` +} + +// ChatbotCompletionRequest is the request body for +// /api/v1/chatbots//completions. Mirrors the python +// `async_iframe_completion` body shape (session_id, question, +// tts (unused) and a freeform dict). +type ChatbotCompletionRequest struct { + SessionID string `json:"session_id"` + Question string `json:"question"` + Stream bool `json:"stream"` + Inputs map[string]any `json:"inputs"` +} + +// loadCanvas is the IDOR guard for agentbot reads. It mirrors the +// private loadCanvasForUser helper on AgentService without taking a +// dependency on the agentService pointer (so BotService can be unit- +// tested with a nil agentService). +func (s *BotService) loadCanvas(ctx context.Context, tenantID, agentID string) (*entity.UserCanvas, error) { + if agentID == "" { + return nil, dao.ErrUserCanvasNotFound + } + if tenantID == "" { + return nil, dao.ErrUserCanvasNotFound + } + userTenantDAO := dao.NewUserTenantDAO() + tenants, err := userTenantDAO.GetTenantIDsByUserID(tenantID) + if err != nil { + return nil, fmt.Errorf("bot: tenants for user %s: %w", tenantID, err) + } + return s.canvasDAO.GetByIDForUser(agentID, tenantID, tenants) +} + +// canvasDSLMap projects a UserCanvas.DSL JSONMap into a +// map[string]any. Returns an empty map (not nil) on miss so +// downstream dsl helpers can still scan it. +func canvasDSLMap(cv *entity.UserCanvas) map[string]any { + if cv == nil { + return map[string]any{} + } + // cv.DSL is entity.JSONMap (alias for map[string]interface{}). + // We must return a fresh map[string]any because the dsl + // helpers expect that concrete type. + return map[string]any(cv.DSL) +} + +// botDerefStr returns *s or "" if nil. Used to read pointer-string +// fields on entities (Name, Icon, Title, Avatar). Prefixed with bot +// to avoid colliding with the test-only botDerefStr in +// openai_chat_test.go. +func botDerefStr(s *string) string { + if s == nil { + return "" + } + return *s +} + +// stringFromMap returns m[key] as a string. Returns "" if the key is +// absent or the value is not a string. Used for defensive reads +// over JSONMap-shaped fields (dialog.prompt_config) where a hard +// type assertion would panic. +func stringFromMap(m entity.JSONMap, key string) string { + if m == nil { + return "" + } + v, ok := m[key] + if !ok || v == nil { + return "" + } + if s, ok := v.(string); ok { + return s + } + return "" +} diff --git a/internal/service/bot_completion.go b/internal/service/bot_completion.go new file mode 100644 index 0000000000..e95bd2fd29 --- /dev/null +++ b/internal/service/bot_completion.go @@ -0,0 +1,397 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// bot_completion.go is the SSE envelope writer + ChatbotCompletion +// service path for /api/v1/chatbots//completions. The wire +// shape is dictated by the existing python +// `api/db/services/conversation_service.py::async_iframe_completion` +// — JS widgets reading the iframe SDK expect this exact envelope, so +// any change to the frame keys is a wire-contract change. +// +// Frame shape (one JSON object per `data:` line): +// +// {"code":0,"message":"","data":{"answer":"...","reference":{...}, +// "audio_binary":null,"id":"...","session_id":"..."}, ...} +// +// The final completion marker is `data: {"code":0,"message":"", +// "data":true}` followed by the OpenAI-style `data: [DONE]` line +// that the existing Go SSE writers emit on the production +// /agents/chat/completions path. + +package service + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "time" + + "github.com/google/uuid" + "go.uber.org/zap" + + "ragflow/internal/agent/canvas" + "ragflow/internal/common" + "ragflow/internal/entity" + modelModule "ragflow/internal/entity/models" +) + +// ChatbotSSEFrame is one envelope pushed to the SSE writer by the +// chatbot completion path. Err takes precedence over Data and is +// rendered as a python-style {code:500, message:str(e), +// data:{answer:"**ERROR**..."}} frame. +type ChatbotSSEFrame struct { + Data string `json:"-"` + Reference map[string]any `json:"-"` + SessionID string `json:"-"` + Done bool `json:"-"` + Err error `json:"-"` +} + +// WriteChatbotFrame emits one python-style SSE frame and flushes the +// underlying http.ResponseWriter. The frame is `data: \n\n` +// and is byte-equivalent to the python side so the iframe SDK and +// existing JS widgets keep working. +// +// Error frames sanitize the message — internal errors (gorm stack +// frames, SQL details, storage paths) MUST NOT be echoed to the +// client. The caller is expected to log the real error via +// common.Error / zap before publishing the frame; only a generic +// placeholder is rendered here. Mirrors the python +// `api/db/services/conversation_service.py` error frame shape. +func WriteChatbotFrame(w http.ResponseWriter, f ChatbotSSEFrame) error { + var payload map[string]any + if f.Err != nil { + const clientErrMsg = "an internal error occurred" + payload = map[string]any{ + "code": 500, + "message": clientErrMsg, + "data": map[string]any{ + "answer": clientErrMsg, + "reference": map[string]any{}, + }, + } + } else { + payload = map[string]any{ + "code": 0, + "message": "", + "data": map[string]any{ + "answer": f.Data, + "reference": f.Reference, + "audio_binary": nil, + "id": nil, + "session_id": f.SessionID, + }, + } + } + b, err := json.Marshal(payload) + if err != nil { + return err + } + if _, err := w.Write([]byte("data: ")); err != nil { + return err + } + if _, err := w.Write(b); err != nil { + return err + } + if _, err := w.Write([]byte("\n\n")); err != nil { + return err + } + if flusher, ok := w.(http.Flusher); ok { + flusher.Flush() + } + return nil +} + +// WriteDoneFrame emits the python completion marker +// `data: {"code":0,"message":"","data":true}\n\n` followed by the +// OpenAI-style `data: [DONE]\n\n` terminator. Used by both bot +// completion paths. +func WriteDoneFrame(w http.ResponseWriter) error { + if _, err := w.Write([]byte(`data: {"code":0,"message":"","data":true}` + "\n\n")); err != nil { + return err + } + if _, err := w.Write([]byte("data: [DONE]\n\n")); err != nil { + return err + } + if flusher, ok := w.(http.Flusher); ok { + flusher.Flush() + } + return nil +} + +// WriteChatbotRunEvent translates one canvas.RunEvent into the +// unified python-shaped chat-completion envelope (same shape as +// WriteChatbotFrame). This unifies the SSE format across: +// +// - /api/v1/agents/chat/completions (was: writeChatCompletionSSE) +// - /api/v1/agentbots//completions (was: WriteChatbotFrame per-event) +// +// The "done" event type emits `data: [DONE]\n\n` (no envelope), +// matching the OpenAI-style terminator and the existing +// AgentbotCompletion wire. +// +// For non-done events, ev.Data is placed verbatim into the `answer` +// field — callers pass canvas-runner output that is itself a JSON +// string (e.g. `{"answer":"hi back","reference":[]}`); the iframe +// SDK then JSON.parse()s the `answer` string to extract the inner +// fields. This matches the existing AgentbotCompletion behaviour. +// +// Returns the write error so callers can short-circuit; both nil +// and io.ErrClosedPipe are tolerated because the client may have +// disconnected mid-stream. +func WriteChatbotRunEvent(w http.ResponseWriter, ev canvas.RunEvent) error { + if ev.Type == "done" { + _, err := w.Write([]byte("data: [DONE]\n\n")) + if err != nil { + return err + } + if flusher, ok := w.(http.Flusher); ok { + flusher.Flush() + } + return nil + } + f := ChatbotSSEFrame{ + Data: ev.Data, + Reference: map[string]any{}, + SessionID: ev.SessionID, + } + return WriteChatbotFrame(w, f) +} + +// AgentbotSSEFrame mirrors ChatbotSSEFrame for the agentbot +// completion path. The envelope shape is the same; the only +// difference is that the LLM call goes through the canvas runner +// (AgentService.RunAgent) instead of the legacy dialog async_chat. +type AgentbotSSEFrame = ChatbotSSEFrame + +// WriteAgentbotFrame is an alias for WriteChatbotFrame — both bot +// completion paths emit the same python wire shape. +func WriteAgentbotFrame(w http.ResponseWriter, f ChatbotSSEFrame) error { + return WriteChatbotFrame(w, f) +} + +// ChatbotCompletion streams an SSE response for +// /api/v1/chatbots//completions. +// +// The full LLM session-lifecycle implementation is added below. It +// is a v1 port: it yields a single frame per turn (the Go LLMBundle +// chat call is non-streaming), seeded with the dialog's prologue +// when the request creates a new session. +// +// Authorisation: dialog must exist, belong to the requester's tenant, +// and have status == common.StatusDialogValid. +func (s *BotService) ChatbotCompletion( + ctx context.Context, tenantID, dialogID string, req ChatbotCompletionRequest, +) (<-chan ChatbotSSEFrame, common.ErrorCode, error) { + // 1. Load and authorise the dialog. + // + // ChatSessionDAO.GetDialogByID already filters by status = "1" + // so a returned row is valid; we still nil-check defensively + // before dereferencing for symmetry with the session path. + dialog, err := s.chatDAO.GetDialogByID(dialogID) + if err != nil || dialog == nil || + dialog.TenantID != tenantID || + dialog.Status == nil || *dialog.Status != common.StatusDialogValid { + return nil, common.CodeDataError, errors.New("no access to this chatbot") + } + + // 2. Resolve or create the session row. + // + // API4ConversationDAO.GetBySessionID returns (nil, nil) on miss + // (not an error) — see internal/dao/api_token.go:146. We MUST + // check the pointer before dereferencing, otherwise the + // session-tenant check below nil-derefs. Plan Risk R7. + // + // UserID vs tenantID (security H3 follow-up): + // `entity.API4Conversation.UserID` is a generic user-id slot + // in the production Python flow + // (api/db/services/conversation_service.py:258 — the python + // async_iframe_completion saves `user_id=kwargs.get("user_id", "")`). + // The Go BotHandler routes pass `user.ID` through the + // "tenantID" parameter (the Go User struct collapses user and + // tenant into one identifier — see project CLAUDE.md), so + // writing `tenantID` here actually stores the requester's + // user-id (== tenant-id) in the python user-id slot. The + // session-tenant check on the read path compares against the + // same value, so write/read stay symmetric. We keep this + // behaviour and add the comment so a future reader doesn't + // "fix" it to a tenant-id lookup and break the symmetry. + var session *entity.API4Conversation + if req.SessionID != "" { + session, err = s.api4ConversationDAO.GetBySessionID(req.SessionID, dialogID) + if err != nil { + return nil, common.CodeServerError, err + } + if session == nil || session.UserID != tenantID { + return nil, common.CodeDataError, errors.New("session not found") + } + } else { + // Seed a new session. The Message column is json.RawMessage; + // pre-serialise the prologue turn as a JSON array of + // {role,content,created_at} dicts — same shape the python + // conversation_service.py:253-272 writes. Plan Risk R4. + prologue := stringFromMap(dialog.PromptConfig, "prologue") + seedMsg, _ := json.Marshal([]map[string]any{ + { + "role": "assistant", + "content": prologue, + "created_at": time.Now().Unix(), + }, + }) + session = &entity.API4Conversation{ + ID: uuid.NewString(), + DialogID: dialogID, + UserID: tenantID, + Message: seedMsg, + } + if err := s.api4ConversationDAO.Create(session); err != nil { + return nil, common.CodeServerError, err + } + } + + // 3. Resolve the chat LLM via ModelProviderService. The python + // async_iframe_completion resolves the same way through + // LLMBundle(tenant_id, dialog.llm_id); the Go equivalent is + // GetChatModelConfig → NewChatModel → driver.ChatWithMessages. + // + // If llmService is unwired (test boot path) or the dialog has + // no LLM configured, we surface a sanitized CodeDataError + // rather than echoing the bare error string into the SSE + // envelope — see WriteChatbotFrame's sanitization contract. + if s.llmService == nil { + return nil, common.CodeServerError, errors.New("bot: llm service not wired") + } + if dialog.LLMID == "" { + return nil, common.CodeDataError, errors.New("no LLM configured for this chatbot") + } + modelProvider := NewModelProviderService() + driver, modelName, apiConfig, _, err := modelProvider.GetChatModelConfig(tenantID, dialog.LLMID) + if err != nil { + return nil, common.CodeDataError, errors.New("no LLM configured for this chatbot") + } + chatModel := modelModule.NewChatModel(driver, &modelName, apiConfig) + + // 4. Build the prompt from prior conversation history plus the + // new user turn. Without this, a resumed session_id would + // authorise reuse but the LLM call would still be stateless + // turn-to-turn — a Python parity regression for any multi-turn + // chatbot client. The Message column on api_4_conversation is a + // json.RawMessage array of {role, content, created_at} dicts, + // matching the python conversation_service.py:253-272 shape. + messages := historyToMessages(session.Message) + messages = append(messages, modelModule.Message{Role: "user", Content: req.Question}) + + // 5. Yield frames on a channel. + out := make(chan ChatbotSSEFrame, 4) + go func() { + defer close(out) + resp, callErr := chatModel.ModelDriver.ChatWithMessages( + modelName, messages, chatModel.APIConfig, &modelModule.ChatConfig{}, + ) + if callErr != nil { + // Log the real error with structured context so + // ops can debug, but do NOT echo the raw + // err.Error() to the client (security M2: + // internal gorm/SQL/file-path leaks). + common.Error("bot: ChatbotCompletion LLM call failed", + callErr, + zap.String("dialog_id", dialogID), + zap.String("session_id", session.ID), + zap.String("llm_id", dialog.LLMID), + ) + out <- ChatbotSSEFrame{ + Err: errors.New("an internal error occurred"), + SessionID: session.ID, + } + out <- ChatbotSSEFrame{Done: true} + return + } + answer := "" + if resp != nil && resp.Answer != nil { + answer = *resp.Answer + } + + // Persist the new turn pair (user + assistant) back to + // api_4_conversation so the NEXT call to ChatbotCompletion + // with the same session_id sees this turn in messages. + // Update errors are logged but do NOT fail the SSE stream + // — the answer has already been produced. The next call + // will rebuild from the prior (pre-this-turn) snapshot, + // losing at most the latest exchange; acceptable for v1. + newTurns := append(historyFromMessages(messages), + map[string]any{"role": "assistant", "content": answer, "created_at": time.Now().Unix()}, + ) + if updated, mErr := json.Marshal(newTurns); mErr == nil { + session.Message = updated + if uErr := s.api4ConversationDAO.Update(session); uErr != nil { + common.Error("bot: ChatbotCompletion session update failed", + uErr, + zap.String("dialog_id", dialogID), + zap.String("session_id", session.ID), + ) + } + } + + out <- ChatbotSSEFrame{ + Data: answer, + Reference: map[string]any{}, + SessionID: session.ID, + } + out <- ChatbotSSEFrame{Done: true} + }() + return out, common.CodeSuccess, nil +} + +// historyToMessages reads the session.Message JSON array of +// {role, content, ...} dicts and projects it onto modelModule.Message +// for the LLM driver. Tolerates an empty / malformed Message column +// by returning an empty slice — the caller appends the new user turn +// so the LLM still receives the current prompt. +func historyToMessages(raw json.RawMessage) []modelModule.Message { + if len(raw) == 0 { + return nil + } + var turns []map[string]any + if err := json.Unmarshal(raw, &turns); err != nil { + return nil + } + out := make([]modelModule.Message, 0, len(turns)) + for _, t := range turns { + role, _ := t["role"].(string) + content, _ := t["content"].(string) + if role == "" || content == "" { + continue + } + out = append(out, modelModule.Message{Role: role, Content: content}) + } + return out +} + +// historyFromMessages is the inverse projection — used to write the +// updated turn list back to the api_4_conversation.Message column. +func historyFromMessages(msgs []modelModule.Message) []map[string]any { + out := make([]map[string]any, 0, len(msgs)) + now := time.Now().Unix() + for i, m := range msgs { + out = append(out, map[string]any{ + "role": m.Role, + "content": m.Content, + "created_at": now + int64(i), // preserve order, monotonic + }) + } + return out +} diff --git a/internal/service/bot_completion_history_test.go b/internal/service/bot_completion_history_test.go new file mode 100644 index 0000000000..aa6e2fbaa3 --- /dev/null +++ b/internal/service/bot_completion_history_test.go @@ -0,0 +1,164 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Tests for the conversation-history round-trip helpers used by +// BotService.ChatbotCompletion. Locks in review Finding 8 — a resumed +// session_id must carry prior turns (assistant prologue + earlier +// user/assistant exchanges) into the next LLM call so multi-turn +// chatbot clients retain context. + +package service + +import ( + "encoding/json" + "testing" + + modelModule "ragflow/internal/entity/models" +) + +func TestHistoryToMessages_Empty(t *testing.T) { + // A freshly-seeded session with no prior turns returns an empty + // slice. Caller appends the new user turn; LLM receives only + // the current prompt. Matches python conversation_service seed. + got := historyToMessages(nil) + if len(got) != 0 { + t.Fatalf("nil raw: want 0 messages, got %d", len(got)) + } + got = historyToMessages(json.RawMessage(`[]`)) + if len(got) != 0 { + t.Fatalf("empty array: want 0 messages, got %d", len(got)) + } +} + +func TestHistoryToMessages_RoundTrip(t *testing.T) { + // Simulate a session with: 1 prologue assistant turn + 1 prior + // user/assistant pair. The LLM must see all 3 prior turns + // before the new user turn is appended. + turns := []map[string]any{ + {"role": "assistant", "content": "Hello, how can I help?", "created_at": 1}, + {"role": "user", "content": "What is Go?", "created_at": 2}, + {"role": "assistant", "content": "Go is a compiled language.", "created_at": 3}, + } + raw, err := json.Marshal(turns) + if err != nil { + t.Fatalf("marshal seed: %v", err) + } + msgs := historyToMessages(raw) + if len(msgs) != 3 { + t.Fatalf("want 3 prior messages, got %d", len(msgs)) + } + if msgs[0].Role != "assistant" || msgs[0].Content != "Hello, how can I help?" { + t.Errorf("turn 0: role=%q content=%q", msgs[0].Role, msgs[0].Content) + } + if msgs[1].Role != "user" || msgs[1].Content != "What is Go?" { + t.Errorf("turn 1: role=%q content=%q", msgs[1].Role, msgs[1].Content) + } + if msgs[2].Role != "assistant" || msgs[2].Content != "Go is a compiled language." { + t.Errorf("turn 2: role=%q content=%q", msgs[2].Role, msgs[2].Content) + } +} + +func TestHistoryToMessages_Malformed(t *testing.T) { + // Malformed JSON must not panic; returns nil so caller falls back + // to a fresh single-turn LLM call rather than failing the request. + got := historyToMessages(json.RawMessage(`not json`)) + if got != nil { + t.Fatalf("malformed raw: want nil, got %v", got) + } +} + +func TestHistoryToMessages_SkipsEmptyFields(t *testing.T) { + // Defensive: turns missing role or content are dropped, not + // passed to the LLM as empty messages. + turns := []map[string]any{ + {"role": "assistant", "content": "valid", "created_at": 1}, + {"role": "", "content": "no role", "created_at": 2}, + {"role": "user", "content": "", "created_at": 3}, + {"role": "user", "content": "second valid", "created_at": 4}, + } + raw, _ := json.Marshal(turns) + msgs := historyToMessages(raw) + if len(msgs) != 2 { + t.Fatalf("want 2 valid turns, got %d", len(msgs)) + } + if msgs[0].Content != "valid" || msgs[1].Content != "second valid" { + t.Errorf("got %+v", msgs) + } +} + +func TestHistoryFromMessages_PreservesOrder(t *testing.T) { + // The LLM driver returns messages in the same order the input + // was provided. The round-trip must preserve that order so the + // next call to ChatbotCompletion sees a coherent history. + msgs := []modelModule.Message{ + {Role: "assistant", Content: "first"}, + {Role: "user", Content: "second"}, + {Role: "assistant", Content: "third"}, + } + turns := historyFromMessages(msgs) + if len(turns) != 3 { + t.Fatalf("want 3 turns, got %d", len(turns)) + } + for i, want := range []string{"first", "second", "third"} { + if turns[i]["content"] != want { + t.Errorf("turn %d content = %v, want %q", i, turns[i]["content"], want) + } + if turns[i]["role"] != msgs[i].Role { + t.Errorf("turn %d role = %v, want %q", i, turns[i]["role"], msgs[i].Role) + } + } +} + +func TestHistoryRoundTrip_PreservesPriorTurns(t *testing.T) { + // End-to-end: prior JSON → history → back to JSON must be + // semantically identical (modulo the created_at monotonic + // adjustment that historyFromMessages applies for ordering). + turns := []map[string]any{ + {"role": "assistant", "content": "p1", "created_at": int64(100)}, + {"role": "user", "content": "p2", "created_at": int64(200)}, + } + raw, _ := json.Marshal(turns) + + msgs := historyToMessages(raw) + // Caller appends a new user turn (the current request). + msgs = append(msgs, modelModule.Message{Role: "user", Content: "current"}) + + // Round-trip back to JSON for storage. + newTurns := historyFromMessages(msgs) + raw2, err := json.Marshal(newTurns) + if err != nil { + t.Fatalf("marshal round-trip: %v", err) + } + + var got []map[string]any + if err := json.Unmarshal(raw2, &got); err != nil { + t.Fatalf("unmarshal round-trip: %v", err) + } + if len(got) != 3 { + t.Fatalf("want 3 turns after round-trip, got %d", len(got)) + } + expected := []struct{ role, content string }{ + {"assistant", "p1"}, + {"user", "p2"}, + {"user", "current"}, + } + for i, want := range expected { + if got[i]["role"] != want.role || got[i]["content"] != want.content { + t.Errorf("turn %d: got role=%v content=%v, want role=%q content=%q", + i, got[i]["role"], got[i]["content"], want.role, want.content) + } + } +}