diff --git a/conf/models/cohere.json b/conf/models/cohere.json index 9d3ae1fe7b..1bc484ebfb 100644 --- a/conf/models/cohere.json +++ b/conf/models/cohere.json @@ -6,7 +6,7 @@ "url_suffix": { "chat": "v2/chat", "models": "v1/models", - "embeddings": "v2/embed", + "embedding": "v2/embed", "rerank": "v2/rerank", "asr": "audio/transcriptions" }, diff --git a/conf/models/xai.json b/conf/models/xai.json index cace7f06bd..e272b0985d 100644 --- a/conf/models/xai.json +++ b/conf/models/xai.json @@ -7,7 +7,7 @@ "chat": "chat/completions", "models": "models", "tts": "tts", - "ast": "stt" + "asr": "stt" }, "class": "grok", "models": [ @@ -61,4 +61,3 @@ } ] } - diff --git a/internal/entity/model.go b/internal/entity/model.go index a0b69f7020..db888c67c0 100644 --- a/internal/entity/model.go +++ b/internal/entity/model.go @@ -17,6 +17,7 @@ package entity import ( + "bytes" "encoding/json" "fmt" "os" @@ -186,6 +187,31 @@ type ModelResponse struct { Message string `json:"message"` } +func decodeProviderConfig(data []byte) (Provider, error) { + var provider Provider + if err := json.Unmarshal(data, &provider); err != nil { + return Provider{}, err + } + + var rawProvider struct { + URLSuffix json.RawMessage `json:"url_suffix"` + } + if err := json.Unmarshal(data, &rawProvider); err != nil { + return Provider{}, err + } + if len(rawProvider.URLSuffix) == 0 { + return provider, nil + } + + decoder := json.NewDecoder(bytes.NewReader(rawProvider.URLSuffix)) + decoder.DisallowUnknownFields() + if err := decoder.Decode(&provider.URLSuffix); err != nil { + return Provider{}, err + } + + return provider, nil +} + // NewProviderManager creates a new ProviderManager by reading all JSON files from a directory func NewProviderManager(dirPath string) (*ProviderManager, error) { providers := []Provider{} @@ -222,7 +248,7 @@ func NewProviderManager(dirPath string) (*ProviderManager, error) { // Parse JSON var provider Provider - if err = json.Unmarshal(data, &provider); err != nil { + if provider, err = decodeProviderConfig(data); err != nil { return nil, fmt.Errorf("error parsing JSON from file %s: %w", filePath, err) } diff --git a/internal/entity/model_test.go b/internal/entity/model_test.go index 591e9943df..90c965b4d4 100644 --- a/internal/entity/model_test.go +++ b/internal/entity/model_test.go @@ -20,6 +20,7 @@ import ( "os" "path/filepath" modeldrivers "ragflow/internal/entity/models" + "strings" "testing" ) @@ -123,6 +124,71 @@ func TestLocalOCRProviderConfigsLoadLocalDrivers(t *testing.T) { } } +func TestProviderConfigsLoadURLSuffixKeys(t *testing.T) { + dir := t.TempDir() + for _, fileName := range []string{"cohere.json", "xai.json"} { + if err := os.WriteFile(filepath.Join(dir, fileName), readProviderConfig(t, fileName), 0o600); err != nil { + t.Fatalf("write %s config: %v", fileName, err) + } + } + + pm, err := NewProviderManager(dir) + if err != nil { + t.Fatalf("NewProviderManager: %v", err) + } + + cohere := pm.FindProvider("CoHere") + if cohere == nil { + t.Fatal("CoHere provider not found") + } + if cohere.URLSuffix.Embedding != "v2/embed" { + t.Errorf("CoHere embedding suffix=%q", cohere.URLSuffix.Embedding) + } + + xAI := pm.FindProvider("xAI") + if xAI == nil { + t.Fatal("xAI provider not found") + } + if xAI.URLSuffix.ASR != "stt" { + t.Errorf("xAI ASR suffix=%q", xAI.URLSuffix.ASR) + } +} + +func TestProviderConfigRejectsUnknownURLSuffixKey(t *testing.T) { + dir := t.TempDir() + config := []byte(`{ + "name": "OpenAI", + "url": { + "default": "https://example.com" + }, + "url_suffix": { + "chat": "chat/completions", + "unknown_suffix": "ignored" + }, + "models": [ + { + "name": "test-model", + "max_tokens": 4096, + "model_types": ["chat"] + } + ] +}`) + if err := os.WriteFile(filepath.Join(dir, "unknown_suffix.json"), config, 0o600); err != nil { + t.Fatalf("write config: %v", err) + } + + _, err := NewProviderManager(dir) + if err == nil { + t.Fatal("NewProviderManager succeeded with unknown url_suffix key") + } + if !strings.Contains(err.Error(), `unknown field "unknown_suffix"`) { + t.Fatalf("error=%q, want unknown_suffix field", err) + } + if !strings.Contains(err.Error(), "unknown_suffix.json") { + t.Fatalf("error=%q, want config file context", err) + } +} + func TestPPIOProviderConfigLoadsIntoProviderManager(t *testing.T) { dir := t.TempDir() if err := os.WriteFile(filepath.Join(dir, "ppio.json"), readPPIOProviderConfig(t), 0o600); err != nil {