From f4f8bed9f7aff4e6107b4c54b71f52f04a36b130 Mon Sep 17 00:00:00 2001 From: Joseff Date: Sun, 10 May 2026 23:24:21 -0400 Subject: [PATCH] Go: implement Encode (embeddings) in Google Gemini driver (#14682) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What problem does this PR solve? - Implements the `Encode` method in the Google Gemini driver, which was previously a stub returning `not implemented` - Uses the `google.golang.org/genai` SDK's `EmbedContent` API, which routes to the `batchEmbedContents` endpoint internally — all texts are sent in a single request - Adds `text-embedding-004` (max 2048 tokens) to `conf/models/google.json` - Response values are `[]float32` from the SDK and are cast to `[]float64` to satisfy the `ModelDriver` interface ## Files changed - `internal/entity/models/google.go` — full `Encode` implementation - `conf/models/google.json` — adds `text-embedding-004` embedding model ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- conf/models/google.json | 7 ++++ internal/entity/models/google.go | 58 ++++++++++++++++++++++++++++++-- 2 files changed, 62 insertions(+), 3 deletions(-) diff --git a/conf/models/google.json b/conf/models/google.json index 2e4cf30525..a1d5f129f0 100644 --- a/conf/models/google.json +++ b/conf/models/google.json @@ -18,6 +18,13 @@ "default_value": true, "clear_thinking": true } + }, + { + "name": "text-embedding-004", + "max_tokens": 2048, + "model_types": [ + "embedding" + ] } ], "features": { diff --git a/internal/entity/models/google.go b/internal/entity/models/google.go index b5679ac8da..052801a0d9 100644 --- a/internal/entity/models/google.go +++ b/internal/entity/models/google.go @@ -212,9 +212,60 @@ func (z *GoogleModel) ChatStreamlyWithSender(modelName string, messages []Messag return err } -// Encode encodes a list of texts into embeddings +// Encode generates embeddings for a batch of texts using the Gemini embeddings API. +// The SDK routes to batchEmbedContents internally, so all texts are sent in one request. func (z *GoogleModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { - return nil, fmt.Errorf("not implemented") + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + if modelName == nil || *modelName == "" { + return nil, fmt.Errorf("model name is required") + } + if len(texts) == 0 { + return nil, fmt.Errorf("texts is empty") + } + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + client, err := genai.NewClient(ctx, &genai.ClientConfig{ + APIKey: *apiConfig.ApiKey, + Backend: genai.BackendGeminiAPI, + }) + if err != nil { + return nil, fmt.Errorf("failed to create client: %w", err) + } + + contents := make([]*genai.Content, len(texts)) + for i, text := range texts { + contents[i] = genai.NewContentFromText(text, genai.RoleUser) + } + + var cfg *genai.EmbedContentConfig + if embeddingConfig != nil && embeddingConfig.Dimension > 0 { + dim := int32(embeddingConfig.Dimension) + cfg = &genai.EmbedContentConfig{OutputDimensionality: &dim} + } + + resp, err := client.Models.EmbedContent(ctx, *modelName, contents, cfg) + if err != nil { + return nil, fmt.Errorf("failed to embed content: %w", err) + } + + if len(resp.Embeddings) != len(texts) { + return nil, fmt.Errorf("expected %d embeddings, got %d", len(texts), len(resp.Embeddings)) + } + + result := make([][]float64, len(resp.Embeddings)) + for i, emb := range resp.Embeddings { + vec := make([]float64, len(emb.Values)) + for j, v := range emb.Values { + vec[j] = float64(v) + } + result[i] = vec + } + + return result, nil } func (z *GoogleModel) ListModels(apiConfig *APIConfig) ([]string, error) { @@ -245,7 +296,8 @@ func (z *GoogleModel) Balance(apiConfig *APIConfig) (map[string]interface{}, err } func (z *GoogleModel) CheckConnection(apiConfig *APIConfig) error { - return fmt.Errorf("no such method") + _, err := z.ListModels(apiConfig) + return err } // Rerank calculates similarity scores between query and documents