mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
Go: implement Encode (embeddings) in Gitee AI driver (#14698)
### What problem does this PR solve? The Gitee AI Go driver in `internal/entity/models/gitee.go` shipped with a stub `Encode` method that returned `gitee, no such method`, even though `conf/models/gitee.json` already wires the `embedding` URL suffix. The conf also listed no embedding models, so the picker had nothing to select. This blocked any tenant who wanted to use Gitee AI for chat, rerank (already working, see #14656), and embeddings from a single provider. This PR fills the gap, mirroring the just-merged Aliyun `Encode` (#14647): - `internal/entity/models/gitee.go`: replace the `Encode` stub with a real implementation. Validates inputs, resolves the region with a default fallback, POSTs the standard OpenAI-compatible `{"model", "input": [...]}` body to `BaseURL[region] + URLSuffix.Embedding`, parses `data[*].embedding` indexed by `data[*].index` so output order matches input order, handles both `float64` and `float32` element types, and uses a 30s per-call context deadline matching the merged `Rerank`. - `conf/models/gitee.json`: add `BAAI/bge-m3` so the embedding picker has something to select. No factory change. No interface change. No URL suffix change. Verified with `go build`, `go vet`, and `gofmt -l` : all clean. Closes #14697 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@@ -39,6 +39,13 @@
|
||||
"model_types": [
|
||||
"rerank"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "BAAI/bge-m3",
|
||||
"max_tokens": 8192,
|
||||
"model_types": [
|
||||
"embedding"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -29,6 +29,13 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type giteeEmbeddingResponse struct {
|
||||
Data []struct {
|
||||
Index int `json:"index"`
|
||||
Embedding []interface{} `json:"embedding"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
// GiteeModel implements ModelDriver for Gitee
|
||||
type GiteeModel struct {
|
||||
BaseURL map[string]string
|
||||
@@ -400,7 +407,105 @@ func (z *GiteeModel) ChatStreamlyWithSender(modelName string, messages []Message
|
||||
|
||||
// Encode encodes a list of texts into embeddings
|
||||
func (z *GiteeModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) {
|
||||
return nil, fmt.Errorf("%s, no such method", z.Name())
|
||||
if len(texts) == 0 {
|
||||
return [][]float64{}, nil
|
||||
}
|
||||
|
||||
if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" {
|
||||
return nil, fmt.Errorf("api key is required")
|
||||
}
|
||||
|
||||
if modelName == nil || *modelName == "" {
|
||||
return nil, fmt.Errorf("model name is required")
|
||||
}
|
||||
|
||||
region := "default"
|
||||
if apiConfig.Region != nil && *apiConfig.Region != "" {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
baseURL := z.BaseURL["default"]
|
||||
if region != "default" {
|
||||
if regional, ok := z.BaseURL[region]; ok && regional != "" {
|
||||
baseURL = regional
|
||||
}
|
||||
}
|
||||
if baseURL == "" {
|
||||
return nil, fmt.Errorf("gitee: no base URL configured for default region")
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), z.URLSuffix.Embedding)
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
"model": *modelName,
|
||||
"input": texts,
|
||||
}
|
||||
if embeddingConfig != nil && embeddingConfig.Dimension > 0 {
|
||||
reqBody["dimensions"] = embeddingConfig.Dimension
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
|
||||
|
||||
resp, err := z.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("Gitee embeddings API error: %s, body: %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
var parsed giteeEmbeddingResponse
|
||||
if err = json.Unmarshal(body, &parsed); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
embeddings := make([][]float64, len(texts))
|
||||
for _, item := range parsed.Data {
|
||||
if item.Index < 0 || item.Index >= len(texts) {
|
||||
return nil, fmt.Errorf("unexpected embedding index %d for %d inputs", item.Index, len(texts))
|
||||
}
|
||||
vec := make([]float64, len(item.Embedding))
|
||||
for j, v := range item.Embedding {
|
||||
switch val := v.(type) {
|
||||
case float64:
|
||||
vec[j] = val
|
||||
case float32:
|
||||
vec[j] = float64(val)
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected embedding value type at item %d index %d", item.Index, j)
|
||||
}
|
||||
}
|
||||
embeddings[item.Index] = vec
|
||||
}
|
||||
|
||||
for i, vec := range embeddings {
|
||||
if vec == nil {
|
||||
return nil, fmt.Errorf("missing embedding for input at index %d", i)
|
||||
}
|
||||
}
|
||||
|
||||
return embeddings, nil
|
||||
}
|
||||
|
||||
type giteeRerankRequest struct {
|
||||
|
||||
Reference in New Issue
Block a user