mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-01 16:25:44 +08:00
Go: implement Encode (embeddings) in vLLM driver (#14688)
### What problem does this PR solve? The vLLM Go driver shipped with a stub \`Encode\` method that returned \`not implemented\`, even though vLLM is one of the most common production-grade self-hosted inference servers and exposes an OpenAI-compatible embeddings endpoint at \`/v1/embeddings\`. Users who self-host \`BAAI/bge-m3\`, \`Qwen3-Embedding-*\`, \`NV-Embed-v2\`, or similar models on vLLM could not run an embedding call through the Go layer. The existing \`ListModels\` already discovers the loaded models, but the embedding path failed because \`Encode\` was a stub. ### What this PR includes - \`conf/models/vllm.json\`: add \`\"embedding\": \"embeddings\"\` under \`url_suffix\` so the driver can build the URL from config. - \`internal/entity/models/vllm.go\`: replace the \`Encode\` stub with a real implementation. Adds a small local response type that matches the OpenAI-compatible shape. No factory change. No interface change. ### How the driver works - Validate the model name. The API key is optional for self-hosted vLLM, so the Authorization header is only set when both \`apiConfig\` and \`ApiKey\` are non-nil and non-empty, the same pattern the recently merged CheckConnection PR (#14614) uses. - Resolve the region with a default fallback. Return a clear "missing base URL" error when the user has not configured the local access address yet. - Use a per-call \`context.WithTimeout(30s)\` and \`http.NewRequestWithContext\`, the same pattern the merged Aliyun Encode (#14647) and in-flight Ollama Encode (#14664) use. - Send \`{model, input: [texts]}\` in one request. - Parse \`data[*].embedding\` and copy each slice into a \`[][]float64\` indexed by \`data[*].index\`, so the output order matches the input order. - Handle both \`float64\` and \`float32\` element types. - Empty input returns \`[][]float64{}\` with no HTTP call. - Length mismatch between input and result, out-of-range index, and any missing slot all return clear errors instead of silent zero vectors. ### Type of change - [x] New Feature (non-breaking change which adds functionality) ### How was this tested? - \`go build ./internal/entity/models/...\` in a clean go 1.25 image returns exit 0. - The full method set on \`VllmModel\` still matches the \`ModelDriver\` interface. - Pattern parity with the merged Aliyun Encode (#14647), the in-flight Ollama Encode (#14664), and the existing SiliconFlow Encode. Closes #14687
This commit is contained in:
@@ -2,7 +2,8 @@
|
||||
"name": "vllm",
|
||||
"url_suffix": {
|
||||
"chat": "chat/completions",
|
||||
"models": "models"
|
||||
"models": "models",
|
||||
"embedding": "embeddings"
|
||||
},
|
||||
"class": "local"
|
||||
}
|
||||
@@ -19,6 +19,7 @@ package models
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -378,8 +379,113 @@ func (z *VllmModel) ChatStreamlyWithSender(modelName string, messages []Message,
|
||||
}
|
||||
|
||||
// Encode encodes a list of texts into embeddings
|
||||
type vllmEmbeddingResponse struct {
|
||||
Data []struct {
|
||||
Index int `json:"index"`
|
||||
Embedding []interface{} `json:"embedding"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
func (z *VllmModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
if len(texts) == 0 {
|
||||
return [][]float64{}, nil
|
||||
}
|
||||
|
||||
if modelName == nil || *modelName == "" {
|
||||
return nil, fmt.Errorf("model name is required")
|
||||
}
|
||||
|
||||
region := "default"
|
||||
if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
baseURL := z.BaseURL[region]
|
||||
if baseURL == "" {
|
||||
baseURL = z.BaseURL["default"]
|
||||
}
|
||||
if baseURL == "" {
|
||||
return nil, fmt.Errorf("missing base URL: please configure the local access address for vLLM (e.g., http://127.0.0.1:8000/v1)")
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), z.URLSuffix.Embedding)
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
"model": *modelName,
|
||||
"input": texts,
|
||||
}
|
||||
if embeddingConfig != nil && embeddingConfig.Dimension > 0 {
|
||||
reqBody["dimensions"] = embeddingConfig.Dimension
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if apiConfig != nil && apiConfig.ApiKey != nil && *apiConfig.ApiKey != "" {
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
|
||||
}
|
||||
|
||||
resp, err := z.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("vLLM embeddings API error: %s, body: %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
var parsed vllmEmbeddingResponse
|
||||
if err = json.Unmarshal(body, &parsed); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
if len(parsed.Data) != len(texts) {
|
||||
return nil, fmt.Errorf("vllm embeddings: expected %d results, got %d", len(texts), len(parsed.Data))
|
||||
}
|
||||
|
||||
embeddings := make([][]float64, len(texts))
|
||||
for _, item := range parsed.Data {
|
||||
if item.Index < 0 || item.Index >= len(texts) {
|
||||
return nil, fmt.Errorf("unexpected embedding index %d for %d inputs", item.Index, len(texts))
|
||||
}
|
||||
vec := make([]float64, len(item.Embedding))
|
||||
for j, v := range item.Embedding {
|
||||
switch val := v.(type) {
|
||||
case float64:
|
||||
vec[j] = val
|
||||
case float32:
|
||||
vec[j] = float64(val)
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected embedding value type at item %d index %d", item.Index, j)
|
||||
}
|
||||
}
|
||||
embeddings[item.Index] = vec
|
||||
}
|
||||
|
||||
for i, vec := range embeddings {
|
||||
if vec == nil {
|
||||
return nil, fmt.Errorf("missing embedding for input at index %d", i)
|
||||
}
|
||||
}
|
||||
|
||||
return embeddings, nil
|
||||
}
|
||||
|
||||
func (z *VllmModel) ListModels(apiConfig *APIConfig) ([]string, error) {
|
||||
|
||||
Reference in New Issue
Block a user