From eaa2e46b1e2601584e82c52facc2352c54c62f30 Mon Sep 17 00:00:00 2001 From: tmimmanuel <14046872+tmimmanuel@users.noreply.github.com> Date: Mon, 11 May 2026 22:11:06 -1000 Subject: [PATCH] Go: implement Embed (embeddings) in Upstage driver (#14819) ### What problem does this PR solve? The Upstage Go driver landed in #14817 with chat, list models, and check connection. `Embed` was left as a stub that returns `"not implemented"`. This PR fills the gap. Upstage exposes an OpenAI-compatible embeddings endpoint at `https://api.upstage.ai/v1/solar/embeddings` via the `solar-embedding-1-large` family (`solar-embedding-1-large-query` for queries, `solar-embedding-1-large-passage` for passages), and the Python side has had `UpstageEmbed(OpenAIEmbed)` in `rag/llm/embedding_model.py` for a long time targeting this same path. The existing `conf/models/upstage.json` did not list any embedding model out of the box, so a tenant who wanted to use Upstage end to end could not run an embedding call. This PR fills the gap. ### What this PR includes - `conf/models/upstage.json`: add `"embedding": "embeddings"` under `url_suffix` so the driver can build the URL from config (matches the `URLSuffix.Embedding` field already used by openai, mistral, siliconflow, zhipu-ai), and add `solar-embedding-1-large-query` and `solar-embedding-1-large-passage` entries under `models`. - `internal/entity/models/upstage.go`: replace the `Embed` stub with a real implementation that POSTs to `/v1/solar/embeddings`. Adds local response types `upstageEmbeddingData` and `upstageEmbeddingResponse`. No factory change. No interface change. ### How the implementation works - Validate `apiConfig`, the API key, and the model name. Use the existing `baseURLForRegion` helper so an unknown region fails fast with a clear error. - Wrap the request with `context.WithTimeout(nonStreamCallTimeout)` so the call has a clear deadline. Same pattern as `ChatWithMessages` and `ListModels` already use in this file. - Send all input texts in one request. The Upstage API accepts the `input` field as an array. - Parse `data[*].embedding` and copy each slice into a `[]EmbeddingData` indexed by `data[*].index` so the output order matches the input order even if the API returns items in a different order. - An empty input slice returns `[]EmbeddingData{}` with no HTTP call. - Non-200 responses propagate the upstream status line and body. - A final pass checks that every input slot got a vector. If any slot is still empty, return a clear error so the caller does not silently use a zero vector. ### Note on stacking This PR builds on #14817 (the Upstage driver). Until #14817 merges, this PR's diff on GitHub will include both that PR's commits and this one. After #14817 lands on `main`, GitHub will auto-reduce this PR to only the `Embed` changes (one commit, ~119 line diff in `upstage.go` plus ~15 lines in `upstage.json`). ### Type of change - [x] New Feature (non-breaking change which adds functionality) ### How was this tested? - `go build ./internal/entity/models/...` returns exit 0 on go 1.25 (the `go.mod` minimum). - The full method set on `UpstageModel` still matches the `ModelDriver` interface. - Pattern parity with the existing Mistral Embed (`internal/entity/models/mistral.go`) and OpenAI Embed (`internal/entity/models/openai.go`) implementations. Closes #14818 Depends on #14817 Tracking: #14736 --------- Co-authored-by: Jin Hai --- conf/models/upstage.json | 56 +++ internal/entity/models/factory.go | 2 + internal/entity/models/upstage.go | 586 +++++++++++++++++++++++++ internal/entity/models/upstage_test.go | 271 ++++++++++++ 4 files changed, 915 insertions(+) create mode 100644 conf/models/upstage.json create mode 100644 internal/entity/models/upstage.go create mode 100644 internal/entity/models/upstage_test.go diff --git a/conf/models/upstage.json b/conf/models/upstage.json new file mode 100644 index 0000000000..045bcaf693 --- /dev/null +++ b/conf/models/upstage.json @@ -0,0 +1,56 @@ +{ + "name": "Upstage", + "url": { + "default": "https://api.upstage.ai/v1" + }, + "url_suffix": { + "chat": "chat/completions", + "models": "models", + "embedding": "embeddings" + }, + "class": "solar", + "models": [ + { + "name": "solar-pro3", + "max_tokens": 65536, + "model_types": [ + "chat" + ] + }, + { + "name": "solar-pro2", + "max_tokens": 65536, + "model_types": [ + "chat" + ] + }, + { + "name": "solar-pro", + "max_tokens": 32768, + "model_types": [ + "chat" + ] + }, + { + "name": "solar-mini", + "max_tokens": 32768, + "model_types": [ + "chat" + ] + }, + { + "name": "solar-embedding-1-large-query", + "max_tokens": 2000, + "model_types": [ + "embedding" + ] + }, + { + "name": "solar-embedding-1-large-passage", + "max_tokens": 2000, + "model_types": [ + "embedding" + ] + } + ] +} diff --git a/internal/entity/models/factory.go b/internal/entity/models/factory.go index 03a33aaacb..702c6e7045 100644 --- a/internal/entity/models/factory.go +++ b/internal/entity/models/factory.go @@ -73,6 +73,8 @@ func (f *ModelFactory) CreateModelDriver(providerName string, baseURL map[string return NewCoHereModel(baseURL, urlSuffix), nil case "fishaudio": return NewFishAudioModel(baseURL, urlSuffix), nil + case "upstage": + return NewUpstageModel(baseURL, urlSuffix), nil case "stepfun": return NewStepFunModel(baseURL, urlSuffix), nil case "baichuan": diff --git a/internal/entity/models/upstage.go b/internal/entity/models/upstage.go new file mode 100644 index 0000000000..fad7f857ac --- /dev/null +++ b/internal/entity/models/upstage.go @@ -0,0 +1,586 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package models + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +// UpstageModel implements ModelDriver for Upstage (Solar models). +// +// Upstage exposes an OpenAI-compatible REST API at +// https://api.upstage.ai/v1 (chat completions at /chat/completions, list +// models at /models, embeddings at /embeddings). The wire shape matches +// OpenAI closely enough that the chat path here is a direct port of the +// OpenAI driver. The legacy /v1/solar/* paths still work but the canonical +// base is /v1. +type UpstageModel struct { + BaseURL map[string]string + URLSuffix URLSuffix + httpClient *http.Client +} + +// NewUpstageModel creates a new Upstage model instance. +// +// We clone http.DefaultTransport so we keep Go's defaults for +// ProxyFromEnvironment, DialContext (with KeepAlive), HTTP/2, +// TLSHandshakeTimeout, and ExpectContinueTimeout, and only override +// the connection-pool fields we care about. +// +// The Client itself has no Timeout. http.Client.Timeout would also +// cap the time spent reading the response body, which would cut off +// long-lived SSE streams in ChatStreamlyWithSender. Non-streaming +// callers wrap each request with context.WithTimeout instead. +func NewUpstageModel(baseURL map[string]string, urlSuffix URLSuffix) *UpstageModel { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.MaxIdleConns = 100 + transport.MaxIdleConnsPerHost = 10 + transport.IdleConnTimeout = 90 * time.Second + transport.DisableCompression = false + transport.ResponseHeaderTimeout = 60 * time.Second + + return &UpstageModel{ + BaseURL: baseURL, + URLSuffix: urlSuffix, + httpClient: &http.Client{ + Transport: transport, + }, + } +} + +func (u *UpstageModel) NewInstance(baseURL map[string]string) ModelDriver { + return NewUpstageModel(baseURL, u.URLSuffix) +} + +func (u *UpstageModel) Name() string { + return "upstage" +} + +// baseURLForRegion returns the base URL for the given region, or an +// error if no entry exists. This makes a misconfigured region fail +// fast with a clear message, instead of silently producing a relative +// URL that the HTTP transport then rejects. +func (u *UpstageModel) baseURLForRegion(region string) (string, error) { + base, ok := u.BaseURL[region] + if !ok || base == "" { + return "", fmt.Errorf("upstage: no base URL configured for region %q", region) + } + return base, nil +} + +// ChatWithMessages sends multiple messages with roles and returns the response. +func (u *UpstageModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + + if len(messages) == 0 { + return nil, fmt.Errorf("messages is empty") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := u.baseURLForRegion(region) + if err != nil { + return nil, err + } + url := fmt.Sprintf("%s/%s", baseURL, u.URLSuffix.Chat) + + apiMessages := make([]map[string]interface{}, len(messages)) + for i, msg := range messages { + apiMessages[i] = map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, + } + } + + reqBody := map[string]interface{}{ + "model": modelName, + "messages": apiMessages, + "stream": false, + } + + // Note: do NOT propagate chatModelConfig.Stream into the request body + // here. ChatWithMessages parses a single JSON response, so stream must + // always be off for this code path. + if chatModelConfig != nil { + if chatModelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *chatModelConfig.MaxTokens + } + if chatModelConfig.Temperature != nil { + reqBody["temperature"] = *chatModelConfig.Temperature + } + if chatModelConfig.TopP != nil { + reqBody["top_p"] = *chatModelConfig.TopP + } + if chatModelConfig.Stop != nil { + reqBody["stop"] = *chatModelConfig.Stop + } + // Upstage Solar reasoning models (solar-pro2 and the upcoming + // solar-pro3) accept reasoning_effort=low|medium|high to trade + // latency for chain-of-thought depth, matching the OpenAI + // o-series shape. ChatConfig.Effort is the canonical carrier. + if chatModelConfig.Effort != nil && *chatModelConfig.Effort != "" { + reqBody["reasoning_effort"] = *chatModelConfig.Effort + } + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := u.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + var result map[string]interface{} + if err = json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + choices, ok := result["choices"].([]interface{}) + if !ok || len(choices) == 0 { + return nil, fmt.Errorf("no choices in response") + } + + firstChoice, ok := choices[0].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid choice format") + } + + messageMap, ok := firstChoice["message"].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid message format") + } + + content, ok := messageMap["content"].(string) + if !ok { + return nil, fmt.Errorf("invalid content format") + } + + // Upstage Solar reasoning models (solar-pro3, solar-pro2 with + // reasoning_effort >= medium) return the chain-of-thought in a + // `reasoning` field on the message. Pass it through when present + // so callers that opted into reasoning can show it. Absent or + // non-string means no reasoning was emitted — leave it empty. + reasonContent := "" + if r, ok := messageMap["reasoning"].(string); ok { + reasonContent = r + } + + return &ChatResponse{ + Answer: &content, + ReasonContent: &reasonContent, + }, nil +} + +// ChatStreamlyWithSender sends messages and streams the response via the +// sender function. The Upstage SSE stream uses the same shape as OpenAI: +// "data:" lines carrying JSON events, with a final "[DONE]" line. +func (u *UpstageModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig, sender func(*string, *string) error) error { + if sender == nil { + return fmt.Errorf("sender is required") + } + + if len(messages) == 0 { + return fmt.Errorf("messages is empty") + } + + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return fmt.Errorf("api key is required") + } + + var region = "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := u.baseURLForRegion(region) + if err != nil { + return err + } + url := fmt.Sprintf("%s/%s", baseURL, u.URLSuffix.Chat) + + apiMessages := make([]map[string]interface{}, len(messages)) + for i, msg := range messages { + apiMessages[i] = map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, + } + } + + reqBody := map[string]interface{}{ + "model": modelName, + "messages": apiMessages, + "stream": true, + } + + if chatModelConfig != nil { + // Refuse to run if the caller explicitly asked for stream=false. + // The body of this method only knows how to read SSE, so a + // non-SSE JSON response would be parsed as if it were a stream + // and produce no chunks. Better to fail clearly. + if chatModelConfig.Stream != nil && !*chatModelConfig.Stream { + return fmt.Errorf("stream must be true in ChatStreamlyWithSender") + } + + if chatModelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *chatModelConfig.MaxTokens + } + if chatModelConfig.Temperature != nil { + reqBody["temperature"] = *chatModelConfig.Temperature + } + if chatModelConfig.TopP != nil { + reqBody["top_p"] = *chatModelConfig.TopP + } + if chatModelConfig.Stop != nil { + reqBody["stop"] = *chatModelConfig.Stop + } + // reasoning_effort: same as the non-streaming path above. + if chatModelConfig.Effort != nil && *chatModelConfig.Effort != "" { + reqBody["reasoning_effort"] = *chatModelConfig.Effort + } + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + + // SSE streams are long-lived. We rely on the transport's + // ResponseHeaderTimeout to cap the connection-establishment phase + // instead of attaching a hard deadline here. + req, err := http.NewRequestWithContext(context.Background(), "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := u.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + // SSE parsing: bump the scanner buffer from the 64KB default to 1MB + // so we never silently truncate a long data: line. + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 64*1024), 1024*1024) + sawTerminal := false + for scanner.Scan() { + line := scanner.Text() + + if !strings.HasPrefix(line, "data:") { + continue + } + + data := strings.TrimSpace(line[5:]) + + if data == "[DONE]" { + sawTerminal = true + break + } + + var event map[string]interface{} + if err = json.Unmarshal([]byte(data), &event); err != nil { + continue + } + + choices, ok := event["choices"].([]interface{}) + if !ok || len(choices) == 0 { + continue + } + + firstChoice, ok := choices[0].(map[string]interface{}) + if !ok { + continue + } + + delta, ok := firstChoice["delta"].(map[string]interface{}) + if !ok { + continue + } + + content, ok := delta["content"].(string) + if ok && content != "" { + if err := sender(&content, nil); err != nil { + return err + } + } + + finishReason, ok := firstChoice["finish_reason"].(string) + if ok && finishReason != "" { + sawTerminal = true + break + } + } + + if err := scanner.Err(); err != nil { + return fmt.Errorf("failed to scan response body: %w", err) + } + if !sawTerminal { + return fmt.Errorf("upstage: stream ended before [DONE] or finish_reason") + } + + endOfStream := "[DONE]" + if err := sender(&endOfStream, nil); err != nil { + return err + } + + return nil +} + +type upstageEmbeddingData struct { + Embedding []float64 `json:"embedding"` + Object string `json:"object"` + Index int `json:"index"` +} + +type upstageEmbeddingResponse struct { + Data []upstageEmbeddingData `json:"data"` + Model string `json:"model"` + Object string `json:"object"` +} + +// Embed turns a list of texts into embedding vectors using the Upstage +// /v1/solar/embeddings endpoint (solar-embedding-1-large-query for queries, +// solar-embedding-1-large-passage for passages). The output has one vector +// per input, in the same order the inputs were given. +func (u *UpstageModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { + if len(texts) == 0 { + return []EmbeddingData{}, nil + } + + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + + if modelName == nil || *modelName == "" { + return nil, fmt.Errorf("model name is required") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := u.baseURLForRegion(region) + if err != nil { + return nil, err + } + url := fmt.Sprintf("%s/%s", baseURL, u.URLSuffix.Embedding) + + reqBody := map[string]interface{}{ + "model": *modelName, + "input": texts, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := u.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Upstage embeddings API error: %s, body: %s", resp.Status, string(body)) + } + + var parsed upstageEmbeddingResponse + if err = json.Unmarshal(body, &parsed); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + // Reorder by the reported index so the output always lines up with + // the input texts, even if the upstream API ever returns items out + // of order. A nil slot at the end indicates the upstream did not + // return an embedding for that input. + embeddings := make([]EmbeddingData, len(texts)) + filled := make([]bool, len(texts)) + for _, item := range parsed.Data { + if item.Index < 0 || item.Index >= len(texts) { + return nil, fmt.Errorf("upstage: response index %d out of range for %d inputs", item.Index, len(texts)) + } + if filled[item.Index] { + // A malformed response that repeats the same index would + // silently overwrite the earlier vector. Fail loudly so + // the caller never uses ambiguous output. + return nil, fmt.Errorf("upstage: duplicate embedding index %d in response", item.Index) + } + embeddings[item.Index] = EmbeddingData{ + Embedding: item.Embedding, + Index: item.Index, + } + filled[item.Index] = true + } + for i, ok := range filled { + if !ok { + return nil, fmt.Errorf("upstage: missing embedding for input index %d", i) + } + } + + return embeddings, nil +} + +// ListModels returns the list of model ids visible to the API key. +func (u *UpstageModel) ListModels(apiConfig *APIConfig) ([]string, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := u.baseURLForRegion(region) + if err != nil { + return nil, err + } + url := fmt.Sprintf("%s/%s", baseURL, u.URLSuffix.Models) + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := u.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + var result map[string]interface{} + if err = json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + data, ok := result["data"].([]interface{}) + if !ok { + return nil, fmt.Errorf("invalid models list format") + } + + models := make([]string, 0) + for _, model := range data { + modelMap, ok := model.(map[string]interface{}) + if !ok { + continue + } + modelName, ok := modelMap["id"].(string) + if !ok { + continue + } + models = append(models, modelName) + } + + return models, nil +} + +// Balance is not exposed by the Upstage API, so this returns "no such method". +func (u *UpstageModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { + return nil, fmt.Errorf("no such method") +} + +// CheckConnection runs a lightweight ListModels call to verify the API key. +func (u *UpstageModel) CheckConnection(apiConfig *APIConfig) error { + _, err := u.ListModels(apiConfig) + if err != nil { + return err + } + return nil +} + +// Rerank calculates similarity scores between query and documents. Upstage +// does not expose a public rerank API, so this returns "no such method". +func (u *UpstageModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { + return nil, fmt.Errorf("no such method") +} diff --git a/internal/entity/models/upstage_test.go b/internal/entity/models/upstage_test.go new file mode 100644 index 0000000000..cb651df94a --- /dev/null +++ b/internal/entity/models/upstage_test.go @@ -0,0 +1,271 @@ +package models + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func newUpstageForTest(baseURL string) *UpstageModel { + return NewUpstageModel( + map[string]string{"default": baseURL}, + URLSuffix{ + Chat: "chat/completions", + Models: "models", + Embedding: "embeddings", + }, + ) +} + +// ---------- reasoning_effort / reasoning field ---------- + +func TestUpstageChatPropagatesReasoningEffort(t *testing.T) { + // Per https://console.upstage.ai/api/docs/for-agents/raw, Upstage + // Solar models accept `reasoning_effort: minimal|low|medium|high`. + // ChatConfig.Effort is the canonical carrier; this test asserts it + // flows into the wire body verbatim. + var seen map[string]interface{} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + raw, _ := io.ReadAll(r.Body) + _ = json.Unmarshal(raw, &seen) + _, _ = io.WriteString(w, `{"choices":[{"message":{"content":"ok"}}]}`) + })) + defer srv.Close() + + u := newUpstageForTest(srv.URL) + apiKey := "test-key" + effort := "high" + _, err := u.ChatWithMessages("solar-pro2", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey}, + &ChatConfig{Effort: &effort}) + if err != nil { + t.Fatalf("Chat: %v", err) + } + if got, ok := seen["reasoning_effort"].(string); !ok || got != "high" { + t.Errorf("reasoning_effort=%v want \"high\"", seen["reasoning_effort"]) + } +} + +func TestUpstageChatOmitsReasoningEffortWhenUnset(t *testing.T) { + // If the caller does not opt in, the field must NOT be sent. Sending + // "minimal" by default would silently change behavior for downstream + // proxies that treat a present field differently from an absent one. + var seen map[string]interface{} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + raw, _ := io.ReadAll(r.Body) + _ = json.Unmarshal(raw, &seen) + _, _ = io.WriteString(w, `{"choices":[{"message":{"content":"ok"}}]}`) + })) + defer srv.Close() + + u := newUpstageForTest(srv.URL) + apiKey := "test-key" + _, err := u.ChatWithMessages("solar-pro2", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey}, + &ChatConfig{}, // no Effort + ) + if err != nil { + t.Fatalf("Chat: %v", err) + } + if _, present := seen["reasoning_effort"]; present { + t.Errorf("reasoning_effort should be absent when Effort is unset, got %v", seen["reasoning_effort"]) + } +} + +func TestUpstageStreamPropagatesReasoningEffort(t *testing.T) { + var seen map[string]interface{} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + raw, _ := io.ReadAll(r.Body) + _ = json.Unmarshal(raw, &seen) + w.Header().Set("Content-Type", "text/event-stream") + _, _ = io.WriteString(w, + `data: {"choices":[{"delta":{"content":"hi"},"finish_reason":"stop"}]}`+"\n"+ + `data: [DONE]`+"\n", + ) + })) + defer srv.Close() + + u := newUpstageForTest(srv.URL) + apiKey := "test-key" + effort := "medium" + err := u.ChatStreamlyWithSender("solar-pro2", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey}, + &ChatConfig{Effort: &effort}, + func(*string, *string) error { return nil }, + ) + if err != nil { + t.Fatalf("Stream: %v", err) + } + if got, ok := seen["reasoning_effort"].(string); !ok || got != "medium" { + t.Errorf("stream reasoning_effort=%v want \"medium\"", seen["reasoning_effort"]) + } +} + +func TestUpstageChatExtractsReasoningField(t *testing.T) { + // Per the Upstage docs: when reasoning_effort is high|medium for + // solar-pro3 (or high for solar-pro2), the response's + // choices[0].message includes a `reasoning` field. The driver must + // pass it through as ChatResponse.ReasonContent. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, `{"choices":[{"message":{ + "content":"15% of 80 is **12**.", + "reasoning":"15/100 = 0.15; 0.15 * 80 = 12" + }}]}`) + })) + defer srv.Close() + + u := newUpstageForTest(srv.URL) + apiKey := "test-key" + resp, err := u.ChatWithMessages("solar-pro3", + []Message{{Role: "user", Content: "What is 15% of 80?"}}, + &APIConfig{ApiKey: &apiKey}, nil) + if err != nil { + t.Fatalf("Chat: %v", err) + } + if resp.ReasonContent == nil || *resp.ReasonContent != "15/100 = 0.15; 0.15 * 80 = 12" { + t.Errorf("ReasonContent=%v want the reasoning trace", resp.ReasonContent) + } + if resp.Answer == nil || *resp.Answer != "15% of 80 is **12**." { + t.Errorf("Answer=%v", resp.Answer) + } +} + +func TestUpstageChatHandlesAbsentReasoning(t *testing.T) { + // Models without reasoning (solar-mini, syn-pro) or low-effort + // requests return no `reasoning` field. The driver must leave + // ReasonContent empty without crashing. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, `{"choices":[{"message":{"content":"ok"}}]}`) + })) + defer srv.Close() + + u := newUpstageForTest(srv.URL) + apiKey := "test-key" + resp, err := u.ChatWithMessages("solar-mini", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey}, nil) + if err != nil { + t.Fatalf("Chat: %v", err) + } + if resp.ReasonContent == nil || *resp.ReasonContent != "" { + t.Errorf("ReasonContent=%v want empty string for no-reasoning response", resp.ReasonContent) + } + if resp.Answer == nil || *resp.Answer != "ok" { + t.Errorf("Answer=%v want ok", resp.Answer) + } +} + +// Ensure the same JSON shape used by the maintainer's docs (per +// https://console.upstage.ai/api/chat) round-trips through the request +// body for both streaming and non-streaming paths. +func TestUpstageRequestBodyMatchesSolarAPIShape(t *testing.T) { + var seen map[string]interface{} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + raw, _ := io.ReadAll(r.Body) + _ = json.Unmarshal(raw, &seen) + _, _ = io.WriteString(w, `{"choices":[{"message":{"content":"ok"}}]}`) + })) + defer srv.Close() + + u := newUpstageForTest(srv.URL) + apiKey := "test-key" + mt := 256 + temp := 0.7 + topP := 0.9 + stop := []string{"END"} + effort := "high" + _, err := u.ChatWithMessages("solar-pro2", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey}, + &ChatConfig{MaxTokens: &mt, Temperature: &temp, TopP: &topP, Stop: &stop, Effort: &effort}) + if err != nil { + t.Fatalf("Chat: %v", err) + } + want := map[string]interface{}{ + "model": "solar-pro2", + "stream": false, + "max_tokens": float64(256), + "temperature": 0.7, + "top_p": 0.9, + "reasoning_effort": "high", + } + for k, v := range want { + if got, ok := seen[k]; !ok { + t.Errorf("missing key %q in body", k) + } else if !strings.HasPrefix(k, "stop") && got != v { + t.Errorf("body[%q]=%v want %v", k, got, v) + } + } + if stopArr, ok := seen["stop"].([]interface{}); !ok || len(stopArr) != 1 || stopArr[0] != "END" { + t.Errorf("body[stop]=%v want [END]", seen["stop"]) + } + if _, ok := seen["messages"].([]interface{}); !ok { + t.Errorf("body[messages] missing or wrong type") + } +} + +// ---------- Embed: duplicate / out-of-range / reorder ---------- + +func TestUpstageEmbedRejectsDuplicateIndex(t *testing.T) { + // A malformed upstream that repeats data[*].index would silently + // overwrite the earlier vector; the driver must fail loudly instead. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, `{"data":[ + {"embedding":[1],"index":0}, + {"embedding":[2],"index":0}]}`) + })) + defer srv.Close() + + u := newUpstageForTest(srv.URL) + apiKey := "test-key" + model := "solar-embedding-1-large-passage" + _, err := u.Embed(&model, []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "duplicate embedding index 0") { + t.Errorf("expected duplicate-index error, got %v", err) + } +} + +func TestUpstageEmbedRejectsOutOfRangeIndex(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, `{"data":[{"embedding":[1],"index":7}]}`) + })) + defer srv.Close() + + u := newUpstageForTest(srv.URL) + apiKey := "test-key" + model := "solar-embedding-1-large-passage" + _, err := u.Embed(&model, []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "out of range") { + t.Errorf("expected out-of-range error, got %v", err) + } +} + +func TestUpstageEmbedHappyPathReordersByIndex(t *testing.T) { + // Upstream returns vectors in shuffled order; driver must realign. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, `{"data":[ + {"embedding":[2],"index":2}, + {"embedding":[0],"index":0}, + {"embedding":[1],"index":1}]}`) + })) + defer srv.Close() + + u := newUpstageForTest(srv.URL) + apiKey := "test-key" + model := "solar-embedding-1-large-passage" + vecs, err := u.Embed(&model, []string{"a", "b", "c"}, &APIConfig{ApiKey: &apiKey}, nil) + if err != nil { + t.Fatalf("Embed: %v", err) + } + for i, v := range vecs { + if v.Index != i || v.Embedding[0] != float64(i) { + t.Errorf("slot %d = %+v, want index=%d embedding=[%d]", i, v, i, i) + } + } +}