mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-03 09:11:59 +08:00
feat[Go]: implement /api/v1/agents/<agent_id>/sessions (#15705)
### What problem does this PR solve? As Title Codes were tested by Postman ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@@ -119,6 +119,24 @@ func (dao *API4ConversationDAO) Stats(tenantID, fromDate, toDate string, source
|
||||
return rows, err
|
||||
}
|
||||
|
||||
func (dao *API4ConversationDAO) GetBySessionID(sessionID, agentID string) (*entity.API4Conversation, error) {
|
||||
var result entity.API4Conversation
|
||||
tx := DB.Where("id = ? AND dialog_id = ?", sessionID, agentID).Find(&result)
|
||||
if tx.Error != nil {
|
||||
return nil, tx.Error
|
||||
}
|
||||
if tx.RowsAffected == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// DeleteBySessionIDAndAgentID deletes API4Conversations by sessionID and agentID
|
||||
func (dao *API4ConversationDAO) DeleteBySessionIDAndAgentID(sessionID, agentID string) (int64, error) {
|
||||
result := DB.Where("id = ? AND dialog_id = ?", sessionID, agentID).Delete(&entity.API4Conversation{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
// DeleteByDialogIDs deletes API4Conversations by dialog IDs (hard delete)
|
||||
func (dao *API4ConversationDAO) DeleteByDialogIDs(dialogIDs []string) (int64, error) {
|
||||
if len(dialogIDs) == 0 {
|
||||
|
||||
108
internal/dao/api_token_test.go
Normal file
108
internal/dao/api_token_test.go
Normal file
@@ -0,0 +1,108 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package dao
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
func setupAPI4ConversationTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||
TranslateError: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open sqlite: %v", err)
|
||||
}
|
||||
|
||||
if err := db.AutoMigrate(&entity.API4Conversation{}); err != nil {
|
||||
t.Fatalf("failed to migrate: %v", err)
|
||||
}
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func createAPI4ConversationForDAOTest(t *testing.T, id, agentID string) {
|
||||
t.Helper()
|
||||
if err := DB.Create(&entity.API4Conversation{
|
||||
ID: id,
|
||||
DialogID: agentID,
|
||||
UserID: "user-1",
|
||||
Message: json.RawMessage(`[]`),
|
||||
Reference: json.RawMessage(`[]`),
|
||||
}).Error; err != nil {
|
||||
t.Fatalf("failed to create api conversation %s: %v", id, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPI4ConversationDAOGetBySessionID(t *testing.T) {
|
||||
db := setupAPI4ConversationTestDB(t)
|
||||
pushDB(t, db)
|
||||
|
||||
createAPI4ConversationForDAOTest(t, "session-1", "agent-1")
|
||||
|
||||
session, err := NewAPI4ConversationDAO().GetBySessionID("session-1", "agent-1")
|
||||
if err != nil {
|
||||
t.Fatalf("GetBySessionID failed: %v", err)
|
||||
}
|
||||
if session == nil {
|
||||
t.Fatal("expected session, got nil")
|
||||
}
|
||||
if session.ID != "session-1" {
|
||||
t.Fatalf("expected session-1, got %s", session.ID)
|
||||
}
|
||||
if session.DialogID != "agent-1" {
|
||||
t.Fatalf("expected agent-1, got %s", session.DialogID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPI4ConversationDAOGetBySessionIDWrongAgent(t *testing.T) {
|
||||
db := setupAPI4ConversationTestDB(t)
|
||||
pushDB(t, db)
|
||||
|
||||
createAPI4ConversationForDAOTest(t, "session-1", "agent-1")
|
||||
|
||||
session, err := NewAPI4ConversationDAO().GetBySessionID("session-1", "agent-2")
|
||||
if err != nil {
|
||||
t.Fatalf("GetBySessionID failed: %v", err)
|
||||
}
|
||||
if session != nil {
|
||||
t.Fatalf("expected nil for wrong agent, got %+v", session)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPI4ConversationDAOGetBySessionIDNoRows(t *testing.T) {
|
||||
db := setupAPI4ConversationTestDB(t)
|
||||
pushDB(t, db)
|
||||
|
||||
createAPI4ConversationForDAOTest(t, "session-1", "agent-1")
|
||||
|
||||
session, err := NewAPI4ConversationDAO().GetBySessionID("missing-session", "agent-1")
|
||||
if err != nil {
|
||||
t.Fatalf("GetBySessionID failed: %v", err)
|
||||
}
|
||||
if session != nil {
|
||||
t.Fatalf("expected nil for missing session, got %+v", session)
|
||||
}
|
||||
}
|
||||
@@ -17,12 +17,30 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
// ChatSessionDAO chat session data access object
|
||||
type ChatSessionDAO struct{}
|
||||
|
||||
type ListAgentSessionsParams struct {
|
||||
AgentID string
|
||||
Page int
|
||||
PageSize int
|
||||
OrderBy string
|
||||
Desc bool
|
||||
SessionID string
|
||||
UserID string
|
||||
IncludeDSL bool
|
||||
Keywords string
|
||||
FromDate *time.Time
|
||||
ToDate *time.Time
|
||||
ExpUserID string
|
||||
}
|
||||
|
||||
// NewChatSessionDAO create chat session DAO
|
||||
func NewChatSessionDAO() *ChatSessionDAO {
|
||||
return &ChatSessionDAO{}
|
||||
@@ -92,3 +110,107 @@ func (dao *ChatSessionDAO) DeleteByDialogIDs(dialogIDs []string) (int64, error)
|
||||
result := DB.Unscoped().Where("dialog_id IN ?", dialogIDs).Delete(&entity.ChatSession{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
func (dao *ChatSessionDAO) ListAgentSessionNames(agentID, expUserID string) ([]map[string]interface{}, error) {
|
||||
var rows []map[string]interface{}
|
||||
err := DB.Model(&entity.API4Conversation{}).
|
||||
Select("id", "name").
|
||||
Where("dialog_id = ? AND exp_user_id = ?", agentID, expUserID).
|
||||
Order("create_date DESC").
|
||||
Find(&rows).Error
|
||||
return rows, err
|
||||
}
|
||||
|
||||
func normalizeAgentSessionOrderBy(orderBy string) string {
|
||||
switch orderBy {
|
||||
case "id":
|
||||
return "id"
|
||||
case "name":
|
||||
return "name"
|
||||
case "create_time":
|
||||
return "create_time"
|
||||
case "create_date":
|
||||
return "create_date"
|
||||
case "update_time":
|
||||
return "update_time"
|
||||
case "update_date":
|
||||
return "update_date"
|
||||
case "tokens":
|
||||
return "tokens"
|
||||
case "duration":
|
||||
return "duration"
|
||||
case "round":
|
||||
return "round"
|
||||
case "thumb_up":
|
||||
return "thumb_up"
|
||||
default:
|
||||
return "update_time"
|
||||
}
|
||||
}
|
||||
|
||||
func (dao *ChatSessionDAO) ListAgentSessions(params ListAgentSessionsParams) (int64, []*entity.API4Conversation, error) {
|
||||
query := DB.Model(&entity.API4Conversation{}).Where("dialog_id = ?", params.AgentID)
|
||||
if !params.IncludeDSL {
|
||||
query = query.Omit("dsl")
|
||||
}
|
||||
|
||||
if params.SessionID != "" {
|
||||
query = query.Where("id = ?", params.SessionID)
|
||||
}
|
||||
|
||||
if params.UserID != "" {
|
||||
query = query.Where("user_id = ?", params.UserID)
|
||||
}
|
||||
|
||||
if params.Keywords != "" {
|
||||
query = query.Where("LOWER(message) LIKE ?", "%"+strings.ToLower(params.Keywords)+"%")
|
||||
}
|
||||
|
||||
dateColumn := "create_date"
|
||||
if strings.HasPrefix(params.OrderBy, "update_") {
|
||||
dateColumn = "update_date"
|
||||
}
|
||||
|
||||
if params.FromDate != nil {
|
||||
query = query.Where(dateColumn+" >= ?", *params.FromDate)
|
||||
}
|
||||
|
||||
if params.ToDate != nil {
|
||||
query = query.Where(dateColumn+" <= ?", *params.ToDate)
|
||||
}
|
||||
|
||||
if params.ExpUserID != "" {
|
||||
query = query.Where("exp_user_id = ?", params.ExpUserID)
|
||||
}
|
||||
|
||||
var total int64
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
orderBy := normalizeAgentSessionOrderBy(params.OrderBy)
|
||||
if params.Desc {
|
||||
orderBy += " DESC"
|
||||
} else {
|
||||
orderBy += " ASC"
|
||||
}
|
||||
|
||||
page := params.Page
|
||||
if page <= 0 {
|
||||
page = 1
|
||||
}
|
||||
|
||||
pageSize := params.PageSize
|
||||
if pageSize <= 0 {
|
||||
pageSize = 30
|
||||
}
|
||||
|
||||
var sessions []*entity.API4Conversation
|
||||
err := query.
|
||||
Order(orderBy).
|
||||
Offset((page - 1) * pageSize).
|
||||
Limit(pageSize).
|
||||
Find(&sessions).Error
|
||||
|
||||
return total, sessions, err
|
||||
}
|
||||
|
||||
140
internal/dao/chat_session_test.go
Normal file
140
internal/dao/chat_session_test.go
Normal file
@@ -0,0 +1,140 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package dao
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
func setupChatSessionDAOTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||
TranslateError: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open sqlite: %v", err)
|
||||
}
|
||||
|
||||
if err := db.AutoMigrate(&entity.API4Conversation{}); err != nil {
|
||||
t.Fatalf("failed to migrate: %v", err)
|
||||
}
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func createAgentSessionForDAOTest(t *testing.T, db *gorm.DB, id, agentID, userID string, updateTime int64) {
|
||||
t.Helper()
|
||||
|
||||
updateDate := time.UnixMilli(updateTime).Local()
|
||||
session := &entity.API4Conversation{
|
||||
ID: id,
|
||||
DialogID: agentID,
|
||||
UserID: userID,
|
||||
Message: json.RawMessage(`[{"role":"assistant","content":"hello"}]`),
|
||||
Reference: json.RawMessage(`[]`),
|
||||
BaseModel: entity.BaseModel{
|
||||
CreateTime: &updateTime,
|
||||
CreateDate: &updateDate,
|
||||
UpdateTime: &updateTime,
|
||||
UpdateDate: &updateDate,
|
||||
},
|
||||
}
|
||||
if err := db.Create(session).Error; err != nil {
|
||||
t.Fatalf("failed to create session %s: %v", id, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatSessionDAOListAgentSessionsOrdersByUpdateTimeDesc(t *testing.T) {
|
||||
db := setupChatSessionDAOTestDB(t)
|
||||
pushDB(t, db)
|
||||
|
||||
createAgentSessionForDAOTest(t, db, "session-old", "agent-1", "user-1", 1000)
|
||||
createAgentSessionForDAOTest(t, db, "session-new", "agent-1", "user-1", 3000)
|
||||
createAgentSessionForDAOTest(t, db, "session-middle", "agent-1", "user-1", 2000)
|
||||
createAgentSessionForDAOTest(t, db, "session-other-agent", "agent-2", "user-1", 9999)
|
||||
|
||||
total, sessions, err := NewChatSessionDAO().ListAgentSessions(ListAgentSessionsParams{
|
||||
AgentID: "agent-1",
|
||||
Page: 1,
|
||||
PageSize: 10,
|
||||
OrderBy: "update_time",
|
||||
Desc: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ListAgentSessions failed: %v", err)
|
||||
}
|
||||
|
||||
if total != 3 {
|
||||
t.Fatalf("expected total 3, got %d", total)
|
||||
}
|
||||
if len(sessions) != 3 {
|
||||
t.Fatalf("expected 3 sessions, got %d", len(sessions))
|
||||
}
|
||||
|
||||
wantIDs := []string{"session-new", "session-middle", "session-old"}
|
||||
for i, wantID := range wantIDs {
|
||||
if sessions[i].ID != wantID {
|
||||
t.Fatalf("session[%d]: expected %s, got %s", i, wantID, sessions[i].ID)
|
||||
}
|
||||
if sessions[i].DialogID != "agent-1" {
|
||||
t.Fatalf("session[%d]: expected agent-1, got %s", i, sessions[i].DialogID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatSessionDAOListAgentSessionsFiltersAndPaginates(t *testing.T) {
|
||||
db := setupChatSessionDAOTestDB(t)
|
||||
pushDB(t, db)
|
||||
|
||||
createAgentSessionForDAOTest(t, db, "session-1", "agent-1", "user-1", 1000)
|
||||
createAgentSessionForDAOTest(t, db, "session-2", "agent-1", "user-1", 2000)
|
||||
createAgentSessionForDAOTest(t, db, "session-3", "agent-1", "user-1", 3000)
|
||||
createAgentSessionForDAOTest(t, db, "session-other-user", "agent-1", "user-2", 4000)
|
||||
|
||||
total, sessions, err := NewChatSessionDAO().ListAgentSessions(ListAgentSessionsParams{
|
||||
AgentID: "agent-1",
|
||||
UserID: "user-1",
|
||||
Page: 2,
|
||||
PageSize: 1,
|
||||
OrderBy: "update_time",
|
||||
Desc: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ListAgentSessions failed: %v", err)
|
||||
}
|
||||
|
||||
if total != 3 {
|
||||
t.Fatalf("expected total 3 after user filter, got %d", total)
|
||||
}
|
||||
if len(sessions) != 1 {
|
||||
t.Fatalf("expected one paginated session, got %d", len(sessions))
|
||||
}
|
||||
if sessions[0].ID != "session-2" {
|
||||
t.Fatalf("expected second ascending session session-2, got %s", sessions[0].ID)
|
||||
}
|
||||
if sessions[0].UserID != "user-1" {
|
||||
t.Fatalf("expected user-1, got %s", sessions[0].UserID)
|
||||
}
|
||||
}
|
||||
@@ -55,14 +55,7 @@ func (dao *UserCanvasDAO) Delete(id string) error {
|
||||
|
||||
// GetList get canvases list with pagination and filtering
|
||||
// Similar to Python UserCanvasService.get_list
|
||||
func (dao *UserCanvasDAO) GetList(
|
||||
tenantID string,
|
||||
pageNumber, itemsPerPage int,
|
||||
orderby string,
|
||||
desc bool,
|
||||
id, title string,
|
||||
canvasCategory string,
|
||||
) ([]*entity.UserCanvas, error) {
|
||||
func (dao *UserCanvasDAO) GetList(tenantID string, pageNumber, itemsPerPage int, orderby string, desc bool, id, title string, canvasCategory string) ([]*entity.UserCanvas, error) {
|
||||
|
||||
query := DB.Model(&entity.UserCanvas{}).
|
||||
Where("user_id = ?", tenantID)
|
||||
@@ -116,15 +109,7 @@ func (dao *UserCanvasDAO) GetAllCanvasesByTenantIDs(tenantIDs []string, userID s
|
||||
// ListByTenantIDs lists agent canvases accessible to the given owner IDs with optional
|
||||
// keyword filter, pagination, and ordering.
|
||||
// Mirrors Python UserCanvasService.get_by_tenant_ids (list route only).
|
||||
func (dao *UserCanvasDAO) ListByTenantIDs(
|
||||
ownerIDs []string,
|
||||
userID string,
|
||||
page, pageSize int,
|
||||
orderby string,
|
||||
desc bool,
|
||||
keywords string,
|
||||
canvasCategory string,
|
||||
) ([]*entity.UserCanvas, int64, error) {
|
||||
func (dao *UserCanvasDAO) ListByTenantIDs(ownerIDs []string, userID string, page, pageSize int, orderby string, desc bool, keywords string, canvasCategory string) ([]*entity.UserCanvas, int64, error) {
|
||||
if len(ownerIDs) == 0 {
|
||||
return nil, 0, nil
|
||||
}
|
||||
@@ -201,6 +186,12 @@ func (dao *UserCanvasDAO) GetAllCanvasIDsByUserID(userID string) ([]string, erro
|
||||
return canvasIDs, err
|
||||
}
|
||||
|
||||
// UpdateDSL updates a canvas DSL by canvas ID.
|
||||
func (dao *UserCanvasDAO) UpdateDSL(canvasID string, dsl entity.JSONMap) (int64, error) {
|
||||
result := DB.Model(&entity.UserCanvas{}).Where("id = ?", canvasID).Update("dsl", dsl)
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
// UpdateTags updates a canvas's comma-separated tags by canvas ID.
|
||||
func (dao *UserCanvasDAO) UpdateTags(canvasID, tags string) (int64, error) {
|
||||
result := DB.Model(&entity.UserCanvas{}).Where("id = ?", canvasID).Update("tags", tags)
|
||||
|
||||
135
internal/dao/user_canvas_test.go
Normal file
135
internal/dao/user_canvas_test.go
Normal file
@@ -0,0 +1,135 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package dao
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"ragflow/internal/entity"
|
||||
)
|
||||
|
||||
func setupUserCanvasTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||
TranslateError: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open sqlite: %v", err)
|
||||
}
|
||||
|
||||
if err := db.AutoMigrate(&entity.UserCanvas{}); err != nil {
|
||||
t.Fatalf("failed to migrate: %v", err)
|
||||
}
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func TestUserCanvasDAOUpdateDSL(t *testing.T) {
|
||||
db := setupUserCanvasTestDB(t)
|
||||
pushDB(t, db)
|
||||
|
||||
dao := NewUserCanvasDAO()
|
||||
originalDSL := entity.JSONMap{"graph": map[string]interface{}{"nodes": []interface{}{"old"}}}
|
||||
if err := dao.Create(&entity.UserCanvas{
|
||||
ID: "canvas-1",
|
||||
UserID: "user-1",
|
||||
Title: stringPtr("Test Canvas"),
|
||||
CanvasCategory: "agent_canvas",
|
||||
DSL: originalDSL,
|
||||
}); err != nil {
|
||||
t.Fatalf("failed to create canvas: %v", err)
|
||||
}
|
||||
|
||||
newDSL := entity.JSONMap{
|
||||
"graph": map[string]interface{}{
|
||||
"nodes": []interface{}{"start", "end"},
|
||||
"edges": []interface{}{"start:end"},
|
||||
},
|
||||
"path": []interface{}{"start", "end"},
|
||||
}
|
||||
rows, err := dao.UpdateDSL("canvas-1", newDSL)
|
||||
if err != nil {
|
||||
t.Fatalf("UpdateDSL failed: %v", err)
|
||||
}
|
||||
if rows != 1 {
|
||||
t.Fatalf("expected 1 row affected, got %d", rows)
|
||||
}
|
||||
|
||||
canvas, err := dao.GetByID("canvas-1")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get canvas: %v", err)
|
||||
}
|
||||
graph, ok := canvas.DSL["graph"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected graph map, got %T", canvas.DSL["graph"])
|
||||
}
|
||||
nodes, ok := graph["nodes"].([]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected nodes slice, got %T", graph["nodes"])
|
||||
}
|
||||
if len(nodes) != 2 || nodes[0] != "start" || nodes[1] != "end" {
|
||||
t.Fatalf("unexpected nodes after update: %v", nodes)
|
||||
}
|
||||
path, ok := canvas.DSL["path"].([]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected path slice, got %T", canvas.DSL["path"])
|
||||
}
|
||||
if len(path) != 2 || path[0] != "start" || path[1] != "end" {
|
||||
t.Fatalf("unexpected path after update: %v", path)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserCanvasDAOUpdateDSLNoMatch(t *testing.T) {
|
||||
db := setupUserCanvasTestDB(t)
|
||||
pushDB(t, db)
|
||||
|
||||
dao := NewUserCanvasDAO()
|
||||
originalDSL := entity.JSONMap{"path": []interface{}{"old"}}
|
||||
if err := dao.Create(&entity.UserCanvas{
|
||||
ID: "canvas-1",
|
||||
UserID: "user-1",
|
||||
Title: stringPtr("Test Canvas"),
|
||||
CanvasCategory: "agent_canvas",
|
||||
DSL: originalDSL,
|
||||
}); err != nil {
|
||||
t.Fatalf("failed to create canvas: %v", err)
|
||||
}
|
||||
|
||||
rows, err := dao.UpdateDSL("missing-canvas", entity.JSONMap{"path": []interface{}{"new"}})
|
||||
if err != nil {
|
||||
t.Fatalf("UpdateDSL failed: %v", err)
|
||||
}
|
||||
if rows != 0 {
|
||||
t.Fatalf("expected 0 rows affected, got %d", rows)
|
||||
}
|
||||
|
||||
canvas, err := dao.GetByID("canvas-1")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get canvas: %v", err)
|
||||
}
|
||||
path, ok := canvas.DSL["path"].([]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected path slice, got %T", canvas.DSL["path"])
|
||||
}
|
||||
if len(path) != 1 || path[0] != "old" {
|
||||
t.Fatalf("expected original DSL to remain unchanged, got %v", path)
|
||||
}
|
||||
}
|
||||
@@ -16,6 +16,8 @@
|
||||
|
||||
package entity
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
// APIToken API token model
|
||||
type APIToken struct {
|
||||
TenantID string `gorm:"column:tenant_id;size:32;not null;primaryKey" json:"tenant_id"`
|
||||
@@ -33,20 +35,21 @@ func (APIToken) TableName() string {
|
||||
|
||||
// API4Conversation API for conversation model
|
||||
type API4Conversation struct {
|
||||
ID string `gorm:"column:id;primaryKey;size:32" json:"id"`
|
||||
Name *string `gorm:"column:name;size:255" json:"name,omitempty"`
|
||||
DialogID string `gorm:"column:dialog_id;size:32;not null;index" json:"dialog_id"`
|
||||
UserID string `gorm:"column:user_id;size:255;not null;index" json:"user_id"`
|
||||
ExpUserID *string `gorm:"column:exp_user_id;size:255;index" json:"exp_user_id,omitempty"`
|
||||
Message JSONMap `gorm:"column:message;type:longtext" json:"message,omitempty"`
|
||||
Reference JSONMap `gorm:"column:reference;type:longtext" json:"reference"`
|
||||
Tokens int64 `gorm:"column:tokens;default:0" json:"tokens"`
|
||||
Source *string `gorm:"column:source;size:16;index" json:"source,omitempty"`
|
||||
DSL JSONMap `gorm:"column:dsl;type:longtext" json:"dsl,omitempty"`
|
||||
Duration float64 `gorm:"column:duration;default:0;index" json:"duration"`
|
||||
Round int64 `gorm:"column:round;default:0;index" json:"round"`
|
||||
ThumbUp int64 `gorm:"column:thumb_up;default:0;index" json:"thumb_up"`
|
||||
Errors *string `gorm:"column:errors;type:longtext" json:"errors,omitempty"`
|
||||
ID string `gorm:"column:id;primaryKey;size:32" json:"id"`
|
||||
Name *string `gorm:"column:name;size:255" json:"name,omitempty"`
|
||||
DialogID string `gorm:"column:dialog_id;size:32;not null;index" json:"dialog_id"`
|
||||
UserID string `gorm:"column:user_id;size:255;not null;index" json:"user_id"`
|
||||
ExpUserID *string `gorm:"column:exp_user_id;size:255;index" json:"exp_user_id,omitempty"`
|
||||
Message json.RawMessage `gorm:"column:message;type:longtext" json:"message,omitempty"`
|
||||
Reference json.RawMessage `gorm:"column:reference;type:longtext" json:"reference,omitempty"`
|
||||
Tokens int `gorm:"column:tokens" json:"tokens"`
|
||||
Source *string `gorm:"column:source;size:16" json:"source,omitempty"`
|
||||
DSL JSONMap `gorm:"column:dsl;type:longtext" json:"dsl,omitempty"`
|
||||
Duration float64 `gorm:"column:duration" json:"duration"`
|
||||
Round int `gorm:"column:round" json:"round"`
|
||||
ThumbUp int `gorm:"column:thumb_up" json:"thumb_up"`
|
||||
Errors *string `gorm:"column:errors;type:text" json:"errors,omitempty"`
|
||||
VersionTitle *string `gorm:"column:version_title;size:255" json:"version_title,omitempty"`
|
||||
BaseModel
|
||||
}
|
||||
|
||||
|
||||
@@ -131,6 +131,182 @@ func (h *AgentHandler) ListAgents(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// ListAgentSessions List all sessions
|
||||
func (h *AgentHandler) ListAgentSessions(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
agentID := c.Param("agent_id")
|
||||
if agentID == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeArgumentError,
|
||||
"data": nil,
|
||||
"message": "agent_id is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
page := parsePositiveIntQuery(c, "page", 1)
|
||||
pageSize := parsePositiveIntQuery(c, "page_size", 30)
|
||||
if pageSize > 100 {
|
||||
pageSize = 100
|
||||
}
|
||||
|
||||
req := service.ListAgentSessionsRequest{
|
||||
SessionID: c.Query("id"),
|
||||
UserID: c.Query("user_id"),
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
Keywords: c.Query("keywords"),
|
||||
FromDate: c.Query("from_date"),
|
||||
ToDate: c.Query("to_date"),
|
||||
OrderBy: defaultQueryString(c.Query("orderby"), "update_time"),
|
||||
ExpUserID: c.Query("exp_user_id"),
|
||||
Desc: c.Query("desc") != "False" && c.Query("desc") != "false",
|
||||
IncludeDSL: c.Query("dsl") != "False" && c.Query("dsl") != "false",
|
||||
}
|
||||
|
||||
tenantID := user.ID
|
||||
result, code, err := h.agentService.ListAgentSessions(user.ID, tenantID, agentID, req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": code,
|
||||
"data": nil,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeSuccess,
|
||||
"data": result.Data,
|
||||
"message": "success",
|
||||
"total": result.Total,
|
||||
})
|
||||
}
|
||||
|
||||
func parsePositiveIntQuery(c *gin.Context, key string, defaultValue int) int {
|
||||
raw := c.Query(key)
|
||||
if raw == "" {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
value, err := strconv.Atoi(raw)
|
||||
if err != nil || value <= 0 {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
return value
|
||||
}
|
||||
|
||||
func defaultQueryString(value, defaultValue string) string {
|
||||
if value == "" {
|
||||
return defaultValue
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func (h *AgentHandler) GetAgentSession(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
agentID := c.Param("agent_id")
|
||||
if agentID == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeOperatingError,
|
||||
"data": nil,
|
||||
"message": "agent_id is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
sessionID := c.Param("session_id")
|
||||
if sessionID == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeDataError,
|
||||
"data": nil,
|
||||
"message": "session_id is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
userID := user.ID
|
||||
userID = strings.TrimSpace(userID)
|
||||
sessionID = strings.TrimSpace(sessionID)
|
||||
agentID = strings.TrimSpace(agentID)
|
||||
|
||||
data, code, err := h.agentService.GetAgentSession(userID, agentID, sessionID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": code,
|
||||
"data": nil,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeSuccess,
|
||||
"data": data,
|
||||
"message": "success",
|
||||
})
|
||||
}
|
||||
|
||||
func (h *AgentHandler) DeleteAgentSessionItem(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
agentID := c.Param("agent_id")
|
||||
if agentID == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeOperatingError,
|
||||
"data": nil,
|
||||
"message": "agent_id is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
sessionID := c.Param("session_id")
|
||||
if sessionID == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeDataError,
|
||||
"data": nil,
|
||||
"message": "session_id is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
userID := user.ID
|
||||
userID = strings.TrimSpace(userID)
|
||||
sessionID = strings.TrimSpace(sessionID)
|
||||
agentID = strings.TrimSpace(agentID)
|
||||
|
||||
ok, code, err := h.agentService.DeleteAgentSessionItem(userID, agentID, sessionID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": code,
|
||||
"data": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": code,
|
||||
"data": ok,
|
||||
"message": "success",
|
||||
})
|
||||
}
|
||||
|
||||
// ListAgentVersions returns versions for a specific agent.
|
||||
// @Summary List Agent Versions
|
||||
// @Description Returns all versions for a specific agent, ordered by update_time DESC.
|
||||
|
||||
@@ -49,6 +49,7 @@ func setupHandlerAgentsTestDB(t *testing.T) *gorm.DB {
|
||||
&entity.User{},
|
||||
&entity.UserCanvas{},
|
||||
&entity.UserCanvasVersion{},
|
||||
&entity.API4Conversation{},
|
||||
); err != nil {
|
||||
t.Fatalf("failed to migrate: %v", err)
|
||||
}
|
||||
@@ -262,6 +263,251 @@ func TestGetAgentVersionHandler_VersionNotFound(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestListAgentSessionsHandlerSuccess(t *testing.T) {
|
||||
c, w, db := setupGinContextWithUserAndDB(t, http.MethodGet, "/api/v1/agents/canvas-1/sessions")
|
||||
c.Params = gin.Params{{Key: "agent_id", Value: "canvas-1"}}
|
||||
|
||||
db.Create(&entity.UserCanvas{
|
||||
ID: "canvas-1",
|
||||
UserID: "user-1",
|
||||
Title: sptr("Test Agent"),
|
||||
})
|
||||
db.Create(&entity.API4Conversation{
|
||||
ID: "session-1",
|
||||
DialogID: "canvas-1",
|
||||
UserID: "user-1",
|
||||
Message: json.RawMessage(`[{"role":"assistant","content":"hello","prompt":"hidden"}]`),
|
||||
Reference: json.RawMessage(`[]`),
|
||||
BaseModel: entity.BaseModel{
|
||||
UpdateTime: ptr(time.Now().UnixMilli()),
|
||||
},
|
||||
})
|
||||
db.Create(&entity.API4Conversation{
|
||||
ID: "session-2",
|
||||
DialogID: "canvas-1",
|
||||
UserID: "user-1",
|
||||
Message: json.RawMessage(`[{"role":"user","content":"question"}]`),
|
||||
Reference: json.RawMessage(`[]`),
|
||||
BaseModel: entity.BaseModel{
|
||||
UpdateTime: ptr(time.Now().Add(-time.Hour).UnixMilli()),
|
||||
},
|
||||
})
|
||||
db.Create(&entity.API4Conversation{
|
||||
ID: "session-other-agent",
|
||||
DialogID: "canvas-other",
|
||||
UserID: "user-1",
|
||||
Message: json.RawMessage(`[{"role":"assistant","content":"other"}]`),
|
||||
Reference: json.RawMessage(`[]`),
|
||||
})
|
||||
|
||||
h := NewAgentHandler(service.NewAgentService(), nil)
|
||||
h.ListAgentSessions(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to parse response: %v", err)
|
||||
}
|
||||
if resp["code"] != float64(common.CodeSuccess) {
|
||||
t.Fatalf("expected code %d, got %v: %v", common.CodeSuccess, resp["code"], resp["message"])
|
||||
}
|
||||
if resp["total"] != float64(2) {
|
||||
t.Fatalf("expected total 2, got %v", resp["total"])
|
||||
}
|
||||
|
||||
data, ok := resp["data"].([]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected data array, got %T", resp["data"])
|
||||
}
|
||||
if len(data) != 2 {
|
||||
t.Fatalf("expected 2 sessions, got %d", len(data))
|
||||
}
|
||||
|
||||
first := data[0].(map[string]interface{})
|
||||
if first["agent_id"] != "canvas-1" {
|
||||
t.Fatalf("expected agent_id canvas-1, got %v", first["agent_id"])
|
||||
}
|
||||
messages := first["message"].([]interface{})
|
||||
message := messages[0].(map[string]interface{})
|
||||
if _, ok := message["prompt"]; ok {
|
||||
t.Fatalf("expected prompt to be stripped from list response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAgentSessionHandlerSuccess(t *testing.T) {
|
||||
c, w, db := setupGinContextWithUserAndDB(t, http.MethodGet, "/api/v1/agents/canvas-1/sessions/session-1")
|
||||
c.Params = gin.Params{{Key: "agent_id", Value: "canvas-1"}, {Key: "session_id", Value: "session-1"}}
|
||||
|
||||
db.Create(&entity.UserCanvas{
|
||||
ID: "canvas-1",
|
||||
UserID: "user-1",
|
||||
Title: sptr("Test Agent"),
|
||||
})
|
||||
db.Create(&entity.API4Conversation{
|
||||
ID: "session-1",
|
||||
DialogID: "canvas-1",
|
||||
UserID: "user-1",
|
||||
Message: json.RawMessage(`[{"role":"assistant","content":"hello"}]`),
|
||||
Reference: json.RawMessage(`[]`),
|
||||
})
|
||||
|
||||
h := NewAgentHandler(service.NewAgentService(), nil)
|
||||
h.GetAgentSession(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to parse response: %v", err)
|
||||
}
|
||||
if resp["code"] != float64(common.CodeSuccess) {
|
||||
t.Fatalf("expected code %d, got %v: %v", common.CodeSuccess, resp["code"], resp["message"])
|
||||
}
|
||||
|
||||
data, ok := resp["data"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected data object, got %T", resp["data"])
|
||||
}
|
||||
if data["id"] != "session-1" {
|
||||
t.Fatalf("expected session-1, got %v", data["id"])
|
||||
}
|
||||
if data["dialog_id"] != "canvas-1" {
|
||||
t.Fatalf("expected dialog_id canvas-1, got %v", data["dialog_id"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAgentSessionHandlerRejectsSessionFromAnotherAgent(t *testing.T) {
|
||||
c, w, db := setupGinContextWithUserAndDB(t, http.MethodGet, "/api/v1/agents/canvas-1/sessions/session-other")
|
||||
c.Params = gin.Params{{Key: "agent_id", Value: "canvas-1"}, {Key: "session_id", Value: "session-other"}}
|
||||
|
||||
db.Create(&entity.UserCanvas{
|
||||
ID: "canvas-1",
|
||||
UserID: "user-1",
|
||||
Title: sptr("Test Agent"),
|
||||
})
|
||||
db.Create(&entity.API4Conversation{
|
||||
ID: "session-other",
|
||||
DialogID: "canvas-other",
|
||||
UserID: "user-1",
|
||||
Message: json.RawMessage(`[{"role":"assistant","content":"other"}]`),
|
||||
Reference: json.RawMessage(`[]`),
|
||||
})
|
||||
|
||||
h := NewAgentHandler(service.NewAgentService(), nil)
|
||||
h.GetAgentSession(c)
|
||||
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to parse response: %v", err)
|
||||
}
|
||||
if resp["code"] == float64(common.CodeSuccess) {
|
||||
t.Fatalf("expected non-success for cross-agent session, got response %v", resp)
|
||||
}
|
||||
if resp["data"] != nil {
|
||||
t.Fatalf("expected nil data for cross-agent session, got %v", resp["data"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteAgentSessionItemHandlerDeletesOnlyMatchingAgent(t *testing.T) {
|
||||
c, w, db := setupGinContextWithUserAndDB(t, http.MethodDelete, "/api/v1/agents/canvas-1/sessions/session-1")
|
||||
c.Params = gin.Params{{Key: "agent_id", Value: "canvas-1"}, {Key: "session_id", Value: "session-1"}}
|
||||
|
||||
db.Create(&entity.UserCanvas{
|
||||
ID: "canvas-1",
|
||||
UserID: "user-1",
|
||||
Title: sptr("Test Agent"),
|
||||
})
|
||||
db.Create(&entity.API4Conversation{
|
||||
ID: "session-1",
|
||||
DialogID: "canvas-1",
|
||||
UserID: "user-1",
|
||||
Message: json.RawMessage(`[]`),
|
||||
Reference: json.RawMessage(`[]`),
|
||||
})
|
||||
db.Create(&entity.API4Conversation{
|
||||
ID: "session-other",
|
||||
DialogID: "canvas-other",
|
||||
UserID: "user-1",
|
||||
Message: json.RawMessage(`[]`),
|
||||
Reference: json.RawMessage(`[]`),
|
||||
})
|
||||
|
||||
h := NewAgentHandler(service.NewAgentService(), nil)
|
||||
h.DeleteAgentSessionItem(c)
|
||||
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to parse response: %v", err)
|
||||
}
|
||||
if resp["code"] != float64(common.CodeSuccess) {
|
||||
t.Fatalf("expected code %d, got %v: %v", common.CodeSuccess, resp["code"], resp["message"])
|
||||
}
|
||||
if resp["data"] != true {
|
||||
t.Fatalf("expected data true, got %v", resp["data"])
|
||||
}
|
||||
|
||||
var deletedCount int64
|
||||
if err := db.Model(&entity.API4Conversation{}).Where("id = ?", "session-1").Count(&deletedCount).Error; err != nil {
|
||||
t.Fatalf("failed to count deleted session: %v", err)
|
||||
}
|
||||
if deletedCount != 0 {
|
||||
t.Fatalf("expected session-1 to be deleted, count=%d", deletedCount)
|
||||
}
|
||||
|
||||
var otherCount int64
|
||||
if err := db.Model(&entity.API4Conversation{}).Where("id = ?", "session-other").Count(&otherCount).Error; err != nil {
|
||||
t.Fatalf("failed to count other session: %v", err)
|
||||
}
|
||||
if otherCount != 1 {
|
||||
t.Fatalf("expected session-other to remain, count=%d", otherCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteAgentSessionItemHandlerIgnoresSessionFromAnotherAgent(t *testing.T) {
|
||||
c, w, db := setupGinContextWithUserAndDB(t, http.MethodDelete, "/api/v1/agents/canvas-1/sessions/session-other")
|
||||
c.Params = gin.Params{{Key: "agent_id", Value: "canvas-1"}, {Key: "session_id", Value: "session-other"}}
|
||||
|
||||
db.Create(&entity.UserCanvas{
|
||||
ID: "canvas-1",
|
||||
UserID: "user-1",
|
||||
Title: sptr("Test Agent"),
|
||||
})
|
||||
db.Create(&entity.API4Conversation{
|
||||
ID: "session-other",
|
||||
DialogID: "canvas-other",
|
||||
UserID: "user-1",
|
||||
Message: json.RawMessage(`[]`),
|
||||
Reference: json.RawMessage(`[]`),
|
||||
})
|
||||
|
||||
h := NewAgentHandler(service.NewAgentService(), nil)
|
||||
h.DeleteAgentSessionItem(c)
|
||||
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to parse response: %v", err)
|
||||
}
|
||||
if resp["code"] != float64(common.CodeSuccess) {
|
||||
t.Fatalf("expected code %d, got %v: %v", common.CodeSuccess, resp["code"], resp["message"])
|
||||
}
|
||||
if resp["data"] != false {
|
||||
t.Fatalf("expected data false, got %v", resp["data"])
|
||||
}
|
||||
|
||||
var count int64
|
||||
if err := db.Model(&entity.API4Conversation{}).Where("id = ?", "session-other").Count(&count).Error; err != nil {
|
||||
t.Fatalf("failed to count other session: %v", err)
|
||||
}
|
||||
if count != 1 {
|
||||
t.Fatalf("expected cross-agent session to remain, count=%d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateAgentTagsHandlerSuccess(t *testing.T) {
|
||||
c, w, db := setupGinContextWithUserAndDB(t, http.MethodPut, "/api/v1/agents/canvas-1/tags")
|
||||
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/agents/canvas-1/tags", strings.NewReader(`{"tags":["alpha","beta","alpha"]}`))
|
||||
|
||||
@@ -381,6 +381,9 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
agents.GET("/:agent_id/versions/:version_id", r.agentHandler.GetAgentVersion)
|
||||
agents.POST("/:agent_id/upload", r.agentHandler.UploadAgentFile)
|
||||
agents.PUT("/:agent_id/tags", r.agentHandler.UpdateAgentTags)
|
||||
agents.GET("/:agent_id/sessions", r.agentHandler.ListAgentSessions)
|
||||
agents.GET("/:agent_id/sessions/:session_id", r.agentHandler.GetAgentSession)
|
||||
agents.DELETE("/:agent_id/sessions/:session_id", r.agentHandler.DeleteAgentSessionItem)
|
||||
}
|
||||
|
||||
// Plugin routes
|
||||
|
||||
@@ -17,8 +17,13 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/dao"
|
||||
@@ -36,6 +41,7 @@ type AgentService struct {
|
||||
userTenantDAO *dao.UserTenantDAO
|
||||
userCanvasVersionDAO *dao.UserCanvasVersionDAO
|
||||
canvasTemplateDAO *dao.CanvasTemplateDAO
|
||||
api4ConversationDAO *dao.API4ConversationDAO
|
||||
}
|
||||
|
||||
// NewAgentService create agent service
|
||||
@@ -44,6 +50,7 @@ func NewAgentService() *AgentService {
|
||||
canvasDAO: dao.NewUserCanvasDAO(),
|
||||
userTenantDAO: dao.NewUserTenantDAO(),
|
||||
userCanvasVersionDAO: dao.NewUserCanvasVersionDAO(),
|
||||
api4ConversationDAO: dao.NewAPI4ConversationDAO(),
|
||||
canvasTemplateDAO: dao.NewCanvasTemplateDAO(),
|
||||
}
|
||||
}
|
||||
@@ -137,6 +144,336 @@ func (s *AgentService) ListAgents(userID string, keywords string, page, pageSize
|
||||
return &ListAgentsResponse{Canvas: items, Total: total}, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
type ListAgentSessionsRequest struct {
|
||||
SessionID string
|
||||
UserID string
|
||||
Page int
|
||||
PageSize int
|
||||
Keywords string
|
||||
FromDate string
|
||||
ToDate string
|
||||
OrderBy string
|
||||
Desc bool
|
||||
ExpUserID string
|
||||
IncludeDSL bool
|
||||
}
|
||||
|
||||
type ListAgentSessionsResponse struct {
|
||||
Data []map[string]interface{} `json:"data"`
|
||||
Total int64 `json:"total"`
|
||||
}
|
||||
|
||||
func parseAgentSessionDate(value string, isEnd bool) (*time.Time, error) {
|
||||
if value == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if strings.Contains(value, "T") {
|
||||
normalized := strings.ReplaceAll(value, "Z", "+00:00")
|
||||
parsed, err := time.Parse(time.RFC3339, normalized)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
local := parsed.Local()
|
||||
return &local, nil
|
||||
}
|
||||
|
||||
if len(value) == 10 {
|
||||
if isEnd {
|
||||
value += " 23:59:59"
|
||||
} else {
|
||||
value += " 00:00:00"
|
||||
}
|
||||
}
|
||||
|
||||
parsed, err := time.ParseInLocation("2006-01-02 15:04:05", value, time.Local)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &parsed, nil
|
||||
}
|
||||
|
||||
func normalizeAgentSession(session *entity.API4Conversation, includeDSL bool) map[string]interface{} {
|
||||
messages := parseAgentSessionMessages(session.Message)
|
||||
references := parseAgentSessionReferences(session.Reference)
|
||||
|
||||
for _, message := range messages {
|
||||
delete(message, "prompt")
|
||||
}
|
||||
|
||||
if len(references) > 0 {
|
||||
assistantMessages := make([]map[string]interface{}, 0)
|
||||
for i, message := range messages {
|
||||
role, _ := message["role"].(string)
|
||||
if i != 0 && role != "user" {
|
||||
assistantMessages = append(assistantMessages, message)
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < len(assistantMessages) && i < len(references); i++ {
|
||||
rawChunks, _ := references[i]["chunks"].([]interface{})
|
||||
assistantMessages[i]["reference"] = normalizeAgentReferenceChunks(rawChunks)
|
||||
}
|
||||
}
|
||||
|
||||
result := map[string]interface{}{
|
||||
"id": session.ID,
|
||||
"name": session.Name,
|
||||
"agent_id": session.DialogID,
|
||||
"user_id": session.UserID,
|
||||
"exp_user_id": session.ExpUserID,
|
||||
"message": messages,
|
||||
"tokens": session.Tokens,
|
||||
"source": session.Source,
|
||||
"duration": session.Duration,
|
||||
"round": session.Round,
|
||||
"thumb_up": session.ThumbUp,
|
||||
"errors": session.Errors,
|
||||
"version_title": session.VersionTitle,
|
||||
"create_time": session.CreateTime,
|
||||
"create_date": session.CreateDate,
|
||||
"update_time": session.UpdateTime,
|
||||
"update_date": session.UpdateDate,
|
||||
}
|
||||
|
||||
if includeDSL {
|
||||
result["dsl"] = session.DSL
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func parseAgentSessionReferences(raw json.RawMessage) []map[string]interface{} {
|
||||
if len(raw) == 0 {
|
||||
return []map[string]interface{}{}
|
||||
}
|
||||
|
||||
var references []map[string]interface{}
|
||||
if err := json.Unmarshal(raw, &references); err == nil {
|
||||
for i, reference := range references {
|
||||
references[i] = normalizeAgentReferenceEntry(reference)
|
||||
}
|
||||
return references
|
||||
}
|
||||
|
||||
var referenceMap map[string]interface{}
|
||||
if err := json.Unmarshal(raw, &referenceMap); err != nil {
|
||||
return []map[string]interface{}{}
|
||||
}
|
||||
|
||||
if _, ok := referenceMap["chunks"]; ok {
|
||||
return []map[string]interface{}{normalizeAgentReferenceEntry(referenceMap)}
|
||||
}
|
||||
|
||||
keys := make([]string, 0, len(referenceMap))
|
||||
for key := range referenceMap {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
|
||||
sort.Slice(keys, func(i, j int) bool {
|
||||
left, _ := strconv.Atoi(keys[i])
|
||||
right, _ := strconv.Atoi(keys[j])
|
||||
return left < right
|
||||
})
|
||||
|
||||
result := make([]map[string]interface{}, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
reference, ok := referenceMap[key].(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
result = append(result, normalizeAgentReferenceEntry(reference))
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func parseAgentSessionMessages(raw json.RawMessage) []map[string]interface{} {
|
||||
if len(raw) == 0 {
|
||||
return []map[string]interface{}{}
|
||||
}
|
||||
|
||||
var messages []map[string]interface{}
|
||||
if err := json.Unmarshal(raw, &messages); err != nil {
|
||||
return []map[string]interface{}{}
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
func normalizeAgentReferenceEntry(reference map[string]interface{}) map[string]interface{} {
|
||||
if reference == nil {
|
||||
return map[string]interface{}{
|
||||
"chunks": []interface{}{},
|
||||
"doc_aggs": []interface{}{},
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := reference["chunks"]; ok {
|
||||
return map[string]interface{}{
|
||||
"chunks": valueOrEmptySlice(reference["chunks"]),
|
||||
"doc_aggs": valueOrEmptySlice(reference["doc_aggs"]),
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := reference["doc_aggs"]; ok {
|
||||
return map[string]interface{}{
|
||||
"chunks": valueOrEmptySlice(reference["chunks"]),
|
||||
"doc_aggs": valueOrEmptySlice(reference["doc_aggs"]),
|
||||
}
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"chunks": valueOrEmptySlice(reference["reference"]),
|
||||
"doc_aggs": valueOrEmptySlice(reference["doc_aggs"]),
|
||||
}
|
||||
}
|
||||
|
||||
func valueOrEmptySlice(value interface{}) interface{} {
|
||||
if value == nil {
|
||||
return []interface{}{}
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func normalizeAgentReferenceChunks(chunks []interface{}) []map[string]interface{} {
|
||||
result := make([]map[string]interface{}, 0, len(chunks))
|
||||
|
||||
for _, rawChunk := range chunks {
|
||||
chunk, ok := rawChunk.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
result = append(result, map[string]interface{}{
|
||||
"id": firstNonNil(chunk["chunk_id"], chunk["id"]),
|
||||
"content": firstNonNil(chunk["content_with_weight"], chunk["content"]),
|
||||
"document_id": firstNonNil(chunk["doc_id"], chunk["document_id"]),
|
||||
"document_name": firstNonNil(chunk["docnm_kwd"], chunk["document_name"]),
|
||||
"dataset_id": firstNonNil(chunk["kb_id"], chunk["dataset_id"]),
|
||||
"image_id": firstNonNil(chunk["image_id"], chunk["img_id"]),
|
||||
"positions": firstNonNil(chunk["positions"], chunk["position_int"]),
|
||||
})
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func firstNonNil(values ...interface{}) interface{} {
|
||||
for _, value := range values {
|
||||
if value != nil {
|
||||
return value
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AgentService) ListAgentSessions(userID, tenantID, agentID string, req ListAgentSessionsRequest) (*ListAgentSessionsResponse, common.ErrorCode, error) {
|
||||
if agentID == "" {
|
||||
return nil, common.CodeArgumentError, errors.New("agent_id is required")
|
||||
}
|
||||
|
||||
ok, err := s.CheckCanvasAccess(userID, agentID)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, fmt.Errorf("failed to check agent permission: %w", err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, common.CodeOperatingError, fmt.Errorf("Agent not found or no permission.")
|
||||
}
|
||||
|
||||
sessionDAO := dao.NewChatSessionDAO()
|
||||
|
||||
if req.ExpUserID != "" {
|
||||
rows, err := sessionDAO.ListAgentSessionNames(agentID, req.ExpUserID)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
return &ListAgentSessionsResponse{Data: rows, Total: int64(len(rows))}, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
fromDate, err := parseAgentSessionDate(req.FromDate, false)
|
||||
if err != nil {
|
||||
return nil, common.CodeArgumentError, err
|
||||
}
|
||||
|
||||
toDate, err := parseAgentSessionDate(req.ToDate, true)
|
||||
if err != nil {
|
||||
return nil, common.CodeArgumentError, err
|
||||
}
|
||||
|
||||
total, sessions, err := sessionDAO.ListAgentSessions(dao.ListAgentSessionsParams{
|
||||
AgentID: agentID,
|
||||
Page: req.Page,
|
||||
PageSize: req.PageSize,
|
||||
OrderBy: req.OrderBy,
|
||||
Desc: req.Desc,
|
||||
SessionID: req.SessionID,
|
||||
UserID: req.UserID,
|
||||
IncludeDSL: req.IncludeDSL,
|
||||
Keywords: req.Keywords,
|
||||
FromDate: fromDate,
|
||||
ToDate: toDate,
|
||||
ExpUserID: req.ExpUserID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
|
||||
data := make([]map[string]interface{}, 0, len(sessions))
|
||||
for _, session := range sessions {
|
||||
data = append(data, normalizeAgentSession(session, req.IncludeDSL))
|
||||
}
|
||||
|
||||
return &ListAgentSessionsResponse{Data: data, Total: total}, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
func (s *AgentService) GetAgentSession(userID, agentID, sessionID string) (*entity.API4Conversation, common.ErrorCode, error) {
|
||||
if sessionID == "" {
|
||||
return nil, common.CodeArgumentError, fmt.Errorf("session_id is required")
|
||||
}
|
||||
ok, err := s.CheckCanvasAccess(userID, agentID)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, fmt.Errorf("failed to check agent permission: %w", err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, common.CodeOperatingError, fmt.Errorf("Agent not found or no permission.")
|
||||
}
|
||||
|
||||
data, err := s.api4ConversationDAO.GetBySessionID(sessionID, agentID)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, fmt.Errorf("failed to fetch session: %w", err)
|
||||
}
|
||||
if data == nil {
|
||||
return nil, common.CodeNotFound, fmt.Errorf("agent session not found")
|
||||
}
|
||||
return data, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
func (s *AgentService) DeleteAgentSessionItem(userID, agentID, sessionID string) (bool, common.ErrorCode, error) {
|
||||
if sessionID == "" {
|
||||
return false, common.CodeArgumentError, errors.New("session_id is required")
|
||||
}
|
||||
ok, err := s.CheckCanvasAccess(userID, agentID)
|
||||
if err != nil {
|
||||
return false, common.CodeServerError, fmt.Errorf("failed to check agent permission: %w", err)
|
||||
}
|
||||
if !ok {
|
||||
return false, common.CodeOperatingError, fmt.Errorf("Agent not found or no permission.")
|
||||
}
|
||||
|
||||
row, err := s.api4ConversationDAO.DeleteBySessionIDAndAgentID(sessionID, agentID)
|
||||
if err != nil {
|
||||
return false, common.CodeServerError, err
|
||||
}
|
||||
if row == 0 {
|
||||
return false, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
return true, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
// normalizeAgentTags returns an error for unsupported tag payload types
|
||||
func normalizeAgentTags(rawTags interface{}) (string, error) {
|
||||
cleaned := make([]string, 0)
|
||||
|
||||
Reference in New Issue
Block a user