mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
feat[Go]: implement /api/v1/connectors/<connector_id> PATCH (#15512)
### What problem does this PR solve? As title, all test are passed ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@@ -17,6 +17,7 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"ragflow/internal/entity"
|
||||
"time"
|
||||
|
||||
@@ -114,6 +115,36 @@ func (dao *ConnectorDAO) CancelRunningOrScheduledLogs(connectorID string) error
|
||||
Update("status", string(entity.TaskStatusCancel)).Error
|
||||
}
|
||||
|
||||
// ScheduleConnectorTasks schedules sync and optional prune tasks for a connector.
|
||||
func (dao *ConnectorDAO) ScheduleConnectorTasks(connectorID string) error {
|
||||
return DB.Transaction(func(tx *gorm.DB) error {
|
||||
var connector entity.Connector
|
||||
if err := tx.Where("id = ?", connectorID).First(&connector).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var mappings []entity.Connector2Kb
|
||||
if err := tx.Where("connector_id = ?", connectorID).Find(&mappings).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, mapping := range mappings {
|
||||
if err := scheduleConnectorTask(tx, connectorID, mapping.KbID, connectorTaskTypeSync, false); err != nil {
|
||||
return err
|
||||
}
|
||||
if connectorConfigBool(connector.Config, "sync_deleted_files") {
|
||||
if err := scheduleConnectorTask(tx, connectorID, mapping.KbID, connectorTaskTypePrune, false); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Model(&entity.Connector{}).
|
||||
Where("id = ?", connectorID).
|
||||
Update("status", string(entity.TaskStatusSchedule)).Error
|
||||
})
|
||||
}
|
||||
|
||||
// ListDocumentsByKBAndSourceType lists connector documents in a dataset.
|
||||
func (dao *ConnectorDAO) ListDocumentsByKBAndSourceType(kbID, sourceType string) ([]*entity.Document, error) {
|
||||
var documents []*entity.Document
|
||||
@@ -224,6 +255,68 @@ func createRebuildSyncLog(tx *gorm.DB, connectorID, kbID, taskType string, reind
|
||||
}).Error
|
||||
}
|
||||
|
||||
func scheduleConnectorTask(tx *gorm.DB, connectorID, kbID, taskType string, reindex bool) error {
|
||||
var existing int64
|
||||
if err := tx.Model(&entity.SyncLogs{}).
|
||||
Where("connector_id = ? AND kb_id = ? AND task_type = ? AND status = ?", connectorID, kbID, taskType, string(entity.TaskStatusSchedule)).
|
||||
Count(&existing).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if existing > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var pollRangeStart *string
|
||||
var totalDocsIndexed int64
|
||||
if taskType == connectorTaskTypeSync {
|
||||
var latest entity.SyncLogs
|
||||
err := tx.Where("connector_id = ? AND kb_id = ? AND task_type = ? AND status = ?", connectorID, kbID, taskType, string(entity.TaskStatusDone)).
|
||||
Order("update_time DESC").
|
||||
First(&latest).Error
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return err
|
||||
}
|
||||
if err == nil {
|
||||
pollRangeStart = latest.PollRangeEnd
|
||||
totalDocsIndexed = latest.TotalDocsIndexed
|
||||
}
|
||||
}
|
||||
|
||||
fromBeginning := "0"
|
||||
if reindex {
|
||||
fromBeginning = "1"
|
||||
}
|
||||
now := time.Now().Local()
|
||||
return tx.Create(&entity.SyncLogs{
|
||||
ID: generateUUID(),
|
||||
ConnectorID: connectorID,
|
||||
KbID: kbID,
|
||||
TaskType: taskType,
|
||||
Status: string(entity.TaskStatusSchedule),
|
||||
FromBeginning: &fromBeginning,
|
||||
PollRangeStart: pollRangeStart,
|
||||
TimeStarted: &now,
|
||||
ErrorMsg: "",
|
||||
TotalDocsIndexed: totalDocsIndexed,
|
||||
}).Error
|
||||
}
|
||||
|
||||
func connectorConfigBool(config map[string]interface{}, key string) bool {
|
||||
value, ok := config[key]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
switch typed := value.(type) {
|
||||
case bool:
|
||||
return typed
|
||||
case string:
|
||||
return typed == "1" || typed == "true" || typed == "TRUE"
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// ListLogsByConnectorID lists sync logs for one connector with pagination.
|
||||
func (dao *ConnectorDAO) ListLogsByConnectorID(connectorID string, offset, limit int) ([]*entity.ConnectorSyncLog, int64, error) {
|
||||
baseQuery := DB.Model(&entity.SyncLogs{}).
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strconv"
|
||||
@@ -37,6 +38,7 @@ type connectorServiceIface interface {
|
||||
DeleteConnector(connectorID, userID string) (bool, common.ErrorCode, error)
|
||||
RebuildConnector(connectorID, userID, kbID string) (bool, common.ErrorCode, error)
|
||||
TestConnector(connectorID, userID string) error
|
||||
UpdateConnector(connectorID, userID string, req *service.UpdateConnectorRequest) (*entity.Connector, common.ErrorCode, error)
|
||||
}
|
||||
|
||||
// ConnectorHandler connector handler
|
||||
@@ -133,6 +135,60 @@ func (h *ConnectorHandler) GetConnector(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateConnector Update an accessible connector's polling configuration.
|
||||
func (h *ConnectorHandler) UpdateConnector(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
req, err := decodeUpdateConnectorRequest(c)
|
||||
if err != nil {
|
||||
jsonError(c, common.CodeBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
connector, code, err := h.connectorService.UpdateConnector(c.Param("connector_id"), user.ID, req)
|
||||
if err != nil {
|
||||
jsonError(c, code, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeSuccess,
|
||||
"data": connector,
|
||||
"message": "success",
|
||||
})
|
||||
}
|
||||
|
||||
func decodeUpdateConnectorRequest(c *gin.Context) (*service.UpdateConnectorRequest, error) {
|
||||
var raw map[string]json.RawMessage
|
||||
if err := c.ShouldBindJSON(&raw); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
payload := raw
|
||||
if dataRaw, ok := raw["data"]; ok {
|
||||
var nested map[string]json.RawMessage
|
||||
if err := json.Unmarshal(dataRaw, &nested); err == nil && nested != nil {
|
||||
payload = nested
|
||||
}
|
||||
}
|
||||
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var req service.UpdateConnectorRequest
|
||||
if err := json.Unmarshal(data, &req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &req, nil
|
||||
}
|
||||
|
||||
// ListLogs list connector sync logs.
|
||||
// @Summary List Connector Logs
|
||||
// @Description List sync logs for a connector the current user can access
|
||||
|
||||
@@ -45,6 +45,13 @@ func (s fakeConnectorService) GetConnector(string, string) (*entity.Connector, c
|
||||
return s.connector, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
func (s fakeConnectorService) UpdateConnector(string, string, *service.UpdateConnectorRequest) (*entity.Connector, common.ErrorCode, error) {
|
||||
if s.err != nil {
|
||||
return nil, s.code, s.err
|
||||
}
|
||||
return s.connector, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
func (s fakeConnectorService) ListLog(string, string, int, int) ([]*entity.ConnectorSyncLog, int64, common.ErrorCode, error) {
|
||||
if s.err != nil {
|
||||
return nil, 0, s.code, s.err
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
@@ -288,6 +289,94 @@ func (s *ConnectorService) DeleteConnector(connectorID, userID string) (bool, co
|
||||
return true, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
type UpdateConnectorRequest struct {
|
||||
PruneFreq *int64 `json:"prune_freq,omitempty"`
|
||||
RefreshFreq *int64 `json:"refresh_freq,omitempty"`
|
||||
Config entity.JSONMap `json:"config,omitempty"`
|
||||
TimeoutSecs *int64 `json:"timeout_secs,omitempty"`
|
||||
Reschedule bool `json:"reschedule,omitempty"`
|
||||
Status string `json:"status,omitempty"`
|
||||
}
|
||||
|
||||
func (s *ConnectorService) UpdateConnector(connectorID, userID string, req *UpdateConnectorRequest) (*entity.Connector, common.ErrorCode, error) {
|
||||
if connectorID == "" {
|
||||
return nil, common.CodeDataError, fmt.Errorf("connector_id is required")
|
||||
}
|
||||
|
||||
connector, err := s.connectorDAO.GetByID(connectorID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, common.CodeDataError, fmt.Errorf("Can't find this Connector!")
|
||||
}
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
|
||||
if !s.canAccessConnector(connector, userID) {
|
||||
return nil, common.CodeAuthenticationError, fmt.Errorf("No authorization.")
|
||||
}
|
||||
|
||||
updates := map[string]interface{}{}
|
||||
if req != nil {
|
||||
if req.PruneFreq != nil {
|
||||
updates["prune_freq"] = *req.PruneFreq
|
||||
}
|
||||
if req.RefreshFreq != nil {
|
||||
updates["refresh_freq"] = *req.RefreshFreq
|
||||
}
|
||||
if req.Config != nil {
|
||||
updates["config"] = req.Config
|
||||
}
|
||||
if req.TimeoutSecs != nil {
|
||||
updates["timeout_secs"] = *req.TimeoutSecs
|
||||
}
|
||||
}
|
||||
|
||||
if len(updates) > 0 {
|
||||
if err := s.connectorDAO.UpdateByID(connectorID, updates); err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
}
|
||||
|
||||
if req != nil {
|
||||
if req.Reschedule {
|
||||
if err := s.cancelConnectorTasks(connectorID); err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
if err := s.connectorDAO.ScheduleConnectorTasks(connectorID); err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
} else if isConnectorCancelStatus(req.Status) {
|
||||
if err := s.cancelConnectorTasks(connectorID); err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
} else if isConnectorScheduleStatus(req.Status) {
|
||||
if err := s.connectorDAO.ScheduleConnectorTasks(connectorID); err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
connector, err = s.connectorDAO.GetByID(connectorID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, common.CodeDataError, fmt.Errorf("Can't find this Connector!")
|
||||
}
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
|
||||
return connector, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
func isConnectorCancelStatus(status string) bool {
|
||||
status = strings.TrimSpace(status)
|
||||
return status == string(entity.TaskStatusCancel) || strings.EqualFold(status, "CANCEL")
|
||||
}
|
||||
|
||||
func isConnectorScheduleStatus(status string) bool {
|
||||
status = strings.TrimSpace(status)
|
||||
return status == string(entity.TaskStatusSchedule) || strings.EqualFold(status, "SCHEDULE")
|
||||
}
|
||||
|
||||
// RebuildConnector schedules a rebuild for an accessible connector and knowledge base.
|
||||
func (s *ConnectorService) RebuildConnector(connectorID, userID, kbID string) (bool, common.ErrorCode, error) {
|
||||
if connectorID == "" {
|
||||
|
||||
Reference in New Issue
Block a user