feat[Go]: implement /api/v1/agents/<agent_id>/sessions (#15705)

### What problem does this PR solve?

As Title
Codes were tested by Postman

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Haruko386
2026-06-08 16:26:27 +08:00
committed by GitHub
parent e2b0da9eea
commit 67ce0c896d
11 changed files with 1310 additions and 31 deletions

View File

@@ -119,6 +119,24 @@ func (dao *API4ConversationDAO) Stats(tenantID, fromDate, toDate string, source
return rows, err
}
func (dao *API4ConversationDAO) GetBySessionID(sessionID, agentID string) (*entity.API4Conversation, error) {
var result entity.API4Conversation
tx := DB.Where("id = ? AND dialog_id = ?", sessionID, agentID).Find(&result)
if tx.Error != nil {
return nil, tx.Error
}
if tx.RowsAffected == 0 {
return nil, nil
}
return &result, nil
}
// DeleteBySessionIDAndAgentID deletes API4Conversations by sessionID and agentID
func (dao *API4ConversationDAO) DeleteBySessionIDAndAgentID(sessionID, agentID string) (int64, error) {
result := DB.Where("id = ? AND dialog_id = ?", sessionID, agentID).Delete(&entity.API4Conversation{})
return result.RowsAffected, result.Error
}
// DeleteByDialogIDs deletes API4Conversations by dialog IDs (hard delete)
func (dao *API4ConversationDAO) DeleteByDialogIDs(dialogIDs []string) (int64, error) {
if len(dialogIDs) == 0 {

View File

@@ -0,0 +1,108 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package dao
import (
"encoding/json"
"testing"
"github.com/glebarez/sqlite"
"gorm.io/gorm"
"ragflow/internal/entity"
)
func setupAPI4ConversationTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
TranslateError: true,
})
if err != nil {
t.Fatalf("failed to open sqlite: %v", err)
}
if err := db.AutoMigrate(&entity.API4Conversation{}); err != nil {
t.Fatalf("failed to migrate: %v", err)
}
return db
}
func createAPI4ConversationForDAOTest(t *testing.T, id, agentID string) {
t.Helper()
if err := DB.Create(&entity.API4Conversation{
ID: id,
DialogID: agentID,
UserID: "user-1",
Message: json.RawMessage(`[]`),
Reference: json.RawMessage(`[]`),
}).Error; err != nil {
t.Fatalf("failed to create api conversation %s: %v", id, err)
}
}
func TestAPI4ConversationDAOGetBySessionID(t *testing.T) {
db := setupAPI4ConversationTestDB(t)
pushDB(t, db)
createAPI4ConversationForDAOTest(t, "session-1", "agent-1")
session, err := NewAPI4ConversationDAO().GetBySessionID("session-1", "agent-1")
if err != nil {
t.Fatalf("GetBySessionID failed: %v", err)
}
if session == nil {
t.Fatal("expected session, got nil")
}
if session.ID != "session-1" {
t.Fatalf("expected session-1, got %s", session.ID)
}
if session.DialogID != "agent-1" {
t.Fatalf("expected agent-1, got %s", session.DialogID)
}
}
func TestAPI4ConversationDAOGetBySessionIDWrongAgent(t *testing.T) {
db := setupAPI4ConversationTestDB(t)
pushDB(t, db)
createAPI4ConversationForDAOTest(t, "session-1", "agent-1")
session, err := NewAPI4ConversationDAO().GetBySessionID("session-1", "agent-2")
if err != nil {
t.Fatalf("GetBySessionID failed: %v", err)
}
if session != nil {
t.Fatalf("expected nil for wrong agent, got %+v", session)
}
}
func TestAPI4ConversationDAOGetBySessionIDNoRows(t *testing.T) {
db := setupAPI4ConversationTestDB(t)
pushDB(t, db)
createAPI4ConversationForDAOTest(t, "session-1", "agent-1")
session, err := NewAPI4ConversationDAO().GetBySessionID("missing-session", "agent-1")
if err != nil {
t.Fatalf("GetBySessionID failed: %v", err)
}
if session != nil {
t.Fatalf("expected nil for missing session, got %+v", session)
}
}

View File

