Go: implement Encode (embeddings) in OpenAI driver (#14630)

### What problem does this PR solve?

The OpenAI Go driver landed in #14605 with chat, list models, and check
connection. Encode was left as a stub that returns \`not implemented\`.

\`conf/models/openai.json\` already lists three embedding models out of
the box:

- text-embedding-ada-002
- text-embedding-3-small
- text-embedding-3-large

So a tenant who picked one of these in the Go layer could not actually
run an embedding call. This PR fills the gap.

### What this PR includes

- \`conf/models/openai.json\`: add \`\"embedding\": \"embeddings\"\`
under \`url_suffix\` so the driver can build the URL from config. This
matches the \`URLSuffix.Embedding\` field used by other drivers
(siliconflow, zhipu-ai).
- \`internal/entity/models/openai.go\`: replace the Encode stub with a
real implementation that POSTs to \`/v1/embeddings\`. Adds a small local
response type \`openaiEmbeddingResponse\`.

No factory change. No interface change.

### How the implementation works

- Validate \`apiConfig\` and the API key, validate the model name. Use
the existing \`baseURLForRegion\` helper so an unknown region fails fast
with a clear error.
- Wrap the request with \`context.WithTimeout(nonStreamCallTimeout)\` so
the call has a clear deadline. Same pattern as \`ChatWithMessages\` and
\`ListModels\` already use in this file.
- Send all input texts in one request. The OpenAI API accepts the
\`input\` field as an array.
- Parse \`data[*].embedding\` and copy each slice into a \`[][]float64\`
indexed by \`data[*].index\` so the output order matches the input order
even if the API returns items in a different order.
- Handle both \`float64\` and \`float32\` element types, the way the
SiliconFlow driver does.
- An empty input slice returns \`[][]float64{}\` with no HTTP call.
- Non-200 responses propagate the upstream status line and body.
- A final pass checks that every input slot got a vector. If any slot is
still nil, return a clear error so the caller does not silently use a
zero vector.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

### How was this tested?

- \`go build ./internal/entity/models/...\` in a clean go 1.25 image
(the go.mod minimum) returns exit 0.
- The full method set on \`OpenAIModel\` still matches the
\`ModelDriver\` interface.
- Pattern parity with the existing SiliconFlow Encode implementation
(\`internal/entity/models/siliconflow.go\`).

Closes #14629

---------

Co-authored-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
Panda Dev
2026-05-10 04:31:37 +02:00
committed by GitHub
parent 048ec2fc5c
commit 6bfe0f9a10
3 changed files with 112 additions and 6 deletions

View File

@@ -5,7 +5,8 @@
},
"url_suffix": {
"chat": "chat/completions",
"models": "models"
"models": "models",
"embedding": "embeddings"
},
"class": "gpt",
"models": [

View File

@@ -57,6 +57,8 @@ func (f *ModelFactory) CreateModelDriver(providerName string, baseURL map[string
return NewXAIModel(baseURL, urlSuffix), nil
case "lmstudio":
return NewLmStudioModel(baseURL, urlSuffix), nil
case "openai":
return NewOpenAIModel(baseURL, urlSuffix), nil
case "nvidia":
return NewNvidiaModel(baseURL, urlSuffix), nil
case "openrouter":

View File

@@ -403,12 +403,115 @@ func (z *OpenAIModel) ChatStreamlyWithSender(modelName string, messages []Messag
return nil
}
// Encode encodes a list of texts into embeddings. OpenAI does expose
// embedding endpoints (text-embedding-3-* and text-embedding-ada-002),
// but this initial driver intentionally leaves embedding support
// unimplemented. A follow-up PR can add it.
// openaiEmbeddingResponse is the response shape returned by
// /v1/embeddings. The "index" field gives the position of the embedding
// in the input array, which we use to keep the output order stable
// even if the API returns items in a different order.
type openaiEmbeddingResponse struct {
Data []struct {
Index int `json:"index"`
Embedding []interface{} `json:"embedding"`
} `json:"data"`
}
// Encode turns a list of texts into embedding vectors using the
// OpenAI /v1/embeddings endpoint (e.g. text-embedding-3-small,
// text-embedding-3-large, text-embedding-ada-002). The output has
// one vector per input, in the same order the inputs were given.
func (z *OpenAIModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
return nil, fmt.Errorf("not implemented")
if len(texts) == 0 {
return [][]float64{}, nil
}
if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" {
return nil, fmt.Errorf("api key is required")
}
if modelName == nil || *modelName == "" {
return nil, fmt.Errorf("model name is required")
}
region := "default"
if apiConfig.Region != nil && *apiConfig.Region != "" {
region = *apiConfig.Region
}
baseURL, err := z.baseURLForRegion(region)
if err != nil {
return nil, err
}
url := fmt.Sprintf("%s/%s", baseURL, z.URLSuffix.Embedding)
reqBody := map[string]interface{}{
"model": *modelName,
"input": texts,
}
if embeddingConfig != nil && embeddingConfig.Dimension > 0 {
reqBody["dimensions"] = embeddingConfig.Dimension
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
resp, err := z.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("OpenAI embeddings API error: %s, body: %s", resp.Status, string(body))
}
var parsed openaiEmbeddingResponse
if err = json.Unmarshal(body, &parsed); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
embeddings := make([][]float64, len(texts))
for _, item := range parsed.Data {
if item.Index < 0 || item.Index >= len(texts) {
continue
}
vec := make([]float64, len(item.Embedding))
for j, v := range item.Embedding {
switch val := v.(type) {
case float64:
vec[j] = val
case float32:
vec[j] = float64(val)
default:
return nil, fmt.Errorf("unexpected embedding value type at item %d index %d", item.Index, j)
}
}
embeddings[item.Index] = vec
}
for i, vec := range embeddings {
if vec == nil {
return nil, fmt.Errorf("missing embedding for input at index %d", i)
}
}
return embeddings, nil
}
// ListModels returns the list of model ids visible to the API key.