Go: set and list default models (#14191)

### What problem does this PR solve?

```
RAGFlow(user)> set default vlm "zhipu-ai" "ccc" "glm-4.6v-flash";
SUCCESS
RAGFlow(user)> list default models;
+--------+----------------+----------------+----------------+------------+
| enable | model_instance | model_name     | model_provider | model_type |
+--------+----------------+----------------+----------------+------------+
| true   | ccc            | glm-4.6v-flash | zhipu-ai       | llm        |
| true   | ccc            | glm-4.6v-flash | zhipu-ai       | image2text |
+--------+----------------+----------------+----------------+------------+
```

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
Jin Hai
2026-04-17 18:05:33 +08:00
committed by GitHub
parent 28d8b1c883
commit 94106646e7
8 changed files with 466 additions and 19 deletions

View File

@@ -26,6 +26,15 @@
],
"features": {}
},
{
"name": "glm-4.6v-Flash",
"max_tokens": 128000,
"model_types": [
"chat",
"image2text"
],
"features": {}
},
{
"name": "glm-4.5-x",
"max_tokens": 128000,

View File

@@ -248,6 +248,10 @@ func (c *RAGFlowClient) ExecuteUserCommand(cmd *Command) (ResponseIf, error) {
return c.UseModel(cmd)
case "show_current_model":
return c.ShowCurrentModel(cmd)
case "set_default_model":
return c.SetDefaultModel(cmd)
case "list_user_default_models":
return c.ListDefaultModels(cmd)
// Dataset, metadata commands
case "create_dataset_table":
return c.CreateDatasetInDocEngine(cmd)

View File

@@ -373,6 +373,75 @@ func (c *RAGFlowClient) ShowModel(cmd *Command) (ResponseIf, error) {
return &result, nil
}
func (c *RAGFlowClient) SetDefaultModel(cmd *Command) (ResponseIf, error) {
modeType, ok := cmd.Params["model_type"].(string)
if !ok {
return nil, fmt.Errorf("model_type not provided")
}
modelProvider, ok := cmd.Params["model_provider"].(string)
if !ok {
return nil, fmt.Errorf("model_provider not provided")
}
modelInstance, ok := cmd.Params["model_instance"].(string)
if !ok {
return nil, fmt.Errorf("model_instance not provided")
}
modelName, ok := cmd.Params["model_name"].(string)
if !ok {
return nil, fmt.Errorf("model_name not provided")
}
payload := map[string]interface{}{
"model_type": modeType,
"model_provider": modelProvider,
"model_instance": modelInstance,
"model_name": modelName,
}
resp, err := c.HTTPClient.Request("PATCH", "/models", true, "web", nil, payload)
if err != nil {
return nil, fmt.Errorf("failed to set default model: %w", err)
}
if resp.StatusCode != 200 {
return nil, fmt.Errorf("failed to set default model: HTTP %d, body: %s", resp.StatusCode, string(resp.Body))
}
var result SimpleResponse
if err = json.Unmarshal(resp.Body, &result); err != nil {
return nil, fmt.Errorf("failed to set default model: invalid JSON (%w)", err)
}
if result.Code != 0 {
return nil, fmt.Errorf("%s", result.Message)
}
result.Duration = resp.Duration
return &result, nil
}
func (c *RAGFlowClient) ListDefaultModels(cmd *Command) (ResponseIf, error) {
resp, err := c.HTTPClient.Request("GET", "/models", true, "web", nil, nil)
if err != nil {
return nil, fmt.Errorf("failed to list default models: %w", err)
}
if resp.StatusCode != 200 {
return nil, fmt.Errorf("failed to list default models: HTTP %d, body: %s", resp.StatusCode, string(resp.Body))
}
var result CommonResponse
if err = json.Unmarshal(resp.Body, &result); err != nil {
return nil, fmt.Errorf("failed to list default models: invalid JSON (%w)", err)
}
if result.Code != 0 {
return nil, fmt.Errorf("%s", result.Message)
}
result.Duration = resp.Duration
return &result, nil
}
// readPassword reads password from terminal without echoing
func ReadPassword() (string, error) {
if !term.IsTerminal(int(os.Stdin.Fd())) {

View File

@@ -1597,35 +1597,49 @@ func (p *Parser) parseSetVariable() (*Command, error) {
func (p *Parser) parseSetDefault() (*Command, error) {
p.nextToken() // consume DEFAULT
var modelType, modelID string
var modelType, modelProvider, modelInstance, modelName string
var err error
switch p.curToken.Type {
case TokenLLM:
modelType = "llm_id"
modelType = "chat"
case TokenVLM:
modelType = "img2txt_id"
modelType = "image2text"
case TokenEmbedding:
modelType = "embd_id"
modelType = "embedding"
case TokenReranker:
modelType = "reranker_id"
modelType = "rerank"
case TokenASR:
modelType = "asr_id"
modelType = "asr"
case TokenTTS:
modelType = "tts_id"
modelType = "tts"
default:
return nil, fmt.Errorf("unknown model type: %s", p.curToken.Value)
}
p.nextToken()
id, err := p.parseQuotedString()
modelProvider, err = p.parseQuotedString()
if err != nil {
return nil, err
}
p.nextToken()
modelInstance, err = p.parseQuotedString()
if err != nil {
return nil, err
}
p.nextToken()
modelName, err = p.parseQuotedString()
if err != nil {
return nil, err
}
modelID = id
cmd := NewCommand("set_default_model")
cmd.Params["model_type"] = modelType
cmd.Params["model_id"] = modelID
cmd.Params["model_provider"] = modelProvider
cmd.Params["model_instance"] = modelInstance
cmd.Params["model_name"] = modelName
p.nextToken()
// Semicolon is optional for UNSET TOKEN
@@ -2601,7 +2615,6 @@ func (p *Parser) parseRemoveTags() (*Command, error) {
return cmd, nil
}
// parseRemoveChunk parses:
// - REMOVE CHUNKS 'chunk_id1', 'chunk_id2' FROM DOCUMENT 'doc_id';
// - REMOVE ALL CHUNKS FROM DOCUMENT 'doc_id';

View File

@@ -149,7 +149,7 @@ type Provider struct {
Tags string `json:"tags"`
URL string `json:"url"`
URLSuffix models.URLSuffix `json:"url_suffix"`
Models []Model `json:"models"`
Models []*Model `json:"models"`
ModelDriver models.ModelDriver
}
@@ -547,7 +547,7 @@ func (pm *ProviderManager) FindProvider(name string) *Provider {
func (pm *ProviderManager) findModel(provider *Provider, modelName string) *Model {
for i := range provider.Models {
if strings.EqualFold(provider.Models[i].Name, modelName) {
return &provider.Models[i]
return provider.Models[i]
}
}
return nil

View File

@@ -42,6 +42,81 @@ func NewTenantHandler(tenantService *service.TenantService, userService *service
}
}
func (h *TenantHandler) GetModels(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
defaultModels, err := h.tenantService.ListTenantDefaultModels(user.ID)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeExceptionError,
"message": err.Error(),
"data": false,
})
return
}
if defaultModels == nil {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeDataError,
"message": "No default models",
"data": nil,
})
return
}
c.JSON(http.StatusOK, gin.H{
"code": common.CodeSuccess,
"message": "success",
"data": defaultModels,
})
}
type SetModelRequest struct {
ModelProvider string `json:"model_provider" binding:"required"`
ModelInstance string `json:"model_instance" binding:"required"`
ModelName string `json:"model_name" binding:"required"`
ModelType string `json:"model_type" binding:"required"`
}
func (h *TenantHandler) SetModels(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
// Parse request body (same as Python get_request_json())
var req SetModelRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"code": common.CodeBadRequest,
"data": nil,
"message": "Invalid request body: " + err.Error(),
})
return
}
err := h.tenantService.SetTenantDefaultModels(user.ID, req.ModelProvider, req.ModelInstance, req.ModelName, req.ModelType)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeExceptionError,
"message": err.Error(),
"data": false,
})
return
}
c.JSON(http.StatusOK, gin.H{
"code": common.CodeSuccess,
"message": "success",
"data": nil,
})
}
// TenantInfo get tenant information
// @Summary Get Tenant Information
// @Description Get current user's tenant information (owner tenant)

