feat[Go]: implement /api/v1/agents/<agent_id> and test_db_connection (#15771)

### What problem does this PR solve?

Add two API in go
```
/api/v1/agents/test_db_connection POST

/api/v1/agents/<agent_id>/sessions DELETE
```

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Yingfeng <yingfeng.zhang@gmail.com>
This commit is contained in:
Haruko386
2026-06-10 09:54:07 +08:00
committed by GitHub
parent 87b8062df4
commit a396b1ace2
7 changed files with 1173 additions and 1 deletions

View File

@@ -131,6 +131,13 @@ func (dao *API4ConversationDAO) GetBySessionID(sessionID, agentID string) (*enti
return &result, nil
}
// ListIDsByAgentID lists conversation IDs for one agent.
func (dao *API4ConversationDAO) ListIDsByAgentID(agentID string) ([]string, error) {
var ids []string
err := DB.Model(&entity.API4Conversation{}).Where("dialog_id = ?", agentID).Pluck("id", &ids).Error
return ids, err
}
// DeleteBySessionIDAndAgentID deletes API4Conversations by sessionID and agentID
func (dao *API4ConversationDAO) DeleteBySessionIDAndAgentID(sessionID, agentID string) (int64, error) {
result := DB.Where("id = ? AND dialog_id = ?", sessionID, agentID).Delete(&entity.API4Conversation{})

View File

@@ -18,6 +18,7 @@ package dao
import (
"encoding/json"
"sort"
"testing"
"github.com/glebarez/sqlite"
@@ -106,3 +107,43 @@ func TestAPI4ConversationDAOGetBySessionIDNoRows(t *testing.T) {
t.Fatalf("expected nil for missing session, got %+v", session)
}
}
func TestAPI4ConversationDAOListIDsByAgentID(t *testing.T) {
db := setupAPI4ConversationTestDB(t)
pushDB(t, db)
createAPI4ConversationForDAOTest(t, "session-1", "agent-1")
createAPI4ConversationForDAOTest(t, "session-2", "agent-1")
createAPI4ConversationForDAOTest(t, "session-other", "agent-2")
ids, err := NewAPI4ConversationDAO().ListIDsByAgentID("agent-1")
if err != nil {
t.Fatalf("ListIDsByAgentID failed: %v", err)
}
sort.Strings(ids)
want := []string{"session-1", "session-2"}
if len(ids) != len(want) {
t.Fatalf("expected %d ids, got %d: %v", len(want), len(ids), ids)
}
for i := range want {
if ids[i] != want[i] {
t.Fatalf("expected ids %v, got %v", want, ids)
}
}
}
func TestAPI4ConversationDAOListIDsByAgentIDNoRows(t *testing.T) {
db := setupAPI4ConversationTestDB(t)
pushDB(t, db)
createAPI4ConversationForDAOTest(t, "session-1", "agent-1")
ids, err := NewAPI4ConversationDAO().ListIDsByAgentID("agent-2")
if err != nil {
t.Fatalf("ListIDsByAgentID failed: %v", err)
}
if len(ids) != 0 {
t.Fatalf("expected empty ids for missing agent, got %v", ids)
}
}

View File

@@ -19,6 +19,7 @@ package handler
import (
"errors"
"fmt"
"io"
"mime/multipart"
"net/http"
"strconv"
@@ -307,6 +308,89 @@ func (h *AgentHandler) DeleteAgentSessionItem(c *gin.Context) {
})
}
type deleteAgentSessionsRequest struct {
IDs []string `json:"ids"`
DeleteAll bool `json:"delete_all,omitempty"`
}
func (h *AgentHandler) DeleteAgentSessions(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
agentID := strings.TrimSpace(c.Param("agent_id"))
if agentID == "" {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeOperatingError,
"data": nil,
"message": "agent_id is required",
})
return
}
var req deleteAgentSessionsRequest
if err := c.ShouldBindJSON(&req); err != nil && !errors.Is(err, io.EOF) {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeBadRequest,
"data": false,
"message": err.Error(),
})
return
}
result, code, err := h.agentService.DeleteAgentSessions(strings.TrimSpace(user.ID), agentID, req.IDs, req.DeleteAll)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"code": code,
"message": err.Error(),
})
return
}
response := gin.H{"code": common.CodeSuccess}
if result != nil && result.Data != nil {
response["data"] = result.Data
}
if result != nil && result.Message != "" {
response["message"] = result.Message
}
c.JSON(http.StatusOK, response)
}
// TestDBConnection Test DB connection
func (h *AgentHandler) TestDBConnection(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
var req service.TestDBConnectionRequest
if err := c.ShouldBindJSON(&req); err != nil && !errors.Is(err, io.EOF) {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeBadRequest,
"message": err.Error(),
})
return
}
code, err := h.agentService.TestDBConnection(user.ID, &req)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"code": code,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"code": common.CodeSuccess,
"message": "success",
})
}
// ListAgentVersions returns versions for a specific agent.
// @Summary List Agent Versions
// @Description Returns all versions for a specific agent, ordered by update_time DESC.

View File

