feat[Go]: implement /api/v1/connectors/<connector_id> PATCH (#15512)

### What problem does this PR solve?

As title, all test are passed

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
ちー
2026-06-02 19:34:07 +08:00
committed by GitHub
parent 9f969feb89
commit 5f8926410d
4 changed files with 245 additions and 0 deletions

View File

@@ -17,6 +17,7 @@
package dao
import (
"errors"
"ragflow/internal/entity"
"time"
@@ -114,6 +115,36 @@ func (dao *ConnectorDAO) CancelRunningOrScheduledLogs(connectorID string) error
Update("status", string(entity.TaskStatusCancel)).Error
}
// ScheduleConnectorTasks schedules sync and optional prune tasks for a connector.
func (dao *ConnectorDAO) ScheduleConnectorTasks(connectorID string) error {
return DB.Transaction(func(tx *gorm.DB) error {
var connector entity.Connector
if err := tx.Where("id = ?", connectorID).First(&connector).Error; err != nil {
return err
}
var mappings []entity.Connector2Kb
if err := tx.Where("connector_id = ?", connectorID).Find(&mappings).Error; err != nil {
return err
}
for _, mapping := range mappings {
if err := scheduleConnectorTask(tx, connectorID, mapping.KbID, connectorTaskTypeSync, false); err != nil {
return err
}
if connectorConfigBool(connector.Config, "sync_deleted_files") {
if err := scheduleConnectorTask(tx, connectorID, mapping.KbID, connectorTaskTypePrune, false); err != nil {
return err
}
}
}
return tx.Model(&entity.Connector{}).
Where("id = ?", connectorID).
Update("status", string(entity.TaskStatusSchedule)).Error
})
}
// ListDocumentsByKBAndSourceType lists connector documents in a dataset.
func (dao *ConnectorDAO) ListDocumentsByKBAndSourceType(kbID, sourceType string) ([]*entity.Document, error) {
var documents []*entity.Document
@@ -224,6 +255,68 @@ func createRebuildSyncLog(tx *gorm.DB, connectorID, kbID, taskType string, reind
}).Error
}
func scheduleConnectorTask(tx *gorm.DB, connectorID, kbID, taskType string, reindex bool) error {
var existing int64
if err := tx.Model(&entity.SyncLogs{}).
Where("connector_id = ? AND kb_id = ? AND task_type = ? AND status = ?", connectorID, kbID, taskType, string(entity.TaskStatusSchedule)).
Count(&existing).Error; err != nil {
return err
}
if existing > 0 {
return nil
}
var pollRangeStart *string
var totalDocsIndexed int64
if taskType == connectorTaskTypeSync {
var latest entity.SyncLogs
err := tx.Where("connector_id = ? AND kb_id = ? AND task_type = ? AND status = ?", connectorID, kbID, taskType, string(entity.TaskStatusDone)).
Order("update_time DESC").
First(&latest).Error
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
if err == nil {
pollRangeStart = latest.PollRangeEnd
totalDocsIndexed = latest.TotalDocsIndexed
}
}
fromBeginning := "0"
if reindex {
fromBeginning = "1"
}
now := time.Now().Local()
return tx.Create(&entity.SyncLogs{
ID: generateUUID(),
ConnectorID: connectorID,
KbID: kbID,
TaskType: taskType,
Status: string(entity.TaskStatusSchedule),
FromBeginning: &fromBeginning,
PollRangeStart: pollRangeStart,
TimeStarted: &now,
ErrorMsg: "",
TotalDocsIndexed: totalDocsIndexed,
}).Error
}
func connectorConfigBool(config map[string]interface{}, key string) bool {
value, ok := config[key]
if !ok {
return false
}
switch typed := value.(type) {
case bool:
return typed
case string:
return typed == "1" || typed == "true" || typed == "TRUE"
default:
return false
}
}
// ListLogsByConnectorID lists sync logs for one connector with pagination.
func (dao *ConnectorDAO) ListLogsByConnectorID(connectorID string, offset, limit int) ([]*entity.ConnectorSyncLog, int64, error) {
baseQuery := DB.Model(&entity.SyncLogs{}).

View File

@@ -17,6 +17,7 @@
package handler
import (
"encoding/json"
"errors"
"net/http"
"strconv"
@@ -37,6 +38,7 @@ type connectorServiceIface interface {
DeleteConnector(connectorID, userID string) (bool, common.ErrorCode, error)
RebuildConnector(connectorID, userID, kbID string) (bool, common.ErrorCode, error)
TestConnector(connectorID, userID string) error
UpdateConnector(connectorID, userID string, req *service.UpdateConnectorRequest) (*entity.Connector, common.ErrorCode, error)
}
// ConnectorHandler connector handler
@@ -133,6 +135,60 @@ func (h *ConnectorHandler) GetConnector(c *gin.Context) {
})
}
// UpdateConnector Update an accessible connector's polling configuration.
func (h *ConnectorHandler) UpdateConnector(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
req, err := decodeUpdateConnectorRequest(c)
if err != nil {
jsonError(c, common.CodeBadRequest, err.Error())
return
}
connector, code, err := h.connectorService.UpdateConnector(c.Param("connector_id"), user.ID, req)
if err != nil {
jsonError(c, code, err.Error())
return
}
c.JSON(http.StatusOK, gin.H{
"code": common.CodeSuccess,
"data": connector,
"message": "success",
})
}
func decodeUpdateConnectorRequest(c *gin.Context) (*service.UpdateConnectorRequest, error) {
var raw map[string]json.RawMessage
if err := c.ShouldBindJSON(&raw); err != nil {
return nil, err
}
payload := raw
if dataRaw, ok := raw["data"]; ok {
var nested map[string]json.RawMessage
if err := json.Unmarshal(dataRaw, &nested); err == nil && nested != nil {
payload = nested
}
}
data, err := json.Marshal(payload)
if err != nil {
return nil, err
}
var req service.UpdateConnectorRequest
if err := json.Unmarshal(data, &req); err != nil {
return nil, err
}
return &req, nil
}
// ListLogs list connector sync logs.
// @Summary List Connector Logs
// @Description List sync logs for a connector the current user can access

