Fix: consolidate beta auth (#15581)

Fix: consolidate beta auth
This commit is contained in:
Wang Qi
2026-06-03 19:58:06 +08:00
committed by GitHub
parent 2eed0d4679
commit b946df8ba2
6 changed files with 153 additions and 151 deletions

View File

@@ -87,12 +87,27 @@ commands.register_commands(app)
from functools import wraps
from typing import ParamSpec, TypeVar
from collections.abc import Awaitable, Callable
from collections.abc import Awaitable, Callable, Iterable
from werkzeug.local import LocalProxy
T = TypeVar("T")
P = ParamSpec("P")
AUTH_JWT = "JWT"
AUTH_API = "API"
AUTH_BETA = "BETA"
DEFAULT_AUTH_TYPES = (AUTH_JWT, AUTH_API)
def _normalize_auth_types(auth_types=None):
if auth_types is None:
return set(DEFAULT_AUTH_TYPES)
if isinstance(auth_types, str):
return {auth_types.upper()}
if isinstance(auth_types, Iterable):
return {str(auth_type).upper() for auth_type in auth_types}
return {str(auth_types).upper()}
def _load_user_from_session():
"""Resolve the current user from the session cookie set by ``login_user()``.
@@ -123,76 +138,103 @@ def _load_user_from_session():
if not access_token or len(access_token) < 32 or access_token.startswith("INVALID_"):
return None
logging.debug("Authenticated request via session fallback for user_id=%s", user_id)
g.auth_type = AUTH_JWT
g.user = user
return user
def _load_user():
jwt = Serializer(secret_key=settings.get_secret_key())
def _load_user(auth_types=None):
explicit_auth_types = auth_types is not None
auth_types = _normalize_auth_types(auth_types)
if getattr(g, "user", None) and (not explicit_auth_types or getattr(g, "auth_type", None) in auth_types):
return g.user
# No Authorization header, try to load user from session cookie if JWT auth is allowed
authorization = request.headers.get("Authorization")
g.user = None
g.auth_via_api_token = False
if not authorization:
return _load_user_from_session()
return _load_user_from_session() if AUTH_JWT in auth_types else None
# Extract auth_token based on whether Authorization starts with "bearer" (case-insensitive)
if authorization.lower().startswith("bearer "):
if authorization[:7].lower() == "bearer ":
parts = authorization.split(maxsplit=1)
if len(parts) < 2:
logging.warning("Authorization header has invalid bearer format")
return _load_user_from_session()
return _load_user_from_session() if AUTH_JWT in auth_types else None
auth_token = parts[1]
else:
auth_token = authorization
g.user = None
g.auth_type = None
g.auth_error_message = None
# Try Beta token
if AUTH_BETA in auth_types:
try:
objs = APIToken.query(beta=auth_token)
if objs:
user = UserService.query(id=objs[0].tenant_id, status=StatusEnum.VALID.value)
if user:
g.auth_type = AUTH_BETA
g.user = user[0]
return user[0]
g.auth_error_message = 'Authentication error: API key is invalid! '
except Exception as e_beta:
logging.warning(f"load_user from beta token got exception {e_beta}")
g.auth_error_message = 'Authentication error: API key is invalid!'
# Try JWT decoding
try:
access_token = str(jwt.loads(auth_token))
if AUTH_JWT in auth_types:
try:
jwt = Serializer(secret_key=settings.get_secret_key())
access_token = str(jwt.loads(auth_token))
if not access_token or not access_token.strip():
logging.warning("Authentication attempt with empty access token")
return _load_user_from_session()
if len(access_token.strip()) < 32:
logging.warning(f"Authentication attempt with invalid token format: {len(access_token)} chars")
return _load_user_from_session()
user = UserService.query(access_token=access_token, status=StatusEnum.VALID.value)
if user:
if not user[0].access_token or not user[0].access_token.strip():
logging.warning(f"User {user[0].email} has empty access_token in database")
if not access_token or not access_token.strip():
logging.warning("Authentication attempt with empty access token")
return _load_user_from_session()
g.user = user[0]
return user[0]
return _load_user_from_session()
except Exception as e_jwt:
logging.warning(f"load_user from jwt got exception {e_jwt}")
# JWT decode failed, try as api_token
try:
objs = APIToken.query(token=auth_token)
if objs:
user = UserService.query(id=objs[0].tenant_id, status=StatusEnum.VALID.value)
if len(access_token.strip()) < 32:
logging.warning(f"Authentication attempt with invalid token format: {len(access_token)} chars")
return _load_user_from_session()
user = UserService.query(access_token=access_token, status=StatusEnum.VALID.value)
if user:
if not user[0].access_token or not user[0].access_token.strip():
logging.warning(f"User {user[0].email} has empty access_token in database")
return _load_user_from_session()
g.auth_via_api_token = True
g.auth_type = AUTH_JWT
g.user = user[0]
return user[0]
logging.warning(f"load_user: No user found for tenant_id={objs[0].tenant_id} from APIToken")
else:
logging.warning(f"load_user: No APIToken found for token={auth_token[:10]}...")
except Exception as e_api_token:
logging.warning(f"load_user from api token got exception {e_api_token}")
return _load_user_from_session()
except Exception as e_jwt:
logging.warning(f"load_user from jwt got exception {e_jwt}")
return _load_user_from_session()
# JWT decode failed, try as api_token
if AUTH_API in auth_types:
try:
objs = APIToken.query(token=auth_token)
if objs:
user = UserService.query(id=objs[0].tenant_id, status=StatusEnum.VALID.value)
if user:
if not user[0].access_token or not user[0].access_token.strip():
logging.warning(f"User {user[0].email} has empty access_token in database")
return _load_user_from_session() if AUTH_JWT in auth_types else None
g.auth_type = AUTH_API
g.user = user[0]
return user[0]
logging.warning(f"load_user: No user found for tenant_id={objs[0].tenant_id} from APIToken")
else:
logging.warning(f"load_user: No APIToken found for token={auth_token[:10]}...")
except Exception as e_api_token:
logging.warning(f"load_user from api token got exception {e_api_token}")
return _load_user_from_session() if AUTH_JWT in auth_types else None
current_user = LocalProxy(_load_user)
def login_required(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
def login_required(func: Callable[P, Awaitable[T]] = None, auth_types=None) -> Callable[P, Awaitable[T]]:
"""A decorator to restrict route access to authenticated users.
This should be used to wrap a route handler (or view function) to
@@ -212,22 +254,32 @@ def login_required(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]
"""
@wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
timing_enabled = os.getenv("RAGFLOW_API_TIMING")
t_start = time.perf_counter() if timing_enabled else None
user = current_user
if timing_enabled:
logging.info(
"api_timing login_required auth_ms=%.2f path=%s",
(time.perf_counter() - t_start) * 1000,
request.path,
)
if not user: # or not session.get("_user_id"):
raise QuartAuthUnauthorized()
return await current_app.ensure_async(func)(*args, **kwargs)
def decorator(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
@wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
timing_enabled = os.getenv("RAGFLOW_API_TIMING")
t_start = time.perf_counter() if timing_enabled else None
user = _load_user(auth_types)
if timing_enabled:
logging.info(
"api_timing login_required auth_ms=%.2f path=%s",
(time.perf_counter() - t_start) * 1000,
request.path,
)
if not user: # or not session.get("_user_id"):
if _normalize_auth_types(auth_types) == {AUTH_BETA}:
return get_json_result(
code=RetCode.DATA_ERROR,
message=getattr(g, "auth_error_message", None) or "Authorization is not valid!",
)
raise QuartAuthUnauthorized()
return await current_app.ensure_async(func)(*args, **kwargs)
return wrapper
return wrapper
if func is None:
return decorator
return decorator(func)
def login_user(user, remember=False, duration=None, force=False, fresh=True):

View File

@@ -22,7 +22,7 @@ import logging
from quart import Response, request
from agent.canvas import Canvas
from api.db.db_models import APIToken
from api.apps import AUTH_BETA, login_required
from api.db.services.api_service import API4ConversationService
from api.db.services.canvas_service import UserCanvasService
from api.db.services.canvas_service import completion as agent_completion
@@ -37,7 +37,7 @@ from api.db.services.user_service import UserTenantService
from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type, get_model_config_from_provider_instance
from common.misc_utils import thread_pool_exec
from api.utils.api_utils import get_error_data_result, get_json_result, \
get_result, get_request_json, server_error_response, validate_request
add_tenant_id_to_kwargs, get_result, get_request_json, server_error_response, validate_request
from rag.app.tag import label_question
from rag.prompts.template import load_prompt
from rag.prompts.generator import cross_languages, keyword_extraction
@@ -51,24 +51,12 @@ from api.utils.reference_metadata_utils import (
logger = logging.getLogger(__name__)
def _get_sdk_authorization_token():
token = request.headers.get("Authorization", "").split()
if len(token) != 2:
return None
return token[1]
@manager.route("/chatbots/<dialog_id>/completions", methods=["POST"]) # noqa: F821
async def chatbot_completions(dialog_id):
@login_required(auth_types=AUTH_BETA)
@add_tenant_id_to_kwargs
async def chatbot_completions(dialog_id, tenant_id=None):
req = await get_request_json()
token = _get_sdk_authorization_token()
if not token:
return get_error_data_result(message='Authorization is not valid!')
objs = await thread_pool_exec(APIToken.query, beta=token)
if not objs:
return get_error_data_result(message='Authentication error: API key is invalid!"')
tenant_id = objs[0].tenant_id
exists, dialog = DialogService.get_by_id(dialog_id)
if (not exists
or getattr(dialog, "tenant_id", None) != tenant_id
@@ -135,14 +123,9 @@ async def chatbot_completions(dialog_id):
return None
@manager.route("/chatbots/<dialog_id>/info", methods=["GET"]) # noqa: F821
async def chatbots_inputs(dialog_id):
token = _get_sdk_authorization_token()
if not token:
return get_error_data_result(message='Authorization is not valid!')
objs = await thread_pool_exec(APIToken.query, beta=token)
if not objs:
return get_error_data_result(message='Authentication error: API key is invalid!"')
tenant_id = objs[0].tenant_id
@login_required(auth_types=AUTH_BETA)
@add_tenant_id_to_kwargs
async def chatbots_inputs(dialog_id, tenant_id=None):
exists, dialog = await thread_pool_exec(DialogService.get_by_id, dialog_id)
if (not exists
or getattr(dialog, "tenant_id", None) != tenant_id
@@ -170,20 +153,15 @@ async def chatbots_inputs(dialog_id):
@manager.route("/agentbots/<agent_id>/completions", methods=["POST"]) # noqa: F821
async def agent_bot_completions(agent_id):
@login_required(auth_types=AUTH_BETA)
@add_tenant_id_to_kwargs
async def agent_bot_completions(agent_id, tenant_id=None):
req = await get_request_json()
token = _get_sdk_authorization_token()
if not token:
return get_error_data_result(message='Authorization is not valid!')
objs = await thread_pool_exec(APIToken.query, beta=token)
if not objs:
return get_error_data_result(message='Authentication error: API key is invalid!"')
if req.get("stream", True):
async def stream():
try:
async for answer in agent_completion(objs[0].tenant_id, agent_id, **req):
async for answer in agent_completion(tenant_id, agent_id, **req):
yield answer
except Exception as e:
logging.exception(e)
@@ -209,7 +187,7 @@ async def agent_bot_completions(agent_id):
reference = {}
structured_output = {}
final_ans = {}
async for answer in agent_completion(objs[0].tenant_id, agent_id, **req):
async for answer in agent_completion(tenant_id, agent_id, **req):
# agent_completion yields SSE-formatted strings. A single yielded
# chunk can contain multiple "data:..." frames separated by "\n\n"
# plus blank or comment lines, so parse line-by-line rather than
@@ -257,36 +235,26 @@ async def agent_bot_completions(agent_id):
@manager.route("/agentbots/<agent_id>/inputs", methods=["GET"]) # noqa: F821
async def begin_inputs(agent_id):
token = _get_sdk_authorization_token()
if not token:
return get_error_data_result(message='Authorization is not valid!')
objs = await thread_pool_exec(APIToken.query, beta=token)
if not objs:
return get_error_data_result(message='Authentication error: API key is invalid!"')
@login_required(auth_types=AUTH_BETA)
@add_tenant_id_to_kwargs
async def begin_inputs(agent_id, tenant_id=None):
e, cvs = await thread_pool_exec(UserCanvasService.get_by_id, agent_id)
if not e:
return get_error_data_result(f"Can't find agent by ID: {agent_id}")
canvas = Canvas(json.dumps(cvs.dsl), objs[0].tenant_id, canvas_id=cvs.id)
canvas = Canvas(json.dumps(cvs.dsl), tenant_id, canvas_id=cvs.id)
return get_result(
data={"title": cvs.title, "avatar": cvs.avatar, "inputs": canvas.get_component_input_form("begin"),
"prologue": canvas.get_prologue(), "mode": canvas.get_mode()})
@manager.route("/searchbots/ask", methods=["POST"]) # noqa: F821
@login_required(auth_types=AUTH_BETA)
@add_tenant_id_to_kwargs
@validate_request("question", "kb_ids")
async def ask_about_embedded():
token = _get_sdk_authorization_token()
if not token:
return get_error_data_result(message='Authorization is not valid!')
objs = await thread_pool_exec(APIToken.query, beta=token)
if not objs:
return get_error_data_result(message='Authentication error: API key is invalid!"')
async def ask_about_embedded(tenant_id=None):
req = await get_request_json()
uid = objs[0].tenant_id
uid = tenant_id
search_id = req.get("search_id", "")
search_config = {}
@@ -314,15 +282,10 @@ async def ask_about_embedded():
@manager.route("/searchbots/retrieval_test", methods=["POST"]) # noqa: F821
@login_required(auth_types=AUTH_BETA)
@add_tenant_id_to_kwargs
@validate_request("kb_id", "question")
async def retrieval_test_embedded():
token = _get_sdk_authorization_token()
if not token:
return get_error_data_result(message='Authorization is not valid!')
objs = await thread_pool_exec(APIToken.query, beta=token)
if not objs:
return get_error_data_result(message='Authentication error: API key is invalid!"')
async def retrieval_test_embedded(tenant_id=None):
req = await get_request_json()
page = int(req.get("page", 1))
size = int(req.get("size", 30))
@@ -342,7 +305,6 @@ async def retrieval_test_embedded():
return get_error_data_result("`top_k` must be greater than 0")
langs = req.get("cross_languages", [])
rerank_id = req.get("rerank_id", "")
tenant_id = objs[0].tenant_id
if not tenant_id:
return get_error_data_result(message="permission denined.")
search_config = {}
@@ -456,17 +418,11 @@ async def retrieval_test_embedded():
@manager.route("/searchbots/related_questions", methods=["POST"]) # noqa: F821
@login_required(auth_types=AUTH_BETA)
@add_tenant_id_to_kwargs
@validate_request("question")
async def related_questions_embedded():
token = _get_sdk_authorization_token()
if not token:
return get_error_data_result(message='Authorization is not valid!')
objs = await thread_pool_exec(APIToken.query, beta=token)
if not objs:
return get_error_data_result(message='Authentication error: API key is invalid!"')
async def related_questions_embedded(tenant_id=None):
req = await get_request_json()
tenant_id = objs[0].tenant_id
if not tenant_id:
return get_error_data_result(message="permission denined.")
@@ -504,16 +460,10 @@ Related search terms:
@manager.route("/searchbots/detail", methods=["GET"]) # noqa: F821
async def detail_share_embedded():
token = _get_sdk_authorization_token()
if not token:
return get_error_data_result(message='Authorization is not valid!')
objs = await thread_pool_exec(APIToken.query, beta=token)
if not objs:
return get_error_data_result(message='Authentication error: API key is invalid!"')
@login_required(auth_types=AUTH_BETA)
@add_tenant_id_to_kwargs
async def detail_share_embedded(tenant_id=None):
search_id = request.args["search_id"]
tenant_id = objs[0].tenant_id
if not tenant_id:
return get_error_data_result(message="permission denined.")
try:
@@ -534,16 +484,10 @@ async def detail_share_embedded():
@manager.route("/searchbots/mindmap", methods=["POST"]) # noqa: F821
@login_required(auth_types=AUTH_BETA)
@add_tenant_id_to_kwargs
@validate_request("question", "kb_ids")
async def mindmap():
token = _get_sdk_authorization_token()
if not token:
return get_error_data_result(message='Authorization is not valid!')
objs = await thread_pool_exec(APIToken.query, beta=token)
if not objs:
return get_error_data_result(message='Authentication error: API key is invalid!"')
tenant_id = objs[0].tenant_id
async def mindmap(tenant_id=None):
req = await get_request_json()
search_id = req.get("search_id", "")

View File

@@ -24,7 +24,7 @@ from quart import request, make_response,send_file
from peewee import OperationalError
from pydantic import ValidationError
from api.apps import login_required
from api.apps import AUTH_JWT, AUTH_API, AUTH_BETA, login_required
from api.constants import FILE_NAME_LEN_LIMIT, IMG_BASE64_PREFIX
from api.apps.services.document_api_service import validate_document_update_fields, map_doc_keys, \
map_doc_keys_with_run_status, update_document_name_only, update_chunk_method, update_document_status_only, \
@@ -1191,6 +1191,7 @@ async def update_metadata_config(tenant_id, dataset_id, document_id):
@manager.route("/thumbnails", methods=["GET"]) # noqa: F821
@login_required(auth_types=[AUTH_JWT, AUTH_API, AUTH_BETA])
def list_thumbnails():
"""
Get thumbnails for documents.
@@ -1687,6 +1688,7 @@ def _content_type_for_document_image(object_name, data):
@manager.route("/documents/images/<image_id>", methods=["GET"]) # noqa: F821
@login_required(auth_types=[AUTH_JWT, AUTH_API, AUTH_BETA])
async def get_document_image(image_id):
"""
Get a document image by ID.
@@ -1908,7 +1910,7 @@ async def batch_update_document_status(tenant_id, dataset_id):
return get_json_result(data=result)
@manager.route("/documents/<doc_id>/preview", methods=["GET"]) # noqa: F821
@login_required
@login_required(auth_types=[AUTH_JWT, AUTH_API, AUTH_BETA])
async def get(doc_id):
"""Return the raw file bytes for a document the requesting user is authorized to read.

View File

@@ -20,8 +20,8 @@ from test.testcases.restful_api.helpers.client import RestClient
@pytest.mark.p2
def test_document_image_invalid_id_contract(rest_client_noauth):
res = rest_client_noauth.get("/documents/images/not-a-valid-image-id")
def test_document_image_invalid_id_contract(rest_client):
res = rest_client.get("/documents/images/not-a-valid-image-id")
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 102, payload

View File

@@ -390,7 +390,7 @@ def test_related_questions_contract(rest_client, rest_client_noauth):
assert success_payload["code"] == 0, success_payload
assert isinstance(success_payload["data"], list), success_payload
missing_res = rest_client.post("/searchbots/related_questions", json={"industry": "search"})
missing_res = success_client.post("/searchbots/related_questions", json={"industry": "search"})
assert missing_res.status_code == 200
missing_payload = missing_res.json()
assert missing_payload["code"] == 101, missing_payload
@@ -404,4 +404,8 @@ def test_related_questions_contract(rest_client, rest_client_noauth):
assert invalid_auth_res.status_code == 200
invalid_auth_payload = invalid_auth_res.json()
assert invalid_auth_payload["code"] == 102, invalid_auth_payload
assert "Authorization is not valid!" in invalid_auth_payload["message"], invalid_auth_payload
assert invalid_auth_payload["message"].strip() in {
"Authorization is not valid!",
'Authentication error: API key is invalid!"',
"Authentication error: API key is invalid!",
}, invalid_auth_payload

View File

@@ -611,7 +611,7 @@ def test_related_questions_compatibility_requires_auth(rest_client_noauth):
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 102, payload
assert payload["message"] in {
assert payload["message"].strip() in {
"Authorization is not valid!",
'Authentication error: API key is invalid!"',
"Authentication error: API key is invalid!",