From ae88578451e2e445124c2f63317ce9b743cd23b6 Mon Sep 17 00:00:00 2001 From: Haruko386 Date: Wed, 27 May 2026 14:08:35 +0800 Subject: [PATCH] Go: implement TTS and ASR for X.AI (#15247) ### What problem does this PR solve? As title ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] New Feature (non-breaking change which adds functionality) - [x] Refactoring --- conf/models/xai.json | 35 +++- internal/entity/models/mineru.go | 2 +- internal/entity/models/mineru_local.go | 2 +- internal/entity/models/paddleocr.go | 2 +- internal/entity/models/paddleocr_local.go | 2 +- internal/entity/models/xai.go | 187 +++++++++++++++++++++- 6 files changed, 217 insertions(+), 13 deletions(-) diff --git a/conf/models/xai.json b/conf/models/xai.json index e46351d01f..cace7f06bd 100644 --- a/conf/models/xai.json +++ b/conf/models/xai.json @@ -5,39 +5,60 @@ }, "url_suffix": { "chat": "chat/completions", - "models": "models" + "models": "models", + "tts": "tts", + "ast": "stt" }, "class": "grok", "models": [ { "name": "grok-4", "max_tokens": 256000, - "model_types": ["chat"] + "model_types": [ + "chat" + ] }, { "name": "grok-3", "max_tokens": 131072, - "model_types": ["chat"] + "model_types": [ + "chat" + ] }, { "name": "grok-3-fast", "max_tokens": 131072, - "model_types": ["chat"] + "model_types": [ + "chat" + ] }, { "name": "grok-3-mini", "max_tokens": 131072, - "model_types": ["chat"] + "model_types": [ + "chat" + ] }, { "name": "grok-3-mini-mini-fast", "max_tokens": 131072, - "model_types": ["chat"] + "model_types": [ + "chat" + ] }, { "name": "grok-2-vision", "max_tokens": 32768, - "model_types": ["vision"] + "model_types": [ + "vision" + ] + }, + { + "name": "eve", + "model_types": [ + "tts" + ] } ] } + diff --git a/internal/entity/models/mineru.go b/internal/entity/models/mineru.go index 1ff4697db2..b8a14be6b7 100644 --- a/internal/entity/models/mineru.go +++ b/internal/entity/models/mineru.go @@ -48,7 +48,7 @@ func (m *MinerUModel) NewInstance(baseURL map[string]string) ModelDriver { } func (m *MinerUModel) Name() string { - return "mineru" + return "mineru.net" } func (m *MinerUModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { diff --git a/internal/entity/models/mineru_local.go b/internal/entity/models/mineru_local.go index 177ce77530..4307f0e8b1 100644 --- a/internal/entity/models/mineru_local.go +++ b/internal/entity/models/mineru_local.go @@ -49,7 +49,7 @@ func (m *MinerULocalModel) NewInstance(baseURL map[string]string) ModelDriver { } func (m *MinerULocalModel) Name() string { - return "mineru_local" + return "mineru" } func (m *MinerULocalModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { diff --git a/internal/entity/models/paddleocr.go b/internal/entity/models/paddleocr.go index 49d859976a..2e95d1e095 100644 --- a/internal/entity/models/paddleocr.go +++ b/internal/entity/models/paddleocr.go @@ -51,7 +51,7 @@ func (p PaddleOCRModel) NewInstance(baseURL map[string]string) ModelDriver { } func (p *PaddleOCRModel) Name() string { - return "paddle_ocr" + return "paddle_ocr.net" } func (p *PaddleOCRModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { diff --git a/internal/entity/models/paddleocr_local.go b/internal/entity/models/paddleocr_local.go index 5213c0eff0..c562b4f9ac 100644 --- a/internal/entity/models/paddleocr_local.go +++ b/internal/entity/models/paddleocr_local.go @@ -50,7 +50,7 @@ func (p *PaddleOCRLocalModel) NewInstance(baseURL map[string]string) ModelDriver } func (p *PaddleOCRLocalModel) Name() string { - return "paddleocr_local" + return "paddleocr" } func (p *PaddleOCRLocalModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { diff --git a/internal/entity/models/xai.go b/internal/entity/models/xai.go index 1f9cc03fe1..67dc86b214 100644 --- a/internal/entity/models/xai.go +++ b/internal/entity/models/xai.go @@ -23,7 +23,11 @@ import ( "encoding/json" "fmt" "io" + "mime/multipart" "net/http" + "os" + "path/filepath" + "strconv" "strings" "time" ) @@ -499,7 +503,127 @@ func (z *XAIModel) Rerank(modelName *string, query string, documents []string, a // TranscribeAudio transcribe audio func (o *XAIModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { - return nil, fmt.Errorf("%s, no such method", o.Name()) + if file == nil || *file == "" { + return nil, fmt.Errorf("file is missing") + } + + region := "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", o.BaseURL[region], o.URLSuffix.ASR) + + // multipart body + var body bytes.Buffer + writer := multipart.NewWriter(&body) + + // 1. model field + if modelName != nil && *modelName != "" { + if err := writer.WriteField("model", *modelName); err != nil { + return nil, fmt.Errorf("failed to write model field: %w", err) + } + } + + // 2. extra params + if asrConfig != nil && asrConfig.Params != nil { + for key, value := range asrConfig.Params { + + switch v := value.(type) { + case []interface{}: + for _, item := range v { + if err := writer.WriteField(key, fmt.Sprintf("%v", item)); err != nil { + return nil, fmt.Errorf("failed to write array field %s: %w", key, err) + } + } + case []string: + for _, item := range v { + if err := writer.WriteField(key, item); err != nil { + return nil, fmt.Errorf("failed to write array field %s: %w", key, err) + } + } + default: + var val string + switch v2 := value.(type) { + case string: + val = v2 + case bool: + val = strconv.FormatBool(v2) + case int: + val = strconv.Itoa(v2) + case int64: + val = strconv.FormatInt(v2, 10) + case float32: + val = strconv.FormatFloat(float64(v2), 'f', -1, 32) + case float64: + val = strconv.FormatFloat(v2, 'f', -1, 64) + default: + val = fmt.Sprintf("%v", v2) + } + + if err := writer.WriteField(key, val); err != nil { + return nil, fmt.Errorf("failed to write field %s: %w", key, err) + } + } + } + } + + // open audio file + audioFile, err := os.Open(*file) + if err != nil { + return nil, fmt.Errorf("failed to open audio file: %w", err) + } + defer audioFile.Close() + + // create multipart file field + part, err := writer.CreateFormFile("file", filepath.Base(*file)) + if err != nil { + return nil, fmt.Errorf("failed to create multipart file: %w", err) + } + + // copy file content + if _, err = io.Copy(part, audioFile); err != nil { + return nil, fmt.Errorf("failed to copy audio data: %w", err) + } + if err = writer.Close(); err != nil { + return nil, fmt.Errorf("failed to close multipart writer: %w", err) + } + + // build request + req, err := http.NewRequest("POST", url, &body) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + req.Header.Set("Content-Type", writer.FormDataContentType()) + + // send request + resp, err := o.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("xAI ASR API error: %s - %s", resp.Status, string(respBody)) + } + + // xAI response parsing (assuming standard format matching OpenAI) + var result struct { + Text string `json:"text"` + } + + if err = json.Unmarshal(respBody, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w, body=%s", err, string(respBody)) + } + + return &ASRResponse{Text: result.Text}, nil } func (z *XAIModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { @@ -508,7 +632,66 @@ func (z *XAIModel) TranscribeAudioWithSender(modelName *string, file *string, ap // AudioSpeech convert text to audio func (o *XAIModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { - return nil, fmt.Errorf("%s, no such method", o.Name()) + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("xai API key is missing") + } + + if audioContent == nil || *audioContent == "" { + return nil, fmt.Errorf("text content is missing") + } + + var region = "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", o.BaseURL[region], o.URLSuffix.TTS) + + reqBody := map[string]interface{}{ + "text": *audioContent, + "voice_id": modelName, + } + + if ttsConfig != nil && ttsConfig.Params != nil { + for key, value := range ttsConfig.Params { + reqBody[key] = value + } + } + if ttsConfig != nil && ttsConfig.Format != "" { + reqBody["output_format"] = map[string]interface{}{ + "codec": ttsConfig.Format, + } + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := o.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("%s - %s", resp.Status, string(body)) + } + + return &TTSResponse{Audio: body}, nil } func (z *XAIModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error {