mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-02 00:35:46 +08:00
## Summary - Add custom `base_url` support to the Google Go model driver. - Preserve Google URL suffix configuration when creating custom base URL driver instances. - Validate Google chat/stream request inputs before constructing the SDK client. - Cover Google model listing, connection checks, base URL resolution, and request validation with focused tests. ## What changed - `GoogleModel.NewInstance` now returns a Google driver configured with the supplied base URL map. - Google SDK client creation now resolves configured base URLs through `genai.HTTPOptions.BaseURL`. - Base URL lookup supports configured regions, empty-region keys, and `default` fallback. - Google chat, streaming chat, embeddings, and model listing now reject blank API keys before creating SDK clients. - Google chat and streaming chat now reject blank model names locally, and streaming chat rejects a nil sender. - Existing message handling, embeddings, pagination, and provider errors are preserved. ## Why Google custom model instances could not use configured base URLs because `NewInstance` returned `nil` and the SDK client path ignored the driver base URL map. The request validation keeps invalid Google calls from reaching SDK client construction with blank credentials or incomplete chat inputs.
392 lines
11 KiB
Go
392 lines
11 KiB
Go
package models
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"reflect"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
|
|
"google.golang.org/genai"
|
|
)
|
|
|
|
var googleListModelsMu sync.Mutex
|
|
|
|
func withGoogleListModelsStub(t *testing.T, fn func(context.Context, *genai.ClientConfig) ([]string, error)) {
|
|
t.Helper()
|
|
|
|
googleListModelsMu.Lock()
|
|
original := googleListModels
|
|
googleListModels = fn
|
|
t.Cleanup(func() {
|
|
googleListModels = original
|
|
googleListModelsMu.Unlock()
|
|
})
|
|
}
|
|
|
|
func TestGoogleModelListModelsRequiresAPIKey(t *testing.T) {
|
|
model := &GoogleModel{}
|
|
cases := []struct {
|
|
name string
|
|
apiConfig *APIConfig
|
|
}{
|
|
{
|
|
name: "nil config",
|
|
apiConfig: nil,
|
|
},
|
|
{
|
|
name: "nil api key",
|
|
apiConfig: &APIConfig{},
|
|
},
|
|
{
|
|
name: "empty api key",
|
|
apiConfig: &APIConfig{
|
|
ApiKey: stringPtr(""),
|
|
},
|
|
},
|
|
{
|
|
name: "blank api key",
|
|
apiConfig: &APIConfig{
|
|
ApiKey: stringPtr(" \t\n "),
|
|
},
|
|
},
|
|
}
|
|
|
|
calls := 0
|
|
withGoogleListModelsStub(t, func(context.Context, *genai.ClientConfig) ([]string, error) {
|
|
calls++
|
|
return nil, nil
|
|
})
|
|
|
|
for _, tc := range cases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
models, err := model.ListModels(tc.apiConfig)
|
|
if err == nil {
|
|
t.Fatal("expected an API key error")
|
|
}
|
|
if !strings.Contains(err.Error(), "api key is required") {
|
|
t.Fatalf("expected API key error, got %v", err)
|
|
}
|
|
if models != nil {
|
|
t.Fatalf("expected no models, got %v", models)
|
|
}
|
|
})
|
|
}
|
|
|
|
if calls != 0 {
|
|
t.Fatalf("expected no ListModels calls without an API key, got %d", calls)
|
|
}
|
|
}
|
|
|
|
func TestGoogleModelListModelsReturnsModelNames(t *testing.T) {
|
|
model := &GoogleModel{}
|
|
apiKey := "test-api-key"
|
|
configuredAPIKey := " " + apiKey + " "
|
|
expected := []string{"models/gemini-2.5-flash", "models/gemini-2.5-pro"}
|
|
|
|
withGoogleListModelsStub(t, func(_ context.Context, config *genai.ClientConfig) ([]string, error) {
|
|
if config.APIKey != apiKey {
|
|
t.Fatalf("expected API key %q, got %q", apiKey, config.APIKey)
|
|
}
|
|
return expected, nil
|
|
})
|
|
|
|
models, err := model.ListModels(&APIConfig{ApiKey: &configuredAPIKey})
|
|
if err != nil {
|
|
t.Fatalf("expected no error, got %v", err)
|
|
}
|
|
if !reflect.DeepEqual(models, expected) {
|
|
t.Fatalf("expected models %v, got %v", expected, models)
|
|
}
|
|
}
|
|
|
|
func TestGoogleModelCheckConnectionUsesListModels(t *testing.T) {
|
|
customBaseURL := "https://check-connection.example.test/google"
|
|
model := NewGoogleModel(map[string]string{"default": customBaseURL}, URLSuffix{})
|
|
apiKey := "test-api-key"
|
|
calls := 0
|
|
|
|
withGoogleListModelsStub(t, func(_ context.Context, config *genai.ClientConfig) ([]string, error) {
|
|
calls++
|
|
if config.APIKey != apiKey {
|
|
t.Fatalf("expected API key %q, got %q", apiKey, config.APIKey)
|
|
}
|
|
if config.HTTPOptions.BaseURL != customBaseURL {
|
|
t.Fatalf("expected base URL %q, got %q", customBaseURL, config.HTTPOptions.BaseURL)
|
|
}
|
|
return []string{"models/gemini-2.5-flash"}, nil
|
|
})
|
|
|
|
if err := model.CheckConnection(&APIConfig{ApiKey: &apiKey}); err != nil {
|
|
t.Fatalf("expected no error, got %v", err)
|
|
}
|
|
if calls != 1 {
|
|
t.Fatalf("expected one ListModels call, got %d", calls)
|
|
}
|
|
}
|
|
|
|
func TestGoogleModelCheckConnectionRequiresAPIKey(t *testing.T) {
|
|
model := &GoogleModel{}
|
|
calls := 0
|
|
|
|
withGoogleListModelsStub(t, func(context.Context, *genai.ClientConfig) ([]string, error) {
|
|
calls++
|
|
return nil, nil
|
|
})
|
|
|
|
cases := []struct {
|
|
name string
|
|
apiConfig *APIConfig
|
|
}{
|
|
{
|
|
name: "nil config",
|
|
apiConfig: nil,
|
|
},
|
|
{
|
|
name: "nil api key",
|
|
apiConfig: &APIConfig{},
|
|
},
|
|
{
|
|
name: "empty api key",
|
|
apiConfig: &APIConfig{
|
|
ApiKey: stringPtr(""),
|
|
},
|
|
},
|
|
{
|
|
name: "blank api key",
|
|
apiConfig: &APIConfig{
|
|
ApiKey: stringPtr(" \t\n "),
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tc := range cases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
err := model.CheckConnection(tc.apiConfig)
|
|
if err == nil {
|
|
t.Fatal("expected an API key error")
|
|
}
|
|
if !strings.Contains(err.Error(), "api key is required") {
|
|
t.Fatalf("expected API key error, got %v", err)
|
|
}
|
|
})
|
|
}
|
|
if calls != 0 {
|
|
t.Fatalf("expected no ListModels calls without an API key, got %d", calls)
|
|
}
|
|
}
|
|
|
|
func TestGoogleModelCheckConnectionReturnsListModelsError(t *testing.T) {
|
|
model := &GoogleModel{}
|
|
apiKey := "test-api-key"
|
|
listErr := errors.New("list models failed")
|
|
|
|
withGoogleListModelsStub(t, func(context.Context, *genai.ClientConfig) ([]string, error) {
|
|
return nil, listErr
|
|
})
|
|
|
|
err := model.CheckConnection(&APIConfig{ApiKey: &apiKey})
|
|
if !errors.Is(err, listErr) {
|
|
t.Fatalf("expected ListModels error %v, got %v", listErr, err)
|
|
}
|
|
}
|
|
|
|
func TestGoogleModelChatStreamlyRequiresAPIKey(t *testing.T) {
|
|
model := &GoogleModel{}
|
|
messages := []Message{{Role: "user", Content: "hello"}}
|
|
cases := []struct {
|
|
name string
|
|
apiConfig *APIConfig
|
|
}{
|
|
{name: "nil config"},
|
|
{name: "nil api key", apiConfig: &APIConfig{}},
|
|
{name: "empty api key", apiConfig: &APIConfig{ApiKey: stringPtr("")}},
|
|
{name: "blank api key", apiConfig: &APIConfig{ApiKey: stringPtr(" \t\n ")}},
|
|
}
|
|
|
|
for _, tc := range cases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
err := model.ChatStreamlyWithSender("gemini-2.5-flash", messages, tc.apiConfig, nil, func(*string, *string) error {
|
|
t.Errorf("sender should not be called without an API key")
|
|
return nil
|
|
})
|
|
if err == nil {
|
|
t.Fatal("expected an API key error")
|
|
}
|
|
if !strings.Contains(err.Error(), "api key is nil or empty") {
|
|
t.Fatalf("expected API key error, got %v", err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGoogleModelChatRequiresModelName(t *testing.T) {
|
|
model := &GoogleModel{}
|
|
apiKey := "test-api-key"
|
|
messages := []Message{{Role: "user", Content: "hello"}}
|
|
|
|
response, err := model.ChatWithMessages("", messages, &APIConfig{ApiKey: &apiKey}, nil)
|
|
if err == nil {
|
|
t.Fatal("expected a model name error")
|
|
}
|
|
if !strings.Contains(err.Error(), "model name is empty") {
|
|
t.Fatalf("expected model name error, got %v", err)
|
|
}
|
|
if response != nil {
|
|
t.Fatalf("expected no response, got %v", response)
|
|
}
|
|
|
|
err = model.ChatStreamlyWithSender("", messages, &APIConfig{ApiKey: &apiKey}, nil, func(*string, *string) error {
|
|
t.Errorf("sender should not be called without a model name")
|
|
return nil
|
|
})
|
|
if err == nil {
|
|
t.Fatal("expected a model name error")
|
|
}
|
|
if !strings.Contains(err.Error(), "model name is empty") {
|
|
t.Fatalf("expected model name error, got %v", err)
|
|
}
|
|
|
|
err = model.ChatStreamlyWithSender("gemini-2.5-flash", messages, &APIConfig{ApiKey: &apiKey}, nil, nil)
|
|
if err == nil {
|
|
t.Fatal("expected a sender error")
|
|
}
|
|
if !strings.Contains(err.Error(), "sender is nil") {
|
|
t.Fatalf("expected sender error, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestGoogleModelNewInstancePreservesCustomBaseURL(t *testing.T) {
|
|
model := NewGoogleModel(map[string]string{"default": "https://generativelanguage.googleapis.com"}, URLSuffix{Models: "v1beta/models"})
|
|
customBaseURL := map[string]string{"default": "https://example.test/google"}
|
|
|
|
instance := model.NewInstance(customBaseURL)
|
|
google, ok := instance.(*GoogleModel)
|
|
if !ok {
|
|
t.Fatalf("expected *GoogleModel, got %T", instance)
|
|
}
|
|
if google.BaseURL["default"] != customBaseURL["default"] {
|
|
t.Fatalf("expected base URL %q, got %q", customBaseURL["default"], google.BaseURL["default"])
|
|
}
|
|
if google.URLSuffix != model.URLSuffix {
|
|
t.Fatalf("expected URL suffix %v, got %v", model.URLSuffix, google.URLSuffix)
|
|
}
|
|
}
|
|
|
|
func TestGoogleModelListModelsPassesBaseURL(t *testing.T) {
|
|
apiKey := "test-api-key"
|
|
cases := []struct {
|
|
name string
|
|
baseURL map[string]string
|
|
region *string
|
|
expectedBaseURL string
|
|
}{
|
|
{
|
|
name: "default custom base URL",
|
|
baseURL: map[string]string{"default": "https://example.test/google"},
|
|
expectedBaseURL: "https://example.test/google",
|
|
},
|
|
{
|
|
name: "regional custom base URL",
|
|
baseURL: map[string]string{"east": "https://east.example.test/google", "default": "https://default.example.test/google"},
|
|
region: stringPtr("east"),
|
|
expectedBaseURL: "https://east.example.test/google",
|
|
},
|
|
{
|
|
name: "empty region custom base URL",
|
|
baseURL: map[string]string{"": "https://empty-region.example.test/google"},
|
|
region: stringPtr(""),
|
|
expectedBaseURL: "https://empty-region.example.test/google",
|
|
},
|
|
{
|
|
name: "missing region falls back to default base URL",
|
|
baseURL: map[string]string{"default": "https://default.example.test/google"},
|
|
region: stringPtr("missing"),
|
|
expectedBaseURL: "https://default.example.test/google",
|
|
},
|
|
{
|
|
name: "SDK default base URL",
|
|
},
|
|
}
|
|
|
|
for _, tc := range cases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
model := NewGoogleModel(tc.baseURL, URLSuffix{})
|
|
withGoogleListModelsStub(t, func(_ context.Context, config *genai.ClientConfig) ([]string, error) {
|
|
if config.HTTPOptions.BaseURL != tc.expectedBaseURL {
|
|
t.Fatalf("expected base URL %q, got %q", tc.expectedBaseURL, config.HTTPOptions.BaseURL)
|
|
}
|
|
return []string{"models/gemini-2.5-flash"}, nil
|
|
})
|
|
|
|
if _, err := model.ListModels(&APIConfig{ApiKey: &apiKey, Region: tc.region}); err != nil {
|
|
t.Fatalf("expected no error, got %v", err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCollectGoogleModelNamesPaginates(t *testing.T) {
|
|
pages := []googleModelPage{
|
|
{items: []string{"models/gemini-2.5-flash"}, nextPageToken: "page-2"},
|
|
{items: []string{"models/gemini-2.5-pro"}, nextPageToken: ""},
|
|
}
|
|
var pageTokens []string
|
|
|
|
models, err := collectGoogleModelNames(context.Background(), func(_ context.Context, pageToken string) (googleModelPage, error) {
|
|
pageTokens = append(pageTokens, pageToken)
|
|
if len(pageTokens) > len(pages) {
|
|
t.Fatalf("unexpected extra page request with token %q", pageToken)
|
|
}
|
|
return pages[len(pageTokens)-1], nil
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("expected no error, got %v", err)
|
|
}
|
|
|
|
expectedModels := []string{"models/gemini-2.5-flash", "models/gemini-2.5-pro"}
|
|
if !reflect.DeepEqual(models, expectedModels) {
|
|
t.Fatalf("expected models %v, got %v", expectedModels, models)
|
|
}
|
|
expectedPageTokens := []string{"", "page-2"}
|
|
if !reflect.DeepEqual(pageTokens, expectedPageTokens) {
|
|
t.Fatalf("expected page tokens %v, got %v", expectedPageTokens, pageTokens)
|
|
}
|
|
}
|
|
|
|
func TestCollectGoogleModelNamesPreservesEmptyResult(t *testing.T) {
|
|
models, err := collectGoogleModelNames(context.Background(), func(context.Context, string) (googleModelPage, error) {
|
|
return googleModelPage{}, nil
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("expected no error, got %v", err)
|
|
}
|
|
if models != nil {
|
|
t.Fatalf("expected nil models, got %v", models)
|
|
}
|
|
}
|
|
|
|
func TestCollectGoogleModelNamesReturnsPageError(t *testing.T) {
|
|
pageErr := errors.New("next page failed")
|
|
calls := 0
|
|
|
|
models, err := collectGoogleModelNames(context.Background(), func(context.Context, string) (googleModelPage, error) {
|
|
calls++
|
|
if calls == 1 {
|
|
return googleModelPage{items: []string{"models/gemini-2.5-flash"}, nextPageToken: "page-2"}, nil
|
|
}
|
|
return googleModelPage{}, pageErr
|
|
})
|
|
if !errors.Is(err, pageErr) {
|
|
t.Fatalf("expected page error %v, got %v", pageErr, err)
|
|
}
|
|
if models != nil {
|
|
t.Fatalf("expected no models on error, got %v", models)
|
|
}
|
|
}
|
|
|
|
func stringPtr(value string) *string {
|
|
return &value
|
|
}
|