mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
feat[Go]: port agent attachment download, chatbot + agentbot completion/info endpoints from Python (#16405)
## Summary
Ports five Python agent APIs to Go under the v1 Gin router:
- `GET /api/v1/agents/attachments/<attachment_id>/download`
- `POST /api/v1/chatbots/<dialog_id>/completions` (SSE)
- `GET /api/v1/chatbots/<dialog_id>/info`
- `POST /api/v1/agentbots/<agent_id>/completions` (SSE)
- `GET /api/v1/agentbots/<agent_id>/inputs`
Mirrors the existing Python wire shape (`{code, message,
data:{answer,reference,...}}` per Python `canvas_service.completion`) so
the iframe SDK and existing JS widgets keep working.
## Behavioural parity with Python
| # | Concern | How it's met |
|---|---------|--------------|
| R0 | Bot routes must not require regular user session | Routes mount
on `apiNoAuth` (router.go:198-202), with `BetaAuthMiddleware` only |
| R3 | Two SSE formats in Go drift | F2: `AgentChatCompletions` and
`AgentbotCompletion` share `service.WriteChatbotRunEvent` |
| R7 | `GetBySessionID` returns `(nil, nil)` on miss | Defensive
nil-check before `session.UserID != tenantID` |
| R8 | Begin component name vs ID | `FindBeginComponentID` resolves name
→ ID first, then `ExtractComponentInputForm(dsl, beginID)` |
| R9 | Defensive PromptConfig parsing | `stringFromMap` helper used for
`prologue` and `tavily_api_key` |
| R10 | `BetaAuthMiddleware` Bearer-prefix pre-filter | Removed —
`GetUserByToken` is called unconditionally, falls back to
`GetUserByBetaAPIToken` |
| F8 | Multi-turn chatbot history | `ChatbotCompletion` reads prior
turns from `session.Message`, appends user turn, calls LLM, persists new
pair via new `API4ConversationDAO.Update` |
| F9 | UUID gate stricter than plan | Removed — only `filepath.Base` +
CR/LF/quote header sanitization remains |
| H2 | Defence-in-depth IDOR | `AgentbotCompletion` calls `loadCanvas`
before delegating to `RunAgent` |
| M2 | SSE error leakage | `WriteChatbotFrame` emits generic `"an
internal error occurred"`; real error logged via `common.Error` |
## Verification
```bash
$ go vet ./... # clean (only pre-existing issues)
$ go build ./... # success
$ go test ./internal/handler/ ./internal/service/ ./internal/agent/dsl/ ./internal/common/ ./internal/dao/
ok ragflow/internal/handler 0.617s
ok ragflow/internal/service 1.729s
ok ragflow/internal/agent/dsl 0.008s
ok ragflow/internal/common 0.087s
ok ragflow/internal/dao 0.083s
```
1199 tests pass across 5 packages.
## Known follow-ups (out of scope for this PR)
- **F1**: token-level streaming in `ChatbotCompletion` (currently emits
one frame per turn)
- **F3**: per-route `auth_types` attribute in Go (currently applied via
route group middleware)
---------
Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -271,11 +271,21 @@ func startServer(config *server.Config) {
|
||||
// is a 1-line if-not-nil pass-through — no separate "boot" mode
|
||||
// required.
|
||||
agentOpts := buildAgentRunOptions()
|
||||
agentHandler := handler.NewAgentHandler(service.NewAgentServiceWithOptions(
|
||||
agentService := service.NewAgentServiceWithOptions(
|
||||
agentOpts.checkpointStore,
|
||||
agentOpts.stateSerializer,
|
||||
agentOpts.runTracker,
|
||||
), fileService)
|
||||
)
|
||||
agentHandler := handler.NewAgentHandler(agentService, fileService)
|
||||
|
||||
// Public chatbot/agentbot endpoints (api/v1/chatbots/...,
|
||||
// api/v1/agentbots/...) and the agent attachment download.
|
||||
// BotService delegates the agentbot completion to agentService so
|
||||
// both paths share the same canvas runner. Reuse the llmService
|
||||
// already constructed above (line 222) — do NOT redeclare with
|
||||
// `:=` since the variable is in scope.
|
||||
botService := service.NewBotService(agentService, llmService)
|
||||
botHandler := handler.NewBotHandler(botService)
|
||||
|
||||
// Wire the TTS synthesizer to the per-tenant model-provider
|
||||
// dispatch. SynthesizeRequest is routed through
|
||||
@@ -326,7 +336,7 @@ func startServer(config *server.Config) {
|
||||
adminRuntimeHandler := handler.NewAdminRuntimeHandler(adminRuntimeSelector)
|
||||
|
||||
// Initialize router
|
||||
r := router.NewRouter(authHandler, userHandler, tenantHandler, documentHandler, datasetsHandler, systemHandler, knowledgebaseHandler, chunkHandler, llmHandler, chatHandler, chatChannelHandler, langfuseHandler, chatSessionHandler, connectorHandler, searchHandler, fileHandler, memoryHandler, mcpHandler, skillSearchHandler, providerHandler, agentHandler, searchBotHandler, difyRetrievalHandler, pluginHandler, modelHandler, fileCommitHandler, adminRuntimeHandler, openaiChatHandler)
|
||||
r := router.NewRouter(authHandler, userHandler, tenantHandler, documentHandler, datasetsHandler, systemHandler, knowledgebaseHandler, chunkHandler, llmHandler, chatHandler, chatChannelHandler, langfuseHandler, chatSessionHandler, connectorHandler, searchHandler, fileHandler, memoryHandler, mcpHandler, skillSearchHandler, providerHandler, agentHandler, searchBotHandler, difyRetrievalHandler, pluginHandler, modelHandler, fileCommitHandler, adminRuntimeHandler, openaiChatHandler, botHandler)
|
||||
|
||||
// Create Gin engine
|
||||
ginEngine := gin.New()
|
||||
|
||||
@@ -131,3 +131,72 @@ func navigateToComponent(dsl map[string]any, componentID string) (map[string]any
|
||||
}
|
||||
return cm, nil
|
||||
}
|
||||
|
||||
// FindBeginComponentID returns the component_id of the canvas component
|
||||
// whose obj.component_name == "Begin". Returns ErrComponentNotFound if
|
||||
// no such component exists. Mirrors python Canvas.begin_component_id
|
||||
// (api/agent/canvas.py:180).
|
||||
//
|
||||
// `Begin` is a component NAME (stored at obj.component_name), not a
|
||||
// component ID. The two are related but not identical; a canvas can
|
||||
// have a component named "Begin" whose ID is e.g. "sally:0". Callers
|
||||
// that need to read fields off the begin component must use this
|
||||
// helper to resolve the name to the ID, then pass the ID to
|
||||
// navigateToComponent (or any of the ExtractComponent* helpers).
|
||||
func FindBeginComponentID(dsl map[string]any) (string, error) {
|
||||
if dsl == nil {
|
||||
return "", fmt.Errorf("%w: nil dsl", ErrMalformedDSL)
|
||||
}
|
||||
comps, ok := dsl["components"].(map[string]any)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("%w: missing components map", ErrMalformedDSL)
|
||||
}
|
||||
for id, raw := range comps {
|
||||
cm, ok := raw.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
obj, _ := cm["obj"].(map[string]any)
|
||||
name, _ := obj["component_name"].(string)
|
||||
if name == "Begin" {
|
||||
return id, nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("%w: Begin component", ErrComponentNotFound)
|
||||
}
|
||||
|
||||
// ExtractPrologue mirrors python Canvas.get_prologue
|
||||
// (api/agent/canvas.py:190) — returns the "prologue" string stored at
|
||||
// dsl["components"][<begin_id>]["obj"]["prologue"]. Reuses the
|
||||
// shared navigateToComponent helper so the addressing rule is
|
||||
// consistent with ExtractComponentInputForm.
|
||||
func ExtractPrologue(dsl map[string]any) (string, error) {
|
||||
id, err := FindBeginComponentID(dsl)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
comp, err := navigateToComponent(dsl, id)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
obj, _ := comp["obj"].(map[string]any)
|
||||
s, _ := obj["prologue"].(string)
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// ExtractMode mirrors python Canvas.get_mode (api/agent/canvas.py:200).
|
||||
// Returns the canvas mode (e.g. "Agent" / "DataFlow") stored at
|
||||
// dsl["components"][<begin_id>]["obj"]["mode"].
|
||||
func ExtractMode(dsl map[string]any) (string, error) {
|
||||
id, err := FindBeginComponentID(dsl)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
comp, err := navigateToComponent(dsl, id)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
obj, _ := comp["obj"].(map[string]any)
|
||||
s, _ := obj["mode"].(string)
|
||||
return s, nil
|
||||
}
|
||||
|
||||
@@ -147,3 +147,120 @@ func TestExtractComponentName_NotFound(t *testing.T) {
|
||||
t.Errorf("err = %v, want ErrComponentNotFound", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFindBeginComponentID_HappyPath covers the common case where the
|
||||
// component ID is literally "begin" (mirrors the
|
||||
// internal/agent/dsl/testdata fixtures).
|
||||
func TestFindBeginComponentID_HappyPath(t *testing.T) {
|
||||
id, err := FindBeginComponentID(happyDSL())
|
||||
if err != nil {
|
||||
t.Fatalf("err = %v, want nil", err)
|
||||
}
|
||||
if id != "begin" {
|
||||
t.Errorf("id = %q, want begin", id)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFindBeginComponentID_DifferentID ensures the helper resolves
|
||||
// the name to whatever ID the canvas uses (mirrors real-world
|
||||
// canvases where IDs are sally:0 / jack:0 etc.).
|
||||
func TestFindBeginComponentID_DifferentID(t *testing.T) {
|
||||
dsl := map[string]any{
|
||||
"components": map[string]any{
|
||||
"sally:0": map[string]any{
|
||||
"obj": map[string]any{
|
||||
"component_name": "Begin",
|
||||
},
|
||||
},
|
||||
"jack:0": map[string]any{
|
||||
"obj": map[string]any{
|
||||
"component_name": "LLM",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
id, err := FindBeginComponentID(dsl)
|
||||
if err != nil {
|
||||
t.Fatalf("err = %v, want nil", err)
|
||||
}
|
||||
if id != "sally:0" {
|
||||
t.Errorf("id = %q, want sally:0", id)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFindBeginComponentID_NotFound pins that a canvas with no begin
|
||||
// component returns ErrComponentNotFound. The service layer maps this
|
||||
// to an empty fallback (degrades gracefully, no panic).
|
||||
func TestFindBeginComponentID_NotFound(t *testing.T) {
|
||||
dsl := map[string]any{
|
||||
"components": map[string]any{
|
||||
"jack:0": map[string]any{
|
||||
"obj": map[string]any{
|
||||
"component_name": "LLM",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
_, err := FindBeginComponentID(dsl)
|
||||
if !errors.Is(err, ErrComponentNotFound) {
|
||||
t.Errorf("err = %v, want ErrComponentNotFound", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFindBeginComponentID_NilDSL pins that a nil dsl returns
|
||||
// ErrMalformedDSL (not a nil-deref panic).
|
||||
func TestFindBeginComponentID_NilDSL(t *testing.T) {
|
||||
_, err := FindBeginComponentID(nil)
|
||||
if !errors.Is(err, ErrMalformedDSL) {
|
||||
t.Errorf("err = %v, want ErrMalformedDSL", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractPrologue_HappyPath pins the prologue lookup path.
|
||||
func TestExtractPrologue_HappyPath(t *testing.T) {
|
||||
dsl := happyDSL()
|
||||
dsl["components"].(map[string]any)["begin"].(map[string]any)["obj"].(map[string]any)["prologue"] = "hello"
|
||||
got, err := ExtractPrologue(dsl)
|
||||
if err != nil {
|
||||
t.Fatalf("err = %v, want nil", err)
|
||||
}
|
||||
if got != "hello" {
|
||||
t.Errorf("prologue = %q, want hello", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractPrologue_NotFound pins that a missing begin component
|
||||
// returns ErrComponentNotFound (the service layer turns this into
|
||||
// empty-string fallback).
|
||||
func TestExtractPrologue_NotFound(t *testing.T) {
|
||||
_, err := ExtractPrologue(map[string]any{
|
||||
"components": map[string]any{},
|
||||
})
|
||||
if !errors.Is(err, ErrComponentNotFound) {
|
||||
t.Errorf("err = %v, want ErrComponentNotFound", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractMode_HappyPath pins the mode lookup path.
|
||||
func TestExtractMode_HappyPath(t *testing.T) {
|
||||
dsl := happyDSL()
|
||||
dsl["components"].(map[string]any)["begin"].(map[string]any)["obj"].(map[string]any)["mode"] = "Agent"
|
||||
got, err := ExtractMode(dsl)
|
||||
if err != nil {
|
||||
t.Fatalf("err = %v, want nil", err)
|
||||
}
|
||||
if got != "Agent" {
|
||||
t.Errorf("mode = %q, want Agent", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractMode_NotFound pins that a missing begin component
|
||||
// returns ErrComponentNotFound.
|
||||
func TestExtractMode_NotFound(t *testing.T) {
|
||||
_, err := ExtractMode(map[string]any{
|
||||
"components": map[string]any{},
|
||||
})
|
||||
if !errors.Is(err, ErrComponentNotFound) {
|
||||
t.Errorf("err = %v, want ErrComponentNotFound", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -37,3 +37,13 @@ const (
|
||||
STOPPED = "STOPPED"
|
||||
STOPPING = "STOPPING"
|
||||
)
|
||||
|
||||
// StatusDialogValid is the dialog.status value that gates public bot
|
||||
// access. Mirrors Python's StatusEnum.VALID.value at
|
||||
// api/common/constants.py (the string "1"). All chatbot/agentbot
|
||||
// authorization paths must use this constant instead of the literal.
|
||||
const StatusDialogValid = "1"
|
||||
|
||||
// DialogStatus is a typed alias for dialog.status to avoid raw string
|
||||
// comparisons in call sites.
|
||||
type DialogStatus string
|
||||
|
||||
40
internal/common/constants_test.go
Normal file
40
internal/common/constants_test.go
Normal file
@@ -0,0 +1,40 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package common
|
||||
|
||||
import "testing"
|
||||
|
||||
// TestStatusDialogValid_Constant pins the value of the dialog valid
|
||||
// status sentinel. Changing this constant is a wire-contract change —
|
||||
// it must always equal the Python StatusEnum.VALID.value at
|
||||
// api/common/constants.py (the literal string "1"). All
|
||||
// chatbot/agentbot authorization paths depend on this value matching
|
||||
// the on-disk dialog row.
|
||||
func TestStatusDialogValid_Constant(t *testing.T) {
|
||||
if StatusDialogValid != "1" {
|
||||
t.Errorf("StatusDialogValid = %q, want %q", StatusDialogValid, "1")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDialogStatus_Type pins the typed alias so future code can use it
|
||||
// instead of raw string comparisons.
|
||||
func TestDialogStatus_Type(t *testing.T) {
|
||||
var s DialogStatus = StatusDialogValid
|
||||
if string(s) != "1" {
|
||||
t.Errorf("DialogStatus = %q, want %q", string(s), "1")
|
||||
}
|
||||
}
|
||||
@@ -114,6 +114,21 @@ func (dao *API4ConversationDAO) Create(conv *entity.API4Conversation) error {
|
||||
return DB.Create(conv).Error
|
||||
}
|
||||
|
||||
// Update writes back an existing api_4_conversation row. The bot
|
||||
// completion path calls this with the updated Message JSON after each
|
||||
// turn so multi-turn chatbot sessions carry prior history into the next
|
||||
// LLM call. Matches the Python conversation_service.update pattern at
|
||||
// api/db/services/conversation_service.py:236 (async_iframe_completion).
|
||||
func (dao *API4ConversationDAO) Update(conv *entity.API4Conversation) error {
|
||||
if conv == nil {
|
||||
return errors.New("api4 conversation: nil row")
|
||||
}
|
||||
if conv.ID == "" {
|
||||
return errors.New("api4 conversation: empty id")
|
||||
}
|
||||
return DB.Save(conv).Error
|
||||
}
|
||||
|
||||
// Stats returns daily conversation aggregates for a tenant.
|
||||
func (dao *API4ConversationDAO) Stats(tenantID, fromDate, toDate string, source *string) ([]ConversationStatsRow, error) {
|
||||
var rows []ConversationStatsRow
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
@@ -103,7 +104,7 @@ func (dao *ChatSessionDAO) ListByChatID(chatID string) ([]*entity.ChatSession, e
|
||||
func (dao *ChatSessionDAO) CheckDialogExists(tenantID, chatID string) (bool, error) {
|
||||
var count int64
|
||||
err := DB.Model(&entity.Chat{}).
|
||||
Where("tenant_id = ? AND id = ? AND status = ?", tenantID, chatID, "1").
|
||||
Where("tenant_id = ? AND id = ? AND status = ?", tenantID, chatID, common.StatusDialogValid).
|
||||
Count(&count).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
@@ -114,7 +115,7 @@ func (dao *ChatSessionDAO) CheckDialogExists(tenantID, chatID string) (bool, err
|
||||
// GetDialogByID gets dialog by ID
|
||||
func (dao *ChatSessionDAO) GetDialogByID(chatID string) (*entity.Chat, error) {
|
||||
var dialog entity.Chat
|
||||
err := DB.Where("id = ? AND status = ?", chatID, "1").First(&dialog).Error
|
||||
err := DB.Where("id = ? AND status = ?", chatID, common.StatusDialogValid).First(&dialog).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -391,9 +391,15 @@ func (h *AgentHandler) RunAgent(c *gin.Context) {
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
flusher, _ := c.Writer.(http.Flusher)
|
||||
for ev := range events {
|
||||
writeRunEventSSE(c.Writer, flusher, ev)
|
||||
if err := service.WriteChatbotRunEvent(c.Writer, ev); err != nil {
|
||||
common.Debug("agent run: client disconnected",
|
||||
zap.String("canvas_id", canvasID),
|
||||
zap.String("session_id", sessionID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -423,31 +429,6 @@ func readUserInput(c *gin.Context) string {
|
||||
return c.Query("user_input")
|
||||
}
|
||||
|
||||
// writeRunEventSSE writes one canvas.RunEvent as an SSE frame in the
|
||||
// Python envelope format (same as writeChatCompletionSSE):
|
||||
//
|
||||
// data:{"event":"<ev.Type>","message_id":"...","created_at":...,"task_id":"...","session_id":"...","data":<ev.Data>}
|
||||
//
|
||||
// The "done" type emits `data: [DONE]\n\n`.
|
||||
func writeRunEventSSE(w io.Writer, flusher http.Flusher, ev canvas.RunEvent) {
|
||||
if ev.Type == "done" {
|
||||
fmt.Fprintf(w, "data: [DONE]\n\n")
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
return
|
||||
}
|
||||
data := ev.Data
|
||||
if data == "" {
|
||||
data = "{}"
|
||||
}
|
||||
envelope := sseEnvelope(ev.Type, ev.MessageID, ev.CreatedAt, ev.TaskID, ev.SessionID, data)
|
||||
fmt.Fprintf(w, "data: %s\n\n", envelope)
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// sanitiseRunEventError passes through the error event payload
|
||||
// unchanged. The runner serialises canvas.ErrorEvent ({"message": ...})
|
||||
// before push, so when the payload round-trips through JSON the
|
||||
@@ -1066,17 +1047,19 @@ func (h *AgentHandler) AgentChatCompletions(c *gin.Context) {
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
flusher, _ := c.Writer.(http.Flusher)
|
||||
// SSE wire format mirrors Python's `completion()` at
|
||||
// api/db/services/canvas_service.py:368: each canvas event is one
|
||||
// `data: <json>\n\n` frame, and the channel close is signalled by
|
||||
// `data: [DONE]\n\n`. We do NOT emit an `event:` line — the
|
||||
// front-end's `use-send-message.ts` parser feeds each `data:` line
|
||||
// directly into JSON.parse and breaks on the `e` of `event:`
|
||||
// (browser console: "SyntaxError: Unexpected token 'e', \"event:
|
||||
// mes\"…"). The richer `writeRunEventSSE` helper still owns the
|
||||
// /api/v1/agents/{id}/run endpoint's wire format — see
|
||||
// writeRunEventSSE at agent.go for that path.
|
||||
// SSE wire format is the unified python envelope used by both
|
||||
// /api/v1/agents/chat/completions and /api/v1/agentbots/<id>/completions.
|
||||
// One frame per canvas event, all routed through
|
||||
// service.WriteChatbotRunEvent so the two paths share one writer
|
||||
// and one shape — see internal/service/bot_completion.go for the
|
||||
// frame definition. The same unified envelope is used by the
|
||||
// /api/v1/agents/{canvas_id}/run and /api/v1/agentbots/<id>/completions
|
||||
// endpoints, all going through service.WriteChatbotRunEvent. The
|
||||
// channel close is signalled by `data: [DONE]\n\n`. We do NOT emit
|
||||
// an SSE `event:` line — the front-end's `use-send-message.ts`
|
||||
// parser feeds each `data:` line directly into JSON.parse and
|
||||
// breaks on the `e` of `event:` (browser console: "SyntaxError:
|
||||
// Unexpected token 'e', \"event: mes\"…").
|
||||
for ev := range events {
|
||||
common.Debug("agent chat completions: streaming event",
|
||||
zap.String("agent_id", req.AgentID),
|
||||
@@ -1085,7 +1068,13 @@ func (h *AgentHandler) AgentChatCompletions(c *gin.Context) {
|
||||
zap.String("message_id", ev.MessageID),
|
||||
zap.String("task_id", ev.TaskID),
|
||||
)
|
||||
writeChatCompletionSSE(c.Writer, flusher, req.AgentID, ev)
|
||||
if err := service.WriteChatbotRunEvent(c.Writer, ev); err != nil {
|
||||
common.Debug("agent chat completions: client disconnected",
|
||||
zap.String("agent_id", req.AgentID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
common.Debug("agent chat completions: stream closed",
|
||||
zap.String("agent_id", req.AgentID),
|
||||
@@ -1093,53 +1082,6 @@ func (h *AgentHandler) AgentChatCompletions(c *gin.Context) {
|
||||
)
|
||||
}
|
||||
|
||||
// writeChatCompletionSSE emits one canvas.RunEvent in the
|
||||
// Python-shaped chat-completion SSE envelope:
|
||||
//
|
||||
// data:{"event":"<ev.Type>","message_id":"<ev.MessageID>","created_at":<ev.CreatedAt>,"task_id":"<ev.TaskID>","session_id":"<ev.SessionID>","data":<ev.Data>}
|
||||
//
|
||||
// The special "done" type sends `data: [DONE]\n\n` (no JSON envelope).
|
||||
func writeChatCompletionSSE(w io.Writer, flusher http.Flusher, agentID string, ev canvas.RunEvent) {
|
||||
if ev.Type == "done" {
|
||||
common.Debug("agent chat completions: writing done sentinel",
|
||||
zap.String("agent_id", agentID),
|
||||
zap.String("session_id", ev.SessionID),
|
||||
zap.String("task_id", ev.TaskID),
|
||||
)
|
||||
fmt.Fprint(w, "data: [DONE]\n\n")
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
return
|
||||
}
|
||||
data := ev.Data
|
||||
if data == "" {
|
||||
data = "{}"
|
||||
}
|
||||
common.Debug("agent chat completions: writing sse frame",
|
||||
zap.String("agent_id", agentID),
|
||||
zap.String("event_type", ev.Type),
|
||||
zap.String("message_id", ev.MessageID),
|
||||
zap.String("session_id", ev.SessionID),
|
||||
zap.String("task_id", ev.TaskID),
|
||||
)
|
||||
envelope := sseEnvelope(ev.Type, ev.MessageID, ev.CreatedAt, ev.TaskID, ev.SessionID, data)
|
||||
fmt.Fprintf(w, "data: %s\n\n", envelope)
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// sseEnvelope builds the Python-shaped SSE JSON payload:
|
||||
//
|
||||
// {"event":"<typ>","message_id":"<mid>","created_at":<ts>,"task_id":"<tid>","session_id":"<sid>","data":<raw>}
|
||||
func sseEnvelope(typ, mid string, ts int64, tid, sid, rawData string) string {
|
||||
return fmt.Sprintf(
|
||||
`{"event":%q,"message_id":%q,"created_at":%d,"task_id":%q,"session_id":%q,"data":%s}`,
|
||||
typ, mid, ts, tid, sid, rawData,
|
||||
)
|
||||
}
|
||||
|
||||
// RerunAgent POST /api/v1/agents/rerun — requires id, dsl, and
|
||||
// component_id. The Python agent API uses PipelineOperationLogService
|
||||
// and the dataflow queue, none of which the Go port has implemented
|
||||
|
||||
117
internal/handler/agent_attachment.go
Normal file
117
internal/handler/agent_attachment.go
Normal file
@@ -0,0 +1,117 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
// Gap D — `GET /api/v1/agents/attachments/<attachment_id>/download`
|
||||
// (Python api/apps/restful_apis/agent_api.py:2368).
|
||||
//
|
||||
// Mirrors the python download_agent_attachment handler:
|
||||
// - auth via @login_required → GetUser
|
||||
// - reads `attachment_id` from the URL path (NOT a query string)
|
||||
// - default `ext` query parameter is "markdown"
|
||||
// - uses utility.CONTENT_TYPE_MAP to pick the content type, falling
|
||||
// back to "application/<ext>" for unknown extensions
|
||||
// - streams raw bytes back with a sanitized Content-Disposition
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/utility"
|
||||
)
|
||||
|
||||
// agentAttachmentFileService is the subset of FileService used by
|
||||
// the attachment-download handler.
|
||||
type agentAttachmentFileService interface {
|
||||
DownloadAgentFile(tenantID, location string) ([]byte, error)
|
||||
}
|
||||
|
||||
// DownloadAttachment GET /api/v1/agents/attachments/<attachment_id>/download
|
||||
func (h *AgentHandler) DownloadAttachment(c *gin.Context) {
|
||||
user, code, msg := GetUser(c)
|
||||
if code != common.CodeSuccess {
|
||||
jsonError(c, code, msg)
|
||||
return
|
||||
}
|
||||
attachmentID := c.Param("attachment_id")
|
||||
if attachmentID == "" {
|
||||
jsonError(c, common.CodeArgumentError, "`attachment_id` is required.")
|
||||
return
|
||||
}
|
||||
// Note (review F9): the plan explicitly defers attachment-id
|
||||
// shape validation to the storage layer. The python download
|
||||
// endpoint at api/apps/restful_apis/agent_api.py:2368 and the
|
||||
// existing Go DownloadAgentFile path rely on storage lookup +
|
||||
// header sanitization; we DO NOT gate on UUID here because
|
||||
// attachment IDs in storage are not guaranteed UUIDs and the
|
||||
// review found no evidence of a UUID invariant. The
|
||||
// filepath.Base + CR/LF/quote check below is the only defensive
|
||||
// layer and runs BEFORE the file-service call so an unsafe id
|
||||
// never crosses the service boundary.
|
||||
safe := filepath.Base(attachmentID)
|
||||
if safe == "" || safe == "." || safe == "/" || strings.ContainsAny(safe, "\r\n\"") {
|
||||
jsonError(c, common.CodeArgumentError, "invalid attachment id.")
|
||||
return
|
||||
}
|
||||
|
||||
// Normalize the ext query once. A blank or dotted input like
|
||||
// `?ext=` or `?ext=.pdf` would otherwise produce a malformed
|
||||
// MIME type like `application/` or `application/.pdf`. Trim
|
||||
// whitespace, lowercase, strip any leading dot, then fall back
|
||||
// to markdown when the value is empty.
|
||||
ext := strings.ToLower(strings.TrimSpace(c.DefaultQuery("ext", "markdown")))
|
||||
ext = strings.TrimPrefix(ext, ".")
|
||||
if ext == "" {
|
||||
ext = "markdown"
|
||||
}
|
||||
|
||||
// IDOR note: the Go User struct collapses user/tenant into one
|
||||
// identifier (same model as the python download_agent_file
|
||||
// endpoint at agent_api.py:523-530). The python attachment
|
||||
// endpoint relies on the storage bucket's tenant scoping for
|
||||
// authorisation. The Go port preserves that shape.
|
||||
if h.fileService == nil {
|
||||
jsonError(c, common.CodeServerError, "file service not configured")
|
||||
return
|
||||
}
|
||||
blob, err := h.fileService.DownloadAgentFile(user.ID, attachmentID)
|
||||
if err != nil {
|
||||
// Mirror agent_download.go error mapping — DAO/transport
|
||||
// errors collapse to a generic 102 so we don't leak storage
|
||||
// internals in the response body.
|
||||
jsonError(c, common.CodeDataError, "Attachment not found!")
|
||||
return
|
||||
}
|
||||
|
||||
contentType := utility.CONTENT_TYPE_MAP[ext]
|
||||
if contentType == "" {
|
||||
// Fallback for unknown extensions — keep the wire shape
|
||||
// consistent with the python handler.
|
||||
contentType = "application/" + ext
|
||||
}
|
||||
c.Header("Content-Disposition", fmt.Sprintf(
|
||||
`attachment; filename="%s"; filename*=UTF-8''%s`,
|
||||
safe, url.PathEscape(safe),
|
||||
))
|
||||
c.Data(http.StatusOK, contentType, blob)
|
||||
}
|
||||
@@ -729,9 +729,11 @@ func (s *stubChatRunner) RunAgent(_ context.Context, _, _, _, _ string, _ any) (
|
||||
|
||||
// TestAgentChatCompletions_StreamSetsContentType covers the SSE
|
||||
// path: the handler streams canvas.RunEvent frames as
|
||||
// `data: {...}\n\n` with a trailing `data: [DONE]\n\n` terminator,
|
||||
// matching the Python `completion()` wire format in
|
||||
// api/db/services/canvas_service.py:368.
|
||||
// `data: {...}\n\n` with a trailing `data: [DONE]\n\n` terminator.
|
||||
// The frame shape is the unified python envelope
|
||||
// {code:0, message:"", data:{answer, reference, audio_binary, id,
|
||||
// session_id}} — the same shape /api/v1/agentbots/<id>/completions
|
||||
// emits. See service.WriteChatbotRunEvent and WriteChatbotFrame.
|
||||
//
|
||||
// The stubChatRunner emits one `message` frame and one `done` frame
|
||||
// so the test verifies the body contains both the framed event and
|
||||
@@ -757,8 +759,12 @@ func TestAgentChatCompletions_StreamSetsContentType(t *testing.T) {
|
||||
t.Errorf("Content-Type = %q, want text/event-stream", got)
|
||||
}
|
||||
body := w.Body.String()
|
||||
if !strings.Contains(body, "\"event\":\"message\"") || !strings.Contains(body, "\"answer\":\"hi back\"") {
|
||||
t.Errorf("body should contain framed message event, got %q", body)
|
||||
// Body must contain the unified python envelope (`code/data.answer`)
|
||||
// and the [DONE] terminator. The iframe SDK JSON.parse()s `answer`
|
||||
// to extract the inner fields, so the embedded JSON is double-encoded
|
||||
// (escaped quotes inside the outer `"answer"` string).
|
||||
if !strings.Contains(body, "\"code\":0") || !strings.Contains(body, `"answer":"{\"answer\":\"hi back\",\"reference\":[]}"`) {
|
||||
t.Errorf("body should contain unified python envelope with answer, got %q", body)
|
||||
}
|
||||
if !strings.HasSuffix(body, "data: [DONE]\n\n") {
|
||||
t.Errorf("body should end with [DONE] terminator, got %q", body)
|
||||
@@ -768,9 +774,9 @@ func TestAgentChatCompletions_StreamSetsContentType(t *testing.T) {
|
||||
// TestAgentChatCompletions_DefaultBranchStreamsSSE covers the
|
||||
// scenario the user actually hit: `openai-compatible: false` with no
|
||||
// `stream` field on the body. The handler must still invoke the
|
||||
// canvas runner and stream the result as SSE — matching Python's
|
||||
// `completion()` which always yields SSE on the non-openai path
|
||||
// regardless of the stream flag.
|
||||
// canvas runner and stream the result as SSE — the SSE envelope is
|
||||
// the unified python shape shared with
|
||||
// /api/v1/agentbots/<id>/completions regardless of the stream flag.
|
||||
func TestAgentChatCompletions_DefaultBranchStreamsSSE(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
@@ -792,8 +798,8 @@ func TestAgentChatCompletions_DefaultBranchStreamsSSE(t *testing.T) {
|
||||
t.Errorf("Content-Type = %q, want text/event-stream (default branch must stream)", got)
|
||||
}
|
||||
body := w.Body.String()
|
||||
if !strings.Contains(body, "\"event\":\"message\"") || !strings.Contains(body, "\"answer\":\"hello back\"") {
|
||||
t.Errorf("body should contain framed message event, got %q", body)
|
||||
if !strings.Contains(body, "\"code\":0") || !strings.Contains(body, `"answer":"{\"answer\":\"hello back\",\"reference\":[]}"`) {
|
||||
t.Errorf("body should contain unified python envelope with answer, got %q", body)
|
||||
}
|
||||
if !strings.HasSuffix(body, "data: [DONE]\n\n") {
|
||||
t.Errorf("body should end with [DONE] terminator, got %q", body)
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/entity"
|
||||
"ragflow/internal/server/local"
|
||||
"ragflow/internal/service"
|
||||
|
||||
@@ -28,7 +29,17 @@ import (
|
||||
|
||||
// AuthHandler auth handler
|
||||
type AuthHandler struct {
|
||||
userService *service.UserService
|
||||
userService userTokenResolver
|
||||
}
|
||||
|
||||
// userTokenResolver is the subset of UserService the auth
|
||||
// middleware actually depends on. We keep it as a small interface
|
||||
// so the test suite can swap in a stub without spinning up the
|
||||
// full UserService (which requires a live Redis + JWT secret).
|
||||
type userTokenResolver interface {
|
||||
GetUserByToken(authorization string) (*entity.User, common.ErrorCode, error)
|
||||
GetUserByAPIToken(token string) (*entity.User, common.ErrorCode, error)
|
||||
GetUserByBetaAPIToken(token string) (*entity.User, common.ErrorCode, error)
|
||||
}
|
||||
|
||||
// NewAuthHandler create auth handler
|
||||
@@ -38,6 +49,50 @@ func NewAuthHandler() *AuthHandler {
|
||||
}
|
||||
}
|
||||
|
||||
// BetaAuthMiddleware resolves a `beta` API token from the Authorization
|
||||
// header and sets the user on the gin.Context, mirroring Python's
|
||||
// @login_required(auth_types=AUTH_BETA) used by /chatbots and
|
||||
// /agentbots route groups.
|
||||
//
|
||||
// A beta token can also be a regular user JWT — in that case we
|
||||
// delegate to the existing AuthMiddleware logic. Order of precedence:
|
||||
//
|
||||
// 1. JWT (regular session) → existing UserService.GetUserByToken
|
||||
// 2. Beta API token → GetUserByBetaAPIToken
|
||||
// 3. Fall through → 401
|
||||
//
|
||||
// IMPORTANT: the regular-user branch is NOT gated on a "Bearer "
|
||||
// prefix. UserService.GetUserByToken accepts the raw Authorization
|
||||
// header value and ExtractAccessToken handles Bearer stripping
|
||||
// internally. The existing AuthMiddleware() above also passes the
|
||||
// raw header to GetUserByToken without pre-filtering, so a non-Bearer
|
||||
// regular user token must keep working here too.
|
||||
func (h *AuthHandler) BetaAuthMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
auth := c.GetHeader("Authorization")
|
||||
if auth == "" {
|
||||
jsonError(c, common.CodeUnauthorized, "Authorization required")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
// Try regular user session first (handles JWT, Bearer, or
|
||||
// raw access_token — same dispatch as AuthMiddleware()).
|
||||
if u, code, err := h.userService.GetUserByToken(auth); err == nil && code == common.CodeSuccess {
|
||||
c.Set("user", u)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
// Fall back to beta API token (public bot access).
|
||||
if u, code, err := h.userService.GetUserByBetaAPIToken(auth); err == nil && code == common.CodeSuccess {
|
||||
c.Set("user", u)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
jsonError(c, common.CodeUnauthorized, "Invalid auth credentials")
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
|
||||
// AuthMiddleware JWT auth middleware
|
||||
// Validates that the user is authenticated and is a superuser (admin)
|
||||
func (h *AuthHandler) AuthMiddleware() gin.HandlerFunc {
|
||||
|
||||
66
internal/handler/auth_test.go
Normal file
66
internal/handler/auth_test.go
Normal file
@@ -0,0 +1,66 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"ragflow/internal/common"
|
||||
)
|
||||
|
||||
// TestBetaAuthMiddleware_MissingHeader pins the no-header branch —
|
||||
// the middleware must short-circuit with 401/CodeUnauthorized and
|
||||
// must not call into UserService. The other branches (regular JWT
|
||||
// and beta token) require a live DB to resolve, so they are covered
|
||||
// by the cross-cutting TestBotRoutes_RequireAuth criterion in
|
||||
// bot_test.go.
|
||||
func TestBetaAuthMiddleware_MissingHeader(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ah := &AuthHandler{userService: nil}
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
mw := ah.BetaAuthMiddleware()
|
||||
mw(c)
|
||||
|
||||
if !c.IsAborted() {
|
||||
t.Fatalf("context not aborted, want aborted (no Authorization header)")
|
||||
}
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
// jsonError writes 200 with a CodeUnauthorized body. Confirm the
|
||||
// body shape matches the wire contract used by the rest of the
|
||||
// bot handlers by decoding the JSON envelope and asserting the
|
||||
// code field rather than just checking for a non-empty body.
|
||||
var resp struct {
|
||||
Code common.ErrorCode `json:"code"`
|
||||
}
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("unmarshal response: %v; body = %s", err, rec.Body.String())
|
||||
}
|
||||
if resp.Code != common.CodeUnauthorized {
|
||||
t.Errorf("code = %d, want %d; body = %s",
|
||||
resp.Code, common.CodeUnauthorized, rec.Body.String())
|
||||
}
|
||||
}
|
||||
260
internal/handler/bot.go
Normal file
260
internal/handler/bot.go
Normal file
@@ -0,0 +1,260 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"ragflow/internal/agent/canvas"
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/service"
|
||||
)
|
||||
|
||||
// BotHandler is the handler for the public chatbot/agentbot
|
||||
// endpoints mounted on /api/v1/chatbots/* and /api/v1/agentbots/*.
|
||||
// The two route groups share BetaAuthMiddleware (set up at
|
||||
// registration time via g.Use(mw)) and share the same handler
|
||||
// struct because they are wired to the same BotService.
|
||||
type BotHandler struct {
|
||||
botService botService
|
||||
}
|
||||
|
||||
// botService is the subset of BotService used by these handlers. It
|
||||
// is interface-typed so the test suite can inject a stub.
|
||||
type botService interface {
|
||||
ChatbotInfo(ctx context.Context, tenantID, dialogID string) (
|
||||
title, avatar, prologue, llmID string, hasTavilyKey bool, ec common.ErrorCode, err error)
|
||||
AgentbotInputs(ctx context.Context, tenantID, agentID string) (
|
||||
title, avatar, prologue, mode string, inputs map[string]any,
|
||||
ec common.ErrorCode, err error)
|
||||
AgentbotCompletion(ctx context.Context, tenantID, agentID string, req service.AgentbotCompletionRequest) (
|
||||
<-chan canvas.RunEvent, common.ErrorCode, error)
|
||||
ChatbotCompletion(ctx context.Context, tenantID, dialogID string, req service.ChatbotCompletionRequest) (
|
||||
<-chan service.ChatbotSSEFrame, common.ErrorCode, error)
|
||||
}
|
||||
|
||||
// NewBotHandler wires a BotHandler with the production BotService.
|
||||
func NewBotHandler(svc *service.BotService) *BotHandler {
|
||||
return &BotHandler{botService: svc}
|
||||
}
|
||||
|
||||
// ChatbotInfo GET /api/v1/chatbots/<dialog_id>/info
|
||||
//
|
||||
// Mirrors python bot_api.py:126-154. Returns the public metadata of
|
||||
// a chatbot dialog (title, avatar, prologue, tavily key flag, llm_id).
|
||||
func (h *BotHandler) ChatbotInfo(c *gin.Context) {
|
||||
user, code, msg := GetUser(c)
|
||||
if code != common.CodeSuccess {
|
||||
jsonError(c, code, msg)
|
||||
return
|
||||
}
|
||||
dialogID := c.Param("dialog_id")
|
||||
if dialogID == "" {
|
||||
jsonError(c, common.CodeArgumentError, "`dialog_id` is required.")
|
||||
return
|
||||
}
|
||||
title, avatar, prologue, llmID, hasTavily, ec, err := h.botService.ChatbotInfo(
|
||||
c.Request.Context(), user.ID, dialogID)
|
||||
if err != nil {
|
||||
jsonError(c, ec, err.Error())
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"code": common.CodeSuccess,
|
||||
"data": gin.H{
|
||||
"title": title,
|
||||
"avatar": avatar,
|
||||
"prologue": prologue,
|
||||
"has_tavily_key": hasTavily,
|
||||
"llm_id": llmID,
|
||||
},
|
||||
"message": "success",
|
||||
})
|
||||
}
|
||||
|
||||
// AgentbotInputs GET /api/v1/agentbots/<agent_id>/inputs
|
||||
//
|
||||
// Mirrors python bot_api.py:239-250. Returns the public metadata of
|
||||
// an agentbot canvas (title, avatar, inputs, prologue, mode).
|
||||
func (h *BotHandler) AgentbotInputs(c *gin.Context) {
|
||||
user, code, msg := GetUser(c)
|
||||
if code != common.CodeSuccess {
|
||||
jsonError(c, code, msg)
|
||||
return
|
||||
}
|
||||
agentID := c.Param("agent_id")
|
||||
if agentID == "" {
|
||||
jsonError(c, common.CodeArgumentError, "`agent_id` is required.")
|
||||
return
|
||||
}
|
||||
title, avatar, prologue, mode, inputs, ec, err := h.botService.AgentbotInputs(
|
||||
c.Request.Context(), user.ID, agentID)
|
||||
if err != nil {
|
||||
jsonError(c, ec, err.Error())
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"code": common.CodeSuccess,
|
||||
"data": gin.H{
|
||||
"title": title,
|
||||
"avatar": avatar,
|
||||
"inputs": inputs,
|
||||
"prologue": prologue,
|
||||
"mode": mode,
|
||||
},
|
||||
"message": "success",
|
||||
})
|
||||
}
|
||||
|
||||
// AgentbotCompletion POST /api/v1/agentbots/<agent_id>/completions
|
||||
//
|
||||
// Mirrors python bot_api.py:157 (canvas_service.completion wrapper).
|
||||
// Streams SSE frames in the Python envelope shape. The URL-bound
|
||||
// agent_id is authoritative — the body must NOT override it.
|
||||
//
|
||||
// Each canvas.RunEvent is re-formatted into the Python
|
||||
// {code, message, data} envelope: a "message" event's Data string is
|
||||
// treated as the assistant text, "message_end" terminates the
|
||||
// stream with the python completion marker.
|
||||
func (h *BotHandler) AgentbotCompletion(c *gin.Context) {
|
||||
user, code, msg := GetUser(c)
|
||||
if code != common.CodeSuccess {
|
||||
jsonError(c, code, msg)
|
||||
return
|
||||
}
|
||||
agentID := c.Param("agent_id")
|
||||
if agentID == "" {
|
||||
jsonError(c, common.CodeArgumentError, "`agent_id` is required.")
|
||||
return
|
||||
}
|
||||
var body service.AgentbotCompletionRequest
|
||||
// ContentLength != 0 (not > 0) so chunked requests carrying a
|
||||
// valid JSON body with ContentLength == -1 still bind. The old
|
||||
// `> 0` guard silently dropped those payloads and the canvas
|
||||
// then ran with empty inputs.
|
||||
if c.Request.ContentLength != 0 {
|
||||
if err := c.ShouldBindJSON(&body); err != nil {
|
||||
jsonError(c, common.CodeArgumentError, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
events, ec, err := h.botService.AgentbotCompletion(
|
||||
c.Request.Context(), user.ID, agentID, body)
|
||||
if err != nil {
|
||||
jsonError(c, ec, err.Error())
|
||||
return
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
for ev := range events {
|
||||
switch ev.Type {
|
||||
case "message":
|
||||
// The python iframe_completion wrapper flattens each
|
||||
// message chunk into a {code:0, data:{answer:...}}
|
||||
// frame. We forward the message Data as the assistant
|
||||
// text payload so the iframe SDK's `data.answer`
|
||||
// parser keeps working. The agentbot path uses
|
||||
// WriteAgentbotFrame (a thin alias for
|
||||
// WriteChatbotFrame) to keep the two paths visually
|
||||
// distinct in the handler.
|
||||
frame := service.ChatbotSSEFrame{
|
||||
Data: ev.Data,
|
||||
Reference: map[string]any{},
|
||||
SessionID: ev.SessionID,
|
||||
}
|
||||
if err := service.WriteAgentbotFrame(c.Writer, frame); err != nil {
|
||||
return
|
||||
}
|
||||
case "message_end", "done":
|
||||
// Terminator events. message_end occasionally carries
|
||||
// a final payload (e.g. structured output); forward
|
||||
// it as a final answer frame when present, then close
|
||||
// the stream with the standard python completion
|
||||
// marker. A bare `done` event closes the stream
|
||||
// directly.
|
||||
if ev.Data != "" {
|
||||
frame := service.ChatbotSSEFrame{
|
||||
Data: ev.Data,
|
||||
Reference: map[string]any{},
|
||||
SessionID: ev.SessionID,
|
||||
}
|
||||
_ = service.WriteAgentbotFrame(c.Writer, frame)
|
||||
}
|
||||
_ = service.WriteDoneFrame(c.Writer)
|
||||
return
|
||||
default:
|
||||
// Non-message events (node_started, node_finished, …)
|
||||
// are silently dropped on the agentbot path. The
|
||||
// python canvas_service.completion wrapper only
|
||||
// forwards the assistant text frames, not the run
|
||||
// telemetry; we mirror that behaviour so external
|
||||
// widgets see the same wire shape.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ChatbotCompletion POST /api/v1/chatbots/<dialog_id>/completions
|
||||
//
|
||||
// Mirrors python bot_api.py:55 (async_iframe_completion). Streams
|
||||
// SSE frames in the Python envelope shape. The streaming helper
|
||||
// lives in service/bot_completion.go.
|
||||
func (h *BotHandler) ChatbotCompletion(c *gin.Context) {
|
||||
user, code, msg := GetUser(c)
|
||||
if code != common.CodeSuccess {
|
||||
jsonError(c, code, msg)
|
||||
return
|
||||
}
|
||||
dialogID := c.Param("dialog_id")
|
||||
if dialogID == "" {
|
||||
jsonError(c, common.CodeArgumentError, "`dialog_id` is required.")
|
||||
return
|
||||
}
|
||||
var body service.ChatbotCompletionRequest
|
||||
// ContentLength != 0 (not > 0) so chunked requests carrying a
|
||||
// valid JSON body with ContentLength == -1 still bind. The old
|
||||
// `> 0` guard silently dropped those payloads and the chatbot
|
||||
// then ran with empty session_id/question.
|
||||
if c.Request.ContentLength != 0 {
|
||||
if err := c.ShouldBindJSON(&body); err != nil {
|
||||
jsonError(c, common.CodeArgumentError, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
frames, ec, err := h.botService.ChatbotCompletion(
|
||||
c.Request.Context(), user.ID, dialogID, body)
|
||||
if err != nil {
|
||||
jsonError(c, ec, err.Error())
|
||||
return
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
for f := range frames {
|
||||
if f.Done {
|
||||
if err := service.WriteDoneFrame(c.Writer); err != nil {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
if err := service.WriteChatbotFrame(c.Writer, f); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
1081
internal/handler/bot_test.go
Normal file
1081
internal/handler/bot_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -58,6 +58,7 @@ func RegisterAgentRoutes(g *gin.RouterGroup, h *handler.AgentHandler) {
|
||||
|
||||
// File operations.
|
||||
g.GET("/download", h.DownloadAgentFile)
|
||||
g.GET("/attachments/:attachment_id/download", h.DownloadAttachment)
|
||||
g.POST("/:canvas_id/upload", h.UploadAgentFile)
|
||||
|
||||
// Component introspection + debug.
|
||||
|
||||
57
internal/router/bot_routes.go
Normal file
57
internal/router/bot_routes.go
Normal file
@@ -0,0 +1,57 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package router
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"ragflow/internal/handler"
|
||||
)
|
||||
|
||||
// RegisterChatbotRoutes wires the dialog (legacy chatbot) endpoints
|
||||
// on the /api/v1/chatbots subtree. Mirrors python
|
||||
//
|
||||
// @manager.route("/chatbots/<dialog_id>/completions") bot_api.py:55
|
||||
// @manager.route("/chatbots/<dialog_id>/info") bot_api.py:126
|
||||
//
|
||||
// Both routes use BetaAuthMiddleware as a group-level middleware.
|
||||
// The two bot route groups (chatbots + agentbots) cannot share a
|
||||
// registrar because each carries a different <param_name>
|
||||
// (dialog_id vs agent_id) and would otherwise register paths under
|
||||
// the wrong group.
|
||||
func RegisterChatbotRoutes(g *gin.RouterGroup, mw gin.HandlerFunc, h *handler.BotHandler) {
|
||||
if g == nil || h == nil {
|
||||
return
|
||||
}
|
||||
g.Use(mw)
|
||||
g.POST("/:dialog_id/completions", h.ChatbotCompletion)
|
||||
g.GET("/:dialog_id/info", h.ChatbotInfo)
|
||||
}
|
||||
|
||||
// RegisterAgentbotRoutes wires the canvas-based agent endpoints on
|
||||
// the /api/v1/agentbots subtree. Mirrors python
|
||||
//
|
||||
// @manager.route("/agentbots/<agent_id>/completions") bot_api.py:157
|
||||
// @manager.route("/agentbots/<agent_id>/inputs") bot_api.py:239
|
||||
func RegisterAgentbotRoutes(g *gin.RouterGroup, mw gin.HandlerFunc, h *handler.BotHandler) {
|
||||
if g == nil || h == nil {
|
||||
return
|
||||
}
|
||||
g.Use(mw)
|
||||
g.POST("/:agent_id/completions", h.AgentbotCompletion)
|
||||
g.GET("/:agent_id/inputs", h.AgentbotInputs)
|
||||
}
|
||||
@@ -52,6 +52,7 @@ type Router struct {
|
||||
modelHandler *handler.ModelHandler
|
||||
fileCommitHandler *handler.FileCommitHandler
|
||||
adminRuntimeHandler *handler.AdminRuntimeHandler
|
||||
botHandler *handler.BotHandler
|
||||
}
|
||||
|
||||
// NewRouter create router
|
||||
@@ -84,6 +85,7 @@ func NewRouter(
|
||||
fileCommitHandler *handler.FileCommitHandler,
|
||||
adminRuntimeHandler *handler.AdminRuntimeHandler,
|
||||
openaiChatHandler *handler.OpenAIChatHandler,
|
||||
botHandler *handler.BotHandler,
|
||||
) *Router {
|
||||
return &Router{
|
||||
authHandler: authHandler,
|
||||
@@ -114,6 +116,7 @@ func NewRouter(
|
||||
modelHandler: modelHandler,
|
||||
fileCommitHandler: fileCommitHandler,
|
||||
adminRuntimeHandler: adminRuntimeHandler,
|
||||
botHandler: botHandler,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -183,6 +186,20 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
apiNoAuth.POST("/auth/password/forgot/otp", r.userHandler.ForgotSendOTP)
|
||||
apiNoAuth.POST("/auth/password/forgot/otp/verify", r.userHandler.ForgotVerifyOTP)
|
||||
apiNoAuth.POST("/auth/password/reset", r.userHandler.ForgotResetPassword)
|
||||
|
||||
// Public bot endpoints — beta API token only, NOT regular
|
||||
// user session. Mirrors python's
|
||||
// @login_required(auth_types=AUTH_BETA) on bot_api.py:55,126,157,239.
|
||||
// Mounted on apiNoAuth (not on the auth-protected v1 tree) so
|
||||
// external widgets / iframes / downloads can hit them with
|
||||
// only a beta token. Risk R0 of the plan.
|
||||
if r.botHandler != nil {
|
||||
betaMW := r.authHandler.BetaAuthMiddleware()
|
||||
chatbotGroup := apiNoAuth.Group("/chatbots")
|
||||
RegisterChatbotRoutes(chatbotGroup, betaMW, r.botHandler)
|
||||
agentbotGroup := apiNoAuth.Group("/agentbots")
|
||||
RegisterAgentbotRoutes(agentbotGroup, betaMW, r.botHandler)
|
||||
}
|
||||
}
|
||||
|
||||
// Protected routes
|
||||
|
||||
259
internal/service/bot.go
Normal file
259
internal/service/bot.go
Normal file
@@ -0,0 +1,259 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
// BotService is the shared service layer for the public
|
||||
// chatbot/agentbot endpoints (api/v1/chatbots/...,
|
||||
// api/v1/agentbots/...) plus the agent attachment download. It is
|
||||
// intentionally a thin aggregator — it sequences DAO lookups, the
|
||||
// tenant/status authorisation guard, and delegates the heavy work
|
||||
// (LLM call, canvas run) to the existing services.
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"ragflow/internal/agent/canvas"
|
||||
"ragflow/internal/agent/dsl"
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
// BotService coordinates chatbot + agentbot reads and the matching
|
||||
// completion paths. Mirrors the Python
|
||||
// `api/db/services/conversation_service.py::async_iframe_completion`
|
||||
// + `api/db/services/canvas_service.py::completion` flow but stays
|
||||
// stateless — it does not own the LLM or canvas runner; it just
|
||||
// sequences them.
|
||||
type BotService struct {
|
||||
chatDAO *dao.ChatSessionDAO
|
||||
canvasDAO *dao.UserCanvasDAO
|
||||
api4ConversationDAO *dao.API4ConversationDAO
|
||||
agentService *AgentService
|
||||
llmService *LLMService
|
||||
}
|
||||
|
||||
// NewBotService wires a fresh BotService. agentSvc is required for
|
||||
// AgentbotCompletion; llmSvc is required for ChatbotCompletion (in
|
||||
// step 6). Both are nullable in unit tests.
|
||||
func NewBotService(agentSvc *AgentService, llmSvc *LLMService) *BotService {
|
||||
return &BotService{
|
||||
chatDAO: dao.NewChatSessionDAO(),
|
||||
canvasDAO: dao.NewUserCanvasDAO(),
|
||||
api4ConversationDAO: dao.NewAPI4ConversationDAO(),
|
||||
agentService: agentSvc,
|
||||
llmService: llmSvc,
|
||||
}
|
||||
}
|
||||
|
||||
// ChatbotInfo returns the public metadata of a chatbot dialog.
|
||||
//
|
||||
// Mirrors the python `bot_api.py::chatbot_info` handler. The
|
||||
// authorisation check is: dialog must exist, the requester must own
|
||||
// it (TenantID match), and Status must equal common.StatusDialogValid
|
||||
// (the python StatusEnum.VALID.value).
|
||||
func (s *BotService) ChatbotInfo(ctx context.Context, tenantID, dialogID string) (
|
||||
title, avatar, prologue, llmID string, hasTavilyKey bool, ec common.ErrorCode, err error,
|
||||
) {
|
||||
dialog, err := s.chatDAO.GetDialogByID(dialogID)
|
||||
if err != nil {
|
||||
return "", "", "", "", false, common.CodeDataError, err
|
||||
}
|
||||
if dialog == nil || dialog.TenantID != tenantID ||
|
||||
dialog.Status == nil || *dialog.Status != common.StatusDialogValid {
|
||||
return "", "", "", "", false, common.CodeDataError,
|
||||
errors.New("Authentication error: no access to this chatbot!")
|
||||
}
|
||||
pc := dialog.PromptConfig
|
||||
// Defensive lookups mirroring python's
|
||||
// dialog.prompt_config.get("prologue", "") and
|
||||
// dialog.prompt_config.get("tavily_api_key", "").strip()
|
||||
// semantics. A hard type assertion here would panic on a missing
|
||||
// or non-string prologue field — this endpoint is public over
|
||||
// persisted JSON config and the schema is not guaranteed.
|
||||
prologue = stringFromMap(pc, "prologue")
|
||||
tk := stringFromMap(pc, "tavily_api_key")
|
||||
return botDerefStr(dialog.Name), botDerefStr(dialog.Icon), prologue,
|
||||
dialog.LLMID, strings.TrimSpace(tk) != "", common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
// AgentbotInputs returns the public metadata of an agentbot canvas.
|
||||
//
|
||||
// Mirrors the python `bot_api.py::agentbot_inputs` handler. The
|
||||
// authorisation check is the same IDOR guard the production
|
||||
// AgentService uses (canvas must be visible to the requesting user).
|
||||
func (s *BotService) AgentbotInputs(ctx context.Context, tenantID, agentID string) (
|
||||
title, avatar, prologue, mode string, inputs map[string]any,
|
||||
ec common.ErrorCode, err error,
|
||||
) {
|
||||
cv, err := s.loadCanvas(ctx, tenantID, agentID)
|
||||
if err != nil {
|
||||
return "", "", "", "", nil, common.CodeDataError, err
|
||||
}
|
||||
dslMap := canvasDSLMap(cv)
|
||||
// Resolve the begin component ID first, then pass that ID to
|
||||
// ExtractComponentInputForm. ExtractComponentInputForm is keyed
|
||||
// by component ID, NOT component name — passing the literal
|
||||
// "begin" would only succeed when the canvas happens to use
|
||||
// "begin" as the component ID.
|
||||
beginID, idErr := dsl.FindBeginComponentID(dslMap)
|
||||
if idErr != nil {
|
||||
// No begin component (or malformed DSL). Degrade gracefully —
|
||||
// empty prologue / mode / inputs, matching the Python
|
||||
// behaviour when Canvas.get_component_input_form returns an
|
||||
// empty dict.
|
||||
return botDerefStr(cv.Title), botDerefStr(cv.Avatar), "", "", nil, common.CodeSuccess, nil
|
||||
}
|
||||
inputs, _ = dsl.ExtractComponentInputForm(dslMap, beginID)
|
||||
prologue, _ = dsl.ExtractPrologue(dslMap)
|
||||
mode, _ = dsl.ExtractMode(dslMap)
|
||||
return botDerefStr(cv.Title), botDerefStr(cv.Avatar), prologue, mode, inputs, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
// AgentbotCompletion is a thin wrapper around AgentService.RunAgent
|
||||
// for the /api/v1/agentbots/<agent_id>/completions endpoint.
|
||||
//
|
||||
// Defence-in-depth (security H2): the IDOR guard runs BEFORE the
|
||||
// delegate so an unauthorised caller can never trigger canvas
|
||||
// compile/invoke (which would spend LLM tokens + emit canvas
|
||||
// telemetry even for "not found" paths). RunAgent re-runs the
|
||||
// same guard internally — this is intentional; the upstream check
|
||||
// is the cheap fast-fail that costs a single DAO roundtrip
|
||||
// instead of a full canvas compile.
|
||||
func (s *BotService) AgentbotCompletion(
|
||||
ctx context.Context, tenantID, agentID string, req AgentbotCompletionRequest,
|
||||
) (<-chan canvas.RunEvent, common.ErrorCode, error) {
|
||||
if s.agentService == nil {
|
||||
return nil, common.CodeServerError, fmt.Errorf("bot: agent service not wired")
|
||||
}
|
||||
if _, err := s.loadCanvas(ctx, tenantID, agentID); err != nil {
|
||||
return nil, common.CodeDataError, err
|
||||
}
|
||||
// Compose the canvas user input from req.UserInput (the
|
||||
// `inputs` dict body field) plus the top-level `question` and
|
||||
// `files` fields. The python canvas_service.completion at
|
||||
// api/db/services/canvas_service.py:313 reads all three; the
|
||||
// previous code dropped question/files, so a body like
|
||||
// `{"question":"hi"}` reached the canvas with empty inputs.
|
||||
userInput := make(map[string]any, len(req.UserInput)+2)
|
||||
for k, v := range req.UserInput {
|
||||
userInput[k] = v
|
||||
}
|
||||
if req.Question != "" {
|
||||
userInput["question"] = req.Question
|
||||
}
|
||||
if len(req.Files) > 0 {
|
||||
userInput["files"] = req.Files
|
||||
}
|
||||
ch, err := s.agentService.RunAgent(ctx, tenantID, agentID,
|
||||
req.SessionID, "", userInput)
|
||||
if err != nil {
|
||||
return nil, common.CodeDataError, err
|
||||
}
|
||||
return ch, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
// AgentbotCompletionRequest is the request body for
|
||||
// /api/v1/agentbots/<agent_id>/completions. We intentionally accept
|
||||
// the same fields the production /agents/chat/completions handler
|
||||
// accepts; the URL-bound agent_id is the authoritative canvas id
|
||||
// (matches python bot_api.py:159).
|
||||
type AgentbotCompletionRequest struct {
|
||||
SessionID string `json:"session_id"`
|
||||
UserID string `json:"user_id"`
|
||||
Stream bool `json:"stream"`
|
||||
// UserInput is the dict-shaped root input the canvas run expects
|
||||
// (mirrors the python "question"/"files"/"inputs" trio collapsed
|
||||
// into one map).
|
||||
UserInput map[string]any `json:"inputs"`
|
||||
Question string `json:"question"`
|
||||
Files []string `json:"files"`
|
||||
}
|
||||
|
||||
// ChatbotCompletionRequest is the request body for
|
||||
// /api/v1/chatbots/<dialog_id>/completions. Mirrors the python
|
||||
// `async_iframe_completion` body shape (session_id, question,
|
||||
// tts (unused) and a freeform dict).
|
||||
type ChatbotCompletionRequest struct {
|
||||
SessionID string `json:"session_id"`
|
||||
Question string `json:"question"`
|
||||
Stream bool `json:"stream"`
|
||||
Inputs map[string]any `json:"inputs"`
|
||||
}
|
||||
|
||||
// loadCanvas is the IDOR guard for agentbot reads. It mirrors the
|
||||
// private loadCanvasForUser helper on AgentService without taking a
|
||||
// dependency on the agentService pointer (so BotService can be unit-
|
||||
// tested with a nil agentService).
|
||||
func (s *BotService) loadCanvas(ctx context.Context, tenantID, agentID string) (*entity.UserCanvas, error) {
|
||||
if agentID == "" {
|
||||
return nil, dao.ErrUserCanvasNotFound
|
||||
}
|
||||
if tenantID == "" {
|
||||
return nil, dao.ErrUserCanvasNotFound
|
||||
}
|
||||
userTenantDAO := dao.NewUserTenantDAO()
|
||||
tenants, err := userTenantDAO.GetTenantIDsByUserID(tenantID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("bot: tenants for user %s: %w", tenantID, err)
|
||||
}
|
||||
return s.canvasDAO.GetByIDForUser(agentID, tenantID, tenants)
|
||||
}
|
||||
|
||||
// canvasDSLMap projects a UserCanvas.DSL JSONMap into a
|
||||
// map[string]any. Returns an empty map (not nil) on miss so
|
||||
// downstream dsl helpers can still scan it.
|
||||
func canvasDSLMap(cv *entity.UserCanvas) map[string]any {
|
||||
if cv == nil {
|
||||
return map[string]any{}
|
||||
}
|
||||
// cv.DSL is entity.JSONMap (alias for map[string]interface{}).
|
||||
// We must return a fresh map[string]any because the dsl
|
||||
// helpers expect that concrete type.
|
||||
return map[string]any(cv.DSL)
|
||||
}
|
||||
|
||||
// botDerefStr returns *s or "" if nil. Used to read pointer-string
|
||||
// fields on entities (Name, Icon, Title, Avatar). Prefixed with bot
|
||||
// to avoid colliding with the test-only botDerefStr in
|
||||
// openai_chat_test.go.
|
||||
func botDerefStr(s *string) string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
return *s
|
||||
}
|
||||
|
||||
// stringFromMap returns m[key] as a string. Returns "" if the key is
|
||||
// absent or the value is not a string. Used for defensive reads
|
||||
// over JSONMap-shaped fields (dialog.prompt_config) where a hard
|
||||
// type assertion would panic.
|
||||
func stringFromMap(m entity.JSONMap, key string) string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
v, ok := m[key]
|
||||
if !ok || v == nil {
|
||||
return ""
|
||||
}
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
return ""
|
||||
}
|
||||
397
internal/service/bot_completion.go
Normal file
397
internal/service/bot_completion.go
Normal file
@@ -0,0 +1,397 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
// bot_completion.go is the SSE envelope writer + ChatbotCompletion
|
||||
// service path for /api/v1/chatbots/<dialog_id>/completions. The wire
|
||||
// shape is dictated by the existing python
|
||||
// `api/db/services/conversation_service.py::async_iframe_completion`
|
||||
// — JS widgets reading the iframe SDK expect this exact envelope, so
|
||||
// any change to the frame keys is a wire-contract change.
|
||||
//
|
||||
// Frame shape (one JSON object per `data:` line):
|
||||
//
|
||||
// {"code":0,"message":"","data":{"answer":"...","reference":{...},
|
||||
// "audio_binary":null,"id":"...","session_id":"..."}, ...}
|
||||
//
|
||||
// The final completion marker is `data: {"code":0,"message":"",
|
||||
// "data":true}` followed by the OpenAI-style `data: [DONE]` line
|
||||
// that the existing Go SSE writers emit on the production
|
||||
// /agents/chat/completions path.
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"ragflow/internal/agent/canvas"
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/entity"
|
||||
modelModule "ragflow/internal/entity/models"
|
||||
)
|
||||
|
||||
// ChatbotSSEFrame is one envelope pushed to the SSE writer by the
|
||||
// chatbot completion path. Err takes precedence over Data and is
|
||||
// rendered as a python-style {code:500, message:str(e),
|
||||
// data:{answer:"**ERROR**..."}} frame.
|
||||
type ChatbotSSEFrame struct {
|
||||
Data string `json:"-"`
|
||||
Reference map[string]any `json:"-"`
|
||||
SessionID string `json:"-"`
|
||||
Done bool `json:"-"`
|
||||
Err error `json:"-"`
|
||||
}
|
||||
|
||||
// WriteChatbotFrame emits one python-style SSE frame and flushes the
|
||||
// underlying http.ResponseWriter. The frame is `data: <json>\n\n`
|
||||
// and is byte-equivalent to the python side so the iframe SDK and
|
||||
// existing JS widgets keep working.
|
||||
//
|
||||
// Error frames sanitize the message — internal errors (gorm stack
|
||||
// frames, SQL details, storage paths) MUST NOT be echoed to the
|
||||
// client. The caller is expected to log the real error via
|
||||
// common.Error / zap before publishing the frame; only a generic
|
||||
// placeholder is rendered here. Mirrors the python
|
||||
// `api/db/services/conversation_service.py` error frame shape.
|
||||
func WriteChatbotFrame(w http.ResponseWriter, f ChatbotSSEFrame) error {
|
||||
var payload map[string]any
|
||||
if f.Err != nil {
|
||||
const clientErrMsg = "an internal error occurred"
|
||||
payload = map[string]any{
|
||||
"code": 500,
|
||||
"message": clientErrMsg,
|
||||
"data": map[string]any{
|
||||
"answer": clientErrMsg,
|
||||
"reference": map[string]any{},
|
||||
},
|
||||
}
|
||||
} else {
|
||||
payload = map[string]any{
|
||||
"code": 0,
|
||||
"message": "",
|
||||
"data": map[string]any{
|
||||
"answer": f.Data,
|
||||
"reference": f.Reference,
|
||||
"audio_binary": nil,
|
||||
"id": nil,
|
||||
"session_id": f.SessionID,
|
||||
},
|
||||
}
|
||||
}
|
||||
b, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := w.Write([]byte("data: ")); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := w.Write(b); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := w.Write([]byte("\n\n")); err != nil {
|
||||
return err
|
||||
}
|
||||
if flusher, ok := w.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteDoneFrame emits the python completion marker
|
||||
// `data: {"code":0,"message":"","data":true}\n\n` followed by the
|
||||
// OpenAI-style `data: [DONE]\n\n` terminator. Used by both bot
|
||||
// completion paths.
|
||||
func WriteDoneFrame(w http.ResponseWriter) error {
|
||||
if _, err := w.Write([]byte(`data: {"code":0,"message":"","data":true}` + "\n\n")); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := w.Write([]byte("data: [DONE]\n\n")); err != nil {
|
||||
return err
|
||||
}
|
||||
if flusher, ok := w.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteChatbotRunEvent translates one canvas.RunEvent into the
|
||||
// unified python-shaped chat-completion envelope (same shape as
|
||||
// WriteChatbotFrame). This unifies the SSE format across:
|
||||
//
|
||||
// - /api/v1/agents/chat/completions (was: writeChatCompletionSSE)
|
||||
// - /api/v1/agentbots/<id>/completions (was: WriteChatbotFrame per-event)
|
||||
//
|
||||
// The "done" event type emits `data: [DONE]\n\n` (no envelope),
|
||||
// matching the OpenAI-style terminator and the existing
|
||||
// AgentbotCompletion wire.
|
||||
//
|
||||
// For non-done events, ev.Data is placed verbatim into the `answer`
|
||||
// field — callers pass canvas-runner output that is itself a JSON
|
||||
// string (e.g. `{"answer":"hi back","reference":[]}`); the iframe
|
||||
// SDK then JSON.parse()s the `answer` string to extract the inner
|
||||
// fields. This matches the existing AgentbotCompletion behaviour.
|
||||
//
|
||||
// Returns the write error so callers can short-circuit; both nil
|
||||
// and io.ErrClosedPipe are tolerated because the client may have
|
||||
// disconnected mid-stream.
|
||||
func WriteChatbotRunEvent(w http.ResponseWriter, ev canvas.RunEvent) error {
|
||||
if ev.Type == "done" {
|
||||
_, err := w.Write([]byte("data: [DONE]\n\n"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if flusher, ok := w.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
f := ChatbotSSEFrame{
|
||||
Data: ev.Data,
|
||||
Reference: map[string]any{},
|
||||
SessionID: ev.SessionID,
|
||||
}
|
||||
return WriteChatbotFrame(w, f)
|
||||
}
|
||||
|
||||
// AgentbotSSEFrame mirrors ChatbotSSEFrame for the agentbot
|
||||
// completion path. The envelope shape is the same; the only
|
||||
// difference is that the LLM call goes through the canvas runner
|
||||
// (AgentService.RunAgent) instead of the legacy dialog async_chat.
|
||||
type AgentbotSSEFrame = ChatbotSSEFrame
|
||||
|
||||
// WriteAgentbotFrame is an alias for WriteChatbotFrame — both bot
|
||||
// completion paths emit the same python wire shape.
|
||||
func WriteAgentbotFrame(w http.ResponseWriter, f ChatbotSSEFrame) error {
|
||||
return WriteChatbotFrame(w, f)
|
||||
}
|
||||
|
||||
// ChatbotCompletion streams an SSE response for
|
||||
// /api/v1/chatbots/<dialog_id>/completions.
|
||||
//
|
||||
// The full LLM session-lifecycle implementation is added below. It
|
||||
// is a v1 port: it yields a single frame per turn (the Go LLMBundle
|
||||
// chat call is non-streaming), seeded with the dialog's prologue
|
||||
// when the request creates a new session.
|
||||
//
|
||||
// Authorisation: dialog must exist, belong to the requester's tenant,
|
||||
// and have status == common.StatusDialogValid.
|
||||
func (s *BotService) ChatbotCompletion(
|
||||
ctx context.Context, tenantID, dialogID string, req ChatbotCompletionRequest,
|
||||
) (<-chan ChatbotSSEFrame, common.ErrorCode, error) {
|
||||
// 1. Load and authorise the dialog.
|
||||
//
|
||||
// ChatSessionDAO.GetDialogByID already filters by status = "1"
|
||||
// so a returned row is valid; we still nil-check defensively
|
||||
// before dereferencing for symmetry with the session path.
|
||||
dialog, err := s.chatDAO.GetDialogByID(dialogID)
|
||||
if err != nil || dialog == nil ||
|
||||
dialog.TenantID != tenantID ||
|
||||
dialog.Status == nil || *dialog.Status != common.StatusDialogValid {
|
||||
return nil, common.CodeDataError, errors.New("no access to this chatbot")
|
||||
}
|
||||
|
||||
// 2. Resolve or create the session row.
|
||||
//
|
||||
// API4ConversationDAO.GetBySessionID returns (nil, nil) on miss
|
||||
// (not an error) — see internal/dao/api_token.go:146. We MUST
|
||||
// check the pointer before dereferencing, otherwise the
|
||||
// session-tenant check below nil-derefs. Plan Risk R7.
|
||||
//
|
||||
// UserID vs tenantID (security H3 follow-up):
|
||||
// `entity.API4Conversation.UserID` is a generic user-id slot
|
||||
// in the production Python flow
|
||||
// (api/db/services/conversation_service.py:258 — the python
|
||||
// async_iframe_completion saves `user_id=kwargs.get("user_id", "")`).
|
||||
// The Go BotHandler routes pass `user.ID` through the
|
||||
// "tenantID" parameter (the Go User struct collapses user and
|
||||
// tenant into one identifier — see project CLAUDE.md), so
|
||||
// writing `tenantID` here actually stores the requester's
|
||||
// user-id (== tenant-id) in the python user-id slot. The
|
||||
// session-tenant check on the read path compares against the
|
||||
// same value, so write/read stay symmetric. We keep this
|
||||
// behaviour and add the comment so a future reader doesn't
|
||||
// "fix" it to a tenant-id lookup and break the symmetry.
|
||||
var session *entity.API4Conversation
|
||||
if req.SessionID != "" {
|
||||
session, err = s.api4ConversationDAO.GetBySessionID(req.SessionID, dialogID)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
if session == nil || session.UserID != tenantID {
|
||||
return nil, common.CodeDataError, errors.New("session not found")
|
||||
}
|
||||
} else {
|
||||
// Seed a new session. The Message column is json.RawMessage;
|
||||
// pre-serialise the prologue turn as a JSON array of
|
||||
// {role,content,created_at} dicts — same shape the python
|
||||
// conversation_service.py:253-272 writes. Plan Risk R4.
|
||||
prologue := stringFromMap(dialog.PromptConfig, "prologue")
|
||||
seedMsg, _ := json.Marshal([]map[string]any{
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": prologue,
|
||||
"created_at": time.Now().Unix(),
|
||||
},
|
||||
})
|
||||
session = &entity.API4Conversation{
|
||||
ID: uuid.NewString(),
|
||||
DialogID: dialogID,
|
||||
UserID: tenantID,
|
||||
Message: seedMsg,
|
||||
}
|
||||
if err := s.api4ConversationDAO.Create(session); err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Resolve the chat LLM via ModelProviderService. The python
|
||||
// async_iframe_completion resolves the same way through
|
||||
// LLMBundle(tenant_id, dialog.llm_id); the Go equivalent is
|
||||
// GetChatModelConfig → NewChatModel → driver.ChatWithMessages.
|
||||
//
|
||||
// If llmService is unwired (test boot path) or the dialog has
|
||||
// no LLM configured, we surface a sanitized CodeDataError
|
||||
// rather than echoing the bare error string into the SSE
|
||||
// envelope — see WriteChatbotFrame's sanitization contract.
|
||||
if s.llmService == nil {
|
||||
return nil, common.CodeServerError, errors.New("bot: llm service not wired")
|
||||
}
|
||||
if dialog.LLMID == "" {
|
||||
return nil, common.CodeDataError, errors.New("no LLM configured for this chatbot")
|
||||
}
|
||||
modelProvider := NewModelProviderService()
|
||||
driver, modelName, apiConfig, _, err := modelProvider.GetChatModelConfig(tenantID, dialog.LLMID)
|
||||
if err != nil {
|
||||
return nil, common.CodeDataError, errors.New("no LLM configured for this chatbot")
|
||||
}
|
||||
chatModel := modelModule.NewChatModel(driver, &modelName, apiConfig)
|
||||
|
||||
// 4. Build the prompt from prior conversation history plus the
|
||||
// new user turn. Without this, a resumed session_id would
|
||||
// authorise reuse but the LLM call would still be stateless
|
||||
// turn-to-turn — a Python parity regression for any multi-turn
|
||||
// chatbot client. The Message column on api_4_conversation is a
|
||||
// json.RawMessage array of {role, content, created_at} dicts,
|
||||
// matching the python conversation_service.py:253-272 shape.
|
||||
messages := historyToMessages(session.Message)
|
||||
messages = append(messages, modelModule.Message{Role: "user", Content: req.Question})
|
||||
|
||||
// 5. Yield frames on a channel.
|
||||
out := make(chan ChatbotSSEFrame, 4)
|
||||
go func() {
|
||||
defer close(out)
|
||||
resp, callErr := chatModel.ModelDriver.ChatWithMessages(
|
||||
modelName, messages, chatModel.APIConfig, &modelModule.ChatConfig{},
|
||||
)
|
||||
if callErr != nil {
|
||||
// Log the real error with structured context so
|
||||
// ops can debug, but do NOT echo the raw
|
||||
// err.Error() to the client (security M2:
|
||||
// internal gorm/SQL/file-path leaks).
|
||||
common.Error("bot: ChatbotCompletion LLM call failed",
|
||||
callErr,
|
||||
zap.String("dialog_id", dialogID),
|
||||
zap.String("session_id", session.ID),
|
||||
zap.String("llm_id", dialog.LLMID),
|
||||
)
|
||||
out <- ChatbotSSEFrame{
|
||||
Err: errors.New("an internal error occurred"),
|
||||
SessionID: session.ID,
|
||||
}
|
||||
out <- ChatbotSSEFrame{Done: true}
|
||||
return
|
||||
}
|
||||
answer := ""
|
||||
if resp != nil && resp.Answer != nil {
|
||||
answer = *resp.Answer
|
||||
}
|
||||
|
||||
// Persist the new turn pair (user + assistant) back to
|
||||
// api_4_conversation so the NEXT call to ChatbotCompletion
|
||||
// with the same session_id sees this turn in messages.
|
||||
// Update errors are logged but do NOT fail the SSE stream
|
||||
// — the answer has already been produced. The next call
|
||||
// will rebuild from the prior (pre-this-turn) snapshot,
|
||||
// losing at most the latest exchange; acceptable for v1.
|
||||
newTurns := append(historyFromMessages(messages),
|
||||
map[string]any{"role": "assistant", "content": answer, "created_at": time.Now().Unix()},
|
||||
)
|
||||
if updated, mErr := json.Marshal(newTurns); mErr == nil {
|
||||
session.Message = updated
|
||||
if uErr := s.api4ConversationDAO.Update(session); uErr != nil {
|
||||
common.Error("bot: ChatbotCompletion session update failed",
|
||||
uErr,
|
||||
zap.String("dialog_id", dialogID),
|
||||
zap.String("session_id", session.ID),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
out <- ChatbotSSEFrame{
|
||||
Data: answer,
|
||||
Reference: map[string]any{},
|
||||
SessionID: session.ID,
|
||||
}
|
||||
out <- ChatbotSSEFrame{Done: true}
|
||||
}()
|
||||
return out, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
// historyToMessages reads the session.Message JSON array of
|
||||
// {role, content, ...} dicts and projects it onto modelModule.Message
|
||||
// for the LLM driver. Tolerates an empty / malformed Message column
|
||||
// by returning an empty slice — the caller appends the new user turn
|
||||
// so the LLM still receives the current prompt.
|
||||
func historyToMessages(raw json.RawMessage) []modelModule.Message {
|
||||
if len(raw) == 0 {
|
||||
return nil
|
||||
}
|
||||
var turns []map[string]any
|
||||
if err := json.Unmarshal(raw, &turns); err != nil {
|
||||
return nil
|
||||
}
|
||||
out := make([]modelModule.Message, 0, len(turns))
|
||||
for _, t := range turns {
|
||||
role, _ := t["role"].(string)
|
||||
content, _ := t["content"].(string)
|
||||
if role == "" || content == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, modelModule.Message{Role: role, Content: content})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// historyFromMessages is the inverse projection — used to write the
|
||||
// updated turn list back to the api_4_conversation.Message column.
|
||||
func historyFromMessages(msgs []modelModule.Message) []map[string]any {
|
||||
out := make([]map[string]any, 0, len(msgs))
|
||||
now := time.Now().Unix()
|
||||
for i, m := range msgs {
|
||||
out = append(out, map[string]any{
|
||||
"role": m.Role,
|
||||
"content": m.Content,
|
||||
"created_at": now + int64(i), // preserve order, monotonic
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
164
internal/service/bot_completion_history_test.go
Normal file
164
internal/service/bot_completion_history_test.go
Normal file
@@ -0,0 +1,164 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
// Tests for the conversation-history round-trip helpers used by
|
||||
// BotService.ChatbotCompletion. Locks in review Finding 8 — a resumed
|
||||
// session_id must carry prior turns (assistant prologue + earlier
|
||||
// user/assistant exchanges) into the next LLM call so multi-turn
|
||||
// chatbot clients retain context.
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
modelModule "ragflow/internal/entity/models"
|
||||
)
|
||||
|
||||
func TestHistoryToMessages_Empty(t *testing.T) {
|
||||
// A freshly-seeded session with no prior turns returns an empty
|
||||
// slice. Caller appends the new user turn; LLM receives only
|
||||
// the current prompt. Matches python conversation_service seed.
|
||||
got := historyToMessages(nil)
|
||||
if len(got) != 0 {
|
||||
t.Fatalf("nil raw: want 0 messages, got %d", len(got))
|
||||
}
|
||||
got = historyToMessages(json.RawMessage(`[]`))
|
||||
if len(got) != 0 {
|
||||
t.Fatalf("empty array: want 0 messages, got %d", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestHistoryToMessages_RoundTrip(t *testing.T) {
|
||||
// Simulate a session with: 1 prologue assistant turn + 1 prior
|
||||
// user/assistant pair. The LLM must see all 3 prior turns
|
||||
// before the new user turn is appended.
|
||||
turns := []map[string]any{
|
||||
{"role": "assistant", "content": "Hello, how can I help?", "created_at": 1},
|
||||
{"role": "user", "content": "What is Go?", "created_at": 2},
|
||||
{"role": "assistant", "content": "Go is a compiled language.", "created_at": 3},
|
||||
}
|
||||
raw, err := json.Marshal(turns)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal seed: %v", err)
|
||||
}
|
||||
msgs := historyToMessages(raw)
|
||||
if len(msgs) != 3 {
|
||||
t.Fatalf("want 3 prior messages, got %d", len(msgs))
|
||||
}
|
||||
if msgs[0].Role != "assistant" || msgs[0].Content != "Hello, how can I help?" {
|
||||
t.Errorf("turn 0: role=%q content=%q", msgs[0].Role, msgs[0].Content)
|
||||
}
|
||||
if msgs[1].Role != "user" || msgs[1].Content != "What is Go?" {
|
||||
t.Errorf("turn 1: role=%q content=%q", msgs[1].Role, msgs[1].Content)
|
||||
}
|
||||
if msgs[2].Role != "assistant" || msgs[2].Content != "Go is a compiled language." {
|
||||
t.Errorf("turn 2: role=%q content=%q", msgs[2].Role, msgs[2].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHistoryToMessages_Malformed(t *testing.T) {
|
||||
// Malformed JSON must not panic; returns nil so caller falls back
|
||||
// to a fresh single-turn LLM call rather than failing the request.
|
||||
got := historyToMessages(json.RawMessage(`not json`))
|
||||
if got != nil {
|
||||
t.Fatalf("malformed raw: want nil, got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHistoryToMessages_SkipsEmptyFields(t *testing.T) {
|
||||
// Defensive: turns missing role or content are dropped, not
|
||||
// passed to the LLM as empty messages.
|
||||
turns := []map[string]any{
|
||||
{"role": "assistant", "content": "valid", "created_at": 1},
|
||||
{"role": "", "content": "no role", "created_at": 2},
|
||||
{"role": "user", "content": "", "created_at": 3},
|
||||
{"role": "user", "content": "second valid", "created_at": 4},
|
||||
}
|
||||
raw, _ := json.Marshal(turns)
|
||||
msgs := historyToMessages(raw)
|
||||
if len(msgs) != 2 {
|
||||
t.Fatalf("want 2 valid turns, got %d", len(msgs))
|
||||
}
|
||||
if msgs[0].Content != "valid" || msgs[1].Content != "second valid" {
|
||||
t.Errorf("got %+v", msgs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHistoryFromMessages_PreservesOrder(t *testing.T) {
|
||||
// The LLM driver returns messages in the same order the input
|
||||
// was provided. The round-trip must preserve that order so the
|
||||
// next call to ChatbotCompletion sees a coherent history.
|
||||
msgs := []modelModule.Message{
|
||||
{Role: "assistant", Content: "first"},
|
||||
{Role: "user", Content: "second"},
|
||||
{Role: "assistant", Content: "third"},
|
||||
}
|
||||
turns := historyFromMessages(msgs)
|
||||
if len(turns) != 3 {
|
||||
t.Fatalf("want 3 turns, got %d", len(turns))
|
||||
}
|
||||
for i, want := range []string{"first", "second", "third"} {
|
||||
if turns[i]["content"] != want {
|
||||
t.Errorf("turn %d content = %v, want %q", i, turns[i]["content"], want)
|
||||
}
|
||||
if turns[i]["role"] != msgs[i].Role {
|
||||
t.Errorf("turn %d role = %v, want %q", i, turns[i]["role"], msgs[i].Role)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHistoryRoundTrip_PreservesPriorTurns(t *testing.T) {
|
||||
// End-to-end: prior JSON → history → back to JSON must be
|
||||
// semantically identical (modulo the created_at monotonic
|
||||
// adjustment that historyFromMessages applies for ordering).
|
||||
turns := []map[string]any{
|
||||
{"role": "assistant", "content": "p1", "created_at": int64(100)},
|
||||
{"role": "user", "content": "p2", "created_at": int64(200)},
|
||||
}
|
||||
raw, _ := json.Marshal(turns)
|
||||
|
||||
msgs := historyToMessages(raw)
|
||||
// Caller appends a new user turn (the current request).
|
||||
msgs = append(msgs, modelModule.Message{Role: "user", Content: "current"})
|
||||
|
||||
// Round-trip back to JSON for storage.
|
||||
newTurns := historyFromMessages(msgs)
|
||||
raw2, err := json.Marshal(newTurns)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal round-trip: %v", err)
|
||||
}
|
||||
|
||||
var got []map[string]any
|
||||
if err := json.Unmarshal(raw2, &got); err != nil {
|
||||
t.Fatalf("unmarshal round-trip: %v", err)
|
||||
}
|
||||
if len(got) != 3 {
|
||||
t.Fatalf("want 3 turns after round-trip, got %d", len(got))
|
||||
}
|
||||
expected := []struct{ role, content string }{
|
||||
{"assistant", "p1"},
|
||||
{"user", "p2"},
|
||||
{"user", "current"},
|
||||
}
|
||||
for i, want := range expected {
|
||||
if got[i]["role"] != want.role || got[i]["content"] != want.content {
|
||||
t.Errorf("turn %d: got role=%v content=%v, want role=%q content=%q",
|
||||
i, got[i]["role"], got[i]["content"], want.role, want.content)
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user