diff --git a/conf/models/google.json b/conf/models/google.json index 2e4cf30525..a1d5f129f0 100644 --- a/conf/models/google.json +++ b/conf/models/google.json @@ -18,6 +18,13 @@ "default_value": true, "clear_thinking": true } + }, + { + "name": "text-embedding-004", + "max_tokens": 2048, + "model_types": [ + "embedding" + ] } ], "features": { diff --git a/internal/entity/models/google.go b/internal/entity/models/google.go index b5679ac8da..052801a0d9 100644 --- a/internal/entity/models/google.go +++ b/internal/entity/models/google.go @@ -212,9 +212,60 @@ func (z *GoogleModel) ChatStreamlyWithSender(modelName string, messages []Messag return err } -// Encode encodes a list of texts into embeddings +// Encode generates embeddings for a batch of texts using the Gemini embeddings API. +// The SDK routes to batchEmbedContents internally, so all texts are sent in one request. func (z *GoogleModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { - return nil, fmt.Errorf("not implemented") + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + if modelName == nil || *modelName == "" { + return nil, fmt.Errorf("model name is required") + } + if len(texts) == 0 { + return nil, fmt.Errorf("texts is empty") + } + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + client, err := genai.NewClient(ctx, &genai.ClientConfig{ + APIKey: *apiConfig.ApiKey, + Backend: genai.BackendGeminiAPI, + }) + if err != nil { + return nil, fmt.Errorf("failed to create client: %w", err) + } + + contents := make([]*genai.Content, len(texts)) + for i, text := range texts { + contents[i] = genai.NewContentFromText(text, genai.RoleUser) + } + + var cfg *genai.EmbedContentConfig + if embeddingConfig != nil && embeddingConfig.Dimension > 0 { + dim := int32(embeddingConfig.Dimension) + cfg = &genai.EmbedContentConfig{OutputDimensionality: &dim} + } + + resp, err := client.Models.EmbedContent(ctx, *modelName, contents, cfg) + if err != nil { + return nil, fmt.Errorf("failed to embed content: %w", err) + } + + if len(resp.Embeddings) != len(texts) { + return nil, fmt.Errorf("expected %d embeddings, got %d", len(texts), len(resp.Embeddings)) + } + + result := make([][]float64, len(resp.Embeddings)) + for i, emb := range resp.Embeddings { + vec := make([]float64, len(emb.Values)) + for j, v := range emb.Values { + vec[j] = float64(v) + } + result[i] = vec + } + + return result, nil } func (z *GoogleModel) ListModels(apiConfig *APIConfig) ([]string, error) { @@ -245,7 +296,8 @@ func (z *GoogleModel) Balance(apiConfig *APIConfig) (map[string]interface{}, err } func (z *GoogleModel) CheckConnection(apiConfig *APIConfig) error { - return fmt.Errorf("no such method") + _, err := z.ListModels(apiConfig) + return err } // Rerank calculates similarity scores between query and documents