mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
fix(go-models): validate TokenHub chat requests (#15283)
## Summary - centralize TokenHub chat request validation for chat and streaming calls - reject blank TokenHub model names before sending provider requests - send TokenHub model listing requests as bodyless GET requests ## What changed - Added shared TokenHub chat request validation for API key, model name, and messages. - Updated `ListModels` to call `GET /models` without a request body. - Added focused tests for blank model names and accidental GET request bodies. - Replaced an httptest handler callback `t.Fatalf` with `t.Errorf` plus an HTTP error and return. ## Why TokenHub chat requests should fail locally for invalid model names instead of sending avoidable malformed requests upstream. Model listing should also match normal GET semantics and avoid sending an empty JSON body. Closes #14736 Co-authored-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
@@ -41,12 +41,22 @@ func (t *TokenHubModel) Name() string {
|
||||
return "tokenhub"
|
||||
}
|
||||
|
||||
func (t *TokenHubModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) {
|
||||
func validateTokenHubChatRequest(modelName string, messages []Message, apiConfig *APIConfig) error {
|
||||
if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" {
|
||||
return nil, fmt.Errorf("api key is required")
|
||||
return fmt.Errorf("api key is required")
|
||||
}
|
||||
if strings.TrimSpace(modelName) == "" {
|
||||
return fmt.Errorf("model name is required")
|
||||
}
|
||||
if len(messages) == 0 {
|
||||
return nil, fmt.Errorf("messages is empty")
|
||||
return fmt.Errorf("messages is empty")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TokenHubModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) {
|
||||
if err := validateTokenHubChatRequest(modelName, messages, apiConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var region = "default"
|
||||
@@ -170,11 +180,8 @@ func (t *TokenHubModel) ChatStreamlyWithSender(modelName string, messages []Mess
|
||||
if sender == nil {
|
||||
return fmt.Errorf("sender is required")
|
||||
}
|
||||
if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" {
|
||||
return fmt.Errorf("api key is required")
|
||||
}
|
||||
if len(messages) == 0 {
|
||||
return fmt.Errorf("messages is empty")
|
||||
if err := validateTokenHubChatRequest(modelName, messages, apiConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var region = "default"
|
||||
@@ -439,20 +446,11 @@ func (t *TokenHubModel) ListModels(apiConfig *APIConfig) ([]string, error) {
|
||||
|
||||
url := fmt.Sprintf("%s/%s", t.BaseURL[region], t.URLSuffix.Models)
|
||||
|
||||
// Build request body
|
||||
reqBody := map[string]interface{}{}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", url, bytes.NewBuffer(jsonData))
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
|
||||
|
||||
resp, err := t.httpClient.Do(req)
|
||||
|
||||
@@ -45,6 +45,17 @@ func newTokenHubServer(t *testing.T, expectedMethod, expectedPath string, handle
|
||||
handler(t, body, w)
|
||||
return
|
||||
}
|
||||
raw, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Errorf("read body: %v", err)
|
||||
http.Error(w, "read error", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if len(raw) != 0 {
|
||||
t.Errorf("expected no request body for %s, got %q", r.Method, string(raw))
|
||||
http.Error(w, "unexpected body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
handler(t, nil, w)
|
||||
}))
|
||||
}
|
||||
@@ -128,6 +139,14 @@ func TestTokenHubChatRequiresAPIKey(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenHubChatRequiresModelName(t *testing.T) {
|
||||
apiKey := "test-key"
|
||||
_, err := newTokenHubForTest("http://unused").ChatWithMessages(" ", []Message{{Role: "user", Content: "x"}}, &APIConfig{ApiKey: &apiKey}, nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "model name is required") {
|
||||
t.Fatalf("expected model-name error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenHubStreamHappyPath(t *testing.T) {
|
||||
srv := newTokenHubSSEServer(t, "/chat/completions", strings.Join([]string{
|
||||
`data: {"choices":[{"delta":{"reasoning_content":"thinking"}}]}`,
|
||||
@@ -208,6 +227,20 @@ func TestTokenHubStreamRequiresAPIKey(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenHubStreamRequiresModelName(t *testing.T) {
|
||||
apiKey := "test-key"
|
||||
err := newTokenHubForTest("http://unused").ChatStreamlyWithSender(
|
||||
" ",
|
||||
[]Message{{Role: "user", Content: "ping"}},
|
||||
&APIConfig{ApiKey: &apiKey},
|
||||
nil,
|
||||
func(*string, *string) error { return nil },
|
||||
)
|
||||
if err == nil || !strings.Contains(err.Error(), "model name is required") {
|
||||
t.Fatalf("expected model-name error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenHubEmbedHappyPath(t *testing.T) {
|
||||
srv := newTokenHubServer(t, http.MethodPost, "/embeddings", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) {
|
||||
if body["model"] != "text-embedding-3-small" {
|
||||
@@ -215,7 +248,9 @@ func TestTokenHubEmbedHappyPath(t *testing.T) {
|
||||
}
|
||||
inputs, ok := body["input"].([]interface{})
|
||||
if !ok || len(inputs) != 2 {
|
||||
t.Fatalf("input=%#v", body["input"])
|
||||
t.Errorf("input=%#v", body["input"])
|
||||
http.Error(w, "invalid input", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"data": []map[string]interface{}{
|
||||
|
||||
Reference in New Issue
Block a user