mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-03 17:21:59 +08:00
## Summary - Add custom `base_url` support to the Google Go model driver. - Preserve Google URL suffix configuration when creating custom base URL driver instances. - Validate Google chat/stream request inputs before constructing the SDK client. - Cover Google model listing, connection checks, base URL resolution, and request validation with focused tests. ## What changed - `GoogleModel.NewInstance` now returns a Google driver configured with the supplied base URL map. - Google SDK client creation now resolves configured base URLs through `genai.HTTPOptions.BaseURL`. - Base URL lookup supports configured regions, empty-region keys, and `default` fallback. - Google chat, streaming chat, embeddings, and model listing now reject blank API keys before creating SDK clients. - Google chat and streaming chat now reject blank model names locally, and streaming chat rejects a nil sender. - Existing message handling, embeddings, pagination, and provider errors are preserved. ## Why Google custom model instances could not use configured base URLs because `NewInstance` returned `nil` and the SDK client path ignored the driver base URL map. The request validation keeps invalid Google calls from reaching SDK client construction with blank credentials or incomplete chat inputs.
391 lines
12 KiB
Go
391 lines
12 KiB
Go
//
|
|
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
//
|
|
|
|
package models
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"ragflow/internal/common"
|
|
"strings"
|
|
|
|
"google.golang.org/genai"
|
|
)
|
|
|
|
type googleModelPage struct {
|
|
items []string
|
|
nextPageToken string
|
|
}
|
|
|
|
func collectGoogleModelNames(ctx context.Context, listPage func(context.Context, string) (googleModelPage, error)) ([]string, error) {
|
|
var modelNames []string
|
|
pageToken := ""
|
|
|
|
for {
|
|
page, err := listPage(ctx, pageToken)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
modelNames = append(modelNames, page.items...)
|
|
if page.nextPageToken == "" {
|
|
return modelNames, nil
|
|
}
|
|
pageToken = page.nextPageToken
|
|
}
|
|
}
|
|
|
|
var googleListModels = func(ctx context.Context, config *genai.ClientConfig) ([]string, error) {
|
|
client, err := genai.NewClient(ctx, config)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return collectGoogleModelNames(ctx, func(ctx context.Context, pageToken string) (googleModelPage, error) {
|
|
models, err := client.Models.List(ctx, &genai.ListModelsConfig{PageToken: pageToken})
|
|
if err != nil {
|
|
return googleModelPage{}, err
|
|
}
|
|
|
|
var modelNames []string
|
|
for _, m := range models.Items {
|
|
modelNames = append(modelNames, m.Name)
|
|
}
|
|
return googleModelPage{items: modelNames, nextPageToken: models.NextPageToken}, nil
|
|
})
|
|
}
|
|
|
|
// GoogleModel implements ModelDriver for Google AI
|
|
type GoogleModel struct {
|
|
BaseURL map[string]string
|
|
URLSuffix URLSuffix
|
|
}
|
|
|
|
// NewGoogleModel creates a new Google AI model instance
|
|
func NewGoogleModel(baseURL map[string]string, urlSuffix URLSuffix) *GoogleModel {
|
|
return &GoogleModel{
|
|
BaseURL: baseURL,
|
|
URLSuffix: urlSuffix,
|
|
}
|
|
}
|
|
|
|
func (g *GoogleModel) NewInstance(baseURL map[string]string) ModelDriver {
|
|
return NewGoogleModel(baseURL, g.URLSuffix)
|
|
}
|
|
|
|
func (g *GoogleModel) Name() string {
|
|
return "google"
|
|
}
|
|
|
|
func (g *GoogleModel) clientConfig(apiKey string, apiConfig *APIConfig) *genai.ClientConfig {
|
|
return &genai.ClientConfig{APIKey: apiKey, Backend: genai.BackendGeminiAPI, HTTPOptions: genai.HTTPOptions{BaseURL: g.baseURL(apiConfig)}}
|
|
}
|
|
|
|
func (g *GoogleModel) baseURL(apiConfig *APIConfig) string {
|
|
if apiConfig != nil && apiConfig.Region != nil {
|
|
if baseURL := strings.TrimSpace(g.BaseURL[*apiConfig.Region]); baseURL != "" {
|
|
return baseURL
|
|
}
|
|
}
|
|
return strings.TrimSpace(g.BaseURL["default"])
|
|
}
|
|
|
|
func (g *GoogleModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) {
|
|
if apiConfig == nil || apiConfig.ApiKey == nil || strings.TrimSpace(*apiConfig.ApiKey) == "" {
|
|
return nil, fmt.Errorf("api key is nil or empty")
|
|
}
|
|
if strings.TrimSpace(modelName) == "" {
|
|
return nil, fmt.Errorf("model name is empty")
|
|
}
|
|
|
|
if len(messages) == 0 {
|
|
return nil, fmt.Errorf("messages is empty")
|
|
}
|
|
|
|
ctx := context.Background()
|
|
client, err := genai.NewClient(ctx, g.clientConfig(strings.TrimSpace(*apiConfig.ApiKey), apiConfig))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Convert messages to Google SDK format
|
|
var contents []*genai.Content
|
|
for _, msg := range messages {
|
|
var role genai.Role
|
|
switch msg.Role {
|
|
case "user":
|
|
role = genai.RoleUser
|
|
case "model", "assistant":
|
|
role = genai.RoleModel
|
|
default:
|
|
role = genai.RoleUser
|
|
}
|
|
|
|
// Handle content based on type
|
|
switch c := msg.Content.(type) {
|
|
case string:
|
|
contents = append(contents, genai.NewContentFromText(c, role))
|
|
case []interface{}:
|
|
// Multimodal content - group parts within a single content
|
|
var parts []*genai.Part
|
|
for _, item := range c {
|
|
if itemMap, ok := item.(map[string]interface{}); ok {
|
|
contentType, _ := itemMap["type"].(string)
|
|
switch contentType {
|
|
case "text":
|
|
if text, ok := itemMap["text"].(string); ok {
|
|
parts = append(parts, genai.NewPartFromText(text))
|
|
}
|
|
case "image_url":
|
|
if imgMap, ok := itemMap["image_url"].(map[string]interface{}); ok {
|
|
if url, ok := imgMap["url"].(string); ok {
|
|
parts = append(parts, genai.NewPartFromURI(url, "image/jpeg"))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if len(parts) > 0 {
|
|
contents = append(contents, genai.NewContentFromParts(parts, role))
|
|
}
|
|
}
|
|
}
|
|
|
|
// Generate content (non-streaming)
|
|
response, err := client.Models.GenerateContent(ctx, modelName, contents, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Extract text from response
|
|
answer := response.Text()
|
|
|
|
return &ChatResponse{Answer: &answer}, nil
|
|
}
|
|
|
|
// ChatStreamlyWithSender sends messages and streams response via sender function (best performance, no channel)
|
|
func (g *GoogleModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig, sender func(*string, *string) error) error {
|
|
if len(messages) == 0 {
|
|
return fmt.Errorf("messages is empty")
|
|
}
|
|
if apiConfig == nil || apiConfig.ApiKey == nil || strings.TrimSpace(*apiConfig.ApiKey) == "" {
|
|
return fmt.Errorf("api key is nil or empty")
|
|
}
|
|
if strings.TrimSpace(modelName) == "" {
|
|
return fmt.Errorf("model name is empty")
|
|
}
|
|
if sender == nil {
|
|
return fmt.Errorf("sender is nil")
|
|
}
|
|
|
|
ctx := context.Background()
|
|
client, err := genai.NewClient(ctx, g.clientConfig(strings.TrimSpace(*apiConfig.ApiKey), apiConfig))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Convert messages to Google SDK format
|
|
var contents []*genai.Content
|
|
for _, msg := range messages {
|
|
var role genai.Role
|
|
switch msg.Role {
|
|
case "user":
|
|
role = genai.RoleUser
|
|
case "model", "assistant":
|
|
role = genai.RoleModel
|
|
default:
|
|
role = genai.RoleUser
|
|
}
|
|
|
|
// Handle content based on type
|
|
switch c := msg.Content.(type) {
|
|
case string:
|
|
contents = append(contents, genai.NewContentFromText(c, role))
|
|
case []interface{}:
|
|
// Multimodal content - group parts within a single content
|
|
var parts []*genai.Part
|
|
for _, item := range c {
|
|
if itemMap, ok := item.(map[string]interface{}); ok {
|
|
contentType, _ := itemMap["type"].(string)
|
|
switch contentType {
|
|
case "text":
|
|
if text, ok := itemMap["text"].(string); ok {
|
|
parts = append(parts, genai.NewPartFromText(text))
|
|
}
|
|
case "image_url":
|
|
if imgMap, ok := itemMap["image_url"].(map[string]interface{}); ok {
|
|
if url, ok := imgMap["url"].(string); ok {
|
|
parts = append(parts, genai.NewPartFromURI(url, "image/jpeg"))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if len(parts) > 0 {
|
|
contents = append(contents, genai.NewContentFromParts(parts, role))
|
|
}
|
|
}
|
|
}
|
|
|
|
for response, err := range client.Models.GenerateContentStream(
|
|
ctx,
|
|
modelName,
|
|
contents,
|
|
nil,
|
|
) {
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
content := response.Text()
|
|
|
|
var responseContent string
|
|
if chatModelConfig != nil && chatModelConfig.Thinking != nil && *chatModelConfig.Thinking {
|
|
responseContent = response.Candidates[0].Content.Parts[0].Text
|
|
}
|
|
|
|
if responseContent != "" {
|
|
common.Info(fmt.Sprintf("Thinking: %s", responseContent))
|
|
if err = sender(nil, &responseContent); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if content != "" {
|
|
common.Info(fmt.Sprintf("Answer: %s", content))
|
|
if err = sender(&content, nil); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
// Embed generates embeddings for a batch of texts using the Gemini embeddings API.
|
|
// The SDK routes to batchEmbedContents internally, so all texts are sent in one request.
|
|
func (g *GoogleModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) {
|
|
if apiConfig == nil || apiConfig.ApiKey == nil || strings.TrimSpace(*apiConfig.ApiKey) == "" {
|
|
return nil, fmt.Errorf("api key is required")
|
|
}
|
|
if modelName == nil || *modelName == "" {
|
|
return nil, fmt.Errorf("model name is required")
|
|
}
|
|
if len(texts) == 0 {
|
|
return nil, fmt.Errorf("texts is empty")
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout)
|
|
defer cancel()
|
|
|
|
client, err := genai.NewClient(ctx, g.clientConfig(strings.TrimSpace(*apiConfig.ApiKey), apiConfig))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create client: %w", err)
|
|
}
|
|
|
|
contents := make([]*genai.Content, len(texts))
|
|
for i, text := range texts {
|
|
contents[i] = genai.NewContentFromText(text, genai.RoleUser)
|
|
}
|
|
|
|
var cfg *genai.EmbedContentConfig
|
|
if embeddingConfig != nil && embeddingConfig.Dimension > 0 {
|
|
dim := int32(embeddingConfig.Dimension)
|
|
cfg = &genai.EmbedContentConfig{OutputDimensionality: &dim}
|
|
}
|
|
|
|
resp, err := client.Models.EmbedContent(ctx, *modelName, contents, cfg)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to embed content: %w", err)
|
|
}
|
|
|
|
if len(resp.Embeddings) != len(texts) {
|
|
return nil, fmt.Errorf("expected %d embeddings, got %d", len(texts), len(resp.Embeddings))
|
|
}
|
|
|
|
result := make([]EmbeddingData, len(resp.Embeddings))
|
|
for i, emb := range resp.Embeddings {
|
|
vec := make([]float64, len(emb.Values))
|
|
for j, v := range emb.Values {
|
|
vec[j] = float64(v)
|
|
}
|
|
result[i] = EmbeddingData{
|
|
Embedding: vec,
|
|
Index: i,
|
|
}
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
func (g *GoogleModel) ListModels(apiConfig *APIConfig) ([]string, error) {
|
|
if apiConfig == nil || apiConfig.ApiKey == nil || strings.TrimSpace(*apiConfig.ApiKey) == "" {
|
|
return nil, fmt.Errorf("api key is required")
|
|
}
|
|
|
|
return googleListModels(context.Background(), g.clientConfig(strings.TrimSpace(*apiConfig.ApiKey), apiConfig))
|
|
}
|
|
|
|
func (g *GoogleModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) {
|
|
return nil, fmt.Errorf("no such method")
|
|
}
|
|
|
|
func (g *GoogleModel) CheckConnection(apiConfig *APIConfig) error {
|
|
_, err := g.ListModels(apiConfig)
|
|
return err
|
|
}
|
|
|
|
// Rerank calculates similarity scores between query and documents
|
|
func (g *GoogleModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) {
|
|
return nil, fmt.Errorf("%s, Rerank not implemented", g.Name())
|
|
}
|
|
|
|
// TranscribeAudio transcribe audio
|
|
func (g *GoogleModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) {
|
|
return nil, fmt.Errorf("%s, no such method", g.Name())
|
|
}
|
|
|
|
func (z *GoogleModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error {
|
|
return fmt.Errorf("%s, no such method", z.Name())
|
|
}
|
|
|
|
// AudioSpeech convert text to audio
|
|
func (g *GoogleModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) {
|
|
return nil, fmt.Errorf("%s, no such method", g.Name())
|
|
}
|
|
|
|
func (z *GoogleModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error {
|
|
return fmt.Errorf("%s, no such method", z.Name())
|
|
}
|
|
|
|
// OCRFile OCR file
|
|
func (g *GoogleModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) {
|
|
return nil, fmt.Errorf("%s, no such method", g.Name())
|
|
}
|
|
|
|
// ParseFile parse file
|
|
func (z *GoogleModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) {
|
|
return nil, fmt.Errorf("%s, no such method", z.Name())
|
|
}
|
|
|
|
func (z *GoogleModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) {
|
|
return nil, fmt.Errorf("%s, no such method", z.Name())
|
|
}
|
|
|
|
func (z *GoogleModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) {
|
|
return nil, fmt.Errorf("%s, no such method", z.Name())
|
|
}
|