diff --git a/api/apps/__init__.py b/api/apps/__init__.py index e05bbb03d4..e26b2c39af 100644 --- a/api/apps/__init__.py +++ b/api/apps/__init__.py @@ -56,6 +56,7 @@ def _unauthorized_message(error): except Exception: return UNAUTHORIZED_MESSAGE + app = Quart(__name__) app = cors(app, allow_origin="*") @@ -92,19 +93,52 @@ T = TypeVar("T") P = ParamSpec("P") +def _load_user_from_session(): + """Resolve the current user from the session cookie set by ``login_user()``. + + OAuth/OIDC callbacks call ``login_user(user)`` which writes ``_user_id`` + into the session. The frontend's response interceptor wipes the + Authorization header from localStorage on the first 401, so post-redirect + requests can arrive with no header at all — we still want to honour the + server-side session in that window. + + The same access-token validity rules used by the JWT path are applied + here so that tokens revoked by ``logout`` (which rewrites the column to + ``INVALID_``) or shortened by data corruption can't keep a stale + session authenticated. + """ + user_id = session.get("_user_id") + if not user_id: + return None + try: + users = UserService.query(id=user_id, status=StatusEnum.VALID.value) + except Exception: + logging.exception("load_user from session failed") + return None + if not users: + return None + user = users[0] + access_token = str(user.access_token or "").strip() + if not access_token or len(access_token) < 32 or access_token.startswith("INVALID_"): + return None + logging.debug("Authenticated request via session fallback for user_id=%s", user_id) + g.user = user + return user + + def _load_user(): jwt = Serializer(secret_key=settings.get_secret_key()) authorization = request.headers.get("Authorization") g.user = None if not authorization: - return None + return _load_user_from_session() # Extract auth_token based on whether Authorization starts with "bearer" (case-insensitive) if authorization.lower().startswith("bearer "): parts = authorization.split(maxsplit=1) if len(parts) < 2: logging.warning("Authorization header has invalid bearer format") - return None + return _load_user_from_session() auth_token = parts[1] else: auth_token = authorization @@ -115,20 +149,20 @@ def _load_user(): if not access_token or not access_token.strip(): logging.warning("Authentication attempt with empty access token") - return None + return _load_user_from_session() if len(access_token.strip()) < 32: logging.warning(f"Authentication attempt with invalid token format: {len(access_token)} chars") - return None + return _load_user_from_session() user = UserService.query(access_token=access_token, status=StatusEnum.VALID.value) if user: if not user[0].access_token or not user[0].access_token.strip(): logging.warning(f"User {user[0].email} has empty access_token in database") - return None + return _load_user_from_session() g.user = user[0] return user[0] - return None + return _load_user_from_session() except Exception as e_jwt: logging.warning(f"load_user from jwt got exception {e_jwt}") @@ -140,7 +174,7 @@ def _load_user(): if user: if not user[0].access_token or not user[0].access_token.strip(): logging.warning(f"User {user[0].email} has empty access_token in database") - return None + return _load_user_from_session() g.user = user[0] return user[0] logging.warning(f"load_user: No user found for tenant_id={objs[0].tenant_id} from APIToken") @@ -149,7 +183,7 @@ def _load_user(): except Exception as e_api_token: logging.warning(f"load_user from api token got exception {e_api_token}") - return None + return _load_user_from_session() current_user = LocalProxy(_load_user) @@ -251,16 +285,10 @@ def logout_user(): def search_pages_path(page_path): - app_path_list = [ - path for path in page_path.glob("*_app.py") if not path.name.startswith(".") - ] - api_path_list = [ - path for path in page_path.glob("*sdk/*.py") if not path.name.startswith(".") - ] + app_path_list = [path for path in page_path.glob("*_app.py") if not path.name.startswith(".")] + api_path_list = [path for path in page_path.glob("*sdk/*.py") if not path.name.startswith(".")] app_path_list.extend(api_path_list) - restful_api_path_list = [ - path for path in page_path.glob("*restful_apis/*.py") if not path.name.startswith(".") - ] + restful_api_path_list = [path for path in page_path.glob("*restful_apis/*.py") if not path.name.startswith(".")] app_path_list.extend(restful_api_path_list) return app_path_list @@ -269,9 +297,7 @@ def register_page(page_path): path = f"{page_path}" page_name = page_path.stem.removesuffix("_app") - module_name = ".".join( - page_path.parts[page_path.parts.index("api"): -1] + (page_name,) - ) + module_name = ".".join(page_path.parts[page_path.parts.index("api") : -1] + (page_name,)) spec = spec_from_file_location(module_name, page_path) page = module_from_spec(spec) @@ -282,9 +308,7 @@ def register_page(page_path): page_name = getattr(page, "page_name", page_name) sdk_path = "\\sdk\\" if sys.platform.startswith("win") else "/sdk/" restful_api_path = "\\restful_apis\\" if sys.platform.startswith("win") else "/restful_apis/" - url_prefix = ( - f"/api/{API_VERSION}" if sdk_path in path or restful_api_path in path else f"/{API_VERSION}/{page_name}" - ) + url_prefix = f"/api/{API_VERSION}" if sdk_path in path or restful_api_path in path else f"/{API_VERSION}/{page_name}" app.register_blueprint(page.manager, url_prefix=url_prefix) return url_prefix @@ -297,12 +321,11 @@ pages_dir = [ Path(__file__).parent.parent / "api" / "apps" / "sdk", ] -client_urls_prefix = [ - register_page(path) for directory in pages_dir for path in search_pages_path(directory) -] +client_urls_prefix = [register_page(path) for directory in pages_dir for path in search_pages_path(directory)] # Register backward compatibility routes for deprecated APIs from api.apps.backward_compat import register_backward_compat_routes + register_backward_compat_routes(app) @@ -336,6 +359,7 @@ async def unauthorized_werkzeug(error): logging.warning("Unauthorized request (werkzeug)") return get_json_result(code=error.code, message=error.description), RetCode.UNAUTHORIZED + @app.teardown_request def _db_close(exception): if exception: diff --git a/test/testcases/test_web_api/test_system_app/test_apps_init_unit.py b/test/testcases/test_web_api/test_system_app/test_apps_init_unit.py index e183100cd3..c7d951270a 100644 --- a/test/testcases/test_web_api/test_system_app/test_apps_init_unit.py +++ b/test/testcases/test_web_api/test_system_app/test_apps_init_unit.py @@ -175,6 +175,96 @@ def test_load_user_api_token_fallback_and_fallback_exception(monkeypatch, caplog assert "api token fallback failed" in caplog.text +@pytest.mark.p2 +def test_load_user_session_fallback(monkeypatch, caplog): + quart_app, apps_module = _load_apps_module(monkeypatch) + + valid_token = "a" * 32 + valid_user = SimpleNamespace(id="user-1", email="oidc@example.com", access_token=valid_token) + invalid_token_user = SimpleNamespace(id="user-1", email="oidc@example.com", access_token="INVALID_deadbeef") + short_token_user = SimpleNamespace(id="user-1", email="oidc@example.com", access_token="too-short") + + async def _case(): + # No Authorization header but a valid session: helper resolves the user. + async with quart_app.test_request_context("/"): + from quart import session + + session["_user_id"] = "user-1" + monkeypatch.setattr(apps_module.UserService, "query", lambda **_kw: [valid_user]) + assert apps_module._load_user() is valid_user + + # Malformed bearer header still falls back to session. + async with quart_app.test_request_context("/", headers={"Authorization": "Bearer"}): + from quart import session + + session["_user_id"] = "user-1" + monkeypatch.setattr(apps_module.UserService, "query", lambda **_kw: [valid_user]) + assert apps_module._load_user() is valid_user + + # Logout-revoked tokens (INVALID_ prefix) are rejected even with a session. + async with quart_app.test_request_context("/"): + from quart import session + + session["_user_id"] = "user-1" + monkeypatch.setattr(apps_module.UserService, "query", lambda **_kw: [invalid_token_user]) + assert apps_module._load_user() is None + + # Short tokens are rejected (matches the JWT-path length floor). + async with quart_app.test_request_context("/"): + from quart import session + + session["_user_id"] = "user-1" + monkeypatch.setattr(apps_module.UserService, "query", lambda **_kw: [short_token_user]) + assert apps_module._load_user() is None + + # No session and no header → still None. + async with quart_app.test_request_context("/"): + assert apps_module._load_user() is None + + # Database errors during the session lookup are swallowed and logged. + async with quart_app.test_request_context("/"): + from quart import session + + session["_user_id"] = "user-1" + + def _raise(**_kw): + raise RuntimeError("db down") + + monkeypatch.setattr(apps_module.UserService, "query", _raise) + with caplog.at_level(logging.ERROR): + assert apps_module._load_user() is None + + _run(_case()) + assert "load_user from session failed" in caplog.text + + +@pytest.mark.p2 +def test_load_user_session_fallback_after_token_paths_fail(monkeypatch): + """JWT-decode failures and API-token exhaustion must still fall through + to the session and return the user, not None.""" + quart_app, apps_module = _load_apps_module(monkeypatch) + + valid_token = "b" * 32 + valid_user = SimpleNamespace(id="user-1", email="oidc@example.com", access_token=valid_token) + + def _raise_decode(_self, _auth): + raise RuntimeError("jwt decode boom") + + monkeypatch.setattr(apps_module.Serializer, "loads", _raise_decode) + monkeypatch.setattr(apps_module.APIToken, "query", lambda **_kw: []) + + async def _case(): + # JWT decode fails AND API-token query returns nothing → session wins. + async with quart_app.test_request_context("/", headers={"Authorization": "Bearer junk"}): + from quart import session + + session["_user_id"] = "user-1" + monkeypatch.setattr(apps_module.UserService, "query", lambda **_kw: [valid_user]) + assert apps_module._load_user() is valid_user + + _run(_case()) + + @pytest.mark.p2 def test_login_required_timing_and_login_user_inactive(monkeypatch, caplog): quart_app, apps_module = _load_apps_module(monkeypatch) @@ -227,6 +317,7 @@ def test_logout_user_not_found_and_unauthorized_handlers(monkeypatch): assert "Not Found:" in payload["message"] async with quart_app.test_request_context("/protected"): + @apps_module.login_required async def _protected(): return {"ok": True}