mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
fix(go-models): validate URL suffix config keys (#15734)
## Summary Fixes typoed model-provider URL suffix keys and adds strict nested decoding so future URL suffix config mistakes fail during provider loading instead of being silently ignored.
This commit is contained in:
@@ -6,7 +6,7 @@
|
||||
"url_suffix": {
|
||||
"chat": "v2/chat",
|
||||
"models": "v1/models",
|
||||
"embeddings": "v2/embed",
|
||||
"embedding": "v2/embed",
|
||||
"rerank": "v2/rerank",
|
||||
"asr": "audio/transcriptions"
|
||||
},
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
"chat": "chat/completions",
|
||||
"models": "models",
|
||||
"tts": "tts",
|
||||
"ast": "stt"
|
||||
"asr": "stt"
|
||||
},
|
||||
"class": "grok",
|
||||
"models": [
|
||||
@@ -61,4 +61,3 @@
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
package entity
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
@@ -186,6 +187,31 @@ type ModelResponse struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
func decodeProviderConfig(data []byte) (Provider, error) {
|
||||
var provider Provider
|
||||
if err := json.Unmarshal(data, &provider); err != nil {
|
||||
return Provider{}, err
|
||||
}
|
||||
|
||||
var rawProvider struct {
|
||||
URLSuffix json.RawMessage `json:"url_suffix"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &rawProvider); err != nil {
|
||||
return Provider{}, err
|
||||
}
|
||||
if len(rawProvider.URLSuffix) == 0 {
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
decoder := json.NewDecoder(bytes.NewReader(rawProvider.URLSuffix))
|
||||
decoder.DisallowUnknownFields()
|
||||
if err := decoder.Decode(&provider.URLSuffix); err != nil {
|
||||
return Provider{}, err
|
||||
}
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
// NewProviderManager creates a new ProviderManager by reading all JSON files from a directory
|
||||
func NewProviderManager(dirPath string) (*ProviderManager, error) {
|
||||
providers := []Provider{}
|
||||
@@ -222,7 +248,7 @@ func NewProviderManager(dirPath string) (*ProviderManager, error) {
|
||||
|
||||
// Parse JSON
|
||||
var provider Provider
|
||||
if err = json.Unmarshal(data, &provider); err != nil {
|
||||
if provider, err = decodeProviderConfig(data); err != nil {
|
||||
return nil, fmt.Errorf("error parsing JSON from file %s: %w", filePath, err)
|
||||
}
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
modeldrivers "ragflow/internal/entity/models"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -123,6 +124,71 @@ func TestLocalOCRProviderConfigsLoadLocalDrivers(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderConfigsLoadURLSuffixKeys(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
for _, fileName := range []string{"cohere.json", "xai.json"} {
|
||||
if err := os.WriteFile(filepath.Join(dir, fileName), readProviderConfig(t, fileName), 0o600); err != nil {
|
||||
t.Fatalf("write %s config: %v", fileName, err)
|
||||
}
|
||||
}
|
||||
|
||||
pm, err := NewProviderManager(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProviderManager: %v", err)
|
||||
}
|
||||
|
||||
cohere := pm.FindProvider("CoHere")
|
||||
if cohere == nil {
|
||||
t.Fatal("CoHere provider not found")
|
||||
}
|
||||
if cohere.URLSuffix.Embedding != "v2/embed" {
|
||||
t.Errorf("CoHere embedding suffix=%q", cohere.URLSuffix.Embedding)
|
||||
}
|
||||
|
||||
xAI := pm.FindProvider("xAI")
|
||||
if xAI == nil {
|
||||
t.Fatal("xAI provider not found")
|
||||
}
|
||||
if xAI.URLSuffix.ASR != "stt" {
|
||||
t.Errorf("xAI ASR suffix=%q", xAI.URLSuffix.ASR)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderConfigRejectsUnknownURLSuffixKey(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
config := []byte(`{
|
||||
"name": "OpenAI",
|
||||
"url": {
|
||||
"default": "https://example.com"
|
||||
},
|
||||
"url_suffix": {
|
||||
"chat": "chat/completions",
|
||||
"unknown_suffix": "ignored"
|
||||
},
|
||||
"models": [
|
||||
{
|
||||
"name": "test-model",
|
||||
"max_tokens": 4096,
|
||||
"model_types": ["chat"]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
if err := os.WriteFile(filepath.Join(dir, "unknown_suffix.json"), config, 0o600); err != nil {
|
||||
t.Fatalf("write config: %v", err)
|
||||
}
|
||||
|
||||
_, err := NewProviderManager(dir)
|
||||
if err == nil {
|
||||
t.Fatal("NewProviderManager succeeded with unknown url_suffix key")
|
||||
}
|
||||
if !strings.Contains(err.Error(), `unknown field "unknown_suffix"`) {
|
||||
t.Fatalf("error=%q, want unknown_suffix field", err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "unknown_suffix.json") {
|
||||
t.Fatalf("error=%q, want config file context", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPPIOProviderConfigLoadsIntoProviderManager(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
if err := os.WriteFile(filepath.Join(dir, "ppio.json"), readPPIOProviderConfig(t), 0o600); err != nil {
|
||||
|
||||
Reference in New Issue
Block a user