@@ -17,12 +17,30 @@
package dao
import (
"strings"
"time"
"ragflow/internal/entity"
)
// ChatSessionDAO chat session data access object
type ChatSessionDAO struct{}
type ListAgentSessionsParams struct {
AgentID string
Page int
PageSize int
OrderBy string
Desc bool
SessionID string
UserID string
IncludeDSL bool
Keywords string
FromDate *time.Time
ToDate *time.Time
ExpUserID string
}
// NewChatSessionDAO create chat session DAO
func NewChatSessionDAO() *ChatSessionDAO {
return &ChatSessionDAO{}
@@ -92,3 +110,107 @@ func (dao *ChatSessionDAO) DeleteByDialogIDs(dialogIDs []string) (int64, error)
result := DB.Unscoped().Where("dialog_id IN ?", dialogIDs).Delete(&entity.ChatSession{})
return result.RowsAffected, result.Error
}
func (dao *ChatSessionDAO) ListAgentSessionNames(agentID, expUserID string) ([]map[string]interface{}, error) {
var rows []map[string]interface{}
err := DB.Model(&entity.API4Conversation{}).
Select("id", "name").
Where("dialog_id = ? AND exp_user_id = ?", agentID, expUserID).
Order("create_date DESC").
Find(&rows).Error
return rows, err
}
func normalizeAgentSessionOrderBy(orderBy string) string {
switch orderBy {
case "id":
return "id"
case "name":
return "name"
case "create_time":
return "create_time"
case "create_date":
return "create_date"
case "update_time":
return "update_time"
case "update_date":
return "update_date"
case "tokens":
return "tokens"
case "duration":
return "duration"
case "round":
return "round"
case "thumb_up":
return "thumb_up"
default:
return "update_time"
}
}
func (dao *ChatSessionDAO) ListAgentSessions(params ListAgentSessionsParams) (int64, []*entity.API4Conversation, error) {
query := DB.Model(&entity.API4Conversation{}).Where("dialog_id = ?", params.AgentID)
if !params.IncludeDSL {
query = query.Omit("dsl")
}
if params.SessionID != "" {
query = query.Where("id = ?", params.SessionID)
}
if params.UserID != "" {
query = query.Where("user_id = ?", params.UserID)
}
if params.Keywords != "" {
query = query.Where("LOWER(message) LIKE ?", "%"+strings.ToLower(params.Keywords)+"%")
}
dateColumn := "create_date"
if strings.HasPrefix(params.OrderBy, "update_") {
dateColumn = "update_date"
}
if params.FromDate != nil {
query = query.Where(dateColumn+" >= ?", *params.FromDate)
}
if params.ToDate != nil {
query = query.Where(dateColumn+" <= ?", *params.ToDate)
}
if params.ExpUserID != "" {
query = query.Where("exp_user_id = ?", params.ExpUserID)
}
var total int64
if err := query.Count(&total).Error; err != nil {
return 0, nil, err
}
orderBy := normalizeAgentSessionOrderBy(params.OrderBy)
if params.Desc {
orderBy += " DESC"
} else {
orderBy += " ASC"
}
page := params.Page
if page <= 0 {
page = 1
}
pageSize := params.PageSize
if pageSize <= 0 {
pageSize = 30
}
var sessions []*entity.API4Conversation
err := query.
Order(orderBy).
Offset((page - 1) * pageSize).
Limit(pageSize).
Find(&sessions).Error
return total, sessions, err
}

View File

@@ -0,0 +1,140 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package dao
import (
"encoding/json"
"testing"
"time"
"github.com/glebarez/sqlite"
"gorm.io/gorm"
"ragflow/internal/entity"
)
func setupChatSessionDAOTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
TranslateError: true,
})
if err != nil {
t.Fatalf("failed to open sqlite: %v", err)
}
if err := db.AutoMigrate(&entity.API4Conversation{}); err != nil {
t.Fatalf("failed to migrate: %v", err)
}
return db
}
func createAgentSessionForDAOTest(t *testing.T, db *gorm.DB, id, agentID, userID string, updateTime int64) {
t.Helper()
updateDate := time.UnixMilli(updateTime).Local()
session := &entity.API4Conversation{
ID: id,
DialogID: agentID,
UserID: userID,
Message: json.RawMessage(`[{"role":"assistant","content":"hello"}]`),
Reference: json.RawMessage(`[]`),
BaseModel: entity.BaseModel{
CreateTime: &updateTime,
CreateDate: &updateDate,
UpdateTime: &updateTime,
UpdateDate: &updateDate,
},
}
if err := db.Create(session).Error; err != nil {
t.Fatalf("failed to create session %s: %v", id, err)
}
}
func TestChatSessionDAOListAgentSessionsOrdersByUpdateTimeDesc(t *testing.T) {
db := setupChatSessionDAOTestDB(t)
pushDB(t, db)
createAgentSessionForDAOTest(t, db, "session-old", "agent-1", "user-1", 1000)
createAgentSessionForDAOTest(t, db, "session-new", "agent-1", "user-1", 3000)
createAgentSessionForDAOTest(t, db, "session-middle", "agent-1", "user-1", 2000)
createAgentSessionForDAOTest(t, db, "session-other-agent", "agent-2", "user-1", 9999)
total, sessions, err := NewChatSessionDAO().ListAgentSessions(ListAgentSessionsParams{
AgentID: "agent-1",
Page: 1,
PageSize: 10,
OrderBy: "update_time",
Desc: true,
})
if err != nil {
t.Fatalf("ListAgentSessions failed: %v", err)
}
if total != 3 {
t.Fatalf("expected total 3, got %d", total)
}
if len(sessions) != 3 {
t.Fatalf("expected 3 sessions, got %d", len(sessions))
}
wantIDs := []string{"session-new", "session-middle", "session-old"}
for i, wantID := range wantIDs {
if sessions[i].ID != wantID {
t.Fatalf("session[%d]: expected %s, got %s", i, wantID, sessions[i].ID)
}
if sessions[i].DialogID != "agent-1" {
t.Fatalf("session[%d]: expected agent-1, got %s", i, sessions[i].DialogID)
}
}
}
func TestChatSessionDAOListAgentSessionsFiltersAndPaginates(t *testing.T) {
db := setupChatSessionDAOTestDB(t)
pushDB(t, db)
createAgentSessionForDAOTest(t, db, "session-1", "agent-1", "user-1", 1000)
createAgentSessionForDAOTest(t, db, "session-2", "agent-1", "user-1", 2000)
createAgentSessionForDAOTest(t, db, "session-3", "agent-1", "user-1", 3000)
createAgentSessionForDAOTest(t, db, "session-other-user", "agent-1", "user-2", 4000)
total, sessions, err := NewChatSessionDAO().ListAgentSessions(ListAgentSessionsParams{
AgentID: "agent-1",
UserID: "user-1",
Page: 2,
PageSize: 1,
OrderBy: "update_time",
Desc: false,
})
if err != nil {
t.Fatalf("ListAgentSessions failed: %v", err)
}
if total != 3 {
t.Fatalf("expected total 3 after user filter, got %d", total)
}
if len(sessions) != 1 {
t.Fatalf("expected one paginated session, got %d", len(sessions))
}
if sessions[0].ID != "session-2" {
t.Fatalf("expected second ascending session session-2, got %s", sessions[0].ID)
}
if sessions[0].UserID != "user-1" {
t.Fatalf("expected user-1, got %s", sessions[0].UserID)
}
}

