From 5e25e2600bac3b9ac82513393915ca2cbbf03b65 Mon Sep 17 00:00:00 2001 From: tmimmanuel <14046872+tmimmanuel@users.noreply.github.com> Date: Sun, 7 Jun 2026 19:09:36 -1000 Subject: [PATCH] Go: implement Xiaomi chat provider (#15626) ### What problem does this PR solve? Implements the Xiaomi MiMo chat provider for the Go model provider layer. Reference issue: #14736 Official docs used: - Xiaomi MiMo OpenAI-compatible chat API: https://platform.xiaomimimo.com/docs/en-US/api/chat/openai-api - Xiaomi MiMo model and rate limits: https://platform.xiaomimimo.com/docs/en-US/quick-start/model - Xiaomi MiMo model hyperparameters: https://platform.xiaomimimo.com/docs/en-US/quick-start/model-hyperparameters --- conf/llm_factories.json | 30 ++ conf/models/xiaomi.json | 34 ++ internal/entity/models/factory.go | 2 + internal/entity/models/xiaomi.go | 439 ++++++++++++++++++++++++++ internal/entity/models/xiaomi_test.go | 355 +++++++++++++++++++++ 5 files changed, 860 insertions(+) create mode 100644 conf/models/xiaomi.json create mode 100644 internal/entity/models/xiaomi.go create mode 100644 internal/entity/models/xiaomi_test.go diff --git a/conf/llm_factories.json b/conf/llm_factories.json index 333cd8b310..83ab0480ab 100644 --- a/conf/llm_factories.json +++ b/conf/llm_factories.json @@ -6795,6 +6795,36 @@ "status": "1", "llm": [] }, + { + "name": "Xiaomi", + "logo": "", + "tags": "LLM,IMAGE2TEXT", + "status": "1", + "url": "https://api.xiaomimimo.com/v1", + "llm": [ + { + "llm_name": "mimo-v2.5-pro", + "tags": "LLM,CHAT,1M", + "max_tokens": 1048576, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "mimo-v2.5", + "tags": "LLM,CHAT,1M,IMAGE2TEXT", + "max_tokens": 1048576, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "mimo-v2-flash", + "tags": "LLM,CHAT,256K", + "max_tokens": 262144, + "model_type": "chat", + "is_tools": true + } + ] + }, { "name": "Perplexity", "logo": "", diff --git a/conf/models/xiaomi.json b/conf/models/xiaomi.json new file mode 100644 index 0000000000..8902867415 --- /dev/null +++ b/conf/models/xiaomi.json @@ -0,0 +1,34 @@ +{ + "name": "Xiaomi", + "url": { + "default": "https://api.xiaomimimo.com" + }, + "url_suffix": { + "chat": "v1/chat/completions" + }, + "class": "xiaomi", + "models": [ + { + "name": "mimo-v2.5-pro", + "max_tokens": 1048576, + "model_types": [ + "chat" + ] + }, + { + "name": "mimo-v2.5", + "max_tokens": 1048576, + "model_types": [ + "chat", + "vision" + ] + }, + { + "name": "mimo-v2-flash", + "max_tokens": 262144, + "model_types": [ + "chat" + ] + } + ] +} diff --git a/internal/entity/models/factory.go b/internal/entity/models/factory.go index 985961f898..714c772acc 100644 --- a/internal/entity/models/factory.go +++ b/internal/entity/models/factory.go @@ -151,6 +151,8 @@ func (f *ModelFactory) CreateModelDriver(providerName string, baseURL map[string return NewHuaweiCloudModel(baseURL, urlSuffix), nil case "qiniu": return NewQiniuModel(baseURL, urlSuffix), nil + case "xiaomi": + return NewXiaomiModel(baseURL, urlSuffix), nil default: return NewDummyModel(baseURL, urlSuffix), nil } diff --git a/internal/entity/models/xiaomi.go b/internal/entity/models/xiaomi.go new file mode 100644 index 0000000000..52809f3e79 --- /dev/null +++ b/internal/entity/models/xiaomi.go @@ -0,0 +1,439 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package models + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +// XiaomiModel implements ModelDriver for Xiaomi MiMo chat models. +// +// Xiaomi MiMo documents an OpenAI-compatible chat completions endpoint at +// https://api.xiaomimimo.com/v1/chat/completions. The documented request +// sample uses api-key authentication and max_completion_tokens, so this +// driver follows that wire shape instead of blindly reusing max_tokens. +type XiaomiModel struct { + baseModel BaseModel +} + +func NewXiaomiModel(baseURL map[string]string, urlSuffix URLSuffix) *XiaomiModel { + defaultTransport, ok := http.DefaultTransport.(*http.Transport) + var transport *http.Transport + if ok { + transport = defaultTransport.Clone() + } else { + transport = &http.Transport{ + Proxy: http.ProxyFromEnvironment, + } + } + transport.MaxIdleConns = 100 + transport.MaxIdleConnsPerHost = 10 + transport.IdleConnTimeout = 90 * time.Second + transport.DisableCompression = false + transport.ResponseHeaderTimeout = 60 * time.Second + + return &XiaomiModel{ + baseModel: BaseModel{ + BaseURL: baseURL, + URLSuffix: urlSuffix, + httpClient: &http.Client{ + Transport: transport, + }, + }, + } +} + +func (m *XiaomiModel) NewInstance(baseURL map[string]string) ModelDriver { + return NewXiaomiModel(baseURL, m.baseModel.URLSuffix) +} + +func (m *XiaomiModel) Name() string { + return "xiaomi" +} + +func (m *XiaomiModel) baseURLForRegion(region string) (string, error) { + keys := []string{region} + if region != "" { + keys = append(keys, "", "default") + } else { + keys = append(keys, "default") + } + for _, key := range keys { + if base := strings.TrimRight(m.baseModel.BaseURL[key], "/"); base != "" { + return base, nil + } + } + return "", fmt.Errorf("xiaomi: no base URL configured for region %q", region) +} + +func (m *XiaomiModel) endpointURL(apiConfig *APIConfig) (string, error) { + if apiConfig != nil && apiConfig.BaseURL != nil && *apiConfig.BaseURL != "" { + return fmt.Sprintf("%s/%s", strings.TrimRight(*apiConfig.BaseURL, "/"), strings.TrimLeft(m.baseModel.URLSuffix.Chat, "/")), nil + } + + region := "" + if apiConfig != nil && apiConfig.Region != nil { + region = *apiConfig.Region + } + baseURL, err := m.baseURLForRegion(region) + if err != nil { + return "", err + } + return fmt.Sprintf("%s/%s", baseURL, strings.TrimLeft(m.baseModel.URLSuffix.Chat, "/")), nil +} + +func xiaomiAPIKey(apiConfig *APIConfig) (string, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || strings.TrimSpace(*apiConfig.ApiKey) == "" { + return "", fmt.Errorf("api key is required") + } + return *apiConfig.ApiKey, nil +} + +type xiaomiAPIMessage struct { + Role string `json:"role"` + Content interface{} `json:"content"` +} + +type xiaomiThinking struct { + Type string `json:"type"` +} + +type xiaomiChatRequest struct { + Model string `json:"model"` + Messages []xiaomiAPIMessage `json:"messages"` + Stream bool `json:"stream"` + MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + Stop *[]string `json:"stop,omitempty"` + Thinking *xiaomiThinking `json:"thinking,omitempty"` +} + +func buildXiaomiChatRequest(modelName string, messages []Message, stream bool, chatModelConfig *ChatConfig) xiaomiChatRequest { + apiMessages := make([]xiaomiAPIMessage, len(messages)) + for i, msg := range messages { + apiMessages[i] = xiaomiAPIMessage{ + Role: msg.Role, + Content: msg.Content, + } + } + + reqBody := xiaomiChatRequest{ + Model: modelName, + Messages: apiMessages, + Stream: stream, + } + if chatModelConfig != nil { + reqBody.MaxCompletionTokens = chatModelConfig.MaxTokens + reqBody.Temperature = chatModelConfig.Temperature + reqBody.TopP = chatModelConfig.TopP + reqBody.Stop = chatModelConfig.Stop + if chatModelConfig.Thinking != nil { + if *chatModelConfig.Thinking { + reqBody.Thinking = &xiaomiThinking{Type: "enabled"} + } else { + reqBody.Thinking = &xiaomiThinking{Type: "disabled"} + } + } + } + return reqBody +} + +type xiaomiChatMessage struct { + Content *string `json:"content"` + ReasoningContent string `json:"reasoning_content"` +} + +type xiaomiChatDelta struct { + Content string `json:"content"` + ReasoningContent string `json:"reasoning_content"` +} + +type xiaomiChatChoice struct { + Message xiaomiChatMessage `json:"message"` + Delta xiaomiChatDelta `json:"delta"` + FinishReason string `json:"finish_reason"` +} + +type xiaomiChatResponse struct { + Choices []xiaomiChatChoice `json:"choices"` + Error interface{} `json:"error"` +} + +func newXiaomiJSONRequest(ctx context.Context, method, endpoint string, payload interface{}, apiKey string) (*http.Request, error) { + var body io.Reader + if payload != nil { + jsonData, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + body = bytes.NewBuffer(jsonData) + } + + req, err := http.NewRequestWithContext(ctx, method, endpoint, body) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("api-key", apiKey) + return req, nil +} + +func (m *XiaomiModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { + apiKey, err := xiaomiAPIKey(apiConfig) + if err != nil { + return nil, err + } + if strings.TrimSpace(modelName) == "" { + return nil, fmt.Errorf("model name is required") + } + if len(messages) == 0 { + return nil, fmt.Errorf("messages is empty") + } + + endpoint, err := m.endpointURL(apiConfig) + if err != nil { + return nil, err + } + + reqBody := buildXiaomiChatRequest(modelName, messages, false, chatModelConfig) + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := newXiaomiJSONRequest(ctx, http.MethodPost, endpoint, reqBody, apiKey) + if err != nil { + return nil, err + } + + resp, err := m.baseModel.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("xiaomi chat API error: %s, body: %s", resp.Status, string(body)) + } + + var parsed xiaomiChatResponse + if err := json.Unmarshal(body, &parsed); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + if parsed.Error != nil { + return nil, fmt.Errorf("xiaomi: upstream error: %v", parsed.Error) + } + if len(parsed.Choices) == 0 { + return nil, fmt.Errorf("no choices in response") + } + if parsed.Choices[0].Message.Content == nil { + return nil, fmt.Errorf("invalid content format") + } + + content := *parsed.Choices[0].Message.Content + reasonContent := parsed.Choices[0].Message.ReasoningContent + return &ChatResponse{ + Answer: &content, + ReasonContent: &reasonContent, + }, nil +} + +func (m *XiaomiModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig, sender func(*string, *string) error) error { + if sender == nil { + return fmt.Errorf("sender is required") + } + apiKey, err := xiaomiAPIKey(apiConfig) + if err != nil { + return err + } + if strings.TrimSpace(modelName) == "" { + return fmt.Errorf("model name is required") + } + if len(messages) == 0 { + return fmt.Errorf("messages is empty") + } + if chatModelConfig != nil && chatModelConfig.Stream != nil && !*chatModelConfig.Stream { + return fmt.Errorf("stream must be true in ChatStreamlyWithSender") + } + + endpoint, err := m.endpointURL(apiConfig) + if err != nil { + return err + } + reqBody := buildXiaomiChatRequest(modelName, messages, true, chatModelConfig) + req, err := newXiaomiJSONRequest(context.Background(), http.MethodPost, endpoint, reqBody, apiKey) + if err != nil { + return err + } + + resp, err := m.baseModel.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("xiaomi chat stream API error: %s, body: %s", resp.Status, string(body)) + } + + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 64*1024), 1024*1024) + sawTerminal := false + var dataLines []string + dispatchEvent := func() (bool, error) { + if len(dataLines) == 0 { + return false, nil + } + payload := strings.Join(dataLines, "\n") + dataLines = dataLines[:0] + if payload == "[DONE]" { + sawTerminal = true + return true, nil + } + + var event xiaomiChatResponse + if err := json.Unmarshal([]byte(payload), &event); err != nil { + return false, fmt.Errorf("xiaomi: invalid SSE event: %w", err) + } + if event.Error != nil { + return false, fmt.Errorf("xiaomi: upstream stream error: %v", event.Error) + } + if len(event.Choices) == 0 { + return false, nil + } + choice := event.Choices[0] + if choice.Delta.ReasoningContent != "" { + r := choice.Delta.ReasoningContent + if err := sender(nil, &r); err != nil { + return false, err + } + } + if choice.Delta.Content != "" { + c := choice.Delta.Content + if err := sender(&c, nil); err != nil { + return false, err + } + } + if choice.FinishReason != "" { + sawTerminal = true + return true, nil + } + return false, nil + } + + for scanner.Scan() { + line := strings.TrimSuffix(scanner.Text(), "\r") + if line == "" { + stop, err := dispatchEvent() + if err != nil { + return err + } + if stop { + break + } + continue + } + if strings.HasPrefix(line, "data:") { + value := line[5:] + if strings.HasPrefix(value, " ") { + value = value[1:] + } + dataLines = append(dataLines, value) + } + } + if !sawTerminal { + if _, err := dispatchEvent(); err != nil { + return err + } + } + if err := scanner.Err(); err != nil { + return fmt.Errorf("failed to scan response body: %w", err) + } + if !sawTerminal { + return fmt.Errorf("xiaomi: stream ended before [DONE] or finish_reason") + } + + endOfStream := "[DONE]" + if err := sender(&endOfStream, nil); err != nil { + return err + } + return nil +} + +func (m *XiaomiModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { + return nil, fmt.Errorf("%s, no such method", m.Name()) +} + +func (m *XiaomiModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { + return nil, fmt.Errorf("%s, no such method", m.Name()) +} + +func (m *XiaomiModel) ListModels(apiConfig *APIConfig) ([]string, error) { + return nil, fmt.Errorf("%s, no such method", m.Name()) +} + +func (m *XiaomiModel) CheckConnection(apiConfig *APIConfig) error { + return fmt.Errorf("%s, no such method", m.Name()) +} + +func (m *XiaomiModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { + return nil, fmt.Errorf("%s, no such method", m.Name()) +} + +func (m *XiaomiModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", m.Name()) +} + +func (m *XiaomiModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", m.Name()) +} + +func (m *XiaomiModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", m.Name()) +} + +func (m *XiaomiModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", m.Name()) +} + +func (m *XiaomiModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", m.Name()) +} + +func (m *XiaomiModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", m.Name()) +} + +func (m *XiaomiModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + return nil, fmt.Errorf("%s, no such method", m.Name()) +} + +func (m *XiaomiModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + return nil, fmt.Errorf("%s, no such method", m.Name()) +} diff --git a/internal/entity/models/xiaomi_test.go b/internal/entity/models/xiaomi_test.go new file mode 100644 index 0000000000..1ae65c809c --- /dev/null +++ b/internal/entity/models/xiaomi_test.go @@ -0,0 +1,355 @@ +package models + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func newXiaomiServer(t *testing.T, expectedPath string, handler func(t *testing.T, r *http.Request, body map[string]interface{}, w http.ResponseWriter)) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != expectedPath { + t.Errorf("expected path=%s, got %s", expectedPath, r.URL.Path) + return + } + if got := r.Header.Get("api-key"); got != "test-key" { + t.Errorf("expected api-key=test-key, got %q", got) + return + } + if got := r.Header.Get("Authorization"); got != "" { + t.Errorf("expected no Authorization header, got %q", got) + return + } + if got := r.Header.Get("Content-Type"); !strings.HasPrefix(got, "application/json") { + t.Errorf("expected Content-Type to start with application/json, got %q", got) + return + } + raw, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("read body: %v", err) + return + } + var body map[string]interface{} + if err := json.Unmarshal(raw, &body); err != nil { + t.Errorf("unmarshal: %v\nraw=%s", err, string(raw)) + return + } + handler(t, r, body, w) + })) +} + +func newXiaomiForTest(baseURL string) *XiaomiModel { + return NewXiaomiModel( + map[string]string{"default": baseURL}, + URLSuffix{Chat: "v1/chat/completions"}, + ) +} + +func TestXiaomiName(t *testing.T) { + if got := newXiaomiForTest("http://unused").Name(); got != "xiaomi" { + t.Errorf("Name()=%q, want xiaomi", got) + } +} + +func TestXiaomiFactory(t *testing.T) { + driver, err := NewModelFactory().CreateModelDriver("Xiaomi", map[string]string{"default": "http://unused"}, URLSuffix{}) + if err != nil { + t.Fatalf("factory: %v", err) + } + if _, ok := driver.(*XiaomiModel); !ok { + t.Fatalf("driver type=%T, want *XiaomiModel", driver) + } +} + +func TestXiaomiNewModelWithCustomDefaultTransport(t *testing.T) { + original := http.DefaultTransport + http.DefaultTransport = roundTripperFunc(func(*http.Request) (*http.Response, error) { + return nil, nil + }) + t.Cleanup(func() { + http.DefaultTransport = original + }) + + if model := NewXiaomiModel(map[string]string{"default": "http://unused"}, URLSuffix{}); model == nil { + t.Fatal("NewXiaomiModel returned nil") + } +} + +func TestXiaomiChatHappyPath(t *testing.T) { + srv := newXiaomiServer(t, "/v1/chat/completions", func(t *testing.T, _ *http.Request, body map[string]interface{}, w http.ResponseWriter) { + if body["model"] != "mimo-v2.5-pro" { + t.Errorf("model=%v", body["model"]) + } + if body["stream"] != false { + t.Errorf("stream=%v want false", body["stream"]) + } + if body["max_tokens"] != nil { + t.Errorf("max_tokens must not be sent: %v", body["max_tokens"]) + } + if body["max_completion_tokens"] != float64(1024) { + t.Errorf("max_completion_tokens=%v", body["max_completion_tokens"]) + } + thinking, ok := body["thinking"].(map[string]interface{}) + if !ok || thinking["type"] != "disabled" { + t.Errorf("thinking=%v", body["thinking"]) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "choices": []map[string]interface{}{{ + "message": map[string]interface{}{ + "role": "assistant", + "content": "pong", + }, + }}, + }) + }) + defer srv.Close() + + apiKey := "test-key" + maxTokens := 1024 + thinking := false + resp, err := newXiaomiForTest(srv.URL).ChatWithMessages( + "mimo-v2.5-pro", + []Message{{Role: "user", Content: "ping"}}, + &APIConfig{ApiKey: &apiKey}, + &ChatConfig{MaxTokens: &maxTokens, Thinking: &thinking}, + ) + if err != nil { + t.Fatalf("ChatWithMessages: %v", err) + } + if resp.Answer == nil || *resp.Answer != "pong" { + t.Fatalf("answer=%v", resp.Answer) + } + if resp.ReasonContent == nil || *resp.ReasonContent != "" { + t.Fatalf("reasoning=%v", resp.ReasonContent) + } +} + +func TestXiaomiUsesEmptyRegionBaseURLOverride(t *testing.T) { + srv := newXiaomiServer(t, "/v1/chat/completions", func(t *testing.T, _ *http.Request, _ map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "choices": []map[string]interface{}{{ + "message": map[string]interface{}{ + "content": "pong", + }, + }}, + }) + }) + defer srv.Close() + + apiKey := "test-key" + m := NewXiaomiModel( + map[string]string{"": srv.URL}, + URLSuffix{Chat: "v1/chat/completions"}, + ) + resp, err := m.ChatWithMessages("mimo-v2.5-pro", []Message{{Role: "user", Content: "ping"}}, &APIConfig{ApiKey: &apiKey}, nil) + if err != nil { + t.Fatalf("ChatWithMessages: %v", err) + } + if resp.Answer == nil || *resp.Answer != "pong" { + t.Fatalf("answer=%v", resp.Answer) + } +} + +func TestXiaomiAPIConfigBaseURLOverridesRegionMap(t *testing.T) { + srv := newXiaomiServer(t, "/override/chat", func(t *testing.T, _ *http.Request, _ map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "choices": []map[string]interface{}{{ + "message": map[string]interface{}{ + "content": "override", + }, + }}, + }) + }) + defer srv.Close() + + apiKey := "test-key" + baseURL := srv.URL + m := NewXiaomiModel( + map[string]string{"default": "http://unused"}, + URLSuffix{Chat: "override/chat"}, + ) + resp, err := m.ChatWithMessages("mimo-v2.5-pro", []Message{{Role: "user", Content: "ping"}}, &APIConfig{ApiKey: &apiKey, BaseURL: &baseURL}, nil) + if err != nil { + t.Fatalf("ChatWithMessages: %v", err) + } + if resp.Answer == nil || *resp.Answer != "override" { + t.Fatalf("answer=%v", resp.Answer) + } +} + +func TestXiaomiChatExtractsReasoningContent(t *testing.T) { + srv := newXiaomiServer(t, "/v1/chat/completions", func(t *testing.T, _ *http.Request, _ map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "choices": []map[string]interface{}{{ + "message": map[string]interface{}{ + "content": "final", + "reasoning_content": "think", + }, + }}, + }) + }) + defer srv.Close() + + apiKey := "test-key" + resp, err := newXiaomiForTest(srv.URL).ChatWithMessages("mimo-v2.5-pro", []Message{{Role: "user", Content: "x"}}, &APIConfig{ApiKey: &apiKey}, nil) + if err != nil { + t.Fatalf("ChatWithMessages: %v", err) + } + if *resp.Answer != "final" || *resp.ReasonContent != "think" { + t.Fatalf("response=%+v", resp) + } +} + +func TestXiaomiChatRequiresInputs(t *testing.T) { + apiKey := "test-key" + m := newXiaomiForTest("http://unused") + if _, err := m.ChatWithMessages("mimo-v2.5-pro", []Message{{Role: "user", Content: "x"}}, &APIConfig{}, nil); err == nil || !strings.Contains(err.Error(), "api key is required") { + t.Errorf("api key guard: %v", err) + } + if _, err := m.ChatWithMessages("", []Message{{Role: "user", Content: "x"}}, &APIConfig{ApiKey: &apiKey}, nil); err == nil || !strings.Contains(err.Error(), "model name is required") { + t.Errorf("model guard: %v", err) + } + if _, err := m.ChatWithMessages("mimo-v2.5-pro", nil, &APIConfig{ApiKey: &apiKey}, nil); err == nil || !strings.Contains(err.Error(), "messages is empty") { + t.Errorf("messages guard: %v", err) + } +} + +func TestXiaomiChatRejectsHTTPError(t *testing.T) { + srv := newXiaomiServer(t, "/v1/chat/completions", func(t *testing.T, _ *http.Request, _ map[string]interface{}, w http.ResponseWriter) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = io.WriteString(w, `{"error":"unauthorized"}`) + }) + defer srv.Close() + + apiKey := "test-key" + _, err := newXiaomiForTest(srv.URL).ChatWithMessages("mimo-v2.5-pro", []Message{{Role: "user", Content: "x"}}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "401") { + t.Errorf("expected 401 propagated, got %v", err) + } +} + +func TestXiaomiStreamHappyPath(t *testing.T) { + srv := newXiaomiServer(t, "/v1/chat/completions", func(t *testing.T, _ *http.Request, body map[string]interface{}, w http.ResponseWriter) { + if body["stream"] != true { + t.Errorf("stream=%v want true", body["stream"]) + } + w.Header().Set("Content-Type", "text/event-stream") + _, _ = io.WriteString(w, + `data: {"choices":[{"delta":{"reasoning_content":"step "}}]}`+"\n\n"+ + `data: {"choices":[{"delta":{"content":"Hello"}}]}`+"\n\n"+ + `data: {"choices":[{"delta":{"content":" world"},"finish_reason":"stop"}]}`+"\n\n", + ) + }) + defer srv.Close() + + apiKey := "test-key" + var content, reasoning []string + var sawDone bool + err := newXiaomiForTest(srv.URL).ChatStreamlyWithSender( + "mimo-v2.5-pro", + []Message{{Role: "user", Content: "hi"}}, + &APIConfig{ApiKey: &apiKey}, + nil, + func(c *string, r *string) error { + if c != nil && *c == "[DONE]" { + sawDone = true + return nil + } + if c != nil { + content = append(content, *c) + } + if r != nil { + reasoning = append(reasoning, *r) + } + return nil + }, + ) + if err != nil { + t.Fatalf("stream: %v", err) + } + if strings.Join(content, "") != "Hello world" { + t.Errorf("content=%v", content) + } + if strings.Join(reasoning, "") != "step " { + t.Errorf("reasoning=%v", reasoning) + } + if !sawDone { + t.Error("expected [DONE] sentinel") + } +} + +func TestXiaomiStreamHandlesCRLFFrames(t *testing.T) { + srv := newXiaomiServer(t, "/v1/chat/completions", func(t *testing.T, _ *http.Request, _ map[string]interface{}, w http.ResponseWriter) { + w.Header().Set("Content-Type", "text/event-stream") + _, _ = io.WriteString(w, + "data: {\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\r\n\r\n"+ + "data: {\"choices\":[{\"delta\":{\"content\":\" world\"},\"finish_reason\":\"stop\"}]}\r\n\r\n", + ) + }) + defer srv.Close() + + apiKey := "test-key" + var content []string + err := newXiaomiForTest(srv.URL).ChatStreamlyWithSender( + "mimo-v2.5-pro", + []Message{{Role: "user", Content: "hi"}}, + &APIConfig{ApiKey: &apiKey}, + nil, + func(c *string, _ *string) error { + if c != nil && *c != "[DONE]" { + content = append(content, *c) + } + return nil + }, + ) + if err != nil { + t.Fatalf("stream: %v", err) + } + if strings.Join(content, "") != "Hello world" { + t.Errorf("content=%v", content) + } +} + +func TestXiaomiStreamRejectsMalformedFrame(t *testing.T) { + srv := newXiaomiServer(t, "/v1/chat/completions", func(t *testing.T, _ *http.Request, _ map[string]interface{}, w http.ResponseWriter) { + w.Header().Set("Content-Type", "text/event-stream") + _, _ = io.WriteString(w, "data: {bad json}\n\n") + }) + defer srv.Close() + + apiKey := "test-key" + err := newXiaomiForTest(srv.URL).ChatStreamlyWithSender("mimo-v2.5-pro", []Message{{Role: "user", Content: "x"}}, &APIConfig{ApiKey: &apiKey}, nil, func(*string, *string) error { return nil }) + if err == nil || !strings.Contains(err.Error(), "invalid SSE event") { + t.Errorf("expected invalid-SSE error, got %v", err) + } +} + +func TestXiaomiUnsupportedMethods(t *testing.T) { + m := newXiaomiForTest("http://unused") + model := "mimo-v2.5-pro" + apiKey := "test-key" + cfg := &APIConfig{ApiKey: &apiKey} + + if _, err := m.Embed(&model, []string{"x"}, cfg, nil); err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("Embed: %v", err) + } + if _, err := m.Rerank(&model, "q", []string{"d"}, cfg, nil); err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("Rerank: %v", err) + } + if err := m.CheckConnection(cfg); err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("CheckConnection: %v", err) + } + if _, err := m.TranscribeAudio(&model, nil, cfg, nil); err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("TranscribeAudio: %v", err) + } + if _, err := m.AudioSpeech(&model, nil, cfg, nil); err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("AudioSpeech: %v", err) + } + if _, err := m.OCRFile(&model, nil, nil, cfg, nil); err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("OCRFile: %v", err) + } +}