From 6bfe0f9a1045619df362a0753c8627f1b96f7e4f Mon Sep 17 00:00:00 2001 From: Panda Dev <56657208+pandadev66@users.noreply.github.com> Date: Sun, 10 May 2026 04:31:37 +0200 Subject: [PATCH] Go: implement Encode (embeddings) in OpenAI driver (#14630) ### What problem does this PR solve? The OpenAI Go driver landed in #14605 with chat, list models, and check connection. Encode was left as a stub that returns \`not implemented\`. \`conf/models/openai.json\` already lists three embedding models out of the box: - text-embedding-ada-002 - text-embedding-3-small - text-embedding-3-large So a tenant who picked one of these in the Go layer could not actually run an embedding call. This PR fills the gap. ### What this PR includes - \`conf/models/openai.json\`: add \`\"embedding\": \"embeddings\"\` under \`url_suffix\` so the driver can build the URL from config. This matches the \`URLSuffix.Embedding\` field used by other drivers (siliconflow, zhipu-ai). - \`internal/entity/models/openai.go\`: replace the Encode stub with a real implementation that POSTs to \`/v1/embeddings\`. Adds a small local response type \`openaiEmbeddingResponse\`. No factory change. No interface change. ### How the implementation works - Validate \`apiConfig\` and the API key, validate the model name. Use the existing \`baseURLForRegion\` helper so an unknown region fails fast with a clear error. - Wrap the request with \`context.WithTimeout(nonStreamCallTimeout)\` so the call has a clear deadline. Same pattern as \`ChatWithMessages\` and \`ListModels\` already use in this file. - Send all input texts in one request. The OpenAI API accepts the \`input\` field as an array. - Parse \`data[*].embedding\` and copy each slice into a \`[][]float64\` indexed by \`data[*].index\` so the output order matches the input order even if the API returns items in a different order. - Handle both \`float64\` and \`float32\` element types, the way the SiliconFlow driver does. - An empty input slice returns \`[][]float64{}\` with no HTTP call. - Non-200 responses propagate the upstream status line and body. - A final pass checks that every input slot got a vector. If any slot is still nil, return a clear error so the caller does not silently use a zero vector. ### Type of change - [x] New Feature (non-breaking change which adds functionality) ### How was this tested? - \`go build ./internal/entity/models/...\` in a clean go 1.25 image (the go.mod minimum) returns exit 0. - The full method set on \`OpenAIModel\` still matches the \`ModelDriver\` interface. - Pattern parity with the existing SiliconFlow Encode implementation (\`internal/entity/models/siliconflow.go\`). Closes #14629 --------- Co-authored-by: Jin Hai --- conf/models/openai.json | 3 +- internal/entity/models/factory.go | 2 + internal/entity/models/openai.go | 113 ++++++++++++++++++++++++++++-- 3 files changed, 112 insertions(+), 6 deletions(-) diff --git a/conf/models/openai.json b/conf/models/openai.json index 696c6f93b3..c78a82b4c2 100644 --- a/conf/models/openai.json +++ b/conf/models/openai.json @@ -5,7 +5,8 @@ }, "url_suffix": { "chat": "chat/completions", - "models": "models" + "models": "models", + "embedding": "embeddings" }, "class": "gpt", "models": [ diff --git a/internal/entity/models/factory.go b/internal/entity/models/factory.go index f4b64271f4..8475049c5b 100644 --- a/internal/entity/models/factory.go +++ b/internal/entity/models/factory.go @@ -57,6 +57,8 @@ func (f *ModelFactory) CreateModelDriver(providerName string, baseURL map[string return NewXAIModel(baseURL, urlSuffix), nil case "lmstudio": return NewLmStudioModel(baseURL, urlSuffix), nil + case "openai": + return NewOpenAIModel(baseURL, urlSuffix), nil case "nvidia": return NewNvidiaModel(baseURL, urlSuffix), nil case "openrouter": diff --git a/internal/entity/models/openai.go b/internal/entity/models/openai.go index 1adbb35cbc..fcacb6d22b 100644 --- a/internal/entity/models/openai.go +++ b/internal/entity/models/openai.go @@ -403,12 +403,115 @@ func (z *OpenAIModel) ChatStreamlyWithSender(modelName string, messages []Messag return nil } -// Encode encodes a list of texts into embeddings. OpenAI does expose -// embedding endpoints (text-embedding-3-* and text-embedding-ada-002), -// but this initial driver intentionally leaves embedding support -// unimplemented. A follow-up PR can add it. +// openaiEmbeddingResponse is the response shape returned by +// /v1/embeddings. The "index" field gives the position of the embedding +// in the input array, which we use to keep the output order stable +// even if the API returns items in a different order. +type openaiEmbeddingResponse struct { + Data []struct { + Index int `json:"index"` + Embedding []interface{} `json:"embedding"` + } `json:"data"` +} + +// Encode turns a list of texts into embedding vectors using the +// OpenAI /v1/embeddings endpoint (e.g. text-embedding-3-small, +// text-embedding-3-large, text-embedding-ada-002). The output has +// one vector per input, in the same order the inputs were given. func (z *OpenAIModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { - return nil, fmt.Errorf("not implemented") + if len(texts) == 0 { + return [][]float64{}, nil + } + + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + + if modelName == nil || *modelName == "" { + return nil, fmt.Errorf("model name is required") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := z.baseURLForRegion(region) + if err != nil { + return nil, err + } + url := fmt.Sprintf("%s/%s", baseURL, z.URLSuffix.Embedding) + + reqBody := map[string]interface{}{ + "model": *modelName, + "input": texts, + } + if embeddingConfig != nil && embeddingConfig.Dimension > 0 { + reqBody["dimensions"] = embeddingConfig.Dimension + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := z.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("OpenAI embeddings API error: %s, body: %s", resp.Status, string(body)) + } + + var parsed openaiEmbeddingResponse + if err = json.Unmarshal(body, &parsed); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + embeddings := make([][]float64, len(texts)) + for _, item := range parsed.Data { + if item.Index < 0 || item.Index >= len(texts) { + continue + } + vec := make([]float64, len(item.Embedding)) + for j, v := range item.Embedding { + switch val := v.(type) { + case float64: + vec[j] = val + case float32: + vec[j] = float64(val) + default: + return nil, fmt.Errorf("unexpected embedding value type at item %d index %d", item.Index, j) + } + } + embeddings[item.Index] = vec + } + + for i, vec := range embeddings { + if vec == nil { + return nil, fmt.Errorf("missing embedding for input at index %d", i) + } + } + + return embeddings, nil } // ListModels returns the list of model ids visible to the API key.