diff --git a/internal/service/model_service.go b/internal/service/model_service.go index fe179d54e2..9813010a28 100644 --- a/internal/service/model_service.go +++ b/internal/service/model_service.go @@ -44,6 +44,25 @@ func parseModelName(compositeName string) (modelName, instanceName, providerName } } +func newModelDriverForBaseURL(driver modelModule.ModelDriver, providerName, region, baseURL string) (modelModule.ModelDriver, error) { + if driver == nil { + return nil, fmt.Errorf("provider %s driver not found", providerName) + } + + if strings.TrimSpace(baseURL) == "" { + return driver, nil + } + + newDriver := driver.NewInstance(map[string]string{ + region: baseURL, + }) + if newDriver == nil { + return nil, fmt.Errorf("provider %s does not support custom base_url", providerName) + } + + return newDriver, nil +} + func NewModelProviderService() *ModelProviderService { return &ModelProviderService{ modelProviderDAO: dao.NewTenantModelProviderDAO(), @@ -196,11 +215,10 @@ func (m *ModelProviderService) ListSupportedModels(providerName, instanceName, u // For local deployed models if baseURL, ok := extra["base_url"]; ok && baseURL != "" { - newURL := map[string]string{ - region: baseURL, + driver, err = newModelDriverForBaseURL(driver, providerName, region, baseURL) + if err != nil { + return nil, err } - - driver = driver.NewInstance(newURL) } return driver.ListModels(apiConfig) @@ -447,10 +465,10 @@ func (m *ModelProviderService) CheckProviderConnection(providerName, instanceNam driver := providerInfo.ModelDriver if baseURL, ok := extra["base_url"]; ok && baseURL != "" { - newURL := map[string]string{ - region: baseURL, + driver, err = newModelDriverForBaseURL(driver, providerName, region, baseURL) + if err != nil { + return common.CodeServerError, err } - driver = driver.NewInstance(newURL) } err = driver.CheckConnection(apiConfig) @@ -507,10 +525,10 @@ func (m *ModelProviderService) ListTasks(providerName, instanceName, userID stri driver := providerInfo.ModelDriver if baseURL, ok := extra["base_url"]; ok && baseURL != "" { - newURL := map[string]string{ - region: baseURL, + driver, err = newModelDriverForBaseURL(driver, providerName, region, baseURL) + if err != nil { + return nil, common.CodeServerError, err } - driver = driver.NewInstance(newURL) } var listTaskResponse []modelModule.ListTaskStatus @@ -568,10 +586,10 @@ func (m *ModelProviderService) ShowTask(providerName, instanceName, taskID, user driver := providerInfo.ModelDriver if baseURL, ok := extra["base_url"]; ok && baseURL != "" { - newURL := map[string]string{ - region: baseURL, + driver, err = newModelDriverForBaseURL(driver, providerName, region, baseURL) + if err != nil { + return nil, common.CodeServerError, err } - driver = driver.NewInstance(newURL) } var taskResponse *modelModule.TaskResponse @@ -897,10 +915,10 @@ func (m *ModelProviderService) ChatToModelWithMessages(providerName, instanceNam modelConfig.ModelClass = &providerInfo.Class - newURL := map[string]string{ - region: extra["base_url"], + newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"]) + if err != nil { + return nil, common.CodeServerError, err } - newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL) var response *modelModule.ChatResponse response, err = newProviderInfo.ChatWithMessages(modelName, messages, apiConfig, modelConfig) @@ -999,10 +1017,10 @@ func (m *ModelProviderService) ChatToModelStreamWithSender(providerName, instanc modelConfig.ModelClass = &providerInfo.Class - newURL := map[string]string{ - region: extra["base_url"], + newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"]) + if err != nil { + return common.CodeServerError, err } - newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL) err = newProviderInfo.ChatStreamlyWithSender(modelName, messages, apiConfig, modelConfig, sender) if err != nil { @@ -1105,10 +1123,10 @@ func (m *ModelProviderService) EmbedText(providerName, instanceName, modelName, apiConfig.Region = ®ion apiConfig.ApiKey = &instance.APIKey - newURL := map[string]string{ - region: extra["base_url"], + newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"]) + if err != nil { + return nil, common.CodeServerError, err } - newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL) var response []modelModule.EmbeddingData response, err = newProviderInfo.Embed(&modelName, texts, apiConfig, modelConfig) @@ -1213,10 +1231,10 @@ func (m *ModelProviderService) RerankDocument(providerName, instanceName, modelN apiConfig.Region = ®ion apiConfig.ApiKey = &instance.APIKey - newURL := map[string]string{ - region: extra["base_url"], + newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"]) + if err != nil { + return nil, common.CodeServerError, err } - newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL) var response *modelModule.RerankResponse response, err = newProviderInfo.Rerank(&modelName, query, documents, apiConfig, modelConfig) @@ -1321,10 +1339,10 @@ func (m *ModelProviderService) TranscribeAudio(providerName, instanceName, model apiConfig.Region = ®ion apiConfig.ApiKey = &instance.APIKey - newURL := map[string]string{ - region: extra["base_url"], + newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"]) + if err != nil { + return nil, common.CodeServerError, err } - newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL) var response *modelModule.ASRResponse response, err = newProviderInfo.TranscribeAudio(&modelName, audioFile, apiConfig, asrConfig) @@ -1420,10 +1438,10 @@ func (m *ModelProviderService) TranscribeAudioStream(providerName, instanceName, apiConfig.Region = ®ion apiConfig.ApiKey = &instance.APIKey - newURL := map[string]string{ - region: extra["base_url"], + newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"]) + if err != nil { + return common.CodeServerError, err } - newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL) err = newProviderInfo.TranscribeAudioWithSender(&modelName, audioFile, apiConfig, asrConfig, sender) if err != nil { @@ -1526,10 +1544,10 @@ func (m *ModelProviderService) AudioSpeech(providerName, instanceName, modelName apiConfig.Region = ®ion apiConfig.ApiKey = &instance.APIKey - newURL := map[string]string{ - region: extra["base_url"], + newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"]) + if err != nil { + return nil, common.CodeServerError, err } - newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL) var response *modelModule.TTSResponse response, err = newProviderInfo.AudioSpeech(&modelName, audioContent, apiConfig, ttsConfig) @@ -1625,10 +1643,10 @@ func (m *ModelProviderService) AudioSpeechStream(providerName, instanceName, mod apiConfig.Region = ®ion apiConfig.ApiKey = &instance.APIKey - newURL := map[string]string{ - region: extra["base_url"], + newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"]) + if err != nil { + return common.CodeServerError, err } - newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL) err = newProviderInfo.AudioSpeechWithSender(&modelName, audioContent, apiConfig, ttsConfig, sender) if err != nil { @@ -1730,10 +1748,10 @@ func (m *ModelProviderService) OCRFile(providerName, instanceName, modelName, us apiConfig.Region = ®ion apiConfig.ApiKey = &instance.APIKey - newURL := map[string]string{ - region: extra["base_url"], + newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"]) + if err != nil { + return nil, common.CodeServerError, err } - newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL) var response *modelModule.OCRFileResponse response, err = newProviderInfo.OCRFile(&modelName, content, url, apiConfig, ocrConfig) @@ -1840,10 +1858,10 @@ func (m *ModelProviderService) ParseFile(providerName, instanceName, modelName, apiConfig.Region = ®ion apiConfig.ApiKey = &instance.APIKey - newURL := map[string]string{ - region: extra["base_url"], + newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"]) + if err != nil { + return nil, common.CodeServerError, err } - newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL) var response *modelModule.ParseFileResponse response, err = newProviderInfo.ParseFile(&modelName, content, url, apiConfig, parseFileConfig) diff --git a/internal/service/model_service_test.go b/internal/service/model_service_test.go new file mode 100644 index 0000000000..7b6e138c4c --- /dev/null +++ b/internal/service/model_service_test.go @@ -0,0 +1,129 @@ +package service + +import ( + "strings" + "testing" + + modelModule "ragflow/internal/entity/models" +) + +type stubModelDriver struct { + modelModule.ModelDriver + newInstance func(map[string]string) modelModule.ModelDriver +} + +var _ modelModule.ModelDriver = (*stubModelDriver)(nil) + +func (s *stubModelDriver) NewInstance(baseURL map[string]string) modelModule.ModelDriver { + if s.newInstance != nil { + return s.newInstance(baseURL) + } + return s +} + +func (s *stubModelDriver) Name() string { + return "stub" +} + +func TestNewModelDriverForBaseURLPreservesEmptyRegion(t *testing.T) { + expected := &stubModelDriver{} + var gotBaseURL map[string]string + driver := &stubModelDriver{ + newInstance: func(baseURL map[string]string) modelModule.ModelDriver { + gotBaseURL = baseURL + return expected + }, + } + + got, err := newModelDriverForBaseURL(driver, "stub", "", "http://localhost:1234") + if err != nil { + t.Fatalf("newModelDriverForBaseURL returned error: %v", err) + } + if got != expected { + t.Fatalf("expected returned driver %p, got %p", expected, got) + } + if gotBaseURL[""] != "http://localhost:1234" { + t.Fatalf("expected empty-region base URL, got %#v", gotBaseURL) + } + if _, ok := gotBaseURL["default"]; ok { + t.Fatalf("unexpected default region key in base URL map: %#v", gotBaseURL) + } +} + +func TestNewModelDriverForBaseURLUsesProvidedRegion(t *testing.T) { + var gotBaseURL map[string]string + driver := &stubModelDriver{ + newInstance: func(baseURL map[string]string) modelModule.ModelDriver { + gotBaseURL = baseURL + return &stubModelDriver{} + }, + } + + _, err := newModelDriverForBaseURL(driver, "stub", "cn-hangzhou", "http://localhost:5678") + if err != nil { + t.Fatalf("newModelDriverForBaseURL returned error: %v", err) + } + if gotBaseURL["cn-hangzhou"] != "http://localhost:5678" { + t.Fatalf("expected regional base URL, got %#v", gotBaseURL) + } + if _, ok := gotBaseURL["default"]; ok { + t.Fatalf("unexpected default region key in base URL map: %#v", gotBaseURL) + } +} + +func TestNewModelDriverForBaseURLSkipsEmptyBaseURL(t *testing.T) { + for _, baseURL := range []string{"", " "} { + t.Run(baseURL, func(t *testing.T) { + called := false + driver := &stubModelDriver{ + newInstance: func(map[string]string) modelModule.ModelDriver { + called = true + return nil + }, + } + + got, err := newModelDriverForBaseURL(driver, "deepseek", "default", baseURL) + if err != nil { + t.Fatalf("newModelDriverForBaseURL returned error: %v", err) + } + if got != driver { + t.Fatalf("expected original driver %p, got %p", driver, got) + } + if called { + t.Fatal("expected empty base URL to skip NewInstance") + } + }) + } +} + +func TestNewModelDriverForBaseURLRejectsNilInstance(t *testing.T) { + driver := &stubModelDriver{ + newInstance: func(map[string]string) modelModule.ModelDriver { + return nil + }, + } + + got, err := newModelDriverForBaseURL(driver, "deepseek", "default", "http://localhost:1234") + if err == nil { + t.Fatal("expected nil NewInstance result to return an error") + } + if got != nil { + t.Fatalf("expected nil driver on error, got %T", got) + } + if !strings.Contains(err.Error(), "deepseek") || !strings.Contains(err.Error(), "custom base_url") { + t.Fatalf("expected provider-specific custom base_url error, got %v", err) + } +} + +func TestNewModelDriverForBaseURLRejectsNilDriver(t *testing.T) { + got, err := newModelDriverForBaseURL(nil, "deepseek", "default", "http://localhost:1234") + if err == nil { + t.Fatal("expected nil driver to return an error") + } + if got != nil { + t.Fatalf("expected nil driver on error, got %T", got) + } + if !strings.Contains(err.Error(), "driver not found") { + t.Fatalf("expected driver not found error, got %v", err) + } +}