View File

@@ -55,14 +55,7 @@ func (dao *UserCanvasDAO) Delete(id string) error {
// GetList get canvases list with pagination and filtering
// Similar to Python UserCanvasService.get_list
func (dao *UserCanvasDAO) GetList(
tenantID string,
pageNumber, itemsPerPage int,
orderby string,
desc bool,
id, title string,
canvasCategory string,
) ([]*entity.UserCanvas, error) {
func (dao *UserCanvasDAO) GetList(tenantID string, pageNumber, itemsPerPage int, orderby string, desc bool, id, title string, canvasCategory string) ([]*entity.UserCanvas, error) {
query := DB.Model(&entity.UserCanvas{}).
Where("user_id = ?", tenantID)
@@ -116,15 +109,7 @@ func (dao *UserCanvasDAO) GetAllCanvasesByTenantIDs(tenantIDs []string, userID s
// ListByTenantIDs lists agent canvases accessible to the given owner IDs with optional
// keyword filter, pagination, and ordering.
// Mirrors Python UserCanvasService.get_by_tenant_ids (list route only).
func (dao *UserCanvasDAO) ListByTenantIDs(
ownerIDs []string,
userID string,
page, pageSize int,
orderby string,
desc bool,
keywords string,
canvasCategory string,
) ([]*entity.UserCanvas, int64, error) {
func (dao *UserCanvasDAO) ListByTenantIDs(ownerIDs []string, userID string, page, pageSize int, orderby string, desc bool, keywords string, canvasCategory string) ([]*entity.UserCanvas, int64, error) {
if len(ownerIDs) == 0 {
return nil, 0, nil
}
@@ -201,6 +186,12 @@ func (dao *UserCanvasDAO) GetAllCanvasIDsByUserID(userID string) ([]string, erro
return canvasIDs, err
}
// UpdateDSL updates a canvas DSL by canvas ID.
func (dao *UserCanvasDAO) UpdateDSL(canvasID string, dsl entity.JSONMap) (int64, error) {
result := DB.Model(&entity.UserCanvas{}).Where("id = ?", canvasID).Update("dsl", dsl)
return result.RowsAffected, result.Error
}
// UpdateTags updates a canvas's comma-separated tags by canvas ID.
func (dao *UserCanvasDAO) UpdateTags(canvasID, tags string) (int64, error) {
result := DB.Model(&entity.UserCanvas{}).Where("id = ?", canvasID).Update("tags", tags)

View File

@@ -0,0 +1,135 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package dao
import (
"testing"
"github.com/glebarez/sqlite"
"gorm.io/gorm"
"ragflow/internal/entity"
)
func setupUserCanvasTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
TranslateError: true,
})
if err != nil {
t.Fatalf("failed to open sqlite: %v", err)
}
if err := db.AutoMigrate(&entity.UserCanvas{}); err != nil {
t.Fatalf("failed to migrate: %v", err)
}
return db
}
func TestUserCanvasDAOUpdateDSL(t *testing.T) {
db := setupUserCanvasTestDB(t)
pushDB(t, db)
dao := NewUserCanvasDAO()
originalDSL := entity.JSONMap{"graph": map[string]interface{}{"nodes": []interface{}{"old"}}}
if err := dao.Create(&entity.UserCanvas{
ID: "canvas-1",
UserID: "user-1",
Title: stringPtr("Test Canvas"),
CanvasCategory: "agent_canvas",
DSL: originalDSL,
}); err != nil {
t.Fatalf("failed to create canvas: %v", err)
}
newDSL := entity.JSONMap{
"graph": map[string]interface{}{
"nodes": []interface{}{"start", "end"},
"edges": []interface{}{"start:end"},
},
"path": []interface{}{"start", "end"},
}
rows, err := dao.UpdateDSL("canvas-1", newDSL)
if err != nil {
t.Fatalf("UpdateDSL failed: %v", err)
}
if rows != 1 {
t.Fatalf("expected 1 row affected, got %d", rows)
}
canvas, err := dao.GetByID("canvas-1")
if err != nil {
t.Fatalf("failed to get canvas: %v", err)
}
graph, ok := canvas.DSL["graph"].(map[string]interface{})
if !ok {
t.Fatalf("expected graph map, got %T", canvas.DSL["graph"])
}
nodes, ok := graph["nodes"].([]interface{})
if !ok {
t.Fatalf("expected nodes slice, got %T", graph["nodes"])
}
if len(nodes) != 2 || nodes[0] != "start" || nodes[1] != "end" {
t.Fatalf("unexpected nodes after update: %v", nodes)
}
path, ok := canvas.DSL["path"].([]interface{})
if !ok {
t.Fatalf("expected path slice, got %T", canvas.DSL["path"])
}
if len(path) != 2 || path[0] != "start" || path[1] != "end" {
t.Fatalf("unexpected path after update: %v", path)
}
}
func TestUserCanvasDAOUpdateDSLNoMatch(t *testing.T) {
db := setupUserCanvasTestDB(t)
pushDB(t, db)
dao := NewUserCanvasDAO()
originalDSL := entity.JSONMap{"path": []interface{}{"old"}}
if err := dao.Create(&entity.UserCanvas{
ID: "canvas-1",
UserID: "user-1",
Title: stringPtr("Test Canvas"),
CanvasCategory: "agent_canvas",
DSL: originalDSL,
}); err != nil {
t.Fatalf("failed to create canvas: %v", err)
}
rows, err := dao.UpdateDSL("missing-canvas", entity.JSONMap{"path": []interface{}{"new"}})
if err != nil {
t.Fatalf("UpdateDSL failed: %v", err)
}
if rows != 0 {
t.Fatalf("expected 0 rows affected, got %d", rows)
}
canvas, err := dao.GetByID("canvas-1")
if err != nil {
t.Fatalf("failed to get canvas: %v", err)
}
path, ok := canvas.DSL["path"].([]interface{})
if !ok {
t.Fatalf("expected path slice, got %T", canvas.DSL["path"])
}
if len(path) != 1 || path[0] != "old" {
t.Fatalf("expected original DSL to remain unchanged, got %v", path)
}
}

View File

@@ -16,6 +16,8 @@
package entity
import "encoding/json"
// APIToken API token model
type APIToken struct {
TenantID string `gorm:"column:tenant_id;size:32;not null;primaryKey" json:"tenant_id"`
@@ -33,20 +35,21 @@ func (APIToken) TableName() string {
// API4Conversation API for conversation model
type API4Conversation struct {
ID string `gorm:"column:id;primaryKey;size:32" json:"id"`
Name *string `gorm:"column:name;size:255" json:"name,omitempty"`
DialogID string `gorm:"column:dialog_id;size:32;not null;index" json:"dialog_id"`
UserID string `gorm:"column:user_id;size:255;not null;index" json:"user_id"`
ExpUserID *string `gorm:"column:exp_user_id;size:255;index" json:"exp_user_id,omitempty"`
Message JSONMap `gorm:"column:message;type:longtext" json:"message,omitempty"`
Reference JSONMap `gorm:"column:reference;type:longtext" json:"reference"`
Tokens int64 `gorm:"column:tokens;default:0" json:"tokens"`
Source *string `gorm:"column:source;size:16;index" json:"source,omitempty"`
DSL JSONMap `gorm:"column:dsl;type:longtext" json:"dsl,omitempty"`
Duration float64 `gorm:"column:duration;default:0;index" json:"duration"`
Round int64 `gorm:"column:round;default:0;index" json:"round"`
ThumbUp int64 `gorm:"column:thumb_up;default:0;index" json:"thumb_up"`
Errors *string `gorm:"column:errors;type:longtext" json:"errors,omitempty"`
ID string `gorm:"column:id;primaryKey;size:32" json:"id"`
Name *string `gorm:"column:name;size:255" json:"name,omitempty"`
DialogID string `gorm:"column:dialog_id;size:32;not null;index" json:"dialog_id"`
UserID string `gorm:"column:user_id;size:255;not null;index" json:"user_id"`
ExpUserID *string `gorm:"column:exp_user_id;size:255;index" json:"exp_user_id,omitempty"`
Message json.RawMessage `gorm:"column:message;type:longtext" json:"message,omitempty"`
Reference json.RawMessage `gorm:"column:reference;type:longtext" json:"reference,omitempty"`
Tokens int `gorm:"column:tokens" json:"tokens"`
Source *string `gorm:"column:source;size:16" json:"source,omitempty"`
DSL JSONMap `gorm:"column:dsl;type:longtext" json:"dsl,omitempty"`
Duration float64 `gorm:"column:duration" json:"duration"`
Round int `gorm:"column:round" json:"round"`
ThumbUp int `gorm:"column:thumb_up" json:"thumb_up"`
Errors *string `gorm:"column:errors;type:text" json:"errors,omitempty"`
VersionTitle *string `gorm:"column:version_title;size:255" json:"version_title,omitempty"`
BaseModel
}

View File

@@ -131,6 +131,182 @@ func (h *AgentHandler) ListAgents(c *gin.Context) {
})
}
// ListAgentSessions List all sessions
func (h *AgentHandler) ListAgentSessions(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
agentID := c.Param("agent_id")
if agentID == "" {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeArgumentError,
"data": nil,
"message": "agent_id is required",
})
return
}
page := parsePositiveIntQuery(c, "page", 1)
pageSize := parsePositiveIntQuery(c, "page_size", 30)
if pageSize > 100 {
pageSize = 100
}
req := service.ListAgentSessionsRequest{
SessionID: c.Query("id"),
UserID: c.Query("user_id"),
Page: page,
PageSize: pageSize,
Keywords: c.Query("keywords"),
FromDate: c.Query("from_date"),
ToDate: c.Query("to_date"),
OrderBy: defaultQueryString(c.Query("orderby"), "update_time"),
ExpUserID: c.Query("exp_user_id"),
Desc: c.Query("desc") != "False" && c.Query("desc") != "false",
IncludeDSL: c.Query("dsl") != "False" && c.Query("dsl") != "false",
}
tenantID := user.ID
result, code, err := h.agentService.ListAgentSessions(user.ID, tenantID, agentID, req)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"code": code,
"data": nil,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"code": common.CodeSuccess,
"data": result.Data,
"message": "success",
"total": result.Total,
})
}
func parsePositiveIntQuery(c *gin.Context, key string, defaultValue int) int {
raw := c.Query(key)
if raw == "" {
return defaultValue
}
value, err := strconv.Atoi(raw)
if err != nil || value <= 0 {
return defaultValue
}
return value
}
func defaultQueryString(value, defaultValue string) string {
if value == "" {
return defaultValue
}
return value
}
func (h *AgentHandler) GetAgentSession(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
agentID := c.Param("agent_id")
if agentID == "" {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeOperatingError,
"data": nil,
"message": "agent_id is required",
})
return
}
sessionID := c.Param("session_id")
if sessionID == "" {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeDataError,
"data": nil,
"message": "session_id is required",
})
return
}
userID := user.ID
userID = strings.TrimSpace(userID)
sessionID = strings.TrimSpace(sessionID)
agentID = strings.TrimSpace(agentID)
data, code, err := h.agentService.GetAgentSession(userID, agentID, sessionID)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"code": code,
"data": nil,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"code": common.CodeSuccess,
"data": data,
"message": "success",
})
}
func (h *AgentHandler) DeleteAgentSessionItem(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
agentID := c.Param("agent_id")
if agentID == "" {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeOperatingError,
"data": nil,
"message": "agent_id is required",
})
return
}
sessionID := c.Param("session_id")
if sessionID == "" {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeDataError,
"data": nil,
"message": "session_id is required",
})
return
}
userID := user.ID
userID = strings.TrimSpace(userID)
sessionID = strings.TrimSpace(sessionID)
agentID = strings.TrimSpace(agentID)
ok, code, err := h.agentService.DeleteAgentSessionItem(userID, agentID, sessionID)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"code": code,
"data": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"code": code,
"data": ok,
"message": "success",
})
}
// ListAgentVersions returns versions for a specific agent.
// @Summary List Agent Versions
// @Description Returns all versions for a specific agent, ordered by update_time DESC.