View File

@@ -45,6 +45,13 @@ func (s fakeConnectorService) GetConnector(string, string) (*entity.Connector, c
return s.connector, common.CodeSuccess, nil
}
func (s fakeConnectorService) UpdateConnector(string, string, *service.UpdateConnectorRequest) (*entity.Connector, common.ErrorCode, error) {
if s.err != nil {
return nil, s.code, s.err
}
return s.connector, common.CodeSuccess, nil
}
func (s fakeConnectorService) ListLog(string, string, int, int) ([]*entity.ConnectorSyncLog, int64, common.ErrorCode, error) {
if s.err != nil {
return nil, 0, s.code, s.err

View File

@@ -20,6 +20,7 @@ import (
"context"
"errors"
"fmt"
"strings"
"gorm.io/gorm"
@@ -288,6 +289,94 @@ func (s *ConnectorService) DeleteConnector(connectorID, userID string) (bool, co
return true, common.CodeSuccess, nil
}
type UpdateConnectorRequest struct {
PruneFreq *int64 `json:"prune_freq,omitempty"`
RefreshFreq *int64 `json:"refresh_freq,omitempty"`
Config entity.JSONMap `json:"config,omitempty"`
TimeoutSecs *int64 `json:"timeout_secs,omitempty"`
Reschedule bool `json:"reschedule,omitempty"`
Status string `json:"status,omitempty"`
}
func (s *ConnectorService) UpdateConnector(connectorID, userID string, req *UpdateConnectorRequest) (*entity.Connector, common.ErrorCode, error) {
if connectorID == "" {
return nil, common.CodeDataError, fmt.Errorf("connector_id is required")
}
connector, err := s.connectorDAO.GetByID(connectorID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, common.CodeDataError, fmt.Errorf("Can't find this Connector!")
}
return nil, common.CodeServerError, err
}
if !s.canAccessConnector(connector, userID) {
return nil, common.CodeAuthenticationError, fmt.Errorf("No authorization.")
}
updates := map[string]interface{}{}
if req != nil {
if req.PruneFreq != nil {
updates["prune_freq"] = *req.PruneFreq
}
if req.RefreshFreq != nil {
updates["refresh_freq"] = *req.RefreshFreq
}
if req.Config != nil {
updates["config"] = req.Config
}
if req.TimeoutSecs != nil {
updates["timeout_secs"] = *req.TimeoutSecs
}
}
if len(updates) > 0 {
if err := s.connectorDAO.UpdateByID(connectorID, updates); err != nil {
return nil, common.CodeServerError, err
}
}
if req != nil {
if req.Reschedule {
if err := s.cancelConnectorTasks(connectorID); err != nil {
return nil, common.CodeServerError, err
}
if err := s.connectorDAO.ScheduleConnectorTasks(connectorID); err != nil {
return nil, common.CodeServerError, err
}
} else if isConnectorCancelStatus(req.Status) {
if err := s.cancelConnectorTasks(connectorID); err != nil {
return nil, common.CodeServerError, err
}
} else if isConnectorScheduleStatus(req.Status) {
if err := s.connectorDAO.ScheduleConnectorTasks(connectorID); err != nil {
return nil, common.CodeServerError, err
}
}
}
connector, err = s.connectorDAO.GetByID(connectorID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, common.CodeDataError, fmt.Errorf("Can't find this Connector!")
}
return nil, common.CodeServerError, err
}
return connector, common.CodeSuccess, nil
}
func isConnectorCancelStatus(status string) bool {
status = strings.TrimSpace(status)
return status == string(entity.TaskStatusCancel) || strings.EqualFold(status, "CANCEL")
}
func isConnectorScheduleStatus(status string) bool {
status = strings.TrimSpace(status)
return status == string(entity.TaskStatusSchedule) || strings.EqualFold(status, "SCHEDULE")
}
// RebuildConnector schedules a rebuild for an accessible connector and knowledge base.
func (s *ConnectorService) RebuildConnector(connectorID, userID, kbID string) (bool, common.ErrorCode, error) {
if connectorID == "" {