mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
### What problem does this PR solve? Added 4 new models: deepseek-ai/DeepSeek-V4-Pro deepseek-ai/DeepSeek-V4-Flash Pro/moonshotai/Kimi-K2.6 Pro/zai-org/GLM-5.1 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
274 lines
8.5 KiB
Go
274 lines
8.5 KiB
Go
//
|
|
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
//
|
|
|
|
package entity
|
|
|
|
import (
|
|
"os"
|
|
"path/filepath"
|
|
modeldrivers "ragflow/internal/entity/models"
|
|
"testing"
|
|
)
|
|
|
|
func readProviderConfig(t *testing.T, fileName string) []byte {
|
|
t.Helper()
|
|
|
|
for _, candidate := range []string{
|
|
filepath.Join("..", "..", "conf", "models", fileName),
|
|
filepath.Join("conf", "models", fileName),
|
|
} {
|
|
data, err := os.ReadFile(candidate)
|
|
if err == nil {
|
|
return data
|
|
}
|
|
}
|
|
|
|
t.Fatalf("could not locate conf/models/%s", fileName)
|
|
return nil
|
|
}
|
|
|
|
func readPPIOProviderConfig(t *testing.T) []byte {
|
|
t.Helper()
|
|
return readProviderConfig(t, "ppio.json")
|
|
}
|
|
|
|
func TestHostedProviderConfigsLoadSharedDrivers(t *testing.T) {
|
|
dir := t.TempDir()
|
|
for _, fileName := range []string{"mineru.json", "paddleocr.json"} {
|
|
if err := os.WriteFile(filepath.Join(dir, fileName), readProviderConfig(t, fileName), 0o600); err != nil {
|
|
t.Fatalf("write %s config: %v", fileName, err)
|
|
}
|
|
}
|
|
|
|
pm, err := NewProviderManager(dir)
|
|
if err != nil {
|
|
t.Fatalf("NewProviderManager: %v", err)
|
|
}
|
|
|
|
minerU := pm.FindProvider("MinerU.Net")
|
|
if minerU == nil {
|
|
t.Fatal("MinerU.Net provider not found")
|
|
}
|
|
if _, ok := minerU.ModelDriver.(*modeldrivers.MinerUModel); !ok {
|
|
t.Fatalf("MinerU.Net ModelDriver=%T, want *models.MinerUModel", minerU.ModelDriver)
|
|
}
|
|
if minerU.Class != "mineru.net" {
|
|
t.Errorf("MinerU.Net class=%q", minerU.Class)
|
|
}
|
|
if minerU.URLSuffix.DocumentParse != "v4/extract/task" {
|
|
t.Errorf("MinerU.Net doc_parse suffix=%q", minerU.URLSuffix.DocumentParse)
|
|
}
|
|
|
|
paddleOCR := pm.FindProvider("PaddleOCR.Net")
|
|
if paddleOCR == nil {
|
|
t.Fatal("PaddleOCR.Net provider not found")
|
|
}
|
|
if _, ok := paddleOCR.ModelDriver.(*modeldrivers.PaddleOCRModel); !ok {
|
|
t.Fatalf("PaddleOCR.Net ModelDriver=%T, want *models.PaddleOCRModel", paddleOCR.ModelDriver)
|
|
}
|
|
if paddleOCR.Class != "paddleocr.net" {
|
|
t.Errorf("PaddleOCR.Net class=%q", paddleOCR.Class)
|
|
}
|
|
if paddleOCR.URLSuffix.OCR != "v2/ocr/jobs" {
|
|
t.Errorf("PaddleOCR.Net OCR suffix=%q", paddleOCR.URLSuffix.OCR)
|
|
}
|
|
}
|
|
|
|
func TestLocalOCRProviderConfigsLoadLocalDrivers(t *testing.T) {
|
|
dir := t.TempDir()
|
|
for _, fileName := range []string{"mineru_local.json", "paddleocr_local.json"} {
|
|
if err := os.WriteFile(filepath.Join(dir, fileName), readProviderConfig(t, fileName), 0o600); err != nil {
|
|
t.Fatalf("write %s config: %v", fileName, err)
|
|
}
|
|
}
|
|
|
|
pm, err := NewProviderManager(dir)
|
|
if err != nil {
|
|
t.Fatalf("NewProviderManager: %v", err)
|
|
}
|
|
|
|
minerU := pm.FindProvider("MinerU")
|
|
if minerU == nil {
|
|
t.Fatal("MinerU provider not found")
|
|
}
|
|
if _, ok := minerU.ModelDriver.(*modeldrivers.MinerULocalModel); !ok {
|
|
t.Fatalf("MinerU ModelDriver=%T, want *models.MinerULocalModel", minerU.ModelDriver)
|
|
}
|
|
if minerU.URLSuffix.DocumentParse != "file_parse" {
|
|
t.Errorf("MinerU doc_parse suffix=%q", minerU.URLSuffix.DocumentParse)
|
|
}
|
|
|
|
paddleOCR := pm.FindProvider("PaddleOCR")
|
|
if paddleOCR == nil {
|
|
t.Fatal("PaddleOCR provider not found")
|
|
}
|
|
if _, ok := paddleOCR.ModelDriver.(*modeldrivers.PaddleOCRLocalModel); !ok {
|
|
t.Fatalf("PaddleOCR ModelDriver=%T, want *models.PaddleOCRLocalModel", paddleOCR.ModelDriver)
|
|
}
|
|
if paddleOCR.URLSuffix.OCR != "layout-parsing" {
|
|
t.Errorf("PaddleOCR OCR suffix=%q", paddleOCR.URLSuffix.OCR)
|
|
}
|
|
}
|
|
|
|
func TestPPIOProviderConfigLoadsIntoProviderManager(t *testing.T) {
|
|
dir := t.TempDir()
|
|
if err := os.WriteFile(filepath.Join(dir, "ppio.json"), readPPIOProviderConfig(t), 0o600); err != nil {
|
|
t.Fatalf("write ppio config: %v", err)
|
|
}
|
|
|
|
pm, err := NewProviderManager(dir)
|
|
if err != nil {
|
|
t.Fatalf("NewProviderManager: %v", err)
|
|
}
|
|
|
|
provider := pm.FindProvider("ppio")
|
|
if provider == nil {
|
|
t.Fatal("PPIO provider not found")
|
|
}
|
|
if provider.Name != "PPIO" {
|
|
t.Errorf("provider.Name=%q", provider.Name)
|
|
}
|
|
if provider.URL["default"] != "https://api.ppio.com/openai/v1" {
|
|
t.Errorf("default URL=%q", provider.URL["default"])
|
|
}
|
|
if provider.URL["us"] != "https://api.ppinfra.com/v3/openai" {
|
|
t.Errorf("us URL=%q", provider.URL["us"])
|
|
}
|
|
if provider.URLSuffix.Chat != "chat/completions" {
|
|
t.Errorf("chat suffix=%q", provider.URLSuffix.Chat)
|
|
}
|
|
if provider.URLSuffix.Models != "models" {
|
|
t.Errorf("models suffix=%q", provider.URLSuffix.Models)
|
|
}
|
|
if _, ok := provider.ModelDriver.(*modeldrivers.PPIOModel); !ok {
|
|
t.Fatalf("ModelDriver=%T, want *models.PPIOModel", provider.ModelDriver)
|
|
}
|
|
if provider.ModelDriver.Name() != "ppio" {
|
|
t.Errorf("ModelDriver.Name()=%q", provider.ModelDriver.Name())
|
|
}
|
|
if len(provider.Models) != 21 {
|
|
t.Fatalf("PPIO model count=%d, want 21", len(provider.Models))
|
|
}
|
|
for _, model := range provider.Models {
|
|
if !model.ModelTypeMap["chat"] {
|
|
t.Errorf("model %q missing chat type map", model.Name)
|
|
}
|
|
if model.Class == nil || *model.Class != "PPIO" {
|
|
t.Errorf("model %q class=%v", model.Name, model.Class)
|
|
}
|
|
}
|
|
|
|
models, err := pm.ListModels("PPIO")
|
|
if err != nil {
|
|
t.Fatalf("ListModels: %v", err)
|
|
}
|
|
if len(models) != 21 {
|
|
t.Errorf("ListModels count=%d, want 21", len(models))
|
|
}
|
|
|
|
model, err := pm.GetModelByName("ppio", "deepseek/deepseek-r1")
|
|
if err != nil {
|
|
t.Fatalf("GetModelByName: %v", err)
|
|
}
|
|
if model.MaxTokens != 64000 {
|
|
t.Errorf("deepseek/deepseek-r1 max_tokens=%d", model.MaxTokens)
|
|
}
|
|
model, err = pm.GetModelByName("ppio", "deepseek/deepseek-v4-pro")
|
|
if err != nil {
|
|
t.Fatalf("GetModelByName v4 pro: %v", err)
|
|
}
|
|
if model.MaxTokens != 1048576 {
|
|
t.Errorf("deepseek/deepseek-v4-pro max_tokens=%d", model.MaxTokens)
|
|
}
|
|
model, err = pm.GetModelByName("ppio", "deepseek/deepseek-v4-flash")
|
|
if err != nil {
|
|
t.Fatalf("GetModelByName v4 flash: %v", err)
|
|
}
|
|
if model.MaxTokens != 1048576 {
|
|
t.Errorf("deepseek/deepseek-v4-flash max_tokens=%d", model.MaxTokens)
|
|
}
|
|
|
|
resp := pm.SearchByType("chat")
|
|
if resp.Code != 0 {
|
|
t.Fatalf("SearchByType code=%d message=%q", resp.Code, resp.Message)
|
|
}
|
|
if len(resp.Data) != 21 {
|
|
t.Errorf("SearchByType data count=%d, want 21", len(resp.Data))
|
|
}
|
|
}
|
|
|
|
func TestSiliconFlowProviderConfigLoadsLatestProModels(t *testing.T) {
|
|
dir := t.TempDir()
|
|
if err := os.WriteFile(filepath.Join(dir, "siliconflow.json"), readProviderConfig(t, "siliconflow.json"), 0o600); err != nil {
|
|
t.Fatalf("write siliconflow config: %v", err)
|
|
}
|
|
|
|
pm, err := NewProviderManager(dir)
|
|
if err != nil {
|
|
t.Fatalf("NewProviderManager: %v", err)
|
|
}
|
|
|
|
provider := pm.FindProvider("SiliconFlow")
|
|
if provider == nil {
|
|
t.Fatal("SiliconFlow provider not found")
|
|
}
|
|
if provider.URL["default"] != "https://api.siliconflow.cn/v1" {
|
|
t.Errorf("default URL=%q", provider.URL["default"])
|
|
}
|
|
if provider.URLSuffix.Chat != "chat/completions" {
|
|
t.Errorf("chat suffix=%q", provider.URLSuffix.Chat)
|
|
}
|
|
if _, ok := provider.ModelDriver.(*modeldrivers.SiliconflowModel); !ok {
|
|
t.Fatalf("ModelDriver=%T, want *models.SiliconflowModel", provider.ModelDriver)
|
|
}
|
|
if provider.ModelDriver.Name() != "siliconflow" {
|
|
t.Errorf("ModelDriver.Name()=%q", provider.ModelDriver.Name())
|
|
}
|
|
if len(provider.Models) != 12 {
|
|
t.Fatalf("SiliconFlow model count=%d, want 12", len(provider.Models))
|
|
}
|
|
|
|
deepSeekV4Pro, err := pm.GetModelByName("SiliconFlow", "Pro/deepseek-ai/DeepSeek-V4-Pro")
|
|
if err != nil {
|
|
t.Fatalf("GetModelByName DeepSeek-V4-Pro: %v", err)
|
|
}
|
|
if deepSeekV4Pro.MaxTokens != 1048576 {
|
|
t.Errorf("DeepSeek-V4-Pro max_tokens=%d", deepSeekV4Pro.MaxTokens)
|
|
}
|
|
if !deepSeekV4Pro.ModelTypeMap["chat"] {
|
|
t.Errorf("DeepSeek-V4-Pro model types=%v, want chat", deepSeekV4Pro.ModelTypes)
|
|
}
|
|
|
|
kimiK26, err := pm.GetModelByName("SiliconFlow", "Pro/moonshotai/Kimi-K2.6")
|
|
if err != nil {
|
|
t.Fatalf("GetModelByName Kimi-K2.6: %v", err)
|
|
}
|
|
if kimiK26.MaxTokens != 262144 {
|
|
t.Errorf("Kimi-K2.6 max_tokens=%d", kimiK26.MaxTokens)
|
|
}
|
|
if !kimiK26.ModelTypeMap["chat"] || !kimiK26.ModelTypeMap["vision"] {
|
|
t.Errorf("Kimi-K2.6 model types=%v, want chat+vision", kimiK26.ModelTypes)
|
|
}
|
|
|
|
glm51, err := pm.GetModelByName("SiliconFlow", "Pro/zai-org/GLM-5.1")
|
|
if err != nil {
|
|
t.Fatalf("GetModelByName GLM-5.1: %v", err)
|
|
}
|
|
if glm51.MaxTokens != 204800 {
|
|
t.Errorf("GLM-5.1 max_tokens=%d", glm51.MaxTokens)
|
|
}
|
|
}
|