Go: implement Encode (embeddings) in NVIDIA driver (#14700)

### What problem does this PR solve?

The NVIDIA Go driver in `internal/entity/models/nvidia.go` shipped with
a stub `Encode`
method that returned `no such method`. `conf/models/nvidia.json` already
lists
`nvidia/llama-3.2-nemoretriever-1b-vlm-embed-v1` as an embedding model,
but the conf had
no `embedding` URL suffix, so the picker had nothing wired even if
`Encode` worked.

A tenant who wanted to use NVIDIA NIM for chat (already working) and
embeddings from a
single provider could not, even though the upstream endpoint is public
at
`https://integrate.api.nvidia.com/v1/embeddings` and uses an
OpenAI-compatible request
body extended with the NVIDIA-specific `input_type` and `truncate`
fields. Several other
Go drivers already implement `Encode` (siliconflow, zhipu-ai, aliyun),
so the interface
and the pattern are well-established.

This PR fills the gap.

### What this PR includes

* `conf/models/nvidia.json`: declare the `embedding` URL suffix
alongside the existing
`chat` and `models` entries. The embedding model entry was already
present, so no
  model addition is needed.
* `internal/entity/models/nvidia.go`: replace the `Encode` stub with a
real
implementation. Adds a small local response type that matches the
OpenAI-compatible
  shape NVIDIA NIM returns.

No factory change. No interface change.

### How the driver works

* Validates `apiConfig` and the API key, validates the model name,
resolves the region
with a default fallback (matching the pattern the merged `ListModels`
and
`CheckConnection` paths in this driver already use), and builds the URL
from
  `BaseURL[region] + URLSuffix.Embedding`.
* Sends all input texts in one request as the `input` array, with the
NVIDIA-specific `input_type: "query"`, `encoding_format: "float"`, and
`truncate: "END"`
  fields, mirroring the Python `NvidiaEmbed` reference.
* Parses `data[*].embedding` and copies each slice into `[][]float64`
indexed by
`data[*].index` so the output order matches the input order even if the
API returns
  items in a different order.
* Handles both `float64` and `float32` element types.
* Empty input returns `[][]float64{}` with no HTTP call.
* Non-200 responses propagate the upstream status line and body.
* A final pass checks every input slot got a vector and returns a clear
error if any
  slot is still nil.
* Per-call 30s context deadline so a slow call cannot block forever.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

### How was this tested?

* `go build ./internal/entity/models/...` returns exit 0.
* `go vet ./internal/entity/models/...` is clean.
* `gofmt -l internal/entity/models/nvidia.go` is clean.
* The full method set on `NvidiaModel` still matches the `ModelDriver`
interface.
* Pattern parity with the just-merged Aliyun `Encode` (#14647).

Closes #14699
This commit is contained in:
BitToby
2026-05-10 18:50:50 -10:00
committed by GitHub
parent 8ff623fbc4
commit 4b96362092
2 changed files with 152 additions and 2 deletions

View File

@@ -5,7 +5,8 @@
},
"url_suffix": {
"chat": "chat/completions",
"models": "models"
"models": "models",
"embedding": "embeddings"
},
"class": "nvidia",
"models": [
@@ -16,6 +17,13 @@
"chat"
]
},
{
"name": "baai/bge-m3",
"max_tokens": 8192,
"model_types": [
"embedding"
]
},
{
"name": "bytedance/seed-oss-36b-instruct",
"max_tokens": 32768,
@@ -295,6 +303,13 @@
"embedding"
]
},
{
"name": "nvidia/llama-3.2-nv-embedqa-1b-v2",
"max_tokens": 8192,
"model_types": [
"embedding"
]
},
{
"name": "nvidia/llama-3.3-nemotron-super-49b-v1",
"max_tokens": 131072,
@@ -360,6 +375,27 @@
"chat"
]
},
{
"name": "nvidia/nv-embed-v1",
"max_tokens": 32768,
"model_types": [
"embedding"
]
},
{
"name": "nvidia/nv-embedqa-e5-v5",
"max_tokens": 512,
"model_types": [
"embedding"
]
},
{
"name": "nvidia/nv-embedqa-mistral-7b-v2",
"max_tokens": 512,
"model_types": [
"embedding"
]
},
{
"name": "nvidia/nvidia-nemotron-nano-9b-v2",
"max_tokens": 131072,
@@ -424,6 +460,13 @@
"clear_thinking": true
}
},
{
"name": "snowflake/arctic-embed-l",
"max_tokens": 512,
"model_types": [
"embedding"
]
},
{
"name": "z-ai/glm-5",
"max_tokens": 131072,

View File

@@ -3,6 +3,7 @@ package models
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
@@ -329,8 +330,114 @@ func (n *NvidiaModel) ChatStreamlyWithSender(modelName string, messages []Messag
return scanner.Err()
}
type nvidiaEmbeddingResponse struct {
Data []struct {
Index int `json:"index"`
Embedding []interface{} `json:"embedding"`
} `json:"data"`
}
func (n NvidiaModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
return nil, fmt.Errorf("no such method")
if len(texts) == 0 {
return [][]float64{}, nil
}
if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" {
return nil, fmt.Errorf("api key is required")
}
if modelName == nil || *modelName == "" {
return nil, fmt.Errorf("model name is required")
}
region := "default"
if apiConfig.Region != nil && *apiConfig.Region != "" {
region = *apiConfig.Region
}
baseURL := n.BaseURL[region]
if baseURL == "" {
baseURL = n.BaseURL["default"]
}
if baseURL == "" {
return nil, fmt.Errorf("nvidia: no base URL configured for region %q", region)
}
url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), n.URLSuffix.Embedding)
reqBody := map[string]interface{}{
"model": *modelName,
"input": texts,
"input_type": "query",
"encoding_format": "float",
"truncate": "END",
}
if embeddingConfig != nil && embeddingConfig.Dimension > 0 {
reqBody["dimensions"] = embeddingConfig.Dimension
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
resp, err := n.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("Nvidia embeddings API error: %s, body: %s", resp.Status, string(body))
}
var parsed nvidiaEmbeddingResponse
if err = json.Unmarshal(body, &parsed); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
embeddings := make([][]float64, len(texts))
for _, item := range parsed.Data {
if item.Index < 0 || item.Index >= len(texts) {
return nil, fmt.Errorf("unexpected embedding index %d for %d inputs", item.Index, len(texts))
}
vec := make([]float64, len(item.Embedding))
for j, v := range item.Embedding {
switch val := v.(type) {
case float64:
vec[j] = val
case float32:
vec[j] = float64(val)
default:
return nil, fmt.Errorf("unexpected embedding value type at item %d index %d", item.Index, j)
}
}
embeddings[item.Index] = vec
}
for i, vec := range embeddings {
if vec == nil {
return nil, fmt.Errorf("missing embedding for input at index %d", i)
}
}
return embeddings, nil
}
func (n NvidiaModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) {