fix: validate custom model inputs (#15200)

### What problem does this PR solve?

Closes #15199.

The add-custom-model endpoint is routed through
`/api/v1/providers/:provider_name/instances/:instance_name/models`, but
the handler previously trusted `provider_name` and `instance_name` from
the JSON body instead of the path target. A request could therefore hit
one provider/instance URL while operating on a different body
provider/instance.

The same handler only rejected `model_types` when the slice was nil. An
empty array passed validation and reached
`ModelProviderService.AddCustomModel`, where `request.ModelTypes[0]`
could panic.

This PR makes the path provider/instance authoritative, rejects
mismatched body values, rejects missing or empty `model_types`, and adds
a service-level guard so direct service callers cannot hit the same
panic path.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
bitloi
2026-05-28 23:15:01 -03:00
committed by GitHub
parent 550bdf215c
commit ea3a5dba11
4 changed files with 196 additions and 28 deletions

View File

@@ -17,6 +17,7 @@
package handler
import (
"errors"
"fmt"
"net/http"
"ragflow/internal/common"
@@ -723,45 +724,51 @@ func (h *ProviderHandler) EnableOrDisableModel(c *gin.Context) {
})
}
func prepareAddCustomModelRequest(req *service.AddCustomModelRequest, providerName, instanceName string) error {
if providerName == "" {
return errors.New("Provider name is required")
}
if instanceName == "" {
return errors.New("Instance name is required")
}
if req.ProviderName != "" && !strings.EqualFold(req.ProviderName, providerName) {
return errors.New("Provider name does not match path")
}
if req.InstanceName != "" && !strings.EqualFold(req.InstanceName, instanceName) {
return errors.New("Instance name does not match path")
}
if req.ModelName == "" {
return errors.New("Model name is required")
}
if len(req.ModelTypes) == 0 {
return errors.New("Model type is required")
}
req.ProviderName = providerName
req.InstanceName = instanceName
return nil
}
func (h *ProviderHandler) AddCustomModel(c *gin.Context) {
var req service.AddCustomModelRequest
if err := c.ShouldBindJSON(&req); err != nil {
println("JSON bind error: %v (type: %T)", err, err)
c.JSON(http.StatusOK, gin.H{
c.JSON(http.StatusBadRequest, gin.H{
"code": common.CodeBadRequest,
"message": err.Error(),
})
return
}
if req.ProviderName == "" {
if err := prepareAddCustomModelRequest(&req, c.Param("provider_name"), c.Param("instance_name")); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"code": 400,
"message": "Provider name is required",
})
return
}
if req.InstanceName == "" {
c.JSON(http.StatusBadRequest, gin.H{
"code": 400,
"message": "Instance name is required",
})
return
}
if req.ModelName == "" {
c.JSON(http.StatusBadRequest, gin.H{
"code": 400,
"message": "Model name is required",
})
return
}
if req.ModelTypes == nil {
c.JSON(http.StatusBadRequest, gin.H{
"code": 400,
"message": "Model type is required",
"code": common.CodeBadRequest,
"message": err.Error(),
})
return
}

View File

@@ -0,0 +1,111 @@
package handler
import (
"testing"
"ragflow/internal/service"
)
func TestPrepareAddCustomModelRequestUsesPathTarget(t *testing.T) {
req := service.AddCustomModelRequest{
ModelName: "custom-chat",
ModelTypes: []string{"chat"},
}
if err := prepareAddCustomModelRequest(&req, "openai", "default"); err != nil {
t.Fatalf("prepareAddCustomModelRequest returned error: %v", err)
}
if req.ProviderName != "openai" {
t.Fatalf("expected provider name from path, got %q", req.ProviderName)
}
if req.InstanceName != "default" {
t.Fatalf("expected instance name from path, got %q", req.InstanceName)
}
}
func TestPrepareAddCustomModelRequestAcceptsCaseInsensitivePathMatch(t *testing.T) {
req := service.AddCustomModelRequest{
ProviderName: "openai",
InstanceName: "default",
ModelName: "custom-chat",
ModelTypes: []string{"chat"},
}
if err := prepareAddCustomModelRequest(&req, "OpenAI", "Default"); err != nil {
t.Fatalf("prepareAddCustomModelRequest returned error: %v", err)
}
if req.ProviderName != "OpenAI" {
t.Fatalf("expected provider name from path, got %q", req.ProviderName)
}
if req.InstanceName != "Default" {
t.Fatalf("expected instance name from path, got %q", req.InstanceName)
}
}
func TestPrepareAddCustomModelRequestRejectsPathMismatches(t *testing.T) {
tests := []struct {
name string
req service.AddCustomModelRequest
expectedErr string
}{
{
name: "provider",
req: service.AddCustomModelRequest{
ProviderName: "deepseek",
InstanceName: "default",
ModelName: "custom-chat",
ModelTypes: []string{"chat"},
},
expectedErr: "Provider name does not match path",
},
{
name: "instance",
req: service.AddCustomModelRequest{
ProviderName: "openai",
InstanceName: "other",
ModelName: "custom-chat",
ModelTypes: []string{"chat"},
},
expectedErr: "Instance name does not match path",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := prepareAddCustomModelRequest(&tt.req, "openai", "default")
if err == nil {
t.Fatal("expected mismatch error")
}
if err.Error() != tt.expectedErr {
t.Fatalf("expected %q, got %q", tt.expectedErr, err.Error())
}
})
}
}
func TestPrepareAddCustomModelRequestRejectsEmptyModelTypes(t *testing.T) {
tests := []struct {
name string
modelTypes []string
}{
{name: "nil"},
{name: "empty", modelTypes: []string{}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := service.AddCustomModelRequest{
ModelName: "custom-chat",
ModelTypes: tt.modelTypes,
}
err := prepareAddCustomModelRequest(&req, "openai", "default")
if err == nil {
t.Fatal("expected empty model_types to return an error")
}
if err.Error() != "Model type is required" {
t.Fatalf("expected model type error, got %q", err.Error())
}
})
}
}

