From 7d3836907aa0324d6ef7dfef233231996bbe2b3f Mon Sep 17 00:00:00 2001 From: tmimmanuel <14046872+tmimmanuel@users.noreply.github.com> Date: Mon, 11 May 2026 23:45:48 -1000 Subject: [PATCH] Go: implement Embed (embeddings) in Mistral driver (#14807) ### What problem does this PR solve? The Mistral Go driver landed in #14805 with chat, list models, and check connection. `Embed` was left as a stub that returns `"not implemented"`. This PR fills the gap. `conf/models/mistral.json` did not list any embedding model out of the box, so a tenant who wanted to use Mistral end to end (chat + embeddings) could not run an embedding call. This PR adds `mistral-embed` to the config and a real `/v1/embeddings` implementation. ### What this PR includes - `conf/models/mistral.json`: add `"embedding": "embeddings"` under `url_suffix` so the driver can build the URL from config (matches the `URLSuffix.Embedding` field already used by openai, siliconflow, zhipu-ai), and add a `mistral-embed` entry under `models` (1024-dimensional vectors, 8192 max input tokens). - `internal/entity/models/mistral.go`: replace the `Embed` stub with a real implementation that POSTs to `/v1/embeddings`. Adds local response types `mistralEmbeddingData` and `mistralEmbeddingResponse`. No factory change. No interface change. ### How the implementation works - Validate `apiConfig`, the API key, and the model name. Use the existing `baseURLForRegion` helper so an unknown region fails fast with a clear error. - Wrap the request with `context.WithTimeout(nonStreamCallTimeout)` so the call has a clear deadline. Same pattern as `ChatWithMessages` and `ListModels` already use in this file. - Send all input texts in one request. The Mistral API accepts the `input` field as an array. - Parse `data[*].embedding` and copy each slice into a `[]EmbeddingData` indexed by `data[*].index` so the output order matches the input order even if the API returns items in a different order. - An empty input slice returns `[]EmbeddingData{}` with no HTTP call. - Non-200 responses propagate the upstream status line and body. - A final pass checks that every input slot got a vector. If any slot is still empty, return a clear error so the caller does not silently use a zero vector. ### Note on stacking This PR builds on #14805 (the Mistral driver). Until #14805 merges, this PR's diff on GitHub will include both that PR's commits and this one. After #14805 lands on `main`, GitHub will auto-reduce this PR to only the `Embed` changes (one commit, ~111 line diff in `mistral.go` plus 8 lines in `mistral.json`). ### Type of change - [x] New Feature (non-breaking change which adds functionality) ### How was this tested? - `go build ./internal/entity/models/...` returns exit 0 on go 1.25 (the `go.mod` minimum). - The full method set on `MistralModel` still matches the `ModelDriver` interface. - Pattern parity with the existing OpenAI Embed implementation (`internal/entity/models/openai.go`). Closes #14806 Depends on #14805 Tracking: #14736 --------- Co-authored-by: Jin Hai --- conf/models/mistral.json | 99 +++++ internal/entity/models/factory.go | 2 + internal/entity/models/mistral.go | 565 ++++++++++++++++++++++++ internal/entity/models/mistral_test.go | 574 +++++++++++++++++++++++++ 4 files changed, 1240 insertions(+) create mode 100644 conf/models/mistral.json create mode 100644 internal/entity/models/mistral.go create mode 100644 internal/entity/models/mistral_test.go diff --git a/conf/models/mistral.json b/conf/models/mistral.json new file mode 100644 index 0000000000..fefc4833a6 --- /dev/null +++ b/conf/models/mistral.json @@ -0,0 +1,99 @@ +{ + "name": "Mistral", + "url": { + "default": "https://api.mistral.ai/v1" + }, + "url_suffix": { + "chat": "chat/completions", + "models": "models", + "embedding": "embeddings" + }, + "class": "mistral", + "models": [ + { + "name": "mistral-large-latest", + "max_tokens": 128000, + "model_types": [ + "chat" + ] + }, + { + "name": "mistral-medium-latest", + "max_tokens": 128000, + "model_types": [ + "chat" + ] + }, + { + "name": "mistral-small-latest", + "max_tokens": 128000, + "model_types": [ + "chat" + ] + }, + { + "name": "ministral-8b-latest", + "max_tokens": 128000, + "model_types": [ + "chat" + ] + }, + { + "name": "ministral-3b-latest", + "max_tokens": 128000, + "model_types": [ + "chat" + ] + }, + { + "name": "pixtral-large-latest", + "max_tokens": 128000, + "model_types": [ + "chat", + "vision" + ] + }, + { + "name": "codestral-latest", + "max_tokens": 256000, + "model_types": [ + "chat" + ] + }, + { + "name": "open-mistral-nemo", + "max_tokens": 128000, + "model_types": [ + "chat" + ] + }, + { + "name": "open-mistral-7b", + "max_tokens": 32000, + "model_types": [ + "chat" + ] + }, + { + "name": "open-mixtral-8x7b", + "max_tokens": 32000, + "model_types": [ + "chat" + ] + }, + { + "name": "open-mixtral-8x22b", + "max_tokens": 64000, + "model_types": [ + "chat" + ] + }, + { + "name": "mistral-embed", + "max_tokens": 8192, + "model_types": [ + "embedding" + ] + } + ] +} diff --git a/internal/entity/models/factory.go b/internal/entity/models/factory.go index 702c6e7045..c11e479642 100644 --- a/internal/entity/models/factory.go +++ b/internal/entity/models/factory.go @@ -73,6 +73,8 @@ func (f *ModelFactory) CreateModelDriver(providerName string, baseURL map[string return NewCoHereModel(baseURL, urlSuffix), nil case "fishaudio": return NewFishAudioModel(baseURL, urlSuffix), nil + case "mistral": + return NewMistralModel(baseURL, urlSuffix), nil case "upstage": return NewUpstageModel(baseURL, urlSuffix), nil case "stepfun": diff --git a/internal/entity/models/mistral.go b/internal/entity/models/mistral.go new file mode 100644 index 0000000000..b9ff04df57 --- /dev/null +++ b/internal/entity/models/mistral.go @@ -0,0 +1,565 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package models + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +// MistralModel implements ModelDriver for Mistral AI. +// +// Mistral exposes an OpenAI-compatible REST API at https://api.mistral.ai/v1 +// (chat completions at /chat/completions, list models at /models). The wire +// shape matches OpenAI closely enough that the chat path here is a direct +// port of the OpenAI driver, with the differences kept small on purpose: +// no reasoning_content pass-through (Mistral does not expose one), and a +// distinct Name() so the factory can route to this driver. +type MistralModel struct { + BaseURL map[string]string + URLSuffix URLSuffix + httpClient *http.Client +} + +// NewMistralModel creates a new Mistral model instance. +// +// We clone http.DefaultTransport so we keep Go's defaults for +// ProxyFromEnvironment, DialContext (with KeepAlive), HTTP/2, +// TLSHandshakeTimeout, and ExpectContinueTimeout, and only override +// the connection-pool fields we care about. +// +// The Client itself has no Timeout. http.Client.Timeout would also +// cap the time spent reading the response body, which would cut off +// long-lived SSE streams in ChatStreamlyWithSender. Non-streaming +// callers wrap each request with context.WithTimeout instead. +func NewMistralModel(baseURL map[string]string, urlSuffix URLSuffix) *MistralModel { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.MaxIdleConns = 100 + transport.MaxIdleConnsPerHost = 10 + transport.IdleConnTimeout = 90 * time.Second + transport.DisableCompression = false + transport.ResponseHeaderTimeout = 60 * time.Second + + return &MistralModel{ + BaseURL: baseURL, + URLSuffix: urlSuffix, + httpClient: &http.Client{ + Transport: transport, + }, + } +} + +func (m *MistralModel) NewInstance(baseURL map[string]string) ModelDriver { + return NewMistralModel(baseURL, m.URLSuffix) +} + +func (m *MistralModel) Name() string { + return "mistral" +} + +// baseURLForRegion returns the base URL for the given region, or an +// error if no entry exists. This makes a misconfigured region fail +// fast with a clear message, instead of silently producing a relative +// URL that the HTTP transport then rejects. +func (m *MistralModel) baseURLForRegion(region string) (string, error) { + base, ok := m.BaseURL[region] + if !ok || base == "" { + return "", fmt.Errorf("mistral: no base URL configured for region %q", region) + } + return base, nil +} + +// ChatWithMessages sends multiple messages with roles and returns the response. +func (m *MistralModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + + if len(messages) == 0 { + return nil, fmt.Errorf("messages is empty") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := m.baseURLForRegion(region) + if err != nil { + return nil, err + } + url := fmt.Sprintf("%s/%s", baseURL, m.URLSuffix.Chat) + + apiMessages := make([]map[string]interface{}, len(messages)) + for i, msg := range messages { + apiMessages[i] = map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, + } + } + + reqBody := map[string]interface{}{ + "model": modelName, + "messages": apiMessages, + "stream": false, + } + + // Note: do NOT propagate chatModelConfig.Stream into the request body + // here. ChatWithMessages parses a single JSON response, so stream must + // always be off for this code path. + if chatModelConfig != nil { + if chatModelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *chatModelConfig.MaxTokens + } + if chatModelConfig.Temperature != nil { + reqBody["temperature"] = *chatModelConfig.Temperature + } + if chatModelConfig.TopP != nil { + reqBody["top_p"] = *chatModelConfig.TopP + } + if chatModelConfig.Stop != nil { + reqBody["stop"] = *chatModelConfig.Stop + } + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := m.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + var result map[string]interface{} + if err = json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + choices, ok := result["choices"].([]interface{}) + if !ok || len(choices) == 0 { + return nil, fmt.Errorf("no choices in response") + } + + firstChoice, ok := choices[0].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid choice format") + } + + messageMap, ok := firstChoice["message"].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid message format") + } + + content, ok := messageMap["content"].(string) + if !ok { + return nil, fmt.Errorf("invalid content format") + } + + emptyReason := "" + return &ChatResponse{ + Answer: &content, + ReasonContent: &emptyReason, + }, nil +} + +// ChatStreamlyWithSender sends messages and streams the response via the +// sender function. The Mistral SSE stream uses the same shape as OpenAI: +// "data:" lines carrying JSON events, with a final "[DONE]" line. +func (m *MistralModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig, sender func(*string, *string) error) error { + if sender == nil { + return fmt.Errorf("sender is required") + } + + if len(messages) == 0 { + return fmt.Errorf("messages is empty") + } + + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return fmt.Errorf("api key is required") + } + + var region = "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := m.baseURLForRegion(region) + if err != nil { + return err + } + url := fmt.Sprintf("%s/%s", baseURL, m.URLSuffix.Chat) + + apiMessages := make([]map[string]interface{}, len(messages)) + for i, msg := range messages { + apiMessages[i] = map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, + } + } + + reqBody := map[string]interface{}{ + "model": modelName, + "messages": apiMessages, + "stream": true, + } + + if chatModelConfig != nil { + // Refuse to run if the caller explicitly asked for stream=false. + // The body of this method only knows how to read SSE, so a + // non-SSE JSON response would be parsed as if it were a stream + // and produce no chunks. Better to fail clearly. + if chatModelConfig.Stream != nil && !*chatModelConfig.Stream { + return fmt.Errorf("stream must be true in ChatStreamlyWithSender") + } + + if chatModelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *chatModelConfig.MaxTokens + } + if chatModelConfig.Temperature != nil { + reqBody["temperature"] = *chatModelConfig.Temperature + } + if chatModelConfig.TopP != nil { + reqBody["top_p"] = *chatModelConfig.TopP + } + if chatModelConfig.Stop != nil { + reqBody["stop"] = *chatModelConfig.Stop + } + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + + // Use an explicit background context. SSE streams are long-lived + // so we do not attach a hard deadline here; the transport's + // ResponseHeaderTimeout caps the connection-establishment phase. + req, err := http.NewRequestWithContext(context.Background(), "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := m.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + // SSE parsing: bump the scanner buffer from the 64KB default to 1MB + // so we never silently truncate a long data: line. + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 64*1024), 1024*1024) + sawTerminal := false + for scanner.Scan() { + line := scanner.Text() + + if !strings.HasPrefix(line, "data:") { + continue + } + + data := strings.TrimSpace(line[5:]) + + if data == "[DONE]" { + sawTerminal = true + break + } + + var event map[string]interface{} + if err = json.Unmarshal([]byte(data), &event); err != nil { + continue + } + + choices, ok := event["choices"].([]interface{}) + if !ok || len(choices) == 0 { + continue + } + + firstChoice, ok := choices[0].(map[string]interface{}) + if !ok { + continue + } + + delta, ok := firstChoice["delta"].(map[string]interface{}) + if !ok { + continue + } + + content, ok := delta["content"].(string) + if ok && content != "" { + if err := sender(&content, nil); err != nil { + return err + } + } + + finishReason, ok := firstChoice["finish_reason"].(string) + if ok && finishReason != "" { + sawTerminal = true + break + } + } + + if err := scanner.Err(); err != nil { + return fmt.Errorf("failed to scan response body: %w", err) + } + if !sawTerminal { + return fmt.Errorf("mistral: stream ended before [DONE] or finish_reason") + } + + endOfStream := "[DONE]" + if err := sender(&endOfStream, nil); err != nil { + return err + } + + return nil +} + +type mistralEmbeddingData struct { + Embedding []float64 `json:"embedding"` + Object string `json:"object"` + Index int `json:"index"` +} + +type mistralEmbeddingResponse struct { + Data []mistralEmbeddingData `json:"data"` + Model string `json:"model"` + Object string `json:"object"` +} + +// Embed turns a list of texts into embedding vectors using the +// Mistral /v1/embeddings endpoint (mistral-embed). The output has +// one vector per input, in the same order the inputs were given. +func (m *MistralModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { + if len(texts) == 0 { + return []EmbeddingData{}, nil + } + + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + + if modelName == nil || *modelName == "" { + return nil, fmt.Errorf("model name is required") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := m.baseURLForRegion(region) + if err != nil { + return nil, err + } + url := fmt.Sprintf("%s/%s", baseURL, m.URLSuffix.Embedding) + + reqBody := map[string]interface{}{ + "model": *modelName, + "input": texts, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := m.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Mistral embeddings API error: %s, body: %s", resp.Status, string(body)) + } + + var parsed mistralEmbeddingResponse + if err = json.Unmarshal(body, &parsed); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + // Reorder the returned vectors by their reported index so the output + // always lines up with the input texts, even if the upstream API ever + // returns items out of order. A nil slot at the end indicates the + // upstream did not return an embedding for that input. + embeddings := make([]EmbeddingData, len(texts)) + filled := make([]bool, len(texts)) + for _, item := range parsed.Data { + if item.Index < 0 || item.Index >= len(texts) { + return nil, fmt.Errorf("mistral: response index %d out of range for %d inputs", item.Index, len(texts)) + } + if filled[item.Index] { + // A malformed response that repeats the same index would + // silently overwrite the earlier vector. Fail loudly so + // the caller never uses ambiguous output. + return nil, fmt.Errorf("mistral: duplicate embedding index %d in response", item.Index) + } + embeddings[item.Index] = EmbeddingData{ + Embedding: item.Embedding, + Index: item.Index, + } + filled[item.Index] = true + } + for i, ok := range filled { + if !ok { + return nil, fmt.Errorf("mistral: missing embedding for input index %d", i) + } + } + + return embeddings, nil +} + +// ListModels returns the list of model ids visible to the API key. +func (m *MistralModel) ListModels(apiConfig *APIConfig) ([]string, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := m.baseURLForRegion(region) + if err != nil { + return nil, err + } + url := fmt.Sprintf("%s/%s", baseURL, m.URLSuffix.Models) + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := m.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + var result map[string]interface{} + if err = json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + data, ok := result["data"].([]interface{}) + if !ok { + return nil, fmt.Errorf("invalid models list format") + } + + models := make([]string, 0) + for _, model := range data { + modelMap, ok := model.(map[string]interface{}) + if !ok { + continue + } + modelName, ok := modelMap["id"].(string) + if !ok { + continue + } + models = append(models, modelName) + } + + return models, nil +} + +// Balance is not exposed by the Mistral API, so this returns "no such method". +func (m *MistralModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { + return nil, fmt.Errorf("no such method") +} + +// CheckConnection runs a lightweight ListModels call to verify the API key. +func (m *MistralModel) CheckConnection(apiConfig *APIConfig) error { + _, err := m.ListModels(apiConfig) + if err != nil { + return err + } + return nil +} + +// Rerank calculates similarity scores between query and documents. Mistral +// does not expose a public rerank API, so this returns "no such method". +func (m *MistralModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { + return nil, fmt.Errorf("no such method") +} diff --git a/internal/entity/models/mistral_test.go b/internal/entity/models/mistral_test.go new file mode 100644 index 0000000000..dc7f318e14 --- /dev/null +++ b/internal/entity/models/mistral_test.go @@ -0,0 +1,574 @@ +package models + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" +) + +// newMistralServer stands up an httptest server that asserts the +// request shape and lets the caller decide what to return. +func newMistralServer(t *testing.T, expectedPath string, handler func(t *testing.T, body map[string]interface{}, w http.ResponseWriter)) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != expectedPath { + t.Errorf("expected path=%s, got %s", expectedPath, r.URL.Path) + return + } + if got := r.Header.Get("Authorization"); got != "Bearer test-key" { + t.Errorf("expected Authorization=Bearer test-key, got %q", got) + return + } + if r.Method == http.MethodPost { + if got := r.Header.Get("Content-Type"); got != "application/json" { + t.Errorf("expected Content-Type=application/json, got %q", got) + return + } + raw, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("failed to read body: %v", err) + return + } + var body map[string]interface{} + if err := json.Unmarshal(raw, &body); err != nil { + t.Errorf("invalid JSON body: %v\n%s", err, string(raw)) + return + } + handler(t, body, w) + return + } + // GET path: no body + handler(t, nil, w) + })) +} + +func newMistralForTest(baseURL string) *MistralModel { + return NewMistralModel( + map[string]string{"default": baseURL}, + URLSuffix{ + Chat: "chat/completions", + Models: "models", + Embedding: "embeddings", + }, + ) +} + +func TestMistralName(t *testing.T) { + m := newMistralForTest("http://unused") + if got := m.Name(); got != "mistral" { + t.Errorf("Name()=%q, want %q", got, "mistral") + } +} + +func TestMistralChatHappyPath(t *testing.T) { + srv := newMistralServer(t, "/chat/completions", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + if body["model"] != "mistral-large-latest" { + t.Errorf("expected model=mistral-large-latest, got %v", body["model"]) + } + if body["stream"] != false { + t.Errorf("expected stream=false, got %v", body["stream"]) + } + msgs, ok := body["messages"].([]interface{}) + if !ok || len(msgs) != 1 { + t.Errorf("expected 1 message, got %v", body["messages"]) + return + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "choices": []map[string]interface{}{ + {"message": map[string]interface{}{"content": "pong"}}, + }, + }) + }) + defer srv.Close() + + m := newMistralForTest(srv.URL) + apiKey := "test-key" + resp, err := m.ChatWithMessages("mistral-large-latest", []Message{ + {Role: "user", Content: "ping"}, + }, &APIConfig{ApiKey: &apiKey}, nil) + if err != nil { + t.Fatalf("ChatWithMessages: %v", err) + } + if resp.Answer == nil || *resp.Answer != "pong" { + t.Errorf("answer=%v, want pong", resp.Answer) + } + if resp.ReasonContent == nil || *resp.ReasonContent != "" { + t.Errorf("expected empty reason content, got %v", resp.ReasonContent) + } +} + +func TestMistralChatPropagatesConfig(t *testing.T) { + srv := newMistralServer(t, "/chat/completions", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + if body["max_tokens"] != float64(64) { + t.Errorf("max_tokens=%v want 64", body["max_tokens"]) + } + if body["temperature"] != 0.3 { + t.Errorf("temperature=%v want 0.3", body["temperature"]) + } + if body["top_p"] != 0.9 { + t.Errorf("top_p=%v want 0.9", body["top_p"]) + } + stop, ok := body["stop"].([]interface{}) + if !ok || len(stop) != 1 || stop[0] != "END" { + t.Errorf("stop=%v want [END]", body["stop"]) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "choices": []map[string]interface{}{{"message": map[string]interface{}{"content": "ok"}}}, + }) + }) + defer srv.Close() + + m := newMistralForTest(srv.URL) + apiKey := "test-key" + mt := 64 + temp := 0.3 + topP := 0.9 + stop := []string{"END"} + _, err := m.ChatWithMessages("mistral-large-latest", []Message{{Role: "user", Content: "ping"}}, + &APIConfig{ApiKey: &apiKey}, + &ChatConfig{MaxTokens: &mt, Temperature: &temp, TopP: &topP, Stop: &stop}, + ) + if err != nil { + t.Fatalf("ChatWithMessages: %v", err) + } +} + +func TestMistralChatRequiresAPIKey(t *testing.T) { + m := newMistralForTest("http://unused") + _, err := m.ChatWithMessages("mistral-large-latest", []Message{{Role: "user", Content: "x"}}, &APIConfig{}, nil) + if err == nil || !strings.Contains(err.Error(), "api key is required") { + t.Errorf("expected api-key error, got %v", err) + } + emptyKey := "" + _, err = m.ChatWithMessages("mistral-large-latest", []Message{{Role: "user", Content: "x"}}, &APIConfig{ApiKey: &emptyKey}, nil) + if err == nil || !strings.Contains(err.Error(), "api key is required") { + t.Errorf("empty key: expected api-key error, got %v", err) + } +} + +func TestMistralChatRequiresMessages(t *testing.T) { + m := newMistralForTest("http://unused") + apiKey := "test-key" + _, err := m.ChatWithMessages("mistral-large-latest", nil, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "messages is empty") { + t.Errorf("expected messages-empty error, got %v", err) + } +} + +func TestMistralChatRejectsHTTPError(t *testing.T) { + srv := newMistralServer(t, "/chat/completions", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"unauthorized"}`)) + }) + defer srv.Close() + + m := newMistralForTest(srv.URL) + apiKey := "test-key" + _, err := m.ChatWithMessages("mistral-large-latest", []Message{{Role: "user", Content: "x"}}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "401") { + t.Errorf("expected 401 propagated, got %v", err) + } +} + +func TestMistralChatFallsBackToDefaultOnEmptyRegion(t *testing.T) { + // Empty *Region pointer must fall back to the "default" entry, not + // be treated as an explicit "" region (which would miss the lookup). + srv := newMistralServer(t, "/chat/completions", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "choices": []map[string]interface{}{{"message": map[string]interface{}{"content": "ok"}}}, + }) + }) + defer srv.Close() + + m := newMistralForTest(srv.URL) + apiKey := "test-key" + emptyRegion := "" + _, err := m.ChatWithMessages("mistral-large-latest", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey, Region: &emptyRegion}, nil) + if err != nil { + t.Errorf("empty Region: expected fallback to default, got %v", err) + } +} + +func TestMistralListModelsFallsBackToDefaultOnEmptyRegion(t *testing.T) { + srv := newMistralServer(t, "/models", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{"data": []map[string]interface{}{{"id": "x"}}}) + }) + defer srv.Close() + + m := newMistralForTest(srv.URL) + apiKey := "test-key" + emptyRegion := "" + if _, err := m.ListModels(&APIConfig{ApiKey: &apiKey, Region: &emptyRegion}); err != nil { + t.Errorf("empty Region: expected fallback to default, got %v", err) + } +} + +func TestMistralStreamRequiresSender(t *testing.T) { + m := newMistralForTest("http://unused") + apiKey := "test-key" + err := m.ChatStreamlyWithSender("mistral-large-latest", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey}, nil, nil) + if err == nil || !strings.Contains(err.Error(), "sender is required") { + t.Errorf("expected sender-required error, got %v", err) + } +} + +func TestMistralChatRejectsUnknownRegion(t *testing.T) { + m := newMistralForTest("http://unused") + apiKey := "test-key" + region := "eu" + _, err := m.ChatWithMessages("mistral-large-latest", []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey, Region: ®ion}, nil) + if err == nil || !strings.Contains(err.Error(), "no base URL configured for region") { + t.Errorf("expected region error, got %v", err) + } +} + +func TestMistralStreamHappyPath(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/chat/completions" { + t.Errorf("path=%s", r.URL.Path) + return + } + raw, _ := io.ReadAll(r.Body) + var body map[string]interface{} + _ = json.Unmarshal(raw, &body) + if body["stream"] != true { + t.Errorf("expected stream=true, got %v", body["stream"]) + } + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + // Two content chunks then finish_reason terminator, then [DONE]. + _, _ = io.WriteString(w, + `data: {"choices":[{"delta":{"content":"Hello "}}]}`+"\n"+ + `data: {"choices":[{"delta":{"content":"world"}}]}`+"\n"+ + `data: {"choices":[{"delta":{},"finish_reason":"stop"}]}`+"\n"+ + `data: [DONE]`+"\n", + ) + })) + defer srv.Close() + + m := newMistralForTest(srv.URL) + apiKey := "test-key" + var chunks []string + var sawDone int32 + err := m.ChatStreamlyWithSender("mistral-large-latest", + []Message{{Role: "user", Content: "hi"}}, + &APIConfig{ApiKey: &apiKey}, nil, + func(content *string, _ *string) error { + if content == nil { + return nil + } + if *content == "[DONE]" { + atomic.StoreInt32(&sawDone, 1) + return nil + } + chunks = append(chunks, *content) + return nil + }, + ) + if err != nil { + t.Fatalf("stream: %v", err) + } + if strings.Join(chunks, "") != "Hello world" { + t.Errorf("chunks=%v want [\"Hello \" \"world\"]", chunks) + } + if atomic.LoadInt32(&sawDone) != 1 { + t.Error("expected sender to receive [DONE] sentinel") + } +} + +func TestMistralStreamRejectsExplicitFalse(t *testing.T) { + m := newMistralForTest("http://unused") + apiKey := "test-key" + stream := false + err := m.ChatStreamlyWithSender("mistral-large-latest", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey}, + &ChatConfig{Stream: &stream}, + func(*string, *string) error { return nil }, + ) + if err == nil || !strings.Contains(err.Error(), "stream must be true") { + t.Errorf("expected stream-true guard, got %v", err) + } +} + +func TestMistralStreamFailsWithoutTerminal(t *testing.T) { + // Body closes before [DONE] or a finish_reason -> driver must complain + // instead of pretending the stream finished cleanly. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, `data: {"choices":[{"delta":{"content":"half"}}]}`+"\n") + })) + defer srv.Close() + + m := newMistralForTest(srv.URL) + apiKey := "test-key" + err := m.ChatStreamlyWithSender("mistral-large-latest", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey}, nil, + func(*string, *string) error { return nil }, + ) + if err == nil || !strings.Contains(err.Error(), "stream ended before") { + t.Errorf("expected stream-truncation error, got %v", err) + } +} + +func TestMistralListModelsHappyPath(t *testing.T) { + srv := newMistralServer(t, "/models", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "data": []map[string]interface{}{ + {"id": "mistral-large-latest"}, + {"id": "mistral-small-latest"}, + {"id": "mistral-embed"}, + }, + }) + }) + defer srv.Close() + + m := newMistralForTest(srv.URL) + apiKey := "test-key" + ids, err := m.ListModels(&APIConfig{ApiKey: &apiKey}) + if err != nil { + t.Fatalf("ListModels: %v", err) + } + if len(ids) != 3 || ids[0] != "mistral-large-latest" || ids[2] != "mistral-embed" { + t.Errorf("ids=%v, want [mistral-large-latest mistral-small-latest mistral-embed]", ids) + } +} + +func TestMistralListModelsRequiresAPIKey(t *testing.T) { + m := newMistralForTest("http://unused") + if _, err := m.ListModels(&APIConfig{}); err == nil || !strings.Contains(err.Error(), "api key is required") { + t.Errorf("expected api-key error, got %v", err) + } +} + +func TestMistralCheckConnectionDelegatesToListModels(t *testing.T) { + // 200 -> CheckConnection succeeds; 401 -> CheckConnection propagates. + okSrv := newMistralServer(t, "/models", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{"data": []map[string]interface{}{{"id": "x"}}}) + }) + defer okSrv.Close() + failSrv := newMistralServer(t, "/models", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + w.WriteHeader(http.StatusUnauthorized) + }) + defer failSrv.Close() + + apiKey := "test-key" + mOK := newMistralForTest(okSrv.URL) + if err := mOK.CheckConnection(&APIConfig{ApiKey: &apiKey}); err != nil { + t.Errorf("CheckConnection(ok): %v", err) + } + mFail := newMistralForTest(failSrv.URL) + if err := mFail.CheckConnection(&APIConfig{ApiKey: &apiKey}); err == nil { + t.Error("CheckConnection(fail): expected error, got nil") + } +} + +func TestMistralBalanceReturnsNoSuchMethod(t *testing.T) { + m := newMistralForTest("http://unused") + _, err := m.Balance(&APIConfig{}) + if err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("Balance: expected 'no such method', got %v", err) + } +} + +func TestMistralRerankReturnsNoSuchMethod(t *testing.T) { + m := newMistralForTest("http://unused") + q := "mistral-large-latest" + _, err := m.Rerank(&q, "what is rag?", []string{"a", "b"}, &APIConfig{}, &RerankConfig{TopN: 2}) + if err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("Rerank: expected 'no such method', got %v", err) + } +} + +func TestMistralEmbedHappyPath(t *testing.T) { + srv := newMistralServer(t, "/embeddings", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + if body["model"] != "mistral-embed" { + t.Errorf("model=%v want mistral-embed", body["model"]) + } + inputs, ok := body["input"].([]interface{}) + if !ok || len(inputs) != 3 { + t.Errorf("input=%v want 3-element array", body["input"]) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "data": []map[string]interface{}{ + {"embedding": []float64{0.1, 0.2}, "index": 0}, + {"embedding": []float64{0.3, 0.4}, "index": 1}, + {"embedding": []float64{0.5, 0.6}, "index": 2}, + }, + }) + }) + defer srv.Close() + + m := newMistralForTest(srv.URL) + apiKey := "test-key" + model := "mistral-embed" + vecs, err := m.Embed(&model, []string{"a", "b", "c"}, &APIConfig{ApiKey: &apiKey}, nil) + if err != nil { + t.Fatalf("Embed: %v", err) + } + if len(vecs) != 3 { + t.Fatalf("len(vecs)=%d want 3", len(vecs)) + } + if vecs[1].Embedding[0] != 0.3 || vecs[1].Index != 1 { + t.Errorf("vecs[1]=%+v want {Embedding:[0.3 0.4] Index:1}", vecs[1]) + } +} + +func TestMistralEmbedReordersByIndex(t *testing.T) { + // Upstream returns the three vectors in shuffled order. The driver + // must reorder them so the slot at position i corresponds to input i. + srv := newMistralServer(t, "/embeddings", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "data": []map[string]interface{}{ + {"embedding": []float64{2}, "index": 2}, + {"embedding": []float64{0}, "index": 0}, + {"embedding": []float64{1}, "index": 1}, + }, + }) + }) + defer srv.Close() + + m := newMistralForTest(srv.URL) + apiKey := "test-key" + model := "mistral-embed" + vecs, err := m.Embed(&model, []string{"a", "b", "c"}, &APIConfig{ApiKey: &apiKey}, nil) + if err != nil { + t.Fatalf("Embed: %v", err) + } + for i, v := range vecs { + if v.Index != i || v.Embedding[0] != float64(i) { + t.Errorf("slot %d = %+v, want Embedding=[%d] Index=%d", i, v, i, i) + } + } +} + +func TestMistralEmbedEmptyInputShortCircuits(t *testing.T) { + // Empty input must NOT make an HTTP call; the test fails the request + // rather than the assertion if it does. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + t.Error("Embed([]) made an unexpected HTTP call") + w.WriteHeader(http.StatusInternalServerError) + })) + defer srv.Close() + + m := newMistralForTest(srv.URL) + apiKey := "test-key" + model := "mistral-embed" + vecs, err := m.Embed(&model, []string{}, &APIConfig{ApiKey: &apiKey}, nil) + if err != nil { + t.Fatalf("Embed([]): %v", err) + } + if len(vecs) != 0 { + t.Errorf("len(vecs)=%d want 0", len(vecs)) + } +} + +func TestMistralEmbedRequiresAPIKey(t *testing.T) { + m := newMistralForTest("http://unused") + model := "mistral-embed" + _, err := m.Embed(&model, []string{"a"}, &APIConfig{}, nil) + if err == nil || !strings.Contains(err.Error(), "api key is required") { + t.Errorf("expected api-key error, got %v", err) + } +} + +func TestMistralEmbedRequiresModelName(t *testing.T) { + m := newMistralForTest("http://unused") + apiKey := "test-key" + _, err := m.Embed(nil, []string{"a"}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "model name is required") { + t.Errorf("expected model-name error, got %v", err) + } + empty := "" + _, err = m.Embed(&empty, []string{"a"}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "model name is required") { + t.Errorf("empty model: expected model-name error, got %v", err) + } +} + +func TestMistralEmbedRejectsDuplicateIndex(t *testing.T) { + // A malformed upstream that repeats data[*].index would silently + // overwrite the earlier vector; the driver must fail loudly instead. + srv := newMistralServer(t, "/embeddings", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "data": []map[string]interface{}{ + {"embedding": []float64{1}, "index": 0}, + {"embedding": []float64{2}, "index": 0}, + }, + }) + }) + defer srv.Close() + + m := newMistralForTest(srv.URL) + apiKey := "test-key" + model := "mistral-embed" + _, err := m.Embed(&model, []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "duplicate embedding index 0") { + t.Errorf("expected duplicate-index error, got %v", err) + } +} + +func TestMistralEmbedRejectsOutOfRangeIndex(t *testing.T) { + srv := newMistralServer(t, "/embeddings", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "data": []map[string]interface{}{ + {"embedding": []float64{1}, "index": 7}, // out of range for 2-input request + }, + }) + }) + defer srv.Close() + + m := newMistralForTest(srv.URL) + apiKey := "test-key" + model := "mistral-embed" + _, err := m.Embed(&model, []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "out of range") { + t.Errorf("expected out-of-range error, got %v", err) + } +} + +func TestMistralEmbedRejectsMissingSlot(t *testing.T) { + // Upstream returns only one of the two requested embeddings. + srv := newMistralServer(t, "/embeddings", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "data": []map[string]interface{}{ + {"embedding": []float64{1}, "index": 0}, + }, + }) + }) + defer srv.Close() + + m := newMistralForTest(srv.URL) + apiKey := "test-key" + model := "mistral-embed" + _, err := m.Embed(&model, []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "missing embedding for input index 1") { + t.Errorf("expected missing-embedding error for slot 1, got %v", err) + } +} + +func TestMistralEmbedRejectsHTTPError(t *testing.T) { + srv := newMistralServer(t, "/embeddings", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"unauthorized"}`)) + }) + defer srv.Close() + + m := newMistralForTest(srv.URL) + apiKey := "test-key" + model := "mistral-embed" + _, err := m.Embed(&model, []string{"a"}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "Mistral embeddings API error") { + t.Errorf("expected Mistral embeddings API error, got %v", err) + } +}