mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
fix(auth): return HTTP 401 for token-auth failures (#13420)
Follow-up to #12488 #13386 ### What problem does this PR solve? Previously, token authentication failures returned HTTP 200 with an error code in the response body. This PR updates `token_required` to raise `Unauthorized` and relies on the global error handler to return a structured JSON response with HTTP 401 status. The response body structure (`code`, `message`, `data`) remains unchanged to preserve compatibility with the official SDK. Frontend logic has been updated to handle HTTP 401 responses in addition to checking `data.code`. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@@ -46,15 +46,15 @@ UNAUTHORIZED_MESSAGE = "<Unauthorized '401: Unauthorized'>"
|
||||
def _unauthorized_message(error):
|
||||
if error is None:
|
||||
return UNAUTHORIZED_MESSAGE
|
||||
|
||||
description = getattr(error, "description", None)
|
||||
if description:
|
||||
return description
|
||||
|
||||
try:
|
||||
msg = repr(error)
|
||||
return repr(error)
|
||||
except Exception:
|
||||
return UNAUTHORIZED_MESSAGE
|
||||
if msg == UNAUTHORIZED_MESSAGE:
|
||||
return msg
|
||||
if "Unauthorized" in msg and "401" in msg:
|
||||
return msg
|
||||
return UNAUTHORIZED_MESSAGE
|
||||
|
||||
app = Quart(__name__)
|
||||
app = cors(app, allow_origin="*")
|
||||
@@ -316,7 +316,7 @@ async def unauthorized_quart_auth(error):
|
||||
@app.errorhandler(WerkzeugUnauthorized)
|
||||
async def unauthorized_werkzeug(error):
|
||||
logging.warning("Unauthorized request (werkzeug)")
|
||||
return get_json_result(code=RetCode.UNAUTHORIZED, message=_unauthorized_message(error)), RetCode.UNAUTHORIZED
|
||||
return get_json_result(code=error.code, message=error.description), RetCode.UNAUTHORIZED
|
||||
|
||||
@app.teardown_request
|
||||
def _db_close(exception):
|
||||
|
||||
@@ -33,7 +33,7 @@ from quart import (
|
||||
request,
|
||||
has_app_context,
|
||||
)
|
||||
from werkzeug.exceptions import BadRequest as WerkzeugBadRequest
|
||||
from werkzeug.exceptions import BadRequest as WerkzeugBadRequest, Unauthorized as WerkzeugUnauthorized
|
||||
|
||||
try:
|
||||
from quart.exceptions import BadRequest as QuartBadRequest
|
||||
@@ -270,39 +270,41 @@ def construct_json_result(code: RetCode = RetCode.SUCCESS, message="success", da
|
||||
|
||||
|
||||
def token_required(func):
|
||||
def get_tenant_id(**kwargs):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Validate the token (API Key)
|
||||
if os.environ.get("DISABLE_SDK"):
|
||||
return False, get_json_result(data=False, message="`Authorization` can't be empty")
|
||||
err = WerkzeugUnauthorized(description="`Authorization` can't be empty")
|
||||
err.code = RetCode.SUCCESS
|
||||
raise err
|
||||
|
||||
authorization_str = request.headers.get("Authorization")
|
||||
if not authorization_str:
|
||||
return False, get_json_result(data=False, message="`Authorization` can't be empty")
|
||||
err = WerkzeugUnauthorized(description="`Authorization` can't be empty")
|
||||
err.code = RetCode.SUCCESS
|
||||
raise err
|
||||
|
||||
authorization_list = authorization_str.split()
|
||||
if len(authorization_list) < 2:
|
||||
return False, get_json_result(data=False, message="Please check your authorization format.")
|
||||
err = WerkzeugUnauthorized(description="Please check your authorization format.")
|
||||
err.code = RetCode.AUTHENTICATION_ERROR
|
||||
raise err
|
||||
|
||||
token = authorization_list[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return False, get_json_result(data=False, message="Authentication error: API key is invalid!", code=RetCode.AUTHENTICATION_ERROR)
|
||||
err = WerkzeugUnauthorized(description="Authentication error: API key is invalid!")
|
||||
err.code = RetCode.AUTHENTICATION_ERROR
|
||||
raise err
|
||||
|
||||
# On success, inject tenant_id into the route function's kwargs
|
||||
kwargs["tenant_id"] = objs[0].tenant_id
|
||||
return True, kwargs
|
||||
result = func(*args, **kwargs)
|
||||
if inspect.iscoroutine(result):
|
||||
return await result
|
||||
return result
|
||||
|
||||
@wraps(func)
|
||||
def decorated_function(*args, **kwargs):
|
||||
e, kwargs = get_tenant_id(**kwargs)
|
||||
if not e:
|
||||
return kwargs
|
||||
return func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
async def adecorated_function(*args, **kwargs):
|
||||
e, kwargs = get_tenant_id(**kwargs)
|
||||
if not e:
|
||||
return kwargs
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return adecorated_function
|
||||
return decorated_function
|
||||
return wrapper
|
||||
|
||||
|
||||
def get_result(code=RetCode.SUCCESS, message="", data=None, total=None):
|
||||
|
||||
@@ -116,7 +116,8 @@ def download_document(auth, dataset_id, document_id, save_path):
|
||||
url = f"{HOST_ADDRESS}{FILE_API_URL}/{document_id}".format(dataset_id=dataset_id)
|
||||
res = requests.get(url=url, auth=auth, stream=True)
|
||||
try:
|
||||
if res.status_code == 200:
|
||||
# available for unauthed downloads
|
||||
if res.status_code in (200, 401):
|
||||
with open(save_path, "wb") as f:
|
||||
for chunk in res.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
|
||||
@@ -40,7 +40,7 @@ class TestAuthorization:
|
||||
)
|
||||
def test_invalid_auth(self, invalid_auth, tmp_path, expected_code, expected_message):
|
||||
res = download_document(invalid_auth, "dataset_id", "document_id", tmp_path / "ragflow_tes.txt")
|
||||
assert res.status_code == codes.ok
|
||||
assert res.status_code == 401
|
||||
with (tmp_path / "ragflow_tes.txt").open("r") as f:
|
||||
response_json = json.load(f)
|
||||
assert response_json["code"] == expected_code
|
||||
|
||||
@@ -108,15 +108,14 @@ def test_module_init_and_unauthorized_message_variants(monkeypatch):
|
||||
def __repr__(self):
|
||||
return "Unauthorized 401 from upstream"
|
||||
|
||||
class _OtherRepr:
|
||||
def __repr__(self):
|
||||
return "Forbidden 403"
|
||||
class _WithDescription:
|
||||
description = "Custom description"
|
||||
|
||||
assert apps_module._unauthorized_message(None) == apps_module.UNAUTHORIZED_MESSAGE
|
||||
assert apps_module._unauthorized_message(_BrokenRepr()) == apps_module.UNAUTHORIZED_MESSAGE
|
||||
assert apps_module._unauthorized_message(_ExactUnauthorizedRepr()) == apps_module.UNAUTHORIZED_MESSAGE
|
||||
assert apps_module._unauthorized_message(_Unauthorized401Repr()) == "Unauthorized 401 from upstream"
|
||||
assert apps_module._unauthorized_message(_OtherRepr()) == apps_module.UNAUTHORIZED_MESSAGE
|
||||
assert apps_module._unauthorized_message(_WithDescription()) == "Custom description"
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
|
||||
@@ -73,6 +73,9 @@ const errorHandler = (error: {
|
||||
return response ?? { data: { code: 1999 } };
|
||||
};
|
||||
|
||||
// avoid duplicate 401 redirects
|
||||
let isRedirecting = false;
|
||||
|
||||
const request = axios.create({
|
||||
// errorHandler,
|
||||
timeout: 300000,
|
||||
@@ -123,13 +126,16 @@ request.interceptors.response.use(
|
||||
if (data?.code === 100) {
|
||||
message.error(data?.message);
|
||||
} else if (data?.code === 401) {
|
||||
notification.error({
|
||||
message: data?.message,
|
||||
description: data?.message,
|
||||
duration: 3,
|
||||
});
|
||||
authorizationUtil.removeAll();
|
||||
redirectToLogin();
|
||||
if (!isRedirecting) {
|
||||
isRedirecting = true;
|
||||
notification.error({
|
||||
message: data?.message,
|
||||
description: data?.message,
|
||||
duration: 3,
|
||||
});
|
||||
authorizationUtil.removeAll();
|
||||
redirectToLogin();
|
||||
}
|
||||
} else if (data?.code !== 0) {
|
||||
notification.error({
|
||||
message: `${i18n.t('message.hint')} : ${data?.code}`,
|
||||
@@ -141,6 +147,26 @@ request.interceptors.response.use(
|
||||
},
|
||||
function (error) {
|
||||
console.log('🚀 ~ error:', error);
|
||||
|
||||
// Handle HTTP 401 (token expired / invalid)
|
||||
const status = error?.response?.status;
|
||||
if (status === 401) {
|
||||
if (!isRedirecting) {
|
||||
isRedirecting = true;
|
||||
const messageText =
|
||||
error?.response?.data?.message || RetcodeMessage[401];
|
||||
notification.error({
|
||||
message: messageText,
|
||||
description: messageText,
|
||||
duration: 3,
|
||||
});
|
||||
authorizationUtil.removeAll();
|
||||
redirectToLogin();
|
||||
}
|
||||
|
||||
return Promise.reject(error);
|
||||
}
|
||||
|
||||
errorHandler(error);
|
||||
return Promise.reject(error);
|
||||
},
|
||||
|
||||
@@ -80,6 +80,9 @@ const request: RequestMethod = extend({
|
||||
getResponse: true,
|
||||
});
|
||||
|
||||
// avoid duplicate 401 redirects
|
||||
let isRedirecting = false;
|
||||
|
||||
request.interceptors.request.use((url: string, options: any) => {
|
||||
const data = convertTheKeysOfTheObjectToSnake(options.data);
|
||||
const params = convertTheKeysOfTheObjectToSnake(options.params);
|
||||
@@ -109,6 +112,27 @@ request.interceptors.response.use(async (response: Response, options) => {
|
||||
message.error(RetcodeMessage[response?.status as ResultCode]);
|
||||
}
|
||||
|
||||
// Handle HTTP 401
|
||||
if (response?.status === 401) {
|
||||
if (!isRedirecting) {
|
||||
isRedirecting = true;
|
||||
|
||||
const data = await response.clone().json().catch(() => ({}));
|
||||
|
||||
const messageText =
|
||||
data?.message || RetcodeMessage[401];
|
||||
notification.error({
|
||||
message: messageText,
|
||||
description: messageText,
|
||||
duration: 3,
|
||||
});
|
||||
authorizationUtil.removeAll();
|
||||
redirectToLogin();
|
||||
}
|
||||
|
||||
return response;
|
||||
}
|
||||
|
||||
if (options.responseType === 'blob') {
|
||||
return response;
|
||||
}
|
||||
@@ -126,11 +150,16 @@ request.interceptors.response.use(async (response: Response, options) => {
|
||||
if (data?.code === 100) {
|
||||
message.error(data?.message);
|
||||
} else if (data?.code === 401) {
|
||||
notification.error({
|
||||
message: data?.message,
|
||||
description: data?.message,
|
||||
duration: 3,
|
||||
});
|
||||
if (!isRedirecting) {
|
||||
isRedirecting = true;
|
||||
notification.error({
|
||||
message: data?.message,
|
||||
description: data?.message,
|
||||
duration: 3,
|
||||
});
|
||||
authorizationUtil.removeAll();
|
||||
redirectToLogin();
|
||||
}
|
||||
authorizationUtil.removeAll();
|
||||
redirectToLogin();
|
||||
} else if (data?.code !== 0) {
|
||||
|
||||
Reference in New Issue
Block a user