@@ -508,6 +508,215 @@ func TestDeleteAgentSessionItemHandlerIgnoresSessionFromAnotherAgent(t *testing.
}
}
func TestDeleteAgentSessionsHandlerDeletesDuplicateIDsPartially(t *testing.T) {
c, w, db := setupGinContextWithUserAndDB(t, http.MethodDelete, "/api/v1/agents/canvas-1/sessions")
c.Request = httptest.NewRequest(http.MethodDelete, "/api/v1/agents/canvas-1/sessions", strings.NewReader(`{"ids":["session-1","session-1"]}`))
c.Request.Header.Set("Content-Type", "application/json")
c.Params = gin.Params{{Key: "agent_id", Value: "canvas-1"}}
db.Create(&entity.UserCanvas{
ID: "canvas-1",
UserID: "user-1",
Title: sptr("Test Agent"),
})
db.Create(&entity.API4Conversation{
ID: "session-1",
DialogID: "canvas-1",
UserID: "user-1",
Message: json.RawMessage(`[]`),
Reference: json.RawMessage(`[]`),
})
h := NewAgentHandler(service.NewAgentService(), nil)
h.DeleteAgentSessions(c)
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to parse response: %v", err)
}
if resp["code"] != float64(common.CodeSuccess) {
t.Fatalf("expected code %d, got %v: %v", common.CodeSuccess, resp["code"], resp["message"])
}
data, ok := resp["data"].(map[string]interface{})
if !ok {
t.Fatalf("expected partial data object, got %T", resp["data"])
}
if data["success_count"] != float64(1) {
t.Fatalf("expected success_count 1, got %v", data["success_count"])
}
errorsList, ok := data["errors"].([]interface{})
if !ok || len(errorsList) != 1 {
t.Fatalf("expected one duplicate error, got %v", data["errors"])
}
if errorsList[0] != "Duplicate session ids: session-1" {
t.Fatalf("unexpected duplicate error: %v", errorsList[0])
}
var count int64
if err := db.Model(&entity.API4Conversation{}).Where("id = ?", "session-1").Count(&count).Error; err != nil {
t.Fatalf("failed to count deleted session: %v", err)
}
if count != 0 {
t.Fatalf("expected session-1 to be deleted, count=%d", count)
}
}
func TestDeleteAgentSessionsHandlerDeleteAll(t *testing.T) {
c, w, db := setupGinContextWithUserAndDB(t, http.MethodDelete, "/api/v1/agents/canvas-1/sessions")
c.Request = httptest.NewRequest(http.MethodDelete, "/api/v1/agents/canvas-1/sessions", strings.NewReader(`{"delete_all":true}`))
c.Request.Header.Set("Content-Type", "application/json")
c.Params = gin.Params{{Key: "agent_id", Value: "canvas-1"}}
db.Create(&entity.UserCanvas{
ID: "canvas-1",
UserID: "user-1",
Title: sptr("Test Agent"),
})
db.Create(&entity.API4Conversation{
ID: "session-1",
DialogID: "canvas-1",
UserID: "user-1",
Message: json.RawMessage(`[]`),
Reference: json.RawMessage(`[]`),
})
db.Create(&entity.API4Conversation{
ID: "session-2",
DialogID: "canvas-1",
UserID: "user-1",
Message: json.RawMessage(`[]`),
Reference: json.RawMessage(`[]`),
})
db.Create(&entity.API4Conversation{
ID: "session-other",
DialogID: "canvas-other",
UserID: "user-1",
Message: json.RawMessage(`[]`),
Reference: json.RawMessage(`[]`),
})
h := NewAgentHandler(service.NewAgentService(), nil)
h.DeleteAgentSessions(c)
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to parse response: %v", err)
}
if resp["code"] != float64(common.CodeSuccess) {
t.Fatalf("expected code %d, got %v: %v", common.CodeSuccess, resp["code"], resp["message"])
}
var ownCount int64
if err := db.Model(&entity.API4Conversation{}).Where("dialog_id = ?", "canvas-1").Count(&ownCount).Error; err != nil {
t.Fatalf("failed to count own sessions: %v", err)
}
if ownCount != 0 {
t.Fatalf("expected all canvas-1 sessions to be deleted, count=%d", ownCount)
}
var otherCount int64
if err := db.Model(&entity.API4Conversation{}).Where("id = ?", "session-other").Count(&otherCount).Error; err != nil {
t.Fatalf("failed to count other session: %v", err)
}
if otherCount != 1 {
t.Fatalf("expected other agent session to remain, count=%d", otherCount)
}
}
func TestDeleteAgentSessionsHandlerRequiresOwner(t *testing.T) {
c, w, db := setupGinContextWithUserAndDB(t, http.MethodDelete, "/api/v1/agents/canvas-1/sessions")
c.Request = httptest.NewRequest(http.MethodDelete, "/api/v1/agents/canvas-1/sessions", strings.NewReader(`{"ids":["session-1"]}`))
c.Request.Header.Set("Content-Type", "application/json")
c.Params = gin.Params{{Key: "agent_id", Value: "canvas-1"}}
db.Create(&entity.UserCanvas{
ID: "canvas-1",
UserID: "user-2",
Permission: "team",
Title: sptr("Team Agent"),
})
db.Create(&entity.API4Conversation{
ID: "session-1",
DialogID: "canvas-1",
UserID: "user-1",
Message: json.RawMessage(`[]`),
Reference: json.RawMessage(`[]`),
})
h := NewAgentHandler(service.NewAgentService(), nil)
h.DeleteAgentSessions(c)
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to parse response: %v", err)
}
if resp["code"] != float64(common.CodeDataError) {
t.Fatalf("expected code %d, got %v: %v", common.CodeDataError, resp["code"], resp["message"])
}
var count int64
if err := db.Model(&entity.API4Conversation{}).Where("id = ?", "session-1").Count(&count).Error; err != nil {
t.Fatalf("failed to count session: %v", err)
}
if count != 1 {
t.Fatalf("expected session to remain, count=%d", count)
}
}
func TestTestDBConnectionHandlerMissingFields(t *testing.T) {
c, w, _ := setupGinContextWithUserAndDB(t, http.MethodPost, "/api/v1/agents/test_db_connection")
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/agents/test_db_connection", strings.NewReader(`{"db_type":"mysql"}`))
c.Request.Header.Set("Content-Type", "application/json")
h := NewAgentHandler(service.NewAgentService(), nil)
h.TestDBConnection(c)
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to parse response: %v", err)
}
if resp["code"] != float64(common.CodeArgumentError) {
t.Fatalf("expected code %d, got %v: %v", common.CodeArgumentError, resp["code"], resp["message"])
}
if resp["data"] != nil {
t.Fatalf("expected nil data, got %v", resp["data"])
}
want := "required argument are missing: database,username,host,port,password; "
if resp["message"] != want {
t.Fatalf("expected message %q, got %v", want, resp["message"])
}
}
func TestTestDBConnectionHandlerRejectsLocalhost(t *testing.T) {
c, w, _ := setupGinContextWithUserAndDB(t, http.MethodPost, "/api/v1/agents/test_db_connection")
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/agents/test_db_connection", strings.NewReader(`{
"db_type":"mysql",
"database":"rag_flow",
"username":"root",
"host":"localhost",
"port":3306,
"password":"infini_rag_flow"
}`))
c.Request.Header.Set("Content-Type", "application/json")
h := NewAgentHandler(service.NewAgentService(), nil)
h.TestDBConnection(c)
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to parse response: %v", err)
}
if resp["code"] != float64(common.CodeDataError) {
t.Fatalf("expected code %d, got %v: %v", common.CodeDataError, resp["code"], resp["message"])
}
if resp["data"] != nil {
t.Fatalf("expected nil data, got %v", resp["data"])
}
message, ok := resp["message"].(string)
if !ok || !strings.Contains(message, "non-public address") {
t.Fatalf("expected non-public host message, got %v", resp["message"])
}
}
func TestUpdateAgentTagsHandlerSuccess(t *testing.T) {
c, w, db := setupGinContextWithUserAndDB(t, http.MethodPut, "/api/v1/agents/canvas-1/tags")
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/agents/canvas-1/tags", strings.NewReader(`{"tags":["alpha","beta","alpha"]}`))

View File

@@ -388,6 +388,7 @@ func (r *Router) Setup(engine *gin.Engine) {
agents.GET("", r.agentHandler.ListAgents)
agents.GET("/prompts", r.agentHandler.GetPrompts)
agents.GET("/templates", r.agentHandler.ListTemplates)
agents.POST("/test_db_connection", r.agentHandler.TestDBConnection)
agents.GET("/:agent_id/versions", r.agentHandler.ListAgentVersions)
agents.GET("/:agent_id/versions/:version_id", r.agentHandler.GetAgentVersion)
agents.POST("/:agent_id/upload", r.agentHandler.UploadAgentFile)
@@ -395,6 +396,7 @@ func (r *Router) Setup(engine *gin.Engine) {
agents.GET("/:agent_id/sessions", r.agentHandler.ListAgentSessions)
agents.GET("/:agent_id/sessions/:session_id", r.agentHandler.GetAgentSession)
agents.DELETE("/:agent_id/sessions/:session_id", r.agentHandler.DeleteAgentSessionItem)
agents.DELETE("/:agent_id/sessions", r.agentHandler.DeleteAgentSessions)
}
// Plugin routes

View File

@@ -17,14 +17,21 @@
package service
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"net"
"net/netip"
"sort"
"strconv"
"strings"
"time"
"github.com/go-sql-driver/mysql"
"go.uber.org/zap"
"ragflow/internal/common"
"ragflow/internal/dao"
"ragflow/internal/entity"
@@ -163,6 +170,25 @@ type ListAgentSessionsResponse struct {
Total int64 `json:"total"`
}
type DeleteAgentSessionsResult struct {
Data *DeleteAgentSessionsResponse
Message string
}
type DeleteAgentSessionsResponse struct {
SuccessCount int `json:"success_count"`
Errors []string `json:"errors,omitempty"`
}
type TestDBConnectionRequest struct {
DBType string `json:"db_type"`
Database string `json:"database"`
Username string `json:"username"`
Host string `json:"host"`
Port interface{} `json:"port"`
Password string `json:"password"`
}
func parseAgentSessionDate(value string, isEnd bool) (*time.Time, error) {
if value == "" {
return nil, nil
@@ -370,6 +396,29 @@ func firstNonNil(values ...interface{}) interface{} {
return nil
}
// checkDuplicateSessionIDs check duplicated ID in IDS and add some messages for it
func checkDuplicateSessionIDs(ids []string) ([]string, []string) {
seen := make(map[string]int, len(ids))
uniqueIDs := make([]string, 0, len(ids))
for _, id := range ids {
id = strings.TrimSpace(id)
seen[id]++
if seen[id] == 1 {
uniqueIDs = append(uniqueIDs, id)
}
}
duplicateMessages := make([]string, 0)
for _, id := range uniqueIDs {
if seen[id] > 1 {
duplicateMessages = append(duplicateMessages, fmt.Sprintf("Duplicate session ids: %s", id))
}
}
return uniqueIDs, duplicateMessages
}
func (s *AgentService) ListAgentSessions(userID, tenantID, agentID string, req ListAgentSessionsRequest) (*ListAgentSessionsResponse, common.ErrorCode, error) {
if agentID == "" {
return nil, common.CodeArgumentError, errors.New("agent_id is required")
@@ -474,6 +523,301 @@ func (s *AgentService) DeleteAgentSessionItem(userID, agentID, sessionID string)
return true, common.CodeSuccess, nil
}
// DeleteAgentSessions Delete sessions by ids
func (s *AgentService) DeleteAgentSessions(userID, agentID string, ids []string, deleteAll bool) (*DeleteAgentSessionsResult, common.ErrorCode, error) {
if agentID == "" {
return nil, common.CodeArgumentError, errors.New("agent_id is required")
}
canvas, err := s.canvasDAO.GetByID(agentID)
if err != nil || canvas == nil || canvas.UserID != userID {
return nil, common.CodeDataError, fmt.Errorf("You don't own the agent %s", agentID)
}
if len(ids) == 0 {
if !deleteAll {
return &DeleteAgentSessionsResult{}, common.CodeSuccess, nil
}
ids, err = s.api4ConversationDAO.ListIDsByAgentID(agentID)
if err != nil {
return nil, common.CodeServerError, err
}
if len(ids) == 0 {
return &DeleteAgentSessionsResult{}, common.CodeSuccess, nil
}
}
sessionIDs, duplicateMessages := checkDuplicateSessionIDs(ids)
errorsList := make([]string, 0)
successCount := 0
for _, sessionID := range sessionIDs {
sessionID = strings.TrimSpace(sessionID)
if sessionID == "" {
errorsList = append(errorsList, "The agent doesn't own the session ")
continue
}
conv, err := s.api4ConversationDAO.GetBySessionID(sessionID, agentID)
if err != nil {
return nil, common.CodeServerError, err
}
if conv == nil {
errorsList = append(errorsList, fmt.Sprintf("The agent doesn't own the session %s", sessionID))
continue
}
if _, err := s.api4ConversationDAO.DeleteBySessionIDAndAgentID(sessionID, agentID); err != nil {
return nil, common.CodeServerError, err
}
successCount++
}
if len(errorsList) > 0 {
if successCount > 0 {
return &DeleteAgentSessionsResult{
Message: fmt.Sprintf("Partially deleted %d sessions with %d errors", successCount, len(errorsList)),
Data: &DeleteAgentSessionsResponse{
SuccessCount: successCount,
Errors: errorsList,
},
}, common.CodeSuccess, nil
}
return nil, common.CodeDataError, errors.New(strings.Join(errorsList, "; "))
}
if len(duplicateMessages) > 0 {
if successCount > 0 {
return &DeleteAgentSessionsResult{
Message: fmt.Sprintf("Partially deleted %d sessions with %d errors", successCount, len(duplicateMessages)),
Data: &DeleteAgentSessionsResponse{
SuccessCount: successCount,
Errors: duplicateMessages,
},
}, common.CodeSuccess, nil
}
return nil, common.CodeDataError, errors.New(strings.Join(duplicateMessages, ";"))
}
return &DeleteAgentSessionsResult{}, common.CodeSuccess, nil
}
// AssertHostIsSafe checks whether host resolves only to public IP addresses.
func AssertHostIsSafe(host string) (string, error) {
host = strings.TrimSpace(host)
if host == "" {
return "", errors.New("Host must not be empty.")
}
ips, err := net.LookupIP(host)
if err != nil {
zap.L().Warn("SSRF guard could not resolve host",
zap.String("host", host),
zap.Error(err),
)
return "", fmt.Errorf("Could not resolve host %q: %w", host, err)
}
if len(ips) == 0 {
zap.L().Warn("SSRF guard blocked host: resolved to no addresses",
zap.String("host", host),
)
return "", fmt.Errorf("Host %q resolved to no addresses.", host)
}
var resolvedIP string
for _, ip := range ips {
addr, ok := netip.AddrFromSlice(ip)
if !ok {
return "", fmt.Errorf("invalid resolved IP %q for host %q", ip.String(), host)
}
// Normalize IPv4-mapped IPv6, equivalent to Python _effective_ip().
addr = addr.Unmap()
if !isPublicAddr(addr) {
zap.L().Warn("SSRF guard blocked host",
zap.String("host", host),
zap.String("resolved_ip", addr.String()),
)
return "", fmt.Errorf("Host resolves to a non-public address (%s), which is not allowed.", addr.String())
}
if resolvedIP == "" {
resolvedIP = addr.String()
}
}
if resolvedIP == "" {
return "", fmt.Errorf("Host %q resolved to no addresses.", host)
}
return resolvedIP, nil
}
func isPublicAddr(addr netip.Addr) bool {
addr = addr.Unmap()
if !addr.IsValid() {
return false
}
if !addr.IsGlobalUnicast() {
return false
}
if addr.IsPrivate() ||
addr.IsLoopback() ||
addr.IsLinkLocalUnicast() ||
addr.IsLinkLocalMulticast() ||
addr.IsMulticast() ||
addr.IsUnspecified() {
return false
}
return !isSpecialUseAddr(addr)
}
func isSpecialUseAddr(addr netip.Addr) bool {
addr = addr.Unmap()
specialCIDRs := []string{
// IPv4 special-use / documentation / reserved ranges.
"0.0.0.0/8",
"100.64.0.0/10",
"127.0.0.0/8",
"169.254.0.0/16",
"192.0.0.0/24",
"192.0.2.0/24",
"198.18.0.0/15",
"198.51.100.0/24",
"203.0.113.0/24",
"224.0.0.0/4",
"240.0.0.0/4",
// IPv6 special-use / documentation / local ranges.
"::/128",
"::1/128",
"64:ff9b:1::/48",
"100::/64",
"2001::/23",
"2001:2::/48",
"fc00::/7",
"fe80::/10",
"ff00::/8",
"2001:db8::/32",
"2002::/16",
}
for _, cidr := range specialCIDRs {
prefix := netip.MustParsePrefix(cidr)
if prefix.Contains(addr) {
return true
}
}
return false
}
// missingDBConnectionFields Check if request is missing something
func missingDBConnectionFields(req *TestDBConnectionRequest) []string {
missing := make([]string, 0, 6)
if req == nil || strings.TrimSpace(req.DBType) == "" {
missing = append(missing, "db_type")
}
if req == nil || strings.TrimSpace(req.Database) == "" {
missing = append(missing, "database")
}
if req == nil || strings.TrimSpace(req.Username) == "" {
missing = append(missing, "username")
}
if req == nil || strings.TrimSpace(req.Host) == "" {
missing = append(missing, "host")
}
if req == nil || dbConnectionPort(req.Port) == "" {
missing = append(missing, "port")
}
if req == nil || req.Password == "" {
missing = append(missing, "password")
}
return missing
}
func dbConnectionPort(port interface{}) string {
switch value := port.(type) {
case nil:
return ""
case string:
return strings.TrimSpace(value)
case float64:
return strconv.Itoa(int(value))
case float32:
return strconv.Itoa(int(value))
case int:
return strconv.Itoa(value)
case int64:
return strconv.FormatInt(value, 10)
case json.Number:
return value.String()
default:
return strings.TrimSpace(fmt.Sprint(value))
}
}
func (s *AgentService) TestDBConnection(userID string, req *TestDBConnectionRequest) (common.ErrorCode, error) {
if missing := missingDBConnectionFields(req); len(missing) > 0 {
return common.CodeArgumentError, fmt.Errorf("required argument are missing: %s; ", strings.Join(missing, ","))
}
safeHost, err := AssertHostIsSafe(req.Host)
if err != nil {
zap.L().Warn(
"Rejected test_db_connection: unsafe host",
zap.String("host", req.Host),
zap.String("db_type", req.DBType),
zap.String("user", userID),
zap.Error(err),
)
return common.CodeDataError, err
}
switch req.DBType {
case "mysql", "mariadb", "oceanbase":
port := dbConnectionPort(req.Port)
dbProbeTimeout := 5 * time.Second
config := mysql.Config{
User: req.Username,
Passwd: req.Password,
Net: "tcp",
Addr: net.JoinHostPort(safeHost, port),
DBName: req.Database,
Timeout: dbProbeTimeout,
AllowNativePasswords: true,
}
db, err := sql.Open("mysql", config.FormatDSN())
if err != nil {
return common.CodeExceptionError, err
}
defer db.Close()
ctx, cancel := context.WithTimeout(context.Background(), dbProbeTimeout)
defer cancel()
if err := db.PingContext(ctx); err != nil {
return common.CodeExceptionError, err
}
if _, err := db.ExecContext(ctx, "SELECT 1"); err != nil {
return common.CodeExceptionError, err
}
default:
return common.CodeExceptionError, errors.New("Unsupported database type.")
}
return common.CodeSuccess, nil
}
// normalizeAgentTags returns an error for unsupported tag payload types
func normalizeAgentTags(rawTags interface{}) (string, error) {
cleaned := make([]string, 0)

View File

@@ -17,9 +17,13 @@
package service
import (
"encoding/json"
"net/netip"
"strings"
"testing"
"time"
"ragflow/internal/common"
"ragflow/internal/dao"
"ragflow/internal/entity"
)
@@ -344,5 +348,486 @@ func TestGetVersion_NotFound(t *testing.T) {
}
}
func setupAgentSessionServiceTest(t *testing.T) {
t.Helper()
testDB := setupServiceTestDB(t)
if err := testDB.AutoMigrate(
&entity.User{},
&entity.UserCanvas{},
&entity.UserTenant{},
&entity.API4Conversation{},
); err != nil {
t.Fatalf("failed to migrate: %v", err)
}
orig := dao.DB
dao.DB = testDB
t.Cleanup(func() { dao.DB = orig })
}
func createAgentSessionTestCanvas(t *testing.T, id, userID string) {
t.Helper()
if err := dao.DB.Create(&entity.UserCanvas{
ID: id,
UserID: userID,
Title: sptr("Test Agent"),
CanvasCategory: "agent_canvas",
}).Error; err != nil {
t.Fatalf("failed to create canvas %s: %v", id, err)
}
}
func createAgentSessionTestConversation(t *testing.T, id, agentID, userID string, updateTime int64) {
t.Helper()
updateDate := time.UnixMilli(updateTime)
if err := dao.DB.Create(&entity.API4Conversation{
ID: id,
DialogID: agentID,
UserID: userID,
Message: json.RawMessage(`[{"role":"assistant","content":"hello","prompt":"hidden"}]`),
Reference: json.RawMessage(`[]`),
BaseModel: entity.BaseModel{
CreateTime: ptr(updateTime),
CreateDate: &updateDate,
UpdateTime: ptr(updateTime),
UpdateDate: &updateDate,
},
}).Error; err != nil {
t.Fatalf("failed to create session %s: %v", id, err)
}
}
func TestListAgentSessionsServiceSuccess(t *testing.T) {
setupAgentSessionServiceTest(t)
createAgentSessionTestCanvas(t, "canvas-1", "user-1")
createAgentSessionTestConversation(t, "session-old", "canvas-1", "user-1", 1000)
createAgentSessionTestConversation(t, "session-new", "canvas-1", "user-1", 3000)
createAgentSessionTestConversation(t, "session-other-agent", "canvas-other", "user-1", 9999)
resp, code, err := NewAgentService().ListAgentSessions("user-1", "user-1", "canvas-1", ListAgentSessionsRequest{
Page: 1,
PageSize: 10,
OrderBy: "update_time",
Desc: true,
IncludeDSL: false,
})
if err != nil {
t.Fatalf("ListAgentSessions failed: %v", err)
}
if code != common.CodeSuccess {
t.Fatalf("expected code %d, got %d", common.CodeSuccess, code)
}
if resp.Total != 2 {
t.Fatalf("expected total 2, got %d", resp.Total)
}
if len(resp.Data) != 2 {
t.Fatalf("expected 2 sessions, got %d", len(resp.Data))
}
if resp.Data[0]["id"] != "session-new" {
t.Fatalf("expected newest session first, got %v", resp.Data[0]["id"])
}
if resp.Data[0]["agent_id"] != "canvas-1" {
t.Fatalf("expected agent_id canvas-1, got %v", resp.Data[0]["agent_id"])
}
messages, ok := resp.Data[0]["message"].([]map[string]interface{})
if !ok || len(messages) != 1 {
t.Fatalf("expected normalized message slice, got %T %v", resp.Data[0]["message"], resp.Data[0]["message"])
}
if _, ok := messages[0]["prompt"]; ok {
t.Fatal("expected prompt to be removed from normalized session message")
}
}
func TestListAgentSessionsServiceDenied(t *testing.T) {
setupAgentSessionServiceTest(t)
createAgentSessionTestCanvas(t, "canvas-1", "user-2")
resp, code, err := NewAgentService().ListAgentSessions("user-1", "user-1", "canvas-1", ListAgentSessionsRequest{})
if err == nil {
t.Fatal("expected permission error")
}
if code != common.CodeOperatingError {
t.Fatalf("expected code %d, got %d", common.CodeOperatingError, code)
}
if resp != nil {
t.Fatalf("expected nil response, got %+v", resp)
}
}
func TestGetAgentSessionServiceSuccess(t *testing.T) {
setupAgentSessionServiceTest(t)
createAgentSessionTestCanvas(t, "canvas-1", "user-1")
createAgentSessionTestConversation(t, "session-1", "canvas-1", "user-1", 1000)
session, code, err := NewAgentService().GetAgentSession("user-1", "canvas-1", "session-1")
if err != nil {
t.Fatalf("GetAgentSession failed: %v", err)
}
if code != common.CodeSuccess {
t.Fatalf("expected code %d, got %d", common.CodeSuccess, code)
}
if session == nil {
t.Fatal("expected session, got nil")
}
if session.ID != "session-1" {
t.Fatalf("expected session-1, got %s", session.ID)
}
if session.DialogID != "canvas-1" {
t.Fatalf("expected dialog_id canvas-1, got %s", session.DialogID)
}
}
func TestGetAgentSessionServiceNotFoundWhenSessionBelongsToAnotherAgent(t *testing.T) {
setupAgentSessionServiceTest(t)
createAgentSessionTestCanvas(t, "canvas-1", "user-1")
createAgentSessionTestConversation(t, "session-other", "canvas-other", "user-1", 1000)
session, code, err := NewAgentService().GetAgentSession("user-1", "canvas-1", "session-other")
if err == nil {
t.Fatal("expected not found error")
}
if code != common.CodeNotFound {
t.Fatalf("expected code %d, got %d", common.CodeNotFound, code)
}
if session != nil {
t.Fatalf("expected nil session, got %+v", session)
}
}
func TestDeleteAgentSessionItemServiceDeletesMatchingSession(t *testing.T) {
setupAgentSessionServiceTest(t)
createAgentSessionTestCanvas(t, "canvas-1", "user-1")
createAgentSessionTestConversation(t, "session-1", "canvas-1", "user-1", 1000)
createAgentSessionTestConversation(t, "session-other", "canvas-other", "user-1", 2000)
deleted, code, err := NewAgentService().DeleteAgentSessionItem("user-1", "canvas-1", "session-1")
if err != nil {
t.Fatalf("DeleteAgentSessionItem failed: %v", err)
}
if code != common.CodeSuccess {
t.Fatalf("expected code %d, got %d", common.CodeSuccess, code)
}
if !deleted {
t.Fatal("expected session to be deleted")
}
var count int64
if err := dao.DB.Model(&entity.API4Conversation{}).Where("id = ?", "session-1").Count(&count).Error; err != nil {
t.Fatalf("failed to count deleted session: %v", err)
}
if count != 0 {
t.Fatalf("expected session-1 to be deleted, count=%d", count)
}
if err := dao.DB.Model(&entity.API4Conversation{}).Where("id = ?", "session-other").Count(&count).Error; err != nil {
t.Fatalf("failed to count other session: %v", err)
}
if count != 1 {
t.Fatalf("expected other agent session to remain, count=%d", count)
}
}
func TestDeleteAgentSessionItemServiceNoopForSessionFromAnotherAgent(t *testing.T) {
setupAgentSessionServiceTest(t)
createAgentSessionTestCanvas(t, "canvas-1", "user-1")
createAgentSessionTestConversation(t, "session-other", "canvas-other", "user-1", 1000)
deleted, code, err := NewAgentService().DeleteAgentSessionItem("user-1", "canvas-1", "session-other")
if err != nil {
t.Fatalf("DeleteAgentSessionItem failed: %v", err)
}
if code != common.CodeSuccess {
t.Fatalf("expected code %d, got %d", common.CodeSuccess, code)
}
if deleted {
t.Fatal("expected cross-agent session delete to be a noop")
}
var count int64
if err := dao.DB.Model(&entity.API4Conversation{}).Where("id = ?", "session-other").Count(&count).Error; err != nil {
t.Fatalf("failed to count other session: %v", err)
}
if count != 1 {
t.Fatalf("expected cross-agent session to remain, count=%d", count)
}
}
func TestDeleteAgentSessionsServiceDeleteAll(t *testing.T) {
setupAgentSessionServiceTest(t)
createAgentSessionTestCanvas(t, "canvas-1", "user-1")
createAgentSessionTestConversation(t, "session-1", "canvas-1", "user-1", 1000)
createAgentSessionTestConversation(t, "session-2", "canvas-1", "user-1", 2000)
createAgentSessionTestConversation(t, "session-other", "canvas-other", "user-1", 3000)
result, code, err := NewAgentService().DeleteAgentSessions("user-1", "canvas-1", nil, true)
if err != nil {
t.Fatalf("DeleteAgentSessions failed: %v", err)
}
if code != common.CodeSuccess {
t.Fatalf("expected code %d, got %d", common.CodeSuccess, code)
}
if result == nil {
t.Fatal("expected non-nil result")
}
if result.Data != nil {
t.Fatalf("expected no partial data on full success, got %+v", result.Data)
}
var ownCount int64
if err := dao.DB.Model(&entity.API4Conversation{}).Where("dialog_id = ?", "canvas-1").Count(&ownCount).Error; err != nil {
t.Fatalf("failed to count own sessions: %v", err)
}
if ownCount != 0 {
t.Fatalf("expected all canvas-1 sessions to be deleted, count=%d", ownCount)
}
var otherCount int64
if err := dao.DB.Model(&entity.API4Conversation{}).Where("id = ?", "session-other").Count(&otherCount).Error; err != nil {
t.Fatalf("failed to count other session: %v", err)
}
if otherCount != 1 {
t.Fatalf("expected other agent session to remain, count=%d", otherCount)
}
}
func TestDeleteAgentSessionsServiceDuplicateIDsPartial(t *testing.T) {
setupAgentSessionServiceTest(t)
createAgentSessionTestCanvas(t, "canvas-1", "user-1")
createAgentSessionTestConversation(t, "session-1", "canvas-1", "user-1", 1000)
result, code, err := NewAgentService().DeleteAgentSessions("user-1", "canvas-1", []string{"session-1", "session-1"}, false)
if err != nil {
t.Fatalf("DeleteAgentSessions failed: %v", err)
}
if code != common.CodeSuccess {
t.Fatalf("expected code %d, got %d", common.CodeSuccess, code)
}
if result == nil || result.Data == nil {
t.Fatalf("expected partial result data, got %+v", result)
}
if result.Data.SuccessCount != 1 {
t.Fatalf("expected success_count 1, got %d", result.Data.SuccessCount)
}
if len(result.Data.Errors) != 1 || result.Data.Errors[0] != "Duplicate session ids: session-1" {
t.Fatalf("unexpected duplicate errors: %v", result.Data.Errors)
}
var count int64
if err := dao.DB.Model(&entity.API4Conversation{}).Where("id = ?", "session-1").Count(&count).Error; err != nil {
t.Fatalf("failed to count deleted session: %v", err)
}
if count != 0 {
t.Fatalf("expected session-1 to be deleted, count=%d", count)
}
}
func TestDeleteAgentSessionsServiceMissingSessionError(t *testing.T) {
setupAgentSessionServiceTest(t)
createAgentSessionTestCanvas(t, "canvas-1", "user-1")
result, code, err := NewAgentService().DeleteAgentSessions("user-1", "canvas-1", []string{"missing-session"}, false)
if err == nil {
t.Fatal("expected missing session error")
}
if code != common.CodeDataError {
t.Fatalf("expected code %d, got %d", common.CodeDataError, code)
}
if result != nil {
t.Fatalf("expected nil result, got %+v", result)
}
if !strings.Contains(err.Error(), "The agent doesn't own the session missing-session") {
t.Fatalf("unexpected error: %v", err)
}
}
func TestDeleteAgentSessionsServiceRequiresOwner(t *testing.T) {
setupAgentSessionServiceTest(t)
createAgentSessionTestCanvas(t, "canvas-1", "user-2")
createAgentSessionTestConversation(t, "session-1", "canvas-1", "user-1", 1000)
result, code, err := NewAgentService().DeleteAgentSessions("user-1", "canvas-1", []string{"session-1"}, false)
if err == nil {
t.Fatal("expected owner error")
}
if code != common.CodeDataError {
t.Fatalf("expected code %d, got %d", common.CodeDataError, code)
}
if result != nil {
t.Fatalf("expected nil result, got %+v", result)
}
var count int64
if err := dao.DB.Model(&entity.API4Conversation{}).Where("id = ?", "session-1").Count(&count).Error; err != nil {
t.Fatalf("failed to count session: %v", err)
}
if count != 1 {
t.Fatalf("expected session to remain, count=%d", count)
}
}
func TestUpdateAgentTagsServiceSuccess(t *testing.T) {
setupAgentSessionServiceTest(t)
createAgentSessionTestCanvas(t, "canvas-1", "user-1")
ok, code, err := NewAgentService().UpdateAgentTags("user-1", "canvas-1", []interface{}{"alpha", "beta", "alpha", "with,comma"})
if err != nil {
t.Fatalf("UpdateAgentTags failed: %v", err)
}
if code != common.CodeSuccess {
t.Fatalf("expected code %d, got %d", common.CodeSuccess, code)
}
if !ok {
t.Fatal("expected update to succeed")
}
canvas, err := dao.NewUserCanvasDAO().GetByID("canvas-1")
if err != nil {
t.Fatalf("failed to get canvas: %v", err)
}
if canvas.Tags != "alpha,beta,with comma" {
t.Fatalf("expected normalized tags, got %q", canvas.Tags)
}
}
func TestUpdateAgentTagsServiceInvalidPayload(t *testing.T) {
setupAgentSessionServiceTest(t)
createAgentSessionTestCanvas(t, "canvas-1", "user-1")
ok, code, err := NewAgentService().UpdateAgentTags("user-1", "canvas-1", map[string]string{"tag": "alpha"})
if err == nil {
t.Fatal("expected invalid tags error")
}
if code != common.CodeBadRequest {
t.Fatalf("expected code %d, got %d", common.CodeBadRequest, code)
}
if ok {
t.Fatal("expected update to fail")
}
canvas, err := dao.NewUserCanvasDAO().GetByID("canvas-1")
if err != nil {
t.Fatalf("failed to get canvas: %v", err)
}
if canvas.Tags != "" {
t.Fatalf("expected tags to remain unchanged, got %q", canvas.Tags)
}
}
func TestUpdateAgentTagsServiceNoPermission(t *testing.T) {
setupAgentSessionServiceTest(t)
createAgentSessionTestCanvas(t, "canvas-1", "user-2")
ok, code, err := NewAgentService().UpdateAgentTags("user-1", "canvas-1", []string{"alpha"})
if err == nil {
t.Fatal("expected permission error")
}
if code != common.CodeOperatingError {
t.Fatalf("expected code %d, got %d", common.CodeOperatingError, code)
}
if ok {
t.Fatal("expected update to fail")
}
canvas, err := dao.NewUserCanvasDAO().GetByID("canvas-1")
if err != nil {
t.Fatalf("failed to get canvas: %v", err)
}
if canvas.Tags != "" {
t.Fatalf("expected tags to remain unchanged, got %q", canvas.Tags)
}
}
func TestIsPublicAddr(t *testing.T) {
tests := []struct {
name string
addr string
want bool
}{
{name: "public IPv4", addr: "8.8.8.8", want: true},
{name: "loopback", addr: "127.0.0.1", want: false},
{name: "private", addr: "192.168.1.1", want: false},
{name: "carrier NAT", addr: "100.64.0.1", want: false},
{name: "documentation", addr: "203.0.113.1", want: false},
{name: "IPv4 mapped loopback", addr: "::ffff:127.0.0.1", want: false},
{name: "IPv6 documentation", addr: "2001:db8::1", want: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := isPublicAddr(netip.MustParseAddr(tt.addr))
if got != tt.want {
t.Fatalf("isPublicAddr(%s): expected %v, got %v", tt.addr, tt.want, got)
}
})
}
}
func TestAssertHostIsSafeRejectsLocalhost(t *testing.T) {
_, err := AssertHostIsSafe("localhost")
if err == nil {
t.Fatal("expected localhost to be rejected")
}
if !strings.Contains(err.Error(), "non-public address") {
t.Fatalf("expected non-public address error, got %v", err)
}
}
func TestTestDBConnectionMissingFields(t *testing.T) {
code, err := NewAgentService().TestDBConnection("user-1", &TestDBConnectionRequest{DBType: "mysql"})
if err == nil {
t.Fatal("expected missing field error")
}
if code != common.CodeArgumentError {
t.Fatalf("expected code %d, got %d", common.CodeArgumentError, code)
}
want := "required argument are missing: database,username,host,port,password; "
if err.Error() != want {
t.Fatalf("expected %q, got %q", want, err.Error())
}
}
func TestTestDBConnectionUnsupportedDatabaseType(t *testing.T) {
code, err := NewAgentService().TestDBConnection("user-1", &TestDBConnectionRequest{
DBType: "postgres",
Database: "rag_flow",
Username: "root",
Host: "8.8.8.8",
Port: 5432,
Password: "password",
})
if err == nil {
t.Fatal("expected unsupported database type error")
}
if code != common.CodeExceptionError {
t.Fatalf("expected code %d, got %d", common.CodeExceptionError, code)
}
if err.Error() != "Unsupported database type." {
t.Fatalf("unexpected error: %v", err)
}
}
func TestDBConnectionPortAcceptsStringAndNumber(t *testing.T) {
if got := dbConnectionPort("3306"); got != "3306" {
t.Fatalf("expected string port 3306, got %q", got)
}
if got := dbConnectionPort(float64(3306)); got != "3306" {
t.Fatalf("expected numeric port 3306, got %q", got)
}
}
// ptr returns a pointer to the given int64.
func ptr(v int64) *int64 { return &v }
func ptr(v int64) *int64 { return &v }