feat(go-cli): support add models with embedding type (#16020)

### What problem does this PR solve?

This PR enhances the CLI parser to support dimension configurations for
custom embedding models. Users can now specify the maximum dimension and
other supported dimensions directly after the embedding keyword.

```
add model 'x1 x2 x3 x4 x5' to provider 'vllm' instance 'test' with 
tokens 1024 chat think vision, 
token 2048 chat, 
token 1024 think vision,
token 0 embedding 2048 64 1024 2048,
token 0 embedding 2048;
```
- The first integer following embedding represents the max_dimension.
- Any subsequent integers represent specific alternative dimensions.
- If no subsequent integers are provided, dimensions defaults to empty,
indicating all sizes under max_dimension are supported.
This commit is contained in:
Hz_
2026-06-16 12:53:43 +08:00
committed by GitHub
parent 3d7b45bbd7
commit 0c92a38055
2 changed files with 173 additions and 7 deletions

View File

@@ -889,13 +889,6 @@ func (p *Parser) parseModelNames(raw string) ([]string, error) {
return modelNames, nil
}
type AddModelConfig struct {
ModelName string
ModelTypes []string
MaxTokens int
Thinking *bool
}
// syntax: add model 'xxx' to provider 'vllm' instance 'test' with tokens 1024 chat think vision;
func (p *Parser) parseAddModel() (*Command, error) {
p.nextToken() // consume MODEL
@@ -953,6 +946,8 @@ func (p *Parser) parseAddModel() (*Command, error) {
var modelTypes []string
var supportThink *bool = nil
maxTokens := 0
var maxDimension *int = nil
var dimensions []int = nil
models := make([]map[string]any, 0, len(modelNames))
if p.curToken.Type != TokenWith {
@@ -985,6 +980,24 @@ A:
case TokenEmbedding:
modelTypes = append(modelTypes, "embedding")
p.nextToken()
if p.curToken.Type == TokenInteger {
val, err := p.parseNumber()
if err != nil {
return nil, err
}
maxDimension = &val
p.nextToken()
dimensions = make([]int, 0)
for p.curToken.Type == TokenInteger {
dim, err := p.parseNumber()
if err != nil {
return nil, err
}
dimensions = append(dimensions, int(dim))
p.nextToken()
}
}
case TokenRerank:
modelTypes = append(modelTypes, "rerank")
@@ -1057,12 +1070,19 @@ A:
model["thinking"] = *supportThink
}
if maxDimension != nil {
model["max_dimension"] = *maxDimension
model["dimensions"] = dimensions
}
models = append(models, model)
i++
modelTypes = nil
supportThink = nil
maxTokens = 0
maxDimension = nil
dimensions = nil
if p.curToken.Type == TokenComma {
p.nextToken()

View File

@@ -0,0 +1,146 @@
package cli
import (
"reflect"
"testing"
)
func TestParseAddModelWithDimensions(t *testing.T) {
tests := []struct {
name string
input string
expected *Command
wantErr bool
}{
{
name: "Add model with detailed embedding dimensions",
input: "add model 'x1 x2 x3 x4 x5' to provider 'vllm' instance 'test' with tokens 1024 chat think vision, token 2048 chat, token 1024 think vision, token 0 embedding 2048 64 1024 2048, token 0 embedding 2048;",
expected: &Command{
Type: "add_custom_model",
Params: map[string]interface{}{
"provider_name": "vllm",
"instance_name": "test",
"models": []map[string]interface{}{
{
"model_name": "x1",
"model_types": []string{"chat", "vision"},
"max_tokens": 1024,
"thinking": true,
},
{
"model_name": "x2",
"model_types": []string{"chat"},
"max_tokens": 2048,
},
{
"model_name": "x3",
"model_types": []string{"vision"},
"max_tokens": 1024,
"thinking": true,
},
{
"model_name": "x4",
"model_types": []string{"embedding"},
"max_tokens": 0,
"max_dimension": 2048,
"dimensions": []int{64, 1024, 2048},
},
{
"model_name": "x5",
"model_types": []string{"embedding"},
"max_tokens": 0,
"max_dimension": 2048,
"dimensions": []int{},
},
},
},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := NewParser(tt.input)
cmd, err := p.Parse(APIMode)
if (err != nil) != tt.wantErr {
t.Fatalf("Parse() error = %v, wantErr %v", err, tt.wantErr)
}
if tt.wantErr {
return
}
if cmd.Type != tt.expected.Type {
t.Errorf("Command Type = %v, expected = %v", cmd.Type, tt.expected.Type)
}
// Validate provider name
gotProvider, _ := cmd.Params["provider_name"].(string)
expectedProvider, _ := tt.expected.Params["provider_name"].(string)
if gotProvider != expectedProvider {
t.Errorf("provider_name = %v, expected = %v", gotProvider, expectedProvider)
}
// Validate instance name
gotInstance, _ := cmd.Params["instance_name"].(string)
expectedInstance, _ := tt.expected.Params["instance_name"].(string)
if gotInstance != expectedInstance {
t.Errorf("instance_name = %v, expected = %v", gotInstance, expectedInstance)
}
// Validate models
gotModels, ok1 := cmd.Params["models"].([]map[string]interface{})
if !ok1 {
// Try another type just in case type conversion differs
gotModelsAny, okAny := cmd.Params["models"].([]map[string]any)
if okAny {
gotModels = make([]map[string]interface{}, len(gotModelsAny))
for idx, val := range gotModelsAny {
gotModels[idx] = val
}
ok1 = true
}
}
expectedModels, _ := tt.expected.Params["models"].([]map[string]interface{})
if !ok1 {
t.Fatalf("models param not found or has incorrect type: %T", cmd.Params["models"])
}
if len(gotModels) != len(expectedModels) {
t.Fatalf("len(models) = %d, expected = %d", len(gotModels), len(expectedModels))
}
for idx := range gotModels {
gotModel := gotModels[idx]
expectedModel := expectedModels[idx]
if gotModel["model_name"] != expectedModel["model_name"] {
t.Errorf("model[%d].model_name = %v, expected = %v", idx, gotModel["model_name"], expectedModel["model_name"])
}
if !reflect.DeepEqual(gotModel["model_types"], expectedModel["model_types"]) {
t.Errorf("model[%d].model_types = %v, expected = %v", idx, gotModel["model_types"], expectedModel["model_types"])
}
if gotModel["max_tokens"] != expectedModel["max_tokens"] {
t.Errorf("model[%d].max_tokens = %v, expected = %v", idx, gotModel["max_tokens"], expectedModel["max_tokens"])
}
if gotModel["thinking"] != expectedModel["thinking"] {
t.Errorf("model[%d].thinking = %v, expected = %v", idx, gotModel["thinking"], expectedModel["thinking"])
}
if gotModel["max_dimension"] != expectedModel["max_dimension"] {
t.Errorf("model[%d].max_dimension = %v, expected = %v", idx, gotModel["max_dimension"], expectedModel["max_dimension"])
}
if expectedModel["dimensions"] != nil {
if !reflect.DeepEqual(gotModel["dimensions"], expectedModel["dimensions"]) {
t.Errorf("model[%d].dimensions = %v, expected = %v", idx, gotModel["dimensions"], expectedModel["dimensions"])
}
}
}
})
}
}