mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
Go: implement connector get API (#15259)
## Summary - Add Go REST support for `GET /api/v1/connectors/:connector_id`. - Reuse the Python API behavior by returning the connector only when the current user can access its tenant. - Add focused handler coverage for success and unauthorized responses. Co-authored-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
@@ -19,15 +19,21 @@ package handler
|
||||
import (
|
||||
"net/http"
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/entity"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"ragflow/internal/service"
|
||||
)
|
||||
|
||||
type connectorService interface {
|
||||
ListConnectors(userID string) (*service.ListConnectorsResponse, error)
|
||||
GetConnector(connectorID string, userID string) (*entity.Connector, common.ErrorCode, error)
|
||||
}
|
||||
|
||||
// ConnectorHandler connector handler
|
||||
type ConnectorHandler struct {
|
||||
connectorService *service.ConnectorService
|
||||
connectorService connectorService
|
||||
userService *service.UserService
|
||||
}
|
||||
|
||||
@@ -71,3 +77,31 @@ func (h *ConnectorHandler) ListConnectors(c *gin.Context) {
|
||||
"message": "success",
|
||||
})
|
||||
}
|
||||
|
||||
// GetConnector get connector
|
||||
// @Summary Get Connector
|
||||
// @Description Get connector details for the current user
|
||||
// @Tags connector
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /api/v1/connectors/{connector_id} [get]
|
||||
func (h *ConnectorHandler) GetConnector(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
connector, code, err := h.connectorService.GetConnector(c.Param("connector_id"), user.ID)
|
||||
if err != nil {
|
||||
jsonError(c, code, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeSuccess,
|
||||
"data": connector,
|
||||
"message": "success",
|
||||
})
|
||||
}
|
||||
|
||||
97
internal/handler/connector_test.go
Normal file
97
internal/handler/connector_test.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/entity"
|
||||
"ragflow/internal/service"
|
||||
)
|
||||
|
||||
type fakeConnectorService struct {
|
||||
connector *entity.Connector
|
||||
code common.ErrorCode
|
||||
err error
|
||||
}
|
||||
|
||||
func (s fakeConnectorService) ListConnectors(string) (*service.ListConnectorsResponse, error) {
|
||||
return &service.ListConnectorsResponse{Connectors: []*dao.ConnectorListItem{}}, nil
|
||||
}
|
||||
|
||||
func (s fakeConnectorService) GetConnector(string, string) (*entity.Connector, common.ErrorCode, error) {
|
||||
if s.err != nil {
|
||||
return nil, s.code, s.err
|
||||
}
|
||||
return s.connector, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
func TestConnectorHandlerGetConnector(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
service fakeConnectorService
|
||||
wantCode common.ErrorCode
|
||||
wantID string
|
||||
wantMsg string
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
service: fakeConnectorService{connector: &entity.Connector{
|
||||
ID: "connector-1",
|
||||
TenantID: "tenant-1",
|
||||
Name: "Docs",
|
||||
Source: "google_drive",
|
||||
Status: "unstart",
|
||||
Config: entity.JSONMap{"folder": "docs"},
|
||||
}},
|
||||
wantCode: common.CodeSuccess,
|
||||
wantID: "connector-1",
|
||||
},
|
||||
{
|
||||
name: "unauthorized",
|
||||
service: fakeConnectorService{code: common.CodeAuthenticationError, err: fmt.Errorf("No authorization.")},
|
||||
wantCode: common.CodeAuthenticationError,
|
||||
wantMsg: "No authorization.",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := &ConnectorHandler{connectorService: tt.service}
|
||||
router := gin.New()
|
||||
router.GET("/api/v1/connectors/:connector_id", func(c *gin.Context) {
|
||||
c.Set("user", &entity.User{ID: "tenant-1"})
|
||||
h.GetConnector(c)
|
||||
})
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/connectors/connector-1", nil)
|
||||
router.ServeHTTP(resp, req)
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%s", resp.Code, resp.Body.String())
|
||||
}
|
||||
|
||||
var body map[string]interface{}
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("unmarshal response: %v", err)
|
||||
}
|
||||
if body["code"] != float64(tt.wantCode) {
|
||||
t.Fatalf("code=%v body=%v", body["code"], body)
|
||||
}
|
||||
if tt.wantMsg != "" && body["message"] != tt.wantMsg {
|
||||
t.Fatalf("message=%v", body["message"])
|
||||
}
|
||||
if tt.wantID != "" && body["data"].(map[string]interface{})["id"] != tt.wantID {
|
||||
t.Fatalf("data=%v", body["data"])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -308,6 +308,7 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
connector := v1.Group("/connectors")
|
||||
{
|
||||
connector.GET("/", r.connectorHandler.ListConnectors)
|
||||
connector.GET("/:connector_id", r.connectorHandler.GetConnector)
|
||||
}
|
||||
|
||||
system := v1.Group("/system")
|
||||
|
||||
@@ -17,7 +17,13 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/entity"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// ConnectorService connector service
|
||||
@@ -39,6 +45,37 @@ type ListConnectorsResponse struct {
|
||||
Connectors []*dao.ConnectorListItem `json:"connectors"`
|
||||
}
|
||||
|
||||
// GetConnector returns one connector when the user can access its tenant.
|
||||
func (s *ConnectorService) GetConnector(connectorID string, userID string) (*entity.Connector, common.ErrorCode, error) {
|
||||
if connectorID == "" {
|
||||
return nil, common.CodeDataError, fmt.Errorf("connector_id is required")
|
||||
}
|
||||
|
||||
connector, err := s.connectorDAO.GetByID(connectorID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, common.CodeAuthenticationError, fmt.Errorf("No authorization.")
|
||||
}
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
|
||||
if connector.TenantID == userID {
|
||||
return connector, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
tenantIDs, err := s.userTenantDAO.GetTenantIDsByUserID(userID)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
for _, tenantID := range tenantIDs {
|
||||
if tenantID == connector.TenantID {
|
||||
return connector, common.CodeSuccess, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, common.CodeAuthenticationError, fmt.Errorf("No authorization.")
|
||||
}
|
||||
|
||||
// ListConnectors list connectors for a user
|
||||
// Equivalent to Python's ConnectorService.list(current_user.id)
|
||||
func (s *ConnectorService) ListConnectors(userID string) (*ListConnectorsResponse, error) {
|
||||
|
||||
Reference in New Issue
Block a user