From 5f8926410da2236fd3ea359d929b8e6152658fd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E3=81=A1=E3=83=BC?= <85853517+Dimon0000000@users.noreply.github.com> Date: Tue, 2 Jun 2026 19:34:07 +0800 Subject: [PATCH] feat[Go]: implement /api/v1/connectors/ PATCH (#15512) ### What problem does this PR solve? As title, all test are passed ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- internal/dao/connector.go | 93 ++++++++++++++++++++++++++++++ internal/handler/connector.go | 56 ++++++++++++++++++ internal/handler/connector_test.go | 7 +++ internal/service/connector.go | 89 ++++++++++++++++++++++++++++ 4 files changed, 245 insertions(+) diff --git a/internal/dao/connector.go b/internal/dao/connector.go index e10a28221d..0961a12401 100644 --- a/internal/dao/connector.go +++ b/internal/dao/connector.go @@ -17,6 +17,7 @@ package dao import ( + "errors" "ragflow/internal/entity" "time" @@ -114,6 +115,36 @@ func (dao *ConnectorDAO) CancelRunningOrScheduledLogs(connectorID string) error Update("status", string(entity.TaskStatusCancel)).Error } +// ScheduleConnectorTasks schedules sync and optional prune tasks for a connector. +func (dao *ConnectorDAO) ScheduleConnectorTasks(connectorID string) error { + return DB.Transaction(func(tx *gorm.DB) error { + var connector entity.Connector + if err := tx.Where("id = ?", connectorID).First(&connector).Error; err != nil { + return err + } + + var mappings []entity.Connector2Kb + if err := tx.Where("connector_id = ?", connectorID).Find(&mappings).Error; err != nil { + return err + } + + for _, mapping := range mappings { + if err := scheduleConnectorTask(tx, connectorID, mapping.KbID, connectorTaskTypeSync, false); err != nil { + return err + } + if connectorConfigBool(connector.Config, "sync_deleted_files") { + if err := scheduleConnectorTask(tx, connectorID, mapping.KbID, connectorTaskTypePrune, false); err != nil { + return err + } + } + } + + return tx.Model(&entity.Connector{}). + Where("id = ?", connectorID). + Update("status", string(entity.TaskStatusSchedule)).Error + }) +} + // ListDocumentsByKBAndSourceType lists connector documents in a dataset. func (dao *ConnectorDAO) ListDocumentsByKBAndSourceType(kbID, sourceType string) ([]*entity.Document, error) { var documents []*entity.Document @@ -224,6 +255,68 @@ func createRebuildSyncLog(tx *gorm.DB, connectorID, kbID, taskType string, reind }).Error } +func scheduleConnectorTask(tx *gorm.DB, connectorID, kbID, taskType string, reindex bool) error { + var existing int64 + if err := tx.Model(&entity.SyncLogs{}). + Where("connector_id = ? AND kb_id = ? AND task_type = ? AND status = ?", connectorID, kbID, taskType, string(entity.TaskStatusSchedule)). + Count(&existing).Error; err != nil { + return err + } + if existing > 0 { + return nil + } + + var pollRangeStart *string + var totalDocsIndexed int64 + if taskType == connectorTaskTypeSync { + var latest entity.SyncLogs + err := tx.Where("connector_id = ? AND kb_id = ? AND task_type = ? AND status = ?", connectorID, kbID, taskType, string(entity.TaskStatusDone)). + Order("update_time DESC"). + First(&latest).Error + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return err + } + if err == nil { + pollRangeStart = latest.PollRangeEnd + totalDocsIndexed = latest.TotalDocsIndexed + } + } + + fromBeginning := "0" + if reindex { + fromBeginning = "1" + } + now := time.Now().Local() + return tx.Create(&entity.SyncLogs{ + ID: generateUUID(), + ConnectorID: connectorID, + KbID: kbID, + TaskType: taskType, + Status: string(entity.TaskStatusSchedule), + FromBeginning: &fromBeginning, + PollRangeStart: pollRangeStart, + TimeStarted: &now, + ErrorMsg: "", + TotalDocsIndexed: totalDocsIndexed, + }).Error +} + +func connectorConfigBool(config map[string]interface{}, key string) bool { + value, ok := config[key] + if !ok { + return false + } + + switch typed := value.(type) { + case bool: + return typed + case string: + return typed == "1" || typed == "true" || typed == "TRUE" + default: + return false + } +} + // ListLogsByConnectorID lists sync logs for one connector with pagination. func (dao *ConnectorDAO) ListLogsByConnectorID(connectorID string, offset, limit int) ([]*entity.ConnectorSyncLog, int64, error) { baseQuery := DB.Model(&entity.SyncLogs{}). diff --git a/internal/handler/connector.go b/internal/handler/connector.go index d693b50ad3..999b08aeaf 100644 --- a/internal/handler/connector.go +++ b/internal/handler/connector.go @@ -17,6 +17,7 @@ package handler import ( + "encoding/json" "errors" "net/http" "strconv" @@ -37,6 +38,7 @@ type connectorServiceIface interface { DeleteConnector(connectorID, userID string) (bool, common.ErrorCode, error) RebuildConnector(connectorID, userID, kbID string) (bool, common.ErrorCode, error) TestConnector(connectorID, userID string) error + UpdateConnector(connectorID, userID string, req *service.UpdateConnectorRequest) (*entity.Connector, common.ErrorCode, error) } // ConnectorHandler connector handler @@ -133,6 +135,60 @@ func (h *ConnectorHandler) GetConnector(c *gin.Context) { }) } +// UpdateConnector Update an accessible connector's polling configuration. +func (h *ConnectorHandler) UpdateConnector(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + req, err := decodeUpdateConnectorRequest(c) + if err != nil { + jsonError(c, common.CodeBadRequest, err.Error()) + return + } + + connector, code, err := h.connectorService.UpdateConnector(c.Param("connector_id"), user.ID, req) + if err != nil { + jsonError(c, code, err.Error()) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeSuccess, + "data": connector, + "message": "success", + }) +} + +func decodeUpdateConnectorRequest(c *gin.Context) (*service.UpdateConnectorRequest, error) { + var raw map[string]json.RawMessage + if err := c.ShouldBindJSON(&raw); err != nil { + return nil, err + } + + payload := raw + if dataRaw, ok := raw["data"]; ok { + var nested map[string]json.RawMessage + if err := json.Unmarshal(dataRaw, &nested); err == nil && nested != nil { + payload = nested + } + } + + data, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + var req service.UpdateConnectorRequest + if err := json.Unmarshal(data, &req); err != nil { + return nil, err + } + + return &req, nil +} + // ListLogs list connector sync logs. // @Summary List Connector Logs // @Description List sync logs for a connector the current user can access diff --git a/internal/handler/connector_test.go b/internal/handler/connector_test.go index 11710b5508..a5e17c50f0 100644 --- a/internal/handler/connector_test.go +++ b/internal/handler/connector_test.go @@ -45,6 +45,13 @@ func (s fakeConnectorService) GetConnector(string, string) (*entity.Connector, c return s.connector, common.CodeSuccess, nil } +func (s fakeConnectorService) UpdateConnector(string, string, *service.UpdateConnectorRequest) (*entity.Connector, common.ErrorCode, error) { + if s.err != nil { + return nil, s.code, s.err + } + return s.connector, common.CodeSuccess, nil +} + func (s fakeConnectorService) ListLog(string, string, int, int) ([]*entity.ConnectorSyncLog, int64, common.ErrorCode, error) { if s.err != nil { return nil, 0, s.code, s.err diff --git a/internal/service/connector.go b/internal/service/connector.go index 88449cbbd7..78bd393d5f 100644 --- a/internal/service/connector.go +++ b/internal/service/connector.go @@ -20,6 +20,7 @@ import ( "context" "errors" "fmt" + "strings" "gorm.io/gorm" @@ -288,6 +289,94 @@ func (s *ConnectorService) DeleteConnector(connectorID, userID string) (bool, co return true, common.CodeSuccess, nil } +type UpdateConnectorRequest struct { + PruneFreq *int64 `json:"prune_freq,omitempty"` + RefreshFreq *int64 `json:"refresh_freq,omitempty"` + Config entity.JSONMap `json:"config,omitempty"` + TimeoutSecs *int64 `json:"timeout_secs,omitempty"` + Reschedule bool `json:"reschedule,omitempty"` + Status string `json:"status,omitempty"` +} + +func (s *ConnectorService) UpdateConnector(connectorID, userID string, req *UpdateConnectorRequest) (*entity.Connector, common.ErrorCode, error) { + if connectorID == "" { + return nil, common.CodeDataError, fmt.Errorf("connector_id is required") + } + + connector, err := s.connectorDAO.GetByID(connectorID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, common.CodeDataError, fmt.Errorf("Can't find this Connector!") + } + return nil, common.CodeServerError, err + } + + if !s.canAccessConnector(connector, userID) { + return nil, common.CodeAuthenticationError, fmt.Errorf("No authorization.") + } + + updates := map[string]interface{}{} + if req != nil { + if req.PruneFreq != nil { + updates["prune_freq"] = *req.PruneFreq + } + if req.RefreshFreq != nil { + updates["refresh_freq"] = *req.RefreshFreq + } + if req.Config != nil { + updates["config"] = req.Config + } + if req.TimeoutSecs != nil { + updates["timeout_secs"] = *req.TimeoutSecs + } + } + + if len(updates) > 0 { + if err := s.connectorDAO.UpdateByID(connectorID, updates); err != nil { + return nil, common.CodeServerError, err + } + } + + if req != nil { + if req.Reschedule { + if err := s.cancelConnectorTasks(connectorID); err != nil { + return nil, common.CodeServerError, err + } + if err := s.connectorDAO.ScheduleConnectorTasks(connectorID); err != nil { + return nil, common.CodeServerError, err + } + } else if isConnectorCancelStatus(req.Status) { + if err := s.cancelConnectorTasks(connectorID); err != nil { + return nil, common.CodeServerError, err + } + } else if isConnectorScheduleStatus(req.Status) { + if err := s.connectorDAO.ScheduleConnectorTasks(connectorID); err != nil { + return nil, common.CodeServerError, err + } + } + } + + connector, err = s.connectorDAO.GetByID(connectorID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, common.CodeDataError, fmt.Errorf("Can't find this Connector!") + } + return nil, common.CodeServerError, err + } + + return connector, common.CodeSuccess, nil +} + +func isConnectorCancelStatus(status string) bool { + status = strings.TrimSpace(status) + return status == string(entity.TaskStatusCancel) || strings.EqualFold(status, "CANCEL") +} + +func isConnectorScheduleStatus(status string) bool { + status = strings.TrimSpace(status) + return status == string(entity.TaskStatusSchedule) || strings.EqualFold(status, "SCHEDULE") +} + // RebuildConnector schedules a rebuild for an accessible connector and knowledge base. func (s *ConnectorService) RebuildConnector(connectorID, userID, kbID string) (bool, common.ErrorCode, error) { if connectorID == "" {