View File

@@ -219,6 +219,12 @@ func (r *Router) Setup(engine *gin.Engine) {
provider.POST("/:provider_name/instances/:instance_name/models/:model_name", r.providerHandler.ChatToModel)
}
model := v1.Group("/models")
{
model.GET("/", r.tenantHandler.GetModels)
model.PATCH("/", r.tenantHandler.SetModels)
}
system := v1.Group("/system")
{
system.GET("/version", r.systemHandler.GetVersion)

View File

@@ -28,17 +28,27 @@ import (
// TenantService tenant service
type TenantService struct {
tenantDAO *dao.TenantDAO
userTenantDAO *dao.UserTenantDAO
docEngine engine.DocEngine
tenantDAO *dao.TenantDAO
userTenantDAO *dao.UserTenantDAO
modelProviderDAO *dao.TenantModelProviderDAO
modelInstanceDAO *dao.TenantModelInstanceDAO
modelDAO *dao.TenantModelDAO
modelGroupDAO *dao.TenantModelGroupDAO
modelGroupMappingDAO *dao.TenantModelGroupMappingDAO
docEngine engine.DocEngine
}
// NewTenantService create tenant service
func NewTenantService() *TenantService {
return &TenantService{
tenantDAO: dao.NewTenantDAO(),
userTenantDAO: dao.NewUserTenantDAO(),
docEngine: engine.Get(),
tenantDAO: dao.NewTenantDAO(),
userTenantDAO: dao.NewUserTenantDAO(),
modelProviderDAO: dao.NewTenantModelProviderDAO(),
modelInstanceDAO: dao.NewTenantModelInstanceDAO(),
modelDAO: dao.NewTenantModelDAO(),
modelGroupDAO: dao.NewTenantModelGroupDAO(),
modelGroupMappingDAO: dao.NewTenantModelGroupMappingDAO(),
docEngine: engine.Get(),
}
}
@@ -282,3 +292,264 @@ func (s *TenantService) DeleteMetadataInDocEngine(tenantID string) (common.Error
return common.CodeSuccess, nil
}
type ModelItem struct {
ModelProvider *string `json:"model_provider"`
ModelInstance *string `json:"model_instance"`
ModelName *string `json:"model_name"`
ModelType string `json:"model_type"`
Enable bool `json:"enable"`
}
type DefaultModelResponse struct {
Models []ModelItem `json:"models,omitempty"`
//TenantID string `json:"tenant_id"`
//ChatModelProvider *string `json:"chat_model_provider"`
//ChatModelInstance *string `json:"chat_model_instance"`
//ChatModelName *string `json:"chat_model_name"`
//ChatModelEnable bool `json:"chat_model_enable"`
//EmbeddingModelProvider *string `json:"embedding_model_provider"`
//EmbeddingModelInstance *string `json:"embedding_model_instance"`
//EmbeddingModelName *string `json:"embedding_model_name"`
//EmbeddingModelEnable bool `json:"embedding_model_enable"`
//RerankModelProvider *string `json:"rerank_model_provider"`
//RerankModelInstance *string `json:"rerank_model_instance"`
//RerankModelName *string `json:"rerank_model_name"`
//RerankModelEnable bool `json:"rerank_model_enable"`
//ASRModelProvider *string `json:"asr_model_provider"`
//ASRModelInstance *string `json:"asr_model_instance"`
//ASRModelName *string `json:"asr_model_name"`
//ASREnable bool `json:"asr_enable"`
//Image2TextModelProvider *string `json:"image2text_model_provider"`
//Image2TextModelInstance *string `json:"image2text_model_instance"`
//Image2TextModelName *string `json:"image2text_model_name"`
//Image2TextModelEnable bool `json:"image2text_model_enable"`
//TTSModelProvider *string `json:"tts_model_provider"`
//TTSModelInstance *string `json:"tts_model_instance"`
//TTSModelName *string `json:"tts_model_name"`
//TTSModelEnable bool `json:"tts_model_enable"`
}
func (s *TenantService) GetModelInfo(tenantID string, defaultModel string, modelType string) (*string, *string, *string, bool, error) {
// normally the model string is: modelName@instanceName@providerName, sometimes it's just modelName@providerName
// for the 1st case, parse defaultChatModel into three parts
defaultChatModelParts := strings.Split(defaultModel, "@")
var providerName *string
var instanceName *string
var modelName *string
if len(defaultChatModelParts) == 3 {
providerName = &defaultChatModelParts[2]
instanceName = &defaultChatModelParts[1]
modelName = &defaultChatModelParts[0]
} else if len(defaultChatModelParts) == 2 {
providerName = &defaultChatModelParts[1]
instanceName = new(string)
*instanceName = "default"
modelName = &defaultChatModelParts[0]
} else {
return nil, nil, nil, false, fmt.Errorf("invalid model string: %s", defaultModel)
}
// Check if the provider and instance exists
modelProvider, err := s.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, *providerName)
if err != nil {
return nil, nil, nil, false, err
}
modelInstance, err := s.modelInstanceDAO.GetByProviderIDAndInstanceName(modelProvider.ID, *instanceName)
if err != nil {
return nil, nil, nil, false, err
}
modelSchema, err := dao.GetModelProviderManager().GetModelByName(*providerName, *modelName)
if err != nil {
return nil, nil, nil, false, err
}
if !modelSchema.ModelTypeMap[modelType] {
return nil, nil, nil, false, fmt.Errorf("model %s isn't a chat model", *modelName)
}
var modelEntity *entity.TenantModel
modelEntity, err = s.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(modelProvider.ID, modelInstance.ID, *modelName)
if err != nil {
errString := err.Error()
if !strings.Contains(errString, "record not found") {
return nil, nil, nil, false, err
}
}
enable := modelEntity == nil
return providerName, instanceName, modelName, enable, nil
}
func (s *TenantService) ListTenantDefaultModels(userID string) ([]ModelItem, error) {
tenantInfos, err := s.tenantDAO.GetInfoByUserID(userID)
if err != nil {
return nil, err
}
if len(tenantInfos) == 0 {
return nil, nil // No tenant found (should not happen for valid user)
}
ownedTenant := tenantInfos[0]
var result []ModelItem
defaultChatModelProvider, defaultChatModelInstance, defaultChatModelName, defaultChatModelEnable, err := s.GetModelInfo(ownedTenant.TenantID, ownedTenant.LLMID, "chat")
if err == nil {
result = append(result, ModelItem{
ModelProvider: defaultChatModelProvider,
ModelInstance: defaultChatModelInstance,
ModelName: defaultChatModelName,
ModelType: "llm",
Enable: defaultChatModelEnable,
})
}
defaultEmbeddingModelProvider, defaultEmbeddingModelInstance, defaultEmbeddingModelName, defaultEmbeddingModelEnable, err := s.GetModelInfo(ownedTenant.TenantID, ownedTenant.EmbDID, "embedding")
if err == nil {
result = append(result, ModelItem{
ModelProvider: defaultEmbeddingModelProvider,
ModelInstance: defaultEmbeddingModelInstance,
ModelName: defaultEmbeddingModelName,
ModelType: "embedding",
Enable: defaultEmbeddingModelEnable,
})
}
defaultRerankModelProvider, defaultRerankModelInstance, defaultRerankModelName, defaultRerankModelEnable, err := s.GetModelInfo(ownedTenant.TenantID, ownedTenant.RerankID, "rerank")
if err == nil {
result = append(result, ModelItem{
ModelProvider: defaultRerankModelProvider,
ModelInstance: defaultRerankModelInstance,
ModelName: defaultRerankModelName,
ModelType: "rerank",
Enable: defaultRerankModelEnable,
})
}
defaultASRModelProvider, defaultASRModelInstance, defaultASRModelName, defaultASREnable, err := s.GetModelInfo(ownedTenant.TenantID, ownedTenant.ASRID, "asr")
if err == nil {
result = append(result, ModelItem{
ModelProvider: defaultASRModelProvider,
ModelInstance: defaultASRModelInstance,
ModelName: defaultASRModelName,
ModelType: "asr",
Enable: defaultASREnable,
})
}
defaultImage2TextModelProvider, defaultImage2TextModelInstance, defaultImage2TextModelName, defaultImage2TextModelEnable, err := s.GetModelInfo(ownedTenant.TenantID, ownedTenant.Img2TxtID, "image2text")
if err == nil {
result = append(result, ModelItem{
ModelProvider: defaultImage2TextModelProvider,
ModelInstance: defaultImage2TextModelInstance,
ModelName: defaultImage2TextModelName,
ModelType: "image2text",
Enable: defaultImage2TextModelEnable,
})
}
if ownedTenant.TTSID == nil {
return result, nil
}
defaultTTSModelProvider, defaultTTSModelInstance, defaultTTSModelName, defaultTTSModelEnable, err := s.GetModelInfo(ownedTenant.TenantID, *ownedTenant.TTSID, "tts")
if err == nil {
result = append(result, ModelItem{
ModelProvider: defaultTTSModelProvider,
ModelInstance: defaultTTSModelInstance,
ModelName: defaultTTSModelName,
ModelType: "tts",
Enable: defaultTTSModelEnable,
})
}
return result, nil
}
func (s *TenantService) checkModelAvailable(tenantID, providerName, instanceName, modelName, modelType string) error {
// Check if the provider and instance exists
modelProvider, err := s.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, providerName)
if err != nil {
return err
}
modelInstance, err := s.modelInstanceDAO.GetByProviderIDAndInstanceName(modelProvider.ID, instanceName)
if err != nil {
return err
}
modelSchema, err := dao.GetModelProviderManager().GetModelByName(providerName, modelName)
if err != nil {
return err
}
if !modelSchema.ModelTypeMap[modelType] {
return fmt.Errorf("model %s isn't a chat model", modelName)
}
var modelEntity *entity.TenantModel
modelEntity, err = s.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(modelProvider.ID, modelInstance.ID, modelName)
if err != nil || modelEntity != nil {
var errString = err.Error()
if errString == "record not found" {
return nil
}
return fmt.Errorf("model %s isn't available", modelName)
}
return nil
}
func (s *TenantService) SetTenantDefaultModels(userID, modelProvider, modelInstance, modelName, modelType string) error {
tenantInfos, err := s.tenantDAO.GetInfoByUserID(userID)
if err != nil {
return err
}
if len(tenantInfos) == 0 {
return nil // No tenant found (should not happen for valid user)
}
ownedTenant := tenantInfos[0]
err = s.checkModelAvailable(ownedTenant.TenantID, modelProvider, modelInstance, modelName, modelType)
if err != nil {
return err
}
var modelTypeID string
if modelType == "chat" {
modelTypeID = "llm_id"
}
if modelType == "embedding" {
modelTypeID = "embd_id"
}
if modelType == "rerank" {
modelTypeID = "rerank_id"
}
if modelType == "asr" {
modelTypeID = "asr_id"
}
if modelType == "image2text" {
modelTypeID = "img2txt_id"
}
if modelType == "tts" {
modelTypeID = "tts_id"
}
if modelTypeID == "" {
return fmt.Errorf("model type %s is invalid", modelType)
}
defaultModel := fmt.Sprintf("%s@%s@%s", modelName, modelInstance, modelProvider)
err = s.tenantDAO.Update(ownedTenant.TenantID, map[string]interface{}{
modelTypeID: defaultModel,
})
return nil
}