mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
fix(go): guard custom base URL driver creation (#15030)
### What problem does this PR solve? Closes #15029. Some custom `base_url` paths in `ModelProviderService` call `NewInstance(newURL)` and then immediately invoke methods on the returned driver. Several real Go model drivers still return `nil` from `NewInstance`, so those paths can panic instead of returning a normal error. This PR: - centralizes custom base URL driver creation in `model_service.go` - skips request-local driver creation when `base_url` is blank or whitespace - preserves the existing region key behavior when building the request-local base URL map - returns a clear error when the provider driver is missing or `NewInstance` returns `nil` - routes list/check/task and active model paths through the guarded helper - adds focused unit coverage for empty-region preservation, regional base URLs, blank base URLs, nil drivers, and nil `NewInstance` results ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) ### Test plan - [x] `git diff --check upstream/main...HEAD` - [x] `/root/go/bin/gofmt -w internal/service/model_service.go internal/service/model_service_test.go` - [x] `GOPATH=/root/gopath GOTOOLCHAIN=local /root/go/bin/go test ./internal/service -run TestNewModelDriverForBaseURL -count=1 -vet=off` - [x] `GOPATH=/root/gopath GOTOOLCHAIN=local /root/go/bin/go build ./internal/service/... ./internal/entity/models/...` Note: the same targeted `go test` command without `-vet=off` is currently blocked by an existing unrelated vet finding in `internal/service/llm.go:355` (`non-constant format string in call to fmt.Errorf`).
This commit is contained in:
@@ -44,6 +44,25 @@ func parseModelName(compositeName string) (modelName, instanceName, providerName
|
||||
}
|
||||
}
|
||||
|
||||
func newModelDriverForBaseURL(driver modelModule.ModelDriver, providerName, region, baseURL string) (modelModule.ModelDriver, error) {
|
||||
if driver == nil {
|
||||
return nil, fmt.Errorf("provider %s driver not found", providerName)
|
||||
}
|
||||
|
||||
if strings.TrimSpace(baseURL) == "" {
|
||||
return driver, nil
|
||||
}
|
||||
|
||||
newDriver := driver.NewInstance(map[string]string{
|
||||
region: baseURL,
|
||||
})
|
||||
if newDriver == nil {
|
||||
return nil, fmt.Errorf("provider %s does not support custom base_url", providerName)
|
||||
}
|
||||
|
||||
return newDriver, nil
|
||||
}
|
||||
|
||||
func NewModelProviderService() *ModelProviderService {
|
||||
return &ModelProviderService{
|
||||
modelProviderDAO: dao.NewTenantModelProviderDAO(),
|
||||
@@ -196,11 +215,10 @@ func (m *ModelProviderService) ListSupportedModels(providerName, instanceName, u
|
||||
|
||||
// For local deployed models
|
||||
if baseURL, ok := extra["base_url"]; ok && baseURL != "" {
|
||||
newURL := map[string]string{
|
||||
region: baseURL,
|
||||
driver, err = newModelDriverForBaseURL(driver, providerName, region, baseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
driver = driver.NewInstance(newURL)
|
||||
}
|
||||
|
||||
return driver.ListModels(apiConfig)
|
||||
@@ -447,10 +465,10 @@ func (m *ModelProviderService) CheckProviderConnection(providerName, instanceNam
|
||||
|
||||
driver := providerInfo.ModelDriver
|
||||
if baseURL, ok := extra["base_url"]; ok && baseURL != "" {
|
||||
newURL := map[string]string{
|
||||
region: baseURL,
|
||||
driver, err = newModelDriverForBaseURL(driver, providerName, region, baseURL)
|
||||
if err != nil {
|
||||
return common.CodeServerError, err
|
||||
}
|
||||
driver = driver.NewInstance(newURL)
|
||||
}
|
||||
|
||||
err = driver.CheckConnection(apiConfig)
|
||||
@@ -507,10 +525,10 @@ func (m *ModelProviderService) ListTasks(providerName, instanceName, userID stri
|
||||
|
||||
driver := providerInfo.ModelDriver
|
||||
if baseURL, ok := extra["base_url"]; ok && baseURL != "" {
|
||||
newURL := map[string]string{
|
||||
region: baseURL,
|
||||
driver, err = newModelDriverForBaseURL(driver, providerName, region, baseURL)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
driver = driver.NewInstance(newURL)
|
||||
}
|
||||
|
||||
var listTaskResponse []modelModule.ListTaskStatus
|
||||
@@ -568,10 +586,10 @@ func (m *ModelProviderService) ShowTask(providerName, instanceName, taskID, user
|
||||
|
||||
driver := providerInfo.ModelDriver
|
||||
if baseURL, ok := extra["base_url"]; ok && baseURL != "" {
|
||||
newURL := map[string]string{
|
||||
region: baseURL,
|
||||
driver, err = newModelDriverForBaseURL(driver, providerName, region, baseURL)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
driver = driver.NewInstance(newURL)
|
||||
}
|
||||
|
||||
var taskResponse *modelModule.TaskResponse
|
||||
@@ -897,10 +915,10 @@ func (m *ModelProviderService) ChatToModelWithMessages(providerName, instanceNam
|
||||
|
||||
modelConfig.ModelClass = &providerInfo.Class
|
||||
|
||||
newURL := map[string]string{
|
||||
region: extra["base_url"],
|
||||
newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"])
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL)
|
||||
|
||||
var response *modelModule.ChatResponse
|
||||
response, err = newProviderInfo.ChatWithMessages(modelName, messages, apiConfig, modelConfig)
|
||||
@@ -999,10 +1017,10 @@ func (m *ModelProviderService) ChatToModelStreamWithSender(providerName, instanc
|
||||
|
||||
modelConfig.ModelClass = &providerInfo.Class
|
||||
|
||||
newURL := map[string]string{
|
||||
region: extra["base_url"],
|
||||
newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"])
|
||||
if err != nil {
|
||||
return common.CodeServerError, err
|
||||
}
|
||||
newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL)
|
||||
|
||||
err = newProviderInfo.ChatStreamlyWithSender(modelName, messages, apiConfig, modelConfig, sender)
|
||||
if err != nil {
|
||||
@@ -1105,10 +1123,10 @@ func (m *ModelProviderService) EmbedText(providerName, instanceName, modelName,
|
||||
apiConfig.Region = ®ion
|
||||
apiConfig.ApiKey = &instance.APIKey
|
||||
|
||||
newURL := map[string]string{
|
||||
region: extra["base_url"],
|
||||
newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"])
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL)
|
||||
|
||||
var response []modelModule.EmbeddingData
|
||||
response, err = newProviderInfo.Embed(&modelName, texts, apiConfig, modelConfig)
|
||||
@@ -1213,10 +1231,10 @@ func (m *ModelProviderService) RerankDocument(providerName, instanceName, modelN
|
||||
apiConfig.Region = ®ion
|
||||
apiConfig.ApiKey = &instance.APIKey
|
||||
|
||||
newURL := map[string]string{
|
||||
region: extra["base_url"],
|
||||
newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"])
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL)
|
||||
|
||||
var response *modelModule.RerankResponse
|
||||
response, err = newProviderInfo.Rerank(&modelName, query, documents, apiConfig, modelConfig)
|
||||
@@ -1321,10 +1339,10 @@ func (m *ModelProviderService) TranscribeAudio(providerName, instanceName, model
|
||||
apiConfig.Region = ®ion
|
||||
apiConfig.ApiKey = &instance.APIKey
|
||||
|
||||
newURL := map[string]string{
|
||||
region: extra["base_url"],
|
||||
newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"])
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL)
|
||||
|
||||
var response *modelModule.ASRResponse
|
||||
response, err = newProviderInfo.TranscribeAudio(&modelName, audioFile, apiConfig, asrConfig)
|
||||
@@ -1420,10 +1438,10 @@ func (m *ModelProviderService) TranscribeAudioStream(providerName, instanceName,
|
||||
apiConfig.Region = ®ion
|
||||
apiConfig.ApiKey = &instance.APIKey
|
||||
|
||||
newURL := map[string]string{
|
||||
region: extra["base_url"],
|
||||
newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"])
|
||||
if err != nil {
|
||||
return common.CodeServerError, err
|
||||
}
|
||||
newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL)
|
||||
|
||||
err = newProviderInfo.TranscribeAudioWithSender(&modelName, audioFile, apiConfig, asrConfig, sender)
|
||||
if err != nil {
|
||||
@@ -1526,10 +1544,10 @@ func (m *ModelProviderService) AudioSpeech(providerName, instanceName, modelName
|
||||
apiConfig.Region = ®ion
|
||||
apiConfig.ApiKey = &instance.APIKey
|
||||
|
||||
newURL := map[string]string{
|
||||
region: extra["base_url"],
|
||||
newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"])
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL)
|
||||
|
||||
var response *modelModule.TTSResponse
|
||||
response, err = newProviderInfo.AudioSpeech(&modelName, audioContent, apiConfig, ttsConfig)
|
||||
@@ -1625,10 +1643,10 @@ func (m *ModelProviderService) AudioSpeechStream(providerName, instanceName, mod
|
||||
apiConfig.Region = ®ion
|
||||
apiConfig.ApiKey = &instance.APIKey
|
||||
|
||||
newURL := map[string]string{
|
||||
region: extra["base_url"],
|
||||
newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"])
|
||||
if err != nil {
|
||||
return common.CodeServerError, err
|
||||
}
|
||||
newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL)
|
||||
|
||||
err = newProviderInfo.AudioSpeechWithSender(&modelName, audioContent, apiConfig, ttsConfig, sender)
|
||||
if err != nil {
|
||||
@@ -1730,10 +1748,10 @@ func (m *ModelProviderService) OCRFile(providerName, instanceName, modelName, us
|
||||
apiConfig.Region = ®ion
|
||||
apiConfig.ApiKey = &instance.APIKey
|
||||
|
||||
newURL := map[string]string{
|
||||
region: extra["base_url"],
|
||||
newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"])
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL)
|
||||
|
||||
var response *modelModule.OCRFileResponse
|
||||
response, err = newProviderInfo.OCRFile(&modelName, content, url, apiConfig, ocrConfig)
|
||||
@@ -1840,10 +1858,10 @@ func (m *ModelProviderService) ParseFile(providerName, instanceName, modelName,
|
||||
apiConfig.Region = ®ion
|
||||
apiConfig.ApiKey = &instance.APIKey
|
||||
|
||||
newURL := map[string]string{
|
||||
region: extra["base_url"],
|
||||
newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"])
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL)
|
||||
|
||||
var response *modelModule.ParseFileResponse
|
||||
response, err = newProviderInfo.ParseFile(&modelName, content, url, apiConfig, parseFileConfig)
|
||||
|
||||
129
internal/service/model_service_test.go
Normal file
129
internal/service/model_service_test.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
modelModule "ragflow/internal/entity/models"
|
||||
)
|
||||
|
||||
type stubModelDriver struct {
|
||||
modelModule.ModelDriver
|
||||
newInstance func(map[string]string) modelModule.ModelDriver
|
||||
}
|
||||
|
||||
var _ modelModule.ModelDriver = (*stubModelDriver)(nil)
|
||||
|
||||
func (s *stubModelDriver) NewInstance(baseURL map[string]string) modelModule.ModelDriver {
|
||||
if s.newInstance != nil {
|
||||
return s.newInstance(baseURL)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *stubModelDriver) Name() string {
|
||||
return "stub"
|
||||
}
|
||||
|
||||
func TestNewModelDriverForBaseURLPreservesEmptyRegion(t *testing.T) {
|
||||
expected := &stubModelDriver{}
|
||||
var gotBaseURL map[string]string
|
||||
driver := &stubModelDriver{
|
||||
newInstance: func(baseURL map[string]string) modelModule.ModelDriver {
|
||||
gotBaseURL = baseURL
|
||||
return expected
|
||||
},
|
||||
}
|
||||
|
||||
got, err := newModelDriverForBaseURL(driver, "stub", "", "http://localhost:1234")
|
||||
if err != nil {
|
||||
t.Fatalf("newModelDriverForBaseURL returned error: %v", err)
|
||||
}
|
||||
if got != expected {
|
||||
t.Fatalf("expected returned driver %p, got %p", expected, got)
|
||||
}
|
||||
if gotBaseURL[""] != "http://localhost:1234" {
|
||||
t.Fatalf("expected empty-region base URL, got %#v", gotBaseURL)
|
||||
}
|
||||
if _, ok := gotBaseURL["default"]; ok {
|
||||
t.Fatalf("unexpected default region key in base URL map: %#v", gotBaseURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewModelDriverForBaseURLUsesProvidedRegion(t *testing.T) {
|
||||
var gotBaseURL map[string]string
|
||||
driver := &stubModelDriver{
|
||||
newInstance: func(baseURL map[string]string) modelModule.ModelDriver {
|
||||
gotBaseURL = baseURL
|
||||
return &stubModelDriver{}
|
||||
},
|
||||
}
|
||||
|
||||
_, err := newModelDriverForBaseURL(driver, "stub", "cn-hangzhou", "http://localhost:5678")
|
||||
if err != nil {
|
||||
t.Fatalf("newModelDriverForBaseURL returned error: %v", err)
|
||||
}
|
||||
if gotBaseURL["cn-hangzhou"] != "http://localhost:5678" {
|
||||
t.Fatalf("expected regional base URL, got %#v", gotBaseURL)
|
||||
}
|
||||
if _, ok := gotBaseURL["default"]; ok {
|
||||
t.Fatalf("unexpected default region key in base URL map: %#v", gotBaseURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewModelDriverForBaseURLSkipsEmptyBaseURL(t *testing.T) {
|
||||
for _, baseURL := range []string{"", " "} {
|
||||
t.Run(baseURL, func(t *testing.T) {
|
||||
called := false
|
||||
driver := &stubModelDriver{
|
||||
newInstance: func(map[string]string) modelModule.ModelDriver {
|
||||
called = true
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
got, err := newModelDriverForBaseURL(driver, "deepseek", "default", baseURL)
|
||||
if err != nil {
|
||||
t.Fatalf("newModelDriverForBaseURL returned error: %v", err)
|
||||
}
|
||||
if got != driver {
|
||||
t.Fatalf("expected original driver %p, got %p", driver, got)
|
||||
}
|
||||
if called {
|
||||
t.Fatal("expected empty base URL to skip NewInstance")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewModelDriverForBaseURLRejectsNilInstance(t *testing.T) {
|
||||
driver := &stubModelDriver{
|
||||
newInstance: func(map[string]string) modelModule.ModelDriver {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
got, err := newModelDriverForBaseURL(driver, "deepseek", "default", "http://localhost:1234")
|
||||
if err == nil {
|
||||
t.Fatal("expected nil NewInstance result to return an error")
|
||||
}
|
||||
if got != nil {
|
||||
t.Fatalf("expected nil driver on error, got %T", got)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "deepseek") || !strings.Contains(err.Error(), "custom base_url") {
|
||||
t.Fatalf("expected provider-specific custom base_url error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewModelDriverForBaseURLRejectsNilDriver(t *testing.T) {
|
||||
got, err := newModelDriverForBaseURL(nil, "deepseek", "default", "http://localhost:1234")
|
||||
if err == nil {
|
||||
t.Fatal("expected nil driver to return an error")
|
||||
}
|
||||
if got != nil {
|
||||
t.Fatalf("expected nil driver on error, got %T", got)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "driver not found") {
|
||||
t.Fatalf("expected driver not found error, got %v", err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user