View File

@@ -49,6 +49,7 @@ func setupHandlerAgentsTestDB(t *testing.T) *gorm.DB {
&entity.User{},
&entity.UserCanvas{},
&entity.UserCanvasVersion{},
&entity.API4Conversation{},
); err != nil {
t.Fatalf("failed to migrate: %v", err)
}
@@ -262,6 +263,251 @@ func TestGetAgentVersionHandler_VersionNotFound(t *testing.T) {
}
}
func TestListAgentSessionsHandlerSuccess(t *testing.T) {
c, w, db := setupGinContextWithUserAndDB(t, http.MethodGet, "/api/v1/agents/canvas-1/sessions")
c.Params = gin.Params{{Key: "agent_id", Value: "canvas-1"}}
db.Create(&entity.UserCanvas{
ID: "canvas-1",
UserID: "user-1",
Title: sptr("Test Agent"),
})
db.Create(&entity.API4Conversation{
ID: "session-1",
DialogID: "canvas-1",
UserID: "user-1",
Message: json.RawMessage(`[{"role":"assistant","content":"hello","prompt":"hidden"}]`),
Reference: json.RawMessage(`[]`),
BaseModel: entity.BaseModel{
UpdateTime: ptr(time.Now().UnixMilli()),
},
})
db.Create(&entity.API4Conversation{
ID: "session-2",
DialogID: "canvas-1",
UserID: "user-1",
Message: json.RawMessage(`[{"role":"user","content":"question"}]`),
Reference: json.RawMessage(`[]`),
BaseModel: entity.BaseModel{
UpdateTime: ptr(time.Now().Add(-time.Hour).UnixMilli()),
},
})
db.Create(&entity.API4Conversation{
ID: "session-other-agent",
DialogID: "canvas-other",
UserID: "user-1",
Message: json.RawMessage(`[{"role":"assistant","content":"other"}]`),
Reference: json.RawMessage(`[]`),
})
h := NewAgentHandler(service.NewAgentService(), nil)
h.ListAgentSessions(c)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to parse response: %v", err)
}
if resp["code"] != float64(common.CodeSuccess) {
t.Fatalf("expected code %d, got %v: %v", common.CodeSuccess, resp["code"], resp["message"])
}
if resp["total"] != float64(2) {
t.Fatalf("expected total 2, got %v", resp["total"])
}
data, ok := resp["data"].([]interface{})
if !ok {
t.Fatalf("expected data array, got %T", resp["data"])
}
if len(data) != 2 {
t.Fatalf("expected 2 sessions, got %d", len(data))
}
first := data[0].(map[string]interface{})
if first["agent_id"] != "canvas-1" {
t.Fatalf("expected agent_id canvas-1, got %v", first["agent_id"])
}
messages := first["message"].([]interface{})
message := messages[0].(map[string]interface{})
if _, ok := message["prompt"]; ok {
t.Fatalf("expected prompt to be stripped from list response")
}
}
func TestGetAgentSessionHandlerSuccess(t *testing.T) {
c, w, db := setupGinContextWithUserAndDB(t, http.MethodGet, "/api/v1/agents/canvas-1/sessions/session-1")
c.Params = gin.Params{{Key: "agent_id", Value: "canvas-1"}, {Key: "session_id", Value: "session-1"}}
db.Create(&entity.UserCanvas{
ID: "canvas-1",
UserID: "user-1",
Title: sptr("Test Agent"),
})
db.Create(&entity.API4Conversation{
ID: "session-1",
DialogID: "canvas-1",
UserID: "user-1",
Message: json.RawMessage(`[{"role":"assistant","content":"hello"}]`),
Reference: json.RawMessage(`[]`),
})
h := NewAgentHandler(service.NewAgentService(), nil)
h.GetAgentSession(c)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to parse response: %v", err)
}
if resp["code"] != float64(common.CodeSuccess) {
t.Fatalf("expected code %d, got %v: %v", common.CodeSuccess, resp["code"], resp["message"])
}
data, ok := resp["data"].(map[string]interface{})
if !ok {
t.Fatalf("expected data object, got %T", resp["data"])
}
if data["id"] != "session-1" {
t.Fatalf("expected session-1, got %v", data["id"])
}
if data["dialog_id"] != "canvas-1" {
t.Fatalf("expected dialog_id canvas-1, got %v", data["dialog_id"])
}
}
func TestGetAgentSessionHandlerRejectsSessionFromAnotherAgent(t *testing.T) {
c, w, db := setupGinContextWithUserAndDB(t, http.MethodGet, "/api/v1/agents/canvas-1/sessions/session-other")
c.Params = gin.Params{{Key: "agent_id", Value: "canvas-1"}, {Key: "session_id", Value: "session-other"}}
db.Create(&entity.UserCanvas{
ID: "canvas-1",
UserID: "user-1",
Title: sptr("Test Agent"),
})
db.Create(&entity.API4Conversation{
ID: "session-other",
DialogID: "canvas-other",
UserID: "user-1",
Message: json.RawMessage(`[{"role":"assistant","content":"other"}]`),
Reference: json.RawMessage(`[]`),
})
h := NewAgentHandler(service.NewAgentService(), nil)
h.GetAgentSession(c)
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to parse response: %v", err)
}
if resp["code"] == float64(common.CodeSuccess) {
t.Fatalf("expected non-success for cross-agent session, got response %v", resp)
}
if resp["data"] != nil {
t.Fatalf("expected nil data for cross-agent session, got %v", resp["data"])
}
}
func TestDeleteAgentSessionItemHandlerDeletesOnlyMatchingAgent(t *testing.T) {
c, w, db := setupGinContextWithUserAndDB(t, http.MethodDelete, "/api/v1/agents/canvas-1/sessions/session-1")
c.Params = gin.Params{{Key: "agent_id", Value: "canvas-1"}, {Key: "session_id", Value: "session-1"}}
db.Create(&entity.UserCanvas{
ID: "canvas-1",
UserID: "user-1",
Title: sptr("Test Agent"),
})
db.Create(&entity.API4Conversation{
ID: "session-1",
DialogID: "canvas-1",
UserID: "user-1",
Message: json.RawMessage(`[]`),
Reference: json.RawMessage(`[]`),
})
db.Create(&entity.API4Conversation{
ID: "session-other",
DialogID: "canvas-other",
UserID: "user-1",
Message: json.RawMessage(`[]`),
Reference: json.RawMessage(`[]`),
})
h := NewAgentHandler(service.NewAgentService(), nil)
h.DeleteAgentSessionItem(c)
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to parse response: %v", err)
}
if resp["code"] != float64(common.CodeSuccess) {
t.Fatalf("expected code %d, got %v: %v", common.CodeSuccess, resp["code"], resp["message"])
}
if resp["data"] != true {
t.Fatalf("expected data true, got %v", resp["data"])
}
var deletedCount int64
if err := db.Model(&entity.API4Conversation{}).Where("id = ?", "session-1").Count(&deletedCount).Error; err != nil {
t.Fatalf("failed to count deleted session: %v", err)
}
if deletedCount != 0 {
t.Fatalf("expected session-1 to be deleted, count=%d", deletedCount)
}
var otherCount int64
if err := db.Model(&entity.API4Conversation{}).Where("id = ?", "session-other").Count(&otherCount).Error; err != nil {
t.Fatalf("failed to count other session: %v", err)
}
if otherCount != 1 {
t.Fatalf("expected session-other to remain, count=%d", otherCount)
}
}
func TestDeleteAgentSessionItemHandlerIgnoresSessionFromAnotherAgent(t *testing.T) {
c, w, db := setupGinContextWithUserAndDB(t, http.MethodDelete, "/api/v1/agents/canvas-1/sessions/session-other")
c.Params = gin.Params{{Key: "agent_id", Value: "canvas-1"}, {Key: "session_id", Value: "session-other"}}
db.Create(&entity.UserCanvas{
ID: "canvas-1",
UserID: "user-1",
Title: sptr("Test Agent"),
})
db.Create(&entity.API4Conversation{
ID: "session-other",
DialogID: "canvas-other",
UserID: "user-1",
Message: json.RawMessage(`[]`),
Reference: json.RawMessage(`[]`),
})
h := NewAgentHandler(service.NewAgentService(), nil)
h.DeleteAgentSessionItem(c)
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to parse response: %v", err)
}
if resp["code"] != float64(common.CodeSuccess) {
t.Fatalf("expected code %d, got %v: %v", common.CodeSuccess, resp["code"], resp["message"])
}
if resp["data"] != false {
t.Fatalf("expected data false, got %v", resp["data"])
}
var count int64
if err := db.Model(&entity.API4Conversation{}).Where("id = ?", "session-other").Count(&count).Error; err != nil {
t.Fatalf("failed to count other session: %v", err)
}
if count != 1 {
t.Fatalf("expected cross-agent session to remain, count=%d", count)
}
}
func TestUpdateAgentTagsHandlerSuccess(t *testing.T) {
c, w, db := setupGinContextWithUserAndDB(t, http.MethodPut, "/api/v1/agents/canvas-1/tags")
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/agents/canvas-1/tags", strings.NewReader(`{"tags":["alpha","beta","alpha"]}`))

View File

@@ -381,6 +381,9 @@ func (r *Router) Setup(engine *gin.Engine) {
agents.GET("/:agent_id/versions/:version_id", r.agentHandler.GetAgentVersion)
agents.POST("/:agent_id/upload", r.agentHandler.UploadAgentFile)
agents.PUT("/:agent_id/tags", r.agentHandler.UpdateAgentTags)
agents.GET("/:agent_id/sessions", r.agentHandler.ListAgentSessions)
agents.GET("/:agent_id/sessions/:session_id", r.agentHandler.GetAgentSession)
agents.DELETE("/:agent_id/sessions/:session_id", r.agentHandler.DeleteAgentSessionItem)
}
// Plugin routes

View File

@@ -17,8 +17,13 @@
package service
import (
"encoding/json"
"errors"
"fmt"
"sort"
"strconv"
"strings"
"time"
"ragflow/internal/common"
"ragflow/internal/dao"
@@ -36,6 +41,7 @@ type AgentService struct {
userTenantDAO *dao.UserTenantDAO
userCanvasVersionDAO *dao.UserCanvasVersionDAO
canvasTemplateDAO *dao.CanvasTemplateDAO
api4ConversationDAO *dao.API4ConversationDAO
}
// NewAgentService create agent service
@@ -44,6 +50,7 @@ func NewAgentService() *AgentService {
canvasDAO: dao.NewUserCanvasDAO(),
userTenantDAO: dao.NewUserTenantDAO(),
userCanvasVersionDAO: dao.NewUserCanvasVersionDAO(),
api4ConversationDAO: dao.NewAPI4ConversationDAO(),
canvasTemplateDAO: dao.NewCanvasTemplateDAO(),
}
}
@@ -137,6 +144,336 @@ func (s *AgentService) ListAgents(userID string, keywords string, page, pageSize
return &ListAgentsResponse{Canvas: items, Total: total}, common.CodeSuccess, nil
}
type ListAgentSessionsRequest struct {
SessionID string
UserID string
Page int
PageSize int
Keywords string
FromDate string
ToDate string
OrderBy string
Desc bool
ExpUserID string
IncludeDSL bool
}
type ListAgentSessionsResponse struct {
Data []map[string]interface{} `json:"data"`
Total int64 `json:"total"`
}
func parseAgentSessionDate(value string, isEnd bool) (*time.Time, error) {
if value == "" {
return nil, nil
}
if strings.Contains(value, "T") {
normalized := strings.ReplaceAll(value, "Z", "+00:00")
parsed, err := time.Parse(time.RFC3339, normalized)
if err != nil {
return nil, err
}
local := parsed.Local()
return &local, nil
}
if len(value) == 10 {
if isEnd {
value += " 23:59:59"
} else {
value += " 00:00:00"
}
}
parsed, err := time.ParseInLocation("2006-01-02 15:04:05", value, time.Local)
if err != nil {
return nil, err
}
return &parsed, nil
}
func normalizeAgentSession(session *entity.API4Conversation, includeDSL bool) map[string]interface{} {
messages := parseAgentSessionMessages(session.Message)
references := parseAgentSessionReferences(session.Reference)
for _, message := range messages {
delete(message, "prompt")
}
if len(references) > 0 {
assistantMessages := make([]map[string]interface{}, 0)
for i, message := range messages {
role, _ := message["role"].(string)
if i != 0 && role != "user" {
assistantMessages = append(assistantMessages, message)
}
}
for i := 0; i < len(assistantMessages) && i < len(references); i++ {
rawChunks, _ := references[i]["chunks"].([]interface{})
assistantMessages[i]["reference"] = normalizeAgentReferenceChunks(rawChunks)
}
}
result := map[string]interface{}{
"id": session.ID,
"name": session.Name,
"agent_id": session.DialogID,
"user_id": session.UserID,
"exp_user_id": session.ExpUserID,
"message": messages,
"tokens": session.Tokens,
"source": session.Source,
"duration": session.Duration,
"round": session.Round,
"thumb_up": session.ThumbUp,
"errors": session.Errors,
"version_title": session.VersionTitle,
"create_time": session.CreateTime,
"create_date": session.CreateDate,
"update_time": session.UpdateTime,
"update_date": session.UpdateDate,
}
if includeDSL {
result["dsl"] = session.DSL
}
return result
}
func parseAgentSessionReferences(raw json.RawMessage) []map[string]interface{} {
if len(raw) == 0 {
return []map[string]interface{}{}
}
var references []map[string]interface{}
if err := json.Unmarshal(raw, &references); err == nil {
for i, reference := range references {
references[i] = normalizeAgentReferenceEntry(reference)
}
return references
}
var referenceMap map[string]interface{}
if err := json.Unmarshal(raw, &referenceMap); err != nil {
return []map[string]interface{}{}
}
if _, ok := referenceMap["chunks"]; ok {
return []map[string]interface{}{normalizeAgentReferenceEntry(referenceMap)}
}
keys := make([]string, 0, len(referenceMap))
for key := range referenceMap {
keys = append(keys, key)
}
sort.Slice(keys, func(i, j int) bool {
left, _ := strconv.Atoi(keys[i])
right, _ := strconv.Atoi(keys[j])
return left < right
})
result := make([]map[string]interface{}, 0, len(keys))
for _, key := range keys {
reference, ok := referenceMap[key].(map[string]interface{})
if !ok {
continue
}
result = append(result, normalizeAgentReferenceEntry(reference))
}
return result
}
func parseAgentSessionMessages(raw json.RawMessage) []map[string]interface{} {
if len(raw) == 0 {
return []map[string]interface{}{}
}
var messages []map[string]interface{}
if err := json.Unmarshal(raw, &messages); err != nil {
return []map[string]interface{}{}
}
return messages
}
func normalizeAgentReferenceEntry(reference map[string]interface{}) map[string]interface{} {
if reference == nil {
return map[string]interface{}{
"chunks": []interface{}{},
"doc_aggs": []interface{}{},
}
}
if _, ok := reference["chunks"]; ok {
return map[string]interface{}{
"chunks": valueOrEmptySlice(reference["chunks"]),
"doc_aggs": valueOrEmptySlice(reference["doc_aggs"]),
}
}
if _, ok := reference["doc_aggs"]; ok {
return map[string]interface{}{
"chunks": valueOrEmptySlice(reference["chunks"]),
"doc_aggs": valueOrEmptySlice(reference["doc_aggs"]),
}
}
return map[string]interface{}{
"chunks": valueOrEmptySlice(reference["reference"]),
"doc_aggs": valueOrEmptySlice(reference["doc_aggs"]),
}
}
func valueOrEmptySlice(value interface{}) interface{} {
if value == nil {
return []interface{}{}
}
return value
}
func normalizeAgentReferenceChunks(chunks []interface{}) []map[string]interface{} {
result := make([]map[string]interface{}, 0, len(chunks))
for _, rawChunk := range chunks {
chunk, ok := rawChunk.(map[string]interface{})
if !ok {
continue
}
result = append(result, map[string]interface{}{
"id": firstNonNil(chunk["chunk_id"], chunk["id"]),
"content": firstNonNil(chunk["content_with_weight"], chunk["content"]),
"document_id": firstNonNil(chunk["doc_id"], chunk["document_id"]),
"document_name": firstNonNil(chunk["docnm_kwd"], chunk["document_name"]),
"dataset_id": firstNonNil(chunk["kb_id"], chunk["dataset_id"]),
"image_id": firstNonNil(chunk["image_id"], chunk["img_id"]),
"positions": firstNonNil(chunk["positions"], chunk["position_int"]),
})
}
return result
}
func firstNonNil(values ...interface{}) interface{} {
for _, value := range values {
if value != nil {
return value
}
}
return nil
}
func (s *AgentService) ListAgentSessions(userID, tenantID, agentID string, req ListAgentSessionsRequest) (*ListAgentSessionsResponse, common.ErrorCode, error) {
if agentID == "" {
return nil, common.CodeArgumentError, errors.New("agent_id is required")
}
ok, err := s.CheckCanvasAccess(userID, agentID)
if err != nil {
return nil, common.CodeServerError, fmt.Errorf("failed to check agent permission: %w", err)
}
if !ok {
return nil, common.CodeOperatingError, fmt.Errorf("Agent not found or no permission.")
}
sessionDAO := dao.NewChatSessionDAO()
if req.ExpUserID != "" {
rows, err := sessionDAO.ListAgentSessionNames(agentID, req.ExpUserID)
if err != nil {
return nil, common.CodeServerError, err
}
return &ListAgentSessionsResponse{Data: rows, Total: int64(len(rows))}, common.CodeSuccess, nil
}
fromDate, err := parseAgentSessionDate(req.FromDate, false)
if err != nil {
return nil, common.CodeArgumentError, err
}
toDate, err := parseAgentSessionDate(req.ToDate, true)
if err != nil {
return nil, common.CodeArgumentError, err
}
total, sessions, err := sessionDAO.ListAgentSessions(dao.ListAgentSessionsParams{
AgentID: agentID,
Page: req.Page,
PageSize: req.PageSize,
OrderBy: req.OrderBy,
Desc: req.Desc,
SessionID: req.SessionID,
UserID: req.UserID,
IncludeDSL: req.IncludeDSL,
Keywords: req.Keywords,
FromDate: fromDate,
ToDate: toDate,
ExpUserID: req.ExpUserID,
})
if err != nil {
return nil, common.CodeServerError, err
}
data := make([]map[string]interface{}, 0, len(sessions))
for _, session := range sessions {
data = append(data, normalizeAgentSession(session, req.IncludeDSL))
}
return &ListAgentSessionsResponse{Data: data, Total: total}, common.CodeSuccess, nil
}
func (s *AgentService) GetAgentSession(userID, agentID, sessionID string) (*entity.API4Conversation, common.ErrorCode, error) {
if sessionID == "" {
return nil, common.CodeArgumentError, fmt.Errorf("session_id is required")
}
ok, err := s.CheckCanvasAccess(userID, agentID)
if err != nil {
return nil, common.CodeServerError, fmt.Errorf("failed to check agent permission: %w", err)
}
if !ok {
return nil, common.CodeOperatingError, fmt.Errorf("Agent not found or no permission.")
}
data, err := s.api4ConversationDAO.GetBySessionID(sessionID, agentID)
if err != nil {
return nil, common.CodeServerError, fmt.Errorf("failed to fetch session: %w", err)
}
if data == nil {
return nil, common.CodeNotFound, fmt.Errorf("agent session not found")
}
return data, common.CodeSuccess, nil
}
func (s *AgentService) DeleteAgentSessionItem(userID, agentID, sessionID string) (bool, common.ErrorCode, error) {
if sessionID == "" {
return false, common.CodeArgumentError, errors.New("session_id is required")
}
ok, err := s.CheckCanvasAccess(userID, agentID)
if err != nil {
return false, common.CodeServerError, fmt.Errorf("failed to check agent permission: %w", err)
}
if !ok {
return false, common.CodeOperatingError, fmt.Errorf("Agent not found or no permission.")
}
row, err := s.api4ConversationDAO.DeleteBySessionIDAndAgentID(sessionID, agentID)
if err != nil {
return false, common.CodeServerError, err
}
if row == 0 {
return false, common.CodeSuccess, nil
}
return true, common.CodeSuccess, nil
}
// normalizeAgentTags returns an error for unsupported tag payload types
func normalizeAgentTags(rawTags interface{}) (string, error) {
cleaned := make([]string, 0)