View File

@@ -1915,6 +1915,13 @@ type AddCustomModelRequest struct {
}
func (m *ModelProviderService) AddCustomModel(request *AddCustomModelRequest, userID string) (common.ErrorCode, error) {
if request == nil {
return common.CodeBadRequest, errors.New("request is required")
}
if len(request.ModelTypes) == 0 {
return common.CodeBadRequest, errors.New("model type is required")
}
// Get tenant ID from user
tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner")
if err != nil {

View File

@@ -4,6 +4,7 @@ import (
"strings"
"testing"
"ragflow/internal/common"
modelModule "ragflow/internal/entity/models"
)
@@ -127,3 +128,45 @@ func TestNewModelDriverForBaseURLRejectsNilDriver(t *testing.T) {
t.Fatalf("expected driver not found error, got %v", err)
}
}
func TestAddCustomModelRejectsNilRequest(t *testing.T) {
service := &ModelProviderService{}
code, err := service.AddCustomModel(nil, "user-id")
if err == nil {
t.Fatal("expected nil request to return an error")
}
if code != common.CodeBadRequest {
t.Fatalf("expected bad request code, got %v", code)
}
}
func TestAddCustomModelRejectsEmptyModelTypes(t *testing.T) {
tests := []struct {
name string
modelTypes []string
}{
{name: "nil"},
{name: "empty", modelTypes: []string{}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
service := &ModelProviderService{}
req := &AddCustomModelRequest{
ProviderName: "openai",
InstanceName: "default",
ModelName: "custom-chat",
ModelTypes: tt.modelTypes,
}
code, err := service.AddCustomModel(req, "user-id")
if err == nil {
t.Fatal("expected empty model_types to return an error")
}
if code != common.CodeBadRequest {
t.Fatalf("expected bad request code, got %v", code)
}
})
}
}