diff --git a/api/apps/restful_apis/connector_api.py b/api/apps/restful_apis/connector_api.py index 8e9403fcd7..99a5893021 100644 --- a/api/apps/restful_apis/connector_api.py +++ b/api/apps/restful_apis/connector_api.py @@ -172,6 +172,22 @@ def _get_web_client_config(credentials: dict[str, Any]) -> dict[str, Any]: return {"web": web_section} +def _exchange_google_web_oauth_code( + client_config: dict[str, Any], + scopes: list[str], + redirect_uri: str, + code: str, + code_verifier: str | None, +) -> Flow: + flow = Flow.from_client_config(client_config, scopes=scopes) + flow.redirect_uri = redirect_uri + fetch_token_kwargs: dict[str, Any] = {"code": code} + if code_verifier: + fetch_token_kwargs["code_verifier"] = code_verifier + flow.fetch_token(**fetch_token_kwargs) + return flow + + async def _render_web_oauth_popup(flow_id: str, success: bool, message: str, source="drive"): status = "success" if success else "error" auto_close = "window.close();" if success else "" @@ -267,6 +283,7 @@ async def start_google_web_oauth(): "user_id": current_user.id, "client_config": client_config, "redirect_uri": redirect_uri, + "code_verifier": flow.code_verifier, "created_at": int(time.time()), } REDIS_CONN.set_obj(_web_state_cache_key(flow_id, source), cache_payload, WEB_FLOW_TTL_SECS) @@ -298,6 +315,7 @@ async def google_gmail_web_oauth_callback(): state_obj = json.loads(state_cache) client_config = state_obj.get("client_config") redirect_uri = state_obj.get("redirect_uri", GMAIL_WEB_OAUTH_REDIRECT_URI) + code_verifier = state_obj.get("code_verifier") if not client_config: REDIS_CONN.delete(_web_state_cache_key(state_id, source)) return await _render_web_oauth_popup(state_id, False, "Authorization session was invalid. Please retry.", source) @@ -311,10 +329,13 @@ async def google_gmail_web_oauth_callback(): return await _render_web_oauth_popup(state_id, False, "Missing authorization code from Google.", source) try: - # TODO(google-oauth): branch scopes/redirect_uri based on source_type (drive vs gmail) - flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GMAIL]) - flow.redirect_uri = redirect_uri - flow.fetch_token(code=code) + flow = _exchange_google_web_oauth_code( + client_config=client_config, + scopes=GOOGLE_SCOPES[DocumentSource.GMAIL], + redirect_uri=redirect_uri, + code=code, + code_verifier=code_verifier, + ) except Exception as exc: # pragma: no cover - defensive logging.exception("Failed to exchange Google OAuth code: %s", exc) REDIS_CONN.delete(_web_state_cache_key(state_id, source)) @@ -349,6 +370,7 @@ async def google_drive_web_oauth_callback(): state_obj = json.loads(state_cache) client_config = state_obj.get("client_config") redirect_uri = state_obj.get("redirect_uri", GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI) + code_verifier = state_obj.get("code_verifier") if not client_config: REDIS_CONN.delete(_web_state_cache_key(state_id, source)) return await _render_web_oauth_popup(state_id, False, "Authorization session was invalid. Please retry.", source) @@ -362,10 +384,13 @@ async def google_drive_web_oauth_callback(): return await _render_web_oauth_popup(state_id, False, "Missing authorization code from Google.", source) try: - # TODO(google-oauth): branch scopes/redirect_uri based on source_type (drive vs gmail) - flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE]) - flow.redirect_uri = redirect_uri - flow.fetch_token(code=code) + flow = _exchange_google_web_oauth_code( + client_config=client_config, + scopes=GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE], + redirect_uri=redirect_uri, + code=code, + code_verifier=code_verifier, + ) except Exception as exc: # pragma: no cover - defensive logging.exception("Failed to exchange Google OAuth code: %s", exc) REDIS_CONN.delete(_web_state_cache_key(state_id, source)) diff --git a/test/testcases/test_web_api/test_connector_app/test_connector_routes_unit.py b/test/testcases/test_web_api/test_connector_app/test_connector_routes_unit.py index ea3bad9078..9d9e1c9c14 100644 --- a/test/testcases/test_web_api/test_connector_app/test_connector_routes_unit.py +++ b/test/testcases/test_web_api/test_connector_app/test_connector_routes_unit.py @@ -88,13 +88,16 @@ class _FakeFlow: self.credentials = _FakeCredentials() self.auth_kwargs = None self.token_code = None + self.token_code_verifier = None + self.code_verifier = "fake-code-verifier" def authorization_url(self, **kwargs): self.auth_kwargs = dict(kwargs) return f"https://oauth.example/{kwargs['state']}", kwargs["state"] - def fetch_token(self, code): + def fetch_token(self, code, code_verifier=None): self.token_code = code + self.token_code_verifier = code_verifier class _FakeBoxToken: @@ -519,6 +522,8 @@ def test_start_google_web_oauth_matrix(monkeypatch): assert any(call.scopes == module.GOOGLE_SCOPES[module.DocumentSource.GOOGLE_DRIVE] for call in flow_calls) assert "gmail_web_flow_state:flow-gmail" in redis.store assert "google-drive_web_flow_state:flow-drive" in redis.store + assert json.loads(redis.store["gmail_web_flow_state:flow-gmail"])["code_verifier"] == "fake-code-verifier" + assert json.loads(redis.store["google-drive_web_flow_state:flow-drive"])["code_verifier"] == "fake-code-verifier" @pytest.mark.p2 @@ -586,6 +591,7 @@ def test_google_web_oauth_callbacks_matrix(monkeypatch): redis.store[module._web_state_cache_key("sid", source)] = json.dumps({ "user_id": "tenant-1", "client_config": {"web": {"client_id": "cid"}}, + "code_verifier": "state-code-verifier", }) _set_request(module, args={"state": "sid", "code": "code-123"}) success = _run(callback()) @@ -598,6 +604,7 @@ def test_google_web_oauth_callbacks_matrix(monkeypatch): assert flow_calls[-1].redirect_uri == expected_redirect assert flow_calls[-1].scopes == expected_scopes assert flow_calls[-1].token_code == "code-123" + assert flow_calls[-1].token_code_verifier == "state-code-verifier" @pytest.mark.p2