From 8dc7f1d95ec250c08bb13ed3fb7c431988387c88 Mon Sep 17 00:00:00 2001 From: Haruko386 Date: Mon, 8 Jun 2026 19:27:45 +0800 Subject: [PATCH] Go: implement ASR and TTS for xiaomi (#15765) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What problem does this PR solve? **Verified from CLI** ``` RAGFlow(user)> chat with 'mimo-v2.5@test@xiaomi' message 'who r u' Answer: Hello! I'm MiMo-v2.5, a large language model developed by Xiaomi's LLM Core Team. You can think of me as a friendly AI assistant ready to help you answer questions, have conversations, or work on creative tasks. My context window can handle up to 1 million tokens, so we can dive into pretty long discussions or documents if you'd like. What can I help you with today? Time: 3.831830 RAGFlow(user)> stream chat with 'mimo-v2.5@test@xiaomi' message 'who r u' Answer: there! I'm MiMo-v2.5, an AI assistant created by the Xiaomi LLM Core Team. I'm here to chat, help out, answer questions, or just have a friendly conversation. Think of me as a helpful buddy with a pretty big memory (1 million tokens worth!). What can I do for you today?😊 Time: 2.421630 RAGFlow(user)> think chat with 'mimo-v2.5@test@xiaomi' message 'who r u' Thinking: The user is asking a simple question about who I am. According to my system prompt, I should: - Identify myself as **MiMo-v2.5** - State that I was developed by the **Xiaomi LLM Core Team** - Answer in first person and be warm and conversational Answer: Hey there! 👋 I'm **MiMo**, an AI assistant created by the **Xiaomi LLM Core Team**. Think of me as a friendly chat buddy who's here to help you with all sorts of questions and tasks! I love having conversations, answering questions, brainstorming ideas, and helping people figure things out. Whether you want to chat, need help with something specific, or just want to explore ideas together — I'm here for it! 😊 What can I help you with today? Time: 6.651589 RAGFlow(user)> tts with 'mimo-v2.5-tts@test@xiaomi' text 'hello? show yourself' play format 'wav' param '{"voice": "Chloe"}' SUCCESS RAGFlow(user)> asr with 'mimo-v2.5-asr@test@xiaomi' audio './internal/test.wav' param '{"language": "zh"}' +------------------------------------------------------------------------------------------------------------------------+ | text | +------------------------------------------------------------------------------------------------------------------------+ | 1 The examination and testimony of the experts enabled the commission to conclude that five shots may have been fired. | +------------------------------------------------------------------------------------------------------------------------+ ``` ### Type of change - [x] New Feature (non-breaking change which adds functionality) - [x] Refactoring --- conf/models/xiaomi.json | 26 +- internal/entity/models/xiaomi.go | 1043 +++++++++++++++++++++--------- 2 files changed, 769 insertions(+), 300 deletions(-) diff --git a/conf/models/xiaomi.json b/conf/models/xiaomi.json index 8902867415..e3f9934f90 100644 --- a/conf/models/xiaomi.json +++ b/conf/models/xiaomi.json @@ -1,12 +1,11 @@ { "name": "Xiaomi", "url": { - "default": "https://api.xiaomimimo.com" + "default": "https://api.xiaomimimo.com/v1" }, "url_suffix": { - "chat": "v1/chat/completions" + "chat": "chat/completions" }, - "class": "xiaomi", "models": [ { "name": "mimo-v2.5-pro", @@ -19,15 +18,26 @@ "name": "mimo-v2.5", "max_tokens": 1048576, "model_types": [ - "chat", - "vision" + "chat" ] }, { - "name": "mimo-v2-flash", - "max_tokens": 262144, + "name": "mimo-v2.5-asr", + "max_tokens": 8192, "model_types": [ - "chat" + "asr" + ] + }, + { + "name": "mimo-v2.5-tts", + "model_types": [ + "tts" + ] + }, + { + "name": "mimo-v2-tts", + "model_types": [ + "tts" ] } ] diff --git a/internal/entity/models/xiaomi.go b/internal/entity/models/xiaomi.go index 52809f3e79..7e4446cb68 100644 --- a/internal/entity/models/xiaomi.go +++ b/internal/entity/models/xiaomi.go @@ -20,34 +20,25 @@ import ( "bufio" "bytes" "context" + "encoding/base64" "encoding/json" "fmt" "io" + "mime" "net/http" + "os" + "path/filepath" + "ragflow/internal/common" "strings" "time" ) -// XiaomiModel implements ModelDriver for Xiaomi MiMo chat models. -// -// Xiaomi MiMo documents an OpenAI-compatible chat completions endpoint at -// https://api.xiaomimimo.com/v1/chat/completions. The documented request -// sample uses api-key authentication and max_completion_tokens, so this -// driver follows that wire shape instead of blindly reusing max_tokens. type XiaomiModel struct { baseModel BaseModel } func NewXiaomiModel(baseURL map[string]string, urlSuffix URLSuffix) *XiaomiModel { - defaultTransport, ok := http.DefaultTransport.(*http.Transport) - var transport *http.Transport - if ok { - transport = defaultTransport.Clone() - } else { - transport = &http.Transport{ - Proxy: http.ProxyFromEnvironment, - } - } + transport := http.DefaultTransport.(*http.Transport).Clone() transport.MaxIdleConns = 100 transport.MaxIdleConnsPerHost = 10 transport.IdleConnTimeout = 90 * time.Second @@ -65,169 +56,96 @@ func NewXiaomiModel(baseURL map[string]string, urlSuffix URLSuffix) *XiaomiModel } } -func (m *XiaomiModel) NewInstance(baseURL map[string]string) ModelDriver { - return NewXiaomiModel(baseURL, m.baseModel.URLSuffix) +func (x *XiaomiModel) NewInstance(baseURL map[string]string) ModelDriver { + return NewXiaomiModel(baseURL, x.baseModel.URLSuffix) } -func (m *XiaomiModel) Name() string { +func (x *XiaomiModel) Name() string { return "xiaomi" } -func (m *XiaomiModel) baseURLForRegion(region string) (string, error) { - keys := []string{region} - if region != "" { - keys = append(keys, "", "default") - } else { - keys = append(keys, "default") - } - for _, key := range keys { - if base := strings.TrimRight(m.baseModel.BaseURL[key], "/"); base != "" { - return base, nil - } - } - return "", fmt.Errorf("xiaomi: no base URL configured for region %q", region) -} - -func (m *XiaomiModel) endpointURL(apiConfig *APIConfig) (string, error) { - if apiConfig != nil && apiConfig.BaseURL != nil && *apiConfig.BaseURL != "" { - return fmt.Sprintf("%s/%s", strings.TrimRight(*apiConfig.BaseURL, "/"), strings.TrimLeft(m.baseModel.URLSuffix.Chat, "/")), nil - } - - region := "" - if apiConfig != nil && apiConfig.Region != nil { - region = *apiConfig.Region - } - baseURL, err := m.baseURLForRegion(region) - if err != nil { - return "", err - } - return fmt.Sprintf("%s/%s", baseURL, strings.TrimLeft(m.baseModel.URLSuffix.Chat, "/")), nil -} - -func xiaomiAPIKey(apiConfig *APIConfig) (string, error) { - if apiConfig == nil || apiConfig.ApiKey == nil || strings.TrimSpace(*apiConfig.ApiKey) == "" { - return "", fmt.Errorf("api key is required") - } - return *apiConfig.ApiKey, nil -} - -type xiaomiAPIMessage struct { - Role string `json:"role"` - Content interface{} `json:"content"` -} - -type xiaomiThinking struct { - Type string `json:"type"` -} - -type xiaomiChatRequest struct { - Model string `json:"model"` - Messages []xiaomiAPIMessage `json:"messages"` - Stream bool `json:"stream"` - MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` - TopP *float64 `json:"top_p,omitempty"` - Stop *[]string `json:"stop,omitempty"` - Thinking *xiaomiThinking `json:"thinking,omitempty"` -} - -func buildXiaomiChatRequest(modelName string, messages []Message, stream bool, chatModelConfig *ChatConfig) xiaomiChatRequest { - apiMessages := make([]xiaomiAPIMessage, len(messages)) - for i, msg := range messages { - apiMessages[i] = xiaomiAPIMessage{ - Role: msg.Role, - Content: msg.Content, - } - } - - reqBody := xiaomiChatRequest{ - Model: modelName, - Messages: apiMessages, - Stream: stream, - } - if chatModelConfig != nil { - reqBody.MaxCompletionTokens = chatModelConfig.MaxTokens - reqBody.Temperature = chatModelConfig.Temperature - reqBody.TopP = chatModelConfig.TopP - reqBody.Stop = chatModelConfig.Stop - if chatModelConfig.Thinking != nil { - if *chatModelConfig.Thinking { - reqBody.Thinking = &xiaomiThinking{Type: "enabled"} - } else { - reqBody.Thinking = &xiaomiThinking{Type: "disabled"} - } - } - } - return reqBody -} - -type xiaomiChatMessage struct { - Content *string `json:"content"` - ReasoningContent string `json:"reasoning_content"` -} - -type xiaomiChatDelta struct { - Content string `json:"content"` - ReasoningContent string `json:"reasoning_content"` -} - -type xiaomiChatChoice struct { - Message xiaomiChatMessage `json:"message"` - Delta xiaomiChatDelta `json:"delta"` - FinishReason string `json:"finish_reason"` -} - -type xiaomiChatResponse struct { - Choices []xiaomiChatChoice `json:"choices"` - Error interface{} `json:"error"` -} - -func newXiaomiJSONRequest(ctx context.Context, method, endpoint string, payload interface{}, apiKey string) (*http.Request, error) { - var body io.Reader - if payload != nil { - jsonData, err := json.Marshal(payload) - if err != nil { - return nil, fmt.Errorf("failed to marshal request: %w", err) - } - body = bytes.NewBuffer(jsonData) - } - - req, err := http.NewRequestWithContext(ctx, method, endpoint, body) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("api-key", apiKey) - return req, nil -} - -func (m *XiaomiModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { - apiKey, err := xiaomiAPIKey(apiConfig) - if err != nil { +func (x *XiaomiModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { + if err := x.baseModel.APIConfigCheck(apiConfig); err != nil { return nil, err } - if strings.TrimSpace(modelName) == "" { - return nil, fmt.Errorf("model name is required") - } if len(messages) == 0 { return nil, fmt.Errorf("messages is empty") } - endpoint, err := m.endpointURL(apiConfig) + resolvedBaseURL, err := x.baseModel.GetBaseURL(apiConfig) if err != nil { return nil, err } + url := fmt.Sprintf("%s/%s", resolvedBaseURL, x.baseModel.URLSuffix.Chat) + + // Convert messages to API format + apiMessages := make([]map[string]interface{}, len(messages)) + for i, msg := range messages { + apiMessages[i] = map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, + } + } + + // Build request body + reqBody := map[string]interface{}{ + "model": modelName, + "messages": apiMessages, + "stream": false, + "temperature": 1, + } + + if chatModelConfig != nil { + if chatModelConfig.Stream != nil { + reqBody["stream"] = *chatModelConfig.Stream + } + + if chatModelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *chatModelConfig.MaxTokens + } + + if chatModelConfig.Temperature != nil { + reqBody["temperature"] = *chatModelConfig.Temperature + } + + if chatModelConfig.TopP != nil { + reqBody["top_p"] = *chatModelConfig.TopP + } + + if chatModelConfig.Stop != nil { + reqBody["stop"] = *chatModelConfig.Stop + } + + if chatModelConfig.Thinking != nil { + if *chatModelConfig.Thinking { + reqBody["thinking"] = map[string]interface{}{ + "type": "enabled", + } + } else { + reqBody["thinking"] = map[string]interface{}{ + "type": "disabled", + } + } + } + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } - reqBody := buildXiaomiChatRequest(modelName, messages, false, chatModelConfig) ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) defer cancel() - req, err := newXiaomiJSONRequest(ctx, http.MethodPost, endpoint, reqBody, apiKey) + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to create request: %w", err) } - resp, err := m.baseModel.httpClient.Do(req) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("api-key", *apiConfig.ApiKey) + + resp, err := x.baseModel.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("failed to send request: %w", err) } @@ -237,61 +155,146 @@ func (m *XiaomiModel) ChatWithMessages(modelName string, messages []Message, api if err != nil { return nil, fmt.Errorf("failed to read response: %w", err) } + if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("xiaomi chat API error: %s, body: %s", resp.Status, string(body)) + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) } - var parsed xiaomiChatResponse - if err := json.Unmarshal(body, &parsed); err != nil { + // Parse response + var result map[string]interface{} + if err := json.Unmarshal(body, &result); err != nil { return nil, fmt.Errorf("failed to parse response: %w", err) } - if parsed.Error != nil { - return nil, fmt.Errorf("xiaomi: upstream error: %v", parsed.Error) - } - if len(parsed.Choices) == 0 { + + choices, ok := result["choices"].([]interface{}) + if !ok || len(choices) == 0 { return nil, fmt.Errorf("no choices in response") } - if parsed.Choices[0].Message.Content == nil { + + firstChoice, ok := choices[0].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid choice format") + } + + messageMap, ok := firstChoice["message"].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid message format") + } + + content, ok := messageMap["content"].(string) + if !ok { return nil, fmt.Errorf("invalid content format") } - content := *parsed.Choices[0].Message.Content - reasonContent := parsed.Choices[0].Message.ReasoningContent - return &ChatResponse{ + var reasonContent string + if chatModelConfig != nil && chatModelConfig.Thinking != nil && *chatModelConfig.Thinking { + reasonContent, ok = messageMap["reasoning_content"].(string) + if !ok { + // If reasoning_content not in response, try parsing from content tags + reasoning, answer := GetThinkingAndAnswer(chatModelConfig.ModelClass, &content) + if reasoning != nil { + reasonContent = *reasoning + content = *answer + } + } else { + // if first char of reasonContent is \n remove the '\n' + if reasonContent != "" && reasonContent[0] == '\n' { + reasonContent = reasonContent[1:] + } + } + } + + chatResponse := &ChatResponse{ Answer: &content, ReasonContent: &reasonContent, - }, nil + } + + return chatResponse, nil } -func (m *XiaomiModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig, sender func(*string, *string) error) error { - if sender == nil { - return fmt.Errorf("sender is required") - } - apiKey, err := xiaomiAPIKey(apiConfig) - if err != nil { +func (x *XiaomiModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, modelConfig *ChatConfig, sender func(*string, *string) error) error { + if err := x.baseModel.APIConfigCheck(apiConfig); err != nil { return err } - if strings.TrimSpace(modelName) == "" { - return fmt.Errorf("model name is required") - } + if len(messages) == 0 { return fmt.Errorf("messages is empty") } - if chatModelConfig != nil && chatModelConfig.Stream != nil && !*chatModelConfig.Stream { - return fmt.Errorf("stream must be true in ChatStreamlyWithSender") - } - endpoint, err := m.endpointURL(apiConfig) + baseURL, err := x.baseModel.GetBaseURL(apiConfig) if err != nil { return err } - reqBody := buildXiaomiChatRequest(modelName, messages, true, chatModelConfig) - req, err := newXiaomiJSONRequest(context.Background(), http.MethodPost, endpoint, reqBody, apiKey) - if err != nil { - return err + url := fmt.Sprintf("%s/%s", baseURL, x.baseModel.URLSuffix.Chat) + + // Convert messages to API format + apiMessages := make([]map[string]interface{}, len(messages)) + for i, msg := range messages { + apiMessages[i] = map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, + } } - resp, err := m.baseModel.httpClient.Do(req) + // Build request body with streaming enabled + reqBody := map[string]interface{}{ + "model": modelName, + "messages": apiMessages, + "stream": true, + "temperature": 1, + } + + if modelConfig != nil { + if modelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *modelConfig.MaxTokens + } + + if modelConfig.Temperature != nil { + reqBody["temperature"] = *modelConfig.Temperature + } + + if modelConfig.DoSample != nil { + reqBody["do_sample"] = *modelConfig.DoSample + } + + if modelConfig.TopP != nil { + reqBody["top_p"] = *modelConfig.TopP + } + + if modelConfig.Stop != nil { + reqBody["stop"] = *modelConfig.Stop + } + + if modelConfig.Thinking != nil { + if *modelConfig.Thinking { + reqBody["thinking"] = map[string]interface{}{ + "type": "enabled", + } + } else { + reqBody["thinking"] = map[string]interface{}{ + "type": "disabled", + } + } + } + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), streamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("api-key", *apiConfig.ApiKey) + + resp, err := x.baseModel.httpClient.Do(req) if err != nil { return fmt.Errorf("failed to send request: %w", err) } @@ -299,141 +302,597 @@ func (m *XiaomiModel) ChatStreamlyWithSender(modelName string, messages []Messag if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) - return fmt.Errorf("xiaomi chat stream API error: %s, body: %s", resp.Status, string(body)) + return fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) } + // SSE parsing: read line by line scanner := bufio.NewScanner(resp.Body) scanner.Buffer(make([]byte, 64*1024), 1024*1024) - sawTerminal := false - var dataLines []string - dispatchEvent := func() (bool, error) { - if len(dataLines) == 0 { - return false, nil - } - payload := strings.Join(dataLines, "\n") - dataLines = dataLines[:0] - if payload == "[DONE]" { - sawTerminal = true - return true, nil - } - - var event xiaomiChatResponse - if err := json.Unmarshal([]byte(payload), &event); err != nil { - return false, fmt.Errorf("xiaomi: invalid SSE event: %w", err) - } - if event.Error != nil { - return false, fmt.Errorf("xiaomi: upstream stream error: %v", event.Error) - } - if len(event.Choices) == 0 { - return false, nil - } - choice := event.Choices[0] - if choice.Delta.ReasoningContent != "" { - r := choice.Delta.ReasoningContent - if err := sender(nil, &r); err != nil { - return false, err - } - } - if choice.Delta.Content != "" { - c := choice.Delta.Content - if err := sender(&c, nil); err != nil { - return false, err - } - } - if choice.FinishReason != "" { - sawTerminal = true - return true, nil - } - return false, nil - } - for scanner.Scan() { - line := strings.TrimSuffix(scanner.Text(), "\r") - if line == "" { - stop, err := dispatchEvent() - if err != nil { - return err - } - if stop { - break - } + line := scanner.Text() + common.Info(line) + + // SSE data line starts with "data:" + if !strings.HasPrefix(line, "data:") { continue } - if strings.HasPrefix(line, "data:") { - value := line[5:] - if strings.HasPrefix(value, " ") { - value = value[1:] + + // Extract JSON after "data:" + data := strings.TrimSpace(line[5:]) + + // [DONE] marks the end of stream + if data == "[DONE]" { + break + } + + // Parse the JSON event + var event map[string]interface{} + if err = json.Unmarshal([]byte(data), &event); err != nil { + continue + } + + choices, ok := event["choices"].([]interface{}) + if !ok || len(choices) == 0 { + continue + } + + firstChoice, ok := choices[0].(map[string]interface{}) + if !ok { + continue + } + + delta, ok := firstChoice["delta"].(map[string]interface{}) + if !ok { + continue + } + + reasoningContent, ok := delta["reasoning_content"].(string) + if ok && reasoningContent != "" { + if err := sender(nil, &reasoningContent); err != nil { + return err } - dataLines = append(dataLines, value) + } + + content, ok := delta["content"].(string) + if ok && content != "" { + if err := sender(&content, nil); err != nil { + return err + } + } + + finishReason, ok := firstChoice["finish_reason"].(string) + if ok && finishReason != "" { + break } } - if !sawTerminal { - if _, err := dispatchEvent(); err != nil { + + // Send [DONE] marker for OpenAI compatibility + endOfStream := "[DONE]" + if err = sender(&endOfStream, nil); err != nil { + return err + } + + return scanner.Err() +} + +func (x *XiaomiModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { + return nil, fmt.Errorf("no such method %s", x.Name()) +} + +func (x *XiaomiModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { + return nil, fmt.Errorf("no such method %s", x.Name()) +} + +func (x *XiaomiModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + ctx, cancel := context.WithTimeout(context.Background(), longOpCallTimeout) + defer cancel() + + req, err := x.newXiaomiASRRequest(ctx, modelName, file, apiConfig, asrConfig, false) + if err != nil { + return nil, err + } + + resp, err := x.baseModel.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Xiaomi ASR API error: %s, body: %s", resp.Status, string(body)) + } + + var result xiaomiChatCompletionResponse + if err = json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse Xiaomi ASR response: %w, body=%s", err, string(body)) + } + if len(result.Choices) == 0 { + return nil, fmt.Errorf("no choices in Xiaomi ASR response") + } + + return &ASRResponse{Text: result.Choices[0].Message.Content}, nil +} + +func (x *XiaomiModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + if sender == nil { + return fmt.Errorf("sender is required") + } + + ctx, cancel := context.WithTimeout(context.Background(), streamCallTimeout) + defer cancel() + + req, err := x.newXiaomiASRRequest(ctx, modelName, file, apiConfig, asrConfig, true) + if err != nil { + return err + } + req.Header.Set("Accept", "text/event-stream") + + resp, err := x.baseModel.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("Xiaomi ASR stream API error: %s, body: %s", resp.Status, string(body)) + } + + if !strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") { + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read response body: %w", err) + } + response, err := decodeXiaomiASRResponse(body) + if err != nil { + return err + } + if response.Text != "" { + if err = sender(&response.Text, nil); err != nil { + return err + } + } + done := "[DONE]" + return sender(&done, nil) + } + + return readXiaomiASRStream(resp.Body, sender) +} + +type xiaomiChatCompletionResponse struct { + Choices []struct { + Message struct { + Content string `json:"content"` + Audio *xiaomiAudioPayload `json:"audio"` + } `json:"message"` + } `json:"choices"` +} + +type xiaomiChatCompletionChunk struct { + Choices []struct { + Delta struct { + Content *string `json:"content"` + Audio *xiaomiAudioPayload `json:"audio"` + } `json:"delta"` + FinishReason *string `json:"finish_reason"` + } `json:"choices"` +} + +type xiaomiAudioPayload struct { + ID string `json:"id"` + Data string `json:"data"` +} + +func (x *XiaomiModel) newXiaomiASRRequest(ctx context.Context, modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, stream bool) (*http.Request, error) { + if err := x.baseModel.APIConfigCheck(apiConfig); err != nil { + return nil, err + } + if modelName == nil || *modelName == "" { + return nil, fmt.Errorf("model name is required") + } + if file == nil || *file == "" { + return nil, fmt.Errorf("file is missing") + } + if strings.TrimSpace(x.baseModel.URLSuffix.Chat) == "" { + return nil, fmt.Errorf("xiaomi chat URL suffix is required") + } + + audio, err := os.ReadFile(*file) + if err != nil { + return nil, fmt.Errorf("failed to read audio file: %w", err) + } + + mimeType := xiaomiAudioMIMEType(*file, audio, asrConfig) + reqBody := map[string]interface{}{ + "model": *modelName, + "messages": []map[string]interface{}{ + { + "role": "user", + "content": []map[string]interface{}{ + { + "type": "input_audio", + "input_audio": map[string]interface{}{ + "data": fmt.Sprintf("data:%s;base64,%s", mimeType, base64.StdEncoding.EncodeToString(audio)), + }, + }, + }, + }, + }, + "asr_options": xiaomiASROptions(asrConfig), + } + if stream { + reqBody["stream"] = true + } + if asrConfig != nil && asrConfig.Params != nil { + for key, value := range asrConfig.Params { + switch key { + case "asr_options", "language", "mime", "mime_type", "model", "messages", "stream": + continue + default: + reqBody[key] = value + } + } + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + baseURL, err := x.baseModel.GetBaseURL(apiConfig) + if err != nil { + return nil, err + } + url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), strings.TrimPrefix(x.baseModel.URLSuffix.Chat, "/")) + + if ctx == nil { + ctx = context.Background() + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("api-key", *apiConfig.ApiKey) + + return req, nil +} + +func xiaomiASROptions(asrConfig *ASRConfig) map[string]interface{} { + options := map[string]interface{}{"language": "auto"} + if asrConfig == nil || asrConfig.Params == nil { + return options + } + if rawOptions, ok := asrConfig.Params["asr_options"].(map[string]interface{}); ok { + for key, value := range rawOptions { + options[key] = value + } + } + if language, ok := asrConfig.Params["language"]; ok && language != nil && fmt.Sprint(language) != "" { + options["language"] = language + } + return options +} + +func xiaomiAudioMIMEType(file string, audio []byte, asrConfig *ASRConfig) string { + if asrConfig != nil && asrConfig.Params != nil { + for _, key := range []string{"mime_type", "mime"} { + if value, ok := asrConfig.Params[key].(string); ok && strings.TrimSpace(value) != "" { + return strings.TrimSpace(value) + } + } + } + if detected := mime.TypeByExtension(strings.ToLower(filepath.Ext(file))); detected != "" { + return detected + } + if len(audio) > 0 { + return http.DetectContentType(audio) + } + return "application/octet-stream" +} + +func decodeXiaomiASRResponse(body []byte) (*ASRResponse, error) { + var result xiaomiChatCompletionResponse + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse Xiaomi ASR response: %w, body=%s", err, string(body)) + } + if len(result.Choices) == 0 { + return nil, fmt.Errorf("no choices in Xiaomi ASR response") + } + return &ASRResponse{Text: result.Choices[0].Message.Content}, nil +} + +func readXiaomiASRStream(body io.Reader, sender func(*string, *string) error) error { + scanner := bufio.NewScanner(body) + scanner.Buffer(make([]byte, 64*1024), 8*1024*1024) + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "data:") { + continue + } + + data := strings.TrimSpace(strings.TrimPrefix(line, "data:")) + if data == "" { + continue + } + if data == "[DONE]" { + break + } + + var chunk xiaomiChatCompletionChunk + if err := json.Unmarshal([]byte(data), &chunk); err != nil { + continue + } + if len(chunk.Choices) == 0 { + continue + } + + content := chunk.Choices[0].Delta.Content + if content != nil && *content != "" { + if err := sender(content, nil); err != nil { + return err + } + } + } + if err := scanner.Err(); err != nil { + return fmt.Errorf("error reading Xiaomi ASR stream: %w", err) + } + + done := "[DONE]" + return sender(&done, nil) +} + +func (x *XiaomiModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + ctx, cancel := context.WithTimeout(context.Background(), longOpCallTimeout) + defer cancel() + + req, err := x.newXiaomiTTSRequest(ctx, modelName, audioContent, apiConfig, ttsConfig, false) + if err != nil { + return nil, err + } + + resp, err := x.baseModel.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Xiaomi TTS API error: %s, body: %s", resp.Status, string(body)) + } + + return decodeXiaomiTTSResponse(body) +} + +func (x *XiaomiModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + if sender == nil { + return fmt.Errorf("sender is required") + } + + ctx, cancel := context.WithTimeout(context.Background(), streamCallTimeout) + defer cancel() + + req, err := x.newXiaomiTTSRequest(ctx, modelName, audioContent, apiConfig, ttsConfig, true) + if err != nil { + return err + } + req.Header.Set("Accept", "text/event-stream") + + resp, err := x.baseModel.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("Xiaomi TTS stream API error: %s, body: %s", resp.Status, string(body)) + } + + if !strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") { + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read response body: %w", err) + } + audio, err := decodeXiaomiTTSResponse(body) + if err != nil { + return err + } + if len(audio.Audio) > 0 { + chunk := base64.StdEncoding.EncodeToString(audio.Audio) + return sender(&chunk, nil) + } + return nil + } + + return readXiaomiTTSStream(resp.Body, sender) +} + +func (x *XiaomiModel) newXiaomiTTSRequest(ctx context.Context, modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, stream bool) (*http.Request, error) { + if err := x.baseModel.APIConfigCheck(apiConfig); err != nil { + return nil, err + } + if modelName == nil || *modelName == "" { + return nil, fmt.Errorf("model name is required") + } + if audioContent == nil || *audioContent == "" { + return nil, fmt.Errorf("audio content is empty") + } + if strings.TrimSpace(x.baseModel.URLSuffix.Chat) == "" { + return nil, fmt.Errorf("xiaomi chat URL suffix is required") + } + + reqBody := map[string]interface{}{ + "model": *modelName, + "messages": []map[string]interface{}{ + { + "role": "assistant", + "content": *audioContent, + }, + }, + "audio": xiaomiTTSOptions(ttsConfig), + } + if stream { + reqBody["stream"] = true + } + if ttsConfig != nil && ttsConfig.Params != nil { + for key, value := range ttsConfig.Params { + switch key { + case "audio", "format", "voice", "model", "messages", "stream": + continue + default: + reqBody[key] = value + } + } + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + baseURL, err := x.baseModel.GetBaseURL(apiConfig) + if err != nil { + return nil, err + } + url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), strings.TrimPrefix(x.baseModel.URLSuffix.Chat, "/")) + + if ctx == nil { + ctx = context.Background() + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("api-key", *apiConfig.ApiKey) + + return req, nil +} + +func xiaomiTTSOptions(ttsConfig *TTSConfig) map[string]interface{} { + options := map[string]interface{}{ + "format": "wav", + "voice": "mimo_default", + } + if ttsConfig == nil { + return options + } + if ttsConfig.Format != "" { + options["format"] = ttsConfig.Format + } + if ttsConfig.Params == nil { + return options + } + if rawOptions, ok := ttsConfig.Params["audio"].(map[string]interface{}); ok { + for key, value := range rawOptions { + options[key] = value + } + } + if format, ok := ttsConfig.Params["format"]; ok && format != nil && fmt.Sprint(format) != "" { + options["format"] = format + } + if voice, ok := ttsConfig.Params["voice"]; ok && voice != nil && fmt.Sprint(voice) != "" { + options["voice"] = voice + } + return options +} + +func decodeXiaomiTTSResponse(body []byte) (*TTSResponse, error) { + var result xiaomiChatCompletionResponse + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse Xiaomi TTS response: %w, body=%s", err, string(body)) + } + if len(result.Choices) == 0 || result.Choices[0].Message.Audio == nil || result.Choices[0].Message.Audio.Data == "" { + return nil, fmt.Errorf("no audio data in Xiaomi TTS response") + } + + audio, err := decodeXiaomiAudioData(result.Choices[0].Message.Audio.Data) + if err != nil { + return nil, fmt.Errorf("failed to decode Xiaomi TTS audio: %w", err) + } + return &TTSResponse{Audio: audio}, nil +} + +func readXiaomiTTSStream(body io.Reader, sender func(*string, *string) error) error { + scanner := bufio.NewScanner(body) + scanner.Buffer(make([]byte, 64*1024), 8*1024*1024) + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "data:") { + continue + } + + data := strings.TrimSpace(strings.TrimPrefix(line, "data:")) + if data == "" { + continue + } + if data == "[DONE]" { + break + } + + var chunk xiaomiChatCompletionChunk + if err := json.Unmarshal([]byte(data), &chunk); err != nil { + continue + } + if len(chunk.Choices) == 0 || chunk.Choices[0].Delta.Audio == nil || chunk.Choices[0].Delta.Audio.Data == "" { + continue + } + audioData := chunk.Choices[0].Delta.Audio.Data + if err := sender(&audioData, nil); err != nil { return err } } if err := scanner.Err(); err != nil { - return fmt.Errorf("failed to scan response body: %w", err) - } - if !sawTerminal { - return fmt.Errorf("xiaomi: stream ended before [DONE] or finish_reason") - } - - endOfStream := "[DONE]" - if err := sender(&endOfStream, nil); err != nil { - return err + return fmt.Errorf("error reading Xiaomi TTS stream: %w", err) } return nil } -func (m *XiaomiModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { - return nil, fmt.Errorf("%s, no such method", m.Name()) +func decodeXiaomiAudioData(data string) ([]byte, error) { + if comma := strings.Index(data, ","); strings.HasPrefix(data, "data:") && comma >= 0 { + data = data[comma+1:] + } + return base64.StdEncoding.DecodeString(data) } -func (m *XiaomiModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { - return nil, fmt.Errorf("%s, no such method", m.Name()) +func (x *XiaomiModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + return nil, fmt.Errorf("no such method %s", x.Name()) } -func (m *XiaomiModel) ListModels(apiConfig *APIConfig) ([]string, error) { - return nil, fmt.Errorf("%s, no such method", m.Name()) +func (x *XiaomiModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { + return nil, fmt.Errorf("no such method %s", x.Name()) } -func (m *XiaomiModel) CheckConnection(apiConfig *APIConfig) error { - return fmt.Errorf("%s, no such method", m.Name()) +func (x *XiaomiModel) ListModels(apiConfig *APIConfig) ([]string, error) { + return nil, fmt.Errorf("no such method %s", x.Name()) } -func (m *XiaomiModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { - return nil, fmt.Errorf("%s, no such method", m.Name()) +func (x *XiaomiModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { + return nil, fmt.Errorf("no such method %s", x.Name()) } -func (m *XiaomiModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { - return nil, fmt.Errorf("%s, no such method", m.Name()) +func (x *XiaomiModel) CheckConnection(apiConfig *APIConfig) error { + if err := x.baseModel.APIConfigCheck(apiConfig); err != nil { + return err + } + _, err := x.baseModel.GetBaseURL(apiConfig) + return err } -func (m *XiaomiModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { - return fmt.Errorf("%s, no such method", m.Name()) +func (x *XiaomiModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + return nil, fmt.Errorf("no such method %s", x.Name()) } -func (m *XiaomiModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { - return nil, fmt.Errorf("%s, no such method", m.Name()) -} - -func (m *XiaomiModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { - return fmt.Errorf("%s, no such method", m.Name()) -} - -func (m *XiaomiModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { - return nil, fmt.Errorf("%s, no such method", m.Name()) -} - -func (m *XiaomiModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { - return nil, fmt.Errorf("%s, no such method", m.Name()) -} - -func (m *XiaomiModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { - return nil, fmt.Errorf("%s, no such method", m.Name()) -} - -func (m *XiaomiModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { - return nil, fmt.Errorf("%s, no such method", m.Name()) +func (x *XiaomiModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + return nil, fmt.Errorf("no such method %s", x.Name()) }