mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-04 09:39:32 +08:00
Refactor: reformat all code for lefthook using ruff and gofmt (#16585)
This commit is contained in:
@@ -87,15 +87,11 @@ def get_nextUrl_key(pag_list: PaginatedList[PullRequest | Issue]) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def get_nextUrl(
|
||||
pag_list: PaginatedList[PullRequest | Issue], nextUrl_key: str
|
||||
) -> str | None:
|
||||
def get_nextUrl(pag_list: PaginatedList[PullRequest | Issue], nextUrl_key: str) -> str | None:
|
||||
return getattr(pag_list, nextUrl_key) if nextUrl_key else None
|
||||
|
||||
|
||||
def set_nextUrl(
|
||||
pag_list: PaginatedList[PullRequest | Issue], nextUrl_key: str, nextUrl: str
|
||||
) -> None:
|
||||
def set_nextUrl(pag_list: PaginatedList[PullRequest | Issue], nextUrl_key: str, nextUrl: str) -> None:
|
||||
if nextUrl_key:
|
||||
setattr(pag_list, nextUrl_key, nextUrl)
|
||||
elif nextUrl:
|
||||
@@ -119,11 +115,7 @@ def _paginate_until_error(
|
||||
# over previous calls. Unfortunately, this WILL retrieve all
|
||||
# pages before the one we are resuming from, so we really
|
||||
# don't want this case to be hit often
|
||||
logging.warning(
|
||||
"Retrying from a previous cursor-based pagination call. "
|
||||
"This will retrieve all pages before the one we are resuming from, "
|
||||
"which may take a while and consume many API calls."
|
||||
)
|
||||
logging.warning("Retrying from a previous cursor-based pagination call. This will retrieve all pages before the one we are resuming from, which may take a while and consume many API calls.")
|
||||
pag_list = cast(PaginatedList[PullRequest | Issue], pag_list[prev_num_objs:])
|
||||
num_objs = 0
|
||||
|
||||
@@ -137,9 +129,7 @@ def _paginate_until_error(
|
||||
cursor_url_callback(get_nextUrl(pag_list, nextUrl_key), num_objs)
|
||||
|
||||
if num_objs % CURSOR_LOG_FREQUENCY == 0:
|
||||
logging.info(
|
||||
f"Retrieved {num_objs} objects with current cursor url: {get_nextUrl(pag_list, nextUrl_key)}"
|
||||
)
|
||||
logging.info(f"Retrieved {num_objs} objects with current cursor url: {get_nextUrl(pag_list, nextUrl_key)}")
|
||||
|
||||
except Exception as e:
|
||||
logging.exception(f"Error during cursor-based pagination: {e}")
|
||||
@@ -147,14 +137,8 @@ def _paginate_until_error(
|
||||
raise
|
||||
|
||||
if get_nextUrl(pag_list, nextUrl_key) is not None and not retrying:
|
||||
logging.info(
|
||||
"Assuming that this error is due to cursor "
|
||||
"expiration because no objects were retrieved. "
|
||||
"Retrying from the first page."
|
||||
)
|
||||
yield from _paginate_until_error(
|
||||
git_objs, None, prev_num_objs, cursor_url_callback, retrying=True
|
||||
)
|
||||
logging.info("Assuming that this error is due to cursor expiration because no objects were retrieved. Retrying from the first page.")
|
||||
yield from _paginate_until_error(git_objs, None, prev_num_objs, cursor_url_callback, retrying=True)
|
||||
return
|
||||
|
||||
# for no cursor url or if we reach this point after a retry, raise the error
|
||||
@@ -174,16 +158,12 @@ def _get_batch_rate_limited(
|
||||
attempt_num: int = 0,
|
||||
) -> Generator[PullRequest | Issue, None, None]:
|
||||
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
|
||||
raise RuntimeError(
|
||||
"Re-tried fetching batch too many times. Something is going wrong with fetching objects from Github"
|
||||
)
|
||||
raise RuntimeError("Re-tried fetching batch too many times. Something is going wrong with fetching objects from Github")
|
||||
try:
|
||||
if cursor_url:
|
||||
# when this is set, we are resuming from an earlier
|
||||
# cursor-based pagination call.
|
||||
yield from _paginate_until_error(
|
||||
git_objs, cursor_url, prev_num_objs, cursor_url_callback
|
||||
)
|
||||
yield from _paginate_until_error(git_objs, cursor_url, prev_num_objs, cursor_url_callback)
|
||||
return
|
||||
objs = list(git_objs().get_page(page_num))
|
||||
# fetch all data here to disable lazy loading later
|
||||
@@ -204,13 +184,7 @@ def _get_batch_rate_limited(
|
||||
attempt_num + 1,
|
||||
)
|
||||
except GithubException as e:
|
||||
if not (
|
||||
e.status == 422
|
||||
and (
|
||||
"cursor" in (e.message or "")
|
||||
or "cursor" in (e.data or {}).get("message", "")
|
||||
)
|
||||
):
|
||||
if not (e.status == 422 and ("cursor" in (e.message or "") or "cursor" in (e.data or {}).get("message", ""))):
|
||||
raise
|
||||
# Fallback to a cursor-based pagination strategy
|
||||
# This can happen for "large datasets," but there's no documentation
|
||||
@@ -218,9 +192,7 @@ def _get_batch_rate_limited(
|
||||
# Error message:
|
||||
# "Pagination with the page parameter is not supported for large datasets,
|
||||
# please use cursor based pagination (after/before)"
|
||||
yield from _paginate_until_error(
|
||||
git_objs, cursor_url, prev_num_objs, cursor_url_callback
|
||||
)
|
||||
yield from _paginate_until_error(git_objs, cursor_url, prev_num_objs, cursor_url_callback)
|
||||
|
||||
|
||||
def _get_userinfo(user: NamedUser) -> dict[str, str]:
|
||||
@@ -242,28 +214,22 @@ def _get_userinfo(user: NamedUser) -> dict[str, str]:
|
||||
}
|
||||
|
||||
|
||||
def _convert_pr_to_document(
|
||||
pull_request: PullRequest, repo_external_access: ExternalAccess | None
|
||||
) -> Document:
|
||||
def _convert_pr_to_document(pull_request: PullRequest, repo_external_access: ExternalAccess | None) -> Document:
|
||||
repo_name = pull_request.base.repo.full_name if pull_request.base else ""
|
||||
doc_metadata = DocMetadata(repo=repo_name)
|
||||
file_content_byte = pull_request.body.encode('utf-8') if pull_request.body else b""
|
||||
file_content_byte = pull_request.body.encode("utf-8") if pull_request.body else b""
|
||||
name = sanitize_filename(pull_request.title, "md")
|
||||
|
||||
return Document(
|
||||
id=pull_request.html_url,
|
||||
blob= file_content_byte,
|
||||
blob=file_content_byte,
|
||||
source=DocumentSource.GITHUB,
|
||||
external_access=repo_external_access,
|
||||
semantic_identifier=f"{pull_request.number}:{name}",
|
||||
# updated_at is UTC time but is timezone unaware, explicitly add UTC
|
||||
# as there is logic in indexing to prevent wrong timestamped docs
|
||||
# due to local time discrepancies with UTC
|
||||
doc_updated_at=(
|
||||
pull_request.updated_at.replace(tzinfo=timezone.utc)
|
||||
if pull_request.updated_at
|
||||
else None
|
||||
),
|
||||
doc_updated_at=(pull_request.updated_at.replace(tzinfo=timezone.utc) if pull_request.updated_at else None),
|
||||
extension=".md",
|
||||
# this metadata is used in perm sync
|
||||
size_bytes=len(file_content_byte) if file_content_byte else 0,
|
||||
@@ -277,40 +243,16 @@ def _convert_pr_to_document(
|
||||
"merged": pull_request.merged,
|
||||
"state": pull_request.state,
|
||||
"user": _get_userinfo(pull_request.user) if pull_request.user else None,
|
||||
"assignees": [
|
||||
_get_userinfo(assignee) for assignee in pull_request.assignees
|
||||
],
|
||||
"repo": (
|
||||
pull_request.base.repo.full_name if pull_request.base else None
|
||||
),
|
||||
"assignees": [_get_userinfo(assignee) for assignee in pull_request.assignees],
|
||||
"repo": (pull_request.base.repo.full_name if pull_request.base else None),
|
||||
"num_commits": str(pull_request.commits),
|
||||
"num_files_changed": str(pull_request.changed_files),
|
||||
"labels": [label.name for label in pull_request.labels],
|
||||
"created_at": (
|
||||
pull_request.created_at.replace(tzinfo=timezone.utc)
|
||||
if pull_request.created_at
|
||||
else None
|
||||
),
|
||||
"updated_at": (
|
||||
pull_request.updated_at.replace(tzinfo=timezone.utc)
|
||||
if pull_request.updated_at
|
||||
else None
|
||||
),
|
||||
"closed_at": (
|
||||
pull_request.closed_at.replace(tzinfo=timezone.utc)
|
||||
if pull_request.closed_at
|
||||
else None
|
||||
),
|
||||
"merged_at": (
|
||||
pull_request.merged_at.replace(tzinfo=timezone.utc)
|
||||
if pull_request.merged_at
|
||||
else None
|
||||
),
|
||||
"merged_by": (
|
||||
_get_userinfo(pull_request.merged_by)
|
||||
if pull_request.merged_by
|
||||
else None
|
||||
),
|
||||
"created_at": (pull_request.created_at.replace(tzinfo=timezone.utc) if pull_request.created_at else None),
|
||||
"updated_at": (pull_request.updated_at.replace(tzinfo=timezone.utc) if pull_request.updated_at else None),
|
||||
"closed_at": (pull_request.closed_at.replace(tzinfo=timezone.utc) if pull_request.closed_at else None),
|
||||
"merged_at": (pull_request.merged_at.replace(tzinfo=timezone.utc) if pull_request.merged_at else None),
|
||||
"merged_by": (_get_userinfo(pull_request.merged_by) if pull_request.merged_by else None),
|
||||
}.items()
|
||||
if v is not None
|
||||
},
|
||||
@@ -322,12 +264,10 @@ def _fetch_issue_comments(issue: Issue) -> str:
|
||||
return "\nComment: ".join(comment.body for comment in comments)
|
||||
|
||||
|
||||
def _convert_issue_to_document(
|
||||
issue: Issue, repo_external_access: ExternalAccess | None
|
||||
) -> Document:
|
||||
def _convert_issue_to_document(issue: Issue, repo_external_access: ExternalAccess | None) -> Document:
|
||||
repo_name = issue.repository.full_name if issue.repository else ""
|
||||
doc_metadata = DocMetadata(repo=repo_name)
|
||||
file_content_byte = issue.body.encode('utf-8') if issue.body else b""
|
||||
file_content_byte = issue.body.encode("utf-8") if issue.body else b""
|
||||
name = sanitize_filename(issue.title, "md")
|
||||
|
||||
return Document(
|
||||
@@ -353,24 +293,10 @@ def _convert_issue_to_document(
|
||||
"assignees": [_get_userinfo(assignee) for assignee in issue.assignees],
|
||||
"repo": issue.repository.full_name if issue.repository else None,
|
||||
"labels": [label.name for label in issue.labels],
|
||||
"created_at": (
|
||||
issue.created_at.replace(tzinfo=timezone.utc)
|
||||
if issue.created_at
|
||||
else None
|
||||
),
|
||||
"updated_at": (
|
||||
issue.updated_at.replace(tzinfo=timezone.utc)
|
||||
if issue.updated_at
|
||||
else None
|
||||
),
|
||||
"closed_at": (
|
||||
issue.closed_at.replace(tzinfo=timezone.utc)
|
||||
if issue.closed_at
|
||||
else None
|
||||
),
|
||||
"closed_by": (
|
||||
_get_userinfo(issue.closed_by) if issue.closed_by else None
|
||||
),
|
||||
"created_at": (issue.created_at.replace(tzinfo=timezone.utc) if issue.created_at else None),
|
||||
"updated_at": (issue.updated_at.replace(tzinfo=timezone.utc) if issue.updated_at else None),
|
||||
"closed_at": (issue.closed_at.replace(tzinfo=timezone.utc) if issue.closed_at else None),
|
||||
"closed_by": (_get_userinfo(issue.closed_by) if issue.closed_by else None),
|
||||
}.items()
|
||||
if v is not None
|
||||
},
|
||||
@@ -451,13 +377,9 @@ class GithubConnector(CheckpointedConnectorWithPermSyncGH[GithubConnectorCheckpo
|
||||
|
||||
return None
|
||||
|
||||
def get_github_repo(
|
||||
self, github_client: Github, attempt_num: int = 0
|
||||
) -> Repository.Repository:
|
||||
def get_github_repo(self, github_client: Github, attempt_num: int = 0) -> Repository.Repository:
|
||||
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
|
||||
raise RuntimeError(
|
||||
"Re-tried fetching repo too many times. Something is going wrong with fetching objects from Github"
|
||||
)
|
||||
raise RuntimeError("Re-tried fetching repo too many times. Something is going wrong with fetching objects from Github")
|
||||
|
||||
try:
|
||||
return github_client.get_repo(f"{self.repo_owner}/{self.repositories}")
|
||||
@@ -465,21 +387,15 @@ class GithubConnector(CheckpointedConnectorWithPermSyncGH[GithubConnectorCheckpo
|
||||
sleep_after_rate_limit_exception(github_client)
|
||||
return self.get_github_repo(github_client, attempt_num + 1)
|
||||
|
||||
def get_github_repos(
|
||||
self, github_client: Github, attempt_num: int = 0
|
||||
) -> list[Repository.Repository]:
|
||||
def get_github_repos(self, github_client: Github, attempt_num: int = 0) -> list[Repository.Repository]:
|
||||
"""Get specific repositories based on comma-separated repo_name string."""
|
||||
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
|
||||
raise RuntimeError(
|
||||
"Re-tried fetching repos too many times. Something is going wrong with fetching objects from Github"
|
||||
)
|
||||
raise RuntimeError("Re-tried fetching repos too many times. Something is going wrong with fetching objects from Github")
|
||||
|
||||
try:
|
||||
repos = []
|
||||
# Split repo_name by comma and strip whitespace
|
||||
repo_names = [
|
||||
name.strip() for name in (cast(str, self.repositories)).split(",")
|
||||
]
|
||||
repo_names = [name.strip() for name in (cast(str, self.repositories)).split(",")]
|
||||
|
||||
for repo_name in repo_names:
|
||||
if repo_name: # Skip empty strings
|
||||
@@ -487,22 +403,16 @@ class GithubConnector(CheckpointedConnectorWithPermSyncGH[GithubConnectorCheckpo
|
||||
repo = github_client.get_repo(f"{self.repo_owner}/{repo_name}")
|
||||
repos.append(repo)
|
||||
except GithubException as e:
|
||||
logging.warning(
|
||||
f"Could not fetch repo {self.repo_owner}/{repo_name}: {e}"
|
||||
)
|
||||
logging.warning(f"Could not fetch repo {self.repo_owner}/{repo_name}: {e}")
|
||||
|
||||
return repos
|
||||
except RateLimitExceededException:
|
||||
sleep_after_rate_limit_exception(github_client)
|
||||
return self.get_github_repos(github_client, attempt_num + 1)
|
||||
|
||||
def get_all_repos(
|
||||
self, github_client: Github, attempt_num: int = 0
|
||||
) -> list[Repository.Repository]:
|
||||
def get_all_repos(self, github_client: Github, attempt_num: int = 0) -> list[Repository.Repository]:
|
||||
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
|
||||
raise RuntimeError(
|
||||
"Re-tried fetching repos too many times. Something is going wrong with fetching objects from Github"
|
||||
)
|
||||
raise RuntimeError("Re-tried fetching repos too many times. Something is going wrong with fetching objects from Github")
|
||||
|
||||
try:
|
||||
# Try to get organization first
|
||||
@@ -518,19 +428,11 @@ class GithubConnector(CheckpointedConnectorWithPermSyncGH[GithubConnectorCheckpo
|
||||
sleep_after_rate_limit_exception(github_client)
|
||||
return self.get_all_repos(github_client, attempt_num + 1)
|
||||
|
||||
def _pull_requests_func(
|
||||
self, repo: Repository.Repository
|
||||
) -> Callable[[], PaginatedList[PullRequest]]:
|
||||
return lambda: repo.get_pulls(
|
||||
state=self.state_filter, sort="updated", direction="desc"
|
||||
)
|
||||
def _pull_requests_func(self, repo: Repository.Repository) -> Callable[[], PaginatedList[PullRequest]]:
|
||||
return lambda: repo.get_pulls(state=self.state_filter, sort="updated", direction="desc")
|
||||
|
||||
def _issues_func(
|
||||
self, repo: Repository.Repository
|
||||
) -> Callable[[], PaginatedList[Issue]]:
|
||||
return lambda: repo.get_issues(
|
||||
state=self.state_filter, sort="updated", direction="desc"
|
||||
)
|
||||
def _issues_func(self, repo: Repository.Repository) -> Callable[[], PaginatedList[Issue]]:
|
||||
return lambda: repo.get_issues(state=self.state_filter, sort="updated", direction="desc")
|
||||
|
||||
def _fetch_from_github(
|
||||
self,
|
||||
@@ -582,9 +484,7 @@ class GithubConnector(CheckpointedConnectorWithPermSyncGH[GithubConnectorCheckpo
|
||||
cursor_url_callback = make_cursor_url_callback(checkpoint)
|
||||
repo_external_access: ExternalAccess | None = None
|
||||
if include_permissions:
|
||||
repo_external_access = get_external_access_permission(
|
||||
repo, self.github_client
|
||||
)
|
||||
repo_external_access = get_external_access_permission(repo, self.github_client)
|
||||
if self.include_prs and checkpoint.stage == GithubConnectorStage.PRS:
|
||||
logging.info(f"Fetching PRs for repo: {repo.name}")
|
||||
|
||||
@@ -603,31 +503,19 @@ class GithubConnector(CheckpointedConnectorWithPermSyncGH[GithubConnectorCheckpo
|
||||
for pr in pr_batch:
|
||||
num_prs += 1
|
||||
# we iterate backwards in time, so at this point we stop processing prs
|
||||
if (
|
||||
start is not None
|
||||
and pr.updated_at
|
||||
and pr.updated_at.replace(tzinfo=timezone.utc) <= start
|
||||
):
|
||||
if start is not None and pr.updated_at and pr.updated_at.replace(tzinfo=timezone.utc) <= start:
|
||||
done_with_prs = True
|
||||
break
|
||||
# Skip PRs updated after the end date
|
||||
if (
|
||||
end is not None
|
||||
and pr.updated_at
|
||||
and pr.updated_at.replace(tzinfo=timezone.utc) > end
|
||||
):
|
||||
if end is not None and pr.updated_at and pr.updated_at.replace(tzinfo=timezone.utc) > end:
|
||||
continue
|
||||
try:
|
||||
yield _convert_pr_to_document(
|
||||
cast(PullRequest, pr), repo_external_access
|
||||
)
|
||||
yield _convert_pr_to_document(cast(PullRequest, pr), repo_external_access)
|
||||
except Exception as e:
|
||||
error_msg = f"Error converting PR to document: {e}"
|
||||
logging.exception(error_msg)
|
||||
yield ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=str(pr.id), document_link=pr.html_url
|
||||
),
|
||||
failed_document=DocumentFailure(document_id=str(pr.id), document_link=pr.html_url),
|
||||
failure_message=error_msg,
|
||||
exception=e,
|
||||
)
|
||||
@@ -676,17 +564,11 @@ class GithubConnector(CheckpointedConnectorWithPermSyncGH[GithubConnectorCheckpo
|
||||
num_issues += 1
|
||||
issue = cast(Issue, issue)
|
||||
# we iterate backwards in time, so at this point we stop processing prs
|
||||
if (
|
||||
start is not None
|
||||
and issue.updated_at.replace(tzinfo=timezone.utc) <= start
|
||||
):
|
||||
if start is not None and issue.updated_at.replace(tzinfo=timezone.utc) <= start:
|
||||
done_with_issues = True
|
||||
break
|
||||
# Skip PRs updated after the end date
|
||||
if (
|
||||
end is not None
|
||||
and issue.updated_at.replace(tzinfo=timezone.utc) > end
|
||||
):
|
||||
if end is not None and issue.updated_at.replace(tzinfo=timezone.utc) > end:
|
||||
continue
|
||||
|
||||
if issue.pull_request is not None:
|
||||
@@ -731,9 +613,7 @@ class GithubConnector(CheckpointedConnectorWithPermSyncGH[GithubConnectorCheckpo
|
||||
checkpoint.reset()
|
||||
|
||||
if checkpoint.cached_repo_ids:
|
||||
logging.info(
|
||||
f"{len(checkpoint.cached_repo_ids)} checkpoint repos remaining (IDs: {checkpoint.cached_repo_ids})"
|
||||
)
|
||||
logging.info(f"{len(checkpoint.cached_repo_ids)} checkpoint repos remaining (IDs: {checkpoint.cached_repo_ids})")
|
||||
else:
|
||||
logging.info("There are no more checkpoint repos left.")
|
||||
|
||||
@@ -775,9 +655,7 @@ class GithubConnector(CheckpointedConnectorWithPermSyncGH[GithubConnectorCheckpo
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: GithubConnectorCheckpoint,
|
||||
) -> CheckpointOutput[GithubConnectorCheckpoint]:
|
||||
return self._load_from_checkpoint(
|
||||
start, end, checkpoint, include_permissions=False
|
||||
)
|
||||
return self._load_from_checkpoint(start, end, checkpoint, include_permissions=False)
|
||||
|
||||
@override
|
||||
def load_from_checkpoint_with_perm_sync(
|
||||
@@ -786,18 +664,14 @@ class GithubConnector(CheckpointedConnectorWithPermSyncGH[GithubConnectorCheckpo
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: GithubConnectorCheckpoint,
|
||||
) -> CheckpointOutput[GithubConnectorCheckpoint]:
|
||||
return self._load_from_checkpoint(
|
||||
start, end, checkpoint, include_permissions=True
|
||||
)
|
||||
return self._load_from_checkpoint(start, end, checkpoint, include_permissions=True)
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
if self.github_client is None:
|
||||
raise ConnectorMissingCredentialError("GitHub credentials not loaded.")
|
||||
|
||||
if not self.repo_owner:
|
||||
raise ConnectorValidationError(
|
||||
"Invalid connector settings: 'repo_owner' must be provided."
|
||||
)
|
||||
raise ConnectorValidationError("Invalid connector settings: 'repo_owner' must be provided.")
|
||||
|
||||
try:
|
||||
if self.repositories:
|
||||
@@ -805,9 +679,7 @@ class GithubConnector(CheckpointedConnectorWithPermSyncGH[GithubConnectorCheckpo
|
||||
# Multiple repositories specified
|
||||
repo_names = [name.strip() for name in self.repositories.split(",")]
|
||||
if not repo_names:
|
||||
raise ConnectorValidationError(
|
||||
"Invalid connector settings: No valid repository names provided."
|
||||
)
|
||||
raise ConnectorValidationError("Invalid connector settings: No valid repository names provided.")
|
||||
|
||||
# Validate at least one repository exists and is accessible
|
||||
valid_repos = False
|
||||
@@ -818,32 +690,22 @@ class GithubConnector(CheckpointedConnectorWithPermSyncGH[GithubConnectorCheckpo
|
||||
continue
|
||||
|
||||
try:
|
||||
test_repo = self.github_client.get_repo(
|
||||
f"{self.repo_owner}/{repo_name}"
|
||||
)
|
||||
logging.info(
|
||||
f"Successfully accessed repository: {self.repo_owner}/{repo_name}"
|
||||
)
|
||||
test_repo = self.github_client.get_repo(f"{self.repo_owner}/{repo_name}")
|
||||
logging.info(f"Successfully accessed repository: {self.repo_owner}/{repo_name}")
|
||||
test_repo.get_contents("")
|
||||
valid_repos = True
|
||||
# If at least one repo is valid, we can proceed
|
||||
break
|
||||
except GithubException as e:
|
||||
validation_errors.append(
|
||||
f"Repository '{repo_name}': {e.data.get('message', str(e))}"
|
||||
)
|
||||
validation_errors.append(f"Repository '{repo_name}': {e.data.get('message', str(e))}")
|
||||
|
||||
if not valid_repos:
|
||||
error_msg = (
|
||||
"None of the specified repositories could be accessed: "
|
||||
)
|
||||
error_msg = "None of the specified repositories could be accessed: "
|
||||
error_msg += ", ".join(validation_errors)
|
||||
raise ConnectorValidationError(error_msg)
|
||||
else:
|
||||
# Single repository (backward compatibility)
|
||||
test_repo = self.github_client.get_repo(
|
||||
f"{self.repo_owner}/{self.repositories}"
|
||||
)
|
||||
test_repo = self.github_client.get_repo(f"{self.repo_owner}/{self.repositories}")
|
||||
test_repo.get_contents("")
|
||||
else:
|
||||
# Try to get organization first
|
||||
@@ -851,10 +713,7 @@ class GithubConnector(CheckpointedConnectorWithPermSyncGH[GithubConnectorCheckpo
|
||||
org = self.github_client.get_organization(self.repo_owner)
|
||||
total_count = org.get_repos().totalCount
|
||||
if total_count == 0:
|
||||
raise ConnectorValidationError(
|
||||
f"Found no repos for organization: {self.repo_owner}. "
|
||||
"Does the credential have the right scopes?"
|
||||
)
|
||||
raise ConnectorValidationError(f"Found no repos for organization: {self.repo_owner}. Does the credential have the right scopes?")
|
||||
except GithubException as e:
|
||||
# Check for missing SSO
|
||||
MISSING_SSO_ERROR_MESSAGE = "You must grant your Personal Access token access to this organization".lower()
|
||||
@@ -865,9 +724,7 @@ class GithubConnector(CheckpointedConnectorWithPermSyncGH[GithubConnectorCheckpo
|
||||
"authorizing-a-personal-access-token-for-use-with-saml-single-sign-on"
|
||||
)
|
||||
raise ConnectorValidationError(
|
||||
f"Your GitHub token is missing authorization to access the "
|
||||
f"`{self.repo_owner}` organization. Please follow the guide to "
|
||||
f"authorize your token: {SSO_GUIDE_LINK}"
|
||||
f"Your GitHub token is missing authorization to access the `{self.repo_owner}` organization. Please follow the guide to authorize your token: {SSO_GUIDE_LINK}"
|
||||
)
|
||||
# If not an org, try as a user
|
||||
user = self.github_client.get_user(self.repo_owner)
|
||||
@@ -875,52 +732,31 @@ class GithubConnector(CheckpointedConnectorWithPermSyncGH[GithubConnectorCheckpo
|
||||
# Check if we can access any repos
|
||||
total_count = user.get_repos().totalCount
|
||||
if total_count == 0:
|
||||
raise ConnectorValidationError(
|
||||
f"Found no repos for user: {self.repo_owner}. "
|
||||
"Does the credential have the right scopes?"
|
||||
)
|
||||
raise ConnectorValidationError(f"Found no repos for user: {self.repo_owner}. Does the credential have the right scopes?")
|
||||
|
||||
except RateLimitExceededException:
|
||||
raise UnexpectedValidationError(
|
||||
"Validation failed due to GitHub rate-limits being exceeded. Please try again later."
|
||||
)
|
||||
raise UnexpectedValidationError("Validation failed due to GitHub rate-limits being exceeded. Please try again later.")
|
||||
|
||||
except GithubException as e:
|
||||
if e.status == 401:
|
||||
raise CredentialExpiredError(
|
||||
"GitHub credential appears to be invalid or expired (HTTP 401)."
|
||||
)
|
||||
raise CredentialExpiredError("GitHub credential appears to be invalid or expired (HTTP 401).")
|
||||
elif e.status == 403:
|
||||
raise InsufficientPermissionsError(
|
||||
"Your GitHub token does not have sufficient permissions for this repository (HTTP 403)."
|
||||
)
|
||||
raise InsufficientPermissionsError("Your GitHub token does not have sufficient permissions for this repository (HTTP 403).")
|
||||
elif e.status == 404:
|
||||
if self.repositories:
|
||||
if "," in self.repositories:
|
||||
raise ConnectorValidationError(
|
||||
f"None of the specified GitHub repositories could be found for owner: {self.repo_owner}"
|
||||
)
|
||||
raise ConnectorValidationError(f"None of the specified GitHub repositories could be found for owner: {self.repo_owner}")
|
||||
else:
|
||||
raise ConnectorValidationError(
|
||||
f"GitHub repository not found with name: {self.repo_owner}/{self.repositories}"
|
||||
)
|
||||
raise ConnectorValidationError(f"GitHub repository not found with name: {self.repo_owner}/{self.repositories}")
|
||||
else:
|
||||
raise ConnectorValidationError(
|
||||
f"GitHub user or organization not found: {self.repo_owner}"
|
||||
)
|
||||
raise ConnectorValidationError(f"GitHub user or organization not found: {self.repo_owner}")
|
||||
else:
|
||||
raise ConnectorValidationError(
|
||||
f"Unexpected GitHub error (status={e.status}): {e.data}"
|
||||
)
|
||||
raise ConnectorValidationError(f"Unexpected GitHub error (status={e.status}): {e.data}")
|
||||
|
||||
except Exception as exc:
|
||||
raise Exception(
|
||||
f"Unexpected error during GitHub settings validation: {exc}"
|
||||
)
|
||||
raise Exception(f"Unexpected error during GitHub settings validation: {exc}")
|
||||
|
||||
def validate_checkpoint_json(
|
||||
self, checkpoint_json: str
|
||||
) -> GithubConnectorCheckpoint:
|
||||
def validate_checkpoint_json(self, checkpoint_json: str) -> GithubConnectorCheckpoint:
|
||||
return GithubConnectorCheckpoint.model_validate_json(checkpoint_json)
|
||||
|
||||
def retrieve_slim_document(
|
||||
@@ -930,17 +766,13 @@ class GithubConnector(CheckpointedConnectorWithPermSyncGH[GithubConnectorCheckpo
|
||||
callback: Any = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
start_value = 0.0 if start is None else start
|
||||
end_value = (
|
||||
datetime.now(timezone.utc).timestamp() if end is None else end
|
||||
)
|
||||
end_value = datetime.now(timezone.utc).timestamp() if end is None else end
|
||||
checkpoint = self.build_dummy_checkpoint()
|
||||
slim_batch: list[SlimDocument] = []
|
||||
|
||||
while checkpoint.has_more:
|
||||
wrapper = CheckpointOutputWrapper[GithubConnectorCheckpoint]()
|
||||
for document, failure, next_checkpoint in wrapper(
|
||||
self.load_from_checkpoint(start_value, end_value, checkpoint)
|
||||
):
|
||||
for document, failure, next_checkpoint in wrapper(self.load_from_checkpoint(start_value, end_value, checkpoint)):
|
||||
if failure is not None:
|
||||
logging.warning(
|
||||
"GitHub connector failure during slim retrieval: %s",
|
||||
@@ -969,9 +801,7 @@ class GithubConnector(CheckpointedConnectorWithPermSyncGH[GithubConnectorCheckpo
|
||||
yield from self.retrieve_slim_document(callback=callback)
|
||||
|
||||
def build_dummy_checkpoint(self) -> GithubConnectorCheckpoint:
|
||||
return GithubConnectorCheckpoint(
|
||||
stage=GithubConnectorStage.PRS, curr_page=0, has_more=True, num_retrieved=0
|
||||
)
|
||||
return GithubConnectorCheckpoint(stage=GithubConnectorStage.PRS, curr_page=0, has_more=True, num_retrieved=0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -982,9 +812,7 @@ if __name__ == "__main__":
|
||||
include_issues=True,
|
||||
include_prs=False,
|
||||
)
|
||||
connector.load_credentials(
|
||||
{"github_access_token": "<Your_GitHub_Access_Token>"}
|
||||
)
|
||||
connector.load_credentials({"github_access_token": "<Your_GitHub_Access_Token>"})
|
||||
|
||||
if connector.github_client:
|
||||
get_external_access_permission(
|
||||
@@ -998,9 +826,7 @@ if __name__ == "__main__":
|
||||
time_range = (start_time, end_time)
|
||||
|
||||
# Initialize the runner with a batch size of 10
|
||||
runner: ConnectorRunner[GithubConnectorCheckpoint] = ConnectorRunner(
|
||||
connector, batch_size=10, include_permissions=False, time_range=time_range
|
||||
)
|
||||
runner: ConnectorRunner[GithubConnectorCheckpoint] = ConnectorRunner(connector, batch_size=10, include_permissions=False, time_range=time_range)
|
||||
|
||||
# Get initial checkpoint
|
||||
checkpoint = connector.build_dummy_checkpoint()
|
||||
|
||||
@@ -12,6 +12,4 @@ class SerializedRepository(BaseModel):
|
||||
raw_data: dict[str, Any]
|
||||
|
||||
def to_Repository(self, requester: Requester) -> Repository.Repository:
|
||||
return Repository.Repository(
|
||||
requester, self.headers, self.raw_data, completed=True
|
||||
)
|
||||
return Repository.Repository(requester, self.headers, self.raw_data, completed=True)
|
||||
|
||||
@@ -14,11 +14,7 @@ def sleep_after_rate_limit_exception(github_client: Github) -> None:
|
||||
Args:
|
||||
github_client: The GitHub client that hit the rate limit
|
||||
"""
|
||||
sleep_time = github_client.get_rate_limit().core.reset.replace(
|
||||
tzinfo=timezone.utc
|
||||
) - datetime.now(tz=timezone.utc)
|
||||
sleep_time = github_client.get_rate_limit().core.reset.replace(tzinfo=timezone.utc) - datetime.now(tz=timezone.utc)
|
||||
sleep_time += timedelta(minutes=1) # add an extra minute just to be safe
|
||||
logging.info(
|
||||
"Ran into Github rate-limit. Sleeping %s seconds.", sleep_time.seconds
|
||||
)
|
||||
time.sleep(sleep_time.total_seconds())
|
||||
logging.info("Ran into Github rate-limit. Sleeping %s seconds.", sleep_time.seconds)
|
||||
time.sleep(sleep_time.total_seconds())
|
||||
|
||||
@@ -8,9 +8,7 @@ from common.data_source.models import ExternalAccess
|
||||
from .models import SerializedRepository
|
||||
|
||||
|
||||
def get_external_access_permission(
|
||||
repo: Repository, github_client: Github
|
||||
) -> ExternalAccess:
|
||||
def get_external_access_permission(repo: Repository, github_client: Github) -> ExternalAccess:
|
||||
"""
|
||||
Get the external access permission for a repository.
|
||||
This functionality requires Enterprise Edition.
|
||||
@@ -20,9 +18,7 @@ def get_external_access_permission(
|
||||
return ExternalAccess.empty()
|
||||
|
||||
|
||||
def deserialize_repository(
|
||||
cached_repo: SerializedRepository, github_client: Github
|
||||
) -> Repository:
|
||||
def deserialize_repository(cached_repo: SerializedRepository, github_client: Github) -> Repository:
|
||||
"""
|
||||
Deserialize a SerializedRepository back into a Repository object.
|
||||
"""
|
||||
@@ -41,4 +37,4 @@ def deserialize_repository(
|
||||
# If all else fails, re-fetch the repo directly
|
||||
logging.warning("Failed to deserialize repository: %s. Attempting to re-fetch.", e)
|
||||
repo_id = cached_repo.id
|
||||
return github_client.get_repo(repo_id)
|
||||
return github_client.get_repo(repo_id)
|
||||
|
||||
Reference in New Issue
Block a user