Files
ragflow/internal/entity/models/tokenhub.go
oktofeesh 7fb9a26623 fix(go-models): validate TokenHub chat requests (#15283)
## Summary
- centralize TokenHub chat request validation for chat and streaming
calls
- reject blank TokenHub model names before sending provider requests
- send TokenHub model listing requests as bodyless GET requests

## What changed
- Added shared TokenHub chat request validation for API key, model name,
and messages.
- Updated `ListModels` to call `GET /models` without a request body.
- Added focused tests for blank model names and accidental GET request
bodies.
- Replaced an httptest handler callback `t.Fatalf` with `t.Errorf` plus
an HTTP error and return.

## Why
TokenHub chat requests should fail locally for invalid model names
instead of sending avoidable malformed requests upstream. Model listing
should also match normal GET semantics and avoid sending an empty JSON
body.

Closes #14736

Co-authored-by: Jin Hai <haijin.chn@gmail.com>
2026-05-27 14:39:41 +08:00

514 lines
14 KiB
Go

package models
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
)
type TokenHubModel struct {
BaseURL map[string]string
URLSuffix URLSuffix
httpClient *http.Client
}
func NewTokenHubModel(baseURL map[string]string, urlSuffix URLSuffix) *TokenHubModel {
return &TokenHubModel{
BaseURL: baseURL,
URLSuffix: urlSuffix,
httpClient: &http.Client{
Timeout: time.Second * 120,
Transport: &http.Transport{
MaxIdleConns: 10,
MaxIdleConnsPerHost: 100,
IdleConnTimeout: time.Second * 60,
DisableCompression: false,
},
},
}
}
func (t *TokenHubModel) NewInstance(baseURL map[string]string) ModelDriver {
return NewTokenHubModel(baseURL, t.URLSuffix)
}
func (t *TokenHubModel) Name() string {
return "tokenhub"
}
func validateTokenHubChatRequest(modelName string, messages []Message, apiConfig *APIConfig) error {
if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" {
return fmt.Errorf("api key is required")
}
if strings.TrimSpace(modelName) == "" {
return fmt.Errorf("model name is required")
}
if len(messages) == 0 {
return fmt.Errorf("messages is empty")
}
return nil
}
func (t *TokenHubModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) {
if err := validateTokenHubChatRequest(modelName, messages, apiConfig); err != nil {
return nil, err
}
var region = "default"
if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" {
region = *apiConfig.Region
}
url := fmt.Sprintf("%s/%s", t.BaseURL[region], t.URLSuffix.Chat)
// Convert messages to the format expected by API
apiMessages := make([]map[string]interface{}, len(messages))
for i, msg := range messages {
apiMessages[i] = map[string]interface{}{
"role": msg.Role,
"content": msg.Content,
}
}
// Build request body
reqBody := map[string]interface{}{
"model": modelName,
"messages": apiMessages,
"stream": false,
"temperature": 0.6,
}
if chatModelConfig != nil {
if chatModelConfig.MaxTokens != nil {
reqBody["max_tokens"] = *chatModelConfig.MaxTokens
}
if chatModelConfig.Temperature != nil {
reqBody["temperature"] = *chatModelConfig.Temperature
}
if chatModelConfig.TopP != nil {
reqBody["top_p"] = *chatModelConfig.TopP
}
if chatModelConfig.Stop != nil {
reqBody["stop"] = *chatModelConfig.Stop
}
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
if apiConfig != nil && apiConfig.ApiKey != nil {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
}
resp, err := t.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
}
// Parse response
var result map[string]interface{}
if err = json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
choices, ok := result["choices"].([]interface{})
if !ok || len(choices) == 0 {
return nil, fmt.Errorf("no choices in response")
}
firstChoice, ok := choices[0].(map[string]interface{})
if !ok {
return nil, fmt.Errorf("invalid choice format")
}
messageMap, ok := firstChoice["message"].(map[string]interface{})
if !ok {
return nil, fmt.Errorf("invalid message format")
}
content, ok := messageMap["content"].(string)
if !ok {
return nil, fmt.Errorf("invalid content format")
}
var reasonContent string
if rc, exists := messageMap["reasoning_content"].(string); exists && rc != "" {
reasonContent = rc
} else if r, exists := messageMap["reasoning"].(string); exists && r != "" {
reasonContent = r
}
if reasonContent != "" && reasonContent[0] == '\n' {
reasonContent = reasonContent[1:]
}
chatResponse := &ChatResponse{
Answer: &content,
ReasonContent: &reasonContent,
}
return chatResponse, nil
}
func (t *TokenHubModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, modelConfig *ChatConfig, sender func(*string, *string) error) error {
if sender == nil {
return fmt.Errorf("sender is required")
}
if err := validateTokenHubChatRequest(modelName, messages, apiConfig); err != nil {
return err
}
var region = "default"
if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" {
region = *apiConfig.Region
}
url := fmt.Sprintf("%s/%s", t.BaseURL[region], t.URLSuffix.Chat)
// Convert messages to API format
apiMessages := make([]map[string]interface{}, len(messages))
for i, msg := range messages {
apiMessages[i] = map[string]interface{}{
"role": msg.Role,
"content": msg.Content,
}
}
// Build request body with streaming enabled
reqBody := map[string]interface{}{
"model": modelName,
"messages": apiMessages,
"stream": true,
}
if modelConfig != nil {
if modelConfig.Stream != nil && !*modelConfig.Stream {
return fmt.Errorf("stream must be true in ChatStreamlyWithSender")
}
if modelConfig.MaxTokens != nil {
reqBody["max_tokens"] = *modelConfig.MaxTokens
}
if modelConfig.Temperature != nil {
reqBody["temperature"] = *modelConfig.Temperature
}
if modelConfig.DoSample != nil {
reqBody["do_sample"] = *modelConfig.DoSample
}
if modelConfig.TopP != nil {
reqBody["top_p"] = *modelConfig.TopP
}
if modelConfig.Stop != nil {
reqBody["stop"] = *modelConfig.Stop
}
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
resp, err := t.httpClient.Do(req)
if err != nil {
return fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
}
// SSE parsing: read line by line
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := scanner.Text()
// SSE data line starts with "data:"
if !strings.HasPrefix(line, "data:") {
continue
}
// Extract JSON after "data:"
data := strings.TrimSpace(line[5:])
// [DONE] marks the end of stream
if data == "[DONE]" {
break
}
// Parse the JSON event
var event map[string]interface{}
if err = json.Unmarshal([]byte(data), &event); err != nil {
continue
}
choices, ok := event["choices"].([]interface{})
if !ok || len(choices) == 0 {
continue
}
firstChoice, ok := choices[0].(map[string]interface{})
if !ok {
continue
}
delta, ok := firstChoice["delta"].(map[string]interface{})
if !ok {
continue
}
reasoningContent, ok := delta["reasoning_content"].(string)
if !ok || reasoningContent == "" {
reasoningContent, _ = delta["reasoning"].(string)
}
if reasoningContent != "" {
if err := sender(nil, &reasoningContent); err != nil {
return err
}
}
content, ok := delta["content"].(string)
if ok && content != "" {
if err := sender(&content, nil); err != nil {
return err
}
}
finishReason, ok := firstChoice["finish_reason"].(string)
if ok && finishReason != "" {
break
}
}
// Send [DONE] marker for OpenAI compatibility
endOfStream := "[DONE]"
if err = sender(&endOfStream, nil); err != nil {
return err
}
return scanner.Err()
}
func (t *TokenHubModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) {
if len(texts) == 0 {
return []EmbeddingData{}, nil
}
if modelName == nil || *modelName == "" {
return nil, fmt.Errorf("model name is required")
}
if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" {
return nil, fmt.Errorf("api key is required")
}
var region = "default"
if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" {
region = *apiConfig.Region
}
url := fmt.Sprintf("%s/%s", t.BaseURL[region], t.URLSuffix.Embedding)
reqBody := map[string]interface{}{
"model": *modelName,
"input": texts,
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
resp, err := t.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("TokenHub embedding API error: status %d, body: %s", resp.StatusCode, string(body))
}
var parsedResponse struct {
Data []struct {
Embedding []float64 `json:"embedding"`
Index int `json:"index"`
} `json:"data"`
}
if err = json.Unmarshal(body, &parsedResponse); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
if len(parsedResponse.Data) == 0 {
return nil, fmt.Errorf("TokenHub embedding response contains no data: %s", string(body))
}
var embeddings []EmbeddingData
for _, dataElem := range parsedResponse.Data {
embeddings = append(embeddings, EmbeddingData{
Embedding: dataElem.Embedding,
Index: dataElem.Index,
})
}
return embeddings, nil
}
func (t *TokenHubModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) {
return nil, fmt.Errorf("%s no such method", t.Name())
}
func (t *TokenHubModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) {
return nil, fmt.Errorf("%s no such method", t.Name())
}
func (t *TokenHubModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error {
return fmt.Errorf("%s no such method", t.Name())
}
func (t *TokenHubModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) {
return nil, fmt.Errorf("%s no such method", t.Name())
}
func (t *TokenHubModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error {
return fmt.Errorf("%s no such method", t.Name())
}
func (t *TokenHubModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) {
return nil, fmt.Errorf("%s no such method", t.Name())
}
func (t *TokenHubModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) {
return nil, fmt.Errorf("%s no such method", t.Name())
}
func (t *TokenHubModel) ListModels(apiConfig *APIConfig) ([]string, error) {
if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" {
return nil, fmt.Errorf("api key is required")
}
var region = "default"
if apiConfig.Region != nil && *apiConfig.Region != "" {
region = *apiConfig.Region
}
url := fmt.Sprintf("%s/%s", t.BaseURL[region], t.URLSuffix.Models)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
resp, err := t.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
}
// Parse response
var result map[string]interface{}
if err = json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
data, ok := result["data"].([]interface{})
if !ok {
return nil, fmt.Errorf("invalid models list format")
}
models := make([]string, 0, len(data))
for _, model := range data {
modelMap, ok := model.(map[string]interface{})
if !ok {
continue
}
modelName, ok := modelMap["id"].(string)
if !ok {
continue
}
models = append(models, modelName)
}
return models, nil
}
func (t *TokenHubModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) {
return nil, fmt.Errorf("%s no such method", t.Name())
}
func (t *TokenHubModel) CheckConnection(apiConfig *APIConfig) error {
_, err := t.ListModels(apiConfig)
return err
}
func (t *TokenHubModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) {
return nil, fmt.Errorf("%s no such method", t.Name())
}
func (t *TokenHubModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) {
return nil, fmt.Errorf("%s no such method", t.Name())
}