mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
fix: validate custom model inputs (#15200)
### What problem does this PR solve? Closes #15199. The add-custom-model endpoint is routed through `/api/v1/providers/:provider_name/instances/:instance_name/models`, but the handler previously trusted `provider_name` and `instance_name` from the JSON body instead of the path target. A request could therefore hit one provider/instance URL while operating on a different body provider/instance. The same handler only rejected `model_types` when the slice was nil. An empty array passed validation and reached `ModelProviderService.AddCustomModel`, where `request.ModelTypes[0]` could panic. This PR makes the path provider/instance authoritative, rejects mismatched body values, rejects missing or empty `model_types`, and adds a service-level guard so direct service callers cannot hit the same panic path. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@@ -17,6 +17,7 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"ragflow/internal/common"
|
||||
@@ -723,45 +724,51 @@ func (h *ProviderHandler) EnableOrDisableModel(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
func prepareAddCustomModelRequest(req *service.AddCustomModelRequest, providerName, instanceName string) error {
|
||||
if providerName == "" {
|
||||
return errors.New("Provider name is required")
|
||||
}
|
||||
|
||||
if instanceName == "" {
|
||||
return errors.New("Instance name is required")
|
||||
}
|
||||
|
||||
if req.ProviderName != "" && !strings.EqualFold(req.ProviderName, providerName) {
|
||||
return errors.New("Provider name does not match path")
|
||||
}
|
||||
|
||||
if req.InstanceName != "" && !strings.EqualFold(req.InstanceName, instanceName) {
|
||||
return errors.New("Instance name does not match path")
|
||||
}
|
||||
|
||||
if req.ModelName == "" {
|
||||
return errors.New("Model name is required")
|
||||
}
|
||||
|
||||
if len(req.ModelTypes) == 0 {
|
||||
return errors.New("Model type is required")
|
||||
}
|
||||
|
||||
req.ProviderName = providerName
|
||||
req.InstanceName = instanceName
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *ProviderHandler) AddCustomModel(c *gin.Context) {
|
||||
var req service.AddCustomModelRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
println("JSON bind error: %v (type: %T)", err, err)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": common.CodeBadRequest,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if req.ProviderName == "" {
|
||||
if err := prepareAddCustomModelRequest(&req, c.Param("provider_name"), c.Param("instance_name")); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Provider name is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if req.InstanceName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Instance name is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if req.ModelName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Model name is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if req.ModelTypes == nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "Model type is required",
|
||||
"code": common.CodeBadRequest,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
111
internal/handler/providers_test.go
Normal file
111
internal/handler/providers_test.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"ragflow/internal/service"
|
||||
)
|
||||
|
||||
func TestPrepareAddCustomModelRequestUsesPathTarget(t *testing.T) {
|
||||
req := service.AddCustomModelRequest{
|
||||
ModelName: "custom-chat",
|
||||
ModelTypes: []string{"chat"},
|
||||
}
|
||||
|
||||
if err := prepareAddCustomModelRequest(&req, "openai", "default"); err != nil {
|
||||
t.Fatalf("prepareAddCustomModelRequest returned error: %v", err)
|
||||
}
|
||||
if req.ProviderName != "openai" {
|
||||
t.Fatalf("expected provider name from path, got %q", req.ProviderName)
|
||||
}
|
||||
if req.InstanceName != "default" {
|
||||
t.Fatalf("expected instance name from path, got %q", req.InstanceName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareAddCustomModelRequestAcceptsCaseInsensitivePathMatch(t *testing.T) {
|
||||
req := service.AddCustomModelRequest{
|
||||
ProviderName: "openai",
|
||||
InstanceName: "default",
|
||||
ModelName: "custom-chat",
|
||||
ModelTypes: []string{"chat"},
|
||||
}
|
||||
|
||||
if err := prepareAddCustomModelRequest(&req, "OpenAI", "Default"); err != nil {
|
||||
t.Fatalf("prepareAddCustomModelRequest returned error: %v", err)
|
||||
}
|
||||
if req.ProviderName != "OpenAI" {
|
||||
t.Fatalf("expected provider name from path, got %q", req.ProviderName)
|
||||
}
|
||||
if req.InstanceName != "Default" {
|
||||
t.Fatalf("expected instance name from path, got %q", req.InstanceName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareAddCustomModelRequestRejectsPathMismatches(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
req service.AddCustomModelRequest
|
||||
expectedErr string
|
||||
}{
|
||||
{
|
||||
name: "provider",
|
||||
req: service.AddCustomModelRequest{
|
||||
ProviderName: "deepseek",
|
||||
InstanceName: "default",
|
||||
ModelName: "custom-chat",
|
||||
ModelTypes: []string{"chat"},
|
||||
},
|
||||
expectedErr: "Provider name does not match path",
|
||||
},
|
||||
{
|
||||
name: "instance",
|
||||
req: service.AddCustomModelRequest{
|
||||
ProviderName: "openai",
|
||||
InstanceName: "other",
|
||||
ModelName: "custom-chat",
|
||||
ModelTypes: []string{"chat"},
|
||||
},
|
||||
expectedErr: "Instance name does not match path",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := prepareAddCustomModelRequest(&tt.req, "openai", "default")
|
||||
if err == nil {
|
||||
t.Fatal("expected mismatch error")
|
||||
}
|
||||
if err.Error() != tt.expectedErr {
|
||||
t.Fatalf("expected %q, got %q", tt.expectedErr, err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareAddCustomModelRequestRejectsEmptyModelTypes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
modelTypes []string
|
||||
}{
|
||||
{name: "nil"},
|
||||
{name: "empty", modelTypes: []string{}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := service.AddCustomModelRequest{
|
||||
ModelName: "custom-chat",
|
||||
ModelTypes: tt.modelTypes,
|
||||
}
|
||||
|
||||
err := prepareAddCustomModelRequest(&req, "openai", "default")
|
||||
if err == nil {
|
||||
t.Fatal("expected empty model_types to return an error")
|
||||
}
|
||||
if err.Error() != "Model type is required" {
|
||||
t.Fatalf("expected model type error, got %q", err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1915,6 +1915,13 @@ type AddCustomModelRequest struct {
|
||||
}
|
||||
|
||||
func (m *ModelProviderService) AddCustomModel(request *AddCustomModelRequest, userID string) (common.ErrorCode, error) {
|
||||
if request == nil {
|
||||
return common.CodeBadRequest, errors.New("request is required")
|
||||
}
|
||||
if len(request.ModelTypes) == 0 {
|
||||
return common.CodeBadRequest, errors.New("model type is required")
|
||||
}
|
||||
|
||||
// Get tenant ID from user
|
||||
tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner")
|
||||
if err != nil {
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"ragflow/internal/common"
|
||||
modelModule "ragflow/internal/entity/models"
|
||||
)
|
||||
|
||||
@@ -127,3 +128,45 @@ func TestNewModelDriverForBaseURLRejectsNilDriver(t *testing.T) {
|
||||
t.Fatalf("expected driver not found error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddCustomModelRejectsNilRequest(t *testing.T) {
|
||||
service := &ModelProviderService{}
|
||||
|
||||
code, err := service.AddCustomModel(nil, "user-id")
|
||||
if err == nil {
|
||||
t.Fatal("expected nil request to return an error")
|
||||
}
|
||||
if code != common.CodeBadRequest {
|
||||
t.Fatalf("expected bad request code, got %v", code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddCustomModelRejectsEmptyModelTypes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
modelTypes []string
|
||||
}{
|
||||
{name: "nil"},
|
||||
{name: "empty", modelTypes: []string{}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
service := &ModelProviderService{}
|
||||
req := &AddCustomModelRequest{
|
||||
ProviderName: "openai",
|
||||
InstanceName: "default",
|
||||
ModelName: "custom-chat",
|
||||
ModelTypes: tt.modelTypes,
|
||||
}
|
||||
|
||||
code, err := service.AddCustomModel(req, "user-id")
|
||||
if err == nil {
|
||||
t.Fatal("expected empty model_types to return an error")
|
||||
}
|
||||
if code != common.CodeBadRequest {
|
||||
t.Fatalf("expected bad request code, got %v", code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user