From b946df8ba20662577b1e4501d4c2bfb0473806db Mon Sep 17 00:00:00 2001 From: Wang Qi Date: Wed, 3 Jun 2026 19:58:06 +0800 Subject: [PATCH] Fix: consolidate beta auth (#15581) Fix: consolidate beta auth --- api/apps/__init__.py | 162 ++++++++++++------ api/apps/restful_apis/bot_api.py | 122 ++++--------- api/apps/restful_apis/document_api.py | 6 +- .../restful_api/test_document_raw_routes.py | 4 +- test/testcases/restful_api/test_retrieval.py | 8 +- test/testcases/restful_api/test_sessions.py | 2 +- 6 files changed, 153 insertions(+), 151 deletions(-) diff --git a/api/apps/__init__.py b/api/apps/__init__.py index b8da01423c..07c28b0087 100644 --- a/api/apps/__init__.py +++ b/api/apps/__init__.py @@ -87,12 +87,27 @@ commands.register_commands(app) from functools import wraps from typing import ParamSpec, TypeVar -from collections.abc import Awaitable, Callable +from collections.abc import Awaitable, Callable, Iterable from werkzeug.local import LocalProxy T = TypeVar("T") P = ParamSpec("P") +AUTH_JWT = "JWT" +AUTH_API = "API" +AUTH_BETA = "BETA" +DEFAULT_AUTH_TYPES = (AUTH_JWT, AUTH_API) + + +def _normalize_auth_types(auth_types=None): + if auth_types is None: + return set(DEFAULT_AUTH_TYPES) + if isinstance(auth_types, str): + return {auth_types.upper()} + if isinstance(auth_types, Iterable): + return {str(auth_type).upper() for auth_type in auth_types} + return {str(auth_types).upper()} + def _load_user_from_session(): """Resolve the current user from the session cookie set by ``login_user()``. @@ -123,76 +138,103 @@ def _load_user_from_session(): if not access_token or len(access_token) < 32 or access_token.startswith("INVALID_"): return None logging.debug("Authenticated request via session fallback for user_id=%s", user_id) + g.auth_type = AUTH_JWT g.user = user return user -def _load_user(): - jwt = Serializer(secret_key=settings.get_secret_key()) +def _load_user(auth_types=None): + explicit_auth_types = auth_types is not None + auth_types = _normalize_auth_types(auth_types) + if getattr(g, "user", None) and (not explicit_auth_types or getattr(g, "auth_type", None) in auth_types): + return g.user + + # No Authorization header, try to load user from session cookie if JWT auth is allowed authorization = request.headers.get("Authorization") - g.user = None - g.auth_via_api_token = False if not authorization: - return _load_user_from_session() + return _load_user_from_session() if AUTH_JWT in auth_types else None # Extract auth_token based on whether Authorization starts with "bearer" (case-insensitive) - if authorization.lower().startswith("bearer "): + if authorization[:7].lower() == "bearer ": parts = authorization.split(maxsplit=1) if len(parts) < 2: logging.warning("Authorization header has invalid bearer format") - return _load_user_from_session() + return _load_user_from_session() if AUTH_JWT in auth_types else None auth_token = parts[1] else: auth_token = authorization + g.user = None + g.auth_type = None + g.auth_error_message = None + + # Try Beta token + if AUTH_BETA in auth_types: + try: + objs = APIToken.query(beta=auth_token) + if objs: + user = UserService.query(id=objs[0].tenant_id, status=StatusEnum.VALID.value) + if user: + g.auth_type = AUTH_BETA + g.user = user[0] + return user[0] + g.auth_error_message = 'Authentication error: API key is invalid! ' + except Exception as e_beta: + logging.warning(f"load_user from beta token got exception {e_beta}") + g.auth_error_message = 'Authentication error: API key is invalid!' + # Try JWT decoding - try: - access_token = str(jwt.loads(auth_token)) + if AUTH_JWT in auth_types: + try: + jwt = Serializer(secret_key=settings.get_secret_key()) + access_token = str(jwt.loads(auth_token)) - if not access_token or not access_token.strip(): - logging.warning("Authentication attempt with empty access token") - return _load_user_from_session() - - if len(access_token.strip()) < 32: - logging.warning(f"Authentication attempt with invalid token format: {len(access_token)} chars") - return _load_user_from_session() - - user = UserService.query(access_token=access_token, status=StatusEnum.VALID.value) - if user: - if not user[0].access_token or not user[0].access_token.strip(): - logging.warning(f"User {user[0].email} has empty access_token in database") + if not access_token or not access_token.strip(): + logging.warning("Authentication attempt with empty access token") return _load_user_from_session() - g.user = user[0] - return user[0] - return _load_user_from_session() - except Exception as e_jwt: - logging.warning(f"load_user from jwt got exception {e_jwt}") - # JWT decode failed, try as api_token - try: - objs = APIToken.query(token=auth_token) - if objs: - user = UserService.query(id=objs[0].tenant_id, status=StatusEnum.VALID.value) + if len(access_token.strip()) < 32: + logging.warning(f"Authentication attempt with invalid token format: {len(access_token)} chars") + return _load_user_from_session() + + user = UserService.query(access_token=access_token, status=StatusEnum.VALID.value) if user: if not user[0].access_token or not user[0].access_token.strip(): logging.warning(f"User {user[0].email} has empty access_token in database") return _load_user_from_session() - g.auth_via_api_token = True + g.auth_type = AUTH_JWT g.user = user[0] return user[0] - logging.warning(f"load_user: No user found for tenant_id={objs[0].tenant_id} from APIToken") - else: - logging.warning(f"load_user: No APIToken found for token={auth_token[:10]}...") - except Exception as e_api_token: - logging.warning(f"load_user from api token got exception {e_api_token}") + return _load_user_from_session() + except Exception as e_jwt: + logging.warning(f"load_user from jwt got exception {e_jwt}") - return _load_user_from_session() + # JWT decode failed, try as api_token + if AUTH_API in auth_types: + try: + objs = APIToken.query(token=auth_token) + if objs: + user = UserService.query(id=objs[0].tenant_id, status=StatusEnum.VALID.value) + if user: + if not user[0].access_token or not user[0].access_token.strip(): + logging.warning(f"User {user[0].email} has empty access_token in database") + return _load_user_from_session() if AUTH_JWT in auth_types else None + g.auth_type = AUTH_API + g.user = user[0] + return user[0] + logging.warning(f"load_user: No user found for tenant_id={objs[0].tenant_id} from APIToken") + else: + logging.warning(f"load_user: No APIToken found for token={auth_token[:10]}...") + except Exception as e_api_token: + logging.warning(f"load_user from api token got exception {e_api_token}") + + return _load_user_from_session() if AUTH_JWT in auth_types else None current_user = LocalProxy(_load_user) -def login_required(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]: +def login_required(func: Callable[P, Awaitable[T]] = None, auth_types=None) -> Callable[P, Awaitable[T]]: """A decorator to restrict route access to authenticated users. This should be used to wrap a route handler (or view function) to @@ -212,22 +254,32 @@ def login_required(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]] """ - @wraps(func) - async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: - timing_enabled = os.getenv("RAGFLOW_API_TIMING") - t_start = time.perf_counter() if timing_enabled else None - user = current_user - if timing_enabled: - logging.info( - "api_timing login_required auth_ms=%.2f path=%s", - (time.perf_counter() - t_start) * 1000, - request.path, - ) - if not user: # or not session.get("_user_id"): - raise QuartAuthUnauthorized() - return await current_app.ensure_async(func)(*args, **kwargs) + def decorator(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]: + @wraps(func) + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + timing_enabled = os.getenv("RAGFLOW_API_TIMING") + t_start = time.perf_counter() if timing_enabled else None + user = _load_user(auth_types) + if timing_enabled: + logging.info( + "api_timing login_required auth_ms=%.2f path=%s", + (time.perf_counter() - t_start) * 1000, + request.path, + ) + if not user: # or not session.get("_user_id"): + if _normalize_auth_types(auth_types) == {AUTH_BETA}: + return get_json_result( + code=RetCode.DATA_ERROR, + message=getattr(g, "auth_error_message", None) or "Authorization is not valid!", + ) + raise QuartAuthUnauthorized() + return await current_app.ensure_async(func)(*args, **kwargs) - return wrapper + return wrapper + + if func is None: + return decorator + return decorator(func) def login_user(user, remember=False, duration=None, force=False, fresh=True): diff --git a/api/apps/restful_apis/bot_api.py b/api/apps/restful_apis/bot_api.py index ee4d0c68a0..bbb79e7d54 100644 --- a/api/apps/restful_apis/bot_api.py +++ b/api/apps/restful_apis/bot_api.py @@ -22,7 +22,7 @@ import logging from quart import Response, request from agent.canvas import Canvas -from api.db.db_models import APIToken +from api.apps import AUTH_BETA, login_required from api.db.services.api_service import API4ConversationService from api.db.services.canvas_service import UserCanvasService from api.db.services.canvas_service import completion as agent_completion @@ -37,7 +37,7 @@ from api.db.services.user_service import UserTenantService from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type, get_model_config_from_provider_instance from common.misc_utils import thread_pool_exec from api.utils.api_utils import get_error_data_result, get_json_result, \ - get_result, get_request_json, server_error_response, validate_request + add_tenant_id_to_kwargs, get_result, get_request_json, server_error_response, validate_request from rag.app.tag import label_question from rag.prompts.template import load_prompt from rag.prompts.generator import cross_languages, keyword_extraction @@ -51,24 +51,12 @@ from api.utils.reference_metadata_utils import ( logger = logging.getLogger(__name__) -def _get_sdk_authorization_token(): - token = request.headers.get("Authorization", "").split() - if len(token) != 2: - return None - return token[1] - - @manager.route("/chatbots//completions", methods=["POST"]) # noqa: F821 -async def chatbot_completions(dialog_id): +@login_required(auth_types=AUTH_BETA) +@add_tenant_id_to_kwargs +async def chatbot_completions(dialog_id, tenant_id=None): req = await get_request_json() - token = _get_sdk_authorization_token() - if not token: - return get_error_data_result(message='Authorization is not valid!') - objs = await thread_pool_exec(APIToken.query, beta=token) - if not objs: - return get_error_data_result(message='Authentication error: API key is invalid!"') - tenant_id = objs[0].tenant_id exists, dialog = DialogService.get_by_id(dialog_id) if (not exists or getattr(dialog, "tenant_id", None) != tenant_id @@ -135,14 +123,9 @@ async def chatbot_completions(dialog_id): return None @manager.route("/chatbots//info", methods=["GET"]) # noqa: F821 -async def chatbots_inputs(dialog_id): - token = _get_sdk_authorization_token() - if not token: - return get_error_data_result(message='Authorization is not valid!') - objs = await thread_pool_exec(APIToken.query, beta=token) - if not objs: - return get_error_data_result(message='Authentication error: API key is invalid!"') - tenant_id = objs[0].tenant_id +@login_required(auth_types=AUTH_BETA) +@add_tenant_id_to_kwargs +async def chatbots_inputs(dialog_id, tenant_id=None): exists, dialog = await thread_pool_exec(DialogService.get_by_id, dialog_id) if (not exists or getattr(dialog, "tenant_id", None) != tenant_id @@ -170,20 +153,15 @@ async def chatbots_inputs(dialog_id): @manager.route("/agentbots//completions", methods=["POST"]) # noqa: F821 -async def agent_bot_completions(agent_id): +@login_required(auth_types=AUTH_BETA) +@add_tenant_id_to_kwargs +async def agent_bot_completions(agent_id, tenant_id=None): req = await get_request_json() - token = _get_sdk_authorization_token() - if not token: - return get_error_data_result(message='Authorization is not valid!') - objs = await thread_pool_exec(APIToken.query, beta=token) - if not objs: - return get_error_data_result(message='Authentication error: API key is invalid!"') - if req.get("stream", True): async def stream(): try: - async for answer in agent_completion(objs[0].tenant_id, agent_id, **req): + async for answer in agent_completion(tenant_id, agent_id, **req): yield answer except Exception as e: logging.exception(e) @@ -209,7 +187,7 @@ async def agent_bot_completions(agent_id): reference = {} structured_output = {} final_ans = {} - async for answer in agent_completion(objs[0].tenant_id, agent_id, **req): + async for answer in agent_completion(tenant_id, agent_id, **req): # agent_completion yields SSE-formatted strings. A single yielded # chunk can contain multiple "data:..." frames separated by "\n\n" # plus blank or comment lines, so parse line-by-line rather than @@ -257,36 +235,26 @@ async def agent_bot_completions(agent_id): @manager.route("/agentbots//inputs", methods=["GET"]) # noqa: F821 -async def begin_inputs(agent_id): - token = _get_sdk_authorization_token() - if not token: - return get_error_data_result(message='Authorization is not valid!') - objs = await thread_pool_exec(APIToken.query, beta=token) - if not objs: - return get_error_data_result(message='Authentication error: API key is invalid!"') - +@login_required(auth_types=AUTH_BETA) +@add_tenant_id_to_kwargs +async def begin_inputs(agent_id, tenant_id=None): e, cvs = await thread_pool_exec(UserCanvasService.get_by_id, agent_id) if not e: return get_error_data_result(f"Can't find agent by ID: {agent_id}") - canvas = Canvas(json.dumps(cvs.dsl), objs[0].tenant_id, canvas_id=cvs.id) + canvas = Canvas(json.dumps(cvs.dsl), tenant_id, canvas_id=cvs.id) return get_result( data={"title": cvs.title, "avatar": cvs.avatar, "inputs": canvas.get_component_input_form("begin"), "prologue": canvas.get_prologue(), "mode": canvas.get_mode()}) @manager.route("/searchbots/ask", methods=["POST"]) # noqa: F821 +@login_required(auth_types=AUTH_BETA) +@add_tenant_id_to_kwargs @validate_request("question", "kb_ids") -async def ask_about_embedded(): - token = _get_sdk_authorization_token() - if not token: - return get_error_data_result(message='Authorization is not valid!') - objs = await thread_pool_exec(APIToken.query, beta=token) - if not objs: - return get_error_data_result(message='Authentication error: API key is invalid!"') - +async def ask_about_embedded(tenant_id=None): req = await get_request_json() - uid = objs[0].tenant_id + uid = tenant_id search_id = req.get("search_id", "") search_config = {} @@ -314,15 +282,10 @@ async def ask_about_embedded(): @manager.route("/searchbots/retrieval_test", methods=["POST"]) # noqa: F821 +@login_required(auth_types=AUTH_BETA) +@add_tenant_id_to_kwargs @validate_request("kb_id", "question") -async def retrieval_test_embedded(): - token = _get_sdk_authorization_token() - if not token: - return get_error_data_result(message='Authorization is not valid!') - objs = await thread_pool_exec(APIToken.query, beta=token) - if not objs: - return get_error_data_result(message='Authentication error: API key is invalid!"') - +async def retrieval_test_embedded(tenant_id=None): req = await get_request_json() page = int(req.get("page", 1)) size = int(req.get("size", 30)) @@ -342,7 +305,6 @@ async def retrieval_test_embedded(): return get_error_data_result("`top_k` must be greater than 0") langs = req.get("cross_languages", []) rerank_id = req.get("rerank_id", "") - tenant_id = objs[0].tenant_id if not tenant_id: return get_error_data_result(message="permission denined.") search_config = {} @@ -456,17 +418,11 @@ async def retrieval_test_embedded(): @manager.route("/searchbots/related_questions", methods=["POST"]) # noqa: F821 +@login_required(auth_types=AUTH_BETA) +@add_tenant_id_to_kwargs @validate_request("question") -async def related_questions_embedded(): - token = _get_sdk_authorization_token() - if not token: - return get_error_data_result(message='Authorization is not valid!') - objs = await thread_pool_exec(APIToken.query, beta=token) - if not objs: - return get_error_data_result(message='Authentication error: API key is invalid!"') - +async def related_questions_embedded(tenant_id=None): req = await get_request_json() - tenant_id = objs[0].tenant_id if not tenant_id: return get_error_data_result(message="permission denined.") @@ -504,16 +460,10 @@ Related search terms: @manager.route("/searchbots/detail", methods=["GET"]) # noqa: F821 -async def detail_share_embedded(): - token = _get_sdk_authorization_token() - if not token: - return get_error_data_result(message='Authorization is not valid!') - objs = await thread_pool_exec(APIToken.query, beta=token) - if not objs: - return get_error_data_result(message='Authentication error: API key is invalid!"') - +@login_required(auth_types=AUTH_BETA) +@add_tenant_id_to_kwargs +async def detail_share_embedded(tenant_id=None): search_id = request.args["search_id"] - tenant_id = objs[0].tenant_id if not tenant_id: return get_error_data_result(message="permission denined.") try: @@ -534,16 +484,10 @@ async def detail_share_embedded(): @manager.route("/searchbots/mindmap", methods=["POST"]) # noqa: F821 +@login_required(auth_types=AUTH_BETA) +@add_tenant_id_to_kwargs @validate_request("question", "kb_ids") -async def mindmap(): - token = _get_sdk_authorization_token() - if not token: - return get_error_data_result(message='Authorization is not valid!') - objs = await thread_pool_exec(APIToken.query, beta=token) - if not objs: - return get_error_data_result(message='Authentication error: API key is invalid!"') - - tenant_id = objs[0].tenant_id +async def mindmap(tenant_id=None): req = await get_request_json() search_id = req.get("search_id", "") diff --git a/api/apps/restful_apis/document_api.py b/api/apps/restful_apis/document_api.py index a639c5e7b8..868ab28615 100644 --- a/api/apps/restful_apis/document_api.py +++ b/api/apps/restful_apis/document_api.py @@ -24,7 +24,7 @@ from quart import request, make_response,send_file from peewee import OperationalError from pydantic import ValidationError -from api.apps import login_required +from api.apps import AUTH_JWT, AUTH_API, AUTH_BETA, login_required from api.constants import FILE_NAME_LEN_LIMIT, IMG_BASE64_PREFIX from api.apps.services.document_api_service import validate_document_update_fields, map_doc_keys, \ map_doc_keys_with_run_status, update_document_name_only, update_chunk_method, update_document_status_only, \ @@ -1191,6 +1191,7 @@ async def update_metadata_config(tenant_id, dataset_id, document_id): @manager.route("/thumbnails", methods=["GET"]) # noqa: F821 +@login_required(auth_types=[AUTH_JWT, AUTH_API, AUTH_BETA]) def list_thumbnails(): """ Get thumbnails for documents. @@ -1687,6 +1688,7 @@ def _content_type_for_document_image(object_name, data): @manager.route("/documents/images/", methods=["GET"]) # noqa: F821 +@login_required(auth_types=[AUTH_JWT, AUTH_API, AUTH_BETA]) async def get_document_image(image_id): """ Get a document image by ID. @@ -1908,7 +1910,7 @@ async def batch_update_document_status(tenant_id, dataset_id): return get_json_result(data=result) @manager.route("/documents//preview", methods=["GET"]) # noqa: F821 -@login_required +@login_required(auth_types=[AUTH_JWT, AUTH_API, AUTH_BETA]) async def get(doc_id): """Return the raw file bytes for a document the requesting user is authorized to read. diff --git a/test/testcases/restful_api/test_document_raw_routes.py b/test/testcases/restful_api/test_document_raw_routes.py index d1ff2520f1..36866650cf 100644 --- a/test/testcases/restful_api/test_document_raw_routes.py +++ b/test/testcases/restful_api/test_document_raw_routes.py @@ -20,8 +20,8 @@ from test.testcases.restful_api.helpers.client import RestClient @pytest.mark.p2 -def test_document_image_invalid_id_contract(rest_client_noauth): - res = rest_client_noauth.get("/documents/images/not-a-valid-image-id") +def test_document_image_invalid_id_contract(rest_client): + res = rest_client.get("/documents/images/not-a-valid-image-id") assert res.status_code == 200 payload = res.json() assert payload["code"] == 102, payload diff --git a/test/testcases/restful_api/test_retrieval.py b/test/testcases/restful_api/test_retrieval.py index d75f331651..212f4d6261 100644 --- a/test/testcases/restful_api/test_retrieval.py +++ b/test/testcases/restful_api/test_retrieval.py @@ -390,7 +390,7 @@ def test_related_questions_contract(rest_client, rest_client_noauth): assert success_payload["code"] == 0, success_payload assert isinstance(success_payload["data"], list), success_payload - missing_res = rest_client.post("/searchbots/related_questions", json={"industry": "search"}) + missing_res = success_client.post("/searchbots/related_questions", json={"industry": "search"}) assert missing_res.status_code == 200 missing_payload = missing_res.json() assert missing_payload["code"] == 101, missing_payload @@ -404,4 +404,8 @@ def test_related_questions_contract(rest_client, rest_client_noauth): assert invalid_auth_res.status_code == 200 invalid_auth_payload = invalid_auth_res.json() assert invalid_auth_payload["code"] == 102, invalid_auth_payload - assert "Authorization is not valid!" in invalid_auth_payload["message"], invalid_auth_payload + assert invalid_auth_payload["message"].strip() in { + "Authorization is not valid!", + 'Authentication error: API key is invalid!"', + "Authentication error: API key is invalid!", + }, invalid_auth_payload diff --git a/test/testcases/restful_api/test_sessions.py b/test/testcases/restful_api/test_sessions.py index 3d7d30a713..7a23a39790 100644 --- a/test/testcases/restful_api/test_sessions.py +++ b/test/testcases/restful_api/test_sessions.py @@ -611,7 +611,7 @@ def test_related_questions_compatibility_requires_auth(rest_client_noauth): assert res.status_code == 200 payload = res.json() assert payload["code"] == 102, payload - assert payload["message"] in { + assert payload["message"].strip() in { "Authorization is not valid!", 'Authentication error: API key is invalid!"', "Authentication error: API key is invalid!",