feat(go-agent): Ported retrieval node, added Keenable web search tool (#16396)

Ported retrieval node, added Keenable web search tool
- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Zhichang Yu
2026-06-26 22:55:49 +08:00
committed by yzc
parent f86a0e7386
commit f58fae5fb7
91 changed files with 5920 additions and 3817 deletions

View File

@@ -29,6 +29,7 @@ import (
"strings"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"ragflow/internal/agent/canvas"
"ragflow/internal/common"
@@ -53,7 +54,7 @@ type agentFileService interface {
// NewAgentHandler assigns the concrete *service.AgentService — which
// satisfies this interface because its RunAgent signature matches.
type chatAgentService interface {
RunAgent(ctx context.Context, userID, canvasID, sessionID, version, userInput string) (<-chan canvas.RunEvent, error)
RunAgent(ctx context.Context, userID, canvasID, sessionID, version string, userInput any) (<-chan canvas.RunEvent, error)
}
// AgentHandler agent handler
@@ -869,6 +870,7 @@ func (h *AgentHandler) DeleteAgentSession(c *gin.Context) {
type agentChatCompletionsRequest struct {
AgentID string `json:"agent_id"`
Query string `json:"query"`
Inputs map[string]interface{} `json:"inputs"`
SessionID string `json:"session_id"`
Stream bool `json:"stream"`
OpenAICompat bool `json:"openai-compatible"`
@@ -895,6 +897,76 @@ func extractLastUserContent(messages []map[string]interface{}) string {
return ""
}
// extractUserInputFromFormInputs mirrors the front-end's wait-for-user submit
// shape: `inputs` is an object keyed by form field name, and each entry carries
// a nested `value`. The current chat-completion resume path consumes a single
// string payload, so we lift the first field's value and stringify it.
func extractUserInputFromFormInputs(inputs map[string]interface{}) interface{} {
if len(inputs) == 0 {
return nil
}
if len(inputs) == 1 {
for _, raw := range inputs {
if field, ok := raw.(map[string]interface{}); ok {
if v, ok := field["value"]; ok {
return v
}
}
return raw
}
}
out := make(map[string]any, len(inputs))
for name, raw := range inputs {
if field, ok := raw.(map[string]interface{}); ok {
if v, ok := field["value"]; ok {
out[name] = v
continue
}
}
out[name] = raw
}
return out
}
func countInputValues(inputs map[string]interface{}) int {
count := 0
for _, raw := range inputs {
if field, ok := raw.(map[string]interface{}); ok {
if _, exists := field["value"]; exists {
count++
}
continue
}
if raw != nil {
count++
}
}
return count
}
func userInputMeta(userInput any) []zap.Field {
fields := []zap.Field{zap.String("user_input_type", fmt.Sprintf("%T", userInput))}
switch v := userInput.(type) {
case nil:
fields = append(fields, zap.Bool("user_input_present", false))
case string:
fields = append(fields,
zap.Bool("user_input_present", true),
zap.Int("user_input_length", len(v)),
zap.Bool("user_input_blank", v == ""),
)
case map[string]interface{}:
fields = append(fields,
zap.Bool("user_input_present", true),
zap.Int("user_input_keys", len(v)),
)
default:
fields = append(fields, zap.Bool("user_input_present", true))
}
return fields
}
func (h *AgentHandler) AgentChatCompletions(c *gin.Context) {
user, code, msg := GetUser(c)
if code != common.CodeSuccess {
@@ -914,6 +986,18 @@ func (h *AgentHandler) AgentChatCompletions(c *gin.Context) {
jsonError(c, common.CodeDataError, "at least one message is required in openai-compatible mode.")
return
}
common.Debug("agent chat completions: request received",
zap.String("user_id", user.ID),
zap.String("agent_id", req.AgentID),
zap.String("session_id", req.SessionID),
zap.Bool("stream", req.Stream),
zap.Bool("openai_compatible", req.OpenAICompat),
zap.Bool("query_present", req.Query != ""),
zap.Int("query_length", len(req.Query)),
zap.Int("inputs_count", len(req.Inputs)),
zap.Int("inputs_with_values_count", countInputValues(req.Inputs)),
zap.Int("messages_count", len(req.Messages)),
)
// TODO(phase5-openai-framing): the openai-compat branches below are
// stubs. They keep the existing "choices"-shape contract for the
@@ -936,13 +1020,31 @@ func (h *AgentHandler) AgentChatCompletions(c *gin.Context) {
// Real canvas run — derive userInput from `query` first, then fall
// back to the last user message (covers the front-end that posts
// running_hint_text without a top-level `query`).
userInput := req.Query
if userInput == "" {
userInput = extractLastUserContent(req.Messages)
var userInput any = req.Query
if req.Query == "" {
if extracted := extractUserInputFromFormInputs(req.Inputs); extracted != nil {
userInput = extracted
} else if extracted := extractLastUserContent(req.Messages); extracted != "" {
userInput = extracted
}
}
common.Debug("agent chat completions: derived user input",
append([]zap.Field{
zap.String("agent_id", req.AgentID),
zap.String("session_id", req.SessionID),
}, userInputMeta(userInput)...)...,
)
events, err := h.chatRunner.RunAgent(c.Request.Context(), user.ID, req.AgentID, req.SessionID, "", userInput)
if err != nil {
common.Warn("agent chat completions: RunAgent failed",
append([]zap.Field{
zap.String("user_id", user.ID),
zap.String("agent_id", req.AgentID),
zap.String("session_id", req.SessionID),
zap.Error(err),
}, userInputMeta(userInput)...)...,
)
ec, em := mapAgentError(err)
jsonError(c, ec, em)
return
@@ -963,8 +1065,19 @@ func (h *AgentHandler) AgentChatCompletions(c *gin.Context) {
// /api/v1/agents/{id}/run endpoint's wire format — see
// writeRunEventSSE at agent.go for that path.
for ev := range events {
writeChatCompletionSSE(c.Writer, flusher, ev)
common.Debug("agent chat completions: streaming event",
zap.String("agent_id", req.AgentID),
zap.String("session_id", req.SessionID),
zap.String("event_type", ev.Type),
zap.String("message_id", ev.MessageID),
zap.String("task_id", ev.TaskID),
)
writeChatCompletionSSE(c.Writer, flusher, req.AgentID, ev)
}
common.Debug("agent chat completions: stream closed",
zap.String("agent_id", req.AgentID),
zap.String("session_id", req.SessionID),
)
}
// writeChatCompletionSSE emits one canvas.RunEvent in the
@@ -973,8 +1086,13 @@ func (h *AgentHandler) AgentChatCompletions(c *gin.Context) {
// data:{"event":"<ev.Type>","message_id":"<ev.MessageID>","created_at":<ev.CreatedAt>,"task_id":"<ev.TaskID>","session_id":"<ev.SessionID>","data":<ev.Data>}
//
// The special "done" type sends `data: [DONE]\n\n` (no JSON envelope).
func writeChatCompletionSSE(w io.Writer, flusher http.Flusher, ev canvas.RunEvent) {
func writeChatCompletionSSE(w io.Writer, flusher http.Flusher, agentID string, ev canvas.RunEvent) {
if ev.Type == "done" {
common.Debug("agent chat completions: writing done sentinel",
zap.String("agent_id", agentID),
zap.String("session_id", ev.SessionID),
zap.String("task_id", ev.TaskID),
)
fmt.Fprint(w, "data: [DONE]\n\n")
if flusher != nil {
flusher.Flush()
@@ -985,6 +1103,13 @@ func writeChatCompletionSSE(w io.Writer, flusher http.Flusher, ev canvas.RunEven
if data == "" {
data = "{}"
}
common.Debug("agent chat completions: writing sse frame",
zap.String("agent_id", agentID),
zap.String("event_type", ev.Type),
zap.String("message_id", ev.MessageID),
zap.String("session_id", ev.SessionID),
zap.String("task_id", ev.TaskID),
)
envelope := sseEnvelope(ev.Type, ev.MessageID, ev.CreatedAt, ev.TaskID, ev.SessionID, data)
fmt.Fprintf(w, "data: %s\n\n", envelope)
if flusher != nil {

View File

@@ -447,7 +447,7 @@ func (f *fullFakeAgentService) UpdateAgent(context.Context, string, string, enti
func (f *fullFakeAgentService) DeleteAgent(context.Context, string, string) error {
return nil
}
func (f *fullFakeAgentService) RunAgent(context.Context, string, string, string, string, string) (<-chan canvas.RunEvent, error) {
func (f *fullFakeAgentService) RunAgent(context.Context, string, string, string, string, any) (<-chan canvas.RunEvent, error) {
ch := make(chan canvas.RunEvent)
close(ch)
return ch, nil
@@ -715,7 +715,7 @@ type stubChatRunner struct {
err error
}
func (s *stubChatRunner) RunAgent(_ context.Context, _, _, _, _, _ string) (<-chan canvas.RunEvent, error) {
func (s *stubChatRunner) RunAgent(_ context.Context, _, _, _, _ string, _ any) (<-chan canvas.RunEvent, error) {
if s.err != nil {
return nil, s.err
}
@@ -815,13 +815,61 @@ func TestAgentChatCompletions_DerivesUserInputFromMessages(t *testing.T) {
c.Set("user", &entity.User{ID: "u1"})
c.Set("user_id", "u1")
var captured string
var captured any
runner := &captureChatRunner{captured: &captured}
h := &AgentHandler{chatRunner: runner}
h.AgentChatCompletions(c)
if captured != "from-messages" {
t.Errorf("userInput = %q, want %q (last user message content)", captured, "from-messages")
t.Errorf("userInput = %#v, want %q (last user message content)", captured, "from-messages")
}
}
// TestAgentChatCompletions_DerivesUserInputFromInputs covers the wait-for-user
// resume path used by the front-end: the follow-up submit posts `inputs`
// instead of a top-level `query`. The handler must lift the nested field value
// and pass it through as the resumed user input.
func TestAgentChatCompletions_DerivesUserInputFromInputs(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/api/v1/agents/chat/completions",
strings.NewReader(`{"agent_id":"a1","session_id":"s1","inputs":{"text":{"name":"text","value":"a b c d e","type":"line"}}}`))
c.Request.Header.Set("Content-Type", "application/json")
c.Set("user", &entity.User{ID: "u1"})
c.Set("user_id", "u1")
var captured any
runner := &captureChatRunner{captured: &captured}
h := &AgentHandler{chatRunner: runner}
h.AgentChatCompletions(c)
if captured != "a b c d e" {
t.Errorf("userInput = %#v, want %q (nested inputs.value)", captured, "a b c d e")
}
}
func TestAgentChatCompletions_DerivesStructuredUserInputFromInputs(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/api/v1/agents/chat/completions",
strings.NewReader(`{"agent_id":"a1","session_id":"s1","inputs":{"kb":{"name":"KB","value":"da1","type":"line"},"query":{"name":"Query","value":"合同","type":"line"}}}`))
c.Request.Header.Set("Content-Type", "application/json")
c.Set("user", &entity.User{ID: "u1"})
c.Set("user_id", "u1")
var captured any
runner := &captureChatRunner{captured: &captured}
h := &AgentHandler{chatRunner: runner}
h.AgentChatCompletions(c)
got, ok := captured.(map[string]any)
if !ok {
t.Fatalf("userInput type = %T, want map[string]any", captured)
}
if got["kb"] != "da1" || got["query"] != "合同" {
t.Fatalf("userInput = %#v, want kb=da1 query=合同", got)
}
}
@@ -829,10 +877,10 @@ func TestAgentChatCompletions_DerivesUserInputFromMessages(t *testing.T) {
// returns an empty (closed) channel. Used to assert on argument
// derivation without exercising the runner.
type captureChatRunner struct {
captured *string
captured *any
}
func (c *captureChatRunner) RunAgent(_ context.Context, _, _, _, _, userInput string) (<-chan canvas.RunEvent, error) {
func (c *captureChatRunner) RunAgent(_ context.Context, _, _, _, _ string, userInput any) (<-chan canvas.RunEvent, error) {
*c.captured = userInput
ch := make(chan canvas.RunEvent)
close(ch)

View File

@@ -24,8 +24,8 @@ import (
"strings"
"testing"
"github.com/glebarez/sqlite"
"github.com/gin-gonic/gin"
"github.com/glebarez/sqlite"
"gorm.io/gorm"
"ragflow/internal/common"
@@ -58,10 +58,10 @@ func setupUploadTestDB(t *testing.T) *gorm.DB {
// fakeUploadFileService implements fileUploader for tests.
type fakeUploadFileService struct {
uploaded []map[string]interface{}
err error
lastTenantID string
lastParentID string
uploaded []map[string]interface{}
err error
lastTenantID string
lastParentID string
}
func (f *fakeUploadFileService) UploadFile(tenantID, parentID string, files []*multipart.FileHeader) ([]map[string]interface{}, error) {

View File

@@ -94,7 +94,7 @@ func (f *waitFakeAgentService) DeleteAgent(context.Context, string, string) erro
// RunAgent mimics service.AgentService.RunAgent for the test
// driver. It loads the canvas (a no-op in tests), builds a RunFunc
// from the supplied stub, and hands off to the orchestrator.
func (f *waitFakeAgentService) RunAgent(ctx context.Context, userID, canvasID, sessionID, version, userInput string) (<-chan canvas.RunEvent, error) {
func (f *waitFakeAgentService) RunAgent(ctx context.Context, userID, canvasID, sessionID, version string, userInput any) (<-chan canvas.RunEvent, error) {
_ = ctx
_ = userID
_ = version

View File

@@ -20,7 +20,7 @@ import (
"context"
"encoding/json"
"errors"
"gorm.io/gorm"
"gorm.io/gorm"
"net/http"
"net/http/httptest"
"strings"
@@ -28,10 +28,10 @@ import (
"ragflow/internal/common"
"ragflow/internal/engine"
"ragflow/internal/engine/types"
"ragflow/internal/entity"
modelModule "ragflow/internal/entity/models"
"ragflow/internal/service/nlp"
"ragflow/internal/engine/types"
"github.com/gin-gonic/gin"
)
@@ -135,12 +135,12 @@ type mockDocEngine struct {
engine.DocEngine
}
func (m *mockDocEngine) Close() error { return nil }
func (m *mockDocEngine) Ping(ctx context.Context) error { return nil }
func (m *mockDocEngine) GetType() string { return "mock" }
func (m *mockDocEngine) Search(ctx context.Context, req *types.SearchRequest) (*types.SearchResult, error) {
return &types.SearchResult{}, nil
}
func (m *mockDocEngine) Close() error { return nil }
func (m *mockDocEngine) Ping(ctx context.Context) error { return nil }
func (m *mockDocEngine) GetType() string { return "mock" }
func (m *mockDocEngine) Search(ctx context.Context, req *types.SearchRequest) (*types.SearchResult, error) {
return &types.SearchResult{}, nil
}
func (m *mockDocEngine) GetChunk(ctx context.Context, _, _ string, _ []string) (interface{}, error) {
return map[string]interface{}{}, nil
}

View File

@@ -300,14 +300,14 @@ func (h *FileCommitHandler) GetCommit(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeSuccess,
"data": gin.H{
"id": commit.ID,
"folder_id": commit.FolderID,
"parent_id": commit.ParentID,
"message": commit.Message,
"author_id": commit.AuthorID,
"file_count": commit.FileCount,
"id": commit.ID,
"folder_id": commit.FolderID,
"parent_id": commit.ParentID,
"message": commit.Message,
"author_id": commit.AuthorID,
"file_count": commit.FileCount,
"create_time": ct,
"files": items,
"files": items,
},
"message": common.CodeSuccess.Message(),
})

View File

@@ -31,15 +31,15 @@ import (
// mockFileCommitSvc implements FileCommitServiceInterface for testing
type mockFileCommitSvc struct {
createCommitFn func(folderID, authorID, message string, changes []entity.FileChange) (*entity.FileCommit, error)
listCommitsFn func(folderID string, page, pageSize int, orderBy string, desc bool) ([]*entity.FileCommit, int64, error)
getCommitFn func(commitID string) (*entity.FileCommit, error)
listCommitFilesFn func(commitID string) ([]*entity.FileCommitItem, error)
diffCommitsFn func(fromID, toID string) ([]entity.DiffEntry, error)
getUncommittedChangesFn func(folderID string) ([]entity.DiffEntry, error)
getCommitTreeFn func(commitID string) (map[string]interface{}, error)
getCommitFileContentFn func(folderID, commitID, fileID string) ([]byte, error)
getFileVersionHistoryFn func(fileID string) ([]entity.VersionEntry, error)
createCommitFn func(folderID, authorID, message string, changes []entity.FileChange) (*entity.FileCommit, error)
listCommitsFn func(folderID string, page, pageSize int, orderBy string, desc bool) ([]*entity.FileCommit, int64, error)
getCommitFn func(commitID string) (*entity.FileCommit, error)
listCommitFilesFn func(commitID string) ([]*entity.FileCommitItem, error)
diffCommitsFn func(fromID, toID string) ([]entity.DiffEntry, error)
getUncommittedChangesFn func(folderID string) ([]entity.DiffEntry, error)
getCommitTreeFn func(commitID string) (map[string]interface{}, error)
getCommitFileContentFn func(folderID, commitID, fileID string) ([]byte, error)
getFileVersionHistoryFn func(fileID string) ([]entity.VersionEntry, error)
}
func (m *mockFileCommitSvc) CreateCommit(folderID, authorID, message string, changes []entity.FileChange) (*entity.FileCommit, error) {

View File

@@ -188,7 +188,6 @@ func mcpDetailError(c *gin.Context, code common.ErrorCode, err error) {
})
}
// UpdateMCPServer updates an MCP server for the current user.
func (h *MCPHandler) UpdateMCPServer(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)

View File

@@ -377,15 +377,21 @@ func jsonDecodeMessage(t *testing.T, body []byte) string {
}
func nullableInt(p *int) string {
if p == nil { return "nil" }
if p == nil {
return "nil"
}
return fmt.Sprintf("%d", *p)
}
func nullableBool(p *bool) string {
if p == nil { return "nil" }
if p == nil {
return "nil"
}
return fmt.Sprintf("%v", *p)
}
func nullableFloat(p *float64) string {
if p == nil { return "nil" }
if p == nil {
return "nil"
}
return fmt.Sprintf("%v", *p)
}
func TestSearchBotsRetrieval_EmptyQuestion(t *testing.T) {
@@ -404,6 +410,7 @@ func TestSearchBotsRetrieval_EmptyQuestion(t *testing.T) {
t.Errorf("expected validation error mentioning Question and required, got %q", msg)
}
}
// fakeSearchbotLLM implements searchbotLLM for testing.
type fakeSearchbotLLM struct {
response string
@@ -899,8 +906,6 @@ func TestAskHandler_WhitespaceKbIDFiltered(t *testing.T) {
}
}
// ---- SSE helper direct tests ----
func TestSseAnswer_Final(t *testing.T) {