mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
fix(auth): fall back to session-based auth in _load_user (#14569)
## Summary Closes #13663. OAuth / OIDC callbacks call `login_user(user)` which writes `_user_id` into the session cookie, but `_load_user()` in `api/apps/__init__.py` only ever looked at the `Authorization` header. The SPA's response interceptor wipes the Authorization value from `localStorage` on the first 401 it sees — meaning that during the post-redirect window after an OAuth login, a single transient 401 sends every subsequent request back to the login page even though `login_user()` had already established a perfectly good server-side session. The reporter's analysis traces this all the way through the redirect → `navigate('/')` → first request → empty header → 401 → `removeAll()` → infinite-redirect-to-login chain. ## What changed - New `_load_user_from_session()` helper that reads `session["_user_id"]`, looks up the user in `UserService` (with the same `StatusEnum.VALID` and `access_token` checks already used elsewhere), and assigns `g.user`. - Every `return None` path in `_load_user()` now routes through that helper before giving up: - missing `Authorization` header - malformed `bearer ` prefix - empty / too-short JWT payload - JWT signature failure - JWT-resolved user not found / has no `access_token` - `APIToken.query()` fallback exhausted The JWT and API-token paths still take precedence — the session is only consulted when those can't authenticate the request. So existing local-login and SDK callers see no behaviour change; only OAuth / OIDC users that hit the original race now stay logged in. The Bearer-prefix issue called out in #13663 (lines 103-110) is already handled in the current code, so this PR only addresses the second half of the report. ## Test plan - [ ] Configure OIDC under `oauth` in `service_conf.yaml` - [ ] Click the OIDC login button, complete auth at the IdP - [ ] Confirm that navigating between pages no longer bounces back to `/login` - [ ] Confirm local email/password login still issues + accepts JWTs - [ ] Confirm SDK/API key callers still authenticate via `Authorization: Bearer <api-token>` --------- Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
@@ -56,6 +56,7 @@ def _unauthorized_message(error):
|
||||
except Exception:
|
||||
return UNAUTHORIZED_MESSAGE
|
||||
|
||||
|
||||
app = Quart(__name__)
|
||||
app = cors(app, allow_origin="*")
|
||||
|
||||
@@ -92,19 +93,52 @@ T = TypeVar("T")
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
def _load_user_from_session():
|
||||
"""Resolve the current user from the session cookie set by ``login_user()``.
|
||||
|
||||
OAuth/OIDC callbacks call ``login_user(user)`` which writes ``_user_id``
|
||||
into the session. The frontend's response interceptor wipes the
|
||||
Authorization header from localStorage on the first 401, so post-redirect
|
||||
requests can arrive with no header at all — we still want to honour the
|
||||
server-side session in that window.
|
||||
|
||||
The same access-token validity rules used by the JWT path are applied
|
||||
here so that tokens revoked by ``logout`` (which rewrites the column to
|
||||
``INVALID_<hex>``) or shortened by data corruption can't keep a stale
|
||||
session authenticated.
|
||||
"""
|
||||
user_id = session.get("_user_id")
|
||||
if not user_id:
|
||||
return None
|
||||
try:
|
||||
users = UserService.query(id=user_id, status=StatusEnum.VALID.value)
|
||||
except Exception:
|
||||
logging.exception("load_user from session failed")
|
||||
return None
|
||||
if not users:
|
||||
return None
|
||||
user = users[0]
|
||||
access_token = str(user.access_token or "").strip()
|
||||
if not access_token or len(access_token) < 32 or access_token.startswith("INVALID_"):
|
||||
return None
|
||||
logging.debug("Authenticated request via session fallback for user_id=%s", user_id)
|
||||
g.user = user
|
||||
return user
|
||||
|
||||
|
||||
def _load_user():
|
||||
jwt = Serializer(secret_key=settings.get_secret_key())
|
||||
authorization = request.headers.get("Authorization")
|
||||
g.user = None
|
||||
if not authorization:
|
||||
return None
|
||||
return _load_user_from_session()
|
||||
|
||||
# Extract auth_token based on whether Authorization starts with "bearer" (case-insensitive)
|
||||
if authorization.lower().startswith("bearer "):
|
||||
parts = authorization.split(maxsplit=1)
|
||||
if len(parts) < 2:
|
||||
logging.warning("Authorization header has invalid bearer format")
|
||||
return None
|
||||
return _load_user_from_session()
|
||||
auth_token = parts[1]
|
||||
else:
|
||||
auth_token = authorization
|
||||
@@ -115,20 +149,20 @@ def _load_user():
|
||||
|
||||
if not access_token or not access_token.strip():
|
||||
logging.warning("Authentication attempt with empty access token")
|
||||
return None
|
||||
return _load_user_from_session()
|
||||
|
||||
if len(access_token.strip()) < 32:
|
||||
logging.warning(f"Authentication attempt with invalid token format: {len(access_token)} chars")
|
||||
return None
|
||||
return _load_user_from_session()
|
||||
|
||||
user = UserService.query(access_token=access_token, status=StatusEnum.VALID.value)
|
||||
if user:
|
||||
if not user[0].access_token or not user[0].access_token.strip():
|
||||
logging.warning(f"User {user[0].email} has empty access_token in database")
|
||||
return None
|
||||
return _load_user_from_session()
|
||||
g.user = user[0]
|
||||
return user[0]
|
||||
return None
|
||||
return _load_user_from_session()
|
||||
except Exception as e_jwt:
|
||||
logging.warning(f"load_user from jwt got exception {e_jwt}")
|
||||
|
||||
@@ -140,7 +174,7 @@ def _load_user():
|
||||
if user:
|
||||
if not user[0].access_token or not user[0].access_token.strip():
|
||||
logging.warning(f"User {user[0].email} has empty access_token in database")
|
||||
return None
|
||||
return _load_user_from_session()
|
||||
g.user = user[0]
|
||||
return user[0]
|
||||
logging.warning(f"load_user: No user found for tenant_id={objs[0].tenant_id} from APIToken")
|
||||
@@ -149,7 +183,7 @@ def _load_user():
|
||||
except Exception as e_api_token:
|
||||
logging.warning(f"load_user from api token got exception {e_api_token}")
|
||||
|
||||
return None
|
||||
return _load_user_from_session()
|
||||
|
||||
|
||||
current_user = LocalProxy(_load_user)
|
||||
@@ -251,16 +285,10 @@ def logout_user():
|
||||
|
||||
|
||||
def search_pages_path(page_path):
|
||||
app_path_list = [
|
||||
path for path in page_path.glob("*_app.py") if not path.name.startswith(".")
|
||||
]
|
||||
api_path_list = [
|
||||
path for path in page_path.glob("*sdk/*.py") if not path.name.startswith(".")
|
||||
]
|
||||
app_path_list = [path for path in page_path.glob("*_app.py") if not path.name.startswith(".")]
|
||||
api_path_list = [path for path in page_path.glob("*sdk/*.py") if not path.name.startswith(".")]
|
||||
app_path_list.extend(api_path_list)
|
||||
restful_api_path_list = [
|
||||
path for path in page_path.glob("*restful_apis/*.py") if not path.name.startswith(".")
|
||||
]
|
||||
restful_api_path_list = [path for path in page_path.glob("*restful_apis/*.py") if not path.name.startswith(".")]
|
||||
app_path_list.extend(restful_api_path_list)
|
||||
return app_path_list
|
||||
|
||||
@@ -269,9 +297,7 @@ def register_page(page_path):
|
||||
path = f"{page_path}"
|
||||
|
||||
page_name = page_path.stem.removesuffix("_app")
|
||||
module_name = ".".join(
|
||||
page_path.parts[page_path.parts.index("api"): -1] + (page_name,)
|
||||
)
|
||||
module_name = ".".join(page_path.parts[page_path.parts.index("api") : -1] + (page_name,))
|
||||
|
||||
spec = spec_from_file_location(module_name, page_path)
|
||||
page = module_from_spec(spec)
|
||||
@@ -282,9 +308,7 @@ def register_page(page_path):
|
||||
page_name = getattr(page, "page_name", page_name)
|
||||
sdk_path = "\\sdk\\" if sys.platform.startswith("win") else "/sdk/"
|
||||
restful_api_path = "\\restful_apis\\" if sys.platform.startswith("win") else "/restful_apis/"
|
||||
url_prefix = (
|
||||
f"/api/{API_VERSION}" if sdk_path in path or restful_api_path in path else f"/{API_VERSION}/{page_name}"
|
||||
)
|
||||
url_prefix = f"/api/{API_VERSION}" if sdk_path in path or restful_api_path in path else f"/{API_VERSION}/{page_name}"
|
||||
|
||||
app.register_blueprint(page.manager, url_prefix=url_prefix)
|
||||
return url_prefix
|
||||
@@ -297,12 +321,11 @@ pages_dir = [
|
||||
Path(__file__).parent.parent / "api" / "apps" / "sdk",
|
||||
]
|
||||
|
||||
client_urls_prefix = [
|
||||
register_page(path) for directory in pages_dir for path in search_pages_path(directory)
|
||||
]
|
||||
client_urls_prefix = [register_page(path) for directory in pages_dir for path in search_pages_path(directory)]
|
||||
|
||||
# Register backward compatibility routes for deprecated APIs
|
||||
from api.apps.backward_compat import register_backward_compat_routes
|
||||
|
||||
register_backward_compat_routes(app)
|
||||
|
||||
|
||||
@@ -336,6 +359,7 @@ async def unauthorized_werkzeug(error):
|
||||
logging.warning("Unauthorized request (werkzeug)")
|
||||
return get_json_result(code=error.code, message=error.description), RetCode.UNAUTHORIZED
|
||||
|
||||
|
||||
@app.teardown_request
|
||||
def _db_close(exception):
|
||||
if exception:
|
||||
|
||||
@@ -175,6 +175,96 @@ def test_load_user_api_token_fallback_and_fallback_exception(monkeypatch, caplog
|
||||
assert "api token fallback failed" in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_load_user_session_fallback(monkeypatch, caplog):
|
||||
quart_app, apps_module = _load_apps_module(monkeypatch)
|
||||
|
||||
valid_token = "a" * 32
|
||||
valid_user = SimpleNamespace(id="user-1", email="oidc@example.com", access_token=valid_token)
|
||||
invalid_token_user = SimpleNamespace(id="user-1", email="oidc@example.com", access_token="INVALID_deadbeef")
|
||||
short_token_user = SimpleNamespace(id="user-1", email="oidc@example.com", access_token="too-short")
|
||||
|
||||
async def _case():
|
||||
# No Authorization header but a valid session: helper resolves the user.
|
||||
async with quart_app.test_request_context("/"):
|
||||
from quart import session
|
||||
|
||||
session["_user_id"] = "user-1"
|
||||
monkeypatch.setattr(apps_module.UserService, "query", lambda **_kw: [valid_user])
|
||||
assert apps_module._load_user() is valid_user
|
||||
|
||||
# Malformed bearer header still falls back to session.
|
||||
async with quart_app.test_request_context("/", headers={"Authorization": "Bearer"}):
|
||||
from quart import session
|
||||
|
||||
session["_user_id"] = "user-1"
|
||||
monkeypatch.setattr(apps_module.UserService, "query", lambda **_kw: [valid_user])
|
||||
assert apps_module._load_user() is valid_user
|
||||
|
||||
# Logout-revoked tokens (INVALID_ prefix) are rejected even with a session.
|
||||
async with quart_app.test_request_context("/"):
|
||||
from quart import session
|
||||
|
||||
session["_user_id"] = "user-1"
|
||||
monkeypatch.setattr(apps_module.UserService, "query", lambda **_kw: [invalid_token_user])
|
||||
assert apps_module._load_user() is None
|
||||
|
||||
# Short tokens are rejected (matches the JWT-path length floor).
|
||||
async with quart_app.test_request_context("/"):
|
||||
from quart import session
|
||||
|
||||
session["_user_id"] = "user-1"
|
||||
monkeypatch.setattr(apps_module.UserService, "query", lambda **_kw: [short_token_user])
|
||||
assert apps_module._load_user() is None
|
||||
|
||||
# No session and no header → still None.
|
||||
async with quart_app.test_request_context("/"):
|
||||
assert apps_module._load_user() is None
|
||||
|
||||
# Database errors during the session lookup are swallowed and logged.
|
||||
async with quart_app.test_request_context("/"):
|
||||
from quart import session
|
||||
|
||||
session["_user_id"] = "user-1"
|
||||
|
||||
def _raise(**_kw):
|
||||
raise RuntimeError("db down")
|
||||
|
||||
monkeypatch.setattr(apps_module.UserService, "query", _raise)
|
||||
with caplog.at_level(logging.ERROR):
|
||||
assert apps_module._load_user() is None
|
||||
|
||||
_run(_case())
|
||||
assert "load_user from session failed" in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_load_user_session_fallback_after_token_paths_fail(monkeypatch):
|
||||
"""JWT-decode failures and API-token exhaustion must still fall through
|
||||
to the session and return the user, not None."""
|
||||
quart_app, apps_module = _load_apps_module(monkeypatch)
|
||||
|
||||
valid_token = "b" * 32
|
||||
valid_user = SimpleNamespace(id="user-1", email="oidc@example.com", access_token=valid_token)
|
||||
|
||||
def _raise_decode(_self, _auth):
|
||||
raise RuntimeError("jwt decode boom")
|
||||
|
||||
monkeypatch.setattr(apps_module.Serializer, "loads", _raise_decode)
|
||||
monkeypatch.setattr(apps_module.APIToken, "query", lambda **_kw: [])
|
||||
|
||||
async def _case():
|
||||
# JWT decode fails AND API-token query returns nothing → session wins.
|
||||
async with quart_app.test_request_context("/", headers={"Authorization": "Bearer junk"}):
|
||||
from quart import session
|
||||
|
||||
session["_user_id"] = "user-1"
|
||||
monkeypatch.setattr(apps_module.UserService, "query", lambda **_kw: [valid_user])
|
||||
assert apps_module._load_user() is valid_user
|
||||
|
||||
_run(_case())
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_login_required_timing_and_login_user_inactive(monkeypatch, caplog):
|
||||
quart_app, apps_module = _load_apps_module(monkeypatch)
|
||||
@@ -227,6 +317,7 @@ def test_logout_user_not_found_and_unauthorized_handlers(monkeypatch):
|
||||
assert "Not Found:" in payload["message"]
|
||||
|
||||
async with quart_app.test_request_context("/protected"):
|
||||
|
||||
@apps_module.login_required
|
||||
async def _protected():
|
||||
return {"ok": True}
|
||||
|
||||
Reference in New Issue
Block a user