From e1f19f6679eaaf301beb54980169bbd0121b9f35 Mon Sep 17 00:00:00 2001 From: Jin Hai Date: Wed, 3 Jun 2026 13:23:20 +0800 Subject: [PATCH] Go: fix gitee balance api (#15554) ``` RAGFlow(user)> create provider 'gitee' instance 'intl' key 'api-token' url 'https://ai.gitee.com/v1' region 'intl'; SUCCESS ``` --------- Signed-off-by: Jin Hai --- internal/cli/response.go | 19 ++++++--- internal/cli/user_command.go | 2 +- internal/cli/user_parser.go | 42 +++++++++---------- internal/dao/tenant_model_instance.go | 22 +++++++++- internal/entity/models/gitee.go | 26 ++++++++---- internal/entity/models/types.go | 5 ++- internal/service/model_service.go | 2 + .../test_openai_stream_no_duplicate.py | 8 ++++ 8 files changed, 86 insertions(+), 40 deletions(-) diff --git a/internal/cli/response.go b/internal/cli/response.go index 76337e7703..7bf7150b6d 100644 --- a/internal/cli/response.go +++ b/internal/cli/response.go @@ -80,7 +80,14 @@ func (r *CommonDataResponse) SetOutputFormat(format OutputFormat) { func (r *CommonDataResponse) PrintOut() { if r.Code == 0 { table := make([]map[string]interface{}, 0) - table = append(table, r.Data) + for key, value := range r.Data { + elem := map[string]interface{}{ + "field": key, + "value": value, + } + table = append(table, elem) + } + //table = append(table, r.Data) PrintTableSimpleByFormat(table, r.OutputFormat) } else { fmt.Println("ERROR") @@ -156,9 +163,9 @@ func (r *ChunkResponse) PrintOut() { } type MetadataResponse struct { - Code int `json:"code"` - Data map[string]interface{} `json:"data"` - Message string `json:"message"` + Code int `json:"code"` + Data map[string]interface{} `json:"data"` + Message string `json:"message"` Duration float64 OutputFormat OutputFormat } @@ -213,8 +220,8 @@ func printFlattenedMetadata(data map[string]interface{}, format OutputFormat) { docIDStr = fmt.Sprintf("%v", docIDs) } tableData = append(tableData, map[string]interface{}{ - "field": field, - "value": value, + "field": field, + "value": value, "document_ids": docIDStr, }) } diff --git a/internal/cli/user_command.go b/internal/cli/user_command.go index 35b2c9dc7f..e41c03b22f 100644 --- a/internal/cli/user_command.go +++ b/internal/cli/user_command.go @@ -1166,7 +1166,7 @@ func (c *RAGFlowClient) DeleteProvider(cmd *Command) (ResponseIf, error) { } // CreateProviderInstance creates a new provider instance -// CREATE PROVIDER INSTANCE +// CREATE PROVIDER INSTANCE KEY URL REGION func (c *RAGFlowClient) CreateProviderInstance(cmd *Command) (ResponseIf, error) { if c.ServerType != "user" { return nil, fmt.Errorf("this command is only allowed in USER mode") diff --git a/internal/cli/user_parser.go b/internal/cli/user_parser.go index 4919ad13a7..1b107b7c50 100644 --- a/internal/cli/user_parser.go +++ b/internal/cli/user_parser.go @@ -1435,7 +1435,7 @@ func (p *Parser) parseAlterProvider() (*Command, error) { return cmd, nil } -// parseCreateProviderInstance parses CREATE PROVIDER INSTANCE KEY URL command +// parseCreateProviderInstance parses CREATE PROVIDER INSTANCE KEY URL REGION command // instance_name cannot be "default" func (p *Parser) parseCreateProviderInstance() (*Command, error) { p.nextToken() // consume PROVIDER @@ -1444,8 +1444,8 @@ func (p *Parser) parseCreateProviderInstance() (*Command, error) { if err != nil { return nil, fmt.Errorf("expected provider name: %w", err) } - p.nextToken() + if p.curToken.Type != TokenInstance { return nil, fmt.Errorf("expected INSTANCE after provider name") } @@ -1455,7 +1455,6 @@ func (p *Parser) parseCreateProviderInstance() (*Command, error) { if err != nil { return nil, fmt.Errorf("expected instance name: %w", err) } - p.nextToken() if p.curToken.Type != TokenKey { @@ -1470,23 +1469,27 @@ func (p *Parser) parseCreateProviderInstance() (*Command, error) { p.nextToken() baseURL := "" - if p.curToken.Type == TokenURL { - p.nextToken() - baseURL, err = p.parseQuotedString() - if err != nil { - return nil, fmt.Errorf("expected base URL: %w", err) - } - p.nextToken() - } - region := "" - if p.curToken.Type == TokenRegion { - p.nextToken() - region, err = p.parseQuotedString() - if err != nil { - return nil, fmt.Errorf("expected base URL: %w", err) +optionsLoop: + for { + switch p.curToken.Type { + case TokenRegion: + p.nextToken() + region, err = p.parseQuotedString() + if err != nil { + return nil, fmt.Errorf("expected region: %w", err) + } + p.nextToken() + case TokenURL: + p.nextToken() + baseURL, err = p.parseQuotedString() + if err != nil { + return nil, fmt.Errorf("expected base URL: %w", err) + } + p.nextToken() + default: + break optionsLoop } - p.nextToken() } cmd := NewCommand("create_provider_instance") @@ -1496,9 +1499,6 @@ func (p *Parser) parseCreateProviderInstance() (*Command, error) { if baseURL != "" { // Only local model provider need to set URL cmd.Params["base_url"] = baseURL - if region == "" { - region = instanceName - } } if region != "" { diff --git a/internal/dao/tenant_model_instance.go b/internal/dao/tenant_model_instance.go index d8d7f6026b..56590e9b6a 100644 --- a/internal/dao/tenant_model_instance.go +++ b/internal/dao/tenant_model_instance.go @@ -17,7 +17,11 @@ package dao import ( + "errors" + "fmt" "ragflow/internal/entity" + + "gorm.io/gorm" ) // TenantModelInstanceDAO tenant model instance data access object @@ -29,7 +33,23 @@ func NewTenantModelInstanceDAO() *TenantModelInstanceDAO { } func (dao *TenantModelInstanceDAO) Create(instance *entity.TenantModelInstance) error { - return DB.Create(instance).Error + // begin tx and check if the same provider instance exists + tx := DB.Begin() + defer tx.Rollback() + var existingInstance entity.TenantModelInstance + err := tx.Where("provider_id = ? AND instance_name = ?", instance.ProviderID, instance.InstanceName).First(&existingInstance).Error + if err == nil { + return fmt.Errorf("instance %s already exists", instance.InstanceName) + } + if !errors.Is(err, gorm.ErrRecordNotFound) { + return err + } + err = tx.Create(instance).Error + if err != nil { + return err + } + tx.Commit() + return nil } func (dao *TenantModelInstanceDAO) GetAllInstancesByProviderID(providerID string) ([]*entity.TenantModelInstance, error) { diff --git a/internal/entity/models/gitee.go b/internal/entity/models/gitee.go index 22e7ccbf73..e6c468c164 100644 --- a/internal/entity/models/gitee.go +++ b/internal/entity/models/gitee.go @@ -605,8 +605,8 @@ func (g *GiteeModel) TranscribeAudio(modelName *string, file *string, apiConfig return nil, fmt.Errorf("%s, no such method", g.Name()) } -func (z *GiteeModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { - return fmt.Errorf("%s, no such method", z.Name()) +func (g *GiteeModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", g.Name()) } // AudioSpeech convert text to audio @@ -614,8 +614,8 @@ func (g *GiteeModel) AudioSpeech(modelName *string, audioContent *string, apiCon return nil, fmt.Errorf("%s, no such method", g.Name()) } -func (z *GiteeModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { - return fmt.Errorf("%s, no such method", z.Name()) +func (g *GiteeModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", g.Name()) } type giteeOCRResponse struct { @@ -910,7 +910,6 @@ func (g *GiteeModel) ListModels(apiConfig *APIConfig) ([]string, error) { } req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) resp, err := g.httpClient.Do(req) if err != nil { @@ -946,12 +945,21 @@ func (g *GiteeModel) ListModels(apiConfig *APIConfig) ([]string, error) { } func (g *GiteeModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { - var region = "default" - if apiConfig.Region != nil && *apiConfig.Region != "" { - region = *apiConfig.Region + + var baseURL = "" + if apiConfig.BaseURL != nil && *apiConfig.BaseURL != "" { + baseURL = *apiConfig.BaseURL } - url := fmt.Sprintf("%s/%s", g.BaseURL[region], g.URLSuffix.Balance) + if baseURL == "" { + var region = "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + baseURL = g.BaseURL[region] + } + + url := fmt.Sprintf("%s/%s", baseURL, g.URLSuffix.Balance) // Build request body reqBody := map[string]interface{}{} diff --git a/internal/entity/models/types.go b/internal/entity/models/types.go index 1253485104..7963910e43 100644 --- a/internal/entity/models/types.go +++ b/internal/entity/models/types.go @@ -130,8 +130,9 @@ type ChatConfig struct { } type APIConfig struct { - ApiKey *string - Region *string + ApiKey *string + Region *string + BaseURL *string } type EmbeddingConfig struct { diff --git a/internal/service/model_service.go b/internal/service/model_service.go index cee461b59d..ce618978bd 100644 --- a/internal/service/model_service.go +++ b/internal/service/model_service.go @@ -422,8 +422,10 @@ func (m *ModelProviderService) ShowInstanceBalance(providerName, instanceName, u } region := extra["region"] + baseURL := extra["base_url"] apiConfig.Region = ®ion apiConfig.ApiKey = &instance.APIKey + apiConfig.BaseURL = &baseURL var result map[string]interface{} result, err = providerInfo.ModelDriver.Balance(apiConfig) diff --git a/test/unit_test/api/apps/restful_apis/test_openai_stream_no_duplicate.py b/test/unit_test/api/apps/restful_apis/test_openai_stream_no_duplicate.py index 8ab2e37d01..410dd61044 100644 --- a/test/unit_test/api/apps/restful_apis/test_openai_stream_no_duplicate.py +++ b/test/unit_test/api/apps/restful_apis/test_openai_stream_no_duplicate.py @@ -53,6 +53,14 @@ def _load_openai_api(monkeypatch): """Load api/apps/restful_apis/openai_api.py with the heavy deps stubbed.""" _stub(monkeypatch, "quart", Response=object, jsonify=lambda *a, **k: None) _stub(monkeypatch, "api.apps", current_user=SimpleNamespace(id="tenant-1"), login_required=lambda func: func) + # Pre-register nested modules so importlib finds them directly in + # sys.modules without trying to traverse the stubbed parent package. + _stub( + monkeypatch, + "api.apps.restful_apis._generation_params", + extract_generation_config=lambda *a, **k: ({}, {}), + merge_generation_config=lambda *a, **k: None, + ) _stub(monkeypatch, "api.db.services.dialog_service", DialogService=SimpleNamespace(), async_chat=lambda *_a, **_k: None) _stub(monkeypatch, "api.db.services.doc_metadata_service", DocMetadataService=SimpleNamespace()) _stub(