mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
feat(go-cli): support batch model add/remove and optional embedding dimension (#15631)
## Summary This PR improves the Go CLI in two areas: 1. It adds batch model management support, allowing multiple models to be added or removed in a single command. 2. It makes the `dimension` argument optional for the `embed text` command. These changes keep the existing single-model and explicit-dimension behaviors compatible while making the CLI more convenient for common workflows. ## What Changed ### 1. Batch model add/remove support The CLI now supports operating on multiple model names provided in a single quoted string. Supported commands include: ``` add model 'x1 x2 x3' to provider 'vllm' instance 'test' with tokens 1024 chat think vision, token 2048 chat, token 1024 think vision; drop model 'x1 x2 x3' from 'vllm' 'test'; remove model 'x1 x2 x3' from 'vllm' 'test'; ``` For add model, each config segment after with is matched to the corresponding model name by position. Example mapping: - x1 -> tokens 1024, chat + vision, thinking=true - x2 -> tokens 2048, chat - x3 -> tokens 1024, vision, thinking=true The existing single-model syntax remains supported. ### 2. Optional embedding dimension Previously, the Go CLI required dimension to be explicitly provided for embed text. Before: embed text 'what is rag' 'who are you' with 'model@test@provider' dimension 8192; Now both forms are supported: embed text 'what is rag' 'who are you' with 'model@test@provider' dimension 8192; embed text 'what is rag' 'who are you' with 'model@test@provider'; When omitted, the CLI leaves dimension unset and relies on provider/backend behavior. ## Tests Added parser tests covering: - Multiple models with multiple config segments - Model type deduplication - Model/config count mismatch - Drop/remove multiple models - Optional embedding dimension parsing
This commit is contained in:
@@ -1442,7 +1442,8 @@ func (c *RAGFlowClient) DropProviderInstance(cmd *Command) (ResponseIf, error) {
|
||||
}
|
||||
|
||||
// DropInstanceModel deletes a provider instance, only works for local deployed model
|
||||
// DROP MODEL <name> FROM <provider_name> <instance_name>
|
||||
// DROP MODEL <name1 name2 name3> FROM <provider_name> <instance_name>
|
||||
// Remove MODEL <name1 name2 name3> FROM <provider_name> <instance_name>
|
||||
func (c *RAGFlowClient) DropInstanceModel(cmd *Command) (ResponseIf, error) {
|
||||
if c.ServerType != "user" {
|
||||
return nil, fmt.Errorf("this command is only allowed in USER mode")
|
||||
@@ -1458,13 +1459,13 @@ func (c *RAGFlowClient) DropInstanceModel(cmd *Command) (ResponseIf, error) {
|
||||
return nil, fmt.Errorf("provider name not provided")
|
||||
}
|
||||
|
||||
modelName, ok := cmd.Params["model_name"].(string)
|
||||
modelNames, ok := cmd.Params["model_names"].([]string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("model name not provided")
|
||||
}
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"models": []string{modelName},
|
||||
"models": modelNames,
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("/providers/%s/instances/%s/models", providerName, instanceName)
|
||||
@@ -2681,35 +2682,17 @@ func (c *RAGFlowClient) AddCustomModel(cmd *Command) (ResponseIf, error) {
|
||||
return nil, fmt.Errorf("instance name not provided")
|
||||
}
|
||||
|
||||
modelName, ok := cmd.Params["model_name"].(string)
|
||||
models, ok := cmd.Params["models"].([]map[string]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("model name not provided")
|
||||
}
|
||||
|
||||
// chat, vision, embedding, rerank, tts, asr, ocr
|
||||
modelTypes, ok := cmd.Params["model_types"].([]string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("model type not provided")
|
||||
}
|
||||
|
||||
maxTokens, ok := cmd.Params["max_tokens"].(int)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("max tokens not provided")
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("/providers/%s/instances/%s/models", providerName, instanceName)
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"provider_name": providerName,
|
||||
"instance_name": instanceName,
|
||||
"model_name": modelName,
|
||||
"model_types": modelTypes,
|
||||
"max_tokens": maxTokens,
|
||||
}
|
||||
|
||||
supportThink, ok := cmd.Params["support_think"].(bool)
|
||||
if ok {
|
||||
payload["thinking"] = supportThink
|
||||
"models": models,
|
||||
}
|
||||
|
||||
resp, err := c.HTTPClient.Request("POST", url, "web", nil, payload)
|
||||
|
||||
@@ -841,6 +841,31 @@ func (p *Parser) parseAddProvider() (*Command, error) {
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
func (p *Parser) parseModelNames(raw string) ([]string, error) {
|
||||
modelNames := strings.Fields(raw)
|
||||
|
||||
if len(modelNames) == 0 {
|
||||
return nil, fmt.Errorf("model name is required")
|
||||
}
|
||||
|
||||
seen := make(map[string]struct{}, len(modelNames))
|
||||
for _, modelName := range modelNames {
|
||||
if _, ok := seen[modelName]; ok {
|
||||
return nil, fmt.Errorf("duplicate model name: %s", modelName)
|
||||
}
|
||||
seen[modelName] = struct{}{}
|
||||
}
|
||||
|
||||
return modelNames, nil
|
||||
}
|
||||
|
||||
type AddModelConfig struct {
|
||||
ModelName string
|
||||
ModelTypes []string
|
||||
MaxTokens int
|
||||
Thinking *bool
|
||||
}
|
||||
|
||||
// syntax: add model 'xxx' to provider 'vllm' instance 'test' with tokens 1024 chat think vision;
|
||||
func (p *Parser) parseAddModel() (*Command, error) {
|
||||
p.nextToken() // consume MODEL
|
||||
@@ -849,7 +874,11 @@ func (p *Parser) parseAddModel() (*Command, error) {
|
||||
return nil, fmt.Errorf("expected model name")
|
||||
}
|
||||
|
||||
modelName, err := p.parseQuotedString()
|
||||
rawModelNames, err := p.parseQuotedString()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
modelNames, err := p.parseModelNames(rawModelNames)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -890,77 +919,145 @@ func (p *Parser) parseAddModel() (*Command, error) {
|
||||
}
|
||||
p.nextToken()
|
||||
|
||||
i := 0
|
||||
var modelTypes []string
|
||||
var supportThink *bool = nil
|
||||
maxTokens := 0
|
||||
if p.curToken.Type == TokenWith {
|
||||
p.nextToken() // pass WITH
|
||||
optionsLoop:
|
||||
for {
|
||||
switch p.curToken.Type {
|
||||
case TokenThink:
|
||||
if supportThink != nil {
|
||||
return nil, fmt.Errorf("think model is already set")
|
||||
}
|
||||
supportThink = new(bool)
|
||||
p.nextToken()
|
||||
*supportThink = true
|
||||
case TokenVision:
|
||||
p.nextToken()
|
||||
modelTypes = append(modelTypes, "vision")
|
||||
case TokenChat:
|
||||
p.nextToken()
|
||||
modelTypes = append(modelTypes, "chat")
|
||||
case TokenEmbedding:
|
||||
p.nextToken()
|
||||
modelTypes = append(modelTypes, "embedding")
|
||||
case TokenRerank:
|
||||
p.nextToken()
|
||||
modelTypes = append(modelTypes, "rerank")
|
||||
case TokenOCR:
|
||||
p.nextToken()
|
||||
modelTypes = append(modelTypes, "ocr")
|
||||
case TokenDocParse:
|
||||
p.nextToken()
|
||||
modelTypes = append(modelTypes, "doc_parse")
|
||||
case TokenTTS:
|
||||
p.nextToken()
|
||||
modelTypes = append(modelTypes, "tts")
|
||||
case TokenASR:
|
||||
p.nextToken()
|
||||
modelTypes = append(modelTypes, "asr")
|
||||
case TokenTokens:
|
||||
p.nextToken() // pass TOKENS
|
||||
if maxTokens != 0 {
|
||||
return nil, fmt.Errorf("max tokens is already given %d", maxTokens)
|
||||
}
|
||||
if p.curToken.Type != TokenInteger {
|
||||
return nil, fmt.Errorf("expected integer")
|
||||
}
|
||||
maxTokens, err = p.parseNumber()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.nextToken() // consume
|
||||
case TokenSemicolon:
|
||||
p.nextToken()
|
||||
break optionsLoop // done
|
||||
default:
|
||||
// No more options to process
|
||||
break optionsLoop
|
||||
}
|
||||
|
||||
models := make([]map[string]any, 0, len(modelNames))
|
||||
if p.curToken.Type != TokenWith {
|
||||
return nil, fmt.Errorf("expected with")
|
||||
}
|
||||
p.nextToken()
|
||||
|
||||
A:
|
||||
for {
|
||||
if i >= len(modelNames) {
|
||||
return nil, fmt.Errorf("too many model configs: got more configs than model names")
|
||||
}
|
||||
switch p.curToken.Type {
|
||||
case TokenThink:
|
||||
if supportThink != nil {
|
||||
return nil, fmt.Errorf("think model is already set for model %s", modelNames[i])
|
||||
}
|
||||
value := true
|
||||
supportThink = &value
|
||||
p.nextToken()
|
||||
|
||||
case TokenVision:
|
||||
modelTypes = append(modelTypes, "vision")
|
||||
p.nextToken()
|
||||
|
||||
case TokenChat:
|
||||
modelTypes = append(modelTypes, "chat")
|
||||
p.nextToken()
|
||||
|
||||
case TokenEmbedding:
|
||||
modelTypes = append(modelTypes, "embedding")
|
||||
p.nextToken()
|
||||
|
||||
case TokenRerank:
|
||||
modelTypes = append(modelTypes, "rerank")
|
||||
p.nextToken()
|
||||
|
||||
case TokenOCR:
|
||||
modelTypes = append(modelTypes, "ocr")
|
||||
p.nextToken()
|
||||
|
||||
case TokenDocParse:
|
||||
modelTypes = append(modelTypes, "doc_parse")
|
||||
p.nextToken()
|
||||
|
||||
case TokenTTS:
|
||||
modelTypes = append(modelTypes, "tts")
|
||||
p.nextToken()
|
||||
|
||||
case TokenASR:
|
||||
modelTypes = append(modelTypes, "asr")
|
||||
p.nextToken()
|
||||
|
||||
case TokenToken, TokenTokens:
|
||||
p.nextToken()
|
||||
if maxTokens != 0 {
|
||||
return nil, fmt.Errorf("max tokens is already given %d for model %s", maxTokens, modelNames[i])
|
||||
}
|
||||
if p.curToken.Type != TokenInteger {
|
||||
return nil, fmt.Errorf("expected integer")
|
||||
}
|
||||
var err error
|
||||
maxTokens, err = p.parseNumber()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.nextToken() // consume number
|
||||
|
||||
case TokenComma, TokenSemicolon, TokenEOF:
|
||||
if len(modelTypes) == 0 {
|
||||
return nil, fmt.Errorf("model type is required for model %s", modelNames[i])
|
||||
}
|
||||
|
||||
seenTypes := make(map[string]struct{}, len(modelTypes))
|
||||
dedupedModelTypes := make([]string, 0, len(modelTypes))
|
||||
|
||||
for _, modelType := range modelTypes {
|
||||
modelType = strings.TrimSpace(modelType)
|
||||
if modelType == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := seenTypes[modelType]; ok {
|
||||
continue
|
||||
}
|
||||
|
||||
seenTypes[modelType] = struct{}{}
|
||||
dedupedModelTypes = append(dedupedModelTypes, modelType)
|
||||
}
|
||||
|
||||
modelTypes = dedupedModelTypes
|
||||
if len(modelTypes) == 0 {
|
||||
return nil, fmt.Errorf("model type is required for model %s", modelNames[i])
|
||||
}
|
||||
|
||||
model := map[string]any{
|
||||
"model_name": modelNames[i],
|
||||
"model_types": modelTypes,
|
||||
"max_tokens": maxTokens,
|
||||
}
|
||||
if supportThink != nil {
|
||||
model["thinking"] = *supportThink
|
||||
}
|
||||
|
||||
models = append(models, model)
|
||||
|
||||
i++
|
||||
modelTypes = nil
|
||||
supportThink = nil
|
||||
maxTokens = 0
|
||||
|
||||
if p.curToken.Type == TokenComma {
|
||||
p.nextToken()
|
||||
continue
|
||||
}
|
||||
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
break A
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected token type: %s", p.curToken.Value)
|
||||
}
|
||||
|
||||
}
|
||||
if len(models) != len(modelNames) {
|
||||
return nil, fmt.Errorf("model config count %d does not match model name count %d", len(models), len(modelNames))
|
||||
}
|
||||
|
||||
cmd := NewCommand("add_custom_model")
|
||||
cmd.Params["model_name"] = modelName
|
||||
cmd.Params["model_types"] = modelTypes
|
||||
cmd.Params["provider_name"] = providerName
|
||||
cmd.Params["instance_name"] = instanceName
|
||||
if supportThink != nil {
|
||||
cmd.Params["support_think"] = *supportThink
|
||||
}
|
||||
cmd.Params["max_tokens"] = maxTokens
|
||||
|
||||
cmd.Params["models"] = models
|
||||
|
||||
return cmd, nil
|
||||
}
|
||||
@@ -1085,6 +1182,8 @@ func (p *Parser) parseRemoveCommand() (*Command, error) {
|
||||
return p.parseRemoveTags()
|
||||
case TokenChunks, TokenAll:
|
||||
return p.parseRemoveChunk()
|
||||
case TokenModel:
|
||||
return p.parseRemoveInstanceModel()
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown REMOVE target: %s", p.curToken.Value)
|
||||
}
|
||||
@@ -1711,16 +1810,24 @@ func (p *Parser) parseDropInstance() (*Command, error) {
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
func (p *Parser) parseRemoveInstanceModel() (*Command, error) {
|
||||
return p.parseDropInstanceModel()
|
||||
}
|
||||
|
||||
// parseDropInstanceModel parses DROP MODEL <name> FROM <provider_name> <instance_name> command
|
||||
// Only works for local deployed model
|
||||
func (p *Parser) parseDropInstanceModel() (*Command, error) {
|
||||
p.nextToken() // consume MODEL
|
||||
|
||||
modelName, err := p.parseQuotedString()
|
||||
rawModelNames, err := p.parseQuotedString()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("expected instance name: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
p.nextToken()
|
||||
modelNames, err := p.parseModelNames(rawModelNames)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.nextToken() // consume model name
|
||||
|
||||
if p.curToken.Type != TokenFrom {
|
||||
return nil, fmt.Errorf("expected FROM")
|
||||
@@ -1742,7 +1849,7 @@ func (p *Parser) parseDropInstanceModel() (*Command, error) {
|
||||
cmd := NewCommand("drop_instance_model")
|
||||
cmd.Params["instance_name"] = instanceName
|
||||
cmd.Params["provider_name"] = providerName
|
||||
cmd.Params["model_name"] = modelName
|
||||
cmd.Params["model_names"] = modelNames
|
||||
|
||||
p.nextToken()
|
||||
// Semicolon is optional
|
||||
@@ -2827,21 +2934,35 @@ textLoop:
|
||||
}
|
||||
p.nextToken()
|
||||
|
||||
if p.curToken.Type != TokenDimension {
|
||||
return nil, fmt.Errorf("expected DIMENSION")
|
||||
}
|
||||
p.nextToken() // consume WITH
|
||||
dimension := 0
|
||||
if p.curToken.Type == TokenDimension {
|
||||
p.nextToken() // consume DIMENSION
|
||||
|
||||
dimension, err := p.parseNumber()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if p.curToken.Type != TokenInteger {
|
||||
return nil, fmt.Errorf("expected integer after DIMENSION")
|
||||
}
|
||||
|
||||
var err error
|
||||
dimension, err = p.parseNumber()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.nextToken()
|
||||
}
|
||||
|
||||
if p.curToken.Type == TokenSemicolon {
|
||||
p.nextToken()
|
||||
}
|
||||
if p.curToken.Type != TokenEOF {
|
||||
return nil, fmt.Errorf("unexpected token after embed command: %s", p.curToken.Value)
|
||||
}
|
||||
p.nextToken()
|
||||
|
||||
cmd := NewCommand("embed_user_text")
|
||||
cmd.Params["composite_model_name"] = compositeModelName
|
||||
cmd.Params["texts"] = texts
|
||||
cmd.Params["dimension"] = dimension
|
||||
if dimension > 0 {
|
||||
cmd.Params["dimension"] = dimension
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user