mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
Ports the agent canvas subsystem from Python to Go.
## What's included
### Canvas Engine (Phase 0/1)
- State engine, scheduler, variable resolver, Redis checkpoint store,
cancel protocol
- **209 tests** across canvas / component / io packages
### 22 Components (P0–P4)
| Tier | Components |
|---|---|
| P0 T1+T2+T3 | LLM, Agent, ExitLoop, Switch, Categorize, Begin,
Message, Invoke |
| P1 T3 | VariableAggregator, VariableAssigner, StringTransform,
ListOperations, DataOperations |
| P2 T3 | Iteration, IterationItem, Loop, LoopItem |
| P3 T3 | UserFillUp, Fillup |
| P4 T5 | Browser, ExcelProcessor, DocsGenerator |
### DSL v2 Schema (Phase 2.5)
- Typed v2 in-memory model with v1-to-v2 auto-detect converter
- v1 legacy field stripping per plan §2.11.7
### HTTP Endpoints & Bug Fixes (Plans PR1–PR3)
- **DELETE SQL bug fix**: gorm v2 `Where("id = ?", id).Delete(...)`
pattern
- **CreateAgent validation**: title/DSL required, duplicate check, 103
envelope
- **13 new endpoints**: templates, prompts, tags, sessions CRUD,
chat/completions (SSE + non-stream stubs), rerun, test_db_connection,
logs, webhook/logs
- **756 Go unit tests** (745 → 756, +18)
- **17 → 0 Python integration test failures** (test_agents.py +
test_session_management/)
### Tools
21 eino tools: HTTPHelper, search tools, financial/data tools, mandatory
stubs
### Infrastructure
OTel observability, NATS message queue, DeepDoc gRPC client, SSRF
guards, IDOR mitigation
389 lines
11 KiB
Go
389 lines
11 KiB
Go
package models
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"reflect"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
|
|
"google.golang.org/genai"
|
|
)
|
|
|
|
var googleListModelsMu sync.Mutex
|
|
|
|
func withGoogleListModelsStub(t *testing.T, fn func(context.Context, *genai.ClientConfig) ([]ListModelResponse, error)) {
|
|
t.Helper()
|
|
|
|
googleListModelsMu.Lock()
|
|
original := googleListModels
|
|
googleListModels = fn
|
|
t.Cleanup(func() {
|
|
googleListModels = original
|
|
googleListModelsMu.Unlock()
|
|
})
|
|
}
|
|
|
|
func TestGoogleModelListModelsRequiresAPIKey(t *testing.T) {
|
|
model := &GoogleModel{}
|
|
cases := []struct {
|
|
name string
|
|
apiConfig *APIConfig
|
|
}{
|
|
{
|
|
name: "nil config",
|
|
apiConfig: nil,
|
|
},
|
|
{
|
|
name: "nil api key",
|
|
apiConfig: &APIConfig{},
|
|
},
|
|
{
|
|
name: "empty api key",
|
|
apiConfig: &APIConfig{
|
|
ApiKey: stringPtr(""),
|
|
},
|
|
},
|
|
{
|
|
name: "blank api key",
|
|
apiConfig: &APIConfig{
|
|
ApiKey: stringPtr(" \t\n "),
|
|
},
|
|
},
|
|
}
|
|
|
|
calls := 0
|
|
withGoogleListModelsStub(t, func(context.Context, *genai.ClientConfig) ([]ListModelResponse, error) {
|
|
calls++
|
|
return nil, nil
|
|
})
|
|
|
|
for _, tc := range cases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
models, err := model.ListModels(tc.apiConfig)
|
|
if err == nil {
|
|
t.Fatal("expected an API key error")
|
|
}
|
|
if !strings.Contains(err.Error(), "api key is required") {
|
|
t.Fatalf("expected API key error, got %v", err)
|
|
}
|
|
if models != nil {
|
|
t.Fatalf("expected no models, got %v", models)
|
|
}
|
|
})
|
|
}
|
|
|
|
if calls != 0 {
|
|
t.Fatalf("expected no ListModels calls without an API key, got %d", calls)
|
|
}
|
|
}
|
|
|
|
func TestGoogleModelListModelsReturnsModelNames(t *testing.T) {
|
|
model := &GoogleModel{}
|
|
apiKey := "test-api-key"
|
|
configuredAPIKey := " " + apiKey + " "
|
|
expected := []ListModelResponse{{Name: "models/gemini-2.5-flash"}, {Name: "models/gemini-2.5-pro"}}
|
|
|
|
withGoogleListModelsStub(t, func(_ context.Context, config *genai.ClientConfig) ([]ListModelResponse, error) {
|
|
if config.APIKey != apiKey {
|
|
t.Fatalf("expected API key %q, got %q", apiKey, config.APIKey)
|
|
}
|
|
return expected, nil
|
|
})
|
|
|
|
models, err := model.ListModels(&APIConfig{ApiKey: &configuredAPIKey})
|
|
if err != nil {
|
|
t.Fatalf("expected no error, got %v", err)
|
|
}
|
|
if !reflect.DeepEqual(models, expected) {
|
|
t.Fatalf("expected models %v, got %v", expected, models)
|
|
}
|
|
}
|
|
|
|
func TestGoogleModelCheckConnectionUsesListModels(t *testing.T) {
|
|
customBaseURL := "https://check-connection.example.test/google"
|
|
model := NewGoogleModel(map[string]string{"default": customBaseURL}, URLSuffix{})
|
|
apiKey := "test-api-key"
|
|
calls := 0
|
|
|
|
withGoogleListModelsStub(t, func(_ context.Context, config *genai.ClientConfig) ([]ListModelResponse, error) {
|
|
calls++
|
|
if config.APIKey != apiKey {
|
|
t.Fatalf("expected API key %q, got %q", apiKey, config.APIKey)
|
|
}
|
|
if config.HTTPOptions.BaseURL != customBaseURL {
|
|
t.Fatalf("expected base URL %q, got %q", customBaseURL, config.HTTPOptions.BaseURL)
|
|
}
|
|
return []ListModelResponse{{Name: "models/gemini-2.5-flash"}}, nil
|
|
})
|
|
|
|
if err := model.CheckConnection(&APIConfig{ApiKey: &apiKey}); err != nil {
|
|
t.Fatalf("expected no error, got %v", err)
|
|
}
|
|
if calls != 1 {
|
|
t.Fatalf("expected one ListModels call, got %d", calls)
|
|
}
|
|
}
|
|
|
|
func TestGoogleModelCheckConnectionRequiresAPIKey(t *testing.T) {
|
|
model := &GoogleModel{}
|
|
calls := 0
|
|
|
|
withGoogleListModelsStub(t, func(context.Context, *genai.ClientConfig) ([]ListModelResponse, error) {
|
|
calls++
|
|
return nil, nil
|
|
})
|
|
|
|
cases := []struct {
|
|
name string
|
|
apiConfig *APIConfig
|
|
}{
|
|
{
|
|
name: "nil config",
|
|
apiConfig: nil,
|
|
},
|
|
{
|
|
name: "nil api key",
|
|
apiConfig: &APIConfig{},
|
|
},
|
|
{
|
|
name: "empty api key",
|
|
apiConfig: &APIConfig{
|
|
ApiKey: stringPtr(""),
|
|
},
|
|
},
|
|
{
|
|
name: "blank api key",
|
|
apiConfig: &APIConfig{
|
|
ApiKey: stringPtr(" \t\n "),
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tc := range cases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
err := model.CheckConnection(tc.apiConfig)
|
|
if err == nil {
|
|
t.Fatal("expected an API key error")
|
|
}
|
|
if !strings.Contains(err.Error(), "api key is required") {
|
|
t.Fatalf("expected API key error, got %v", err)
|
|
}
|
|
})
|
|
}
|
|
if calls != 0 {
|
|
t.Fatalf("expected no ListModels calls without an API key, got %d", calls)
|
|
}
|
|
}
|
|
|
|
func TestGoogleModelCheckConnectionReturnsListModelsError(t *testing.T) {
|
|
model := &GoogleModel{}
|
|
apiKey := "test-api-key"
|
|
listErr := errors.New("list models failed")
|
|
|
|
withGoogleListModelsStub(t, func(context.Context, *genai.ClientConfig) ([]ListModelResponse, error) {
|
|
return nil, listErr
|
|
})
|
|
|
|
err := model.CheckConnection(&APIConfig{ApiKey: &apiKey})
|
|
if !errors.Is(err, listErr) {
|
|
t.Fatalf("expected ListModels error %v, got %v", listErr, err)
|
|
}
|
|
}
|
|
|
|
func TestGoogleModelChatStreamlyRequiresAPIKey(t *testing.T) {
|
|
model := &GoogleModel{}
|
|
messages := []Message{{Role: "user", Content: "hello"}}
|
|
cases := []struct {
|
|
name string
|
|
apiConfig *APIConfig
|
|
}{
|
|
{name: "nil config"},
|
|
{name: "nil api key", apiConfig: &APIConfig{}},
|
|
{name: "empty api key", apiConfig: &APIConfig{ApiKey: stringPtr("")}},
|
|
{name: "blank api key", apiConfig: &APIConfig{ApiKey: stringPtr(" \t\n ")}},
|
|
}
|
|
|
|
for _, tc := range cases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
err := model.ChatStreamlyWithSender("gemini-2.5-flash", messages, tc.apiConfig, nil, func(*string, *string) error {
|
|
t.Errorf("sender should not be called without an API key")
|
|
return nil
|
|
})
|
|
if err == nil {
|
|
t.Fatal("expected an API key error")
|
|
}
|
|
if !strings.Contains(err.Error(), "api key is required") {
|
|
t.Fatalf("expected API key error, got %v", err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGoogleModelChatRequiresModelName(t *testing.T) {
|
|
model := &GoogleModel{}
|
|
apiKey := "test-api-key"
|
|
messages := []Message{{Role: "user", Content: "hello"}}
|
|
|
|
response, err := model.ChatWithMessages("", messages, &APIConfig{ApiKey: &apiKey}, nil)
|
|
if err == nil {
|
|
t.Fatal("expected a model name error")
|
|
}
|
|
if !strings.Contains(err.Error(), "model name is empty") {
|
|
t.Fatalf("expected model name error, got %v", err)
|
|
}
|
|
if response != nil {
|
|
t.Fatalf("expected no response, got %v", response)
|
|
}
|
|
|
|
err = model.ChatStreamlyWithSender("", messages, &APIConfig{ApiKey: &apiKey}, nil, func(*string, *string) error {
|
|
t.Errorf("sender should not be called without a model name")
|
|
return nil
|
|
})
|
|
if err == nil {
|
|
t.Fatal("expected a model name error")
|
|
}
|
|
if !strings.Contains(err.Error(), "model name is empty") {
|
|
t.Fatalf("expected model name error, got %v", err)
|
|
}
|
|
|
|
err = model.ChatStreamlyWithSender("gemini-2.5-flash", messages, &APIConfig{ApiKey: &apiKey}, nil, nil)
|
|
if err == nil {
|
|
t.Fatal("expected a sender error")
|
|
}
|
|
if !strings.Contains(err.Error(), "sender is nil") {
|
|
t.Fatalf("expected sender error, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestGoogleModelNewInstancePreservesCustomBaseURL(t *testing.T) {
|
|
model := NewGoogleModel(map[string]string{"default": "https://generativelanguage.googleapis.com"}, URLSuffix{Models: "v1beta/models"})
|
|
customBaseURL := map[string]string{"default": "https://example.test/google"}
|
|
|
|
instance := model.NewInstance(customBaseURL)
|
|
google, ok := instance.(*GoogleModel)
|
|
if !ok {
|
|
t.Fatalf("expected *GoogleModel, got %T", instance)
|
|
}
|
|
if google.baseModel.BaseURL["default"] != customBaseURL["default"] {
|
|
t.Fatalf("expected base URL %q, got %q", customBaseURL["default"], google.baseModel.BaseURL["default"])
|
|
}
|
|
if google.baseModel.URLSuffix != model.baseModel.URLSuffix {
|
|
t.Fatalf("expected URL suffix %v, got %v", model.baseModel.URLSuffix, google.baseModel.URLSuffix)
|
|
}
|
|
}
|
|
|
|
func TestGoogleModelListModelsPassesBaseURL(t *testing.T) {
|
|
apiKey := "test-api-key"
|
|
cases := []struct {
|
|
name string
|
|
baseURL map[string]string
|
|
region *string
|
|
expectedBaseURL string
|
|
}{
|
|
{
|
|
name: "default custom base URL",
|
|
baseURL: map[string]string{"default": "https://example.test/google"},
|
|
expectedBaseURL: "https://example.test/google",
|
|
},
|
|
{
|
|
name: "regional custom base URL",
|
|
baseURL: map[string]string{"east": "https://east.example.test/google", "default": "https://default.example.test/google"},
|
|
region: stringPtr("east"),
|
|
expectedBaseURL: "https://east.example.test/google",
|
|
},
|
|
{
|
|
name: "empty region custom base URL",
|
|
baseURL: map[string]string{"": "https://empty-region.example.test/google"},
|
|
region: stringPtr(""),
|
|
expectedBaseURL: "https://empty-region.example.test/google",
|
|
},
|
|
{
|
|
name: "missing region falls back to default base URL",
|
|
baseURL: map[string]string{"default": "https://default.example.test/google"},
|
|
region: stringPtr("missing"),
|
|
expectedBaseURL: "https://default.example.test/google",
|
|
},
|
|
{
|
|
name: "SDK default base URL",
|
|
},
|
|
}
|
|
|
|
for _, tc := range cases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
model := NewGoogleModel(tc.baseURL, URLSuffix{})
|
|
withGoogleListModelsStub(t, func(_ context.Context, config *genai.ClientConfig) ([]ListModelResponse, error) {
|
|
if config.HTTPOptions.BaseURL != tc.expectedBaseURL {
|
|
t.Fatalf("expected base URL %q, got %q", tc.expectedBaseURL, config.HTTPOptions.BaseURL)
|
|
}
|
|
return []ListModelResponse{{Name: "models/gemini-2.5-flash"}}, nil
|
|
})
|
|
|
|
if _, err := model.ListModels(&APIConfig{ApiKey: &apiKey, Region: tc.region}); err != nil {
|
|
t.Fatalf("expected no error, got %v", err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCollectGoogleModelNamesPaginates(t *testing.T) {
|
|
pages := []googleModelPage{
|
|
{items: []DSModel{{ID: "Gemini 2.5 Flash", OwnedBy: "Google"}}, nextPageToken: "page-2"},
|
|
{items: []DSModel{{ID: "Gemini 2.5 Pro", OwnedBy: "Google"}}, nextPageToken: ""},
|
|
}
|
|
var pageTokens []string
|
|
|
|
models, err := collectGoogleModelNames(context.Background(), func(_ context.Context, pageToken string) (googleModelPage, error) {
|
|
pageTokens = append(pageTokens, pageToken)
|
|
if len(pageTokens) > len(pages) {
|
|
t.Fatalf("unexpected extra page request with token %q", pageToken)
|
|
}
|
|
return pages[len(pageTokens)-1], nil
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("expected no error, got %v", err)
|
|
}
|
|
|
|
expectedModels := []ListModelResponse{{Name: "Gemini 2.5 Flash@Google"}, {Name: "Gemini 2.5 Pro@Google"}}
|
|
if !reflect.DeepEqual(models, expectedModels) {
|
|
t.Fatalf("expected models %v, got %v", expectedModels, models)
|
|
}
|
|
expectedPageTokens := []string{"", "page-2"}
|
|
if !reflect.DeepEqual(pageTokens, expectedPageTokens) {
|
|
t.Fatalf("expected page tokens %v, got %v", expectedPageTokens, pageTokens)
|
|
}
|
|
}
|
|
|
|
func TestCollectGoogleModelNamesPreservesEmptyResult(t *testing.T) {
|
|
models, err := collectGoogleModelNames(context.Background(), func(context.Context, string) (googleModelPage, error) {
|
|
return googleModelPage{}, nil
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("expected no error, got %v", err)
|
|
}
|
|
if models != nil {
|
|
t.Fatalf("expected nil models, got %v", models)
|
|
}
|
|
}
|
|
|
|
func TestCollectGoogleModelNamesReturnsPageError(t *testing.T) {
|
|
pageErr := errors.New("next page failed")
|
|
calls := 0
|
|
|
|
_, err := collectGoogleModelNames(context.Background(), func(context.Context, string) (googleModelPage, error) {
|
|
calls++
|
|
if calls == 1 {
|
|
return googleModelPage{items: []DSModel{{ID: "Gemini 2.5 Flash", OwnedBy: "Google"}}, nextPageToken: "page-2"}, nil
|
|
}
|
|
return googleModelPage{}, pageErr
|
|
})
|
|
if !errors.Is(err, pageErr) {
|
|
t.Fatalf("expected page error %v, got %v", pageErr, err)
|
|
}
|
|
}
|
|
|
|
func stringPtr(value string) *string {
|
|
return &value
|
|
}
|