From 0c92a38055b8256cb0bc56544c4d442fe9b8067f Mon Sep 17 00:00:00 2001 From: Hz_ Date: Tue, 16 Jun 2026 12:53:43 +0800 Subject: [PATCH] feat(go-cli): support add models with embedding type (#16020) ### What problem does this PR solve? This PR enhances the CLI parser to support dimension configurations for custom embedding models. Users can now specify the maximum dimension and other supported dimensions directly after the embedding keyword. ``` add model 'x1 x2 x3 x4 x5' to provider 'vllm' instance 'test' with tokens 1024 chat think vision, token 2048 chat, token 1024 think vision, token 0 embedding 2048 64 1024 2048, token 0 embedding 2048; ``` - The first integer following embedding represents the max_dimension. - Any subsequent integers represent specific alternative dimensions. - If no subsequent integers are provided, dimensions defaults to empty, indicating all sizes under max_dimension are supported. --- internal/cli/user_parser.go | 34 +++++-- internal/cli/user_parser_test.go | 146 +++++++++++++++++++++++++++++++ 2 files changed, 173 insertions(+), 7 deletions(-) create mode 100644 internal/cli/user_parser_test.go diff --git a/internal/cli/user_parser.go b/internal/cli/user_parser.go index c9c8d29d35..c98b2d5f10 100644 --- a/internal/cli/user_parser.go +++ b/internal/cli/user_parser.go @@ -889,13 +889,6 @@ func (p *Parser) parseModelNames(raw string) ([]string, error) { return modelNames, nil } -type AddModelConfig struct { - ModelName string - ModelTypes []string - MaxTokens int - Thinking *bool -} - // syntax: add model 'xxx' to provider 'vllm' instance 'test' with tokens 1024 chat think vision; func (p *Parser) parseAddModel() (*Command, error) { p.nextToken() // consume MODEL @@ -953,6 +946,8 @@ func (p *Parser) parseAddModel() (*Command, error) { var modelTypes []string var supportThink *bool = nil maxTokens := 0 + var maxDimension *int = nil + var dimensions []int = nil models := make([]map[string]any, 0, len(modelNames)) if p.curToken.Type != TokenWith { @@ -985,6 +980,24 @@ A: case TokenEmbedding: modelTypes = append(modelTypes, "embedding") p.nextToken() + if p.curToken.Type == TokenInteger { + val, err := p.parseNumber() + if err != nil { + return nil, err + } + maxDimension = &val + p.nextToken() + + dimensions = make([]int, 0) + for p.curToken.Type == TokenInteger { + dim, err := p.parseNumber() + if err != nil { + return nil, err + } + dimensions = append(dimensions, int(dim)) + p.nextToken() + } + } case TokenRerank: modelTypes = append(modelTypes, "rerank") @@ -1057,12 +1070,19 @@ A: model["thinking"] = *supportThink } + if maxDimension != nil { + model["max_dimension"] = *maxDimension + model["dimensions"] = dimensions + } + models = append(models, model) i++ modelTypes = nil supportThink = nil maxTokens = 0 + maxDimension = nil + dimensions = nil if p.curToken.Type == TokenComma { p.nextToken() diff --git a/internal/cli/user_parser_test.go b/internal/cli/user_parser_test.go new file mode 100644 index 0000000000..c2b47e8c34 --- /dev/null +++ b/internal/cli/user_parser_test.go @@ -0,0 +1,146 @@ +package cli + +import ( + "reflect" + "testing" +) + +func TestParseAddModelWithDimensions(t *testing.T) { + tests := []struct { + name string + input string + expected *Command + wantErr bool + }{ + { + name: "Add model with detailed embedding dimensions", + input: "add model 'x1 x2 x3 x4 x5' to provider 'vllm' instance 'test' with tokens 1024 chat think vision, token 2048 chat, token 1024 think vision, token 0 embedding 2048 64 1024 2048, token 0 embedding 2048;", + expected: &Command{ + Type: "add_custom_model", + Params: map[string]interface{}{ + "provider_name": "vllm", + "instance_name": "test", + "models": []map[string]interface{}{ + { + "model_name": "x1", + "model_types": []string{"chat", "vision"}, + "max_tokens": 1024, + "thinking": true, + }, + { + "model_name": "x2", + "model_types": []string{"chat"}, + "max_tokens": 2048, + }, + { + "model_name": "x3", + "model_types": []string{"vision"}, + "max_tokens": 1024, + "thinking": true, + }, + { + "model_name": "x4", + "model_types": []string{"embedding"}, + "max_tokens": 0, + "max_dimension": 2048, + "dimensions": []int{64, 1024, 2048}, + }, + { + "model_name": "x5", + "model_types": []string{"embedding"}, + "max_tokens": 0, + "max_dimension": 2048, + "dimensions": []int{}, + }, + }, + }, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := NewParser(tt.input) + cmd, err := p.Parse(APIMode) + if (err != nil) != tt.wantErr { + t.Fatalf("Parse() error = %v, wantErr %v", err, tt.wantErr) + } + if tt.wantErr { + return + } + + if cmd.Type != tt.expected.Type { + t.Errorf("Command Type = %v, expected = %v", cmd.Type, tt.expected.Type) + } + + // Validate provider name + gotProvider, _ := cmd.Params["provider_name"].(string) + expectedProvider, _ := tt.expected.Params["provider_name"].(string) + if gotProvider != expectedProvider { + t.Errorf("provider_name = %v, expected = %v", gotProvider, expectedProvider) + } + + // Validate instance name + gotInstance, _ := cmd.Params["instance_name"].(string) + expectedInstance, _ := tt.expected.Params["instance_name"].(string) + if gotInstance != expectedInstance { + t.Errorf("instance_name = %v, expected = %v", gotInstance, expectedInstance) + } + + // Validate models + gotModels, ok1 := cmd.Params["models"].([]map[string]interface{}) + if !ok1 { + // Try another type just in case type conversion differs + gotModelsAny, okAny := cmd.Params["models"].([]map[string]any) + if okAny { + gotModels = make([]map[string]interface{}, len(gotModelsAny)) + for idx, val := range gotModelsAny { + gotModels[idx] = val + } + ok1 = true + } + } + expectedModels, _ := tt.expected.Params["models"].([]map[string]interface{}) + + if !ok1 { + t.Fatalf("models param not found or has incorrect type: %T", cmd.Params["models"]) + } + + if len(gotModels) != len(expectedModels) { + t.Fatalf("len(models) = %d, expected = %d", len(gotModels), len(expectedModels)) + } + + for idx := range gotModels { + gotModel := gotModels[idx] + expectedModel := expectedModels[idx] + + if gotModel["model_name"] != expectedModel["model_name"] { + t.Errorf("model[%d].model_name = %v, expected = %v", idx, gotModel["model_name"], expectedModel["model_name"]) + } + + if !reflect.DeepEqual(gotModel["model_types"], expectedModel["model_types"]) { + t.Errorf("model[%d].model_types = %v, expected = %v", idx, gotModel["model_types"], expectedModel["model_types"]) + } + + if gotModel["max_tokens"] != expectedModel["max_tokens"] { + t.Errorf("model[%d].max_tokens = %v, expected = %v", idx, gotModel["max_tokens"], expectedModel["max_tokens"]) + } + + if gotModel["thinking"] != expectedModel["thinking"] { + t.Errorf("model[%d].thinking = %v, expected = %v", idx, gotModel["thinking"], expectedModel["thinking"]) + } + + if gotModel["max_dimension"] != expectedModel["max_dimension"] { + t.Errorf("model[%d].max_dimension = %v, expected = %v", idx, gotModel["max_dimension"], expectedModel["max_dimension"]) + } + + if expectedModel["dimensions"] != nil { + if !reflect.DeepEqual(gotModel["dimensions"], expectedModel["dimensions"]) { + t.Errorf("model[%d].dimensions = %v, expected = %v", idx, gotModel["dimensions"], expectedModel["dimensions"]) + } + } + } + }) + } +}