mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
Refactor: consolidate to use @login_required (#15652)
Refactor: consolidate to use @login_required
This commit is contained in:
@@ -19,7 +19,6 @@ import functools
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from copy import deepcopy
|
||||
@@ -32,7 +31,7 @@ from quart import (
|
||||
request,
|
||||
has_app_context,
|
||||
)
|
||||
from werkzeug.exceptions import BadRequest as WerkzeugBadRequest, Unauthorized as WerkzeugUnauthorized
|
||||
from werkzeug.exceptions import BadRequest as WerkzeugBadRequest
|
||||
|
||||
try:
|
||||
from quart.exceptions import BadRequest as QuartBadRequest
|
||||
@@ -42,7 +41,6 @@ except ImportError: # pragma: no cover - optional dependency
|
||||
from peewee import OperationalError
|
||||
|
||||
from common.constants import ActiveEnum, LLMType
|
||||
from api.db.db_models import APIToken
|
||||
from api.utils.json_encode import CustomJSONEncoder
|
||||
from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||
from api.db.services.tenant_llm_service import LLMFactoriesService
|
||||
@@ -252,28 +250,6 @@ def get_json_result(code: RetCode = RetCode.SUCCESS, message="success", data=Non
|
||||
return _safe_jsonify(response)
|
||||
|
||||
|
||||
def apikey_required(func):
|
||||
@wraps(func)
|
||||
async def decorated_function(*args, **kwargs):
|
||||
authorization = request.headers.get("Authorization")
|
||||
if not authorization:
|
||||
return build_error_result(message="Authorization header is missing!", code=RetCode.FORBIDDEN)
|
||||
parts = authorization.split()
|
||||
if len(parts) < 2:
|
||||
return build_error_result(message="Please check your authorization format.", code=RetCode.FORBIDDEN)
|
||||
token = parts[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return build_error_result(message="API-KEY is invalid!", code=RetCode.FORBIDDEN)
|
||||
kwargs["tenant_id"] = objs[0].tenant_id
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return decorated_function
|
||||
|
||||
|
||||
def build_error_result(code=RetCode.FORBIDDEN, message="success"):
|
||||
response = {"code": code, "message": message}
|
||||
response = _safe_jsonify(response)
|
||||
@@ -288,69 +264,6 @@ def construct_json_result(code: RetCode = RetCode.SUCCESS, message="success", da
|
||||
return _safe_jsonify({"code": code, "message": message, "data": data})
|
||||
|
||||
|
||||
def token_required(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Validate the token (API Key)
|
||||
if os.environ.get("DISABLE_SDK"):
|
||||
err = WerkzeugUnauthorized(description="`Authorization` can't be empty")
|
||||
err.code = RetCode.SUCCESS
|
||||
raise err
|
||||
|
||||
authorization_str = request.headers.get("Authorization")
|
||||
if not authorization_str:
|
||||
err = WerkzeugUnauthorized(description="`Authorization` can't be empty")
|
||||
err.code = RetCode.SUCCESS
|
||||
raise err
|
||||
|
||||
authorization_list = authorization_str.split()
|
||||
if len(authorization_list) < 2:
|
||||
err = WerkzeugUnauthorized(description="Please check your authorization format.")
|
||||
err.code = RetCode.AUTHENTICATION_ERROR
|
||||
raise err
|
||||
|
||||
token = authorization_list[1]
|
||||
|
||||
# First try API token (explicit API token authentication)
|
||||
objs = APIToken.query(token=token)
|
||||
if objs:
|
||||
# On success, inject tenant_id into the route function's kwargs
|
||||
kwargs["tenant_id"] = objs[0].tenant_id
|
||||
result = func(*args, **kwargs)
|
||||
if inspect.iscoroutine(result):
|
||||
return await result
|
||||
return result
|
||||
|
||||
# Fallback: try login token (for clients that use login token as API token)
|
||||
# Login tokens are JWT-encoded (URLSafeTimedSerializer), need to decode to get raw access_token
|
||||
from api.db.services.user_service import UserService
|
||||
from common.constants import StatusEnum
|
||||
from common import settings
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
try:
|
||||
jwt = Serializer(secret_key=settings.get_secret_key())
|
||||
raw_token = str(jwt.loads(token))
|
||||
user = UserService.query(access_token=raw_token, status=StatusEnum.VALID.value)
|
||||
if user:
|
||||
# On success, inject tenant_id from user's tenant
|
||||
from api.db.services.user_service import UserTenantService
|
||||
tenants = UserTenantService.query(user_id=user[0].id)
|
||||
if tenants:
|
||||
kwargs["tenant_id"] = tenants[0].tenant_id
|
||||
result = func(*args, **kwargs)
|
||||
if inspect.iscoroutine(result):
|
||||
return await result
|
||||
return result
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
err = WerkzeugUnauthorized(description="Authentication error: API key is invalid!")
|
||||
err.code = RetCode.AUTHENTICATION_ERROR
|
||||
raise err
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def get_result(code=RetCode.SUCCESS, message="", data=None, total=None):
|
||||
"""
|
||||
Standard API response format:
|
||||
|
||||
Reference in New Issue
Block a user