mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-05 10:58:34 +08:00
feat[Go]: implement /api/v1/agents/<agent_id> and test_db_connection (#15771)
### What problem does this PR solve? Add two API in go ``` /api/v1/agents/test_db_connection POST /api/v1/agents/<agent_id>/sessions DELETE ``` ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Yingfeng <yingfeng.zhang@gmail.com>
This commit is contained in:
@@ -131,6 +131,13 @@ func (dao *API4ConversationDAO) GetBySessionID(sessionID, agentID string) (*enti
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// ListIDsByAgentID lists conversation IDs for one agent.
|
||||
func (dao *API4ConversationDAO) ListIDsByAgentID(agentID string) ([]string, error) {
|
||||
var ids []string
|
||||
err := DB.Model(&entity.API4Conversation{}).Where("dialog_id = ?", agentID).Pluck("id", &ids).Error
|
||||
return ids, err
|
||||
}
|
||||
|
||||
// DeleteBySessionIDAndAgentID deletes API4Conversations by sessionID and agentID
|
||||
func (dao *API4ConversationDAO) DeleteBySessionIDAndAgentID(sessionID, agentID string) (int64, error) {
|
||||
result := DB.Where("id = ? AND dialog_id = ?", sessionID, agentID).Delete(&entity.API4Conversation{})
|
||||
|
||||
@@ -18,6 +18,7 @@ package dao
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
@@ -106,3 +107,43 @@ func TestAPI4ConversationDAOGetBySessionIDNoRows(t *testing.T) {
|
||||
t.Fatalf("expected nil for missing session, got %+v", session)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPI4ConversationDAOListIDsByAgentID(t *testing.T) {
|
||||
db := setupAPI4ConversationTestDB(t)
|
||||
pushDB(t, db)
|
||||
|
||||
createAPI4ConversationForDAOTest(t, "session-1", "agent-1")
|
||||
createAPI4ConversationForDAOTest(t, "session-2", "agent-1")
|
||||
createAPI4ConversationForDAOTest(t, "session-other", "agent-2")
|
||||
|
||||
ids, err := NewAPI4ConversationDAO().ListIDsByAgentID("agent-1")
|
||||
if err != nil {
|
||||
t.Fatalf("ListIDsByAgentID failed: %v", err)
|
||||
}
|
||||
|
||||
sort.Strings(ids)
|
||||
want := []string{"session-1", "session-2"}
|
||||
if len(ids) != len(want) {
|
||||
t.Fatalf("expected %d ids, got %d: %v", len(want), len(ids), ids)
|
||||
}
|
||||
for i := range want {
|
||||
if ids[i] != want[i] {
|
||||
t.Fatalf("expected ids %v, got %v", want, ids)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPI4ConversationDAOListIDsByAgentIDNoRows(t *testing.T) {
|
||||
db := setupAPI4ConversationTestDB(t)
|
||||
pushDB(t, db)
|
||||
|
||||
createAPI4ConversationForDAOTest(t, "session-1", "agent-1")
|
||||
|
||||
ids, err := NewAPI4ConversationDAO().ListIDsByAgentID("agent-2")
|
||||
if err != nil {
|
||||
t.Fatalf("ListIDsByAgentID failed: %v", err)
|
||||
}
|
||||
if len(ids) != 0 {
|
||||
t.Fatalf("expected empty ids for missing agent, got %v", ids)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ package handler
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"strconv"
|
||||
@@ -307,6 +308,89 @@ func (h *AgentHandler) DeleteAgentSessionItem(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
type deleteAgentSessionsRequest struct {
|
||||
IDs []string `json:"ids"`
|
||||
DeleteAll bool `json:"delete_all,omitempty"`
|
||||
}
|
||||
|
||||
func (h *AgentHandler) DeleteAgentSessions(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
agentID := strings.TrimSpace(c.Param("agent_id"))
|
||||
if agentID == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeOperatingError,
|
||||
"data": nil,
|
||||
"message": "agent_id is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var req deleteAgentSessionsRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil && !errors.Is(err, io.EOF) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeBadRequest,
|
||||
"data": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
result, code, err := h.agentService.DeleteAgentSessions(strings.TrimSpace(user.ID), agentID, req.IDs, req.DeleteAll)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
response := gin.H{"code": common.CodeSuccess}
|
||||
if result != nil && result.Data != nil {
|
||||
response["data"] = result.Data
|
||||
}
|
||||
if result != nil && result.Message != "" {
|
||||
response["message"] = result.Message
|
||||
}
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// TestDBConnection Test DB connection
|
||||
func (h *AgentHandler) TestDBConnection(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
var req service.TestDBConnectionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil && !errors.Is(err, io.EOF) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeBadRequest,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
code, err := h.agentService.TestDBConnection(user.ID, &req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": code,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeSuccess,
|
||||
"message": "success",
|
||||
})
|
||||
}
|
||||
|
||||
// ListAgentVersions returns versions for a specific agent.
|
||||
// @Summary List Agent Versions
|
||||
// @Description Returns all versions for a specific agent, ordered by update_time DESC.
|
||||
|
||||
@@ -508,6 +508,215 @@ func TestDeleteAgentSessionItemHandlerIgnoresSessionFromAnotherAgent(t *testing.
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteAgentSessionsHandlerDeletesDuplicateIDsPartially(t *testing.T) {
|
||||
c, w, db := setupGinContextWithUserAndDB(t, http.MethodDelete, "/api/v1/agents/canvas-1/sessions")
|
||||
c.Request = httptest.NewRequest(http.MethodDelete, "/api/v1/agents/canvas-1/sessions", strings.NewReader(`{"ids":["session-1","session-1"]}`))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Params = gin.Params{{Key: "agent_id", Value: "canvas-1"}}
|
||||
|
||||
db.Create(&entity.UserCanvas{
|
||||
ID: "canvas-1",
|
||||
UserID: "user-1",
|
||||
Title: sptr("Test Agent"),
|
||||
})
|
||||
db.Create(&entity.API4Conversation{
|
||||
ID: "session-1",
|
||||
DialogID: "canvas-1",
|
||||
UserID: "user-1",
|
||||
Message: json.RawMessage(`[]`),
|
||||
Reference: json.RawMessage(`[]`),
|
||||
})
|
||||
|
||||
h := NewAgentHandler(service.NewAgentService(), nil)
|
||||
h.DeleteAgentSessions(c)
|
||||
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to parse response: %v", err)
|
||||
}
|
||||
if resp["code"] != float64(common.CodeSuccess) {
|
||||
t.Fatalf("expected code %d, got %v: %v", common.CodeSuccess, resp["code"], resp["message"])
|
||||
}
|
||||
data, ok := resp["data"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected partial data object, got %T", resp["data"])
|
||||
}
|
||||
if data["success_count"] != float64(1) {
|
||||
t.Fatalf("expected success_count 1, got %v", data["success_count"])
|
||||
}
|
||||
errorsList, ok := data["errors"].([]interface{})
|
||||
if !ok || len(errorsList) != 1 {
|
||||
t.Fatalf("expected one duplicate error, got %v", data["errors"])
|
||||
}
|
||||
if errorsList[0] != "Duplicate session ids: session-1" {
|
||||
t.Fatalf("unexpected duplicate error: %v", errorsList[0])
|
||||
}
|
||||
|
||||
var count int64
|
||||
if err := db.Model(&entity.API4Conversation{}).Where("id = ?", "session-1").Count(&count).Error; err != nil {
|
||||
t.Fatalf("failed to count deleted session: %v", err)
|
||||
}
|
||||
if count != 0 {
|
||||
t.Fatalf("expected session-1 to be deleted, count=%d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteAgentSessionsHandlerDeleteAll(t *testing.T) {
|
||||
c, w, db := setupGinContextWithUserAndDB(t, http.MethodDelete, "/api/v1/agents/canvas-1/sessions")
|
||||
c.Request = httptest.NewRequest(http.MethodDelete, "/api/v1/agents/canvas-1/sessions", strings.NewReader(`{"delete_all":true}`))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Params = gin.Params{{Key: "agent_id", Value: "canvas-1"}}
|
||||
|
||||
db.Create(&entity.UserCanvas{
|
||||
ID: "canvas-1",
|
||||
UserID: "user-1",
|
||||
Title: sptr("Test Agent"),
|
||||
})
|
||||
db.Create(&entity.API4Conversation{
|
||||
ID: "session-1",
|
||||
DialogID: "canvas-1",
|
||||
UserID: "user-1",
|
||||
Message: json.RawMessage(`[]`),
|
||||
Reference: json.RawMessage(`[]`),
|
||||
})
|
||||
db.Create(&entity.API4Conversation{
|
||||
ID: "session-2",
|
||||
DialogID: "canvas-1",
|
||||
UserID: "user-1",
|
||||
Message: json.RawMessage(`[]`),
|
||||
Reference: json.RawMessage(`[]`),
|
||||
})
|
||||
db.Create(&entity.API4Conversation{
|
||||
ID: "session-other",
|
||||
DialogID: "canvas-other",
|
||||
UserID: "user-1",
|
||||
Message: json.RawMessage(`[]`),
|
||||
Reference: json.RawMessage(`[]`),
|
||||
})
|
||||
|
||||
h := NewAgentHandler(service.NewAgentService(), nil)
|
||||
h.DeleteAgentSessions(c)
|
||||
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to parse response: %v", err)
|
||||
}
|
||||
if resp["code"] != float64(common.CodeSuccess) {
|
||||
t.Fatalf("expected code %d, got %v: %v", common.CodeSuccess, resp["code"], resp["message"])
|
||||
}
|
||||
|
||||
var ownCount int64
|
||||
if err := db.Model(&entity.API4Conversation{}).Where("dialog_id = ?", "canvas-1").Count(&ownCount).Error; err != nil {
|
||||
t.Fatalf("failed to count own sessions: %v", err)
|
||||
}
|
||||
if ownCount != 0 {
|
||||
t.Fatalf("expected all canvas-1 sessions to be deleted, count=%d", ownCount)
|
||||
}
|
||||
|
||||
var otherCount int64
|
||||
if err := db.Model(&entity.API4Conversation{}).Where("id = ?", "session-other").Count(&otherCount).Error; err != nil {
|
||||
t.Fatalf("failed to count other session: %v", err)
|
||||
}
|
||||
if otherCount != 1 {
|
||||
t.Fatalf("expected other agent session to remain, count=%d", otherCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteAgentSessionsHandlerRequiresOwner(t *testing.T) {
|
||||
c, w, db := setupGinContextWithUserAndDB(t, http.MethodDelete, "/api/v1/agents/canvas-1/sessions")
|
||||
c.Request = httptest.NewRequest(http.MethodDelete, "/api/v1/agents/canvas-1/sessions", strings.NewReader(`{"ids":["session-1"]}`))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Params = gin.Params{{Key: "agent_id", Value: "canvas-1"}}
|
||||
|
||||
db.Create(&entity.UserCanvas{
|
||||
ID: "canvas-1",
|
||||
UserID: "user-2",
|
||||
Permission: "team",
|
||||
Title: sptr("Team Agent"),
|
||||
})
|
||||
db.Create(&entity.API4Conversation{
|
||||
ID: "session-1",
|
||||
DialogID: "canvas-1",
|
||||
UserID: "user-1",
|
||||
Message: json.RawMessage(`[]`),
|
||||
Reference: json.RawMessage(`[]`),
|
||||
})
|
||||
|
||||
h := NewAgentHandler(service.NewAgentService(), nil)
|
||||
h.DeleteAgentSessions(c)
|
||||
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to parse response: %v", err)
|
||||
}
|
||||
if resp["code"] != float64(common.CodeDataError) {
|
||||
t.Fatalf("expected code %d, got %v: %v", common.CodeDataError, resp["code"], resp["message"])
|
||||
}
|
||||
|
||||
var count int64
|
||||
if err := db.Model(&entity.API4Conversation{}).Where("id = ?", "session-1").Count(&count).Error; err != nil {
|
||||
t.Fatalf("failed to count session: %v", err)
|
||||
}
|
||||
if count != 1 {
|
||||
t.Fatalf("expected session to remain, count=%d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTestDBConnectionHandlerMissingFields(t *testing.T) {
|
||||
c, w, _ := setupGinContextWithUserAndDB(t, http.MethodPost, "/api/v1/agents/test_db_connection")
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/agents/test_db_connection", strings.NewReader(`{"db_type":"mysql"}`))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h := NewAgentHandler(service.NewAgentService(), nil)
|
||||
h.TestDBConnection(c)
|
||||
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to parse response: %v", err)
|
||||
}
|
||||
if resp["code"] != float64(common.CodeArgumentError) {
|
||||
t.Fatalf("expected code %d, got %v: %v", common.CodeArgumentError, resp["code"], resp["message"])
|
||||
}
|
||||
if resp["data"] != nil {
|
||||
t.Fatalf("expected nil data, got %v", resp["data"])
|
||||
}
|
||||
want := "required argument are missing: database,username,host,port,password; "
|
||||
if resp["message"] != want {
|
||||
t.Fatalf("expected message %q, got %v", want, resp["message"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestTestDBConnectionHandlerRejectsLocalhost(t *testing.T) {
|
||||
c, w, _ := setupGinContextWithUserAndDB(t, http.MethodPost, "/api/v1/agents/test_db_connection")
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/agents/test_db_connection", strings.NewReader(`{
|
||||
"db_type":"mysql",
|
||||
"database":"rag_flow",
|
||||
"username":"root",
|
||||
"host":"localhost",
|
||||
"port":3306,
|
||||
"password":"infini_rag_flow"
|
||||
}`))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h := NewAgentHandler(service.NewAgentService(), nil)
|
||||
h.TestDBConnection(c)
|
||||
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to parse response: %v", err)
|
||||
}
|
||||
if resp["code"] != float64(common.CodeDataError) {
|
||||
t.Fatalf("expected code %d, got %v: %v", common.CodeDataError, resp["code"], resp["message"])
|
||||
}
|
||||
if resp["data"] != nil {
|
||||
t.Fatalf("expected nil data, got %v", resp["data"])
|
||||
}
|
||||
message, ok := resp["message"].(string)
|
||||
if !ok || !strings.Contains(message, "non-public address") {
|
||||
t.Fatalf("expected non-public host message, got %v", resp["message"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateAgentTagsHandlerSuccess(t *testing.T) {
|
||||
c, w, db := setupGinContextWithUserAndDB(t, http.MethodPut, "/api/v1/agents/canvas-1/tags")
|
||||
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/agents/canvas-1/tags", strings.NewReader(`{"tags":["alpha","beta","alpha"]}`))
|
||||
|
||||
@@ -388,6 +388,7 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
agents.GET("", r.agentHandler.ListAgents)
|
||||
agents.GET("/prompts", r.agentHandler.GetPrompts)
|
||||
agents.GET("/templates", r.agentHandler.ListTemplates)
|
||||
agents.POST("/test_db_connection", r.agentHandler.TestDBConnection)
|
||||
agents.GET("/:agent_id/versions", r.agentHandler.ListAgentVersions)
|
||||
agents.GET("/:agent_id/versions/:version_id", r.agentHandler.GetAgentVersion)
|
||||
agents.POST("/:agent_id/upload", r.agentHandler.UploadAgentFile)
|
||||
@@ -395,6 +396,7 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
agents.GET("/:agent_id/sessions", r.agentHandler.ListAgentSessions)
|
||||
agents.GET("/:agent_id/sessions/:session_id", r.agentHandler.GetAgentSession)
|
||||
agents.DELETE("/:agent_id/sessions/:session_id", r.agentHandler.DeleteAgentSessionItem)
|
||||
agents.DELETE("/:agent_id/sessions", r.agentHandler.DeleteAgentSessions)
|
||||
}
|
||||
|
||||
// Plugin routes
|
||||
|
||||
@@ -17,14 +17,21 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-sql-driver/mysql"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/entity"
|
||||
@@ -163,6 +170,25 @@ type ListAgentSessionsResponse struct {
|
||||
Total int64 `json:"total"`
|
||||
}
|
||||
|
||||
type DeleteAgentSessionsResult struct {
|
||||
Data *DeleteAgentSessionsResponse
|
||||
Message string
|
||||
}
|
||||
|
||||
type DeleteAgentSessionsResponse struct {
|
||||
SuccessCount int `json:"success_count"`
|
||||
Errors []string `json:"errors,omitempty"`
|
||||
}
|
||||
|
||||
type TestDBConnectionRequest struct {
|
||||
DBType string `json:"db_type"`
|
||||
Database string `json:"database"`
|
||||
Username string `json:"username"`
|
||||
Host string `json:"host"`
|
||||
Port interface{} `json:"port"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
func parseAgentSessionDate(value string, isEnd bool) (*time.Time, error) {
|
||||
if value == "" {
|
||||
return nil, nil
|
||||
@@ -370,6 +396,29 @@ func firstNonNil(values ...interface{}) interface{} {
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkDuplicateSessionIDs check duplicated ID in IDS and add some messages for it
|
||||
func checkDuplicateSessionIDs(ids []string) ([]string, []string) {
|
||||
seen := make(map[string]int, len(ids))
|
||||
uniqueIDs := make([]string, 0, len(ids))
|
||||
|
||||
for _, id := range ids {
|
||||
id = strings.TrimSpace(id)
|
||||
seen[id]++
|
||||
if seen[id] == 1 {
|
||||
uniqueIDs = append(uniqueIDs, id)
|
||||
}
|
||||
}
|
||||
|
||||
duplicateMessages := make([]string, 0)
|
||||
for _, id := range uniqueIDs {
|
||||
if seen[id] > 1 {
|
||||
duplicateMessages = append(duplicateMessages, fmt.Sprintf("Duplicate session ids: %s", id))
|
||||
}
|
||||
}
|
||||
|
||||
return uniqueIDs, duplicateMessages
|
||||
}
|
||||
|
||||
func (s *AgentService) ListAgentSessions(userID, tenantID, agentID string, req ListAgentSessionsRequest) (*ListAgentSessionsResponse, common.ErrorCode, error) {
|
||||
if agentID == "" {
|
||||
return nil, common.CodeArgumentError, errors.New("agent_id is required")
|
||||
@@ -474,6 +523,301 @@ func (s *AgentService) DeleteAgentSessionItem(userID, agentID, sessionID string)
|
||||
return true, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
// DeleteAgentSessions Delete sessions by ids
|
||||
func (s *AgentService) DeleteAgentSessions(userID, agentID string, ids []string, deleteAll bool) (*DeleteAgentSessionsResult, common.ErrorCode, error) {
|
||||
if agentID == "" {
|
||||
return nil, common.CodeArgumentError, errors.New("agent_id is required")
|
||||
}
|
||||
|
||||
canvas, err := s.canvasDAO.GetByID(agentID)
|
||||
if err != nil || canvas == nil || canvas.UserID != userID {
|
||||
return nil, common.CodeDataError, fmt.Errorf("You don't own the agent %s", agentID)
|
||||
}
|
||||
|
||||
if len(ids) == 0 {
|
||||
if !deleteAll {
|
||||
return &DeleteAgentSessionsResult{}, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
ids, err = s.api4ConversationDAO.ListIDsByAgentID(agentID)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return &DeleteAgentSessionsResult{}, common.CodeSuccess, nil
|
||||
}
|
||||
}
|
||||
|
||||
sessionIDs, duplicateMessages := checkDuplicateSessionIDs(ids)
|
||||
errorsList := make([]string, 0)
|
||||
successCount := 0
|
||||
|
||||
for _, sessionID := range sessionIDs {
|
||||
sessionID = strings.TrimSpace(sessionID)
|
||||
if sessionID == "" {
|
||||
errorsList = append(errorsList, "The agent doesn't own the session ")
|
||||
continue
|
||||
}
|
||||
|
||||
conv, err := s.api4ConversationDAO.GetBySessionID(sessionID, agentID)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
if conv == nil {
|
||||
errorsList = append(errorsList, fmt.Sprintf("The agent doesn't own the session %s", sessionID))
|
||||
continue
|
||||
}
|
||||
|
||||
if _, err := s.api4ConversationDAO.DeleteBySessionIDAndAgentID(sessionID, agentID); err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
successCount++
|
||||
}
|
||||
|
||||
if len(errorsList) > 0 {
|
||||
if successCount > 0 {
|
||||
return &DeleteAgentSessionsResult{
|
||||
Message: fmt.Sprintf("Partially deleted %d sessions with %d errors", successCount, len(errorsList)),
|
||||
Data: &DeleteAgentSessionsResponse{
|
||||
SuccessCount: successCount,
|
||||
Errors: errorsList,
|
||||
},
|
||||
}, common.CodeSuccess, nil
|
||||
}
|
||||
return nil, common.CodeDataError, errors.New(strings.Join(errorsList, "; "))
|
||||
}
|
||||
|
||||
if len(duplicateMessages) > 0 {
|
||||
if successCount > 0 {
|
||||
return &DeleteAgentSessionsResult{
|
||||
Message: fmt.Sprintf("Partially deleted %d sessions with %d errors", successCount, len(duplicateMessages)),
|
||||
Data: &DeleteAgentSessionsResponse{
|
||||
SuccessCount: successCount,
|
||||
Errors: duplicateMessages,
|
||||
},
|
||||
}, common.CodeSuccess, nil
|
||||
}
|
||||
return nil, common.CodeDataError, errors.New(strings.Join(duplicateMessages, ";"))
|
||||
}
|
||||
|
||||
return &DeleteAgentSessionsResult{}, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
// AssertHostIsSafe checks whether host resolves only to public IP addresses.
|
||||
func AssertHostIsSafe(host string) (string, error) {
|
||||
host = strings.TrimSpace(host)
|
||||
if host == "" {
|
||||
return "", errors.New("Host must not be empty.")
|
||||
}
|
||||
|
||||
ips, err := net.LookupIP(host)
|
||||
if err != nil {
|
||||
zap.L().Warn("SSRF guard could not resolve host",
|
||||
zap.String("host", host),
|
||||
zap.Error(err),
|
||||
)
|
||||
return "", fmt.Errorf("Could not resolve host %q: %w", host, err)
|
||||
}
|
||||
|
||||
if len(ips) == 0 {
|
||||
zap.L().Warn("SSRF guard blocked host: resolved to no addresses",
|
||||
zap.String("host", host),
|
||||
)
|
||||
return "", fmt.Errorf("Host %q resolved to no addresses.", host)
|
||||
}
|
||||
|
||||
var resolvedIP string
|
||||
|
||||
for _, ip := range ips {
|
||||
addr, ok := netip.AddrFromSlice(ip)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("invalid resolved IP %q for host %q", ip.String(), host)
|
||||
}
|
||||
|
||||
// Normalize IPv4-mapped IPv6, equivalent to Python _effective_ip().
|
||||
addr = addr.Unmap()
|
||||
|
||||
if !isPublicAddr(addr) {
|
||||
zap.L().Warn("SSRF guard blocked host",
|
||||
zap.String("host", host),
|
||||
zap.String("resolved_ip", addr.String()),
|
||||
)
|
||||
return "", fmt.Errorf("Host resolves to a non-public address (%s), which is not allowed.", addr.String())
|
||||
}
|
||||
|
||||
if resolvedIP == "" {
|
||||
resolvedIP = addr.String()
|
||||
}
|
||||
}
|
||||
|
||||
if resolvedIP == "" {
|
||||
return "", fmt.Errorf("Host %q resolved to no addresses.", host)
|
||||
}
|
||||
|
||||
return resolvedIP, nil
|
||||
}
|
||||
|
||||
func isPublicAddr(addr netip.Addr) bool {
|
||||
addr = addr.Unmap()
|
||||
|
||||
if !addr.IsValid() {
|
||||
return false
|
||||
}
|
||||
|
||||
if !addr.IsGlobalUnicast() {
|
||||
return false
|
||||
}
|
||||
|
||||
if addr.IsPrivate() ||
|
||||
addr.IsLoopback() ||
|
||||
addr.IsLinkLocalUnicast() ||
|
||||
addr.IsLinkLocalMulticast() ||
|
||||
addr.IsMulticast() ||
|
||||
addr.IsUnspecified() {
|
||||
return false
|
||||
}
|
||||
|
||||
return !isSpecialUseAddr(addr)
|
||||
}
|
||||
|
||||
func isSpecialUseAddr(addr netip.Addr) bool {
|
||||
addr = addr.Unmap()
|
||||
|
||||
specialCIDRs := []string{
|
||||
// IPv4 special-use / documentation / reserved ranges.
|
||||
"0.0.0.0/8",
|
||||
"100.64.0.0/10",
|
||||
"127.0.0.0/8",
|
||||
"169.254.0.0/16",
|
||||
"192.0.0.0/24",
|
||||
"192.0.2.0/24",
|
||||
"198.18.0.0/15",
|
||||
"198.51.100.0/24",
|
||||
"203.0.113.0/24",
|
||||
"224.0.0.0/4",
|
||||
"240.0.0.0/4",
|
||||
|
||||
// IPv6 special-use / documentation / local ranges.
|
||||
"::/128",
|
||||
"::1/128",
|
||||
"64:ff9b:1::/48",
|
||||
"100::/64",
|
||||
"2001::/23",
|
||||
"2001:2::/48",
|
||||
"fc00::/7",
|
||||
"fe80::/10",
|
||||
"ff00::/8",
|
||||
"2001:db8::/32",
|
||||
"2002::/16",
|
||||
}
|
||||
|
||||
for _, cidr := range specialCIDRs {
|
||||
prefix := netip.MustParsePrefix(cidr)
|
||||
if prefix.Contains(addr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// missingDBConnectionFields Check if request is missing something
|
||||
func missingDBConnectionFields(req *TestDBConnectionRequest) []string {
|
||||
missing := make([]string, 0, 6)
|
||||
if req == nil || strings.TrimSpace(req.DBType) == "" {
|
||||
missing = append(missing, "db_type")
|
||||
}
|
||||
if req == nil || strings.TrimSpace(req.Database) == "" {
|
||||
missing = append(missing, "database")
|
||||
}
|
||||
if req == nil || strings.TrimSpace(req.Username) == "" {
|
||||
missing = append(missing, "username")
|
||||
}
|
||||
if req == nil || strings.TrimSpace(req.Host) == "" {
|
||||
missing = append(missing, "host")
|
||||
}
|
||||
if req == nil || dbConnectionPort(req.Port) == "" {
|
||||
missing = append(missing, "port")
|
||||
}
|
||||
if req == nil || req.Password == "" {
|
||||
missing = append(missing, "password")
|
||||
}
|
||||
return missing
|
||||
}
|
||||
|
||||
func dbConnectionPort(port interface{}) string {
|
||||
switch value := port.(type) {
|
||||
case nil:
|
||||
return ""
|
||||
case string:
|
||||
return strings.TrimSpace(value)
|
||||
case float64:
|
||||
return strconv.Itoa(int(value))
|
||||
case float32:
|
||||
return strconv.Itoa(int(value))
|
||||
case int:
|
||||
return strconv.Itoa(value)
|
||||
case int64:
|
||||
return strconv.FormatInt(value, 10)
|
||||
case json.Number:
|
||||
return value.String()
|
||||
default:
|
||||
return strings.TrimSpace(fmt.Sprint(value))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AgentService) TestDBConnection(userID string, req *TestDBConnectionRequest) (common.ErrorCode, error) {
|
||||
if missing := missingDBConnectionFields(req); len(missing) > 0 {
|
||||
return common.CodeArgumentError, fmt.Errorf("required argument are missing: %s; ", strings.Join(missing, ","))
|
||||
}
|
||||
|
||||
safeHost, err := AssertHostIsSafe(req.Host)
|
||||
if err != nil {
|
||||
zap.L().Warn(
|
||||
"Rejected test_db_connection: unsafe host",
|
||||
zap.String("host", req.Host),
|
||||
zap.String("db_type", req.DBType),
|
||||
zap.String("user", userID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return common.CodeDataError, err
|
||||
}
|
||||
|
||||
switch req.DBType {
|
||||
case "mysql", "mariadb", "oceanbase":
|
||||
port := dbConnectionPort(req.Port)
|
||||
dbProbeTimeout := 5 * time.Second
|
||||
config := mysql.Config{
|
||||
User: req.Username,
|
||||
Passwd: req.Password,
|
||||
Net: "tcp",
|
||||
Addr: net.JoinHostPort(safeHost, port),
|
||||
DBName: req.Database,
|
||||
Timeout: dbProbeTimeout,
|
||||
AllowNativePasswords: true,
|
||||
}
|
||||
db, err := sql.Open("mysql", config.FormatDSN())
|
||||
if err != nil {
|
||||
return common.CodeExceptionError, err
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), dbProbeTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := db.PingContext(ctx); err != nil {
|
||||
return common.CodeExceptionError, err
|
||||
}
|
||||
if _, err := db.ExecContext(ctx, "SELECT 1"); err != nil {
|
||||
return common.CodeExceptionError, err
|
||||
}
|
||||
default:
|
||||
return common.CodeExceptionError, errors.New("Unsupported database type.")
|
||||
}
|
||||
|
||||
return common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
// normalizeAgentTags returns an error for unsupported tag payload types
|
||||
func normalizeAgentTags(rawTags interface{}) (string, error) {
|
||||
cleaned := make([]string, 0)
|
||||
|
||||
@@ -17,9 +17,13 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
@@ -344,5 +348,486 @@ func TestGetVersion_NotFound(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func setupAgentSessionServiceTest(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
testDB := setupServiceTestDB(t)
|
||||
if err := testDB.AutoMigrate(
|
||||
&entity.User{},
|
||||
&entity.UserCanvas{},
|
||||
&entity.UserTenant{},
|
||||
&entity.API4Conversation{},
|
||||
); err != nil {
|
||||
t.Fatalf("failed to migrate: %v", err)
|
||||
}
|
||||
|
||||
orig := dao.DB
|
||||
dao.DB = testDB
|
||||
t.Cleanup(func() { dao.DB = orig })
|
||||
}
|
||||
|
||||
func createAgentSessionTestCanvas(t *testing.T, id, userID string) {
|
||||
t.Helper()
|
||||
if err := dao.DB.Create(&entity.UserCanvas{
|
||||
ID: id,
|
||||
UserID: userID,
|
||||
Title: sptr("Test Agent"),
|
||||
CanvasCategory: "agent_canvas",
|
||||
}).Error; err != nil {
|
||||
t.Fatalf("failed to create canvas %s: %v", id, err)
|
||||
}
|
||||
}
|
||||
|
||||
func createAgentSessionTestConversation(t *testing.T, id, agentID, userID string, updateTime int64) {
|
||||
t.Helper()
|
||||
updateDate := time.UnixMilli(updateTime)
|
||||
if err := dao.DB.Create(&entity.API4Conversation{
|
||||
ID: id,
|
||||
DialogID: agentID,
|
||||
UserID: userID,
|
||||
Message: json.RawMessage(`[{"role":"assistant","content":"hello","prompt":"hidden"}]`),
|
||||
Reference: json.RawMessage(`[]`),
|
||||
BaseModel: entity.BaseModel{
|
||||
CreateTime: ptr(updateTime),
|
||||
CreateDate: &updateDate,
|
||||
UpdateTime: ptr(updateTime),
|
||||
UpdateDate: &updateDate,
|
||||
},
|
||||
}).Error; err != nil {
|
||||
t.Fatalf("failed to create session %s: %v", id, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListAgentSessionsServiceSuccess(t *testing.T) {
|
||||
setupAgentSessionServiceTest(t)
|
||||
|
||||
createAgentSessionTestCanvas(t, "canvas-1", "user-1")
|
||||
createAgentSessionTestConversation(t, "session-old", "canvas-1", "user-1", 1000)
|
||||
createAgentSessionTestConversation(t, "session-new", "canvas-1", "user-1", 3000)
|
||||
createAgentSessionTestConversation(t, "session-other-agent", "canvas-other", "user-1", 9999)
|
||||
|
||||
resp, code, err := NewAgentService().ListAgentSessions("user-1", "user-1", "canvas-1", ListAgentSessionsRequest{
|
||||
Page: 1,
|
||||
PageSize: 10,
|
||||
OrderBy: "update_time",
|
||||
Desc: true,
|
||||
IncludeDSL: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ListAgentSessions failed: %v", err)
|
||||
}
|
||||
if code != common.CodeSuccess {
|
||||
t.Fatalf("expected code %d, got %d", common.CodeSuccess, code)
|
||||
}
|
||||
if resp.Total != 2 {
|
||||
t.Fatalf("expected total 2, got %d", resp.Total)
|
||||
}
|
||||
if len(resp.Data) != 2 {
|
||||
t.Fatalf("expected 2 sessions, got %d", len(resp.Data))
|
||||
}
|
||||
if resp.Data[0]["id"] != "session-new" {
|
||||
t.Fatalf("expected newest session first, got %v", resp.Data[0]["id"])
|
||||
}
|
||||
if resp.Data[0]["agent_id"] != "canvas-1" {
|
||||
t.Fatalf("expected agent_id canvas-1, got %v", resp.Data[0]["agent_id"])
|
||||
}
|
||||
messages, ok := resp.Data[0]["message"].([]map[string]interface{})
|
||||
if !ok || len(messages) != 1 {
|
||||
t.Fatalf("expected normalized message slice, got %T %v", resp.Data[0]["message"], resp.Data[0]["message"])
|
||||
}
|
||||
if _, ok := messages[0]["prompt"]; ok {
|
||||
t.Fatal("expected prompt to be removed from normalized session message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestListAgentSessionsServiceDenied(t *testing.T) {
|
||||
setupAgentSessionServiceTest(t)
|
||||
|
||||
createAgentSessionTestCanvas(t, "canvas-1", "user-2")
|
||||
|
||||
resp, code, err := NewAgentService().ListAgentSessions("user-1", "user-1", "canvas-1", ListAgentSessionsRequest{})
|
||||
if err == nil {
|
||||
t.Fatal("expected permission error")
|
||||
}
|
||||
if code != common.CodeOperatingError {
|
||||
t.Fatalf("expected code %d, got %d", common.CodeOperatingError, code)
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatalf("expected nil response, got %+v", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAgentSessionServiceSuccess(t *testing.T) {
|
||||
setupAgentSessionServiceTest(t)
|
||||
|
||||
createAgentSessionTestCanvas(t, "canvas-1", "user-1")
|
||||
createAgentSessionTestConversation(t, "session-1", "canvas-1", "user-1", 1000)
|
||||
|
||||
session, code, err := NewAgentService().GetAgentSession("user-1", "canvas-1", "session-1")
|
||||
if err != nil {
|
||||
t.Fatalf("GetAgentSession failed: %v", err)
|
||||
}
|
||||
if code != common.CodeSuccess {
|
||||
t.Fatalf("expected code %d, got %d", common.CodeSuccess, code)
|
||||
}
|
||||
if session == nil {
|
||||
t.Fatal("expected session, got nil")
|
||||
}
|
||||
if session.ID != "session-1" {
|
||||
t.Fatalf("expected session-1, got %s", session.ID)
|
||||
}
|
||||
if session.DialogID != "canvas-1" {
|
||||
t.Fatalf("expected dialog_id canvas-1, got %s", session.DialogID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAgentSessionServiceNotFoundWhenSessionBelongsToAnotherAgent(t *testing.T) {
|
||||
setupAgentSessionServiceTest(t)
|
||||
|
||||
createAgentSessionTestCanvas(t, "canvas-1", "user-1")
|
||||
createAgentSessionTestConversation(t, "session-other", "canvas-other", "user-1", 1000)
|
||||
|
||||
session, code, err := NewAgentService().GetAgentSession("user-1", "canvas-1", "session-other")
|
||||
if err == nil {
|
||||
t.Fatal("expected not found error")
|
||||
}
|
||||
if code != common.CodeNotFound {
|
||||
t.Fatalf("expected code %d, got %d", common.CodeNotFound, code)
|
||||
}
|
||||
if session != nil {
|
||||
t.Fatalf("expected nil session, got %+v", session)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteAgentSessionItemServiceDeletesMatchingSession(t *testing.T) {
|
||||
setupAgentSessionServiceTest(t)
|
||||
|
||||
createAgentSessionTestCanvas(t, "canvas-1", "user-1")
|
||||
createAgentSessionTestConversation(t, "session-1", "canvas-1", "user-1", 1000)
|
||||
createAgentSessionTestConversation(t, "session-other", "canvas-other", "user-1", 2000)
|
||||
|
||||
deleted, code, err := NewAgentService().DeleteAgentSessionItem("user-1", "canvas-1", "session-1")
|
||||
if err != nil {
|
||||
t.Fatalf("DeleteAgentSessionItem failed: %v", err)
|
||||
}
|
||||
if code != common.CodeSuccess {
|
||||
t.Fatalf("expected code %d, got %d", common.CodeSuccess, code)
|
||||
}
|
||||
if !deleted {
|
||||
t.Fatal("expected session to be deleted")
|
||||
}
|
||||
|
||||
var count int64
|
||||
if err := dao.DB.Model(&entity.API4Conversation{}).Where("id = ?", "session-1").Count(&count).Error; err != nil {
|
||||
t.Fatalf("failed to count deleted session: %v", err)
|
||||
}
|
||||
if count != 0 {
|
||||
t.Fatalf("expected session-1 to be deleted, count=%d", count)
|
||||
}
|
||||
|
||||
if err := dao.DB.Model(&entity.API4Conversation{}).Where("id = ?", "session-other").Count(&count).Error; err != nil {
|
||||
t.Fatalf("failed to count other session: %v", err)
|
||||
}
|
||||
if count != 1 {
|
||||
t.Fatalf("expected other agent session to remain, count=%d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteAgentSessionItemServiceNoopForSessionFromAnotherAgent(t *testing.T) {
|
||||
setupAgentSessionServiceTest(t)
|
||||
|
||||
createAgentSessionTestCanvas(t, "canvas-1", "user-1")
|
||||
createAgentSessionTestConversation(t, "session-other", "canvas-other", "user-1", 1000)
|
||||
|
||||
deleted, code, err := NewAgentService().DeleteAgentSessionItem("user-1", "canvas-1", "session-other")
|
||||
if err != nil {
|
||||
t.Fatalf("DeleteAgentSessionItem failed: %v", err)
|
||||
}
|
||||
if code != common.CodeSuccess {
|
||||
t.Fatalf("expected code %d, got %d", common.CodeSuccess, code)
|
||||
}
|
||||
if deleted {
|
||||
t.Fatal("expected cross-agent session delete to be a noop")
|
||||
}
|
||||
|
||||
var count int64
|
||||
if err := dao.DB.Model(&entity.API4Conversation{}).Where("id = ?", "session-other").Count(&count).Error; err != nil {
|
||||
t.Fatalf("failed to count other session: %v", err)
|
||||
}
|
||||
if count != 1 {
|
||||
t.Fatalf("expected cross-agent session to remain, count=%d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteAgentSessionsServiceDeleteAll(t *testing.T) {
|
||||
setupAgentSessionServiceTest(t)
|
||||
|
||||
createAgentSessionTestCanvas(t, "canvas-1", "user-1")
|
||||
createAgentSessionTestConversation(t, "session-1", "canvas-1", "user-1", 1000)
|
||||
createAgentSessionTestConversation(t, "session-2", "canvas-1", "user-1", 2000)
|
||||
createAgentSessionTestConversation(t, "session-other", "canvas-other", "user-1", 3000)
|
||||
|
||||
result, code, err := NewAgentService().DeleteAgentSessions("user-1", "canvas-1", nil, true)
|
||||
if err != nil {
|
||||
t.Fatalf("DeleteAgentSessions failed: %v", err)
|
||||
}
|
||||
if code != common.CodeSuccess {
|
||||
t.Fatalf("expected code %d, got %d", common.CodeSuccess, code)
|
||||
}
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
if result.Data != nil {
|
||||
t.Fatalf("expected no partial data on full success, got %+v", result.Data)
|
||||
}
|
||||
|
||||
var ownCount int64
|
||||
if err := dao.DB.Model(&entity.API4Conversation{}).Where("dialog_id = ?", "canvas-1").Count(&ownCount).Error; err != nil {
|
||||
t.Fatalf("failed to count own sessions: %v", err)
|
||||
}
|
||||
if ownCount != 0 {
|
||||
t.Fatalf("expected all canvas-1 sessions to be deleted, count=%d", ownCount)
|
||||
}
|
||||
|
||||
var otherCount int64
|
||||
if err := dao.DB.Model(&entity.API4Conversation{}).Where("id = ?", "session-other").Count(&otherCount).Error; err != nil {
|
||||
t.Fatalf("failed to count other session: %v", err)
|
||||
}
|
||||
if otherCount != 1 {
|
||||
t.Fatalf("expected other agent session to remain, count=%d", otherCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteAgentSessionsServiceDuplicateIDsPartial(t *testing.T) {
|
||||
setupAgentSessionServiceTest(t)
|
||||
|
||||
createAgentSessionTestCanvas(t, "canvas-1", "user-1")
|
||||
createAgentSessionTestConversation(t, "session-1", "canvas-1", "user-1", 1000)
|
||||
|
||||
result, code, err := NewAgentService().DeleteAgentSessions("user-1", "canvas-1", []string{"session-1", "session-1"}, false)
|
||||
if err != nil {
|
||||
t.Fatalf("DeleteAgentSessions failed: %v", err)
|
||||
}
|
||||
if code != common.CodeSuccess {
|
||||
t.Fatalf("expected code %d, got %d", common.CodeSuccess, code)
|
||||
}
|
||||
if result == nil || result.Data == nil {
|
||||
t.Fatalf("expected partial result data, got %+v", result)
|
||||
}
|
||||
if result.Data.SuccessCount != 1 {
|
||||
t.Fatalf("expected success_count 1, got %d", result.Data.SuccessCount)
|
||||
}
|
||||
if len(result.Data.Errors) != 1 || result.Data.Errors[0] != "Duplicate session ids: session-1" {
|
||||
t.Fatalf("unexpected duplicate errors: %v", result.Data.Errors)
|
||||
}
|
||||
|
||||
var count int64
|
||||
if err := dao.DB.Model(&entity.API4Conversation{}).Where("id = ?", "session-1").Count(&count).Error; err != nil {
|
||||
t.Fatalf("failed to count deleted session: %v", err)
|
||||
}
|
||||
if count != 0 {
|
||||
t.Fatalf("expected session-1 to be deleted, count=%d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteAgentSessionsServiceMissingSessionError(t *testing.T) {
|
||||
setupAgentSessionServiceTest(t)
|
||||
|
||||
createAgentSessionTestCanvas(t, "canvas-1", "user-1")
|
||||
|
||||
result, code, err := NewAgentService().DeleteAgentSessions("user-1", "canvas-1", []string{"missing-session"}, false)
|
||||
if err == nil {
|
||||
t.Fatal("expected missing session error")
|
||||
}
|
||||
if code != common.CodeDataError {
|
||||
t.Fatalf("expected code %d, got %d", common.CodeDataError, code)
|
||||
}
|
||||
if result != nil {
|
||||
t.Fatalf("expected nil result, got %+v", result)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "The agent doesn't own the session missing-session") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteAgentSessionsServiceRequiresOwner(t *testing.T) {
|
||||
setupAgentSessionServiceTest(t)
|
||||
|
||||
createAgentSessionTestCanvas(t, "canvas-1", "user-2")
|
||||
createAgentSessionTestConversation(t, "session-1", "canvas-1", "user-1", 1000)
|
||||
|
||||
result, code, err := NewAgentService().DeleteAgentSessions("user-1", "canvas-1", []string{"session-1"}, false)
|
||||
if err == nil {
|
||||
t.Fatal("expected owner error")
|
||||
}
|
||||
if code != common.CodeDataError {
|
||||
t.Fatalf("expected code %d, got %d", common.CodeDataError, code)
|
||||
}
|
||||
if result != nil {
|
||||
t.Fatalf("expected nil result, got %+v", result)
|
||||
}
|
||||
|
||||
var count int64
|
||||
if err := dao.DB.Model(&entity.API4Conversation{}).Where("id = ?", "session-1").Count(&count).Error; err != nil {
|
||||
t.Fatalf("failed to count session: %v", err)
|
||||
}
|
||||
if count != 1 {
|
||||
t.Fatalf("expected session to remain, count=%d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateAgentTagsServiceSuccess(t *testing.T) {
|
||||
setupAgentSessionServiceTest(t)
|
||||
|
||||
createAgentSessionTestCanvas(t, "canvas-1", "user-1")
|
||||
|
||||
ok, code, err := NewAgentService().UpdateAgentTags("user-1", "canvas-1", []interface{}{"alpha", "beta", "alpha", "with,comma"})
|
||||
if err != nil {
|
||||
t.Fatalf("UpdateAgentTags failed: %v", err)
|
||||
}
|
||||
if code != common.CodeSuccess {
|
||||
t.Fatalf("expected code %d, got %d", common.CodeSuccess, code)
|
||||
}
|
||||
if !ok {
|
||||
t.Fatal("expected update to succeed")
|
||||
}
|
||||
|
||||
canvas, err := dao.NewUserCanvasDAO().GetByID("canvas-1")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get canvas: %v", err)
|
||||
}
|
||||
if canvas.Tags != "alpha,beta,with comma" {
|
||||
t.Fatalf("expected normalized tags, got %q", canvas.Tags)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateAgentTagsServiceInvalidPayload(t *testing.T) {
|
||||
setupAgentSessionServiceTest(t)
|
||||
|
||||
createAgentSessionTestCanvas(t, "canvas-1", "user-1")
|
||||
|
||||
ok, code, err := NewAgentService().UpdateAgentTags("user-1", "canvas-1", map[string]string{"tag": "alpha"})
|
||||
if err == nil {
|
||||
t.Fatal("expected invalid tags error")
|
||||
}
|
||||
if code != common.CodeBadRequest {
|
||||
t.Fatalf("expected code %d, got %d", common.CodeBadRequest, code)
|
||||
}
|
||||
if ok {
|
||||
t.Fatal("expected update to fail")
|
||||
}
|
||||
|
||||
canvas, err := dao.NewUserCanvasDAO().GetByID("canvas-1")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get canvas: %v", err)
|
||||
}
|
||||
if canvas.Tags != "" {
|
||||
t.Fatalf("expected tags to remain unchanged, got %q", canvas.Tags)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateAgentTagsServiceNoPermission(t *testing.T) {
|
||||
setupAgentSessionServiceTest(t)
|
||||
|
||||
createAgentSessionTestCanvas(t, "canvas-1", "user-2")
|
||||
|
||||
ok, code, err := NewAgentService().UpdateAgentTags("user-1", "canvas-1", []string{"alpha"})
|
||||
if err == nil {
|
||||
t.Fatal("expected permission error")
|
||||
}
|
||||
if code != common.CodeOperatingError {
|
||||
t.Fatalf("expected code %d, got %d", common.CodeOperatingError, code)
|
||||
}
|
||||
if ok {
|
||||
t.Fatal("expected update to fail")
|
||||
}
|
||||
|
||||
canvas, err := dao.NewUserCanvasDAO().GetByID("canvas-1")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get canvas: %v", err)
|
||||
}
|
||||
if canvas.Tags != "" {
|
||||
t.Fatalf("expected tags to remain unchanged, got %q", canvas.Tags)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsPublicAddr(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr string
|
||||
want bool
|
||||
}{
|
||||
{name: "public IPv4", addr: "8.8.8.8", want: true},
|
||||
{name: "loopback", addr: "127.0.0.1", want: false},
|
||||
{name: "private", addr: "192.168.1.1", want: false},
|
||||
{name: "carrier NAT", addr: "100.64.0.1", want: false},
|
||||
{name: "documentation", addr: "203.0.113.1", want: false},
|
||||
{name: "IPv4 mapped loopback", addr: "::ffff:127.0.0.1", want: false},
|
||||
{name: "IPv6 documentation", addr: "2001:db8::1", want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := isPublicAddr(netip.MustParseAddr(tt.addr))
|
||||
if got != tt.want {
|
||||
t.Fatalf("isPublicAddr(%s): expected %v, got %v", tt.addr, tt.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAssertHostIsSafeRejectsLocalhost(t *testing.T) {
|
||||
_, err := AssertHostIsSafe("localhost")
|
||||
if err == nil {
|
||||
t.Fatal("expected localhost to be rejected")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "non-public address") {
|
||||
t.Fatalf("expected non-public address error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTestDBConnectionMissingFields(t *testing.T) {
|
||||
code, err := NewAgentService().TestDBConnection("user-1", &TestDBConnectionRequest{DBType: "mysql"})
|
||||
if err == nil {
|
||||
t.Fatal("expected missing field error")
|
||||
}
|
||||
if code != common.CodeArgumentError {
|
||||
t.Fatalf("expected code %d, got %d", common.CodeArgumentError, code)
|
||||
}
|
||||
want := "required argument are missing: database,username,host,port,password; "
|
||||
if err.Error() != want {
|
||||
t.Fatalf("expected %q, got %q", want, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestTestDBConnectionUnsupportedDatabaseType(t *testing.T) {
|
||||
code, err := NewAgentService().TestDBConnection("user-1", &TestDBConnectionRequest{
|
||||
DBType: "postgres",
|
||||
Database: "rag_flow",
|
||||
Username: "root",
|
||||
Host: "8.8.8.8",
|
||||
Port: 5432,
|
||||
Password: "password",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected unsupported database type error")
|
||||
}
|
||||
if code != common.CodeExceptionError {
|
||||
t.Fatalf("expected code %d, got %d", common.CodeExceptionError, code)
|
||||
}
|
||||
if err.Error() != "Unsupported database type." {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBConnectionPortAcceptsStringAndNumber(t *testing.T) {
|
||||
if got := dbConnectionPort("3306"); got != "3306" {
|
||||
t.Fatalf("expected string port 3306, got %q", got)
|
||||
}
|
||||
if got := dbConnectionPort(float64(3306)); got != "3306" {
|
||||
t.Fatalf("expected numeric port 3306, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
// ptr returns a pointer to the given int64.
|
||||
func ptr(v int64) *int64 { return &v }
|
||||
func ptr(v int64) *int64 { return &v }
|
||||
|
||||
Reference in New Issue
Block a user