diff --git a/internal/entity/models/tokenhub.go b/internal/entity/models/tokenhub.go index a6726cbc76..6346000e32 100644 --- a/internal/entity/models/tokenhub.go +++ b/internal/entity/models/tokenhub.go @@ -41,12 +41,22 @@ func (t *TokenHubModel) Name() string { return "tokenhub" } -func (t *TokenHubModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { +func validateTokenHubChatRequest(modelName string, messages []Message, apiConfig *APIConfig) error { if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { - return nil, fmt.Errorf("api key is required") + return fmt.Errorf("api key is required") + } + if strings.TrimSpace(modelName) == "" { + return fmt.Errorf("model name is required") } if len(messages) == 0 { - return nil, fmt.Errorf("messages is empty") + return fmt.Errorf("messages is empty") + } + return nil +} + +func (t *TokenHubModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { + if err := validateTokenHubChatRequest(modelName, messages, apiConfig); err != nil { + return nil, err } var region = "default" @@ -170,11 +180,8 @@ func (t *TokenHubModel) ChatStreamlyWithSender(modelName string, messages []Mess if sender == nil { return fmt.Errorf("sender is required") } - if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { - return fmt.Errorf("api key is required") - } - if len(messages) == 0 { - return fmt.Errorf("messages is empty") + if err := validateTokenHubChatRequest(modelName, messages, apiConfig); err != nil { + return err } var region = "default" @@ -439,20 +446,11 @@ func (t *TokenHubModel) ListModels(apiConfig *APIConfig) ([]string, error) { url := fmt.Sprintf("%s/%s", t.BaseURL[region], t.URLSuffix.Models) - // Build request body - reqBody := map[string]interface{}{} - - jsonData, err := json.Marshal(reqBody) - if err != nil { - return nil, fmt.Errorf("failed to marshal request: %w", err) - } - - req, err := http.NewRequest("GET", url, bytes.NewBuffer(jsonData)) + req, err := http.NewRequest("GET", url, nil) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } - req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) resp, err := t.httpClient.Do(req) diff --git a/internal/entity/models/tokenhub_test.go b/internal/entity/models/tokenhub_test.go index 62e2c41aef..987a02cc1d 100644 --- a/internal/entity/models/tokenhub_test.go +++ b/internal/entity/models/tokenhub_test.go @@ -45,6 +45,17 @@ func newTokenHubServer(t *testing.T, expectedMethod, expectedPath string, handle handler(t, body, w) return } + raw, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("read body: %v", err) + http.Error(w, "read error", http.StatusBadRequest) + return + } + if len(raw) != 0 { + t.Errorf("expected no request body for %s, got %q", r.Method, string(raw)) + http.Error(w, "unexpected body", http.StatusBadRequest) + return + } handler(t, nil, w) })) } @@ -128,6 +139,14 @@ func TestTokenHubChatRequiresAPIKey(t *testing.T) { } } +func TestTokenHubChatRequiresModelName(t *testing.T) { + apiKey := "test-key" + _, err := newTokenHubForTest("http://unused").ChatWithMessages(" ", []Message{{Role: "user", Content: "x"}}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "model name is required") { + t.Fatalf("expected model-name error, got %v", err) + } +} + func TestTokenHubStreamHappyPath(t *testing.T) { srv := newTokenHubSSEServer(t, "/chat/completions", strings.Join([]string{ `data: {"choices":[{"delta":{"reasoning_content":"thinking"}}]}`, @@ -208,6 +227,20 @@ func TestTokenHubStreamRequiresAPIKey(t *testing.T) { } } +func TestTokenHubStreamRequiresModelName(t *testing.T) { + apiKey := "test-key" + err := newTokenHubForTest("http://unused").ChatStreamlyWithSender( + " ", + []Message{{Role: "user", Content: "ping"}}, + &APIConfig{ApiKey: &apiKey}, + nil, + func(*string, *string) error { return nil }, + ) + if err == nil || !strings.Contains(err.Error(), "model name is required") { + t.Fatalf("expected model-name error, got %v", err) + } +} + func TestTokenHubEmbedHappyPath(t *testing.T) { srv := newTokenHubServer(t, http.MethodPost, "/embeddings", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { if body["model"] != "text-embedding-3-small" { @@ -215,7 +248,9 @@ func TestTokenHubEmbedHappyPath(t *testing.T) { } inputs, ok := body["input"].([]interface{}) if !ok || len(inputs) != 2 { - t.Fatalf("input=%#v", body["input"]) + t.Errorf("input=%#v", body["input"]) + http.Error(w, "invalid input", http.StatusBadRequest) + return } _ = json.NewEncoder(w).Encode(map[string]interface{}{ "data": []map[string]interface{}{