From d69518ea420d78cdbeb94002d2c1b65d6c29f034 Mon Sep 17 00:00:00 2001 From: bitloi <89318445+bitloi@users.noreply.github.com> Date: Wed, 20 May 2026 03:58:20 -0300 Subject: [PATCH] fix(go): guard custom base URL driver creation (#15030) ### What problem does this PR solve? Closes #15029. Some custom `base_url` paths in `ModelProviderService` call `NewInstance(newURL)` and then immediately invoke methods on the returned driver. Several real Go model drivers still return `nil` from `NewInstance`, so those paths can panic instead of returning a normal error. This PR: - centralizes custom base URL driver creation in `model_service.go` - skips request-local driver creation when `base_url` is blank or whitespace - preserves the existing region key behavior when building the request-local base URL map - returns a clear error when the provider driver is missing or `NewInstance` returns `nil` - routes list/check/task and active model paths through the guarded helper - adds focused unit coverage for empty-region preservation, regional base URLs, blank base URLs, nil drivers, and nil `NewInstance` results ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) ### Test plan - [x] `git diff --check upstream/main...HEAD` - [x] `/root/go/bin/gofmt -w internal/service/model_service.go internal/service/model_service_test.go` - [x] `GOPATH=/root/gopath GOTOOLCHAIN=local /root/go/bin/go test ./internal/service -run TestNewModelDriverForBaseURL -count=1 -vet=off` - [x] `GOPATH=/root/gopath GOTOOLCHAIN=local /root/go/bin/go build ./internal/service/... ./internal/entity/models/...` Note: the same targeted `go test` command without `-vet=off` is currently blocked by an existing unrelated vet finding in `internal/service/llm.go:355` (`non-constant format string in call to fmt.Errorf`). --- internal/service/model_service.go | 104 +++++++++++--------- internal/service/model_service_test.go | 129 +++++++++++++++++++++++++ 2 files changed, 190 insertions(+), 43 deletions(-) create mode 100644 internal/service/model_service_test.go diff --git a/internal/service/model_service.go b/internal/service/model_service.go index fe179d54e2..9813010a28 100644 --- a/internal/service/model_service.go +++ b/internal/service/model_service.go @@ -44,6 +44,25 @@ func parseModelName(compositeName string) (modelName, instanceName, providerName } } +func newModelDriverForBaseURL(driver modelModule.ModelDriver, providerName, region, baseURL string) (modelModule.ModelDriver, error) { + if driver == nil { + return nil, fmt.Errorf("provider %s driver not found", providerName) + } + + if strings.TrimSpace(baseURL) == "" { + return driver, nil + } + + newDriver := driver.NewInstance(map[string]string{ + region: baseURL, + }) + if newDriver == nil { + return nil, fmt.Errorf("provider %s does not support custom base_url", providerName) + } + + return newDriver, nil +} + func NewModelProviderService() *ModelProviderService { return &ModelProviderService{ modelProviderDAO: dao.NewTenantModelProviderDAO(), @@ -196,11 +215,10 @@ func (m *ModelProviderService) ListSupportedModels(providerName, instanceName, u // For local deployed models if baseURL, ok := extra["base_url"]; ok && baseURL != "" { - newURL := map[string]string{ - region: baseURL, + driver, err = newModelDriverForBaseURL(driver, providerName, region, baseURL) + if err != nil { + return nil, err } - - driver = driver.NewInstance(newURL) } return driver.ListModels(apiConfig) @@ -447,10 +465,10 @@ func (m *ModelProviderService) CheckProviderConnection(providerName, instanceNam driver := providerInfo.ModelDriver if baseURL, ok := extra["base_url"]; ok && baseURL != "" { - newURL := map[string]string{ - region: baseURL, + driver, err = newModelDriverForBaseURL(driver, providerName, region, baseURL) + if err != nil { + return common.CodeServerError, err } - driver = driver.NewInstance(newURL) } err = driver.CheckConnection(apiConfig) @@ -507,10 +525,10 @@ func (m *ModelProviderService) ListTasks(providerName, instanceName, userID stri driver := providerInfo.ModelDriver if baseURL, ok := extra["base_url"]; ok && baseURL != "" { - newURL := map[string]string{ - region: baseURL, + driver, err = newModelDriverForBaseURL(driver, providerName, region, baseURL) + if err != nil { + return nil, common.CodeServerError, err } - driver = driver.NewInstance(newURL) } var listTaskResponse []modelModule.ListTaskStatus @@ -568,10 +586,10 @@ func (m *ModelProviderService) ShowTask(providerName, instanceName, taskID, user driver := providerInfo.ModelDriver if baseURL, ok := extra["base_url"]; ok && baseURL != "" { - newURL := map[string]string{ - region: baseURL, + driver, err = newModelDriverForBaseURL(driver, providerName, region, baseURL) + if err != nil { + return nil, common.CodeServerError, err } - driver = driver.NewInstance(newURL) } var taskResponse *modelModule.TaskResponse @@ -897,10 +915,10 @@ func (m *ModelProviderService) ChatToModelWithMessages(providerName, instanceNam modelConfig.ModelClass = &providerInfo.Class - newURL := map[string]string{ - region: extra["base_url"], + newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"]) + if err != nil { + return nil, common.CodeServerError, err } - newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL) var response *modelModule.ChatResponse response, err = newProviderInfo.ChatWithMessages(modelName, messages, apiConfig, modelConfig) @@ -999,10 +1017,10 @@ func (m *ModelProviderService) ChatToModelStreamWithSender(providerName, instanc modelConfig.ModelClass = &providerInfo.Class - newURL := map[string]string{ - region: extra["base_url"], + newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"]) + if err != nil { + return common.CodeServerError, err } - newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL) err = newProviderInfo.ChatStreamlyWithSender(modelName, messages, apiConfig, modelConfig, sender) if err != nil { @@ -1105,10 +1123,10 @@ func (m *ModelProviderService) EmbedText(providerName, instanceName, modelName, apiConfig.Region = ®ion apiConfig.ApiKey = &instance.APIKey - newURL := map[string]string{ - region: extra["base_url"], + newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"]) + if err != nil { + return nil, common.CodeServerError, err } - newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL) var response []modelModule.EmbeddingData response, err = newProviderInfo.Embed(&modelName, texts, apiConfig, modelConfig) @@ -1213,10 +1231,10 @@ func (m *ModelProviderService) RerankDocument(providerName, instanceName, modelN apiConfig.Region = ®ion apiConfig.ApiKey = &instance.APIKey - newURL := map[string]string{ - region: extra["base_url"], + newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"]) + if err != nil { + return nil, common.CodeServerError, err } - newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL) var response *modelModule.RerankResponse response, err = newProviderInfo.Rerank(&modelName, query, documents, apiConfig, modelConfig) @@ -1321,10 +1339,10 @@ func (m *ModelProviderService) TranscribeAudio(providerName, instanceName, model apiConfig.Region = ®ion apiConfig.ApiKey = &instance.APIKey - newURL := map[string]string{ - region: extra["base_url"], + newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"]) + if err != nil { + return nil, common.CodeServerError, err } - newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL) var response *modelModule.ASRResponse response, err = newProviderInfo.TranscribeAudio(&modelName, audioFile, apiConfig, asrConfig) @@ -1420,10 +1438,10 @@ func (m *ModelProviderService) TranscribeAudioStream(providerName, instanceName, apiConfig.Region = ®ion apiConfig.ApiKey = &instance.APIKey - newURL := map[string]string{ - region: extra["base_url"], + newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"]) + if err != nil { + return common.CodeServerError, err } - newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL) err = newProviderInfo.TranscribeAudioWithSender(&modelName, audioFile, apiConfig, asrConfig, sender) if err != nil { @@ -1526,10 +1544,10 @@ func (m *ModelProviderService) AudioSpeech(providerName, instanceName, modelName apiConfig.Region = ®ion apiConfig.ApiKey = &instance.APIKey - newURL := map[string]string{ - region: extra["base_url"], + newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"]) + if err != nil { + return nil, common.CodeServerError, err } - newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL) var response *modelModule.TTSResponse response, err = newProviderInfo.AudioSpeech(&modelName, audioContent, apiConfig, ttsConfig) @@ -1625,10 +1643,10 @@ func (m *ModelProviderService) AudioSpeechStream(providerName, instanceName, mod apiConfig.Region = ®ion apiConfig.ApiKey = &instance.APIKey - newURL := map[string]string{ - region: extra["base_url"], + newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"]) + if err != nil { + return common.CodeServerError, err } - newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL) err = newProviderInfo.AudioSpeechWithSender(&modelName, audioContent, apiConfig, ttsConfig, sender) if err != nil { @@ -1730,10 +1748,10 @@ func (m *ModelProviderService) OCRFile(providerName, instanceName, modelName, us apiConfig.Region = ®ion apiConfig.ApiKey = &instance.APIKey - newURL := map[string]string{ - region: extra["base_url"], + newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"]) + if err != nil { + return nil, common.CodeServerError, err } - newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL) var response *modelModule.OCRFileResponse response, err = newProviderInfo.OCRFile(&modelName, content, url, apiConfig, ocrConfig) @@ -1840,10 +1858,10 @@ func (m *ModelProviderService) ParseFile(providerName, instanceName, modelName, apiConfig.Region = ®ion apiConfig.ApiKey = &instance.APIKey - newURL := map[string]string{ - region: extra["base_url"], + newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"]) + if err != nil { + return nil, common.CodeServerError, err } - newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL) var response *modelModule.ParseFileResponse response, err = newProviderInfo.ParseFile(&modelName, content, url, apiConfig, parseFileConfig) diff --git a/internal/service/model_service_test.go b/internal/service/model_service_test.go new file mode 100644 index 0000000000..7b6e138c4c --- /dev/null +++ b/internal/service/model_service_test.go @@ -0,0 +1,129 @@ +package service + +import ( + "strings" + "testing" + + modelModule "ragflow/internal/entity/models" +) + +type stubModelDriver struct { + modelModule.ModelDriver + newInstance func(map[string]string) modelModule.ModelDriver +} + +var _ modelModule.ModelDriver = (*stubModelDriver)(nil) + +func (s *stubModelDriver) NewInstance(baseURL map[string]string) modelModule.ModelDriver { + if s.newInstance != nil { + return s.newInstance(baseURL) + } + return s +} + +func (s *stubModelDriver) Name() string { + return "stub" +} + +func TestNewModelDriverForBaseURLPreservesEmptyRegion(t *testing.T) { + expected := &stubModelDriver{} + var gotBaseURL map[string]string + driver := &stubModelDriver{ + newInstance: func(baseURL map[string]string) modelModule.ModelDriver { + gotBaseURL = baseURL + return expected + }, + } + + got, err := newModelDriverForBaseURL(driver, "stub", "", "http://localhost:1234") + if err != nil { + t.Fatalf("newModelDriverForBaseURL returned error: %v", err) + } + if got != expected { + t.Fatalf("expected returned driver %p, got %p", expected, got) + } + if gotBaseURL[""] != "http://localhost:1234" { + t.Fatalf("expected empty-region base URL, got %#v", gotBaseURL) + } + if _, ok := gotBaseURL["default"]; ok { + t.Fatalf("unexpected default region key in base URL map: %#v", gotBaseURL) + } +} + +func TestNewModelDriverForBaseURLUsesProvidedRegion(t *testing.T) { + var gotBaseURL map[string]string + driver := &stubModelDriver{ + newInstance: func(baseURL map[string]string) modelModule.ModelDriver { + gotBaseURL = baseURL + return &stubModelDriver{} + }, + } + + _, err := newModelDriverForBaseURL(driver, "stub", "cn-hangzhou", "http://localhost:5678") + if err != nil { + t.Fatalf("newModelDriverForBaseURL returned error: %v", err) + } + if gotBaseURL["cn-hangzhou"] != "http://localhost:5678" { + t.Fatalf("expected regional base URL, got %#v", gotBaseURL) + } + if _, ok := gotBaseURL["default"]; ok { + t.Fatalf("unexpected default region key in base URL map: %#v", gotBaseURL) + } +} + +func TestNewModelDriverForBaseURLSkipsEmptyBaseURL(t *testing.T) { + for _, baseURL := range []string{"", " "} { + t.Run(baseURL, func(t *testing.T) { + called := false + driver := &stubModelDriver{ + newInstance: func(map[string]string) modelModule.ModelDriver { + called = true + return nil + }, + } + + got, err := newModelDriverForBaseURL(driver, "deepseek", "default", baseURL) + if err != nil { + t.Fatalf("newModelDriverForBaseURL returned error: %v", err) + } + if got != driver { + t.Fatalf("expected original driver %p, got %p", driver, got) + } + if called { + t.Fatal("expected empty base URL to skip NewInstance") + } + }) + } +} + +func TestNewModelDriverForBaseURLRejectsNilInstance(t *testing.T) { + driver := &stubModelDriver{ + newInstance: func(map[string]string) modelModule.ModelDriver { + return nil + }, + } + + got, err := newModelDriverForBaseURL(driver, "deepseek", "default", "http://localhost:1234") + if err == nil { + t.Fatal("expected nil NewInstance result to return an error") + } + if got != nil { + t.Fatalf("expected nil driver on error, got %T", got) + } + if !strings.Contains(err.Error(), "deepseek") || !strings.Contains(err.Error(), "custom base_url") { + t.Fatalf("expected provider-specific custom base_url error, got %v", err) + } +} + +func TestNewModelDriverForBaseURLRejectsNilDriver(t *testing.T) { + got, err := newModelDriverForBaseURL(nil, "deepseek", "default", "http://localhost:1234") + if err == nil { + t.Fatal("expected nil driver to return an error") + } + if got != nil { + t.Fatalf("expected nil driver on error, got %T", got) + } + if !strings.Contains(err.Error(), "driver not found") { + t.Fatalf("expected driver not found error, got %v", err) + } +}