fix(go): guard custom base URL driver creation (#15030)

### What problem does this PR solve?

Closes #15029.

Some custom `base_url` paths in `ModelProviderService` call
`NewInstance(newURL)` and then immediately invoke methods on the
returned driver. Several real Go model drivers still return `nil` from
`NewInstance`, so those paths can panic instead of returning a normal
error.

This PR:

- centralizes custom base URL driver creation in `model_service.go`
- skips request-local driver creation when `base_url` is blank or
whitespace
- preserves the existing region key behavior when building the
request-local base URL map
- returns a clear error when the provider driver is missing or
`NewInstance` returns `nil`
- routes list/check/task and active model paths through the guarded
helper
- adds focused unit coverage for empty-region preservation, regional
base URLs, blank base URLs, nil drivers, and nil `NewInstance` results

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

### Test plan

- [x] `git diff --check upstream/main...HEAD`
- [x] `/root/go/bin/gofmt -w internal/service/model_service.go
internal/service/model_service_test.go`
- [x] `GOPATH=/root/gopath GOTOOLCHAIN=local /root/go/bin/go test
./internal/service -run TestNewModelDriverForBaseURL -count=1 -vet=off`
- [x] `GOPATH=/root/gopath GOTOOLCHAIN=local /root/go/bin/go build
./internal/service/... ./internal/entity/models/...`

Note: the same targeted `go test` command without `-vet=off` is
currently blocked by an existing unrelated vet finding in
`internal/service/llm.go:355` (`non-constant format string in call to
fmt.Errorf`).
This commit is contained in:
bitloi
2026-05-20 03:58:20 -03:00
committed by GitHub
parent aea90f4e39
commit d69518ea42
2 changed files with 190 additions and 43 deletions

View File

@@ -44,6 +44,25 @@ func parseModelName(compositeName string) (modelName, instanceName, providerName
}
}
func newModelDriverForBaseURL(driver modelModule.ModelDriver, providerName, region, baseURL string) (modelModule.ModelDriver, error) {
if driver == nil {
return nil, fmt.Errorf("provider %s driver not found", providerName)
}
if strings.TrimSpace(baseURL) == "" {
return driver, nil
}
newDriver := driver.NewInstance(map[string]string{
region: baseURL,
})
if newDriver == nil {
return nil, fmt.Errorf("provider %s does not support custom base_url", providerName)
}
return newDriver, nil
}
func NewModelProviderService() *ModelProviderService {
return &ModelProviderService{
modelProviderDAO: dao.NewTenantModelProviderDAO(),
@@ -196,11 +215,10 @@ func (m *ModelProviderService) ListSupportedModels(providerName, instanceName, u
// For local deployed models
if baseURL, ok := extra["base_url"]; ok && baseURL != "" {
newURL := map[string]string{
region: baseURL,
driver, err = newModelDriverForBaseURL(driver, providerName, region, baseURL)
if err != nil {
return nil, err
}
driver = driver.NewInstance(newURL)
}
return driver.ListModels(apiConfig)
@@ -447,10 +465,10 @@ func (m *ModelProviderService) CheckProviderConnection(providerName, instanceNam
driver := providerInfo.ModelDriver
if baseURL, ok := extra["base_url"]; ok && baseURL != "" {
newURL := map[string]string{
region: baseURL,
driver, err = newModelDriverForBaseURL(driver, providerName, region, baseURL)
if err != nil {
return common.CodeServerError, err
}
driver = driver.NewInstance(newURL)
}
err = driver.CheckConnection(apiConfig)
@@ -507,10 +525,10 @@ func (m *ModelProviderService) ListTasks(providerName, instanceName, userID stri
driver := providerInfo.ModelDriver
if baseURL, ok := extra["base_url"]; ok && baseURL != "" {
newURL := map[string]string{
region: baseURL,
driver, err = newModelDriverForBaseURL(driver, providerName, region, baseURL)
if err != nil {
return nil, common.CodeServerError, err
}
driver = driver.NewInstance(newURL)
}
var listTaskResponse []modelModule.ListTaskStatus
@@ -568,10 +586,10 @@ func (m *ModelProviderService) ShowTask(providerName, instanceName, taskID, user
driver := providerInfo.ModelDriver
if baseURL, ok := extra["base_url"]; ok && baseURL != "" {
newURL := map[string]string{
region: baseURL,
driver, err = newModelDriverForBaseURL(driver, providerName, region, baseURL)
if err != nil {
return nil, common.CodeServerError, err
}
driver = driver.NewInstance(newURL)
}
var taskResponse *modelModule.TaskResponse
@@ -897,10 +915,10 @@ func (m *ModelProviderService) ChatToModelWithMessages(providerName, instanceNam
modelConfig.ModelClass = &providerInfo.Class
newURL := map[string]string{
region: extra["base_url"],
newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"])
if err != nil {
return nil, common.CodeServerError, err
}
newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL)
var response *modelModule.ChatResponse
response, err = newProviderInfo.ChatWithMessages(modelName, messages, apiConfig, modelConfig)
@@ -999,10 +1017,10 @@ func (m *ModelProviderService) ChatToModelStreamWithSender(providerName, instanc
modelConfig.ModelClass = &providerInfo.Class
newURL := map[string]string{
region: extra["base_url"],
newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"])
if err != nil {
return common.CodeServerError, err
}
newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL)
err = newProviderInfo.ChatStreamlyWithSender(modelName, messages, apiConfig, modelConfig, sender)
if err != nil {
@@ -1105,10 +1123,10 @@ func (m *ModelProviderService) EmbedText(providerName, instanceName, modelName,
apiConfig.Region = &region
apiConfig.ApiKey = &instance.APIKey
newURL := map[string]string{
region: extra["base_url"],
newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"])
if err != nil {
return nil, common.CodeServerError, err
}
newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL)
var response []modelModule.EmbeddingData
response, err = newProviderInfo.Embed(&modelName, texts, apiConfig, modelConfig)
@@ -1213,10 +1231,10 @@ func (m *ModelProviderService) RerankDocument(providerName, instanceName, modelN
apiConfig.Region = &region
apiConfig.ApiKey = &instance.APIKey
newURL := map[string]string{
region: extra["base_url"],
newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"])
if err != nil {
return nil, common.CodeServerError, err
}
newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL)
var response *modelModule.RerankResponse
response, err = newProviderInfo.Rerank(&modelName, query, documents, apiConfig, modelConfig)
@@ -1321,10 +1339,10 @@ func (m *ModelProviderService) TranscribeAudio(providerName, instanceName, model
apiConfig.Region = &region
apiConfig.ApiKey = &instance.APIKey
newURL := map[string]string{
region: extra["base_url"],
newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"])
if err != nil {
return nil, common.CodeServerError, err
}
newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL)
var response *modelModule.ASRResponse
response, err = newProviderInfo.TranscribeAudio(&modelName, audioFile, apiConfig, asrConfig)
@@ -1420,10 +1438,10 @@ func (m *ModelProviderService) TranscribeAudioStream(providerName, instanceName,
apiConfig.Region = &region
apiConfig.ApiKey = &instance.APIKey
newURL := map[string]string{
region: extra["base_url"],
newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"])
if err != nil {
return common.CodeServerError, err
}
newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL)
err = newProviderInfo.TranscribeAudioWithSender(&modelName, audioFile, apiConfig, asrConfig, sender)
if err != nil {
@@ -1526,10 +1544,10 @@ func (m *ModelProviderService) AudioSpeech(providerName, instanceName, modelName
apiConfig.Region = &region
apiConfig.ApiKey = &instance.APIKey
newURL := map[string]string{
region: extra["base_url"],
newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"])
if err != nil {
return nil, common.CodeServerError, err
}
newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL)
var response *modelModule.TTSResponse
response, err = newProviderInfo.AudioSpeech(&modelName, audioContent, apiConfig, ttsConfig)
@@ -1625,10 +1643,10 @@ func (m *ModelProviderService) AudioSpeechStream(providerName, instanceName, mod
apiConfig.Region = &region
apiConfig.ApiKey = &instance.APIKey
newURL := map[string]string{
region: extra["base_url"],
newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"])
if err != nil {
return common.CodeServerError, err
}
newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL)
err = newProviderInfo.AudioSpeechWithSender(&modelName, audioContent, apiConfig, ttsConfig, sender)
if err != nil {
@@ -1730,10 +1748,10 @@ func (m *ModelProviderService) OCRFile(providerName, instanceName, modelName, us
apiConfig.Region = &region
apiConfig.ApiKey = &instance.APIKey
newURL := map[string]string{
region: extra["base_url"],
newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"])
if err != nil {
return nil, common.CodeServerError, err
}
newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL)
var response *modelModule.OCRFileResponse
response, err = newProviderInfo.OCRFile(&modelName, content, url, apiConfig, ocrConfig)
@@ -1840,10 +1858,10 @@ func (m *ModelProviderService) ParseFile(providerName, instanceName, modelName,
apiConfig.Region = &region
apiConfig.ApiKey = &instance.APIKey
newURL := map[string]string{
region: extra["base_url"],
newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"])
if err != nil {
return nil, common.CodeServerError, err
}
newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL)
var response *modelModule.ParseFileResponse
response, err = newProviderInfo.ParseFile(&modelName, content, url, apiConfig, parseFileConfig)

View File

@@ -0,0 +1,129 @@
package service
import (
"strings"
"testing"
modelModule "ragflow/internal/entity/models"
)
type stubModelDriver struct {
modelModule.ModelDriver
newInstance func(map[string]string) modelModule.ModelDriver
}
var _ modelModule.ModelDriver = (*stubModelDriver)(nil)
func (s *stubModelDriver) NewInstance(baseURL map[string]string) modelModule.ModelDriver {
if s.newInstance != nil {
return s.newInstance(baseURL)
}
return s
}
func (s *stubModelDriver) Name() string {
return "stub"
}
func TestNewModelDriverForBaseURLPreservesEmptyRegion(t *testing.T) {
expected := &stubModelDriver{}
var gotBaseURL map[string]string
driver := &stubModelDriver{
newInstance: func(baseURL map[string]string) modelModule.ModelDriver {
gotBaseURL = baseURL
return expected
},
}
got, err := newModelDriverForBaseURL(driver, "stub", "", "http://localhost:1234")
if err != nil {
t.Fatalf("newModelDriverForBaseURL returned error: %v", err)
}
if got != expected {
t.Fatalf("expected returned driver %p, got %p", expected, got)
}
if gotBaseURL[""] != "http://localhost:1234" {
t.Fatalf("expected empty-region base URL, got %#v", gotBaseURL)
}
if _, ok := gotBaseURL["default"]; ok {
t.Fatalf("unexpected default region key in base URL map: %#v", gotBaseURL)
}
}
func TestNewModelDriverForBaseURLUsesProvidedRegion(t *testing.T) {
var gotBaseURL map[string]string
driver := &stubModelDriver{
newInstance: func(baseURL map[string]string) modelModule.ModelDriver {
gotBaseURL = baseURL
return &stubModelDriver{}
},
}
_, err := newModelDriverForBaseURL(driver, "stub", "cn-hangzhou", "http://localhost:5678")
if err != nil {
t.Fatalf("newModelDriverForBaseURL returned error: %v", err)
}
if gotBaseURL["cn-hangzhou"] != "http://localhost:5678" {
t.Fatalf("expected regional base URL, got %#v", gotBaseURL)
}
if _, ok := gotBaseURL["default"]; ok {
t.Fatalf("unexpected default region key in base URL map: %#v", gotBaseURL)
}
}
func TestNewModelDriverForBaseURLSkipsEmptyBaseURL(t *testing.T) {
for _, baseURL := range []string{"", " "} {
t.Run(baseURL, func(t *testing.T) {
called := false
driver := &stubModelDriver{
newInstance: func(map[string]string) modelModule.ModelDriver {
called = true
return nil
},
}
got, err := newModelDriverForBaseURL(driver, "deepseek", "default", baseURL)
if err != nil {
t.Fatalf("newModelDriverForBaseURL returned error: %v", err)
}
if got != driver {
t.Fatalf("expected original driver %p, got %p", driver, got)
}
if called {
t.Fatal("expected empty base URL to skip NewInstance")
}
})
}
}
func TestNewModelDriverForBaseURLRejectsNilInstance(t *testing.T) {
driver := &stubModelDriver{
newInstance: func(map[string]string) modelModule.ModelDriver {
return nil
},
}
got, err := newModelDriverForBaseURL(driver, "deepseek", "default", "http://localhost:1234")
if err == nil {
t.Fatal("expected nil NewInstance result to return an error")
}
if got != nil {
t.Fatalf("expected nil driver on error, got %T", got)
}
if !strings.Contains(err.Error(), "deepseek") || !strings.Contains(err.Error(), "custom base_url") {
t.Fatalf("expected provider-specific custom base_url error, got %v", err)
}
}
func TestNewModelDriverForBaseURLRejectsNilDriver(t *testing.T) {
got, err := newModelDriverForBaseURL(nil, "deepseek", "default", "http://localhost:1234")
if err == nil {
t.Fatal("expected nil driver to return an error")
}
if got != nil {
t.Fatalf("expected nil driver on error, got %T", got)
}
if !strings.Contains(err.Error(), "driver not found") {
t.Fatalf("expected driver not found error, got %v", err)
}
}