mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-06 03:18:36 +08:00
Go: implement Encode (embeddings) in NVIDIA driver (#14700)
### What problem does this PR solve? The NVIDIA Go driver in `internal/entity/models/nvidia.go` shipped with a stub `Encode` method that returned `no such method`. `conf/models/nvidia.json` already lists `nvidia/llama-3.2-nemoretriever-1b-vlm-embed-v1` as an embedding model, but the conf had no `embedding` URL suffix, so the picker had nothing wired even if `Encode` worked. A tenant who wanted to use NVIDIA NIM for chat (already working) and embeddings from a single provider could not, even though the upstream endpoint is public at `https://integrate.api.nvidia.com/v1/embeddings` and uses an OpenAI-compatible request body extended with the NVIDIA-specific `input_type` and `truncate` fields. Several other Go drivers already implement `Encode` (siliconflow, zhipu-ai, aliyun), so the interface and the pattern are well-established. This PR fills the gap. ### What this PR includes * `conf/models/nvidia.json`: declare the `embedding` URL suffix alongside the existing `chat` and `models` entries. The embedding model entry was already present, so no model addition is needed. * `internal/entity/models/nvidia.go`: replace the `Encode` stub with a real implementation. Adds a small local response type that matches the OpenAI-compatible shape NVIDIA NIM returns. No factory change. No interface change. ### How the driver works * Validates `apiConfig` and the API key, validates the model name, resolves the region with a default fallback (matching the pattern the merged `ListModels` and `CheckConnection` paths in this driver already use), and builds the URL from `BaseURL[region] + URLSuffix.Embedding`. * Sends all input texts in one request as the `input` array, with the NVIDIA-specific `input_type: "query"`, `encoding_format: "float"`, and `truncate: "END"` fields, mirroring the Python `NvidiaEmbed` reference. * Parses `data[*].embedding` and copies each slice into `[][]float64` indexed by `data[*].index` so the output order matches the input order even if the API returns items in a different order. * Handles both `float64` and `float32` element types. * Empty input returns `[][]float64{}` with no HTTP call. * Non-200 responses propagate the upstream status line and body. * A final pass checks every input slot got a vector and returns a clear error if any slot is still nil. * Per-call 30s context deadline so a slow call cannot block forever. ### Type of change - [x] New Feature (non-breaking change which adds functionality) ### How was this tested? * `go build ./internal/entity/models/...` returns exit 0. * `go vet ./internal/entity/models/...` is clean. * `gofmt -l internal/entity/models/nvidia.go` is clean. * The full method set on `NvidiaModel` still matches the `ModelDriver` interface. * Pattern parity with the just-merged Aliyun `Encode` (#14647). Closes #14699
This commit is contained in:
@@ -5,7 +5,8 @@
|
||||
},
|
||||
"url_suffix": {
|
||||
"chat": "chat/completions",
|
||||
"models": "models"
|
||||
"models": "models",
|
||||
"embedding": "embeddings"
|
||||
},
|
||||
"class": "nvidia",
|
||||
"models": [
|
||||
@@ -16,6 +17,13 @@
|
||||
"chat"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "baai/bge-m3",
|
||||
"max_tokens": 8192,
|
||||
"model_types": [
|
||||
"embedding"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "bytedance/seed-oss-36b-instruct",
|
||||
"max_tokens": 32768,
|
||||
@@ -295,6 +303,13 @@
|
||||
"embedding"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "nvidia/llama-3.2-nv-embedqa-1b-v2",
|
||||
"max_tokens": 8192,
|
||||
"model_types": [
|
||||
"embedding"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "nvidia/llama-3.3-nemotron-super-49b-v1",
|
||||
"max_tokens": 131072,
|
||||
@@ -360,6 +375,27 @@
|
||||
"chat"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "nvidia/nv-embed-v1",
|
||||
"max_tokens": 32768,
|
||||
"model_types": [
|
||||
"embedding"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "nvidia/nv-embedqa-e5-v5",
|
||||
"max_tokens": 512,
|
||||
"model_types": [
|
||||
"embedding"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "nvidia/nv-embedqa-mistral-7b-v2",
|
||||
"max_tokens": 512,
|
||||
"model_types": [
|
||||
"embedding"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "nvidia/nvidia-nemotron-nano-9b-v2",
|
||||
"max_tokens": 131072,
|
||||
@@ -424,6 +460,13 @@
|
||||
"clear_thinking": true
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "snowflake/arctic-embed-l",
|
||||
"max_tokens": 512,
|
||||
"model_types": [
|
||||
"embedding"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "z-ai/glm-5",
|
||||
"max_tokens": 131072,
|
||||
|
||||
@@ -3,6 +3,7 @@ package models
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -329,8 +330,114 @@ func (n *NvidiaModel) ChatStreamlyWithSender(modelName string, messages []Messag
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
type nvidiaEmbeddingResponse struct {
|
||||
Data []struct {
|
||||
Index int `json:"index"`
|
||||
Embedding []interface{} `json:"embedding"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
func (n NvidiaModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
|
||||
return nil, fmt.Errorf("no such method")
|
||||
if len(texts) == 0 {
|
||||
return [][]float64{}, nil
|
||||
}
|
||||
|
||||
if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" {
|
||||
return nil, fmt.Errorf("api key is required")
|
||||
}
|
||||
|
||||
if modelName == nil || *modelName == "" {
|
||||
return nil, fmt.Errorf("model name is required")
|
||||
}
|
||||
|
||||
region := "default"
|
||||
if apiConfig.Region != nil && *apiConfig.Region != "" {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
baseURL := n.BaseURL[region]
|
||||
if baseURL == "" {
|
||||
baseURL = n.BaseURL["default"]
|
||||
}
|
||||
if baseURL == "" {
|
||||
return nil, fmt.Errorf("nvidia: no base URL configured for region %q", region)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), n.URLSuffix.Embedding)
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
"model": *modelName,
|
||||
"input": texts,
|
||||
"input_type": "query",
|
||||
"encoding_format": "float",
|
||||
"truncate": "END",
|
||||
}
|
||||
if embeddingConfig != nil && embeddingConfig.Dimension > 0 {
|
||||
reqBody["dimensions"] = embeddingConfig.Dimension
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
|
||||
|
||||
resp, err := n.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("Nvidia embeddings API error: %s, body: %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
var parsed nvidiaEmbeddingResponse
|
||||
if err = json.Unmarshal(body, &parsed); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
embeddings := make([][]float64, len(texts))
|
||||
for _, item := range parsed.Data {
|
||||
if item.Index < 0 || item.Index >= len(texts) {
|
||||
return nil, fmt.Errorf("unexpected embedding index %d for %d inputs", item.Index, len(texts))
|
||||
}
|
||||
vec := make([]float64, len(item.Embedding))
|
||||
for j, v := range item.Embedding {
|
||||
switch val := v.(type) {
|
||||
case float64:
|
||||
vec[j] = val
|
||||
case float32:
|
||||
vec[j] = float64(val)
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected embedding value type at item %d index %d", item.Index, j)
|
||||
}
|
||||
}
|
||||
embeddings[item.Index] = vec
|
||||
}
|
||||
|
||||
for i, vec := range embeddings {
|
||||
if vec == nil {
|
||||
return nil, fmt.Errorf("missing embedding for input at index %d", i)
|
||||
}
|
||||
}
|
||||
|
||||
return embeddings, nil
|
||||
}
|
||||
|
||||
func (n NvidiaModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) {
|
||||
|
||||
Reference in New Issue
Block a user