mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
### What problem does this PR solve? This PR adds Google BigQuery as a first-class data source connector in RAGFlow. It enables users to ingest and sync BigQuery data using the same row-to-document model used by relational database connectors: selected content columns become document text, metadata columns become document metadata, an optional ID column provides stable document IDs, and an optional timestamp column enables cursor-based incremental sync. The connector supports service-account JSON credentials, table mode, custom query mode, GoogleSQL queries, cursor-based incremental sync, deleted-row pruning support, configurable query limits such as `maximum_bytes_billed`, dry-run validation, batch loading, stable document IDs, and BigQuery-aware value serialization.
677 lines
26 KiB
Python
677 lines
26 KiB
Python
#
|
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import time
|
|
import uuid
|
|
from html import escape
|
|
from typing import Any
|
|
|
|
from quart import request, make_response
|
|
from google_auth_oauthlib.flow import Flow
|
|
|
|
from api.db import InputType
|
|
from api.db.services.connector_service import ConnectorService, SyncLogsService
|
|
from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, validate_request
|
|
from api.utils.pagination_utils import validate_rest_api_page_size
|
|
from common.constants import RetCode, TaskStatus
|
|
from common.data_source.config import GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI, GMAIL_WEB_OAUTH_REDIRECT_URI, BOX_WEB_OAUTH_REDIRECT_URI, DocumentSource
|
|
from common.data_source.google_util.constant import WEB_OAUTH_POPUP_TEMPLATE, GOOGLE_SCOPES
|
|
from common.misc_utils import get_uuid
|
|
from rag.utils.redis_conn import REDIS_CONN
|
|
from api.apps import login_required, current_user
|
|
from box_sdk_gen import BoxOAuth, OAuthConfig, GetAuthorizeUrlOptions
|
|
|
|
|
|
LOGGER = logging.getLogger(__name__)
|
|
|
|
|
|
def _connector_auth_error(connector_id: str, user_id: str):
|
|
"""Return the connector authorization failure response and log the denial."""
|
|
LOGGER.warning("connector access denied: connector_id=%s user_id=%s", connector_id, user_id)
|
|
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
|
|
|
|
|
@manager.route("/connectors/<connector_id>", methods=["PATCH"]) # noqa: F821
|
|
@login_required
|
|
async def update_connector(connector_id):
|
|
"""Update an accessible connector's polling configuration."""
|
|
if not ConnectorService.accessible(connector_id, current_user.id):
|
|
return _connector_auth_error(connector_id, current_user.id)
|
|
|
|
req = await get_request_json()
|
|
if isinstance(req, dict) and isinstance(req.get("data"), dict):
|
|
req = req["data"]
|
|
|
|
e, conn = ConnectorService.get_by_id(connector_id)
|
|
if not e:
|
|
return get_data_error_result(message="Can't find this Connector!")
|
|
|
|
should_sleep = False
|
|
if req:
|
|
update_fields = {fld: req[fld] for fld in ["prune_freq", "refresh_freq", "config", "timeout_secs"] if fld in req}
|
|
if update_fields:
|
|
update_fields["id"] = connector_id
|
|
ConnectorService.update_by_id(connector_id, update_fields)
|
|
should_sleep = True
|
|
|
|
if req.get("reschedule"):
|
|
ConnectorService.cancel_tasks(connector_id)
|
|
ConnectorService.schedule_tasks(connector_id)
|
|
elif req.get("status") in [TaskStatus.CANCEL, "CANCEL"]:
|
|
ConnectorService.cancel_tasks(connector_id)
|
|
elif req.get("status") in [TaskStatus.SCHEDULE, "SCHEDULE"]:
|
|
ConnectorService.schedule_tasks(connector_id)
|
|
|
|
if should_sleep:
|
|
await asyncio.sleep(1)
|
|
e, conn = ConnectorService.get_by_id(connector_id)
|
|
if not e:
|
|
return get_data_error_result(message="Can't find this Connector!")
|
|
|
|
return get_json_result(data=conn.to_dict())
|
|
|
|
|
|
@manager.route("/connectors", methods=["POST"]) # noqa: F821
|
|
@login_required
|
|
async def create_connector():
|
|
"""Create a connector owned by the current tenant."""
|
|
req = await get_request_json()
|
|
if req:
|
|
req["id"] = get_uuid()
|
|
conn = {
|
|
"id": req["id"],
|
|
"tenant_id": current_user.id,
|
|
"name": req["name"],
|
|
"source": req["source"],
|
|
"input_type": InputType.POLL,
|
|
"config": req["config"],
|
|
"refresh_freq": int(req.get("refresh_freq", 5)),
|
|
"prune_freq": int(req.get("prune_freq", 5)),
|
|
"timeout_secs": int(req.get("timeout_secs", 60 * 29)),
|
|
"status": TaskStatus.UNSTART,
|
|
}
|
|
ConnectorService.save(**conn)
|
|
|
|
await asyncio.sleep(1)
|
|
e, conn = ConnectorService.get_by_id(req["id"])
|
|
|
|
return get_json_result(data=conn.to_dict())
|
|
|
|
|
|
@manager.route("/connectors", methods=["GET"]) # noqa: F821
|
|
@login_required
|
|
def list_connector():
|
|
"""List connectors owned by the current tenant."""
|
|
return get_json_result(data=ConnectorService.list(current_user.id))
|
|
|
|
|
|
@manager.route("/connectors/<connector_id>", methods=["GET"]) # noqa: F821
|
|
@login_required
|
|
def get_connector(connector_id):
|
|
"""Return connector details when the current user can access it."""
|
|
if not ConnectorService.accessible(connector_id, current_user.id):
|
|
return _connector_auth_error(connector_id, current_user.id)
|
|
|
|
e, conn = ConnectorService.get_by_id(connector_id)
|
|
if not e:
|
|
return get_data_error_result(message="Can't find this Connector!")
|
|
return get_json_result(data=conn.to_dict())
|
|
|
|
|
|
@manager.route("/connectors/<connector_id>/logs", methods=["GET"]) # noqa: F821
|
|
@login_required
|
|
def list_logs(connector_id):
|
|
"""List sync logs for a connector the current user can access."""
|
|
if not ConnectorService.accessible(connector_id, current_user.id):
|
|
return _connector_auth_error(connector_id, current_user.id)
|
|
|
|
req = request.args.to_dict(flat=True)
|
|
arr, total = SyncLogsService.list_sync_tasks(
|
|
connector_id,
|
|
int(req.get("page", 1)),
|
|
validate_rest_api_page_size(int(req.get("page_size", 15))),
|
|
)
|
|
return get_json_result(data={"total": total, "logs": arr})
|
|
|
|
|
|
@manager.route("/connectors/<connector_id>/rebuild", methods=["POST"]) # noqa: F821
|
|
@login_required
|
|
async def rebuild(connector_id):
|
|
"""Schedule a rebuild for an accessible connector and knowledge base."""
|
|
if not ConnectorService.accessible(connector_id, current_user.id):
|
|
return _connector_auth_error(connector_id, current_user.id)
|
|
|
|
req = await get_request_json()
|
|
if "kb_id" not in req:
|
|
return get_json_result(code=RetCode.ARGUMENT_ERROR, message="required argument is missing: kb_id")
|
|
|
|
err = ConnectorService.rebuild(req["kb_id"], connector_id, current_user.id)
|
|
if err:
|
|
return get_json_result(data=False, message=err, code=RetCode.SERVER_ERROR)
|
|
return get_json_result(data=True)
|
|
|
|
|
|
@manager.route("/connectors/<connector_id>", methods=["DELETE"]) # noqa: F821
|
|
@login_required
|
|
def rm_connector(connector_id):
|
|
"""Delete an accessible connector after canceling its sync tasks."""
|
|
if not ConnectorService.accessible(connector_id, current_user.id):
|
|
return _connector_auth_error(connector_id, current_user.id)
|
|
|
|
ConnectorService.cancel_tasks(connector_id)
|
|
ConnectorService.delete_by_id(connector_id)
|
|
return get_json_result(data=True)
|
|
|
|
|
|
@manager.route("/connectors/<connector_id>/test", methods=["POST"]) # noqa: F821
|
|
@login_required
|
|
async def test_connector(connector_id):
|
|
"""Validate connector configuration without persisting changes or triggering sync.
|
|
|
|
For the REST API connector, this uses `RestAPIConnector.validate_config`
|
|
against the existing saved configuration. For BigQuery, it runs `SELECT 1`
|
|
plus a free dry-run of the configured base query under `maximum_bytes_billed`
|
|
so bad credentials, wrong location, or runaway scans surface before scheduled
|
|
syncs run (and incur cost).
|
|
"""
|
|
if not ConnectorService.accessible(connector_id, current_user.id):
|
|
return _connector_auth_error(connector_id, current_user.id)
|
|
|
|
from common.data_source.exceptions import ConnectorMissingCredentialError, ConnectorValidationError
|
|
|
|
ok, conn = ConnectorService.get_by_id(connector_id)
|
|
if not ok:
|
|
return get_data_error_result(message="Can't find this Connector!")
|
|
|
|
config = conn.config or {}
|
|
credentials = config.get("credentials") or {}
|
|
|
|
if conn.source == DocumentSource.REST_API:
|
|
from common.data_source.rest_api_connector import RestAPIConnector
|
|
|
|
try:
|
|
await asyncio.to_thread(
|
|
RestAPIConnector.validate_config,
|
|
config=config,
|
|
credentials=credentials,
|
|
)
|
|
except (ConnectorValidationError, ConnectorMissingCredentialError) as exc:
|
|
return get_json_result(
|
|
code=RetCode.DATA_ERROR,
|
|
message=str(exc),
|
|
data=False,
|
|
)
|
|
except Exception as exc:
|
|
logging.exception("REST API connector validation failed: %s", exc)
|
|
return get_json_result(
|
|
code=RetCode.SERVER_ERROR,
|
|
message="REST API connector validation failed, please check logs.",
|
|
data=False,
|
|
)
|
|
|
|
return get_json_result(data=True)
|
|
|
|
if conn.source == DocumentSource.BIGQUERY:
|
|
from common.data_source.bigquery_connector import BigQueryConnector
|
|
|
|
def _validate_bigquery():
|
|
connector_kwargs = {
|
|
"project_id": config.get("project_id", ""),
|
|
"dataset_id": config.get("dataset_id") or None,
|
|
"table_id": config.get("table_id") or None,
|
|
"location": config.get("location") or None,
|
|
"query": config.get("query", ""),
|
|
"content_columns": config.get("content_columns", ""),
|
|
"metadata_columns": config.get("metadata_columns", ""),
|
|
"id_column": config.get("id_column") or None,
|
|
"timestamp_column": config.get("timestamp_column") or None,
|
|
"use_query_cache": config.get("use_query_cache", True),
|
|
}
|
|
if config.get("page_size") is not None:
|
|
connector_kwargs["page_size"] = int(config["page_size"])
|
|
if config.get("maximum_bytes_billed") is not None:
|
|
connector_kwargs["maximum_bytes_billed"] = int(config["maximum_bytes_billed"])
|
|
if config.get("job_timeout_ms") is not None:
|
|
connector_kwargs["job_timeout_ms"] = int(config["job_timeout_ms"])
|
|
|
|
connector = BigQueryConnector(**connector_kwargs)
|
|
connector.load_credentials(credentials)
|
|
connector.validate_connector_settings()
|
|
|
|
try:
|
|
await asyncio.to_thread(_validate_bigquery)
|
|
except (ConnectorValidationError, ConnectorMissingCredentialError) as exc:
|
|
return get_json_result(
|
|
code=RetCode.DATA_ERROR,
|
|
message=str(exc),
|
|
data=False,
|
|
)
|
|
except Exception as exc:
|
|
logging.exception("BigQuery connector validation failed: %s", exc)
|
|
return get_json_result(
|
|
code=RetCode.SERVER_ERROR,
|
|
message="BigQuery connector validation failed, please check logs.",
|
|
data=False,
|
|
)
|
|
|
|
return get_json_result(data=True)
|
|
|
|
return get_json_result(
|
|
code=RetCode.ARGUMENT_ERROR,
|
|
message="Test endpoint currently supports only REST API and BigQuery connectors.",
|
|
data=False,
|
|
)
|
|
|
|
|
|
WEB_FLOW_TTL_SECS = 15 * 60
|
|
|
|
|
|
def _web_state_cache_key(flow_id: str, source_type: str | None = None) -> str:
|
|
"""Return Redis key for web OAuth state.
|
|
|
|
The default prefix keeps backward compatibility for Google Drive.
|
|
When source_type == "gmail", a different prefix is used so that
|
|
Drive/Gmail flows don't clash in Redis.
|
|
"""
|
|
prefix = f"{source_type}_web_flow_state"
|
|
return f"{prefix}:{flow_id}"
|
|
|
|
|
|
def _web_result_cache_key(flow_id: str, source_type: str | None = None) -> str:
|
|
"""Return Redis key for web OAuth result.
|
|
|
|
Mirrors _web_state_cache_key logic for result storage.
|
|
"""
|
|
prefix = f"{source_type}_web_flow_result"
|
|
return f"{prefix}:{flow_id}"
|
|
|
|
|
|
def _load_credentials(payload: str | dict[str, Any]) -> dict[str, Any]:
|
|
if isinstance(payload, dict):
|
|
return payload
|
|
try:
|
|
return json.loads(payload)
|
|
except json.JSONDecodeError as exc: # pragma: no cover - defensive
|
|
raise ValueError("Invalid Google credentials JSON.") from exc
|
|
|
|
|
|
def _get_web_client_config(credentials: dict[str, Any]) -> dict[str, Any]:
|
|
web_section = credentials.get("web")
|
|
if not isinstance(web_section, dict):
|
|
raise ValueError("Google OAuth JSON must include a 'web' client configuration to use browser-based authorization.")
|
|
return {"web": web_section}
|
|
|
|
|
|
def _exchange_google_web_oauth_code(
|
|
client_config: dict[str, Any],
|
|
scopes: list[str],
|
|
redirect_uri: str,
|
|
code: str,
|
|
code_verifier: str | None,
|
|
) -> Flow:
|
|
flow = Flow.from_client_config(client_config, scopes=scopes)
|
|
flow.redirect_uri = redirect_uri
|
|
fetch_token_kwargs: dict[str, Any] = {"code": code}
|
|
if code_verifier:
|
|
fetch_token_kwargs["code_verifier"] = code_verifier
|
|
flow.fetch_token(**fetch_token_kwargs)
|
|
return flow
|
|
|
|
|
|
async def _render_web_oauth_popup(flow_id: str, success: bool, message: str, source="drive"):
|
|
status = "success" if success else "error"
|
|
auto_close = "window.close();" if success else ""
|
|
escaped_message = escape(message)
|
|
# Drive: ragflow-google-drive-oauth
|
|
# Gmail: ragflow-gmail-oauth
|
|
payload_type = f"ragflow-{source}-oauth"
|
|
payload_json = json.dumps(
|
|
{
|
|
"type": payload_type,
|
|
"status": status,
|
|
"flowId": flow_id or "",
|
|
"message": message,
|
|
}
|
|
)
|
|
# TODO(google-oauth): title/heading/message may need to reflect drive/gmail based on cached type
|
|
html = WEB_OAUTH_POPUP_TEMPLATE.format(
|
|
title=f"Google {source.capitalize()} Authorization",
|
|
heading="Authorization complete" if success else "Authorization failed",
|
|
message=escaped_message,
|
|
payload_json=payload_json,
|
|
auto_close=auto_close,
|
|
)
|
|
response = await make_response(html, 200)
|
|
response.headers["Content-Type"] = "text/html; charset=utf-8"
|
|
return response
|
|
|
|
|
|
@manager.route("/connectors/google/oauth/web/start", methods=["POST"]) # noqa: F821
|
|
@login_required
|
|
@validate_request("credentials")
|
|
async def start_google_web_oauth():
|
|
source = request.args.get("type", "google-drive")
|
|
if source not in ("google-drive", "gmail"):
|
|
return get_json_result(code=RetCode.ARGUMENT_ERROR, message="Invalid Google OAuth type.")
|
|
|
|
req = await get_request_json()
|
|
|
|
if source == "gmail":
|
|
default_redirect_uri = GMAIL_WEB_OAUTH_REDIRECT_URI
|
|
scopes = GOOGLE_SCOPES[DocumentSource.GMAIL]
|
|
else:
|
|
default_redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI
|
|
scopes = GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE]
|
|
|
|
redirect_uri = req.get("redirect_uri", default_redirect_uri)
|
|
if isinstance(redirect_uri, str):
|
|
redirect_uri = redirect_uri.strip()
|
|
|
|
if not redirect_uri:
|
|
return get_json_result(
|
|
code=RetCode.SERVER_ERROR,
|
|
message="Google OAuth redirect URI is not configured on the server.",
|
|
)
|
|
|
|
raw_credentials = req.get("credentials", "")
|
|
|
|
try:
|
|
credentials = _load_credentials(raw_credentials)
|
|
print(credentials)
|
|
except ValueError as exc:
|
|
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=str(exc))
|
|
|
|
if credentials.get("refresh_token"):
|
|
return get_json_result(
|
|
code=RetCode.ARGUMENT_ERROR,
|
|
message="Uploaded credentials already include a refresh token.",
|
|
)
|
|
|
|
try:
|
|
client_config = _get_web_client_config(credentials)
|
|
except ValueError as exc:
|
|
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=str(exc))
|
|
|
|
flow_id = str(uuid.uuid4())
|
|
try:
|
|
flow = Flow.from_client_config(client_config, scopes=scopes)
|
|
flow.redirect_uri = redirect_uri
|
|
authorization_url, _ = flow.authorization_url(
|
|
access_type="offline",
|
|
include_granted_scopes="true",
|
|
prompt="consent",
|
|
state=flow_id,
|
|
)
|
|
except Exception as exc: # pragma: no cover - defensive
|
|
logging.exception("Failed to create Google OAuth flow: %s", exc)
|
|
return get_json_result(
|
|
code=RetCode.SERVER_ERROR,
|
|
message="Failed to initialize Google OAuth flow. Please verify the uploaded client configuration.",
|
|
)
|
|
|
|
cache_payload = {
|
|
"user_id": current_user.id,
|
|
"client_config": client_config,
|
|
"redirect_uri": redirect_uri,
|
|
"code_verifier": flow.code_verifier,
|
|
"created_at": int(time.time()),
|
|
}
|
|
REDIS_CONN.set_obj(_web_state_cache_key(flow_id, source), cache_payload, WEB_FLOW_TTL_SECS)
|
|
|
|
return get_json_result(
|
|
data={
|
|
"flow_id": flow_id,
|
|
"authorization_url": authorization_url,
|
|
"expires_in": WEB_FLOW_TTL_SECS,
|
|
}
|
|
)
|
|
|
|
|
|
@manager.route("/connectors/gmail/oauth/web/callback", methods=["GET"]) # noqa: F821
|
|
async def google_gmail_web_oauth_callback():
|
|
state_id = request.args.get("state")
|
|
error = request.args.get("error")
|
|
source = "gmail"
|
|
|
|
error_description = request.args.get("error_description") or error
|
|
|
|
if not state_id:
|
|
return await _render_web_oauth_popup("", False, "Missing OAuth state parameter.", source)
|
|
|
|
state_cache = REDIS_CONN.get(_web_state_cache_key(state_id, source))
|
|
if not state_cache:
|
|
return await _render_web_oauth_popup(state_id, False, "Authorization session expired. Please restart from the main window.", source)
|
|
|
|
state_obj = json.loads(state_cache)
|
|
client_config = state_obj.get("client_config")
|
|
redirect_uri = state_obj.get("redirect_uri", GMAIL_WEB_OAUTH_REDIRECT_URI)
|
|
code_verifier = state_obj.get("code_verifier")
|
|
if not client_config:
|
|
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
|
return await _render_web_oauth_popup(state_id, False, "Authorization session was invalid. Please retry.", source)
|
|
|
|
if error:
|
|
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
|
return await _render_web_oauth_popup(state_id, False, error_description or "Authorization was cancelled.", source)
|
|
|
|
code = request.args.get("code")
|
|
if not code:
|
|
return await _render_web_oauth_popup(state_id, False, "Missing authorization code from Google.", source)
|
|
|
|
try:
|
|
flow = _exchange_google_web_oauth_code(
|
|
client_config=client_config,
|
|
scopes=GOOGLE_SCOPES[DocumentSource.GMAIL],
|
|
redirect_uri=redirect_uri,
|
|
code=code,
|
|
code_verifier=code_verifier,
|
|
)
|
|
except Exception as exc: # pragma: no cover - defensive
|
|
logging.exception("Failed to exchange Google OAuth code: %s", exc)
|
|
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
|
return await _render_web_oauth_popup(state_id, False, "Failed to exchange tokens with Google. Please retry.", source)
|
|
|
|
creds_json = flow.credentials.to_json()
|
|
result_payload = {
|
|
"user_id": state_obj.get("user_id"),
|
|
"credentials": creds_json,
|
|
}
|
|
REDIS_CONN.set_obj(_web_result_cache_key(state_id, source), result_payload, WEB_FLOW_TTL_SECS)
|
|
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
|
|
|
return await _render_web_oauth_popup(state_id, True, "Authorization completed successfully.", source)
|
|
|
|
|
|
@manager.route("/connectors/google-drive/oauth/web/callback", methods=["GET"]) # noqa: F821
|
|
async def google_drive_web_oauth_callback():
|
|
state_id = request.args.get("state")
|
|
error = request.args.get("error")
|
|
source = "google-drive"
|
|
|
|
error_description = request.args.get("error_description") or error
|
|
|
|
if not state_id:
|
|
return await _render_web_oauth_popup("", False, "Missing OAuth state parameter.", source)
|
|
|
|
state_cache = REDIS_CONN.get(_web_state_cache_key(state_id, source))
|
|
if not state_cache:
|
|
return await _render_web_oauth_popup(state_id, False, "Authorization session expired. Please restart from the main window.", source)
|
|
|
|
state_obj = json.loads(state_cache)
|
|
client_config = state_obj.get("client_config")
|
|
redirect_uri = state_obj.get("redirect_uri", GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI)
|
|
code_verifier = state_obj.get("code_verifier")
|
|
if not client_config:
|
|
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
|
return await _render_web_oauth_popup(state_id, False, "Authorization session was invalid. Please retry.", source)
|
|
|
|
if error:
|
|
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
|
return await _render_web_oauth_popup(state_id, False, error_description or "Authorization was cancelled.", source)
|
|
|
|
code = request.args.get("code")
|
|
if not code:
|
|
return await _render_web_oauth_popup(state_id, False, "Missing authorization code from Google.", source)
|
|
|
|
try:
|
|
flow = _exchange_google_web_oauth_code(
|
|
client_config=client_config,
|
|
scopes=GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE],
|
|
redirect_uri=redirect_uri,
|
|
code=code,
|
|
code_verifier=code_verifier,
|
|
)
|
|
except Exception as exc: # pragma: no cover - defensive
|
|
logging.exception("Failed to exchange Google OAuth code: %s", exc)
|
|
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
|
return await _render_web_oauth_popup(state_id, False, "Failed to exchange tokens with Google. Please retry.", source)
|
|
|
|
creds_json = flow.credentials.to_json()
|
|
result_payload = {
|
|
"user_id": state_obj.get("user_id"),
|
|
"credentials": creds_json,
|
|
}
|
|
REDIS_CONN.set_obj(_web_result_cache_key(state_id, source), result_payload, WEB_FLOW_TTL_SECS)
|
|
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
|
|
|
return await _render_web_oauth_popup(state_id, True, "Authorization completed successfully.", source)
|
|
|
|
@manager.route("/connectors/google/oauth/web/result", methods=["POST"]) # noqa: F821
|
|
@login_required
|
|
@validate_request("flow_id")
|
|
async def poll_google_web_result():
|
|
req = await request.json or {}
|
|
source = request.args.get("type")
|
|
if source not in ("google-drive", "gmail"):
|
|
return get_json_result(code=RetCode.ARGUMENT_ERROR, message="Invalid Google OAuth type.")
|
|
flow_id = req.get("flow_id")
|
|
cache_raw = REDIS_CONN.get(_web_result_cache_key(flow_id, source))
|
|
if not cache_raw:
|
|
return get_json_result(code=RetCode.RUNNING, message="Authorization is still pending.")
|
|
|
|
result = json.loads(cache_raw)
|
|
if result.get("user_id") != current_user.id:
|
|
return get_json_result(code=RetCode.PERMISSION_ERROR, message="You are not allowed to access this authorization result.")
|
|
|
|
REDIS_CONN.delete(_web_result_cache_key(flow_id, source))
|
|
return get_json_result(data={"credentials": result.get("credentials")})
|
|
|
|
@manager.route("/connectors/box/oauth/web/start", methods=["POST"]) # noqa: F821
|
|
@login_required
|
|
async def start_box_web_oauth():
|
|
req = await get_request_json()
|
|
|
|
client_id = req.get("client_id")
|
|
client_secret = req.get("client_secret")
|
|
redirect_uri = req.get("redirect_uri", BOX_WEB_OAUTH_REDIRECT_URI)
|
|
|
|
if not client_id or not client_secret:
|
|
return get_json_result(code=RetCode.ARGUMENT_ERROR, message="Box client_id and client_secret are required.")
|
|
|
|
flow_id = str(uuid.uuid4())
|
|
|
|
box_auth = BoxOAuth(
|
|
OAuthConfig(
|
|
client_id=client_id,
|
|
client_secret=client_secret,
|
|
)
|
|
)
|
|
|
|
auth_url = box_auth.get_authorize_url(
|
|
options=GetAuthorizeUrlOptions(
|
|
redirect_uri=redirect_uri,
|
|
state=flow_id,
|
|
)
|
|
)
|
|
|
|
cache_payload = {
|
|
"user_id": current_user.id,
|
|
"auth_url": auth_url,
|
|
"client_id": client_id,
|
|
"client_secret": client_secret,
|
|
"created_at": int(time.time()),
|
|
}
|
|
REDIS_CONN.set_obj(_web_state_cache_key(flow_id, "box"), cache_payload, WEB_FLOW_TTL_SECS)
|
|
return get_json_result(
|
|
data = {
|
|
"flow_id": flow_id,
|
|
"authorization_url": auth_url,
|
|
"expires_in": WEB_FLOW_TTL_SECS,}
|
|
)
|
|
|
|
@manager.route("/connectors/box/oauth/web/callback", methods=["GET"]) # noqa: F821
|
|
async def box_web_oauth_callback():
|
|
flow_id = request.args.get("state")
|
|
if not flow_id:
|
|
return await _render_web_oauth_popup("", False, "Missing OAuth parameters.", "box")
|
|
|
|
code = request.args.get("code")
|
|
if not code:
|
|
return await _render_web_oauth_popup(flow_id, False, "Missing authorization code from Box.", "box")
|
|
|
|
cache_payload = json.loads(REDIS_CONN.get(_web_state_cache_key(flow_id, "box")))
|
|
if not cache_payload:
|
|
return get_json_result(code=RetCode.ARGUMENT_ERROR, message="Box OAuth session expired or invalid.")
|
|
|
|
error = request.args.get("error")
|
|
error_description = request.args.get("error_description") or error
|
|
if error:
|
|
REDIS_CONN.delete(_web_state_cache_key(flow_id, "box"))
|
|
return await _render_web_oauth_popup(flow_id, False, error_description or "Authorization failed.", "box")
|
|
|
|
auth = BoxOAuth(
|
|
OAuthConfig(
|
|
client_id=cache_payload.get("client_id"),
|
|
client_secret=cache_payload.get("client_secret"),
|
|
)
|
|
)
|
|
|
|
auth.get_tokens_authorization_code_grant(code)
|
|
token = auth.retrieve_token()
|
|
result_payload = {
|
|
"user_id": cache_payload.get("user_id"),
|
|
"client_id": cache_payload.get("client_id"),
|
|
"client_secret": cache_payload.get("client_secret"),
|
|
"access_token": token.access_token,
|
|
"refresh_token": token.refresh_token,
|
|
}
|
|
|
|
REDIS_CONN.set_obj(_web_result_cache_key(flow_id, "box"), result_payload, WEB_FLOW_TTL_SECS)
|
|
REDIS_CONN.delete(_web_state_cache_key(flow_id, "box"))
|
|
|
|
return await _render_web_oauth_popup(flow_id, True, "Authorization completed successfully.", "box")
|
|
|
|
@manager.route("/connectors/box/oauth/web/result", methods=["POST"]) # noqa: F821
|
|
@login_required
|
|
@validate_request("flow_id")
|
|
async def poll_box_web_result():
|
|
req = await get_request_json()
|
|
flow_id = req.get("flow_id")
|
|
|
|
cache_blob = REDIS_CONN.get(_web_result_cache_key(flow_id, "box"))
|
|
if not cache_blob:
|
|
return get_json_result(code=RetCode.RUNNING, message="Authorization is still pending.")
|
|
|
|
cache_raw = json.loads(cache_blob)
|
|
if cache_raw.get("user_id") != current_user.id:
|
|
return get_json_result(code=RetCode.PERMISSION_ERROR, message="You are not allowed to access this authorization result.")
|
|
|
|
REDIS_CONN.delete(_web_result_cache_key(flow_id, "box"))
|
|
|
|
return get_json_result(data={"credentials": cache_raw})
|