diff --git a/api/apps/restful_apis/file_api.py b/api/apps/restful_apis/file_api.py index 58c6cde727..b67aa30ffc 100644 --- a/api/apps/restful_apis/file_api.py +++ b/api/apps/restful_apis/file_api.py @@ -335,7 +335,7 @@ async def parent_folder(tenant_id: str = None, file_id: str = None): description: Parent folder information. """ try: - success, result = file_api_service.get_parent_folder(file_id) + success, result = file_api_service.get_parent_folder(file_id, user_id=tenant_id) if success: return get_result(data=result) else: @@ -366,7 +366,7 @@ async def ancestors(tenant_id: str = None, file_id: str = None): description: List of ancestor folders. """ try: - success, result = file_api_service.get_all_parent_folders(file_id) + success, result = file_api_service.get_all_parent_folders(file_id, user_id=tenant_id) if success: return get_result(data=result) else: diff --git a/api/apps/services/file_api_service.py b/api/apps/services/file_api_service.py index 21dfaeb004..cfde3de294 100644 --- a/api/apps/services/file_api_service.py +++ b/api/apps/services/file_api_service.py @@ -174,32 +174,46 @@ def list_files(tenant_id: str, args: dict): -def get_parent_folder(file_id: str): +def get_parent_folder(file_id: str, user_id: str = None): """ - Get parent folder of a file. + Get parent folder of a file with permission check. :param file_id: file ID + :param user_id: user ID for permission validation :return: (success, result) or (success, error_message) """ + from api.common.check_team_permission import check_file_team_permission + e, file = FileService.get_by_id(file_id) if not e: return False, "Folder not found!" + # Permission check + if user_id and not check_file_team_permission(file, user_id): + return False, "No authorization." + parent_folder = FileService.get_parent_folder(file_id) return True, {"parent_folder": parent_folder.to_json()} -def get_all_parent_folders(file_id: str): +def get_all_parent_folders(file_id: str, user_id: str = None): """ - Get all ancestor folders of a file. + Get all ancestor folders of a file with permission check. :param file_id: file ID + :param user_id: user ID for permission validation :return: (success, result) or (success, error_message) """ + from api.common.check_team_permission import check_file_team_permission + e, file = FileService.get_by_id(file_id) if not e: return False, "Folder not found!" + # Permission check + if user_id and not check_file_team_permission(file, user_id): + return False, "No authorization." + parent_folders = FileService.get_all_parent_folders(file_id) return True, {"parent_folders": [pf.to_json() for pf in parent_folders]} diff --git a/internal/handler/file.go b/internal/handler/file.go index 195733146e..8c83e3b1f6 100644 --- a/internal/handler/file.go +++ b/internal/handler/file.go @@ -155,11 +155,12 @@ func (h *FileHandler) GetRootFolder(c *gin.Context) { // @Success 200 {object} map[string]interface{} // @Router /v1/file/parent_folder [get] func (h *FileHandler) GetParentFolder(c *gin.Context) { - _, errorCode, errorMessage := GetUser(c) + user, errorCode, errorMessage := GetUser(c) if errorCode != common.CodeSuccess { jsonError(c, errorCode, errorMessage) return } + userID := user.ID // Get file_id from query fileID := c.Query("file_id") @@ -168,8 +169,8 @@ func (h *FileHandler) GetParentFolder(c *gin.Context) { return } - // Get parent folder - parentFolder, err := h.fileService.GetParentFolder(fileID) + // Get parent folder with permission check + parentFolder, err := h.fileService.GetParentFolder(userID, fileID) if err != nil { jsonError(c, common.CodeServerError, err.Error()) return @@ -192,11 +193,12 @@ func (h *FileHandler) GetParentFolder(c *gin.Context) { // @Success 200 {object} map[string]interface{} // @Router /v1/file/all_parent_folder [get] func (h *FileHandler) GetAllParentFolders(c *gin.Context) { - _, errorCode, errorMessage := GetUser(c) + user, errorCode, errorMessage := GetUser(c) if errorCode != common.CodeSuccess { jsonError(c, errorCode, errorMessage) return } + userID := user.ID // Get file_id from query fileID := c.Query("file_id") @@ -205,8 +207,8 @@ func (h *FileHandler) GetAllParentFolders(c *gin.Context) { return } - // Get all parent folders - parentFolders, err := h.fileService.GetAllParentFolders(fileID) + // Get all parent folders with permission check + parentFolders, err := h.fileService.GetAllParentFolders(userID, fileID) if err != nil { jsonError(c, common.CodeServerError, err.Error()) return @@ -229,11 +231,12 @@ func (h *FileHandler) GetAllParentFolders(c *gin.Context) { // @Success 200 {object} map[string]interface{} // @Router /api/v1/files/{id}/ancestors [get] func (h *FileHandler) GetFileAncestors(c *gin.Context) { - _, errorCode, errorMessage := GetUser(c) + user, errorCode, errorMessage := GetUser(c) if errorCode != common.CodeSuccess { jsonError(c, errorCode, errorMessage) return } + userID := user.ID fileID := c.Param("id") if fileID == "" { @@ -241,7 +244,8 @@ func (h *FileHandler) GetFileAncestors(c *gin.Context) { return } - parentFolders, err := h.fileService.GetAllParentFolders(fileID) + // Get all parent folders with permission check + parentFolders, err := h.fileService.GetAllParentFolders(userID, fileID) if err != nil { jsonError(c, common.CodeServerError, err.Error()) return diff --git a/internal/service/file.go b/internal/service/file.go index 24d27f3acb..662d50010c 100644 --- a/internal/service/file.go +++ b/internal/service/file.go @@ -213,13 +213,19 @@ func (s *FileService) fileInfoToResponse(info *FileInfo) map[string]interface{} return result } -// GetParentFolder gets parent folder of a file -func (s *FileService) GetParentFolder(fileID string) (map[string]interface{}, error) { - // Check if file exists - if _, err := s.fileDAO.GetByID(fileID); err != nil { +// GetParentFolder gets parent folder of a file with permission check +func (s *FileService) GetParentFolder(userID, fileID string) (map[string]interface{}, error) { + // Get file + file, err := s.fileDAO.GetByID(fileID) + if err != nil { return nil, err } + // Permission check + if !s.checkFileTeamPermission(file, userID) { + return nil, fmt.Errorf("No authorization.") + } + // Get parent folder parentFolder, err := s.fileDAO.GetParentFolder(fileID) if err != nil { @@ -229,13 +235,19 @@ func (s *FileService) GetParentFolder(fileID string) (map[string]interface{}, er return s.toFileResponse(parentFolder), nil } -// GetAllParentFolders gets all parent folders in path -func (s *FileService) GetAllParentFolders(fileID string) ([]map[string]interface{}, error) { - // Check if file exists - if _, err := s.fileDAO.GetByID(fileID); err != nil { +// GetAllParentFolders gets all parent folders in path with permission check +func (s *FileService) GetAllParentFolders(userID, fileID string) ([]map[string]interface{}, error) { + // Get file + file, err := s.fileDAO.GetByID(fileID) + if err != nil { return nil, err } + // Permission check + if !s.checkFileTeamPermission(file, userID) { + return nil, fmt.Errorf("No authorization.") + } + // Get all parent folders parentFolders, err := s.fileDAO.GetAllParentFolders(fileID) if err != nil { diff --git a/test/testcases/test_web_api/test_file_app/test_file_routes_unit.py b/test/testcases/test_web_api/test_file_app/test_file_routes_unit.py index 87c37d4667..c1ff639ac1 100644 --- a/test/testcases/test_web_api/test_file_app/test_file_routes_unit.py +++ b/test/testcases/test_web_api/test_file_app/test_file_routes_unit.py @@ -133,8 +133,8 @@ def _load_file_api_module(monkeypatch): True, SimpleNamespace(parent_id="bucket1", location="path1", name="doc.txt", type="doc"), ) - file_api_service_mod.get_parent_folder = lambda _file_id: (True, {"parent_folder": {"id": "parent1"}}) - file_api_service_mod.get_all_parent_folders = lambda _file_id: (True, {"parent_folders": [{"id": "root"}]}) + file_api_service_mod.get_parent_folder = lambda _file_id, user_id=None: (True, {"parent_folder": {"id": "parent1"}}) + file_api_service_mod.get_all_parent_folders = lambda _file_id, user_id=None: (True, {"parent_folders": [{"id": "root"}]}) monkeypatch.setitem(sys.modules, "api.apps.services.file_api_service", file_api_service_mod) services_pkg.file_api_service = file_api_service_mod