fix(auth): fall back to session-based auth in _load_user (#14569)

## Summary

Closes #13663.

OAuth / OIDC callbacks call `login_user(user)` which writes `_user_id`
into the session cookie, but `_load_user()` in `api/apps/__init__.py`
only ever looked at the `Authorization` header. The SPA's response
interceptor wipes the Authorization value from `localStorage` on the
first 401 it sees — meaning that during the post-redirect window after
an OAuth login, a single transient 401 sends every subsequent request
back to the login page even though `login_user()` had already
established a perfectly good server-side session.

The reporter's analysis traces this all the way through the redirect →
`navigate('/')` → first request → empty header → 401 → `removeAll()` →
infinite-redirect-to-login chain.

## What changed

- New `_load_user_from_session()` helper that reads
`session["_user_id"]`, looks up the user in `UserService` (with the same
`StatusEnum.VALID` and `access_token` checks already used elsewhere),
and assigns `g.user`.
- Every `return None` path in `_load_user()` now routes through that
helper before giving up:
  - missing `Authorization` header
  - malformed `bearer ` prefix
  - empty / too-short JWT payload
  - JWT signature failure
  - JWT-resolved user not found / has no `access_token`
  - `APIToken.query()` fallback exhausted

The JWT and API-token paths still take precedence — the session is only
consulted when those can't authenticate the request. So existing
local-login and SDK callers see no behaviour change; only OAuth / OIDC
users that hit the original race now stay logged in.

The Bearer-prefix issue called out in #13663 (lines 103-110) is already
handled in the current code, so this PR only addresses the second half
of the report.

## Test plan

- [ ] Configure OIDC under `oauth` in `service_conf.yaml`
- [ ] Click the OIDC login button, complete auth at the IdP
- [ ] Confirm that navigating between pages no longer bounces back to
`/login`
- [ ] Confirm local email/password login still issues + accepts JWTs
- [ ] Confirm SDK/API key callers still authenticate via `Authorization:
Bearer <api-token>`

---------

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
Mehmet Karakose
2026-05-11 04:59:52 +03:00
committed by GitHub
parent 6cb4bc2947
commit 7ec87f7cb7
2 changed files with 141 additions and 26 deletions

View File

@@ -56,6 +56,7 @@ def _unauthorized_message(error):
except Exception:
return UNAUTHORIZED_MESSAGE
app = Quart(__name__)
app = cors(app, allow_origin="*")
@@ -92,19 +93,52 @@ T = TypeVar("T")
P = ParamSpec("P")
def _load_user_from_session():
"""Resolve the current user from the session cookie set by ``login_user()``.
OAuth/OIDC callbacks call ``login_user(user)`` which writes ``_user_id``
into the session. The frontend's response interceptor wipes the
Authorization header from localStorage on the first 401, so post-redirect
requests can arrive with no header at all — we still want to honour the
server-side session in that window.
The same access-token validity rules used by the JWT path are applied
here so that tokens revoked by ``logout`` (which rewrites the column to
``INVALID_<hex>``) or shortened by data corruption can't keep a stale
session authenticated.
"""
user_id = session.get("_user_id")
if not user_id:
return None
try:
users = UserService.query(id=user_id, status=StatusEnum.VALID.value)
except Exception:
logging.exception("load_user from session failed")
return None
if not users:
return None
user = users[0]
access_token = str(user.access_token or "").strip()
if not access_token or len(access_token) < 32 or access_token.startswith("INVALID_"):
return None
logging.debug("Authenticated request via session fallback for user_id=%s", user_id)
g.user = user
return user
def _load_user():
jwt = Serializer(secret_key=settings.get_secret_key())
authorization = request.headers.get("Authorization")
g.user = None
if not authorization:
return None
return _load_user_from_session()
# Extract auth_token based on whether Authorization starts with "bearer" (case-insensitive)
if authorization.lower().startswith("bearer "):
parts = authorization.split(maxsplit=1)
if len(parts) < 2:
logging.warning("Authorization header has invalid bearer format")
return None
return _load_user_from_session()
auth_token = parts[1]
else:
auth_token = authorization
@@ -115,20 +149,20 @@ def _load_user():
if not access_token or not access_token.strip():
logging.warning("Authentication attempt with empty access token")
return None
return _load_user_from_session()
if len(access_token.strip()) < 32:
logging.warning(f"Authentication attempt with invalid token format: {len(access_token)} chars")
return None
return _load_user_from_session()
user = UserService.query(access_token=access_token, status=StatusEnum.VALID.value)
if user:
if not user[0].access_token or not user[0].access_token.strip():
logging.warning(f"User {user[0].email} has empty access_token in database")
return None
return _load_user_from_session()
g.user = user[0]
return user[0]
return None
return _load_user_from_session()
except Exception as e_jwt:
logging.warning(f"load_user from jwt got exception {e_jwt}")
@@ -140,7 +174,7 @@ def _load_user():
if user:
if not user[0].access_token or not user[0].access_token.strip():
logging.warning(f"User {user[0].email} has empty access_token in database")
return None
return _load_user_from_session()
g.user = user[0]
return user[0]
logging.warning(f"load_user: No user found for tenant_id={objs[0].tenant_id} from APIToken")
@@ -149,7 +183,7 @@ def _load_user():
except Exception as e_api_token:
logging.warning(f"load_user from api token got exception {e_api_token}")
return None
return _load_user_from_session()
current_user = LocalProxy(_load_user)
@@ -251,16 +285,10 @@ def logout_user():
def search_pages_path(page_path):
app_path_list = [
path for path in page_path.glob("*_app.py") if not path.name.startswith(".")
]
api_path_list = [
path for path in page_path.glob("*sdk/*.py") if not path.name.startswith(".")
]
app_path_list = [path for path in page_path.glob("*_app.py") if not path.name.startswith(".")]
api_path_list = [path for path in page_path.glob("*sdk/*.py") if not path.name.startswith(".")]
app_path_list.extend(api_path_list)
restful_api_path_list = [
path for path in page_path.glob("*restful_apis/*.py") if not path.name.startswith(".")
]
restful_api_path_list = [path for path in page_path.glob("*restful_apis/*.py") if not path.name.startswith(".")]
app_path_list.extend(restful_api_path_list)
return app_path_list
@@ -269,9 +297,7 @@ def register_page(page_path):
path = f"{page_path}"
page_name = page_path.stem.removesuffix("_app")
module_name = ".".join(
page_path.parts[page_path.parts.index("api"): -1] + (page_name,)
)
module_name = ".".join(page_path.parts[page_path.parts.index("api") : -1] + (page_name,))
spec = spec_from_file_location(module_name, page_path)
page = module_from_spec(spec)
@@ -282,9 +308,7 @@ def register_page(page_path):
page_name = getattr(page, "page_name", page_name)
sdk_path = "\\sdk\\" if sys.platform.startswith("win") else "/sdk/"
restful_api_path = "\\restful_apis\\" if sys.platform.startswith("win") else "/restful_apis/"
url_prefix = (
f"/api/{API_VERSION}" if sdk_path in path or restful_api_path in path else f"/{API_VERSION}/{page_name}"
)
url_prefix = f"/api/{API_VERSION}" if sdk_path in path or restful_api_path in path else f"/{API_VERSION}/{page_name}"
app.register_blueprint(page.manager, url_prefix=url_prefix)
return url_prefix
@@ -297,12 +321,11 @@ pages_dir = [
Path(__file__).parent.parent / "api" / "apps" / "sdk",
]
client_urls_prefix = [
register_page(path) for directory in pages_dir for path in search_pages_path(directory)
]
client_urls_prefix = [register_page(path) for directory in pages_dir for path in search_pages_path(directory)]
# Register backward compatibility routes for deprecated APIs
from api.apps.backward_compat import register_backward_compat_routes
register_backward_compat_routes(app)
@@ -336,6 +359,7 @@ async def unauthorized_werkzeug(error):
logging.warning("Unauthorized request (werkzeug)")
return get_json_result(code=error.code, message=error.description), RetCode.UNAUTHORIZED
@app.teardown_request
def _db_close(exception):
if exception:

View File

@@ -175,6 +175,96 @@ def test_load_user_api_token_fallback_and_fallback_exception(monkeypatch, caplog
assert "api token fallback failed" in caplog.text
@pytest.mark.p2
def test_load_user_session_fallback(monkeypatch, caplog):
quart_app, apps_module = _load_apps_module(monkeypatch)
valid_token = "a" * 32
valid_user = SimpleNamespace(id="user-1", email="oidc@example.com", access_token=valid_token)
invalid_token_user = SimpleNamespace(id="user-1", email="oidc@example.com", access_token="INVALID_deadbeef")
short_token_user = SimpleNamespace(id="user-1", email="oidc@example.com", access_token="too-short")
async def _case():
# No Authorization header but a valid session: helper resolves the user.
async with quart_app.test_request_context("/"):
from quart import session
session["_user_id"] = "user-1"
monkeypatch.setattr(apps_module.UserService, "query", lambda **_kw: [valid_user])
assert apps_module._load_user() is valid_user
# Malformed bearer header still falls back to session.
async with quart_app.test_request_context("/", headers={"Authorization": "Bearer"}):
from quart import session
session["_user_id"] = "user-1"
monkeypatch.setattr(apps_module.UserService, "query", lambda **_kw: [valid_user])
assert apps_module._load_user() is valid_user
# Logout-revoked tokens (INVALID_ prefix) are rejected even with a session.
async with quart_app.test_request_context("/"):
from quart import session
session["_user_id"] = "user-1"
monkeypatch.setattr(apps_module.UserService, "query", lambda **_kw: [invalid_token_user])
assert apps_module._load_user() is None
# Short tokens are rejected (matches the JWT-path length floor).
async with quart_app.test_request_context("/"):
from quart import session
session["_user_id"] = "user-1"
monkeypatch.setattr(apps_module.UserService, "query", lambda **_kw: [short_token_user])
assert apps_module._load_user() is None
# No session and no header → still None.
async with quart_app.test_request_context("/"):
assert apps_module._load_user() is None
# Database errors during the session lookup are swallowed and logged.
async with quart_app.test_request_context("/"):
from quart import session
session["_user_id"] = "user-1"
def _raise(**_kw):
raise RuntimeError("db down")
monkeypatch.setattr(apps_module.UserService, "query", _raise)
with caplog.at_level(logging.ERROR):
assert apps_module._load_user() is None
_run(_case())
assert "load_user from session failed" in caplog.text
@pytest.mark.p2
def test_load_user_session_fallback_after_token_paths_fail(monkeypatch):
"""JWT-decode failures and API-token exhaustion must still fall through
to the session and return the user, not None."""
quart_app, apps_module = _load_apps_module(monkeypatch)
valid_token = "b" * 32
valid_user = SimpleNamespace(id="user-1", email="oidc@example.com", access_token=valid_token)
def _raise_decode(_self, _auth):
raise RuntimeError("jwt decode boom")
monkeypatch.setattr(apps_module.Serializer, "loads", _raise_decode)
monkeypatch.setattr(apps_module.APIToken, "query", lambda **_kw: [])
async def _case():
# JWT decode fails AND API-token query returns nothing → session wins.
async with quart_app.test_request_context("/", headers={"Authorization": "Bearer junk"}):
from quart import session
session["_user_id"] = "user-1"
monkeypatch.setattr(apps_module.UserService, "query", lambda **_kw: [valid_user])
assert apps_module._load_user() is valid_user
_run(_case())
@pytest.mark.p2
def test_login_required_timing_and_login_user_inactive(monkeypatch, caplog):
quart_app, apps_module = _load_apps_module(monkeypatch)
@@ -227,6 +317,7 @@ def test_logout_user_not_found_and_unauthorized_handlers(monkeypatch):
assert "Not Found:" in payload["message"]
async with quart_app.test_request_context("/protected"):
@apps_module.login_required
async def _protected():
return {"ok": True}