mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-05 19:08:38 +08:00
Go: add drop instance models (#14485)
### What problem does this PR solve? 1. drop instance model 2. Fix issue of drop instance but not drop models. ### Type of change - [x] New Feature (non-breaking change which adds functionality) Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
@@ -700,8 +700,6 @@ func (p *Parser) parseAdminDropCommand() (*Command, error) {
|
||||
return p.parseDropUser()
|
||||
case TokenRole:
|
||||
return p.parseDropRole()
|
||||
case TokenModel:
|
||||
return p.parseDropModelProvider()
|
||||
case TokenDataset:
|
||||
return p.parseDropDataset()
|
||||
case TokenChat:
|
||||
|
||||
@@ -242,6 +242,8 @@ func (c *RAGFlowClient) ExecuteUserCommand(cmd *Command) (ResponseIf, error) {
|
||||
return c.AlterProviderInstance(cmd)
|
||||
case "drop_provider_instance":
|
||||
return c.DropProviderInstance(cmd)
|
||||
case "drop_instance_model":
|
||||
return c.DropInstanceModel(cmd)
|
||||
case "enable_model":
|
||||
return c.EnableOrDisableModel(cmd, "enable")
|
||||
case "disable_model":
|
||||
|
||||
@@ -1383,6 +1383,56 @@ func (c *RAGFlowClient) DropProviderInstance(cmd *Command) (ResponseIf, error) {
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// DropInstanceModel deletes a provider instance, only works for local deployed model
|
||||
// DROP MODEL <name> FROM <provider_name> <instance_name>
|
||||
func (c *RAGFlowClient) DropInstanceModel(cmd *Command) (ResponseIf, error) {
|
||||
if c.ServerType != "user" {
|
||||
return nil, fmt.Errorf("this command is only allowed in USER mode")
|
||||
}
|
||||
|
||||
instanceName, ok := cmd.Params["instance_name"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("instance name not provided")
|
||||
}
|
||||
|
||||
providerName, ok := cmd.Params["provider_name"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("provider name not provided")
|
||||
}
|
||||
|
||||
modelName, ok := cmd.Params["model_name"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("model name not provided")
|
||||
}
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"models": []string{modelName},
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("/providers/%s/instances/%s/models", providerName, instanceName)
|
||||
|
||||
resp, err := c.HTTPClient.Request("DELETE", url, true, "web", nil, payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to drop instance: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("failed to drop instance: HTTP %d, body: %s", resp.StatusCode, string(resp.Body))
|
||||
}
|
||||
|
||||
var result SimpleResponse
|
||||
if err = json.Unmarshal(resp.Body, &result); err != nil {
|
||||
return nil, fmt.Errorf("drop instance failed: invalid JSON (%w)", err)
|
||||
}
|
||||
|
||||
if result.Code != 0 {
|
||||
return nil, fmt.Errorf("%s", result.Message)
|
||||
}
|
||||
|
||||
result.Duration = resp.Duration
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (c *RAGFlowClient) ListInstanceModels(cmd *Command) (ResponseIf, error) {
|
||||
if c.ServerType != "user" {
|
||||
return nil, fmt.Errorf("this command is only allowed in USER mode")
|
||||
@@ -1722,7 +1772,7 @@ func (c *RAGFlowClient) AddCustomModel(cmd *Command) (ResponseIf, error) {
|
||||
}
|
||||
|
||||
// chat, vision, embedding, rerank, tts, asr, ocr
|
||||
modelType, ok := cmd.Params["model_type"].(string)
|
||||
modelTypes, ok := cmd.Params["model_types"].([]string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("model type not provided")
|
||||
}
|
||||
@@ -1738,7 +1788,7 @@ func (c *RAGFlowClient) AddCustomModel(cmd *Command) (ResponseIf, error) {
|
||||
"provider_name": providerName,
|
||||
"instance_name": instanceName,
|
||||
"model_name": modelName,
|
||||
"model_type": modelType,
|
||||
"model_types": modelTypes,
|
||||
"max_tokens": maxTokens,
|
||||
}
|
||||
|
||||
|
||||
@@ -772,7 +772,7 @@ func (p *Parser) parseAddModel() (*Command, error) {
|
||||
}
|
||||
p.nextToken()
|
||||
|
||||
modelType := ""
|
||||
var modelTypes []string
|
||||
var supportThink *bool = nil
|
||||
maxTokens := 0
|
||||
if p.curToken.Type == TokenWith {
|
||||
@@ -789,46 +789,25 @@ func (p *Parser) parseAddModel() (*Command, error) {
|
||||
*supportThink = true
|
||||
case TokenVision:
|
||||
p.nextToken()
|
||||
if modelType != "" {
|
||||
return nil, fmt.Errorf("model type is %s, attempt to change to vision", modelType)
|
||||
}
|
||||
modelType = "vision"
|
||||
modelTypes = append(modelTypes, "vision")
|
||||
case TokenChat:
|
||||
p.nextToken()
|
||||
if modelType != "" {
|
||||
return nil, fmt.Errorf("model type is %s, attempt to change to chat", modelType)
|
||||
}
|
||||
modelType = "chat"
|
||||
modelTypes = append(modelTypes, "chat")
|
||||
case TokenEmbedding:
|
||||
if modelType != "" {
|
||||
return nil, fmt.Errorf("model type is %s, attempt to change to embedding", modelType)
|
||||
}
|
||||
p.nextToken()
|
||||
modelType = "embedding"
|
||||
modelTypes = append(modelTypes, "embedding")
|
||||
case TokenRerank:
|
||||
if modelType != "" {
|
||||
return nil, fmt.Errorf("model type is %s, attempt to change to rerank", modelType)
|
||||
}
|
||||
p.nextToken()
|
||||
modelType = "rerank"
|
||||
modelTypes = append(modelTypes, "rerank")
|
||||
case TokenOCR:
|
||||
if modelType != "" {
|
||||
return nil, fmt.Errorf("model type is %s, attempt to change to OCR", modelType)
|
||||
}
|
||||
p.nextToken()
|
||||
modelType = "ocr"
|
||||
modelTypes = append(modelTypes, "ocr")
|
||||
case TokenTTS:
|
||||
if modelType != "" {
|
||||
return nil, fmt.Errorf("model type is %s, attempt to change to TTS", modelType)
|
||||
}
|
||||
p.nextToken()
|
||||
modelType = "tts"
|
||||
modelTypes = append(modelTypes, "tts")
|
||||
case TokenASR:
|
||||
if modelType != "" {
|
||||
return nil, fmt.Errorf("model type is %s, attempt to change to ASR", modelType)
|
||||
}
|
||||
p.nextToken()
|
||||
modelType = "asr"
|
||||
modelTypes = append(modelTypes, "asr")
|
||||
case TokenTokens:
|
||||
p.nextToken() // pass TOKENS
|
||||
if maxTokens != 0 {
|
||||
@@ -854,7 +833,7 @@ func (p *Parser) parseAddModel() (*Command, error) {
|
||||
|
||||
cmd := NewCommand("add_custom_model")
|
||||
cmd.Params["model_name"] = modelName
|
||||
cmd.Params["model_type"] = modelType
|
||||
cmd.Params["model_types"] = modelTypes
|
||||
cmd.Params["provider_name"] = providerName
|
||||
cmd.Params["instance_name"] = instanceName
|
||||
if supportThink != nil {
|
||||
@@ -862,12 +841,6 @@ func (p *Parser) parseAddModel() (*Command, error) {
|
||||
}
|
||||
cmd.Params["max_tokens"] = maxTokens
|
||||
|
||||
if modelType != "chat" && modelType != "vision" {
|
||||
if supportThink != nil && *supportThink {
|
||||
return nil, fmt.Errorf("think not supported for model type %s", modelType)
|
||||
}
|
||||
}
|
||||
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
@@ -951,8 +924,6 @@ func (p *Parser) parseDropCommand() (*Command, error) {
|
||||
return p.parseDropUser()
|
||||
case TokenRole:
|
||||
return p.parseDropRole()
|
||||
case TokenModel:
|
||||
return p.parseDropModelProvider()
|
||||
case TokenDataset:
|
||||
return p.parseDropDataset()
|
||||
case TokenChat:
|
||||
@@ -965,6 +936,8 @@ func (p *Parser) parseDropCommand() (*Command, error) {
|
||||
return p.parseDropMetadataTable()
|
||||
case TokenInstance:
|
||||
return p.parseDropInstance()
|
||||
case TokenModel:
|
||||
return p.parseDropInstanceModel()
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown DROP target: %s", p.curToken.Value)
|
||||
}
|
||||
@@ -1099,29 +1072,6 @@ func (p *Parser) parseDropRole() (*Command, error) {
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
func (p *Parser) parseDropModelProvider() (*Command, error) {
|
||||
p.nextToken() // consume MODEL
|
||||
if p.curToken.Type != TokenProvider {
|
||||
return nil, fmt.Errorf("expected PROVIDER")
|
||||
}
|
||||
p.nextToken()
|
||||
|
||||
providerName, err := p.parseQuotedString()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cmd := NewCommand("drop_model_provider")
|
||||
cmd.Params["provider_name"] = providerName
|
||||
|
||||
p.nextToken()
|
||||
// Semicolon is optional for UNSET TOKEN
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
// parseDeleteProvider parses DELETE PROVIDER <name> command
|
||||
func (p *Parser) parseDeleteProvider() (*Command, error) {
|
||||
p.nextToken() // consume PROVIDER
|
||||
@@ -1610,6 +1560,47 @@ func (p *Parser) parseDropInstance() (*Command, error) {
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
// parseDropInstanceModel parses DROP MODEL <name> FROM <provider_name> <instance_name> command
|
||||
// Only works for local deployed model
|
||||
func (p *Parser) parseDropInstanceModel() (*Command, error) {
|
||||
p.nextToken() // consume MODEL
|
||||
|
||||
modelName, err := p.parseQuotedString()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("expected instance name: %w", err)
|
||||
}
|
||||
p.nextToken()
|
||||
|
||||
if p.curToken.Type != TokenFrom {
|
||||
return nil, fmt.Errorf("expected FROM")
|
||||
}
|
||||
p.nextToken()
|
||||
|
||||
providerName, err := p.parseQuotedString()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("expected provider name after FROM PROVIDER: %w", err)
|
||||
}
|
||||
p.nextToken()
|
||||
|
||||
instanceName, err := p.parseQuotedString()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("expected instance name after provider name: %w", err)
|
||||
}
|
||||
p.nextToken()
|
||||
|
||||
cmd := NewCommand("drop_instance_model")
|
||||
cmd.Params["instance_name"] = instanceName
|
||||
cmd.Params["provider_name"] = providerName
|
||||
cmd.Params["model_name"] = modelName
|
||||
|
||||
p.nextToken()
|
||||
// Semicolon is optional
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
func (p *Parser) parseGrantCommand() (*Command, error) {
|
||||
p.nextToken() // consume GRANT
|
||||
|
||||
|
||||
@@ -37,6 +37,16 @@ func (dao *TenantModelDAO) DeleteByModelID(modelID string) (int64, error) {
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
func (dao *TenantModelDAO) DeleteByProviderIDAndInstanceID(provideID, instanceID string) (int64, error) {
|
||||
result := DB.Unscoped().Where("provider_id = ? AND instance_id = ?", provideID, instanceID).Delete(&entity.TenantModel{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
func (dao *TenantModelDAO) DeleteByProviderIDAndInstanceIDAndModelName(provideID, instanceID, modelName string) (int64, error) {
|
||||
result := DB.Unscoped().Where("provider_id = ? AND instance_id = ? AND model_name = ?", provideID, instanceID, modelName).Delete(&entity.TenantModel{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
// GetByID get tenant model by primary key (id)
|
||||
func (dao *TenantModelDAO) GetByID(id string) (*entity.TenantModel, error) {
|
||||
var model entity.TenantModel
|
||||
|
||||
@@ -50,6 +50,7 @@ type URLSuffix struct {
|
||||
|
||||
type ChatConfig struct {
|
||||
Stream *bool
|
||||
Vision *bool
|
||||
Thinking *bool
|
||||
MaxTokens *int
|
||||
Temperature *float64
|
||||
|
||||
@@ -682,7 +682,7 @@ func (h *ProviderHandler) AddCustomModel(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.ModelType == "" {
|
||||
if req.ModelTypes == nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Model type is required",
|
||||
@@ -707,6 +707,54 @@ func (h *ProviderHandler) AddCustomModel(c *gin.Context) {
|
||||
|
||||
}
|
||||
|
||||
type DropInstanceModelRequest struct {
|
||||
Models []string `json:"models" binding:"required"`
|
||||
}
|
||||
|
||||
func (h *ProviderHandler) DropInstanceModels(c *gin.Context) {
|
||||
providerName := c.Param("provider_name")
|
||||
if providerName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Provider name is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
instanceName := c.Param("instance_name")
|
||||
if instanceName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Instance name is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var req DropInstanceModelRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeBadRequest,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
userID := c.GetString("user_id")
|
||||
|
||||
_, err := h.modelProviderService.DropInstanceModels(providerName, instanceName, userID, req.Models)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": common.CodeServerError,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"message": "success",
|
||||
})
|
||||
}
|
||||
|
||||
type ChatToModelRequest struct {
|
||||
ProviderName *string `json:"provider_name"`
|
||||
InstanceName *string `json:"instance_name"`
|
||||
@@ -768,6 +816,7 @@ func (h *ProviderHandler) ChatToModel(c *gin.Context) {
|
||||
chatConfig := models.ChatConfig{
|
||||
Thinking: &req.Thinking,
|
||||
Stream: &req.Stream,
|
||||
Vision: nil,
|
||||
Stop: &[]string{},
|
||||
DoSample: nil,
|
||||
MaxTokens: nil,
|
||||
|
||||
@@ -219,6 +219,7 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
provider.GET("/:provider_name/instances/:instance_name/models", r.providerHandler.ListInstanceModels)
|
||||
provider.PATCH("/:provider_name/instances/:instance_name/models/*model_name", r.providerHandler.EnableOrDisableModel)
|
||||
provider.POST("/:provider_name/instances/:instance_name/models", r.providerHandler.AddCustomModel)
|
||||
provider.DELETE("/:provider_name/instances/:instance_name/models", r.providerHandler.DropInstanceModels)
|
||||
v1.POST("/chat/completions", r.providerHandler.ChatToModel)
|
||||
}
|
||||
|
||||
|
||||
@@ -478,7 +478,22 @@ func (m *ModelProviderService) DropProviderInstances(providerName, userID string
|
||||
}
|
||||
|
||||
for _, instanceName := range instances {
|
||||
count, err := m.modelInstanceDAO.DeleteByProviderIDAndInstanceName(provider.ID, instanceName)
|
||||
// Get model instance
|
||||
var tenantModelInstance *entity.TenantModelInstance
|
||||
tenantModelInstance, err = m.modelInstanceDAO.GetByProviderIDAndInstanceName(provider.ID, instanceName)
|
||||
if err != nil {
|
||||
return common.CodeServerError, err
|
||||
}
|
||||
|
||||
// Delete all models of this instance
|
||||
var count int64 = 0
|
||||
count, err = m.modelDAO.DeleteByProviderIDAndInstanceID(provider.ID, tenantModelInstance.ID)
|
||||
if err != nil {
|
||||
return common.CodeServerError, err
|
||||
}
|
||||
|
||||
// Delete model instance
|
||||
count, err = m.modelInstanceDAO.DeleteByProviderIDAndInstanceName(provider.ID, instanceName)
|
||||
if err != nil {
|
||||
return common.CodeServerError, err
|
||||
}
|
||||
@@ -491,6 +506,48 @@ func (m *ModelProviderService) DropProviderInstances(providerName, userID string
|
||||
return common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
func (m *ModelProviderService) DropInstanceModels(providerName, instanceName, userID string, models []string) (common.ErrorCode, error) {
|
||||
|
||||
// Get tenant ID from user
|
||||
tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner")
|
||||
if err != nil {
|
||||
return common.CodeServerError, err
|
||||
}
|
||||
|
||||
if len(tenants) == 0 {
|
||||
return common.CodeNotFound, errors.New("user has no tenants")
|
||||
}
|
||||
|
||||
tenantID := tenants[0].TenantID
|
||||
|
||||
// Check if provider exists
|
||||
provider, err := m.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, providerName)
|
||||
if err != nil {
|
||||
return common.CodeServerError, err
|
||||
}
|
||||
|
||||
var modelInstance *entity.TenantModelInstance
|
||||
modelInstance, err = m.modelInstanceDAO.GetByProviderIDAndInstanceName(provider.ID, instanceName)
|
||||
if err != nil {
|
||||
return common.CodeServerError, err
|
||||
}
|
||||
|
||||
for _, modelName := range models {
|
||||
// Delete all models of this instance
|
||||
var count int64 = 0
|
||||
count, err = m.modelDAO.DeleteByProviderIDAndInstanceIDAndModelName(provider.ID, modelInstance.ID, modelName)
|
||||
if err != nil {
|
||||
return common.CodeServerError, err
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
return common.CodeNotFound, fmt.Errorf("model: %s not found", modelName)
|
||||
}
|
||||
}
|
||||
|
||||
return common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
func (m *ModelProviderService) ListInstanceModels(providerName, instanceName, userID string) ([]map[string]interface{}, error) {
|
||||
// Get tenant ID from user
|
||||
tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner")
|
||||
@@ -693,6 +750,9 @@ func (m *ModelProviderService) ChatToModel(providerName, instanceName, modelName
|
||||
apiConfig.Region = ®ion
|
||||
apiConfig.ApiKey = &instance.APIKey
|
||||
|
||||
modelTypes := extra["model_types"]
|
||||
println(modelTypes)
|
||||
|
||||
modelConfig.ModelClass = &providerInfo.Class
|
||||
|
||||
newURL := map[string]string{
|
||||
@@ -891,12 +951,12 @@ func (m *ModelProviderService) GetChatModel(tenantID, compositeModelName string)
|
||||
}
|
||||
|
||||
type AddCustomModelRequest struct {
|
||||
ProviderName string `json:"provider_name"`
|
||||
InstanceName string `json:"instance_name"`
|
||||
ModelName string `json:"model_name"`
|
||||
ModelType string `json:"model_type"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Thinking *bool `json:"thinking"`
|
||||
ProviderName string `json:"provider_name"`
|
||||
InstanceName string `json:"instance_name"`
|
||||
ModelName string `json:"model_name"`
|
||||
ModelTypes []string `json:"model_types"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Thinking *bool `json:"thinking"`
|
||||
}
|
||||
|
||||
func (m *ModelProviderService) AddCustomModel(request *AddCustomModelRequest, userID string) (common.ErrorCode, error) {
|
||||
@@ -938,6 +998,7 @@ func (m *ModelProviderService) AddCustomModel(request *AddCustomModelRequest, us
|
||||
if request.Thinking != nil {
|
||||
extra["thinking"] = *request.Thinking
|
||||
}
|
||||
extra["model_types"] = request.ModelTypes
|
||||
// convert extra to string
|
||||
extraByte, err := json.Marshal(extra)
|
||||
if err != nil {
|
||||
@@ -948,7 +1009,7 @@ func (m *ModelProviderService) AddCustomModel(request *AddCustomModelRequest, us
|
||||
model := &entity.TenantModel{
|
||||
ID: modelID,
|
||||
ModelName: request.ModelName,
|
||||
ModelType: request.ModelType,
|
||||
ModelType: request.ModelTypes[0],
|
||||
ProviderID: provider.ID,
|
||||
InstanceID: instance.ID,
|
||||
Status: "active",
|
||||
|
||||
Reference in New Issue
Block a user