From ea3a5dba11d015a6cba5312f2cf4bc5ee46f0e9e Mon Sep 17 00:00:00 2001 From: bitloi <89318445+bitloi@users.noreply.github.com> Date: Thu, 28 May 2026 23:15:01 -0300 Subject: [PATCH] fix: validate custom model inputs (#15200) ### What problem does this PR solve? Closes #15199. The add-custom-model endpoint is routed through `/api/v1/providers/:provider_name/instances/:instance_name/models`, but the handler previously trusted `provider_name` and `instance_name` from the JSON body instead of the path target. A request could therefore hit one provider/instance URL while operating on a different body provider/instance. The same handler only rejected `model_types` when the slice was nil. An empty array passed validation and reached `ModelProviderService.AddCustomModel`, where `request.ModelTypes[0]` could panic. This PR makes the path provider/instance authoritative, rejects mismatched body values, rejects missing or empty `model_types`, and adds a service-level guard so direct service callers cannot hit the same panic path. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- internal/handler/providers.go | 63 +++++++------- internal/handler/providers_test.go | 111 +++++++++++++++++++++++++ internal/service/model_service.go | 7 ++ internal/service/model_service_test.go | 43 ++++++++++ 4 files changed, 196 insertions(+), 28 deletions(-) create mode 100644 internal/handler/providers_test.go diff --git a/internal/handler/providers.go b/internal/handler/providers.go index 9f7b238fc6..9c74250500 100644 --- a/internal/handler/providers.go +++ b/internal/handler/providers.go @@ -17,6 +17,7 @@ package handler import ( + "errors" "fmt" "net/http" "ragflow/internal/common" @@ -723,45 +724,51 @@ func (h *ProviderHandler) EnableOrDisableModel(c *gin.Context) { }) } +func prepareAddCustomModelRequest(req *service.AddCustomModelRequest, providerName, instanceName string) error { + if providerName == "" { + return errors.New("Provider name is required") + } + + if instanceName == "" { + return errors.New("Instance name is required") + } + + if req.ProviderName != "" && !strings.EqualFold(req.ProviderName, providerName) { + return errors.New("Provider name does not match path") + } + + if req.InstanceName != "" && !strings.EqualFold(req.InstanceName, instanceName) { + return errors.New("Instance name does not match path") + } + + if req.ModelName == "" { + return errors.New("Model name is required") + } + + if len(req.ModelTypes) == 0 { + return errors.New("Model type is required") + } + + req.ProviderName = providerName + req.InstanceName = instanceName + return nil +} + func (h *ProviderHandler) AddCustomModel(c *gin.Context) { var req service.AddCustomModelRequest if err := c.ShouldBindJSON(&req); err != nil { println("JSON bind error: %v (type: %T)", err, err) - c.JSON(http.StatusOK, gin.H{ + c.JSON(http.StatusBadRequest, gin.H{ "code": common.CodeBadRequest, "message": err.Error(), }) return } - if req.ProviderName == "" { + if err := prepareAddCustomModelRequest(&req, c.Param("provider_name"), c.Param("instance_name")); err != nil { c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, - "message": "Provider name is required", - }) - return - } - - if req.InstanceName == "" { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, - "message": "Instance name is required", - }) - return - } - - if req.ModelName == "" { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, - "message": "Model name is required", - }) - return - } - - if req.ModelTypes == nil { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, - "message": "Model type is required", + "code": common.CodeBadRequest, + "message": err.Error(), }) return } diff --git a/internal/handler/providers_test.go b/internal/handler/providers_test.go new file mode 100644 index 0000000000..87657da4f5 --- /dev/null +++ b/internal/handler/providers_test.go @@ -0,0 +1,111 @@ +package handler + +import ( + "testing" + + "ragflow/internal/service" +) + +func TestPrepareAddCustomModelRequestUsesPathTarget(t *testing.T) { + req := service.AddCustomModelRequest{ + ModelName: "custom-chat", + ModelTypes: []string{"chat"}, + } + + if err := prepareAddCustomModelRequest(&req, "openai", "default"); err != nil { + t.Fatalf("prepareAddCustomModelRequest returned error: %v", err) + } + if req.ProviderName != "openai" { + t.Fatalf("expected provider name from path, got %q", req.ProviderName) + } + if req.InstanceName != "default" { + t.Fatalf("expected instance name from path, got %q", req.InstanceName) + } +} + +func TestPrepareAddCustomModelRequestAcceptsCaseInsensitivePathMatch(t *testing.T) { + req := service.AddCustomModelRequest{ + ProviderName: "openai", + InstanceName: "default", + ModelName: "custom-chat", + ModelTypes: []string{"chat"}, + } + + if err := prepareAddCustomModelRequest(&req, "OpenAI", "Default"); err != nil { + t.Fatalf("prepareAddCustomModelRequest returned error: %v", err) + } + if req.ProviderName != "OpenAI" { + t.Fatalf("expected provider name from path, got %q", req.ProviderName) + } + if req.InstanceName != "Default" { + t.Fatalf("expected instance name from path, got %q", req.InstanceName) + } +} + +func TestPrepareAddCustomModelRequestRejectsPathMismatches(t *testing.T) { + tests := []struct { + name string + req service.AddCustomModelRequest + expectedErr string + }{ + { + name: "provider", + req: service.AddCustomModelRequest{ + ProviderName: "deepseek", + InstanceName: "default", + ModelName: "custom-chat", + ModelTypes: []string{"chat"}, + }, + expectedErr: "Provider name does not match path", + }, + { + name: "instance", + req: service.AddCustomModelRequest{ + ProviderName: "openai", + InstanceName: "other", + ModelName: "custom-chat", + ModelTypes: []string{"chat"}, + }, + expectedErr: "Instance name does not match path", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := prepareAddCustomModelRequest(&tt.req, "openai", "default") + if err == nil { + t.Fatal("expected mismatch error") + } + if err.Error() != tt.expectedErr { + t.Fatalf("expected %q, got %q", tt.expectedErr, err.Error()) + } + }) + } +} + +func TestPrepareAddCustomModelRequestRejectsEmptyModelTypes(t *testing.T) { + tests := []struct { + name string + modelTypes []string + }{ + {name: "nil"}, + {name: "empty", modelTypes: []string{}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := service.AddCustomModelRequest{ + ModelName: "custom-chat", + ModelTypes: tt.modelTypes, + } + + err := prepareAddCustomModelRequest(&req, "openai", "default") + if err == nil { + t.Fatal("expected empty model_types to return an error") + } + if err.Error() != "Model type is required" { + t.Fatalf("expected model type error, got %q", err.Error()) + } + }) + } +} diff --git a/internal/service/model_service.go b/internal/service/model_service.go index 9813010a28..004649c7c6 100644 --- a/internal/service/model_service.go +++ b/internal/service/model_service.go @@ -1915,6 +1915,13 @@ type AddCustomModelRequest struct { } func (m *ModelProviderService) AddCustomModel(request *AddCustomModelRequest, userID string) (common.ErrorCode, error) { + if request == nil { + return common.CodeBadRequest, errors.New("request is required") + } + if len(request.ModelTypes) == 0 { + return common.CodeBadRequest, errors.New("model type is required") + } + // Get tenant ID from user tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner") if err != nil { diff --git a/internal/service/model_service_test.go b/internal/service/model_service_test.go index 7b6e138c4c..c6edbadedd 100644 --- a/internal/service/model_service_test.go +++ b/internal/service/model_service_test.go @@ -4,6 +4,7 @@ import ( "strings" "testing" + "ragflow/internal/common" modelModule "ragflow/internal/entity/models" ) @@ -127,3 +128,45 @@ func TestNewModelDriverForBaseURLRejectsNilDriver(t *testing.T) { t.Fatalf("expected driver not found error, got %v", err) } } + +func TestAddCustomModelRejectsNilRequest(t *testing.T) { + service := &ModelProviderService{} + + code, err := service.AddCustomModel(nil, "user-id") + if err == nil { + t.Fatal("expected nil request to return an error") + } + if code != common.CodeBadRequest { + t.Fatalf("expected bad request code, got %v", code) + } +} + +func TestAddCustomModelRejectsEmptyModelTypes(t *testing.T) { + tests := []struct { + name string + modelTypes []string + }{ + {name: "nil"}, + {name: "empty", modelTypes: []string{}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + service := &ModelProviderService{} + req := &AddCustomModelRequest{ + ProviderName: "openai", + InstanceName: "default", + ModelName: "custom-chat", + ModelTypes: tt.modelTypes, + } + + code, err := service.AddCustomModel(req, "user-id") + if err == nil { + t.Fatal("expected empty model_types to return an error") + } + if code != common.CodeBadRequest { + t.Fatalf("expected bad request code, got %v", code) + } + }) + } +}