diff --git a/internal/dao/api_token.go b/internal/dao/api_token.go index 67cde37b37..b41dec6216 100644 --- a/internal/dao/api_token.go +++ b/internal/dao/api_token.go @@ -131,6 +131,13 @@ func (dao *API4ConversationDAO) GetBySessionID(sessionID, agentID string) (*enti return &result, nil } +// ListIDsByAgentID lists conversation IDs for one agent. +func (dao *API4ConversationDAO) ListIDsByAgentID(agentID string) ([]string, error) { + var ids []string + err := DB.Model(&entity.API4Conversation{}).Where("dialog_id = ?", agentID).Pluck("id", &ids).Error + return ids, err +} + // DeleteBySessionIDAndAgentID deletes API4Conversations by sessionID and agentID func (dao *API4ConversationDAO) DeleteBySessionIDAndAgentID(sessionID, agentID string) (int64, error) { result := DB.Where("id = ? AND dialog_id = ?", sessionID, agentID).Delete(&entity.API4Conversation{}) diff --git a/internal/dao/api_token_test.go b/internal/dao/api_token_test.go index 180a1138b1..87eeef5cfa 100644 --- a/internal/dao/api_token_test.go +++ b/internal/dao/api_token_test.go @@ -18,6 +18,7 @@ package dao import ( "encoding/json" + "sort" "testing" "github.com/glebarez/sqlite" @@ -106,3 +107,43 @@ func TestAPI4ConversationDAOGetBySessionIDNoRows(t *testing.T) { t.Fatalf("expected nil for missing session, got %+v", session) } } + +func TestAPI4ConversationDAOListIDsByAgentID(t *testing.T) { + db := setupAPI4ConversationTestDB(t) + pushDB(t, db) + + createAPI4ConversationForDAOTest(t, "session-1", "agent-1") + createAPI4ConversationForDAOTest(t, "session-2", "agent-1") + createAPI4ConversationForDAOTest(t, "session-other", "agent-2") + + ids, err := NewAPI4ConversationDAO().ListIDsByAgentID("agent-1") + if err != nil { + t.Fatalf("ListIDsByAgentID failed: %v", err) + } + + sort.Strings(ids) + want := []string{"session-1", "session-2"} + if len(ids) != len(want) { + t.Fatalf("expected %d ids, got %d: %v", len(want), len(ids), ids) + } + for i := range want { + if ids[i] != want[i] { + t.Fatalf("expected ids %v, got %v", want, ids) + } + } +} + +func TestAPI4ConversationDAOListIDsByAgentIDNoRows(t *testing.T) { + db := setupAPI4ConversationTestDB(t) + pushDB(t, db) + + createAPI4ConversationForDAOTest(t, "session-1", "agent-1") + + ids, err := NewAPI4ConversationDAO().ListIDsByAgentID("agent-2") + if err != nil { + t.Fatalf("ListIDsByAgentID failed: %v", err) + } + if len(ids) != 0 { + t.Fatalf("expected empty ids for missing agent, got %v", ids) + } +} diff --git a/internal/handler/agent.go b/internal/handler/agent.go index 97c7a8651b..bb223bbd6a 100644 --- a/internal/handler/agent.go +++ b/internal/handler/agent.go @@ -19,6 +19,7 @@ package handler import ( "errors" "fmt" + "io" "mime/multipart" "net/http" "strconv" @@ -307,6 +308,89 @@ func (h *AgentHandler) DeleteAgentSessionItem(c *gin.Context) { }) } +type deleteAgentSessionsRequest struct { + IDs []string `json:"ids"` + DeleteAll bool `json:"delete_all,omitempty"` +} + +func (h *AgentHandler) DeleteAgentSessions(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + agentID := strings.TrimSpace(c.Param("agent_id")) + if agentID == "" { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeOperatingError, + "data": nil, + "message": "agent_id is required", + }) + return + } + + var req deleteAgentSessionsRequest + if err := c.ShouldBindJSON(&req); err != nil && !errors.Is(err, io.EOF) { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeBadRequest, + "data": false, + "message": err.Error(), + }) + return + } + + result, code, err := h.agentService.DeleteAgentSessions(strings.TrimSpace(user.ID), agentID, req.IDs, req.DeleteAll) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": code, + "message": err.Error(), + }) + return + } + + response := gin.H{"code": common.CodeSuccess} + if result != nil && result.Data != nil { + response["data"] = result.Data + } + if result != nil && result.Message != "" { + response["message"] = result.Message + } + c.JSON(http.StatusOK, response) +} + +// TestDBConnection Test DB connection +func (h *AgentHandler) TestDBConnection(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + var req service.TestDBConnectionRequest + if err := c.ShouldBindJSON(&req); err != nil && !errors.Is(err, io.EOF) { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeBadRequest, + "message": err.Error(), + }) + return + } + + code, err := h.agentService.TestDBConnection(user.ID, &req) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": code, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeSuccess, + "message": "success", + }) +} + // ListAgentVersions returns versions for a specific agent. // @Summary List Agent Versions // @Description Returns all versions for a specific agent, ordered by update_time DESC. diff --git a/internal/handler/agent_test.go b/internal/handler/agent_test.go index 99097c6902..92a217b18a 100644 --- a/internal/handler/agent_test.go +++ b/internal/handler/agent_test.go @@ -508,6 +508,215 @@ func TestDeleteAgentSessionItemHandlerIgnoresSessionFromAnotherAgent(t *testing. } } +func TestDeleteAgentSessionsHandlerDeletesDuplicateIDsPartially(t *testing.T) { + c, w, db := setupGinContextWithUserAndDB(t, http.MethodDelete, "/api/v1/agents/canvas-1/sessions") + c.Request = httptest.NewRequest(http.MethodDelete, "/api/v1/agents/canvas-1/sessions", strings.NewReader(`{"ids":["session-1","session-1"]}`)) + c.Request.Header.Set("Content-Type", "application/json") + c.Params = gin.Params{{Key: "agent_id", Value: "canvas-1"}} + + db.Create(&entity.UserCanvas{ + ID: "canvas-1", + UserID: "user-1", + Title: sptr("Test Agent"), + }) + db.Create(&entity.API4Conversation{ + ID: "session-1", + DialogID: "canvas-1", + UserID: "user-1", + Message: json.RawMessage(`[]`), + Reference: json.RawMessage(`[]`), + }) + + h := NewAgentHandler(service.NewAgentService(), nil) + h.DeleteAgentSessions(c) + + var resp map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to parse response: %v", err) + } + if resp["code"] != float64(common.CodeSuccess) { + t.Fatalf("expected code %d, got %v: %v", common.CodeSuccess, resp["code"], resp["message"]) + } + data, ok := resp["data"].(map[string]interface{}) + if !ok { + t.Fatalf("expected partial data object, got %T", resp["data"]) + } + if data["success_count"] != float64(1) { + t.Fatalf("expected success_count 1, got %v", data["success_count"]) + } + errorsList, ok := data["errors"].([]interface{}) + if !ok || len(errorsList) != 1 { + t.Fatalf("expected one duplicate error, got %v", data["errors"]) + } + if errorsList[0] != "Duplicate session ids: session-1" { + t.Fatalf("unexpected duplicate error: %v", errorsList[0]) + } + + var count int64 + if err := db.Model(&entity.API4Conversation{}).Where("id = ?", "session-1").Count(&count).Error; err != nil { + t.Fatalf("failed to count deleted session: %v", err) + } + if count != 0 { + t.Fatalf("expected session-1 to be deleted, count=%d", count) + } +} + +func TestDeleteAgentSessionsHandlerDeleteAll(t *testing.T) { + c, w, db := setupGinContextWithUserAndDB(t, http.MethodDelete, "/api/v1/agents/canvas-1/sessions") + c.Request = httptest.NewRequest(http.MethodDelete, "/api/v1/agents/canvas-1/sessions", strings.NewReader(`{"delete_all":true}`)) + c.Request.Header.Set("Content-Type", "application/json") + c.Params = gin.Params{{Key: "agent_id", Value: "canvas-1"}} + + db.Create(&entity.UserCanvas{ + ID: "canvas-1", + UserID: "user-1", + Title: sptr("Test Agent"), + }) + db.Create(&entity.API4Conversation{ + ID: "session-1", + DialogID: "canvas-1", + UserID: "user-1", + Message: json.RawMessage(`[]`), + Reference: json.RawMessage(`[]`), + }) + db.Create(&entity.API4Conversation{ + ID: "session-2", + DialogID: "canvas-1", + UserID: "user-1", + Message: json.RawMessage(`[]`), + Reference: json.RawMessage(`[]`), + }) + db.Create(&entity.API4Conversation{ + ID: "session-other", + DialogID: "canvas-other", + UserID: "user-1", + Message: json.RawMessage(`[]`), + Reference: json.RawMessage(`[]`), + }) + + h := NewAgentHandler(service.NewAgentService(), nil) + h.DeleteAgentSessions(c) + + var resp map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to parse response: %v", err) + } + if resp["code"] != float64(common.CodeSuccess) { + t.Fatalf("expected code %d, got %v: %v", common.CodeSuccess, resp["code"], resp["message"]) + } + + var ownCount int64 + if err := db.Model(&entity.API4Conversation{}).Where("dialog_id = ?", "canvas-1").Count(&ownCount).Error; err != nil { + t.Fatalf("failed to count own sessions: %v", err) + } + if ownCount != 0 { + t.Fatalf("expected all canvas-1 sessions to be deleted, count=%d", ownCount) + } + + var otherCount int64 + if err := db.Model(&entity.API4Conversation{}).Where("id = ?", "session-other").Count(&otherCount).Error; err != nil { + t.Fatalf("failed to count other session: %v", err) + } + if otherCount != 1 { + t.Fatalf("expected other agent session to remain, count=%d", otherCount) + } +} + +func TestDeleteAgentSessionsHandlerRequiresOwner(t *testing.T) { + c, w, db := setupGinContextWithUserAndDB(t, http.MethodDelete, "/api/v1/agents/canvas-1/sessions") + c.Request = httptest.NewRequest(http.MethodDelete, "/api/v1/agents/canvas-1/sessions", strings.NewReader(`{"ids":["session-1"]}`)) + c.Request.Header.Set("Content-Type", "application/json") + c.Params = gin.Params{{Key: "agent_id", Value: "canvas-1"}} + + db.Create(&entity.UserCanvas{ + ID: "canvas-1", + UserID: "user-2", + Permission: "team", + Title: sptr("Team Agent"), + }) + db.Create(&entity.API4Conversation{ + ID: "session-1", + DialogID: "canvas-1", + UserID: "user-1", + Message: json.RawMessage(`[]`), + Reference: json.RawMessage(`[]`), + }) + + h := NewAgentHandler(service.NewAgentService(), nil) + h.DeleteAgentSessions(c) + + var resp map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to parse response: %v", err) + } + if resp["code"] != float64(common.CodeDataError) { + t.Fatalf("expected code %d, got %v: %v", common.CodeDataError, resp["code"], resp["message"]) + } + + var count int64 + if err := db.Model(&entity.API4Conversation{}).Where("id = ?", "session-1").Count(&count).Error; err != nil { + t.Fatalf("failed to count session: %v", err) + } + if count != 1 { + t.Fatalf("expected session to remain, count=%d", count) + } +} + +func TestTestDBConnectionHandlerMissingFields(t *testing.T) { + c, w, _ := setupGinContextWithUserAndDB(t, http.MethodPost, "/api/v1/agents/test_db_connection") + c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/agents/test_db_connection", strings.NewReader(`{"db_type":"mysql"}`)) + c.Request.Header.Set("Content-Type", "application/json") + + h := NewAgentHandler(service.NewAgentService(), nil) + h.TestDBConnection(c) + + var resp map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to parse response: %v", err) + } + if resp["code"] != float64(common.CodeArgumentError) { + t.Fatalf("expected code %d, got %v: %v", common.CodeArgumentError, resp["code"], resp["message"]) + } + if resp["data"] != nil { + t.Fatalf("expected nil data, got %v", resp["data"]) + } + want := "required argument are missing: database,username,host,port,password; " + if resp["message"] != want { + t.Fatalf("expected message %q, got %v", want, resp["message"]) + } +} + +func TestTestDBConnectionHandlerRejectsLocalhost(t *testing.T) { + c, w, _ := setupGinContextWithUserAndDB(t, http.MethodPost, "/api/v1/agents/test_db_connection") + c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/agents/test_db_connection", strings.NewReader(`{ + "db_type":"mysql", + "database":"rag_flow", + "username":"root", + "host":"localhost", + "port":3306, + "password":"infini_rag_flow" + }`)) + c.Request.Header.Set("Content-Type", "application/json") + + h := NewAgentHandler(service.NewAgentService(), nil) + h.TestDBConnection(c) + + var resp map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to parse response: %v", err) + } + if resp["code"] != float64(common.CodeDataError) { + t.Fatalf("expected code %d, got %v: %v", common.CodeDataError, resp["code"], resp["message"]) + } + if resp["data"] != nil { + t.Fatalf("expected nil data, got %v", resp["data"]) + } + message, ok := resp["message"].(string) + if !ok || !strings.Contains(message, "non-public address") { + t.Fatalf("expected non-public host message, got %v", resp["message"]) + } +} + func TestUpdateAgentTagsHandlerSuccess(t *testing.T) { c, w, db := setupGinContextWithUserAndDB(t, http.MethodPut, "/api/v1/agents/canvas-1/tags") c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/agents/canvas-1/tags", strings.NewReader(`{"tags":["alpha","beta","alpha"]}`)) diff --git a/internal/router/router.go b/internal/router/router.go index 98b02fdba7..d8ac7f0bdd 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -388,6 +388,7 @@ func (r *Router) Setup(engine *gin.Engine) { agents.GET("", r.agentHandler.ListAgents) agents.GET("/prompts", r.agentHandler.GetPrompts) agents.GET("/templates", r.agentHandler.ListTemplates) + agents.POST("/test_db_connection", r.agentHandler.TestDBConnection) agents.GET("/:agent_id/versions", r.agentHandler.ListAgentVersions) agents.GET("/:agent_id/versions/:version_id", r.agentHandler.GetAgentVersion) agents.POST("/:agent_id/upload", r.agentHandler.UploadAgentFile) @@ -395,6 +396,7 @@ func (r *Router) Setup(engine *gin.Engine) { agents.GET("/:agent_id/sessions", r.agentHandler.ListAgentSessions) agents.GET("/:agent_id/sessions/:session_id", r.agentHandler.GetAgentSession) agents.DELETE("/:agent_id/sessions/:session_id", r.agentHandler.DeleteAgentSessionItem) + agents.DELETE("/:agent_id/sessions", r.agentHandler.DeleteAgentSessions) } // Plugin routes diff --git a/internal/service/agent.go b/internal/service/agent.go index 6694403a22..b066d7f41e 100644 --- a/internal/service/agent.go +++ b/internal/service/agent.go @@ -17,14 +17,21 @@ package service import ( + "context" + "database/sql" "encoding/json" "errors" "fmt" + "net" + "net/netip" "sort" "strconv" "strings" "time" + "github.com/go-sql-driver/mysql" + "go.uber.org/zap" + "ragflow/internal/common" "ragflow/internal/dao" "ragflow/internal/entity" @@ -163,6 +170,25 @@ type ListAgentSessionsResponse struct { Total int64 `json:"total"` } +type DeleteAgentSessionsResult struct { + Data *DeleteAgentSessionsResponse + Message string +} + +type DeleteAgentSessionsResponse struct { + SuccessCount int `json:"success_count"` + Errors []string `json:"errors,omitempty"` +} + +type TestDBConnectionRequest struct { + DBType string `json:"db_type"` + Database string `json:"database"` + Username string `json:"username"` + Host string `json:"host"` + Port interface{} `json:"port"` + Password string `json:"password"` +} + func parseAgentSessionDate(value string, isEnd bool) (*time.Time, error) { if value == "" { return nil, nil @@ -370,6 +396,29 @@ func firstNonNil(values ...interface{}) interface{} { return nil } +// checkDuplicateSessionIDs check duplicated ID in IDS and add some messages for it +func checkDuplicateSessionIDs(ids []string) ([]string, []string) { + seen := make(map[string]int, len(ids)) + uniqueIDs := make([]string, 0, len(ids)) + + for _, id := range ids { + id = strings.TrimSpace(id) + seen[id]++ + if seen[id] == 1 { + uniqueIDs = append(uniqueIDs, id) + } + } + + duplicateMessages := make([]string, 0) + for _, id := range uniqueIDs { + if seen[id] > 1 { + duplicateMessages = append(duplicateMessages, fmt.Sprintf("Duplicate session ids: %s", id)) + } + } + + return uniqueIDs, duplicateMessages +} + func (s *AgentService) ListAgentSessions(userID, tenantID, agentID string, req ListAgentSessionsRequest) (*ListAgentSessionsResponse, common.ErrorCode, error) { if agentID == "" { return nil, common.CodeArgumentError, errors.New("agent_id is required") @@ -474,6 +523,301 @@ func (s *AgentService) DeleteAgentSessionItem(userID, agentID, sessionID string) return true, common.CodeSuccess, nil } +// DeleteAgentSessions Delete sessions by ids +func (s *AgentService) DeleteAgentSessions(userID, agentID string, ids []string, deleteAll bool) (*DeleteAgentSessionsResult, common.ErrorCode, error) { + if agentID == "" { + return nil, common.CodeArgumentError, errors.New("agent_id is required") + } + + canvas, err := s.canvasDAO.GetByID(agentID) + if err != nil || canvas == nil || canvas.UserID != userID { + return nil, common.CodeDataError, fmt.Errorf("You don't own the agent %s", agentID) + } + + if len(ids) == 0 { + if !deleteAll { + return &DeleteAgentSessionsResult{}, common.CodeSuccess, nil + } + + ids, err = s.api4ConversationDAO.ListIDsByAgentID(agentID) + if err != nil { + return nil, common.CodeServerError, err + } + if len(ids) == 0 { + return &DeleteAgentSessionsResult{}, common.CodeSuccess, nil + } + } + + sessionIDs, duplicateMessages := checkDuplicateSessionIDs(ids) + errorsList := make([]string, 0) + successCount := 0 + + for _, sessionID := range sessionIDs { + sessionID = strings.TrimSpace(sessionID) + if sessionID == "" { + errorsList = append(errorsList, "The agent doesn't own the session ") + continue + } + + conv, err := s.api4ConversationDAO.GetBySessionID(sessionID, agentID) + if err != nil { + return nil, common.CodeServerError, err + } + if conv == nil { + errorsList = append(errorsList, fmt.Sprintf("The agent doesn't own the session %s", sessionID)) + continue + } + + if _, err := s.api4ConversationDAO.DeleteBySessionIDAndAgentID(sessionID, agentID); err != nil { + return nil, common.CodeServerError, err + } + successCount++ + } + + if len(errorsList) > 0 { + if successCount > 0 { + return &DeleteAgentSessionsResult{ + Message: fmt.Sprintf("Partially deleted %d sessions with %d errors", successCount, len(errorsList)), + Data: &DeleteAgentSessionsResponse{ + SuccessCount: successCount, + Errors: errorsList, + }, + }, common.CodeSuccess, nil + } + return nil, common.CodeDataError, errors.New(strings.Join(errorsList, "; ")) + } + + if len(duplicateMessages) > 0 { + if successCount > 0 { + return &DeleteAgentSessionsResult{ + Message: fmt.Sprintf("Partially deleted %d sessions with %d errors", successCount, len(duplicateMessages)), + Data: &DeleteAgentSessionsResponse{ + SuccessCount: successCount, + Errors: duplicateMessages, + }, + }, common.CodeSuccess, nil + } + return nil, common.CodeDataError, errors.New(strings.Join(duplicateMessages, ";")) + } + + return &DeleteAgentSessionsResult{}, common.CodeSuccess, nil +} + +// AssertHostIsSafe checks whether host resolves only to public IP addresses. +func AssertHostIsSafe(host string) (string, error) { + host = strings.TrimSpace(host) + if host == "" { + return "", errors.New("Host must not be empty.") + } + + ips, err := net.LookupIP(host) + if err != nil { + zap.L().Warn("SSRF guard could not resolve host", + zap.String("host", host), + zap.Error(err), + ) + return "", fmt.Errorf("Could not resolve host %q: %w", host, err) + } + + if len(ips) == 0 { + zap.L().Warn("SSRF guard blocked host: resolved to no addresses", + zap.String("host", host), + ) + return "", fmt.Errorf("Host %q resolved to no addresses.", host) + } + + var resolvedIP string + + for _, ip := range ips { + addr, ok := netip.AddrFromSlice(ip) + if !ok { + return "", fmt.Errorf("invalid resolved IP %q for host %q", ip.String(), host) + } + + // Normalize IPv4-mapped IPv6, equivalent to Python _effective_ip(). + addr = addr.Unmap() + + if !isPublicAddr(addr) { + zap.L().Warn("SSRF guard blocked host", + zap.String("host", host), + zap.String("resolved_ip", addr.String()), + ) + return "", fmt.Errorf("Host resolves to a non-public address (%s), which is not allowed.", addr.String()) + } + + if resolvedIP == "" { + resolvedIP = addr.String() + } + } + + if resolvedIP == "" { + return "", fmt.Errorf("Host %q resolved to no addresses.", host) + } + + return resolvedIP, nil +} + +func isPublicAddr(addr netip.Addr) bool { + addr = addr.Unmap() + + if !addr.IsValid() { + return false + } + + if !addr.IsGlobalUnicast() { + return false + } + + if addr.IsPrivate() || + addr.IsLoopback() || + addr.IsLinkLocalUnicast() || + addr.IsLinkLocalMulticast() || + addr.IsMulticast() || + addr.IsUnspecified() { + return false + } + + return !isSpecialUseAddr(addr) +} + +func isSpecialUseAddr(addr netip.Addr) bool { + addr = addr.Unmap() + + specialCIDRs := []string{ + // IPv4 special-use / documentation / reserved ranges. + "0.0.0.0/8", + "100.64.0.0/10", + "127.0.0.0/8", + "169.254.0.0/16", + "192.0.0.0/24", + "192.0.2.0/24", + "198.18.0.0/15", + "198.51.100.0/24", + "203.0.113.0/24", + "224.0.0.0/4", + "240.0.0.0/4", + + // IPv6 special-use / documentation / local ranges. + "::/128", + "::1/128", + "64:ff9b:1::/48", + "100::/64", + "2001::/23", + "2001:2::/48", + "fc00::/7", + "fe80::/10", + "ff00::/8", + "2001:db8::/32", + "2002::/16", + } + + for _, cidr := range specialCIDRs { + prefix := netip.MustParsePrefix(cidr) + if prefix.Contains(addr) { + return true + } + } + + return false +} + +// missingDBConnectionFields Check if request is missing something +func missingDBConnectionFields(req *TestDBConnectionRequest) []string { + missing := make([]string, 0, 6) + if req == nil || strings.TrimSpace(req.DBType) == "" { + missing = append(missing, "db_type") + } + if req == nil || strings.TrimSpace(req.Database) == "" { + missing = append(missing, "database") + } + if req == nil || strings.TrimSpace(req.Username) == "" { + missing = append(missing, "username") + } + if req == nil || strings.TrimSpace(req.Host) == "" { + missing = append(missing, "host") + } + if req == nil || dbConnectionPort(req.Port) == "" { + missing = append(missing, "port") + } + if req == nil || req.Password == "" { + missing = append(missing, "password") + } + return missing +} + +func dbConnectionPort(port interface{}) string { + switch value := port.(type) { + case nil: + return "" + case string: + return strings.TrimSpace(value) + case float64: + return strconv.Itoa(int(value)) + case float32: + return strconv.Itoa(int(value)) + case int: + return strconv.Itoa(value) + case int64: + return strconv.FormatInt(value, 10) + case json.Number: + return value.String() + default: + return strings.TrimSpace(fmt.Sprint(value)) + } +} + +func (s *AgentService) TestDBConnection(userID string, req *TestDBConnectionRequest) (common.ErrorCode, error) { + if missing := missingDBConnectionFields(req); len(missing) > 0 { + return common.CodeArgumentError, fmt.Errorf("required argument are missing: %s; ", strings.Join(missing, ",")) + } + + safeHost, err := AssertHostIsSafe(req.Host) + if err != nil { + zap.L().Warn( + "Rejected test_db_connection: unsafe host", + zap.String("host", req.Host), + zap.String("db_type", req.DBType), + zap.String("user", userID), + zap.Error(err), + ) + return common.CodeDataError, err + } + + switch req.DBType { + case "mysql", "mariadb", "oceanbase": + port := dbConnectionPort(req.Port) + dbProbeTimeout := 5 * time.Second + config := mysql.Config{ + User: req.Username, + Passwd: req.Password, + Net: "tcp", + Addr: net.JoinHostPort(safeHost, port), + DBName: req.Database, + Timeout: dbProbeTimeout, + AllowNativePasswords: true, + } + db, err := sql.Open("mysql", config.FormatDSN()) + if err != nil { + return common.CodeExceptionError, err + } + defer db.Close() + + ctx, cancel := context.WithTimeout(context.Background(), dbProbeTimeout) + defer cancel() + + if err := db.PingContext(ctx); err != nil { + return common.CodeExceptionError, err + } + if _, err := db.ExecContext(ctx, "SELECT 1"); err != nil { + return common.CodeExceptionError, err + } + default: + return common.CodeExceptionError, errors.New("Unsupported database type.") + } + + return common.CodeSuccess, nil +} + // normalizeAgentTags returns an error for unsupported tag payload types func normalizeAgentTags(rawTags interface{}) (string, error) { cleaned := make([]string, 0) diff --git a/internal/service/agent_test.go b/internal/service/agent_test.go index 07ba83a25f..50be627546 100644 --- a/internal/service/agent_test.go +++ b/internal/service/agent_test.go @@ -17,9 +17,13 @@ package service import ( + "encoding/json" + "net/netip" + "strings" "testing" "time" + "ragflow/internal/common" "ragflow/internal/dao" "ragflow/internal/entity" ) @@ -344,5 +348,486 @@ func TestGetVersion_NotFound(t *testing.T) { } } +func setupAgentSessionServiceTest(t *testing.T) { + t.Helper() + + testDB := setupServiceTestDB(t) + if err := testDB.AutoMigrate( + &entity.User{}, + &entity.UserCanvas{}, + &entity.UserTenant{}, + &entity.API4Conversation{}, + ); err != nil { + t.Fatalf("failed to migrate: %v", err) + } + + orig := dao.DB + dao.DB = testDB + t.Cleanup(func() { dao.DB = orig }) +} + +func createAgentSessionTestCanvas(t *testing.T, id, userID string) { + t.Helper() + if err := dao.DB.Create(&entity.UserCanvas{ + ID: id, + UserID: userID, + Title: sptr("Test Agent"), + CanvasCategory: "agent_canvas", + }).Error; err != nil { + t.Fatalf("failed to create canvas %s: %v", id, err) + } +} + +func createAgentSessionTestConversation(t *testing.T, id, agentID, userID string, updateTime int64) { + t.Helper() + updateDate := time.UnixMilli(updateTime) + if err := dao.DB.Create(&entity.API4Conversation{ + ID: id, + DialogID: agentID, + UserID: userID, + Message: json.RawMessage(`[{"role":"assistant","content":"hello","prompt":"hidden"}]`), + Reference: json.RawMessage(`[]`), + BaseModel: entity.BaseModel{ + CreateTime: ptr(updateTime), + CreateDate: &updateDate, + UpdateTime: ptr(updateTime), + UpdateDate: &updateDate, + }, + }).Error; err != nil { + t.Fatalf("failed to create session %s: %v", id, err) + } +} + +func TestListAgentSessionsServiceSuccess(t *testing.T) { + setupAgentSessionServiceTest(t) + + createAgentSessionTestCanvas(t, "canvas-1", "user-1") + createAgentSessionTestConversation(t, "session-old", "canvas-1", "user-1", 1000) + createAgentSessionTestConversation(t, "session-new", "canvas-1", "user-1", 3000) + createAgentSessionTestConversation(t, "session-other-agent", "canvas-other", "user-1", 9999) + + resp, code, err := NewAgentService().ListAgentSessions("user-1", "user-1", "canvas-1", ListAgentSessionsRequest{ + Page: 1, + PageSize: 10, + OrderBy: "update_time", + Desc: true, + IncludeDSL: false, + }) + if err != nil { + t.Fatalf("ListAgentSessions failed: %v", err) + } + if code != common.CodeSuccess { + t.Fatalf("expected code %d, got %d", common.CodeSuccess, code) + } + if resp.Total != 2 { + t.Fatalf("expected total 2, got %d", resp.Total) + } + if len(resp.Data) != 2 { + t.Fatalf("expected 2 sessions, got %d", len(resp.Data)) + } + if resp.Data[0]["id"] != "session-new" { + t.Fatalf("expected newest session first, got %v", resp.Data[0]["id"]) + } + if resp.Data[0]["agent_id"] != "canvas-1" { + t.Fatalf("expected agent_id canvas-1, got %v", resp.Data[0]["agent_id"]) + } + messages, ok := resp.Data[0]["message"].([]map[string]interface{}) + if !ok || len(messages) != 1 { + t.Fatalf("expected normalized message slice, got %T %v", resp.Data[0]["message"], resp.Data[0]["message"]) + } + if _, ok := messages[0]["prompt"]; ok { + t.Fatal("expected prompt to be removed from normalized session message") + } +} + +func TestListAgentSessionsServiceDenied(t *testing.T) { + setupAgentSessionServiceTest(t) + + createAgentSessionTestCanvas(t, "canvas-1", "user-2") + + resp, code, err := NewAgentService().ListAgentSessions("user-1", "user-1", "canvas-1", ListAgentSessionsRequest{}) + if err == nil { + t.Fatal("expected permission error") + } + if code != common.CodeOperatingError { + t.Fatalf("expected code %d, got %d", common.CodeOperatingError, code) + } + if resp != nil { + t.Fatalf("expected nil response, got %+v", resp) + } +} + +func TestGetAgentSessionServiceSuccess(t *testing.T) { + setupAgentSessionServiceTest(t) + + createAgentSessionTestCanvas(t, "canvas-1", "user-1") + createAgentSessionTestConversation(t, "session-1", "canvas-1", "user-1", 1000) + + session, code, err := NewAgentService().GetAgentSession("user-1", "canvas-1", "session-1") + if err != nil { + t.Fatalf("GetAgentSession failed: %v", err) + } + if code != common.CodeSuccess { + t.Fatalf("expected code %d, got %d", common.CodeSuccess, code) + } + if session == nil { + t.Fatal("expected session, got nil") + } + if session.ID != "session-1" { + t.Fatalf("expected session-1, got %s", session.ID) + } + if session.DialogID != "canvas-1" { + t.Fatalf("expected dialog_id canvas-1, got %s", session.DialogID) + } +} + +func TestGetAgentSessionServiceNotFoundWhenSessionBelongsToAnotherAgent(t *testing.T) { + setupAgentSessionServiceTest(t) + + createAgentSessionTestCanvas(t, "canvas-1", "user-1") + createAgentSessionTestConversation(t, "session-other", "canvas-other", "user-1", 1000) + + session, code, err := NewAgentService().GetAgentSession("user-1", "canvas-1", "session-other") + if err == nil { + t.Fatal("expected not found error") + } + if code != common.CodeNotFound { + t.Fatalf("expected code %d, got %d", common.CodeNotFound, code) + } + if session != nil { + t.Fatalf("expected nil session, got %+v", session) + } +} + +func TestDeleteAgentSessionItemServiceDeletesMatchingSession(t *testing.T) { + setupAgentSessionServiceTest(t) + + createAgentSessionTestCanvas(t, "canvas-1", "user-1") + createAgentSessionTestConversation(t, "session-1", "canvas-1", "user-1", 1000) + createAgentSessionTestConversation(t, "session-other", "canvas-other", "user-1", 2000) + + deleted, code, err := NewAgentService().DeleteAgentSessionItem("user-1", "canvas-1", "session-1") + if err != nil { + t.Fatalf("DeleteAgentSessionItem failed: %v", err) + } + if code != common.CodeSuccess { + t.Fatalf("expected code %d, got %d", common.CodeSuccess, code) + } + if !deleted { + t.Fatal("expected session to be deleted") + } + + var count int64 + if err := dao.DB.Model(&entity.API4Conversation{}).Where("id = ?", "session-1").Count(&count).Error; err != nil { + t.Fatalf("failed to count deleted session: %v", err) + } + if count != 0 { + t.Fatalf("expected session-1 to be deleted, count=%d", count) + } + + if err := dao.DB.Model(&entity.API4Conversation{}).Where("id = ?", "session-other").Count(&count).Error; err != nil { + t.Fatalf("failed to count other session: %v", err) + } + if count != 1 { + t.Fatalf("expected other agent session to remain, count=%d", count) + } +} + +func TestDeleteAgentSessionItemServiceNoopForSessionFromAnotherAgent(t *testing.T) { + setupAgentSessionServiceTest(t) + + createAgentSessionTestCanvas(t, "canvas-1", "user-1") + createAgentSessionTestConversation(t, "session-other", "canvas-other", "user-1", 1000) + + deleted, code, err := NewAgentService().DeleteAgentSessionItem("user-1", "canvas-1", "session-other") + if err != nil { + t.Fatalf("DeleteAgentSessionItem failed: %v", err) + } + if code != common.CodeSuccess { + t.Fatalf("expected code %d, got %d", common.CodeSuccess, code) + } + if deleted { + t.Fatal("expected cross-agent session delete to be a noop") + } + + var count int64 + if err := dao.DB.Model(&entity.API4Conversation{}).Where("id = ?", "session-other").Count(&count).Error; err != nil { + t.Fatalf("failed to count other session: %v", err) + } + if count != 1 { + t.Fatalf("expected cross-agent session to remain, count=%d", count) + } +} + +func TestDeleteAgentSessionsServiceDeleteAll(t *testing.T) { + setupAgentSessionServiceTest(t) + + createAgentSessionTestCanvas(t, "canvas-1", "user-1") + createAgentSessionTestConversation(t, "session-1", "canvas-1", "user-1", 1000) + createAgentSessionTestConversation(t, "session-2", "canvas-1", "user-1", 2000) + createAgentSessionTestConversation(t, "session-other", "canvas-other", "user-1", 3000) + + result, code, err := NewAgentService().DeleteAgentSessions("user-1", "canvas-1", nil, true) + if err != nil { + t.Fatalf("DeleteAgentSessions failed: %v", err) + } + if code != common.CodeSuccess { + t.Fatalf("expected code %d, got %d", common.CodeSuccess, code) + } + if result == nil { + t.Fatal("expected non-nil result") + } + if result.Data != nil { + t.Fatalf("expected no partial data on full success, got %+v", result.Data) + } + + var ownCount int64 + if err := dao.DB.Model(&entity.API4Conversation{}).Where("dialog_id = ?", "canvas-1").Count(&ownCount).Error; err != nil { + t.Fatalf("failed to count own sessions: %v", err) + } + if ownCount != 0 { + t.Fatalf("expected all canvas-1 sessions to be deleted, count=%d", ownCount) + } + + var otherCount int64 + if err := dao.DB.Model(&entity.API4Conversation{}).Where("id = ?", "session-other").Count(&otherCount).Error; err != nil { + t.Fatalf("failed to count other session: %v", err) + } + if otherCount != 1 { + t.Fatalf("expected other agent session to remain, count=%d", otherCount) + } +} + +func TestDeleteAgentSessionsServiceDuplicateIDsPartial(t *testing.T) { + setupAgentSessionServiceTest(t) + + createAgentSessionTestCanvas(t, "canvas-1", "user-1") + createAgentSessionTestConversation(t, "session-1", "canvas-1", "user-1", 1000) + + result, code, err := NewAgentService().DeleteAgentSessions("user-1", "canvas-1", []string{"session-1", "session-1"}, false) + if err != nil { + t.Fatalf("DeleteAgentSessions failed: %v", err) + } + if code != common.CodeSuccess { + t.Fatalf("expected code %d, got %d", common.CodeSuccess, code) + } + if result == nil || result.Data == nil { + t.Fatalf("expected partial result data, got %+v", result) + } + if result.Data.SuccessCount != 1 { + t.Fatalf("expected success_count 1, got %d", result.Data.SuccessCount) + } + if len(result.Data.Errors) != 1 || result.Data.Errors[0] != "Duplicate session ids: session-1" { + t.Fatalf("unexpected duplicate errors: %v", result.Data.Errors) + } + + var count int64 + if err := dao.DB.Model(&entity.API4Conversation{}).Where("id = ?", "session-1").Count(&count).Error; err != nil { + t.Fatalf("failed to count deleted session: %v", err) + } + if count != 0 { + t.Fatalf("expected session-1 to be deleted, count=%d", count) + } +} + +func TestDeleteAgentSessionsServiceMissingSessionError(t *testing.T) { + setupAgentSessionServiceTest(t) + + createAgentSessionTestCanvas(t, "canvas-1", "user-1") + + result, code, err := NewAgentService().DeleteAgentSessions("user-1", "canvas-1", []string{"missing-session"}, false) + if err == nil { + t.Fatal("expected missing session error") + } + if code != common.CodeDataError { + t.Fatalf("expected code %d, got %d", common.CodeDataError, code) + } + if result != nil { + t.Fatalf("expected nil result, got %+v", result) + } + if !strings.Contains(err.Error(), "The agent doesn't own the session missing-session") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestDeleteAgentSessionsServiceRequiresOwner(t *testing.T) { + setupAgentSessionServiceTest(t) + + createAgentSessionTestCanvas(t, "canvas-1", "user-2") + createAgentSessionTestConversation(t, "session-1", "canvas-1", "user-1", 1000) + + result, code, err := NewAgentService().DeleteAgentSessions("user-1", "canvas-1", []string{"session-1"}, false) + if err == nil { + t.Fatal("expected owner error") + } + if code != common.CodeDataError { + t.Fatalf("expected code %d, got %d", common.CodeDataError, code) + } + if result != nil { + t.Fatalf("expected nil result, got %+v", result) + } + + var count int64 + if err := dao.DB.Model(&entity.API4Conversation{}).Where("id = ?", "session-1").Count(&count).Error; err != nil { + t.Fatalf("failed to count session: %v", err) + } + if count != 1 { + t.Fatalf("expected session to remain, count=%d", count) + } +} + +func TestUpdateAgentTagsServiceSuccess(t *testing.T) { + setupAgentSessionServiceTest(t) + + createAgentSessionTestCanvas(t, "canvas-1", "user-1") + + ok, code, err := NewAgentService().UpdateAgentTags("user-1", "canvas-1", []interface{}{"alpha", "beta", "alpha", "with,comma"}) + if err != nil { + t.Fatalf("UpdateAgentTags failed: %v", err) + } + if code != common.CodeSuccess { + t.Fatalf("expected code %d, got %d", common.CodeSuccess, code) + } + if !ok { + t.Fatal("expected update to succeed") + } + + canvas, err := dao.NewUserCanvasDAO().GetByID("canvas-1") + if err != nil { + t.Fatalf("failed to get canvas: %v", err) + } + if canvas.Tags != "alpha,beta,with comma" { + t.Fatalf("expected normalized tags, got %q", canvas.Tags) + } +} + +func TestUpdateAgentTagsServiceInvalidPayload(t *testing.T) { + setupAgentSessionServiceTest(t) + + createAgentSessionTestCanvas(t, "canvas-1", "user-1") + + ok, code, err := NewAgentService().UpdateAgentTags("user-1", "canvas-1", map[string]string{"tag": "alpha"}) + if err == nil { + t.Fatal("expected invalid tags error") + } + if code != common.CodeBadRequest { + t.Fatalf("expected code %d, got %d", common.CodeBadRequest, code) + } + if ok { + t.Fatal("expected update to fail") + } + + canvas, err := dao.NewUserCanvasDAO().GetByID("canvas-1") + if err != nil { + t.Fatalf("failed to get canvas: %v", err) + } + if canvas.Tags != "" { + t.Fatalf("expected tags to remain unchanged, got %q", canvas.Tags) + } +} + +func TestUpdateAgentTagsServiceNoPermission(t *testing.T) { + setupAgentSessionServiceTest(t) + + createAgentSessionTestCanvas(t, "canvas-1", "user-2") + + ok, code, err := NewAgentService().UpdateAgentTags("user-1", "canvas-1", []string{"alpha"}) + if err == nil { + t.Fatal("expected permission error") + } + if code != common.CodeOperatingError { + t.Fatalf("expected code %d, got %d", common.CodeOperatingError, code) + } + if ok { + t.Fatal("expected update to fail") + } + + canvas, err := dao.NewUserCanvasDAO().GetByID("canvas-1") + if err != nil { + t.Fatalf("failed to get canvas: %v", err) + } + if canvas.Tags != "" { + t.Fatalf("expected tags to remain unchanged, got %q", canvas.Tags) + } +} + +func TestIsPublicAddr(t *testing.T) { + tests := []struct { + name string + addr string + want bool + }{ + {name: "public IPv4", addr: "8.8.8.8", want: true}, + {name: "loopback", addr: "127.0.0.1", want: false}, + {name: "private", addr: "192.168.1.1", want: false}, + {name: "carrier NAT", addr: "100.64.0.1", want: false}, + {name: "documentation", addr: "203.0.113.1", want: false}, + {name: "IPv4 mapped loopback", addr: "::ffff:127.0.0.1", want: false}, + {name: "IPv6 documentation", addr: "2001:db8::1", want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isPublicAddr(netip.MustParseAddr(tt.addr)) + if got != tt.want { + t.Fatalf("isPublicAddr(%s): expected %v, got %v", tt.addr, tt.want, got) + } + }) + } +} + +func TestAssertHostIsSafeRejectsLocalhost(t *testing.T) { + _, err := AssertHostIsSafe("localhost") + if err == nil { + t.Fatal("expected localhost to be rejected") + } + if !strings.Contains(err.Error(), "non-public address") { + t.Fatalf("expected non-public address error, got %v", err) + } +} + +func TestTestDBConnectionMissingFields(t *testing.T) { + code, err := NewAgentService().TestDBConnection("user-1", &TestDBConnectionRequest{DBType: "mysql"}) + if err == nil { + t.Fatal("expected missing field error") + } + if code != common.CodeArgumentError { + t.Fatalf("expected code %d, got %d", common.CodeArgumentError, code) + } + want := "required argument are missing: database,username,host,port,password; " + if err.Error() != want { + t.Fatalf("expected %q, got %q", want, err.Error()) + } +} + +func TestTestDBConnectionUnsupportedDatabaseType(t *testing.T) { + code, err := NewAgentService().TestDBConnection("user-1", &TestDBConnectionRequest{ + DBType: "postgres", + Database: "rag_flow", + Username: "root", + Host: "8.8.8.8", + Port: 5432, + Password: "password", + }) + if err == nil { + t.Fatal("expected unsupported database type error") + } + if code != common.CodeExceptionError { + t.Fatalf("expected code %d, got %d", common.CodeExceptionError, code) + } + if err.Error() != "Unsupported database type." { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestDBConnectionPortAcceptsStringAndNumber(t *testing.T) { + if got := dbConnectionPort("3306"); got != "3306" { + t.Fatalf("expected string port 3306, got %q", got) + } + if got := dbConnectionPort(float64(3306)); got != "3306" { + t.Fatalf("expected numeric port 3306, got %q", got) + } +} + // ptr returns a pointer to the given int64. -func ptr(v int64) *int64 { return &v } \ No newline at end of file +func ptr(v int64) *int64 { return &v }