diff --git a/admin/client/http_client.py b/admin/client/http_client.py index bf18466ebc..3cc62db1a3 100644 --- a/admin/client/http_client.py +++ b/admin/client/http_client.py @@ -25,14 +25,14 @@ import requests class HttpClient: def __init__( - self, - host: str = "127.0.0.1", - port: int = 9381, - api_version: str = "v1", - api_key: Optional[str] = None, - connect_timeout: float = 5.0, - read_timeout: float = 60.0, - verify_ssl: bool = False, + self, + host: str = "127.0.0.1", + port: int = 9381, + api_version: str = "v1", + api_key: Optional[str] = None, + connect_timeout: float = 5.0, + read_timeout: float = 60.0, + verify_ssl: bool = False, ) -> None: self.host = host self.port = port @@ -71,19 +71,19 @@ class HttpClient: return headers def request( - self, - method: str, - path: str, - *, - use_api_base: bool = True, - auth_kind: Optional[str] = "api", - headers: Optional[Dict[str, str]] = None, - json_body: Optional[Dict[str, Any]] = None, - data: Any = None, - files: Any = None, - params: Optional[Dict[str, Any]] = None, - stream: bool = False, - iterations: int = 1, + self, + method: str, + path: str, + *, + use_api_base: bool = True, + auth_kind: Optional[str] = "api", + headers: Optional[Dict[str, str]] = None, + json_body: Optional[Dict[str, Any]] = None, + data: Any = None, + files: Any = None, + params: Optional[Dict[str, Any]] = None, + stream: bool = False, + iterations: int = 1, ) -> requests.Response | dict: url = self.build_url(path, use_api_base=use_api_base) merged_headers = self._headers(auth_kind, headers) @@ -144,18 +144,18 @@ class HttpClient: # ) def request_json( - self, - method: str, - path: str, - *, - use_api_base: bool = True, - auth_kind: Optional[str] = "api", - headers: Optional[Dict[str, str]] = None, - json_body: Optional[Dict[str, Any]] = None, - data: Any = None, - files: Any = None, - params: Optional[Dict[str, Any]] = None, - stream: bool = False, + self, + method: str, + path: str, + *, + use_api_base: bool = True, + auth_kind: Optional[str] = "api", + headers: Optional[Dict[str, str]] = None, + json_body: Optional[Dict[str, Any]] = None, + data: Any = None, + files: Any = None, + params: Optional[Dict[str, Any]] = None, + stream: bool = False, ) -> Dict[str, Any]: response = self.request( method, diff --git a/admin/client/parser.py b/admin/client/parser.py index 7e668c4e29..8f91352ad1 100644 --- a/admin/client/parser.py +++ b/admin/client/parser.py @@ -336,8 +336,8 @@ reset_default_asr: RESET DEFAULT ASR ";" reset_default_tts: RESET DEFAULT TTS ";" list_user_datasets: LIST DATASETS ";" -create_user_dataset_with_parser: CREATE DATASET quoted_string WITH EMBEDDING quoted_string PARSER quoted_string ";" -create_user_dataset_with_pipeline: CREATE DATASET quoted_string WITH EMBEDDING quoted_string PIPELINE quoted_string ";" +create_user_dataset_with_parser: CREATE DATASET quoted_string WITH EMBEDDING quoted_string PARSER quoted_string ";" +create_user_dataset_with_pipeline: CREATE DATASET quoted_string WITH EMBEDDING quoted_string PIPELINE quoted_string ";" drop_user_dataset: DROP DATASET quoted_string ";" list_user_dataset_files: LIST FILES OF DATASET quoted_string ";" list_user_dataset_documents: LIST DOCUMENTS OF DATASET quoted_string ";" @@ -640,15 +640,13 @@ class RAGFlowCLITransformer(Transformer): dataset_name = items[2].children[0].strip("'\"") embedding = items[5].children[0].strip("'\"") parser_type = items[7].children[0].strip("'\"") - return {"type": "create_user_dataset", "dataset_name": dataset_name, "embedding": embedding, - "parser_type": parser_type} + return {"type": "create_user_dataset", "dataset_name": dataset_name, "embedding": embedding, "parser_type": parser_type} def create_user_dataset_with_pipeline(self, items): dataset_name = items[2].children[0].strip("'\"") embedding = items[5].children[0].strip("'\"") pipeline = items[7].children[0].strip("'\"") - return {"type": "create_user_dataset", "dataset_name": dataset_name, "embedding": embedding, - "pipeline": pipeline} + return {"type": "create_user_dataset", "dataset_name": dataset_name, "embedding": embedding, "pipeline": pipeline} def drop_user_dataset(self, items): dataset_name = items[2].children[0].strip("'\"") @@ -666,7 +664,7 @@ class RAGFlowCLITransformer(Transformer): dataset_names = [] dataset_names.append(items[4].children[0].strip("'\"")) for i in range(5, len(items)): - if items[i] and hasattr(items[i], 'children') and items[i].children: + if items[i] and hasattr(items[i], "children") and items[i].children: dataset_names.append(items[i].children[0].strip("'\"")) return {"type": "list_user_datasets_metadata", "dataset_names": dataset_names} @@ -675,7 +673,7 @@ class RAGFlowCLITransformer(Transformer): doc_ids = [] if len(items) > 6 and items[6] == "DOCUMENTS": for i in range(7, len(items)): - if items[i] and hasattr(items[i], 'children') and items[i].children: + if items[i] and hasattr(items[i], "children") and items[i].children: doc_id = items[i].children[0].strip("'\"") doc_ids.append(doc_id) return {"type": "list_user_documents_metadata_summary", "dataset_name": dataset_name, "document_ids": doc_ids} @@ -698,17 +696,17 @@ class RAGFlowCLITransformer(Transformer): dataset_name = None vector_size = None for i, item in enumerate(items): - if hasattr(item, 'data') and item.data == 'quoted_string': + if hasattr(item, "data") and item.data == "quoted_string": dataset_name = item.children[0].strip("'\"") - if hasattr(item, 'type') and item.type == 'NUMBER': - if i > 0 and items[i-1].type == 'SIZE' and items[i-2].type == 'VECTOR': + if hasattr(item, "type") and item.type == "NUMBER": + if i > 0 and items[i - 1].type == "SIZE" and items[i - 2].type == "VECTOR": vector_size = int(item) return {"type": "create_dataset_table", "dataset_name": dataset_name, "vector_size": vector_size} def drop_dataset_table(self, items): dataset_name = None for item in items: - if hasattr(item, 'data') and item.data == 'quoted_string': + if hasattr(item, "data") and item.data == "quoted_string": dataset_name = item.children[0].strip("'\"") return {"type": "drop_dataset_table", "dataset_name": dataset_name} @@ -792,7 +790,7 @@ class RAGFlowCLITransformer(Transformer): def update_chunk(self, items): def get_quoted_value(item): - if hasattr(item, 'children') and item.children: + if hasattr(item, "children") and item.children: return item.children[0].strip("'\"") return str(item).strip("'\"") @@ -813,16 +811,16 @@ class RAGFlowCLITransformer(Transformer): for i in range(2, len(items)): item = items[i] # Check for FROM token to stop - if hasattr(item, 'type') and item.type == 'FROM': + if hasattr(item, "type") and item.type == "FROM": break - if hasattr(item, 'children') and item.children: + if hasattr(item, "children") and item.children: tag = item.children[0].strip("'\"") tags.append(tag) # Find dataset_name: quoted_string after DATASET dataset_name = None for i, item in enumerate(items): # Check if item is a DATASET token - if hasattr(item, 'type') and item.type == 'DATASET': + if hasattr(item, "type") and item.type == "DATASET": # Next item should be quoted_string dataset_name = items[i + 1].children[0].strip("'\"") break @@ -835,10 +833,10 @@ class RAGFlowCLITransformer(Transformer): # Check if it's "REMOVE ALL CHUNKS" for item in items: - if hasattr(item, 'type') and item.type == 'ALL': + if hasattr(item, "type") and item.type == "ALL": # Find doc_id for j, inner_item in enumerate(items): - if hasattr(inner_item, 'type') and inner_item.type == 'DOCUMENT': + if hasattr(inner_item, "type") and inner_item.type == "DOCUMENT": doc_id = items[j + 1].children[0].strip("'\"") return {"type": "remove_chunks", "doc_id": doc_id, "delete_all": True} @@ -846,12 +844,12 @@ class RAGFlowCLITransformer(Transformer): chunk_ids = [] doc_id = None for i, item in enumerate(items): - if hasattr(item, 'type') and item.type == 'DOCUMENT': + if hasattr(item, "type") and item.type == "DOCUMENT": doc_id = items[i + 1].children[0].strip("'\"") - elif hasattr(item, 'children') and item.children: + elif hasattr(item, "children") and item.children: val = item.children[0].strip("'\"") # Skip if it's "FROM" or "DOCUMENT" - if val.upper() in ['FROM', 'DOCUMENT']: + if val.upper() in ["FROM", "DOCUMENT"]: continue chunk_ids.append(val) diff --git a/admin/client/ragflow_cli.py b/admin/client/ragflow_cli.py index e7378790cc..1a48af07ba 100644 --- a/admin/client/ragflow_cli.py +++ b/admin/client/ragflow_cli.py @@ -36,6 +36,7 @@ from user import login_user warnings.filterwarnings("ignore", category=getpass.GetPassWarning) + def encrypt(input_string): pub = "-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEArq9XTUSeYr2+N1h3Afl/z8Dse/2yD0ZGrKwx+EEEcdsBLca9Ynmx3nIB5obmLlSfmskLpBo0UACBmB5rEjBp2Q2f3AG3Hjd4B+gNCG6BDaawuDlgANIhGnaTLrIqWrrcm4EMzJOnAOI1fgzJRsOOUEfaS318Eq9OVO3apEyCCt0lOQK6PuksduOjVxtltDav+guVAA068NrPYmRNabVKRNLJpL8w4D44sfth5RvZ3q9t+6RTArpEtc5sh5ChzvqPOzKGMXW83C95TxmXqpbK6olN4RevSfVjEAgCydH6HN6OhtOQEcnrU97r9H0iZOWwbw3pVrZiUkuRD1R56Wzs2wIDAQAB\n-----END PUBLIC KEY-----" pub_key = RSA.importKey(pub) @@ -49,9 +50,6 @@ def encode_to_base64(input_string): return base64_encoded.decode("utf-8") - - - class RAGFlowCLI(Cmd): def __init__(self): super().__init__() @@ -240,9 +238,9 @@ class RAGFlowCLI(Cmd): print(r""" ____ ___ ______________ ________ ____ / __ \/ | / ____/ ____/ /___ _ __ / ____/ / / _/ - / /_/ / /| |/ / __/ /_ / / __ \ | /| / / / / / / / / - / _, _/ ___ / /_/ / __/ / / /_/ / |/ |/ / / /___/ /____/ / - /_/ |_/_/ |_\____/_/ /_/\____/|__/|__/ \____/_____/___/ + / /_/ / /| |/ / __/ /_ / / __ \ | /| / / / / / / / / + / _, _/ ___ / /_/ / __/ / / /_/ / |/ |/ / / /___/ /____/ / + /_/ |_/_/ |_\____/_/ /_/\____/|__/|__/ \____/_____/___/ """) self.cmdloop() @@ -254,15 +252,13 @@ class RAGFlowCLI(Cmd): result = self.parse_command(command) self.execute_command(result) - def parse_connection_args(self, args: List[str]) -> Dict[str, Any]: parser = argparse.ArgumentParser(description="RAGFlow CLI Client", add_help=False) parser.add_argument("-h", "--host", default="127.0.0.1", help="Admin or RAGFlow service host") parser.add_argument("-p", "--port", type=int, default=9381, help="Admin or RAGFlow service port") parser.add_argument("-w", "--password", default="admin", type=str, help="Superuser password") parser.add_argument("-t", "--type", default="admin", type=str, help="CLI mode, admin or user") - parser.add_argument("-u", "--username", default=None, - help="Username (email). In admin mode defaults to admin@ragflow.io, in user mode required.") + parser.add_argument("-u", "--username", default=None, help="Username (email). In admin mode defaults to admin@ragflow.io, in user mode required.") parser.add_argument("command", nargs="?", help="Single command") try: parsed_args, remaining_args = parser.parse_known_args(args) @@ -274,7 +270,7 @@ class RAGFlowCLI(Cmd): if remaining_args: if remaining_args[0] == "command": - command_str = ' '.join(remaining_args[1:]) + ';' + command_str = " ".join(remaining_args[1:]) + ";" auth = True if remaining_args[1] == "register": auth = False @@ -282,28 +278,14 @@ class RAGFlowCLI(Cmd): if username is None: print("Error: username (-u) is required in user mode") return {"error": "Username required"} - return { - "host": parsed_args.host, - "port": parsed_args.port, - "password": parsed_args.password, - "type": parsed_args.type, - "username": username, - "command": command_str, - "auth": auth - } + return {"host": parsed_args.host, "port": parsed_args.port, "password": parsed_args.password, "type": parsed_args.type, "username": username, "command": command_str, "auth": auth} else: return {"error": "Invalid command"} else: auth = True if username is None: auth = False - return { - "host": parsed_args.host, - "port": parsed_args.port, - "type": parsed_args.type, - "username": username, - "auth": auth - } + return {"host": parsed_args.host, "port": parsed_args.port, "type": parsed_args.type, "username": username, "auth": auth} except SystemExit: return {"error": "Invalid connection arguments"} @@ -321,6 +303,7 @@ class RAGFlowCLI(Cmd): # print(f"Parsed command: {command_dict}") run_command(self.ragflow_client, command_dict) + def main(): cli = RAGFlowCLI() diff --git a/admin/client/ragflow_client.py b/admin/client/ragflow_client.py index b8102520ad..a41e8926eb 100644 --- a/admin/client/ragflow_client.py +++ b/admin/client/ragflow_client.py @@ -71,6 +71,7 @@ class RAGFlowClient: user_password: str = command.get("password") if not user_password: import getpass + user_password = getpass.getpass("Password: ") try: token = login_user(self.http_client, self.server_type, email, user_password) @@ -86,8 +87,7 @@ class RAGFlowClient: def ping_server(self, command): iterations = command.get("iterations", 1) if iterations > 1: - response = self.http_client.request("GET", "/system/ping", use_api_base=True, auth_kind="web", - iterations=iterations) + response = self.http_client.request("GET", "/system/ping", use_api_base=True, auth_kind="web", iterations=iterations) return response else: response = self.http_client.request("GET", "/system/ping", use_api_base=True, auth_kind="web") @@ -106,8 +106,7 @@ class RAGFlowClient: enc_password = encrypt_password(password) print(f"Register user: {nickname}, email: {username}, password: ******") payload = {"email": username, "nickname": nickname, "password": enc_password} - response = self.http_client.request(method="POST", path="/users", - json_body=payload, use_api_base=True, auth_kind="web") + response = self.http_client.request(method="POST", path="/users", json_body=payload, use_api_base=True, auth_kind="web") res_json = response.json() if response.status_code == 200: if res_json["code"] == 0: @@ -135,8 +134,7 @@ class RAGFlowClient: service_id: int = command["number"] - response = self.http_client.request("GET", f"/admin/services/{service_id}", use_api_base=True, - auth_kind="admin") + response = self.http_client.request("GET", f"/admin/services/{service_id}", use_api_base=True, auth_kind="admin") res_json = response.json() if response.status_code == 200: res_data = res_json["data"] @@ -226,9 +224,7 @@ class RAGFlowClient: password_tree: Tree = command["password"] password: str = password_tree.children[0].strip("'\"") print(f"Alter user: {user_name}, password: ******") - response = self.http_client.request("PUT", f"/admin/users/{user_name}/password", - json_body={"new_password": encrypt_password(password)}, use_api_base=True, - auth_kind="admin") + response = self.http_client.request("PUT", f"/admin/users/{user_name}/password", json_body={"new_password": encrypt_password(password)}, use_api_base=True, auth_kind="admin") res_json = response.json() if response.status_code == 200: print(res_json["message"]) @@ -247,9 +243,7 @@ class RAGFlowClient: print(f"Create user: {user_name}, password: ******, role: {role}") # enpass1 = encrypt(password) enc_password = encrypt_password(password) - response = self.http_client.request(method="POST", path="/admin/users", - json_body={"username": user_name, "password": enc_password, "role": role}, - use_api_base=True, auth_kind="admin") + response = self.http_client.request(method="POST", path="/admin/users", json_body={"username": user_name, "password": enc_password, "role": role}, use_api_base=True, auth_kind="admin") res_json = response.json() if response.status_code == 200: self._print_table_simple(res_json["data"]) @@ -266,9 +260,7 @@ class RAGFlowClient: activate_status: str = activate_tree.children[0].strip("'\"") if activate_status.lower() in ["on", "off"]: print(f"Alter user {user_name} activate status, turn {activate_status.lower()}.") - response = self.http_client.request("PUT", f"/admin/users/{user_name}/activate", - json_body={"activate_status": activate_status}, use_api_base=True, - auth_kind="admin") + response = self.http_client.request("PUT", f"/admin/users/{user_name}/activate", json_body={"activate_status": activate_status}, use_api_base=True, auth_kind="admin") res_json = response.json() if response.status_code == 200: print(res_json["message"]) @@ -283,14 +275,12 @@ class RAGFlowClient: user_name_tree: Tree = command["user_name"] user_name: str = user_name_tree.children[0].strip("'\"") - response = self.http_client.request("PUT", f"/admin/users/{user_name}/admin", use_api_base=True, - auth_kind="admin") + response = self.http_client.request("PUT", f"/admin/users/{user_name}/admin", use_api_base=True, auth_kind="admin") res_json = response.json() if response.status_code == 200: print(res_json["message"]) else: - print( - f"Fail to grant {user_name} admin authorization, code: {res_json['code']}, message: {res_json['message']}") + print(f"Fail to grant {user_name} admin authorization, code: {res_json['code']}, message: {res_json['message']}") def revoke_admin(self, command): if self.server_type != "admin": @@ -298,14 +288,12 @@ class RAGFlowClient: user_name_tree: Tree = command["user_name"] user_name: str = user_name_tree.children[0].strip("'\"") - response = self.http_client.request("DELETE", f"/admin/users/{user_name}/admin", use_api_base=True, - auth_kind="admin") + response = self.http_client.request("DELETE", f"/admin/users/{user_name}/admin", use_api_base=True, auth_kind="admin") res_json = response.json() if response.status_code == 200: print(res_json["message"]) else: - print( - f"Fail to revoke {user_name} admin authorization, code: {res_json['code']}, message: {res_json['message']}") + print(f"Fail to revoke {user_name} admin authorization, code: {res_json['code']}, message: {res_json['message']}") def create_role(self, command): if self.server_type != "admin": @@ -319,10 +307,7 @@ class RAGFlowClient: desc_str = desc_tree.children[0].strip("'\"") print(f"create role name: {role_name}, description: {desc_str}") - response = self.http_client.request("POST", "/admin/roles", - json_body={"role_name": role_name, "description": desc_str}, - use_api_base=True, - auth_kind="admin") + response = self.http_client.request("POST", "/admin/roles", json_body={"role_name": role_name, "description": desc_str}, use_api_base=True, auth_kind="admin") res_json = response.json() if response.status_code == 200: self._print_table_simple(res_json["data"]) @@ -336,9 +321,7 @@ class RAGFlowClient: role_name_tree: Tree = command["role_name"] role_name: str = role_name_tree.children[0].strip("'\"") print(f"drop role name: {role_name}") - response = self.http_client.request("DELETE", f"/admin/roles/{role_name}", - use_api_base=True, - auth_kind="admin") + response = self.http_client.request("DELETE", f"/admin/roles/{role_name}", use_api_base=True, auth_kind="admin") res_json = response.json() if response.status_code == 200: self._print_table_simple(res_json["data"]) @@ -355,24 +338,18 @@ class RAGFlowClient: desc_str: str = desc_tree.children[0].strip("'\"") print(f"alter role name: {role_name}, description: {desc_str}") - response = self.http_client.request("PUT", f"/admin/roles/{role_name}", - json_body={"description": desc_str}, - use_api_base=True, - auth_kind="admin") + response = self.http_client.request("PUT", f"/admin/roles/{role_name}", json_body={"description": desc_str}, use_api_base=True, auth_kind="admin") res_json = response.json() if response.status_code == 200: self._print_table_simple(res_json["data"]) else: - print( - f"Fail to update role {role_name} with description: {desc_str}, code: {res_json['code']}, message: {res_json['message']}") + print(f"Fail to update role {role_name} with description: {desc_str}, code: {res_json['code']}, message: {res_json['message']}") def list_roles(self, command): if self.server_type != "admin": print("This command is only allowed in ADMIN mode") - response = self.http_client.request("GET", "/admin/roles", - use_api_base=True, - auth_kind="admin") + response = self.http_client.request("GET", "/admin/roles", use_api_base=True, auth_kind="admin") res_json = response.json() if response.status_code == 200: self._print_table_simple(res_json["data"]) @@ -386,9 +363,7 @@ class RAGFlowClient: role_name_tree: Tree = command["role_name"] role_name: str = role_name_tree.children[0].strip("'\"") print(f"show role: {role_name}") - response = self.http_client.request("GET", f"/admin/roles/{role_name}/permission", - use_api_base=True, - auth_kind="admin") + response = self.http_client.request("GET", f"/admin/roles/{role_name}/permission", use_api_base=True, auth_kind="admin") res_json = response.json() if response.status_code == 200: self._print_table_simple(res_json["data"]) @@ -409,15 +384,12 @@ class RAGFlowClient: action_str: str = action_tree.children[0].strip("'\"") actions.append(action_str) print(f"grant role_name: {role_name_str}, resource: {resource_str}, actions: {actions}") - response = self.http_client.request("POST", f"/admin/roles/{role_name_str}/permission", - json_body={"actions": actions, "resource": resource_str}, use_api_base=True, - auth_kind="admin") + response = self.http_client.request("POST", f"/admin/roles/{role_name_str}/permission", json_body={"actions": actions, "resource": resource_str}, use_api_base=True, auth_kind="admin") res_json = response.json() if response.status_code == 200: self._print_table_simple(res_json["data"]) else: - print( - f"Fail to grant role {role_name_str} with {actions} on {resource_str}, code: {res_json['code']}, message: {res_json['message']}") + print(f"Fail to grant role {role_name_str} with {actions} on {resource_str}, code: {res_json['code']}, message: {res_json['message']}") def revoke_permission(self, command): if self.server_type != "admin": @@ -433,15 +405,12 @@ class RAGFlowClient: action_str: str = action_tree.children[0].strip("'\"") actions.append(action_str) print(f"revoke role_name: {role_name_str}, resource: {resource_str}, actions: {actions}") - response = self.http_client.request("DELETE", f"/admin/roles/{role_name_str}/permission", - json_body={"actions": actions, "resource": resource_str}, use_api_base=True, - auth_kind="admin") + response = self.http_client.request("DELETE", f"/admin/roles/{role_name_str}/permission", json_body={"actions": actions, "resource": resource_str}, use_api_base=True, auth_kind="admin") res_json = response.json() if response.status_code == 200: self._print_table_simple(res_json["data"]) else: - print( - f"Fail to revoke role {role_name_str} with {actions} on {resource_str}, code: {res_json['code']}, message: {res_json['message']}") + print(f"Fail to revoke role {role_name_str} with {actions} on {resource_str}, code: {res_json['code']}, message: {res_json['message']}") def alter_user_role(self, command): if self.server_type != "admin": @@ -452,15 +421,12 @@ class RAGFlowClient: user_name_tree: Tree = command["user_name"] user_name_str: str = user_name_tree.children[0].strip("'\"") print(f"alter_user_role user_name: {user_name_str}, role_name: {role_name_str}") - response = self.http_client.request("PUT", f"/admin/users/{user_name_str}/role", - json_body={"role_name": role_name_str}, use_api_base=True, - auth_kind="admin") + response = self.http_client.request("PUT", f"/admin/users/{user_name_str}/role", json_body={"role_name": role_name_str}, use_api_base=True, auth_kind="admin") res_json = response.json() if response.status_code == 200: self._print_table_simple(res_json["data"]) else: - print( - f"Fail to alter user: {user_name_str} to role {role_name_str}, code: {res_json['code']}, message: {res_json['message']}") + print(f"Fail to alter user: {user_name_str} to role {role_name_str}, code: {res_json['code']}, message: {res_json['message']}") def show_user_permission(self, command): if self.server_type != "admin": @@ -469,14 +435,12 @@ class RAGFlowClient: user_name_tree: Tree = command["user_name"] user_name_str: str = user_name_tree.children[0].strip("'\"") print(f"show_user_permission user_name: {user_name_str}") - response = self.http_client.request("GET", f"/admin/users/{user_name_str}/permission", use_api_base=True, - auth_kind="admin") + response = self.http_client.request("GET", f"/admin/users/{user_name_str}/permission", use_api_base=True, auth_kind="admin") res_json = response.json() if response.status_code == 200: self._print_table_simple(res_json["data"]) else: - print( - f"Fail to show user: {user_name_str} permission, code: {res_json['code']}, message: {res_json['message']}") + print(f"Fail to show user: {user_name_str} permission, code: {res_json['code']}, message: {res_json['message']}") def generate_key(self, command: dict[str, Any]) -> None: if self.server_type != "admin": @@ -485,14 +449,12 @@ class RAGFlowClient: username_tree: Tree = command["user_name"] user_name: str = username_tree.children[0].strip("'\"") print(f"Generating API key for user: {user_name}") - response = self.http_client.request("POST", f"/admin/users/{user_name}/keys", use_api_base=True, - auth_kind="admin") + response = self.http_client.request("POST", f"/admin/users/{user_name}/keys", use_api_base=True, auth_kind="admin") res_json: dict[str, Any] = response.json() if response.status_code == 200: self._print_table_simple(res_json["data"]) else: - print( - f"Failed to generate key for user {user_name}, code: {res_json['code']}, message: {res_json['message']}") + print(f"Failed to generate key for user {user_name}, code: {res_json['code']}, message: {res_json['message']}") def list_keys(self, command: dict[str, Any]) -> None: if self.server_type != "admin": @@ -501,8 +463,7 @@ class RAGFlowClient: username_tree: Tree = command["user_name"] user_name: str = username_tree.children[0].strip("'\"") print(f"Listing API keys for user: {user_name}") - response = self.http_client.request("GET", f"/admin/users/{user_name}/keys", use_api_base=True, - auth_kind="admin") + response = self.http_client.request("GET", f"/admin/users/{user_name}/keys", use_api_base=True, auth_kind="admin") res_json: dict[str, Any] = response.json() if response.status_code == 200: self._print_table_simple(res_json["data"]) @@ -520,8 +481,7 @@ class RAGFlowClient: print(f"Dropping API key for user: {user_name}") # URL encode the key to handle special characters encoded_key: str = urllib.parse.quote(key, safe="") - response = self.http_client.request("DELETE", f"/admin/users/{user_name}/keys/{encoded_key}", use_api_base=True, - auth_kind="admin") + response = self.http_client.request("DELETE", f"/admin/users/{user_name}/keys/{encoded_key}", use_api_base=True, auth_kind="admin") res_json: dict[str, Any] = response.json() if response.status_code == 200: print(res_json["message"]) @@ -534,23 +494,19 @@ class RAGFlowClient: var_name = _strip_tree_value(command["var_name"]) var_value = _strip_tree_value(command["var_value"]) - response = self.http_client.request("PUT", "/admin/variables", - json_body={"var_name": var_name, "var_value": var_value}, use_api_base=True, - auth_kind="admin") + response = self.http_client.request("PUT", "/admin/variables", json_body={"var_name": var_name, "var_value": var_value}, use_api_base=True, auth_kind="admin") res_json = response.json() if response.status_code == 200: print(res_json["message"]) else: - print( - f"Fail to set variable {var_name} to {var_value}, code: {res_json['code']}, message: {res_json['message']}") + print(f"Fail to set variable {var_name} to {var_value}, code: {res_json['code']}, message: {res_json['message']}") def show_variable(self, command): if self.server_type != "admin": print("This command is only allowed in ADMIN mode") var_name = _strip_tree_value(command["var_name"]) - response = self.http_client.request(method="GET", path="/admin/variables", json_body={"var_name": var_name}, - use_api_base=True, auth_kind="admin") + response = self.http_client.request(method="GET", path="/admin/variables", json_body={"var_name": var_name}, use_api_base=True, auth_kind="admin") res_json = response.json() if response.status_code == 200: self._print_table_simple(res_json["data"]) @@ -604,8 +560,7 @@ class RAGFlowClient: if self.server_type != "admin": print("This command is only allowed in ADMIN mode") license = command["license"] - response = self.http_client.request("POST", "/admin/license", json_body={"license": license}, use_api_base=True, - auth_kind="admin") + response = self.http_client.request("POST", "/admin/license", json_body={"license": license}, use_api_base=True, auth_kind="admin") res_json = response.json() if response.status_code == 200: print("Set license successfully") @@ -617,9 +572,7 @@ class RAGFlowClient: print("This command is only allowed in ADMIN mode") value1 = command["value1"] value2 = command["value2"] - response = self.http_client.request("POST", "/admin/license/config", - json_body={"value1": value1, "value2": value2}, use_api_base=True, - auth_kind="admin") + response = self.http_client.request("POST", "/admin/license/config", json_body={"value1": value1, "value2": value2}, use_api_base=True, auth_kind="admin") res_json = response.json() if response.status_code == 200: print("Set license successfully") @@ -690,8 +643,7 @@ class RAGFlowClient: user_name: str = username_tree.children[0].strip("'\"") print(f"Listing all datasets of user: {user_name}") - response = self.http_client.request("GET", f"/admin/users/{user_name}/datasets", use_api_base=True, - auth_kind="admin") + response = self.http_client.request("GET", f"/admin/users/{user_name}/datasets", use_api_base=True, auth_kind="admin") res_json = response.json() if response.status_code == 200: table_data = res_json["data"] @@ -708,8 +660,7 @@ class RAGFlowClient: username_tree: Tree = command["user_name"] user_name: str = username_tree.children[0].strip("'\"") print(f"Listing all agents of user: {user_name}") - response = self.http_client.request("GET", f"/admin/users/{user_name}/agents", use_api_base=True, - auth_kind="admin") + response = self.http_client.request("GET", f"/admin/users/{user_name}/agents", use_api_base=True, auth_kind="admin") res_json = response.json() if response.status_code == 200: table_data = res_json["data"] @@ -733,8 +684,7 @@ class RAGFlowClient: # Step 1: Add provider provider_payload = {"provider_name": provider_name} - provider_response = self.http_client.request("PUT", "/providers", json_body=provider_payload, - use_api_base=True, auth_kind="web") + provider_response = self.http_client.request("PUT", "/providers", json_body=provider_payload, use_api_base=True, auth_kind="web") provider_res = provider_response.json() if provider_response.status_code == 200 and provider_res.get("code") == 0: print(f"Success to add provider {provider_name}") @@ -747,15 +697,8 @@ class RAGFlowClient: return # Step 2: Add instance - instance_payload = { - "instance_name": "default", - "api_key": api_key, - "region": "default", - "base_url": "" - } - instance_response = self.http_client.request("POST", f"/providers/{provider_name}/instances", - json_body=instance_payload, use_api_base=True, - auth_kind="web") + instance_payload = {"instance_name": "default", "api_key": api_key, "region": "default", "base_url": ""} + instance_response = self.http_client.request("POST", f"/providers/{provider_name}/instances", json_body=instance_payload, use_api_base=True, auth_kind="web") instance_res = instance_response.json() if instance_response.status_code == 200 and instance_res.get("code") == 0: print(f"Success to add instance for provider {provider_name}") @@ -771,8 +714,7 @@ class RAGFlowClient: print("This command is only allowed in USER mode") return provider_name: str = command["provider_name"] - response = self.http_client.request("DELETE", f"/providers/{provider_name}", use_api_base=True, - auth_kind="web") + response = self.http_client.request("DELETE", f"/providers/{provider_name}", use_api_base=True, auth_kind="web") res_json = response.json() if response.status_code == 200 and res_json.get("code") == 0: print(f"Success to drop model provider {provider_name}") @@ -810,8 +752,7 @@ class RAGFlowClient: "model_type": model_type, "model_name": model_name, } - response = self.http_client.request("PATCH", "/models/default", json_body=payload, use_api_base=True, - auth_kind="web") + response = self.http_client.request("PATCH", "/models/default", json_body=payload, use_api_base=True, auth_kind="web") res_json = response.json() if response.status_code == 200 and res_json.get("code") == 0: print(f"Success to set default {model_type} to {model_id}") @@ -830,8 +771,7 @@ class RAGFlowClient: return payload = {"model_type": model_type} - response = self.http_client.request("PATCH", "/models/default", json_body=payload, use_api_base=True, - auth_kind="web") + response = self.http_client.request("PATCH", "/models/default", json_body=payload, use_api_base=True, auth_kind="web") res_json = response.json() if response.status_code == 200 and res_json.get("code") == 0: print(f"Success to reset default {model_type}") @@ -861,8 +801,7 @@ class RAGFlowClient: iterations = command.get("iterations", 1) if iterations > 1: - response = self.http_client.request("GET", "/datasets", use_api_base=True, auth_kind="web", - iterations=iterations) + response = self.http_client.request("GET", "/datasets", use_api_base=True, auth_kind="web", iterations=iterations) return response else: response = self.http_client.request("GET", "/datasets", use_api_base=True, auth_kind="web") @@ -876,16 +815,12 @@ class RAGFlowClient: def create_user_dataset(self, command): if self.server_type != "user": print("This command is only allowed in USER mode") - payload = { - "name": command["dataset_name"], - "embedding_model": command["embedding"] - } + payload = {"name": command["dataset_name"], "embedding_model": command["embedding"]} if "parser_id" in command: payload["chunk_method"] = command["parser"] if "pipeline" in command: payload["pipeline_id"] = command["pipeline"] - response = self.http_client.request("POST", "/datasets", json_body=payload, use_api_base=True, - auth_kind="web") + response = self.http_client.request("POST", "/datasets", json_body=payload, use_api_base=True, auth_kind="web") res_json = response.json() if response.status_code == 200: self._print_table_simple(res_json["data"]) @@ -981,8 +916,7 @@ class RAGFlowClient: dataset_ids = [dataset_id for _, dataset_id in valid_datasets] kb_ids_param = ",".join(dataset_ids) - response = self.http_client.request("GET", f"/kb/get_meta?kb_ids={kb_ids_param}", - use_api_base=False, auth_kind="web") + response = self.http_client.request("GET", f"/kb/get_meta?kb_ids={kb_ids_param}", use_api_base=False, auth_kind="web") res_json = response.json() if response.status_code != 200: print(f"Fail to get metadata, code: {res_json.get('code')}, message: {res_json.get('message')}") @@ -996,11 +930,7 @@ class RAGFlowClient: table_data = [] for field_name, values_dict in meta.items(): for value, docs in values_dict.items(): - table_data.append({ - "field": field_name, - "value": value, - "doc_ids": ", ".join(docs) - }) + table_data.append({"field": field_name, "value": value, "doc_ids": ", ".join(docs)}) self._print_table_simple(table_data) def list_user_documents_metadata_summary(self, command_dict): @@ -1018,8 +948,7 @@ class RAGFlowClient: payload = {"kb_id": kb_id} if doc_ids: payload["doc_ids"] = doc_ids - response = self.http_client.request("POST", "/document/metadata/summary", json_body=payload, - use_api_base=False, auth_kind="web") + response = self.http_client.request("POST", "/document/metadata/summary", json_body=payload, use_api_base=False, auth_kind="web") res_json = response.json() if response.status_code == 200: summary = res_json.get("data", {}).get("summary", {}) @@ -1086,16 +1015,11 @@ class RAGFlowClient: "quote": True, "keyword": False, "tts": False, - "system": "You are an intelligent assistant. Your primary function is to answer questions based strictly on the provided knowledge base.\n\n **Essential Rules:**\n - Your answer must be derived **solely** from this knowledge base: `{knowledge}`.\n - **When information is available**: Summarize the content to give a detailed answer.\n - **When information is unavailable**: Your response must contain this exact sentence: \"The answer you are looking for is not found in the knowledge base!\"\n - **Always consider** the entire conversation history.", + "system": 'You are an intelligent assistant. Your primary function is to answer questions based strictly on the provided knowledge base.\n\n **Essential Rules:**\n - Your answer must be derived **solely** from this knowledge base: `{knowledge}`.\n - **When information is available**: Summarize the content to give a detailed answer.\n - **When information is unavailable**: Your response must contain this exact sentence: "The answer you are looking for is not found in the knowledge base!"\n - **Always consider** the entire conversation history.', "refine_multiturn": False, "use_kg": False, "reasoning": False, - "parameters": [ - { - "key": "knowledge", - "optional": False - } - ], + "parameters": [{"key": "knowledge", "optional": False}], "toc_enhance": False, }, "similarity_threshold": 0.2, @@ -1136,8 +1060,7 @@ class RAGFlowClient: # Build payload payload = {"kb_id": dataset_id, "vector_size": vector_size} # Call API - response = self.http_client.request("POST", "/kb/doc_engine_table", json_body=payload, - use_api_base=False, auth_kind="web") + response = self.http_client.request("POST", "/kb/doc_engine_table", json_body=payload, use_api_base=False, auth_kind="web") res_json = response.json() if response.status_code == 200 and res_json.get("code") == 0: print(f"Success to create table for dataset: {dataset_name}") @@ -1155,8 +1078,7 @@ class RAGFlowClient: return # Call API to delete table payload = {"kb_id": dataset_id} - response = self.http_client.request("DELETE", "/kb/doc_engine_table", json_body=payload, - use_api_base=False, auth_kind="web") + response = self.http_client.request("DELETE", "/kb/doc_engine_table", json_body=payload, use_api_base=False, auth_kind="web") res_json = response.json() if response.status_code == 200 and res_json.get("code") == 0: print(f"Success to drop table for dataset: {dataset_name}") @@ -1168,8 +1090,7 @@ class RAGFlowClient: print("This command is only allowed in USER mode") return # Call API to create metadata table - response = self.http_client.request("POST", "/tenant/doc_engine_metadata_table", - use_api_base=False, auth_kind="web") + response = self.http_client.request("POST", "/tenant/doc_engine_metadata_table", use_api_base=False, auth_kind="web") res_json = response.json() if response.status_code == 200 and res_json.get("code") == 0: print("Success to create metadata table") @@ -1181,8 +1102,7 @@ class RAGFlowClient: print("This command is only allowed in USER mode") return # Call API to delete metadata table - response = self.http_client.request("DELETE", "/tenant/doc_engine_metadata_table", - use_api_base=False, auth_kind="web") + response = self.http_client.request("DELETE", "/tenant/doc_engine_metadata_table", use_api_base=False, auth_kind="web") res_json = response.json() if response.status_code == 200 and res_json.get("code") == 0: print("Success to drop metadata table") @@ -1225,8 +1145,7 @@ class RAGFlowClient: def _list_chat_sessions(self, dialog_id): """List all sessions (conversations) for a given dialog.""" - response = self.http_client.request("GET", f"/chats/{dialog_id}/conversations", use_api_base=True, - auth_kind="web") + response = self.http_client.request("GET", f"/chats/{dialog_id}/conversations", use_api_base=True, auth_kind="web") res_json = response.json() if response.status_code == 200 and res_json["code"] == 0: return res_json["data"] @@ -1242,14 +1161,12 @@ class RAGFlowClient: if dialog_id is None: return payload = {"name": "New conversation"} - response = self.http_client.request("POST", f"/chats/{dialog_id}/conversations", json_body=payload, - use_api_base=True, auth_kind="web") + response = self.http_client.request("POST", f"/chats/{dialog_id}/conversations", json_body=payload, use_api_base=True, auth_kind="web") res_json = response.json() if response.status_code == 200 and res_json["code"] == 0: print(f"Success to create chat session for chat: {chat_name}") else: - print( - f"Fail to create chat session for chat {chat_name}, code: {res_json['code']}, message: {res_json['message']}") + print(f"Fail to create chat session for chat {chat_name}, code: {res_json['code']}, message: {res_json['message']}") def drop_chat_session(self, command): if self.server_type != "user": @@ -1270,14 +1187,12 @@ class RAGFlowClient: print(f"Chat session '{session_id}' not found in chat '{chat_name}'") return payload = {"ids": to_drop_session_ids} - response = self.http_client.request("DELETE", f"/chats/{dialog_id}/conversations", json_body=payload, - use_api_base=True, auth_kind="web") + response = self.http_client.request("DELETE", f"/chats/{dialog_id}/conversations", json_body=payload, use_api_base=True, auth_kind="web") res_json = response.json() if response.status_code == 200 and res_json["code"] == 0: print(f"Success to drop chat session '{session_id}' from chat: {chat_name}") else: - print( - f"Fail to drop chat session '{session_id}' from chat {chat_name}, code: {res_json['code']}, message: {res_json['message']}") + print(f"Fail to drop chat session '{session_id}' from chat {chat_name}, code: {res_json['code']}, message: {res_json['message']}") def list_chat_sessions(self, command): if self.server_type != "user": @@ -1305,13 +1220,9 @@ class RAGFlowClient: # Prepare payload for completion API # Note: stream parameter is not sent, server defaults to stream=True - payload = { - "session_id": session_id, - "messages": [{"role": "user", "content": message}] - } + payload = {"session_id": session_id, "messages": [{"role": "user", "content": message}]} - response = self.http_client.request("POST", "/chat/completions", json_body=payload, - use_api_base=True, auth_kind="web", stream=True) + response = self.http_client.request("POST", "/chat/completions", json_body=payload, use_api_base=True, auth_kind="web", stream=True) if response.status_code != 200: print(f"Fail to chat on session, status code: {response.status_code}") @@ -1322,17 +1233,16 @@ class RAGFlowClient: for line in response.iter_lines(): if not line: continue - line_str = line.decode('utf-8') - if not line_str.startswith('data:'): + line_str = line.decode("utf-8") + if not line_str.startswith("data:"): continue data_str = line_str[5:].strip() - if data_str == '[DONE]': + if data_str == "[DONE]": break try: data_json = json.loads(data_str) if data_json.get("code") != 0: - print( - f"\nFail to chat on session, code: {data_json.get('code')}, message: {data_json.get('message', '')}") + print(f"\nFail to chat on session, code: {data_json.get('code')}, message: {data_json.get('message', '')}") return # Check if it's the final message if data_json.get("data") is True: @@ -1416,14 +1326,12 @@ class RAGFlowClient: print(f"Documents {document_names} not found in {dataset_name}") payload = {"doc_ids": document_ids, "run": 1} - response = self.http_client.request("POST", "/documents/ingest", json_body=payload, use_api_base=True, - auth_kind="web") + response = self.http_client.request("POST", "/documents/ingest", json_body=payload, use_api_base=True, auth_kind="web") res_json = response.json() if response.status_code == 200 and res_json["code"] == 0: print(f"Success to parse {to_parse_doc_names} of {dataset_name}") else: - print( - f"Fail to parse documents {res_json["data"]["docs"]}, code: {res_json['code']}, message: {res_json['message']}") + print(f"Fail to parse documents {res_json['data']['docs']}, code: {res_json['code']}, message: {res_json['message']}") def parse_dataset(self, command_dict): if self.server_type != "user": @@ -1442,8 +1350,7 @@ class RAGFlowClient: document_ids.append(doc["id"]) payload = {"doc_ids": document_ids, "run": 1} - response = self.http_client.request("POST", "/documents/ingest", json_body=payload, use_api_base=True, - auth_kind="web") + response = self.http_client.request("POST", "/documents/ingest", json_body=payload, use_api_base=True, auth_kind="web") res_json = response.json() if response.status_code == 200 and res_json["code"] == 0: pass @@ -1483,15 +1390,7 @@ class RAGFlowClient: encoder = MultipartEncoder(fields=fields) headers = {"Content-Type": encoder.content_type} response = self.http_client.request( - "POST", - f"/datasets/{dataset_id}/documents?return_raw_files=true", - headers=headers, - data=encoder, - json_body=None, - params=None, - stream=False, - auth_kind="web", - use_api_base=True + "POST", f"/datasets/{dataset_id}/documents?return_raw_files=true", headers=headers, data=encoder, json_body=None, params=None, stream=False, auth_kind="web", use_api_base=True ) res = response.json() if res.get("code") == 0: @@ -1526,22 +1425,18 @@ class RAGFlowClient: } iterations = command_dict.get("iterations", 1) if iterations > 1: - response = self.http_client.request("POST", "/retrieval", json_body=payload, use_api_base=True, - auth_kind="web", iterations=iterations) + response = self.http_client.request("POST", "/retrieval", json_body=payload, use_api_base=True, auth_kind="web", iterations=iterations) return response else: - response = self.http_client.request("POST", "/retrieval", json_body=payload, use_api_base=True, - auth_kind="web") + response = self.http_client.request("POST", "/retrieval", json_body=payload, use_api_base=True, auth_kind="web") res_json = response.json() if response.status_code == 200: if res_json["code"] == 0: self._print_table_simple(res_json["data"]["chunks"]) else: - print( - f"Fail to search datasets: {dataset_names}, code: {res_json['code']}, message: {res_json['message']}") + print(f"Fail to search datasets: {dataset_names}, code: {res_json['code']}, message: {res_json['message']}") else: - print( - f"Fail to search datasets: {dataset_names}, code: {res_json['code']}, message: {res_json['message']}") + print(f"Fail to search datasets: {dataset_names}, code: {res_json['code']}, message: {res_json['message']}") def get_chunk(self, command_dict): if self.server_type != "user": @@ -1549,8 +1444,7 @@ class RAGFlowClient: return chunk_id = command_dict["chunk_id"] - response = self.http_client.request("GET", f"/chunk/get?chunk_id={chunk_id}", use_api_base=False, - auth_kind="web") + response = self.http_client.request("GET", f"/chunk/get?chunk_id={chunk_id}", use_api_base=False, auth_kind="web") res_json = response.json() if response.status_code == 200: if res_json["code"] == 0: @@ -1568,8 +1462,7 @@ class RAGFlowClient: file_path = command_dict["file_path"] payload = {"file_path": file_path} - response = self.http_client.request("POST", "/kb/insert_from_file", json_body=payload, - use_api_base=False, auth_kind="web") + response = self.http_client.request("POST", "/kb/insert_from_file", json_body=payload, use_api_base=False, auth_kind="web") res_json = response.json() if response.status_code == 200: if res_json["code"] == 0: @@ -1589,8 +1482,7 @@ class RAGFlowClient: file_path = command_dict["file_path"] payload = {"file_path": file_path} - response = self.http_client.request("POST", "/tenant/insert_metadata_from_file", json_body=payload, - use_api_base=False, auth_kind="web") + response = self.http_client.request("POST", "/tenant/insert_metadata_from_file", json_body=payload, use_api_base=False, auth_kind="web") res_json = response.json() if response.status_code == 200: if res_json["code"] == 0: @@ -1617,8 +1509,7 @@ class RAGFlowClient: return # Get doc_id from chunk_id via GET /chunk/get - response = self.http_client.request("GET", f"/chunk/get?chunk_id={chunk_id}", use_api_base=False, - auth_kind="web") + response = self.http_client.request("GET", f"/chunk/get?chunk_id={chunk_id}", use_api_base=False, auth_kind="web") res_json = response.json() if response.status_code != 200: print(f"Fail to get chunk info, code: {res_json.get('code')}, message: {res_json.get('message')}") @@ -1655,14 +1546,8 @@ class RAGFlowClient: else: print(f"Fail to update chunk, HTTP {response.status_code}") - def _get_documents_by_ids(self, ids:list[str]): - response = self.http_client.request( - "POST", - "/document/infos", - json_body={"doc_ids": ids}, - use_api_base=False, - auth_kind="web" - ) + def _get_documents_by_ids(self, ids: list[str]): + response = self.http_client.request("POST", "/document/infos", json_body={"doc_ids": ids}, use_api_base=False, auth_kind="web") if response.status_code != 200: return f"Fail to get document info, HTTP {response.status_code}", None @@ -1687,6 +1572,7 @@ class RAGFlowClient: # Parse JSON string to dict import json + try: meta_fields = json.loads(meta_json_str) except json.JSONDecodeError as e: @@ -1713,13 +1599,7 @@ class RAGFlowClient: "meta_fields": meta_fields, } - response = self.http_client.request( - "PATCH", - f"/datasets/{dataset_id}/documents/{doc_id}", - json_body=payload, - use_api_base=True, - auth_kind="web" - ) + response = self.http_client.request("PATCH", f"/datasets/{dataset_id}/documents/{doc_id}", json_body=payload, use_api_base=True, auth_kind="web") res_json = response.json() if response.status_code == 200: @@ -1747,8 +1627,7 @@ class RAGFlowClient: "tags": tags, } - response = self.http_client.request("POST", f"/kb/{dataset_id}/rm_tags", json_body=payload, - use_api_base=False, auth_kind="web") + response = self.http_client.request("POST", f"/kb/{dataset_id}/rm_tags", json_body=payload, use_api_base=False, auth_kind="web") res_json = response.json() if response.status_code == 200: if res_json.get("code") == 0: @@ -1771,8 +1650,7 @@ class RAGFlowClient: elif command_dict.get("chunk_ids"): payload["chunk_ids"] = command_dict["chunk_ids"] - response = self.http_client.request("POST", "/chunk/rm", json_body=payload, - use_api_base=False, auth_kind="web") + response = self.http_client.request("POST", "/chunk/rm", json_body=payload, use_api_base=False, auth_kind="web") res_json = response.json() if response.status_code == 200: if res_json.get("code") == 0: @@ -1803,15 +1681,14 @@ class RAGFlowClient: if "available_int" in command_dict: payload["available_int"] = command_dict["available_int"] - response = self.http_client.request("POST", "/chunk/list", json_body=payload, use_api_base=False, - auth_kind="web") + response = self.http_client.request("POST", "/chunk/list", json_body=payload, use_api_base=False, auth_kind="web") res_json = response.json() if response.status_code == 200: if res_json["code"] == 0: chunks = res_json["data"]["chunks"] if chunks: for i, chunk in enumerate(chunks): - print(f"\n--- Chunk {i+1} ---") + print(f"\n--- Chunk {i + 1} ---") for key, value in chunk.items(): print(f" {key}: {value}") else: @@ -1845,7 +1722,7 @@ class RAGFlowClient: all_done = True for doc in docs: if doc.get("run") != "DONE": - print(f"Document {doc["name"]} is not done, status: {doc.get("run")}") + print(f"Document {doc['name']} is not done, status: {doc.get('run')}") all_done = False break if all_done: @@ -1856,16 +1733,10 @@ class RAGFlowClient: def _list_documents(self, dataset_name: str, dataset_id: str): # Use the new RESTful API: GET /api/v1/datasets//documents - response = self.http_client.request( - "GET", - f"/datasets/{dataset_id}/documents", - use_api_base=True, - auth_kind="web" - ) + response = self.http_client.request("GET", f"/datasets/{dataset_id}/documents", use_api_base=True, auth_kind="web") res_json = response.json() if response.status_code != 200: - print( - f"Fail to list files from dataset {dataset_name}, code: {res_json['code']}, message: {res_json['message']}") + print(f"Fail to list files from dataset {dataset_name}, code: {res_json['code']}, message: {res_json['message']}") return None return res_json["data"]["docs"] @@ -2254,22 +2125,14 @@ def run_benchmark(client: RAGFlowClient, command_dict: dict): total_duration = result["duration"] qps = iterations / total_duration if total_duration > 0 else None print(f"command: {command}, Concurrency: {concurrency}, iterations: {iterations}") - print( - f"total duration: {total_duration:.4f}s, QPS: {qps}, COMMAND_COUNT: {iterations}, SUCCESS: {success_count}, FAILURE: {iterations - success_count}") + print(f"total duration: {total_duration:.4f}s, QPS: {qps}, COMMAND_COUNT: {iterations}, SUCCESS: {success_count}, FAILURE: {iterations - success_count}") pass else: results: List[Optional[dict]] = [None] * concurrency mp_context = mp.get_context("spawn") start_time = time.perf_counter() with ProcessPoolExecutor(max_workers=concurrency, mp_context=mp_context) as executor: - future_map = { - executor.submit( - run_command, - client, - command - ): idx - for idx in range(concurrency) - } + future_map = {executor.submit(run_command, client, command): idx for idx in range(concurrency)} for future in as_completed(future_map): idx = future_map[future] results[idx] = future.result() @@ -2291,7 +2154,6 @@ def run_benchmark(client: RAGFlowClient, command_dict: dict): total_command_count = iterations * concurrency qps = total_command_count / total_duration if total_duration > 0 else None print(f"command: {command}, Concurrency: {concurrency} , iterations: {iterations}") - print( - f"total duration: {total_duration:.4f}s, QPS: {qps}, COMMAND_COUNT: {total_command_count}, SUCCESS: {success_count}, FAILURE: {total_command_count - success_count}") + print(f"total duration: {total_duration:.4f}s, QPS: {qps}, COMMAND_COUNT: {total_command_count}, SUCCESS: {success_count}, FAILURE: {total_command_count - success_count}") pass diff --git a/admin/client/user.py b/admin/client/user.py index c86328f388..4d417fa2a0 100644 --- a/admin/client/user.py +++ b/admin/client/user.py @@ -29,6 +29,7 @@ def encrypt_password(password_plain: str) -> str: import base64 from Cryptodome.PublicKey import RSA from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5 + def crypt(line): """ decrypt(crypt(input_string)) == base64(input_string), which frontend and ragflow_cli use. @@ -36,13 +37,11 @@ def encrypt_password(password_plain: str) -> str: pub = "-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEArq9XTUSeYr2+N1h3Afl/z8Dse/2yD0ZGrKwx+EEEcdsBLca9Ynmx3nIB5obmLlSfmskLpBo0UACBmB5rEjBp2Q2f3AG3Hjd4B+gNCG6BDaawuDlgANIhGnaTLrIqWrrcm4EMzJOnAOI1fgzJRsOOUEfaS318Eq9OVO3apEyCCt0lOQK6PuksduOjVxtltDav+guVAA068NrPYmRNabVKRNLJpL8w4D44sfth5RvZ3q9t+6RTArpEtc5sh5ChzvqPOzKGMXW83C95TxmXqpbK6olN4RevSfVjEAgCydH6HN6OhtOQEcnrU97r9H0iZOWwbw3pVrZiUkuRD1R56Wzs2wIDAQAB\n-----END PUBLIC KEY-----" rsa_key = RSA.importKey(pub) cipher = Cipher_pkcs1_v1_5.new(rsa_key) - password_base64 = base64.b64encode(line.encode('utf-8')).decode("utf-8") + password_base64 = base64.b64encode(line.encode("utf-8")).decode("utf-8") encrypted_password = cipher.encrypt(password_base64.encode()) - return base64.b64encode(encrypted_password).decode('utf-8') + return base64.b64encode(encrypted_password).decode("utf-8") except Exception as exc: - raise AuthException( - "Password encryption unavailable; install pycryptodomex (uv sync --python 3.13 --group test)." - ) from exc + raise AuthException("Password encryption unavailable; install pycryptodomex (uv sync --python 3.13 --group test).") from exc return crypt(password_plain) diff --git a/admin/server/admin_server.py b/admin/server/admin_server.py index b7c5cd78bb..677d8a1c96 100644 --- a/admin/server/admin_server.py +++ b/admin/server/admin_server.py @@ -15,6 +15,7 @@ # import time + start_ts = time.time() import os @@ -38,26 +39,24 @@ from common.versions import get_ragflow_version stop_event = threading.Event() -if __name__ == '__main__': +if __name__ == "__main__": faulthandler.enable() init_root_logger("admin_service") logging.info(r""" - ____ ___ ______________ ___ __ _ - / __ \/ | / ____/ ____/ /___ _ __ / | ____/ /___ ___ (_)___ + ____ ___ ______________ ___ __ _ + / __ \/ | / ____/ ____/ /___ _ __ / | ____/ /___ ___ (_)___ / /_/ / /| |/ / __/ /_ / / __ \ | /| / / / /| |/ __ / __ `__ \/ / __ \ / _, _/ ___ / /_/ / __/ / / /_/ / |/ |/ / / ___ / /_/ / / / / / / / / / / - /_/ |_/_/ |_\____/_/ /_/\____/|__/|__/ /_/ |_\__,_/_/ /_/ /_/_/_/ /_/ + /_/ |_/_/ |_\____/_/ /_/\____/|__/|__/ /_/ |_\__,_/_/ /_/ /_/_/_/ /_/ """) app = Flask(__name__) app.register_blueprint(admin_bp) app.config["SESSION_PERMANENT"] = False app.config["SESSION_TYPE"] = "filesystem" - app.config["MAX_CONTENT_LENGTH"] = int( - os.environ.get("MAX_CONTENT_LENGTH", 1024 * 1024 * 1024) - ) + app.config["MAX_CONTENT_LENGTH"] = int(os.environ.get("MAX_CONTENT_LENGTH", 1024 * 1024 * 1024)) Session(app) - logging.info(f'RAGFlow admin version: {get_ragflow_version()}') + logging.info(f"RAGFlow admin version: {get_ragflow_version()}") show_configs() login_manager = LoginManager() login_manager.init_app(app) diff --git a/admin/server/auth.py b/admin/server/auth.py index 36cf60f8f3..891020077a 100644 --- a/admin/server/auth.py +++ b/admin/server/auth.py @@ -70,9 +70,7 @@ def setup_auth(login_manager): logging.warning(f"Authentication attempt with invalid token format: {len(access_token)} chars") return None - user = UserService.query( - access_token=access_token, status=StatusEnum.VALID.value - ) + user = UserService.query(access_token=access_token, status=StatusEnum.VALID.value) if user: if not user[0].access_token or not user[0].access_token.strip(): logging.warning(f"User {user[0].email} has empty access_token in database") @@ -126,19 +124,13 @@ def add_tenant_for_admin(user_info: dict, role: str): "img2txt_id": settings.IMAGE2TEXT_MDL, "rerank_id": settings.RERANK_MDL, } - usr_tenant = { - "tenant_id": user_info["id"], - "user_id": user_info["id"], - "invited_by": user_info["id"], - "role": role - } + usr_tenant = {"tenant_id": user_info["id"], "user_id": user_info["id"], "invited_by": user_info["id"], "role": role} # tenant_llm = get_init_tenant_llm(user_info["id"]) TenantService.insert(**tenant) UserTenantService.insert(**usr_tenant) # TenantLLMService.insert_many(tenant_llm) - logging.info( - f"Added tenant for email: {user_info['email']}, A default tenant has been set; changing the default models after login is strongly recommended.") + logging.info(f"Added tenant for email: {user_info['email']}, A default tenant has been set; changing the default models after login is strongly recommended.") def check_admin_auth(func): @@ -212,28 +204,17 @@ def login_verify(f): @wraps(f) def decorated(*args, **kwargs): auth = request.authorization - if not auth or 'username' not in auth.parameters or 'password' not in auth.parameters: - return jsonify({ - "code": 401, - "message": "Authentication required", - "data": None - }), 200 + if not auth or "username" not in auth.parameters or "password" not in auth.parameters: + return jsonify({"code": 401, "message": "Authentication required", "data": None}), 200 - username = auth.parameters['username'] - password = auth.parameters['password'] + username = auth.parameters["username"] + password = auth.parameters["password"] try: if not check_admin(username, password): - return jsonify({ - "code": 500, - "message": "Access denied", - "data": None - }), 200 + return jsonify({"code": 500, "message": "Access denied", "data": None}), 200 except Exception: logging.exception("An error occurred during admin login verification.") - return jsonify({ - "code": 500, - "message": "An internal server error occurred." - }), 200 + return jsonify({"code": 500, "message": "An internal server error occurred."}), 200 return f(*args, **kwargs) diff --git a/admin/server/config.py b/admin/server/config.py index 61432ff29f..9982b9964c 100644 --- a/admin/server/config.py +++ b/admin/server/config.py @@ -34,8 +34,7 @@ class BaseConfig(BaseModel): detail_func_name: str def to_dict(self) -> dict[str, Any]: - return {'id': self.id, 'name': self.name, 'host': self.host, 'port': self.port, - 'service_type': self.service_type} + return {"id": self.id, "name": self.name, "host": self.host, "port": self.port, "service_type": self.service_type} class ServiceConfigs: @@ -63,11 +62,11 @@ class MetaConfig(BaseConfig): def to_dict(self) -> dict[str, Any]: result = super().to_dict() - if 'extra' not in result: - result['extra'] = dict() - extra_dict = result['extra'].copy() - extra_dict['meta_type'] = self.meta_type - result['extra'] = extra_dict + if "extra" not in result: + result["extra"] = dict() + extra_dict = result["extra"].copy() + extra_dict["meta_type"] = self.meta_type + result["extra"] = extra_dict return result @@ -77,21 +76,20 @@ class MySQLConfig(MetaConfig): def to_dict(self) -> dict[str, Any]: result = super().to_dict() - if 'extra' not in result: - result['extra'] = dict() - extra_dict = result['extra'].copy() - extra_dict['username'] = self.username - extra_dict['password'] = self.password - result['extra'] = extra_dict + if "extra" not in result: + result["extra"] = dict() + extra_dict = result["extra"].copy() + extra_dict["username"] = self.username + extra_dict["password"] = self.password + result["extra"] = extra_dict return result class PostgresConfig(MetaConfig): - def to_dict(self) -> dict[str, Any]: result = super().to_dict() - if 'extra' not in result: - result['extra'] = dict() + if "extra" not in result: + result["extra"] = dict() return result @@ -100,11 +98,11 @@ class RetrievalConfig(BaseConfig): def to_dict(self) -> dict[str, Any]: result = super().to_dict() - if 'extra' not in result: - result['extra'] = dict() - extra_dict = result['extra'].copy() - extra_dict['retrieval_type'] = self.retrieval_type - result['extra'] = extra_dict + if "extra" not in result: + result["extra"] = dict() + extra_dict = result["extra"].copy() + extra_dict["retrieval_type"] = self.retrieval_type + result["extra"] = extra_dict return result @@ -113,11 +111,11 @@ class InfinityConfig(RetrievalConfig): def to_dict(self) -> dict[str, Any]: result = super().to_dict() - if 'extra' not in result: - result['extra'] = dict() - extra_dict = result['extra'].copy() - extra_dict['db_name'] = self.db_name - result['extra'] = extra_dict + if "extra" not in result: + result["extra"] = dict() + extra_dict = result["extra"].copy() + extra_dict["db_name"] = self.db_name + result["extra"] = extra_dict return result @@ -127,12 +125,12 @@ class ElasticsearchConfig(RetrievalConfig): def to_dict(self) -> dict[str, Any]: result = super().to_dict() - if 'extra' not in result: - result['extra'] = dict() - extra_dict = result['extra'].copy() - extra_dict['username'] = self.username - extra_dict['password'] = self.password - result['extra'] = extra_dict + if "extra" not in result: + result["extra"] = dict() + extra_dict = result["extra"].copy() + extra_dict["username"] = self.username + extra_dict["password"] = self.password + result["extra"] = extra_dict return result @@ -141,11 +139,11 @@ class MessageQueueConfig(BaseConfig): def to_dict(self) -> dict[str, Any]: result = super().to_dict() - if 'extra' not in result: - result['extra'] = dict() - extra_dict = result['extra'].copy() - extra_dict['mq_type'] = self.mq_type - result['extra'] = extra_dict + if "extra" not in result: + result["extra"] = dict() + extra_dict = result["extra"].copy() + extra_dict["mq_type"] = self.mq_type + result["extra"] = extra_dict return result @@ -155,30 +153,28 @@ class RedisConfig(MessageQueueConfig): def to_dict(self) -> dict[str, Any]: result = super().to_dict() - if 'extra' not in result: - result['extra'] = dict() - extra_dict = result['extra'].copy() - extra_dict['database'] = self.database - extra_dict['password'] = self.password - result['extra'] = extra_dict + if "extra" not in result: + result["extra"] = dict() + extra_dict = result["extra"].copy() + extra_dict["database"] = self.database + extra_dict["password"] = self.password + result["extra"] = extra_dict return result class RabbitMQConfig(MessageQueueConfig): - def to_dict(self) -> dict[str, Any]: result = super().to_dict() - if 'extra' not in result: - result['extra'] = dict() + if "extra" not in result: + result["extra"] = dict() return result class RAGFlowServerConfig(BaseConfig): - def to_dict(self) -> dict[str, Any]: result = super().to_dict() - if 'extra' not in result: - result['extra'] = dict() + if "extra" not in result: + result["extra"] = dict() return result @@ -187,9 +183,9 @@ class TaskExecutorConfig(BaseConfig): def to_dict(self) -> dict[str, Any]: result = super().to_dict() - if 'extra' not in result: - result['extra'] = dict() - result['extra']['message_queue_type'] = self.message_queue_type + if "extra" not in result: + result["extra"] = dict() + result["extra"]["message_queue_type"] = self.message_queue_type return result @@ -198,11 +194,11 @@ class FileStoreConfig(BaseConfig): def to_dict(self) -> dict[str, Any]: result = super().to_dict() - if 'extra' not in result: - result['extra'] = dict() - extra_dict = result['extra'].copy() - extra_dict['store_type'] = self.store_type - result['extra'] = extra_dict + if "extra" not in result: + result["extra"] = dict() + extra_dict = result["extra"].copy() + extra_dict["store_type"] = self.store_type + result["extra"] = extra_dict return result @@ -212,12 +208,12 @@ class MinioConfig(FileStoreConfig): def to_dict(self) -> dict[str, Any]: result = super().to_dict() - if 'extra' not in result: - result['extra'] = dict() - extra_dict = result['extra'].copy() - extra_dict['user'] = self.user - extra_dict['password'] = self.password - result['extra'] = extra_dict + if "extra" not in result: + result["extra"] = dict() + extra_dict = result["extra"].copy() + extra_dict["user"] = self.user + extra_dict["password"] = self.password + result["extra"] = extra_dict return result @@ -229,106 +225,105 @@ def load_configurations(config_path: str) -> list[BaseConfig]: for k, v in raw_configs.items(): match k: case "ragflow": - name: str = f'ragflow_{ragflow_count}' - host: str = v['host'] - http_port: int = v['http_port'] - config = RAGFlowServerConfig(id=id_count, name=name, host=host, port=http_port, - service_type="ragflow_server", - detail_func_name="check_ragflow_server_alive") + name: str = f"ragflow_{ragflow_count}" + host: str = v["host"] + http_port: int = v["http_port"] + config = RAGFlowServerConfig(id=id_count, name=name, host=host, port=http_port, service_type="ragflow_server", detail_func_name="check_ragflow_server_alive") configurations.append(config) id_count += 1 case "es": - name: str = 'elasticsearch' - url = v['hosts'] + name: str = "elasticsearch" + url = v["hosts"] parsed = urlparse(url) host: str = parsed.hostname port: int = parsed.port - username: str = v.get('username') - password: str = v.get('password') - config = ElasticsearchConfig(id=id_count, name=name, host=host, port=port, service_type="retrieval", - retrieval_type="elasticsearch", - username=username, password=password, - detail_func_name="get_es_cluster_stats") + username: str = v.get("username") + password: str = v.get("password") + config = ElasticsearchConfig( + id=id_count, + name=name, + host=host, + port=port, + service_type="retrieval", + retrieval_type="elasticsearch", + username=username, + password=password, + detail_func_name="get_es_cluster_stats", + ) configurations.append(config) id_count += 1 case "infinity": - name: str = 'infinity' - url = v['uri'] - parts = url.split(':', 1) + name: str = "infinity" + url = v["uri"] + parts = url.split(":", 1) host = parts[0] port = int(parts[1]) - database: str = v.get('db_name', 'default_db') - config = InfinityConfig(id=id_count, name=name, host=host, port=port, service_type="retrieval", - retrieval_type="infinity", - db_name=database, detail_func_name="get_infinity_status") + database: str = v.get("db_name", "default_db") + config = InfinityConfig(id=id_count, name=name, host=host, port=port, service_type="retrieval", retrieval_type="infinity", db_name=database, detail_func_name="get_infinity_status") configurations.append(config) id_count += 1 case "minio_0": - name: str = 'minio_0' - url = v['host'] - parts = url.split(':', 1) + name: str = "minio_0" + url = v["host"] + parts = url.split(":", 1) host = parts[0] port = int(parts[1]) - user = v.get('user') - password = v.get('password') - config = MinioConfig(id=id_count, name=name, host=host, port=port, user=user, password=password, - service_type="file_store", - store_type="minio", detail_func_name="check_minio_alive") + user = v.get("user") + password = v.get("password") + config = MinioConfig(id=id_count, name=name, host=host, port=port, user=user, password=password, service_type="file_store", store_type="minio", detail_func_name="check_minio_alive") configurations.append(config) id_count += 1 case "minio": - name: str = 'minio' - url = v['host'] - parts = url.split(':', 1) + name: str = "minio" + url = v["host"] + parts = url.split(":", 1) host = parts[0] port = int(parts[1]) - user = v.get('user') - password = v.get('password') - config = MinioConfig(id=id_count, name=name, host=host, port=port, user=user, password=password, - service_type="file_store", - store_type="minio", detail_func_name="check_minio_alive") + user = v.get("user") + password = v.get("password") + config = MinioConfig(id=id_count, name=name, host=host, port=port, user=user, password=password, service_type="file_store", store_type="minio", detail_func_name="check_minio_alive") configurations.append(config) id_count += 1 case "redis": - name: str = 'redis' - url = v['host'] - parts = url.split(':', 1) + name: str = "redis" + url = v["host"] + parts = url.split(":", 1) host = parts[0] port = int(parts[1]) - password = v.get('password') - db: int = v.get('db') - config = RedisConfig(id=id_count, name=name, host=host, port=port, password=password, database=db, - service_type="message_queue", mq_type="redis", detail_func_name="get_redis_info") + password = v.get("password") + db: int = v.get("db") + config = RedisConfig(id=id_count, name=name, host=host, port=port, password=password, database=db, service_type="message_queue", mq_type="redis", detail_func_name="get_redis_info") configurations.append(config) id_count += 1 case "mysql": - name: str = 'mysql' - host: str = v.get('host') - port: int = v.get('port') - username = v.get('user') - password = v.get('password') - config = MySQLConfig(id=id_count, name=name, host=host, port=port, username=username, password=password, - service_type="meta_data", meta_type="mysql", detail_func_name="get_mysql_status") + name: str = "mysql" + host: str = v.get("host") + port: int = v.get("port") + username = v.get("user") + password = v.get("password") + config = MySQLConfig( + id=id_count, name=name, host=host, port=port, username=username, password=password, service_type="meta_data", meta_type="mysql", detail_func_name="get_mysql_status" + ) configurations.append(config) id_count += 1 case "admin": pass case "task_executor": - name: str = 'task_executor' - host: str = v.get('host', '') - port: int = v.get('port', 0) - message_queue_type: str = v.get('message_queue_type') - config = TaskExecutorConfig(id=id_count, name=name, host=host, port=port, message_queue_type=message_queue_type, - service_type="task_executor", detail_func_name="check_task_executor_alive") + name: str = "task_executor" + host: str = v.get("host", "") + port: int = v.get("port", 0) + message_queue_type: str = v.get("message_queue_type") + config = TaskExecutorConfig( + id=id_count, name=name, host=host, port=port, message_queue_type=message_queue_type, service_type="task_executor", detail_func_name="check_task_executor_alive" + ) configurations.append(config) id_count += 1 case "rabbitmq": - name: str = 'rabbitmq' - host: str = v.get('host') - port: int = v.get('port') - config = RabbitMQConfig(id=id_count, name=name, host=host, port=port, - service_type="message_queue", mq_type="rabbitmq", detail_func_name="check_rabbitmq_alive") + name: str = "rabbitmq" + host: str = v.get("host") + port: int = v.get("port") + config = RabbitMQConfig(id=id_count, name=name, host=host, port=port, service_type="message_queue", mq_type="rabbitmq", detail_func_name="check_rabbitmq_alive") configurations.append(config) id_count += 1 case _: diff --git a/admin/server/exceptions.py b/admin/server/exceptions.py index 5e3021b418..81fd800114 100644 --- a/admin/server/exceptions.py +++ b/admin/server/exceptions.py @@ -4,14 +4,17 @@ class AdminException(Exception): self.code = code self.message = message + class UserNotFoundError(AdminException): def __init__(self, username): super().__init__(f"User '{username}' not found", 404) + class UserAlreadyExistsError(AdminException): def __init__(self, username): super().__init__(f"User '{username}' already exists", 409) + class CannotDeleteAdminError(AdminException): def __init__(self): - super().__init__("Cannot delete admin account", 403) \ No newline at end of file + super().__init__("Cannot delete admin account", 403) diff --git a/admin/server/responses.py b/admin/server/responses.py index c41c4512eb..1264ed0a0c 100644 --- a/admin/server/responses.py +++ b/admin/server/responses.py @@ -17,16 +17,8 @@ from flask import jsonify def success_response(data=None, message="Success", code=0): - return jsonify({ - "code": code, - "message": message, - "data": data - }), 200 + return jsonify({"code": code, "message": message, "data": data}), 200 def error_response(message="Error", code=-1, data=None): - return jsonify({ - "code": code, - "message": message, - "data": data - }), 400 + return jsonify({"code": code, "message": message, "data": data}), 400 diff --git a/admin/server/services.py b/admin/server/services.py index 341dfafdaf..754d690b50 100644 --- a/admin/server/services.py +++ b/admin/server/services.py @@ -489,10 +489,7 @@ class SandboxMgr: """List all available sandbox providers.""" result = [] for provider_id, metadata in SandboxMgr.PROVIDER_REGISTRY.items(): - result.append({ - "id": provider_id, - **metadata - }) + result.append({"id": provider_id, **metadata}) return result @staticmethod @@ -635,6 +632,7 @@ class SandboxMgr: config_json = json.dumps(config) SettingsMgr.update_by_name(f"sandbox.{provider_type}", config_json) from agent.sandbox.client import reload_provider + reload_provider() return {"provider_type": provider_type, "config": config} @@ -727,11 +725,7 @@ def main() -> dict: # Build detailed result message success = execution_result.exit_code == 0 and "TEST_PASSED" in execution_result.stdout - message_parts = [ - f"Test {success and 'PASSED' or 'FAILED'}", - f"Exit code: {execution_result.exit_code}", - f"Execution time: {execution_result.execution_time:.2f}s" - ] + message_parts = [f"Test {success and 'PASSED' or 'FAILED'}", f"Exit code: {execution_result.exit_code}", f"Execution time: {execution_result.execution_time:.2f}s"] if execution_result.stdout.strip(): stdout_preview = execution_result.stdout.strip()[:200] @@ -751,12 +745,13 @@ def main() -> dict: "execution_time": execution_result.execution_time, "stdout": execution_result.stdout, "stderr": execution_result.stderr, - } + }, } except AdminException: raise except Exception as e: import traceback + error_details = traceback.format_exc() raise AdminException(f"Connection test failed: {str(e)}\\n\\nStack trace:\\n{error_details}") diff --git a/agent/canvas.py b/agent/canvas.py index 572a85e09d..2aa7cc40f1 100644 --- a/agent/canvas.py +++ b/agent/canvas.py @@ -161,6 +161,7 @@ class Graph: def close(self): from common.mcp_tool_call_conn import MCPToolCallSession + seen = set() for cpn in self.components.values(): obj = cpn.get("obj") diff --git a/agent/component/__init__.py b/agent/component/__init__.py index d4a481518b..4588df7c8d 100644 --- a/agent/component/__init__.py +++ b/agent/component/__init__.py @@ -22,8 +22,9 @@ from typing import Dict, Type _package_path = os.path.dirname(__file__) __all_classes: Dict[str, Type] = {} + def _import_submodules() -> None: - for filename in os.listdir(_package_path): # noqa: F821 + for filename in os.listdir(_package_path): # noqa: F821 if filename.startswith("__") or not filename.endswith(".py") or filename.startswith("base"): continue module_name = filename[:-3] @@ -34,13 +35,14 @@ def _import_submodules() -> None: except ImportError as e: print(f"Warning: Failed to import module {module_name}: {str(e)}") + def _extract_classes_from_module(module: ModuleType) -> None: for name, obj in inspect.getmembers(module): - if (inspect.isclass(obj) and - obj.__module__ == module.__name__ and not name.startswith("_")): + if inspect.isclass(obj) and obj.__module__ == module.__name__ and not name.startswith("_"): __all_classes[name] = obj globals()[name] = obj + _import_submodules() __all__ = list(__all_classes.keys()) + ["__all_classes"] diff --git a/agent/component/base.py b/agent/component/base.py index a91bec70ef..afb4e09e86 100644 --- a/agent/component/base.py +++ b/agent/component/base.py @@ -28,7 +28,6 @@ from agent import settings from common.connection_utils import timeout - from common.misc_utils import thread_pool_exec _logger = logging.getLogger(__name__) @@ -101,6 +100,7 @@ class ComponentParamBase(ABC): return None logging.warning("ComponentParamBase.__str__: JSON fallback via str() for type=%s", type(obj).__name__) return str(obj) + return json.dumps(self.as_dict(), ensure_ascii=False, default=_serialize_default) def as_dict(self): @@ -135,15 +135,11 @@ class ComponentParamBase(ABC): update_from_raw_conf = conf.get(_IS_RAW_CONF, True) if update_from_raw_conf: deprecated_params_set = self._get_or_init_deprecated_params_set() - feeded_deprecated_params_set = ( - self._get_or_init_feeded_deprecated_params_set() - ) + feeded_deprecated_params_set = self._get_or_init_feeded_deprecated_params_set() user_feeded_params_set = self._get_or_init_user_feeded_params_set() setattr(self, _IS_RAW_CONF, False) else: - feeded_deprecated_params_set = ( - self._get_or_init_feeded_deprecated_params_set(conf) - ) + feeded_deprecated_params_set = self._get_or_init_feeded_deprecated_params_set(conf) user_feeded_params_set = self._get_or_init_user_feeded_params_set(conf) def _recursive_update_param(param, config, depth, prefix): @@ -179,15 +175,11 @@ class ComponentParamBase(ABC): else: # recursive set obj attr - sub_params = _recursive_update_param( - attr, config_value, depth + 1, prefix=f"{prefix}{config_key}." - ) + sub_params = _recursive_update_param(attr, config_value, depth + 1, prefix=f"{prefix}{config_key}.") setattr(param, config_key, sub_params) if not allow_redundant and redundant_attrs: - raise ValueError( - f"cpn `{getattr(self, '_name', type(self))}` has redundant parameters: `{[redundant_attrs]}`" - ) + raise ValueError(f"cpn `{getattr(self, '_name', type(self))}` has redundant parameters: `{[redundant_attrs]}`") return param @@ -218,9 +210,7 @@ class ComponentParamBase(ABC): param_validation_path_prefix = home_dir + "/param_validation/" param_name = type(self).__name__ - param_validation_path = "/".join( - [param_validation_path_prefix, param_name + ".json"] - ) + param_validation_path = "/".join([param_validation_path_prefix, param_name + ".json"]) validation_json = None @@ -253,11 +243,7 @@ class ComponentParamBase(ABC): break if not value_legal: - raise ValueError( - "Please check runtime conf, {} = {} does not match user-parameter restriction".format( - variable, value - ) - ) + raise ValueError("Please check runtime conf, {} = {} does not match user-parameter restriction".format(variable, value)) elif variable in validation_json: self._validate_param(attr, validation_json) @@ -335,11 +321,7 @@ class ComponentParamBase(ABC): def _range(value, ranges): in_range = False for left_limit, right_limit in ranges: - if ( - left_limit - settings.FLOAT_ZERO - <= value - <= right_limit + settings.FLOAT_ZERO - ): + if left_limit - settings.FLOAT_ZERO <= value <= right_limit + settings.FLOAT_ZERO: in_range = True break @@ -355,16 +337,11 @@ class ComponentParamBase(ABC): def _warn_deprecated_param(self, param_name, description): if self._deprecated_params_set.get(param_name): - logging.warning( - f"{description} {param_name} is deprecated and ignored in this version." - ) + logging.warning(f"{description} {param_name} is deprecated and ignored in this version.") def _warn_to_deprecate_param(self, param_name, description, new_param): if self._deprecated_params_set.get(param_name): - logging.warning( - f"{description} {param_name} will be deprecated in future release; " - f"please use {new_param} instead." - ) + logging.warning(f"{description} {param_name} will be deprecated in future release; please use {new_param} instead.") return True return False @@ -385,9 +362,7 @@ class ComponentBase(ABC): return """{{ "component_name": "{}", "params": {} - }}""".format(self.component_name, - self._param - ) + }}""".format(self.component_name, self._param) def __init__(self, canvas, id, param: ComponentParamBase): from agent.canvas import Graph # Local import to avoid cyclic dependency @@ -403,7 +378,7 @@ class ComponentBase(ABC): def check_if_canceled(self, message: str = "") -> bool: if self.is_canceled(): - task_id = getattr(self._canvas, 'task_id', 'unknown') + task_id = getattr(self._canvas, "task_id", "unknown") log_message = f"Task {task_id} has been canceled" if message: log_message += f" during {message}" @@ -491,7 +466,9 @@ class ComponentBase(ABC): input_elements = self.get_input_elements() _logger.debug( "[Base] Component '%s' (%s) resolving inputs. Input element keys: %s", - self._id, self.component_name, list(input_elements.keys()), + self._id, + self.component_name, + list(input_elements.keys()), ) for var, o in input_elements.items(): v = self.get_param(var) @@ -504,7 +481,7 @@ class ComponentBase(ABC): _logger.debug("[Base] var '%s': resolved ref '%s' -> %s", var, v, json.dumps(resolved, ensure_ascii=False, default=str)[:200]) elif isinstance(v, str) and re.search(self.variable_ref_patt, v): elements = self.get_input_elements_from_text(v) - kv = {k: e.get('value', '') for k, e in elements.items()} + kv = {k: e.get("value", "") for k, e in elements.items()} self.set_input_value(var, self.string_format(v, kv)) _logger.debug("[Base] var '%s': resolved text refs '%s' -> %s", var, v, json.dumps(kv, ensure_ascii=False, default=str)[:200]) else: @@ -545,7 +522,7 @@ class ComponentBase(ABC): "name": (self._canvas.get_component_name(cpn_id) + f"@{var_nm}") if cpn_id else exp, "value": self._canvas.get_variable_value(exp), "_retrieval": self._canvas.get_variable_value(f"{cpn_id}@_references") if cpn_id else None, - "_cpn_id": cpn_id + "_cpn_id": cpn_id, } for r in re.finditer(self.iteration_alias_patt, txt, flags=re.IGNORECASE | re.DOTALL): exp = r.group(1) @@ -559,7 +536,7 @@ class ComponentBase(ABC): "name": (self._canvas.get_component_name(cpn_id) + f"@{var_nm}"), "value": self._canvas.get_variable_value(ref), "_retrieval": self._canvas.get_variable_value(f"{cpn_id}@_references"), - "_cpn_id": cpn_id + "_cpn_id": cpn_id, } return res @@ -601,33 +578,27 @@ class ComponentBase(ABC): return self._canvas.get_component(pid)["obj"] def get_upstream(self) -> List[str]: - cpn_nms = self._canvas.get_component(self._id)['upstream'] + cpn_nms = self._canvas.get_component(self._id)["upstream"] return cpn_nms def get_downstream(self) -> List[str]: - cpn_nms = self._canvas.get_component(self._id)['downstream'] + cpn_nms = self._canvas.get_component(self._id)["downstream"] return cpn_nms @staticmethod def string_format(content: str, kv: dict[str, str]) -> str: for n, v in kv.items(): + def repl(_match, val=v): return str(val) if val is not None else "" - content = re.sub( - r"\{%s\}" % re.escape(n), - repl, - content - ) + content = re.sub(r"\{%s\}" % re.escape(n), repl, content) return content def exception_handler(self): if not self._param.exception_method: return None - return { - "goto": self._param.exception_goto, - "default_value": self._param.exception_default_value - } + return {"goto": self._param.exception_goto, "default_value": self._param.exception_default_value} def get_exception_default_value(self): if self._param.exception_method != "comment": diff --git a/agent/component/begin.py b/agent/component/begin.py index 7945d64251..e93f5fe302 100644 --- a/agent/component/begin.py +++ b/agent/component/begin.py @@ -17,17 +17,17 @@ from agent.component.fillup import UserFillUpParam, UserFillUp class BeginParam(UserFillUpParam): - """ Define the Begin component parameters. """ + def __init__(self): super().__init__() self.mode = "conversational" self.prologue = "Hi! I'm your smart assistant. What can I do for you?" def check(self): - self.check_valid_value(self.mode, "The 'mode' should be either `conversational` or `task`", ["conversational", "task","Webhook"]) + self.check_valid_value(self.mode, "The 'mode' should be either `conversational` or `task`", ["conversational", "task", "Webhook"]) def get_input_form(self) -> dict[str, dict]: return getattr(self, "inputs") diff --git a/agent/component/browser.py b/agent/component/browser.py index c7f77b1577..19a8a91380 100644 --- a/agent/component/browser.py +++ b/agent/component/browser.py @@ -371,9 +371,7 @@ class Browser(ComponentBase, ABC): sort_keys=True, ensure_ascii=False, ) - raw_canvas_id = ( - f"dsl_{hashlib.sha1(graph_text.encode('utf-8')).hexdigest()[:12]}" - ) + raw_canvas_id = f"dsl_{hashlib.sha1(graph_text.encode('utf-8')).hexdigest()[:12]}" canvas_id = self._safe_path_segment(raw_canvas_id) node_id = self._safe_path_segment(self._id) return os.path.join(root, tenant, canvas_id, node_id) @@ -488,10 +486,7 @@ class Browser(ComponentBase, ABC): # Keep browser-use watchdog fallback in sync with our resolved path. os.environ["BROWSER_USE_BROWSER_BINARY_PATH"] = executable_path else: - logging.warning( - "Browser no local browser executable found. " - "Set BROWSER_USE_EXECUTABLE_PATH or preinstall chromium in image to avoid runtime playwright install." - ) + logging.warning("Browser no local browser executable found. Set BROWSER_USE_EXECUTABLE_PATH or preinstall chromium in image to avoid runtime playwright install.") if profile_dir: browser_kwargs["user_data_dir"] = profile_dir # browser-use expects profile_directory to be a profile name @@ -682,21 +677,13 @@ class Browser(ComponentBase, ABC): try: self._prepare_input_values() user_prompt = self._resolve_text(kwargs.get("prompts", self._param.prompts)) - with tempfile.TemporaryDirectory(prefix="browser_use_upload_") as upload_dir, tempfile.TemporaryDirectory( - prefix="browser_use_download_" - ) as download_dir: + with tempfile.TemporaryDirectory(prefix="browser_use_upload_") as upload_dir, tempfile.TemporaryDirectory(prefix="browser_use_download_") as download_dir: uploaded_files = self._prepare_upload_files(upload_dir) - upload_lines = [ - f"- file_id={item['file_id']}, name={item['name']}, local_path={item['local_path']}" - for item in uploaded_files - ] + upload_lines = [f"- file_id={item['file_id']}, name={item['name']}, local_path={item['local_path']}" for item in uploaded_files] task_text = user_prompt if upload_lines: - task_text += ( - "\n\nYou can upload files from these local paths when operating web pages:\n" - + "\n".join(upload_lines) - ) + task_text += "\n\nYou can upload files from these local paths when operating web pages:\n" + "\n".join(upload_lines) upload_local_paths = [item.get("local_path", "") for item in uploaded_files if item.get("local_path")] if persist_session: @@ -707,11 +694,7 @@ class Browser(ComponentBase, ABC): profile_dir = tempfile.mkdtemp(prefix="browser_use_profile_") except OSError: profile_dir = None - history = asyncio.run( - self._run_browser_use_async( - task_text, download_dir, upload_local_paths, profile_dir - ) - ) + history = asyncio.run(self._run_browser_use_async(task_text, download_dir, upload_local_paths, profile_dir)) target_dir_id = FileService.get_root_folder(self._canvas.get_tenant_id())["id"] downloaded_files = self._save_downloads(download_dir, target_dir_id) diff --git a/agent/component/categorize.py b/agent/component/categorize.py index 4b5c39631c..7ad5315fdf 100644 --- a/agent/component/categorize.py +++ b/agent/component/categorize.py @@ -28,10 +28,10 @@ from rag.llm.chat_model import ERROR_PREFIX class CategorizeParam(LLMParam): - """ Define the categorize component parameters. """ + def __init__(self): super().__init__() self.category_description = {} @@ -50,12 +50,7 @@ class CategorizeParam(LLMParam): raise ValueError(f"[Categorize] 'To' of category {k} can not be empty!") def get_input_form(self) -> dict[str, dict]: - return { - "query": { - "type": "line", - "name": "Query" - } - } + return {"query": {"type": "line", "name": "Query"}} def update_prompt(self): cate_lines = [] @@ -63,13 +58,12 @@ class CategorizeParam(LLMParam): for line in desc.get("examples", []): if not line: continue - cate_lines.append("USER: \"" + re.sub(r"\n", " ", line, flags=re.DOTALL) + "\" → "+c) + cate_lines.append('USER: "' + re.sub(r"\n", " ", line, flags=re.DOTALL) + '" → ' + c) descriptions = [] for c, desc in self.category_description.items(): if desc.get("description"): - descriptions.append( - "\n------\nCategory: {}\nDescription: {}".format(c, desc["description"])) + descriptions.append("\n------\nCategory: {}\nDescription: {}".format(c, desc["description"])) self.sys_prompt = """ You are an advanced classification system that categorizes user questions into specific types. Analyze the input question and classify it into ONE of the following categories: @@ -84,10 +78,7 @@ Here's description of each category: - Return only the category name without explanations - Use "Other" only when no other category fits - """.format( - "\n - ".join(list(self.category_description.keys())), - "\n".join(descriptions) - ) + """.format("\n - ".join(list(self.category_description.keys())), "\n".join(descriptions)) if cate_lines: self.sys_prompt += """ @@ -106,7 +97,7 @@ class Categorize(LLM, ABC): logging.warning(f"[Categorize] input element not detected for query key: {query_key}") return elements - @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) + @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10 * 60))) async def _invoke_async(self, **kwargs): if self.check_if_canceled("Categorize processing"): return @@ -130,7 +121,7 @@ class Categorize(LLM, ABC): user_prompt = """ ---- Real Data ---- {} → -""".format(" | ".join(["{}: \"{}\"".format(c["role"].upper(), re.sub(r"\n", "", c["content"], flags=re.DOTALL)) for c in msg])) +""".format(" | ".join(['{}: "{}"'.format(c["role"].upper(), re.sub(r"\n", "", c["content"], flags=re.DOTALL)) for c in msg])) if self.check_if_canceled("Categorize processing"): return @@ -158,7 +149,7 @@ class Categorize(LLM, ABC): self.set_output("category_name", max_category) self.set_output("_next", cpn_ids) - @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) + @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10 * 60))) def _invoke(self, **kwargs): return asyncio.run(self._invoke_async(**kwargs)) diff --git a/agent/component/data_operations.py b/agent/component/data_operations.py index 9cf5c55335..b7884757f2 100644 --- a/agent/component/data_operations.py +++ b/agent/component/data_operations.py @@ -19,55 +19,47 @@ import os from agent.component.base import ComponentBase, ComponentParamBase from api.utils.api_utils import timeout + class DataOperationsParam(ComponentParamBase): """ Define the Data Operations component parameters. """ + def __init__(self): super().__init__() self.query = [] self.operations = "literal_eval" self.select_keys = [] - self.filter_values=[] - self.updates=[] - self.remove_keys=[] - self.rename_keys=[] - self.outputs = { - "result": { - "value": [], - "type": "Array of Object" - } - } - - def check(self): - self.check_valid_value(self.operations, "Support operations", ["select_keys", "literal_eval","combine","filter_values","append_or_update","remove_keys","rename_keys"]) - - + self.filter_values = [] + self.updates = [] + self.remove_keys = [] + self.rename_keys = [] + self.outputs = {"result": {"value": [], "type": "Array of Object"}} -class DataOperations(ComponentBase,ABC): + def check(self): + self.check_valid_value(self.operations, "Support operations", ["select_keys", "literal_eval", "combine", "filter_values", "append_or_update", "remove_keys", "rename_keys"]) + + +class DataOperations(ComponentBase, ABC): component_name = "DataOperations" def get_input_form(self) -> dict[str, dict]: - return { - k: {"name": o.get("name", ""), "type": "line"} - for input_item in (self._param.query or []) - for k, o in self.get_input_elements_from_text(input_item).items() - } + return {k: {"name": o.get("name", ""), "type": "line"} for input_item in (self._param.query or []) for k, o in self.get_input_elements_from_text(input_item).items()} - @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) + @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10 * 60))) def _invoke(self, **kwargs): - self.input_objects=[] + self.input_objects = [] inputs = getattr(self._param, "query", None) if not isinstance(inputs, (list, tuple)): inputs = [inputs] for input_ref in inputs: - input_object=self._canvas.get_variable_value(input_ref) + input_object = self._canvas.get_variable_value(input_ref) self.set_input_value(input_ref, input_object) if input_object is None: continue - if isinstance(input_object,dict): + if isinstance(input_object, dict): self.input_objects.append(input_object) - elif isinstance(input_object,list): + elif isinstance(input_object, list): self.input_objects.extend(x for x in input_object if isinstance(x, dict)) else: continue @@ -85,13 +77,12 @@ class DataOperations(ComponentBase,ABC): self._remove_keys() else: self._rename_keys() - + def _select_keys(self): filter_criteria: list[str] = self._param.select_keys results = [{key: value for key, value in data_dict.items() if key in filter_criteria} for data_dict in self.input_objects] self.set_output("result", results) - def _recursive_eval(self, data): if isinstance(data, dict): return {k: self._recursive_eval(v) for k, v in data.items()} @@ -99,23 +90,19 @@ class DataOperations(ComponentBase,ABC): return [self._recursive_eval(item) for item in data] if isinstance(data, str): try: - if ( - data.strip().startswith(("{", "[", "(", "'", '"')) - or data.strip().lower() in ("true", "false", "none") - or data.strip().replace(".", "").isdigit() - ): + if data.strip().startswith(("{", "[", "(", "'", '"')) or data.strip().lower() in ("true", "false", "none") or data.strip().replace(".", "").isdigit(): return ast.literal_eval(data) except (ValueError, SyntaxError, TypeError, MemoryError): return data else: return data return data - + def _literal_eval(self): self.set_output("result", self._recursive_eval(self.input_objects)) def _combine(self): - result={} + result = {} for obj in self.input_objects: for key, value in obj.items(): if key not in result: @@ -126,15 +113,13 @@ class DataOperations(ComponentBase,ABC): else: result[key].append(value) else: - result[key] = ( - [result[key], value] if not isinstance(value, list) else [result[key], *value] - ) + result[key] = [result[key], value] if not isinstance(value, list) else [result[key], *value] self.set_output("result", result) - - def norm(self,v): + + def norm(self, v): s = "" if v is None else str(v) return s - + def match_rule(self, obj, rule): key = rule.get("key") op = (rule.get("operator") or "equals").lower() @@ -155,10 +140,10 @@ class DataOperations(ComponentBase,ABC): if op == "end with": return v.endswith(target) return False - + def _filter_values(self): - results=[] - rules = (getattr(self._param, "filter_values", None) or []) + results = [] + rules = getattr(self._param, "filter_values", None) or [] for obj in self.input_objects: if not rules: results.append(obj) @@ -166,11 +151,10 @@ class DataOperations(ComponentBase,ABC): if all(self.match_rule(obj, r) for r in rules): results.append(obj) self.set_output("result", results) - - + def _append_or_update(self): - results=[] - updates = getattr(self._param, "updates", []) or [] + results = [] + updates = getattr(self._param, "updates", []) or [] for obj in self.input_objects: new_obj = dict(obj) for item in updates: @@ -187,7 +171,7 @@ class DataOperations(ComponentBase,ABC): results = [] remove_keys = getattr(self._param, "remove_keys", []) or [] - for obj in (self.input_objects or []): + for obj in self.input_objects or []: new_obj = dict(obj) for k in remove_keys: if not isinstance(k, str): @@ -200,7 +184,7 @@ class DataOperations(ComponentBase,ABC): results = [] rename_pairs = getattr(self._param, "rename_keys", []) or [] - for obj in (self.input_objects or []): + for obj in self.input_objects or []: new_obj = dict(obj) for pair in rename_pairs: if not isinstance(pair, dict): diff --git a/agent/component/docs_generator.py b/agent/component/docs_generator.py index 2809a9b1ca..f73ff441aa 100644 --- a/agent/component/docs_generator.py +++ b/agent/component/docs_generator.py @@ -163,6 +163,7 @@ class DocGenerator(Message, ABC): logging.info("Starting document generation, content length: %s chars", len(content)) if content: + def _replace_variable(match_obj: re.Match[str]) -> str: match = match_obj.group(1) try: diff --git a/agent/component/excel_processor.py b/agent/component/excel_processor.py index 65b3a9bd20..df1bfaaf2c 100644 --- a/agent/component/excel_processor.py +++ b/agent/component/excel_processor.py @@ -39,71 +39,52 @@ class ExcelProcessorParam(ComponentParamBase): """ Define the ExcelProcessor component parameters. """ + def __init__(self): super().__init__() # Input configuration self.input_files = [] # Variable references to uploaded files self.operation = "read" # read, merge, transform, output - + # Processing options self.sheet_selection = "all" # all, first, or comma-separated sheet names self.merge_strategy = "concat" # concat, join self.join_on = "" # Column name for join operations - + # Transform options (for LLM-guided transformations) self.transform_instructions = "" self.transform_data = "" # Variable reference to transformation data - + # Output options self.output_format = "xlsx" # xlsx, csv self.output_filename = "output" - + # Component outputs - self.outputs = { - "data": { - "type": "object", - "value": {} - }, - "summary": { - "type": "str", - "value": "" - }, - "markdown": { - "type": "str", - "value": "" - } - } - + self.outputs = {"data": {"type": "object", "value": {}}, "summary": {"type": "str", "value": ""}, "markdown": {"type": "str", "value": ""}} + def check(self): - self.check_valid_value( - self.operation, - "[ExcelProcessor] Operation", - ["read", "merge", "transform", "output"] - ) - self.check_valid_value( - self.output_format, - "[ExcelProcessor] Output format", - ["xlsx", "csv"] - ) + self.check_valid_value(self.operation, "[ExcelProcessor] Operation", ["read", "merge", "transform", "output"]) + self.check_valid_value(self.output_format, "[ExcelProcessor] Output format", ["xlsx", "csv"]) return True class ExcelProcessor(ComponentBase, ABC): """ Excel processing component for RAGFlow agents. - + Operations: - read: Parse Excel files into structured data - merge: Combine multiple Excel files - transform: Apply data transformations based on instructions - output: Generate Excel file output """ + component_name = "ExcelProcessor" def get_input_form(self) -> dict[str, dict]: """Define input form for the component.""" res = {} - for ref in (self._param.input_files or []): + for ref in self._param.input_files or []: for k, o in self.get_input_elements_from_text(ref).items(): res[k] = {"name": o.get("name", ""), "type": "file"} if self._param.transform_data: @@ -111,13 +92,13 @@ class ExcelProcessor(ComponentBase, ABC): res[k] = {"name": o.get("name", ""), "type": "object"} return res - @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) + @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10 * 60))) def _invoke(self, **kwargs): if self.check_if_canceled("ExcelProcessor processing"): return operation = self._param.operation.lower() - + if operation == "read": self._read_excels() elif operation == "merge": @@ -137,7 +118,7 @@ class ExcelProcessor(ComponentBase, ABC): value = self._canvas.get_variable_value(file_ref) if value is None: return None, None - + # Handle different value formats if isinstance(value, dict): # File reference from Begin/UserFillUp component @@ -154,12 +135,13 @@ class ExcelProcessor(ComponentBase, ABC): # Could be base64 encoded or a path if value.startswith("data:"): import base64 + # Extract base64 content _, encoded = value.split(",", 1) return base64.b64decode(encoded), "uploaded.xlsx" - + return None, None - + def _get_file_content_from_list(self, item) -> tuple[bytes, str]: """Extract file content from a list item.""" if isinstance(item, dict): @@ -170,15 +152,15 @@ class ExcelProcessor(ComponentBase, ABC): """Parse Excel content into a dictionary of DataFrames (one per sheet).""" try: excel_file = BytesIO(content) - + if filename.lower().endswith(".csv"): df = pd.read_csv(excel_file) return {"Sheet1": df} else: # Read all sheets - xlsx = pd.ExcelFile(excel_file, engine='openpyxl') + xlsx = pd.ExcelFile(excel_file, engine="openpyxl") sheet_selection = self._param.sheet_selection - + if sheet_selection == "all": sheets_to_read = xlsx.sheet_names elif sheet_selection == "first": @@ -187,12 +169,12 @@ class ExcelProcessor(ComponentBase, ABC): # Comma-separated sheet names requested = [s.strip() for s in sheet_selection.split(",")] sheets_to_read = [s for s in requested if s in xlsx.sheet_names] - + dfs = {} for sheet in sheets_to_read: dfs[sheet] = pd.read_excel(xlsx, sheet_name=sheet) return dfs - + except Exception as e: logging.error(f"Error parsing Excel file {filename}: {e}") return {} @@ -202,36 +184,36 @@ class ExcelProcessor(ComponentBase, ABC): all_data = {} summaries = [] markdown_parts = [] - - for file_ref in (self._param.input_files or []): + + for file_ref in self._param.input_files or []: if self.check_if_canceled("ExcelProcessor reading"): return - + # Get variable value value = self._canvas.get_variable_value(file_ref) self.set_input_value(file_ref, str(value)[:200] if value else "") - + if value is None: continue - + # Handle file content content, filename = self._get_file_content(file_ref) if content is None: continue - + # Parse Excel dfs = self._parse_excel_to_dataframes(content, filename) - + for sheet_name, df in dfs.items(): key = f"{filename}_{sheet_name}" if len(dfs) > 1 else filename all_data[key] = df.to_dict(orient="records") - + # Build summary summaries.append(f"**{key}**: {len(df)} rows, {len(df.columns)} columns ({', '.join(df.columns.tolist()[:5])}{'...' if len(df.columns) > 5 else ''})") - + # Build markdown table markdown_parts.append(f"### {key}\n\n{df.head(10).to_markdown(index=False)}\n") - + # Set outputs self.set_output("data", all_data) self.set_output("summary", "\n".join(summaries) if summaries else "No Excel files found") @@ -240,29 +222,29 @@ class ExcelProcessor(ComponentBase, ABC): def _merge_excels(self): """Merge multiple Excel files/sheets into one.""" all_dfs = [] - - for file_ref in (self._param.input_files or []): + + for file_ref in self._param.input_files or []: if self.check_if_canceled("ExcelProcessor merging"): return - + value = self._canvas.get_variable_value(file_ref) self.set_input_value(file_ref, str(value)[:200] if value else "") - + if value is None: continue - + content, filename = self._get_file_content(file_ref) if content is None: continue - + dfs = self._parse_excel_to_dataframes(content, filename) all_dfs.extend(dfs.values()) - + if not all_dfs: self.set_output("data", {}) self.set_output("summary", "No data to merge") return - + # Merge strategy if self._param.merge_strategy == "concat": merged_df = pd.concat(all_dfs, ignore_index=True) @@ -273,7 +255,7 @@ class ExcelProcessor(ComponentBase, ABC): merged_df = merged_df.merge(df, on=self._param.join_on, how="outer") else: merged_df = pd.concat(all_dfs, ignore_index=True) - + self.set_output("data", {"merged": merged_df.to_dict(orient="records")}) self.set_output("summary", f"Merged {len(all_dfs)} sources into {len(merged_df)} rows, {len(merged_df.columns)} columns") self.set_output("markdown", merged_df.head(20).to_markdown(index=False)) @@ -285,14 +267,14 @@ class ExcelProcessor(ComponentBase, ABC): if not transform_ref: self.set_output("summary", "No transform data reference provided") return - + data = self._canvas.get_variable_value(transform_ref) self.set_input_value(transform_ref, str(data)[:300] if data else "") - + if data is None: self.set_output("summary", "Transform data is empty") return - + # Convert to DataFrame if isinstance(data, dict): # Could be {"sheet": [rows]} format @@ -315,7 +297,7 @@ class ExcelProcessor(ComponentBase, ABC): else: self.set_output("data", {"raw": str(data)}) self.set_output("markdown", str(data)) - + self.set_output("summary", "Transformed data ready for processing") def _output_excel(self): @@ -325,14 +307,14 @@ class ExcelProcessor(ComponentBase, ABC): if not transform_ref: self.set_output("summary", "No data reference for output") return - + data = self._canvas.get_variable_value(transform_ref) self.set_input_value(transform_ref, str(data)[:300] if data else "") - + if data is None: self.set_output("summary", "No data to output") return - + try: # Prepare DataFrames if isinstance(data, dict): @@ -346,10 +328,10 @@ class ExcelProcessor(ComponentBase, ABC): else: self.set_output("summary", "Invalid data format for Excel output") return - + # Generate output doc_id = get_uuid() - + if self._param.output_format == "csv": # For CSV, only output first sheet first_df = list(dfs.values())[0] @@ -358,7 +340,7 @@ class ExcelProcessor(ComponentBase, ABC): else: # Excel output excel_io = BytesIO() - with pd.ExcelWriter(excel_io, engine='openpyxl') as writer: + with pd.ExcelWriter(excel_io, engine="openpyxl") as writer: for sheet_name, df in dfs.items(): # Sanitize sheet name (max 31 chars, no special chars) safe_name = sheet_name[:31].replace("/", "_").replace("\\", "_") @@ -366,23 +348,19 @@ class ExcelProcessor(ComponentBase, ABC): excel_io.seek(0) binary_content = excel_io.read() filename = f"{self._param.output_filename}.xlsx" - + # Store file settings.STORAGE_IMPL.put(self._canvas._tenant_id, doc_id, binary_content) - + # Set attachment output - self.set_output("attachment", { - "doc_id": doc_id, - "format": self._param.output_format, - "file_name": filename - }) - + self.set_output("attachment", {"doc_id": doc_id, "format": self._param.output_format, "file_name": filename}) + total_rows = sum(len(df) for df in dfs.values()) self.set_output("summary", f"Generated {filename} with {len(dfs)} sheet(s), {total_rows} total rows") self.set_output("data", {k: v.to_dict(orient="records") for k, v in dfs.items()}) - + logging.info(f"ExcelProcessor: Generated {filename} as {doc_id}") - + except Exception as e: logging.error(f"ExcelProcessor output error: {e}") self.set_output("summary", f"Error generating output: {str(e)}") diff --git a/agent/component/exit_loop.py b/agent/component/exit_loop.py index 9dc0449129..24abeca55f 100644 --- a/agent/component/exit_loop.py +++ b/agent/component/exit_loop.py @@ -29,4 +29,4 @@ class ExitLoop(ComponentBase, ABC): pass def thoughts(self) -> str: - return "" \ No newline at end of file + return "" diff --git a/agent/component/fillup.py b/agent/component/fillup.py index 0206727d2a..bb62657619 100644 --- a/agent/component/fillup.py +++ b/agent/component/fillup.py @@ -25,7 +25,6 @@ _INITIAL_USER_INPUT_CONSUMED_KEY = "sys.__initial_user_input_consumed__" class UserFillUpParam(ComponentParamBase): - def __init__(self): super().__init__() self.enable_tips = True @@ -55,11 +54,7 @@ class UserFillUp(ComponentBase): return {} if isinstance(query, dict): - matched = { - key: value if isinstance(value, dict) else {"value": value} - for key, value in query.items() - if key in fields - } + matched = {key: value if isinstance(value, dict) else {"value": value} for key, value in query.items() if key in fields} if matched: self._canvas.globals[_INITIAL_USER_INPUT_CONSUMED_KEY] = True return matched @@ -108,7 +103,7 @@ class UserFillUp(ComponentBase): ans = v if not ans: ans = "" - content = re.sub(r"\{%s\}"%k, ans, content) + content = re.sub(r"\{%s\}" % k, ans, content) self.set_output("tips", content) layout_recognize = self._param.layout_recognize or None diff --git a/agent/component/iteration.py b/agent/component/iteration.py index ae5c0b6772..3ccae2c0bb 100644 --- a/agent/component/iteration.py +++ b/agent/component/iteration.py @@ -24,6 +24,7 @@ class VariableModel(BaseModel): model_config = ConfigDict(extra="forbid") """ + class IterationParam(ComponentParamBase): """ Define the Iteration component parameters. @@ -32,15 +33,10 @@ class IterationParam(ComponentParamBase): def __init__(self): super().__init__() self.items_ref = "" - self.variable={} + self.variable = {} def get_input_form(self) -> dict[str, dict]: - return { - "items": { - "type": "json", - "name": "Items" - } - } + return {"items": {"type": "json", "name": "Items"}} def check(self): return True @@ -62,10 +58,7 @@ class Iteration(ComponentBase, ABC): arr = self._canvas.get_variable_value(self._param.items_ref) if not isinstance(arr, list): - self.set_output("_ERROR", self._param.items_ref + " must be an array, but its type is "+str(type(arr))) + self.set_output("_ERROR", self._param.items_ref + " must be an array, but its type is " + str(type(arr))) def thoughts(self) -> str: return "Need to process {} items.".format(len(self._canvas.get_variable_value(self._param.items_ref))) - - - diff --git a/agent/component/iterationitem.py b/agent/component/iterationitem.py index c9134e7c77..6abb6aa7b8 100644 --- a/agent/component/iterationitem.py +++ b/agent/component/iterationitem.py @@ -21,6 +21,7 @@ class IterationItemParam(ComponentParamBase): """ Define the IterationItem component parameters. """ + def check(self): return True @@ -40,7 +41,7 @@ class IterationItem(ComponentBase, ABC): arr = self._canvas.get_variable_value(parent._param.items_ref) if not isinstance(arr, list): self._idx = -1 - raise Exception(parent._param.items_ref + " must be an array, but its type is "+str(type(arr))) + raise Exception(parent._param.items_ref + " must be an array, but its type is " + str(type(arr))) if self._idx > 0: if self.check_if_canceled("IterationItem processing"): diff --git a/agent/component/list_operations.py b/agent/component/list_operations.py index 11d4a1e0a9..b81058e74d 100644 --- a/agent/component/list_operations.py +++ b/agent/component/list_operations.py @@ -3,10 +3,12 @@ import os from agent.component.base import ComponentBase, ComponentParamBase from api.utils.api_utils import timeout + class ListOperationsParam(ComponentParamBase): """ Define the List Operations component parameters. """ + def __init__(self): super().__init__() self.query = "" @@ -20,24 +22,8 @@ class ListOperationsParam(ComponentParamBase): # first field). Mirrors internal/agent/component/list_operations.go # parseSortByFieldList + opSort's SortBy path. self.sort_by = "" - self.filter = { - "operator": "=", - "value": "" - } - self.outputs = { - "result": { - "value": [], - "type": "Array of ?" - }, - "first": { - "value": "", - "type": "?" - }, - "last": { - "value": "", - "type": "?" - } - } + self.filter = {"operator": "=", "value": ""} + self.outputs = {"result": {"value": [], "type": "Array of ?"}, "first": {"value": "", "type": "?"}, "last": {"value": "", "type": "?"}} @staticmethod def _normalize_operation_name(operation): @@ -45,7 +31,7 @@ class ListOperationsParam(ComponentParamBase): if op.lower() == "topn": return "head" return op or "nth" - + def check(self): self.check_empty(self.query, "query") self.operations = self._normalize_operation_name(self.operations) @@ -57,14 +43,14 @@ class ListOperationsParam(ComponentParamBase): def get_input_form(self) -> dict[str, dict]: return {} - -class ListOperations(ComponentBase,ABC): + +class ListOperations(ComponentBase, ABC): component_name = "ListOperations" - @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) + @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10 * 60))) def _invoke(self, **kwargs): - self.input_objects=[] + self.input_objects = [] inputs = getattr(self._param, "query", None) self.inputs = self._canvas.get_variable_value(inputs) if not isinstance(self.inputs, list): @@ -83,7 +69,6 @@ class ListOperations(ComponentBase,ABC): elif self._param.operations == "drop_duplicates": self._drop_duplicates() - def _coerce_n(self): try: return int(getattr(self._param, "n", 0)) @@ -99,12 +84,10 @@ class ListOperations(ComponentBase,ABC): def _set_outputs(self, outputs): self._param.outputs["result"]["value"] = outputs self._param.outputs["first"]["value"] = outputs[0] if outputs else None - self._param.outputs["last"]["value"] = outputs[-1] if outputs else None + self._param.outputs["last"]["value"] = outputs[-1] if outputs else None def _raise_strict_range_error(self, operation, n): - raise ValueError( - f"{operation} requires n to be within the valid range in strict mode, got {n}." - ) + raise ValueError(f"{operation} requires n to be within the valid range in strict mode, got {n}.") def _nth(self): n = self._coerce_n() @@ -160,9 +143,9 @@ class ListOperations(ComponentBase,ABC): self._set_outputs(outputs) def _filter(self): - self._set_outputs([i for i in self.inputs if self._eval(self._norm(i),self._param.filter["operator"],self._param.filter["value"])]) + self._set_outputs([i for i in self.inputs if self._eval(self._norm(i), self._param.filter["operator"], self._param.filter["value"])]) - def _norm(self,v): + def _norm(self, v): s = "" if v is None else str(v) return s @@ -222,7 +205,7 @@ class ListOperations(ComponentBase,ABC): outs.append(item) self._set_outputs(outs) - def _hashable(self,x): + def _hashable(self, x): if isinstance(x, dict): return tuple(sorted((k, self._hashable(v)) for k, v in x.items())) if isinstance(x, (list, tuple)): diff --git a/agent/component/loop.py b/agent/component/loop.py index 5337311915..93d126c67e 100644 --- a/agent/component/loop.py +++ b/agent/component/loop.py @@ -25,16 +25,11 @@ class LoopParam(ComponentParamBase): def __init__(self): super().__init__() self.loop_variables = [] - self.loop_termination_condition=[] + self.loop_termination_condition = [] self.maximum_loop_count = 0 def get_input_form(self) -> dict[str, dict]: - return { - "items": { - "type": "json", - "name": "Items" - } - } + return {"items": {"type": "json", "name": "Items"}} def check(self): return True @@ -83,10 +78,10 @@ class Loop(ComponentBase, ABC): for item in self._param.loop_variables: if self._is_incomplete_loop_variable(item): raise ValueError("Loop Variable is not complete.") - if item["input_mode"]=="variable": - self.set_output(item["variable"],self._canvas.get_variable_value(item["value"])) - elif item["input_mode"]=="constant": - self.set_output(item["variable"],item["value"]) + if item["input_mode"] == "variable": + self.set_output(item["variable"], self._canvas.get_variable_value(item["value"])) + elif item["input_mode"] == "constant": + self.set_output(item["variable"], item["value"]) else: if item["type"] == "number": self.set_output(item["variable"], 0) @@ -101,6 +96,5 @@ class Loop(ComponentBase, ABC): else: self.set_output(item["variable"], "") - def thoughts(self) -> str: - return "Loop from canvas." \ No newline at end of file + return "Loop from canvas." diff --git a/agent/component/loopitem.py b/agent/component/loopitem.py index 0cfb500850..8392a2b483 100644 --- a/agent/component/loopitem.py +++ b/agent/component/loopitem.py @@ -21,9 +21,11 @@ class LoopItemParam(ComponentParamBase): """ Define the LoopItem component parameters. """ + def check(self): return True + class LoopItem(ComponentBase, ABC): component_name = "LoopItem" @@ -31,7 +33,6 @@ class LoopItem(ComponentBase, ABC): super().__init__(canvas, id, param) self._idx = 0 - def _invoke(self, **kwargs): if self.check_if_canceled("LoopItem processing"): return @@ -45,7 +46,7 @@ class LoopItem(ComponentBase, ABC): return self._idx += 1 - def evaluate_condition(self,var, operator, value): + def evaluate_condition(self, var, operator, value): if isinstance(var, str): if operator == "contains": return value in var @@ -140,11 +141,7 @@ class LoopItem(ComponentBase, ABC): else: raise ValueError("Invalid input mode.") conditions.append(self.evaluate_condition(var, operator, value)) - should_end = ( - all(conditions) if logical_operator == "and" - else any(conditions) if logical_operator == "or" - else None - ) + should_end = all(conditions) if logical_operator == "and" else any(conditions) if logical_operator == "or" else None if should_end is None: raise ValueError("Invalid logical operator,should be 'and' or 'or'.") @@ -164,4 +161,4 @@ class LoopItem(ComponentBase, ABC): return False def thoughts(self) -> str: - return "Next turn..." \ No newline at end of file + return "Next turn..." diff --git a/agent/component/message.py b/agent/component/message.py index 359bedbe58..6b03b5f445 100644 --- a/agent/component/message.py +++ b/agent/component/message.py @@ -14,8 +14,10 @@ # limitations under the License. # import asyncio + try: import nest_asyncio + nest_asyncio.apply() except Exception: pass @@ -45,20 +47,14 @@ class MessageParam(ComponentParamBase): """ Define the Message component parameters. """ + def __init__(self): super().__init__() self.content = [] self.stream = True self.output_format = None # default output format self.auto_play = False - self.outputs = { - "content": { - "type": "str" - }, - "downloads": { - "type": "list" - } - } + self.outputs = {"content": {"type": "str"}, "downloads": {"type": "list"}} def check(self): self.check_empty(self.content, "[Message] Content") @@ -71,9 +67,7 @@ class Message(ComponentBase): @staticmethod def _is_download_info(value: Any) -> bool: - return isinstance(value, dict) and all( - key in value for key in ("doc_id", "filename", "mime_type") - ) + return isinstance(value, dict) and all(key in value for key in ("doc_id", "filename", "mime_type")) @staticmethod def _download_info_includes_content(value: Any) -> bool: @@ -157,7 +151,7 @@ class Message(ComponentBase): delimiter: str = None, downloads: list[dict[str, Any]] | None = None, ) -> tuple[str, dict[str, str | list | Any]]: - for k,v in self.get_input_elements_from_text(script).items(): + for k, v in self.get_input_elements_from_text(script).items(): if k in kwargs: continue v = v["value"] @@ -191,7 +185,7 @@ class Message(ComponentBase): buf += t return buf - async def _stream(self, rand_cnt:str): + async def _stream(self, rand_cnt: str): s = 0 all_content = "" cache = {} @@ -200,8 +194,8 @@ class Message(ComponentBase): if self.check_if_canceled("Message streaming"): return - all_content += rand_cnt[s: r.start()] - yield rand_cnt[s: r.start()] + all_content += rand_cnt[s : r.start()] + yield rand_cnt[s : r.start()] s = r.end() exp = r.group(1) if exp in cache: @@ -235,9 +229,7 @@ class Message(ComponentBase): continue elif inspect.isawaitable(v): v = await v - v = self._stringify_message_value( - v, downloads=downloads, fallback_to_str=True - ) + v = self._stringify_message_value(v, downloads=downloads, fallback_to_str=True) yield v self.set_input_value(exp, v) all_content += v @@ -247,21 +239,19 @@ class Message(ComponentBase): if self.check_if_canceled("Message streaming"): return - all_content += rand_cnt[s: ] - yield rand_cnt[s: ] + all_content += rand_cnt[s:] + yield rand_cnt[s:] self.set_output("downloads", downloads) self.set_output("content", all_content) self._convert_content(all_content) await self._save_to_memory(all_content) - def _is_jinjia2(self, content:str) -> bool: - patt = [ - r"\{%.*%\}", "{{", "}}" - ] + def _is_jinjia2(self, content: str) -> bool: + patt = [r"\{%.*%\}", "{{", "}}"] return any([re.search(p, content) for p in patt]) - @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) + @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10 * 60))) def _invoke(self, **kwargs): if self.check_if_canceled("Message processing"): return @@ -303,19 +293,19 @@ class Message(ComponentBase): def _parse_markdown_table_lines(self, table_lines: list): """ Parse a list of Markdown table lines into a pandas DataFrame. - + Args: table_lines: List of strings, each representing a row in the Markdown table (excluding separator lines like |---|---|) - + Returns: pandas DataFrame with the table data, or None if parsing fails """ import pandas as pd - + if not table_lines: return None - + rows = [] headers = None @@ -350,29 +340,29 @@ class Message(ComponentBase): return cell return cell - + for line in table_lines: # Split by | and clean up - cells = [cell.strip() for cell in line.split('|')] + cells = [cell.strip() for cell in line.split("|")] # Remove empty first and last elements from split (caused by leading/trailing |) cells = [c for c in cells if c] - + if headers is None: headers = cells else: cells = [_coerce_excel_cell_type(c) for c in cells] rows.append(cells) - + if headers and rows: # Ensure all rows have same number of columns as headers normalized_rows = [] for row in rows: while len(row) < len(headers): - row.append('') - normalized_rows.append(row[:len(headers)]) - + row.append("") + normalized_rows.append(row[: len(headers)]) + return pd.DataFrame(normalized_rows, columns=headers) - + return None def _convert_content(self, content): @@ -380,6 +370,7 @@ class Message(ComponentBase): return import pypandoc + doc_id = get_uuid() if self._param.output_format.lower() not in {"markdown", "html", "pdf", "docx", "xlsx"}: @@ -408,49 +399,47 @@ class Message(ComponentBase): # Debug: log the content being parsed logging.info(f"XLSX Parser: Content length={len(content) if content else 0}, first 500 chars: {content[:500] if content else 'None'}") - + # Try to parse ALL Markdown tables from the content # Each table will be written to a separate sheet tables = [] # List of (sheet_name, dataframe) - + if isinstance(content, str): - lines = content.strip().split('\n') + lines = content.strip().split("\n") logging.info(f"XLSX Parser: Total lines={len(lines)}, lines starting with '|': {sum(1 for line in lines if line.strip().startswith('|'))}") current_table_lines = [] current_table_title = None pending_title = None in_table = False table_count = 0 - + for i, line in enumerate(lines): stripped = line.strip() - + # Check for potential table title (lines before a table) # Look for patterns like "Table 1:", "## Table", or markdown headers - if not in_table and stripped and not stripped.startswith('|'): + if not in_table and stripped and not stripped.startswith("|"): # Check if this could be a table title lower_stripped = stripped.lower() - if (lower_stripped.startswith('table') or - stripped.startswith('#') or - ':' in stripped): - pending_title = stripped.lstrip('#').strip() - - if stripped.startswith('|') and '|' in stripped[1:]: + if lower_stripped.startswith("table") or stripped.startswith("#") or ":" in stripped: + pending_title = stripped.lstrip("#").strip() + + if stripped.startswith("|") and "|" in stripped[1:]: # Check if this is a separator line (|---|---|) - cleaned = stripped.replace(' ', '').replace('|', '').replace('-', '').replace(':', '') - if cleaned == '': + cleaned = stripped.replace(" ", "").replace("|", "").replace("-", "").replace(":", "") + if cleaned == "": continue # Skip separator line - + if not in_table: # Starting a new table in_table = True current_table_lines = [] current_table_title = pending_title pending_title = None - + current_table_lines.append(stripped) - - elif in_table and not stripped.startswith('|'): + + elif in_table and not stripped.startswith("|"): # End of current table - save it if current_table_lines: df = self._parse_markdown_table_lines(current_table_lines) @@ -460,24 +449,22 @@ class Message(ComponentBase): if current_table_title: # Clean and truncate title for sheet name sheet_name = current_table_title[:31] - sheet_name = sheet_name.replace('/', '_').replace('\\', '_').replace('*', '').replace('?', '').replace('[', '').replace(']', '').replace(':', '') + sheet_name = sheet_name.replace("/", "_").replace("\\", "_").replace("*", "").replace("?", "").replace("[", "").replace("]", "").replace(":", "") else: sheet_name = f"Table_{table_count}" tables.append((sheet_name, df)) - + # Reset for next table in_table = False current_table_lines = [] current_table_title = None - + # Check if this line could be a title for the next table if stripped: lower_stripped = stripped.lower() - if (lower_stripped.startswith('table') or - stripped.startswith('#') or - ':' in stripped): - pending_title = stripped.lstrip('#').strip() - + if lower_stripped.startswith("table") or stripped.startswith("#") or ":" in stripped: + pending_title = stripped.lstrip("#").strip() + # Don't forget the last table if content ends with a table if in_table and current_table_lines: df = self._parse_markdown_table_lines(current_table_lines) @@ -485,11 +472,11 @@ class Message(ComponentBase): table_count += 1 if current_table_title: sheet_name = current_table_title[:31] - sheet_name = sheet_name.replace('/', '_').replace('\\', '_').replace('*', '').replace('?', '').replace('[', '').replace(']', '').replace(':', '') + sheet_name = sheet_name.replace("/", "_").replace("\\", "_").replace("*", "").replace("?", "").replace("[", "").replace("]", "").replace(":", "") else: sheet_name = f"Table_{table_count}" tables.append((sheet_name, df)) - + # Fallback: if no tables found, create single sheet with content if not tables: df = pd.DataFrame({"Content": [content if content else ""]}) @@ -497,7 +484,7 @@ class Message(ComponentBase): # Write all tables to Excel, each in a separate sheet excel_io = BytesIO() - with pd.ExcelWriter(excel_io, engine='openpyxl') as writer: + with pd.ExcelWriter(excel_io, engine="openpyxl") as writer: used_names = set() for sheet_name, df in tables: # Ensure unique sheet names @@ -505,14 +492,14 @@ class Message(ComponentBase): counter = 1 while sheet_name in used_names: suffix = f"_{counter}" - sheet_name = original_name[:31-len(suffix)] + suffix + sheet_name = original_name[: 31 - len(suffix)] + suffix counter += 1 used_names.add(sheet_name) df.to_excel(writer, sheet_name=sheet_name, index=False) - + excel_io.seek(0) binary_content = excel_io.read() - + logging.info(f"Generated Excel with {len(tables)} sheet(s): {[t[0] for t in tables]}") else: # pdf, docx @@ -543,10 +530,7 @@ class Message(ComponentBase): os.remove(tmp_name) settings.STORAGE_IMPL.put(self._canvas._tenant_id, doc_id, binary_content) - self.set_output("attachment", { - "doc_id":doc_id, - "format":self._param.output_format, - "file_name":f"{doc_id[:8]}.{self._param.output_format}"}) + self.set_output("attachment", {"doc_id": doc_id, "format": self._param.output_format, "file_name": f"{doc_id[:8]}.{self._param.output_format}"}) logging.info(f"Converted content uploaded as {doc_id} (format={self._param.output_format})") @@ -560,15 +544,10 @@ class Message(ComponentBase): user_id = self._param.user_id if hasattr(self._param, "user_id") else "" if user_id: import re + # is variable if re.match(r"^{.*}$", user_id): user_id = self._canvas.get_variable_value(user_id) - message_dict = { - "user_id": user_id, - "agent_id": self._canvas._id, - "session_id": self._canvas.task_id, - "user_input": self._canvas.get_sys_query(), - "agent_response": content - } + message_dict = {"user_id": user_id, "agent_id": self._canvas._id, "session_id": self._canvas.task_id, "user_input": self._canvas.get_sys_query(), "agent_response": content} return await queue_save_to_memory_task(self._param.memory_ids, message_dict) diff --git a/agent/component/pipeline_chunker.py b/agent/component/pipeline_chunker.py index 3bf5bd0305..c63b881695 100644 --- a/agent/component/pipeline_chunker.py +++ b/agent/component/pipeline_chunker.py @@ -131,10 +131,7 @@ class PipelineChunker(ComponentBase, ABC): try: return FileService.get_blob(created_by, file_id), filename except Exception as e: - logging.exception( - f"[PipelineChunker] FileService.get_blob failed for " - f"file_id={file_id} created_by={created_by} filename={filename}: {e}" - ) + logging.exception(f"[PipelineChunker] FileService.get_blob failed for file_id={file_id} created_by={created_by} filename={filename}: {e}") return None, None return None, None diff --git a/agent/component/string_transform.py b/agent/component/string_transform.py index 0b152f8f01..db3fdab5bb 100644 --- a/agent/component/string_transform.py +++ b/agent/component/string_transform.py @@ -52,18 +52,10 @@ class StringTransform(Message, ABC): def get_input_form(self) -> dict[str, dict]: if self._param.method == "split": - return { - "line": { - "name": "String", - "type": "line" - } - } - return {k: { - "name": o["name"], - "type": "line" - } for k, o in self.get_input_elements_from_text(self._param.script).items()} + return {"line": {"name": "String", "type": "line"}} + return {k: {"name": o["name"], "type": "line"} for k, o in self.get_input_elements_from_text(self._param.script).items()} - @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) + @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10 * 60))) def _invoke(self, **kwargs): if self.check_if_canceled("StringTransform processing"): return @@ -73,7 +65,7 @@ class StringTransform(Message, ABC): else: self._merge(kwargs) - def _split(self, line:str|None = None): + def _split(self, line: str | None = None): if self.check_if_canceled("StringTransform split processing"): return @@ -84,13 +76,13 @@ class StringTransform(Message, ABC): self.set_input_value(self._param.split_ref, var) res = [] - for i,s in enumerate(re.split(r"(%s)"%("|".join([re.escape(d) for d in self._param.delimiters])), var, flags=re.DOTALL)): + for i, s in enumerate(re.split(r"(%s)" % ("|".join([re.escape(d) for d in self._param.delimiters])), var, flags=re.DOTALL)): if i % 2 == 1: continue res.append(s) self.set_output("result", res) - def _merge(self, kwargs:dict[str, str] = {}): + def _merge(self, kwargs: dict[str, str] = {}): if self.check_if_canceled("StringTransform merge processing"): return @@ -104,7 +96,7 @@ class StringTransform(Message, ABC): except Exception: pass - for k,v in kwargs.items(): + for k, v in kwargs.items(): if v is None: v = "" script = re.sub(k, lambda match: v, script) @@ -113,5 +105,3 @@ class StringTransform(Message, ABC): def thoughts(self) -> str: return f"It's {self._param.method}ing." - - diff --git a/agent/component/switch.py b/agent/component/switch.py index 20e41b7a1c..8612078a87 100644 --- a/agent/component/switch.py +++ b/agent/component/switch.py @@ -40,8 +40,7 @@ class SwitchParam(ComponentParamBase): """ self.conditions = [] self.end_cpn_ids = [] - self.operators = ['contains', 'not contains', 'start with', 'end with', 'empty', 'not empty', '=', '≠', '>', - '<', '≥', '≤'] + self.operators = ["contains", "not contains", "start with", "end with", "empty", "not empty", "=", "≠", ">", "<", "≥", "≤"] def check(self): self.check_empty(self.conditions, "[Switch] conditions") @@ -51,12 +50,8 @@ class SwitchParam(ComponentParamBase): self.check_empty(self.end_cpn_ids, "[Switch] the ELSE/Other destination can not be empty.") def get_input_form(self) -> dict[str, dict]: - return { - "urls": { - "name": "URLs", - "type": "line" - } - } + return {"urls": {"name": "URLs", "type": "line"}} + class Switch(ComponentBase, ABC): component_name = "Switch" @@ -137,7 +132,7 @@ class Switch(ComponentBase, ABC): except Exception: return True if input <= value else False - raise ValueError(f'Not supported operator: {operator}') + raise ValueError(f"Not supported operator: {operator}") def thoughts(self) -> str: return "I’m weighing a few options and will pick the next step shortly." diff --git a/agent/component/variable_aggregator.py b/agent/component/variable_aggregator.py index 63d10aca24..e21b389c7d 100644 --- a/agent/component/variable_aggregator.py +++ b/agent/component/variable_aggregator.py @@ -38,13 +38,9 @@ class VariableAggregatorParam(ComponentParamBase): if not g.get("group_name"): raise ValueError("[VariableAggregator] group_name can not be empty!") if not g.get("variables"): - raise ValueError( - f"[VariableAggregator] variables of group `{g.get('group_name')}` can not be empty" - ) + raise ValueError(f"[VariableAggregator] variables of group `{g.get('group_name')}` can not be empty") if not isinstance(g.get("variables"), list): - raise ValueError( - f"[VariableAggregator] variables of group `{g.get('group_name')}` should be a list of strings" - ) + raise ValueError(f"[VariableAggregator] variables of group `{g.get('group_name')}` should be a list of strings") def get_input_form(self) -> dict[str, dict]: return { @@ -67,11 +63,11 @@ class VariableAggregator(ComponentBase): # record candidate selectors within this group self.set_input_value(f"{gname}.variables", list(group.get("variables", []))) for selector in group.get("variables", []): - val = self._canvas.get_variable_value(selector['value']) + val = self._canvas.get_variable_value(selector["value"]) if val: self.set_output(gname, val) break - + @staticmethod def _to_object(value: Any) -> Any: # Try to convert value to serializable object if it has to_object() diff --git a/agent/component/variable_assigner.py b/agent/component/variable_assigner.py index 163fa27727..10906bda1b 100644 --- a/agent/component/variable_assigner.py +++ b/agent/component/variable_assigner.py @@ -19,32 +19,30 @@ import numbers from agent.component.base import ComponentBase, ComponentParamBase from api.utils.api_utils import timeout + class VariableAssignerParam(ComponentParamBase): """ Define the Variable Assigner component parameters. """ + def __init__(self): super().__init__() - self.variables=[] + self.variables = [] def check(self): return True - - def get_input_form(self) -> dict[str, dict]: - return { - "items": { - "type": "json", - "name": "Items" - } - } -class VariableAssigner(ComponentBase,ABC): + def get_input_form(self) -> dict[str, dict]: + return {"items": {"type": "json", "name": "Items"}} + + +class VariableAssigner(ComponentBase, ABC): component_name = "VariableAssigner" _NO_PARAMETER_OPERATORS = {"clear", "remove_first", "remove_last"} - @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) + @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10 * 60))) def _invoke(self, **kwargs): - if not isinstance(self._param.variables,list): + if not isinstance(self._param.variables, list): return else: for item in self._param.variables: @@ -56,105 +54,105 @@ class VariableAssigner(ComponentBase,ABC): raise ValueError("Variable is not complete.") if operator not in self._NO_PARAMETER_OPERATORS and parameter is None: raise ValueError("Variable is not complete.") - variable_value=self._canvas.get_variable_value(variable) - new_variable=self._operate(variable_value,operator,parameter) + variable_value = self._canvas.get_variable_value(variable) + new_variable = self._operate(variable_value, operator, parameter) self._canvas.set_variable_value(variable, new_variable) - def _operate(self,variable,operator,parameter): + def _operate(self, variable, operator, parameter): if operator == "overwrite": return self._overwrite(parameter) elif operator == "clear": return self._clear(variable) elif operator == "set": - return self._set(variable,parameter) + return self._set(variable, parameter) elif operator == "append": - return self._append(variable,parameter) + return self._append(variable, parameter) elif operator == "extend": - return self._extend(variable,parameter) + return self._extend(variable, parameter) elif operator == "remove_first": return self._remove_first(variable) elif operator == "remove_last": return self._remove_last(variable) elif operator == "+=": - return self._add(variable,parameter) + return self._add(variable, parameter) elif operator == "-=": - return self._subtract(variable,parameter) + return self._subtract(variable, parameter) elif operator == "*=": - return self._multiply(variable,parameter) + return self._multiply(variable, parameter) elif operator == "/=": - return self._divide(variable,parameter) + return self._divide(variable, parameter) else: return - - def _overwrite(self,parameter): + + def _overwrite(self, parameter): return self._canvas.get_variable_value(parameter) - def _clear(self,variable): - if isinstance(variable,list): + def _clear(self, variable): + if isinstance(variable, list): return [] - elif isinstance(variable,str): + elif isinstance(variable, str): return "" - elif isinstance(variable,dict): + elif isinstance(variable, dict): return {} - elif isinstance(variable,bool): + elif isinstance(variable, bool): return False - elif isinstance(variable,int): + elif isinstance(variable, int): return 0 - elif isinstance(variable,float): + elif isinstance(variable, float): return 0.0 else: return None - def _set(self,variable,parameter): + def _set(self, variable, parameter): if variable is None: return self._canvas.get_value_with_variable(parameter) - elif isinstance(variable,str): + elif isinstance(variable, str): return self._canvas.get_value_with_variable(parameter) - elif isinstance(variable,bool): + elif isinstance(variable, bool): return parameter - elif isinstance(variable,int): + elif isinstance(variable, int): return parameter - elif isinstance(variable,float): + elif isinstance(variable, float): return parameter else: return parameter - def _append(self,variable,parameter): - parameter=self._canvas.get_variable_value(parameter) + def _append(self, variable, parameter): + parameter = self._canvas.get_variable_value(parameter) if variable is None: - variable=[] - if not isinstance(variable,list): + variable = [] + if not isinstance(variable, list): return "ERROR:VARIABLE_NOT_LIST" - elif len(variable)!=0 and not isinstance(parameter,type(variable[0])): + elif len(variable) != 0 and not isinstance(parameter, type(variable[0])): return "ERROR:PARAMETER_NOT_LIST_ELEMENT_TYPE" else: variable.append(parameter) return variable - def _extend(self,variable,parameter): - parameter=self._canvas.get_variable_value(parameter) + def _extend(self, variable, parameter): + parameter = self._canvas.get_variable_value(parameter) if variable is None: - variable=[] - if not isinstance(variable,list): + variable = [] + if not isinstance(variable, list): return "ERROR:VARIABLE_NOT_LIST" - elif not isinstance(parameter,list): + elif not isinstance(parameter, list): return "ERROR:PARAMETER_NOT_LIST" - elif len(variable)!=0 and len(parameter)!=0 and not isinstance(parameter[0],type(variable[0])): + elif len(variable) != 0 and len(parameter) != 0 and not isinstance(parameter[0], type(variable[0])): return "ERROR:PARAMETER_NOT_LIST_ELEMENT_TYPE" else: return variable + parameter - def _remove_first(self,variable): - if not isinstance(variable,list): + def _remove_first(self, variable): + if not isinstance(variable, list): return "ERROR:VARIABLE_NOT_LIST" - if len(variable)==0: + if len(variable) == 0: return variable return variable[1:] - def _remove_last(self,variable): - if not isinstance(variable,list): + def _remove_last(self, variable): + if not isinstance(variable, list): return "ERROR:VARIABLE_NOT_LIST" - if len(variable)==0: + if len(variable) == 0: return variable return variable[:-1] @@ -163,32 +161,32 @@ class VariableAssigner(ComponentBase,ABC): return False return isinstance(value, numbers.Number) - def _add(self,variable,parameter): + def _add(self, variable, parameter): if self.is_number(variable) and self.is_number(parameter): return variable + parameter else: return "ERROR:VARIABLE_NOT_NUMBER or PARAMETER_NOT_NUMBER" - def _subtract(self,variable,parameter): + def _subtract(self, variable, parameter): if self.is_number(variable) and self.is_number(parameter): return variable - parameter else: return "ERROR:VARIABLE_NOT_NUMBER or PARAMETER_NOT_NUMBER" - def _multiply(self,variable,parameter): + def _multiply(self, variable, parameter): if self.is_number(variable) and self.is_number(parameter): return variable * parameter else: return "ERROR:VARIABLE_NOT_NUMBER or PARAMETER_NOT_NUMBER" - def _divide(self,variable,parameter): + def _divide(self, variable, parameter): if self.is_number(variable) and self.is_number(parameter): - if parameter==0: + if parameter == 0: return "ERROR:DIVIDE_BY_ZERO" else: - return variable/parameter + return variable / parameter else: - return "ERROR:VARIABLE_NOT_NUMBER or PARAMETER_NOT_NUMBER" + return "ERROR:VARIABLE_NOT_NUMBER or PARAMETER_NOT_NUMBER" def thoughts(self) -> str: - return "Assign variables from canvas." \ No newline at end of file + return "Assign variables from canvas." diff --git a/agent/dsl_migration.py b/agent/dsl_migration.py index ca4ee894c3..1b995090df 100644 --- a/agent/dsl_migration.py +++ b/agent/dsl_migration.py @@ -56,7 +56,7 @@ def normalize_chunker_dsl(dsl: dict) -> dict: for old_name, new_name in COMPONENT_RENAMES.items(): prefix = f"{old_name}:" if component_id.startswith(prefix): - new_component_id = f"{new_name}:{component_id[len(prefix):]}" + new_component_id = f"{new_name}:{component_id[len(prefix) :]}" break component_id_map[component_id] = new_component_id @@ -66,12 +66,7 @@ def normalize_chunker_dsl(dsl: dict) -> dict: def repl(match: re.Match[str]) -> str: component_id = match.group(2) - return ( - match.group(1) - + component_id_map.get(component_id, component_id) - + match.group(3) - + match.group(4) - ) + return match.group(1) + component_id_map.get(component_id, component_id) + match.group(3) + match.group(4) return VARIABLE_REF_PATTERN.sub(repl, text) @@ -96,15 +91,9 @@ def normalize_chunker_dsl(dsl: dict) -> dict: obj["component_name"] = COMPONENT_RENAMES.get(component_name, component_name) if isinstance(new_component.get("downstream"), list): - new_component["downstream"] = [ - component_id_map.get(component_id, component_id) - for component_id in new_component["downstream"] - ] + new_component["downstream"] = [component_id_map.get(component_id, component_id) for component_id in new_component["downstream"]] if isinstance(new_component.get("upstream"), list): - new_component["upstream"] = [ - component_id_map.get(component_id, component_id) - for component_id in new_component["upstream"] - ] + new_component["upstream"] = [component_id_map.get(component_id, component_id) for component_id in new_component["upstream"]] parent_id = new_component.get("parent_id") if isinstance(parent_id, str): @@ -115,10 +104,7 @@ def normalize_chunker_dsl(dsl: dict) -> dict: normalized["components"] = rewritten_components if isinstance(normalized.get("path"), list): - normalized["path"] = [ - component_id_map.get(component_id, component_id) - for component_id in normalized["path"] - ] + normalized["path"] = [component_id_map.get(component_id, component_id) for component_id in normalized["path"]] graph = normalized.get("graph") if isinstance(graph, dict): diff --git a/agent/plugin/common.py b/agent/plugin/common.py index 7e85e0a13a..d8ed7c6068 100644 --- a/agent/plugin/common.py +++ b/agent/plugin/common.py @@ -1 +1 @@ -PLUGIN_TYPE_LLM_TOOLS = "llm_tools" \ No newline at end of file +PLUGIN_TYPE_LLM_TOOLS = "llm_tools" diff --git a/agent/plugin/embedded_plugins/llm_tools/bad_calculator.py b/agent/plugin/embedded_plugins/llm_tools/bad_calculator.py index 38376aa984..643495aaa6 100644 --- a/agent/plugin/embedded_plugins/llm_tools/bad_calculator.py +++ b/agent/plugin/embedded_plugins/llm_tools/bad_calculator.py @@ -7,6 +7,7 @@ class BadCalculatorPlugin(LLMToolPlugin): A sample LLM tool plugin, will add two numbers with 100. It only presents for demo purpose. Do not use it in production. """ + _version_ = "1.0.0" @classmethod @@ -17,19 +18,9 @@ class BadCalculatorPlugin(LLMToolPlugin): "description": "A tool to calculate the sum of two numbers (will give wrong answer)", "displayDescription": "$t:bad_calculator.description", "parameters": { - "a": { - "type": "number", - "description": "The first number", - "displayDescription": "$t:bad_calculator.params.a", - "required": True - }, - "b": { - "type": "number", - "description": "The second number", - "displayDescription": "$t:bad_calculator.params.b", - "required": True - } - } + "a": {"type": "number", "description": "The first number", "displayDescription": "$t:bad_calculator.params.a", "required": True}, + "b": {"type": "number", "description": "The second number", "displayDescription": "$t:bad_calculator.params.b", "required": True}, + }, } def invoke(self, a: int, b: int) -> str: diff --git a/agent/plugin/llm_tool_plugin.py b/agent/plugin/llm_tool_plugin.py index b0dc4c8e8f..13124b6e66 100644 --- a/agent/plugin/llm_tool_plugin.py +++ b/agent/plugin/llm_tool_plugin.py @@ -38,14 +38,8 @@ def llm_tool_metadata_to_openai_tool(llm_tool_metadata: LLMToolMetadata) -> dict "description": llm_tool_metadata["description"], "parameters": { "type": "object", - "properties": { - k: { - "type": p["type"], - "description": p["description"] - } - for k, p in llm_tool_metadata["parameters"].items() - }, - "required": [k for k, p in llm_tool_metadata["parameters"].items() if p["required"]] - } - } + "properties": {k: {"type": p["type"], "description": p["description"]} for k, p in llm_tool_metadata["parameters"].items()}, + "required": [k for k, p in llm_tool_metadata["parameters"].items() if p["required"]], + }, + }, } diff --git a/agent/plugin/plugin_manager.py b/agent/plugin/plugin_manager.py index 1f1b815912..c24ef61986 100644 --- a/agent/plugin/plugin_manager.py +++ b/agent/plugin/plugin_manager.py @@ -15,10 +15,8 @@ class PluginManager: self._llm_tool_plugins = {} def load_plugins(self) -> None: - loader = pluginlib.PluginLoader( - paths=[str(Path(os.path.dirname(__file__), "embedded_plugins"))] - ) - + loader = pluginlib.PluginLoader(paths=[str(Path(os.path.dirname(__file__), "embedded_plugins"))]) + for type, plugins in loader.plugins.items(): for name, plugin in plugins.items(): logging.info(f"Loaded {type} plugin {name} version {plugin.version}") diff --git a/agent/sandbox/client.py b/agent/sandbox/client.py index daafb0d07f..efcea13cb6 100644 --- a/agent/sandbox/client.py +++ b/agent/sandbox/client.py @@ -111,7 +111,10 @@ def _load_provider_from_settings() -> None: except Exception as e: logger.error(f"Failed to load sandbox provider from settings: {e}") import traceback + traceback.print_exc() + + def _load_provider_config_from_settings(provider_type: str) -> Dict[str, Any]: provider_config_settings = SystemSettingsService.get_by_name(f"sandbox.{provider_type}") if not provider_config_settings: @@ -147,12 +150,7 @@ def reload_provider() -> None: _load_provider_from_settings() -def execute_code( - code: str, - language: str = "python", - timeout: int = 30, - arguments: Optional[Dict[str, Any]] = None -) -> ExecutionResult: +def execute_code(code: str, language: str = "python", timeout: int = 30, arguments: Optional[Dict[str, Any]] = None) -> ExecutionResult: """ Execute code in the configured sandbox. @@ -173,9 +171,7 @@ def execute_code( provider_manager = get_provider_manager() if not provider_manager.is_configured(): - raise RuntimeError( - "No sandbox provider configured. Please configure sandbox settings in the admin panel." - ) + raise RuntimeError("No sandbox provider configured. Please configure sandbox settings in the admin panel.") provider = provider_manager.get_provider() provider_name = provider_manager.get_provider_name() or getattr(provider, "__class__", type(provider)).__name__ @@ -192,13 +188,7 @@ def execute_code( try: # Execute the code - result = provider.execute_code( - instance_id=instance.instance_id, - code=code, - language=language, - timeout=timeout, - arguments=arguments - ) + result = provider.execute_code(instance_id=instance.instance_id, code=code, language=language, timeout=timeout, arguments=arguments) return result diff --git a/agent/sandbox/executor_manager/api/routes.py b/agent/sandbox/executor_manager/api/routes.py index 86a034d6f3..b0ffe908b2 100644 --- a/agent/sandbox/executor_manager/api/routes.py +++ b/agent/sandbox/executor_manager/api/routes.py @@ -22,4 +22,3 @@ router = APIRouter() router.get("/")(healthz_handler) router.get("/healthz")(healthz_handler) router.post("/run")(run_code_handler) - diff --git a/agent/sandbox/executor_manager/services/execution.py b/agent/sandbox/executor_manager/services/execution.py index 48bd96d74f..d8a36bdae3 100644 --- a/agent/sandbox/executor_manager/services/execution.py +++ b/agent/sandbox/executor_manager/services/execution.py @@ -230,9 +230,7 @@ async def execute_code(req: CodeExecutionRequest): if returncode != 0: raise RuntimeError(f"Directory creation failed: {stderr}") - tar_proc = await asyncio.create_subprocess_exec( - "tar", "czf", "-", "-C", workdir, code_name, runner_name, str(bundle["args_name"]), stdout=asyncio.subprocess.PIPE - ) + tar_proc = await asyncio.create_subprocess_exec("tar", "czf", "-", "-C", workdir, code_name, runner_name, str(bundle["args_name"]), stdout=asyncio.subprocess.PIPE) tar_stdout, _ = await tar_proc.communicate() docker_proc = await asyncio.create_subprocess_exec( @@ -334,8 +332,16 @@ async def _collect_artifacts(container: str, task_id: str, host_workdir: str) -> # List files in the artifacts directory inside the container returncode, stdout, _ = await async_run_command( - "docker", "exec", container, "find", artifacts_path, - "-maxdepth", "1", "-type", "f", timeout=5, + "docker", + "exec", + container, + "find", + artifacts_path, + "-maxdepth", + "1", + "-type", + "f", + timeout=5, ) if returncode != 0 or not stdout.strip(): return [] @@ -359,7 +365,14 @@ async def _collect_artifacts(container: str, task_id: str, host_workdir: str) -> # Check file size inside the container returncode, size_str, _ = await async_run_command( - "docker", "exec", container, "stat", "-c", "%s", file_path, timeout=5, + "docker", + "exec", + container, + "stat", + "-c", + "%s", + file_path, + timeout=5, ) if returncode != 0: logger.warning(f"Failed to stat artifact {fname}") @@ -374,7 +387,12 @@ async def _collect_artifacts(container: str, task_id: str, host_workdir: str) -> # Read file content via docker exec (docker cp doesn't work with gVisor tmpfs) returncode, content_b64, stderr = await async_run_command( - "docker", "exec", container, "base64", file_path, timeout=30, + "docker", + "exec", + container, + "base64", + file_path, + timeout=30, ) if returncode != 0: logger.warning(f"Failed to read artifact {fname}: {stderr}") @@ -382,12 +400,14 @@ async def _collect_artifacts(container: str, task_id: str, host_workdir: str) -> content_b64 = content_b64.replace("\n", "").strip() - items.append(ArtifactItem( - name=fname, - mime_type=mime_type, - size=file_size, - content_b64=content_b64, - )) + items.append( + ArtifactItem( + name=fname, + mime_type=mime_type, + size=file_size, + content_b64=content_b64, + ) + ) logger.info(f"Collected artifact: {fname} ({file_size} bytes, {mime_type})") return items diff --git a/agent/sandbox/providers/aliyun_codeinterpreter.py b/agent/sandbox/providers/aliyun_codeinterpreter.py index 144ae15251..fd4adb41bd 100644 --- a/agent/sandbox/providers/aliyun_codeinterpreter.py +++ b/agent/sandbox/providers/aliyun_codeinterpreter.py @@ -232,11 +232,7 @@ class AliyunCodeInterpreterProvider(SandboxProvider): # Wrap code to call main() function # Matches self_managed provider behavior: call main(**arguments) args_json = json.dumps(arguments or {}) - wrapped_code = ( - build_python_wrapper(code, args_json) - if normalized_lang == "python" - else build_javascript_wrapper(code, args_json) - ) + wrapped_code = build_python_wrapper(code, args_json) if normalized_lang == "python" else build_javascript_wrapper(code, args_json) logger.debug(f"Aliyun Code Interpreter: Wrapped code (first 200 chars): {wrapped_code[:200]}") start_time = time.time() diff --git a/agent/sandbox/providers/base.py b/agent/sandbox/providers/base.py index 8f9c04aaa4..915487cf59 100644 --- a/agent/sandbox/providers/base.py +++ b/agent/sandbox/providers/base.py @@ -33,6 +33,7 @@ class SandboxProviderConfigError(Exception): @dataclass class SandboxInstance: """Represents a sandbox execution instance""" + instance_id: str provider: str status: str # running, stopped, error @@ -46,6 +47,7 @@ class SandboxInstance: @dataclass class ExecutionResult: """Result of code execution in a sandbox""" + stdout: str stderr: str exit_code: int @@ -96,14 +98,7 @@ class SandboxProvider(ABC): pass @abstractmethod - def execute_code( - self, - instance_id: str, - code: str, - language: str, - timeout: int = 10, - arguments: Optional[Dict[str, Any]] = None - ) -> ExecutionResult: + def execute_code(self, instance_id: str, code: str, language: str, timeout: int = 10, arguments: Optional[Dict[str, Any]] = None) -> ExecutionResult: """ Execute code in a sandbox instance. diff --git a/agent/sandbox/providers/e2b.py b/agent/sandbox/providers/e2b.py index 5c4bd5d912..0ebd837564 100644 --- a/agent/sandbox/providers/e2b.py +++ b/agent/sandbox/providers/e2b.py @@ -97,16 +97,10 @@ class E2BProvider(SandboxProvider): metadata={ "language": language, "region": self.region, - } + }, ) - def execute_code( - self, - instance_id: str, - code: str, - language: str, - timeout: int = 10 - ) -> ExecutionResult: + def execute_code(self, instance_id: str, code: str, language: str, timeout: int = 10) -> ExecutionResult: """ Execute code in the E2B instance. @@ -130,9 +124,7 @@ class E2BProvider(SandboxProvider): # POST /sandbox/{sandboxID}/execute raise RuntimeError( - "E2B provider is not yet fully implemented. " - "Please use the self-managed provider or implement the E2B API integration. " - "See https://github.com/e2b-dev/e2b for API documentation." + "E2B provider is not yet fully implemented. Please use the self-managed provider or implement the E2B API integration. See https://github.com/e2b-dev/e2b for API documentation." ) def destroy_instance(self, instance_id: str) -> bool: @@ -208,7 +200,7 @@ class E2BProvider(SandboxProvider): "min": 5, "max": 300, "description": "API request timeout for code execution", - } + }, } def _normalize_language(self, language: str) -> str: diff --git a/agent/sandbox/providers/local.py b/agent/sandbox/providers/local.py index ed37cc57d0..6a9ac516bd 100644 --- a/agent/sandbox/providers/local.py +++ b/agent/sandbox/providers/local.py @@ -49,6 +49,8 @@ LOCAL_PYTHON_THREAD_ENV_VARS = ( "BLIS_NUM_THREADS", "VECLIB_MAXIMUM_THREADS", ) + + class LocalProvider(SandboxProvider): """ Execute code as a local child process. diff --git a/agent/sandbox/providers/self_managed.py b/agent/sandbox/providers/self_managed.py index 8b92d0b2c4..85e02cb424 100644 --- a/agent/sandbox/providers/self_managed.py +++ b/agent/sandbox/providers/self_managed.py @@ -108,17 +108,10 @@ class SelfManagedProvider(SandboxProvider): "language": language, "endpoint": self.endpoint, "pool_size": self.pool_size, - } + }, ) - def execute_code( - self, - instance_id: str, - code: str, - language: str, - timeout: int = 10, - arguments: Optional[Dict[str, Any]] = None - ) -> ExecutionResult: + def execute_code(self, instance_id: str, code: str, language: str, timeout: int = 10, arguments: Optional[Dict[str, Any]] = None) -> ExecutionResult: """ Execute code in the sandbox. @@ -144,11 +137,7 @@ class SelfManagedProvider(SandboxProvider): # Prepare request code_b64 = base64.b64encode(code.encode("utf-8")).decode("utf-8") - payload = { - "code_b64": code_b64, - "language": normalized_lang, - "arguments": arguments or {} - } + payload = {"code_b64": code_b64, "language": normalized_lang, "arguments": arguments or {}} url = f"{self.endpoint}/run" exec_timeout = timeout or self.timeout @@ -156,19 +145,12 @@ class SelfManagedProvider(SandboxProvider): start_time = time.time() try: - response = requests.post( - url, - json=payload, - timeout=exec_timeout, - headers={"Content-Type": "application/json"} - ) + response = requests.post(url, json=payload, timeout=exec_timeout, headers={"Content-Type": "application/json"}) execution_time = time.time() - start_time if response.status_code != 200: - raise RuntimeError( - f"HTTP {response.status_code}: {response.text}" - ) + raise RuntimeError(f"HTTP {response.status_code}: {response.text}") result = response.json() structured_result = result.get("result") or {} @@ -188,14 +170,12 @@ class SelfManagedProvider(SandboxProvider): "result_present": structured_result.get("present", False), "result_value": structured_result.get("value"), "result_type": structured_result.get("type"), - } + }, ) except requests.Timeout: execution_time = time.time() - start_time - raise TimeoutError( - f"Execution timed out after {exec_timeout} seconds" - ) + raise TimeoutError(f"Execution timed out after {exec_timeout} seconds") except requests.RequestException as e: raise RuntimeError(f"HTTP request failed: {str(e)}") @@ -388,7 +368,8 @@ class SelfManagedProvider(SandboxProvider): if endpoint: # Check if it's a valid HTTP/HTTPS URL or localhost import re - url_pattern = r'^(https?://|http://localhost|http://[\d\.]+:[a-z]+:[/]|http://[\w\.]+:)' + + url_pattern = r"^(https?://|http://localhost|http://[\d\.]+:[a-z]+:[/]|http://[\w\.]+:)" if not re.match(url_pattern, endpoint): return False, f"Invalid endpoint format: {endpoint}. Must start with http:// or https://" diff --git a/agent/sandbox/providers/ssh.py b/agent/sandbox/providers/ssh.py index 2ac33b0045..001c294300 100644 --- a/agent/sandbox/providers/ssh.py +++ b/agent/sandbox/providers/ssh.py @@ -135,9 +135,7 @@ class SSHProvider(SandboxProvider): timeout=min(self.timeout, 10), ) if exit_code != 0: - raise RuntimeError( - f"Failed to create remote artifacts directory: {stderr or stdout or 'unknown error'}" - ) + raise RuntimeError(f"Failed to create remote artifacts directory: {stderr or stdout or 'unknown error'}") except Exception: sftp.close() client.close() @@ -211,9 +209,7 @@ class SSHProvider(SandboxProvider): "status": "ok" if exit_code == 0 else "error", "timeout": exec_timeout, "command": command, - "artifacts": self._collect_artifacts( - sftp, posixpath.join(remote_work_dir, "artifacts") - ), + "artifacts": self._collect_artifacts(sftp, posixpath.join(remote_work_dir, "artifacts")), "result_present": structured_result.get("present", False), "result_value": structured_result.get("value"), "result_type": structured_result.get("type"), @@ -269,18 +265,13 @@ class SSHProvider(SandboxProvider): timeout=min(self.timeout, 10), ) if exit_code != 0: - raise SandboxProviderConfigError( - f"SSH connectivity check failed on {self.username}@{self.host}:{self.port}: " - f"{stderr or 'remote command returned non-zero exit status'}" - ) + raise SandboxProviderConfigError(f"SSH connectivity check failed on {self.username}@{self.host}:{self.port}: {stderr or 'remote command returned non-zero exit status'}") finally: client.close() except SandboxProviderConfigError: raise except Exception as exc: - raise SandboxProviderConfigError( - f"Failed to connect to SSH host {self.username}@{self.host}:{self.port}: {exc}" - ) from exc + raise SandboxProviderConfigError(f"Failed to connect to SSH host {self.username}@{self.host}:{self.port}: {exc}") from exc def get_supported_languages(self) -> List[str]: return ["python", "javascript", "nodejs"] @@ -470,9 +461,7 @@ class SSHProvider(SandboxProvider): # Match the Go provider's fail-closed posture (see # internal/agent/sandbox/ssh.go::hostKeyCallback). logging.warning("SSH: failed to load configured known_hosts file; refusing connection") - raise SandboxProviderConfigError( - "Failed to load configured SSH known_hosts file." - ) from exc + raise SandboxProviderConfigError("Failed to load configured SSH known_hosts file.") from exc # Reject unknown hosts: this is the default fail-closed posture # to prevent silent MITM. Operators must either ship a populated # known_hosts file or accept the warning (paramiko will fail the @@ -522,9 +511,7 @@ class SSHProvider(SandboxProvider): except Exception as exc: errors.append(str(exc)) - raise SandboxProviderConfigError( - "Failed to load SSH private key. " + "; ".join(error for error in errors if error) - ) + raise SandboxProviderConfigError("Failed to load SSH private key. " + "; ".join(error for error in errors if error)) def _create_remote_workspace(self, client: paramiko.SSHClient) -> str: base_dir = self.work_dir.rstrip("/") or "/tmp" @@ -535,9 +522,7 @@ class SSHProvider(SandboxProvider): timeout=min(self.timeout, 10), ) if exit_code != 0: - raise RuntimeError( - f"Failed to create remote workspace on {self.host}: {stderr or stdout or 'unknown error'}" - ) + raise RuntimeError(f"Failed to create remote workspace on {self.host}: {stderr or stdout or 'unknown error'}") remote_work_dir = stdout.strip().splitlines()[-1] if stdout.strip() else "" if not remote_work_dir: @@ -577,10 +562,7 @@ class SSHProvider(SandboxProvider): else: raise RuntimeError(f"Unsupported language for SSH provider: {language}") - return ( - f"cd {shlex.quote(remote_work_dir)} && " - f"{shlex.quote(executable)} {shlex.quote(remote_script_path)}" - ) + return f"cd {shlex.quote(remote_work_dir)} && {shlex.quote(executable)} {shlex.quote(remote_script_path)}" def _run_remote_command( self, @@ -700,7 +682,5 @@ def _get_paramiko_module(): try: import paramiko except ImportError as exc: - raise SandboxProviderConfigError( - "paramiko is required for the SSH sandbox provider. Install the project dependencies to enable it." - ) from exc + raise SandboxProviderConfigError("paramiko is required for the SSH sandbox provider. Install the project dependencies to enable it.") from exc return paramiko diff --git a/agent/sandbox/result_protocol.py b/agent/sandbox/result_protocol.py index f71e5f4996..385f6fb615 100644 --- a/agent/sandbox/result_protocol.py +++ b/agent/sandbox/result_protocol.py @@ -36,7 +36,7 @@ if __name__ == "__main__": def build_javascript_wrapper(code: str, args_json: str) -> str: - return f'''{code} + return f"""{code} const __ragflowArgs = {args_json}; @@ -55,7 +55,7 @@ const __ragflowArgs = {args_json}; }} console.log('{RESULT_MARKER_PREFIX}' + Buffer.from(payload, 'utf8').toString('base64')); }})(); -''' +""" def extract_structured_result(stdout: str) -> tuple[str, dict[str, Any]]: diff --git a/agent/sandbox/tests/test_aliyun_codeinterpreter_integration.py b/agent/sandbox/tests/test_aliyun_codeinterpreter_integration.py index 491d19ba42..a114b5d4e5 100644 --- a/agent/sandbox/tests/test_aliyun_codeinterpreter_integration.py +++ b/agent/sandbox/tests/test_aliyun_codeinterpreter_integration.py @@ -151,7 +151,7 @@ class TestAliyunCodeInterpreterIntegration: """, language="python", timeout=30, - arguments={"name": "World", "count": 2} + arguments={"name": "World", "count": 2}, ) assert result.exit_code == 0 @@ -211,7 +211,7 @@ class TestAliyunCodeInterpreterIntegration: }""", language="javascript", timeout=30, - arguments={"name": "World", "count": 2} + arguments={"name": "World", "count": 2}, ) assert result.exit_code == 0 diff --git a/agent/sandbox/tests/test_providers.py b/agent/sandbox/tests/test_providers.py index cf90bb79ab..12d3d0a129 100644 --- a/agent/sandbox/tests/test_providers.py +++ b/agent/sandbox/tests/test_providers.py @@ -32,12 +32,7 @@ class TestSandboxDataclasses: def test_sandbox_instance_creation(self): """Test SandboxInstance dataclass creation.""" - instance = SandboxInstance( - instance_id="test-123", - provider="self_managed", - status="running", - metadata={"language": "python"} - ) + instance = SandboxInstance(instance_id="test-123", provider="self_managed", status="running", metadata={"language": "python"}) assert instance.instance_id == "test-123" assert instance.provider == "self_managed" @@ -46,24 +41,13 @@ class TestSandboxDataclasses: def test_sandbox_instance_default_metadata(self): """Test SandboxInstance with None metadata.""" - instance = SandboxInstance( - instance_id="test-123", - provider="self_managed", - status="running", - metadata=None - ) + instance = SandboxInstance(instance_id="test-123", provider="self_managed", status="running", metadata=None) assert instance.metadata == {} def test_execution_result_creation(self): """Test ExecutionResult dataclass creation.""" - result = ExecutionResult( - stdout="Hello, World!", - stderr="", - exit_code=0, - execution_time=1.5, - metadata={"status": "success"} - ) + result = ExecutionResult(stdout="Hello, World!", stderr="", exit_code=0, execution_time=1.5, metadata={"status": "success"}) assert result.stdout == "Hello, World!" assert result.stderr == "" @@ -73,13 +57,7 @@ class TestSandboxDataclasses: def test_execution_result_default_metadata(self): """Test ExecutionResult with None metadata.""" - result = ExecutionResult( - stdout="output", - stderr="error", - exit_code=1, - execution_time=0.5, - metadata=None - ) + result = ExecutionResult(stdout="output", stderr="error", exit_code=1, execution_time=0.5, metadata=None) assert result.metadata == {} @@ -145,7 +123,7 @@ class TestSelfManagedProvider: assert provider.pool_size == 10 assert not provider._initialized - @patch('requests.get') + @patch("requests.get") def test_initialize_success(self, mock_get): """Test successful initialization.""" mock_response = Mock() @@ -153,12 +131,7 @@ class TestSelfManagedProvider: mock_get.return_value = mock_response provider = SelfManagedProvider() - result = provider.initialize({ - "endpoint": "http://test-endpoint:9385", - "timeout": 60, - "max_retries": 5, - "pool_size": 20 - }) + result = provider.initialize({"endpoint": "http://test-endpoint:9385", "timeout": 60, "max_retries": 5, "pool_size": 20}) assert result is True assert provider.endpoint == "http://test-endpoint:9385" @@ -168,7 +141,7 @@ class TestSelfManagedProvider: assert provider._initialized mock_get.assert_called_once_with("http://test-endpoint:9385/healthz", timeout=5) - @patch('requests.get') + @patch("requests.get") def test_initialize_failure(self, mock_get): """Test initialization failure.""" mock_get.side_effect = Exception("Connection error") @@ -181,7 +154,7 @@ class TestSelfManagedProvider: def test_initialize_default_config(self): """Test initialization with default config.""" - with patch('requests.get') as mock_get: + with patch("requests.get") as mock_get: mock_response = Mock() mock_response.status_code = 200 mock_get.return_value = mock_response @@ -222,30 +195,18 @@ class TestSelfManagedProvider: with pytest.raises(RuntimeError, match="Provider not initialized"): provider.create_instance("python") - @patch('requests.post') + @patch("requests.post") def test_execute_code_success(self, mock_post): """Test successful code execution.""" mock_response = Mock() mock_response.status_code = 200 - mock_response.json.return_value = { - "status": "success", - "stdout": '{"result": 42}', - "stderr": "", - "exit_code": 0, - "time_used_ms": 100.0, - "memory_used_kb": 1024.0 - } + mock_response.json.return_value = {"status": "success", "stdout": '{"result": 42}', "stderr": "", "exit_code": 0, "time_used_ms": 100.0, "memory_used_kb": 1024.0} mock_post.return_value = mock_response provider = SelfManagedProvider() provider._initialized = True - result = provider.execute_code( - instance_id="test-123", - code="def main(): return {'result': 42}", - language="python", - timeout=10 - ) + result = provider.execute_code(instance_id="test-123", code="def main(): return {'result': 42}", language="python", timeout=10) assert result.stdout == '{"result": 42}' assert result.stderr == "" @@ -254,7 +215,7 @@ class TestSelfManagedProvider: assert result.metadata["status"] == "success" assert result.metadata["instance_id"] == "test-123" - @patch('requests.post') + @patch("requests.post") def test_execute_code_maps_structured_result_into_metadata(self, mock_post): """Test successful code execution with structured result envelope.""" mock_response = Mock() @@ -277,19 +238,14 @@ class TestSelfManagedProvider: provider = SelfManagedProvider() provider._initialized = True - result = provider.execute_code( - instance_id="test-123", - code="def main(): return {'items': ['a', 'b']}", - language="python", - timeout=10 - ) + result = provider.execute_code(instance_id="test-123", code="def main(): return {'items': ['a', 'b']}", language="python", timeout=10) assert result.stdout == "debug line\n" assert result.metadata["result_present"] is True assert result.metadata["result_value"] == {"items": ["a", "b"]} assert result.metadata["result_type"] == "json" - @patch('requests.post') + @patch("requests.post") def test_execute_code_timeout(self, mock_post): """Test code execution timeout.""" mock_post.side_effect = requests.Timeout() @@ -298,14 +254,9 @@ class TestSelfManagedProvider: provider._initialized = True with pytest.raises(TimeoutError, match="Execution timed out"): - provider.execute_code( - instance_id="test-123", - code="while True: pass", - language="python", - timeout=5 - ) + provider.execute_code(instance_id="test-123", code="while True: pass", language="python", timeout=5) - @patch('requests.post') + @patch("requests.post") def test_execute_code_http_error(self, mock_post): """Test code execution with HTTP error.""" mock_response = Mock() @@ -317,22 +268,14 @@ class TestSelfManagedProvider: provider._initialized = True with pytest.raises(RuntimeError, match="HTTP 500"): - provider.execute_code( - instance_id="test-123", - code="invalid code", - language="python" - ) + provider.execute_code(instance_id="test-123", code="invalid code", language="python") def test_execute_code_not_initialized(self): """Test executing code when provider not initialized.""" provider = SelfManagedProvider() with pytest.raises(RuntimeError, match="Provider not initialized"): - provider.execute_code( - instance_id="test-123", - code="print('hello')", - language="python" - ) + provider.execute_code(instance_id="test-123", code="print('hello')", language="python") def test_destroy_instance(self): """Test destroying an instance (no-op for self-managed).""" @@ -344,7 +287,7 @@ class TestSelfManagedProvider: assert result is True - @patch('requests.get') + @patch("requests.get") def test_health_check_success(self, mock_get): """Test successful health check.""" mock_response = Mock() @@ -358,7 +301,7 @@ class TestSelfManagedProvider: assert result is True mock_get.assert_called_once_with("http://localhost:9385/healthz", timeout=5) - @patch('requests.get') + @patch("requests.get") def test_health_check_failure(self, mock_get): """Test health check failure.""" mock_get.side_effect = Exception("Connection error") @@ -439,20 +382,20 @@ class TestProviderInterface: provider = SelfManagedProvider() # Check all abstract methods are implemented - assert hasattr(provider, 'initialize') + assert hasattr(provider, "initialize") assert callable(provider.initialize) - assert hasattr(provider, 'create_instance') + assert hasattr(provider, "create_instance") assert callable(provider.create_instance) - assert hasattr(provider, 'execute_code') + assert hasattr(provider, "execute_code") assert callable(provider.execute_code) - assert hasattr(provider, 'destroy_instance') + assert hasattr(provider, "destroy_instance") assert callable(provider.destroy_instance) - assert hasattr(provider, 'health_check') + assert hasattr(provider, "health_check") assert callable(provider.health_check) - assert hasattr(provider, 'get_supported_languages') + assert hasattr(provider, "get_supported_languages") assert callable(provider.get_supported_languages) diff --git a/agent/sandbox/tests/test_security.py b/agent/sandbox/tests/test_security.py index dc8d9f8063..95a4ee35d4 100644 --- a/agent/sandbox/tests/test_security.py +++ b/agent/sandbox/tests/test_security.py @@ -76,9 +76,7 @@ def test_python_builtins_import_is_rejected(): assert is_safe is False # Pin the specific reason: rejection must come from the new ``builtins`` # entry in ``DANGEROUS_IMPORTS``, not from some unrelated parse error. - assert any("builtins" in issue for issue, _ in issues), ( - f"expected an issue mentioning 'builtins', got {issues!r}" - ) + assert any("builtins" in issue for issue, _ in issues), f"expected an issue mentioning 'builtins', got {issues!r}" def test_python_attribute_eval_call_is_rejected(): @@ -94,9 +92,7 @@ def test_python_attribute_eval_call_is_rejected(): # not from the ``import builtins`` line above. We assert ``exec`` is in at # least one finding so the test fails if visit_Call's attribute branch is # ever reverted. - assert any("exec" in issue for issue, _ in issues), ( - f"expected an issue mentioning 'exec', got {issues!r}" - ) + assert any("exec" in issue for issue, _ in issues), f"expected an issue mentioning 'exec', got {issues!r}" def test_javascript_safe_code_still_passes(): diff --git a/agent/sandbox/tests/verify_sdk.py b/agent/sandbox/tests/verify_sdk.py index 94aea18f88..e3ebf16066 100644 --- a/agent/sandbox/tests/verify_sdk.py +++ b/agent/sandbox/tests/verify_sdk.py @@ -36,17 +36,14 @@ print("✓ Provider has all required methods") print("\n[3/5] Testing SDK imports...") try: # Check if agentrun SDK is available using importlib - if ( - importlib.util.find_spec("agentrun.sandbox") is None - or importlib.util.find_spec("agentrun.utils.config") is None - or importlib.util.find_spec("agentrun.utils.exception") is None - ): + if importlib.util.find_spec("agentrun.sandbox") is None or importlib.util.find_spec("agentrun.utils.config") is None or importlib.util.find_spec("agentrun.utils.exception") is None: raise ImportError("agentrun SDK not found") # Verify imports work (assign to _ to indicate they're intentionally unused) from agentrun.sandbox import CodeInterpreterSandbox, TemplateType, CodeLanguage from agentrun.utils.config import Config from agentrun.utils.exception import ServerError + _ = (CodeInterpreterSandbox, TemplateType, CodeLanguage, Config, ServerError) print("✓ SDK modules imported successfully") diff --git a/agent/test/client.py b/agent/test/client.py index 26a02b957e..783307cdf2 100644 --- a/agent/test/client.py +++ b/agent/test/client.py @@ -18,29 +18,29 @@ import os from agent.canvas import Canvas from common import settings -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() dsl_default_path = os.path.join( os.path.dirname(os.path.realpath(__file__)), "dsl_examples", "retrieval_and_generate.json", ) - parser.add_argument('-s', '--dsl', default=dsl_default_path, help="input dsl", action='store', required=True) - parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True) - parser.add_argument('-m', '--stream', default=False, help="Stream output", action='store_true', required=False) + parser.add_argument("-s", "--dsl", default=dsl_default_path, help="input dsl", action="store", required=True) + parser.add_argument("-t", "--tenant_id", default=False, help="Tenant ID", action="store", required=True) + parser.add_argument("-m", "--stream", default=False, help="Stream output", action="store_true", required=False) args = parser.parse_args() settings.init_settings() canvas = Canvas(open(args.dsl, "r").read(), args.tenant_id) if canvas.get_prologue(): - print(f"==================== Bot =====================\n> {canvas.get_prologue()}", end='') + print(f"==================== Bot =====================\n> {canvas.get_prologue()}", end="") query = "" while True: canvas.reset(True) query = input("\n==================== User =====================\n> ") ans = canvas.run(query=query) - print("==================== Bot =====================\n> ", end='') + print("==================== Bot =====================\n> ", end="") for ans in canvas.run(query=query): - print(ans, end='\n', flush=True) + print(ans, end="\n", flush=True) print(canvas.path) diff --git a/agent/tools/__init__.py b/agent/tools/__init__.py index e002614c8c..9f16bfd9f1 100644 --- a/agent/tools/__init__.py +++ b/agent/tools/__init__.py @@ -22,8 +22,9 @@ from typing import Dict, Type _package_path = os.path.dirname(__file__) __all_classes: Dict[str, Type] = {} + def _import_submodules() -> None: - for filename in os.listdir(_package_path): # noqa: F821 + for filename in os.listdir(_package_path): # noqa: F821 if filename.startswith("__") or not filename.endswith(".py") or filename.startswith("base"): continue module_name = filename[:-3] @@ -34,15 +35,16 @@ def _import_submodules() -> None: except ImportError as e: print(f"Warning: Failed to import module {module_name}: {str(e)}") + def _extract_classes_from_module(module: ModuleType) -> None: for name, obj in inspect.getmembers(module): - if (inspect.isclass(obj) and - obj.__module__ == module.__name__ and not name.startswith("_")): + if inspect.isclass(obj) and obj.__module__ == module.__name__ and not name.startswith("_"): __all_classes[name] = obj globals()[name] = obj + _import_submodules() __all__ = list(__all_classes.keys()) + ["__all_classes"] -del _package_path, _import_submodules, _extract_classes_from_module \ No newline at end of file +del _package_path, _import_submodules, _extract_classes_from_module diff --git a/agent/tools/arxiv.py b/agent/tools/arxiv.py index 10d502c56c..e09df5bf5c 100644 --- a/agent/tools/arxiv.py +++ b/agent/tools/arxiv.py @@ -28,7 +28,7 @@ class ArXivParam(ToolParamBase): """ def __init__(self): - self.meta:ToolMeta = { + self.meta: ToolMeta = { "name": "arxiv_search", "description": """arXiv is a free distribution service and an open-access archive for nearly 2.4 million scholarly articles in the fields of physics, mathematics, computer science, quantitative biology, quantitative finance, statistics, electrical engineering and systems science, and economics. Materials on this site are not peer-reviewed by arXiv.""", "parameters": { @@ -36,26 +36,20 @@ class ArXivParam(ToolParamBase): "type": "string", "description": "The search keywords to execute with arXiv. The keywords should be the most important words/terms(includes synonyms) from the original request.", "default": "{sys.query}", - "required": True + "required": True, } - } + }, } super().__init__() self.top_n = 12 - self.sort_by = 'submittedDate' + self.sort_by = "submittedDate" def check(self): self.check_positive_integer(self.top_n, "Top N") - self.check_valid_value(self.sort_by, "ArXiv Search Sort_by", - ['submittedDate', 'lastUpdatedDate', 'relevance']) + self.check_valid_value(self.sort_by, "ArXiv Search Sort_by", ["submittedDate", "lastUpdatedDate", "relevance"]) def get_input_form(self) -> dict[str, dict]: - return { - "query": { - "name": "Query", - "type": "line" - } - } + return {"query": {"name": "Query", "type": "line"}} class ArXiv(ToolBase, ABC): @@ -71,29 +65,20 @@ class ArXiv(ToolBase, ABC): return "" last_e = "" - for _ in range(self._param.max_retries+1): + for _ in range(self._param.max_retries + 1): if self.check_if_canceled("ArXiv processing"): return try: - sort_choices = {"relevance": arxiv.SortCriterion.Relevance, - "lastUpdatedDate": arxiv.SortCriterion.LastUpdatedDate, - 'submittedDate': arxiv.SortCriterion.SubmittedDate} + sort_choices = {"relevance": arxiv.SortCriterion.Relevance, "lastUpdatedDate": arxiv.SortCriterion.LastUpdatedDate, "submittedDate": arxiv.SortCriterion.SubmittedDate} arxiv_client = arxiv.Client() - search = arxiv.Search( - query=kwargs["query"], - max_results=self._param.top_n, - sort_by=sort_choices[self._param.sort_by] - ) + search = arxiv.Search(query=kwargs["query"], max_results=self._param.top_n, sort_by=sort_choices[self._param.sort_by]) results = list(arxiv_client.results(search)) if self.check_if_canceled("ArXiv processing"): return - self._retrieve_chunks(results, - get_title=lambda r: r.title, - get_url=lambda r: r.pdf_url, - get_content=lambda r: r.summary) + self._retrieve_chunks(results, get_title=lambda r: r.title, get_url=lambda r: r.pdf_url, get_content=lambda r: r.summary) return self.output("formalized_content") except Exception as e: if self.check_if_canceled("ArXiv processing"): diff --git a/agent/tools/base.py b/agent/tools/base.py index 71cf2c593e..1cb1fa23fc 100644 --- a/agent/tools/base.py +++ b/agent/tools/base.py @@ -28,10 +28,9 @@ from common.mcp_tool_call_conn import MCPToolBinding, MCPToolCallSession, ToolCa from timeit import default_timer as timer - - from common.misc_utils import thread_pool_exec + class ToolParameter(TypedDict): type: str description: str @@ -81,7 +80,9 @@ class LLMToolPluginCallSession(ToolCallSession): resp = fallback_output else: resp = fallback_output - logging.warning(f"[ToolCall] resp is None, fallback to output name={name} output_keys={list(fallback_output.keys()) if isinstance(fallback_output, dict) else type(fallback_output).__name__}") + logging.warning( + f"[ToolCall] resp is None, fallback to output name={name} output_keys={list(fallback_output.keys()) if isinstance(fallback_output, dict) else type(fallback_output).__name__}" + ) except Exception as e: logging.warning(f"[ToolCall] resp is None and output fallback failed name={name} err={e}") @@ -96,28 +97,25 @@ class LLMToolPluginCallSession(ToolCallSession): class ToolParamBase(ComponentParamBase): def __init__(self): - #self.meta:ToolMeta = None + # self.meta:ToolMeta = None super().__init__() self._init_inputs() self._init_attr_by_meta() def _init_inputs(self): self.inputs = {} - for k,p in self.meta["parameters"].items(): + for k, p in self.meta["parameters"].items(): self.inputs[k] = deepcopy(p) def _init_attr_by_meta(self): - for k,p in self.meta["parameters"].items(): + for k, p in self.meta["parameters"].items(): if not hasattr(self, k): setattr(self, k, p.get("default")) def get_meta(self): params = {} for k, p in self.meta["parameters"].items(): - params[k] = { - "type": p["type"], - "description": p["description"] - } + params[k] = {"type": p["type"], "description": p["description"]} if "enum" in p: params[k]["enum"] = p["enum"] @@ -129,12 +127,8 @@ class ToolParamBase(ComponentParamBase): "function": { "name": function_name, "description": desc, - "parameters": { - "type": "object", - "properties": params, - "required": [k for k, p in self.meta["parameters"].items() if p["required"]] - } - } + "parameters": {"type": "object", "properties": params, "required": [k for k, p in self.meta["parameters"].items() if p["required"]]}, + }, } @@ -209,20 +203,8 @@ class ToolBase(ComponentBase): title = get_title(r) url = get_url(r) score = get_score(r) if get_score else 1 - chunks.append({ - "chunk_id": id, - "content": content, - "doc_id": id, - "docnm_kwd": title, - "similarity": score, - "url": url - }) - aggs.append({ - "doc_name": title, - "doc_id": id, - "count": 1, - "url": url - }) + chunks.append({"chunk_id": id, "content": content, "doc_id": id, "docnm_kwd": title, "similarity": score, "url": url}) + aggs.append({"doc_name": title, "doc_id": id, "count": 1, "url": url}) self._canvas.add_reference(chunks, aggs) self.set_output("formalized_content", "\n".join(kb_prompt({"chunks": chunks, "doc_aggs": aggs}, 200000, True))) diff --git a/agent/tools/deepl.py b/agent/tools/deepl.py index 272347657a..037648a5ff 100644 --- a/agent/tools/deepl.py +++ b/agent/tools/deepl.py @@ -27,18 +27,53 @@ class DeepLParam(ComponentParamBase): super().__init__() self.auth_key = "xxx" self.parameters = [] - self.source_lang = 'ZH' - self.target_lang = 'EN-GB' + self.source_lang = "ZH" + self.target_lang = "EN-GB" def check(self): - self.check_valid_value(self.source_lang, "Source language", - ['AR', 'BG', 'CS', 'DA', 'DE', 'EL', 'EN', 'ES', 'ET', 'FI', 'FR', 'HU', 'ID', 'IT', - 'JA', 'KO', 'LT', 'LV', 'NB', 'NL', 'PL', 'PT', 'RO', 'RU', 'SK', 'SL', 'SV', 'TR', - 'UK', 'ZH']) - self.check_valid_value(self.target_lang, "Target language", - ['AR', 'BG', 'CS', 'DA', 'DE', 'EL', 'EN-GB', 'EN-US', 'ES', 'ET', 'FI', 'FR', 'HU', - 'ID', 'IT', 'JA', 'KO', 'LT', 'LV', 'NB', 'NL', 'PL', 'PT-BR', 'PT-PT', 'RO', 'RU', - 'SK', 'SL', 'SV', 'TR', 'UK', 'ZH']) + self.check_valid_value( + self.source_lang, + "Source language", + ["AR", "BG", "CS", "DA", "DE", "EL", "EN", "ES", "ET", "FI", "FR", "HU", "ID", "IT", "JA", "KO", "LT", "LV", "NB", "NL", "PL", "PT", "RO", "RU", "SK", "SL", "SV", "TR", "UK", "ZH"], + ) + self.check_valid_value( + self.target_lang, + "Target language", + [ + "AR", + "BG", + "CS", + "DA", + "DE", + "EL", + "EN-GB", + "EN-US", + "ES", + "ET", + "FI", + "FR", + "HU", + "ID", + "IT", + "JA", + "KO", + "LT", + "LV", + "NB", + "NL", + "PL", + "PT-BR", + "PT-PT", + "RO", + "RU", + "SK", + "SL", + "SV", + "TR", + "UK", + "ZH", + ], + ) class DeepL(ComponentBase, ABC): @@ -57,8 +92,7 @@ class DeepL(ComponentBase, ABC): try: translator = deepl.Translator(self._param.auth_key) - result = translator.translate_text(ans, source_lang=self._param.source_lang, - target_lang=self._param.target_lang) + result = translator.translate_text(ans, source_lang=self._param.source_lang, target_lang=self._param.target_lang) return DeepL.be_output(result.text) except Exception as e: diff --git a/agent/tools/duckduckgo.py b/agent/tools/duckduckgo.py index fd2ec1801b..bd072bd242 100644 --- a/agent/tools/duckduckgo.py +++ b/agent/tools/duckduckgo.py @@ -28,7 +28,7 @@ class DuckDuckGoParam(ToolParamBase): """ def __init__(self): - self.meta:ToolMeta = { + self.meta: ToolMeta = { "name": "duckduckgo_search", "description": "DuckDuckGo is a search engine focused on privacy. It offers search capabilities for web pages, images, and provides translation services. DuckDuckGo also features a private AI chat interface, providing users with an AI assistant that prioritizes data protection.", "parameters": { @@ -36,7 +36,7 @@ class DuckDuckGoParam(ToolParamBase): "type": "string", "description": "The search keywords to execute with DuckDuckGo. The keywords should be the most important words/terms(includes synonyms) from the original request.", "default": "{sys.query}", - "required": True + "required": True, }, "channel": { "type": "string", @@ -45,7 +45,7 @@ class DuckDuckGoParam(ToolParamBase): "default": "general", "required": False, }, - } + }, } super().__init__() self.top_n = 10 @@ -56,18 +56,7 @@ class DuckDuckGoParam(ToolParamBase): self.check_valid_value(self.channel, "Web Search or News", ["text", "news"]) def get_input_form(self) -> dict[str, dict]: - return { - "query": { - "name": "Query", - "type": "line" - }, - "channel": { - "name": "Channel", - "type": "options", - "value": "general", - "options": ["general", "news"] - } - } + return {"query": {"name": "Query", "type": "line"}, "channel": {"name": "Channel", "type": "options", "value": "general", "options": ["general", "news"]}} class DuckDuckGo(ToolBase, ABC): @@ -83,7 +72,7 @@ class DuckDuckGo(ToolBase, ABC): return "" last_e = "" - for _ in range(self._param.max_retries+1): + for _ in range(self._param.max_retries + 1): if self.check_if_canceled("DuckDuckGo processing"): return @@ -99,10 +88,7 @@ class DuckDuckGo(ToolBase, ABC): if self.check_if_canceled("DuckDuckGo processing"): return - self._retrieve_chunks(duck_res, - get_title=lambda r: r["title"], - get_url=lambda r: r.get("href", r.get("url")), - get_content=lambda r: r["body"]) + self._retrieve_chunks(duck_res, get_title=lambda r: r["title"], get_url=lambda r: r.get("href", r.get("url")), get_content=lambda r: r["body"]) self.set_output("json", duck_res) return self.output("formalized_content") else: @@ -116,10 +102,7 @@ class DuckDuckGo(ToolBase, ABC): if self.check_if_canceled("DuckDuckGo processing"): return - self._retrieve_chunks(duck_res, - get_title=lambda r: r["title"], - get_url=lambda r: r.get("href", r.get("url")), - get_content=lambda r: r["body"]) + self._retrieve_chunks(duck_res, get_title=lambda r: r["title"], get_url=lambda r: r.get("href", r.get("url")), get_content=lambda r: r["body"]) self.set_output("json", duck_res) return self.output("formalized_content") except Exception as e: diff --git a/agent/tools/email.py b/agent/tools/email.py index aa563cf9cc..45fd355ff1 100644 --- a/agent/tools/email.py +++ b/agent/tools/email.py @@ -32,36 +32,17 @@ class EmailParam(ToolParamBase): """ Define the Email component parameters. """ + def __init__(self): - self.meta:ToolMeta = { + self.meta: ToolMeta = { "name": "email", "description": "The email is a method of electronic communication for sending and receiving information through the Internet. This tool helps users to send emails to one person or to multiple recipients with support for CC, BCC, file attachments, and markdown-to-HTML conversion.", "parameters": { - "to_email": { - "type": "string", - "description": "The target email address.", - "default": "{sys.query}", - "required": True - }, - "cc_email": { - "type": "string", - "description": "The other email addresses needs to be send to. Comma splited.", - "default": "", - "required": False - }, - "content": { - "type": "string", - "description": "The content of the email.", - "default": "", - "required": False - }, - "subject": { - "type": "string", - "description": "The subject/title of the email.", - "default": "", - "required": False - } - } + "to_email": {"type": "string", "description": "The target email address.", "default": "{sys.query}", "required": True}, + "cc_email": {"type": "string", "description": "The other email addresses needs to be send to. Comma splited.", "default": "", "required": False}, + "content": {"type": "string", "description": "The content of the email.", "default": "", "required": False}, + "subject": {"type": "string", "description": "The subject/title of the email.", "default": "", "required": False}, + }, } super().__init__() # Fixed configuration parameters @@ -81,20 +62,9 @@ class EmailParam(ToolParamBase): def get_input_form(self) -> dict[str, dict]: return { - "to_email": { - "name": "To ", - "type": "line" - }, - "subject": { - "name": "Subject", - "type": "line", - "optional": True - }, - "cc_email": { - "name": "CC To", - "type": "line", - "optional": True - }, + "to_email": {"name": "To ", "type": "line"}, + "subject": {"name": "Subject", "type": "line", "optional": True}, + "cc_email": {"name": "CC To", "type": "line", "optional": True}, } @@ -111,7 +81,7 @@ class Email(ToolBase, ABC): return "" last_e = "" - for _ in range(self._param.max_retries+1): + for _ in range(self._param.max_retries + 1): if self.check_if_canceled("Email processing"): return @@ -126,19 +96,19 @@ class Email(ToolBase, ABC): return False # Create email object - msg = MIMEMultipart('alternative') + msg = MIMEMultipart("alternative") # Properly handle sender name encoding - msg['From'] = formataddr((str(Header(self._param.sender_name,'utf-8')), self._param.email)) - msg['To'] = email_data["to_email"] + msg["From"] = formataddr((str(Header(self._param.sender_name, "utf-8")), self._param.email)) + msg["To"] = email_data["to_email"] if email_data.get("cc_email"): - msg['Cc'] = email_data["cc_email"] - msg['Subject'] = Header(email_data.get("subject", "No Subject"), 'utf-8').encode() + msg["Cc"] = email_data["cc_email"] + msg["Subject"] = Header(email_data.get("subject", "No Subject"), "utf-8").encode() # Use content from email_data or default content email_content = email_data.get("content", "No content provided") # msg.attach(MIMEText(email_content, 'plain', 'utf-8')) - msg.attach(MIMEText(email_content, 'html', 'utf-8')) + msg.attach(MIMEText(email_content, "html", "utf-8")) # Connect to SMTP server and send logging.info(f"Connecting to SMTP server {self._param.smtp_server}:{self._param.smtp_port}") @@ -160,7 +130,7 @@ class Email(ToolBase, ABC): # Get all recipient list recipients = [email_data["to_email"]] if email_data.get("cc_email"): - recipients.extend(email_data["cc_email"].split(',')) + recipients.extend(email_data["cc_email"].split(",")) # Send email logging.info(f"Sending email to recipients: {recipients}") diff --git a/agent/tools/exesql.py b/agent/tools/exesql.py index 86935cc49d..11359755b4 100644 --- a/agent/tools/exesql.py +++ b/agent/tools/exesql.py @@ -34,17 +34,10 @@ class ExeSQLParam(ToolParamBase): """ def __init__(self): - self.meta:ToolMeta = { + self.meta: ToolMeta = { "name": "execute_sql", "description": "This is a tool that can execute SQL.", - "parameters": { - "sql": { - "type": "string", - "description": "The SQL needs to be executed.", - "default": "{sys.query}", - "required": True - } - } + "parameters": {"sql": {"type": "string", "description": "The SQL needs to be executed.", "default": "{sys.query}", "required": True}}, } super().__init__() self.db_type = "mysql" @@ -56,7 +49,7 @@ class ExeSQLParam(ToolParamBase): self.max_records = 1024 def check(self): - self.check_valid_value(self.db_type, "Choose DB type", ['mysql', 'postgres', 'mariadb', 'mssql', 'IBM DB2', 'trino', 'oceanbase']) + self.check_valid_value(self.db_type, "Choose DB type", ["mysql", "postgres", "mariadb", "mssql", "IBM DB2", "trino", "oceanbase"]) self.check_empty(self.database, "Database name") self.check_empty(self.username, "database username") self.check_empty(self.host, "IP Address") @@ -71,12 +64,7 @@ class ExeSQLParam(ToolParamBase): raise ValueError("For the security reason, it does not support database named rag_flow.") def get_input_form(self) -> dict[str, dict]: - return { - "sql": { - "name": "SQL", - "type": "line" - } - } + return {"sql": {"name": "SQL", "type": "line"}} class ExeSQL(ToolBase, ABC): @@ -90,6 +78,7 @@ class ExeSQL(ToolBase, ABC): def convert_decimals(obj): from decimal import Decimal import math + if isinstance(obj, float): # Handle NaN and Infinity which are not valid JSON values if math.isnan(obj) or math.isinf(obj): @@ -140,24 +129,21 @@ class ExeSQL(ToolBase, ABC): sqls = sql.split(";") if self._param.db_type in ["mysql", "mariadb"]: - db = pymysql.connect(db=self._param.database, user=self._param.username, host=safe_host, - port=self._param.port, password=self._param.password) - elif self._param.db_type == 'oceanbase': - db = pymysql.connect(db=self._param.database, user=self._param.username, host=safe_host, - port=self._param.port, password=self._param.password, charset='utf8mb4') - elif self._param.db_type == 'postgres': - db = psycopg2.connect(dbname=self._param.database, user=self._param.username, host=safe_host, - port=self._param.port, password=self._param.password) - elif self._param.db_type == 'mssql': + db = pymysql.connect(db=self._param.database, user=self._param.username, host=safe_host, port=self._param.port, password=self._param.password) + elif self._param.db_type == "oceanbase": + db = pymysql.connect(db=self._param.database, user=self._param.username, host=safe_host, port=self._param.port, password=self._param.password, charset="utf8mb4") + elif self._param.db_type == "postgres": + db = psycopg2.connect(dbname=self._param.database, user=self._param.username, host=safe_host, port=self._param.port, password=self._param.password) + elif self._param.db_type == "mssql": conn_str = ( - r'DRIVER={ODBC Driver 17 for SQL Server};' - r'SERVER=' + safe_host + ',' + str(self._param.port) + ';' - r'DATABASE=' + self._param.database + ';' - r'UID=' + self._param.username + ';' - r'PWD=' + self._param.password + r"DRIVER={ODBC Driver 17 for SQL Server};" + r"SERVER=" + safe_host + "," + str(self._param.port) + ";" + r"DATABASE=" + self._param.database + ";" + r"UID=" + self._param.username + ";" + r"PWD=" + self._param.password ) db = pyodbc.connect(conn_str) - elif self._param.db_type == 'trino': + elif self._param.db_type == "trino": try: import trino from trino.auth import BasicAuthentication @@ -186,26 +172,14 @@ class ExeSQL(ToolBase, ABC): try: db = trino.dbapi.connect( - host=safe_host, - port=int(self._param.port or 8080), - user=self._param.username or "ragflow", - catalog=catalog, - schema=schema or "default", - http_scheme=http_scheme, - auth=auth + host=safe_host, port=int(self._param.port or 8080), user=self._param.username or "ragflow", catalog=catalog, schema=schema or "default", http_scheme=http_scheme, auth=auth ) except Exception as e: raise Exception("Database Connection Failed! \n" + str(e)) - elif self._param.db_type == 'IBM DB2': + elif self._param.db_type == "IBM DB2": import ibm_db - conn_str = ( - f"DATABASE={self._param.database};" - f"HOSTNAME={safe_host};" - f"PORT={self._param.port};" - f"PROTOCOL=TCPIP;" - f"UID={self._param.username};" - f"PWD={self._param.password};" - ) + + conn_str = f"DATABASE={self._param.database};HOSTNAME={safe_host};PORT={self._param.port};PROTOCOL=TCPIP;UID={self._param.username};PWD={self._param.password};" try: conn = ibm_db.connect(conn_str, "", "") except Exception as e: @@ -275,7 +249,7 @@ class ExeSQL(ToolBase, ABC): if self.check_if_canceled("ExeSQL processing"): return - single_sql = single_sql.replace('```', '').strip() + single_sql = single_sql.replace("```", "").strip() if not single_sql: continue single_sql = re.sub(r"\[ID:[0-9]+\]", "", single_sql) @@ -288,20 +262,19 @@ class ExeSQL(ToolBase, ABC): if cursor.rowcount == 0: sql_res.append({"content": "No record in the database!"}) break - if self._param.db_type == 'mssql': - single_res = pd.DataFrame.from_records(cursor.fetchmany(self._param.max_records), - columns=[desc[0] for desc in cursor.description]) + if self._param.db_type == "mssql": + single_res = pd.DataFrame.from_records(cursor.fetchmany(self._param.max_records), columns=[desc[0] for desc in cursor.description]) else: single_res = pd.DataFrame([i for i in cursor.fetchmany(self._param.max_records)]) single_res.columns = [i[0] for i in cursor.description] for col in single_res.columns: if pd.api.types.is_datetime64_any_dtype(single_res[col]): - single_res[col] = single_res[col].dt.strftime('%Y-%m-%d') + single_res[col] = single_res[col].dt.strftime("%Y-%m-%d") single_res = single_res.where(pd.notnull(single_res), None) - sql_res.append(convert_decimals(single_res.to_dict(orient='records'))) + sql_res.append(convert_decimals(single_res.to_dict(orient="records"))) formalized_content.append(single_res.to_markdown(index=False, floatfmt=".6f")) except Exception as e: # A failing statement must not abort the node: report it and keep diff --git a/agent/tools/github.py b/agent/tools/github.py index 4a95ac366a..614e558d5b 100644 --- a/agent/tools/github.py +++ b/agent/tools/github.py @@ -29,7 +29,7 @@ class GitHubParam(ToolParamBase): """ def __init__(self): - self.meta:ToolMeta = { + self.meta: ToolMeta = { "name": "github_search", "description": """GitHub repository search is a feature that enables users to find specific repositories on the GitHub platform. This search functionality allows users to locate projects, codebases, and other content hosted on GitHub based on various criteria.""", "parameters": { @@ -37,9 +37,9 @@ class GitHubParam(ToolParamBase): "type": "string", "description": "The search keywords to execute with GitHub. The keywords should be the most important words/terms(includes synonyms) from the original request.", "default": "{sys.query}", - "required": True + "required": True, } - } + }, } super().__init__() self.top_n = 10 @@ -48,12 +48,8 @@ class GitHubParam(ToolParamBase): self.check_positive_integer(self.top_n, "Top N") def get_input_form(self) -> dict[str, dict]: - return { - "query": { - "name": "Query", - "type": "line" - } - } + return {"query": {"name": "Query", "type": "line"}} + class GitHub(ToolBase, ABC): component_name = "GitHub" @@ -68,24 +64,20 @@ class GitHub(ToolBase, ABC): return "" last_e = "" - for _ in range(self._param.max_retries+1): + for _ in range(self._param.max_retries + 1): if self.check_if_canceled("GitHub processing"): return try: - url = 'https://api.github.com/search/repositories?q=' + kwargs["query"] + '&sort=stars&order=desc&per_page=' + str( - self._param.top_n) - headers = {"Content-Type": "application/vnd.github+json", "X-GitHub-Api-Version": '2022-11-28'} + url = "https://api.github.com/search/repositories?q=" + kwargs["query"] + "&sort=stars&order=desc&per_page=" + str(self._param.top_n) + headers = {"Content-Type": "application/vnd.github+json", "X-GitHub-Api-Version": "2022-11-28"} response = requests.get(url=url, headers=headers, timeout=DEFAULT_TIMEOUT).json() if self.check_if_canceled("GitHub processing"): return - self._retrieve_chunks(response['items'], - get_title=lambda r: r["name"], - get_url=lambda r: r["html_url"], - get_content=lambda r: str(r["description"]) + '\n stars:' + str(r['watchers'])) - self.set_output("json", response['items']) + self._retrieve_chunks(response["items"], get_title=lambda r: r["name"], get_url=lambda r: r["html_url"], get_content=lambda r: str(r["description"]) + "\n stars:" + str(r["watchers"])) + self.set_output("json", response["items"]) return self.output("formalized_content") except Exception as e: if self.check_if_canceled("GitHub processing"): diff --git a/agent/tools/google.py b/agent/tools/google.py index 312b5a1fe3..a5fb4bdf29 100644 --- a/agent/tools/google.py +++ b/agent/tools/google.py @@ -28,7 +28,7 @@ class GoogleParam(ToolParamBase): """ def __init__(self): - self.meta:ToolMeta = { + self.meta: ToolMeta = { "name": "google_search", "description": """Search the world's information, including webpages, images, videos and more. Google has many special features to help you find exactly what you're looking ...""", "parameters": { @@ -36,7 +36,7 @@ class GoogleParam(ToolParamBase): "type": "string", "description": "The search keywords to execute with Google. The keywords should be the most important words/terms(includes synonyms) from the original request.", "default": "{sys.query}", - "required": True + "required": True, }, "start": { "type": "integer", @@ -49,8 +49,8 @@ class GoogleParam(ToolParamBase): "description": "Parameter defines the maximum number of results to return. (e.g., 10 (default) returns 10 results, 40 returns 40 results, and 100 returns 100 results). The use of num may introduce latency, and/or prevent the inclusion of specialized result types. It is better to omit this parameter unless it is strictly necessary to increase the number of results per page. Results are not guaranteed to have the number of results specified in num.", "default": "6", "required": False, - } - } + }, + }, } super().__init__() self.start = 0 @@ -61,57 +61,418 @@ class GoogleParam(ToolParamBase): def check(self): self.check_empty(self.api_key, "SerpApi API key") - self.check_valid_value(self.country, "Google Country", - ['af', 'al', 'dz', 'as', 'ad', 'ao', 'ai', 'aq', 'ag', 'ar', 'am', 'aw', 'au', 'at', - 'az', 'bs', 'bh', 'bd', 'bb', 'by', 'be', 'bz', 'bj', 'bm', 'bt', 'bo', 'ba', 'bw', - 'bv', 'br', 'io', 'bn', 'bg', 'bf', 'bi', 'kh', 'cm', 'ca', 'cv', 'ky', 'cf', 'td', - 'cl', 'cn', 'cx', 'cc', 'co', 'km', 'cg', 'cd', 'ck', 'cr', 'ci', 'hr', 'cu', 'cy', - 'cz', 'dk', 'dj', 'dm', 'do', 'ec', 'eg', 'sv', 'gq', 'er', 'ee', 'et', 'fk', 'fo', - 'fj', 'fi', 'fr', 'gf', 'pf', 'tf', 'ga', 'gm', 'ge', 'de', 'gh', 'gi', 'gr', 'gl', - 'gd', 'gp', 'gu', 'gt', 'gn', 'gw', 'gy', 'ht', 'hm', 'va', 'hn', 'hk', 'hu', 'is', - 'in', 'id', 'ir', 'iq', 'ie', 'il', 'it', 'jm', 'jp', 'jo', 'kz', 'ke', 'ki', 'kp', - 'kr', 'kw', 'kg', 'la', 'lv', 'lb', 'ls', 'lr', 'ly', 'li', 'lt', 'lu', 'mo', 'mk', - 'mg', 'mw', 'my', 'mv', 'ml', 'mt', 'mh', 'mq', 'mr', 'mu', 'yt', 'mx', 'fm', 'md', - 'mc', 'mn', 'ms', 'ma', 'mz', 'mm', 'na', 'nr', 'np', 'nl', 'an', 'nc', 'nz', 'ni', - 'ne', 'ng', 'nu', 'nf', 'mp', 'no', 'om', 'pk', 'pw', 'ps', 'pa', 'pg', 'py', 'pe', - 'ph', 'pn', 'pl', 'pt', 'pr', 'qa', 're', 'ro', 'ru', 'rw', 'sh', 'kn', 'lc', 'pm', - 'vc', 'ws', 'sm', 'st', 'sa', 'sn', 'rs', 'sc', 'sl', 'sg', 'sk', 'si', 'sb', 'so', - 'za', 'gs', 'es', 'lk', 'sd', 'sr', 'sj', 'sz', 'se', 'ch', 'sy', 'tw', 'tj', 'tz', - 'th', 'tl', 'tg', 'tk', 'to', 'tt', 'tn', 'tr', 'tm', 'tc', 'tv', 'ug', 'ua', 'ae', - 'uk', 'gb', 'us', 'um', 'uy', 'uz', 'vu', 've', 'vn', 'vg', 'vi', 'wf', 'eh', 'ye', - 'zm', 'zw']) - self.check_valid_value(self.language, "Google languages", - ['af', 'ak', 'sq', 'ws', 'am', 'ar', 'hy', 'az', 'eu', 'be', 'bem', 'bn', 'bh', - 'xx-bork', 'bs', 'br', 'bg', 'bt', 'km', 'ca', 'chr', 'ny', 'zh-cn', 'zh-tw', 'co', - 'hr', 'cs', 'da', 'nl', 'xx-elmer', 'en', 'eo', 'et', 'ee', 'fo', 'tl', 'fi', 'fr', - 'fy', 'gaa', 'gl', 'ka', 'de', 'el', 'kl', 'gn', 'gu', 'xx-hacker', 'ht', 'ha', 'haw', - 'iw', 'hi', 'hu', 'is', 'ig', 'id', 'ia', 'ga', 'it', 'ja', 'jw', 'kn', 'kk', 'rw', - 'rn', 'xx-klingon', 'kg', 'ko', 'kri', 'ku', 'ckb', 'ky', 'lo', 'la', 'lv', 'ln', 'lt', - 'loz', 'lg', 'ach', 'mk', 'mg', 'ms', 'ml', 'mt', 'mv', 'mi', 'mr', 'mfe', 'mo', 'mn', - 'sr-me', 'my', 'ne', 'pcm', 'nso', 'no', 'nn', 'oc', 'or', 'om', 'ps', 'fa', - 'xx-pirate', 'pl', 'pt', 'pt-br', 'pt-pt', 'pa', 'qu', 'ro', 'rm', 'nyn', 'ru', 'gd', - 'sr', 'sh', 'st', 'tn', 'crs', 'sn', 'sd', 'si', 'sk', 'sl', 'so', 'es', 'es-419', 'su', - 'sw', 'sv', 'tg', 'ta', 'tt', 'te', 'th', 'ti', 'to', 'lua', 'tum', 'tr', 'tk', 'tw', - 'ug', 'uk', 'ur', 'uz', 'vu', 'vi', 'cy', 'wo', 'xh', 'yi', 'yo', 'zu'] - ) + self.check_valid_value( + self.country, + "Google Country", + [ + "af", + "al", + "dz", + "as", + "ad", + "ao", + "ai", + "aq", + "ag", + "ar", + "am", + "aw", + "au", + "at", + "az", + "bs", + "bh", + "bd", + "bb", + "by", + "be", + "bz", + "bj", + "bm", + "bt", + "bo", + "ba", + "bw", + "bv", + "br", + "io", + "bn", + "bg", + "bf", + "bi", + "kh", + "cm", + "ca", + "cv", + "ky", + "cf", + "td", + "cl", + "cn", + "cx", + "cc", + "co", + "km", + "cg", + "cd", + "ck", + "cr", + "ci", + "hr", + "cu", + "cy", + "cz", + "dk", + "dj", + "dm", + "do", + "ec", + "eg", + "sv", + "gq", + "er", + "ee", + "et", + "fk", + "fo", + "fj", + "fi", + "fr", + "gf", + "pf", + "tf", + "ga", + "gm", + "ge", + "de", + "gh", + "gi", + "gr", + "gl", + "gd", + "gp", + "gu", + "gt", + "gn", + "gw", + "gy", + "ht", + "hm", + "va", + "hn", + "hk", + "hu", + "is", + "in", + "id", + "ir", + "iq", + "ie", + "il", + "it", + "jm", + "jp", + "jo", + "kz", + "ke", + "ki", + "kp", + "kr", + "kw", + "kg", + "la", + "lv", + "lb", + "ls", + "lr", + "ly", + "li", + "lt", + "lu", + "mo", + "mk", + "mg", + "mw", + "my", + "mv", + "ml", + "mt", + "mh", + "mq", + "mr", + "mu", + "yt", + "mx", + "fm", + "md", + "mc", + "mn", + "ms", + "ma", + "mz", + "mm", + "na", + "nr", + "np", + "nl", + "an", + "nc", + "nz", + "ni", + "ne", + "ng", + "nu", + "nf", + "mp", + "no", + "om", + "pk", + "pw", + "ps", + "pa", + "pg", + "py", + "pe", + "ph", + "pn", + "pl", + "pt", + "pr", + "qa", + "re", + "ro", + "ru", + "rw", + "sh", + "kn", + "lc", + "pm", + "vc", + "ws", + "sm", + "st", + "sa", + "sn", + "rs", + "sc", + "sl", + "sg", + "sk", + "si", + "sb", + "so", + "za", + "gs", + "es", + "lk", + "sd", + "sr", + "sj", + "sz", + "se", + "ch", + "sy", + "tw", + "tj", + "tz", + "th", + "tl", + "tg", + "tk", + "to", + "tt", + "tn", + "tr", + "tm", + "tc", + "tv", + "ug", + "ua", + "ae", + "uk", + "gb", + "us", + "um", + "uy", + "uz", + "vu", + "ve", + "vn", + "vg", + "vi", + "wf", + "eh", + "ye", + "zm", + "zw", + ], + ) + self.check_valid_value( + self.language, + "Google languages", + [ + "af", + "ak", + "sq", + "ws", + "am", + "ar", + "hy", + "az", + "eu", + "be", + "bem", + "bn", + "bh", + "xx-bork", + "bs", + "br", + "bg", + "bt", + "km", + "ca", + "chr", + "ny", + "zh-cn", + "zh-tw", + "co", + "hr", + "cs", + "da", + "nl", + "xx-elmer", + "en", + "eo", + "et", + "ee", + "fo", + "tl", + "fi", + "fr", + "fy", + "gaa", + "gl", + "ka", + "de", + "el", + "kl", + "gn", + "gu", + "xx-hacker", + "ht", + "ha", + "haw", + "iw", + "hi", + "hu", + "is", + "ig", + "id", + "ia", + "ga", + "it", + "ja", + "jw", + "kn", + "kk", + "rw", + "rn", + "xx-klingon", + "kg", + "ko", + "kri", + "ku", + "ckb", + "ky", + "lo", + "la", + "lv", + "ln", + "lt", + "loz", + "lg", + "ach", + "mk", + "mg", + "ms", + "ml", + "mt", + "mv", + "mi", + "mr", + "mfe", + "mo", + "mn", + "sr-me", + "my", + "ne", + "pcm", + "nso", + "no", + "nn", + "oc", + "or", + "om", + "ps", + "fa", + "xx-pirate", + "pl", + "pt", + "pt-br", + "pt-pt", + "pa", + "qu", + "ro", + "rm", + "nyn", + "ru", + "gd", + "sr", + "sh", + "st", + "tn", + "crs", + "sn", + "sd", + "si", + "sk", + "sl", + "so", + "es", + "es-419", + "su", + "sw", + "sv", + "tg", + "ta", + "tt", + "te", + "th", + "ti", + "to", + "lua", + "tum", + "tr", + "tk", + "tw", + "ug", + "uk", + "ur", + "uz", + "vu", + "vi", + "cy", + "wo", + "xh", + "yi", + "yo", + "zu", + ], + ) def get_input_form(self) -> dict[str, dict]: - return { - "q": { - "name": "Query", - "type": "line" - }, - "start": { - "name": "From", - "type": "integer", - "value": 0 - }, - "num": { - "name": "Limit", - "type": "integer", - "value": 12 - } - } + return {"q": {"name": "Query", "type": "line"}, "start": {"name": "From", "type": "integer", "value": 0}, "num": {"name": "Limit", "type": "integer", "value": 12}} + class Google(ToolBase, ABC): component_name = "Google" @@ -125,16 +486,9 @@ class Google(ToolBase, ABC): self.set_output("formalized_content", "") return "" - params = { - "api_key": self._param.api_key, - "engine": "google", - "q": kwargs["q"], - "google_domain": "google.com", - "gl": self._param.country, - "hl": self._param.language - } + params = {"api_key": self._param.api_key, "engine": "google", "q": kwargs["q"], "google_domain": "google.com", "gl": self._param.country, "hl": self._param.language} last_e = "" - for _ in range(self._param.max_retries+1): + for _ in range(self._param.max_retries + 1): if self.check_if_canceled("Google processing"): return @@ -144,11 +498,12 @@ class Google(ToolBase, ABC): if self.check_if_canceled("Google processing"): return - self._retrieve_chunks(search["organic_results"], - get_title=lambda r: r["title"], - get_url=lambda r: r["link"], - get_content=lambda r: r.get("about_this_result", {}).get("source", {}).get("description", r["snippet"]) - ) + self._retrieve_chunks( + search["organic_results"], + get_title=lambda r: r["title"], + get_url=lambda r: r["link"], + get_content=lambda r: r.get("about_this_result", {}).get("source", {}).get("description", r["snippet"]), + ) self.set_output("json", search["organic_results"]) return self.output("formalized_content") except Exception as e: diff --git a/agent/tools/googlescholar.py b/agent/tools/googlescholar.py index 8196304ee8..bc7f01a919 100644 --- a/agent/tools/googlescholar.py +++ b/agent/tools/googlescholar.py @@ -29,7 +29,7 @@ class GoogleScholarParam(ToolParamBase): """ def __init__(self): - self.meta:ToolMeta = { + self.meta: ToolMeta = { "name": "google_scholar_search", "description": """Google Scholar provides a simple way to broadly search for scholarly literature. From one place, you can search across many disciplines and sources: articles, theses, books, abstracts and court opinions, from academic publishers, professional societies, online repositories, universities and other web sites. Google Scholar helps you find relevant work across the world of scholarly research.""", "parameters": { @@ -37,29 +37,25 @@ class GoogleScholarParam(ToolParamBase): "type": "string", "description": "The search keyword to execute with Google Scholar. The keywords should be the most important words/terms(includes synonyms) from the original request.", "default": "{sys.query}", - "required": True + "required": True, } - } + }, } super().__init__() self.top_n = 12 - self.sort_by = 'relevance' + self.sort_by = "relevance" self.year_low = None self.year_high = None self.patents = True def check(self): self.check_positive_integer(self.top_n, "Top N") - self.check_valid_value(self.sort_by, "GoogleScholar Sort_by", ['date', 'relevance']) + self.check_valid_value(self.sort_by, "GoogleScholar Sort_by", ["date", "relevance"]) self.check_boolean(self.patents, "Whether or not to include patents, defaults to True") def get_input_form(self) -> dict[str, dict]: - return { - "query": { - "name": "Query", - "type": "line" - } - } + return {"query": {"name": "Query", "type": "line"}} + class GoogleScholar(ToolBase, ABC): component_name = "GoogleScholar" @@ -77,13 +73,12 @@ class GoogleScholar(ToolBase, ABC): return "" last_e = "" - for _ in range(self._param.max_retries+1): + for _ in range(self._param.max_retries + 1): if self.check_if_canceled("GoogleScholar processing"): return try: - scholar_client = scholarly.search_pubs(kwargs["query"], patents=self._param.patents, year_low=self._param.year_low, - year_high=self._param.year_high, sort_by=self._param.sort_by) + scholar_client = scholarly.search_pubs(kwargs["query"], patents=self._param.patents, year_low=self._param.year_low, year_high=self._param.year_high, sort_by=self._param.sort_by) if self.check_if_canceled("GoogleScholar processing"): return @@ -94,11 +89,12 @@ class GoogleScholar(ToolBase, ABC): # would otherwise leave json empty). results = list(itertools.islice(scholar_client, self._param.top_n)) - self._retrieve_chunks(results, - get_title=lambda r: r['bib']['title'], - get_url=lambda r: r["pub_url"], - get_content=lambda r: "\n author: " + ",".join(r['bib']['author']) + '\n Abstract: ' + r['bib'].get('abstract', 'no abstract') - ) + self._retrieve_chunks( + results, + get_title=lambda r: r["bib"]["title"], + get_url=lambda r: r["pub_url"], + get_content=lambda r: "\n author: " + ",".join(r["bib"]["author"]) + "\n Abstract: " + r["bib"].get("abstract", "no abstract"), + ) self.set_output("json", results) return self.output("formalized_content") except Exception as e: diff --git a/agent/tools/jin10.py b/agent/tools/jin10.py index a37249ca40..d0629155c8 100644 --- a/agent/tools/jin10.py +++ b/agent/tools/jin10.py @@ -30,21 +30,21 @@ class Jin10Param(ComponentParamBase): super().__init__() self.type = "flash" self.secret_key = "xxx" - self.flash_type = '1' - self.calendar_type = 'cj' - self.calendar_datatype = 'data' - self.symbols_type = 'GOODS' - self.symbols_datatype = 'symbols' + self.flash_type = "1" + self.calendar_type = "cj" + self.calendar_datatype = "data" + self.symbols_type = "GOODS" + self.symbols_datatype = "symbols" self.contain = "" self.filter = "" def check(self): - self.check_valid_value(self.type, "Type", ['flash', 'calendar', 'symbols', 'news']) - self.check_valid_value(self.flash_type, "Flash Type", ['1', '2', '3', '4', '5']) - self.check_valid_value(self.calendar_type, "Calendar Type", ['cj', 'qh', 'hk', 'us']) - self.check_valid_value(self.calendar_datatype, "Calendar DataType", ['data', 'event', 'holiday']) - self.check_valid_value(self.symbols_type, "Symbols Type", ['GOODS', 'FOREX', 'FUTURE', 'CRYPTO']) - self.check_valid_value(self.symbols_datatype, 'Symbols DataType', ['symbols', 'quotes']) + self.check_valid_value(self.type, "Type", ["flash", "calendar", "symbols", "news"]) + self.check_valid_value(self.flash_type, "Flash Type", ["1", "2", "3", "4", "5"]) + self.check_valid_value(self.calendar_type, "Calendar Type", ["cj", "qh", "hk", "us"]) + self.check_valid_value(self.calendar_datatype, "Calendar DataType", ["data", "event", "holiday"]) + self.check_valid_value(self.symbols_type, "Symbols Type", ["GOODS", "FOREX", "FUTURE", "CRYPTO"]) + self.check_valid_value(self.symbols_datatype, "Symbols DataType", ["symbols", "quotes"]) class Jin10(ComponentBase, ABC): @@ -60,86 +60,77 @@ class Jin10(ComponentBase, ABC): return Jin10.be_output("") jin10_res = [] - headers = {'secret-key': self._param.secret_key} + headers = {"secret-key": self._param.secret_key} try: if self.check_if_canceled("Jin10 processing"): return if self._param.type == "flash": - params = { - 'category': self._param.flash_type, - 'contain': self._param.contain, - 'filter': self._param.filter - } - response = requests.get( - url='https://open-data-api.jin10.com/data-api/flash?category=' + self._param.flash_type, - headers=headers, data=json.dumps(params), timeout=DEFAULT_TIMEOUT) + params = {"category": self._param.flash_type, "contain": self._param.contain, "filter": self._param.filter} + response = requests.get(url="https://open-data-api.jin10.com/data-api/flash?category=" + self._param.flash_type, headers=headers, data=json.dumps(params), timeout=DEFAULT_TIMEOUT) response = response.json() - for i in response['data']: + for i in response["data"]: if self.check_if_canceled("Jin10 processing"): return - jin10_res.append({"content": i['data']['content']}) + jin10_res.append({"content": i["data"]["content"]}) if self._param.type == "calendar": - params = { - 'category': self._param.calendar_type - } + params = {"category": self._param.calendar_type} response = requests.get( - url='https://open-data-api.jin10.com/data-api/calendar/' + self._param.calendar_datatype + '?category=' + self._param.calendar_type, - headers=headers, data=json.dumps(params), timeout=DEFAULT_TIMEOUT) + url="https://open-data-api.jin10.com/data-api/calendar/" + self._param.calendar_datatype + "?category=" + self._param.calendar_type, + headers=headers, + data=json.dumps(params), + timeout=DEFAULT_TIMEOUT, + ) response = response.json() if self.check_if_canceled("Jin10 processing"): return - jin10_res.append({"content": pd.DataFrame(response['data']).to_markdown()}) + jin10_res.append({"content": pd.DataFrame(response["data"]).to_markdown()}) if self._param.type == "symbols": - params = { - 'type': self._param.symbols_type - } + params = {"type": self._param.symbols_type} if self._param.symbols_datatype == "quotes": - params['codes'] = 'BTCUSD' + params["codes"] = "BTCUSD" response = requests.get( - url='https://open-data-api.jin10.com/data-api/' + self._param.symbols_datatype + '?type=' + self._param.symbols_type, - headers=headers, data=json.dumps(params), timeout=DEFAULT_TIMEOUT) + url="https://open-data-api.jin10.com/data-api/" + self._param.symbols_datatype + "?type=" + self._param.symbols_type, + headers=headers, + data=json.dumps(params), + timeout=DEFAULT_TIMEOUT, + ) response = response.json() if self.check_if_canceled("Jin10 processing"): return if self._param.symbols_datatype == "symbols": - for i in response['data']: + for i in response["data"]: if self.check_if_canceled("Jin10 processing"): return - i['Commodity Code'] = i['c'] - i['Stock Exchange'] = i['e'] - i['Commodity Name'] = i['n'] - i['Commodity Type'] = i['t'] - del i['c'], i['e'], i['n'], i['t'] + i["Commodity Code"] = i["c"] + i["Stock Exchange"] = i["e"] + i["Commodity Name"] = i["n"] + i["Commodity Type"] = i["t"] + del i["c"], i["e"], i["n"], i["t"] if self._param.symbols_datatype == "quotes": - for i in response['data']: + for i in response["data"]: if self.check_if_canceled("Jin10 processing"): return - i['Selling Price'] = i['a'] - i['Buying Price'] = i['b'] - i['Commodity Code'] = i['c'] - i['Stock Exchange'] = i['e'] - i['Highest Price'] = i['h'] - i['Yesterday’s Closing Price'] = i['hc'] - i['Lowest Price'] = i['l'] - i['Opening Price'] = i['o'] - i['Latest Price'] = i['p'] - i['Market Quote Time'] = i['t'] - del i['a'], i['b'], i['c'], i['e'], i['h'], i['hc'], i['l'], i['o'], i['p'], i['t'] - jin10_res.append({"content": pd.DataFrame(response['data']).to_markdown()}) + i["Selling Price"] = i["a"] + i["Buying Price"] = i["b"] + i["Commodity Code"] = i["c"] + i["Stock Exchange"] = i["e"] + i["Highest Price"] = i["h"] + i["Yesterday’s Closing Price"] = i["hc"] + i["Lowest Price"] = i["l"] + i["Opening Price"] = i["o"] + i["Latest Price"] = i["p"] + i["Market Quote Time"] = i["t"] + del i["a"], i["b"], i["c"], i["e"], i["h"], i["hc"], i["l"], i["o"], i["p"], i["t"] + jin10_res.append({"content": pd.DataFrame(response["data"]).to_markdown()}) if self._param.type == "news": - params = { - 'contain': self._param.contain, - 'filter': self._param.filter - } - response = requests.get( - url='https://open-data-api.jin10.com/data-api/news', - headers=headers, data=json.dumps(params), timeout=DEFAULT_TIMEOUT) + params = {"contain": self._param.contain, "filter": self._param.filter} + response = requests.get(url="https://open-data-api.jin10.com/data-api/news", headers=headers, data=json.dumps(params), timeout=DEFAULT_TIMEOUT) response = response.json() if self.check_if_canceled("Jin10 processing"): return - jin10_res.append({"content": pd.DataFrame(response['data']).to_markdown()}) + jin10_res.append({"content": pd.DataFrame(response["data"]).to_markdown()}) except Exception as e: if self.check_if_canceled("Jin10 processing"): return diff --git a/agent/tools/pubmed.py b/agent/tools/pubmed.py index 48117f1567..c5bfbe3d0d 100644 --- a/agent/tools/pubmed.py +++ b/agent/tools/pubmed.py @@ -30,7 +30,7 @@ class PubMedParam(ToolParamBase): """ def __init__(self): - self.meta:ToolMeta = { + self.meta: ToolMeta = { "name": "pubmed_search", "description": """ PubMed is an openly accessible, free database which includes primarily the MEDLINE database of references and abstracts on life sciences and biomedical topics. @@ -47,9 +47,9 @@ In addition to MEDLINE, PubMed provides access to: "type": "string", "description": "The search keywords to execute with PubMed. The keywords should be the most important words/terms(includes synonyms) from the original request.", "default": "{sys.query}", - "required": True + "required": True, } - } + }, } super().__init__() self.top_n = 12 @@ -59,12 +59,8 @@ In addition to MEDLINE, PubMed provides access to: self.check_positive_integer(self.top_n, "Top N") def get_input_form(self) -> dict[str, dict]: - return { - "query": { - "name": "Query", - "type": "line" - } - } + return {"query": {"name": "Query", "type": "line"}} + class PubMed(ToolBase, ABC): component_name = "PubMed" @@ -79,27 +75,28 @@ class PubMed(ToolBase, ABC): return "" last_e = "" - for _ in range(self._param.max_retries+1): + for _ in range(self._param.max_retries + 1): if self.check_if_canceled("PubMed processing"): return try: Entrez.email = self._param.email - pubmedids = Entrez.read(Entrez.esearch(db='pubmed', retmax=self._param.top_n, term=kwargs["query"]))['IdList'] + pubmedids = Entrez.read(Entrez.esearch(db="pubmed", retmax=self._param.top_n, term=kwargs["query"]))["IdList"] if self.check_if_canceled("PubMed processing"): return - pubmedcnt = ET.fromstring(re.sub(r'<(/?)b>|<(/?)i>', '', Entrez.efetch(db='pubmed', id=",".join(pubmedids), - retmode="xml").read().decode("utf-8"))) + pubmedcnt = ET.fromstring(re.sub(r"<(/?)b>|<(/?)i>", "", Entrez.efetch(db="pubmed", id=",".join(pubmedids), retmode="xml").read().decode("utf-8"))) if self.check_if_canceled("PubMed processing"): return - self._retrieve_chunks(pubmedcnt.findall("PubmedArticle"), - get_title=lambda child: child.find("MedlineCitation").find("Article").find("ArticleTitle").text, - get_url=lambda child: "https://pubmed.ncbi.nlm.nih.gov/" + child.find("MedlineCitation").find("PMID").text, - get_content=lambda child: self._format_pubmed_content(child),) + self._retrieve_chunks( + pubmedcnt.findall("PubmedArticle"), + get_title=lambda child: child.find("MedlineCitation").find("Article").find("ArticleTitle").text, + get_url=lambda child: "https://pubmed.ncbi.nlm.nih.gov/" + child.find("MedlineCitation").find("PMID").text, + get_content=lambda child: self._format_pubmed_content(child), + ) return self.output("formalized_content") except Exception as e: if self.check_if_canceled("PubMed processing"): @@ -117,6 +114,7 @@ class PubMed(ToolBase, ABC): def _format_pubmed_content(self, child): """Extract structured reference info from PubMed XML""" + def safe_find(path, base=None): node = child if base is None else base for p in path.split("/"): @@ -149,16 +147,7 @@ class PubMed(ToolBase, ABC): doi = eid.text break - return ( - f"Title: {title}\n" - f"Authors: {authors_str}\n" - f"Journal: {journal}\n" - f"Volume: {volume}\n" - f"Issue: {issue}\n" - f"Pages: {pages}\n" - f"DOI: {doi or '-'}\n" - f"Abstract: {abstract.strip()}" - ) + return f"Title: {title}\nAuthors: {authors_str}\nJournal: {journal}\nVolume: {volume}\nIssue: {issue}\nPages: {pages}\nDOI: {doi or '-'}\nAbstract: {abstract.strip()}" def thoughts(self) -> str: return "Looking for scholarly papers on `{}`,” prioritising reputable sources.".format(self.get_input().get("query", "-_-!")) diff --git a/agent/tools/qweather.py b/agent/tools/qweather.py index 2a1b7d9772..49f97f9d6a 100644 --- a/agent/tools/qweather.py +++ b/agent/tools/qweather.py @@ -30,7 +30,7 @@ class QWeatherParam(ComponentParamBase): self.web_apikey = "xxx" self.lang = "zh" self.type = "weather" - self.user_type = 'free' + self.user_type = "free" self.error_code = { "204": "The request was successful, but the region you are querying does not have the data you need at this time.", "400": "Request error, may contain incorrect request parameters or missing mandatory request parameters.", @@ -39,20 +39,53 @@ class QWeatherParam(ComponentParamBase): "403": "No access, may be the binding PackageName, BundleID, domain IP address is inconsistent, or the data that requires additional payment.", "404": "The queried data or region does not exist.", "429": "Exceeded the limited QPM (number of accesses per minute), please refer to the QPM description", - "500": "No response or timeout, interface service abnormality please contact us" - } + "500": "No response or timeout, interface service abnormality please contact us", + } # Weather - self.time_period = 'now' + self.time_period = "now" def check(self): self.check_empty(self.web_apikey, "BaiduFanyi APPID") self.check_valid_value(self.type, "Type", ["weather", "indices", "airquality"]) self.check_valid_value(self.user_type, "Free subscription or paid subscription", ["free", "paid"]) - self.check_valid_value(self.lang, "Use language", - ['zh', 'zh-hant', 'en', 'de', 'es', 'fr', 'it', 'ja', 'ko', 'ru', 'hi', 'th', 'ar', 'pt', - 'bn', 'ms', 'nl', 'el', 'la', 'sv', 'id', 'pl', 'tr', 'cs', 'et', 'vi', 'fil', 'fi', - 'he', 'is', 'nb']) - self.check_valid_value(self.time_period, "Time period", ['now', '3d', '7d', '10d', '15d', '30d']) + self.check_valid_value( + self.lang, + "Use language", + [ + "zh", + "zh-hant", + "en", + "de", + "es", + "fr", + "it", + "ja", + "ko", + "ru", + "hi", + "th", + "ar", + "pt", + "bn", + "ms", + "nl", + "el", + "la", + "sv", + "id", + "pl", + "tr", + "cs", + "et", + "vi", + "fil", + "fi", + "he", + "is", + "nb", + ], + ) + self.check_valid_value(self.time_period, "Time period", ["now", "3d", "7d", "10d", "15d", "30d"]) class QWeather(ComponentBase, ABC): @@ -71,9 +104,7 @@ class QWeather(ComponentBase, ABC): if self.check_if_canceled("Qweather processing"): return - response = requests.get( - url="https://geoapi.qweather.com/v2/city/lookup?location=" + ans + "&key=" + self._param.web_apikey, - timeout=DEFAULT_TIMEOUT).json() + response = requests.get(url="https://geoapi.qweather.com/v2/city/lookup?location=" + ans + "&key=" + self._param.web_apikey, timeout=DEFAULT_TIMEOUT).json() if response["code"] == "200": location_id = response["location"][0]["id"] else: @@ -82,7 +113,7 @@ class QWeather(ComponentBase, ABC): if self.check_if_canceled("Qweather processing"): return - base_url = "https://api.qweather.com/v7/" if self._param.user_type == 'paid' else "https://devapi.qweather.com/v7/" + base_url = "https://api.qweather.com/v7/" if self._param.user_type == "paid" else "https://devapi.qweather.com/v7/" if self._param.type == "weather": url = base_url + "weather/" + self._param.time_period + "?location=" + location_id + "&key=" + self._param.web_apikey + "&lang=" + self._param.lang @@ -110,8 +141,7 @@ class QWeather(ComponentBase, ABC): if self.check_if_canceled("Qweather processing"): return if response["code"] == "200": - indices_res = response["daily"][0]["date"] + "\n" + "\n".join( - [i["name"] + ": " + i["category"] + ", " + i["text"] for i in response["daily"]]) + indices_res = response["daily"][0]["date"] + "\n" + "\n".join([i["name"] + ": " + i["category"] + ", " + i["text"] for i in response["daily"]]) return QWeather.be_output(indices_res) else: diff --git a/agent/tools/retrieval.py b/agent/tools/retrieval.py index 0d31490b52..2fb649e869 100644 --- a/agent/tools/retrieval.py +++ b/agent/tools/retrieval.py @@ -40,7 +40,7 @@ class RetrievalParam(ToolParamBase): """ def __init__(self): - self.meta:ToolMeta = { + self.meta: ToolMeta = { "name": "search_my_dateset", "description": "This tool can be utilized for relevant content searching in the datasets.", "parameters": { @@ -48,9 +48,9 @@ class RetrievalParam(ToolParamBase): "type": "string", "description": "The keywords to search the dataset. The keywords should be the most important words/terms(includes synonyms) from the original request.", "default": "", - "required": True + "required": True, } - } + }, } super().__init__() self.function_name = "search_my_dateset" @@ -68,7 +68,7 @@ class RetrievalParam(ToolParamBase): self.use_kg = False self.cross_languages = [] self.toc_enhance = False - self.meta_data_filter={} + self.meta_data_filter = {} def check(self): self.check_decimal_float(self.similarity_threshold, "[Retrieval] Similarity threshold") @@ -76,12 +76,8 @@ class RetrievalParam(ToolParamBase): self.check_positive_number(self.top_n, "[Retrieval] Top N") def get_input_form(self) -> dict[str, dict]: - return { - "query": { - "name": "Query", - "type": "line" - } - } + return {"query": {"name": "Query", "type": "line"}} + class Retrieval(ToolBase, ABC): component_name = "Retrieval" @@ -101,8 +97,7 @@ class Retrieval(ToolBase, ABC): # if kb_nm is a list kb_nm_list = kb_nm if isinstance(kb_nm, list) else [kb_nm] for nm_or_id in kb_nm_list: - e, kb = KnowledgebaseService.get_by_name(nm_or_id, - self._canvas._tenant_id) + e, kb = KnowledgebaseService.get_by_name(nm_or_id, self._canvas._tenant_id) if not e: e, kb = KnowledgebaseService.get_by_id(nm_or_id) if not e: @@ -153,7 +148,7 @@ class Retrieval(ToolBase, ABC): last = 0 for m in pat.finditer(s): - out_parts.append(s[last:m.start()]) + out_parts.append(s[last : m.start()]) key = m.group(1) v = self._canvas.get_variable_value(key) if v is None: @@ -220,22 +215,16 @@ class Retrieval(ToolBase, ABC): tenant_id = self._canvas._tenant_id chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(tenant_id, chat_model_config) - cks = await settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], - chat_mdl, self._param.top_n) + cks = await settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], chat_mdl, self._param.top_n) if self.check_if_canceled("Retrieval processing"): return if cks: kbinfos["chunks"] = cks - kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"], - [kb.tenant_id for kb in kbs]) + kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"], [kb.tenant_id for kb in kbs]) if self._param.use_kg: tenant_id = self._canvas.get_tenant_id() chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT) - ck = await settings.kg_retriever.retrieval(query, - [kb.tenant_id for kb in kbs], - kb_ids, - embd_mdl, - LLMBundle(tenant_id, chat_model_config)) + ck = await settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], kb_ids, embd_mdl, LLMBundle(tenant_id, chat_model_config)) if self.check_if_canceled("Retrieval processing"): return if ck["content_with_weight"]: @@ -245,8 +234,7 @@ class Retrieval(ToolBase, ABC): if self._param.use_kg and kbs: chat_model_config = get_tenant_default_model_by_type(kbs[0].tenant_id, LLMType.CHAT) - ck = await settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl, - LLMBundle(kbs[0].tenant_id, chat_model_config)) + ck = await settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl, LLMBundle(kbs[0].tenant_id, chat_model_config)) if self.check_if_canceled("Retrieval processing"): return if ck["content_with_weight"]: @@ -293,16 +281,14 @@ class Retrieval(ToolBase, ABC): filter_dict: dict = {"memory_id": memory_ids} if user_id: import re + # is variable if re.match(r"^{.*}$", user_id): user_id = self._canvas.get_variable_value(user_id) filter_dict["user_id"] = user_id - message_list = memory_message_service.query_message(filter_dict, { - "query": query, - "similarity_threshold": self._param.similarity_threshold, - "keywords_similarity_weight": self._param.keywords_similarity_weight, - "top_n": self._param.top_n - }) + message_list = memory_message_service.query_message( + filter_dict, {"query": query, "similarity_threshold": self._param.similarity_threshold, "keywords_similarity_weight": self._param.keywords_similarity_weight, "top_n": self._param.top_n} + ) if not message_list: self.set_output("formalized_content", self._param.empty_response) return "" diff --git a/agent/tools/tavily.py b/agent/tools/tavily.py index 1f1fa01375..61eee6046f 100644 --- a/agent/tools/tavily.py +++ b/agent/tools/tavily.py @@ -28,7 +28,7 @@ class TavilySearchParam(ToolParamBase): """ def __init__(self): - self.meta:ToolMeta = { + self.meta: ToolMeta = { "name": "tavily_search", "description": """ Tavily is a search engine optimized for LLMs, aimed at efficient, quick and persistent search results. @@ -43,7 +43,7 @@ When searching: "type": "string", "description": "The search keywords to execute with Tavily. The keywords should be the most important words/terms(includes synonyms) from the original request.", "default": "{sys.query}", - "required": True + "required": True, }, "topic": { "type": "string", @@ -56,27 +56,21 @@ When searching: "type": "array", "description": "default:[]. A list of domains only from which the search results can be included.", "default": [], - "items": { - "type": "string", - "description": "Domain name that must be included, e.g. www.yahoo.com" - }, - "required": False + "items": {"type": "string", "description": "Domain name that must be included, e.g. www.yahoo.com"}, + "required": False, }, "exclude_domains": { "type": "array", "description": "default:[]. A list of domains from which the search results can not be included", "default": [], - "items": { - "type": "string", - "description": "Domain name that must be excluded, e.g. www.yahoo.com" - }, - "required": False + "items": {"type": "string", "description": "Domain name that must be excluded, e.g. www.yahoo.com"}, + "required": False, }, - } + }, } super().__init__() self.api_key = "" - self.search_depth = "basic" # basic/advanced + self.search_depth = "basic" # basic/advanced self.max_results = 6 self.days = 14 self.include_answer = False @@ -91,12 +85,8 @@ When searching: self.check_positive_integer(self.days, "Tavily days should be greater than 1") def get_input_form(self) -> dict[str, dict]: - return { - "query": { - "name": "Query", - "type": "line" - } - } + return {"query": {"name": "Query", "type": "line"}} + class TavilySearch(ToolBase, ABC): component_name = "TavilySearch" @@ -115,7 +105,7 @@ class TavilySearch(ToolBase, ABC): for fld in ["search_depth", "topic", "max_results", "days", "include_answer", "include_raw_content", "include_images", "include_image_descriptions", "include_domains", "exclude_domains"]: if fld not in kwargs: kwargs[fld] = getattr(self._param, fld) - for _ in range(self._param.max_retries+1): + for _ in range(self._param.max_retries + 1): if self.check_if_canceled("TavilySearch processing"): return @@ -126,11 +116,13 @@ class TavilySearch(ToolBase, ABC): if self.check_if_canceled("TavilySearch processing"): return - self._retrieve_chunks(res["results"], - get_title=lambda r: r["title"], - get_url=lambda r: r["url"], - get_content=lambda r: r["raw_content"] if r["raw_content"] else r["content"], - get_score=lambda r: r["score"]) + self._retrieve_chunks( + res["results"], + get_title=lambda r: r["title"], + get_url=lambda r: r["url"], + get_content=lambda r: r["raw_content"] if r["raw_content"] else r["content"], + get_score=lambda r: r["score"], + ) self.set_output("json", res["results"]) return self.output("formalized_content") except Exception as e: @@ -159,7 +151,7 @@ class TavilyExtractParam(ToolParamBase): """ def __init__(self): - self.meta:ToolMeta = { + self.meta: ToolMeta = { "name": "tavily_extract", "description": "Extract web page content from one or more specified URLs using Tavily Extract.", "parameters": { @@ -167,11 +159,8 @@ class TavilyExtractParam(ToolParamBase): "type": "array", "description": "The URLs to extract content from.", "default": "", - "items": { - "type": "string", - "description": "The URL to extract content from, e.g. www.yahoo.com" - }, - "required": True + "items": {"type": "string", "description": "The URL to extract content from, e.g. www.yahoo.com"}, + "required": True, }, "extract_depth": { "type": "string", @@ -186,12 +175,12 @@ class TavilyExtractParam(ToolParamBase): "enum": ["markdown", "text"], "default": "markdown", "required": False, - } - } + }, + }, } super().__init__() self.api_key = "" - self.extract_depth = "basic" # basic/advanced + self.extract_depth = "basic" # basic/advanced self.urls = [] self.format = "markdown" self.include_images = False @@ -201,17 +190,13 @@ class TavilyExtractParam(ToolParamBase): self.check_valid_value(self.format, "Tavily extract format should be in 'markdown/text'", ["markdown", "text"]) def get_input_form(self) -> dict[str, dict]: - return { - "urls": { - "name": "URLs", - "type": "line" - } - } + return {"urls": {"name": "URLs", "type": "line"}} + class TavilyExtract(ToolBase, ABC): component_name = "TavilyExtract" - @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) + @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10 * 60))) def _invoke(self, **kwargs): if self.check_if_canceled("TavilyExtract processing"): return @@ -223,7 +208,7 @@ class TavilyExtract(ToolBase, ABC): kwargs[fld] = getattr(self._param, fld) if kwargs.get("urls") and isinstance(kwargs["urls"], str): kwargs["urls"] = kwargs["urls"].split(",") - for _ in range(self._param.max_retries+1): + for _ in range(self._param.max_retries + 1): if self.check_if_canceled("TavilyExtract processing"): return diff --git a/agent/tools/tushare.py b/agent/tools/tushare.py index cb814f5cdd..c6cc326520 100644 --- a/agent/tools/tushare.py +++ b/agent/tools/tushare.py @@ -37,8 +37,7 @@ class TuShareParam(ComponentParamBase): self.keyword = "" def check(self): - self.check_valid_value(self.src, "Quick News Source", - ["sina", "wallstreetcn", "10jqka", "eastmoney", "yuncaijing", "fenghuang", "jinrongjie"]) + self.check_valid_value(self.src, "Quick News Source", ["sina", "wallstreetcn", "10jqka", "eastmoney", "yuncaijing", "fenghuang", "jinrongjie"]) class TuShare(ComponentBase, ABC): @@ -58,20 +57,15 @@ class TuShare(ComponentBase, ABC): return tus_res = [] - params = { - "api_name": "news", - "token": self._param.token, - "params": {"src": self._param.src, "start_date": self._param.start_date, - "end_date": self._param.end_date} - } - response = requests.post(url="http://api.tushare.pro", data=json.dumps(params).encode('utf-8'), timeout=DEFAULT_TIMEOUT) + params = {"api_name": "news", "token": self._param.token, "params": {"src": self._param.src, "start_date": self._param.start_date, "end_date": self._param.end_date}} + response = requests.post(url="http://api.tushare.pro", data=json.dumps(params).encode("utf-8"), timeout=DEFAULT_TIMEOUT) response = response.json() if self.check_if_canceled("TuShare processing"): return - if response['code'] != 0: - return TuShare.be_output(response['msg']) - df = pd.DataFrame(response['data']['items']) - df.columns = response['data']['fields'] + if response["code"] != 0: + return TuShare.be_output(response["msg"]) + df = pd.DataFrame(response["data"]["items"]) + df.columns = response["data"]["fields"] if self.check_if_canceled("TuShare processing"): return keyword = self._param.keyword or ans @@ -79,13 +73,7 @@ class TuShare(ComponentBase, ABC): "TuShare news filter keyword source=%s", "param.keyword" if self._param.keyword else "upstream_input", ) - tus_res.append( - { - "content": ( - df[df["content"].str.contains(keyword, case=False, na=False, regex=False)] - ).to_markdown() - } - ) + tus_res.append({"content": (df[df["content"].str.contains(keyword, case=False, na=False, regex=False)]).to_markdown()}) except Exception as e: if self.check_if_canceled("TuShare processing"): return diff --git a/agent/tools/wencai.py b/agent/tools/wencai.py index 18e7b14c46..18b53e3b89 100644 --- a/agent/tools/wencai.py +++ b/agent/tools/wencai.py @@ -30,21 +30,14 @@ class WenCaiParam(ToolParamBase): """ def __init__(self): - self.meta:ToolMeta = { + self.meta: ToolMeta = { "name": "iwencai", "description": """ iwencai search: search platform is committed to providing hundreds of millions of investors with the most timely, accurate and comprehensive information, covering news, announcements, research reports, blogs, forums, Weibo, characters, etc. robo-advisor intelligent stock selection platform: through AI technology, is committed to providing investors with intelligent stock selection, quantitative investment, main force tracking, value investment, technical analysis and other types of stock selection technologies. fund selection platform: through AI technology, is committed to providing excellent fund, value investment, quantitative analysis and other fund selection technologies for foundation citizens. """, - "parameters": { - "query": { - "type": "string", - "description": "The question/conditions to select stocks.", - "default": "{sys.query}", - "required": True - } - } + "parameters": {"query": {"type": "string", "description": "The question/conditions to select stocks.", "default": "{sys.query}", "required": True}}, } super().__init__() self.top_n = 10 @@ -52,18 +45,11 @@ fund selection platform: through AI technology, is committed to providing excell def check(self): self.check_positive_integer(self.top_n, "Top N") - self.check_valid_value(self.query_type, "Query type", - ['stock', 'zhishu', 'fund', 'hkstock', 'usstock', 'threeboard', 'conbond', 'insurance', - 'futures', 'lccp', - 'foreign_exchange']) + self.check_valid_value(self.query_type, "Query type", ["stock", "zhishu", "fund", "hkstock", "usstock", "threeboard", "conbond", "insurance", "futures", "lccp", "foreign_exchange"]) def get_input_form(self) -> dict[str, dict]: - return { - "query": { - "name": "Query", - "type": "line" - } - } + return {"query": {"name": "Query", "type": "line"}} + class WenCai(ToolBase, ABC): component_name = "WenCai" @@ -78,7 +64,7 @@ class WenCai(ToolBase, ABC): return "" last_e = "" - for _ in range(self._param.max_retries+1): + for _ in range(self._param.max_retries + 1): if self.check_if_canceled("WenCai processing"): return @@ -103,7 +89,7 @@ class WenCai(ToolBase, ABC): elif isinstance(item[1], dict): if "meta" in item[1].keys(): continue - wencai_res.append(pd.DataFrame.from_dict(item[1], orient='index').to_markdown()) + wencai_res.append(pd.DataFrame.from_dict(item[1], orient="index").to_markdown()) elif isinstance(item[1], pd.DataFrame): if "image_url" in item[1].columns: continue diff --git a/agent/tools/wikipedia.py b/agent/tools/wikipedia.py index 5bd76d80ee..301a271376 100644 --- a/agent/tools/wikipedia.py +++ b/agent/tools/wikipedia.py @@ -28,7 +28,7 @@ class WikipediaParam(ToolParamBase): """ def __init__(self): - self.meta:ToolMeta = { + self.meta: ToolMeta = { "name": "wikipedia_search", "description": """A wide range of how-to and information pages are made available in wikipedia. Since 2001, it has grown rapidly to become the world's largest reference website. From Wikipedia, the free encyclopedia.""", "parameters": { @@ -36,9 +36,9 @@ class WikipediaParam(ToolParamBase): "type": "string", "description": "The search keyword to execute with wikipedia. The keyword MUST be a specific subject that can match the title.", "default": "{sys.query}", - "required": True + "required": True, } - } + }, } super().__init__() self.top_n = 10 @@ -46,20 +46,86 @@ class WikipediaParam(ToolParamBase): def check(self): self.check_positive_integer(self.top_n, "Top N") - self.check_valid_value(self.language, "Wikipedia languages", - ['af', 'pl', 'ar', 'ast', 'az', 'bg', 'nan', 'bn', 'be', 'ca', 'cs', 'cy', 'da', 'de', - 'et', 'el', 'en', 'es', 'eo', 'eu', 'fa', 'fr', 'gl', 'ko', 'hy', 'hi', 'hr', 'id', - 'it', 'he', 'ka', 'lld', 'la', 'lv', 'lt', 'hu', 'mk', 'arz', 'ms', 'min', 'my', 'nl', - 'ja', 'nb', 'nn', 'ce', 'uz', 'pt', 'kk', 'ro', 'ru', 'ceb', 'sk', 'sl', 'sr', 'sh', - 'fi', 'sv', 'ta', 'tt', 'th', 'tg', 'azb', 'tr', 'uk', 'ur', 'vi', 'war', 'zh', 'yue']) + self.check_valid_value( + self.language, + "Wikipedia languages", + [ + "af", + "pl", + "ar", + "ast", + "az", + "bg", + "nan", + "bn", + "be", + "ca", + "cs", + "cy", + "da", + "de", + "et", + "el", + "en", + "es", + "eo", + "eu", + "fa", + "fr", + "gl", + "ko", + "hy", + "hi", + "hr", + "id", + "it", + "he", + "ka", + "lld", + "la", + "lv", + "lt", + "hu", + "mk", + "arz", + "ms", + "min", + "my", + "nl", + "ja", + "nb", + "nn", + "ce", + "uz", + "pt", + "kk", + "ro", + "ru", + "ceb", + "sk", + "sl", + "sr", + "sh", + "fi", + "sv", + "ta", + "tt", + "th", + "tg", + "azb", + "tr", + "uk", + "ur", + "vi", + "war", + "zh", + "yue", + ], + ) def get_input_form(self) -> dict[str, dict]: - return { - "query": { - "name": "Query", - "type": "line" - } - } + return {"query": {"name": "Query", "type": "line"}} + class Wikipedia(ToolBase, ABC): """Wikipedia search tool that retrieves and processes Wikipedia articles.""" @@ -84,7 +150,7 @@ class Wikipedia(ToolBase, ABC): return "" last_e = "" - for _ in range(self._param.max_retries+1): + for _ in range(self._param.max_retries + 1): if self.check_if_canceled("Wikipedia processing"): return @@ -104,10 +170,7 @@ class Wikipedia(ToolBase, ABC): logging.info(f"Wikipedia page not found: '{p}'") except Exception as e: logging.exception(f"Unexpected error fetching Wikipedia page '{p}': {e}") - self._retrieve_chunks(pages, - get_title=lambda r: r.title, - get_url=lambda r: r.url, - get_content=lambda r: r.summary) + self._retrieve_chunks(pages, get_title=lambda r: r.title, get_url=lambda r: r.url, get_content=lambda r: r.summary) return self.output("formalized_content") except Exception as e: if self.check_if_canceled("Wikipedia processing"): diff --git a/agent/tools/yahoofinance.py b/agent/tools/yahoofinance.py index 06a4a9dad4..69254b04fc 100644 --- a/agent/tools/yahoofinance.py +++ b/agent/tools/yahoofinance.py @@ -29,17 +29,10 @@ class YahooFinanceParam(ToolParamBase): """ def __init__(self): - self.meta:ToolMeta = { + self.meta: ToolMeta = { "name": "yahoo_finance", "description": "The Yahoo Finance is a service that provides access to real-time and historical stock market data. It enables users to fetch various types of stock information, such as price quotes, historical prices, company profiles, and financial news. The API offers structured data, allowing developers to integrate market data into their applications and analysis tools.", - "parameters": { - "stock_code": { - "type": "string", - "description": "The stock code or company name.", - "default": "{sys.query}", - "required": True - } - } + "parameters": {"stock_code": {"type": "string", "description": "The stock code or company name.", "default": "{sys.query}", "required": True}}, } super().__init__() self.info = True @@ -62,12 +55,8 @@ class YahooFinanceParam(ToolParamBase): self.check_boolean(self.news, "show news") def get_input_form(self) -> dict[str, dict]: - return { - "stock_code": { - "name": "Stock code/Company name", - "type": "line" - } - } + return {"stock_code": {"name": "Stock code/Company name", "type": "line"}} + class YahooFinance(ToolBase, ABC): component_name = "YahooFinance" @@ -82,7 +71,7 @@ class YahooFinance(ToolBase, ABC): return "" last_e = "" - for _ in range(self._param.max_retries+1): + for _ in range(self._param.max_retries + 1): if self.check_if_canceled("YahooFinance processing"): return None diff --git a/api/apps/auth/__init__.py b/api/apps/auth/__init__.py index f989b6d2e3..662099f1c2 100644 --- a/api/apps/auth/__init__.py +++ b/api/apps/auth/__init__.py @@ -19,14 +19,10 @@ from .oidc import OIDCClient from .github import GithubOAuthClient -CLIENT_TYPES = { - "oauth2": OAuthClient, - "oidc": OIDCClient, - "github": GithubOAuthClient -} +CLIENT_TYPES = {"oauth2": OAuthClient, "oidc": OIDCClient, "github": GithubOAuthClient} -def get_auth_client(config)->OAuthClient: +def get_auth_client(config) -> OAuthClient: channel_type = str(config.get("type", "")).lower() if channel_type == "": if config.get("issuer"): diff --git a/api/apps/auth/github.py b/api/apps/auth/github.py index 918ff60db8..ef173376de 100644 --- a/api/apps/auth/github.py +++ b/api/apps/auth/github.py @@ -23,15 +23,16 @@ class GithubOAuthClient(OAuthClient): """ Initialize the GithubOAuthClient with the provider's configuration. """ - config.update({ - "authorization_url": "https://github.com/login/oauth/authorize", - "token_url": "https://github.com/login/oauth/access_token", - "userinfo_url": "https://api.github.com/user", - "scope": "user:email" - }) + config.update( + { + "authorization_url": "https://github.com/login/oauth/authorize", + "token_url": "https://github.com/login/oauth/access_token", + "userinfo_url": "https://api.github.com/user", + "scope": "user:email", + } + ) super().__init__(config) - def fetch_user_info(self, access_token, **kwargs): """ Fetch GitHub user info (synchronous). @@ -42,9 +43,7 @@ class GithubOAuthClient(OAuthClient): response = sync_request("GET", self.userinfo_url, headers=headers, timeout=self.http_request_timeout) response.raise_for_status() user_info.update(response.json()) - email_response = sync_request( - "GET", self.userinfo_url + "/emails", headers=headers, timeout=self.http_request_timeout - ) + email_response = sync_request("GET", self.userinfo_url + "/emails", headers=headers, timeout=self.http_request_timeout) email_response.raise_for_status() email_info = email_response.json() user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"] @@ -79,7 +78,6 @@ class GithubOAuthClient(OAuthClient): except Exception as e: raise ValueError(f"Failed to fetch github user info: {e}") - def normalize_user_info(self, user_info): email = user_info.get("email") username = user_info.get("login", str(email).split("@")[0]) diff --git a/api/apps/auth/oauth.py b/api/apps/auth/oauth.py index 5b2afcea1d..56edd589b6 100644 --- a/api/apps/auth/oauth.py +++ b/api/apps/auth/oauth.py @@ -24,7 +24,7 @@ class UserInfo: self.username = username self.nickname = nickname self.avatar_url = avatar_url - + def to_dict(self): return {key: value for key, value in self.__dict__.items()} @@ -44,7 +44,6 @@ class OAuthClient: self.http_request_timeout = 7 - def get_authorization_url(self, state=None): """ Generate the authorization URL for user login. @@ -61,19 +60,12 @@ class OAuthClient: authorization_url = f"{self.authorization_url}?{urllib.parse.urlencode(params)}" return authorization_url - def exchange_code_for_token(self, code): """ Exchange authorization code for access token. """ try: - payload = { - "client_id": self.client_id, - "client_secret": self.client_secret, - "code": code, - "redirect_uri": self.redirect_uri, - "grant_type": "authorization_code" - } + payload = {"client_id": self.client_id, "client_secret": self.client_secret, "code": code, "redirect_uri": self.redirect_uri, "grant_type": "authorization_code"} response = sync_request( "POST", self.token_url, @@ -110,7 +102,6 @@ class OAuthClient: except Exception as e: raise ValueError(f"Failed to exchange authorization code for token: {e}") - def fetch_user_info(self, access_token, **kwargs): """ Fetch user information using access token. @@ -140,7 +131,6 @@ class OAuthClient: except Exception as e: raise ValueError(f"Failed to fetch user info: {e}") - def normalize_user_info(self, user_info): email = user_info.get("email") username = user_info.get("username", str(email).split("@")[0]) diff --git a/api/apps/auth/oidc.py b/api/apps/auth/oidc.py index e28e982805..2f0808551d 100644 --- a/api/apps/auth/oidc.py +++ b/api/apps/auth/oidc.py @@ -26,12 +26,20 @@ from .oauth import OAuthClient # forge tokens by HMAC-signing them with the public key bytes # (RSA/HMAC algorithm-confusion attack, CWE-347). "none" is excluded for the # obvious reason that it disables signature verification entirely. -_ALLOWED_OIDC_SIGNING_ALGS = frozenset({ - "RS256", "RS384", "RS512", - "ES256", "ES384", "ES512", - "PS256", "PS384", "PS512", - "EdDSA", -}) +_ALLOWED_OIDC_SIGNING_ALGS = frozenset( + { + "RS256", + "RS384", + "RS512", + "ES256", + "ES384", + "ES512", + "PS256", + "PS384", + "PS512", + "EdDSA", + } +) # OIDC Core 1.0 § 2 makes RS256 the spec-default ``id_token_signing_alg``, # so this is the safe fallback when a provider's discovery document does not @@ -69,24 +77,25 @@ class OIDCClient(OAuthClient): raise ValueError("Missing issuer in configuration.") oidc_metadata = self._load_oidc_metadata(self.issuer) - config.update({ - 'issuer': oidc_metadata['issuer'], - 'jwks_uri': oidc_metadata['jwks_uri'], - 'authorization_url': oidc_metadata['authorization_endpoint'], - 'token_url': oidc_metadata['token_endpoint'], - 'userinfo_url': oidc_metadata['userinfo_endpoint'] - }) + config.update( + { + "issuer": oidc_metadata["issuer"], + "jwks_uri": oidc_metadata["jwks_uri"], + "authorization_url": oidc_metadata["authorization_endpoint"], + "token_url": oidc_metadata["token_endpoint"], + "userinfo_url": oidc_metadata["userinfo_endpoint"], + } + ) super().__init__(config) - self.issuer = config['issuer'] - self.jwks_uri = config['jwks_uri'] + self.issuer = config["issuer"] + self.jwks_uri = config["jwks_uri"] # Pin the accepted ID-token signing algorithms at construction time # from a trusted source (provider metadata + safe allowlist) so the # JWT verification step in :meth:`parse_id_token` cannot be tricked # by attacker-controlled JWT headers (CWE-345 / CWE-347). self.id_token_signing_algs = _resolve_id_token_signing_algs(oidc_metadata) - @staticmethod def _load_oidc_metadata(issuer): """ @@ -100,7 +109,6 @@ class OIDCClient(OAuthClient): except Exception as e: raise ValueError(f"Failed to fetch OIDC metadata: {e}") - def parse_id_token(self, id_token): """ Parse and validate OIDC ID Token (JWT format) with signature verification. @@ -134,7 +142,6 @@ class OIDCClient(OAuthClient): except Exception as e: raise ValueError(f"Error parsing ID Token: {e}") - def fetch_user_info(self, access_token, id_token=None, **kwargs): """ Fetch user info. @@ -152,6 +159,5 @@ class OIDCClient(OAuthClient): user_info.update((await super().async_fetch_user_info(access_token)).to_dict()) return self.normalize_user_info(user_info) - def normalize_user_info(self, user_info): return super().normalize_user_info(user_info) diff --git a/api/apps/backward_compat.py b/api/apps/backward_compat.py index e9ad3bb505..3b4c682919 100644 --- a/api/apps/backward_compat.py +++ b/api/apps/backward_compat.py @@ -44,6 +44,7 @@ Deprecated APIs and their replacements: - POST /api/v1/sessions/related_questions -> POST /api/v1/chat/recommandation - PUT (chunk update) -> PATCH (chunk update) """ + import logging from quart import Blueprint, jsonify, request @@ -68,6 +69,7 @@ def _index_result(success, result): # System APIs # ============================================================================= + @legacy_v1_manager.route("/system/healthz", methods=["GET"]) async def deprecated_system_healthz(): """ @@ -76,17 +78,16 @@ async def deprecated_system_healthz(): Old path: GET /v1/system/healthz New path: GET /api/v1/system/healthz """ - logging.warning( - "API endpoint /v1/system/healthz is deprecated. " - "Please use /api/v1/system/healthz instead." - ) + logging.warning("API endpoint /v1/system/healthz is deprecated. Please use /api/v1/system/healthz instead.") result, all_ok = run_health_checks() return jsonify(result), (200 if all_ok else 500) + # ============================================================================= # Chat Completion APIs # ============================================================================= + @manager.route("/chats//completions", methods=["POST"]) @login_required async def deprecated_chat_completions(chat_id): @@ -97,8 +98,7 @@ async def deprecated_chat_completions(chat_id): New path: POST /api/v1/chat/completions """ logging.warning( - "API endpoint /api/v1/chats/%s/completions is deprecated. " - "Please use /api/v1/chat/completions instead.", + "API endpoint /api/v1/chats/%s/completions is deprecated. Please use /api/v1/chat/completions instead.", chat_id, ) # Forward to the new API implementation @@ -115,9 +115,9 @@ async def deprecated_openai_chat_completions(chat_id): New path: POST /api/v1/openai/{chat_id}/chat/completions """ logging.warning( - "API endpoint /api/v1/chats_openai/%s/chat/completions is deprecated. " - "Please use /api/v1/openai/%s/chat/completions instead.", - chat_id, chat_id, + "API endpoint /api/v1/chats_openai/%s/chat/completions is deprecated. Please use /api/v1/openai/%s/chat/completions instead.", + chat_id, + chat_id, ) # Forward to the new API implementation return await openai_api.openai_chat_completions(chat_id) @@ -134,8 +134,7 @@ async def deprecated_agents_openai_chat_completions(agent_id, tenant_id=None): New path: POST /api/v1/agents/chat/completions """ logging.warning( - "API endpoint /api/v1/agents_openai/%s/chat/completions is deprecated. " - "Please use /api/v1/agents/chat/completions with `openai-compatible` instead.", + "API endpoint /api/v1/agents_openai/%s/chat/completions is deprecated. Please use /api/v1/agents/chat/completions with `openai-compatible` instead.", agent_id, ) req = dict(await get_request_json()) @@ -148,6 +147,7 @@ async def deprecated_agents_openai_chat_completions(agent_id, tenant_id=None): # Dataset Graph and Index APIs # ============================================================================= + @manager.route("/datasets//knowledge_graph", methods=["GET"]) @login_required async def deprecated_get_knowledge_graph(dataset_id): @@ -158,9 +158,9 @@ async def deprecated_get_knowledge_graph(dataset_id): New path: GET /api/v1/datasets/{dataset_id}/graph """ logging.warning( - "API endpoint /api/v1/datasets/%s/knowledge_graph is deprecated. " - "Please use /api/v1/datasets/%s/graph instead.", - dataset_id, dataset_id, + "API endpoint /api/v1/datasets/%s/knowledge_graph is deprecated. Please use /api/v1/datasets/%s/graph instead.", + dataset_id, + dataset_id, ) return await dataset_api.get_knowledge_graph(dataset_id=dataset_id) @@ -175,9 +175,9 @@ async def deprecated_delete_knowledge_graph(dataset_id): New path: DELETE /api/v1/datasets/{dataset_id}/graph """ logging.warning( - "API endpoint DELETE /api/v1/datasets/%s/knowledge_graph is deprecated. " - "Please use DELETE /api/v1/datasets/%s/graph instead.", - dataset_id, dataset_id, + "API endpoint DELETE /api/v1/datasets/%s/knowledge_graph is deprecated. Please use DELETE /api/v1/datasets/%s/graph instead.", + dataset_id, + dataset_id, ) return await dataset_api.delete_knowledge_graph(dataset_id=dataset_id) @@ -193,9 +193,9 @@ async def deprecated_run_graphrag(dataset_id, tenant_id=None): New path: POST /api/v1/datasets/{dataset_id}/index?type=graph """ logging.warning( - "API endpoint /api/v1/datasets/%s/run_graphrag is deprecated. " - "Please use /api/v1/datasets/%s/index?type=graph instead.", - dataset_id, dataset_id, + "API endpoint /api/v1/datasets/%s/run_graphrag is deprecated. Please use /api/v1/datasets/%s/index?type=graph instead.", + dataset_id, + dataset_id, ) return _index_result(*dataset_api_service.run_index(dataset_id, tenant_id, "graph")) @@ -211,9 +211,9 @@ async def deprecated_trace_graphrag(dataset_id, tenant_id=None): New path: GET /api/v1/datasets/{dataset_id}/index?type=graph """ logging.warning( - "API endpoint /api/v1/datasets/%s/trace_graphrag is deprecated. " - "Please use /api/v1/datasets/%s/index?type=graph instead.", - dataset_id, dataset_id, + "API endpoint /api/v1/datasets/%s/trace_graphrag is deprecated. Please use /api/v1/datasets/%s/index?type=graph instead.", + dataset_id, + dataset_id, ) return _index_result(*dataset_api_service.trace_index(dataset_id, tenant_id, "graph")) @@ -229,9 +229,9 @@ async def deprecated_run_raptor(dataset_id, tenant_id=None): New path: POST /api/v1/datasets/{dataset_id}/index?type=raptor """ logging.warning( - "API endpoint /api/v1/datasets/%s/run_raptor is deprecated. " - "Please use /api/v1/datasets/%s/index?type=raptor instead.", - dataset_id, dataset_id, + "API endpoint /api/v1/datasets/%s/run_raptor is deprecated. Please use /api/v1/datasets/%s/index?type=raptor instead.", + dataset_id, + dataset_id, ) return _index_result(*dataset_api_service.run_index(dataset_id, tenant_id, "raptor")) @@ -247,9 +247,9 @@ async def deprecated_trace_raptor(dataset_id, tenant_id=None): New path: GET /api/v1/datasets/{dataset_id}/index?type=raptor """ logging.warning( - "API endpoint /api/v1/datasets/%s/trace_raptor is deprecated. " - "Please use /api/v1/datasets/%s/index?type=raptor instead.", - dataset_id, dataset_id, + "API endpoint /api/v1/datasets/%s/trace_raptor is deprecated. Please use /api/v1/datasets/%s/index?type=raptor instead.", + dataset_id, + dataset_id, ) return _index_result(*dataset_api_service.trace_index(dataset_id, tenant_id, "raptor")) @@ -258,6 +258,7 @@ async def deprecated_trace_raptor(dataset_id, tenant_id=None): # Chat Session APIs # ============================================================================= + @manager.route("/chats//sessions/", methods=["PUT"]) @login_required async def deprecated_update_session(chat_id, session_id): @@ -268,9 +269,11 @@ async def deprecated_update_session(chat_id, session_id): New path: PATCH /api/v1/chats/{chat_id}/sessions/{session_id} """ logging.warning( - "API endpoint PUT /api/v1/chats/%s/sessions/%s is deprecated. " - "Please use PATCH /api/v1/chats/%s/sessions/%s instead.", - chat_id, session_id, chat_id, session_id, + "API endpoint PUT /api/v1/chats/%s/sessions/%s is deprecated. Please use PATCH /api/v1/chats/%s/sessions/%s instead.", + chat_id, + session_id, + chat_id, + session_id, ) # Forward to the new API implementation return await chat_api.update_session(chat_id, session_id) @@ -280,6 +283,7 @@ async def deprecated_update_session(chat_id, session_id): # File APIs (Old /api/v1/file/* -> New /api/v1/files*) # ============================================================================= + @manager.route("/file/get/", methods=["GET"]) @login_required async def deprecated_file_get(file_id): @@ -290,9 +294,9 @@ async def deprecated_file_get(file_id): New path: GET /api/v1/files/{file_id} """ logging.warning( - "API endpoint /api/v1/file/get/%s is deprecated. " - "Please use /api/v1/files/%s instead.", - file_id, file_id, + "API endpoint /api/v1/file/get/%s is deprecated. Please use /api/v1/files/%s instead.", + file_id, + file_id, ) # Forward to the new API implementation (download) return await file_api.download(file_id=file_id) @@ -307,10 +311,7 @@ async def deprecated_file_list(): Old path: GET /api/v1/file/list?... New path: GET /api/v1/files?... """ - logging.warning( - "API endpoint /api/v1/file/list is deprecated. " - "Please use /api/v1/files instead." - ) + logging.warning("API endpoint /api/v1/file/list is deprecated. Please use /api/v1/files instead.") # Forward to the new API implementation return await file_api.list_files() @@ -328,8 +329,7 @@ async def deprecated_file_all_parent_folder(): if not file_id: return get_data_error_result(message="`file_id` query parameter is required") logging.warning( - "API endpoint /api/v1/file/all_parent_folder is deprecated. " - "Please use /api/v1/files/%s/ancestors instead.", + "API endpoint /api/v1/file/all_parent_folder is deprecated. Please use /api/v1/files/%s/ancestors instead.", file_id, ) # Forward to the new API implementation @@ -349,8 +349,7 @@ async def deprecated_file_parent_folder(): if not file_id: return get_data_error_result(message="`file_id` query parameter is required") logging.warning( - "API endpoint /api/v1/file/parent_folder is deprecated. " - "Please use /api/v1/files/%s/parent instead.", + "API endpoint /api/v1/file/parent_folder is deprecated. Please use /api/v1/files/%s/parent instead.", file_id, ) # Forward to the new API implementation @@ -366,10 +365,7 @@ async def deprecated_file_root_folder(): Old path: GET /api/v1/file/root_folder New path: GET /api/v1/files?parent_id= """ - logging.warning( - "API endpoint /api/v1/file/root_folder is deprecated. " - "Please use /api/v1/files with appropriate parent_id instead." - ) + logging.warning("API endpoint /api/v1/file/root_folder is deprecated. Please use /api/v1/files with appropriate parent_id instead.") # Forward to the new API implementation with empty parent_id to get root return await file_api.list_files() @@ -384,10 +380,7 @@ async def deprecated_file_create(tenant_id=None): Old path: POST /api/v1/file/create New path: POST /api/v1/files """ - logging.warning( - "API endpoint /api/v1/file/create is deprecated. " - "Please use POST /api/v1/files instead." - ) + logging.warning("API endpoint /api/v1/file/create is deprecated. Please use POST /api/v1/files instead.") # Forward to the new API implementation return await file_api.create_or_upload(tenant_id=tenant_id) @@ -402,10 +395,7 @@ async def deprecated_file_upload(tenant_id=None): Old path: POST /api/v1/file/upload New path: POST /api/v1/files """ - logging.warning( - "API endpoint /api/v1/file/upload is deprecated. " - "Please use POST /api/v1/files with multipart/form-data instead." - ) + logging.warning("API endpoint /api/v1/file/upload is deprecated. Please use POST /api/v1/files with multipart/form-data instead.") # Forward to the new API implementation return await file_api.create_or_upload(tenant_id=tenant_id) @@ -419,10 +409,7 @@ async def deprecated_file_convert(): Old path: POST /api/v1/file/convert New path: POST /api/v1/files/link-to-datasets """ - logging.warning( - "API endpoint /api/v1/file/convert is deprecated. " - "Please use POST /api/v1/files/link-to-datasets instead." - ) + logging.warning("API endpoint /api/v1/file/convert is deprecated. Please use POST /api/v1/files/link-to-datasets instead.") return await file2document_api.convert() @@ -436,10 +423,7 @@ async def deprecated_file_mv(tenant_id=None): Old path: POST /api/v1/file/mv New path: POST /api/v1/files/move """ - logging.warning( - "API endpoint /api/v1/file/mv is deprecated. " - "Please use POST /api/v1/files/move instead." - ) + logging.warning("API endpoint /api/v1/file/mv is deprecated. Please use POST /api/v1/files/move instead.") # Forward to the new API implementation return await file_api.move(tenant_id=tenant_id) @@ -454,10 +438,7 @@ async def deprecated_file_rename(tenant_id=None): Old path: POST /api/v1/file/rename New path: POST /api/v1/files/move """ - logging.warning( - "API endpoint /api/v1/file/rename is deprecated. " - "Please use POST /api/v1/files/move with `new_name` instead." - ) + logging.warning("API endpoint /api/v1/file/rename is deprecated. Please use POST /api/v1/files/move with `new_name` instead.") # Transform the old API format to new format req = await request.get_json() # Old API used `file_id` and `name`, new API uses `src_file_ids` and `new_name` @@ -465,9 +446,7 @@ async def deprecated_file_rename(tenant_id=None): new_name = req.get("name") # Call the underlying service directly with transformed data try: - success, result = await file_api_service.move_files( - tenant_id, src_file_ids, None, new_name - ) + success, result = await file_api_service.move_files(tenant_id, src_file_ids, None, new_name) if success: return get_json_result(data=result) else: @@ -487,10 +466,7 @@ async def deprecated_file_rm(tenant_id=None): Old path: POST /api/v1/file/rm New path: DELETE /api/v1/files """ - logging.warning( - "API endpoint /api/v1/file/rm is deprecated. " - "Please use DELETE /api/v1/files instead." - ) + logging.warning("API endpoint /api/v1/file/rm is deprecated. Please use DELETE /api/v1/files instead.") # Transform POST with body to DELETE behavior # The new API expects a JSON body with `ids` return await file_api.delete(tenant_id=tenant_id) @@ -500,6 +476,7 @@ async def deprecated_file_rm(tenant_id=None): # Related Questions API # ============================================================================= + @manager.route("/sessions/related_questions", methods=["POST"]) @login_required async def deprecated_related_questions(): @@ -509,10 +486,7 @@ async def deprecated_related_questions(): Old path: POST /api/v1/sessions/related_questions New path: POST /api/v1/chat/recommendation """ - logging.warning( - "API endpoint /api/v1/sessions/related_questions is deprecated. " - "Please use /api/v1/chat/recommendation instead." - ) + logging.warning("API endpoint /api/v1/sessions/related_questions is deprecated. Please use /api/v1/chat/recommendation instead.") # Forward to the new API implementation return await chat_api.recommendation() @@ -521,6 +495,7 @@ async def deprecated_related_questions(): # Chunk Update API (PUT -> PATCH) # ============================================================================= + @manager.route("/datasets//documents//chunks/", methods=["PUT"]) @login_required async def deprecated_update_chunk(dataset_id, document_id, chunk_id): @@ -531,9 +506,10 @@ async def deprecated_update_chunk(dataset_id, document_id, chunk_id): New path: PATCH /api/v1/datasets/{dataset_id}/documents/{document_id}/chunks/{chunk_id} """ logging.warning( - "API endpoint PUT /api/v1/datasets/%s/documents/%s/chunks/%s is deprecated. " - "Please use PATCH instead.", - dataset_id, document_id, chunk_id, + "API endpoint PUT /api/v1/datasets/%s/documents/%s/chunks/%s is deprecated. Please use PATCH instead.", + dataset_id, + document_id, + chunk_id, ) # Forward to the new API implementation return await chunk_api.update_chunk(dataset_id=dataset_id, document_id=document_id, chunk_id=chunk_id) @@ -543,6 +519,7 @@ async def deprecated_update_chunk(dataset_id, document_id, chunk_id): # File Upload Info API # ============================================================================= + @manager.route("/file/upload_info", methods=["POST"]) @login_required async def deprecated_file_upload_info(): @@ -554,10 +531,7 @@ async def deprecated_file_upload_info(): """ from api.apps import current_user - logging.warning( - "API endpoint /api/v1/file/upload_info is deprecated. " - "Please use POST /api/v1/documents/upload instead." - ) + logging.warning("API endpoint /api/v1/file/upload_info is deprecated. Please use POST /api/v1/documents/upload instead.") # Forward to the new API implementation # Need to pass tenant_id explicitly since we're calling the function directly tenant_id = current_user.id @@ -575,10 +549,7 @@ async def deprecated_legacy_document_upload_info(): """ from api.apps import current_user - logging.warning( - "API endpoint /v1/document/upload_info is deprecated. " - "Please use POST /api/v1/documents/upload instead." - ) + logging.warning("API endpoint /v1/document/upload_info is deprecated. Please use POST /api/v1/documents/upload instead.") tenant_id = current_user.id return await document_api.upload_info(tenant_id=tenant_id) @@ -587,6 +558,7 @@ async def deprecated_legacy_document_upload_info(): # Document APIs # ============================================================================= + @manager.route("/datasets//documents/", methods=["PUT"]) @login_required async def deprecated_update_document(dataset_id, document_id): @@ -597,9 +569,9 @@ async def deprecated_update_document(dataset_id, document_id): New path: PATCH /api/v1/datasets/{dataset_id}/documents/{document_id} """ logging.warning( - "API endpoint PUT /api/v1/datasets/%s/documents/%s is deprecated. " - "Please use PATCH instead.", - dataset_id, document_id, + "API endpoint PUT /api/v1/datasets/%s/documents/%s is deprecated. Please use PATCH instead.", + dataset_id, + document_id, ) # Forward to the new API implementation return await document_api.update_document(dataset_id=dataset_id, document_id=document_id) @@ -615,9 +587,9 @@ async def deprecated_document_get(doc_id): New path: GET /api/v1/documents/{doc_id}/preview """ logging.warning( - "API endpoint /api/v1/document/get/%s is deprecated. " - "Please use /api/v1/documents/%s/preview instead.", - doc_id, doc_id, + "API endpoint /api/v1/document/get/%s is deprecated. Please use /api/v1/documents/%s/preview instead.", + doc_id, + doc_id, ) return await document_api.get(doc_id) @@ -632,9 +604,9 @@ async def deprecated_document_download(doc_id): New path: GET /api/v1/agents/attachments/{doc_id}/download """ logging.warning( - "API endpoint /api/v1/document/download/%s is deprecated. " - "Please use /api/v1/agents/attachments/%s/download instead.", - doc_id, doc_id, + "API endpoint /api/v1/document/download/%s is deprecated. Please use /api/v1/agents/attachments/%s/download instead.", + doc_id, + doc_id, ) return await agent_api.download_attachment(attachment_id=doc_id) @@ -649,16 +621,18 @@ async def document_download_v1(attachment_id): New path: GET /api/v1/agents/attachments/{attachment_id}/download """ logging.warning( - "API endpoint /v1/document/download/%s is deprecated. " - "Please use /api/v1/agents/attachments/%s/download instead.", - attachment_id, attachment_id, + "API endpoint /v1/document/download/%s is deprecated. Please use /api/v1/agents/attachments/%s/download instead.", + attachment_id, + attachment_id, ) return await agent_api.download_attachment(attachment_id=attachment_id) + # ============================================================================= # Agent Chat API # ============================================================================= + @manager.route("/agents//completions", methods=["POST"]) @login_required @add_tenant_id_to_kwargs @@ -670,12 +644,12 @@ async def deprecated_agent_completions(agent_id, tenant_id=None): New path: POST /api/v1/agents/chat/completions """ logging.warning( - "API endpoint /api/v1/agents/%s/completions is deprecated. " - "Please use /api/v1/agents/chat/completions instead.", + "API endpoint /api/v1/agents/%s/completions is deprecated. Please use /api/v1/agents/chat/completions instead.", agent_id, ) return await agent_api.agent_chat_completion(tenant_id=tenant_id, agent_id=agent_id) + def register_backward_compat_routes(app_instance): """ Register all backward compatibility routes with the app. diff --git a/api/apps/canvas_app.py b/api/apps/canvas_app.py index 39b18a9f09..b16a97c468 100644 --- a/api/apps/canvas_app.py +++ b/api/apps/canvas_app.py @@ -50,27 +50,25 @@ from api.apps.services.canvas_replica_service import CanvasReplicaService from api.db.services.canvas_service import completion as agent_completion -@manager.route('/templates', methods=['GET']) # noqa: F821 +@manager.route("/templates", methods=["GET"]) # noqa: F821 @login_required def templates(): return get_json_result(data=[c.to_dict() for c in CanvasTemplateService.get_all()]) -@manager.route('/rm', methods=['POST']) # noqa: F821 +@manager.route("/rm", methods=["POST"]) # noqa: F821 @validate_request("canvas_ids") @login_required async def rm(): req = await get_request_json() for i in req["canvas_ids"]: if not UserCanvasService.accessible(i, current_user.id): - return get_json_result( - data=False, message='Only owner of canvas authorized for this operation.', - code=RetCode.OPERATING_ERROR) + return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR) UserCanvasService.delete_by_id(i) return get_json_result(data=True) -@manager.route('/set', methods=['POST']) # noqa: F821 +@manager.route("/set", methods=["POST"]) # noqa: F821 @validate_request("dsl", "title") @login_required async def save(): @@ -89,9 +87,7 @@ async def save(): return get_data_error_result(message="Fail to save canvas.") else: if not UserCanvasService.accessible(req["id"], current_user.id): - return get_json_result( - data=False, message='Only owner of canvas authorized for this operation.', - code=RetCode.OPERATING_ERROR) + return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR) UserCanvasService.update_by_id(req["id"], req) # save version UserCanvasVersionService.save_or_replace_latest( @@ -112,7 +108,7 @@ async def save(): return get_json_result(data=req) -@manager.route('/get/', methods=['GET']) # noqa: F821 +@manager.route("/get/", methods=["GET"]) # noqa: F821 @login_required def get(canvas_id): if not UserCanvasService.accessible(canvas_id, current_user.id): @@ -135,29 +131,25 @@ def get(canvas_id): return get_json_result(data=c) -@manager.route('/getsse/', methods=['GET']) # type: ignore # noqa: F821 +@manager.route("/getsse/", methods=["GET"]) # type: ignore # noqa: F821 def getsse(canvas_id): - token = request.headers.get('Authorization').split() + token = request.headers.get("Authorization").split() if len(token) != 2: - return get_data_error_result(message='Authorization is not valid!') + return get_data_error_result(message="Authorization is not valid!") token = token[1] objs = APIToken.query(beta=token) if not objs: return get_data_error_result(message='Authentication error: API key is invalid!"') tenant_id = objs[0].tenant_id if not UserCanvasService.query(user_id=tenant_id, id=canvas_id): - return get_json_result( - data=False, - message='Only owner of canvas authorized for this operation.', - code=RetCode.OPERATING_ERROR - ) + return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR) e, c = UserCanvasService.get_by_id(canvas_id) if not e or c.user_id != tenant_id: return get_data_error_result(message="canvas not found.") return get_json_result(data=c.to_dict()) -@manager.route('/completion', methods=['POST']) # noqa: F821 +@manager.route("/completion", methods=["POST"]) # noqa: F821 @validate_request("id") @login_required async def run(): @@ -169,9 +161,7 @@ async def run(): runtime_user_id = req.get("user_id") or tenant_id user_id = str(runtime_user_id) if not await thread_pool_exec(UserCanvasService.accessible, req["id"], tenant_id): - return get_json_result( - data=False, message='Only owner of canvas authorized for this operation.', - code=RetCode.OPERATING_ERROR) + return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR) replica_payload = CanvasReplicaService.load_for_run( canvas_id=req["id"], @@ -234,7 +224,7 @@ async def run(): resp.headers.add_header("Connection", "keep-alive") resp.headers.add_header("X-Accel-Buffering", "no") resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") - #resp.call_on_close(lambda: canvas.cancel_task()) + # resp.call_on_close(lambda: canvas.cancel_task()) return resp @@ -244,6 +234,7 @@ async def exp_agent_completion(canvas_id): tenant_id = current_user.id req = await get_request_json() return_trace = bool(req.get("return_trace", False)) + async def generate(): trace_items = [] async for answer in agent_completion(tenant_id=tenant_id, agent_id=canvas_id, **req): @@ -280,9 +271,9 @@ async def exp_agent_completion(canvas_id): resp.headers.add_header("X-Accel-Buffering", "no") resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") return resp - -@manager.route('/rerun', methods=['POST']) # noqa: F821 + +@manager.route("/rerun", methods=["POST"]) # noqa: F821 @validate_request("id", "dsl", "component_id") @login_required async def rerun(): @@ -310,7 +301,7 @@ async def rerun(): return get_json_result(data=True) -@manager.route('/cancel/', methods=['PUT']) # noqa: F821 +@manager.route("/cancel/", methods=["PUT"]) # noqa: F821 @login_required def cancel(task_id): try: @@ -320,15 +311,13 @@ def cancel(task_id): return get_json_result(data=True) -@manager.route('/reset', methods=['POST']) # noqa: F821 +@manager.route("/reset", methods=["POST"]) # noqa: F821 @validate_request("id") @login_required async def reset(): req = await get_request_json() if not UserCanvasService.accessible(req["id"], current_user.id): - return get_json_result( - data=False, message='Only owner of canvas authorized for this operation.', - code=RetCode.OPERATING_ERROR) + return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR) try: e, user_canvas = UserCanvasService.get_by_id(req["id"]) if not e: @@ -361,7 +350,7 @@ async def upload(canvas_id): return server_error_response(e) -@manager.route('/input_form', methods=['GET']) # noqa: F821 +@manager.route("/input_form", methods=["GET"]) # noqa: F821 @login_required def input_form(): cvs_id = request.args.get("id") @@ -371,9 +360,7 @@ def input_form(): if not e: return get_data_error_result(message="canvas not found.") if not UserCanvasService.query(user_id=current_user.id, id=cvs_id): - return get_json_result( - data=False, message='Only owner of canvas authorized for this operation.', - code=RetCode.OPERATING_ERROR) + return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR) canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id, canvas_id=user_canvas.id) return get_json_result(data=canvas.get_component_input_form(cpn_id)) @@ -381,15 +368,13 @@ def input_form(): return server_error_response(e) -@manager.route('/debug', methods=['POST']) # noqa: F821 +@manager.route("/debug", methods=["POST"]) # noqa: F821 @validate_request("id", "component_id", "params") @login_required async def debug(): req = await get_request_json() if not UserCanvasService.accessible(req["id"], current_user.id): - return get_json_result( - data=False, message='Only owner of canvas authorized for this operation.', - code=RetCode.OPERATING_ERROR) + return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR) try: e, user_canvas = UserCanvasService.get_by_id(req["id"]) canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id, canvas_id=user_canvas.id) @@ -400,7 +385,7 @@ async def debug(): if isinstance(component, LLM): component.set_debug_inputs(req["params"]) - component.invoke(**{k: o["value"] for k,o in req["params"].items()}) + component.invoke(**{k: o["value"] for k, o in req["params"].items()}) outputs = component.output() for k in outputs.keys(): if isinstance(outputs[k], partial): @@ -418,59 +403,39 @@ async def debug(): return server_error_response(e) -@manager.route('/test_db_connect', methods=['POST']) # noqa: F821 +@manager.route("/test_db_connect", methods=["POST"]) # noqa: F821 @validate_request("db_type", "database", "username", "host", "port", "password") @login_required async def test_db_connect(): req = await get_request_json() try: if req["db_type"] in ["mysql", "mariadb"]: - db = MySQLDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"], - password=req["password"]) + db = MySQLDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"], password=req["password"]) elif req["db_type"] == "oceanbase": - db = MySQLDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"], - password=req["password"], charset="utf8mb4") - elif req["db_type"] == 'postgres': - db = PostgresqlDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"], - password=req["password"]) - elif req["db_type"] == 'mssql': + db = MySQLDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"], password=req["password"], charset="utf8mb4") + elif req["db_type"] == "postgres": + db = PostgresqlDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"], password=req["password"]) + elif req["db_type"] == "mssql": import pyodbc - connection_string = ( - f"DRIVER={{ODBC Driver 17 for SQL Server}};" - f"SERVER={req['host']},{req['port']};" - f"DATABASE={req['database']};" - f"UID={req['username']};" - f"PWD={req['password']};" - ) + + connection_string = f"DRIVER={{ODBC Driver 17 for SQL Server}};SERVER={req['host']},{req['port']};DATABASE={req['database']};UID={req['username']};PWD={req['password']};" db = pyodbc.connect(connection_string) cursor = db.cursor() cursor.execute("SELECT 1") cursor.close() - elif req["db_type"] == 'IBM DB2': + elif req["db_type"] == "IBM DB2": import ibm_db - conn_str = ( - f"DATABASE={req['database']};" - f"HOSTNAME={req['host']};" - f"PORT={req['port']};" - f"PROTOCOL=TCPIP;" - f"UID={req['username']};" - f"PWD={req['password']};" - ) - redacted_conn_str = ( - f"DATABASE={req['database']};" - f"HOSTNAME={req['host']};" - f"PORT={req['port']};" - f"PROTOCOL=TCPIP;" - f"UID={req['username']};" - f"PWD=****;" - ) + + conn_str = f"DATABASE={req['database']};HOSTNAME={req['host']};PORT={req['port']};PROTOCOL=TCPIP;UID={req['username']};PWD={req['password']};" + redacted_conn_str = f"DATABASE={req['database']};HOSTNAME={req['host']};PORT={req['port']};PROTOCOL=TCPIP;UID={req['username']};PWD=****;" logging.info(redacted_conn_str) conn = ibm_db.connect(conn_str, "", "") stmt = ibm_db.exec_immediate(conn, "SELECT 1 FROM sysibm.sysdummy1") ibm_db.fetch_assoc(stmt) ibm_db.close(conn) return get_json_result(data="Database Connection Successful!") - elif req["db_type"] == 'trino': + elif req["db_type"] == "trino": + def _parse_catalog_schema(db_name: str): if not db_name: return None, None @@ -481,6 +446,7 @@ async def test_db_connect(): else: catalog_name, schema_name = db_name, "default" return catalog_name, schema_name + try: import trino import os @@ -498,13 +464,7 @@ async def test_db_connect(): auth = trino.BasicAuthentication(req.get("username") or "ragflow", req["password"]) conn = trino.dbapi.connect( - host=req["host"], - port=int(req["port"] or 8080), - user=req["username"] or "ragflow", - catalog=catalog, - schema=schema or "default", - http_scheme=http_scheme, - auth=auth + host=req["host"], port=int(req["port"] or 8080), user=req["username"] or "ragflow", catalog=catalog, schema=schema or "default", http_scheme=http_scheme, auth=auth ) cur = conn.cursor() cur.execute("SELECT 1") @@ -514,7 +474,7 @@ async def test_db_connect(): return get_json_result(data="Database Connection Successful!") else: return server_error_response("Unsupported database type.") - if req["db_type"] != 'mssql': + if req["db_type"] != "mssql": db.connect() db.close() @@ -523,21 +483,21 @@ async def test_db_connect(): return server_error_response(e) -#api get list version dsl of canvas -@manager.route('/getlistversion/', methods=['GET']) # noqa: F821 +# api get list version dsl of canvas +@manager.route("/getlistversion/", methods=["GET"]) # noqa: F821 @login_required def getlistversion(canvas_id): try: - versions =sorted([c.to_dict() for c in UserCanvasVersionService.list_by_canvas_id(canvas_id)], key=lambda x: x["update_time"]*-1) + versions = sorted([c.to_dict() for c in UserCanvasVersionService.list_by_canvas_id(canvas_id)], key=lambda x: x["update_time"] * -1) return get_json_result(data=versions) except Exception as e: return get_data_error_result(message=f"Error getting history files: {e}") -#api get version dsl of canvas -@manager.route('/getversion/', methods=['GET']) # noqa: F821 +# api get version dsl of canvas +@manager.route("/getversion/", methods=["GET"]) # noqa: F821 @login_required -def getversion( version_id): +def getversion(version_id): try: e, version = UserCanvasVersionService.get_by_id(version_id) if version: @@ -546,7 +506,7 @@ def getversion( version_id): return get_json_result(data=f"Error getting history file: {e}") -@manager.route('/list', methods=['GET']) # noqa: F821 +@manager.route("/list", methods=["GET"]) # noqa: F821 @login_required def list_canvas(): keywords = request.args.get("keywords", "") @@ -563,18 +523,14 @@ def list_canvas(): tenants = TenantService.get_joined_tenants_by_user_id(current_user.id) tenants = [m["tenant_id"] for m in tenants] tenants.append(current_user.id) - canvas, total = UserCanvasService.get_by_tenant_ids( - tenants, current_user.id, page_number, - items_per_page, orderby, desc, keywords, canvas_category) + canvas, total = UserCanvasService.get_by_tenant_ids(tenants, current_user.id, page_number, items_per_page, orderby, desc, keywords, canvas_category) else: tenants = owner_ids - canvas, total = UserCanvasService.get_by_tenant_ids( - tenants, current_user.id, 0, - 0, orderby, desc, keywords, canvas_category) + canvas, total = UserCanvasService.get_by_tenant_ids(tenants, current_user.id, 0, 0, orderby, desc, keywords, canvas_category) return get_json_result(data={"canvas": canvas, "total": total}) -@manager.route('/setting', methods=['POST']) # noqa: F821 +@manager.route("/setting", methods=["POST"]) # noqa: F821 @validate_request("id", "title", "permission") @login_required async def setting(): @@ -582,11 +538,9 @@ async def setting(): req["user_id"] = current_user.id if not UserCanvasService.accessible(req["id"], current_user.id): - return get_json_result( - data=False, message='Only owner of canvas authorized for this operation.', - code=RetCode.OPERATING_ERROR) + return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR) - e,flow = UserCanvasService.get_by_id(req["id"]) + e, flow = UserCanvasService.get_by_id(req["id"]) if not e: return get_data_error_result(message="canvas not found.") flow = flow.to_dict() @@ -596,11 +550,11 @@ async def setting(): if value := req.get(key): flow[key] = value - num= UserCanvasService.update_by_id(req["id"], flow) + num = UserCanvasService.update_by_id(req["id"], flow) return get_json_result(data=num) -@manager.route('/trace', methods=['GET']) # noqa: F821 +@manager.route("/trace", methods=["GET"]) # noqa: F821 def trace(): cvs_id = request.args.get("canvas_id") msg_id = request.args.get("message_id") @@ -614,14 +568,12 @@ def trace(): logging.exception(e) -@manager.route('//sessions', methods=['GET']) # noqa: F821 +@manager.route("//sessions", methods=["GET"]) # noqa: F821 @login_required def sessions(canvas_id): tenant_id = current_user.id if not UserCanvasService.accessible(canvas_id, tenant_id): - return get_json_result( - data=False, message='Only owner of canvas authorized for this operation.', - code=RetCode.OPERATING_ERROR) + return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR) user_id = request.args.get("user_id") page_number = int(request.args.get("page", 1)) @@ -639,18 +591,17 @@ def sessions(canvas_id): if exp_user_id: sess = API4ConversationService.get_names(canvas_id, exp_user_id) return get_json_result(data={"total": len(sess), "sessions": sess}) - + # dsl defaults to True in all cases except for False and false include_dsl = request.args.get("dsl") != "False" and request.args.get("dsl") != "false" - total, sess = API4ConversationService.get_list(canvas_id, tenant_id, page_number, items_per_page, orderby, desc, - None, user_id, include_dsl, keywords, from_date, to_date, exp_user_id=exp_user_id) + total, sess = API4ConversationService.get_list(canvas_id, tenant_id, page_number, items_per_page, orderby, desc, None, user_id, include_dsl, keywords, from_date, to_date, exp_user_id=exp_user_id) try: return get_json_result(data={"total": total, "sessions": sess}) except Exception as e: return server_error_response(e) -@manager.route('//sessions', methods=['PUT']) # noqa: F821 +@manager.route("//sessions", methods=["PUT"]) # noqa: F821 @login_required async def set_session(canvas_id): req = await get_request_json() @@ -659,63 +610,51 @@ async def set_session(canvas_id): assert e, "Agent not found." if not isinstance(cvs.dsl, str): cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False) - session_id=get_uuid() + session_id = get_uuid() canvas = Canvas(cvs.dsl, tenant_id, canvas_id, canvas_id=cvs.id) canvas.reset() - conv = { - "id": session_id, - "name": req.get("name", ""), - "dialog_id": cvs.id, - "user_id": tenant_id, - "exp_user_id": tenant_id, - "message": [], - "source": "agent", - "dsl": cvs.dsl, - "reference": [] - } + conv = {"id": session_id, "name": req.get("name", ""), "dialog_id": cvs.id, "user_id": tenant_id, "exp_user_id": tenant_id, "message": [], "source": "agent", "dsl": cvs.dsl, "reference": []} API4ConversationService.save(**conv) return get_json_result(data=conv) -@manager.route('//sessions/', methods=['GET']) # noqa: F821 +@manager.route("//sessions/", methods=["GET"]) # noqa: F821 @login_required def get_session(canvas_id, session_id): tenant_id = current_user.id if not UserCanvasService.accessible(canvas_id, tenant_id): - return get_json_result( - data=False, message='Only owner of canvas authorized for this operation.', - code=RetCode.OPERATING_ERROR) + return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR) _, conv = API4ConversationService.get_by_id(session_id) return get_json_result(data=conv.to_dict()) -@manager.route('//sessions/', methods=['DELETE']) # noqa: F821 +@manager.route("//sessions/", methods=["DELETE"]) # noqa: F821 @login_required def del_session(canvas_id, session_id): tenant_id = current_user.id if not UserCanvasService.accessible(canvas_id, tenant_id): - return get_json_result( - data=False, message='Only owner of canvas authorized for this operation.', - code=RetCode.OPERATING_ERROR) + return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR) return get_json_result(data=API4ConversationService.delete_by_id(session_id)) -@manager.route('/prompts', methods=['GET']) # noqa: F821 +@manager.route("/prompts", methods=["GET"]) # noqa: F821 @login_required def prompts(): from rag.prompts.generator import ANALYZE_TASK_SYSTEM, ANALYZE_TASK_USER, NEXT_STEP, REFLECT, CITATION_PROMPT_TEMPLATE - return get_json_result(data={ - "task_analysis": ANALYZE_TASK_SYSTEM +"\n\n"+ ANALYZE_TASK_USER, - "plan_generation": NEXT_STEP, - "reflection": REFLECT, - #"context_summary": SUMMARY4MEMORY, - #"context_ranking": RANK_MEMORY, - "citation_guidelines": CITATION_PROMPT_TEMPLATE - }) + return get_json_result( + data={ + "task_analysis": ANALYZE_TASK_SYSTEM + "\n\n" + ANALYZE_TASK_USER, + "plan_generation": NEXT_STEP, + "reflection": REFLECT, + # "context_summary": SUMMARY4MEMORY, + # "context_ranking": RANK_MEMORY, + "citation_guidelines": CITATION_PROMPT_TEMPLATE, + } + ) -@manager.route('/download', methods=['GET']) # noqa: F821 +@manager.route("/download", methods=["GET"]) # noqa: F821 async def download(): id = request.args.get("id") created_by = request.args.get("created_by") diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index 690d54b954..8f52101f16 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -111,6 +111,7 @@ async def set_api_key(): assert factory in ChatModel, f"Chat model from {factory} is not supported yet." mdl = ChatModel[factory](req["api_key"], llm.llm_name, base_url=base_url, **extra) try: + async def check_streamly(): async for chunk in mdl.async_chat_streamly( None, @@ -149,7 +150,7 @@ async def set_api_key(): break if req.get("verify", False): - return get_json_result(data={"message": msg, "success": len(msg.strip())==0}) + return get_json_result(data={"message": msg, "success": len(msg.strip()) == 0}) if msg: return get_data_error_result(message=msg) @@ -207,7 +208,10 @@ async def add_llm(): saved_llm_name = llm_name + _LLM_NAME_SUFFIX.get(factory, "") logging.debug( "add_llm: attempting api_key recovery factory=%s llm_name=%s saved_llm_name=%s tenant_id=%s", - factory, llm_name, saved_llm_name, current_user.id, + factory, + llm_name, + saved_llm_name, + current_user.id, ) existing_llms = TenantLLMService.query( tenant_id=current_user.id, @@ -216,21 +220,25 @@ async def add_llm(): ) logging.debug( "add_llm: api_key recovery query matched=%d factory=%s saved_llm_name=%s", - len(existing_llms) if existing_llms else 0, factory, saved_llm_name, + len(existing_llms) if existing_llms else 0, + factory, + saved_llm_name, ) if existing_llms: - existing_api_key, _, _ = TenantLLMService._decode_api_key_config( - existing_llms[0].api_key - ) + existing_api_key, _, _ = TenantLLMService._decode_api_key_config(existing_llms[0].api_key) logging.debug( "add_llm: api_key recovery decoded=%s factory=%s saved_llm_name=%s", - "present" if existing_api_key else "absent", factory, saved_llm_name, + "present" if existing_api_key else "absent", + factory, + saved_llm_name, ) if existing_api_key: req["api_key"] = existing_api_key logging.info( "add_llm: recovered saved api_key from existing record factory=%s saved_llm_name=%s tenant_id=%s", - factory, saved_llm_name, current_user.id, + factory, + saved_llm_name, + current_user.id, ) api_key = req.get("api_key", "x") @@ -335,6 +343,7 @@ async def add_llm(): **extra, ) try: + async def check_streamly(): async for chunk in mdl.async_chat_streamly( None, @@ -387,6 +396,7 @@ async def add_llm(): assert factory in TTSModel, f"TTS model from {factory} is not supported yet." mdl = TTSModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url) try: + def drain_tts(): for _ in mdl.tts("Hello~ RAGFlower!"): pass diff --git a/api/apps/restful_apis/bot_api.py b/api/apps/restful_apis/bot_api.py index a1878be692..f2ee95ab21 100644 --- a/api/apps/restful_apis/bot_api.py +++ b/api/apps/restful_apis/bot_api.py @@ -39,8 +39,7 @@ from api.db.services.search_service import SearchService from api.db.services.user_service import UserTenantService from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type, get_model_config_from_provider_instance from common.misc_utils import thread_pool_exec -from api.utils.api_utils import get_error_data_result, get_json_result, \ - add_tenant_id_to_kwargs, get_result, get_request_json, server_error_response, validate_request +from api.utils.api_utils import get_error_data_result, get_json_result, add_tenant_id_to_kwargs, get_result, get_request_json, server_error_response, validate_request from rag.app.tag import label_question from rag.prompts.template import load_prompt from rag.prompts.generator import cross_languages, keyword_extraction @@ -58,7 +57,7 @@ def _get_sdk_authorization_token(): auth_header = request.headers.get("Authorization", "") if not auth_header.startswith("Bearer "): return "" - return auth_header[len("Bearer "):].strip() + return auth_header[len("Bearer ") :].strip() @manager.route("/chatbots//completions", methods=["POST"]) # noqa: F821 @@ -68,9 +67,7 @@ async def chatbot_completions(dialog_id, tenant_id=None): req = await get_request_json() exists, dialog = DialogService.get_by_id(dialog_id) - if (not exists - or getattr(dialog, "tenant_id", None) != tenant_id - or str(getattr(dialog, "status", "")) != StatusEnum.VALID.value): + if not exists or getattr(dialog, "tenant_id", None) != tenant_id or str(getattr(dialog, "status", "")) != StatusEnum.VALID.value: logger.warning( "Denied chatbot access: reason=%s tenant_id=%s dialog_id=%s user_id=%s session_id=%s", "no access to this chatbot", @@ -132,14 +129,13 @@ async def chatbot_completions(dialog_id, tenant_id=None): return None + @manager.route("/chatbots//info", methods=["GET"]) # noqa: F821 @login_required(auth_types=AUTH_BETA) @add_tenant_id_to_kwargs async def chatbots_inputs(dialog_id, tenant_id=None): exists, dialog = await thread_pool_exec(DialogService.get_by_id, dialog_id) - if (not exists - or getattr(dialog, "tenant_id", None) != tenant_id - or str(getattr(dialog, "status", "")) != StatusEnum.VALID.value): + if not exists or getattr(dialog, "tenant_id", None) != tenant_id or str(getattr(dialog, "status", "")) != StatusEnum.VALID.value: request_args = getattr(request, "args", {}) or {} request_user_id = request_args.get("user_id") if hasattr(request_args, "get") else None request_session_id = request_args.get("session_id") if hasattr(request_args, "get") else None @@ -178,6 +174,7 @@ async def agent_bot_completions(agent_id, tenant_id=None): return get_error_data_result(message=f"Can't find agent by ID: {agent_id}") if req.get("stream", True): + async def stream(): try: async for answer in agent_completion(tenant_id, agent_id, **req): @@ -185,14 +182,18 @@ async def agent_bot_completions(agent_id, tenant_id=None): except Exception as e: logging.exception(e) error_result = get_error_data_result(message=str(e) or "Unknown error") - yield "data:" + json.dumps( - { - "event": "message", - "data": {"content": f"Error {error_result['code']}: {error_result['message']}\n\n"}, - **error_result, - }, - ensure_ascii=False, - ) + "\n\n" + yield ( + "data:" + + json.dumps( + { + "event": "message", + "data": {"content": f"Error {error_result['code']}: {error_result['message']}\n\n"}, + **error_result, + }, + ensure_ascii=False, + ) + + "\n\n" + ) resp = Response(stream(), mimetype="text/event-stream") resp.headers.add_header("Cache-control", "no-cache") @@ -217,7 +218,7 @@ async def agent_bot_completions(agent_id, tenant_id=None): line = line.strip() if not line.startswith("data:"): continue - payload = line[len("data:"):].strip() + payload = line[len("data:") :].strip() if not payload: continue try: @@ -270,9 +271,7 @@ async def begin_inputs(agent_id, tenant_id=None): return get_error_data_result(f"Can't find agent by ID: {agent_id}") canvas = Canvas(json.dumps(cvs.dsl), tenant_id, canvas_id=cvs.id) - return get_result( - data={"title": cvs.title, "avatar": cvs.avatar, "inputs": canvas.get_component_input_form("begin"), - "prologue": canvas.get_prologue(), "mode": canvas.get_mode()}) + return get_result(data={"title": cvs.title, "avatar": cvs.avatar, "inputs": canvas.get_component_input_form("begin"), "prologue": canvas.get_prologue(), "mode": canvas.get_mode()}) @manager.route("/agentbots//logs/", methods=["GET"]) # noqa: F821 @@ -289,9 +288,10 @@ async def agent_bot_logs(shared_id, message_id): if not token: logger.warning( "agent_bot_logs: missing Authorization header (shared_id=%s message_id=%s)", - shared_id, message_id, + shared_id, + message_id, ) - return get_error_data_result(message='Authorization is not valid!') + return get_error_data_result(message="Authorization is not valid!") # Non-reversible fingerprint of the share token: lets operators correlate # auth-failure log lines for the same token without leaking a guessable # substring of the secret itself. @@ -300,7 +300,8 @@ async def agent_bot_logs(shared_id, message_id): if not objs: logger.warning( "agent_bot_logs: invalid beta token (fingerprint=%s shared_id=%s)", - token_fp, shared_id, + token_fp, + shared_id, ) return get_error_data_result(message='Authentication error: API key is invalid!"') @@ -308,9 +309,10 @@ async def agent_bot_logs(shared_id, message_id): if not agent_id: logger.warning( "agent_bot_logs: APIToken has no dialog_id (tenant_id=%s fingerprint=%s)", - objs[0].tenant_id, token_fp, + objs[0].tenant_id, + token_fp, ) - return get_error_data_result(message='API token is not bound to an agent.') + return get_error_data_result(message="API token is not bound to an agent.") try: binary = await thread_pool_exec(REDIS_CONN.get, f"{agent_id}-{message_id}-logs") @@ -348,9 +350,7 @@ async def ask_about_embedded(tenant_id=None): async for ans in async_ask(req["question"], req["kb_ids"], uid, chat_llm_name=chat_llm_name, search_config=search_config): yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" except Exception as e: - yield "data:" + json.dumps( - {"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, - ensure_ascii=False) + "\n\n" + yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" resp = Response(stream(), mimetype="text/event-stream") @@ -374,8 +374,7 @@ async def retrieval_test_embedded(tenant_id=None): if isinstance(kb_ids, str): kb_ids = [kb_ids] if not kb_ids: - return get_json_result(data=False, message='Please specify dataset firstly.', - code=RetCode.DATA_ERROR) + return get_json_result(data=False, message="Please specify dataset firstly.", code=RetCode.DATA_ERROR) doc_ids = req.get("doc_ids", []) similarity_threshold = float(req.get("similarity_threshold", 0.0)) vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3)) @@ -443,8 +442,7 @@ async def retrieval_test_embedded(tenant_id=None): tenant_ids.append(tenant.tenant_id) break else: - return get_json_result(data=False, message="Only owner of dataset authorized for this operation.", - code=RetCode.OPERATING_ERROR) + return get_json_result(data=False, message="Only owner of dataset authorized for this operation.", code=RetCode.OPERATING_ERROR) e, kb = await thread_pool_exec(KnowledgebaseService.get_by_id, kb_ids[0]) if not e: @@ -467,13 +465,23 @@ async def retrieval_test_embedded(tenant_id=None): labels = label_question(_question, [kb]) ranks = await settings.retriever.retrieval( - _question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top, - local_doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels + _question, + embd_mdl, + tenant_ids, + kb_ids, + page, + size, + similarity_threshold, + vector_similarity_weight, + top, + local_doc_ids, + rerank_mdl=rerank_mdl, + highlight=req.get("highlight"), + rank_feature=labels, ) if use_kg: default_chat_model = await thread_pool_exec(get_tenant_default_model_by_type, kb.tenant_id, LLMType.CHAT) - ck = await settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl, - LLMBundle(kb.tenant_id, default_chat_model)) + ck = await settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl, LLMBundle(kb.tenant_id, default_chat_model)) if ck["content_with_weight"]: ranks["chunks"].insert(0, ck) @@ -492,8 +500,7 @@ async def retrieval_test_embedded(tenant_id=None): return await _retrieval() except Exception as e: if "not_found" in str(e): - return get_json_result(data=False, message="No chunk found! Check the chunk status please!", - code=RetCode.DATA_ERROR) + return get_json_result(data=False, message="No chunk found! Check the chunk status please!", code=RetCode.DATA_ERROR) return server_error_response(e) @@ -552,8 +559,7 @@ async def detail_share_embedded(tenant_id=None): if await thread_pool_exec(SearchService.query, tenant_id=tenant.tenant_id, id=search_id): break else: - return get_json_result(data=False, message="Has no permission for this operation.", - code=RetCode.OPERATING_ERROR) + return get_json_result(data=False, message="Has no permission for this operation.", code=RetCode.OPERATING_ERROR) search = await thread_pool_exec(SearchService.get_detail, search_id) if not search: @@ -573,7 +579,7 @@ async def mindmap(tenant_id=None): search_id = req.get("search_id", "") search_app = await thread_pool_exec(SearchService.get_detail, search_id) if search_id else {} - mind_map =await gen_mindmap(req["question"], req["kb_ids"], tenant_id, search_app.get("search_config", {})) + mind_map = await gen_mindmap(req["question"], req["kb_ids"], tenant_id, search_app.get("search_config", {})) if "error" in mind_map: return server_error_response(Exception(mind_map["error"])) return get_json_result(data=mind_map) diff --git a/api/apps/restful_apis/chat_api.py b/api/apps/restful_apis/chat_api.py index 605b473137..d9a613f380 100644 --- a/api/apps/restful_apis/chat_api.py +++ b/api/apps/restful_apis/chat_api.py @@ -27,9 +27,7 @@ from quart import Response, request from api.apps import current_user, login_required from api.apps.restful_apis._generation_params import merge_generation_config, pop_generation_config -from api.db.joint_services.tenant_model_service import ( - get_tenant_default_model_by_type, get_model_config_from_provider_instance, get_api_key, split_model_name -) +from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type, get_model_config_from_provider_instance, get_api_key, split_model_name from api.db.services.chunk_feedback_service import ChunkFeedbackService from api.db.services.conversation_service import ConversationService, structure_answer from api.db.services.dialog_service import DialogService, async_chat, gen_mindmap @@ -52,6 +50,7 @@ from common.misc_utils import get_uuid, thread_pool_exec from rag.prompts.generator import chunks_format from rag.prompts.template import load_prompt + def _sanitize_json_floats(obj): """Replace NaN/Infinity floats with None so the result is RFC 8259 JSON. @@ -86,8 +85,8 @@ def _sanitize_json_floats(obj): _DEFAULT_PROMPT_CONFIG = { "system": ( - 'You are an intelligent assistant. Please summarize the content of the dataset to answer the question. ' - 'Please list the data in the dataset and answer in detail. When all dataset content is irrelevant to the ' + "You are an intelligent assistant. Please summarize the content of the dataset to answer the question. " + "Please list the data in the dataset and answer in detail. When all dataset content is irrelevant to the " 'question, your answer must include the sentence "The answer you are looking for is not found in the dataset!" ' "Answers need to consider chat history.\n" " Here is the knowledge base:\n" @@ -162,10 +161,7 @@ def _build_session_response(conv: dict) -> dict: async def _ensure_owned_chat(chat_id): - return await thread_pool_exec( - DialogService.query, - tenant_id=current_user.id, id=chat_id, status=StatusEnum.VALID.value - ) + return await thread_pool_exec(DialogService.query, tenant_id=current_user.id, id=chat_id, status=StatusEnum.VALID.value) def _build_default_completion_dialog(): @@ -292,10 +288,11 @@ async def _validate_llm_id(llm_id, tenant_id, llm_setting=None): return None + async def _validate_rerank_id(rerank_id, tenant_id): if not rerank_id: return None - parts = rerank_id.split('@') + parts = rerank_id.split("@") llm_name = parts[0] if llm_name in _DEFAULT_RERANK_MODELS: return None @@ -342,7 +339,7 @@ async def _validate_dataset_ids(dataset_ids, tenant_id): embd_ids = [split_model_name(kb.embd_id)[0] for kb in kbs] if len(set(embd_ids)) > 1: - return f'Datasets use different embedding models: {[kb.embd_id for kb in kbs]}' + return f"Datasets use different embedding models: {[kb.embd_id for kb in kbs]}" return normalized_ids @@ -462,7 +459,14 @@ async def list_chats(): if owner_ids: chats, total = await thread_pool_exec( DialogService.get_by_tenant_ids, - owner_ids, current_user.id, 0, 0, orderby, desc, keywords, **exact_filters, + owner_ids, + current_user.id, + 0, + 0, + orderby, + desc, + keywords, + **exact_filters, ) chats = [chat for chat in chats if chat["tenant_id"] in owner_ids] total = len(chats) @@ -472,12 +476,17 @@ async def list_chats(): else: chats, total = await thread_pool_exec( DialogService.get_by_tenant_ids, - [], current_user.id, page_number, items_per_page, orderby, desc, keywords, **exact_filters, + [], + current_user.id, + page_number, + items_per_page, + orderby, + desc, + keywords, + **exact_filters, ) - return get_json_result( - data={"chats": [_build_chat_response(chat) for chat in chats], "total": total} - ) + return get_json_result(data={"chats": [_build_chat_response(chat) for chat in chats], "total": total}) except Exception as ex: return server_error_response(ex) @@ -490,7 +499,9 @@ async def get_chat(chat_id): for tenant in tenants: if await thread_pool_exec( DialogService.query, - tenant_id=tenant.tenant_id, id=chat_id, status=StatusEnum.VALID.value, + tenant_id=tenant.tenant_id, + id=chat_id, + status=StatusEnum.VALID.value, ): break else: @@ -512,9 +523,7 @@ async def get_chat(chat_id): @login_required async def update_chat(chat_id): if not await _ensure_owned_chat(chat_id): - return get_json_result( - data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR - ) + return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) try: req = await get_request_json() @@ -596,9 +605,7 @@ async def update_chat(chat_id): @login_required async def patch_chat(chat_id): if not await _ensure_owned_chat(chat_id): - return get_json_result( - data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR - ) + return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) try: req = await get_request_json() @@ -686,9 +693,7 @@ async def patch_chat(chat_id): @login_required async def delete_chat(chat_id): if not await _ensure_owned_chat(chat_id): - return get_json_result( - data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR - ) + return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) try: if not DialogService.update_by_id(chat_id, {"status": StatusEnum.INVALID.value}): @@ -708,12 +713,7 @@ async def bulk_delete_chats(): ids = req.get("ids") if not ids: if req.get("delete_all") is True: - ids = [ - chat.id - for chat in DialogService.query( - tenant_id=current_user.id, status=StatusEnum.VALID.value - ) - ] + ids = [chat.id for chat in DialogService.query(tenant_id=current_user.id, status=StatusEnum.VALID.value)] if not ids: return get_json_result(data={}) else: @@ -799,9 +799,7 @@ async def list_sessions(chat_id): session_id = request.args.get("id") name = request.args.get("name") user_id = request.args.get("user_id") - convs = ConversationService.get_list( - chat_id, page_number, items_per_page, orderby, desc, session_id, name, user_id - ) + convs = ConversationService.get_list(chat_id, page_number, items_per_page, orderby, desc, session_id, name, user_id) if items_per_page == 0: convs = [] return get_json_result(data=[_build_session_response(c) for c in convs]) @@ -1045,17 +1043,21 @@ async def transcription(): uploaded = files["file"] ALLOWED_EXTS = { - ".wav", ".mp3", ".m4a", ".aac", - ".flac", ".ogg", ".webm", - ".opus", ".wma", + ".wav", + ".mp3", + ".m4a", + ".aac", + ".flac", + ".ogg", + ".webm", + ".opus", + ".wma", } filename = uploaded.filename or "" suffix = os.path.splitext(filename)[-1].lower() if suffix not in ALLOWED_EXTS: - return get_data_error_result( - message=f"Unsupported audio format: {suffix}. Allowed: {', '.join(sorted(ALLOWED_EXTS))}" - ) + return get_data_error_result(message=f"Unsupported audio format: {suffix}. Allowed: {', '.join(sorted(ALLOWED_EXTS))}") fd, temp_audio_path = tempfile.mkstemp(suffix=suffix) os.close(fd) diff --git a/api/apps/restful_apis/chat_channel_api.py b/api/apps/restful_apis/chat_channel_api.py index 93e389a8aa..0ba9cc5f99 100644 --- a/api/apps/restful_apis/chat_channel_api.py +++ b/api/apps/restful_apis/chat_channel_api.py @@ -37,14 +37,7 @@ def _chat_channel_auth_error(channel_id: str, user_id: str): async def create_chat_channel(): """Create a chat channel bot owned by the current tenant.""" req = await get_request_json() - channel = { - "id": get_uuid(), - "tenant_id": current_user.id, - "name": req["name"], - "channel": req["channel"], - "config": req.get("config") or {}, - "chat_id": req.get("chat_id") or None - } + channel = {"id": get_uuid(), "tenant_id": current_user.id, "name": req["name"], "channel": req["channel"], "config": req.get("config") or {}, "chat_id": req.get("chat_id") or None} ChatChannelService.insert(**channel) e, conn = ChatChannelService.get_by_id(channel["id"]) diff --git a/api/apps/restful_apis/connector_api.py b/api/apps/restful_apis/connector_api.py index b7b09bae0f..8348351bd7 100644 --- a/api/apps/restful_apis/connector_api.py +++ b/api/apps/restful_apis/connector_api.py @@ -551,6 +551,7 @@ async def google_drive_web_oauth_callback(): return await _render_web_oauth_popup(state_id, True, "Authorization completed successfully.", source) + @manager.route("/connectors/google/oauth/web/result", methods=["POST"]) # noqa: F821 @login_required @validate_request("flow_id") @@ -571,13 +572,14 @@ async def poll_google_web_result(): REDIS_CONN.delete(_web_result_cache_key(flow_id, source)) return get_json_result(data={"credentials": result.get("credentials")}) + @manager.route("/connectors/box/oauth/web/start", methods=["POST"]) # noqa: F821 @login_required async def start_box_web_oauth(): req = await get_request_json() client_id = req.get("client_id") - client_secret = req.get("client_secret") + client_secret = req.get("client_secret") redirect_uri = req.get("redirect_uri", BOX_WEB_OAUTH_REDIRECT_URI) if not client_id or not client_secret: @@ -608,18 +610,20 @@ async def start_box_web_oauth(): } REDIS_CONN.set_obj(_web_state_cache_key(flow_id, "box"), cache_payload, WEB_FLOW_TTL_SECS) return get_json_result( - data = { + data={ "flow_id": flow_id, "authorization_url": auth_url, - "expires_in": WEB_FLOW_TTL_SECS,} + "expires_in": WEB_FLOW_TTL_SECS, + } ) + @manager.route("/connectors/box/oauth/web/callback", methods=["GET"]) # noqa: F821 async def box_web_oauth_callback(): flow_id = request.args.get("state") if not flow_id: return await _render_web_oauth_popup("", False, "Missing OAuth parameters.", "box") - + code = request.args.get("code") if not code: return await _render_web_oauth_popup(flow_id, False, "Missing authorization code from Box.", "box") @@ -633,7 +637,7 @@ async def box_web_oauth_callback(): if error: REDIS_CONN.delete(_web_state_cache_key(flow_id, "box")) return await _render_web_oauth_popup(flow_id, False, error_description or "Authorization failed.", "box") - + auth = BoxOAuth( OAuthConfig( client_id=cache_payload.get("client_id"), @@ -656,6 +660,7 @@ async def box_web_oauth_callback(): return await _render_web_oauth_popup(flow_id, True, "Authorization completed successfully.", "box") + @manager.route("/connectors/box/oauth/web/result", methods=["POST"]) # noqa: F821 @login_required @validate_request("flow_id") @@ -670,7 +675,7 @@ async def poll_box_web_result(): cache_raw = json.loads(cache_blob) if cache_raw.get("user_id") != current_user.id: return get_json_result(code=RetCode.PERMISSION_ERROR, message="You are not allowed to access this authorization result.") - + REDIS_CONN.delete(_web_result_cache_key(flow_id, "box")) return get_json_result(data={"credentials": cache_raw}) diff --git a/api/apps/restful_apis/dify_retrieval_api.py b/api/apps/restful_apis/dify_retrieval_api.py index d7c2968606..6ede32cc40 100644 --- a/api/apps/restful_apis/dify_retrieval_api.py +++ b/api/apps/restful_apis/dify_retrieval_api.py @@ -108,7 +108,7 @@ def _parse_retrieval_options(retrieval_setting): return retrieval_setting, similarity_threshold, top -@manager.route('/dify/retrieval', methods=['POST', 'GET']) # noqa: F821 +@manager.route("/dify/retrieval", methods=["POST", "GET"]) # noqa: F821 @login_required @add_tenant_id_to_kwargs async def retrieval(tenant_id): @@ -251,7 +251,6 @@ async def retrieval(tenant_id): doc_ids = [] try: - e, kb = KnowledgebaseService.get_by_id(kb_id) if not e: return build_error_result(message="Knowledgebase not found!", code=RetCode.NOT_FOUND) @@ -279,17 +278,13 @@ async def retrieval(tenant_id): vector_similarity_weight=0.3, top=top, doc_ids=doc_ids, - rank_feature=label_question(question, [kb]) + rank_feature=label_question(question, [kb]), ) ranks["chunks"] = settings.retriever.retrieval_by_children(ranks["chunks"], [tenant_id]) if use_kg: model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT) - ck = await settings.kg_retriever.retrieval(question, - [tenant_id], - [kb_id], - embd_mdl, - LLMBundle(kb.tenant_id, model_config)) + ck = await settings.kg_retriever.retrieval(question, [tenant_id], [kb_id], embd_mdl, LLMBundle(kb.tenant_id, model_config)) if ck["content_with_weight"]: ranks["chunks"].insert(0, ck) @@ -303,30 +298,21 @@ async def retrieval(tenant_id): if not doc: continue c.pop("vector", None) - meta = getattr(doc, 'meta_fields', {}) + meta = getattr(doc, "meta_fields", {}) meta["doc_id"] = c["doc_id"] # Dify expects metadata.document_id for external retrieval sources. meta["document_id"] = c["doc_id"] - records.append({ - "content": c["content_with_weight"], - "score": c["similarity"], - "title": c["docnm_kwd"], - "metadata": meta - }) + records.append({"content": c["content_with_weight"], "score": c["similarity"], "title": c["docnm_kwd"], "metadata": meta}) return jsonify({"records": records}) except Exception as e: if "not_found" in str(e): - return build_error_result( - message='No chunk found! Check the chunk status please!', - code=RetCode.NOT_FOUND - ) + return build_error_result(message="No chunk found! Check the chunk status please!", code=RetCode.NOT_FOUND) logging.exception(e) return build_error_result(message=str(e), code=RetCode.SERVER_ERROR) - - -@manager.route('/dify/retrieval/health', methods=['GET']) # noqa: F821 + + +@manager.route("/dify/retrieval/health", methods=["GET"]) # noqa: F821 async def retrieval_health_check(): """Health check endpoint for Dify external knowledge base connectivity verification.""" return get_json_result(data=True) - diff --git a/api/apps/restful_apis/file2document_api.py b/api/apps/restful_apis/file2document_api.py index 9c466a441d..a148fce2b7 100644 --- a/api/apps/restful_apis/file2document_api.py +++ b/api/apps/restful_apis/file2document_api.py @@ -56,27 +56,31 @@ def _convert_files(file_ids, kb_ids, user_id): e, kb = KnowledgebaseService.get_by_id(kb_id) if not e: continue - doc = DocumentService.insert({ - "id": get_uuid(), - "kb_id": kb.id, - "parser_id": FileService.get_parser(file.type, file.name, kb.parser_id), - "pipeline_id": kb.pipeline_id, - "parser_config": kb.parser_config, - "created_by": user_id, - "type": file.type, - "name": file.name, - "suffix": Path(file.name).suffix.lstrip("."), - "location": file.location, - "size": file.size - }) - File2DocumentService.insert({ - "id": get_uuid(), - "file_id": id, - "document_id": doc.id, - }) + doc = DocumentService.insert( + { + "id": get_uuid(), + "kb_id": kb.id, + "parser_id": FileService.get_parser(file.type, file.name, kb.parser_id), + "pipeline_id": kb.pipeline_id, + "parser_config": kb.parser_config, + "created_by": user_id, + "type": file.type, + "name": file.name, + "suffix": Path(file.name).suffix.lstrip("."), + "location": file.location, + "size": file.size, + } + ) + File2DocumentService.insert( + { + "id": get_uuid(), + "file_id": id, + "document_id": doc.id, + } + ) -@manager.route('/files/link-to-datasets', methods=['POST']) # noqa: F821 +@manager.route("/files/link-to-datasets", methods=["POST"]) # noqa: F821 @login_required @validate_request("file_ids", "kb_ids") async def convert(): @@ -162,9 +166,7 @@ async def convert(): # soon as the background task is scheduled. loop = asyncio.get_running_loop() future = loop.run_in_executor(None, _convert_files, all_file_ids, kb_ids, user_id) - future.add_done_callback( - lambda f: logging.error("_convert_files failed: %s", f.exception()) if f.exception() else None - ) + future.add_done_callback(lambda f: logging.error("_convert_files failed: %s", f.exception()) if f.exception() else None) logger.info( "user_id=%s resource_type=file_to_dataset_link resource_id=batch action=schedule_convert result=scheduled file_ids=%s kb_ids=%s", user_id, diff --git a/api/apps/restful_apis/file_api.py b/api/apps/restful_apis/file_api.py index 2815dd681a..5ffd99dede 100644 --- a/api/apps/restful_apis/file_api.py +++ b/api/apps/restful_apis/file_api.py @@ -69,11 +69,11 @@ async def create_or_upload(tenant_id: str = None): form = await request.form pf_id = form.get("parent_id") files = await request.files - if 'file' not in files: + if "file" not in files: return get_error_argument_result("No file part!") - file_objs = files.getlist('file') + file_objs = files.getlist("file") for file_obj in file_objs: - if file_obj.filename == '': + if file_obj.filename == "": return get_error_argument_result("No file selected!") success, result = await file_api_service.upload_file(tenant_id, pf_id, file_objs) @@ -86,9 +86,7 @@ async def create_or_upload(tenant_id: str = None): if err is not None: return get_error_argument_result(err) - success, result = await file_api_service.create_folder( - tenant_id, req["name"], req.get("parent_id"), req.get("type") - ) + success, result = await file_api_service.create_folder(tenant_id, req["name"], req.get("parent_id"), req.get("type")) if success: return get_result(data=result) else: @@ -198,9 +196,7 @@ async def delete(tenant_id: str = None): errors = result.get("errors", []) return get_json_result( code=RetCode.DATA_ERROR, - message=f"Partially deleted {success_count} files with {len(errors)} errors" - if success_count > 0 - else f"Deleted files failed with {len(errors)} errors", + message=f"Partially deleted {success_count} files with {len(errors)} errors" if success_count > 0 else f"Deleted files failed with {len(errors)} errors", data=result, ) return get_error_data_result(message=result) @@ -209,7 +205,6 @@ async def delete(tenant_id: str = None): return get_error_data_result(message="Internal server error") - @manager.route("/files/move", methods=["POST"]) # noqa: F821 @login_required @add_tenant_id_to_kwargs @@ -254,9 +249,7 @@ async def move(tenant_id: str = None): return get_error_argument_result(err) try: - success, result = await file_api_service.move_files( - tenant_id, req["src_file_ids"], req.get("dest_file_id"), req.get("new_name") - ) + success, result = await file_api_service.move_files(tenant_id, req["src_file_ids"], req.get("dest_file_id"), req.get("new_name")) if success: return get_result(data=result) else: diff --git a/api/apps/restful_apis/memory_api.py b/api/apps/restful_apis/memory_api.py index e080f82c45..b455fca1f8 100644 --- a/api/apps/restful_apis/memory_api.py +++ b/api/apps/restful_apis/memory_api.py @@ -35,12 +35,7 @@ async def create_memory(): req = await get_request_json() t_parsed = time.perf_counter() if timing_enabled else None try: - memory_info = { - "name": req["name"], - "memory_type": req["memory_type"], - "embd_id": req["embd_id"], - "llm_id": req["llm_id"] - } + memory_info = {"name": req["name"], "memory_type": req["memory_type"], "embd_id": req["embd_id"], "llm_id": req["llm_id"]} success, res = await memory_api_service.create_memory(memory_info) if timing_enabled: logging.info( @@ -84,10 +79,26 @@ async def create_memory(): @login_required async def update_memory(memory_id): req = await get_request_json() - new_settings = {k: req[k] for k in [ - "name", "permissions", "llm_id", "embd_id", "memory_type", "memory_size", "forgetting_policy", "temperature", - "avatar", "description", "system_prompt", "user_prompt", "tenant_llm_id", "tenant_embd_id" - ] if k in req} + new_settings = { + k: req[k] + for k in [ + "name", + "permissions", + "llm_id", + "embd_id", + "memory_type", + "memory_size", + "forgetting_policy", + "temperature", + "avatar", + "description", + "system_prompt", + "user_prompt", + "tenant_llm_id", + "tenant_embd_id", + ] + if k in req + } try: success, res = await memory_api_service.update_memory(memory_id, new_settings) if success: @@ -122,9 +133,7 @@ async def delete_memory(memory_id): @manager.route("/memories", methods=["GET"]) # noqa: F821 @login_required async def list_memory(): - filter_params = { - k: request.args.get(k) for k in ["memory_type", "tenant_id", "owner_ids", "storage_type"] if k in request.args - } + filter_params = {k: request.args.get(k) for k in ["memory_type", "tenant_id", "owner_ids", "storage_type"] if k in request.args} keywords = request.args.get("keywords") page = int(request.args.get("page", 1)) page_size = validate_rest_api_page_size(int(request.args.get("page_size", 50))) @@ -155,16 +164,14 @@ async def get_memory_config(memory_id): async def get_memory_messages(memory_id): args = request.args agent_ids = args.getlist("agent_id") - if len(agent_ids) == 1 and ',' in agent_ids[0]: - agent_ids = agent_ids[0].split(',') + if len(agent_ids) == 1 and "," in agent_ids[0]: + agent_ids = agent_ids[0].split(",") keywords = args.get("keywords", "") keywords = keywords.strip() page = int(args.get("page", 1)) page_size = validate_rest_api_page_size(int(args.get("page_size", 50))) try: - res = await memory_api_service.get_memory_messages( - memory_id, agent_ids, keywords, page, page_size - ) + res = await memory_api_service.get_memory_messages(memory_id, agent_ids, keywords, page, page_size) return get_json_result(message=True, data=res) except NotFoundException as not_found_exception: logging.error(not_found_exception) @@ -174,7 +181,7 @@ async def get_memory_messages(memory_id): return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error") -@manager.route("/messages", methods=["POST"]) # noqa: F821 +@manager.route("/messages", methods=["POST"]) # noqa: F821 @login_required @validate_request("memory_id", "agent_id", "session_id", "user_input", "agent_response") async def add_message(): @@ -206,7 +213,7 @@ async def add_message(): return get_json_result(message="Some messages failed to add. Detail:" + msg, code=RetCode.SERVER_ERROR) -@manager.route("/messages/:", methods=["DELETE"]) # noqa: F821 +@manager.route("/messages/:", methods=["DELETE"]) # noqa: F821 @login_required async def forget_message(memory_id: str, message_id: int): try: @@ -220,7 +227,7 @@ async def forget_message(memory_id: str, message_id: int): return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error") -@manager.route("/messages/:", methods=["PUT"]) # noqa: F821 +@manager.route("/messages/:", methods=["PUT"]) # noqa: F821 @login_required @validate_request("status") async def update_message(memory_id: str, message_id: int): @@ -243,13 +250,13 @@ async def update_message(memory_id: str, message_id: int): return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error") -@manager.route("/messages/search", methods=["GET"]) # noqa: F821 +@manager.route("/messages/search", methods=["GET"]) # noqa: F821 @login_required async def search_message(): args = request.args memory_ids = args.getlist("memory_id") - if len(memory_ids) == 1 and ',' in memory_ids[0]: - memory_ids = memory_ids[0].split(',') + if len(memory_ids) == 1 and "," in memory_ids[0]: + memory_ids = memory_ids[0].split(",") query = args.get("query") similarity_threshold = float(args.get("similarity_threshold", 0.2)) keywords_similarity_weight = float(args.get("keywords_similarity_weight", 0.7)) @@ -258,28 +265,19 @@ async def search_message(): session_id = args.get("session_id", "") user_id = args.get("user_id", "") - filter_dict = { - "memory_id": memory_ids, - "agent_id": agent_id, - "session_id": session_id, - "user_id": user_id - } - params = { - "query": query, - "similarity_threshold": similarity_threshold, - "keywords_similarity_weight": keywords_similarity_weight, - "top_n": top_n - } + filter_dict = {"memory_id": memory_ids, "agent_id": agent_id, "session_id": session_id, "user_id": user_id} + params = {"query": query, "similarity_threshold": similarity_threshold, "keywords_similarity_weight": keywords_similarity_weight, "top_n": top_n} res = await memory_api_service.search_message(filter_dict, params) return get_json_result(message=True, data=res) -@manager.route("/messages", methods=["GET"]) # noqa: F821 + +@manager.route("/messages", methods=["GET"]) # noqa: F821 @login_required async def get_messages(): args = request.args memory_ids = args.getlist("memory_id") - if len(memory_ids) == 1 and ',' in memory_ids[0]: - memory_ids = memory_ids[0].split(',') + if len(memory_ids) == 1 and "," in memory_ids[0]: + memory_ids = memory_ids[0].split(",") agent_id = args.get("agent_id", "") session_id = args.get("session_id", "") limit = int(args.get("limit", 10)) @@ -293,7 +291,7 @@ async def get_messages(): return get_json_result(code=RetCode.SERVER_ERROR, message="Internal server error") -@manager.route("/messages/:/content", methods=["GET"]) # noqa: F821 +@manager.route("/messages/:/content", methods=["GET"]) # noqa: F821 @login_required async def get_message_content(memory_id: str, message_id: int): try: diff --git a/api/apps/restful_apis/models_api.py b/api/apps/restful_apis/models_api.py index cac0a1cf9e..c19d0f7d34 100644 --- a/api/apps/restful_apis/models_api.py +++ b/api/apps/restful_apis/models_api.py @@ -188,9 +188,7 @@ async def set_default_models(tenant_id: str): model_type = data["model_type"] try: - success, msg = models_api_service.set_tenant_default_models( - tenant_id, model_provider, model_instance, model_name, model_type - ) + success, msg = models_api_service.set_tenant_default_models(tenant_id, model_provider, model_instance, model_name, model_type) if success: logging.info(f"success: {success}, msg: {msg}") return get_result(message=msg) diff --git a/api/apps/restful_apis/openai_api.py b/api/apps/restful_apis/openai_api.py index 0642dd04ca..4d913dccdd 100644 --- a/api/apps/restful_apis/openai_api.py +++ b/api/apps/restful_apis/openai_api.py @@ -54,6 +54,7 @@ def _validate_llm_id(llm_id, tenant_id, llm_setting=None): import logging from api.utils.reference_metadata_utils import enrich_chunks_with_document_metadata + def _build_reference_chunks(reference, include_metadata=False, metadata_fields=None): chunks = chunks_format(reference) if not include_metadata: @@ -195,6 +196,7 @@ async def _stream_chat_completion_sse( yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n" yield "data:[DONE]\n\n" + def _normalize_message_content(content): """Convert OpenAI message content to a string for the dialog layer. diff --git a/api/apps/restful_apis/plugin_api.py b/api/apps/restful_apis/plugin_api.py index 6d53fbc626..ad82d149c2 100644 --- a/api/apps/restful_apis/plugin_api.py +++ b/api/apps/restful_apis/plugin_api.py @@ -21,7 +21,7 @@ from api.utils.api_utils import get_json_result from agent.plugin import GlobalPluginManager -@manager.route('/plugin/tools', methods=['GET']) # noqa: F821 +@manager.route("/plugin/tools", methods=["GET"]) # noqa: F821 @login_required def llm_tools() -> Response: tools = GlobalPluginManager.get_llm_tools() diff --git a/api/apps/restful_apis/provider_api.py b/api/apps/restful_apis/provider_api.py index 3c3d524e95..dea0c03b49 100644 --- a/api/apps/restful_apis/provider_api.py +++ b/api/apps/restful_apis/provider_api.py @@ -618,9 +618,7 @@ def list_instance_models(tenant_id: str = None, provider_id_or_name: str = None, """ supported_only = request.args.get("supported", "").lower() == "true" try: - success, result = provider_api_service.list_instance_models( - tenant_id, provider_id_or_name, instance_id_or_name, supported_only - ) + success, result = provider_api_service.list_instance_models(tenant_id, provider_id_or_name, instance_id_or_name, supported_only) if success: return get_result(data=result) else: @@ -680,9 +678,7 @@ async def update_instance_models(tenant_id: str, provider_id_or_name: str, insta model_name = data["model_name"] model_type = data["model_type"] try: - success, msg = provider_api_service.update_instance_models( - tenant_id, provider_id_or_name, instance_id_or_name, model_name, model_type - ) + success, msg = provider_api_service.update_instance_models(tenant_id, provider_id_or_name, instance_id_or_name, model_name, model_type) if success: return get_result(message=msg) else: @@ -755,9 +751,7 @@ async def add_model_to_instance(tenant_id: str, provider_id_or_name: str, instan extra = data.get("extra", {}) try: - success, result = provider_api_service.add_model_to_instance( - tenant_id, provider_id_or_name, instance_id_or_name, model_name, model_type, max_tokens, extra - ) + success, result = provider_api_service.add_model_to_instance(tenant_id, provider_id_or_name, instance_id_or_name, model_name, model_type, max_tokens, extra) if success: return get_result(message=result) else: @@ -827,9 +821,7 @@ async def enable_or_disable_model(tenant_id: str = None, provider_id_or_name: st return get_error_argument_result(message="status must be 'active' or 'inactive'") try: - success, msg = provider_api_service.update_model_status( - tenant_id, provider_id_or_name, instance_id_or_name, model_name, status - ) + success, msg = provider_api_service.update_model_status(tenant_id, provider_id_or_name, instance_id_or_name, model_name, status) if success: return get_result(message=msg) else: @@ -904,15 +896,14 @@ async def chat_to_model(tenant_id: str = None, provider_id_or_name: str = None, thinking = data.get("thinking", False) try: - success, result = await provider_api_service.chat_to_model( - tenant_id, provider_id_or_name, instance_id_or_name, model_name, message, stream, thinking - ) + success, result = await provider_api_service.chat_to_model(tenant_id, provider_id_or_name, instance_id_or_name, model_name, message, stream, thinking) if not success: return get_error_data_result(message=result) if stream and isinstance(result, dict) and result.get("type") == "stream": # Streaming response using SSE from quart import Response + llm = result["llm"] async def generate(): @@ -925,10 +916,14 @@ async def chat_to_model(tenant_id: str = None, provider_id_or_name: str = None, yield f"data: [MESSAGE]{chunk}\n\n" yield "data: [DONE]\n\n" - return Response(generate(), mimetype="text/event-stream", headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - }) + return Response( + generate(), + mimetype="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + }, + ) # Non-streaming response return get_result(data=result) diff --git a/api/apps/restful_apis/search_api.py b/api/apps/restful_apis/search_api.py index 9e164b17f7..69965b2e28 100644 --- a/api/apps/restful_apis/search_api.py +++ b/api/apps/restful_apis/search_api.py @@ -92,7 +92,7 @@ def list_searches(): search_apps = [s for s in search_apps if s["tenant_id"] in owner_ids] total = len(search_apps) if page_number and items_per_page: - search_apps = search_apps[(page_number - 1) * items_per_page: page_number * items_per_page] + search_apps = search_apps[(page_number - 1) * items_per_page : page_number * items_per_page] return get_json_result(data={"search_apps": search_apps, "total": total}) except Exception as e: return server_error_response(e) @@ -151,9 +151,7 @@ async def update(search_id): return get_data_error_result(message="search_config must be a JSON object") req["search_config"] = {**current_config, **new_config} logging.debug( - "Search update weight: search_id=%s user_id=%s " - "incoming_vector_similarity_weight=%s stored_vector_similarity_weight=%s " - "stored_full_text_weight=%s", + "Search update weight: search_id=%s user_id=%s incoming_vector_similarity_weight=%s stored_vector_similarity_weight=%s stored_full_text_weight=%s", search_id, current_user.id, new_config.get("vector_similarity_weight"), @@ -212,8 +210,7 @@ async def completion(search_id): search_config = search_app.get("search_config", {}) logging.debug( - "Search completion loaded weight: search_id=%s user_id=%s " - "stored_vector_similarity_weight=%s stored_full_text_weight=%s", + "Search completion loaded weight: search_id=%s user_id=%s stored_vector_similarity_weight=%s stored_full_text_weight=%s", search_id, uid, search_config.get("vector_similarity_weight", 0.3), @@ -229,10 +226,14 @@ async def completion(search_id): async for ans in async_ask(req["question"], kb_ids, uid, search_config=search_config, search_id=search_id): yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" except Exception as ex: - yield "data:" + json.dumps( - {"code": 500, "message": str(ex), "data": {"answer": "**ERROR**: " + str(ex), "reference": []}}, - ensure_ascii=False, - ) + "\n\n" + yield ( + "data:" + + json.dumps( + {"code": 500, "message": str(ex), "data": {"answer": "**ERROR**: " + str(ex), "reference": []}}, + ensure_ascii=False, + ) + + "\n\n" + ) yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" resp = Response(stream(), mimetype="text/event-stream") diff --git a/api/apps/restful_apis/stats_api.py b/api/apps/restful_apis/stats_api.py index 7185194327..42be01d77b 100644 --- a/api/apps/restful_apis/stats_api.py +++ b/api/apps/restful_apis/stats_api.py @@ -20,7 +20,8 @@ from api.db.services.user_service import UserTenantService from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response from api.apps import login_required, current_user -@manager.route('/system/stats', methods=['GET']) # noqa: F821 + +@manager.route("/system/stats", methods=["GET"]) # noqa: F821 @login_required def stats(): try: @@ -29,15 +30,10 @@ def stats(): return get_data_error_result(message="Tenant not found!") objs = API4ConversationService.stats( tenants[0].tenant_id, - request.args.get( - "from_date", - (datetime.now() - - timedelta( - days=7)).strftime("%Y-%m-%d 00:00:00")), - request.args.get( - "to_date", - datetime.now().strftime("%Y-%m-%d %H:%M:%S")), - "agent" if "canvas_id" in request.args else None) + request.args.get("from_date", (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d 00:00:00")), + request.args.get("to_date", datetime.now().strftime("%Y-%m-%d %H:%M:%S")), + "agent" if "canvas_id" in request.args else None, + ) res = {"pv": [], "uv": [], "speed": [], "tokens": [], "round": [], "thumb_up": []} @@ -45,8 +41,8 @@ def stats(): dt = obj["dt"] res["pv"].append((dt, obj["pv"])) res["uv"].append((dt, obj["uv"])) - res["speed"].append((dt, float(obj["tokens"]) / (float(obj["duration"]) + 0.1))) # +0.1 to avoid division by zero - res["tokens"].append((dt, float(obj["tokens"]) / 1000.0)) # convert to thousands + res["speed"].append((dt, float(obj["tokens"]) / (float(obj["duration"]) + 0.1))) # +0.1 to avoid division by zero + res["tokens"].append((dt, float(obj["tokens"]) / 1000.0)) # convert to thousands res["round"].append((dt, obj["round"])) res["thumb_up"].append((dt, obj["thumb_up"])) diff --git a/api/apps/restful_apis/system_api.py b/api/apps/restful_apis/system_api.py index 55c34c25a3..efc011bd87 100644 --- a/api/apps/restful_apis/system_api.py +++ b/api/apps/restful_apis/system_api.py @@ -34,10 +34,12 @@ from common.log_utils import get_log_levels, set_log_level from common import settings from rag.utils.redis_conn import REDIS_CONN + @manager.route("/system/ping", methods=["GET"]) # noqa: F821 async def ping(): return "pong", 200 + @manager.route("/system/version", methods=["GET"]) # noqa: F821 def version(): """ @@ -196,13 +198,7 @@ def oceanbase_status(): status_info = get_oceanbase_status() return get_json_result(data=status_info) except Exception as e: - return get_json_result( - data={ - "status": "error", - "message": f"Failed to get OceanBase status: {str(e)}" - }, - code=500 - ) + return get_json_result(data={"status": "error", "message": f"Failed to get OceanBase status: {str(e)}"}, code=500) @manager.route("/system/config", methods=["GET"]) # noqa: F821 @@ -222,16 +218,20 @@ def get_config(): type: integer 0 means disabled, 1 means enabled description: Whether user registration is enabled """ - return get_json_result(data={ - "registerEnabled": settings.REGISTER_ENABLED, - "disablePasswordLogin": settings.DISABLE_PASSWORD_LOGIN, - }) + return get_json_result( + data={ + "registerEnabled": settings.REGISTER_ENABLED, + "disablePasswordLogin": settings.DISABLE_PASSWORD_LOGIN, + } + ) + @manager.route("/system/healthz", methods=["GET"]) # noqa: F821 def healthz(): result, all_ok = run_health_checks() return jsonify(result), (200 if all_ok else 500) + @manager.route("/system/tokens", methods=["GET"]) # noqa: F821 @login_required def token_list(): @@ -409,6 +409,7 @@ async def set_logger_level(): description: Log level updated successfully """ from quart import request + data = await request.get_json() if not data or "pkg_name" not in data or "level" not in data: return get_data_error_result(message="pkg_name and level are required") diff --git a/api/apps/restful_apis/task_api.py b/api/apps/restful_apis/task_api.py index 2bd7a41802..8566c5583d 100644 --- a/api/apps/restful_apis/task_api.py +++ b/api/apps/restful_apis/task_api.py @@ -30,8 +30,7 @@ from rag.utils.redis_conn import REDIS_CONN @manager.route("/tasks//cancel", methods=["POST"]) # noqa: F821 @login_required async def cancel_task(task_id): - """Cancel a running task. - """ + """Cancel a running task.""" return await _cancel_task(task_id) @@ -77,11 +76,7 @@ async def _cancel_task(task_id): TaskService.model.update( progress_msg=TaskService.model.progress_msg + cancel_msg, progress=-1, - ).where( - (TaskService.model.id == task_id) - & (TaskService.model.progress >= 0) - & (TaskService.model.progress < 1) - ).execute() + ).where((TaskService.model.id == task_id) & (TaskService.model.progress >= 0) & (TaskService.model.progress < 1)).execute() except Exception as e: logging.warning("Failed to update task %s progress after cancellation: %s", task_id, str(e)) @@ -89,6 +84,7 @@ async def _cancel_task(task_id): # cancelled so that the UI reflects the state correctly. try: from api.db.services.document_service import DocumentService + doc_id = task.doc_id if doc_id and doc_id not in (CANVAS_DEBUG_DOC_ID, GRAPH_RAPTOR_FAKE_DOC_ID): _, doc = DocumentService.get_by_id(doc_id) diff --git a/api/apps/restful_apis/tenant_api.py b/api/apps/restful_apis/tenant_api.py index 56f40444a5..a0a69a3f4d 100644 --- a/api/apps/restful_apis/tenant_api.py +++ b/api/apps/restful_apis/tenant_api.py @@ -82,9 +82,7 @@ async def create(tenant_id): return get_data_error_result(message=f"{invite_user_email} is already in the team.") if user_tenant_role == UserTenantRole.OWNER: return get_data_error_result(message=f"{invite_user_email} is the owner of the team.") - return get_data_error_result( - message=f"{invite_user_email} is in the team, but the role: {user_tenant_role} is invalid." - ) + return get_data_error_result(message=f"{invite_user_email} is in the team, but the role: {user_tenant_role} is invalid.") UserTenantService.save( id=get_uuid(), diff --git a/api/apps/services/canvas_replica_service.py b/api/apps/services/canvas_replica_service.py index 17b6c99cb0..b4d28807e0 100644 --- a/api/apps/services/canvas_replica_service.py +++ b/api/apps/services/canvas_replica_service.py @@ -42,7 +42,6 @@ class CanvasReplicaService: LOCK_RETRY_ATTEMPTS = 3 LOCK_RETRY_SLEEP_SECS = 0.2 - @classmethod def normalize_dsl(cls, dsl): """Normalize DSL to a JSON-serializable dict. Raise ValueError on invalid input.""" @@ -61,17 +60,14 @@ class CanvasReplicaService: except Exception as e: raise ValueError("DSL is not JSON-serializable.") from e - @classmethod def _replica_key(cls, canvas_id: str, tenant_id: str, runtime_user_id: str) -> str: return f"{cls.REPLICA_KEY_PREFIX}:{canvas_id}:{tenant_id}:{runtime_user_id}" - @classmethod def _lock_key(cls, canvas_id: str, tenant_id: str, runtime_user_id: str) -> str: return f"{cls.LOCK_KEY_PREFIX}:{canvas_id}:{tenant_id}:{runtime_user_id}" - @classmethod def _read_payload(cls, replica_key: str): """Read replica payload from Redis; return None on missing/invalid content.""" @@ -88,14 +84,12 @@ class CanvasReplicaService: logging.warning("Failed to parse canvas replica %s: %s", replica_key, e) return None - @classmethod def _write_payload(cls, replica_key: str, payload: dict): """Write payload and refresh TTL.""" payload["updated_at"] = int(time.time()) REDIS_CONN.set_obj(replica_key, payload, cls.TTL_SECS) - @classmethod def _build_payload( cls, @@ -116,7 +110,6 @@ class CanvasReplicaService: "updated_at": int(time.time()), } - @classmethod def create_if_absent( cls, @@ -136,7 +129,6 @@ class CanvasReplicaService: cls._write_payload(replica_key, payload) return payload - @classmethod def bootstrap( cls, @@ -157,14 +149,12 @@ class CanvasReplicaService: title=title, ) - @classmethod def load_for_run(cls, canvas_id: str, tenant_id: str, runtime_user_id: str): """Load current runtime replica used by /completions.""" replica_key = cls._replica_key(canvas_id, str(tenant_id), str(runtime_user_id)) return cls._read_payload(replica_key) - @classmethod def replace_for_set( cls, @@ -203,7 +193,6 @@ class CanvasReplicaService: except Exception: logging.exception("Failed to release canvas replica lock: %s", lock_key) - @classmethod def _acquire_lock_with_retry(cls, lock_key: str): """Acquire distributed lock with bounded retries; return lock object or None.""" @@ -219,7 +208,6 @@ class CanvasReplicaService: time.sleep(cls.LOCK_RETRY_SLEEP_SECS + random.uniform(0, 0.1)) return None - @classmethod def commit_after_run( cls, diff --git a/api/apps/services/file_api_service.py b/api/apps/services/file_api_service.py index 3c0218ad2b..61e11335cd 100644 --- a/api/apps/services/file_api_service.py +++ b/api/apps/services/file_api_service.py @@ -48,15 +48,15 @@ async def upload_file(tenant_id: str, pf_id: str, file_objs: list): file_res = [] for file_obj in file_objs: - MAX_FILE_NUM_PER_USER = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0)) + MAX_FILE_NUM_PER_USER = int(os.environ.get("MAX_FILE_NUM_PER_USER", 0)) if 0 < MAX_FILE_NUM_PER_USER <= await thread_pool_exec(DocumentService.get_doc_count, tenant_id): return False, "Exceed the maximum file number of a free user!" if not file_obj.filename: file_obj_names = [pf_folder.name, file_obj.filename] else: - full_path = '/' + file_obj.filename - file_obj_names = full_path.split('/') + full_path = "/" + file_obj.filename + file_obj_names = full_path.split("/") file_len = len(file_obj_names) file_id_list = await thread_pool_exec(FileService.get_id_list_by_id, pf_id, file_obj_names, 1, [pf_id]) @@ -66,25 +66,19 @@ async def upload_file(tenant_id: str, pf_id: str, file_objs: list): e, file = await thread_pool_exec(FileService.get_by_id, file_id_list[len_id_list - 1]) if not e: return False, "Folder not found!" - last_folder = await thread_pool_exec( - FileService.create_folder, file, file_id_list[len_id_list - 1], file_obj_names, len_id_list, tenant_id, tenant_id - ) + last_folder = await thread_pool_exec(FileService.create_folder, file, file_id_list[len_id_list - 1], file_obj_names, len_id_list, tenant_id, tenant_id) else: e, file = await thread_pool_exec(FileService.get_by_id, file_id_list[len_id_list - 2]) if not e: return False, "Folder not found!" - last_folder = await thread_pool_exec( - FileService.create_folder, file, file_id_list[len_id_list - 2], file_obj_names, len_id_list, tenant_id, tenant_id - ) + last_folder = await thread_pool_exec(FileService.create_folder, file, file_id_list[len_id_list - 2], file_obj_names, len_id_list, tenant_id, tenant_id) filetype = filename_type(file_obj_names[file_len - 1]) location = file_obj_names[file_len - 1] while await thread_pool_exec(settings.STORAGE_IMPL.obj_exist, last_folder.id, location): location += "_" blob = await thread_pool_exec(file_obj.read) - filename = await thread_pool_exec( - duplicate_name, FileService.query, name=file_obj_names[file_len - 1], parent_id=last_folder.id - ) + filename = await thread_pool_exec(duplicate_name, FileService.query, name=file_obj_names[file_len - 1], parent_id=last_folder.id) await thread_pool_exec(settings.STORAGE_IMPL.put, last_folder.id, location, blob) file_data = { "id": get_uuid(), @@ -126,16 +120,18 @@ async def create_folder(tenant_id: str, name: str, pf_id: str = None, file_type: else: ft = FileType.VIRTUAL.value - file = FileService.insert({ - "id": get_uuid(), - "parent_id": pf_id, - "tenant_id": tenant_id, - "created_by": tenant_id, - "name": name, - "location": "", - "size": 0, - "type": ft, - }) + file = FileService.insert( + { + "id": get_uuid(), + "parent_id": pf_id, + "tenant_id": tenant_id, + "created_by": tenant_id, + "name": name, + "location": "", + "size": 0, + "type": ft, + } + ) return True, file.to_json() @@ -173,7 +169,6 @@ def list_files(tenant_id: str, args: dict): return True, {"total": total, "files": files, "parent_folder": parent_folder.to_json()} - def get_parent_folder(file_id: str, user_id: str = None): """ Get parent folder of a file with permission check. @@ -235,9 +230,9 @@ async def delete_files(uid: str, file_ids: list, auth_header: str = ""): try: import requests - host = getattr(settings, 'HOST_IP', '127.0.0.1') + host = getattr(settings, "HOST_IP", "127.0.0.1") # Go service runs on port+4 (9384 by default) - port = getattr(settings, 'HOST_PORT', 9380) + 4 + port = getattr(settings, "HOST_PORT", 9380) + 4 service_url = f"http://{host}:{port}" # List all spaces and find the one matching the name @@ -270,9 +265,9 @@ async def delete_files(uid: str, file_ids: list, auth_header: str = ""): from urllib.parse import quote # Construct service URL from settings - host = getattr(settings, 'HOST_IP', '127.0.0.1') + host = getattr(settings, "HOST_IP", "127.0.0.1") # Go service runs on port+4 (9384 by default) - port = getattr(settings, 'HOST_PORT', 9380) + 4 + port = getattr(settings, "HOST_PORT", 9380) + 4 service_url = f"http://{host}:{port}" # Get space UUID from space name @@ -289,37 +284,24 @@ async def delete_files(uid: str, file_ids: list, auth_header: str = ""): try: data = response.json() if data.get("code") == 0: - logging.info( - f"Successfully deleted skill index: space={space_name}, skill={skill_name}, " - f"status={response.status_code}, code=0" - ) + logging.info(f"Successfully deleted skill index: space={space_name}, skill={skill_name}, status={response.status_code}, code=0") return True else: app_code = data.get("code", "unknown") app_msg = data.get("message", "no message") logging.error( - f"Failed to delete skill index: space={space_name}, skill={skill_name}, " - f"status={response.status_code}, app_code={app_code}, app_msg={app_msg}, " - f"response={response.text}" + f"Failed to delete skill index: space={space_name}, skill={skill_name}, status={response.status_code}, app_code={app_code}, app_msg={app_msg}, response={response.text}" ) return False except ValueError as json_err: # JSON decode error - treat as failure - logging.error( - f"Failed to parse delete response JSON: space={space_name}, skill={skill_name}, " - f"error={json_err}, raw_response={response.text}" - ) + logging.error(f"Failed to parse delete response JSON: space={space_name}, skill={skill_name}, error={json_err}, raw_response={response.text}") return False else: - logging.error( - f"Failed to delete skill index: space={space_name}, skill={skill_name}, " - f"status={response.status_code}, response={response.text}" - ) + logging.error(f"Failed to delete skill index: space={space_name}, skill={skill_name}, status={response.status_code}, response={response.text}") return False except Exception as e: - logging.error( - f"Exception deleting skill index: space={space_name}, skill={skill_name}, error={e}" - ) + logging.error(f"Exception deleting skill index: space={space_name}, skill={skill_name}, error={e}") return False def _delete_single_file(file) -> int: @@ -408,18 +390,12 @@ async def delete_files(uid: str, file_ids: list, auth_header: str = ""): logging.info(f"Deleting skill index for skill '{folder.name}' in space '{current_space_name}'") index_deleted = _delete_skill_index(tenant_id, current_space_name, folder.name, auth_header) if not index_deleted: - logging.error( - f"Aborting folder deletion due to index deletion failure: " - f"folder={folder.name}, space={current_space_name}" - ) - errors.append( - f"Failed to delete skill index for folder '{folder.name}' in space '{current_space_name}'. " - f"Folder deletion aborted to prevent orphaned indexes." - ) + logging.error(f"Aborting folder deletion due to index deletion failure: folder={folder.name}, space={current_space_name}") + errors.append(f"Failed to delete skill index for folder '{folder.name}' in space '{current_space_name}'. Folder deletion aborted to prevent orphaned indexes.") return deleted sub_files = FileService.list_all_files_by_parent_id(folder.id) logging.info(f"Folder '{folder.name}': found {len(sub_files)} children to delete") - + for sub_file in sub_files: if sub_file.type == FileType.FOLDER.value: deleted += _delete_folder_recursive(sub_file, tenant_id) @@ -432,16 +408,16 @@ async def delete_files(uid: str, file_ids: list, auth_header: str = ""): errors.append(f"Failed to delete folder record {folder.id}: {e}") else: deleted += 1 - + try: - if hasattr(settings.STORAGE_IMPL, 'remove_bucket'): + if hasattr(settings.STORAGE_IMPL, "remove_bucket"): logging.info(f"Removing storage bucket for folder '{folder.name}' (id={folder.id})") settings.STORAGE_IMPL.remove_bucket(folder.id) else: logging.debug(f"Storage implementation does not support remove_bucket, skipping for folder '{folder.name}'") except Exception as e: logging.warning(f"Failed to remove storage bucket for folder '{folder.name}' (id={folder.id}): {e}") - + return deleted def _rm_sync(): @@ -513,8 +489,7 @@ async def move_files(uid: str, src_file_ids: list, dest_file_id: str = None, new if new_name: file = files_dict[src_file_ids[0]] - if file.type != FileType.FOLDER.value and \ - pathlib.Path(new_name.lower()).suffix != pathlib.Path(file.name.lower()).suffix: + if file.type != FileType.FOLDER.value and pathlib.Path(new_name.lower()).suffix != pathlib.Path(file.name.lower()).suffix: return False, "The extension of file can't be changed" target_parent_id = dest_folder.id if dest_folder else file.parent_id for f in FileService.query(name=new_name, parent_id=target_parent_id): @@ -541,16 +516,18 @@ async def move_files(uid: str, src_file_ids: list, dest_file_id: str = None, new if existing_folder: new_folder = existing_folder[0] else: - new_folder = FileService.insert({ - "id": get_uuid(), - "parent_id": dest_folder_entry.id, - "tenant_id": source_file_entry.tenant_id, - "created_by": source_file_entry.tenant_id, - "name": effective_name, - "location": "", - "size": 0, - "type": FileType.FOLDER.value, - }) + new_folder = FileService.insert( + { + "id": get_uuid(), + "parent_id": dest_folder_entry.id, + "tenant_id": source_file_entry.tenant_id, + "created_by": source_file_entry.tenant_id, + "name": effective_name, + "location": "", + "size": 0, + "type": FileType.FOLDER.value, + } + ) sub_files = FileService.list_all_files_by_parent_id(source_file_entry.id) for sub_file in sub_files: @@ -569,8 +546,10 @@ async def move_files(uid: str, src_file_ids: list, dest_file_id: str = None, new new_location += "_" try: moved = settings.STORAGE_IMPL.move( - source_file_entry.parent_id, source_file_entry.location, - dest_folder_entry.id, new_location, + source_file_entry.parent_id, + source_file_entry.location, + dest_folder_entry.id, + new_location, ) except Exception as storage_err: raise RuntimeError(f"Move file failed at storage layer: {str(storage_err)}") diff --git a/api/apps/services/memory_api_service.py b/api/apps/services/memory_api_service.py index 7d5955407a..a12eb15946 100644 --- a/api/apps/services/memory_api_service.py +++ b/api/apps/services/memory_api_service.py @@ -98,13 +98,7 @@ async def create_memory(memory_info: dict): if invalid_type: raise ArgumentException(f"Memory type '{invalid_type}' is not supported.") memory_type = list(memory_type) - success, res = MemoryService.create_memory( - tenant_id=current_user.id, - name=memory_name, - memory_type=memory_type, - embd_id=memory_info["embd_id"], - llm_id=memory_info["llm_id"] - ) + success, res = MemoryService.create_memory(tenant_id=current_user.id, name=memory_name, memory_type=memory_type, embd_id=memory_info["embd_id"], llm_id=memory_info["llm_id"]) if success: return True, format_ret_data_from_memory(res) else: @@ -243,7 +237,7 @@ async def delete_memory(memory_id): return True -async def list_memory(filter_params: dict, keywords: str, page: int=1, page_size: int = 50): +async def list_memory(filter_params: dict, keywords: str, page: int = 1, page_size: int = 50): """ :param filter_params: { "memory_type": list[str], @@ -268,9 +262,7 @@ async def list_memory(filter_params: dict, keywords: str, page: int=1, page_size memory_list, count = MemoryService.get_by_filter(filter_dict, keywords, page, page_size) [memory.update({"memory_type": get_memory_type_human(memory["memory_type"])}) for memory in memory_list] - return { - "memory_list": memory_list, "total_count": count - } + return {"memory_list": memory_list, "total_count": count} async def get_memory_config(memory_id): @@ -280,10 +272,9 @@ async def get_memory_config(memory_id): return format_ret_data_from_memory(memory) -async def get_memory_messages(memory_id, agent_ids: list[str], keywords: str, page: int=1, page_size: int = 50): +async def get_memory_messages(memory_id, agent_ids: list[str], keywords: str, page: int = 1, page_size: int = 50): memory = _require_memory_access(memory_id) - messages = MessageService.list_message( - memory.tenant_id, memory_id, agent_ids, keywords, page, page_size) + messages = MessageService.list_message(memory.tenant_id, memory_id, agent_ids, keywords, page, page_size) agent_name_mapping = {} extract_task_mapping = {} if messages["message_list"]: @@ -291,7 +282,7 @@ async def get_memory_messages(memory_id, agent_ids: list[str], keywords: str, pa agent_name_mapping = {agent["id"]: agent["title"] for agent in agent_list} task_list = TaskService.get_tasks_progress_by_doc_ids([memory_id]) if task_list: - task_list.sort(key=lambda t: t["create_time"]) # asc, use newer when exist more than one task + task_list.sort(key=lambda t: t["create_time"]) # asc, use newer when exist more than one task for task in task_list: # the 'digest' field carries the source_id when a task is created, so use 'digest' as key extract_task_mapping.update({int(task["digest"]): task}) @@ -324,10 +315,7 @@ async def forget_message(memory_id: str, message_id: int): memory = _require_memory_access(memory_id) forget_time = timestamp_to_date(current_timestamp()) - update_succeed = MessageService.update_message( - {"memory_id": memory_id, "message_id": int(message_id)}, - {"forget_at": forget_time}, - memory.tenant_id, memory_id) + update_succeed = MessageService.update_message({"memory_id": memory_id, "message_id": int(message_id)}, {"forget_at": forget_time}, memory.tenant_id, memory_id) if update_succeed: return True raise Exception(f"Failed to forget message '{message_id}' in memory '{memory_id}'.") @@ -336,10 +324,7 @@ async def forget_message(memory_id: str, message_id: int): async def update_message_status(memory_id: str, message_id: int, status: bool): memory = _require_memory_access(memory_id) - update_succeed = MessageService.update_message( - {"memory_id": memory_id, "message_id": int(message_id)}, - {"status": status}, - memory.tenant_id, memory_id) + update_succeed = MessageService.update_message({"memory_id": memory_id, "message_id": int(message_id)}, {"status": status}, memory.tenant_id, memory_id) if update_succeed: return True raise Exception(f"Failed to set status for message '{message_id}' in memory '{memory_id}'.") @@ -383,13 +368,7 @@ async def get_messages(memory_ids: list[str], agent_id: str = "", session_id: st return [] uids = [memory.tenant_id for memory in memory_list] accessible_memory_ids = [memory.id for memory in memory_list] - res = MessageService.get_recent_messages( - uids, - accessible_memory_ids, - agent_id, - session_id, - limit - ) + res = MessageService.get_recent_messages(uids, accessible_memory_ids, agent_id, session_id, limit) return res diff --git a/api/apps/services/models_api_service.py b/api/apps/services/models_api_service.py index aaef7f90ca..18cee06d0d 100644 --- a/api/apps/services/models_api_service.py +++ b/api/apps/services/models_api_service.py @@ -98,11 +98,7 @@ def _get_model_info(tenant_id: str, default_model: str, model_type: str): # Special case: TEI Builtin embedding model compose_profiles = os.getenv("COMPOSE_PROFILES", "") tei_model = os.getenv("TEI_MODEL", "") - if (model_type == "embedding" - and "tei-" in compose_profiles - and tei_model - and model_name == tei_model - and (not provider_name or provider_name == "Builtin")): + if model_type == "embedding" and "tei-" in compose_profiles and tei_model and model_name == tei_model and (not provider_name or provider_name == "Builtin"): return { "model_provider": "Builtin", "model_instance": "default", @@ -124,9 +120,7 @@ def _get_model_info(tenant_id: str, default_model: str, model_type: str): return None # Check if model is enabled (no TenantModel record or status != inactive means enabled) - model_entity = TenantModelService.get_by_provider_id_and_instance_id_and_model_type_and_model_name( - provider_obj.id, instance_obj.id, model_type, model_name - ) + model_entity = TenantModelService.get_by_provider_id_and_instance_id_and_model_type_and_model_name(provider_obj.id, instance_obj.id, model_type, model_name) enable = model_entity is None or model_entity.status == ActiveStatusEnum.ACTIVE.value if not enable: @@ -134,12 +128,12 @@ def _get_model_info(tenant_id: str, default_model: str, model_type: str): if model_entity: return { - "model_provider": provider_name, - "model_instance": instance_name, - "model_name": model_name, - "model_type": model_type, - "enable": enable, - } + "model_provider": provider_name, + "model_instance": instance_name, + "model_name": model_name, + "model_type": model_type, + "enable": enable, + } # Check if model is in the LLM factory info factory_info = [f for f in (FACTORY_LLM_INFOS or []) if f["name"] == provider_name] @@ -185,12 +179,7 @@ def _check_model_available(tenant_id: str, provider_name: str, instance_name: st return True, None compose_profiles = os.getenv("COMPOSE_PROFILES", "") - is_tei_builtin_embedding = ( - model_type == LLMType.EMBEDDING.value - and "tei-" in compose_profiles - and model_name == os.getenv("TEI_MODEL", "") - and (provider_name == "Builtin" or not provider_name) - ) + is_tei_builtin_embedding = model_type == LLMType.EMBEDDING.value and "tei-" in compose_profiles and model_name == os.getenv("TEI_MODEL", "") and (provider_name == "Builtin" or not provider_name) if is_tei_builtin_embedding: return True, None @@ -210,9 +199,7 @@ def _check_model_available(tenant_id: str, provider_name: str, instance_name: st return False, f"Provider '{provider_name}' not found in factory info" model_type = MODEL_TAG_TO_TYPE.get(model_type, model_type) # Check if model is disabled - model_entity = TenantModelService.get_by_provider_id_and_instance_id_and_model_type_and_model_name( - provider_obj.id, instance_obj.id, model_type, model_name - ) + model_entity = TenantModelService.get_by_provider_id_and_instance_id_and_model_type_and_model_name(provider_obj.id, instance_obj.id, model_type, model_name) if model_entity: if model_entity.status != ActiveStatusEnum.ACTIVE.value: return False, f"Model '{model_name}' isn't available" @@ -297,7 +284,7 @@ def set_tenant_default_models(tenant_id: str, model_provider: str, model_instanc return True, "success" -def list_tenant_added_models(tenant_id: str, model_type_filter: str=None): +def list_tenant_added_models(tenant_id: str, model_type_filter: str = None): """ List all added models for a tenant. @@ -369,14 +356,16 @@ def list_tenant_added_models(tenant_id: str, model_type_filter: str=None): if not model_types: continue - added_models.append({ - "model_type": model_types, - "name": llm["llm_name"], - "provider_id": factory_instance.provider_id, - "provider_name": provider_info_map[factory_instance.provider_id].provider_name if provider_info_map.get(factory_instance.provider_id) else "", - "instance_id": factory_instance.id, - "instance_name": factory_instance.instance_name - }) + added_models.append( + { + "model_type": model_types, + "name": llm["llm_name"], + "provider_id": factory_instance.provider_id, + "provider_name": provider_info_map[factory_instance.provider_id].provider_name if provider_info_map.get(factory_instance.provider_id) else "", + "instance_id": factory_instance.id, + "instance_name": factory_instance.instance_name, + } + ) manual_added_model_record_keys = list(set(model_record_map.keys()) - set(model_key_in_factory)) if manual_added_model_record_keys: @@ -390,33 +379,34 @@ def list_tenant_added_models(tenant_id: str, model_type_filter: str=None): if not model_types: continue - added_models.append({ - "model_type": model_types, - "name": model_name, - "provider_id": provider_id, - "provider_name": provider_info_map[provider_id].provider_name if provider_info_map.get(provider_id) else "", - "instance_id": instance_id, - "instance_name": instance_info_map[instance_id].instance_name if instance_info_map.get(instance_id) else "" - }) + added_models.append( + { + "model_type": model_types, + "name": model_name, + "provider_id": provider_id, + "provider_name": provider_info_map[provider_id].provider_name if provider_info_map.get(provider_id) else "", + "instance_id": instance_id, + "instance_name": instance_info_map[instance_id].instance_name if instance_info_map.get(instance_id) else "", + } + ) # Add TEI Builtin embedding model if configured compose_profiles = os.getenv("COMPOSE_PROFILES", "") tei_model = os.getenv("TEI_MODEL", "") if "tei-" in compose_profiles and tei_model: if not model_type_filter or model_type_filter == "embedding": - tei_already_added = any( - m["provider_name"] == "Builtin" and m["name"] == tei_model - for m in added_models - ) + tei_already_added = any(m["provider_name"] == "Builtin" and m["name"] == tei_model for m in added_models) if not tei_already_added: - added_models.append({ - "model_type": ["embedding"], - "name": tei_model, - "provider_id": "", - "provider_name": "Builtin", - "instance_id": "", - "instance_name": "default", - }) + added_models.append( + { + "model_type": ["embedding"], + "name": tei_model, + "provider_id": "", + "provider_name": "Builtin", + "instance_id": "", + "instance_name": "default", + } + ) added_models.sort(key=lambda x: (factory_rank_mapping.get(x["provider_name"]), x["provider_name"], x["instance_name"])) diff --git a/api/channels/bootstrap.py b/api/channels/bootstrap.py index 212e137d2d..a607d40383 100644 --- a/api/channels/bootstrap.py +++ b/api/channels/bootstrap.py @@ -21,6 +21,7 @@ table: newly added bots are started, deleted ones are stopped, and edited ones messages are answered with a RAG completion routed through the conversation wired to that bot. Replaces the standalone ``server.py`` entrypoint. """ + from __future__ import annotations import asyncio @@ -87,9 +88,7 @@ def _build_one(account_id: str, channel: str, credential: dict): from api.channels.core.registry import build_channels # account_id == chat_channel.id. - instances = build_channels( - {"channels": {channel: {"accounts": {account_id: credential}}}} - ) + instances = build_channels({"channels": {channel: {"accounts": {account_id: credential}}}}) return instances[0] if instances else None @@ -242,9 +241,7 @@ async def _reconcile(running: dict, failed: dict) -> None: active_whatsapp = any(channel == "whatsapp" for channel, _, _ in desired.values()) if not active_whatsapp: - active_whatsapp = any( - entry["ch"].channel_id == "whatsapp" for entry in running.values() - ) + active_whatsapp = any(entry["ch"].channel_id == "whatsapp" for entry in running.values()) from api.channels.whatsapp.gateway import sync_whatsapp_gateway try: diff --git a/api/channels/dingtalk/__init__.py b/api/channels/dingtalk/__init__.py index 08f075b2b2..c53f4fe2e8 100644 --- a/api/channels/dingtalk/__init__.py +++ b/api/channels/dingtalk/__init__.py @@ -1,2 +1 @@ from .channel import _build # noqa: F401 - diff --git a/api/channels/dingtalk/channel.py b/api/channels/dingtalk/channel.py index e0394bc93d..edb2592f10 100644 --- a/api/channels/dingtalk/channel.py +++ b/api/channels/dingtalk/channel.py @@ -136,9 +136,7 @@ class DingTalkChannel(Channel): if msg.type == aiohttp.WSMsgType.TEXT: await self._handle_ws_payload(msg.data) elif msg.type == aiohttp.WSMsgType.BINARY: - await self._handle_ws_payload( - msg.data.decode("utf-8", "ignore") - ) + await self._handle_ws_payload(msg.data.decode("utf-8", "ignore")) elif msg.type in ( aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSED, @@ -274,27 +272,9 @@ class DingTalkChannel(Channel): return session_webhook = str(data.get("sessionWebhook") or obj.get("sessionWebhook") or "").strip() - callback_message_id = str( - headers.get("messageId") - or obj.get("messageId") - or data.get("messageId") - or data.get("msgId") - or obj.get("msgId") - or "" - ).strip() - chat_id = str( - data.get("conversationId") - or data.get("chatId") - or data.get("openConversationId") - or data.get("msgId") - or "" - ).strip() - sender_id = str( - data.get("senderId") - or data.get("senderStaffId") - or data.get("userId") - or "" - ).strip() + callback_message_id = str(headers.get("messageId") or obj.get("messageId") or data.get("messageId") or data.get("msgId") or obj.get("msgId") or "").strip() + chat_id = str(data.get("conversationId") or data.get("chatId") or data.get("openConversationId") or data.get("msgId") or "").strip() + sender_id = str(data.get("senderId") or data.get("senderStaffId") or data.get("userId") or "").strip() message_id = str(data.get("msgId") or obj.get("msgId") or "").strip() chat_type = str(data.get("conversationType") or "").strip() text = self._extract_text(data) @@ -389,44 +369,21 @@ class DingTalkChannel(Channel): parts = urlsplit(endpoint) query = dict(parse_qsl(parts.query, keep_blank_values=True)) query.setdefault("ticket", ticket) - return urlunsplit( - (parts.scheme, parts.netloc, parts.path, urlencode(query), parts.fragment) - ) + return urlunsplit((parts.scheme, parts.netloc, parts.path, urlencode(query), parts.fragment)) def _build_dedup_key(self, data: Dict[str, Any], obj: Dict[str, Any]) -> str: message_id = str(data.get("msgId") or obj.get("msgId") or "").strip() if message_id: return f"msg:{message_id}" - callback_message_id = str( - obj.get("headers", {}).get("messageId") - or obj.get("messageId") - or data.get("messageId") - or "" - ).strip() + callback_message_id = str(obj.get("headers", {}).get("messageId") or obj.get("messageId") or data.get("messageId") or "").strip() if callback_message_id: return f"callback:{callback_message_id}" - conversation_id = str( - data.get("conversationId") - or data.get("chatId") - or data.get("openConversationId") - or "" - ).strip() - sender_id = str( - data.get("senderId") - or data.get("senderStaffId") - or data.get("userId") - or "" - ).strip() + conversation_id = str(data.get("conversationId") or data.get("chatId") or data.get("openConversationId") or "").strip() + sender_id = str(data.get("senderId") or data.get("senderStaffId") or data.get("userId") or "").strip() text = self._extract_text(data).strip() - event_ts = str( - data.get("eventTime") - or obj.get("eventTime") - or data.get("timestamp") - or obj.get("timestamp") - or "" - ).strip() + event_ts = str(data.get("eventTime") or obj.get("eventTime") or data.get("timestamp") or obj.get("timestamp") or "").strip() if conversation_id and sender_id and text: digest = hashlib.sha1(text.encode("utf-8")).hexdigest()[:16] suffix = f":{event_ts}" if event_ts else "" diff --git a/api/channels/feishu/channel.py b/api/channels/feishu/channel.py index 93d9a05da4..2fbe7e3730 100644 --- a/api/channels/feishu/channel.py +++ b/api/channels/feishu/channel.py @@ -46,14 +46,7 @@ class FeishuChannel(Channel): self._loop: Optional[asyncio.AbstractEventLoop] = None self._ws_client = None self._ws_thread: Optional[threading.Thread] = None - self._rest = ( - lark.Client.builder() - .app_id(account.app_id) - .app_secret(account.app_secret) - .domain(_lark_domain(account.domain)) - .log_level(lark.LogLevel.DEBUG) - .build() - ) + self._rest = lark.Client.builder().app_id(account.app_id).app_secret(account.app_secret).domain(_lark_domain(account.domain)).log_level(lark.LogLevel.DEBUG).build() async def start(self) -> None: # The channel loop is the cross-thread dispatch target for inbound events. @@ -82,11 +75,7 @@ class FeishuChannel(Channel): # server's main loop. lark_ws_client.loop = loop try: - handler = ( - lark.EventDispatcherHandler.builder("", "") - .register_p2_im_message_receive_v1(self._on_message_receive) - .build() - ) + handler = lark.EventDispatcherHandler.builder("", "").register_p2_im_message_receive_v1(self._on_message_receive).build() self._ws_client = lark.ws.Client( self.account.app_id, self.account.app_secret, @@ -123,9 +112,7 @@ class FeishuChannel(Channel): if asyncio.iscoroutine(result): ws_loop = lark_ws_client.loop if ws_loop and not ws_loop.is_closed(): - await asyncio.wrap_future( - asyncio.run_coroutine_threadsafe(result, ws_loop) - ) + await asyncio.wrap_future(asyncio.run_coroutine_threadsafe(result, ws_loop)) else: await result except Exception: @@ -139,30 +126,11 @@ class FeishuChannel(Channel): async def send(self, message: OutgoingMessage) -> None: content = json.dumps({"text": message.text}, ensure_ascii=False) if message.reply_to_message_id: - req = ( - ReplyMessageRequest.builder() - .message_id(message.reply_to_message_id) - .request_body( - ReplyMessageRequestBody.builder() - .content(content) - .msg_type("text") - .build() - ) - .build() - ) + req = ReplyMessageRequest.builder().message_id(message.reply_to_message_id).request_body(ReplyMessageRequestBody.builder().content(content).msg_type("text").build()).build() resp = await asyncio.to_thread(self._rest.im.v1.message.reply, req) else: req = ( - CreateMessageRequest.builder() - .receive_id_type("chat_id") - .request_body( - CreateMessageRequestBody.builder() - .receive_id(message.chat_id) - .content(content) - .msg_type("text") - .build() - ) - .build() + CreateMessageRequest.builder().receive_id_type("chat_id").request_body(CreateMessageRequestBody.builder().receive_id(message.chat_id).content(content).msg_type("text").build()).build() ) resp = await asyncio.to_thread(self._rest.im.v1.message.create, req) if not resp.success(): @@ -219,9 +187,7 @@ def _build(account_id: str, cfg: dict) -> Channel: app_id = cfg.get("app_id") app_secret = cfg.get("app_secret") if not app_id or not app_secret: - raise ValueError( - f"feishu account '{account_id}' is missing app_id or app_secret" - ) + raise ValueError(f"feishu account '{account_id}' is missing app_id or app_secret") return FeishuChannel( FeishuAccount( account_id=account_id, diff --git a/api/channels/line/channel.py b/api/channels/line/channel.py index f4222a0f9b..d66a48590d 100644 --- a/api/channels/line/channel.py +++ b/api/channels/line/channel.py @@ -213,9 +213,7 @@ def _build(account_id: str, cfg: dict) -> Channel: channel_secret = cfg.get("channel_secret") channel_access_token = cfg.get("channel_access_token") if not channel_secret or not channel_access_token: - raise ValueError( - f"line account '{account_id}' missing channel_secret or channel_access_token" - ) + raise ValueError(f"line account '{account_id}' missing channel_secret or channel_access_token") return LineChannel( LineAccount( account_id=account_id, diff --git a/api/channels/wecom/channel.py b/api/channels/wecom/channel.py index 4ad04e1f20..b102fafa00 100644 --- a/api/channels/wecom/channel.py +++ b/api/channels/wecom/channel.py @@ -85,9 +85,7 @@ class _SharedWebhookServer: if request.method == "GET": echo_str = request.query.get("echostr", "") try: - decrypted = channel.crypto.check_signature( - signature, timestamp, nonce, echo_str - ) + decrypted = channel.crypto.check_signature(signature, timestamp, nonce, echo_str) return web.Response(text=decrypted) except InvalidSignatureException: return web.Response(status=403, text="bad signature") @@ -148,11 +146,7 @@ class WeComChannel(Channel): self.account = account self.account_id = account.account_id self.connection_type = (account.connection_type or "webhook").strip().lower() - self.crypto = ( - WeChatCrypto(account.token, account.aes_key, account.corp_id) - if self.connection_type == "webhook" - else None - ) + self.crypto = WeChatCrypto(account.token, account.aes_key, account.corp_id) if self.connection_type == "webhook" else None self._server: Optional[_SharedWebhookServer] = None self._access_token: Optional[str] = None self._access_token_expires_at: float = 0.0 @@ -180,9 +174,7 @@ class WeComChannel(Channel): ) return - self._server = await _acquire_server( - self.account.webhook_host, self.account.webhook_port - ) + self._server = await _acquire_server(self.account.webhook_host, self.account.webhook_port) self._server.channels[self.account_id] = self LOGGER.info( "[wecom:%s] registered at path /wecom/%s/callback (agent_id=%s)", @@ -215,9 +207,7 @@ class WeComChannel(Channel): self._ws = None if self._server is not None: self._server.channels.pop(self.account_id, None) - await _release_server( - self.account.webhook_host, self.account.webhook_port - ) + await _release_server(self.account.webhook_host, self.account.webhook_port) self._server = None async def _handle_text_message( @@ -285,18 +275,14 @@ class WeComChannel(Channel): self.account_id, ) await self._subscribe_websocket(ws) - self._heartbeat_task = asyncio.create_task( - self._heartbeat_loop(ws) - ) + self._heartbeat_task = asyncio.create_task(self._heartbeat_loop(ws)) async for msg in ws: if self._stop_requested: break if msg.type == aiohttp.WSMsgType.TEXT: await self._handle_ws_payload(msg.data) elif msg.type == aiohttp.WSMsgType.BINARY: - await self._handle_ws_payload( - msg.data.decode("utf-8", "ignore") - ) + await self._handle_ws_payload(msg.data.decode("utf-8", "ignore")) elif msg.type == aiohttp.WSMsgType.PONG: LOGGER.debug("[wecom:%s] websocket pong", self.account_id) elif msg.type in ( @@ -372,9 +358,7 @@ class WeComChannel(Channel): errcode = int(resp.get("errcode", 0) or 0) if errcode != 0: if errcode == 853000: - raise PermissionError( - f"wecom websocket subscribe failed: invalid bot_id or secret: {resp}" - ) + raise PermissionError(f"wecom websocket subscribe failed: invalid bot_id or secret: {resp}") raise RuntimeError(f"wecom websocket subscribe failed: {resp}") LOGGER.info("[wecom:%s] websocket subscribed", self.account_id) @@ -460,17 +444,13 @@ class WeComChannel(Channel): "corpsecret": self.account.secret, } async with aiohttp.ClientSession() as session: - async with session.get( - f"{WECOM_API_BASE}/gettoken", params=params - ) as resp: + async with session.get(f"{WECOM_API_BASE}/gettoken", params=params) as resp: data = await resp.json(content_type=None) if data.get("errcode", 0) != 0 or "access_token" not in data: raise RuntimeError(f"wecom gettoken failed: {data}") self._access_token = data["access_token"] # 60s safety margin against clock skew / in-flight calls. - self._access_token_expires_at = ( - now + int(data.get("expires_in", 7200)) - 60 - ) + self._access_token_expires_at = now + int(data.get("expires_in", 7200)) - 60 return self._access_token async def send(self, message: OutgoingMessage) -> None: @@ -551,9 +531,7 @@ def _build(account_id: str, cfg: dict) -> Channel: required = ("corp_id", "agent_id", "secret", "token", "aes_key") missing = [k for k in required if not cfg.get(k)] if missing: - raise ValueError( - f"wecom account '{account_id}' missing required fields: {missing}" - ) + raise ValueError(f"wecom account '{account_id}' missing required fields: {missing}") agent_id = 0 aes_key = "" corp_id = str(cfg.get("corp_id") or "") @@ -563,16 +541,12 @@ def _build(account_id: str, cfg: dict) -> Channel: try: agent_id = int(cfg["agent_id"]) except (TypeError, ValueError) as err: - raise ValueError( - f"wecom account '{account_id}' agent_id must be int: {err}" - ) from err + raise ValueError(f"wecom account '{account_id}' agent_id must be int: {err}") from err # WeCom EncodingAESKey is always 43 characters; reject placeholders early so # the failure is a clear message instead of a base64 "Incorrect padding" error. aes_key = str(cfg["aes_key"]) if len(aes_key) != 43: - raise ValueError( - f"wecom account '{account_id}' aes_key (EncodingAESKey) must be 43 characters, got {len(aes_key)}" - ) + raise ValueError(f"wecom account '{account_id}' aes_key (EncodingAESKey) must be 43 characters, got {len(aes_key)}") return WeComChannel( WeComAccount( account_id=account_id, diff --git a/api/channels/whatsapp/channel.py b/api/channels/whatsapp/channel.py index 20988cfb39..b3a9966d2f 100644 --- a/api/channels/whatsapp/channel.py +++ b/api/channels/whatsapp/channel.py @@ -88,10 +88,7 @@ class WhatsAppChannel(Channel): ws_base = f"wss://{base_url.removeprefix('https://')}" else: ws_base = base_url - return ( - f"{ws_base}/whatsapp/{self._session_key()}/events/ws" - f"?after={self._event_cursor}" - ) + return f"{ws_base}/whatsapp/{self._session_key()}/events/ws?after={self._event_cursor}" def _gateway_headers(self) -> dict[str, str]: token = str(self.account.gateway_token or "").strip() @@ -129,9 +126,7 @@ class WhatsAppChannel(Channel): return {} content_type = resp.headers.get("content-type", "") if "application/json" not in content_type.lower(): - raise RuntimeError( - f"unexpected response content-type: {content_type or 'unknown'}, response: {text[:200]}" - ) + raise RuntimeError(f"unexpected response content-type: {content_type or 'unknown'}, response: {text[:200]}") try: return await resp.json() except Exception as ex: @@ -370,12 +365,7 @@ class WhatsAppChannel(Channel): def _build(account_id: str, cfg: dict) -> Channel: - gateway_base_url = str( - cfg.get("gateway_base_url") - or cfg.get("gateway_url") - or cfg.get("control_url") - or _default_gateway_base_url() - ) + gateway_base_url = str(cfg.get("gateway_base_url") or cfg.get("gateway_url") or cfg.get("control_url") or _default_gateway_base_url()) gateway_token = str(cfg.get("gateway_token") or cfg.get("token") or "") session_key = str(cfg.get("session_key") or cfg.get("session_id") or account_id) timeout_secs = int(cfg.get("timeout_secs") or WHATSAPP_DEFAULT_TIMEOUT_SECS) diff --git a/api/channels/whatsapp/gateway.py b/api/channels/whatsapp/gateway.py index 533b2879a8..a58149715d 100644 --- a/api/channels/whatsapp/gateway.py +++ b/api/channels/whatsapp/gateway.py @@ -97,9 +97,7 @@ class WhatsAppGatewayRuntime: npm = shutil.which("npm") if not npm: if not _deps_install_warned: - LOGGER.warning( - "npm is not available; WhatsApp gateway dependencies cannot be installed automatically" - ) + LOGGER.warning("npm is not available; WhatsApp gateway dependencies cannot be installed automatically") _deps_install_warned = True return diff --git a/api/common/base64.py b/api/common/base64.py index 2b37dd2819..7d3dd245f3 100644 --- a/api/common/base64.py +++ b/api/common/base64.py @@ -16,6 +16,7 @@ import base64 + def encode_to_base64(input_string): - base64_encoded = base64.b64encode(input_string.encode('utf-8')) - return base64_encoded.decode('utf-8') \ No newline at end of file + base64_encoded = base64.b64encode(input_string.encode("utf-8")) + return base64_encoded.decode("utf-8") diff --git a/api/constants.py b/api/constants.py index 3f74cc46d6..72661f4761 100644 --- a/api/constants.py +++ b/api/constants.py @@ -26,4 +26,4 @@ DATASET_NAME_LIMIT = 128 FILE_NAME_LEN_LIMIT = 255 MEMORY_NAME_LIMIT = 128 NICKNAME_MAX_LENGTH = 100 -MEMORY_SIZE_LIMIT = 10*1024*1024 # Byte +MEMORY_SIZE_LIMIT = 10 * 1024 * 1024 # Byte diff --git a/api/db/db_utils.py b/api/db/db_utils.py index e86f1234a2..98e44f8d53 100644 --- a/api/db/db_utils.py +++ b/api/db/db_utils.py @@ -30,19 +30,19 @@ def bulk_insert_into_db(model, data_source, replace_on_conflict=False): for i, data in enumerate(data_source): current_time = current_timestamp() + i current_date = timestamp_to_date(current_time) - if 'create_time' not in data: - data['create_time'] = current_time - data['create_date'] = timestamp_to_date(data['create_time']) - data['update_time'] = current_time - data['update_date'] = current_date + if "create_time" not in data: + data["create_time"] = current_time + data["create_date"] = timestamp_to_date(data["create_time"]) + data["update_time"] = current_time + data["update_date"] = current_date - preserve = tuple(data_source[0].keys() - {'create_time', 'create_date'}) + preserve = tuple(data_source[0].keys() - {"create_time", "create_date"}) batch_size = 1000 for i in range(0, len(data_source), batch_size): with DB.atomic(): - query = model.insert_many(data_source[i:i + batch_size]) + query = model.insert_many(data_source[i : i + batch_size]) if replace_on_conflict: if isinstance(DB, PooledMySQLDatabase): query = query.on_conflict(preserve=preserve) @@ -52,8 +52,7 @@ def bulk_insert_into_db(model, data_source, replace_on_conflict=False): def get_dynamic_db_model(base, job_id): - return type(base.model( - table_index=get_dynamic_tracking_table_index(job_id=job_id))) + return type(base.model(table_index=get_dynamic_tracking_table_index(job_id=job_id))) def get_dynamic_tracking_table_index(job_id): @@ -62,7 +61,7 @@ def get_dynamic_tracking_table_index(job_id): def fill_db_model_object(model_object, human_model_dict): for k, v in human_model_dict.items(): - attr_name = 'f_%s' % k + attr_name = "f_%s" % k if hasattr(model_object.__class__, attr_name): setattr(model_object, attr_name, v) return model_object @@ -70,53 +69,48 @@ def fill_db_model_object(model_object, human_model_dict): # https://docs.peewee-orm.com/en/latest/peewee/query_operators.html supported_operators = { - '==': operator.eq, - '<': operator.lt, - '<=': operator.le, - '>': operator.gt, - '>=': operator.ge, - '!=': operator.ne, - '<<': operator.lshift, - '>>': operator.rshift, - '%': operator.mod, - '**': operator.pow, - '^': operator.xor, - '~': operator.inv, + "==": operator.eq, + "<": operator.lt, + "<=": operator.le, + ">": operator.gt, + ">=": operator.ge, + "!=": operator.ne, + "<<": operator.lshift, + ">>": operator.rshift, + "%": operator.mod, + "**": operator.pow, + "^": operator.xor, + "~": operator.inv, } -def query_dict2expression( - model: type[DataBaseModel], query: dict[str, bool | int | str | list | tuple]): +def query_dict2expression(model: type[DataBaseModel], query: dict[str, bool | int | str | list | tuple]): expression = [] for field, value in query.items(): if not isinstance(value, (list, tuple)): - value = ('==', value) + value = ("==", value) op, *val = value - field = getattr(model, f'f_{field}') - value = supported_operators[op]( - field, val[0]) if op in supported_operators else getattr( - field, op)( - *val) + field = getattr(model, f"f_{field}") + value = supported_operators[op](field, val[0]) if op in supported_operators else getattr(field, op)(*val) expression.append(value) return reduce(operator.iand, expression) -def query_db(model: type[DataBaseModel], limit: int = 0, offset: int = 0, - query: dict = None, order_by: str | list | tuple | None = None): +def query_db(model: type[DataBaseModel], limit: int = 0, offset: int = 0, query: dict = None, order_by: str | list | tuple | None = None): data = model.select() if query: data = data.where(query_dict2expression(model, query)) count = data.count() if not order_by: - order_by = 'create_time' + order_by = "create_time" if not isinstance(order_by, (list, tuple)): - order_by = (order_by, 'asc') + order_by = (order_by, "asc") order_by, order = order_by - order_by = getattr(model, f'f_{order_by}') + order_by = getattr(model, f"f_{order_by}") order_by = getattr(order_by, order)() data = data.order_by(order_by) diff --git a/api/db/joint_services/memory_message_service.py b/api/db/joint_services/memory_message_service.py index 40b830db59..2e773bb06c 100644 --- a/api/db/joint_services/memory_message_service.py +++ b/api/db/joint_services/memory_message_service.py @@ -52,58 +52,68 @@ async def save_to_memory(memory_id: str, message_dict: dict): return False, f"Memory '{memory_id}' not found." tenant_id = memory.tenant_id - extracted_content = await extract_by_llm( - tenant_id, - memory.tenant_llm_id, - {"temperature": memory.temperature}, - get_memory_type_human(memory.memory_type), - message_dict.get("user_input", ""), - message_dict.get("agent_response", ""), - llm_id=memory.llm_id - ) if memory.memory_type != MemoryType.RAW.value else [] # if only RAW, no need to extract + extracted_content = ( + await extract_by_llm( + tenant_id, + memory.tenant_llm_id, + {"temperature": memory.temperature}, + get_memory_type_human(memory.memory_type), + message_dict.get("user_input", ""), + message_dict.get("agent_response", ""), + llm_id=memory.llm_id, + ) + if memory.memory_type != MemoryType.RAW.value + else [] + ) # if only RAW, no need to extract raw_message_id = REDIS_CONN.generate_auto_increment_id(namespace="memory") - message_list = [{ - "message_id": raw_message_id, - "message_type": MemoryType.RAW.name.lower(), - "source_id": 0, - "memory_id": memory_id, - "user_id": message_dict.get("user_id", ""), - "agent_id": message_dict["agent_id"], - "session_id": message_dict["session_id"], - "content": f"User Input: {message_dict.get('user_input')}\nAgent Response: {message_dict.get('agent_response')}", - "valid_at": timestamp_to_date(current_timestamp()), - "invalid_at": None, - "forget_at": None, - "status": True - }, *[{ - "message_id": REDIS_CONN.generate_auto_increment_id(namespace="memory"), - "message_type": content["message_type"], - "source_id": raw_message_id, - "memory_id": memory_id, - "user_id": message_dict.get("user_id", ""), - "agent_id": message_dict["agent_id"], - "session_id": message_dict["session_id"], - "content": content["content"], - "valid_at": content["valid_at"], - "invalid_at": content["invalid_at"] if content["invalid_at"] else None, - "forget_at": None, - "status": True - } for content in extracted_content]] + message_list = [ + { + "message_id": raw_message_id, + "message_type": MemoryType.RAW.name.lower(), + "source_id": 0, + "memory_id": memory_id, + "user_id": message_dict.get("user_id", ""), + "agent_id": message_dict["agent_id"], + "session_id": message_dict["session_id"], + "content": f"User Input: {message_dict.get('user_input')}\nAgent Response: {message_dict.get('agent_response')}", + "valid_at": timestamp_to_date(current_timestamp()), + "invalid_at": None, + "forget_at": None, + "status": True, + }, + *[ + { + "message_id": REDIS_CONN.generate_auto_increment_id(namespace="memory"), + "message_type": content["message_type"], + "source_id": raw_message_id, + "memory_id": memory_id, + "user_id": message_dict.get("user_id", ""), + "agent_id": message_dict["agent_id"], + "session_id": message_dict["session_id"], + "content": content["content"], + "valid_at": content["valid_at"], + "invalid_at": content["invalid_at"] if content["invalid_at"] else None, + "forget_at": None, + "status": True, + } + for content in extracted_content + ], + ] return await embed_and_save(memory, message_list) -async def save_extracted_to_memory_only(memory_id: str, message_dict, source_message_id: int, task_id: str=None): +async def save_extracted_to_memory_only(memory_id: str, message_dict, source_message_id: int, task_id: str = None): memory = MemoryService.get_by_memory_id(memory_id) if not memory: msg = f"Memory '{memory_id}' not found." if task_id: - TaskService.update_progress(task_id, {"progress": -1, "progress_msg": timestamp_to_date(current_timestamp())+ " " + msg}) + TaskService.update_progress(task_id, {"progress": -1, "progress_msg": timestamp_to_date(current_timestamp()) + " " + msg}) return False, msg if memory.memory_type == MemoryType.RAW.value: msg = f"Memory '{memory_id}' don't need to extract." if task_id: - TaskService.update_progress(task_id, {"progress": 1.0, "progress_msg": timestamp_to_date(current_timestamp())+ " " + msg}) + TaskService.update_progress(task_id, {"progress": 1.0, "progress_msg": timestamp_to_date(current_timestamp()) + " " + msg}) return True, msg tenant_id = memory.tenant_id @@ -115,35 +125,48 @@ async def save_extracted_to_memory_only(memory_id: str, message_dict, source_mes message_dict.get("user_input", ""), message_dict.get("agent_response", ""), task_id=task_id, - llm_id=memory.llm_id + llm_id=memory.llm_id, ) - message_list = [{ - "message_id": REDIS_CONN.generate_auto_increment_id(namespace="memory"), - "message_type": content["message_type"], - "source_id": source_message_id, - "memory_id": memory_id, - "user_id": message_dict.get("user_id", ""), - "agent_id": message_dict["agent_id"], - "session_id": message_dict["session_id"], - "content": content["content"], - "valid_at": content["valid_at"], - "invalid_at": content["invalid_at"] if content["invalid_at"] else None, - "forget_at": None, - "status": True - } for content in extracted_content] + message_list = [ + { + "message_id": REDIS_CONN.generate_auto_increment_id(namespace="memory"), + "message_type": content["message_type"], + "source_id": source_message_id, + "memory_id": memory_id, + "user_id": message_dict.get("user_id", ""), + "agent_id": message_dict["agent_id"], + "session_id": message_dict["session_id"], + "content": content["content"], + "valid_at": content["valid_at"], + "invalid_at": content["invalid_at"] if content["invalid_at"] else None, + "forget_at": None, + "status": True, + } + for content in extracted_content + ] if not message_list: msg = "No memory extracted from raw message." if task_id: - TaskService.update_progress(task_id, {"progress": 1.0, "progress_msg": timestamp_to_date(current_timestamp())+ " " + msg}) + TaskService.update_progress(task_id, {"progress": 1.0, "progress_msg": timestamp_to_date(current_timestamp()) + " " + msg}) return True, msg if task_id: - TaskService.update_progress(task_id, {"progress": 0.5, "progress_msg": timestamp_to_date(current_timestamp())+ " " + f"Extracted {len(message_list)} messages from raw dialogue."}) + TaskService.update_progress(task_id, {"progress": 0.5, "progress_msg": timestamp_to_date(current_timestamp()) + " " + f"Extracted {len(message_list)} messages from raw dialogue."}) return await embed_and_save(memory, message_list, task_id) -async def extract_by_llm(tenant_id: str, tenant_llm_id: int, extract_conf: dict, memory_type: List[str], user_input: str, - agent_response: str, system_prompt: str = "", user_prompt: str="", task_id: str=None, llm_id: str = "") -> List[dict]: +async def extract_by_llm( + tenant_id: str, + tenant_llm_id: int, + extract_conf: dict, + memory_type: List[str], + user_input: str, + agent_response: str, + system_prompt: str = "", + user_prompt: str = "", + task_id: str = None, + llm_id: str = "", +) -> List[dict]: if not system_prompt: system_prompt = PromptAssembler.assemble_system_prompt({"memory_type": memory_type}) conversation_content = f"User Input: {user_input}\nAgent Response: {agent_response}" @@ -157,36 +180,40 @@ async def extract_by_llm(tenant_id: str, tenant_llm_id: int, extract_conf: dict, llm_config = get_model_config_from_provider_instance(tenant_id, LLMType.CHAT, llm_id) with LLMBundle(tenant_id, llm_config) as llm: if task_id: - TaskService.update_progress(task_id, {"progress": 0.15, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Prepared prompts and LLM."}) + TaskService.update_progress(task_id, {"progress": 0.15, "progress_msg": timestamp_to_date(current_timestamp()) + " " + "Prepared prompts and LLM."}) res = await llm.async_chat(system_prompt, user_prompts, extract_conf) res_json = get_json_result_from_llm_response(res) if task_id: - TaskService.update_progress(task_id, {"progress": 0.35, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Get extracted result from LLM."}) - return [{ - "content": extracted_content["content"], - "valid_at": format_iso_8601_to_ymd_hms(extracted_content["valid_at"]), - "invalid_at": format_iso_8601_to_ymd_hms(extracted_content["invalid_at"]) if extracted_content.get("invalid_at") else "", - "message_type": message_type - } for message_type, extracted_content_list in res_json.items() for extracted_content in extracted_content_list] + TaskService.update_progress(task_id, {"progress": 0.35, "progress_msg": timestamp_to_date(current_timestamp()) + " " + "Get extracted result from LLM."}) + return [ + { + "content": extracted_content["content"], + "valid_at": format_iso_8601_to_ymd_hms(extracted_content["valid_at"]), + "invalid_at": format_iso_8601_to_ymd_hms(extracted_content["invalid_at"]) if extracted_content.get("invalid_at") else "", + "message_type": message_type, + } + for message_type, extracted_content_list in res_json.items() + for extracted_content in extracted_content_list + ] -async def embed_and_save(memory, message_list: list[dict], task_id: str=None): +async def embed_and_save(memory, message_list: list[dict], task_id: str = None): embd_model_config = get_model_config_from_provider_instance(memory.tenant_id, LLMType.EMBEDDING, memory.embd_id) with LLMBundle(memory.tenant_id, embd_model_config) as embedding_model: if task_id: - TaskService.update_progress(task_id, {"progress": 0.65, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Prepared embedding model."}) + TaskService.update_progress(task_id, {"progress": 0.65, "progress_msg": timestamp_to_date(current_timestamp()) + " " + "Prepared embedding model."}) vector_list, _ = embedding_model.encode([msg["content"] for msg in message_list]) for idx, msg in enumerate(message_list): msg["content_embed"] = vector_list[idx] if task_id: - TaskService.update_progress(task_id, {"progress": 0.85, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Embedded extracted content."}) + TaskService.update_progress(task_id, {"progress": 0.85, "progress_msg": timestamp_to_date(current_timestamp()) + " " + "Embedded extracted content."}) vector_dimension = len(vector_list[0]) if not MessageService.has_index(memory.tenant_id, memory.id): created = MessageService.create_index(memory.tenant_id, memory.id, vector_size=vector_dimension) if not created: error_msg = "Failed to create message index." if task_id: - TaskService.update_progress(task_id, {"progress": -1, "progress_msg": timestamp_to_date(current_timestamp())+ " " + error_msg}) + TaskService.update_progress(task_id, {"progress": -1, "progress_msg": timestamp_to_date(current_timestamp()) + " " + error_msg}) return False, error_msg new_msg_size = sum([MessageService.calculate_message_size(m) for m in message_list]) @@ -194,24 +221,23 @@ async def embed_and_save(memory, message_list: list[dict], task_id: str=None): if new_msg_size + current_memory_size > memory.memory_size: size_to_delete = current_memory_size + new_msg_size - memory.memory_size if memory.forgetting_policy == "FIFO": - message_ids_to_delete, delete_size = MessageService.pick_messages_to_delete_by_fifo(memory.id, memory.tenant_id, - size_to_delete) + message_ids_to_delete, delete_size = MessageService.pick_messages_to_delete_by_fifo(memory.id, memory.tenant_id, size_to_delete) MessageService.delete_message({"message_id": message_ids_to_delete}, memory.tenant_id, memory.id) decrease_memory_size_cache(memory.id, delete_size) else: error_msg = "Failed to insert message into memory. Memory size reached limit and cannot decide which to delete." if task_id: - TaskService.update_progress(task_id, {"progress": -1, "progress_msg": timestamp_to_date(current_timestamp())+ " " + error_msg}) + TaskService.update_progress(task_id, {"progress": -1, "progress_msg": timestamp_to_date(current_timestamp()) + " " + error_msg}) return False, error_msg fail_cases = MessageService.insert_message(message_list, memory.tenant_id, memory.id) if fail_cases: error_msg = "Failed to insert message into memory. Details: " + "; ".join(fail_cases) if task_id: - TaskService.update_progress(task_id, {"progress": -1, "progress_msg": timestamp_to_date(current_timestamp())+ " " + error_msg}) + TaskService.update_progress(task_id, {"progress": -1, "progress_msg": timestamp_to_date(current_timestamp()) + " " + error_msg}) return False, error_msg if task_id: - TaskService.update_progress(task_id, {"progress": 0.95, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Saved messages to storage."}) + TaskService.update_progress(task_id, {"progress": 0.95, "progress_msg": timestamp_to_date(current_timestamp()) + " " + "Saved messages to storage."}) increase_memory_size_cache(memory.id, new_msg_size) return True, "Message saved successfully." @@ -263,10 +289,7 @@ def init_message_id_sequence(): if not exist_memory_list: REDIS_CONN.set(message_id_redis_key, max_id) else: - max_id = MessageService.get_max_message_id( - uid_list=[m.tenant_id for m in exist_memory_list], - memory_ids=[m.id for m in exist_memory_list] - ) + max_id = MessageService.get_max_message_id(uid_list=[m.tenant_id for m in exist_memory_list], memory_ids=[m.id for m in exist_memory_list]) REDIS_CONN.set(message_id_redis_key, max_id) logging.info(f"Init message_id sequence done, current max id is {max_id}.") @@ -276,10 +299,7 @@ def get_memory_size_cache(memory_id: str, uid: str): if REDIS_CONN.exist(redis_key): return int(REDIS_CONN.get(redis_key)) else: - memory_size_map = MessageService.calculate_memory_size( - [memory_id], - [uid] - ) + memory_size_map = MessageService.calculate_memory_size([memory_id], [uid]) memory_size = memory_size_map.get(memory_id, 0) set_memory_size_cache(memory_id, memory_size) return memory_size @@ -328,7 +348,7 @@ def fix_missing_tokenized_memory(): logging.info("Fix missing tokenized memory done.") -def judge_system_prompt_is_default(system_prompt: str, memory_type: int|list[str]): +def judge_system_prompt_is_default(system_prompt: str, memory_type: int | list[str]): memory_type_list = memory_type if isinstance(memory_type, list) else get_memory_type_human(memory_type) return system_prompt == PromptAssembler.assemble_system_prompt({"memory_type": memory_type_list}) @@ -344,15 +364,9 @@ async def queue_save_to_memory_task(memory_ids: list[str], message_dict: dict): "agent_response": str } """ + def new_task(_memory_id: str, _source_id: int): - return { - "id": get_uuid(), - "doc_id": _memory_id, - "task_type": "memory", - "progress": 0.0, - "begin_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), - "digest": str(_source_id) - } + return {"id": get_uuid(), "doc_id": _memory_id, "task_type": "memory", "progress": 0.0, "begin_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "digest": str(_source_id)} not_found_memory = [] failed_memory = [] @@ -375,7 +389,7 @@ async def queue_save_to_memory_task(memory_ids: list[str], message_dict: dict): "valid_at": timestamp_to_date(current_timestamp()), "invalid_at": None, "forget_at": None, - "status": True + "status": True, } res, msg = await embed_and_save(memory, [raw_message]) if not res: @@ -391,7 +405,7 @@ async def queue_save_to_memory_task(memory_ids: list[str], message_dict: dict): "memory_id": memory_id, "tenant_id": memory.tenant_id, "source_id": raw_message_id, - "message_dict": message_dict + "message_dict": message_dict, } if not REDIS_CONN.queue_product(settings.get_svr_queue_name(priority=0), message=task_message): failed_memory.append({"memory_id": memory_id, "fail_msg": "Can't access Redis."}) @@ -436,9 +450,9 @@ async def handle_save_to_memory_task(task_param: dict): message_dict = task_param["message_dict"] success, msg = await save_extracted_to_memory_only(memory_id, message_dict, source_id, task.id) if success: - TaskService.update_progress(task.id, {"progress": 1.0, "progress_msg": timestamp_to_date(current_timestamp())+ " " + msg}) + TaskService.update_progress(task.id, {"progress": 1.0, "progress_msg": timestamp_to_date(current_timestamp()) + " " + msg}) return True, msg logging.error(msg) - TaskService.update_progress(task.id, {"progress": -1, "progress_msg": timestamp_to_date(current_timestamp())+ " " + msg}) + TaskService.update_progress(task.id, {"progress": -1, "progress_msg": timestamp_to_date(current_timestamp()) + " " + msg}) return False, msg diff --git a/api/db/joint_services/user_account_service.py b/api/db/joint_services/user_account_service.py index 91e6e38362..55e3eb5498 100644 --- a/api/db/joint_services/user_account_service.py +++ b/api/db/joint_services/user_account_service.py @@ -39,6 +39,7 @@ from rag.nlp import search from common.constants import ActiveEnum from common import settings + def create_new_user(user_info: dict) -> dict: """ Add a new user, and create tenant, tenant llm, file folder for new user. @@ -56,8 +57,8 @@ def create_new_user(user_info: dict) -> dict: """ # generate user_id and access_token for user user_id = uuid.uuid1().hex - user_info['id'] = user_id - user_info['access_token'] = uuid.uuid1().hex + user_info["id"] = user_id + user_info["access_token"] = uuid.uuid1().hex # construct tenant info tenant = { "id": user_id, @@ -148,7 +149,7 @@ def delete_user_data(user_id: str) -> dict: tenants = UserTenantService.get_user_tenant_relation_by_user_id(usr.id) owned_tenant = [t for t in tenants if t["role"] == UserTenantRole.OWNER.value] - done_msg = '' + done_msg = "" try: # step1. delete owned tenant info if owned_tenant: @@ -180,14 +181,10 @@ def delete_user_data(user_id: str) -> dict: file_delete_res = FileService.delete_by_ids([f["id"] for f in file_ids]) done_msg += f"- Deleted {file_delete_res} file records.\n" if doc_ids or file_ids: - file2doc_delete_res = File2DocumentService.delete_by_document_ids_or_file_ids( - [i["id"] for i in doc_ids], - [f["id"] for f in file_ids] - ) + file2doc_delete_res = File2DocumentService.delete_by_document_ids_or_file_ids([i["id"] for i in doc_ids], [f["id"] for f in file_ids]) done_msg += f"- Deleted {file2doc_delete_res} document-file relation records.\n" # step1.1.3 delete chunk in es - r = settings.docStoreConn.delete({"kb_id": kb_ids}, - search.index_name(tenant_id), kb_ids) + r = settings.docStoreConn.delete({"kb_id": kb_ids}, search.index_name(tenant_id), kb_ids) done_msg += f"- Deleted {r} chunk records.\n" kb_delete_res = KnowledgebaseService.delete_by_ids(kb_ids) done_msg += f"- Deleted {kb_delete_res} dataset records.\n" @@ -236,8 +233,8 @@ def delete_user_data(user_id: str) -> dict: created_documents = DocumentService.get_all_docs_by_creator_id(usr.id) if created_documents: # step2.1.1 delete files - doc_file_info = File2DocumentService.get_by_document_ids([d['id'] for d in created_documents]) - created_files = FileService.get_by_ids([f['file_id'] for f in doc_file_info]) + doc_file_info = File2DocumentService.get_by_document_ids([d["id"] for d in created_documents]) + created_files = FileService.get_by_ids([f["file_id"] for f in doc_file_info]) if created_files: # step2.1.1.1 delete file in storage for f in created_files: @@ -247,10 +244,7 @@ def delete_user_data(user_id: str) -> dict: file_delete_res = FileService.delete_by_ids([f.id for f in created_files]) done_msg += f"- Deleted {file_delete_res} file records.\n" # step2.1.2 delete document-file relation record - file2doc_delete_res = File2DocumentService.delete_by_document_ids_or_file_ids( - [d['id'] for d in created_documents], - [f.id for f in created_files] - ) + file2doc_delete_res = File2DocumentService.delete_by_document_ids_or_file_ids([d["id"] for d in created_documents], [f.id for f in created_files]) done_msg += f"- Deleted {file2doc_delete_res} document-file relation records.\n" # step2.1.3 delete chunks doc_groups = group_by(created_documents, "tenant_id") @@ -260,31 +254,24 @@ def delete_user_data(user_id: str) -> dict: kb_doc_info = {} for _tenant_id, kb_doc in kb_grouped_doc.items(): for _kb_id, docs in kb_doc.items(): - chunk_delete_res += settings.docStoreConn.delete( - {"doc_id": [d["id"] for d in docs]}, - search.index_name(_tenant_id), _kb_id - ) + chunk_delete_res += settings.docStoreConn.delete({"doc_id": [d["id"] for d in docs]}, search.index_name(_tenant_id), _kb_id) # record doc info if _kb_id in kb_doc_info.keys(): - kb_doc_info[_kb_id]['doc_num'] += 1 - kb_doc_info[_kb_id]['token_num'] += sum([d["token_num"] for d in docs]) - kb_doc_info[_kb_id]['chunk_num'] += sum([d["chunk_num"] for d in docs]) + kb_doc_info[_kb_id]["doc_num"] += 1 + kb_doc_info[_kb_id]["token_num"] += sum([d["token_num"] for d in docs]) + kb_doc_info[_kb_id]["chunk_num"] += sum([d["chunk_num"] for d in docs]) else: - kb_doc_info[_kb_id] = { - 'doc_num': 1, - 'token_num': sum([d["token_num"] for d in docs]), - 'chunk_num': sum([d["chunk_num"] for d in docs]) - } + kb_doc_info[_kb_id] = {"doc_num": 1, "token_num": sum([d["token_num"] for d in docs]), "chunk_num": sum([d["chunk_num"] for d in docs])} done_msg += f"- Deleted {chunk_delete_res} chunks.\n" # step2.1.4 delete tasks - task_delete_res = TaskService.delete_by_doc_ids([d['id'] for d in created_documents]) + task_delete_res = TaskService.delete_by_doc_ids([d["id"] for d in created_documents]) done_msg += f"- Deleted {task_delete_res} tasks.\n" # step2.1.5 delete document record - doc_delete_res = DocumentService.delete_by_ids([d['id'] for d in created_documents]) + doc_delete_res = DocumentService.delete_by_ids([d["id"] for d in created_documents]) done_msg += f"- Deleted {doc_delete_res} documents.\n" for doc in created_documents: try: - DocMetadataService.delete_document_metadata(doc['id'], doc['kb_id'], doc['tenant_id']) + DocMetadataService.delete_document_metadata(doc["id"], doc["kb_id"], doc["tenant_id"]) except Exception as e: logging.warning(f"Failed to delete metadata for document {doc['id']}: {e}") # step2.1.6 update dataset doc&chunk&token cnt @@ -302,7 +289,7 @@ def delete_user_data(user_id: str) -> dict: except Exception as e: logging.exception(e) - return {"success": False, "message": "An internal error occurred during user deletion. Some operations may have completed.","details": done_msg} + return {"success": False, "message": "An internal error occurred during user deletion. Some operations may have completed.", "details": done_msg} def delete_user_agents(user_id: str) -> dict: @@ -316,13 +303,10 @@ def delete_user_agents(user_id: str) -> dict: agents_deleted_count, agents_version_deleted_count = 0, 0 user_agents = UserCanvasService.get_all_agents_by_tenant_ids([user_id], user_id) if user_agents: - agents_version = UserCanvasVersionService.get_all_canvas_version_by_canvas_ids([a['id'] for a in user_agents]) - agents_version_deleted_count = UserCanvasVersionService.delete_by_ids([v['id'] for v in agents_version]) - agents_deleted_count = UserCanvasService.delete_by_ids([a['id'] for a in user_agents]) - return { - "agents_deleted_count": agents_deleted_count, - "version_deleted_count": agents_version_deleted_count - } + agents_version = UserCanvasVersionService.get_all_canvas_version_by_canvas_ids([a["id"] for a in user_agents]) + agents_version_deleted_count = UserCanvasVersionService.delete_by_ids([v["id"] for v in agents_version]) + agents_deleted_count = UserCanvasService.delete_by_ids([a["id"] for a in user_agents]) + return {"agents_deleted_count": agents_deleted_count, "version_deleted_count": agents_version_deleted_count} def delete_user_dialogs(user_id: str) -> dict: @@ -339,17 +323,17 @@ def delete_user_dialogs(user_id: str) -> dict: user_dialogs = DialogService.get_all_dialogs_by_tenant_id(user_id) if user_dialogs: # delete conversation - conversations = ConversationService.get_all_conversation_by_dialog_ids([ud['id'] for ud in user_dialogs]) - conversations_deleted_count = ConversationService.delete_by_ids([c['id'] for c in conversations]) + conversations = ConversationService.get_all_conversation_by_dialog_ids([ud["id"] for ud in user_dialogs]) + conversations_deleted_count = ConversationService.delete_by_ids([c["id"] for c in conversations]) # delete api token api_token_deleted_count = APITokenService.delete_by_tenant_id(user_id) # delete api for conversation - api4conversation_deleted_count = API4ConversationService.delete_by_dialog_ids([ud['id'] for ud in user_dialogs]) + api4conversation_deleted_count = API4ConversationService.delete_by_dialog_ids([ud["id"] for ud in user_dialogs]) # delete dialog at last - dialog_deleted_count = DialogService.delete_by_ids([ud['id'] for ud in user_dialogs]) + dialog_deleted_count = DialogService.delete_by_ids([ud["id"] for ud in user_dialogs]) return { "dialogs_deleted_count": dialog_deleted_count, "conversations_deleted_count": conversations_deleted_count, "api_token_deleted_count": api_token_deleted_count, - "api4conversation_deleted_count": api4conversation_deleted_count + "api4conversation_deleted_count": api4conversation_deleted_count, } diff --git a/api/db/reload_config_base.py b/api/db/reload_config_base.py index be37afc6bc..b9746eaa2c 100644 --- a/api/db/reload_config_base.py +++ b/api/db/reload_config_base.py @@ -18,8 +18,7 @@ class ReloadConfigBase: def get_all(cls): configs = {} for k, v in cls.__dict__.items(): - if not callable(getattr(cls, k)) and not k.startswith( - "__") and not k.startswith("_"): + if not callable(getattr(cls, k)) and not k.startswith("__") and not k.startswith("_"): configs[k] = v return configs diff --git a/api/db/services/__init__.py b/api/db/services/__init__.py index a5e83ea0e4..c18e6a6bb7 100644 --- a/api/db/services/__init__.py +++ b/api/db/services/__init__.py @@ -42,7 +42,7 @@ def _split_name_counter(filename: str) -> tuple[str, int | None]: return filename, None -def duplicate_name(query_func, name_field: str="name", **kwargs) -> str: +def duplicate_name(query_func, name_field: str = "name", **kwargs) -> str: """ Generates a unique filename by appending/incrementing a counter when duplicates exist. diff --git a/api/db/services/api_service.py b/api/db/services/api_service.py index 8f60a1c5ab..9a9d0297f0 100644 --- a/api/db/services/api_service.py +++ b/api/db/services/api_service.py @@ -28,12 +28,12 @@ class APITokenService(CommonService): @classmethod @DB.connection_context() def used(cls, token): - return cls.model.update({ - "update_time": current_timestamp(), - "update_date": datetime_format(datetime.now()), - }).where( - cls.model.token == token - ) + return cls.model.update( + { + "update_time": current_timestamp(), + "update_date": datetime_format(datetime.now()), + } + ).where(cls.model.token == token) @classmethod @DB.connection_context() @@ -54,15 +54,11 @@ class API4ConversationService(CommonService): @classmethod @DB.connection_context() - def get_list(cls, dialog_id, tenant_id, - page_number, items_per_page, - orderby, desc, id=None, user_id=None, include_dsl=True, keywords="", - from_date=None, to_date=None, exp_user_id=None - ): + def get_list(cls, dialog_id, tenant_id, page_number, items_per_page, orderby, desc, id=None, user_id=None, include_dsl=True, keywords="", from_date=None, to_date=None, exp_user_id=None): if include_dsl: sessions = cls.model.select().where(cls.model.dialog_id == dialog_id) else: - fields = [field for field in cls.model._meta.fields.values() if field.name != 'dsl'] + fields = [field for field in cls.model._meta.fields.values() if field.name != "dsl"] sessions = cls.model.select(*fields).where(cls.model.dialog_id == dialog_id) if id: sessions = sessions.where(cls.model.id == id) @@ -85,15 +81,15 @@ class API4ConversationService(CommonService): sessions = sessions.paginate(page_number, items_per_page) return count, list(sessions.dicts()) - + @classmethod @DB.connection_context() def get_names(cls, dialog_id, exp_user_id): - fields = [cls.model.id, cls.model.name,] - sessions = cls.model.select(*fields).where( - cls.model.dialog_id == dialog_id, - cls.model.exp_user_id == exp_user_id - ).order_by(cls.model.getter_by("create_date").desc()) + fields = [ + cls.model.id, + cls.model.name, + ] + sessions = cls.model.select(*fields).where(cls.model.dialog_id == dialog_id, cls.model.exp_user_id == exp_user_id).order_by(cls.model.getter_by("create_date").desc()) return list(sessions.dicts()) @@ -108,25 +104,21 @@ class API4ConversationService(CommonService): def stats(cls, tenant_id, from_date, to_date, source=None): if len(to_date) == 10: to_date += " 23:59:59" - return cls.model.select( - cls.model.create_date.truncate("day").alias("dt"), - peewee.fn.COUNT( - cls.model.id).alias("pv"), - peewee.fn.COUNT( - cls.model.user_id.distinct()).alias("uv"), - peewee.fn.SUM( - cls.model.tokens).alias("tokens"), - peewee.fn.SUM( - cls.model.duration).alias("duration"), - peewee.fn.AVG( - cls.model.round).alias("round"), - peewee.fn.SUM( - cls.model.thumb_up).alias("thumb_up") - ).join(Dialog, on=((cls.model.dialog_id == Dialog.id) & (Dialog.tenant_id == tenant_id))).where( - cls.model.create_date >= from_date, - cls.model.create_date <= to_date, - cls.model.source == source - ).group_by(cls.model.create_date.truncate("day")).dicts() + return ( + cls.model.select( + cls.model.create_date.truncate("day").alias("dt"), + peewee.fn.COUNT(cls.model.id).alias("pv"), + peewee.fn.COUNT(cls.model.user_id.distinct()).alias("uv"), + peewee.fn.SUM(cls.model.tokens).alias("tokens"), + peewee.fn.SUM(cls.model.duration).alias("duration"), + peewee.fn.AVG(cls.model.round).alias("round"), + peewee.fn.SUM(cls.model.thumb_up).alias("thumb_up"), + ) + .join(Dialog, on=((cls.model.dialog_id == Dialog.id) & (Dialog.tenant_id == tenant_id))) + .where(cls.model.create_date >= from_date, cls.model.create_date <= to_date, cls.model.source == source) + .group_by(cls.model.create_date.truncate("day")) + .dicts() + ) @classmethod @DB.connection_context() diff --git a/api/db/services/chunk_feedback_service.py b/api/db/services/chunk_feedback_service.py index 1d9fe23f48..305d609d85 100644 --- a/api/db/services/chunk_feedback_service.py +++ b/api/db/services/chunk_feedback_service.py @@ -39,6 +39,7 @@ Infinity uses row_id (returned by search results since PR #13901) for targeted single-row updates. If a concurrent update changes the row_id, the Infinity connector retries with a fresh row_id lookup. """ + import logging import math import os @@ -245,12 +246,7 @@ class ChunkFeedbackService: return False @classmethod - def apply_feedback( - cls, - tenant_id: str, - reference: dict, - is_positive: bool - ) -> dict: + def apply_feedback(cls, tenant_id: str, reference: dict, is_positive: bool) -> dict: """ Apply user feedback to all chunks referenced in a response. @@ -274,13 +270,16 @@ class ChunkFeedbackService: logging.debug("No chunk IDs found in reference for feedback") return {"success_count": 0, "fail_count": 0, "chunk_ids": []} - signed_budget = ( - UPVOTE_WEIGHT_INCREMENT if is_positive else -DOWNVOTE_WEIGHT_DECREMENT + signed_budget = UPVOTE_WEIGHT_INCREMENT if is_positive else -DOWNVOTE_WEIGHT_DECREMENT + weighting = ( + CHUNK_FEEDBACK_WEIGHTING + if CHUNK_FEEDBACK_WEIGHTING + in ( + "uniform", + "relevance", + ) + else "relevance" ) - weighting = CHUNK_FEEDBACK_WEIGHTING if CHUNK_FEEDBACK_WEIGHTING in ( - "uniform", - "relevance", - ) else "relevance" if weighting == "uniform": deltas = _allocate_deltas_uniform([(r[0], r[1]) for r in rows], signed_budget) @@ -314,8 +313,4 @@ class ChunkFeedbackService: len(chunk_ids), ) - return { - "success_count": success_count, - "fail_count": fail_count, - "chunk_ids": chunk_ids - } + return {"success_count": success_count, "fail_count": fail_count, "chunk_ids": chunk_ids} diff --git a/api/db/services/common_service.py b/api/db/services/common_service.py index 8ef4bb94b4..93fd5f8a83 100644 --- a/api/db/services/common_service.py +++ b/api/db/services/common_service.py @@ -68,6 +68,7 @@ def retry_db_operation(func): ) def wrapper(*args, **kwargs): return func(*args, **kwargs) + return wrapper diff --git a/api/db/services/connector_service.py b/api/db/services/connector_service.py index 8d2b015010..a7c4aace94 100644 --- a/api/db/services/connector_service.py +++ b/api/db/services/connector_service.py @@ -127,23 +127,17 @@ class ConnectorService(CommonService): @classmethod def list(cls, tenant_id): - fields = [ - cls.model.id, - cls.model.name, - cls.model.source, - cls.model.status - ] - return list(cls.model.select(*fields).where( - cls.model.tenant_id == tenant_id - ).dicts()) + fields = [cls.model.id, cls.model.name, cls.model.source, cls.model.status] + return list(cls.model.select(*fields).where(cls.model.tenant_id == tenant_id).dicts()) @classmethod - def rebuild(cls, kb_id:str, connector_id: str, tenant_id:str): + def rebuild(cls, kb_id: str, connector_id: str, tenant_id: str): from api.db.services.file_service import FileService + e, conn = cls.get_by_id(connector_id) if not e: return None - SyncLogsService.filter_delete([SyncLogs.connector_id==connector_id, SyncLogs.kb_id==kb_id]) + SyncLogsService.filter_delete([SyncLogs.connector_id == connector_id, SyncLogs.kb_id == kb_id]) docs = DocumentService.query(source_type=f"{conn.source}/{conn.id}", kb_id=kb_id) err = FileService.delete_docs([d.id for d in docs], tenant_id) SyncLogsService.schedule(connector_id, kb_id, reindex=True, task_type=ConnectorTaskType.SYNC) @@ -176,9 +170,7 @@ class ConnectorService(CommonService): kb_id, source_type, ) - stale_doc_ids = [ - doc["id"] for doc in existing_docs if doc["id"] not in retain_doc_ids - ] + stale_doc_ids = [doc["id"] for doc in existing_docs if doc["id"] not in retain_doc_ids] if not stale_doc_ids: return 0, [] @@ -213,7 +205,6 @@ class ConnectorService(CommonService): class SyncLogsService(CommonService): model = SyncLogs - @classmethod def list_sync_tasks(cls, connector_id=None, page_number=None, items_per_page=15) -> Tuple[List[dict], int]: fields = [ @@ -235,11 +226,13 @@ class SyncLogsService(CommonService): ] if not connector_id: fields.append(Connector.config) - - query = cls.model.select(*fields)\ - .join(Connector, on=(cls.model.connector_id==Connector.id))\ - .join(Connector2Kb, on=(cls.model.kb_id==Connector2Kb.kb_id))\ - .join(Knowledgebase, on=(cls.model.kb_id==Knowledgebase.id)) + + query = ( + cls.model.select(*fields) + .join(Connector, on=(cls.model.connector_id == Connector.id)) + .join(Connector2Kb, on=(cls.model.kb_id == Connector2Kb.kb_id)) + .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) + ) if connector_id: query = query.where(cls.model.connector_id == connector_id) @@ -249,12 +242,7 @@ class SyncLogsService(CommonService): expr = SQL(f"NOW() AT TIME ZONE '{TIMEZONE}' - make_interval(mins => t2.refresh_freq)") else: expr = SQL("NOW() - INTERVAL `t2`.`refresh_freq` MINUTE") - query = query.where( - Connector.input_type == InputType.POLL, - Connector.status == TaskStatus.SCHEDULE, - cls.model.status == TaskStatus.SCHEDULE, - cls.model.update_date < expr - ) + query = query.where(Connector.input_type == InputType.POLL, Connector.status == TaskStatus.SCHEDULE, cls.model.status == TaskStatus.SCHEDULE, cls.model.update_date < expr) query = query.distinct().order_by(cls.model.update_time.desc()) total = query.count() @@ -277,11 +265,11 @@ class SyncLogsService(CommonService): "prune_freq", ) return [ - task for task in tasks + task + for task in tasks # Prune is opt-in at the connector config level; keep the scheduler # blind to prune_freq until the flag is enabled. - if bool((task.get("config") or {}).get("sync_deleted_files")) - and int(task.get("prune_freq") or 0) > 0 + if bool((task.get("config") or {}).get("sync_deleted_files")) and int(task.get("prune_freq") or 0) > 0 ] @classmethod @@ -314,10 +302,12 @@ class SyncLogsService(CommonService): cls.model.update_time, ] - query = cls.model.select(*fields)\ - .join(Connector, on=(cls.model.connector_id==Connector.id))\ - .join(Connector2Kb, on=(cls.model.kb_id==Connector2Kb.kb_id))\ - .join(Knowledgebase, on=(cls.model.kb_id==Knowledgebase.id)) + query = ( + cls.model.select(*fields) + .join(Connector, on=(cls.model.connector_id == Connector.id)) + .join(Connector2Kb, on=(cls.model.kb_id == Connector2Kb.kb_id)) + .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) + ) query = query.where( Connector.input_type == InputType.POLL, @@ -328,9 +318,7 @@ class SyncLogsService(CommonService): database_type = os.getenv("DB_TYPE", "mysql") if "postgres" in database_type.lower(): - expr = SQL( - f"NOW() AT TIME ZONE '{TIMEZONE}' - make_interval(mins => t2.{freq_field})" - ) + expr = SQL(f"NOW() AT TIME ZONE '{TIMEZONE}' - make_interval(mins => t2.{freq_field})") else: expr = SQL(f"NOW() - INTERVAL `t2`.`{freq_field}` MINUTE") query = query.where(cls.model.update_date < expr) @@ -339,7 +327,7 @@ class SyncLogsService(CommonService): @classmethod def start(cls, id, connector_id): - cls.update_by_id(id, {"status": TaskStatus.RUNNING, "time_started": datetime.now().strftime('%Y-%m-%d %H:%M:%S') }) + cls.update_by_id(id, {"status": TaskStatus.RUNNING, "time_started": datetime.now().strftime("%Y-%m-%d %H:%M:%S")}) ConnectorService.update_by_id(connector_id, {"status": TaskStatus.RUNNING}) @classmethod @@ -382,39 +370,41 @@ class SyncLogsService(CommonService): return None reindex = "1" if reindex else "0" ConnectorService.update_by_id(connector_id, {"status": TaskStatus.SCHEDULE}) - return cls.save(**{ - "id": get_uuid(), - "kb_id": kb_id, "status": TaskStatus.SCHEDULE, "connector_id": connector_id, - "task_type": task_type, - "poll_range_start": poll_range_start, "from_beginning": reindex, - "total_docs_indexed": total_docs_indexed, - "time_started": datetime.now().strftime('%Y-%m-%d %H:%M:%S') - }) + return cls.save( + **{ + "id": get_uuid(), + "kb_id": kb_id, + "status": TaskStatus.SCHEDULE, + "connector_id": connector_id, + "task_type": task_type, + "poll_range_start": poll_range_start, + "from_beginning": reindex, + "total_docs_indexed": total_docs_indexed, + "time_started": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + } + ) except Exception as e: logging.exception(e) task = cls.get_latest_task(connector_id, kb_id, task_type) if task: - cls.model.update(status=TaskStatus.SCHEDULE, - poll_range_start=poll_range_start, - error_msg=cls.model.error_msg + str(e), - full_exception_trace=cls.model.full_exception_trace + str(e) - ) \ - .where(cls.model.id == task.id).execute() + cls.model.update( + status=TaskStatus.SCHEDULE, poll_range_start=poll_range_start, error_msg=cls.model.error_msg + str(e), full_exception_trace=cls.model.full_exception_trace + str(e) + ).where(cls.model.id == task.id).execute() ConnectorService.update_by_id(connector_id, {"status": TaskStatus.SCHEDULE}) @classmethod def increase_docs(cls, id, max_update, doc_num, err_msg="", error_count=0): # Keep sync monotonic. - cls.model.update(new_docs_indexed=cls.model.new_docs_indexed + doc_num, - total_docs_indexed=cls.model.total_docs_indexed + doc_num, - poll_range_start=fn.COALESCE(fn.GREATEST(cls.model.poll_range_start, max_update), max_update), - poll_range_end=fn.COALESCE(fn.GREATEST(cls.model.poll_range_end, max_update), max_update), - error_msg=cls.model.error_msg + err_msg, - error_count=cls.model.error_count + error_count, - update_time=current_timestamp(), - update_date=timestamp_to_date(current_timestamp()) - )\ - .where(cls.model.id == id).execute() + cls.model.update( + new_docs_indexed=cls.model.new_docs_indexed + doc_num, + total_docs_indexed=cls.model.total_docs_indexed + doc_num, + poll_range_start=fn.COALESCE(fn.GREATEST(cls.model.poll_range_start, max_update), max_update), + poll_range_end=fn.COALESCE(fn.GREATEST(cls.model.poll_range_end, max_update), max_update), + error_msg=cls.model.error_msg + err_msg, + error_count=cls.model.error_count + error_count, + update_time=current_timestamp(), + update_date=timestamp_to_date(current_timestamp()), + ).where(cls.model.id == id).execute() @classmethod def increase_removed_docs(cls, id, removed_count, err_msg="", error_count=0): @@ -429,6 +419,7 @@ class SyncLogsService(CommonService): @classmethod def duplicate_and_parse(cls, kb, docs, tenant_id, src, auto_parse=True): from api.db.services.file_service import FileService + if not docs: return None @@ -442,7 +433,15 @@ class SyncLogsService(CommonService): return self.blob errs = [] - files = [FileObj(id=d["id"], filename=d["semantic_identifier"]+(f"{d['extension']}" if d["semantic_identifier"][::-1].find(d['extension'][::-1])<0 else ""), blob=d["blob"], fingerprint=d.get("fingerprint")) for d in docs] + files = [ + FileObj( + id=d["id"], + filename=d["semantic_identifier"] + (f"{d['extension']}" if d["semantic_identifier"][::-1].find(d["extension"][::-1]) < 0 else ""), + blob=d["blob"], + fingerprint=d.get("fingerprint"), + ) + for d in docs + ] doc_ids = [] err, doc_blob_pairs = FileService.upload_document(kb, files, tenant_id, src) errs.extend(err) @@ -451,17 +450,17 @@ class SyncLogsService(CommonService): metadata_map = {} for d in docs: if d.get("metadata"): - filename = d["semantic_identifier"]+(f"{d['extension']}" if d["semantic_identifier"][::-1].find(d['extension'][::-1])<0 else "") + filename = d["semantic_identifier"] + (f"{d['extension']}" if d["semantic_identifier"][::-1].find(d["extension"][::-1]) < 0 else "") metadata_map[filename] = d["metadata"] kb_table_num_map = {} for doc, _ in doc_blob_pairs: doc_ids.append(doc["id"]) - + # Set metadata if available for this document if doc["name"] in metadata_map: DocMetadataService.update_document_metadata(doc["id"], metadata_map[doc["name"]]) - + if not auto_parse or auto_parse == "0": continue DocumentService.run(tenant_id, doc, kb_table_num_map) @@ -470,10 +469,7 @@ class SyncLogsService(CommonService): @classmethod def get_latest_task(cls, connector_id, kb_id, task_type=None): - query = cls.model.select().where( - cls.model.connector_id==connector_id, - cls.model.kb_id == kb_id - ) + query = cls.model.select().where(cls.model.connector_id == connector_id, cls.model.kb_id == kb_id) if task_type is not None: query = query.where(cls.model.task_type == task_type) return query.order_by(cls.model.update_time.desc()).first() @@ -483,7 +479,7 @@ class Connector2KbService(CommonService): model = Connector2Kb @classmethod - def link_connectors(cls, kb_id:str, connectors: list[dict], tenant_id:str): + def link_connectors(cls, kb_id: str, connectors: list[dict], tenant_id: str): arr = cls.query(kb_id=kb_id) old_conn_ids = [a.connector_id for a in arr] connector_ids = [] @@ -491,14 +487,9 @@ class Connector2KbService(CommonService): conn_id = conn["id"] connector_ids.append(conn_id) if conn_id in old_conn_ids: - cls.filter_update([cls.model.connector_id==conn_id, cls.model.kb_id==kb_id], {"auto_parse": conn.get("auto_parse", "1")}) + cls.filter_update([cls.model.connector_id == conn_id, cls.model.kb_id == kb_id], {"auto_parse": conn.get("auto_parse", "1")}) continue - cls.save(**{ - "id": get_uuid(), - "connector_id": conn_id, - "kb_id": kb_id, - "auto_parse": conn.get("auto_parse", "1") - }) + cls.save(**{"id": get_uuid(), "connector_id": conn_id, "kb_id": kb_id, "auto_parse": conn.get("auto_parse", "1")}) SyncLogsService.schedule(conn_id, kb_id, reindex=True, task_type=ConnectorTaskType.SYNC) e, full_conn = ConnectorService.get_by_id(conn_id) if e and (full_conn.config or {}).get("sync_deleted_files"): @@ -508,31 +499,20 @@ class Connector2KbService(CommonService): for conn_id in old_conn_ids: if conn_id in connector_ids: continue - cls.filter_delete([cls.model.kb_id==kb_id, cls.model.connector_id==conn_id]) + cls.filter_delete([cls.model.kb_id == kb_id, cls.model.connector_id == conn_id]) e, conn = ConnectorService.get_by_id(conn_id) if not e: continue - #SyncLogsService.filter_delete([SyncLogs.connector_id==conn_id, SyncLogs.kb_id==kb_id]) + # SyncLogsService.filter_delete([SyncLogs.connector_id==conn_id, SyncLogs.kb_id==kb_id]) # Do not delete docs while unlinking. - SyncLogsService.filter_update([SyncLogs.connector_id==conn_id, SyncLogs.kb_id==kb_id, SyncLogs.status.in_([TaskStatus.SCHEDULE, TaskStatus.RUNNING])], {"status": TaskStatus.CANCEL}) - #docs = DocumentService.query(source_type=f"{conn.source}/{conn.id}") - #err = FileService.delete_docs([d.id for d in docs], tenant_id) - #if err: + SyncLogsService.filter_update([SyncLogs.connector_id == conn_id, SyncLogs.kb_id == kb_id, SyncLogs.status.in_([TaskStatus.SCHEDULE, TaskStatus.RUNNING])], {"status": TaskStatus.CANCEL}) + # docs = DocumentService.query(source_type=f"{conn.source}/{conn.id}") + # err = FileService.delete_docs([d.id for d in docs], tenant_id) + # if err: # errs.append(err) return "\n".join(errs) @classmethod def list_connectors(cls, kb_id): - fields = [ - Connector.id, - Connector.source, - Connector.name, - cls.model.auto_parse, - Connector.status - ] - return list(cls.model.select(*fields)\ - .join(Connector, on=(cls.model.connector_id==Connector.id))\ - .where( - cls.model.kb_id==kb_id - ).dicts() - ) + fields = [Connector.id, Connector.source, Connector.name, cls.model.auto_parse, Connector.status] + return list(cls.model.select(*fields).join(Connector, on=(cls.model.connector_id == Connector.id)).where(cls.model.kb_id == kb_id).dicts()) diff --git a/api/db/services/conversation_service.py b/api/db/services/conversation_service.py index 73afe09e1c..274aae0833 100644 --- a/api/db/services/conversation_service.py +++ b/api/db/services/conversation_service.py @@ -79,9 +79,7 @@ class ConversationService(CommonService): # are still found on the first read after deploy — without # that fallback the writer would create a duplicate # conversation (splitting the channel's history). - sha256_id = hashlib.sha256( - f"{dialog_id}:{channel_id}:{chat_id}".encode("utf-8") - ).hexdigest()[:32] + sha256_id = hashlib.sha256(f"{dialog_id}:{channel_id}:{chat_id}".encode("utf-8")).hexdigest()[:32] # codeql[py/weak-sensitive-data-hashing] Intentional: the # MD5 here is a backward-compatibility lookup for rows # created under the previous hashing scheme. The @@ -89,9 +87,7 @@ class ConversationService(CommonService): # MD5 is read-only and only used to find-and-migrate # existing rows on first access. It is not used for # authentication or any other security-sensitive purpose. - legacy_id = hashlib.md5( - f"{dialog_id}:{channel_id}:{chat_id}".encode("utf-8") - ).hexdigest()[:32] + legacy_id = hashlib.md5(f"{dialog_id}:{channel_id}:{chat_id}".encode("utf-8")).hexdigest()[:32] conv = cls.model.get_or_none(cls.model.id == sha256_id) if conv is not None: # SHA row already present. A previous call may have @@ -245,29 +241,21 @@ async def async_completion(tenant_id, chat_id, question, name="New session", ses "dialog_id": chat_id, "name": name, "message": [{"role": "assistant", "content": dia[0].prompt_config.get("prologue"), "created_at": time.time()}], - "user_id": kwargs.get("user_id", "") + "user_id": kwargs.get("user_id", ""), } ConversationService.save(**conv) if stream: - yield "data:" + json.dumps({"code": 0, "message": "", - "data": { - "answer": conv["message"][0]["content"], - "reference": {}, - "audio_binary": None, - "id": None, - "session_id": session_id - }}, - ensure_ascii=False) + "\n\n" + yield ( + "data:" + + json.dumps( + {"code": 0, "message": "", "data": {"answer": conv["message"][0]["content"], "reference": {}, "audio_binary": None, "id": None, "session_id": session_id}}, ensure_ascii=False + ) + + "\n\n" + ) yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" return else: - answer = { - "answer": conv["message"][0]["content"], - "reference": {}, - "audio_binary": None, - "id": None, - "session_id": session_id - } + answer = {"answer": conv["message"][0]["content"], "reference": {}, "audio_binary": None, "id": None, "session_id": session_id} yield answer return @@ -277,11 +265,7 @@ async def async_completion(tenant_id, chat_id, question, name="New session", ses conv = conv[0] msg = [] - question = { - "content": question, - "role": "user", - "id": str(uuid4()) - } + question = {"content": question, "role": "user", "id": str(uuid4())} # Propagate runtime attachments so downstream chat flow can resolve file content. if isinstance(kwargs.get("files"), list) and kwargs["files"]: @@ -297,7 +281,7 @@ async def async_completion(tenant_id, chat_id, question, name="New session", ses message_id = msg[-1].get("id") e, dia = DialogService.get_by_id(conv.dialog_id) - kb_ids = kwargs.get("kb_ids",[]) + kb_ids = kwargs.get("kb_ids", []) dia.kb_ids = list(set(dia.kb_ids + kb_ids)) if not conv.reference: conv.reference = [] @@ -311,9 +295,7 @@ async def async_completion(tenant_id, chat_id, question, name="New session", ses yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n" ConversationService.update_by_id(conv.id, conv.to_dict()) except Exception as e: - yield "data:" + json.dumps({"code": 500, "message": str(e), - "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, - ensure_ascii=False) + "\n\n" + yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 0, "data": True}, ensure_ascii=False) + "\n\n" else: @@ -324,15 +306,13 @@ async def async_completion(tenant_id, chat_id, question, name="New session", ses break yield answer + async def async_iframe_completion(dialog_id, question, session_id=None, stream=True, tenant_id=None, **kwargs): if tenant_id: exists, dia = DialogService.get_by_id(dialog_id) - if (not exists - or getattr(dia, "tenant_id", None) != tenant_id - or str(getattr(dia, "status", "")) != StatusEnum.VALID.value): + if not exists or getattr(dia, "tenant_id", None) != tenant_id or str(getattr(dia, "status", "")) != StatusEnum.VALID.value: logger.warning( - "Dialog lookup failed for tenant-scoped iframe completion: " - "tenant_id=%s dialog_id=%s required_status=%s", + "Dialog lookup failed for tenant-scoped iframe completion: tenant_id=%s dialog_id=%s required_status=%s", tenant_id, dialog_id, StatusEnum.VALID.value, @@ -343,22 +323,13 @@ async def async_iframe_completion(dialog_id, question, session_id=None, stream=T assert e, "Dialog not found" if not session_id: session_id = get_uuid() - conv = { - "id": session_id, - "dialog_id": dialog_id, - "user_id": kwargs.get("user_id", ""), - "message": [{"role": "assistant", "content": dia.prompt_config["prologue"], "created_at": time.time()}] - } + conv = {"id": session_id, "dialog_id": dialog_id, "user_id": kwargs.get("user_id", ""), "message": [{"role": "assistant", "content": dia.prompt_config["prologue"], "created_at": time.time()}]} API4ConversationService.save(**conv) - yield "data:" + json.dumps({"code": 0, "message": "", - "data": { - "answer": conv["message"][0]["content"], - "reference": {}, - "audio_binary": None, - "id": None, - "session_id": session_id - }}, - ensure_ascii=False) + "\n\n" + yield ( + "data:" + + json.dumps({"code": 0, "message": "", "data": {"answer": conv["message"][0]["content"], "reference": {}, "audio_binary": None, "id": None, "session_id": session_id}}, ensure_ascii=False) + + "\n\n" + ) yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" return else: @@ -370,11 +341,7 @@ async def async_iframe_completion(dialog_id, question, session_id=None, stream=T if not conv.message: conv.message = [] messages = conv.message - question = { - "role": "user", - "content": question, - "id": str(uuid4()) - } + question = {"role": "user", "content": question, "id": str(uuid4())} messages.append(question) msg = [] @@ -396,13 +363,10 @@ async def async_iframe_completion(dialog_id, question, session_id=None, stream=T try: async for ans in async_chat(dia, msg, True, session_id=session_id, **kwargs): ans = structure_answer(conv, ans, message_id, session_id) - yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, - ensure_ascii=False) + "\n\n" + yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" API4ConversationService.append_message(conv.id, conv.to_dict()) except Exception as e: - yield "data:" + json.dumps({"code": 500, "message": str(e), - "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, - ensure_ascii=False) + "\n\n" + yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" else: diff --git a/api/db/services/doc_metadata_service.py b/api/db/services/doc_metadata_service.py index f07c8fb052..8f2f2a875c 100644 --- a/api/db/services/doc_metadata_service.py +++ b/api/db/services/doc_metadata_service.py @@ -93,13 +93,14 @@ class DocMetadataService: if not flat_meta or not isinstance(flat_meta, dict): return {} - meta_fields = flat_meta.get('meta_fields') + meta_fields = flat_meta.get("meta_fields") if not meta_fields: return {} # Parse JSON string if needed if isinstance(meta_fields, str): import json + try: return json.loads(meta_fields) except json.JSONDecodeError: @@ -125,7 +126,7 @@ class DocMetadataService: """ if hit: # ES format: doc is in _source, id is in _id - return hit.get('_id', '') + return hit.get("_id", "") # DataFrame or list format: check multiple possible fields return doc.get("doc_id") or doc.get("_id") or doc.get("id", "") @@ -146,7 +147,7 @@ class DocMetadataService: results = results[0] # Extract DataFrame from tuple # Check if results is a pandas DataFrame (from Infinity) - if hasattr(results, 'iterrows'): + if hasattr(results, "iterrows"): # Handle pandas DataFrame - use iterrows() to iterate over rows for _, row in results.iterrows(): doc = dict(row) # Convert Series to dict @@ -156,11 +157,11 @@ class DocMetadataService: # Check if ES format (has 'hits' key) # Note: ES returns ObjectApiResponse which is dict-like but not isinstance(dict) - elif hasattr(results, 'get') and 'hits' in results: + elif hasattr(results, "get") and "hits" in results: # ES format: {"hits": {"hits": [{"_source": {...}, "_id": "..."}]}} - hits = results.get('hits', {}).get('hits', []) + hits = results.get("hits", {}).get("hits", []) for hit in hits: - doc = hit.get('_source', {}) + doc = hit.get("_source", {}) doc_id = cls._extract_doc_id(doc, hit) if doc_id: yield doc_id, doc @@ -179,7 +180,7 @@ class DocMetadataService: yield doc_id, doc # Check if OceanBase SearchResult format - elif hasattr(results, 'chunks') and hasattr(results, 'total'): + elif hasattr(results, "chunks") and hasattr(results, "total"): # OceanBase format: SearchResult(total=int, chunks=[{...}, {...}]) for doc in results.chunks: doc_id = cls._extract_doc_id(doc) @@ -237,7 +238,7 @@ class DocMetadataService: offset=page * page_size, limit=page_size, index_names=index_name, - knowledgebase_ids=[kb_id] + knowledgebase_ids=[kb_id], ) # Handle different result formats @@ -251,28 +252,28 @@ class DocMetadataService: # Check for Infinity format first (DataFrame, total) tuple if isinstance(results, tuple) and len(results) == 2: df, total_count = results - if hasattr(df, 'iterrows'): + if hasattr(df, "iterrows"): # Pandas DataFrame from Infinity - page_docs = df.to_dict('records') + page_docs = df.to_dict("records") else: page_docs = list(df) if df else [] # Check for ES format (dict with 'hits' key) - elif hasattr(results, 'get') and 'hits' in results: - hits_obj = results.get('hits', {}) - hits = hits_obj.get('hits', []) + elif hasattr(results, "get") and "hits" in results: + hits_obj = results.get("hits", {}) + hits = hits_obj.get("hits", []) page_docs = [] for hit in hits: - doc = hit.get('_source', {}) - doc['id'] = hit.get('_id', '') # Add _id as 'id' for _extract_doc_id to work + doc = hit.get("_source", {}) + doc["id"] = hit.get("_id", "") # Add _id as 'id' for _extract_doc_id to work page_docs.append(doc) # Extract total count from ES response - total_hits = hits_obj.get('total', {}) + total_hits = hits_obj.get("total", {}) if isinstance(total_hits, dict): - total_count = total_hits.get('value', len(page_docs)) + total_count = total_hits.get("value", len(page_docs)) else: total_count = total_hits if total_hits else len(page_docs) # Handle list/iterable results - elif hasattr(results, '__iter__') and not isinstance(results, dict): + elif hasattr(results, "__iter__") and not isinstance(results, dict): page_docs = list(results) else: page_docs = [] @@ -323,7 +324,7 @@ class DocMetadataService: if isinstance(item, str): # Split by common delimiters: Chinese comma (、), regular comma (,), pipe (|), semicolon (;), Chinese semicolon (;) # Also handle mixed delimiters and spaces - split_items = re.split(r'[、,,;;|]+', item.strip()) + split_items = re.split(r"[、,,;;|]+", item.strip()) # Trim whitespace and filter empty strings split_items = [s.strip() for s in split_items if s.strip()] if split_items: @@ -358,9 +359,7 @@ class DocMetadataService: """ try: # Get document with tenant_id (need to join with Knowledgebase) - doc_query = Document.select(Document, Knowledgebase.tenant_id).join( - Knowledgebase, on=(Knowledgebase.id == Document.kb_id) - ).where(Document.id == doc_id) + doc_query = Document.select(Document, Knowledgebase.tenant_id).join(Knowledgebase, on=(Knowledgebase.id == Document.kb_id)).where(Document.id == doc_id) doc = doc_query.first() if not doc: @@ -406,11 +405,7 @@ class DocMetadataService: logging.debug(f"Metadata table already exists: {index_name}") # Insert into ES/Infinity - result = settings.docStoreConn.insert( - [doc_meta], - index_name, - kb_id - ) + result = settings.docStoreConn.insert([doc_meta], index_name, kb_id) if result: logging.error(f"Failed to insert metadata for document {doc_id}: {result}") @@ -427,11 +422,7 @@ class DocMetadataService: # A failed refresh can leave just-inserted metadata # invisible to subsequent reads; surface it so operators # can correlate stale-read complaints with the cause. - logging.warning( - f"Failed to refresh metadata index {index_name} on backend " - f"{type(settings.docStoreConn).__name__}; " - f"metadata may not be immediately searchable" - ) + logging.warning(f"Failed to refresh metadata index {index_name} on backend {type(settings.docStoreConn).__name__}; metadata may not be immediately searchable") else: logging.debug(f"Backend {type(settings.docStoreConn).__name__} has no refresh_idx; skipping") @@ -460,9 +451,7 @@ class DocMetadataService: """ try: # Get document with tenant_id - doc_query = Document.select(Document, Knowledgebase.tenant_id).join( - Knowledgebase, on=(Knowledgebase.id == Document.kb_id) - ).where(Document.id == doc_id) + doc_query = Document.select(Document, Knowledgebase.tenant_id).join(Knowledgebase, on=(Knowledgebase.id == Document.kb_id)).where(Document.id == doc_id) doc = doc_query.first() if not doc: @@ -478,8 +467,7 @@ class DocMetadataService: # Post-process to split combined values processed_meta = cls._split_combined_values(meta_fields) - logging.debug( - f"[update_document_metadata] Updating doc_id: {doc_id}, kb_id: {kb_id}, meta_fields: {processed_meta}") + logging.debug(f"[update_document_metadata] Updating doc_id: {doc_id}, kb_id: {kb_id}, meta_fields: {processed_meta}") # For Elasticsearch, use efficient partial update if not settings.DOC_ENGINE_INFINITY and not settings.DOC_ENGINE_OCEANBASE: @@ -487,8 +475,7 @@ class DocMetadataService: index_exists = settings.docStoreConn.index_exist(index_name, "") if not index_exists: # Index doesn't exist - create it and insert directly - logging.debug( - f"[update_document_metadata] Index {index_name} does not exist, creating and inserting") + logging.debug(f"[update_document_metadata] Index {index_name} does not exist, creating and inserting") result = settings.docStoreConn.create_doc_meta_idx(index_name) if result is False: logging.error(f"Failed to create metadata index {index_name}") @@ -497,11 +484,7 @@ class DocMetadataService: # Index exists - check if document exists try: - doc_exists = settings.docStoreConn.get( - doc_id, - index_name, - [kb_id] - ) + doc_exists = settings.docStoreConn.get(doc_id, index_name, [kb_id]) if doc_exists: # Document exists - replace meta_fields entirely. # Using update with a `doc` body would deep-merge the meta_fields @@ -509,13 +492,9 @@ class DocMetadataService: # to a backend-provided scripted assignment that fully overwrites it. replace_meta_fields = getattr(settings.docStoreConn, "replace_meta_fields", None) if callable(replace_meta_fields) and replace_meta_fields(index_name, doc_id, processed_meta): - logging.debug( - f"Successfully updated metadata for document {doc_id} via {type(settings.docStoreConn).__name__}.replace_meta_fields") + logging.debug(f"Successfully updated metadata for document {doc_id} via {type(settings.docStoreConn).__name__}.replace_meta_fields") return True - logging.warning( - f"replace_meta_fields unavailable or failed on backend " - f"{type(settings.docStoreConn).__name__}; falling back to delete+insert" - ) + logging.warning(f"replace_meta_fields unavailable or failed on backend {type(settings.docStoreConn).__name__}; falling back to delete+insert") # Mirror the Infinity fallback below so a failed scripted # replace still guarantees full overwrite semantics rather # than leaking through the "document not found" branch. @@ -570,8 +549,7 @@ class DocMetadataService: # Check if metadata table exists before attempting deletion # This is the key optimization - no table = no metadata = nothing to delete if not settings.docStoreConn.index_exist(index_name, ""): - logging.debug( - f"Metadata table {index_name} does not exist, skipping metadata deletion for document {doc_id}") + logging.debug(f"Metadata table {index_name} does not exist, skipping metadata deletion for document {doc_id}") return True # No metadata to delete is considered success # Try to get the metadata to confirm it exists before deleting @@ -580,7 +558,7 @@ class DocMetadataService: existing_metadata = settings.docStoreConn.get( doc_id, index_name, - [""] # Empty list for metadata tables + [""], # Empty list for metadata tables ) logging.debug(f"[METADATA DELETE] Get result: {existing_metadata is not None}") if not existing_metadata: @@ -598,7 +576,7 @@ class DocMetadataService: deleted_count = settings.docStoreConn.delete( {"id": doc_id}, index_name, - kb_id # Pass actual kb_id (delete() will handle metadata tables correctly) + kb_id, # Pass actual kb_id (delete() will handle metadata tables correctly) ) logging.debug(f"[METADATA DELETE] Deleted count: {deleted_count}") return True @@ -639,7 +617,7 @@ class DocMetadataService: if count_value < 0: raise RuntimeError("native count_idx unavailable or failed") logging.debug(f"[DROP EMPTY TABLE] count_idx API result: {count_value} documents") - is_empty = (count_value == 0) + is_empty = count_value == 0 except Exception as e: logging.warning(f"[DROP EMPTY TABLE] Count API failed, falling back to search: {e}") # Fallback to search if count fails @@ -652,7 +630,7 @@ class DocMetadataService: offset=0, limit=1, # Only need 1 result to know if table is non-empty index_names=index_name, - knowledgebase_ids=[""] # Metadata tables don't filter by KB + knowledgebase_ids=[""], # Metadata tables don't filter by KB ) logging.debug(f"[DROP EMPTY TABLE] Search results type: {type(results)}, results: {results}") @@ -661,25 +639,24 @@ class DocMetadataService: if isinstance(results, tuple) and len(results) == 2: # Infinity returns (DataFrame, int) df, total = results - logging.debug( - f"[DROP EMPTY TABLE] Infinity format - total: {total}, df length: {len(df) if hasattr(df, '__len__') else 'N/A'}") - is_empty = (total == 0 or (hasattr(df, '__len__') and len(df) == 0)) - elif hasattr(results, 'get') and 'hits' in results: + logging.debug(f"[DROP EMPTY TABLE] Infinity format - total: {total}, df length: {len(df) if hasattr(df, '__len__') else 'N/A'}") + is_empty = total == 0 or (hasattr(df, "__len__") and len(df) == 0) + elif hasattr(results, "get") and "hits" in results: # ES format - MUST check this before hasattr(results, '__len__') # because ES response objects also have __len__ - total = results.get('hits', {}).get('total', {}) - hits = results.get('hits', {}).get('hits', []) + total = results.get("hits", {}).get("total", {}) + hits = results.get("hits", {}).get("hits", []) # ES 7.x+: total is a dict like {'value': 0, 'relation': 'eq'} # ES 6.x: total is an int if isinstance(total, dict): - total_count = total.get('value', 0) + total_count = total.get("value", 0) else: total_count = total logging.debug(f"[DROP EMPTY TABLE] ES format - total: {total_count}, hits count: {len(hits)}") - is_empty = (total_count == 0 or len(hits) == 0) - elif hasattr(results, '__len__'): + is_empty = total_count == 0 or len(hits) == 0 + elif hasattr(results, "__len__"): # DataFrame or list (check this AFTER ES format) result_len = len(results) logging.debug(f"[DROP EMPTY TABLE] List/DataFrame format - length: {result_len}") @@ -713,9 +690,7 @@ class DocMetadataService: """ try: # Get document with tenant_id - doc_query = Document.select(Document, Knowledgebase.tenant_id).join( - Knowledgebase, on=(Knowledgebase.id == Document.kb_id) - ).where(Document.id == doc_id) + doc_query = Document.select(Document, Knowledgebase.tenant_id).join(Knowledgebase, on=(Knowledgebase.id == Document.kb_id)).where(Document.id == doc_id) doc = doc_query.first() if not doc: @@ -729,11 +704,7 @@ class DocMetadataService: index_name = cls._get_doc_meta_index_name(tenant_id) # Try to get metadata from ES/Infinity - metadata_doc = settings.docStoreConn.get( - doc_id, - index_name, - [kb_id] - ) + metadata_doc = settings.docStoreConn.get(doc_id, index_name, [kb_id]) if metadata_doc: # Extract and unflatten metadata @@ -791,7 +762,7 @@ class DocMetadataService: offset=offset, limit=page_size, index_names=index_name, - knowledgebase_ids=kb_ids + knowledgebase_ids=kb_ids, ) batch_docs = list(cls._iter_search_results(batch)) if not batch_docs: @@ -799,7 +770,10 @@ class DocMetadataService: all_results.extend(batch_docs) logging.debug( "[get_flatted_meta_by_kbs] offset=%d batch=%d total=%d kb_ids=%s", - offset, len(batch_docs), len(all_results), kb_ids, + offset, + len(batch_docs), + len(all_results), + kb_ids, ) if len(batch_docs) < page_size: break @@ -826,12 +800,12 @@ class DocMetadataService: doc_count = len(all_results) if doc_count >= 100000: logging.warning( - "[get_flatted_meta_by_kbs] Large result set: %d documents for KBs %s. " - "Consider performance impact.", doc_count, kb_ids, + "[get_flatted_meta_by_kbs] Large result set: %d documents for KBs %s. Consider performance impact.", + doc_count, + kb_ids, ) - logging.debug("[get_flatted_meta_by_kbs] KBs: %s, Retrieved %d documents, metadata: %s", - kb_ids, doc_count, meta) + logging.debug("[get_flatted_meta_by_kbs] KBs: %s, Retrieved %d documents, metadata: %s", kb_ids, doc_count, meta) return meta except Exception as e: @@ -840,11 +814,11 @@ class DocMetadataService: @classmethod def filter_doc_ids_by_meta_pushdown( - cls, - kb_ids: List[str], - filters: List[Dict], - logic: str = "and", - limit: int = 10000, + cls, + kb_ids: List[str], + filters: List[Dict], + logic: str = "and", + limit: int = 10000, ) -> Optional[List[str]]: """Run a metadata filter directly against ES or Infinity, returning matching doc IDs. @@ -878,22 +852,18 @@ class DocMetadataService: return [] if settings.DOC_ENGINE_INFINITY: - return cls._filter_doc_ids_by_metadata_infinity( - index_name, kb_ids, filters, logic - ) + return cls._filter_doc_ids_by_metadata_infinity(index_name, kb_ids, filters, logic) else: - return cls._filter_doc_ids_by_metadata_es( - index_name, kb_ids, filters, logic, limit - ) + return cls._filter_doc_ids_by_metadata_es(index_name, kb_ids, filters, logic, limit) @classmethod def _filter_doc_ids_by_metadata_es( - cls, - index_name: str, - kb_ids: List[str], - filters: List[Dict], - logic: str, - limit: int, + cls, + index_name: str, + kb_ids: List[str], + filters: List[Dict], + logic: str, + limit: int, ) -> Optional[List[str]]: """ES push-down path for metadata filtering.""" from common.metadata_es_filter import ( @@ -942,9 +912,7 @@ class DocMetadataService: unique.append(did) if len(unique) >= limit: - logging.warning( - f"ES metadata filter hit limit {limit} for KBs {kb_ids}" - ) + logging.warning(f"ES metadata filter hit limit {limit} for KBs {kb_ids}") # Detect silent truncation: the push-down is a fast path, not # the system of record. When the query matched more than @@ -956,10 +924,7 @@ class DocMetadataService: # dropping docs. total = _es_response_total(response) if total is not None and total > limit: - logging.warning( - f"ES metadata filter result exceeds push-down cap, falling back to in-memory: " - f"total={total}, cap={limit}, kb_ids={kb_ids}" - ) + logging.warning(f"ES metadata filter result exceeds push-down cap, falling back to in-memory: total={total}, cap={limit}, kb_ids={kb_ids}") return None logging.debug(f"ES metadata filter returned {len(unique)} matches for KBs {kb_ids}") @@ -967,11 +932,11 @@ class DocMetadataService: @classmethod def _filter_doc_ids_by_metadata_infinity( - cls, - index_name: str, - kb_ids: List[str], - filters: List[Dict], - logic: str, + cls, + index_name: str, + kb_ids: List[str], + filters: List[Dict], + logic: str, ) -> Optional[List[str]]: """Infinity push-down path for metadata filtering.""" from common.metadata_infinity_filter import ( @@ -996,8 +961,7 @@ class DocMetadataService: table_instance = db_instance.get_table(index_name) df, _ = table_instance.output(["id"]).filter(where_clause).to_df() doc_ids = extract_doc_ids(df) - logging.debug( - f"Infinity metadata filter returned {len(doc_ids)} doc IDs for kb_ids={kb_ids}, logic={logic}") + logging.debug(f"Infinity metadata filter returned {len(doc_ids)} doc IDs for kb_ids={kb_ids}, logic={logic}") return doc_ids finally: settings.docStoreConn.connPool.release_conn(inf_conn) @@ -1061,14 +1025,12 @@ class DocMetadataService: # Use helper to iterate over results for doc_id, doc in cls._iter_search_results(results): - # Extract metadata (handles both JSON strings and dicts) doc_meta = cls._extract_metadata(doc) if doc_meta: meta_mapping[doc_id] = doc_meta - logging.debug( - f"[get_metadata_for_documents] Found metadata for {len(meta_mapping)}/{len(doc_ids) if doc_ids else 'all'} documents") + logging.debug(f"[get_metadata_for_documents] Found metadata for {len(meta_mapping)}/{len(doc_ids) if doc_ids else 'all'} documents") return meta_mapping except Exception as e: @@ -1099,7 +1061,7 @@ class DocMetadataService: """Check if a string value is an ISO 8601 datetime (e.g., '2026-02-03T00:00:00').""" if not isinstance(value, str): return False - return bool(re.match(r'^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}$', value)) + return bool(re.match(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}$", value)) def _meta_value_type(value): """Determine the type of a metadata value.""" @@ -1131,7 +1093,6 @@ class DocMetadataService: # Use helper to iterate over results in any format for doc_id, doc in cls._iter_search_results(results): - doc_meta = cls._extract_metadata(doc) for k, v in doc_meta.items(): @@ -1334,8 +1295,7 @@ class DocMetadataService: doc_ids_set = set(doc_ids) missing_doc_ids = doc_ids_set - found_doc_ids if missing_doc_ids and updates: - logging.debug( - f"[batch_update_metadata] Inserting new metadata for documents without metadata rows: {missing_doc_ids}") + logging.debug(f"[batch_update_metadata] Inserting new metadata for documents without metadata rows: {missing_doc_ids}") for doc_id in missing_doc_ids: # Apply updates to create new metadata meta = {} diff --git a/api/db/services/evaluation_service.py b/api/db/services/evaluation_service.py index 8a4878ad3a..d2dadef852 100644 --- a/api/db/services/evaluation_service.py +++ b/api/db/services/evaluation_service.py @@ -50,8 +50,7 @@ class EvaluationService(CommonService): # ==================== Dataset Management ==================== @classmethod - def create_dataset(cls, name: str, description: str, kb_ids: List[str], - tenant_id: str, user_id: str) -> Tuple[bool, str]: + def create_dataset(cls, name: str, description: str, kb_ids: List[str], tenant_id: str, user_id: str) -> Tuple[bool, str]: """ Create a new evaluation dataset. @@ -66,7 +65,7 @@ class EvaluationService(CommonService): (success, dataset_id or error_message) """ try: - timestamp= current_timestamp() + timestamp = current_timestamp() dataset_id = get_uuid() dataset = { "id": dataset_id, @@ -77,7 +76,7 @@ class EvaluationService(CommonService): "created_by": user_id, "create_time": timestamp, "update_time": timestamp, - "status": StatusEnum.VALID.value + "status": StatusEnum.VALID.value, } if not EvaluationDataset.create(**dataset): @@ -101,22 +100,15 @@ class EvaluationService(CommonService): return None @classmethod - def list_datasets(cls, tenant_id: str, user_id: str, - page: int = 1, page_size: int = 20) -> Dict[str, Any]: + def list_datasets(cls, tenant_id: str, user_id: str, page: int = 1, page_size: int = 20) -> Dict[str, Any]: """List datasets for a tenant""" try: - query = EvaluationDataset.select().where( - (EvaluationDataset.tenant_id == tenant_id) & - (EvaluationDataset.status == StatusEnum.VALID.value) - ).order_by(EvaluationDataset.create_time.desc()) + query = EvaluationDataset.select().where((EvaluationDataset.tenant_id == tenant_id) & (EvaluationDataset.status == StatusEnum.VALID.value)).order_by(EvaluationDataset.create_time.desc()) total = query.count() datasets = query.paginate(page, page_size) - return { - "total": total, - "datasets": [d.to_dict() for d in datasets] - } + return {"total": total, "datasets": [d.to_dict() for d in datasets]} except Exception as e: logging.error(f"Error listing datasets: {e}") return {"total": 0, "datasets": []} @@ -126,9 +118,7 @@ class EvaluationService(CommonService): """Update dataset""" try: kwargs["update_time"] = current_timestamp() - return EvaluationDataset.update(**kwargs).where( - EvaluationDataset.id == dataset_id - ).execute() > 0 + return EvaluationDataset.update(**kwargs).where(EvaluationDataset.id == dataset_id).execute() > 0 except Exception as e: logging.error(f"Error updating dataset {dataset_id}: {e}") return False @@ -137,10 +127,7 @@ class EvaluationService(CommonService): def delete_dataset(cls, dataset_id: str) -> bool: """Soft delete dataset""" try: - return EvaluationDataset.update( - status=StatusEnum.INVALID.value, - update_time=current_timestamp() - ).where(EvaluationDataset.id == dataset_id).execute() > 0 + return EvaluationDataset.update(status=StatusEnum.INVALID.value, update_time=current_timestamp()).where(EvaluationDataset.id == dataset_id).execute() > 0 except Exception as e: logging.error(f"Error deleting dataset {dataset_id}: {e}") return False @@ -148,11 +135,15 @@ class EvaluationService(CommonService): # ==================== Test Case Management ==================== @classmethod - def add_test_case(cls, dataset_id: str, question: str, - reference_answer: Optional[str] = None, - relevant_doc_ids: Optional[List[str]] = None, - relevant_chunk_ids: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None) -> Tuple[bool, str]: + def add_test_case( + cls, + dataset_id: str, + question: str, + reference_answer: Optional[str] = None, + relevant_doc_ids: Optional[List[str]] = None, + relevant_chunk_ids: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> Tuple[bool, str]: """ Add a test case to a dataset. @@ -177,7 +168,7 @@ class EvaluationService(CommonService): "relevant_doc_ids": relevant_doc_ids, "relevant_chunk_ids": relevant_chunk_ids, "metadata": metadata, - "create_time": current_timestamp() + "create_time": current_timestamp(), } if not EvaluationCase.create(**case): @@ -192,9 +183,7 @@ class EvaluationService(CommonService): def get_test_cases(cls, dataset_id: str) -> List[Dict[str, Any]]: """Get all test cases for a dataset""" try: - cases = EvaluationCase.select().where( - EvaluationCase.dataset_id == dataset_id - ).order_by(EvaluationCase.create_time) + cases = EvaluationCase.select().where(EvaluationCase.dataset_id == dataset_id).order_by(EvaluationCase.create_time) return [c.to_dict() for c in cases] except Exception as e: @@ -205,9 +194,7 @@ class EvaluationService(CommonService): def delete_test_case(cls, case_id: str) -> bool: """Delete a test case""" try: - return EvaluationCase.delete().where( - EvaluationCase.id == case_id - ).execute() > 0 + return EvaluationCase.delete().where(EvaluationCase.id == case_id).execute() > 0 except Exception as e: logging.error(f"Error deleting test case {case_id}: {e}") return False @@ -227,10 +214,10 @@ class EvaluationService(CommonService): success_count = 0 failure_count = 0 case_instances = [] - + if not cases: return success_count, failure_count - + cur_timestamp = current_timestamp() try: @@ -244,7 +231,7 @@ class EvaluationService(CommonService): "relevant_doc_ids": case_data.get("relevant_doc_ids"), "relevant_chunk_ids": case_data.get("relevant_chunk_ids"), "metadata": case_data.get("metadata"), - "create_time": cur_timestamp + "create_time": cur_timestamp, } case_instances.append(EvaluationCase(**case_info)) @@ -262,8 +249,7 @@ class EvaluationService(CommonService): # ==================== Evaluation Execution ==================== @classmethod - def start_evaluation(cls, dataset_id: str, dialog_id: str, - user_id: str, name: Optional[str] = None) -> Tuple[bool, str]: + def start_evaluation(cls, dataset_id: str, dialog_id: str, user_id: str, name: Optional[str] = None) -> Tuple[bool, str]: """ Start an evaluation run. @@ -297,7 +283,7 @@ class EvaluationService(CommonService): "status": "RUNNING", "created_by": user_id, "create_time": current_timestamp(), - "complete_time": None + "complete_time": None, } if not EvaluationRun.create(**run): @@ -324,10 +310,7 @@ class EvaluationService(CommonService): test_cases = cls.get_test_cases(dataset_id) if not test_cases: - EvaluationRun.update( - status="FAILED", - complete_time=current_timestamp() - ).where(EvaluationRun.id == run_id).execute() + EvaluationRun.update(status="FAILED", complete_time=current_timestamp()).where(EvaluationRun.id == run_id).execute() return # Execute each test case @@ -341,22 +324,14 @@ class EvaluationService(CommonService): metrics_summary = cls._compute_summary_metrics(results) # Update run status - EvaluationRun.update( - status="COMPLETED", - metrics_summary=metrics_summary, - complete_time=current_timestamp() - ).where(EvaluationRun.id == run_id).execute() + EvaluationRun.update(status="COMPLETED", metrics_summary=metrics_summary, complete_time=current_timestamp()).where(EvaluationRun.id == run_id).execute() except Exception as e: logging.error(f"Error executing evaluation {run_id}: {e}") - EvaluationRun.update( - status="FAILED", - complete_time=current_timestamp() - ).where(EvaluationRun.id == run_id).execute() + EvaluationRun.update(status="FAILED", complete_time=current_timestamp()).where(EvaluationRun.id == run_id).execute() @classmethod - def _evaluate_single_case(cls, run_id: str, case: Dict[str, Any], - dialog: Any) -> Optional[Dict[str, Any]]: + def _evaluate_single_case(cls, run_id: str, case: Dict[str, Any], dialog: Any) -> Optional[Dict[str, Any]]: """ Evaluate a single test case. @@ -377,7 +352,6 @@ class EvaluationService(CommonService): answer = "" retrieved_chunks = [] - def _sync_from_async_gen(async_gen): result_queue: queue.Queue = queue.Queue() @@ -407,7 +381,6 @@ class EvaluationService(CommonService): raise item yield item - def chat(dialog, messages, stream=True, **kwargs): from api.db.services.dialog_service import async_chat @@ -434,7 +407,7 @@ class EvaluationService(CommonService): reference_answer=case.get("reference_answer"), retrieved_chunks=retrieved_chunks, relevant_chunk_ids=case.get("relevant_chunk_ids"), - dialog=dialog + dialog=dialog, ) # Track token usage: use full prompt from async_chat when available. @@ -446,8 +419,7 @@ class EvaluationService(CommonService): prompt_tokens = num_tokens_from_string(full_prompt) else: logging.debug( - "Evaluation case %s: ans has no 'prompt' key; using question-only count " - "(undercounts system + retrieved context)", + "Evaluation case %s: ans has no 'prompt' key; using question-only count (undercounts system + retrieved context)", case.get("id", "unknown"), ) prompt_tokens = num_tokens_from_string(case.get("question", "") or "") @@ -469,7 +441,7 @@ class EvaluationService(CommonService): "metrics": metrics, "execution_time": execution_time, "token_usage": token_usage, - "create_time": current_timestamp() + "create_time": current_timestamp(), } EvaluationResult.create(**result) @@ -480,11 +452,9 @@ class EvaluationService(CommonService): return None @classmethod - def _compute_metrics(cls, question: str, generated_answer: str, - reference_answer: Optional[str], - retrieved_chunks: List[Dict[str, Any]], - relevant_chunk_ids: Optional[List[str]], - dialog: Any) -> Dict[str, float]: + def _compute_metrics( + cls, question: str, generated_answer: str, reference_answer: Optional[str], retrieved_chunks: List[Dict[str, Any]], relevant_chunk_ids: Optional[List[str]], dialog: Any + ) -> Dict[str, float]: """ Compute evaluation metrics for a single test case. @@ -513,8 +483,7 @@ class EvaluationService(CommonService): return metrics @classmethod - def _compute_retrieval_metrics(cls, retrieved_ids: List[str], - relevant_ids: List[str]) -> Dict[str, float]: + def _compute_retrieval_metrics(cls, retrieved_ids: List[str], relevant_ids: List[str]) -> Dict[str, float]: """ Compute retrieval metrics. @@ -550,13 +519,7 @@ class EvaluationService(CommonService): mrr = 1.0 / i break - return { - "precision": precision, - "recall": recall, - "f1_score": f1, - "hit_rate": hit_rate, - "mrr": mrr - } + return {"precision": precision, "recall": recall, "f1_score": f1, "hit_rate": hit_rate, "mrr": mrr} @classmethod def _compute_summary_metrics(cls, results: List[Dict[str, Any]]) -> Dict[str, Any]: @@ -584,10 +547,7 @@ class EvaluationService(CommonService): metric_counts[key] = metric_counts.get(key, 0) + 1 # Compute averages - summary = { - "total_cases": len(results), - "avg_execution_time": sum(r.get("execution_time", 0) for r in results) / len(results) - } + summary = {"total_cases": len(results), "avg_execution_time": sum(r.get("execution_time", 0) for r in results) / len(results)} for key in metric_sums: summary[f"avg_{key}"] = metric_sums[key] / metric_counts[key] @@ -604,14 +564,9 @@ class EvaluationService(CommonService): if not run: return {} - results = EvaluationResult.select().where( - EvaluationResult.run_id == run_id - ).order_by(EvaluationResult.create_time) + results = EvaluationResult.select().where(EvaluationResult.run_id == run_id).order_by(EvaluationResult.create_time) - return { - "run": run.to_dict(), - "results": [r.to_dict() for r in results] - } + return {"run": run.to_dict(), "results": [r.to_dict() for r in results]} except Exception as e: logging.error(f"Error getting run results {run_id}: {e}") return {} @@ -637,43 +592,41 @@ class EvaluationService(CommonService): # Low precision: retrieving irrelevant chunks if metrics.get("avg_precision", 1.0) < 0.7: - recommendations.append({ - "issue": "Low Precision", - "severity": "high", - "description": "System is retrieving many irrelevant chunks", - "suggestions": [ - "Increase similarity_threshold to filter out less relevant chunks", - "Enable reranking to improve chunk ordering", - "Reduce top_k to return fewer chunks" - ] - }) + recommendations.append( + { + "issue": "Low Precision", + "severity": "high", + "description": "System is retrieving many irrelevant chunks", + "suggestions": ["Increase similarity_threshold to filter out less relevant chunks", "Enable reranking to improve chunk ordering", "Reduce top_k to return fewer chunks"], + } + ) # Low recall: missing relevant chunks if metrics.get("avg_recall", 1.0) < 0.7: - recommendations.append({ - "issue": "Low Recall", - "severity": "high", - "description": "System is missing relevant chunks", - "suggestions": [ - "Increase top_k to retrieve more chunks", - "Lower similarity_threshold to be more inclusive", - "Enable hybrid search (keyword + semantic)", - "Check chunk size - may be too large or too small" - ] - }) + recommendations.append( + { + "issue": "Low Recall", + "severity": "high", + "description": "System is missing relevant chunks", + "suggestions": [ + "Increase top_k to retrieve more chunks", + "Lower similarity_threshold to be more inclusive", + "Enable hybrid search (keyword + semantic)", + "Check chunk size - may be too large or too small", + ], + } + ) # Slow response time if metrics.get("avg_execution_time", 0) > 5.0: - recommendations.append({ - "issue": "Slow Response Time", - "severity": "medium", - "description": f"Average response time is {metrics['avg_execution_time']:.2f}s", - "suggestions": [ - "Reduce top_k to retrieve fewer chunks", - "Optimize embedding model selection", - "Consider caching frequently asked questions" - ] - }) + recommendations.append( + { + "issue": "Slow Response Time", + "severity": "medium", + "description": f"Average response time is {metrics['avg_execution_time']:.2f}s", + "suggestions": ["Reduce top_k to retrieve fewer chunks", "Optimize embedding model selection", "Consider caching frequently asked questions"], + } + ) return recommendations except Exception as e: diff --git a/api/db/services/mcp_server_service.py b/api/db/services/mcp_server_service.py index 1eae882d6f..101555f4b3 100644 --- a/api/db/services/mcp_server_service.py +++ b/api/db/services/mcp_server_service.py @@ -33,8 +33,7 @@ class MCPServerService(CommonService): @classmethod @DB.connection_context() - def get_servers(cls, tenant_id: str, id_list: list[str] | None, page_number, items_per_page, orderby, desc, - keywords): + def get_servers(cls, tenant_id: str, id_list: list[str] | None, page_number, items_per_page, orderby, desc, keywords): """Retrieve all MCP servers associated with a tenant. This method fetches all MCP servers for a given tenant, ordered by creation time. diff --git a/api/db/services/memory_service.py b/api/db/services/memory_service.py index 552747e6f1..1bb16cf796 100644 --- a/api/db/services/memory_service.py +++ b/api/db/services/memory_service.py @@ -66,11 +66,9 @@ class MemoryService(CommonService): cls.model.system_prompt, cls.model.user_prompt, cls.model.create_date, - cls.model.create_time + cls.model.create_time, ] - memory = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id)).where( - cls.model.id == memory_id - ).first() + memory = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id)).where(cls.model.id == memory_id).first() return memory @classmethod @@ -89,16 +87,13 @@ class MemoryService(CommonService): cls.model.permissions, cls.model.description, cls.model.create_time, - cls.model.create_date + cls.model.create_date, ] memories = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id)) if filter_dict.get("tenant_id"): memories = memories.where(cls.model.tenant_id.in_(filter_dict["tenant_id"])) if filter_dict.get("accessible_user_id"): - memories = memories.where( - (cls.model.tenant_id == filter_dict["accessible_user_id"]) | - (cls.model.permissions == "team") - ) + memories = memories.where((cls.model.tenant_id == filter_dict["accessible_user_id"]) | (cls.model.permissions == "team")) if filter_dict.get("memory_type"): memory_type_int = calculate_memory_type(filter_dict["memory_type"]) memories = memories.where(cls.model.memory_type.bin_and(memory_type_int) > 0) @@ -116,11 +111,7 @@ class MemoryService(CommonService): @DB.connection_context() def create_memory(cls, tenant_id: str, name: str, memory_type: List[str], embd_id: str, llm_id: str): # Deduplicate name within tenant - memory_name = duplicate_name( - cls.query, - name=name, - tenant_id=tenant_id - ) + memory_name = duplicate_name(cls.query, name=name, tenant_id=tenant_id) if len(memory_name) > MEMORY_NAME_LIMIT: return False, f"Memory name {memory_name} exceeds limit of {MEMORY_NAME_LIMIT}." @@ -159,15 +150,8 @@ class MemoryService(CommonService): if "memory_type" in update_dict and isinstance(update_dict["memory_type"], list): update_dict["memory_type"] = calculate_memory_type(update_dict["memory_type"]) if "name" in update_dict: - update_dict["name"] = duplicate_name( - cls.query, - name=update_dict["name"], - tenant_id=tenant_id - ) - update_dict.update({ - "update_time": current_timestamp(), - "update_date": get_format_time() - }) + update_dict["name"] = duplicate_name(cls.query, name=update_dict["name"], tenant_id=tenant_id) + update_dict.update({"update_time": current_timestamp(), "update_date": get_format_time()}) return cls.model.update(update_dict).where(cls.model.id == memory_id).execute() @@ -179,21 +163,13 @@ class MemoryService(CommonService): @classmethod @DB.connection_context() def get_null_tenant_embd_id_row(cls): - fields = [ - cls.model.id, - cls.model.tenant_id, - cls.model.embd_id - ] + fields = [cls.model.id, cls.model.tenant_id, cls.model.embd_id] objs = cls.model.select(*fields).where(cls.model.tenant_embd_id.is_null()) return list(objs) @classmethod @DB.connection_context() def get_null_tenant_llm_id_row(cls): - fields = [ - cls.model.id, - cls.model.tenant_id, - cls.model.llm_id - ] + fields = [cls.model.id, cls.model.tenant_id, cls.model.llm_id] objs = cls.model.select(*fields).where(cls.model.tenant_llm_id.is_null()) return list(objs) diff --git a/api/db/services/search_service.py b/api/db/services/search_service.py index 90d0e0e605..8e14a84bf2 100644 --- a/api/db/services/search_service.py +++ b/api/db/services/search_service.py @@ -96,8 +96,7 @@ class SearchService(CommonService): query = ( cls.model.select(*fields) .join(User, on=(cls.model.tenant_id == User.id)) - .where(((cls.model.tenant_id.in_(joined_tenant_ids)) | (cls.model.tenant_id == user_id)) & ( - cls.model.status == StatusEnum.VALID.value)) + .where(((cls.model.tenant_id.in_(joined_tenant_ids)) | (cls.model.tenant_id == user_id)) & (cls.model.status == StatusEnum.VALID.value)) ) if keywords: diff --git a/api/db/services/tenant_model_group_mapping_service.py b/api/db/services/tenant_model_group_mapping_service.py index 590c65129f..49f28cf0a5 100644 --- a/api/db/services/tenant_model_group_mapping_service.py +++ b/api/db/services/tenant_model_group_mapping_service.py @@ -28,4 +28,4 @@ class TenantModelGroupMappingService(CommonService): cls.model.provider_id == provider_id, cls.model.instance_id == instance_id, cls.model.model_id == model_id, - ) \ No newline at end of file + ) diff --git a/api/db/services/tenant_model_group_service.py b/api/db/services/tenant_model_group_service.py index 88781eb17e..73682af887 100644 --- a/api/db/services/tenant_model_group_service.py +++ b/api/db/services/tenant_model_group_service.py @@ -18,4 +18,4 @@ from api.db.services.common_service import CommonService class TenantModelGroupService(CommonService): - model = TenantModelGroup \ No newline at end of file + model = TenantModelGroup diff --git a/api/db/services/tenant_model_instance_service.py b/api/db/services/tenant_model_instance_service.py index 0f44de8962..76fcbf2417 100644 --- a/api/db/services/tenant_model_instance_service.py +++ b/api/db/services/tenant_model_instance_service.py @@ -18,6 +18,7 @@ from api.db.db_models import DB, TenantModelInstance from api.db.services.common_service import CommonService from api.db.services import duplicate_name + class TenantModelInstanceService(CommonService): model = TenantModelInstance @@ -48,22 +49,21 @@ class TenantModelInstanceService(CommonService): @classmethod @DB.connection_context() def get_by_provider_id_and_api_key(cls, provider_id, api_key): - return cls.model.get_or_none( - cls.model.provider_id == provider_id, - cls.model.api_key == api_key - ) + return cls.model.get_or_none(cls.model.provider_id == provider_id, cls.model.api_key == api_key) @classmethod @DB.connection_context() def delete_by_provider_id_and_instance_name(cls, provider_id, instance_name): - return cls.model.delete().where( - cls.model.provider_id == provider_id, - cls.model.instance_name == instance_name, - ).execute() + return ( + cls.model.delete() + .where( + cls.model.provider_id == provider_id, + cls.model.instance_name == instance_name, + ) + .execute() + ) @classmethod @DB.connection_context() def delete_by_provider_ids(cls, provider_ids): - return cls.model.delete().where( - cls.model.provider_id.in_(provider_ids) - ).execute() + return cls.model.delete().where(cls.model.provider_id.in_(provider_ids)).execute() diff --git a/api/db/services/tenant_model_provider_service.py b/api/db/services/tenant_model_provider_service.py index ee9a1ab951..91c95f13ed 100644 --- a/api/db/services/tenant_model_provider_service.py +++ b/api/db/services/tenant_model_provider_service.py @@ -49,12 +49,16 @@ class TenantModelProviderService(CommonService): @classmethod @DB.connection_context() def delete_by_tenant_id_and_provider_name(cls, tenant_id, provider_name): - return cls.model.delete().where( - cls.model.tenant_id == tenant_id, - cls.model.provider_name == provider_name, - ).execute() + return ( + cls.model.delete() + .where( + cls.model.tenant_id == tenant_id, + cls.model.provider_name == provider_name, + ) + .execute() + ) @classmethod @DB.connection_context() def list_provider_names_by_tenant_id(cls, tenant_id): - return [row.provider_name for row in cls.model.select(cls.model.provider_name).where(cls.model.tenant_id == tenant_id)] \ No newline at end of file + return [row.provider_name for row in cls.model.select(cls.model.provider_name).where(cls.model.tenant_id == tenant_id)] diff --git a/api/db/services/tenant_model_service.py b/api/db/services/tenant_model_service.py index 2c32d45295..98dc6e892a 100644 --- a/api/db/services/tenant_model_service.py +++ b/api/db/services/tenant_model_service.py @@ -29,21 +29,12 @@ class TenantModelService(CommonService): @classmethod @DB.connection_context() def get_by_provider_id_and_instance_id_and_model_type_and_model_name(cls, provider_id, instance_id, model_type, model_name): - return cls.model.get_or_none( - cls.model.provider_id == provider_id, - cls.model.instance_id == instance_id, - cls.model.model_type == model_type, - cls.model.model_name == model_name - ) + return cls.model.get_or_none(cls.model.provider_id == provider_id, cls.model.instance_id == instance_id, cls.model.model_type == model_type, cls.model.model_name == model_name) @classmethod @DB.connection_context() def get_by_provider_id_and_instance_id_and_model_type(cls, provider_id, instance_id, model_type): - return cls.model.get_or_none( - cls.model.provider_id == provider_id, - cls.model.instance_id == instance_id, - cls.model.model_type == model_type - ) + return cls.model.get_or_none(cls.model.provider_id == provider_id, cls.model.instance_id == instance_id, cls.model.model_type == model_type) @classmethod @DB.connection_context() @@ -66,22 +57,9 @@ class TenantModelService(CommonService): model_type_records = cls.model.select().where(cls.model.provider_id == provider_id, cls.model.instance_id == instance_id, cls.model.model_name == model_name) if not model_type_records: for _type in operation.get("add", []): - cls.insert( - model_name=model_name, - provider_id=provider_id, - instance_id=instance_id, - model_type=_type, - extra="{}" - ) + cls.insert(model_name=model_name, provider_id=provider_id, instance_id=instance_id, model_type=_type, extra="{}") for _type in operation.get("delete", []): - cls.insert( - model_name=model_name, - provider_id=provider_id, - instance_id=instance_id, - model_type=_type, - status=ActiveStatusEnum.UNSUPPORTED, - extra="{}" - ) + cls.insert(model_name=model_name, provider_id=provider_id, instance_id=instance_id, model_type=_type, status=ActiveStatusEnum.UNSUPPORTED, extra="{}") return len(operation.get("add", [])) + len(operation.get("delete", [])) model_record_example = [model_record for model_record in model_type_records if model_record.status != ActiveStatusEnum.UNSUPPORTED.value] extra_fields = model_record_example[0].extra if model_record_example else "{}" @@ -93,27 +71,13 @@ class TenantModelService(CommonService): cls.update_by_id(type_record_map[_type].id, {"status": model_status}) else: - cls.insert( - model_name=model_name, - provider_id=provider_id, - instance_id=instance_id, - model_type=_type, - status=model_status, - extra=extra_fields - ) + cls.insert(model_name=model_name, provider_id=provider_id, instance_id=instance_id, model_type=_type, status=model_status, extra=extra_fields) operated_cnt += 1 for _type in operation.get("delete", []): if type_record_map.get(_type): cls.update_by_id(type_record_map[_type].id, {"status": ActiveStatusEnum.UNSUPPORTED.value}) else: - cls.insert( - model_name=model_name, - provider_id=provider_id, - instance_id=instance_id, - model_type=_type, - status=ActiveStatusEnum.UNSUPPORTED.value, - extra=extra_fields - ) + cls.insert(model_name=model_name, provider_id=provider_id, instance_id=instance_id, model_type=_type, status=ActiveStatusEnum.UNSUPPORTED.value, extra=extra_fields) operated_cnt += 1 return operated_cnt diff --git a/api/db/services/user_canvas_version.py b/api/db/services/user_canvas_version.py index faaca89d10..315312f572 100644 --- a/api/db/services/user_canvas_version.py +++ b/api/db/services/user_canvas_version.py @@ -42,14 +42,7 @@ class UserCanvasVersionService(CommonService): def list_by_canvas_id(cls, user_canvas_id): try: user_canvas_version = cls.model.select( - *[cls.model.id, - cls.model.create_time, - cls.model.title, - cls.model.create_date, - cls.model.update_date, - cls.model.user_canvas_id, - cls.model.update_time, - cls.model.release] + *[cls.model.id, cls.model.create_time, cls.model.title, cls.model.create_date, cls.model.update_date, cls.model.user_canvas_id, cls.model.update_time, cls.model.release] ).where(cls.model.user_canvas_id == user_canvas_id) return user_canvas_version except DoesNotExist: @@ -138,12 +131,7 @@ class UserCanvasVersionService(CommonService): """ try: normalized_dsl = cls._normalize_dsl(dsl) - latest = ( - cls.model.select() - .where(cls.model.user_canvas_id == user_canvas_id) - .order_by(cls.model.create_time.desc()) - .first() - ) + latest = cls.model.select().where(cls.model.user_canvas_id == user_canvas_id).order_by(cls.model.create_time.desc()).first() # Repeated saves with the same DSL only refresh the latest snapshot. if latest and cls._normalize_dsl(latest.dsl) == normalized_dsl: diff --git a/api/db/services/user_service.py b/api/db/services/user_service.py index d6b985dd47..0fb7775ad7 100644 --- a/api/db/services/user_service.py +++ b/api/db/services/user_service.py @@ -39,13 +39,14 @@ class UserService(CommonService): Attributes: model: The User model class for database operations. """ + model = User @classmethod @DB.connection_context() def query(cls, cols=None, reverse=None, order_by=None, **kwargs): - if 'access_token' in kwargs: - access_token = kwargs['access_token'] + if "access_token" in kwargs: + access_token = kwargs["access_token"] # Reject empty, None, or whitespace-only access tokens if not access_token or not str(access_token).strip(): @@ -94,8 +95,7 @@ class UserService(CommonService): Returns: User object if authentication successful, None otherwise. """ - user = cls.model.select().where((cls.model.email == email), - (cls.model.status == StatusEnum.VALID.value)).first() + user = cls.model.select().where((cls.model.email == email), (cls.model.status == StatusEnum.VALID.value)).first() if user and check_password_hash(str(user.password), password): return user else: @@ -113,8 +113,7 @@ class UserService(CommonService): if "id" not in kwargs: kwargs["id"] = get_uuid() if "password" in kwargs: - kwargs["password"] = generate_password_hash( - str(kwargs["password"])) + kwargs["password"] = generate_password_hash(str(kwargs["password"])) current_ts = current_timestamp() current_date = datetime_format(datetime.now()) @@ -130,8 +129,7 @@ class UserService(CommonService): @DB.connection_context() def delete_user(cls, user_ids, update_user_dict): with DB.atomic(): - cls.model.update({"status": 0}).where( - cls.model.id.in_(user_ids)).execute() + cls.model.update({"status": 0}).where(cls.model.id.in_(user_ids)).execute() @classmethod @DB.connection_context() @@ -140,26 +138,19 @@ class UserService(CommonService): if user_dict: user_dict["update_time"] = current_timestamp() user_dict["update_date"] = datetime_format(datetime.now()) - cls.model.update(user_dict).where( - cls.model.id == user_id).execute() + cls.model.update(user_dict).where(cls.model.id == user_id).execute() @classmethod @DB.connection_context() def update_user_password(cls, user_id, new_password): with DB.atomic(): - update_dict = { - "password": generate_password_hash(str(new_password)), - "update_time": current_timestamp(), - "update_date": datetime_format(datetime.now()) - } + update_dict = {"password": generate_password_hash(str(new_password)), "update_time": current_timestamp(), "update_date": datetime_format(datetime.now())} cls.model.update(update_dict).where(cls.model.id == user_id).execute() @classmethod @DB.connection_context() def is_admin(cls, user_id): - return cls.model.select().where( - cls.model.id == user_id, - cls.model.is_superuser == 1).count() > 0 + return cls.model.select().where(cls.model.id == user_id, cls.model.is_superuser == 1).count() > 0 @classmethod @DB.connection_context() @@ -177,6 +168,7 @@ class TenantService(CommonService): Attributes: model: The Tenant model class for database operations. """ + model = Tenant @classmethod @@ -193,31 +185,32 @@ class TenantService(CommonService): cls.model.tts_id, cls.model.ocr_id, cls.model.parser_ids, - UserTenant.role] - return list(cls.model.select(*fields) - .join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id == user_id) & (UserTenant.status == StatusEnum.VALID.value) & (UserTenant.role == UserTenantRole.OWNER))) - .where(cls.model.status == StatusEnum.VALID.value).dicts()) + UserTenant.role, + ] + return list( + cls.model.select(*fields) + .join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id == user_id) & (UserTenant.status == StatusEnum.VALID.value) & (UserTenant.role == UserTenantRole.OWNER))) + .where(cls.model.status == StatusEnum.VALID.value) + .dicts() + ) @classmethod @DB.connection_context() def get_joined_tenants_by_user_id(cls, user_id): - fields = [ - cls.model.id.alias("tenant_id"), - cls.model.name, - cls.model.llm_id, - cls.model.embd_id, - cls.model.asr_id, - cls.model.img2txt_id, - UserTenant.role] - return list(cls.model.select(*fields) - .join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id == user_id) & (UserTenant.status == StatusEnum.VALID.value) & (UserTenant.role == UserTenantRole.NORMAL))) - .where(cls.model.status == StatusEnum.VALID.value).dicts()) + fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, UserTenant.role] + return list( + cls.model.select(*fields) + .join( + UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id == user_id) & (UserTenant.status == StatusEnum.VALID.value) & (UserTenant.role == UserTenantRole.NORMAL)) + ) + .where(cls.model.status == StatusEnum.VALID.value) + .dicts() + ) @classmethod @DB.connection_context() def decrease(cls, user_id, num): - num = cls.model.update(credit=cls.model.credit - num).where( - cls.model.id == user_id).execute() + num = cls.model.update(credit=cls.model.credit - num).where(cls.model.id == user_id).execute() if num == 0: raise LookupError("Tenant not found which is supposed to be there") @@ -225,12 +218,19 @@ class TenantService(CommonService): @DB.connection_context() def user_gateway(cls, tenant_id): hash_obj = hashlib.sha256(tenant_id.encode("utf-8")) - return int(hash_obj.hexdigest(), 16)%len(settings.MINIO) + return int(hash_obj.hexdigest(), 16) % len(settings.MINIO) @classmethod @DB.connection_context() def get_null_tenant_model_id_rows(cls): - objs = cls.model.select().orwhere(cls.model.tenant_llm_id.is_null(), cls.model.tenant_embd_id.is_null(), cls.model.tenant_asr_id.is_null(), cls.model.tenant_tts_id.is_null(), cls.model.tenant_rerank_id.is_null(), cls.model.tenant_img2txt_id.is_null()) + objs = cls.model.select().orwhere( + cls.model.tenant_llm_id.is_null(), + cls.model.tenant_embd_id.is_null(), + cls.model.tenant_asr_id.is_null(), + cls.model.tenant_tts_id.is_null(), + cls.model.tenant_rerank_id.is_null(), + cls.model.tenant_img2txt_id.is_null(), + ) return list(objs) @@ -243,6 +243,7 @@ class UserTenantService(CommonService): Attributes: model: The UserTenant model class for database operations. """ + model = UserTenant @classmethod @@ -278,36 +279,30 @@ class UserTenantService(CommonService): User.is_anonymous, User.status, User.update_date, - User.is_superuser] - return list(cls.model.select(*fields) - .join(User, on=((cls.model.user_id == User.id) & (cls.model.status == StatusEnum.VALID.value) & (cls.model.role != UserTenantRole.OWNER))) - .where(cls.model.tenant_id == tenant_id) - .dicts()) + User.is_superuser, + ] + return list( + cls.model.select(*fields) + .join(User, on=((cls.model.user_id == User.id) & (cls.model.status == StatusEnum.VALID.value) & (cls.model.role != UserTenantRole.OWNER))) + .where(cls.model.tenant_id == tenant_id) + .dicts() + ) @classmethod @DB.connection_context() def get_tenants_by_user_id(cls, user_id): - fields = [ - cls.model.tenant_id, - cls.model.role, - User.nickname, - User.email, - User.avatar, - User.update_date - ] - return list(cls.model.select(*fields) - .join(User, on=((cls.model.tenant_id == User.id) & (UserTenant.user_id == user_id) & (UserTenant.status == StatusEnum.VALID.value))) - .where(cls.model.status == StatusEnum.VALID.value).dicts()) + fields = [cls.model.tenant_id, cls.model.role, User.nickname, User.email, User.avatar, User.update_date] + return list( + cls.model.select(*fields) + .join(User, on=((cls.model.tenant_id == User.id) & (UserTenant.user_id == user_id) & (UserTenant.status == StatusEnum.VALID.value))) + .where(cls.model.status == StatusEnum.VALID.value) + .dicts() + ) @classmethod @DB.connection_context() def get_user_tenant_relation_by_user_id(cls, user_id): - fields = [ - cls.model.id, - cls.model.user_id, - cls.model.tenant_id, - cls.model.role - ] + fields = [cls.model.id, cls.model.user_id, cls.model.tenant_id, cls.model.role] return list(cls.model.select(*fields).where(cls.model.user_id == user_id).dicts().dicts()) @classmethod @@ -320,10 +315,7 @@ class UserTenantService(CommonService): @DB.connection_context() def filter_by_tenant_and_user_id(cls, tenant_id, user_id): try: - user_tenant = cls.model.select().where( - (cls.model.tenant_id == tenant_id) & (cls.model.status == StatusEnum.VALID.value) & - (cls.model.user_id == user_id) - ).first() + user_tenant = cls.model.select().where((cls.model.tenant_id == tenant_id) & (cls.model.status == StatusEnum.VALID.value) & (cls.model.user_id == user_id)).first() return user_tenant except peewee.DoesNotExist: return None diff --git a/api/ragflow_server.py b/api/ragflow_server.py index 777c995fa7..5f580075d6 100644 --- a/api/ragflow_server.py +++ b/api/ragflow_server.py @@ -17,6 +17,7 @@ print("Start RAGFlow server...") import time + start_ts = time.time() import os @@ -48,7 +49,8 @@ from rag.utils.redis_conn import RedisDistributedLock stop_event = threading.Event() -RAGFLOW_DEBUGPY_LISTEN = int(os.environ.get('RAGFLOW_DEBUGPY_LISTEN', "0")) +RAGFLOW_DEBUGPY_LISTEN = int(os.environ.get("RAGFLOW_DEBUGPY_LISTEN", "0")) + def update_progress(): lock_value = str(uuid.uuid4()) @@ -68,6 +70,7 @@ def update_progress(): logging.exception("update_progress exception") stop_event.wait(6) + def signal_handler(sig, frame): logging.info("Received interrupt signal, shutting down...") shutdown_all_mcp_sessions() @@ -75,7 +78,8 @@ def signal_handler(sig, frame): stop_event.wait(1) sys.exit(0) -if __name__ == '__main__': + +if __name__ == "__main__": faulthandler.enable() init_root_logger("ragflow_server") logging.info(r""" @@ -86,12 +90,8 @@ if __name__ == '__main__': /_/ |_|/_/ |_|\____//_/ /_/ \____/ |__/|__/ """) - logging.info( - f'RAGFlow version: {get_ragflow_version()}' - ) - logging.info( - f'project base: {get_project_base_directory()}' - ) + logging.info(f"RAGFlow version: {get_ragflow_version()}") + logging.info(f"project base: {get_project_base_directory()}") show_configs() settings.init_settings() settings.print_rag_settings() @@ -99,6 +99,7 @@ if __name__ == '__main__': if RAGFLOW_DEBUGPY_LISTEN > 0: logging.info(f"debugpy listen on {RAGFLOW_DEBUGPY_LISTEN}") import debugpy + debugpy.listen(("0.0.0.0", RAGFLOW_DEBUGPY_LISTEN)) # init db @@ -108,15 +109,9 @@ if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() - parser.add_argument( - "--version", default=False, help="RAGFlow version", action="store_true" - ) - parser.add_argument( - "--debug", default=False, help="debug mode", action="store_true" - ) - parser.add_argument( - "--init-superuser", default=False, help="init superuser", action="store_true" - ) + parser.add_argument("--version", default=False, help="RAGFlow version", action="store_true") + parser.add_argument("--debug", default=False, help="debug mode", action="store_true") + parser.add_argument("--init-superuser", default=False, help="init superuser", action="store_true") args = parser.parse_args() if args.version: print(get_ragflow_version()) @@ -144,6 +139,7 @@ if __name__ == '__main__': def start_chat_channels(): try: from api.channels.bootstrap import start_channel_server + logging.info("Starting chat channel server thread") t = threading.Thread( target=start_channel_server, diff --git a/api/utils/__init__.py b/api/utils/__init__.py index e7d5615028..fbe96e5475 100644 --- a/api/utils/__init__.py +++ b/api/utils/__init__.py @@ -21,7 +21,6 @@ def from_dict_hook(in_dict: dict): if in_dict["module"] is None: return in_dict["data"] else: - return getattr(importlib.import_module( - in_dict["module"]), in_dict["type"])(**in_dict["data"]) + return getattr(importlib.import_module(in_dict["module"]), in_dict["type"])(**in_dict["data"]) else: return in_dict diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index c95d689c94..973384b8d1 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -51,6 +51,7 @@ from common.misc_utils import thread_pool_exec requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder) + def _safe_jsonify(payload: dict): if has_app_context(): return jsonify(payload) @@ -87,9 +88,11 @@ async def _coerce_request_data() -> dict: request._cached_payload = payload return payload + async def get_request_json(): return await _coerce_request_data() + def serialize_for_json(obj): """ Recursively serialize objects to make them JSON serializable. @@ -211,6 +214,7 @@ def not_allowed_parameters(*params): if inspect.iscoroutinefunction(func): return await func(*args, **kwargs) return func(*args, **kwargs) + return wrapper return decorator @@ -238,10 +242,12 @@ def add_tenant_id_to_kwargs(func): @wraps(func) async def wrapper(**kwargs): from api.apps import current_user + kwargs["tenant_id"] = current_user.id if inspect.iscoroutinefunction(func): return await func(**kwargs) return func(**kwargs) + return wrapper @@ -677,24 +683,15 @@ async def is_strong_enough(chat_model, embedding_model): async def _is_strong_enough(): nonlocal chat_model, embedding_model if embedding_model: - await asyncio.wait_for( - thread_pool_exec(embedding_model.encode, ["Are you strong enough!?"]), - timeout=10 - ) + await asyncio.wait_for(thread_pool_exec(embedding_model.encode, ["Are you strong enough!?"]), timeout=10) if chat_model: - res = await asyncio.wait_for( - chat_model.async_chat("Nothing special.", [{"role": "user", "content": "Are you strong enough!?"}]), - timeout=30 - ) + res = await asyncio.wait_for(chat_model.async_chat("Nothing special.", [{"role": "user", "content": "Are you strong enough!?"}]), timeout=30) if "**ERROR**" in res: raise Exception(res) # Pressure test for GraphRAG task - tasks = [ - asyncio.create_task(_is_strong_enough()) - for _ in range(count) - ] + tasks = [asyncio.create_task(_is_strong_enough()) for _ in range(count)] try: await asyncio.gather(*tasks, return_exceptions=False) except Exception as e: diff --git a/api/utils/commands.py b/api/utils/commands.py index a3df7b507d..d72d41ad04 100644 --- a/api/utils/commands.py +++ b/api/utils/commands.py @@ -24,54 +24,50 @@ from werkzeug.security import generate_password_hash from api.db.services import UserService -@click.command('reset-password', help='Reset the account password.') -@click.option('--email', prompt=True, help='The email address of the account whose password you need to reset') -@click.option('--new-password', prompt=True, help='the new password.') -@click.option('--password-confirm', prompt=True, help='the new password confirm.') +@click.command("reset-password", help="Reset the account password.") +@click.option("--email", prompt=True, help="The email address of the account whose password you need to reset") +@click.option("--new-password", prompt=True, help="the new password.") +@click.option("--password-confirm", prompt=True, help="the new password confirm.") def reset_password(email, new_password, password_confirm): if str(new_password).strip() != str(password_confirm).strip(): - click.echo(click.style('sorry. The two passwords do not match.', fg='red')) + click.echo(click.style("sorry. The two passwords do not match.", fg="red")) return user = UserService.query(email=email) if not user: - click.echo(click.style('sorry. The Email is not registered!.', fg='red')) + click.echo(click.style("sorry. The Email is not registered!.", fg="red")) return - encode_password = base64.b64encode(new_password.encode('utf-8')).decode('utf-8') + encode_password = base64.b64encode(new_password.encode("utf-8")).decode("utf-8") password_hash = generate_password_hash(encode_password) - user_dict = { - 'password': password_hash - } - UserService.update_user(user[0].id,user_dict) - click.echo(click.style('Congratulations! Password has been reset.', fg='green')) + user_dict = {"password": password_hash} + UserService.update_user(user[0].id, user_dict) + click.echo(click.style("Congratulations! Password has been reset.", fg="green")) -@click.command('reset-email', help='Reset the account email.') -@click.option('--email', prompt=True, help='The old email address of the account whose email you need to reset') -@click.option('--new-email', prompt=True, help='the new email.') -@click.option('--email-confirm', prompt=True, help='the new email confirm.') +@click.command("reset-email", help="Reset the account email.") +@click.option("--email", prompt=True, help="The old email address of the account whose email you need to reset") +@click.option("--new-email", prompt=True, help="the new email.") +@click.option("--email-confirm", prompt=True, help="the new email confirm.") def reset_email(email, new_email, email_confirm): if str(new_email).strip() != str(email_confirm).strip(): - click.echo(click.style('Sorry, new email and confirm email do not match.', fg='red')) + click.echo(click.style("Sorry, new email and confirm email do not match.", fg="red")) return if str(new_email).strip() == str(email).strip(): - click.echo(click.style('Sorry, new email and old email are the same.', fg='red')) + click.echo(click.style("Sorry, new email and old email are the same.", fg="red")) return user = UserService.query(email=email) if not user: - click.echo(click.style('sorry. the account: [{}] not exist .'.format(email), fg='red')) + click.echo(click.style("sorry. the account: [{}] not exist .".format(email), fg="red")) return if not re.match(r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,4}$", new_email): - click.echo(click.style('sorry. {} is not a valid email. '.format(new_email), fg='red')) + click.echo(click.style("sorry. {} is not a valid email. ".format(new_email), fg="red")) return new_user = UserService.query(email=new_email) if new_user: - click.echo(click.style('sorry. the account: [{}] is exist .'.format(new_email), fg='red')) + click.echo(click.style("sorry. the account: [{}] is exist .".format(new_email), fg="red")) return - user_dict = { - 'email': new_email - } - UserService.update_user(user[0].id,user_dict) - click.echo(click.style('Congratulations!, email has been reset.', fg='green')) + user_dict = {"email": new_email} + UserService.update_user(user[0].id, user_dict) + click.echo(click.style("Congratulations!, email has been reset.", fg="green")) def register_commands(app: Quart): diff --git a/api/utils/common.py b/api/utils/common.py index 4d38c40d21..1b07783eb6 100644 --- a/api/utils/common.py +++ b/api/utils/common.py @@ -17,13 +17,13 @@ import xxhash def string_to_bytes(string): - return string if isinstance( - string, bytes) else string.encode(encoding="utf-8") + return string if isinstance(string, bytes) else string.encode(encoding="utf-8") def bytes_to_string(byte): return byte.decode(encoding="utf-8") + # 128 bit = 32 character def hash128(data: str) -> str: return xxhash.xxh128(data).hexdigest() diff --git a/api/utils/configs.py b/api/utils/configs.py index c3abc13c37..6b5929a43f 100644 --- a/api/utils/configs.py +++ b/api/utils/configs.py @@ -19,21 +19,18 @@ import base64 import pickle from api.utils.common import bytes_to_string, string_to_bytes -safe_module = { - 'numpy', - 'rag_flow' -} +safe_module = {"numpy", "rag_flow"} class RestrictedUnpickler(pickle.Unpickler): def find_class(self, module, name): import importlib - if module.split('.')[0] in safe_module: + + if module.split(".")[0] in safe_module: _module = importlib.import_module(module) return getattr(_module, name) # Forbid everything else. - raise pickle.UnpicklingError("global '%s.%s' is forbidden" % - (module, name)) + raise pickle.UnpicklingError("global '%s.%s' is forbidden" % (module, name)) def restricted_loads(src): @@ -50,7 +47,5 @@ def serialize_b64(src, to_str=False): def deserialize_b64(src): - src = base64.b64decode( - string_to_bytes(src) if isinstance( - src, str) else src) + src = base64.b64decode(string_to_bytes(src) if isinstance(src, str) else src) return restricted_loads(src) diff --git a/api/utils/crypt.py b/api/utils/crypt.py index 0f3a28ae6e..53afcbec0d 100644 --- a/api/utils/crypt.py +++ b/api/utils/crypt.py @@ -30,25 +30,26 @@ def crypt(line): file_path = os.path.join(get_project_base_directory(), "conf", "public.pem") rsa_key = RSA.importKey(Path(file_path).read_text(), "Welcome") cipher = Cipher_pkcs1_v1_5.new(rsa_key) - password_base64 = base64.b64encode(line.encode('utf-8')).decode("utf-8") + password_base64 = base64.b64encode(line.encode("utf-8")).decode("utf-8") encrypted_password = cipher.encrypt(password_base64.encode()) - return base64.b64encode(encrypted_password).decode('utf-8') + return base64.b64encode(encrypted_password).decode("utf-8") def decrypt(line): file_path = os.path.join(get_project_base_directory(), "conf", "private.pem") rsa_key = RSA.importKey(Path(file_path).read_text(), "Welcome") cipher = Cipher_pkcs1_v1_5.new(rsa_key) - return cipher.decrypt(base64.b64decode(line), "Fail to decrypt password!").decode('utf-8') + return cipher.decrypt(base64.b64decode(line), "Fail to decrypt password!").decode("utf-8") def decrypt2(crypt_text): from base64 import b64decode, b16decode from Crypto.Cipher import PKCS1_v1_5 as Cipher_PKCS1_v1_5 from Crypto.PublicKey import RSA + decode_data = b64decode(crypt_text) if len(decode_data) == 127: - hex_fixed = '00' + decode_data.hex() + hex_fixed = "00" + decode_data.hex() decode_data = b16decode(hex_fixed.upper()) file_path = os.path.join(get_project_base_directory(), "conf", "private.pem") diff --git a/api/utils/health_utils.py b/api/utils/health_utils.py index 34f098b8c9..8e8a574f5f 100644 --- a/api/utils/health_utils.py +++ b/api/utils/health_utils.py @@ -70,14 +70,11 @@ def check_storage() -> tuple[bool, dict]: def get_es_cluster_stats() -> dict: - doc_engine = os.getenv('DOC_ENGINE', 'elasticsearch') - if doc_engine != 'elasticsearch': + doc_engine = os.getenv("DOC_ENGINE", "elasticsearch") + if doc_engine != "elasticsearch": raise Exception("Elasticsearch is not in use.") try: - return { - "status": "alive", - "message": ESConnection().get_cluster_stats() - } + return {"status": "alive", "message": ESConnection().get_cluster_stats()} except Exception as e: return { "status": "timeout", @@ -86,14 +83,11 @@ def get_es_cluster_stats() -> dict: def get_infinity_status(): - doc_engine = os.getenv('DOC_ENGINE', 'elasticsearch') - if doc_engine != 'infinity': + doc_engine = os.getenv("DOC_ENGINE", "elasticsearch") + if doc_engine != "infinity": raise Exception("Infinity is not in use.") try: - return { - "status": "alive", - "message": InfinityConnection().health() - } + return {"status": "alive", "message": InfinityConnection().health()} except Exception as e: return { "status": "timeout", @@ -104,28 +98,22 @@ def get_infinity_status(): def get_oceanbase_status(): """ Get OceanBase health status and performance metrics. - + Returns: dict: OceanBase status with health information and performance metrics """ - doc_engine = os.getenv('DOC_ENGINE', 'elasticsearch') - if doc_engine != 'oceanbase': + doc_engine = os.getenv("DOC_ENGINE", "elasticsearch") + if doc_engine != "oceanbase": raise Exception("OceanBase is not in use.") try: ob_conn = OBConnection() health_info = ob_conn.health() performance_metrics = ob_conn.get_performance_metrics() - + # Combine health and performance metrics status = "alive" if health_info.get("status") == "healthy" else "timeout" - - return { - "status": status, - "message": { - "health": health_info, - "performance": performance_metrics - } - } + + return {"status": status, "message": {"health": health_info, "performance": performance_metrics}} except Exception as e: return { "status": "timeout", @@ -136,7 +124,7 @@ def get_oceanbase_status(): def check_oceanbase_health() -> dict: """ Check OceanBase health status with comprehensive metrics. - + This function provides detailed health information including: - Connection status - Query latency @@ -144,28 +132,22 @@ def check_oceanbase_health() -> dict: - Query throughput (QPS) - Slow query statistics - Connection pool statistics - + Returns: dict: Health status with detailed metrics """ - doc_engine = os.getenv('DOC_ENGINE', 'elasticsearch') - if doc_engine != 'oceanbase': - return { - "status": "not_configured", - "details": { - "connection": "not_configured", - "message": "OceanBase is not configured as the document engine" - } - } - + doc_engine = os.getenv("DOC_ENGINE", "elasticsearch") + if doc_engine != "oceanbase": + return {"status": "not_configured", "details": {"connection": "not_configured", "message": "OceanBase is not configured as the document engine"}} + try: ob_conn = OBConnection() health_info = ob_conn.health() performance_metrics = ob_conn.get_performance_metrics() - + # Determine overall health status connection_status = performance_metrics.get("connection", "unknown") - + # If connection is disconnected, return unhealthy if connection_status == "disconnected" or health_info.get("status") != "healthy": return { @@ -181,16 +163,15 @@ def check_oceanbase_health() -> dict: "max_connections": performance_metrics.get("max_connections", 0), "uri": health_info.get("uri", "unknown"), "version": health_info.get("version_comment", "unknown"), - "error": health_info.get("error", performance_metrics.get("error")) - } + "error": health_info.get("error", performance_metrics.get("error")), + }, } - + # Check if healthy (connected and low latency) is_healthy = ( - connection_status == "connected" and - performance_metrics.get("latency_ms", float('inf')) < 1000 # Latency under 1 second + connection_status == "connected" and performance_metrics.get("latency_ms", float("inf")) < 1000 # Latency under 1 second ) - + return { "status": "healthy" if is_healthy else "degraded", "details": { @@ -203,29 +184,20 @@ def check_oceanbase_health() -> dict: "active_connections": performance_metrics.get("active_connections", 0), "max_connections": performance_metrics.get("max_connections", 0), "uri": health_info.get("uri", "unknown"), - "version": health_info.get("version_comment", "unknown") - } + "version": health_info.get("version_comment", "unknown"), + }, } except Exception as e: - return { - "status": "unhealthy", - "details": { - "connection": "disconnected", - "error": str(e) - } - } + return {"status": "unhealthy", "details": {"connection": "disconnected", "error": str(e)}} def get_mysql_status(): try: cursor = DB.execute_sql("SHOW PROCESSLIST;") res_rows = cursor.fetchall() - headers = ['id', 'user', 'host', 'db', 'command', 'time', 'state', 'info'] + headers = ["id", "user", "host", "db", "command", "time", "state", "info"] cursor.close() - return { - "status": "alive", - "message": [dict(zip(headers, r)) for r in res_rows] - } + return {"status": "alive", "message": [dict(zip(headers, r)) for r in res_rows]} except Exception as e: return { "status": "timeout", @@ -276,10 +248,7 @@ def check_minio_alive(): def get_redis_info(): try: - return { - "status": "alive", - "message": REDIS_CONN.info() - } + return {"status": "alive", "message": REDIS_CONN.info()} except Exception as e: return { "status": "timeout", @@ -290,9 +259,9 @@ def get_redis_info(): def check_ragflow_server_alive(): start_time = timer() try: - url = f'http://{settings.HOST_IP}:{settings.HOST_PORT}/api/v1/system/ping' - if '0.0.0.0' in url: - url = url.replace('0.0.0.0', '127.0.0.1') + url = f"http://{settings.HOST_IP}:{settings.HOST_PORT}/api/v1/system/ping" + if "0.0.0.0" in url: + url = url.replace("0.0.0.0", "127.0.0.1") response = requests.get(url, timeout=10) if response.status_code == 200: return {"status": "alive", "message": f"Confirm elapsed: {(timer() - start_time) * 1000.0:.1f} ms."} @@ -320,10 +289,7 @@ def check_task_executor_alive(): else: return {"status": "timeout", "message": "Not found any task executor."} except Exception as e: - return { - "status": "timeout", - "message": f"error: {str(e)}" - } + return {"status": "timeout", "message": f"error: {str(e)}"} def run_health_checks() -> tuple[dict, bool]: @@ -358,7 +324,6 @@ def run_health_checks() -> tuple[dict, bool]: except Exception: result["storage"] = "nok" - all_ok = (result.get("db") == "ok") and (result.get("redis") == "ok") and (result.get("doc_engine") == "ok") and ( - result.get("storage") == "ok") + all_ok = (result.get("db") == "ok") and (result.get("redis") == "ok") and (result.get("doc_engine") == "ok") and (result.get("storage") == "ok") result["status"] = "ok" if all_ok else "nok" return result, all_ok diff --git a/api/utils/json_encode.py b/api/utils/json_encode.py index fa5ea973aa..f2d280dcbc 100644 --- a/api/utils/json_encode.py +++ b/api/utils/json_encode.py @@ -43,8 +43,7 @@ class BaseType: data[_k] = _dict(vv) else: data = obj - return {"type": obj.__class__.__name__, - "data": data, "module": module} + return {"type": obj.__class__.__name__, "data": data, "module": module} return _dict(self) @@ -56,9 +55,9 @@ class CustomJSONEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, datetime.datetime): - return obj.strftime('%Y-%m-%d %H:%M:%S') + return obj.strftime("%Y-%m-%d %H:%M:%S") elif isinstance(obj, datetime.date): - return obj.strftime('%Y-%m-%d') + return obj.strftime("%Y-%m-%d") elif isinstance(obj, datetime.timedelta): return str(obj) elif issubclass(type(obj), Enum) or issubclass(type(obj), IntEnum): @@ -77,11 +76,7 @@ class CustomJSONEncoder(json.JSONEncoder): def json_dumps(src, byte=False, indent=None, with_type=False): - dest = json.dumps( - src, - indent=indent, - cls=CustomJSONEncoder, - with_type=with_type) + dest = json.dumps(src, indent=indent, cls=CustomJSONEncoder, with_type=with_type) if byte: dest = string_to_bytes(dest) return dest @@ -90,5 +85,4 @@ def json_dumps(src, byte=False, indent=None, with_type=False): def json_loads(src, object_hook=None, object_pairs_hook=None): if isinstance(src, bytes): src = bytes_to_string(src) - return json.loads(src, object_hook=object_hook, - object_pairs_hook=object_pairs_hook) + return json.loads(src, object_hook=object_hook, object_pairs_hook=object_pairs_hook) diff --git a/api/utils/memory_utils.py b/api/utils/memory_utils.py index bb78949518..2cdb8fe428 100644 --- a/api/utils/memory_utils.py +++ b/api/utils/memory_utils.py @@ -16,6 +16,7 @@ from typing import List from common.constants import MemoryType + def format_ret_data_from_memory(memory): return { "id": memory.id, @@ -37,7 +38,7 @@ def format_ret_data_from_memory(memory): "create_time": memory.create_time, "create_date": memory.create_date, "update_time": memory.update_time, - "update_date": memory.update_date + "update_date": memory.update_date, } diff --git a/api/utils/reference_metadata_utils.py b/api/utils/reference_metadata_utils.py index 58d5beffb0..e9444f994a 100644 --- a/api/utils/reference_metadata_utils.py +++ b/api/utils/reference_metadata_utils.py @@ -51,8 +51,7 @@ def resolve_reference_metadata_preferences( return include_metadata, None if not isinstance(fields, list): logger.warning( - "reference_metadata.fields is not a list; include_metadata=%s fields=%r type=%s resolved=%r. " - "enrich_chunks_with_document_metadata will skip enrichment.", + "reference_metadata.fields is not a list; include_metadata=%s fields=%r type=%s resolved=%r. enrich_chunks_with_document_metadata will skip enrichment.", include_metadata, fields, type(fields).__name__, @@ -96,12 +95,10 @@ def enrich_chunks_with_document_metadata( # Resolve service lazily so callers/tests that swap service modules at runtime # (e.g. via monkeypatch) don't get stuck with a stale class reference. from api.db.services.doc_metadata_service import DocMetadataService + metadata_getter = getattr(DocMetadataService, "get_metadata_for_documents", None) if not callable(metadata_getter): - logging.warning( - "DocMetadataService.get_metadata_for_documents is unavailable; " - "skipping metadata enrichment." - ) + logging.warning("DocMetadataService.get_metadata_for_documents is unavailable; skipping metadata enrichment.") return meta_by_doc: dict[str, dict] = {} diff --git a/api/validation.py b/api/validation.py index b552b3375a..65f65d991b 100644 --- a/api/validation.py +++ b/api/validation.py @@ -22,9 +22,7 @@ def python_version_validation(): # Check python version required_python_version = (3, 10) if sys.version_info < required_python_version: - logging.info( - f"Required Python: >= {required_python_version[0]}.{required_python_version[1]}. Current Python version: {sys.version_info[0]}.{sys.version_info[1]}." - ) + logging.info(f"Required Python: >= {required_python_version[0]}.{required_python_version[1]}. Current Python version: {sys.version_info[0]}.{sys.version_info[1]}.") sys.exit(1) else: logging.info(f"Python version: {sys.version_info[0]}.{sys.version_info[1]}") @@ -36,14 +34,16 @@ python_version_validation() # Download nltk data def download_nltk_data(): import nltk - nltk.download('wordnet', halt_on_error=False, quiet=True) - nltk.download('punkt_tab', halt_on_error=False, quiet=True) + + nltk.download("wordnet", halt_on_error=False, quiet=True) + nltk.download("punkt_tab", halt_on_error=False, quiet=True) try: from multiprocessing import Pool + pool = Pool(processes=1) thread = pool.apply_async(download_nltk_data) binary = thread.get(timeout=60) except Exception: - print('\x1b[6;37;41m WARNING \x1b[0m' + "Downloading NLTK data failure.", flush=True) + print("\x1b[6;37;41m WARNING \x1b[0m" + "Downloading NLTK data failure.", flush=True) diff --git a/common/__init__.py b/common/__init__.py index e156bc93dd..177b91dd05 100644 --- a/common/__init__.py +++ b/common/__init__.py @@ -12,4 +12,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# \ No newline at end of file +# diff --git a/common/asyncio_utils.py b/common/asyncio_utils.py index 12f5e0220a..26acf1b45a 100644 --- a/common/asyncio_utils.py +++ b/common/asyncio_utils.py @@ -26,9 +26,7 @@ class LoopLocalSemaphore: def __init__(self, value: int): self._value = int(value) - self._semaphores: "weakref.WeakKeyDictionary[asyncio.AbstractEventLoop, asyncio.Semaphore]" = ( - weakref.WeakKeyDictionary() - ) + self._semaphores: "weakref.WeakKeyDictionary[asyncio.AbstractEventLoop, asyncio.Semaphore]" = weakref.WeakKeyDictionary() def _get(self) -> asyncio.Semaphore: loop = asyncio.get_running_loop() diff --git a/common/config_utils.py b/common/config_utils.py index d367536de1..524732265d 100644 --- a/common/config_utils.py +++ b/common/config_utils.py @@ -54,7 +54,7 @@ def conf_realpath(conf_name): def read_config(conf_name=SERVICE_CONF): local_config = {} - local_path = conf_realpath(f'local.{conf_name}') + local_path = conf_realpath(f"local.{conf_name}") # load local config file if os.path.exists(local_path): @@ -128,10 +128,7 @@ def decrypt_database_password(password): raise ValueError("No private key") module_fun = encrypt_module.split("#") - pwdecrypt_fun = getattr( - importlib.import_module( - module_fun[0]), - module_fun[1]) + pwdecrypt_fun = getattr(importlib.import_module(module_fun[0]), module_fun[1]) return pwdecrypt_fun(private_key, password) @@ -152,4 +149,4 @@ def update_config(key, value, conf_name=SERVICE_CONF): with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")): config = load_yaml_conf(conf_path=conf_path) or {} config[key] = value - rewrite_yaml_conf(conf_path=conf_path, config=config) \ No newline at end of file + rewrite_yaml_conf(conf_path=conf_path, config=config) diff --git a/common/connection_utils.py b/common/connection_utils.py index 0218d99a28..48327a1879 100644 --- a/common/connection_utils.py +++ b/common/connection_utils.py @@ -27,8 +27,7 @@ TimeoutException = Union[Type[BaseException], BaseException] OnTimeoutCallback = Union[Callable[..., Any], Coroutine[Any, Any, Any]] -def timeout(seconds: float | int | str = None, attempts: int = 2, *, exception: Optional[TimeoutException] = None, - on_timeout: Optional[OnTimeoutCallback] = None): +def timeout(seconds: float | int | str = None, attempts: int = 2, *, exception: Optional[TimeoutException] = None, on_timeout: Optional[OnTimeoutCallback] = None): if isinstance(seconds, str): seconds = float(seconds) @@ -121,6 +120,7 @@ async def construct_response(code=RetCode.SUCCESS, message="success", data=None, def sync_construct_response(code=RetCode.SUCCESS, message="success", data=None, auth=None): import flask + result_dict = {"code": code, "message": message, "data": data} response_dict = {} for key, value in result_dict.items(): diff --git a/common/constants.py b/common/constants.py index 17feeb3a4b..0c66d39ab2 100644 --- a/common/constants.py +++ b/common/constants.py @@ -173,7 +173,15 @@ class PipelineTaskType(StrEnum): SKILL = "Skill" -VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD, PipelineTaskType.RAPTOR, PipelineTaskType.GRAPH_RAG, PipelineTaskType.MINDMAP, PipelineTaskType.ARTIFACT, PipelineTaskType.SKILL} +VALID_PIPELINE_TASK_TYPES = { + PipelineTaskType.PARSE, + PipelineTaskType.DOWNLOAD, + PipelineTaskType.RAPTOR, + PipelineTaskType.GRAPH_RAG, + PipelineTaskType.MINDMAP, + PipelineTaskType.ARTIFACT, + PipelineTaskType.SKILL, +} class MCPServerType(StrEnum): diff --git a/common/crypto_utils.py b/common/crypto_utils.py index 5dcbd2937f..9138cebda1 100644 --- a/common/crypto_utils.py +++ b/common/crypto_utils.py @@ -24,14 +24,14 @@ from cryptography.hazmat.primitives import hashes class BaseCrypto: """Base class for cryptographic algorithms""" - + # Magic header to identify encrypted data - ENCRYPTED_MAGIC = b'RAGF' - + ENCRYPTED_MAGIC = b"RAGF" + def __init__(self, key, iv=None, block_size=16, key_length=32, iv_length=16): """ Initialize cryptographic algorithm - + Args: key: Encryption key iv: Initialization vector, automatically generated if None @@ -42,57 +42,57 @@ class BaseCrypto: self.block_size = block_size self.key_length = key_length self.iv_length = iv_length - + # Normalize key self.key = self._normalize_key(key) self.iv = iv - + def _normalize_key(self, key): """Normalize key length""" if isinstance(key, str): - key = key.encode('utf-8') - + key = key.encode("utf-8") + # Use PBKDF2 for key derivation to ensure correct key length kdf = PBKDF2HMAC( algorithm=hashes.SHA256(), length=self.key_length, salt=b"ragflow_crypto_salt", # Fixed salt to ensure consistent key derivation results iterations=100000, - backend=default_backend() + backend=default_backend(), ) - + return kdf.derive(key) - + def encrypt(self, data): """ Encrypt data (template method) - + Args: data: Data to encrypt (bytes) - + Returns: Encrypted data (bytes), format: magic_header + iv + encrypted_data """ # Generate random IV iv = os.urandom(self.iv_length) if not self.iv else self.iv - + # Use PKCS7 padding padder = padding.PKCS7(self.block_size * 8).padder() padded_data = padder.update(data) + padder.finalize() - + # Delegate to subclass for specific encryption ciphertext = self._encrypt(padded_data, iv) - + # Return Magic Header + IV + encrypted data return self.ENCRYPTED_MAGIC + iv + ciphertext - + def decrypt(self, encrypted_data): """ Decrypt data (template method) - + Args: encrypted_data: Encrypted data (bytes) - + Returns: Decrypted data (bytes) """ @@ -100,44 +100,44 @@ class BaseCrypto: if not encrypted_data.startswith(self.ENCRYPTED_MAGIC): # Not encrypted, return as-is return encrypted_data - + # Remove magic header - encrypted_data = encrypted_data[len(self.ENCRYPTED_MAGIC):] - + encrypted_data = encrypted_data[len(self.ENCRYPTED_MAGIC) :] + # Separate IV and encrypted data - iv = encrypted_data[:self.iv_length] - ciphertext = encrypted_data[self.iv_length:] - + iv = encrypted_data[: self.iv_length] + ciphertext = encrypted_data[self.iv_length :] + # Delegate to subclass for specific decryption padded_data = self._decrypt(ciphertext, iv) - + # Remove padding unpadder = padding.PKCS7(self.block_size * 8).unpadder() data = unpadder.update(padded_data) + unpadder.finalize() - + return data - + def _encrypt(self, padded_data, iv): """ Encrypt padded data with specific algorithm - + Args: padded_data: Padded data to encrypt iv: Initialization vector - + Returns: Encrypted data """ raise NotImplementedError("_encrypt method must be implemented by subclass") - + def _decrypt(self, ciphertext, iv): """ Decrypt ciphertext with specific algorithm - + Args: ciphertext: Ciphertext to decrypt iv: Initialization vector - + Returns: Decrypted padded data """ @@ -146,11 +146,11 @@ class BaseCrypto: class AESCrypto(BaseCrypto): """Base class for AES cryptographic algorithm""" - + def __init__(self, key, iv=None, key_length=32): """ Initialize AES cryptographic algorithm - + Args: key: Encryption key iv: Initialization vector, automatically generated if None @@ -161,37 +161,29 @@ class AESCrypto(BaseCrypto): def _encrypt(self, padded_data, iv): """AES encryption implementation""" # Create encryptor - cipher = Cipher( - algorithms.AES(self.key), - modes.CBC(iv), - backend=default_backend() - ) + cipher = Cipher(algorithms.AES(self.key), modes.CBC(iv), backend=default_backend()) encryptor = cipher.encryptor() - + # Encrypt data return encryptor.update(padded_data) + encryptor.finalize() - + def _decrypt(self, ciphertext, iv): """AES decryption implementation""" # Create decryptor - cipher = Cipher( - algorithms.AES(self.key), - modes.CBC(iv), - backend=default_backend() - ) + cipher = Cipher(algorithms.AES(self.key), modes.CBC(iv), backend=default_backend()) decryptor = cipher.decryptor() - + # Decrypt data return decryptor.update(ciphertext) + decryptor.finalize() class AES128CBC(AESCrypto): """AES-128-CBC cryptographic algorithm""" - + def __init__(self, key, iv=None): """ Initialize AES-128-CBC cryptographic algorithm - + Args: key: Encryption key iv: Initialization vector, automatically generated if None @@ -201,11 +193,11 @@ class AES128CBC(AESCrypto): class AES256CBC(AESCrypto): """AES-256-CBC cryptographic algorithm""" - + def __init__(self, key, iv=None): """ Initialize AES-256-CBC cryptographic algorithm - + Args: key: Encryption key iv: Initialization vector, automatically generated if None @@ -215,11 +207,11 @@ class AES256CBC(AESCrypto): class SM4CBC(BaseCrypto): """SM4-CBC cryptographic algorithm using cryptography library for better performance""" - + def __init__(self, key, iv=None): """ Initialize SM4-CBC cryptographic algorithm - + Args: key: Encryption key iv: Initialization vector, automatically generated if None @@ -229,44 +221,32 @@ class SM4CBC(BaseCrypto): def _encrypt(self, padded_data, iv): """SM4 encryption implementation using cryptography library""" # Create encryptor - cipher = Cipher( - algorithms.SM4(self.key), - modes.CBC(iv), - backend=default_backend() - ) + cipher = Cipher(algorithms.SM4(self.key), modes.CBC(iv), backend=default_backend()) encryptor = cipher.encryptor() - + # Encrypt data return encryptor.update(padded_data) + encryptor.finalize() - + def _decrypt(self, ciphertext, iv): """SM4 decryption implementation using cryptography library""" # Create decryptor - cipher = Cipher( - algorithms.SM4(self.key), - modes.CBC(iv), - backend=default_backend() - ) + cipher = Cipher(algorithms.SM4(self.key), modes.CBC(iv), backend=default_backend()) decryptor = cipher.decryptor() - + # Decrypt data return decryptor.update(ciphertext) + decryptor.finalize() class CryptoUtil: """Cryptographic utility class, using factory pattern to create cryptographic algorithm instances""" - + # Supported cryptographic algorithms mapping - SUPPORTED_ALGORITHMS = { - "aes-128-cbc": AES128CBC, - "aes-256-cbc": AES256CBC, - "sm4-cbc": SM4CBC - } - + SUPPORTED_ALGORITHMS = {"aes-128-cbc": AES128CBC, "aes-256-cbc": AES256CBC, "sm4-cbc": SM4CBC} + def __init__(self, algorithm="aes-256-cbc", key=None, iv=None): """ Initialize cryptographic utility - + Args: algorithm: Cryptographic algorithm, default is aes-256-cbc key: Encryption key, uses RAGFLOW_CRYPTO_KEY environment variable if None @@ -274,21 +254,21 @@ class CryptoUtil: """ if algorithm not in self.SUPPORTED_ALGORITHMS: raise ValueError(f"Unsupported algorithm: {algorithm}") - + if not key: raise ValueError("Encryption key not provided and RAGFLOW_CRYPTO_KEY environment variable not set") - + # Create cryptographic algorithm instance self.algorithm_name = algorithm self.crypto = self.SUPPORTED_ALGORITHMS[algorithm](key=key, iv=iv) - + def encrypt(self, data): """ Encrypt data - + Args: data: Data to encrypt (bytes) - + Returns: Encrypted data (bytes) """ @@ -298,14 +278,14 @@ class CryptoUtil: # end_time = time.time() # logging.info(f"Encryption completed, data length: {len(data)} bytes, time: {(end_time - start_time)*1000:.2f} ms") return encrypted - + def decrypt(self, encrypted_data): """ Decrypt data - + Args: encrypted_data: Encrypted data (bytes) - + Returns: Decrypted data (bytes) """ @@ -322,23 +302,23 @@ if __name__ == "__main__": # Test AES encryption crypto = CryptoUtil(algorithm="aes-256-cbc", key="test_key_123456") test_data = b"Hello, RAGFlow! This is a test for encryption." - + encrypted = crypto.encrypt(test_data) decrypted = crypto.decrypt(encrypted) - + print("AES Test:") print(f"Original: {test_data}") print(f"Encrypted: {encrypted}") print(f"Decrypted: {decrypted}") print(f"Success: {test_data == decrypted}") print() - + # Test SM4 encryption try: crypto_sm4 = CryptoUtil(algorithm="sm4-cbc", key="test_key_123456") encrypted_sm4 = crypto_sm4.encrypt(test_data) decrypted_sm4 = crypto_sm4.decrypt(encrypted_sm4) - + print("SM4 Test:") print(f"Original: {test_data}") print(f"Encrypted: {encrypted_sm4}") @@ -347,23 +327,24 @@ if __name__ == "__main__": except Exception as e: print(f"SM4 Test Failed: {e}") import traceback + traceback.print_exc() - + # Test with specific algorithm classes directly print("\nDirect Algorithm Class Test:") - + # Test AES-128-CBC aes128 = AES128CBC(key="test_key_123456") encrypted_aes128 = aes128.encrypt(test_data) decrypted_aes128 = aes128.decrypt(encrypted_aes128) print(f"AES-128-CBC test: {'passed' if decrypted_aes128 == test_data else 'failed'}") - + # Test AES-256-CBC aes256 = AES256CBC(key="test_key_123456") encrypted_aes256 = aes256.encrypt(test_data) decrypted_aes256 = aes256.decrypt(encrypted_aes256) print(f"AES-256-CBC test: {'passed' if decrypted_aes256 == test_data else 'failed'}") - + # Test SM4-CBC try: sm4 = SM4CBC(key="test_key_123456") diff --git a/common/data_source/__init__.py b/common/data_source/__init__.py index 39bfc0a11b..158e5450b5 100644 --- a/common/data_source/__init__.py +++ b/common/data_source/__init__.py @@ -1,4 +1,3 @@ - """ Thanks to https://github.com/onyx-dot-app/onyx @@ -52,13 +51,7 @@ from .webdav_connector import WebDAVConnector from .rest_api_connector import RestAPIConnector from .config import BlobType, DocumentSource from .models import Document, TextSection, ImageSection, BasicExpertInfo -from .exceptions import ( - ConnectorMissingCredentialError, - ConnectorValidationError, - CredentialExpiredError, - InsufficientPermissionsError, - UnexpectedValidationError -) +from .exceptions import ConnectorMissingCredentialError, ConnectorValidationError, CredentialExpiredError, InsufficientPermissionsError, UnexpectedValidationError __all__ = [ "BlobStorageConnector", diff --git a/common/data_source/airtable_connector.py b/common/data_source/airtable_connector.py index f1ab300403..2ab471191c 100644 --- a/common/data_source/airtable_connector.py +++ b/common/data_source/airtable_connector.py @@ -18,11 +18,10 @@ from common.data_source.models import ( ) from common.data_source.utils import extract_size_bytes, get_file_ext + class AirtableClientNotSetUpError(PermissionError): def __init__(self) -> None: - super().__init__( - "Airtable client is not set up. Did you forget to call load_credentials()?" - ) + super().__init__("Airtable client is not set up. Did you forget to call load_credentials()?") class AirtableConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): @@ -52,10 +51,7 @@ class AirtableConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync) table = self.airtable_client.table(self.base_id, self.table_name_or_id) records = table.all() - logging.info( - f"Starting Airtable attachment scan for table {self.table_name_or_id}, " - f"{len(records)} records found." - ) + logging.info(f"Starting Airtable attachment scan for table {self.table_name_or_id}, {len(records)} records found.") for record in records: record_id = record.get("id") @@ -119,20 +115,11 @@ class AirtableConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync) resp.raise_for_status() content = resp.content except Exception: - logging.exception( - f"Failed to download attachment {filename} " - f"(record={record_id})" - ) + logging.exception(f"Failed to download attachment {filename} (record={record_id})") continue size_bytes = extract_size_bytes(attachment) - if ( - self.size_threshold is not None - and isinstance(size_bytes, int) - and size_bytes > self.size_threshold - ): - logging.warning( - f"{filename} exceeds size threshold of {self.size_threshold}. Skipping." - ) + if self.size_threshold is not None and isinstance(size_bytes, int) and size_bytes > self.size_threshold: + logging.warning(f"{filename} exceeds size threshold of {self.size_threshold}. Skipping.") continue batch.append( Document( @@ -142,7 +129,7 @@ class AirtableConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync) semantic_identifier=filename, extension=get_file_ext(filename), size_bytes=size_bytes if size_bytes else 0, - doc_updated_at=datetime.strptime(created_time, "%Y-%m-%dT%H:%M:%S.%fZ").replace(tzinfo=timezone.utc) + doc_updated_at=datetime.strptime(created_time, "%Y-%m-%dT%H:%M:%S.%fZ").replace(tzinfo=timezone.utc), ) ) @@ -190,11 +177,12 @@ class AirtableConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync) if filtered: yield filtered + if __name__ == "__main__": import os logging.basicConfig(level=logging.DEBUG) - connector = AirtableConnector("xxx","xxx") + connector = AirtableConnector("xxx", "xxx") connector.load_credentials({"airtable_access_token": os.environ.get("AIRTABLE_ACCESS_TOKEN")}) connector.validate_connector_settings() document_batches = connector.load_from_state() diff --git a/common/data_source/asana_connector.py b/common/data_source/asana_connector.py index e3aee9c4f0..2592888807 100644 --- a/common/data_source/asana_connector.py +++ b/common/data_source/asana_connector.py @@ -11,7 +11,6 @@ from common.data_source.models import Document, GenerateDocumentsOutput, Generat from common.data_source.utils import extract_size_bytes, get_file_ext - # https://github.com/Asana/python-asana/tree/master?tab=readme-ov-file#documentation-for-api-endpoints class AsanaTask: def __init__( @@ -37,9 +36,7 @@ class AsanaTask: class AsanaAPI: - def __init__( - self, api_token: str, workspace_gid: str, team_gid: str | None - ) -> None: + def __init__(self, api_token: str, workspace_gid: str, team_gid: str | None) -> None: self._user = None self.workspace_gid = workspace_gid self.team_gid = team_gid @@ -58,36 +55,26 @@ class AsanaAPI: self.configuration.access_token = api_token self.task_count = 0 - def get_tasks( - self, project_gids: list[str] | None, start_date: str - ) -> Iterator[AsanaTask]: + def get_tasks(self, project_gids: list[str] | None, start_date: str) -> Iterator[AsanaTask]: """Get all tasks from the projects with the given gids that were modified since the given date. If project_gids is None, get all tasks from all projects in the workspace.""" projects_list = self._get_project_gids_to_process(project_gids) start_seconds = int(time.mktime(datetime.now().timetuple())) for project_gid in projects_list: - for task in self._get_tasks_for_project( - project_gid, start_date, start_seconds - ): + for task in self._get_tasks_for_project(project_gid, start_date, start_seconds): yield task logging.info(f"Completed fetching {self.task_count} tasks from Asana") if self.api_error_count > 0: - logging.warning( - f"Encountered {self.api_error_count} API errors during task fetching" - ) + logging.warning(f"Encountered {self.api_error_count} API errors during task fetching") - def get_task_ids( - self, project_gids: list[str] | None, start_date: str - ) -> Iterator[str]: + def get_task_ids(self, project_gids: list[str] | None, start_date: str) -> Iterator[str]: """Get task gids without hydrating comments, users, or task text.""" projects_list = self._get_project_gids_to_process(project_gids) for project_gid in projects_list: for task_id in self._get_task_ids_for_project(project_gid, start_date): yield task_id - def _get_project_gids_to_process( - self, project_gids: list[str] | None - ) -> list[str]: + def _get_project_gids_to_process(self, project_gids: list[str] | None) -> list[str]: logging.info("Starting to fetch Asana projects") projects = self.project_api.get_projects( opts={ @@ -102,9 +89,7 @@ class AsanaAPI: if project_gids is None or project_gid in project_gids: projects_list.append(project_gid) else: - logging.debug( - f"Skipping project: {project_gid} - not in accepted project_gids" - ) + logging.debug(f"Skipping project: {project_gid} - not in accepted project_gids") project_count += 1 if project_count % 100 == 0: logging.info(f"Processed {project_count} projects") @@ -121,25 +106,17 @@ class AsanaAPI: logging.info(f"Skipping archived project: {project_name} ({project_gid})") return if not team_gid: - logging.info( - f"Skipping project without a team: {project_name} ({project_gid})" - ) + logging.info(f"Skipping project without a team: {project_name} ({project_gid})") return if project.get("privacy_setting") == "private": if self.team_gid and team_gid != self.team_gid: - logging.info( - f"Skipping private project not in configured team: {project_name} ({project_gid})" - ) + logging.info(f"Skipping private project not in configured team: {project_name} ({project_gid})") return - logging.info( - f"Processing private project in configured team: {project_name} ({project_gid})" - ) + logging.info(f"Processing private project in configured team: {project_name} ({project_gid})") return project - def _get_task_ids_for_project( - self, project_gid: str, start_date: str - ) -> Iterator[str]: + def _get_task_ids_for_project(self, project_gid: str, start_date: str) -> Iterator[str]: project = self._get_project_to_process(project_gid) if project is None: return @@ -156,18 +133,14 @@ class AsanaAPI: if task_id: yield task_id - def _get_tasks_for_project( - self, project_gid: str, start_date: str, start_seconds: int - ) -> Iterator[AsanaTask]: + def _get_tasks_for_project(self, project_gid: str, start_date: str, start_seconds: int) -> Iterator[AsanaTask]: project = self._get_project_to_process(project_gid) if project is None: return project_name = project.get("name", project_gid) simple_start_date = start_date.split(".")[0].split("+")[0] - logging.info( - f"Fetching tasks modified since {simple_start_date} for project: {project_name} ({project_gid})" - ) + logging.info(f"Fetching tasks modified since {simple_start_date} for project: {project_name} ({project_gid})") opts = { "opt_fields": "name,memberships,memberships.project,completed_at,completed_by,created_at," @@ -183,10 +156,7 @@ class AsanaAPI: end_seconds = time.mktime(datetime.now().timetuple()) runtime_seconds = end_seconds - start_seconds if runtime_seconds > 0: - logging.info( - f"Processed {self.task_count} tasks in {runtime_seconds:.0f} seconds " - f"({self.task_count / runtime_seconds:.2f} tasks/second)" - ) + logging.info(f"Processed {self.task_count} tasks in {runtime_seconds:.0f} seconds ({self.task_count / runtime_seconds:.2f} tasks/second)") logging.debug(f"Processing Asana task: {data['name']}") @@ -249,17 +219,13 @@ class AsanaAPI: for story in stories: story_count += 1 if story["resource_subtype"] == "comment_added": - comment = self.stories_api.get_story( - story["gid"], opts={"opt_fields": "text,created_by,created_at"} - ) + comment = self.stories_api.get_story(story["gid"], opts={"opt_fields": "text,created_by,created_at"}) commenter = self.get_user(comment["created_by"]["gid"])["name"] text += f"Comment by {commenter}: {comment['text']}\n\n" comment_count += 1 story_duration = time.time() - story_start - logging.debug( - f"Processed {story_count} stories (including {comment_count} comments) in {story_duration:.2f} seconds" - ) + logging.debug(f"Processed {story_count} stories (including {comment_count} comments) in {story_duration:.2f} seconds") return text @@ -271,30 +237,20 @@ class AsanaAPI: try: # Step 1: list attachment compact records - for att in self.attachments_api.get_attachments_for_object( - parent=task_gid, - opts={} - ): + for att in self.attachments_api.get_attachments_for_object(parent=task_gid, opts={}): gid = att.get("gid") if not gid: continue try: # Step 2: expand to full attachment - full = self.attachments_api.get_attachment( - attachment_gid=gid, - opts={ - "opt_fields": "gid,name,download_url,size,created_at" - } - ) + full = self.attachments_api.get_attachment(attachment_gid=gid, opts={"opt_fields": "gid,name,download_url,size,created_at"}) if full.get("download_url"): attachments.append(full) except Exception: - logging.exception( - f"Failed to fetch attachment detail {gid} for task {task_gid}" - ) + logging.exception(f"Failed to fetch attachment detail {gid} for task {task_gid}") self.api_error_count += 1 except Exception: @@ -310,42 +266,26 @@ class AsanaAPI: team_id: str | None, ): - ws_users = self.users_api.get_users( - opts={ - "workspace": workspace_id, - "opt_fields": "gid,name,email" - } - ) + ws_users = self.users_api.get_users(opts={"workspace": workspace_id, "opt_fields": "gid,name,email"}) - workspace_users = { - u["gid"]: u.get("email") - for u in ws_users - if u.get("email") - } + workspace_users = {u["gid"]: u.get("email") for u in ws_users if u.get("email")} if not project_ids: return set(workspace_users.values()) - project_emails = set() for pid in project_ids: pid = pid.strip() if not pid: continue - project = self.project_api.get_project( - pid, - opts={"opt_fields": "team,privacy_setting"} - ) + project = self.project_api.get_project(pid, opts={"opt_fields": "team,privacy_setting"}) if project.get("privacy_setting") == "private": if team_id and project.get("team", {}).get("gid") != team_id: continue - memberships = self.project_memberships_api.get_project_memberships_for_project( - pid, - opts={"opt_fields": "user.gid,user.email"} - ) + memberships = self.project_memberships_api.get_project_memberships_for_project(pid, opts={"opt_fields": "user.gid,user.email"}) for m in memberships: email = (m.get("user") or {}).get("email") @@ -382,18 +322,12 @@ class AsanaConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE, ) -> None: self.workspace_id = asana_workspace_id - self.project_ids_to_index: list[str] | None = ( - [project_id.strip() for project_id in asana_project_ids.split(",") if project_id.strip()] - if asana_project_ids - else None - ) + self.project_ids_to_index: list[str] | None = [project_id.strip() for project_id in asana_project_ids.split(",") if project_id.strip()] if asana_project_ids else None self.asana_team_id = asana_team_id.strip() if asana_team_id and asana_team_id.strip() else None self.batch_size = batch_size self.continue_on_failure = continue_on_failure self.size_threshold = None - logging.info( - f"AsanaConnector initialized with workspace_id: {asana_workspace_id}" - ) + logging.info(f"AsanaConnector initialized with workspace_id: {asana_workspace_id}") def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: self.api_token = credentials["asana_api_token_secret"] @@ -406,9 +340,7 @@ class AsanaConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): logging.info("Asana credentials loaded and API client initialized") return None - def poll_source( - self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch | None - ) -> GenerateDocumentsOutput: + def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch | None) -> GenerateDocumentsOutput: start_time = datetime.fromtimestamp(start, tz=timezone.utc).isoformat() end_time = datetime.fromtimestamp(end, tz=timezone.utc) if end is not None else None logging.info(f"Starting Asana poll from {start_time}") @@ -480,14 +412,8 @@ class AsanaConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): file_blob = resp.content filename = att.get("name", "attachment") size_bytes = extract_size_bytes(att) - if ( - self.size_threshold is not None - and isinstance(size_bytes, int) - and size_bytes > self.size_threshold - ): - logging.warning( - f"{filename} exceeds size threshold of {self.size_threshold}. Skipping." - ) + if self.size_threshold is not None and isinstance(size_bytes, int) and size_bytes > self.size_threshold: + logging.warning(f"{filename} exceeds size threshold of {self.size_threshold}. Skipping.") continue docs.append( Document( @@ -502,14 +428,11 @@ class AsanaConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): ) ) except Exception: - logging.exception( - f"Failed to download attachment {att.get('gid')} for task {task.id}" - ) + logging.exception(f"Failed to download attachment {att.get('gid')} for task {task.id}") return docs - if __name__ == "__main__": import time import os diff --git a/common/data_source/azure_blob_connector.py b/common/data_source/azure_blob_connector.py index 771aa13f5b..5ac88b10fc 100644 --- a/common/data_source/azure_blob_connector.py +++ b/common/data_source/azure_blob_connector.py @@ -50,9 +50,20 @@ logger = logging.getLogger(__name__) # Extensions we ingest; mirrors the same set used by the OneDrive # connector so behaviour is consistent across all file-based sources. _SUPPORTED_EXTENSIONS = { - ".pdf", ".docx", ".doc", ".xlsx", ".xls", - ".pptx", ".ppt", ".txt", ".md", ".csv", - ".html", ".htm", ".json", ".xml", + ".pdf", + ".docx", + ".doc", + ".xlsx", + ".xls", + ".pptx", + ".ppt", + ".txt", + ".md", + ".csv", + ".html", + ".htm", + ".json", + ".xml", } _AZURE_ENDPOINT_SUFFIX = "blob.core.windows.net" @@ -125,24 +136,16 @@ class AzureBlobConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPer try: if mode == "connection_string": if not conn_str: - raise ConnectorMissingCredentialError( - "Azure Blob: connection_string is required for the connection_string auth mode" - ) + raise ConnectorMissingCredentialError("Azure Blob: connection_string is required for the connection_string auth mode") if not container_name: - raise ConnectorMissingCredentialError( - "Azure Blob: container_name is required together with connection_string" - ) + raise ConnectorMissingCredentialError("Azure Blob: container_name is required together with connection_string") svc = BlobServiceClient.from_connection_string(conn_str) self._container_client = svc.get_container_client(container_name) elif mode == "account_key": if not (account_name and account_key): - raise ConnectorMissingCredentialError( - "Azure Blob: account_name and account_key are required for the account_key auth mode" - ) + raise ConnectorMissingCredentialError("Azure Blob: account_name and account_key are required for the account_key auth mode") if not container_name: - raise ConnectorMissingCredentialError( - "Azure Blob: container_name is required together with account_name + account_key" - ) + raise ConnectorMissingCredentialError("Azure Blob: container_name is required together with account_name + account_key") account_url = f"https://{account_name}.{_AZURE_ENDPOINT_SUFFIX}" svc = BlobServiceClient( account_url=account_url, @@ -151,9 +154,7 @@ class AzureBlobConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPer self._container_client = svc.get_container_client(container_name) elif mode == "sas_token": if not (container_url and sas_token): - raise ConnectorMissingCredentialError( - "Azure Blob: container_url and sas_token are required for the sas_token auth mode" - ) + raise ConnectorMissingCredentialError("Azure Blob: container_url and sas_token are required for the sas_token auth mode") # mirrors RAGFlowAzureSasBlob; strip a leading "?" so we # never produce a double-"?" that breaks SAS auth. normalized_sas = str(sas_token).lstrip("?") @@ -161,17 +162,12 @@ class AzureBlobConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPer self._container_client = ContainerClient.from_container_url(full_url) else: raise ConnectorMissingCredentialError( - "Azure Blob credentials are incomplete. Provide one of: " - "(a) connection_string + container_name, " - "(b) account_name + account_key + container_name, " - "(c) container_url + sas_token." + "Azure Blob credentials are incomplete. Provide one of: (a) connection_string + container_name, (b) account_name + account_key + container_name, (c) container_url + sas_token." ) except ConnectorMissingCredentialError: raise except Exception as exc: - raise ConnectorMissingCredentialError( - f"Failed to initialise Azure Blob client: {exc}" - ) from exc + raise ConnectorMissingCredentialError(f"Failed to initialise Azure Blob client: {exc}") from exc return None @@ -192,20 +188,12 @@ class AzureBlobConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPer msg = str(exc) code = getattr(getattr(exc, "error_code", None), "value", None) or getattr(exc, "error_code", "") if "AuthenticationFailed" in msg or "InvalidAuthenticationInfo" in msg: - raise ConnectorMissingCredentialError( - f"Azure Blob credential rejected: {msg[:300]}" - ) from exc + raise ConnectorMissingCredentialError(f"Azure Blob credential rejected: {msg[:300]}") from exc if "AuthorizationPermissionMismatch" in msg or "403" in msg: - raise InsufficientPermissionsError( - f"Azure Blob: insufficient permissions on container: {msg[:300]}" - ) from exc + raise InsufficientPermissionsError(f"Azure Blob: insufficient permissions on container: {msg[:300]}") from exc if "ContainerNotFound" in msg or "404" in msg: - raise ConnectorValidationError( - f"Azure Blob: container not found: {msg[:300]}" - ) from exc - raise UnexpectedValidationError( - f"Azure Blob validation failed ({code}): {msg[:300]}" - ) from exc + raise ConnectorValidationError(f"Azure Blob: container not found: {msg[:300]}") from exc + raise UnexpectedValidationError(f"Azure Blob validation failed ({code}): {msg[:300]}") from exc # ------------------------------------------------------------------ # Checkpoint helpers @@ -224,9 +212,7 @@ class AzureBlobConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPer # Core data loading # ------------------------------------------------------------------ - def poll_source( - self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch - ) -> Any: + def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Any: return self._iter_documents(since_epoch=start, until_epoch=end) def load_from_checkpoint( @@ -239,9 +225,7 @@ class AzureBlobConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPer checkpoint = self.build_dummy_checkpoint() since = start if start else None until = end if end else None - return self._iter_documents( - checkpoint=checkpoint, since_epoch=since, until_epoch=until - ) + return self._iter_documents(checkpoint=checkpoint, since_epoch=since, until_epoch=until) def load_from_checkpoint_with_perm_sync( self, @@ -272,9 +256,7 @@ class AzureBlobConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPer yield batch batch = [] except Exception as exc: - raise UnexpectedValidationError( - f"Azure Blob prune listing failed: {exc}" - ) from exc + raise UnexpectedValidationError(f"Azure Blob prune listing failed: {exc}") from exc if batch: yield batch @@ -297,9 +279,7 @@ class AzureBlobConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPer batch: list[Document] = [] try: - for blob_props in self._container_client.list_blobs( - name_starts_with=self.prefix or None - ): + for blob_props in self._container_client.list_blobs(name_starts_with=self.prefix or None): name: str = blob_props.name if not _has_supported_extension(name, self.allow_images): @@ -346,15 +326,9 @@ class AzureBlobConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPer name, ) continue - raise UnexpectedValidationError( - f"Azure Blob: failed to download {name}: {exc}" - ) from exc + raise UnexpectedValidationError(f"Azure Blob: failed to download {name}: {exc}") from exc - doc_updated_at = ( - last_modified.astimezone(timezone.utc) - if last_modified - else datetime.now(timezone.utc) - ) + doc_updated_at = last_modified.astimezone(timezone.utc) if last_modified else datetime.now(timezone.utc) ext = _extension(name) doc = Document( @@ -380,9 +354,7 @@ class AzureBlobConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPer except UnexpectedValidationError: raise except Exception as exc: - raise UnexpectedValidationError( - f"Azure Blob listing failed: {exc}" - ) from exc + raise UnexpectedValidationError(f"Azure Blob listing failed: {exc}") from exc if batch: yield batch @@ -395,6 +367,7 @@ class AzureBlobConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPer # Module-level helpers # ---------------------------------------------------------------------- + def _extension(name: str) -> str: if "." not in name: return "" diff --git a/common/data_source/bigquery_connector.py b/common/data_source/bigquery_connector.py index 1aeec24a0f..00489c5301 100644 --- a/common/data_source/bigquery_connector.py +++ b/common/data_source/bigquery_connector.py @@ -156,15 +156,11 @@ class BigQueryConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync) try: service_account_info = json.loads(raw) except json.JSONDecodeError as exc: - raise ConnectorMissingCredentialError( - f"BigQuery: service_account_json is not valid JSON: {exc}" - ) + raise ConnectorMissingCredentialError(f"BigQuery: service_account_json is not valid JSON: {exc}") elif isinstance(raw, dict): service_account_info = raw else: - raise ConnectorMissingCredentialError( - "BigQuery: service_account_json must be a JSON string or object" - ) + raise ConnectorMissingCredentialError("BigQuery: service_account_json must be a JSON string or object") self._credentials = {"service_account_info": service_account_info} return None @@ -175,9 +171,7 @@ class BigQueryConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync) return self._client if bigquery is None or service_account is None: - raise ConnectorValidationError( - "BigQuery client not installed. Please install google-cloud-bigquery." - ) + raise ConnectorValidationError("BigQuery client not installed. Please install google-cloud-bigquery.") service_account_info = self._credentials.get("service_account_info") if not service_account_info: @@ -208,9 +202,7 @@ class BigQueryConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync) return self.query.rstrip(";") if self.dataset_id and self.table_id: return f"SELECT * FROM `{self.project_id}.{self.dataset_id}.{self.table_id}`" - raise ConnectorValidationError( - "BigQuery requires either a custom query or both dataset_id and table_id." - ) + raise ConnectorValidationError("BigQuery requires either a custom query or both dataset_id and table_id.") @staticmethod def _wrap_query(base_query: str, select_clause: str = "*") -> str: @@ -256,9 +248,7 @@ class BigQueryConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync) for field in schema: if field.name == self.timestamp_column: return field.field_type - raise ConnectorValidationError( - f"BigQuery timestamp column '{self.timestamp_column}' was not found in the schema." - ) + raise ConnectorValidationError(f"BigQuery timestamp column '{self.timestamp_column}' was not found in the schema.") def _resolve_cursor_param_type(self) -> str: if self._cursor_param_type is not None: @@ -266,9 +256,7 @@ class BigQueryConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync) field_type = (self._get_cursor_column_field_type() or "").upper() param_type = _CURSOR_PARAM_TYPE_MAP.get(field_type) if param_type is None: - raise ConnectorValidationError( - f"BigQuery timestamp column type '{field_type}' is not supported as a cursor." - ) + raise ConnectorValidationError(f"BigQuery timestamp column type '{field_type}' is not supported as a cursor.") self._cursor_param_type = param_type return param_type @@ -291,10 +279,7 @@ class BigQueryConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync) params: List[Any] = [] if start is not None: if self.id_column and start_id is not None: - conditions.append( - f"(ragflow_src.{self.timestamp_column} > @start_cursor OR " - f"(ragflow_src.{self.timestamp_column} = @start_cursor AND ragflow_src.{self.id_column} > @start_cursor_id))" - ) + conditions.append(f"(ragflow_src.{self.timestamp_column} > @start_cursor OR (ragflow_src.{self.timestamp_column} = @start_cursor AND ragflow_src.{self.id_column} > @start_cursor_id))") params.append(self._make_cursor_param("start_cursor", start, param_type)) params.append(self._make_cursor_param("start_cursor_id", start_id, "STRING")) else: @@ -319,10 +304,7 @@ class BigQueryConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync) f") " f"GROUP BY ragflow_src.{self.timestamp_column}" ) - return ( - f"SELECT MAX(ragflow_src.{self.timestamp_column}), NULL " - f"FROM ({base_query}) AS ragflow_src" - ) + return f"SELECT MAX(ragflow_src.{self.timestamp_column}), NULL FROM ({base_query}) AS ragflow_src" def _build_slim_query(self, base_query: str) -> str: columns = [self.id_column] if self.id_column else self.content_columns @@ -426,17 +408,10 @@ class BigQueryConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync) else: doc_updated_at = ts_value.astimezone(timezone.utc) elif isinstance(ts_value, date): - doc_updated_at = datetime( - ts_value.year, ts_value.month, ts_value.day, tzinfo=timezone.utc - ) + doc_updated_at = datetime(ts_value.year, ts_value.month, ts_value.day, tzinfo=timezone.utc) first_content_col = self.content_columns[0] if self.content_columns else "record" - semantic_id = ( - str(row_dict.get(first_content_col, "bigquery_record")) - .replace("\n", " ") - .replace("\r", " ") - .strip()[:100] - ) + semantic_id = str(row_dict.get(first_content_col, "bigquery_record")).replace("\n", " ").replace("\r", " ").strip()[:100] blob = content.encode("utf-8") return Document( @@ -540,9 +515,7 @@ class BigQueryConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync) logging.debug("Loading all records from BigQuery project: %s", self.project_id) return self._yield_documents() - def poll_source( - self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch - ) -> Generator[list[Document], None, None]: + def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Generator[list[Document], None, None]: """Poll for new/updated rows. Provided for interface completeness. Orchestration drives full/incremental sync via ``load_from_state`` / @@ -550,9 +523,7 @@ class BigQueryConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync) timestamp column. """ if not self.timestamp_column: - logging.warning( - "No timestamp column configured for incremental sync. Falling back to full sync." - ) + logging.warning("No timestamp column configured for incremental sync. Falling back to full sync.") return self.load_from_state() start_dt = datetime.fromtimestamp(start, tz=timezone.utc) if start else None end_dt = datetime.fromtimestamp(end, tz=timezone.utc) if end else None @@ -575,9 +546,7 @@ class BigQueryConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync) callback: Any = None, ) -> Generator[list[SlimDocument], None, None]: del callback - yield from self._yield_slim_documents_from_query( - self._build_slim_query(self._build_base_query()) - ) + yield from self._yield_slim_documents_from_query(self._build_slim_query(self._build_base_query())) # ------------------------------------------------------------------ # # Sync-state persistence (success-only cursor) @@ -608,9 +577,7 @@ class BigQueryConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync) from api.db.services.connector_service import ConnectorService updated_conf = copy.deepcopy(self._sync_config) - updated_conf["sync_cursor_value"] = self.serialize_cursor_value( - self._pending_sync_cursor_value - ) + updated_conf["sync_cursor_value"] = self.serialize_cursor_value(self._pending_sync_cursor_value) updated_conf["sync_cursor_id"] = self._pending_sync_cursor_id ConnectorService.update_by_id(self._sync_connector_id, {"config": updated_conf}) self._sync_config = updated_conf @@ -627,9 +594,7 @@ class BigQueryConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync) if not self.content_columns: raise ConnectorValidationError("At least one content column must be specified.") if not self.query and not (self.dataset_id and self.table_id): - raise ConnectorValidationError( - "BigQuery requires either a custom query or both dataset_id and table_id." - ) + raise ConnectorValidationError("BigQuery requires either a custom query or both dataset_id and table_id.") try: client = self._get_client() @@ -649,10 +614,8 @@ class BigQueryConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync) ) estimated_bytes = getattr(dry_run_job, "total_bytes_processed", None) if estimated_bytes is not None: - logging.info( - "BigQuery base query dry-run estimate: %s bytes processed.", estimated_bytes - ) - + logging.info("BigQuery base query dry-run estimate: %s bytes processed.", estimated_bytes) + schema = self._resolve_schema(client, dry_run_job) schema_columns = {field.name for field in schema} @@ -665,9 +628,7 @@ class BigQueryConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync) missing = (required | optional) - schema_columns if missing: - raise ConnectorValidationError( - f"BigQuery configured columns not found in schema: {', '.join(sorted(missing))}" - ) + raise ConnectorValidationError(f"BigQuery configured columns not found in schema: {', '.join(sorted(missing))}") if self.timestamp_column: self._resolve_cursor_param_type() diff --git a/common/data_source/bitbucket/connector.py b/common/data_source/bitbucket/connector.py index 0557d2a503..0e570f7786 100644 --- a/common/data_source/bitbucket/connector.py +++ b/common/data_source/bitbucket/connector.py @@ -13,7 +13,7 @@ from typing_extensions import override from common.data_source.config import INDEX_BATCH_SIZE from common.data_source.config import DocumentSource from common.data_source.config import REQUEST_TIMEOUT_SECONDS -from common.data_source.exceptions import ( +from common.data_source.exceptions import ( ConnectorMissingCredentialError, CredentialExpiredError, InsufficientPermissionsError, @@ -76,14 +76,8 @@ class BitbucketConnector( batch_size: int = INDEX_BATCH_SIZE, ) -> None: self.workspace = workspace - self._repositories = ( - [s.strip() for s in repositories.split(",") if s.strip()] - if repositories - else None - ) - self._projects: list[str] | None = ( - [s.strip() for s in projects.split(",") if s.strip()] if projects else None - ) + self._repositories = [s.strip() for s in repositories.split(",") if s.strip()] if repositories else None + self._projects: list[str] | None = [s.strip() for s in projects.split(",") if s.strip()] if projects else None self.batch_size = batch_size self.email: str | None = None self.api_token: str | None = None @@ -230,11 +224,7 @@ class BitbucketConnector( yield document except Exception as e: pr_id = pr.get("id") - pr_link = ( - f"https://bitbucket.org/{self.workspace}/{repo_slug}/pull-requests/{pr_id}" - if pr_id is not None - else None - ) + pr_link = f"https://bitbucket.org/{self.workspace}/{repo_slug}/pull-requests/{pr_id}" if pr_id is not None else None yield ConnectorFailure( failed_document=DocumentFailure( document_id=( @@ -261,9 +251,7 @@ class BitbucketConnector( return BitbucketConnectorCheckpoint(has_more=True) @override - def validate_checkpoint_json( - self, checkpoint_json: str - ) -> BitbucketConnectorCheckpoint: + def validate_checkpoint_json(self, checkpoint_json: str) -> BitbucketConnectorCheckpoint: """Validate and deserialize a checkpoint instance from JSON.""" return BitbucketConnectorCheckpoint.model_validate_json(checkpoint_json) @@ -276,9 +264,7 @@ class BitbucketConnector( params = self._build_params(fields=SLIM_PR_LIST_RESPONSE_FIELDS) with self._client() as client: for slug in self._iter_target_repositories(client): - for pr in self._iter_pull_requests_for_repo( - client, slug, params=params - ): + for pr in self._iter_pull_requests_for_repo(client, slug, params=params): pr_id = pr["id"] doc_id = f"{DocumentSource.BITBUCKET.value}:{self.workspace}:{slug}:pr:{pr_id}" batch.append(SlimDocument(id=doc_id)) @@ -288,9 +274,7 @@ class BitbucketConnector( if callback: if callback.should_stop(): # Note: this is not actually used for permission sync yet, just pruning - raise RuntimeError( - "bitbucket_pr_sync: Stop signal detected" - ) + raise RuntimeError("bitbucket_pr_sync: Stop signal detected") callback.progress("bitbucket_pr_sync", len(batch)) if batch: yield batch @@ -312,17 +296,11 @@ class BitbucketConnector( timeout=REQUEST_TIMEOUT_SECONDS, ) if resp.status_code == 401: - raise CredentialExpiredError( - "Invalid or expired Bitbucket credentials (HTTP 401)." - ) + raise CredentialExpiredError("Invalid or expired Bitbucket credentials (HTTP 401).") if resp.status_code == 403: - raise InsufficientPermissionsError( - "Insufficient permissions to access Bitbucket workspace (HTTP 403)." - ) + raise InsufficientPermissionsError("Insufficient permissions to access Bitbucket workspace (HTTP 403).") if resp.status_code < 200 or resp.status_code >= 300: - raise UnexpectedValidationError( - f"Unexpected Bitbucket error (status={resp.status_code})." - ) + raise UnexpectedValidationError(f"Unexpected Bitbucket error (status={resp.status_code}).") except Exception as e: # Network or other unexpected errors if isinstance( @@ -335,19 +313,18 @@ class BitbucketConnector( ), ): raise - raise UnexpectedValidationError( - f"Unexpected error while validating Bitbucket settings: {e}" - ) + raise UnexpectedValidationError(f"Unexpected error while validating Bitbucket settings: {e}") + if __name__ == "__main__": - bitbucket = BitbucketConnector( - workspace="" - ) + bitbucket = BitbucketConnector(workspace="") - bitbucket.load_credentials({ - "bitbucket_email": "", - "bitbucket_api_token": "", - }) + bitbucket.load_credentials( + { + "bitbucket_email": "", + "bitbucket_api_token": "", + } + ) bitbucket.validate_connector_settings() print("Credentials validated successfully.") @@ -359,9 +336,8 @@ if __name__ == "__main__": for doc in doc_batch: print(doc) - bitbucket_checkpoint = bitbucket.build_dummy_checkpoint() - + while bitbucket_checkpoint.has_more: gen = bitbucket.load_from_checkpoint( start=start_time.timestamp(), @@ -371,9 +347,8 @@ if __name__ == "__main__": while True: try: - doc = next(gen) + doc = next(gen) print(doc) except StopIteration as e: - bitbucket_checkpoint = e.value + bitbucket_checkpoint = e.value break - diff --git a/common/data_source/bitbucket/utils.py b/common/data_source/bitbucket/utils.py index 4667a96006..3fd3e9a2f1 100644 --- a/common/data_source/bitbucket/utils.py +++ b/common/data_source/bitbucket/utils.py @@ -88,9 +88,7 @@ class BitbucketNonRetriableError(Exception): exceptions=(BitbucketRetriableError, httpx.RequestError), ) @rate_limit_builder(max_calls=60, period=60) -def bitbucket_get( - client: httpx.Client, url: str, params: dict[str, Any] | None = None -) -> httpx.Response: +def bitbucket_get(client: httpx.Client, url: str, params: dict[str, Any] | None = None) -> httpx.Response: """Perform a GET against Bitbucket with retry and rate limiting. Retries on 429 and 5xx responses, and on transport errors. Honors @@ -162,9 +160,7 @@ def paginate( query = None -def list_repositories( - client: httpx.Client, workspace: str, project_key: str | None = None -) -> Iterator[dict[str, Any]]: +def list_repositories(client: httpx.Client, workspace: str, project_key: str | None = None) -> Iterator[dict[str, Any]]: """List repositories in a workspace, optionally filtered by project key.""" base_url = f"https://api.bitbucket.org/2.0/repositories/{workspace}" params: dict[str, Any] = { @@ -189,26 +185,16 @@ def map_pr_to_document(pr: dict[str, Any], workspace: str, repo_slug: str) -> Do reviewers = pr.get("reviewers", []) participants = pr.get("participants", []) - link = pr.get("links", {}).get("html", {}).get("href") or ( - f"https://bitbucket.org/{workspace}/{repo_slug}/pull-requests/{pr_id}" - ) + link = pr.get("links", {}).get("html", {}).get("href") or (f"https://bitbucket.org/{workspace}/{repo_slug}/pull-requests/{pr_id}") created_on = pr.get("created_on") updated_on = pr.get("updated_on") - updated_dt = ( - datetime.fromisoformat(updated_on.replace("Z", "+00:00")).astimezone( - timezone.utc - ) - if isinstance(updated_on, str) - else None - ) + updated_dt = datetime.fromisoformat(updated_on.replace("Z", "+00:00")).astimezone(timezone.utc) if isinstance(updated_on, str) else None source_branch = pr.get("source", {}).get("branch", {}).get("name", "") destination_branch = pr.get("destination", {}).get("branch", {}).get("name", "") - approved_by = [ - _get_user_name(p.get("user", {})) for p in participants if p.get("approved") - ] + approved_by = [_get_user_name(p.get("user", {})) for p in participants if p.get("approved")] primary_owner = None if author: @@ -216,21 +202,16 @@ def map_pr_to_document(pr: dict[str, Any], workspace: str, repo_slug: str) -> Do display_name=_get_user_name(author), ) - # secondary_owners = [ + # secondary_owners = [ # BasicExpertInfo(display_name=_get_user_name(r)) for r in reviewers - # ] or None + # ] or None reviewer_names = [_get_user_name(r) for r in reviewers] # Create a concise summary of key PR info created_date = created_on.split("T")[0] if created_on else "N/A" updated_date = updated_on.split("T")[0] if updated_on else "N/A" - content_text = ( - "Pull Request Information:\n" - f"- Pull Request ID: {pr_id}\n" - f"- Title: {title}\n" - f"- State: {state or 'N/A'} {'(Draft)' if draft else ''}\n" - ) + content_text = f"Pull Request Information:\n- Pull Request ID: {pr_id}\n- Title: {title}\n- State: {state or 'N/A'} {'(Draft)' if draft else ''}\n" if state == "DECLINED": content_text += f"- Reason: {pr.get('reason', 'N/A')}\n" content_text += ( @@ -262,9 +243,7 @@ def map_pr_to_document(pr: dict[str, Any], workspace: str, repo_slug: str) -> Do "updated_on": updated_on or "", "source_branch": source_branch, "destination_branch": destination_branch, - "closed_by": ( - _get_user_name(pr.get("closed_by", {})) if pr.get("closed_by") else "" - ), + "closed_by": (_get_user_name(pr.get("closed_by", {})) if pr.get("closed_by") else ""), "close_source_branch": str(bool(pr.get("close_source_branch", False))), } @@ -285,4 +264,4 @@ def map_pr_to_document(pr: dict[str, Any], workspace: str, repo_slug: str) -> Do def _get_user_name(user: dict[str, Any]) -> str: - return user.get("display_name") or user.get("nickname") or "unknown" \ No newline at end of file + return user.get("display_name") or user.get("nickname") or "unknown" diff --git a/common/data_source/blob_connector.py b/common/data_source/blob_connector.py index e183eb63aa..d66e5a0993 100644 --- a/common/data_source/blob_connector.py +++ b/common/data_source/blob_connector.py @@ -1,4 +1,5 @@ """Blob storage connector""" + import logging import os from collections.abc import Iterator @@ -15,12 +16,7 @@ from common.data_source.utils import ( get_file_ext, ) from common.data_source.config import BlobType, DocumentSource, BLOB_STORAGE_SIZE_THRESHOLD, INDEX_BATCH_SIZE -from common.data_source.exceptions import ( - ConnectorMissingCredentialError, - ConnectorValidationError, - CredentialExpiredError, - InsufficientPermissionsError -) +from common.data_source.exceptions import ConnectorMissingCredentialError, ConnectorValidationError, CredentialExpiredError, InsufficientPermissionsError from common.data_source.interfaces import ( FingerprintConnector, LoadConnector, @@ -82,32 +78,24 @@ class BlobStorageConnector(LoadConnector, PollConnector, FingerprintConnector): def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: """Load credentials""" - logging.debug( - f"Loading credentials for {self.bucket_name} of type {self.bucket_type}" - ) + logging.debug(f"Loading credentials for {self.bucket_name} of type {self.bucket_type}") # Validate credentials if self.bucket_type == BlobType.R2: - if not all( - credentials.get(key) - for key in ["r2_access_key_id", "r2_secret_access_key", "account_id"] - ): + if not all(credentials.get(key) for key in ["r2_access_key_id", "r2_secret_access_key", "account_id"]): raise ConnectorMissingCredentialError("Cloudflare R2") elif self.bucket_type == BlobType.S3: authentication_method = credentials.get("authentication_method", "access_key") if authentication_method == "access_key": - if not all( - credentials.get(key) - for key in ["aws_access_key_id", "aws_secret_access_key"] - ): + if not all(credentials.get(key) for key in ["aws_access_key_id", "aws_secret_access_key"]): raise ConnectorMissingCredentialError("Amazon S3") elif authentication_method == "iam_role": if not credentials.get("aws_role_arn"): raise ConnectorMissingCredentialError("Amazon S3 IAM role ARN is required") - + elif authentication_method == "assume_role": pass @@ -115,32 +103,22 @@ class BlobStorageConnector(LoadConnector, PollConnector, FingerprintConnector): raise ConnectorMissingCredentialError("Unsupported S3 authentication method") elif self.bucket_type == BlobType.GOOGLE_CLOUD_STORAGE: - if not all( - credentials.get(key) for key in ["access_key_id", "secret_access_key"] - ): + if not all(credentials.get(key) for key in ["access_key_id", "secret_access_key"]): raise ConnectorMissingCredentialError("Google Cloud Storage") elif self.bucket_type == BlobType.OCI_STORAGE: - if not all( - credentials.get(key) - for key in ["namespace", "region", "access_key_id", "secret_access_key"] - ): + if not all(credentials.get(key) for key in ["namespace", "region", "access_key_id", "secret_access_key"]): raise ConnectorMissingCredentialError("Oracle Cloud Infrastructure") elif self.bucket_type == BlobType.S3_COMPATIBLE: - if not all( - credentials.get(key) - for key in ["endpoint_url", "aws_access_key_id", "aws_secret_access_key", "addressing_style"] - ): + if not all(credentials.get(key) for key in ["endpoint_url", "aws_access_key_id", "aws_secret_access_key", "addressing_style"]): raise ConnectorMissingCredentialError("S3 Compatible Storage") else: raise ValueError(f"Unsupported bucket type: {self.bucket_type}") # Create S3 client - self.s3_client = create_s3_client( - self.bucket_type, credentials, self.european_residency - ) + self.s3_client = create_s3_client(self.bucket_type, credentials, self.european_residency) # Detect bucket region (only important for S3) if self.bucket_type == BlobType.S3: @@ -159,19 +137,11 @@ class BlobStorageConnector(LoadConnector, PollConnector, FingerprintConnector): last_modified = obj["LastModified"].replace(tzinfo=timezone.utc) size_bytes = extract_size_bytes(obj) - if ( - self.size_threshold is not None - and isinstance(size_bytes, int) - and size_bytes > self.size_threshold - ): - logging.warning( - f"{file_name} exceeds size threshold of {self.size_threshold}. Skipping." - ) + if self.size_threshold is not None and isinstance(size_bytes, int) and size_bytes > self.size_threshold: + logging.warning(f"{file_name} exceeds size threshold of {self.size_threshold}. Skipping.") return None - blob = download_object( - self.s3_client, self.bucket_name, key, self.size_threshold - ) + blob = download_object(self.s3_client, self.bucket_name, key, self.size_threshold) if blob is None: return None @@ -245,10 +215,7 @@ class BlobStorageConnector(LoadConnector, PollConnector, FingerprintConnector): """ obj = self._listing_cache.get(key) if obj is None: - raise KeyError( - f"get_value({key!r}) called before list_keys() yielded the key, " - "or after a subsequent list_keys() reset the cache" - ) + raise KeyError(f"get_value({key!r}) called before list_keys() yielded the key, or after a subsequent list_keys() reset the cache") doc = self._build_document_from_obj(obj, self._filename_counts) if doc is None: raise RuntimeError(f"Failed to materialize Document for key {key!r}") @@ -295,7 +262,7 @@ class BlobStorageConnector(LoadConnector, PollConnector, FingerprintConnector): if filename_counts.get(file_name, 0) > 1: relative_path = key if self.prefix and key.startswith(self.prefix): - relative_path = key[len(self.prefix):] + relative_path = key[len(self.prefix) :] return relative_path.replace("/", " / ") if relative_path else file_name return file_name @@ -313,9 +280,7 @@ class BlobStorageConnector(LoadConnector, PollConnector, FingerprintConnector): batch: list[SlimDocument] = [] for obj in all_objects: - batch.append( - SlimDocument(id=f"{self.bucket_type}:{self.bucket_name}:{obj['Key']}") - ) + batch.append(SlimDocument(id=f"{self.bucket_type}:{self.bucket_name}:{obj['Key']}")) if len(batch) == self.batch_size: yield batch batch = [] @@ -331,9 +296,7 @@ class BlobStorageConnector(LoadConnector, PollConnector, FingerprintConnector): end=datetime.now(timezone.utc), ) - def poll_source( - self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch - ) -> GenerateDocumentsOutput: + def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> GenerateDocumentsOutput: """Poll source to get documents""" if self.s3_client is None: raise ConnectorMissingCredentialError("Blob storage") @@ -347,24 +310,18 @@ class BlobStorageConnector(LoadConnector, PollConnector, FingerprintConnector): def validate_connector_settings(self) -> None: """Validate connector settings""" if self.s3_client is None: - raise ConnectorMissingCredentialError( - "Blob storage credentials not loaded." - ) + raise ConnectorMissingCredentialError("Blob storage credentials not loaded.") if not self.bucket_name: - raise ConnectorValidationError( - "No bucket name was provided in connector settings." - ) + raise ConnectorValidationError("No bucket name was provided in connector settings.") try: # Lightweight validation step - self.s3_client.list_objects_v2( - Bucket=self.bucket_name, Prefix=self.prefix, MaxKeys=1 - ) + self.s3_client.list_objects_v2(Bucket=self.bucket_name, Prefix=self.prefix, MaxKeys=1) except Exception as e: - error_code = getattr(e, 'response', {}).get('Error', {}).get('Code', '') - status_code = getattr(e, 'response', {}).get('ResponseMetadata', {}).get('HTTPStatusCode') + error_code = getattr(e, "response", {}).get("Error", {}).get("Code", "") + status_code = getattr(e, "response", {}).get("ResponseMetadata", {}).get("HTTPStatusCode") # Common S3 error scenarios if error_code in [ @@ -373,27 +330,16 @@ class BlobStorageConnector(LoadConnector, PollConnector, FingerprintConnector): "SignatureDoesNotMatch", ]: if status_code == 403 or error_code == "AccessDenied": - raise InsufficientPermissionsError( - f"Insufficient permissions to list objects in bucket '{self.bucket_name}'. " - "Please check your bucket policy and/or IAM policy." - ) + raise InsufficientPermissionsError(f"Insufficient permissions to list objects in bucket '{self.bucket_name}'. Please check your bucket policy and/or IAM policy.") if status_code == 401 or error_code == "SignatureDoesNotMatch": - raise CredentialExpiredError( - "Provided blob storage credentials appear invalid or expired." - ) + raise CredentialExpiredError("Provided blob storage credentials appear invalid or expired.") - raise CredentialExpiredError( - f"Credential issue encountered ({error_code})." - ) + raise CredentialExpiredError(f"Credential issue encountered ({error_code}).") if error_code == "NoSuchBucket" or status_code == 404: - raise ConnectorValidationError( - f"Bucket '{self.bucket_name}' does not exist or cannot be found." - ) + raise ConnectorValidationError(f"Bucket '{self.bucket_name}' does not exist or cannot be found.") - raise ConnectorValidationError( - f"Unexpected S3 client error (code={error_code}, status={status_code}): {e}" - ) + raise ConnectorValidationError(f"Unexpected S3 client error (code={error_code}, status={status_code}): {e}") if __name__ == "__main__": diff --git a/common/data_source/box_connector.py b/common/data_source/box_connector.py index cc44f356e8..51be0ffbcc 100644 --- a/common/data_source/box_connector.py +++ b/common/data_source/box_connector.py @@ -1,4 +1,5 @@ """Box connector""" + import logging from datetime import datetime, timezone from typing import Any, Generator @@ -53,18 +54,10 @@ class BoxConnector(LoadConnector, PollConnector): for entry in result.entries: if entry.type == "file": file = self.box_client.files.get_file_by_id(entry.id) - semantic_identifier = ( - f"{relative_folder_path} / {file.name}" - if relative_folder_path - else file.name - ) + semantic_identifier = f"{relative_folder_path} / {file.name}" if relative_folder_path else file.name yield file, semantic_identifier elif entry.type == "folder": - child_relative_path = ( - f"{relative_folder_path} / {entry.name}" - if relative_folder_path - else entry.name - ) + child_relative_path = f"{relative_folder_path} / {entry.name}" if relative_folder_path else entry.name yield from self._iter_files_recursive( folder_id=entry.id, relative_folder_path=child_relative_path, @@ -96,10 +89,7 @@ class BoxConnector(LoadConnector, PollConnector): relative_folder_path=relative_folder_path, ): modified_time: SecondsSinceUnixEpoch | None = None - raw_time = ( - getattr(file, "created_at", None) - or getattr(file, "content_created_at", None) - ) + raw_time = getattr(file, "created_at", None) or getattr(file, "content_created_at", None) if raw_time: modified_time = self._box_datetime_to_epoch_seconds(raw_time) @@ -161,7 +151,6 @@ class BoxConnector(LoadConnector, PollConnector): def poll_source(self, start, end): return self._yield_files_recursive(folder_id=self.folder_id, start=start, end=end) - def load_from_state(self): return self._yield_files_recursive(folder_id=self.folder_id, start=None, end=None) @@ -190,7 +179,7 @@ class BoxConnector(LoadConnector, PollConnector): # AUTH.get_tokens_authorization_code_grant(request.args.get("code")) # box = BoxConnector() # box.load_credentials({"auth": AUTH}) - + # lst = [] # for file in box.load_from_state(): # for f in file: diff --git a/common/data_source/config.py b/common/data_source/config.py index 6e476d0fda..49f063930e 100644 --- a/common/data_source/config.py +++ b/common/data_source/config.py @@ -1,4 +1,5 @@ """Configuration constants and enum definitions""" + import json import os from datetime import datetime, timezone @@ -31,6 +32,7 @@ ONYX_SLACK_LOCK_TOTAL_BLOCKING_TIMEOUT = 3600 class BlobType(str, Enum): """Supported storage types""" + S3 = "s3" R2 = "r2" GOOGLE_CLOUD_STORAGE = "google_cloud_storage" @@ -40,6 +42,7 @@ class BlobType(str, Enum): class DocumentSource(str, Enum): """Document sources""" + RSS = "rss" S3 = "s3" NOTION = "notion" @@ -78,6 +81,7 @@ class DocumentSource(str, Enum): class FileOrigin(str, Enum): """File origins""" + CONNECTOR = "connector" @@ -127,10 +131,7 @@ BOT_CHANNEL_PERCENTAGE_THRESHOLD = 0.95 DOWNLOAD_CHUNK_SIZE = 1024 * 1024 # 1MB SIZE_THRESHOLD_BUFFER = 64 -NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP = ( - os.environ.get("NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP", "").lower() - == "true" -) +NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP = os.environ.get("NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP", "").lower() == "true" SLIM_BATCH_SIZE = 100 @@ -146,48 +147,28 @@ _ITERATION_LIMIT = 100_000 # NOTE: Currently only supported in the Confluence and Google Drive connectors + # only handles some failures (Confluence = handles API call failures, Google # Drive = handles failures pulling files / parsing them) -CONTINUE_ON_CONNECTOR_FAILURE = os.environ.get( - "CONTINUE_ON_CONNECTOR_FAILURE", "" -).lower() not in ["false", ""] +CONTINUE_ON_CONNECTOR_FAILURE = os.environ.get("CONTINUE_ON_CONNECTOR_FAILURE", "").lower() not in ["false", ""] ##### # Confluence Connector Configs ##### -CONFLUENCE_CONNECTOR_LABELS_TO_SKIP = [ - ignored_tag - for ignored_tag in os.environ.get("CONFLUENCE_CONNECTOR_LABELS_TO_SKIP", "").split( - "," - ) - if ignored_tag -] +CONFLUENCE_CONNECTOR_LABELS_TO_SKIP = [ignored_tag for ignored_tag in os.environ.get("CONFLUENCE_CONNECTOR_LABELS_TO_SKIP", "").split(",") if ignored_tag] # Avoid to get archived pages -CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES = ( - os.environ.get("CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES", "").lower() == "true" -) +CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES = os.environ.get("CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES", "").lower() == "true" # Attachments exceeding this size will not be retrieved (in bytes) -CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD = int( - os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD", 10 * 1024 * 1024) -) +CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD = int(os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD", 10 * 1024 * 1024)) # Attachments with more chars than this will not be indexed. This is to prevent extremely # large files from freezing indexing. 200,000 is ~100 google doc pages. -CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD = int( - os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD", 200_000) -) +CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD = int(os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD", 200_000)) -_RAW_CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE = os.environ.get( - "CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE", "" -) +_RAW_CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE = os.environ.get("CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE", "") CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE = cast( list[dict[str, str]] | None, - ( - json.loads(_RAW_CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE) - if _RAW_CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE - else None - ), + (json.loads(_RAW_CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE) if _RAW_CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE else None), ) # enter as a floating point offset from UTC in hours (-24 < val < 24) @@ -196,53 +177,29 @@ CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE = cast( # For the default value, we assume that the user's local timezone is more likely to be # correct (i.e. the configured user's timezone or the default server one) than UTC. # https://developer.atlassian.com/cloud/confluence/cql-fields/#created -CONFLUENCE_TIMEZONE_OFFSET = float( - os.environ.get("CONFLUENCE_TIMEZONE_OFFSET", get_current_tz_offset()) -) +CONFLUENCE_TIMEZONE_OFFSET = float(os.environ.get("CONFLUENCE_TIMEZONE_OFFSET", get_current_tz_offset())) -CONFLUENCE_SYNC_TIME_BUFFER_SECONDS = int( - os.environ.get("CONFLUENCE_SYNC_TIME_BUFFER_SECONDS", ONE_DAY) -) +CONFLUENCE_SYNC_TIME_BUFFER_SECONDS = int(os.environ.get("CONFLUENCE_SYNC_TIME_BUFFER_SECONDS", ONE_DAY)) -GOOGLE_DRIVE_SYNC_TIME_BUFFER_SECONDS = int( - os.environ.get("GOOGLE_DRIVE_SYNC_TIME_BUFFER_SECONDS", ONE_DAY) -) +GOOGLE_DRIVE_SYNC_TIME_BUFFER_SECONDS = int(os.environ.get("GOOGLE_DRIVE_SYNC_TIME_BUFFER_SECONDS", ONE_DAY)) -GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD = int( - os.environ.get("GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024) -) +GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD = int(os.environ.get("GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024)) -JIRA_CONNECTOR_LABELS_TO_SKIP = [ - ignored_tag - for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",") - if ignored_tag -] -JIRA_CONNECTOR_MAX_TICKET_SIZE = int( - os.environ.get("JIRA_CONNECTOR_MAX_TICKET_SIZE", 100 * 1024) -) -JIRA_SYNC_TIME_BUFFER_SECONDS = int( - os.environ.get("JIRA_SYNC_TIME_BUFFER_SECONDS", ONE_MINUTE) -) -JIRA_TIMEZONE_OFFSET = float( - os.environ.get("JIRA_TIMEZONE_OFFSET", get_current_tz_offset()) -) +JIRA_CONNECTOR_LABELS_TO_SKIP = [ignored_tag for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",") if ignored_tag] +JIRA_CONNECTOR_MAX_TICKET_SIZE = int(os.environ.get("JIRA_CONNECTOR_MAX_TICKET_SIZE", 100 * 1024)) +JIRA_SYNC_TIME_BUFFER_SECONDS = int(os.environ.get("JIRA_SYNC_TIME_BUFFER_SECONDS", ONE_MINUTE)) +JIRA_TIMEZONE_OFFSET = float(os.environ.get("JIRA_TIMEZONE_OFFSET", get_current_tz_offset())) OAUTH_SLACK_CLIENT_ID = os.environ.get("OAUTH_SLACK_CLIENT_ID", "") OAUTH_SLACK_CLIENT_SECRET = os.environ.get("OAUTH_SLACK_CLIENT_SECRET", "") -OAUTH_CONFLUENCE_CLOUD_CLIENT_ID = os.environ.get( - "OAUTH_CONFLUENCE_CLOUD_CLIENT_ID", "" -) +OAUTH_CONFLUENCE_CLOUD_CLIENT_ID = os.environ.get("OAUTH_CONFLUENCE_CLOUD_CLIENT_ID", "") -OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET = os.environ.get( - "OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET", "" -) +OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET = os.environ.get("OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET", "") OAUTH_JIRA_CLOUD_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLOUD_CLIENT_ID", "") OAUTH_JIRA_CLOUD_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLOUD_CLIENT_SECRET", "") OAUTH_GOOGLE_DRIVE_CLIENT_ID = os.environ.get("OAUTH_GOOGLE_DRIVE_CLIENT_ID", "") -OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get( - "OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", "" -) +OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get("OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", "") GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI = os.environ.get("GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI", "http://localhost:9380/v1/connector/google-drive/oauth/web/callback") GMAIL_WEB_OAUTH_REDIRECT_URI = os.environ.get("GMAIL_WEB_OAUTH_REDIRECT_URI", "http://localhost:9380/v1/connector/gmail/oauth/web/callback") @@ -258,6 +215,7 @@ BOX_WEB_OAUTH_REDIRECT_URI = os.environ.get("BOX_WEB_OAUTH_REDIRECT_URI", "http: GITHUB_CONNECTOR_BASE_URL = os.environ.get("GITHUB_CONNECTOR_BASE_URL") or None + class HtmlBasedConnectorTransformLinksStrategy(str, Enum): # remove links entirely STRIP = "strip" @@ -272,28 +230,16 @@ HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY = os.environ.get( PARSE_WITH_TRAFILATURA = os.environ.get("PARSE_WITH_TRAFILATURA", "").lower() == "true" -WEB_CONNECTOR_IGNORED_CLASSES = os.environ.get( - "WEB_CONNECTOR_IGNORED_CLASSES", "sidebar,footer" -).split(",") -WEB_CONNECTOR_IGNORED_ELEMENTS = os.environ.get( - "WEB_CONNECTOR_IGNORED_ELEMENTS", "nav,footer,meta,script,style,symbol,aside" -).split(",") +WEB_CONNECTOR_IGNORED_CLASSES = os.environ.get("WEB_CONNECTOR_IGNORED_CLASSES", "sidebar,footer").split(",") +WEB_CONNECTOR_IGNORED_ELEMENTS = os.environ.get("WEB_CONNECTOR_IGNORED_ELEMENTS", "nav,footer,meta,script,style,symbol,aside").split(",") -AIRTABLE_CONNECTOR_SIZE_THRESHOLD = int( - os.environ.get("AIRTABLE_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024) -) +AIRTABLE_CONNECTOR_SIZE_THRESHOLD = int(os.environ.get("AIRTABLE_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024)) -ASANA_CONNECTOR_SIZE_THRESHOLD = int( - os.environ.get("ASANA_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024) -) +ASANA_CONNECTOR_SIZE_THRESHOLD = int(os.environ.get("ASANA_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024)) -IMAP_CONNECTOR_SIZE_THRESHOLD = int( - os.environ.get("IMAP_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024) -) +IMAP_CONNECTOR_SIZE_THRESHOLD = int(os.environ.get("IMAP_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024)) -ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS = os.environ.get( - "ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS", "" -).split(",") +ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS = os.environ.get("ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS", "").split(",") _USER_NOT_FOUND = "Unknown Confluence User" diff --git a/common/data_source/confluence_connector.py b/common/data_source/confluence_connector.py index ef0d6a7760..bb447ecdcd 100644 --- a/common/data_source/confluence_connector.py +++ b/common/data_source/confluence_connector.py @@ -1,6 +1,5 @@ - - """Confluence connector""" + import copy import json import logging @@ -18,41 +17,67 @@ from atlassian.errors import ApiError from atlassian import Confluence from requests.exceptions import HTTPError -from common.data_source.config import INDEX_BATCH_SIZE, DocumentSource, CONTINUE_ON_CONNECTOR_FAILURE, \ - CONFLUENCE_CONNECTOR_LABELS_TO_SKIP, CONFLUENCE_TIMEZONE_OFFSET, CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE, \ - CONFLUENCE_SYNC_TIME_BUFFER_SECONDS, \ - OAUTH_CONFLUENCE_CLOUD_CLIENT_ID, OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET, _DEFAULT_PAGINATION_LIMIT, \ - _PROBLEMATIC_EXPANSIONS, _REPLACEMENT_EXPANSIONS, _USER_NOT_FOUND, _COMMENT_EXPANSION_FIELDS, \ - _ATTACHMENT_EXPANSION_FIELDS, _PAGE_EXPANSION_FIELDS, ONE_DAY, ONE_HOUR, _RESTRICTIONS_EXPANSION_FIELDS, \ - _SLIM_DOC_BATCH_SIZE, CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD -from common.data_source.exceptions import ( - ConnectorMissingCredentialError, - ConnectorValidationError, - InsufficientPermissionsError, - UnexpectedValidationError, CredentialExpiredError +from common.data_source.config import ( + INDEX_BATCH_SIZE, + DocumentSource, + CONTINUE_ON_CONNECTOR_FAILURE, + CONFLUENCE_CONNECTOR_LABELS_TO_SKIP, + CONFLUENCE_TIMEZONE_OFFSET, + CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE, + CONFLUENCE_SYNC_TIME_BUFFER_SECONDS, + OAUTH_CONFLUENCE_CLOUD_CLIENT_ID, + OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET, + _DEFAULT_PAGINATION_LIMIT, + _PROBLEMATIC_EXPANSIONS, + _REPLACEMENT_EXPANSIONS, + _USER_NOT_FOUND, + _COMMENT_EXPANSION_FIELDS, + _ATTACHMENT_EXPANSION_FIELDS, + _PAGE_EXPANSION_FIELDS, + ONE_DAY, + ONE_HOUR, + _RESTRICTIONS_EXPANSION_FIELDS, + _SLIM_DOC_BATCH_SIZE, + CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD, ) +from common.data_source.exceptions import ConnectorMissingCredentialError, ConnectorValidationError, InsufficientPermissionsError, UnexpectedValidationError, CredentialExpiredError from common.data_source.html_utils import format_document_soup from common.data_source.interfaces import ( ConnectorCheckpoint, CredentialsConnector, SecondsSinceUnixEpoch, - SlimConnectorWithPermSync, StaticCredentialsProvider, CheckpointedConnector, SlimConnector, - CredentialsProviderInterface, ConfluenceUser, IndexingHeartbeatInterface, AttachmentProcessingResult, - CheckpointOutput + SlimConnectorWithPermSync, + StaticCredentialsProvider, + CheckpointedConnector, + SlimConnector, + CredentialsProviderInterface, + ConfluenceUser, + IndexingHeartbeatInterface, + AttachmentProcessingResult, + CheckpointOutput, +) +from common.data_source.models import ConnectorFailure, Document, TextSection, ImageSection, BasicExpertInfo, DocumentFailure, GenerateSlimDocumentOutput, SlimDocument, ExternalAccess +from common.data_source.utils import ( + load_all_docs_from_checkpoint_connector, + scoped_url, + process_confluence_user_profiles_override, + confluence_refresh_tokens, + run_with_timeout, + _handle_http_error, + update_param_in_path, + get_start_param_from_url, + build_confluence_document_id, + datetime_from_string, + is_atlassian_date_error, + validate_attachment_filetype, ) -from common.data_source.models import ConnectorFailure, Document, TextSection, ImageSection, BasicExpertInfo, \ - DocumentFailure, GenerateSlimDocumentOutput, SlimDocument, ExternalAccess -from common.data_source.utils import load_all_docs_from_checkpoint_connector, scoped_url, \ - process_confluence_user_profiles_override, confluence_refresh_tokens, run_with_timeout, _handle_http_error, \ - update_param_in_path, get_start_param_from_url, build_confluence_document_id, datetime_from_string, \ - is_atlassian_date_error, validate_attachment_filetype from rag.utils.redis_conn import RedisDB, REDIS_CONN _USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {} _USER_EMAIL_CACHE: dict[str, str | None] = {} -class ConfluenceCheckpoint(ConnectorCheckpoint): +class ConfluenceCheckpoint(ConnectorCheckpoint): next_page_url: str | None @@ -83,9 +108,7 @@ class OnyxConfluence: scoped_token: bool = False, # should generally not be passed in, but making it overridable for # easier testing - confluence_user_profiles_override: list[dict[str, str]] | None = ( - CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE - ), + confluence_user_profiles_override: list[dict[str, str]] | None = (CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE), ) -> None: self.base_url = url #'/'.join(url.rstrip("/").split("/")[:-1]) url = scoped_url(url, "confluence") if scoped_token else url @@ -102,10 +125,7 @@ class OnyxConfluence: self.static_credentials = self._credentials_provider.get_credentials() self._confluence = Confluence(url) - self.credential_key: str = ( - self.CREDENTIAL_PREFIX - + f":credential_{self._credentials_provider.get_provider_key()}" - ) + self.credential_key: str = self.CREDENTIAL_PREFIX + f":credential_{self._credentials_provider.get_provider_key()}" self._kwargs: Any = None @@ -117,11 +137,7 @@ class OnyxConfluence: if timeout: self.shared_base_kwargs["timeout"] = timeout - self._confluence_user_profiles_override = ( - process_confluence_user_profiles_override(confluence_user_profiles_override) - if confluence_user_profiles_override - else None - ) + self._confluence_user_profiles_override = process_confluence_user_profiles_override(confluence_user_profiles_override) if confluence_user_profiles_override else None def _renew_credentials(self) -> tuple[dict[str, Any], bool]: """credential_json - the current json credentials @@ -185,9 +201,7 @@ class OnyxConfluence: # reasonably frequently rather than trying to handle strong synchronization # between the db and redis everywhere the credentials might be updated new_credential_str = json.dumps(new_credentials) - self.redis_client.set( - self.credential_key, new_credential_str, exp=self.CREDENTIAL_TTL - ) + self.redis_client.set(self.credential_key, new_credential_str, exp=self.CREDENTIAL_TTL) self._credentials_provider.set_credentials(new_credentials) return new_credentials, True @@ -198,9 +212,7 @@ class OnyxConfluence: if "confluence_refresh_token" in credentials: oauth2_dict["client_id"] = OAUTH_CONFLUENCE_CLOUD_CLIENT_ID oauth2_dict["token"] = {} - oauth2_dict["token"]["access_token"] = credentials[ - "confluence_access_token" - ] + oauth2_dict["token"]["access_token"] = credentials["confluence_access_token"] return oauth2_dict def _probe_connection( @@ -230,14 +242,10 @@ class OnyxConfluence: r.raise_for_status() except HTTPError as e: if e.response.status_code == 403: - logging.warning( - "scoped token authenticated but not valid for probe endpoint (spaces)" - ) + logging.warning("scoped token authenticated but not valid for probe endpoint (spaces)") else: if "WWW-Authenticate" in e.response.headers: - logging.warning( - f"WWW-Authenticate: {e.response.headers['WWW-Authenticate']}" - ) + logging.warning(f"WWW-Authenticate: {e.response.headers['WWW-Authenticate']}") logging.warning(f"Full error: {e.response.text}") raise e return @@ -246,15 +254,9 @@ class OnyxConfluence: if "confluence_refresh_token" in credentials: logging.info("Probing Confluence with OAuth Access Token.") - oauth2_dict: dict[str, Any] = OnyxConfluence._make_oauth2_dict( - credentials - ) - url = ( - f"https://api.atlassian.com/ex/confluence/{credentials['cloud_id']}" - ) - confluence_client_with_minimal_retries = Confluence( - url=url, oauth2=oauth2_dict, **merged_kwargs - ) + oauth2_dict: dict[str, Any] = OnyxConfluence._make_oauth2_dict(credentials) + url = f"https://api.atlassian.com/ex/confluence/{credentials['cloud_id']}" + confluence_client_with_minimal_retries = Confluence(url=url, oauth2=oauth2_dict, **merged_kwargs) else: logging.info("Probing Confluence with Personal Access Token.") url = self._url @@ -288,11 +290,7 @@ class OnyxConfluence: # space_details = confluence_client_with_minimal_retries.cql(f"space.key={space_key}+AND+type=space") if not spaces: - raise RuntimeError( - f"No spaces found at {url}! " - "Check your credentials and wiki_base and make sure " - "is_cloud is set correctly." - ) + raise RuntimeError(f"No spaces found at {url}! Check your credentials and wiki_base and make sure is_cloud is set correctly.") logging.info("Confluence probe succeeded.") @@ -304,9 +302,7 @@ class OnyxConfluence: merged_kwargs = {**self.shared_base_kwargs, **kwargs} with self._credentials_provider: credentials, _ = self._renew_credentials() - self._confluence = self._initialize_connection_helper( - credentials, **merged_kwargs - ) + self._confluence = self._initialize_connection_helper(credentials, **merged_kwargs) self._kwargs = merged_kwargs def _initialize_connection_helper( @@ -328,9 +324,7 @@ class OnyxConfluence: url = f"https://api.atlassian.com/ex/confluence/{credentials['cloud_id']}" confluence = Confluence(url=url, oauth2=oauth2_dict, **kwargs) else: - logging.info( - f"Connecting to Confluence with Personal Access Token as user: {credentials['confluence_username']}" - ) + logging.info(f"Connecting to Confluence with Personal Access Token as user: {credentials['confluence_username']}") if self._is_cloud: confluence = Confluence( url=self._url, @@ -350,9 +344,7 @@ class OnyxConfluence: # https://developer.atlassian.com/cloud/confluence/rate-limiting/ # This uses the native rate limiting option provided by the # confluence client and otherwise applies a simpler set of error handling. - def _make_rate_limited_confluence_method( - self, name: str, credential_provider: CredentialsProviderInterface | None - ) -> Callable[..., Any]: + def _make_rate_limited_confluence_method(self, name: str, credential_provider: CredentialsProviderInterface | None) -> Callable[..., Any]: def wrapped_call(*args: list[Any], **kwargs: Any) -> Any: MAX_RETRIES = 5 @@ -361,9 +353,7 @@ class OnyxConfluence: for attempt in range(MAX_RETRIES): if time.monotonic() > timeout_at: - raise TimeoutError( - f"Confluence call attempts took longer than {TIMEOUT} seconds." - ) + raise TimeoutError(f"Confluence call attempts took longer than {TIMEOUT} seconds.") # we're relying more on the client to rate limit itself # and applying our own retries in a more specific set of circumstances @@ -372,33 +362,24 @@ class OnyxConfluence: with credential_provider: credentials, renewed = self._renew_credentials() if renewed: - self._confluence = self._initialize_connection_helper( - credentials, **self._kwargs - ) + self._confluence = self._initialize_connection_helper(credentials, **self._kwargs) attr = getattr(self._confluence, name, None) if attr is None: # The underlying Confluence client doesn't have this attribute - raise AttributeError( - f"'{type(self).__name__}' object has no attribute '{name}'" - ) + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") return attr(*args, **kwargs) else: attr = getattr(self._confluence, name, None) if attr is None: # The underlying Confluence client doesn't have this attribute - raise AttributeError( - f"'{type(self).__name__}' object has no attribute '{name}'" - ) + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") return attr(*args, **kwargs) except HTTPError as e: delay_until = _handle_http_error(e, attempt) - logging.warning( - f"HTTPError in confluence call. " - f"Retrying in {delay_until} seconds..." - ) + logging.warning(f"HTTPError in confluence call. Retrying in {delay_until} seconds...") while time.monotonic() < delay_until: # in the future, check a signal here to exit time.sleep(1) @@ -408,9 +389,7 @@ class OnyxConfluence: if attempt == MAX_RETRIES - 1: raise e - logging.exception( - "Confluence Client raised an AttributeError. Retrying..." - ) + logging.exception("Confluence Client raised an AttributeError. Retrying...") time.sleep(5) return wrapped_call @@ -420,9 +399,7 @@ class OnyxConfluence: attr = getattr(self._confluence, name, None) if attr is None: # The underlying Confluence client doesn't have this attribute - raise AttributeError( - f"'{type(self).__name__}' object has no attribute '{name}'" - ) + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") # If it's not a method, just return it after ensuring token validity if not callable(attr): @@ -433,9 +410,7 @@ class OnyxConfluence: return attr # wrap the method with our retry handler - rate_limited_method: Callable[..., Any] = ( - self._make_rate_limited_confluence_method(name, self._credentials_provider) - ) + rate_limited_method: Callable[..., Any] = self._make_rate_limited_confluence_method(name, self._credentials_provider) return rate_limited_method @@ -465,9 +440,7 @@ class OnyxConfluence: for ind in range(limit): try: - temp_url_suffix = update_param_in_path( - url_suffix, "start", str(initial_start + ind) - ) + temp_url_suffix = update_param_in_path(url_suffix, "start", str(initial_start + ind)) temp_url_suffix = update_param_in_path(temp_url_suffix, "limit", "1") logging.info(f"Making recovery confluence call to {temp_url_suffix}") raw_response = self.get(path=temp_url_suffix, advanced_mode=True) @@ -478,16 +451,11 @@ class OnyxConfluence: if not latest_results: # no more results, break out of the loop - logging.info( - f"No results found for call '{temp_url_suffix}'" - "Stopping pagination." - ) + logging.info(f"No results found for call '{temp_url_suffix}'Stopping pagination.") found_empty_page = True break except Exception: - logging.exception( - f"Error in confluence call to {temp_url_suffix}. Continuing." - ) + logging.exception(f"Error in confluence call to {temp_url_suffix}. Continuing.") if found_empty_page: return None @@ -535,10 +503,7 @@ class OnyxConfluence: # with the replacement expansion and try again # If that fails, raise the error if _PROBLEMATIC_EXPANSIONS in url_suffix: - logging.warning( - f"Replacing {_PROBLEMATIC_EXPANSIONS} with {_REPLACEMENT_EXPANSIONS}" - " and trying again." - ) + logging.warning(f"Replacing {_PROBLEMATIC_EXPANSIONS} with {_REPLACEMENT_EXPANSIONS} and trying again.") url_suffix = url_suffix.replace( _PROBLEMATIC_EXPANSIONS, _REPLACEMENT_EXPANSIONS, @@ -571,20 +536,13 @@ class OnyxConfluence: continue else: - logging.exception( - f"Error in confluence call to {url_suffix} \n" - f"Raw Response Text: {raw_response.text} \n" - f"Full Response: {raw_response.__dict__} \n" - f"Error: {e} \n" - ) + logging.exception(f"Error in confluence call to {url_suffix} \nRaw Response Text: {raw_response.text} \nFull Response: {raw_response.__dict__} \nError: {e} \n") raise try: next_response = raw_response.json() except Exception as e: - logging.exception( - f"Failed to parse response as JSON. Response: {raw_response.__dict__}" - ) + logging.exception(f"Failed to parse response as JSON. Response: {raw_response.__dict__}") raise e # Yield the results individually. @@ -623,16 +581,12 @@ class OnyxConfluence: if not self._is_cloud: # If confluence claims there are more results, we update the start param # based on how many results were returned and try again. - url_suffix = update_param_in_path( - url_suffix, "start", str(updated_start) - ) + url_suffix = update_param_in_path(url_suffix, "start", str(updated_start)) # notify the caller of the new url next_page_callback(url_suffix) elif force_offset_pagination and i == len(results) - 1: - url_suffix = update_param_in_path( - old_url_suffix, "start", str(updated_start) - ) + url_suffix = update_param_in_path(old_url_suffix, "start", str(updated_start)) yield result @@ -640,10 +594,7 @@ class OnyxConfluence: # 0 results. This is a bug with Confluence, so we need to check for it and # stop paginating. if url_suffix and not results: - logging.info( - f"No results found for call '{old_url_suffix}' despite next link " - "being present. Stopping pagination." - ) + logging.info(f"No results found for call '{old_url_suffix}' despite next link being present. Stopping pagination.") break def build_cql_url(self, cql: str, expand: str | None = None) -> str: @@ -675,9 +626,7 @@ class OnyxConfluence: next page links manually. """ try: - yield from self._paginate_url( - cql_url, limit=limit, next_page_callback=next_page_callback - ) + yield from self._paginate_url(cql_url, limit=limit, next_page_callback=next_page_callback) except Exception as e: logging.exception(f"Error in paginated_page_retrieval: {e}") raise e @@ -731,9 +680,7 @@ class OnyxConfluence: url = "rest/api/search/user" expand_string = f"&expand={expand}" if expand else "" url += f"?cql={cql}{expand_string}" - for user_result in self._paginate_url( - url, limit, force_offset_pagination=True - ): + for user_result in self._paginate_url(url, limit, force_offset_pagination=True): user = user_result["user"] yield ConfluenceUser( user_id=user["accountId"], @@ -816,10 +763,7 @@ class OnyxConfluence: response = self.post(url, data=data) logging.debug(f"jsonrpc response: {response}") if not response.get("result"): - logging.warning( - f"No jsonrpc response for space permissions for space {space_key}" - f"\nResponse: {response}" - ) + logging.warning(f"No jsonrpc response for space permissions for space {space_key}\nResponse: {response}") return response.get("result", []) @@ -843,16 +787,12 @@ class OnyxConfluence: response = self.get(url, params=params) except HTTPError as e: if e.response.status_code == 403: - raise ApiPermissionError( - "The calling user does not have permission", reason=e - ) + raise ApiPermissionError("The calling user does not have permission", reason=e) raise return response -def get_user_email_from_username__server( - confluence_client: OnyxConfluence, user_name: str -) -> str | None: +def get_user_email_from_username__server(confluence_client: OnyxConfluence, user_name: str) -> str | None: global _USER_EMAIL_CACHE if _USER_EMAIL_CACHE.get(user_name) is None: try: @@ -930,15 +870,9 @@ def extract_text_from_confluence_html( _remove_macro_stylings(soup=soup) for user in soup.findAll("ri:user"): - user_id = ( - user.attrs["ri:account-id"] - if "ri:account-id" in user.attrs - else user.get("ri:userkey") - ) + user_id = user.attrs["ri:account-id"] if "ri:account-id" in user.attrs else user.get("ri:userkey") if not user_id: - logging.warning( - "ri:userkey not found in ri:user element. " f"Found attrs: {user.attrs}" - ) + logging.warning(f"ri:userkey not found in ri:user element. Found attrs: {user.attrs}") continue # Include @ sign for tagging, more clear for LLM user.replaceWith("@" + _get_user(confluence_client, user_id)) @@ -950,17 +884,13 @@ def extract_text_from_confluence_html( page_data = html_page_reference.find("ri:page") if not page_data: - logging.warning( - f"Skipping retrieval of {html_page_reference} because because page data is missing" - ) + logging.warning(f"Skipping retrieval of {html_page_reference} because because page data is missing") continue page_title = page_data.attrs.get("ri:content-title") if not page_title: # only fetch pages that have a title - logging.warning( - f"Skipping retrieval of {html_page_reference} because it has no title" - ) + logging.warning(f"Skipping retrieval of {html_page_reference} because it has no title") continue if page_title in fetched_titles: @@ -984,9 +914,7 @@ def extract_text_from_confluence_html( page_contents = page break except Exception as e: - logging.warning( - f"Error getting page contents for object {confluence_object}: {e}" - ) + logging.warning(f"Error getting page contents for object {confluence_object}: {e}") continue if not page_contents: @@ -1013,9 +941,7 @@ def extract_text_from_confluence_html( # This extracts the text from inline attachments in the page so they can be # represented in the document text as plain text try: - html_attachment.replaceWith( - f"{sanitize_attachment_title(html_attachment.attrs['ri:filename'])}" - ) # to be replaced later + html_attachment.replaceWith(f"{sanitize_attachment_title(html_attachment.attrs['ri:filename'])}") # to be replaced later except Exception as e: logging.warning(f"Error processing ac:attachment: {e}") @@ -1111,20 +1037,16 @@ def _make_attachment_link( download_link = "" from urllib.parse import urlparse - netloc =urlparse(confluence_client.url).hostname + + netloc = urlparse(confluence_client.url).hostname if netloc == "api.atlassian.com" or (netloc and netloc.endswith(".api.atlassian.com")): - # if "api.atlassian.com" in confluence_client.url: + # if "api.atlassian.com" in confluence_client.url: # https://developer.atlassian.com/cloud/confluence/rest/v1/api-group-content---attachments/#api-wiki-rest-api-content-id-child-attachment-attachmentid-download-get if not parent_content_id: - logging.warning( - "parent_content_id is required to download attachments from Confluence Cloud!" - ) + logging.warning("parent_content_id is required to download attachments from Confluence Cloud!") return None - download_link = ( - confluence_client.url - + f"/rest/api/content/{parent_content_id}/child/attachment/{attachment['id']}/download" - ) + download_link = confluence_client.url + f"/rest/api/content/{parent_content_id}/child/attachment/{attachment['id']}/download" else: download_link = confluence_client.url + attachment["_links"]["download"] @@ -1163,13 +1085,9 @@ def process_attachment( error=f"Unsupported file type: {media_type}", ) - attachment_link = _make_attachment_link( - confluence_client, attachment, parent_content_id - ) + attachment_link = _make_attachment_link(confluence_client, attachment, parent_content_id) if not attachment_link: - return AttachmentProcessingResult( - text=None, file_blob=None, file_name=None, error="Failed to make attachment link" - ) + return AttachmentProcessingResult(text=None, file_blob=None, file_name=None, error="Failed to make attachment link") attachment_size = attachment["extensions"]["fileSize"] @@ -1183,11 +1101,7 @@ def process_attachment( ) else: if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD: - logging.warning( - f"Skipping {attachment_link} due to size. " - f"size={attachment_size} " - f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}" - ) + logging.warning(f"Skipping {attachment_link} due to size. size={attachment_size} threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}") return AttachmentProcessingResult( text=None, file_blob=None, @@ -1195,19 +1109,12 @@ def process_attachment( error=f"Attachment text too long: {attachment_size} chars", ) - logging.info( - f"Downloading attachment: " - f"title={attachment['title']} " - f"length={attachment_size} " - f"link={attachment_link}" - ) + logging.info(f"Downloading attachment: title={attachment['title']} length={attachment_size} link={attachment_link}") # Download the attachment resp: requests.Response = confluence_client._session.get(attachment_link) if resp.status_code != 200: - logging.warning( - f"Failed to fetch {attachment_link} with status code {resp.status_code}" - ) + logging.warning(f"Failed to fetch {attachment_link} with status code {resp.status_code}") return AttachmentProcessingResult( text=None, file_blob=None, @@ -1217,29 +1124,21 @@ def process_attachment( raw_bytes = resp.content if not raw_bytes: - return AttachmentProcessingResult( - text=None, file_blob=None, file_name=None, error="attachment.content is None" - ) + return AttachmentProcessingResult(text=None, file_blob=None, file_name=None, error="attachment.content is None") # Process image attachments if media_type.startswith("image/"): - return _process_image_attachment( - confluence_client, attachment, raw_bytes, media_type - ) + return _process_image_attachment(confluence_client, attachment, raw_bytes, media_type) # Process document attachments try: - return AttachmentProcessingResult(text="",file_blob=raw_bytes, file_name=attachment.get("title", "unknown_title"), error=None) + return AttachmentProcessingResult(text="", file_blob=raw_bytes, file_name=attachment.get("title", "unknown_title"), error=None) except Exception as e: logging.exception(e) - return AttachmentProcessingResult( - text=None, file_blob=None, file_name=None, error=f"Failed to extract text: {e}" - ) + return AttachmentProcessingResult(text=None, file_blob=None, file_name=None, error=f"Failed to extract text: {e}") except Exception as e: - return AttachmentProcessingResult( - text=None, file_blob=None, file_name=None, error=f"Failed to process attachment: {e}" - ) + return AttachmentProcessingResult(text=None, file_blob=None, file_name=None, error=f"Failed to process attachment: {e}") def convert_attachment_to_content( @@ -1257,16 +1156,12 @@ def convert_attachment_to_content( media_type = attachment.get("metadata", {}).get("mediaType", "") # Quick check for unsupported types: if media_type.startswith("video/") or media_type == "application/gliffy+json": - logging.warning( - f"Skipping unsupported attachment type: '{media_type}' for {attachment['title']}" - ) + logging.warning(f"Skipping unsupported attachment type: '{media_type}' for {attachment['title']}") return None result = process_attachment(confluence_client, attachment, page_id, allow_images) if result.error is not None: - logging.warning( - f"Attachment {attachment['title']} encountered error: {result.error}" - ) + logging.warning(f"Attachment {attachment['title']} encountered error: {result.error}") return None return result.file_name, result.file_blob @@ -1340,9 +1235,7 @@ class ConfluenceConnector( self.cql_label_filter = "" if labels_to_skip: labels_to_skip = list(set(labels_to_skip)) - comma_separated_labels = ",".join( - f"'{quote(label)}'" for label in labels_to_skip - ) + comma_separated_labels = ",".join(f"'{quote(label)}'" for label in labels_to_skip) self.cql_label_filter = f" and label not in ({comma_separated_labels})" self.timezone: timezone = timezone(offset=timedelta(hours=timezone_offset)) @@ -1365,18 +1258,14 @@ class ConfluenceConnector( logging.info(f"Setting allow_images to {value}.") self.allow_images = value - def _adjust_start_for_query( - self, start: SecondsSinceUnixEpoch | None - ) -> SecondsSinceUnixEpoch | None: + def _adjust_start_for_query(self, start: SecondsSinceUnixEpoch | None) -> SecondsSinceUnixEpoch | None: if not start or start <= 0: return start if self.time_buffer_seconds <= 0: return start return max(0.0, start - self.time_buffer_seconds) - def _is_newer_than_start( - self, doc_time: datetime | None, start: SecondsSinceUnixEpoch | None - ) -> bool: + def _is_newer_than_start(self, doc_time: datetime | None, start: SecondsSinceUnixEpoch | None) -> bool: if not start or start <= 0: return True if doc_time is None: @@ -1395,9 +1284,7 @@ class ConfluenceConnector( raise ConnectorMissingCredentialError("Confluence") return self._low_timeout_confluence_client - def set_credentials_provider( - self, credentials_provider: CredentialsProviderInterface - ) -> None: + def set_credentials_provider(self, credentials_provider: CredentialsProviderInterface) -> None: self.credentials_provider = credentials_provider # raises exception if there's a problem @@ -1443,14 +1330,10 @@ class ConfluenceConnector( # Add time filters query_start = self._adjust_start_for_query(start) if query_start: - formatted_start_time = datetime.fromtimestamp( - query_start, tz=self.timezone - ).strftime("%Y-%m-%d %H:%M") + formatted_start_time = datetime.fromtimestamp(query_start, tz=self.timezone).strftime("%Y-%m-%d %H:%M") page_query += f" and lastmodified >= '{formatted_start_time}'" if end: - formatted_end_time = datetime.fromtimestamp(end, tz=self.timezone).strftime( - "%Y-%m-%d %H:%M" - ) + formatted_end_time = datetime.fromtimestamp(end, tz=self.timezone).strftime("%Y-%m-%d %H:%M") page_query += f" and lastmodified <= '{formatted_end_time}'" page_query += " order by lastmodified asc" @@ -1468,14 +1351,10 @@ class ConfluenceConnector( # Add time filters to avoid reprocessing unchanged attachments during refresh query_start = self._adjust_start_for_query(start) if query_start: - formatted_start_time = datetime.fromtimestamp( - query_start, tz=self.timezone - ).strftime("%Y-%m-%d %H:%M") + formatted_start_time = datetime.fromtimestamp(query_start, tz=self.timezone).strftime("%Y-%m-%d %H:%M") attachment_query += f" and lastmodified >= '{formatted_start_time}'" if end: - formatted_end_time = datetime.fromtimestamp(end, tz=self.timezone).strftime( - "%Y-%m-%d %H:%M" - ) + formatted_end_time = datetime.fromtimestamp(end, tz=self.timezone).strftime("%Y-%m-%d %H:%M") attachment_query += f" and lastmodified <= '{formatted_end_time}'" attachment_query += " order by lastmodified asc" @@ -1499,9 +1378,7 @@ class ConfluenceConnector( ) return comment_string - def _convert_page_to_document( - self, page: dict[str, Any] - ) -> Document | ConnectorFailure: + def _convert_page_to_document(self, page: dict[str, Any]) -> Document | ConnectorFailure: """ Converts a Confluence page to a Document object. Includes the page content, comments, and attachments. @@ -1512,38 +1389,36 @@ class ConfluenceConnector( page_id = page["id"] page_title = page["title"] logging.info(f"Converting page {page_title} to document") - page_url = build_confluence_document_id( - self.wiki_base, page["_links"]["webui"], self.is_cloud - ) + page_url = build_confluence_document_id(self.wiki_base, page["_links"]["webui"], self.is_cloud) # Build hierarchical path for semantic identifier space_name = page.get("space", {}).get("name", "") - + # Build path from ancestors path_parts = [] if space_name: path_parts.append(space_name) - + # Add ancestor pages to path if available if "ancestors" in page and page["ancestors"]: for ancestor in page["ancestors"]: ancestor_title = ancestor.get("title", "") if ancestor_title: path_parts.append(ancestor_title) - + # Add current page title path_parts.append(page_title) - + # Track page names for duplicate detection full_path = " / ".join(path_parts) if len(path_parts) > 1 else page_title - + # Count occurrences of this page title if page_title not in self._document_name_counts: self._document_name_counts[page_title] = 0 self._document_name_paths[page_title] = [] self._document_name_counts[page_title] += 1 self._document_name_paths[page_title].append(full_path) - + # Use simple name if no duplicates, otherwise use full path if self._document_name_counts[page_title] == 1: semantic_identifier = page_title @@ -1551,21 +1426,15 @@ class ConfluenceConnector( semantic_identifier = full_path # Get the page content - page_content = extract_text_from_confluence_html( - self.confluence_client, page, self._fetched_titles - ) + page_content = extract_text_from_confluence_html(self.confluence_client, page, self._fetched_titles) # Create the main section for the page content - sections: list[TextSection | ImageSection] = [ - TextSection(text=page_content, link=page_url) - ] + sections: list[TextSection | ImageSection] = [TextSection(text=page_content, link=page_url)] # Process comments if available comment_text = self._get_comment_string_for_page_id(page_id) if comment_text: - sections.append( - TextSection(text=comment_text, link=f"{page_url}#comments") - ) + sections.append(TextSection(text=comment_text, link=f"{page_url}#comments")) # Note: attachments are no longer merged into the page document. # They are indexed as separate documents downstream. @@ -1588,9 +1457,7 @@ class ConfluenceConnector( author = page["version"]["by"] display_name = author.get("displayName", "Unknown") email = author.get("email", "unknown@domain.invalid") - primary_owners.append( - BasicExpertInfo(display_name=display_name, email=email) - ) + primary_owners.append(BasicExpertInfo(display_name=display_name, email=email)) # Create the document return Document( @@ -1643,32 +1510,21 @@ class ConfluenceConnector( # but doing the check here avoids an unnecessary download. Due for refactoring. if not self.allow_images: if media_type.startswith("image/"): - logging.info( - f"Skipping attachment because allow images is False: {attachment['title']}" - ) + logging.info(f"Skipping attachment because allow images is False: {attachment['title']}") continue if not validate_attachment_filetype( attachment, ): - logging.info( - f"Skipping attachment because it is not an accepted file type: {attachment['title']}" - ) + logging.info(f"Skipping attachment because it is not an accepted file type: {attachment['title']}") continue - - logging.info( - f"Processing attachment: {attachment['title']} attached to page {page['title']}" - ) + logging.info(f"Processing attachment: {attachment['title']} attached to page {page['title']}") # Attachment document id: use the download URL for stable identity try: - object_url = build_confluence_document_id( - self.wiki_base, attachment["_links"]["download"], self.is_cloud - ) + object_url = build_confluence_document_id(self.wiki_base, attachment["_links"]["download"], self.is_cloud) except Exception as e: - logging.warning( - f"Invalid attachment url for id {attachment['id']}, skipping" - ) + logging.warning(f"Invalid attachment url for id {attachment['id']}, skipping") logging.debug(f"Error building attachment url: {e}") continue try: @@ -1697,19 +1553,15 @@ class ConfluenceConnector( labels.append(label.get("name", "")) if labels: attachment_metadata["labels"] = labels - page_url = page_url or build_confluence_document_id( - self.wiki_base, page["_links"]["webui"], self.is_cloud - ) + page_url = page_url or build_confluence_document_id(self.wiki_base, page["_links"]["webui"], self.is_cloud) attachment_metadata["parent_page_id"] = page_url - attachment_id = build_confluence_document_id( - self.wiki_base, attachment["_links"]["webui"], self.is_cloud - ) + attachment_id = build_confluence_document_id(self.wiki_base, attachment["_links"]["webui"], self.is_cloud) # Build semantic identifier with space and page context attachment_title = attachment.get("title", object_url) space_name = page.get("space", {}).get("name", "") page_title = page.get("title", "") - + # Create hierarchical name: Space / Page / Attachment attachment_path_parts = [] if space_name: @@ -1717,16 +1569,16 @@ class ConfluenceConnector( if page_title: attachment_path_parts.append(page_title) attachment_path_parts.append(attachment_title) - + full_attachment_path = " / ".join(attachment_path_parts) if len(attachment_path_parts) > 1 else attachment_title - + # Track attachment names for duplicate detection if attachment_title not in self._document_name_counts: self._document_name_counts[attachment_title] = 0 self._document_name_paths[attachment_title] = [] self._document_name_counts[attachment_title] += 1 self._document_name_paths[attachment_title].append(full_attachment_path) - + # Use simple name if no duplicates, otherwise use full path if self._document_name_counts[attachment_title] == 1: attachment_semantic_identifier = attachment_title @@ -1738,13 +1590,10 @@ class ConfluenceConnector( author = attachment["version"]["by"] display_name = author.get("displayName", "Unknown") email = author.get("email", "unknown@domain.invalid") - primary_owners = [ - BasicExpertInfo(display_name=display_name, email=email) - ] + primary_owners = [BasicExpertInfo(display_name=display_name, email=email)] extension = Path(attachment.get("title", "")).suffix or ".unknown" - attachment_doc = Document( id=attachment_id, # sections=sections, @@ -1754,12 +1603,7 @@ class ConfluenceConnector( blob=file_blob, size_bytes=len(file_blob), metadata=attachment_metadata, - doc_updated_at=( - datetime_from_string(attachment["version"]["when"]) - if attachment.get("version") - and attachment["version"].get("when") - else None - ), + doc_updated_at=(datetime_from_string(attachment["version"]["when"]) if attachment.get("version") and attachment["version"].get("when") else None), primary_owners=primary_owners, ) if self._is_newer_than_start(attachment_doc.doc_updated_at, start): @@ -1802,11 +1646,9 @@ class ConfluenceConnector( # use "start" when last_updated is 0 or for confluence server start_ts = start - page_query_url = checkpoint.next_page_url or self._build_page_retrieval_url( - start_ts, end, self.batch_size - ) + page_query_url = checkpoint.next_page_url or self._build_page_retrieval_url(start_ts, end, self.batch_size) logging.debug(f"page_query_url: {page_query_url}") - + # store the next page start for confluence server, cursor for confluence cloud def store_next_page_url(next_page_url: str) -> None: checkpoint.next_page_url = next_page_url @@ -1828,9 +1670,7 @@ class ConfluenceConnector( yield doc_or_failure # Now get attachments for that page: - attachment_docs, attachment_failures = self._fetch_page_attachments( - page, start, end - ) + attachment_docs, attachment_failures = self._fetch_page_attachments(page, start, end) # yield attached docs and failures yield from attachment_docs # yield from attachment_failures @@ -1854,9 +1694,7 @@ class ConfluenceConnector( or paginated_page_retrieval methods. """ page_query = self._construct_page_cql_query(start, end) - cql_url = self.confluence_client.build_cql_url( - page_query, expand=",".join(_PAGE_EXPANSION_FIELDS) - ) + cql_url = self.confluence_client.build_cql_url(page_query, expand=",".join(_PAGE_EXPANSION_FIELDS)) logging.info(f"[Confluence Connector] Building CQL URL {cql_url}") return update_param_in_path(cql_url, "limit", str(limit)) @@ -1925,16 +1763,10 @@ class ConfluenceConnector( space_level_access_info: dict[str, ExternalAccess] = {} if include_permissions: - space_level_access_info = get_all_space_permissions( - self.confluence_client, self.is_cloud - ) + space_level_access_info = get_all_space_permissions(self.confluence_client, self.is_cloud) - def get_external_access( - doc_id: str, restrictions: dict[str, Any], ancestors: list[dict[str, Any]] - ) -> ExternalAccess | None: - return get_page_restrictions( - self.confluence_client, doc_id, restrictions, ancestors - ) or space_level_access_info.get(page_space_key) + def get_external_access(doc_id: str, restrictions: dict[str, Any], ancestors: list[dict[str, Any]]) -> ExternalAccess | None: + return get_page_restrictions(self.confluence_client, doc_id, restrictions, ancestors) or space_level_access_info.get(page_space_key) # Query pages page_query = self.base_cql_page_query + self.cql_label_filter @@ -1948,17 +1780,11 @@ class ConfluenceConnector( page_space_key = page.get("space", {}).get("key") page_ancestors = page.get("ancestors", []) - page_id = build_confluence_document_id( - self.wiki_base, page["_links"]["webui"], self.is_cloud - ) + page_id = build_confluence_document_id(self.wiki_base, page["_links"]["webui"], self.is_cloud) doc_metadata_list.append( SlimDocument( id=page_id, - external_access=( - get_external_access(page_id, page_restrictions, page_ancestors) - if include_permissions - else None - ), + external_access=(get_external_access(page_id, page_restrictions, page_ancestors) if include_permissions else None), ) ) @@ -1992,13 +1818,7 @@ class ConfluenceConnector( doc_metadata_list.append( SlimDocument( id=attachment_id, - external_access=( - get_external_access( - attachment_id, attachment_restrictions, [] - ) - if include_permissions - else None - ), + external_access=(get_external_access(attachment_id, attachment_restrictions, []) if include_permissions else None), ) ) @@ -2007,9 +1827,7 @@ class ConfluenceConnector( doc_metadata_list = doc_metadata_list[_SLIM_DOC_BATCH_SIZE:] if callback and callback.should_stop(): - raise RuntimeError( - "retrieve_all_slim_docs_perm_sync: Stop signal detected" - ) + raise RuntimeError("retrieve_all_slim_docs_perm_sync: Stop signal detected") if callback: callback.progress("retrieve_all_slim_docs_perm_sync", 1) @@ -2021,35 +1839,21 @@ class ConfluenceConnector( except HTTPError as e: status_code = e.response.status_code if e.response else None if status_code == 401: - raise CredentialExpiredError( - "Invalid or expired Confluence credentials (HTTP 401)." - ) + raise CredentialExpiredError("Invalid or expired Confluence credentials (HTTP 401).") elif status_code == 403: - raise InsufficientPermissionsError( - "Insufficient permissions to access Confluence resources (HTTP 403)." - ) - raise UnexpectedValidationError( - f"Unexpected Confluence error (status={status_code}): {e}" - ) + raise InsufficientPermissionsError("Insufficient permissions to access Confluence resources (HTTP 403).") + raise UnexpectedValidationError(f"Unexpected Confluence error (status={status_code}): {e}") except Exception as e: - raise UnexpectedValidationError( - f"Unexpected error while validating Confluence settings: {e}" - ) + raise UnexpectedValidationError(f"Unexpected error while validating Confluence settings: {e}") if self.space: try: self.low_timeout_confluence_client.get_space(self.space) except ApiError as e: - raise ConnectorValidationError( - "Invalid Confluence space key provided" - ) from e + raise ConnectorValidationError("Invalid Confluence space key provided") from e if not spaces or not spaces.get("results"): - raise ConnectorValidationError( - "No Confluence spaces found. Either your credentials lack permissions, or " - "there truly are no spaces in this Confluence instance." - ) - + raise ConnectorValidationError("No Confluence spaces found. Either your credentials lack permissions, or there truly are no spaces in this Confluence instance.") if __name__ == "__main__": diff --git a/common/data_source/connector_runner.py b/common/data_source/connector_runner.py index d47d651284..ff77f23122 100644 --- a/common/data_source/connector_runner.py +++ b/common/data_source/connector_runner.py @@ -26,14 +26,10 @@ def batched_doc_ids( batch_size: int, ) -> Generator[set[str], None, None]: batch: set[str] = set() - for document, failure, next_checkpoint in CheckpointOutputWrapper[CT]()( - checkpoint_connector_generator - ): + for document, failure, next_checkpoint in CheckpointOutputWrapper[CT]()(checkpoint_connector_generator): if document is not None: batch.add(document.id) - elif ( - failure and failure.failed_document and failure.failed_document.document_id - ): + elif failure and failure.failed_document and failure.failed_document.document_id: batch.add(failure.failed_document.document_id) if len(batch) >= batch_size: @@ -76,14 +72,10 @@ class CheckpointOutputWrapper(Generic[CT]): elif isinstance(document_or_failure, ConnectorFailure): yield None, document_or_failure, None else: - raise ValueError( - f"Invalid document_or_failure type: {type(document_or_failure)}" - ) + raise ValueError(f"Invalid document_or_failure type: {type(document_or_failure)}") if self.next_checkpoint is None: - raise RuntimeError( - "Checkpoint is None. This should never happen - the connector should always return a checkpoint." - ) + raise RuntimeError("Checkpoint is None. This should never happen - the connector should always return a checkpoint.") yield None, None, self.next_checkpoint @@ -105,9 +97,7 @@ class ConnectorRunner(Generic[CT]): time_range: TimeRange | None = None, ): if not isinstance(connector, CheckpointedConnector) and include_permissions: - raise ValueError( - "include_permissions cannot be True for non-checkpointed connectors" - ) + raise ValueError("include_permissions cannot be True for non-checkpointed connectors") self.connector = connector self.time_range = time_range @@ -116,7 +106,9 @@ class ConnectorRunner(Generic[CT]): self.doc_batch: list[Document] = [] - def run(self, checkpoint: CT) -> Generator[ + def run( + self, checkpoint: CT + ) -> Generator[ tuple[list[Document] | None, ConnectorFailure | None, CT | None], None, None, @@ -129,15 +121,9 @@ class ConnectorRunner(Generic[CT]): start = time.monotonic() if self.include_permissions: - if not isinstance( - self.connector, CheckpointedConnectorWithPermSync - ): - raise ValueError( - "Connector does not support permission syncing" - ) - load_from_checkpoint = ( - self.connector.load_from_checkpoint_with_perm_sync - ) + if not isinstance(self.connector, CheckpointedConnectorWithPermSync): + raise ValueError("Connector does not support permission syncing") + load_from_checkpoint = self.connector.load_from_checkpoint_with_perm_sync else: load_from_checkpoint = self.connector.load_from_checkpoint checkpoint_connector_generator = load_from_checkpoint( @@ -147,9 +133,7 @@ class ConnectorRunner(Generic[CT]): ) next_checkpoint: CT | None = None # this is guaranteed to always run at least once with next_checkpoint being non-None - for document, failure, next_checkpoint in CheckpointOutputWrapper[CT]()( - checkpoint_connector_generator - ): + for document, failure, next_checkpoint in CheckpointOutputWrapper[CT]()(checkpoint_connector_generator): if document is not None and isinstance(document, Document): self.doc_batch.append(document) @@ -167,9 +151,7 @@ class ConnectorRunner(Generic[CT]): yield None, None, next_checkpoint - logging.debug( - f"Connector took {time.monotonic() - start} seconds to get to the next checkpoint." - ) + logging.debug(f"Connector took {time.monotonic() - start} seconds to get to the next checkpoint.") else: finished_checkpoint = self.connector.build_dummy_checkpoint() @@ -207,11 +189,6 @@ class ConnectorRunner(Generic[CT]): # Get the local variables from the frame where the exception occurred local_vars = tb.tb_frame.f_locals - local_vars_str = "\n".join( - f"{key}: {value}" for key, value in local_vars.items() - ) - logging.error( - f"Error in connector. type: {exc_type};\n" - f"local_vars below -> \n{local_vars_str[:1024]}" - ) - raise \ No newline at end of file + local_vars_str = "\n".join(f"{key}: {value}" for key, value in local_vars.items()) + logging.error(f"Error in connector. type: {exc_type};\nlocal_vars below -> \n{local_vars_str[:1024]}") + raise diff --git a/common/data_source/cross_connector_utils/rate_limit_wrapper.py b/common/data_source/cross_connector_utils/rate_limit_wrapper.py index bc0e0b470d..c9a7fcce2a 100644 --- a/common/data_source/cross_connector_utils/rate_limit_wrapper.py +++ b/common/data_source/cross_connector_utils/rate_limit_wrapper.py @@ -52,16 +52,11 @@ class _RateLimitDecorator: sleep_cnt = 0 while len(self.call_history) == self.max_calls: sleep_time = self.sleep_time * (self.sleep_backoff**sleep_cnt) - logging.warning( - f"Rate limit exceeded for function {func.__name__}. " - f"Waiting {sleep_time} seconds before retrying." - ) + logging.warning(f"Rate limit exceeded for function {func.__name__}. Waiting {sleep_time} seconds before retrying.") time.sleep(sleep_time) sleep_cnt += 1 if self.max_num_sleep != 0 and sleep_cnt >= self.max_num_sleep: - raise RateLimitTriedTooManyTimesError( - f"Exceeded '{self.max_num_sleep}' retries for function '{func.__name__}'" - ) + raise RateLimitTriedTooManyTimesError(f"Exceeded '{self.max_num_sleep}' retries for function '{func.__name__}'") self._cleanup() @@ -74,11 +69,7 @@ class _RateLimitDecorator: def _cleanup(self) -> None: curr_time = time.monotonic() time_to_expire_before = curr_time - self.period - self.call_history = [ - call_time - for call_time in self.call_history - if call_time > time_to_expire_before - ] + self.call_history = [call_time for call_time in self.call_history if call_time > time_to_expire_before] rate_limit_builder = _RateLimitDecorator @@ -90,17 +81,13 @@ use the following instead""" R = TypeVar("R", bound=Callable[..., requests.Response]) -def wrap_request_to_handle_ratelimiting( - request_fn: R, default_wait_time_sec: int = 30, max_waits: int = 30 -) -> R: +def wrap_request_to_handle_ratelimiting(request_fn: R, default_wait_time_sec: int = 30, max_waits: int = 30) -> R: def wrapped_request(*args: list, **kwargs: dict[str, Any]) -> requests.Response: for _ in range(max_waits): response = request_fn(*args, **kwargs) if response.status_code == 429: try: - wait_time = int( - response.headers.get("Retry-After", default_wait_time_sec) - ) + wait_time = int(response.headers.get("Retry-After", default_wait_time_sec)) except ValueError: wait_time = default_wait_time_sec @@ -123,4 +110,4 @@ class _RateLimitedRequest: post = _rate_limited_post -rl_requests = _RateLimitedRequest \ No newline at end of file +rl_requests = _RateLimitedRequest diff --git a/common/data_source/cross_connector_utils/retry_wrapper.py b/common/data_source/cross_connector_utils/retry_wrapper.py index a055847975..9f4a25604f 100644 --- a/common/data_source/cross_connector_utils/retry_wrapper.py +++ b/common/data_source/cross_connector_utils/retry_wrapper.py @@ -13,6 +13,7 @@ from common.data_source.config import REQUEST_TIMEOUT_SECONDS F = TypeVar("F", bound=Callable[..., Any]) logger = logging.getLogger(__name__) + def retry_builder( tries: int = 20, delay: float = 0.1, @@ -85,4 +86,4 @@ def request_with_retries( raise return response - return _make_request() \ No newline at end of file + return _make_request() diff --git a/common/data_source/discord_connector.py b/common/data_source/discord_connector.py index e047148f33..02deddc746 100644 --- a/common/data_source/discord_connector.py +++ b/common/data_source/discord_connector.py @@ -281,12 +281,12 @@ class DiscordConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): id = doc_batch[0].id min_updated_at = doc_batch[0].doc_updated_at max_updated_at = doc_batch[-1].doc_updated_at - blob = b'' + blob = b"" size_bytes = 0 for d in doc_batch: min_updated_at = min(min_updated_at, d.doc_updated_at) max_updated_at = max(max_updated_at, d.doc_updated_at) - blob += b'\n\n' + d.blob + blob += b"\n\n" + d.blob size_bytes += d.size_bytes return Document( diff --git a/common/data_source/exceptions.py b/common/data_source/exceptions.py index eeb60132bd..330deb9b8b 100644 --- a/common/data_source/exceptions.py +++ b/common/data_source/exceptions.py @@ -3,28 +3,34 @@ class ConnectorMissingCredentialError(Exception): """Missing credentials exception""" + def __init__(self, connector_name: str): super().__init__(f"Missing credentials for {connector_name}") class ConnectorValidationError(Exception): """Connector validation exception""" + pass class CredentialExpiredError(Exception): """Credential expired exception""" + pass class InsufficientPermissionsError(Exception): """Insufficient permissions exception""" + pass class UnexpectedValidationError(Exception): """Unexpected validation exception""" + pass + class RateLimitTriedTooManyTimesError(Exception): - pass \ No newline at end of file + pass diff --git a/common/data_source/file_types.py b/common/data_source/file_types.py index be4d56d7b5..755712b799 100644 --- a/common/data_source/file_types.py +++ b/common/data_source/file_types.py @@ -1,13 +1,7 @@ -PRESENTATION_MIME_TYPE = ( - "application/vnd.openxmlformats-officedocument.presentationml.presentation" -) +PRESENTATION_MIME_TYPE = "application/vnd.openxmlformats-officedocument.presentationml.presentation" -SPREADSHEET_MIME_TYPE = ( - "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" -) -WORD_PROCESSING_MIME_TYPE = ( - "application/vnd.openxmlformats-officedocument.wordprocessingml.document" -) +SPREADSHEET_MIME_TYPE = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" +WORD_PROCESSING_MIME_TYPE = "application/vnd.openxmlformats-officedocument.wordprocessingml.document" PDF_MIME_TYPE = "application/pdf" @@ -35,6 +29,4 @@ class UploadMimeTypes: "application/epub+zip", } - ALLOWED_MIME_TYPES = IMAGE_MIME_TYPES.union( - TEXT_MIME_TYPES, DOCUMENT_MIME_TYPES, CSV_MIME_TYPES - ) + ALLOWED_MIME_TYPES = IMAGE_MIME_TYPES.union(TEXT_MIME_TYPES, DOCUMENT_MIME_TYPES, CSV_MIME_TYPES) diff --git a/common/data_source/github/connector.py b/common/data_source/github/connector.py index 2d65c995e6..1e9be17d1b 100644 --- a/common/data_source/github/connector.py +++ b/common/data_source/github/connector.py @@ -87,15 +87,11 @@ def get_nextUrl_key(pag_list: PaginatedList[PullRequest | Issue]) -> str: return "" -def get_nextUrl( - pag_list: PaginatedList[PullRequest | Issue], nextUrl_key: str -) -> str | None: +def get_nextUrl(pag_list: PaginatedList[PullRequest | Issue], nextUrl_key: str) -> str | None: return getattr(pag_list, nextUrl_key) if nextUrl_key else None -def set_nextUrl( - pag_list: PaginatedList[PullRequest | Issue], nextUrl_key: str, nextUrl: str -) -> None: +def set_nextUrl(pag_list: PaginatedList[PullRequest | Issue], nextUrl_key: str, nextUrl: str) -> None: if nextUrl_key: setattr(pag_list, nextUrl_key, nextUrl) elif nextUrl: @@ -119,11 +115,7 @@ def _paginate_until_error( # over previous calls. Unfortunately, this WILL retrieve all # pages before the one we are resuming from, so we really # don't want this case to be hit often - logging.warning( - "Retrying from a previous cursor-based pagination call. " - "This will retrieve all pages before the one we are resuming from, " - "which may take a while and consume many API calls." - ) + logging.warning("Retrying from a previous cursor-based pagination call. This will retrieve all pages before the one we are resuming from, which may take a while and consume many API calls.") pag_list = cast(PaginatedList[PullRequest | Issue], pag_list[prev_num_objs:]) num_objs = 0 @@ -137,9 +129,7 @@ def _paginate_until_error( cursor_url_callback(get_nextUrl(pag_list, nextUrl_key), num_objs) if num_objs % CURSOR_LOG_FREQUENCY == 0: - logging.info( - f"Retrieved {num_objs} objects with current cursor url: {get_nextUrl(pag_list, nextUrl_key)}" - ) + logging.info(f"Retrieved {num_objs} objects with current cursor url: {get_nextUrl(pag_list, nextUrl_key)}") except Exception as e: logging.exception(f"Error during cursor-based pagination: {e}") @@ -147,14 +137,8 @@ def _paginate_until_error( raise if get_nextUrl(pag_list, nextUrl_key) is not None and not retrying: - logging.info( - "Assuming that this error is due to cursor " - "expiration because no objects were retrieved. " - "Retrying from the first page." - ) - yield from _paginate_until_error( - git_objs, None, prev_num_objs, cursor_url_callback, retrying=True - ) + logging.info("Assuming that this error is due to cursor expiration because no objects were retrieved. Retrying from the first page.") + yield from _paginate_until_error(git_objs, None, prev_num_objs, cursor_url_callback, retrying=True) return # for no cursor url or if we reach this point after a retry, raise the error @@ -174,16 +158,12 @@ def _get_batch_rate_limited( attempt_num: int = 0, ) -> Generator[PullRequest | Issue, None, None]: if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES: - raise RuntimeError( - "Re-tried fetching batch too many times. Something is going wrong with fetching objects from Github" - ) + raise RuntimeError("Re-tried fetching batch too many times. Something is going wrong with fetching objects from Github") try: if cursor_url: # when this is set, we are resuming from an earlier # cursor-based pagination call. - yield from _paginate_until_error( - git_objs, cursor_url, prev_num_objs, cursor_url_callback - ) + yield from _paginate_until_error(git_objs, cursor_url, prev_num_objs, cursor_url_callback) return objs = list(git_objs().get_page(page_num)) # fetch all data here to disable lazy loading later @@ -204,13 +184,7 @@ def _get_batch_rate_limited( attempt_num + 1, ) except GithubException as e: - if not ( - e.status == 422 - and ( - "cursor" in (e.message or "") - or "cursor" in (e.data or {}).get("message", "") - ) - ): + if not (e.status == 422 and ("cursor" in (e.message or "") or "cursor" in (e.data or {}).get("message", ""))): raise # Fallback to a cursor-based pagination strategy # This can happen for "large datasets," but there's no documentation @@ -218,9 +192,7 @@ def _get_batch_rate_limited( # Error message: # "Pagination with the page parameter is not supported for large datasets, # please use cursor based pagination (after/before)" - yield from _paginate_until_error( - git_objs, cursor_url, prev_num_objs, cursor_url_callback - ) + yield from _paginate_until_error(git_objs, cursor_url, prev_num_objs, cursor_url_callback) def _get_userinfo(user: NamedUser) -> dict[str, str]: @@ -242,28 +214,22 @@ def _get_userinfo(user: NamedUser) -> dict[str, str]: } -def _convert_pr_to_document( - pull_request: PullRequest, repo_external_access: ExternalAccess | None -) -> Document: +def _convert_pr_to_document(pull_request: PullRequest, repo_external_access: ExternalAccess | None) -> Document: repo_name = pull_request.base.repo.full_name if pull_request.base else "" doc_metadata = DocMetadata(repo=repo_name) - file_content_byte = pull_request.body.encode('utf-8') if pull_request.body else b"" + file_content_byte = pull_request.body.encode("utf-8") if pull_request.body else b"" name = sanitize_filename(pull_request.title, "md") return Document( id=pull_request.html_url, - blob= file_content_byte, + blob=file_content_byte, source=DocumentSource.GITHUB, external_access=repo_external_access, semantic_identifier=f"{pull_request.number}:{name}", # updated_at is UTC time but is timezone unaware, explicitly add UTC # as there is logic in indexing to prevent wrong timestamped docs # due to local time discrepancies with UTC - doc_updated_at=( - pull_request.updated_at.replace(tzinfo=timezone.utc) - if pull_request.updated_at - else None - ), + doc_updated_at=(pull_request.updated_at.replace(tzinfo=timezone.utc) if pull_request.updated_at else None), extension=".md", # this metadata is used in perm sync size_bytes=len(file_content_byte) if file_content_byte else 0, @@ -277,40 +243,16 @@ def _convert_pr_to_document( "merged": pull_request.merged, "state": pull_request.state, "user": _get_userinfo(pull_request.user) if pull_request.user else None, - "assignees": [ - _get_userinfo(assignee) for assignee in pull_request.assignees - ], - "repo": ( - pull_request.base.repo.full_name if pull_request.base else None - ), + "assignees": [_get_userinfo(assignee) for assignee in pull_request.assignees], + "repo": (pull_request.base.repo.full_name if pull_request.base else None), "num_commits": str(pull_request.commits), "num_files_changed": str(pull_request.changed_files), "labels": [label.name for label in pull_request.labels], - "created_at": ( - pull_request.created_at.replace(tzinfo=timezone.utc) - if pull_request.created_at - else None - ), - "updated_at": ( - pull_request.updated_at.replace(tzinfo=timezone.utc) - if pull_request.updated_at - else None - ), - "closed_at": ( - pull_request.closed_at.replace(tzinfo=timezone.utc) - if pull_request.closed_at - else None - ), - "merged_at": ( - pull_request.merged_at.replace(tzinfo=timezone.utc) - if pull_request.merged_at - else None - ), - "merged_by": ( - _get_userinfo(pull_request.merged_by) - if pull_request.merged_by - else None - ), + "created_at": (pull_request.created_at.replace(tzinfo=timezone.utc) if pull_request.created_at else None), + "updated_at": (pull_request.updated_at.replace(tzinfo=timezone.utc) if pull_request.updated_at else None), + "closed_at": (pull_request.closed_at.replace(tzinfo=timezone.utc) if pull_request.closed_at else None), + "merged_at": (pull_request.merged_at.replace(tzinfo=timezone.utc) if pull_request.merged_at else None), + "merged_by": (_get_userinfo(pull_request.merged_by) if pull_request.merged_by else None), }.items() if v is not None }, @@ -322,12 +264,10 @@ def _fetch_issue_comments(issue: Issue) -> str: return "\nComment: ".join(comment.body for comment in comments) -def _convert_issue_to_document( - issue: Issue, repo_external_access: ExternalAccess | None -) -> Document: +def _convert_issue_to_document(issue: Issue, repo_external_access: ExternalAccess | None) -> Document: repo_name = issue.repository.full_name if issue.repository else "" doc_metadata = DocMetadata(repo=repo_name) - file_content_byte = issue.body.encode('utf-8') if issue.body else b"" + file_content_byte = issue.body.encode("utf-8") if issue.body else b"" name = sanitize_filename(issue.title, "md") return Document( @@ -353,24 +293,10 @@ def _convert_issue_to_document( "assignees": [_get_userinfo(assignee) for assignee in issue.assignees], "repo": issue.repository.full_name if issue.repository else None, "labels": [label.name for label in issue.labels], - "created_at": ( - issue.created_at.replace(tzinfo=timezone.utc) - if issue.created_at - else None - ), - "updated_at": ( - issue.updated_at.replace(tzinfo=timezone.utc) - if issue.updated_at - else None - ), - "closed_at": ( - issue.closed_at.replace(tzinfo=timezone.utc) - if issue.closed_at - else None - ), - "closed_by": ( - _get_userinfo(issue.closed_by) if issue.closed_by else None - ), + "created_at": (issue.created_at.replace(tzinfo=timezone.utc) if issue.created_at else None), + "updated_at": (issue.updated_at.replace(tzinfo=timezone.utc) if issue.updated_at else None), + "closed_at": (issue.closed_at.replace(tzinfo=timezone.utc) if issue.closed_at else None), + "closed_by": (_get_userinfo(issue.closed_by) if issue.closed_by else None), }.items() if v is not None }, @@ -451,13 +377,9 @@ class GithubConnector(CheckpointedConnectorWithPermSyncGH[GithubConnectorCheckpo return None - def get_github_repo( - self, github_client: Github, attempt_num: int = 0 - ) -> Repository.Repository: + def get_github_repo(self, github_client: Github, attempt_num: int = 0) -> Repository.Repository: if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES: - raise RuntimeError( - "Re-tried fetching repo too many times. Something is going wrong with fetching objects from Github" - ) + raise RuntimeError("Re-tried fetching repo too many times. Something is going wrong with fetching objects from Github") try: return github_client.get_repo(f"{self.repo_owner}/{self.repositories}") @@ -465,21 +387,15 @@ class GithubConnector(CheckpointedConnectorWithPermSyncGH[GithubConnectorCheckpo sleep_after_rate_limit_exception(github_client) return self.get_github_repo(github_client, attempt_num + 1) - def get_github_repos( - self, github_client: Github, attempt_num: int = 0 - ) -> list[Repository.Repository]: + def get_github_repos(self, github_client: Github, attempt_num: int = 0) -> list[Repository.Repository]: """Get specific repositories based on comma-separated repo_name string.""" if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES: - raise RuntimeError( - "Re-tried fetching repos too many times. Something is going wrong with fetching objects from Github" - ) + raise RuntimeError("Re-tried fetching repos too many times. Something is going wrong with fetching objects from Github") try: repos = [] # Split repo_name by comma and strip whitespace - repo_names = [ - name.strip() for name in (cast(str, self.repositories)).split(",") - ] + repo_names = [name.strip() for name in (cast(str, self.repositories)).split(",")] for repo_name in repo_names: if repo_name: # Skip empty strings @@ -487,22 +403,16 @@ class GithubConnector(CheckpointedConnectorWithPermSyncGH[GithubConnectorCheckpo repo = github_client.get_repo(f"{self.repo_owner}/{repo_name}") repos.append(repo) except GithubException as e: - logging.warning( - f"Could not fetch repo {self.repo_owner}/{repo_name}: {e}" - ) + logging.warning(f"Could not fetch repo {self.repo_owner}/{repo_name}: {e}") return repos except RateLimitExceededException: sleep_after_rate_limit_exception(github_client) return self.get_github_repos(github_client, attempt_num + 1) - def get_all_repos( - self, github_client: Github, attempt_num: int = 0 - ) -> list[Repository.Repository]: + def get_all_repos(self, github_client: Github, attempt_num: int = 0) -> list[Repository.Repository]: if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES: - raise RuntimeError( - "Re-tried fetching repos too many times. Something is going wrong with fetching objects from Github" - ) + raise RuntimeError("Re-tried fetching repos too many times. Something is going wrong with fetching objects from Github") try: # Try to get organization first @@ -518,19 +428,11 @@ class GithubConnector(CheckpointedConnectorWithPermSyncGH[GithubConnectorCheckpo sleep_after_rate_limit_exception(github_client) return self.get_all_repos(github_client, attempt_num + 1) - def _pull_requests_func( - self, repo: Repository.Repository - ) -> Callable[[], PaginatedList[PullRequest]]: - return lambda: repo.get_pulls( - state=self.state_filter, sort="updated", direction="desc" - ) + def _pull_requests_func(self, repo: Repository.Repository) -> Callable[[], PaginatedList[PullRequest]]: + return lambda: repo.get_pulls(state=self.state_filter, sort="updated", direction="desc") - def _issues_func( - self, repo: Repository.Repository - ) -> Callable[[], PaginatedList[Issue]]: - return lambda: repo.get_issues( - state=self.state_filter, sort="updated", direction="desc" - ) + def _issues_func(self, repo: Repository.Repository) -> Callable[[], PaginatedList[Issue]]: + return lambda: repo.get_issues(state=self.state_filter, sort="updated", direction="desc") def _fetch_from_github( self, @@ -582,9 +484,7 @@ class GithubConnector(CheckpointedConnectorWithPermSyncGH[GithubConnectorCheckpo cursor_url_callback = make_cursor_url_callback(checkpoint) repo_external_access: ExternalAccess | None = None if include_permissions: - repo_external_access = get_external_access_permission( - repo, self.github_client - ) + repo_external_access = get_external_access_permission(repo, self.github_client) if self.include_prs and checkpoint.stage == GithubConnectorStage.PRS: logging.info(f"Fetching PRs for repo: {repo.name}") @@ -603,31 +503,19 @@ class GithubConnector(CheckpointedConnectorWithPermSyncGH[GithubConnectorCheckpo for pr in pr_batch: num_prs += 1 # we iterate backwards in time, so at this point we stop processing prs - if ( - start is not None - and pr.updated_at - and pr.updated_at.replace(tzinfo=timezone.utc) <= start - ): + if start is not None and pr.updated_at and pr.updated_at.replace(tzinfo=timezone.utc) <= start: done_with_prs = True break # Skip PRs updated after the end date - if ( - end is not None - and pr.updated_at - and pr.updated_at.replace(tzinfo=timezone.utc) > end - ): + if end is not None and pr.updated_at and pr.updated_at.replace(tzinfo=timezone.utc) > end: continue try: - yield _convert_pr_to_document( - cast(PullRequest, pr), repo_external_access - ) + yield _convert_pr_to_document(cast(PullRequest, pr), repo_external_access) except Exception as e: error_msg = f"Error converting PR to document: {e}" logging.exception(error_msg) yield ConnectorFailure( - failed_document=DocumentFailure( - document_id=str(pr.id), document_link=pr.html_url - ), + failed_document=DocumentFailure(document_id=str(pr.id), document_link=pr.html_url), failure_message=error_msg, exception=e, ) @@ -676,17 +564,11 @@ class GithubConnector(CheckpointedConnectorWithPermSyncGH[GithubConnectorCheckpo num_issues += 1 issue = cast(Issue, issue) # we iterate backwards in time, so at this point we stop processing prs - if ( - start is not None - and issue.updated_at.replace(tzinfo=timezone.utc) <= start - ): + if start is not None and issue.updated_at.replace(tzinfo=timezone.utc) <= start: done_with_issues = True break # Skip PRs updated after the end date - if ( - end is not None - and issue.updated_at.replace(tzinfo=timezone.utc) > end - ): + if end is not None and issue.updated_at.replace(tzinfo=timezone.utc) > end: continue if issue.pull_request is not None: @@ -731,9 +613,7 @@ class GithubConnector(CheckpointedConnectorWithPermSyncGH[GithubConnectorCheckpo checkpoint.reset() if checkpoint.cached_repo_ids: - logging.info( - f"{len(checkpoint.cached_repo_ids)} checkpoint repos remaining (IDs: {checkpoint.cached_repo_ids})" - ) + logging.info(f"{len(checkpoint.cached_repo_ids)} checkpoint repos remaining (IDs: {checkpoint.cached_repo_ids})") else: logging.info("There are no more checkpoint repos left.") @@ -775,9 +655,7 @@ class GithubConnector(CheckpointedConnectorWithPermSyncGH[GithubConnectorCheckpo end: SecondsSinceUnixEpoch, checkpoint: GithubConnectorCheckpoint, ) -> CheckpointOutput[GithubConnectorCheckpoint]: - return self._load_from_checkpoint( - start, end, checkpoint, include_permissions=False - ) + return self._load_from_checkpoint(start, end, checkpoint, include_permissions=False) @override def load_from_checkpoint_with_perm_sync( @@ -786,18 +664,14 @@ class GithubConnector(CheckpointedConnectorWithPermSyncGH[GithubConnectorCheckpo end: SecondsSinceUnixEpoch, checkpoint: GithubConnectorCheckpoint, ) -> CheckpointOutput[GithubConnectorCheckpoint]: - return self._load_from_checkpoint( - start, end, checkpoint, include_permissions=True - ) + return self._load_from_checkpoint(start, end, checkpoint, include_permissions=True) def validate_connector_settings(self) -> None: if self.github_client is None: raise ConnectorMissingCredentialError("GitHub credentials not loaded.") if not self.repo_owner: - raise ConnectorValidationError( - "Invalid connector settings: 'repo_owner' must be provided." - ) + raise ConnectorValidationError("Invalid connector settings: 'repo_owner' must be provided.") try: if self.repositories: @@ -805,9 +679,7 @@ class GithubConnector(CheckpointedConnectorWithPermSyncGH[GithubConnectorCheckpo # Multiple repositories specified repo_names = [name.strip() for name in self.repositories.split(",")] if not repo_names: - raise ConnectorValidationError( - "Invalid connector settings: No valid repository names provided." - ) + raise ConnectorValidationError("Invalid connector settings: No valid repository names provided.") # Validate at least one repository exists and is accessible valid_repos = False @@ -818,32 +690,22 @@ class GithubConnector(CheckpointedConnectorWithPermSyncGH[GithubConnectorCheckpo continue try: - test_repo = self.github_client.get_repo( - f"{self.repo_owner}/{repo_name}" - ) - logging.info( - f"Successfully accessed repository: {self.repo_owner}/{repo_name}" - ) + test_repo = self.github_client.get_repo(f"{self.repo_owner}/{repo_name}") + logging.info(f"Successfully accessed repository: {self.repo_owner}/{repo_name}") test_repo.get_contents("") valid_repos = True # If at least one repo is valid, we can proceed break except GithubException as e: - validation_errors.append( - f"Repository '{repo_name}': {e.data.get('message', str(e))}" - ) + validation_errors.append(f"Repository '{repo_name}': {e.data.get('message', str(e))}") if not valid_repos: - error_msg = ( - "None of the specified repositories could be accessed: " - ) + error_msg = "None of the specified repositories could be accessed: " error_msg += ", ".join(validation_errors) raise ConnectorValidationError(error_msg) else: # Single repository (backward compatibility) - test_repo = self.github_client.get_repo( - f"{self.repo_owner}/{self.repositories}" - ) + test_repo = self.github_client.get_repo(f"{self.repo_owner}/{self.repositories}") test_repo.get_contents("") else: # Try to get organization first @@ -851,10 +713,7 @@ class GithubConnector(CheckpointedConnectorWithPermSyncGH[GithubConnectorCheckpo org = self.github_client.get_organization(self.repo_owner) total_count = org.get_repos().totalCount if total_count == 0: - raise ConnectorValidationError( - f"Found no repos for organization: {self.repo_owner}. " - "Does the credential have the right scopes?" - ) + raise ConnectorValidationError(f"Found no repos for organization: {self.repo_owner}. Does the credential have the right scopes?") except GithubException as e: # Check for missing SSO MISSING_SSO_ERROR_MESSAGE = "You must grant your Personal Access token access to this organization".lower() @@ -865,9 +724,7 @@ class GithubConnector(CheckpointedConnectorWithPermSyncGH[GithubConnectorCheckpo "authorizing-a-personal-access-token-for-use-with-saml-single-sign-on" ) raise ConnectorValidationError( - f"Your GitHub token is missing authorization to access the " - f"`{self.repo_owner}` organization. Please follow the guide to " - f"authorize your token: {SSO_GUIDE_LINK}" + f"Your GitHub token is missing authorization to access the `{self.repo_owner}` organization. Please follow the guide to authorize your token: {SSO_GUIDE_LINK}" ) # If not an org, try as a user user = self.github_client.get_user(self.repo_owner) @@ -875,52 +732,31 @@ class GithubConnector(CheckpointedConnectorWithPermSyncGH[GithubConnectorCheckpo # Check if we can access any repos total_count = user.get_repos().totalCount if total_count == 0: - raise ConnectorValidationError( - f"Found no repos for user: {self.repo_owner}. " - "Does the credential have the right scopes?" - ) + raise ConnectorValidationError(f"Found no repos for user: {self.repo_owner}. Does the credential have the right scopes?") except RateLimitExceededException: - raise UnexpectedValidationError( - "Validation failed due to GitHub rate-limits being exceeded. Please try again later." - ) + raise UnexpectedValidationError("Validation failed due to GitHub rate-limits being exceeded. Please try again later.") except GithubException as e: if e.status == 401: - raise CredentialExpiredError( - "GitHub credential appears to be invalid or expired (HTTP 401)." - ) + raise CredentialExpiredError("GitHub credential appears to be invalid or expired (HTTP 401).") elif e.status == 403: - raise InsufficientPermissionsError( - "Your GitHub token does not have sufficient permissions for this repository (HTTP 403)." - ) + raise InsufficientPermissionsError("Your GitHub token does not have sufficient permissions for this repository (HTTP 403).") elif e.status == 404: if self.repositories: if "," in self.repositories: - raise ConnectorValidationError( - f"None of the specified GitHub repositories could be found for owner: {self.repo_owner}" - ) + raise ConnectorValidationError(f"None of the specified GitHub repositories could be found for owner: {self.repo_owner}") else: - raise ConnectorValidationError( - f"GitHub repository not found with name: {self.repo_owner}/{self.repositories}" - ) + raise ConnectorValidationError(f"GitHub repository not found with name: {self.repo_owner}/{self.repositories}") else: - raise ConnectorValidationError( - f"GitHub user or organization not found: {self.repo_owner}" - ) + raise ConnectorValidationError(f"GitHub user or organization not found: {self.repo_owner}") else: - raise ConnectorValidationError( - f"Unexpected GitHub error (status={e.status}): {e.data}" - ) + raise ConnectorValidationError(f"Unexpected GitHub error (status={e.status}): {e.data}") except Exception as exc: - raise Exception( - f"Unexpected error during GitHub settings validation: {exc}" - ) + raise Exception(f"Unexpected error during GitHub settings validation: {exc}") - def validate_checkpoint_json( - self, checkpoint_json: str - ) -> GithubConnectorCheckpoint: + def validate_checkpoint_json(self, checkpoint_json: str) -> GithubConnectorCheckpoint: return GithubConnectorCheckpoint.model_validate_json(checkpoint_json) def retrieve_slim_document( @@ -930,17 +766,13 @@ class GithubConnector(CheckpointedConnectorWithPermSyncGH[GithubConnectorCheckpo callback: Any = None, ) -> GenerateSlimDocumentOutput: start_value = 0.0 if start is None else start - end_value = ( - datetime.now(timezone.utc).timestamp() if end is None else end - ) + end_value = datetime.now(timezone.utc).timestamp() if end is None else end checkpoint = self.build_dummy_checkpoint() slim_batch: list[SlimDocument] = [] while checkpoint.has_more: wrapper = CheckpointOutputWrapper[GithubConnectorCheckpoint]() - for document, failure, next_checkpoint in wrapper( - self.load_from_checkpoint(start_value, end_value, checkpoint) - ): + for document, failure, next_checkpoint in wrapper(self.load_from_checkpoint(start_value, end_value, checkpoint)): if failure is not None: logging.warning( "GitHub connector failure during slim retrieval: %s", @@ -969,9 +801,7 @@ class GithubConnector(CheckpointedConnectorWithPermSyncGH[GithubConnectorCheckpo yield from self.retrieve_slim_document(callback=callback) def build_dummy_checkpoint(self) -> GithubConnectorCheckpoint: - return GithubConnectorCheckpoint( - stage=GithubConnectorStage.PRS, curr_page=0, has_more=True, num_retrieved=0 - ) + return GithubConnectorCheckpoint(stage=GithubConnectorStage.PRS, curr_page=0, has_more=True, num_retrieved=0) if __name__ == "__main__": @@ -982,9 +812,7 @@ if __name__ == "__main__": include_issues=True, include_prs=False, ) - connector.load_credentials( - {"github_access_token": ""} - ) + connector.load_credentials({"github_access_token": ""}) if connector.github_client: get_external_access_permission( @@ -998,9 +826,7 @@ if __name__ == "__main__": time_range = (start_time, end_time) # Initialize the runner with a batch size of 10 - runner: ConnectorRunner[GithubConnectorCheckpoint] = ConnectorRunner( - connector, batch_size=10, include_permissions=False, time_range=time_range - ) + runner: ConnectorRunner[GithubConnectorCheckpoint] = ConnectorRunner(connector, batch_size=10, include_permissions=False, time_range=time_range) # Get initial checkpoint checkpoint = connector.build_dummy_checkpoint() diff --git a/common/data_source/github/models.py b/common/data_source/github/models.py index 9754bfa8db..a87cb0a732 100644 --- a/common/data_source/github/models.py +++ b/common/data_source/github/models.py @@ -12,6 +12,4 @@ class SerializedRepository(BaseModel): raw_data: dict[str, Any] def to_Repository(self, requester: Requester) -> Repository.Repository: - return Repository.Repository( - requester, self.headers, self.raw_data, completed=True - ) \ No newline at end of file + return Repository.Repository(requester, self.headers, self.raw_data, completed=True) diff --git a/common/data_source/github/rate_limit_utils.py b/common/data_source/github/rate_limit_utils.py index d683bad08d..c46a10a3fe 100644 --- a/common/data_source/github/rate_limit_utils.py +++ b/common/data_source/github/rate_limit_utils.py @@ -14,11 +14,7 @@ def sleep_after_rate_limit_exception(github_client: Github) -> None: Args: github_client: The GitHub client that hit the rate limit """ - sleep_time = github_client.get_rate_limit().core.reset.replace( - tzinfo=timezone.utc - ) - datetime.now(tz=timezone.utc) + sleep_time = github_client.get_rate_limit().core.reset.replace(tzinfo=timezone.utc) - datetime.now(tz=timezone.utc) sleep_time += timedelta(minutes=1) # add an extra minute just to be safe - logging.info( - "Ran into Github rate-limit. Sleeping %s seconds.", sleep_time.seconds - ) - time.sleep(sleep_time.total_seconds()) \ No newline at end of file + logging.info("Ran into Github rate-limit. Sleeping %s seconds.", sleep_time.seconds) + time.sleep(sleep_time.total_seconds()) diff --git a/common/data_source/github/utils.py b/common/data_source/github/utils.py index 93b843bc84..f9e2980c33 100644 --- a/common/data_source/github/utils.py +++ b/common/data_source/github/utils.py @@ -8,9 +8,7 @@ from common.data_source.models import ExternalAccess from .models import SerializedRepository -def get_external_access_permission( - repo: Repository, github_client: Github -) -> ExternalAccess: +def get_external_access_permission(repo: Repository, github_client: Github) -> ExternalAccess: """ Get the external access permission for a repository. This functionality requires Enterprise Edition. @@ -20,9 +18,7 @@ def get_external_access_permission( return ExternalAccess.empty() -def deserialize_repository( - cached_repo: SerializedRepository, github_client: Github -) -> Repository: +def deserialize_repository(cached_repo: SerializedRepository, github_client: Github) -> Repository: """ Deserialize a SerializedRepository back into a Repository object. """ @@ -41,4 +37,4 @@ def deserialize_repository( # If all else fails, re-fetch the repo directly logging.warning("Failed to deserialize repository: %s. Attempting to re-fetch.", e) repo_id = cached_repo.id - return github_client.get_repo(repo_id) \ No newline at end of file + return github_client.get_repo(repo_id) diff --git a/common/data_source/gmail_connector.py b/common/data_source/gmail_connector.py index ea4dd993ae..bdd03d6c14 100644 --- a/common/data_source/gmail_connector.py +++ b/common/data_source/gmail_connector.py @@ -100,7 +100,7 @@ def thread_to_document(full_thread: dict[str, Any], email_used_to_fetch_thread: if message_metadata.get("updated_at"): updated_at = message_metadata.get("updated_at") - + updated_at_datetime = None if updated_at: updated_at_datetime = gmail_time_str_to_utc(updated_at) @@ -115,12 +115,10 @@ def thread_to_document(full_thread: dict[str, Any], email_used_to_fetch_thread: if not semantic_identifier: semantic_identifier = "(no subject)" - combined_sections = "\n\n".join( - sec.text for sec in sections if hasattr(sec, "text") - ) + combined_sections = "\n\n".join(sec.text for sec in sections if hasattr(sec, "text")) blob = combined_sections size_bytes = len(blob) - extension = '.txt' + extension = ".txt" return Document( id=thread_id, @@ -318,6 +316,7 @@ if __name__ == "__main__": import time import os from common.data_source.google_util.util import get_credentials_from_env + logging.basicConfig(level=logging.INFO) try: email = os.environ.get("GMAIL_TEST_EMAIL", "newyorkupperbay@gmail.com") @@ -336,7 +335,7 @@ if __name__ == "__main__": int(time.time()) - 1 * 24 * 60 * 60, int(time.time()), ): - print("new batch","-"*80) + print("new batch", "-" * 80) for f in file: print(f) print("\n\n") diff --git a/common/data_source/google_drive/connector.py b/common/data_source/google_drive/connector.py index 479c60e0b6..6fdae09bc5 100644 --- a/common/data_source/google_drive/connector.py +++ b/common/data_source/google_drive/connector.py @@ -252,9 +252,7 @@ class GoogleDriveConnector(SlimConnectorWithPermSync, CheckpointedConnectorWithP def get_all_drive_ids(self) -> set[str]: if self._all_drive_ids_cache is None: - self._all_drive_ids_cache = self._get_all_drives_for_user( - self.primary_admin_email - ) + self._all_drive_ids_cache = self._get_all_drives_for_user(self.primary_admin_email) return set(self._all_drive_ids_cache) def _get_all_drives_for_user(self, user_email: str) -> set[str]: @@ -752,9 +750,7 @@ class GoogleDriveConnector(SlimConnectorWithPermSync, CheckpointedConnectorWithP if remaining_folders: self.logger.warning(f"Some folders/drives were not retrieved. IDs: {remaining_folders}") - def _adjust_start_for_query( - self, start: SecondsSinceUnixEpoch | None - ) -> SecondsSinceUnixEpoch | None: + def _adjust_start_for_query(self, start: SecondsSinceUnixEpoch | None) -> SecondsSinceUnixEpoch | None: """Subtract the configured time buffer from start to create an overlap window for incremental syncs.""" if not start or start <= 0: return start @@ -1227,6 +1223,7 @@ def yield_all_docs_from_checkpoint_connector( if __name__ == "__main__": import time from common.data_source.google_util.util import get_credentials_from_env + logging.basicConfig(level=logging.DEBUG) try: diff --git a/common/data_source/google_drive/file_retrieval.py b/common/data_source/google_drive/file_retrieval.py index f143cca814..34c1305772 100644 --- a/common/data_source/google_drive/file_retrieval.py +++ b/common/data_source/google_drive/file_retrieval.py @@ -41,10 +41,7 @@ def generate_time_range_filter( time_range_filter = "" if start is not None: time_start = datetime.fromtimestamp(start, tz=timezone.utc).isoformat() - time_range_filter += ( - f" and ({GoogleFields.MODIFIED_TIME.value} > '{time_start}'" - f" or {GoogleFields.CREATED_TIME.value} >= '{time_start}')" - ) + time_range_filter += f" and ({GoogleFields.MODIFIED_TIME.value} > '{time_start}' or {GoogleFields.CREATED_TIME.value} >= '{time_start}')" if end is not None: time_stop = datetime.fromtimestamp(end, tz=timezone.utc).isoformat() time_range_filter += f" and {GoogleFields.MODIFIED_TIME.value} <= '{time_stop}'" diff --git a/common/data_source/google_util/auth.py b/common/data_source/google_util/auth.py index 85c2e6c828..bf1ce7314a 100644 --- a/common/data_source/google_util/auth.py +++ b/common/data_source/google_util/auth.py @@ -67,9 +67,7 @@ def get_google_creds( try: credentials_dict = ensure_oauth_token_dict(credentials_dict, source) except Exception as exc: - raise PermissionError( - "Google Drive OAuth credentials are incomplete. Please finish the OAuth flow to generate access tokens." - ) from exc + raise PermissionError("Google Drive OAuth credentials are incomplete. Please finish the OAuth flow to generate access tokens.") from exc credentials_dict_str = json.dumps(credentials_dict) regenerated_from_client_secret = True diff --git a/common/data_source/google_util/util.py b/common/data_source/google_util/util.py index 187c06d6d8..2e65a04013 100644 --- a/common/data_source/google_util/util.py +++ b/common/data_source/google_util/util.py @@ -223,4 +223,4 @@ def clean_string(text: str | None) -> str | None: except UnicodeEncodeError: text = text.encode("utf-8", errors="ignore").decode("utf-8") - return text \ No newline at end of file + return text diff --git a/common/data_source/html_utils.py b/common/data_source/html_utils.py index 5eff624636..b39569e664 100644 --- a/common/data_source/html_utils.py +++ b/common/data_source/html_utils.py @@ -7,9 +7,13 @@ from typing import IO import bs4 -from common.data_source.config import HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY, \ - HtmlBasedConnectorTransformLinksStrategy, WEB_CONNECTOR_IGNORED_CLASSES, WEB_CONNECTOR_IGNORED_ELEMENTS, \ - PARSE_WITH_TRAFILATURA +from common.data_source.config import ( + HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY, + HtmlBasedConnectorTransformLinksStrategy, + WEB_CONNECTOR_IGNORED_CLASSES, + WEB_CONNECTOR_IGNORED_ELEMENTS, + PARSE_WITH_TRAFILATURA, +) MINTLIFY_UNWANTED = ["sticky", "hidden"] @@ -38,11 +42,7 @@ def strip_newlines(document: str) -> str: def format_element_text(element_text: str, link_href: str | None) -> str: element_text_no_newlines = strip_newlines(element_text) - if ( - not link_href - or HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY - == HtmlBasedConnectorTransformLinksStrategy.STRIP - ): + if not link_href or HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY == HtmlBasedConnectorTransformLinksStrategy.STRIP: return element_text_no_newlines return f"[{element_text_no_newlines}]({link_href})" @@ -63,9 +63,7 @@ def parse_html_with_trafilatura(html_content: str) -> str: return strip_excessive_newlines_and_spaces(extracted_text) if extracted_text else "" -def format_document_soup( - document: bs4.BeautifulSoup, table_cell_separator: str = "\t" -) -> str: +def format_document_soup(document: bs4.BeautifulSoup, table_cell_separator: str = "\t") -> str: """Format html to a flat text document. The following goals: @@ -101,16 +99,10 @@ def format_document_soup( last_added_newline = False if element_text: - content_to_add = ( - element_text - if verbatim_output > 0 - else format_element_text(element_text, link_href) - ) + content_to_add = element_text if verbatim_output > 0 else format_element_text(element_text, link_href) # Don't join separate elements without any spacing - if (text and not text[-1].isspace()) and ( - content_to_add and not content_to_add[0].isspace() - ): + if (text and not text[-1].isspace()) and (content_to_add and not content_to_add[0].isspace()): text += " " text += content_to_add @@ -134,9 +126,7 @@ def format_document_soup( elif e.name == "a": href_value = e.get("href", None) # mostly for typing, having multiple hrefs is not valid HTML - link_href = ( - href_value[0] if isinstance(href_value, list) else href_value - ) + link_href = href_value[0] if isinstance(href_value, list) else href_value elif e.name == "/a": link_href = None elif e.name in ["p", "div"]: @@ -185,12 +175,7 @@ def web_html_cleanup( if mintlify_cleanup_enabled: unwanted_classes.extend(MINTLIFY_UNWANTED) for undesired_element in unwanted_classes: - [ - tag.extract() - for tag in soup.find_all( - class_=lambda x: x and undesired_element in x.split() - ) - ] + [tag.extract() for tag in soup.find_all(class_=lambda x: x and undesired_element in x.split())] for undesired_tag in WEB_CONNECTOR_IGNORED_ELEMENTS: [tag.extract() for tag in soup.find_all(undesired_tag)] diff --git a/common/data_source/imap_connector.py b/common/data_source/imap_connector.py index 2f12e6be91..1d3560c2c5 100644 --- a/common/data_source/imap_connector.py +++ b/common/data_source/imap_connector.py @@ -40,14 +40,13 @@ _PAGE_SIZE = 100 _USERNAME_KEY = "imap_username" _PASSWORD_KEY = "imap_password" + class Header(str, Enum): SUBJECT_HEADER = "subject" FROM_HEADER = "from" TO_HEADER = "to" CC_HEADER = "cc" - DELIVERED_TO_HEADER = ( - "Delivered-To" # Used in mailing lists instead of the "to" header. - ) + DELIVERED_TO_HEADER = "Delivered-To" # Used in mailing lists instead of the "to" header. DATE_HEADER = "date" MESSAGE_ID_HEADER = "Message-ID" @@ -77,13 +76,9 @@ class EmailHeaders(BaseModel): for decoded_value, encoding in decoded_fragments: if isinstance(decoded_value, bytes): try: - decoded_strings.append( - decoded_value.decode(encoding or "utf-8", errors="replace") - ) + decoded_strings.append(decoded_value.decode(encoding or "utf-8", errors="replace")) except LookupError: - decoded_strings.append( - decoded_value.decode("utf-8", errors="replace") - ) + decoded_strings.append(decoded_value.decode("utf-8", errors="replace")) elif isinstance(decoded_value, str): decoded_strings.append(decoded_value) else: @@ -121,11 +116,7 @@ class EmailHeaders(BaseModel): sender=from_ or "", recipients=to or "", cc=cc or "", - date_key=( - _as_utc(parsed_date).isoformat() - if parsed_date - else (date_str or "") - ), + date_key=(_as_utc(parsed_date).isoformat() if parsed_date else (date_str or "")), ) # If any of the above are `None`, model validation will fail. @@ -141,6 +132,7 @@ class EmailHeaders(BaseModel): } ) + class CurrentMailbox(BaseModel): mailbox: str todo_email_ids: list[str] @@ -184,9 +176,7 @@ class ImapConnector( @property def credentials(self) -> dict[str, Any]: if not self._credentials: - raise RuntimeError( - "Credentials have not been initialized; call `set_credentials_provider` first" - ) + raise RuntimeError("Credentials have not been initialized; call `set_credentials_provider` first") return self._credentials def _get_mail_client(self) -> imaplib.IMAP4_SSL: @@ -213,9 +203,7 @@ class ImapConnector( if not value: raise RuntimeError(f"Credential item {name=} was not found") if not isinstance(value, str): - raise RuntimeError( - f"Credential item {name=} must be of type str, instead received {type(name)=}" - ) + raise RuntimeError(f"Credential item {name=} must be of type str, instead received {type(name)=}") return value username = get_or_raise(_USERNAME_KEY) @@ -247,21 +235,14 @@ class ImapConnector( if self._mailboxes: checkpoint.todo_mailboxes = _sanitize_mailbox_names(self._mailboxes) else: - fetched_mailboxes = _fetch_all_mailboxes_for_email_account( - mail_client=mail_client - ) + fetched_mailboxes = _fetch_all_mailboxes_for_email_account(mail_client=mail_client) if not fetched_mailboxes: - raise RuntimeError( - "Failed to find any mailboxes for this email account" - ) + raise RuntimeError("Failed to find any mailboxes for this email account") checkpoint.todo_mailboxes = _sanitize_mailbox_names(fetched_mailboxes) return checkpoint - if ( - not checkpoint.current_mailbox - or not checkpoint.current_mailbox.todo_email_ids - ): + if not checkpoint.current_mailbox or not checkpoint.current_mailbox.todo_email_ids: if not checkpoint.todo_mailboxes: checkpoint.has_more = False return checkpoint @@ -278,15 +259,9 @@ class ImapConnector( todo_email_ids=email_ids, ) - _select_mailbox( - mail_client=mail_client, mailbox=checkpoint.current_mailbox.mailbox - ) - current_todos = cast( - list, copy.deepcopy(checkpoint.current_mailbox.todo_email_ids[:_PAGE_SIZE]) - ) - checkpoint.current_mailbox.todo_email_ids = ( - checkpoint.current_mailbox.todo_email_ids[_PAGE_SIZE:] - ) + _select_mailbox(mail_client=mail_client, mailbox=checkpoint.current_mailbox.mailbox) + current_todos = cast(list, copy.deepcopy(checkpoint.current_mailbox.todo_email_ids[:_PAGE_SIZE])) + checkpoint.current_mailbox.todo_email_ids = checkpoint.current_mailbox.todo_email_ids[_PAGE_SIZE:] for email_id in current_todos: email_msg = _fetch_email(mail_client=mail_client, email_id=email_id) @@ -325,9 +300,7 @@ class ImapConnector( # impls for CredentialsConnector - def set_credentials_provider( - self, credentials_provider: CredentialsProviderInterface - ) -> None: + def set_credentials_provider(self, credentials_provider: CredentialsProviderInterface) -> None: self._credentials = credentials_provider.get_credentials() # impls for CheckpointedConnector @@ -338,9 +311,7 @@ class ImapConnector( end: SecondsSinceUnixEpoch, checkpoint: ImapCheckpoint, ) -> CheckpointOutput[ImapCheckpoint]: - return self._load_from_checkpoint( - start=start, end=end, checkpoint=checkpoint, include_perm_sync=False - ) + return self._load_from_checkpoint(start=start, end=end, checkpoint=checkpoint, include_perm_sync=False) def build_dummy_checkpoint(self) -> ImapCheckpoint: return ImapCheckpoint(has_more=True) @@ -356,9 +327,7 @@ class ImapConnector( end: SecondsSinceUnixEpoch, checkpoint: ImapCheckpoint, ) -> CheckpointOutput[ImapCheckpoint]: - return self._load_from_checkpoint( - start=start, end=end, checkpoint=checkpoint, include_perm_sync=True - ) + return self._load_from_checkpoint(start=start, end=end, checkpoint=checkpoint, include_perm_sync=True) def retrieve_all_slim_docs_perm_sync( self, @@ -369,18 +338,14 @@ class ImapConnector( del callback mail_client = self._get_mail_client() start_ts = start if start is not None else 0 - end_ts = ( - end if end is not None else datetime.now(tz=timezone.utc).timestamp() - ) + end_ts = end if end is not None else datetime.now(tz=timezone.utc).timestamp() start_dt = datetime.fromtimestamp(start_ts, tz=timezone.utc) end_dt = datetime.fromtimestamp(end_ts, tz=timezone.utc) if self._mailboxes: mailboxes = _sanitize_mailbox_names(self._mailboxes) else: - mailboxes = _sanitize_mailbox_names( - _fetch_all_mailboxes_for_email_account(mail_client=mail_client) - ) + mailboxes = _sanitize_mailbox_names(_fetch_all_mailboxes_for_email_account(mail_client=mail_client)) slim_doc_batch: list[SlimDocument] = [] for mailbox in mailboxes: @@ -405,11 +370,7 @@ class ImapConnector( slim_doc_batch.append(SlimDocument(id=email_headers.id)) for att in extract_attachments(email_msg): - slim_doc_batch.append( - SlimDocument( - id=_attachment_document_id(email_headers.id, att) - ) - ) + slim_doc_batch.append(SlimDocument(id=_attachment_document_id(email_headers.id, att))) if len(slim_doc_batch) >= _PAGE_SIZE: yield slim_doc_batch @@ -432,9 +393,7 @@ def _fetch_all_mailboxes_for_email_account(mail_client: imaplib.IMAP4_SSL) -> li elif isinstance(mailboxes_raw, str): mailboxes_str = mailboxes_raw else: - logging.warning( - f"Expected the mailbox data to be of type str, instead got {type(mailboxes_raw)=} {mailboxes_raw}; skipping" - ) + logging.warning(f"Expected the mailbox data to be of type str, instead got {type(mailboxes_raw)=} {mailboxes_raw}; skipping") continue # The mailbox LIST response output can be found here: @@ -446,17 +405,13 @@ def _fetch_all_mailboxes_for_email_account(mail_client: imaplib.IMAP4_SSL) -> li # The below regex matches on that pattern; from there, we select the 3rd match (index 2), which is the mailbox-name. match = re.match(r'\([^)]*\)\s+"([^"]+)"\s+"?(.+?)"?$', mailboxes_str) if not match: - logging.warning( - f"Invalid mailbox-data formatting structure: {mailboxes_str=}; skipping" - ) + logging.warning(f"Invalid mailbox-data formatting structure: {mailboxes_str=}; skipping") continue mailbox = match.group(2) mailboxes.append(mailbox) if not mailboxes: - logging.warning( - "No mailboxes parsed from LIST response; falling back to INBOX" - ) + logging.warning("No mailboxes parsed from LIST response; falling back to INBOX") return ["INBOX"] return mailboxes @@ -506,9 +461,7 @@ def _fetch_email(mail_client: imaplib.IMAP4_SSL, email_id: str) -> Message | Non data = msg_data[0] if not isinstance(data, tuple): - raise RuntimeError( - f"Message data should be a tuple; instead got a {type(data)=} {data=}" - ) + raise RuntimeError(f"Message data should be a tuple; instead got a {type(data)=} {data=}") _, raw_email = data return email.message_from_bytes(raw_email) @@ -553,28 +506,13 @@ def _convert_email_headers_and_body_into_document( include_perm_sync: bool, ) -> Document: sender_name, sender_addr = _parse_singular_addr(raw_header=email_headers.sender) - to_addrs = ( - _parse_addrs(email_headers.recipients) - if email_headers.recipients - else [] - ) - cc_addrs = ( - _parse_addrs(email_headers.cc) - if email_headers.cc - else [] - ) + to_addrs = _parse_addrs(email_headers.recipients) if email_headers.recipients else [] + cc_addrs = _parse_addrs(email_headers.cc) if email_headers.cc else [] all_participants = to_addrs + cc_addrs - expert_info_map = { - recipient_addr: BasicExpertInfo( - display_name=recipient_name, email=recipient_addr - ) - for recipient_name, recipient_addr in all_participants - } + expert_info_map = {recipient_addr: BasicExpertInfo(display_name=recipient_name, email=recipient_addr) for recipient_name, recipient_addr in all_participants} if sender_addr not in expert_info_map: - expert_info_map[sender_addr] = BasicExpertInfo( - display_name=sender_name, email=sender_addr - ) + expert_info_map[sender_addr] = BasicExpertInfo(display_name=sender_name, email=sender_addr) email_body = _parse_email_body(email_msg=email_msg, email_headers=email_headers) primary_owners = list(expert_info_map.values()) @@ -594,13 +532,14 @@ def _convert_email_headers_and_body_into_document( size_bytes=len(email_body), semantic_identifier=email_headers.subject, metadata={}, - extension='.txt', + extension=".txt", doc_updated_at=email_headers.date, source=DocumentSource.IMAP, primary_owners=primary_owners, external_access=external_access, ) + def extract_attachments(email_msg: Message, max_bytes: int = IMAP_CONNECTOR_SIZE_THRESHOLD): attachments = [] @@ -614,10 +553,7 @@ def extract_attachments(email_msg: Message, max_bytes: int = IMAP_CONNECTOR_SIZE disposition = (part.get("Content-Disposition") or "").lower() filename = part.get_filename() - if not ( - disposition.startswith("attachment") - or (disposition.startswith("inline") and filename) - ): + if not (disposition.startswith("attachment") or (disposition.startswith("inline") and filename)): continue payload = part.get_payload(decode=True) @@ -627,15 +563,18 @@ def extract_attachments(email_msg: Message, max_bytes: int = IMAP_CONNECTOR_SIZE if len(payload) > max_bytes: continue - attachments.append({ - "filename": filename or "attachment.bin", - "content_type": part.get_content_type(), - "content_bytes": payload, - "size_bytes": len(payload), - }) + attachments.append( + { + "filename": filename or "attachment.bin", + "content_type": part.get_content_type(), + "content_bytes": payload, + "size_bytes": len(payload), + } + ) return attachments + def decode_mime_filename(raw: str | None) -> str | None: if not raw: return None @@ -689,15 +628,14 @@ def attachment_to_document( }, ) + def _parse_email_body( email_msg: Message, email_headers: EmailHeaders, ) -> str: body = _extract_email_body_text(email_msg) if not body: - logging.warning( - f"Email with {email_headers.id=} has an empty body; returning an empty string" - ) + logging.warning(f"Email with {email_headers.id=} has an empty body; returning an empty string") return body @@ -714,10 +652,7 @@ def _extract_email_body_text(email_msg: Message) -> str: try: raw_payload = part.get_payload(decode=True) if not isinstance(raw_payload, bytes): - logging.warning( - "Payload section from email was expected to be an array of bytes, instead got " - f"{type(raw_payload)=}, {raw_payload=}" - ) + logging.warning(f"Payload section from email was expected to be an array of bytes, instead got {type(raw_payload)=}, {raw_payload=}") continue body = raw_payload.decode(charset) break @@ -765,10 +700,7 @@ if __name__ == "__main__": from types import TracebackType from common.data_source.utils import load_all_docs_from_checkpoint_connector - - class OnyxStaticCredentialsProvider( - CredentialsProviderInterface["OnyxStaticCredentialsProvider"] - ): + class OnyxStaticCredentialsProvider(CredentialsProviderInterface["OnyxStaticCredentialsProvider"]): """Implementation (a very simple one!) to handle static credentials.""" def __init__( @@ -808,19 +740,16 @@ if __name__ == "__main__": def is_dynamic(self) -> bool: return False + # from tests.daily.connectors.utils import load_all_docs_from_checkpoint_connector # from onyx.connectors.credentials_provider import OnyxStaticCredentialsProvider host = os.environ.get("IMAP_HOST") - mailboxes_str = os.environ.get("IMAP_MAILBOXES","INBOX") + mailboxes_str = os.environ.get("IMAP_MAILBOXES", "INBOX") username = os.environ.get("IMAP_USERNAME") password = os.environ.get("IMAP_PASSWORD") - mailboxes = ( - [mailbox.strip() for mailbox in mailboxes_str.split(",")] - if mailboxes_str - else [] - ) + mailboxes = [mailbox.strip() for mailbox in mailboxes_str.split(",")] if mailboxes_str else [] if not host: raise RuntimeError("`IMAP_HOST` must be set") @@ -847,4 +776,4 @@ if __name__ == "__main__": start=START, end=END, ): - print(doc.id,doc.extension) + print(doc.id, doc.extension) diff --git a/common/data_source/interfaces.py b/common/data_source/interfaces.py index fb547d7d92..5c103f0603 100644 --- a/common/data_source/interfaces.py +++ b/common/data_source/interfaces.py @@ -1,4 +1,5 @@ """Interface definitions""" + import abc import uuid from abc import ABC, abstractmethod @@ -8,14 +9,7 @@ from typing import Any, Dict, Generator, TypeVar, Generic, Callable, TypeAlias from collections.abc import Iterator from anthropic import BaseModel -from common.data_source.models import ( - Document, - KeyRecord, - SlimDocument, - ConnectorCheckpoint, - ConnectorFailure, - SecondsSinceUnixEpoch, GenerateSlimDocumentOutput -) +from common.data_source.models import Document, KeyRecord, SlimDocument, ConnectorCheckpoint, ConnectorFailure, SecondsSinceUnixEpoch, GenerateSlimDocumentOutput class IncrementalCapability(IntEnum): @@ -25,6 +19,7 @@ class IncrementalCapability(IntEnum): CURSOR -- "give me everything since cursor X"; opaque cursor persisted across syncs. FINGERPRINT -- list_keys() returns (key, fingerprint) cheaply; bodies fetched lazily. """ + FULL_RESYNC = 0 CURSOR = 1 FINGERPRINT = 2 @@ -32,6 +27,7 @@ class IncrementalCapability(IntEnum): GenerateDocumentsOutput = Iterator[list[Document]] + class LoadConnector(ABC): """Load connector interface""" @@ -165,9 +161,7 @@ class CredentialsProviderInterface(abc.ABC, Generic[T]): raise NotImplementedError -class StaticCredentialsProvider( - CredentialsProviderInterface["StaticCredentialsProvider"] -): +class StaticCredentialsProvider(CredentialsProviderInterface["StaticCredentialsProvider"]): """Implementation (a very simple one!) to handle static credentials.""" def __init__( @@ -224,9 +218,7 @@ class BaseConnector(abc.ABC, Generic[CT]): @staticmethod def parse_metadata(metadata: dict[str, Any]) -> list[str]: """Parse the metadata for a document/chunk into a string to pass to Generative AI as additional context""" - custom_parser_req_msg = ( - "Specific metadata parsing required, connector has not implemented it." - ) + custom_parser_req_msg = "Specific metadata parsing required, connector has not implemented it." metadata_lines = [] for metadata_key, metadata_value in metadata.items(): if isinstance(metadata_value, str): @@ -234,7 +226,7 @@ class BaseConnector(abc.ABC, Generic[CT]): elif isinstance(metadata_value, list): if not all([isinstance(val, str) for val in metadata_value]): raise RuntimeError(custom_parser_req_msg) - metadata_lines.append(f'{metadata_key}: {", ".join(metadata_value)}') + metadata_lines.append(f"{metadata_key}: {', '.join(metadata_value)}") else: raise RuntimeError(custom_parser_req_msg) return metadata_lines @@ -342,14 +334,10 @@ class CheckpointOutputWrapper(Generic[CT]): elif isinstance(document_or_failure, ConnectorFailure): yield None, document_or_failure, None else: - raise ValueError( - f"Invalid document_or_failure type: {type(document_or_failure)}" - ) + raise ValueError(f"Invalid document_or_failure type: {type(document_or_failure)}") if self.next_checkpoint is None: - raise RuntimeError( - "Checkpoint is None. This should never happen - the connector should always return a checkpoint." - ) + raise RuntimeError("Checkpoint is None. This should never happen - the connector should always return a checkpoint.") yield None, None, self.next_checkpoint @@ -464,4 +452,3 @@ class FingerprintConnector(ABC): content_hash for that key (or when no persisted fingerprint exists). """ raise NotImplementedError - diff --git a/common/data_source/jira/connector.py b/common/data_source/jira/connector.py index aa4082f414..c562cf15e0 100644 --- a/common/data_source/jira/connector.py +++ b/common/data_source/jira/connector.py @@ -128,9 +128,7 @@ class JiraConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermSync try: buffer_value = int(time_buffer_seconds) except (TypeError, ValueError) as exc: - raise ConnectorValidationError( - f"Invalid time_buffer_seconds value ({time_buffer_seconds!r}); expected an integer." - ) from exc + raise ConnectorValidationError(f"Invalid time_buffer_seconds value ({time_buffer_seconds!r}); expected an integer.") from exc self.time_buffer_seconds = max(0, buffer_value) # ------------------------------------------------------------------------- @@ -149,10 +147,7 @@ class JiraConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermSync else: logger.warning("[Jira] Scoped token requested but Jira base URL does not appear to be an Atlassian Cloud domain; scoped token ignored.") - user_email = ( - credentials.get("jira_user_email") - or credentials.get("jira_username") - ) + user_email = credentials.get("jira_user_email") or credentials.get("jira_username") api_token = credentials.get("jira_api_token") or credentials.get("token") or credentials.get("api_token") password = credentials.get("jira_password") or credentials.get("password") rest_api_version = credentials.get("rest_api_version") @@ -963,16 +958,7 @@ def main(config: dict[str, Any] | None = None) -> None: if not base_url: raise RuntimeError("Jira base URL must be provided via config or CLI arguments.") - if not ( - credentials.get("jira_api_token") - or ( - ( - credentials.get("jira_user_email") - or credentials.get("jira_username") - ) - and credentials.get("jira_password") - ) - ): + if not (credentials.get("jira_api_token") or ((credentials.get("jira_user_email") or credentials.get("jira_username")) and credentials.get("jira_password"))): raise RuntimeError("Provide either an API token or both email/password for Jira authentication.") connector_options = { diff --git a/common/data_source/models.py b/common/data_source/models.py index 29cb6bc251..abad772b2b 100644 --- a/common/data_source/models.py +++ b/common/data_source/models.py @@ -1,4 +1,5 @@ """Data model definitions for all connectors""" + from dataclasses import dataclass from datetime import datetime from typing import Any, Optional, List, Sequence, NamedTuple @@ -9,7 +10,6 @@ from enum import Enum @dataclass(frozen=True) class ExternalAccess: - # arbitrary limit to prevent excessively large permissions sets # not internally enforced ... the caller can check this before using the instance MAX_NUM_ENTRIES = 5000 @@ -30,12 +30,7 @@ class ExternalAccess: return f"{s_str[:max_len]}... ({len(s)} items)" return s_str - return ( - f"ExternalAccess(" - f"external_user_emails={truncate_set(self.external_user_emails)}, " - f"external_user_group_ids={truncate_set(self.external_user_group_ids)}, " - f"is_public={self.is_public})" - ) + return f"ExternalAccess(external_user_emails={truncate_set(self.external_user_emails)}, external_user_group_ids={truncate_set(self.external_user_group_ids)}, is_public={self.is_public})" @property def num_entries(self) -> int: @@ -76,18 +71,21 @@ class ExtractionResult(NamedTuple): class TextSection(BaseModel): """Text section model""" + link: str text: str class ImageSection(BaseModel): """Image section model""" + link: str image_file_id: str class Document(BaseModel): """Document model""" + id: str source: str semantic_identifier: str @@ -115,6 +113,7 @@ class KeyRecord(BaseModel): orchestrator only fetches content when the fingerprint differs from what's persisted. """ + key: str fingerprint: Optional[str] = None deleted: bool = False @@ -122,6 +121,7 @@ class KeyRecord(BaseModel): class BasicExpertInfo(BaseModel): """Expert information model""" + display_name: Optional[str] = None first_name: Optional[str] = None last_name: Optional[str] = None @@ -143,29 +143,34 @@ class BasicExpertInfo(BaseModel): class SlimDocument(BaseModel): """Simplified document model (contains only ID and permission info)""" + id: str external_access: Optional[Any] = None class ConnectorCheckpoint(BaseModel): """Connector checkpoint model""" + has_more: bool = True class DocumentFailure(BaseModel): """Document processing failure information""" + document_id: str document_link: str class EntityFailure(BaseModel): """Entity processing failure information""" + entity_id: str missed_time_range: tuple[datetime, datetime] class ConnectorFailure(BaseModel): """Connector failure information""" + failed_document: Optional[DocumentFailure] = None failed_entity: Optional[EntityFailure] = None failure_message: str @@ -177,18 +182,21 @@ class ConnectorFailure(BaseModel): # Gmail Models class GmailCredentials(BaseModel): """Gmail authentication credentials model""" + primary_admin_email: str credentials: dict[str, Any] class GmailThread(BaseModel): """Gmail thread data model""" + id: str messages: list[dict[str, Any]] class GmailMessage(BaseModel): """Gmail message data model""" + id: str payload: dict[str, Any] label_ids: Optional[list[str]] = None @@ -197,6 +205,7 @@ class GmailMessage(BaseModel): # Notion Models class NotionPage(BaseModel): """Represents a Notion Page object""" + id: str created_time: str last_edited_time: str @@ -209,6 +218,7 @@ class NotionPage(BaseModel): class NotionBlock(BaseModel): """Represents a Notion Block object""" + id: str # Used for the URL text: str prefix: str # How this block should be joined with existing text @@ -216,6 +226,7 @@ class NotionBlock(BaseModel): class NotionSearchResponse(BaseModel): """Represents the response from the Notion Search API""" + results: list[dict[str, Any]] next_cursor: Optional[str] has_more: bool = False @@ -223,12 +234,14 @@ class NotionSearchResponse(BaseModel): class NotionCredentials(BaseModel): """Notion authentication credentials model""" + integration_token: str # Slack Models class ChannelTopicPurposeType(TypedDict): """Slack channel topic or purpose""" + value: str creator: str last_set: int @@ -236,6 +249,7 @@ class ChannelTopicPurposeType(TypedDict): class ChannelType(TypedDict): """Slack channel""" + id: str name: str is_channel: bool @@ -264,6 +278,7 @@ class ChannelType(TypedDict): class AttachmentType(TypedDict): """Slack message attachment""" + service_name: NotRequired[str] text: NotRequired[str] fallback: NotRequired[str] @@ -275,6 +290,7 @@ class AttachmentType(TypedDict): class BotProfileType(TypedDict): """Slack bot profile""" + id: NotRequired[str] deleted: NotRequired[bool] name: NotRequired[str] @@ -285,6 +301,7 @@ class BotProfileType(TypedDict): class MessageType(TypedDict): """Slack message""" + type: str user: str text: str @@ -303,6 +320,7 @@ ThreadType = List[MessageType] class SlackCheckpoint(TypedDict): """Slack checkpoint""" + channel_ids: List[str] | None channel_completion_map: dict[str, str] current_channel: ChannelType | None @@ -313,12 +331,14 @@ class SlackCheckpoint(TypedDict): class SlackMessageFilterReason(str): """Slack message filter reason""" + BOT = "bot" DISALLOWED = "disallowed" class ProcessedSlackMessage: """Processed Slack message""" + def __init__(self, doc=None, thread_or_message_ts=None, filter_reason=None, failure=None): self.doc = doc self.thread_or_message_ts = thread_or_message_ts @@ -328,9 +348,12 @@ class ProcessedSlackMessage: class SeafileSyncScope(str, Enum): """Defines how much of SeaFile to synchronise.""" - ACCOUNT = "account" # All libraries the token can see - LIBRARY = "library" # A single library (repo) + + ACCOUNT = "account" # All libraries the token can see + LIBRARY = "library" # A single library (repo) DIRECTORY = "directory" # A single directory inside a library + + # Type aliases for type hints SecondsSinceUnixEpoch = float GenerateDocumentsOutput = Any diff --git a/common/data_source/moodle_connector.py b/common/data_source/moodle_connector.py index 850ce5815d..192155e03a 100644 --- a/common/data_source/moodle_connector.py +++ b/common/data_source/moodle_connector.py @@ -51,9 +51,7 @@ class MoodleConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): delimiter = "&" if "?" in file_url else "?" return f"{file_url}{delimiter}token={token}" - def _log_error( - self, context: str, error: Exception, level: str = "warning" - ) -> None: + def _log_error(self, context: str, error: Exception, level: str = "warning") -> None: """Simplified logging wrapper""" msg = f"{context}: {error}" if level == "error": @@ -65,9 +63,7 @@ class MoodleConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): """Return latest valid timestamp""" return max((t for t in timestamps if t and t > 0), default=0) - def _yield_in_batches( - self, generator: Generator[Document, None, None] - ) -> Generator[list[Document], None, None]: + def _yield_in_batches(self, generator: Generator[Document, None, None]) -> Generator[list[Document], None, None]: for batch in batch_generator(generator, self.batch_size): yield batch @@ -77,16 +73,12 @@ class MoodleConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): raise ConnectorMissingCredentialError("Moodle API token is required") try: - self.moodle_client = MoodleClient( - self.moodle_url + "/webservice/rest/server.php", token - ) + self.moodle_client = MoodleClient(self.moodle_url + "/webservice/rest/server.php", token) self.moodle_client.core.webservice.get_site_info() except MoodleException as e: if "invalidtoken" in str(e).lower(): raise CredentialExpiredError("Moodle token is invalid or expired") - raise ConnectorMissingCredentialError( - f"Failed to initialize Moodle client: {e}" - ) + raise ConnectorMissingCredentialError(f"Failed to initialize Moodle client: {e}") def validate_connector_settings(self) -> None: if not self.moodle_client: @@ -101,9 +93,7 @@ class MoodleConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): if "invalidtoken" in msg: raise CredentialExpiredError("Moodle token is invalid or expired") if "accessexception" in msg: - raise InsufficientPermissionsError( - "Insufficient permissions. Ensure web services are enabled and permissions are correct." - ) + raise InsufficientPermissionsError("Insufficient permissions. Ensure web services are enabled and permissions are correct.") raise ConnectorValidationError(f"Moodle validation error: {e}") except Exception as e: raise ConnectorValidationError(f"Unexpected validation error: {e}") @@ -124,23 +114,17 @@ class MoodleConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): yield from self._yield_in_batches(self._process_courses(courses)) - def poll_source( - self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch - ) -> Generator[list[Document], None, None]: + def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Generator[list[Document], None, None]: if not self.moodle_client: raise ConnectorMissingCredentialError("Moodle client not initialized") - logger.info( - f"Polling Moodle updates between {datetime.fromtimestamp(start)} and {datetime.fromtimestamp(end)}" - ) + logger.info(f"Polling Moodle updates between {datetime.fromtimestamp(start)} and {datetime.fromtimestamp(end)}") courses = self._get_enrolled_courses() if not courses: logger.warning("No courses found to poll") return - yield from self._yield_in_batches( - self._get_updated_content(courses, start, end) - ) + yield from self._yield_in_batches(self._get_updated_content(courses, start, end)) @staticmethod def _slim_doc_id_for_module(module) -> Optional[str]: @@ -248,9 +232,7 @@ class MoodleConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): except Exception as e: self._log_error(f"processing course {course.fullname}", e) - def _get_updated_content( - self, courses, start: float, end: float - ) -> Generator[Document, None, None]: + def _get_updated_content(self, courses, start: float, end: float) -> Generator[Document, None, None]: for course in courses: try: contents = self._get_course_contents(course.id) @@ -261,11 +243,7 @@ class MoodleConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): getattr(module, "timemodified", 0), ] if hasattr(module, "contents"): - times.extend( - getattr(c, "timemodified", 0) - for c in module.contents - if c and getattr(c, "timemodified", 0) - ) + times.extend(getattr(c, "timemodified", 0) for c in module.contents if c and getattr(c, "timemodified", 0)) last_mod = self._get_latest_timestamp(*times) if start < last_mod <= end: doc = self._process_module(course, section, module) @@ -309,9 +287,7 @@ class MoodleConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): ) try: - resp = rl_requests.get( - self._add_token_to_url(file_info.fileurl), timeout=60 - ) + resp = rl_requests.get(self._add_token_to_url(file_info.fileurl), timeout=60) resp.raise_for_status() blob = resp.content ext = os.path.splitext(file_name)[1] or ".bin" @@ -359,9 +335,7 @@ class MoodleConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): return None try: - result = self.moodle_client.mod.forum.get_forum_discussions( - forumid=module.instance - ) + result = self.moodle_client.mod.forum.get_forum_discussions(forumid=module.instance) disc_list = getattr(result, "discussions", []) if not disc_list: return None @@ -440,9 +414,7 @@ class MoodleConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): ) try: - resp = rl_requests.get( - self._add_token_to_url(file_info.fileurl), timeout=60 - ) + resp = rl_requests.get(self._add_token_to_url(file_info.fileurl), timeout=60) resp.raise_for_status() blob = resp.content ext = os.path.splitext(file_name)[1] or ".html" @@ -539,12 +511,7 @@ class MoodleConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): return None contents = module.contents - chapters = [ - c - for c in contents - if getattr(c, "fileurl", None) - and os.path.basename(c.filename) == "index.html" - ] + chapters = [c for c in contents if getattr(c, "fileurl", None) and os.path.basename(c.filename) == "index.html"] if not chapters: return None diff --git a/common/data_source/notion_connector.py b/common/data_source/notion_connector.py index ea3d6d0764..2972809294 100644 --- a/common/data_source/notion_connector.py +++ b/common/data_source/notion_connector.py @@ -457,9 +457,7 @@ class NotionConnector(LoadConnector, PollConnector): if result_type == "child_page": child_pages.append(result_block_id) else: - nested_child_pages, nested_attachment_ids = self._read_slim_blocks( - result_block_id - ) + nested_child_pages, nested_attachment_ids = self._read_slim_blocks(result_block_id) child_pages.extend(nested_child_pages) attachment_ids.extend(nested_attachment_ids) @@ -566,7 +564,13 @@ class NotionConnector(LoadConnector, PollConnector): joined_text = "\n".join(sec.text for sec in sections) blob = joined_text.encode("utf-8") yield Document( - id=page.id, blob=blob, source=DocumentSource.NOTION, semantic_identifier=semantic_identifier, extension=".txt", size_bytes=len(blob), doc_updated_at=datetime_from_string(page.last_edited_time) + id=page.id, + blob=blob, + source=DocumentSource.NOTION, + semantic_identifier=semantic_identifier, + extension=".txt", + size_bytes=len(blob), + doc_updated_at=datetime_from_string(page.last_edited_time), ) for attachment_doc in attachment_docs: @@ -616,11 +620,7 @@ class NotionConnector(LoadConnector, PollConnector): if self.recursive_index_enabled and all_child_page_ids: for child_page_batch_ids in batch_generator(all_child_page_ids, INDEX_BATCH_SIZE): - child_page_batch = [ - self._fetch_page(page_id) - for page_id in child_page_batch_ids - if page_id not in slim_indexed_pages - ] + child_page_batch = [self._fetch_page(page_id) for page_id in child_page_batch_ids if page_id not in slim_indexed_pages] yield from self._read_pages_for_slim_docs( child_page_batch, slim_indexed_pages, diff --git a/common/data_source/onedrive_connector.py b/common/data_source/onedrive_connector.py index ef5353c919..0d2a614595 100644 --- a/common/data_source/onedrive_connector.py +++ b/common/data_source/onedrive_connector.py @@ -27,8 +27,16 @@ _GRAPH_SCOPE = ["https://graph.microsoft.com/.default"] # File extensions we support for ingestion _SUPPORTED_EXTENSIONS = { - ".pdf", ".docx", ".doc", ".xlsx", ".xls", - ".pptx", ".ppt", ".txt", ".md", ".csv", + ".pdf", + ".docx", + ".doc", + ".xlsx", + ".xls", + ".pptx", + ".ppt", + ".txt", + ".md", + ".csv", } @@ -49,6 +57,7 @@ def _normalize_folder_path(folder_path: str | None) -> str | None: class OneDriveCheckpoint(ConnectorCheckpoint): """OneDrive-specific checkpoint tracking delta links per drive.""" + delta_links: dict[str, str] | None = None @@ -81,9 +90,7 @@ class OneDriveConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPerm client_secret = credentials.get("client_secret") if not all([tenant_id, client_id, client_secret]): - raise ConnectorMissingCredentialError( - "OneDrive credentials are incomplete (tenant_id, client_id, client_secret required)" - ) + raise ConnectorMissingCredentialError("OneDrive credentials are incomplete (tenant_id, client_id, client_secret required)") self._tenant_id = tenant_id @@ -96,9 +103,7 @@ class OneDriveConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPerm if "access_token" not in result: error = result.get("error_description", result.get("error", "unknown")) - raise ConnectorMissingCredentialError( - f"Failed to acquire OneDrive access token: {error}" - ) + raise ConnectorMissingCredentialError(f"Failed to acquire OneDrive access token: {error}") self._access_token = result["access_token"] return None @@ -115,24 +120,15 @@ class OneDriveConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPerm # Requires Files.Read.All. resp = self._get(f"{_GRAPH_BASE}/drives?$top=1") if resp.status_code == 401: - raise ConnectorMissingCredentialError( - "OneDrive access token is invalid or expired." - ) + raise ConnectorMissingCredentialError("OneDrive access token is invalid or expired.") if resp.status_code == 403: - raise InsufficientPermissionsError( - "The service principal lacks the 'Files.Read.All' permission " - "required by the OneDrive connector." - ) + raise InsufficientPermissionsError("The service principal lacks the 'Files.Read.All' permission required by the OneDrive connector.") if not resp.ok: - raise UnexpectedValidationError( - f"OneDrive validation failed (HTTP {resp.status_code}): {resp.text[:200]}" - ) + raise UnexpectedValidationError(f"OneDrive validation failed (HTTP {resp.status_code}): {resp.text[:200]}") data = resp.json() if "value" not in data: - raise ConnectorValidationError( - "Unexpected response format from Microsoft Graph /drives." - ) + raise ConnectorValidationError("Unexpected response format from Microsoft Graph /drives.") # ------------------------------------------------------------------ # Checkpoint helpers @@ -151,9 +147,7 @@ class OneDriveConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPerm # Core data loading # ------------------------------------------------------------------ - def poll_source( - self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch - ) -> Any: + def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Any: """Return documents modified at or after *start* (epoch seconds). Kept for callers that prefer the time-window interface; internally @@ -252,15 +246,11 @@ class OneDriveConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPerm url, body_snippet, ) - raise UnexpectedValidationError( - f"OneDrive Graph request failed ({context}): HTTP {resp.status_code} {body_snippet}" - ) + raise UnexpectedValidationError(f"OneDrive Graph request failed ({context}): HTTP {resp.status_code} {body_snippet}") try: return resp.json() except ValueError as exc: - raise UnexpectedValidationError( - f"OneDrive Graph response is not JSON ({context}): {exc}" - ) + raise UnexpectedValidationError(f"OneDrive Graph response is not JSON ({context}): {exc}") def _list_drive_ids(self) -> list[str]: """Return all drive IDs visible to the service principal.""" @@ -324,9 +314,7 @@ class OneDriveConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPerm modified_ts: float | None = None if modified_str: try: - dt = datetime.fromisoformat( - modified_str.replace("Z", "+00:00") - ) + dt = datetime.fromisoformat(modified_str.replace("Z", "+00:00")) modified_ts = dt.timestamp() except ValueError: pass @@ -335,11 +323,7 @@ class OneDriveConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPerm if since_epoch and modified_ts and modified_ts < since_epoch: continue - doc_updated_at = ( - datetime.fromtimestamp(modified_ts, tz=timezone.utc) - if modified_ts - else datetime.now(timezone.utc) - ) + doc_updated_at = datetime.fromtimestamp(modified_ts, tz=timezone.utc) if modified_ts else datetime.now(timezone.utc) doc = Document( id=item["id"], source="onedrive", @@ -351,11 +335,7 @@ class OneDriveConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPerm metadata={ "drive_id": drive_id, "web_url": item.get("webUrl", ""), - "created_by": ( - item.get("createdBy", {}) - .get("user", {}) - .get("displayName", "") - ), + "created_by": (item.get("createdBy", {}).get("user", {}).get("displayName", "")), }, ) batch.append(doc) diff --git a/common/data_source/outlook_connector.py b/common/data_source/outlook_connector.py index 395f03c31a..08aa210f52 100644 --- a/common/data_source/outlook_connector.py +++ b/common/data_source/outlook_connector.py @@ -58,6 +58,7 @@ def _redact(value: str | None) -> str: class OutlookCheckpoint(ConnectorCheckpoint): """Outlook-specific checkpoint tracking delta links per user mailbox.""" + delta_links: dict[str, str] | None = None @@ -129,10 +130,7 @@ class OutlookConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermS client_secret = credentials.get("client_secret") if not all([tenant_id, client_id, client_secret]): - raise ConnectorMissingCredentialError( - "Outlook credentials are incomplete (tenant_id, client_id, " - "client_secret required)" - ) + raise ConnectorMissingCredentialError("Outlook credentials are incomplete (tenant_id, client_id, client_secret required)") self._tenant_id = tenant_id @@ -145,9 +143,7 @@ class OutlookConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermS if "access_token" not in result: error = result.get("error_description", result.get("error", "unknown")) - raise ConnectorMissingCredentialError( - f"Failed to acquire Outlook access token: {error}" - ) + raise ConnectorMissingCredentialError(f"Failed to acquire Outlook access token: {error}") self._access_token = result["access_token"] return None @@ -161,32 +157,17 @@ class OutlookConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermS raise ConnectorMissingCredentialError("Outlook") # Probe: list one user (or check explicit user mailbox). - probe_url = ( - f"{_GRAPH_BASE}/users/{self.user_ids[0]}" - if self.user_ids - else f"{_GRAPH_BASE}/users?$top=1" - ) + probe_url = f"{_GRAPH_BASE}/users/{self.user_ids[0]}" if self.user_ids else f"{_GRAPH_BASE}/users?$top=1" resp = self._get(probe_url) if resp.status_code == 401: - raise ConnectorMissingCredentialError( - "Outlook access token is invalid or expired." - ) + raise ConnectorMissingCredentialError("Outlook access token is invalid or expired.") if resp.status_code == 403: - raise InsufficientPermissionsError( - "The service principal lacks the 'Mail.Read' (and possibly " - "'User.Read.All') permission required by the Outlook connector." - ) + raise InsufficientPermissionsError("The service principal lacks the 'Mail.Read' (and possibly 'User.Read.All') permission required by the Outlook connector.") if resp.status_code == 404 and self.user_ids: - raise ConnectorValidationError( - f"Configured Outlook mailbox '{self.user_ids[0]}' does not exist " - "in this tenant." - ) + raise ConnectorValidationError(f"Configured Outlook mailbox '{self.user_ids[0]}' does not exist in this tenant.") if not resp.ok: - raise UnexpectedValidationError( - f"Outlook validation failed (HTTP {resp.status_code}): " - f"{resp.text[:200]}" - ) + raise UnexpectedValidationError(f"Outlook validation failed (HTTP {resp.status_code}): {resp.text[:200]}") # ------------------------------------------------------------------ # Checkpoint helpers @@ -205,9 +186,7 @@ class OutlookConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermS # Core data loading # ------------------------------------------------------------------ - def poll_source( - self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch - ) -> Any: + def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Any: """Return messages received at or after *start* (epoch seconds). Kept for callers that prefer the time-window interface; internally @@ -305,16 +284,11 @@ class OutlookConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermS resp.status_code, body_snippet, ) - raise UnexpectedValidationError( - f"Outlook Graph request failed ({context}): " - f"HTTP {resp.status_code} {body_snippet}" - ) + raise UnexpectedValidationError(f"Outlook Graph request failed ({context}): HTTP {resp.status_code} {body_snippet}") try: return resp.json() except ValueError as exc: - raise UnexpectedValidationError( - f"Outlook Graph response is not JSON ({context}): {exc}" - ) + raise UnexpectedValidationError(f"Outlook Graph response is not JSON ({context}): {exc}") def _list_user_ids(self) -> list[str]: """Return mailbox identifiers to sync.""" @@ -335,14 +309,9 @@ class OutlookConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermS def _delta_url(self, user_id: str, delta_link: str | None = None) -> str: if delta_link: return delta_link - return ( - f"{_GRAPH_BASE}/users/{user_id}/mailFolders/" - f"{self.folder}/messages/delta" - ) + return f"{_GRAPH_BASE}/users/{user_id}/mailFolders/{self.folder}/messages/delta" - def _message_to_document( - self, msg: dict[str, Any], user_id: str - ) -> Document | None: + def _message_to_document(self, msg: dict[str, Any], user_id: str) -> Document | None: subject: str = msg.get("subject") or "(no subject)" body_obj = msg.get("body") or {} @@ -357,25 +326,13 @@ class OutlookConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermS received_dt: datetime | None = None if received_str: try: - received_dt = datetime.fromisoformat( - received_str.replace("Z", "+00:00") - ) + received_dt = datetime.fromisoformat(received_str.replace("Z", "+00:00")) except ValueError: pass - from_addr = ( - msg.get("from", {}).get("emailAddress", {}) if msg.get("from") else {} - ) - to_recipients: list[str] = [ - r.get("emailAddress", {}).get("address", "") - for r in (msg.get("toRecipients") or []) - if r.get("emailAddress", {}).get("address") - ] - cc_recipients: list[str] = [ - r.get("emailAddress", {}).get("address", "") - for r in (msg.get("ccRecipients") or []) - if r.get("emailAddress", {}).get("address") - ] + from_addr = msg.get("from", {}).get("emailAddress", {}) if msg.get("from") else {} + to_recipients: list[str] = [r.get("emailAddress", {}).get("address", "") for r in (msg.get("toRecipients") or []) if r.get("emailAddress", {}).get("address")] + cc_recipients: list[str] = [r.get("emailAddress", {}).get("address", "") for r in (msg.get("ccRecipients") or []) if r.get("emailAddress", {}).get("address")] header_lines = [ f"From: {from_addr.get('name', '')} <{from_addr.get('address', '')}>", @@ -436,9 +393,7 @@ class OutlookConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermS next_delta: str | None = None while url: - data = self._get_json( - url, context=f"delta user={_redact(user_id)}" - ) + data = self._get_json(url, context=f"delta user={_redact(user_id)}") for msg in data.get("value", []): # Skip removed/deleted messages signalled by delta semantics @@ -449,9 +404,7 @@ class OutlookConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermS received_ts: float | None = None if received_str: try: - received_ts = datetime.fromisoformat( - received_str.replace("Z", "+00:00") - ).timestamp() + received_ts = datetime.fromisoformat(received_str.replace("Z", "+00:00")).timestamp() except ValueError: pass diff --git a/common/data_source/rdbms_connector.py b/common/data_source/rdbms_connector.py index 4d38d303f1..33551fb429 100644 --- a/common/data_source/rdbms_connector.py +++ b/common/data_source/rdbms_connector.py @@ -25,6 +25,7 @@ from common.data_source.models import Document, SlimDocument class DatabaseType(str, Enum): """Supported database types.""" + MYSQL = "mysql" POSTGRESQL = "postgresql" MSSQL = "mssql" @@ -43,6 +44,7 @@ class RDBMSConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): 6. For incremental sync, treat the timestamp column as an ordered cursor and only compare values by size. 7. For deleted-file sync, read a slim snapshot of current row IDs and let the sync worker remove stale documents. """ + def __init__( self, db_type: str, @@ -58,7 +60,7 @@ class RDBMSConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): ) -> None: """ Initialize the RDBMS connector. - + Args: db_type: Database type ('mysql', 'postgresql', or 'mssql') host: Database host @@ -83,7 +85,7 @@ class RDBMSConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): self.id_column = id_column.strip() if id_column else None self.timestamp_column = timestamp_column.strip() if timestamp_column else None self.batch_size = batch_size - + self._connection = None self._credentials: Dict[str, Any] = {} self._sync_connector_id: str | None = None @@ -131,12 +133,12 @@ class RDBMSConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): def load_credentials(self, credentials: Dict[str, Any]) -> Dict[str, Any] | None: """Load database credentials.""" logging.debug(f"Loading credentials for {self.db_type} database: {self.database}") - + required_keys = ["username", "password"] for key in required_keys: if not credentials.get(key): raise ConnectorMissingCredentialError(f"RDBMS ({self.db_type}): missing {key}") - + self._credentials = credentials return None @@ -144,17 +146,15 @@ class RDBMSConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): """Create and return a database connection.""" if self._connection is not None: return self._connection - + username = self._credentials.get("username") password = self._credentials.get("password") - + if self.db_type == DatabaseType.MYSQL: try: import mysql.connector except ImportError: - raise ConnectorValidationError( - "MySQL connector not installed. Please install mysql-connector-python." - ) + raise ConnectorValidationError("MySQL connector not installed. Please install mysql-connector-python.") try: self._connection = mysql.connector.connect( host=self.host, @@ -162,7 +162,7 @@ class RDBMSConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): database=self.database, user=username, password=password, - charset='utf8mb4', + charset="utf8mb4", use_unicode=True, ) except Exception as e: @@ -171,9 +171,7 @@ class RDBMSConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): try: import psycopg2 except ImportError: - raise ConnectorValidationError( - "PostgreSQL connector not installed. Please install psycopg2-binary." - ) + raise ConnectorValidationError("PostgreSQL connector not installed. Please install psycopg2-binary.") try: self._connection = psycopg2.connect( host=self.host, @@ -188,9 +186,7 @@ class RDBMSConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): try: import pymssql except ImportError: - raise ConnectorValidationError( - "pymssql not installed. Please install pymssql." - ) + raise ConnectorValidationError("pymssql not installed. Please install pymssql.") try: self._connection = pymssql.connect( server=self.host, @@ -218,26 +214,19 @@ class RDBMSConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): """Get list of all tables in the database.""" connection = self._get_connection() cursor = connection.cursor() - + try: if self.db_type == DatabaseType.MYSQL: cursor.execute("SHOW TABLES") elif self.db_type == DatabaseType.MSSQL: - cursor.execute( - "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES " - "WHERE TABLE_TYPE = 'BASE TABLE'" - ) + cursor.execute("SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE = 'BASE TABLE'") else: - cursor.execute( - "SELECT table_name FROM information_schema.tables " - "WHERE table_schema = 'public' AND table_type = 'BASE TABLE'" - ) + cursor.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_type = 'BASE TABLE'") tables = [row[0] for row in cursor.fetchall()] return tables finally: cursor.close() - def _get_base_queries(self) -> list[str]: """Return the list of base SQL queries to execute. @@ -249,7 +238,6 @@ class RDBMSConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): return [self.query.rstrip(";")] return [f"SELECT * FROM {table}" for table in self._get_tables()] - @staticmethod def _strip_trailing_order_by(query: str) -> str: """Remove a trailing top-level ORDER BY clause. @@ -275,7 +263,6 @@ class RDBMSConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): inner = self._strip_trailing_order_by(base_query) return f"SELECT {select_clause} FROM ({inner}) AS ragflow_src" - @staticmethod def serialize_cursor_value(value: Any) -> Any: """Serialize a cursor value to a JSON-safe representation. @@ -291,7 +278,6 @@ class RDBMSConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): } return value - @staticmethod def deserialize_cursor_value(value: Any) -> Any: """Deserialize a cursor value produced by :meth:`serialize_cursor_value`. @@ -299,14 +285,10 @@ class RDBMSConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): Recognises the ``__ragflow_rdbms_cursor_type__`` wrapper and converts it back to a ``datetime``. Any other value is returned unchanged. """ - if ( - isinstance(value, dict) - and value.get("__ragflow_rdbms_cursor_type__") == "datetime" - ): + if isinstance(value, dict) and value.get("__ragflow_rdbms_cursor_type__") == "datetime": return datetime.fromisoformat(value["value"]) return value - def _format_sql_value(self, value: Any) -> str: """Format a Python value as a SQL literal suitable for embedding in a WHERE clause. @@ -330,10 +312,7 @@ class RDBMSConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): return str(value) if isinstance(value, str): return "'" + value.replace("'", "''") + "'" - raise ConnectorValidationError( - f"Unsupported timestamp cursor value type: {type(value).__name__}" - ) - + raise ConnectorValidationError(f"Unsupported timestamp cursor value type: {type(value).__name__}") def _build_time_filtered_query( self, @@ -358,27 +337,18 @@ class RDBMSConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): conditions = [] if start is not None: - conditions.append( - f"ragflow_src.{self.timestamp_column} >= {self._format_sql_value(start)}" - ) + conditions.append(f"ragflow_src.{self.timestamp_column} >= {self._format_sql_value(start)}") if end is not None: - conditions.append( - f"ragflow_src.{self.timestamp_column} <= {self._format_sql_value(end)}" - ) + conditions.append(f"ragflow_src.{self.timestamp_column} <= {self._format_sql_value(end)}") query = self._wrap_query(base_query) if conditions: query = f"{query} WHERE {' AND '.join(conditions)}" return query - def _build_max_timestamp_query(self, base_query: str) -> str: """Build a query that returns the maximum value of the timestamp column.""" - return ( - f"SELECT MAX(ragflow_src.{self.timestamp_column}) " - f"FROM ({base_query}) AS ragflow_src" - ) - + return f"SELECT MAX(ragflow_src.{self.timestamp_column}) FROM ({base_query}) AS ragflow_src" def _build_slim_query(self, base_query: str) -> str: """Build a lightweight query that fetches only the columns needed to identify documents. @@ -395,7 +365,6 @@ class RDBMSConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): select_clause = ", ".join(f"ragflow_src.{column}" for column in columns) return self._wrap_query(base_query, select_clause) - def _build_content(self, row_dict: Dict[str, Any]) -> str: """Build the document content string from the resolved content columns of a row.""" content_parts = [] @@ -408,7 +377,6 @@ class RDBMSConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): content_parts.append(f"【{col}】:\n{value}") return "\n\n".join(content_parts) - def _build_document_id_from_row(self, row_dict: Dict[str, Any]) -> str: """Derive a stable document id from a database row. @@ -421,7 +389,6 @@ class RDBMSConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): content_hash = hashlib.md5(content.encode()).hexdigest() return f"{self.db_type}:{self.database}:{content_hash}" - def _row_to_document( self, row: Union[tuple, list, Dict[str, Any]], @@ -456,12 +423,7 @@ class RDBMSConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): resolved_content_columns = self._content_columns_for_row(row_dict) first_content_col = resolved_content_columns[0] if resolved_content_columns else "record" - semantic_id = ( - str(row_dict.get(first_content_col, "database_record")) - .replace("\n", " ") - .replace("\r", " ") - .strip()[:100] - ) + semantic_id = str(row_dict.get(first_content_col, "database_record")).replace("\n", " ").replace("\r", " ").strip()[:100] blob = content.encode("utf-8") return Document( @@ -475,7 +437,6 @@ class RDBMSConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): metadata=metadata if metadata else None, ) - def _yield_documents_from_query( self, query: str, @@ -483,28 +444,28 @@ class RDBMSConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): """Generate documents from a single query.""" connection = self._get_connection() cursor = connection.cursor() - + try: logging.info(f"Executing query: {query[:200]}...") cursor.execute(query) column_names = [desc[0] for desc in cursor.description] - + batch: list[Document] = [] for row in cursor: try: doc = self._row_to_document(row, column_names) batch.append(doc) - + if len(batch) >= self.batch_size: yield batch batch = [] except Exception as e: logging.warning(f"Error converting row to document: {e}") continue - + if batch: yield batch - + finally: try: cursor.fetchall() @@ -512,7 +473,6 @@ class RDBMSConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): pass cursor.close() - def _yield_slim_documents_from_query( self, query: str, @@ -547,7 +507,6 @@ class RDBMSConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): pass cursor.close() - def get_max_cursor_value(self) -> Any: """Return the maximum value of the timestamp column across all base queries. @@ -577,7 +536,6 @@ class RDBMSConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): return max_cursor_value - def _yield_documents( self, start: Any = None, @@ -595,13 +553,11 @@ class RDBMSConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): finally: self._close_connection() - def load_from_state(self) -> Generator[list[Document], None, None]: """Load all documents from the database (full sync).""" logging.debug(f"Loading all records from {self.db_type} database: {self.database}") return self._yield_documents() - def retrieve_all_slim_docs_perm_sync( self, callback: Any = None, @@ -615,9 +571,7 @@ class RDBMSConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): try: for base_query in base_queries: - yield from self._yield_slim_documents_from_query( - self._build_slim_query(base_query) - ) + yield from self._yield_slim_documents_from_query(self._build_slim_query(base_query)) finally: self._close_connection() @@ -635,14 +589,12 @@ class RDBMSConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): return self._pending_sync_cursor_value = self.get_max_cursor_value() - def get_saved_sync_cursor_value(self) -> Any: """Return the cursor value that was persisted at the end of the previous sync run.""" if self._sync_config is None: return None return self.deserialize_cursor_value(self._sync_config.get("sync_cursor_value")) - def persist_sync_state(self) -> None: """Write the pending cursor value back to the connector config in the database. @@ -655,13 +607,10 @@ class RDBMSConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): from api.db.services.connector_service import ConnectorService updated_conf = copy.deepcopy(self._sync_config) - updated_conf["sync_cursor_value"] = self.serialize_cursor_value( - self._pending_sync_cursor_value - ) + updated_conf["sync_cursor_value"] = self.serialize_cursor_value(self._pending_sync_cursor_value) ConnectorService.update_by_id(self._sync_connector_id, {"config": updated_conf}) self._sync_config = updated_conf - def load_from_cursor_range( self, start_value: Any = None, @@ -681,28 +630,21 @@ class RDBMSConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): return iter(()) return self._yield_documents(start_value, end_value) - - def poll_source( - self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch - ) -> Generator[list[Document], None, None]: + def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Generator[list[Document], None, None]: """Poll for new/updated documents since the last sync (incremental sync).""" if not self.timestamp_column: - logging.warning( - "No timestamp column configured for incremental sync. " - "Falling back to full sync." - ) + logging.warning("No timestamp column configured for incremental sync. Falling back to full sync.") return self.load_from_state() return self._yield_documents(start, end) - def validate_connector_settings(self) -> None: """Validate connector settings by testing the connection.""" if not self._credentials: raise ConnectorMissingCredentialError("RDBMS credentials not loaded.") - + if not self.host: raise ConnectorValidationError("Database host is required.") - + if not self.database: raise ConnectorValidationError("Database name is required.") @@ -712,34 +654,32 @@ class RDBMSConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): try: connection = self._get_connection() cursor = connection.cursor() - + test_query = "SELECT 1" cursor.execute(test_query) cursor.fetchone() cursor.close() - + logging.info(f"Successfully connected to {self.db_type} database: {self.database}") - + except ConnectorValidationError: self._close_connection() raise except Exception as e: self._close_connection() - raise ConnectorValidationError( - f"Failed to connect to {self.db_type} database: {str(e)}" - ) + raise ConnectorValidationError(f"Failed to connect to {self.db_type} database: {str(e)}") finally: self._close_connection() if __name__ == "__main__": import os - + credentials_dict = { "username": os.environ.get("DB_USERNAME", "root"), "password": os.environ.get("DB_PASSWORD", ""), } - + connector = RDBMSConnector( db_type="mysql", host=os.environ.get("DB_HOST", "localhost"), @@ -751,16 +691,16 @@ if __name__ == "__main__": id_column="id", timestamp_column="updated_at", ) - + try: connector.load_credentials(credentials_dict) connector.validate_connector_settings() - + for batch in connector.load_from_state(): print(f"Batch of {len(batch)} documents:") for doc in batch: print(f" - {doc.id}: {doc.semantic_identifier}") break - + except Exception as e: print(f"Error: {e}") diff --git a/common/data_source/rest_api_connector.py b/common/data_source/rest_api_connector.py index f7196114c8..91d905089f 100644 --- a/common/data_source/rest_api_connector.py +++ b/common/data_source/rest_api_connector.py @@ -42,7 +42,7 @@ try: except Exception: # pragma: no cover _jsonpath = None -_FIELD_SEGMENT_RE = re.compile(r'^(?P[^\[\]]+)(\[(?P\d+|\*)\])?$') +_FIELD_SEGMENT_RE = re.compile(r"^(?P[^\[\]]+)(\[(?P\d+|\*)\])?$") _DEFAULT_MAX_PAGES = 1000 _REDIRECT_STATUSES = frozenset({301, 302, 303, 307, 308}) _MAX_REDIRECTS = 5 @@ -217,17 +217,8 @@ class RestAPIConnector(LoadConnector, PollConnector): logger.debug("Skipping non-IP address resolved from %r: %r", hostname, ip_str) continue - if ( - ip_obj.is_loopback - or ip_obj.is_private - or ip_obj.is_link_local - or ip_obj.is_reserved - or ip_obj.is_multicast - ): - msg = ( - f"REST API connector URL {url!r} resolves to disallowed address {ip_str} " - "(localhost, private, link-local, reserved, or multicast addresses are blocked)." - ) + if ip_obj.is_loopback or ip_obj.is_private or ip_obj.is_link_local or ip_obj.is_reserved or ip_obj.is_multicast: + msg = f"REST API connector URL {url!r} resolves to disallowed address {ip_str} (localhost, private, link-local, reserved, or multicast addresses are blocked)." logger.warning(msg) raise ConnectorValidationError(msg) @@ -266,14 +257,10 @@ class RestAPIConnector(LoadConnector, PollConnector): for k, v_list in parse_qs(parsed.query, keep_blank_values=True).items(): self._url_params[k] = v_list[-1] - self._explicit_query_params: Dict[str, str] = ( - _text_to_dict(query_params) if isinstance(query_params, str) else (query_params or {}) - ) + self._explicit_query_params: Dict[str, str] = _text_to_dict(query_params) if isinstance(query_params, str) else (query_params or {}) self.url = self._base_url self.method = (method or "GET").upper() - self._base_headers: Dict[str, str] = ( - _text_to_dict(headers) if isinstance(headers, str) else (headers or {}) - ) + self._base_headers: Dict[str, str] = _text_to_dict(headers) if isinstance(headers, str) else (headers or {}) self.auth_type = auth_type or AuthType.NONE self.auth_config: Dict[str, Any] = auth_config or {} self.items_path = items_path @@ -282,10 +269,7 @@ class RestAPIConnector(LoadConnector, PollConnector): self.metadata_fields: List[str] = metadata_fields or [] self.pagination_type = pagination_type or PaginationType.NONE self.pagination_config: Dict[str, Any] = pagination_config or {} - self._static_request_body: Dict[str, Any] = ( - request_body if request_body is not None - else self.pagination_config.get("request_body") or {} - ) + self._static_request_body: Dict[str, Any] = request_body if request_body is not None else self.pagination_config.get("request_body") or {} self.poll_timestamp_field = poll_timestamp_field self.batch_size = batch_size self.max_pages = max_pages @@ -320,21 +304,16 @@ class RestAPIConnector(LoadConnector, PollConnector): if self.auth_type == AuthType.API_KEY_HEADER: header_name = self.auth_config.get("header_name") - api_key = ( - self._credentials.get("api_key") - or self.auth_config.get("api_key_value") - or self.auth_config.get("api_key") - ) + api_key = self._credentials.get("api_key") or self.auth_config.get("api_key_value") or self.auth_config.get("api_key") if not header_name or not api_key: logging.warning( - "REST API auth setup failed: header_name=%s, api_key present=%s, " - "credentials keys=%s, auth_config keys=%s", - header_name, bool(api_key), - list(self._credentials.keys()), list(self.auth_config.keys()), - ) - raise ConnectorMissingCredentialError( - "REST API (api_key_header) requires 'header_name' in auth_config and 'api_key' in credentials" + "REST API auth setup failed: header_name=%s, api_key present=%s, credentials keys=%s, auth_config keys=%s", + header_name, + bool(api_key), + list(self._credentials.keys()), + list(self.auth_config.keys()), ) + raise ConnectorMissingCredentialError("REST API (api_key_header) requires 'header_name' in auth_config and 'api_key' in credentials") self._auth_headers[header_name] = str(api_key) logging.info("REST API auth configured: header '%s' set.", header_name) return @@ -455,15 +434,10 @@ class RestAPIConnector(LoadConnector, PollConnector): """Full fetch with pagination.""" return self._yield_documents(time_window=None) - def poll_source( - self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch - ) -> Generator[List[Document], None, None]: + def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Generator[List[Document], None, None]: """Incremental fetch; filters by ``poll_timestamp_field`` if configured.""" if not self.poll_timestamp_field: - logging.warning( - "poll_source called without poll_timestamp_field; " - "falling back to full fetch with in-memory filtering." - ) + logging.warning("poll_source called without poll_timestamp_field; falling back to full fetch with in-memory filtering.") return self._yield_documents( time_window=( datetime.fromtimestamp(start, tz=timezone.utc), @@ -585,7 +559,10 @@ class RestAPIConnector(LoadConnector, PollConnector): yield item @retry_builder( - tries=5, delay=1, max_delay=30, backoff=2, + tries=5, + delay=1, + max_delay=30, + backoff=2, exceptions=(requests.ConnectionError, requests.Timeout, requests.HTTPError), ) def _fetch_page(self, params: Dict[str, Any]) -> Any: @@ -601,7 +578,8 @@ class RestAPIConnector(LoadConnector, PollConnector): sensitive = {"authorization", "apikey", "api-key", "x-api-key"} logging.debug( "REST API request: %s %s | params=%s | headers=%s", - self.method, url, + self.method, + url, {k: ("***" if k.lower() in sensitive else v) for k, v in query_params.items()}, {k: ("***" if k.lower() in sensitive else v) for k, v in headers.items()}, ) @@ -631,16 +609,15 @@ class RestAPIConnector(LoadConnector, PollConnector): if status in (401, 403): sensitive = {"authorization", "apikey", "api-key", "x-api-key"} logging.warning( - "REST API %d for %s %s | auth_type=%s | " - "request header keys=%s | auth_header keys=%s", - status, self.method, resp.url, + "REST API %d for %s %s | auth_type=%s | request header keys=%s | auth_header keys=%s", + status, + self.method, + resp.url, self.auth_type, [k for k in headers], [k for k in self._auth_headers], ) - raise ConnectorMissingCredentialError( - f"REST API authentication failed with status {status}" - ) from exc + raise ConnectorMissingCredentialError(f"REST API authentication failed with status {status}") from exc if status is not None and 400 <= status < 500 and status != 429: logging.warning( "REST API client error %d for %s %s; not retrying.", @@ -648,9 +625,7 @@ class RestAPIConnector(LoadConnector, PollConnector): self.method, resp.url, ) - raise ConnectorValidationError( - f"REST API request failed with non-retriable client error status {status}" - ) from exc + raise ConnectorValidationError(f"REST API request failed with non-retriable client error status {status}") from exc raise try: @@ -660,14 +635,16 @@ class RestAPIConnector(LoadConnector, PollConnector): # Headers that carry auth state. Stripped on cross-origin redirects to # prevent credential exfiltration to a third-party host. (Coderabbit MAJOR #3486038792) - _AUTH_SENSITIVE_HEADER_KEYS = frozenset({ - "authorization", - "proxy-authorization", - "apikey", - "api-key", - "x-api-key", - "x-auth-token", - }) + _AUTH_SENSITIVE_HEADER_KEYS = frozenset( + { + "authorization", + "proxy-authorization", + "apikey", + "api-key", + "x-api-key", + "x-auth-token", + } + ) def _safe_request( self, @@ -697,9 +674,7 @@ class RestAPIConnector(LoadConnector, PollConnector): try: hostname, pin_ip = assert_url_is_safe(current_url) except ValueError as exc: - raise ConnectorValidationError( - f"Unsafe REST API URL: {exc}" - ) from exc + raise ConnectorValidationError(f"Unsafe REST API URL: {exc}") from exc with pin_dns(hostname, pin_ip): if current_method == "GET": resp = rl_requests.get( @@ -747,11 +722,7 @@ class RestAPIConnector(LoadConnector, PollConnector): # crosses to a different origin so a public→private redirect chain # cannot exfiltrate Bearer/Basic/API-key headers. if next_netloc and next_netloc != previous_netloc: - headers = { - k: v - for k, v in headers.items() - if k.lower() not in self._AUTH_SENSITIVE_HEADER_KEYS - } + headers = {k: v for k, v in headers.items() if k.lower() not in self._AUTH_SENSITIVE_HEADER_KEYS} current_auth = None previous_netloc = next_netloc @@ -829,9 +800,7 @@ class RestAPIConnector(LoadConnector, PollConnector): try: matches = _jsonpath(response_json, self.items_path) except Exception as exc: - raise ConnectorValidationError( - f"Failed to apply items JSONPath '{self.items_path}': {exc}" - ) from exc + raise ConnectorValidationError(f"Failed to apply items JSONPath '{self.items_path}': {exc}") from exc if not matches: return [] if len(matches) == 1 and isinstance(matches[0], list): @@ -1060,6 +1029,7 @@ class RestAPIConnector(LoadConnector, PollConnector): class _SafeDict(dict): """Dict subclass that returns empty string for missing keys in format_map.""" + def __missing__(self, key: str) -> str: return "" diff --git a/common/data_source/salesforce_connector.py b/common/data_source/salesforce_connector.py index 10479ccf34..5e0c71f1c5 100644 --- a/common/data_source/salesforce_connector.py +++ b/common/data_source/salesforce_connector.py @@ -134,9 +134,7 @@ class SalesforceConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPe client_secret = credentials.get("client_secret") if not all([instance_url, client_id, client_secret]): - raise ConnectorMissingCredentialError( - "Salesforce credentials are incomplete (instance_url, client_id, client_secret required)" - ) + raise ConnectorMissingCredentialError("Salesforce credentials are incomplete (instance_url, client_id, client_secret required)") token_url = urljoin(instance_url + "/", "services/oauth2/token") try: @@ -150,9 +148,7 @@ class SalesforceConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPe timeout=60, ) except requests.RequestException as exc: - raise ConnectorMissingCredentialError( - f"Salesforce token request failed: {exc}" - ) + raise ConnectorMissingCredentialError(f"Salesforce token request failed: {exc}") if not resp.ok: # Salesforce returns {"error": "...", "error_description": "..."} @@ -161,9 +157,7 @@ class SalesforceConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPe detail = body.get("error_description") or body.get("error") or resp.text except ValueError: detail = resp.text[:200] - raise ConnectorMissingCredentialError( - f"Failed to acquire Salesforce access token (HTTP {resp.status_code}): {detail}" - ) + raise ConnectorMissingCredentialError(f"Failed to acquire Salesforce access token (HTTP {resp.status_code}): {detail}") data = resp.json() token = data.get("access_token") @@ -172,9 +166,7 @@ class SalesforceConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPe # correct host even when the configured URL went stale. canonical = (data.get("instance_url") or "").rstrip("/") if not token: - raise ConnectorMissingCredentialError( - "Salesforce token response did not contain access_token" - ) + raise ConnectorMissingCredentialError("Salesforce token response did not contain access_token") self._access_token = token self._instance_url = canonical or instance_url @@ -193,38 +185,24 @@ class SalesforceConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPe # the connected app or the user lacks API access altogether. resp = self._get(f"{self._base()}/sobjects") if resp.status_code == 401: - raise ConnectorMissingCredentialError( - "Salesforce access token is invalid or expired." - ) + raise ConnectorMissingCredentialError("Salesforce access token is invalid or expired.") if resp.status_code == 403: - raise InsufficientPermissionsError( - "The Salesforce execution user lacks API access; enable the 'API Enabled' profile permission." - ) + raise InsufficientPermissionsError("The Salesforce execution user lacks API access; enable the 'API Enabled' profile permission.") if not resp.ok: - raise UnexpectedValidationError( - f"Salesforce validation failed (HTTP {resp.status_code}): {resp.text[:200]}" - ) + raise UnexpectedValidationError(f"Salesforce validation failed (HTTP {resp.status_code}): {resp.text[:200]}") try: payload = resp.json() except ValueError as exc: - raise ConnectorValidationError( - f"Salesforce /sobjects response is not JSON: {exc}" - ) + raise ConnectorValidationError(f"Salesforce /sobjects response is not JSON: {exc}") if "sobjects" not in payload: - raise ConnectorValidationError( - "Unexpected response format from Salesforce /sobjects." - ) + raise ConnectorValidationError("Unexpected response format from Salesforce /sobjects.") # Fail fast on typos / inaccessible objects instead of silently # missing their data during sync. The global describe lists every # object the user can see plus its queryable flag, so we can vet the # configured objects without an extra call per object. - queryable = { - so["name"]: bool(so.get("queryable", False)) - for so in payload.get("sobjects", []) - if isinstance(so, dict) and so.get("name") - } + queryable = {so["name"]: bool(so.get("queryable", False)) for so in payload.get("sobjects", []) if isinstance(so, dict) and so.get("name")} unknown: list[str] = [] not_queryable: list[str] = [] for obj in self.objects: @@ -247,11 +225,7 @@ class SalesforceConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPe problems.append(f"unknown object(s): {', '.join(sorted(unknown))}") if not_queryable: problems.append(f"non-queryable object(s): {', '.join(sorted(not_queryable))}") - raise ConnectorValidationError( - "Salesforce 'objects' configuration is invalid — " - + "; ".join(problems) - + ". Check for typos and that the execution user has read access to each object." - ) + raise ConnectorValidationError("Salesforce 'objects' configuration is invalid — " + "; ".join(problems) + ". Check for typos and that the execution user has read access to each object.") # ------------------------------------------------------------------ # Checkpoint helpers @@ -270,9 +244,7 @@ class SalesforceConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPe # Core data loading # ------------------------------------------------------------------ - def poll_source( - self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch - ) -> Any: + def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Any: return self._iter_documents(since_epoch=start, until_epoch=end if end else None) def load_from_checkpoint( @@ -367,18 +339,12 @@ class SalesforceConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPe body_snippet, ) if _is_object_unavailable(resp): - raise SalesforceObjectUnavailable( - f"Salesforce object unavailable ({context}): HTTP {resp.status_code} {body_snippet}" - ) - raise UnexpectedValidationError( - f"Salesforce request failed ({context}): HTTP {resp.status_code} {body_snippet}" - ) + raise SalesforceObjectUnavailable(f"Salesforce object unavailable ({context}): HTTP {resp.status_code} {body_snippet}") + raise UnexpectedValidationError(f"Salesforce request failed ({context}): HTTP {resp.status_code} {body_snippet}") try: return resp.json() except ValueError as exc: - raise UnexpectedValidationError( - f"Salesforce response is not JSON ({context}): {exc}" - ) + raise UnexpectedValidationError(f"Salesforce response is not JSON ({context}): {exc}") def _describe_fields(self, obj: str) -> list[str]: """Return field API names for *obj*. Filters out compound types @@ -415,14 +381,10 @@ class SalesforceConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPe field_list = ",".join(fields) filters = [] if since_epoch: - since_iso = datetime.fromtimestamp(since_epoch, tz=timezone.utc).strftime( - "%Y-%m-%dT%H:%M:%SZ" - ) + since_iso = datetime.fromtimestamp(since_epoch, tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") filters.append(f"SystemModstamp > {since_iso}") if until_epoch: - until_iso = datetime.fromtimestamp(until_epoch, tz=timezone.utc).strftime( - "%Y-%m-%dT%H:%M:%SZ" - ) + until_iso = datetime.fromtimestamp(until_epoch, tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") filters.append(f"SystemModstamp <= {until_iso}") where = f" WHERE {' AND '.join(filters)}" if filters else "" soql = f"SELECT {field_list} FROM {obj}{where} ORDER BY SystemModstamp ASC" @@ -513,9 +475,7 @@ class SalesforceConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPe modified_dt: datetime | None = None if modified_str: try: - modified_dt = datetime.fromisoformat( - modified_str.replace("Z", "+00:00") - ) + modified_dt = datetime.fromisoformat(modified_str.replace("Z", "+00:00")) except ValueError: modified_dt = None @@ -525,12 +485,7 @@ class SalesforceConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPe # ``Subject`` (Case) or ``Title`` (Knowledge); last # resort is ``/`` so the doc list is # never blank-titled. - name = ( - record.get("Name") - or record.get("Subject") - or record.get("Title") - or f"{obj}/{rec_id}" - ) + name = record.get("Name") or record.get("Subject") or record.get("Title") or f"{obj}/{rec_id}" body = self._record_to_text(obj, record) blob = body.encode("utf-8") diff --git a/common/data_source/seafile_connector.py b/common/data_source/seafile_connector.py index 66bcf954fd..c9ee59e083 100644 --- a/common/data_source/seafile_connector.py +++ b/common/data_source/seafile_connector.py @@ -1,4 +1,5 @@ """SeaFile connector with granular sync support""" + import logging from datetime import datetime, timezone from typing import Any, Optional @@ -32,6 +33,7 @@ from common.data_source.models import ( logger = logging.getLogger(__name__) + class SeaFileConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): """SeaFile connector supporting account-, library- and directory-level sync. @@ -65,14 +67,13 @@ class SeaFileConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): self.repo_id = repo_id self.sync_path = self._normalise_path(sync_path) - self.token: Optional[str] = None # account-level - self.repo_token: Optional[str] = None # library-scoped + self.token: Optional[str] = None # account-level + self.repo_token: Optional[str] = None # library-scoped self.current_user_email: Optional[str] = None self.size_threshold: int = BLOB_STORAGE_SIZE_THRESHOLD self._validate_scope_params() - @staticmethod def _normalise_path(path: Optional[str]) -> str: if not path: @@ -114,26 +115,20 @@ class SeaFileConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): logger.warning("Unparseable mtime %r, using current time", raw_mtime) return datetime.now(timezone.utc) - + def _validate_scope_params(self) -> None: if self.sync_scope in (SeafileSyncScope.LIBRARY, SeafileSyncScope.DIRECTORY): if not self.repo_id: - raise ConnectorValidationError( - f"sync_scope={self.sync_scope.value!r} requires 'repo_id'." - ) + raise ConnectorValidationError(f"sync_scope={self.sync_scope.value!r} requires 'repo_id'.") if self.sync_scope == SeafileSyncScope.DIRECTORY: if self.sync_path == "/": - raise ConnectorValidationError( - "sync_scope='directory' requires a non-root 'sync_path'. " - "Use sync_scope='library' to sync an entire library." - ) + raise ConnectorValidationError("sync_scope='directory' requires a non-root 'sync_path'. Use sync_scope='library' to sync an entire library.") @property def _use_repo_token(self) -> bool: """Whether we should use repo-token endpoints.""" return self.repo_token is not None - def _account_headers(self) -> dict[str, str]: if not self.token: raise ConnectorMissingCredentialError("Account token not set") @@ -154,7 +149,10 @@ class SeaFileConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): """GET against /api2/... using the account token.""" url = f"{self.seafile_url}/api2/{endpoint.lstrip('/')}" resp = rl_requests.get( - url, headers=self._account_headers(), params=params, timeout=60, + url, + headers=self._account_headers(), + params=params, + timeout=60, ) return resp @@ -162,11 +160,13 @@ class SeaFileConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): """GET against /api/v2.1/via-repo-token/... using the repo token.""" url = f"{self.seafile_url}/api/v2.1/via-repo-token/{endpoint.lstrip('/')}" resp = rl_requests.get( - url, headers=self._repo_token_headers(), params=params, timeout=60, + url, + headers=self._repo_token_headers(), + params=params, + timeout=60, ) return resp - def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: logger.debug("Loading credentials for SeaFile server %s", self.seafile_url) @@ -189,19 +189,14 @@ class SeaFileConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): ) if not self.token and not self.repo_token: - raise ConnectorMissingCredentialError( - "SeaFile requires 'seafile_token', 'repo_token', " - "or 'username'/'password'." - ) + raise ConnectorMissingCredentialError("SeaFile requires 'seafile_token', 'repo_token', or 'username'/'password'.") try: self._validate_credentials() except ConnectorMissingCredentialError: raise except Exception as e: - raise CredentialExpiredError( - f"SeaFile credential validation failed: {e}" - ) + raise CredentialExpiredError(f"SeaFile credential validation failed: {e}") return None @@ -218,9 +213,7 @@ class SeaFileConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): raise CredentialExpiredError("No token returned") return token except Exception as e: - raise ConnectorMissingCredentialError( - f"Failed to authenticate with SeaFile: {e}" - ) + raise ConnectorMissingCredentialError(f"Failed to authenticate with SeaFile: {e}") def _validate_credentials(self) -> None: if self.token: @@ -247,30 +240,23 @@ class SeaFileConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): info = resp.json() logger.info( "Repo token validated — library: %s (id: %s)", - info.get("repo_name", "?"), info.get("repo_id", self.repo_id), + info.get("repo_name", "?"), + info.get("repo_id", self.repo_id), ) # Update repo_id from response if not set if not self.repo_id and info.get("repo_id"): self.repo_id = info["repo_id"] except Exception as e: - raise CredentialExpiredError( - f"Repo token validation failed: {e}" - ) + raise CredentialExpiredError(f"Repo token validation failed: {e}") def _validate_repo_access_via_account(self) -> None: repo_info = self._get_repo_info_via_account(self.repo_id) if not repo_info: - raise ConnectorValidationError( - f"Library {self.repo_id} not accessible with account token." - ) + raise ConnectorValidationError(f"Library {self.repo_id} not accessible with account token.") if self.sync_scope == SeafileSyncScope.DIRECTORY: entries = self._get_directory_entries(self.repo_id, self.sync_path) if entries is None: - raise ConnectorValidationError( - f"Directory {self.sync_path!r} does not exist " - f"in library {self.repo_id}." - ) - + raise ConnectorValidationError(f"Directory {self.sync_path!r} does not exist in library {self.repo_id}.") def validate_connector_settings(self) -> None: if not self.token and not self.repo_token: @@ -284,18 +270,20 @@ class SeaFileConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): logger.info("Validated (account scope). %d libraries.", len(libs)) elif self.sync_scope == SeafileSyncScope.LIBRARY: info = self._get_repo_info() - logger.info( - "Validated (library scope): %s", info.get("name", self.repo_id) - ) + logger.info("Validated (library scope): %s", info.get("name", self.repo_id)) elif self.sync_scope == SeafileSyncScope.DIRECTORY: entries = self._get_directory_entries(self.repo_id, self.sync_path) logger.info( "Validated (directory scope): %s:%s (%d entries)", - self.repo_id, self.sync_path, len(entries), + self.repo_id, + self.sync_path, + len(entries), ) except ( - ConnectorValidationError, ConnectorMissingCredentialError, - CredentialExpiredError, InsufficientPermissionsError, + ConnectorValidationError, + ConnectorMissingCredentialError, + CredentialExpiredError, + InsufficientPermissionsError, ): raise except Exception as e: @@ -306,7 +294,6 @@ class SeaFileConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): raise InsufficientPermissionsError("Insufficient permissions.") raise ConnectorValidationError(f"Validation failed: {repr(e)}") - @retry(tries=3, delay=1, backoff=2) def _get_libraries(self) -> list[dict]: """List all libraries (account token only).""" @@ -315,11 +302,7 @@ class SeaFileConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): libraries = resp.json() if not self.include_shared and self.current_user_email: - libraries = [ - lib for lib in libraries - if lib.get("owner") == self.current_user_email - or lib.get("owner_email") == self.current_user_email - ] + libraries = [lib for lib in libraries if lib.get("owner") == self.current_user_email or lib.get("owner_email") == self.current_user_email] return libraries @@ -378,7 +361,8 @@ class SeaFileConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): else: # GET /api2/repos/{repo_id}/dir/?p=/foo resp = self._account_get( - f"/repos/{repo_id}/dir/", params={"p": path}, + f"/repos/{repo_id}/dir/", + params={"p": path}, ) resp.raise_for_status() data = resp.json() @@ -390,27 +374,30 @@ class SeaFileConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): except Exception as e: logger.warning( - "Error fetching directory %s in repo %s: %s", path, repo_id, e, + "Error fetching directory %s in repo %s: %s", + path, + repo_id, + e, ) if raise_on_failure: raise return [] @retry(tries=3, delay=1, backoff=2) - def _get_file_download_link( - self, repo_id: str, path: str - ) -> Optional[str]: + def _get_file_download_link(self, repo_id: str, path: str) -> Optional[str]: """Get a temporary download URL for a file.""" try: if self._use_repo_token: # GET /api/v2.1/via-repo-token/download-link/?path=/foo.pdf resp = self._repo_token_get( - "download-link/", params={"path": path}, + "download-link/", + params={"path": path}, ) else: # GET /api2/repos/{repo_id}/file/?p=/foo.pdf&reuse=1 resp = self._account_get( - f"/repos/{repo_id}/file/", params={"p": path, "reuse": 1}, + f"/repos/{repo_id}/file/", + params={"p": path, "reuse": 1}, ) resp.raise_for_status() return resp.text.strip('"') @@ -418,7 +405,6 @@ class SeaFileConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): logger.warning("Error getting download link for %s: %s", path, e) return None - def _list_files_recursive( self, repo_id: str, @@ -432,7 +418,9 @@ class SeaFileConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): ) -> list[tuple[str, dict, dict]]: files = [] entries = self._get_directory_entries( - repo_id, path, raise_on_failure=strict_listing, + repo_id, + path, + raise_on_failure=strict_listing, ) for entry in entries: @@ -476,31 +464,28 @@ class SeaFileConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): def _resolve_libraries_to_scan(self) -> list[dict]: if self.sync_scope == SeafileSyncScope.ACCOUNT: - return [ - {"id": lib["id"], "name": lib.get("name", "Unknown")} - for lib in self._get_libraries() if lib.get("id") - ] + return [{"id": lib["id"], "name": lib.get("name", "Unknown")} for lib in self._get_libraries() if lib.get("id")] info = self._get_repo_info() if info: - return [{"id": info.get("id", self.repo_id), - "name": info.get("name", self.repo_id)}] + return [{"id": info.get("id", self.repo_id), "name": info.get("name", self.repo_id)}] return [{"id": self.repo_id, "name": self.repo_id}] def _root_path_for_repo(self, repo_id: str) -> str: - if (self.sync_scope == SeafileSyncScope.DIRECTORY - and repo_id == self.repo_id): + if self.sync_scope == SeafileSyncScope.DIRECTORY and repo_id == self.repo_id: return self.sync_path return "/" - def _yield_seafile_documents( - self, start: datetime, end: datetime, + self, + start: datetime, + end: datetime, ) -> GenerateDocumentsOutput: libraries = self._resolve_libraries_to_scan() logger.info( "Processing %d library(ies) [scope=%s]", - len(libraries), self.sync_scope.value, + len(libraries), + self.sync_scope.value, ) all_files: list[tuple[str, dict, dict]] = [] @@ -509,7 +494,11 @@ class SeaFileConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): logger.debug("Scanning %s starting at %s", lib["name"], root) try: files = self._list_files_recursive( - lib["id"], lib["name"], root, start, end, + lib["id"], + lib["name"], + root, + start, + end, filter_by_mtime=True, strict_listing=False, ) @@ -528,7 +517,7 @@ class SeaFileConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): repo_name = library["name"] modified = self._parse_mtime(file_entry.get("mtime")) - + if file_size > self.size_threshold: logger.warning("Skipping large file: %s (%d B)", file_path, file_size) continue @@ -544,15 +533,17 @@ class SeaFileConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): if not blob: continue - batch.append(Document( - id=f"seafile:{repo_id}:{file_id}", - blob=blob, - source=DocumentSource.SEAFILE, - semantic_identifier=f"{repo_name}{file_path}", - extension=get_file_ext(file_name), - doc_updated_at=modified, # <-- already parsed - size_bytes=len(blob), - )) + batch.append( + Document( + id=f"seafile:{repo_id}:{file_id}", + blob=blob, + source=DocumentSource.SEAFILE, + semantic_identifier=f"{repo_name}{file_path}", + extension=get_file_ext(file_name), + doc_updated_at=modified, # <-- already parsed + size_bytes=len(blob), + ) + ) if len(batch) >= self.batch_size: yield batch @@ -571,7 +562,9 @@ class SeaFileConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): ) def poll_source( - self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch, + self, + start: SecondsSinceUnixEpoch, + end: SecondsSinceUnixEpoch, ) -> GenerateDocumentsOutput: start_dt = datetime.fromtimestamp(start, tz=timezone.utc) end_dt = datetime.fromtimestamp(end, tz=timezone.utc) diff --git a/common/data_source/sharepoint_connector.py b/common/data_source/sharepoint_connector.py index ab3384d702..519c41c37f 100644 --- a/common/data_source/sharepoint_connector.py +++ b/common/data_source/sharepoint_connector.py @@ -76,9 +76,7 @@ class SharePointConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPe token = app.acquire_token_for_client(scopes=GRAPH_SCOPES) if "access_token" not in token: detail = token.get("error_description") or token.get("error") or token - raise ConnectorMissingCredentialError( - f"Failed to acquire SharePoint access token: {detail}" - ) + raise ConnectorMissingCredentialError(f"Failed to acquire SharePoint access token: {detail}") return token self.graph_client = GraphClient(_acquire_token) @@ -98,9 +96,7 @@ class SharePointConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPe except Exception as e: message = str(e) if "401" in message or "403" in message: - raise ConnectorValidationError( - "Invalid credentials or insufficient permissions for SharePoint" - ) + raise ConnectorValidationError("Invalid credentials or insufficient permissions for SharePoint") raise ConnectorValidationError(f"SharePoint validation error: {e}") # -- traversal helpers --------------------------------------------------- @@ -205,9 +201,7 @@ class SharePointConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPe logging.exception("SharePoint failed to process drive item") yield ConnectorFailure( failed_document=DocumentFailure( - document_id=self._composite_doc_id(drive_id, drive_item) - if getattr(drive_item, "id", None) is not None - else "unknown", + document_id=self._composite_doc_id(drive_id, drive_item) if getattr(drive_item, "id", None) is not None else "unknown", document_link=getattr(drive_item, "web_url", "") or "", ), failure_message=str(e), diff --git a/common/data_source/slack_connector.py b/common/data_source/slack_connector.py index 441d4b6e9f..fb2e235491 100644 --- a/common/data_source/slack_connector.py +++ b/common/data_source/slack_connector.py @@ -14,22 +14,9 @@ from slack_sdk.errors import SlackApiError from slack_sdk.http_retry import ConnectionErrorRetryHandler from slack_sdk.http_retry.builtin_interval_calculators import FixedValueRetryIntervalCalculator -from common.data_source.config import ( - INDEX_BATCH_SIZE, SLACK_NUM_THREADS, ENABLE_EXPENSIVE_EXPERT_CALLS, - _SLACK_LIMIT, FAST_TIMEOUT, MAX_RETRIES, MAX_CHANNELS_TO_LOG -) -from common.data_source.exceptions import ( - ConnectorMissingCredentialError, - ConnectorValidationError, - CredentialExpiredError, - InsufficientPermissionsError, - UnexpectedValidationError -) -from common.data_source.interfaces import ( - CheckpointedConnectorWithPermSync, - CredentialsConnector, - SlimConnectorWithPermSync -) +from common.data_source.config import INDEX_BATCH_SIZE, SLACK_NUM_THREADS, ENABLE_EXPENSIVE_EXPERT_CALLS, _SLACK_LIMIT, FAST_TIMEOUT, MAX_RETRIES, MAX_CHANNELS_TO_LOG +from common.data_source.exceptions import ConnectorMissingCredentialError, ConnectorValidationError, CredentialExpiredError, InsufficientPermissionsError, UnexpectedValidationError +from common.data_source.interfaces import CheckpointedConnectorWithPermSync, CredentialsConnector, SlimConnectorWithPermSync from common.data_source.models import ( BasicExpertInfo, ConnectorCheckpoint, @@ -38,18 +25,33 @@ from common.data_source.models import ( DocumentFailure, SlimDocument, SecondsSinceUnixEpoch, - GenerateSlimDocumentOutput, MessageType, SlackMessageFilterReason, ChannelType, ThreadType, ProcessedSlackMessage, - CheckpointOutput + GenerateSlimDocumentOutput, + MessageType, + SlackMessageFilterReason, + ChannelType, + ThreadType, + ProcessedSlackMessage, + CheckpointOutput, ) -from common.data_source.utils import make_paginated_slack_api_call, SlackTextCleaner, expert_info_from_slack_id, \ - get_message_link +from common.data_source.utils import make_paginated_slack_api_call, SlackTextCleaner, expert_info_from_slack_id, get_message_link # Disallowed message subtypes list _DISALLOWED_MSG_SUBTYPES = { - "channel_join", "channel_leave", "channel_archive", "channel_unarchive", - "pinned_item", "unpinned_item", "ekm_access_denied", "channel_posting_permissions", - "group_join", "group_leave", "group_archive", "group_unarchive", - "channel_leave", "channel_name", "channel_join", + "channel_join", + "channel_leave", + "channel_archive", + "channel_unarchive", + "pinned_item", + "unpinned_item", + "ekm_access_denied", + "channel_posting_permissions", + "group_join", + "group_leave", + "group_archive", + "group_unarchive", + "channel_leave", + "channel_name", + "channel_join", } @@ -97,7 +99,7 @@ def get_channels( channel_types.append("public_channel") if get_private: channel_types.append("private_channel") - + # First try to get public and private channels try: channels = _collect_paginated_channels( @@ -153,9 +155,7 @@ def get_channel_messages( def get_thread(client: WebClient, channel_id: str, thread_id: str) -> ThreadType: threads: list[MessageType] = [] - for result in make_paginated_slack_api_call( - client.conversations_replies, channel=channel_id, ts=thread_id - ): + for result in make_paginated_slack_api_call(client.conversations_replies, channel=channel_id, ts=thread_id): threads.extend(result["messages"]) return threads @@ -179,40 +179,20 @@ def thread_to_doc( ) -> Document: channel_id = channel["id"] - initial_sender_expert_info = expert_info_from_slack_id( - user_id=thread[0].get("user"), client=client, user_cache=user_cache - ) - initial_sender_name = ( - initial_sender_expert_info.get_semantic_name() - if initial_sender_expert_info - else "Unknown" - ) + initial_sender_expert_info = expert_info_from_slack_id(user_id=thread[0].get("user"), client=client, user_cache=user_cache) + initial_sender_name = initial_sender_expert_info.get_semantic_name() if initial_sender_expert_info else "Unknown" valid_experts = None if ENABLE_EXPENSIVE_EXPERT_CALLS: all_sender_ids = [m.get("user") for m in thread] - experts = [ - expert_info_from_slack_id( - user_id=sender_id, client=client, user_cache=user_cache - ) - for sender_id in all_sender_ids - if sender_id - ] + experts = [expert_info_from_slack_id(user_id=sender_id, client=client, user_cache=user_cache) for sender_id in all_sender_ids if sender_id] valid_experts = [expert for expert in experts if expert] - cleaned_messages = [ - slack_cleaner.index_clean(cast(str, m["text"])) for m in thread - ] + cleaned_messages = [slack_cleaner.index_clean(cast(str, m["text"])) for m in thread] first_message = cleaned_messages[0] if cleaned_messages else "" - snippet = ( - first_message[:50].rstrip() + "..." - if len(first_message) > 50 - else first_message - ) + snippet = first_message[:50].rstrip() + "..." if len(first_message) > 50 else first_message - doc_sem_id = f"{initial_sender_name} in #{channel['name']}: {snippet}".replace( - "\n", " " - ) + doc_sem_id = f"{initial_sender_name} in #{channel['name']}: {snippet}".replace("\n", " ") # The Document model is blob-based (no sections), so flatten the thread's # cleaned messages into a single UTF-8 text blob. @@ -242,14 +222,7 @@ def filter_channels( return all_channels if regex_enabled: - return [ - channel - for channel in all_channels - if any( - re.fullmatch(channel_to_connect, channel["name"]) - for channel_to_connect in channels_to_connect - ) - ] + return [channel for channel in all_channels if any(re.fullmatch(channel_to_connect, channel["name"]) for channel_to_connect in channels_to_connect)] # Validate all specified channels are valid all_channel_names = {channel["name"] for channel in all_channels} @@ -262,9 +235,7 @@ def filter_channels( f"{list(itertools.islice(all_channel_names, MAX_CHANNELS_TO_LOG))}" ) - return [ - channel for channel in all_channels if channel["name"] in channels_to_connect - ] + return [channel for channel in all_channels if channel["name"] in channels_to_connect] def _get_channel_by_id(client: WebClient, channel_id: str) -> ChannelType: @@ -309,9 +280,7 @@ def _get_messages( messages = cast(list[MessageType], response.get("messages", [])) - cursor = cast(dict[str, Any], response.get("response_metadata", {})).get( - "next_cursor", "" - ) + cursor = cast(dict[str, Any], response.get("response_metadata", {})).get("next_cursor", "") has_more = bool(cursor) return messages, has_more @@ -324,9 +293,7 @@ def _message_to_doc( user_cache: dict[str, BasicExpertInfo | None], seen_thread_ts: set[str], channel_access: Any | None, - msg_filter_func: Callable[ - [MessageType], SlackMessageFilterReason | None - ] = default_msg_filter, + msg_filter_func: Callable[[MessageType], SlackMessageFilterReason | None] = default_msg_filter, ) -> tuple[Document | None, SlackMessageFilterReason | None]: """Convert message to document""" filtered_thread: ThreadType | None = None @@ -337,9 +304,7 @@ def _message_to_doc( if thread_ts in seen_thread_ts: return None, None - thread = get_thread( - client=client, channel_id=channel["id"], thread_id=thread_ts - ) + thread = get_thread(client=client, channel_id=channel["id"], thread_id=thread_ts) filtered_thread = [] for message in thread: @@ -377,9 +342,7 @@ def _process_message( user_cache: dict[str, BasicExpertInfo | None], seen_thread_ts: set[str], channel_access: Any | None, - msg_filter_func: Callable[ - [MessageType], SlackMessageFilterReason | None - ] = default_msg_filter, + msg_filter_func: Callable[[MessageType], SlackMessageFilterReason | None] = default_msg_filter, ) -> ProcessedSlackMessage: thread_ts = message.get("thread_ts") thread_or_message_ts = thread_ts or message["ts"] @@ -408,9 +371,7 @@ def _process_message( filter_reason=None, failure=ConnectorFailure( failed_document=DocumentFailure( - document_id=_build_doc_id( - channel_id=channel["id"], thread_ts=thread_or_message_ts - ), + document_id=_build_doc_id(channel_id=channel["id"], thread_ts=thread_or_message_ts), document_link=get_message_link(message, client, channel["id"]), ), failure_message=str(e), @@ -423,15 +384,11 @@ def _get_all_doc_ids( client: WebClient, channels: list[str] | None = None, channel_name_regex_enabled: bool = False, - msg_filter_func: Callable[ - [MessageType], SlackMessageFilterReason | None - ] = default_msg_filter, + msg_filter_func: Callable[[MessageType], SlackMessageFilterReason | None] = default_msg_filter, callback: Any = None, ) -> GenerateSlimDocumentOutput: all_channels = get_channels(client) - filtered_channels = filter_channels( - all_channels, channels, channel_name_regex_enabled - ) + filtered_channels = filter_channels(all_channels, channels, channel_name_regex_enabled) for channel in filtered_channels: channel_id = channel["id"] @@ -451,9 +408,7 @@ def _get_all_doc_ids( slim_doc_batch.append( SlimDocument( - id=_build_doc_id( - channel_id=channel_id, thread_ts=message["ts"] - ), + id=_build_doc_id(channel_id=channel_id, thread_ts=message["ts"]), external_access=external_access, ) ) @@ -493,9 +448,7 @@ class SlackConnector( @channels.setter def channels(self, channels: list[str] | None) -> None: - self._channels = ( - [channel.removeprefix("#") for channel in channels] if channels else None - ) + self._channels = [channel.removeprefix("#") for channel in channels] if channels else None def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: """Load credentials""" @@ -518,14 +471,10 @@ class SlackConnector( ], ) - self.client = WebClient( - token=bot_token, retry_handlers=[connection_error_retry_handler] - ) + self.client = WebClient(token=bot_token, retry_handlers=[connection_error_retry_handler]) # For fast response requests - self.fast_client = WebClient( - token=bot_token, timeout=FAST_TIMEOUT - ) + self.fast_client = WebClient(token=bot_token, timeout=FAST_TIMEOUT) self.text_cleaner = SlackTextCleaner(client=self.client) self.credentials_provider = credentials_provider @@ -560,9 +509,7 @@ class SlackConnector( raise ConnectorMissingCredentialError("Slack") all_channels = get_channels(self.client) - filtered_channels = filter_channels( - all_channels, self.channels, self.channel_regex_enabled - ) + filtered_channels = filter_channels(all_channels, self.channels, self.channel_regex_enabled) batch: list[Document] = [] for channel in filtered_channels: @@ -655,70 +602,47 @@ class SlackConnector( # 1) Validate workspace connection auth_response = self.fast_client.auth_test() if not auth_response.get("ok", False): - error_msg = auth_response.get( - "error", "Unknown error from Slack auth_test" - ) + error_msg = auth_response.get("error", "Unknown error from Slack auth_test") raise ConnectorValidationError(f"Failed Slack auth_test: {error_msg}") # 2) Confirm listing channels functionality works - test_resp = self.fast_client.conversations_list( - limit=1, types=["public_channel"] - ) + test_resp = self.fast_client.conversations_list(limit=1, types=["public_channel"]) if not test_resp.get("ok", False): error_msg = test_resp.get("error", "Unknown error from Slack") if error_msg == "invalid_auth": - raise ConnectorValidationError( - f"Invalid Slack bot token ({error_msg})." - ) + raise ConnectorValidationError(f"Invalid Slack bot token ({error_msg}).") elif error_msg == "not_authed": - raise CredentialExpiredError( - f"Invalid or expired Slack bot token ({error_msg})." - ) - raise UnexpectedValidationError( - f"Slack API returned a failure: {error_msg}" - ) + raise CredentialExpiredError(f"Invalid or expired Slack bot token ({error_msg}).") + raise UnexpectedValidationError(f"Slack API returned a failure: {error_msg}") # 3) Confirm users:read scope is available (required by thread_to_doc) users_resp = self.fast_client.users_info(user="USLACKBOT") if not users_resp.get("ok", False): error_msg = users_resp.get("error", "") if error_msg in ("missing_scope", "not_allowed_token_type"): - raise InsufficientPermissionsError( - "Slack bot token lacks the 'users:read' scope required to look up message senders. " - "Please add 'users:read' to your Slack app's OAuth scopes." - ) + raise InsufficientPermissionsError("Slack bot token lacks the 'users:read' scope required to look up message senders. Please add 'users:read' to your Slack app's OAuth scopes.") except SlackApiError as e: slack_error = e.response.get("error", "") if slack_error == "ratelimited": retry_after = int(e.response.headers.get("Retry-After", 1)) logging.warning( - f"Slack API rate limited during validation. Retry suggested after {retry_after} seconds. " - "Proceeding with validation, but be aware that connector operations might be throttled." + f"Slack API rate limited during validation. Retry suggested after {retry_after} seconds. Proceeding with validation, but be aware that connector operations might be throttled." ) return elif slack_error == "missing_scope": raise InsufficientPermissionsError( - "Slack bot token lacks the necessary scope to list/access channels. " - "Please ensure your Slack app has 'channels:read' (and/or 'groups:read' for private channels)." + "Slack bot token lacks the necessary scope to list/access channels. Please ensure your Slack app has 'channels:read' (and/or 'groups:read' for private channels)." ) elif slack_error == "invalid_auth": - raise CredentialExpiredError( - f"Invalid Slack bot token ({slack_error})." - ) + raise CredentialExpiredError(f"Invalid Slack bot token ({slack_error}).") elif slack_error == "not_authed": - raise CredentialExpiredError( - f"Invalid or expired Slack bot token ({slack_error})." - ) - raise UnexpectedValidationError( - f"Unexpected Slack error '{slack_error}' during settings validation." - ) + raise CredentialExpiredError(f"Invalid or expired Slack bot token ({slack_error}).") + raise UnexpectedValidationError(f"Unexpected Slack error '{slack_error}' during settings validation.") except ConnectorValidationError as e: raise e except Exception as e: - raise UnexpectedValidationError( - f"Unexpected error during Slack settings validation: {e}" - ) + raise UnexpectedValidationError(f"Unexpected error during Slack settings validation: {e}") if __name__ == "__main__": @@ -731,9 +655,7 @@ if __name__ == "__main__": ) # Simplified version, directly using credentials dictionary - credentials = { - "slack_bot_token": os.environ.get("SLACK_BOT_TOKEN", "test-token") - } + credentials = {"slack_bot_token": os.environ.get("SLACK_BOT_TOKEN", "test-token")} class SimpleCredentialsProvider: def get_credentials(self): diff --git a/common/data_source/teams_connector.py b/common/data_source/teams_connector.py index a4bb75d358..2ca4604217 100644 --- a/common/data_source/teams_connector.py +++ b/common/data_source/teams_connector.py @@ -43,6 +43,7 @@ GRAPH_SCOPES = ["https://graph.microsoft.com/.default"] class TeamsCheckpoint(ConnectorCheckpoint): """Teams-specific checkpoint""" + todo_team_ids: list[str] | None = None @@ -84,9 +85,7 @@ class TeamsConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermSyn token = app.acquire_token_for_client(scopes=GRAPH_SCOPES) if "access_token" not in token: detail = token.get("error_description") or token.get("error") or token - raise ConnectorMissingCredentialError( - f"Failed to acquire Microsoft Teams access token: {detail}" - ) + raise ConnectorMissingCredentialError(f"Failed to acquire Microsoft Teams access token: {detail}") return token self.graph_client = GraphClient(_acquire_token) @@ -104,9 +103,7 @@ class TeamsConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermSyn except Exception as e: message = str(e) if "401" in message or "403" in message: - raise InsufficientPermissionsError( - "Invalid credentials or insufficient permissions for Microsoft Teams" - ) + raise InsufficientPermissionsError("Invalid credentials or insufficient permissions for Microsoft Teams") raise UnexpectedValidationError(f"Microsoft Teams validation error: {e}") # -- helpers ------------------------------------------------------------- @@ -172,9 +169,7 @@ class TeamsConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermSyn contents.append(text) if ctype == "html": content_type = "html" - modified = self._parse_dt(self._prop(item, "lastModifiedDateTime")) or self._parse_dt( - self._prop(item, "createdDateTime") - ) + modified = self._parse_dt(self._prop(item, "lastModifiedDateTime")) or self._parse_dt(self._prop(item, "createdDateTime")) if modified is not None and (latest is None or modified > latest): latest = modified @@ -232,9 +227,7 @@ class TeamsConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermSyn for team_id, team_name, channel_id, channel_name, message in self._iter_channel_messages(): try: - modified = self._parse_dt(self._prop(message, "lastModifiedDateTime")) or self._parse_dt( - self._prop(message, "createdDateTime") - ) + modified = self._parse_dt(self._prop(message, "lastModifiedDateTime")) or self._parse_dt(self._prop(message, "createdDateTime")) if modified is not None: ts = modified.timestamp() # start is an exclusive lower bound; full reindex passes start=0. @@ -242,9 +235,7 @@ class TeamsConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermSyn continue replies = list(message.replies.get_all().execute_query()) - yield self._message_to_document( - message, replies, team_id, team_name, channel_id, channel_name - ) + yield self._message_to_document(message, replies, team_id, team_name, channel_id, channel_name) except Exception as e: logging.exception("Microsoft Teams failed to process message") yield ConnectorFailure( diff --git a/common/data_source/utils.py b/common/data_source/utils.py index 849e304795..ad4c43c6e4 100644 --- a/common/data_source/utils.py +++ b/common/data_source/utils.py @@ -317,13 +317,12 @@ def create_s3_client(bucket_type: BlobType, credentials: dict[str, Any], europea region_name=credentials["region"], ) elif bucket_type == BlobType.S3_COMPATIBLE: - return boto3.client( "s3", endpoint_url=credentials["endpoint_url"], aws_access_key_id=credentials["aws_access_key_id"], aws_secret_access_key=credentials["aws_secret_access_key"], - config=Config(s3={'addressing_style': credentials["addressing_style"]}), + config=Config(s3={"addressing_style": credentials["addressing_style"]}), ) else: @@ -1187,8 +1186,11 @@ def sanitize_filename(name: str, extension: str = "txt") -> str: name += f".{extension}" return name + + F = TypeVar("F", bound=Callable[..., Any]) + class _RateLimitDecorator: """Builds a generic wrapper/decorator for calls to external APIs that prevents making more than `max_calls` requests per `period` @@ -1226,16 +1228,11 @@ class _RateLimitDecorator: sleep_cnt = 0 while len(self.call_history) == self.max_calls: sleep_time = self.sleep_time * (self.sleep_backoff**sleep_cnt) - logging.warning( - f"Rate limit exceeded for function {func.__name__}. " - f"Waiting {sleep_time} seconds before retrying." - ) + logging.warning(f"Rate limit exceeded for function {func.__name__}. Waiting {sleep_time} seconds before retrying.") time.sleep(sleep_time) sleep_cnt += 1 if self.max_num_sleep != 0 and sleep_cnt >= self.max_num_sleep: - raise RateLimitTriedTooManyTimesError( - f"Exceeded '{self.max_num_sleep}' retries for function '{func.__name__}'" - ) + raise RateLimitTriedTooManyTimesError(f"Exceeded '{self.max_num_sleep}' retries for function '{func.__name__}'") self._cleanup() @@ -1248,14 +1245,12 @@ class _RateLimitDecorator: def _cleanup(self) -> None: curr_time = time.monotonic() time_to_expire_before = curr_time - self.period - self.call_history = [ - call_time - for call_time in self.call_history - if call_time > time_to_expire_before - ] + self.call_history = [call_time for call_time in self.call_history if call_time > time_to_expire_before] + rate_limit_builder = _RateLimitDecorator + def retry_builder( tries: int = 20, delay: float = 0.1, diff --git a/common/data_source/webdav_connector.py b/common/data_source/webdav_connector.py index 8cdd295794..4ba6bd3372 100644 --- a/common/data_source/webdav_connector.py +++ b/common/data_source/webdav_connector.py @@ -1,4 +1,5 @@ """WebDAV connector""" + import logging import os from datetime import datetime, timezone @@ -12,12 +13,7 @@ from common.data_source.utils import ( is_accepted_file_ext, ) from common.data_source.config import DocumentSource, INDEX_BATCH_SIZE, BLOB_STORAGE_SIZE_THRESHOLD -from common.data_source.exceptions import ( - ConnectorMissingCredentialError, - ConnectorValidationError, - CredentialExpiredError, - InsufficientPermissionsError -) +from common.data_source.exceptions import ConnectorMissingCredentialError, ConnectorValidationError, CredentialExpiredError, InsufficientPermissionsError from common.data_source.interfaces import LoadConnector, OnyxExtensionType, PollConnector, SlimConnectorWithPermSync from common.data_source.models import Document, GenerateDocumentsOutput, GenerateSlimDocumentOutput, SecondsSinceUnixEpoch, SlimDocument @@ -32,7 +28,7 @@ class WebDAVConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): batch_size: int = INDEX_BATCH_SIZE, ) -> None: """Initialize WebDAV connector - + Args: base_url: Base URL of the WebDAV server (e.g., "https://webdav.example.com") remote_path: Remote path to sync from (default: "/") @@ -111,13 +107,13 @@ class WebDAVConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: """Load credentials and initialize WebDAV client - + Args: credentials: Dictionary containing 'username' and 'password' - + Returns: None - + Raises: ConnectorMissingCredentialError: If required credentials are missing """ @@ -125,23 +121,16 @@ class WebDAVConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): username = credentials.get("username") password = credentials.get("password") - + if not username or not password: - raise ConnectorMissingCredentialError( - "WebDAV requires 'username' and 'password' credentials" - ) + raise ConnectorMissingCredentialError("WebDAV requires 'username' and 'password' credentials") try: # Initialize WebDAV client - self.client = WebDAVClient( - base_url=self.base_url, - auth=(username, password) - ) + self.client = WebDAVClient(base_url=self.base_url, auth=(username, password)) except Exception as e: logging.error(f"Failed to connect to WebDAV server: {e}") - raise ConnectorMissingCredentialError( - f"Failed to authenticate with WebDAV server: {e}" - ) + raise ConnectorMissingCredentialError(f"Failed to authenticate with WebDAV server: {e}") return None @@ -154,13 +143,13 @@ class WebDAVConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): filter_by_mtime: bool = True, ) -> list[tuple[str, dict]]: """Recursively list all files in the given path - + Args: path: Path to list files from start: Start datetime for filtering (ignored when ``filter_by_mtime`` is False) end: End datetime for filtering (ignored when ``filter_by_mtime`` is False) filter_by_mtime: When False, include every supported extension without mtime window - + Returns: List of tuples containing (file_path, file_info) """ @@ -168,18 +157,18 @@ class WebDAVConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): raise ConnectorMissingCredentialError("WebDAV client not initialized") files = [] - + try: logging.debug(f"Listing directory: {path}") for item in self.client.ls(path, detail=True): - item_path = item['name'] - - if item_path == path or item_path == path + '/': + item_path = item["name"] + + if item_path == path or item_path == path + "/": continue - + logging.debug(f"Found item: {item_path}, type: {item.get('type')}") - if item.get('type') == 'directory': + if item.get("type") == "directory": try: files.extend( self._list_files_recursive( @@ -199,7 +188,7 @@ class WebDAVConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): logging.debug(f"Skipping file {item_path} due to unsupported extension.") continue - modified_time = item.get('modified') + modified_time = item.get("modified") if modified_time: if isinstance(modified_time, datetime): modified = modified_time @@ -207,11 +196,11 @@ class WebDAVConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): modified = modified.replace(tzinfo=timezone.utc) elif isinstance(modified_time, str): try: - modified = datetime.strptime(modified_time, '%a, %d %b %Y %H:%M:%S %Z') + modified = datetime.strptime(modified_time, "%a, %d %b %Y %H:%M:%S %Z") modified = modified.replace(tzinfo=timezone.utc) except (ValueError, TypeError): try: - modified = datetime.fromisoformat(modified_time.replace('Z', '+00:00')) + modified = datetime.fromisoformat(modified_time.replace("Z", "+00:00")) except (ValueError, TypeError): logging.warning(f"Could not parse modified time for {item_path}: {modified_time}") modified = datetime.now(timezone.utc) @@ -219,7 +208,6 @@ class WebDAVConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): modified = datetime.now(timezone.utc) else: modified = datetime.now(timezone.utc) - logging.debug(f"File {item_path}: modified={modified}, start={start}, end={end}, include={start < modified <= end}") if filter_by_mtime: @@ -232,10 +220,10 @@ class WebDAVConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): except Exception as e: logging.error(f"Error processing file {item_path}: {e}") continue - + except Exception as e: logging.error(f"Error listing directory {path}: {e}") - + return files def _yield_webdav_documents( @@ -244,11 +232,11 @@ class WebDAVConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): end: datetime, ) -> GenerateDocumentsOutput: """Generate documents from WebDAV server - + Args: start: Start datetime for filtering end: End datetime for filtering - + Yields: Batches of documents """ @@ -258,12 +246,12 @@ class WebDAVConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): logging.info(f"Searching for files in {self.remote_path} between {start} and {end}") files = self._list_files_recursive(self.remote_path, start, end) logging.info(f"Found {len(files)} files matching time criteria") - + filename_counts: dict[str, int] = {} for file_path, _ in files: file_name = os.path.basename(file_path) filename_counts[file_name] = filename_counts.get(file_name, 0) + 1 - + batch: list[Document] = [] for file_path, file_info in files: file_name = os.path.basename(file_path) @@ -271,39 +259,30 @@ class WebDAVConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): if not self._is_supported_file(file_name): logging.debug(f"Skipping file {file_path} due to unsupported extension.") continue - + size_bytes = self._get_size_bytes(file_info) if self.size_threshold is not None and size_bytes is None: file_identifier = self._get_log_file_identifier(file_info, file_path) - logging.warning( - f"{file_identifier}: size metadata missing from WebDAV server response, " - f"skipping to avoid processing potentially large files." - ) + logging.warning(f"{file_identifier}: size metadata missing from WebDAV server response, skipping to avoid processing potentially large files.") continue - if ( - self.size_threshold is not None - and size_bytes is not None - and size_bytes > self.size_threshold - ): + if self.size_threshold is not None and size_bytes is not None and size_bytes > self.size_threshold: file_identifier = self._get_log_file_identifier(file_info, file_path) - logging.warning( - f"{file_identifier} exceeds size threshold of {self.size_threshold} " - f"(size_bytes={size_bytes}). Skipping." - ) + logging.warning(f"{file_identifier} exceeds size threshold of {self.size_threshold} (size_bytes={size_bytes}). Skipping.") continue - + try: logging.debug(f"Downloading file: {file_path}") from io import BytesIO + buffer = BytesIO() self.client.download_fileobj(file_path, buffer) blob = buffer.getvalue() - + if blob is None or len(blob) == 0: logging.warning(f"Downloaded content is empty for {file_path}") continue - modified_time = file_info.get('modified') + modified_time = file_info.get("modified") if modified_time: if isinstance(modified_time, datetime): modified = modified_time @@ -311,11 +290,11 @@ class WebDAVConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): modified = modified.replace(tzinfo=timezone.utc) elif isinstance(modified_time, str): try: - modified = datetime.strptime(modified_time, '%a, %d %b %Y %H:%M:%S %Z') + modified = datetime.strptime(modified_time, "%a, %d %b %Y %H:%M:%S %Z") modified = modified.replace(tzinfo=timezone.utc) except (ValueError, TypeError): try: - modified = datetime.fromisoformat(modified_time.replace('Z', '+00:00')) + modified = datetime.fromisoformat(modified_time.replace("Z", "+00:00")) except (ValueError, TypeError): logging.warning(f"Could not parse modified time for {file_path}: {modified_time}") modified = datetime.now(timezone.utc) @@ -327,10 +306,10 @@ class WebDAVConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): if filename_counts.get(file_name, 0) > 1: relative_path = file_path if file_path.startswith(self.remote_path): - relative_path = file_path[len(self.remote_path):] - if relative_path.startswith('/'): + relative_path = file_path[len(self.remote_path) :] + if relative_path.startswith("/"): relative_path = relative_path[1:] - semantic_id = relative_path.replace('/', ' / ') if relative_path else file_name + semantic_id = relative_path.replace("/", " / ") if relative_path else file_name else: semantic_id = file_name @@ -342,23 +321,23 @@ class WebDAVConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): semantic_identifier=semantic_id, extension=get_file_ext(file_name), doc_updated_at=modified, - size_bytes=size_bytes if size_bytes is not None else 0 + size_bytes=size_bytes if size_bytes is not None else 0, ) ) - + if len(batch) == self.batch_size: yield batch batch = [] except Exception as e: logging.exception(f"Error downloading file {file_path}: {e}") - + if batch: yield batch def load_from_state(self) -> GenerateDocumentsOutput: """Load all documents from WebDAV server - + Yields: Batches of documents """ @@ -368,15 +347,13 @@ class WebDAVConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): end=datetime.now(timezone.utc), ) - def poll_source( - self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch - ) -> GenerateDocumentsOutput: + def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> GenerateDocumentsOutput: """Poll WebDAV server for updated documents - + Args: start: Start timestamp (seconds since Unix epoch) end: End timestamp (seconds since Unix epoch) - + Yields: Batches of documents """ @@ -423,25 +400,13 @@ class WebDAVConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): size_bytes = self._get_size_bytes(file_info) if self.size_threshold is not None and size_bytes is None: file_identifier = self._get_log_file_identifier(file_info, file_path) - logging.warning( - f"{file_identifier}: size metadata missing from WebDAV server response, " - f"skipping to avoid processing potentially large files." - ) + logging.warning(f"{file_identifier}: size metadata missing from WebDAV server response, skipping to avoid processing potentially large files.") continue - if ( - self.size_threshold is not None - and size_bytes is not None - and size_bytes > self.size_threshold - ): + if self.size_threshold is not None and size_bytes is not None and size_bytes > self.size_threshold: file_identifier = self._get_log_file_identifier(file_info, file_path) - logging.warning( - f"{file_identifier} exceeds size threshold of {self.size_threshold} " - f"(size_bytes={size_bytes}). Skipping." - ) + logging.warning(f"{file_identifier} exceeds size threshold of {self.size_threshold} (size_bytes={size_bytes}). Skipping.") continue - batch.append( - SlimDocument(id=f"webdav:{self.base_url}:{file_path}") - ) + batch.append(SlimDocument(id=f"webdav:{self.base_url}:{file_path}")) total += 1 if len(batch) >= self.batch_size: yield batch @@ -498,20 +463,13 @@ class WebDAVConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): if status == 401: raise CredentialExpiredError("WebDAV credentials appear invalid or expired.") if status == 403: - raise InsufficientPermissionsError( - f"Insufficient permissions to access path '{self.remote_path}' on WebDAV server." - ) + raise InsufficientPermissionsError(f"Insufficient permissions to access path '{self.remote_path}' on WebDAV server.") if status == 404: - raise ConnectorValidationError( - f"Remote path '{self.remote_path}' does not exist on WebDAV server." - ) + raise ConnectorValidationError(f"Remote path '{self.remote_path}' does not exist on WebDAV server.") # Fallback: avoid brittle substring matching that caused false positives. # Provide the original exception for diagnosis. - raise ConnectorValidationError( - f"WebDAV validation failed for path '{test_path}': {repr(e)}" - ) - + raise ConnectorValidationError(f"WebDAV validation failed for path '{test_path}': {repr(e)}") if __name__ == "__main__": @@ -525,8 +483,6 @@ if __name__ == "__main__": "password": "pass", } - - connector = WebDAVConnector( base_url="http://172.17.0.1:8080/", remote_path="/", @@ -535,7 +491,7 @@ if __name__ == "__main__": try: connector.load_credentials(credentials_dict) connector.validate_connector_settings() - + document_batch_generator = connector.load_from_state() for document_batch in document_batch_generator: print("First batch of documents:") diff --git a/common/data_source/zendesk_connector.py b/common/data_source/zendesk_connector.py index c357b500fb..c740e994b9 100644 --- a/common/data_source/zendesk_connector.py +++ b/common/data_source/zendesk_connector.py @@ -15,7 +15,7 @@ from common.data_source.exceptions import ConnectorValidationError, CredentialEx from common.data_source.html_utils import parse_html_page_basic from common.data_source.interfaces import CheckpointOutput, CheckpointOutputWrapper, CheckpointedConnector, IndexingHeartbeatInterface, SlimConnectorWithPermSync from common.data_source.models import BasicExpertInfo, ConnectorCheckpoint, ConnectorFailure, Document, DocumentFailure, GenerateSlimDocumentOutput, SecondsSinceUnixEpoch, SlimDocument -from common.data_source.utils import retry_builder, time_str_to_utc,rate_limit_builder +from common.data_source.utils import retry_builder, time_str_to_utc, rate_limit_builder MAX_PAGE_SIZE = 30 # Zendesk API maximum MAX_AUTHOR_MAP_SIZE = 50_000 # Reset author map cache if it gets too large @@ -24,9 +24,7 @@ _SLIM_BATCH_SIZE = 1000 class ZendeskCredentialsNotSetUpError(PermissionError): def __init__(self) -> None: - super().__init__( - "Zendesk Credentials are not set up, was load_credentials called?" - ) + super().__init__("Zendesk Credentials are not set up, was load_credentials called?") class ZendeskClient: @@ -42,19 +40,11 @@ class ZendeskClient: self.make_request = request_with_rate_limit(self, calls_per_minute) -def request_with_rate_limit( - client: ZendeskClient, max_calls_per_minute: int | None = None -) -> Callable[[str, dict[str, Any]], dict[str, Any]]: +def request_with_rate_limit(client: ZendeskClient, max_calls_per_minute: int | None = None) -> Callable[[str, dict[str, Any]], dict[str, Any]]: @retry_builder() - @( - rate_limit_builder(max_calls=max_calls_per_minute, period=60) - if max_calls_per_minute - else lambda x: x - ) + @(rate_limit_builder(max_calls=max_calls_per_minute, period=60) if max_calls_per_minute else lambda x: x) def make_request(endpoint: str, params: dict[str, Any]) -> dict[str, Any]: - response = requests.get( - f"{client.base_url}/{endpoint}", auth=client.auth, params=params - ) + response = requests.get(f"{client.base_url}/{endpoint}", auth=client.auth, params=params) if response.status_code == 429: retry_after = response.headers.get("Retry-After") @@ -62,10 +52,7 @@ def request_with_rate_limit( # Sleep for the duration indicated by the Retry-After header time.sleep(int(retry_after)) - elif ( - response.status_code == 403 - and response.json().get("error") == "SupportProductInactive" - ): + elif response.status_code == 403 and response.json().get("error") == "SupportProductInactive": return response.json() response.raise_for_status() @@ -102,9 +89,7 @@ def _get_content_tag_mapping(client: ZendeskClient) -> dict[str, str]: raise Exception(f"Error fetching content tags: {str(e)}") -def _get_articles( - client: ZendeskClient, start_time: int | None = None, page_size: int = MAX_PAGE_SIZE -) -> Iterator[dict[str, Any]]: +def _get_articles(client: ZendeskClient, start_time: int | None = None, page_size: int = MAX_PAGE_SIZE) -> Iterator[dict[str, Any]]: params = {"page[size]": page_size, "sort_by": "updated_at", "sort_order": "asc"} if start_time is not None: params["start_time"] = start_time @@ -139,9 +124,7 @@ def _get_article_page( ) -def _get_tickets( - client: ZendeskClient, start_time: int | None = None -) -> Iterator[dict[str, Any]]: +def _get_tickets(client: ZendeskClient, start_time: int | None = None) -> Iterator[dict[str, Any]]: params = {"start_time": start_time or 0} while True: @@ -156,9 +139,7 @@ def _get_tickets( # TODO: maybe these don't need to be their own functions? -def _get_tickets_page( - client: ZendeskClient, start_time: int | None = None -) -> ZendeskPageResponse: +def _get_tickets_page(client: ZendeskClient, start_time: int | None = None) -> ZendeskPageResponse: params = {"start_time": start_time or 0} # NOTE: for some reason zendesk doesn't seem to be respecting the start_time param @@ -166,9 +147,7 @@ def _get_tickets_page( # issue in larger deployments data = client.make_request("incremental/tickets.json", params) if data.get("error") == "SupportProductInactive": - raise ValueError( - "Zendesk Support Product is not active for this account, No tickets to index" - ) + raise ValueError("Zendesk Support Product is not active for this account, No tickets to index") return ZendeskPageResponse( data=data["tickets"], meta={"end_time": data["end_time"]}, @@ -176,9 +155,7 @@ def _get_tickets_page( ) -def _fetch_author( - client: ZendeskClient, author_id: str | int -) -> BasicExpertInfo | None: +def _fetch_author(client: ZendeskClient, author_id: str | int) -> BasicExpertInfo | None: # Skip fetching if author_id is invalid # cast to str to avoid issues with zendesk changing their types if not author_id or str(author_id) == "-1": @@ -187,11 +164,7 @@ def _fetch_author( try: author_data = client.make_request(f"users/{author_id}", {}) user = author_data.get("user") - return ( - BasicExpertInfo(display_name=user.get("name"), email=user.get("email")) - if user and user.get("name") and user.get("email") - else None - ) + return BasicExpertInfo(display_name=user.get("name"), email=user.get("email")) if user and user.get("name") and user.get("email") else None except requests.exceptions.HTTPError: # Handle any API errors gracefully return None @@ -207,11 +180,7 @@ def _article_to_document( if not author_id: author = None else: - author = ( - author_map.get(author_id) - if author_id in author_map - else _fetch_author(client, author_id) - ) + author = author_map.get(author_id) if author_id in author_map else _fetch_author(client, author_id) new_author_mapping = {author_id: author} if author_id and author else None @@ -223,11 +192,7 @@ def _article_to_document( # Build metadata metadata: dict[str, str | list[str]] = { "labels": [str(label) for label in article.get("label_names", []) if label], - "content_tags": [ - content_tags[tag_id] - for tag_id in article.get("content_tag_ids", []) - if tag_id in content_tags - ], + "content_tags": [content_tags[tag_id] for tag_id in article.get("content_tag_ids", []) if tag_id in content_tags], } # Remove empty values @@ -248,14 +213,7 @@ def _article_to_document( def _is_indexable_article(article: dict[str, Any]) -> bool: body = article.get("body") - return ( - bool(body) - and not article.get("draft") - and not any( - label in ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS - for label in article.get("label_names") or [] - ) - ) + return bool(body) and not article.get("draft") and not any(label in ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS for label in article.get("label_names") or []) def _get_comment_text( @@ -267,11 +225,7 @@ def _get_comment_text( if not author_id: author = None else: - author = ( - author_map.get(author_id) - if author_id in author_map - else _fetch_author(client, author_id) - ) + author = author_map.get(author_id) if author_id in author_map else _fetch_author(client, author_id) new_author_mapping = {author_id: author} if author_id and author else None @@ -290,15 +244,9 @@ def _ticket_to_document( if not submitter_id: submitter = None else: - submitter = ( - author_map.get(submitter_id) - if submitter_id in author_map - else _fetch_author(client, submitter_id) - ) + submitter = author_map.get(submitter_id) if submitter_id in author_map else _fetch_author(client, submitter_id) - new_author_mapping = ( - {submitter_id: submitter} if submitter_id and submitter else None - ) + new_author_mapping = {submitter_id: submitter} if submitter_id and submitter else None updated_at = ticket.get("updated_at") update_time = time_str_to_utc(updated_at) if updated_at else None @@ -319,9 +267,7 @@ def _ticket_to_document( comment_texts = [] for comment in comments: - new_author_mapping, comment_text = _get_comment_text( - comment, author_map, client - ) + new_author_mapping, comment_text = _get_comment_text(comment, author_map, client) if new_author_mapping: author_map.update(new_author_mapping) comment_texts.append(comment_text) @@ -360,9 +306,7 @@ class ZendeskConnectorCheckpoint(ConnectorCheckpoint): cached_content_tags: dict[str, str] | None -class ZendeskConnector( - SlimConnectorWithPermSync, CheckpointedConnector[ZendeskConnectorCheckpoint] -): +class ZendeskConnector(SlimConnectorWithPermSync, CheckpointedConnector[ZendeskConnectorCheckpoint]): def __init__( self, content_type: str = "articles", @@ -376,11 +320,7 @@ class ZendeskConnector( def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: # Subdomain is actually the whole URL - subdomain = ( - credentials["zendesk_subdomain"] - .replace("https://", "") - .split(".zendesk.com")[0] - ) + subdomain = credentials["zendesk_subdomain"].replace("https://", "").split(".zendesk.com")[0] self.subdomain = subdomain self.client = ZendeskClient( @@ -439,9 +379,7 @@ class ZendeskConnector( continue try: - new_author_map, document = _article_to_document( - article, self.content_tags, author_map, self.client - ) + new_author_map, document = _article_to_document(article, self.content_tags, author_map, self.client) except Exception as e: logging.error(f"Error processing article {article['id']}: {e}") yield ConnectorFailure( @@ -477,14 +415,8 @@ class ZendeskConnector( checkpoint.after_cursor_articles = after_cursor last_doc_updated_at = doc_batch[-1].doc_updated_at if doc_batch else None - checkpoint.has_more = bool( - end is None - or last_doc_updated_at is None - or last_doc_updated_at.timestamp() <= end - ) - checkpoint.cached_author_map = ( - author_map if len(author_map) <= MAX_AUTHOR_MAP_SIZE else None - ) + checkpoint.has_more = bool(end is None or last_doc_updated_at is None or last_doc_updated_at.timestamp() <= end) + checkpoint.cached_author_map = author_map if len(author_map) <= MAX_AUTHOR_MAP_SIZE else None return checkpoint def _retrieve_tickets( @@ -550,14 +482,8 @@ class ZendeskConnector( yield from doc_batch checkpoint.next_start_time_tickets = next_start_time last_doc_updated_at = doc_batch[-1].doc_updated_at if doc_batch else None - checkpoint.has_more = bool( - end is None - or last_doc_updated_at is None - or last_doc_updated_at.timestamp() <= end - ) - checkpoint.cached_author_map = ( - author_map if len(author_map) <= MAX_AUTHOR_MAP_SIZE else None - ) + checkpoint.has_more = bool(end is None or last_doc_updated_at is None or last_doc_updated_at.timestamp() <= end) + checkpoint.cached_author_map = author_map if len(author_map) <= MAX_AUTHOR_MAP_SIZE else None return checkpoint def retrieve_all_slim_docs_perm_sync( @@ -606,26 +532,16 @@ class ZendeskConnector( except HTTPError as e: # Check for HTTP status codes if e.response.status_code == 401: - raise CredentialExpiredError( - "Your Zendesk credentials appear to be invalid or expired (HTTP 401)." - ) from e + raise CredentialExpiredError("Your Zendesk credentials appear to be invalid or expired (HTTP 401).") from e elif e.response.status_code == 403: - raise InsufficientPermissionsError( - "Your Zendesk token does not have sufficient permissions (HTTP 403)." - ) from e + raise InsufficientPermissionsError("Your Zendesk token does not have sufficient permissions (HTTP 403).") from e elif e.response.status_code == 404: - raise ConnectorValidationError( - "Zendesk resource not found (HTTP 404)." - ) from e + raise ConnectorValidationError("Zendesk resource not found (HTTP 404).") from e else: - raise ConnectorValidationError( - f"Unexpected Zendesk error (status={e.response.status_code}): {e}" - ) from e + raise ConnectorValidationError(f"Unexpected Zendesk error (status={e.response.status_code}): {e}") from e @override - def validate_checkpoint_json( - self, checkpoint_json: str - ) -> ZendeskConnectorCheckpoint: + def validate_checkpoint_json(self, checkpoint_json: str) -> ZendeskConnectorCheckpoint: return ZendeskConnectorCheckpoint.model_validate_json(checkpoint_json) @override @@ -657,9 +573,7 @@ if __name__ == "__main__": checkpoint = connector.build_dummy_checkpoint() while checkpoint.has_more: - gen = connector.load_from_checkpoint( - one_day_ago, current, checkpoint - ) + gen = connector.load_from_checkpoint(one_day_ago, current, checkpoint) wrapper = CheckpointOutputWrapper() any_doc = False diff --git a/common/decorator.py b/common/decorator.py index 7dd0319f43..6085498d31 100644 --- a/common/decorator.py +++ b/common/decorator.py @@ -59,6 +59,7 @@ def timing(func=None, *, name=None, context=None): log = logging.getLogger(__name__) if inspect.iscoroutinefunction(func): + @functools.wraps(func) async def async_wrapper(*args, **kwargs): start = time.perf_counter() @@ -70,8 +71,10 @@ def timing(func=None, *, name=None, context=None): log.debug(f"[TIMING] {func_name} took {elapsed:.3f}s") if context is not None: context.record(f"{func_name}_time", elapsed) + return async_wrapper else: + @functools.wraps(func) def sync_wrapper(*args, **kwargs): start = time.perf_counter() @@ -83,4 +86,5 @@ def timing(func=None, *, name=None, context=None): log.debug(f"[TIMING] {func_name} took {elapsed:.3f}s") if context is not None: context.record(f"{func_name}_time", elapsed) - return sync_wrapper \ No newline at end of file + + return sync_wrapper diff --git a/common/doc_store/doc_store_base.py b/common/doc_store/doc_store_base.py index fd684baef2..edecd0531c 100644 --- a/common/doc_store/doc_store_base.py +++ b/common/doc_store/doc_store_base.py @@ -21,6 +21,7 @@ DEFAULT_MATCH_VECTOR_TOPN = 10 DEFAULT_MATCH_SPARSE_TOPN = 10 VEC = list | np.ndarray + @dataclass class SparseVector: indices: list[int] @@ -53,6 +54,7 @@ class SparseVector: def __repr__(self): return str(self) + class MatchTextExpr: def __init__( self, @@ -130,12 +132,15 @@ MatchExpr = MatchTextExpr | MatchDenseExpr | MatchSparseExpr | MatchTensorExpr | class OrderByExpr: def __init__(self): self.fields = list() + def asc(self, field: str): self.fields.append((field, 0)) return self + def desc(self, field: str): self.fields.append((field, 1)) return self + def fields(self): return self.fields @@ -190,17 +195,18 @@ class DocStoreConnection(ABC): @abstractmethod def search( - self, select_fields: list[str], - highlight_fields: list[str], - condition: dict, - match_expressions: list[MatchExpr], - order_by: OrderByExpr, - offset: int, - limit: int, - index_names: str|list[str], - dataset_ids: list[str], - agg_fields: list[str] | None = None, - rank_feature: dict | None = None + self, + select_fields: list[str], + highlight_fields: list[str], + condition: dict, + match_expressions: list[MatchExpr], + order_by: OrderByExpr, + offset: int, + limit: int, + index_names: str | list[str], + dataset_ids: list[str], + agg_fields: list[str] | None = None, + rank_feature: dict | None = None, ): """ Search with given conjunctive equivalent filtering condition and return all fields of matched documents @@ -262,6 +268,7 @@ class DocStoreConnection(ABC): """ SQL """ + @abstractmethod def sql(self, sql: str, fetch_size: int, format: str): """ diff --git a/common/doc_store/es_conn_base.py b/common/doc_store/es_conn_base.py index 5d1a718647..e745a697b3 100644 --- a/common/doc_store/es_conn_base.py +++ b/common/doc_store/es_conn_base.py @@ -35,7 +35,7 @@ ATTEMPT_TIME = 2 class ESConnectionBase(DocStoreConnection): - def __init__(self, mapping_file_name: str="mapping.json", logger_name: str='ragflow.es_conn'): + def __init__(self, mapping_file_name: str = "mapping.json", logger_name: str = "ragflow.es_conn"): from common.doc_store.es_conn_pool import ES_CONN self.logger = logging.getLogger(logger_name) @@ -79,42 +79,34 @@ class ESConnectionBase(DocStoreConnection): raw_stats = self.es.cluster.stats() self.logger.debug(f"ESConnection.get_cluster_stats: {raw_stats}") try: - res = { - 'cluster_name': raw_stats['cluster_name'], - 'status': raw_stats['status'] - } - indices_status = raw_stats['indices'] - res.update({ - 'indices': indices_status['count'], - 'indices_shards': indices_status['shards']['total'] - }) - doc_info = indices_status['docs'] - res.update({ - 'docs': doc_info['count'], - 'docs_deleted': doc_info['deleted'] - }) - store_info = indices_status['store'] - res.update({ - 'store_size': convert_bytes(store_info['size_in_bytes']), - 'total_dataset_size': convert_bytes(store_info['total_data_set_size_in_bytes']) - }) - mappings_info = indices_status['mappings'] - res.update({ - 'mappings_fields': mappings_info['total_field_count'], - 'mappings_deduplicated_fields': mappings_info['total_deduplicated_field_count'], - 'mappings_deduplicated_size': convert_bytes(mappings_info['total_deduplicated_mapping_size_in_bytes']) - }) - node_info = raw_stats['nodes'] - res.update({ - 'nodes': node_info['count']['total'], - 'nodes_version': node_info['versions'], - 'os_mem': convert_bytes(node_info['os']['mem']['total_in_bytes']), - 'os_mem_used': convert_bytes(node_info['os']['mem']['used_in_bytes']), - 'os_mem_used_percent': node_info['os']['mem']['used_percent'], - 'jvm_versions': node_info['jvm']['versions'][0]['vm_version'], - 'jvm_heap_used': convert_bytes(node_info['jvm']['mem']['heap_used_in_bytes']), - 'jvm_heap_max': convert_bytes(node_info['jvm']['mem']['heap_max_in_bytes']) - }) + res = {"cluster_name": raw_stats["cluster_name"], "status": raw_stats["status"]} + indices_status = raw_stats["indices"] + res.update({"indices": indices_status["count"], "indices_shards": indices_status["shards"]["total"]}) + doc_info = indices_status["docs"] + res.update({"docs": doc_info["count"], "docs_deleted": doc_info["deleted"]}) + store_info = indices_status["store"] + res.update({"store_size": convert_bytes(store_info["size_in_bytes"]), "total_dataset_size": convert_bytes(store_info["total_data_set_size_in_bytes"])}) + mappings_info = indices_status["mappings"] + res.update( + { + "mappings_fields": mappings_info["total_field_count"], + "mappings_deduplicated_fields": mappings_info["total_deduplicated_field_count"], + "mappings_deduplicated_size": convert_bytes(mappings_info["total_deduplicated_mapping_size_in_bytes"]), + } + ) + node_info = raw_stats["nodes"] + res.update( + { + "nodes": node_info["count"]["total"], + "nodes_version": node_info["versions"], + "os_mem": convert_bytes(node_info["os"]["mem"]["total_in_bytes"]), + "os_mem_used": convert_bytes(node_info["os"]["mem"]["used_in_bytes"]), + "os_mem_used_percent": node_info["os"]["mem"]["used_percent"], + "jvm_versions": node_info["jvm"]["versions"][0]["vm_version"], + "jvm_heap_used": convert_bytes(node_info["jvm"]["mem"]["heap_used_in_bytes"]), + "jvm_heap_max": convert_bytes(node_info["jvm"]["mem"]["heap_max_in_bytes"]), + } + ) return res except Exception as e: @@ -130,9 +122,7 @@ class ESConnectionBase(DocStoreConnection): if self.index_exist(index_name, dataset_id): return True try: - return IndicesClient(self.es).create(index=index_name, - settings=self.mapping["settings"], - mappings=self.mapping["mappings"]) + return IndicesClient(self.es).create(index=index_name, settings=self.mapping["settings"], mappings=self.mapping["mappings"]) except Exception: self.logger.exception("ESConnection.createIndex error %s" % index_name) @@ -153,9 +143,7 @@ class ESConnectionBase(DocStoreConnection): with open(fp_mapping, "r") as f: doc_meta_mapping = json.load(f) - return IndicesClient(self.es).create(index=index_name, - settings=doc_meta_mapping["settings"], - mappings=doc_meta_mapping["mappings"]) + return IndicesClient(self.es).create(index=index_name, settings=doc_meta_mapping["settings"], mappings=doc_meta_mapping["mappings"]) except Exception as e: self.logger.exception(f"Error creating document metadata index {index_name}: {e}") @@ -247,8 +235,11 @@ class ESConnectionBase(DocStoreConnection): def get(self, doc_id: str, index_name: str, dataset_ids: list[str]) -> dict | None: for i in range(ATTEMPT_TIME): try: - res = self.es.get(index=index_name, - id=doc_id, source=True, ) + res = self.es.get( + index=index_name, + id=doc_id, + source=True, + ) if str(res.get("timed_out", "")).lower() == "true": raise Exception("Es Timeout.") doc = res["_source"] @@ -264,17 +255,18 @@ class ESConnectionBase(DocStoreConnection): @abstractmethod def search( - self, select_fields: list[str], - highlight_fields: list[str], - condition: dict, - match_expressions: list[MatchExpr], - order_by: OrderByExpr, - offset: int, - limit: int, - index_names: str | list[str], - dataset_ids: list[str], - agg_fields: list[str] | None = None, - rank_feature: dict | None = None + self, + select_fields: list[str], + highlight_fields: list[str], + condition: dict, + match_expressions: list[MatchExpr], + order_by: OrderByExpr, + offset: int, + limit: int, + index_names: str | list[str], + dataset_ids: list[str], + agg_fields: list[str] | None = None, + rank_feature: dict | None = None, ): raise NotImplementedError("Not implemented") @@ -345,8 +337,7 @@ class ESConnectionBase(DocStoreConnection): txt_list = [] for t in re.split(r"[.?!;\n]", txt): for w in keywords: - t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])" % re.escape(w), r"\1\2\3", t, - flags=re.IGNORECASE | re.MULTILINE) + t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])" % re.escape(w), r"\1\2\3", t, flags=re.IGNORECASE | re.MULTILINE) if not re.search(r"[^<>]+", t, flags=re.IGNORECASE | re.MULTILINE): continue txt_list.append(t) @@ -372,14 +363,8 @@ class ESConnectionBase(DocStoreConnection): replaces = [] for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql): fld, v = r.group(1), r.group(3) - match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format( - fld, rag_tokenizer.fine_grained_tokenize(rag_tokenizer.tokenize(v))) - replaces.append( - ("{}{}'{}'".format( - r.group(1), - r.group(2), - r.group(3)), - match)) + match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format(fld, rag_tokenizer.fine_grained_tokenize(rag_tokenizer.tokenize(v))) + replaces.append(("{}{}'{}'".format(r.group(1), r.group(2), r.group(3)), match)) for p, r in replaces: sql = sql.replace(p, r, 1) @@ -387,8 +372,7 @@ class ESConnectionBase(DocStoreConnection): for i in range(ATTEMPT_TIME): try: - res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format, - request_timeout="2s") + res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format, request_timeout="2s") return res except ConnectionTimeout: self.logger.exception("ES request timeout") diff --git a/common/doc_store/es_conn_pool.py b/common/doc_store/es_conn_pool.py index c25f5f786b..c9139430cb 100644 --- a/common/doc_store/es_conn_pool.py +++ b/common/doc_store/es_conn_pool.py @@ -25,7 +25,6 @@ ATTEMPT_TIME = 2 @singleton class ElasticSearchConnectionPool: - def __init__(self): if hasattr(settings, "ES"): self.ES_CONFIG = settings.ES @@ -54,10 +53,10 @@ class ElasticSearchConnectionPool: def _connect(self): self.es_conn = Elasticsearch( self.ES_CONFIG["hosts"].split(","), - basic_auth=(self.ES_CONFIG["username"], self.ES_CONFIG[ - "password"]) if "username" in self.ES_CONFIG and "password" in self.ES_CONFIG else None, - verify_certs= self.ES_CONFIG.get("verify_certs", False), - timeout=600 ) + basic_auth=(self.ES_CONFIG["username"], self.ES_CONFIG["password"]) if "username" in self.ES_CONFIG and "password" in self.ES_CONFIG else None, + verify_certs=self.ES_CONFIG.get("verify_certs", False), + timeout=600, + ) if self.es_conn: self.info = self.es_conn.info() return True diff --git a/common/doc_store/infinity_conn_base.py b/common/doc_store/infinity_conn_base.py index 169bd85e87..1a97e2a833 100644 --- a/common/doc_store/infinity_conn_base.py +++ b/common/doc_store/infinity_conn_base.py @@ -74,7 +74,10 @@ def _int_env(name: str, default: int) -> int: return int(raw) except ValueError: logging.getLogger(__name__).warning( - "Ignoring invalid %s=%r, falling back to %d", name, raw, default, + "Ignoring invalid %s=%r, falling back to %d", + name, + raw, + default, ) return default @@ -129,24 +132,29 @@ def _retry_on_meta_contention( last_exc = exc if attempt == max_attempts - 1: break - base = (base_delay_ms / 1000.0) * (2 ** attempt) + base = (base_delay_ms / 1000.0) * (2**attempt) sleep_for = base + random.uniform(0, base * 0.5) log.info( - "INFINITY meta contention on %s (attempt %d/%d), " - "retrying in %.3fs: %s", - op_name, attempt + 1, max_attempts, sleep_for, exc, + "INFINITY meta contention on %s (attempt %d/%d), retrying in %.3fs: %s", + op_name, + attempt + 1, + max_attempts, + sleep_for, + exc, ) time.sleep(sleep_for) log.warning( "INFINITY meta contention on %s exhausted %d attempts: %s", - op_name, max_attempts, last_exc, + op_name, + max_attempts, + last_exc, ) assert last_exc is not None raise last_exc class InfinityConnectionBase(DocStoreConnection): - def __init__(self, mapping_file_name: str = "infinity_mapping.json", logger_name: str = "ragflow.infinity_conn", table_name_prefix: str="ragflow_"): + def __init__(self, mapping_file_name: str = "infinity_mapping.json", logger_name: str = "ragflow.infinity_conn", table_name_prefix: str = "ragflow_"): from common.doc_store.infinity_conn_pool import INFINITY_CONN self.dbName = settings.INFINITY.get("db_name", "default_db") diff --git a/common/doc_store/infinity_conn_pool.py b/common/doc_store/infinity_conn_pool.py index 83ea4d51ff..8eab676f75 100644 --- a/common/doc_store/infinity_conn_pool.py +++ b/common/doc_store/infinity_conn_pool.py @@ -27,16 +27,11 @@ from common.decorator import singleton @singleton class InfinityConnectionPool: - def __init__(self): if hasattr(settings, "INFINITY"): self.INFINITY_CONFIG = settings.INFINITY else: - self.INFINITY_CONFIG = settings.get_base_config("infinity", { - "uri": "infinity:23817", - "postgres_port": 5432, - "db_name": "default_db" - }) + self.INFINITY_CONFIG = settings.get_base_config("infinity", {"uri": "infinity:23817", "postgres_port": 5432, "db_name": "default_db"}) raw_pool_max_size = os.environ.get("INFINITY_POOL_MAX_SIZE", "4") try: diff --git a/common/doc_store/ob_conn_base.py b/common/doc_store/ob_conn_base.py index c42868249e..29b2379c59 100644 --- a/common/doc_store/ob_conn_base.py +++ b/common/doc_store/ob_conn_base.py @@ -72,6 +72,7 @@ def _try_with_lock(lock_name: str, process_func, check_func, timeout: int = None if not check_func(): from rag.utils.redis_conn import RedisDistributedLock + lock = RedisDistributedLock(lock_name) if lock.acquire(): try: @@ -97,7 +98,7 @@ def _try_with_lock(lock_name: str, process_func, check_func, timeout: int = None class OBConnectionBase(DocStoreConnection): """Base class for OceanBase document store connections.""" - def __init__(self, logger_name: str = 'ragflow.ob_conn'): + def __init__(self, logger_name: str = "ragflow.ob_conn"): from common.doc_store.ob_conn_pool import OB_CONN self.logger = logging.getLogger(logger_name) @@ -119,13 +120,13 @@ class OBConnectionBase(DocStoreConnection): def _load_env_vars(self): def is_true(var: str, default: str) -> bool: - return os.getenv(var, default).lower() in ['true', '1', 'yes', 'y'] + return os.getenv(var, default).lower() in ["true", "1", "yes", "y"] - self.enable_fulltext_search = is_true('ENABLE_FULLTEXT_SEARCH', 'true') - self.use_fulltext_hint = is_true('USE_FULLTEXT_HINT', 'true') - self.search_original_content = is_true("SEARCH_ORIGINAL_CONTENT", 'true') - self.enable_hybrid_search = is_true('ENABLE_HYBRID_SEARCH', 'false') - self.use_fulltext_first_fusion_search = is_true('USE_FULLTEXT_FIRST_FUSION_SEARCH', 'true') + self.enable_fulltext_search = is_true("ENABLE_FULLTEXT_SEARCH", "true") + self.use_fulltext_hint = is_true("USE_FULLTEXT_HINT", "true") + self.search_original_content = is_true("SEARCH_ORIGINAL_CONTENT", "true") + self.enable_hybrid_search = is_true("ENABLE_HYBRID_SEARCH", "false") + self.use_fulltext_first_fusion_search = is_true("USE_FULLTEXT_FIRST_FUSION_SEARCH", "true") # Adjust settings based on hybrid search availability if self.es is not None and self.search_original_content: @@ -172,10 +173,7 @@ class OBConnectionBase(DocStoreConnection): return "oceanbase" def health(self) -> dict: - return { - "uri": self.uri, - "version_comment": self._get_variable_value("version_comment") - } + return {"uri": self.uri, "version_comment": self._get_variable_value("version_comment")} def _get_variable_value(self, var_name: str) -> Any: rows = self.client.perform_raw_text_sql(f"SHOW VARIABLES LIKE '{var_name}'") @@ -241,8 +239,7 @@ class OBConnectionBase(DocStoreConnection): for column_name in self.get_index_columns(): _try_with_lock( lock_name=f"{lock_prefix}add_idx_{table_name}_{column_name}", - check_func=lambda cn=column_name: self._index_exists(table_name, - index_name_template % (table_name, cn)), + check_func=lambda cn=column_name: self._index_exists(table_name, index_name_template % (table_name, cn)), process_func=lambda cn=column_name: self._add_index(table_name, cn), ) @@ -338,28 +335,34 @@ class OBConnectionBase(DocStoreConnection): def _get_count(self, table_name: str, filter_list: list[str] = None) -> int: where_clause = "WHERE " + " AND ".join(filter_list) if filter_list and len(filter_list) > 0 else "" - (count,) = self.client.perform_raw_text_sql( - f"SELECT COUNT(*) FROM {table_name} {where_clause}" - ).fetchone() + (count,) = self.client.perform_raw_text_sql(f"SELECT COUNT(*) FROM {table_name} {where_clause}").fetchone() return count def _column_exist(self, table_name: str, column_name: str) -> bool: - return self._get_count( - table_name="INFORMATION_SCHEMA.COLUMNS", - filter_list=[ - f"TABLE_SCHEMA = '{self.db_name}'", - f"TABLE_NAME = '{table_name}'", - f"COLUMN_NAME = '{column_name}'", - ]) > 0 + return ( + self._get_count( + table_name="INFORMATION_SCHEMA.COLUMNS", + filter_list=[ + f"TABLE_SCHEMA = '{self.db_name}'", + f"TABLE_NAME = '{table_name}'", + f"COLUMN_NAME = '{column_name}'", + ], + ) + > 0 + ) def _index_exists(self, table_name: str, idx_name: str) -> bool: - return self._get_count( - table_name="INFORMATION_SCHEMA.STATISTICS", - filter_list=[ - f"TABLE_SCHEMA = '{self.db_name}'", - f"TABLE_NAME = '{table_name}'", - f"INDEX_NAME = '{idx_name}'", - ]) > 0 + return ( + self._get_count( + table_name="INFORMATION_SCHEMA.STATISTICS", + filter_list=[ + f"TABLE_SCHEMA = '{self.db_name}'", + f"TABLE_NAME = '{table_name}'", + f"INDEX_NAME = '{idx_name}'", + ], + ) + > 0 + ) def _create_table_with_columns(self, table_name: str, columns: list[Column]): """Create table with specified columns.""" @@ -418,9 +421,7 @@ class OBConnectionBase(DocStoreConnection): column_names=[vector_field_name], vidx_params="distance=cosine, type=hnsw, lib=vsag", ) - self.logger.info( - f"Created vector index '{vector_idx_name}' on table '{table_name}' with column '{vector_field_name}'." - ) + self.logger.info(f"Created vector index '{vector_idx_name}' on table '{table_name}' with column '{vector_field_name}'.") def _add_column(self, table_name: str, column: Column): try: @@ -496,11 +497,7 @@ class OBConnectionBase(DocStoreConnection): elapsed_time = time.time() - start_time return rows, elapsed_time - def _parse_fulltext_columns( - self, - fulltext_query: str, - fulltext_columns: list[str] - ) -> tuple[dict[str, str], dict[str, float]]: + def _parse_fulltext_columns(self, fulltext_query: str, fulltext_columns: list[str]) -> tuple[dict[str, str], dict[str, float]]: """ Parse fulltext search columns with optional weight suffix and build search expressions. @@ -538,16 +535,7 @@ class OBConnectionBase(DocStoreConnection): return fulltext_search_expr, fulltext_search_weight def _build_vector_search_sql( - self, - table_name: str, - fields_expr: str, - vector_search_score_expr: str, - filters_expr: str, - vector_search_filter: str, - vector_search_expr: str, - limit: int, - vector_topn: int, - offset: int = 0 + self, table_name: str, fields_expr: str, vector_search_score_expr: str, filters_expr: str, vector_search_filter: str, vector_search_expr: str, limit: int, vector_topn: int, offset: int = 0 ) -> str: sql = ( f"SELECT {fields_expr}, {vector_search_score_expr} AS _score" @@ -561,16 +549,7 @@ class OBConnectionBase(DocStoreConnection): return sql def _build_fulltext_search_sql( - self, - table_name: str, - fields_expr: str, - fulltext_search_score_expr: str, - filters_expr: str, - fulltext_search_filter: str, - offset: int, - limit: int, - fulltext_topn: int, - hint: str = "" + self, table_name: str, fields_expr: str, fulltext_search_score_expr: str, filters_expr: str, fulltext_search_filter: str, offset: int, limit: int, fulltext_topn: int, hint: str = "" ) -> str: hint_expr = f"{hint} " if hint else "" return ( @@ -581,28 +560,10 @@ class OBConnectionBase(DocStoreConnection): f" LIMIT {offset}, {limit if limit != 0 else fulltext_topn}" ) - def _build_filter_search_sql( - self, - table_name: str, - fields_expr: str, - filters_expr: str, - order_by_expr: str = "", - limit_expr: str = "" - ) -> str: - return ( - f"SELECT {fields_expr}" - f" FROM {table_name}" - f" WHERE {filters_expr}" - f" {order_by_expr} {limit_expr}" - ) + def _build_filter_search_sql(self, table_name: str, fields_expr: str, filters_expr: str, order_by_expr: str = "", limit_expr: str = "") -> str: + return f"SELECT {fields_expr} FROM {table_name} WHERE {filters_expr} {order_by_expr} {limit_expr}" - def _build_count_sql( - self, - table_name: str, - filters_expr: str, - extra_filter: str = "", - hint: str = "" - ) -> str: + def _build_count_sql(self, table_name: str, filters_expr: str, extra_filter: str = "", hint: str = "") -> str: hint_expr = f"{hint} " if hint else "" where_clause = f"{filters_expr} AND {extra_filter}" if extra_filter else filters_expr return f"SELECT {hint_expr}COUNT(id) FROM {table_name} WHERE {where_clause}" @@ -662,6 +623,7 @@ class OBConnectionBase(DocStoreConnection): condition[self._get_dataset_id_field()] = dataset_id try: from sqlalchemy import text + res = self.client.get( table_name=index_name, ids=None, diff --git a/common/doc_store/ob_conn_pool.py b/common/doc_store/ob_conn_pool.py index 5cb995edb5..d7e23ab748 100644 --- a/common/doc_store/ob_conn_pool.py +++ b/common/doc_store/ob_conn_pool.py @@ -28,12 +28,11 @@ from common.decorator import singleton ATTEMPT_TIME = 2 OB_QUERY_TIMEOUT = int(os.environ.get("OB_QUERY_TIMEOUT", "100_000_000")) -logger = logging.getLogger('ragflow.ob_conn_pool') +logger = logging.getLogger("ragflow.ob_conn_pool") @singleton class OceanBaseConnectionPool: - def __init__(self): self.client = None self.es = None # HybridSearch client @@ -112,9 +111,7 @@ class OceanBaseConnectionPool: ob_version = ObVersion.from_db_version_string(version_str) if ob_version < ObVersion.from_db_version_nums(4, 3, 5, 1): - raise Exception( - f"The version of OceanBase needs to be higher than or equal to 4.3.5.1, current version is {version_str}" - ) + raise Exception(f"The version of OceanBase needs to be higher than or equal to 4.3.5.1, current version is {version_str}") def _try_to_update_ob_query_timeout(self): try: @@ -135,7 +132,7 @@ class OceanBaseConnectionPool: logger.warning(f"Failed to set 'ob_query_timeout' variable: {str(e)}") def _init_hybrid_search(self, max_connections, max_overflow, pool_timeout): - enable_hybrid_search = os.getenv('ENABLE_HYBRID_SEARCH', 'false').lower() in ['true', '1', 'yes', 'y'] + enable_hybrid_search = os.getenv("ENABLE_HYBRID_SEARCH", "false").lower() in ["true", "1", "yes", "y"] if enable_hybrid_search: try: self.es = HybridSearch( diff --git a/common/exceptions.py b/common/exceptions.py index bfbf245228..d5542dba6f 100644 --- a/common/exceptions.py +++ b/common/exceptions.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + class TaskCanceledException(Exception): def __init__(self, msg): self.msg = msg @@ -27,8 +28,9 @@ class NotFoundException(Exception): def __init__(self, msg): self.msg = msg + class ModelException(Exception): def __init__(self, msg, retryable=False): super().__init__(msg) self.msg = msg - self.retryable = retryable \ No newline at end of file + self.retryable = retryable diff --git a/common/file_utils.py b/common/file_utils.py index af691f9fee..3db448c25d 100644 --- a/common/file_utils.py +++ b/common/file_utils.py @@ -41,6 +41,7 @@ def get_project_base_directory(*args): return os.path.join(project_base, *args) return project_base + def traversal_files(base): for root, ds, fs in os.walk(base): for f in fs: diff --git a/common/http_client.py b/common/http_client.py index 28c988ef65..38411ab560 100644 --- a/common/http_client.py +++ b/common/http_client.py @@ -26,9 +26,7 @@ logger = logging.getLogger(__name__) # Default knobs; keep conservative to avoid unexpected behavioural changes. DEFAULT_TIMEOUT = float(os.environ.get("HTTP_CLIENT_TIMEOUT", "15")) # Align with requests default: follow redirects with a max of 30 unless overridden. -DEFAULT_FOLLOW_REDIRECTS = bool( - int(os.environ.get("HTTP_CLIENT_FOLLOW_REDIRECTS", "1")) -) +DEFAULT_FOLLOW_REDIRECTS = bool(int(os.environ.get("HTTP_CLIENT_FOLLOW_REDIRECTS", "1"))) DEFAULT_MAX_REDIRECTS = int(os.environ.get("HTTP_CLIENT_MAX_REDIRECTS", "30")) DEFAULT_MAX_RETRIES = int(os.environ.get("HTTP_CLIENT_MAX_RETRIES", "2")) DEFAULT_BACKOFF_FACTOR = float(os.environ.get("HTTP_CLIENT_BACKOFF_FACTOR", "0.5")) @@ -36,9 +34,7 @@ DEFAULT_PROXY = os.environ.get("HTTP_CLIENT_PROXY") DEFAULT_USER_AGENT = os.environ.get("HTTP_CLIENT_USER_AGENT", "ragflow-http-client") -def _clean_headers( - headers: Optional[Dict[str, str]], auth_token: Optional[str] = None -) -> Optional[Dict[str, str]]: +def _clean_headers(headers: Optional[Dict[str, str]], auth_token: Optional[str] = None) -> Optional[Dict[str, str]]: merged_headers: Dict[str, str] = {} if DEFAULT_USER_AGENT: merged_headers["User-Agent"] = DEFAULT_USER_AGENT @@ -57,6 +53,7 @@ def _get_delay(backoff_factor: float, attempt: int) -> float: # List of sensitive parameters to redact from URLs before logging _SENSITIVE_QUERY_KEYS = {"client_secret", "secret", "code", "access_token", "refresh_token", "password", "token", "app_secret"} + def _redact_sensitive_url_params(url: str) -> str: """ Return a version of the URL that is safe to log. @@ -87,6 +84,7 @@ def _redact_sensitive_url_params(url: str) -> str: # If parsing fails, fall back to omitting the URL entirely. return "" + def _is_sensitive_url(url: str) -> bool: """Return True if URL is one of the configured OAuth endpoints.""" # Collect known sensitive endpoint URLs from settings @@ -116,6 +114,7 @@ def _is_sensitive_url(url: str) -> bool: return True return False + async def async_request( method: str, url: str, @@ -132,14 +131,10 @@ async def async_request( ) -> httpx.Response: """Lightweight async HTTP wrapper using httpx.AsyncClient with safe defaults.""" timeout = request_timeout if request_timeout is not None else DEFAULT_TIMEOUT - follow_redirects = ( - DEFAULT_FOLLOW_REDIRECTS if follow_redirects is None else follow_redirects - ) + follow_redirects = DEFAULT_FOLLOW_REDIRECTS if follow_redirects is None else follow_redirects max_redirects = DEFAULT_MAX_REDIRECTS if max_redirects is None else max_redirects retries = DEFAULT_MAX_RETRIES if retries is None else max(retries, 0) - backoff_factor = ( - DEFAULT_BACKOFF_FACTOR if backoff_factor is None else backoff_factor - ) + backoff_factor = DEFAULT_BACKOFF_FACTOR if backoff_factor is None else backoff_factor headers = _clean_headers(headers, auth_token=auth_token) proxy = DEFAULT_PROXY if proxy is None else proxy @@ -153,9 +148,7 @@ async def async_request( for attempt in range(retries + 1): try: start = time.monotonic() - response = await client.request( - method=method, url=url, headers=headers, **kwargs - ) + response = await client.request(method=method, url=url, headers=headers, **kwargs) duration = time.monotonic() - start if not _is_sensitive_url(url): log_url = _redact_sensitive_url_params(url) @@ -171,9 +164,7 @@ async def async_request( delay = _get_delay(backoff_factor, attempt) if not _is_sensitive_url(url): log_url = _redact_sensitive_url_params(url) - logger.warning( - f"async_request attempt {attempt + 1}/{retries + 1} failed for {method} {log_url}; retrying in {delay:.2f}s" - ) + logger.warning(f"async_request attempt {attempt + 1}/{retries + 1} failed for {method} {log_url}; retrying in {delay:.2f}s") await asyncio.sleep(delay) raise last_exc # pragma: no cover @@ -194,14 +185,10 @@ def sync_request( ) -> httpx.Response: """Synchronous counterpart to async_request, for CLI/tests or sync contexts.""" timeout = timeout if timeout is not None else DEFAULT_TIMEOUT - follow_redirects = ( - DEFAULT_FOLLOW_REDIRECTS if follow_redirects is None else follow_redirects - ) + follow_redirects = DEFAULT_FOLLOW_REDIRECTS if follow_redirects is None else follow_redirects max_redirects = DEFAULT_MAX_REDIRECTS if max_redirects is None else max_redirects retries = DEFAULT_MAX_RETRIES if retries is None else max(retries, 0) - backoff_factor = ( - DEFAULT_BACKOFF_FACTOR if backoff_factor is None else backoff_factor - ) + backoff_factor = DEFAULT_BACKOFF_FACTOR if backoff_factor is None else backoff_factor headers = _clean_headers(headers, auth_token=auth_token) proxy = DEFAULT_PROXY if proxy is None else proxy @@ -215,25 +202,17 @@ def sync_request( for attempt in range(retries + 1): try: start = time.monotonic() - response = client.request( - method=method, url=url, headers=headers, **kwargs - ) + response = client.request(method=method, url=url, headers=headers, **kwargs) duration = time.monotonic() - start - logger.debug( - f"sync_request {method} {url} -> {response.status_code} in {duration:.3f}s" - ) + logger.debug(f"sync_request {method} {url} -> {response.status_code} in {duration:.3f}s") return response except httpx.RequestError as exc: last_exc = exc if attempt >= retries: - logger.warning( - f"sync_request exhausted retries for {method} {url}: {exc}" - ) + logger.warning(f"sync_request exhausted retries for {method} {url}: {exc}") raise delay = _get_delay(backoff_factor, attempt) - logger.warning( - f"sync_request attempt {attempt + 1}/{retries + 1} failed for {method} {url}: {exc}; retrying in {delay:.2f}s" - ) + logger.warning(f"sync_request attempt {attempt + 1}/{retries + 1} failed for {method} {url}: {exc}; retrying in {delay:.2f}s") time.sleep(delay) raise last_exc # pragma: no cover diff --git a/common/log_utils.py b/common/log_utils.py index af6b20fb2a..e1d2a6a0f5 100644 --- a/common/log_utils.py +++ b/common/log_utils.py @@ -23,6 +23,7 @@ from common.file_utils import get_project_base_directory initialized_root_logger = False pkg_levels = {} # module-level to allow runtime modification + def init_root_logger(logfile_basename: str, log_format: str = "%(asctime)-15s %(levelname)-8s %(process)d %(message)s"): global initialized_root_logger, pkg_levels if initialized_root_logger: @@ -36,7 +37,7 @@ def init_root_logger(logfile_basename: str, log_format: str = "%(asctime)-15s %( os.makedirs(os.path.dirname(log_path), exist_ok=True) formatter = logging.Formatter(log_format) - handler1 = RotatingFileHandler(log_path, maxBytes=10*1024*1024, backupCount=5) + handler1 = RotatingFileHandler(log_path, maxBytes=10 * 1024 * 1024, backupCount=5) handler1.setFormatter(formatter) logger.addHandler(handler1) @@ -49,7 +50,7 @@ def init_root_logger(logfile_basename: str, log_format: str = "%(asctime)-15s %( LOG_LEVELS = os.environ.get("LOG_LEVELS", "") for pkg_name_level in LOG_LEVELS.split(","): terms = pkg_name_level.split("=") - if len(terms)!= 2: + if len(terms) != 2: continue pkg_name, pkg_level = terms[0], terms[1] pkg_name = pkg_name.strip() @@ -58,11 +59,11 @@ def init_root_logger(logfile_basename: str, log_format: str = "%(asctime)-15s %( pkg_level = logging.INFO pkg_levels[pkg_name] = logging.getLevelName(pkg_level) - for pkg_name in ['peewee', 'pdfminer']: + for pkg_name in ["peewee", "pdfminer"]: if pkg_name not in pkg_levels: pkg_levels[pkg_name] = logging.getLevelName(logging.WARNING) - if 'root' not in pkg_levels: - pkg_levels['root'] = logging.getLevelName(logging.INFO) + if "root" not in pkg_levels: + pkg_levels["root"] = logging.getLevelName(logging.INFO) for pkg_name, pkg_level in pkg_levels.items(): pkg_logger = logging.getLogger(pkg_name) diff --git a/common/mcp_tool_call_conn.py b/common/mcp_tool_call_conn.py index 676978d052..7cf54ed6c4 100644 --- a/common/mcp_tool_call_conn.py +++ b/common/mcp_tool_call_conn.py @@ -49,7 +49,7 @@ class MCPToolBinding: class MCPToolCallSession(ToolCallSession): _ALL_INSTANCES: weakref.WeakSet["MCPToolCallSession"] = weakref.WeakSet() - def __init__(self, mcp_server: Any, server_variables: dict[str, Any] | None = None, custom_header = None) -> None: + def __init__(self, mcp_server: Any, server_variables: dict[str, Any] | None = None, custom_header=None) -> None: self.__class__._ALL_INSTANCES.add(self) self._custom_header = custom_header @@ -123,8 +123,7 @@ class MCPToolCallSession(ToolCallSession): await self._process_mcp_tasks(None, msg) else: - await self._process_mcp_tasks(None, - f"Unsupported MCP server type: {self._mcp_server.server_type}, id: {self._mcp_server.id}") + await self._process_mcp_tasks(None, f"Unsupported MCP server type: {self._mcp_server.server_type}, id: {self._mcp_server.id}") async def _process_mcp_tasks(self, client_session: ClientSession | None, error_message: str | None = None) -> None: while not self._close: @@ -182,8 +181,7 @@ class MCPToolCallSession(ToolCallSession): raise async def _call_mcp_tool(self, name: str, arguments: dict[str, Any], request_timeout: float | int = 10) -> str: - result: CallToolResult = await self._call_mcp_server("tool_call", name=name, arguments=arguments, - request_timeout=request_timeout) + result: CallToolResult = await self._call_mcp_server("tool_call", name=name, arguments=arguments, request_timeout=request_timeout) if result.isError: return f"MCP server error: {result.content}" @@ -307,8 +305,7 @@ def close_multiple_mcp_toolcall_sessions(sessions: list[MCPToolCallSession]) -> except Exception: logging.exception("Exception during MCP session cleanup thread management") - logging.info( - f"{len(sessions)} MCP sessions has been cleaned up. {len(list(MCPToolCallSession._ALL_INSTANCES))} in global context.") + logging.info(f"{len(sessions)} MCP sessions has been cleaned up. {len(list(MCPToolCallSession._ALL_INSTANCES))} in global context.") def shutdown_all_mcp_sessions(): diff --git a/common/metadata_es_filter.py b/common/metadata_es_filter.py index afe0f27386..fcdf168128 100644 --- a/common/metadata_es_filter.py +++ b/common/metadata_es_filter.py @@ -305,9 +305,7 @@ class MetaFilterTranslator: def _translate_start_with(self, field_path: str, value: Any, flt: Dict[str, Any]) -> TranslatedFilter: text = _coerce_string(value, flt) - return TranslatedFilter( - must=[{"prefix": {_keyword_path(field_path): {"value": text, "case_insensitive": True}}}] - ) + return TranslatedFilter(must=[{"prefix": {_keyword_path(field_path): {"value": text, "case_insensitive": True}}}]) def _translate_end_with(self, field_path: str, value: Any, flt: Dict[str, Any]) -> TranslatedFilter: text = _coerce_string(value, flt) diff --git a/common/metadata_infinity_filter.py b/common/metadata_infinity_filter.py index 076cc2e23e..238afd3596 100644 --- a/common/metadata_infinity_filter.py +++ b/common/metadata_infinity_filter.py @@ -13,8 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -"""Translate RAGflow document-metadata filter lists into Infinity SQL filter expressions. -""" +"""Translate RAGflow document-metadata filter lists into Infinity SQL filter expressions.""" from __future__ import annotations @@ -29,6 +28,7 @@ def _validate_key(key: str, flt: Dict[str, Any]) -> None: if not _KEY_PATTERN.match(key): raise ValueError(f"invalid key format (must be identifier-like): {flt}") + SUPPORTED_OPERATORS: frozenset[str] = frozenset( { "=", @@ -55,6 +55,7 @@ _RANGE_OPS: Dict[str, str] = { "≤": "<=", } + class MetaFilterTranslator: """Translate one user filter clause at a time into Infinity SQL filter strings.""" @@ -293,4 +294,4 @@ def _escape_sql_string(s: str) -> str: def _escape_likeWildcards(text: str) -> str: - return text.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") \ No newline at end of file + return text.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") diff --git a/common/metadata_utils.py b/common/metadata_utils.py index a6c6d273dc..130413402d 100644 --- a/common/metadata_utils.py +++ b/common/metadata_utils.py @@ -23,21 +23,8 @@ import json_repair def convert_conditions(metadata_condition): if metadata_condition is None: metadata_condition = {} - op_mapping = { - "is": "=", - "not is": "≠", - ">=": "≥", - "<=": "≤", - "!=": "≠" - } - return [ - { - "op": op_mapping.get(cond["comparison_operator"], cond["comparison_operator"]), - "key": cond["name"], - "value": cond["value"] - } - for cond in metadata_condition.get("conditions", []) - ] + op_mapping = {"is": "=", "not is": "≠", ">=": "≥", "<=": "≤", "!=": "≠"} + return [{"op": op_mapping.get(cond["comparison_operator"], cond["comparison_operator"]), "key": cond["name"], "value": cond["value"]} for cond in metadata_condition.get("conditions", [])] def meta_filter(metas: dict, filters: list[dict], logic: str = "and"): @@ -53,30 +40,15 @@ def meta_filter(metas: dict, filters: list[dict], logic: str = "and"): def filter_out(v2docs, operator, value): ids = [] for input, docids in v2docs.items(): - if operator in ["=", "≠", ">", "<", "≥", "≤"]: # Check if input is in YYYY-MM-DD date format input_str = str(input).strip() value_str = str(value).strip() # Strict date format detection: YYYY-MM-DD (must be 10 chars with correct format) - is_input_date = ( - len(input_str) == 10 and - input_str[4] == '-' and - input_str[7] == '-' and - input_str[:4].isdigit() and - input_str[5:7].isdigit() and - input_str[8:10].isdigit() - ) + is_input_date = len(input_str) == 10 and input_str[4] == "-" and input_str[7] == "-" and input_str[:4].isdigit() and input_str[5:7].isdigit() and input_str[8:10].isdigit() - is_value_date = ( - len(value_str) == 10 and - value_str[4] == '-' and - value_str[7] == '-' and - value_str[:4].isdigit() and - value_str[5:7].isdigit() and - value_str[8:10].isdigit() - ) + is_value_date = len(value_str) == 10 and value_str[4] == "-" and value_str[7] == "-" and value_str[:4].isdigit() and value_str[5:7].isdigit() and value_str[8:10].isdigit() if is_value_date: # Query value is in date format @@ -110,23 +82,17 @@ def meta_filter(metas: dict, filters: list[dict], logic: str = "and"): matched = False try: if operator == "contains": - matched = str(input).find(value) >= 0 if not isinstance(input, list) else any( - str(i).find(value) >= 0 for i in input) + matched = str(input).find(value) >= 0 if not isinstance(input, list) else any(str(i).find(value) >= 0 for i in input) elif operator == "not contains": - matched = str(input).find(value) == -1 if not isinstance(input, list) else all( - str(i).find(value) == -1 for i in input) + matched = str(input).find(value) == -1 if not isinstance(input, list) else all(str(i).find(value) == -1 for i in input) elif operator == "in": matched = input in value if not isinstance(input, list) else all(i in value for i in input) elif operator == "not in": matched = input not in value if not isinstance(input, list) else all(i not in value for i in input) elif operator == "start with": - matched = str(input).lower().startswith(str(value).lower()) if not isinstance(input, - list) else "".join( - [str(i).lower() for i in input]).startswith(str(value).lower()) + matched = str(input).lower().startswith(str(value).lower()) if not isinstance(input, list) else "".join([str(i).lower() for i in input]).startswith(str(value).lower()) elif operator == "end with": - matched = str(input).lower().endswith(str(value).lower()) if not isinstance(input, - list) else "".join( - [str(i).lower() for i in input]).endswith(str(value).lower()) + matched = str(input).lower().endswith(str(value).lower()) if not isinstance(input, list) else "".join([str(i).lower() for i in input]).endswith(str(value).lower()) elif operator == "empty": matched = not input elif operator == "not empty": @@ -173,14 +139,14 @@ def meta_filter(metas: dict, filters: list[dict], logic: str = "and"): async def apply_meta_data_filter( - meta_data_filter: dict | None, - metas: dict | None = None, - question: str = "", - chat_mdl: Any = None, - base_doc_ids: list[str] | None = None, - manual_value_resolver: Callable[[dict], dict] | None = None, - kb_ids: list[str] | None = None, - metas_loader: Callable[[], dict] | None = None, + meta_data_filter: dict | None, + metas: dict | None = None, + question: str = "", + chat_mdl: Any = None, + base_doc_ids: list[str] | None = None, + manual_value_resolver: Callable[[dict], dict] | None = None, + kb_ids: list[str] | None = None, + metas_loader: Callable[[], dict] | None = None, ) -> list[str] | None: """ Apply metadata filtering rules and return the filtered doc_ids. @@ -232,6 +198,7 @@ async def apply_meta_data_filter( if conditions and kb_ids: try: from api.db.services.doc_metadata_service import DocMetadataService + doc_ids = DocMetadataService.filter_doc_ids_by_meta_pushdown(kb_ids, conditions, logic) logging.debug(f"Doc ids filtered by metadata: {doc_ids}") if doc_ids is not None: @@ -285,9 +252,9 @@ async def apply_meta_data_filter( def _try_meta_pushdown( - kb_ids: list[str], - conditions: list[dict], - logic: str, + kb_ids: list[str], + conditions: list[dict], + logic: str, ) -> list[str] | None: """Attempt the ES push-down path; return ``None`` to fall back in-memory. @@ -364,9 +331,7 @@ def metadata_schema(metadata: dict | list | None) -> Dict[str, Any]: if not key: continue - prop_schema = { - "description": item.get("description", "") - } + prop_schema = {"description": item.get("description", "")} if "enum" in item and item["enum"]: prop_schema["enum"] = item["enum"] prop_schema["type"] = "string" diff --git a/common/misc_utils.py b/common/misc_utils.py index 3fe8de5fd2..c1a7bbb50f 100644 --- a/common/misc_utils.py +++ b/common/misc_utils.py @@ -97,7 +97,6 @@ async def download_img(url): if not location: logger.warning( "download_img redirect missing Location header: status=%s redirect_hops=%s", - response.status_code, redirect_hops, ) @@ -106,7 +105,6 @@ async def download_img(url): if response.status_code != 200: logger.warning( "download_img non-200 response: status=%s redirect_hops=%s", - response.status_code, redirect_hops, ) @@ -122,19 +120,13 @@ async def download_img(url): # the URL query string. Only the static # threshold value is logged. "download_img response exceeded max size: max_bytes=%s", - _OAUTH_AVATAR_MAX_BYTES, ) await response.aclose() return ("fail", None) body.extend(chunk) content_type = response.headers.get("Content-Type", "image/jpeg") - data_uri = ( - "data:" - + content_type - + ";base64," - + base64.b64encode(bytes(body)).decode("utf-8") - ) + data_uri = "data:" + content_type + ";base64," + base64.b64encode(bytes(body)).decode("utf-8") return ("data", data_uri) try: @@ -168,16 +160,16 @@ async def download_img(url): # hop count and configured max are logged. logger.warning( "download_img redirect hop limit exceeded: redirect_hops=%s max_redirects=%s", - redirect_hops, _OAUTH_AVATAR_MAX_REDIRECTS, ) return "" -def hash_str2int(line: str, mod: int = 10 ** 8) -> int: +def hash_str2int(line: str, mod: int = 10**8) -> int: return int(hashlib.sha1(line.encode("utf-8")).hexdigest(), 16) % mod + def convert_bytes(size_in_bytes: int) -> str: """ Format size in bytes. @@ -185,7 +177,7 @@ def convert_bytes(size_in_bytes: int) -> str: if size_in_bytes == 0: return "0 B" - units = ['B', 'KB', 'MB', 'GB', 'TB', 'PB'] + units = ["B", "KB", "MB", "GB", "TB", "PB"] i = 0 size = float(size_in_bytes) @@ -227,6 +219,7 @@ def once(func): executed = False result = None lock = threading.Lock() + def wrapper(*args, **kwargs): nonlocal executed, result with lock: @@ -234,12 +227,14 @@ def once(func): executed = True result = func(*args, **kwargs) return result + return wrapper + @once def pip_install_torch(): device = os.getenv("DEVICE", "cpu") - if device=="cpu": + if device == "cpu": return logging.info("Installing pytorch") pkg_names = ["torch>=2.5.0,<3.0.0"] diff --git a/common/query_base.py b/common/query_base.py index ef7ba23d1f..51017d776c 100644 --- a/common/query_base.py +++ b/common/query_base.py @@ -18,7 +18,6 @@ from abc import ABC, abstractmethod class QueryBase(ABC): - @staticmethod def is_chinese(line): arr = re.split(r"[ \t]+", line) @@ -46,7 +45,8 @@ class QueryBase(ABC): (r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "), ( r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of|to|or|and|if) ", - " ") + " ", + ), ] otxt = txt for r, p in patts: @@ -58,12 +58,12 @@ class QueryBase(ABC): @staticmethod def add_space_between_eng_zh(txt): # (ENG/ENG+NUM) + ZH - txt = re.sub(r'([A-Za-z]+[0-9]+)([\u4e00-\u9fa5]+)', r'\1 \2', txt) + txt = re.sub(r"([A-Za-z]+[0-9]+)([\u4e00-\u9fa5]+)", r"\1 \2", txt) # ENG + ZH - txt = re.sub(r'([A-Za-z])([\u4e00-\u9fa5]+)', r'\1 \2', txt) + txt = re.sub(r"([A-Za-z])([\u4e00-\u9fa5]+)", r"\1 \2", txt) # ZH + (ENG/ENG+NUM) - txt = re.sub(r'([\u4e00-\u9fa5]+)([A-Za-z]+[0-9]+)', r'\1 \2', txt) - txt = re.sub(r'([\u4e00-\u9fa5]+)([A-Za-z])', r'\1 \2', txt) + txt = re.sub(r"([\u4e00-\u9fa5]+)([A-Za-z]+[0-9]+)", r"\1 \2", txt) + txt = re.sub(r"([\u4e00-\u9fa5]+)([A-Za-z])", r"\1 \2", txt) return txt @abstractmethod diff --git a/common/settings.py b/common/settings.py index 1c313b3494..d5270eb506 100644 --- a/common/settings.py +++ b/common/settings.py @@ -82,9 +82,9 @@ HTTP_APP_KEY = None GITHUB_OAUTH = None FEISHU_OAUTH = None OAUTH_CONFIG = None -DOC_ENGINE = os.getenv('DOC_ENGINE', 'elasticsearch') -DOC_ENGINE_INFINITY = (DOC_ENGINE.lower() == "infinity") -DOC_ENGINE_OCEANBASE = (DOC_ENGINE.lower() == "oceanbase") +DOC_ENGINE = os.getenv("DOC_ENGINE", "elasticsearch") +DOC_ENGINE_INFINITY = DOC_ENGINE.lower() == "infinity" +DOC_ENGINE_OCEANBASE = DOC_ENGINE.lower() == "oceanbase" docStoreConn = None @@ -130,21 +130,22 @@ EMBEDDING_BATCH_SIZE: int = 16 PARALLEL_DEVICES: int = 0 -STORAGE_IMPL_TYPE = os.getenv('STORAGE_IMPL', 'MINIO') +STORAGE_IMPL_TYPE = os.getenv("STORAGE_IMPL", "MINIO") STORAGE_IMPL = None + def get_svr_queue_name(priority: int, suffix: str = "common") -> str: """ Generate queue name with two dimensions: priority and suffix. - + Args: priority: Task priority (0=low, 1=high) suffix: Task type suffix (common/resume/graphrag/raptor/mindmap) Currently only "common" is used, other suffixes are reserved. - + Returns: Queue name string - + Examples: get_svr_queue_name(0, "common") -> "te.0.common" get_svr_queue_name(1, "common") -> "te.1.common" @@ -154,10 +155,11 @@ def get_svr_queue_name(priority: int, suffix: str = "common") -> str: return f"{SVR_QUEUE_NAME}.{priority}.common" -def get_svr_queue_names(suffix:str): +def get_svr_queue_names(suffix: str): """Return queue names sorted by priority (high to low).""" return [get_svr_queue_name(priority, suffix) for priority in [1, 0]] + def init_secret_key(): secret_key = os.environ.get("RAGFLOW_SECRET_KEY") if secret_key and len(secret_key) >= 32: @@ -176,6 +178,7 @@ def get_secret_key(): return _get_or_create_secret_key() return SECRET_KEY + def _get_or_create_secret_key(): # secret_key = os.environ.get("RAGFLOW_SECRET_KEY") # if secret_key and len(secret_key) >= 32: @@ -195,6 +198,7 @@ def _get_or_create_secret_key(): logging.warning("SECURITY WARNING: Using auto-generated SECRET_KEY.") return secret_key + class StorageFactory: storage_mapping = { Storage.MINIO: RAGFlowMinio, @@ -215,7 +219,7 @@ def init_settings(): global DATABASE_TYPE, DATABASE DATABASE_TYPE = os.getenv("DB_TYPE", "mysql") DATABASE = decrypt_database_config(name=DATABASE_TYPE) - + global ALLOWED_LLM_FACTORIES, LLM_FACTORY, LLM_BASE_URL llm_settings = get_base_config("user_default_llm", {}) or {} llm_default_models = llm_settings.get("default_models", {}) or {} @@ -285,7 +289,6 @@ def init_settings(): global SECRET_KEY SECRET_KEY = init_secret_key() - # authentication authentication_conf = get_base_config("authentication", {}) @@ -299,18 +302,14 @@ def init_settings(): global DOC_ENGINE, DOC_ENGINE_INFINITY, DOC_ENGINE_OCEANBASE, docStoreConn, ES, OB, OS, INFINITY DOC_ENGINE = os.environ.get("DOC_ENGINE", "elasticsearch").strip() - DOC_ENGINE_INFINITY = (DOC_ENGINE.lower() == "infinity") - DOC_ENGINE_OCEANBASE = (DOC_ENGINE.lower() == "oceanbase") + DOC_ENGINE_INFINITY = DOC_ENGINE.lower() == "infinity" + DOC_ENGINE_OCEANBASE = DOC_ENGINE.lower() == "oceanbase" lower_case_doc_engine = DOC_ENGINE.lower() if lower_case_doc_engine == "elasticsearch": ES = get_base_config("es", {}) docStoreConn = rag.utils.es_conn.ESConnection() elif lower_case_doc_engine == "infinity": - INFINITY = get_base_config("infinity", { - "uri": "infinity:23817", - "postgres_port": 5432, - "db_name": "default_db" - }) + INFINITY = get_base_config("infinity", {"uri": "infinity:23817", "postgres_port": 5432, "db_name": "default_db"}) docStoreConn = rag.utils.infinity_conn.InfinityConnection() elif lower_case_doc_engine == "opensearch": OS = get_base_config("os", {}) @@ -330,44 +329,38 @@ def init_settings(): ES = get_base_config("es", {}) msgStoreConn = memory_es_conn.ESConnection() elif DOC_ENGINE == "infinity": - INFINITY = get_base_config("infinity", { - "uri": "infinity:23817", - "postgres_port": 5432, - "db_name": "default_db" - }) + INFINITY = get_base_config("infinity", {"uri": "infinity:23817", "postgres_port": 5432, "db_name": "default_db"}) msgStoreConn = memory_infinity_conn.InfinityConnection() elif lower_case_doc_engine in ["oceanbase", "seekdb"]: msgStoreConn = memory_ob_conn.OBConnection() global AZURE, S3, MINIO, OSS, GCS - if STORAGE_IMPL_TYPE in ['AZURE_SPN', 'AZURE_SAS']: + if STORAGE_IMPL_TYPE in ["AZURE_SPN", "AZURE_SAS"]: AZURE = get_base_config("azure", {}) - elif STORAGE_IMPL_TYPE == 'AWS_S3': + elif STORAGE_IMPL_TYPE == "AWS_S3": S3 = get_base_config("s3", {}) - elif STORAGE_IMPL_TYPE == 'MINIO': + elif STORAGE_IMPL_TYPE == "MINIO": MINIO = decrypt_database_config(name="minio") - elif STORAGE_IMPL_TYPE == 'OSS': + elif STORAGE_IMPL_TYPE == "OSS": OSS = get_base_config("oss", {}) - elif STORAGE_IMPL_TYPE == 'GCS': + elif STORAGE_IMPL_TYPE == "GCS": GCS = get_base_config("gcs", {}) global STORAGE_IMPL storage_impl = StorageFactory.create(Storage[STORAGE_IMPL_TYPE]) - + # Define crypto settings crypto_enabled = os.environ.get("RAGFLOW_CRYPTO_ENABLED", "false").lower() == "true" - + # Check if encryption is enabled if crypto_enabled: try: from rag.utils.encrypted_storage import create_encrypted_storage + algorithm = os.environ.get("RAGFLOW_CRYPTO_ALGORITHM", "aes-256-cbc") crypto_key = os.environ.get("RAGFLOW_CRYPTO_KEY") - - STORAGE_IMPL = create_encrypted_storage(storage_impl, - algorithm=algorithm, - key=crypto_key, - encryption_enabled=crypto_enabled) + + STORAGE_IMPL = create_encrypted_storage(storage_impl, algorithm=algorithm, key=crypto_key, encryption_enabled=crypto_enabled) except Exception as e: logging.error(f"Failed to initialize encrypted storage: {e}") STORAGE_IMPL = storage_impl @@ -412,11 +405,13 @@ def check_and_install_torch(): try: pip_install_torch() import torch.cuda + PARALLEL_DEVICES = torch.cuda.device_count() logging.info(f"found {PARALLEL_DEVICES} gpus") except Exception: logging.info("can't import package 'torch'") + def _parse_model_entry(entry): if isinstance(entry, str): return {"name": entry, "factory": None, "api_key": None, "base_url": None} @@ -447,7 +442,7 @@ def _resolve_per_model_config(entry_dict, backup_factory, backup_api_key, backup "base_url": m_base_url, } + def print_rag_settings(): logging.info(f"MAX_CONTENT_LENGTH: {DOC_MAXIMUM_SIZE}") logging.info(f"MAX_FILE_COUNT_PER_USER: {int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))}") - diff --git a/common/signal_utils.py b/common/signal_utils.py index eb814325ae..f248b5972e 100644 --- a/common/signal_utils.py +++ b/common/signal_utils.py @@ -21,6 +21,7 @@ import logging import tracemalloc from common.log_utils import get_project_base_directory + # SIGUSR1 handler: start tracemalloc and take snapshot def start_tracemalloc_and_snapshot(signum, frame): if not tracemalloc.is_tracing(): @@ -37,11 +38,13 @@ def start_tracemalloc_and_snapshot(signum, frame): snapshot.dump(snapshot_file) current, peak = tracemalloc.get_traced_memory() if sys.platform == "win32": - import psutil + import psutil + process = psutil.Process() max_rss = process.memory_info().rss / 1024 else: import resource + max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB") diff --git a/common/string_utils.py b/common/string_utils.py index ba8371311b..5271709b00 100644 --- a/common/string_utils.py +++ b/common/string_utils.py @@ -63,11 +63,11 @@ def clean_markdown_block(text): """ # Remove opening ```Markdown tag with optional whitespace and newlines # Matches: optional whitespace + ```markdown + optional whitespace + optional newline - text = re.sub(r'^\s*```markdown\s*\n?', '', text) + text = re.sub(r"^\s*```markdown\s*\n?", "", text) # Remove closing ``` tag with optional whitespace and newlines # Matches: optional newline + optional whitespace + ``` + optional whitespace at end - text = re.sub(r'\n?\s*```\s*$', '', text) + text = re.sub(r"\n?\s*```\s*$", "", text) # Return text with surrounding whitespace removed return text.strip() diff --git a/common/time_utils.py b/common/time_utils.py index 50ea2bcaf1..5f1753e905 100644 --- a/common/time_utils.py +++ b/common/time_utils.py @@ -17,6 +17,7 @@ import datetime import logging import time + def current_timestamp(): """ Get the current timestamp in milliseconds. @@ -74,6 +75,7 @@ def date_string_to_timestamp(time_str, format_string="%Y-%m-%d %H:%M:%S"): time_stamp = int(time.mktime(time_array) * 1000) return time_stamp + def datetime_format(date_time: datetime.datetime) -> datetime.datetime: """ Normalize a datetime object by removing microsecond component. @@ -92,8 +94,7 @@ def datetime_format(date_time: datetime.datetime) -> datetime.datetime: >>> datetime_format(dt) datetime.datetime(2024, 1, 1, 12, 30, 45) """ - return datetime.datetime(date_time.year, date_time.month, date_time.day, - date_time.hour, date_time.minute, date_time.second) + return datetime.datetime(date_time.year, date_time.month, date_time.day, date_time.hour, date_time.minute, date_time.second) def get_format_time() -> datetime.datetime: diff --git a/common/versions.py b/common/versions.py index 082927c585..2507ab1048 100644 --- a/common/versions.py +++ b/common/versions.py @@ -24,11 +24,7 @@ def get_ragflow_version() -> str: global RAGFLOW_VERSION_INFO if RAGFLOW_VERSION_INFO != "unknown": return RAGFLOW_VERSION_INFO - version_path = os.path.abspath( - os.path.join( - os.path.dirname(os.path.realpath(__file__)), os.pardir, "VERSION" - ) - ) + version_path = os.path.abspath(os.path.join(os.path.dirname(os.path.realpath(__file__)), os.pardir, "VERSION")) if os.path.exists(version_path): with open(version_path, "r") as f: RAGFLOW_VERSION_INFO = f.read().strip() @@ -40,11 +36,7 @@ def get_ragflow_version() -> str: def get_closest_tag_and_count(): try: # Get the current commit hash - version_info = ( - subprocess.check_output(["git", "describe", "--tags", "--match=v*", "--first-parent", "--always"]) - .strip() - .decode("utf-8") - ) + version_info = subprocess.check_output(["git", "describe", "--tags", "--match=v*", "--first-parent", "--always"]).strip().decode("utf-8") return version_info except Exception: return "unknown" diff --git a/deepdoc/__init__.py b/deepdoc/__init__.py index 643f79713c..3c485c025d 100644 --- a/deepdoc/__init__.py +++ b/deepdoc/__init__.py @@ -15,4 +15,5 @@ # from beartype.claw import beartype_this_package + beartype_this_package() diff --git a/deepdoc/parser/docling_parser.py b/deepdoc/parser/docling_parser.py index 097e9c9045..6e4dadd771 100644 --- a/deepdoc/parser/docling_parser.py +++ b/deepdoc/parser/docling_parser.py @@ -40,11 +40,12 @@ except Exception: try: from deepdoc.parser.pdf_parser import RAGFlowPdfParser except Exception: - class RAGFlowPdfParser: + + class RAGFlowPdfParser: pass -from deepdoc.parser.utils import extract_pdf_outlines +from deepdoc.parser.utils import extract_pdf_outlines class DoclingContentType(str, Enum): @@ -56,7 +57,7 @@ class DoclingContentType(str, Enum): @dataclass class _BBox: - page_no: int + page_no: int x0: float y0: float x1: float @@ -67,17 +68,17 @@ def _extract_bbox_from_prov(item, prov_attr: str = "prov") -> Optional[_BBox]: prov = getattr(item, prov_attr, None) if not prov: return None - + prov_item = prov[0] if isinstance(prov, list) else prov pn = getattr(prov_item, "page_no", None) bb = getattr(prov_item, "bbox", None) if pn is None or bb is None: return None - + coords = [getattr(bb, attr) for attr in ("l", "t", "r", "b")] if None in coords: return None - + return _BBox(page_no=int(pn), x0=coords[0], y0=coords[1], x1=coords[2], y1=coords[3]) @@ -92,9 +93,7 @@ class DoclingParser(RAGFlowPdfParser): self.request_timeout = request_timeout def _effective_server_url(self, docling_server_url: Optional[str] = None) -> str: - return (docling_server_url or self.docling_server_url or "").rstrip("/") or ( - os.environ.get("DOCLING_SERVER_URL", "").rstrip("/") - ) + return (docling_server_url or self.docling_server_url or "").rstrip("/") or (os.environ.get("DOCLING_SERVER_URL", "").rstrip("/")) @staticmethod def _is_http_endpoint_valid(url: str, timeout: int = 5) -> bool: @@ -146,16 +145,14 @@ class DoclingParser(RAGFlowPdfParser): if bytes_io: bytes_io.close() - def _make_line_tag(self,bbox: _BBox) -> str: + def _make_line_tag(self, bbox: _BBox) -> str: if bbox is None: return "" - x0,x1, top, bott = bbox.x0, bbox.x1, bbox.y0, bbox.y1 + x0, x1, top, bott = bbox.x0, bbox.x1, bbox.y0, bbox.y1 if hasattr(self, "page_images") and self.page_images and len(self.page_images) >= bbox.page_no: - _, page_height = self.page_images[bbox.page_no-1].size - top, bott = page_height-top ,page_height-bott - return "@@{}\t{:.1f}\t{:.1f}\t{:.1f}\t{:.1f}##".format( - bbox.page_no, x0,x1, top, bott - ) + _, page_height = self.page_images[bbox.page_no - 1].size + top, bott = page_height - top, page_height - bott + return "@@{}\t{:.1f}\t{:.1f}\t{:.1f}\t{:.1f}##".format(bbox.page_no, x0, x1, top, bott) @staticmethod def extract_positions(txt: str) -> list[tuple[list[int], float, float, float, float]]: @@ -183,10 +180,10 @@ class DoclingParser(RAGFlowPdfParser): bottom = top + 4 img0 = self.page_images[pns[0]] x0, y0, x1, y1 = int(left), int(top), int(right), int(min(bottom, img0.size[1])) - + crop0 = img0.crop((x0, y0, x1, y1)) imgs.append(crop0) - if 0 < ii < len(poss)-1: + if 0 < ii < len(poss) - 1: positions.append((pns[0] + self.page_from, x0, x1, y0, y1)) remain_bottom = bottom - img0.size[1] for pn in pns[1:]: @@ -245,8 +242,8 @@ class DoclingParser(RAGFlowPdfParser): section = payload.strip() else: continue - - tag = self._make_line_tag(bbox) if isinstance(bbox,_BBox) else "" + + tag = self._make_line_tag(bbox) if isinstance(bbox, _BBox) else "" if parse_method in {"manual", "pipeline"}: sections.append((section, typ, tag)) elif parse_method == "paper": @@ -268,9 +265,9 @@ class DoclingParser(RAGFlowPdfParser): left, top, right, bott = bbox x0 = float(left) - y0 = float(H-top) + y0 = float(H - top) x1 = float(right) - y1 = float(H-bott) + y1 = float(H - bott) x0, y0 = max(0.0, min(x0, W - 1)), max(0.0, min(y0, H - 1)) x1, y1 = max(x0 + 1.0, min(x1, W)), max(y0 + 1.0, min(y1, H)) @@ -280,7 +277,7 @@ class DoclingParser(RAGFlowPdfParser): except Exception: return None, "" - pos = (page_no-1 if page_no>0 else 0, x0, x1, y0, y1) + pos = (page_no - 1 if page_no > 0 else 0, x0, x1, y0, y1) return crop, [pos] def _transfer_to_tables(self, doc): @@ -355,9 +352,9 @@ class DoclingParser(RAGFlowPdfParser): ): """ Parses a PDF document using a remote Docling server. - - Prioritizes native chunking endpoints (/v1/chunk/source, /v1alpha/chunk/source) - to prevent token overflow, with a graceful fallback to standard conversion + + Prioritizes native chunking endpoints (/v1/chunk/source, /v1alpha/chunk/source) + to prevent token overflow, with a graceful fallback to standard conversion endpoints if chunking is unavailable. """ server_url = self._effective_server_url(docling_server_url) @@ -382,7 +379,7 @@ class DoclingParser(RAGFlowPdfParser): filename = Path(filepath).name or "input.pdf" b64 = base64.b64encode(pdf_bytes).decode("ascii") - + # Standard payloads # Standard fallback payloads (no chunking) v1_payload_standard = { @@ -393,17 +390,17 @@ class DoclingParser(RAGFlowPdfParser): "options": {"from_formats": ["pdf"], "to_formats": ["json", "md", "text"]}, "file_sources": [{"filename": filename, "base64_string": b64}], } - + # --- NEW: Correct API Contract for Chunking --- chunking_opts = { - "from_formats": ["pdf"], + "from_formats": ["pdf"], "to_formats": ["json", "md", "text"], "do_chunking": True, "chunking_options": { "max_tokens": 512, "overlap": 50, - "tokenizer": "sentencepiece" # Required by Docling contract - } + "tokenizer": "sentencepiece", # Required by Docling contract + }, } v1_payload_chunked = { "options": chunking_opts, @@ -434,21 +431,21 @@ class DoclingParser(RAGFlowPdfParser): if resp.status_code < 300: response_json = resp.json() is_chunked_response = chunk_flag - + if chunk_flag: self.logger.info(f"[Docling] Successfully used native chunking on: {endpoint}") else: self.logger.info(f"[Docling] Chunking unavailable, fell back to standard: {endpoint}") break - - # If chunking request is rejected (e.g., 422 Unprocessable Entity on older servers), + + # If chunking request is rejected (e.g., 422 Unprocessable Entity on older servers), # log it and let the loop naturally fall back to the standard payload. if chunk_flag: self.logger.warning(f"[Docling] Server rejected chunking parameters: HTTP {resp.status_code}") continue errors.append(f"{endpoint}: HTTP {resp.status_code} {resp.text[:300]}") - + except Exception as exc: self.logger.error(f"[Docling] Request error on {endpoint}: {exc}") errors.append(f"{endpoint}: {exc}") @@ -458,7 +455,7 @@ class DoclingParser(RAGFlowPdfParser): sections: list[tuple[str, ...]] = [] tables = [] - + # --- NEW: Handle Native Chunked Response --- if is_chunked_response: # The chunking endpoint returns an array of chunk items @@ -470,11 +467,11 @@ class DoclingParser(RAGFlowPdfParser): chunk_text = chunk_data.get("text", "") if not chunk_text and isinstance(chunk_data.get("chunk"), dict): chunk_text = chunk_data["chunk"].get("text", "") - + if isinstance(chunk_text, str) and chunk_text.strip(): # Feed the pre-sliced chunks directly into RAGFlow's expected format sections.extend(self._sections_from_remote_text(chunk_text, parse_method=parse_method)) - + if callback: callback(0.95, f"[Docling] Native chunks received: {len(sections)}") if sections: @@ -511,9 +508,9 @@ class DoclingParser(RAGFlowPdfParser): binary: BytesIO | bytes | None = None, callback: Optional[Callable] = None, *, - output_dir: Optional[str] = None, - lang: Optional[str] = None, - method: str = "auto", + output_dir: Optional[str] = None, + lang: Optional[str] = None, + method: str = "auto", delete_output: bool = True, parse_method: str = "raw", docling_server_url: Optional[str] = None, @@ -559,7 +556,7 @@ class DoclingParser(RAGFlowPdfParser): except Exception as e: self.logger.warning(f"[Docling] render pages failed: {e}") - conv = DocumentConverter() + conv = DocumentConverter() conv_res = conv.convert(str(src_path)) doc = conv_res.document if callback: diff --git a/deepdoc/parser/docx_parser.py b/deepdoc/parser/docx_parser.py index 2d56729b74..48f8db8064 100644 --- a/deepdoc/parser/docx_parser.py +++ b/deepdoc/parser/docx_parser.py @@ -29,6 +29,7 @@ from docx.image.exceptions import ( ) from rag.utils.lazy_image import LazyImage + class RAGFlowDocxParser: def get_picture(self, document, paragraph): imgs = paragraph._element.xpath(".//pic:pic") @@ -69,7 +70,6 @@ class RAGFlowDocxParser: return None return LazyImage(image_blobs) - def __extract_table_content(self, tb): df = [] for row in tb.rows: @@ -91,7 +91,7 @@ class RAGFlowDocxParser: (r"^[0-9A-Z/\._~-]+$", "Ca"), (r"^[A-Z]*[a-z' -]+$", "En"), (r"^[0-9.,+-]+[0-9A-Za-z/$¥%<>()()' -]+$", "NE"), - (r"^.{1}$", "Sg") + (r"^.{1}$", "Sg"), ] for p, n in pattern: if re.search(p, b): @@ -110,16 +110,14 @@ class RAGFlowDocxParser: if len(df) < 2: return [] - max_type = Counter([blockType(str(df.iloc[i, j])) for i in range( - 1, len(df)) for j in range(len(df.iloc[i, :]))]) + max_type = Counter([blockType(str(df.iloc[i, j])) for i in range(1, len(df)) for j in range(len(df.iloc[i, :]))]) max_type = max(max_type.items(), key=lambda x: x[1])[0] colnm = len(df.iloc[0, :]) hdrows = [0] # header is not necessarily appear in the first line if max_type == "Nu": for r in range(1, len(df)): - tys = Counter([blockType(str(df.iloc[r, j])) - for j in range(len(df.iloc[r, :]))]) + tys = Counter([blockType(str(df.iloc[r, j])) for j in range(len(df.iloc[r, :]))]) tys = max(tys.items(), key=lambda x: x[1])[0] if tys != max_type: hdrows.append(r) @@ -160,26 +158,25 @@ class RAGFlowDocxParser: return ["\n".join(lines)] def __call__(self, fnm, from_page=0, to_page=MAXIMUM_PAGE_NUMBER): - self.doc = Document(fnm) if isinstance( - fnm, str) else Document(BytesIO(fnm)) - pn = 0 # parsed page - secs = [] # parsed contents + self.doc = Document(fnm) if isinstance(fnm, str) else Document(BytesIO(fnm)) + pn = 0 # parsed page + secs = [] # parsed contents for p in self.doc.paragraphs: if pn > to_page: break - runs_within_single_paragraph = [] # save runs within the range of pages + runs_within_single_paragraph = [] # save runs within the range of pages for run in p.runs: if pn > to_page: break if from_page <= pn < to_page and p.text.strip(): - runs_within_single_paragraph.append(run.text) # append run.text first + runs_within_single_paragraph.append(run.text) # append run.text first # wrap page break checker into a static method - if 'lastRenderedPageBreak' in run._element.xml: + if "lastRenderedPageBreak" in run._element.xml: pn += 1 - secs.append(("".join(runs_within_single_paragraph), p.style.name if hasattr(p.style, 'name') else '')) # then concat run.text as part of the paragraph + secs.append(("".join(runs_within_single_paragraph), p.style.name if hasattr(p.style, "name") else "")) # then concat run.text as part of the paragraph tbls = [self.__extract_table_content(tb) for tb in self.doc.tables] return secs, tbls diff --git a/deepdoc/parser/epub_parser.py b/deepdoc/parser/epub_parser.py index 5badd7c33b..5b86bc98a4 100644 --- a/deepdoc/parser/epub_parser.py +++ b/deepdoc/parser/epub_parser.py @@ -63,9 +63,7 @@ class RAGFlowEpubParser: continue with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) - sections = html_parser( - item_path, binary=html_bytes, chunk_token_num=chunk_token_num - ) + sections = html_parser(item_path, binary=html_bytes, chunk_token_num=chunk_token_num) all_sections.extend(sections) return all_sections @@ -130,16 +128,9 @@ class RAGFlowEpubParser: continue spine_items.append(opf_dir + href) - return ( - spine_items if spine_items else RAGFlowEpubParser._fallback_xhtml_order(zf) - ) + return spine_items if spine_items else RAGFlowEpubParser._fallback_xhtml_order(zf) @staticmethod def _fallback_xhtml_order(zf): """Fallback: return all .xhtml/.html files sorted alphabetically.""" - return sorted( - n - for n in zf.namelist() - if n.lower().endswith((".xhtml", ".html", ".htm")) - and not n.startswith("META-INF/") - ) + return sorted(n for n in zf.namelist() if n.lower().endswith((".xhtml", ".html", ".htm")) and not n.startswith("META-INF/")) diff --git a/deepdoc/parser/excel_parser.py b/deepdoc/parser/excel_parser.py index 019559d0d7..bc9edf57c2 100644 --- a/deepdoc/parser/excel_parser.py +++ b/deepdoc/parser/excel_parser.py @@ -42,7 +42,7 @@ class RAGFlowExcelParser: try: file_like_object.seek(0) - df = pd.read_csv(file_like_object, on_bad_lines='skip') + df = pd.read_csv(file_like_object, on_bad_lines="skip") return RAGFlowExcelParser._dataframe_to_workbook(df) except Exception as e_csv: @@ -261,7 +261,7 @@ class RAGFlowExcelParser: except Exception as e: logging.warning(f"Parse spreadsheet error: {e}, trying to interpret as CSV file") file_like_object.seek(0) - df = pd.read_csv(file_like_object, on_bad_lines='skip') + df = pd.read_csv(file_like_object, on_bad_lines="skip") df = df.replace(r"^\s*$", "", regex=True) return df.to_markdown(index=False) diff --git a/deepdoc/parser/figure_parser.py b/deepdoc/parser/figure_parser.py index e062f46253..d37a5e7309 100644 --- a/deepdoc/parser/figure_parser.py +++ b/deepdoc/parser/figure_parser.py @@ -27,6 +27,7 @@ from rag.prompts.generator import vision_llm_figure_describe_prompt, vision_llm_ from rag.nlp import append_context2table_image4pdf from rag.utils.lazy_image import ensure_pil_image, open_image_for_processing, is_image_like + # need to delete before pr def vision_figure_parser_figure_data_wrapper(figures_data_without_positions): if not figures_data_without_positions: @@ -44,7 +45,8 @@ def vision_figure_parser_figure_data_wrapper(figures_data_without_positions): ) return res -def vision_figure_parser_docx_wrapper(sections, tbls, callback=None,**kwargs): + +def vision_figure_parser_docx_wrapper(sections, tbls, callback=None, **kwargs): if not sections: return tbls try: @@ -63,7 +65,8 @@ def vision_figure_parser_docx_wrapper(sections, tbls, callback=None,**kwargs): callback(0.8, f"Visual model error: {e}. Skipping figure parsing enhancement.") return tbls -def vision_figure_parser_figure_xlsx_wrapper(images,callback=None, **kwargs): + +def vision_figure_parser_figure_xlsx_wrapper(images, callback=None, **kwargs): tbls = [] if not images: return [] @@ -74,13 +77,18 @@ def vision_figure_parser_figure_xlsx_wrapper(images,callback=None, **kwargs): except Exception: vision_model = None if vision_model: - figures_data = [(( - img["image"], # Image.Image or LazyImage (converted by ensure_pil_image) - [img["image_description"]] # description list (must be list) - ), - [ - (0, 0, 0, 0, 0) # dummy position - ]) for img in images] + figures_data = [ + ( + ( + img["image"], # Image.Image or LazyImage (converted by ensure_pil_image) + [img["image_description"]], # description list (must be list) + ), + [ + (0, 0, 0, 0, 0) # dummy position + ], + ) + for img in images + ] try: parser = VisionFigureParser(vision_model=vision_model, figures_data=figures_data, **kwargs) callback(0.22, "Parsing images...") @@ -90,6 +98,7 @@ def vision_figure_parser_figure_xlsx_wrapper(images,callback=None, **kwargs): callback(0.25, f"Excel visual model error: {e}. Skipping vision enhancement.") return tbls + def vision_figure_parser_pdf_wrapper(tbls, callback=None, **kwargs): if not tbls: return [] @@ -142,6 +151,7 @@ def vision_figure_parser_docx_wrapper_naive(chunks, idx_lst, callback=None, **kw except Exception: vision_model = None if vision_model: + @timeout(30, 3) def worker(idx, ck): img, close_after = open_image_for_processing(ck.get("image"), allow_bytes=True) @@ -178,16 +188,15 @@ def vision_figure_parser_docx_wrapper_naive(chunks, idx_lst, callback=None, **kw pass with ThreadPoolExecutor(max_workers=10) as executor: - futures = [ - executor.submit(worker, idx, chunks[idx]) - for idx in idx_lst - ] + futures = [executor.submit(worker, idx, chunks[idx]) for idx in idx_lst] for future in as_completed(futures): idx, description = future.result() - chunks[idx]['text'] += description - -shared_executor = ThreadPoolExecutor(max_workers=10) + chunks[idx]["text"] += description + + +shared_executor = ThreadPoolExecutor(max_workers=10) + class VisionFigureParser: def __init__(self, vision_model, figures_data, *args, **kwargs): @@ -253,7 +262,9 @@ class VisionFigureParser: context_above=context_above, context_below=context_below, ) - logging.info(f"[VisionFigureParser] figure={figure_idx} context_size={self.context_size} context_above_len={len(context_above)} context_below_len={len(context_below)} prompt=with_context") + logging.info( + f"[VisionFigureParser] figure={figure_idx} context_size={self.context_size} context_above_len={len(context_above)} context_below_len={len(context_below)} prompt=with_context" + ) logging.info(f"[VisionFigureParser] figure={figure_idx} context_above_snippet={context_above[:512]}") logging.info(f"[VisionFigureParser] figure={figure_idx} context_below_snippet={context_below[:512]}") else: diff --git a/deepdoc/parser/html_parser.py b/deepdoc/parser/html_parser.py index c28fa99607..334d1e897e 100644 --- a/deepdoc/parser/html_parser.py +++ b/deepdoc/parser/html_parser.py @@ -23,18 +23,14 @@ import chardet from bs4 import BeautifulSoup, NavigableString, Tag, Comment import html -def get_encoding(file): - with open(file,'rb') as f: - tmp = chardet.detect(f.read()) - return tmp['encoding'] -BLOCK_TAGS = [ - "h1", "h2", "h3", "h4", "h5", "h6", - "p", "div", "article", "section", "aside", - "ul", "ol", "li", - "table", "pre", "code", "blockquote", - "figure", "figcaption" -] +def get_encoding(file): + with open(file, "rb") as f: + tmp = chardet.detect(f.read()) + return tmp["encoding"] + + +BLOCK_TAGS = ["h1", "h2", "h3", "h4", "h5", "h6", "p", "div", "article", "section", "aside", "ul", "ol", "li", "table", "pre", "code", "blockquote", "figure", "figcaption"] TITLE_TAGS = {"h1": "#", "h2": "##", "h3": "###", "h4": "####", "h5": "#####", "h6": "######"} @@ -44,7 +40,7 @@ class RAGFlowHtmlParser: encoding = find_codec(binary) txt = binary.decode(encoding, errors="ignore") else: - with open(fnm, "r",encoding=get_encoding(fnm)) as f: + with open(fnm, "r", encoding=get_encoding(fnm)) as f: txt = f.read() return self.parser_txt(txt, chunk_token_num) @@ -64,8 +60,8 @@ class RAGFlowHtmlParser: script_tag.decompose() # delete inline style for tag in soup.find_all(True): - if 'style' in tag.attrs: - del tag.attrs['style'] + if "style" in tag.attrs: + del tag.attrs["style"] # delete HTML comment for comment in soup.find_all(string=lambda text: isinstance(text, Comment)): comment.extract() @@ -133,21 +129,18 @@ class RAGFlowHtmlParser: return_info.append(info) return return_info elif isinstance(element, Tag): - if str.lower(element.name) == "table": table_info_list = [] table_id = str(uuid.uuid1()) table_list = [html.unescape(str(element))] for t in table_list: - table_info_list.append({"content": t, "tag_name": "table", - "metadata": {"table_id": table_id, "index": table_list.index(t)}}) + table_info_list.append({"content": t, "tag_name": "table", "metadata": {"table_id": table_id, "index": table_list.index(t)}}) return table_info_list else: if str.lower(element.name) in BLOCK_TAGS: block_id = str(uuid.uuid1()) for child in element.children: - child_info = cls.read_text_recursively(child, parser_result, chunk_token_num, element.name, - block_id) + child_info = cls.read_text_recursively(child, parser_result, chunk_token_num, element.name, block_id) parser_result.extend(child_info) return [] @@ -230,13 +223,12 @@ class RAGFlowHtmlParser: # A single atom longer than the budget (e.g. a very long # unbroken token): fall back to fixed character windows. logging.debug( - "html_parser: atom of %d chars exceeds chunk_token_num=%d; " - "falling back to character windows", + "html_parser: atom of %d chars exceeds chunk_token_num=%d; falling back to character windows", len(atom), chunk_token_num, ) for i in range(0, len(atom), chunk_token_num): - pieces.append(atom[i:i + chunk_token_num]) + pieces.append(atom[i : i + chunk_token_num]) continue current += atom current_tokens += atom_tokens diff --git a/deepdoc/parser/markdown_parser.py b/deepdoc/parser/markdown_parser.py index 5acf7111bc..583e4ffd2b 100644 --- a/deepdoc/parser/markdown_parser.py +++ b/deepdoc/parser/markdown_parser.py @@ -245,11 +245,7 @@ class MarkdownElementExtractor: return merged def _protected_ranges(self, text): - return self._merge_ranges( - self._fenced_code_ranges(text) - + self._markdown_table_ranges(text) - + self._html_table_ranges(text) - ) + return self._merge_ranges(self._fenced_code_ranges(text) + self._markdown_table_ranges(text) + self._html_table_ranges(text)) def _append_delimited_section(self, sections, text, start, end, include_meta): part = text[start:end] @@ -307,6 +303,7 @@ class MarkdownElementExtractor: if len(dels) > 0: text = "\n".join(self.lines) sections = self._extract_delimited_elements(text, dels, include_meta) + # Attach lone header lines to the section that follows them so that # "## Title\n" never becomes an isolated chunk when the delimiter # splits at every newline. A header is "lone" when it occupies a @@ -354,11 +351,13 @@ class MarkdownElementExtractor: if _is_attachable_body(body_content): combined = "\n".join(header_parts) + "\n" + body_content if include_meta: - merged.append({ - **sections[i], - "content": combined, - "end_line": sections[j]["end_line"], - }) + merged.append( + { + **sections[i], + "content": combined, + "end_line": sections[j]["end_line"], + } + ) else: merged.append(combined) merged_header_count += len(header_parts) diff --git a/deepdoc/parser/mineru_parser.py b/deepdoc/parser/mineru_parser.py index 4b854925f0..080942bc6d 100644 --- a/deepdoc/parser/mineru_parser.py +++ b/deepdoc/parser/mineru_parser.py @@ -60,28 +60,28 @@ class MinerUContentType(StrEnum): # Mapping from language names to MinerU language codes LANGUAGE_TO_MINERU_MAP = { - 'English': 'en', - 'Chinese': 'ch', - 'Traditional Chinese': 'chinese_cht', - 'Russian': 'east_slavic', - 'Ukrainian': 'east_slavic', - 'Indonesian': 'latin', - 'Spanish': 'latin', - 'Vietnamese': 'latin', - 'Japanese': 'japan', - 'Korean': 'korean', - 'Portuguese BR': 'latin', - 'German': 'latin', - 'French': 'latin', - 'Italian': 'latin', - 'Tamil': 'ta', - 'Telugu': 'te', - 'Kannada': 'ka', - 'Thai': 'th', - 'Greek': 'el', - 'Hindi': 'devanagari', - 'Bulgarian': 'cyrillic', - 'Turkish': 'latin', + "English": "en", + "Chinese": "ch", + "Traditional Chinese": "chinese_cht", + "Russian": "east_slavic", + "Ukrainian": "east_slavic", + "Indonesian": "latin", + "Spanish": "latin", + "Vietnamese": "latin", + "Japanese": "japan", + "Korean": "korean", + "Portuguese BR": "latin", + "German": "latin", + "French": "latin", + "Italian": "latin", + "Tamil": "ta", + "Telugu": "te", + "Kannada": "ka", + "Thai": "th", + "Greek": "el", + "Hindi": "devanagari", + "Bulgarian": "cyrillic", + "Turkish": "latin", } @@ -269,14 +269,10 @@ class MinerUParser(RAGFlowPdfParser): return True, reason - def _run_mineru( - self, input_path: Path, output_dir: Path, options: MinerUParseOptions, callback: Optional[Callable] = None - ) -> Path: + def _run_mineru(self, input_path: Path, output_dir: Path, options: MinerUParseOptions, callback: Optional[Callable] = None) -> Path: return self._run_mineru_api(input_path, output_dir, options, callback) - def _run_mineru_api( - self, input_path: Path, output_dir: Path, options: MinerUParseOptions, callback: Optional[Callable] = None - ) -> Path: + def _run_mineru_api(self, input_path: Path, output_dir: Path, options: MinerUParseOptions, callback: Optional[Callable] = None) -> Path: pdf_file_path = str(input_path) if not os.path.exists(pdf_file_path): @@ -352,8 +348,7 @@ class MinerUParser(RAGFlowPdfParser): try: with pdfplumber.open(fnm) if isinstance(fnm, (str, PathLike)) else pdfplumber.open(BytesIO(fnm)) as pdf: self.pdf = pdf - self.page_images = [p.to_image(resolution=72 * zoomin, antialias=True).original for _, p in - enumerate(self.pdf.pages[page_from:page_to])] + self.page_images = [p.to_image(resolution=72 * zoomin, antialias=True).original for _, p in enumerate(self.pdf.pages[page_from:page_to])] except Exception as e: self.page_images = None self.total_page = 0 @@ -420,8 +415,7 @@ class MinerUParser(RAGFlowPdfParser): pos = poss[-1] last_page_idx = pos[0][-1] if not (0 <= last_page_idx < page_count): - self.logger.warning( - f"[MinerU] Last page index {last_page_idx} out of range for {page_count} pages; skipping crop.") + self.logger.warning(f"[MinerU] Last page index {last_page_idx} out of range for {page_count} pages; skipping crop.") if need_position: return None, None return @@ -447,12 +441,10 @@ class MinerUParser(RAGFlowPdfParser): if 0 <= pn - 1 < page_count: bottom += self.page_images[pn - 1].size[1] else: - self.logger.warning( - f"[MinerU] Page index {pn}-1 out of range for {page_count} pages during crop; skipping height accumulation.") + self.logger.warning(f"[MinerU] Page index {pn}-1 out of range for {page_count} pages during crop; skipping height accumulation.") if not (0 <= pns[0] < page_count): - self.logger.warning( - f"[MinerU] Base page index {pns[0]} out of range for {page_count} pages during crop; skipping this segment.") + self.logger.warning(f"[MinerU] Base page index {pns[0]} out of range for {page_count} pages during crop; skipping this segment.") continue img0 = self.page_images[pns[0]] @@ -471,8 +463,7 @@ class MinerUParser(RAGFlowPdfParser): bottom -= img0.size[1] for pn in pns[1:]: if not (0 <= pn < page_count): - self.logger.warning( - f"[MinerU] Page index {pn} out of range for {page_count} pages during crop; skipping this page.") + self.logger.warning(f"[MinerU] Page index {pn} out of range for {page_count} pages during crop; skipping this page.") continue page = self.page_images[pn] x0, y0, x1, y1 = int(left), 0, int(right), int(min(bottom, page.size[1])) @@ -523,8 +514,7 @@ class MinerUParser(RAGFlowPdfParser): poss.append(([int(p) - 1 for p in pn.split("-")], left, right, top, bottom)) return poss - def _read_output(self, output_dir: Path, file_stem: str, method: str = "auto", backend: str = "pipeline") -> list[ - dict[str, Any]]: + def _read_output(self, output_dir: Path, file_stem: str, method: str = "auto", backend: str = "pipeline") -> list[dict[str, Any]]: json_file = None subdir = None attempted = [] @@ -661,13 +651,11 @@ class MinerUParser(RAGFlowPdfParser): case MinerUContentType.TEXT: section = output.get("text", "") case MinerUContentType.TABLE: - section = output.get("table_body", "") + "\n".join(output.get("table_caption", [])) + "\n".join( - output.get("table_footnote", [])) + section = output.get("table_body", "") + "\n".join(output.get("table_caption", [])) + "\n".join(output.get("table_footnote", [])) if not section.strip(): section = "FAILED TO PARSE TABLE" case MinerUContentType.IMAGE: - section = "".join(output.get("image_caption", [])) + "\n" + "".join( - output.get("image_footnote", [])) + section = "".join(output.get("image_caption", [])) + "\n" + "".join(output.get("image_footnote", [])) # If a vision model enriched this image with a semantic # description (see _enhance_images_with_vlm), embed it in # the chunk so it becomes searchable / retrievable. @@ -680,12 +668,7 @@ class MinerUParser(RAGFlowPdfParser): section = output.get("code_body", "") + "\n".join(output.get("code_caption", [])) case MinerUContentType.LIST: section = "\n".join(output.get("list_items", [])) - case ( - MinerUContentType.HEADER - | MinerUContentType.FOOTER - | MinerUContentType.PAGE_NUMBER - | MinerUContentType.DISCARDED - ): + case MinerUContentType.HEADER | MinerUContentType.FOOTER | MinerUContentType.PAGE_NUMBER | MinerUContentType.DISCARDED: continue case _: self.logger.debug("[MinerU] Skip unsupported section type=%s", output.get("type")) @@ -719,13 +702,7 @@ class MinerUParser(RAGFlowPdfParser): from rag.app.picture import vision_llm_chunk from rag.prompts.generator import vision_llm_figure_describe_prompt - image_jobs = [ - (idx, item) - for idx, item in enumerate(outputs) - if item.get("type") == MinerUContentType.IMAGE - and item.get("img_path") - and os.path.exists(item["img_path"]) - ] + image_jobs = [(idx, item) for idx, item in enumerate(outputs) if item.get("type") == MinerUContentType.IMAGE and item.get("img_path") and os.path.exists(item["img_path"])] if not image_jobs: return @@ -752,17 +729,17 @@ class MinerUParser(RAGFlowPdfParser): outputs[idx]["vlm_description"] = desc def parse_pdf( - self, - filepath: str | PathLike[str], - binary: BytesIO | bytes, - callback: Optional[Callable] = None, - *, - output_dir: Optional[str] = None, - backend: str = "pipeline", - server_url: Optional[str] = None, - delete_output: bool = True, - parse_method: str = "raw", - **kwargs, + self, + filepath: str | PathLike[str], + binary: BytesIO | bytes, + callback: Optional[Callable] = None, + *, + output_dir: Optional[str] = None, + backend: str = "pipeline", + server_url: Optional[str] = None, + delete_output: bool = True, + parse_method: str = "raw", + **kwargs, ) -> tuple: import shutil @@ -770,12 +747,12 @@ class MinerUParser(RAGFlowPdfParser): temp_pdf = None created_tmp_dir = False - parser_cfg = kwargs.get('parser_config', {}) - lang = parser_cfg.get('mineru_lang') or kwargs.get('lang', 'English') - mineru_lang_code = LANGUAGE_TO_MINERU_MAP.get(lang, 'ch') # Defaults to Chinese if not matched - mineru_method_raw_str = parser_cfg.get('mineru_parse_method', 'auto') - enable_formula = parser_cfg.get('mineru_formula_enable', True) - enable_table = parser_cfg.get('mineru_table_enable', True) + parser_cfg = kwargs.get("parser_config", {}) + lang = parser_cfg.get("mineru_lang") or kwargs.get("lang", "English") + mineru_lang_code = LANGUAGE_TO_MINERU_MAP.get(lang, "ch") # Defaults to Chinese if not matched + mineru_method_raw_str = parser_cfg.get("mineru_parse_method", "auto") + enable_formula = parser_cfg.get("mineru_formula_enable", True) + enable_table = parser_cfg.get("mineru_table_enable", True) # remove spaces, or mineru crash, and _read_output fail too file_path = Path(filepath) diff --git a/deepdoc/parser/opendataloader_parser.py b/deepdoc/parser/opendataloader_parser.py index ed496d1c49..33d54463aa 100644 --- a/deepdoc/parser/opendataloader_parser.py +++ b/deepdoc/parser/opendataloader_parser.py @@ -1,4 +1,3 @@ - from __future__ import annotations import logging @@ -20,9 +19,11 @@ from common.constants import MAXIMUM_PAGE_NUMBER try: from deepdoc.parser.pdf_parser import RAGFlowPdfParser except Exception: + class RAGFlowPdfParser: pass + from deepdoc.parser.utils import extract_pdf_outlines @@ -137,19 +138,14 @@ class OpenDataLoaderParser(RAGFlowPdfParser): def check_installation(self) -> bool: """Return True when the OpenDataLoader service is reachable.""" if not self.api_url: - self.logger.warning( - "[OpenDataLoader] OPENDATALOADER_APISERVER is not set. " - "Start the opendataloader service and set the env var." - ) + self.logger.warning("[OpenDataLoader] OPENDATALOADER_APISERVER is not set. Start the opendataloader service and set the env var.") return False try: headers = {"Authorization": f"Bearer {self.api_key}"} if self.api_key else {} resp = requests.get(f"{self.api_url}/health", timeout=5, headers=headers) if resp.status_code == 200: return True - self.logger.warning( - f"[OpenDataLoader] Health check returned {resp.status_code}: {resp.text[:200]}" - ) + self.logger.warning(f"[OpenDataLoader] Health check returned {resp.status_code}: {resp.text[:200]}") return False except Exception as exc: self.logger.warning(f"[OpenDataLoader] Health check failed: {exc}") @@ -185,9 +181,7 @@ class OpenDataLoaderParser(RAGFlowPdfParser): _, page_height = self.page_images[bbox.page_no - 1].size top = page_height - bbox.y1 bott = page_height - bbox.y0 - return "@@{}\t{:.1f}\t{:.1f}\t{:.1f}\t{:.1f}##".format( - bbox.page_no, x0, x1, top, bott - ) + return "@@{}\t{:.1f}\t{:.1f}\t{:.1f}\t{:.1f}##".format(bbox.page_no, x0, x1, top, bott) @staticmethod def extract_positions(txt: str) -> list[tuple[list[int], float, float, float, float]]: @@ -345,10 +339,7 @@ class OpenDataLoaderParser(RAGFlowPdfParser): self.outlines = extract_pdf_outlines(binary if binary is not None else filepath) if not self.api_url: - raise RuntimeError( - "[OpenDataLoader] OPENDATALOADER_APISERVER is not configured. " - "Please start the opendataloader service and set the env var." - ) + raise RuntimeError("[OpenDataLoader] OPENDATALOADER_APISERVER is not configured. Please start the opendataloader service and set the env var.") # Render page images locally — used by _make_line_tag() and crop(). # The image rendering stays on the RAGFlow host; only the Java conversion diff --git a/deepdoc/parser/pdf_parser.py b/deepdoc/parser/pdf_parser.py index 028edc8df3..1dbb11b753 100644 --- a/deepdoc/parser/pdf_parser.py +++ b/deepdoc/parser/pdf_parser.py @@ -46,7 +46,6 @@ from deepdoc.parser.utils import extract_pdf_outlines from common import settings - from common.misc_utils import thread_pool_exec LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber" @@ -216,7 +215,7 @@ class RAGFlowPdfParser: return True if cp == 0xFFFD: return True - if cp < 0x20 and ch not in ('\t', '\n', '\r'): + if cp < 0x20 and ch not in ("\t", "\n", "\r"): return True if 0x80 <= cp <= 0x9F: return True @@ -292,13 +291,9 @@ class RAGFlowPdfParser: subset_font_count += 1 cp = ord(text[0]) - if (0x2E80 <= cp <= 0x9FFF or 0xF900 <= cp <= 0xFAFF - or 0x20000 <= cp <= 0x2FA1F - or 0xAC00 <= cp <= 0xD7AF - or 0x3040 <= cp <= 0x30FF): + if 0x2E80 <= cp <= 0x9FFF or 0xF900 <= cp <= 0xFAFF or 0x20000 <= cp <= 0x2FA1F or 0xAC00 <= cp <= 0xD7AF or 0x3040 <= cp <= 0x30FF: cjk_like += 1 - elif (0x21 <= cp <= 0x2F or 0x3A <= cp <= 0x40 - or 0x5B <= cp <= 0x60 or 0x7B <= cp <= 0x7E): + elif 0x21 <= cp <= 0x2F or 0x3A <= cp <= 0x40 or 0x5B <= cp <= 0x60 or 0x7B <= cp <= 0x7E: ascii_punct_sym += 1 if total_non_space < min_chars: @@ -758,7 +753,11 @@ class RAGFlowPdfParser: if total_count > 0 and garbled_count / total_count >= 0.5: logging.info( "Page %d: detected garbled pdfplumber text (garbled=%d/%d), falling back to OCR for box at (%.1f, %.1f)", - pagenum, garbled_count, total_count, b["x0"], b["top"], + pagenum, + garbled_count, + total_count, + b["x0"], + b["top"], ) b["text"] = "" continue @@ -767,7 +766,10 @@ class RAGFlowPdfParser: if total_count > 0 and self._is_garbled_by_font_encoding(box_chars, min_chars=5): logging.info( "Page %d: detected font-encoding garbled text (%d chars), falling back to OCR for box at (%.1f, %.1f)", - pagenum, total_count, b["x0"], b["top"], + pagenum, + total_count, + b["x0"], + b["top"], ) b["text"] = "" @@ -1558,19 +1560,18 @@ class RAGFlowPdfParser: sample_text = "".join(c.get("text", "") for c in sample) if self._is_garbled_text(sample_text, threshold=0.3): logging.warning( - "Page %d: pdfplumber extracted mostly garbled characters (%d chars), " - "clearing to use OCR fallback.", - page_from + pi + 1, len(page_ch), + "Page %d: pdfplumber extracted mostly garbled characters (%d chars), clearing to use OCR fallback.", + page_from + pi + 1, + len(page_ch), ) self.page_chars[pi] = [] continue # Strategy 2: font-encoding garbling (CJK mapped to ASCII) if self._is_garbled_by_font_encoding(page_ch): logging.warning( - "Page %d: detected font-encoding garbled text " - "(subset fonts with no CJK output, %d chars), " - "clearing to use OCR fallback.", - page_from + pi + 1, len(page_ch), + "Page %d: detected font-encoding garbled text (subset fonts with no CJK output, %d chars), clearing to use OCR fallback.", + page_from + pi + 1, + len(page_ch), ) self.page_chars[pi] = [] @@ -1856,12 +1857,7 @@ class RAGFlowPdfParser: if isinstance(box.get("position_tag"), str): box["position_tag"] = self._offset_position_tag(box["position_tag"], self.page_from) if isinstance(box.get("positions"), list): - box["positions"] = [ - [int(pos[0]) + self.page_from, *pos[1:]] - if isinstance(pos, list) and len(pos) > 0 and isinstance(pos[0], (int, float)) - else pos - for pos in box["positions"] - ] + box["positions"] = [[int(pos[0]) + self.page_from, *pos[1:]] if isinstance(pos, list) and len(pos) > 0 and isinstance(pos[0], (int, float)) else pos for pos in box["positions"]] return boxes @staticmethod diff --git a/deepdoc/parser/ppt_parser.py b/deepdoc/parser/ppt_parser.py index afff23d7de..9d1537f7c3 100644 --- a/deepdoc/parser/ppt_parser.py +++ b/deepdoc/parser/ppt_parser.py @@ -27,23 +27,20 @@ class RAGFlowPptParser: def __sort_shapes(self, shapes): cache_key = id(shapes) if cache_key not in self._shape_cache: - self._shape_cache[cache_key] = sorted( - shapes, - key=lambda x: ((x.top if x.top is not None else 0) // 10, x.left if x.left is not None else 0) - ) + self._shape_cache[cache_key] = sorted(shapes, key=lambda x: ((x.top if x.top is not None else 0) // 10, x.left if x.left is not None else 0)) return self._shape_cache[cache_key] def __get_bulleted_text(self, paragraph): is_bulleted = bool(paragraph._p.xpath("./a:pPr/a:buChar")) or bool(paragraph._p.xpath("./a:pPr/a:buAutoNum")) or bool(paragraph._p.xpath("./a:pPr/a:buBlip")) if is_bulleted: - return f"{' '* paragraph.level}.{paragraph.text}" + return f"{' ' * paragraph.level}.{paragraph.text}" else: return paragraph.text def __extract(self, shape): try: # First try to get text content - if hasattr(shape, 'has_text_frame') and shape.has_text_frame: + if hasattr(shape, "has_text_frame") and shape.has_text_frame: text_frame = shape.text_frame texts = [] for paragraph in text_frame.paragraphs: @@ -56,7 +53,7 @@ class RAGFlowPptParser: shape_type = shape.shape_type except NotImplementedError: # If shape_type is not available, try to get text content - if hasattr(shape, 'text'): + if hasattr(shape, "text"): return shape.text.strip() return "" @@ -65,8 +62,7 @@ class RAGFlowPptParser: tb = shape.table rows = [] for i in range(1, len(tb.rows)): - rows.append("; ".join([tb.cell( - 0, j).text + ": " + tb.cell(i, j).text for j in range(len(tb.columns)) if tb.cell(i, j)])) + rows.append("; ".join([tb.cell(0, j).text + ": " + tb.cell(i, j).text for j in range(len(tb.columns)) if tb.cell(i, j)])) return "\n".join(rows) # Handle group shape @@ -85,9 +81,7 @@ class RAGFlowPptParser: return "" def __call__(self, fnm, from_page, to_page, callback=None): - ppt = Presentation(fnm) if isinstance( - fnm, str) else Presentation( - BytesIO(fnm)) + ppt = Presentation(fnm) if isinstance(fnm, str) else Presentation(BytesIO(fnm)) txts = [] self.total_page = len(ppt.slides) for i, slide in enumerate(ppt.slides): diff --git a/deepdoc/parser/resume/__init__.py b/deepdoc/parser/resume/__init__.py index 5006d7fbcc..8f35210cc7 100644 --- a/deepdoc/parser/resume/__init__.py +++ b/deepdoc/parser/resume/__init__.py @@ -77,11 +77,7 @@ def refactor(cv): if work: cv["basic"]["work_start_time"] = work[0].get("start_time", "") - cv["basic"]["management_experience"] = ( - "Y" - if any([w.get("management_experience", "") == "Y" for w in work]) - else "N" - ) + cv["basic"]["management_experience"] = "Y" if any([w.get("management_experience", "") == "Y" for w in work]) else "N" cv["basic"]["annual_salary"] = work[-1].get("annual_salary_from", "0") for n in [ diff --git a/deepdoc/parser/resume/entities/__init__.py b/deepdoc/parser/resume/entities/__init__.py index e156bc93dd..177b91dd05 100644 --- a/deepdoc/parser/resume/entities/__init__.py +++ b/deepdoc/parser/resume/entities/__init__.py @@ -12,4 +12,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# \ No newline at end of file +# diff --git a/deepdoc/parser/resume/entities/corporations.py b/deepdoc/parser/resume/entities/corporations.py index 5035967303..0704ece1cf 100644 --- a/deepdoc/parser/resume/entities/corporations.py +++ b/deepdoc/parser/resume/entities/corporations.py @@ -24,9 +24,7 @@ from . import regions current_file_path = os.path.dirname(os.path.abspath(__file__)) -GOODS = pd.read_csv( - os.path.join(current_file_path, "res/corp_baike_len.csv"), sep="\t", header=0 -).fillna(0) +GOODS = pd.read_csv(os.path.join(current_file_path, "res/corp_baike_len.csv"), sep="\t", header=0).fillna(0) GOODS["cid"] = GOODS["cid"].astype(str) GOODS = GOODS.set_index(["cid"]) with open(os.path.join(current_file_path, "res/corp.tks.freq.json"), "r", encoding="utf-8") as f: @@ -53,9 +51,7 @@ def corpNorm(nm, add_region=True): nm = rag_tokenizer.tradi2simp(rag_tokenizer.strQ2B(nm)).lower() nm = re.sub(r"&", "&", nm) nm = re.sub(r"[\(\)()\+'\"\t \*\\【】-]+", " ", nm) - nm = re.sub( - r"([—-]+.*| +co\..*|corp\..*| +inc\..*| +ltd.*)", "", nm, count=10000, flags=re.IGNORECASE - ) + nm = re.sub(r"([—-]+.*| +co\..*|corp\..*| +inc\..*| +ltd.*)", "", nm, count=10000, flags=re.IGNORECASE) nm = re.sub( r"(计算机|技术|(技术|科技|网络)*有限公司|公司|有限|研发中心|中国|总部)$", "", diff --git a/deepdoc/parser/resume/entities/regions.py b/deepdoc/parser/resume/entities/regions.py index ac37f6ea1d..f0a1dc430a 100644 --- a/deepdoc/parser/resume/entities/regions.py +++ b/deepdoc/parser/resume/entities/regions.py @@ -778,7 +778,6 @@ def get_names(id): return nms - def isName(nm): if nm in NM_SET: return True diff --git a/deepdoc/parser/resume/entities/schools.py b/deepdoc/parser/resume/entities/schools.py index 5763ca48be..a6856f623d 100644 --- a/deepdoc/parser/resume/entities/schools.py +++ b/deepdoc/parser/resume/entities/schools.py @@ -21,9 +21,7 @@ import copy import pandas as pd current_file_path = os.path.dirname(os.path.abspath(__file__)) -TBL = pd.read_csv( - os.path.join(current_file_path, "res/schools.csv"), sep="\t", header=0 -).fillna("") +TBL = pd.read_csv(os.path.join(current_file_path, "res/schools.csv"), sep="\t", header=0).fillna("") TBL["name_en"] = TBL["name_en"].map(lambda x: x.lower().strip()) with open(os.path.join(current_file_path, "res/good_sch.json"), "r", encoding="utf-8") as f: GOOD_SCH = json.load(f) @@ -53,12 +51,7 @@ loadRank(os.path.join(current_file_path, "res/school.rank.csv")) def split(txt): tks = [] for t in re.sub(r"[ \t]+", " ", txt).split(): - if ( - tks - and re.match(r".*[a-zA-Z]$", tks[-1]) - and re.match(r"[a-zA-Z]", t) - and tks - ): + if tks and re.match(r".*[a-zA-Z]$", tks[-1]) and re.match(r"[a-zA-Z]", t) and tks: tks[-1] = tks[-1] + " " + t else: tks.append(t) diff --git a/deepdoc/parser/resume/step_one.py b/deepdoc/parser/resume/step_one.py index 4c65a97fe8..edce6a9028 100644 --- a/deepdoc/parser/resume/step_one.py +++ b/deepdoc/parser/resume/step_one.py @@ -18,59 +18,60 @@ import json from deepdoc.parser.resume.entities import degrees, regions, industries FIELDS = [ -"address STRING", -"annual_salary int", -"annual_salary_from int", -"annual_salary_to int", -"birth STRING", -"card STRING", -"certificate_obj string", -"city STRING", -"corporation_id int", -"corporation_name STRING", -"corporation_type STRING", -"degree STRING", -"discipline_name STRING", -"education_obj string", -"email STRING", -"expect_annual_salary int", -"expect_city_names string", -"expect_industry_name STRING", -"expect_position_name STRING", -"expect_salary_from int", -"expect_salary_to int", -"expect_type STRING", -"gender STRING", -"industry_name STRING", -"industry_names STRING", -"is_deleted STRING", -"is_fertility STRING", -"is_house STRING", -"is_management_experience STRING", -"is_marital STRING", -"is_oversea STRING", -"language_obj string", -"name STRING", -"nation STRING", -"phone STRING", -"political_status STRING", -"position_name STRING", -"project_obj string", -"responsibilities string", -"salary_month int", -"scale STRING", -"school_name STRING", -"self_remark string", -"skill_obj string", -"title_name STRING", -"tob_resume_id STRING", -"updated_at Timestamp", -"wechat STRING", -"work_obj string", -"work_experience int", -"work_start_time BIGINT" + "address STRING", + "annual_salary int", + "annual_salary_from int", + "annual_salary_to int", + "birth STRING", + "card STRING", + "certificate_obj string", + "city STRING", + "corporation_id int", + "corporation_name STRING", + "corporation_type STRING", + "degree STRING", + "discipline_name STRING", + "education_obj string", + "email STRING", + "expect_annual_salary int", + "expect_city_names string", + "expect_industry_name STRING", + "expect_position_name STRING", + "expect_salary_from int", + "expect_salary_to int", + "expect_type STRING", + "gender STRING", + "industry_name STRING", + "industry_names STRING", + "is_deleted STRING", + "is_fertility STRING", + "is_house STRING", + "is_management_experience STRING", + "is_marital STRING", + "is_oversea STRING", + "language_obj string", + "name STRING", + "nation STRING", + "phone STRING", + "political_status STRING", + "position_name STRING", + "project_obj string", + "responsibilities string", + "salary_month int", + "scale STRING", + "school_name STRING", + "self_remark string", + "skill_obj string", + "title_name STRING", + "tob_resume_id STRING", + "updated_at Timestamp", + "wechat STRING", + "work_obj string", + "work_experience int", + "work_start_time BIGINT", ] + def refactor(df): def deal_obj(obj, k, kk): if not isinstance(obj, type({})): @@ -100,44 +101,59 @@ def refactor(df): df[c] = df["obj"].map(lambda x: deal_obj(x, cc, c)) else: df[c] = df["obj"].map( - lambda x: json.dumps( - x.get( - c, - {}), - ensure_ascii=False) if isinstance( - x, - type( - {})) and ( - isinstance( - x.get(c), - type( - {})) or not x.get(c)) else str(x).replace( - "None", - "")) + lambda x: json.dumps(x.get(c, {}), ensure_ascii=False) if isinstance(x, type({})) and (isinstance(x.get(c), type({})) or not x.get(c)) else str(x).replace("None", "") + ) - extract(["education", "work", "certificate", "project", "language", - "skill"]) - extract(["wechat", "phone", "is_deleted", - "name", "tel", "email"], "contact") - extract(["nation", "expect_industry_name", "salary_month", - "industry_ids", "is_house", "birth", "annual_salary_from", - "annual_salary_to", "card", - "expect_salary_to", "expect_salary_from", - "expect_position_name", "gender", "city", - "is_fertility", "expect_city_names", - "political_status", "title_name", "expect_annual_salary", - "industry_name", "address", "position_name", "school_name", - "corporation_id", - "is_oversea", "responsibilities", - "work_start_time", "degree", "management_experience", - "expect_type", "corporation_type", "scale", "corporation_name", - "self_remark", "annual_salary", "work_experience", - "discipline_name", "marital", "updated_at"], "basic") + extract(["education", "work", "certificate", "project", "language", "skill"]) + extract(["wechat", "phone", "is_deleted", "name", "tel", "email"], "contact") + extract( + [ + "nation", + "expect_industry_name", + "salary_month", + "industry_ids", + "is_house", + "birth", + "annual_salary_from", + "annual_salary_to", + "card", + "expect_salary_to", + "expect_salary_from", + "expect_position_name", + "gender", + "city", + "is_fertility", + "expect_city_names", + "political_status", + "title_name", + "expect_annual_salary", + "industry_name", + "address", + "position_name", + "school_name", + "corporation_id", + "is_oversea", + "responsibilities", + "work_start_time", + "degree", + "management_experience", + "expect_type", + "corporation_type", + "scale", + "corporation_name", + "self_remark", + "annual_salary", + "work_experience", + "discipline_name", + "marital", + "updated_at", + ], + "basic", + ) df["degree"] = df["degree"].map(lambda x: degrees.get_name(x)) df["address"] = df["address"].map(lambda x: " ".join(regions.get_names(x))) - df["industry_names"] = df["industry_ids"].map(lambda x: " ".join([" ".join(industries.get_names(i)) for i in - str(x).split(",")])) + df["industry_names"] = df["industry_ids"].map(lambda x: " ".join([" ".join(industries.get_names(i)) for i in str(x).split(",")])) clms.append("industry_names") def arr2str(a): @@ -147,16 +163,10 @@ def refactor(df): a = " ".join([str(i) for i in a]) return str(a).replace(",", " ") - df["expect_industry_name"] = df["expect_industry_name"].map( - lambda x: arr2str(x)) - df["gender"] = df["gender"].map( - lambda x: "男" if x == 'M' else ( - "女" if x == 'F' else "")) - for c in ["is_fertility", "is_oversea", "is_house", - "management_experience", "marital"]: - df[c] = df[c].map( - lambda x: '是' if x == 'Y' else ( - '否' if x == 'N' else "")) + df["expect_industry_name"] = df["expect_industry_name"].map(lambda x: arr2str(x)) + df["gender"] = df["gender"].map(lambda x: "男" if x == "M" else ("女" if x == "F" else "")) + for c in ["is_fertility", "is_oversea", "is_house", "management_experience", "marital"]: + df[c] = df[c].map(lambda x: "是" if x == "Y" else ("否" if x == "N" else "")) df["is_management_experience"] = df["management_experience"] df["is_marital"] = df["marital"] clms.extend(["is_management_experience", "is_marital"]) @@ -175,15 +185,8 @@ def refactor(df): clms = list(set(clms)) df = df.reindex(sorted(clms), axis=1) - #print(json.dumps(list(df.columns.values)), "LLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLL") + # print(json.dumps(list(df.columns.values)), "LLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLL") for c in clms: - df[c] = df[c].map( - lambda s: str(s).replace( - "\t", - " ").replace( - "\n", - "\\n").replace( - "\r", - "\\n")) + df[c] = df[c].map(lambda s: str(s).replace("\t", " ").replace("\n", "\\n").replace("\r", "\\n")) # print(df.values.tolist()) return dict(zip([n.split()[0] for n in FIELDS], df.values.tolist()[0])) diff --git a/deepdoc/parser/resume/step_two.py b/deepdoc/parser/resume/step_two.py index f23b6ad204..f32d308dea 100644 --- a/deepdoc/parser/resume/step_two.py +++ b/deepdoc/parser/resume/step_two.py @@ -130,7 +130,7 @@ def forEdu(cv): d = degrees.get_name(n["degree"]) if d: e["degree_kwd"] = d - if d == "本科" and ("专科" in deg or "专升本" in deg or "中专" in deg or "大专" in deg or re.search(r"(成人|自考|自学考试)", n.get("school_name",""))): + if d == "本科" and ("专科" in deg or "专升本" in deg or "中专" in deg or "大专" in deg or re.search(r"(成人|自考|自学考试)", n.get("school_name", ""))): d = "专升本" if d: deg.append(d) @@ -179,7 +179,7 @@ def forEdu(cv): if deg: if "本科" in deg and "专科" in deg: deg.append("专升本") - deg = [d for d in deg if d != '本科'] + deg = [d for d in deg if d != "本科"] cv["degree_kwd"] = deg cv["highest_degree_kwd"] = highest_degree(deg) if edu_end_dt: @@ -194,9 +194,7 @@ def forEdu(cv): logging.exception("forEdu {} {} {}".format(e, edu_end_dt, cv.get("work_exp_flt"))) if sch: cv["school_name_kwd"] = sch - if (len(cv.get("degree_kwd", [])) >= 1 and "本科" in cv["degree_kwd"]) \ - or all([c.lower() in ["硕士", "博士", "mba", "博士后"] for c in cv.get("degree_kwd", [])]) \ - or not cv.get("degree_kwd"): + if (len(cv.get("degree_kwd", [])) >= 1 and "本科" in cv["degree_kwd"]) or all([c.lower() in ["硕士", "博士", "mba", "博士后"] for c in cv.get("degree_kwd", [])]) or not cv.get("degree_kwd"): for c in sch: if schools.is_good(c): if "tag_kwd" not in cv: @@ -204,10 +202,11 @@ def forEdu(cv): cv["tag_kwd"].append("好学校") cv["tag_kwd"].append("好学历") break - if (len(cv.get("degree_kwd", [])) >= 1 and "本科" in cv["degree_kwd"] and - any([d.lower() in ["硕士", "博士", "mba", "博士"] for d in cv.get("degree_kwd", [])])) \ - or all([d.lower() in ["硕士", "博士", "mba", "博士后"] for d in cv.get("degree_kwd", [])]) \ - or any([d in ["mba", "emba", "博士后"] for d in cv.get("degree_kwd", [])]): + if ( + (len(cv.get("degree_kwd", [])) >= 1 and "本科" in cv["degree_kwd"] and any([d.lower() in ["硕士", "博士", "mba", "博士"] for d in cv.get("degree_kwd", [])])) + or all([d.lower() in ["硕士", "博士", "mba", "博士后"] for d in cv.get("degree_kwd", [])]) + or any([d in ["mba", "emba", "博士后"] for d in cv.get("degree_kwd", [])]) + ): if "tag_kwd" not in cv: cv["tag_kwd"] = [] if "好学历" not in cv["tag_kwd"]: @@ -230,9 +229,7 @@ def forProj(cv): return cv pro_nms, desc = [], [] - for i, n in enumerate( - sorted(cv.get("project_obj", []), key=lambda x: str(x.get("updated_at", "")) if isinstance(x, dict) else "", - reverse=True)): + for i, n in enumerate(sorted(cv.get("project_obj", []), key=lambda x: str(x.get("updated_at", "")) if isinstance(x, dict) else "", reverse=True)): if n.get("name"): pro_nms.append(n["name"]) if n.get("describe"): @@ -261,8 +258,7 @@ def forWork(cv): cv["integerity_flt"] *= 0.7 return cv - flds = ["position_name", "corporation_name", "corporation_id", "responsibilities", - "industry_name", "subordinates_count"] + flds = ["position_name", "corporation_name", "corporation_id", "responsibilities", "industry_name", "subordinates_count"] duas = [] scales = [] fea = {c: [] for c in flds} @@ -271,9 +267,7 @@ def forWork(cv): goodcorp_ = False work_st_tm = "" corp_tags = [] - for i, n in enumerate( - sorted(cv.get("work_obj", []), key=lambda x: str(x.get("start_time", "")) if isinstance(x, dict) else "", - reverse=True)): + for i, n in enumerate(sorted(cv.get("work_obj", []), key=lambda x: str(x.get("start_time", "")) if isinstance(x, dict) else "", reverse=True)): if isinstance(n, str): try: n = json_loads(n) @@ -283,7 +277,7 @@ def forWork(cv): if n.get("start_time") and (not work_st_tm or n["start_time"] < work_st_tm): work_st_tm = n["start_time"] for c in flds: - if not n.get(c) or str(n[c]) == '0': + if not n.get(c) or str(n[c]) == "0": fea[c].append("") continue if c == "corporation_name": @@ -368,8 +362,7 @@ def forWork(cv): cv["resp_ltks"] = rag_tokenizer.tokenize(" ".join(fea["responsibilities"][1:])) if fea["subordinates_count"]: - fea["subordinates_count"] = [int(i) for i in fea["subordinates_count"] if - re.match(r"[^0-9]+$", str(i))] + fea["subordinates_count"] = [int(i) for i in fea["subordinates_count"] if re.match(r"[^0-9]+$", str(i))] if fea["subordinates_count"]: cv["max_sub_cnt_int"] = np.max(fea["subordinates_count"]) @@ -445,25 +438,55 @@ def birth(cv): def parse(cv): for k in cv.keys(): - if cv[k] == '\\N': - cv[k] = '' + if cv[k] == "\\N": + cv[k] = "" # cv = cv.asDict() - tks_fld = ["address", "corporation_name", "discipline_name", "email", "expect_city_names", - "expect_industry_name", "expect_position_name", "industry_name", "industry_names", "name", - "position_name", "school_name", "self_remark", "title_name"] + tks_fld = [ + "address", + "corporation_name", + "discipline_name", + "email", + "expect_city_names", + "expect_industry_name", + "expect_position_name", + "industry_name", + "industry_names", + "name", + "position_name", + "school_name", + "self_remark", + "title_name", + ] small_tks_fld = ["corporation_name", "expect_position_name", "position_name", "school_name", "title_name"] - kwd_fld = ["address", "city", "corporation_type", "degree", "discipline_name", "expect_city_names", "email", - "expect_industry_name", "expect_position_name", "expect_type", "gender", "industry_name", - "industry_names", "political_status", "position_name", "scale", "school_name", "phone", "tel"] - num_fld = ["annual_salary", "annual_salary_from", "annual_salary_to", "expect_annual_salary", "expect_salary_from", - "expect_salary_to", "salary_month"] + kwd_fld = [ + "address", + "city", + "corporation_type", + "degree", + "discipline_name", + "expect_city_names", + "email", + "expect_industry_name", + "expect_position_name", + "expect_type", + "gender", + "industry_name", + "industry_names", + "political_status", + "position_name", + "scale", + "school_name", + "phone", + "tel", + ] + num_fld = ["annual_salary", "annual_salary_from", "annual_salary_to", "expect_annual_salary", "expect_salary_from", "expect_salary_to", "salary_month"] is_fld = [ ("is_fertility", "已育", "未育"), ("is_house", "有房", "没房"), ("is_management_experience", "有管理经验", "无管理经验"), ("is_marital", "已婚", "未婚"), - ("is_oversea", "有海外经验", "无海外经验") + ("is_oversea", "有海外经验", "无海外经验"), ] rmkeys = [] @@ -475,15 +498,15 @@ def parse(cv): for k in rmkeys: del cv[k] - integrity = 0. - flds_num = 0. + integrity = 0.0 + flds_num = 0.0 def hasValues(flds): nonlocal integrity, flds_num flds_num += len(flds) for f in flds: v = str(cv.get(f, "")) - if len(v) > 0 and v != '0' and v != '[]': + if len(v) > 0 and v != "0" and v != "[]": integrity += 1 hasValues(tks_fld) @@ -493,24 +516,23 @@ def parse(cv): cv["integerity_flt"] = integrity / flds_num if cv.get("corporation_type"): - for p, r in [(r"(公司|企业|其它|其他|Others*|\n|未填写|Enterprises|Company|companies)", ""), - (r"[//.· <\((]+.*", ""), - (r".*(合资|民企|股份制|中外|私营|个体|Private|创业|Owned|投资).*", "民营"), - (r".*(机关|事业).*", "机关"), - (r".*(非盈利|Non-profit).*", "非盈利"), - (r".*(外企|外商|欧美|foreign|Institution|Australia|港资).*", "外企"), - (r".*国有.*", "国企"), - (r"[ ()\(\)人/·0-9-]+", ""), - (r".*(元|规模|于|=|北京|上海|至今|中国|工资|州|shanghai|强|餐饮|融资|职).*", "")]: + for p, r in [ + (r"(公司|企业|其它|其他|Others*|\n|未填写|Enterprises|Company|companies)", ""), + (r"[//.· <\((]+.*", ""), + (r".*(合资|民企|股份制|中外|私营|个体|Private|创业|Owned|投资).*", "民营"), + (r".*(机关|事业).*", "机关"), + (r".*(非盈利|Non-profit).*", "非盈利"), + (r".*(外企|外商|欧美|foreign|Institution|Australia|港资).*", "外企"), + (r".*国有.*", "国企"), + (r"[ ()\(\)人/·0-9-]+", ""), + (r".*(元|规模|于|=|北京|上海|至今|中国|工资|州|shanghai|强|餐饮|融资|职).*", ""), + ]: cv["corporation_type"] = re.sub(p, r, cv["corporation_type"], count=1000, flags=re.IGNORECASE) if len(cv["corporation_type"]) < 2: del cv["corporation_type"] if cv.get("political_status"): - for p, r in [ - (r".*党员.*", "党员"), - (r".*(无党派|公民).*", "群众"), - (r".*团员.*", "团员")]: + for p, r in [(r".*党员.*", "党员"), (r".*(无党派|公民).*", "群众"), (r".*团员.*", "团员")]: cv["political_status"] = re.sub(p, r, cv["political_status"]) if not re.search(r"[党团群]", cv["political_status"]): del cv["political_status"] @@ -549,10 +571,7 @@ def parse(cv): # keyword fields if k in kwd_fld: - cv[f"{k}_kwd"] = [n.lower() - for n in re.split(r"[\t,,;;. ]", - re.sub(r"([^a-zA-Z])[ ]+([^a-zA-Z ])", r"\1,\2", cv[k]) - ) if n] + cv[f"{k}_kwd"] = [n.lower() for n in re.split(r"[\t,,;;. ]", re.sub(r"([^a-zA-Z])[ ]+([^a-zA-Z ])", r"\1,\2", cv[k])) if n] if k in num_fld and cv.get(k): cv[f"{k}_int"] = cv[k] @@ -575,24 +594,22 @@ def parse(cv): name = cv["name"] # name pingyin and its prefix - cv["name_py_tks"] = " ".join(PY.get_pinyins(nm[:20], '')) + " " + " ".join(PY.get_pinyins(nm[:20], ' ')) + cv["name_py_tks"] = " ".join(PY.get_pinyins(nm[:20], "")) + " " + " ".join(PY.get_pinyins(nm[:20], " ")) cv["name_py_pref0_tks"] = "" cv["name_py_pref_tks"] = "" - for py in PY.get_pinyins(nm[:20], ''): + for py in PY.get_pinyins(nm[:20], ""): for i in range(2, len(py) + 1): cv["name_py_pref_tks"] += " " + py[:i] - for py in PY.get_pinyins(nm[:20], ' '): + for py in PY.get_pinyins(nm[:20], " "): py = py.split() for i in range(1, len(py) + 1): cv["name_py_pref0_tks"] += " " + "".join(py[:i]) cv["name_kwd"] = name - cv["name_pinyin_kwd"] = PY.get_pinyins(nm[:20], ' ')[:3] - cv["name_tks"] = ( - rag_tokenizer.tokenize(name) + " " + (" ".join(list(name)) if not re.match(r"[a-zA-Z ]+$", name) else "") - ) if name else "" + cv["name_pinyin_kwd"] = PY.get_pinyins(nm[:20], " ")[:3] + cv["name_tks"] = (rag_tokenizer.tokenize(name) + " " + (" ".join(list(name)) if not re.match(r"[a-zA-Z ]+$", name) else "")) if name else "" else: - cv["integerity_flt"] /= 2. + cv["integerity_flt"] /= 2.0 if cv.get("phone"): r = re.search(r"(1[3456789][0-9]{9})", cv["phone"]) @@ -603,7 +620,7 @@ def parse(cv): # deal with date fields if cv.get("updated_at") and isinstance(cv["updated_at"], datetime.datetime): - cv["updated_at_dt"] = cv["updated_at"].strftime('%Y-%m-%d %H:%M:%S') + cv["updated_at_dt"] = cv["updated_at"].strftime("%Y-%m-%d %H:%M:%S") else: y, m, d = getYMD(str(cv.get("updated_at", ""))) if not y: @@ -623,9 +640,9 @@ def parse(cv): for f, y, n in is_fld: if f not in cv: continue - if cv[f] == '是': + if cv[f] == "是": fea.append(y) - if cv[f] == '否': + if cv[f] == "否": fea.append(n) if fea: @@ -648,7 +665,7 @@ def parse(cv): if not cv.get("work_exp_flt") and cv.get("work_start_time"): if re.match(r"[0-9]{9,}", str(cv["work_start_time"])): cv["work_start_dt"] = turnTm2Dt(cv["work_start_time"]) - cv["work_exp_flt"] = (time.time() - int(int(cv["work_start_time"]) / 1000)) / 3600. / 24. / 365. + cv["work_exp_flt"] = (time.time() - int(int(cv["work_start_time"]) / 1000)) / 3600.0 / 24.0 / 365.0 elif re.match(r"[0-9]{4}[^0-9]", str(cv["work_start_time"])): y, m, d = getYMD(str(cv["work_start_time"])) cv["work_start_dt"] = "%s-%02d-%02d 00:00:00" % (y, int(m), int(d)) @@ -656,7 +673,7 @@ def parse(cv): except Exception as e: logging.exception("parse {} ==> {}".format(e, cv.get("work_start_time"))) if "work_exp_flt" not in cv and cv.get("work_experience", 0): - cv["work_exp_flt"] = int(cv["work_experience"]) / 12. + cv["work_exp_flt"] = int(cv["work_experience"]) / 12.0 keys = list(cv.keys()) for k in keys: @@ -665,7 +682,7 @@ def parse(cv): for k in cv.keys(): if not re.search("_(kwd|id)$", k) or not isinstance(cv[k], list): continue - cv[k] = list(set([re.sub("(市)$", "", str(n)) for n in cv[k] if n not in ['中国', '0']])) + cv[k] = list(set([re.sub("(市)$", "", str(n)) for n in cv[k] if n not in ["中国", "0"]])) keys = [k for k in cv.keys() if re.search(r"_feas*$", k)] for k in keys: if cv[k] <= 0: diff --git a/deepdoc/parser/tcadp_parser.py b/deepdoc/parser/tcadp_parser.py index 6a37f0befd..37839eb87a 100644 --- a/deepdoc/parser/tcadp_parser.py +++ b/deepdoc/parser/tcadp_parser.py @@ -102,12 +102,12 @@ class TencentCloudAPIClient: logging.info("[TCADP] Detected streaming response") for event in resp: logging.info(f"[TCADP] Received event: {event}") - if event.get('data'): + if event.get("data"): try: - data_dict = json.loads(event['data']) + data_dict = json.loads(event["data"]) logging.info(f"[TCADP] Parsed data: {data_dict}") - if data_dict.get('Progress') == "100": + if data_dict.get("Progress") == "100": parser_result = data_dict logging.info("[TCADP] Document parsing completed!") logging.info(f"[TCADP] Task ID: {data_dict.get('TaskId')}") @@ -141,7 +141,7 @@ class TencentCloudAPIClient: logging.info(f"[TCADP] Event without data: {event}") else: # Non-streaming response logging.info("[TCADP] Detected non-streaming response") - if hasattr(resp, 'data') and resp.data: + if hasattr(resp, "data") and resp.data: try: data_dict = json.loads(resp.data) parser_result = data_dict @@ -198,8 +198,7 @@ class TencentCloudAPIClient: class TCADPParser(RAGFlowPdfParser): - def __init__(self, secret_id: str = None, secret_key: str = None, region: str = "ap-guangzhou", - table_result_type: str = None, markdown_image_response_type: str = None): + def __init__(self, secret_id: str = None, secret_key: str = None, region: str = "ap-guangzhou", table_result_type: str = None, markdown_image_response_type: str = None): super().__init__() # First initialize logger @@ -268,12 +267,12 @@ class TCADPParser(RAGFlowPdfParser): if binary: # If binary data is directly available, convert directly - return base64.b64encode(binary).decode('utf-8') + return base64.b64encode(binary).decode("utf-8") else: # Read from file path and convert - with open(file_path, 'rb') as f: + with open(file_path, "rb") as f: file_data = f.read() - return base64.b64encode(file_data).decode('utf-8') + return base64.b64encode(file_data).decode("utf-8") def _extract_content_from_zip(self, zip_path: str) -> list[dict[str, Any]]: """Extract parsing results from downloaded ZIP file""" @@ -436,22 +435,13 @@ class TCADPParser(RAGFlowPdfParser): self.logger.info(f"[TCADP] Retry attempt {attempt + 1}") if callback: callback(0.3 + attempt * 0.1, f"[TCADP] Retry attempt {attempt + 1}") - time.sleep(2 ** attempt) # Exponential backoff + time.sleep(2**attempt) # Exponential backoff - config = { - "TableResultType": self.table_result_type, - "MarkdownImageResponseType": self.markdown_image_response_type - } + config = {"TableResultType": self.table_result_type, "MarkdownImageResponseType": self.markdown_image_response_type} self.logger.info(f"[TCADP] API request config - TableResultType: {self.table_result_type}, MarkdownImageResponseType: {self.markdown_image_response_type}") - result = client.reconstruct_document_sse( - file_type=file_type, - file_base64=file_base64, - file_start_page=file_start_page, - file_end_page=file_end_page, - config=config - ) + result = client.reconstruct_document_sse(file_type=file_type, file_base64=file_base64, file_start_page=file_start_page, file_end_page=file_end_page, config=config) if result: self.logger.info(f"[TCADP] Attempt {attempt + 1} successful") diff --git a/deepdoc/parser/txt_parser.py b/deepdoc/parser/txt_parser.py index 6abf8591da..2a5e999f7d 100644 --- a/deepdoc/parser/txt_parser.py +++ b/deepdoc/parser/txt_parser.py @@ -31,7 +31,7 @@ class RAGFlowTxtParser: raise TypeError("txt type should be str!") cks = [""] tk_nums = [0] - delimiter = delimiter.encode('utf-8').decode('unicode_escape').encode('latin1').decode('utf-8') + delimiter = delimiter.encode("utf-8").decode("unicode_escape").encode("latin1").decode("utf-8") def add_chunk(t): nonlocal cks, tk_nums, delimiter @@ -51,7 +51,7 @@ class RAGFlowTxtParser: for m in re.finditer(r"`([^`]+)`", delimiter, re.I): f, t = m.span() dels.append(m.group(1)) - dels.extend(list(delimiter[s: f])) + dels.extend(list(delimiter[s:f])) s = t if s < len(delimiter): dels.extend(list(delimiter[s:])) diff --git a/deepdoc/server/deepdoc_server.py b/deepdoc/server/deepdoc_server.py index 4ce7613e6c..7559eaf64c 100644 --- a/deepdoc/server/deepdoc_server.py +++ b/deepdoc/server/deepdoc_server.py @@ -32,32 +32,17 @@ def parse_args(): description="Unified OSS DeepDoc Model Server", formatter_class=argparse.RawDescriptionHelpFormatter, ) - parser.add_argument( - "--port", type=int, default=9390, help="Serving port (default: 9390)" - ) - parser.add_argument( - "--timeout", type=int, default=100, help="Request timeout in seconds (default: 100)" - ) + parser.add_argument("--port", type=int, default=9390, help="Serving port (default: 9390)") + parser.add_argument("--timeout", type=int, default=100, help="Request timeout in seconds (default: 100)") parser.add_argument( "--model-dir", type=str, - default=os.path.join( - os.path.dirname(__file__), "..", "..", "..", "rag", "res", "deepdoc" - ), + default=os.path.join(os.path.dirname(__file__), "..", "..", "..", "rag", "res", "deepdoc"), help="Model file directory", ) - parser.add_argument( - "--disable-dla", action="store_true", dest="disable_dla", default=False, - help="Disable DLA endpoint" - ) - parser.add_argument( - "--disable-ocr", action="store_true", dest="disable_ocr", default=False, - help="Disable OCR endpoint" - ) - parser.add_argument( - "--disable-tsr", action="store_true", dest="disable_tsr", default=False, - help="Disable TSR endpoint" - ) + parser.add_argument("--disable-dla", action="store_true", dest="disable_dla", default=False, help="Disable DLA endpoint") + parser.add_argument("--disable-ocr", action="store_true", dest="disable_ocr", default=False, help="Disable OCR endpoint") + parser.add_argument("--disable-tsr", action="store_true", dest="disable_tsr", default=False, help="Disable TSR endpoint") parser.add_argument("--log-level", type=str, default="INFO", help="Logging level") return parser.parse_args() diff --git a/deepdoc/server/docker_stubs.py b/deepdoc/server/docker_stubs.py index a847f04482..6bdb9e5a38 100644 --- a/deepdoc/server/docker_stubs.py +++ b/deepdoc/server/docker_stubs.py @@ -32,15 +32,20 @@ def write(path: str, content: str) -> None: # Real deepdoc/__init__.py calls beartype_this_package() which requires # the beartype library. -write("deepdoc/__init__.py", """ +write( + "deepdoc/__init__.py", + """ # Minimal deepdoc __init__ for Docker — avoids beartype dependency. -""") +""", +) # Real deepdoc/vision/__init__.py imports pdfplumber and # AscendLayoutRecognizer (requires ais_bench). The Docker server only # needs the four ONNX-based classes below. -write("deepdoc/vision/__init__.py", """ +write( + "deepdoc/vision/__init__.py", + """ # Minimal deepdoc.vision __init__ for Docker — avoids pdfplumber and Ascend imports. from .ocr import OCR from .recognizer import Recognizer @@ -48,13 +53,16 @@ from .layout_recognizer import LayoutRecognizer4YOLOv10 as LayoutRecognizer from .table_structure_recognizer import TableStructureRecognizer __all__ = ["OCR", "Recognizer", "LayoutRecognizer", "TableStructureRecognizer"] -""") +""", +) # ── common ───────────────────────────────────────────────────────────── # Real common.settings imports rag.utils.es_conn and other database/storage # connectors. The server only needs PARALLEL_DEVICES for OCR. -write("common/__init__.py", """ +write( + "common/__init__.py", + """ # Stub common.__init__ for Docker deepdoc service. import os @@ -64,12 +72,15 @@ class _Settings: settings = _Settings() -""") +""", +) # Real common.file_utils derives the project base from __file__. In # Docker the project root is always /app. -write("common/file_utils.py", """ +write( + "common/file_utils.py", + """ # Stub common.file_utils for Docker deepdoc service. import os @@ -83,14 +94,17 @@ def get_project_base_directory(*args): if args: return os.path.join(_PROJECT_BASE, *args) return _PROJECT_BASE -""") +""", +) # Real common.misc_utils imports 15+ modules. The server only calls # pip_install_torch() inside load_model()'s cuda_is_available() guard. # On CPU-only images torch is not installed, so the try/except silently # returns False and onnxruntime falls back to CPUExecutionProvider. -write("common/misc_utils.py", """ +write( + "common/misc_utils.py", + """ # Stub common.misc_utils for Docker deepdoc service. @@ -99,13 +113,17 @@ def pip_install_torch(*args, **kwargs): import torch # noqa: F401 except ImportError: pass -""") +""", +) # ── rag ──────────────────────────────────────────────────────────────── -write("rag/__init__.py", """ +write( + "rag/__init__.py", + """ # Stub rag package for Docker deepdoc service. -""") +""", +) # table_structure_recognizer.py imports rag_tokenizer at module level. # Its tokenize/tag methods are only called from blockType() / @@ -113,7 +131,9 @@ write("rag/__init__.py", """ # __call__() path. The stub exists solely to satisfy the module-level # import; its methods are never called at server runtime. -write("rag/nlp/__init__.py", """ +write( + "rag/nlp/__init__.py", + """ # Stub rag.nlp module for Docker deepdoc service. # Provides minimal rag_tokenizer to satisfy table_structure_recognizer import. @@ -127,14 +147,17 @@ class _StubTokenizer: rag_tokenizer = _StubTokenizer() -""") +""", +) # operators.py imports ensure_pil_image at module level and calls it in # NormalizeImage.__call__ / ToCHWImage.__call__ (OCR text detection path). # The real rag.utils.lazy_image imports concat_img from rag.nlp, pulling # in the entire NLP stack. -write("rag/utils/lazy_image.py", """ +write( + "rag/utils/lazy_image.py", + """ # Stub rag.utils.lazy_image for Docker. from PIL import Image @@ -143,7 +166,8 @@ def ensure_pil_image(img): if isinstance(img, Image.Image): return img return None -""") +""", +) if __name__ == "__main__": diff --git a/deepdoc/server/endpoints/ocr_endpoint.py b/deepdoc/server/endpoints/ocr_endpoint.py index 409ac77ac0..03164678fd 100644 --- a/deepdoc/server/endpoints/ocr_endpoint.py +++ b/deepdoc/server/endpoints/ocr_endpoint.py @@ -50,9 +50,7 @@ class OCREndpoint(ls.LitAPI): operator = operator.strip().lower() if operator not in ("det", "rec"): - raise ValueError( - f"Invalid or missing operator '{operator}' (must be 'det' or 'rec')" - ) + raise ValueError(f"Invalid or missing operator '{operator}' (must be 'det' or 'rec')") return operator, data diff --git a/deepdoc/vision/ocr.py b/deepdoc/vision/ocr.py index f5508dc4f1..d7472a130c 100644 --- a/deepdoc/vision/ocr.py +++ b/deepdoc/vision/ocr.py @@ -35,8 +35,9 @@ from .postprocess import build_post_process loaded_models = {} + def transform(data, ops=None): - """ transform """ + """transform""" if ops is None: ops = [] for op in ops: @@ -53,12 +54,10 @@ def create_operators(op_param_list, global_config=None): Args: params(list): a dict list, used to create some operators """ - assert isinstance( - op_param_list, list), ('operator config should be a list') + assert isinstance(op_param_list, list), "operator config should be a list" ops = [] for operator in op_param_list: - assert isinstance(operator, - dict) and len(operator) == 1, "yaml format error" + assert isinstance(operator, dict) and len(operator) == 1, "yaml format error" op_name = list(operator)[0] param = {} if operator[op_name] is None else operator[op_name] if global_config is not None: @@ -79,13 +78,13 @@ def load_model(model_dir, nm, device_id: int | None = None): return loaded_model if not os.path.exists(model_file_path): - raise ValueError("not find model file path {}".format( - model_file_path)) + raise ValueError("not find model file path {}".format(model_file_path)) def cuda_is_available(): try: pip_install_torch() import torch + target_id = 0 if device_id is None else device_id if torch.cuda.is_available() and torch.cuda.device_count() > target_id: return True @@ -108,27 +107,18 @@ def load_model(model_dir, nm, device_id: int | None = None): arena_strategy = os.environ.get("OCR_ARENA_EXTEND_STRATEGY", "kNextPowerOfTwo") provider_device_id = 0 if device_id is None else device_id cuda_provider_options = { - "device_id": provider_device_id, # Use specific GPU + "device_id": provider_device_id, # Use specific GPU "gpu_mem_limit": max(gpu_mem_limit_mb, 0) * 1024 * 1024, "arena_extend_strategy": arena_strategy, # gpu memory allocation strategy } - sess = ort.InferenceSession( - model_file_path, - options=options, - providers=['CUDAExecutionProvider'], - provider_options=[cuda_provider_options] - ) + sess = ort.InferenceSession(model_file_path, options=options, providers=["CUDAExecutionProvider"], provider_options=[cuda_provider_options]) # Explicit arena shrinkage for GPU to release VRAM back to the system after each run if os.environ.get("OCR_GPUMEM_ARENA_SHRINKAGE") == "1": run_options.add_run_config_entry("memory.enable_memory_arena_shrinkage", f"gpu:{provider_device_id}") - logging.info( - f"load_model {model_file_path} enabled GPU memory arena shrinkage on device {provider_device_id}") + logging.info(f"load_model {model_file_path} enabled GPU memory arena shrinkage on device {provider_device_id}") logging.info(f"load_model {model_file_path} uses GPU (device {provider_device_id}, gpu_mem_limit={cuda_provider_options['gpu_mem_limit']}, arena_strategy={arena_strategy})") else: - sess = ort.InferenceSession( - model_file_path, - options=options, - providers=['CPUExecutionProvider']) + sess = ort.InferenceSession(model_file_path, options=options, providers=["CPUExecutionProvider"]) run_options.add_run_config_entry("memory.enable_memory_arena_shrinkage", "cpu") logging.info(f"load_model {model_file_path} uses CPU") loaded_model = (sess, run_options) @@ -140,13 +130,9 @@ class TextRecognizer: def __init__(self, model_dir, device_id: int | None = None): self.rec_image_shape = [int(v) for v in "3, 48, 320".split(",")] self.rec_batch_num = 16 - postprocess_params = { - 'name': 'CTCLabelDecode', - "character_dict_path": os.path.join(model_dir, "ocr.res"), - "use_space_char": True - } + postprocess_params = {"name": "CTCLabelDecode", "character_dict_path": os.path.join(model_dir, "ocr.res"), "use_space_char": True} self.postprocess_op = build_post_process(postprocess_params) - self.predictor, self.run_options = load_model(model_dir, 'rec', device_id) + self.predictor, self.run_options = load_model(model_dir, "rec", device_id) self.input_tensor = self.predictor.get_inputs()[0] def resize_norm_img(self, img, max_wh_ratio): @@ -167,7 +153,7 @@ class TextRecognizer: resized_w = int(math.ceil(imgH * ratio)) resized_image = cv2.resize(img, (resized_w, imgH)) - resized_image = resized_image.astype('float32') + resized_image = resized_image.astype("float32") resized_image = resized_image.transpose((2, 0, 1)) / 255 resized_image -= 0.5 resized_image /= 0.5 @@ -179,9 +165,8 @@ class TextRecognizer: imgC, imgH, imgW = image_shape img = img[:, :, ::-1] # bgr2rgb - resized_image = cv2.resize( - img, (imgW, imgH), interpolation=cv2.INTER_LINEAR) - resized_image = resized_image.astype('float32') + resized_image = cv2.resize(img, (imgW, imgH), interpolation=cv2.INTER_LINEAR) + resized_image = resized_image.astype("float32") resized_image = resized_image.transpose((2, 0, 1)) / 255 return resized_image @@ -203,7 +188,7 @@ class TextRecognizer: img_np = np.asarray(img_new) img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY) - img_black[:, 0:img_np.shape[1]] = img_np + img_black[:, 0 : img_np.shape[1]] = img_np img_black = img_black[:, :, np.newaxis] row, col, c = img_black.shape @@ -216,49 +201,35 @@ class TextRecognizer: imgC, imgH, imgW = image_shape feature_dim = int((imgH / 8) * (imgW / 8)) - encoder_word_pos = np.array(range(0, feature_dim)).reshape( - (feature_dim, 1)).astype('int64') - gsrm_word_pos = np.array(range(0, max_text_length)).reshape( - (max_text_length, 1)).astype('int64') + encoder_word_pos = np.array(range(0, feature_dim)).reshape((feature_dim, 1)).astype("int64") + gsrm_word_pos = np.array(range(0, max_text_length)).reshape((max_text_length, 1)).astype("int64") gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length)) - gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape( - [-1, 1, max_text_length, max_text_length]) - gsrm_slf_attn_bias1 = np.tile( - gsrm_slf_attn_bias1, - [1, num_heads, 1, 1]).astype('float32') * [-1e9] + gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape([-1, 1, max_text_length, max_text_length]) + gsrm_slf_attn_bias1 = np.tile(gsrm_slf_attn_bias1, [1, num_heads, 1, 1]).astype("float32") * [-1e9] - gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape( - [-1, 1, max_text_length, max_text_length]) - gsrm_slf_attn_bias2 = np.tile( - gsrm_slf_attn_bias2, - [1, num_heads, 1, 1]).astype('float32') * [-1e9] + gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape([-1, 1, max_text_length, max_text_length]) + gsrm_slf_attn_bias2 = np.tile(gsrm_slf_attn_bias2, [1, num_heads, 1, 1]).astype("float32") * [-1e9] encoder_word_pos = encoder_word_pos[np.newaxis, :] gsrm_word_pos = gsrm_word_pos[np.newaxis, :] - return [ - encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, - gsrm_slf_attn_bias2 - ] + return [encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] def process_image_srn(self, img, image_shape, num_heads, max_text_length): norm_img = self.resize_norm_img_srn(img, image_shape) norm_img = norm_img[np.newaxis, :] - [encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \ - self.srn_other_inputs(image_shape, num_heads, max_text_length) + [encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = self.srn_other_inputs(image_shape, num_heads, max_text_length) gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32) gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32) encoder_word_pos = encoder_word_pos.astype(np.int64) gsrm_word_pos = gsrm_word_pos.astype(np.int64) - return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, - gsrm_slf_attn_bias2) + return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2) - def resize_norm_img_sar(self, img, image_shape, - width_downsample_ratio=0.25): + def resize_norm_img_sar(self, img, image_shape, width_downsample_ratio=0.25): imgC, imgH, imgW_min, imgW_max = image_shape h = img.shape[0] w = img.shape[1] @@ -276,7 +247,7 @@ class TextRecognizer: valid_ratio = min(1.0, 1.0 * resize_w / imgW_max) resize_w = min(imgW_max, resize_w) resized_image = cv2.resize(img, (resize_w, imgH)) - resized_image = resized_image.astype('float32') + resized_image = resized_image.astype("float32") # norm if image_shape[0] == 1: resized_image = resized_image / 255 @@ -312,9 +283,8 @@ class TextRecognizer: def resize_norm_img_svtr(self, img, image_shape): imgC, imgH, imgW = image_shape - resized_image = cv2.resize( - img, (imgW, imgH), interpolation=cv2.INTER_LINEAR) - resized_image = resized_image.astype('float32') + resized_image = cv2.resize(img, (imgW, imgH), interpolation=cv2.INTER_LINEAR) + resized_image = resized_image.astype("float32") resized_image = resized_image.transpose((2, 0, 1)) / 255 resized_image -= 0.5 resized_image /= 0.5 @@ -324,24 +294,21 @@ class TextRecognizer: imgC, imgH, imgW = image_shape - resized_image = cv2.resize( - img, (imgW, imgH), interpolation=cv2.INTER_LINEAR) - resized_image = resized_image.astype('float32') - resized_image = resized_image / 255. + resized_image = cv2.resize(img, (imgW, imgH), interpolation=cv2.INTER_LINEAR) + resized_image = resized_image.astype("float32") + resized_image = resized_image / 255.0 mean = np.array([0.485, 0.456, 0.406]) std = np.array([0.229, 0.224, 0.225]) - resized_image = ( - resized_image - mean[None, None, ...]) / std[None, None, ...] + resized_image = (resized_image - mean[None, None, ...]) / std[None, None, ...] resized_image = resized_image.transpose((2, 0, 1)) - resized_image = resized_image.astype('float32') + resized_image = resized_image.astype("float32") return resized_image def norm_img_can(self, img, image_shape): - img = cv2.cvtColor( - img, cv2.COLOR_BGR2GRAY) # CAN only predict gray scale image + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # CAN only predict gray scale image if self.rec_image_shape[0] == 1: h, w = img.shape @@ -349,19 +316,17 @@ class TextRecognizer: if h < imgH or w < imgW: padding_h = max(imgH - h, 0) padding_w = max(imgW - w, 0) - img_padded = np.pad(img, ((0, padding_h), (0, padding_w)), - 'constant', - constant_values=(255)) + img_padded = np.pad(img, ((0, padding_h), (0, padding_w)), "constant", constant_values=(255)) img = img_padded img = np.expand_dims(img, 0) / 255.0 # h,w,c -> c,h,w - img = img.astype('float32') + img = img.astype("float32") return img def close(self): # close session and release manually - logging.info('Close text recognizer.') + logging.info("Close text recognizer.") if hasattr(self, "predictor"): del self.predictor gc.collect() @@ -374,7 +339,7 @@ class TextRecognizer: width_list.append(img.shape[1] / float(img.shape[0])) # Sorting can speed up the recognition process indices = np.argsort(np.array(width_list)) - rec_res = [['', 0.0]] * img_num + rec_res = [["", 0.0]] * img_num batch_num = self.rec_batch_num st = time.time() @@ -389,8 +354,7 @@ class TextRecognizer: wh_ratio = w * 1.0 / h max_wh_ratio = max(max_wh_ratio, wh_ratio) for ino in range(beg_img_no, end_img_no): - norm_img = self.resize_norm_img(img_list[indices[ino]], - max_wh_ratio) + norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio) norm_img = norm_img[np.newaxis, :] norm_img_batch.append(norm_img) norm_img_batch = np.concatenate(norm_img_batch) @@ -419,41 +383,28 @@ class TextRecognizer: class TextDetector: def __init__(self, model_dir, device_id: int | None = None): - pre_process_list = [{ - 'DetResizeForTest': { - 'limit_side_len': 960, - 'limit_type': "max", - } - }, { - 'NormalizeImage': { - 'std': [0.229, 0.224, 0.225], - 'mean': [0.485, 0.456, 0.406], - 'scale': '1./255.', - 'order': 'hwc' - } - }, { - 'ToCHWImage': None - }, { - 'KeepKeys': { - 'keep_keys': ['image', 'shape'] - } - }] - postprocess_params = {"name": "DBPostProcess", "thresh": 0.3, "box_thresh": 0.5, "max_candidates": 1000, - "unclip_ratio": 1.5, "use_dilation": False, "score_mode": "fast", "box_type": "quad"} + pre_process_list = [ + { + "DetResizeForTest": { + "limit_side_len": 960, + "limit_type": "max", + } + }, + {"NormalizeImage": {"std": [0.229, 0.224, 0.225], "mean": [0.485, 0.456, 0.406], "scale": "1./255.", "order": "hwc"}}, + {"ToCHWImage": None}, + {"KeepKeys": {"keep_keys": ["image", "shape"]}}, + ] + postprocess_params = {"name": "DBPostProcess", "thresh": 0.3, "box_thresh": 0.5, "max_candidates": 1000, "unclip_ratio": 1.5, "use_dilation": False, "score_mode": "fast", "box_type": "quad"} self.postprocess_op = build_post_process(postprocess_params) - self.predictor, self.run_options = load_model(model_dir, 'det', device_id) + self.predictor, self.run_options = load_model(model_dir, "det", device_id) self.input_tensor = self.predictor.get_inputs()[0] img_h, img_w = self.input_tensor.shape[2:] if isinstance(img_h, str) or isinstance(img_w, str): pass elif img_h is not None and img_w is not None and img_h > 0 and img_w > 0: - pre_process_list[0] = { - 'DetResizeForTest': { - 'image_shape': [img_h, img_w] - } - } + pre_process_list[0] = {"DetResizeForTest": {"image_shape": [img_h, img_w]}} self.preprocess_op = create_operators(pre_process_list) def order_points_clockwise(self, pts): @@ -508,7 +459,7 @@ class TextDetector: def __call__(self, img): ori_im = img.copy() - data = {'image': img} + data = {"image": img} st = time.time() data = transform(data, self.preprocess_op) @@ -530,7 +481,7 @@ class TextDetector: time.sleep(5) post_result = self.postprocess_op({"maps": outputs[0]}, shape_list) - dt_boxes = post_result[0]['points'] + dt_boxes = post_result[0]["points"] dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape) return dt_boxes, time.time() - st @@ -554,10 +505,8 @@ class OCR: """ if not model_dir: try: - model_dir = os.path.join( - get_project_base_directory(), - "rag/res/deepdoc") - + model_dir = os.path.join(get_project_base_directory(), "rag/res/deepdoc") + # Append muti-gpus task to the list if settings.PARALLEL_DEVICES > 0: self.text_detector = [] @@ -574,7 +523,7 @@ class OCR: repo_id="InfiniFlow/deepdoc", local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"), ) - + if settings.PARALLEL_DEVICES > 0: self.text_detector = [] self.text_recognizer = [] @@ -600,23 +549,11 @@ class OCR: points[:, 1] = points[:, 1] - top """ assert len(points) == 4, "shape of points must be 4*2" - img_crop_width = int( - max( - np.linalg.norm(points[0] - points[1]), - np.linalg.norm(points[2] - points[3]))) - img_crop_height = int( - max( - np.linalg.norm(points[0] - points[3]), - np.linalg.norm(points[1] - points[2]))) - pts_std = np.float32([[0, 0], [img_crop_width, 0], - [img_crop_width, img_crop_height], - [0, img_crop_height]]) + img_crop_width = int(max(np.linalg.norm(points[0] - points[1]), np.linalg.norm(points[2] - points[3]))) + img_crop_height = int(max(np.linalg.norm(points[0] - points[3]), np.linalg.norm(points[1] - points[2]))) + pts_std = np.float32([[0, 0], [img_crop_width, 0], [img_crop_width, img_crop_height], [0, img_crop_height]]) M = cv2.getPerspectiveTransform(points, pts_std) - dst_img = cv2.warpPerspective( - img, - M, (img_crop_width, img_crop_height), - borderMode=cv2.BORDER_REPLICATE, - flags=cv2.INTER_CUBIC) + dst_img = cv2.warpPerspective(img, M, (img_crop_width, img_crop_height), borderMode=cv2.BORDER_REPLICATE, flags=cv2.INTER_CUBIC) dst_img_height, dst_img_width = dst_img.shape[0:2] if dst_img_height * 1.0 / dst_img_width >= 1.5: # Try original orientation @@ -658,8 +595,7 @@ class OCR: for i in range(num_boxes - 1): for j in range(i, -1, -1): - if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \ - (_boxes[j + 1][0][0] < _boxes[j][0][0]): + if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and (_boxes[j + 1][0][0] < _boxes[j][0][0]): tmp = _boxes[j] _boxes[j] = _boxes[j + 1] _boxes[j + 1] = tmp @@ -679,8 +615,7 @@ class OCR: if dt_boxes is None: return None - return zip(self.sorted_boxes(dt_boxes), [ - ("", 0) for _ in range(len(dt_boxes))]) + return zip(self.sorted_boxes(dt_boxes), [("", 0) for _ in range(len(dt_boxes))]) def recognize(self, ori_im, box, device_id: int | None = None): if device_id is None: @@ -706,8 +641,8 @@ class OCR: texts.append(text) return texts - def __call__(self, img, device_id = 0, cls=True): - time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0} + def __call__(self, img, device_id=0, cls=True): + time_dict = {"det": 0, "rec": 0, "cls": 0, "all": 0} if device_id is None: device_id = 0 @@ -717,11 +652,11 @@ class OCR: start = time.time() ori_im = img.copy() dt_boxes, elapse = self.text_detector[device_id](img) - time_dict['det'] = elapse + time_dict["det"] = elapse if dt_boxes is None: end = time.time() - time_dict['all'] = end - start + time_dict["all"] = end - start return None, None, time_dict img_crop_list = [] @@ -735,7 +670,7 @@ class OCR: rec_res, elapse = self.text_recognizer[device_id](img_crop_list) - time_dict['rec'] = elapse + time_dict["rec"] = elapse filter_boxes, filter_rec_res = [], [] for box, rec_result in zip(dt_boxes, rec_res): @@ -744,7 +679,7 @@ class OCR: filter_boxes.append(box) filter_rec_res.append(rec_result) end = time.time() - time_dict['all'] = end - start + time_dict["all"] = end - start # for bno in range(len(img_crop_list)): # print(f"{bno}, {rec_res[bno]}") diff --git a/deepdoc/vision/operators.py b/deepdoc/vision/operators.py index 43b55ccd3a..6bcbcb2ee9 100644 --- a/deepdoc/vision/operators.py +++ b/deepdoc/vision/operators.py @@ -26,44 +26,36 @@ from rag.utils.lazy_image import ensure_pil_image class DecodeImage: - """ decode image """ + """decode image""" - def __init__(self, - img_mode='RGB', - channel_first=False, - ignore_orientation=False, - **kwargs): + def __init__(self, img_mode="RGB", channel_first=False, ignore_orientation=False, **kwargs): self.img_mode = img_mode self.channel_first = channel_first self.ignore_orientation = ignore_orientation def __call__(self, data): - img = data['image'] + img = data["image"] if six.PY2: - assert isinstance(img, str) and len( - img) > 0, "invalid input 'img' in DecodeImage" + assert isinstance(img, str) and len(img) > 0, "invalid input 'img' in DecodeImage" else: - assert isinstance(img, bytes) and len( - img) > 0, "invalid input 'img' in DecodeImage" - img = np.frombuffer(img, dtype='uint8') + assert isinstance(img, bytes) and len(img) > 0, "invalid input 'img' in DecodeImage" + img = np.frombuffer(img, dtype="uint8") if self.ignore_orientation: - img = cv2.imdecode(img, cv2.IMREAD_IGNORE_ORIENTATION | - cv2.IMREAD_COLOR) + img = cv2.imdecode(img, cv2.IMREAD_IGNORE_ORIENTATION | cv2.IMREAD_COLOR) else: img = cv2.imdecode(img, 1) if img is None: return None - if self.img_mode == 'GRAY': + if self.img_mode == "GRAY": img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) - elif self.img_mode == 'RGB': - assert img.shape[2] == 3, 'invalid shape of image[%s]' % ( - img.shape) + elif self.img_mode == "RGB": + assert img.shape[2] == 3, "invalid shape of image[%s]" % (img.shape) img = img[:, :, ::-1] if self.channel_first: img = img.transpose((2, 0, 1)) - data['image'] = img + data["image"] = img return data @@ -76,7 +68,7 @@ class StandardizeImag: norm_type (str): type in ['mean_std', 'none'] """ - def __init__(self, mean, std, is_scale=True, norm_type='mean_std'): + def __init__(self, mean, std, is_scale=True, norm_type="mean_std"): self.mean = mean self.std = std self.is_scale = is_scale @@ -96,7 +88,7 @@ class StandardizeImag: scale = 1.0 / 255.0 im *= scale - if self.norm_type == 'mean_std': + if self.norm_type == "mean_std": mean = np.array(self.mean)[np.newaxis, np.newaxis, :] std = np.array(self.std)[np.newaxis, np.newaxis, :] im -= mean @@ -105,16 +97,15 @@ class StandardizeImag: class NormalizeImage: - """ normalize image such as subtract mean, divide std - """ + """normalize image such as subtract mean, divide std""" - def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs): + def __init__(self, scale=None, mean=None, std=None, order="chw", **kwargs): if isinstance(scale, str): try: scale = float(scale) except ValueError: - if '/' in scale: - parts = scale.split('/') + if "/" in scale: + parts = scale.split("/") scale = ast.literal_eval(parts[0]) / ast.literal_eval(parts[1]) else: scale = ast.literal_eval(scale) @@ -122,37 +113,36 @@ class NormalizeImage: mean = mean if mean is not None else [0.485, 0.456, 0.406] std = std if std is not None else [0.229, 0.224, 0.225] - shape = (3, 1, 1) if order == 'chw' else (1, 1, 3) - self.mean = np.array(mean).reshape(shape).astype('float32') - self.std = np.array(std).reshape(shape).astype('float32') + shape = (3, 1, 1) if order == "chw" else (1, 1, 3) + self.mean = np.array(mean).reshape(shape).astype("float32") + self.std = np.array(std).reshape(shape).astype("float32") def __call__(self, data): - img = data['image'] + img = data["image"] from PIL import Image + pil = ensure_pil_image(img) if isinstance(pil, Image.Image): img = np.array(pil) - assert isinstance(img, - np.ndarray), "invalid input 'img' in NormalizeImage" - data['image'] = ( - img.astype('float32') * self.scale - self.mean) / self.std + assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage" + data["image"] = (img.astype("float32") * self.scale - self.mean) / self.std return data class ToCHWImage: - """ convert hwc image to chw image - """ + """convert hwc image to chw image""" def __init__(self, **kwargs): pass def __call__(self, data): - img = data['image'] + img = data["image"] from PIL import Image + pil = ensure_pil_image(img) if isinstance(pil, Image.Image): img = np.array(pil) - data['image'] = img.transpose((2, 0, 1)) + data["image"] = img.transpose((2, 0, 1)) return data @@ -170,8 +160,7 @@ class KeepKeys: class Pad: def __init__(self, size=None, size_div=32, **kwargs): if size is not None and not isinstance(size, (int, list, tuple)): - raise TypeError("Type of target_size is invalid. Now is {}".format( - type(size))) + raise TypeError("Type of target_size is invalid. Now is {}".format(type(size))) if isinstance(size, int): size = [size, size] self.size = size @@ -179,29 +168,16 @@ class Pad: def __call__(self, data): - img = data['image'] + img = data["image"] img_h, img_w = img.shape[0], img.shape[1] if self.size: resize_h2, resize_w2 = self.size - assert ( - img_h < resize_h2 and img_w < resize_w2 - ), '(h, w) of target size should be greater than (img_h, img_w)' + assert img_h < resize_h2 and img_w < resize_w2, "(h, w) of target size should be greater than (img_h, img_w)" else: - resize_h2 = max( - int(math.ceil(img.shape[0] / self.size_div) * self.size_div), - self.size_div) - resize_w2 = max( - int(math.ceil(img.shape[1] / self.size_div) * self.size_div), - self.size_div) - img = cv2.copyMakeBorder( - img, - 0, - resize_h2 - img_h, - 0, - resize_w2 - img_w, - cv2.BORDER_CONSTANT, - value=0) - data['image'] = img + resize_h2 = max(int(math.ceil(img.shape[0] / self.size_div) * self.size_div), self.size_div) + resize_w2 = max(int(math.ceil(img.shape[1] / self.size_div) * self.size_div), self.size_div) + img = cv2.copyMakeBorder(img, 0, resize_h2 - img_h, 0, resize_w2 - img_w, cv2.BORDER_CONSTANT, value=0) + data["image"] = img return data @@ -233,16 +209,9 @@ class LinearResize: assert self.target_size[0] > 0 and self.target_size[1] > 0 _im_channel = im.shape[2] im_scale_y, im_scale_x = self.generate_scale(im) - im = cv2.resize( - im, - None, - None, - fx=im_scale_x, - fy=im_scale_y, - interpolation=self.interp) - im_info['im_shape'] = np.array(im.shape[:2]).astype('float32') - im_info['scale_factor'] = np.array( - [im_scale_y, im_scale_x]).astype('float32') + im = cv2.resize(im, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=self.interp) + im_info["im_shape"] = np.array(im.shape[:2]).astype("float32") + im_info["scale_factor"] = np.array([im_scale_y, im_scale_x]).astype("float32") return im, im_info def generate_scale(self, im): @@ -285,20 +254,20 @@ class Resize: return img, [ratio_h, ratio_w] def __call__(self, data): - img = data['image'] - if 'polys' in data: - text_polys = data['polys'] + img = data["image"] + if "polys" in data: + text_polys = data["polys"] img_resize, [ratio_h, ratio_w] = self.resize_image(img) - if 'polys' in data: + if "polys" in data: new_boxes = [] for box in text_polys: new_box = [] for cord in box: new_box.append([cord[0] * ratio_w, cord[1] * ratio_h]) new_boxes.append(new_box) - data['polys'] = np.array(new_boxes, dtype=np.float32) - data['image'] = img_resize + data["polys"] = np.array(new_boxes, dtype=np.float32) + data["image"] = img_resize return data @@ -307,23 +276,23 @@ class DetResizeForTest: super(DetResizeForTest, self).__init__() self.resize_type = 0 self.keep_ratio = False - if 'image_shape' in kwargs: - self.image_shape = kwargs['image_shape'] + if "image_shape" in kwargs: + self.image_shape = kwargs["image_shape"] self.resize_type = 1 - if 'keep_ratio' in kwargs: - self.keep_ratio = kwargs['keep_ratio'] - elif 'limit_side_len' in kwargs: - self.limit_side_len = kwargs['limit_side_len'] - self.limit_type = kwargs.get('limit_type', 'min') - elif 'resize_long' in kwargs: + if "keep_ratio" in kwargs: + self.keep_ratio = kwargs["keep_ratio"] + elif "limit_side_len" in kwargs: + self.limit_side_len = kwargs["limit_side_len"] + self.limit_type = kwargs.get("limit_type", "min") + elif "resize_long" in kwargs: self.resize_type = 2 - self.resize_long = kwargs.get('resize_long', 960) + self.resize_long = kwargs.get("resize_long", 960) else: self.limit_side_len = 736 - self.limit_type = 'min' + self.limit_type = "min" def __call__(self, data): - img = data['image'] + img = data["image"] src_h, src_w, _ = img.shape if sum([src_h, src_w]) < 64: img = self.image_padding(img) @@ -336,8 +305,8 @@ class DetResizeForTest: else: # img, shape = self.resize_image_type1(img) img, [ratio_h, ratio_w] = self.resize_image_type1(img) - data['image'] = img - data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w]) + data["image"] = img + data["shape"] = np.array([src_h, src_w, ratio_h, ratio_w]) return data def image_padding(self, im, value=0): @@ -371,26 +340,26 @@ class DetResizeForTest: h, w, c = img.shape # limit the max side - if self.limit_type == 'max': + if self.limit_type == "max": if max(h, w) > limit_side_len: if h > w: ratio = float(limit_side_len) / h else: ratio = float(limit_side_len) / w else: - ratio = 1. - elif self.limit_type == 'min': + ratio = 1.0 + elif self.limit_type == "min": if min(h, w) < limit_side_len: if h < w: ratio = float(limit_side_len) / h else: ratio = float(limit_side_len) / w else: - ratio = 1. - elif self.limit_type == 'resize_long': + ratio = 1.0 + elif self.limit_type == "resize_long": ratio = float(limit_side_len) / max(h, w) else: - raise Exception('not support limit type, image ') + raise Exception("not support limit type, image ") resize_h = int(h * ratio) resize_w = int(w * ratio) @@ -435,20 +404,18 @@ class DetResizeForTest: class E2EResizeForTest: def __init__(self, **kwargs): super(E2EResizeForTest, self).__init__() - self.max_side_len = kwargs['max_side_len'] - self.valid_set = kwargs['valid_set'] + self.max_side_len = kwargs["max_side_len"] + self.valid_set = kwargs["valid_set"] def __call__(self, data): - img = data['image'] + img = data["image"] src_h, src_w, _ = img.shape - if self.valid_set == 'totaltext': - im_resized, [ratio_h, ratio_w] = self.resize_image_for_totaltext( - img, max_side_len=self.max_side_len) + if self.valid_set == "totaltext": + im_resized, [ratio_h, ratio_w] = self.resize_image_for_totaltext(img, max_side_len=self.max_side_len) else: - im_resized, (ratio_h, ratio_w) = self.resize_image( - img, max_side_len=self.max_side_len) - data['image'] = im_resized - data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w]) + im_resized, (ratio_h, ratio_w) = self.resize_image(img, max_side_len=self.max_side_len) + data["image"] = im_resized + data["shape"] = np.array([src_h, src_w, ratio_h, ratio_w]) return data def resize_image_for_totaltext(self, im, max_side_len=512): @@ -503,33 +470,29 @@ class E2EResizeForTest: class KieResize: def __init__(self, **kwargs): super(KieResize, self).__init__() - self.max_side, self.min_side = kwargs['img_scale'][0], kwargs[ - 'img_scale'][1] + self.max_side, self.min_side = kwargs["img_scale"][0], kwargs["img_scale"][1] def __call__(self, data): - img = data['image'] - points = data['points'] + img = data["image"] + points = data["points"] src_h, src_w, _ = img.shape - im_resized, scale_factor, [ratio_h, ratio_w - ], [new_h, new_w] = self.resize_image(img) + im_resized, scale_factor, [ratio_h, ratio_w], [new_h, new_w] = self.resize_image(img) resize_points = self.resize_boxes(img, points, scale_factor) - data['ori_image'] = img - data['ori_boxes'] = points - data['points'] = resize_points - data['image'] = im_resized - data['shape'] = np.array([new_h, new_w]) + data["ori_image"] = img + data["ori_boxes"] = points + data["points"] = resize_points + data["image"] = im_resized + data["shape"] = np.array([new_h, new_w]) return data def resize_image(self, img): - norm_img = np.zeros([1024, 1024, 3], dtype='float32') + norm_img = np.zeros([1024, 1024, 3], dtype="float32") scale = [512, 1024] h, w = img.shape[:2] max_long_edge = max(scale) max_short_edge = min(scale) - scale_factor = min(max_long_edge / max(h, w), - max_short_edge / min(h, w)) - resize_w, resize_h = int(w * float(scale_factor) + 0.5), int(h * float( - scale_factor) + 0.5) + scale_factor = min(max_long_edge / max(h, w), max_short_edge / min(h, w)) + resize_w, resize_h = int(w * float(scale_factor) + 0.5), int(h * float(scale_factor) + 0.5) max_stride = 32 resize_h = (resize_h + max_stride - 1) // max_stride * max_stride resize_w = (resize_w + max_stride - 1) // max_stride * max_stride @@ -537,8 +500,7 @@ class KieResize: new_h, new_w = im.shape[:2] w_scale = new_w / w h_scale = new_h / h - scale_factor = np.array( - [w_scale, h_scale, w_scale, h_scale], dtype=np.float32) + scale_factor = np.array([w_scale, h_scale, w_scale, h_scale], dtype=np.float32) norm_img[:new_h, :new_w, :] = im return norm_img, scale_factor, [h_scale, w_scale], [new_h, new_w] @@ -551,15 +513,7 @@ class KieResize: class SRResize: - def __init__(self, - imgH=32, - imgW=128, - down_sample_scale=4, - keep_ratio=False, - min_ratio=1, - mask=False, - infer_mode=False, - **kwargs): + def __init__(self, imgH=32, imgW=128, down_sample_scale=4, keep_ratio=False, min_ratio=1, mask=False, infer_mode=False, **kwargs): self.imgH = imgH self.imgW = imgW self.keep_ratio = keep_ratio @@ -572,8 +526,7 @@ class SRResize: imgH = self.imgH imgW = self.imgW images_lr = data["image_lr"] - transform2 = ResizeNormalize( - (imgW // self.down_sample_scale, imgH // self.down_sample_scale)) + transform2 = ResizeNormalize((imgW // self.down_sample_scale, imgH // self.down_sample_scale)) images_lr = transform2(images_lr) data["img_lr"] = images_lr if self.infer_mode: @@ -610,16 +563,16 @@ class GrayImageChannelFormat: self.inverse = inverse def __call__(self, data): - img = data['image'] + img = data["image"] img_single_channel = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) img_expanded = np.expand_dims(img_single_channel, 0) if self.inverse: - data['image'] = np.abs(img_expanded - 1) + data["image"] = np.abs(img_expanded - 1) else: - data['image'] = img_expanded + data["image"] = img_expanded - data['src_image'] = img + data["src_image"] = img return data @@ -630,7 +583,9 @@ class Permute: channel_first (bool): whether convert HWC to CHW """ - def __init__(self, ): + def __init__( + self, + ): super(Permute, self).__init__() def __call__(self, im, im_info): @@ -647,7 +602,7 @@ class Permute: class PadStride: - """ padding image for model with FPN, instead PadBatch(pad_to_stride) in original config + """padding image for model with FPN, instead PadBatch(pad_to_stride) in original config Args: stride (bool): model with FPN need image shape % stride == 0 """ @@ -685,24 +640,23 @@ def decode_image(im_file, im_info): im_info (dict): info of processed image """ if isinstance(im_file, str): - with open(im_file, 'rb') as f: + with open(im_file, "rb") as f: im_read = f.read() - data = np.frombuffer(im_read, dtype='uint8') + data = np.frombuffer(im_read, dtype="uint8") im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) else: im = im_file - im_info['im_shape'] = np.array(im.shape[:2], dtype=np.float32) - im_info['scale_factor'] = np.array([1., 1.], dtype=np.float32) + im_info["im_shape"] = np.array(im.shape[:2], dtype=np.float32) + im_info["scale_factor"] = np.array([1.0, 1.0], dtype=np.float32) return im, im_info def preprocess(im, preprocess_ops): # process image by preprocess_ops im_info = { - 'scale_factor': np.array( - [1., 1.], dtype=np.float32), - 'im_shape': None, + "scale_factor": np.array([1.0, 1.0], dtype=np.float32), + "im_shape": None, } im, im_info = decode_image(im, im_info) for operator in preprocess_ops: @@ -712,6 +666,7 @@ def preprocess(im, preprocess_ops): def nms(bboxes, scores, iou_thresh): import numpy as np + x1 = bboxes[:, 0] y1 = bboxes[:, 1] x2 = bboxes[:, 2] diff --git a/deepdoc/vision/postprocess.py b/deepdoc/vision/postprocess.py index 7704bc5826..40197c0390 100644 --- a/deepdoc/vision/postprocess.py +++ b/deepdoc/vision/postprocess.py @@ -23,18 +23,17 @@ import pyclipper def build_post_process(config, global_config=None): - support_dict = {'DBPostProcess': DBPostProcess, 'CTCLabelDecode': CTCLabelDecode} + support_dict = {"DBPostProcess": DBPostProcess, "CTCLabelDecode": CTCLabelDecode} config = copy.deepcopy(config) - module_name = config.pop('name') + module_name = config.pop("name") if module_name == "None": return if global_config is not None: config.update(global_config) module_class = support_dict.get(module_name) if module_class is None: - raise ValueError( - 'post process only support {}'.format(list(support_dict))) + raise ValueError("post process only support {}".format(list(support_dict))) return module_class(**config) @@ -43,15 +42,7 @@ class DBPostProcess: The post process for Differentiable Binarization (DB). """ - def __init__(self, - thresh=0.3, - box_thresh=0.7, - max_candidates=1000, - unclip_ratio=2.0, - use_dilation=False, - score_mode="fast", - box_type='quad', - **kwargs): + def __init__(self, thresh=0.3, box_thresh=0.7, max_candidates=1000, unclip_ratio=2.0, use_dilation=False, score_mode="fast", box_type="quad", **kwargs): self.thresh = thresh self.box_thresh = box_thresh self.max_candidates = max_candidates @@ -59,12 +50,9 @@ class DBPostProcess: self.min_size = 3 self.score_mode = score_mode self.box_type = box_type - assert score_mode in [ - "slow", "fast" - ], "Score mode must be in [slow, fast] but got: {}".format(score_mode) + assert score_mode in ["slow", "fast"], "Score mode must be in [slow, fast] but got: {}".format(score_mode) - self.dilation_kernel = None if not use_dilation else np.array( - [[1, 1], [1, 1]]) + self.dilation_kernel = None if not use_dilation else np.array([[1, 1], [1, 1]]) def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height): """ @@ -78,10 +66,9 @@ class DBPostProcess: boxes = [] scores = [] - contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8), - cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) + contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) - for contour in contours[:self.max_candidates]: + for contour in contours[: self.max_candidates]: epsilon = 0.002 * cv2.arcLength(contour, True) approx = cv2.approxPolyDP(contour, epsilon, True) points = approx.reshape((-1, 2)) @@ -105,10 +92,8 @@ class DBPostProcess: continue box = np.array(box) - box[:, 0] = np.clip( - np.round(box[:, 0] / width * dest_width), 0, dest_width) - box[:, 1] = np.clip( - np.round(box[:, 1] / height * dest_height), 0, dest_height) + box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width) + box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0, dest_height) boxes.append(box.tolist()) scores.append(score) return boxes, scores @@ -122,8 +107,7 @@ class DBPostProcess: bitmap = _bitmap height, width = bitmap.shape - outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, - cv2.CHAIN_APPROX_SIMPLE) + outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) if len(outs) == 3: _img, contours, _ = outs[0], outs[1], outs[2] elif len(outs) == 2: @@ -152,10 +136,8 @@ class DBPostProcess: continue box = np.array(box) - box[:, 0] = np.clip( - np.round(box[:, 0] / width * dest_width), 0, dest_width) - box[:, 1] = np.clip( - np.round(box[:, 1] / height * dest_height), 0, dest_height) + box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width) + box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0, dest_height) boxes.append(box.astype("int32")) scores.append(score) return np.array(boxes, dtype="int32"), scores @@ -186,9 +168,7 @@ class DBPostProcess: index_2 = 3 index_3 = 2 - box = [ - points[index_1], points[index_2], points[index_3], points[index_4] - ] + box = [points[index_1], points[index_2], points[index_3], points[index_4]] return box, min(bounding_box[1]) def box_score_fast(self, bitmap, _box): @@ -206,7 +186,7 @@ class DBPostProcess: box[:, 0] = box[:, 0] - xmin box[:, 1] = box[:, 1] - ymin cv2.fillPoly(mask, box.reshape(1, -1, 2).astype("int32"), 1) - return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] + return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0] def box_score_slow(self, bitmap, contour): """ @@ -227,10 +207,10 @@ class DBPostProcess: contour[:, 1] = contour[:, 1] - ymin cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype("int32"), 1) - return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] + return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0] def __call__(self, outs_dict, shape_list): - pred = outs_dict['maps'] + pred = outs_dict["maps"] if not isinstance(pred, np.ndarray): pred = pred.numpy() pred = pred[:, 0, :, :] @@ -240,27 +220,22 @@ class DBPostProcess: for batch_index in range(pred.shape[0]): src_h, src_w, ratio_h, ratio_w = shape_list[batch_index] if self.dilation_kernel is not None: - mask = cv2.dilate( - np.array(segmentation[batch_index]).astype(np.uint8), - self.dilation_kernel) + mask = cv2.dilate(np.array(segmentation[batch_index]).astype(np.uint8), self.dilation_kernel) else: mask = segmentation[batch_index] - if self.box_type == 'poly': - boxes, scores = self.polygons_from_bitmap(pred[batch_index], - mask, src_w, src_h) - elif self.box_type == 'quad': - boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask, - src_w, src_h) + if self.box_type == "poly": + boxes, scores = self.polygons_from_bitmap(pred[batch_index], mask, src_w, src_h) + elif self.box_type == "quad": + boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask, src_w, src_h) else: - raise ValueError( - "box_type can only be one of ['quad', 'poly']") + raise ValueError("box_type can only be one of ['quad', 'poly']") - boxes_batch.append({'points': boxes}) + boxes_batch.append({"points": boxes}) return boxes_batch class BaseRecLabelDecode: - """ Convert between text-label and text-index """ + """Convert between text-label and text-index""" def __init__(self, character_dict_path=None, use_space_char=False): self.beg_str = "sos" @@ -275,12 +250,12 @@ class BaseRecLabelDecode: with open(character_dict_path, "rb") as fin: lines = fin.readlines() for line in lines: - line = line.decode('utf-8').strip("\n").strip("\r\n") + line = line.decode("utf-8").strip("\n").strip("\r\n") self.character_str.append(line) if use_space_char: self.character_str.append(" ") dict_character = list(self.character_str) - if 'arabic' in character_dict_path: + if "arabic" in character_dict_path: self.reverse = True dict_character = self.add_special_char(dict_character) @@ -291,40 +266,36 @@ class BaseRecLabelDecode: def pred_reverse(self, pred): pred_re = [] - c_current = '' + c_current = "" for c in pred: - if not bool(re.search('[a-zA-Z0-9 :*./%+-]', c)): - if c_current != '': + if not bool(re.search("[a-zA-Z0-9 :*./%+-]", c)): + if c_current != "": pred_re.append(c_current) pred_re.append(c) - c_current = '' + c_current = "" else: c_current += c - if c_current != '': + if c_current != "": pred_re.append(c_current) - return ''.join(pred_re[::-1]) + return "".join(pred_re[::-1]) def add_special_char(self, dict_character): return dict_character def decode(self, text_index, text_prob=None, is_remove_duplicate=False): - """ convert text-index into text-label. """ + """convert text-index into text-label.""" result_list = [] ignored_tokens = self.get_ignored_tokens() batch_size = len(text_index) for batch_idx in range(batch_size): selection = np.ones(len(text_index[batch_idx]), dtype=bool) if is_remove_duplicate: - selection[1:] = text_index[batch_idx][1:] != text_index[ - batch_idx][:-1] + selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1] for ignored_token in ignored_tokens: selection &= text_index[batch_idx] != ignored_token - char_list = [ - self.character[text_id] - for text_id in text_index[batch_idx][selection] - ] + char_list = [self.character[text_id] for text_id in text_index[batch_idx][selection]] if text_prob is not None: conf_list = text_prob[batch_idx][selection] else: @@ -332,7 +303,7 @@ class BaseRecLabelDecode: if len(conf_list) == 0: conf_list = [0] - text = ''.join(char_list) + text = "".join(char_list) if self.reverse: # for arabic rec text = self.pred_reverse(text) @@ -345,12 +316,10 @@ class BaseRecLabelDecode: class CTCLabelDecode(BaseRecLabelDecode): - """ Convert between text-label and text-index """ + """Convert between text-label and text-index""" - def __init__(self, character_dict_path=None, use_space_char=False, - **kwargs): - super(CTCLabelDecode, self).__init__(character_dict_path, - use_space_char) + def __init__(self, character_dict_path=None, use_space_char=False, **kwargs): + super(CTCLabelDecode, self).__init__(character_dict_path, use_space_char) def __call__(self, preds, label=None, *args, **kwargs): if isinstance(preds, tuple) or isinstance(preds, list): @@ -366,5 +335,5 @@ class CTCLabelDecode(BaseRecLabelDecode): return text, label def add_special_char(self, dict_character): - dict_character = ['blank'] + dict_character + dict_character = ["blank"] + dict_character return dict_character diff --git a/deepdoc/vision/recognizer.py b/deepdoc/vision/recognizer.py index af259a3842..dcb7af4708 100644 --- a/deepdoc/vision/recognizer.py +++ b/deepdoc/vision/recognizer.py @@ -28,6 +28,7 @@ from .operators import preprocess from . import operators from .ocr import load_model + class Recognizer: def __init__(self, label_list, task_name, model_dir=None): """ @@ -42,9 +43,7 @@ class Recognizer: """ if not model_dir: - model_dir = os.path.join( - get_project_base_directory(), - "rag/res/deepdoc") + model_dir = os.path.join(get_project_base_directory(), "rag/res/deepdoc") self.ort_sess, self.run_options = load_model(model_dir, task_name) self.input_names = [node.name for node in self.ort_sess.get_inputs()] self.output_names = [node.name for node in self.ort_sess.get_outputs()] @@ -58,6 +57,7 @@ class Recognizer: if abs(diff) < threshold: diff = c1["x0"] - c2["x0"] return diff + arr = sorted(arr, key=cmp_to_key(cmp)) return arr @@ -68,6 +68,7 @@ class Recognizer: if abs(diff) < threshold: diff = c1["top"] - c2["top"] return diff + arr = sorted(arr, key=cmp_to_key(cmp)) return arr @@ -81,11 +82,7 @@ class Recognizer: # restore the order using th if "C" not in arr[j] or "C" not in arr[j + 1]: continue - if arr[j + 1]["C"] < arr[j]["C"] \ - or ( - arr[j + 1]["C"] == arr[j]["C"] - and arr[j + 1]["top"] < arr[j]["top"] - ): + if arr[j + 1]["C"] < arr[j]["C"] or (arr[j + 1]["C"] == arr[j]["C"] and arr[j + 1]["top"] < arr[j]["top"]): tmp = arr[j] arr[j] = arr[j + 1] arr[j + 1] = tmp @@ -100,11 +97,7 @@ class Recognizer: for j in range(i, -1, -1): if "R" not in arr[j] or "R" not in arr[j + 1]: continue - if arr[j + 1]["R"] < arr[j]["R"] \ - or ( - arr[j + 1]["R"] == arr[j]["R"] - and arr[j + 1]["x0"] < arr[j]["x0"] - ): + if arr[j + 1]["R"] < arr[j]["R"] or (arr[j + 1]["R"] == arr[j]["R"] and arr[j + 1]["x0"] < arr[j]["x0"]): tmp = arr[j] arr[j] = arr[j + 1] arr[j + 1] = tmp @@ -119,14 +112,11 @@ class Recognizer: return 0 x0_ = max(b["x0"], x0) x1_ = min(b["x1"], x1) - assert x0_ <= x1_, "Bbox mismatch! T:{},B:{},X0:{},X1:{} ==> {}".format( - tp, btm, x0, x1, b) + assert x0_ <= x1_, "Bbox mismatch! T:{},B:{},X0:{},X1:{} ==> {}".format(tp, btm, x0, x1, b) tp_ = max(b["top"], tp) btm_ = min(b["bottom"], btm) - assert tp_ <= btm_, "Bbox mismatch! T:{},B:{},X0:{},X1:{} => {}".format( - tp, btm, x0, x1, b) - ov = (btm_ - tp_) * (x1_ - x0_) if x1 - \ - x0 != 0 and btm - tp != 0 else 0 + assert tp_ <= btm_, "Bbox mismatch! T:{},B:{},X0:{},X1:{} => {}".format(tp, btm, x0, x1, b) + ov = (btm_ - tp_) * (x1_ - x0_) if x1 - x0 != 0 and btm - tp != 0 else 0 if ov > 0 and ratio: ov /= (x1 - x0) * (btm - tp) return ov @@ -134,23 +124,17 @@ class Recognizer: @staticmethod def layouts_cleanup(boxes, layouts, far=2, thr=0.7): def not_overlapped(a, b): - return any([a["x1"] < b["x0"], - a["x0"] > b["x1"], - a["bottom"] < b["top"], - a["top"] > b["bottom"]]) + return any([a["x1"] < b["x0"], a["x0"] > b["x1"], a["bottom"] < b["top"], a["top"] > b["bottom"]]) i = 0 while i + 1 < len(layouts): j = i + 1 - while j < min(i + far, len(layouts)) \ - and (layouts[i].get("type", "") != layouts[j].get("type", "") - or not_overlapped(layouts[i], layouts[j])): + while j < min(i + far, len(layouts)) and (layouts[i].get("type", "") != layouts[j].get("type", "") or not_overlapped(layouts[i], layouts[j])): j += 1 if j >= min(i + far, len(layouts)): i += 1 continue - if Recognizer.overlapped_area(layouts[i], layouts[j]) < thr \ - and Recognizer.overlapped_area(layouts[j], layouts[i]) < thr: + if Recognizer.overlapped_area(layouts[i], layouts[j]) < thr and Recognizer.overlapped_area(layouts[j], layouts[i]) < thr: i += 1 continue @@ -188,18 +172,16 @@ class Recognizer: im_shape = [] scale_factor = [] if len(imgs) == 1: - inputs['image'] = np.array((imgs[0],)).astype('float32') - inputs['im_shape'] = np.array( - (im_info[0]['im_shape'],)).astype('float32') - inputs['scale_factor'] = np.array( - (im_info[0]['scale_factor'],)).astype('float32') + inputs["image"] = np.array((imgs[0],)).astype("float32") + inputs["im_shape"] = np.array((im_info[0]["im_shape"],)).astype("float32") + inputs["scale_factor"] = np.array((im_info[0]["scale_factor"],)).astype("float32") return inputs - - im_shape = np.array([info['im_shape'] for info in im_info], dtype='float32') - scale_factor = np.array([info['scale_factor'] for info in im_info], dtype='float32') - inputs['im_shape'] = np.concatenate(im_shape, axis=0) - inputs['scale_factor'] = np.concatenate(scale_factor, axis=0) + im_shape = np.array([info["im_shape"] for info in im_info], dtype="float32") + scale_factor = np.array([info["scale_factor"] for info in im_info], dtype="float32") + + inputs["im_shape"] = np.concatenate(im_shape, axis=0) + inputs["scale_factor"] = np.concatenate(scale_factor, axis=0) imgs_shape = [[e.shape[1], e.shape[2]] for e in imgs] max_shape_h = max([e[0] for e in imgs_shape]) @@ -207,11 +189,10 @@ class Recognizer: padding_imgs = [] for img in imgs: im_c, im_h, im_w = img.shape[:] - padding_im = np.zeros( - (im_c, max_shape_h, max_shape_w), dtype=np.float32) + padding_im = np.zeros((im_c, max_shape_h, max_shape_w), dtype=np.float32) padding_im[:, :im_h, :im_w] = img padding_imgs.append(padding_im) - inputs['image'] = np.stack(padding_imgs, axis=0) + inputs["image"] = np.stack(padding_imgs, axis=0) return inputs @staticmethod @@ -254,10 +235,10 @@ class Recognizer: if not boxes: return min_dis, min_i = 1000000, None - for i,b in enumerate(boxes): + for i, b in enumerate(boxes): if box.get("layoutno", "0") != b.get("layoutno", "0"): continue - dis = min(abs(box["x0"] - b["x0"]), abs(box["x1"] - b["x1"]), abs(box["x0"]+box["x1"] - b["x1"] - b["x0"])/2) + dis = min(abs(box["x0"] - b["x0"]), abs(box["x1"] - b["x1"]), abs(box["x0"] + box["x1"] - b["x1"] - b["x0"]) / 2) if dis < min_dis: min_i = i min_dis = dis @@ -285,30 +266,29 @@ class Recognizer: if "scale_factor" in self.input_names: preprocess_ops = [] for op_info in [ - {'interp': 2, 'keep_ratio': False, 'target_size': [800, 608], 'type': 'LinearResize'}, - {'is_scale': True, 'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225], 'type': 'StandardizeImage'}, - {'type': 'Permute'}, - {'stride': 32, 'type': 'PadStride'} + {"interp": 2, "keep_ratio": False, "target_size": [800, 608], "type": "LinearResize"}, + {"is_scale": True, "mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225], "type": "StandardizeImage"}, + {"type": "Permute"}, + {"stride": 32, "type": "PadStride"}, ]: new_op_info = op_info.copy() - op_type = new_op_info.pop('type') + op_type = new_op_info.pop("type") preprocess_ops.append(getattr(operators, op_type)(**new_op_info)) for im_path in image_list: im, im_info = preprocess(im_path, preprocess_ops) - inputs.append({"image": np.array((im,)).astype('float32'), - "scale_factor": np.array((im_info["scale_factor"],)).astype('float32')}) + inputs.append({"image": np.array((im,)).astype("float32"), "scale_factor": np.array((im_info["scale_factor"],)).astype("float32")}) else: hh, ww = self.input_shape for img in image_list: h, w = img.shape[:2] img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - img = cv2.resize(np.array(img).astype('float32'), (ww, hh)) + img = cv2.resize(np.array(img).astype("float32"), (ww, hh)) # Scale input pixel values to 0 to 1 img /= 255.0 img = img.transpose(2, 0, 1) img = img[np.newaxis, :, :, :].astype(np.float32) - inputs.append({self.input_names[0]: img, "scale_factor": [w/ww, h/hh]}) + inputs.append({self.input_names[0]: img, "scale_factor": [w / ww, h / hh]}) return inputs def postprocess(self, boxes, inputs, thr): @@ -320,11 +300,7 @@ class Recognizer: continue if clsid >= len(self.label_list): continue - bb.append({ - "type": self.label_list[clsid].lower(), - "bbox": [float(t) for t in bbox.tolist()], - "score": float(score) - }) + bb.append({"type": self.label_list[clsid].lower(), "bbox": [float(t) for t in bbox.tolist()], "score": float(score)}) return bb def xywh2xyxy(x): @@ -400,11 +376,7 @@ class Recognizer: class_keep_boxes = iou_filter(class_boxes, class_scores, 0.2) indices.extend(class_indices[class_keep_boxes]) - return [{ - "type": self.label_list[class_ids[i]].lower(), - "bbox": [float(t) for t in boxes[i].tolist()], - "score": float(scores[i]) - } for i in indices] + return [{"type": self.label_list[class_ids[i]].lower(), "bbox": [float(t) for t in boxes[i].tolist()], "score": float(scores[i])} for i in indices] def close(self): logging.info("Close recognizer.") @@ -429,14 +401,12 @@ class Recognizer: inputs = self.preprocess(batch_image_list) logging.debug("preprocess") for ins in inputs: - bb = self.postprocess(self.ort_sess.run(None, {k:v for k,v in ins.items() if k in self.input_names}, self.run_options)[0], ins, thr) + bb = self.postprocess(self.ort_sess.run(None, {k: v for k, v in ins.items() if k in self.input_names}, self.run_options)[0], ins, thr) res.append(bb) - #seeit.save_results(image_list, res, self.label_list, threshold=thr) + # seeit.save_results(image_list, res, self.label_list, threshold=thr) return res def __del__(self): self.close() - - diff --git a/deepdoc/vision/seeit.py b/deepdoc/vision/seeit.py index e90c09ed80..f0db23239b 100644 --- a/deepdoc/vision/seeit.py +++ b/deepdoc/vision/seeit.py @@ -20,7 +20,7 @@ import PIL from PIL import ImageDraw -def save_results(image_list, results, labels, output_dir='output/', threshold=0.5): +def save_results(image_list, results, labels, output_dir="output/", threshold=0.5): if not os.path.exists(output_dir): os.makedirs(output_dir) for idx, im in enumerate(image_list): @@ -35,23 +35,18 @@ def draw_box(im, result, labels, threshold=0.5): draw_thickness = min(im.size) // 320 draw = ImageDraw.Draw(im) color_list = get_color_map_list(len(labels)) - clsid2color = {n.lower():color_list[i] for i,n in enumerate(labels)} + clsid2color = {n.lower(): color_list[i] for i, n in enumerate(labels)} result = [r for r in result if r["score"] >= threshold] for dt in result: color = tuple(clsid2color[dt["type"]]) xmin, ymin, xmax, ymax = dt["bbox"] - draw.line( - [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin), - (xmin, ymin)], - width=draw_thickness, - fill=color) + draw.line([(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin), (xmin, ymin)], width=draw_thickness, fill=color) # draw label text = "{} {:.4f}".format(dt["type"], dt["score"]) tw, th = imagedraw_textsize_c(draw, text) - draw.rectangle( - [(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill=color) + draw.rectangle([(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill=color) draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255)) return im @@ -68,17 +63,17 @@ def get_color_map_list(num_classes): j = 0 lab = i while lab: - color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j)) - color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j)) - color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j)) + color_map[i * 3] |= ((lab >> 0) & 1) << (7 - j) + color_map[i * 3 + 1] |= ((lab >> 1) & 1) << (7 - j) + color_map[i * 3 + 2] |= ((lab >> 2) & 1) << (7 - j) j += 1 lab >>= 3 - color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)] + color_map = [color_map[i : i + 3] for i in range(0, len(color_map), 3)] return color_map def imagedraw_textsize_c(draw, text): - if int(PIL.__version__.split('.')[0]) < 10: + if int(PIL.__version__.split(".")[0]) < 10: tw, th = draw.textsize(text) else: left, top, right, bottom = draw.textbbox((0, 0), text) diff --git a/deepdoc/vision/t_ocr.py b/deepdoc/vision/t_ocr.py index 58ada1b15e..2b31d27f3d 100644 --- a/deepdoc/vision/t_ocr.py +++ b/deepdoc/vision/t_ocr.py @@ -22,13 +22,7 @@ import sys from common.misc_utils import thread_pool_exec -sys.path.insert( - 0, - os.path.abspath( - os.path.join( - os.path.dirname( - os.path.abspath(__file__)), - '../../'))) +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../"))) from deepdoc.vision.seeit import draw_box from deepdoc.vision import OCR, init_in_out @@ -36,7 +30,7 @@ import argparse import numpy as np # os.environ['CUDA_VISIBLE_DEVICES'] = '0,2' #2 gpus, uncontinuous -os.environ['CUDA_VISIBLE_DEVICES'] = '0' #1 gpu +os.environ["CUDA_VISIBLE_DEVICES"] = "0" # 1 gpu # os.environ['CUDA_VISIBLE_DEVICES'] = '' #cpu @@ -52,19 +46,15 @@ def main(args): print("Task {} start".format(i)) bxs = ocr(np.array(img), id) bxs = [(line[0], line[1][0]) for line in bxs] - bxs = [{ - "text": t, - "bbox": [b[0][0], b[0][1], b[1][0], b[-1][1]], - "type": "ocr", - "score": 1} for b, t in bxs if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]] - img = draw_box(images[i], bxs, ["ocr"], 1.) + bxs = [{"text": t, "bbox": [b[0][0], b[0][1], b[1][0], b[-1][1]], "type": "ocr", "score": 1} for b, t in bxs if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]] + img = draw_box(images[i], bxs, ["ocr"], 1.0) img.save(outputs[i], quality=95) - with open(outputs[i] + ".txt", "w+", encoding='utf-8') as f: + with open(outputs[i] + ".txt", "w+", encoding="utf-8") as f: f.write("\n".join([o["text"] for o in bxs])) print("Task {} done".format(i)) - async def __ocr_thread(i, id, img, limiter = None): + async def __ocr_thread(i, id, img, limiter=None): if limiter: async with limiter: print(f"Task {i} use device {id}") @@ -72,7 +62,6 @@ def main(args): else: await thread_pool_exec(__ocr, i, id, img) - async def __ocr_launcher(): tasks = [] for i, img in enumerate(images): @@ -96,10 +85,7 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--inputs', - help="Directory where to store images or PDFs, or a file path to a single image or PDF", - required=True) - parser.add_argument('--output_dir', help="Directory where to store the output images. Default: './ocr_outputs'", - default="./ocr_outputs") + parser.add_argument("--inputs", help="Directory where to store images or PDFs, or a file path to a single image or PDF", required=True) + parser.add_argument("--output_dir", help="Directory where to store the output images. Default: './ocr_outputs'", default="./ocr_outputs") args = parser.parse_args() main(args) diff --git a/deepdoc/vision/t_recognizer.py b/deepdoc/vision/t_recognizer.py index 264014c860..9e147aab29 100644 --- a/deepdoc/vision/t_recognizer.py +++ b/deepdoc/vision/t_recognizer.py @@ -18,13 +18,7 @@ import logging import os import sys -sys.path.insert( - 0, - os.path.abspath( - os.path.join( - os.path.dirname( - os.path.abspath(__file__)), - '../../'))) +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../"))) from deepdoc.vision.seeit import draw_box from deepdoc.vision import LayoutRecognizer, TableStructureRecognizer, OCR, init_in_out @@ -44,15 +38,11 @@ def main(args): layouts = detr(images, thr=float(args.threshold)) for i, lyt in enumerate(layouts): if args.mode.lower() == "tsr": - #lyt = [t for t in lyt if t["type"] == "table column"] + # lyt = [t for t in lyt if t["type"] == "table column"] html = get_table_html(images[i], lyt, ocr) - with open(outputs[i] + ".html", "w+", encoding='utf-8') as f: + with open(outputs[i] + ".html", "w+", encoding="utf-8") as f: f.write(html) - lyt = [{ - "type": t["label"], - "bbox": [t["x0"], t["top"], t["x1"], t["bottom"]], - "score": t["score"] - } for t in lyt] + lyt = [{"type": t["label"], "bbox": [t["x0"], t["top"], t["x1"], t["bottom"]], "score": t["score"]} for t in lyt] img = draw_box(images[i], lyt, detr.labels, float(args.threshold)) img.save(outputs[i], quality=95) logging.info("save result to: " + outputs[i]) @@ -61,26 +51,20 @@ def main(args): def get_table_html(img, tb_cpns, ocr): boxes = ocr(np.array(img)) boxes = LayoutRecognizer.sort_Y_firstly( - [{"x0": b[0][0], "x1": b[1][0], - "top": b[0][1], "text": t[0], - "bottom": b[-1][1], - "layout_type": "table", - "page_number": 0} for b, t in boxes if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]], - np.mean([b[-1][1] - b[0][1] for b, _ in boxes]) / 3 + [{"x0": b[0][0], "x1": b[1][0], "top": b[0][1], "text": t[0], "bottom": b[-1][1], "layout_type": "table", "page_number": 0} for b, t in boxes if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]], + np.mean([b[-1][1] - b[0][1] for b, _ in boxes]) / 3, ) def gather(kwd, fzy=10, ption=0.6): nonlocal boxes - eles = LayoutRecognizer.sort_Y_firstly( - [r for r in tb_cpns if re.match(kwd, r["label"])], fzy) + eles = LayoutRecognizer.sort_Y_firstly([r for r in tb_cpns if re.match(kwd, r["label"])], fzy) eles = LayoutRecognizer.layouts_cleanup(boxes, eles, 5, ption) return LayoutRecognizer.sort_Y_firstly(eles, 0) headers = gather(r".*header$") rows = gather(r".* (row|header)") spans = gather(r".*spanning") - clmns = sorted([r for r in tb_cpns if re.match( - r"table column$", r["label"])], key=lambda x: x["x0"]) + clmns = sorted([r for r in tb_cpns if re.match(r"table column$", r["label"])], key=lambda x: x["x0"]) clmns = LayoutRecognizer.layouts_cleanup(boxes, clmns, 5, 0.5) for b in boxes: @@ -171,16 +155,9 @@ def get_table_html(img, tb_cpns, ocr): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--inputs', - help="Directory where to store images or PDFs, or a file path to a single image or PDF", - required=True) - parser.add_argument('--output_dir', help="Directory where to store the output images. Default: './layouts_outputs'", - default="./layouts_outputs") - parser.add_argument( - '--threshold', - help="A threshold to filter out detections. Default: 0.5", - default=0.5) - parser.add_argument('--mode', help="Task mode: layout recognition or table structure recognition", choices=["layout", "tsr"], - default="layout") + parser.add_argument("--inputs", help="Directory where to store images or PDFs, or a file path to a single image or PDF", required=True) + parser.add_argument("--output_dir", help="Directory where to store the output images. Default: './layouts_outputs'", default="./layouts_outputs") + parser.add_argument("--threshold", help="A threshold to filter out detections. Default: 0.5", default=0.5) + parser.add_argument("--mode", help="Task mode: layout recognition or table structure recognition", choices=["layout", "tsr"], default="layout") args = parser.parse_args() main(args) diff --git a/example/sdk/chat_assistant_example.py b/example/sdk/chat_assistant_example.py index 6c2e38f534..b282f7fe91 100644 --- a/example/sdk/chat_assistant_example.py +++ b/example/sdk/chat_assistant_example.py @@ -15,7 +15,7 @@ # """ -The example demonstrates how to create a chat assistant, manage sessions, +The example demonstrates how to create a chat assistant, manage sessions, and perform both standard and streaming chat. """ @@ -39,7 +39,7 @@ try: name="Test Assistant", dataset_ids=[dataset.id], llm_id="deepseek-chat", # Example LLM ID, replace with your actual model ID - prompt_config={"system": "You are a helpful assistant."} + prompt_config={"system": "You are a helpful assistant."}, ) print(f"Assistant created: {assistant.name} (ID: {assistant.id})") @@ -52,12 +52,12 @@ try: print("\n--- Standard Chat ---") question = "What is RAGFlow?" print(f"User: {question}") - + # ask returns a generator of Message objects # for stream=False, it yields once with the full answer for message in session.ask(question=question, stream=False): print(f"Assistant: {message.content}") - if hasattr(message, 'reference') and message.reference: + if hasattr(message, "reference") and message.reference: print(f"References used: {len(message.reference)} chunks") # 5. Streaming chat @@ -65,10 +65,10 @@ try: question = "Tell me more about its features." print(f"User: {question}") print("Assistant: ", end="", flush=True) - + for message in session.ask(question=question, stream=True): # In streaming mode, each message.content usually contains the incremental part - # or the full content so far depending on the SDK implementation. + # or the full content so far depending on the SDK implementation. # Based on RAGFlow SDK, it typically yields incremental parts. print(message.content, end="", flush=True) print("\n") diff --git a/example/sdk/chunk_example.py b/example/sdk/chunk_example.py index aed2d9b235..9aac7e745f 100644 --- a/example/sdk/chunk_example.py +++ b/example/sdk/chunk_example.py @@ -38,21 +38,21 @@ try: print("Uploading document...") # Using a simple text content for example content = "RAGFlow is an open-source RAG (Retrieval-Augmented Generation) engine based on deep document understanding." - docs = dataset.upload_documents([{"display_name": "sample.txt", "blob": content.encode('utf-8')}]) + docs = dataset.upload_documents([{"display_name": "sample.txt", "blob": content.encode("utf-8")}]) doc = docs[0] # 3. Parse the document (required before manual chunk operations if you want it to be processed) print("Parsing document...") dataset.async_parse_documents([doc.id]) - + # Wait for parsing to complete with timeout MAX_WAIT = 120 # seconds elapsed = 0 while elapsed < MAX_WAIT: doc_status = dataset.list_documents(id=doc.id)[0] if doc_status.run == "1" and doc_status.progress >= 1.0: - print("Parsing completed.") - break + print("Parsing completed.") + break print(f"Parsing progress: {doc_status.progress:.2f}") time.sleep(2) elapsed += 2 @@ -75,7 +75,7 @@ try: # 6. Update a chunk print("Updating chunk...") chunk.update({"content": "RAGFlow features a streamlined and powerful RAG workflow."}) - + # 7. Delete the chunk print("Deleting chunk...") doc.delete_chunks([chunk.id]) @@ -83,7 +83,7 @@ try: # Cleanup print("Cleaning up dataset...") rag.delete_datasets(ids=[dataset.id]) - + print("Chunk example done.") sys.exit(0) diff --git a/example/sdk/dataset_example.py b/example/sdk/dataset_example.py index a3931f1432..c15e959f06 100644 --- a/example/sdk/dataset_example.py +++ b/example/sdk/dataset_example.py @@ -32,7 +32,7 @@ try: dataset_instance = ragflow_instance.create_dataset(name="dataset_instance") # update the dataset instance - updated_message = {"name":"updated_dataset"} + updated_message = {"name": "updated_dataset"} updated_dataset = dataset_instance.update(updated_message) # get the dataset (list datasets) @@ -49,5 +49,3 @@ try: except Exception as e: print(str(e)) sys.exit(-1) - - diff --git a/example/sdk/retrieval_example.py b/example/sdk/retrieval_example.py index 70afa776c4..f48885ef2b 100644 --- a/example/sdk/retrieval_example.py +++ b/example/sdk/retrieval_example.py @@ -37,9 +37,9 @@ try: # 2. Upload and parse a document to have content for retrieval print("Uploading and parsing document...") content = "RAGFlow is an open-source RAG engine based on deep document understanding. It features a streamlined RAG workflow for businesses of any size." - docs = dataset.upload_documents([{"display_name": "ragflow_info.txt", "blob": content.encode('utf-8')}]) + docs = dataset.upload_documents([{"display_name": "ragflow_info.txt", "blob": content.encode("utf-8")}]) doc = docs[0] - + # Wait for parsing to complete with timeout print("Parsing document...") dataset.async_parse_documents([doc.id]) @@ -48,7 +48,7 @@ try: while elapsed < MAX_WAIT: doc_status = dataset.list_documents(id=doc.id)[0] if doc_status.run == "1" and doc_status.progress >= 1.0: - break + break print(f"Parsing progress: {doc_status.progress:.2f}") time.sleep(2) elapsed += 2 @@ -61,18 +61,13 @@ try: print("\n--- Performing Retrieval ---") question = "What is RAGFlow?" print(f"Question: {question}") - + # Retrieve relevant chunks from one or more datasets - chunks = rag.retrieve( - dataset_ids=[dataset.id], - question=question, - top_k=5, - similarity_threshold=0.1 - ) + chunks = rag.retrieve(dataset_ids=[dataset.id], question=question, top_k=5, similarity_threshold=0.1) print(f"Found {len(chunks)} relevant chunks:") for i, chunk in enumerate(chunks): - print(f"\nChunk {i+1}:") + print(f"\nChunk {i + 1}:") print(f"Content: {chunk.content[:200]}...") print(f"Similarity Score: {chunk.similarity:.4f}") print(f"Source Document: {chunk.document_name}") @@ -83,10 +78,10 @@ try: dataset_ids=[dataset.id], question="workflow for businesses", top_k=3, - keyword=True # Enable keyword search in addition to semantic search + keyword=True, # Enable keyword search in addition to semantic search ) for i, chunk in enumerate(chunks): - print(f"Chunk {i+1}: {chunk.content[:100]}... (Score: {chunk.similarity:.4f})") + print(f"Chunk {i + 1}: {chunk.content[:100]}... (Score: {chunk.similarity:.4f})") # Cleanup print("\nCleaning up...") diff --git a/internal/agent/canvas/runner.go b/internal/agent/canvas/runner.go index 961a2b8759..0640d00b91 100644 --- a/internal/agent/canvas/runner.go +++ b/internal/agent/canvas/runner.go @@ -335,13 +335,13 @@ func (r *Runner) Run( } } } - push(out, RunEvent{Type: "waiting_for_user", Data: safeEventJSON(waiting), MessageID: messageID, CreatedAt: nowUnix(), TaskID: taskID, SessionID: sessionID}) - // Always close a RunAgent call with the `done` - // terminator so the front-end can rely on a - // channel-end sentinel regardless of whether the run - // completed, errored, or paused for user input. - push(out, RunEvent{Type: "done", Data: "", MessageID: messageID, CreatedAt: nowUnix(), TaskID: taskID, SessionID: sessionID}) - return + push(out, RunEvent{Type: "waiting_for_user", Data: safeEventJSON(waiting), MessageID: messageID, CreatedAt: nowUnix(), TaskID: taskID, SessionID: sessionID}) + // Always close a RunAgent call with the `done` + // terminator so the front-end can rely on a + // channel-end sentinel regardless of whether the run + // completed, errored, or paused for user input. + push(out, RunEvent{Type: "done", Data: "", MessageID: messageID, CreatedAt: nowUnix(), TaskID: taskID, SessionID: sessionID}) + return } if IsInterruptError(runErr) { // Raw InterruptSignal (no wrapped InterruptCtx list diff --git a/internal/agent/tool/bgpt.go b/internal/agent/tool/bgpt.go index 106bd7cb8d..8dc69276c9 100644 --- a/internal/agent/tool/bgpt.go +++ b/internal/agent/tool/bgpt.go @@ -41,29 +41,29 @@ var bgptEndpoint = "https://bgpt.pro/api/mcp-search" // bgptParams is the JSON shape the model sends into InvokableRun. type bgptParams struct { - Query string `json:"query"` + Query string `json:"query"` MaxResults int `json:"num_results"` - APIKey string `json:"api_key,omitempty"` - DaysBack int `json:"days_back,omitempty"` + APIKey string `json:"api_key,omitempty"` + DaysBack int `json:"days_back,omitempty"` } // bgptResult is one paper in the result list. type bgptResult struct { - Title string `json:"title"` - Authors string `json:"authors"` - Journal string `json:"journal"` - Year string `json:"year"` - DOI string `json:"doi"` - URL string `json:"url"` - Abstract string `json:"abstract"` - Methods string `json:"methods"` - SampleSize string `json:"sample_size"` - Results string `json:"results"` - Limitations string `json:"limitations"` - ConflictOfInterest string `json:"conflict_of_interest"` - DataAvailability string `json:"data_availability"` - BlindSpots string `json:"blind_spots"` - Falsify string `json:"falsify"` + Title string `json:"title"` + Authors string `json:"authors"` + Journal string `json:"journal"` + Year string `json:"year"` + DOI string `json:"doi"` + URL string `json:"url"` + Abstract string `json:"abstract"` + Methods string `json:"methods"` + SampleSize string `json:"sample_size"` + Results string `json:"results"` + Limitations string `json:"limitations"` + ConflictOfInterest string `json:"conflict_of_interest"` + DataAvailability string `json:"data_availability"` + BlindSpots string `json:"blind_spots"` + Falsify string `json:"falsify"` } // bgptEnv is what the model sees. @@ -142,7 +142,7 @@ func (b *BGPTTool) InvokableRun(ctx context.Context, argsJSON string, _ ...tool. } reqBody := map[string]interface{}{ - "query": strings.TrimSpace(p.Query), + "query": strings.TrimSpace(p.Query), "num_results": p.MaxResults, } if p.APIKey != "" { @@ -196,21 +196,21 @@ func (b *BGPTTool) InvokableRun(ctx context.Context, argsJSON string, _ ...tool. results := make([]bgptResult, 0, len(raw.Results)) for _, r := range raw.Results { results = append(results, bgptResult{ - Title: strVal(r["title"]), - Authors: strVal(r["authors"]), - Journal: strVal(r["journal"]), - Year: strVal(r["year"]), - DOI: strVal(r["doi"]), - URL: strVal(r["url"]), - Abstract: strVal(r["abstract"]), - Methods: firstStr(r, "methods_and_experimental_techniques", "methods"), - SampleSize: firstStr(r, "sample_size_and_population_characteristics", "sample_size_and_population"), - Results: firstStr(r, "results_and_conclusions", "results"), - Limitations: firstStr(r, "paper_limitations_and_biases", "limitations"), - ConflictOfInterest: firstStr(r, "conflict_of_interest_statements", "conflict_of_interest"), - DataAvailability: firstStr(r, "data_availability_statements", "data_availability"), - BlindSpots: strVal(r["study_blindspots"]), - Falsify: strVal(r["how_to_falsify"]), + Title: strVal(r["title"]), + Authors: strVal(r["authors"]), + Journal: strVal(r["journal"]), + Year: strVal(r["year"]), + DOI: strVal(r["doi"]), + URL: strVal(r["url"]), + Abstract: strVal(r["abstract"]), + Methods: firstStr(r, "methods_and_experimental_techniques", "methods"), + SampleSize: firstStr(r, "sample_size_and_population_characteristics", "sample_size_and_population"), + Results: firstStr(r, "results_and_conclusions", "results"), + Limitations: firstStr(r, "paper_limitations_and_biases", "limitations"), + ConflictOfInterest: firstStr(r, "conflict_of_interest_statements", "conflict_of_interest"), + DataAvailability: firstStr(r, "data_availability_statements", "data_availability"), + BlindSpots: strVal(r["study_blindspots"]), + Falsify: strVal(r["how_to_falsify"]), }) } @@ -241,5 +241,3 @@ func firstStr(m map[string]interface{}, keys ...string) string { } return "" } - - diff --git a/internal/cli/response.go b/internal/cli/response.go index 36a143186d..17eb7edc69 100644 --- a/internal/cli/response.go +++ b/internal/cli/response.go @@ -1293,11 +1293,11 @@ func (r *QuotaSummaryResponse) PrintOut() { // // {"code":0,"data":{"answer":"...","reference":...},"message":""} type ChatCompletionsResponse struct { - Code int `json:"code"` + Code int `json:"code"` Data *chatCompletionData `json:"data"` - Message string `json:"message"` - Duration float64 `json:"-"` - OutputFormat OutputFormat `json:"-"` + Message string `json:"message"` + Duration float64 `json:"-"` + OutputFormat OutputFormat `json:"-"` // raw HTTP body for "raw" output. raw []byte // streamed skips the "Answer:" line in PrintOut to avoid duplication @@ -1313,8 +1313,8 @@ type chatCompletionData struct { ChatID string `json:"chat_id,omitempty"` } -func (r *ChatCompletionsResponse) Type() string { return "chat_completions" } -func (r *ChatCompletionsResponse) TimeCost() float64 { return r.Duration } +func (r *ChatCompletionsResponse) Type() string { return "chat_completions" } +func (r *ChatCompletionsResponse) TimeCost() float64 { return r.Duration } func (r *ChatCompletionsResponse) SetOutputFormat(f OutputFormat) { r.OutputFormat = f } func (r *ChatCompletionsResponse) PrintOut() { diff --git a/internal/cli/user_parser.go b/internal/cli/user_parser.go index 8a9791d2ff..9f310957bf 100644 --- a/internal/cli/user_parser.go +++ b/internal/cli/user_parser.go @@ -3769,7 +3769,6 @@ optionsLoop: // CHAT COMPLETIONS // [chat_id ] [session ] [llm ] - // parseChatCompletionsBody parses the question and options of a CHAT COMPLETIONS // command. The leading keyword(s) must already have been consumed by the caller. func (p *Parser) parseChatCompletionsBody() (*Command, error) { diff --git a/internal/common/time.go b/internal/common/time.go index ed2127315b..f872d48f73 100644 --- a/internal/common/time.go +++ b/internal/common/time.go @@ -40,10 +40,10 @@ func ParseISO8601(dateString string) (time.Time, error) { } layouts := []string{ - time.RFC3339Nano, // "2006-01-02T15:04:05.999999999Z07:00" - time.RFC3339, // "2006-01-02T15:04:05Z07:00" - "2006-01-02T15:04:05", // no timezone - "2006-01-02", // date only + time.RFC3339Nano, // "2006-01-02T15:04:05.999999999Z07:00" + time.RFC3339, // "2006-01-02T15:04:05Z07:00" + "2006-01-02T15:04:05", // no timezone + "2006-01-02", // date only } for _, layout := range layouts { var t time.Time diff --git a/internal/deepdoc/parser/pdf/parser_test.go b/internal/deepdoc/parser/pdf/parser_test.go index 6deb03f6c5..db37502918 100644 --- a/internal/deepdoc/parser/pdf/parser_test.go +++ b/internal/deepdoc/parser/pdf/parser_test.go @@ -3,10 +3,10 @@ package pdf import ( "context" "image" + "math" "strings" "sync" "testing" - "math" lyt "ragflow/internal/deepdoc/parser/pdf/layout" tbl "ragflow/internal/deepdoc/parser/pdf/table" diff --git a/internal/handler/chunk_test.go b/internal/handler/chunk_test.go index 1b6decbe15..77ec780ecb 100644 --- a/internal/handler/chunk_test.go +++ b/internal/handler/chunk_test.go @@ -182,9 +182,9 @@ func TestChunkHandlerListChunksMapsAvailableFalse(t *testing.T) { t.Fatalf("available_int = %v, want 0", req.AvailableInt) } return &service.ListChunksResponse{ - Total: 0, + Total: 0, Chunks: []map[string]interface{}{}, - Doc: map[string]interface{}{"id": "doc-1"}, + Doc: map[string]interface{}{"id": "doc-1"}, }, nil } diff --git a/memory/services/messages.py b/memory/services/messages.py index 748321385c..8863f5a2f5 100644 --- a/memory/services/messages.py +++ b/memory/services/messages.py @@ -21,16 +21,17 @@ from common import settings from common.constants import MemoryType from common.doc_store.doc_store_base import OrderByExpr, MatchExpr + def _es_index_prefix() -> str: return os.environ.get("ES_INDEX_PREFIX", "").strip() + def index_name(uid: str): prefix = _es_index_prefix() return f"memory_{prefix}_{uid}" if prefix else f"memory_{uid}" class MessageService: - @classmethod def has_index(cls, uid: str, memory_id: str): index = index_name(uid) @@ -49,10 +50,7 @@ class MessageService: @classmethod def insert_message(cls, messages: List[dict], uid: str, memory_id: str): index = index_name(uid) - [m.update({ - "id": f'{memory_id}_{m["message_id"]}', - "status": 1 if m["status"] else 0 - }) for m in messages] + [m.update({"id": f"{memory_id}_{m['message_id']}", "status": 1 if m["status"] else 0}) for m in messages] return settings.msgStoreConn.insert(messages, index, memory_id) @classmethod @@ -68,32 +66,31 @@ class MessageService: return settings.msgStoreConn.delete(condition, index, memory_id) @classmethod - def list_message(cls, uid: str, memory_id: str, agent_ids: List[str]=None, keywords: str=None, page: int=1, page_size: int=50): + def list_message(cls, uid: str, memory_id: str, agent_ids: List[str] = None, keywords: str = None, page: int = 1, page_size: int = 50): index = index_name(uid) filter_dict = {} if agent_ids: filter_dict["agent_id"] = agent_ids if keywords: filter_dict["session_id"] = keywords - select_fields = [ - "message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id", "valid_at", - "invalid_at", "forget_at", "status" - ] + select_fields = ["message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id", "valid_at", "invalid_at", "forget_at", "status"] order_by = OrderByExpr() order_by.desc("valid_at") res, total_count = settings.msgStoreConn.search( select_fields=select_fields, highlight_fields=[], condition={**filter_dict, "message_type": MemoryType.RAW.name.lower()}, - match_expressions=[], order_by=order_by, - offset=(page-1)*page_size, limit=page_size, - index_names=index, memory_ids=[memory_id], agg_fields=[], hide_forgotten=False + match_expressions=[], + order_by=order_by, + offset=(page - 1) * page_size, + limit=page_size, + index_names=index, + memory_ids=[memory_id], + agg_fields=[], + hide_forgotten=False, ) if not total_count: - return { - "message_list": [], - "total_count": 0 - } + return {"message_list": [], "total_count": 0} raw_msg_mapping = settings.msgStoreConn.get_fields(res, select_fields) raw_messages = list(raw_msg_mapping.values()) @@ -102,9 +99,14 @@ class MessageService: select_fields=select_fields, highlight_fields=[], condition=extract_filter, - match_expressions=[], order_by=order_by, - offset=0, limit=512, - index_names=index, memory_ids=[memory_id], agg_fields=[], hide_forgotten=False + match_expressions=[], + order_by=order_by, + offset=0, + limit=512, + index_names=index, + memory_ids=[memory_id], + agg_fields=[], + hide_forgotten=False, ) extract_msg = settings.msgStoreConn.get_fields(extract_res, select_fields) grouped_extract_msg = {} @@ -117,42 +119,36 @@ class MessageService: for raw_msg in raw_messages: raw_msg["extract"] = grouped_extract_msg.get(raw_msg["message_id"], []) - return { - "message_list": raw_messages, - "total_count": total_count - } + return {"message_list": raw_messages, "total_count": total_count} @classmethod def get_recent_messages(cls, uid_list: List[str], memory_ids: List[str], agent_id: str, session_id: str, limit: int): index_names = [index_name(uid) for uid in uid_list] - condition_dict = { - "agent_id": agent_id, - "session_id": session_id - } + condition_dict = {"agent_id": agent_id, "session_id": session_id} order_by = OrderByExpr() order_by.desc("valid_at") res, total_count = settings.msgStoreConn.search( - select_fields=[ - "message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id", "valid_at", - "invalid_at", "forget_at", "status", "content" - ], + select_fields=["message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id", "valid_at", "invalid_at", "forget_at", "status", "content"], highlight_fields=[], condition=condition_dict, - match_expressions=[], order_by=order_by, - offset=0, limit=limit, - index_names=index_names, memory_ids=memory_ids, agg_fields=[] + match_expressions=[], + order_by=order_by, + offset=0, + limit=limit, + index_names=index_names, + memory_ids=memory_ids, + agg_fields=[], ) if not total_count: return [] - doc_mapping = settings.msgStoreConn.get_fields(res, [ - "message_id", "message_type", "source_id", "memory_id","user_id", "agent_id", "session_id", - "valid_at", "invalid_at", "forget_at", "status", "content" - ]) + doc_mapping = settings.msgStoreConn.get_fields( + res, ["message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id", "valid_at", "invalid_at", "forget_at", "status", "content"] + ) return list(doc_mapping.values()) @classmethod - def search_message(cls, memory_ids: List[str], condition_dict: dict, uid_list: List[str], match_expressions:list[MatchExpr], top_n: int): + def search_message(cls, memory_ids: List[str], condition_dict: dict, uid_list: List[str], match_expressions: list[MatchExpr], top_n: int): index_names = [index_name(uid) for uid in uid_list] # filter only valid messages by default if "status" not in condition_dict: @@ -161,35 +157,29 @@ class MessageService: order_by = OrderByExpr() order_by.desc("valid_at") res, total_count = settings.msgStoreConn.search( - select_fields=[ - "message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id", - "valid_at", - "invalid_at", "forget_at", "status", "content" - ], + select_fields=["message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id", "valid_at", "invalid_at", "forget_at", "status", "content"], highlight_fields=[], condition=condition_dict, match_expressions=match_expressions, order_by=order_by, - offset=0, limit=top_n, - index_names=index_names, memory_ids=memory_ids, agg_fields=[] + offset=0, + limit=top_n, + index_names=index_names, + memory_ids=memory_ids, + agg_fields=[], ) if not total_count: return [] - docs = settings.msgStoreConn.get_fields(res, [ - "message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id", "valid_at", - "invalid_at", "forget_at", "status", "content" - ]) + docs = settings.msgStoreConn.get_fields( + res, ["message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id", "valid_at", "invalid_at", "forget_at", "status", "content"] + ) return list(docs.values()) @staticmethod def calculate_message_size(message: dict): content_embed = message.get("content_embed") - embed_size = ( - sys.getsizeof(content_embed[0]) * len(content_embed) - if content_embed is not None and len(content_embed) > 0 - else 0 - ) + embed_size = sys.getsizeof(content_embed[0]) * len(content_embed) if content_embed is not None and len(content_embed) > 0 else 0 return sys.getsizeof(message.get("content", "")) + embed_size @classmethod @@ -204,8 +194,12 @@ class MessageService: condition={}, match_expressions=[], order_by=order_by, - offset=0, limit=2048*len(memory_ids), - index_names=index_names, memory_ids=memory_ids, agg_fields=[], hide_forgotten=False + offset=0, + limit=2048 * len(memory_ids), + index_names=index_names, + memory_ids=memory_ids, + agg_fields=[], + hide_forgotten=False, ) if count == 0: @@ -246,8 +240,11 @@ class MessageService: condition={}, match_expressions=[], order_by=order_by, - offset=0, limit=512, - index_names=[_index_name], memory_ids=[memory_id], agg_fields=[] + offset=0, + limit=512, + index_names=[_index_name], + memory_ids=[memory_id], + agg_fields=[], ) docs = settings.msgStoreConn.get_fields(res, select_fields) for doc in docs.values(): @@ -262,12 +259,7 @@ class MessageService: def get_missing_field_messages(cls, memory_id: str, uid: str, field_name: str): select_fields = ["message_id", "content"] _index_name = index_name(uid) - res = settings.msgStoreConn.get_missing_field_message( - select_fields=select_fields, - index_name=_index_name, - memory_id=memory_id, - field_name=field_name - ) + res = settings.msgStoreConn.get_missing_field_message(select_fields=select_fields, index_name=_index_name, memory_id=memory_id, field_name=field_name) if not res: return [] docs = settings.msgStoreConn.get_fields(res, select_fields) @@ -276,7 +268,7 @@ class MessageService: @classmethod def get_by_message_id(cls, memory_id: str, message_id: int, uid: str): index = index_name(uid) - doc_id = f'{memory_id}_{message_id}' + doc_id = f"{memory_id}_{message_id}" return settings.msgStoreConn.get(doc_id, index, [memory_id]) @classmethod @@ -290,9 +282,12 @@ class MessageService: condition={}, match_expressions=[], order_by=order_by, - offset=0, limit=1, - index_names=index_names, memory_ids=memory_ids, - agg_fields=[], hide_forgotten=False + offset=0, + limit=1, + index_names=index_names, + memory_ids=memory_ids, + agg_fields=[], + hide_forgotten=False, ) if not total_count: return 1 diff --git a/memory/services/query.py b/memory/services/query.py index e2bce608b9..041fa31d49 100644 --- a/memory/services/query.py +++ b/memory/services/query.py @@ -23,6 +23,7 @@ from common.float_utils import get_float from rag.nlp import rag_tokenizer, term_weight, synonym from rag.utils.redis_conn import REDIS_CONN + def get_vector(txt, emb_mdl, topk=10, similarity=0.1): if isinstance(similarity, str) and len(similarity) > 0: try: @@ -33,23 +34,19 @@ def get_vector(txt, emb_mdl, topk=10, similarity=0.1): qv, _ = emb_mdl.encode_queries(txt) shape = np.array(qv).shape if len(shape) > 1: - raise Exception( - f"Dealer.get_vector returned array's shape {shape} doesn't match expectation(exact one dimension).") + raise Exception(f"Dealer.get_vector returned array's shape {shape} doesn't match expectation(exact one dimension).") embedding_data = [get_float(v) for v in qv] vector_column_name = f"q_{len(embedding_data)}_vec" - return MatchDenseExpr(vector_column_name, embedding_data, 'float', 'cosine', topk, {"similarity": similarity}) + return MatchDenseExpr(vector_column_name, embedding_data, "float", "cosine", topk, {"similarity": similarity}) class MsgTextQuery(QueryBase): - def __init__(self): self.tw = term_weight.Dealer() self.syn = synonym.Dealer(redis=REDIS_CONN.REDIS if REDIS_CONN.is_alive() else None) - self.query_fields = [ - "content" - ] + self.query_fields = ["content"] - def question(self, txt, tbl="messages", min_match: float=0.6): + def question(self, txt, tbl="messages", min_match: float = 0.6): original_query = txt txt = MsgTextQuery.add_space_between_eng_zh(txt) txt = re.sub( @@ -76,11 +73,10 @@ class MsgTextQuery(QueryBase): # (e.g. WordNet returns "cat-o'-nine-tails" for "cat") syn = re.sub(r"'", "", rag_tokenizer.tokenize(" ".join(syn))).split() keywords.extend(syn) - syn = ["\"{}\"^{:.4f}".format(s, w / 4.) for s in syn if s.strip()] + syn = ['"{}"^{:.4f}'.format(s, w / 4.0) for s in syn if s.strip()] syns.append(" ".join(syn)) - q = ["({}^{:.4f}".format(tk, w) + " {})".format(syn) for (tk, w), syn in zip(tks_w, syns) if - tk and not re.match(r"[.^+\(\)-]", tk)] + q = ["({}^{:.4f}".format(tk, w) + " {})".format(syn) for (tk, w), syn in zip(tks_w, syns) if tk and not re.match(r"[.^+\(\)-]", tk)] for i in range(1, len(tks_w)): left, right = tks_w[i - 1][0].strip(), tks_w[i][0].strip() if not left or not right: @@ -96,9 +92,7 @@ class MsgTextQuery(QueryBase): if not q: q.append(txt) query = " ".join(q) - return MatchTextExpr( - self.query_fields, query, 100, {"original_query": original_query} - ), keywords + return MatchTextExpr(self.query_fields, query, 100, {"original_query": original_query}), keywords def need_fine_grained_tokenize(tk): if len(tk) < 3: @@ -120,11 +114,7 @@ class MsgTextQuery(QueryBase): logging.debug(json.dumps(twts, ensure_ascii=False)) tms = [] for tk, w in sorted(twts, key=lambda x: x[1] * -1): - sm = ( - rag_tokenizer.fine_grained_tokenize(tk).split() - if need_fine_grained_tokenize(tk) - else [] - ) + sm = rag_tokenizer.fine_grained_tokenize(tk).split() if need_fine_grained_tokenize(tk) else [] sm = [ re.sub( r"[ ,\./;'\[\]\\`~!@#$%\^&\*\(\)=\+_<>\?:\"\{\}\|,。;‘’【】、!¥……()——《》?:“”-]+", @@ -145,7 +135,7 @@ class MsgTextQuery(QueryBase): if len(keywords) < 32: keywords.extend([s for s in tk_syns if s]) tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s] - tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns] + tk_syns = [f'"{s}"' if s.find(" ") > 0 else s for s in tk_syns] if len(keywords) >= 32: break @@ -165,13 +155,7 @@ class MsgTextQuery(QueryBase): if len(twts) > 1: tms += ' ("%s"~2)^1.5' % rag_tokenizer.tokenize(tt) - syns = " OR ".join( - [ - '"%s"' - % rag_tokenizer.tokenize(self.sub_special_char(s)) - for s in syns - ] - ) + syns = " OR ".join(['"%s"' % rag_tokenizer.tokenize(self.sub_special_char(s)) for s in syns]) if syns and tms: tms = f"({tms})^5 OR ({syns})^0.7" @@ -181,7 +165,5 @@ class MsgTextQuery(QueryBase): query = " OR ".join([f"({t})" for t in qs if t]) if not query: query = otxt - return MatchTextExpr( - self.query_fields, query, 100, {"minimum_should_match": min_match, "original_query": original_query} - ), keywords - return None, keywords \ No newline at end of file + return MatchTextExpr(self.query_fields, query, 100, {"minimum_should_match": min_match, "original_query": original_query}), keywords + return None, keywords diff --git a/memory/utils/es_conn.py b/memory/utils/es_conn.py index 60eda59f62..f49e39705f 100644 --- a/memory/utils/es_conn.py +++ b/memory/utils/es_conn.py @@ -34,7 +34,6 @@ ATTEMPT_TIME = 2 @singleton class ESConnection(ESConnectionBase): - @staticmethod def convert_field_name(field_name: str, use_tokenized_content=False) -> str: match field_name: @@ -111,18 +110,19 @@ class ESConnection(ESConnectionBase): """ def search( - self, select_fields: list[str], - highlight_fields: list[str], - condition: dict, - match_expressions: list[MatchExpr], - order_by: OrderByExpr, - offset: int, - limit: int, - index_names: str | list[str], - memory_ids: list[str], - agg_fields: list[str] | None = None, - rank_feature: dict | None = None, - hide_forgotten: bool = True + self, + select_fields: list[str], + highlight_fields: list[str], + condition: dict, + match_expressions: list[MatchExpr], + order_by: OrderByExpr, + offset: int, + limit: int, + index_names: str | list[str], + memory_ids: list[str], + agg_fields: list[str] | None = None, + rank_feature: dict | None = None, + hide_forgotten: bool = True, ): """ Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html @@ -154,15 +154,17 @@ class ESConnection(ESConnectionBase): elif isinstance(v, str) or isinstance(v, int): bool_query.filter.append(Q("term", **{field_name: v})) else: - raise Exception( - f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.") + raise Exception(f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.") s = Search() vector_similarity_weight = 0.5 for m in match_expressions: if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params: - assert len(match_expressions) == 3 and isinstance(match_expressions[0], MatchTextExpr) and isinstance(match_expressions[1], - MatchDenseExpr) and isinstance( - match_expressions[2], FusionExpr) + assert ( + len(match_expressions) == 3 + and isinstance(match_expressions[0], MatchTextExpr) + and isinstance(match_expressions[1], MatchDenseExpr) + and isinstance(match_expressions[2], FusionExpr) + ) weights = m.fusion_params["weights"] vector_similarity_weight = get_float(weights.split(",")[1]) for m in match_expressions: @@ -170,24 +172,31 @@ class ESConnection(ESConnectionBase): minimum_should_match = m.extra_options.get("minimum_should_match", 0.0) if isinstance(minimum_should_match, float): minimum_should_match = str(int(minimum_should_match * 100)) + "%" - bool_query.must.append(Q("query_string", fields=[self.convert_field_name(f, use_tokenized_content=True) for f in m.fields], - type="best_fields", query=m.matching_text, - minimum_should_match=minimum_should_match, - boost=1)) + bool_query.must.append( + Q( + "query_string", + fields=[self.convert_field_name(f, use_tokenized_content=True) for f in m.fields], + type="best_fields", + query=m.matching_text, + minimum_should_match=minimum_should_match, + boost=1, + ) + ) bool_query.boost = 1.0 - vector_similarity_weight elif isinstance(m, MatchDenseExpr): - assert (bool_query is not None) + assert bool_query is not None similarity = 0.0 if "similarity" in m.extra_options: similarity = m.extra_options["similarity"] - s = s.knn(self.convert_field_name(m.vector_column_name), - m.topn, - m.topn * 2, - query_vector=list(m.embedding_data), - filter=bool_query.to_dict(), - similarity=similarity, - ) + s = s.knn( + self.convert_field_name(m.vector_column_name), + m.topn, + m.topn * 2, + query_vector=list(m.embedding_data), + filter=bool_query.to_dict(), + similarity=similarity, + ) if bool_query and rank_feature: for fld, sc in rank_feature.items(): @@ -207,7 +216,7 @@ class ESConnection(ESConnectionBase): if field.endswith("_int") or field.endswith("_flt"): order_info = {"order": order, "unmapped_type": "float"} elif field == "id": - continue # id as "text", not a "keyword", order by it will cause error + continue # id as "text", not a "keyword", order by it will cause error else: order_info = {"order": order, "unmapped_type": "keyword"} orders.append({field: order_info}) @@ -215,22 +224,24 @@ class ESConnection(ESConnectionBase): if agg_fields: for fld in agg_fields: - s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000) + s.aggs.bucket(f"aggs_{fld}", "terms", field=fld, size=1000000) if limit > 0: - s = s[offset:offset + limit] + s = s[offset : offset + limit] q = s.to_dict() self.logger.debug(f"ESConnection.search {str(index_names)} query: " + json.dumps(q)) for i in range(ATTEMPT_TIME): try: - #print(json.dumps(q, ensure_ascii=False)) - res = self.es.search(index=exist_index_list, - body=q, - timeout="600s", - # search_type="dfs_query_then_fetch", - track_total_hits=True, - _source=True) + # print(json.dumps(q, ensure_ascii=False)) + res = self.es.search( + index=exist_index_list, + body=q, + timeout="600s", + # search_type="dfs_query_then_fetch", + track_total_hits=True, + _source=True, + ) if str(res.get("timed_out", "")).lower() == "true": raise Exception("Es Timeout.") self.logger.debug(f"ESConnection.search {str(index_names)} res: " + str(res)) @@ -249,7 +260,7 @@ class ESConnection(ESConnectionBase): self.logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!") raise Exception("ESConnection.search timeout.") - def get_forgotten_messages(self, select_fields: list[str], index_name: str, memory_id: str, limit: int=512): + def get_forgotten_messages(self, select_fields: list[str], index_name: str, memory_id: str, limit: int = 512): bool_query = Q("bool", must=[]) bool_query.must.append(Q("exists", field="forget_at")) bool_query.filter.append(Q("term", memory_id=memory_id)) @@ -292,7 +303,7 @@ class ESConnection(ESConnectionBase): self.logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!") raise Exception("ESConnection.search timeout.") - def get_missing_field_message(self, select_fields: list[str], index_name: str, memory_id: str, field_name: str, limit: int=512): + def get_missing_field_message(self, select_fields: list[str], index_name: str, memory_id: str, field_name: str, limit: int = 512): if not self.index_exist(index_name): return None bool_query = Q("bool", must=[]) @@ -340,8 +351,11 @@ class ESConnection(ESConnectionBase): def get(self, doc_id: str, index_name: str, memory_ids: list[str]) -> dict | None: for i in range(ATTEMPT_TIME): try: - res = self.es.get(index=index_name, - id=doc_id, source=True, ) + res = self.es.get( + index=index_name, + id=doc_id, + source=True, + ) if str(res.get("timed_out", "")).lower() == "true": raise Exception("Es Timeout.") message = res["_source"] @@ -365,15 +379,13 @@ class ESConnection(ESConnectionBase): d_copy = self.map_message_to_es_fields(d_copy_raw) d_copy["memory_id"] = memory_id meta_id = d_copy.pop("id", "") - operations.append( - {"index": {"_index": index_name, "_id": meta_id}}) + operations.append({"index": {"_index": index_name, "_id": meta_id}}) operations.append(d_copy) res = [] for _ in range(ATTEMPT_TIME): try: res = [] - r = self.es.bulk(index=index_name, operations=operations, - refresh=False, timeout="60s") + r = self.es.bulk(index=index_name, operations=operations, refresh=False, timeout="60s") if re.search(r"False", str(r["errors"]), re.IGNORECASE): return res @@ -409,15 +421,14 @@ class ESConnection(ESConnectionBase): if "feas" != k.split("_")[-1]: continue try: - self.es.update(index=index_name, id=message_id, script=f"ctx._source.remove(\"{k}\");") + self.es.update(index=index_name, id=message_id, script=f'ctx._source.remove("{k}");') except Exception: self.logger.exception(f"ESConnection.update(index={index_name}, id={message_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception") try: self.es.update(index=index_name, id=message_id, doc=update_dict) return True except Exception as e: - self.logger.exception( - f"ESConnection.update(index={index_name}, id={message_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception: " + str(e)) + self.logger.exception(f"ESConnection.update(index={index_name}, id={message_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception: " + str(e)) break return False @@ -434,8 +445,7 @@ class ESConnection(ESConnectionBase): elif isinstance(v, str) or isinstance(v, int): bool_query.filter.append(Q("term", **{k: v})) else: - raise Exception( - f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.") + raise Exception(f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.") scripts = [] params = {} for k, v in update_dict.items(): @@ -465,11 +475,8 @@ class ESConnection(ESConnectionBase): scripts.append(f"ctx._source.{k}=params.pp_{k};") params[f"pp_{k}"] = json.dumps(v, ensure_ascii=False) else: - raise Exception( - f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.") - ubq = UpdateByQuery( - index=index_name).using( - self.es).query(bool_query) + raise Exception(f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.") + ubq = UpdateByQuery(index=index_name).using(self.es).query(bool_query) ubq = ubq.script(source="".join(scripts), params=params) ubq = ubq.params(refresh=True) ubq = ubq.params(slices=5) @@ -521,10 +528,7 @@ class ESConnection(ESConnectionBase): self.logger.debug("ESConnection.delete query: " + json.dumps(qry.to_dict())) for _ in range(ATTEMPT_TIME): try: - res = self.es.delete_by_query( - index=index_name, - body=Search().query(qry).to_dict(), - refresh=True) + res = self.es.delete_by_query(index=index_name, body=Search().query(qry).to_dict(), refresh=True) return res["deleted"] except ConnectionTimeout: self.logger.exception("ES request timeout") diff --git a/memory/utils/infinity_conn.py b/memory/utils/infinity_conn.py index ae350c0c8e..feeabe591f 100644 --- a/memory/utils/infinity_conn.py +++ b/memory/utils/infinity_conn.py @@ -42,7 +42,7 @@ class InfinityConnection(InfinityConnectionBase): return False @staticmethod - def convert_message_field_to_infinity(field_name: str, table_fields: list[str]=None): + def convert_message_field_to_infinity(field_name: str, table_fields: list[str] = None): match field_name: case "message_type": return "message_type_kwd" @@ -68,7 +68,7 @@ class InfinityConnection(InfinityConnectionBase): return "content_embed" return field_name - def convert_select_fields(self, output_fields: list[str], table_fields: list[str]=None) -> list[str]: + def convert_select_fields(self, output_fields: list[str], table_fields: list[str] = None) -> list[str]: return list({self.convert_message_field_to_infinity(f, table_fields) for f in output_fields}) @staticmethod @@ -277,7 +277,7 @@ class InfinityConnection(InfinityConnectionBase): self.logger.debug(f"INFINITY search final result: {str(res)}") return res, total_hits_count - def get_forgotten_messages(self, select_fields: list[str], index_name: str, memory_id: str, limit: int=512): + def get_forgotten_messages(self, select_fields: list[str], index_name: str, memory_id: str, limit: int = 512): condition = {"memory_id": memory_id, "exists": "forget_at_flt"} order_by = OrderByExpr() order_by.asc("forget_at_flt") @@ -309,7 +309,7 @@ class InfinityConnection(InfinityConnectionBase): self.connPool.release_conn(inf_conn) return res - def get_missing_field_message(self, select_fields: list[str], index_name: str, memory_id: str, field_name: str, limit: int=512): + def get_missing_field_message(self, select_fields: list[str], index_name: str, memory_id: str, field_name: str, limit: int = 512): condition = {"memory_id": memory_id, "must_not": {"exists": field_name}} order_by = OrderByExpr() order_by.asc("valid_at_flt") diff --git a/memory/utils/msg_util.py b/memory/utils/msg_util.py index 71a9f6b767..0f4e90e8ca 100644 --- a/memory/utils/msg_util.py +++ b/memory/utils/msg_util.py @@ -27,9 +27,9 @@ def get_json_result_from_llm_response(response_str: str) -> dict: """ try: clean_str = response_str.strip() - if clean_str.startswith('```json'): + if clean_str.startswith("```json"): clean_str = clean_str[7:] # Remove the starting ```json - if clean_str.endswith('```'): + if clean_str.endswith("```"): clean_str = clean_str[:-3] # Remove the ending ``` return json.loads(clean_str.strip()) diff --git a/memory/utils/ob_conn.py b/memory/utils/ob_conn.py index f179992373..45f78c7802 100644 --- a/memory/utils/ob_conn.py +++ b/memory/utils/ob_conn.py @@ -75,7 +75,7 @@ class SearchResult(BaseModel): @singleton class OBConnection(OBConnectionBase): def __init__(self): - super().__init__(logger_name='ragflow.memory_ob_conn') + super().__init__(logger_name="ragflow.memory_ob_conn") self._fulltext_search_columns = FTS_COLUMNS """ @@ -101,10 +101,10 @@ class OBConnection(OBConnectionBase): def _get_vector_column_name_from_table(self, table_name: str) -> Optional[str]: """Get the vector column name from the table (q_{size}_vec pattern).""" sql = f""" - SELECT COLUMN_NAME - FROM INFORMATION_SCHEMA.COLUMNS - WHERE TABLE_SCHEMA = '{self.db_name}' - AND TABLE_NAME = '{table_name}' + SELECT COLUMN_NAME + FROM INFORMATION_SCHEMA.COLUMNS + WHERE TABLE_SCHEMA = '{self.db_name}' + AND TABLE_NAME = '{table_name}' AND COLUMN_NAME REGEXP '^q_[0-9]+_vec$' LIMIT 1 """ @@ -204,7 +204,7 @@ class OBConnection(OBConnectionBase): memory_ids: list[str], agg_fields: list[str] | None = None, rank_feature: dict | None = None, - hide_forgotten: bool = True + hide_forgotten: bool = True, ): """Search messages in memory storage.""" if isinstance(index_names, str): @@ -273,9 +273,7 @@ class OBConnection(OBConnectionBase): fulltext_query = escape_string(fulltext_query.strip()) fulltext_topn = m.topn - fulltext_search_expr, fulltext_search_weight = self._parse_fulltext_columns( - fulltext_query, self._fulltext_search_columns - ) + fulltext_search_expr, fulltext_search_weight = self._parse_fulltext_columns(fulltext_query, self._fulltext_search_columns) elif isinstance(m, MatchDenseExpr): vector_column_name = m.vector_column_name vector_data = m.embedding_data @@ -339,39 +337,27 @@ class OBConnection(OBConnectionBase): ) self.logger.debug("OBConnection.search with fusion sql: %s", fusion_sql) rows, elapsed_time = self._execute_search_sql(fusion_sql) - self.logger.info( - f"OBConnection.search table {table_name}, search type: fusion, elapsed time: {elapsed_time:.3f}s, rows: {len(rows)}" - ) + self.logger.info(f"OBConnection.search table {table_name}, search type: fusion, elapsed time: {elapsed_time:.3f}s, rows: {len(rows)}") for row in rows: result.messages.append(self._row_to_entity(row, db_output_fields + ["_score"])) result.total += 1 elif search_type == "vector": - vector_sql = self._build_vector_search_sql( - table_name, fields_expr, vector_search_score_expr, filters_expr, - vector_search_filter, vector_search_expr, limit, vector_topn, offset - ) + vector_sql = self._build_vector_search_sql(table_name, fields_expr, vector_search_score_expr, filters_expr, vector_search_filter, vector_search_expr, limit, vector_topn, offset) self.logger.debug("OBConnection.search with vector sql: %s", vector_sql) rows, elapsed_time = self._execute_search_sql(vector_sql) - self.logger.info( - f"OBConnection.search table {table_name}, search type: vector, elapsed time: {elapsed_time:.3f}s, rows: {len(rows)}" - ) + self.logger.info(f"OBConnection.search table {table_name}, search type: vector, elapsed time: {elapsed_time:.3f}s, rows: {len(rows)}") for row in rows: result.messages.append(self._row_to_entity(row, db_output_fields + ["_score"])) result.total += 1 elif search_type == "fulltext": - fulltext_sql = self._build_fulltext_search_sql( - table_name, fields_expr, fulltext_search_score_expr, filters_expr, - fulltext_search_filter, offset, limit, fulltext_topn - ) + fulltext_sql = self._build_fulltext_search_sql(table_name, fields_expr, fulltext_search_score_expr, filters_expr, fulltext_search_filter, offset, limit, fulltext_topn) self.logger.debug("OBConnection.search with fulltext sql: %s", fulltext_sql) rows, elapsed_time = self._execute_search_sql(fulltext_sql) - self.logger.info( - f"OBConnection.search table {table_name}, search type: fulltext, elapsed time: {elapsed_time:.3f}s, rows: {len(rows)}" - ) + self.logger.info(f"OBConnection.search table {table_name}, search type: fulltext, elapsed time: {elapsed_time:.3f}s, rows: {len(rows)}") for row in rows: result.messages.append(self._row_to_entity(row, db_output_fields + ["_score"])) @@ -387,14 +373,10 @@ class OBConnection(OBConnectionBase): order_by_expr = ("ORDER BY " + ", ".join(orders)) if orders else "" limit_expr = f"LIMIT {offset}, {limit}" if limit != 0 else "" - filter_sql = self._build_filter_search_sql( - table_name, fields_expr, filters_expr, order_by_expr, limit_expr - ) + filter_sql = self._build_filter_search_sql(table_name, fields_expr, filters_expr, order_by_expr, limit_expr) self.logger.debug("OBConnection.search with filter sql: %s", filter_sql) rows, elapsed_time = self._execute_search_sql(filter_sql) - self.logger.info( - f"OBConnection.search table {table_name}, search type: filter, elapsed time: {elapsed_time:.3f}s, rows: {len(rows)}" - ) + self.logger.info(f"OBConnection.search table {table_name}, search type: filter, elapsed time: {elapsed_time:.3f}s, rows: {len(rows)}") for row in rows: result.messages.append(self._row_to_entity(row, db_output_fields)) @@ -413,13 +395,7 @@ class OBConnection(OBConnectionBase): db_output_fields = [self.convert_field_name(f) for f in select_fields] fields_expr = ", ".join(db_output_fields) - sql = ( - f"SELECT {fields_expr}" - f" FROM {index_name}" - f" WHERE memory_id = {get_value_str(memory_id)} AND forget_at IS NOT NULL" - f" ORDER BY forget_at ASC" - f" LIMIT {limit}" - ) + sql = f"SELECT {fields_expr} FROM {index_name} WHERE memory_id = {get_value_str(memory_id)} AND forget_at IS NOT NULL ORDER BY forget_at ASC LIMIT {limit}" self.logger.debug("OBConnection.get_forgotten_messages sql: %s", sql) res = self.client.perform_raw_text_sql(sql) @@ -431,8 +407,7 @@ class OBConnection(OBConnectionBase): return result - def get_missing_field_message(self, select_fields: list[str], index_name: str, memory_id: str, field_name: str, - limit: int = 512): + def get_missing_field_message(self, select_fields: list[str], index_name: str, memory_id: str, field_name: str, limit: int = 512): """Get messages missing a specific field.""" if not self._check_table_exists_cached(index_name): return None @@ -441,13 +416,7 @@ class OBConnection(OBConnectionBase): db_output_fields = [self.convert_field_name(f) for f in select_fields] fields_expr = ", ".join(db_output_fields) - sql = ( - f"SELECT {fields_expr}" - f" FROM {index_name}" - f" WHERE memory_id = {get_value_str(memory_id)} AND {db_field_name} IS NULL" - f" ORDER BY valid_at ASC" - f" LIMIT {limit}" - ) + sql = f"SELECT {fields_expr} FROM {index_name} WHERE memory_id = {get_value_str(memory_id)} AND {db_field_name} IS NULL ORDER BY valid_at ASC LIMIT {limit}" self.logger.debug("OBConnection.get_missing_field_message sql: %s", sql) res = self.client.perform_raw_text_sql(sql) @@ -531,11 +500,7 @@ class OBConnection(OBConnectionBase): if not set_values: return True - update_sql = ( - f"UPDATE {index_name}" - f" SET {', '.join(set_values)}" - f" WHERE {' AND '.join(filters)}" - ) + update_sql = f"UPDATE {index_name} SET {', '.join(set_values)} WHERE {' AND '.join(filters)}" self.logger.debug("OBConnection.update sql: %s", update_sql) try: @@ -557,14 +522,14 @@ class OBConnection(OBConnectionBase): def get_total(self, res) -> int: if isinstance(res, tuple): return res[1] - if hasattr(res, 'total'): + if hasattr(res, "total"): return res.total return 0 def get_doc_ids(self, res) -> list[str]: if isinstance(res, tuple): res = res[0] - if hasattr(res, 'messages'): + if hasattr(res, "messages"): return [row.get("id") for row in res.messages if row.get("id")] return [] @@ -577,7 +542,7 @@ class OBConnection(OBConnectionBase): if not fields: return {} - messages = res.messages if hasattr(res, 'messages') else [] + messages = res.messages if hasattr(res, "messages") else [] for doc in messages: message = self.get_message_from_ob_doc(doc) @@ -588,10 +553,7 @@ class OBConnection(OBConnectionBase): if isinstance(v, list): m[n] = v continue - if n in ["message_id", "source_id", "valid_at", "invalid_at", "forget_at", "status"] and isinstance(v, - (int, - float, - bool)): + if n in ["message_id", "source_id", "valid_at", "invalid_at", "forget_at", "status"] and isinstance(v, (int, float, bool)): m[n] = v continue if not isinstance(v, str): @@ -610,9 +572,7 @@ class OBConnection(OBConnectionBase): if isinstance(res, tuple): res = res[0] messages = getattr(res, "messages", None) - return get_highlight_from_messages( - messages, keywords, field_name, is_english_fn=lambda s: is_english([s]) - ) + return get_highlight_from_messages(messages, keywords, field_name, is_english_fn=lambda s: is_english([s])) def get_aggregation(self, res, field_name: str): """Get aggregation for search results.""" diff --git a/memory/utils/prompt_util.py b/memory/utils/prompt_util.py index e46e1be6ab..422f45cd48 100644 --- a/memory/utils/prompt_util.py +++ b/memory/utils/prompt_util.py @@ -18,8 +18,8 @@ from typing import Optional, List from common.constants import MemoryType from common.time_utils import current_timestamp -class PromptAssembler: +class PromptAssembler: SYSTEM_BASE_TEMPLATE = """**Memory Extraction Specialist** You are an expert at analyzing conversations to extract structured memory. @@ -47,7 +47,6 @@ You are an expert at analyzing conversations to extract structured memory. - invalid_at: When it becomes false (e.g., repeal, disproven) or empty if still true - Default: valid_at = conversation time, invalid_at = "" for timeless facts """, - MemoryType.EPISODIC.name.lower(): """ **EXTRACT EPISODIC KNOWLEDGE:** - Specific experiences, events, personal stories @@ -59,7 +58,6 @@ You are an expert at analyzing conversations to extract structured memory. - invalid_at: Event end time or empty if instantaneous - Extract explicit times: "at 3 PM", "last Monday", "from X to Y" """, - MemoryType.PROCEDURAL.name.lower(): """ **EXTRACT PROCEDURAL KNOWLEDGE:** - Processes, methods, step-by-step instructions @@ -71,7 +69,7 @@ You are an expert at analyzing conversations to extract structured memory. - invalid_at: When it expires/becomes obsolete or empty if current - For version-specific: use release dates - For best practices: invalid_at = "" - """ + """, } OUTPUT_TEMPLATES = { @@ -84,7 +82,6 @@ You are an expert at analyzing conversations to extract structured memory. } ] """, - MemoryType.EPISODIC.name.lower(): """ "episodic": [ { @@ -94,7 +91,6 @@ You are an expert at analyzing conversations to extract structured memory. } ] """, - MemoryType.PROCEDURAL.name.lower(): """ "procedural": [ { @@ -103,7 +99,7 @@ You are an expert at analyzing conversations to extract structured memory. "invalid_at": "procedure expiration timestamp or empty" } ] - """ + """, } BASE_USER_PROMPT = """ @@ -111,7 +107,7 @@ You are an expert at analyzing conversations to extract structured memory. {conversation} **CONVERSATION TIME:** {conversation_time} -**CURRENT TIME:** {current_time} +**CURRENT TIME:** {current_time} """ @classmethod @@ -123,9 +119,7 @@ You are an expert at analyzing conversations to extract structured memory. output_format = cls._generate_output_format(types_to_extract) full_prompt = cls.SYSTEM_BASE_TEMPLATE.format( - type_specific_instructions=type_instructions, - timestamp_format=config.get("timestamp_format", "ISO 8601"), - max_items=config.get("max_items_per_type", 5) + type_specific_instructions=type_instructions, timestamp_format=config.get("timestamp_format", "ISO 8601"), max_items=config.get("max_items_per_type", 5) ) full_prompt += f"\n**REQUIRED OUTPUT FORMAT (JSON):**\n```json\n{{\n{output_format}\n}}\n```\n" @@ -140,7 +134,7 @@ You are an expert at analyzing conversations to extract structured memory. def _get_types_to_extract(requested_types: List[str]) -> List[str]: types = set() for rt in requested_types: - if rt in [e.name.lower() for e in MemoryType] and rt != MemoryType.RAW.name.lower(): + if rt in [e.name.lower() for e in MemoryType] and rt != MemoryType.RAW.name.lower(): types.add(rt) return list(types) @@ -184,12 +178,7 @@ You are an expert at analyzing conversations to extract structured memory. return "\n".join(examples) @classmethod - def assemble_user_prompt( - cls, - conversation: str, - conversation_time: Optional[str] = None, - current_time: Optional[str] = None - ) -> str: + def assemble_user_prompt(cls, conversation: str, conversation_time: Optional[str] = None, current_time: Optional[str] = None) -> str: return cls.BASE_USER_PROMPT.format( conversation=conversation, conversation_time=conversation_time or "Not specified", diff --git a/rag/advanced_rag/__init__.py b/rag/advanced_rag/__init__.py index bde0ff643d..c294e8898d 100644 --- a/rag/advanced_rag/__init__.py +++ b/rag/advanced_rag/__init__.py @@ -17,4 +17,4 @@ from .tree_structured_query_decomposition_retrieval import TreeStructuredQueryDecompositionRetrieval as DeepResearcher -__all__ = ['DeepResearcher'] \ No newline at end of file +__all__ = ["DeepResearcher"] diff --git a/rag/advanced_rag/knowlege_compile/raptor.py b/rag/advanced_rag/knowlege_compile/raptor.py index b3b853642e..e57905bda5 100644 --- a/rag/advanced_rag/knowlege_compile/raptor.py +++ b/rag/advanced_rag/knowlege_compile/raptor.py @@ -862,7 +862,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: random_state=random_state, task_id=task_id, ) - + # Loop-termination guarantee. The outer ``while end - start > 1`` # relies on each layer strictly shrinking the input count. If # the clusterer degenerates and returns one cluster per input, @@ -878,10 +878,9 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: # that ``layers`` is monotonically shrinking. if n_clusters >= len(embeddings): logging.warning( - "RAPTOR clustering did not reduce input count " - "(%d inputs → %d clusters); collapsing this layer " - "into a single summary to prevent a non-terminating loop", - len(embeddings), n_clusters, + "RAPTOR clustering did not reduce input count (%d inputs → %d clusters); collapsing this layer into a single summary to prevent a non-terminating loop", + len(embeddings), + n_clusters, ) n_clusters = 1 lbls = [0] * len(embeddings) diff --git a/rag/advanced_rag/tree_structured_query_decomposition_retrieval.py b/rag/advanced_rag/tree_structured_query_decomposition_retrieval.py index 38d9f9808b..c814bdee51 100644 --- a/rag/advanced_rag/tree_structured_query_decomposition_retrieval.py +++ b/rag/advanced_rag/tree_structured_query_decomposition_retrieval.py @@ -24,13 +24,14 @@ from timeit import default_timer as timer class TreeStructuredQueryDecompositionRetrieval: - def __init__(self, - chat_mdl: LLMBundle, - prompt_config: dict, - kb_retrieve: partial = None, - kg_retrieve: partial = None, - internet_enabled: bool = False, - ): + def __init__( + self, + chat_mdl: LLMBundle, + prompt_config: dict, + kb_retrieve: partial = None, + kg_retrieve: partial = None, + internet_enabled: bool = False, + ): self.chat_mdl = chat_mdl self.prompt_config = prompt_config self._kb_retrieve = kb_retrieve @@ -103,7 +104,7 @@ class TreeStructuredQueryDecompositionRetrieval: async def _research(self, chunk_info, question, query, depth=3, callback=None): if depth == 0: - #if callback: + # if callback: # await callback("Reach the max search depth.") return "" if callback: @@ -111,9 +112,9 @@ class TreeStructuredQueryDecompositionRetrieval: st = timer() ret = await self._retrieve_information(query) if callback: - await callback("Retrieval %d results in %.1fms"%(len(ret["chunks"]), (timer()-st)*1000)) + await callback("Retrieval %d results in %.1fms" % (len(ret["chunks"]), (timer() - st) * 1000)) await self._async_update_chunk_info(chunk_info, ret) - ret = kb_prompt(ret, self.chat_mdl.max_length*0.5) + ret = kb_prompt(ret, self.chat_mdl.max_length * 0.5) if callback: await callback("Checking the sufficiency for retrieved information.") @@ -123,13 +124,13 @@ class TreeStructuredQueryDecompositionRetrieval: await callback(f"Yes, the retrieved information is sufficient for '{question}'.") return ret - #if callback: + # if callback: # await callback("The retrieved information is not sufficient. Planing next steps...") succ_question_info = await multi_queries_gen(self.chat_mdl, question, query, suff.get("missing_information", []), ret) if callback: await callback("Next step is to search for the following questions:
- " + "
- ".join(step["question"] for step in succ_question_info["questions"])) steps = [] for step in succ_question_info["questions"]: - steps.append(asyncio.create_task(self._research(chunk_info, step["question"], step["query"], depth-1, callback))) + steps.append(asyncio.create_task(self._research(chunk_info, step["question"], step["query"], depth - 1, callback))) results = await asyncio.gather(*steps, return_exceptions=True) return "\n".join([str(r) for r in results]) diff --git a/rag/app/__init__.py b/rag/app/__init__.py index e156bc93dd..177b91dd05 100644 --- a/rag/app/__init__.py +++ b/rag/app/__init__.py @@ -12,4 +12,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# \ No newline at end of file +# diff --git a/rag/app/audio.py b/rag/app/audio.py index 2741c91a90..38294eb997 100644 --- a/rag/app/audio.py +++ b/rag/app/audio.py @@ -35,8 +35,7 @@ def chunk(filename, binary, tenant_id, lang, callback=None, **kwargs): if not ext: raise RuntimeError("No extension detected.") - if ext not in [".da", ".wave", ".wav", ".mp3", ".aac", ".flac", ".ogg", ".aiff", ".au", ".midi", ".wma", - ".realaudio", ".vqf", ".oggvorbis", ".ape"]: + if ext not in [".da", ".wave", ".wav", ".mp3", ".aac", ".flac", ".ogg", ".aiff", ".au", ".midi", ".wma", ".realaudio", ".vqf", ".oggvorbis", ".ape"]: raise RuntimeError(f"Extension {ext} is not supported yet.") tmp_path = "" diff --git a/rag/app/book.py b/rag/app/book.py index 8611f38401..c19c0d257d 100644 --- a/rag/app/book.py +++ b/rag/app/book.py @@ -87,11 +87,7 @@ def chunk(filename, binary=None, from_page=0, to_page=MAXIMUM_PAGE_NUMBER, lang= tbls = vision_figure_parser_docx_wrapper(sections=sections, tbls=tbls, callback=callback, **kwargs) # tbls = [((None, lns), None) for lns in tbls] - sections = [ - (item[0], item[1] if item[1] is not None else "") - for item in sections - if not isinstance(item[1], (Image.Image, LazyImage)) - ] + sections = [(item[0], item[1] if item[1] is not None else "") for item in sections if not isinstance(item[1], (Image.Image, LazyImage))] callback(0.8, "Finish parsing.") elif re.search(r"\.pdf$", filename, re.IGNORECASE): diff --git a/rag/app/email.py b/rag/app/email.py index 9edaddcb79..0e9f4b0419 100644 --- a/rag/app/email.py +++ b/rag/app/email.py @@ -27,13 +27,13 @@ import io def chunk( - filename, - binary=None, - from_page=0, - to_page=MAXIMUM_PAGE_NUMBER, - lang="Chinese", - callback=None, - **kwargs, + filename, + binary=None, + from_page=0, + to_page=MAXIMUM_PAGE_NUMBER, + lang="Chinese", + callback=None, + **kwargs, ): """ Only eml is supported @@ -93,10 +93,7 @@ def chunk( _add_content(msg, msg.get_content_type()) - sections = TxtParser.parser_txt("\n".join(text_txt)) + [ - (line, "") for line in - HtmlParser.parser_txt("\n".join(html_txt), chunk_token_num=parser_config["chunk_token_num"]) if line - ] + sections = TxtParser.parser_txt("\n".join(text_txt)) + [(line, "") for line in HtmlParser.parser_txt("\n".join(html_txt), chunk_token_num=parser_config["chunk_token_num"]) if line] st = timer() chunks = naive_merge( @@ -116,9 +113,7 @@ def chunk( filename = part.get_filename() payload = part.get_payload(decode=True) try: - attachment_res.extend( - naive_chunk(filename, payload, callback=callback, **kwargs) - ) + attachment_res.extend(naive_chunk(filename, payload, callback=callback, **kwargs)) except Exception: pass @@ -128,9 +123,7 @@ def chunk( if __name__ == "__main__": import sys - def dummy(prog=None, msg=""): pass - chunk(sys.argv[1], callback=dummy) diff --git a/rag/app/manual.py b/rag/app/manual.py index c2e17aeb20..371755e3c6 100644 --- a/rag/app/manual.py +++ b/rag/app/manual.py @@ -183,7 +183,7 @@ def chunk(filename, binary=None, from_page=0, to_page=MAXIMUM_PAGE_NUMBER, lang= txt, layoutno, poss = section if isinstance(poss, str): - poss = (getattr(pdf_parser, "extract_positions", lambda _: [])(poss) or [[0, 0, 0, 0, 0]]) + poss = getattr(pdf_parser, "extract_positions", lambda _: [])(poss) or [[0, 0, 0, 0, 0]] if poss: first = poss[0] # tuple: ([pn], x1, x2, y1, y2) pn = first[0] @@ -268,10 +268,7 @@ def chunk(filename, binary=None, from_page=0, to_page=MAXIMUM_PAGE_NUMBER, lang= if table_ctx or image_ctx: attach_media_context(res, table_ctx, image_ctx) if res and pdf_parser and getattr(pdf_parser, "outlines", None): - res[0]["__outline__"] = [ - {"title": title, "depth": depth} - for title, depth, *_ in pdf_parser.outlines - ] + res[0]["__outline__"] = [{"title": title, "depth": depth} for title, depth, *_ in pdf_parser.outlines] return res elif re.search(r"\.docx?$", filename, re.IGNORECASE): diff --git a/rag/app/paper.py b/rag/app/paper.py index f578a5fc7a..cc3a898a6d 100644 --- a/rag/app/paper.py +++ b/rag/app/paper.py @@ -20,8 +20,7 @@ import re from deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper from common.constants import ParserType, MAXIMUM_PAGE_NUMBER -from rag.nlp import rag_tokenizer, tokenize, tokenize_table, add_positions, bullets_category, title_frequency, \ - tokenize_chunks, attach_media_context +from rag.nlp import rag_tokenizer, tokenize, tokenize_table, add_positions, bullets_category, title_frequency, tokenize_chunks, attach_media_context from deepdoc.parser import PdfParser import numpy as np from rag.app.naive import by_plaintext, PARSERS @@ -33,18 +32,12 @@ class Pdf(PdfParser): self.model_species = ParserType.PAPER.value super().__init__() - def __call__(self, filename, binary=None, from_page=0, - to_page=MAXIMUM_PAGE_NUMBER, zoomin=3, callback=None): + def __call__(self, filename, binary=None, from_page=0, to_page=MAXIMUM_PAGE_NUMBER, zoomin=3, callback=None): from timeit import default_timer as timer + start = timer() callback(msg="OCR started") - self.__images__( - filename if not binary else binary, - zoomin, - from_page, - to_page, - callback - ) + self.__images__(filename if not binary else binary, zoomin, from_page, to_page, callback) callback(msg="OCR finished ({:.2f}s)".format(timer() - start)) start = timer() @@ -66,25 +59,21 @@ class Pdf(PdfParser): # clean mess if column_width < self.page_images[0].size[0] / zoomin / 2: - logging.debug("two_column................... {} {}".format(column_width, - self.page_images[0].size[0] / zoomin / 2)) + logging.debug("two_column................... {} {}".format(column_width, self.page_images[0].size[0] / zoomin / 2)) self.boxes = self.sort_X_by_page(self.boxes, column_width / 2) for b in self.boxes: b["text"] = re.sub(r"([\t  ]|\u3000){2,}", " ", b["text"].strip()) def _begin(txt): - return re.match( - "[0-9. 一、i]*(introduction|abstract|摘要|引言|keywords|key words|关键词|background|背景|目录|前言|contents)", - txt.lower().strip()) + return re.match("[0-9. 一、i]*(introduction|abstract|摘要|引言|keywords|key words|关键词|background|背景|目录|前言|contents)", txt.lower().strip()) if from_page > 0: return { "title": "", "authors": "", "abstract": "", - "sections": [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", "")) for b in self.boxes if - re.match(r"(text|title)", b.get("layoutno", "text"))], - "tables": tbls + "sections": [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", "")) for b in self.boxes if re.match(r"(text|title)", b.get("layoutno", "text"))], + "tables": tbls, } # get title and authors title = "" @@ -128,10 +117,7 @@ class Pdf(PdfParser): if not abstr: i = 0 - callback( - 0.8, "Page {}~{}: Text merging finished".format( - from_page, min( - to_page, self.total_page))) + callback(0.8, "Page {}~{}: Text merging finished".format(from_page, min(to_page, self.total_page))) for b in self.boxes: logging.debug("{} {}".format(b["text"], b.get("layoutno"))) logging.debug("{}".format(tbls)) @@ -140,25 +126,19 @@ class Pdf(PdfParser): "title": title, "authors": " ".join(authors), "abstract": abstr, - "sections": [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", "")) for b in self.boxes[i:] if - re.match(r"(text|title)", b.get("layoutno", "text"))], - "tables": tbls + "sections": [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", "")) for b in self.boxes[i:] if re.match(r"(text|title)", b.get("layoutno", "text"))], + "tables": tbls, } -def chunk(filename, binary=None, from_page=0, to_page=MAXIMUM_PAGE_NUMBER, - lang="Chinese", callback=None, **kwargs): +def chunk(filename, binary=None, from_page=0, to_page=MAXIMUM_PAGE_NUMBER, lang="Chinese", callback=None, **kwargs): """ - Only pdf is supported. - The abstract of the paper will be sliced as an entire chunk, and will not be sliced partly. + Only pdf is supported. + The abstract of the paper will be sliced as an entire chunk, and will not be sliced partly. """ - parser_config = kwargs.get( - "parser_config", { - "chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"}) + parser_config = kwargs.get("parser_config", {"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"}) if re.search(r"\.pdf$", filename, re.IGNORECASE): - layout_recognizer, parser_model_name = normalize_layout_recognizer( - parser_config.get("layout_recognize", "DeepDOC") - ) + layout_recognizer, parser_model_name = normalize_layout_recognizer(parser_config.get("layout_recognize", "DeepDOC")) if isinstance(layout_recognizer, bool): layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text" @@ -169,8 +149,7 @@ def chunk(filename, binary=None, from_page=0, to_page=MAXIMUM_PAGE_NUMBER, if name == "deepdoc": pdf_parser = Pdf() - paper = pdf_parser(filename if not binary else binary, - from_page=from_page, to_page=to_page, callback=callback) + paper = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page, callback=callback) sections = paper.get("sections", []) else: kwargs.pop("parse_method", None) @@ -186,16 +165,10 @@ def chunk(filename, binary=None, from_page=0, to_page=MAXIMUM_PAGE_NUMBER, layout_recognizer=layout_recognizer, mineru_llm_name=parser_model_name, parse_method="paper", - **kwargs + **kwargs, ) - paper = { - "title": filename, - "authors": " ", - "abstract": "", - "sections": sections, - "tables": tables - } + paper = {"title": filename, "authors": " ", "abstract": "", "sections": sections, "tables": tables} tbls = paper["tables"] tbls = vision_figure_parser_pdf_wrapper( @@ -208,8 +181,7 @@ def chunk(filename, binary=None, from_page=0, to_page=MAXIMUM_PAGE_NUMBER, else: raise NotImplementedError("file type not supported yet(pdf supported)") - doc = {"docnm_kwd": filename, "authors_tks": rag_tokenizer.tokenize(paper["authors"]), - "title_tks": rag_tokenizer.tokenize(paper["title"] if paper["title"] else filename)} + doc = {"docnm_kwd": filename, "authors_tks": rag_tokenizer.tokenize(paper["authors"]), "title_tks": rag_tokenizer.tokenize(paper["title"] if paper["title"] else filename)} doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"]) doc["authors_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["authors_tks"]) # is it English @@ -223,8 +195,7 @@ def chunk(filename, binary=None, from_page=0, to_page=MAXIMUM_PAGE_NUMBER, txt = pdf_parser.remove_tag(paper["abstract"]) d["important_kwd"] = ["abstract", "总结", "概括", "summary", "summarize"] d["important_tks"] = " ".join(d["important_kwd"]) - d["image"], poss = pdf_parser.crop( - paper["abstract"], need_position=True) + d["image"], poss = pdf_parser.crop(paper["abstract"], need_position=True) add_positions(d, poss) tokenize(d, txt, eng) res.append(d) @@ -257,7 +228,7 @@ def chunk(filename, binary=None, from_page=0, to_page=MAXIMUM_PAGE_NUMBER, image_ctx = max(0, int(parser_config.get("image_context_size", 0) or 0)) if table_ctx or image_ctx: attach_media_context(res, table_ctx, image_ctx) - + return res @@ -342,9 +313,7 @@ def chunk(filename, binary=None, from_page=0, to_page=MAXIMUM_PAGE_NUMBER, if __name__ == "__main__": import sys - def dummy(prog=None, msg=""): pass - chunk(sys.argv[1], callback=dummy) diff --git a/rag/app/picture.py b/rag/app/picture.py index 22aa5fd11c..9b1c27a4e6 100644 --- a/rag/app/picture.py +++ b/rag/app/picture.py @@ -58,8 +58,7 @@ def chunk(filename, binary, tenant_id, lang, callback=None, **kwargs): cv_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.IMAGE2TEXT) cv_mdl = LLMBundle(tenant_id, model_config=cv_model_config, lang=lang) video_prompt = str(parser_config.get("video_prompt", "") or "") - ans = asyncio.run( - cv_mdl.async_chat(system="", history=[], gen_conf={}, video_bytes=binary, filename=filename, video_prompt=video_prompt)) + ans = asyncio.run(cv_mdl.async_chat(system="", history=[], gen_conf={}, video_bytes=binary, filename=filename, video_prompt=video_prompt)) callback(0.8, "CV LLM respond: %s ..." % ans[:32]) ans += "\n" + ans tokenize(doc, ans, eng) diff --git a/rag/app/presentation.py b/rag/app/presentation.py index e49d1bd2d8..c5b5025fb6 100644 --- a/rag/app/presentation.py +++ b/rag/app/presentation.py @@ -154,7 +154,7 @@ def chunk(filename, binary=None, from_page=0, to_page=MAXIMUM_PAGE_NUMBER, lang= logging.warning(f"python-pptx parsing failed for {filename}: {e}, trying tika as fallback") if callback: callback(0.1, "python-pptx failed, trying tika as fallback") - + try: from tika import parser as tika_parser except Exception as tika_error: @@ -163,18 +163,18 @@ def chunk(filename, binary=None, from_page=0, to_page=MAXIMUM_PAGE_NUMBER, lang= callback(0.8, error_msg) logging.warning(f"{error_msg} for {filename}.") raise NotImplementedError(error_msg) - + if binary: binary_data = binary else: - with open(filename, 'rb') as f: + with open(filename, "rb") as f: binary_data = f.read() doc_parsed = tika_parser.from_buffer(BytesIO(binary_data)) - + if doc_parsed.get("content", None) is not None: sections = doc_parsed["content"].split("\n") sections = [s for s in sections if s.strip()] - + for pn, txt in enumerate(sections): d = copy.deepcopy(doc) pn += from_page @@ -184,7 +184,7 @@ def chunk(filename, binary=None, from_page=0, to_page=MAXIMUM_PAGE_NUMBER, lang= d["position_int"] = [(pn + 1, 0, 0, 0, 0)] tokenize(d, txt, eng) res.append(d) - + if callback: callback(0.8, "Finish parsing with tika.") return res diff --git a/rag/app/qa.py b/rag/app/qa.py index 7a55f32d3d..293bd53280 100644 --- a/rag/app/qa.py +++ b/rag/app/qa.py @@ -63,31 +63,18 @@ class Excel(ExcelParser): else: fails.append(str(i + 1)) if len(res) % 999 == 0: - callback(len(res) * - 0.6 / - total, ("Extract pairs: {}".format(len(res)) + - (f"{len(fails)} failure, line: %s..." % - (",".join(fails[:3])) if fails else ""))) + callback(len(res) * 0.6 / total, ("Extract pairs: {}".format(len(res)) + (f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) - callback(0.6, ("Extract pairs: {}. ".format(len(res)) + ( - f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) - self.is_english = is_english( - [rmPrefix(q) for q, _ in random_choices(res, k=30) if len(q) > 1]) + callback(0.6, ("Extract pairs: {}. ".format(len(res)) + (f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) + self.is_english = is_english([rmPrefix(q) for q, _ in random_choices(res, k=30) if len(q) > 1]) return res class Pdf(PdfParser): - def __call__(self, filename, binary=None, from_page=0, - to_page=MAXIMUM_PAGE_NUMBER, zoomin=3, callback=None): + def __call__(self, filename, binary=None, from_page=0, to_page=MAXIMUM_PAGE_NUMBER, zoomin=3, callback=None): start = timer() callback(msg="OCR started") - self.__images__( - filename if not binary else binary, - zoomin, - from_page, - to_page, - callback - ) + self.__images__(filename if not binary else binary, zoomin, from_page, to_page, callback) callback(msg="OCR finished ({:.2f}s)".format(timer() - start)) logging.debug("OCR({}~{}): {:.2f}s".format(from_page, to_page, timer() - start)) start = timer() @@ -112,9 +99,9 @@ class Pdf(PdfParser): if q_bull == -1: raise ValueError("Unable to recognize Q&A structure.") qai_list = [] - last_q, last_a, last_tag = '', '', '' + last_q, last_a, last_tag = "", "", "" last_index = -1 - last_box = {'text': ''} + last_box = {"text": ""} last_bull = None def sort_key(element): @@ -125,13 +112,13 @@ class Pdf(PdfParser): tbls.sort(key=sort_key) tbl_index = 0 last_pn, last_bottom = 0, 0 - tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = 1, 0, 0, 0, 0, '@@0\t0\t0\t0\t0##', '' + tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = 1, 0, 0, 0, 0, "@@0\t0\t0\t0\t0##", "" for box in self.boxes: - section, line_tag = box['text'], self._line_tag(box, zoomin) + section, line_tag = box["text"], self._line_tag(box, zoomin) has_bull, index = has_qbullet(reg, box, last_box, last_index, last_bull, bull_x0_list) last_box, last_index, last_bull = box, index, has_bull - line_pn = get_float(line_tag.lstrip('@@').split('\t')[0]) - line_top = get_float(line_tag.rstrip('##').split('\t')[3]) + line_pn = get_float(line_tag.lstrip("@@").split("\t")[0]) + line_top = get_float(line_tag.rstrip("##").split("\t")[3]) tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = self.get_tbls_info(tbls, tbl_index) if not has_bull: # No question bullet if not last_q: @@ -141,34 +128,32 @@ class Pdf(PdfParser): else: sum_tag = line_tag sum_section = section - while ((tbl_pn == last_pn and tbl_top >= last_bottom) or (tbl_pn > last_pn)) \ - and ((tbl_pn == line_pn and tbl_top <= line_top) or ( - tbl_pn < line_pn)): # add image at the middle of current answer - sum_tag = f'{tbl_tag}{sum_tag}' - sum_section = f'{tbl_text}{sum_section}' + while ((tbl_pn == last_pn and tbl_top >= last_bottom) or (tbl_pn > last_pn)) and ( + (tbl_pn == line_pn and tbl_top <= line_top) or (tbl_pn < line_pn) + ): # add image at the middle of current answer + sum_tag = f"{tbl_tag}{sum_tag}" + sum_section = f"{tbl_text}{sum_section}" tbl_index += 1 - tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = self.get_tbls_info(tbls, - tbl_index) - last_a = f'{last_a}{sum_section}' - last_tag = f'{last_tag}{sum_tag}' + tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = self.get_tbls_info(tbls, tbl_index) + last_a = f"{last_a}{sum_section}" + last_tag = f"{last_tag}{sum_tag}" else: if last_q: - while ((tbl_pn == last_pn and tbl_top >= last_bottom) or (tbl_pn > last_pn)) \ - and ((tbl_pn == line_pn and tbl_top <= line_top) or ( - tbl_pn < line_pn)): # add image at the end of last answer - last_tag = f'{last_tag}{tbl_tag}' - last_a = f'{last_a}{tbl_text}' + while ((tbl_pn == last_pn and tbl_top >= last_bottom) or (tbl_pn > last_pn)) and ( + (tbl_pn == line_pn and tbl_top <= line_top) or (tbl_pn < line_pn) + ): # add image at the end of last answer + last_tag = f"{last_tag}{tbl_tag}" + last_a = f"{last_a}{tbl_text}" tbl_index += 1 - tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = self.get_tbls_info(tbls, - tbl_index) + tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = self.get_tbls_info(tbls, tbl_index) image, poss = self.crop(last_tag, need_position=True) qai_list.append((last_q, last_a, image, poss)) - last_q, last_a, last_tag = '', '', '' + last_q, last_a, last_tag = "", "", "" last_q = has_bull.group() _, end = has_bull.span() last_a = section[end:] last_tag = line_tag - last_bottom = float(line_tag.rstrip('##').split('\t')[4]) + last_bottom = float(line_tag.rstrip("##").split("\t")[4]) last_pn = line_pn if last_q: qai_list.append((last_q, last_a, *self.crop(last_tag, need_position=True))) @@ -176,15 +161,14 @@ class Pdf(PdfParser): def get_tbls_info(self, tbls, tbl_index): if tbl_index >= len(tbls): - return 1, 0, 0, 0, 0, '@@0\t0\t0\t0\t0##', '' + return 1, 0, 0, 0, 0, "@@0\t0\t0\t0\t0##", "" tbl_pn = tbls[tbl_index][1][0][0] + 1 tbl_left = tbls[tbl_index][1][0][1] tbl_right = tbls[tbl_index][1][0][2] tbl_top = tbls[tbl_index][1][0][3] tbl_bottom = tbls[tbl_index][1][0][4] - tbl_tag = "@@{}\t{:.1f}\t{:.1f}\t{:.1f}\t{:.1f}##" \ - .format(tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom) - _tbl_text = ''.join(tbls[tbl_index][0][1]) + tbl_tag = "@@{}\t{:.1f}\t{:.1f}\t{:.1f}\t{:.1f}##".format(tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom) + _tbl_text = "".join(tbls[tbl_index][0][1]) return tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, _tbl_text @@ -193,8 +177,7 @@ class Docx(DocxParser): pass def __call__(self, filename, binary=None, from_page=0, to_page=MAXIMUM_PAGE_NUMBER, callback=None): - self.doc = Document( - filename) if not binary else Document(BytesIO(binary)) + self.doc = Document(filename) if not binary else Document(BytesIO(binary)) pn = 0 last_answer, last_image = "", None question_stack, level_stack = [], [] @@ -202,19 +185,19 @@ class Docx(DocxParser): for p in self.doc.paragraphs: if pn > to_page: break - question_level, p_text = 0, '' + question_level, p_text = 0, "" if from_page <= pn < to_page and p.text.strip(): question_level, p_text = docx_question_level(p) if not question_level or question_level > 6: # not a question - last_answer = f'{last_answer}\n{p_text}' + last_answer = f"{last_answer}\n{p_text}" current_image = self.get_picture(self.doc, p) last_image = concat_img(last_image, current_image) else: # is a question if last_answer or last_image: - sum_question = '\n'.join(question_stack) + sum_question = "\n".join(question_stack) if sum_question: qai_list.append((sum_question, last_answer, last_image)) - last_answer, last_image = '', None + last_answer, last_image = "", None i = question_level while question_stack and i <= level_stack[-1]: @@ -223,13 +206,13 @@ class Docx(DocxParser): question_stack.append(p_text) level_stack.append(question_level) for run in p.runs: - if 'lastRenderedPageBreak' in run._element.xml: + if "lastRenderedPageBreak" in run._element.xml: pn += 1 continue - if 'w:br' in run._element.xml and 'type="page"' in run._element.xml: + if "w:br" in run._element.xml and 'type="page"' in run._element.xml: pn += 1 if last_answer: - sum_question = '\n'.join(question_stack) + sum_question = "\n".join(question_stack) if sum_question: qai_list.append((sum_question, last_answer, last_image)) @@ -255,15 +238,13 @@ class Docx(DocxParser): def rmPrefix(txt): - return re.sub( - r"^(问题|答案|回答|user|assistant|Q|A|Question|Answer|问|答)[\t:: ]+", "", txt.strip(), flags=re.IGNORECASE) + return re.sub(r"^(问题|答案|回答|user|assistant|Q|A|Question|Answer|问|答)[\t:: ]+", "", txt.strip(), flags=re.IGNORECASE) def beAdocPdf(d, q, a, eng, image, poss): qprefix = "Question: " if eng else "问题:" aprefix = "Answer: " if eng else "回答:" - d["content_with_weight"] = "\t".join( - [qprefix + rmPrefix(q), aprefix + rmPrefix(a)]) + d["content_with_weight"] = "\t".join([qprefix + rmPrefix(q), aprefix + rmPrefix(a)]) d["content_ltks"] = rag_tokenizer.tokenize(q) d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) if image: @@ -276,8 +257,7 @@ def beAdocPdf(d, q, a, eng, image, poss): def beAdocDocx(d, q, a, eng, image, row_num=-1): qprefix = "Question: " if eng else "问题:" aprefix = "Answer: " if eng else "回答:" - d["content_with_weight"] = "\t".join( - [qprefix + rmPrefix(q), aprefix + rmPrefix(a)]) + d["content_with_weight"] = "\t".join([qprefix + rmPrefix(q), aprefix + rmPrefix(a)]) d["content_ltks"] = rag_tokenizer.tokenize(q) d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) if image: @@ -291,8 +271,7 @@ def beAdocDocx(d, q, a, eng, image, row_num=-1): def beAdoc(d, q, a, eng, row_num=-1): qprefix = "Question: " if eng else "问题:" aprefix = "Answer: " if eng else "回答:" - d["content_with_weight"] = "\t".join( - [qprefix + rmPrefix(q), aprefix + rmPrefix(a)]) + d["content_with_weight"] = "\t".join([qprefix + rmPrefix(q), aprefix + rmPrefix(a)]) d["content_ltks"] = rag_tokenizer.tokenize(q) d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) if row_num >= 0: @@ -301,28 +280,25 @@ def beAdoc(d, q, a, eng, row_num=-1): def mdQuestionLevel(s): - match = re.match(r'#*', s) - return (len(match.group(0)), s.lstrip('#').lstrip()) if match else (0, s) + match = re.match(r"#*", s) + return (len(match.group(0)), s.lstrip("#").lstrip()) if match else (0, s) def chunk(filename, binary=None, from_page=0, to_page=MAXIMUM_PAGE_NUMBER, lang="Chinese", callback=None, **kwargs): """ - Excel and csv(txt) format files are supported. - If the file is in Excel format, there should be 2 column question and answer without header. - And question column is ahead of answer column. - And it's O.K if it has multiple sheets as long as the columns are rightly composed. + Excel and csv(txt) format files are supported. + If the file is in Excel format, there should be 2 column question and answer without header. + And question column is ahead of answer column. + And it's O.K if it has multiple sheets as long as the columns are rightly composed. - If it's in csv format, it should be UTF-8 encoded. Use TAB as delimiter to separate question and answer. + If it's in csv format, it should be UTF-8 encoded. Use TAB as delimiter to separate question and answer. - All the deformed lines will be ignored. - Every pair of Q&A will be treated as a chunk. + All the deformed lines will be ignored. + Every pair of Q&A will be treated as a chunk. """ eng = lang.lower() == "english" res = [] - doc = { - "docnm_kwd": filename, - "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename)) - } + doc = {"docnm_kwd": filename, "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))} if re.search(r"\.xlsx?$", filename, re.IGNORECASE): callback(0.1, "Start to parse.") excel_parser = Excel() @@ -358,14 +334,12 @@ def chunk(filename, binary=None, from_page=0, to_page=MAXIMUM_PAGE_NUMBER, lang= question, answer = arr i += 1 if len(res) % 999 == 0: - callback(len(res) * 0.6 / len(lines), ("Extract Q&A: {}".format(len(res)) + ( - f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) + callback(len(res) * 0.6 / len(lines), ("Extract Q&A: {}".format(len(res)) + (f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) if question: res.append(beAdoc(deepcopy(doc), question, answer, eng, len(lines))) - callback(0.6, ("Extract Q&A: {}".format(len(res)) + ( - f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) + callback(0.6, ("Extract Q&A: {}".format(len(res)) + (f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) return res @@ -391,21 +365,18 @@ def chunk(filename, binary=None, from_page=0, to_page=MAXIMUM_PAGE_NUMBER, lang= res.append(beAdoc(deepcopy(doc), question, answer, eng, i)) question, answer = row if len(res) % 999 == 0: - callback(len(res) * 0.6 / len(lines), ("Extract Q&A: {}".format(len(res)) + ( - f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) + callback(len(res) * 0.6 / len(lines), ("Extract Q&A: {}".format(len(res)) + (f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) if question: res.append(beAdoc(deepcopy(doc), question, answer, eng, len(lines))) - callback(0.6, ("Extract Q&A: {}".format(len(res)) + ( - f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) + callback(0.6, ("Extract Q&A: {}".format(len(res)) + (f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) return res elif re.search(r"\.pdf$", filename, re.IGNORECASE): callback(0.1, "Start to parse.") pdf_parser = Pdf() - qai_list, tbls = pdf_parser(filename if not binary else binary, - from_page=from_page, to_page=to_page, callback=callback) + qai_list, tbls = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page, callback=callback) for q, a, image, poss in qai_list: res.append(beAdocPdf(deepcopy(doc), q, a, eng, image, poss)) return res @@ -418,21 +389,20 @@ def chunk(filename, binary=None, from_page=0, to_page=MAXIMUM_PAGE_NUMBER, lang= question_stack, level_stack = [], [] code_block = False for index, line in enumerate(lines): - if line.strip().startswith('```'): + if line.strip().startswith("```"): code_block = not code_block - question_level, question = 0, '' + question_level, question = 0, "" if not code_block: question_level, question = mdQuestionLevel(line) if not question_level or question_level > 6: # not a question - last_answer = f'{last_answer}\n{line}' + last_answer = f"{last_answer}\n{line}" else: # is a question if last_answer.strip(): - sum_question = '\n'.join(question_stack) + sum_question = "\n".join(question_stack) if sum_question: - res.append(beAdoc(deepcopy(doc), sum_question, - markdown(last_answer, extensions=['markdown.extensions.tables']), eng, index)) - last_answer = '' + res.append(beAdoc(deepcopy(doc), sum_question, markdown(last_answer, extensions=["markdown.extensions.tables"]), eng, index)) + last_answer = "" i = question_level while question_stack and i <= level_stack[-1]: @@ -441,31 +411,26 @@ def chunk(filename, binary=None, from_page=0, to_page=MAXIMUM_PAGE_NUMBER, lang= question_stack.append(question) level_stack.append(question_level) if last_answer.strip(): - sum_question = '\n'.join(question_stack) + sum_question = "\n".join(question_stack) if sum_question: - res.append(beAdoc(deepcopy(doc), sum_question, - markdown(last_answer, extensions=['markdown.extensions.tables']), eng, index)) + res.append(beAdoc(deepcopy(doc), sum_question, markdown(last_answer, extensions=["markdown.extensions.tables"]), eng, index)) return res elif re.search(r"\.docx$", filename, re.IGNORECASE): docx_parser = Docx() - qai_list, tbls = docx_parser(filename, binary, - from_page=0, to_page=MAXIMUM_PAGE_NUMBER, callback=callback) + qai_list, tbls = docx_parser(filename, binary, from_page=0, to_page=MAXIMUM_PAGE_NUMBER, callback=callback) res = tokenize_table(tbls, doc, eng) for i, (q, a, image) in enumerate(qai_list): res.append(beAdocDocx(deepcopy(doc), q, a, eng, image, i)) return res - raise NotImplementedError( - "Excel, csv(txt), pdf, markdown and docx format files are supported.") + raise NotImplementedError("Excel, csv(txt), pdf, markdown and docx format files are supported.") if __name__ == "__main__": import sys - def dummy(prog=None, msg=""): pass - chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy) diff --git a/rag/app/resume.py b/rag/app/resume.py index a244c75219..e7aed73dbf 100644 --- a/rag/app/resume.py +++ b/rag/app/resume.py @@ -45,12 +45,13 @@ from common.constants import MAXIMUM_PAGE_NUMBER # tiktoken for long random string filtering (ref: SmartResume should_remove strategy) try: import tiktoken + _tiktoken_encoding = tiktoken.encoding_for_model("gpt-3.5-turbo") except ImportError: _tiktoken_encoding = None # Long random string pattern: 40+ char alphanumeric mixed strings (hash, token, tracking ID, etc.) -_LONG_RANDOM_PATTERN = re.compile(r'[a-zA-Z0-9\-~_]{40,}') +_LONG_RANDOM_PATTERN = re.compile(r"[a-zA-Z0-9\-~_]{40,}") import logging as logger from rag.nlp import rag_tokenizer @@ -79,6 +80,7 @@ def _get_layout_recognizer(): if _layout_recognizer is None: try: from deepdoc.vision import LayoutRecognizer + _layout_recognizer = LayoutRecognizer("layout") logger.info("YOLOv10 layout detector loaded successfully") except Exception as e: @@ -86,13 +88,11 @@ def _get_layout_recognizer(): _layout_recognizer = False # Mark as failed to avoid repeated attempts return _layout_recognizer if _layout_recognizer is not False else None + # ==================== Constants ==================== # Fields forbidden from being used as select fields in resume -FORBIDDEN_SELECT_FIELDS = [ - "name_pinyin_kwd", "edu_first_fea_kwd", "degree_kwd", - "sch_rank_kwd", "edu_fea_kwd" -] +FORBIDDEN_SELECT_FIELDS = ["name_pinyin_kwd", "edu_first_fea_kwd", "degree_kwd", "sch_rank_kwd", "edu_fea_kwd"] # Field name to description mapping (bilingual versions for chunk construction) FIELD_MAP_ZH = { @@ -243,8 +243,8 @@ def _normalize_whitespace(text: str) -> str: """ Unicode whitespace normalization (ref: SmartResume _clean_text_content) - Replaces various Unicode spaces (\u00A0 non-breaking space, \u3000 fullwidth space, - \u2000-\u200A various width spaces, etc.) with regular spaces, + Replaces various Unicode spaces (\u00a0 non-breaking space, \u3000 fullwidth space, + \u2000-\u200a various width spaces, etc.) with regular spaces, then applies NFKC normalization (fullwidth to halfwidth) and merges consecutive spaces. Args: @@ -255,14 +255,11 @@ def _normalize_whitespace(text: str) -> str: if not text: return "" # NFKC normalization (fullwidth to halfwidth, etc.) - text = unicodedata.normalize('NFKC', text) + text = unicodedata.normalize("NFKC", text) # Unify various Unicode spaces to regular space - text = re.sub( - r'[\u0020\u00A0\u1680\u2000-\u200A\u2028\u2029\u202F\u205F\u3000\u00A7]', - ' ', text - ) + text = re.sub(r"[\u0020\u00A0\u1680\u2000-\u200A\u2028\u2029\u202F\u205F\u3000\u00A7]", " ", text) # Merge consecutive spaces - text = re.sub(r' {2,}', ' ', text) + text = re.sub(r" {2,}", " ", text) return text.strip() @@ -282,11 +279,7 @@ def _should_remove_random_str(match: re.Match) -> bool: if _tiktoken_encoding is None: # When tiktoken is unavailable, use simple heuristic: case/digit alternation frequency s = match.group(0) - changes = sum( - 1 for i in range(1, len(s)) - if s[i].isdigit() != s[i-1].isdigit() - or (s[i].isalpha() and s[i-1].isalpha() and s[i].isupper() != s[i-1].isupper()) - ) + changes = sum(1 for i in range(1, len(s)) if s[i].isdigit() != s[i - 1].isdigit() or (s[i].isalpha() and s[i - 1].isalpha() and s[i].isupper() != s[i - 1].isupper())) return changes / len(s) > 0.3 encoded = _tiktoken_encoding.encode(match.group(0)) return len(encoded) > len(match.group(0)) * 0.5 @@ -306,20 +299,15 @@ def _clean_line_content(text: str) -> str: # Unicode whitespace normalization text = _normalize_whitespace(text) # Filter long random strings (hash, token and other meaningless content) - text = _LONG_RANDOM_PATTERN.sub( - lambda m: '' if _should_remove_random_str(m) else m.group(0), - text - ) + text = _LONG_RANDOM_PATTERN.sub(lambda m: "" if _should_remove_random_str(m) else m.group(0), text) # Clean up extra spaces after filtering - text = re.sub(r' {2,}', ' ', text).strip() + text = re.sub(r" {2,}", " ", text).strip() return text # ==================== Phase 1: PDF Text Fusion and Layout Reconstruction ==================== - - def _is_noise_char(obj: dict) -> bool: """ Determine if a PDF character object is a decorative layer noise character @@ -350,15 +338,13 @@ def _is_noise_char(obj: dict) -> bool: # Whitelist condition 2: Has PDF structure tag tag = obj.get("tag") - if tag in ("Span", "NonStruct", "P", "H1", "H2", "H3", "H4", "H5", "H6", - "TD", "TH", "LI", "L", "Table", "TR", "Figure", "Caption"): + if tag in ("Span", "NonStruct", "P", "H1", "H2", "H3", "H4", "H5", "H6", "TD", "TH", "LI", "L", "Table", "TR", "Figure", "Caption"): return False # Has semantic structure tag = body content # Doesn't meet any whitelist condition, treat as noise return True - def _extract_metadata_text(binary: bytes) -> list[dict]: """ Extract text blocks from PDF metadata (with coordinate info) @@ -377,6 +363,7 @@ def _extract_metadata_text(binary: bytes) -> list[dict]: """ try: import pdfplumber + blocks = [] with pdfplumber.open(BytesIO(binary)) as pdf: for page_idx, page in enumerate(pdf.pages): @@ -387,9 +374,7 @@ def _extract_metadata_text(binary: bytes) -> list[dict]: # may use non-embedded fonts without structure tags, skip filtering to avoid false positives try: original_char_count = len(page.chars) - filtered_page = page.filter( - lambda obj: not _is_noise_char(obj) - ) + filtered_page = page.filter(lambda obj: not _is_noise_char(obj)) filtered_char_count = len(filtered_page.chars) if original_char_count > 0 and filtered_char_count < original_char_count * 0.3: # Filtered out over 70% of chars, likely false positives, fall back to original page @@ -400,9 +385,7 @@ def _extract_metadata_text(binary: bytes) -> list[dict]: # Use extract_words for extraction (with real coordinates) words = [] try: - words = filtered_page.extract_words( - keep_blank_chars=False, use_text_flow=True - ) + words = filtered_page.extract_words(keep_blank_chars=False, use_text_flow=True) except Exception: pass @@ -461,14 +444,16 @@ def _extract_metadata_text(binary: bytes) -> list[dict]: cleaned = line.strip() if not cleaned: continue - blocks.append({ - "text": cleaned, - "x0": 0, - "top": i * line_height, - "x1": page_width, - "bottom": i * line_height + line_height - 2, - "page": page_idx, - }) + blocks.append( + { + "text": cleaned, + "x0": 0, + "top": i * line_height, + "x1": page_width, + "bottom": i * line_height + line_height - 2, + "page": page_idx, + } + ) # Extract table content from the page # Many resumes use table layouts (e.g., personal info section), extract_words may miss table structure @@ -495,14 +480,16 @@ def _extract_metadata_text(binary: bytes) -> list[dict]: break if is_dup: continue - blocks.append({ - "text": row_text, - "x0": 0, - "top": max_top, - "x1": page_width, - "bottom": max_top + row_height - 2, - "page": page_idx, - }) + blocks.append( + { + "text": row_text, + "x0": 0, + "top": max_top, + "x1": page_width, + "bottom": max_top + row_height - 2, + "page": page_idx, + } + ) max_top += row_height except Exception as e: logger.debug(f"PDF table extraction skipped (page {page_idx}): {e}") @@ -511,6 +498,7 @@ def _extract_metadata_text(binary: bytes) -> list[dict]: logger.warning(f"PDF metadata extraction failed: {e}") return [] + def _extract_ocr_text(binary: bytes, meta_blocks: list[dict] | None = None) -> list[dict]: """ Extract OCR text blocks using blackout strategy (with coordinate info). @@ -564,12 +552,16 @@ def _extract_ocr_text(binary: bytes, meta_blocks: list[dict] | None = None) -> l xs = [p[0] for p in coords if isinstance(p, (list, tuple))] ys = [p[1] for p in coords if isinstance(p, (list, tuple))] if xs and ys: - blocks.append({ - "text": text.strip(), - "x0": min(xs), "top": min(ys), - "x1": max(xs), "bottom": max(ys), - "page": page_idx, - }) + blocks.append( + { + "text": text.strip(), + "x0": min(xs), + "top": min(ys), + "x1": max(xs), + "bottom": max(ys), + "page": page_idx, + } + ) return blocks except Exception as e: logger.warning(f"OCR extraction failed: {e}") @@ -613,8 +605,6 @@ def _fuse_text_blocks(meta_blocks: list[dict], ocr_blocks: list[dict]) -> list[d return fused - - def _layout_aware_reorder(blocks: list[dict]) -> list[dict]: """ Layout-aware hierarchical sorting (ref: SmartResume Hierarchical Re-ordering) @@ -752,6 +742,7 @@ def _build_indexed_text(blocks: list[dict]) -> tuple[str, list[str], list[dict]] return indexed_text, lines, line_positions + def _is_valid_line(line: str) -> bool: """ Check if a text line is valid content (not garbled) @@ -770,19 +761,19 @@ def _is_valid_line(line: str) -> bool: # Short lines may be valid content like names, keep them return True - cid_count = len(re.findall(r'\(cid:\d+\)', line)) + cid_count = len(re.findall(r"\(cid:\d+\)", line)) if cid_count >= 3: return False # Valid characters: Chinese (incl. extension), ASCII alphanumeric, common punctuation and spaces, fullwidth chars, CJK punctuation valid_chars = re.findall( - r'[\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff' - r'a-zA-Z0-9\s@.,:;!?()()【】\-_/\\|·•' - r'、,。:;!?\u201c\u201d\u2018\u2019《》' - r'\uff01-\uff5e' - r'\u3000-\u303f' - r'#%&+=~`\u00b7\u2022\u2013\u2014' - r']', - line + r"[\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff" + r"a-zA-Z0-9\s@.,:;!?()()【】\-_/\\|·•" + r"、,。:;!?\u201c\u201d\u2018\u2019《》" + r"\uff01-\uff5e" + r"\u3000-\u303f" + r"#%&+=~`\u00b7\u2022\u2013\u2014" + r"]", + line, ) ratio = len(valid_chars) / len(line) if len(line) > 0 else 0 if ratio < 0.5: @@ -791,7 +782,7 @@ def _is_valid_line(line: str) -> bool: # Detect PDF custom font mapping causing single-character spacing anomaly pattern # Feature: lots of "single letter space single letter space" sequences, e.g. "O U W Z_W V 2 X 3" # Stats: ratio of space-separated single chars among non-space chars - spaced_singles = re.findall(r'(?:^|\s)([a-zA-Z0-9])(?:\s|$)', line) + spaced_singles = re.findall(r"(?:^|\s)([a-zA-Z0-9])(?:\s|$)", line) non_space_len = len(line.replace(" ", "")) if non_space_len > 5 and len(spaced_singles) > 0: # If ratio of space-separated single chars to non-space chars is too high, classify as garbled @@ -801,16 +792,12 @@ def _is_valid_line(line: str) -> bool: # Detect consecutive meaningless mixed-case alphanumeric sequences (e.g. "UJqZX9V2") # Normal English words don't have such frequent case alternation patterns - garbled_seqs = re.findall(r'[a-zA-Z0-9]{4,}', line.replace(" ", "")) + garbled_seqs = re.findall(r"[a-zA-Z0-9]{4,}", line.replace(" ", "")) if garbled_seqs: garbled_count = 0 for seq in garbled_seqs: # Count case alternations - case_changes = sum( - 1 for i in range(1, len(seq)) - if (seq[i].isupper() != seq[i-1].isupper() and seq[i].isalpha() and seq[i-1].isalpha()) - or (seq[i].isdigit() != seq[i-1].isdigit()) - ) + case_changes = sum(1 for i in range(1, len(seq)) if (seq[i].isupper() != seq[i - 1].isupper() and seq[i].isalpha() and seq[i - 1].isalpha()) or (seq[i].isdigit() != seq[i - 1].isdigit())) # Too high alternation frequency = garbled sequence (normal words like "Spring" have only 1 alternation) if len(seq) >= 4 and case_changes / len(seq) > 0.5: garbled_count += 1 @@ -853,14 +840,14 @@ def _fix_split_labels(lines: list[str]) -> list[str]: # Detect in-line split patterns: "X:content Y" where (Y, X) is a split pair for (suffix_char, prefix_char), full_label in split_patterns.items(): # Pattern: "prefix_char:content suffix_char" (first half at line start, second half at line end) - pattern = rf'^({re.escape(prefix_char)})\s*[::]\s*(.+?)\s+{re.escape(suffix_char)}\s*$' + pattern = rf"^({re.escape(prefix_char)})\s*[::]\s*(.+?)\s+{re.escape(suffix_char)}\s*$" m = re.match(pattern, line) if m: content = m.group(2).strip() line = f"{full_label}:{content}" break # Pattern: "suffix_char content prefix_char:" (second half at line start, first half at line end) - pattern2 = rf'^{re.escape(suffix_char)}\s*[::]?\s*(.+?)\s+{re.escape(prefix_char)}\s*$' + pattern2 = rf"^{re.escape(suffix_char)}\s*[::]?\s*(.+?)\s+{re.escape(prefix_char)}\s*$" m2 = re.match(pattern2, line) if m2: content = m2.group(1).strip() @@ -870,9 +857,6 @@ def _fix_split_labels(lines: list[str]) -> list[str]: return fixed - - - def extract_text(filename: str, binary: bytes) -> tuple[str, list[str], list[dict]]: """ Extract text content based on file type (Pipeline Phase 1). @@ -921,9 +905,7 @@ def extract_text(filename: str, binary: bytes) -> tuple[str, list[str], list[dic if total_line_count > 0: valid_ratio = valid_line_count / total_line_count if valid_ratio < 0.6: - logger.info( - f"PDF metadata text quality low (valid line ratio {valid_ratio:.1%}), enabling OCR supplementation" - ) + logger.info(f"PDF metadata text quality low (valid line ratio {valid_ratio:.1%}), enabling OCR supplementation") need_ocr = True if need_ocr: @@ -941,6 +923,7 @@ def extract_text(filename: str, binary: bytes) -> tuple[str, list[str], list[dic elif fname_lower.endswith(".docx"): from docx import Document + doc = Document(BytesIO(binary)) lines = [p.text.strip() for p in doc.paragraphs if p.text.strip()] @@ -997,14 +980,14 @@ def _clean_llm_json_response(response: str) -> str: # Remove markdown code block markers text = text.replace("```json", "").replace("```", "").strip() # Remove reasoning model thinking tags - text = re.sub(r'.*?', '', text, flags=re.DOTALL).strip() + text = re.sub(r".*?", "", text, flags=re.DOTALL).strip() # Clean escaped quotes (SmartResume's approach) text = text.replace('\\"', '"') # SmartResume strategy: locate first { and last } start = text.find("{") end = text.rfind("}") if start != -1 and end != -1 and end > start: - return text[start:end + 1] + return text[start : end + 1] return text @@ -1032,9 +1015,9 @@ def _parse_json_with_repair(text: str) -> dict: # Second attempt: replace Python-style values (ref SmartResume) repaired = text.replace("'", '"') - repaired = repaired.replace('True', 'true') - repaired = repaired.replace('False', 'false') - repaired = repaired.replace('None', 'null') + repaired = repaired.replace("True", "true") + repaired = repaired.replace("False", "false") + repaired = repaired.replace("None", "null") try: return json.loads(repaired) except json.JSONDecodeError: @@ -1051,7 +1034,7 @@ def _parse_json_with_repair(text: str) -> dict: raise json.JSONDecodeError("All JSON repair strategies failed", text, 0) -def _call_llm(prompt: str, tenant_id , lang: str) -> Optional[dict]: +def _call_llm(prompt: str, tenant_id, lang: str) -> Optional[dict]: """ Call LLM and parse JSON response (ref SmartResume's retry + fault-tolerance strategy). @@ -1071,7 +1054,7 @@ def _call_llm(prompt: str, tenant_id , lang: str) -> Optional[dict]: from api.db.services.llm_service import LLMBundle from common.constants import LLMType - llm = LLMBundle(tenant_id, LLMType.CHAT, lang=lang) + llm = LLMBundle(tenant_id, LLMType.CHAT, lang=lang) for attempt in range(_LLM_MAX_RETRIES + 1): try: @@ -1121,9 +1104,10 @@ def _normalize_for_comparison(text: str) -> str: # Unicode NFKC normalization (fullwidth to halfwidth, etc.) text = unicodedata.normalize("NFKC", text) # Remove all whitespace characters - text = re.sub(r'\s+', '', text) + text = re.sub(r"\s+", "", text) return text.lower() + def _calc_single_exp_years(start_str: str, end_str: str) -> float: """ Calculate years for a single experience entry. @@ -1169,9 +1153,7 @@ def _calculate_work_years(experiences: list[dict]) -> float: """ total = 0.0 for exp in experiences: - total += _calc_single_exp_years( - exp.get("start_date", ""), exp.get("end_date", "") - ) + total += _calc_single_exp_years(exp.get("start_date", ""), exp.get("end_date", "")) return round(total, 1) @@ -1214,12 +1196,7 @@ def _parse_date_str(date_str: str) -> Optional[datetime.datetime]: return None - - -def _extract_description_from_range( - index_range: list, lines: list[str], - company: str = "", position: str = "" -) -> str: +def _extract_description_from_range(index_range: list, lines: list[str], company: str = "", position: str = "") -> str: """ Extract description from original text by index range (ref SmartResume's _extract_description_from_range). @@ -1244,7 +1221,7 @@ def _extract_description_from_range( if start_idx < 0 or end_idx >= len(lines) or start_idx > end_idx: return "" - extracted_lines = lines[start_idx:end_idx + 1] + extracted_lines = lines[start_idx : end_idx + 1] # Filter out lines containing both company name and position title (ref SmartResume) if company or position: @@ -1268,44 +1245,44 @@ def _extract_description_from_range( return "\n".join(line.strip() for line in extracted_lines if line.strip()) -def _extract_basic_info(indexed_text: str, tenant_id , lang: str) -> Optional[dict]: +def _extract_basic_info(indexed_text: str, tenant_id, lang: str) -> Optional[dict]: """Extract basic info (subtask 1). Basic info is usually at the beginning of the resume, first 8000 chars suffice. """ prompt = get_basic_info_prompt(lang).format(indexed_text=indexed_text[:8000]) - return _call_llm(prompt,tenant_id, lang) + return _call_llm(prompt, tenant_id, lang) -def _extract_work_experience(indexed_text: str, tenant_id , lang: str) -> Optional[dict]: +def _extract_work_experience(indexed_text: str, tenant_id, lang: str) -> Optional[dict]: """Extract work experience (subtask 2, using index pointers). Work experience may span the middle-to-end of the resume, use full text to avoid truncation. """ prompt = get_work_exp_prompt(lang).format(indexed_text=indexed_text) - return _call_llm(prompt, tenant_id , lang) + return _call_llm(prompt, tenant_id, lang) -def _extract_education(indexed_text: str, tenant_id , lang: str) -> Optional[dict]: +def _extract_education(indexed_text: str, tenant_id, lang: str) -> Optional[dict]: """Extract education background (subtask 3). Education is usually at the end of the resume, must use full text to avoid truncation. Resume text is generally under 30K chars, within LLM context window. """ prompt = get_education_prompt(lang).format(indexed_text=indexed_text) - return _call_llm(prompt,tenant_id, lang) + return _call_llm(prompt, tenant_id, lang) -def _extract_project_experience(indexed_text: str, tenant_id , lang: str) -> Optional[dict]: +def _extract_project_experience(indexed_text: str, tenant_id, lang: str) -> Optional[dict]: """Extract project experience (subtask 4, using index pointers). Project experience may span the middle-to-end of the resume, use full text to avoid truncation. """ prompt = get_project_exp_prompt(lang).format(indexed_text=indexed_text) - return _call_llm(prompt, tenant_id , lang) + return _call_llm(prompt, tenant_id, lang) -def parse_with_llm(indexed_text: str, lines: list[str], tenant_id , lang: str) -> Optional[dict]: +def parse_with_llm(indexed_text: str, lines: list[str], tenant_id, lang: str) -> Optional[dict]: """ Extract resume info using parallel task decomposition strategy (ref SmartResume Section 3.2). @@ -1325,10 +1302,10 @@ def parse_with_llm(indexed_text: str, lines: list[str], tenant_id , lang: str) - try: # Execute four subtasks in parallel with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: - future_basic = executor.submit(_extract_basic_info, indexed_text, tenant_id , lang) - future_work = executor.submit(_extract_work_experience, indexed_text, tenant_id , lang) + future_basic = executor.submit(_extract_basic_info, indexed_text, tenant_id, lang) + future_work = executor.submit(_extract_work_experience, indexed_text, tenant_id, lang) future_edu = executor.submit(_extract_education, indexed_text, tenant_id, lang) - future_project = executor.submit(_extract_project_experience, indexed_text, tenant_id , lang) + future_project = executor.submit(_extract_project_experience, indexed_text, tenant_id, lang) basic_info = future_basic.result(timeout=60) work_exp = future_work.result(timeout=60) @@ -1363,20 +1340,20 @@ def parse_with_llm(indexed_text: str, lines: list[str], tenant_id , lang: str) - if position: positions.append(position) # Save detailed info for each experience entry - work_exp_details.append({ - "company": company, - "position": position, - "start_date": start_date, - "end_date": end_date, - "years": years, - }) + work_exp_details.append( + { + "company": company, + "position": position, + "start_date": start_date, + "end_date": end_date, + "years": years, + } + ) # Index pointer mechanism: extract description from original text by line range # Use _extract_description_from_range to filter header lines (ref SmartResume) desc_lines = exp.get("desc_lines", []) if isinstance(desc_lines, list) and len(desc_lines) == 2: - desc = _extract_description_from_range( - desc_lines, lines, company=company, position=position - ) + desc = _extract_description_from_range(desc_lines, lines, company=company, position=position) if desc.strip(): work_descs.append(desc.strip()) @@ -1426,11 +1403,22 @@ def parse_with_llm(indexed_text: str, lines: list[str], tenant_id , lang: str) - resume["degree_kwd"] = degrees # Infer highest degree (supports both Chinese and English degree names) degree_rank = { - "博士": 5, "PhD": 5, "Doctor": 5, - "硕士": 4, "Master": 4, "MBA": 4, "EMBA": 4, "MPA": 4, - "本科": 3, "Bachelor": 3, - "大专": 2, "专科": 2, "Associate": 2, "Diploma": 2, - "高中": 1, "High School": 1, + "博士": 5, + "PhD": 5, + "Doctor": 5, + "硕士": 4, + "Master": 4, + "MBA": 4, + "EMBA": 4, + "MPA": 4, + "本科": 3, + "Bachelor": 3, + "大专": 2, + "专科": 2, + "Associate": 2, + "Diploma": 2, + "高中": 1, + "High School": 1, } highest = max(degrees, key=lambda d: degree_rank.get(d, 0), default="") if highest: @@ -1450,9 +1438,7 @@ def parse_with_llm(indexed_text: str, lines: list[str], tenant_id , lang: str) - # Index pointer mechanism: extract project description from original text by line range desc_lines = proj.get("desc_lines", []) if isinstance(desc_lines, list) and len(desc_lines) == 2: - desc = _extract_description_from_range( - desc_lines, lines, company=name, position=proj.get("role", "") - ) + desc = _extract_description_from_range(desc_lines, lines, company=name, position=proj.get("role", "")) if desc.strip(): project_descs.append(desc.strip()) @@ -1478,7 +1464,6 @@ def parse_with_llm(indexed_text: str, lines: list[str], tenant_id , lang: str) - # ==================== Phase 3: Regex Fallback Parsing ==================== - def parse_with_regex(text: str, lang: str = "Chinese") -> dict: """ Parse resume text using regex (fallback strategy) @@ -1498,19 +1483,19 @@ def parse_with_regex(text: str, lang: str = "Chinese") -> dict: if _is_english(lang): # English resume: extract from "Name: XXX" format for line in lines[:30]: - name_match = re.search(r'(?:Name|Full\s*Name)\s*[::]\s*([A-Za-z][A-Za-z\s\-\.]{1,40})', line, re.IGNORECASE) + name_match = re.search(r"(?:Name|Full\s*Name)\s*[::]\s*([A-Za-z][A-Za-z\s\-\.]{1,40})", line, re.IGNORECASE) if name_match: resume["name_kwd"] = name_match.group(1).strip() break # English resume strategy 2: first line if short text without digits, may be a name if "name_kwd" not in resume and lines: first = lines[0].strip() - if len(first) <= 40 and not re.search(r"\d", first) and re.match(r'^[A-Za-z][A-Za-z\s\-\.]+$', first): + if len(first) <= 40 and not re.search(r"\d", first) and re.match(r"^[A-Za-z][A-Za-z\s\-\.]+$", first): resume["name_kwd"] = first else: # Chinese resume: extract from "姓名:XXX" format for line in lines[:30]: - name_match = re.search(r'姓\s*名\s*[::]\s*([\u4e00-\u9fa5]{2,4})', line) + name_match = re.search(r"姓\s*名\s*[::]\s*([\u4e00-\u9fa5]{2,4})", line) if name_match: resume["name_kwd"] = name_match.group(1) break @@ -1518,14 +1503,35 @@ def parse_with_regex(text: str, lang: str = "Chinese") -> dict: # Strategy 2: search first 20 lines for standalone Chinese names (2-4 chars), excluding common title words if "name_kwd" not in resume: title_words = { - "个人", "简历", "求职", "应聘", "基本", "信息", "概述", "简介", - "教育", "工作", "经历", "经验", "技能", "项目", "自我", "评价", - "专业", "技术", "证书", "语言", "能力", "培训", "荣誉", "奖项", + "个人", + "简历", + "求职", + "应聘", + "基本", + "信息", + "概述", + "简介", + "教育", + "工作", + "经历", + "经验", + "技能", + "项目", + "自我", + "评价", + "专业", + "技术", + "证书", + "语言", + "能力", + "培训", + "荣誉", + "奖项", } for line in lines[:20]: if any(w in line for w in title_words): continue - if re.search(r'[::]', line) and len(line) > 6: + if re.search(r"[::]", line) and len(line) > 6: continue cleaned = re.sub(r"^[A-Za-z_\-\d\s]+\s+", "", line) cleaned = re.sub(r"\s+[A-Za-z_\-\d\s]+$", "", cleaned).strip() @@ -1537,7 +1543,7 @@ def parse_with_regex(text: str, lang: str = "Chinese") -> dict: if "name_kwd" not in resume and lines: first = lines[0].strip() if len(first) <= 10 and not re.search(r"\d", first): - cn_part = re.findall(r'[\u4e00-\u9fa5]+', first) + cn_part = re.findall(r"[\u4e00-\u9fa5]+", first) if cn_part and 2 <= len(cn_part[0]) <= 4: resume["name_kwd"] = cn_part[0] @@ -1554,17 +1560,17 @@ def parse_with_regex(text: str, lang: str = "Chinese") -> dict: # --- Extract Gender --- if _is_english(lang): # English resume: extract from "Gender: Male/Female" format - gender_label = re.search(r'(?:Gender|Sex)\s*[::]\s*(Male|Female|M|F)', text, re.IGNORECASE) + gender_label = re.search(r"(?:Gender|Sex)\s*[::]\s*(Male|Female|M|F)", text, re.IGNORECASE) if gender_label: raw = gender_label.group(1).strip().upper() resume["gender_kwd"] = "Male" if raw in ("M", "MALE") else "Female" else: - gender_match = re.search(r'\b(Male|Female)\b', text[:500], re.IGNORECASE) + gender_match = re.search(r"\b(Male|Female)\b", text[:500], re.IGNORECASE) if gender_match: resume["gender_kwd"] = gender_match.group(1).capitalize() else: # Chinese resume: extract from "性别:男/女" format - gender_label = re.search(r'性\s*别\s*[::]\s*(男|女)', text) + gender_label = re.search(r"性\s*别\s*[::]\s*(男|女)", text) if gender_label: resume["gender_kwd"] = gender_label.group(1) else: @@ -1575,9 +1581,9 @@ def parse_with_regex(text: str, lang: str = "Chinese") -> dict: # --- Extract Age --- if _is_english(lang): # English resume: match "25 years old" or "Age: 25" - age_match = re.search(r'(?:Age)\s*[::]\s*(\d{1,2})', text, re.IGNORECASE) + age_match = re.search(r"(?:Age)\s*[::]\s*(\d{1,2})", text, re.IGNORECASE) if not age_match: - age_match = re.search(r'(\d{1,2})\s*years?\s*old', text, re.IGNORECASE) + age_match = re.search(r"(\d{1,2})\s*years?\s*old", text, re.IGNORECASE) if age_match: resume["age_int"] = int(age_match.group(1)) else: @@ -1589,7 +1595,7 @@ def parse_with_regex(text: str, lang: str = "Chinese") -> dict: # --- Extract Date of Birth --- if _is_english(lang): # English resume: match "1990-01-15" or "Jan 15, 1990" etc. - birth_match = re.search(r'(?:Birth|DOB|Date\s*of\s*Birth)\s*[::]\s*(.{6,20})', text, re.IGNORECASE) + birth_match = re.search(r"(?:Birth|DOB|Date\s*of\s*Birth)\s*[::]\s*(.{6,20})", text, re.IGNORECASE) if birth_match: resume["birth_dt"] = birth_match.group(1).strip() else: @@ -1604,8 +1610,7 @@ def parse_with_regex(text: str, lang: str = "Chinese") -> dict: # --- Extract Education Level --- degree_keywords_zh = ["博士", "硕士", "本科", "大专", "专科", "高中", "MBA", "EMBA", "MPA"] - degree_keywords_en = ["PhD", "Master", "Bachelor", "Associate", "Diploma", "High School", - "MBA", "EMBA", "MPA", "Doctor"] + degree_keywords_en = ["PhD", "Master", "Bachelor", "Associate", "Diploma", "High School", "MBA", "EMBA", "MPA", "Doctor"] degree_keywords = degree_keywords_en if _is_english(lang) else degree_keywords_zh found_degrees = [d for d in degree_keywords if d in text] if found_degrees: @@ -1614,12 +1619,9 @@ def parse_with_regex(text: str, lang: str = "Chinese") -> dict: # --- Extract School --- if _is_english(lang): # English resume: match "University/College/Institute/School" keywords - schools = re.findall( - r'([A-Z][A-Za-z\s\-&]{2,40}(?:University|College|Institute|School|Academy))', - text - ) + schools = re.findall(r"([A-Z][A-Za-z\s\-&]{2,40}(?:University|College|Institute|School|Academy))", text) # Remove extra whitespace - schools = [re.sub(r'\s+', ' ', s).strip() for s in schools] + schools = [re.sub(r"\s+", " ", s).strip() for s in schools] else: # Chinese resume: match "XX大学/学院/职业技术学院" schools = re.findall(r"[\u4e00-\u9fa5]{2,15}(?:大学|学院|职业技术学院)", text) @@ -1630,10 +1632,7 @@ def parse_with_regex(text: str, lang: str = "Chinese") -> dict: # --- Extract Major --- if _is_english(lang): # English resume: match "Major: XXX" / "Field of Study: XXX" / "Specialization: XXX" - majors = re.findall( - r'(?:Major|Field\s*of\s*Study|Specialization|Concentration)\s*[::]\s*([A-Za-z\s\-&,]{2,40})', - text, re.IGNORECASE - ) + majors = re.findall(r"(?:Major|Field\s*of\s*Study|Specialization|Concentration)\s*[::]\s*([A-Za-z\s\-&,]{2,40})", text, re.IGNORECASE) majors = [m.strip() for m in majors if m.strip()] else: # Chinese resume: match "专业:XXX" @@ -1646,12 +1645,12 @@ def parse_with_regex(text: str, lang: str = "Chinese") -> dict: if _is_english(lang): # English resume: match common company suffixes en_company_patterns = [ - r'([A-Z][A-Za-z\s\-&,\.]{2,40}(?:Inc\.|Corp\.|Ltd\.|LLC|Co\.|Company|Group|Technologies|Technology|Solutions|Consulting|Services|Bank))', + r"([A-Z][A-Za-z\s\-&,\.]{2,40}(?:Inc\.|Corp\.|Ltd\.|LLC|Co\.|Company|Group|Technologies|Technology|Solutions|Consulting|Services|Bank))", ] companies = [] for pattern in en_company_patterns: companies.extend(re.findall(pattern, text)) - companies = [re.sub(r'\s+', ' ', c).strip() for c in companies] + companies = [re.sub(r"\s+", " ", c).strip() for c in companies] else: # Chinese resume: match "XX有限公司" format company_patterns = [ @@ -1665,11 +1664,7 @@ def parse_with_regex(text: str, lang: str = "Chinese") -> dict: unique_companies = [] seen = set() # Filter verb list (bilingual) - filter_verbs = ( - ["completed", "conducted", "implemented", "responsible", "participated", "developed"] - if _is_english(lang) - else ["完成", "进行", "实施", "负责", "参与", "开发"] - ) + filter_verbs = ["completed", "conducted", "implemented", "responsible", "participated", "developed"] if _is_english(lang) else ["完成", "进行", "实施", "负责", "参与", "开发"] min_len = 3 if _is_english(lang) else 6 for c in companies: if len(c) < min_len or any(v in c.lower() for v in filter_verbs) or c in seen: @@ -1693,24 +1688,35 @@ def parse_with_regex(text: str, lang: str = "Chinese") -> dict: # --- Extract Position (improved: context constraints to reduce noise) --- if _is_english(lang): # English resume: Strategy 1 - extract from "Title: XXX" / "Position: XXX" / "Role: XXX" format - position_label_matches = re.findall( - r'(?:Title|Position|Role|Job\s*Title)\s*[::]\s*([A-Za-z\s\-/&]{2,30})', - text, re.IGNORECASE - ) + position_label_matches = re.findall(r"(?:Title|Position|Role|Job\s*Title)\s*[::]\s*([A-Za-z\s\-/&]{2,30})", text, re.IGNORECASE) positions = [p.strip() for p in position_label_matches if p.strip()] # English resume: Strategy 2 - match common position suffix keywords en_position_suffixes = [ - "Engineer", "Manager", "Director", "Supervisor", "Specialist", - "Designer", "Consultant", "Assistant", "Architect", "Analyst", - "Developer", "Lead", "Officer", "Coordinator", "Administrator", - "Intern", "VP", "President", + "Engineer", + "Manager", + "Director", + "Supervisor", + "Specialist", + "Designer", + "Consultant", + "Assistant", + "Architect", + "Analyst", + "Developer", + "Lead", + "Officer", + "Coordinator", + "Administrator", + "Intern", + "VP", + "President", ] for line in lines: if len(line) > 60: continue # Skip overly long lines (usually description text) for suffix in en_position_suffixes: - match = re.search(rf'([A-Za-z\s\-]{{1,25}}{suffix})\b', line, re.IGNORECASE) + match = re.search(rf"([A-Za-z\s\-]{{1,25}}{suffix})\b", line, re.IGNORECASE) if match: pos = match.group(1).strip() # Filter out matches that are clearly not positions (contain verbs) @@ -1719,29 +1725,22 @@ def parse_with_regex(text: str, lang: str = "Chinese") -> dict: positions.append(pos) else: # Chinese resume: Strategy 1 - extract from "职位/岗位:XXX" format - position_label_matches = re.findall( - r'(?:职位|岗位|职务|职称|担任)\s*[::]\s*([\u4e00-\u9fa5a-zA-Z]{2,15})', - text - ) + position_label_matches = re.findall(r"(?:职位|岗位|职务|职称|担任)\s*[::]\s*([\u4e00-\u9fa5a-zA-Z]{2,15})", text) positions = list(position_label_matches) # Chinese resume: Strategy 2 - extract from work experience paragraphs (company name followed by position) for line in lines: - pos_match = re.search( - r'(?:有限公司|集团|银行)\s+([\u4e00-\u9fa5]{2,8}(?:工程师|经理|总监|主管|专员|设计师|顾问|助理|架构师|分析师|运营|产品))', - line - ) + pos_match = re.search(r"(?:有限公司|集团|银行)\s+([\u4e00-\u9fa5]{2,8}(?:工程师|经理|总监|主管|专员|设计师|顾问|助理|架构师|分析师|运营|产品))", line) if pos_match: positions.append(pos_match.group(1)) # Chinese resume: Strategy 3 - position keywords in standalone lines (length-limited to avoid matching description text) - position_suffixes = ["工程师", "经理", "总监", "主管", "专员", "设计师", "顾问", - "助理", "架构师", "分析师", "开发者", "负责人"] + position_suffixes = ["工程师", "经理", "总监", "主管", "专员", "设计师", "顾问", "助理", "架构师", "分析师", "开发者", "负责人"] for line in lines: if len(line) > 20: continue # Skip overly long lines for suffix in position_suffixes: - match = re.search(rf'([\u4e00-\u9fa5]{{1,6}}{suffix})', line) + match = re.search(rf"([\u4e00-\u9fa5]{{1,6}}{suffix})", line) if match: pos = match.group(1) if not any(v in pos for v in ["负责", "参与", "完成", "开发了", "设计了"]): @@ -1760,7 +1759,7 @@ def parse_with_regex(text: str, lang: str = "Chinese") -> dict: # --- Extract Years of Experience --- if _is_english(lang): # English resume: match "5 years experience" / "5+ years of experience" - work_exp_match = re.search(r'(\d+)\+?\s*years?\s*(?:of\s*)?(?:experience|work)', text, re.IGNORECASE) + work_exp_match = re.search(r"(\d+)\+?\s*years?\s*(?:of\s*)?(?:experience|work)", text, re.IGNORECASE) if work_exp_match: resume["work_exp_flt"] = float(work_exp_match.group(1)) else: @@ -1772,7 +1771,7 @@ def parse_with_regex(text: str, lang: str = "Chinese") -> dict: # --- Extract Graduation Year --- if _is_english(lang): # English resume: match "Graduated 2020" / "Graduation: 2020" / "Class of 2020" - grad_match = re.search(r'(?:Graduat(?:ed|ion)|Class\s*of)\s*[::]?\s*((?:19|20)\d{2})', text, re.IGNORECASE) + grad_match = re.search(r"(?:Graduat(?:ed|ion)|Class\s*of)\s*[::]?\s*((?:19|20)\d{2})", text, re.IGNORECASE) if grad_match: resume["edu_end_int"] = int(grad_match.group(1)) else: @@ -1787,7 +1786,6 @@ def parse_with_regex(text: str, lang: str = "Chinese") -> dict: return resume - # ==================== Phase 4: Post-processing Pipeline ==================== @@ -1885,8 +1883,7 @@ def _postprocess_resume(resume: dict, lines: list[str], lang: str = "Chinese") - resume["gender_kwd"] = "Female" if _en else "女" # --- Phase 3: Contextual deduplication --- - for list_field in ["corp_nm_tks", "school_name_tks", "major_tks", - "position_name_tks", "skill_tks"]: + for list_field in ["corp_nm_tks", "school_name_tks", "major_tks", "position_name_tks", "skill_tks"]: if isinstance(resume.get(list_field), list): # Order-preserving deduplication seen = set() @@ -1941,8 +1938,7 @@ def _postprocess_resume(resume: dict, lines: list[str], lang: str = "Chinese") - if dt_start_i <= eff_end_j and dt_start_j <= eff_end_i: is_dup = True break - elif (start_i and start_j and start_i == start_j) or \ - (end_i and end_j and end_i == end_j): + elif (start_i and start_j and start_i == start_j) or (end_i and end_j and end_i == end_j): # Fallback: exact string match if date parsing fails is_dup = True break @@ -1958,7 +1954,7 @@ def _postprocess_resume(resume: dict, lines: list[str], lang: str = "Chinese") - is_dup = True break if is_dup: - dup_corp = corp_names[i] if i < len(corp_names) else f"#{i+1}" + dup_corp = corp_names[i] if i < len(corp_names) else f"#{i + 1}" logger.debug(f"Work desc internal duplicate removed: {dup_corp}") else: kept_indices.append(i) @@ -1978,7 +1974,9 @@ def _postprocess_resume(resume: dict, lines: list[str], lang: str = "Chinese") - recalc_years = round(recalc_years, 1) if recalc_years > 0: resume["work_exp_flt"] = recalc_years - logger.info(f"Work years recalculated: {recalc_years} yrs (before dedup: {_calculate_work_years([{'start_date': d.get('start_date',''), 'end_date': d.get('end_date','')} for d in work_details])} yrs)") + logger.info( + f"Work years recalculated: {recalc_years} yrs (before dedup: {_calculate_work_years([{'start_date': d.get('start_date', ''), 'end_date': d.get('end_date', '')} for d in work_details])} yrs)" + ) new_corps = resume.get("corp_nm_tks", []) if new_corps: resume["corporation_name_tks"] = new_corps[0] @@ -2018,7 +2016,7 @@ def _postprocess_resume(resume: dict, lines: list[str], lang: str = "Chinese") - break if already_exists: skipped_count += 1 - proj_name = project_names[i] if i < len(project_names) else f"#{i+1}" + proj_name = project_names[i] if i < len(project_names) else f"#{i + 1}" logger.debug(f"Project desc already in work_desc, skipped: {proj_name}") else: # Append to work_desc_tks with project name prefix for context @@ -2035,8 +2033,13 @@ def _postprocess_resume(resume: dict, lines: list[str], lang: str = "Chinese") - logger.info(f"Merged project descs into work_desc_tks: {merged_count} merged, {skipped_count} skipped (duplicate)") # --- Phase 4: Field completion --- required_fields = [ - "name_kwd", "gender_kwd", "phone_kwd", "email_tks", - "position_name_tks", "school_name_tks", "major_tks", + "name_kwd", + "gender_kwd", + "phone_kwd", + "email_tks", + "position_name_tks", + "school_name_tks", + "major_tks", ] for field in required_fields: if field not in resume: @@ -2056,7 +2059,7 @@ def _postprocess_resume(resume: dict, lines: list[str], lang: str = "Chinese") - # ==================== Pipeline Orchestration & Chunk Construction ==================== -def parse_resume(filename: str, binary: bytes, tenant_id , lang: str = "Chinese") -> tuple[dict, list[str], list[dict]]: +def parse_resume(filename: str, binary: bytes, tenant_id, lang: str = "Chinese") -> tuple[dict, list[str], list[dict]]: """ Resume parsing pipeline orchestration function @@ -2084,7 +2087,7 @@ def parse_resume(filename: str, binary: bytes, tenant_id , lang: str = "Chinese" return {"name_kwd": default_name}, [], [] # Phase 2: Parallel LLM structured extraction - resume = parse_with_llm(indexed_text, lines, tenant_id , lang) + resume = parse_with_llm(indexed_text, lines, tenant_id, lang) # Phase 3: Fallback to regex parsing when LLM fails if not resume: @@ -2098,8 +2101,7 @@ def parse_resume(filename: str, binary: bytes, tenant_id , lang: str = "Chinese" return resume, lines, line_positions -def _build_chunk_document(filename: str, resume: dict, - lang: str = "Chinese") -> list[dict]: +def _build_chunk_document(filename: str, resume: dict, lang: str = "Chinese") -> list[dict]: """ Build a list of document chunks from structured resume information @@ -2129,17 +2131,14 @@ def _build_chunk_document(filename: str, resume: dict, # Extract key identity fields, redundantly written to each chunk # These fields are small in size but high in information density; once retrieved, the candidate can be immediately identified - _IDENTITY_FIELDS = ("name_kwd", "phone_kwd", "email_tks", "gender_kwd", - "highest_degree_kwd", "work_exp_flt") + _IDENTITY_FIELDS = ("name_kwd", "phone_kwd", "email_tks", "gender_kwd", "highest_degree_kwd", "work_exp_flt") identity_meta = {} for ik in _IDENTITY_FIELDS: iv = resume.get(ik) if not iv: continue if ik.endswith("_tks"): - identity_meta[ik] = rag_tokenizer.tokenize( - " ".join(iv) if isinstance(iv, list) else str(iv) - ) + identity_meta[ik] = rag_tokenizer.tokenize(" ".join(iv) if isinstance(iv, list) else str(iv)) elif ik.endswith("_kwd"): identity_meta[ik] = iv if isinstance(iv, list) else str(iv) elif ik.endswith("_flt"): @@ -2174,27 +2173,46 @@ def _build_chunk_document(filename: str, resume: dict, # Basic info field set: these fields should be merged into one chunk to avoid splitting name, phone, email, etc. _BASIC_INFO_FIELDS = { - "name_kwd", "name_pinyin_kwd", "gender_kwd", "age_int", - "phone_kwd", "email_tks", "birth_dt", "work_exp_flt", - "position_name_tks", "expect_city_names_tks", + "name_kwd", + "name_pinyin_kwd", + "gender_kwd", + "age_int", + "phone_kwd", + "email_tks", + "birth_dt", + "work_exp_flt", + "position_name_tks", + "expect_city_names_tks", "expect_position_name_tks", } # Education field set: degree, school, major, tags, etc. should be merged into one chunk _EDUCATION_FIELDS = { - "first_school_name_tks", "first_degree_kwd", "highest_degree_kwd", - "first_major_tks", "edu_first_fea_kwd", "degree_kwd", "major_tks", - "school_name_tks", "sch_rank_kwd", "edu_fea_kwd", "edu_end_int", + "first_school_name_tks", + "first_degree_kwd", + "highest_degree_kwd", + "first_major_tks", + "edu_first_fea_kwd", + "degree_kwd", + "major_tks", + "school_name_tks", + "sch_rank_kwd", + "edu_fea_kwd", + "edu_end_int", } # Skills & certificates field set: skills, languages, certificates are small, merge into one chunk _SKILL_CERT_FIELDS = { - "skill_tks", "language_tks", "certificate_tks", + "skill_tks", + "language_tks", + "certificate_tks", } # Work overview field set: company list, industry, most recent company merged into one chunk _WORK_OVERVIEW_FIELDS = { - "corporation_name_tks", "corp_nm_tks", "industry_name_tks", + "corporation_name_tks", + "corp_nm_tks", + "industry_name_tks", } # All merge groups: (field_set, group_title) tuple list @@ -2239,9 +2257,7 @@ def _build_chunk_document(filename: str, resume: dict, chunk = { "content_with_weight": content, "content_ltks": rag_tokenizer.tokenize(content), - "content_sm_ltks": rag_tokenizer.fine_grained_tokenize( - rag_tokenizer.tokenize(content) - ), + "content_sm_ltks": rag_tokenizer.fine_grained_tokenize(rag_tokenizer.tokenize(content)), } chunk.update(doc) # Redundantly write identity fields @@ -2327,9 +2343,7 @@ def _build_chunk_document(filename: str, resume: dict, chunk = { "content_with_weight": content, "content_ltks": rag_tokenizer.tokenize(content), - "content_sm_ltks": rag_tokenizer.fine_grained_tokenize( - rag_tokenizer.tokenize(content) - ), + "content_sm_ltks": rag_tokenizer.fine_grained_tokenize(rag_tokenizer.tokenize(content)), } chunk.update(doc) @@ -2361,9 +2375,7 @@ def _build_chunk_document(filename: str, resume: dict, chunk = { "content_with_weight": content, "content_ltks": rag_tokenizer.tokenize(content), - "content_sm_ltks": rag_tokenizer.fine_grained_tokenize( - rag_tokenizer.tokenize(content) - ), + "content_sm_ltks": rag_tokenizer.fine_grained_tokenize(rag_tokenizer.tokenize(content)), } chunk.update(doc) @@ -2402,9 +2414,7 @@ def _build_chunk_document(filename: str, resume: dict, chunk = { "content_with_weight": content, "content_ltks": rag_tokenizer.tokenize(content), - "content_sm_ltks": rag_tokenizer.fine_grained_tokenize( - rag_tokenizer.tokenize(content) - ), + "content_sm_ltks": rag_tokenizer.fine_grained_tokenize(rag_tokenizer.tokenize(content)), } chunk.update(doc) chunks.append(chunk) @@ -2429,8 +2439,8 @@ def _build_chunk_document(filename: str, resume: dict, return chunks -def _blackout_text_regions(image: "np.ndarray", meta_blocks: list[dict], page_idx: int, - pdf_to_img_scale: float) -> "np.ndarray": + +def _blackout_text_regions(image: "np.ndarray", meta_blocks: list[dict], page_idx: int, pdf_to_img_scale: float) -> "np.ndarray": """ Black out metadata-extracted text regions on the page image to prevent OCR duplication. @@ -2447,6 +2457,7 @@ def _blackout_text_regions(image: "np.ndarray", meta_blocks: list[dict], page_id Image with text regions blacked out """ import cv2 + blacked = image.copy() page_blocks = [b for b in meta_blocks if b.get("page") == page_idx] # Draw filled black rectangles over each metadata text block @@ -2465,9 +2476,7 @@ def _blackout_text_regions(image: "np.ndarray", meta_blocks: list[dict], page_id return blacked - -def chunk(filename, binary, tenant_id, from_page=0, to_page=MAXIMUM_PAGE_NUMBER, - lang="Chinese", callback=None, **kwargs): +def chunk(filename, binary, tenant_id, from_page=0, to_page=MAXIMUM_PAGE_NUMBER, lang="Chinese", callback=None, **kwargs): """ Resume parsing entry function (compatible with task_executor.py) @@ -2486,7 +2495,9 @@ def chunk(filename, binary, tenant_id, from_page=0, to_page=MAXIMUM_PAGE_NUMBER, Document chunk list """ if callback is None: - def callback(prog, msg): return None + + def callback(prog, msg): + return None if settings.DOC_ENGINE.lower() != "elasticsearch": raise Exception("Resume is supported only with Elasticsearch.") @@ -2495,7 +2506,7 @@ def chunk(filename, binary, tenant_id, from_page=0, to_page=MAXIMUM_PAGE_NUMBER, callback(0.1, "Starting resume parsing...") # Parse resume - resume, lines, line_positions = parse_resume(filename, binary, tenant_id , lang) + resume, lines, line_positions = parse_resume(filename, binary, tenant_id, lang) callback(0.6, "Resume structured extraction complete") # Build document chunks (with coordinate info) @@ -2516,10 +2527,13 @@ def _resort_page_with_layout(page_blocks: list[dict], layout_regions: list[dict] return [] if not layout_regions: - return sorted(page_blocks, key=lambda b: ( - (b.get("top", 0) + b.get("bottom", 0)) / 2, - (b.get("x0", 0) + b.get("x1", 0)) / 2, - )) + return sorted( + page_blocks, + key=lambda b: ( + (b.get("top", 0) + b.get("bottom", 0)) / 2, + (b.get("x0", 0) + b.get("x1", 0)) / 2, + ), + ) type_groups: dict[str, list] = {} for lt in layout_regions: @@ -2531,11 +2545,18 @@ def _resort_page_with_layout(page_blocks: list[dict], layout_regions: list[dict] key = f"{tp}-{idx}" x0, x1 = lt.get("x0", 0), lt.get("x1", 0) top, bottom = lt.get("top", 0), lt.get("bottom", 0) - entries.append({ - "key": key, "type": tp, - "x0": x0, "top": top, "x1": x1, "bottom": bottom, - "cy": (top + bottom) / 2, "cx": (x0 + x1) / 2, - }) + entries.append( + { + "key": key, + "type": tp, + "x0": x0, + "top": top, + "x1": x1, + "bottom": bottom, + "cy": (top + bottom) / 2, + "cx": (x0 + x1) / 2, + } + ) for b in page_blocks: if b.get("layoutno"): @@ -2543,8 +2564,7 @@ def _resort_page_with_layout(page_blocks: list[dict], layout_regions: list[dict] b_cx = (b.get("x0", 0) + b.get("x1", 0)) / 2 b_cy = (b.get("top", 0) + b.get("bottom", 0)) / 2 for entry in entries: - if (entry["x0"] <= b_cx <= entry["x1"] - and entry["top"] <= b_cy <= entry["bottom"]): + if entry["x0"] <= b_cx <= entry["x1"] and entry["top"] <= b_cy <= entry["bottom"]: b["layoutno"] = entry["key"] b["layout_type"] = entry["type"] break @@ -2557,10 +2577,7 @@ def _resort_page_with_layout(page_blocks: list[dict], layout_regions: list[dict] layout_blocks = [b for b in page_blocks if b.get("layoutno") == layout_key] if not layout_blocks: continue - text_total_area = sum( - (b.get("x1", 0) - b.get("x0", 0)) * (b.get("bottom", 0) - b.get("top", 0)) - for b in layout_blocks - ) + text_total_area = sum((b.get("x1", 0) - b.get("x0", 0)) * (b.get("bottom", 0) - b.get("top", 0)) for b in layout_blocks) if text_total_area / layout_area < 0.075: for b in layout_blocks: b["layoutno"] = "" @@ -2605,19 +2622,22 @@ def _resort_page_with_layout(page_blocks: list[dict], layout_regions: list[dict] dx = b_cx - lx2 else: dx = 0 - dist = (dx ** 2 + dy ** 2) ** 0.5 + dist = (dx**2 + dy**2) ** 0.5 if dist < min_dist: min_dist = dist best_cx, best_cy = ae["cx"], ae["cy"] b["_lx_center"] = best_cx b["_ly_center"] = best_cy - sorted_blocks = sorted(page_blocks, key=lambda b: ( - b.get("_ly_center", 0), - b.get("_lx_center", 0), - b.get("_y_center", 0), - b.get("_x_center", 0), - )) + sorted_blocks = sorted( + page_blocks, + key=lambda b: ( + b.get("_ly_center", 0), + b.get("_lx_center", 0), + b.get("_y_center", 0), + b.get("_x_center", 0), + ), + ) for b in sorted_blocks: b.pop("_ly_center", None) @@ -2639,6 +2659,7 @@ def _layout_detect_reorder(blocks: list[dict], binary: bytes) -> list[dict]: try: import pdfplumber + pages_blocks: dict[int, list[dict]] = {} for b in blocks: pg = b.get("page", 0) @@ -2658,22 +2679,22 @@ def _layout_detect_reorder(blocks: list[dict], binary: bytes) -> list[dict]: page_bxs = [] for b in pages_blocks[pg]: - page_bxs.append({ - "x0": float(b["x0"]), - "top": float(b["top"]), - "x1": float(b["x1"]), - "bottom": float(b["bottom"]), - "text": b["text"], - "page": pg, - }) + page_bxs.append( + { + "x0": float(b["x0"]), + "top": float(b["top"]), + "x1": float(b["x1"]), + "bottom": float(b["bottom"]), + "text": b["text"], + "page": pg, + } + ) ocr_res_per_page.append(page_bxs) if not image_list: return _layout_aware_reorder(blocks) - tagged_blocks, page_layouts = recognizer( - image_list, ocr_res_per_page, scale_factor=3, thr=0.2, drop=False - ) + tagged_blocks, page_layouts = recognizer(image_list, ocr_res_per_page, scale_factor=3, thr=0.2, drop=False) if not tagged_blocks: logger.warning("Layout detector unavailable, falling back to heuristic sorting") @@ -2697,8 +2718,7 @@ def _layout_detect_reorder(blocks: list[dict], binary: bytes) -> list[dict]: if "page" not in b: b["page"] = 0 - logger.info(f"YOLOv10 detector completed, {len(sorted_all)} total chunks," - f"checked {total_layout_count} layout") + logger.info(f"YOLOv10 detector completed, {len(sorted_all)} total chunks,checked {total_layout_count} layout") return sorted_all except Exception as e: @@ -2706,7 +2726,6 @@ def _layout_detect_reorder(blocks: list[dict], binary: bytes) -> list[dict]: return _layout_aware_reorder(blocks) - def _text_shingles(text: str, n: int = 5) -> set[tuple[int, ...]]: """ Generate text fingerprint set using tiktoken BPE tokenization + n-gram shingling. @@ -2726,7 +2745,7 @@ def _text_shingles(text: str, n: int = 5) -> set[tuple[int, ...]]: if len(tokens) < n: # Text too short: return the entire token sequence as a single shingle return {tuple(tokens)} if tokens else set() - return {tuple(tokens[i:i + n]) for i in range(len(tokens) - n + 1)} + return {tuple(tokens[i : i + n]) for i in range(len(tokens) - n + 1)} def _shingling_jaccard(text1: str, text2: str, n: int = 5) -> float: diff --git a/rag/app/table.py b/rag/app/table.py index 5f4fabd527..fd41e5e273 100644 --- a/rag/app/table.py +++ b/rag/app/table.py @@ -38,6 +38,7 @@ from common import settings logger = logging.getLogger(__name__) + class Excel(ExcelParser): def __call__(self, fnm, binary=None, from_page=0, to_page=MAXIMUM_TASK_PAGE_NUMBER, callback=None, **kwargs): if not binary: @@ -56,8 +57,7 @@ class Excel(ExcelParser): images = Excel._extract_images_from_worksheet(ws, sheetname=sheet_name) pending_cell_images = [] if images: - image_descriptions = vision_figure_parser_figure_xlsx_wrapper(images=images, callback=callback, - **kwargs) + image_descriptions = vision_figure_parser_figure_xlsx_wrapper(images=images, callback=callback, **kwargs) if image_descriptions and len(image_descriptions) == len(images): for i, bf in enumerate(image_descriptions): images[i]["image_description"] = "\n".join(bf[0][1]) @@ -118,15 +118,14 @@ class Excel(ExcelParser): ( ( img["image"], # Image.Image or LazyImage - [img["image_description"]] # description list (must be list) + [img["image_description"]], # description list (must be list) ), [ (0, 0, 0, 0, 0) # dummy position - ] + ], ) ) - callback(0.3, ("Extract records: {}~{}".format(from_page + 1, min(to_page, from_page + rn)) + ( - f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) + callback(0.3, ("Extract records: {}~{}".format(from_page + 1, min(to_page, from_page + rn)) + (f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) return res, tables def _parse_headers(self, ws, rows): @@ -321,15 +320,14 @@ def trans_bool(s): def column_data_type(arr): arr = list(arr) counts = {"int": 0, "float": 0, "text": 0, "datetime": 0, "bool": 0} - trans = {t: f for f, t in - [(int, "int"), (float, "float"), (trans_datatime, "datetime"), (trans_bool, "bool"), (str, "text")]} + trans = {t: f for f, t in [(int, "int"), (float, "float"), (trans_datatime, "datetime"), (trans_bool, "bool"), (str, "text")]} float_flag = False for a in arr: if a is None: continue if re.match(r"[+-]?[0-9]+$", str(a).replace("%%", "")) and not str(a).replace("%%", "").startswith("0"): counts["int"] += 1 - if int(str(a)) > 2 ** 63 - 1: + if int(str(a)) > 2**63 - 1: float_flag = True break elif re.match(r"[+-]?[0-9.]{,19}$", str(a).replace("%%", "")) and not str(a).replace("%%", "").startswith("0"): @@ -402,8 +400,7 @@ def chunk(filename, binary=None, from_page=0, to_page=MAXIMUM_TASK_PAGE_NUMBER, continue rows.append(row) - callback(0.3, ("Extract records: {}~{}".format(from_page, min(len(lines), to_page)) + ( - f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) + callback(0.3, ("Extract records: {}~{}".format(from_page, min(len(lines), to_page)) + (f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) dfs = [pd.DataFrame(np.array(rows), columns=headers)] elif re.search(r"\.csv$", filename, re.IGNORECASE): @@ -420,17 +417,13 @@ def chunk(filename, binary=None, from_page=0, to_page=MAXIMUM_TASK_PAGE_NUMBER, fails = [] rows = [] - for i, row in enumerate(all_rows[1 + from_page: 1 + to_page]): + for i, row in enumerate(all_rows[1 + from_page : 1 + to_page]): if len(row) != len(headers): fails.append(str(i + from_page)) continue rows.append(row) - callback( - 0.3, - (f"Extract records: {from_page}~{from_page + len(rows)}" + - (f"{len(fails)} failure, line: {','.join(fails[:3])}..." if fails else "")) - ) + callback(0.3, (f"Extract records: {from_page}~{from_page + len(rows)}" + (f"{len(fails)} failure, line: {','.join(fails[:3])}..." if fails else ""))) dfs = [pd.DataFrame(rows, columns=headers)] else: @@ -446,10 +439,7 @@ def chunk(filename, binary=None, from_page=0, to_page=MAXIMUM_TASK_PAGE_NUMBER, column_roles = parser_config.get("table_column_roles") or {} else: column_roles = {} - logger.debug( - f"[TABLE_PARSER_DEBUG] effective table_column_mode={parser_config.get('table_column_mode')!r}, " - f"column_roles keys={list(column_roles.keys())}" - ) + logger.debug(f"[TABLE_PARSER_DEBUG] effective table_column_mode={parser_config.get('table_column_mode')!r}, column_roles keys={list(column_roles.keys())}") # Pass 1: infer columns per sheet (multi-sheet Excel => multiple DataFrames). Merge field_map and # table_column_names, then update KB once so the UI role selector sees all columns, not only the last sheet. @@ -474,23 +464,13 @@ def chunk(filename, binary=None, from_page=0, to_page=MAXIMUM_TASK_PAGE_NUMBER, df[clmns[j]] = cln if ty == "text": txts.extend([str(c) for c in cln if c]) - clmns_map = [(py_clmns[i].lower() + fields_map[clmn_tys[i]], str(clmns[i]).replace("_", " ")) for i in - range(len(clmns))] + clmns_map = [(py_clmns[i].lower() + fields_map[clmn_tys[i]], str(clmns[i]).replace("_", " ")) for i in range(len(clmns))] # field_map: only columns stored in chunk_data (metadata or both) — used for retrieval/SQL - stored_indices = [ - i for i in range(len(clmns)) - if column_roles.get(clmns[i], "both") in ("metadata", "both") - ] + stored_indices = [i for i in range(len(clmns)) if column_roles.get(clmns[i], "both") in ("metadata", "both")] if settings.DOC_ENGINE_INFINITY or settings.DOC_ENGINE_OCEANBASE: - field_map = { - py_clmns[i].lower(): str(clmns[i]).replace("_", " ") - for i in stored_indices - } + field_map = {py_clmns[i].lower(): str(clmns[i]).replace("_", " ") for i in stored_indices} else: - field_map = { - clmns_map[i][0]: clmns_map[i][1] - for i in stored_indices - } + field_map = {clmns_map[i][0]: clmns_map[i][1] for i in stored_indices} logging.debug(f"Field map (sheet): {field_map}") sheet_specs.append( { @@ -570,13 +550,9 @@ def chunk(filename, binary=None, from_page=0, to_page=MAXIMUM_TASK_PAGE_NUMBER, formatted_text = "\n".join([f"- {field}: {value}" for field, value in text_fields]) if text_fields else "" tokenize(d, formatted_text, eng) if _debug_row_idx == 1: - logger.debug( - f"[TABLE_PARSER_DEBUG] Chunk content_with_weight length: {len(d.get('content_with_weight', '') or '')}" - ) + logger.debug(f"[TABLE_PARSER_DEBUG] Chunk content_with_weight length: {len(d.get('content_with_weight', '') or '')}") _cd = d.get("chunk_data") - logger.debug( - f"[TABLE_PARSER_DEBUG] Chunk chunk_data keys: {list(_cd.keys()) if isinstance(_cd, dict) else 'N/A'}" - ) + logger.debug(f"[TABLE_PARSER_DEBUG] Chunk chunk_data keys: {list(_cd.keys()) if isinstance(_cd, dict) else 'N/A'}") if not (settings.DOC_ENGINE_INFINITY or settings.DOC_ENGINE_OCEANBASE): _extra = [k for k in d if k not in ("docnm_kwd", "title_tks", "content_with_weight", "content_ltks", "content_sm_ltks")] logger.debug(f"[TABLE_PARSER_DEBUG] Chunk ES extra field keys (sample): {_extra[:20]}") @@ -592,9 +568,7 @@ def chunk(filename, binary=None, from_page=0, to_page=MAXIMUM_TASK_PAGE_NUMBER, if __name__ == "__main__": import sys - def dummy(prog=None, msg=""): pass - chunk(sys.argv[1], callback=dummy) diff --git a/rag/app/tag.py b/rag/app/tag.py index d43692f612..43d50b8abc 100644 --- a/rag/app/tag.py +++ b/rag/app/tag.py @@ -36,22 +36,19 @@ def beAdoc(d, q, a, eng, row_num=-1): def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs): """ - Excel and csv(txt) format files are supported. - If the file is in Excel format, there should be 2 column content and tags without header. - And content column is ahead of tags column. - And it's O.K if it has multiple sheets as long as the columns are rightly composed. + Excel and csv(txt) format files are supported. + If the file is in Excel format, there should be 2 column content and tags without header. + And content column is ahead of tags column. + And it's O.K if it has multiple sheets as long as the columns are rightly composed. - If it's in csv format, it should be UTF-8 encoded. Use TAB as delimiter to separate content and tags. + If it's in csv format, it should be UTF-8 encoded. Use TAB as delimiter to separate content and tags. - All the deformed lines will be ignored. - Every pair will be treated as a chunk. + All the deformed lines will be ignored. + Every pair will be treated as a chunk. """ eng = lang.lower() == "english" res = [] - doc = { - "docnm_kwd": filename, - "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename)) - } + doc = {"docnm_kwd": filename, "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))} if re.search(r"\.xlsx?$", filename, re.IGNORECASE): callback(0.1, "Start to parse.") excel_parser = Excel() @@ -84,11 +81,9 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs): content = "" i += 1 if len(res) % 999 == 0: - callback(len(res) * 0.6 / len(lines), ("Extract TAG: {}".format(len(res)) + ( - f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) + callback(len(res) * 0.6 / len(lines), ("Extract TAG: {}".format(len(res)) + (f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) - callback(0.6, ("Extract TAG: {}".format(len(res)) + ( - f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) + callback(0.6, ("Extract TAG: {}".format(len(res)) + (f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) return res @@ -111,20 +106,18 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs): res.append(beAdoc(deepcopy(doc), content, row[1], eng, i)) content = "" if len(res) % 999 == 0: - callback(len(res) * 0.6 / len(lines), ("Extract Tags: {}".format(len(res)) + ( - f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) + callback(len(res) * 0.6 / len(lines), ("Extract Tags: {}".format(len(res)) + (f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) - callback(0.6, ("Extract TAG : {}".format(len(res)) + ( - f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) + callback(0.6, ("Extract TAG : {}".format(len(res)) + (f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) return res - raise NotImplementedError( - "Excel, csv(txt) format files are supported.") + raise NotImplementedError("Excel, csv(txt) format files are supported.") def label_question(question, kbs): from api.db.services.knowledgebase_service import KnowledgebaseService from rag.graphrag.utils import get_tags_from_cache, set_tags_to_cache + tags = None tag_kb_ids = [] for kb in kbs: @@ -140,21 +133,14 @@ def label_question(question, kbs): tag_kbs = KnowledgebaseService.get_by_ids(tag_kb_ids) if not tag_kbs: return tags - tags = settings.retriever.tag_query(question, - list(set([kb.tenant_id for kb in tag_kbs])), - tag_kb_ids, - all_tags, - kb.parser_config.get("topn_tags", 3) - ) + tags = settings.retriever.tag_query(question, list(set([kb.tenant_id for kb in tag_kbs])), tag_kb_ids, all_tags, kb.parser_config.get("topn_tags", 3)) return tags if __name__ == "__main__": import sys - def dummy(prog=None, msg=""): pass - chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy) diff --git a/rag/benchmark.py b/rag/benchmark.py index 9b069c1d3a..7866ab6930 100644 --- a/rag/benchmark.py +++ b/rag/benchmark.py @@ -45,8 +45,8 @@ class Benchmark: self.vector_similarity_weight = self.kb.vector_similarity_weight embd_model_config = get_model_config_from_provider_instance(self.kb.tenant_id, LLMType.EMBEDDING, self.kb.embd_id) self.embd_mdl = LLMBundle(self.kb.tenant_id, embd_model_config, lang=self.kb.language) - self.tenant_id = '' - self.index_name = '' + self.tenant_id = "" + self.index_name = "" self.initialized_index = False def _get_retrieval(self, qrels): @@ -55,8 +55,7 @@ class Benchmark: run = defaultdict(dict) query_list = list(qrels.keys()) for query in query_list: - ranks = asyncio.run(settings.retriever.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30, - 0.0, self.vector_similarity_weight)) + ranks = asyncio.run(settings.retriever.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30, 0.0, self.vector_similarity_weight)) if len(ranks["chunks"]) == 0: print(f"deleted query: {query}") del qrels[query] @@ -101,14 +100,9 @@ class Benchmark: for i in tqdm(range(len(data)), colour="green", desc="Tokenizing:" + fn): if docs_count >= max_docs: break - query = data.iloc[i]['query'] - for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']): - d = { - "id": get_uuid(), - "kb_id": self.kb.id, - "docnm_kwd": "xxxxx", - "doc_id": "ksksks" - } + query = data.iloc[i]["query"] + for rel, text in zip(data.iloc[i]["passages"]["is_selected"], data.iloc[i]["passages"]["passage_text"]): + d = {"id": get_uuid(), "kb_id": self.kb.id, "docnm_kwd": "xxxxx", "doc_id": "ksksks"} tokenize(d, text, "english") docs.append(d) texts[d["id"]] = text @@ -141,15 +135,9 @@ class Benchmark: for i in tqdm(range(len(data)), colour="green", desc="Indexing:" + fn): if docs_count >= max_docs: break - query = data.iloc[i]['question'] - for rel, text in zip(data.iloc[i]["search_results"]['rank'], - data.iloc[i]["search_results"]['search_context']): - d = { - "id": get_uuid(), - "kb_id": self.kb.id, - "docnm_kwd": "xxxxx", - "doc_id": "ksksks" - } + query = data.iloc[i]["question"] + for rel, text in zip(data.iloc[i]["search_results"]["rank"], data.iloc[i]["search_results"]["search_context"]): + d = {"id": get_uuid(), "kb_id": self.kb.id, "docnm_kwd": "xxxxx", "doc_id": "ksksks"} tokenize(d, text, "english") docs.append(d) texts[d["id"]] = text @@ -158,7 +146,7 @@ class Benchmark: docs_count += len(docs) docs, vector_size = self.embedding(docs) self.init_index(vector_size) - settings.docStoreConn.insert(docs,self.index_name) + settings.docStoreConn.insert(docs, self.index_name) docs = [] docs, vector_size = self.embedding(docs) @@ -171,41 +159,35 @@ class Benchmark: for corpus_file in os.listdir(corpus_path): tmp_data = pd.read_json(os.path.join(corpus_path, corpus_file), lines=True) for index, i in tmp_data.iterrows(): - corpus_total[i['docid']] = i['text'] + corpus_total[i["docid"]] = i["text"] topics_total = {} - for topics_file in os.listdir(os.path.join(file_path, 'topics')): - if 'test' in topics_file: + for topics_file in os.listdir(os.path.join(file_path, "topics")): + if "test" in topics_file: continue - tmp_data = pd.read_csv(os.path.join(file_path, 'topics', topics_file), sep='\t', names=['qid', 'query']) + tmp_data = pd.read_csv(os.path.join(file_path, "topics", topics_file), sep="\t", names=["qid", "query"]) for index, i in tmp_data.iterrows(): - topics_total[i['qid']] = i['query'] + topics_total[i["qid"]] = i["query"] qrels = defaultdict(dict) texts = defaultdict(dict) docs_count = 0 docs = [] - for qrels_file in os.listdir(os.path.join(file_path, 'qrels')): - if 'test' in qrels_file: + for qrels_file in os.listdir(os.path.join(file_path, "qrels")): + if "test" in qrels_file: continue if docs_count >= max_docs: break - tmp_data = pd.read_csv(os.path.join(file_path, 'qrels', qrels_file), sep='\t', - names=['qid', 'Q0', 'docid', 'relevance']) + tmp_data = pd.read_csv(os.path.join(file_path, "qrels", qrels_file), sep="\t", names=["qid", "Q0", "docid", "relevance"]) for i in tqdm(range(len(tmp_data)), colour="green", desc="Indexing:" + qrels_file): if docs_count >= max_docs: break - query = topics_total[tmp_data.iloc[i]['qid']] - text = corpus_total[tmp_data.iloc[i]['docid']] - rel = tmp_data.iloc[i]['relevance'] - d = { - "id": get_uuid(), - "kb_id": self.kb.id, - "docnm_kwd": "xxxxx", - "doc_id": "ksksks" - } - tokenize(d, text, 'english') + query = topics_total[tmp_data.iloc[i]["qid"]] + text = corpus_total[tmp_data.iloc[i]["docid"]] + rel = tmp_data.iloc[i]["relevance"] + d = {"id": get_uuid(), "kb_id": self.kb.id, "docnm_kwd": "xxxxx", "doc_id": "ksksks"} + tokenize(d, text, "english") docs.append(d) texts[d["id"]] = text qrels[query][d["id"]] = int(rel) @@ -226,22 +208,21 @@ class Benchmark: run_keys = list(run.keys()) for run_i in tqdm(range(len(run_keys)), desc="Calculating ndcg@10 for single query"): key = run_keys[run_i] - keep_result.append({'query': key, 'qrel': qrels[key], 'run': run[key], - 'ndcg@10': evaluate({key: qrels[key]}, {key: run[key]}, "ndcg@10")}) - keep_result = sorted(keep_result, key=lambda kk: kk['ndcg@10']) - with open(os.path.join(file_path, dataset + 'result.md'), 'w', encoding='utf-8') as f: - f.write('## Score For Every Query\n') + keep_result.append({"query": key, "qrel": qrels[key], "run": run[key], "ndcg@10": evaluate({key: qrels[key]}, {key: run[key]}, "ndcg@10")}) + keep_result = sorted(keep_result, key=lambda kk: kk["ndcg@10"]) + with open(os.path.join(file_path, dataset + "result.md"), "w", encoding="utf-8") as f: + f.write("## Score For Every Query\n") for keep_result_i in keep_result: - f.write('### query: ' + keep_result_i['query'] + ' ndcg@10:' + str(keep_result_i['ndcg@10']) + '\n') - scores = [[i[0], i[1]] for i in keep_result_i['run'].items()] + f.write("### query: " + keep_result_i["query"] + " ndcg@10:" + str(keep_result_i["ndcg@10"]) + "\n") + scores = [[i[0], i[1]] for i in keep_result_i["run"].items()] scores = sorted(scores, key=lambda kk: kk[1]) for score in scores[:10]: - f.write('- text: ' + str(texts[score[0]]) + '\t qrel: ' + str(score[1]) + '\n') - json.dump(qrels, open(os.path.join(file_path, dataset + '.qrels.json'), "w+", encoding='utf-8'), indent=2) - json.dump(run, open(os.path.join(file_path, dataset + '.run.json'), "w+", encoding='utf-8'), indent=2) - print(os.path.join(file_path, dataset + '_result.md'), 'Saved!') + f.write("- text: " + str(texts[score[0]]) + "\t qrel: " + str(score[1]) + "\n") + json.dump(qrels, open(os.path.join(file_path, dataset + ".qrels.json"), "w+", encoding="utf-8"), indent=2) + json.dump(run, open(os.path.join(file_path, dataset + ".run.json"), "w+", encoding="utf-8"), indent=2) + print(os.path.join(file_path, dataset + "_result.md"), "Saved!") - def __call__(self, dataset, file_path, miracl_corpus=''): + def __call__(self, dataset, file_path, miracl_corpus=""): if dataset == "ms_marco_v1.1": self.tenant_id = "benchmark_ms_marco_v11" self.index_name = search.index_name(self.tenant_id) @@ -257,39 +238,40 @@ class Benchmark: print(dataset, evaluate(Qrels(qrels), Run(run), ["ndcg@10", "map@5", "mrr@10"])) self.save_results(qrels, run, texts, dataset, file_path) if dataset == "miracl": - for lang in ['ar', 'bn', 'de', 'en', 'es', 'fa', 'fi', 'fr', 'hi', 'id', 'ja', 'ko', 'ru', 'sw', 'te', 'th', - 'yo', 'zh']: - if not os.path.isdir(os.path.join(file_path, 'miracl-v1.0-' + lang)): - print('Directory: ' + os.path.join(file_path, 'miracl-v1.0-' + lang) + ' not found!') + for lang in ["ar", "bn", "de", "en", "es", "fa", "fi", "fr", "hi", "id", "ja", "ko", "ru", "sw", "te", "th", "yo", "zh"]: + if not os.path.isdir(os.path.join(file_path, "miracl-v1.0-" + lang)): + print("Directory: " + os.path.join(file_path, "miracl-v1.0-" + lang) + " not found!") continue - if not os.path.isdir(os.path.join(file_path, 'miracl-v1.0-' + lang, 'qrels')): - print('Directory: ' + os.path.join(file_path, 'miracl-v1.0-' + lang, 'qrels') + 'not found!') + if not os.path.isdir(os.path.join(file_path, "miracl-v1.0-" + lang, "qrels")): + print("Directory: " + os.path.join(file_path, "miracl-v1.0-" + lang, "qrels") + "not found!") continue - if not os.path.isdir(os.path.join(file_path, 'miracl-v1.0-' + lang, 'topics')): - print('Directory: ' + os.path.join(file_path, 'miracl-v1.0-' + lang, 'topics') + 'not found!') + if not os.path.isdir(os.path.join(file_path, "miracl-v1.0-" + lang, "topics")): + print("Directory: " + os.path.join(file_path, "miracl-v1.0-" + lang, "topics") + "not found!") continue - if not os.path.isdir(os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang)): - print('Directory: ' + os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang) + ' not found!') + if not os.path.isdir(os.path.join(miracl_corpus, "miracl-corpus-v1.0-" + lang)): + print("Directory: " + os.path.join(miracl_corpus, "miracl-corpus-v1.0-" + lang) + " not found!") continue self.tenant_id = "benchmark_miracl_" + lang self.index_name = search.index_name(self.tenant_id) self.initialized_index = False - qrels, texts = self.miracl_index(os.path.join(file_path, 'miracl-v1.0-' + lang), - os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang), - "benchmark_miracl_" + lang) + qrels, texts = self.miracl_index(os.path.join(file_path, "miracl-v1.0-" + lang), os.path.join(miracl_corpus, "miracl-corpus-v1.0-" + lang), "benchmark_miracl_" + lang) run = self._get_retrieval(qrels) print(dataset, evaluate(Qrels(qrels), Run(run), ["ndcg@10", "map@5", "mrr@10"])) self.save_results(qrels, run, texts, dataset, file_path) -if __name__ == '__main__': - print('*****************RAGFlow Benchmark*****************') - parser = argparse.ArgumentParser(usage="benchmark.py [])", description='RAGFlow Benchmark') - parser.add_argument('max_docs', metavar='max_docs', type=int, help='max docs to evaluate') - parser.add_argument('kb_id', metavar='kb_id', help='dataset id') - parser.add_argument('dataset', metavar='dataset', help='dataset name, shall be one of ms_marco_v1.1(https://huggingface.co/datasets/microsoft/ms_marco), trivia_qa(https://huggingface.co/datasets/mandarjoshi/trivia_qa>), miracl(https://huggingface.co/datasets/miracl/miracl') - parser.add_argument('dataset_path', metavar='dataset_path', help='dataset path') - parser.add_argument('miracl_corpus_path', metavar='miracl_corpus_path', nargs='?', default="", help='miracl corpus path. Only needed when dataset is miracl') +if __name__ == "__main__": + print("*****************RAGFlow Benchmark*****************") + parser = argparse.ArgumentParser(usage="benchmark.py [])", description="RAGFlow Benchmark") + parser.add_argument("max_docs", metavar="max_docs", type=int, help="max docs to evaluate") + parser.add_argument("kb_id", metavar="kb_id", help="dataset id") + parser.add_argument( + "dataset", + metavar="dataset", + help="dataset name, shall be one of ms_marco_v1.1(https://huggingface.co/datasets/microsoft/ms_marco), trivia_qa(https://huggingface.co/datasets/mandarjoshi/trivia_qa>), miracl(https://huggingface.co/datasets/miracl/miracl", + ) + parser.add_argument("dataset_path", metavar="dataset_path", help="dataset path") + parser.add_argument("miracl_corpus_path", metavar="miracl_corpus_path", nargs="?", default="", help="miracl corpus path. Only needed when dataset is miracl") args = parser.parse_args() max_docs = args.max_docs @@ -303,7 +285,7 @@ if __name__ == '__main__': ex(dataset, dataset_path) elif dataset == "miracl": if len(args) < 5: - print('Please input the correct parameters!') + print("Please input the correct parameters!") exit(1) miracl_corpus_path = args[4] ex(dataset, dataset_path, miracl_corpus=args.miracl_corpus_path) diff --git a/rag/flow/base.py b/rag/flow/base.py index 03005dc034..2581cc9656 100644 --- a/rag/flow/base.py +++ b/rag/flow/base.py @@ -43,10 +43,7 @@ class ProcessBase(ComponentBase): for k, v in kwargs.items(): self.set_output(k, v) try: - await asyncio.wait_for( - self._invoke(**kwargs), - timeout=self._param.timeout - ) + await asyncio.wait_for(self._invoke(**kwargs), timeout=self._param.timeout) self.callback(1, "Done") except Exception as e: if self.get_exception_default_value(): diff --git a/rag/flow/chunker/title_chunker/common.py b/rag/flow/chunker/title_chunker/common.py index 59aad31752..0848ac0ca7 100644 --- a/rag/flow/chunker/title_chunker/common.py +++ b/rag/flow/chunker/title_chunker/common.py @@ -61,7 +61,6 @@ class BaseTitleChunker(ABC): self.param = process._param self.from_upstream = from_upstream - async def invoke(self): self.process.set_output("output_format", "chunks") self.process.callback(random.randint(1, 5) / 100.0, self.start_message) @@ -71,7 +70,6 @@ class BaseTitleChunker(ABC): await self.set_chunks(chunks) self.process.callback(1, "Done.") - def extract_line_records(self): """ Normalize all upstream input payloads into a unified ordered record stream. @@ -79,8 +77,9 @@ class BaseTitleChunker(ABC): decoupling downstream chunking strategies from different upstream output formats. """ import logging + logger = logging.getLogger(__name__) - + payload = None # Extract raw content payload based on upstream output format type if self.from_upstream.output_format == "markdown": @@ -95,7 +94,7 @@ class BaseTitleChunker(ABC): if payload is not None: lines = payload.split("\n") input_line_count = len(lines) - + # Format-branched text processing to preserve original document semantics # Plain text: perform full whitespace stripping and invalid empty line filtering if self.from_upstream.output_format == "text": @@ -103,23 +102,12 @@ class BaseTitleChunker(ABC): # Markdown & HTML: retain original indentation/spacing, only filter pure blank lines else: clean_lines = [line for line in lines if line.strip()] - + output_line_count = len(clean_lines) # Production observability log: added format dimension per project coding guidelines - logger.info( - f"payload filter: format={self.from_upstream.output_format} before={input_line_count} after={output_line_count}" - ) - - return [ - { - "text": line, - "doc_type_kwd": "text", - "img_id": None, - "layout": "", - PDF_POSITIONS_KEY: [] - } - for line in clean_lines - ] + logger.info(f"payload filter: format={self.from_upstream.output_format} before={input_line_count} after={output_line_count}") + + return [{"text": line, "doc_type_kwd": "text", "img_id": None, "layout": "", PDF_POSITIONS_KEY: []} for line in clean_lines] items = self.from_upstream.chunks if self.from_upstream.output_format == "chunks" else self.from_upstream.json_result return [ { @@ -134,17 +122,11 @@ class BaseTitleChunker(ABC): def extract_outlines(self): file = self.from_upstream.file or {} - source = ( - file.get("blob") - or file.get("binary") - or file.get("path") - or file.get("name") - ) + source = file.get("blob") or file.get("binary") or file.get("path") or file.get("name") if not source: return [] return extract_pdf_outlines(source) - @staticmethod def match_regex_level(text, level_group): stripped = text.strip() @@ -153,7 +135,6 @@ class BaseTitleChunker(ABC): return level return None - @staticmethod def select_level_group(lines, raw_levels): if not raw_levels: @@ -185,21 +166,18 @@ class BaseTitleChunker(ABC): return [] return [pattern for pattern in raw_levels[selected] if pattern] - @staticmethod def match_layout_level(text, layout, fallback_level): if re.search(r"(section|title|head)", layout, re.I) and not not_title(text.split("@")[0].strip()): return fallback_level return BODY_LEVEL - @staticmethod def _outline_similarity(left, right): left_pairs = {left[i] + left[i + 1] for i in range(len(left) - 1)} right_pairs = {right[i] + right[i + 1] for i in range(min(len(left), len(right) - 1))} return len(left_pairs & right_pairs) / max(len(left_pairs), len(right_pairs), 1) - def resolve_outline_levels(self, line_records): outlines = self.extract_outlines() if not line_records or len(outlines) / len(line_records) <= 0.03: @@ -225,7 +203,6 @@ class BaseTitleChunker(ABC): "source": "outline", } - def resolve_frequency_levels(self, line_records): level_group = self.select_level_group( [record["text"] for record in line_records], @@ -254,29 +231,23 @@ class BaseTitleChunker(ABC): if level < BODY_LEVEL: most_level = level break - + return { "levels": levels, "most_level": most_level, "source": "frequency", } - def resolve_title_levels(self, line_records): return self.resolve_outline_levels(line_records) or self.resolve_frequency_levels(line_records) - def build_chunks_from_record_groups(self, record_groups): # Strategy code decides record grouping. This method materializes each # group into the output chunk representation. For PDF-like inputs, the # chunk box is defined by merged source positions and the text payload # is normalized by removing parser tags. if self.from_upstream.output_format in ["markdown", "text", "html"]: - chunks = [ - {"text": "".join(record["text"] + "\n" for record in records)} - for records in record_groups - if records - ] + chunks = [{"text": "".join(record["text"] + "\n" for record in records)} for records in record_groups if records] else: chunks = [ ( @@ -296,19 +267,18 @@ class BaseTitleChunker(ABC): for records in record_groups if records ] - + if self.param.root_chunk_as_heading and len(chunks) > 1: root_chunk = chunks[0] root_text = root_chunk.get("text", "") for ck in chunks[1:]: - ck['text'] = root_text + "\n" + ck.get("text", "") - + ck["text"] = root_text + "\n" + ck.get("text", "") + return chunks[1:] return chunks - async def set_chunks(self, chunks): if self.from_upstream.output_format in ["markdown", "text", "html"]: self.process.set_output("chunks", chunks) @@ -319,12 +289,10 @@ class BaseTitleChunker(ABC): await restore_pdf_text_previews(chunks, self.from_upstream, self.process._canvas) self.process.set_output("chunks", [finalize_pdf_chunk(deepcopy(chunk)) for chunk in chunks]) - @abstractmethod def resolve_levels(self, line_records): raise NotImplementedError() - @abstractmethod def build_chunks(self, line_records, resolved): raise NotImplementedError() diff --git a/rag/flow/chunker/title_chunker/group_chunker.py b/rag/flow/chunker/title_chunker/group_chunker.py index ca43a2d0be..bb4d5e9777 100644 --- a/rag/flow/chunker/title_chunker/group_chunker.py +++ b/rag/flow/chunker/title_chunker/group_chunker.py @@ -45,7 +45,6 @@ class GroupTitleChunker(BaseTitleChunker): def resolve_levels(self, line_records): return self.resolve_title_levels(line_records) - def build_chunks(self, line_records, resolved): target_level = _resolve_group_target_level( resolved["levels"], @@ -73,14 +72,7 @@ class GroupTitleChunker(BaseTitleChunker): continue token_count = num_tokens_from_string(text) - should_merge = ( - record_groups - and record_groups[-1][0]["doc_type_kwd"] == "text" - and ( - tk_cnt < MIN_GROUP_TOKENS - or (tk_cnt < MAX_GROUP_TOKENS and sec_id == last_sid) - ) - ) + should_merge = record_groups and record_groups[-1][0]["doc_type_kwd"] == "text" and (tk_cnt < MIN_GROUP_TOKENS or (tk_cnt < MAX_GROUP_TOKENS and sec_id == last_sid)) if should_merge: record_groups[-1].append(record) diff --git a/rag/flow/chunker/title_chunker/hierarchy_chunker.py b/rag/flow/chunker/title_chunker/hierarchy_chunker.py index 430bd2240f..4349904710 100644 --- a/rag/flow/chunker/title_chunker/hierarchy_chunker.py +++ b/rag/flow/chunker/title_chunker/hierarchy_chunker.py @@ -26,15 +26,12 @@ class _ChunkNode: self.body_indexes = body_indexes or [] self.children = [] - def add_child(self, child): self.children.append(child) - def add_body_index(self, index): self.body_indexes.append(index) - def build_tree(self, indexed_lines, depth): stack = [self] for level, index in indexed_lines: @@ -51,13 +48,11 @@ class _ChunkNode: return self - def get_paths(self, depth, include_heading_content): chunk_paths = [] self._dfs(chunk_paths, [], depth, include_heading_content) return chunk_paths - def _dfs(self, chunk_paths, titles, depth, include_heading_content): if self.level == 0 and self.body_indexes: chunk_paths.append(titles + self.body_indexes) @@ -70,11 +65,7 @@ class _ChunkNode: elif not self.children and 1 <= self.level <= depth: chunk_paths.append(path_titles) else: - path_titles = ( - titles + self.title_indexes + self.body_indexes - if 1 <= self.level <= depth - else titles - ) + path_titles = titles + self.title_indexes + self.body_indexes if 1 <= self.level <= depth else titles if not self.children and 1 <= self.level <= depth: chunk_paths.append(path_titles) @@ -89,7 +80,6 @@ class HierarchyTitleChunker(BaseTitleChunker): def resolve_levels(self, line_records): return self.resolve_title_levels(line_records) - def build_chunks(self, line_records, resolved): record_groups = [] text_records = [] diff --git a/rag/flow/chunker/title_chunker/title_chunker.py b/rag/flow/chunker/title_chunker/title_chunker.py index 7fc005b1df..7b21712ac8 100644 --- a/rag/flow/chunker/title_chunker/title_chunker.py +++ b/rag/flow/chunker/title_chunker/title_chunker.py @@ -18,6 +18,7 @@ from rag.flow.chunker.title_chunker.group_chunker import GroupTitleChunker from rag.flow.chunker.title_chunker.hierarchy_chunker import HierarchyTitleChunker from rag.flow.chunker.title_chunker.schema import TitleChunkerFromUpstream + class TitleChunker(ProcessBase): component_name = "TitleChunker" diff --git a/rag/flow/chunker/token_chunker.py b/rag/flow/chunker/token_chunker.py index 7df4b43005..d39a6583e4 100644 --- a/rag/flow/chunker/token_chunker.py +++ b/rag/flow/chunker/token_chunker.py @@ -292,6 +292,7 @@ def _split_chunk_docs_by_children(chunks, pattern): return docs + class TokenChunker(ProcessBase): component_name = "TokenChunker" @@ -316,11 +317,15 @@ class TokenChunker(ProcessBase): self.set_output("chunks", [{"text": payload}] if payload.strip() else []) self.callback(1, "Done.") return - cks = _split_text_by_pattern(payload, delimiter_pattern) if delimiter_pattern else naive_merge( - payload, - self._param.chunk_token_size, - "", - overlapped_percent, + cks = ( + _split_text_by_pattern(payload, delimiter_pattern) + if delimiter_pattern + else naive_merge( + payload, + self._param.chunk_token_size, + "", + overlapped_percent, + ) ) if custom_pattern: docs = [] diff --git a/rag/flow/file.py b/rag/flow/file.py index f35f3d2114..a280810273 100644 --- a/rag/flow/file.py +++ b/rag/flow/file.py @@ -38,13 +38,13 @@ class File(ProcessBase): self.set_output("_ERROR", f"Document({self._canvas._doc_id}) not found!") return - #b, n = File2DocumentService.get_storage_address(doc_id=self._canvas._doc_id) - #self.set_output("blob", STORAGE_IMPL.get(b, n)) + # b, n = File2DocumentService.get_storage_address(doc_id=self._canvas._doc_id) + # self.set_output("blob", STORAGE_IMPL.get(b, n)) self.set_output("name", doc.name) else: file = kwargs.get("file")[0] self.set_output("name", file["name"]) self.set_output("file", file) - #self.set_output("blob", FileService.get_blob(file["created_by"], file["id"])) + # self.set_output("blob", FileService.get_blob(file["created_by"], file["id"])) self.callback(1, "File fetched.") diff --git a/rag/flow/parser/pdf_chunk_metadata.py b/rag/flow/parser/pdf_chunk_metadata.py index 175ac3772e..62314f4291 100644 --- a/rag/flow/parser/pdf_chunk_metadata.py +++ b/rag/flow/parser/pdf_chunk_metadata.py @@ -51,15 +51,9 @@ def _extract_raw_positions(item): position_int = item.get("position_int") if isinstance(position_int, list): - return [ - list(pos) - for pos in position_int - if isinstance(pos, (list, tuple)) and len(pos) >= 5 - ] + return [list(pos) for pos in position_int if isinstance(pos, (list, tuple)) and len(pos) >= 5] - if item.get("page_number") is not None and all( - item.get(key) is not None for key in ["x0", "x1", "top", "bottom"] - ): + if item.get("page_number") is not None and all(item.get(key) is not None for key in ["x0", "x1", "top", "bottom"]): return [[item["page_number"], item["x0"], item["x1"], item["top"], item["bottom"]]] return [] @@ -90,9 +84,7 @@ def extract_pdf_positions(item): elif page_number <= 0: page_number += 1 - normalized_positions.append( - [page_number, float(pos[1]), float(pos[2]), float(pos[3]), float(pos[4])] - ) + normalized_positions.append([page_number, float(pos[1]), float(pos[2]), float(pos[3]), float(pos[4])]) except (TypeError, ValueError): continue @@ -120,12 +112,7 @@ def normalize_pdf_items_metadata(items): def reorder_multi_column_bboxes(pdf_parser, bboxes, zoom=PDF_MULTI_COLUMN_ZOOM): - text_boxes = [ - box - for box in bboxes - if box.get("layout_type") == "text" - and all(box.get(key) is not None for key in ["x0", "x1", "page_number"]) - ] + text_boxes = [box for box in bboxes if box.get("layout_type") == "text" and all(box.get(key) is not None for key in ["x0", "x1", "page_number"])] if not text_boxes or not pdf_parser.page_images: return bboxes @@ -211,10 +198,7 @@ def _fetch_source_blob(from_upstream, canvas): def _load_pdf_page_images(blob, zoom=PDF_PREVIEW_ZOOM): with sys.modules[LOCK_KEY_pdfplumber]: with pdfplumber.open(io.BytesIO(blob)) as pdf: - return [ - page.to_image(resolution=72 * zoom, antialias=True).annotated - for page in pdf.pages - ] + return [page.to_image(resolution=72 * zoom, antialias=True).annotated for page in pdf.pages] def _crop_pdf_preview(page_images, positions, zoom=PDF_PREVIEW_ZOOM): @@ -241,6 +225,7 @@ def _crop_pdf_preview(page_images, positions, zoom=PDF_PREVIEW_ZOOM): max_width = max(right - left for _, left, right, _, _ in normalized_positions) first_page, first_left, _, first_top, _ = normalized_positions[0] last_page, last_left, _, _, last_bottom = normalized_positions[-1] + def page_height(idx): return page_images[idx].size[1] / zoom @@ -253,12 +238,7 @@ def _crop_pdf_preview(page_images, positions, zoom=PDF_PREVIEW_ZOOM): max(first_top - PDF_PREVIEW_GAP, 0), ) ] - crop_positions.extend( - [ - ([page_idx], left, right, top, bottom) - for page_idx, left, right, top, bottom in normalized_positions - ] - ) + crop_positions.extend([([page_idx], left, right, top, bottom) for page_idx, left, right, top, bottom in normalized_positions]) crop_positions.append( ( [last_page], @@ -272,9 +252,7 @@ def _crop_pdf_preview(page_images, positions, zoom=PDF_PREVIEW_ZOOM): imgs = [] for idx, (pages, left, right, top, bottom) in enumerate(crop_positions): page_idx = pages[0] - effective_right = ( - left + max_width if idx in {0, len(crop_positions) - 1} else max(left + 10, right) - ) + effective_right = left + max_width if idx in {0, len(crop_positions) - 1} else max(left + 10, right) imgs.append( page_images[page_idx].crop( ( @@ -309,11 +287,7 @@ async def restore_pdf_text_previews(chunks, from_upstream, canvas): if not chunks or not str(from_upstream.name).lower().endswith(".pdf"): return - text_chunks = [ - chunk - for chunk in chunks - if chunk.get("doc_type_kwd", "text") == "text" and extract_pdf_positions(chunk) - ] + text_chunks = [chunk for chunk in chunks if chunk.get("doc_type_kwd", "text") == "text" and extract_pdf_positions(chunk)] if not text_chunks: return diff --git a/rag/flow/parser/utils.py b/rag/flow/parser/utils.py index 5246acd2fd..48df57ea3c 100644 --- a/rag/flow/parser/utils.py +++ b/rag/flow/parser/utils.py @@ -20,7 +20,8 @@ from bs4 import BeautifulSoup from docx import Document from api.db.services.llm_service import LLMBundle from api.db.joint_services.tenant_model_service import ( - get_tenant_default_model_by_type, get_model_config_from_provider_instance, + get_tenant_default_model_by_type, + get_model_config_from_provider_instance, ) from common.constants import LLMType from deepdoc.parser.figure_parser import VisionFigureParser @@ -68,10 +69,7 @@ def remove_header_footer_docx_sections(items, header_footer_texts): def remove_header_footer_html_blob(blob): soup = BeautifulSoup(blob, "html.parser") - for element in soup.find_all( - lambda tag: tag.name in {"header", "footer"} - or tag.get("role") in {"banner", "contentinfo"} - ): + for element in soup.find_all(lambda tag: tag.name in {"header", "footer"} or tag.get("role") in {"banner", "contentinfo"}): element.decompose() return str(soup).encode("utf-8") @@ -100,7 +98,7 @@ def remove_toc_pdf(items, outlines): for i, (title, level, page_no) in enumerate(outlines): if re.match(r"(contents|目录|目次|table of contents|致谢|acknowledge)$", title.split("@@")[0].strip().lower()): toc_start_page = page_no - for next_title, next_level, next_page_no in outlines[i + 1:]: + for next_title, next_level, next_page_no in outlines[i + 1 :]: if next_level != level: continue if re.match(r"(contents|目录|目次|table of contents|致谢|acknowledge)$", next_title.split("@@")[0].strip().lower()): @@ -172,13 +170,9 @@ def enhance_media_sections_with_vision( try: try: - vision_model_config = get_model_config_from_provider_instance( - tenant_id, LLMType.IMAGE2TEXT, vlm_conf["llm_id"] - ) + vision_model_config = get_model_config_from_provider_instance(tenant_id, LLMType.IMAGE2TEXT, vlm_conf["llm_id"]) except Exception: - vision_model_config = get_tenant_default_model_by_type( - tenant_id, LLMType.IMAGE2TEXT - ) + vision_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.IMAGE2TEXT) vision_model = LLMBundle(tenant_id, vision_model_config) except Exception: return sections @@ -188,7 +182,7 @@ def enhance_media_sections_with_vision( continue if item.get("image") is None: continue - + text = item.get("text") or "" try: parsed = VisionFigureParser( diff --git a/rag/flow/pipeline.py b/rag/flow/pipeline.py index 76e19084e0..75951e4bec 100644 --- a/rag/flow/pipeline.py +++ b/rag/flow/pipeline.py @@ -26,7 +26,7 @@ from rag.utils.redis_conn import REDIS_CONN class Pipeline(Graph): - def __init__(self, dsl: str|dict, tenant_id=None, doc_id=None, task_id=None, flow_id=None): + def __init__(self, dsl: str | dict, tenant_id=None, doc_id=None, task_id=None, flow_id=None): if isinstance(dsl, dict): dsl = json.dumps(dsl, ensure_ascii=False) super().__init__(dsl, tenant_id, task_id) @@ -42,6 +42,7 @@ class Pipeline(Graph): def callback(self, component_name: str, progress: float | int | None = None, message: str = "") -> None: from common.exceptions import TaskCanceledException + log_key = f"{self._flow_id}-{self.task_id}-logs" timestamp = timer() if has_canceled(self.task_id): @@ -113,7 +114,6 @@ class Pipeline(Graph): logging.exception(e) return [] - async def run(self, **kwargs): log_key = f"{self._flow_id}-{self.task_id}-logs" try: @@ -130,10 +130,9 @@ class Pipeline(Graph): self.callback(cpn_obj.component_name, -1, self.error) if self._doc_id: - TaskService.update_progress(self.task_id, { - "progress": random.randint(0, 5) / 100.0, - "progress_msg": "Start the pipeline...", - "begin_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}) + TaskService.update_progress( + self.task_id, {"progress": random.randint(0, 5) / 100.0, "progress_msg": "Start the pipeline...", "begin_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")} + ) idx = len(self.path) - 1 cpn_obj = self.get_component_obj(self.path[idx]) @@ -147,9 +146,9 @@ class Pipeline(Graph): async def invoke(): nonlocal last_cpn, cpn_obj await cpn_obj.invoke(**last_cpn.output()) - #if inspect.iscoroutinefunction(cpn_obj.invoke): + # if inspect.iscoroutinefunction(cpn_obj.invoke): # await cpn_obj.invoke(**last_cpn.output()) - #else: + # else: # cpn_obj.invoke(**last_cpn.output()) tasks = [] @@ -168,8 +167,6 @@ class Pipeline(Graph): if not self.error: return self.get_component_obj(self.path[-1]).output() - TaskService.update_progress(self.task_id, { - "progress": -1, - "progress_msg": f"[ERROR]: {self.error}"}) + TaskService.update_progress(self.task_id, {"progress": -1, "progress_msg": f"[ERROR]: {self.error}"}) return {} diff --git a/rag/flow/tokenizer/tokenizer.py b/rag/flow/tokenizer/tokenizer.py index cdf5ed0071..20d0fe5826 100644 --- a/rag/flow/tokenizer/tokenizer.py +++ b/rag/flow/tokenizer/tokenizer.py @@ -33,6 +33,7 @@ from common.token_utils import truncate from common.misc_utils import thread_pool_exec + class TokenizerParam(ProcessParamBase): def __init__(self): super().__init__() @@ -69,7 +70,7 @@ class Tokenizer(ProcessBase): for i, c in enumerate(chunks): txt = "" if isinstance(self._param.fields, str): - self._param.fields=[self._param.fields] + self._param.fields = [self._param.fields] for f in self._param.fields: f = c.get(f) if isinstance(f, str): @@ -97,7 +98,10 @@ class Tokenizer(ProcessBase): cnts_batches = [] for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE): async with embed_limiter: - vts, c = await thread_pool_exec(batch_encode,texts[i : i + settings.EMBEDDING_BATCH_SIZE],) + vts, c = await thread_pool_exec( + batch_encode, + texts[i : i + settings.EMBEDDING_BATCH_SIZE], + ) cnts_batches.append(vts) token_count += c if i % 33 == 32: diff --git a/rag/graphrag/entity_resolution.py b/rag/graphrag/entity_resolution.py index 6819d003e0..1efd7a2afd 100644 --- a/rag/graphrag/entity_resolution.py +++ b/rag/graphrag/entity_resolution.py @@ -42,6 +42,7 @@ DEFAULT_RESOLUTION_RESULT_DELIMITER = "&&" @dataclass class EntityResolutionResult: """Entity resolution result class definition.""" + graph: nx.Graph change: GraphChange @@ -56,8 +57,8 @@ class EntityResolution(Extractor): _resolution_result_delimiter_key: str def __init__( - self, - llm_invoker: CompletionLLM, + self, + llm_invoker: CompletionLLM, ): super().__init__(llm_invoker) """Init method definition.""" @@ -68,13 +69,16 @@ class EntityResolution(Extractor): self._resolution_result_delimiter_key = "resolution_result_delimiter" self._input_text_key = "input_text" - async def __call__(self, graph: nx.Graph, - subgraph_nodes: set[str], - prompt_variables: dict[str, Any] | None = None, - callback: Callable | None = None, - task_id: str = "", - checkpoints: dict[str, Any] | None = None, - save_checkpoint: Callable[[str, Any], Awaitable[bool]] | None = None) -> EntityResolutionResult: + async def __call__( + self, + graph: nx.Graph, + subgraph_nodes: set[str], + prompt_variables: dict[str, Any] | None = None, + callback: Callable | None = None, + task_id: str = "", + checkpoints: dict[str, Any] | None = None, + save_checkpoint: Callable[[str, Any], Awaitable[bool]] | None = None, + ) -> EntityResolutionResult: """Call method definition.""" if prompt_variables is None: prompt_variables = {} @@ -82,20 +86,17 @@ class EntityResolution(Extractor): # Wire defaults into the prompt variables self.prompt_variables = { **prompt_variables, - self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key) - or DEFAULT_RECORD_DELIMITER, - self._entity_index_delimiter_key: prompt_variables.get(self._entity_index_delimiter_key) - or DEFAULT_ENTITY_INDEX_DELIMITER, - self._resolution_result_delimiter_key: prompt_variables.get(self._resolution_result_delimiter_key) - or DEFAULT_RESOLUTION_RESULT_DELIMITER, + self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key) or DEFAULT_RECORD_DELIMITER, + self._entity_index_delimiter_key: prompt_variables.get(self._entity_index_delimiter_key) or DEFAULT_ENTITY_INDEX_DELIMITER, + self._resolution_result_delimiter_key: prompt_variables.get(self._resolution_result_delimiter_key) or DEFAULT_RESOLUTION_RESULT_DELIMITER, } nodes = sorted(graph.nodes()) - entity_types = sorted(set(graph.nodes[node].get('entity_type', '-') for node in nodes)) + entity_types = sorted(set(graph.nodes[node].get("entity_type", "-") for node in nodes)) node_clusters = {entity_type: [] for entity_type in entity_types} for node in nodes: - node_clusters[graph.nodes[node].get('entity_type', '-')].append(node) + node_clusters[graph.nodes[node].get("entity_type", "-")].append(node) candidate_resolution = {entity_type: [] for entity_type in entity_types} for k, v in node_clusters.items(): @@ -123,45 +124,32 @@ class EntityResolution(Extractor): if isinstance(pair, (list, tuple)) and len(pair) == 2: result_set.add((pair[0], pair[1])) remain_candidates_to_resolve -= len(candidate_batch[1]) - callback( - msg=f"Replayed {len(candidate_batch[1])} resolved pairs from checkpoint, " - f"{remain_candidates_to_resolve} remain." - ) + callback(msg=f"Replayed {len(candidate_batch[1])} resolved pairs from checkpoint, {remain_candidates_to_resolve} remain.") return enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") timeout_sec = 280 if enable_timeout_assertion else 1_000_000_000 try: - selected_pairs = await asyncio.wait_for( - self._resolve_candidate(candidate_batch, result_set, result_lock, task_id), - timeout=timeout_sec - ) + selected_pairs = await asyncio.wait_for(self._resolve_candidate(candidate_batch, result_set, result_lock, task_id), timeout=timeout_sec) if selected_pairs is not None and save_checkpoint: await save_checkpoint(checkpoint_key, [list(pair) for pair in selected_pairs]) remain_candidates_to_resolve -= len(candidate_batch[1]) - callback( - msg=f"Resolved {len(candidate_batch[1])} pairs, " - f"{remain_candidates_to_resolve} remain." - ) + callback(msg=f"Resolved {len(candidate_batch[1])} pairs, {remain_candidates_to_resolve} remain.") except asyncio.TimeoutError: logging.warning(f"Timeout resolving {candidate_batch}, skipping...") remain_candidates_to_resolve -= len(candidate_batch[1]) - callback( - msg=f"Failed to resolve {len(candidate_batch[1])} pairs due to timeout, skipped. " - f"{remain_candidates_to_resolve} remain." - ) + callback(msg=f"Failed to resolve {len(candidate_batch[1])} pairs due to timeout, skipped. {remain_candidates_to_resolve} remain.") except Exception as exception: logging.error(f"Error resolving candidate batch: {exception}") - tasks = [] for key, lst in candidate_resolution.items(): if not lst: continue for i in range(0, len(lst), resolution_batch_size): - batch = (key, lst[i:i + resolution_batch_size]) + batch = (key, lst[i : i + resolution_batch_size]) tasks.append(limited_resolve_candidate(batch, resolution_result, resolution_result_lock)) try: await asyncio.gather(*tasks, return_exceptions=False) @@ -213,19 +201,15 @@ class EntityResolution(Extractor): logging.info(f"Task {task_id} cancelled during entity resolution candidate processing.") raise TaskCanceledException(f"Task {task_id} was cancelled") - pair_txt = [ - f'When determining whether two {candidate_resolution_i[0]}s are the same, you should only focus on critical properties and overlook noisy factors.\n'] + pair_txt = [f"When determining whether two {candidate_resolution_i[0]}s are the same, you should only focus on critical properties and overlook noisy factors.\n"] for index, candidate in enumerate(candidate_resolution_i[1]): - pair_txt.append( - f'Question {index + 1}: name of{candidate_resolution_i[0]} A is {candidate[0]} ,name of{candidate_resolution_i[0]} B is {candidate[1]}') - sent = 'question above' if len(pair_txt) == 1 else f'above {len(pair_txt)} questions' + pair_txt.append(f"Question {index + 1}: name of{candidate_resolution_i[0]} A is {candidate[0]} ,name of{candidate_resolution_i[0]} B is {candidate[1]}") + sent = "question above" if len(pair_txt) == 1 else f"above {len(pair_txt)} questions" pair_txt.append( - f'\nUse domain knowledge of {candidate_resolution_i[0]}s to help understand the text and answer the {sent} in the format: For Question i, Yes, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are the same {candidate_resolution_i[0]}./No, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are different {candidate_resolution_i[0]}s. For Question i+1, (repeat the above procedures)') - pair_prompt = '\n'.join(pair_txt) - variables = { - **self.prompt_variables, - self._input_text_key: pair_prompt - } + f"\nUse domain knowledge of {candidate_resolution_i[0]}s to help understand the text and answer the {sent} in the format: For Question i, Yes, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are the same {candidate_resolution_i[0]}./No, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are different {candidate_resolution_i[0]}s. For Question i+1, (repeat the above procedures)" + ) + pair_prompt = "\n".join(pair_txt) + variables = {**self.prompt_variables, self._input_text_key: pair_prompt} text = perform_variable_replacements(self._resolution_prompt, variables=variables) logging.info(f"Created resolution prompt {len(text)} bytes for {len(candidate_resolution_i[1])} entity pairs of type {candidate_resolution_i[0]}") async with chat_limiter: @@ -244,49 +228,42 @@ class EntityResolution(Extractor): return None logging.debug(f"_resolve_candidate chat prompt: {text}\nchat response: {response}") - result = self._process_results(len(candidate_resolution_i[1]), response, - self.prompt_variables.get(self._record_delimiter_key, - DEFAULT_RECORD_DELIMITER), - self.prompt_variables.get(self._entity_index_delimiter_key, - DEFAULT_ENTITY_INDEX_DELIMITER), - self.prompt_variables.get(self._resolution_result_delimiter_key, - DEFAULT_RESOLUTION_RESULT_DELIMITER)) + result = self._process_results( + len(candidate_resolution_i[1]), + response, + self.prompt_variables.get(self._record_delimiter_key, DEFAULT_RECORD_DELIMITER), + self.prompt_variables.get(self._entity_index_delimiter_key, DEFAULT_ENTITY_INDEX_DELIMITER), + self.prompt_variables.get(self._resolution_result_delimiter_key, DEFAULT_RESOLUTION_RESULT_DELIMITER), + ) selected_pairs = [candidate_resolution_i[1][result_i[0] - 1] for result_i in result] async with resolution_result_lock: for pair in selected_pairs: resolution_result.add(pair) return selected_pairs - def _process_results( - self, - records_length: int, - results: str, - record_delimiter: str, - entity_index_delimiter: str, - resolution_result_delimiter: str - ) -> list: + def _process_results(self, records_length: int, results: str, record_delimiter: str, entity_index_delimiter: str, resolution_result_delimiter: str) -> list: ans_list = [] records = [r.strip() for r in results.split(record_delimiter)] for record in records: - pattern_int = fr"{re.escape(entity_index_delimiter)}(\d+){re.escape(entity_index_delimiter)}" + pattern_int = rf"{re.escape(entity_index_delimiter)}(\d+){re.escape(entity_index_delimiter)}" match_int = re.search(pattern_int, record) - res_int = int(str(match_int.group(1) if match_int else '0')) + res_int = int(str(match_int.group(1) if match_int else "0")) if res_int > records_length: continue pattern_bool = f"{re.escape(resolution_result_delimiter)}([a-zA-Z]+){re.escape(resolution_result_delimiter)}" match_bool = re.search(pattern_bool, record) - res_bool = str(match_bool.group(1) if match_bool else '') + res_bool = str(match_bool.group(1) if match_bool else "") if res_int and res_bool: - if res_bool.lower() == 'yes': + if res_bool.lower() == "yes": ans_list.append((res_int, "yes")) return ans_list def _has_digit_in_2gram_diff(self, a, b): def to_2gram_set(s): - return {s[i:i+2] for i in range(len(s) - 1)} + return {s[i : i + 2] for i in range(len(s) - 1)} set_a = to_2gram_set(a) set_b = to_2gram_set(b) @@ -308,4 +285,4 @@ class EntityResolution(Extractor): if max_l < 4: return len(a & b) > 1 - return len(a & b)*1./max_l >= 0.8 + return len(a & b) * 1.0 / max_l >= 0.8 diff --git a/rag/graphrag/general/community_report_prompt.py b/rag/graphrag/general/community_report_prompt.py index 8b9fa2f6ed..1b75858a83 100644 --- a/rag/graphrag/general/community_report_prompt.py +++ b/rag/graphrag/general/community_report_prompt.py @@ -155,4 +155,4 @@ where 1, 5, 7, 23, 2, 34, 46, and 64 represent the id (not the index) of the rel Do not include information where the supporting evidence for it is not provided. -Output:""" \ No newline at end of file +Output:""" diff --git a/rag/graphrag/general/community_reports_extractor.py b/rag/graphrag/general/community_reports_extractor.py index 207aebd9e7..57f1759e16 100644 --- a/rag/graphrag/general/community_reports_extractor.py +++ b/rag/graphrag/general/community_reports_extractor.py @@ -28,6 +28,7 @@ from rag.llm.chat_model import Base as CompletionLLM from rag.graphrag.utils import perform_variable_replacements, dict_has_keys_with_types, chat_limiter from common.token_utils import num_tokens_from_string + @dataclass class CommunityReportsResult: """Community reports result class definition.""" @@ -44,9 +45,9 @@ class CommunityReportsExtractor(Extractor): _max_report_length: int def __init__( - self, - llm_invoker: CompletionLLM, - max_report_length: int | None = None, + self, + llm_invoker: CompletionLLM, + max_report_length: int | None = None, ): super().__init__(llm_invoker) """Init method definition.""" @@ -116,10 +117,7 @@ class CommunityReportsExtractor(Extractor): k += 1 rela_df = pd.DataFrame(rela_list) - prompt_variables = { - "entity_df": ent_df.to_csv(index_label="id"), - "relation_df": rela_df.to_csv(index_label="id") - } + prompt_variables = {"entity_df": ent_df.to_csv(index_label="id"), "relation_df": rela_df.to_csv(index_label="id")} text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables) async with chat_limiter: try: @@ -143,13 +141,16 @@ class CommunityReportsExtractor(Extractor): logging.error(f"Failed to parse JSON response: {e}") logging.error(f"Response content: {response}") return - if not dict_has_keys_with_types(response, [ - ("title", str), - ("summary", str), - ("findings", list), - ("rating", float), - ("rating_explanation", str), - ]): + if not dict_has_keys_with_types( + response, + [ + ("title", str), + ("summary", str), + ("findings", list), + ("rating", float), + ("rating_explanation", str), + ], + ): return response["weight"] = weight response["entities"] = ents @@ -203,7 +204,5 @@ class CommunityReportsExtractor(Extractor): return "" return finding.get("explanation") - report_sections = "\n\n".join( - f"## {finding_summary(f)}\n\n{finding_explanation(f)}" for f in findings - ) + report_sections = "\n\n".join(f"## {finding_summary(f)}\n\n{finding_explanation(f)}" for f in findings) return f"# {title}\n\n{summary}\n\n{report_sections}" diff --git a/rag/graphrag/general/extractor.py b/rag/graphrag/general/extractor.py index ae188b2889..0feacd6935 100644 --- a/rag/graphrag/general/extractor.py +++ b/rag/graphrag/general/extractor.py @@ -142,7 +142,6 @@ class Extractor: async def worker(chunk_key_dp: tuple[str, str], idx: int, total: int, task_id=""): nonlocal error_count async with limiter: - if task_id and has_canceled(task_id): raise TaskCanceledException(f"Task {task_id} was cancelled during entity extraction") @@ -158,10 +157,7 @@ class Extractor: if error_count > max_errors: raise Exception(f"Maximum error count ({max_errors}) reached. Last errors: {str(e)}") - tasks = [ - asyncio.create_task(worker((doc_id, ck), i, len(chunks), task_id)) - for i, ck in enumerate(chunks) - ] + tasks = [asyncio.create_task(worker((doc_id, ck), i, len(chunks), task_id)) for i, ck in enumerate(chunks)] try: await asyncio.gather(*tasks, return_exceptions=False) @@ -207,10 +203,7 @@ class Extractor: if task_id and has_canceled(task_id): raise TaskCanceledException(f"Task {task_id} was cancelled before nodes merging") - tasks = [ - asyncio.create_task(self._merge_nodes(en_nm, ents, all_entities_data, task_id)) - for en_nm, ents in maybe_nodes.items() - ] + tasks = [asyncio.create_task(self._merge_nodes(en_nm, ents, all_entities_data, task_id)) for en_nm, ents in maybe_nodes.items()] try: await asyncio.gather(*tasks, return_exceptions=False) except Exception as e: @@ -236,11 +229,7 @@ class Extractor: tasks = [] for (src, tgt), rels in maybe_edges.items(): - tasks.append( - asyncio.create_task( - self._merge_edges(src, tgt, rels, all_relationships_data, task_id) - ) - ) + tasks.append(asyncio.create_task(self._merge_edges(src, tgt, rels, all_relationships_data, task_id))) try: await asyncio.gather(*tasks, return_exceptions=False) except Exception as e: diff --git a/rag/graphrag/general/graph_extractor.py b/rag/graphrag/general/graph_extractor.py index 5f06c36879..38270baec3 100644 --- a/rag/graphrag/general/graph_extractor.py +++ b/rag/graphrag/general/graph_extractor.py @@ -70,16 +70,10 @@ class GraphExtractor(Extractor): self._input_text_key = input_text_key or "input_text" self._tuple_delimiter_key = tuple_delimiter_key or "tuple_delimiter" self._record_delimiter_key = record_delimiter_key or "record_delimiter" - self._completion_delimiter_key = ( - completion_delimiter_key or "completion_delimiter" - ) + self._completion_delimiter_key = completion_delimiter_key or "completion_delimiter" self._entity_types_key = entity_types_key or "entity_types" self._extraction_prompt = GRAPH_EXTRACTION_PROMPT - self._max_gleanings = ( - max_gleanings - if max_gleanings is not None - else ENTITY_EXTRACTION_MAX_GLEANINGS - ) + self._max_gleanings = max_gleanings if max_gleanings is not None else ENTITY_EXTRACTION_MAX_GLEANINGS self._on_error = on_error or (lambda _e, _s, _d: None) self.prompt_token_count = num_tokens_from_string(self._extraction_prompt) @@ -147,4 +141,7 @@ class GraphExtractor(Extractor): maybe_nodes, maybe_edges = self._entities_and_relations(chunk_key, records, self._prompt_variables[self._tuple_delimiter_key]) out_results.append((maybe_nodes, maybe_edges, token_count)) if self.callback: - self.callback(0.5+0.1*len(out_results)/num_chunks, msg = f"Entities extraction of chunk {chunk_seq+1} {len(out_results)}/{num_chunks} done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {token_count} tokens.") + self.callback( + 0.5 + 0.1 * len(out_results) / num_chunks, + msg=f"Entities extraction of chunk {chunk_seq + 1} {len(out_results)}/{num_chunks} done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {token_count} tokens.", + ) diff --git a/rag/graphrag/general/graph_prompt.py b/rag/graphrag/general/graph_prompt.py index 54ca3a1d3b..4e8570d3dc 100644 --- a/rag/graphrag/general/graph_prompt.py +++ b/rag/graphrag/general/graph_prompt.py @@ -121,4 +121,4 @@ Use {language} as output language. Entities: {entity_name} Description List: {description_list} ####### -""" \ No newline at end of file +""" diff --git a/rag/graphrag/general/index.py b/rag/graphrag/general/index.py index 9c00ae28ed..3c8314c29b 100644 --- a/rag/graphrag/general/index.py +++ b/rag/graphrag/general/index.py @@ -221,11 +221,7 @@ async def load_subgraph_from_store(tenant_id: str, kb_id: str, doc_id: str): "source_id": [doc_id], } try: - res = await thread_pool_exec( - settings.docStoreConn.search, - fields, [], condition, [], OrderByExpr(), - 0, 1, search.index_name(tenant_id), [kb_id] - ) + res = await thread_pool_exec(settings.docStoreConn.search, fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id]) field_map = settings.docStoreConn.get_fields(res, fields) for cid, row in field_map.items(): content = row.get("content_with_weight", "") @@ -237,19 +233,22 @@ async def load_subgraph_from_store(tenant_id: str, kb_id: str, doc_id: str): sg.graph["source_id"] = [doc_id] logging.info( "Checkpoint hit: subgraph for doc %s (tenant=%s kb=%s) found at chunk %s", - doc_id, tenant_id, kb_id, cid, + doc_id, + tenant_id, + kb_id, + cid, ) return sg except Exception: - logging.exception( - "Failed to parse subgraph JSON for doc %s chunk %s", doc_id, cid - ) + logging.exception("Failed to parse subgraph JSON for doc %s chunk %s", doc_id, cid) except Exception: logging.exception("Failed to load subgraph from store for doc %s", doc_id) return None logging.info( "Checkpoint miss: no subgraph for doc %s (tenant=%s kb=%s)", - doc_id, tenant_id, kb_id, + doc_id, + tenant_id, + kb_id, ) return None @@ -327,19 +326,11 @@ async def run_graphrag_for_kb( chunks = [] current_chunk = "" - raw_chunks = list(settings.retriever.chunk_list( - doc_id, - tenant_id, - [kb_id], - fields=fields_for_chunks, - sort_by_position=True, - retrieve_all=True - )) + raw_chunks = list(settings.retriever.chunk_list(doc_id, tenant_id, [kb_id], fields=fields_for_chunks, sort_by_position=True, retrieve_all=True)) callback(msg=f"[GraphRAG] chunk_list returned {len(raw_chunks)} raw chunks for doc:{doc_id}") - contents = [content for chunk in raw_chunks if (content := chunk.get("content_with_weight", "")) -] + contents = [content for chunk in raw_chunks if (content := chunk.get("content_with_weight", ""))] # For NER-based extractionm, no need to batch extract entity and relation if _select_extractor_type(graphrag_config) == "ner": return contents @@ -398,6 +389,7 @@ async def run_graphrag_for_kb( _has_cancel_and_exit(task_id, f"Task {task_id} cancelled before subgraph generation for doc {doc_id}.", callback) try: + async def build_subgraph_attempt(): checkpoint_sg = await load_subgraph_from_store(tenant_id, kb_id, doc_id) if checkpoint_sg: @@ -492,6 +484,7 @@ async def run_graphrag_for_kb( union_nodes.update(set(sg.nodes())) try: + async def merge_subgraph_attempt(): current_graph = await get_graph(tenant_id, kb_id) if current_graph and doc_id in current_graph.graph.get("source_id", []): @@ -717,8 +710,18 @@ async def generate_subgraph( } cid = chunk_id(chunk) _has_cancel_and_exit(task_id, f"Task {task_id} cancelled before saving subgraph for doc {doc_id}.", callback) - await thread_pool_exec(settings.docStoreConn.delete,{"knowledge_graph_kwd": "subgraph", "source_id": doc_id},search.index_name(tenant_id),kb_id,) - await thread_pool_exec(settings.docStoreConn.insert,[{"id": cid, **chunk}],search.index_name(tenant_id),kb_id,) + await thread_pool_exec( + settings.docStoreConn.delete, + {"knowledge_graph_kwd": "subgraph", "source_id": doc_id}, + search.index_name(tenant_id), + kb_id, + ) + await thread_pool_exec( + settings.docStoreConn.insert, + [{"id": cid, **chunk}], + search.index_name(tenant_id), + kb_id, + ) now = asyncio.get_running_loop().time() callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.") return subgraph @@ -883,8 +886,15 @@ async def extract_community( try: existing_res = await thread_pool_exec( settings.docStoreConn.search, - ["id"], [], {"knowledge_graph_kwd": ["community_report"]}, [], OrderByExpr(), - 0, 10000, search.index_name(tenant_id), [kb_id], + ["id"], + [], + {"knowledge_graph_kwd": ["community_report"]}, + [], + OrderByExpr(), + 0, + 10000, + search.index_name(tenant_id), + [kb_id], ) existing_fields = settings.docStoreConn.get_fields(existing_res, ["id"]) old_ids = list(existing_fields.keys()) diff --git a/rag/graphrag/general/leiden.py b/rag/graphrag/general/leiden.py index b859e7e62f..b1af534155 100644 --- a/rag/graphrag/general/leiden.py +++ b/rag/graphrag/general/leiden.py @@ -70,10 +70,10 @@ def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph: def _compute_leiden_communities( - graph: nx.Graph | nx.DiGraph, - max_cluster_size: int, - use_lcc: bool, - seed=0xDEADBEEF, + graph: nx.Graph | nx.DiGraph, + max_cluster_size: int, + use_lcc: bool, + seed=0xDEADBEEF, ) -> dict[int, dict[str, int]]: """Return Leiden root communities.""" results: dict[int, dict[str, int]] = {} @@ -82,9 +82,7 @@ def _compute_leiden_communities( if use_lcc: graph = stable_largest_connected_component(graph) - community_mapping = hierarchical_leiden( - graph, max_cluster_size=max_cluster_size, random_seed=seed - ) + community_mapping = hierarchical_leiden(graph, max_cluster_size=max_cluster_size, random_seed=seed) for partition in community_mapping: results[partition.level] = results.get(partition.level, {}) results[partition.level][partition.node] = partition.cluster @@ -97,9 +95,7 @@ def run(graph: nx.Graph, args: dict[str, Any]) -> dict[int, dict[str, dict]]: max_cluster_size = args.get("max_cluster_size", 12) use_lcc = args.get("use_lcc", True) if args.get("verbose", False): - logging.debug( - "Running leiden with max_cluster_size=%s, lcc=%s", max_cluster_size, use_lcc - ) + logging.debug("Running leiden with max_cluster_size=%s, lcc=%s", max_cluster_size, use_lcc) nodes = set(graph.nodes()) if not nodes: return {} diff --git a/rag/graphrag/general/mind_map_prompt.py b/rag/graphrag/general/mind_map_prompt.py index 37f324cfa0..35b49065e5 100644 --- a/rag/graphrag/general/mind_map_prompt.py +++ b/rag/graphrag/general/mind_map_prompt.py @@ -20,16 +20,16 @@ MIND_MAP_EXTRACTION_PROMPT = """ - Step of task: 1. Generate a title for user's 'TEXT'。 2. Classify the 'TEXT' into sections of a mind map. - 3. If the subject matter is really complex, split them into sub-sections and sub-subsections. + 3. If the subject matter is really complex, split them into sub-sections and sub-subsections. 4. Add a shot content summary of the bottom level section. - Output requirement: - Generate at least 4 levels. - - Always try to maximize the number of sub-sections. + - Always try to maximize the number of sub-sections. - In language of 'Text' - MUST IN FORMAT OF MARKDOWN -TEXT- {input_text} -""" \ No newline at end of file +""" diff --git a/rag/graphrag/general/smoke.py b/rag/graphrag/general/smoke.py index 7c8ee2de18..dc84b2fbbe 100644 --- a/rag/graphrag/general/smoke.py +++ b/rag/graphrag/general/smoke.py @@ -90,12 +90,8 @@ async def main(): ) print(json.dumps(nx.node_link_data(graph), ensure_ascii=False, indent=2)) - await with_resolution( - args.tenant_id, kb_id, args.doc_id, llm_bdl, embed_bdl, callback - ) - community_structure, community_reports = await with_community( - args.tenant_id, kb_id, args.doc_id, llm_bdl, embed_bdl, callback - ) + await with_resolution(args.tenant_id, kb_id, args.doc_id, llm_bdl, embed_bdl, callback) + community_structure, community_reports = await with_community(args.tenant_id, kb_id, args.doc_id, llm_bdl, embed_bdl, callback) print( "------------------ COMMUNITY STRUCTURE--------------------\n", diff --git a/rag/graphrag/light/graph_extractor.py b/rag/graphrag/light/graph_extractor.py index f775468c8a..91d9d89653 100644 --- a/rag/graphrag/light/graph_extractor.py +++ b/rag/graphrag/light/graph_extractor.py @@ -19,6 +19,7 @@ from rag.graphrag.utils import chat_limiter, pack_user_ass_to_openai_messages, s from rag.llm.chat_model import Base as CompletionLLM from common.token_utils import num_tokens_from_string + @dataclass class GraphExtractionResult: """Unipartite graph extraction result class definition.""" @@ -121,5 +122,5 @@ class GraphExtractor(Extractor): if self.callback: self.callback( 0.5 + 0.1 * len(out_results) / num_chunks, - msg=f"Entities extraction of chunk {chunk_seq+1} {len(out_results)}/{num_chunks} done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {token_count} tokens.", + msg=f"Entities extraction of chunk {chunk_seq + 1} {len(out_results)}/{num_chunks} done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {token_count} tokens.", ) diff --git a/rag/graphrag/ner/dep_relation_extractor.py b/rag/graphrag/ner/dep_relation_extractor.py index 8824c4a905..2fdee15afd 100644 --- a/rag/graphrag/ner/dep_relation_extractor.py +++ b/rag/graphrag/ner/dep_relation_extractor.py @@ -22,6 +22,7 @@ Extracts typed relations using spaCy dependency parse with: - Dynamic confidence scoring - Multi-occurrence entity matching """ + from typing import Dict, List, Optional from .types import Entity, Relation @@ -31,28 +32,21 @@ from .types import Entity, Relation # or a tuple (dep, child_dep) for compound patterns. # None = no standard mapping (language uses different structure) _LANG_DEP_RULES: Dict[str, Dict[str, object]] = { - "en": {"pass_subj": "nsubjpass", "subj": "nsubj", - "agent": ("agent", "pobj"), - "dobj": "dobj", "prep_obj": ("prep", "pobj")}, - "de": {"subj": "sb", - "agent": ("sbp", "nk"), - "prep_obj": ("mo", "nk"), - "root_verb_child": "oc"}, # German ROOT is aux, real verb is "oc" - "fr": {"pass_subj": "nsubj:pass", "subj": "nsubj", - "agent": "obl:agent", - "dobj": "obj", "prep_obj": ("case", "obl")}, - "es": {"subj": "nsubj", - "agent": "obj", - "prep_obj": ("case", "obl")}, - "pt": {"pass_subj": "nsubj:pass", "subj": "nsubj", - "agent": "obl:agent", - "dobj": "obj", "prep_obj": ("case", "obl")}, - "zh": {"subj": "nsubj", - "agent": ("nmod:prep", None, "由"), # case "由" marks agent - "prep_obj": ("case", "nmod")}, - "ja": {"subj": "nsubj", - "agent": ("obl", None, "によって"), # "によって" marks agent - "prep_obj": ("case", "obl")}, + "en": {"pass_subj": "nsubjpass", "subj": "nsubj", "agent": ("agent", "pobj"), "dobj": "dobj", "prep_obj": ("prep", "pobj")}, + "de": {"subj": "sb", "agent": ("sbp", "nk"), "prep_obj": ("mo", "nk"), "root_verb_child": "oc"}, # German ROOT is aux, real verb is "oc" + "fr": {"pass_subj": "nsubj:pass", "subj": "nsubj", "agent": "obl:agent", "dobj": "obj", "prep_obj": ("case", "obl")}, + "es": {"subj": "nsubj", "agent": "obj", "prep_obj": ("case", "obl")}, + "pt": {"pass_subj": "nsubj:pass", "subj": "nsubj", "agent": "obl:agent", "dobj": "obj", "prep_obj": ("case", "obl")}, + "zh": { + "subj": "nsubj", + "agent": ("nmod:prep", None, "由"), # case "由" marks agent + "prep_obj": ("case", "nmod"), + }, + "ja": { + "subj": "nsubj", + "agent": ("obl", None, "によって"), # "によって" marks agent + "prep_obj": ("case", "obl"), + }, } # Multi-hop inference rules: if A rel1 B and B rel2 C then A rel3 C @@ -64,91 +58,137 @@ _MULTI_HOP: Dict[str, Dict[str, str]] = { _VERB_RELATIONS: Dict[str, str] = { # English - "found+by": "founded_by", "co-found+by": "founded_by", - "establish+by": "founded_by", "create+by": "founded_by", - "set+up": "founded_by", "start+by": "founded_by", - "work+for": "works_for", "employ+by": "works_for", - "hire+by": "works_for", "join": "works_for", - "lead+by": "works_for", "manage+by": "works_for", - "head+by": "works_for", "run+by": "works_for", - "own+by": "owns", "develop+by": "develops", - "write+by": "wrote", "publish+by": "published", - "invest+in": "invests_in", "partner+with": "partners_with", + "found+by": "founded_by", + "co-found+by": "founded_by", + "establish+by": "founded_by", + "create+by": "founded_by", + "set+up": "founded_by", + "start+by": "founded_by", + "work+for": "works_for", + "employ+by": "works_for", + "hire+by": "works_for", + "join": "works_for", + "lead+by": "works_for", + "manage+by": "works_for", + "head+by": "works_for", + "run+by": "works_for", + "own+by": "owns", + "develop+by": "develops", + "write+by": "wrote", + "publish+by": "published", + "invest+in": "invests_in", + "partner+with": "partners_with", "collaborate+with": "collaborates_with", - "merge+with": "merged_with", "subsidiar+y": "is_subsidiary_of", - "base+in": "located_in", "locate+in": "located_in", - "situate+in": "located_in", "headquarter+in": "located_in", - "bear+in": "born_in", "bear+on": "born_in", - "acquire+by": "acquired", "buy+by": "acquired", + "merge+with": "merged_with", + "subsidiar+y": "is_subsidiary_of", + "base+in": "located_in", + "locate+in": "located_in", + "situate+in": "located_in", + "headquarter+in": "located_in", + "bear+in": "born_in", + "bear+on": "born_in", + "acquire+by": "acquired", + "buy+by": "acquired", # German (de): spaCy lemmas - "gründen+von": "founded_by", "errichten+von": "founded_by", - "arbeiten+für": "works_for", "beschäftigen+bei": "works_for", - "anstellen+bei": "works_for", "sich+befinden": "located_in", - "liegen+in": "located_in", "sitzen+in": "located_in", - "gebären+in": "born_in", "gebären+am": "born_in", - "erwerben+durch": "acquired", "kaufen+durch": "acquired", + "gründen+von": "founded_by", + "errichten+von": "founded_by", + "arbeiten+für": "works_for", + "beschäftigen+bei": "works_for", + "anstellen+bei": "works_for", + "sich+befinden": "located_in", + "liegen+in": "located_in", + "sitzen+in": "located_in", + "gebären+in": "born_in", + "gebären+am": "born_in", + "erwerben+durch": "acquired", + "kaufen+durch": "acquired", "übernehmen+durch": "acquired", # French (fr): spaCy lemmas - "fonder+par": "founded_by", "créer+par": "founded_by", + "fonder+par": "founded_by", + "créer+par": "founded_by", "établir+par": "founded_by", - "travailler+pour": "works_for", "employer+par": "works_for", + "travailler+pour": "works_for", + "employer+par": "works_for", "embaucher+par": "works_for", - "situer+à": "located_in", "baser+à": "located_in", + "situer+à": "located_in", + "baser+à": "located_in", "implanter+à": "located_in", "naître+à": "born_in", - "acquérir+par": "acquired", "racheter+par": "acquired", + "acquérir+par": "acquired", + "racheter+par": "acquired", # Spanish + Portuguese (shared lemmas, no duplicate keys) - "fundar+por": "founded_by", "crear+por": "founded_by", + "fundar+por": "founded_by", + "crear+por": "founded_by", "criar+por": "founded_by", - "establecer+por": "founded_by", "estabelecer+por": "founded_by", - "trabajar+para": "works_for", "trabalhar+para": "works_for", - "emplear+por": "works_for", "empregar+por": "works_for", + "establecer+por": "founded_by", + "estabelecer+por": "founded_by", + "trabajar+para": "works_for", + "trabalhar+para": "works_for", + "emplear+por": "works_for", + "empregar+por": "works_for", "contratar+por": "works_for", - "ubicar+en": "located_in", "situar+en": "located_in", - "localizar+em": "located_in", "situar+em": "located_in", - "sediar+em": "located_in", "tener+sede": "located_in", - "nacer+en": "born_in", "nascer+em": "born_in", - "adquirir+por": "acquired", "comprar+por": "acquired", + "ubicar+en": "located_in", + "situar+en": "located_in", + "localizar+em": "located_in", + "situar+em": "located_in", + "sediar+em": "located_in", + "tener+sede": "located_in", + "nacer+en": "born_in", + "nascer+em": "born_in", + "adquirir+por": "acquired", + "comprar+por": "acquired", # Chinese: verb + "由" (agent marker) or "被" (passive) - "创立+由": "founded_by", "创建+由": "founded_by", - "成立+由": "founded_by", "创办+由": "founded_by", + "创立+由": "founded_by", + "创建+由": "founded_by", + "成立+由": "founded_by", + "创办+由": "founded_by", "设立+由": "founded_by", - "任职+于": "works_for", "就职+于": "works_for", - "工作+在": "works_for", "位于+在": "located_in", - "坐落+在": "located_in", "总部设+在": "located_in", - "出生+在": "born_in", "出生+于": "born_in", - "收购+由": "acquired", "并购+由": "acquired", + "任职+于": "works_for", + "就职+于": "works_for", + "工作+在": "works_for", + "位于+在": "located_in", + "坐落+在": "located_in", + "总部设+在": "located_in", + "出生+在": "born_in", + "出生+于": "born_in", + "收购+由": "acquired", + "并购+由": "acquired", # Japanese: verb + "によって" (agent marker) - "設立+によって": "founded_by", "創立+によって": "founded_by", - "勤務+で": "works_for", "在籍+で": "works_for", - "位置+に": "located_in", "所在+に": "located_in", + "設立+によって": "founded_by", + "創立+によって": "founded_by", + "勤務+で": "works_for", + "在籍+で": "works_for", + "位置+に": "located_in", + "所在+に": "located_in", "本社+を": "located_in", "出生+に": "born_in", "買収+によって": "acquired", } _COPULA_TITLE_MAP: Dict[str, List[str]] = { - "ceo": ["ceo_of", "works_for"], "cto": ["works_for"], - "cfo": ["works_for"], "coo": ["works_for"], - "vp": ["works_for"], "director": ["works_for"], - "manager": ["works_for"], "engineer": ["works_for"], + "ceo": ["ceo_of", "works_for"], + "cto": ["works_for"], + "cfo": ["works_for"], + "coo": ["works_for"], + "vp": ["works_for"], + "director": ["works_for"], + "manager": ["works_for"], + "engineer": ["works_for"], "employee": ["works_for"], - "founder": ["founded_by"], "co-founder": ["founded_by"], + "founder": ["founded_by"], + "co-founder": ["founded_by"], } class DepRelationExtractor: """Extract typed relations using dependency parse — semantica-aligned.""" - def __init__(self, language: str = "en", - confidence_threshold: float = 0.3, - max_distance: int = 100): + def __init__(self, language: str = "en", confidence_threshold: float = 0.3, max_distance: int = 100): self.language = language self.confidence_threshold = confidence_threshold self.max_distance = max_distance - def extract(self, text: str, entities: List[Entity], - doc=None, **options) -> List[Relation]: + def extract(self, text: str, entities: List[Entity], doc=None, **options) -> List[Relation]: semantica_rels = [] if doc is not None: semantica_rels = self._extract_with_dep(text, doc, entities) @@ -180,12 +220,15 @@ class DepRelationExtractor: if r2.predicate in _MULTI_HOP.get(r.predicate, {}): inferred_rel = _MULTI_HOP[r.predicate][r2.predicate] if inferred_rel: - inferred.append(Relation( - subject=r.subject, predicate=inferred_rel, - obj=r2.obj, confidence=min(r.confidence, r2.confidence) * 0.9, - metadata={"method": "multi_hop", - "via": f"{r.predicate}→{r2.predicate}"}, - )) + inferred.append( + Relation( + subject=r.subject, + predicate=inferred_rel, + obj=r2.obj, + confidence=min(r.confidence, r2.confidence) * 0.9, + metadata={"method": "multi_hop", "via": f"{r.predicate}→{r2.predicate}"}, + ) + ) return relations + inferred # ------------------------------------------------------------------ @@ -221,10 +264,7 @@ class DepRelationExtractor: if dep == parent_dep: if case_marker: # Check if any child has the expected case lemma - has_case = any( - gc.lemma_ == case_marker or gc.text == case_marker - for gc in c.subtree - ) + has_case = any(gc.lemma_ == case_marker or gc.text == case_marker for gc in c.subtree) if not has_case: continue if child_dep is None: @@ -281,6 +321,7 @@ class DepRelationExtractor: # Extract roles (check both the main verb and optional aux parent) def first(lst): return lst[0][0] if lst else None + def get_roles(token): return ( first(self._get_by_role(token, "subj", entity_map)), @@ -340,8 +381,7 @@ class DepRelationExtractor: for prep_entity, prep_l in prep_list: rt = self._lookup(verb_lemma, prep_l) if rt: - relations.append(self._make_rel(effective_nsubj, rt, prep_entity, 0.85, - "active_prep", verb_lemma, prep=prep_l)) + relations.append(self._make_rel(effective_nsubj, rt, prep_entity, 0.85, "active_prep", verb_lemma, prep=prep_l)) # Passive with prep ("is based in") if effective_nsubjpass and prep_list and not agent_entity: @@ -350,8 +390,7 @@ class DepRelationExtractor: if not rt: rt = self._lookup("be+" + verb_lemma, prep_l) if rt: - relations.append(self._make_rel(effective_nsubjpass, rt, prep_entity, 0.85, - "passive_prep", verb_lemma, prep=prep_l)) + relations.append(self._make_rel(effective_nsubjpass, rt, prep_entity, 0.85, "passive_prep", verb_lemma, prep=prep_l)) return relations @@ -360,8 +399,7 @@ class DepRelationExtractor: m = {"method": method, "verb": verb} if prep: m["prep"] = prep - return Relation(subject=subj, predicate=pred, obj=obj, - confidence=conf, metadata=m) + return Relation(subject=subj, predicate=pred, obj=obj, confidence=conf, metadata=m) @staticmethod def _already_has(rels, subj, pred, obj) -> bool: @@ -401,11 +439,16 @@ class DepRelationExtractor: for keyword, rel_types in _COPULA_TITLE_MAP.items(): if keyword in title_lemma: for rt in rel_types: - relations.append(Relation( - subject=subj, predicate=rt, obj=prep_obj, - confidence=0.88, context=text, - metadata={"method": "copula", "title": title_lemma}, - )) + relations.append( + Relation( + subject=subj, + predicate=rt, + obj=prep_obj, + confidence=0.88, + context=text, + metadata={"method": "copula", "title": title_lemma}, + ) + ) break return relations @@ -426,8 +469,7 @@ class DepRelationExtractor: return result @staticmethod - def _find_best_entity(key: str, entity_map: Dict[str, List[Entity]], - fallback_text: str = "") -> Optional[Entity]: + def _find_best_entity(key: str, entity_map: Dict[str, List[Entity]], fallback_text: str = "") -> Optional[Entity]: """Find the best entity match. If multiple, prefer the one whose text is an exact match for fallback_text, or the first one.""" entries = entity_map.get(key.lower(), []) @@ -534,8 +576,8 @@ class DepRelationExtractor: if len(entities) < 2: return [] import re as _re - spans = [(m.start(), m.end()) - for m in _re.finditer(r'[^.!?]+(?:[.!?](?=\s|$))+', text)] + + spans = [(m.start(), m.end()) for m in _re.finditer(r"[^.!?]+(?:[.!?](?=\s|$))+", text)] def same_sent(c1, c2): return any(ss <= c1 < se and ss <= c2 < se for ss, se in spans) @@ -550,9 +592,14 @@ class DepRelationExtractor: continue cs = max(0, min(e1.start_char, e2.start_char) - 20) ce = min(len(text), max(e1.end_char, e2.end_char) + 20) - rels.append(Relation( - subject=e1, predicate="related_to", obj=e2, - confidence=0.4, context=text[cs:ce], - metadata={"method": "cooccurrence"}, - )) + rels.append( + Relation( + subject=e1, + predicate="related_to", + obj=e2, + confidence=0.4, + context=text[cs:ce], + metadata={"method": "cooccurrence"}, + ) + ) return rels diff --git a/rag/graphrag/ner/graph_extractor.py b/rag/graphrag/ner/graph_extractor.py index 7e2fc69de7..d5150c60f7 100644 --- a/rag/graphrag/ner/graph_extractor.py +++ b/rag/graphrag/ner/graph_extractor.py @@ -58,18 +58,14 @@ def _load_spacy_model(model_name: str = "en_core_web_sm"): try: import spacy except ImportError: - raise ImportError( - "spaCy is required for the spacy GraphRAG method. " - "Install it with: pip install spacy && python -m spacy download en_core_web_sm" - ) + raise ImportError("spaCy is required for the spacy GraphRAG method. Install it with: pip install spacy && python -m spacy download en_core_web_sm") try: _nlp = spacy.load(model_name) logging.info("Loaded spaCy model '%s'", model_name) except OSError: - logging.warning( - "spaCy model '%s' not found; downloading automatically …", model_name - ) + logging.warning("spaCy model '%s' not found; downloading automatically …", model_name) from spacy.cli import download as spacy_download + spacy_download(model_name) _nlp = spacy.load(model_name) logging.info("Downloaded and loaded spaCy model '%s'", model_name) @@ -110,19 +106,14 @@ _SKIP_SPACY_LABELS = {"ORDINAL", "CARDINAL"} # MGranRAG-style multi-pass keyword extraction # --------------------------------------------------------------------------- + def _has_uppercase(text: str) -> bool: return any(c.isupper() for c in text) def _replace_word(word: str) -> str: """Normalise spaces around hyphens and apostrophes (from MGranRAG).""" - return ( - word.replace(" - ", "-") - .replace(" -", "-") - .replace("- ", "-") - .replace(" 's", "'s") - .replace(" 'S", "'S") - ) + return word.replace(" - ", "-").replace(" -", "-").replace("- ", "-").replace(" 's", "'s").replace(" 'S", "'S") def extract_keywords(spacy_doc) -> set[str]: @@ -265,16 +256,19 @@ def extract_keywords(spacy_doc) -> set[str]: continue # Truncate trailing lowercase non-noun / non-number words. - if cwl and not _has_uppercase(cwl[-1]) and cpl[-1] not in ( - "PROPN", - "NOUN", - "NUM", - "PART", + if ( + cwl + and not _has_uppercase(cwl[-1]) + and cpl[-1] + not in ( + "PROPN", + "NOUN", + "NUM", + "PART", + ) ): for i in range(len(cpl) - 1, 0, -1): - if cpl[i] in ("PROPN", "NOUN", "NUM", "PART") or _has_uppercase( - cwl[i] - ): + if cpl[i] in ("PROPN", "NOUN", "NUM", "PART") or _has_uppercase(cwl[i]): break word = _replace_word(" ".join(cwl[: i + 1])) keywords.add(word) @@ -330,6 +324,7 @@ def ner_all_keywords(spacy_doc) -> set[str]: # Main extractor class # --------------------------------------------------------------------------- + class GraphExtractor(Extractor): """Extract entities and relationships using spaCy (no LLM calls). @@ -440,12 +435,12 @@ class GraphExtractor(Extractor): sent_idx = self._keyword_sent_idx(doc, kw) # Description: use the containing sentence (LinearRAG semantic bridging). - #sent_text = self._keyword_sent_text(doc, kw) + # sent_text = self._keyword_sent_text(doc, kw) ent_record = dict( entity_name=kw_upper, entity_type=app_type.upper(), - description="", #sent_text or kw, + description="", # sent_text or kw, source_id=chunk_key, ) # A keyword may appear multiple times; keep the first. @@ -463,9 +458,7 @@ class GraphExtractor(Extractor): # Pre-compute TF weights if needed (LinearRAG). entity_tf: dict[str, float] = {} if self._use_tf_weight: - total_count = sum( - content.upper().count(name) for name in ent_records - ) + total_count = sum(content.upper().count(name) for name in ent_records) for name in ent_records: count = content.upper().count(name) entity_tf[name] = count / total_count if total_count > 0 else 0.0 @@ -495,12 +488,11 @@ class GraphExtractor(Extractor): # Relationship description: shared sentence text # (LinearRAG semantic bridging — the sentence is the # semantic bridge between entities). - #desc = self._cooccurrence_description(doc, ea["entity_name"], eb["entity_name"]) + # desc = self._cooccurrence_description(doc, ea["entity_name"], eb["entity_name"]) # Edge weight: TF-normalised (LinearRAG) or fixed. if self._use_tf_weight: - w = (entity_tf.get(ea["entity_name"], 0.0) - + entity_tf.get(eb["entity_name"], 0.0)) + w = entity_tf.get(ea["entity_name"], 0.0) + entity_tf.get(eb["entity_name"], 0.0) weight = max(w, 0.01) else: weight = self._relationship_strength @@ -510,7 +502,7 @@ class GraphExtractor(Extractor): src_id=pair[0], tgt_id=pair[1], weight=weight, - description="", #desc, + description="", # desc, keywords=[ea["entity_name"], eb["entity_name"]], source_id=chunk_key, ) @@ -521,10 +513,7 @@ class GraphExtractor(Extractor): if self.callback: self.callback( 0.5 + 0.1 * len(out_results) / num_chunks, - msg=f"[spacy] Entities extraction of chunk {chunk_seq+1} " - f"{len(out_results)}/{num_chunks} done, " - f"{len(maybe_nodes)} nodes, {len(maybe_edges)} edges, " - f"{token_count} tokens.", + msg=f"[spacy] Entities extraction of chunk {chunk_seq + 1} {len(out_results)}/{num_chunks} done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {token_count} tokens.", ) # ------------------------------------------------------------------ diff --git a/rag/graphrag/ner/ner_extractor.py b/rag/graphrag/ner/ner_extractor.py index f97f02c0ed..48fc61a98f 100644 --- a/rag/graphrag/ner/ner_extractor.py +++ b/rag/graphrag/ner/ner_extractor.py @@ -40,9 +40,12 @@ from .types import Entity, ExtractionResult # Language → spaCy model _MODEL_MAP = { - "en": "en_core_web_sm", "zh": "zh_core_web_sm", - "de": "de_core_news_sm", "fr": "fr_core_news_sm", - "es": "es_core_news_sm", "pt": "pt_core_news_sm", + "en": "en_core_web_sm", + "zh": "zh_core_web_sm", + "de": "de_core_news_sm", + "fr": "fr_core_news_sm", + "es": "es_core_news_sm", + "pt": "pt_core_news_sm", "ja": "ja_core_news_sm", } @@ -51,8 +54,7 @@ _SKIP_LABELS = {"ORDINAL", "CARDINAL"} # Labels by confidence tier (for NER confidence scoring) _HIGH_CONF = {"PERSON", "ORG", "GPE", "LOC", "DATE"} -_MED_CONF = {"PRODUCT", "EVENT", "WORK_OF_ART", "LAW", "LANGUAGE", "NORP", - "MONEY", "TIME", "PERCENT", "FAC", "QUANTITY"} +_MED_CONF = {"PRODUCT", "EVENT", "WORK_OF_ART", "LAW", "LANGUAGE", "NORP", "MONEY", "TIME", "PERCENT", "FAC", "QUANTITY"} class NERExtractor: @@ -99,8 +101,7 @@ class NERExtractor: self._nlp_cache[self.model_name] = nlp self._nlp = nlp except Exception as e: - logging.error("Failed to load spaCy model '%s': %s", - self.model_name, e) + logging.error("Failed to load spaCy model '%s': %s", self.model_name, e) raise # ------------------------------------------------------------------ @@ -207,14 +208,16 @@ class NERExtractor: if key in seen: continue seen.add(key) - entities.append(Entity( - text=ent.text, - label=ent.label_, - start_char=ent.start_char, - end_char=ent.end_char, - confidence=confidence, - metadata={"source": "spacy"}, - )) + entities.append( + Entity( + text=ent.text, + label=ent.label_, + start_char=ent.start_char, + end_char=ent.end_char, + confidence=confidence, + metadata={"source": "spacy"}, + ) + ) return entities @staticmethod @@ -240,4 +243,3 @@ class NERExtractor: # Patch ExtractionResult to support metadata - diff --git a/rag/graphrag/ner/types.py b/rag/graphrag/ner/types.py index 80e420df2f..3884de0eb7 100644 --- a/rag/graphrag/ner/types.py +++ b/rag/graphrag/ner/types.py @@ -16,6 +16,7 @@ """ Data types for entity and relation extraction. """ + from dataclasses import dataclass, field from typing import Any, Dict, List @@ -23,6 +24,7 @@ from typing import Any, Dict, List @dataclass class Entity: """Extracted entity.""" + text: str label: str # spaCy NER label: PERSON, ORG, GPE, ... start_char: int @@ -34,8 +36,9 @@ class Entity: @dataclass class Relation: """Extracted relation between two entities.""" + subject: Entity - predicate: str # relation type: "founded_by", "works_for", ... + predicate: str # relation type: "founded_by", "works_for", ... obj: Entity confidence: float = 1.0 context: str = "" # surrounding text @@ -45,6 +48,7 @@ class Relation: @dataclass class ExtractionResult: """Result of a full extraction pass.""" + entities: List[Entity] = field(default_factory=list) relations: List[Relation] = field(default_factory=list) language: str = "en" diff --git a/rag/graphrag/query_analyze_prompt.py b/rag/graphrag/query_analyze_prompt.py index 3e5862e45e..07d4bd846b 100644 --- a/rag/graphrag/query_analyze_prompt.py +++ b/rag/graphrag/query_analyze_prompt.py @@ -4,6 +4,7 @@ Reference: - [LightRag](https://github.com/HKUDS/LightRAG) - [MiniRAG](https://github.com/HKUDS/MiniRAG) """ + PROMPTS = {} PROMPTS["minirag_query2kwd"] = """---Role--- @@ -14,7 +15,7 @@ You are a helpful assistant tasked with identifying both answer-type and low-lev Given the query, list both answer-type and low-level keywords. answer_type_keywords focus on the type of the answer to the certain query, while low-level keywords focus on specific entities, details, or concrete terms. -The answer_type_keywords must be selected from Answer type pool. +The answer_type_keywords must be selected from Answer type pool. This pool is in the form of a dictionary, where the key represents the Type you should choose from and the value represents the example samples. ---Instructions--- diff --git a/rag/graphrag/search.py b/rag/graphrag/search.py index c342ae03cf..34cf0a39d3 100644 --- a/rag/graphrag/search.py +++ b/rag/graphrag/search.py @@ -45,8 +45,7 @@ class KGSearch(Dealer): async def query_rewrite(self, llm, question, idxnms, kb_ids): ty2ents = await get_entity_type2samples(idxnms, kb_ids) - hint_prompt = PROMPTS["minirag_query2kwd"].format(query=question, - TYPE_POOL=json.dumps(ty2ents, ensure_ascii=False, indent=2)) + hint_prompt = PROMPTS["minirag_query2kwd"].format(query=question, TYPE_POOL=json.dumps(ty2ents, ensure_ascii=False, indent=2)) result = await self._chat(llm, hint_prompt, [{"role": "user", "content": "Output:"}], {}) try: keywords_data = json_repair.loads(result) @@ -55,8 +54,8 @@ class KGSearch(Dealer): return type_keywords, entities_from_query except json_repair.JSONDecodeError: try: - result = result.replace(hint_prompt[:-1], '').replace('user', '').replace('model', '').strip() - result = '{' + result.split('{')[1].split('}')[0] + '}' + result = result.replace(hint_prompt[:-1], "").replace("user", "").replace("model", "").strip() + result = "{" + result.split("{")[1].split("}")[0] + "}" keywords_data = json_repair.loads(result) type_keywords = keywords_data.get("answer_type_keywords", []) entities_from_query = keywords_data.get("entities_from_query", [])[:5] @@ -90,14 +89,13 @@ class KGSearch(Dealer): "sim": get_float(ent.get("_score", 0)), "pagerank": get_float(ent.get("rank_flt", 0)), "n_hop_ents": n_hop_ents, - "description": ent.get("content_with_weight", "{}") + "description": ent.get("content_with_weight", "{}"), } return res def _relation_info_from_(self, es_res, sim_thr=0.3): res = {} - es_res = self.dataStore.get_fields(es_res, ["content_with_weight", "_score", "from_entity_kwd", "to_entity_kwd", - "weight_int"]) + es_res = self.dataStore.get_fields(es_res, ["content_with_weight", "_score", "from_entity_kwd", "to_entity_kwd", "weight_int"]) for _, ent in es_res.items(): if get_float(ent.get("_score", 0)) < sim_thr: continue @@ -106,11 +104,7 @@ class KGSearch(Dealer): f = f[0] if isinstance(t, list): t = t[0] - res[(f, t)] = { - "sim": get_float(ent.get("_score", 0)), - "pagerank": get_float(ent.get("weight_int", 0)), - "description": ent["content_with_weight"] - } + res[(f, t)] = {"sim": get_float(ent.get("_score", 0)), "pagerank": get_float(ent.get("weight_int", 0)), "description": ent["content_with_weight"]} return res def get_relevant_ents_by_keywords(self, keywords, filters, idxnms, kb_ids, emb_mdl, sim_thr=0.3, N=56): @@ -119,9 +113,7 @@ class KGSearch(Dealer): filters = deepcopy(filters) filters["knowledge_graph_kwd"] = "entity" matchDense = self.get_vector(", ".join(keywords), emb_mdl, 1024, sim_thr) - es_res = self.dataStore.search(["content_with_weight", "entity_kwd", "rank_flt", "n_hop_with_weight"], [], filters, [matchDense], - OrderByExpr(), 0, N, - idxnms, kb_ids) + es_res = self.dataStore.search(["content_with_weight", "entity_kwd", "rank_flt", "n_hop_with_weight"], [], filters, [matchDense], OrderByExpr(), 0, N, idxnms, kb_ids) return self._ent_info_from_(es_res, sim_thr) def get_relevant_relations_by_txt(self, txt, filters, idxnms, kb_ids, emb_mdl, sim_thr=0.3, N=56): @@ -130,9 +122,7 @@ class KGSearch(Dealer): filters = deepcopy(filters) filters["knowledge_graph_kwd"] = "relation" matchDense = self.get_vector(txt, emb_mdl, 1024, sim_thr) - es_res = self.dataStore.search( - ["content_with_weight", "_score", "from_entity_kwd", "to_entity_kwd", "weight_int"], - [], filters, [matchDense], OrderByExpr(), 0, N, idxnms, kb_ids) + es_res = self.dataStore.search(["content_with_weight", "_score", "from_entity_kwd", "to_entity_kwd", "weight_int"], [], filters, [matchDense], OrderByExpr(), 0, N, idxnms, kb_ids) return self._relation_info_from_(es_res, sim_thr) def get_relevant_ents_by_types(self, types, filters, idxnms, kb_ids, N=56): @@ -143,23 +133,24 @@ class KGSearch(Dealer): filters["entity_type_kwd"] = types ordr = OrderByExpr() ordr.desc("rank_flt") - es_res = self.dataStore.search(["entity_kwd", "rank_flt"], [], filters, [], ordr, 0, N, - idxnms, kb_ids) + es_res = self.dataStore.search(["entity_kwd", "rank_flt"], [], filters, [], ordr, 0, N, idxnms, kb_ids) return self._ent_info_from_(es_res, 0) - async def retrieval(self, question: str, - tenant_ids: str | list[str], - kb_ids: list[str], - emb_mdl, - llm, - max_token: int = 8196, - ent_topn: int = 6, - rel_topn: int = 6, - comm_topn: int = 1, - ent_sim_threshold: float = 0.3, - rel_sim_threshold: float = 0.3, - **kwargs - ): + async def retrieval( + self, + question: str, + tenant_ids: str | list[str], + kb_ids: list[str], + emb_mdl, + llm, + max_token: int = 8196, + ent_topn: int = 6, + rel_topn: int = 6, + comm_topn: int = 1, + ent_sim_threshold: float = 0.3, + rel_sim_threshold: float = 0.3, + **kwargs, + ): qst = question filters = self.get_filters({"kb_ids": kb_ids}) if isinstance(tenant_ids, str): @@ -192,9 +183,7 @@ class KGSearch(Dealer): nhop_pathes[(f, t)]["sim"] += ent["sim"] / (2 + i) else: nhop_pathes[(f, t)]["sim"] = ent["sim"] / (2 + i) - nhop_pathes[(f, t)]["pagerank"] = max( - nhop_pathes[(f, t)].get("pagerank", 0), wts[i] - ) + nhop_pathes[(f, t)]["pagerank"] = max(nhop_pathes[(f, t)].get("pagerank", 0), wts[i]) logging.info("Retrieved entities: {}".format(list(ents_from_query.keys()))) logging.info("Retrieved relations: {}".format(list(rels_from_txt.keys()))) @@ -207,7 +196,7 @@ class KGSearch(Dealer): continue ents_from_query[ent]["sim"] *= 2 - for (f, t) in rels_from_txt.keys(): + for f, t in rels_from_txt.keys(): pair = tuple(sorted([f, t])) s = 0 if pair in nhop_pathes: @@ -220,30 +209,21 @@ class KGSearch(Dealer): rels_from_txt[(f, t)]["sim"] *= s + 1 # This is for the relations from n-hop but not by query search - for (f, t) in nhop_pathes.keys(): + for f, t in nhop_pathes.keys(): s = 0 if f in ents_from_types: s += 1 if t in ents_from_types: s += 1 - rels_from_txt[(f, t)] = { - "sim": nhop_pathes[(f, t)]["sim"] * (s + 1), - "pagerank": nhop_pathes[(f, t)]["pagerank"] - } + rels_from_txt[(f, t)] = {"sim": nhop_pathes[(f, t)]["sim"] * (s + 1), "pagerank": nhop_pathes[(f, t)]["pagerank"]} - ents_from_query = sorted(ents_from_query.items(), key=lambda x: x[1]["sim"] * x[1]["pagerank"], reverse=True)[ - :ent_topn] - rels_from_txt = sorted(rels_from_txt.items(), key=lambda x: x[1]["sim"] * x[1]["pagerank"], reverse=True)[ - :rel_topn] + ents_from_query = sorted(ents_from_query.items(), key=lambda x: x[1]["sim"] * x[1]["pagerank"], reverse=True)[:ent_topn] + rels_from_txt = sorted(rels_from_txt.items(), key=lambda x: x[1]["sim"] * x[1]["pagerank"], reverse=True)[:rel_topn] ents = [] relas = [] for n, ent in ents_from_query: - ents.append({ - "Entity": n, - "Score": "%.2f" % (ent["sim"] * ent["pagerank"]), - "Description": json.loads(ent["description"]).get("description", "") if ent["description"] else "" - }) + ents.append({"Entity": n, "Score": "%.2f" % (ent["sim"] * ent["pagerank"]), "Description": json.loads(ent["description"]).get("description", "") if ent["description"] else ""}) max_token -= num_tokens_from_string(str(ents[-1])) if max_token <= 0: ents = ents[:-1] @@ -263,12 +243,7 @@ class KGSearch(Dealer): desc = json.loads(desc).get("description", "") except Exception: pass - relas.append({ - "From Entity": f, - "To Entity": t, - "Score": "%.2f" % (rel["sim"] * rel["pagerank"]), - "Description": desc - }) + relas.append({"From Entity": f, "To Entity": t, "Score": "%.2f" % (rel["sim"] * rel["pagerank"]), "Description": desc}) max_token -= num_tokens_from_string(str(relas[-1])) if max_token <= 0: relas = relas[:-1] @@ -284,21 +259,20 @@ class KGSearch(Dealer): relas = "" return { - "chunk_id": get_uuid(), - "content_ltks": "", - "content_with_weight": ents + relas + self._community_retrieval_([n for n, _ in ents_from_query], filters, kb_ids, idxnms, - comm_topn, max_token), - "doc_id": "", - "docnm_kwd": "Related content in Knowledge Graph", - "kb_id": kb_ids, - "important_kwd": [], - "image_id": "", - "similarity": 1., - "vector_similarity": 1., - "term_similarity": 0, - "vector": [], - "positions": [], - } + "chunk_id": get_uuid(), + "content_ltks": "", + "content_with_weight": ents + relas + self._community_retrieval_([n for n, _ in ents_from_query], filters, kb_ids, idxnms, comm_topn, max_token), + "doc_id": "", + "docnm_kwd": "Related content in Knowledge Graph", + "kb_id": kb_ids, + "important_kwd": [], + "image_id": "", + "similarity": 1.0, + "vector_similarity": 1.0, + "term_similarity": 0, + "vector": [], + "positions": [], + } def _community_retrieval_(self, entities, condition, kb_ids, idxnms, topn, max_token): ## Community retrieval @@ -308,14 +282,12 @@ class KGSearch(Dealer): fltr = deepcopy(condition) fltr["knowledge_graph_kwd"] = "community_report" fltr["entities_kwd"] = entities - comm_res = self.dataStore.search(fields, [], fltr, [], - odr, 0, topn, idxnms, kb_ids) + comm_res = self.dataStore.search(fields, [], fltr, [], odr, 0, topn, idxnms, kb_ids) comm_res_fields = self.dataStore.get_fields(comm_res, fields) txts = [] for ii, (_, row) in enumerate(comm_res_fields.items()): obj = json.loads(row["content_with_weight"]) - txts.append("# {}. {}\n## Content\n{}\n## Evidences\n{}\n".format( - ii + 1, row["docnm_kwd"], obj["report"], obj["evidences"])) + txts.append("# {}. {}\n## Content\n{}\n## Evidences\n{}\n".format(ii + 1, row["docnm_kwd"], obj["report"], obj["evidences"])) max_token -= num_tokens_from_string(str(txts[-1])) if not txts: @@ -333,9 +305,9 @@ if __name__ == "__main__": settings.init_settings() parser = argparse.ArgumentParser() - parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True) - parser.add_argument('-d', '--kb_id', default=False, help="Knowledge base ID", action='store', required=True) - parser.add_argument('-q', '--question', default=False, help="Question", action='store', required=True) + parser.add_argument("-t", "--tenant_id", default=False, help="Tenant ID", action="store", required=True) + parser.add_argument("-d", "--kb_id", default=False, help="Knowledge base ID", action="store", required=True) + parser.add_argument("-q", "--question", default=False, help="Question", action="store", required=True) args = parser.parse_args() kb_id = args.kb_id @@ -346,5 +318,4 @@ if __name__ == "__main__": embed_bdl = LLMBundle(args.tenant_id, embd_model_config) kg = KGSearch(settings.docStoreConn) - print(asyncio.run(kg.retrieval({"question": args.question, "kb_ids": [kb_id]}, - search.index_name(kb.tenant_id), [kb_id], embed_bdl, llm_bdl))) + print(asyncio.run(kg.retrieval({"question": args.question, "kb_ids": [kb_id]}, search.index_name(kb.tenant_id), [kb_id], embed_bdl, llm_bdl))) diff --git a/rag/graphrag/utils.py b/rag/graphrag/utils.py index 563647ea8e..5c584af901 100644 --- a/rag/graphrag/utils.py +++ b/rag/graphrag/utils.py @@ -90,7 +90,7 @@ async def insert_chunks_bounded(chunks, tenant_id, kb_id, *, callback=None, labe break except asyncio.TimeoutError: if attempt < max_retries - 1: - wait = 2 ** attempt + wait = 2**attempt logging.warning(f"Insert batch at offset {offset}/{total} attempt {attempt + 1} timed out, retrying in {wait}s") await asyncio.sleep(wait) else: @@ -99,7 +99,7 @@ async def insert_chunks_bounded(chunks, tenant_id, kb_id, *, callback=None, labe raise except Exception as e: if attempt < max_retries - 1: - wait = 2 ** attempt + wait = 2**attempt logging.warning(f"Insert batch at offset {offset}/{total} attempt {attempt + 1} failed: {e}, retrying in {wait}s") await asyncio.sleep(wait) else: @@ -170,7 +170,7 @@ def dict_has_keys_with_types(data: dict, expected_fields: list[tuple[str, type]] def get_llm_cache(llmnm, txt, history, genconf): """Return a cached LLM completion for the given model/text/history/config, or None on miss.""" hasher = xxhash.xxh64() - hasher.update((str(llmnm)+str(txt)+str(history)+str(genconf)).encode("utf-8")) + hasher.update((str(llmnm) + str(txt) + str(history) + str(genconf)).encode("utf-8")) k = hasher.hexdigest() bin = REDIS_CONN.get(k) @@ -182,7 +182,7 @@ def get_llm_cache(llmnm, txt, history, genconf): def set_llm_cache(llmnm, txt, v, history, genconf): """Store an LLM completion *v* in Redis keyed by a hash of model/text/history/config.""" hasher = xxhash.xxh64() - hasher.update((str(llmnm)+str(txt)+str(history)+str(genconf)).encode("utf-8")) + hasher.update((str(llmnm) + str(txt) + str(history) + str(genconf)).encode("utf-8")) k = hasher.hexdigest() REDIS_CONN.set(k, v.encode("utf-8"), 24 * 3600) @@ -438,10 +438,7 @@ async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks, nhop_neig if ebd is None: async with chat_limiter: timeout = 3 if enable_timeout_assertion else 30000000 - ebd, _ = await asyncio.wait_for( - thread_pool_exec(embd_mdl.encode, [ent_name]), - timeout=timeout - ) + ebd, _ = await asyncio.wait_for(thread_pool_exec(embd_mdl.encode, [ent_name]), timeout=timeout) ebd = ebd[0] set_embed_cache(embd_mdl.llm_name, ent_name, ebd) assert ebd is not None @@ -494,13 +491,7 @@ async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, if ebd is None: async with chat_limiter: timeout = 3 if enable_timeout_assertion else 300000000 - ebd, _ = await asyncio.wait_for( - thread_pool_exec( - embd_mdl.encode, - [txt + f": {meta['description']}"] - ), - timeout=timeout - ) + ebd, _ = await asyncio.wait_for(thread_pool_exec(embd_mdl.encode, [txt + f": {meta['description']}"]), timeout=timeout) ebd = ebd[0] set_embed_cache(embd_mdl.llm_name, txt, ebd) assert ebd is not None @@ -516,11 +507,7 @@ async def does_graph_contains(tenant_id, kb_id, doc_id): "knowledge_graph_kwd": ["graph"], "removed_kwd": "N", } - res = await thread_pool_exec( - settings.docStoreConn.search, - fields, [], condition, [], OrderByExpr(), - 0, 1, search.index_name(tenant_id), [kb_id] - ) + res = await thread_pool_exec(settings.docStoreConn.search, fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id]) fields2 = settings.docStoreConn.get_fields(res, fields) graph_doc_ids = set() for chunk_id in fields2.keys(): @@ -609,13 +596,12 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang # embd_mdl.encode([single_name]). For 17 k+ nodes that is 17 k round-trips. # Pre-warming the cache here collapses N calls to ceil(N/_INSERT_BULK_SIZE). _node_list = list(change.added_updated_nodes) - _node_misses = await thread_pool_exec( - _batch_embed_cache_misses, embd_mdl.llm_name, _node_list - ) + _node_misses = await thread_pool_exec(_batch_embed_cache_misses, embd_mdl.llm_name, _node_list) _uncached_node_names = [n for n, miss in zip(_node_list, _node_misses) if miss] logging.debug( "set_graph node pre-warm: %d nodes, %d cache misses", - len(_node_list), len(_uncached_node_names), + len(_node_list), + len(_uncached_node_names), ) if _uncached_node_names: _enable_ta = os.environ.get("ENABLE_TIMEOUT_ASSERTION") @@ -635,18 +621,14 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang len(_batch), ) if callback: - callback(msg=f"Batch-embedded {len(_uncached_node_names)} entity names " - f"({(len(_uncached_node_names) + _INSERT_BULK_SIZE - 1) // _INSERT_BULK_SIZE} " - f"batches of {_INSERT_BULK_SIZE}).") + callback(msg=f"Batch-embedded {len(_uncached_node_names)} entity names ({(len(_uncached_node_names) + _INSERT_BULK_SIZE - 1) // _INSERT_BULK_SIZE} batches of {_INSERT_BULK_SIZE}).") # ── end batch pre-warm ────────────────────────────────────────────────────── tasks = [] for ii, node in enumerate(change.added_updated_nodes): node_attrs = graph.nodes[node] nhop_neighbors = n_neighbor(graph, node) - tasks.append(asyncio.create_task( - graph_node_to_chunk(kb_id, embd_mdl, node, node_attrs, chunks, nhop_neighbors) - )) + tasks.append(asyncio.create_task(graph_node_to_chunk(kb_id, embd_mdl, node, node_attrs, chunks, nhop_neighbors))) if ii % 100 == 9 and callback: callback(msg=f"Get embedding of nodes: {ii}/{len(change.added_updated_nodes)}") try: @@ -662,28 +644,24 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang # Mirror of the node pre-warm above for relation chunks. # Cache key = "A->B" (matches graph_edge_to_chunk lookup key) # Encoded text = "A->B: " (matches graph_edge_to_chunk encode text) - _all_edge_data = [ - (_fn, _tn, graph.get_edge_data(_fn, _tn)) - for _fn, _tn in change.added_updated_edges - ] + _all_edge_data = [(_fn, _tn, graph.get_edge_data(_fn, _tn)) for _fn, _tn in change.added_updated_edges] _all_edge_data = [(f, t, a) for f, t, a in _all_edge_data if a] _edge_lookup_keys = [f"{f}->{t}" for f, t, _ in _all_edge_data] - _edge_misses = await thread_pool_exec( - _batch_embed_cache_misses, embd_mdl.llm_name, _edge_lookup_keys - ) if _all_edge_data else [] + _edge_misses = await thread_pool_exec(_batch_embed_cache_misses, embd_mdl.llm_name, _edge_lookup_keys) if _all_edge_data else [] _uncached_edge_items = [item for item, miss in zip(_all_edge_data, _edge_misses) if miss] logging.debug( "set_graph edge pre-warm: %d edges, %d cache misses", - len(_all_edge_data), len(_uncached_edge_items), + len(_all_edge_data), + len(_uncached_edge_items), ) if _uncached_edge_items: - _edge_keys = [f"{f}->{t}" for f, t, _ in _uncached_edge_items] + _edge_keys = [f"{f}->{t}" for f, t, _ in _uncached_edge_items] _edge_texts = [f"{f}->{t}: {a['description']}" for f, t, a in _uncached_edge_items] _enable_ta = os.environ.get("ENABLE_TIMEOUT_ASSERTION") _timeout = 3 if _enable_ta else 30000000 for _i in range(0, len(_edge_texts), _INSERT_BULK_SIZE): _btexts = _edge_texts[_i : _i + _INSERT_BULK_SIZE] - _bkeys = _edge_keys [_i : _i + _INSERT_BULK_SIZE] + _bkeys = _edge_keys[_i : _i + _INSERT_BULK_SIZE] async with chat_limiter: _ebds, _ = await asyncio.wait_for( thread_pool_exec(embd_mdl.encode, _btexts), @@ -697,9 +675,7 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang len(_btexts), ) if callback: - callback(msg=f"Batch-embedded {len(_uncached_edge_items)} edge descriptions " - f"({(len(_uncached_edge_items) + _INSERT_BULK_SIZE - 1) // _INSERT_BULK_SIZE} " - f"batches of {_INSERT_BULK_SIZE}).") + callback(msg=f"Batch-embedded {len(_uncached_edge_items)} edge descriptions ({(len(_uncached_edge_items) + _INSERT_BULK_SIZE - 1) // _INSERT_BULK_SIZE} batches of {_INSERT_BULK_SIZE}).") # ── end batch pre-warm ────────────────────────────────────────────────────── tasks = [] @@ -707,9 +683,7 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang edge_attrs = graph.get_edge_data(from_node, to_node) if not edge_attrs: continue - tasks.append(asyncio.create_task( - graph_edge_to_chunk(kb_id, embd_mdl, from_node, to_node, edge_attrs, chunks) - )) + tasks.append(asyncio.create_task(graph_edge_to_chunk(kb_id, embd_mdl, from_node, to_node, edge_attrs, chunks))) if ii % 100 == 9 and callback: callback(msg=f"Get embedding of edges: {ii}/{len(change.added_updated_edges)}") try: @@ -729,24 +703,14 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang # All new chunks are ready. Now delete old data and insert the new data. # Deleting only after chunks are built ensures that a crash during embedding # generation above does not destroy the old graph/subgraph checkpoints. - await thread_pool_exec( - settings.docStoreConn.delete, - {"knowledge_graph_kwd": ["graph", "subgraph"]}, - search.index_name(tenant_id), - kb_id - ) + await thread_pool_exec(settings.docStoreConn.delete, {"knowledge_graph_kwd": ["graph", "subgraph"]}, search.index_name(tenant_id), kb_id) if change.removed_nodes: BATCH_SIZE = 100 sorted_nodes = sorted(change.removed_nodes) for i in range(0, len(sorted_nodes), BATCH_SIZE): - batch = sorted_nodes[i:i + BATCH_SIZE] - await thread_pool_exec( - settings.docStoreConn.delete, - {"knowledge_graph_kwd": ["entity"], "entity_kwd": batch}, - search.index_name(tenant_id), - kb_id - ) + batch = sorted_nodes[i : i + BATCH_SIZE] + await thread_pool_exec(settings.docStoreConn.delete, {"knowledge_graph_kwd": ["entity"], "entity_kwd": batch}, search.index_name(tenant_id), kb_id) if change.removed_edges: @@ -756,15 +720,12 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang try: async with chat_limiter: await thread_pool_exec( - settings.docStoreConn.delete, - {"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, - search.index_name(tenant_id), - kb_id + settings.docStoreConn.delete, {"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id ) return except Exception as e: if attempt < max_retries - 1: - wait = 2 ** attempt + wait = 2**attempt logging.warning(f"del_edges({from_node}, {to_node}) attempt {attempt + 1} failed: {e}, retrying in {wait}s") await asyncio.sleep(wait) else: @@ -796,6 +757,7 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang def is_continuous_subsequence(subseq, seq): """Return True if *subseq* appears as a contiguous sub-path within tuple *seq*.""" + def find_all_indexes(tup, value): indexes = [] start = 0 @@ -875,7 +837,7 @@ def n_neighbor(graph: nx.Graph, node, n_hop: int = 2): async def get_entity_type2samples(idxnms, kb_ids: list): """Return a mapping of entity type → sample entity names fetched from the document store.""" - es_res = await settings.retriever.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]},idxnms,kb_ids) + es_res = await settings.retriever.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]}, idxnms, kb_ids) res = defaultdict(list) for id in es_res.ids: @@ -910,11 +872,7 @@ async def rebuild_graph(tenant_id, kb_id, exclude_rebuild=None): flds = ["knowledge_graph_kwd", "content_with_weight", "source_id"] bs = 256 for i in range(0, 1024 * bs, bs): - es_res = await thread_pool_exec( - settings.docStoreConn.search, - flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, - [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id] - ) + es_res = await thread_pool_exec(settings.docStoreConn.search, flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id]) # tot = settings.docStoreConn.get_total(es_res) es_res = settings.docStoreConn.get_fields(es_res, flds) diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py index 5aae7c9641..d423451c98 100644 --- a/rag/llm/cv_model.py +++ b/rag/llm/cv_model.py @@ -36,10 +36,9 @@ from rag.nlp import is_english from rag.prompts.generator import vision_llm_describe_prompt - - from common.misc_utils import thread_pool_exec + def _qwen3_no_think_extra_body(model_name: str) -> dict[str, bool] | None: """Build DashScope-compatible options that disable Qwen3.x thinking.""" if "qwen3." in model_name.lower(): @@ -277,11 +276,7 @@ class GptV4(Base): def describe(self, image): b64 = self.image2base64(image) - res = self.client.chat.completions.create( - model=self.model_name, - messages=self.prompt(b64), - extra_body=self.extra_body - ) + res = self.client.chat.completions.create(model=self.model_name, messages=self.prompt(b64), extra_body=self.extra_body) if not res.choices: raise ValueError("LLM returned empty response") # pact: guard empty choices list return res.choices[0].message.content.strip(), total_token_count_from_response(res) @@ -303,9 +298,7 @@ def _resolve_azure_credentials(key): key_obj = json.loads(key) if isinstance(key_obj, dict): return key_obj.get("api_key", ""), key_obj.get("api_version", "2024-02-01") - logging.warning( - "Azure credential payload parsed as JSON but is not an object; using raw api_key string" - ) + logging.warning("Azure credential payload parsed as JSON but is not an object; using raw api_key string") except (json.JSONDecodeError, TypeError): logging.warning("Azure credential payload is not valid JSON; using raw api_key string") return key, "2024-02-01" @@ -586,6 +579,7 @@ class VolcEngineCV(GptV4): self.lang = lang Base.__init__(self, **kwargs) + class LmStudioCV(GptV4): _FACTORY_NAME = "LM-Studio" @@ -803,7 +797,9 @@ class OllamaCV(Base): async def async_chat(self, system, history, gen_conf, images=None, **kwargs): try: - response = await thread_pool_exec(self.client.chat, model=self.model_name, messages=self._form_history(system, history, images), options=self._clean_conf(gen_conf), keep_alive=self.keep_alive) + response = await thread_pool_exec( + self.client.chat, model=self.model_name, messages=self._form_history(system, history, images), options=self._clean_conf(gen_conf), keep_alive=self.keep_alive + ) ans = response["message"]["content"].strip() return ans, response["eval_count"] + response.get("prompt_eval_count", 0) @@ -813,7 +809,9 @@ class OllamaCV(Base): async def async_chat_streamly(self, system, history, gen_conf, images=None, **kwargs): ans = "" try: - response = await thread_pool_exec(self.client.chat, model=self.model_name, messages=self._form_history(system, history, images), stream=True, options=self._clean_conf(gen_conf), keep_alive=self.keep_alive) + response = await thread_pool_exec( + self.client.chat, model=self.model_name, messages=self._form_history(system, history, images), stream=True, options=self._clean_conf(gen_conf), keep_alive=self.keep_alive + ) for resp in response: if resp["done"]: yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0) @@ -1282,6 +1280,7 @@ class GoogleCV(AnthropicCV, GeminiCV): self.client = AnthropicVertex(region=region, project_id=project_id) else: from google import genai + client_kwargs = { "vertexai": True, "project": project_id, @@ -1347,12 +1346,13 @@ class FuturMixCV(GptV4): class RAGconCV(GptV4): """ RAGcon CV Provider - routes through LiteLLM proxy - + Supports vision models through LiteLLM. Default Base URL: https://connect.ragcon.ai/v1 """ + _FACTORY_NAME = "RAGcon" - + def __init__(self, key, model_name, lang="Chinese", base_url="", **kwargs): if not base_url: @@ -1393,6 +1393,7 @@ class BedrockCV(Base): } elif self.auth_mode == "iam_role": import boto3 + sts_client = boto3.client("sts", region_name=self.aws_region) resp = sts_client.assume_role(RoleArn=self.aws_role_arn, RoleSessionName="BedrockCVSession") creds = resp["Credentials"] @@ -1407,6 +1408,7 @@ class BedrockCV(Base): def describe_with_prompt(self, image, prompt=None): import litellm + b64 = self.image2base64(image) messages = self.vision_llm_prompt(b64, prompt) res = litellm.completion( diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 655895fdc4..64b7526545 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -307,9 +307,7 @@ def _resolve_azure_credentials(key): key_obj = json.loads(key) if isinstance(key_obj, dict): return key_obj.get("api_key", ""), key_obj.get("api_version", "2024-02-01") - logging.warning( - "Azure credential payload parsed as JSON but is not an object; using raw api_key string" - ) + logging.warning("Azure credential payload parsed as JSON but is not an object; using raw api_key string") except (json.JSONDecodeError, TypeError): logging.warning("Azure credential payload is not valid JSON; using raw api_key string") return key, "2024-02-01" diff --git a/rag/llm/model_meta.py b/rag/llm/model_meta.py index 6923cc65f5..d819d0c219 100644 --- a/rag/llm/model_meta.py +++ b/rag/llm/model_meta.py @@ -77,7 +77,6 @@ class VolcEngine(Base): serving_model = [model for model in raw_model_list["data"] if model.get("status", "") != "Shutdown"] res = [] for model in serving_model: - model_types = [] if model.get("domain", "") == "Embedding": @@ -109,13 +108,9 @@ class VolcEngine(Base): if model.get("token_limits", {}).get("max_reasoning_token_length", 0) > 0: features.append("thinking") - res.append({ - "name": model["id"], - "model_types": model_types, - "features": features, - "max_tokens": model.get("token_limits", {}).get("max_input_token_length", 8192), - "status": model.get("status") - }) + res.append( + {"name": model["id"], "model_types": model_types, "features": features, "max_tokens": model.get("token_limits", {}).get("max_input_token_length", 8192), "status": model.get("status")} + ) return res diff --git a/rag/llm/rerank_model.py b/rag/llm/rerank_model.py index 8067615141..5d6171ab21 100644 --- a/rag/llm/rerank_model.py +++ b/rag/llm/rerank_model.py @@ -27,6 +27,7 @@ from yarl import URL from common.log_utils import log_exception from common.token_utils import num_tokens_from_string, truncate, total_token_count_from_response + class Base(ABC): def __init__(self, key, model_name, **kwargs): pass @@ -400,6 +401,7 @@ class QWenRerank(Base): def __init__(self, key, model_name="gte-rerank", **kwargs): import dashscope + self.api_key = key self.model_name = dashscope.TextReRank.Models.gte_rerank if model_name is None else model_name # Remove invalid global timeout, use official SDK per-request timeout parameter @@ -409,19 +411,10 @@ class QWenRerank(Base): import dashscope # Pass official request_timeout parameter to both API call branches - if self.model_name.startswith("qwen3-rerank"): - resp = dashscope.TextReRank.call( - api_key=self.api_key, model=self.model_name, - query=query, documents=texts, top_n=len(texts), - request_timeout=self.request_timeout - ) - else: - resp = dashscope.TextReRank.call( - api_key=self.api_key, model=self.model_name, - query=query, documents=texts, - top_n=len(texts), return_documents=False, - request_timeout=self.request_timeout - ) + if self.model_name.startswith("qwen3-rerank"): + resp = dashscope.TextReRank.call(api_key=self.api_key, model=self.model_name, query=query, documents=texts, top_n=len(texts), request_timeout=self.request_timeout) + else: + resp = dashscope.TextReRank.call(api_key=self.api_key, model=self.model_name, query=query, documents=texts, top_n=len(texts), return_documents=False, request_timeout=self.request_timeout) rank = np.zeros(len(texts), dtype=float) if resp.status_code == HTTPStatus.OK: @@ -463,9 +456,7 @@ class HuggingfaceRerank(Base): try: # Fix: Add request timeout res = requests.post( - endpoint, headers={"Content-Type": "application/json"}, - json={"query": query, "texts": texts[i:i+batch_size], "raw_scores": False, "truncate": True}, - timeout=30 + endpoint, headers={"Content-Type": "application/json"}, json={"query": query, "texts": texts[i : i + batch_size], "raw_scores": False, "truncate": True}, timeout=30 ) res.raise_for_status() for o in res.json(): @@ -592,18 +583,17 @@ class FuturMixRerank(OpenAI_APIRerank): class RAGconRerank(Base): _FACTORY_NAME = "RAGcon" - + def __init__(self, key, model_name, base_url=None, **kwargs): if not base_url: base_url = "https://connect.ragcon.com/v1" - + self._api_key = key self._base_url = base_url - + self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"} self.model_name = model_name - - + def _compute_rank(self, query: str, texts: List) -> Tuple[np.ndarray, int]: texts = [truncate(t, 500) for t in texts] data = { diff --git a/rag/llm/sequence2txt_model.py b/rag/llm/sequence2txt_model.py index 1642314dcb..e41c48159e 100644 --- a/rag/llm/sequence2txt_model.py +++ b/rag/llm/sequence2txt_model.py @@ -96,26 +96,9 @@ class QWenSeq2txt(Base): else: audio_input = f"file://{audio_path}" - messages = [ - { - "role": "system", - "content": [{"text": ""}] - }, - { - "role": "user", - "content": [{"audio": audio_input}] - } - ] + messages = [{"role": "system", "content": [{"text": ""}]}, {"role": "user", "content": [{"audio": audio_input}]}] - resp = dashscope.MultiModalConversation.call( - model=self.model_name, - messages=messages, - result_format="message", - asr_options={ - "enable_lid": True, - "enable_itn": False - } - ) + resp = dashscope.MultiModalConversation.call(model=self.model_name, messages=messages, result_format="message", asr_options={"enable_lid": True, "enable_itn": False}) try: text = resp["output"]["choices"][0]["message"].content[0]["text"] @@ -131,27 +114,9 @@ class QWenSeq2txt(Base): else: audio_input = f"file://{audio_path}" - messages = [ - { - "role": "system", - "content": [{"text": ""}] - }, - { - "role": "user", - "content": [{"audio": audio_input}] - } - ] + messages = [{"role": "system", "content": [{"text": ""}]}, {"role": "user", "content": [{"audio": audio_input}]}] - stream = dashscope.MultiModalConversation.call( - model=self.model_name, - messages=messages, - result_format="message", - stream=True, - asr_options={ - "enable_lid": True, - "enable_itn": False - } - ) + stream = dashscope.MultiModalConversation.call(model=self.model_name, messages=messages, result_format="message", stream=True, asr_options={"enable_lid": True, "enable_itn": False}) full = "" for chunk in stream: @@ -164,6 +129,7 @@ class QWenSeq2txt(Base): yield {"event": "final", "text": full} + class AzureSeq2txt(Base): _FACTORY_NAME = "Azure-OpenAI" @@ -346,14 +312,9 @@ class ZhipuSeq2txt(Base): try: import ffmpeg import imageio_ffmpeg as ffmpeg_exe + ffmpeg_path = ffmpeg_exe.get_ffmpeg_exe() - ( - ffmpeg - .input(input_path) - .output(out_path, ar=16000, ac=1) - .overwrite_output() - .run(cmd=ffmpeg_path,quiet=True) - ) + (ffmpeg.input(input_path).output(out_path, ar=16000, ac=1).overwrite_output().run(cmd=ffmpeg_path, quiet=True)) return out_path except Exception as e: raise RuntimeError(f"audio convert failed: {e}") @@ -393,43 +354,41 @@ class ZhipuSeq2txt(Base): class RAGconSeq2txt(Base): """ RAGcon Sequence2Text Provider - routes through LiteLLM proxy - + Speech-to-text models routed through LiteLLM. Default Base URL: https://connect.ragcon.com/v1 """ + _FACTORY_NAME = "RAGcon" - + def __init__(self, key, model_name, base_url=None, lang="English", **kwargs): # Use provided base_url or fallback to default if not base_url: base_url = "https://connect.ragcon.com/v1" - + self.base_url = base_url self.model_name = model_name self.key = key self.lang = lang - + self.client = OpenAI(api_key=key, base_url=self.base_url) - + def transcription(self, audio_path, **kwargs): """ Transcribe audio file using RAGcon's OpenAI-compatible API. Uses Whisper's automatic language detection for German and English audio. - + Args: audio_path: Path to the audio file **kwargs: Additional parameters (currently unused but maintained for compatibility) - + Returns: tuple: (transcribed_text, token_count) """ with open(audio_path, "rb") as audio_file: # Call RAGcon API - Whisper will auto-detect language - transcription = self.client.audio.transcriptions.create( - model=self.model_name, - file=audio_file - ) - + transcription = self.client.audio.transcriptions.create(model=self.model_name, file=audio_file) + # Return text and token count text = transcription.text.strip() return text, num_tokens_from_string(text) diff --git a/rag/llm/tool_decorator.py b/rag/llm/tool_decorator.py index a1029c62dc..2870f94ed8 100644 --- a/rag/llm/tool_decorator.py +++ b/rag/llm/tool_decorator.py @@ -188,9 +188,7 @@ class FunctionToolSession: self.tools_map: dict[str, Callable[..., Any]] = {} for fn in tools: if not is_tool(fn): - raise TypeError( - f"{getattr(fn, '__name__', fn)!r} is not a @tool-decorated callable" - ) + raise TypeError(f"{getattr(fn, '__name__', fn)!r} is not a @tool-decorated callable") self.tools_map[fn.openai_schema["function"]["name"]] = fn @property @@ -204,9 +202,7 @@ class FunctionToolSession: if name not in self.tools_map: raise KeyError(f"Tool {name!r} is not registered") if not isinstance(arguments, Mapping): - raise TypeError( - f"Tool arguments for {name} must be an object, got {type(arguments).__name__}" - ) + raise TypeError(f"Tool arguments for {name} must be an object, got {type(arguments).__name__}") fn = self.tools_map[name] logging.info(f"[FunctionTool] invoke name={name} args={str(arguments)[:200]}") if asyncio.iscoroutinefunction(fn): diff --git a/rag/llm/tts_model.py b/rag/llm/tts_model.py index 4f95b2cb6b..38ff8de54a 100644 --- a/rag/llm/tts_model.py +++ b/rag/llm/tts_model.py @@ -85,28 +85,22 @@ class HTTPBasedTTS(Base): Base class for HTTP-based TTS services. Provides common HTTP request handling and response processing. """ - + def __init__(self, key, model_name, base_url, **kwargs): self.model_name = model_name self.base_url = base_url self.api_key = key - self.headers = { - "Content-Type": "application/json" - } + self.headers = {"Content-Type": "application/json"} if key and key != "x": self.headers["Authorization"] = f"Bearer {self.api_key}" - + def _build_payload(self, text, voice, **kwargs): """ Build payload for TTS request. Subclasses should override this method if they need custom payload structure. """ - return { - "model": self.model_name, - "voice": voice, - "input": text - } - + return {"model": self.model_name, "voice": voice, "input": text} + def _send_request(self, endpoint, payload, stream=True): """ Send HTTP request to TTS service. @@ -119,12 +113,12 @@ class HTTPBasedTTS(Base): stream=stream, timeout=60, ) - + if response.status_code != 200: raise Exception(f"**Error**: {response.status_code}, {response.text}") - + return response - + def _process_response(self, response): """ Process streaming response from TTS service. @@ -132,7 +126,7 @@ class HTTPBasedTTS(Base): for chunk in response.iter_content(): if chunk: yield chunk - + def tts(self, text, voice="alloy"): """ Generate speech from text. @@ -499,36 +493,29 @@ class StepFunTTS(OpenAITTS): class RAGconTTS(Base): """ RAGcon TTS Provider - routes through LiteLLM proxy - + Text-to-speech models routed through LiteLLM. Default Base URL: https://connect.ragcon.ai/v1 """ + _FACTORY_NAME = "RAGcon" - + def __init__(self, key, model_name, base_url=None, **kwargs): if not base_url: base_url = "https://connect.ragcon.com/v1" - + self.base_url = base_url self.api_key = key self.model_name = model_name - self.headers = { - "accept": "application/json", - "Content-Type": "application/json", - "Authorization": f"Bearer {self.api_key}" - } - + self.headers = {"accept": "application/json", "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} + def tts(self, text, voice="English Female", stream=True): """ Uses LiteLLM's /v1/audio/speech endpoint """ - payload = { - "model": self.model_name, - "input": text, - "voice": voice - } - + payload = {"model": self.model_name, "input": text, "voice": voice} + response = requests.post( f"{self.base_url}/audio/speech", headers=self.headers, @@ -536,10 +523,10 @@ class RAGconTTS(Base): stream=stream, timeout=60, ) - + if response.status_code != 200: raise Exception(f"**Error**: {response.status_code}, {response.text}") - + for chunk in response.iter_content(chunk_size=1024): if chunk: yield chunk diff --git a/rag/nlp/__init__.py b/rag/nlp/__init__.py index 3f5e7f2913..26a65a4df3 100644 --- a/rag/nlp/__init__.py +++ b/rag/nlp/__init__.py @@ -28,33 +28,122 @@ from PIL import Image import chardet -__all__ = ['rag_tokenizer'] +__all__ = ["rag_tokenizer"] all_codecs = [ - 'utf-8', 'gb2312', 'gbk', 'utf_16', 'ascii', 'big5', 'big5hkscs', - 'cp037', 'cp273', 'cp424', 'cp437', - 'cp500', 'cp720', 'cp737', 'cp775', 'cp850', 'cp852', 'cp855', 'cp856', 'cp857', - 'cp858', 'cp860', 'cp861', 'cp862', 'cp863', 'cp864', 'cp865', 'cp866', 'cp869', - 'cp874', 'cp875', 'cp932', 'cp949', 'cp950', 'cp1006', 'cp1026', 'cp1125', - 'cp1140', 'cp1250', 'cp1251', 'cp1252', 'cp1253', 'cp1254', 'cp1255', 'cp1256', - 'cp1257', 'cp1258', 'euc_jp', 'euc_jis_2004', 'euc_jisx0213', 'euc_kr', - 'gb18030', 'hz', 'iso2022_jp', 'iso2022_jp_1', 'iso2022_jp_2', - 'iso2022_jp_2004', 'iso2022_jp_3', 'iso2022_jp_ext', 'iso2022_kr', 'latin_1', - 'iso8859_2', 'iso8859_3', 'iso8859_4', 'iso8859_5', 'iso8859_6', 'iso8859_7', - 'iso8859_8', 'iso8859_9', 'iso8859_10', 'iso8859_11', 'iso8859_13', - 'iso8859_14', 'iso8859_15', 'iso8859_16', 'johab', 'koi8_r', 'koi8_t', 'koi8_u', - 'kz1048', 'mac_cyrillic', 'mac_greek', 'mac_iceland', 'mac_latin2', 'mac_roman', - 'mac_turkish', 'ptcp154', 'shift_jis', 'shift_jis_2004', 'shift_jisx0213', - 'utf_32', 'utf_32_be', 'utf_32_le', 'utf_16_be', 'utf_16_le', 'utf_7', 'windows-1250', 'windows-1251', - 'windows-1252', 'windows-1253', 'windows-1254', 'windows-1255', 'windows-1256', - 'windows-1257', 'windows-1258', 'latin-2' + "utf-8", + "gb2312", + "gbk", + "utf_16", + "ascii", + "big5", + "big5hkscs", + "cp037", + "cp273", + "cp424", + "cp437", + "cp500", + "cp720", + "cp737", + "cp775", + "cp850", + "cp852", + "cp855", + "cp856", + "cp857", + "cp858", + "cp860", + "cp861", + "cp862", + "cp863", + "cp864", + "cp865", + "cp866", + "cp869", + "cp874", + "cp875", + "cp932", + "cp949", + "cp950", + "cp1006", + "cp1026", + "cp1125", + "cp1140", + "cp1250", + "cp1251", + "cp1252", + "cp1253", + "cp1254", + "cp1255", + "cp1256", + "cp1257", + "cp1258", + "euc_jp", + "euc_jis_2004", + "euc_jisx0213", + "euc_kr", + "gb18030", + "hz", + "iso2022_jp", + "iso2022_jp_1", + "iso2022_jp_2", + "iso2022_jp_2004", + "iso2022_jp_3", + "iso2022_jp_ext", + "iso2022_kr", + "latin_1", + "iso8859_2", + "iso8859_3", + "iso8859_4", + "iso8859_5", + "iso8859_6", + "iso8859_7", + "iso8859_8", + "iso8859_9", + "iso8859_10", + "iso8859_11", + "iso8859_13", + "iso8859_14", + "iso8859_15", + "iso8859_16", + "johab", + "koi8_r", + "koi8_t", + "koi8_u", + "kz1048", + "mac_cyrillic", + "mac_greek", + "mac_iceland", + "mac_latin2", + "mac_roman", + "mac_turkish", + "ptcp154", + "shift_jis", + "shift_jis_2004", + "shift_jisx0213", + "utf_32", + "utf_32_be", + "utf_32_le", + "utf_16_be", + "utf_16_le", + "utf_7", + "windows-1250", + "windows-1251", + "windows-1252", + "windows-1253", + "windows-1254", + "windows-1255", + "windows-1256", + "windows-1257", + "windows-1258", + "latin-2", ] def find_codec(blob): detected = chardet.detect(blob[:1024]) - if detected['confidence'] > 0.5: - if detected['encoding'] == "ascii": + if detected["confidence"] > 0.5: + if detected["encoding"] == "ascii": return "utf-8" for c in all_codecs: @@ -88,44 +177,44 @@ QUESTION_PATTERN = [ def has_qbullet(reg, box, last_box, last_index, last_bull, bull_x0_list): - section, last_section = box['text'], last_box['text'] - q_reg = r'(\w|\W)*?(?:?|\?|\n|$)+' + section, last_section = box["text"], last_box["text"] + q_reg = r"(\w|\W)*?(?:?|\?|\n|$)+" full_reg = reg + q_reg has_bull = re.match(full_reg, section) index_str = None if has_bull: - if 'x0' not in last_box: - last_box['x0'] = box['x0'] - if 'top' not in last_box: - last_box['top'] = box['top'] - if last_bull and box['x0'] - last_box['x0'] > 10: + if "x0" not in last_box: + last_box["x0"] = box["x0"] + if "top" not in last_box: + last_box["top"] = box["top"] + if last_bull and box["x0"] - last_box["x0"] > 10: return None, last_index - if not last_bull and box['x0'] >= last_box['x0'] and box['top'] - last_box['top'] < 20: + if not last_bull and box["x0"] >= last_box["x0"] and box["top"] - last_box["top"] < 20: return None, last_index avg_bull_x0 = 0 if bull_x0_list: avg_bull_x0 = sum(bull_x0_list) / len(bull_x0_list) else: - avg_bull_x0 = box['x0'] - if box['x0'] - avg_bull_x0 > 10: + avg_bull_x0 = box["x0"] + if box["x0"] - avg_bull_x0 > 10: return None, last_index index_str = has_bull.group(1) index = index_int(index_str) - if last_section[-1] == ':' or last_section[-1] == ':': + if last_section[-1] == ":" or last_section[-1] == ":": return None, last_index if not last_index or index >= last_index: - bull_x0_list.append(box['x0']) + bull_x0_list.append(box["x0"]) return has_bull, index - if section[-1] == '?' or section[-1] == '?': - bull_x0_list.append(box['x0']) + if section[-1] == "?" or section[-1] == "?": + bull_x0_list.append(box["x0"]) return has_bull, index - if box['layout_type'] == 'title': - bull_x0_list.append(box['x0']) + if box["layout_type"] == "title": + bull_x0_list.append(box["x0"]) return has_bull, index pure_section = section.lstrip(re.match(reg, section).group()).lower() - ask_reg = r'(what|when|where|how|why|which|who|whose|为什么|为啥|哪)' + ask_reg = r"(what|when|where|how|why|which|who|whose|为什么|为啥|哪)" if re.match(ask_reg, pure_section): - bull_x0_list.append(box['x0']) + bull_x0_list.append(box["x0"]) return has_bull, index return None, last_index @@ -166,38 +255,38 @@ def qbullets_category(sections): return res, QUESTION_PATTERN[res] -BULLET_PATTERN = [[ - r"第[零一二三四五六七八九十百0-9]+(分?编|部分)", - r"第[零一二三四五六七八九十百0-9]+章", - r"第[零一二三四五六七八九十百0-9]+节", - r"第[零一二三四五六七八九十百0-9]+条", - r"[\((][零一二三四五六七八九十百]+[\))]", -], [ - r"第[0-9]+章", - r"第[0-9]+节", - r"[0-9]{,2}[\. 、]", - r"[0-9]{,2}\.[0-9]{,2}[^a-zA-Z/%~-]", - r"[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}", - r"[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}", -], [ - r"第[零一二三四五六七八九十百0-9]+章", - r"第[零一二三四五六七八九十百0-9]+节", - r"[零一二三四五六七八九十百]+[ 、]", - r"[\((][零一二三四五六七八九十百]+[\))]", - r"[\((][0-9]{,2}[\))]", -], [ - r"PART (ONE|TWO|THREE|FOUR|FIVE|SIX|SEVEN|EIGHT|NINE|TEN)", - r"Chapter (I+V?|VI*|XI|IX|X)", - r"Section [0-9]+", - r"Article [0-9]+" -], [ - r"^#[^#]", - r"^##[^#]", - r"^###.*", - r"^####.*", - r"^#####.*", - r"^######.*", -] +BULLET_PATTERN = [ + [ + r"第[零一二三四五六七八九十百0-9]+(分?编|部分)", + r"第[零一二三四五六七八九十百0-9]+章", + r"第[零一二三四五六七八九十百0-9]+节", + r"第[零一二三四五六七八九十百0-9]+条", + r"[\((][零一二三四五六七八九十百]+[\))]", + ], + [ + r"第[0-9]+章", + r"第[0-9]+节", + r"[0-9]{,2}[\. 、]", + r"[0-9]{,2}\.[0-9]{,2}[^a-zA-Z/%~-]", + r"[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}", + r"[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}", + ], + [ + r"第[零一二三四五六七八九十百0-9]+章", + r"第[零一二三四五六七八九十百0-9]+节", + r"[零一二三四五六七八九十百]+[ 、]", + r"[\((][零一二三四五六七八九十百]+[\))]", + r"[\((][0-9]{,2}[\))]", + ], + [r"PART (ONE|TWO|THREE|FOUR|FIVE|SIX|SEVEN|EIGHT|NINE|TEN)", r"Chapter (I+V?|VI*|XI|IX|X)", r"Section [0-9]+", r"Article [0-9]+"], + [ + r"^#[^#]", + r"^##[^#]", + r"^###.*", + r"^####.*", + r"^#####.*", + r"^######.*", + ], ] @@ -207,9 +296,7 @@ def random_choices(arr, k): def not_bullet(line): - patt = [ - r"0", r"[0-9]+ +[0-9~个只-]", r"[0-9]+\.{2,}" - ] + patt = [r"0", r"[0-9]+ +[0-9~个只-]", r"[0-9]+\.{2,}"] return any([re.match(r, line) for r in patt]) @@ -258,7 +345,7 @@ def is_chinese(text): return False chinese = 0 for ch in text: - if '\u4e00' <= ch <= '\u9fff': + if "\u4e00" <= ch <= "\u9fff": chinese += 1 if chinese / len(text) > 0.2: return True @@ -267,6 +354,7 @@ def is_chinese(text): def tokenize(d, txt, eng): from . import rag_tokenizer + d["content_with_weight"] = txt t = re.sub(r"]{0,12})?>", " ", txt) d["content_ltks"] = rag_tokenizer.tokenize(t) @@ -394,7 +482,7 @@ def tokenize_table(tbls, doc, eng, batch_size=10): de = "; " if eng else "; " for i in range(0, len(rows), batch_size): d = copy.deepcopy(doc) - r = de.join(rows[i:i + batch_size]) + r = de.join(rows[i : i + batch_size]) tokenize(d, r, eng) d["doc_type_kwd"] = "table" if img: @@ -546,7 +634,7 @@ def attach_media_context(chunks, table_context_size=0, image_context_size=0): def collect_context_from_sentences(sentences, boundary_idx, token_budget): prev_ctx = [] remaining_prev = token_budget - for s in reversed(sentences[:boundary_idx + 1]): + for s in reversed(sentences[: boundary_idx + 1]): if remaining_prev <= 0: break tks = num_tokens_from_string(s) @@ -561,7 +649,7 @@ def attach_media_context(chunks, table_context_size=0, image_context_size=0): next_ctx = [] remaining_next = token_budget - for s in sentences[boundary_idx + 1:]: + for s in sentences[boundary_idx + 1 :]: if remaining_next <= 0: break tks = num_tokens_from_string(s) @@ -658,7 +746,7 @@ def attach_media_context(chunks, table_context_size=0, image_context_size=0): if page_order and idx in page_order: pos_in_page = page_order.index(idx) if pos_in_page == 0: - for neighbor in page_order[pos_in_page + 1:]: + for neighbor in page_order[pos_in_page + 1 :]: if is_text_chunk(chunks[neighbor]): best_idx = neighbor break @@ -695,8 +783,7 @@ def attach_media_context(chunks, table_context_size=0, image_context_size=0): if "content_ltks" in ck: ck["content_ltks"] = rag_tokenizer.tokenize(combined) if "content_sm_ltks" in ck: - ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize( - ck.get("content_ltks", rag_tokenizer.tokenize(combined))) + ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck.get("content_ltks", rag_tokenizer.tokenize(combined))) if positioned_indices: chunks[:] = [chunks[i] for i in ordered_indices] @@ -706,7 +793,8 @@ def attach_media_context(chunks, table_context_size=0, image_context_size=0): def append_context2table_image4pdf(sections: list, tabls: list, table_context_size=0, return_context=False): from deepdoc.parser import PdfParser - if table_context_size <=0: + + if table_context_size <= 0: return [] if return_context else tabls page_bucket = defaultdict(list) @@ -750,12 +838,12 @@ def append_context2table_image4pdf(sections: list, tabls: list, table_context_si page -= 1 if page < 0 or page not in page_bucket: break - i = len(page_bucket[page]) -1 + i = len(page_bucket[page]) - 1 blks = page_bucket[page] (_, _, _, _), cnt = blks[i] txts = re.split(r"([。!??;!\n]|\. )", cnt, flags=re.DOTALL)[::-1] for j in range(0, len(txts), 2): - txt = (txts[j+1] if j+1 table_context_size: break i -= 1 @@ -775,7 +863,7 @@ def append_context2table_image4pdf(sections: list, tabls: list, table_context_si (_, _, _, _), cnt = blks[i] txts = re.split(r"([。!??;!\n]|\. )", cnt, flags=re.DOTALL) for j in range(0, len(txts), 2): - txt += txts[j] + (txts[j+1] if j+1 table_context_size: break i += 1 @@ -807,7 +895,7 @@ def append_context2table_image4pdf(sections: list, tabls: list, table_context_si (_, _, t, b), txt = blks[i] if b > top: break - (_, _, _t, _b), _txt = blks[i+1] + (_, _, _t, _b), _txt = blks[i + 1] if _t < _bott: i += 1 continue @@ -847,13 +935,12 @@ def add_positions(d, poss): def remove_contents_table(sections, eng=False): i = 0 while i < len(sections): + def get(i): nonlocal sections - return (sections[i] if isinstance(sections[i], - type("")) else sections[i][0]).strip() + return (sections[i] if isinstance(sections[i], type("")) else sections[i][0]).strip() - if not re.match(r"(contents|目录|目次|table of contents|致谢|acknowledge)$", - re.sub(r"( | |\u3000)+", "", get(i).split("@@")[0], flags=re.IGNORECASE)): + if not re.match(r"(contents|目录|目次|table of contents|致谢|acknowledge)$", re.sub(r"( | |\u3000)+", "", get(i).split("@@")[0], flags=re.IGNORECASE)): i += 1 continue sections.pop(i) @@ -935,8 +1022,7 @@ def tree_merge(bull, sections, depth): sections = [(s, "") for s in sections] # filter out position information in pdf sections - sections = [(t, o) for t, o in sections if - t and len(t.split("@")[0].strip()) > 1 and not re.match(r"[0-9]+$", t.split("@")[0].strip())] + sections = [(t, o) for t, o in sections if t and len(t.split("@")[0].strip()) > 1 and not re.match(r"[0-9]+$", t.split("@")[0].strip())] def get_level(bull, section): text, layout = section @@ -982,8 +1068,7 @@ def hierarchical_merge(bull, sections, depth): return [] if isinstance(sections[0], type("")): sections = [(s, "") for s in sections] - sections = [(t, o) for t, o in sections if - t and len(t.split("@")[0].strip()) > 1 and not re.match(r"[0-9]+$", t.split("@")[0].strip())] + sections = [(t, o) for t, o in sections if t and len(t.split("@")[0].strip()) > 1 and not re.match(r"[0-9]+$", t.split("@")[0].strip())] bullets_size = len(BULLET_PATTERN[bull]) levels = [[] for _ in range(bullets_size + 2)] @@ -1069,6 +1154,7 @@ def hierarchical_merge(bull, sections, depth): def naive_merge(sections: str | list, chunk_token_num=128, delimiter="\n。;!?", overlapped_percent=0): from deepdoc.parser.pdf_parser import RAGFlowPdfParser + if not sections: return [] if isinstance(sections, str): @@ -1086,10 +1172,10 @@ def naive_merge(sections: str | list, chunk_token_num=128, delimiter="\n。; if tnum < 8: pos = "" # Ensure that the length of the merged chunk does not exceed chunk_token_num - if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent) / 100.: + if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent) / 100.0: if cks: overlapped = RAGFlowPdfParser.remove_tag(cks[-1]) - t = overlapped[int(len(overlapped) * (100 - overlapped_percent) / 100.):] + t + t = overlapped[int(len(overlapped) * (100 - overlapped_percent) / 100.0) :] + t # Recount with the overlap prefix included, else chunks overshoot chunk_token_num. tnum = num_tokens_from_string(t) if t.find(pos) < 0: @@ -1140,6 +1226,7 @@ def naive_merge(sections: str | list, chunk_token_num=128, delimiter="\n。; def naive_merge_with_images(texts, images, chunk_token_num=128, delimiter="\n。;!?", overlapped_percent=0): from deepdoc.parser.pdf_parser import RAGFlowPdfParser + if not texts or len(texts) != len(images): return [], [] cks = [""] @@ -1154,10 +1241,10 @@ def naive_merge_with_images(texts, images, chunk_token_num=128, delimiter="\n。 if tnum < 8: pos = "" # Ensure that the length of the merged chunk does not exceed chunk_token_num - if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent) / 100.: + if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent) / 100.0: if cks: overlapped = RAGFlowPdfParser.remove_tag(cks[-1]) - t = overlapped[int(len(overlapped) * (100 - overlapped_percent) / 100.):] + t + t = overlapped[int(len(overlapped) * (100 - overlapped_percent) / 100.0) :] + t # Recount with the overlap prefix included, else chunks overshoot chunk_token_num. tnum = num_tokens_from_string(t) if t.find(pos) < 0: @@ -1226,7 +1313,7 @@ def naive_merge_with_images(texts, images, chunk_token_num=128, delimiter="\n。 def docx_question_level(p, bull=-1): txt = re.sub(r"\u3000", " ", p.text).strip() - if hasattr(p.style, 'name') and p.style.name and p.style.name.startswith('Heading'): + if hasattr(p.style, "name") and p.style.name and p.style.name.startswith("Heading"): # Heading styles are usually "Heading N", but the base "Heading" style, # custom "Heading"-prefixed styles, or "HeadingN" (no space) have no # space-separated trailing integer. Extract the level digits safely and @@ -1250,8 +1337,7 @@ def concat_img(img1, img2): if img1 is img2: return img1 - if (img1 is None or isinstance(img1, LazyImage)) and \ - (img2 is None or isinstance(img2, LazyImage)): + if (img1 is None or isinstance(img1, LazyImage)) and (img2 is None or isinstance(img2, LazyImage)): if img1 and not img2: return img1 if not img1 and img2: @@ -1283,12 +1369,13 @@ def concat_img(img1, img2): new_width = max(width1, width2) new_height = height1 + height2 - new_image = Image.new('RGB', (new_width, new_height)) + new_image = Image.new("RGB", (new_width, new_height)) new_image.paste(img1, (0, 0)) new_image.paste(img2, (0, height1)) return new_image + def _build_cks(sections, delimiter): cks = [] tables = [] @@ -1300,9 +1387,7 @@ def _build_cks(sections, delimiter): if has_custom: # escape delimiters and build alternation pattern, longest first - custom_pattern = "|".join( - re.escape(t) for t in sorted(set(custom_delimiters), key=len, reverse=True) - ) + custom_pattern = "|".join(re.escape(t) for t in sorted(set(custom_delimiters), key=len, reverse=True)) # capture delimiters so they appear in re.split results pattern = r"(%s)" % custom_pattern @@ -1318,24 +1403,28 @@ def _build_cks(sections, delimiter): # table chunk ck_text = text + str(table) idx = len(cks) - cks.append({ - "text": ck_text, - "image": image, - "ck_type": "table", - "tk_nums": num_tokens_from_string(ck_text), - }) + cks.append( + { + "text": ck_text, + "image": image, + "ck_type": "table", + "tk_nums": num_tokens_from_string(ck_text), + } + ) tables.append(idx) continue if image: # image chunk (text kept as-is for context) idx = len(cks) - cks.append({ - "text": text, - "image": image, - "ck_type": "image", - "tk_nums": num_tokens_from_string(text), - }) + cks.append( + { + "text": text, + "image": image, + "ck_type": "image", + "tk_nums": num_tokens_from_string(text), + } + ) images.append(idx) continue @@ -1347,12 +1436,14 @@ def _build_cks(sections, delimiter): if not sub_sec or not sub_sec.strip(): if seg and seg.strip(): s = seg.strip() - cks.append({ - "text": s, - "image": None, - "ck_type": "text", - "tk_nums": num_tokens_from_string(s), - }) + cks.append( + { + "text": s, + "image": None, + "ck_type": "text", + "tk_nums": num_tokens_from_string(s), + } + ) seg = "" continue @@ -1360,37 +1451,42 @@ def _build_cks(sections, delimiter): if re.fullmatch(custom_pattern, sub_sec.strip()): if seg and seg.strip(): s = seg.strip() - cks.append({ - "text": s, - "image": None, - "ck_type": "text", - "tk_nums": num_tokens_from_string(s), - }) + cks.append( + { + "text": s, + "image": None, + "ck_type": "text", + "tk_nums": num_tokens_from_string(s), + } + ) seg = "" continue # ③ normal text content → accumulate seg += sub_sec else: - if text and text.strip(): t = text.strip() - cks.append({ - "text": t, - "image": None, - "ck_type": "text", - "tk_nums": num_tokens_from_string(t), - }) + cks.append( + { + "text": t, + "image": None, + "ck_type": "text", + "tk_nums": num_tokens_from_string(t), + } + ) # final flush after loop (only when custom delimiters are used) if has_custom and seg and seg.strip(): s = seg.strip() - cks.append({ - "text": s, - "image": None, - "ck_type": "text", - "tk_nums": num_tokens_from_string(s), - }) + cks.append( + { + "text": s, + "image": None, + "ck_type": "text", + "tk_nums": num_tokens_from_string(s), + } + ) return cks, tables, images, has_custom @@ -1485,7 +1581,7 @@ def _merge_cks(cks, chunk_token_num, has_custom): image_idxs.append(len(merged) - 1) continue - if prev_text_ck<0 or merged[prev_text_ck]["tk_nums"] >= chunk_token_num or has_custom: + if prev_text_ck < 0 or merged[prev_text_ck]["tk_nums"] >= chunk_token_num or has_custom: merged.append(cks[i]) prev_text_ck = len(merged) - 1 continue @@ -1498,10 +1594,11 @@ def _merge_cks(cks, chunk_token_num, has_custom): def naive_merge_docx( sections, - chunk_token_num = 128, + chunk_token_num=128, delimiter="\n。;!?", table_context_size=0, - image_context_size=0,): + image_context_size=0, +): if not sections: return [], [] @@ -1515,7 +1612,7 @@ def naive_merge_docx( if image_context_size > 0: for i in images: _add_context(cks, i, image_context_size) - + merged_cks, merged_image_idx = _merge_cks(cks, chunk_token_num, has_custom) return merged_cks, merged_image_idx @@ -1532,7 +1629,7 @@ def get_delimiters(delimiters: str): for m in re.finditer(r"`([^`]+)`", delimiters, re.I): f, t = m.span() dels.append(m.group(1)) - dels.extend(list(delimiters[s: f])) + dels.extend(list(delimiters[s:f])) s = t if s < len(delimiters): dels.extend(list(delimiters[s:])) diff --git a/rag/nlp/query.py b/rag/nlp/query.py index aefc15ed4e..081da1325f 100644 --- a/rag/nlp/query.py +++ b/rag/nlp/query.py @@ -70,11 +70,10 @@ class FulltextQueryer(QueryBase): # (e.g. WordNet returns "cat-o'-nine-tails" for "cat") syn = [rag_tokenizer.tokenize(s).replace("'", "") for s in self.syn.lookup(tk)] keywords.extend(syn) - syn = ["\"{}\"^{:.4f}".format(s, w / 4.) for s in syn if s.strip()] + syn = ['"{}"^{:.4f}'.format(s, w / 4.0) for s in syn if s.strip()] syns.append(" ".join(syn)) - q = ["({}^{:.4f}".format(tk, w) + " {})".format(syn) for (tk, w), syn in zip(tks_w, syns) if - tk and not re.match(r"[.^+\(\)-]", tk)] + q = ["({}^{:.4f}".format(tk, w) + " {})".format(syn) for (tk, w), syn in zip(tks_w, syns) if tk and not re.match(r"[.^+\(\)-]", tk)] for i in range(1, len(tks_w)): left, right = tks_w[i - 1][0].strip(), tks_w[i][0].strip() if not left or not right: @@ -90,9 +89,7 @@ class FulltextQueryer(QueryBase): if not q: q.append(txt) query = " ".join(q) - return MatchTextExpr( - self.query_fields, query, 100, {"original_query": original_query} - ), keywords + return MatchTextExpr(self.query_fields, query, 100, {"original_query": original_query}), keywords def need_fine_grained_tokenize(tk): if len(tk) < 3: @@ -114,11 +111,7 @@ class FulltextQueryer(QueryBase): logging.debug(json.dumps(twts, ensure_ascii=False)) tms = [] for tk, w in sorted(twts, key=lambda x: x[1] * -1): - sm = ( - rag_tokenizer.fine_grained_tokenize(tk).split() - if need_fine_grained_tokenize(tk) - else [] - ) + sm = rag_tokenizer.fine_grained_tokenize(tk).split() if need_fine_grained_tokenize(tk) else [] sm = [ re.sub( r"[ ,\./;'\[\]\\`~!@#$%\^&\*\(\)=\+_<>\?:\"\{\}\|,。;‘’【】、!¥……()——《》?:“”-]+", @@ -139,7 +132,7 @@ class FulltextQueryer(QueryBase): if len(keywords) < 32: keywords.extend([s for s in tk_syns if s]) tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s] - tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns] + tk_syns = [f'"{s}"' if s.find(" ") > 0 else s for s in tk_syns] if len(keywords) >= 32: break @@ -159,13 +152,7 @@ class FulltextQueryer(QueryBase): if len(twts) > 1: tms += ' ("%s"~2)^1.5' % rag_tokenizer.tokenize(tt) - syns = " OR ".join( - [ - '"%s"' - % rag_tokenizer.tokenize(self.sub_special_char(s)) - for s in syns - ] - ) + syns = " OR ".join(['"%s"' % rag_tokenizer.tokenize(self.sub_special_char(s)) for s in syns]) if syns and tms: tms = f"({tms})^5 OR ({syns})^0.7" @@ -175,9 +162,7 @@ class FulltextQueryer(QueryBase): query = " OR ".join([f"({t})" for t in qs if t]) if not query: query = otxt - return MatchTextExpr( - self.query_fields, query, 100, {"minimum_should_match": min_match, "original_query": original_query} - ), keywords + return MatchTextExpr(self.query_fields, query, 100, {"minimum_should_match": min_match, "original_query": original_query}), keywords return None, keywords def hybrid_similarity(self, avec, bvecs, atks, btkss, tkweight=0.3, vtweight=0.7): @@ -198,9 +183,9 @@ class FulltextQueryer(QueryBase): wts = self.tw.weights(tks, preprocess=False) for i, (t, c) in enumerate(wts): d[t] += c * 0.4 - if i+1 < len(wts): - _t, _c = wts[i+1] - d[t+_t] += max(c, _c) * 0.6 + if i + 1 < len(wts): + _t, _c = wts[i + 1] + d[t + _t] += max(c, _c) * 0.6 return d atks = to_dict(atks) @@ -232,7 +217,7 @@ class FulltextQueryer(QueryBase): tk_syns = self.syn.lookup(tk) tk_syns = [self.sub_special_char(s) for s in tk_syns] tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s] - tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns] + tk_syns = [f'"{s}"' if s.find(" ") > 0 else s for s in tk_syns] tk = self.sub_special_char(tk) if tk.find(" ") > 0: tk = '"%s"' % tk @@ -241,6 +226,4 @@ class FulltextQueryer(QueryBase): if tk: keywords.append(f"{tk}^{w}") - return MatchTextExpr(self.query_fields, " ".join(keywords), 100, - {"minimum_should_match": min(3, round(len(keywords) / 10)), - "original_query": " ".join(origin_keywords)}) + return MatchTextExpr(self.query_fields, " ".join(keywords), 100, {"minimum_should_match": min(3, round(len(keywords) / 10)), "original_query": " ".join(origin_keywords)}) diff --git a/rag/nlp/rag_tokenizer.py b/rag/nlp/rag_tokenizer.py index 369d0448d7..2cbcdfa233 100644 --- a/rag/nlp/rag_tokenizer.py +++ b/rag/nlp/rag_tokenizer.py @@ -15,17 +15,20 @@ # import infinity.rag_tokenizer -class RagTokenizer(infinity.rag_tokenizer.RagTokenizer): + +class RagTokenizer(infinity.rag_tokenizer.RagTokenizer): def tokenize(self, line: str) -> str: - from common import settings # moved from the top of the file to avoid circular import + from common import settings # moved from the top of the file to avoid circular import + if settings.DOC_ENGINE_INFINITY: return line else: return super().tokenize(line) def fine_grained_tokenize(self, tks: str) -> str: - from common import settings # moved from the top of the file to avoid circular import + from common import settings # moved from the top of the file to avoid circular import + if settings.DOC_ENGINE_INFINITY: return tks else: diff --git a/rag/nlp/surname.py b/rag/nlp/surname.py index 39c28be952..0df393a5a8 100644 --- a/rag/nlp/surname.py +++ b/rag/nlp/surname.py @@ -14,131 +14,560 @@ # limitations under the License. # -m = set(["赵", "钱", "孙", "李", - "周", "吴", "郑", "王", - "冯", "陈", "褚", "卫", - "蒋", "沈", "韩", "杨", - "朱", "秦", "尤", "许", - "何", "吕", "施", "张", - "孔", "曹", "严", "华", - "金", "魏", "陶", "姜", - "戚", "谢", "邹", "喻", - "柏", "水", "窦", "章", - "云", "苏", "潘", "葛", - "奚", "范", "彭", "郎", - "鲁", "韦", "昌", "马", - "苗", "凤", "花", "方", - "俞", "任", "袁", "柳", - "酆", "鲍", "史", "唐", - "费", "廉", "岑", "薛", - "雷", "贺", "倪", "汤", - "滕", "殷", "罗", "毕", - "郝", "邬", "安", "常", - "乐", "于", "时", "傅", - "皮", "卞", "齐", "康", - "伍", "余", "元", "卜", - "顾", "孟", "平", "黄", - "和", "穆", "萧", "尹", - "姚", "邵", "湛", "汪", - "祁", "毛", "禹", "狄", - "米", "贝", "明", "臧", - "计", "伏", "成", "戴", - "谈", "宋", "茅", "庞", - "熊", "纪", "舒", "屈", - "项", "祝", "董", "梁", - "杜", "阮", "蓝", "闵", - "席", "季", "麻", "强", - "贾", "路", "娄", "危", - "江", "童", "颜", "郭", - "梅", "盛", "林", "刁", - "钟", "徐", "邱", "骆", - "高", "夏", "蔡", "田", - "樊", "胡", "凌", "霍", - "虞", "万", "支", "柯", - "昝", "管", "卢", "莫", - "经", "房", "裘", "缪", - "干", "解", "应", "宗", - "丁", "宣", "贲", "邓", - "郁", "单", "杭", "洪", - "包", "诸", "左", "石", - "崔", "吉", "钮", "龚", - "程", "嵇", "邢", "滑", - "裴", "陆", "荣", "翁", - "荀", "羊", "於", "惠", - "甄", "曲", "家", "封", - "芮", "羿", "储", "靳", - "汲", "邴", "糜", "松", - "井", "段", "富", "巫", - "乌", "焦", "巴", "弓", - "牧", "隗", "山", "谷", - "车", "侯", "宓", "蓬", - "全", "郗", "班", "仰", - "秋", "仲", "伊", "宫", - "宁", "仇", "栾", "暴", - "甘", "钭", "厉", "戎", - "祖", "武", "符", "刘", - "景", "詹", "束", "龙", - "叶", "幸", "司", "韶", - "郜", "黎", "蓟", "薄", - "印", "宿", "白", "怀", - "蒲", "邰", "从", "鄂", - "索", "咸", "籍", "赖", - "卓", "蔺", "屠", "蒙", - "池", "乔", "阴", "鬱", - "胥", "能", "苍", "双", - "闻", "莘", "党", "翟", - "谭", "贡", "劳", "逄", - "姬", "申", "扶", "堵", - "冉", "宰", "郦", "雍", - "郤", "璩", "桑", "桂", - "濮", "牛", "寿", "通", - "边", "扈", "燕", "冀", - "郏", "浦", "尚", "农", - "温", "别", "庄", "晏", - "柴", "瞿", "阎", "充", - "慕", "连", "茹", "习", - "宦", "艾", "鱼", "容", - "向", "古", "易", "慎", - "戈", "廖", "庾", "终", - "暨", "居", "衡", "步", - "都", "耿", "满", "弘", - "匡", "国", "文", "寇", - "广", "禄", "阙", "东", - "欧", "殳", "沃", "利", - "蔚", "越", "夔", "隆", - "师", "巩", "厍", "聂", - "晁", "勾", "敖", "融", - "冷", "訾", "辛", "阚", - "那", "简", "饶", "空", - "曾", "母", "沙", "乜", - "养", "鞠", "须", "丰", - "巢", "关", "蒯", "相", - "查", "后", "荆", "红", - "游", "竺", "权", "逯", - "盖", "益", "桓", "公", - "兰", "原", "乞", "西", "阿", "肖", "丑", "位", "曽", "巨", "德", "代", "圆", "尉", "仵", "纳", "仝", "脱", - "丘", "但", "展", "迪", "付", "覃", "晗", "特", "隋", "苑", "奥", "漆", "谌", "郄", "练", "扎", "邝", "渠", - "信", "门", "陳", "化", "原", "密", "泮", "鹿", "赫", - "万俟", "司马", "上官", "欧阳", - "夏侯", "诸葛", "闻人", "东方", - "赫连", "皇甫", "尉迟", "公羊", - "澹台", "公冶", "宗政", "濮阳", - "淳于", "单于", "太叔", "申屠", - "公孙", "仲孙", "轩辕", "令狐", - "钟离", "宇文", "长孙", "慕容", - "鲜于", "闾丘", "司徒", "司空", - "亓官", "司寇", "仉督", "子车", - "颛孙", "端木", "巫马", "公西", - "漆雕", "乐正", "壤驷", "公良", - "拓跋", "夹谷", "宰父", "榖梁", - "晋", "楚", "闫", "法", "汝", "鄢", "涂", "钦", - "段干", "百里", "东郭", "南门", - "呼延", "归", "海", "羊舌", "微", "生", - "岳", "帅", "缑", "亢", "况", "后", "有", "琴", - "梁丘", "左丘", "东门", "西门", - "商", "牟", "佘", "佴", "伯", "赏", "南宫", - "墨", "哈", "谯", "笪", "年", "爱", "阳", "佟", - "第五", "言", "福"]) +m = set( + [ + "赵", + "钱", + "孙", + "李", + "周", + "吴", + "郑", + "王", + "冯", + "陈", + "褚", + "卫", + "蒋", + "沈", + "韩", + "杨", + "朱", + "秦", + "尤", + "许", + "何", + "吕", + "施", + "张", + "孔", + "曹", + "严", + "华", + "金", + "魏", + "陶", + "姜", + "戚", + "谢", + "邹", + "喻", + "柏", + "水", + "窦", + "章", + "云", + "苏", + "潘", + "葛", + "奚", + "范", + "彭", + "郎", + "鲁", + "韦", + "昌", + "马", + "苗", + "凤", + "花", + "方", + "俞", + "任", + "袁", + "柳", + "酆", + "鲍", + "史", + "唐", + "费", + "廉", + "岑", + "薛", + "雷", + "贺", + "倪", + "汤", + "滕", + "殷", + "罗", + "毕", + "郝", + "邬", + "安", + "常", + "乐", + "于", + "时", + "傅", + "皮", + "卞", + "齐", + "康", + "伍", + "余", + "元", + "卜", + "顾", + "孟", + "平", + "黄", + "和", + "穆", + "萧", + "尹", + "姚", + "邵", + "湛", + "汪", + "祁", + "毛", + "禹", + "狄", + "米", + "贝", + "明", + "臧", + "计", + "伏", + "成", + "戴", + "谈", + "宋", + "茅", + "庞", + "熊", + "纪", + "舒", + "屈", + "项", + "祝", + "董", + "梁", + "杜", + "阮", + "蓝", + "闵", + "席", + "季", + "麻", + "强", + "贾", + "路", + "娄", + "危", + "江", + "童", + "颜", + "郭", + "梅", + "盛", + "林", + "刁", + "钟", + "徐", + "邱", + "骆", + "高", + "夏", + "蔡", + "田", + "樊", + "胡", + "凌", + "霍", + "虞", + "万", + "支", + "柯", + "昝", + "管", + "卢", + "莫", + "经", + "房", + "裘", + "缪", + "干", + "解", + "应", + "宗", + "丁", + "宣", + "贲", + "邓", + "郁", + "单", + "杭", + "洪", + "包", + "诸", + "左", + "石", + "崔", + "吉", + "钮", + "龚", + "程", + "嵇", + "邢", + "滑", + "裴", + "陆", + "荣", + "翁", + "荀", + "羊", + "於", + "惠", + "甄", + "曲", + "家", + "封", + "芮", + "羿", + "储", + "靳", + "汲", + "邴", + "糜", + "松", + "井", + "段", + "富", + "巫", + "乌", + "焦", + "巴", + "弓", + "牧", + "隗", + "山", + "谷", + "车", + "侯", + "宓", + "蓬", + "全", + "郗", + "班", + "仰", + "秋", + "仲", + "伊", + "宫", + "宁", + "仇", + "栾", + "暴", + "甘", + "钭", + "厉", + "戎", + "祖", + "武", + "符", + "刘", + "景", + "詹", + "束", + "龙", + "叶", + "幸", + "司", + "韶", + "郜", + "黎", + "蓟", + "薄", + "印", + "宿", + "白", + "怀", + "蒲", + "邰", + "从", + "鄂", + "索", + "咸", + "籍", + "赖", + "卓", + "蔺", + "屠", + "蒙", + "池", + "乔", + "阴", + "鬱", + "胥", + "能", + "苍", + "双", + "闻", + "莘", + "党", + "翟", + "谭", + "贡", + "劳", + "逄", + "姬", + "申", + "扶", + "堵", + "冉", + "宰", + "郦", + "雍", + "郤", + "璩", + "桑", + "桂", + "濮", + "牛", + "寿", + "通", + "边", + "扈", + "燕", + "冀", + "郏", + "浦", + "尚", + "农", + "温", + "别", + "庄", + "晏", + "柴", + "瞿", + "阎", + "充", + "慕", + "连", + "茹", + "习", + "宦", + "艾", + "鱼", + "容", + "向", + "古", + "易", + "慎", + "戈", + "廖", + "庾", + "终", + "暨", + "居", + "衡", + "步", + "都", + "耿", + "满", + "弘", + "匡", + "国", + "文", + "寇", + "广", + "禄", + "阙", + "东", + "欧", + "殳", + "沃", + "利", + "蔚", + "越", + "夔", + "隆", + "师", + "巩", + "厍", + "聂", + "晁", + "勾", + "敖", + "融", + "冷", + "訾", + "辛", + "阚", + "那", + "简", + "饶", + "空", + "曾", + "母", + "沙", + "乜", + "养", + "鞠", + "须", + "丰", + "巢", + "关", + "蒯", + "相", + "查", + "后", + "荆", + "红", + "游", + "竺", + "权", + "逯", + "盖", + "益", + "桓", + "公", + "兰", + "原", + "乞", + "西", + "阿", + "肖", + "丑", + "位", + "曽", + "巨", + "德", + "代", + "圆", + "尉", + "仵", + "纳", + "仝", + "脱", + "丘", + "但", + "展", + "迪", + "付", + "覃", + "晗", + "特", + "隋", + "苑", + "奥", + "漆", + "谌", + "郄", + "练", + "扎", + "邝", + "渠", + "信", + "门", + "陳", + "化", + "原", + "密", + "泮", + "鹿", + "赫", + "万俟", + "司马", + "上官", + "欧阳", + "夏侯", + "诸葛", + "闻人", + "东方", + "赫连", + "皇甫", + "尉迟", + "公羊", + "澹台", + "公冶", + "宗政", + "濮阳", + "淳于", + "单于", + "太叔", + "申屠", + "公孙", + "仲孙", + "轩辕", + "令狐", + "钟离", + "宇文", + "长孙", + "慕容", + "鲜于", + "闾丘", + "司徒", + "司空", + "亓官", + "司寇", + "仉督", + "子车", + "颛孙", + "端木", + "巫马", + "公西", + "漆雕", + "乐正", + "壤驷", + "公良", + "拓跋", + "夹谷", + "宰父", + "榖梁", + "晋", + "楚", + "闫", + "法", + "汝", + "鄢", + "涂", + "钦", + "段干", + "百里", + "东郭", + "南门", + "呼延", + "归", + "海", + "羊舌", + "微", + "生", + "岳", + "帅", + "缑", + "亢", + "况", + "后", + "有", + "琴", + "梁丘", + "左丘", + "东门", + "西门", + "商", + "牟", + "佘", + "佴", + "伯", + "赏", + "南宫", + "墨", + "哈", + "谯", + "笪", + "年", + "爱", + "阳", + "佟", + "第五", + "言", + "福", + ] +) -def isit(n): return n.strip() in m +def isit(n): + return n.strip() in m diff --git a/rag/nlp/synonym.py b/rag/nlp/synonym.py index 19744c2542..f29f3cfa52 100644 --- a/rag/nlp/synonym.py +++ b/rag/nlp/synonym.py @@ -30,6 +30,7 @@ try: except Exception: logging.warning("Fail to load wordnet.ensure_loaded()") + class Dealer: def __init__(self, redis=None): @@ -38,17 +39,16 @@ class Dealer: self.dictionary = None path = os.path.join(get_project_base_directory(), "rag/res", "synonym.json") try: - with open(path, 'r') as f: + with open(path, "r") as f: self.dictionary = json.load(f) - self.dictionary = { (k.lower() if isinstance(k, str) else k): v for k, v in self.dictionary.items() } + self.dictionary = {(k.lower() if isinstance(k, str) else k): v for k, v in self.dictionary.items()} except Exception: logging.warning("Missing synonym.json") self.dictionary = {} if not redis: - logging.warning( - "Realtime synonym is disabled, since no redis connection.") + logging.warning("Realtime synonym is disabled, since no redis connection.") if not len(self.dictionary.keys()): logging.warning("Fail to load synonym") @@ -76,7 +76,6 @@ class Dealer: except Exception as e: logging.error("Fail to load synonym!" + str(e)) - def lookup(self, tk, topn=8): if not tk or not isinstance(tk, str): return [] @@ -93,18 +92,15 @@ class Dealer: # 2) If not found and tk is purely alphabetical → fallback to WordNet if re.fullmatch(r"[a-z]+", tk): - wn_set = { - re.sub("_", " ", syn.name().split(".")[0]) - for syn in wordnet.synsets(tk) - } + wn_set = {re.sub("_", " ", syn.name().split(".")[0]) for syn in wordnet.synsets(tk)} wn_set.discard(tk) # Remove the original token itself wn_res = [t for t in wn_set if t] return wn_res[:topn] # 3) Nothing found in either source return [] - -if __name__ == '__main__': + +if __name__ == "__main__": dl = Dealer() print(dl.dictionary) diff --git a/rag/nlp/term_weight.py b/rag/nlp/term_weight.py index 1a7412de9e..43b018a429 100644 --- a/rag/nlp/term_weight.py +++ b/rag/nlp/term_weight.py @@ -26,37 +26,41 @@ from common.file_utils import get_project_base_directory class Dealer: def __init__(self): - self.stop_words = set(["请问", - "您", - "你", - "我", - "他", - "是", - "的", - "就", - "有", - "于", - "及", - "即", - "在", - "为", - "最", - "有", - "从", - "以", - "了", - "将", - "与", - "吗", - "吧", - "中", - "#", - "什么", - "怎么", - "哪个", - "哪些", - "啥", - "相关"]) + self.stop_words = set( + [ + "请问", + "您", + "你", + "我", + "他", + "是", + "的", + "就", + "有", + "于", + "及", + "即", + "在", + "为", + "最", + "有", + "从", + "以", + "了", + "将", + "与", + "吗", + "吧", + "中", + "#", + "什么", + "怎么", + "哪个", + "哪些", + "啥", + "相关", + ] + ) def load_dict(fnm): res = {} @@ -91,19 +95,15 @@ class Dealer: logging.warning("Load term.freq FAIL!") def pretoken(self, txt, num=False, stpwd=True): - patt = [ - r"[~—\t @#%!<>,\.\?\":;'\{\}\[\]_=\(\)\|,。?》•●○↓《;‘’:“”【¥ 】…¥!、·()×`&\\/「」\\]" - ] - rewt = [ - ] + patt = [r"[~—\t @#%!<>,\.\?\":;'\{\}\[\]_=\(\)\|,。?》•●○↓《;‘’:“”【¥ 】…¥!、·()×`&\\/「」\\]"] + rewt = [] for p, r in rewt: txt = re.sub(p, r, txt) res = [] for t in rag_tokenizer.tokenize(txt).split(): tk = t - if (stpwd and tk in self.stop_words) or ( - re.match(r"[0-9]$", tk) and not num): + if (stpwd and tk in self.stop_words) or (re.match(r"[0-9]$", tk) and not num): continue for p in patt: if re.match(p, t): @@ -121,21 +121,19 @@ class Dealer: res, i = [], 0 while i < len(tks): j = i - if i == 0 and one_term(tks[i]) and len( - tks) > 1 and (len(tks[i + 1]) > 1 and not re.match(r"[0-9a-zA-Z]", tks[i + 1])): # 多 工位 + if i == 0 and one_term(tks[i]) and len(tks) > 1 and (len(tks[i + 1]) > 1 and not re.match(r"[0-9a-zA-Z]", tks[i + 1])): # 多 工位 res.append(" ".join(tks[0:2])) i = 2 continue - while j < len( - tks) and tks[j] and tks[j] not in self.stop_words and one_term(tks[j]): + while j < len(tks) and tks[j] and tks[j] not in self.stop_words and one_term(tks[j]): j += 1 if j - i > 1: if j - i < 5: res.append(" ".join(tks[i:j])) i = j else: - res.append(" ".join(tks[i:i + 2])) + res.append(" ".join(tks[i : i + 2])) i = i + 2 else: if len(tks[i]) > 0: @@ -153,9 +151,7 @@ class Dealer: def split(self, txt): tks = [] for t in re.sub(r"[ \t]+", " ", txt).split(): - if tks and re.match(r".*[a-zA-Z]$", tks[-1]) and \ - re.match(r".*[a-zA-Z]$", t) and tks and \ - self.ne.get(t, "") != "func" and self.ne.get(tks[-1], "") != "func": + if tks and re.match(r".*[a-zA-Z]$", tks[-1]) and re.match(r".*[a-zA-Z]$", t) and tks and self.ne.get(t, "") != "func" and self.ne.get(tks[-1], "") != "func": tks[-1] = tks[-1] + " " + t else: tks.append(t) @@ -174,8 +170,7 @@ class Dealer: return 0.01 if not self.ne or t not in self.ne: return 1 - m = {"toxic": 2, "func": 1, "corp": 3, "loca": 3, "sch": 3, "stock": 3, - "firstnm": 1} + m = {"toxic": 2, "func": 1, "corp": 3, "loca": 3, "sch": 3, "stock": 3, "firstnm": 1} return m[self.ne[t]] def postag(t): @@ -202,7 +197,7 @@ class Dealer: if not s and len(t) >= 4: s = [tt for tt in rag_tokenizer.fine_grained_tokenize(t).split() if len(tt) > 1] if len(s) > 1: - s = np.min([freq(tt) for tt in s]) / 6. + s = np.min([freq(tt) for tt in s]) / 6.0 else: s = 0 @@ -218,7 +213,7 @@ class Dealer: elif len(t) >= 4: s = [tt for tt in rag_tokenizer.fine_grained_tokenize(t).split() if len(tt) > 1] if len(s) > 1: - return max(3, np.min([df(tt) for tt in s]) / 6.) + return max(3, np.min([df(tt) for tt in s]) / 6.0) return 3 @@ -229,8 +224,7 @@ class Dealer: if not preprocess: idf1 = np.array([idf(freq(t), 10000000) for t in tks]) idf2 = np.array([idf(df(t), 1000000000) for t in tks]) - wts = (0.3 * idf1 + 0.7 * idf2) * \ - np.array([ner(t) * postag(t) for t in tks]) + wts = (0.3 * idf1 + 0.7 * idf2) * np.array([ner(t) * postag(t) for t in tks]) wts = [s for s in wts] tw = list(zip(tks, wts)) else: @@ -238,8 +232,7 @@ class Dealer: tt = self.token_merge(self.pretoken(tk, True)) idf1 = np.array([idf(freq(t), 10000000) for t in tt]) idf2 = np.array([idf(df(t), 1000000000) for t in tt]) - wts = (0.3 * idf1 + 0.7 * idf2) * \ - np.array([ner(t) * postag(t) for t in tt]) + wts = (0.3 * idf1 + 0.7 * idf2) * np.array([ner(t) * postag(t) for t in tt]) wts = [s for s in wts] tw.extend(zip(tt, wts)) diff --git a/rag/prompts/__init__.py b/rag/prompts/__init__.py index a2d991705a..d9b966c335 100644 --- a/rag/prompts/__init__.py +++ b/rag/prompts/__init__.py @@ -1,6 +1,5 @@ from . import generator -__all__ = [name for name in dir(generator) - if not name.startswith('_')] +__all__ = [name for name in dir(generator) if not name.startswith("_")] globals().update({name: getattr(generator, name) for name in __all__}) diff --git a/rag/svr/sync_data_source.py b/rag/svr/sync_data_source.py index 62dad5823e..5f13ae4ae1 100644 --- a/rag/svr/sync_data_source.py +++ b/rag/svr/sync_data_source.py @@ -20,6 +20,7 @@ import time + start_ts = time.perf_counter() import asyncio @@ -84,6 +85,7 @@ from common.log_utils import init_root_logger from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc from common.versions import get_ragflow_version from box_sdk_gen import BoxOAuth, OAuthConfig, AccessToken + MAX_CONCURRENT_TASKS = int(os.environ.get("MAX_CONCURRENT_TASKS", "5")) task_limiter = asyncio.Semaphore(MAX_CONCURRENT_TASKS) @@ -108,10 +110,11 @@ def _redact_mailbox(value: str) -> str: class SyncBase: """ Base class for all data source synchronization connectors. - - Defines the standard interface for connecting to external APIs, polling for + + Defines the standard interface for connecting to external APIs, polling for new or updated documents, and managing synchronization state intervals. """ + SOURCE_NAME: str = None def __init__(self, conf: dict) -> None: @@ -129,10 +132,7 @@ class SyncBase: if task.get("reindex") != "1" and task.get("poll_range_start"): window_start = task["poll_range_start"] window_end = datetime.now(timezone.utc) - return ( - f"sync window: {cls._format_window_boundary(window_start)}" - f" -> {cls._format_window_boundary(window_end)}" - ) + return f"sync window: {cls._format_window_boundary(window_start)} -> {cls._format_window_boundary(window_end)}" @classmethod def log_connection( @@ -152,9 +152,9 @@ class SyncBase: async def __call__(self, task: dict): """ Entry point for executing a synchronization task worker. - - Manages task execution boundaries including status logging, asynchronous - timeouts, and top-level exception handling, while delegating the core + + Manages task execution boundaries including status logging, asynchronous + timeouts, and top-level exception handling, while delegating the core ingestion logic to `_run_task_logic`. """ SyncLogsService.start(task["id"], task["connector_id"]) @@ -169,15 +169,13 @@ class SyncBase: return except Exception as ex: - msg = "\n".join([ - "".join(traceback.format_exception_only(None, ex)).strip(), - "".join(traceback.format_exception(None, ex, ex.__traceback__)).strip(), - ]) - SyncLogsService.update_by_id(task["id"], { - "status": TaskStatus.FAIL, - "full_exception_trace": msg, - "error_msg": str(ex) - }) + msg = "\n".join( + [ + "".join(traceback.format_exception_only(None, ex)).strip(), + "".join(traceback.format_exception(None, ex, ex.__traceback__)).strip(), + ] + ) + SyncLogsService.update_by_id(task["id"], {"status": TaskStatus.FAIL, "full_exception_trace": msg, "error_msg": str(ex)}) return task_type = task.get("task_type", ConnectorTaskType.SYNC) @@ -254,17 +252,10 @@ class SyncBase: try: e, kb = KnowledgebaseService.get_by_id(task["kb_id"]) - err, dids = SyncLogsService.duplicate_and_parse( - kb, docs, task["tenant_id"], - f"{self.SOURCE_NAME}/{task['connector_id']}", - task["auto_parse"] - ) + err, dids = SyncLogsService.duplicate_and_parse(kb, docs, task["tenant_id"], f"{self.SOURCE_NAME}/{task['connector_id']}", task["auto_parse"]) if err: had_parse_errors = True - SyncLogsService.increase_docs( - task["id"], max_update, - len(docs), "\n".join(err), len(err) - ) + SyncLogsService.increase_docs(task["id"], max_update, len(docs), "\n".join(err), len(err)) changed_doc_ids = set(dids) updated_in_batch = len(changed_doc_ids & existing_doc_ids) added_in_batch = len(changed_doc_ids) - updated_in_batch @@ -289,20 +280,12 @@ class SyncBase: next_update_info = self._format_window_boundary(next_update) total_changed_docs = added_docs + updated_docs - summary = ( - f"{prefix}sync summary till {next_update_info}: " - f"total={total_changed_docs}, added={added_docs}, " - f"updated={updated_docs}" - ) + summary = f"{prefix}sync summary till {next_update_info}: total={total_changed_docs}, added={added_docs}, updated={updated_docs}" if failed_docs > 0: summary = f"{summary}, skipped={failed_docs}" logging.info(summary) - if ( - isinstance(self, _CursorPersistingSyncBase) - and failed_docs == 0 - and not had_parse_errors - ): + if isinstance(self, _CursorPersistingSyncBase) and failed_docs == 0 and not had_parse_errors: self.connector.persist_sync_state() SyncLogsService.done(task["id"], task["connector_id"]) task["poll_range_start"] = next_update @@ -397,7 +380,8 @@ class _BlobLikeBase(SyncBase): """ source_type = f"{self.SOURCE_NAME}/{task['connector_id']}" existing_fingerprints = DocumentService.list_id_content_hash_map_by_kb_and_source_type( - task["kb_id"], source_type, + task["kb_id"], + source_type, ) bypass_count = 0 @@ -436,10 +420,7 @@ class _BlobLikeBase(SyncBase): if batch: yield batch - log_msg = ( - "[%s] fingerprint sync: %d bypassed, %d fetched, %d failed " - "(connector_id=%s, kb_id=%s)" - ) + log_msg = "[%s] fingerprint sync: %d bypassed, %d fetched, %d failed (connector_id=%s, kb_id=%s)" log_args = ( self.SOURCE_NAME, bypass_count, @@ -475,11 +456,7 @@ class _BlobLikeBase(SyncBase): else: document_batch_generator = self.connector.load_from_state() - _begin_info = ( - "fingerprint-bypass" - if use_fingerprint_path - else "full reindex" - ) + _begin_info = "fingerprint-bypass" if use_fingerprint_path else "full reindex" logging.info( "Connect to {}: {}(prefix/{}) {}".format( @@ -566,12 +543,9 @@ class Confluence(SyncBase): space=space, page_id=page_id, index_recursively=index_recursively, - ) - credentials_provider = StaticCredentialsProvider(tenant_id=task["tenant_id"], - connector_name=DocumentSource.CONFLUENCE, - credential_json=self.conf["credentials"]) + credentials_provider = StaticCredentialsProvider(tenant_id=task["tenant_id"], connector_name=DocumentSource.CONFLUENCE, credential_json=self.conf["credentials"]) self.connector.set_credentials_provider(credentials_provider) # Determine the time range for synchronization based on reindex or poll_range_start @@ -579,7 +553,7 @@ class Confluence(SyncBase): start_time = 0.0 else: start_time = task["poll_range_start"].timestamp() - + end_time = datetime.now(timezone.utc).timestamp() raw_batch_size = self.conf.get("sync_batch_size") or self.conf.get("batch_size") or INDEX_BATCH_SIZE @@ -601,8 +575,7 @@ class Confluence(SyncBase): doc_generator = wrapper(self.connector.load_from_checkpoint(start_time, end_time, checkpoint)) for document, failure, next_checkpoint in doc_generator: if failure is not None: - logging.warning("Confluence connector failure: %s", - getattr(failure, "failure_message", failure)) + logging.warning("Confluence connector failure: %s", getattr(failure, "failure_message", failure)) continue if document is not None: pending_docs.append(document) @@ -636,12 +609,10 @@ class Notion(SyncBase): document_generator = ( self.connector.load_from_state() if task["reindex"] == "1" or not task["poll_range_start"] - else self.connector.poll_source(task["poll_range_start"].timestamp(), - datetime.now(timezone.utc).timestamp()) + else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp()) ) - _begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format( - task["poll_range_start"]) + _begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(task["poll_range_start"]) self.log_connection("Notion", f"root({self.conf['root_page_id']})", task) return document_generator @@ -711,12 +682,10 @@ class Discord(SyncBase): document_generator = ( self.connector.load_from_state() if task["reindex"] == "1" or not task["poll_range_start"] - else self.connector.poll_source(task["poll_range_start"].timestamp(), - datetime.now(timezone.utc).timestamp()) + else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp()) ) - _begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format( - task["poll_range_start"]) + _begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(task["poll_range_start"]) self.log_connection("Discord", f"servers({server_ids}), channel({channel_names})", task) return document_generator @@ -810,6 +779,7 @@ class GoogleDrive(SyncBase): Handles both full re-indexing and incremental polling, including the capability to synchronize deleted files by retrieving a lightweight snapshot of current files. """ + SOURCE_NAME: str = FileSource.GOOGLE_DRIVE async def _generate(self, task: dict): @@ -844,7 +814,7 @@ class GoogleDrive(SyncBase): else: start_time = task["poll_range_start"].timestamp() _begin_info = f"from {task['poll_range_start']}" - + raw_batch_size = self.conf.get("sync_batch_size") or self.conf.get("batch_size") or INDEX_BATCH_SIZE try: batch_size = int(raw_batch_size) @@ -865,8 +835,7 @@ class GoogleDrive(SyncBase): doc_generator = wrapper(self.connector.load_from_checkpoint(start_time, end_time, checkpoint)) for document, failure, next_checkpoint in doc_generator: if failure is not None: - logging.warning("Google Drive connector failure: %s", - getattr(failure, "failure_message", failure)) + logging.warning("Google Drive connector failure: %s", getattr(failure, "failure_message", failure)) continue if document is not None: pending_docs.append(document) @@ -888,7 +857,7 @@ class GoogleDrive(SyncBase): except RuntimeError: admin_email = "unknown" self.log_connection("Google Drive", f"as {admin_email}", task) - + return document_batches() def _persist_rotated_credentials(self, connector_id: str, credentials: dict[str, Any]) -> None: @@ -901,7 +870,8 @@ class GoogleDrive(SyncBase): logging.info("Persisted refreshed Google Drive credentials for connector %s", connector_id) except Exception: logging.exception("Failed to persist refreshed Google Drive credentials for connector %s", connector_id) - + + class Jira(SyncBase): SOURCE_NAME: str = FileSource.JIRA @@ -967,9 +937,7 @@ class Jira(SyncBase): ) for document, failure, next_checkpoint in generator: if failure is not None: - logging.warning( - f"[Jira] Jira connector failure: {getattr(failure, 'failure_message', failure)}" - ) + logging.warning(f"[Jira] Jira connector failure: {getattr(failure, 'failure_message', failure)}") continue if document is not None: pending_docs.append(document) @@ -991,10 +959,7 @@ class Jira(SyncBase): "Jira", connector_kwargs["jira_base_url"], task, - ( - f"sync_batch_size={batch_size}, " - f"overlap_buffer_s={getattr(self.connector, 'time_buffer_seconds', connector_kwargs.get('time_buffer_seconds'))}" - ), + (f"sync_batch_size={batch_size}, overlap_buffer_s={getattr(self.connector, 'time_buffer_seconds', connector_kwargs.get('time_buffer_seconds'))}"), ) return document_batches() @@ -1044,9 +1009,7 @@ class SharePoint(SyncBase): while checkpoint.has_more: wrapper = CheckpointOutputWrapper() - doc_generator = wrapper( - self.connector.load_from_checkpoint(start_time, end_time, checkpoint) - ) + doc_generator = wrapper(self.connector.load_from_checkpoint(start_time, end_time, checkpoint)) for document, failure, next_checkpoint in doc_generator: if failure is not None: logging.warning( @@ -1102,9 +1065,7 @@ class OneDrive(SyncBase): start_ts = task["poll_range_start"].timestamp() end_ts = datetime.now(timezone.utc).timestamp() checkpoint = self.connector.build_dummy_checkpoint() - document_batch_generator = self.connector.load_from_checkpoint( - start_ts, end_ts, checkpoint - ) + document_batch_generator = self.connector.load_from_checkpoint(start_ts, end_ts, checkpoint) self.log_connection( "OneDrive", @@ -1157,9 +1118,7 @@ class Outlook(SyncBase): start_ts = task["poll_range_start"].timestamp() end_ts = datetime.now(timezone.utc).timestamp() checkpoint = self.connector.build_dummy_checkpoint() - document_batch_generator = self.connector.load_from_checkpoint( - start_ts, end_ts, checkpoint - ) + document_batch_generator = self.connector.load_from_checkpoint(start_ts, end_ts, checkpoint) # Redact mailbox identifiers — full UPN / email lists in connector # logs leak PII (the entire org's mail directory ends up in @@ -1228,9 +1187,7 @@ class Salesforce(SyncBase): start_ts = task["poll_range_start"].timestamp() end_ts = datetime.now(timezone.utc).timestamp() checkpoint = self.connector.build_dummy_checkpoint() - document_batch_generator = self.connector.load_from_checkpoint( - start_ts, end_ts, checkpoint - ) + document_batch_generator = self.connector.load_from_checkpoint(start_ts, end_ts, checkpoint) instance_url = (self.conf.get("credentials") or {}).get("instance_url", "") self.log_connection( @@ -1276,15 +1233,9 @@ class AzureBlob(SyncBase): start_ts = task["poll_range_start"].timestamp() end_ts = datetime.now(timezone.utc).timestamp() checkpoint = self.connector.build_dummy_checkpoint() - document_batch_generator = self.connector.load_from_checkpoint( - start_ts, end_ts, checkpoint - ) + document_batch_generator = self.connector.load_from_checkpoint(start_ts, end_ts, checkpoint) - container_hint = ( - credentials.get("container_name") - or credentials.get("container_url", "").rstrip("/").rsplit("/", 1)[-1] - or "" - ) + container_hint = credentials.get("container_name") or credentials.get("container_url", "").rstrip("/").rsplit("/", 1)[-1] or "" self.log_connection( "Azure Blob", f"{container_hint}/{self.conf.get('prefix', '') or ''}", @@ -1393,9 +1344,7 @@ class Teams(SyncBase): while checkpoint.has_more: wrapper = CheckpointOutputWrapper() - doc_generator = wrapper( - self.connector.load_from_checkpoint(start_time, end_time, checkpoint) - ) + doc_generator = wrapper(self.connector.load_from_checkpoint(start_time, end_time, checkpoint)) for document, failure, next_checkpoint in doc_generator: if failure is not None: logging.warning( @@ -1466,10 +1415,7 @@ class Moodle(SyncBase): SOURCE_NAME: str = FileSource.MOODLE async def _generate(self, task: dict): - self.connector = MoodleConnector( - moodle_url=self.conf["moodle_url"], - batch_size=self.conf.get("batch_size", INDEX_BATCH_SIZE) - ) + self.connector = MoodleConnector(moodle_url=self.conf["moodle_url"], batch_size=self.conf.get("batch_size", INDEX_BATCH_SIZE)) self.connector.load_credentials(self.conf["credentials"]) @@ -1504,18 +1450,18 @@ class BOX(SyncBase): folder_id=self.conf.get("folder_id", "0"), ) - credential = json.loads(self.conf['credentials']['box_tokens']) + credential = json.loads(self.conf["credentials"]["box_tokens"]) auth = BoxOAuth( OAuthConfig( - client_id=credential['client_id'], - client_secret=credential['client_secret'], + client_id=credential["client_id"], + client_secret=credential["client_secret"], ) ) token = AccessToken( - access_token=credential['access_token'], - refresh_token=credential['refresh_token'], + access_token=credential["access_token"], + refresh_token=credential["refresh_token"], ) auth.token_storage.store(token) @@ -1552,9 +1498,7 @@ class Airtable(SyncBase): if "airtable_access_token" not in credentials: raise ValueError("Missing airtable_access_token in credentials") - self.connector.load_credentials( - {"airtable_access_token": credentials["airtable_access_token"]} - ) + self.connector.load_credentials({"airtable_access_token": credentials["airtable_access_token"]}) poll_start = task.get("poll_range_start") @@ -1576,6 +1520,7 @@ class Airtable(SyncBase): return document_generator + class Asana(SyncBase): SOURCE_NAME: str = FileSource.ASANA @@ -1589,9 +1534,7 @@ class Asana(SyncBase): if "asana_api_token_secret" not in credentials: raise ValueError("Missing asana_api_token_secret in credentials") - self.connector.load_credentials( - {"asana_api_token_secret": credentials["asana_api_token_secret"]} - ) + self.connector.load_credentials({"asana_api_token_secret": credentials["asana_api_token_secret"]}) poll_start = task.get("poll_range_start") @@ -1614,6 +1557,7 @@ class Asana(SyncBase): return document_generator + class Github(SyncBase): SOURCE_NAME: str = FileSource.GITHUB @@ -1634,9 +1578,7 @@ class Github(SyncBase): if "github_access_token" not in credentials: raise ValueError("Missing github_access_token in credentials") - self.connector.load_credentials( - {"github_access_token": credentials["github_access_token"]} - ) + self.connector.load_credentials({"github_access_token": credentials["github_access_token"]}) if task.get("reindex") == "1" or not task.get("poll_range_start"): start_time = datetime.fromtimestamp(0, tz=timezone.utc) @@ -1645,12 +1587,7 @@ class Github(SyncBase): end_time = datetime.now(timezone.utc) - runner = ConnectorRunner( - connector=self.connector, - batch_size=self.conf.get("batch_size", INDEX_BATCH_SIZE), - include_permissions=False, - time_range=(start_time, end_time) - ) + runner = ConnectorRunner(connector=self.connector, batch_size=self.conf.get("batch_size", INDEX_BATCH_SIZE), include_permissions=False, time_range=(start_time, end_time)) def document_batches(): checkpoint = self.connector.build_dummy_checkpoint() @@ -1680,12 +1617,14 @@ class Github(SyncBase): return wrapper() + class IMAP(SyncBase): SOURCE_NAME: str = FileSource.IMAP async def _generate(self, task): from common.data_source.config import DocumentSource from common.data_source.interfaces import StaticCredentialsProvider + self.connector = ImapConnector( host=self.conf.get("imap_host"), port=self.conf.get("imap_port"), @@ -1715,18 +1654,14 @@ class IMAP(SyncBase): try: initial_sync_start = float(initial_sync_start) except (TypeError, ValueError): - initial_sync_start = ( - 0 if task["poll_range_start"] else default_initial_sync_start - ) + initial_sync_start = 0 if task["poll_range_start"] else default_initial_sync_start should_persist_initial_start = True if should_persist_initial_start: updated_conf = copy.deepcopy(self.conf) updated_conf["imap_initial_sync_start"] = initial_sync_start try: - ConnectorService.update_by_id( - task["connector_id"], {"config": updated_conf} - ) + ConnectorService.update_by_id(task["connector_id"], {"config": updated_conf}) self.conf = updated_conf except Exception: logging.exception( @@ -1788,9 +1723,10 @@ class IMAP(SyncBase): def _get_prune_snapshot_kwargs(self, task: dict) -> dict[str, Any]: return getattr(self, "_prune_snapshot_kwargs", {}) -class Zendesk(SyncBase): +class Zendesk(SyncBase): SOURCE_NAME: str = FileSource.ZENDESK + async def _generate(self, task: dict): self.connector = ZendeskConnector(content_type=self.conf.get("zendesk_content_type")) self.connector.load_credentials(self.conf["credentials"]) @@ -1803,11 +1739,7 @@ class Zendesk(SyncBase): start_time = task["poll_range_start"].timestamp() _begin_info = f"from {task['poll_range_start']}" - raw_batch_size = ( - self.conf.get("sync_batch_size") - or self.conf.get("batch_size") - or INDEX_BATCH_SIZE - ) + raw_batch_size = self.conf.get("sync_batch_size") or self.conf.get("batch_size") or INDEX_BATCH_SIZE try: batch_size = int(raw_batch_size) except (TypeError, ValueError): @@ -1824,11 +1756,7 @@ class Zendesk(SyncBase): while checkpoint.has_more: wrapper = CheckpointOutputWrapper() - doc_generator = wrapper( - self.connector.load_from_checkpoint( - start_time, end_time, checkpoint - ) - ) + doc_generator = wrapper(self.connector.load_from_checkpoint(start_time, end_time, checkpoint)) for document, failure, next_checkpoint in doc_generator: if failure is not None: @@ -1849,9 +1777,7 @@ class Zendesk(SyncBase): iterations += 1 if iterations > iteration_limit: - raise RuntimeError( - "Too many iterations while loading Zendesk documents." - ) + raise RuntimeError("Too many iterations while loading Zendesk documents.") if pending_docs: yield pending_docs @@ -1873,11 +1799,11 @@ class Gitlab(SyncBase): """ self.connector = GitlabConnector( - project_owner= self.conf.get("project_owner"), - project_name= self.conf.get("project_name"), - include_mrs = self.conf.get("include_mrs", False), - include_issues = self.conf.get("include_issues", False), - include_code_files= self.conf.get("include_code_files", False), + project_owner=self.conf.get("project_owner"), + project_name=self.conf.get("project_name"), + include_mrs=self.conf.get("include_mrs", False), + include_issues=self.conf.get("include_issues", False), + include_code_files=self.conf.get("include_code_files", False), ) self.connector.load_credentials( @@ -1896,10 +1822,7 @@ class Gitlab(SyncBase): document_generator = self.connector.load_from_state() _begin_info = "totally" else: - document_generator = self.connector.poll_source( - poll_start.timestamp(), - datetime.now(timezone.utc).timestamp() - ) + document_generator = self.connector.poll_source(poll_start.timestamp(), datetime.now(timezone.utc).timestamp()) _begin_info = "from {}".format(poll_start) self.log_connection("Gitlab", f"({self.conf['project_name']})", task) return document_generator @@ -1917,8 +1840,8 @@ class Bitbucket(SyncBase): self.connector.load_credentials( { - "bitbucket_email": self.conf["credentials"].get("bitbucket_account_email"), - "bitbucket_api_token": self.conf["credentials"].get("bitbucket_api_token"), + "bitbucket_email": self.conf["credentials"].get("bitbucket_account_email"), + "bitbucket_api_token": self.conf["credentials"].get("bitbucket_api_token"), } ) @@ -1928,31 +1851,26 @@ class Bitbucket(SyncBase): else: start_time = task.get("poll_range_start") _begin_info = f"from {start_time}" - + end_time = datetime.now(timezone.utc) def document_batches(): checkpoint = self.connector.build_dummy_checkpoint() while checkpoint.has_more: - gen = self.connector.load_from_checkpoint( - start=start_time.timestamp(), - end=end_time.timestamp(), - checkpoint=checkpoint) - + gen = self.connector.load_from_checkpoint(start=start_time.timestamp(), end=end_time.timestamp(), checkpoint=checkpoint) + while True: try: item = next(gen) if isinstance(item, ConnectorFailure): - logging.exception( - "Bitbucket connector failure: %s", - item.failure_message) + logging.exception("Bitbucket connector failure: %s", item.failure_message) break yield [item] except StopIteration as e: checkpoint = e.value break - + def wrapper(): for batch in document_batches(): yield batch @@ -2032,9 +1950,7 @@ class DingTalkAITable(SyncBase): if "access_token" not in credentials: raise ValueError("Missing access_token in credentials") - self.connector.load_credentials( - {"access_token": credentials["access_token"]} - ) + self.connector.load_credentials({"access_token": credentials["access_token"]}) poll_start = task.get("poll_range_start") @@ -2188,11 +2104,7 @@ class BigQuery(_CursorPersistingSyncBase): start_cursor_id, ) - target = ( - f"{self.conf.get('dataset_id')}.{self.conf.get('table_id')}" - if not self.conf.get("query") - else "custom query" - ) + target = f"{self.conf.get('dataset_id')}.{self.conf.get('table_id')}" if not self.conf.get("query") else "custom query" self.log_connection("BigQuery", f"{self.conf.get('project_id')}:{target}", task) return document_generator diff --git a/rag/svr/task_executor_limiter.py b/rag/svr/task_executor_limiter.py index 61b50849b3..6f365bbc23 100644 --- a/rag/svr/task_executor_limiter.py +++ b/rag/svr/task_executor_limiter.py @@ -25,4 +25,4 @@ task_limiter = LoopLocalSemaphore(MAX_CONCURRENT_TASKS) chunk_limiter = LoopLocalSemaphore(MAX_CONCURRENT_CHUNK_BUILDERS) embed_limiter = LoopLocalSemaphore(MAX_CONCURRENT_CHUNK_BUILDERS) minio_limiter = LoopLocalSemaphore(MAX_CONCURRENT_MINIO) -kg_limiter = LoopLocalSemaphore(2) \ No newline at end of file +kg_limiter = LoopLocalSemaphore(2) diff --git a/rag/svr/task_executor_refactor/chunk_builder.py b/rag/svr/task_executor_refactor/chunk_builder.py index b9dc353b4e..0e58382492 100644 --- a/rag/svr/task_executor_refactor/chunk_builder.py +++ b/rag/svr/task_executor_refactor/chunk_builder.py @@ -126,10 +126,7 @@ async def extract_outline(cks: List[Dict], ctx: TaskContext) -> None: ctx.write_interceptor.intercept("DocMetadataService.update_document_metadata") else: temp_doc = DocMetadataService.get_document_metadata(ctx.doc_id) or {} - DocMetadataService.update_document_metadata( - ctx.doc_id, - update_metadata_to({"outline": outline}, temp_doc) - ) + DocMetadataService.update_document_metadata(ctx.doc_id, update_metadata_to({"outline": outline}, temp_doc)) logging.info("Persisted PDF outline (%d entries) for doc %s", len(outline), ctx.doc_id) except Exception as e: diff --git a/rag/svr/task_executor_refactor/chunk_post_processor.py b/rag/svr/task_executor_refactor/chunk_post_processor.py index 1705988ed5..6a809c8635 100644 --- a/rag/svr/task_executor_refactor/chunk_post_processor.py +++ b/rag/svr/task_executor_refactor/chunk_post_processor.py @@ -1114,15 +1114,11 @@ async def run_document_post_chunking_if_last( return False chunking_aborted = is_doc_chunking_aborted(task_doc_id) - remaining_chunking_tasks = ( - 0 if ctx.write_interceptor - else credit_doc_chunking_task(task_doc_id, task_id) - ) + remaining_chunking_tasks = 0 if ctx.write_interceptor else credit_doc_chunking_task(task_doc_id, task_id) if remaining_chunking_tasks != 0: if chunking_aborted: logging.info( - "Chunking for doc %s was aborted before task %s reached post-processing; " - "skip document finalizers.", + "Chunking for doc %s was aborted before task %s reached post-processing; skip document finalizers.", task_doc_id, task_id, ) diff --git a/rag/svr/task_executor_refactor/chunk_service.py b/rag/svr/task_executor_refactor/chunk_service.py index d0fdd836f5..bbac6f707b 100644 --- a/rag/svr/task_executor_refactor/chunk_service.py +++ b/rag/svr/task_executor_refactor/chunk_service.py @@ -109,8 +109,7 @@ class ChunkService: ctx = self._task_context # Validate file size if ctx.size > settings.DOC_MAXIMUM_SIZE: - self._progress(prog=-1, msg="File size exceeds( <= %dMb )" % - (int(settings.DOC_MAXIMUM_SIZE / 1024 / 1024))) + self._progress(prog=-1, msg="File size exceeds( <= %dMb )" % (int(settings.DOC_MAXIMUM_SIZE / 1024 / 1024))) self._task_context.recording_context.record("file_size_exceeded", True) return [] ctx.recording_context.record("file_size_exceeded", False) @@ -123,9 +122,7 @@ class ChunkService: chunk_config = { "parser_id": ctx.parser_id, "chunk_token_num": ctx.parser_config.get("chunk_token_num", 128), - "overlapped_percent": normalize_overlapped_percent( - ctx.parser_config.get("overlapped_percent", 0) - ), + "overlapped_percent": normalize_overlapped_percent(ctx.parser_config.get("overlapped_percent", 0)), "delimiter": ctx.parser_config.get("delimiter", "\n!?。;!?"), "from_page": ctx.from_page, "to_page": ctx.to_page, @@ -160,9 +157,7 @@ class ChunkService: questions = [d for d in docs if d.get("question_kwd")] self._task_context.recording_context.record("questions_generated", questions) - if ctx.parser_config.get("enable_metadata", False) and ( - ctx.parser_config.get("metadata") or ctx.parser_config.get("built_in_metadata") - ): + if ctx.parser_config.get("enable_metadata", False) and (ctx.parser_config.get("metadata") or ctx.parser_config.get("built_in_metadata")): await generate_metadata(docs, ctx) metadata_list = [d for d in docs if d.get("metadata_obj")] self._task_context.recording_context.record("metadata_list_generated", metadata_list) @@ -183,10 +178,7 @@ class ChunkService: """Prepare docs and upload images to MinIO.""" ctx = self._task_context docs = [] - doc = { - "doc_id": ctx.doc_id, - "kb_id": str(ctx.kb_id) - } + doc = {"doc_id": ctx.doc_id, "kb_id": str(ctx.kb_id)} if ctx.pagerank: doc[PAGERANK_FLD] = int(ctx.pagerank) @@ -197,8 +189,7 @@ class ChunkService: try: d = copy.deepcopy(document) d.update(chunk) - d["id"] = xxhash.xxh64( - (chunk["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest() + d["id"] = xxhash.xxh64((chunk["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest() d["create_time"] = str(datetime.now()).replace("T", " ")[:19] d["create_timestamp_flt"] = datetime.now().timestamp() @@ -215,8 +206,7 @@ class ChunkService: await image2id(d, partial(settings.STORAGE_IMPL.put, tenant_id=ctx.tenant_id), d["id"], ctx.kb_id) docs.append(d) except Exception: - logging.exception( - "Saving image of chunk {}/{}/{} got exception".format(ctx.location, ctx.name, d["id"])) + logging.exception("Saving image of chunk {}/{}/{} got exception".format(ctx.location, ctx.name, d["id"])) raise tasks = [] @@ -303,11 +293,7 @@ class ChunkService: mom_ck["available_int"] = 0 # Keep only essential fields - allowed_fields = [ - "id", "content_with_weight", "doc_id", "docnm_kwd", - "kb_id", "available_int", "position_int", - "create_timestamp_flt", "page_num_int", "top_int" - ] + allowed_fields = ["id", "content_with_weight", "doc_id", "docnm_kwd", "kb_id", "available_int", "position_int", "create_timestamp_flt", "page_num_int", "top_int"] for fld in list(mom_ck.keys()): if fld not in allowed_fields: del mom_ck[fld] @@ -326,11 +312,7 @@ class ChunkService: ) -> bool: """Insert mother chunks in batches.""" for b in range(0, len(mothers), doc_bulk_size): - await self._intercept_doc_store_insert( - mothers[b:b + doc_bulk_size], - search.index_name(task_tenant_id), - task_dataset_id - ) + await self._intercept_doc_store_insert(mothers[b : b + doc_bulk_size], search.index_name(task_tenant_id), task_dataset_id) if self._task_context.has_canceled_func(task_id): self._task_context.progress_cb(-1, msg="Task has been canceled.") @@ -346,7 +328,7 @@ class ChunkService: async def _intercept_doc_store_insert(self, chunks: list, index_name: str, task_dataset_id: str) -> Any: if self._task_context.write_interceptor: - if self._task_context.doc_id == GRAPH_RAPTOR_FAKE_DOC_ID: # raptor - non-determinisic + if self._task_context.doc_id == GRAPH_RAPTOR_FAKE_DOC_ID: # raptor - non-determinisic return self._task_context.write_interceptor.intercept("docStoreConn.insert", []) return self._task_context.write_interceptor.intercept("docStoreConn.insert") else: @@ -362,40 +344,28 @@ class ChunkService: ) -> bool: """Insert main chunks in batches with cancellation handling.""" for b in range(0, len(chunks), doc_bulk_size): - doc_store_result = await self._intercept_doc_store_insert( - chunks[b:b + doc_bulk_size], - search.index_name(task_tenant_id), - task_dataset_id - ) + doc_store_result = await self._intercept_doc_store_insert(chunks[b : b + doc_bulk_size], search.index_name(task_tenant_id), task_dataset_id) if self._task_context.has_canceled_func(task_id): # Roll back partial RAPTOR summary inserts - await self._rollback_raptor_chunks( - task_id, task_tenant_id, task_dataset_id, chunks, b, doc_bulk_size - ) + await self._rollback_raptor_chunks(task_id, task_tenant_id, task_dataset_id, chunks, b, doc_bulk_size) self._task_context.progress_cb(-1, msg="Task has been canceled.") return False if b % 128 == 0: - self._task_context.progress_cb(prog=0.8 + 0.1 * (b + 1) / len(chunks),msg="") + self._task_context.progress_cb(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="") if doc_store_result: - error_message = ( - f"Insert chunk error: {doc_store_result}, " - "please check log file and Elasticsearch/Infinity status!" - ) + error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!" self._task_context.progress_cb(-1, msg=error_message) raise Exception(error_message) # Update chunk IDs in task - chunk_ids = [chunk["id"] for chunk in chunks[:b + doc_bulk_size]] + chunk_ids = [chunk["id"] for chunk in chunks[: b + doc_bulk_size]] if not await self._update_task_chunk_ids(task_id, chunk_ids): # Roll back on failure await self._rollback_insertion(task_tenant_id, task_dataset_id, chunk_ids) - self._task_context.progress_cb( - -1, - msg=f"Chunk updates failed since task {task_id} is unknown." - ) + self._task_context.progress_cb(-1, msg=f"Chunk updates failed since task {task_id} is unknown.") return False return True @@ -410,19 +380,15 @@ class ChunkService: doc_bulk_size: int, ): """Roll back partial RAPTOR summary inserts after cancellation.""" - raptor_ids = [ - c["id"] for c in chunks[:up_to_batch + doc_bulk_size] - if c.get("raptor_kwd") == "raptor" - ] + raptor_ids = [c["id"] for c in chunks[: up_to_batch + doc_bulk_size] if c.get("raptor_kwd") == "raptor"] if raptor_ids: try: - await self._intercept_doc_store_delete( - {"id": raptor_ids}, search.index_name(task_tenant_id), task_dataset_id - ) + await self._intercept_doc_store_delete({"id": raptor_ids}, search.index_name(task_tenant_id), task_dataset_id) logging.info( "insert_chunks: rolled back %d partial RAPTOR chunks after cancellation (task=%s)", - len(raptor_ids), task_id, + len(raptor_ids), + task_id, ) except Exception: logging.exception( @@ -454,9 +420,7 @@ class ChunkService: chunk_ids: List[str], ): """Roll back an insertion by deleting chunks and images.""" - await self._intercept_doc_store_delete( - {"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id - ) + await self._intercept_doc_store_delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id) # Delete associated images tasks = [] diff --git a/rag/svr/task_executor_refactor/comparator.py b/rag/svr/task_executor_refactor/comparator.py index 0ace8d9aa0..397eaa5430 100644 --- a/rag/svr/task_executor_refactor/comparator.py +++ b/rag/svr/task_executor_refactor/comparator.py @@ -102,19 +102,17 @@ class ContextComparator: A new dictionary with non-deterministic fields removed. """ import copy + result = copy.copy(data) for key, value in result.items(): if isinstance(value, dict): # Create a new dict without the non-deterministic keys - cleaned = { - k: v for k, v in value.items() - if k not in self.DICT_KEYS_TO_STRIP - } + cleaned = {k: v for k, v in value.items() if k not in self.DICT_KEYS_TO_STRIP} result[key] = cleaned return result @staticmethod - def _get_key_values_to_compare(prod_data_all:dict): + def _get_key_values_to_compare(prod_data_all: dict): prod_data = dict() for key, value in prod_data_all.items(): if key in ALLOWED_METHOD_NAMES: diff --git a/rag/svr/task_executor_refactor/dataflow_service.py b/rag/svr/task_executor_refactor/dataflow_service.py index 9cf345a300..79b396d7a8 100644 --- a/rag/svr/task_executor_refactor/dataflow_service.py +++ b/rag/svr/task_executor_refactor/dataflow_service.py @@ -109,10 +109,7 @@ class DataflowService: dataflow_id = corrected_id # Run pipeline - pipeline = Pipeline( - dsl, tenant_id=ctx.tenant_id, doc_id=doc_id, - task_id=task_id, flow_id=dataflow_id - ) + pipeline = Pipeline(dsl, tenant_id=ctx.tenant_id, doc_id=doc_id, task_id=task_id, flow_id=dataflow_id) chunks = await pipeline.run(file=ctx.file) if ctx.file else await pipeline.run() if doc_id == CANVAS_DEBUG_DOC_ID: @@ -140,9 +137,7 @@ class DataflowService: # Embed chunks if needed keys = [k for o in chunks for k in list(o.keys())] if not any([re.match(r"q_[0-9]+_vec", k) for k in keys]): - chunks, embedding_token_consumption = await self._embed_chunks( - chunks, embedding_token_consumption - ) + chunks, embedding_token_consumption = await self._embed_chunks(chunks, embedding_token_consumption) if chunks is None: self._record_pipeline_log(doc_id, dataflow_id, pipeline) return @@ -157,33 +152,22 @@ class DataflowService: # Insert chunks start_ts = timer() self._progress(prog=0.82, msg="[DOC Engine]:\nStart to index...") - e = await self._insert_chunks( - task_id, ctx.tenant_id, ctx.kb_id, chunks - ) + e = await self._insert_chunks(task_id, ctx.tenant_id, ctx.kb_id, chunks) if not e: self._record_pipeline_log(doc_id, dataflow_id, pipeline) return time_cost = timer() - start_ts task_time_cost = timer() - task_start_ts - self._progress( - prog=1., - msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost) - ) + self._progress(prog=1.0, msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost)) # Update document stats if ctx.write_interceptor: ctx.write_interceptor.intercept("DocumentService.increment_chunk_num") else: - DocumentService.increment_chunk_num( - doc_id, task_dataset_id, embedding_token_consumption, len(chunks), task_time_cost - ) + DocumentService.increment_chunk_num(doc_id, task_dataset_id, embedding_token_consumption, len(chunks), task_time_cost) - logging.info( - "[Done], chunks({}), token({}), elapsed:{:.2f}".format( - len(chunks), embedding_token_consumption, task_time_cost - ) - ) + logging.info("[Done], chunks({}), token({}), elapsed:{:.2f}".format(len(chunks), embedding_token_consumption, task_time_cost)) ctx.recording_context.record("dataflow_chunks", chunks) self._record_pipeline_log(doc_id, dataflow_id, pipeline) @@ -244,21 +228,17 @@ class DataflowService: return [] @timeout(60) - async def _embed_chunks( - self, chunks: List[Dict], token_consumption: int - ) -> Tuple[Optional[List[Dict]], int]: + async def _embed_chunks(self, chunks: List[Dict], token_consumption: int) -> Tuple[Optional[List[Dict]], int]: """Embed chunks using the embedding model.""" ctx = self._task_context try: self._progress(prog=0.82, msg="\n-------------------------------------\nStart to embedding...") e, kb = self._get_kb_by_id(ctx.kb_id) embedding_id = kb.embd_id - embd_model_config = get_model_config_from_provider_instance( - ctx.tenant_id, LLMType.EMBEDDING, embedding_id - ) + embd_model_config = get_model_config_from_provider_instance(ctx.tenant_id, LLMType.EMBEDDING, embedding_id) from api.db.services.llm_service import LLMBundle - with LLMBundle(ctx.tenant_id, embd_model_config) as embedding_model: + with LLMBundle(ctx.tenant_id, embd_model_config) as embedding_model: # Prepare texts for embedding using EmbeddingUtils texts = EmbeddingUtils.prepare_texts_for_dataflow_embedding(chunks) delta = 0.20 / (len(texts) // self._embedding_batch_size + 1) @@ -267,19 +247,14 @@ class DataflowService: # Batch encode using EmbeddingUtils vects_batches = [] for i in range(0, len(texts), self._embedding_batch_size): - batch = texts[i: i + self._embedding_batch_size] + batch = texts[i : i + self._embedding_batch_size] async with ctx.embed_limiter: - vts, c = await thread_pool_exec( - self._encode_batch, batch, embedding_model - ) + vts, c = await thread_pool_exec(self._encode_batch, batch, embedding_model) vects_batches.append(vts) token_consumption += c prog += delta if i % (len(texts) // self._embedding_batch_size / 100 + 1) == 1: - self._progress( - prog=prog, - msg=f"{i + 1} / {len(texts) // self._embedding_batch_size}" - ) + self._progress(prog=prog, msg=f"{i + 1} / {len(texts) // self._embedding_batch_size}") # Stack vectors using EmbeddingUtils vects = EmbeddingUtils.stack_vectors(vects_batches) @@ -358,11 +333,10 @@ class DataflowService: else: DocMetadataService.update_document_metadata(doc_id, metadata) - async def _insert_chunks( - self, task_id: str, tenant_id: str, kb_id: str, chunks: List[Dict] - ) -> bool: + async def _insert_chunks(self, task_id: str, tenant_id: str, kb_id: str, chunks: List[Dict]) -> bool: """Insert chunks into document store.""" from rag.svr.task_executor_refactor.chunk_service import ChunkService + chunk_service = ChunkService(self._task_context) return await chunk_service.insert_chunks(task_id, tenant_id, kb_id, chunks) @@ -371,15 +345,13 @@ class DataflowService: if self._task_context.write_interceptor: self._task_context.write_interceptor.intercept("PipelineOperationLogService.create") else: - PipelineOperationLogService.create( - document_id=doc_id, pipeline_id=dataflow_id, - task_type=PipelineTaskType.PARSE, dsl=str(pipeline) - ) + PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline)) @classmethod def _get_kb_by_id(cls, kb_id: str): """Get knowledge base by ID.""" from api.db.services.knowledgebase_service import KnowledgebaseService + return KnowledgebaseService.get_by_id(kb_id) def _progress(self, prog=None, msg=None): diff --git a/rag/svr/task_executor_refactor/embedding_service.py b/rag/svr/task_executor_refactor/embedding_service.py index a22c733fc3..84d5ba6387 100644 --- a/rag/svr/task_executor_refactor/embedding_service.py +++ b/rag/svr/task_executor_refactor/embedding_service.py @@ -90,7 +90,7 @@ class EmbeddingService: # Batch encode contents using EmbeddingUtils vects_batches = [] for i in range(0, len(contents), self._embedding_batch_size): - batch = contents[i: i + self._embedding_batch_size] + batch = contents[i : i + self._embedding_batch_size] async with self._task_context.embed_limiter: vts, c = await thread_pool_exec( self._batch_encode_wrapper, diff --git a/rag/svr/task_executor_refactor/embedding_utils.py b/rag/svr/task_executor_refactor/embedding_utils.py index 44eec7e125..70a2c4cf26 100644 --- a/rag/svr/task_executor_refactor/embedding_utils.py +++ b/rag/svr/task_executor_refactor/embedding_utils.py @@ -180,11 +180,7 @@ class EmbeddingUtils: if not title_weight: title_weight = cls.DEFAULT_TITLE_WEIGHT - if ( - title_vecs is not None - and content_vecs.ndim == 2 - and title_vecs.shape == content_vecs.shape - ): + if title_vecs is not None and content_vecs.ndim == 2 and title_vecs.shape == content_vecs.shape: return title_weight * title_vecs + (1 - title_weight) * content_vecs return content_vecs diff --git a/rag/svr/task_executor_refactor/recording_context.py b/rag/svr/task_executor_refactor/recording_context.py index bf64a68644..b8631f844a 100644 --- a/rag/svr/task_executor_refactor/recording_context.py +++ b/rag/svr/task_executor_refactor/recording_context.py @@ -416,4 +416,4 @@ def timed_with_recording( return wrapper - return decorator \ No newline at end of file + return decorator diff --git a/rag/svr/task_executor_refactor/report_generator.py b/rag/svr/task_executor_refactor/report_generator.py index 725fb6ff4b..dcbb414c02 100644 --- a/rag/svr/task_executor_refactor/report_generator.py +++ b/rag/svr/task_executor_refactor/report_generator.py @@ -83,10 +83,7 @@ class ComparisonReport: if self.total_keys == 0: return f"Task {self.task_id}: No keys to compare" match_rate = (self.matched_keys / self.total_keys) * 100 - return ( - f"Task {self.task_id}: {self.matched_keys}/{self.total_keys} " - f"keys matched ({match_rate:.1f}%)" - ) + return f"Task {self.task_id}: {self.matched_keys}/{self.total_keys} keys matched ({match_rate:.1f}%)" def to_dict(self) -> dict: """Convert to dictionary for serialization. diff --git a/rag/svr/task_executor_refactor/task_context.py b/rag/svr/task_executor_refactor/task_context.py index 8d9f41c4db..dd52de26b3 100644 --- a/rag/svr/task_executor_refactor/task_context.py +++ b/rag/svr/task_executor_refactor/task_context.py @@ -143,6 +143,7 @@ class TaskDict(TypedDict, total=False): message_dict: Dict[str, Any] """Message dictionary for memory tasks.""" + # ============================================================================ # Data Classes # ============================================================================ @@ -270,7 +271,6 @@ class TaskContext: self._write_interceptor = write_interceptor self._recording_context = recording_context - # Prepare progress callback and set it on the context progress_cb = partial( callbacks.progress, diff --git a/rag/svr/task_executor_refactor/task_manager.py b/rag/svr/task_executor_refactor/task_manager.py index 041a1b8924..28b1475c3f 100644 --- a/rag/svr/task_executor_refactor/task_manager.py +++ b/rag/svr/task_executor_refactor/task_manager.py @@ -30,7 +30,8 @@ from rag.svr.task_executor_refactor.recording_context import ( BaseRecordingContext, RecordingContext, _NULL_RECORDING_CONTEXT, - set_recording_context, recording_context_manager, + set_recording_context, + recording_context_manager, ) from rag.svr.task_executor_refactor.task_context import TaskContext from rag.svr.task_executor_refactor.task_handler import TaskHandler @@ -166,12 +167,14 @@ class TaskManager: comp_result = comp.compare(task_context.id, recording_ctx1, recording_ctx2) logging.info(f"-------{task_context.name}, compare result:{comp_result.to_markdown()}") if interceptor.remaining_values_count() > 0 or comp_result.mismatched_keys > 0: - logging.info(f"------task:{task_context.id} {task_context.name} differs, " - f"interceptor.remaining_values_count():{interceptor.remaining_values_count()}, " - f"mismatched_keys:{comp_result.mismatched_keys}") + logging.info( + f"------task:{task_context.id} {task_context.name} differs, " + f"interceptor.remaining_values_count():{interceptor.remaining_values_count()}, " + f"mismatched_keys:{comp_result.mismatched_keys}" + ) if interceptor.remaining_values_count() > 0: logging.info(f"------task:{task_context.id}, remaining values:{interceptor.remaining_values()}") if comp_result.mismatched_keys > 0: logging.info(f"-------compare result:{comp_result.details}") else: - logging.info(f"------task:{task_context.id} {task_context.name} same result for prod and dry run ") \ No newline at end of file + logging.info(f"------task:{task_context.id} {task_context.name} same result for prod and dry run ") diff --git a/rag/svr/task_executor_refactor/write_operation_interceptor.py b/rag/svr/task_executor_refactor/write_operation_interceptor.py index fe57f8d9c8..0acb38e502 100644 --- a/rag/svr/task_executor_refactor/write_operation_interceptor.py +++ b/rag/svr/task_executor_refactor/write_operation_interceptor.py @@ -20,6 +20,7 @@ Provides a mechanism to intercept write operations during comparison mode. The interceptor consumes pre-recorded return values (from production execution) and returns them one by one when the corresponding methods are called. """ + import logging from typing import Any, Dict, List @@ -34,7 +35,7 @@ ALLOWED_METHOD_NAMES = { "delete_raptor_chunks", "handle_save_to_memory_task", "docStoreConn.insert", - "docStoreConn.delete" + "docStoreConn.delete", } _NO_DEFAULT = object() @@ -80,7 +81,7 @@ class WriteOperationInterceptor: for key in ALLOWED_METHOD_NAMES: self._recorded_values[key] = list(recorded_values.get(key, [])) - def intercept(self, method_name: str, default_value = _NO_DEFAULT) -> Any: + def intercept(self, method_name: str, default_value=_NO_DEFAULT) -> Any: """Intercept a method call and return the next pre-recorded value. Args: @@ -96,10 +97,7 @@ class WriteOperationInterceptor: IndexError: If the recorded values list for method_name is empty. """ if method_name not in ALLOWED_METHOD_NAMES: - raise ValueError( - f"Cannot intercept method '{method_name}'. " - f"Allowed method names: {ALLOWED_METHOD_NAMES}" - ) + raise ValueError(f"Cannot intercept method '{method_name}'. Allowed method names: {ALLOWED_METHOD_NAMES}") if method_name not in self._recorded_values: raise KeyError(f"No recorded values found for method '{method_name}'") @@ -113,7 +111,6 @@ class WriteOperationInterceptor: return values_list.pop(0) - def remaining_count(self, method_name: str) -> int: """Get the number of remaining recorded values for a method. @@ -127,7 +124,6 @@ class WriteOperationInterceptor: return 0 return len(self._recorded_values[method_name]) - def remaining_values(self): return {k: list(v) for k, v in self._recorded_values.items()} diff --git a/rag/utils/azure_sas_conn.py b/rag/utils/azure_sas_conn.py index 9d43bcbf54..f6af45ee37 100644 --- a/rag/utils/azure_sas_conn.py +++ b/rag/utils/azure_sas_conn.py @@ -27,8 +27,8 @@ from common import settings class RAGFlowAzureSasBlob: def __init__(self): self.conn = None - self.container_url = os.getenv('CONTAINER_URL', settings.AZURE["container_url"]) - self.sas_token = os.getenv('SAS_TOKEN', settings.AZURE["sas_token"]) + self.container_url = os.getenv("CONTAINER_URL", settings.AZURE["container_url"]) + self.sas_token = os.getenv("SAS_TOKEN", settings.AZURE["sas_token"]) self.__open__() def __open__(self): diff --git a/rag/utils/azure_spn_conn.py b/rag/utils/azure_spn_conn.py index ac23ecb172..ebc8585832 100644 --- a/rag/utils/azure_spn_conn.py +++ b/rag/utils/azure_spn_conn.py @@ -34,12 +34,12 @@ _CLOUD_AUTHORITY_MAP = { class RAGFlowAzureSpnBlob: def __init__(self): self.conn = None - self.account_url = os.getenv('ACCOUNT_URL', settings.AZURE["account_url"]) - self.client_id = os.getenv('CLIENT_ID', settings.AZURE["client_id"]) - self.secret = os.getenv('SECRET', settings.AZURE["secret"]) - self.tenant_id = os.getenv('TENANT_ID', settings.AZURE["tenant_id"]) - self.container_name = os.getenv('CONTAINER_NAME', settings.AZURE["container_name"]) - self.cloud = os.getenv('AZURE_CLOUD', settings.AZURE.get("cloud", "public")).lower() + self.account_url = os.getenv("ACCOUNT_URL", settings.AZURE["account_url"]) + self.client_id = os.getenv("CLIENT_ID", settings.AZURE["client_id"]) + self.secret = os.getenv("SECRET", settings.AZURE["secret"]) + self.tenant_id = os.getenv("TENANT_ID", settings.AZURE["tenant_id"]) + self.container_name = os.getenv("CONTAINER_NAME", settings.AZURE["container_name"]) + self.cloud = os.getenv("AZURE_CLOUD", settings.AZURE.get("cloud", "public")).lower() self.__open__() def __open__(self): @@ -51,10 +51,8 @@ class RAGFlowAzureSpnBlob: try: authority = _CLOUD_AUTHORITY_MAP.get(self.cloud, AzureAuthorityHosts.AZURE_PUBLIC_CLOUD) - credentials = ClientSecretCredential(tenant_id=self.tenant_id, client_id=self.client_id, - client_secret=self.secret, authority=authority) - self.conn = FileSystemClient(account_url=self.account_url, file_system_name=self.container_name, - credential=credentials) + credentials = ClientSecretCredential(tenant_id=self.tenant_id, client_id=self.client_id, client_secret=self.secret, authority=authority) + self.conn = FileSystemClient(account_url=self.account_url, file_system_name=self.container_name, credential=credentials) except Exception: logging.exception("Fail to connect %s" % self.account_url) diff --git a/rag/utils/base64_image.py b/rag/utils/base64_image.py index cfdb2a19d3..494fe89cf7 100644 --- a/rag/utils/base64_image.py +++ b/rag/utils/base64_image.py @@ -22,7 +22,6 @@ from io import BytesIO from PIL import Image - from common.misc_utils import thread_pool_exec from rag.utils.lazy_image import open_image_for_processing @@ -82,9 +81,7 @@ async def image2id(d: dict, storage_put_func: partial, objname: str, bucket: str return async with minio_limiter: - await thread_pool_exec( - lambda: storage_put_func(bucket=bucket, fnm=objname, binary=jpeg_binary) - ) + await thread_pool_exec(lambda: storage_put_func(bucket=bucket, fnm=objname, binary=jpeg_binary)) d["img_id"] = f"{bucket}-{objname}" diff --git a/rag/utils/encrypted_storage.py b/rag/utils/encrypted_storage.py index e5ac9cf975..45631aaf87 100644 --- a/rag/utils/encrypted_storage.py +++ b/rag/utils/encrypted_storage.py @@ -20,13 +20,14 @@ from common.crypto_utils import CryptoUtil # from common.decorator import singleton + class EncryptedStorageWrapper: """Encrypted storage wrapper that wraps existing storage implementations to provide transparent encryption""" def __init__(self, storage_impl, algorithm="aes-256-cbc", key=None, iv=None): """ Initialize encrypted storage wrapper - + Args: storage_impl: Original storage implementation instance algorithm: Encryption algorithm, default is aes-256-cbc @@ -49,13 +50,13 @@ class EncryptedStorageWrapper: def put(self, bucket, fnm, binary, tenant_id=None): """ Encrypt and store data - + Args: bucket: Bucket name fnm: File name binary: Original binary data tenant_id: Tenant ID (optional) - + Returns: Storage result """ @@ -73,12 +74,12 @@ class EncryptedStorageWrapper: def get(self, bucket, fnm, tenant_id=None): """ Retrieve and decrypt data - + Args: bucket: Bucket name fnm: File name tenant_id: Tenant ID (optional) - + Returns: Decrypted binary data """ @@ -103,12 +104,12 @@ class EncryptedStorageWrapper: def rm(self, bucket, fnm, tenant_id=None): """ Delete data (same as original storage implementation, no decryption needed) - + Args: bucket: Bucket name fnm: File name tenant_id: Tenant ID (optional) - + Returns: Deletion result """ @@ -117,12 +118,12 @@ class EncryptedStorageWrapper: def obj_exist(self, bucket, fnm, tenant_id=None): """ Check if object exists (same as original storage implementation, no decryption needed) - + Args: bucket: Bucket name fnm: File name tenant_id: Tenant ID (optional) - + Returns: Whether the object exists """ @@ -131,7 +132,7 @@ class EncryptedStorageWrapper: def health(self): """ Health check (uses the original storage implementation's method) - + Returns: Health check result """ @@ -140,10 +141,10 @@ class EncryptedStorageWrapper: def bucket_exists(self, bucket): """ Check if bucket exists (if the original storage implementation has this method) - + Args: bucket: Bucket name - + Returns: Whether the bucket exists """ @@ -154,13 +155,13 @@ class EncryptedStorageWrapper: def get_presigned_url(self, bucket, fnm, expires, tenant_id=None): """ Get presigned URL (if the original storage implementation has this method) - + Args: bucket: Bucket name fnm: File name expires: Expiration time tenant_id: Tenant ID (optional) - + Returns: Presigned URL """ @@ -171,12 +172,12 @@ class EncryptedStorageWrapper: def scan(self, bucket, fnm, tenant_id=None): """ Scan objects (if the original storage implementation has this method) - + Args: bucket: Bucket name fnm: File name prefix tenant_id: Tenant ID (optional) - + Returns: Scan results """ @@ -187,13 +188,13 @@ class EncryptedStorageWrapper: def copy(self, src_bucket, src_path, dest_bucket, dest_path): """ Copy object (if the original storage implementation has this method) - + Args: src_bucket: Source bucket name src_path: Source file path dest_bucket: Destination bucket name dest_path: Destination file path - + Returns: Copy result """ @@ -204,13 +205,13 @@ class EncryptedStorageWrapper: def move(self, src_bucket, src_path, dest_bucket, dest_path): """ Move object (if the original storage implementation has this method) - + Args: src_bucket: Source bucket name src_path: Source file path dest_bucket: Destination bucket name dest_path: Destination file path - + Returns: Move result """ @@ -221,10 +222,10 @@ class EncryptedStorageWrapper: def remove_bucket(self, bucket): """ Remove bucket (if the original storage implementation has this method) - + Args: bucket: Bucket name - + Returns: Remove result """ @@ -247,13 +248,13 @@ class EncryptedStorageWrapper: def create_encrypted_storage(storage_impl, algorithm=None, key=None, encryption_enabled=True): """ Create singleton instance of encrypted storage wrapper - + Args: storage_impl: Original storage implementation instance algorithm: Encryption algorithm, uses environment variable RAGFLOW_CRYPTO_ALGORITHM or default if None key: Encryption key, uses environment variable RAGFLOW_CRYPTO_KEY if None encryption_enabled: Whether to enable encryption functionality - + Returns: Encrypted storage wrapper instance """ diff --git a/rag/utils/file_utils.py b/rag/utils/file_utils.py index c9ec50a36a..8a1006fbf4 100644 --- a/rag/utils/file_utils.py +++ b/rag/utils/file_utils.py @@ -35,7 +35,7 @@ def _is_pdf(h: bytes) -> bool: def _is_ole(h: bytes) -> bool: - return h.startswith(b"\xD0\xCF\x11\xE0\xA1\xB1\x1A\xE1") + return h.startswith(b"\xd0\xcf\x11\xe0\xa1\xb1\x1a\xe1") def _sha10(b: bytes) -> str: @@ -70,7 +70,7 @@ def _extract_ole10native_payload(data: bytes) -> bytes: pos = 0 if len(data) < 4: return data - _ = int.from_bytes(data[pos:pos + 4], "little") + _ = int.from_bytes(data[pos : pos + 4], "little") pos += 4 # filename/src/tmp (NUL-terminated ANSI) for _ in range(3): @@ -80,10 +80,10 @@ def _extract_ole10native_payload(data: bytes) -> bytes: pos += 4 if pos + 4 > len(data): return data - size = int.from_bytes(data[pos:pos + 4], "little") + size = int.from_bytes(data[pos : pos + 4], "little") pos += 4 if pos + size <= len(data): - return data[pos:pos + size] + return data[pos : pos + size] except Exception: pass return data @@ -115,10 +115,7 @@ def extract_embed_file(target: Union[bytes, bytearray]) -> List[Tuple[str, bytes if _is_zip(head): try: with zipfile.ZipFile(io.BytesIO(top), "r") as z: - embed_dirs = ( - "word/embeddings/", "word/objects/", "word/activex/", - "xl/embeddings/", "ppt/embeddings/" - ) + embed_dirs = ("word/embeddings/", "word/objects/", "word/activex/", "xl/embeddings/", "ppt/embeddings/") for name in z.namelist(): low = name.lower() if any(low.startswith(d) for d in embed_dirs): @@ -169,9 +166,7 @@ def extract_links_from_docx(docx_bytes: bytes): # Each relationship may represent a hyperlink, image, footer, etc. for rel in document.part.rels.values(): - if rel.reltype == ( - "http://schemas.openxmlformats.org/officeDocument/2006/relationships/hyperlink" - ): + if rel.reltype == ("http://schemas.openxmlformats.org/officeDocument/2006/relationships/hyperlink"): links.add(rel.target_ref) return links @@ -212,23 +207,17 @@ def _get_session(headers: Optional[Dict[str, str]] = None) -> requests.Session: global _GLOBAL_SESSION if _GLOBAL_SESSION is None: _GLOBAL_SESSION = requests.Session() - _GLOBAL_SESSION.headers.update({ - "User-Agent": ( - "Mozilla/5.0 (X11; Linux x86_64) " - "AppleWebKit/537.36 (KHTML, like Gecko) " - "Chrome/121.0 Safari/537.36" - ) - }) + _GLOBAL_SESSION.headers.update({"User-Agent": ("Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/121.0 Safari/537.36")}) if headers: _GLOBAL_SESSION.headers.update(headers) return _GLOBAL_SESSION def extract_html( - url: str, - timeout: float = 60.0, - headers: Optional[Dict[str, str]] = None, - max_retries: int = 2, + url: str, + timeout: float = 60.0, + headers: Optional[Dict[str, str]] = None, + max_retries: int = 2, ) -> Tuple[Optional[bytes], Dict[str, str]]: """ Extract the full HTML page as raw bytes from a given URL. @@ -254,11 +243,13 @@ def extract_html( resp.raise_for_status() html_bytes = resp.content - metadata.update({ - "final_url": resp.url, - "status_code": str(resp.status_code), - "content_type": resp.headers.get("Content-Type", ""), - }) + metadata.update( + { + "final_url": resp.url, + "status_code": str(resp.status_code), + "content_type": resp.headers.get("Content-Type", ""), + } + ) return html_bytes, metadata except Timeout: diff --git a/rag/utils/gcs_conn.py b/rag/utils/gcs_conn.py index 53aa3d4e5a..a46db79ec0 100644 --- a/rag/utils/gcs_conn.py +++ b/rag/utils/gcs_conn.py @@ -59,7 +59,7 @@ class RAGFlowGCS: blob_path = self._get_blob_path(folder, fnm) blob = bucket_obj.blob(blob_path) - blob.upload_from_file(BytesIO(binary), content_type='application/octet-stream') + blob.upload_from_file(BytesIO(binary), content_type="application/octet-stream") return True except Exception as e: logging.exception(f"Health check failed: {e}") @@ -73,7 +73,7 @@ class RAGFlowGCS: blob_path = self._get_blob_path(bucket, fnm) blob = bucket_obj.blob(blob_path) - blob.upload_from_file(BytesIO(binary), content_type='application/octet-stream') + blob.upload_from_file(BytesIO(binary), content_type="application/octet-stream") return True except NotFound: logging.error(f"Fail to put: Main bucket {self.bucket_name} does not exist.") @@ -145,11 +145,7 @@ class RAGFlowGCS: if isinstance(expires, int): expiration = datetime.timedelta(seconds=expires) - url = blob.generate_signed_url( - version="v4", - expiration=expiration, - method="GET" - ) + url = blob.generate_signed_url(version="v4", expiration=expiration, method="GET") return url except Exception: logging.exception(f"Fail to get_presigned {bucket}/{fnm}:") diff --git a/rag/utils/infinity_conn.py b/rag/utils/infinity_conn.py index 9f3274846d..7cb571fa86 100644 --- a/rag/utils/infinity_conn.py +++ b/rag/utils/infinity_conn.py @@ -35,9 +35,7 @@ class InfinityConnection(InfinityConnectionBase): @staticmethod def field_keyword(field_name: str): # Treat "*_kwd" tag-like columns as keyword lists except knowledge_graph_kwd; source_id is also keyword-like. - if field_name == "source_id" or ( - field_name.endswith("_kwd") and field_name not in ["knowledge_graph_kwd", "docnm_kwd", "important_kwd", - "question_kwd"]): + if field_name == "source_id" or (field_name.endswith("_kwd") and field_name not in ["knowledge_graph_kwd", "docnm_kwd", "important_kwd", "question_kwd"]): return True return False @@ -92,18 +90,18 @@ class InfinityConnection(InfinityConnectionBase): """ def search( - self, - select_fields: list[str], - highlight_fields: list[str], - condition: dict, - match_expressions: list[MatchExpr], - order_by: OrderByExpr, - offset: int, - limit: int, - index_names: str | list[str], - knowledgebase_ids: list[str], - agg_fields: list[str] | None = None, - rank_feature: dict | None = None, + self, + select_fields: list[str], + highlight_fields: list[str], + condition: dict, + match_expressions: list[MatchExpr], + order_by: OrderByExpr, + offset: int, + limit: int, + index_names: str | list[str], + knowledgebase_ids: list[str], + agg_fields: list[str] | None = None, + rank_feature: dict | None = None, ) -> tuple[pd.DataFrame, int]: """ BUG: Infinity returns empty for a highlight field if the query string doesn't use that field. @@ -172,8 +170,7 @@ class InfinityConnection(InfinityConnectionBase): if table_found: break if not table_found: - self.logger.error( - f"No valid tables found for indexNames {index_names} and knowledgebaseIds {knowledgebase_ids}") + self.logger.error(f"No valid tables found for indexNames {index_names} and knowledgebaseIds {knowledgebase_ids}") return pd.DataFrame(), 0 for matchExpr in match_expressions: @@ -306,8 +303,7 @@ class InfinityConnection(InfinityConnectionBase): try: table_instance = db_instance.get_table(table_name) except Exception: - self.logger.warning( - f"Table not found: {table_name}, this dataset isn't created in Infinity. Maybe it is created in other document engine.") + self.logger.warning(f"Table not found: {table_name}, this dataset isn't created in Infinity. Maybe it is created in other document engine.") continue kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunk_id}'").to_df() self.logger.debug(f"INFINITY get table: {str(table_list)}, result: {str(kb_res)}") @@ -316,9 +312,20 @@ class InfinityConnection(InfinityConnectionBase): self.connPool.release_conn(inf_conn) res = self.concat_dataframes(df_list, ["id"]) fields = set(res.columns.tolist()) - for field in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "question_kwd", - "question_tks", "content_with_weight", "content_ltks", "content_sm_ltks", "authors_tks", - "authors_sm_tks"]: + for field in [ + "docnm_kwd", + "title_tks", + "title_sm_tks", + "important_kwd", + "important_tks", + "question_kwd", + "question_tks", + "content_with_weight", + "content_ltks", + "content_sm_ltks", + "authors_tks", + "authors_sm_tks", + ]: fields.add(field) res_fields = self.get_fields(res, list(fields)) chunk = res_fields.get(chunk_id, None) @@ -327,7 +334,7 @@ class InfinityConnection(InfinityConnectionBase): return chunk def insert(self, documents: list[dict], index_name: str, knowledgebase_id: str = None) -> list[str]: - ''' + """ # Save input to file to test inserting from file in GO import datetime import os @@ -339,7 +346,7 @@ class InfinityConnection(InfinityConnectionBase): "chunks": documents }, f, indent=2) self.logger.debug(f"Saved insert input to {debug_file}") - ''' + """ inf_conn = self.connPool.get_conn() try: @@ -369,6 +376,7 @@ class InfinityConnection(InfinityConnectionBase): parser_id = None if "chunk_data" in documents[0] and isinstance(documents[0].get("chunk_data"), dict): from common.constants import ParserType + parser_id = ParserType.TABLE.value self.logger.debug("Detected TABLE parser from document structure") @@ -468,10 +476,20 @@ class InfinityConnection(InfinityConnectionBase): d[k] = v if v else "{}" else: d[k] = v - for k in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", - "content_with_weight", - "content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks", "question_kwd", - "question_tks"]: + for k in [ + "docnm_kwd", + "title_tks", + "title_sm_tks", + "important_kwd", + "important_tks", + "content_with_weight", + "content_ltks", + "content_sm_ltks", + "authors_tks", + "authors_sm_tks", + "question_kwd", + "question_tks", + ]: if k in d: del d[k] @@ -588,9 +606,20 @@ class InfinityConnection(InfinityConnectionBase): del new_value[k] else: new_value[k] = v - for k in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "content_with_weight", - "content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks", "question_kwd", - "question_tks"]: + for k in [ + "docnm_kwd", + "title_tks", + "title_sm_tks", + "important_kwd", + "important_tks", + "content_with_weight", + "content_ltks", + "content_sm_ltks", + "authors_tks", + "authors_sm_tks", + "question_kwd", + "question_tks", + ]: if k in new_value: del new_value[k] @@ -598,8 +627,7 @@ class InfinityConnection(InfinityConnectionBase): if removeValue: col_to_remove = list(removeValue.keys()) row_to_opt = table_instance.output(col_to_remove + ["id"]).filter(filter).to_df() - self.logger.debug( - f"INFINITY search table {str(table_name)}, filter {filter}, result: {str(row_to_opt[0])}") + self.logger.debug(f"INFINITY search table {str(table_name)}, filter {filter}, result: {str(row_to_opt[0])}") row_to_opt = self.get_fields(row_to_opt, col_to_remove) for id, old_v in row_to_opt.items(): for k, remove_v in removeValue.items(): @@ -615,8 +643,7 @@ class InfinityConnection(InfinityConnectionBase): self.logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {new_value}.") for update_kv, ids in remove_opt.items(): k, v = json.loads(update_kv) - table_instance.update(filter + " AND id in ({0})".format(",".join([f"'{id}'" for id in ids])), - {k: "###".join(v)}) + table_instance.update(filter + " AND id in ({0})".format(",".join([f"'{id}'" for id in ids])), {k: "###".join(v)}) table_instance.update(filter, new_value) finally: @@ -624,15 +651,15 @@ class InfinityConnection(InfinityConnectionBase): return True def adjust_chunk_pagerank_fea( - self, - chunk_id: str, - index_name: str, - knowledgebase_id: str, - delta: int, - min_weight: int, - max_weight: int, - row_id: int | None = None, - max_retries: int = 2, + self, + chunk_id: str, + index_name: str, + knowledgebase_id: str, + delta: int, + min_weight: int, + max_weight: int, + row_id: int | None = None, + max_retries: int = 2, ) -> bool: """Adjust pagerank_fea on one chunk row in Infinity. @@ -648,21 +675,18 @@ class InfinityConnection(InfinityConnectionBase): table_instance = db_instance.get_table(table_name) if row_id is None: - df, _ = table_instance.output( - [PAGERANK_FLD, "row_id()"] - ).filter(f"id = '{chunk_id}'").to_df() + df, _ = table_instance.output([PAGERANK_FLD, "row_id()"]).filter(f"id = '{chunk_id}'").to_df() if df.empty: self.logger.warning( "adjust_chunk_pagerank_fea: chunk %s not found in %s", - chunk_id, table_name, + chunk_id, + table_name, ) return False current_weight = int(float(df[PAGERANK_FLD].iloc[0] or 0)) row_id = int(df["row_id"].iloc[0]) else: - df, _ = table_instance.output( - [PAGERANK_FLD] - ).filter(f"id = '{chunk_id}'").to_df() + df, _ = table_instance.output([PAGERANK_FLD]).filter(f"id = '{chunk_id}'").to_df() if df.empty: return False current_weight = int(float(df[PAGERANK_FLD].iloc[0] or 0)) @@ -675,7 +699,11 @@ class InfinityConnection(InfinityConnectionBase): ) self.logger.info( "adjust_chunk_pagerank_fea(chunk=%s, table=%s): %s -> %s via row_id=%s", - chunk_id, table_name, current_weight, new_weight, row_id, + chunk_id, + table_name, + current_weight, + new_weight, + row_id, ) return True @@ -683,18 +711,26 @@ class InfinityConnection(InfinityConnectionBase): if attempt < max_retries: self.logger.warning( "adjust_chunk_pagerank_fea stale row_id=%s for chunk %s (attempt %s/%s): %s", - row_id, chunk_id, attempt + 1, max_retries, e, + row_id, + chunk_id, + attempt + 1, + max_retries, + e, ) row_id = None continue self.logger.error( "adjust_chunk_pagerank_fea failed for chunk %s after %s attempts: %s", - chunk_id, max_retries + 1, e, + chunk_id, + max_retries + 1, + e, ) return False except Exception as e: self.logger.error( - "adjust_chunk_pagerank_fea error for chunk %s: %s", chunk_id, e, + "adjust_chunk_pagerank_fea error for chunk %s: %s", + chunk_id, + e, ) return False finally: @@ -722,10 +758,7 @@ class InfinityConnection(InfinityConnectionBase): if "important_kwd_empty_count" in res.columns: base = res["important_keywords"].apply(lambda raw: raw.split(",") if raw else []) counts = res["important_kwd_empty_count"].fillna(0).astype(int) - res["important_kwd"] = [ - tokens + [""] * empty_count - for tokens, empty_count in zip(base.tolist(), counts.tolist()) - ] + res["important_kwd"] = [tokens + [""] * empty_count for tokens, empty_count in zip(base.tolist(), counts.tolist())] else: res["important_kwd"] = res["important_keywords"].apply(lambda v: v.split(",") if v else []) if "important_tks" in fields_all: @@ -765,10 +798,11 @@ class InfinityConnection(InfinityConnectionBase): # Parse JSON data back to dict for table parser fields res2[column] = res2[column].apply(lambda v: json.loads(v) if v and isinstance(v, str) else v) elif k == "position_int": + def to_position_int(v): if v: arr = [int(hex_val, 16) for hex_val in v.split("_")] - v = [arr[i: i + 5] for i in range(0, len(arr), 5)] + v = [arr[i : i + 5] for i in range(0, len(arr), 5)] else: v = [] return v diff --git a/rag/utils/minio_conn.py b/rag/utils/minio_conn.py index 5e46306cd1..63dc4738ff 100644 --- a/rag/utils/minio_conn.py +++ b/rag/utils/minio_conn.py @@ -44,8 +44,8 @@ class RAGFlowMinio: self.conn = None # Use `or None` to convert empty strings to None, ensuring single-bucket # mode is truly disabled when not configured - self.bucket = settings.MINIO.get('bucket', None) or None - self.prefix_path = settings.MINIO.get('prefix_path', None) or None + self.bucket = settings.MINIO.get("bucket", None) or None + self.prefix_path = settings.MINIO.get("prefix_path", None) or None self.__open__() @staticmethod @@ -58,7 +58,7 @@ class RAGFlowMinio: actual_bucket = self.bucket if self.bucket else bucket if self.bucket: # pass original identifier forward for use by other decorators - kwargs['_orig_bucket'] = original_bucket + kwargs["_orig_bucket"] = original_bucket return method(self, actual_bucket, *args, **kwargs) return wrapper @@ -71,7 +71,7 @@ class RAGFlowMinio: # bucket name and forwarded the original identifier as `_orig_bucket`. # Prefer that original identifier when constructing the key path so # objects are stored under //... - orig_bucket = kwargs.pop('_orig_bucket', None) + orig_bucket = kwargs.pop("_orig_bucket", None) if self.prefix_path: # If a prefix_path is configured, include it and then the identifier @@ -110,8 +110,7 @@ class RAGFlowMinio: http_client=http_client, ) except Exception: - logging.exception( - "Fail to connect %s " % settings.MINIO["host"]) + logging.exception("Fail to connect %s " % settings.MINIO["host"]) def __close__(self): del self.conn @@ -150,10 +149,7 @@ class RAGFlowMinio: if not self.bucket and not self.conn.bucket_exists(bucket): self.conn.make_bucket(bucket) - r = self.conn.put_object(bucket, fnm, - BytesIO(binary), - len(binary) - ) + r = self.conn.put_object(bucket, fnm, BytesIO(binary), len(binary)) return r except Exception: logging.exception(f"Fail to put {bucket}/{fnm}:") @@ -226,7 +222,7 @@ class RAGFlowMinio: @use_default_bucket def remove_bucket(self, bucket, **kwargs): - orig_bucket = kwargs.pop('_orig_bucket', None) + orig_bucket = kwargs.pop("_orig_bucket", None) try: if self.bucket: # Single bucket mode: remove objects with prefix diff --git a/rag/utils/ob_conn.py b/rag/utils/ob_conn.py index a51c923b64..3e1b480412 100644 --- a/rag/utils/ob_conn.py +++ b/rag/utils/ob_conn.py @@ -32,15 +32,18 @@ from common.constants import PAGERANK_FLD, TAG_FLD from common.decorator import singleton from common.doc_store.doc_store_base import MatchExpr, OrderByExpr, FusionExpr, MatchTextExpr, MatchDenseExpr from common.doc_store.ob_conn_base import ( - OBConnectionBase, get_value_str, - vector_search_template, vector_column_pattern, - fulltext_index_name_template, doc_meta_column_names, + OBConnectionBase, + get_value_str, + vector_search_template, + vector_column_pattern, + fulltext_index_name_template, + doc_meta_column_names, doc_meta_column_types, ) from common.float_utils import get_float from rag.nlp import rag_tokenizer -logger = logging.getLogger('ragflow.ob_conn') +logger = logging.getLogger("ragflow.ob_conn") column_order_id = Column("_order_id", Integer, nullable=True, comment="chunk order id for maintaining sequence") column_group_id = Column("group_id", String(256), nullable=True, comment="group id for external retrieval") @@ -48,8 +51,7 @@ column_mom_id = Column("mom_id", String(256), nullable=True, comment="parent chu column_chunk_data = Column("chunk_data", JSON, nullable=True, comment="table parser row data") column_raptor_kwd = Column("raptor_kwd", String(256), nullable=True, comment="RAPTOR summary marker") column_raptor_layer_int = Column("raptor_layer_int", Integer, nullable=True, comment="RAPTOR summary layer") -column_n_hop_with_weight = Column("n_hop_with_weight", LONGTEXT, nullable=True, - comment="JSON-encoded n-hop neighbour paths and weights for a graph entity") +column_n_hop_with_weight = Column("n_hop_with_weight", LONGTEXT, nullable=True, comment="JSON-encoded n-hop neighbour paths and weights for a graph entity") column_definitions: list[Column] = [ Column("id", String(256), primary_key=True, comment="chunk id"), @@ -68,10 +70,8 @@ column_definitions: list[Column] = [ Column("question_kwd", ARRAY(String(1024)), nullable=True, comment="questions"), Column("question_tks", TEXT, nullable=True, comment="question tokens"), Column("tag_kwd", ARRAY(String(256)), nullable=True, comment="tags"), - Column("tag_feas", JSON, nullable=True, - comment="tag features used for 'rank_feature', format: [tag -> relevance score]"), - Column("available_int", Integer, nullable=False, index=True, server_default="1", - comment="status of availability, 0 for unavailable, 1 for available"), + Column("tag_feas", JSON, nullable=True, comment="tag features used for 'rank_feature', format: [tag -> relevance score]"), + Column("available_int", Integer, nullable=False, index=True, server_default="1", comment="status of availability, 0 for unavailable, 1 for available"), Column("create_time", String(19), nullable=True, comment="creation time in YYYY-MM-DD HH:MM:SS format"), Column("create_timestamp_flt", Double, nullable=True, comment="creation timestamp in float format"), Column("img_id", String(128), nullable=True, comment="image id"), @@ -89,8 +89,7 @@ column_definitions: list[Column] = [ Column("entities_kwd", ARRAY(String(256)), nullable=True, comment="node ids of entities"), Column("rank_flt", Double, nullable=True, comment="rank of this entity"), column_n_hop_with_weight, - Column("removed_kwd", String(256), nullable=True, index=True, server_default="'N'", - comment="whether it has been deleted"), + Column("removed_kwd", String(256), nullable=True, index=True, server_default="'N'", comment="whether it has been deleted"), column_raptor_kwd, column_raptor_layer_int, column_chunk_data, @@ -320,7 +319,7 @@ def get_filters(condition: dict) -> list[str]: @singleton class OBConnection(OBConnectionBase): def __init__(self): - super().__init__(logger_name='ragflow.ob_conn') + super().__init__(logger_name="ragflow.ob_conn") # Determine which columns to use for full-text search dynamically self._fulltext_search_columns = FTS_COLUMNS_ORIGIN if self.search_original_content else FTS_COLUMNS_TKS @@ -364,16 +363,7 @@ class OBConnection(OBConnectionBase): Returns: dict: Performance metrics including latency, storage, QPS, and slow queries """ - metrics = { - "connection": "connected", - "latency_ms": 0.0, - "storage_used": "0B", - "storage_total": "0B", - "query_per_second": 0, - "slow_queries": 0, - "active_connections": 0, - "max_connections": 0 - } + metrics = {"connection": "connected", "latency_ms": 0.0, "storage_used": "0B", "storage_total": "0B", "query_per_second": 0, "slow_queries": 0, "active_connections": 0, "max_connections": 0} try: # Measure connection latency @@ -426,33 +416,23 @@ class OBConnection(OBConnectionBase): try: # Get database size result = self.client.perform_raw_text_sql( - f"SELECT ROUND(SUM(data_length + index_length) / 1024 / 1024, 2) AS 'size_mb' " - f"FROM information_schema.tables WHERE table_schema = '{self.db_name}'" + f"SELECT ROUND(SUM(data_length + index_length) / 1024 / 1024, 2) AS 'size_mb' FROM information_schema.tables WHERE table_schema = '{self.db_name}'" ).fetchone() size_mb = float(result[0]) if result and result[0] else 0.0 # Try to get total available space (may not be available in all OceanBase versions) try: - result = self.client.perform_raw_text_sql( - "SELECT ROUND(SUM(total_size) / 1024 / 1024 / 1024, 2) AS 'total_gb' " - "FROM oceanbase.__all_disk_stat" - ).fetchone() + result = self.client.perform_raw_text_sql("SELECT ROUND(SUM(total_size) / 1024 / 1024 / 1024, 2) AS 'total_gb' FROM oceanbase.__all_disk_stat").fetchone() total_gb = float(result[0]) if result and result[0] else None except Exception: # Fallback: estimate total space (100GB default if not available) total_gb = 100.0 - return { - "storage_used": f"{size_mb:.2f}MB", - "storage_total": f"{total_gb:.2f}GB" if total_gb else "N/A" - } + return {"storage_used": f"{size_mb:.2f}MB", "storage_total": f"{total_gb:.2f}GB" if total_gb else "N/A"} except Exception as e: logger.warning(f"Failed to get storage info: {str(e)}") - return { - "storage_used": "N/A", - "storage_total": "N/A" - } + return {"storage_used": "N/A", "storage_total": "N/A"} def _get_connection_pool_stats(self) -> dict: """ @@ -467,26 +447,16 @@ class OBConnection(OBConnectionBase): active_connections = len(list(result.fetchall())) # Get max_connections setting - max_conn_result = self.client.perform_raw_text_sql( - "SHOW VARIABLES LIKE 'max_connections'" - ).fetchone() + max_conn_result = self.client.perform_raw_text_sql("SHOW VARIABLES LIKE 'max_connections'").fetchone() max_connections = int(max_conn_result[1]) if max_conn_result and max_conn_result[1] else 0 # Get pool size from client if available - pool_size = getattr(self.client, 'pool_size', None) or 0 + pool_size = getattr(self.client, "pool_size", None) or 0 - return { - "active_connections": active_connections, - "max_connections": max_connections if max_connections > 0 else pool_size, - "pool_size": pool_size - } + return {"active_connections": active_connections, "max_connections": max_connections if max_connections > 0 else pool_size, "pool_size": pool_size} except Exception as e: logger.warning(f"Failed to get connection pool stats: {str(e)}") - return { - "active_connections": 0, - "max_connections": 0, - "pool_size": 0 - } + return {"active_connections": 0, "max_connections": 0, "pool_size": 0} def _get_slow_query_count(self, threshold_seconds: int = 1) -> int: """ @@ -499,10 +469,7 @@ class OBConnection(OBConnectionBase): int: Number of slow queries """ try: - result = self.client.perform_raw_text_sql( - f"SELECT COUNT(*) FROM information_schema.processlist " - f"WHERE time > {threshold_seconds} AND command != 'Sleep'" - ).fetchone() + result = self.client.perform_raw_text_sql(f"SELECT COUNT(*) FROM information_schema.processlist WHERE time > {threshold_seconds} AND command != 'Sleep'").fetchone() return int(result[0]) if result and result[0] else 0 except Exception as e: logger.warning(f"Failed to get slow query count: {str(e)}") @@ -517,9 +484,7 @@ class OBConnection(OBConnectionBase): """ try: # Count active queries (non-Sleep commands) - result = self.client.perform_raw_text_sql( - "SELECT COUNT(*) FROM information_schema.processlist WHERE command != 'Sleep'" - ).fetchone() + result = self.client.perform_raw_text_sql("SELECT COUNT(*) FROM information_schema.processlist WHERE command != 'Sleep'").fetchone() active_queries = int(result[0]) if result and result[0] else 0 # Rough estimate: assume average query takes 0.1 seconds @@ -585,8 +550,7 @@ class OBConnection(OBConnectionBase): if v == 0: bqry.filter.append(Q("range", available_int={"lt": 1})) else: - bqry.filter.append( - Q("bool", must_not=Q("range", available_int={"lt": 1}))) + bqry.filter.append(Q("bool", must_not=Q("range", available_int={"lt": 1}))) continue if not v: continue @@ -595,16 +559,18 @@ class OBConnection(OBConnectionBase): elif isinstance(v, str) or isinstance(v, int): bqry.filter.append(Q("term", **{k: v})) else: - raise Exception( - f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.") + raise Exception(f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.") s = Search() vector_similarity_weight = 0.5 for m in match_expressions: if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params: - if not (len(match_expressions) == 3 and isinstance(match_expressions[0], MatchTextExpr) and isinstance( - match_expressions[1], MatchDenseExpr) and isinstance( - match_expressions[2], FusionExpr)): + if not ( + len(match_expressions) == 3 + and isinstance(match_expressions[0], MatchTextExpr) + and isinstance(match_expressions[1], MatchDenseExpr) + and isinstance(match_expressions[2], FusionExpr) + ): raise ValueError("match_expressions must contain MatchTextExpr, MatchDenseExpr, and FusionExpr") weights = m.fusion_params["weights"] vector_similarity_weight = get_float(weights.split(",")[1]) @@ -613,10 +579,7 @@ class OBConnection(OBConnectionBase): minimum_should_match = m.extra_options.get("minimum_should_match", 0.0) if isinstance(minimum_should_match, float): minimum_should_match = str(int(minimum_should_match * 100)) + "%" - bqry.must.append(Q("query_string", fields=FTS_COLUMNS_TKS, - type="best_fields", query=m.matching_text, - minimum_should_match=minimum_should_match, - boost=1)) + bqry.must.append(Q("query_string", fields=FTS_COLUMNS_TKS, type="best_fields", query=m.matching_text, minimum_should_match=minimum_should_match, boost=1)) bqry.boost = 1.0 - vector_similarity_weight elif isinstance(m, MatchDenseExpr): @@ -625,13 +588,14 @@ class OBConnection(OBConnectionBase): similarity = 0.0 if "similarity" in m.extra_options: similarity = m.extra_options["similarity"] - s = s.knn(m.vector_column_name, - m.topn, - m.topn * 2, - query_vector=list(m.embedding_data), - filter=bqry.to_dict(), - similarity=similarity, - ) + s = s.knn( + m.vector_column_name, + m.topn, + m.topn * 2, + query_vector=list(m.embedding_data), + filter=bqry.to_dict(), + similarity=similarity, + ) if bqry and rank_feature: for fld, sc in rank_feature.items(): @@ -649,8 +613,7 @@ class OBConnection(OBConnectionBase): for field, order in order_by.fields: order = "asc" if order == 0 else "desc" if field in ["page_num_int", "top_int"]: - order_info = {"order": order, "unmapped_type": "float", - "mode": "avg", "numeric_type": "double"} + order_info = {"order": order, "unmapped_type": "float", "mode": "avg", "numeric_type": "double"} elif field.endswith("_int") or field.endswith("_flt"): order_info = {"order": order, "unmapped_type": "float"} else: @@ -659,25 +622,18 @@ class OBConnection(OBConnectionBase): s = s.sort(*orders) for fld in agg_fields: - s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000) + s.aggs.bucket(f"aggs_{fld}", "terms", field=fld, size=1000000) if limit > 0: - s = s[offset:offset + limit] + s = s[offset : offset + limit] q = s.to_dict() logger.debug(f"OBConnection.hybrid_search {str(index_names)} query: " + json.dumps(q)) for index_name in index_names: start_time = time.time() - res = self.es.search(index=index_name, - body=q, - timeout="600s", - track_total_hits=True, - _source=True) + res = self.es.search(index=index_name, body=q, timeout="600s", track_total_hits=True, _source=True) elapsed_time = time.time() - start_time - logger.info( - f"OBConnection.search table {index_name}, search type: hybrid, elapsed time: {elapsed_time:.3f} seconds," - f" got count: {len(res)}" - ) + logger.info(f"OBConnection.search table {index_name}, search type: hybrid, elapsed time: {elapsed_time:.3f} seconds, got count: {len(res)}") for chunk in res: result.chunks.append(self._es_row_to_entity(chunk)) result.total = result.total + 1 @@ -731,9 +687,7 @@ class OBConnection(OBConnectionBase): fulltext_query = escape_string(fulltext_query.strip()) fulltext_topn = m.topn - fulltext_search_expr, fulltext_search_weight = self._parse_fulltext_columns( - fulltext_query, self._fulltext_search_columns - ) + fulltext_search_expr, fulltext_search_weight = self._parse_fulltext_columns(fulltext_query, self._fulltext_search_columns) for column_name in fulltext_search_expr.keys(): fulltext_search_idx_list.append(fulltext_index_name_template % column_name) @@ -786,7 +740,6 @@ class OBConnection(OBConnectionBase): limit = min(fulltext_topn, limit) for index_name in index_names: - if not self._check_table_exists_cached(index_name): continue @@ -919,10 +872,7 @@ class OBConnection(OBConnectionBase): if total_count == 0: continue - vector_sql = self._build_vector_search_sql( - index_name, fields_expr, vector_search_score_expr, filters_expr, - vector_search_filter, vector_search_expr, limit, vector_topn, offset - ) + vector_sql = self._build_vector_search_sql(index_name, fields_expr, vector_search_score_expr, filters_expr, vector_search_filter, vector_search_expr, limit, vector_topn, offset) logger.debug("OBConnection.search with vector sql: %s", vector_sql) rows, elapsed_time = self._execute_search_sql(vector_sql) logger.info( @@ -954,8 +904,7 @@ class OBConnection(OBConnectionBase): continue fulltext_sql = self._build_fulltext_search_sql( - index_name, fields_expr, fulltext_search_score_expr, filters_expr, - fulltext_search_filter, offset, limit, fulltext_topn, fulltext_search_hint + index_name, fields_expr, fulltext_search_score_expr, filters_expr, fulltext_search_filter, offset, limit, fulltext_topn, fulltext_search_hint ) logger.debug("OBConnection.search with fulltext sql: %s", fulltext_sql) rows, elapsed_time = self._execute_search_sql(fulltext_sql) @@ -975,10 +924,7 @@ class OBConnection(OBConnectionBase): raise ValueError("Only one aggregation field is supported in OceanBase.") agg_field = agg_fields[0] if agg_field in array_columns: - res = self.client.perform_raw_text_sql( - f"SELECT {agg_field} FROM {index_name}" - f" WHERE {agg_field} IS NOT NULL AND {filters_expr}" - ) + res = self.client.perform_raw_text_sql(f"SELECT {agg_field} FROM {index_name} WHERE {agg_field} IS NOT NULL AND {filters_expr}") counts = {} for row in res: if row[0]: @@ -997,22 +943,22 @@ class OBConnection(OBConnectionBase): counts[v] = counts.get(v, 0) + 1 for v, count in counts.items(): - result.chunks.append({ - "value": v, - "count": count, - }) + result.chunks.append( + { + "value": v, + "count": count, + } + ) result.total += len(counts) else: - res = self.client.perform_raw_text_sql( - f"SELECT {agg_field}, COUNT(*) as count FROM {index_name}" - f" WHERE {agg_field} IS NOT NULL AND {filters_expr}" - f" GROUP BY {agg_field}" - ) + res = self.client.perform_raw_text_sql(f"SELECT {agg_field}, COUNT(*) as count FROM {index_name} WHERE {agg_field} IS NOT NULL AND {filters_expr} GROUP BY {agg_field}") for row in res: - result.chunks.append({ - "value": row[0], - "count": int(row[1]), - }) + result.chunks.append( + { + "value": row[0], + "count": int(row[1]), + } + ) result.total += 1 else: # only filter @@ -1030,20 +976,14 @@ class OBConnection(OBConnectionBase): rows, elapsed_time = self._execute_search_sql(count_sql) total_count = rows[0][0] if rows else 0 result.total += total_count - logger.info( - f"OBConnection.search table {index_name}, search type: normal, step: 1-count, elapsed time: {elapsed_time:.3f} seconds," - f" condition: '{condition}'," - f" got count: {total_count}" - ) + logger.info(f"OBConnection.search table {index_name}, search type: normal, step: 1-count, elapsed time: {elapsed_time:.3f} seconds, condition: '{condition}', got count: {total_count}") if total_count == 0: continue order_by_expr = ("ORDER BY " + ", ".join(orders)) if len(orders) > 0 else "" limit_expr = f"LIMIT {offset}, {limit}" if limit != 0 else "" - filter_sql = self._build_filter_search_sql( - index_name, fields_expr, filters_expr, order_by_expr, limit_expr - ) + filter_sql = self._build_filter_search_sql(index_name, fields_expr, filters_expr, order_by_expr, limit_expr) logger.debug("OBConnection.search with normal sql: %s", filter_sql) rows, elapsed_time = self._execute_search_sql(filter_sql) logger.info( @@ -1069,10 +1009,7 @@ class OBConnection(OBConnectionBase): return doc except json.JSONDecodeError as e: logger.error(f"JSON decode error when getting chunk {chunk_id}: {str(e)}") - return { - "id": chunk_id, - "error": f"Failed to parse chunk data due to invalid JSON: {str(e)}" - } + return {"id": chunk_id, "error": f"Failed to parse chunk data due to invalid JSON: {str(e)}"} except Exception as e: logger.exception(f"OBConnection.get({chunk_id}) got exception") raise e @@ -1114,10 +1051,10 @@ class OBConnection(OBConnectionBase): for vv in v: if isinstance(vv, str): cleaned_str = vv.strip() - cleaned_str = cleaned_str.replace('\\', '\\\\') - cleaned_str = cleaned_str.replace('\n', '\\n') - cleaned_str = cleaned_str.replace('\r', '\\r') - cleaned_str = cleaned_str.replace('\t', '\\t') + cleaned_str = cleaned_str.replace("\\", "\\\\") + cleaned_str = cleaned_str.replace("\n", "\\n") + cleaned_str = cleaned_str.replace("\r", "\\r") + cleaned_str = cleaned_str.replace("\t", "\\t") cleaned_v.append(cleaned_str) else: cleaned_v.append(vv) @@ -1231,11 +1168,7 @@ class OBConnection(OBConnectionBase): if not set_values: return True - update_sql = ( - f"UPDATE {index_name}" - f" SET {', '.join(set_values)}" - f" WHERE {' AND '.join(filters)}" - ) + update_sql = f"UPDATE {index_name} SET {', '.join(set_values)} WHERE {' AND '.join(filters)}" logger.debug("OBConnection.update sql: %s", update_sql) try: @@ -1330,7 +1263,8 @@ class OBConnection(OBConnectionBase): if question and not self.is_chinese(question): highlighted_txt = re.sub( r"(^|\W)(%s)(\W|$)" % re.escape(question), - r"\1\2\3", highlighted_txt, + r"\1\2\3", + highlighted_txt, flags=re.IGNORECASE | re.MULTILINE, ) if re.search(r"[^<>]+", highlighted_txt, flags=re.IGNORECASE | re.MULTILINE): @@ -1339,11 +1273,11 @@ class OBConnection(OBConnectionBase): for keyword in keywords: highlighted_txt = re.sub( r"(^|\W)(%s)(\W|$)" % re.escape(keyword), - r"\1\2\3", highlighted_txt, + r"\1\2\3", + highlighted_txt, flags=re.IGNORECASE | re.MULTILINE, ) - if len(re.findall(r'', highlighted_txt)) > 0 or len( - re.findall(r'\s*', highlighted_txt)) > 0: + if len(re.findall(r"", highlighted_txt)) > 0 or len(re.findall(r"\s*", highlighted_txt)) > 0: return highlighted_txt else: return None @@ -1361,13 +1295,9 @@ class OBConnection(OBConnectionBase): token_pos = highlighted_txt.rfind(token, 0, last_pos) if token_pos != -1: if token in keywords: - highlighted_txt = ( - highlighted_txt[:token_pos] + - f'{token}' + - highlighted_txt[token_pos + len(token):] - ) + highlighted_txt = highlighted_txt[:token_pos] + f"{token}" + highlighted_txt[token_pos + len(token) :] last_pos = token_pos - return re.sub(r'', '', highlighted_txt) + return re.sub(r"", "", highlighted_txt) def get_highlight(self, res, keywords: list[str], fieldnm: str): ans = {} diff --git a/rag/utils/opendal_conn.py b/rag/utils/opendal_conn.py index b2a364b602..fcf7a7a61d 100644 --- a/rag/utils/opendal_conn.py +++ b/rag/utils/opendal_conn.py @@ -22,9 +22,9 @@ SET GLOBAL max_allowed_packet={} def get_opendal_config(): try: - opendal_config = get_base_config('opendal', {}) - if opendal_config.get("scheme", "mysql") == 'mysql': - mysql_config = get_base_config('mysql', {}) + opendal_config = get_base_config("opendal", {}) + if opendal_config.get("scheme", "mysql") == "mysql": + mysql_config = get_base_config("mysql", {}) max_packet = mysql_config.get("max_allowed_packet", 134217728) kwargs = { "scheme": "mysql", @@ -34,10 +34,9 @@ def get_opendal_config(): "password": mysql_config.get("password", ""), "database": mysql_config.get("name", "test_open_dal"), "table": opendal_config.get("config", {}).get("oss_table", "opendal_storage"), - "max_allowed_packet": str(max_packet) + "max_allowed_packet": str(max_packet), } - kwargs[ - "connection_string"] = f"mysql://{kwargs['user']}:{quote_plus(kwargs['password'])}@{kwargs['host']}:{kwargs['port']}/{kwargs['database']}?max_allowed_packet={max_packet}" + kwargs["connection_string"] = f"mysql://{kwargs['user']}:{quote_plus(kwargs['password'])}@{kwargs['host']}:{kwargs['port']}/{kwargs['database']}?max_allowed_packet={max_packet}" else: scheme = opendal_config.get("scheme") config_data = opendal_config.get("config", {}) @@ -66,8 +65,8 @@ def get_opendal_config(): class OpenDALStorage: def __init__(self): self._kwargs = get_opendal_config() - self._scheme = self._kwargs.get('scheme', 'mysql') - if self._scheme == 'mysql': + self._scheme = self._kwargs.get("scheme", "mysql") + if self._scheme == "mysql": self.init_db_config() self.init_opendal_mysql_table() self._operator = opendal.Operator(**self._kwargs) @@ -96,15 +95,9 @@ class OpenDALStorage: def init_db_config(self): try: - conn = pymysql.connect( - host=self._kwargs['host'], - port=int(self._kwargs['port']), - user=self._kwargs['user'], - password=self._kwargs['password'], - database=self._kwargs['database'] - ) + conn = pymysql.connect(host=self._kwargs["host"], port=int(self._kwargs["port"]), user=self._kwargs["user"], password=self._kwargs["password"], database=self._kwargs["database"]) cursor = conn.cursor() - max_packet = self._kwargs.get('max_allowed_packet', 4194304) # Default to 4MB if not specified + max_packet = self._kwargs.get("max_allowed_packet", 4194304) # Default to 4MB if not specified # Ensure max_packet is a valid integer to prevent SQL injection cursor.execute(SET_MAX_ALLOWED_PACKET_SQL.format(int(max_packet))) conn.commit() @@ -116,18 +109,12 @@ class OpenDALStorage: raise def init_opendal_mysql_table(self): - table_name = self._kwargs['table'] + table_name = self._kwargs["table"] # Validate table name to prevent SQL injection - if not re.match(r'^[a-zA-Z0-9_]+$', table_name): + if not re.match(r"^[a-zA-Z0-9_]+$", table_name): raise ValueError(f"Invalid table name: {table_name}") - conn = pymysql.connect( - host=self._kwargs['host'], - port=int(self._kwargs['port']), - user=self._kwargs['user'], - password=self._kwargs['password'], - database=self._kwargs['database'] - ) + conn = pymysql.connect(host=self._kwargs["host"], port=int(self._kwargs["port"]), user=self._kwargs["user"], password=self._kwargs["password"], database=self._kwargs["database"]) cursor = conn.cursor() cursor.execute(CREATE_TABLE_SQL.format(table_name)) conn.commit() diff --git a/rag/utils/oss_conn.py b/rag/utils/oss_conn.py index 82236f6eb2..8e372b2c95 100644 --- a/rag/utils/oss_conn.py +++ b/rag/utils/oss_conn.py @@ -28,14 +28,14 @@ class RAGFlowOSS: def __init__(self): self.conn = None self.oss_config = settings.OSS - self.access_key = self.oss_config.get('access_key', None) - self.secret_key = self.oss_config.get('secret_key', None) - self.endpoint_url = self.oss_config.get('endpoint_url', None) - self.region = self.oss_config.get('region', None) - self.bucket = self.oss_config.get('bucket', None) - self.prefix_path = self.oss_config.get('prefix_path', None) - self.signature_version = self.oss_config.get('signature_version', None) - self.addressing_style = self.oss_config.get('addressing_style', None) + self.access_key = self.oss_config.get("access_key", None) + self.secret_key = self.oss_config.get("secret_key", None) + self.endpoint_url = self.oss_config.get("endpoint_url", None) + self.region = self.oss_config.get("region", None) + self.bucket = self.oss_config.get("bucket", None) + self.prefix_path = self.oss_config.get("prefix_path", None) + self.signature_version = self.oss_config.get("signature_version", None) + self.addressing_style = self.oss_config.get("addressing_style", None) self.__open__() @staticmethod @@ -67,23 +67,14 @@ class RAGFlowOSS: config_kwargs = {} if self.signature_version: - config_kwargs['signature_version'] = self.signature_version + config_kwargs["signature_version"] = self.signature_version if self.addressing_style: - config_kwargs['s3'] = { - 'addressing_style': self.addressing_style - } + config_kwargs["s3"] = {"addressing_style": self.addressing_style} config = Config(**config_kwargs) if config_kwargs else None # Reference:https://help.aliyun.com/zh/oss/developer-reference/use-amazon-s3-sdks-to-access-oss - self.conn = boto3.client( - 's3', - region_name=self.region, - aws_access_key_id=self.access_key, - aws_secret_access_key=self.secret_key, - endpoint_url=self.endpoint_url, - config=config - ) + self.conn = boto3.client("s3", region_name=self.region, aws_access_key_id=self.access_key, aws_secret_access_key=self.secret_key, endpoint_url=self.endpoint_url, config=config) except Exception: logging.exception(f"Fail to connect at region {self.region}") @@ -150,7 +141,7 @@ class RAGFlowOSS: for _ in range(1): try: r = self.conn.get_object(Bucket=bucket, Key=fnm) - object_data = r['Body'].read() + object_data = r["Body"].read() return object_data except Exception: logging.exception(f"fail get {bucket}/{fnm}") @@ -165,7 +156,7 @@ class RAGFlowOSS: if self.conn.head_object(Bucket=bucket, Key=fnm): return True except ClientError as e: - if e.response['Error']['Code'] == '404': + if e.response["Error"]["Code"] == "404": return False else: raise @@ -175,10 +166,7 @@ class RAGFlowOSS: def get_presigned_url(self, bucket, fnm, expires, tenant_id=None): for _ in range(10): try: - r = self.conn.generate_presigned_url('get_object', - Params={'Bucket': bucket, - 'Key': fnm}, - ExpiresIn=expires) + r = self.conn.generate_presigned_url("get_object", Params={"Bucket": bucket, "Key": fnm}, ExpiresIn=expires) return r except Exception: diff --git a/rag/utils/raptor_utils.py b/rag/utils/raptor_utils.py index e461841abf..083d42b242 100644 --- a/rag/utils/raptor_utils.py +++ b/rag/utils/raptor_utils.py @@ -72,6 +72,7 @@ def _as_extra_dict(extra) -> dict: # Fallback: try parsing Python dict literal (single quotes) try: import ast + parsed = ast.literal_eval(extra) if isinstance(parsed, dict): return parsed @@ -138,10 +139,10 @@ def make_raptor_summary_chunk_id(content: str, doc_id: str) -> str: def is_structured_file_type(file_type: Optional[str]) -> bool: """ Check if a file type is structured data (Excel, CSV, etc.) - + Args: file_type: File extension (e.g., ".xlsx", ".csv") - + Returns: True if file is structured data type """ @@ -159,11 +160,11 @@ def is_structured_file_type(file_type: Optional[str]) -> bool: def is_tabular_pdf(parser_id: str = "", parser_config: Optional[dict] = None) -> bool: """ Check if a PDF is being parsed as tabular data. - + Args: parser_id: Parser ID (e.g., "table", "naive") parser_config: Parser configuration dict - + Returns: True if PDF is being parsed as tabular data """ @@ -180,25 +181,20 @@ def is_tabular_pdf(parser_id: str = "", parser_config: Optional[dict] = None) -> return False -def should_skip_raptor( - file_type: Optional[str] = None, - parser_id: str = "", - parser_config: Optional[dict] = None, - raptor_config: Optional[dict] = None -) -> bool: +def should_skip_raptor(file_type: Optional[str] = None, parser_id: str = "", parser_config: Optional[dict] = None, raptor_config: Optional[dict] = None) -> bool: """ Determine if Raptor should be skipped for a given document. - + This function implements the logic to automatically disable Raptor for: 1. Excel files (.xls, .xlsx, .csv, etc.) 2. PDFs with tabular data (using table parser or html4excel) - + Args: file_type: File extension (e.g., ".xlsx", ".pdf") parser_id: Parser ID being used parser_config: Parser configuration dict raptor_config: Raptor configuration dict (can override with auto_disable_for_structured_data) - + Returns: True if Raptor should be skipped, False otherwise """ @@ -224,19 +220,15 @@ def should_skip_raptor( return False -def get_skip_reason( - file_type: Optional[str] = None, - parser_id: str = "", - parser_config: Optional[dict] = None -) -> str: +def get_skip_reason(file_type: Optional[str] = None, parser_id: str = "", parser_config: Optional[dict] = None) -> str: """ Get a human-readable reason why Raptor was skipped. - + Args: file_type: File extension parser_id: Parser ID being used parser_config: Parser configuration dict - + Returns: Reason string, or empty string if Raptor should not be skipped """ diff --git a/rag/utils/s3_conn.py b/rag/utils/s3_conn.py index fd5fe37fd6..fe2a5ad248 100644 --- a/rag/utils/s3_conn.py +++ b/rag/utils/s3_conn.py @@ -29,15 +29,15 @@ class RAGFlowS3: def __init__(self): self.conn = None self.s3_config = settings.S3 - self.access_key = self.s3_config.get('access_key', None) - self.secret_key = self.s3_config.get('secret_key', None) - self.session_token = self.s3_config.get('session_token', None) - self.region_name = self.s3_config.get('region_name', None) - self.endpoint_url = self.s3_config.get('endpoint_url', None) - self.signature_version = self.s3_config.get('signature_version', None) - self.addressing_style = self.s3_config.get('addressing_style', None) - self.bucket = self.s3_config.get('bucket', None) - self.prefix_path = self.s3_config.get('prefix_path', None) + self.access_key = self.s3_config.get("access_key", None) + self.secret_key = self.s3_config.get("secret_key", None) + self.session_token = self.s3_config.get("session_token", None) + self.region_name = self.s3_config.get("region_name", None) + self.endpoint_url = self.s3_config.get("endpoint_url", None) + self.signature_version = self.s3_config.get("signature_version", None) + self.addressing_style = self.s3_config.get("addressing_style", None) + self.bucket = self.s3_config.get("bucket", None) + self.prefix_path = self.s3_config.get("prefix_path", None) self.__open__() @staticmethod @@ -53,7 +53,7 @@ class RAGFlowS3: def use_prefix_path(method): def wrapper(self, bucket, fnm, *args, **kwargs): # If the prefix path is set, use the prefix path. - # The bucket passed from the upstream call is + # The bucket passed from the upstream call is # used as the file prefix. This is especially useful when you're using the default bucket if self.prefix_path: fnm = f"{self.prefix_path}/{bucket}/{fnm}" @@ -75,25 +75,25 @@ class RAGFlowS3: # see doc: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html#configuring-credentials if self.access_key and self.secret_key: s3_params = { - 'aws_access_key_id': self.access_key, - 'aws_secret_access_key': self.secret_key, - 'aws_session_token': self.session_token, + "aws_access_key_id": self.access_key, + "aws_secret_access_key": self.secret_key, + "aws_session_token": self.session_token, } if self.region_name: - s3_params['region_name'] = self.region_name + s3_params["region_name"] = self.region_name if self.endpoint_url: - s3_params['endpoint_url'] = self.endpoint_url + s3_params["endpoint_url"] = self.endpoint_url # Configure signature_version and addressing_style through Config object if self.signature_version: - config_kwargs['signature_version'] = self.signature_version + config_kwargs["signature_version"] = self.signature_version if self.addressing_style: - config_kwargs['s3'] = {'addressing_style': self.addressing_style} + config_kwargs["s3"] = {"addressing_style": self.addressing_style} if config_kwargs: - s3_params['config'] = Config(**config_kwargs) + s3_params["config"] = Config(**config_kwargs) - self.conn = [boto3.client('s3', **s3_params)] + self.conn = [boto3.client("s3", **s3_params)] except Exception: logging.exception(f"Fail to connect at region {self.region_name} or endpoint {self.endpoint_url}") @@ -160,7 +160,7 @@ class RAGFlowS3: for _ in range(1): try: r = self.conn[0].get_object(Bucket=bucket, Key=fnm) - object_data = r['Body'].read() + object_data = r["Body"].read() return object_data except Exception: logging.exception(f"fail get {bucket}/{fnm}") @@ -175,7 +175,7 @@ class RAGFlowS3: if self.conn[0].head_object(Bucket=bucket, Key=fnm): return True except ClientError as e: - if e.response['Error']['Code'] == '404': + if e.response["Error"]["Code"] == "404": return False else: raise @@ -185,10 +185,7 @@ class RAGFlowS3: def get_presigned_url(self, bucket, fnm, expires, *args, **kwargs): for _ in range(10): try: - r = self.conn[0].generate_presigned_url('get_object', - Params={'Bucket': bucket, - 'Key': fnm}, - ExpiresIn=expires) + r = self.conn[0].generate_presigned_url("get_object", Params={"Bucket": bucket, "Key": fnm}, ExpiresIn=expires) return r except Exception: @@ -207,7 +204,7 @@ class RAGFlowS3: try: actual_src_bucket, actual_src_path = self._resolve_path(src_bucket, src_path) actual_dest_bucket, actual_dest_path = self._resolve_path(dest_bucket, dest_path) - copy_source = {'Bucket': actual_src_bucket, 'Key': actual_src_path} + copy_source = {"Bucket": actual_src_bucket, "Key": actual_src_path} self.conn[0].copy_object( CopySource=copy_source, Bucket=actual_dest_bucket, @@ -226,9 +223,7 @@ class RAGFlowS3: self.conn[0].delete_object(Bucket=actual_src_bucket, Key=actual_src_path) return True except Exception: - logging.exception( - f"Copied but failed to delete source: {src_bucket}/{src_path}" - ) + logging.exception(f"Copied but failed to delete source: {src_bucket}/{src_path}") return False else: logging.error(f"Copy failed, move aborted: {src_bucket}/{src_path}") diff --git a/rag/utils/tavily_conn.py b/rag/utils/tavily_conn.py index 1b391fb1bb..33d265a050 100644 --- a/rag/utils/tavily_conn.py +++ b/rag/utils/tavily_conn.py @@ -25,13 +25,8 @@ class Tavily: def search(self, query): try: - response = self.tavily_client.search( - query=query, - search_depth="advanced", - max_results=6 - ) - return [{"url": res["url"], "title": res["title"], "content": res["content"], "score": res["score"]} for res - in response["results"]] + response = self.tavily_client.search(query=query, search_depth="advanced", max_results=6) + return [{"url": res["url"], "title": res["title"], "content": res["content"], "score": res["score"]} for res in response["results"]] except Exception as e: logging.exception(e) @@ -43,27 +38,24 @@ class Tavily: logging.info("[Tavily]Q: " + question) for r in self.search(question): id = get_uuid() - chunks.append({ - "chunk_id": id, - "content_ltks": rag_tokenizer.tokenize(r["content"]), - "content_with_weight": r["content"], - "doc_id": id, - "docnm_kwd": r["title"], - "kb_id": [], - "important_kwd": [], - "image_id": "", - "similarity": r["score"], - "vector_similarity": 1., - "term_similarity": 0, - "vector": [], - "positions": [], - "url": r["url"] - }) - aggs.append({ - "doc_name": r["title"], - "doc_id": id, - "count": 1, - "url": r["url"] - }) + chunks.append( + { + "chunk_id": id, + "content_ltks": rag_tokenizer.tokenize(r["content"]), + "content_with_weight": r["content"], + "doc_id": id, + "docnm_kwd": r["title"], + "kb_id": [], + "important_kwd": [], + "image_id": "", + "similarity": r["score"], + "vector_similarity": 1.0, + "term_similarity": 0, + "vector": [], + "positions": [], + "url": r["url"], + } + ) + aggs.append({"doc_name": r["title"], "doc_id": id, "count": 1, "url": r["url"]}) logging.info("[Tavily]R: " + r["content"][:128] + "...") return {"chunks": chunks, "doc_aggs": aggs} diff --git a/ragflow_deps/download_deps.py b/ragflow_deps/download_deps.py index cf3839098e..aa4f7605ad 100644 --- a/ragflow_deps/download_deps.py +++ b/ragflow_deps/download_deps.py @@ -70,12 +70,9 @@ def get_urls(use_china_mirrors=False) -> list[Union[str, list[str]]]: # Native static libraries for Go build (pdfium, pdf_oxide, office_oxide) # Used by build.sh's check_*_deps functions — pre-downloaded to avoid # network access during CI. - ["https://github.com/kognitos/pdfium-static/releases/download/chromium%2F7809/pdfium-linux-x64-static.tgz", - "pdfium-linux-x64-static.tgz"], - ["https://github.com/yfedoseev/pdf_oxide/releases/download/v0.3.67/pdf_oxide-go-ffi-linux-amd64.tar.gz", - "pdf_oxide-go-ffi-linux-amd64.tar.gz"], - ["https://github.com/yfedoseev/office_oxide/releases/download/v0.1.2/native-linux-x86_64.tar.gz", - "office_oxide-linux-x86_64.tar.gz"], + ["https://github.com/kognitos/pdfium-static/releases/download/chromium%2F7809/pdfium-linux-x64-static.tgz", "pdfium-linux-x64-static.tgz"], + ["https://github.com/yfedoseev/pdf_oxide/releases/download/v0.3.67/pdf_oxide-go-ffi-linux-amd64.tar.gz", "pdf_oxide-go-ffi-linux-amd64.tar.gz"], + ["https://github.com/yfedoseev/office_oxide/releases/download/v0.1.2/native-linux-x86_64.tar.gz", "office_oxide-linux-x86_64.tar.gz"], ] else: return [ @@ -107,12 +104,9 @@ def get_urls(use_china_mirrors=False) -> list[Union[str, list[str]]]: # Native static libraries for Go build (pdfium, pdf_oxide, office_oxide) # Used by build.sh's check_*_deps functions — pre-downloaded to avoid # network access during CI. - ["https://github.com/kognitos/pdfium-static/releases/download/chromium%2F7809/pdfium-linux-x64-static.tgz", - "pdfium-linux-x64-static.tgz"], - ["https://github.com/yfedoseev/pdf_oxide/releases/download/v0.3.67/pdf_oxide-go-ffi-linux-amd64.tar.gz", - "pdf_oxide-go-ffi-linux-amd64.tar.gz"], - ["https://github.com/yfedoseev/office_oxide/releases/download/v0.1.2/native-linux-x86_64.tar.gz", - "office_oxide-linux-x86_64.tar.gz"], + ["https://github.com/kognitos/pdfium-static/releases/download/chromium%2F7809/pdfium-linux-x64-static.tgz", "pdfium-linux-x64-static.tgz"], + ["https://github.com/yfedoseev/pdf_oxide/releases/download/v0.3.67/pdf_oxide-go-ffi-linux-amd64.tar.gz", "pdf_oxide-go-ffi-linux-amd64.tar.gz"], + ["https://github.com/yfedoseev/office_oxide/releases/download/v0.1.2/native-linux-x86_64.tar.gz", "office_oxide-linux-x86_64.tar.gz"], ] @@ -163,6 +157,7 @@ if __name__ == "__main__": ("office_oxide-linux-x86_64.tar.gz", "office_oxide"), ] import tarfile + for archive, subdir in extractions: archive_path = os.path.join(os.getcwd(), archive) if not os.path.isfile(archive_path): diff --git a/run_tests.py b/run_tests.py index 48b0391873..1a7266aba7 100755 --- a/run_tests.py +++ b/run_tests.py @@ -24,11 +24,12 @@ from typing import List class Colors: """ANSI color codes for terminal output""" - RED = '\033[0;31m' - GREEN = '\033[0;32m' - YELLOW = '\033[1;33m' - BLUE = '\033[0;34m' - NC = '\033[0m' # No Color + + RED = "\033[0;31m" + GREEN = "\033[0;32m" + YELLOW = "\033[1;33m" + BLUE = "\033[0;34m" + NC = "\033[0m" # No Color class TestRunner: @@ -36,7 +37,7 @@ class TestRunner: def __init__(self): self.project_root = Path(__file__).parent.resolve() - self.ut_dir = Path(self.project_root / 'test' / 'unit_test') + self.ut_dir = Path(self.project_root / "test" / "unit_test") # Default options self.coverage = False self.parallel = False @@ -84,7 +85,7 @@ EXAMPLES: # Run in parallel python run_tests.py --parallel - + # Run tests with "-W ignore::SyntaxWarning" option python run_tests.py --ignore @@ -127,17 +128,14 @@ EXAMPLES: if self.coverage: # Relative path from test directory to source code source_path = str(self.project_root / "common") - cmd.extend([ - "--cov", source_path, - "--cov-report", "html", - "--cov-report", "term" - ]) + cmd.extend(["--cov", source_path, "--cov-report", "html", "--cov-report", "term"]) # Add parallel execution if self.parallel: # Try to get number of CPU cores try: import multiprocessing + cpu_count = multiprocessing.cpu_count() cmd.extend(["-n", str(cpu_count)]) except ImportError: @@ -223,53 +221,22 @@ Examples: python run_tests.py --test services/test_dialog_service.py # Run specific test python run_tests.py --markers "unit" # Run only unit tests python run_tests.py --ignore # Run with "-W ignore::SyntaxWarning" option -""" +""", ) - parser.add_argument( - "-c", "--coverage", - action="store_true", - help="Run tests with coverage report" - ) + parser.add_argument("-c", "--coverage", action="store_true", help="Run tests with coverage report") - parser.add_argument( - "-p", "--parallel", - action="store_true", - help="Run tests in parallel (requires pytest-xdist)" - ) + parser.add_argument("-p", "--parallel", action="store_true", help="Run tests in parallel (requires pytest-xdist)") - parser.add_argument( - "-i", "--ignore", - action="store_true", - help="Run tests with '-W ignore::SyntaxWarning' " - ) + parser.add_argument("-i", "--ignore", action="store_true", help="Run tests with '-W ignore::SyntaxWarning' ") - parser.add_argument( - "-v", "--verbose", - action="store_true", - help="Verbose output" - ) + parser.add_argument("-v", "--verbose", action="store_true", help="Verbose output") - parser.add_argument( - "-t", "--test", - type=str, - default="", - help="Run specific test file or directory" - ) + parser.add_argument("-t", "--test", type=str, default="", help="Run specific test file or directory") - parser.add_argument( - "-k", "--keyword", - type=str, - default="", - help="Run tests matching keyword expression (pytest -k)" - ) + parser.add_argument("-k", "--keyword", type=str, default="", help="Run tests matching keyword expression (pytest -k)") - parser.add_argument( - "-m", "--markers", - type=str, - default="", - help="Run tests with specific markers (e.g., 'unit', 'integration')" - ) + parser.add_argument("-m", "--markers", type=str, default="", help="Run tests with specific markers (e.g., 'unit', 'integration')") try: args = parser.parse_args() @@ -311,4 +278,4 @@ def main(): if __name__ == "__main__": - sys.exit(main()) \ No newline at end of file + sys.exit(main()) diff --git a/sdk/python/ragflow_sdk/__init__.py b/sdk/python/ragflow_sdk/__init__.py index 62ddff7160..0bd89e8ee3 100644 --- a/sdk/python/ragflow_sdk/__init__.py +++ b/sdk/python/ragflow_sdk/__init__.py @@ -15,6 +15,7 @@ # from beartype.claw import beartype_this_package + beartype_this_package() import importlib.metadata @@ -30,13 +31,4 @@ from .modules.memory import Memory __version__ = importlib.metadata.version("ragflow_sdk") -__all__ = [ - "RAGFlow", - "DataSet", - "Chat", - "Session", - "Document", - "Chunk", - "Agent", - "Memory" -] +__all__ = ["RAGFlow", "DataSet", "Chat", "Session", "Document", "Chunk", "Agent", "Memory"] diff --git a/sdk/python/ragflow_sdk/modules/__init__.py b/sdk/python/ragflow_sdk/modules/__init__.py index e156bc93dd..177b91dd05 100644 --- a/sdk/python/ragflow_sdk/modules/__init__.py +++ b/sdk/python/ragflow_sdk/modules/__init__.py @@ -12,4 +12,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# \ No newline at end of file +# diff --git a/sdk/python/ragflow_sdk/modules/agent.py b/sdk/python/ragflow_sdk/modules/agent.py index 5e67f40d9e..41ad34989a 100644 --- a/sdk/python/ragflow_sdk/modules/agent.py +++ b/sdk/python/ragflow_sdk/modules/agent.py @@ -20,7 +20,7 @@ from .session import Session class Agent(Base): def __init__(self, rag, res_dict): - self.id = None + self.id = None self.avatar = None self.canvas_type = None self.description = None @@ -30,42 +30,17 @@ class Agent(Base): class Dsl(Base): def __init__(self, rag, res_dict): self.answer = [] - self.components = { - "begin": { - "downstream": ["Answer:China"], - "obj": { - "component_name": "Begin", - "params": {} - }, - "upstream": [] - } - } + self.components = {"begin": {"downstream": ["Answer:China"], "obj": {"component_name": "Begin", "params": {}}, "upstream": []}} self.graph = { "edges": [], - "nodes": [ - { - "data": { - "label": "Begin", - "name": "begin" - }, - "id": "begin", - "position": { - "x": 50, - "y": 200 - }, - "sourcePosition": "left", - "targetPosition": "right", - "type": "beginNode" - } - ] + "nodes": [{"data": {"label": "Begin", "name": "begin"}, "id": "begin", "position": {"x": 50, "y": 200}, "sourcePosition": "left", "targetPosition": "right", "type": "beginNode"}], } - self.history = [] - self.messages = [] - self.path = [] + self.history = [] + self.messages = [] + self.path = [] self.reference = [] super().__init__(rag, res_dict) - def create_session(self, **kwargs) -> Session: res = self.post(f"/agents/{self.id}/sessions", json=kwargs) res = res.json() @@ -73,11 +48,8 @@ class Agent(Base): return Session(self.rag, res.get("data")) raise Exception(res.get("message")) - - def list_sessions(self, page: int = 1, page_size: int = 30, orderby: str = "create_time", desc: bool = True, - id: str = None) -> list[Session]: - res = self.get(f"/agents/{self.id}/sessions", - {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id}) + def list_sessions(self, page: int = 1, page_size: int = 30, orderby: str = "create_time", desc: bool = True, id: str = None) -> list[Session]: + res = self.get(f"/agents/{self.id}/sessions", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id}) res = res.json() if res.get("code") == 0: result_list = [] @@ -86,7 +58,7 @@ class Agent(Base): result_list.append(temp_agent) return result_list raise Exception(res.get("message")) - + def delete_sessions(self, ids: list[str] | None = None, delete_all: bool = False): payload = {"ids": ids} if delete_all: diff --git a/sdk/python/ragflow_sdk/modules/chunk.py b/sdk/python/ragflow_sdk/modules/chunk.py index f6d1da09a3..49092931d8 100644 --- a/sdk/python/ragflow_sdk/modules/chunk.py +++ b/sdk/python/ragflow_sdk/modules/chunk.py @@ -16,6 +16,7 @@ from .base import Base + class ChunkUpdateError(Exception): def __init__(self, code=None, message=None, details=None): self.code = code @@ -23,6 +24,7 @@ class ChunkUpdateError(Exception): self.details = details super().__init__(message) + class Chunk(Base): def __init__(self, rag, res_dict): self.id = "" @@ -48,17 +50,12 @@ class Chunk(Base): res_dict.pop(k) super().__init__(rag, res_dict) - #for backward compatibility + # for backward compatibility if not self.document_name: self.document_name = self.document_keyword - def update(self, update_message: dict): res = self.patch(f"/datasets/{self.dataset_id}/documents/{self.document_id}/chunks/{self.id}", update_message) res = res.json() if res.get("code") != 0: - raise ChunkUpdateError( - code=res.get("code"), - message=res.get("message"), - details=res.get("details") - ) + raise ChunkUpdateError(code=res.get("code"), message=res.get("message"), details=res.get("details")) diff --git a/sdk/python/ragflow_sdk/modules/dataset.py b/sdk/python/ragflow_sdk/modules/dataset.py index de520f3fe4..2de1c0c9c5 100644 --- a/sdk/python/ragflow_sdk/modules/dataset.py +++ b/sdk/python/ragflow_sdk/modules/dataset.py @@ -113,18 +113,21 @@ class DataSet(Base): def _get_documents_status(self, document_ids): import time + terminal_states = {"DONE", "FAIL", "CANCEL"} interval_sec = 1 pending = set(document_ids) finished = [] while pending: for doc_id in list(pending): + def fetch_doc(doc_id: str) -> Document | None: try: docs = self.list_documents(id=doc_id) return docs[0] if docs else None except Exception: return None + doc = fetch_doc(doc_id) if doc is None: continue @@ -137,13 +140,12 @@ class DataSet(Base): if pending: time.sleep(interval_sec) return finished - + def async_parse_documents(self, document_ids): res = self.post(f"/datasets/{self.id}/chunks", {"document_ids": document_ids}) res = res.json() if res.get("code") != 0: raise Exception(res.get("message")) - def parse_documents(self, document_ids): try: @@ -151,9 +153,8 @@ class DataSet(Base): self._get_documents_status(document_ids) except KeyboardInterrupt: self.async_cancel_parse_documents(document_ids) - - return self._get_documents_status(document_ids) + return self._get_documents_status(document_ids) def async_cancel_parse_documents(self, document_ids): res = self.rm(f"/datasets/{self.id}/chunks", {"document_ids": document_ids}) diff --git a/sdk/python/ragflow_sdk/modules/memory.py b/sdk/python/ragflow_sdk/modules/memory.py index 4005deeac3..09b057fd2f 100644 --- a/sdk/python/ragflow_sdk/modules/memory.py +++ b/sdk/python/ragflow_sdk/modules/memory.py @@ -18,7 +18,6 @@ from .base import Base class Memory(Base): - def __init__(self, rag, res_dict): self.id = "" self.name = "" @@ -33,7 +32,7 @@ class Memory(Base): self.description = "" self.memory_size = 5 * 1024 * 1024 self.forgetting_policy = "FIFO" - self.temperature = 0.5, + self.temperature = (0.5,) self.system_prompt = "" self.user_prompt = "" for k in list(res_dict.keys()): @@ -57,13 +56,8 @@ class Memory(Base): self._update_from_dict(self.rag, res.get("data", {})) return self - def list_memory_messages(self, agent_id: str | list[str]=None, keywords: str=None, page: int=1, page_size: int=50): - params = { - "agent_id": agent_id, - "keywords": keywords, - "page": page, - "page_size": page_size - } + def list_memory_messages(self, agent_id: str | list[str] = None, keywords: str = None, page: int = 1, page_size: int = 50): + params = {"agent_id": agent_id, "keywords": keywords, "page": page, "page_size": page_size} res = self.get(f"/memories/{self.id}", params) res = res.json() if res.get("code") != 0: @@ -78,9 +72,7 @@ class Memory(Base): return True def update_message_status(self, message_id: int, status: bool): - update_message = { - "status": status - } + update_message = {"status": status} res = self.put(f"/messages/{self.id}:{message_id}", update_message) res = res.json() if res.get("code") != 0: diff --git a/sdk/python/ragflow_sdk/modules/session.py b/sdk/python/ragflow_sdk/modules/session.py index 5152160f6a..9cbd5d9c34 100644 --- a/sdk/python/ragflow_sdk/modules/session.py +++ b/sdk/python/ragflow_sdk/modules/session.py @@ -36,7 +36,6 @@ class Session(Base): self.__session_type = "agent" super().__init__(rag, res_dict) - def ask( self, question="", @@ -89,8 +88,7 @@ class Session(Base): if inputs is not None or release is not None or return_trace is not None: logger.debug( - "Session.ask explicit-params session_type=%s session_id=%s " - "input_keys=%s release=%s return_trace=%s", + "Session.ask explicit-params session_type=%s session_id=%s input_keys=%s release=%s return_trace=%s", self.__session_type, getattr(self, "id", None), list(inputs.keys()) if isinstance(inputs, dict) else None, @@ -111,7 +109,7 @@ class Session(Base): continue # Skip empty lines line = line.strip() if line.startswith("data:"): - content = line[len("data:"):].strip() + content = line[len("data:") :].strip() if content == "[DONE]": break # End of stream else: @@ -122,14 +120,11 @@ class Session(Base): except json.JSONDecodeError: continue # Skip lines that are not valid JSON - event = json_data.get("event",None) + event = json_data.get("event", None) if event and event != "message": continue - if ( - (self.__session_type == "agent" and event == "message_end") - or (self.__session_type == "chat" and json_data.get("data") is True) - ): + if (self.__session_type == "agent" and event == "message_end") or (self.__session_type == "chat" and json_data.get("data") is True): return if self.__session_type == "agent": yield self._structure_answer(json_data) @@ -141,7 +136,6 @@ class Session(Base): except ValueError: raise Exception(f"Invalid response {res}") yield self._structure_answer(json_data["data"]) - def _structure_answer(self, json_data): answer = "" @@ -150,10 +144,7 @@ class Session(Base): elif self.__session_type == "chat": answer = json_data["answer"] reference = json_data.get("reference", {}) - temp_dict = { - "content": answer, - "role": "assistant" - } + temp_dict = {"content": answer, "role": "assistant"} if reference and "chunks" in reference: chunks = reference["chunks"] temp_dict["reference"] = chunks @@ -163,8 +154,7 @@ class Session(Base): def _ask_chat(self, question: str, stream: bool, **kwargs): json_data = {"question": question, "stream": stream, "session_id": self.id} json_data.update(kwargs) - res = self.post(f"/chats/{self.chat_id}/completions", - json_data, stream=stream) + res = self.post(f"/chats/{self.chat_id}/completions", json_data, stream=stream) return res def _ask_agent(self, question: str, stream: bool, **kwargs): @@ -180,8 +170,7 @@ class Session(Base): return res def update(self, update_message): - res = self.patch(f"/chats/{self.chat_id}/sessions/{self.id}", - update_message) + res = self.patch(f"/chats/{self.chat_id}/sessions/{self.id}", update_message) res = res.json() if res.get("code") != 0: raise Exception(res.get("message")) diff --git a/sdk/python/ragflow_sdk/ragflow.py b/sdk/python/ragflow_sdk/ragflow.py index 5228c18c7f..a339fe1d3f 100644 --- a/sdk/python/ragflow_sdk/ragflow.py +++ b/sdk/python/ragflow_sdk/ragflow.py @@ -196,7 +196,7 @@ class RAGFlow: top_k=1024, rerank_id: str | None = None, keyword: bool = False, - cross_languages: list[str]|None = None, + cross_languages: list[str] | None = None, metadata_condition: dict | None = None, use_kg: bool = False, toc_enhance: bool = False, @@ -217,7 +217,7 @@ class RAGFlow: "cross_languages": cross_languages, "metadata_condition": metadata_condition, "use_kg": use_kg, - "toc_enhance": toc_enhance + "toc_enhance": toc_enhance, } # Send a POST request to the backend service (using requests library as an example, actual implementation may vary) res = self.post("/retrieval", json=data_json) @@ -331,7 +331,7 @@ class RAGFlow: "memory_type": memory_type, "storage_type": storage_type, "keywords": keywords, - } + }, ) res = res.json() if res.get("code") != 0: @@ -339,12 +339,7 @@ class RAGFlow: result_list = [] for data in res["data"]["memory_list"]: result_list.append(Memory(self, data)) - return { - "code": res.get("code", 0), - "message": res.get("message"), - "memory_list": result_list, - "total_count": res["data"]["total_count"] - } + return {"code": res.get("code", 0), "message": res.get("message"), "memory_list": result_list, "total_count": res["data"]["total_count"]} def delete_memory(self, memory_id: str): res = self.delete(f"/memories/{memory_id}", {}) @@ -354,21 +349,24 @@ class RAGFlow: def add_message(self, memory_id: list[str], agent_id: str, session_id: str, user_input: str, agent_response: str, user_id: str = "") -> str: """Append messages to memories; ``user_id`` is forwarded only for API-key auth (external subject).""" - payload = { - "memory_id": memory_id, - "agent_id": agent_id, - "session_id": session_id, - "user_input": user_input, - "agent_response": agent_response, - "user_id": user_id - } + payload = {"memory_id": memory_id, "agent_id": agent_id, "session_id": session_id, "user_input": user_input, "agent_response": agent_response, "user_id": user_id} res = self.post("/messages", payload) res = res.json() if res.get("code") != 0: raise Exception(res["message"]) return res["message"] - def search_message(self, query: str, memory_id: list[str], agent_id: str=None, session_id: str=None, user_id: str=None, similarity_threshold: float=0.2, keywords_similarity_weight: float=0.7, top_n: int=10) -> list[dict]: + def search_message( + self, + query: str, + memory_id: list[str], + agent_id: str = None, + session_id: str = None, + user_id: str = None, + similarity_threshold: float = 0.2, + keywords_similarity_weight: float = 0.7, + top_n: int = 10, + ) -> list[dict]: params = { "query": query, "memory_id": memory_id, @@ -377,7 +375,7 @@ class RAGFlow: "user_id": user_id, "similarity_threshold": similarity_threshold, "keywords_similarity_weight": keywords_similarity_weight, - "top_n": top_n + "top_n": top_n, } res = self.get("/messages/search", params) res = res.json() @@ -385,13 +383,8 @@ class RAGFlow: raise Exception(res["message"]) return res["data"] - def get_recent_messages(self, memory_id: list[str], agent_id: str=None, session_id: str=None, limit: int=10) -> list[dict]: - params = { - "memory_id": memory_id, - "agent_id": agent_id, - "session_id": session_id, - "limit": limit - } + def get_recent_messages(self, memory_id: list[str], agent_id: str = None, session_id: str = None, limit: int = 10) -> list[dict]: + params = {"memory_id": memory_id, "agent_id": agent_id, "session_id": session_id, "limit": limit} res = self.get("/messages", params) res = res.json() if res.get("code") != 0: diff --git a/sdk/python/test.py b/sdk/python/test.py index c670033194..8cc0ae0394 100644 --- a/sdk/python/test.py +++ b/sdk/python/test.py @@ -2,7 +2,7 @@ from .ragflow_sdk import RAGFlow rag_object = RAGFlow(api_key="ragflow-FDfRECsXDRagsKPxb_EfZdDPcmngavSgYEzbU_Blgq4", base_url="http://localhost:9222") assistant = rag_object.get_agent("b0bc46e43dfc11f1b4ff84ba59bc54d9") -session = assistant.create_session() +session = assistant.create_session() print("\n==================== Miss R =====================\n") print("Hello. What can I do for you?") @@ -10,8 +10,8 @@ print("Hello. What can I do for you?") while True: question = input("\n==================== User =====================\n> ") print("\n==================== Miss R =====================\n") - + cont = "" for ans in session.ask(question, stream=True): - print(ans.content[len(cont):], end='', flush=True) + print(ans.content[len(cont) :], end="", flush=True) cont = ans.content diff --git a/sdk/python/test/conftest.py b/sdk/python/test/conftest.py index ef6dd44338..200b86ab90 100644 --- a/sdk/python/test/conftest.py +++ b/sdk/python/test/conftest.py @@ -113,12 +113,7 @@ def add_model_instance(auth): pytest.exit(f"Critical error in add model provider: {add_provider_res.get('message')}") add_instance_api = HOST_ADDRESS + "/api/v1/providers/ZHIPU-AI/instances" - add_instance_response = requests.post(url=add_instance_api, headers=authorization, json={ - "instance_name": "CI", - "api_key": ZHIPU_AI_API_KEY, - "region": "default", - "base_url": "" - }) + add_instance_response = requests.post(url=add_instance_api, headers=authorization, json={"instance_name": "CI", "api_key": ZHIPU_AI_API_KEY, "region": "default", "base_url": ""}) add_instance_res = add_instance_response.json() if add_instance_res.get("code") != 0: pytest.exit(f"Critical error in add model instance: {add_instance_res.get('message')}") @@ -139,41 +134,19 @@ def set_tenant_info(get_auth): url = HOST_ADDRESS + "/api/v1/models/default" authorization = {"Authorization": get_auth} # set chat model - set_default_llm_response = requests.patch( - url=url, - headers=authorization, - json={ - "model_provider": "ZHIPU-AI", - "model_instance": "CI", - "model_type": "chat", - "model_name": "glm-4-flash" - }) + set_default_llm_response = requests.patch(url=url, headers=authorization, json={"model_provider": "ZHIPU-AI", "model_instance": "CI", "model_type": "chat", "model_name": "glm-4-flash"}) llm_res = set_default_llm_response.json() if llm_res.get("code") != 0: raise Exception(llm_res.get("message")) # set embedding model set_default_embedding_response = requests.patch( - url=url, - headers=authorization, - json={ - "model_provider": "Builtin", - "model_instance": "Local", - "model_type": "embedding", - "model_name": "BAAI/bge-small-en-v1.5" - }) + url=url, headers=authorization, json={"model_provider": "Builtin", "model_instance": "Local", "model_type": "embedding", "model_name": "BAAI/bge-small-en-v1.5"} + ) embd_res = set_default_embedding_response.json() if embd_res.get("code") != 0: raise Exception(embd_res.get("message")) # set image to text model - set_default_img2txt_response = requests.patch( - url=url, - headers=authorization, - json={ - "model_provider": "ZHIPU-AI", - "model_instance": "CI", - "model_type": "vision", - "model_name": "glm-4v" - }) + set_default_img2txt_response = requests.patch(url=url, headers=authorization, json={"model_provider": "ZHIPU-AI", "model_instance": "CI", "model_type": "vision", "model_name": "glm-4v"}) img2txt_res = set_default_img2txt_response.json() if img2txt_res.get("code") != 0: raise Exception(img2txt_res.get("message")) diff --git a/sdk/python/test/test_frontend_api/common.py b/sdk/python/test/test_frontend_api/common.py index aafe64a591..646c925b7c 100644 --- a/sdk/python/test/test_frontend_api/common.py +++ b/sdk/python/test/test_frontend_api/common.py @@ -73,20 +73,20 @@ def list_document(auth, dataset_id): def get_docs_info(auth, dataset_id, doc_ids=None, doc_id=None): """ Get document information by IDs. - + Args: auth: Authorization header dataset_id: Dataset ID doc_ids: List of document IDs (use for multiple) - exclusive with doc_id doc_id: Single document ID (use for one) - exclusive with doc_ids - + Raises: ValueError: If both doc_id and doc_ids are provided """ # Validate that id and ids are not used together if doc_id and doc_ids: raise ValueError("Cannot use both 'id' and 'ids' parameters at the same time.") - + authorization = {"Authorization": auth} params = {} if doc_ids: @@ -96,7 +96,7 @@ def get_docs_info(auth, dataset_id, doc_ids=None, doc_id=None): elif doc_id: # Single ID params["id"] = doc_id - + # Use /api/v1 prefix for dataset API url = f"{HOST_ADDRESS}/api/v1/datasets/{dataset_id}/documents" res = requests.get(url=url, headers=authorization, params=params) @@ -113,4 +113,3 @@ def parse_docs(auth, doc_ids): def parse_file(auth, document_id): pass - diff --git a/sdk/python/test/test_frontend_api/get_email.py b/sdk/python/test/test_frontend_api/get_email.py index 181094f5ad..3d0f04c1e1 100644 --- a/sdk/python/test/test_frontend_api/get_email.py +++ b/sdk/python/test/test_frontend_api/get_email.py @@ -14,6 +14,7 @@ # limitations under the License. # + def test_get_email(get_email): - print("\nEmail account:",flush=True) - print(f"{get_email}\n",flush=True) \ No newline at end of file + print("\nEmail account:", flush=True) + print(f"{get_email}\n", flush=True) diff --git a/sdk/python/test/test_frontend_api/test_chunk.py b/sdk/python/test/test_frontend_api/test_chunk.py index b1f7ff1bd1..0bdd4b4b1a 100644 --- a/sdk/python/test/test_frontend_api/test_chunk.py +++ b/sdk/python/test/test_frontend_api/test_chunk.py @@ -40,15 +40,15 @@ def test_parse_txt_document(get_auth): break page_number += 1 - filename = 'ragflow_test.txt' + filename = "ragflow_test.txt" res = upload_file(get_auth, dataset_id, f"../test_sdk_api/test_data/{filename}") assert res.get("code") == 0, f"{res.get('message')}" res = list_document(get_auth, dataset_id) doc_id_list = [] - for doc in res['data']['docs']: - doc_id_list.append(doc['id']) + for doc in res["data"]["docs"]: + doc_id_list.append(doc["id"]) res = get_docs_info(get_auth, dataset_id, doc_ids=doc_id_list) print(doc_id_list) @@ -59,13 +59,13 @@ def test_parse_txt_document(get_auth): while True: res = get_docs_info(get_auth, dataset_id, doc_ids=doc_id_list) finished_count = 0 - for doc_info in res['data']: - if doc_info['progress'] == 1: + for doc_info in res["data"]: + if doc_info["progress"] == 1: finished_count += 1 if finished_count == doc_count: break sleep(1) - print('time cost {:.1f}s'.format(timer() - start_ts)) + print("time cost {:.1f}s".format(timer() - start_ts)) # delete dataset if dataset_list: diff --git a/sdk/python/test/test_frontend_api/test_dataset.py b/sdk/python/test/test_frontend_api/test_dataset.py index bfbc02da2d..61674fb877 100644 --- a/sdk/python/test/test_frontend_api/test_dataset.py +++ b/sdk/python/test/test_frontend_api/test_dataset.py @@ -89,7 +89,7 @@ def test_duplicated_name_dataset(get_auth): if isinstance(data, dict): data = data.get("kbs", []) dataset_list = [] - pattern = r'^test_create_dataset.*' + pattern = r"^test_create_dataset.*" for item in data: dataset_name = item.get("name") dataset_id = item.get("id") @@ -106,10 +106,10 @@ def test_duplicated_name_dataset(get_auth): def test_invalid_name_dataset(get_auth): # create dataset res = create_dataset(get_auth, {"name": 0}) - assert res['code'] != 0 + assert res["code"] != 0 res = create_dataset(get_auth, {"name": ""}) - assert res['code'] != 0 + assert res["code"] != 0 long_string = "" @@ -117,7 +117,7 @@ def test_invalid_name_dataset(get_auth): long_string += random.choice(string.ascii_letters + string.digits) res = create_dataset(get_auth, {"name": long_string}) - assert res['code'] != 0 + assert res["code"] != 0 print(res) @@ -144,13 +144,17 @@ def test_update_different_params_dataset_success(get_auth): print(f"found {len(dataset_list)} datasets") dataset_id = dataset_list[0] - res = update_dataset(get_auth, dataset_id, { - "name": "test_update_dataset", - "description": "test", - "permission": "me", - "chunk_method": "presentation", - "language": "spanish", - }) + res = update_dataset( + get_auth, + dataset_id, + { + "name": "test_update_dataset", + "description": "test", + "permission": "me", + "chunk_method": "presentation", + "language": "spanish", + }, + ) assert res.get("code") == 0, f"{res.get('message')}" # delete dataset diff --git a/test.py b/test.py index 21f395a467..13710b425f 100644 --- a/test.py +++ b/test.py @@ -1,9 +1,15 @@ from fastapi import FastAPI, Request + app = FastAPI() + + @app.post("/") async def echo(request: Request): body = await request.body() return body + + if __name__ == "__main__": import uvicorn - uvicorn.run(app, host="0.0.0.0", port=8000) \ No newline at end of file + + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/test/benchmark/auth.py b/test/benchmark/auth.py index 407ef59b9e..1bbb4671e7 100644 --- a/test/benchmark/auth.py +++ b/test/benchmark/auth.py @@ -11,11 +11,10 @@ def encrypt_password(password_plain: str) -> str: try: from api.utils.crypt import crypt except Exception as exc: - raise AuthError( - "Password encryption unavailable; install pycryptodomex (uv sync --python 3.13 --group test)." - ) from exc + raise AuthError("Password encryption unavailable; install pycryptodomex (uv sync --python 3.13 --group test).") from exc return crypt(password_plain) + def register_user(client: HttpClient, email: str, nickname: str, password_enc: str) -> None: payload = {"email": email, "nickname": nickname, "password": password_enc} res = client.request_json("POST", "/users", use_api_base=True, auth_kind=None, json_body=payload) @@ -82,8 +81,7 @@ def set_llm_api_key( "region": "default", "base_url": base_url or "", } - instance_res = client.request_json("POST", f"/providers/{llm_factory}/instances", use_api_base=True, - auth_kind="login", json_body=instance_payload) + instance_res = client.request_json("POST", f"/providers/{llm_factory}/instances", use_api_base=True, auth_kind="login", json_body=instance_payload) instance_msg = instance_res.get("message", "") if instance_res.get("code") != 0 and "already exist" not in instance_msg.lower(): raise AuthError(f"Failed to add instance: {instance_msg}") diff --git a/test/benchmark/report.py b/test/benchmark/report.py index 64008deb26..12a16cf18a 100644 --- a/test/benchmark/report.py +++ b/test/benchmark/report.py @@ -93,9 +93,7 @@ def retrieval_report( lines.append(f"{key}: {value}") lines.extend( [ - "Latency: " - f"avg={_fmt_ms(stats['avg'])}, min={_fmt_ms(stats['min'])}, " - f"p50={_fmt_ms(stats['p50'])}, p90={_fmt_ms(stats['p90'])}, p95={_fmt_ms(stats['p95'])}", + f"Latency: avg={_fmt_ms(stats['avg'])}, min={_fmt_ms(stats['min'])}, p50={_fmt_ms(stats['p50'])}, p90={_fmt_ms(stats['p90'])}, p95={_fmt_ms(stats['p95'])}", f"Total Duration: {_fmt_seconds(total_duration_s)}", f"QPS (requests / total duration): {_fmt_qps(_calc_qps(total_duration_s, iterations))}", ] diff --git a/test/benchmark/utils.py b/test/benchmark/utils.py index d46641344a..2bf5091d7a 100644 --- a/test/benchmark/utils.py +++ b/test/benchmark/utils.py @@ -38,4 +38,3 @@ def split_csv(value): def unique_name(prefix): return f"{prefix}_{int(time.time() * 1000)}" - diff --git a/test/playwright/auth/test_login_success_optional.py b/test/playwright/auth/test_login_success_optional.py index e7fc29fbf5..2360d8ff3c 100644 --- a/test/playwright/auth/test_login_success_optional.py +++ b/test/playwright/auth/test_login_success_optional.py @@ -74,10 +74,7 @@ def step_01_open_login( lowered = seeded_email.lower() example_domain = "infiniflow.io" if lowered.endswith(f"@{example_domain}"): - raise AssertionError( - "SEEDED_USER_EMAIL must be a real account (not *@example.com). " - "Set valid credentials or use DEMO_CREDS=1 for demo mode." - ) + raise AssertionError("SEEDED_USER_EMAIL must be a real account (not *@example.com). Set valid credentials or use DEMO_CREDS=1 for demo mode.") print(f"[AUTH] using email: {seeded_email} (source={source})", flush=True) flow_state["seeded_email"] = seeded_email flow_state["seeded_password"] = seeded_password @@ -162,9 +159,9 @@ def step_03_verify_login( return false; }} """.format( - post_login_path=post_login_path_js, - auth_status_selector=auth_status_selector, - ) + post_login_path=post_login_path_js, + auth_status_selector=auth_status_selector, + ) with step("wait for success or error"): try: @@ -175,9 +172,7 @@ def step_03_verify_login( except PlaywrightTimeoutError as exc: snap("failure") _debug_login_state(page, "wait_for_outcome_timeout") - raise AssertionError( - f"Login result did not resolve in time. url={page.url}" - ) from exc + raise AssertionError(f"Login result did not resolve in time. url={page.url}") from exc with step("verify authenticated UI marker"): outcome = result.json_value() @@ -185,18 +180,13 @@ def step_03_verify_login( snap("error") snap("failure") _debug_login_state(page, "login_error") - raise AssertionError( - "Login error detected. " - f"url={page.url}" - ) + raise AssertionError(f"Login error detected. url={page.url}") path = urlparse(page.url).path if post_login_path: if not path.startswith(post_login_path): snap("failure") _debug_login_state(page, "post_login_path_mismatch") - raise AssertionError( - f"Post-login path mismatch. expected_prefix={post_login_path} url={page.url}" - ) + raise AssertionError(f"Post-login path mismatch. expected_prefix={post_login_path} url={page.url}") elif "/login" in path: snap("failure") _debug_login_state(page, "still_on_login_path") @@ -209,9 +199,7 @@ def step_03_verify_login( except AssertionError as exc: snap("failure") _debug_login_state(page, "login_form_still_visible") - raise AssertionError( - f"Login form still visible after login. url={page.url}" - ) from exc + raise AssertionError(f"Login form still visible after login. url={page.url}") from exc snap("success") diff --git a/test/playwright/auth/test_register_success_optional.py b/test/playwright/auth/test_register_success_optional.py index 1b9cc4184a..650c1ae5ea 100644 --- a/test/playwright/auth/test_register_success_optional.py +++ b/test/playwright/auth/test_register_success_optional.py @@ -26,9 +26,7 @@ def _debug_register_response(page, response_info: dict) -> None: if isinstance(message, str) and len(message) > 300: message = message[:300] print( - "[auth-debug] register_response " - f"url={response_info.get('__url__')} status={response_info.get('__status__')} " - f"code={response_info.get('code')} message={message}", + f"[auth-debug] register_response url={response_info.get('__url__')} status={response_info.get('__status__')} code={response_info.get('code')} message={message}", flush=True, ) try: @@ -54,7 +52,8 @@ def _wait_for_auth_not_loading(page, timeout_ms: int = 5000) -> None: if (!status) return true; return status.getAttribute('data-state') !== 'loading'; } - """ % auth_status_selector, + """ + % auth_status_selector, timeout=timeout_ms, ) @@ -166,15 +165,12 @@ def step_03_submit_registration( ), snap("retry_submitted" if retried else "submitted"), ), - lambda resp: resp.request.method == "POST" - and "/api/v1/users" in resp.url, + lambda resp: resp.request.method == "POST" and "/api/v1/users" in resp.url, timeout_ms=RESULT_TIMEOUT_MS, ) except PlaywrightTimeoutError as exc: snap("failure") - raise AssertionError( - f"Register response not received in time. url={page.url} email={current_email}" - ) from exc + raise AssertionError(f"Register response not received in time. url={page.url} email={current_email}") from exc _debug_register_response(page, response_info) diff --git a/test/playwright/auth/test_register_then_login_flow.py b/test/playwright/auth/test_register_then_login_flow.py index 5c4fce040e..a01c8a6cfc 100644 --- a/test/playwright/auth/test_register_then_login_flow.py +++ b/test/playwright/auth/test_register_then_login_flow.py @@ -27,9 +27,7 @@ def _debug_register_response(page, response_info: dict) -> None: if isinstance(message, str) and len(message) > 300: message = message[:300] print( - "[auth-debug] register_response " - f"url={response_info.get('__url__')} status={response_info.get('__status__')} " - f"code={response_info.get('code')} message={message}", + f"[auth-debug] register_response url={response_info.get('__url__')} status={response_info.get('__status__')} code={response_info.get('code')} message={message}", flush=True, ) try: @@ -41,9 +39,7 @@ def _debug_register_response(page, response_info: dict) -> None: print(f"[auth-debug] sonner_toast_dump_failed: {exc}", flush=True) -def _wait_for_login_outcome( - page, post_login_path: str | None, timeout_ms: int = RESULT_TIMEOUT_MS -): +def _wait_for_login_outcome(page, post_login_path: str | None, timeout_ms: int = RESULT_TIMEOUT_MS): auth_status_selector = json.dumps(AUTH_STATUS) return page.wait_for_function( """ @@ -72,7 +68,8 @@ def _wait_for_login_outcome( if (successByUrl || successMarker) return { state: 'success' }; return false; } - """ % auth_status_selector, + """ + % auth_status_selector, post_login_path, timeout=timeout_ms, ) @@ -171,15 +168,12 @@ def step_03_register_user( auth_click(submit_button, "submit_register"), snap("register_submitted"), ), - lambda resp: resp.request.method == "POST" - and "/api/v1/users" in resp.url, + lambda resp: resp.request.method == "POST" and "/api/v1/users" in resp.url, timeout_ms=RESULT_TIMEOUT_MS, ) except PlaywrightTimeoutError as exc: snap("register_failure") - raise AssertionError( - f"Register response not received in time. url={page.url}" - ) from exc + raise AssertionError(f"Register response not received in time. url={page.url}") from exc _debug_register_response(page, response_info) @@ -187,9 +181,7 @@ def step_03_register_user( snap("register_error_response") snap("register_failure") raise AssertionError( - "Registration error detected. " - f"url={response_info.get('__url__')} status={response_info.get('__status__')} " - f"code={response_info.get('code')} message={response_info.get('message')}" + f"Registration error detected. url={response_info.get('__url__')} status={response_info.get('__status__')} code={response_info.get('code')} message={response_info.get('message')}" ) snap("register_success_response") @@ -258,9 +250,7 @@ def step_05_verify_login( login_result = _wait_for_login_outcome(page, post_login_path) except PlaywrightTimeoutError as exc: snap("login_failure") - raise AssertionError( - f"Login result did not resolve in time. url={page.url}" - ) from exc + raise AssertionError(f"Login result did not resolve in time. url={page.url}") from exc login_outcome = login_result.json_value() if login_outcome.get("state") == "error": @@ -272,9 +262,7 @@ def step_05_verify_login( if post_login_path: if not path.startswith(post_login_path): snap("login_failure") - raise AssertionError( - f"Post-login path mismatch. expected_prefix={post_login_path} url={page.url}" - ) + raise AssertionError(f"Post-login path mismatch. expected_prefix={post_login_path} url={page.url}") elif "/login" in path: snap("login_failure") raise AssertionError(f"URL still on login after submit. url={page.url}") diff --git a/test/playwright/auth/test_smoke_auth_page.py b/test/playwright/auth/test_smoke_auth_page.py index e66e81de63..7b2158ef86 100644 --- a/test/playwright/auth/test_smoke_auth_page.py +++ b/test/playwright/auth/test_smoke_auth_page.py @@ -37,9 +37,7 @@ def step_02_validate_page(ctx: FlowContext, step, snap): input_count = page.locator("input").count() logo_count = page.locator("img[alt='logo']").count() if root_count + input_count + logo_count == 0: - raise AssertionError( - _format_diag(page, response, "No SPA root, inputs, or logo found") - ) + raise AssertionError(_format_diag(page, response, "No SPA root, inputs, or logo found")) STEPS = [ @@ -52,9 +50,7 @@ STEPS = [ @pytest.mark.p0 @pytest.mark.auth @pytest.mark.parametrize("step_fn", flow_params(STEPS)) -def test_auth_page_smoke_flow( - step_fn, flow_page, flow_state, base_url, smoke_login_url, step, snap -): +def test_auth_page_smoke_flow(step_fn, flow_page, flow_state, base_url, smoke_login_url, step, snap): ctx = FlowContext( page=flow_page, state=flow_state, @@ -73,7 +69,4 @@ def _format_diag(page, response, reason: str) -> str: url = page.url title = page.title() snippet = page.content().strip().replace("\n", " ")[:500] - return ( - f"{reason}. url={url} title={title} status={status} " - f"content_type={content_type} snippet={snippet}" - ) + return f"{reason}. url={url} title={title} status={status} content_type={content_type} snippet={snippet}" diff --git a/test/playwright/auth/test_sso_optional.py b/test/playwright/auth/test_sso_optional.py index aae3c1c0fb..51c53b9fd9 100644 --- a/test/playwright/auth/test_sso_optional.py +++ b/test/playwright/auth/test_sso_optional.py @@ -44,7 +44,5 @@ STEPS = [ @pytest.mark.p1 @pytest.mark.auth @pytest.mark.parametrize("step_fn", flow_params(STEPS)) -def test_sso_optional_flow( - step_fn, flow_page, flow_state, login_url, active_auth_context, step, snap -): +def test_sso_optional_flow(step_fn, flow_page, flow_state, login_url, active_auth_context, step, snap): step_fn(flow_page, flow_state, login_url, active_auth_context, step, snap) diff --git a/test/playwright/auth/test_toggle_login_register.py b/test/playwright/auth/test_toggle_login_register.py index 1651db0a04..f6a37c8166 100644 --- a/test/playwright/auth/test_toggle_login_register.py +++ b/test/playwright/auth/test_toggle_login_register.py @@ -13,9 +13,7 @@ def step_01_open_login(flow_page, flow_state, login_url, active_auth_context, st snap("open") -def step_02_switch_to_register( - flow_page, flow_state, login_url, active_auth_context, step, snap -): +def step_02_switch_to_register(flow_page, flow_state, login_url, active_auth_context, step, snap): require(flow_state, "login_opened") form, card = active_auth_context() toggle_button = card.locator(REGISTER_TAB) @@ -29,9 +27,7 @@ def step_02_switch_to_register( snap("toggled_register") -def step_03_assert_register_visible( - flow_page, flow_state, login_url, active_auth_context, step, snap -): +def step_03_assert_register_visible(flow_page, flow_state, login_url, active_auth_context, step, snap): require(flow_state, "login_opened", "register_toggle_available") form, _ = active_auth_context() nickname_input = form.locator(NICKNAME_INPUT) @@ -40,9 +36,7 @@ def step_03_assert_register_visible( snap("register_visible") -def step_04_switch_back_to_login( - flow_page, flow_state, login_url, active_auth_context, step, snap -): +def step_04_switch_back_to_login(flow_page, flow_state, login_url, active_auth_context, step, snap): require(flow_state, "login_opened", "register_toggle_available") form, card = active_auth_context() toggle_back = card.locator(LOGIN_TAB) @@ -52,9 +46,7 @@ def step_04_switch_back_to_login( snap("toggled_login") -def step_05_assert_login_visible( - flow_page, flow_state, login_url, active_auth_context, step, snap -): +def step_05_assert_login_visible(flow_page, flow_state, login_url, active_auth_context, step, snap): require(flow_state, "login_opened", "login_toggled_back") form, _ = active_auth_context() nickname_input = form.locator(NICKNAME_INPUT) @@ -74,7 +66,5 @@ STEPS = [ @pytest.mark.p1 @pytest.mark.auth @pytest.mark.parametrize("step_fn", flow_params(STEPS)) -def test_toggle_login_register_flow( - step_fn, flow_page, flow_state, login_url, active_auth_context, step, snap -): +def test_toggle_login_register_flow(step_fn, flow_page, flow_state, login_url, active_auth_context, step, snap): step_fn(flow_page, flow_state, login_url, active_auth_context, step, snap) diff --git a/test/playwright/auth/test_validation_presence.py b/test/playwright/auth/test_validation_presence.py index 9671b12d20..3e6085425a 100644 --- a/test/playwright/auth/test_validation_presence.py +++ b/test/playwright/auth/test_validation_presence.py @@ -5,9 +5,7 @@ from test.playwright.helpers.auth_selectors import EMAIL_INPUT, SUBMIT_BUTTON from test.playwright.helpers.flow_steps import flow_params, require -def step_01_open_login( - flow_page, flow_state, login_url, active_auth_context, step, snap, auth_click -): +def step_01_open_login(flow_page, flow_state, login_url, active_auth_context, step, snap, auth_click): page = flow_page with step("open login page"): page.goto(login_url, wait_until="domcontentloaded") @@ -15,9 +13,7 @@ def step_01_open_login( snap("open") -def step_02_submit_empty( - flow_page, flow_state, login_url, active_auth_context, step, snap, auth_click -): +def step_02_submit_empty(flow_page, flow_state, login_url, active_auth_context, step, snap, auth_click): require(flow_state, "login_opened") form, _ = active_auth_context() expect(form.locator(EMAIL_INPUT)).to_have_count(1) @@ -30,9 +26,7 @@ def step_02_submit_empty( snap("submitted_empty") -def step_03_assert_validation( - flow_page, flow_state, login_url, active_auth_context, step, snap, auth_click -): +def step_03_assert_validation(flow_page, flow_state, login_url, active_auth_context, step, snap, auth_click): require(flow_state, "login_opened", "submitted_empty") form, _ = active_auth_context() invalid_inputs = form.locator("input[aria-invalid='true']") @@ -52,11 +46,7 @@ def step_03_assert_validation( except AssertionError: pass - raise AssertionError( - "No validation feedback detected after submitting an empty login form. " - "Expected aria-invalid inputs or visible error containers. " - "See artifacts for DOM evidence." - ) + raise AssertionError("No validation feedback detected after submitting an empty login form. Expected aria-invalid inputs or visible error containers. See artifacts for DOM evidence.") STEPS = [ @@ -69,7 +59,5 @@ STEPS = [ @pytest.mark.p1 @pytest.mark.auth @pytest.mark.parametrize("step_fn", flow_params(STEPS)) -def test_validation_presence_flow( - step_fn, flow_page, flow_state, login_url, active_auth_context, step, snap, auth_click -): +def test_validation_presence_flow(step_fn, flow_page, flow_state, login_url, active_auth_context, step, snap, auth_click): step_fn(flow_page, flow_state, login_url, active_auth_context, step, snap, auth_click) diff --git a/test/playwright/conftest.py b/test/playwright/conftest.py index d421064151..31f52719cb 100644 --- a/test/playwright/conftest.py +++ b/test/playwright/conftest.py @@ -1,5 +1,6 @@ import sys from pathlib import Path + _PW_DIR = Path(__file__).resolve().parent if str(_PW_DIR) not in sys.path: sys.path.insert(0, str(_PW_DIR)) @@ -37,15 +38,9 @@ REG_EMAIL_LOCAL_RE = re.compile(r"^[A-Za-z0-9_.-]+$") REG_EMAIL_BACKEND_RE = re.compile(r"^[\w\._-]{1,}@([\w_-]+\.)+[\w-]{2,}$") AUTH_FORM_SELECTOR = "form[data-testid='auth-form']" AUTH_ACTIVE_FORM_SELECTOR = "form[data-testid='auth-form'][data-active='true']" -AUTH_EMAIL_INPUT_SELECTOR = ( - "input[data-testid='auth-email'], [data-testid='auth-email'] input" -) -AUTH_PASSWORD_INPUT_SELECTOR = ( - "input[data-testid='auth-password'], [data-testid='auth-password'] input" -) -AUTH_SUBMIT_SELECTOR = ( - "button[data-testid='auth-submit'], [data-testid='auth-submit'] button, [data-testid='auth-submit']" -) +AUTH_EMAIL_INPUT_SELECTOR = "input[data-testid='auth-email'], [data-testid='auth-email'] input" +AUTH_PASSWORD_INPUT_SELECTOR = "input[data-testid='auth-password'], [data-testid='auth-password'] input" +AUTH_SUBMIT_SELECTOR = "button[data-testid='auth-submit'], [data-testid='auth-submit'] button, [data-testid='auth-submit']" _PUBLIC_KEY_CACHE = None _RSA_CIPHER_CACHE = None @@ -103,9 +98,7 @@ def _sanitize_timeout_ms(value: int | None, fallback: int | None) -> int | None: def _playwright_action_timeout_ms() -> int | None: - raw = _env_int_with_fallback( - "PLAYWRIGHT_ACTION_TIMEOUT_MS", "PW_TIMEOUT_MS", DEFAULT_TIMEOUT_MS - ) + raw = _env_int_with_fallback("PLAYWRIGHT_ACTION_TIMEOUT_MS", "PW_TIMEOUT_MS", DEFAULT_TIMEOUT_MS) return _sanitize_timeout_ms(raw, DEFAULT_TIMEOUT_MS) @@ -119,14 +112,10 @@ def _playwright_auth_ready_timeout_ms() -> int | None: def _playwright_hang_timeout_s() -> int: - raw = _env_int_with_fallback( - "PLAYWRIGHT_HANG_TIMEOUT_S", "HANG_TIMEOUT_S", DEFAULT_HANG_TIMEOUT_S - ) + raw = _env_int_with_fallback("PLAYWRIGHT_HANG_TIMEOUT_S", "HANG_TIMEOUT_S", DEFAULT_HANG_TIMEOUT_S) return raw if raw > 0 else 0 - - def _failure_text(req) -> str: failure = getattr(req, "failure", None) if callable(failure): @@ -316,9 +305,7 @@ def _api_request_json( parsed = json.loads(body.decode("utf-8")) except Exception: parsed = None - raise RuntimeError( - f"{method} {url} failed with HTTPError {exc.code}: {parsed or body!r}" - ) from exc + raise RuntimeError(f"{method} {url} failed with HTTPError {exc.code}: {parsed or body!r}") from exc except URLError as exc: raise RuntimeError(f"{method} {url} failed with URLError: {exc}") from exc @@ -392,10 +379,7 @@ def _extract_auth_header_from_page(page) -> str: """ ) if not token: - raise AssertionError( - "Missing Authorization/Token in localStorage after login. " - "Cannot provision prerequisites via API." - ) + raise AssertionError("Missing Authorization/Token in localStorage after login. Cannot provision prerequisites via API.") return str(token) @@ -406,10 +390,7 @@ def _rsa_encrypt_password(password: str) -> str: from Cryptodome.PublicKey import RSA from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5 except Exception as exc: - raise RuntimeError( - "Cryptodome is required to encrypt passwords for API seeding. " - "Set RAGFLOW_SEEDING_MODE=ui to skip API seeding." - ) from exc + raise RuntimeError("Cryptodome is required to encrypt passwords for API seeding. Set RAGFLOW_SEEDING_MODE=ui to skip API seeding.") from exc if _PUBLIC_KEY_CACHE is None: public_key_path = ROOT_DIR / "conf" / "public.pem" if not public_key_path.exists(): @@ -513,9 +494,7 @@ def _wait_for_auth_success(page, card, form) -> None: status_marker = page.locator("[data-testid='auth-status']") if status_marker.count() > 0: try: - expect(status_marker).to_have_attribute( - "data-state", "success", timeout=timeout_ms - ) + expect(status_marker).to_have_attribute("data-state", "success", timeout=timeout_ms) return except AssertionError: pass @@ -528,13 +507,9 @@ def _wait_for_auth_success(page, card, form) -> None: except PlaywrightTimeoutError: pass try: - expect(card.locator("[data-testid='auth-nickname']")).to_have_count( - 0, timeout=timeout_ms - ) + expect(card.locator("[data-testid='auth-nickname']")).to_have_count(0, timeout=timeout_ms) except AssertionError as exc: - raise RuntimeError( - "Auth success marker not detected after registration." - ) from exc + raise RuntimeError("Auth success marker not detected after registration.") from exc def _ui_register_user( @@ -627,8 +602,7 @@ def pytest_sessionstart(session): faulthandler.dump_traceback_later(hang_timeout, repeat=True) _HANG_WATCHDOG_INSTALLED = True print( - "Playwright hang watchdog enabled: dumps after " - f"{hang_timeout}s (set PLAYWRIGHT_HANG_TIMEOUT_S=0 to disable)", + f"Playwright hang watchdog enabled: dumps after {hang_timeout}s (set PLAYWRIGHT_HANG_TIMEOUT_S=0 to disable)", flush=True, ) else: @@ -739,9 +713,7 @@ def context(browser): try: yield context_instance finally: - if getattr(context_instance, "_trace_enabled", False) and not getattr( - context_instance, "_trace_saved", False - ): + if getattr(context_instance, "_trace_enabled", False) and not getattr(context_instance, "_trace_saved", False): try: context_instance.tracing.stop() except Exception: @@ -927,9 +899,7 @@ def seeded_user_credentials(base_url: str, login_url: str, browser) -> tuple[str seeded_via = None if seeding_mode == "api": details = "; ".join(seed_errors) - raise RuntimeError( - f"Failed to seed user via API registration. {details}" - ) from exc + raise RuntimeError(f"Failed to seed user via API registration. {details}") from exc if seeded_via is None and seeding_mode in {"auto", "ui"}: seeded_via = "ui" @@ -946,9 +916,7 @@ def seeded_user_credentials(base_url: str, login_url: str, browser) -> tuple[str except Exception as ui_exc: seed_errors.append(f"ui: {ui_exc}") details = "; ".join(seed_errors) - raise RuntimeError( - f"Failed to seed user via API or UI registration. {details}" - ) from ui_exc + raise RuntimeError(f"Failed to seed user via API or UI registration. {details}") from ui_exc os.environ["SEEDED_USER_EMAIL"] = email os.environ["SEEDED_USER_PASSWORD"] = password @@ -1020,9 +988,7 @@ def ensure_auth_context( def _ensure_model_provider_ready_via_api(base_url: str, auth_header: str) -> dict: headers = {"Authorization": auth_header} - _, my_llms_payload = _api_request_json( - _build_url(base_url, "/v1/llm/my_llms"), headers=headers - ) + _, my_llms_payload = _api_request_json(_build_url(base_url, "/v1/llm/my_llms"), headers=headers) my_llms_data = _response_data(my_llms_payload) has_provider = bool(my_llms_data) created_provider = False @@ -1038,17 +1004,13 @@ def _ensure_model_provider_ready_via_api(base_url: str, auth_header: str) -> dic _response_data(set_key_payload) has_provider = True created_provider = True - _, my_llms_payload = _api_request_json( - _build_url(base_url, "/v1/llm/my_llms"), headers=headers - ) + _, my_llms_payload = _api_request_json(_build_url(base_url, "/v1/llm/my_llms"), headers=headers) my_llms_data = _response_data(my_llms_payload) if not has_provider: pytest.skip("No model provider configured and ZHIPU_AI_API_KEY is not set.") - _, tenant_payload = _api_request_json( - _build_url(base_url, "/api/v1/users/me/models"), headers=headers - ) + _, tenant_payload = _api_request_json(_build_url(base_url, "/api/v1/users/me/models"), headers=headers) tenant_data = _response_data(tenant_payload) tenant_id = tenant_data.get("tenant_id") if not tenant_id: @@ -1067,9 +1029,7 @@ def _ensure_model_provider_ready_via_api(base_url: str, auth_header: str) -> dic if not target_llm and _provider_has_model(my_llms_data, "ZHIPU-AI", "glm-4-flash"): target_llm = "glm-4-flash@ZHIPU-AI" if not target_llm: - pytest.skip( - "Provider exists but no canonical default llm_id could be inferred for tenant setup." - ) + pytest.skip("Provider exists but no canonical default llm_id could be inferred for tenant setup.") target_embd = current_embd if not target_embd or _is_malformed_tenant_model_value(target_embd): @@ -1104,12 +1064,7 @@ def _ensure_model_provider_ready_via_api(base_url: str, auth_header: str) -> dic target_tts = target_tts or "" should_update_tenant_defaults = ( - target_llm != current_llm - or target_embd != current_embd - or target_img2txt != current_img2txt - or target_asr != current_asr - or target_rerank != current_rerank - or target_tts != current_tts + target_llm != current_llm or target_embd != current_embd or target_img2txt != current_img2txt or target_asr != current_asr or target_rerank != current_rerank or target_tts != current_tts ) if should_update_tenant_defaults: @@ -1165,9 +1120,7 @@ def ensure_model_provider_configured( } if _env_bool("PW_FIXTURE_DEBUG", False): print( - "[prereq] provider_ready " - f"email={email} created_provider={payload.get('created_provider', False)} " - f"llm_factories={payload.get('llm_factories', [])}", + f"[prereq] provider_ready email={email} created_provider={payload.get('created_provider', False)} llm_factories={payload.get('llm_factories', [])}", flush=True, ) _PROVIDER_READY_CACHE[cache_key] = payload @@ -1185,9 +1138,7 @@ def _find_dataset_by_name(kbs_payload: dict | None, dataset_name: str) -> dict | return None -def _ensure_dataset_ready_via_api( - base_url: str, auth_header: str, dataset_name: str -) -> dict: +def _ensure_dataset_ready_via_api(base_url: str, auth_header: str, dataset_name: str) -> dict: headers = {"Authorization": auth_header} list_url = _build_url(base_url, "/api/v1/datasets?page=1&page_size=100") @@ -1211,14 +1162,10 @@ def _ensure_dataset_ready_via_api( if kb_id: return {"kb_id": kb_id, "kb_name": dataset_name, "reused": False} - _, list_payload_after = _api_request_json( - list_url, method="GET", headers=headers - ) + _, list_payload_after = _api_request_json(list_url, method="GET", headers=headers) existing_after = _find_dataset_by_name(list_payload_after, dataset_name) if not existing_after: - raise RuntimeError( - f"Dataset {dataset_name!r} not found after /api/v1/datasets create response={create_payload}" - ) + raise RuntimeError(f"Dataset {dataset_name!r} not found after /api/v1/datasets create response={create_payload}") return { "kb_id": existing_after.get("id"), "kb_name": dataset_name, @@ -1250,9 +1197,7 @@ def ensure_dataset_ready( } if _env_bool("PW_FIXTURE_DEBUG", False): print( - "[prereq] dataset_ready " - f"kb_name={payload.get('kb_name')} reused={payload.get('reused')} " - f"kb_id={payload.get('kb_id')}", + f"[prereq] dataset_ready kb_name={payload.get('kb_name')} reused={payload.get('reused')} kb_id={payload.get('kb_id')}", flush=True, ) _DATASET_READY_CACHE[cache_key] = payload @@ -1390,6 +1335,7 @@ def _debug_dump_auth_state(page, label: str, submit_locator=None) -> None: def auth_debug_dump(page, request): if "flow_page" in request.fixturenames: page = request.getfixturevalue("flow_page") + def _dump(label: str, submit_locator=None) -> None: _debug_dump_auth_state(page, label, submit_locator) @@ -1432,9 +1378,7 @@ def _write_artifacts_if_failed(page, context, request) -> None: except Exception as exc: print(f"[artifact] events dump failed: {exc}", flush=True) - if getattr(context, "_trace_enabled", False) and not getattr( - context, "_trace_saved", False - ): + if getattr(context, "_trace_enabled", False) and not getattr(context, "_trace_saved", False): try: context.tracing.stop(path=str(trace_path)) context._trace_saved = True @@ -1513,10 +1457,7 @@ def _write_auth_ready_diagnostics(page, request, reason: str) -> None: try: summary = _auth_ready_summary(page) - summary_text = ( - f"reason: {reason}\nurl: {page.url}\ntitle: {page.title()}\n" - + _format_auth_ready_summary(summary) - ) + summary_text = f"reason: {reason}\nurl: {page.url}\ntitle: {page.title()}\n" + _format_auth_ready_summary(summary) summary_path.write_text(summary_text, encoding="utf-8") print(summary_text, flush=True) except Exception as exc: @@ -1533,23 +1474,13 @@ def _wait_for_auth_ui_ready(page, request) -> None: expect(active_forms).to_have_count(1, timeout=timeout_ms) except AssertionError as exc: _write_auth_ready_diagnostics(page, request, "auth active form not unique") - raise AssertionError( - "Auth UI not ready within " - f"{timeout_ms}ms. Expected a single active auth form." - ) from exc - ready_forms = active_forms.filter( - has=page.locator(password_selector) - ).filter(has=page.locator(email_selector)).filter( - has=page.locator(submit_selector) - ) + raise AssertionError(f"Auth UI not ready within {timeout_ms}ms. Expected a single active auth form.") from exc + ready_forms = active_forms.filter(has=page.locator(password_selector)).filter(has=page.locator(email_selector)).filter(has=page.locator(submit_selector)) try: expect(ready_forms).not_to_have_count(0, timeout=timeout_ms) except AssertionError as exc: _write_auth_ready_diagnostics(page, request, "auth UI readiness timeout") - raise AssertionError( - "Auth UI not ready within " - f"{timeout_ms}ms. Expected a visible form with email-like and password inputs." - ) from exc + raise AssertionError(f"Auth UI not ready within {timeout_ms}ms. Expected a visible form with email-like and password inputs.") from exc def _wait_for_active_form_clickable(page, request, form) -> None: @@ -1610,14 +1541,9 @@ def _wait_for_active_form_clickable(page, request, form) -> None: ) except Exception: pass - _write_auth_ready_diagnostics( - page, request, "active auth form submit not clickable" - ) + _write_auth_ready_diagnostics(page, request, "active auth form submit not clickable") _debug_dump_auth_state(page, "active_form_not_clickable", submit_buttons) - raise AssertionError( - "Active auth form submit button not clickable within " - f"{timeout_ms}ms. The flip animation may still be in progress." - ) from exc + raise AssertionError(f"Active auth form submit button not clickable within {timeout_ms}ms. The flip animation may still be in progress.") from exc def _locator_is_topmost(locator) -> bool: @@ -1650,16 +1576,10 @@ def auth_click(): return except PlaywrightTimeoutError as exc: message = str(exc).lower() - can_force = ( - "intercepts pointer events" in message - or "element was detached" in message - or "element is not stable" in message - ) + can_force = "intercepts pointer events" in message or "element was detached" in message or "element is not stable" in message if not can_force: raise - if "intercepts pointer events" in message and not _locator_is_topmost( - locator - ): + if "intercepts pointer events" in message and not _locator_is_topmost(locator): if idx >= attempts - 1: raise time.sleep(0.15) @@ -1681,6 +1601,7 @@ def auth_click(): def active_auth_context(page, request): if "flow_page" in request.fixturenames: page = request.getfixturevalue("flow_page") + def _mark_active_form() -> None: timeout_ms = _playwright_auth_ready_timeout_ms() try: @@ -1789,14 +1710,9 @@ def active_auth_context(page, request): timeout=timeout_ms, ) except Exception as exc: - _write_auth_ready_diagnostics( - page, request, "active auth form did not become front-facing" - ) + _write_auth_ready_diagnostics(page, request, "active auth form did not become front-facing") _debug_dump_auth_state(page, "active_form_not_front_facing") - raise AssertionError( - "Active auth form not ready within " - f"{timeout_ms}ms. The flip animation may not have settled." - ) from exc + raise AssertionError(f"Active auth form not ready within {timeout_ms}ms. The flip animation may not have settled.") from exc def _get(): _wait_for_auth_ui_ready(page, request) @@ -1806,12 +1722,8 @@ def active_auth_context(page, request): try: expect(form).to_have_count(1, timeout=timeout_ms) except AssertionError as exc: - _write_auth_ready_diagnostics( - page, request, "active auth form selection failed" - ) - raise AssertionError( - "Active auth form not found. The login card may not be visible or the DOM changed." - ) from exc + _write_auth_ready_diagnostics(page, request, "active auth form selection failed") + raise AssertionError("Active auth form not found. The login card may not be visible or the DOM changed.") from exc _wait_for_active_form_clickable(page, request, form) return form, card diff --git a/test/playwright/e2e/test_dataset_upload_parse.py b/test/playwright/e2e/test_dataset_upload_parse.py index 9e918714b2..3990cbf05a 100644 --- a/test/playwright/e2e/test_dataset_upload_parse.py +++ b/test/playwright/e2e/test_dataset_upload_parse.py @@ -27,10 +27,7 @@ RESULT_TIMEOUT_MS = 15000 def make_test_png(path: Path) -> Path: - png_b64 = ( - "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8" - "/w8AAgMBAp6X6QAAAABJRU5ErkJggg==" - ) + png_b64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/w8AAgMBAp6X6QAAAABJRU5ErkJggg==" path.write_bytes(base64.b64decode(png_b64)) return path @@ -101,9 +98,7 @@ def select_combobox_option( option.click(force=True) if preferred_text: - preferred_option = options.filter( - has_text=re.compile(rf"^{re.escape(preferred_text)}$", re.I) - ) + preferred_option = options.filter(has_text=re.compile(rf"^{re.escape(preferred_text)}$", re.I)) if preferred_option.count() > 0: click_option(preferred_option.first) return preferred_text @@ -152,9 +147,7 @@ def select_ragflow_option( expect(options.first).to_be_visible(timeout=RESULT_TIMEOUT_MS) if preferred_text: - preferred_option = options.filter( - has_text=re.compile(rf"^{re.escape(preferred_text)}$", re.I) - ) + preferred_option = options.filter(has_text=re.compile(rf"^{re.escape(preferred_text)}$", re.I)) if preferred_option.count() > 0: preferred_option.first.click() return preferred_text @@ -352,9 +345,7 @@ def step_03_create_dataset( except Exception: if created_kb_id: page.goto( - urljoin( - base_url.rstrip("/") + "/", f"/dataset/dataset/{created_kb_id}" - ), + urljoin(base_url.rstrip("/") + "/", f"/dataset/dataset/{created_kb_id}"), wait_until="domcontentloaded", ) else: @@ -388,47 +379,31 @@ def step_04_set_dataset_settings( with step("open dataset settings page"): page.goto( - urljoin( - base_url.rstrip("/") + "/", f"/dataset/dataset-setting/{dataset_id}" - ), + urljoin(base_url.rstrip("/") + "/", f"/dataset/dataset-setting/{dataset_id}"), wait_until="domcontentloaded", ) - expect(page.get_by_test_id("ds-settings-basic-name-input")).to_be_visible( - timeout=RESULT_TIMEOUT_MS - ) - expect(page.get_by_test_id("ds-settings-page-save-btn")).to_be_visible( - timeout=RESULT_TIMEOUT_MS - ) + expect(page.get_by_test_id("ds-settings-basic-name-input")).to_be_visible(timeout=RESULT_TIMEOUT_MS) + expect(page.get_by_test_id("ds-settings-page-save-btn")).to_be_visible(timeout=RESULT_TIMEOUT_MS) snap("dataset_settings_open") with step("fill base settings"): - page.get_by_test_id("ds-settings-basic-name-input").fill( - f"{dataset_name}-cfg" - ) - select_combobox_option( - page, "ds-settings-basic-language-select", preferred_text="English" - ) + page.get_by_test_id("ds-settings-basic-name-input").fill(f"{dataset_name}-cfg") + select_combobox_option(page, "ds-settings-basic-language-select", preferred_text="English") avatar_path = make_test_png(tmp_path / "avatar-test.png") - page.get_by_test_id("ds-settings-basic-avatar-upload").set_input_files( - str(avatar_path) - ) + page.get_by_test_id("ds-settings-basic-avatar-upload").set_input_files(str(avatar_path)) crop_modal = page.get_by_test_id("ds-settings-basic-avatar-crop-modal") expect(crop_modal).to_be_visible(timeout=RESULT_TIMEOUT_MS) page.get_by_test_id("ds-settings-basic-avatar-crop-confirm-btn").click() expect(crop_modal).not_to_be_visible(timeout=RESULT_TIMEOUT_MS) - page.get_by_test_id("ds-settings-basic-description-input").fill( - "Dataset setting playwright description" - ) + page.get_by_test_id("ds-settings-basic-description-input").fill("Dataset setting playwright description") try: select_combobox_option(page, "ds-settings-basic-permissions-select") except Exception: page.keyboard.press("Escape") - embedding_trigger = page.get_by_test_id( - "ds-settings-basic-embedding-model-select" - ).first + embedding_trigger = page.get_by_test_id("ds-settings-basic-embedding-model-select").first expect(embedding_trigger).to_be_visible(timeout=RESULT_TIMEOUT_MS) if not embedding_trigger.is_disabled(): try: @@ -438,14 +413,10 @@ def step_04_set_dataset_settings( with step("fill parser and metadata settings"): set_number_input(page, "ds-settings-parser-page-rank-input", 12) - select_combobox_option( - page, "ds-settings-parser-pdf-parser-select", preferred_text="Plain Text" - ) + select_combobox_option(page, "ds-settings-parser-pdf-parser-select", preferred_text="Plain Text") set_number_input(page, "ds-settings-parser-recommended-chunk-size-input", 640) set_switch_state(page, "ds-settings-parser-child-chunk-switch", True) - expect( - page.get_by_test_id("ds-settings-parser-child-chunk-delimiter-input") - ).to_be_visible(timeout=RESULT_TIMEOUT_MS) + expect(page.get_by_test_id("ds-settings-parser-child-chunk-delimiter-input")).to_be_visible(timeout=RESULT_TIMEOUT_MS) set_switch_state(page, "ds-settings-parser-page-index-switch", True) set_number_input(page, "ds-settings-parser-image-table-context-window-input", 16) set_switch_state(page, "ds-settings-metadata-switch", True) @@ -480,9 +451,7 @@ def step_04_set_dataset_settings( page.get_by_test_id("ds-settings-metadata-modal-save-btn").click() expect(metadata_modal).not_to_be_visible(timeout=RESULT_TIMEOUT_MS) - overlap_slider = page.get_by_test_id( - "ds-settings-parser-overlapped-percent-slider" - ).first + overlap_slider = page.get_by_test_id("ds-settings-parser-overlapped-percent-slider").first expect(overlap_slider).to_be_visible(timeout=RESULT_TIMEOUT_MS) overlap_slider.focus() overlap_slider.press("ArrowRight") @@ -496,20 +465,14 @@ def step_04_set_dataset_settings( expect(entity_input).to_be_visible(timeout=RESULT_TIMEOUT_MS) entity_input.fill("playwright_entity") entity_input.press("Enter") - select_ragflow_option( - page, "ds-settings-graph-method-select", preferred_text="General" - ) + select_ragflow_option(page, "ds-settings-graph-method-select", preferred_text="General") set_switch_state(page, "ds-settings-graph-entity-resolution-switch", True) set_switch_state(page, "ds-settings-graph-community-reports-switch", True) - raptor_scope_dataset = page.get_by_role( - "radio", name=re.compile(r"^Dataset$", re.I) - ).first + raptor_scope_dataset = page.get_by_role("radio", name=re.compile(r"^Dataset$", re.I)).first raptor_scope_dataset.check(force=True) expect(raptor_scope_dataset).to_be_checked(timeout=RESULT_TIMEOUT_MS) - page.get_by_test_id("ds-settings-raptor-prompt-textarea").fill( - "Playwright prompt for dataset settings" - ) + page.get_by_test_id("ds-settings-raptor-prompt-textarea").fill("Playwright prompt for dataset settings") set_number_input(page, "ds-settings-raptor-max-token-input", 300) set_number_input(page, "ds-settings-raptor-threshold-input", 0.3) set_number_input(page, "ds-settings-raptor-max-cluster-input", 128) @@ -546,23 +509,16 @@ def step_04_set_dataset_settings( assert 200 <= response.status < 400, f"Unexpected /api/v1/datasets update status={response.status}" response_payload = response.json() if isinstance(response_payload, dict): - assert response_payload.get("code") == 0, ( - f"/api/v1/datasets update response code={response_payload.get('code')} " - f"message={response_payload.get('message')}" - ) + assert response_payload.get("code") == 0, f"/api/v1/datasets update response code={response_payload.get('code')} message={response_payload.get('message')}" payload = get_request_json_payload(response) for key in ("name", "language", "parser_config"): assert key in payload, f"Expected key {key!r} in /api/v1/datasets update payload" parser_config = payload.get("parser_config") or {} - assert ( - parser_config.get("image_table_context_window") - == parser_config.get("image_context_size") - == parser_config.get("table_context_size") - ), "Expected image/table context window transform keys to be aligned" - expect(page.locator("[data-sonner-toast]").first).to_be_visible( - timeout=RESULT_TIMEOUT_MS + assert parser_config.get("image_table_context_window") == parser_config.get("image_context_size") == parser_config.get("table_context_size"), ( + "Expected image/table context window transform keys to be aligned" ) + expect(page.locator("[data-sonner-toast]").first).to_be_visible(timeout=RESULT_TIMEOUT_MS) with step("return to dataset detail for upload"): page.goto( @@ -598,9 +554,7 @@ def step_05_upload_files( for idx, file_path in enumerate(file_paths): filename = file_path.name with step(f"open upload modal for {filename}"): - upload_modal = ensure_upload_modal_open( - page, expect, auth_click, timeout_ms=RESULT_TIMEOUT_MS - ) + upload_modal = ensure_upload_modal_open(page, expect, auth_click, timeout_ms=RESULT_TIMEOUT_MS) if idx == 0: snap("upload_modal_open") @@ -611,14 +565,10 @@ def step_05_upload_files( with step(f"upload file {filename}"): upload_file(page, expect, upload_modal, str(file_path), RESULT_TIMEOUT_MS) - expect(upload_modal.locator(f"text={filename}")).to_be_visible( - timeout=RESULT_TIMEOUT_MS - ) + expect(upload_modal.locator(f"text={filename}")).to_be_visible(timeout=RESULT_TIMEOUT_MS) with step(f"submit upload {filename}"): - save_button = upload_modal.locator( - "button", has_text=re.compile("save", re.I) - ).first + save_button = upload_modal.locator("button", has_text=re.compile("save", re.I)).first def trigger(): save_button.click() @@ -626,15 +576,12 @@ def step_05_upload_files( capture_response( page, trigger, - lambda resp: resp.request.method == "POST" - and "/v1/document/upload" in resp.url, + lambda resp: resp.request.method == "POST" and "/v1/document/upload" in resp.url, ) expect(upload_modal).not_to_be_visible(timeout=RESULT_TIMEOUT_MS) snap(f"upload_{filename}_submitted") - row = page.locator( - f"[data-testid='document-row'][data-doc-name={json.dumps(filename)}]" - ) + row = page.locator(f"[data-testid='document-row'][data-doc-name={json.dumps(filename)}]") expect(row).to_be_visible(timeout=RESULT_TIMEOUT_MS) flow_state["uploads_done"] = True @@ -682,16 +629,8 @@ def step_07_delete_one_file( with step(f"delete uploaded file {delete_filename}"): delete_uploaded_file(page, expect, delete_filename, timeout_ms=RESULT_TIMEOUT_MS) snap("file_deleted_doc3") - expect( - page.locator( - f"[data-testid='document-row'][data-doc-name={json.dumps('Doc1.pdf')}]" - ) - ).to_be_visible(timeout=RESULT_TIMEOUT_MS) - expect( - page.locator( - f"[data-testid='document-row'][data-doc-name={json.dumps('Doc2.pdf')}]" - ) - ).to_be_visible(timeout=RESULT_TIMEOUT_MS) + expect(page.locator(f"[data-testid='document-row'][data-doc-name={json.dumps('Doc1.pdf')}]")).to_be_visible(timeout=RESULT_TIMEOUT_MS) + expect(page.locator(f"[data-testid='document-row'][data-doc-name={json.dumps('Doc2.pdf')}]")).to_be_visible(timeout=RESULT_TIMEOUT_MS) snap("success") diff --git a/test/playwright/e2e/test_model_providers_zhipu_ai_defaults.py b/test/playwright/e2e/test_model_providers_zhipu_ai_defaults.py index dbf6f702a3..09ec85b5fb 100644 --- a/test/playwright/e2e/test_model_providers_zhipu_ai_defaults.py +++ b/test/playwright/e2e/test_model_providers_zhipu_ai_defaults.py @@ -137,17 +137,10 @@ def step_05_filter_zhipu( expect(search_input).to_have_count(1) search_input.first.fill("zhipu") available_section = page.locator("[data-testid='available-models-section']") - provider = available_section.locator( - "[data-testid='available-model-card'][data-provider='ZHIPU-AI']" - ).first + provider = available_section.locator("[data-testid='available-model-card'][data-provider='ZHIPU-AI']").first if provider.count() == 0: added_section = page.locator("[data-testid='added-models-section']") - if ( - added_section.locator( - "[data-testid='added-model-card'][data-provider='ZHIPU-AI']" - ).count() - == 0 - ): + if added_section.locator("[data-testid='added-model-card'][data-provider='ZHIPU-AI']").count() == 0: raise AssertionError("ZHIPU-AI provider not found in available or added models.") else: expect(provider).to_be_visible() @@ -169,18 +162,14 @@ def step_06_add_api_key( require(flow_state, "provider_filtered", "api_key") page = flow_page available_section = page.locator("[data-testid='available-models-section']") - provider = available_section.locator( - "[data-testid='available-model-card'][data-provider='ZHIPU-AI']" - ).first + provider = available_section.locator("[data-testid='available-model-card'][data-provider='ZHIPU-AI']").first with step("add ZHIPU-AI api key"): if provider.count() > 0: provider.click() else: added_section = page.locator("[data-testid='added-models-section']") - card = added_section.locator( - "[data-testid='added-model-card'][data-provider='ZHIPU-AI']" - ).first + card = added_section.locator("[data-testid='added-model-card'][data-provider='ZHIPU-AI']").first api_key_button = card.locator("button", has_text=re.compile("API-?Key", re.I)).first expect(api_key_button).to_be_visible() api_key_button.click() @@ -189,6 +178,7 @@ def step_06_add_api_key( api_input = modal.locator("[data-testid='apikey-input']").first save_button = modal.locator("[data-testid='apikey-save']").first try: + def trigger(): api_input.fill(flow_state["api_key"]) save_button.click() @@ -206,11 +196,7 @@ def step_06_add_api_key( with step("confirm added model"): added_section = page.locator("[data-testid='added-models-section']") expect(added_section).to_be_visible() - expect( - added_section.locator( - "[data-testid='added-model-card'][data-provider='ZHIPU-AI']" - ) - ).to_be_visible() + expect(added_section.locator("[data-testid='added-model-card'][data-provider='ZHIPU-AI']")).to_be_visible() flow_state["provider_added"] = True snap("provider_saved") @@ -278,11 +264,7 @@ def step_08_verify_persist( expect(llm_combo).to_contain_text("glm-4-flash") expect(emb_combo).to_contain_text(flow_state.get("selected_emb_text") or "embedding-2") added_section = page.locator("[data-testid='added-models-section']") - expect( - added_section.locator( - "[data-testid='added-model-card'][data-provider='ZHIPU-AI']" - ) - ).to_be_visible() + expect(added_section.locator("[data-testid='added-model-card'][data-provider='ZHIPU-AI']")).to_be_visible() snap("defaults_persisted") snap("success") diff --git a/test/playwright/e2e/test_next_apps_agent.py b/test/playwright/e2e/test_next_apps_agent.py index b0869d971c..11c2ea34bc 100644 --- a/test/playwright/e2e/test_next_apps_agent.py +++ b/test/playwright/e2e/test_next_apps_agent.py @@ -170,12 +170,8 @@ def step_03_create_first_agent( name_input_testid="agent-name-input", save_testid="agent-save", ) - expect(page.locator("[data-testid='agents-list']")).to_be_visible( - timeout=RESULT_TIMEOUT_MS - ) - expect(page.locator("[data-testid='agent-card']").first).to_be_visible( - timeout=RESULT_TIMEOUT_MS - ) + expect(page.locator("[data-testid='agents-list']")).to_be_visible(timeout=RESULT_TIMEOUT_MS) + expect(page.locator("[data-testid='agent-card']").first).to_be_visible(timeout=RESULT_TIMEOUT_MS) flow_state["first_agent_created"] = True snap("agent_first_created") @@ -235,16 +231,12 @@ def step_05_open_imported_agent( with step("open imported agent"): card = page.locator( "[data-testid='agent-card']", - has=page.locator( - "[data-testid='agent-name']", has_text=re.compile(flow_state["second_agent_name"]) - ), + has=page.locator("[data-testid='agent-name']", has_text=re.compile(flow_state["second_agent_name"])), ).first expect(card).to_be_visible(timeout=RESULT_TIMEOUT_MS) auth_click(card, "open_agent") _wait_for_url_regex(page, r"/agent/") - expect(page.locator("[data-testid='agent-detail']")).to_be_visible( - timeout=RESULT_TIMEOUT_MS - ) + expect(page.locator("[data-testid='agent-detail']")).to_be_visible(timeout=RESULT_TIMEOUT_MS) flow_state["agent_detail_open"] = True snap("agent_detail_open") @@ -267,11 +259,7 @@ def step_06_run_agent( run_ui_timeout_ms = int(os.getenv("PW_AGENT_RUN_UI_TIMEOUT_MS", "60000")) run_root = page.locator("[data-testid='agent-run']") - run_ui_selector = ( - "[data-testid='agent-run-chat'], " - "[data-testid='chat-textarea'], " - "[data-testid='agent-run-idle']" - ) + run_ui_selector = "[data-testid='agent-run-chat'], [data-testid='chat-textarea'], [data-testid='agent-run-idle']" run_ui_locator = page.locator(run_ui_selector) try: @@ -367,9 +355,7 @@ def step_07_send_chat( except AssertionError: # Older UI builds do not expose agent-run-idle; fallback to assistant reply. agent_chat = page.locator("[data-testid='agent-run-chat']") - assistant_reply = agent_chat.locator( - "text=/how can i assist|hello/i" - ).first + assistant_reply = agent_chat.locator("text=/how can i assist|hello/i").first try: expect(assistant_reply).to_be_visible(timeout=60000) except AssertionError: diff --git a/test/playwright/e2e/test_next_apps_chat.py b/test/playwright/e2e/test_next_apps_chat.py index e0169a8a59..756bada0cf 100644 --- a/test/playwright/e2e/test_next_apps_chat.py +++ b/test/playwright/e2e/test_next_apps_chat.py @@ -41,9 +41,7 @@ def step_02_open_chat_list(ctx: FlowContext, step, snap): with step("open chat list"): _goto_home(page, ctx.base_url) _nav_click(page, "nav-chat") - expect(page.locator("[data-testid='chats-list']")).to_be_visible( - timeout=RESULT_TIMEOUT_MS - ) + expect(page.locator("[data-testid='chats-list']")).to_be_visible(timeout=RESULT_TIMEOUT_MS) snap("chat_list_open") @@ -237,9 +235,7 @@ def _mm_open_model_options(page, card, option_prefix: str): pass page.wait_for_timeout(120) - raise AssertionError( - f"no model options rendered for prefix={option_prefix!r} in multi-model selector" - ) + raise AssertionError(f"no model options rendered for prefix={option_prefix!r} in multi-model selector") def _mm_click_generic_model_option(page, card_index: int, option_prefix: str) -> str: @@ -265,11 +261,7 @@ def _mm_click_generic_model_option(page, card_index: int, option_prefix: str) -> if chosen_testid: return chosen_testid - chosen_value = ( - chosen.get_attribute("data-value") - or chosen.get_attribute("value") - or f"idx-{choose_index}" - ) + chosen_value = chosen.get_attribute("data-value") or chosen.get_attribute("value") or f"idx-{choose_index}" return f"{option_prefix}{chosen_value}" @@ -285,9 +277,7 @@ def mm_step_01_ensure_authed_and_open_chat_list(ctx: FlowContext, step, snap): ) _goto_home(page, ctx.base_url) _nav_click(page, "nav-chat") - expect(page.locator("[data-testid='chats-list']")).to_be_visible( - timeout=RESULT_TIMEOUT_MS - ) + expect(page.locator("[data-testid='chats-list']")).to_be_visible(timeout=RESULT_TIMEOUT_MS) ctx.state["mm_logged_in"] = True snap("chat_mm_list") @@ -342,28 +332,20 @@ def mm_step_05_sessions_panel_row_ops(ctx: FlowContext, step, snap): expect(sessions_root).to_be_visible(timeout=RESULT_TIMEOUT_MS) page.get_by_test_id("chat-detail-sessions-close").click() - expect(page.get_by_test_id("chat-detail-sessions-open")).to_be_visible( - timeout=RESULT_TIMEOUT_MS - ) + expect(page.get_by_test_id("chat-detail-sessions-open")).to_be_visible(timeout=RESULT_TIMEOUT_MS) page.get_by_test_id("chat-detail-sessions-open").click() expect(sessions_root).to_be_visible(timeout=RESULT_TIMEOUT_MS) page.get_by_test_id("chat-detail-session-new").click() session_rows = page.locator("[data-testid='chat-detail-session-item']") expect(session_rows.first).to_be_visible(timeout=RESULT_TIMEOUT_MS) - active_session = sessions_root.locator( - "li[aria-selected='true'] [data-testid='chat-detail-session-item']" - ) + active_session = sessions_root.locator("li[aria-selected='true'] [data-testid='chat-detail-session-item']") selected_row = active_session.first if active_session.count() > 0 else session_rows.first created_session_id = selected_row.get_attribute("data-session-id") or "" assert created_session_id, "failed to capture created session id" selected_row.click() - expect( - page.locator( - f"[data-testid='chat-detail-session-item'][data-session-id='{created_session_id}']" - ).first - ).to_be_visible(timeout=RESULT_TIMEOUT_MS) + expect(page.locator(f"[data-testid='chat-detail-session-item'][data-session-id='{created_session_id}']").first).to_be_visible(timeout=RESULT_TIMEOUT_MS) search_input = page.get_by_test_id("chat-detail-session-search") expect(search_input).to_be_visible(timeout=RESULT_TIMEOUT_MS) @@ -383,31 +365,19 @@ def mm_step_05_sessions_panel_row_ops(ctx: FlowContext, step, snap): # When only one row exists, some builds keep it visible for temporary sessions. # In that case we still validate the search interaction without forcing impossible narrowing. if row_count_before > 1: - assert ( - min_filtered_count < row_count_before - ), "session search did not narrow visible rows" + assert min_filtered_count < row_count_before, "session search did not narrow visible rows" else: assert min_filtered_count <= row_count_before search_input.fill("") - expect( - page.locator( - f"[data-testid='chat-detail-session-item'][data-session-id='{created_session_id}']" - ).first - ).to_be_visible(timeout=RESULT_TIMEOUT_MS) + expect(page.locator(f"[data-testid='chat-detail-session-item'][data-session-id='{created_session_id}']").first).to_be_visible(timeout=RESULT_TIMEOUT_MS) - row_li = sessions_root.locator( - f"li:has([data-testid='chat-detail-session-item'][data-session-id='{created_session_id}'])" - ).first + row_li = sessions_root.locator(f"li:has([data-testid='chat-detail-session-item'][data-session-id='{created_session_id}'])").first row_li.hover() - actions_btn = page.locator( - f"[data-testid='chat-detail-session-actions'][data-session-id='{created_session_id}']" - ).first + actions_btn = page.locator(f"[data-testid='chat-detail-session-actions'][data-session-id='{created_session_id}']").first expect(actions_btn).to_be_visible(timeout=RESULT_TIMEOUT_MS) actions_btn.click() - row_delete = page.locator( - f"[data-testid='chat-detail-session-delete'][data-session-id='{created_session_id}']" - ).first + row_delete = page.locator(f"[data-testid='chat-detail-session-delete'][data-session-id='{created_session_id}']").first expect(row_delete).to_be_visible(timeout=RESULT_TIMEOUT_MS) row_delete.click() row_delete_dialog = page.get_by_test_id("confirm-delete-dialog") @@ -419,11 +389,7 @@ def mm_step_05_sessions_panel_row_ops(ctx: FlowContext, step, snap): # If no dialog renders in this branch, still dismiss any menu overlay. page.keyboard.press("Escape") - expect( - page.locator( - f"[data-testid='chat-detail-session-item'][data-session-id='{created_session_id}']" - ).first - ).to_be_visible(timeout=RESULT_TIMEOUT_MS) + expect(page.locator(f"[data-testid='chat-detail-session-item'][data-session-id='{created_session_id}']").first).to_be_visible(timeout=RESULT_TIMEOUT_MS) ctx.state["mm_created_session_id"] = created_session_id ctx.state["mm_session_row_checked"] = True @@ -448,21 +414,15 @@ def mm_step_06_selection_mode_batch_delete(ctx: FlowContext, step, snap): page.keyboard.press("Escape") page.mouse.click(5, 5) selection_enable.click(timeout=RESULT_TIMEOUT_MS) - checked_before = page.locator( - "[data-testid='chat-detail-session-checkbox'][data-state='checked']" - ).count() + checked_before = page.locator("[data-testid='chat-detail-session-checkbox'][data-state='checked']").count() page.get_by_test_id("chat-detail-session-select-all").click() - checked_after = page.locator( - "[data-testid='chat-detail-session-checkbox'][data-state='checked']" - ).count() + checked_after = page.locator("[data-testid='chat-detail-session-checkbox'][data-state='checked']").count() if page.locator("[data-testid='chat-detail-session-checkbox']").count() > 1: assert checked_after != checked_before else: assert checked_after >= checked_before - session_checkbox = page.locator( - f"[data-testid='chat-detail-session-checkbox'][data-session-id='{created_session_id}']" - ).first + session_checkbox = page.locator(f"[data-testid='chat-detail-session-checkbox'][data-session-id='{created_session_id}']").first expect(session_checkbox).to_be_visible(timeout=RESULT_TIMEOUT_MS) if _mm_is_checked(session_checkbox): session_checkbox.click() @@ -471,11 +431,7 @@ def mm_step_06_selection_mode_batch_delete(ctx: FlowContext, step, snap): assert _mm_is_checked(session_checkbox), "target session checkbox did not become checked" page.get_by_test_id("chat-detail-session-selection-exit").click() - expect( - page.locator( - f"[data-testid='chat-detail-session-item'][data-session-id='{created_session_id}']" - ).first - ).to_be_visible(timeout=RESULT_TIMEOUT_MS) + expect(page.locator(f"[data-testid='chat-detail-session-item'][data-session-id='{created_session_id}']").first).to_be_visible(timeout=RESULT_TIMEOUT_MS) selection_enable = page.get_by_test_id("chat-detail-session-selection-enable") expect(selection_enable).to_be_visible(timeout=RESULT_TIMEOUT_MS) @@ -485,9 +441,7 @@ def mm_step_06_selection_mode_batch_delete(ctx: FlowContext, step, snap): page.keyboard.press("Escape") page.mouse.click(5, 5) selection_enable.click(timeout=RESULT_TIMEOUT_MS) - session_checkbox = page.locator( - f"[data-testid='chat-detail-session-checkbox'][data-session-id='{created_session_id}']" - ).first + session_checkbox = page.locator(f"[data-testid='chat-detail-session-checkbox'][data-session-id='{created_session_id}']").first expect(session_checkbox).to_be_visible(timeout=RESULT_TIMEOUT_MS) if not _mm_is_checked(session_checkbox): session_checkbox.click() @@ -497,27 +451,14 @@ def mm_step_06_selection_mode_batch_delete(ctx: FlowContext, step, snap): expect(batch_dialog).to_be_visible(timeout=RESULT_TIMEOUT_MS) page.get_by_test_id("chat-detail-session-batch-delete-cancel").click() expect(batch_dialog).not_to_be_visible(timeout=RESULT_TIMEOUT_MS) - expect( - page.locator( - f"[data-testid='chat-detail-session-checkbox'][data-session-id='{created_session_id}']" - ).first - ).to_be_visible(timeout=RESULT_TIMEOUT_MS) + expect(page.locator(f"[data-testid='chat-detail-session-checkbox'][data-session-id='{created_session_id}']").first).to_be_visible(timeout=RESULT_TIMEOUT_MS) page.get_by_test_id("chat-detail-session-batch-delete").click() expect(batch_dialog).to_be_visible(timeout=RESULT_TIMEOUT_MS) page.get_by_test_id("chat-detail-session-batch-delete-confirm").click() expect(batch_dialog).not_to_be_visible(timeout=RESULT_TIMEOUT_MS) - expect( - page.locator( - f"[data-testid='chat-detail-session-item'][data-session-id='{created_session_id}']" - ) - ).to_have_count(0, timeout=RESULT_TIMEOUT_MS) - expect( - sessions_root.locator( - "li[aria-selected='true'] " - f"[data-testid='chat-detail-session-item'][data-session-id='{created_session_id}']" - ) - ).to_have_count(0, timeout=RESULT_TIMEOUT_MS) + expect(page.locator(f"[data-testid='chat-detail-session-item'][data-session-id='{created_session_id}']")).to_have_count(0, timeout=RESULT_TIMEOUT_MS) + expect(sessions_root.locator(f"li[aria-selected='true'] [data-testid='chat-detail-session-item'][data-session-id='{created_session_id}']")).to_have_count(0, timeout=RESULT_TIMEOUT_MS) ctx.state["mm_sessions_cleanup_done"] = True snap("chat_mm_sessions_cleanup_done") @@ -615,18 +556,10 @@ def mm_step_10_select_models_for_two_cards(ctx: FlowContext, step, snap): selected_option_testids: list[str] = [] for card_index in (0, 1): - card = mm_grid.locator( - f"[data-testid='chat-detail-multimodel-card'][data-card-index='{card_index}']" - ).first + card = mm_grid.locator(f"[data-testid='chat-detail-multimodel-card'][data-card-index='{card_index}']").first expect(card).to_be_visible(timeout=RESULT_TIMEOUT_MS) options = _mm_open_model_options(page, card, option_prefix) - option_testids = [ - tid - for tid in options.evaluate_all( - "els => els.map(el => el.getAttribute('data-testid') || '')" - ) - if tid - ] + option_testids = [tid for tid in options.evaluate_all("els => els.map(el => el.getAttribute('data-testid') || '')") if tid] option_testids = list(dict.fromkeys(option_testids)) if option_testids: @@ -654,9 +587,7 @@ def mm_step_11_apply_multimodel_config(ctx: FlowContext, step, snap): expect(mm_grid).to_be_visible(timeout=RESULT_TIMEOUT_MS) _mm_dismiss_open_popovers(page) - apply_btn = mm_grid.locator( - "[data-testid='chat-detail-multimodel-card-apply'][data-card-index='0']" - ).first + apply_btn = mm_grid.locator("[data-testid='chat-detail-multimodel-card-apply'][data-card-index='0']").first expect(apply_btn).to_be_enabled(timeout=RESULT_TIMEOUT_MS) with page.expect_request(_mm_settings_save_request, timeout=RESULT_TIMEOUT_MS) as req_info: apply_btn.click() @@ -676,12 +607,7 @@ def mm_step_12_composer_and_single_send(ctx: FlowContext, step, snap): completion_payloads: list[dict] = [] def _on_completion_request(req): - if ( - req.method.upper() in MM_REQUEST_METHOD_WHITELIST - and "/api/v1/chats/" in req.url - and "/sessions/" in req.url - and req.url.rstrip("/").endswith("/completions") - ): + if req.method.upper() in MM_REQUEST_METHOD_WHITELIST and "/api/v1/chats/" in req.url and "/sessions/" in req.url and req.url.rstrip("/").endswith("/completions"): completion_payloads.append(_mm_payload_from_request(req)) with step("composer interactions and single send in multi-model mode"): @@ -696,9 +622,7 @@ def mm_step_12_composer_and_single_send(ctx: FlowContext, step, snap): file_input = page.locator("input[type='file']").first expect(file_input).to_be_attached(timeout=RESULT_TIMEOUT_MS) file_input.set_input_files(str(attach_path)) - expect(page.locator(f"text={attach_path.name}").first).to_be_visible( - timeout=RESULT_TIMEOUT_MS - ) + expect(page.locator(f"text={attach_path.name}").first).to_be_visible(timeout=RESULT_TIMEOUT_MS) thinking_toggle = page.get_by_test_id("chat-detail-thinking-toggle") expect(thinking_toggle).to_be_visible(timeout=RESULT_TIMEOUT_MS) @@ -736,9 +660,7 @@ def mm_step_12_composer_and_single_send(ctx: FlowContext, step, snap): except AssertionError: pass try: - expect(stream_status.first).to_have_attribute( - "data-status", "idle", timeout=90000 - ) + expect(stream_status.first).to_have_attribute("data-status", "idle", timeout=90000) except AssertionError: expect(stream_status).to_have_count(0, timeout=90000) @@ -753,11 +675,7 @@ def mm_step_12_composer_and_single_send(ctx: FlowContext, step, snap): payloads_with_messages = [p for p in completion_payloads if p.get("messages")] assert payloads_with_messages, "completion requests did not include messages" - selected_model_ids = [ - tid.replace(option_prefix, "") - for tid in selected_option_testids - if tid.startswith(option_prefix) - ] + selected_model_ids = [tid.replace(option_prefix, "") for tid in selected_option_testids if tid.startswith(option_prefix)] has_model_payload = any( (p.get("llm_id") in selected_model_ids) or ("llm_id" in p) @@ -793,9 +711,7 @@ def mm_step_13_remove_extra_card_and_exit(ctx: FlowContext, step, snap): expect(cards).to_have_count(current_count - 1, timeout=RESULT_TIMEOUT_MS) page.get_by_test_id("chat-detail-multimodel-back").click() - expect(page.get_by_test_id("chat-detail-multimodel-root")).not_to_be_visible( - timeout=RESULT_TIMEOUT_MS - ) + expect(page.get_by_test_id("chat-detail-multimodel-root")).not_to_be_visible(timeout=RESULT_TIMEOUT_MS) expect(page.get_by_test_id("chat-detail")).to_be_visible(timeout=RESULT_TIMEOUT_MS) expect(page.get_by_test_id("chat-textarea")).to_be_visible(timeout=RESULT_TIMEOUT_MS) diff --git a/test/playwright/e2e/test_next_apps_search.py b/test/playwright/e2e/test_next_apps_search.py index 7fbbe70ea4..142bda63e2 100644 --- a/test/playwright/e2e/test_next_apps_search.py +++ b/test/playwright/e2e/test_next_apps_search.py @@ -86,9 +86,7 @@ def step_02_open_search_list( with step("open search list"): _goto_home(page, base_url) _nav_click(page, "nav-search") - expect(page.locator("[data-testid='search-list']")).to_be_visible( - timeout=RESULT_TIMEOUT_MS - ) + expect(page.locator("[data-testid='search-list']")).to_be_visible(timeout=RESULT_TIMEOUT_MS) snap("search_list_open") @@ -129,9 +127,7 @@ def step_04_create_search( with step("create search app"): _fill_and_save_create_modal(page, search_name) _wait_for_url_or_testid(page, r"/next-search/", "search-detail") - expect(page.locator("[data-testid='search-detail']")).to_be_visible( - timeout=RESULT_TIMEOUT_MS - ) + expect(page.locator("[data-testid='search-detail']")).to_be_visible(timeout=RESULT_TIMEOUT_MS) flow_state["search_created"] = True snap("search_created") diff --git a/test/playwright/helpers/_auth_helpers.py b/test/playwright/helpers/_auth_helpers.py index 7c7b31474f..ca58722cae 100644 --- a/test/playwright/helpers/_auth_helpers.py +++ b/test/playwright/helpers/_auth_helpers.py @@ -31,9 +31,7 @@ def ensure_authed( email, password = seeded_user_credentials else: email = os.getenv("SEEDED_USER_EMAIL") or os.getenv("E2E_ADMIN_EMAIL") - password = os.getenv("SEEDED_USER_PASSWORD") or os.getenv( - "E2E_ADMIN_PASSWORD" - ) + password = os.getenv("SEEDED_USER_PASSWORD") or os.getenv("E2E_ADMIN_PASSWORD") if not email or not password: pytest.skip("SEEDED_USER_EMAIL/SEEDED_USER_PASSWORD not set.") @@ -58,25 +56,17 @@ def ensure_authed( return form, _ = active_auth_context() - email_input = form.locator( - "input[data-testid='auth-email'], [data-testid='auth-email'] input" - ) - password_input = form.locator( - "input[data-testid='auth-password'], [data-testid='auth-password'] input" - ) + email_input = form.locator("input[data-testid='auth-email'], [data-testid='auth-email'] input") + password_input = form.locator("input[data-testid='auth-password'], [data-testid='auth-password'] input") expect(email_input).to_have_count(1) expect(password_input).to_have_count(1) email_input.fill(email) password_input.fill(password) password_input.blur() - submit_button = form.locator( - "button[data-testid='auth-submit'], [data-testid='auth-submit'] button, [data-testid='auth-submit']" - ) + submit_button = form.locator("button[data-testid='auth-submit'], [data-testid='auth-submit'] button, [data-testid='auth-submit']") expect(submit_button).to_have_count(1) auth_click(submit_button, "submit_login") _wait_for_login_complete(page, timeout_ms=timeout_ms) - expect(page.locator("form[data-testid='auth-form'][data-active='true']")).to_have_count( - 0, timeout=timeout_ms - ) + expect(page.locator("form[data-testid='auth-form'][data-active='true']")).to_have_count(0, timeout=timeout_ms) diff --git a/test/playwright/helpers/_next_apps_helpers.py b/test/playwright/helpers/_next_apps_helpers.py index 3300022757..4c20c54c15 100644 --- a/test/playwright/helpers/_next_apps_helpers.py +++ b/test/playwright/helpers/_next_apps_helpers.py @@ -15,10 +15,7 @@ def _unique_name(prefix: str) -> str: def _assert_not_on_login(page) -> None: if "/login" in page.url or page.locator("input[autocomplete='email']").count() > 0: - raise AssertionError( - "Expected authenticated session; landed on /login. " - "Ensure ensure_authed(...) was called and credentials are set." - ) + raise AssertionError("Expected authenticated session; landed on /login. Ensure ensure_authed(...) was called and credentials are set.") def _goto_home(page, base_url: str) -> None: @@ -71,9 +68,7 @@ def _nav_click(page, testid: str) -> None: else: fallback = page.get_by_text(pattern) if fallback.count() == 0: - fallback = page.locator("button, [role='button'], a, span, div").filter( - has_text=pattern - ) + fallback = page.locator("button, [role='button'], a, span, div").filter(has_text=pattern) expect(fallback.first).to_be_visible(timeout=RESULT_TIMEOUT_MS) fallback.first.click() _ensure_expected_path() @@ -107,9 +102,7 @@ def _open_create_from_list( pattern = create_text_map.get(create_btn_testid) clicked = False if pattern: - fallback_btn = page.get_by_role( - "button", name=re.compile(pattern, re.I) - ) + fallback_btn = page.get_by_role("button", name=re.compile(pattern, re.I)) if fallback_btn.count() > 0 and fallback_btn.first.is_visible(): fallback_btn.first.click() clicked = True @@ -122,17 +115,13 @@ def _open_create_from_list( } empty_pattern = empty_text_map.get(empty_testid) if empty_pattern: - empty_state = page.locator("div, section, article").filter( - has_text=re.compile(empty_pattern, re.I) - ) + empty_state = page.locator("div, section, article").filter(has_text=re.compile(empty_pattern, re.I)) if empty_state.count() > 0 and empty_state.first.is_visible(): empty_state.first.click() clicked = True if not clicked: - fallback_card = page.locator( - ".border-dashed, [class*='border-dashed']" - ).first + fallback_card = page.locator(".border-dashed, [class*='border-dashed']").first expect(fallback_card).to_be_visible(timeout=RESULT_TIMEOUT_MS) fallback_card.click() if modal_testid == "agent-create-modal": @@ -215,9 +204,7 @@ def _select_first_dataset_and_save( combo = search_scope.locator(f"[data-testid='{combobox_testid}']") if combo.count() > 0: return combo - combo = search_scope.locator("[role='combobox']").filter( - has_text=re.compile(r"select|dataset|please", re.I) - ) + combo = search_scope.locator("[role='combobox']").filter(has_text=re.compile(r"select|dataset|please", re.I)) if combo.count() > 0: return combo return search_scope.locator("[role='combobox']") @@ -242,9 +229,7 @@ def _select_first_dataset_and_save( settings_button.first.click() break - settings_dialog = page.locator("[role='dialog']").filter( - has_text=re.compile(r"settings", re.I) - ) + settings_dialog = page.locator("[role='dialog']").filter(has_text=re.compile(r"settings", re.I)) if settings_dialog.count() > 0 and settings_dialog.first.is_visible(): scope_root = settings_dialog.first combobox = _find_dataset_combobox(scope_root) @@ -261,13 +246,9 @@ def _select_first_dataset_and_save( save_button = scope_root.locator(f"[data-testid='{save_testid}']") if save_button.count() == 0: - save_button = scope_root.get_by_role( - "button", name=re.compile(r"^save$", re.I) - ) + save_button = scope_root.get_by_role("button", name=re.compile(r"^save$", re.I)) if save_button.count() == 0: - save_button = scope_root.locator( - "button[type='submit']", has_text=re.compile(r"^save$", re.I) - ).first + save_button = scope_root.locator("button[type='submit']", has_text=re.compile(r"^save$", re.I)).first save_button = save_button.first expect(save_button).to_be_visible(timeout=timeout_ms) @@ -294,10 +275,7 @@ def _select_first_dataset_and_save( last_list_text = list_locator.inner_text() or "" except Exception: last_list_text = "" - raise AssertionError( - "Dataset option popover did not open. " - f"combobox_testid={combobox_testid!r} last_list_text={last_list_text[:200]!r}" - ) + raise AssertionError(f"Dataset option popover did not open. combobox_testid={combobox_testid!r} last_list_text={last_list_text[:200]!r}") def _pick_first_dataset_option(options_root) -> bool: search_input = options_root.locator("[cmdk-input], input[placeholder*='Search']").first @@ -327,12 +305,7 @@ def _select_first_dataset_and_save( text = (candidate.inner_text() or "").strip().lower() except Exception: continue - if ( - not text - or "no results found" in text - or text == "close" - or text == "clear" - ): + if not text or "no results found" in text or text == "close" or text == "clear": continue for _ in range(3): try: @@ -378,9 +351,7 @@ def _select_first_dataset_and_save( kb_ids = payload.get("kb_ids") return isinstance(kb_ids, list) and len(kb_ids) > 0 - response_url_pattern = ( - "/api/v1/chats" if save_testid == "chat-settings-save" else "/api/v1/searches/" - ) + response_url_pattern = "/api/v1/chats" if save_testid == "chat-settings-save" else "/api/v1/searches/" last_payload = {} last_combobox_text = "" last_list_text = "" @@ -388,10 +359,7 @@ def _select_first_dataset_and_save( options, last_list_text = _open_dataset_options() clicked = _pick_first_dataset_option(options) if not clicked: - raise AssertionError( - "Failed to select dataset option after retries. " - f"list_text={last_list_text[:200]!r}" - ) + raise AssertionError(f"Failed to select dataset option after retries. list_text={last_list_text[:200]!r}") page.wait_for_timeout(120) try: @@ -404,8 +372,7 @@ def _select_first_dataset_and_save( response = capture_response( page, lambda: save_button.click(), - lambda resp: response_url_pattern in resp.url - and resp.request.method in ("POST", "PUT", "PATCH"), + lambda resp: response_url_pattern in resp.url and resp.request.method in ("POST", "PUT", "PATCH"), timeout_ms=response_timeout_ms, ) except Exception: @@ -432,15 +399,11 @@ def _select_first_dataset_and_save( page.wait_for_timeout(200 * (attempt + 1)) raise AssertionError( - "Dataset selection did not persist in save payload. " - f"save_testid={save_testid!r} payload={last_payload!r} " - f"combobox_text={last_combobox_text!r} list_text={last_list_text[:200]!r}" + f"Dataset selection did not persist in save payload. save_testid={save_testid!r} payload={last_payload!r} combobox_text={last_combobox_text!r} list_text={last_list_text[:200]!r}" ) -def _send_chat_and_wait_done( - page, text: str, timeout_ms: int = 60000 -) -> None: +def _send_chat_and_wait_done(page, text: str, timeout_ms: int = 60000) -> None: textarea = page.locator("[data-testid='chat-textarea']") expect(textarea).to_be_visible(timeout=RESULT_TIMEOUT_MS) tag_name = "" @@ -457,10 +420,7 @@ def _send_chat_and_wait_done( is_input = tag_name in ("INPUT", "TEXTAREA") is_editable = is_input or contenteditable == "true" if not is_editable: - raise AssertionError( - "chat-textarea is not an editable element. " - f"url={page.url} tag={tag_name!r} contenteditable={contenteditable!r}" - ) + raise AssertionError(f"chat-textarea is not an editable element. url={page.url} tag={tag_name!r} contenteditable={contenteditable!r}") textarea.fill(text) typed_value = "" @@ -484,11 +444,7 @@ def _send_chat_and_wait_done( except Exception: typed_value = "" if text not in (typed_value or ""): - raise AssertionError( - "Failed to type prompt into chat-textarea. " - f"url={page.url} tag={tag_name!r} contenteditable={contenteditable!r} " - f"typed_value={typed_value!r}" - ) + raise AssertionError(f"Failed to type prompt into chat-textarea. url={page.url} tag={tag_name!r} contenteditable={contenteditable!r} typed_value={typed_value!r}") composer = textarea.locator("xpath=ancestor::form[1]") if composer.count() == 0: @@ -496,13 +452,9 @@ def _send_chat_and_wait_done( send_button = None if composer.count() > 0: if hasattr(composer, "get_by_role"): - send_button = composer.get_by_role( - "button", name=re.compile(r"send message", re.I) - ) + send_button = composer.get_by_role("button", name=re.compile(r"send message", re.I)) if send_button is None or send_button.count() == 0: - send_button = composer.locator( - "button", has_text=re.compile(r"send message", re.I) - ) + send_button = composer.locator("button", has_text=re.compile(r"send message", re.I)) if send_button is not None and send_button.count() > 0: send_button.first.click() send_used = True @@ -512,15 +464,11 @@ def _send_chat_and_wait_done( status_marker = page.locator("[data-testid='chat-stream-status']").first try: - expect(status_marker).to_have_attribute( - "data-status", "idle", timeout=timeout_ms - ) + expect(status_marker).to_have_attribute("data-status", "idle", timeout=timeout_ms) except Exception as exc: try: # Some UI builds remove the stream-status marker when generation finishes. - expect(page.locator("[data-testid='chat-stream-status']")).to_have_count( - 0, timeout=timeout_ms - ) + expect(page.locator("[data-testid='chat-stream-status']")).to_have_count(0, timeout=timeout_ms) return except Exception: pass @@ -545,9 +493,7 @@ def _wait_for_url_regex(page, pattern: str, timeout_ms: int = RESULT_TIMEOUT_MS) page.wait_for_url(regex, wait_until="commit", timeout=timeout_ms) -def _wait_for_url_or_testid( - page, url_regex: str, testid: str, timeout_ms: int = RESULT_TIMEOUT_MS -) -> str: +def _wait_for_url_or_testid(page, url_regex: str, testid: str, timeout_ms: int = RESULT_TIMEOUT_MS) -> str: end_time = time.time() + (timeout_ms / 1000) regex = re.compile(url_regex) locator = page.locator(f"[data-testid='{testid}']") @@ -563,6 +509,4 @@ def _wait_for_url_or_testid( except Exception: pass page.wait_for_timeout(100) - raise AssertionError( - f"Timed out waiting for url {url_regex!r} or testid {testid!r}. url={page.url}" - ) + raise AssertionError(f"Timed out waiting for url {url_regex!r} or testid {testid!r}. url={page.url}") diff --git a/test/playwright/helpers/auth_selectors.py b/test/playwright/helpers/auth_selectors.py index 51336a500b..2e3565eb72 100644 --- a/test/playwright/helpers/auth_selectors.py +++ b/test/playwright/helpers/auth_selectors.py @@ -7,10 +7,7 @@ EMAIL_INPUT = "input[data-testid='auth-email'], [data-testid='auth-email'] input PASSWORD_INPUT = "input[data-testid='auth-password'], [data-testid='auth-password'] input" NICKNAME_INPUT = "input[data-testid='auth-nickname'], [data-testid='auth-nickname'] input" -SUBMIT_BUTTON = ( - "button[data-testid='auth-submit'], [data-testid='auth-submit'] button, " - "[data-testid='auth-submit']" -) +SUBMIT_BUTTON = "button[data-testid='auth-submit'], [data-testid='auth-submit'] button, [data-testid='auth-submit']" REGISTER_TAB = "[data-testid='auth-toggle-register']" LOGIN_TAB = "[data-testid='auth-toggle-login']" diff --git a/test/playwright/helpers/auth_waits.py b/test/playwright/helpers/auth_waits.py index 31fae9b542..1ac34a940e 100644 --- a/test/playwright/helpers/auth_waits.py +++ b/test/playwright/helpers/auth_waits.py @@ -1,4 +1,3 @@ - from playwright.sync_api import TimeoutError as PlaywrightTimeoutError try: @@ -37,6 +36,4 @@ def wait_for_login_complete(page, timeout_ms: int | None = None) -> None: ) except Exception: testids = [] - raise AssertionError( - f"Login did not complete within {timeout_ms}ms. url={url} auth_testids={testids}" - ) from exc + raise AssertionError(f"Login did not complete within {timeout_ms}ms. url={url} auth_testids={testids}") from exc diff --git a/test/playwright/helpers/datasets.py b/test/playwright/helpers/datasets.py index 89f832aa0a..d8bd5986b2 100644 --- a/test/playwright/helpers/datasets.py +++ b/test/playwright/helpers/datasets.py @@ -47,12 +47,8 @@ def wait_for_dataset_detail_ready(page, expect, timeout_ms: int) -> None: if env_bool("PW_DEBUG_DUMP"): url = page.url button_count = page.locator("button, [role='button']").count() - body_text = page.evaluate( - "(() => (document.body && document.body.innerText) || '')()" - ) - debug( - f"[dataset] detail_ready_failed url={url} button_count={button_count}" - ) + body_text = page.evaluate("(() => (document.body && document.body.innerText) || '')()") + debug(f"[dataset] detail_ready_failed url={url} button_count={button_count}") debug(f"[dataset] body_text_snippet={body_text[:200]!r}") raise @@ -114,9 +110,7 @@ def ensure_upload_modal_open(page, expect, auth_click, timeout_ms: int): return modal except AssertionError: pass - return open_upload_modal_from_dataset_detail( - page, expect, auth_click, timeout_ms=timeout_ms - ) + return open_upload_modal_from_dataset_detail(page, expect, auth_click, timeout_ms=timeout_ms) def ensure_parse_on(upload_modal, expect) -> None: @@ -136,9 +130,7 @@ def open_upload_modal_from_dataset_detail(page, expect, auth_click, timeout_ms: page.wait_for_selector("button", timeout=timeout_ms) if hasattr(page, "get_by_role"): - tab_locator = page.get_by_role( - "tab", name=re.compile(r"^(files|documents|file)$", re.I) - ) + tab_locator = page.get_by_role("tab", name=re.compile(r"^(files|documents|file)$", re.I)) if tab_locator.count() > 0: tab = tab_locator.first try: @@ -148,16 +140,12 @@ def open_upload_modal_from_dataset_detail(page, expect, auth_click, timeout_ms: except Exception: pass - candidate_names = re.compile( - r"(upload file|upload|add file|add document|add|new)", re.I - ) + candidate_names = re.compile(r"(upload file|upload|add file|add document|add|new)", re.I) trigger_locator = None if hasattr(page, "get_by_role"): trigger_locator = page.get_by_role("button", name=candidate_names) if trigger_locator is None or trigger_locator.count() == 0: - trigger_locator = page.locator("[role='button'], button, a").filter( - has_text=candidate_names - ) + trigger_locator = page.locator("[role='button'], button, a").filter(has_text=candidate_names) trigger = None if trigger_locator.count() > 0: @@ -172,9 +160,7 @@ def open_upload_modal_from_dataset_detail(page, expect, auth_click, timeout_ms: continue if trigger is None: - aria_candidates = page.locator( - "button[aria-label], button[title], [role='button'][aria-label], [role='button'][title]" - ) + aria_candidates = page.locator("button[aria-label], button[title], [role='button'][aria-label], [role='button'][title]") limit = min(aria_candidates.count(), 10) for idx in range(limit): candidate = aria_candidates.nth(idx) @@ -215,13 +201,8 @@ def open_upload_modal_from_dataset_detail(page, expect, auth_click, timeout_ms: title = item.get_attribute("title") except Exception as exc: title = f"" - button_dump.append( - {"text": text, "aria_label": aria_label, "title": title} - ) - raise AssertionError( - "Upload entrypoint not found on dataset detail page. " - f"visible_buttons={button_dump}" - ) + button_dump.append({"text": text, "aria_label": aria_label, "title": title}) + raise AssertionError(f"Upload entrypoint not found on dataset detail page. visible_buttons={button_dump}") try: if trigger.evaluate("el => el.tagName.toLowerCase() === 'button'"): @@ -233,15 +214,9 @@ def open_upload_modal_from_dataset_detail(page, expect, auth_click, timeout_ms: def _click_upload_file_popover_item() -> bool: locators = [ - page.locator("[role='menuitem']").filter( - has_text=re.compile(r"^upload file$", re.I) - ), - page.locator("[role='option']").filter( - has_text=re.compile(r"^upload file$", re.I) - ), - page.locator("div, span, li").filter( - has_text=re.compile(r"^upload file$", re.I) - ), + page.locator("[role='menuitem']").filter(has_text=re.compile(r"^upload file$", re.I)), + page.locator("[role='option']").filter(has_text=re.compile(r"^upload file$", re.I)), + page.locator("div, span, li").filter(has_text=re.compile(r"^upload file$", re.I)), ] for locator in locators: if locator.count() == 0: @@ -278,9 +253,7 @@ def open_upload_modal_from_dataset_detail(page, expect, auth_click, timeout_ms: has_upload_text = page.locator("text=/upload file/i").count() > 0 debug(f"[dataset] upload_item_missing has_upload_text={has_upload_text}") debug(f"[dataset] visible_button_texts={button_texts}") - raise AssertionError( - "Upload file popover item not found after clicking Add trigger." - ) + raise AssertionError("Upload file popover item not found after clicking Add trigger.") try: page.wait_for_load_state("domcontentloaded", timeout=timeout_ms) @@ -296,9 +269,7 @@ def select_chunking_method_general(page, expect, modal, timeout_ms: int) -> None """Select the General chunking method inside the dataset modal.""" trigger_locator = modal.locator( "button", - has=modal.locator( - "span", has_text=re.compile(r"please select a chunking method\\.", re.I) - ), + has=modal.locator("span", has_text=re.compile(r"please select a chunking method\\.", re.I)), ).first if trigger_locator.count() == 0: label = modal.locator("text=/please select a chunking method\\./i").first @@ -314,15 +285,8 @@ def select_chunking_method_general(page, expect, modal, timeout_ms: int) -> None if env_bool("PW_DEBUG_DUMP"): modal_text = modal.inner_text() button_count = modal.locator("button").count() - label_count = modal.locator( - "text=/please select a chunking method\\./i" - ).count() - debug( - "[dataset] chunking_trigger_missing " - f"button_count={button_count} label_count={label_count} " - f"trigger_locator_count={trigger_locator.count()} " - "trigger_handle_found=False" - ) + label_count = modal.locator("text=/please select a chunking method\\./i").count() + debug(f"[dataset] chunking_trigger_missing button_count={button_count} label_count={label_count} trigger_locator_count={trigger_locator.count()} trigger_handle_found=False") debug(f"[dataset] modal_text_snippet={modal_text[:300]!r}") raise AssertionError("Chunking method dropdown trigger not found.") @@ -342,29 +306,20 @@ def select_chunking_method_general(page, expect, modal, timeout_ms: int) -> None option = listbox.locator("span", has_text=re.compile(r"^General$", re.I)).first if option.count() == 0: - option = listbox.locator( - "div", has=page.locator("span", has_text=re.compile(r"^General$", re.I)) - ).first + option = listbox.locator("div", has=page.locator("span", has_text=re.compile(r"^General$", re.I))).first if option.count() == 0 and env_bool("PW_DEBUG_DUMP"): try: listbox_text = listbox.inner_text() except Exception: listbox_text = "" - span_count = listbox.locator( - "span", has_text=re.compile(r"^General$", re.I) - ).count() - debug( - "[dataset] general_option_missing " - f"listbox_count={listbox.count()} span_count={span_count}" - ) + span_count = listbox.locator("span", has_text=re.compile(r"^General$", re.I)).count() + debug(f"[dataset] general_option_missing listbox_count={listbox.count()} span_count={span_count}") debug(f"[dataset] listbox_text_snippet={listbox_text[:300]!r}") expect(option).to_be_visible(timeout=timeout_ms) option.click() if trigger_for_assert is not None: try: - expect(trigger_for_assert).to_contain_text( - re.compile(r"General", re.I), timeout=timeout_ms - ) + expect(trigger_for_assert).to_contain_text(re.compile(r"General", re.I), timeout=timeout_ms) except AssertionError: # Trigger can rerender after selection; verify selected label in modal instead. expect(modal).to_contain_text(re.compile(r"General", re.I), timeout=timeout_ms) @@ -386,9 +341,7 @@ def open_create_dataset_modal(page, expect, timeout_ms: int): except PlaywrightTimeoutError: if env_bool("PW_DEBUG_DUMP"): url = page.url - body_text = page.evaluate( - "(() => (document.body && document.body.innerText) || '')()" - ) + body_text = page.evaluate("(() => (document.body && document.body.innerText) || '')()") lines = body_text.splitlines() snippet = "\n".join(lines[:20])[:500] debug(f"[dataset] entrypoint_wait_timeout url={url} snippet={snippet!r}") @@ -399,11 +352,7 @@ def open_create_dataset_modal(page, expect, timeout_ms: int): locator.click() except Exception as exc: message = str(exc).lower() - if ( - "not attached to the dom" not in message - and "intercepts pointer events" not in message - and "element is not stable" not in message - ): + if "not attached to the dom" not in message and "intercepts pointer events" not in message and "element is not stable" not in message: raise locator.click(force=True) @@ -413,15 +362,11 @@ def open_create_dataset_modal(page, expect, timeout_ms: int): if hasattr(page, "get_by_role"): create_btn = page.get_by_role("button", name=re.compile(r"create dataset", re.I)) if create_btn is None or create_btn.count() == 0: - create_btn = page.locator( - "button", has_text=re.compile(r"create dataset", re.I) - ).first + create_btn = page.locator("button", has_text=re.compile(r"create dataset", re.I)).first if create_btn.count() == 0: if env_bool("PW_DEBUG_DUMP"): url = page.url - body_text = page.evaluate( - "(() => (document.body && document.body.innerText) || '')()" - ) + body_text = page.evaluate("(() => (document.body && document.body.innerText) || '')()") lines = body_text.splitlines() snippet = "\n".join(lines[:20])[:500] debug(f"[dataset] entrypoint_not_found url={url} snippet={snippet!r}") @@ -433,9 +378,7 @@ def open_create_dataset_modal(page, expect, timeout_ms: int): except AssertionError: if env_bool("PW_DEBUG_DUMP"): url = page.url - body_text = page.evaluate( - "(() => (document.body && document.body.innerText) || '')()" - ) + body_text = page.evaluate("(() => (document.body && document.body.innerText) || '')()") lines = body_text.splitlines() snippet = "\n".join(lines[:20])[:500] debug(f"[dataset] entrypoint_not_found url={url} snippet={snippet!r}") @@ -446,9 +389,7 @@ def open_create_dataset_modal(page, expect, timeout_ms: int): if empty_text.count() > 0: debug("[dataset] using empty-state entrypoint") expect(empty_text).to_be_visible(timeout=5000) - entrypoint = empty_text.locator( - "xpath=ancestor-or-self::*[self::button or self::a or @role='button'][1]" - ) + entrypoint = empty_text.locator("xpath=ancestor-or-self::*[self::button or self::a or @role='button'][1]") if entrypoint.count() > 0: expect(entrypoint.first).to_be_visible(timeout=5000) _click_entrypoint(entrypoint.first) @@ -482,17 +423,13 @@ def delete_uploaded_file(page, expect, filename: str, timeout_ms: int) -> None: if by_testid.count() > 0: return by_testid.first - by_label = confirm.locator( - "button:visible", has_text=re.compile("^delete$", re.I) - ) + by_label = confirm.locator("button:visible", has_text=re.compile("^delete$", re.I)) if by_label.count() > 0: return by_label.first return confirm.locator("button:visible").last - row = page.locator( - f"[data-testid='document-row'][data-doc-name={json.dumps(filename)}]" - ) + row = page.locator(f"[data-testid='document-row'][data-doc-name={json.dumps(filename)}]") expect(row).to_be_visible(timeout=timeout_ms) delete_button = row.locator("[data-testid='document-delete']") expect(delete_button).to_be_visible(timeout=timeout_ms) diff --git a/test/playwright/helpers/model_providers.py b/test/playwright/helpers/model_providers.py index 81b63f0b5b..c3d61b9b2b 100644 --- a/test/playwright/helpers/model_providers.py +++ b/test/playwright/helpers/model_providers.py @@ -119,18 +119,11 @@ def _assert_selected_option_value( return if _has_malformed_model_suffix(selected_value): - raise AssertionError( - "Selected combobox option contains malformed model suffix '#': " - f"value={selected_value!r} option_text={option_text!r}" - ) + raise AssertionError(f"Selected combobox option contains malformed model suffix '#': value={selected_value!r} option_text={option_text!r}") expected_prefix = _clean_text(expected_value_prefix) if expected_prefix and not selected_value.lower().startswith(expected_prefix.lower()): - raise AssertionError( - "Selected combobox option does not match expected canonical prefix: " - f"expected_prefix={expected_prefix!r} selected_value={selected_value!r} " - f"option_text={option_text!r}" - ) + raise AssertionError(f"Selected combobox option does not match expected canonical prefix: expected_prefix={expected_prefix!r} selected_value={selected_value!r} option_text={option_text!r}") def click_with_retry(page, expect, locator_factory, attempts: int, timeout_ms: int) -> None: @@ -165,9 +158,7 @@ def select_cmdk_option_by_value_prefix( controls_id = combobox.get_attribute("aria-controls") options_container = None - option_selector = ( - "[data-testid='combobox-option'], [role='option'], [cmdk-item], [data-value]" - ) + option_selector = "[data-testid='combobox-option'], [role='option'], [cmdk-item], [data-value]" if controls_id: controls_selector = f"[id={json.dumps(controls_id)}]:visible" @@ -190,11 +181,7 @@ def select_cmdk_option_by_value_prefix( return page.locator(option_selector) def option_locator(): - by_value = ( - options_container.locator(value_selector) - if options_container is not None - else page.locator(f"{value_selector}:visible") - ) + by_value = options_container.locator(value_selector) if options_container is not None else page.locator(f"{value_selector}:visible") if by_value.count() > 0: return by_value.first return options_locator().filter(has_text=option_pattern).first @@ -218,18 +205,12 @@ def select_cmdk_option_by_value_prefix( selected_value = None click_with_retry(page, expect, lambda: first_option, attempts=3, timeout_ms=timeout_ms) if selected_text: - expect(combobox).to_contain_text( - selected_text, timeout=timeout_ms - ) + expect(combobox).to_contain_text(selected_text, timeout=timeout_ms) try: - expect(combobox).to_have_attribute( - "aria-expanded", "false", timeout=timeout_ms - ) + expect(combobox).to_have_attribute("aria-expanded", "false", timeout=timeout_ms) except AssertionError: page.keyboard.press("Escape") - expect(combobox).to_have_attribute( - "aria-expanded", "false", timeout=timeout_ms - ) + expect(combobox).to_have_attribute("aria-expanded", "false", timeout=timeout_ms) return selected_text or option_text, selected_value dump = [] count = min(options.count(), 30) @@ -306,8 +287,7 @@ def select_default_model( capture_response( page, trigger, - lambda resp: resp.request.method == "PATCH" - and "/api/v1/users/me/models" in resp.url, + lambda resp: resp.request.method == "PATCH" and "/api/v1/users/me/models" in resp.url, ) except PlaywrightTimeoutError: if not selected[0]: @@ -322,8 +302,5 @@ def select_default_model( except Exception: current_text = expected_text if _has_malformed_model_suffix(current_text): - raise AssertionError( - "Combobox text still contains malformed model suffix '#': " - f"text={current_text!r} expected={expected_text!r}" - ) + raise AssertionError(f"Combobox text still contains malformed model suffix '#': text={current_text!r} expected={expected_text!r}") return selected diff --git a/test/playwright/helpers/response_capture.py b/test/playwright/helpers/response_capture.py index f7ad33c6f6..c15257cb1e 100644 --- a/test/playwright/helpers/response_capture.py +++ b/test/playwright/helpers/response_capture.py @@ -1,4 +1,3 @@ - try: from test.playwright.helpers._auth_helpers import RESULT_TIMEOUT_MS as DEFAULT_TIMEOUT_MS except Exception: @@ -12,9 +11,7 @@ def capture_response(page, trigger, predicate, timeout_ms: int = DEFAULT_TIMEOUT trigger() return response_info.value if hasattr(page, "expect_event"): - with page.expect_event( - "response", predicate=predicate, timeout=timeout_ms - ) as response_info: + with page.expect_event("response", predicate=predicate, timeout=timeout_ms) as response_info: trigger() return response_info.value if hasattr(page, "wait_for_event"): @@ -23,9 +20,7 @@ def capture_response(page, trigger, predicate, timeout_ms: int = DEFAULT_TIMEOUT raise RuntimeError("Playwright Page lacks expect_response/expect_event/wait_for_event.") -def capture_response_json( - page, trigger, predicate, timeout_ms: int = DEFAULT_TIMEOUT_MS -) -> dict: +def capture_response_json(page, trigger, predicate, timeout_ms: int = DEFAULT_TIMEOUT_MS) -> dict: response = capture_response(page, trigger, predicate, timeout_ms) info: dict = {"__url__": response.url, "__status__": response.status} try: diff --git a/test/testcases/conftest.py b/test/testcases/conftest.py index 841265dd66..85c44ca5ae 100644 --- a/test/testcases/conftest.py +++ b/test/testcases/conftest.py @@ -180,10 +180,7 @@ def get_added_models(auth, factory_name): # Go server (post-Python port) serializes this field as `model_provider` # in the RESTful `/api/v1/models` response. Fall back to the legacy # `provider_name` key so this conftest works against both. - added_factory = { - model.get("model_provider") or model["provider_name"] - for model in res.get("data", []) - } + added_factory = {model.get("model_provider") or model["provider_name"] for model in res.get("data", [])} if factory_name in added_factory: return True return False @@ -225,12 +222,7 @@ def add_model_instance(auth): # and BAAI/bge-reranker-v2-m3@CI@SILICONFLOW). instance_name = "CI" add_instance_api = HOST_ADDRESS + f"/api/v1/providers/{provider_name}/instances" - add_instance_response = requests.post(url=add_instance_api, headers=authorization, json={ - "instance_name": instance_name, - "api_key": api_key, - "region": "default", - "base_url": "" - }) + add_instance_response = requests.post(url=add_instance_api, headers=authorization, json={"instance_name": instance_name, "api_key": api_key, "region": "default", "base_url": ""}) add_instance_res = add_instance_response.json() if add_instance_res.get("code") != 0: msg = add_instance_res.get("message", "") @@ -248,10 +240,7 @@ def add_model_instance(auth): if "cannot be 'default'" in msg: print("Note: model instance name is reserved, skipping") continue - pytest.exit( - f"Critical error in add model instance {provider_name}/{instance_name}: " - f"{msg}" - ) + pytest.exit(f"Critical error in add model instance {provider_name}/{instance_name}: {msg}") add_success = get_added_models(auth, provider_name) if not add_success: @@ -262,10 +251,7 @@ def add_model_instance(auth): # on PUT. Downgrade to a warning so tests that don't depend # on the model can still run; tests that do will fail with # a real error rather than this opaque setup crash. - print( - "WARNING: provider already exists in catalog but missing from " - "this tenant's /api/v1/models. Tests that depend on it may fail." - ) + print("WARNING: provider already exists in catalog but missing from this tenant's /api/v1/models. Tests that depend on it may fail.") continue pytest.exit(f"Critical error in check added model: {provider_name} add model failed") @@ -280,15 +266,7 @@ def set_tenant_info(auth): url = HOST_ADDRESS + "/api/v1/models/default" authorization = {"Authorization": auth} # set chat model - set_default_llm_response = requests.patch( - url=url, - headers=authorization, - json={ - "model_provider": "ZHIPU-AI", - "model_instance": "CI", - "model_type": "chat", - "model_name": "glm-4-flash" - }) + set_default_llm_response = requests.patch(url=url, headers=authorization, json={"model_provider": "ZHIPU-AI", "model_instance": "CI", "model_type": "chat", "model_name": "glm-4-flash"}) llm_res = set_default_llm_response.json() if llm_res.get("code") != 0: # The Go server (post-Python port) doesn't yet implement @@ -296,40 +274,18 @@ def set_tenant_info(auth): # can't be set via API. Downgrade to a warning so tests that # don't rely on a default LLM can still run; tests that do # will fail with their own real error. - print( - f"WARNING: failed to set default chat LLM via {url}: " - f"{llm_res.get('message')!r}. Continuing." - ) + print(f"WARNING: failed to set default chat LLM via {url}: {llm_res.get('message')!r}. Continuing.") # set embedding model set_default_embedding_response = requests.patch( - url=url, - headers=authorization, - json={ - "model_provider": "Builtin", - "model_instance": "Local", - "model_type": "embedding", - "model_name": "BAAI/bge-small-en-v1.5" - }) + url=url, headers=authorization, json={"model_provider": "Builtin", "model_instance": "Local", "model_type": "embedding", "model_name": "BAAI/bge-small-en-v1.5"} + ) embd_res = set_default_embedding_response.json() if embd_res.get("code") != 0: - print( - f"WARNING: failed to set default embedding LLM via {url}: " - f"{embd_res.get('message')!r}. Continuing." - ) + print(f"WARNING: failed to set default embedding LLM via {url}: {embd_res.get('message')!r}. Continuing.") # set rerank model set_default_rerank_response = requests.patch( - url=url, - headers=authorization, - json={ - "model_provider": "SILICONFLOW", - "model_instance": "CI", - "model_type": "rerank", - "model_name": "BAAI/bge-reranker-v2-m3" - } + url=url, headers=authorization, json={"model_provider": "SILICONFLOW", "model_instance": "CI", "model_type": "rerank", "model_name": "BAAI/bge-reranker-v2-m3"} ) rerank_res = set_default_rerank_response.json() if rerank_res.get("code") != 0: - print( - f"WARNING: failed to set default rerank LLM via {url}: " - f"{rerank_res.get('message')!r}. Continuing." - ) + print(f"WARNING: failed to set default rerank LLM via {url}: {rerank_res.get('message')!r}. Continuing.") diff --git a/test/testcases/restful_api/test_agents.py b/test/testcases/restful_api/test_agents.py index 6f39d6800b..e9343cf718 100644 --- a/test/testcases/restful_api/test_agents.py +++ b/test/testcases/restful_api/test_agents.py @@ -142,8 +142,7 @@ def test_agents_crud_validation_contract(rest_client, create_agent_resource): # code=103 = permission denied (Python: "Only the owner of the agent..."; Go: "Make sure you have permission...") assert invalid_delete_payload["code"] == 103, invalid_delete_payload msg = invalid_delete_payload["message"] - assert ("Only the owner of the agent is authorized" in msg - or "Make sure you have permission" in msg), invalid_delete_payload + assert "Only the owner of the agent is authorized" in msg or "Make sure you have permission" in msg, invalid_delete_payload delete_res = rest_client.delete(f"/agents/{agent_id}") assert delete_res.status_code == 200 diff --git a/test/testcases/restful_api/test_agents_go_mode.py b/test/testcases/restful_api/test_agents_go_mode.py index 41b0bd1e8b..5f3ce5fea4 100644 --- a/test/testcases/restful_api/test_agents_go_mode.py +++ b/test/testcases/restful_api/test_agents_go_mode.py @@ -10,21 +10,22 @@ pytestmark = pytest.mark.skipif( V2_DSL = { "graph": { "nodes": [ - {"id": "begin", "type": "beginNode", "position": {"x": 50, "y": 200}, - "data": {"label": "Begin", "name": "begin"}}, - {"id": "answer:0", "type": "messageNode", "position": {"x": 400, "y": 200}, - "data": {"label": "Answer", "name": "answer"}}, + {"id": "begin", "type": "beginNode", "position": {"x": 50, "y": 200}, "data": {"label": "Begin", "name": "begin"}}, + {"id": "answer:0", "type": "messageNode", "position": {"x": 400, "y": 200}, "data": {"label": "Answer", "name": "answer"}}, ], "edges": [ - {"id": "xy-edge__begin-answer:0", "source": "begin", "target": "answer:0", - "sourceHandle": "end", "targetHandle": "start"}, + {"id": "xy-edge__begin-answer:0", "source": "begin", "target": "answer:0", "sourceHandle": "end", "targetHandle": "start"}, ], }, "components": { - "begin": {"obj": {"component_name": "Begin", "params": {}}, "downstream": ["answer:0"], "upstream": []}, - "answer:0": {"obj": {"component_name": "Answer", "params": {}}, "downstream": [], "upstream": ["begin"]}, + "begin": {"obj": {"component_name": "Begin", "params": {}}, "downstream": ["answer:0"], "upstream": []}, + "answer:0": {"obj": {"component_name": "Answer", "params": {}}, "downstream": [], "upstream": ["begin"]}, }, - "retrieval": [], "history": [], "path": [], "variables": [], "globals": {"sys.query": ""}, + "retrieval": [], + "history": [], + "path": [], + "variables": [], + "globals": {"sys.query": ""}, } @@ -48,15 +49,10 @@ def test_v2_dsl_round_trip_position_preserved(rest_client): # 3. PUT — move begin node to (777, 888) new_dsl = {**V2_DSL, "graph": {**V2_DSL["graph"]}} - new_dsl["graph"]["nodes"] = [ - {**n, "position": {"x": 777, "y": 888}} if n["id"] == "begin" else n - for n in V2_DSL["graph"]["nodes"] - ] + new_dsl["graph"]["nodes"] = [{**n, "position": {"x": 777, "y": 888}} if n["id"] == "begin" else n for n in V2_DSL["graph"]["nodes"]] r = rest_client.put(f"/agents/{agent_id}", json={"title": title, "dsl": new_dsl}) assert r.status_code == 200 and r.json()["code"] == 0, r.text # 4. Re-GET — position preserved - pos = next(n["position"] for n in - rest_client.get(f"/agents/{agent_id}").json()["data"]["dsl"]["graph"]["nodes"] - if n["id"] == "begin") + pos = next(n["position"] for n in rest_client.get(f"/agents/{agent_id}").json()["data"]["dsl"]["graph"]["nodes"] if n["id"] == "begin") assert pos == {"x": 777, "y": 888}, f"position lost: {pos}" diff --git a/test/testcases/restful_api/test_chats.py b/test/testcases/restful_api/test_chats.py index 14f79ae4ea..28f213ac7c 100644 --- a/test/testcases/restful_api/test_chats.py +++ b/test/testcases/restful_api/test_chats.py @@ -35,8 +35,8 @@ from test.testcases.utils.file_utils import create_image_file DEFAULT_CHAT_EMPTY_RESPONSE = "Sorry! No relevant content was found in the knowledge base!" DEFAULT_CHAT_PROLOGUE = "Hi! I'm your assistant. What can I do for you?" DEFAULT_CHAT_SYSTEM_PROMPT = ( - 'You are an intelligent assistant. Please summarize the content of the dataset to answer the question. ' - 'Please list the data in the dataset and answer in detail. When all dataset content is irrelevant to the ' + "You are an intelligent assistant. Please summarize the content of the dataset to answer the question. " + "Please list the data in the dataset and answer in detail. When all dataset content is irrelevant to the " 'question, your answer must include the sentence "The answer you are looking for is not found in the dataset!" ' "Answers need to consider chat history.\n" " Here is the knowledge base:\n" @@ -72,7 +72,6 @@ def _reset_chat_batch(rest_client, prefix, count=5): return ids - @pytest.mark.p1 class TestChatsAuthorization: def test_create_requires_auth(self, rest_client_noauth): @@ -630,6 +629,7 @@ def _load_chat_routes_unit_module(monkeypatch): common_constants_mod.RetCode = _StubRetCode common_constants_mod.StatusEnum = _StubStatusEnum from common.constants import MAXIMUM_PAGE_NUMBER as _MPN, MAXIMUM_TASK_PAGE_NUMBER as _MTPN + common_constants_mod.MAXIMUM_PAGE_NUMBER = _MPN common_constants_mod.MAXIMUM_TASK_PAGE_NUMBER = _MTPN monkeypatch.setitem(sys.modules, "common.constants", common_constants_mod) @@ -765,7 +765,7 @@ def _load_chat_routes_unit_module(monkeypatch): tenant_model_service_mod.get_model_config_from_provider_instance = lambda *_args, **_kwargs: {} tenant_model_service_mod.get_tenant_default_model_by_type = lambda *_args, **_kwargs: {} tenant_model_service_mod.get_api_key = lambda *_args, **_kwargs: SimpleNamespace(id=1) - tenant_model_service_mod.split_model_name = lambda model: (model.split("@")[0],"default", "factory") + tenant_model_service_mod.split_model_name = lambda model: (model.split("@")[0], "default", "factory") monkeypatch.setitem(sys.modules, "api.db.joint_services.tenant_model_service", tenant_model_service_mod) user_service_mod = ModuleType("api.db.services.user_service") @@ -811,7 +811,7 @@ def _load_chat_routes_unit_module(monkeypatch): api_utils_mod.get_json_result = lambda data=None, message="", code=0: {"code": code, "data": data, "message": message} api_utils_mod.get_request_json = lambda: _AwaitableValue({}) api_utils_mod.server_error_response = lambda ex: {"code": 500, "data": None, "message": str(ex)} - api_utils_mod.validate_request = lambda *_args, **_kwargs: (lambda func: func) + api_utils_mod.validate_request = lambda *_args, **_kwargs: lambda func: func monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) rag_pkg = ModuleType("rag") @@ -1147,11 +1147,9 @@ def test_chat_create_accepts_provider_scoped_rerank_id_unit(monkeypatch): monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _id: (True, _DummyKB())) def _split_model_name_and_factory(model_name): - return { - "glm-4@ZHIPU-AI": ("glm-4", "default", "ZHIPU-AI"), - "glm-4@CI@ZHIPU-AI": ("glm-4", "CI", "ZHIPU-AI"), - "custom-reranker@OpenAI": ("custom-reranker", "default", "OpenAI") - }.get(model_name, (model_name, None)) + return {"glm-4@ZHIPU-AI": ("glm-4", "default", "ZHIPU-AI"), "glm-4@CI@ZHIPU-AI": ("glm-4", "CI", "ZHIPU-AI"), "custom-reranker@OpenAI": ("custom-reranker", "default", "OpenAI")}.get( + model_name, (model_name, None) + ) monkeypatch.setattr(module, "split_model_name", _split_model_name_and_factory) @@ -1227,7 +1225,7 @@ def test_chat_create_uses_direct_chat_fields_unit(monkeypatch): monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda **_kwargs: [SimpleNamespace(id="kb-1")]) monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: [_DummyKB()]) monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _id: (True, _DummyKB())) - monkeypatch.setattr(module, "split_model_name", lambda model: (model.split("@")[0],"default", "factory")) + monkeypatch.setattr(module, "split_model_name", lambda model: (model.split("@")[0], "default", "factory")) def _save(**kwargs): saved.update(kwargs) @@ -1382,7 +1380,7 @@ def test_patch_chat_drops_response_only_fields_before_update_unit(monkeypatch): monkeypatch.setattr(module.TenantService, "get_by_id", lambda _tid: (True, SimpleNamespace(llm_id="glm-4"))) monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda **_kwargs: [SimpleNamespace(id="kb-1")]) monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: [_DummyKB()]) - monkeypatch.setattr(module, "split_model_name", lambda model: (model.split("@")[0],"default", "factory")) + monkeypatch.setattr(module, "split_model_name", lambda model: (model.split("@")[0], "default", "factory")) monkeypatch.setattr(module, "get_api_key", lambda *args, **kwargs: SimpleNamespace(id=1)) def _update(_chat_id, req): diff --git a/test/testcases/restful_api/test_chunks.py b/test/testcases/restful_api/test_chunks.py index 67faf9f539..1ba7fe20b0 100644 --- a/test/testcases/restful_api/test_chunks.py +++ b/test/testcases/restful_api/test_chunks.py @@ -778,9 +778,7 @@ def test_chunk_update_invalid_target_and_param_contract(rest_client, create_docu @pytest.mark.p2 def test_chunk_update_repeated_concurrent_and_deleted_document_contract(rest_client, create_document): - dataset_id, document_id, chunk_id, base_path = _create_chunk_for_update( - rest_client, create_document, "chunk_update_repeated_concurrent_deleted.txt" - ) + dataset_id, document_id, chunk_id, base_path = _create_chunk_for_update(rest_client, create_document, "chunk_update_repeated_concurrent_deleted.txt") first_res = rest_client.patch(f"{base_path}/{chunk_id}", json={"content": "chunk test 1"}) assert first_res.status_code == 200 diff --git a/test/testcases/restful_api/test_connector_routes_unit.py b/test/testcases/restful_api/test_connector_routes_unit.py index 80cd5662a6..53df4c8219 100644 --- a/test/testcases/restful_api/test_connector_routes_unit.py +++ b/test/testcases/restful_api/test_connector_routes_unit.py @@ -241,7 +241,7 @@ def _load_connector_app(monkeypatch): "message": message, "data": data, } - api_utils_mod.validate_request = lambda *_args, **_kwargs: (lambda fn: fn) + api_utils_mod.validate_request = lambda *_args, **_kwargs: lambda fn: fn monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) constants_mod = ModuleType("common.constants") @@ -268,8 +268,7 @@ def _load_connector_app(monkeypatch): google_constants_mod = ModuleType("common.data_source.google_util.constant") google_constants_mod.WEB_OAUTH_POPUP_TEMPLATE = ( - "{title}" - "

{heading}

{message}

" + "{title}

{heading}

{message}

" ) google_constants_mod.GOOGLE_SCOPES = { config_mod.DocumentSource.GMAIL: ["scope-gmail"], @@ -376,7 +375,7 @@ def test_connector_basic_routes_and_task_controls(monkeypatch): lambda: _AwaitableValue({"id": "conn-1", "refresh_freq": 7, "config": {"x": 1}}), ) res = _run(module.update_connector("conn-1")) - assert update_calls == [("conn-1", {'id': 'conn-1', "refresh_freq": 7, "config": {"x": 1}})] + assert update_calls == [("conn-1", {"id": "conn-1", "refresh_freq": 7, "config": {"x": 1}})] assert res["data"]["id"] == "conn-1" monkeypatch.setattr( @@ -609,27 +608,33 @@ def test_google_web_oauth_callbacks_matrix(monkeypatch): assert "Authorization session was invalid" in invalid_state.body assert module._web_state_cache_key("sid", source) in redis.deleted - redis.store[module._web_state_cache_key("sid", source)] = json.dumps({ - "user_id": "tenant-1", - "client_config": {"web": {"client_id": "cid"}}, - }) + redis.store[module._web_state_cache_key("sid", source)] = json.dumps( + { + "user_id": "tenant-1", + "client_config": {"web": {"client_id": "cid"}}, + } + ) _set_request(module, args={"state": "sid", "error": "denied", "error_description": "permission denied"}) oauth_error = _run(callback()) assert "permission denied" in oauth_error.body - redis.store[module._web_state_cache_key("sid", source)] = json.dumps({ - "user_id": "tenant-1", - "client_config": {"web": {"client_id": "cid"}}, - }) + redis.store[module._web_state_cache_key("sid", source)] = json.dumps( + { + "user_id": "tenant-1", + "client_config": {"web": {"client_id": "cid"}}, + } + ) _set_request(module, args={"state": "sid"}) missing_code = _run(callback()) assert "Missing authorization code" in missing_code.body - redis.store[module._web_state_cache_key("sid", source)] = json.dumps({ - "user_id": "tenant-1", - "client_config": {"web": {"client_id": "cid"}}, - "code_verifier": "state-code-verifier", - }) + redis.store[module._web_state_cache_key("sid", source)] = json.dumps( + { + "user_id": "tenant-1", + "client_config": {"web": {"client_id": "cid"}}, + "code_verifier": "state-code-verifier", + } + ) _set_request(module, args={"state": "sid", "code": "code-123"}) success = _run(callback()) assert "Authorization completed successfully." in success.body @@ -658,16 +663,12 @@ def test_poll_google_web_result_matrix(monkeypatch): pending = _run(module.poll_google_web_result()) assert pending["code"] == module.RetCode.RUNNING - redis.store[module._web_result_cache_key("flow-1", "gmail")] = json.dumps( - {"user_id": "another-user", "credentials": "token-x"} - ) + redis.store[module._web_result_cache_key("flow-1", "gmail")] = json.dumps({"user_id": "another-user", "credentials": "token-x"}) _set_request(module, args={"type": "gmail"}, json_body={"flow_id": "flow-1"}) permission_error = _run(module.poll_google_web_result()) assert permission_error["code"] == module.RetCode.PERMISSION_ERROR - redis.store[module._web_result_cache_key("flow-1", "gmail")] = json.dumps( - {"user_id": "tenant-1", "credentials": "token-ok"} - ) + redis.store[module._web_result_cache_key("flow-1", "gmail")] = json.dumps({"user_id": "tenant-1", "credentials": "token-ok"}) _set_request(module, args={"type": "gmail"}, json_body={"flow_id": "flow-1"}) success = _run(module.poll_google_web_result()) assert success["code"] == 0 @@ -720,16 +721,12 @@ def test_box_oauth_start_callback_and_poll_matrix(monkeypatch): invalid_session = _run(module.box_web_oauth_callback()) assert invalid_session["code"] == module.RetCode.ARGUMENT_ERROR - redis.store[module._web_state_cache_key("flow-box", "box")] = json.dumps( - {"user_id": "tenant-1", "client_id": "cid", "client_secret": "sec"} - ) + redis.store[module._web_state_cache_key("flow-box", "box")] = json.dumps({"user_id": "tenant-1", "client_id": "cid", "client_secret": "sec"}) _set_request(module, args={"state": "flow-box", "code": "abc", "error": "access_denied", "error_description": "denied"}) callback_error = _run(module.box_web_oauth_callback()) assert "denied" in callback_error.body - redis.store[module._web_state_cache_key("flow-ok", "box")] = json.dumps( - {"user_id": "tenant-1", "client_id": "cid", "client_secret": "sec"} - ) + redis.store[module._web_state_cache_key("flow-ok", "box")] = json.dumps({"user_id": "tenant-1", "client_id": "cid", "client_secret": "sec"}) _set_request(module, args={"state": "flow-ok", "code": "code-ok"}) callback_success = _run(module.box_web_oauth_callback()) assert "Authorization completed successfully." in callback_success.body @@ -746,9 +743,7 @@ def test_box_oauth_start_callback_and_poll_matrix(monkeypatch): permission_error = _run(module.poll_box_web_result()) assert permission_error["code"] == module.RetCode.PERMISSION_ERROR - redis.store[module._web_result_cache_key("flow-ok", "box")] = json.dumps( - {"user_id": "tenant-1", "access_token": "at", "refresh_token": "rt"} - ) + redis.store[module._web_result_cache_key("flow-ok", "box")] = json.dumps({"user_id": "tenant-1", "access_token": "at", "refresh_token": "rt"}) poll_success = _run(module.poll_box_web_result()) assert poll_success["code"] == 0 assert poll_success["data"]["credentials"]["access_token"] == "at" diff --git a/test/testcases/restful_api/test_datasets.py b/test/testcases/restful_api/test_datasets.py index 31acad3f9d..c898715dca 100644 --- a/test/testcases/restful_api/test_datasets.py +++ b/test/testcases/restful_api/test_datasets.py @@ -636,9 +636,7 @@ def test_dataset_update_content_type_and_payload_contract(rest_client, clear_dat assert bad_content_type_res.status_code == 200 bad_content_type_payload = bad_content_type_res.json() assert bad_content_type_payload["code"] == 101, bad_content_type_payload - assert ( - f"Unsupported content type: Expected application/json, got {bad_content_type}" in bad_content_type_payload["message"] - ), bad_content_type_payload + assert f"Unsupported content type: Expected application/json, got {bad_content_type}" in bad_content_type_payload["message"], bad_content_type_payload malformed_json_res = rest_client.put(f"/datasets/{dataset_id}", data="a") assert malformed_json_res.status_code == 200 @@ -868,10 +866,7 @@ def test_dataset_update_chunk_method_invalid_contract(rest_client, clear_dataset assert create_payload["code"] == 0, create_payload dataset_id = create_payload["data"]["id"] - expected_chunk_message = ( - "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', " - "'picture', 'presentation', 'qa', 'table', 'tag' or 'resume'" - ) + expected_chunk_message = "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table', 'tag' or 'resume'" for chunk_method in ("", "unknown", []): res = rest_client.put(f"/datasets/{dataset_id}", json={"chunk_method": chunk_method}) assert res.status_code == 200 @@ -1191,9 +1186,7 @@ def test_dataset_create_permission_contract(rest_client, clear_datasets, name, p "tenant_no_auth", ], ) -def test_dataset_create_embedding_model_contract( - rest_client, clear_datasets, name, embedding_model, expected_code, expected_embedding_model, expected_message, unauthorized_is_xfail -): +def test_dataset_create_embedding_model_contract(rest_client, clear_datasets, name, embedding_model, expected_code, expected_embedding_model, expected_message, unauthorized_is_xfail): req = {"name": name} if embedding_model != "__UNSET__": req["embedding_model"] = embedding_model @@ -1414,9 +1407,7 @@ def test_dataset_create_parser_config_valid_matrix_contract(rest_client, clear_d ], ids=["only_raptor", "only_graphrag", "both_fields"], ) -def test_dataset_create_parser_config_bugfix_contract( - rest_client, clear_datasets, name, parser_config, expected_raptor, expected_graphrag -): +def test_dataset_create_parser_config_bugfix_contract(rest_client, clear_datasets, name, parser_config, expected_raptor, expected_graphrag): res = rest_client.post("/datasets", json={"name": name, "parser_config": parser_config}) assert res.status_code == 200 body = res.json() @@ -1451,6 +1442,8 @@ def test_dataset_create_parser_config_different_chunk_methods_contract(rest_clie assert "graphrag" in parser_config, body assert parser_config["raptor"]["use_raptor"] is False, body assert parser_config["graphrag"]["use_graphrag"] is False, body + + def test_dataset_create_name_invalid_and_duplicate_contract(rest_client, clear_datasets): invalid_cases = [ ("", "String should have at least 1 character"), @@ -1604,10 +1597,7 @@ def test_dataset_create_permission_and_chunk_method_contract(rest_client, clear_ ("chunk_unknown", "unknown"), ("chunk_type_error", []), ] - expected_chunk_message = ( - "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', " - "'picture', 'presentation', 'qa', 'table', 'tag' or 'resume'" - ) + expected_chunk_message = "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table', 'tag' or 'resume'" for name, chunk_method in chunk_method_invalid_cases: res = rest_client.post("/datasets", json={"name": name, "chunk_method": chunk_method}) assert res.status_code == 200 @@ -1788,9 +1778,7 @@ def test_dataset_delete_contract_matrix(rest_client, clear_datasets): assert bad_content_type_res.status_code == 200 bad_content_type_payload = bad_content_type_res.json() assert bad_content_type_payload["code"] == 101, bad_content_type_payload - assert ( - f"Unsupported content type: Expected application/json, got {bad_content_type}" in bad_content_type_payload["message"] - ), bad_content_type_payload + assert f"Unsupported content type: Expected application/json, got {bad_content_type}" in bad_content_type_payload["message"], bad_content_type_payload malformed_json_res = rest_client.delete("/datasets", data="a") assert malformed_json_res.status_code == 200 diff --git a/test/testcases/restful_api/test_dify_retrieval_routes_unit.py b/test/testcases/restful_api/test_dify_retrieval_routes_unit.py index d7b00b6afa..83048a425f 100644 --- a/test/testcases/restful_api/test_dify_retrieval_routes_unit.py +++ b/test/testcases/restful_api/test_dify_retrieval_routes_unit.py @@ -53,11 +53,7 @@ class _DummyKB: class _DummyRetriever: async def retrieval(self, *_args, **_kwargs): - return { - "chunks": [ - {"doc_id": "doc-1", "content_with_weight": "chunk-content", "similarity": 0.8, "docnm_kwd": "doc-title", "vector": [0.1]} - ] - } + return {"chunks": [{"doc_id": "doc-1", "content_with_weight": "chunk-content", "similarity": 0.8, "docnm_kwd": "doc-title", "vector": [0.1]}]} def retrieval_by_children(self, chunks, _tenant_ids): return chunks diff --git a/test/testcases/restful_api/test_documents.py b/test/testcases/restful_api/test_documents.py index 77519f67fb..2e04607573 100644 --- a/test/testcases/restful_api/test_documents.py +++ b/test/testcases/restful_api/test_documents.py @@ -446,10 +446,7 @@ def test_documents_upload_contract_matrix(rest_client, create_dataset, tmp_path) assert len(multi_payload["data"]) == 20, multi_payload with ThreadPoolExecutor(max_workers=5) as executor: - futures = [ - executor.submit(_upload_files, rest_client, dataset_id, [create_txt_file(tmp_path / f"parallel_upload_{i}.txt")]) - for i in range(20) - ] + futures = [executor.submit(_upload_files, rest_client, dataset_id, [create_txt_file(tmp_path / f"parallel_upload_{i}.txt")]) for i in range(20)] responses = list(as_completed(futures)) assert len(responses) == 20, responses assert all(f.result().json()["code"] == 0 for f in futures) @@ -1171,9 +1168,7 @@ def test_documents_delete_invalid_dataset_partial_duplicate_repeat_and_cross_dat @pytest.mark.p2 def test_documents_delete_concurrent_and_bulk_contract(rest_client, create_dataset, tmp_path): - dataset_id, uploaded_docs = _seed_documents( - rest_client, create_dataset, tmp_path, count=60, timeout=120 - ) + dataset_id, uploaded_docs = _seed_documents(rest_client, create_dataset, tmp_path, count=60, timeout=120) document_ids = [doc["id"] for doc in uploaded_docs] with ThreadPoolExecutor(max_workers=8) as executor: @@ -1199,9 +1194,7 @@ def test_documents_delete_concurrent_and_bulk_contract(rest_client, create_datas assert list_after_payload["code"] == 0, list_after_payload assert list_after_payload["data"]["total"] == 0, list_after_payload - bulk_dataset_id, bulk_docs = _seed_documents( - rest_client, create_dataset, tmp_path, count=120, timeout=120 - ) + bulk_dataset_id, bulk_docs = _seed_documents(rest_client, create_dataset, tmp_path, count=120, timeout=120) bulk_ids = [doc["id"] for doc in bulk_docs] bulk_delete_res = rest_client.delete( f"/datasets/{bulk_dataset_id}/documents", diff --git a/test/testcases/restful_api/test_file_commit_routes_unit.py b/test/testcases/restful_api/test_file_commit_routes_unit.py index 0e9e72e987..79eb07b17f 100644 --- a/test/testcases/restful_api/test_file_commit_routes_unit.py +++ b/test/testcases/restful_api/test_file_commit_routes_unit.py @@ -38,6 +38,7 @@ class _DummyManager: def route(self, *_args, **_kwargs): def decorator(func): return func + return decorator @@ -55,7 +56,7 @@ _request_payload: list = [{}] # FileCommitService (which uses DB.atomic(), .select(), .where(), etc.) # works against real SQL. -sqlite_db = SqliteDatabase(':memory:') +sqlite_db = SqliteDatabase(":memory:") class BaseTestModel(Model): @@ -141,6 +142,7 @@ def _clear_db(): # ── Module loader ───────────────────────────────────────────────────────── + def _load_module(monkeypatch): """Load file_commit_api.py with SQLite in-memory DB and mocked HTTP layer.""" repo_root = Path(__file__).resolve().parents[3] @@ -184,12 +186,11 @@ def _load_module(monkeypatch): payload = await get_request_json() missing = [k for k in required_keys if k not in payload] if missing: - return get_json_result( - code=101, data=None, - message="required argument are missing: " + ", ".join(missing) - ) + return get_json_result(code=101, data=None, message="required argument are missing: " + ", ".join(missing)) return await func(*args, **kwargs) + return _wrapper + return _decorator api_utils_mod.get_json_result = get_json_result @@ -201,6 +202,7 @@ def _load_module(monkeypatch): # Stub: common.misc_utils import uuid + misc_utils_mod = ModuleType("common.misc_utils") misc_utils_mod.get_uuid = lambda: uuid.uuid1().hex monkeypatch.setitem(sys.modules, "common.misc_utils", misc_utils_mod) @@ -235,6 +237,7 @@ def _load_module(monkeypatch): def connection_context(): def dec(func): return func + return dec @staticmethod @@ -242,8 +245,10 @@ def _load_module(monkeypatch): class Ctx: def __enter__(self2): return self2 + def __exit__(self2, *args): pass + return Ctx() db_models_mod.DB = _DB @@ -256,9 +261,11 @@ def _load_module(monkeypatch): class _StubFileService: model = FileTestModel # class attribute, not staticmethod — code accesses FileService.model.update(...) + @staticmethod def update_by_id(pid, data): return FileTestModel.update(data).where(FileTestModel.id == pid).execute() + @staticmethod def get_by_id(pid): try: @@ -266,6 +273,7 @@ def _load_module(monkeypatch): return True, obj except Exception: return False, None + @staticmethod def get_or_none(**kwargs): try: @@ -275,6 +283,7 @@ def _load_module(monkeypatch): class CommonServiceBase: model = None + @classmethod def get_by_id(cls, pid): try: @@ -284,6 +293,7 @@ def _load_module(monkeypatch): except Exception: pass return False, None + @classmethod def query(cls, cols=None, reverse=None, order_by=None, **kwargs): q = cls.model.select() @@ -291,9 +301,11 @@ def _load_module(monkeypatch): if f_v is not None and hasattr(cls.model, f_n): q = q.where(getattr(cls.model, f_n) == f_v) return q + @classmethod def update_by_id(cls, pid, data): return cls.model.update(data).where(cls.model.id == pid).execute() + @classmethod def filter_update(cls, filters, update_data): return cls.model.update(update_data).where(*filters).execute() @@ -306,11 +318,11 @@ def _load_module(monkeypatch): # Stub: api.db with real filesystem path so sub-packages can be discovered. db_pkg = ModuleType("api.db") db_pkg.__path__ = [str(repo_root / "api" / "db")] - db_pkg.UserTenantRole = type('UserTenantRole', (), {k: k for k in ('OWNER','ADMIN','NORMAL','INVITE')}) - db_pkg.TenantPermission = type('TenantPermission', (), {'ME': 'me', 'TEAM': 'team'}) - db_pkg.FileType = type('FileType', (), {'FOLDER': 'folder', 'DOC': 'doc', 'VISUAL': 'visual', 'AURAL': 'aural', 'VIRTUAL': 'virtual', 'PDF': 'pdf', 'OTHER': 'other'}) - db_pkg.KNOWLEDGEBASE_FOLDER_NAME = '.knowledgebase' - db_pkg.SKILLS_FOLDER_NAME = 'skills' + db_pkg.UserTenantRole = type("UserTenantRole", (), {k: k for k in ("OWNER", "ADMIN", "NORMAL", "INVITE")}) + db_pkg.TenantPermission = type("TenantPermission", (), {"ME": "me", "TEAM": "team"}) + db_pkg.FileType = type("FileType", (), {"FOLDER": "folder", "DOC": "doc", "VISUAL": "visual", "AURAL": "aural", "VIRTUAL": "virtual", "PDF": "pdf", "OTHER": "other"}) + db_pkg.KNOWLEDGEBASE_FOLDER_NAME = ".knowledgebase" + db_pkg.SKILLS_FOLDER_NAME = "skills" monkeypatch.setitem(sys.modules, "api.db", db_pkg) api_pkg.db = db_pkg @@ -333,6 +345,7 @@ def _load_module(monkeypatch): monkeypatch.setitem(sys.modules, "api.db.services.common_service", common_svc_mod) kb_svc_mod = ModuleType("api.db.services.knowledgebase_service") + # NB: The dataset resolver in the API calls KnowledgebaseService.get_by_id # then accesses .name and .tenant_id. We return a simple object. class _StubKnowledgebaseService: @@ -341,6 +354,7 @@ def _load_module(monkeypatch): if dataset_id == "ds-1": return True, SimpleNamespace(name="test-ds", tenant_id="t1") return False, None + kb_svc_mod.KnowledgebaseService = _StubKnowledgebaseService monkeypatch.setitem(sys.modules, "api.db.services.knowledgebase_service", kb_svc_mod) @@ -364,6 +378,7 @@ def _load_module(monkeypatch): # ── Helpers ─────────────────────────────────────────────────────────────── + def _setup_request(module, json_payload=None, args=None): """Set up a request payload and query args for the next handler call.""" if json_payload is not None: @@ -374,6 +389,7 @@ def _setup_request(module, json_payload=None, args=None): # ── Fixtures ────────────────────────────────────────────────────────────── + @pytest.fixture(scope="session") def auth(): return "test-auth" @@ -392,17 +408,20 @@ def reset_db(): # ── Tests ───────────────────────────────────────────────────────────────── + @pytest.mark.p2 def test_create_commit_success(monkeypatch): module = _load_module(monkeypatch) # Seed a file - FileTestModel.create(id="f1", parent_id="root-folder", tenant_id="t1", - created_by="test-user", name="a.txt", type="txt") + FileTestModel.create(id="f1", parent_id="root-folder", tenant_id="t1", created_by="test-user", name="a.txt", type="txt") - _setup_request(module, json_payload={ - "message": "initial commit", - "files": [{"file_id": "f1", "file_name": "a.txt", "operation": "add", "content": "hello"}], - }) + _setup_request( + module, + json_payload={ + "message": "initial commit", + "files": [{"file_id": "f1", "file_name": "a.txt", "operation": "add", "content": "hello"}], + }, + ) res = _run(module.create_commit("root-folder")) assert res["code"] == 0, f"Expected 0, got {res}" @@ -427,26 +446,30 @@ def test_create_commit_missing_fields(monkeypatch): @pytest.mark.p2 def test_create_commit_modify_and_add(monkeypatch): module = _load_module(monkeypatch) - FileTestModel.create(id="f1", parent_id="root-folder", tenant_id="t1", - created_by="test-user", name="a.txt", type="txt") - FileTestModel.create(id="f2", parent_id="root-folder", tenant_id="t1", - created_by="test-user", name="b.txt", type="txt") + FileTestModel.create(id="f1", parent_id="root-folder", tenant_id="t1", created_by="test-user", name="a.txt", type="txt") + FileTestModel.create(id="f2", parent_id="root-folder", tenant_id="t1", created_by="test-user", name="b.txt", type="txt") # Commit 1: add f1 - _setup_request(module, json_payload={ - "message": "c1", - "files": [{"file_id": "f1", "file_name": "a.txt", "operation": "add", "content": "v1"}], - }) + _setup_request( + module, + json_payload={ + "message": "c1", + "files": [{"file_id": "f1", "file_name": "a.txt", "operation": "add", "content": "v1"}], + }, + ) _run(module.create_commit("root-folder")) # Commit 2: modify f1, add f2 - _setup_request(module, json_payload={ - "message": "c2", - "files": [ - {"file_id": "f1", "file_name": "a.txt", "operation": "modify", "content": "v2"}, - {"file_id": "f2", "file_name": "b.txt", "operation": "add", "content": "world"}, - ], - }) + _setup_request( + module, + json_payload={ + "message": "c2", + "files": [ + {"file_id": "f1", "file_name": "a.txt", "operation": "modify", "content": "v2"}, + {"file_id": "f2", "file_name": "b.txt", "operation": "add", "content": "world"}, + ], + }, + ) res = _run(module.create_commit("root-folder")) assert res["code"] == 0 assert res["data"]["file_count"] == 2 @@ -455,20 +478,25 @@ def test_create_commit_modify_and_add(monkeypatch): @pytest.mark.p2 def test_create_commit_delete(monkeypatch): module = _load_module(monkeypatch) - FileTestModel.create(id="f1", parent_id="root-folder", tenant_id="t1", - created_by="test-user", name="a.txt", type="txt") + FileTestModel.create(id="f1", parent_id="root-folder", tenant_id="t1", created_by="test-user", name="a.txt", type="txt") # Add then delete - _setup_request(module, json_payload={ - "message": "add", - "files": [{"file_id": "f1", "file_name": "a.txt", "operation": "add", "content": "hello"}], - }) + _setup_request( + module, + json_payload={ + "message": "add", + "files": [{"file_id": "f1", "file_name": "a.txt", "operation": "add", "content": "hello"}], + }, + ) _run(module.create_commit("root-folder")) - _setup_request(module, json_payload={ - "message": "delete", - "files": [{"file_id": "f1", "file_name": "a.txt", "operation": "delete"}], - }) + _setup_request( + module, + json_payload={ + "message": "delete", + "files": [{"file_id": "f1", "file_name": "a.txt", "operation": "delete"}], + }, + ) res = _run(module.create_commit("root-folder")) assert res["code"] == 0 @@ -476,21 +504,25 @@ def test_create_commit_delete(monkeypatch): @pytest.mark.p2 def test_create_commit_rename(monkeypatch): module = _load_module(monkeypatch) - FileTestModel.create(id="f1", parent_id="root-folder", tenant_id="t1", - created_by="test-user", name="old.txt", type="txt") + FileTestModel.create(id="f1", parent_id="root-folder", tenant_id="t1", created_by="test-user", name="old.txt", type="txt") - _setup_request(module, json_payload={ - "message": "add", - "files": [{"file_id": "f1", "file_name": "old.txt", "operation": "add", "content": "data"}], - }) + _setup_request( + module, + json_payload={ + "message": "add", + "files": [{"file_id": "f1", "file_name": "old.txt", "operation": "add", "content": "data"}], + }, + ) _run(module.create_commit("root-folder")) # Rename - _setup_request(module, json_payload={ - "message": "rename", - "files": [{"file_id": "f1", "file_name": "old.txt", "operation": "rename", - "old_name": "old.txt", "new_name": "new.txt"}], - }) + _setup_request( + module, + json_payload={ + "message": "rename", + "files": [{"file_id": "f1", "file_name": "old.txt", "operation": "rename", "old_name": "old.txt", "new_name": "new.txt"}], + }, + ) res = _run(module.create_commit("root-folder")) assert res["code"] == 0 @@ -498,20 +530,25 @@ def test_create_commit_rename(monkeypatch): @pytest.mark.p2 def test_list_commits_success(monkeypatch): module = _load_module(monkeypatch) - FileTestModel.create(id="f1", parent_id="root-folder", tenant_id="t1", - created_by="test-user", name="a.txt", type="txt") + FileTestModel.create(id="f1", parent_id="root-folder", tenant_id="t1", created_by="test-user", name="a.txt", type="txt") # Create 2 commits - _setup_request(module, json_payload={ - "message": "c1", - "files": [{"file_id": "f1", "file_name": "a.txt", "operation": "add", "content": "v1"}], - }) + _setup_request( + module, + json_payload={ + "message": "c1", + "files": [{"file_id": "f1", "file_name": "a.txt", "operation": "add", "content": "v1"}], + }, + ) _run(module.create_commit("root-folder")) - _setup_request(module, json_payload={ - "message": "c2", - "files": [{"file_id": "f1", "file_name": "a.txt", "operation": "modify", "content": "v2"}], - }) + _setup_request( + module, + json_payload={ + "message": "c2", + "files": [{"file_id": "f1", "file_name": "a.txt", "operation": "modify", "content": "v2"}], + }, + ) _run(module.create_commit("root-folder")) # List @@ -525,13 +562,15 @@ def test_list_commits_success(monkeypatch): @pytest.mark.p2 def test_get_commit_detail(monkeypatch): module = _load_module(monkeypatch) - FileTestModel.create(id="f1", parent_id="root-folder", tenant_id="t1", - created_by="test-user", name="a.txt", type="txt") + FileTestModel.create(id="f1", parent_id="root-folder", tenant_id="t1", created_by="test-user", name="a.txt", type="txt") - _setup_request(module, json_payload={ - "message": "detail test", - "files": [{"file_id": "f1", "file_name": "a.txt", "operation": "add", "content": "data"}], - }) + _setup_request( + module, + json_payload={ + "message": "detail test", + "files": [{"file_id": "f1", "file_name": "a.txt", "operation": "add", "content": "data"}], + }, + ) create_res = _run(module.create_commit("root-folder")) commit_id = create_res["data"]["id"] @@ -553,26 +592,30 @@ def test_get_commit_not_found(monkeypatch): @pytest.mark.p2 def test_diff_commits(monkeypatch): module = _load_module(monkeypatch) - FileTestModel.create(id="f1", parent_id="root-folder", tenant_id="t1", - created_by="test-user", name="a.txt", type="txt") - FileTestModel.create(id="f2", parent_id="root-folder", tenant_id="t1", - created_by="test-user", name="b.txt", type="txt") + FileTestModel.create(id="f1", parent_id="root-folder", tenant_id="t1", created_by="test-user", name="a.txt", type="txt") + FileTestModel.create(id="f2", parent_id="root-folder", tenant_id="t1", created_by="test-user", name="b.txt", type="txt") # c1: add f1 - _setup_request(module, json_payload={ - "message": "c1", - "files": [{"file_id": "f1", "file_name": "a.txt", "operation": "add", "content": "v1"}], - }) + _setup_request( + module, + json_payload={ + "message": "c1", + "files": [{"file_id": "f1", "file_name": "a.txt", "operation": "add", "content": "v1"}], + }, + ) c1 = _run(module.create_commit("root-folder"))["data"]["id"] # c2: add f2, modify f1 - _setup_request(module, json_payload={ - "message": "c2", - "files": [ - {"file_id": "f2", "file_name": "b.txt", "operation": "add", "content": "world"}, - {"file_id": "f1", "file_name": "a.txt", "operation": "modify", "content": "v2"}, - ], - }) + _setup_request( + module, + json_payload={ + "message": "c2", + "files": [ + {"file_id": "f2", "file_name": "b.txt", "operation": "add", "content": "world"}, + {"file_id": "f1", "file_name": "a.txt", "operation": "modify", "content": "v2"}, + ], + }, + ) c2 = _run(module.create_commit("root-folder"))["data"]["id"] assert c1 != c2, "c1 and c2 must have different IDs" @@ -605,17 +648,18 @@ def test_diff_commits_missing_params(monkeypatch): def test_get_uncommitted_changes(monkeypatch): module = _load_module(monkeypatch) # Seed a file that will be committed - FileTestModel.create(id="f1", parent_id="root-folder", tenant_id="t1", - created_by="test-user", name="a.txt", type="txt") + FileTestModel.create(id="f1", parent_id="root-folder", tenant_id="t1", created_by="test-user", name="a.txt", type="txt") # Seed a file that will NOT be committed (uncommitted add) - FileTestModel.create(id="f2", parent_id="root-folder", tenant_id="t1", - created_by="test-user", name="b.txt", type="txt") + FileTestModel.create(id="f2", parent_id="root-folder", tenant_id="t1", created_by="test-user", name="b.txt", type="txt") # Commit only f1 - _setup_request(module, json_payload={ - "message": "add f1", - "files": [{"file_id": "f1", "file_name": "a.txt", "operation": "add", "content": "hello"}], - }) + _setup_request( + module, + json_payload={ + "message": "add f1", + "files": [{"file_id": "f1", "file_name": "a.txt", "operation": "add", "content": "hello"}], + }, + ) _run(module.create_commit("root-folder")) res = _run(module.get_uncommitted_changes("root-folder")) @@ -629,13 +673,15 @@ def test_get_uncommitted_changes(monkeypatch): @pytest.mark.p2 def test_get_commit_tree(monkeypatch): module = _load_module(monkeypatch) - FileTestModel.create(id="f1", parent_id="root-folder", tenant_id="t1", - created_by="test-user", name="a.txt", type="txt") + FileTestModel.create(id="f1", parent_id="root-folder", tenant_id="t1", created_by="test-user", name="a.txt", type="txt") - _setup_request(module, json_payload={ - "message": "c1", - "files": [{"file_id": "f1", "file_name": "a.txt", "operation": "add", "content": "data"}], - }) + _setup_request( + module, + json_payload={ + "message": "c1", + "files": [{"file_id": "f1", "file_name": "a.txt", "operation": "add", "content": "data"}], + }, + ) create_res = _run(module.create_commit("root-folder")) commit_id = create_res["data"]["id"] @@ -649,13 +695,15 @@ def test_get_commit_tree(monkeypatch): @pytest.mark.p2 def test_get_commit_file_content(monkeypatch): module = _load_module(monkeypatch) - FileTestModel.create(id="f1", parent_id="root-folder", tenant_id="t1", - created_by="test-user", name="a.txt", type="txt") + FileTestModel.create(id="f1", parent_id="root-folder", tenant_id="t1", created_by="test-user", name="a.txt", type="txt") - _setup_request(module, json_payload={ - "message": "c1", - "files": [{"file_id": "f1", "file_name": "a.txt", "operation": "add", "content": "hello world"}], - }) + _setup_request( + module, + json_payload={ + "message": "c1", + "files": [{"file_id": "f1", "file_name": "a.txt", "operation": "add", "content": "hello world"}], + }, + ) create_res = _run(module.create_commit("root-folder")) commit_id = create_res["data"]["id"] @@ -667,20 +715,25 @@ def test_get_commit_file_content(monkeypatch): @pytest.mark.p2 def test_get_file_version_history(monkeypatch): module = _load_module(monkeypatch) - FileTestModel.create(id="f1", parent_id="root-folder", tenant_id="t1", - created_by="test-user", name="a.txt", type="txt") + FileTestModel.create(id="f1", parent_id="root-folder", tenant_id="t1", created_by="test-user", name="a.txt", type="txt") # Two commits modifying f1 - _setup_request(module, json_payload={ - "message": "v1", - "files": [{"file_id": "f1", "file_name": "a.txt", "operation": "add", "content": "v1"}], - }) + _setup_request( + module, + json_payload={ + "message": "v1", + "files": [{"file_id": "f1", "file_name": "a.txt", "operation": "add", "content": "v1"}], + }, + ) _run(module.create_commit("root-folder")) - _setup_request(module, json_payload={ - "message": "v2", - "files": [{"file_id": "f1", "file_name": "a.txt", "operation": "modify", "content": "v2"}], - }) + _setup_request( + module, + json_payload={ + "message": "v2", + "files": [{"file_id": "f1", "file_name": "a.txt", "operation": "modify", "content": "v2"}], + }, + ) _run(module.create_commit("root-folder")) res = _run(module.get_file_version_history("f1")) @@ -692,13 +745,15 @@ def test_get_file_version_history(monkeypatch): def test_workspace_alias(monkeypatch): """Verify /workspace/ alias routes work the same as /folders/.""" module = _load_module(monkeypatch) - FileTestModel.create(id="f1", parent_id="ws-folder", tenant_id="t1", - created_by="test-user", name="a.txt", type="txt") + FileTestModel.create(id="f1", parent_id="ws-folder", tenant_id="t1", created_by="test-user", name="a.txt", type="txt") - _setup_request(module, json_payload={ - "message": "workspace commit", - "files": [{"file_id": "f1", "file_name": "a.txt", "operation": "add", "content": "data"}], - }) + _setup_request( + module, + json_payload={ + "message": "workspace commit", + "files": [{"file_id": "f1", "file_name": "a.txt", "operation": "add", "content": "data"}], + }, + ) res = _run(module.create_commit("ws-folder")) assert res["code"] == 0 @@ -712,13 +767,15 @@ def test_workspace_alias(monkeypatch): @pytest.mark.p2 def test_get_commit_wrong_folder_returns_not_found(monkeypatch): module = _load_module(monkeypatch) - FileTestModel.create(id="f1", parent_id="folder-a", tenant_id="t1", - created_by="test-user", name="a.txt", type="txt") + FileTestModel.create(id="f1", parent_id="folder-a", tenant_id="t1", created_by="test-user", name="a.txt", type="txt") - _setup_request(module, json_payload={ - "message": "c1", - "files": [{"file_id": "f1", "file_name": "a.txt", "operation": "add", "content": "data"}], - }) + _setup_request( + module, + json_payload={ + "message": "c1", + "files": [{"file_id": "f1", "file_name": "a.txt", "operation": "add", "content": "data"}], + }, + ) create_res = _run(module.create_commit("folder-a")) commit_id = create_res["data"]["id"] diff --git a/test/testcases/restful_api/test_file_routes_unit.py b/test/testcases/restful_api/test_file_routes_unit.py index 579e37eb07..6d055936cf 100644 --- a/test/testcases/restful_api/test_file_routes_unit.py +++ b/test/testcases/restful_api/test_file_routes_unit.py @@ -343,6 +343,7 @@ def test_parent_and_ancestors_use_new_routes(monkeypatch): assert ancestors_res["code"] == 0 assert ancestors_res["data"]["parent_folders"][0]["id"] == "root" + # # Copyright 2026 The InfiniFlow Authors. All Rights Reserved. # @@ -756,9 +757,7 @@ def _load_file_api_service(monkeypatch): try: spec.loader.exec_module(module) except Exception: - LOGGER.exception( - "_load_file_api_service: spec.loader.exec_module(module) failed" - ) + LOGGER.exception("_load_file_api_service: spec.loader.exec_module(module) failed") raise LOGGER.debug("_load_file_api_service: spec.loader.exec_module(module) completed") return module @@ -799,12 +798,16 @@ def test_upload_file_success_uses_new_service_layer(monkeypatch): "create_folder", lambda _file, parent_id, _names, _len_id, *_args: SimpleNamespace(id=parent_id), ) - monkeypatch.setattr(module.settings, "STORAGE_IMPL", SimpleNamespace( - obj_exist=lambda *_args, **_kwargs: False, - put=lambda bucket, location, blob: storage_puts.append((bucket, location, blob)), - rm=lambda *_args, **_kwargs: None, - move=lambda *_args, **_kwargs: None, - )) + monkeypatch.setattr( + module.settings, + "STORAGE_IMPL", + SimpleNamespace( + obj_exist=lambda *_args, **_kwargs: False, + put=lambda bucket, location, blob: storage_puts.append((bucket, location, blob)), + rm=lambda *_args, **_kwargs: None, + move=lambda *_args, **_kwargs: None, + ), + ) ok, data = _run(module.upload_file("tenant1", "pf1", [_DummyUploadFile("a.txt", b"hello")])) assert ok is True @@ -867,12 +870,16 @@ def test_move_files_handles_dest_and_storage_move(monkeypatch): "get_by_ids", lambda _ids: [_DummyFile("file1", module.FileType.DOC.value, parent_id="src", location="old", name="a.txt")], ) - monkeypatch.setattr(module.settings, "STORAGE_IMPL", SimpleNamespace( - obj_exist=lambda *_args, **_kwargs: False, - put=lambda *_args, **_kwargs: None, - rm=lambda *_args, **_kwargs: None, - move=lambda old_bucket, old_loc, new_bucket, new_loc: moved.append((old_bucket, old_loc, new_bucket, new_loc)), - )) + monkeypatch.setattr( + module.settings, + "STORAGE_IMPL", + SimpleNamespace( + obj_exist=lambda *_args, **_kwargs: False, + put=lambda *_args, **_kwargs: None, + rm=lambda *_args, **_kwargs: None, + move=lambda old_bucket, old_loc, new_bucket, new_loc: moved.append((old_bucket, old_loc, new_bucket, new_loc)), + ), + ) monkeypatch.setattr(module.FileService, "update_by_id", lambda file_id, data: updated.append((file_id, data)) or True) ok, message = _run(module.move_files("tenant1", ["file1"], "missing")) diff --git a/test/testcases/restful_api/test_llm_routes_unit.py b/test/testcases/restful_api/test_llm_routes_unit.py index a43dbac2f8..fc4f91d54b 100644 --- a/test/testcases/restful_api/test_llm_routes_unit.py +++ b/test/testcases/restful_api/test_llm_routes_unit.py @@ -215,7 +215,7 @@ def _load_llm_app(monkeypatch): api_utils_mod.get_request_json = _get_request_json api_utils_mod.server_error_response = lambda exc: {"code": 500, "message": str(exc), "data": None} - api_utils_mod.validate_request = lambda *_args, **_kwargs: (lambda fn: fn) + api_utils_mod.validate_request = lambda *_args, **_kwargs: lambda fn: fn monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) constants_mod = ModuleType("common.constants") @@ -499,9 +499,7 @@ def test_add_llm_factory_specific_key_assembly_unit(monkeypatch): assert json.loads(_run_case("BaiduYiyan", extra={"yiyan_ak": "ak", "yiyan_sk": "sk"})["api_key"]) == {"yiyan_ak": "ak", "yiyan_sk": "sk"} assert json.loads(_run_case("Fish Audio", extra={"fish_audio_ak": "ak", "fish_audio_refid": "rid"})["api_key"]) == {"fish_audio_ak": "ak", "fish_audio_refid": "rid"} - assert json.loads( - _run_case("Google Cloud", extra={"google_project_id": "pid", "google_region": "us", "google_service_account_key": "sak"})["api_key"] - ) == { + assert json.loads(_run_case("Google Cloud", extra={"google_project_id": "pid", "google_region": "us", "google_service_account_key": "sak"})["api_key"]) == { "google_project_id": "pid", "google_region": "us", "google_service_account_key": "sak", @@ -668,9 +666,9 @@ def test_add_llm_model_type_probe_and_persistence_matrix_unit(monkeypatch): monkeypatch.setattr( module.LLMService, "query", - lambda **kwargs: [] if kwargs.get("llm_factory") == "FUnknown" else [ - _LLMRow(llm_name="m", fid=kwargs.get("llm_factory"), model_type=kwargs.get("model_type", module.LLMType.CHAT.value), max_tokens=4096) - ], + lambda **kwargs: ( + [] if kwargs.get("llm_factory") == "FUnknown" else [_LLMRow(llm_name="m", fid=kwargs.get("llm_factory"), model_type=kwargs.get("model_type", module.LLMType.CHAT.value), max_tokens=4096)] + ), ) _set_request_json(monkeypatch, module, {"llm_factory": "FUnknown", "llm_name": "m", "model_type": "unknown"}) diff --git a/test/testcases/restful_api/test_search_datasets_consistency.py b/test/testcases/restful_api/test_search_datasets_consistency.py index 13736a3bf6..5ae072312a 100644 --- a/test/testcases/restful_api/test_search_datasets_consistency.py +++ b/test/testcases/restful_api/test_search_datasets_consistency.py @@ -31,6 +31,7 @@ values (with the empty-value normalization done in compare_chunks). All datasets and documents are created once at module level, then each test unit runs against the pre-built data. Cleanup happens automatically at module teardown. """ + import logging import os import sys @@ -140,6 +141,8 @@ medical image analysis, climate science, material inspection and board game prog where they have produced results comparable to and in some cases surpassing human expert performance. """ + + # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- @@ -168,9 +171,7 @@ def compare_chunks(python_chunk, go_chunk): if field in ("similarity", "term_similarity", "vector_similarity"): if p_val != g_val: - raise AssertionError( - f"Field '{field}' differs: python={p_val}, go={g_val}, diff={abs(p_val - g_val)}" - ) + raise AssertionError(f"Field '{field}' differs: python={p_val}, go={g_val}, diff={abs(p_val - g_val)}") elif isinstance(p_val, (list, dict)): if p_val != g_val: raise AssertionError(f"Field '{field}' mismatch") @@ -200,9 +201,17 @@ def search_and_compare(rest_client, dataset_ids, cfg): "top_k": cfg.get("top_k", 5), } optional_fields = [ - "rerank_id", "search_id", "keyword", "vector_similarity_weight", - "similarity_threshold", "use_kg", "cross_languages", "page", "size", - "meta_data_filter", "doc_ids", + "rerank_id", + "search_id", + "keyword", + "vector_similarity_weight", + "similarity_threshold", + "use_kg", + "cross_languages", + "page", + "size", + "meta_data_filter", + "doc_ids", ] for field in optional_fields: value = cfg.get(field) @@ -230,13 +239,12 @@ def search_and_compare(rest_client, dataset_ids, cfg): go_chunks = go_data["data"]["chunks"] logger.info(f"python_chunks={len(python_chunks)}, go_chunks={len(go_chunks)}") - logger.info(f" Python chunks: {[(c.get('chunk_id','?'), c.get('similarity',0)) for c in python_chunks]}") - logger.info(f" Go chunks: {[(c.get('chunk_id','?'), c.get('similarity',0)) for c in go_chunks]}") + logger.info(f" Python chunks: {[(c.get('chunk_id', '?'), c.get('similarity', 0)) for c in python_chunks]}") + logger.info(f" Go chunks: {[(c.get('chunk_id', '?'), c.get('similarity', 0)) for c in go_chunks]}") llm_involved = bool(cfg.get("rerank_id") or cfg.get("keyword") or cfg.get("cross_languages")) if not llm_involved: - assert len(python_chunks) == len(go_chunks), \ - f"Chunk count differs: python={len(python_chunks)}, go={len(go_chunks)}" + assert len(python_chunks) == len(go_chunks), f"Chunk count differs: python={len(python_chunks)}, go={len(go_chunks)}" for i, (p_chunk, g_chunk) in enumerate(zip(python_chunks, go_chunks)): try: compare_chunks(p_chunk, g_chunk) @@ -252,11 +260,11 @@ def search_and_compare(rest_client, dataset_ids, cfg): def _upload_and_parse(rest_client, dataset_id, text, filename="doc.txt"): """Upload text as a file and wait for parsing to complete. Returns document_id.""" - with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False, encoding='utf-8') as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False, encoding="utf-8") as f: f.write(text) temp_path = f.name - with open(temp_path, 'rb') as f: + with open(temp_path, "rb") as f: files = [("file", (filename, f))] upload_res = rest_client.post(f"/datasets/{dataset_id}/documents", files=files) assert upload_res.status_code == 200, f"Failed to upload {filename}: {upload_res.text}" @@ -331,11 +339,14 @@ def all_datasets(rest_client): # ----------------------------------------------------------------------- # 1) 1 dataset with 2 files (Chinese) # ----------------------------------------------------------------------- - create_res = rest_client.post("/datasets", json={ - "name": "consistency_chinese", - "embedding_model": "BAAI/bge-small-en-v1.5@Builtin", - "parser_config": {"chunk_token_num": 1, "delimiter": "`\n\n`"}, - }) + create_res = rest_client.post( + "/datasets", + json={ + "name": "consistency_chinese", + "embedding_model": "BAAI/bge-small-en-v1.5@Builtin", + "parser_config": {"chunk_token_num": 1, "delimiter": "`\n\n`"}, + }, + ) assert create_res.status_code == 200, create_res.text assert create_res.json()["code"] == 0, create_res.json() ds_chinese_id = create_res.json()["data"]["id"] @@ -357,11 +368,14 @@ def all_datasets(rest_client): # ----------------------------------------------------------------------- # 2) 1 dataset with Three Kingdoms only # ----------------------------------------------------------------------- - create_res = rest_client.post("/datasets", json={ - "name": "consistency_three_kingdoms", - "embedding_model": "BAAI/bge-small-en-v1.5@Builtin", - "parser_config": {"chunk_token_num": 1, "delimiter": "`\n\n`"}, - }) + create_res = rest_client.post( + "/datasets", + json={ + "name": "consistency_three_kingdoms", + "embedding_model": "BAAI/bge-small-en-v1.5@Builtin", + "parser_config": {"chunk_token_num": 1, "delimiter": "`\n\n`"}, + }, + ) assert create_res.status_code == 200, create_res.text assert create_res.json()["code"] == 0, create_res.json() ds_3k_id = create_res.json()["data"]["id"] @@ -379,11 +393,14 @@ def all_datasets(rest_client): # ----------------------------------------------------------------------- # 3) 1 dataset with English text # ----------------------------------------------------------------------- - create_res = rest_client.post("/datasets", json={ - "name": "consistency_english", - "embedding_model": "BAAI/bge-small-en-v1.5@Builtin", - "parser_config": {"chunk_token_num": 1, "delimiter": "`\n\n`"}, - }) + create_res = rest_client.post( + "/datasets", + json={ + "name": "consistency_english", + "embedding_model": "BAAI/bge-small-en-v1.5@Builtin", + "parser_config": {"chunk_token_num": 1, "delimiter": "`\n\n`"}, + }, + ) assert create_res.status_code == 200, create_res.text assert create_res.json()["code"] == 0, create_res.json() ds_en_id = create_res.json()["data"]["id"] @@ -417,6 +434,7 @@ pytestmark = pytest.mark.skipif( reason="GO server is not started in CI", ) + # --------------------------------------------------------------------------- # Test Unit 1: Search consistency — 1 dataset with 2 files # --------------------------------------------------------------------------- @@ -479,7 +497,7 @@ def test_search_datasets_consistency_metadata_filter(rest_client, all_datasets): {"question": "打虎", "meta_data_filter": {"method": "manual", "manual": [{"key": "era", "op": "≠", "value": 960}]}}, {"question": "打虎", "meta_data_filter": {"method": "manual", "manual": [{"key": "era", "op": ">", "value": 220}]}}, {"question": "曹操", "meta_data_filter": {"method": "manual", "manual": [{"key": "source", "op": "contains", "value": "luo"}]}}, - {"question": "努力发展农业", "meta_data_filter": {"method": "manual", "manual": [{"key": "character", "op": "in", "value": ["曹操","孙权"]}]}}, + {"question": "努力发展农业", "meta_data_filter": {"method": "manual", "manual": [{"key": "character", "op": "in", "value": ["曹操", "孙权"]}]}}, {"question": "打虎", "meta_data_filter": {"method": "manual", "manual": [{"key": "character", "op": "=", "value": "武松"}]}}, ] diff --git a/test/testcases/restful_api/test_sessions.py b/test/testcases/restful_api/test_sessions.py index 224d573cd4..30e9c6abbd 100644 --- a/test/testcases/restful_api/test_sessions.py +++ b/test/testcases/restful_api/test_sessions.py @@ -132,7 +132,6 @@ def test_session_create_requires_auth_and_invalid_chat_contract(): assert payload["message"] == "", (scenario_name, payload) - @pytest.mark.p2 def test_session_create_validation_and_deleted_chat_contract(rest_client, create_chat): chat_id = create_chat("restful_session_create_contract") diff --git a/test/testcases/restful_api/test_user_tenant_routes_unit.py b/test/testcases/restful_api/test_user_tenant_routes_unit.py index 30dffe0c01..5b24f88a39 100644 --- a/test/testcases/restful_api/test_user_tenant_routes_unit.py +++ b/test/testcases/restful_api/test_user_tenant_routes_unit.py @@ -161,7 +161,7 @@ def _load_tenant_module(monkeypatch): api_utils_mod.get_json_result = lambda data=None, message="", code=0: {"code": code, "message": message, "data": data} api_utils_mod.get_data_error_result = lambda message="": {"code": 102, "message": message, "data": False} api_utils_mod.server_error_response = lambda exc: {"code": 100, "message": repr(exc), "data": False} - api_utils_mod.validate_request = lambda *_args, **_kwargs: (lambda fn: fn) + api_utils_mod.validate_request = lambda *_args, **_kwargs: lambda fn: fn api_utils_mod.get_request_json = lambda: _AwaitableValue({}) monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) @@ -442,9 +442,7 @@ def _load_user_app(monkeypatch): api_pkg.apps = apps_mod apps_auth_mod = ModuleType("api.apps.auth") - apps_auth_mod.get_auth_client = lambda _config: SimpleNamespace( - get_authorization_url=lambda state: f"https://oauth.example/{state}" - ) + apps_auth_mod.get_auth_client = lambda _config: SimpleNamespace(get_authorization_url=lambda state: f"https://oauth.example/{state}") monkeypatch.setitem(sys.modules, "api.apps.auth", apps_auth_mod) db_mod = ModuleType("api.db") @@ -508,16 +506,7 @@ def _load_user_app(monkeypatch): @staticmethod def get_api_key(tenant_id, model_name, model_type=None): return _MockTableObject( - id=1, - tenant_id=tenant_id, - llm_factory="", - model_type="chat", - llm_name=model_name, - api_key="fake-api-key", - api_base="https://api.example.com", - max_tokens=8192, - used_tokens=0, - status=1 + id=1, tenant_id=tenant_id, llm_factory="", model_type="chat", llm_name=model_name, api_key="fake-api-key", api_base="https://api.example.com", max_tokens=8192, used_tokens=0, status=1 ) tenant_llm_service_mod.TenantLLMService = _StubTenantLLMService @@ -1417,54 +1406,67 @@ def _load_chat_routes_unit_module(monkeypatch): constants_mod.RetCode = SimpleNamespace(SUCCESS=0, DATA_ERROR=102, OPERATING_ERROR=103, AUTHENTICATION_ERROR=109) constants_mod.StatusEnum = SimpleNamespace(VALID=SimpleNamespace(value="1"), INVALID=SimpleNamespace(value="0")) from common.constants import MAXIMUM_PAGE_NUMBER as _MPN, MAXIMUM_TASK_PAGE_NUMBER as _MTPN + constants_mod.MAXIMUM_PAGE_NUMBER = _MPN constants_mod.MAXIMUM_TASK_PAGE_NUMBER = _MTPN monkeypatch.setitem(sys.modules, "common.constants", constants_mod) misc_utils_mod = ModuleType("common.misc_utils") misc_utils_mod.get_uuid = lambda: "generated-chat-id" + async def _thread_pool_exec(func, *args, **kwargs): return func(*args, **kwargs) + misc_utils_mod.thread_pool_exec = _thread_pool_exec monkeypatch.setitem(sys.modules, "common.misc_utils", misc_utils_mod) dialog_service_mod = ModuleType("api.db.services.dialog_service") + class _DialogService: - model = SimpleNamespace(_meta=SimpleNamespace(fields={ - "id": None, - "tenant_id": None, - "name": None, - "description": None, - "icon": None, - "kb_ids": None, - "llm_id": None, - "llm_setting": None, - "prompt_config": None, - "similarity_threshold": None, - "vector_similarity_weight": None, - "top_n": None, - "top_k": None, - "rerank_id": None, - "meta_data_filter": None, - "created_by": None, - "create_time": None, - "create_date": None, - "update_time": None, - "update_date": None, - "status": None, - })) + model = SimpleNamespace( + _meta=SimpleNamespace( + fields={ + "id": None, + "tenant_id": None, + "name": None, + "description": None, + "icon": None, + "kb_ids": None, + "llm_id": None, + "llm_setting": None, + "prompt_config": None, + "similarity_threshold": None, + "vector_similarity_weight": None, + "top_n": None, + "top_k": None, + "rerank_id": None, + "meta_data_filter": None, + "created_by": None, + "create_time": None, + "create_date": None, + "update_time": None, + "update_date": None, + "status": None, + } + ) + ) + @staticmethod def query(**_kwargs): return [] + @staticmethod def save(**_kwargs): return True + @staticmethod def get_by_id(_chat_id): return False, None + @staticmethod def get_by_tenant_ids(*_args, **_kwargs): return [], 0 + dialog_service_mod.DialogService = _DialogService dialog_service_mod.async_ask = lambda *_args, **_kwargs: None dialog_service_mod.async_chat = lambda *_args, **_kwargs: None @@ -1477,6 +1479,7 @@ def _load_chat_routes_unit_module(monkeypatch): monkeypatch.setitem(sys.modules, "api.db.services.conversation_service", conversation_service_mod) kb_service_mod = ModuleType("api.db.services.knowledgebase_service") + class _KB: def __init__(self): self.id = "kb-1" @@ -1484,16 +1487,22 @@ def _load_chat_routes_unit_module(monkeypatch): self.chunk_num = 1 self.name = "Dataset A" self.status = "1" - kb_service_mod.KnowledgebaseService = type('KnowledgebaseService', (), { - 'accessible': staticmethod(lambda **_kwargs: [SimpleNamespace(id='kb-1')]), - 'query': staticmethod(lambda **_kwargs: [_KB()]), - 'get_by_id': staticmethod(lambda _id: (True, _KB())), - }) + + kb_service_mod.KnowledgebaseService = type( + "KnowledgebaseService", + (), + { + "accessible": staticmethod(lambda **_kwargs: [SimpleNamespace(id="kb-1")]), + "query": staticmethod(lambda **_kwargs: [_KB()]), + "get_by_id": staticmethod(lambda _id: (True, _KB())), + }, + ) monkeypatch.setitem(sys.modules, "api.db.services.knowledgebase_service", kb_service_mod) tenant_model_provider_mod = ModuleType("api.db.joint_services.tenant_model_service") tenant_model_provider_mod.get_model_config_from_provider_instance = lambda *_args, **_kwargs: {} tenant_model_provider_mod.get_tenant_default_model_by_type = lambda *_args, **_kwargs: {} + def _split_model_name(model_name): parts = model_name.split("@") if len(parts) == 1: @@ -1502,6 +1511,7 @@ def _load_chat_routes_unit_module(monkeypatch): return parts[0], "default", parts[1] else: return parts[0], parts[1], parts[2] + tenant_model_provider_mod.split_model_name = staticmethod(_split_model_name) tenant_model_provider_mod.get_api_key = lambda *_args, **_kwargs: SimpleNamespace(id=1) monkeypatch.setitem(sys.modules, "api.db.joint_services.tenant_model_service", tenant_model_provider_mod) @@ -1515,39 +1525,43 @@ def _load_chat_routes_unit_module(monkeypatch): monkeypatch.setitem(sys.modules, "api.db.services.search_service", search_service_mod) user_service_mod = ModuleType("api.db.services.user_service") - user_service_mod.UserService = type('UserService', (), {}) - user_service_mod.TenantService = type('TenantService', (), { - 'get_by_id': staticmethod(lambda _tenant_id: (True, SimpleNamespace(llm_id='glm-4'))), - 'get_joined_tenants_by_user_id': staticmethod(lambda _user_id: [{'tenant_id': 'tenant-1'}, {'tenant_id': 'team-tenant-2'}]), - }) - user_service_mod.UserTenantService = type('UserTenantService', (), {'query': staticmethod(lambda **_kwargs: [])}) + user_service_mod.UserService = type("UserService", (), {}) + user_service_mod.TenantService = type( + "TenantService", + (), + { + "get_by_id": staticmethod(lambda _tenant_id: (True, SimpleNamespace(llm_id="glm-4"))), + "get_joined_tenants_by_user_id": staticmethod(lambda _user_id: [{"tenant_id": "tenant-1"}, {"tenant_id": "team-tenant-2"}]), + }, + ) + user_service_mod.UserTenantService = type("UserTenantService", (), {"query": staticmethod(lambda **_kwargs: [])}) monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod) chunk_feedback_service_mod = ModuleType("api.db.services.chunk_feedback_service") - chunk_feedback_service_mod.ChunkFeedbackService = type('ChunkFeedbackService', (), {'apply_feedback': staticmethod(lambda **_kwargs: {'success_count': 0, 'fail_count': 0, 'chunk_ids': []})}) + chunk_feedback_service_mod.ChunkFeedbackService = type("ChunkFeedbackService", (), {"apply_feedback": staticmethod(lambda **_kwargs: {"success_count": 0, "fail_count": 0, "chunk_ids": []})}) monkeypatch.setitem(sys.modules, "api.db.services.chunk_feedback_service", chunk_feedback_service_mod) api_utils_mod = ModuleType("api.utils.api_utils") api_utils_mod.check_duplicate_ids = lambda ids, _label: (list(dict.fromkeys(ids or [])), []) - api_utils_mod.get_data_error_result = lambda message='': {'code': 102, 'data': None, 'message': message} - api_utils_mod.get_json_result = lambda data=None, message='', code=0: {'code': code, 'data': data, 'message': message} - api_utils_mod.server_error_response = lambda ex: {'code': 500, 'data': None, 'message': str(ex)} - api_utils_mod.validate_request = lambda *_args, **_kwargs: (lambda func: func) + api_utils_mod.get_data_error_result = lambda message="": {"code": 102, "data": None, "message": message} + api_utils_mod.get_json_result = lambda data=None, message="", code=0: {"code": code, "data": data, "message": message} + api_utils_mod.server_error_response = lambda ex: {"code": 500, "data": None, "message": str(ex)} + api_utils_mod.validate_request = lambda *_args, **_kwargs: lambda func: func api_utils_mod.get_request_json = lambda: _AwaitableValue({}) monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) rag_pkg = ModuleType("rag") - rag_pkg.__path__ = [str(repo_root / 'rag')] - monkeypatch.setitem(sys.modules, 'rag', rag_pkg) - rag_prompts_pkg = ModuleType('rag.prompts') - rag_prompts_pkg.__path__ = [str(repo_root / 'rag' / 'prompts')] - monkeypatch.setitem(sys.modules, 'rag.prompts', rag_prompts_pkg) - rag_prompts_generator_mod = ModuleType('rag.prompts.generator') - rag_prompts_generator_mod.chunks_format = lambda reference: reference.get('chunks', []) if isinstance(reference, dict) else [] - monkeypatch.setitem(sys.modules, 'rag.prompts.generator', rag_prompts_generator_mod) - rag_prompts_template_mod = ModuleType('rag.prompts.template') - rag_prompts_template_mod.load_prompt = lambda *_args, **_kwargs: '' - monkeypatch.setitem(sys.modules, 'rag.prompts.template', rag_prompts_template_mod) + rag_pkg.__path__ = [str(repo_root / "rag")] + monkeypatch.setitem(sys.modules, "rag", rag_pkg) + rag_prompts_pkg = ModuleType("rag.prompts") + rag_prompts_pkg.__path__ = [str(repo_root / "rag" / "prompts")] + monkeypatch.setitem(sys.modules, "rag.prompts", rag_prompts_pkg) + rag_prompts_generator_mod = ModuleType("rag.prompts.generator") + rag_prompts_generator_mod.chunks_format = lambda reference: reference.get("chunks", []) if isinstance(reference, dict) else [] + monkeypatch.setitem(sys.modules, "rag.prompts.generator", rag_prompts_generator_mod) + rag_prompts_template_mod = ModuleType("rag.prompts.template") + rag_prompts_template_mod.load_prompt = lambda *_args, **_kwargs: "" + monkeypatch.setitem(sys.modules, "rag.prompts.template", rag_prompts_template_mod) spec = importlib.util.spec_from_file_location(module_name, module_path) module = importlib.util.module_from_spec(spec) @@ -1564,27 +1578,27 @@ def test_create_chat_uses_tenant_default_llm_when_llm_id_is_null_unit(monkeypatc async def _request_json(): return { - 'name': 'chat-a', - 'dataset_ids': ['kb-1'], - 'llm_id': None, - 'llm_setting': {'temperature': 0.8}, - 'prompt_config': {'system': 'Answer with {knowledge}', 'parameters': [{'key': 'knowledge', 'optional': False}]}, + "name": "chat-a", + "dataset_ids": ["kb-1"], + "llm_id": None, + "llm_setting": {"temperature": 0.8}, + "prompt_config": {"system": "Answer with {knowledge}", "parameters": [{"key": "knowledge", "optional": False}]}, } - monkeypatch.setattr(module, 'get_request_json', _request_json) - monkeypatch.setattr(module.DialogService, 'query', lambda **_kwargs: []) + monkeypatch.setattr(module, "get_request_json", _request_json) + monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: []) def _save(**kwargs): saved.update(kwargs) return True - monkeypatch.setattr(module.DialogService, 'save', _save) - monkeypatch.setattr(module.DialogService, 'get_by_id', lambda _id: (True, SimpleNamespace(to_dict=lambda: saved))) + monkeypatch.setattr(module.DialogService, "save", _save) + monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (True, SimpleNamespace(to_dict=lambda: saved))) res = _run(module.create.__wrapped__()) - assert res['code'] == 0 - assert saved['llm_id'] == 'glm-4' - assert saved['llm_setting']['temperature'] == 0.8 + assert res["code"] == 0 + assert saved["llm_id"] == "glm-4" + assert saved["llm_setting"]["temperature"] == 0.8 @pytest.mark.p2 @@ -1593,27 +1607,33 @@ def test_list_chats_authorized_multi_tenant_unit(monkeypatch): captured = {} monkeypatch.setattr( module, - 'request', + "request", SimpleNamespace( args=SimpleNamespace( get=lambda key, default=None: { - 'keywords': '', 'page': '1', 'page_size': '10', 'orderby': 'create_time', 'desc': 'true', 'id': None, 'name': None, + "keywords": "", + "page": "1", + "page_size": "10", + "orderby": "create_time", + "desc": "true", + "id": None, + "name": None, }.get(key, default), - getlist=lambda key: ['tenant-1', 'team-tenant-2'] if key == 'owner_ids' else [], + getlist=lambda key: ["tenant-1", "team-tenant-2"] if key == "owner_ids" else [], ) ), ) def _get_by_tenant_ids(owner_ids, user_id, *args, **kwargs): - captured['owner_ids'] = owner_ids - captured['user_id'] = user_id - return ([{'id': 'c1', 'tenant_id': 'tenant-1'}, {'id': 'c2', 'tenant_id': 'team-tenant-2'}], 2) + captured["owner_ids"] = owner_ids + captured["user_id"] = user_id + return ([{"id": "c1", "tenant_id": "tenant-1"}, {"id": "c2", "tenant_id": "team-tenant-2"}], 2) - monkeypatch.setattr(module.DialogService, 'get_by_tenant_ids', _get_by_tenant_ids) - monkeypatch.setattr(module.KnowledgebaseService, 'get_by_id', lambda _id: (False, None)) + monkeypatch.setattr(module.DialogService, "get_by_tenant_ids", _get_by_tenant_ids) + monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _id: (False, None)) res = _run(module.list_chats.__wrapped__()) - assert res['code'] == 0 - assert res['data']['total'] == 2 - assert {c['id'] for c in res['data']['chats']} == {'c1', 'c2'} - assert set(captured['owner_ids']) == {'tenant-1', 'team-tenant-2'} - assert captured['user_id'] == 'tenant-1' + assert res["code"] == 0 + assert res["data"]["total"] == 2 + assert {c["id"] for c in res["data"]["chats"]} == {"c1", "c2"} + assert set(captured["owner_ids"]) == {"tenant-1", "team-tenant-2"} + assert captured["user_id"] == "tenant-1" diff --git a/test/testcases/test_admin_api/test_user_api_key_management/test_delete_user_api_key.py b/test/testcases/test_admin_api/test_user_api_key_management/test_delete_user_api_key.py index 6d91d3779d..8e5c8a6732 100644 --- a/test/testcases/test_admin_api/test_user_api_key_management/test_delete_user_api_key.py +++ b/test/testcases/test_admin_api/test_user_api_key_management/test_delete_user_api_key.py @@ -157,7 +157,7 @@ class TestDeleteUserApiKey: res: Any = requests.post(url=url, json=register_data) res: dict[str, Any] = res.json() if res.get("code") != 0 and "has already registered" not in res.get("message"): - raise Exception(f"Failed to create second user: {res.get("message")}") + raise Exception(f"Failed to create second user: {res.get('message')}") # Generate a token for the test user generate_response: dict[str, Any] = generate_user_api_key(admin_session, user_name) diff --git a/test/testcases/test_http_api/common.py b/test/testcases/test_http_api/common.py index 499baee818..54dd485f0b 100644 --- a/test/testcases/test_http_api/common.py +++ b/test/testcases/test_http_api/common.py @@ -523,5 +523,3 @@ def search_dataset(auth, dataset_id, payload=None, *, headers=HEADERS): url = f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}/search" res = requests.post(url=url, headers=headers, auth=auth, json=payload) return res.json() - - diff --git a/test/testcases/test_http_api/test_chat_assistant_management/test_chat_sdk_routes_unit.py b/test/testcases/test_http_api/test_chat_assistant_management/test_chat_sdk_routes_unit.py index 70d063c3ab..fea449b48c 100644 --- a/test/testcases/test_http_api/test_chat_assistant_management/test_chat_sdk_routes_unit.py +++ b/test/testcases/test_http_api/test_chat_assistant_management/test_chat_sdk_routes_unit.py @@ -212,6 +212,7 @@ def _load_chat_module(monkeypatch): common_constants_mod.StatusEnum = _StubStatusEnum # Import pure-Python constants from the real module (no heavy deps) from common.constants import MAXIMUM_PAGE_NUMBER as _MPN, MAXIMUM_TASK_PAGE_NUMBER as _MTPN + common_constants_mod.MAXIMUM_PAGE_NUMBER = _MPN common_constants_mod.MAXIMUM_TASK_PAGE_NUMBER = _MTPN monkeypatch.setitem(sys.modules, "common.constants", common_constants_mod) @@ -410,7 +411,7 @@ def _load_chat_module(monkeypatch): api_utils_mod.get_json_result = lambda data=None, message="", code=0: {"code": code, "data": data, "message": message} api_utils_mod.get_request_json = lambda: _AwaitableValue({}) api_utils_mod.server_error_response = lambda ex: {"code": 500, "data": None, "message": str(ex)} - api_utils_mod.validate_request = lambda *_args, **_kwargs: (lambda func: func) + api_utils_mod.validate_request = lambda *_args, **_kwargs: lambda func: func monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) rag_pkg = ModuleType("rag") diff --git a/test/testcases/test_http_api/test_chat_management/test_table_parser_dataset_chat.py b/test/testcases/test_http_api/test_chat_management/test_table_parser_dataset_chat.py index 3da599300f..71eb487700 100644 --- a/test/testcases/test_http_api/test_chat_management/test_table_parser_dataset_chat.py +++ b/test/testcases/test_http_api/test_chat_management/test_table_parser_dataset_chat.py @@ -90,11 +90,7 @@ TEST_EXCEL_DATA_2 = [ ["Keyboard", "79", "Electronics"], ] -DEFAULT_CHAT_PROMPT = ( - "You are a helpful assistant that answers questions about table data using SQL queries.\n\n" - "Here is the knowledge base:\n{knowledge}\n\n" - "Use this information to answer questions." -) +DEFAULT_CHAT_PROMPT = "You are a helpful assistant that answers questions about table data using SQL queries.\n\nHere is the knowledge base:\n{knowledge}\n\nUse this information to answer questions." @pytest.mark.usefixtures("add_table_parser_dataset") @@ -171,12 +167,7 @@ class TestTableParserDatasetChat: Test that table parser dataset chat works correctly. """ # Use class-level attributes (set by setup fixture) - answer = self._ask_question( - self.__class__.auth, - self.__class__.chat_id, - self.__class__.session_id, - question - ) + answer = self._ask_question(self.__class__.auth, self.__class__.chat_id, self.__class__.session_id, question) # Verify answer matches expected pattern if provided if expected_answer_pattern: @@ -315,7 +306,4 @@ class TestTableParserDatasetChat: answer: The actual answer from the chat assistant pattern: Regular expression pattern to match """ - assert re.search(pattern, answer, re.IGNORECASE), ( - f"Answer does not match expected pattern '{pattern}'.\n" - f"Answer: {answer}" - ) + assert re.search(pattern, answer, re.IGNORECASE), f"Answer does not match expected pattern '{pattern}'.\nAnswer: {answer}" diff --git a/test/testcases/test_http_api/test_chunk_management_within_dataset/test_add_chunk.py b/test/testcases/test_http_api/test_chunk_management_within_dataset/test_add_chunk.py index 74e86f1966..ed0a0526b0 100644 --- a/test/testcases/test_http_api/test_chunk_management_within_dataset/test_add_chunk.py +++ b/test/testcases/test_http_api/test_chunk_management_within_dataset/test_add_chunk.py @@ -107,9 +107,7 @@ class TestAddChunk: assert False, res chunks_count = res["data"]["doc"]["chunk_count"] res = add_chunk(HttpApiAuth, dataset_id, document_id, payload) - assert res["code"] == expected_code, ( - f"Expected code: {expected_code}, got: {res['code']}, message: {res.get('message')}" - ) + assert res["code"] == expected_code, f"Expected code: {expected_code}, got: {res['code']}, message: {res.get('message')}" if expected_code == 0: validate_chunk_details(dataset_id, document_id, payload, res) res = list_chunks(HttpApiAuth, dataset_id, document_id) diff --git a/test/testcases/test_http_api/test_dataset_management/test_dify_retrieval_routes_unit.py b/test/testcases/test_http_api/test_dataset_management/test_dify_retrieval_routes_unit.py index c51880c801..a8a3b75a91 100644 --- a/test/testcases/test_http_api/test_dataset_management/test_dify_retrieval_routes_unit.py +++ b/test/testcases/test_http_api/test_dataset_management/test_dify_retrieval_routes_unit.py @@ -53,11 +53,7 @@ class _DummyKB: class _DummyRetriever: async def retrieval(self, *_args, **_kwargs): - return { - "chunks": [ - {"doc_id": "doc-1", "content_with_weight": "chunk-content", "similarity": 0.8, "docnm_kwd": "doc-title", "vector": [0.1]} - ] - } + return {"chunks": [{"doc_id": "doc-1", "content_with_weight": "chunk-content", "similarity": 0.8, "docnm_kwd": "doc-title", "vector": [0.1]}]} def retrieval_by_children(self, chunks, _tenant_ids): return chunks @@ -115,7 +111,7 @@ def _load_dify_retrieval_module(monkeypatch): # Mock tenant_llm_service for TenantLLMService and TenantService tenant_llm_service_mod = ModuleType("api.db.services.tenant_llm_service") - + class _MockModelConfig: def __init__(self, tenant_id, model_name): self.tenant_id = tenant_id @@ -128,7 +124,7 @@ def _load_dify_retrieval_module(monkeypatch): self.used_tokens = 0 self.status = 1 self.id = 1 - + def to_dict(self): return { "tenant_id": self.tenant_id, @@ -140,35 +136,27 @@ def _load_dify_retrieval_module(monkeypatch): "max_tokens": self.max_tokens, "used_tokens": self.used_tokens, "status": self.status, - "id": self.id + "id": self.id, } - + class _StubTenantService: @staticmethod def get_by_id(tenant_id): # Return a mock tenant with default model configurations - return True, SimpleNamespace( - id=tenant_id, - llm_id="chat-model", - embd_id="embd-model", - asr_id="asr-model", - img2txt_id="img2txt-model", - rerank_id="rerank-model", - tts_id="tts-model" - ) - + return True, SimpleNamespace(id=tenant_id, llm_id="chat-model", embd_id="embd-model", asr_id="asr-model", img2txt_id="img2txt-model", rerank_id="rerank-model", tts_id="tts-model") + class _StubTenantLLMService: @staticmethod def get_api_key(tenant_id, model_name): return _MockModelConfig(tenant_id, model_name) - + @staticmethod def split_model_name_and_factory(model_name): if "@" in model_name: parts = model_name.split("@") return parts[0], parts[1] return model_name, None - + tenant_llm_service_mod.TenantService = _StubTenantService tenant_llm_service_mod.TenantLLMService = _StubTenantLLMService @@ -180,32 +168,31 @@ def _load_dify_retrieval_module(monkeypatch): # Mock llm_service for LLMService llm_service_mod = ModuleType("api.db.services.llm_service") - + class _StubLLM: def __init__(self, llm_name): self.llm_name = llm_name self.is_tools = False - + class _StubLLMBundle: def __init__(self, tenant_id: str, model_config: dict, lang="Chinese", **kwargs): self.tenant_id = tenant_id self.model_config = model_config self.lang = lang - + def encode(self, texts: list): import numpy as np + # Return mock embeddings and token usage return [np.array([0.1, 0.2, 0.3]) for _ in texts], len(texts) * 10 - - llm_service_mod.LLMService = SimpleNamespace( - query=lambda llm_name: [_StubLLM(llm_name)] if llm_name else [] - ) + + llm_service_mod.LLMService = SimpleNamespace(query=lambda llm_name: [_StubLLM(llm_name)] if llm_name else []) llm_service_mod.LLMBundle = _StubLLMBundle monkeypatch.setitem(sys.modules, "api.db.services.llm_service", llm_service_mod) # Mock tenant_model_service to ensure it uses mocked services tenant_model_service_mod = ModuleType("api.db.joint_services.tenant_model_service") - + class _MockModelConfig2: def __init__(self, tenant_id, model_name): self.tenant_id = tenant_id @@ -218,7 +205,7 @@ def _load_dify_retrieval_module(monkeypatch): self.used_tokens = 0 self.status = 1 self.id = 1 - + def to_dict(self): return { "tenant_id": self.tenant_id, @@ -230,9 +217,9 @@ def _load_dify_retrieval_module(monkeypatch): "max_tokens": self.max_tokens, "used_tokens": self.used_tokens, "status": self.status, - "id": self.id + "id": self.id, } - + def _get_model_config_by_id( tenant_model_id: int, allowed_tenant_ids=None, @@ -247,16 +234,16 @@ def _load_dify_retrieval_module(monkeypatch): if mock_tenant_id not in allowed_tenant_ids and str(requester_tenant_id) != mock_tenant_id: raise LookupError(f"Tenant Model with id {tenant_model_id} not authorized") return _MockModelConfig2(mock_tenant_id, "model-1").to_dict() - + def _get_model_config_from_provider_instance(tenant_id: str, model_type: str, model_name: str): if not model_name: raise Exception("Model Name is required") return _MockModelConfig2(tenant_id, model_name).to_dict() - + def _get_tenant_default_model_by_type(tenant_id: str, model_type): # Return mock tenant with default model configurations return _MockModelConfig2(tenant_id, "chat-model").to_dict() - + tenant_model_service_mod.get_model_config_by_id = _get_model_config_by_id tenant_model_service_mod.get_model_config_from_provider_instance = _get_model_config_from_provider_instance tenant_model_service_mod.get_tenant_default_model_by_type = _get_tenant_default_model_by_type diff --git a/test/testcases/test_http_api/test_dataset_management/test_list_datasets.py b/test/testcases/test_http_api/test_dataset_management/test_list_datasets.py index 15e7fe662f..146ca9d2da 100644 --- a/test/testcases/test_http_api/test_dataset_management/test_list_datasets.py +++ b/test/testcases/test_http_api/test_dataset_management/test_list_datasets.py @@ -146,8 +146,8 @@ class TestDatasetsList: @pytest.mark.parametrize( "params, assertions", [ - ({"orderby": "create_time"}, lambda r: (is_sorted(r["data"], "create_time", True))), - ({"orderby": "update_time"}, lambda r: (is_sorted(r["data"], "update_time", True))), + ({"orderby": "create_time"}, lambda r: is_sorted(r["data"], "create_time", True)), + ({"orderby": "update_time"}, lambda r: is_sorted(r["data"], "update_time", True)), ], ids=["orderby_create_time", "orderby_update_time"], ) @@ -185,16 +185,16 @@ class TestDatasetsList: @pytest.mark.parametrize( "params, assertions", [ - ({"desc": True}, lambda r: (is_sorted(r["data"], "create_time", True))), - ({"desc": False}, lambda r: (is_sorted(r["data"], "create_time", False))), - ({"desc": "true"}, lambda r: (is_sorted(r["data"], "create_time", True))), - ({"desc": "false"}, lambda r: (is_sorted(r["data"], "create_time", False))), - ({"desc": 1}, lambda r: (is_sorted(r["data"], "create_time", True))), - ({"desc": 0}, lambda r: (is_sorted(r["data"], "create_time", False))), - ({"desc": "yes"}, lambda r: (is_sorted(r["data"], "create_time", True))), - ({"desc": "no"}, lambda r: (is_sorted(r["data"], "create_time", False))), - ({"desc": "y"}, lambda r: (is_sorted(r["data"], "create_time", True))), - ({"desc": "n"}, lambda r: (is_sorted(r["data"], "create_time", False))), + ({"desc": True}, lambda r: is_sorted(r["data"], "create_time", True)), + ({"desc": False}, lambda r: is_sorted(r["data"], "create_time", False)), + ({"desc": "true"}, lambda r: is_sorted(r["data"], "create_time", True)), + ({"desc": "false"}, lambda r: is_sorted(r["data"], "create_time", False)), + ({"desc": 1}, lambda r: is_sorted(r["data"], "create_time", True)), + ({"desc": 0}, lambda r: is_sorted(r["data"], "create_time", False)), + ({"desc": "yes"}, lambda r: is_sorted(r["data"], "create_time", True)), + ({"desc": "no"}, lambda r: is_sorted(r["data"], "create_time", False)), + ({"desc": "y"}, lambda r: is_sorted(r["data"], "create_time", True)), + ({"desc": "n"}, lambda r: is_sorted(r["data"], "create_time", False)), ], ids=["desc=True", "desc=False", "desc=true", "desc=false", "desc=1", "desc=0", "desc=yes", "desc=no", "desc=y", "desc=n"], ) diff --git a/test/testcases/test_http_api/test_file_app/test_file_routes.py b/test/testcases/test_http_api/test_file_app/test_file_routes.py index 93774d2908..19c4a9876e 100644 --- a/test/testcases/test_http_api/test_file_app/test_file_routes.py +++ b/test/testcases/test_http_api/test_file_app/test_file_routes.py @@ -223,12 +223,16 @@ def test_upload_file_success_uses_new_service_layer(monkeypatch): "create_folder", lambda _file, parent_id, _names, _len_id, *_args: SimpleNamespace(id=parent_id), ) - monkeypatch.setattr(module.settings, "STORAGE_IMPL", SimpleNamespace( - obj_exist=lambda *_args, **_kwargs: False, - put=lambda bucket, location, blob: storage_puts.append((bucket, location, blob)), - rm=lambda *_args, **_kwargs: None, - move=lambda *_args, **_kwargs: None, - )) + monkeypatch.setattr( + module.settings, + "STORAGE_IMPL", + SimpleNamespace( + obj_exist=lambda *_args, **_kwargs: False, + put=lambda bucket, location, blob: storage_puts.append((bucket, location, blob)), + rm=lambda *_args, **_kwargs: None, + move=lambda *_args, **_kwargs: None, + ), + ) ok, data = _run(module.upload_file("tenant1", "pf1", [_DummyUploadFile("a.txt", b"hello")])) assert ok is True @@ -291,12 +295,16 @@ def test_move_files_handles_dest_and_storage_move(monkeypatch): "get_by_ids", lambda _ids: [_DummyFile("file1", module.FileType.DOC.value, parent_id="src", location="old", name="a.txt")], ) - monkeypatch.setattr(module.settings, "STORAGE_IMPL", SimpleNamespace( - obj_exist=lambda *_args, **_kwargs: False, - put=lambda *_args, **_kwargs: None, - rm=lambda *_args, **_kwargs: None, - move=lambda old_bucket, old_loc, new_bucket, new_loc: moved.append((old_bucket, old_loc, new_bucket, new_loc)), - )) + monkeypatch.setattr( + module.settings, + "STORAGE_IMPL", + SimpleNamespace( + obj_exist=lambda *_args, **_kwargs: False, + put=lambda *_args, **_kwargs: None, + rm=lambda *_args, **_kwargs: None, + move=lambda old_bucket, old_loc, new_bucket, new_loc: moved.append((old_bucket, old_loc, new_bucket, new_loc)), + ), + ) monkeypatch.setattr(module.FileService, "update_by_id", lambda file_id, data: updated.append((file_id, data)) or True) ok, message = _run(module.move_files("tenant1", ["file1"], "missing")) diff --git a/test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py b/test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py index 663735b14e..6620a102e7 100644 --- a/test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py +++ b/test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py @@ -147,8 +147,10 @@ def _load_doc_module(monkeypatch, module_basename="chunk_api"): monkeypatch.setitem(sys.modules, "common.settings", common_settings_mod) common_misc_utils_mod = ModuleType("common.misc_utils") + async def _thread_pool_exec(func, *args, **kwargs): return func(*args, **kwargs) + common_misc_utils_mod.thread_pool_exec = _thread_pool_exec monkeypatch.setitem(sys.modules, "common.misc_utils", common_misc_utils_mod) @@ -243,9 +245,7 @@ def _load_doc_module(monkeypatch, module_basename="chunk_api"): api_utils_mod.get_error_data_result = lambda message="Sorry! Data missing!", code=102: {"code": code, "message": message} api_utils_mod.get_request_json = lambda: _AwaitableValue({}) api_utils_mod.get_result = lambda code=0, message="", data=None, total=None: { - key: value - for key, value in {"code": code, "message": message, "data": data, "total": total}.items() - if value is not None + key: value for key, value in {"code": code, "message": message, "data": data, "total": total}.items() if value is not None } api_utils_mod.server_error_response = lambda e: {"code": 500, "message": str(e)} monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) @@ -255,12 +255,11 @@ def _load_doc_module(monkeypatch, module_basename="chunk_api"): monkeypatch.setitem(sys.modules, "api.utils.image_utils", image_utils_mod) reference_metadata_utils_mod = ModuleType("api.utils.reference_metadata_utils") - reference_metadata_utils_mod.resolve_reference_metadata_preferences = ( - lambda req, *_args, **_kwargs: ( - bool((req.get("reference_metadata") or {}).get("include")), - set((req.get("reference_metadata") or {}).get("fields") or []), - ) + reference_metadata_utils_mod.resolve_reference_metadata_preferences = lambda req, *_args, **_kwargs: ( + bool((req.get("reference_metadata") or {}).get("include")), + set((req.get("reference_metadata") or {}).get("fields") or []), ) + def _enrich_chunks_with_document_metadata(chunks, metadata_fields=None): for chunk in chunks: doc_id = chunk.get("doc_id") or chunk.get("document_id") @@ -325,7 +324,7 @@ def _load_doc_module(monkeypatch, module_basename="chunk_api"): # Mock tenant_llm_service for TenantLLMService and TenantService tenant_llm_service_mod = ModuleType("api.db.services.tenant_llm_service") - + class _MockModelConfig: def __init__(self, tenant_id, model_name): self.tenant_id = tenant_id @@ -338,7 +337,7 @@ def _load_doc_module(monkeypatch, module_basename="chunk_api"): self.used_tokens = 0 self.status = 1 self.id = 1 - + def to_dict(self): return { "tenant_id": self.tenant_id, @@ -350,46 +349,40 @@ def _load_doc_module(monkeypatch, module_basename="chunk_api"): "max_tokens": self.max_tokens, "used_tokens": self.used_tokens, "status": self.status, - "id": self.id + "id": self.id, } - + class _StubTenantService: @staticmethod def get_by_id(tenant_id): - return True, SimpleNamespace( - id=tenant_id, - llm_id="chat-model", - embd_id="embd-model", - asr_id="asr-model", - img2txt_id="img2txt-model", - rerank_id="rerank-model", - tts_id="tts-model" - ) - + return True, SimpleNamespace(id=tenant_id, llm_id="chat-model", embd_id="embd-model", asr_id="asr-model", img2txt_id="img2txt-model", rerank_id="rerank-model", tts_id="tts-model") + class _StubTenantLLMService: @staticmethod def get_api_key(tenant_id, model_name): return _MockModelConfig(tenant_id, model_name) - + @staticmethod def split_model_name_and_factory(model_name): if "@" in model_name: parts = model_name.split("@") return parts[0], parts[1] return model_name, None - + @staticmethod def get_by_id(tenant_model_id): return True, _MockModelConfig("tenant-1", "model-1") - + @staticmethod def model_instance(model_config): class _EmbedModel: def encode(self, texts): import numpy as np + return [np.array([0.2, 0.8]), np.array([0.3, 0.7])], 1 + return _EmbedModel() - + tenant_llm_service_mod.TenantService = _StubTenantService tenant_llm_service_mod.TenantLLMService = _StubTenantLLMService @@ -401,32 +394,31 @@ def _load_doc_module(monkeypatch, module_basename="chunk_api"): # Mock LLMService llm_service_mod = ModuleType("api.db.services.llm_service") - + class _StubLLM: def __init__(self, llm_name): self.llm_name = llm_name self.is_tools = False - + class _StubLLMBundle: def __init__(self, tenant_id: str, model_config: dict, lang="Chinese", **kwargs): self.tenant_id = tenant_id self.model_config = model_config self.lang = lang - + def encode(self, texts: list): import numpy as np + # Return mock embeddings and token usage return [np.array([0.2, 0.8]), np.array([0.3, 0.7])], len(texts) * 10 - - llm_service_mod.LLMService = SimpleNamespace( - query=lambda llm_name: [_StubLLM(llm_name)] if llm_name else [] - ) + + llm_service_mod.LLMService = SimpleNamespace(query=lambda llm_name: [_StubLLM(llm_name)] if llm_name else []) llm_service_mod.LLMBundle = _StubLLMBundle monkeypatch.setitem(sys.modules, "api.db.services.llm_service", llm_service_mod) # Mock tenant_model_service to ensure it uses mocked services tenant_model_service_mod = ModuleType("api.db.joint_services.tenant_model_service") - + class _MockModelConfig2: def __init__(self, tenant_id, model_name): self.tenant_id = tenant_id @@ -439,7 +431,7 @@ def _load_doc_module(monkeypatch, module_basename="chunk_api"): self.used_tokens = 0 self.status = 1 self.id = 1 - + def to_dict(self): return { "tenant_id": self.tenant_id, @@ -451,9 +443,9 @@ def _load_doc_module(monkeypatch, module_basename="chunk_api"): "max_tokens": self.max_tokens, "used_tokens": self.used_tokens, "status": self.status, - "id": self.id + "id": self.id, } - + def _get_model_config_by_id( tenant_model_id: int, allowed_tenant_ids=None, @@ -468,16 +460,16 @@ def _load_doc_module(monkeypatch, module_basename="chunk_api"): if mock_tenant_id not in allowed_tenant_ids and str(requester_tenant_id) != mock_tenant_id: raise LookupError(f"Tenant Model with id {tenant_model_id} not authorized") return _MockModelConfig2(mock_tenant_id, "model-1").to_dict() - + def _get_model_config_from_provider_instance(tenant_id: str, model_type: str, model_name: str): if not model_name: raise Exception("Model Name is required") return _MockModelConfig2(tenant_id, model_name).to_dict() - + def _get_tenant_default_model_by_type(tenant_id: str, model_type): # Return mock tenant with default model configurations return _MockModelConfig2(tenant_id, "chat-model").to_dict() - + tenant_model_service_mod.get_model_config_by_id = _get_model_config_by_id tenant_model_service_mod.get_model_config_from_provider_instance = _get_model_config_from_provider_instance tenant_model_service_mod.get_tenant_default_model_by_type = _get_tenant_default_model_by_type @@ -636,7 +628,6 @@ class TestDocRoutesUnit: assert res["filename"] == "report.pdf" assert res["mimetype"] == "application/pdf" - def test_parse_branches(self, monkeypatch): module = _load_doc_module(monkeypatch) monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda **_kwargs: False) diff --git a/test/testcases/test_http_api/test_file_management_within_dataset/test_list_documents.py b/test/testcases/test_http_api/test_file_management_within_dataset/test_list_documents.py index f2a2f5c905..af5b43c806 100644 --- a/test/testcases/test_http_api/test_file_management_within_dataset/test_list_documents.py +++ b/test/testcases/test_http_api/test_file_management_within_dataset/test_list_documents.py @@ -155,10 +155,10 @@ class TestDocumentsList: @pytest.mark.parametrize( "params, expected_code, assertions, expected_message", [ - ({"orderby": None}, 0, lambda r: (is_sorted(r["data"]["docs"], "create_time", True)), ""), - ({"orderby": "create_time"}, 0, lambda r: (is_sorted(r["data"]["docs"], "create_time", True)), ""), - ({"orderby": "update_time"}, 0, lambda r: (is_sorted(r["data"]["docs"], "update_time", True)), ""), - pytest.param({"orderby": "name", "desc": "False"}, 0, lambda r: (is_sorted(r["data"]["docs"], "name", False)), "", marks=pytest.mark.skip(reason="issues/5851")), + ({"orderby": None}, 0, lambda r: is_sorted(r["data"]["docs"], "create_time", True), ""), + ({"orderby": "create_time"}, 0, lambda r: is_sorted(r["data"]["docs"], "create_time", True), ""), + ({"orderby": "update_time"}, 0, lambda r: is_sorted(r["data"]["docs"], "update_time", True), ""), + pytest.param({"orderby": "name", "desc": "False"}, 0, lambda r: is_sorted(r["data"]["docs"], "name", False), "", marks=pytest.mark.skip(reason="issues/5851")), pytest.param({"orderby": "unknown"}, 102, 0, "orderby should be create_time or update_time", marks=pytest.mark.skip(reason="issues/5851")), ], ) @@ -184,14 +184,14 @@ class TestDocumentsList: @pytest.mark.parametrize( "params, expected_code, assertions, expected_message", [ - ({"desc": None}, 0, lambda r: (is_sorted(r["data"]["docs"], "create_time", True)), ""), - ({"desc": "true"}, 0, lambda r: (is_sorted(r["data"]["docs"], "create_time", True)), ""), - ({"desc": "True"}, 0, lambda r: (is_sorted(r["data"]["docs"], "create_time", True)), ""), - ({"desc": True}, 0, lambda r: (is_sorted(r["data"]["docs"], "create_time", True)), ""), - pytest.param({"desc": "false"}, 0, lambda r: (is_sorted(r["data"]["docs"], "create_time", False)), "", marks=pytest.mark.skip(reason="issues/5851")), - ({"desc": "False"}, 0, lambda r: (is_sorted(r["data"]["docs"], "create_time", False)), ""), - ({"desc": False}, 0, lambda r: (is_sorted(r["data"]["docs"], "create_time", False)), ""), - ({"desc": "False", "orderby": "update_time"}, 0, lambda r: (is_sorted(r["data"]["docs"], "update_time", False)), ""), + ({"desc": None}, 0, lambda r: is_sorted(r["data"]["docs"], "create_time", True), ""), + ({"desc": "true"}, 0, lambda r: is_sorted(r["data"]["docs"], "create_time", True), ""), + ({"desc": "True"}, 0, lambda r: is_sorted(r["data"]["docs"], "create_time", True), ""), + ({"desc": True}, 0, lambda r: is_sorted(r["data"]["docs"], "create_time", True), ""), + pytest.param({"desc": "false"}, 0, lambda r: is_sorted(r["data"]["docs"], "create_time", False), "", marks=pytest.mark.skip(reason="issues/5851")), + ({"desc": "False"}, 0, lambda r: is_sorted(r["data"]["docs"], "create_time", False), ""), + ({"desc": False}, 0, lambda r: is_sorted(r["data"]["docs"], "create_time", False), ""), + ({"desc": "False", "orderby": "update_time"}, 0, lambda r: is_sorted(r["data"]["docs"], "update_time", False), ""), pytest.param({"desc": "unknown"}, 102, 0, "desc should be true or false", marks=pytest.mark.skip(reason="issues/5851")), ], ) @@ -231,7 +231,6 @@ class TestDocumentsList: assert len(res["data"]["docs"]) == expected_num assert res["data"]["total"] == expected_num - @pytest.mark.p1 @pytest.mark.parametrize( "params, expected_code, expected_num, expected_message", @@ -240,21 +239,21 @@ class TestDocumentsList: ({"name": ""}, 0, 5, ""), ({"name": "ragflow_test_upload_0.txt"}, 0, 1, ""), ( - {"name": "unknown.txt"}, - 102, - 0, - "You don't own the document unknown.txt.", + {"name": "unknown.txt"}, + 102, + 0, + "You don't own the document unknown.txt.", ), ], ) def test_name( - self, - HttpApiAuth, - add_documents, - params, - expected_code, - expected_num, - expected_message, + self, + HttpApiAuth, + add_documents, + params, + expected_code, + expected_num, + expected_message, ): dataset_id, _ = add_documents res = list_documents(HttpApiAuth, dataset_id, params=params) @@ -267,7 +266,6 @@ class TestDocumentsList: else: assert res["message"] == expected_message - @pytest.mark.p1 @pytest.mark.parametrize( "document_id, expected_code, expected_num, expected_message", @@ -279,13 +277,13 @@ class TestDocumentsList: ], ) def test_id( - self, - HttpApiAuth, - add_documents, - document_id, - expected_code, - expected_num, - expected_message, + self, + HttpApiAuth, + add_documents, + document_id, + expected_code, + expected_num, + expected_message, ): dataset_id, document_ids = add_documents if callable(document_id): @@ -304,7 +302,6 @@ class TestDocumentsList: else: assert res["message"] == expected_message - @pytest.mark.p2 @pytest.mark.parametrize( "document_id, name, expected_code, expected_num, expected_message", @@ -313,23 +310,23 @@ class TestDocumentsList: (lambda r: r[0], "ragflow_test_upload_1.txt", 0, 0, ""), (lambda r: r[0], "unknown", 102, 0, "You don't own the document unknown."), ( - "id", - "ragflow_test_upload_0.txt", - 102, - 0, - "You don't own the document id.", + "id", + "ragflow_test_upload_0.txt", + 102, + 0, + "You don't own the document id.", ), ], ) def test_name_and_id( - self, - HttpApiAuth, - add_documents, - document_id, - name, - expected_code, - expected_num, - expected_message, + self, + HttpApiAuth, + add_documents, + document_id, + name, + expected_code, + expected_num, + expected_message, ): dataset_id, document_ids = add_documents if callable(document_id): @@ -343,7 +340,6 @@ class TestDocumentsList: else: assert res["message"] == expected_message - @pytest.mark.p3 def test_concurrent_list(self, HttpApiAuth, add_documents): dataset_id, _ = add_documents @@ -379,9 +375,7 @@ class TestDocumentsList: ), ], ) - def test_metadata_condition_validation( - self, HttpApiAuth, add_documents, params, expected_code, expected_message - ): + def test_metadata_condition_validation(self, HttpApiAuth, add_documents, params, expected_code, expected_message): dataset_id, _ = add_documents res = list_documents(HttpApiAuth, dataset_id, params=params) assert res["code"] == expected_code @@ -399,9 +393,7 @@ class TestDocumentsList: ({"create_time_from": "0", "create_time_to": "9999999999000"}, 0, 5), ], ) - def test_create_time_filter( - self, HttpApiAuth, add_documents, params, expected_code, expected_total - ): + def test_create_time_filter(self, HttpApiAuth, add_documents, params, expected_code, expected_total): dataset_id, _ = add_documents res = list_documents(HttpApiAuth, dataset_id, params=params) @@ -417,9 +409,7 @@ class TestDocumentsList: ({"run": ["INVALID_STATUS"]}, 102, "Invalid filter run status conditions: INVALID_STATUS"), ], ) - def test_run_status_filter_invalid( - self, HttpApiAuth, add_documents, params, expected_code, expected_message - ): + def test_run_status_filter_invalid(self, HttpApiAuth, add_documents, params, expected_code, expected_message): dataset_id, _ = add_documents res = list_documents(HttpApiAuth, dataset_id, params=params) @@ -434,9 +424,7 @@ class TestDocumentsList: ({"run": ["UNSTART"]}, 5), ], ) - def test_run_status_filter_unstart( - self, HttpApiAuth, add_documents, params, expected_size - ): + def test_run_status_filter_unstart(self, HttpApiAuth, add_documents, params, expected_size): dataset_id, _ = add_documents res = list_documents(HttpApiAuth, dataset_id, params=params) diff --git a/test/testcases/test_http_api/test_file_management_within_dataset/test_metadata_batch_update.py b/test/testcases/test_http_api/test_file_management_within_dataset/test_metadata_batch_update.py index f2b3060d64..b4413c6830 100644 --- a/test/testcases/test_http_api/test_file_management_within_dataset/test_metadata_batch_update.py +++ b/test/testcases/test_http_api/test_file_management_within_dataset/test_metadata_batch_update.py @@ -19,6 +19,7 @@ End-to-end tests for metadata batch update API. This test file converts the unit test test_metadata_batch_update from test_doc_sdk_routes_unit.py to end-to-end tests that call the actual HTTP API. """ + import pytest from common import ( update_documents_metadata, @@ -376,7 +377,7 @@ class TestMetadataBatchUpdateSuccess: { "selector": { "document_ids": document_ids, - "metadata_condition": {"conditions": [{"comparison_operator":"is", "name": "nonexistent_key", "value": "nonexistent_value"}]}, + "metadata_condition": {"conditions": [{"comparison_operator": "is", "name": "nonexistent_key", "value": "nonexistent_value"}]}, }, "updates": [{"key": "author", "value": "test"}], "deletes": [], diff --git a/test/testcases/test_http_api/test_file_management_within_dataset/test_metadata_retrieval.py b/test/testcases/test_http_api/test_file_management_within_dataset/test_metadata_retrieval.py index 77f9312470..1afc41bcfd 100644 --- a/test/testcases/test_http_api/test_file_management_within_dataset/test_metadata_retrieval.py +++ b/test/testcases/test_http_api/test_file_management_within_dataset/test_metadata_retrieval.py @@ -53,10 +53,7 @@ def _condition_parsing_complete(_auth, dataset_id): @pytest.fixture(scope="function") def add_dataset_with_metadata(HttpApiAuth): # First create the dataset - res = create_dataset(HttpApiAuth, { - "name": f"test_metadata_{int(__import__('time').time())}", - "chunk_method": "naive" - }) + res = create_dataset(HttpApiAuth, {"name": f"test_metadata_{int(__import__('time').time())}", "chunk_method": "naive"}) assert res["code"] == 0, f"Failed to create dataset: {res}" dataset_id = res["data"]["id"] @@ -75,7 +72,7 @@ def add_dataset_with_metadata(HttpApiAuth): {"key": "era", "type": "string", "description": "Historical era"}, {"key": "achievements", "type": "list", "description": "Major achievements"}, ] - } + }, ).json() assert res["code"] == 0, f"Failed to configure metadata: {res}" @@ -118,14 +115,10 @@ class TestMetadataWithRetrieval: doc2_id = res["data"][1]["id"] # Add different metadata to each document - res = update_document(HttpApiAuth, dataset_id, doc1_id, { - "meta_fields": {"character": "Zhuge Liang", "era": "Three Kingdoms"} - }) + res = update_document(HttpApiAuth, dataset_id, doc1_id, {"meta_fields": {"character": "Zhuge Liang", "era": "Three Kingdoms"}}) assert res["code"] == 0, f"Failed to update doc1 metadata: {res}" - res = update_document(HttpApiAuth, dataset_id, doc2_id, { - "meta_fields": {"character": "Cao Cao", "era": "Late Eastern Han"} - }) + res = update_document(HttpApiAuth, dataset_id, doc2_id, {"meta_fields": {"character": "Cao Cao", "era": "Late Eastern Han"}}) assert res["code"] == 0, f"Failed to update doc2 metadata: {res}" # Parse both documents @@ -136,20 +129,14 @@ class TestMetadataWithRetrieval: assert _condition_parsing_complete(HttpApiAuth, dataset_id), "Parsing timeout" # Test retrieval WITH metadata filter for "Zhuge Liang" - res = retrieval_chunks(HttpApiAuth, { - "question": "Zhuge Liang", - "dataset_ids": [dataset_id], - "metadata_condition": { - "logic": "and", - "conditions": [ - { - "name": "character", - "comparison_operator": "is", - "value": "Zhuge Liang" - } - ] - } - }) + res = retrieval_chunks( + HttpApiAuth, + { + "question": "Zhuge Liang", + "dataset_ids": [dataset_id], + "metadata_condition": {"logic": "and", "conditions": [{"name": "character", "comparison_operator": "is", "value": "Zhuge Liang"}]}, + }, + ) assert res["code"] == 0, f"Retrieval with metadata filter failed: {res}" chunks_with_filter = res["data"]["chunks"] diff --git a/test/testcases/test_http_api/test_file_management_within_dataset/test_upload_documents.py b/test/testcases/test_http_api/test_file_management_within_dataset/test_upload_documents.py index 050119ae47..0b21dd6db0 100644 --- a/test/testcases/test_http_api/test_file_management_within_dataset/test_upload_documents.py +++ b/test/testcases/test_http_api/test_file_management_within_dataset/test_upload_documents.py @@ -140,7 +140,7 @@ class TestDocumentsUpload: fp = create_txt_file(tmp_path / "ragflow_test.txt") res = upload_documents(HttpApiAuth, "invalid_dataset_id", [fp]) assert res["code"] == 102 - assert res["message"] == "Can\'t find the dataset with ID invalid_dataset_id!" + assert res["message"] == "Can't find the dataset with ID invalid_dataset_id!" @pytest.mark.p2 def test_duplicate_files(self, HttpApiAuth, add_dataset_func, tmp_path): diff --git a/test/testcases/test_http_api/test_session_management/test_agent_completions.py b/test/testcases/test_http_api/test_session_management/test_agent_completions.py index 18f0392471..2006bfde66 100644 --- a/test/testcases/test_http_api/test_session_management/test_agent_completions.py +++ b/test/testcases/test_http_api/test_session_management/test_agent_completions.py @@ -56,6 +56,7 @@ def _agent_items(res): return data.get("canvas", []) return data + @pytest.fixture(scope="function") def agent_id(HttpApiAuth, request): res = list_agents(HttpApiAuth, {"page_size": 100}) diff --git a/test/testcases/test_http_api/test_session_management/test_agent_sessions.py b/test/testcases/test_http_api/test_session_management/test_agent_sessions.py index 6c606c9686..4b8bf5e389 100644 --- a/test/testcases/test_http_api/test_session_management/test_agent_sessions.py +++ b/test/testcases/test_http_api/test_session_management/test_agent_sessions.py @@ -56,6 +56,7 @@ def _agent_items(res): return data.get("canvas", []) return data + @pytest.fixture(scope="function") def agent_id(HttpApiAuth, request): res = list_agents(HttpApiAuth, {"page_size": 100}) @@ -81,7 +82,6 @@ def agent_id(HttpApiAuth, request): class TestAgentSessions: - @pytest.mark.p2 def test_agent_crud_validation_contract(self, HttpApiAuth, agent_id): res = list_agents(HttpApiAuth, {"id": "missing-agent-id", "title": "missing-agent-title"}) diff --git a/test/testcases/test_http_api/test_session_management/test_chat_completions_openai.py b/test/testcases/test_http_api/test_session_management/test_chat_completions_openai.py index 4df694dc63..b45ae2562d 100644 --- a/test/testcases/test_http_api/test_session_management/test_chat_completions_openai.py +++ b/test/testcases/test_http_api/test_session_management/test_chat_completions_openai.py @@ -77,8 +77,7 @@ class TestChatCompletionsOpenAI: assert "prompt_tokens" in usage, f"'usage' should contain 'prompt_tokens': {usage}" assert "completion_tokens" in usage, f"'usage' should contain 'completion_tokens': {usage}" assert "total_tokens" in usage, f"'usage' should contain 'total_tokens': {usage}" - assert usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"], \ - f"total_tokens should equal prompt_tokens + completion_tokens: {usage}" + assert usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"], f"total_tokens should equal prompt_tokens + completion_tokens: {usage}" @pytest.mark.p2 def test_openai_chat_completion_token_count_reasonable(self, HttpApiAuth, add_dataset_func, tmp_path, request): diff --git a/test/testcases/test_http_api/test_session_management/test_list_sessions_with_chat_assistant.py b/test/testcases/test_http_api/test_session_management/test_list_sessions_with_chat_assistant.py index 8db09d5208..4d3d79c223 100644 --- a/test/testcases/test_http_api/test_session_management/test_list_sessions_with_chat_assistant.py +++ b/test/testcases/test_http_api/test_session_management/test_list_sessions_with_chat_assistant.py @@ -86,10 +86,10 @@ class TestSessionsWithChatAssistantList: @pytest.mark.parametrize( "params, expected_code, assertions, expected_message", [ - ({"orderby": None}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""), - ({"orderby": "create_time"}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""), - ({"orderby": "update_time"}, 0, lambda r: (is_sorted(r["data"], "update_time", True)), ""), - ({"orderby": "name", "desc": "False"}, 0, lambda r: (is_sorted(r["data"], "name", False)), ""), + ({"orderby": None}, 0, lambda r: is_sorted(r["data"], "create_time", True), ""), + ({"orderby": "create_time"}, 0, lambda r: is_sorted(r["data"], "create_time", True), ""), + ({"orderby": "update_time"}, 0, lambda r: is_sorted(r["data"], "update_time", True), ""), + ({"orderby": "name", "desc": "False"}, 0, lambda r: is_sorted(r["data"], "name", False), ""), pytest.param({"orderby": "unknown"}, 102, 0, "orderby should be create_time or update_time", marks=pytest.mark.skip(reason="issues/")), ], ) @@ -115,14 +115,14 @@ class TestSessionsWithChatAssistantList: @pytest.mark.parametrize( "params, expected_code, assertions, expected_message", [ - ({"desc": None}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""), - ({"desc": "true"}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""), - ({"desc": "True"}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""), - ({"desc": True}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""), - ({"desc": "false"}, 0, lambda r: (is_sorted(r["data"], "create_time", False)), ""), - ({"desc": "False"}, 0, lambda r: (is_sorted(r["data"], "create_time", False)), ""), - ({"desc": False}, 0, lambda r: (is_sorted(r["data"], "create_time", False)), ""), - ({"desc": "False", "orderby": "update_time"}, 0, lambda r: (is_sorted(r["data"], "update_time", False)), ""), + ({"desc": None}, 0, lambda r: is_sorted(r["data"], "create_time", True), ""), + ({"desc": "true"}, 0, lambda r: is_sorted(r["data"], "create_time", True), ""), + ({"desc": "True"}, 0, lambda r: is_sorted(r["data"], "create_time", True), ""), + ({"desc": True}, 0, lambda r: is_sorted(r["data"], "create_time", True), ""), + ({"desc": "false"}, 0, lambda r: is_sorted(r["data"], "create_time", False), ""), + ({"desc": "False"}, 0, lambda r: is_sorted(r["data"], "create_time", False), ""), + ({"desc": False}, 0, lambda r: is_sorted(r["data"], "create_time", False), ""), + ({"desc": "False", "orderby": "update_time"}, 0, lambda r: is_sorted(r["data"], "update_time", False), ""), pytest.param({"desc": "unknown"}, 102, 0, "desc should be true or false", marks=pytest.mark.skip(reason="issues/")), ], ) diff --git a/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py b/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py index 0287fc75aa..4fec9309d8 100644 --- a/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py +++ b/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py @@ -247,6 +247,7 @@ def _load_session_module(monkeypatch): common_constants_mod.TAG_FLD = "tag_feas" # Import pure-Python constants from the real module (no heavy deps) from common.constants import MAXIMUM_PAGE_NUMBER as _MPN, MAXIMUM_TASK_PAGE_NUMBER as _MTPN + common_constants_mod.MAXIMUM_PAGE_NUMBER = _MPN common_constants_mod.MAXIMUM_TASK_PAGE_NUMBER = _MTPN monkeypatch.setitem(sys.modules, "common.constants", common_constants_mod) @@ -269,13 +270,11 @@ def _load_session_module(monkeypatch): api_utils_mod.get_error_data_result = lambda message="Sorry! Data missing!", code=_StubRetCode.DATA_ERROR: {"code": code, "message": message} api_utils_mod.get_json_result = lambda code=_StubRetCode.SUCCESS, message="success", data=None: {"code": code, "message": message, "data": data} api_utils_mod.get_result = lambda code=_StubRetCode.SUCCESS, message="", data=None, total=None: { - key: value - for key, value in {"code": code, "message": message, "data": data, "total": total}.items() - if value is not None + key: value for key, value in {"code": code, "message": message, "data": data, "total": total}.items() if value is not None } api_utils_mod.get_request_json = lambda: _AwaitableValue({}) api_utils_mod.server_error_response = lambda e: {"code": _StubRetCode.SERVER_ERROR, "message": str(e)} - api_utils_mod.validate_request = lambda *_args, **_kwargs: (lambda func: func) + api_utils_mod.validate_request = lambda *_args, **_kwargs: lambda func: func monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) rag_app_tag_mod = ModuleType("rag.app.tag") @@ -344,7 +343,7 @@ def _load_session_module(monkeypatch): # Mock tenant_llm_service for TenantLLMService and TenantService tenant_llm_service_mod = ModuleType("api.db.services.tenant_llm_service") - + class _MockModelConfig: def __init__(self, tenant_id, model_name): self.tenant_id = tenant_id @@ -357,7 +356,7 @@ def _load_session_module(monkeypatch): self.used_tokens = 0 self.status = 1 self.id = 1 - + def to_dict(self): return { "tenant_id": self.tenant_id, @@ -369,23 +368,15 @@ def _load_session_module(monkeypatch): "max_tokens": self.max_tokens, "used_tokens": self.used_tokens, "status": self.status, - "id": self.id + "id": self.id, } - + class _StubTenantService: @staticmethod def get_by_id(tenant_id): # Return a mock tenant with default model configurations - return True, SimpleNamespace( - id=tenant_id, - llm_id="chat-model", - embd_id="embd-model", - asr_id="asr-model", - img2txt_id="img2txt-model", - rerank_id="rerank-model", - tts_id="tts-model" - ) - + return True, SimpleNamespace(id=tenant_id, llm_id="chat-model", embd_id="embd-model", asr_id="asr-model", img2txt_id="img2txt-model", rerank_id="rerank-model", tts_id="tts-model") + class _StubTenantLLMService: @staticmethod def get_api_key(tenant_id, model_name): @@ -410,16 +401,14 @@ def _load_session_module(monkeypatch): # Mock LLMService llm_service_mod = ModuleType("api.db.services.llm_service") - + class _StubLLM: def __init__(self, llm_name): self.llm_name = llm_name self.is_tools = False - - llm_service_mod.LLMService = SimpleNamespace( - query=lambda llm_name: [_StubLLM(llm_name)] if llm_name else [] - ) - + + llm_service_mod.LLMService = SimpleNamespace(query=lambda llm_name: [_StubLLM(llm_name)] if llm_name else []) + class _StubLLMBundle: def __init__(self, tenant_id: str, model_config: dict, lang="Chinese", **kwargs): self.tenant_id = tenant_id @@ -431,13 +420,13 @@ def _load_session_module(monkeypatch): def transcription(self, audio_path): return "mock transcription" - + llm_service_mod.LLMBundle = _StubLLMBundle monkeypatch.setitem(sys.modules, "api.db.services.llm_service", llm_service_mod) # Mock tenant_model_service to ensure it uses mocked services tenant_model_service_mod = ModuleType("api.db.joint_services.tenant_model_service") - + class _MockModelConfig2: def __init__(self, tenant_id, model_name, model_type="chat"): self.tenant_id = tenant_id @@ -462,7 +451,7 @@ def _load_session_module(monkeypatch): "max_tokens": self.max_tokens, "used_tokens": self.used_tokens, "status": self.status, - "id": self.id + "id": self.id, } def _get_model_config_by_id( @@ -501,6 +490,7 @@ def _load_session_module(monkeypatch): def _get_tenant_default_model_by_type(tenant_id: str, model_type): # Check if tenant exists from api.db.services.tenant_llm_service import TenantService + exist, tenant = TenantService.get_by_id(tenant_id) if not exist: raise LookupError("Tenant not found!") @@ -523,19 +513,11 @@ def _load_session_module(monkeypatch): raise Exception("OCR model name is required") if not model_name: # Use friendly model type names - friendly_names = { - "embedding": "Embedding", - "speech2text": "ASR", - "image2text": "Image2Text", - "chat": "Chat", - "rerank": "Rerank", - "tts": "TTS", - "ocr": "OCR" - } + friendly_names = {"embedding": "Embedding", "speech2text": "ASR", "image2text": "Image2Text", "chat": "Chat", "rerank": "Rerank", "tts": "TTS", "ocr": "OCR"} friendly_name = friendly_names.get(model_type_val, model_type_val) raise Exception(f"No default {friendly_name} model is set") return _MockModelConfig2(tenant_id, model_name, model_type_val).to_dict() - + tenant_model_service_mod.get_model_config_by_id = _get_model_config_by_id tenant_model_service_mod.get_model_config_from_provider_instance = _get_model_config_from_provider_instance tenant_model_service_mod.get_tenant_default_model_by_type = _get_tenant_default_model_by_type @@ -710,7 +692,7 @@ def _load_session_module(monkeypatch): module.manager = _DummyManager() monkeypatch.setitem(sys.modules, "test_session_sdk_routes_unit_module", module) spec.loader.exec_module(module) - + # Add TenantService to module for test compatibility class _StubTenantServiceForTest: @staticmethod @@ -721,18 +703,10 @@ def _load_session_module(monkeypatch): @staticmethod def get_by_id(tenant_id): # Return mock tenant by id - return True, SimpleNamespace( - id=tenant_id, - llm_id="chat-model", - embd_id="embd-model", - asr_id="asr-model", - img2txt_id="img2txt-model", - rerank_id="rerank-model", - tts_id="tts-model" - ) + return True, SimpleNamespace(id=tenant_id, llm_id="chat-model", embd_id="embd-model", asr_id="asr-model", img2txt_id="img2txt-model", rerank_id="rerank-model", tts_id="tts-model") module.TenantService = _StubTenantServiceForTest - + return module @@ -2202,6 +2176,7 @@ def test_build_reference_chunks_metadata_matrix_unit(monkeypatch): # chat_api unit tests — session user-id spoofing fix # --------------------------------------------------------------------------- + def _load_chat_api_module(monkeypatch): """Load api/apps/restful_apis/chat_api.py with all heavy dependencies mocked.""" repo_root = Path(__file__).resolve().parents[4] @@ -2309,13 +2284,16 @@ def _load_chat_api_module(monkeypatch): dialog_svc_mod.DialogService = SimpleNamespace( model=SimpleNamespace(_meta=SimpleNamespace(fields=[])), query=lambda **_k: [SimpleNamespace(id="chat-1", icon="")], - get_by_id=lambda _id: (True, SimpleNamespace( - prompt_config={"prologue": ""}, - tenant_id="tenant-1", - llm_id="model", - kb_ids=[], - id=_id, - )), + get_by_id=lambda _id: ( + True, + SimpleNamespace( + prompt_config={"prologue": ""}, + tenant_id="tenant-1", + llm_id="model", + kb_ids=[], + id=_id, + ), + ), ) dialog_svc_mod.async_chat = lambda *_a, **_k: None dialog_svc_mod.gen_mindmap = lambda *_a, **_k: None @@ -2355,7 +2333,7 @@ def _load_chat_api_module(monkeypatch): api_utils_mod.get_json_result = lambda code=_RetCode.SUCCESS, message="success", data=None: {"code": code, "message": message, "data": data} api_utils_mod.get_request_json = lambda: _AwaitableValue({}) api_utils_mod.server_error_response = lambda e: {"code": _RetCode.SERVER_ERROR, "message": str(e)} - api_utils_mod.validate_request = lambda *_a, **_k: (lambda func: func) + api_utils_mod.validate_request = lambda *_a, **_k: lambda func: func monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) rag_gen_mod = ModuleType("rag.prompts.generator") @@ -2417,12 +2395,14 @@ def test_session_completion_user_id_not_spoofable(monkeypatch): monkeypatch.setattr( module, "get_request_json", - lambda: _AwaitableValue({ - "messages": [{"role": "user", "content": "hello"}], - "chat_id": "chat-1", - "user_id": "attacker-id", - "stream": False, - }), + lambda: _AwaitableValue( + { + "messages": [{"role": "user", "content": "hello"}], + "chat_id": "chat-1", + "user_id": "attacker-id", + "stream": False, + } + ), ) _run(inspect.unwrap(module.session_completion)()) @@ -2468,16 +2448,18 @@ def test_session_completion_uses_server_history_by_default(monkeypatch): monkeypatch.setattr( module, "get_request_json", - lambda: _AwaitableValue({ - "chat_id": "chat-1", - "session_id": "session-1", - "stream": False, - "messages": [ - {"role": "user", "content": "client old question", "id": "client-old"}, - {"role": "assistant", "content": "client old answer", "id": "client-old"}, - {"role": "user", "content": "latest question", "id": "latest"}, - ], - }), + lambda: _AwaitableValue( + { + "chat_id": "chat-1", + "session_id": "session-1", + "stream": False, + "messages": [ + {"role": "user", "content": "client old question", "id": "client-old"}, + {"role": "assistant", "content": "client old answer", "id": "client-old"}, + {"role": "user", "content": "latest question", "id": "latest"}, + ], + } + ), ) res = _run(inspect.unwrap(module.session_completion)()) @@ -2512,15 +2494,17 @@ def test_session_completion_preserves_zero_generation_params(monkeypatch): monkeypatch.setattr( module, "get_request_json", - lambda: _AwaitableValue({ - "stream": False, - "messages": [{"role": "user", "content": "latest question"}], - "temperature": 0, - "top_p": 0, - "frequency_penalty": 0, - "presence_penalty": 0, - "max_tokens": 0, - }), + lambda: _AwaitableValue( + { + "stream": False, + "messages": [{"role": "user", "content": "latest question"}], + "temperature": 0, + "top_p": 0, + "frequency_penalty": 0, + "presence_penalty": 0, + "max_tokens": 0, + } + ), ) res = _run(inspect.unwrap(module.session_completion)()) @@ -2585,14 +2569,16 @@ def test_session_completion_merges_generation_params_for_existing_chat(monkeypat monkeypatch.setattr( module, "get_request_json", - lambda: _AwaitableValue({ - "chat_id": "chat-1", - "session_id": "session-1", - "stream": False, - "messages": [{"role": "user", "content": "latest question"}], - "temperature": 0, - "presence_penalty": 0, - }), + lambda: _AwaitableValue( + { + "chat_id": "chat-1", + "session_id": "session-1", + "stream": False, + "messages": [{"role": "user", "content": "latest question"}], + "temperature": 0, + "presence_penalty": 0, + } + ), ) res = _run(inspect.unwrap(module.session_completion)()) @@ -2644,17 +2630,19 @@ def test_session_completion_can_use_submitted_full_history(monkeypatch): monkeypatch.setattr( module, "get_request_json", - lambda: _AwaitableValue({ - "chat_id": "chat-1", - "session_id": "session-1", - "stream": False, - "pass_all_history_messages": True, - "messages": [ - {"role": "user", "content": "client old question", "id": "client-old"}, - {"role": "assistant", "content": "client old answer", "id": "client-old"}, - {"role": "user", "content": "latest question", "id": "latest"}, - ], - }), + lambda: _AwaitableValue( + { + "chat_id": "chat-1", + "session_id": "session-1", + "stream": False, + "pass_all_history_messages": True, + "messages": [ + {"role": "user", "content": "client old question", "id": "client-old"}, + {"role": "assistant", "content": "client old answer", "id": "client-old"}, + {"role": "user", "content": "latest question", "id": "latest"}, + ], + } + ), ) res = _run(inspect.unwrap(module.session_completion)()) @@ -2706,12 +2694,14 @@ def test_session_completion_accepts_question_payload(monkeypatch): monkeypatch.setattr( module, "get_request_json", - lambda: _AwaitableValue({ - "chat_id": "chat-1", - "session_id": "session-1", - "stream": False, - "question": "latest question", - }), + lambda: _AwaitableValue( + { + "chat_id": "chat-1", + "session_id": "session-1", + "stream": False, + "question": "latest question", + } + ), ) res = _run(inspect.unwrap(module.session_completion)()) diff --git a/test/testcases/test_sdk_api/common.py b/test/testcases/test_sdk_api/common.py index e3a0e3d030..4f1d7d38ca 100644 --- a/test/testcases/test_sdk_api/common.py +++ b/test/testcases/test_sdk_api/common.py @@ -53,13 +53,13 @@ def list_all_sessions(chat_assistant: Chat, *, limit: int | None = None, page_si def valid_chat_llm_id(client: RAGFlow) -> str: # SDK tests use the tenant's configured chat model; this helper discovers test fixture state, not SDK behavior. - res = client.get('/users/me/models') + res = client.get("/users/me/models") data = res.json() - if data.get('code') == 0: - llm_id = (data.get('data') or {}).get('llm_id') + if data.get("code") == 0: + llm_id = (data.get("data") or {}).get("llm_id") if llm_id: return llm_id - raise Exception('No valid chat llm_id is configured for the current tenant') + raise Exception("No valid chat llm_id is configured for the current tenant") # DATASET MANAGEMENT diff --git a/test/testcases/test_sdk_api/test_chunk_management_within_dataset/conftest.py b/test/testcases/test_sdk_api/test_chunk_management_within_dataset/conftest.py index 57bbd879a0..4727c85a48 100644 --- a/test/testcases/test_sdk_api/test_chunk_management_within_dataset/conftest.py +++ b/test/testcases/test_sdk_api/test_chunk_management_within_dataset/conftest.py @@ -15,7 +15,6 @@ # - import pytest from common import batch_add_chunks, delete_all_chunks from pytest import FixtureRequest @@ -31,11 +30,13 @@ def condition(_dataset: DataSet): return False return True + @wait_for(30, 1, "Chunk indexing timeout") def chunks_visible(_document: Document, _chunk_ids: list[str]): visible_ids = {chunk.id for chunk in _document.list_chunks(page_size=100)} return set(_chunk_ids).issubset(visible_ids) + @pytest.fixture(scope="function") def add_chunks_func(request: FixtureRequest, add_document: tuple[DataSet, Document]) -> tuple[DataSet, Document, list[Chunk]]: def cleanup(): diff --git a/test/testcases/test_sdk_api/test_dataset_mangement/test_create_dataset.py b/test/testcases/test_sdk_api/test_dataset_mangement/test_create_dataset.py index 8146b76cae..91673746e9 100644 --- a/test/testcases/test_sdk_api/test_dataset_mangement/test_create_dataset.py +++ b/test/testcases/test_sdk_api/test_dataset_mangement/test_create_dataset.py @@ -306,7 +306,7 @@ class TestDatasetCreate: ("qa", "qa"), ("table", "table"), ("tag", "tag"), - ("resume", "resume") + ("resume", "resume"), ], ids=["naive", "book", "email", "laws", "manual", "one", "paper", "picture", "presentation", "qa", "table", "tag", "resume"], ) @@ -328,7 +328,9 @@ class TestDatasetCreate: payload = {"name": name, "chunk_method": chunk_method} with pytest.raises(Exception) as exception_info: client.create_dataset(**payload) - assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table', 'tag' or 'resume'" in str(exception_info.value), str(exception_info.value) + assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table', 'tag' or 'resume'" in str(exception_info.value), str( + exception_info.value + ) @pytest.mark.p2 def test_chunk_method_unset(self, client): diff --git a/test/testcases/test_sdk_api/test_dataset_mangement/test_update_dataset.py b/test/testcases/test_sdk_api/test_dataset_mangement/test_update_dataset.py index f29ce98892..a0dfd5cdf5 100644 --- a/test/testcases/test_sdk_api/test_dataset_mangement/test_update_dataset.py +++ b/test/testcases/test_sdk_api/test_dataset_mangement/test_update_dataset.py @@ -27,6 +27,7 @@ from utils.hypothesis_utils import valid_names from configs import DEFAULT_PARSER_CONFIG from utils.engine_utils import get_doc_engine + class TestRquest: @pytest.mark.p2 def test_payload_empty(self, add_dataset_func): @@ -317,14 +318,18 @@ class TestDatasetUpdate: dataset = add_dataset_func with pytest.raises(Exception) as exception_info: dataset.update({"chunk_method": chunk_method}) - assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table', 'tag' or 'resume'" in str(exception_info.value), str(exception_info.value) + assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table', 'tag' or 'resume'" in str(exception_info.value), str( + exception_info.value + ) @pytest.mark.p3 def test_chunk_method_none(self, add_dataset_func): dataset = add_dataset_func with pytest.raises(Exception) as exception_info: dataset.update({"chunk_method": None}) - assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table', 'tag' or 'resume'" in str(exception_info.value), str(exception_info.value) + assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table', 'tag' or 'resume'" in str(exception_info.value), str( + exception_info.value + ) @pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="#8208") @pytest.mark.p2 diff --git a/test/testcases/test_sdk_api/test_file_management_within_dataset/conftest.py b/test/testcases/test_sdk_api/test_file_management_within_dataset/conftest.py index 58d8a7c625..b60f5f2886 100644 --- a/test/testcases/test_sdk_api/test_file_management_within_dataset/conftest.py +++ b/test/testcases/test_sdk_api/test_file_management_within_dataset/conftest.py @@ -37,7 +37,7 @@ def add_document_func(request: FixtureRequest, add_dataset: DataSet, ragflow_tmp def add_documents(request: FixtureRequest, add_dataset: DataSet, ragflow_tmp_dir) -> tuple[DataSet, list[Document]]: dataset = add_dataset documents = bulk_upload_documents(dataset, 5, ragflow_tmp_dir) - + def cleanup(): delete_all_documents(dataset) diff --git a/test/testcases/test_sdk_api/test_file_management_within_dataset/test_list_documents.py b/test/testcases/test_sdk_api/test_file_management_within_dataset/test_list_documents.py index a438512dc0..e77912628c 100644 --- a/test/testcases/test_sdk_api/test_file_management_within_dataset/test_list_documents.py +++ b/test/testcases/test_sdk_api/test_file_management_within_dataset/test_list_documents.py @@ -151,7 +151,6 @@ class TestDocumentsList: documents = dataset.list_documents(**params) assert len(documents) == expected_num, str(documents) - @pytest.mark.p1 @pytest.mark.parametrize( "params, expected_num, expected_message", @@ -223,7 +222,6 @@ class TestDocumentsList: documents = dataset.list_documents(**params) assert len(documents) == expected_num, str(documents) - @pytest.mark.p3 def test_concurrent_list(self, add_documents): dataset, _ = add_documents diff --git a/test/testcases/test_sdk_api/test_file_management_within_dataset/test_update_document.py b/test/testcases/test_sdk_api/test_file_management_within_dataset/test_update_document.py index 2b02c0b19c..f29be2c384 100644 --- a/test/testcases/test_sdk_api/test_file_management_within_dataset/test_update_document.py +++ b/test/testcases/test_sdk_api/test_file_management_within_dataset/test_update_document.py @@ -17,7 +17,8 @@ import pytest from configs import DOCUMENT_NAME_LIMIT from ragflow_sdk import DataSet -from configs import DEFAULT_PARSER_CONFIG +from configs import DEFAULT_PARSER_CONFIG + class TestDocumentsUpdated: @pytest.mark.p1 @@ -317,6 +318,7 @@ DEFAULT_PARSER_CONFIG_FOR_TEST = { }, } + class TestUpdateDocumentParserConfig: @pytest.mark.p2 @pytest.mark.parametrize( @@ -395,12 +397,20 @@ class TestUpdateDocumentParserConfig: {"task_page_size": "1024"}, "Input should be a valid integer", ), - ("naive", {"raptor": {"use_raptor": True, - "prompt": "Please summarize the following paragraphs. Be careful with the numbers, do not make things up. Paragraphs as following:\n {cluster_content}\nThe above is the content you need to summarize.", - "max_token": 256, - "threshold": 0.1, - "max_cluster": 64, - "random_seed": 0,}}, ""), + ( + "naive", + { + "raptor": { + "use_raptor": True, + "prompt": "Please summarize the following paragraphs. Be careful with the numbers, do not make things up. Paragraphs as following:\n {cluster_content}\nThe above is the content you need to summarize.", + "max_token": 256, + "threshold": 0.1, + "max_cluster": 64, + "random_seed": 0, + } + }, + "", + ), ("naive", {"raptor": {"use_raptor": False}}, ""), ( "naive", diff --git a/test/testcases/test_sdk_api/test_memory_management/conftest.py b/test/testcases/test_sdk_api/test_memory_management/conftest.py index 7027d541e6..992ef74d10 100644 --- a/test/testcases/test_sdk_api/test_memory_management/conftest.py +++ b/test/testcases/test_sdk_api/test_memory_management/conftest.py @@ -16,6 +16,7 @@ import pytest import random + @pytest.fixture(scope="class") def add_memory_func(client, request): def cleanup(): @@ -32,7 +33,7 @@ def add_memory_func(client, request): "name": f"test_memory_{i}", "memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)), "embd_id": "BAAI/bge-small-en-v1.5@Builtin", - "llm_id": "glm-4-flash@ZHIPU-AI" + "llm_id": "glm-4-flash@ZHIPU-AI", } res = client.create_memory(**payload) memory_ids.append(res.id) diff --git a/test/testcases/test_sdk_api/test_memory_management/test_create_memory.py b/test/testcases/test_sdk_api/test_memory_management/test_create_memory.py index 0e90b1fb9d..7dc1877b8f 100644 --- a/test/testcases/test_sdk_api/test_memory_management/test_create_memory.py +++ b/test/testcases/test_sdk_api/test_memory_management/test_create_memory.py @@ -31,7 +31,7 @@ class TestAuthorization: (None, ""), (INVALID_API_TOKEN, ""), ], - ids=["empty_auth", "invalid_api_token"] + ids=["empty_auth", "invalid_api_token"], ) def test_auth_invalid(self, invalid_auth, expected_message): client = RAGFlow(invalid_auth, HOST_ADDRESS) @@ -51,10 +51,10 @@ class TestMemoryCreate: "name": name, "memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)), "embd_id": "BAAI/bge-small-en-v1.5@Builtin", - "llm_id": "glm-4-flash@ZHIPU-AI" + "llm_id": "glm-4-flash@ZHIPU-AI", } memory = client.create_memory(**payload) - pattern = rf'^{name}|{name}(?:\((\d+)\))?$' + pattern = rf"^{name}|{name}(?:\((\d+)\))?$" escaped_name = re.escape(memory.name) assert re.match(pattern, escaped_name), str(memory) @@ -64,7 +64,7 @@ class TestMemoryCreate: [ ("", "Memory name cannot be empty or whitespace."), (" ", "Memory name cannot be empty or whitespace."), - ("a" * 129, f"Memory name '{'a'*129}' exceeds limit of 128."), + ("a" * 129, f"Memory name '{'a' * 129}' exceeds limit of 128."), ], ids=["empty_name", "space_name", "too_long_name"], ) @@ -73,7 +73,7 @@ class TestMemoryCreate: "name": name, "memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)), "embd_id": "BAAI/bge-small-en-v1.5@Builtin", - "llm_id": "glm-4-flash@ZHIPU-AI" + "llm_id": "glm-4-flash@ZHIPU-AI", } with pytest.raises(Exception) as exception_info: client.create_memory(**payload) @@ -83,12 +83,7 @@ class TestMemoryCreate: @given(name=valid_names()) @settings(deadline=None) def test_type_invalid(self, client, name): - payload = { - "name": name, - "memory_type": ["something"], - "embd_id": "BAAI/bge-small-en-v1.5@Builtin", - "llm_id": "glm-4-flash@ZHIPU-AI" - } + payload = {"name": name, "memory_type": ["something"], "embd_id": "BAAI/bge-small-en-v1.5@Builtin", "llm_id": "glm-4-flash@ZHIPU-AI"} with pytest.raises(Exception) as exception_info: client.create_memory(**payload) assert str(exception_info.value) == f"Memory type '{ {'something'} }' is not supported.", str(exception_info.value) @@ -100,7 +95,7 @@ class TestMemoryCreate: "name": name, "memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)), "embd_id": "BAAI/bge-small-en-v1.5@Builtin", - "llm_id": "glm-4-flash@ZHIPU-AI" + "llm_id": "glm-4-flash@ZHIPU-AI", } res1 = client.create_memory(**payload) assert res1.name == name, str(res1) diff --git a/test/testcases/test_sdk_api/test_memory_management/test_list_memory.py b/test/testcases/test_sdk_api/test_memory_management/test_list_memory.py index 774cb59ccc..1131b9d435 100644 --- a/test/testcases/test_sdk_api/test_memory_management/test_list_memory.py +++ b/test/testcases/test_sdk_api/test_memory_management/test_list_memory.py @@ -19,6 +19,7 @@ import pytest from ragflow_sdk import RAGFlow from configs import INVALID_API_TOKEN, HOST_ADDRESS + class TestAuthorization: @pytest.mark.p2 @pytest.mark.parametrize( @@ -45,11 +46,12 @@ class TestCapability: assert len(responses) == count, responses assert all(future.result()["code"] == 0 for future in futures) + @pytest.mark.usefixtures("add_memory_func") class TestMemoryList: @pytest.mark.p2 def test_params_unset(self, client): - res = client.list_memory() + res = client.list_memory() assert len(res["memory_list"]) == 3, str(res) assert res["total_count"] == 3, str(res) @@ -69,8 +71,7 @@ class TestMemoryList: ({"page": 2, "page_size": 2}, 1), ({"page": 5, "page_size": 10}, 0), ], - ids=["normal_first_page", "beyond_max_page", "normal_last_partial_page" , "normal_middle_page", - "full_data_single_page"], + ids=["normal_first_page", "beyond_max_page", "normal_last_partial_page", "normal_middle_page", "full_data_single_page"], ) def test_page(self, client, params, expected_page_size): # have added 3 memories in fixture @@ -110,9 +111,23 @@ class TestMemoryList: memory_id = memory.id memory_config = memory.get_config() assert memory_config.id == memory_id, memory_config - for field in ["name", "avatar", "tenant_id", "owner_name", "memory_type", "storage_type", - "embd_id", "llm_id", "permissions", "description", "memory_size", "forgetting_policy", - "temperature", "system_prompt", "user_prompt"]: + for field in [ + "name", + "avatar", + "tenant_id", + "owner_name", + "memory_type", + "storage_type", + "embd_id", + "llm_id", + "permissions", + "description", + "memory_size", + "forgetting_policy", + "temperature", + "system_prompt", + "user_prompt", + ]: assert hasattr(memory, field), memory_config @pytest.mark.p2 diff --git a/test/testcases/test_sdk_api/test_memory_management/test_rm_memory.py b/test/testcases/test_sdk_api/test_memory_management/test_rm_memory.py index 45c8089149..e5bb48f5ff 100644 --- a/test/testcases/test_sdk_api/test_memory_management/test_rm_memory.py +++ b/test/testcases/test_sdk_api/test_memory_management/test_rm_memory.py @@ -17,6 +17,7 @@ import pytest from ragflow_sdk import RAGFlow from configs import INVALID_API_TOKEN, HOST_ADDRESS + class TestAuthorization: @pytest.mark.p2 @pytest.mark.parametrize( diff --git a/test/testcases/test_sdk_api/test_memory_management/test_update_memory.py b/test/testcases/test_sdk_api/test_memory_management/test_update_memory.py index 5e5b0eae6c..495d5b1acf 100644 --- a/test/testcases/test_sdk_api/test_memory_management/test_update_memory.py +++ b/test/testcases/test_sdk_api/test_memory_management/test_update_memory.py @@ -31,7 +31,7 @@ class TestAuthorization: (None, ""), (INVALID_API_TOKEN, ""), ], - ids=["empty_auth", "invalid_api_token"] + ids=["empty_auth", "invalid_api_token"], ) def test_auth_invalid(self, invalid_auth, expected_message): @@ -41,9 +41,9 @@ class TestAuthorization: memory.update({"name": "New_Name"}) assert str(exception_info.value) == expected_message, str(exception_info.value) + @pytest.mark.usefixtures("add_memory_func") class TestMemoryUpdate: - @pytest.mark.p1 @given(name=valid_names()) @example("f" * 128) @@ -62,7 +62,7 @@ class TestMemoryUpdate: ("", "Memory name cannot be empty or whitespace."), (" ", "Memory name cannot be empty or whitespace."), ("a" * 129, f"Memory name '{'a' * 129}' exceeds limit of 128."), - ] + ], ) def test_name_invalid(self, client, name, expected_message): memory_ids = self.memory_ids @@ -112,14 +112,7 @@ class TestMemoryUpdate: assert res.llm_id == llm_id, str(res) @pytest.mark.p2 - @pytest.mark.parametrize( - "permission", - [ - "me", - "team" - ], - ids=["me", "team"] - ) + @pytest.mark.parametrize("permission", ["me", "team"], ids=["me", "team"]) def test_permission(self, client, permission): memory_ids = self.memory_ids update_dict = {"permissions": permission} diff --git a/test/testcases/test_sdk_api/test_message_management/conftest.py b/test/testcases/test_sdk_api/test_message_management/conftest.py index a93dd6fdf7..ddaa7b7cac 100644 --- a/test/testcases/test_sdk_api/test_message_management/conftest.py +++ b/test/testcases/test_sdk_api/test_message_management/conftest.py @@ -27,13 +27,9 @@ def add_empty_raw_type_memory(client, request): exist_memory_ids = [memory.id for memory in memory_list_res["memory_list"]] for _memory_id in exist_memory_ids: client.delete_memory(_memory_id) + request.addfinalizer(cleanup) - payload = { - "name": "test_memory_0", - "memory_type": ["raw"], - "embd_id": "BAAI/bge-small-en-v1.5@Builtin", - "llm_id": "glm-4-flash@ZHIPU-AI" - } + payload = {"name": "test_memory_0", "memory_type": ["raw"], "embd_id": "BAAI/bge-small-en-v1.5@Builtin", "llm_id": "glm-4-flash@ZHIPU-AI"} res = client.create_memory(**payload) memory_id = res.id request.cls.memory_id = memory_id @@ -48,12 +44,13 @@ def add_empty_multiple_type_memory(client, request): exist_memory_ids = [memory.id for memory in memory_list_res["memory_list"]] for _memory_id in exist_memory_ids: client.delete_memory(_memory_id) + request.addfinalizer(cleanup) payload = { "name": "test_memory_0", "memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(1, 3)), "embd_id": "BAAI/bge-small-en-v1.5@Builtin", - "llm_id": "glm-4-flash@ZHIPU-AI" + "llm_id": "glm-4-flash@ZHIPU-AI", } res = client.create_memory(**payload) memory_id = res.id @@ -77,7 +74,7 @@ def add_2_multiple_type_memory(client, request): "name": f"test_memory_{i}", "memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(1, 3)), "embd_id": "BAAI/bge-small-en-v1.5@Builtin", - "llm_id": "glm-4-flash@ZHIPU-AI" + "llm_id": "glm-4-flash@ZHIPU-AI", } res = client.create_memory(**payload) memory_ids.append(res.id) @@ -99,7 +96,7 @@ def add_memory_with_multiple_type_message_func(client, request): "name": "test_memory_0", "memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(1, 3)), "embd_id": "BAAI/bge-small-en-v1.5@Builtin", - "llm_id": "glm-4-flash@ZHIPU-AI" + "llm_id": "glm-4-flash@ZHIPU-AI", } memory = client.create_memory(**payload) memory_id = memory.id @@ -115,7 +112,7 @@ Coriander is a versatile herb with two main edible parts, and its name can refer 1. Leaves and Stems (often called Cilantro or Fresh Coriander): These are the fresh, green, fragrant leaves and tender stems of the plant Coriandrum sativum. They have a bright, citrusy, and sometimes pungent flavor. Cilantro is widely used as a garnish or key ingredient in cuisines like Mexican, Indian, Thai, and Middle Eastern. 2. Seeds (called Coriander Seeds): These are the dried, golden-brown seeds of the same plant. When ground, they become coriander powder. The seeds have a warm, nutty, floral, and slightly citrusy taste, completely different from the fresh leaves. They are a fundamental spice in curries, stews, pickles, and baking. Key Point of Confusion: The naming differs by region. In North America, "coriander" typically refers to the seeds, while "cilantro" refers to the fresh leaves. In the UK, Europe, and many other parts of the world, "coriander" refers to the fresh herb, and the seeds are called "coriander seeds." -""" +""", } client.add_message(**message_payload) request.cls.memory_id = memory_id @@ -134,12 +131,7 @@ def add_memory_with_5_raw_message_func(client, request): request.addfinalizer(cleanup) - payload = { - "name": "test_memory_1", - "memory_type": ["raw"], - "embd_id": "BAAI/bge-small-en-v1.5@Builtin", - "llm_id": "glm-4-flash@ZHIPU-AI" - } + payload = {"name": "test_memory_1", "memory_type": ["raw"], "embd_id": "BAAI/bge-small-en-v1.5@Builtin", "llm_id": "glm-4-flash@ZHIPU-AI"} memory = client.create_memory(**payload) memory_id = memory.id agent_ids = [uuid.uuid4().hex for _ in range(2)] @@ -156,11 +148,11 @@ Coriander is a versatile herb with two main edible parts, and its name can refer 1. Leaves and Stems (often called Cilantro or Fresh Coriander): These are the fresh, green, fragrant leaves and tender stems of the plant Coriandrum sativum. They have a bright, citrusy, and sometimes pungent flavor. Cilantro is widely used as a garnish or key ingredient in cuisines like Mexican, Indian, Thai, and Middle Eastern. 2. Seeds (called Coriander Seeds): These are the dried, golden-brown seeds of the same plant. When ground, they become coriander powder. The seeds have a warm, nutty, floral, and slightly citrusy taste, completely different from the fresh leaves. They are a fundamental spice in curries, stews, pickles, and baking. Key Point of Confusion: The naming differs by region. In North America, "coriander" typically refers to the seeds, while "cilantro" refers to the fresh leaves. In the UK, Europe, and many other parts of the world, "coriander" refers to the fresh herb, and the seeds are called "coriander seeds." -""" +""", } client.add_message(**message_payload) request.cls.memory_id = memory_id request.cls.agent_ids = agent_ids request.cls.session_ids = session_ids - time.sleep(2) # make sure refresh to index before search + time.sleep(2) # make sure refresh to index before search return memory_id diff --git a/test/testcases/test_sdk_api/test_message_management/test_add_message.py b/test/testcases/test_sdk_api/test_message_management/test_add_message.py index 44a374bcae..c460710277 100644 --- a/test/testcases/test_sdk_api/test_message_management/test_add_message.py +++ b/test/testcases/test_sdk_api/test_message_management/test_add_message.py @@ -19,6 +19,7 @@ import pytest from ragflow_sdk import RAGFlow, Memory from configs import INVALID_API_TOKEN, HOST_ADDRESS + class TestAuthorization: @pytest.mark.p2 @pytest.mark.parametrize( @@ -31,20 +32,12 @@ class TestAuthorization: def test_auth_invalid(self, invalid_auth, expected_message): client = RAGFlow(invalid_auth, HOST_ADDRESS) with pytest.raises(Exception) as exception_info: - client.add_message(**{ - "memory_id": [""], - "agent_id": "", - "session_id": "", - "user_id": "", - "user_input": "what is pineapple?", - "agent_response": "" - }) + client.add_message(**{"memory_id": [""], "agent_id": "", "session_id": "", "user_id": "", "user_input": "what is pineapple?", "agent_response": ""}) assert str(exception_info.value) == expected_message, str(exception_info.value) @pytest.mark.usefixtures("add_empty_raw_type_memory") class TestAddRawMessage: - @pytest.mark.p1 def test_add_raw_message(self, client): memory_id = self.memory_id @@ -65,7 +58,7 @@ Uses: Pineapples are eaten fresh, cooked, grilled, juiced, or canned. They are a Nutrition: They are a good source of Vitamin C, manganese, and contain an enzyme called bromelain, which aids in digestion and can tenderize meat. Symbolism: The pineapple is a traditional symbol of hospitality and welcome in many cultures. Are you asking about the fruit itself, or its use in a specific context? -""" +""", } add_res = client.add_message(**message_payload) assert add_res == "All add to task.", str(add_res) @@ -80,7 +73,6 @@ Are you asking about the fruit itself, or its use in a specific context? @pytest.mark.usefixtures("add_empty_multiple_type_memory") class TestAddMultipleTypeMessage: - @pytest.mark.p1 def test_add_multiple_type_message(self, client): memory_id = self.memory_id @@ -101,7 +93,7 @@ Uses: Pineapples are eaten fresh, cooked, grilled, juiced, or canned. They are a Nutrition: They are a good source of Vitamin C, manganese, and contain an enzyme called bromelain, which aids in digestion and can tenderize meat. Symbolism: The pineapple is a traditional symbol of hospitality and welcome in many cultures. Are you asking about the fruit itself, or its use in a specific context? -""" +""", } add_res = client.add_message(**message_payload) assert add_res == "All add to task.", str(add_res) @@ -116,7 +108,6 @@ Are you asking about the fruit itself, or its use in a specific context? @pytest.mark.usefixtures("add_2_multiple_type_memory") class TestAddToMultipleMemory: - @pytest.mark.p1 def test_add_to_multiple_memory(self, client): memory_ids = self.memory_ids @@ -137,7 +128,7 @@ Uses: Pineapples are eaten fresh, cooked, grilled, juiced, or canned. They are a Nutrition: They are a good source of Vitamin C, manganese, and contain an enzyme called bromelain, which aids in digestion and can tenderize meat. Symbolism: The pineapple is a traditional symbol of hospitality and welcome in many cultures. Are you asking about the fruit itself, or its use in a specific context? -""" +""", } add_res = client.add_message(**message_payload) assert add_res == "All add to task.", str(add_res) diff --git a/test/testcases/test_sdk_api/test_message_management/test_forget_message.py b/test/testcases/test_sdk_api/test_message_management/test_forget_message.py index a2e3e50138..1447aa02aa 100644 --- a/test/testcases/test_sdk_api/test_message_management/test_forget_message.py +++ b/test/testcases/test_sdk_api/test_message_management/test_forget_message.py @@ -38,7 +38,6 @@ class TestAuthorization: @pytest.mark.usefixtures("add_memory_with_5_raw_message_func") class TestForgetMessage: - @pytest.mark.p1 def test_forget_message(self, client): memory_id = self.memory_id diff --git a/test/testcases/test_sdk_api/test_message_management/test_get_message_content.py b/test/testcases/test_sdk_api/test_message_management/test_get_message_content.py index 6631923176..df0d88606c 100644 --- a/test/testcases/test_sdk_api/test_message_management/test_get_message_content.py +++ b/test/testcases/test_sdk_api/test_message_management/test_get_message_content.py @@ -19,6 +19,7 @@ import pytest from ragflow_sdk import RAGFlow, Memory from configs import INVALID_API_TOKEN, HOST_ADDRESS + class TestAuthorization: @pytest.mark.p2 @pytest.mark.parametrize( @@ -38,9 +39,8 @@ class TestAuthorization: @pytest.mark.usefixtures("add_memory_with_multiple_type_message_func") class TestGetMessageContent: - @pytest.mark.p1 - def test_get_message_content(self,client): + def test_get_message_content(self, client): memory_id = self.memory_id recent_messages = client.get_recent_messages([memory_id]) assert len(recent_messages) > 0, recent_messages diff --git a/test/testcases/test_sdk_api/test_message_management/test_get_recent_message.py b/test/testcases/test_sdk_api/test_message_management/test_get_recent_message.py index 832b8b4978..a13e14a215 100644 --- a/test/testcases/test_sdk_api/test_message_management/test_get_recent_message.py +++ b/test/testcases/test_sdk_api/test_message_management/test_get_recent_message.py @@ -38,7 +38,6 @@ class TestAuthorization: @pytest.mark.usefixtures("add_memory_with_5_raw_message_func") class TestGetRecentMessage: - @pytest.mark.p1 def test_get_recent_messages(self, client): memory_id = self.memory_id diff --git a/test/testcases/test_sdk_api/test_message_management/test_list_message.py b/test/testcases/test_sdk_api/test_message_management/test_list_message.py index fc7578353d..70766b458f 100644 --- a/test/testcases/test_sdk_api/test_message_management/test_list_message.py +++ b/test/testcases/test_sdk_api/test_message_management/test_list_message.py @@ -41,7 +41,6 @@ class TestAuthorization: @pytest.mark.usefixtures("add_memory_with_5_raw_message_func") class TestMessageList: - @pytest.mark.p2 def test_params_unset(self, client): memory_id = self.memory_id @@ -66,8 +65,7 @@ class TestMessageList: ({"page": 3, "page_size": 2}, 1), ({"page": 5, "page_size": 10}, 0), ], - ids=["normal_first_page", "beyond_max_page", "normal_last_partial_page", "normal_middle_page", - "full_data_single_page"], + ids=["normal_first_page", "beyond_max_page", "normal_last_partial_page", "normal_middle_page", "full_data_single_page"], ) def test_page_size(self, client, params, expected_page_size): # have added 5 messages in fixture diff --git a/test/testcases/test_sdk_api/test_message_management/test_search_message.py b/test/testcases/test_sdk_api/test_message_management/test_search_message.py index 4e0329d1b7..3fe50135cd 100644 --- a/test/testcases/test_sdk_api/test_message_management/test_search_message.py +++ b/test/testcases/test_sdk_api/test_message_management/test_search_message.py @@ -17,6 +17,7 @@ import pytest from ragflow_sdk import RAGFlow, Memory from configs import INVALID_API_TOKEN, HOST_ADDRESS + class TestAuthorization: @pytest.mark.p2 @pytest.mark.parametrize( @@ -35,7 +36,6 @@ class TestAuthorization: @pytest.mark.usefixtures("add_memory_with_multiple_type_message_func") class TestSearchMessage: - @pytest.mark.p1 def test_query(self, client): memory_id = self.memory_id @@ -69,11 +69,7 @@ class TestSearchMessage: assert list_res["messages"]["total_count"] > 0 query = "Coriander is a versatile herb with two main edible parts. What's its name can refer to?" - params = { - "similarity_threshold": 0.1, - "keywords_similarity_weight": 0.6, - "top_n": 4 - } + params = {"similarity_threshold": 0.1, "keywords_similarity_weight": 0.6, "top_n": 4} res = client.search_message(**{"memory_id": [memory_id], "query": query, **params}) assert len(res) > 0 assert len(res) <= params["top_n"] diff --git a/test/testcases/test_sdk_api/test_message_management/test_update_message_status.py b/test/testcases/test_sdk_api/test_message_management/test_update_message_status.py index d58699b9b5..c2455b139e 100644 --- a/test/testcases/test_sdk_api/test_message_management/test_update_message_status.py +++ b/test/testcases/test_sdk_api/test_message_management/test_update_message_status.py @@ -39,7 +39,6 @@ class TestAuthorization: @pytest.mark.usefixtures("add_memory_with_5_raw_message_func") class TestUpdateMessageStatus: - @pytest.mark.p1 def test_update_to_false(self, client): memory_id = self.memory_id diff --git a/test/testcases/test_sdk_api/test_session_management/conftest.py b/test/testcases/test_sdk_api/test_session_management/conftest.py index 7361b34849..351dfcf4ad 100644 --- a/test/testcases/test_sdk_api/test_session_management/conftest.py +++ b/test/testcases/test_sdk_api/test_session_management/conftest.py @@ -25,7 +25,7 @@ def add_sessions_with_chat_assistant(request: FixtureRequest, add_chat_assistant for chat_assistant in chat_assistants: try: delete_all_sessions(chat_assistant) - except Exception : + except Exception: pass request.addfinalizer(cleanup) @@ -40,7 +40,7 @@ def add_sessions_with_chat_assistant_func(request: FixtureRequest, add_chat_assi for chat_assistant in chat_assistants: try: delete_all_sessions(chat_assistant) - except Exception : + except Exception: pass request.addfinalizer(cleanup) diff --git a/test/testcases/test_sdk_api/test_session_management/test_create_session_with_chat_assistant.py b/test/testcases/test_sdk_api/test_session_management/test_create_session_with_chat_assistant.py index 54300ff012..7d58629f8c 100644 --- a/test/testcases/test_sdk_api/test_session_management/test_create_session_with_chat_assistant.py +++ b/test/testcases/test_sdk_api/test_session_management/test_create_session_with_chat_assistant.py @@ -32,7 +32,6 @@ class _DummyStreamResponse: yield line - @pytest.mark.usefixtures("clear_session_with_chat_assistants") class TestSessionWithChatAssistantCreate: @pytest.mark.p1 @@ -128,9 +127,7 @@ def test_session_module_streaming_and_helper_paths_unit(monkeypatch): monkeypatch.setattr( chat_done_session, "post", - lambda *_args, **_kwargs: _DummyStreamResponse( - ['{"data":{"answer":"chat-done","reference":{"chunks":[]}}}', "data: [DONE]"] - ), + lambda *_args, **_kwargs: _DummyStreamResponse(['{"data":{"answer":"chat-done","reference":{"chunks":[]}}}', "data: [DONE]"]), ) monkeypatch.setattr(agent_session, "post", _agent_post) diff --git a/test/testcases/test_web_api/conftest.py b/test/testcases/test_web_api/conftest.py index ab55cf7205..4e20efc192 100644 --- a/test/testcases/test_web_api/conftest.py +++ b/test/testcases/test_web_api/conftest.py @@ -83,10 +83,13 @@ def generate_test_files(request: FixtureRequest, tmp_path): def ragflow_tmp_dir(request, tmp_path_factory): class_name = request.cls.__name__ return tmp_path_factory.mktemp(class_name) + + @pytest.fixture(scope="session") def client(token: str) -> RAGFlow: return RAGFlow(api_key=token, base_url=HOST_ADDRESS, version=VERSION) + @pytest.fixture(scope="session") def WebApiAuth(auth): return RAGFlowWebApiAuth(auth) diff --git a/test/testcases/test_web_api/test_agent_app/test_agents_webhook_unit.py b/test/testcases/test_web_api/test_agent_app/test_agents_webhook_unit.py index b3eaba997e..66fc10a492 100644 --- a/test/testcases/test_web_api/test_agent_app/test_agents_webhook_unit.py +++ b/test/testcases/test_web_api/test_agent_app/test_agents_webhook_unit.py @@ -439,6 +439,7 @@ def _load_agents_app(monkeypatch, *, target="rest"): @classmethod def normalize_dsl(cls, dsl): import json + if isinstance(dsl, str): return json.loads(dsl) return dsl @@ -1516,11 +1517,15 @@ def test_webhook_trace_encoded_id_generation(monkeypatch): res = _run(module.webhook_trace("agent-1")) assert res["code"] == module.RetCode.SUCCESS - expected = base64.urlsafe_b64encode( - hmac.new( - b"webhook_id_secret", - b"101.0", - hashlib.sha256, - ).digest() - ).decode("utf-8").rstrip("=") + expected = ( + base64.urlsafe_b64encode( + hmac.new( + b"webhook_id_secret", + b"101.0", + hashlib.sha256, + ).digest() + ) + .decode("utf-8") + .rstrip("=") + ) assert res["data"]["webhook_id"] == expected diff --git a/test/testcases/test_web_api/test_auth_app/test_oauth_client_unit.py b/test/testcases/test_web_api/test_auth_app/test_oauth_client_unit.py index 90f089a908..e65fad17ec 100644 --- a/test/testcases/test_web_api/test_auth_app/test_oauth_client_unit.py +++ b/test/testcases/test_web_api/test_auth_app/test_oauth_client_unit.py @@ -153,9 +153,7 @@ def test_oauth_client_sync_matrix_unit(monkeypatch): assert call_log[1][0] == "GET" assert call_log[1][3]["Authorization"] == "Bearer access-1" - normalized = client.normalize_user_info( - {"email": "fallback@example.com", "username": "fallback-user", "nickname": "fallback-nick", "avatar_url": "direct-avatar"} - ) + normalized = client.normalize_user_info({"email": "fallback@example.com", "username": "fallback-user", "nickname": "fallback-nick", "avatar_url": "direct-avatar"}) assert normalized.to_dict()["avatar_url"] == "direct-avatar" monkeypatch.setattr(oauth_module, "sync_request", lambda *_args, **_kwargs: _FakeResponse(err=RuntimeError("status boom"))) diff --git a/test/testcases/test_web_api/test_auth_app/test_oidc_client_unit.py b/test/testcases/test_web_api/test_auth_app/test_oidc_client_unit.py index 7f48a3b95e..cd6def73de 100644 --- a/test/testcases/test_web_api/test_auth_app/test_oidc_client_unit.py +++ b/test/testcases/test_web_api/test_auth_app/test_oidc_client_unit.py @@ -334,9 +334,7 @@ def test_parse_id_token_passes_pinned_algorithms_to_jwt_decode(monkeypatch): # verification path must not consult it. We sabotage # ``jwt.get_unverified_header`` to prove the code never calls it. def _explode(_token): # pragma: no cover - must not be called - raise AssertionError( - "parse_id_token must not read the algorithm from the unverified JWT header" - ) + raise AssertionError("parse_id_token must not read the algorithm from the unverified JWT header") monkeypatch.setattr(oidc_module.jwt, "get_unverified_header", _explode) monkeypatch.setattr(oidc_module.jwt, "PyJWKClient", _DummyJwkClient) @@ -650,9 +648,7 @@ def test_github_fetch_user_info_async_success_and_error_unit(monkeypatch): {"email": "octo-async@example.com", "primary": True}, ] ) - return _FakeResponse( - {"login": "octocat-async", "name": "Octo Async", "avatar_url": "https://avatar.example/octo-async.png"} - ) + return _FakeResponse({"login": "octocat-async", "name": "Octo Async", "avatar_url": "https://avatar.example/octo-async.png"}) monkeypatch.setattr(github_module, "async_request", _fake_async_request) info = asyncio.run(client.async_fetch_user_info("async-token")) diff --git a/test/testcases/test_web_api/test_canvas_app/test_code_exec_contract_unit.py b/test/testcases/test_web_api/test_canvas_app/test_code_exec_contract_unit.py index 1992105474..d783196768 100644 --- a/test/testcases/test_web_api/test_canvas_app/test_code_exec_contract_unit.py +++ b/test/testcases/test_web_api/test_canvas_app/test_code_exec_contract_unit.py @@ -23,11 +23,7 @@ from unittest.mock import patch import pytest -CODE_EXEC_MODULE_PATH = next( - parent / "agent" / "tools" / "code_exec.py" - for parent in Path(__file__).resolve().parents - if (parent / "agent" / "tools" / "code_exec.py").exists() -) +CODE_EXEC_MODULE_PATH = next(parent / "agent" / "tools" / "code_exec.py" for parent in Path(__file__).resolve().parents if (parent / "agent" / "tools" / "code_exec.py").exists()) def _load_module(): @@ -274,15 +270,9 @@ def test_malformed_array_schema_is_rejected(schema): def test_any_and_empty_expected_type_skip_validation(): module = _load_module() - assert module.build_code_exec_contract({"result": {"value": None, "type": "Any"}}, {"foo": "bar"})["value"] == { - "foo": "bar" - } - assert module.build_code_exec_contract({"result": {"value": None, "type": ""}}, {"foo": "bar"})["value"] == { - "foo": "bar" - } - assert module.build_code_exec_contract({"result": {"value": None, "type": None}}, {"foo": "bar"})["value"] == { - "foo": "bar" - } + assert module.build_code_exec_contract({"result": {"value": None, "type": "Any"}}, {"foo": "bar"})["value"] == {"foo": "bar"} + assert module.build_code_exec_contract({"result": {"value": None, "type": ""}}, {"foo": "bar"})["value"] == {"foo": "bar"} + assert module.build_code_exec_contract({"result": {"value": None, "type": None}}, {"foo": "bar"})["value"] == {"foo": "bar"} def test_legacy_multi_output_schema_is_rejected(): diff --git a/test/testcases/test_web_api/test_canvas_app/test_fillup_unit.py b/test/testcases/test_web_api/test_canvas_app/test_fillup_unit.py index 4d6f8e43aa..1bec402f38 100644 --- a/test/testcases/test_web_api/test_canvas_app/test_fillup_unit.py +++ b/test/testcases/test_web_api/test_canvas_app/test_fillup_unit.py @@ -83,9 +83,7 @@ def _load_fillup_module(monkeypatch): monkeypatch.setitem(sys.modules, "api.db.services.file_service", file_service_mod) module_path = repo_root / "agent" / "component" / "fillup.py" - spec = importlib.util.spec_from_file_location( - "test_fillup_unit_module", module_path - ) + spec = importlib.util.spec_from_file_location("test_fillup_unit_module", module_path) module = importlib.util.module_from_spec(spec) monkeypatch.setitem(sys.modules, "test_fillup_unit_module", module) spec.loader.exec_module(module) diff --git a/test/testcases/test_web_api/test_canvas_app/test_iteration_runtime_unit.py b/test/testcases/test_web_api/test_canvas_app/test_iteration_runtime_unit.py index 6f1f159e43..74796312ae 100644 --- a/test/testcases/test_web_api/test_canvas_app/test_iteration_runtime_unit.py +++ b/test/testcases/test_web_api/test_canvas_app/test_iteration_runtime_unit.py @@ -136,16 +136,12 @@ def _load_canvas_runtime(monkeypatch): component_pkg.__path__ = [str(repo_root / "agent" / "component")] monkeypatch.setitem(sys.modules, "agent.component", component_pkg) - base_spec = importlib.util.spec_from_file_location( - "agent.component.base", repo_root / "agent" / "component" / "base.py" - ) + base_spec = importlib.util.spec_from_file_location("agent.component.base", repo_root / "agent" / "component" / "base.py") base_mod = importlib.util.module_from_spec(base_spec) monkeypatch.setitem(sys.modules, "agent.component.base", base_mod) base_spec.loader.exec_module(base_mod) - iteration_spec = importlib.util.spec_from_file_location( - "agent.component.iteration", repo_root / "agent" / "component" / "iteration.py" - ) + iteration_spec = importlib.util.spec_from_file_location("agent.component.iteration", repo_root / "agent" / "component" / "iteration.py") iteration_mod = importlib.util.module_from_spec(iteration_spec) monkeypatch.setitem(sys.modules, "agent.component.iteration", iteration_mod) iteration_spec.loader.exec_module(iteration_mod) @@ -189,9 +185,7 @@ def _load_canvas_runtime(monkeypatch): def _invoke(self, **kwargs): query_text = kwargs.get("query") vars_map = self.get_input_elements_from_text(query_text) - query = self.string_format( - query_text, {key: value["value"] for key, value in vars_map.items()} - ) + query = self.string_format(query_text, {key: value["value"] for key, value in vars_map.items()}) calls = self._canvas.globals.setdefault("probe.calls", []) calls.append(query) self.set_output("result", query) @@ -279,9 +273,7 @@ def _load_canvas_runtime(monkeypatch): component_pkg.component_class = lambda name: class_map[name] - canvas_spec = importlib.util.spec_from_file_location( - "agent.canvas", repo_root / "agent" / "canvas.py" - ) + canvas_spec = importlib.util.spec_from_file_location("agent.canvas", repo_root / "agent" / "canvas.py") canvas_mod = importlib.util.module_from_spec(canvas_spec) monkeypatch.setitem(sys.modules, "agent.canvas", canvas_mod) canvas_spec.loader.exec_module(canvas_mod) @@ -505,9 +497,7 @@ def test_canvas_resume_does_not_emit_duplicate_workflow_started(monkeypatch): "user_inputs", ] - resumed_events = asyncio.run( - _collect_events(canvas.run(query="hello", inputs={"value": {"value": "hello"}})) - ) + resumed_events = asyncio.run(_collect_events(canvas.run(query="hello", inputs={"value": {"value": "hello"}}))) resumed_kinds = [event["event"] for event in resumed_events] assert resumed_kinds[0] == "node_started" assert "workflow_started" not in resumed_kinds diff --git a/test/testcases/test_web_api/test_canvas_app/test_iterationitem_unit.py b/test/testcases/test_web_api/test_canvas_app/test_iterationitem_unit.py index 1151bb60dc..dd732ca344 100644 --- a/test/testcases/test_web_api/test_canvas_app/test_iterationitem_unit.py +++ b/test/testcases/test_web_api/test_canvas_app/test_iterationitem_unit.py @@ -44,16 +44,12 @@ def _load_iterationitem_module(monkeypatch): constants.RetCode = _RetCode monkeypatch.setitem(sys.modules, "common.constants", constants) - conn_spec = importlib.util.spec_from_file_location( - "common.connection_utils", repo_root / "common" / "connection_utils.py" - ) + conn_spec = importlib.util.spec_from_file_location("common.connection_utils", repo_root / "common" / "connection_utils.py") conn_mod = importlib.util.module_from_spec(conn_spec) monkeypatch.setitem(sys.modules, "common.connection_utils", conn_mod) conn_spec.loader.exec_module(conn_mod) - misc_spec = importlib.util.spec_from_file_location( - "common.misc_utils", repo_root / "common" / "misc_utils.py" - ) + misc_spec = importlib.util.spec_from_file_location("common.misc_utils", repo_root / "common" / "misc_utils.py") misc_mod = importlib.util.module_from_spec(misc_spec) monkeypatch.setitem(sys.modules, "common.misc_utils", misc_mod) misc_spec.loader.exec_module(misc_mod) @@ -79,9 +75,7 @@ def _load_iterationitem_module(monkeypatch): canvas_mod.Graph = Graph monkeypatch.setitem(sys.modules, "agent.canvas", canvas_mod) - base_spec = importlib.util.spec_from_file_location( - "agent.component.base", repo_root / "agent" / "component" / "base.py" - ) + base_spec = importlib.util.spec_from_file_location("agent.component.base", repo_root / "agent" / "component" / "base.py") base_mod = importlib.util.module_from_spec(base_spec) monkeypatch.setitem(sys.modules, "agent.component.base", base_mod) base_spec.loader.exec_module(base_mod) @@ -91,9 +85,7 @@ def _load_iterationitem_module(monkeypatch): repo_root / "agent" / "component" / "iterationitem.py", ) iterationitem_mod = importlib.util.module_from_spec(iterationitem_spec) - monkeypatch.setitem( - sys.modules, "agent.component.iterationitem", iterationitem_mod - ) + monkeypatch.setitem(sys.modules, "agent.component.iterationitem", iterationitem_mod) iterationitem_spec.loader.exec_module(iterationitem_mod) return iterationitem_mod diff --git a/test/testcases/test_web_api/test_canvas_app/test_list_operations_unit.py b/test/testcases/test_web_api/test_canvas_app/test_list_operations_unit.py index 35c89bc8a7..1f25098a93 100644 --- a/test/testcases/test_web_api/test_canvas_app/test_list_operations_unit.py +++ b/test/testcases/test_web_api/test_canvas_app/test_list_operations_unit.py @@ -58,13 +58,11 @@ def _load_list_operations_module(monkeypatch): monkeypatch.setitem(sys.modules, "api", api_pkg) api_utils_mod = ModuleType("api.utils.api_utils") - api_utils_mod.timeout = lambda *_args, **_kwargs: (lambda func: func) + api_utils_mod.timeout = lambda *_args, **_kwargs: lambda func: func monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) module_path = repo_root / "agent" / "component" / "list_operations.py" - spec = importlib.util.spec_from_file_location( - "test_list_operations_unit_module", module_path - ) + spec = importlib.util.spec_from_file_location("test_list_operations_unit_module", module_path) module = importlib.util.module_from_spec(spec) monkeypatch.setitem(sys.modules, "test_list_operations_unit_module", module) spec.loader.exec_module(module) @@ -101,9 +99,7 @@ def _make_component(module, *, inputs, operation, n, strict=False): ) def test_nth_behaves_like_lenient_indexing(monkeypatch, n, expected): module = _load_list_operations_module(monkeypatch) - component = _make_component( - module, inputs=["a", "b", "c", "d", "e"], operation="nth", n=n - ) + component = _make_component(module, inputs=["a", "b", "c", "d", "e"], operation="nth", n=n) component._nth() assert component._param.outputs["result"]["value"] == expected @@ -120,9 +116,7 @@ def test_nth_behaves_like_lenient_indexing(monkeypatch, n, expected): ) def test_head_supports_lenient_and_strict(monkeypatch, strict, n, expected): module = _load_list_operations_module(monkeypatch) - component = _make_component( - module, inputs=["a", "b", "c", "d", "e"], operation="head", n=n, strict=strict - ) + component = _make_component(module, inputs=["a", "b", "c", "d", "e"], operation="head", n=n, strict=strict) component._head() assert component._param.outputs["result"]["value"] == expected @@ -131,9 +125,7 @@ def test_head_supports_lenient_and_strict(monkeypatch, strict, n, expected): @pytest.mark.parametrize("n", [0, 10]) def test_head_strict_raises_for_out_of_range(monkeypatch, n): module = _load_list_operations_module(monkeypatch) - component = _make_component( - module, inputs=["a", "b", "c", "d", "e"], operation="head", n=n, strict=True - ) + component = _make_component(module, inputs=["a", "b", "c", "d", "e"], operation="head", n=n, strict=True) with pytest.raises(ValueError, match="head requires n"): component._head() @@ -150,9 +142,7 @@ def test_head_strict_raises_for_out_of_range(monkeypatch, n): ) def test_tail_supports_lenient_and_strict(monkeypatch, strict, n, expected): module = _load_list_operations_module(monkeypatch) - component = _make_component( - module, inputs=["a", "b", "c", "d", "e"], operation="tail", n=n, strict=strict - ) + component = _make_component(module, inputs=["a", "b", "c", "d", "e"], operation="tail", n=n, strict=strict) component._tail() assert component._param.outputs["result"]["value"] == expected @@ -161,9 +151,7 @@ def test_tail_supports_lenient_and_strict(monkeypatch, strict, n, expected): @pytest.mark.parametrize("n", [0, 10]) def test_tail_strict_raises_for_out_of_range(monkeypatch, n): module = _load_list_operations_module(monkeypatch) - component = _make_component( - module, inputs=["a", "b", "c", "d", "e"], operation="tail", n=n, strict=True - ) + component = _make_component(module, inputs=["a", "b", "c", "d", "e"], operation="tail", n=n, strict=True) with pytest.raises(ValueError, match="tail requires n"): component._tail() @@ -172,9 +160,7 @@ def test_tail_strict_raises_for_out_of_range(monkeypatch, n): @pytest.mark.parametrize("n", [0, 6, -6]) def test_nth_strict_raises_for_out_of_range(monkeypatch, n): module = _load_list_operations_module(monkeypatch) - component = _make_component( - module, inputs=["a", "b", "c", "d", "e"], operation="nth", n=n, strict=True - ) + component = _make_component(module, inputs=["a", "b", "c", "d", "e"], operation="nth", n=n, strict=True) with pytest.raises(ValueError, match="nth requires n"): component._nth() @@ -182,9 +168,7 @@ def test_nth_strict_raises_for_out_of_range(monkeypatch, n): @pytest.mark.p2 def test_set_outputs_tracks_first_and_last(monkeypatch): module = _load_list_operations_module(monkeypatch) - component = _make_component( - module, inputs=["a", "b", "c", "d", "e"], operation="tail", n=3 - ) + component = _make_component(module, inputs=["a", "b", "c", "d", "e"], operation="tail", n=3) component._tail() assert component._param.outputs["result"]["value"] == ["c", "d", "e"] assert component._param.outputs["first"]["value"] == "c" diff --git a/test/testcases/test_web_api/test_chunk_app/test_chunk_routes_unit.py b/test/testcases/test_web_api/test_chunk_app/test_chunk_routes_unit.py index dd013838b0..38c0e3956c 100644 --- a/test/testcases/test_web_api/test_chunk_app/test_chunk_routes_unit.py +++ b/test/testcases/test_web_api/test_chunk_app/test_chunk_routes_unit.py @@ -289,7 +289,7 @@ def _load_chunk_module(monkeypatch): api_utils_mod.get_result = lambda data=None, message="", code=0: {"code": code, "message": message, "data": data} api_utils_mod.get_error_data_result = lambda message="": {"code": _DummyRetCode.DATA_ERROR, "message": message, "data": False} api_utils_mod.server_error_response = lambda exc: {"code": _DummyRetCode.EXCEPTION_ERROR, "message": repr(exc), "data": False} - api_utils_mod.validate_request = lambda *_args, **_kwargs: (lambda fn: fn) + api_utils_mod.validate_request = lambda *_args, **_kwargs: lambda fn: fn api_utils_mod.add_tenant_id_to_kwargs = lambda func: func api_utils_mod.check_duplicate_ids = lambda ids, _kind: (list(dict.fromkeys(ids)), [] if len(ids) == len(set(ids)) else [f"Duplicate {_kind} ids"]) api_utils_mod.get_request_json = lambda: _AwaitableValue({}) @@ -386,12 +386,7 @@ def _load_chunk_module(monkeypatch): class _DummyLLMService: @staticmethod def query(**_kwargs): - return [SimpleNamespace( - llm_name="gpt-3.5-turbo", - model_type="chat", - max_tokens=8192, - is_tools=True - )] + return [SimpleNamespace(llm_name="gpt-3.5-turbo", model_type="chat", max_tokens=8192, is_tools=True)] llm_service_mod = ModuleType("api.db.services.llm_service") llm_service_mod.LLMService = _DummyLLMService @@ -427,22 +422,13 @@ def _load_chunk_module(monkeypatch): api_base="https://api.example.com", max_tokens=8192, used_tokens=0, - status=1 + status=1, ) @staticmethod def get_api_key(tenant_id, model_name): return _MockTableObject( - id=1, - tenant_id=tenant_id, - llm_factory="", - model_type="chat", - llm_name=model_name, - api_key="fake-api-key", - api_base="https://api.example.com", - max_tokens=8192, - used_tokens=0, - status=1 + id=1, tenant_id=tenant_id, llm_factory="", model_type="chat", llm_name=model_name, api_key="fake-api-key", api_base="https://api.example.com", max_tokens=8192, used_tokens=0, status=1 ) @staticmethod @@ -471,7 +457,7 @@ def _load_chunk_module(monkeypatch): asr_id="whisper-1", img2txt_id="gpt-4-vision-preview", rerank_id="bge-reranker", - tts_id="tts-1" + tts_id="tts-1", ) tenant_llm_service_mod.TenantLLMService = _TenantLLMService @@ -728,4 +714,3 @@ def test_restful_add_chunk_valid_image_base64_stores_before_insert(monkeypatch): inserted = module.settings.docStoreConn.inserted[-1] assert inserted.get("img_id"), inserted assert inserted.get("doc_type_kwd") == "image", inserted - diff --git a/test/testcases/test_web_api/test_chunk_app/test_create_chunk.py b/test/testcases/test_web_api/test_chunk_app/test_create_chunk.py index f9e6f76070..f796cac943 100644 --- a/test/testcases/test_web_api/test_chunk_app/test_create_chunk.py +++ b/test/testcases/test_web_api/test_chunk_app/test_create_chunk.py @@ -187,10 +187,7 @@ class TestAddChunk: chunks_count = list_chunks(WebApiAuth, dataset_id, document_id)["data"]["doc"]["chunk_count"] with ThreadPoolExecutor(max_workers=5) as executor: - futures = [ - executor.submit(add_chunk, WebApiAuth, dataset_id, document_id, {"content": f"chunk test {i}"}) - for i in range(count) - ] + futures = [executor.submit(add_chunk, WebApiAuth, dataset_id, document_id, {"content": f"chunk test {i}"}) for i in range(count)] responses = list(as_completed(futures)) assert len(responses) == count, responses assert all(future.result()["code"] == 0 for future in futures) diff --git a/test/testcases/test_web_api/test_chunk_app/test_rm_chunks.py b/test/testcases/test_web_api/test_chunk_app/test_rm_chunks.py index 6979ef041e..7d965ec99f 100644 --- a/test/testcases/test_web_api/test_chunk_app/test_rm_chunks.py +++ b/test/testcases/test_web_api/test_chunk_app/test_rm_chunks.py @@ -105,10 +105,7 @@ class TestChunksDeletion: chunk_ids = batch_add_chunks(WebApiAuth, dataset_id, document_id, count) with ThreadPoolExecutor(max_workers=5) as executor: - futures = [ - executor.submit(delete_chunks, WebApiAuth, dataset_id, document_id, {"chunk_ids": chunk_ids[i : i + 1]}) - for i in range(count) - ] + futures = [executor.submit(delete_chunks, WebApiAuth, dataset_id, document_id, {"chunk_ids": chunk_ids[i : i + 1]}) for i in range(count)] responses = list(as_completed(futures)) assert len(responses) == count, responses assert all(future.result()["code"] == 0 for future in futures) diff --git a/test/testcases/test_web_api/test_chunk_feedback/test_chunk_feedback_service.py b/test/testcases/test_web_api/test_chunk_feedback/test_chunk_feedback_service.py index 6166f00476..79e3d2a66d 100644 --- a/test/testcases/test_web_api/test_chunk_feedback/test_chunk_feedback_service.py +++ b/test/testcases/test_web_api/test_chunk_feedback/test_chunk_feedback_service.py @@ -20,6 +20,7 @@ Uses importlib to load chunk_feedback_service.py in isolation so that test/testcases/test_web_api/common.py (a test-helper module) does not shadow the project-level common/ package during collection. """ + import importlib.util import sys from pathlib import Path @@ -72,9 +73,7 @@ def _load_feedback_module(monkeypatch): _REPO_ROOT / "api" / "db" / "services" / "chunk_feedback_service.py", ) mod = importlib.util.module_from_spec(spec) - monkeypatch.setitem( - sys.modules, "api.db.services.chunk_feedback_service", mod - ) + monkeypatch.setitem(sys.modules, "api.db.services.chunk_feedback_service", mod) spec.loader.exec_module(mod) return mod, settings_mod @@ -157,12 +156,7 @@ class TestUpdateChunkWeight: mock_doc_store.update.return_value = True settings_mod.docStoreConn = mock_doc_store - result = mod.ChunkFeedbackService.update_chunk_weight( - tenant_id="tenant1", - chunk_id="chunk1", - kb_id="kb1", - delta=1 - ) + result = mod.ChunkFeedbackService.update_chunk_weight(tenant_id="tenant1", chunk_id="chunk1", kb_id="kb1", delta=1) assert result is True mock_doc_store.update.assert_called_once() @@ -176,12 +170,7 @@ class TestUpdateChunkWeight: mock_doc_store.get.return_value = None settings_mod.docStoreConn = mock_doc_store - result = mod.ChunkFeedbackService.update_chunk_weight( - tenant_id="tenant1", - chunk_id="chunk1", - kb_id="kb1", - delta=1 - ) + result = mod.ChunkFeedbackService.update_chunk_weight(tenant_id="tenant1", chunk_id="chunk1", kb_id="kb1", delta=1) assert result is False @@ -199,7 +188,7 @@ class TestUpdateChunkWeight: tenant_id="tenant1", chunk_id="chunk1", kb_id="kb1", - delta=10 # Would exceed max + delta=10, # Would exceed max ) # Verify the new_value passed to update has clamped weight @@ -221,7 +210,7 @@ class TestUpdateChunkWeight: tenant_id="tenant1", chunk_id="chunk1", kb_id="kb1", - delta=-10 # Would go below min + delta=-10, # Would go below min ) call_args = mock_doc_store.update.call_args @@ -363,11 +352,7 @@ class TestApplyFeedback: mod, _ = feedback_env monkeypatch.setattr(mod, "CHUNK_FEEDBACK_ENABLED", False) - result = mod.ChunkFeedbackService.apply_feedback( - tenant_id="tenant1", - reference={"chunks": [{"id": "chunk1", "dataset_id": "kb1"}]}, - is_positive=True - ) + result = mod.ChunkFeedbackService.apply_feedback(tenant_id="tenant1", reference={"chunks": [{"id": "chunk1", "dataset_id": "kb1"}]}, is_positive=True) assert result["success_count"] == 0 assert result["fail_count"] == 0 @@ -378,9 +363,7 @@ class TestApplyFeedback: mod, _ = feedback_env monkeypatch.setattr(mod, "CHUNK_FEEDBACK_ENABLED", True) mock_update = MagicMock(return_value=True) - monkeypatch.setattr( - mod.ChunkFeedbackService, "update_chunk_weight", mock_update - ) + monkeypatch.setattr(mod.ChunkFeedbackService, "update_chunk_weight", mock_update) reference = { "chunks": [ @@ -388,11 +371,7 @@ class TestApplyFeedback: {"id": "chunk2", "dataset_id": "kb1"}, ] } - result = mod.ChunkFeedbackService.apply_feedback( - tenant_id="tenant1", - reference=reference, - is_positive=True - ) + result = mod.ChunkFeedbackService.apply_feedback(tenant_id="tenant1", reference=reference, is_positive=True) assert result["success_count"] == 1 assert result["fail_count"] == 0 @@ -404,16 +383,10 @@ class TestApplyFeedback: mod, _ = feedback_env monkeypatch.setattr(mod, "CHUNK_FEEDBACK_ENABLED", True) mock_update = MagicMock(return_value=True) - monkeypatch.setattr( - mod.ChunkFeedbackService, "update_chunk_weight", mock_update - ) + monkeypatch.setattr(mod.ChunkFeedbackService, "update_chunk_weight", mock_update) reference = {"chunks": [{"id": "chunk1", "dataset_id": "kb1"}]} - result = mod.ChunkFeedbackService.apply_feedback( - tenant_id="tenant1", - reference=reference, - is_positive=False - ) + result = mod.ChunkFeedbackService.apply_feedback(tenant_id="tenant1", reference=reference, is_positive=False) assert result["success_count"] == 1 mock_update.assert_called_with("tenant1", "chunk1", "kb1", -1, row_id=None) @@ -423,11 +396,7 @@ class TestApplyFeedback: mod, _ = feedback_env monkeypatch.setattr(mod, "CHUNK_FEEDBACK_ENABLED", True) - result = mod.ChunkFeedbackService.apply_feedback( - tenant_id="tenant1", - reference={}, - is_positive=True - ) + result = mod.ChunkFeedbackService.apply_feedback(tenant_id="tenant1", reference={}, is_positive=True) assert result["success_count"] == 0 assert result["fail_count"] == 0 @@ -439,9 +408,7 @@ class TestApplyFeedback: monkeypatch.setattr(mod, "CHUNK_FEEDBACK_ENABLED", True) monkeypatch.setattr(mod, "CHUNK_FEEDBACK_WEIGHTING", "uniform") mock_update = MagicMock(side_effect=[True, False]) - monkeypatch.setattr( - mod.ChunkFeedbackService, "update_chunk_weight", mock_update - ) + monkeypatch.setattr(mod.ChunkFeedbackService, "update_chunk_weight", mock_update) reference = { "chunks": [ @@ -449,11 +416,7 @@ class TestApplyFeedback: {"id": "chunk2", "dataset_id": "kb1"}, ] } - result = mod.ChunkFeedbackService.apply_feedback( - tenant_id="tenant1", - reference=reference, - is_positive=True - ) + result = mod.ChunkFeedbackService.apply_feedback(tenant_id="tenant1", reference=reference, is_positive=True) assert result["success_count"] == 1 assert result["fail_count"] == 1 @@ -464,18 +427,14 @@ class TestApplyFeedback: monkeypatch.setattr(mod, "CHUNK_FEEDBACK_ENABLED", True) monkeypatch.setattr(mod, "CHUNK_FEEDBACK_WEIGHTING", "uniform") mock_update = MagicMock(return_value=True) - monkeypatch.setattr( - mod.ChunkFeedbackService, "update_chunk_weight", mock_update - ) + monkeypatch.setattr(mod.ChunkFeedbackService, "update_chunk_weight", mock_update) reference = { "chunks": [ {"id": "chunk1", "dataset_id": "kb1"}, {"id": "chunk2", "dataset_id": "kb1"}, ] } - mod.ChunkFeedbackService.apply_feedback( - tenant_id="tenant1", reference=reference, is_positive=True - ) + mod.ChunkFeedbackService.apply_feedback(tenant_id="tenant1", reference=reference, is_positive=True) mock_update.assert_any_call("tenant1", "chunk1", "kb1", mod.UPVOTE_WEIGHT_INCREMENT, row_id=None) mock_update.assert_any_call("tenant1", "chunk2", "kb1", mod.UPVOTE_WEIGHT_INCREMENT, row_id=None) @@ -485,18 +444,14 @@ class TestApplyFeedback: monkeypatch.setattr(mod, "CHUNK_FEEDBACK_ENABLED", True) monkeypatch.setattr(mod, "CHUNK_FEEDBACK_WEIGHTING", "relevance") mock_update = MagicMock(return_value=True) - monkeypatch.setattr( - mod.ChunkFeedbackService, "update_chunk_weight", mock_update - ) + monkeypatch.setattr(mod.ChunkFeedbackService, "update_chunk_weight", mock_update) reference = { "chunks": [ {"id": "a", "dataset_id": "kb1", "similarity": 0.9}, {"id": "b", "dataset_id": "kb1", "similarity": 0.1}, ] } - mod.ChunkFeedbackService.apply_feedback( - tenant_id="tenant1", reference=reference, is_positive=True - ) + mod.ChunkFeedbackService.apply_feedback(tenant_id="tenant1", reference=reference, is_positive=True) mock_update.assert_called_once_with("tenant1", "a", "kb1", 1, row_id=None) def test_apply_feedback_passes_row_id_from_reference(self, feedback_env, monkeypatch): @@ -505,17 +460,13 @@ class TestApplyFeedback: monkeypatch.setattr(mod, "CHUNK_FEEDBACK_ENABLED", True) monkeypatch.setattr(mod, "CHUNK_FEEDBACK_WEIGHTING", "relevance") mock_update = MagicMock(return_value=True) - monkeypatch.setattr( - mod.ChunkFeedbackService, "update_chunk_weight", mock_update - ) + monkeypatch.setattr(mod.ChunkFeedbackService, "update_chunk_weight", mock_update) reference = { "chunks": [ {"id": "c1", "dataset_id": "kb1", "similarity": 0.8, "row_id": 99}, ] } - mod.ChunkFeedbackService.apply_feedback( - tenant_id="tenant1", reference=reference, is_positive=True - ) + mod.ChunkFeedbackService.apply_feedback(tenant_id="tenant1", reference=reference, is_positive=True) mock_update.assert_called_once_with("tenant1", "c1", "kb1", 1, row_id=99) @@ -540,11 +491,15 @@ class TestThumbFlipFeedback: if apply_chunk_feedback and reference: if isinstance(prior_thumb, bool) and prior_thumb != new_thumb: r = mod.ChunkFeedbackService.apply_feedback( - tenant_id="t1", reference=reference, is_positive=not prior_thumb, + tenant_id="t1", + reference=reference, + is_positive=not prior_thumb, ) calls.append(("undo", r)) r = mod.ChunkFeedbackService.apply_feedback( - tenant_id="t1", reference=reference, is_positive=new_thumb is True, + tenant_id="t1", + reference=reference, + is_positive=new_thumb is True, ) calls.append(("new", r)) diff --git a/test/testcases/test_web_api/test_common.py b/test/testcases/test_web_api/test_common.py index ce850883e9..5fa53baa35 100644 --- a/test/testcases/test_web_api/test_common.py +++ b/test/testcases/test_web_api/test_common.py @@ -67,6 +67,7 @@ def _log_http_debug(method, url, req_id, payload, status, text, resp_json, elaps print(f"[HTTP DEBUG] response_text={text}") print(f"[HTTP DEBUG] response_json={json.dumps(resp_json, default=str) if resp_json is not None else None}") + def api_stats(auth, params=None, *, headers=HEADERS): res = requests.get(url=f"{HOST_ADDRESS}{API_APP_URL}/stats", headers=headers, auth=auth, params=params) return res.json() @@ -447,7 +448,7 @@ def document_update_metadata_setting(auth, dataset_id, doc_id, payload=None, *, def document_change_status(auth, dataset_id, payload=None, *, headers=HEADERS, data=None): """ Batch update document status within a dataset. - + Args: auth: Authentication credentials dataset_id: ID of the dataset @@ -539,13 +540,13 @@ def create_memory(auth, payload=None): return res.json() -def update_memory(auth, memory_id:str, payload=None): +def update_memory(auth, memory_id: str, payload=None): url = f"{HOST_ADDRESS}{MEMORY_API_URL}/{memory_id}" res = requests.put(url=url, headers=HEADERS, auth=auth, json=payload) return res.json() -def delete_memory(auth, memory_id:str): +def delete_memory(auth, memory_id: str): url = f"{HOST_ADDRESS}{MEMORY_API_URL}/{memory_id}" res = requests.delete(url=url, headers=HEADERS, auth=auth) return res.json() @@ -567,7 +568,7 @@ def list_memory(auth, params=None): return res.json() -def get_memory_config(auth, memory_id:str): +def get_memory_config(auth, memory_id: str): url = f"{HOST_ADDRESS}{MEMORY_API_URL}/{memory_id}/config" res = requests.get(url=url, headers=HEADERS, auth=auth) return res.json() diff --git a/test/testcases/test_web_api/test_connector_app/test_connector_routes_unit.py b/test/testcases/test_web_api/test_connector_app/test_connector_routes_unit.py index 605ec415f1..580a8397ae 100644 --- a/test/testcases/test_web_api/test_connector_app/test_connector_routes_unit.py +++ b/test/testcases/test_web_api/test_connector_app/test_connector_routes_unit.py @@ -241,7 +241,7 @@ def _load_connector_app(monkeypatch): "message": message, "data": data, } - api_utils_mod.validate_request = lambda *_args, **_kwargs: (lambda fn: fn) + api_utils_mod.validate_request = lambda *_args, **_kwargs: lambda fn: fn monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) constants_mod = ModuleType("common.constants") @@ -264,8 +264,7 @@ def _load_connector_app(monkeypatch): google_constants_mod = ModuleType("common.data_source.google_util.constant") google_constants_mod.WEB_OAUTH_POPUP_TEMPLATE = ( - "{title}" - "

{heading}

{message}

" + "{title}

{heading}

{message}

" ) google_constants_mod.GOOGLE_SCOPES = { config_mod.DocumentSource.GMAIL: ["scope-gmail"], @@ -372,7 +371,7 @@ def test_connector_basic_routes_and_task_controls(monkeypatch): lambda: _AwaitableValue({"id": "conn-1", "refresh_freq": 7, "config": {"x": 1}}), ) res = _run(module.update_connector("conn-1")) - assert update_calls == [("conn-1", {'id': 'conn-1', "refresh_freq": 7, "config": {"x": 1}})] + assert update_calls == [("conn-1", {"id": "conn-1", "refresh_freq": 7, "config": {"x": 1}})] assert res["data"]["id"] == "conn-1" monkeypatch.setattr( @@ -604,27 +603,33 @@ def test_google_web_oauth_callbacks_matrix(monkeypatch): assert "Authorization session was invalid" in invalid_state.body assert module._web_state_cache_key("sid", source) in redis.deleted - redis.store[module._web_state_cache_key("sid", source)] = json.dumps({ - "user_id": "tenant-1", - "client_config": {"web": {"client_id": "cid"}}, - }) + redis.store[module._web_state_cache_key("sid", source)] = json.dumps( + { + "user_id": "tenant-1", + "client_config": {"web": {"client_id": "cid"}}, + } + ) _set_request(module, args={"state": "sid", "error": "denied", "error_description": "permission denied"}) oauth_error = _run(callback()) assert "permission denied" in oauth_error.body - redis.store[module._web_state_cache_key("sid", source)] = json.dumps({ - "user_id": "tenant-1", - "client_config": {"web": {"client_id": "cid"}}, - }) + redis.store[module._web_state_cache_key("sid", source)] = json.dumps( + { + "user_id": "tenant-1", + "client_config": {"web": {"client_id": "cid"}}, + } + ) _set_request(module, args={"state": "sid"}) missing_code = _run(callback()) assert "Missing authorization code" in missing_code.body - redis.store[module._web_state_cache_key("sid", source)] = json.dumps({ - "user_id": "tenant-1", - "client_config": {"web": {"client_id": "cid"}}, - "code_verifier": "state-code-verifier", - }) + redis.store[module._web_state_cache_key("sid", source)] = json.dumps( + { + "user_id": "tenant-1", + "client_config": {"web": {"client_id": "cid"}}, + "code_verifier": "state-code-verifier", + } + ) _set_request(module, args={"state": "sid", "code": "code-123"}) success = _run(callback()) assert "Authorization completed successfully." in success.body @@ -653,16 +658,12 @@ def test_poll_google_web_result_matrix(monkeypatch): pending = _run(module.poll_google_web_result()) assert pending["code"] == module.RetCode.RUNNING - redis.store[module._web_result_cache_key("flow-1", "gmail")] = json.dumps( - {"user_id": "another-user", "credentials": "token-x"} - ) + redis.store[module._web_result_cache_key("flow-1", "gmail")] = json.dumps({"user_id": "another-user", "credentials": "token-x"}) _set_request(module, args={"type": "gmail"}, json_body={"flow_id": "flow-1"}) permission_error = _run(module.poll_google_web_result()) assert permission_error["code"] == module.RetCode.PERMISSION_ERROR - redis.store[module._web_result_cache_key("flow-1", "gmail")] = json.dumps( - {"user_id": "tenant-1", "credentials": "token-ok"} - ) + redis.store[module._web_result_cache_key("flow-1", "gmail")] = json.dumps({"user_id": "tenant-1", "credentials": "token-ok"}) _set_request(module, args={"type": "gmail"}, json_body={"flow_id": "flow-1"}) success = _run(module.poll_google_web_result()) assert success["code"] == 0 @@ -715,16 +716,12 @@ def test_box_oauth_start_callback_and_poll_matrix(monkeypatch): invalid_session = _run(module.box_web_oauth_callback()) assert invalid_session["code"] == module.RetCode.ARGUMENT_ERROR - redis.store[module._web_state_cache_key("flow-box", "box")] = json.dumps( - {"user_id": "tenant-1", "client_id": "cid", "client_secret": "sec"} - ) + redis.store[module._web_state_cache_key("flow-box", "box")] = json.dumps({"user_id": "tenant-1", "client_id": "cid", "client_secret": "sec"}) _set_request(module, args={"state": "flow-box", "code": "abc", "error": "access_denied", "error_description": "denied"}) callback_error = _run(module.box_web_oauth_callback()) assert "denied" in callback_error.body - redis.store[module._web_state_cache_key("flow-ok", "box")] = json.dumps( - {"user_id": "tenant-1", "client_id": "cid", "client_secret": "sec"} - ) + redis.store[module._web_state_cache_key("flow-ok", "box")] = json.dumps({"user_id": "tenant-1", "client_id": "cid", "client_secret": "sec"}) _set_request(module, args={"state": "flow-ok", "code": "code-ok"}) callback_success = _run(module.box_web_oauth_callback()) assert "Authorization completed successfully." in callback_success.body @@ -741,9 +738,7 @@ def test_box_oauth_start_callback_and_poll_matrix(monkeypatch): permission_error = _run(module.poll_box_web_result()) assert permission_error["code"] == module.RetCode.PERMISSION_ERROR - redis.store[module._web_result_cache_key("flow-ok", "box")] = json.dumps( - {"user_id": "tenant-1", "access_token": "at", "refresh_token": "rt"} - ) + redis.store[module._web_result_cache_key("flow-ok", "box")] = json.dumps({"user_id": "tenant-1", "access_token": "at", "refresh_token": "rt"}) poll_success = _run(module.poll_box_web_result()) assert poll_success["code"] == 0 assert poll_success["data"]["credentials"]["access_token"] == "at" diff --git a/test/testcases/test_web_api/test_dataset_management/test_dataset_sdk_routes_unit.py b/test/testcases/test_web_api/test_dataset_management/test_dataset_sdk_routes_unit.py index 4a2c8a47a0..55bd787638 100644 --- a/test/testcases/test_web_api/test_dataset_management/test_dataset_sdk_routes_unit.py +++ b/test/testcases/test_web_api/test_dataset_management/test_dataset_sdk_routes_unit.py @@ -406,9 +406,7 @@ def _load_dataset_module(monkeypatch): def _parse_args(*_args, **_kwargs): return {"name": "", "page": 1, "page_size": 30, "orderby": "create_time", "desc": True}, None - validation_spec = importlib.util.spec_from_file_location( - "api.utils.validation_utils", repo_root / "api" / "utils" / "validation_utils.py" - ) + validation_spec = importlib.util.spec_from_file_location("api.utils.validation_utils", repo_root / "api" / "utils" / "validation_utils.py") validation_mod = importlib.util.module_from_spec(validation_spec) monkeypatch.setitem(sys.modules, "api.utils.validation_utils", validation_mod) validation_spec.loader.exec_module(validation_mod) diff --git a/test/testcases/test_web_api/test_document_app/conftest.py b/test/testcases/test_web_api/test_document_app/conftest.py index 0e719a1527..ab4cbc0a57 100644 --- a/test/testcases/test_web_api/test_document_app/conftest.py +++ b/test/testcases/test_web_api/test_document_app/conftest.py @@ -28,6 +28,7 @@ class _DummyManager: def route(self, *_args, **_kwargs): def decorator(func): return func + return decorator @@ -218,6 +219,7 @@ def document_rest_api_module(monkeypatch): document_api_service_mod = ModuleType("api.apps.services.document_api_service") document_api_service_mod.validate_document_update_fields = lambda *_args, **_kwargs: (None, None) document_api_service_mod.map_doc_keys = lambda doc: doc.to_dict() if hasattr(doc, "to_dict") else doc + def _map_doc_keys_with_run_status(doc, run_status="0"): payload = doc if isinstance(doc, dict) else doc.to_dict() return {**payload, "run": run_status} diff --git a/test/testcases/test_web_api/test_document_app/test_document_metadata.py b/test/testcases/test_web_api/test_document_app/test_document_metadata.py index 19cab84d21..0d5fb3fb02 100644 --- a/test/testcases/test_web_api/test_document_app/test_document_metadata.py +++ b/test/testcases/test_web_api/test_document_app/test_document_metadata.py @@ -55,19 +55,19 @@ class TestAuthorization: assert expected_fragment in res["message"], res ## The inputs has been changed to add 'doc_ids' - ## TODO: - #@pytest.mark.p2 - #@pytest.mark.parametrize("invalid_auth, expected_code, expected_fragment", INVALID_AUTH_CASES) - #def test_metadata_summary_auth_invalid(self, invalid_auth, expected_code, expected_fragment): + ## TODO: + # @pytest.mark.p2 + # @pytest.mark.parametrize("invalid_auth, expected_code, expected_fragment", INVALID_AUTH_CASES) + # def test_metadata_summary_auth_invalid(self, invalid_auth, expected_code, expected_fragment): # res = document_metadata_summary(invalid_auth, {"kb_id": "kb_id"}) # assert res["code"] == expected_code, res # assert expected_fragment in res["message"], res ## The inputs has been changed to deprecate 'selector' - ## TODO: - #@pytest.mark.p2 - #@pytest.mark.parametrize("invalid_auth, expected_code, expected_fragment", INVALID_AUTH_CASES) - #def test_metadata_update_auth_invalid(self, invalid_auth, expected_code, expected_fragment): + ## TODO: + # @pytest.mark.p2 + # @pytest.mark.parametrize("invalid_auth, expected_code, expected_fragment", INVALID_AUTH_CASES) + # def test_metadata_update_auth_invalid(self, invalid_auth, expected_code, expected_fragment): # res = document_metadata_update(invalid_auth, {"kb_id": "kb_id", "selector": {"document_ids": ["doc_id"]}, "updates": []}) # assert res["code"] == expected_code, res # assert expected_fragment in res["message"], res @@ -87,6 +87,7 @@ class TestAuthorization: assert res["code"] == expected_code, res assert expected_fragment in res["message"], res + class TestDocumentMetadata: @pytest.mark.p2 def test_filter(self, WebApiAuth, add_dataset_func): @@ -106,18 +107,18 @@ class TestDocumentMetadata: assert docs[0]["id"] == doc_id, res ## The inputs has been changed to add 'doc_ids' - ## TODO: - #@pytest.mark.p2 - #def test_metadata_summary(self, WebApiAuth, add_document_func): + ## TODO: + # @pytest.mark.p2 + # def test_metadata_summary(self, WebApiAuth, add_document_func): # kb_id, _ = add_document_func # res = document_metadata_summary(WebApiAuth, {"kb_id": kb_id}) # assert res["code"] == 0, res # assert isinstance(res["data"]["summary"], dict), res ## The inputs has been changed to deprecate 'selector' - ## TODO: - #@pytest.mark.p2 - #def test_metadata_update(self, WebApiAuth, add_document_func): + ## TODO: + # @pytest.mark.p2 + # def test_metadata_update(self, WebApiAuth, add_document_func): # kb_id, doc_id = add_document_func # payload = { # "kb_id": kb_id, @@ -132,11 +133,11 @@ class TestDocumentMetadata: # assert info_res["code"] == 0, info_res # meta_fields = info_res["data"][0].get("meta_fields", {}) # assert meta_fields.get("author") == "alice", info_res - + ## The inputs has been changed to deprecate 'selector' - ## TODO: - #@pytest.mark.p2 - #def test_update_metadata_setting(self, WebApiAuth, add_document_func): + ## TODO: + # @pytest.mark.p2 + # def test_update_metadata_setting(self, WebApiAuth, add_document_func): # _, doc_id = add_document_func # metadata = {"source": "test"} # res = document_update_metadata_setting(WebApiAuth, {"doc_id": doc_id, "metadata": metadata}) @@ -156,7 +157,6 @@ class TestDocumentMetadata: assert info_res["code"] == 0, info_res assert info_res["data"]["docs"][0]["status"] == "1", info_res - @pytest.mark.p2 def test_update_document_change_parser(self, WebApiAuth, add_document_func): """Test updating document chunk_method via PATCH /api/v1/datasets//documents/.""" @@ -185,7 +185,6 @@ class TestDocumentMetadata: assert res["code"] == 0, res assert res["data"]["docs"][0]["chunk_method"] == new_parser_id, res - @pytest.mark.p2 def test_update_document_change_pipeline(self, WebApiAuth, add_document_func): """Test updating document pipeline via PATCH /api/v1/datasets//documents/.""" @@ -224,9 +223,9 @@ class TestDocumentMetadataNegative: assert "KB ID" in res["message"], res ## The inputs has been changed to deprecate 'selector' - ## TODO: - #@pytest.mark.p3 - #def test_metadata_update_missing_kb_id(self, WebApiAuth, add_document_func): + ## TODO: + # @pytest.mark.p3 + # def test_metadata_update_missing_kb_id(self, WebApiAuth, add_document_func): # _, doc_id = add_document_func # res = document_metadata_update(WebApiAuth, {"selector": {"document_ids": [doc_id]}, "updates": []}) # assert res["code"] == 101, res @@ -314,33 +313,25 @@ class TestDocumentMetadataUnit: def test_update_metadata_success(self, WebApiAuth, add_document_func): """Test the new unified update_metadata API - success case.""" kb_id, doc_id = add_document_func - res = document_metadata_update( - WebApiAuth, kb_id, - { - "selector": {"document_ids": [doc_id]}, - "updates": [{"key": "author", "value": "test_author"}], - "deletes": [] - } - ) + res = document_metadata_update(WebApiAuth, kb_id, {"selector": {"document_ids": [doc_id]}, "updates": [{"key": "author", "value": "test_author"}], "deletes": []}) assert res["code"] == 0, res - @pytest.mark.p3 def test_update_metadata_invalid_delete_item(self, WebApiAuth, add_document_func): """Test the new unified update_metadata API - invalid delete item.""" kb_id, doc_id = add_document_func res = document_metadata_update( - WebApiAuth, kb_id, + WebApiAuth, + kb_id, { "selector": {"document_ids": [doc_id]}, "updates": [], - "deletes": [{}] # Invalid - missing key - } + "deletes": [{}], # Invalid - missing key + }, ) assert res["code"] == 102 assert "Each delete requires key" in res["message"], res - def test_get_route_not_found_success_and_exception_unit(self, document_app_module, monkeypatch): module = document_app_module @@ -690,9 +681,7 @@ class TestDocumentMetadataUnit: monkeypatch.setattr( module, "apply_safe_file_response_headers", - lambda response, content_type, extension: response.headers.update( - {"content_type": content_type, "extension": extension} - ), + lambda response, content_type, extension: response.headers.update({"content_type": content_type, "extension": extension}), ) monkeypatch.setattr( module.settings, diff --git a/test/testcases/test_web_api/test_document_app/test_list_documents.py b/test/testcases/test_web_api/test_document_app/test_list_documents.py index 37dc23d692..9ace9185ed 100644 --- a/test/testcases/test_web_api/test_document_app/test_list_documents.py +++ b/test/testcases/test_web_api/test_document_app/test_list_documents.py @@ -47,7 +47,6 @@ class TestDocumentsList: assert len(res["data"]["docs"]) == 5 assert res["data"]["total"] == 5 - @pytest.mark.p1 @pytest.mark.parametrize( "params, expected_code, expected_page_size, expected_message", @@ -97,10 +96,10 @@ class TestDocumentsList: @pytest.mark.parametrize( "params, expected_code, assertions, expected_message", [ - ({"orderby": None}, 0, lambda r: (is_sorted(r["data"]["docs"], "create_time", True)), ""), - ({"orderby": "create_time"}, 0, lambda r: (is_sorted(r["data"]["docs"], "create_time", True)), ""), - ({"orderby": "update_time"}, 0, lambda r: (is_sorted(r["data"]["docs"], "update_time", True)), ""), - pytest.param({"orderby": "name", "desc": "False"}, 0, lambda r: (is_sorted(r["data"]["docs"], "name", False)), "", marks=pytest.mark.skip(reason="issues/5851")), + ({"orderby": None}, 0, lambda r: is_sorted(r["data"]["docs"], "create_time", True), ""), + ({"orderby": "create_time"}, 0, lambda r: is_sorted(r["data"]["docs"], "create_time", True), ""), + ({"orderby": "update_time"}, 0, lambda r: is_sorted(r["data"]["docs"], "update_time", True), ""), + pytest.param({"orderby": "name", "desc": "False"}, 0, lambda r: is_sorted(r["data"]["docs"], "name", False), "", marks=pytest.mark.skip(reason="issues/5851")), pytest.param({"orderby": "unknown"}, 102, 0, "orderby should be create_time or update_time", marks=pytest.mark.skip(reason="issues/5851")), ], ) @@ -118,14 +117,14 @@ class TestDocumentsList: @pytest.mark.parametrize( "params, expected_code, assertions, expected_message", [ - ({"desc": None}, 0, lambda r: (is_sorted(r["data"]["docs"], "create_time", True)), ""), - ({"desc": "true"}, 0, lambda r: (is_sorted(r["data"]["docs"], "create_time", True)), ""), - ({"desc": "True"}, 0, lambda r: (is_sorted(r["data"]["docs"], "create_time", True)), ""), - ({"desc": True}, 0, lambda r: (is_sorted(r["data"]["docs"], "create_time", True)), ""), - pytest.param({"desc": "false"}, 0, lambda r: (is_sorted(r["data"]["docs"], "create_time", False)), "", marks=pytest.mark.skip(reason="issues/5851")), - ({"desc": "False"}, 0, lambda r: (is_sorted(r["data"]["docs"], "create_time", False)), ""), - ({"desc": False}, 0, lambda r: (is_sorted(r["data"]["docs"], "create_time", False)), ""), - ({"desc": "False", "orderby": "update_time"}, 0, lambda r: (is_sorted(r["data"]["docs"], "update_time", False)), ""), + ({"desc": None}, 0, lambda r: is_sorted(r["data"]["docs"], "create_time", True), ""), + ({"desc": "true"}, 0, lambda r: is_sorted(r["data"]["docs"], "create_time", True), ""), + ({"desc": "True"}, 0, lambda r: is_sorted(r["data"]["docs"], "create_time", True), ""), + ({"desc": True}, 0, lambda r: is_sorted(r["data"]["docs"], "create_time", True), ""), + pytest.param({"desc": "false"}, 0, lambda r: is_sorted(r["data"]["docs"], "create_time", False), "", marks=pytest.mark.skip(reason="issues/5851")), + ({"desc": "False"}, 0, lambda r: is_sorted(r["data"]["docs"], "create_time", False), ""), + ({"desc": False}, 0, lambda r: is_sorted(r["data"]["docs"], "create_time", False), ""), + ({"desc": "False", "orderby": "update_time"}, 0, lambda r: is_sorted(r["data"]["docs"], "update_time", False), ""), pytest.param({"desc": "unknown"}, 102, 0, "desc should be true or false", marks=pytest.mark.skip(reason="issues/5851")), ], ) diff --git a/test/testcases/test_web_api/test_document_app/test_paser_documents.py b/test/testcases/test_web_api/test_document_app/test_paser_documents.py index 4a3980093a..e3b0c00839 100644 --- a/test/testcases/test_web_api/test_document_app/test_paser_documents.py +++ b/test/testcases/test_web_api/test_document_app/test_paser_documents.py @@ -319,8 +319,7 @@ class TestDocumentsParseStop: return True kb_id, document_ids = add_documents_func - parse_documents(WebApiAuth, {"doc_ids": document_ids, "run": - "1"}) + parse_documents(WebApiAuth, {"doc_ids": document_ids, "run": "1"}) if callable(payload): payload = payload(document_ids) diff --git a/test/testcases/test_web_api/test_document_app/test_rm_documents.py b/test/testcases/test_web_api/test_document_app/test_rm_documents.py index f0ba072c9d..48a77eba79 100644 --- a/test/testcases/test_web_api/test_document_app/test_rm_documents.py +++ b/test/testcases/test_web_api/test_document_app/test_rm_documents.py @@ -47,7 +47,7 @@ class TestDocumentsDeletion: "payload, expected_code, expected_message, remaining", [ ({}, 102, "should either provide doc ids or set delete_all(true), dataset:", 3), - ({"invalid_key":[]}, 101, "Field: - Message: - Value: <[]>", 3), + ({"invalid_key": []}, 101, "Field: - Message: - Value: <[]>", 3), ({"ids": ""}, 101, "Field: - Message: - Value: <>", 3), ({"ids": ["invalid_id"]}, 102, "These documents do not belong to dataset", 3), ("not json", 101, "Invalid request payload: expected object, got str", 3), diff --git a/test/testcases/test_web_api/test_document_app/test_upload_documents.py b/test/testcases/test_web_api/test_document_app/test_upload_documents.py index 27431e40af..f0bc07b07c 100644 --- a/test/testcases/test_web_api/test_document_app/test_upload_documents.py +++ b/test/testcases/test_web_api/test_document_app/test_upload_documents.py @@ -23,6 +23,7 @@ from utils.file_utils import create_txt_file from concurrent.futures import ThreadPoolExecutor, as_completed + @pytest.mark.p1 @pytest.mark.usefixtures("clear_datasets") class TestAuthorization: @@ -315,7 +316,6 @@ class TestDocumentsUploadUnit: assert "code" in res - @pytest.mark.p2 class TestWebCrawlUnit: def test_invalid_url(self, document_rest_api_module, monkeypatch): diff --git a/test/testcases/test_web_api/test_document_app/test_upload_info_unit.py b/test/testcases/test_web_api/test_document_app/test_upload_info_unit.py index 443e79ef96..f44edf3e66 100644 --- a/test/testcases/test_web_api/test_document_app/test_upload_info_unit.py +++ b/test/testcases/test_web_api/test_document_app/test_upload_info_unit.py @@ -65,6 +65,7 @@ def _run(coro): # End-to-End Tests # ============================================================================ + @pytest.mark.p2 class TestUploadInfoE2E: """End-to-end tests for the /api/v1/documents/upload endpoint""" diff --git a/test/testcases/test_web_api/test_file_app/test_file_routes_unit.py b/test/testcases/test_web_api/test_file_app/test_file_routes_unit.py index b95592ad9e..30b8c6a8f4 100644 --- a/test/testcases/test_web_api/test_file_app/test_file_routes_unit.py +++ b/test/testcases/test_web_api/test_file_app/test_file_routes_unit.py @@ -354,4 +354,3 @@ def test_parent_and_ancestors_use_new_routes(monkeypatch): assert parent_res["data"]["parent_folder"]["id"] == "parent1" assert ancestors_res["code"] == 0 assert ancestors_res["data"]["parent_folders"][0]["id"] == "root" - diff --git a/test/testcases/test_web_api/test_llm_app/test_llm_list_unit.py b/test/testcases/test_web_api/test_llm_app/test_llm_list_unit.py index 1b4dd47a6a..2bc2b99a3f 100644 --- a/test/testcases/test_web_api/test_llm_app/test_llm_list_unit.py +++ b/test/testcases/test_web_api/test_llm_app/test_llm_list_unit.py @@ -210,7 +210,7 @@ def _load_llm_app(monkeypatch): api_utils_mod.get_request_json = _get_request_json api_utils_mod.server_error_response = lambda exc: {"code": 500, "message": str(exc), "data": None} - api_utils_mod.validate_request = lambda *_args, **_kwargs: (lambda fn: fn) + api_utils_mod.validate_request = lambda *_args, **_kwargs: lambda fn: fn monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) constants_mod = ModuleType("common.constants") diff --git a/test/testcases/test_web_api/test_memory_app/conftest.py b/test/testcases/test_web_api/test_memory_app/conftest.py index 8e1c30515e..678c83d231 100644 --- a/test/testcases/test_web_api/test_memory_app/conftest.py +++ b/test/testcases/test_web_api/test_memory_app/conftest.py @@ -17,6 +17,7 @@ import pytest import random from test_common import create_memory, list_memory, delete_memory + @pytest.fixture(scope="function") def add_memory_func(request, WebApiAuth): def cleanup(): @@ -33,7 +34,7 @@ def add_memory_func(request, WebApiAuth): "name": f"test_memory_{i}", "memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)), "embd_id": "BAAI/bge-small-en-v1.5@Builtin", - "llm_id": "glm-4-flash@ZHIPU-AI" + "llm_id": "glm-4-flash@ZHIPU-AI", } res = create_memory(WebApiAuth, payload) memory_ids.append(res["data"]["id"]) diff --git a/test/testcases/test_web_api/test_memory_app/test_create_memory.py b/test/testcases/test_web_api/test_memory_app/test_create_memory.py index 27187c765f..9ecaf7c78f 100644 --- a/test/testcases/test_web_api/test_memory_app/test_create_memory.py +++ b/test/testcases/test_web_api/test_memory_app/test_create_memory.py @@ -30,7 +30,7 @@ class TestAuthorization: (None, 401, ""), (RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, ""), ], - ids=["empty_auth", "invalid_api_token"] + ids=["empty_auth", "invalid_api_token"], ) def test_auth_invalid(self, invalid_auth, expected_code, expected_message): res = create_memory(invalid_auth) @@ -46,11 +46,11 @@ class TestMemoryCreate: "name": name, "memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)), "embd_id": "BAAI/bge-small-en-v1.5@Builtin", - "llm_id": "glm-4-flash@ZHIPU-AI" + "llm_id": "glm-4-flash@ZHIPU-AI", } res = create_memory(WebApiAuth, payload) assert res["code"] == 0, res - pattern = rf'^{name}|{name}(?:\((\d+)\))?$' + pattern = rf"^{name}|{name}(?:\((\d+)\))?$" escaped_name = re.escape(res["data"]["name"]) assert re.match(pattern, escaped_name), res @@ -60,7 +60,7 @@ class TestMemoryCreate: [ ("", "Memory name cannot be empty or whitespace."), (" ", "Memory name cannot be empty or whitespace."), - ("a" * 129, f"Memory name '{'a'*129}' exceeds limit of 128."), + ("a" * 129, f"Memory name '{'a' * 129}' exceeds limit of 128."), ], ids=["empty_name", "space_name", "too_long_name"], ) @@ -69,7 +69,7 @@ class TestMemoryCreate: "name": name, "memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)), "embd_id": "BAAI/bge-small-en-v1.5@Builtin", - "llm_id": "glm-4-flash@ZHIPU-AI" + "llm_id": "glm-4-flash@ZHIPU-AI", } res = create_memory(WebApiAuth, payload) assert res["message"] == expected_message, res @@ -77,12 +77,7 @@ class TestMemoryCreate: @pytest.mark.p2 @pytest.mark.parametrize("name", ["invalid_type_name", "memory_alpha"]) def test_type_invalid(self, WebApiAuth, name): - payload = { - "name": name, - "memory_type": ["something"], - "embd_id": "BAAI/bge-small-en-v1.5@Builtin", - "llm_id": "glm-4-flash@ZHIPU-AI" - } + payload = {"name": name, "memory_type": ["something"], "embd_id": "BAAI/bge-small-en-v1.5@Builtin", "llm_id": "glm-4-flash@ZHIPU-AI"} res = create_memory(WebApiAuth, payload) assert res["message"] == f"Memory type '{ {'something'} }' is not supported.", res @@ -93,7 +88,7 @@ class TestMemoryCreate: "name": name, "memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)), "embd_id": "BAAI/bge-small-en-v1.5@Builtin", - "llm_id": "glm-4-flash@ZHIPU-AI" + "llm_id": "glm-4-flash@ZHIPU-AI", } res1 = create_memory(WebApiAuth, payload) assert res1["code"] == 0, res1 diff --git a/test/testcases/test_web_api/test_memory_app/test_list_memory.py b/test/testcases/test_web_api/test_memory_app/test_list_memory.py index b6ed469b68..3bb6878731 100644 --- a/test/testcases/test_web_api/test_memory_app/test_list_memory.py +++ b/test/testcases/test_web_api/test_memory_app/test_list_memory.py @@ -20,6 +20,7 @@ from test_common import list_memory, get_memory_config from configs import INVALID_API_TOKEN from libs.auth import RAGFlowWebApiAuth + class TestAuthorization: @pytest.mark.p2 @pytest.mark.parametrize( @@ -45,11 +46,12 @@ class TestCapability: assert len(responses) == count, responses assert all(future.result()["code"] == 0 for future in futures) + @pytest.mark.usefixtures("add_memory_func") class TestMemoryList: @pytest.mark.p2 def test_params_unset(self, WebApiAuth): - res = list_memory(WebApiAuth, None) + res = list_memory(WebApiAuth, None) assert res["code"] == 0, res @pytest.mark.p2 @@ -67,8 +69,7 @@ class TestMemoryList: ({"page": 2, "page_size": 2}, 1), ({"page": 5, "page_size": 10}, 0), ], - ids=["normal_first_page", "beyond_max_page", "normal_last_partial_page" , "normal_middle_page", - "full_data_single_page"], + ids=["normal_first_page", "beyond_max_page", "normal_last_partial_page", "normal_middle_page", "full_data_single_page"], ) def test_page(self, WebApiAuth, params, expected_page_size): # have added 3 memories in fixture @@ -112,7 +113,21 @@ class TestMemoryList: memory_config = get_memory_config(WebApiAuth, memory_list["data"]["memory_list"][0]["id"]) assert memory_config["code"] == 0, memory_config assert memory_config["data"]["id"] == memory_list["data"]["memory_list"][0]["id"], memory_config - for field in ["name", "avatar", "tenant_id", "owner_name", "memory_type", "storage_type", - "embd_id", "llm_id", "permissions", "description", "memory_size", "forgetting_policy", - "temperature", "system_prompt", "user_prompt"]: + for field in [ + "name", + "avatar", + "tenant_id", + "owner_name", + "memory_type", + "storage_type", + "embd_id", + "llm_id", + "permissions", + "description", + "memory_size", + "forgetting_policy", + "temperature", + "system_prompt", + "user_prompt", + ]: assert field in memory_config["data"], memory_config diff --git a/test/testcases/test_web_api/test_memory_app/test_rm_memory.py b/test/testcases/test_web_api/test_memory_app/test_rm_memory.py index de04139217..a9866f4db8 100644 --- a/test/testcases/test_web_api/test_memory_app/test_rm_memory.py +++ b/test/testcases/test_web_api/test_memory_app/test_rm_memory.py @@ -14,10 +14,11 @@ # limitations under the License. # import pytest -from test_common import (list_memory, delete_memory) +from test_common import list_memory, delete_memory from configs import INVALID_API_TOKEN from libs.auth import RAGFlowWebApiAuth + class TestAuthorization: @pytest.mark.p2 @pytest.mark.parametrize( diff --git a/test/testcases/test_web_api/test_memory_app/test_update_memory.py b/test/testcases/test_web_api/test_memory_app/test_update_memory.py index 72ecfaa8ec..de47b55a95 100644 --- a/test/testcases/test_web_api/test_memory_app/test_update_memory.py +++ b/test/testcases/test_web_api/test_memory_app/test_update_memory.py @@ -31,7 +31,7 @@ class TestAuthorization: (None, 401, ""), (RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, ""), ], - ids=["empty_auth", "invalid_api_token"] + ids=["empty_auth", "invalid_api_token"], ) def test_auth_invalid(self, invalid_auth, expected_code, expected_message): res = update_memory(invalid_auth, "memory_id") @@ -40,7 +40,6 @@ class TestAuthorization: class TestMemoryUpdate: - @pytest.mark.p1 @pytest.mark.parametrize("name", ["updated_memory", "f" * 128]) def test_name(self, WebApiAuth, add_memory_func, name): @@ -58,7 +57,7 @@ class TestMemoryUpdate: ("", "Memory name cannot be empty or whitespace."), (" ", "Memory name cannot be empty or whitespace."), ("a" * 129, f"Memory name '{'a' * 129}' exceeds limit of 128."), - ] + ], ) def test_name_invalid(self, WebApiAuth, add_memory_func, name, expected_message): memory_ids = add_memory_func @@ -115,14 +114,7 @@ class TestMemoryUpdate: assert "Do not set tenant_llm_id or tenant_embd_id directly" in res["message"], res @pytest.mark.p2 - @pytest.mark.parametrize( - "permission", - [ - "me", - "team" - ], - ids=["me", "team"] - ) + @pytest.mark.parametrize("permission", ["me", "team"], ids=["me", "team"]) def test_permission(self, WebApiAuth, add_memory_func, permission): memory_ids = add_memory_func payload = {"permissions": permission} @@ -130,7 +122,6 @@ class TestMemoryUpdate: assert res["code"] == 0, res assert res["data"]["permissions"] == permission.lower().strip(), res - @pytest.mark.p1 def test_memory_size(self, WebApiAuth, add_memory_func): memory_ids = add_memory_func diff --git a/test/testcases/test_web_api/test_message_app/conftest.py b/test/testcases/test_web_api/test_message_app/conftest.py index 6d34930ea7..69b15c82fb 100644 --- a/test/testcases/test_web_api/test_message_app/conftest.py +++ b/test/testcases/test_web_api/test_message_app/conftest.py @@ -28,13 +28,9 @@ def add_empty_raw_type_memory(request, WebApiAuth): exist_memory_ids = [memory["id"] for memory in memory_list_res["data"]["memory_list"]] for _memory_id in exist_memory_ids: delete_memory(WebApiAuth, _memory_id) + request.addfinalizer(cleanup) - payload = { - "name": "test_memory_0", - "memory_type": ["raw"], - "embd_id": "BAAI/bge-small-en-v1.5@Builtin", - "llm_id": "glm-4-flash@ZHIPU-AI" - } + payload = {"name": "test_memory_0", "memory_type": ["raw"], "embd_id": "BAAI/bge-small-en-v1.5@Builtin", "llm_id": "glm-4-flash@ZHIPU-AI"} res = create_memory(WebApiAuth, payload) memory_id = res["data"]["id"] request.cls.memory_id = memory_id @@ -49,12 +45,13 @@ def add_empty_multiple_type_memory(request, WebApiAuth): exist_memory_ids = [memory["id"] for memory in memory_list_res["data"]["memory_list"]] for _memory_id in exist_memory_ids: delete_memory(WebApiAuth, _memory_id) + request.addfinalizer(cleanup) payload = { "name": "test_memory_0", "memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(1, 3)), "embd_id": "BAAI/bge-small-en-v1.5@Builtin", - "llm_id": "glm-4-flash@ZHIPU-AI" + "llm_id": "glm-4-flash@ZHIPU-AI", } res = create_memory(WebApiAuth, payload) memory_id = res["data"]["id"] @@ -78,7 +75,7 @@ def add_2_multiple_type_memory(request, WebApiAuth): "name": f"test_memory_{i}", "memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(1, 3)), "embd_id": "BAAI/bge-small-en-v1.5@Builtin", - "llm_id": "glm-4-flash@ZHIPU-AI" + "llm_id": "glm-4-flash@ZHIPU-AI", } res = create_memory(WebApiAuth, payload) memory_ids.append(res["data"]["id"]) @@ -100,7 +97,7 @@ def add_memory_with_multiple_type_message_func(request, WebApiAuth): "name": "test_memory_0", "memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(1, 3)), "embd_id": "BAAI/bge-small-en-v1.5@Builtin", - "llm_id": "glm-4-flash@ZHIPU-AI" + "llm_id": "glm-4-flash@ZHIPU-AI", } res = create_memory(WebApiAuth, payload) memory_id = res["data"]["id"] @@ -116,7 +113,7 @@ Coriander is a versatile herb with two main edible parts, and its name can refer 1. Leaves and Stems (often called Cilantro or Fresh Coriander): These are the fresh, green, fragrant leaves and tender stems of the plant Coriandrum sativum. They have a bright, citrusy, and sometimes pungent flavor. Cilantro is widely used as a garnish or key ingredient in cuisines like Mexican, Indian, Thai, and Middle Eastern. 2. Seeds (called Coriander Seeds): These are the dried, golden-brown seeds of the same plant. When ground, they become coriander powder. The seeds have a warm, nutty, floral, and slightly citrusy taste, completely different from the fresh leaves. They are a fundamental spice in curries, stews, pickles, and baking. Key Point of Confusion: The naming differs by region. In North America, "coriander" typically refers to the seeds, while "cilantro" refers to the fresh leaves. In the UK, Europe, and many other parts of the world, "coriander" refers to the fresh herb, and the seeds are called "coriander seeds." -""" +""", } add_message(WebApiAuth, message_payload) request.cls.memory_id = memory_id @@ -135,12 +132,7 @@ def add_memory_with_5_raw_message_func(request, WebApiAuth): request.addfinalizer(cleanup) - payload = { - "name": "test_memory_1", - "memory_type": ["raw"], - "embd_id": "BAAI/bge-small-en-v1.5@Builtin", - "llm_id": "glm-4-flash@ZHIPU-AI" - } + payload = {"name": "test_memory_1", "memory_type": ["raw"], "embd_id": "BAAI/bge-small-en-v1.5@Builtin", "llm_id": "glm-4-flash@ZHIPU-AI"} res = create_memory(WebApiAuth, payload) memory_id = res["data"]["id"] agent_ids = [uuid.uuid4().hex for _ in range(2)] @@ -157,11 +149,11 @@ Coriander is a versatile herb with two main edible parts, and its name can refer 1. Leaves and Stems (often called Cilantro or Fresh Coriander): These are the fresh, green, fragrant leaves and tender stems of the plant Coriandrum sativum. They have a bright, citrusy, and sometimes pungent flavor. Cilantro is widely used as a garnish or key ingredient in cuisines like Mexican, Indian, Thai, and Middle Eastern. 2. Seeds (called Coriander Seeds): These are the dried, golden-brown seeds of the same plant. When ground, they become coriander powder. The seeds have a warm, nutty, floral, and slightly citrusy taste, completely different from the fresh leaves. They are a fundamental spice in curries, stews, pickles, and baking. Key Point of Confusion: The naming differs by region. In North America, "coriander" typically refers to the seeds, while "cilantro" refers to the fresh leaves. In the UK, Europe, and many other parts of the world, "coriander" refers to the fresh herb, and the seeds are called "coriander seeds." -""" +""", } add_message(WebApiAuth, message_payload) request.cls.memory_id = memory_id request.cls.agent_ids = agent_ids request.cls.session_ids = session_ids - time.sleep(2) # make sure refresh to index before search + time.sleep(2) # make sure refresh to index before search return memory_id diff --git a/test/testcases/test_web_api/test_message_app/test_add_message.py b/test/testcases/test_web_api/test_message_app/test_add_message.py index 43e9152e4f..544aceb63d 100644 --- a/test/testcases/test_web_api/test_message_app/test_add_message.py +++ b/test/testcases/test_web_api/test_message_app/test_add_message.py @@ -21,6 +21,7 @@ from test_common import list_memory_message, add_message from configs import INVALID_API_TOKEN from libs.auth import RAGFlowWebApiAuth + class TestAuthorization: @pytest.mark.p2 @pytest.mark.parametrize( @@ -38,7 +39,6 @@ class TestAuthorization: @pytest.mark.usefixtures("add_empty_raw_type_memory") class TestAddRawMessage: - @pytest.mark.p1 def test_add_raw_message(self, WebApiAuth): memory_id = self.memory_id @@ -59,7 +59,7 @@ Uses: Pineapples are eaten fresh, cooked, grilled, juiced, or canned. They are a Nutrition: They are a good source of Vitamin C, manganese, and contain an enzyme called bromelain, which aids in digestion and can tenderize meat. Symbolism: The pineapple is a traditional symbol of hospitality and welcome in many cultures. Are you asking about the fruit itself, or its use in a specific context? -""" +""", } add_res = add_message(WebApiAuth, message_payload) assert add_res["code"] == 0, add_res @@ -88,7 +88,6 @@ Are you asking about the fruit itself, or its use in a specific context? @pytest.mark.usefixtures("add_empty_multiple_type_memory") class TestAddMultipleTypeMessage: - @pytest.mark.p1 def test_add_multiple_type_message(self, WebApiAuth): memory_id = self.memory_id @@ -109,7 +108,7 @@ Uses: Pineapples are eaten fresh, cooked, grilled, juiced, or canned. They are a Nutrition: They are a good source of Vitamin C, manganese, and contain an enzyme called bromelain, which aids in digestion and can tenderize meat. Symbolism: The pineapple is a traditional symbol of hospitality and welcome in many cultures. Are you asking about the fruit itself, or its use in a specific context? -""" +""", } add_res = add_message(WebApiAuth, message_payload) assert add_res["code"] == 0, add_res @@ -124,7 +123,6 @@ Are you asking about the fruit itself, or its use in a specific context? @pytest.mark.usefixtures("add_2_multiple_type_memory") class TestAddToMultipleMemory: - @pytest.mark.p1 def test_add_to_multiple_memory(self, WebApiAuth): memory_ids = self.memory_ids @@ -145,7 +143,7 @@ Uses: Pineapples are eaten fresh, cooked, grilled, juiced, or canned. They are a Nutrition: They are a good source of Vitamin C, manganese, and contain an enzyme called bromelain, which aids in digestion and can tenderize meat. Symbolism: The pineapple is a traditional symbol of hospitality and welcome in many cultures. Are you asking about the fruit itself, or its use in a specific context? -""" +""", } add_res = add_message(WebApiAuth, message_payload) assert add_res["code"] == 0, add_res diff --git a/test/testcases/test_web_api/test_message_app/test_forget_message.py b/test/testcases/test_web_api/test_message_app/test_forget_message.py index 9428fcb23f..afc44eeb7b 100644 --- a/test/testcases/test_web_api/test_message_app/test_forget_message.py +++ b/test/testcases/test_web_api/test_message_app/test_forget_message.py @@ -38,7 +38,6 @@ class TestAuthorization: @pytest.mark.usefixtures("add_memory_with_5_raw_message_func") class TestForgetMessage: - @pytest.mark.p1 def test_forget_message(self, WebApiAuth): memory_id = self.memory_id diff --git a/test/testcases/test_web_api/test_message_app/test_get_message_content.py b/test/testcases/test_web_api/test_message_app/test_get_message_content.py index ac37ac3ada..098ddf80a8 100644 --- a/test/testcases/test_web_api/test_message_app/test_get_message_content.py +++ b/test/testcases/test_web_api/test_message_app/test_get_message_content.py @@ -20,6 +20,7 @@ from test_common import get_message_content, get_recent_message from configs import INVALID_API_TOKEN from libs.auth import RAGFlowWebApiAuth + class TestAuthorization: @pytest.mark.p2 @pytest.mark.parametrize( @@ -37,7 +38,6 @@ class TestAuthorization: @pytest.mark.usefixtures("add_memory_with_multiple_type_message_func") class TestGetMessageContent: - @pytest.mark.p1 def test_get_message_content(self, WebApiAuth): memory_id = self.memory_id diff --git a/test/testcases/test_web_api/test_message_app/test_get_recent_message.py b/test/testcases/test_web_api/test_message_app/test_get_recent_message.py index 355f328d27..3c912c636a 100644 --- a/test/testcases/test_web_api/test_message_app/test_get_recent_message.py +++ b/test/testcases/test_web_api/test_message_app/test_get_recent_message.py @@ -38,7 +38,6 @@ class TestAuthorization: @pytest.mark.usefixtures("add_memory_with_5_raw_message_func") class TestGetRecentMessage: - @pytest.mark.p1 def test_get_recent_messages(self, WebApiAuth): memory_id = self.memory_id diff --git a/test/testcases/test_web_api/test_message_app/test_list_message.py b/test/testcases/test_web_api/test_message_app/test_list_message.py index a55f8b2924..e1c5a2aaeb 100644 --- a/test/testcases/test_web_api/test_message_app/test_list_message.py +++ b/test/testcases/test_web_api/test_message_app/test_list_message.py @@ -39,7 +39,6 @@ class TestAuthorization: @pytest.mark.usefixtures("add_memory_with_5_raw_message_func") class TestMessageList: - @pytest.mark.p2 def test_params_unset(self, WebApiAuth): memory_id = self.memory_id @@ -64,8 +63,7 @@ class TestMessageList: ({"page": 3, "page_size": 2}, 1), ({"page": 5, "page_size": 10}, 0), ], - ids=["normal_first_page", "beyond_max_page", "normal_last_partial_page", "normal_middle_page", - "full_data_single_page"], + ids=["normal_first_page", "beyond_max_page", "normal_last_partial_page", "normal_middle_page", "full_data_single_page"], ) def test_page_size(self, WebApiAuth, params, expected_page_size): # have added 5 messages in fixture diff --git a/test/testcases/test_web_api/test_message_app/test_search_message.py b/test/testcases/test_web_api/test_message_app/test_search_message.py index 0b05df9b53..beeac7fc62 100644 --- a/test/testcases/test_web_api/test_message_app/test_search_message.py +++ b/test/testcases/test_web_api/test_message_app/test_search_message.py @@ -18,6 +18,7 @@ from test_common import search_message, list_memory_message from configs import INVALID_API_TOKEN from libs.auth import RAGFlowWebApiAuth + class TestAuthorization: @pytest.mark.p2 @pytest.mark.parametrize( @@ -35,7 +36,6 @@ class TestAuthorization: @pytest.mark.usefixtures("add_memory_with_multiple_type_message_func") class TestSearchMessage: - @pytest.mark.p1 def test_query(self, WebApiAuth): memory_id = self.memory_id @@ -71,11 +71,7 @@ class TestSearchMessage: assert list_res["data"]["messages"]["total_count"] > 0 query = "Coriander is a versatile herb with two main edible parts. What's its name can refer to?" - params = { - "similarity_threshold": 0.1, - "keywords_similarity_weight": 0.6, - "top_n": 4 - } + params = {"similarity_threshold": 0.1, "keywords_similarity_weight": 0.6, "top_n": 4} res = search_message(WebApiAuth, {"memory_id": memory_id, "query": query, **params}) assert res["code"] == 0, res assert len(res["data"]) > 0 diff --git a/test/testcases/test_web_api/test_message_app/test_update_message_status.py b/test/testcases/test_web_api/test_message_app/test_update_message_status.py index 107c126d55..fa10ddadc9 100644 --- a/test/testcases/test_web_api/test_message_app/test_update_message_status.py +++ b/test/testcases/test_web_api/test_message_app/test_update_message_status.py @@ -40,7 +40,6 @@ class TestAuthorization: @pytest.mark.usefixtures("add_memory_with_5_raw_message_func") class TestUpdateMessageStatus: - @pytest.mark.p1 def test_update_to_false(self, WebApiAuth): memory_id = self.memory_id diff --git a/test/testcases/test_web_api/test_plugin_app/test_llm_tools.py b/test/testcases/test_web_api/test_plugin_app/test_llm_tools.py index 75a18b20bd..81db453519 100644 --- a/test/testcases/test_web_api/test_plugin_app/test_llm_tools.py +++ b/test/testcases/test_web_api/test_plugin_app/test_llm_tools.py @@ -51,6 +51,7 @@ class _DummyManager: def route(self, *_args, **_kwargs): def decorator(func): return func + return decorator diff --git a/test/testcases/test_web_api/test_system_app/test_apps_init_unit.py b/test/testcases/test_web_api/test_system_app/test_apps_init_unit.py index 0ba7ee463a..eb367c2f94 100644 --- a/test/testcases/test_web_api/test_system_app/test_apps_init_unit.py +++ b/test/testcases/test_web_api/test_system_app/test_apps_init_unit.py @@ -183,10 +183,7 @@ def test_load_user_api_token_fallback_and_fallback_exception(monkeypatch, caplog return [] def _query_user(**kwargs): - if ( - kwargs.get("id") == "tenant-1" - and kwargs.get("status") == apps_module.StatusEnum.VALID.value - ): + if kwargs.get("id") == "tenant-1" and kwargs.get("status") == apps_module.StatusEnum.VALID.value: return [beta_user] return [] diff --git a/test/testcases/test_web_api/test_system_app/test_system_routes_unit.py b/test/testcases/test_web_api/test_system_app/test_system_routes_unit.py index 6a2559b151..a9735fa300 100644 --- a/test/testcases/test_web_api/test_system_app/test_system_routes_unit.py +++ b/test/testcases/test_web_api/test_system_app/test_system_routes_unit.py @@ -122,9 +122,7 @@ def _load_system_module(monkeypatch): monkeypatch.setitem(sys.modules, "api.db.services.knowledgebase_service", kb_service_mod) user_service_mod = ModuleType("api.db.services.user_service") - user_service_mod.UserTenantService = SimpleNamespace( - query=lambda **_kwargs: [SimpleNamespace(role="owner", tenant_id="tenant-1")] - ) + user_service_mod.UserTenantService = SimpleNamespace(query=lambda **_kwargs: [SimpleNamespace(role="owner", tenant_id="tenant-1")]) monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod) db_models_mod = ModuleType("api.db.db_models") @@ -214,6 +212,7 @@ def test_status_branch_matrix_unit(monkeypatch): assert "Lost connection!" in res["data"]["redis"]["error"] assert res["data"]["task_executor_heartbeats"] == {} + @pytest.mark.p2 def test_get_config_returns_register_enabled_unit(monkeypatch): module = _load_system_module(monkeypatch) diff --git a/test/testcases/test_web_api/test_user_app/test_tenant_app_unit.py b/test/testcases/test_web_api/test_user_app/test_tenant_app_unit.py index cafe5576e3..d9a99e0e3c 100644 --- a/test/testcases/test_web_api/test_user_app/test_tenant_app_unit.py +++ b/test/testcases/test_web_api/test_user_app/test_tenant_app_unit.py @@ -149,7 +149,7 @@ def _load_tenant_module(monkeypatch): api_utils_mod.get_json_result = lambda data=None, message="", code=0: {"code": code, "message": message, "data": data} api_utils_mod.get_data_error_result = lambda message="": {"code": 102, "message": message, "data": False} api_utils_mod.server_error_response = lambda exc: {"code": 100, "message": repr(exc), "data": False} - api_utils_mod.validate_request = lambda *_args, **_kwargs: (lambda fn: fn) + api_utils_mod.validate_request = lambda *_args, **_kwargs: lambda fn: fn api_utils_mod.get_request_json = lambda: _AwaitableValue({}) monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) diff --git a/test/testcases/test_web_api/test_user_app/test_user_app_unit.py b/test/testcases/test_web_api/test_user_app/test_user_app_unit.py index 08d4ab1f76..5759be169f 100644 --- a/test/testcases/test_web_api/test_user_app/test_user_app_unit.py +++ b/test/testcases/test_web_api/test_user_app/test_user_app_unit.py @@ -166,9 +166,7 @@ def _load_user_app(monkeypatch): api_pkg.apps = apps_mod apps_auth_mod = ModuleType("api.apps.auth") - apps_auth_mod.get_auth_client = lambda _config: SimpleNamespace( - get_authorization_url=lambda state: f"https://oauth.example/{state}" - ) + apps_auth_mod.get_auth_client = lambda _config: SimpleNamespace(get_authorization_url=lambda state: f"https://oauth.example/{state}") monkeypatch.setitem(sys.modules, "api.apps.auth", apps_auth_mod) db_mod = ModuleType("api.db") @@ -232,16 +230,7 @@ def _load_user_app(monkeypatch): @staticmethod def get_api_key(tenant_id, model_name, model_type=None): return _MockTableObject( - id=1, - tenant_id=tenant_id, - llm_factory="", - model_type="chat", - llm_name=model_name, - api_key="fake-api-key", - api_base="https://api.example.com", - max_tokens=8192, - used_tokens=0, - status=1 + id=1, tenant_id=tenant_id, llm_factory="", model_type="chat", llm_name=model_name, api_key="fake-api-key", api_base="https://api.example.com", max_tokens=8192, used_tokens=0, status=1 ) tenant_llm_service_mod.TenantLLMService = _StubTenantLLMService diff --git a/test/testcases/utils/file_utils.py b/test/testcases/utils/file_utils.py index 6ccfb02dc2..788b4611f5 100644 --- a/test/testcases/utils/file_utils.py +++ b/test/testcases/utils/file_utils.py @@ -84,24 +84,14 @@ def create_json_file(path): def create_eml_file(path): - eml_content = ( - "From: sender@example.com\n" - "To: receiver@example.com\n" - "Subject: Test EML File\n\n" - "This is a test email content.\n" - ) + eml_content = "From: sender@example.com\nTo: receiver@example.com\nSubject: Test EML File\n\nThis is a test email content.\n" with open(path, "w", encoding="utf-8") as f: f.write(eml_content) return path def create_html_file(path): - html_content = ( - "\n" - "Test HTML File\n" - "

This is a test HTML file

\n" - "" - ) + html_content = "\nTest HTML File\n

This is a test HTML file

\n" with open(path, "w", encoding="utf-8") as f: f.write(html_content) return path diff --git a/test/unit_test/agent/component/test_akshare.py b/test/unit_test/agent/component/test_akshare.py index 35d6fa2728..d33a50683e 100644 --- a/test/unit_test/agent/component/test_akshare.py +++ b/test/unit_test/agent/component/test_akshare.py @@ -29,7 +29,7 @@ def _make_tool(top_n=10): tool.check_if_canceled = lambda *a, **k: False out = {} tool.set_output = lambda k, v: out.__setitem__(k, v) - tool.output = lambda k=None: (out.get(k) if k else out) + tool.output = lambda k=None: out.get(k) if k else out return tool, out diff --git a/test/unit_test/agent/component/test_browser_use_component.py b/test/unit_test/agent/component/test_browser_use_component.py index 9b3ff8c39c..b2e6144f9a 100644 --- a/test/unit_test/agent/component/test_browser_use_component.py +++ b/test/unit_test/agent/component/test_browser_use_component.py @@ -25,6 +25,7 @@ from api.db import FileType def _install_cv2_stub_if_unavailable(): try: import cv2 # noqa: F401 + return except Exception: pass diff --git a/test/unit_test/agent/component/test_googlescholar.py b/test/unit_test/agent/component/test_googlescholar.py index 0fa4b18e41..27a5f6543f 100644 --- a/test/unit_test/agent/component/test_googlescholar.py +++ b/test/unit_test/agent/component/test_googlescholar.py @@ -57,7 +57,7 @@ def _make_tool(top_n): gs._retrieve_chunks = fake_retrieve gs.set_output = lambda k, v: out.__setitem__(k, v) - gs.output = lambda k=None: (out.get(k) if k else out) + gs.output = lambda k=None: out.get(k) if k else out return gs, captured, out diff --git a/test/unit_test/agent/sandbox/test_local_provider.py b/test/unit_test/agent/sandbox/test_local_provider.py index 25fbe7e03c..7730f628ed 100644 --- a/test/unit_test/agent/sandbox/test_local_provider.py +++ b/test/unit_test/agent/sandbox/test_local_provider.py @@ -56,12 +56,7 @@ def test_local_provider_collects_artifacts(tmp_path): try: result = provider.execute_code( instance.instance_id, - ( - "from pathlib import Path\n" - "def main() -> dict:\n" - " Path('artifacts/chart.png').write_bytes(b'PNGDATA')\n" - " return {'ok': True}\n" - ), + ("from pathlib import Path\ndef main() -> dict:\n Path('artifacts/chart.png').write_bytes(b'PNGDATA')\n return {'ok': True}\n"), "python", timeout=5, ) diff --git a/test/unit_test/agent/sandbox/test_ssh_provider.py b/test/unit_test/agent/sandbox/test_ssh_provider.py index 74787313dd..a07e9b8ed1 100644 --- a/test/unit_test/agent/sandbox/test_ssh_provider.py +++ b/test/unit_test/agent/sandbox/test_ssh_provider.py @@ -56,7 +56,7 @@ class _FakeSFTP: for file_path, payload in self.files.items(): if not file_path.startswith(prefix): continue - relative = file_path[len(prefix):] + relative = file_path[len(prefix) :] if "/" in relative: continue names.append( @@ -115,9 +115,7 @@ def test_ssh_provider_executes_python_main_and_collects_artifacts(monkeypatch): return "", "", 0 if command.startswith("cd /tmp/ws-123 && python3 /tmp/ws-123/main.py"): fake_sftp.files["/tmp/ws-123/artifacts/chart.png"] = b"PNGDATA" - payload = base64.b64encode( - b'{"present":true,"value":{"message":"hello ssh"},"type":"json"}' - ).decode("ascii") + payload = base64.b64encode(b'{"present":true,"value":{"message":"hello ssh"},"type":"json"}').decode("ascii") return f"debug line\n{RESULT_MARKER_PREFIX}{payload}\n", "", 0 if command.startswith("rm -rf "): return "", "", 0 diff --git a/test/unit_test/agent/test_dsl_bridge_roundtrip.py b/test/unit_test/agent/test_dsl_bridge_roundtrip.py index b4c2381b8b..3b4ae0524c 100644 --- a/test/unit_test/agent/test_dsl_bridge_roundtrip.py +++ b/test/unit_test/agent/test_dsl_bridge_roundtrip.py @@ -127,9 +127,7 @@ def _v1_components_to_graph(components: dict[str, Any]) -> tuple[list[dict], lis comp = raw if isinstance(raw, dict) else {} obj = comp.get("obj") if isinstance(comp.get("obj"), dict) else {} name = obj.get("component_name") or comp.get("name") or key - params = obj.get("params") if isinstance(obj.get("params"), dict) else ( - comp.get("params") if isinstance(comp.get("params"), dict) else {} - ) + params = obj.get("params") if isinstance(obj.get("params"), dict) else (comp.get("params") if isinstance(comp.get("params"), dict) else {}) nodes.append( { "id": key, @@ -191,9 +189,7 @@ def _graph_to_v1_components(graph: dict[str, Any]) -> dict[str, Any]: return components -def _build_dsl_components_by_graph( - nodes: list[dict], edges: list[dict], seed: dict[str, Any] -) -> dict[str, Any]: +def _build_dsl_components_by_graph(nodes: list[dict], edges: list[dict], seed: dict[str, Any]) -> dict[str, Any]: """Port of `buildDslComponentsByGraph` (web/src/pages/agent/utils.ts:472). Reverse-derives a v1-style `components` map from React-Flow @@ -247,9 +243,7 @@ def _build_v1_dsl_from_import(raw: dict[str, Any], is_agent: bool) -> dict[str, out["components"] = raw["components"] layout = raw.get("_layout") out["_layout"] = layout - elif isinstance(raw.get("graph"), dict) and ( - isinstance(raw["graph"].get("nodes"), list) and raw["graph"]["nodes"] - ): + elif isinstance(raw.get("graph"), dict) and (isinstance(raw["graph"].get("nodes"), list) and raw["graph"]["nodes"]): graph = raw["graph"] out["components"] = _graph_to_v1_components(graph) # Mirror the TS bridge: stash positions in _layout so the @@ -269,9 +263,7 @@ def _build_v2_dsl_from_import(raw: dict[str, Any], is_agent: bool) -> dict[str, shapes and produces a v2 envelope. """ out: dict[str, Any] = dict(raw) - if isinstance(raw.get("graph"), dict) and ( - isinstance(raw["graph"].get("nodes"), list) and raw["graph"]["nodes"] - ): + if isinstance(raw.get("graph"), dict) and (isinstance(raw["graph"].get("nodes"), list) and raw["graph"]["nodes"]): out["graph"] = raw["graph"] # v2 input that carries its own `components` (e.g. a v2 file # exported from the front-end with `bridge.exportDsl`) keeps @@ -293,11 +285,7 @@ def _build_v2_dsl_from_import(raw: dict[str, Any], is_agent: bool) -> dict[str, # v1 → v2 cross-mode: prefer saved _layout positions over # the default 50/350/200 row layout the inverse-conversion # would produce. The user's drag-and-drop work survives. - if ( - isinstance(layout, dict) - and isinstance(layout.get("nodes"), list) - and layout["nodes"] - ): + if isinstance(layout, dict) and isinstance(layout.get("nodes"), list) and layout["nodes"]: out["graph"] = { "nodes": layout["nodes"], "edges": layout.get("edges") or [], @@ -326,9 +314,7 @@ def dsl_to_graph(dsl: dict[str, Any]) -> tuple[list[dict], list[dict]]: return [], [] -def graph_to_dsl( - mode: str, nodes: list[dict], edges: list[dict], old_dsl: dict[str, Any] -) -> dict[str, Any]: +def graph_to_dsl(mode: str, nodes: list[dict], edges: list[dict], old_dsl: dict[str, Any]) -> dict[str, Any]: """Port of `graphToDsl` (both v1 and v2 branches).""" out = dict(old_dsl) if mode == "v1": @@ -434,10 +420,7 @@ class Diff: """pytest entry point: warn on warnings, fail on failures.""" for w in self.warnings: warnings.warn(f"[React-Flow-internal] {w}", stacklevel=2) - assert self.failures == [], ( - f"{len(self.failures)} round-trip mismatches:\n" - + "\n".join(f" - {f}" for f in self.failures) - ) + assert self.failures == [], f"{len(self.failures)} round-trip mismatches:\n" + "\n".join(f" - {f}" for f in self.failures) def _stable(v: Any) -> str: @@ -475,9 +458,7 @@ def _compare_into(expected: Any, actual: Any, path: str, out: Diff) -> None: if exp_arr: if len(expected) != len(actual): - out.failures.append( - f"{path or ''}: length {len(expected)} != {len(actual)}" - ) + out.failures.append(f"{path or ''}: length {len(expected)} != {len(actual)}") for i in range(min(len(expected), len(actual))): _compare_into(expected[i], actual[i], f"{path}[{i}]", out) return @@ -592,6 +573,4 @@ class TestDslBridgeRoundTrip: # pointer. The diff walks into `position: {x, y}` and # reports each leaf individually, so only `x` shows up # (`y` matches the expected 100). - assert diff.failures == [ - "position.x: value (100 vs 999)" - ] + assert diff.failures == ["position.x: value (100 vs 999)"] diff --git a/test/unit_test/agent/tools/test_exesql_ssrf.py b/test/unit_test/agent/tools/test_exesql_ssrf.py index ec8961d8cd..6321ad2465 100644 --- a/test/unit_test/agent/tools/test_exesql_ssrf.py +++ b/test/unit_test/agent/tools/test_exesql_ssrf.py @@ -25,6 +25,7 @@ must dial the validated/resolved public IP for allowed hosts — mirroring the auto-discover every tool and pull in the full agent framework), with the heavy DB drivers and the agent base classes stubbed so only the real SSRF guard runs. """ + import importlib.util import sys import types @@ -80,12 +81,10 @@ def _load_exesql_module(): # Neutralize the @timeout decorator so _invoke is a plain method. conn_utils = types.ModuleType("common.connection_utils") - conn_utils.timeout = lambda *a, **k: (lambda f: f) + conn_utils.timeout = lambda *a, **k: lambda f: f sys.modules["common.connection_utils"] = conn_utils - spec = importlib.util.spec_from_file_location( - "exesql_uut", _REPO_ROOT / "agent" / "tools" / "exesql.py" - ) + spec = importlib.util.spec_from_file_location("exesql_uut", _REPO_ROOT / "agent" / "tools" / "exesql.py") mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) return mod @@ -99,8 +98,12 @@ def _build_exesql(host, db_type="mysql"): cpn = ExeSQL.__new__(ExeSQL) cpn._canvas = SimpleNamespace() cpn._param = SimpleNamespace( - host=host, port=3306, db_type=db_type, - database="db", username="u", password="p", + host=host, + port=3306, + db_type=db_type, + database="db", + username="u", + password="p", ) # Neutralize the component machinery that runs before the host check. cpn.check_if_canceled = lambda *_a, **_k: False diff --git a/test/unit_test/agent/tools/test_http_timeout.py b/test/unit_test/agent/tools/test_http_timeout.py index 49e9d9be97..bc64c08658 100644 --- a/test/unit_test/agent/tools/test_http_timeout.py +++ b/test/unit_test/agent/tools/test_http_timeout.py @@ -97,9 +97,5 @@ def _tool_files(): @pytest.mark.parametrize("path", _tool_files(), ids=lambda p: p.name) def test_http_calls_have_timeout(path: Path): tree = ast.parse(path.read_text(encoding="utf-8"), filename=str(path)) - missing = [ - f"{path.name}:{call.lineno}" - for call in _iter_request_calls(tree) - if not _has_timeout(call) - ] + missing = [f"{path.name}:{call.lineno}" for call in _iter_request_calls(tree) if not _has_timeout(call)] assert not missing, "HTTP request(s) without timeout=: " + ", ".join(missing) diff --git a/test/unit_test/api/apps/restful_apis/test_agentbots_access_control.py b/test/unit_test/api/apps/restful_apis/test_agentbots_access_control.py index beff9cba9a..938f3c5ed7 100644 --- a/test/unit_test/api/apps/restful_apis/test_agentbots_access_control.py +++ b/test/unit_test/api/apps/restful_apis/test_agentbots_access_control.py @@ -67,10 +67,11 @@ def _load_bot_api(monkeypatch, *, accessible, calls): async def _gen(): yield 'data: {"event":"message","data":{"content":"ok"}}\n\n' yield 'data: {"event":"message_end","data":{"content":"ok"}}\n\n' + return _gen() _stub(monkeypatch, "quart", Response=lambda *a, **k: SimpleNamespace(headers=SimpleNamespace(add_header=lambda *aa, **kk: None)), request=SimpleNamespace()) - _stub(monkeypatch, "api.apps", AUTH_BETA="beta", login_required=lambda *_a, **_k: (lambda func: func)) + _stub(monkeypatch, "api.apps", AUTH_BETA="beta", login_required=lambda *_a, **_k: lambda func: func) _stub(monkeypatch, "agent.canvas", Canvas=lambda *a, **k: SimpleNamespace(get_component_input_form=lambda _n: {}, get_prologue=lambda: "", get_mode=lambda: "agent")) _stub(monkeypatch, "api.db.db_models", APIToken=SimpleNamespace(query=lambda **_k: [SimpleNamespace(tenant_id="attacker-tenant")])) _stub(monkeypatch, "api.db.services.api_service", API4ConversationService=SimpleNamespace()) diff --git a/test/unit_test/api/apps/restful_apis/test_attachment_download_missing_blob.py b/test/unit_test/api/apps/restful_apis/test_attachment_download_missing_blob.py index b581918642..2bcdd8840c 100644 --- a/test/unit_test/api/apps/restful_apis/test_attachment_download_missing_blob.py +++ b/test/unit_test/api/apps/restful_apis/test_attachment_download_missing_blob.py @@ -64,7 +64,8 @@ def _load_agent_api(monkeypatch, *, storage_get): return SimpleNamespace(payload=payload, headers={}) _stub( - monkeypatch, "api.apps", + monkeypatch, + "api.apps", current_user=SimpleNamespace(id="tenant-1"), login_required=lambda func: func, ) @@ -72,7 +73,8 @@ def _load_agent_api(monkeypatch, *, storage_get): _stub(monkeypatch, "api.db", CanvasCategory=SimpleNamespace()) _stub(monkeypatch, "api.db.db_models", Task=SimpleNamespace()) _stub( - monkeypatch, "api.db.services.api_service", + monkeypatch, + "api.db.services.api_service", API4ConversationService=SimpleNamespace( get_by_id=lambda _id: (False, None), save=lambda **_k: True, @@ -81,7 +83,8 @@ def _load_agent_api(monkeypatch, *, storage_get): ), ) _stub( - monkeypatch, "api.db.services.canvas_service", + monkeypatch, + "api.db.services.canvas_service", CanvasTemplateService=SimpleNamespace(), UserCanvasService=SimpleNamespace(accessible=lambda *_a, **_k: True, query=lambda **_k: []), completion=lambda *_a, **_k: None, @@ -92,17 +95,23 @@ def _load_agent_api(monkeypatch, *, storage_get): _stub(monkeypatch, "api.db.services.knowledgebase_service", KnowledgebaseService=SimpleNamespace()) _stub(monkeypatch, "api.db.services.pipeline_operation_log_service", PipelineOperationLogService=SimpleNamespace()) _stub( - monkeypatch, "api.db.services.task_service", - CANVAS_DEBUG_DOC_ID="", TaskService=SimpleNamespace(), queue_dataflow=lambda *_a, **_k: None, + monkeypatch, + "api.db.services.task_service", + CANVAS_DEBUG_DOC_ID="", + TaskService=SimpleNamespace(), + queue_dataflow=lambda *_a, **_k: None, ) _stub( - monkeypatch, "api.db.services.user_service", - TenantService=SimpleNamespace(), UserService=SimpleNamespace(get_by_id=lambda *_a, **_k: (False, None)), + monkeypatch, + "api.db.services.user_service", + TenantService=SimpleNamespace(), + UserService=SimpleNamespace(get_by_id=lambda *_a, **_k: (False, None)), ) _stub(monkeypatch, "api.db.services.user_canvas_version", UserCanvasVersionService=SimpleNamespace()) _stub( - monkeypatch, "api.utils.api_utils", + monkeypatch, + "api.utils.api_utils", construct_json_result=lambda **kw: {"kind": "json", **kw}, get_data_error_result=lambda message="", code=0, data=False: {"kind": "data_error", "message": message}, get_error_data_result=lambda *_a, **_k: {"kind": "error"}, @@ -114,11 +123,13 @@ def _load_agent_api(monkeypatch, *, storage_get): # Used as `@validate_request(...)` decorator factory at module level, so it # must return an identity decorator (the lenient fallback would return None # and `@None` raises TypeError during import). - validate_request=lambda *_a, **_k: (lambda func: func), + validate_request=lambda *_a, **_k: lambda func: func, ) _stub( - monkeypatch, "common.settings", - retriever=SimpleNamespace(), kg_retriever=SimpleNamespace(), + monkeypatch, + "common.settings", + retriever=SimpleNamespace(), + kg_retriever=SimpleNamespace(), # download_attachment reads settings.STORAGE_IMPL.get after # `from common import settings` rebinds the module's `settings` name. STORAGE_IMPL=SimpleNamespace(get=storage_get), @@ -135,7 +146,8 @@ def _load_agent_api(monkeypatch, *, storage_get): _stub(monkeypatch, "common.misc_utils", get_uuid=lambda: "uuid", thread_pool_exec=_thread_pool_exec) _stub( - monkeypatch, "api.utils.web_utils", + monkeypatch, + "api.utils.web_utils", CONTENT_TYPE_MAP={"markdown": "text/markdown"}, apply_safe_file_response_headers=lambda *_a, **_k: None, ) diff --git a/test/unit_test/api/apps/restful_apis/test_openai_stream_no_duplicate.py b/test/unit_test/api/apps/restful_apis/test_openai_stream_no_duplicate.py index 3fad964762..77748cc8b2 100644 --- a/test/unit_test/api/apps/restful_apis/test_openai_stream_no_duplicate.py +++ b/test/unit_test/api/apps/restful_apis/test_openai_stream_no_duplicate.py @@ -80,7 +80,7 @@ def _load_openai_api(monkeypatch): "api.utils.api_utils", get_error_data_result=lambda *a, **k: {"code": 102}, get_request_json=lambda: {}, - validate_request=lambda *_a, **_k: (lambda func: func), + validate_request=lambda *_a, **_k: lambda func: func, ) _stub(monkeypatch, "common.constants", RetCode=SimpleNamespace(ARGUMENT_ERROR=102), StatusEnum=SimpleNamespace(VALID=SimpleNamespace(value="1"))) _stub(monkeypatch, "common.metadata_utils", convert_conditions=lambda *_a, **_k: None, meta_filter=lambda *_a, **_k: []) @@ -107,6 +107,7 @@ async def _aiter(events): def _collect_sse(module, events, **kwargs): """Run the SSE generator over `events` and return parsed JSON chunks (the trailing `[DONE]` sentinel excluded).""" + async def run(): out = [] async for raw in module._stream_chat_completion_sse(_aiter(events), **kwargs): @@ -205,11 +206,7 @@ def test_reasoning_content_streamed_separately(monkeypatch): ] chunks = _collect_sse(module, events, need_reference=False, **_BASE_KWARGS) - reasoning = "".join( - c["choices"][0]["delta"].get("reasoning_content") - for c in chunks - if c != "[DONE]" and isinstance(c["choices"][0]["delta"].get("reasoning_content"), str) - ) + reasoning = "".join(c["choices"][0]["delta"].get("reasoning_content") for c in chunks if c != "[DONE]" and isinstance(c["choices"][0]["delta"].get("reasoning_content"), str)) content = "".join(p for p in _content_pieces(chunks) if isinstance(p, str)) assert reasoning == "thinking" assert content == "answer" diff --git a/test/unit_test/api/apps/sdk/test_dify_retrieval.py b/test/unit_test/api/apps/sdk/test_dify_retrieval.py index fbf5b5225f..0b286544d4 100644 --- a/test/unit_test/api/apps/sdk/test_dify_retrieval.py +++ b/test/unit_test/api/apps/sdk/test_dify_retrieval.py @@ -1,4 +1,4 @@ -# +# # Copyright 2026 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -84,6 +84,7 @@ class _FakeKGRetriever: def _load_dify_retrieval(monkeypatch, *, kb, accessible, request_body, tenant_id, chunks=None): """Load dify_retrieval_api.py with minimum stubs to exercise the retrieval handler.""" + def _add_tenant_id_to_kwargs(func): async def wrapper(**kwargs): kwargs["tenant_id"] = tenant_id diff --git a/test/unit_test/api/apps/services/test_update_document_name_only.py b/test/unit_test/api/apps/services/test_update_document_name_only.py index a644bc2fdc..2b1ecac7ed 100644 --- a/test/unit_test/api/apps/services/test_update_document_name_only.py +++ b/test/unit_test/api/apps/services/test_update_document_name_only.py @@ -99,16 +99,8 @@ def _load_update_document_name_only_module(monkeypatch, *, file_lookup): fine_grained_tokenize=lambda tokens: tokens, ) - module_path = ( - Path(__file__).resolve().parents[5] - / "api" - / "apps" - / "services" - / "document_api_service.py" - ) - spec = importlib.util.spec_from_file_location( - "test_update_document_name_only_module", module_path - ) + module_path = Path(__file__).resolve().parents[5] / "api" / "apps" / "services" / "document_api_service.py" + spec = importlib.util.spec_from_file_location("test_update_document_name_only_module", module_path) module = importlib.util.module_from_spec(spec) monkeypatch.setitem(sys.modules, "test_update_document_name_only_module", module) spec.loader.exec_module(module) diff --git a/test/unit_test/api/db/joint_services/test_tenant_model_service_max_tokens.py b/test/unit_test/api/db/joint_services/test_tenant_model_service_max_tokens.py index b52f625b71..c17a28b37b 100644 --- a/test/unit_test/api/db/joint_services/test_tenant_model_service_max_tokens.py +++ b/test/unit_test/api/db/joint_services/test_tenant_model_service_max_tokens.py @@ -65,9 +65,7 @@ def test_max_tokens_falls_back_to_factory_when_model_extra_empty(monkeypatch): ], ) - config = tms.get_model_config_from_provider_instance( - "tenant-1", "chat", "gpt-test@default@OpenAI" - ) + config = tms.get_model_config_from_provider_instance("tenant-1", "chat", "gpt-test@default@OpenAI") assert config["max_tokens"] == 128000 @@ -115,8 +113,6 @@ def test_max_tokens_prefers_model_extra_over_factory(monkeypatch): ], ) - config = tms.get_model_config_from_provider_instance( - "tenant-1", "chat", "gpt-test@default@OpenAI" - ) + config = tms.get_model_config_from_provider_instance("tenant-1", "chat", "gpt-test@default@OpenAI") assert config["max_tokens"] == 32000 diff --git a/test/unit_test/api/db/services/test_dataset_access_permissions.py b/test/unit_test/api/db/services/test_dataset_access_permissions.py index e3db6d0f2a..0d3c0e8d60 100644 --- a/test/unit_test/api/db/services/test_dataset_access_permissions.py +++ b/test/unit_test/api/db/services/test_dataset_access_permissions.py @@ -30,6 +30,7 @@ warnings.filterwarnings( def _install_cv2_stub_if_unavailable(): try: import cv2 # noqa: F401 + return except Exception: pass diff --git a/test/unit_test/api/db/services/test_dialog_service_final_answer.py b/test/unit_test/api/db/services/test_dialog_service_final_answer.py index c3fd15ea92..be1062e348 100644 --- a/test/unit_test/api/db/services/test_dialog_service_final_answer.py +++ b/test/unit_test/api/db/services/test_dialog_service_final_answer.py @@ -48,6 +48,7 @@ warnings.filterwarnings( def _install_cv2_stub_if_unavailable(): try: import cv2 # noqa: F401 + return except Exception: pass @@ -197,6 +198,7 @@ class _FakeLangfuseClient: def _collect(async_gen): async def _run(): return [ev async for ev in async_gen] + return asyncio.run(_run()) @@ -204,6 +206,7 @@ def _collect(async_gen): # Tests for async_ask (production code path) # --------------------------------------------------------------------------- + @pytest.mark.p2 def test_async_ask_final_event_carries_decorated_answer(monkeypatch): """ @@ -218,23 +221,21 @@ def test_async_ask_final_event_carries_decorated_answer(monkeypatch): chat_mdl = _StreamingChatModel(llm_answer) retriever = _StubRetriever() + monkeypatch.setattr(dialog_service.KnowledgebaseService, "get_by_ids", lambda _ids: [_KB]) monkeypatch.setattr( - dialog_service.KnowledgebaseService, "get_by_ids", lambda _ids: [_KB] - ) - monkeypatch.setattr( - dialog_service, "get_model_config_from_provider_instance", + dialog_service, + "get_model_config_from_provider_instance", lambda _tid, _type, _name: _LLM_CONFIG, ) monkeypatch.setattr(dialog_service, "LLMBundle", lambda _tid, _cfg: chat_mdl) monkeypatch.setattr(dialog_service.settings, "retriever", retriever, raising=False) monkeypatch.setattr(dialog_service.settings, "kg_retriever", retriever, raising=False) - monkeypatch.setattr( - dialog_service.DocMetadataService, "get_flatted_meta_by_kbs", lambda _ids: {} - ) + monkeypatch.setattr(dialog_service.DocMetadataService, "get_flatted_meta_by_kbs", lambda _ids: {}) monkeypatch.setattr(dialog_service, "label_question", lambda _q, _kbs: "") # kb_prompt calls DocumentService.get_by_ids which needs a live DB; stub it out. monkeypatch.setattr( - dialog_service, "kb_prompt", + dialog_service, + "kb_prompt", lambda _kbinfos, _max_tokens, **_kw: ["RAGFlow is a RAG engine."], ) @@ -249,9 +250,7 @@ def test_async_ask_final_event_carries_decorated_answer(monkeypatch): assert events, "async_ask must yield at least one event" final_events = [e for e in events if e.get("final") is True] - assert len(final_events) == 1, ( - f"Expected exactly one final event, got {len(final_events)}: {final_events}" - ) + assert len(final_events) == 1, f"Expected exactly one final event, got {len(final_events)}: {final_events}" final = final_events[0] assert "answer" in final @@ -267,22 +266,20 @@ def test_async_ask_delta_events_carry_incremental_text_only(monkeypatch): chat_mdl = _StreamingChatModel("Incremental text for delta test.") retriever = _StubRetriever() + monkeypatch.setattr(dialog_service.KnowledgebaseService, "get_by_ids", lambda _ids: [_KB]) monkeypatch.setattr( - dialog_service.KnowledgebaseService, "get_by_ids", lambda _ids: [_KB] - ) - monkeypatch.setattr( - dialog_service, "get_model_config_from_provider_instance", + dialog_service, + "get_model_config_from_provider_instance", lambda _tid, _type, _name: _LLM_CONFIG, ) monkeypatch.setattr(dialog_service, "LLMBundle", lambda _tid, _cfg: chat_mdl) monkeypatch.setattr(dialog_service.settings, "retriever", retriever, raising=False) monkeypatch.setattr(dialog_service.settings, "kg_retriever", retriever, raising=False) - monkeypatch.setattr( - dialog_service.DocMetadataService, "get_flatted_meta_by_kbs", lambda _ids: {} - ) + monkeypatch.setattr(dialog_service.DocMetadataService, "get_flatted_meta_by_kbs", lambda _ids: {}) monkeypatch.setattr(dialog_service, "label_question", lambda _q, _kbs: "") monkeypatch.setattr( - dialog_service, "kb_prompt", + dialog_service, + "kb_prompt", lambda _kbinfos, _max_tokens, **_kw: ["RAGFlow is a RAG engine."], ) @@ -295,15 +292,13 @@ def test_async_ask_delta_events_carry_incremental_text_only(monkeypatch): ) delta_events = [e for e in events if not e.get("final")] - final_events = [e for e in events if e.get("final") is True] + final_events = [e for e in events if e.get("final") is True] assert len(final_events) == 1, f"Expected exactly one final event, got {len(final_events)}" for ev in delta_events: assert ev["reference"] == {}, f"Delta event must have empty reference, got: {ev['reference']}" - assert "chunks" in final_events[0]["reference"], ( - "Final event reference must contain chunk data from decorate_answer()" - ) + assert "chunks" in final_events[0]["reference"], "Final event reference must contain chunk data from decorate_answer()" @pytest.mark.p2 @@ -311,9 +306,7 @@ def test_async_ask_empty_kb_ids_yields_error_final_event(monkeypatch): """ When kb_ids is empty, async_ask() must not crash with IndexError on kbs[0]. """ - monkeypatch.setattr( - dialog_service.KnowledgebaseService, "get_by_ids", lambda _ids: [] - ) + monkeypatch.setattr(dialog_service.KnowledgebaseService, "get_by_ids", lambda _ids: []) events = _collect( dialog_service.async_ask( @@ -357,6 +350,7 @@ def test_async_ask_stale_kb_ids_yields_error_final_event(monkeypatch): # Tests for async_chat (production code path) # --------------------------------------------------------------------------- + def _make_dialog(chat_mdl_stub): """Build a minimal dialog SimpleNamespace for async_chat().""" return SimpleNamespace( @@ -405,33 +399,30 @@ def test_async_chat_final_event_carries_decorated_answer(monkeypatch): retriever = _StubRetriever() # Stub out the heavy service/model calls + monkeypatch.setattr(dialog_service, "get_model_type_by_name", lambda _tid, _llm_id: ["chat"]) monkeypatch.setattr( - dialog_service, "get_model_type_by_name", - lambda _tid, _llm_id: ["chat"] - ) - monkeypatch.setattr( - dialog_service, "get_model_config_from_provider_instance", + dialog_service, + "get_model_config_from_provider_instance", lambda _tid, _type, _llm_id: _LLM_CONFIG, ) monkeypatch.setattr( - dialog_service.TenantLangfuseService, "filter_by_tenant", + dialog_service.TenantLangfuseService, + "filter_by_tenant", lambda tenant_id: None, ) # get_models returns (kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl) monkeypatch.setattr( - dialog_service, "get_models", + dialog_service, + "get_models", lambda _dialog, **_kwargs: ([_KB], chat_mdl, None, chat_mdl, None), ) - monkeypatch.setattr( - dialog_service.KnowledgebaseService, "get_field_map", lambda _kb_ids: {} - ) - monkeypatch.setattr( - dialog_service.KnowledgebaseService, "get_by_ids", lambda _ids: [_KB] - ) + monkeypatch.setattr(dialog_service.KnowledgebaseService, "get_field_map", lambda _kb_ids: {}) + monkeypatch.setattr(dialog_service.KnowledgebaseService, "get_by_ids", lambda _ids: [_KB]) monkeypatch.setattr(dialog_service.settings, "retriever", retriever, raising=False) monkeypatch.setattr(dialog_service, "label_question", lambda _q, _kbs: "") monkeypatch.setattr( - dialog_service, "kb_prompt", + dialog_service, + "kb_prompt", lambda _kbinfos, _max_tokens, **_kw: ["RAGFlow is a RAG engine."], ) @@ -441,9 +432,7 @@ def test_async_chat_final_event_carries_decorated_answer(monkeypatch): events = _collect(dialog_service.async_chat(dialog, messages, stream=True, quote=True)) final_events = [e for e in events if e.get("final") is True] - assert len(final_events) == 1, ( - f"Expected exactly one final event, got {len(final_events)}: {final_events}" - ) + assert len(final_events) == 1, f"Expected exactly one final event, got {len(final_events)}: {final_events}" final = final_events[0] assert "answer" in final @@ -462,16 +451,15 @@ def test_async_chat_langfuse_uses_start_observation(monkeypatch): chat_mdl = _StreamingChatModel(llm_answer) retriever = _StubRetriever() + monkeypatch.setattr(dialog_service, "get_model_type_by_name", lambda _tid, _llm_id: ["chat"]) monkeypatch.setattr( - dialog_service, "get_model_type_by_name", - lambda _tid, _llm_id: ["chat"] - ) - monkeypatch.setattr( - dialog_service, "get_model_config_from_provider_instance", + dialog_service, + "get_model_config_from_provider_instance", lambda _tid, _type, _llm_id: _LLM_CONFIG, ) monkeypatch.setattr( - dialog_service.TenantLangfuseService, "filter_by_tenant", + dialog_service.TenantLangfuseService, + "filter_by_tenant", lambda tenant_id: SimpleNamespace( public_key="public", secret_key="secret", @@ -486,12 +474,8 @@ def test_async_chat_langfuse_uses_start_observation(monkeypatch): "get_models", lambda _dialog, **_kwargs: ([_KB], chat_mdl, None, chat_mdl, None), ) - monkeypatch.setattr( - dialog_service.KnowledgebaseService, "get_field_map", lambda _kb_ids: {} - ) - monkeypatch.setattr( - dialog_service.KnowledgebaseService, "get_by_ids", lambda _ids: [_KB] - ) + monkeypatch.setattr(dialog_service.KnowledgebaseService, "get_field_map", lambda _kb_ids: {}) + monkeypatch.setattr(dialog_service.KnowledgebaseService, "get_by_ids", lambda _ids: [_KB]) monkeypatch.setattr(dialog_service.settings, "retriever", retriever, raising=False) monkeypatch.setattr(dialog_service, "label_question", lambda _q, _kbs: "") monkeypatch.setattr( @@ -530,17 +514,15 @@ def test_async_chat_langfuse_observation_includes_session_id(monkeypatch): chat_mdl = _StreamingChatModel("Session traces should be grouped.") retriever = _StubRetriever() - monkeypatch.setattr( - dialog_service, "get_model_type_by_name", - lambda _tid, _llm_id: ["chat"] - ) + monkeypatch.setattr(dialog_service, "get_model_type_by_name", lambda _tid, _llm_id: ["chat"]) monkeypatch.setattr( dialog_service, "get_model_config_from_provider_instance", lambda _tid, _type, _llm_id: _LLM_CONFIG, ) monkeypatch.setattr( - dialog_service.TenantLangfuseService, "filter_by_tenant", + dialog_service.TenantLangfuseService, + "filter_by_tenant", lambda tenant_id: SimpleNamespace( public_key="public", secret_key="secret", @@ -554,12 +536,8 @@ def test_async_chat_langfuse_observation_includes_session_id(monkeypatch): "get_models", lambda _dialog, **_kwargs: ([_KB], chat_mdl, None, chat_mdl, None), ) - monkeypatch.setattr( - dialog_service.KnowledgebaseService, "get_field_map", lambda _kb_ids: {} - ) - monkeypatch.setattr( - dialog_service.KnowledgebaseService, "get_by_ids", lambda _ids: [_KB] - ) + monkeypatch.setattr(dialog_service.KnowledgebaseService, "get_field_map", lambda _kb_ids: {}) + monkeypatch.setattr(dialog_service.KnowledgebaseService, "get_by_ids", lambda _ids: [_KB]) monkeypatch.setattr(dialog_service.settings, "retriever", retriever, raising=False) monkeypatch.setattr(dialog_service, "label_question", lambda _q, _kbs: "") monkeypatch.setattr( @@ -635,16 +613,15 @@ def test_async_chat_continues_when_langfuse_observation_start_fails(monkeypatch) chat_mdl = _StreamingChatModel(llm_answer) retriever = _StubRetriever() + monkeypatch.setattr(dialog_service, "get_model_type_by_name", lambda _tid, _llm_id: ["chat"]) monkeypatch.setattr( - dialog_service, "get_model_type_by_name", - lambda _tid, _llm_id: ["chat"] - ) - monkeypatch.setattr( - dialog_service, "get_model_config_from_provider_instance", + dialog_service, + "get_model_config_from_provider_instance", lambda _tid, _type, _llm_id: _LLM_CONFIG, ) monkeypatch.setattr( - dialog_service.TenantLangfuseService, "filter_by_tenant", + dialog_service.TenantLangfuseService, + "filter_by_tenant", lambda tenant_id: SimpleNamespace( public_key="public", secret_key="secret", @@ -659,12 +636,8 @@ def test_async_chat_continues_when_langfuse_observation_start_fails(monkeypatch) "get_models", lambda _dialog, **_kwargs: ([_KB], chat_mdl, None, chat_mdl, None), ) - monkeypatch.setattr( - dialog_service.KnowledgebaseService, "get_field_map", lambda _kb_ids: {} - ) - monkeypatch.setattr( - dialog_service.KnowledgebaseService, "get_by_ids", lambda _ids: [_KB] - ) + monkeypatch.setattr(dialog_service.KnowledgebaseService, "get_field_map", lambda _kb_ids: {}) + monkeypatch.setattr(dialog_service.KnowledgebaseService, "get_by_ids", lambda _ids: [_KB]) monkeypatch.setattr(dialog_service.settings, "retriever", retriever, raising=False) monkeypatch.setattr(dialog_service, "label_question", lambda _q, _kbs: "") monkeypatch.setattr( diff --git a/test/unit_test/api/db/services/test_dialog_service_use_sql_source_columns.py b/test/unit_test/api/db/services/test_dialog_service_use_sql_source_columns.py index cd87ca960c..aee44eff32 100644 --- a/test/unit_test/api/db/services/test_dialog_service_use_sql_source_columns.py +++ b/test/unit_test/api/db/services/test_dialog_service_use_sql_source_columns.py @@ -316,10 +316,7 @@ def test_async_chat_uses_all_docs_when_no_doc_ids_selected(monkeypatch): ) monkeypatch.setattr(dialog_service.settings, "retriever", retriever, raising=False) - monkeypatch.setattr( - dialog_service, "get_model_type_by_name", - lambda _tid, _llm_id: ["chat"] - ) + monkeypatch.setattr(dialog_service, "get_model_type_by_name", lambda _tid, _llm_id: ["chat"]) monkeypatch.setattr( dialog_service, "get_model_config_from_provider_instance", diff --git a/test/unit_test/api/db/services/test_document_service_get_parsing_status.py b/test/unit_test/api/db/services/test_document_service_get_parsing_status.py index 997fe6f861..210594234a 100644 --- a/test/unit_test/api/db/services/test_document_service_get_parsing_status.py +++ b/test/unit_test/api/db/services/test_document_service_get_parsing_status.py @@ -31,6 +31,7 @@ warnings.filterwarnings( def _install_cv2_stub_if_unavailable(): try: import cv2 # noqa: F401 + return except Exception: pass @@ -70,6 +71,7 @@ from common.constants import TaskStatus # noqa: E402 # Helpers to access the original function bypassing @DB.connection_context() # --------------------------------------------------------------------------- + def _unwrapped_get_parsing_status(): """Return the original (un-decorated) get_parsing_status_by_kb_ids function. @@ -84,6 +86,7 @@ def _unwrapped_get_parsing_status(): # Fake ORM helpers – mimic the minimal peewee query chain used by the function # --------------------------------------------------------------------------- + class _FieldStub: """Minimal stand-in for a peewee model field used in select/where/group_by.""" @@ -130,6 +133,7 @@ def _make_fake_model(rows): # Pytest fixture – patch DocumentService.model per test # --------------------------------------------------------------------------- + @pytest.fixture() def call_with_rows(monkeypatch): """Return a helper that runs get_parsing_status_by_kb_ids with fake DB rows.""" @@ -146,14 +150,11 @@ def call_with_rows(monkeypatch): # Tests # --------------------------------------------------------------------------- -_ALL_STATUS_FIELDS = frozenset( - ["unstart_count", "running_count", "cancel_count", "done_count", "fail_count"] -) +_ALL_STATUS_FIELDS = frozenset(["unstart_count", "running_count", "cancel_count", "done_count", "fail_count"]) @pytest.mark.p2 class TestGetParsingStatusByKbIds: - # ------------------------------------------------------------------ # Edge-case: empty input list – must short-circuit before any DB call # ------------------------------------------------------------------ @@ -224,16 +225,13 @@ class TestGetParsingStatusByKbIds: def test_unknown_run_value_ignored(self, call_with_rows): rows = [ - {"kb_id": "kb-1", "run": "9", "cnt": 99}, # "9" is not a TaskStatus + {"kb_id": "kb-1", "run": "9", "cnt": 99}, # "9" is not a TaskStatus {"kb_id": "kb-1", "run": TaskStatus.DONE.value, "cnt": 4}, ] result = call_with_rows(rows=rows, kb_ids=["kb-1"]) assert result["kb-1"]["done_count"] == 4 - assert all( - result["kb-1"][f] == 0 - for f in _ALL_STATUS_FIELDS - {"done_count"} - ) + assert all(result["kb-1"][f] == 0 for f in _ALL_STATUS_FIELDS - {"done_count"}) # ------------------------------------------------------------------ # A row whose kb_id was NOT requested must not appear in the output @@ -298,9 +296,7 @@ class TestGetParsingStatusByKbIds: rows = [ {"kb_id": "kb-with-data", "run": TaskStatus.DONE.value, "cnt": 1}, ] - result = call_with_rows( - rows=rows, kb_ids=["kb-with-data", "kb-empty-1", "kb-empty-2"] - ) + result = call_with_rows(rows=rows, kb_ids=["kb-with-data", "kb-empty-1", "kb-empty-2"]) assert set(result.keys()) == {"kb-with-data", "kb-empty-1", "kb-empty-2"} assert result["kb-empty-1"] == {f: 0 for f in _ALL_STATUS_FIELDS} @@ -320,7 +316,4 @@ class TestGetParsingStatusByKbIds: assert result["kb-1"]["done_count"] == 2 # SCHEDULE is not a tracked bucket assert "schedule_count" not in result["kb-1"] - assert all( - result["kb-1"][f] == 0 - for f in _ALL_STATUS_FIELDS - {"done_count"} - ) + assert all(result["kb-1"][f] == 0 for f in _ALL_STATUS_FIELDS - {"done_count"}) diff --git a/test/unit_test/api/db/services/test_file_service_upload_document.py b/test/unit_test/api/db/services/test_file_service_upload_document.py index 8962ae8a78..de8c65d4ea 100644 --- a/test/unit_test/api/db/services/test_file_service_upload_document.py +++ b/test/unit_test/api/db/services/test_file_service_upload_document.py @@ -127,6 +127,7 @@ def test_upload_document_skips_cross_kb_document_id_collision(monkeypatch): # Helpers shared by TestValidateUrlForCrawl # --------------------------------------------------------------------------- + def _addrinfo(ip_str: str) -> list: """Build a minimal getaddrinfo-style result for a single address string.""" family = socket.AF_INET6 if ":" in ip_str else socket.AF_INET @@ -137,6 +138,7 @@ def _addrinfo(ip_str: str) -> list: # _validate_url_for_crawl SSRF-guard tests # --------------------------------------------------------------------------- + @pytest.mark.p2 class TestValidateUrlForCrawl: """Focused regression suite for the SSRF guard on the URL-crawl path. @@ -268,10 +270,7 @@ class TestValidateUrlForCrawl: monkeypatch.setattr( socket, "getaddrinfo", - lambda h, p: ( - _addrinfo("93.184.216.34") - + _addrinfo("2606:2800:220:1:248:1893:25c8:1946") - ), + lambda h, p: _addrinfo("93.184.216.34") + _addrinfo("2606:2800:220:1:248:1893:25c8:1946"), ) hostname, resolved_ip = FileService._validate_url_for_crawl("https://example.com/") assert hostname == "example.com" diff --git a/test/unit_test/api/db/test_oceanbase_peewee.py b/test/unit_test/api/db/test_oceanbase_peewee.py index 6ddee98507..fe4c56d98a 100644 --- a/test/unit_test/api/db/test_oceanbase_peewee.py +++ b/test/unit_test/api/db/test_oceanbase_peewee.py @@ -20,22 +20,23 @@ class TestOceanBaseDatabase: def test_oceanbase_in_pooled_database_enum(self): """Test that OCEANBASE is in PooledDatabase enum.""" - assert hasattr(PooledDatabase, 'OCEANBASE') + assert hasattr(PooledDatabase, "OCEANBASE") assert PooledDatabase.OCEANBASE.value == RetryingPooledOceanBaseDatabase def test_oceanbase_in_database_lock_enum(self): """Test that OCEANBASE is in DatabaseLock enum.""" - assert hasattr(DatabaseLock, 'OCEANBASE') + assert hasattr(DatabaseLock, "OCEANBASE") def test_oceanbase_in_text_field_type_enum(self): """Test that OCEANBASE is in TextFieldType enum.""" - assert hasattr(TextFieldType, 'OCEANBASE') + assert hasattr(TextFieldType, "OCEANBASE") # OceanBase should use LONGTEXT like MySQL assert TextFieldType.OCEANBASE.value == "LONGTEXT" def test_oceanbase_database_inherits_mysql(self): """Test that OceanBase database inherits from PooledMySQLDatabase.""" from playhouse.pool import PooledMySQLDatabase + assert issubclass(RetryingPooledOceanBaseDatabase, PooledMySQLDatabase) def test_oceanbase_database_init(self): @@ -64,13 +65,13 @@ class TestOceanBaseDatabase: def test_pooled_database_enum_values(self): """Test PooledDatabase enum has all expected values.""" - expected = {'MYSQL', 'OCEANBASE', 'POSTGRES'} + expected = {"MYSQL", "OCEANBASE", "POSTGRES"} actual = {e.name for e in PooledDatabase} assert expected.issubset(actual), f"Missing: {expected - actual}" def test_database_lock_enum_values(self): """Test DatabaseLock enum has all expected values.""" - expected = {'MYSQL', 'OCEANBASE', 'POSTGRES'} + expected = {"MYSQL", "OCEANBASE", "POSTGRES"} actual = set(DatabaseLock.__members__.keys()) assert expected.issubset(actual), f"Missing: {expected - actual}" @@ -81,45 +82,49 @@ class TestOceanBaseConfiguration: def test_settings_default_to_mysql(self): """Test that default DB_TYPE is mysql.""" import os + # Save original value - original = os.environ.get('DB_TYPE') - + original = os.environ.get("DB_TYPE") + try: # Remove DB_TYPE to test default - if 'DB_TYPE' in os.environ: - del os.environ['DB_TYPE'] - + if "DB_TYPE" in os.environ: + del os.environ["DB_TYPE"] + # Reload settings from common import settings + settings.DATABASE_TYPE = os.getenv("DB_TYPE", "mysql") - + assert settings.DATABASE_TYPE == "mysql" finally: # Restore original value if original: - os.environ['DB_TYPE'] = original + os.environ["DB_TYPE"] = original def test_settings_can_use_oceanbase(self): """Test that DB_TYPE can be set to oceanbase.""" import os + # Save original value - original = os.environ.get('DB_TYPE') - + original = os.environ.get("DB_TYPE") + try: - os.environ['DB_TYPE'] = 'oceanbase' - + os.environ["DB_TYPE"] = "oceanbase" + # Reload settings from common import settings + settings.DATABASE_TYPE = os.getenv("DB_TYPE", "mysql") - + assert settings.DATABASE_TYPE == "oceanbase" finally: # Restore original value if original: - os.environ['DB_TYPE'] = original + os.environ["DB_TYPE"] = original else: - if 'DB_TYPE' in os.environ: - del os.environ['DB_TYPE'] + if "DB_TYPE" in os.environ: + del os.environ["DB_TYPE"] if __name__ == "__main__": diff --git a/test/unit_test/api/utils/test_doc_validation.py b/test/unit_test/api/utils/test_doc_validation.py index aa3deb6102..d2f5a0a1fb 100644 --- a/test/unit_test/api/utils/test_doc_validation.py +++ b/test/unit_test/api/utils/test_doc_validation.py @@ -53,7 +53,7 @@ def test_validate_immutable_fields_no_changes(): doc.chunk_num = 10 doc.token_num = 100 doc.progress = 0.5 - + error_msg, error_code = validate_immutable_fields(update_doc_req, doc) assert error_msg is None assert error_code is None @@ -66,7 +66,7 @@ def test_validate_immutable_fields_chunk_count_matches(): doc.chunk_num = 10 doc.token_num = 100 doc.progress = 0.5 - + error_msg, error_code = validate_immutable_fields(update_doc_req, doc) assert error_msg is None assert error_code is None @@ -79,7 +79,7 @@ def test_validate_immutable_fields_token_count_matches(): doc.chunk_num = 10 doc.token_num = 100 doc.progress = 0.5 - + error_msg, error_code = validate_immutable_fields(update_doc_req, doc) assert error_msg is None assert error_code is None @@ -92,7 +92,7 @@ def test_validate_immutable_fields_progress_matches(): doc.chunk_num = 10 doc.token_num = 100 doc.progress = 0.5 - + error_msg, error_code = validate_immutable_fields(update_doc_req, doc) assert error_msg is None assert error_code is None @@ -105,7 +105,7 @@ def test_validate_immutable_fields_chunk_count_mismatch(): doc.chunk_num = 10 doc.token_num = 100 doc.progress = 0.5 - + error_msg, error_code = validate_immutable_fields(update_doc_req, doc) assert error_msg == "Can't change `chunk_count`." assert error_code == RetCode.DATA_ERROR @@ -118,7 +118,7 @@ def test_validate_immutable_fields_token_count_mismatch(): doc.chunk_num = 10 doc.token_num = 100 doc.progress = 0.5 - + error_msg, error_code = validate_immutable_fields(update_doc_req, doc) assert error_msg == "Can't change `token_count`." assert error_code == RetCode.DATA_ERROR @@ -131,7 +131,7 @@ def test_validate_immutable_fields_progress_mismatch(): doc.chunk_num = 10 doc.token_num = 100 doc.progress = 0.5 - + error_msg, error_code = validate_immutable_fields(update_doc_req, doc) assert error_msg == "Can't change `progress`." assert error_code == RetCode.DATA_ERROR @@ -145,18 +145,18 @@ def test_validate_immutable_fields_progress_boundary_values(): doc.chunk_num = 10 doc.token_num = 100 doc.progress = 0.0 - + error_msg, error_code = validate_immutable_fields(update_doc_req, doc) assert error_msg is None assert error_code is None - + # Test with 1.0 update_doc_req = UpdateDocumentReq(progress=1.0) doc = Mock() doc.chunk_num = 10 doc.token_num = 100 doc.progress = 1.0 - + error_msg, error_code = validate_immutable_fields(update_doc_req, doc) assert error_msg is None assert error_code is None @@ -169,7 +169,7 @@ def test_validate_immutable_fields_none_values(): doc.chunk_num = 10 doc.token_num = 100 doc.progress = 0.5 - + error_msg, error_code = validate_immutable_fields(update_doc_req, doc) assert error_msg is None assert error_code is None @@ -240,6 +240,7 @@ def test_validate_document_name_valid(): assert error_msg is None assert error_code is None + def test_validate_document_name_attr_error(): """Test valid document name update.""" req_doc_name = 0 @@ -312,7 +313,7 @@ def test_validate_chunk_method_valid(): doc = Mock() doc.type = FileType.PDF doc.name = "document.pdf" - + error_msg, error_code = validate_chunk_method(doc) assert error_msg is None assert error_code is None @@ -323,7 +324,7 @@ def test_validate_chunk_method_visual_not_supported(): doc = Mock() doc.type = FileType.VISUAL doc.name = "image.jpg" - + error_msg, error_code = validate_chunk_method(doc) assert "Not supported yet!" in error_msg assert error_code == RetCode.DATA_ERROR @@ -334,7 +335,7 @@ def test_validate_chunk_method_ppt_not_supported(): doc = Mock() doc.type = FileType.PDF doc.name = "presentation.ppt" - + error_msg, error_code = validate_chunk_method(doc) assert "Not supported yet!" in error_msg assert error_code == RetCode.DATA_ERROR @@ -345,7 +346,7 @@ def test_validate_chunk_method_pptx_not_supported(): doc = Mock() doc.type = FileType.PDF doc.name = "presentation.pptx" - + error_msg, error_code = validate_chunk_method(doc) assert "Not supported yet!" in error_msg assert error_code == RetCode.DATA_ERROR @@ -356,7 +357,7 @@ def test_validate_chunk_method_pages_not_supported(): doc = Mock() doc.type = FileType.PDF doc.name = "document.pages" - + error_msg, error_code = validate_chunk_method(doc) assert "Not supported yet!" in error_msg assert error_code == RetCode.DATA_ERROR @@ -367,7 +368,7 @@ def test_validate_chunk_method_other_extensions_still_valid(): doc = Mock() doc.type = FileType.PDF doc.name = "document.docx" - + error_msg, error_code = validate_chunk_method(doc) assert error_msg is None assert error_code is None diff --git a/test/unit_test/api/utils/test_health_utils_minio.py b/test/unit_test/api/utils/test_health_utils_minio.py index 176ace64dd..7d041e6a39 100644 --- a/test/unit_test/api/utils/test_health_utils_minio.py +++ b/test/unit_test/api/utils/test_health_utils_minio.py @@ -17,6 +17,7 @@ Unit tests for MinIO health check (check_minio_alive) and scheme/verify helpers. Covers SSL/HTTPS and certificate verification (issues #13158, #13159). """ + from unittest.mock import patch, Mock @@ -27,6 +28,7 @@ class TestMinioSchemeAndVerify: def test_scheme_http_when_secure_false(self, mock_settings): mock_settings.MINIO = {"host": "minio:9000", "secure": False} from api.utils.health_utils import _minio_scheme_and_verify + scheme, verify = _minio_scheme_and_verify() assert scheme == "http" assert verify is True @@ -35,6 +37,7 @@ class TestMinioSchemeAndVerify: def test_scheme_https_when_secure_true(self, mock_settings): mock_settings.MINIO = {"host": "minio:9000", "secure": True} from api.utils.health_utils import _minio_scheme_and_verify + scheme, verify = _minio_scheme_and_verify() assert scheme == "https" assert verify is True @@ -43,6 +46,7 @@ class TestMinioSchemeAndVerify: def test_scheme_https_when_secure_string_true(self, mock_settings): mock_settings.MINIO = {"host": "minio:9000", "secure": "true"} from api.utils.health_utils import _minio_scheme_and_verify + scheme, verify = _minio_scheme_and_verify() assert scheme == "https" @@ -50,6 +54,7 @@ class TestMinioSchemeAndVerify: def test_verify_false_for_self_signed(self, mock_settings): mock_settings.MINIO = {"host": "minio:9000", "secure": True, "verify": False} from api.utils.health_utils import _minio_scheme_and_verify + scheme, verify = _minio_scheme_and_verify() assert scheme == "https" assert verify is False @@ -58,6 +63,7 @@ class TestMinioSchemeAndVerify: def test_verify_string_false(self, mock_settings): mock_settings.MINIO = {"host": "minio:9000", "verify": "false"} from api.utils.health_utils import _minio_scheme_and_verify + _, verify = _minio_scheme_and_verify() assert verify is False @@ -65,6 +71,7 @@ class TestMinioSchemeAndVerify: def test_default_verify_true_when_key_missing(self, mock_settings): mock_settings.MINIO = {"host": "minio:9000"} from api.utils.health_utils import _minio_scheme_and_verify + _, verify = _minio_scheme_and_verify() assert verify is True @@ -80,6 +87,7 @@ class TestCheckMinioAlive: mock_response.status_code = 200 mock_get.return_value = mock_response from api.utils.health_utils import check_minio_alive + result = check_minio_alive() assert result["status"] == "alive" assert "elapsed" in result["message"] @@ -96,6 +104,7 @@ class TestCheckMinioAlive: mock_response.status_code = 200 mock_get.return_value = mock_response from api.utils.health_utils import check_minio_alive + check_minio_alive() call_args = mock_get.call_args assert call_args[0][0] == "https://minio:9000/minio/health/live" @@ -108,6 +117,7 @@ class TestCheckMinioAlive: mock_response.status_code = 200 mock_get.return_value = mock_response from api.utils.health_utils import check_minio_alive + check_minio_alive() call_args = mock_get.call_args assert call_args[1]["verify"] is False @@ -120,6 +130,7 @@ class TestCheckMinioAlive: mock_response.status_code = 503 mock_get.return_value = mock_response from api.utils.health_utils import check_minio_alive + result = check_minio_alive() assert result["status"] == "timeout" @@ -129,6 +140,7 @@ class TestCheckMinioAlive: mock_settings.MINIO = {"host": "minio:9000"} mock_get.side_effect = ConnectionError("Connection refused") from api.utils.health_utils import check_minio_alive + result = check_minio_alive() assert result["status"] == "timeout" assert "error" in result["message"] @@ -141,6 +153,7 @@ class TestCheckMinioAlive: mock_response.status_code = 200 mock_get.return_value = mock_response from api.utils.health_utils import check_minio_alive + check_minio_alive() call_args = mock_get.call_args assert call_args[1]["timeout"] == 10 diff --git a/test/unit_test/api/utils/test_oceanbase_health.py b/test/unit_test/api/utils/test_oceanbase_health.py index fa6d24dd18..92d93cc719 100644 --- a/test/unit_test/api/utils/test_oceanbase_health.py +++ b/test/unit_test/api/utils/test_oceanbase_health.py @@ -16,6 +16,7 @@ """ Unit tests for OceanBase health check and performance monitoring functionality. """ + import inspect import os import types @@ -27,20 +28,15 @@ from api.utils.health_utils import get_oceanbase_status, check_oceanbase_health class TestOceanBaseHealthCheck: """Test cases for OceanBase health check functionality.""" - - @patch('api.utils.health_utils.OBConnection') - @patch.dict(os.environ, {'DOC_ENGINE': 'oceanbase'}) + + @patch("api.utils.health_utils.OBConnection") + @patch.dict(os.environ, {"DOC_ENGINE": "oceanbase"}) def test_get_oceanbase_status_success(self, mock_ob_class): """Test successful OceanBase status retrieval.""" # Setup mock mock_ob_connection = Mock() mock_ob_connection.uri = "localhost:2881" - mock_ob_connection.health.return_value = { - "uri": "localhost:2881", - "version_comment": "OceanBase 4.3.5.1", - "status": "healthy", - "connection": "connected" - } + mock_ob_connection.health.return_value = {"uri": "localhost:2881", "version_comment": "OceanBase 4.3.5.1", "status": "healthy", "connection": "connected"} mock_ob_connection.get_performance_metrics.return_value = { "connection": "connected", "latency_ms": 5.2, @@ -49,13 +45,13 @@ class TestOceanBaseHealthCheck: "query_per_second": 150, "slow_queries": 2, "active_connections": 10, - "max_connections": 300 + "max_connections": 300, } mock_ob_class.return_value = mock_ob_connection - + # Execute result = get_oceanbase_status() - + # Assert assert result["status"] == "alive" assert "message" in result @@ -63,36 +59,31 @@ class TestOceanBaseHealthCheck: assert "performance" in result["message"] assert result["message"]["health"]["status"] == "healthy" assert result["message"]["performance"]["latency_ms"] == 5.2 - - @patch.dict(os.environ, {'DOC_ENGINE': 'elasticsearch'}) + + @patch.dict(os.environ, {"DOC_ENGINE": "elasticsearch"}) def test_get_oceanbase_status_not_configured(self): """Test OceanBase status when not configured.""" with pytest.raises(Exception) as exc_info: get_oceanbase_status() assert "OceanBase is not in use" in str(exc_info.value) - - @patch('api.utils.health_utils.OBConnection') - @patch.dict(os.environ, {'DOC_ENGINE': 'oceanbase'}) + + @patch("api.utils.health_utils.OBConnection") + @patch.dict(os.environ, {"DOC_ENGINE": "oceanbase"}) def test_get_oceanbase_status_connection_error(self, mock_ob_class): """Test OceanBase status when connection fails.""" mock_ob_class.side_effect = Exception("Connection failed") - + result = get_oceanbase_status() - + assert result["status"] == "timeout" assert "error" in result["message"] - - @patch('api.utils.health_utils.OBConnection') - @patch.dict(os.environ, {'DOC_ENGINE': 'oceanbase'}) + + @patch("api.utils.health_utils.OBConnection") + @patch.dict(os.environ, {"DOC_ENGINE": "oceanbase"}) def test_check_oceanbase_health_healthy(self, mock_ob_class): """Test OceanBase health check returns healthy status.""" mock_ob_connection = Mock() - mock_ob_connection.health.return_value = { - "uri": "localhost:2881", - "version_comment": "OceanBase 4.3.5.1", - "status": "healthy", - "connection": "connected" - } + mock_ob_connection.health.return_value = {"uri": "localhost:2881", "version_comment": "OceanBase 4.3.5.1", "status": "healthy", "connection": "connected"} mock_ob_connection.get_performance_metrics.return_value = { "connection": "connected", "latency_ms": 5.2, @@ -101,28 +92,23 @@ class TestOceanBaseHealthCheck: "query_per_second": 150, "slow_queries": 0, "active_connections": 10, - "max_connections": 300 + "max_connections": 300, } mock_ob_class.return_value = mock_ob_connection - + result = check_oceanbase_health() - + assert result["status"] == "healthy" assert result["details"]["connection"] == "connected" assert result["details"]["latency_ms"] == 5.2 assert result["details"]["query_per_second"] == 150 - - @patch('api.utils.health_utils.OBConnection') - @patch.dict(os.environ, {'DOC_ENGINE': 'oceanbase'}) + + @patch("api.utils.health_utils.OBConnection") + @patch.dict(os.environ, {"DOC_ENGINE": "oceanbase"}) def test_check_oceanbase_health_degraded(self, mock_ob_class): """Test OceanBase health check returns degraded status for high latency.""" mock_ob_connection = Mock() - mock_ob_connection.health.return_value = { - "uri": "localhost:2881", - "version_comment": "OceanBase 4.3.5.1", - "status": "healthy", - "connection": "connected" - } + mock_ob_connection.health.return_value = {"uri": "localhost:2881", "version_comment": "OceanBase 4.3.5.1", "status": "healthy", "connection": "connected"} mock_ob_connection.get_performance_metrics.return_value = { "connection": "connected", "latency_ms": 1500.0, # High latency > 1000ms @@ -131,43 +117,35 @@ class TestOceanBaseHealthCheck: "query_per_second": 50, "slow_queries": 5, "active_connections": 10, - "max_connections": 300 + "max_connections": 300, } mock_ob_class.return_value = mock_ob_connection - + result = check_oceanbase_health() - + assert result["status"] == "degraded" assert result["details"]["latency_ms"] == 1500.0 - - @patch('api.utils.health_utils.OBConnection') - @patch.dict(os.environ, {'DOC_ENGINE': 'oceanbase'}) + + @patch("api.utils.health_utils.OBConnection") + @patch.dict(os.environ, {"DOC_ENGINE": "oceanbase"}) def test_check_oceanbase_health_unhealthy(self, mock_ob_class): """Test OceanBase health check returns unhealthy status.""" mock_ob_connection = Mock() - mock_ob_connection.health.return_value = { - "uri": "localhost:2881", - "status": "unhealthy", - "connection": "disconnected", - "error": "Connection timeout" - } - mock_ob_connection.get_performance_metrics.return_value = { - "connection": "disconnected", - "error": "Connection timeout" - } + mock_ob_connection.health.return_value = {"uri": "localhost:2881", "status": "unhealthy", "connection": "disconnected", "error": "Connection timeout"} + mock_ob_connection.get_performance_metrics.return_value = {"connection": "disconnected", "error": "Connection timeout"} mock_ob_class.return_value = mock_ob_connection - + result = check_oceanbase_health() - + assert result["status"] == "unhealthy" assert result["details"]["connection"] == "disconnected" assert "error" in result["details"] - - @patch.dict(os.environ, {'DOC_ENGINE': 'elasticsearch'}) + + @patch.dict(os.environ, {"DOC_ENGINE": "elasticsearch"}) def test_check_oceanbase_health_not_configured(self): """Test OceanBase health check when not configured.""" result = check_oceanbase_health() - + assert result["status"] == "not_configured" assert result["details"]["connection"] == "not_configured" assert "not configured" in result["details"]["message"].lower() @@ -175,29 +153,32 @@ class TestOceanBaseHealthCheck: class TestOBConnectionPerformanceMetrics: """Test cases for OBConnection performance metrics methods.""" - + def _create_mock_connection(self): """Create a mock OBConnection with actual methods.""" + # Create a simple object and bind the real methods to it class MockConn: pass + conn = MockConn() # Get the actual class from the singleton wrapper's closure from rag.utils import ob_conn + # OBConnection is wrapped by @singleton decorator, so it's a function # The original class is stored in the closure of the singleton function # Find the class by checking all closure cells ob_connection_class = None - if hasattr(ob_conn.OBConnection, '__closure__') and ob_conn.OBConnection.__closure__: + if hasattr(ob_conn.OBConnection, "__closure__") and ob_conn.OBConnection.__closure__: for cell in ob_conn.OBConnection.__closure__: cell_value = cell.cell_contents if inspect.isclass(cell_value): ob_connection_class = cell_value break - + if ob_connection_class is None: raise ValueError("Could not find OBConnection class in closure") - + # Bind the actual methods to our mock object conn.get_performance_metrics = types.MethodType(ob_connection_class.get_performance_metrics, conn) conn._get_storage_info = types.MethodType(ob_connection_class._get_storage_info, conn) @@ -205,7 +186,7 @@ class TestOBConnectionPerformanceMetrics: conn._get_slow_query_count = types.MethodType(ob_connection_class._get_slow_query_count, conn) conn._estimate_qps = types.MethodType(ob_connection_class._estimate_qps, conn) return conn - + def test_get_performance_metrics_success(self): """Test successful retrieval of performance metrics.""" # Create mock connection with actual methods @@ -214,29 +195,27 @@ class TestOBConnectionPerformanceMetrics: conn.client = mock_client conn.uri = "localhost:2881" conn.db_name = "test" - + # Mock client methods - create separate mock results for each call mock_result1 = Mock() mock_result1.fetchone.return_value = (1,) - + mock_result2 = Mock() mock_result2.fetchone.return_value = (100.5,) - + mock_result3 = Mock() mock_result3.fetchone.return_value = (100.0,) - + mock_result4 = Mock() - mock_result4.fetchall.return_value = [ - (1, 'user', 'host', 'db', 'Query', 0, 'executing', 'SELECT 1') - ] - mock_result4.fetchone.return_value = ('max_connections', '300') - + mock_result4.fetchall.return_value = [(1, "user", "host", "db", "Query", 0, "executing", "SELECT 1")] + mock_result4.fetchone.return_value = ("max_connections", "300") + mock_result5 = Mock() mock_result5.fetchone.return_value = (0,) - + mock_result6 = Mock() mock_result6.fetchone.return_value = (5,) - + # Setup side_effect to return different mocks for different queries def sql_side_effect(query): if "SELECT 1" in query: @@ -254,21 +233,22 @@ class TestOBConnectionPerformanceMetrics: elif "information_schema.processlist" in query and "COUNT" in query: return mock_result6 return Mock() - + mock_client.perform_raw_text_sql.side_effect = sql_side_effect mock_client.pool_size = 300 - + # Mock logger import logging - conn.logger = logging.getLogger('test') - + + conn.logger = logging.getLogger("test") + result = conn.get_performance_metrics() - + assert result["connection"] == "connected" assert result["latency_ms"] >= 0 assert "storage_used" in result assert "storage_total" in result - + def test_get_performance_metrics_connection_error(self): """Test performance metrics when connection fails.""" # Create mock connection with actual methods @@ -277,14 +257,14 @@ class TestOBConnectionPerformanceMetrics: conn.client = mock_client conn.uri = "localhost:2881" conn.logger = Mock() - + mock_client.perform_raw_text_sql.side_effect = Exception("Connection failed") - + result = conn.get_performance_metrics() - + assert result["connection"] == "disconnected" assert "error" in result - + def test_get_storage_info_success(self): """Test successful retrieval of storage information.""" # Create mock connection with actual methods @@ -293,27 +273,27 @@ class TestOBConnectionPerformanceMetrics: conn.client = mock_client conn.db_name = "test" conn.logger = Mock() - + mock_result1 = Mock() mock_result1.fetchone.return_value = (100.5,) mock_result2 = Mock() mock_result2.fetchone.return_value = (100.0,) - + def sql_side_effect(query): if "information_schema.tables" in query: return mock_result1 elif "__all_disk_stat" in query: return mock_result2 return Mock() - + mock_client.perform_raw_text_sql.side_effect = sql_side_effect - + result = conn._get_storage_info() - + assert "storage_used" in result assert "storage_total" in result assert "MB" in result["storage_used"] - + def test_get_storage_info_fallback(self): """Test storage info with fallback when total space unavailable.""" # Create mock connection with actual methods @@ -322,7 +302,7 @@ class TestOBConnectionPerformanceMetrics: conn.client = mock_client conn.db_name = "test" conn.logger = Mock() - + # First query succeeds, second fails def side_effect(query): if "information_schema.tables" in query: @@ -331,14 +311,14 @@ class TestOBConnectionPerformanceMetrics: return mock_result else: raise Exception("Table not found") - + mock_client.perform_raw_text_sql.side_effect = side_effect - + result = conn._get_storage_info() - + assert "storage_used" in result assert "storage_total" in result - + def test_get_connection_pool_stats(self): """Test retrieval of connection pool statistics.""" # Create mock connection with actual methods @@ -347,31 +327,28 @@ class TestOBConnectionPerformanceMetrics: conn.client = mock_client conn.logger = Mock() mock_client.pool_size = 300 - + mock_result1 = Mock() - mock_result1.fetchall.return_value = [ - (1, 'user', 'host', 'db', 'Query', 0, 'executing', 'SELECT 1'), - (2, 'user', 'host', 'db', 'Sleep', 10, None, None) - ] - + mock_result1.fetchall.return_value = [(1, "user", "host", "db", "Query", 0, "executing", "SELECT 1"), (2, "user", "host", "db", "Sleep", 10, None, None)] + mock_result2 = Mock() - mock_result2.fetchone.return_value = ('max_connections', '300') - + mock_result2.fetchone.return_value = ("max_connections", "300") + def sql_side_effect(query): if "SHOW PROCESSLIST" in query: return mock_result1 elif "SHOW VARIABLES LIKE 'max_connections'" in query: return mock_result2 return Mock() - + mock_client.perform_raw_text_sql.side_effect = sql_side_effect - + result = conn._get_connection_pool_stats() - + assert "active_connections" in result assert "max_connections" in result assert result["active_connections"] >= 0 - + def test_get_slow_query_count(self): """Test retrieval of slow query count.""" # Create mock connection with actual methods @@ -379,16 +356,16 @@ class TestOBConnectionPerformanceMetrics: mock_client = Mock() conn.client = mock_client conn.logger = Mock() - + mock_result = Mock() mock_result.fetchone.return_value = (5,) mock_client.perform_raw_text_sql.return_value = mock_result - + result = conn._get_slow_query_count(threshold_seconds=1) - + assert isinstance(result, int) assert result >= 0 - + def test_estimate_qps(self): """Test QPS estimation.""" # Create mock connection with actual methods @@ -396,17 +373,16 @@ class TestOBConnectionPerformanceMetrics: mock_client = Mock() conn.client = mock_client conn.logger = Mock() - + mock_result = Mock() mock_result.fetchone.return_value = (10,) mock_client.perform_raw_text_sql.return_value = mock_result - + result = conn._estimate_qps() - + assert isinstance(result, int) assert result >= 0 if __name__ == "__main__": pytest.main([__file__, "-v"]) - diff --git a/test/unit_test/common/test_apply_semi_auto_meta_data_filter.py b/test/unit_test/common/test_apply_semi_auto_meta_data_filter.py index 165e283aa1..138ebd9f5e 100644 --- a/test/unit_test/common/test_apply_semi_auto_meta_data_filter.py +++ b/test/unit_test/common/test_apply_semi_auto_meta_data_filter.py @@ -2,51 +2,41 @@ import pytest from common.metadata_utils import apply_meta_data_filter from unittest.mock import MagicMock, AsyncMock, patch + @pytest.mark.asyncio async def test_apply_meta_data_filter_semi_auto_key(): - meta_data_filter = { - "method": "semi_auto", - "semi_auto": ["key1", "key2"] - } - metas = { - "key1": {"val1": ["doc1"]}, - "key2": {"val2": ["doc2"]} - } + meta_data_filter = {"method": "semi_auto", "semi_auto": ["key1", "key2"]} + metas = {"key1": {"val1": ["doc1"]}, "key2": {"val2": ["doc2"]}} question = "find val1" - + chat_mdl = MagicMock() - + with patch("rag.prompts.generator.gen_meta_filter", new_callable=AsyncMock) as mock_gen: mock_gen.return_value = {"conditions": [{"key": "key1", "op": "=", "value": "val1"}], "logic": "and"} - + doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl) assert doc_ids == ["doc1"] - + # Check that constraints is an empty dict by default for legacy mock_gen.assert_called_once() args, kwargs = mock_gen.call_args assert kwargs["constraints"] == {} + @pytest.mark.asyncio async def test_apply_meta_data_filter_semi_auto_key_and_operator(): - meta_data_filter = { - "method": "semi_auto", - "semi_auto": [{"key": "key1", "op": ">"}, "key2"] - } - metas = { - "key1": {"10": ["doc1"]}, - "key2": {"val2": ["doc2"]} - } + meta_data_filter = {"method": "semi_auto", "semi_auto": [{"key": "key1", "op": ">"}, "key2"]} + metas = {"key1": {"10": ["doc1"]}, "key2": {"val2": ["doc2"]}} question = "find key1 > 5" - + chat_mdl = MagicMock() - + with patch("rag.prompts.generator.gen_meta_filter", new_callable=AsyncMock) as mock_gen: mock_gen.return_value = {"conditions": [{"key": "key1", "op": ">", "value": "5"}], "logic": "and"} - + doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl) assert doc_ids == ["doc1"] - + # Check that constraints are correctly passed mock_gen.assert_called_once() args, kwargs = mock_gen.call_args diff --git a/test/unit_test/common/test_decorator.py b/test/unit_test/common/test_decorator.py index a4bc4a24e3..e33b2cbe5a 100644 --- a/test/unit_test/common/test_decorator.py +++ b/test/unit_test/common/test_decorator.py @@ -30,7 +30,6 @@ class TestClass: # Test cases class TestSingleton: - def test_state_persistence(self): """Test that instance state persists across multiple calls""" instance1 = TestClass() diff --git a/test/unit_test/common/test_delete_query_construction.py b/test/unit_test/common/test_delete_query_construction.py index 52e24cf80a..6bc16a0f04 100644 --- a/test/unit_test/common/test_delete_query_construction.py +++ b/test/unit_test/common/test_delete_query_construction.py @@ -16,11 +16,11 @@ """ Unit tests for delete query construction in ES/OpenSearch connectors. -These tests verify that the delete method correctly combines chunk IDs with +These tests verify that the delete method correctly combines chunk IDs with other filter conditions (doc_id, kb_id) to scope deletions properly. -This addresses issue #12520: "Files of deleted slices can still be searched -and displayed in 'reference'" - caused by delete queries not properly +This addresses issue #12520: "Files of deleted slices can still be searched +and displayed in 'reference'" - caused by delete queries not properly combining all filter conditions. Run with: python -m pytest test/unit/test_delete_query_construction.py -v @@ -85,20 +85,20 @@ class TestDeleteQueryConstruction: def test_delete_with_chunk_ids_includes_kb_id(self): """ CRITICAL: When deleting by chunk IDs, kb_id MUST be included in the query. - - This was the root cause of issue #12520 - the original code would + + This was the root cause of issue #12520 - the original code would only use Q("ids") and ignore kb_id. """ condition = {"id": ["chunk1", "chunk2"]} query = self.build_delete_query(condition, "kb123") - + query_dict = query["query"]["bool"] - + # Verify chunk IDs filter is present ids_filter = [f for f in query_dict.get("filter", []) if "ids" in f] assert len(ids_filter) == 1, "Should have ids filter" assert set(ids_filter[0]["ids"]["values"]) == {"chunk1", "chunk2"} - + # Verify kb_id is also in the query (CRITICAL FIX) must_terms = query_dict.get("must", []) kb_id_terms = [t for t in must_terms if "term" in t and "kb_id" in t.get("term", {})] @@ -112,19 +112,19 @@ class TestDeleteQueryConstruction: """ condition = {"id": ["chunk1"], "doc_id": "doc456"} query = self.build_delete_query(condition, "kb123") - + query_dict = query["query"]["bool"] - + # Verify all three conditions are present ids_filter = [f for f in query_dict.get("filter", []) if "ids" in f] assert len(ids_filter) == 1, "Should have ids filter" - + must_terms = query_dict.get("must", []) - + # Check kb_id kb_id_terms = [t for t in must_terms if "term" in t and "kb_id" in t.get("term", {})] assert len(kb_id_terms) == 1, "kb_id must be present" - + # Check doc_id doc_id_terms = [t for t in must_terms if "term" in t and "doc_id" in t.get("term", {})] assert len(doc_id_terms) == 1, "doc_id must be present" @@ -136,7 +136,7 @@ class TestDeleteQueryConstruction: """ condition = {"id": "single_chunk"} query = self.build_delete_query(condition, "kb123") - + query_dict = query["query"]["bool"] ids_filter = [f for f in query_dict.get("filter", []) if "ids" in f] assert len(ids_filter) == 1 @@ -149,13 +149,13 @@ class TestDeleteQueryConstruction: """ condition = {"id": [], "doc_id": "doc456"} query = self.build_delete_query(condition, "kb123") - + query_dict = query["query"]["bool"] - + # Empty chunk_ids should NOT add an ids filter ids_filter = [f for f in query_dict.get("filter", []) if "ids" in f] assert len(ids_filter) == 0, "Empty chunk_ids should not create ids filter" - + # But kb_id and doc_id should still be present must_terms = query_dict.get("must", []) assert any("kb_id" in str(t) for t in must_terms), "kb_id must be present" @@ -167,14 +167,14 @@ class TestDeleteQueryConstruction: """ condition = {"doc_id": "doc456"} query = self.build_delete_query(condition, "kb123") - + query_dict = query["query"]["bool"] must_terms = query_dict.get("must", []) - + # Both doc_id and kb_id should be in query doc_terms = [t for t in must_terms if "term" in t and "doc_id" in t.get("term", {})] kb_terms = [t for t in must_terms if "term" in t and "kb_id" in t.get("term", {})] - + assert len(doc_terms) == 1 assert len(kb_terms) == 1 @@ -184,13 +184,13 @@ class TestDeleteQueryConstruction: """ condition = { "kb_id": "kb123", # Will be overwritten - "must_not": {"exists": "source_id"} + "must_not": {"exists": "source_id"}, } query = self.build_delete_query(condition, "kb123") - + query_dict = query["query"]["bool"] must_not = query_dict.get("must_not", []) - + exists_filters = [f for f in must_not if "exists" in f] assert len(exists_filters) == 1 assert exists_filters[0]["exists"]["field"] == "source_id" @@ -201,10 +201,10 @@ class TestDeleteQueryConstruction: """ condition = {"knowledge_graph_kwd": ["entity", "relation"]} query = self.build_delete_query(condition, "kb123") - + query_dict = query["query"]["bool"] must_terms = query_dict.get("must", []) - + terms_query = [t for t in must_terms if "terms" in t] assert len(terms_query) >= 1 # Find the knowledge_graph_kwd terms @@ -225,21 +225,18 @@ class TestChunkApiDeleteCondition: passed to settings.docStoreConn.delete. """ # Simulate what the rm endpoint should construct - req = { - "doc_id": "doc123", - "chunk_ids": ["chunk1", "chunk2"] - } - + req = {"doc_id": "doc123", "chunk_ids": ["chunk1", "chunk2"]} + # This is what the FIXED code should produce: correct_condition = { "id": req["chunk_ids"], - "doc_id": req["doc_id"] # <-- CRITICAL: doc_id must be included + "doc_id": req["doc_id"], # <-- CRITICAL: doc_id must be included } - + # Verify doc_id is in the condition assert "doc_id" in correct_condition, "doc_id MUST be in delete condition" assert correct_condition["doc_id"] == "doc123" - + # Verify chunk IDs are in the condition assert "id" in correct_condition assert correct_condition["id"] == ["chunk1", "chunk2"] @@ -259,16 +256,13 @@ class TestSDKDocDeleteCondition: # Simulate SDK request document_id = "doc456" chunk_ids = ["chunk1", "chunk2"] - + # The CORRECT condition construction (from restful_apis/chunk_api.py): condition = {"doc_id": document_id} if chunk_ids: condition["id"] = chunk_ids - - assert condition == { - "doc_id": "doc456", - "id": ["chunk1", "chunk2"] - } + + assert condition == {"doc_id": "doc456", "id": ["chunk1", "chunk2"]} def test_sdk_rm_chunk_all_chunks(self): """ @@ -276,11 +270,11 @@ class TestSDKDocDeleteCondition: """ document_id = "doc456" chunk_ids = [] # Delete all - + condition = {"doc_id": document_id} if chunk_ids: condition["id"] = chunk_ids - + # When no chunk_ids, only doc_id should be in condition assert condition == {"doc_id": "doc456"} assert "id" not in condition diff --git a/test/unit_test/common/test_file_utils.py b/test/unit_test/common/test_file_utils.py index 6a38f51ad0..07592d5cb0 100644 --- a/test/unit_test/common/test_file_utils.py +++ b/test/unit_test/common/test_file_utils.py @@ -107,12 +107,15 @@ class TestGetProjectBaseDirectory: # Parameterized tests for different path combinations -@pytest.mark.parametrize("path_args,expected_suffix", [ - ((), ""), # No additional arguments - (("src",), "src"), - (("data", "models"), os.path.join("data", "models")), - (("config", "app", "settings.json"), os.path.join("config", "app", "settings.json")), -]) +@pytest.mark.parametrize( + "path_args,expected_suffix", + [ + ((), ""), # No additional arguments + (("src",), "src"), + (("data", "models"), os.path.join("data", "models")), + (("config", "app", "settings.json"), os.path.join("config", "app", "settings.json")), + ], +) def test_various_path_combinations(path_args, expected_suffix): """Test various combinations of path arguments""" base_path = get_project_base_directory() diff --git a/test/unit_test/common/test_float_utils.py b/test/unit_test/common/test_float_utils.py index cecad1ae77..a33d9c8009 100644 --- a/test/unit_test/common/test_float_utils.py +++ b/test/unit_test/common/test_float_utils.py @@ -17,8 +17,8 @@ import math from common.float_utils import get_float -class TestGetFloat: +class TestGetFloat: def test_valid_float_string(self): """Test conversion of valid float strings""" assert get_float("3.14") == 3.14 @@ -66,8 +66,8 @@ class TestGetFloat: def test_special_float_strings(self): """Test handling of special float strings""" - assert get_float("inf") == float('inf') - assert get_float("-inf") == float('-inf') + assert get_float("inf") == float("inf") + assert get_float("-inf") == float("-inf") # NaN should return -inf according to our function's design result = get_float("nan") @@ -85,4 +85,4 @@ class TestGetFloat: assert get_float(" 3.14 ") == 3.14 result = get_float(" invalid ") assert math.isinf(result) - assert result < 0 \ No newline at end of file + assert result < 0 diff --git a/test/unit_test/common/test_metadata_filter.py b/test/unit_test/common/test_metadata_filter.py index d48b30fb6c..62f69c9101 100644 --- a/test/unit_test/common/test_metadata_filter.py +++ b/test/unit_test/common/test_metadata_filter.py @@ -158,9 +158,7 @@ def test_equal_translates_to_term_with_lowercased_value(es_translator): return f"{META_FIELDS_PREFIX}.{key}" clauses = es_translator.translate({"key": "tag", "op": "=", "value": "Alpha"}).to_clauses() - assert clauses == [ - {"term": {_field("tag") + ".keyword": {"value": "alpha", "case_insensitive": True}}} - ] + assert clauses == [{"term": {_field("tag") + ".keyword": {"value": "alpha", "case_insensitive": True}}}] def test_equal_parses_numeric_literal(es_translator): @@ -181,9 +179,7 @@ def test_equal_multiword_uses_keyword_subfield(es_translator): def _field(key: str) -> str: return f"{META_FIELDS_PREFIX}.{key}" - clauses = es_translator.translate( - {"key": "author", "op": "=", "value": "Alice Wonderland"} - ).to_clauses() + clauses = es_translator.translate({"key": "author", "op": "=", "value": "Alice Wonderland"}).to_clauses() assert clauses == [ { "term": { @@ -207,9 +203,7 @@ def test_not_equal_requires_field_to_exist(es_translator): { "bool": { "must": [{"exists": {"field": _field("tag")}}], - "must_not": [ - {"term": {_field("tag") + ".keyword": {"value": "alpha", "case_insensitive": True}}} - ], + "must_not": [{"term": {_field("tag") + ".keyword": {"value": "alpha", "case_insensitive": True}}}], } } ] @@ -253,10 +247,7 @@ def test_in_operator_csv_value_lowercased(es_translator): def _string_terms_should(field_path: str, members): return { "bool": { - "should": [ - {"term": {field_path + ".keyword": {"value": m, "case_insensitive": True}}} - for m in members - ], + "should": [{"term": {field_path + ".keyword": {"value": m, "case_insensitive": True}}} for m in members], "minimum_should_match": 1, } } @@ -274,10 +265,7 @@ def test_in_operator_python_list_literal(es_translator): def _string_terms_should(field_path: str, members): return { "bool": { - "should": [ - {"term": {field_path + ".keyword": {"value": m, "case_insensitive": True}}} - for m in members - ], + "should": [{"term": {field_path + ".keyword": {"value": m, "case_insensitive": True}}} for m in members], "minimum_should_match": 1, } } @@ -305,10 +293,7 @@ def test_not_in_negates_with_existence_guard(es_translator): def _string_terms_should(field_path: str, members): return { "bool": { - "should": [ - {"term": {field_path + ".keyword": {"value": m, "case_insensitive": True}}} - for m in members - ], + "should": [{"term": {field_path + ".keyword": {"value": m, "case_insensitive": True}}} for m in members], "minimum_should_match": 1, } } @@ -387,9 +372,7 @@ def test_start_with_uses_prefix(es_translator): return f"{META_FIELDS_PREFIX}.{key}" clauses = es_translator.translate({"key": "name", "op": "start with", "value": "pre"}).to_clauses() - assert clauses == [ - {"prefix": {_field("name") + ".keyword": {"value": "pre", "case_insensitive": True}}} - ] + assert clauses == [{"prefix": {_field("name") + ".keyword": {"value": "pre", "case_insensitive": True}}}] def test_end_with_uses_trailing_wildcard(es_translator): @@ -656,4 +639,4 @@ def test_infinity_string_op_with_empty_value_raises(infinity_translator): def test_infinity_membership_with_empty_csv_raises(infinity_translator): with pytest.raises(ValueError): - infinity_translator.translate({"key": "tag", "op": "in", "value": ""}) \ No newline at end of file + infinity_translator.translate({"key": "tag", "op": "in", "value": ""}) diff --git a/test/unit_test/common/test_misc_utils.py b/test/unit_test/common/test_misc_utils.py index d94de1027b..e1b3490d42 100644 --- a/test/unit_test/common/test_misc_utils.py +++ b/test/unit_test/common/test_misc_utils.py @@ -183,12 +183,12 @@ class TestGetUuid: # UUID v1 hex should be 32 characters (without dashes) assert len(result) == 32 # Should only contain hexadecimal characters - assert all(c in '0123456789abcdef' for c in result) + assert all(c in "0123456789abcdef" for c in result) def test_no_dashes_in_result(self): """Test that result contains no dashes""" result = get_uuid() - assert '-' not in result + assert "-" not in result def test_unique_results(self): """Test that multiple calls return different UUIDs""" @@ -200,7 +200,7 @@ class TestGetUuid: # All should be valid hex strings of correct length for result in results: assert len(result) == 32 - assert all(c in '0123456789abcdef' for c in result) + assert all(c in "0123456789abcdef" for c in result) def test_valid_uuid_structure(self): """Test that the hex string can be converted back to UUID""" @@ -222,7 +222,7 @@ class TestGetUuid: assert uuid_obj.version == 1 # Variant should be RFC 4122 - assert uuid_obj.variant == 'specified in RFC 4122' + assert uuid_obj.variant == "specified in RFC 4122" def test_result_length_consistency(self): """Test that all generated UUIDs have consistent length""" @@ -236,7 +236,7 @@ class TestGetUuid: result = get_uuid() # Should only contain lowercase hex characters (UUID hex is lowercase) assert result.islower() - assert all(c in '0123456789abcdef' for c in result) + assert all(c in "0123456789abcdef" for c in result) class TestDownloadImg: @@ -314,12 +314,12 @@ class TestHashStr2Int: """Test basic string hashing functionality""" result = hash_str2int("hello") assert isinstance(result, int) - assert 0 <= result < 10 ** 8 + assert 0 <= result < 10**8 def test_default_mod_value(self): """Test that default mod value is 10^8""" result = hash_str2int("test") - assert 0 <= result < 10 ** 8 + assert 0 <= result < 10**8 def test_custom_mod_value(self): """Test with custom mod value""" @@ -349,44 +349,32 @@ class TestHashStr2Int: """Test hashing empty string""" result = hash_str2int("") assert isinstance(result, int) - assert 0 <= result < 10 ** 8 + assert 0 <= result < 10**8 def test_unicode_string(self): """Test hashing unicode strings""" - test_strings = [ - "中文", - "🚀火箭", - "café", - "🎉", - "Hello 世界" - ] + test_strings = ["中文", "🚀火箭", "café", "🎉", "Hello 世界"] for test_str in test_strings: result = hash_str2int(test_str) assert isinstance(result, int) - assert 0 <= result < 10 ** 8 + assert 0 <= result < 10**8 def test_special_characters(self): """Test hashing strings with special characters""" - test_strings = [ - "hello@world.com", - "test#123", - "line\nwith\nnewlines", - "tab\tcharacter", - "space in string" - ] + test_strings = ["hello@world.com", "test#123", "line\nwith\nnewlines", "tab\tcharacter", "space in string"] for test_str in test_strings: result = hash_str2int(test_str) assert isinstance(result, int) - assert 0 <= result < 10 ** 8 + assert 0 <= result < 10**8 def test_large_string(self): """Test hashing large string""" large_string = "x" * 10000 result = hash_str2int(large_string) assert isinstance(result, int) - assert 0 <= result < 10 ** 8 + assert 0 <= result < 10**8 def test_mod_value_1(self): """Test with mod value 1 (should always return 0)""" @@ -400,15 +388,15 @@ class TestHashStr2Int: def test_very_large_mod(self): """Test with very large mod value""" - result = hash_str2int("test", mod=10 ** 12) + result = hash_str2int("test", mod=10**12) assert isinstance(result, int) - assert 0 <= result < 10 ** 12 + assert 0 <= result < 10**12 def test_hash_algorithm_sha1(self): """Test that SHA1 algorithm is used""" test_string = "hello" expected_hash = hashlib.sha1(test_string.encode("utf-8")).hexdigest() - expected_int = int(expected_hash, 16) % (10 ** 8) + expected_int = int(expected_hash, 16) % (10**8) result = hash_str2int(test_string) assert result == expected_int @@ -437,7 +425,7 @@ class TestHashStr2Int: test_string = "hello" hash_obj = hashlib.sha1(test_string.encode("utf-8")) hex_digest = hash_obj.hexdigest() - expected_int = int(hex_digest, 16) % (10 ** 8) + expected_int = int(hex_digest, 16) % (10**8) result = hash_str2int(test_string) assert result == expected_int @@ -447,7 +435,7 @@ class TestHashStr2Int: test_strings = ["a", "b", "abc", "hello world", "12345"] for test_str in test_strings: - direct_result = int(hashlib.sha1(test_str.encode("utf-8")).hexdigest(), 16) % (10 ** 8) + direct_result = int(hashlib.sha1(test_str.encode("utf-8")).hexdigest(), 16) % (10**8) function_result = hash_str2int(test_str) assert function_result == direct_result @@ -458,23 +446,16 @@ class TestHashStr2Int: for test_str in test_strings: result = hash_str2int(test_str) assert isinstance(result, int) - assert 0 <= result < 10 ** 8 + assert 0 <= result < 10**8 def test_whitespace_strings(self): """Test hashing strings with various whitespace""" - test_strings = [ - " leading", - "trailing ", - " both ", - "\ttab", - "new\nline", - "\r\nwindows" - ] + test_strings = [" leading", "trailing ", " both ", "\ttab", "new\nline", "\r\nwindows"] for test_str in test_strings: result = hash_str2int(test_str) assert isinstance(result, int) - assert 0 <= result < 10 ** 8 + assert 0 <= result < 10**8 class TestConvertBytes: diff --git a/test/unit_test/common/test_settings_queue.py b/test/unit_test/common/test_settings_queue.py index 6ec582ab9f..3a2c69c7c7 100644 --- a/test/unit_test/common/test_settings_queue.py +++ b/test/unit_test/common/test_settings_queue.py @@ -47,7 +47,7 @@ class TestGetSvrQueueName: def test_suffix_parameter_ignored(self): """Test that suffix parameter is currently ignored (hardcoded to 'common'). - + Note: The function signature accepts a suffix parameter but currently hardcodes 'common' in the return value. This test documents this behavior. """ @@ -137,7 +137,7 @@ class TestGetSvrQueueNames: def test_suffix_parameter_passed_through(self): """Test that suffix parameter is passed to get_svr_queue_name. - + Note: Since get_svr_queue_name currently hardcodes 'common' as the suffix, different suffix values will still produce the same result. """ diff --git a/test/unit_test/common/test_string_utils.py b/test/unit_test/common/test_string_utils.py index 7c33e83558..8833578adb 100644 --- a/test/unit_test/common/test_string_utils.py +++ b/test/unit_test/common/test_string_utils.py @@ -19,7 +19,6 @@ from common.string_utils import remove_redundant_spaces, clean_markdown_block class TestRemoveRedundantSpaces: - # Basic punctuation tests @pytest.mark.skip(reason="Failed") def test_remove_spaces_before_commas(self): @@ -244,7 +243,6 @@ class TestRemoveRedundantSpaces: class TestCleanMarkdownBlock: - def test_standard_markdown_block(self): """Test standard Markdown code block syntax""" input_text = "```markdown\nHello world\n```" @@ -356,4 +354,3 @@ class TestCleanMarkdownBlock: input_text = "```markdown\nFirst line\n```\n```markdown\nSecond line\n```" expected = "First line\n```\n```markdown\nSecond line" assert clean_markdown_block(input_text) == expected - diff --git a/test/unit_test/common/test_think_stream_parser.py b/test/unit_test/common/test_think_stream_parser.py index e98b57006e..ec9ac654d0 100644 --- a/test/unit_test/common/test_think_stream_parser.py +++ b/test/unit_test/common/test_think_stream_parser.py @@ -342,9 +342,7 @@ class TestThinkStreamParser(unittest.TestCase): with self.subTest(case=case_name): buf = io.StringIO() with redirect_stdout(buf): - think_text, answer_text, markers = asyncio.run( - _collect_case(case["chunks"], case["min_tokens"]) - ) + think_text, answer_text, markers = asyncio.run(_collect_case(case["chunks"], case["min_tokens"])) expected = case["expected"] self.assertEqual(think_text, expected["think"], case_name) diff --git a/test/unit_test/common/test_time_utils.py b/test/unit_test/common/test_time_utils.py index c7142df5bc..ae8f5feeba 100644 --- a/test/unit_test/common/test_time_utils.py +++ b/test/unit_test/common/test_time_utils.py @@ -457,13 +457,16 @@ class TestDatetimeFormat: assert result == expected - @pytest.mark.parametrize("year,month,day,hour,minute,second,microsecond", [ - (2024, 1, 1, 0, 0, 0, 0), # Start of day - (2024, 12, 31, 23, 59, 59, 999999), # End of year - (2000, 6, 15, 12, 30, 45, 500000), # Random date - (1970, 1, 1, 0, 0, 0, 123456), # Epoch equivalent - (2030, 3, 20, 6, 15, 30, 750000), # Future date - ]) + @pytest.mark.parametrize( + "year,month,day,hour,minute,second,microsecond", + [ + (2024, 1, 1, 0, 0, 0, 0), # Start of day + (2024, 12, 31, 23, 59, 59, 999999), # End of year + (2000, 6, 15, 12, 30, 45, 500000), # Random date + (1970, 1, 1, 0, 0, 0, 123456), # Epoch equivalent + (2030, 3, 20, 6, 15, 30, 750000), # Future date + ], + ) def test_parametrized_datetimes(self, year, month, day, hour, minute, second, microsecond): """Test multiple datetime scenarios using parametrization""" original_dt = datetime.datetime(year, month, day, hour, minute, second, microsecond) @@ -668,17 +671,13 @@ class TestTimestampToDateCurrentTimeFallback: """None input must resolve to current_timestamp() fallback.""" fixed_ms = 1704067200123 monkeypatch.setattr("common.time_utils.current_timestamp", lambda: fixed_ms) - assert timestamp_to_date(None) == time.strftime( - "%Y-%m-%d %H:%M:%S", time.localtime(fixed_ms / 1000) - ) + assert timestamp_to_date(None) == time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(fixed_ms / 1000)) def test_empty_string_uses_current_time(self, monkeypatch): """Empty-string input must resolve to current_timestamp() fallback.""" fixed_ms = 1704067200123 monkeypatch.setattr("common.time_utils.current_timestamp", lambda: fixed_ms) - assert timestamp_to_date("") == time.strftime( - "%Y-%m-%d %H:%M:%S", time.localtime(fixed_ms / 1000) - ) + assert timestamp_to_date("") == time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(fixed_ms / 1000)) def test_zero_timestamp_is_not_treated_as_empty(self): """Zero timestamp should map to Unix epoch, not fallback to current time.""" @@ -711,4 +710,4 @@ class TestFormatIso8601ToYmdHms: def test_invalid_string_returns_original(self): """Unparseable input is returned unchanged.""" - assert format_iso_8601_to_ymd_hms("not-a-date") == "not-a-date" \ No newline at end of file + assert format_iso_8601_to_ymd_hms("not-a-date") == "not-a-date" diff --git a/test/unit_test/common/test_token_utils.py b/test/unit_test/common/test_token_utils.py index ba8890826d..bcfc84fe37 100644 --- a/test/unit_test/common/test_token_utils.py +++ b/test/unit_test/common/test_token_utils.py @@ -62,8 +62,7 @@ class TestNumTokensFromString: def test_long_text(self): """Test token count for longer text""" - long_text = "This is a longer piece of text that should contain multiple sentences. " \ - "It will help verify that the token counting works correctly for substantial input." + long_text = "This is a longer piece of text that should contain multiple sentences. It will help verify that the token counting works correctly for substantial input." result = num_tokens_from_string(long_text) assert result > 10 @@ -93,13 +92,16 @@ class TestNumTokensFromString: # Additional parameterized tests for efficiency -@pytest.mark.parametrize("input_string,expected_min_tokens", [ - ("a", 1), # Single character - ("test", 1), # Single word - ("hello world", 2), # Two words - ("This is a sentence.", 4), # Short sentence - # ("A" * 100, 100), # Repeated characters -]) +@pytest.mark.parametrize( + "input_string,expected_min_tokens", + [ + ("a", 1), # Single character + ("test", 1), # Single word + ("hello world", 2), # Two words + ("This is a sentence.", 4), # Short sentence + # ("A" * 100, 100), # Repeated characters + ], +) def test_token_count_ranges(input_string, expected_min_tokens): """Parameterized test for various input strings""" result = num_tokens_from_string(input_string) @@ -127,88 +129,42 @@ class TestTotalTokenCountFromResponse: def test_dict_with_usage_total_tokens(self): """Test dictionary response with usage['total_tokens']""" - resp_dict = { - 'usage': { - 'total_tokens': 175 - } - } + resp_dict = {"usage": {"total_tokens": 175}} result = total_token_count_from_response(resp_dict) assert result == 175 def test_dict_with_usage_input_output_tokens(self): """Test dictionary response with input_tokens and output_tokens in usage""" - resp_dict = { - 'usage': { - 'input_tokens': 100, - 'output_tokens': 50 - } - } + resp_dict = {"usage": {"input_tokens": 100, "output_tokens": 50}} result = total_token_count_from_response(resp_dict) assert result == 150 def test_dict_with_meta_tokens_input_output(self): """Test dictionary response with meta.tokens.input_tokens and output_tokens""" - resp_dict = { - 'meta': { - 'tokens': { - 'input_tokens': 80, - 'output_tokens': 40 - } - } - } + resp_dict = {"meta": {"tokens": {"input_tokens": 80, "output_tokens": 40}}} result = total_token_count_from_response(resp_dict) assert result == 120 def test_priority_order_dict_usage_total_tokens_third(self): """Test that dict['usage']['total_tokens'] is third in priority""" - resp_dict = { - 'usage': { - 'total_tokens': 180, - 'input_tokens': 100, - 'output_tokens': 80 - }, - 'meta': { - 'tokens': { - 'input_tokens': 200, - 'output_tokens': 100 - } - } - } + resp_dict = {"usage": {"total_tokens": 180, "input_tokens": 100, "output_tokens": 80}, "meta": {"tokens": {"input_tokens": 200, "output_tokens": 100}}} result = total_token_count_from_response(resp_dict) assert result == 180 # Should use total_tokens from usage def test_priority_order_dict_usage_input_output_fourth(self): """Test that dict['usage']['input_tokens'] + output_tokens is fourth in priority""" - resp_dict = { - 'usage': { - 'input_tokens': 120, - 'output_tokens': 60 - }, - 'meta': { - 'tokens': { - 'input_tokens': 200, - 'output_tokens': 100 - } - } - } + resp_dict = {"usage": {"input_tokens": 120, "output_tokens": 60}, "meta": {"tokens": {"input_tokens": 200, "output_tokens": 100}}} result = total_token_count_from_response(resp_dict) assert result == 180 # Should sum input_tokens + output_tokens from usage def test_priority_order_meta_tokens_last(self): """Test that meta.tokens is the last option in priority""" - resp_dict = { - 'meta': { - 'tokens': { - 'input_tokens': 90, - 'output_tokens': 30 - } - } - } + resp_dict = {"meta": {"tokens": {"input_tokens": 90, "output_tokens": 30}}} result = total_token_count_from_response(resp_dict) assert result == 120 @@ -222,8 +178,8 @@ class TestTotalTokenCountFromResponse: def test_partial_dict_usage_missing_output_tokens(self): """Test dictionary with usage but missing output_tokens""" resp_dict = { - 'usage': { - 'input_tokens': 100 + "usage": { + "input_tokens": 100 # Missing output_tokens } } @@ -234,9 +190,9 @@ class TestTotalTokenCountFromResponse: def test_partial_meta_tokens_missing_input_tokens(self): """Test dictionary with meta.tokens but missing input_tokens""" resp_dict = { - 'meta': { - 'tokens': { - 'output_tokens': 50 + "meta": { + "tokens": { + "output_tokens": 50 # Missing input_tokens } } diff --git a/test/unit_test/data_source/test_bigquery_connector.py b/test/unit_test/data_source/test_bigquery_connector.py index baa856bebe..56eeee9fcb 100644 --- a/test/unit_test/data_source/test_bigquery_connector.py +++ b/test/unit_test/data_source/test_bigquery_connector.py @@ -174,17 +174,12 @@ def test_missing_table_and_query_raises(): @pytest.mark.p2 def test_time_filtered_query_compound_cursor_with_id_column(): - client = _FakeClient(table_schema=[ - _FakeSchemaField("updated_at", "TIMESTAMP"), - _FakeSchemaField("id", "STRING") - ]) + client = _FakeClient(table_schema=[_FakeSchemaField("updated_at", "TIMESTAMP"), _FakeSchemaField("id", "STRING")]) connector = _make_connector(client=client, timestamp_column="updated_at", id_column="id") start = datetime(2026, 1, 1, tzinfo=timezone.utc) end = datetime(2026, 2, 1, tzinfo=timezone.utc) - query, params = connector._build_time_filtered_query( - connector._build_base_query(), start, end, start_id="last-id" - ) + query, params = connector._build_time_filtered_query(connector._build_base_query(), start, end, start_id="last-id") assert "(ragflow_src.updated_at > @start_cursor OR (ragflow_src.updated_at = @start_cursor AND ragflow_src.id > @start_cursor_id))" in query assert "ragflow_src.updated_at <= @end_cursor" in query @@ -194,6 +189,7 @@ def test_time_filtered_query_compound_cursor_with_id_column(): ("end_cursor", "TIMESTAMP", end), ] + @pytest.mark.p2 def test_time_filtered_query_uses_gte_without_id_column(): client = _FakeClient(table_schema=[_FakeSchemaField("updated_at", "TIMESTAMP")]) @@ -201,9 +197,7 @@ def test_time_filtered_query_uses_gte_without_id_column(): start = datetime(2026, 1, 1, tzinfo=timezone.utc) end = datetime(2026, 2, 1, tzinfo=timezone.utc) - query, params = connector._build_time_filtered_query( - connector._build_base_query(), start, end - ) + query, params = connector._build_time_filtered_query(connector._build_base_query(), start, end) assert "ragflow_src.updated_at >= @start_cursor" in query assert "ragflow_src.updated_at <= @end_cursor" in query @@ -251,13 +245,8 @@ def test_value_serialization_types(): assert BigQueryConnector._render_content_value(b"binary") is None assert BigQueryConnector._render_content_value(date(2026, 1, 2)) == "2026-01-02" assert BigQueryConnector._render_content_value(time(3, 4, 5)) == "03:04:05" - assert ( - BigQueryConnector._render_content_value({"a": 1, "b": [1, 2]}) - == '{"a": 1, "b": [1, 2]}' - ) - assert ( - BigQueryConnector._render_content_value("POINT(1 2)") == "POINT(1 2)" - ) # GEOGRAPHY WKT passes through + assert BigQueryConnector._render_content_value({"a": 1, "b": [1, 2]}) == '{"a": 1, "b": [1, 2]}' + assert BigQueryConnector._render_content_value("POINT(1 2)") == "POINT(1 2)" # GEOGRAPHY WKT passes through # Metadata base64-encodes bytes instead of skipping. assert BigQueryConnector._render_metadata_value(b"hi") == "aGk=" @@ -310,10 +299,7 @@ def test_custom_query_mode_id_prefix(): @pytest.mark.p2 def test_batches_accumulate_to_batch_size(): - rows = [ - {"id": i, "name": f"n{i}", "description": "d", "status": "s"} - for i in range(5) - ] + rows = [{"id": i, "name": f"n{i}", "description": "d", "status": "s"} for i in range(5)] client = _FakeClient(rows=rows) connector = _make_connector(client=client, batch_size=2) @@ -329,18 +315,10 @@ def test_cursor_serialize_deserialize_roundtrip(): t = time(12, 0) dec = Decimal("1.23") - assert BigQueryConnector.deserialize_cursor_value( - BigQueryConnector.serialize_cursor_value(dt) - ) == dt - assert BigQueryConnector.deserialize_cursor_value( - BigQueryConnector.serialize_cursor_value(d) - ) == d - assert BigQueryConnector.deserialize_cursor_value( - BigQueryConnector.serialize_cursor_value(t) - ) == t - assert BigQueryConnector.deserialize_cursor_value( - BigQueryConnector.serialize_cursor_value(dec) - ) == dec + assert BigQueryConnector.deserialize_cursor_value(BigQueryConnector.serialize_cursor_value(dt)) == dt + assert BigQueryConnector.deserialize_cursor_value(BigQueryConnector.serialize_cursor_value(d)) == d + assert BigQueryConnector.deserialize_cursor_value(BigQueryConnector.serialize_cursor_value(t)) == t + assert BigQueryConnector.deserialize_cursor_value(BigQueryConnector.serialize_cursor_value(dec)) == dec assert BigQueryConnector.serialize_cursor_value(42) == 42 assert BigQueryConnector.deserialize_cursor_value(42) == 42 @@ -383,6 +361,7 @@ def test_validation_detects_missing_content_column(): with pytest.raises(ConnectorValidationError, match="name"): connector.validate_connector_settings() + @pytest.mark.p2 def test_validation_detects_missing_metadata_column(): client = _FakeClient(table_schema=[_FakeSchemaField("name", "STRING")]) @@ -390,6 +369,7 @@ def test_validation_detects_missing_metadata_column(): with pytest.raises(ConnectorValidationError, match="status"): connector.validate_connector_settings() + @pytest.mark.p2 def test_validation_detects_missing_id_column(): client = _FakeClient(table_schema=[_FakeSchemaField("name", "STRING")]) @@ -397,6 +377,7 @@ def test_validation_detects_missing_id_column(): with pytest.raises(ConnectorValidationError, match="id"): connector.validate_connector_settings() + @pytest.mark.p2 def test_validation_detects_missing_timestamp_column(): client = _FakeClient(table_schema=[_FakeSchemaField("name", "STRING")]) @@ -404,6 +385,7 @@ def test_validation_detects_missing_timestamp_column(): with pytest.raises(ConnectorValidationError, match="ts"): connector.validate_connector_settings() + @pytest.mark.p2 def test_validation_detects_unsupported_cursor_type_early(): client = _FakeClient(table_schema=[_FakeSchemaField("name", "STRING"), _FakeSchemaField("ts", "BOOL")]) diff --git a/test/unit_test/data_source/test_imap_connector_addr_parsing.py b/test/unit_test/data_source/test_imap_connector_addr_parsing.py index c5a54e671f..4603a79dfb 100644 --- a/test/unit_test/data_source/test_imap_connector_addr_parsing.py +++ b/test/unit_test/data_source/test_imap_connector_addr_parsing.py @@ -42,9 +42,7 @@ class TestParseAddrs: assert _parse_addrs("user@example.com") == [("", "user@example.com")] def test_address_with_display_name(self): - assert _parse_addrs("Alice ") == [ - ("Alice", "alice@example.com") - ] + assert _parse_addrs("Alice ") == [("Alice", "alice@example.com")] def test_quoted_display_name_with_comma_returns_single_address(self): # #14963: the bug was that ``split(",")`` produced two bogus tuples. @@ -57,9 +55,7 @@ class TestParseAddrs: assert result == [("", "a@example.com"), ("", "b@example.com")] def test_multiple_addresses_with_quoted_comma_in_name(self): - result = _parse_addrs( - '"Wilkens, Michael" , "Müller, Hans" ' - ) + result = _parse_addrs('"Wilkens, Michael" , "Müller, Hans" ') assert result == [ ("Wilkens, Michael", "m@example.com"), ("Müller, Hans", "h@example.com"), @@ -79,9 +75,7 @@ class TestParseSingularAddr: def test_quoted_comma_display_name_does_not_raise(self): # #14963 cascade: before the fix, ``_parse_addrs`` returned two bogus # tuples and ``_parse_singular_addr`` then raised RuntimeError. - assert _parse_singular_addr( - '"Schlüter, Sabine" ' - ) == ("Schlüter, Sabine", "sabine.schlueter@ihklw.de") + assert _parse_singular_addr('"Schlüter, Sabine" ') == ("Schlüter, Sabine", "sabine.schlueter@ihklw.de") def test_multi_address_header_warns_and_returns_first(self, caplog): # #14964: a legitimately multi-address From header must not crash sync. @@ -89,9 +83,7 @@ class TestParseSingularAddr: with caplog.at_level(logging.WARNING): result = _parse_singular_addr(header) assert result == ("User A", "a@example.com") - assert any( - "Multiple addresses" in rec.message for rec in caplog.records - ), f"expected warning about multiple addresses, got: {caplog.records}" + assert any("Multiple addresses" in rec.message for rec in caplog.records), f"expected warning about multiple addresses, got: {caplog.records}" def test_multi_address_header_does_not_raise(self): # Explicit guard: no RuntimeError should propagate. diff --git a/test/unit_test/data_source/test_onedrive_connector_unit.py b/test/unit_test/data_source/test_onedrive_connector_unit.py index 63ff77b6aa..e8bed87420 100644 --- a/test/unit_test/data_source/test_onedrive_connector_unit.py +++ b/test/unit_test/data_source/test_onedrive_connector_unit.py @@ -26,22 +26,19 @@ _GRAPH_BASE = "https://graph.microsoft.com/v1.0" # folder_path / _delta_url # --------------------------------------------------------------------------- + @pytest.mark.p2 def test_folder_path_prepends_leading_slash_for_delta_url(): connector = OneDriveConnector(folder_path="Documents/Reports") assert connector.folder_path == "/Documents/Reports" - assert connector._delta_url("drive-1") == ( - f"{_GRAPH_BASE}/drives/drive-1/root:/Documents/Reports:/delta" - ) + assert connector._delta_url("drive-1") == (f"{_GRAPH_BASE}/drives/drive-1/root:/Documents/Reports:/delta") @pytest.mark.p2 def test_folder_path_preserves_leading_slash(): connector = OneDriveConnector(folder_path="/Documents/Reports/") assert connector.folder_path == "/Documents/Reports" - assert connector._delta_url("drive-1") == ( - f"{_GRAPH_BASE}/drives/drive-1/root:/Documents/Reports:/delta" - ) + assert connector._delta_url("drive-1") == (f"{_GRAPH_BASE}/drives/drive-1/root:/Documents/Reports:/delta") @pytest.mark.p2 @@ -54,9 +51,7 @@ def test_folder_path_rejects_parent_segments(): def test_folder_path_normalizes_consecutive_slashes(): connector = OneDriveConnector(folder_path="//Documents//Reports") assert connector.folder_path == "/Documents/Reports" - assert connector._delta_url("drive-1") == ( - f"{_GRAPH_BASE}/drives/drive-1/root:/Documents/Reports:/delta" - ) + assert connector._delta_url("drive-1") == (f"{_GRAPH_BASE}/drives/drive-1/root:/Documents/Reports:/delta") @pytest.mark.p2 @@ -83,6 +78,7 @@ def test_folder_path_double_slash_only_uses_drive_root_delta(): # load_credentials # --------------------------------------------------------------------------- + @pytest.mark.p2 def test_load_credentials_missing_fields_raises(): connector = OneDriveConnector() @@ -122,6 +118,7 @@ def test_load_credentials_msal_failure_raises(): # validate_connector_settings # --------------------------------------------------------------------------- + @pytest.mark.p2 def test_validate_without_credentials_raises(): connector = OneDriveConnector() @@ -190,6 +187,7 @@ def test_validate_unexpected_status_raises(): # Checkpoint helpers # --------------------------------------------------------------------------- + @pytest.mark.p2 def test_build_dummy_checkpoint(): connector = OneDriveConnector() @@ -210,6 +208,7 @@ def test_validate_checkpoint_json_invalid_returns_dummy(): # _iter_documents (via poll_source) # --------------------------------------------------------------------------- + def _ok(json_value): """Tiny helper: build a successful MagicMock response with .ok / .json().""" resp = MagicMock() @@ -234,20 +233,22 @@ def test_poll_source_yields_supported_files(): connector._access_token = "tok" drives_resp = _ok({"value": [{"id": "drive-1"}]}) - delta_resp = _ok({ - "value": [ - { - "id": "file-1", - "name": "report.docx", - "file": {}, - "lastModifiedDateTime": "2026-05-20T10:00:00Z", - "webUrl": "https://example.com/report.docx", - "size": 1024, - "createdBy": {"user": {"displayName": "Alice"}}, - } - ], - "@odata.deltaLink": "https://graph.microsoft.com/delta-link", - }) + delta_resp = _ok( + { + "value": [ + { + "id": "file-1", + "name": "report.docx", + "file": {}, + "lastModifiedDateTime": "2026-05-20T10:00:00Z", + "webUrl": "https://example.com/report.docx", + "size": 1024, + "createdBy": {"user": {"displayName": "Alice"}}, + } + ], + "@odata.deltaLink": "https://graph.microsoft.com/delta-link", + } + ) with patch.object(connector, "_get", side_effect=[drives_resp, delta_resp]): batches = list(connector.poll_source(0.0, 9999999999.0)) @@ -263,18 +264,20 @@ def test_poll_source_skips_unsupported_extensions(): connector._access_token = "tok" drives_resp = _ok({"value": [{"id": "drive-1"}]}) - delta_resp = _ok({ - "value": [ - { - "id": "img-1", - "name": "photo.png", # not in _SUPPORTED_EXTENSIONS - "file": {}, - "lastModifiedDateTime": "2026-05-20T10:00:00Z", - "webUrl": "https://example.com/photo.png", - "size": 512, - } - ], - }) + delta_resp = _ok( + { + "value": [ + { + "id": "img-1", + "name": "photo.png", # not in _SUPPORTED_EXTENSIONS + "file": {}, + "lastModifiedDateTime": "2026-05-20T10:00:00Z", + "webUrl": "https://example.com/photo.png", + "size": 512, + } + ], + } + ) with patch.object(connector, "_get", side_effect=[drives_resp, delta_resp]): batches = list(connector.poll_source(0.0, 9999999999.0)) @@ -288,16 +291,18 @@ def test_poll_source_skips_deleted_items(): connector._access_token = "tok" drives_resp = _ok({"value": [{"id": "drive-1"}]}) - delta_resp = _ok({ - "value": [ - { - "id": "file-del", - "name": "gone.docx", - "file": {}, - "deleted": {"state": "deleted"}, - } - ], - }) + delta_resp = _ok( + { + "value": [ + { + "id": "file-del", + "name": "gone.docx", + "file": {}, + "deleted": {"state": "deleted"}, + } + ], + } + ) with patch.object(connector, "_get", side_effect=[drives_resp, delta_resp]): batches = list(connector.poll_source(0.0, 9999999999.0)) @@ -309,6 +314,7 @@ def test_poll_source_skips_deleted_items(): # Non-2xx Graph responses must raise (no silent partial syncs) # --------------------------------------------------------------------------- + @pytest.mark.p1 def test_iter_documents_raises_on_graph_http_500(): """A 500 from the delta endpoint must surface — silently breaking would @@ -353,6 +359,7 @@ def test_list_drive_ids_raises_on_http_error(): # retrieve_all_slim_docs_perm_sync: yields SlimDocument batches for prune # --------------------------------------------------------------------------- + @pytest.mark.p1 def test_retrieve_slim_docs_yields_slimdocument_batches(): """The prune collector does file_list.extend(batch) and reads .id on each @@ -362,13 +369,15 @@ def test_retrieve_slim_docs_yields_slimdocument_batches(): connector._access_token = "tok" drives_resp = _ok({"value": [{"id": "drive-1"}]}) - delta_resp = _ok({ - "value": [ - {"id": "f1", "name": "a.docx", "file": {}}, - {"id": "f2", "name": "b.pdf", "file": {}}, - {"id": "f3", "name": "c.txt", "file": {}}, - ], - }) + delta_resp = _ok( + { + "value": [ + {"id": "f1", "name": "a.docx", "file": {}}, + {"id": "f2", "name": "b.pdf", "file": {}}, + {"id": "f3", "name": "c.txt", "file": {}}, + ], + } + ) with patch.object(connector, "_get", side_effect=[drives_resp, delta_resp]): batches = list(connector.retrieve_all_slim_docs_perm_sync()) @@ -387,13 +396,15 @@ def test_retrieve_slim_docs_skips_folders_and_deleted(): connector._access_token = "tok" drives_resp = _ok({"value": [{"id": "drive-1"}]}) - delta_resp = _ok({ - "value": [ - {"id": "folder-1", "name": "Docs", "folder": {}}, # folder, no "file" - {"id": "del-1", "name": "gone.pdf", "file": {}, "deleted": {"state": "deleted"}}, - {"id": "ok-1", "name": "keep.pdf", "file": {}}, - ], - }) + delta_resp = _ok( + { + "value": [ + {"id": "folder-1", "name": "Docs", "folder": {}}, # folder, no "file" + {"id": "del-1", "name": "gone.pdf", "file": {}, "deleted": {"state": "deleted"}}, + {"id": "ok-1", "name": "keep.pdf", "file": {}}, + ], + } + ) with patch.object(connector, "_get", side_effect=[drives_resp, delta_resp]): batches = list(connector.retrieve_all_slim_docs_perm_sync()) @@ -427,6 +438,7 @@ def test_retrieve_slim_docs_requires_credentials(): # load_from_checkpoint: resumes from delta_links and honors start floor # --------------------------------------------------------------------------- + @pytest.mark.p1 def test_load_from_checkpoint_uses_persisted_delta_link(): """When the checkpoint carries a delta_link for a drive, the connector diff --git a/test/unit_test/data_source/test_outlook_connector_unit.py b/test/unit_test/data_source/test_outlook_connector_unit.py index 37ebe23b7d..8ebac046e1 100644 --- a/test/unit_test/data_source/test_outlook_connector_unit.py +++ b/test/unit_test/data_source/test_outlook_connector_unit.py @@ -29,6 +29,7 @@ _GOOD_CREDS = { # _strip_html # --------------------------------------------------------------------------- + @pytest.mark.p3 def test_strip_html_removes_tags_and_script(): html = "

Hello world

" @@ -45,6 +46,7 @@ def test_strip_html_empty_returns_empty(): # load_credentials # --------------------------------------------------------------------------- + @pytest.mark.p2 def test_load_credentials_missing_fields_raises(): connector = OutlookConnector() @@ -90,6 +92,7 @@ def test_load_credentials_msal_failure_raises(): # validate_connector_settings # --------------------------------------------------------------------------- + @pytest.mark.p2 def test_validate_without_credentials_raises(): connector = OutlookConnector() @@ -170,6 +173,7 @@ def test_validate_5xx_raises_unexpected(): # Checkpoint helpers # --------------------------------------------------------------------------- + @pytest.mark.p2 def test_build_dummy_checkpoint(): connector = OutlookConnector() @@ -190,6 +194,7 @@ def test_validate_checkpoint_json_invalid_returns_dummy(): # _list_user_ids # --------------------------------------------------------------------------- + @pytest.mark.p2 def test_list_user_ids_returns_configured_ids(): connector = OutlookConnector(user_ids=["a@x.com", "b@x.com"]) @@ -224,11 +229,10 @@ def test_list_user_ids_paginates_when_unset(): # _iter_documents (via poll_source) # --------------------------------------------------------------------------- + @pytest.mark.p1 def test_poll_source_yields_messages(): - connector = OutlookConnector( - batch_size=10, user_ids=["alice@example.com"] - ) + connector = OutlookConnector(batch_size=10, user_ids=["alice@example.com"]) connector._access_token = "tok" delta_resp = MagicMock(ok=True) @@ -240,12 +244,8 @@ def test_poll_source_yields_messages(): "body": {"contentType": "text", "content": "Body text"}, "receivedDateTime": "2026-05-20T10:00:00Z", "webLink": "https://outlook.office.com/mail/1", - "from": { - "emailAddress": {"name": "Bob", "address": "bob@example.com"} - }, - "toRecipients": [ - {"emailAddress": {"address": "alice@example.com"}} - ], + "from": {"emailAddress": {"name": "Bob", "address": "bob@example.com"}}, + "toRecipients": [{"emailAddress": {"address": "alice@example.com"}}], "ccRecipients": [], "hasAttachments": False, "conversationId": "conv-1", @@ -268,9 +268,7 @@ def test_poll_source_yields_messages(): @pytest.mark.p2 def test_poll_source_filters_old_messages(): - connector = OutlookConnector( - batch_size=10, user_ids=["alice@example.com"] - ) + connector = OutlookConnector(batch_size=10, user_ids=["alice@example.com"]) connector._access_token = "tok" delta_resp = MagicMock(ok=True) @@ -293,9 +291,7 @@ def test_poll_source_filters_old_messages(): @pytest.mark.p2 def test_poll_source_skips_removed_messages(): - connector = OutlookConnector( - batch_size=10, user_ids=["alice@example.com"] - ) + connector = OutlookConnector(batch_size=10, user_ids=["alice@example.com"]) connector._access_token = "tok" delta_resp = MagicMock(ok=True) @@ -320,9 +316,7 @@ def test_poll_source_skips_removed_messages(): @pytest.mark.p2 def test_poll_source_html_body_is_stripped(): - connector = OutlookConnector( - batch_size=10, user_ids=["alice@example.com"] - ) + connector = OutlookConnector(batch_size=10, user_ids=["alice@example.com"]) connector._access_token = "tok" delta_resp = MagicMock(ok=True) @@ -352,6 +346,7 @@ def test_poll_source_html_body_is_stripped(): # Non-2xx Graph responses must raise (no silent partial syncs) # --------------------------------------------------------------------------- + def _ok(json_value): resp = MagicMock(ok=True, status_code=200) resp.json.return_value = json_value @@ -399,6 +394,7 @@ def test_list_user_ids_raises_on_http_error(): # retrieve_all_slim_docs_perm_sync: yields list[SlimDocument] for prune # --------------------------------------------------------------------------- + @pytest.mark.p1 def test_retrieve_slim_docs_yields_slimdocument_batches(): """The prune collector calls file_list.extend(batch) and reads `.id` on @@ -407,13 +403,15 @@ def test_retrieve_slim_docs_yields_slimdocument_batches(): connector = OutlookConnector(batch_size=2, user_ids=["alice@example.com"]) connector._access_token = "tok" - delta_resp = _ok({ - "value": [ - {"id": "m1", "subject": "a"}, - {"id": "m2", "subject": "b"}, - {"id": "m3", "subject": "c"}, - ], - }) + delta_resp = _ok( + { + "value": [ + {"id": "m1", "subject": "a"}, + {"id": "m2", "subject": "b"}, + {"id": "m3", "subject": "c"}, + ], + } + ) with patch.object(connector, "_get", return_value=delta_resp): batches = list(connector.retrieve_all_slim_docs_perm_sync()) @@ -430,12 +428,14 @@ def test_retrieve_slim_docs_skips_removed(): connector = OutlookConnector(batch_size=10, user_ids=["alice@example.com"]) connector._access_token = "tok" - delta_resp = _ok({ - "value": [ - {"id": "del", "@removed": {"reason": "deleted"}}, - {"id": "keep", "subject": "kept"}, - ], - }) + delta_resp = _ok( + { + "value": [ + {"id": "del", "@removed": {"reason": "deleted"}}, + {"id": "keep", "subject": "kept"}, + ], + } + ) with patch.object(connector, "_get", return_value=delta_resp): batches = list(connector.retrieve_all_slim_docs_perm_sync()) flat = [item for batch in batches for item in batch] @@ -462,6 +462,7 @@ def test_retrieve_slim_docs_requires_credentials(): # load_from_checkpoint: resumes from delta_links # --------------------------------------------------------------------------- + @pytest.mark.p1 def test_load_from_checkpoint_uses_persisted_delta_link(): """With a delta_link for a user the connector must hit that URL — not @@ -490,6 +491,7 @@ def test_load_from_checkpoint_uses_persisted_delta_link(): # _redact: keep debugging hint, drop PII # --------------------------------------------------------------------------- + @pytest.mark.p3 def test_redact_email_masks_local_and_domain(): assert _redact("alice@example.com") == "al***@***" diff --git a/test/unit_test/data_source/test_rest_api_connector.py b/test/unit_test/data_source/test_rest_api_connector.py index 8ae1d2f769..6b93dcf509 100644 --- a/test/unit_test/data_source/test_rest_api_connector.py +++ b/test/unit_test/data_source/test_rest_api_connector.py @@ -53,13 +53,17 @@ def _mocked_rest_api_requests_and_dns(): that wrap `requests.get` / `requests.post` and avoid retry backoff delays. """ mock_rl = MagicMock() - with patch( - "common.data_source.rest_api_connector.socket.getaddrinfo", - return_value=_MOCK_DNS_ADDRINFO, - ), patch.object(_ds_utils._RateLimitedRequest, "get", mock_rl.get), patch.object( - _ds_utils._RateLimitedRequest, - "post", - mock_rl.post, + with ( + patch( + "common.data_source.rest_api_connector.socket.getaddrinfo", + return_value=_MOCK_DNS_ADDRINFO, + ), + patch.object(_ds_utils._RateLimitedRequest, "get", mock_rl.get), + patch.object( + _ds_utils._RateLimitedRequest, + "post", + mock_rl.post, + ), ): yield mock_rl @@ -108,6 +112,7 @@ def _mock_response(json_data, status_code=200): # 1. Config schema validation # # ===================================================================== # + class TestRestAPIConfig: """Test Pydantic RestAPIConnectorConfig schema validation.""" @@ -143,9 +148,7 @@ class TestRestAPIConfig: def test_string_to_dict_coercion_for_headers(self): """A key=value string should be coerced to a dict.""" - cfg = RestAPIConnectorConfig( - url=VALID_URL, content_fields=["t"], headers="X-Custom=hello" - ) + cfg = RestAPIConnectorConfig(url=VALID_URL, content_fields=["t"], headers="X-Custom=hello") assert cfg.headers == {"X-Custom": "hello"} def test_string_to_list_coercion_for_content_fields(self): @@ -158,6 +161,7 @@ class TestRestAPIConfig: # 2. SSRF URL validation # # ===================================================================== # + class TestSSRFValidation: """Test that unsafe URLs are blocked before any HTTP request is made.""" @@ -249,6 +253,7 @@ class TestSSRFValidation: # redirect target. import common.data_source.rest_api_connector as rc_module from unittest.mock import patch as _patch + with _patch.object(rc_module.socket, "getaddrinfo", side_effect=_dns_for_host): # Coderabbit MAJOR #3486038795: SSRF validation failures inside # _safe_request are now wrapped to raise ConnectorValidationError @@ -305,11 +310,11 @@ class TestSSRFValidation: # 3. Authentication setup # # ===================================================================== # + class TestAuthSetup: """Test _build_auth produces the correct headers / auth objects.""" - @patch("common.data_source.rest_api_connector.socket.getaddrinfo", - return_value=[(2, 1, 6, "", ("93.184.216.34", 0))]) + @patch("common.data_source.rest_api_connector.socket.getaddrinfo", return_value=[(2, 1, 6, "", ("93.184.216.34", 0))]) def test_auth_none(self, _dns): """auth_type=none should produce no auth headers.""" c = _make_connector(auth_type=AuthType.NONE) @@ -317,8 +322,7 @@ class TestAuthSetup: assert c._auth_headers == {} assert c._basic_auth is None - @patch("common.data_source.rest_api_connector.socket.getaddrinfo", - return_value=[(2, 1, 6, "", ("93.184.216.34", 0))]) + @patch("common.data_source.rest_api_connector.socket.getaddrinfo", return_value=[(2, 1, 6, "", ("93.184.216.34", 0))]) def test_api_key_header(self, _dns): """api_key_header should set the specified header.""" c = _make_connector( @@ -328,16 +332,14 @@ class TestAuthSetup: c.load_credentials({"api_key": "secret123"}) assert c._auth_headers == {"X-API-Key": "secret123"} - @patch("common.data_source.rest_api_connector.socket.getaddrinfo", - return_value=[(2, 1, 6, "", ("93.184.216.34", 0))]) + @patch("common.data_source.rest_api_connector.socket.getaddrinfo", return_value=[(2, 1, 6, "", ("93.184.216.34", 0))]) def test_bearer_token(self, _dns): """bearer should set Authorization: Bearer .""" c = _make_connector(auth_type=AuthType.BEARER) c.load_credentials({"token": "tok_abc"}) assert c._auth_headers == {"Authorization": "Bearer tok_abc"} - @patch("common.data_source.rest_api_connector.socket.getaddrinfo", - return_value=[(2, 1, 6, "", ("93.184.216.34", 0))]) + @patch("common.data_source.rest_api_connector.socket.getaddrinfo", return_value=[(2, 1, 6, "", ("93.184.216.34", 0))]) def test_basic_auth(self, _dns): """basic should produce an HTTPBasicAuth object.""" c = _make_connector(auth_type=AuthType.BASIC) @@ -351,14 +353,13 @@ class TestAuthSetup: # 4. Field extraction # # ===================================================================== # + class TestFieldExtraction: """Test _extract_field / _extract_field_values dot-notation paths.""" - @patch("common.data_source.rest_api_connector.socket.getaddrinfo", - return_value=[(2, 1, 6, "", ("93.184.216.34", 0))]) + @patch("common.data_source.rest_api_connector.socket.getaddrinfo", return_value=[(2, 1, 6, "", ("93.184.216.34", 0))]) def setup_method(self, method, _dns=None): - with patch("common.data_source.rest_api_connector.socket.getaddrinfo", - return_value=[(2, 1, 6, "", ("93.184.216.34", 0))]): + with patch("common.data_source.rest_api_connector.socket.getaddrinfo", return_value=[(2, 1, 6, "", ("93.184.216.34", 0))]): self.connector = _make_connector() def test_simple_field(self): @@ -382,8 +383,7 @@ class TestFieldExtraction: def test_missing_field_with_default(self): """Missing field returns configured default value.""" - with patch("common.data_source.rest_api_connector.socket.getaddrinfo", - return_value=[(2, 1, 6, "", ("93.184.216.34", 0))]): + with patch("common.data_source.rest_api_connector.socket.getaddrinfo", return_value=[(2, 1, 6, "", ("93.184.216.34", 0))]): c = _make_connector(field_default_values={"missing": "fallback"}) result = c._get_typed_field_value("missing", {"other": 1}) assert result == "fallback" @@ -398,14 +398,13 @@ class TestFieldExtraction: # 5. Items array detection # # ===================================================================== # + class TestItemsArrayDetection: """Test _extract_items auto-detection of the items array.""" - @patch("common.data_source.rest_api_connector.socket.getaddrinfo", - return_value=[(2, 1, 6, "", ("93.184.216.34", 0))]) + @patch("common.data_source.rest_api_connector.socket.getaddrinfo", return_value=[(2, 1, 6, "", ("93.184.216.34", 0))]) def setup_method(self, method, _dns=None): - with patch("common.data_source.rest_api_connector.socket.getaddrinfo", - return_value=[(2, 1, 6, "", ("93.184.216.34", 0))]): + with patch("common.data_source.rest_api_connector.socket.getaddrinfo", return_value=[(2, 1, 6, "", ("93.184.216.34", 0))]): self.connector = _make_connector() def test_items_key(self): @@ -451,6 +450,7 @@ class TestItemsArrayDetection: # 6. HTML stripping # # ===================================================================== # + class TestHTMLStripping: """Test the _strip_html static method.""" @@ -485,14 +485,13 @@ class TestHTMLStripping: # 7. Document creation # # ===================================================================== # + class TestDocumentCreation: """Test _item_to_document mapping.""" - @patch("common.data_source.rest_api_connector.socket.getaddrinfo", - return_value=[(2, 1, 6, "", ("93.184.216.34", 0))]) + @patch("common.data_source.rest_api_connector.socket.getaddrinfo", return_value=[(2, 1, 6, "", ("93.184.216.34", 0))]) def setup_method(self, method, _dns=None): - with patch("common.data_source.rest_api_connector.socket.getaddrinfo", - return_value=[(2, 1, 6, "", ("93.184.216.34", 0))]): + with patch("common.data_source.rest_api_connector.socket.getaddrinfo", return_value=[(2, 1, 6, "", ("93.184.216.34", 0))]): self.connector = _make_connector( id_field="id", content_fields=["title", "body"], @@ -551,6 +550,7 @@ class TestDocumentCreation: # 8. Pagination behaviour # # ===================================================================== # + class TestPaginationBehavior: """Test pagination iteration with mocked HTTP responses.""" @@ -611,9 +611,7 @@ class TestPaginationBehavior: def test_max_pages_cap(self): """Pagination respects the max_pages safety cap.""" with _mocked_rest_api_requests_and_dns() as mock_rl: - mock_rl.get.return_value = _mock_response( - {"items": [{"title": "A"}, {"title": "B"}]} - ) + mock_rl.get.return_value = _mock_response({"items": [{"title": "A"}, {"title": "B"}]}) c = _make_paged_connector( max_pages=3, @@ -642,6 +640,7 @@ class TestPaginationBehavior: # 9. Non-retriable HTTP errors # # ===================================================================== # + class TestNonRetriableErrors: """Test that HTTP errors are classified correctly in _fetch_page.""" diff --git a/test/unit_test/data_source/test_sharepoint_connector_unit.py b/test/unit_test/data_source/test_sharepoint_connector_unit.py index ed12b6714b..a131d6e8a7 100644 --- a/test/unit_test/data_source/test_sharepoint_connector_unit.py +++ b/test/unit_test/data_source/test_sharepoint_connector_unit.py @@ -26,11 +26,7 @@ def _load_sharepoint_connector_module(): """Load sharepoint_connector.py in isolation (avoid the package __init__).""" repo_root = Path(__file__).resolve().parents[3] package_name = "common.data_source" - saved_modules = { - name: module - for name, module in sys.modules.items() - if name == package_name or name.startswith(f"{package_name}.") - } + saved_modules = {name: module for name, module in sys.modules.items() if name == package_name or name.startswith(f"{package_name}.")} package_stub = ModuleType(package_name) package_stub.__path__ = [str(repo_root / "common" / "data_source")] sys.modules[package_name] = package_stub @@ -214,9 +210,7 @@ def _collect(generator): def test_load_from_checkpoint_walks_libraries_and_downloads(): connector, _jan, _feb = _build_connector_with_tree() - docs, checkpoint = _collect( - connector.load_from_checkpoint(0.0, 9e12, connector.build_dummy_checkpoint()) - ) + docs, checkpoint = _collect(connector.load_from_checkpoint(0.0, 9e12, connector.build_dummy_checkpoint())) assert checkpoint.has_more is False assert {doc.id for doc in docs} == {"drv-A:f1", "drv-A:f2"} @@ -240,9 +234,7 @@ def test_load_from_checkpoint_filters_by_modified_window(): start = datetime(2026, 1, 15, tzinfo=timezone.utc).timestamp() end = datetime(2026, 3, 1, tzinfo=timezone.utc).timestamp() - docs, _ = _collect( - connector.load_from_checkpoint(start, end, connector.build_dummy_checkpoint()) - ) + docs, _ = _collect(connector.load_from_checkpoint(start, end, connector.build_dummy_checkpoint())) assert [doc.id for doc in docs] == ["drv-A:f2"] @@ -274,9 +266,7 @@ def test_document_ids_are_unique_across_drives_with_colliding_item_ids(): connector.graph_client = _FakeGraphClient(site) connector._site_url = "https://contoso.sharepoint.com/sites/MySite" - docs, _ = _collect( - connector.load_from_checkpoint(0.0, 9e12, connector.build_dummy_checkpoint()) - ) + docs, _ = _collect(connector.load_from_checkpoint(0.0, 9e12, connector.build_dummy_checkpoint())) ids = {doc.id for doc in docs} assert ids == {"drv-A:same-id", "drv-B:same-id"} diff --git a/test/unit_test/data_source/test_slack_connector_unit.py b/test/unit_test/data_source/test_slack_connector_unit.py index 56a9551e65..ad8c963e17 100644 --- a/test/unit_test/data_source/test_slack_connector_unit.py +++ b/test/unit_test/data_source/test_slack_connector_unit.py @@ -30,11 +30,7 @@ def _load_slack_connector_module(): """ repo_root = Path(__file__).resolve().parents[3] package_name = "common.data_source" - saved_modules = { - name: module - for name, module in sys.modules.items() - if name == package_name or name.startswith(f"{package_name}.") - } + saved_modules = {name: module for name, module in sys.modules.items() if name == package_name or name.startswith(f"{package_name}.")} package_stub = ModuleType(package_name) package_stub.__path__ = [str(repo_root / "common" / "data_source")] sys.modules[package_name] = package_stub diff --git a/test/unit_test/data_source/test_teams_connector_unit.py b/test/unit_test/data_source/test_teams_connector_unit.py index e03dfc5f60..c9bac0d7cf 100644 --- a/test/unit_test/data_source/test_teams_connector_unit.py +++ b/test/unit_test/data_source/test_teams_connector_unit.py @@ -26,11 +26,7 @@ def _load_teams_connector_module(): """Load teams_connector.py in isolation (avoid the package __init__).""" repo_root = Path(__file__).resolve().parents[3] package_name = "common.data_source" - saved_modules = { - name: module - for name, module in sys.modules.items() - if name == package_name or name.startswith(f"{package_name}.") - } + saved_modules = {name: module for name, module in sys.modules.items() if name == package_name or name.startswith(f"{package_name}.")} package_stub = ModuleType(package_name) package_stub.__path__ = [str(repo_root / "common" / "data_source")] sys.modules[package_name] = package_stub @@ -175,9 +171,7 @@ def test_load_credentials_sets_graph_client(monkeypatch): monkeypatch.setattr(teams_connector, "GraphClient", lambda token_callback: SimpleNamespace(cb=token_callback)) connector = TeamsConnector() - result = connector.load_credentials( - {"tenant_id": "tenant", "client_id": "client", "client_secret": "secret"} - ) + result = connector.load_credentials({"tenant_id": "tenant", "client_id": "client", "client_secret": "secret"}) assert result is None assert connector.graph_client is not None @@ -217,9 +211,7 @@ def test_validate_maps_permission_error(): def test_load_from_checkpoint_flattens_posts_and_replies(): connector = _build_connector() - docs, checkpoint = _collect( - connector.load_from_checkpoint(0.0, 9e12, connector.build_dummy_checkpoint()) - ) + docs, checkpoint = _collect(connector.load_from_checkpoint(0.0, 9e12, connector.build_dummy_checkpoint())) assert checkpoint.has_more is False assert {doc.id for doc in docs} == {"t1__c1__m1", "t1__c1__m2"} @@ -245,9 +237,7 @@ def test_load_from_checkpoint_filters_by_modified_window(): start = datetime(2026, 1, 15, tzinfo=timezone.utc).timestamp() end = datetime(2026, 3, 1, tzinfo=timezone.utc).timestamp() - docs, _ = _collect( - connector.load_from_checkpoint(start, end, connector.build_dummy_checkpoint()) - ) + docs, _ = _collect(connector.load_from_checkpoint(start, end, connector.build_dummy_checkpoint())) assert [doc.id for doc in docs] == ["t1__c1__m2"] diff --git a/test/unit_test/data_source/test_webdav_connector_unit.py b/test/unit_test/data_source/test_webdav_connector_unit.py index c7651d058c..fe6dd9c4b9 100644 --- a/test/unit_test/data_source/test_webdav_connector_unit.py +++ b/test/unit_test/data_source/test_webdav_connector_unit.py @@ -256,10 +256,7 @@ def test_yield_webdav_documents_skips_missing_size_metadata(caplog): "webdav:https://webdav.example:/small.txt", ] assert connector.client.downloaded_paths == ["/small.txt"] - assert ( - "missing.txt: size metadata missing from WebDAV server response, " - "skipping to avoid processing potentially large files." - ) in caplog.text + assert ("missing.txt: size metadata missing from WebDAV server response, skipping to avoid processing potentially large files.") in caplog.text @pytest.mark.p1 @@ -301,7 +298,4 @@ def test_retrieve_all_slim_docs_skips_missing_size_metadata(caplog): assert [doc.id for batch in batches for doc in batch] == [ "webdav:https://webdav.example:/small.txt", ] - assert ( - "missing.txt: size metadata missing from WebDAV server response, " - "skipping to avoid processing potentially large files." - ) in caplog.text + assert ("missing.txt: size metadata missing from WebDAV server response, skipping to avoid processing potentially large files.") in caplog.text diff --git a/test/unit_test/deepdoc/parser/test_markdown_parser.py b/test/unit_test/deepdoc/parser/test_markdown_parser.py index e9c8a92413..a649307e7c 100644 --- a/test/unit_test/deepdoc/parser/test_markdown_parser.py +++ b/test/unit_test/deepdoc/parser/test_markdown_parser.py @@ -194,6 +194,7 @@ class TestMarkdownTableDedup: assert "After" in sections[0] assert "| Name | Value |" not in sections[0] + class TestMarkdownElementExtractorDelimiterHeaders: def test_custom_delimiter_merges_consecutive_lone_headers_with_body(self, markdown_element_extractor): text = "# Title\n## Intro\nBody paragraph" diff --git a/test/unit_test/deepdoc/parser/test_opendataloader_parser.py b/test/unit_test/deepdoc/parser/test_opendataloader_parser.py index 98416a77c4..68c63a256e 100644 --- a/test/unit_test/deepdoc/parser/test_opendataloader_parser.py +++ b/test/unit_test/deepdoc/parser/test_opendataloader_parser.py @@ -30,8 +30,12 @@ for _m in ("pdfplumber", "PIL", "PIL.Image"): # deepdoc.parser.pdf_parser — provide a real base class so OpenDataLoaderParser # inherits a proper Python class, not a MagicMock (which breaks __init__). _pdf_parser_mod = _types.ModuleType("deepdoc.parser.pdf_parser") + + class _RAGFlowPdfParserStub: # noqa: E302 pass + + _pdf_parser_mod.RAGFlowPdfParser = _RAGFlowPdfParserStub sys.modules.setdefault("deepdoc.parser.pdf_parser", _pdf_parser_mod) sys.modules.setdefault("deepdoc", mock.MagicMock()) @@ -60,6 +64,7 @@ OpenDataLoaderParser = _mod.OpenDataLoaderParser # Helpers # --------------------------------------------------------------------------- + def _make_parser(api_url: str = "http://odl:9383") -> OpenDataLoaderParser: p = OpenDataLoaderParser() p.api_url = api_url @@ -78,6 +83,7 @@ def _fake_page_image(width: int = 600, height: int = 800): # check_installation() # --------------------------------------------------------------------------- + class TestCheckInstallation: def test_no_api_url_returns_false(self): p = OpenDataLoaderParser() @@ -106,6 +112,7 @@ class TestCheckInstallation: # parse_pdf() # --------------------------------------------------------------------------- + class TestParsePdf: def _mock_response(self, json_doc=None, md_text=None) -> mock.MagicMock: resp = mock.MagicMock() @@ -127,8 +134,7 @@ class TestParsePdf: pdf.write_bytes(b"%PDF-dummy") resp = self._mock_response(md_text="hello world") - with mock.patch.object(p, "__images__"), \ - mock.patch("requests.post", return_value=resp) as mock_post: + with mock.patch.object(p, "__images__"), mock.patch("requests.post", return_value=resp) as mock_post: p.parse_pdf(filepath=str(pdf)) mock_post.assert_called_once() @@ -140,8 +146,7 @@ class TestParsePdf: pdf_bytes = b"%PDF-binary" resp = self._mock_response(md_text="section text") - with mock.patch.object(p, "__images__"), \ - mock.patch("requests.post", return_value=resp) as mock_post: + with mock.patch.object(p, "__images__"), mock.patch("requests.post", return_value=resp) as mock_post: p.parse_pdf(filepath="file.pdf", binary=pdf_bytes) files_arg = mock_post.call_args.kwargs.get("files", {}) @@ -155,8 +160,7 @@ class TestParsePdf: pdf_bytes = b"%PDF-bytesio" resp = self._mock_response(md_text="text from bytesio") - with mock.patch.object(p, "__images__"), \ - mock.patch("requests.post", return_value=resp) as mock_post: + with mock.patch.object(p, "__images__"), mock.patch("requests.post", return_value=resp) as mock_post: p.parse_pdf(filepath="file.pdf", binary=io.BytesIO(pdf_bytes)) files_arg = mock_post.call_args.kwargs.get("files", {}) @@ -173,8 +177,7 @@ class TestParsePdf: } resp = self._mock_response(json_doc=json_doc) - with mock.patch.object(p, "__images__"), \ - mock.patch("requests.post", return_value=resp): + with mock.patch.object(p, "__images__"), mock.patch("requests.post", return_value=resp): sections, tables = p.parse_pdf(filepath="doc.pdf", binary=b"%PDF", parse_method="pipeline") assert any("Hello from JSON" in s[0] for s in sections) @@ -183,8 +186,7 @@ class TestParsePdf: p = _make_parser() resp = self._mock_response(json_doc=None, md_text="# Markdown heading\n\nBody text.") - with mock.patch.object(p, "__images__"), \ - mock.patch("requests.post", return_value=resp): + with mock.patch.object(p, "__images__"), mock.patch("requests.post", return_value=resp): sections, tables = p.parse_pdf(filepath="doc.pdf", binary=b"%PDF", parse_method="pipeline") assert len(sections) > 0 @@ -194,8 +196,7 @@ class TestParsePdf: p = _make_parser() resp = self._mock_response(md_text="ok") - with mock.patch.object(p, "__images__"), \ - mock.patch("requests.post", return_value=resp) as mock_post: + with mock.patch.object(p, "__images__"), mock.patch("requests.post", return_value=resp) as mock_post: p.parse_pdf(filepath="doc.pdf", binary=b"%PDF", sanitize=True) data_arg = mock_post.call_args.kwargs.get("data", {}) @@ -205,8 +206,7 @@ class TestParsePdf: p = _make_parser() resp = self._mock_response(md_text="ok") - with mock.patch.object(p, "__images__"), \ - mock.patch("requests.post", return_value=resp) as mock_post: + with mock.patch.object(p, "__images__"), mock.patch("requests.post", return_value=resp) as mock_post: p.parse_pdf(filepath="doc.pdf", binary=b"%PDF", sanitize=False) data_arg = mock_post.call_args.kwargs.get("data", {}) @@ -216,10 +216,8 @@ class TestParsePdf: p = _make_parser() resp = self._mock_response(md_text="ok") - with mock.patch.object(p, "__images__"), \ - mock.patch("requests.post", return_value=resp) as mock_post: - p.parse_pdf(filepath="doc.pdf", binary=b"%PDF", - hybrid="docling-fast", image_output="embedded") + with mock.patch.object(p, "__images__"), mock.patch("requests.post", return_value=resp) as mock_post: + p.parse_pdf(filepath="doc.pdf", binary=b"%PDF", hybrid="docling-fast", image_output="embedded") data_arg = mock_post.call_args.kwargs.get("data", {}) assert data_arg.get("hybrid") == "docling-fast" @@ -229,8 +227,7 @@ class TestParsePdf: p = _make_parser() resp = self._mock_response(md_text="ok") - with mock.patch.object(p, "__images__"), \ - mock.patch("requests.post", return_value=resp) as mock_post: + with mock.patch.object(p, "__images__"), mock.patch("requests.post", return_value=resp) as mock_post: p.parse_pdf(filepath="doc.pdf", binary=b"%PDF") data_arg = mock_post.call_args.kwargs.get("data", {}) @@ -243,8 +240,7 @@ class TestParsePdf: resp = self._mock_response(md_text="text") cb = mock.MagicMock() - with mock.patch.object(p, "__images__"), \ - mock.patch("requests.post", return_value=resp): + with mock.patch.object(p, "__images__"), mock.patch("requests.post", return_value=resp): p.parse_pdf(filepath="doc.pdf", binary=b"%PDF", callback=cb) progress_values = [call.args[0] for call in cb.call_args_list] @@ -254,8 +250,7 @@ class TestParsePdf: def test_http_error_raises_runtime_error(self): p = _make_parser() - with mock.patch.object(p, "__images__"), \ - mock.patch("requests.post", side_effect=requests.ConnectionError("down")): + with mock.patch.object(p, "__images__"), mock.patch("requests.post", side_effect=requests.ConnectionError("down")): with pytest.raises(RuntimeError, match="service call failed"): p.parse_pdf(filepath="doc.pdf", binary=b"%PDF") @@ -264,8 +259,7 @@ class TestParsePdf: resp = mock.MagicMock() resp.raise_for_status.side_effect = requests.HTTPError("500 Server Error") - with mock.patch.object(p, "__images__"), \ - mock.patch("requests.post", return_value=resp): + with mock.patch.object(p, "__images__"), mock.patch("requests.post", return_value=resp): with pytest.raises(RuntimeError, match="service call failed"): p.parse_pdf(filepath="doc.pdf", binary=b"%PDF") @@ -274,6 +268,7 @@ class TestParsePdf: # crop() — bounds guard # --------------------------------------------------------------------------- + class TestCrop: def test_returns_none_when_no_page_images(self): p = _make_parser() @@ -307,8 +302,7 @@ class TestCrop: canvas = mock.MagicMock() canvas.paste = mock.MagicMock() try: - with mock.patch.object(_mod.Image, "new", return_value=canvas), \ - mock.patch.object(_mod.Image, "alpha_composite", return_value=img): + with mock.patch.object(_mod.Image, "new", return_value=canvas), mock.patch.object(_mod.Image, "alpha_composite", return_value=img): p.crop(tag) except IndexError: pytest.fail("crop() raised IndexError for a valid page index") diff --git a/test/unit_test/deepdoc/parser/test_pdf_garbled_detection.py b/test/unit_test/deepdoc/parser/test_pdf_garbled_detection.py index fa7c4a8b76..43c572e2eb 100644 --- a/test/unit_test/deepdoc/parser/test_pdf_garbled_detection.py +++ b/test/unit_test/deepdoc/parser/test_pdf_garbled_detection.py @@ -35,18 +35,36 @@ from unittest import mock # We mock the heavy third-party modules so that pdf_parser.py can be loaded # purely for its static detection methods. _MOCK_MODULES = [ - "numpy", "np", "pdfplumber", "xgboost", "xgb", - "huggingface_hub", "PIL", "PIL.Image", "pypdf", - "sklearn", "sklearn.cluster", "sklearn.metrics", - "common", "common.file_utils", "common.misc_utils", "common.settings", + "numpy", + "np", + "pdfplumber", + "xgboost", + "xgb", + "huggingface_hub", + "PIL", + "PIL.Image", + "pypdf", + "sklearn", + "sklearn.cluster", + "sklearn.metrics", + "common", + "common.file_utils", + "common.misc_utils", + "common.settings", "common.token_utils", - "deepdoc", "deepdoc.vision", "deepdoc.parser", - "rag", "rag.nlp", "rag.prompts", "rag.prompts.generator", + "deepdoc", + "deepdoc.vision", + "deepdoc.parser", + "rag", + "rag.nlp", + "rag.prompts", + "rag.prompts.generator", ] for _m in _MOCK_MODULES: if _m not in sys.modules: sys.modules[_m] = mock.MagicMock() + def _find_project_root(marker="pyproject.toml"): """Walk up from this file until a directory containing *marker* is found.""" cur = os.path.dirname(os.path.abspath(__file__)) @@ -96,15 +114,15 @@ class TestIsGarbledChar: assert is_garbled_char(ch) is False def test_common_whitespace_not_garbled(self): - assert is_garbled_char('\t') is False - assert is_garbled_char('\n') is False - assert is_garbled_char('\r') is False - assert is_garbled_char(' ') is False + assert is_garbled_char("\t") is False + assert is_garbled_char("\n") is False + assert is_garbled_char("\r") is False + assert is_garbled_char(" ") is False def test_pua_chars_are_garbled(self): - assert is_garbled_char('\uE000') is True - assert is_garbled_char('\uF000') is True - assert is_garbled_char('\uF8FF') is True + assert is_garbled_char("\ue000") is True + assert is_garbled_char("\uf000") is True + assert is_garbled_char("\uf8ff") is True def test_supplementary_pua_a(self): assert is_garbled_char(chr(0xF0000)) is True @@ -115,20 +133,20 @@ class TestIsGarbledChar: assert is_garbled_char(chr(0x10FFFF)) is True def test_replacement_char(self): - assert is_garbled_char('\uFFFD') is True + assert is_garbled_char("\ufffd") is True def test_c0_control_chars(self): - assert is_garbled_char('\x00') is True - assert is_garbled_char('\x01') is True - assert is_garbled_char('\x1F') is True + assert is_garbled_char("\x00") is True + assert is_garbled_char("\x01") is True + assert is_garbled_char("\x1f") is True def test_c1_control_chars(self): - assert is_garbled_char('\x80') is True - assert is_garbled_char('\x8F') is True - assert is_garbled_char('\x9F') is True + assert is_garbled_char("\x80") is True + assert is_garbled_char("\x8f") is True + assert is_garbled_char("\x9f") is True def test_empty_string(self): - assert is_garbled_char('') is False + assert is_garbled_char("") is False def test_common_punctuation(self): for ch in ".,;:!?()[]{}\"'-/\\@#$%^&*+=<>~`|": @@ -164,15 +182,15 @@ class TestIsGarbledText: assert is_garbled_text(None) is False def test_all_pua_chars(self): - text = "\uE000\uE001\uE002\uE003\uE004" + text = "\ue000\ue001\ue002\ue003\ue004" assert is_garbled_text(text) is True def test_mostly_garbled(self): - text = "\uE000\uE001\uE002好" + text = "\ue000\ue001\ue002好" assert is_garbled_text(text, threshold=0.5) is True def test_few_garbled_below_threshold(self): - text = "这是正常文本\uE000" + text = "这是正常文本\ue000" assert is_garbled_text(text, threshold=0.5) is False def test_cid_pattern_detected(self): @@ -187,23 +205,23 @@ class TestIsGarbledText: assert is_garbled_text(" \t\n ") is False def test_custom_threshold(self): - text = "\uE000正常" + text = "\ue000正常" assert is_garbled_text(text, threshold=0.3) is True assert is_garbled_text(text, threshold=0.5) is False def test_replacement_chars_in_text(self): - text = "文档\uFFFD\uFFFD解析" + text = "文档\ufffd\ufffd解析" assert is_garbled_text(text, threshold=0.5) is False assert is_garbled_text(text, threshold=0.3) is True def test_real_world_garbled_pattern(self): - text = "\uE000\uE001\uE002\uE003\uE004\uE005\uE006\uE007" + text = "\ue000\ue001\ue002\ue003\ue004\ue005\ue006\ue007" assert is_garbled_text(text) is True def test_mixed_garbled_and_normal_at_boundary(self): - text = "AB\uE000\uE001" + text = "AB\ue000\ue001" assert is_garbled_text(text, threshold=0.5) is True - text2 = "ABC\uE000" + text2 = "ABC\ue000" assert is_garbled_text(text2, threshold=0.5) is False @@ -263,7 +281,7 @@ class TestIsGarbledByFontEncoding: def test_ascii_punct_from_subset_font_is_garbled(self): """Simulates GB.18067-2000.pdf: all chars are ASCII punct from subset fonts.""" chars = _make_chars( - list('!"#$%&\'(\'&)\'"*$!"#$%&\'\'()*+,$-'), + list("!\"#$%&'('&)'\"*$!\"#$%&''()*+,$-"), fontname="DY1+ZLQDm1-1", ) assert is_garbled_by_font_encoding(chars) is True @@ -287,7 +305,7 @@ class TestIsGarbledByFontEncoding: def test_non_subset_font_not_flagged(self): """ASCII punct from non-subset fonts should not be flagged.""" chars = _make_chars( - list('!"#$%&\'()*+,-./!"#$%&\'()*+,-./'), + list("!\"#$%&'()*+,-./!\"#$%&'()*+,-./"), fontname="Arial", ) assert is_garbled_by_font_encoding(chars) is False @@ -315,13 +333,13 @@ class TestIsGarbledByFontEncoding: def test_real_world_gb18067_page1(self): """Simulate actual GB.18067-2000.pdf Page 1 character distribution.""" - page_text = '!"#$%&\'(\'&)\'"*$!"#$%&\'\'()*+,$-' + page_text = "!\"#$%&'('&)'\"*$!\"#$%&''()*+,$-" chars = _make_chars(list(page_text), fontname="DY1+ZLQDm1-1") assert is_garbled_by_font_encoding(chars) is True def test_real_world_gb18067_page3(self): """Simulate actual GB.18067-2000.pdf Page 3 character distribution.""" - page_text = '!"#$%&\'()*+,-.*+/0+123456789:;<' + page_text = "!\"#$%&'()*+,-.*+/0+123456789:;<" chars = _make_chars(list(page_text), fontname="DY1+ZLQDnC-1") assert is_garbled_by_font_encoding(chars) is True @@ -342,14 +360,14 @@ class TestIsGarbledByFontEncoding: def test_boundary_cjk_ratio(self): """Just below 5% CJK threshold should still be flagged.""" # 1 CJK out of 25 chars = 4% CJK, rest are punct from subset font - chars = _make_chars(list('!"#$%&\'()*+,-./!@#$%^&*'), fontname="DY1+Font") + chars = _make_chars(list("!\"#$%&'()*+,-./!@#$%^&*"), fontname="DY1+Font") chars.append({"text": "中", "fontname": "DY1+Font"}) assert is_garbled_by_font_encoding(chars, min_chars=5) is True def test_boundary_above_cjk_threshold(self): """Above 5% CJK ratio should NOT be flagged.""" # 3 CJK out of 23 chars = ~13% CJK - chars = _make_chars(list('!"#$%&\'()*+,-./!@#$'), fontname="DY1+Font") + chars = _make_chars(list("!\"#$%&'()*+,-./!@#$"), fontname="DY1+Font") for ch in "中文字": chars.append({"text": ch, "fontname": "DY1+Font"}) assert is_garbled_by_font_encoding(chars, min_chars=5) is False @@ -362,14 +380,14 @@ class TestIsGarbledByFontEncoding: """ # 5 chars from subset font, 20 from normal font -> 20% subset ratio < 30% chars = _make_chars(list('!"#$%'), fontname="DY1+Font") - chars.extend(_make_chars(list('!"#$%&\'()*+,-./!@#$%'), fontname="Arial")) + chars.extend(_make_chars(list("!\"#$%&'()*+,-./!@#$%"), fontname="Arial")) assert is_garbled_by_font_encoding(chars, min_chars=5) is False def test_high_subset_ratio_flagged(self): """When most chars come from subset fonts, detection should trigger.""" # All 30 chars from subset font with punct -> garbled chars = _make_chars( - list('!"#$%&\'()*+,-./!@#$%^&*()[]{}'), + list("!\"#$%&'()*+,-./!@#$%^&*()[]{}"), fontname="BCDGEE+R0015", ) assert is_garbled_by_font_encoding(chars) is True diff --git a/test/unit_test/memory/utils/test_ob_conn_aggregation.py b/test/unit_test/memory/utils/test_ob_conn_aggregation.py index a409a5c255..abe8dd60b7 100644 --- a/test/unit_test/memory/utils/test_ob_conn_aggregation.py +++ b/test/unit_test/memory/utils/test_ob_conn_aggregation.py @@ -42,9 +42,7 @@ class TestAggregateByField: assert set(out) == {("user", 2), ("assistant", 1)} def test_single_doc_result(self): - messages = [ - {"id": "m1", "message_type_kwd": "user", "content_ltks": "x", "message_id": "msg1", "memory_id": "mem1", "status_int": 1} - ] + messages = [{"id": "m1", "message_type_kwd": "user", "content_ltks": "x", "message_id": "msg1", "memory_id": "mem1", "status_int": 1}] out = aggregate_by_field(messages, "message_type_kwd") assert out == [("user", 1)] diff --git a/test/unit_test/memory/utils/test_ob_conn_highlight.py b/test/unit_test/memory/utils/test_ob_conn_highlight.py index 99550cf011..35ff125c39 100644 --- a/test/unit_test/memory/utils/test_ob_conn_highlight.py +++ b/test/unit_test/memory/utils/test_ob_conn_highlight.py @@ -57,9 +57,7 @@ class TestGetHighlightFromMessages: assert get_highlight_from_messages(None, ["k"], "content_ltks") == {} def test_empty_keywords_returns_empty_dict(self): - assert get_highlight_from_messages( - [{"id": "m1", "content_ltks": "hello"}], [], "content_ltks" - ) == {} + assert get_highlight_from_messages([{"id": "m1", "content_ltks": "hello"}], [], "content_ltks") == {} def test_returns_id_to_highlighted_text(self): messages = [ diff --git a/test/unit_test/rag/conftest.py b/test/unit_test/rag/conftest.py index 3ca5e289e9..c11596cb5a 100644 --- a/test/unit_test/rag/conftest.py +++ b/test/unit_test/rag/conftest.py @@ -39,20 +39,14 @@ def _restore_common_data_source_package() -> None: return if not isinstance(mod, types.ModuleType) or not getattr(mod, "__path__", None): return - keys = [ - key - for key in sys.modules - if key == "common.data_source" or key.startswith("common.data_source.") - ] + keys = [key for key in sys.modules if key == "common.data_source" or key.startswith("common.data_source.")] for key in keys: del sys.modules[key] importlib.invalidate_caches() try: importlib.import_module("common.data_source") except Exception as exc: # pragma: no cover - raise ImportError( - "conftest: failed to restore real common.data_source package" - ) from exc + raise ImportError("conftest: failed to restore real common.data_source package") from exc _restore_common_data_source_package() diff --git a/test/unit_test/rag/graphrag/conftest.py b/test/unit_test/rag/graphrag/conftest.py index a980592b5d..ddb86791dd 100644 --- a/test/unit_test/rag/graphrag/conftest.py +++ b/test/unit_test/rag/graphrag/conftest.py @@ -43,7 +43,7 @@ for mod_name in _modules_to_mock: sys.modules[mod_name] = MagicMock() # Ensure `from common.connection_utils import timeout` returns a no-op decorator -sys.modules["common.connection_utils"].timeout = lambda *a, **kw: (lambda fn: fn) +sys.modules["common.connection_utils"].timeout = lambda *a, **kw: lambda fn: fn sys.modules["api.db.services.task_service"].has_canceled = lambda *_a, **_kw: False sys.modules["rag.graphrag.general.leiden"].run = lambda *_a, **_kw: {} sys.modules["rag.graphrag.general.leiden"].add_community_info2graph = lambda *_a, **_kw: None diff --git a/test/unit_test/rag/graphrag/test_graphrag_extractors.py b/test/unit_test/rag/graphrag/test_graphrag_extractors.py index 947307df02..3361e9fbbe 100644 --- a/test/unit_test/rag/graphrag/test_graphrag_extractors.py +++ b/test/unit_test/rag/graphrag/test_graphrag_extractors.py @@ -76,10 +76,7 @@ class TestCommunityReportsExtractor: async def slow_async_chat(*_args, **_kwargs): await asyncio.sleep(0.02) - return ( - '{"title":"Community","summary":"Summary","findings":[],' - '"rating":1.0,"rating_explanation":"Clear"}' - ) + return '{"title":"Community","summary":"Summary","findings":[],"rating":1.0,"rating_explanation":"Clear"}' monkeypatch.setattr(community_reports_module, "timeout", fake_timeout, raising=False) monkeypatch.setattr( diff --git a/test/unit_test/rag/llm/test_gen_conf_no_mutable_default.py b/test/unit_test/rag/llm/test_gen_conf_no_mutable_default.py index 075d4a65f4..7bd6b07a23 100644 --- a/test/unit_test/rag/llm/test_gen_conf_no_mutable_default.py +++ b/test/unit_test/rag/llm/test_gen_conf_no_mutable_default.py @@ -29,6 +29,7 @@ The fix is to default to ``None`` and copy at the call site the ``ast`` module and asserts no parameter named ``gen_conf`` ever has a mutable literal as its default. """ + import ast from pathlib import Path from typing import Union @@ -50,7 +51,7 @@ def _iter_param_defaults(func: Union[ast.FunctionDef, ast.AsyncFunctionDef]): pos_args = args.args pos_defaults = args.defaults # positional defaults are right-aligned with args - for arg, default in zip(pos_args[-len(pos_defaults):], pos_defaults): + for arg, default in zip(pos_args[-len(pos_defaults) :], pos_defaults): yield arg.arg, default for arg, default in zip(args.kwonlyargs, args.kw_defaults): if default is not None: @@ -80,11 +81,7 @@ def test_no_mutable_default_for_gen_conf(path: Path): """No function in chat_model.py / cv_model.py should declare ``gen_conf={}`` (or ``gen_conf=[]``) as a default value.""" bad = _find_mutable_gen_conf_defaults(path) - assert not bad, ( - f"{path.name} has functions declaring `gen_conf` with a mutable " - f"default: {bad}. Use `gen_conf=None` and copy with " - f"`gen_conf = dict(gen_conf or {{}})` at the top of the function." - ) + assert not bad, f"{path.name} has functions declaring `gen_conf` with a mutable default: {bad}. Use `gen_conf=None` and copy with `gen_conf = dict(gen_conf or {{}})` at the top of the function." def test_target_files_exist(): diff --git a/test/unit_test/rag/llm/test_tool_decorator.py b/test/unit_test/rag/llm/test_tool_decorator.py index d2b8439676..8d4d3cbb2b 100644 --- a/test/unit_test/rag/llm/test_tool_decorator.py +++ b/test/unit_test/rag/llm/test_tool_decorator.py @@ -20,6 +20,7 @@ on: each ``@tool`` callable carries a well-formed OpenAI function schema, required vs. optional params are derived from defaults, and the session dispatches both sync and async callables by name. """ + import asyncio import pytest @@ -99,7 +100,8 @@ def test_session_rejects_non_mapping_arguments(): def test_session_rejects_non_tool_callable(): - def plain(x): return x + def plain(x): + return x with pytest.raises(TypeError): FunctionToolSession([plain]) diff --git a/test/unit_test/rag/prompts/test_generator_sandbox.py b/test/unit_test/rag/prompts/test_generator_sandbox.py index 55095788b0..80883f62df 100644 --- a/test/unit_test/rag/prompts/test_generator_sandbox.py +++ b/test/unit_test/rag/prompts/test_generator_sandbox.py @@ -38,9 +38,7 @@ class TestJinjaSandbox: ) def test_ssti_payload_blocked(self, payload): """Verify that SSTI payloads are blocked by SandboxedEnvironment.""" - assert isinstance(PROMPT_JINJA_ENV, SandboxedEnvironment), ( - "PROMPT_JINJA_ENV must use SandboxedEnvironment to prevent SSTI" - ) + assert isinstance(PROMPT_JINJA_ENV, SandboxedEnvironment), "PROMPT_JINJA_ENV must use SandboxedEnvironment to prevent SSTI" template = PROMPT_JINJA_ENV.from_string(payload) # SandboxedEnvironment raises SecurityError, AttributeError, or UndefinedError to block SSTI attacks with pytest.raises((SecurityError, AttributeError, UndefinedError)) as exc_info: @@ -59,8 +57,6 @@ class TestJinjaSandbox: @pytest.mark.p1 def test_loop_and_conditional_rendering(self): """Verify control flow templates work properly.""" - template = PROMPT_JINJA_ENV.from_string( - "{% for item in items %}{{ item }}{% endfor %}" - ) + template = PROMPT_JINJA_ENV.from_string("{% for item in items %}{{ item }}{% endfor %}") result = template.render(items=["a", "b", "c"]) assert result == "abc" diff --git a/test/unit_test/rag/svr/task_executor_refactor/conftest.py b/test/unit_test/rag/svr/task_executor_refactor/conftest.py index 84d06cc959..f36092dec7 100644 --- a/test/unit_test/rag/svr/task_executor_refactor/conftest.py +++ b/test/unit_test/rag/svr/task_executor_refactor/conftest.py @@ -24,6 +24,7 @@ Design principles: - Use real TaskContext, TaskHandler, and service instances - Verify RecordingContext for data flow assertions """ + # ============================================================================= # TensorFlow/UMAP Import Workaround # ============================================================================= @@ -57,6 +58,7 @@ from rag.svr.task_executor_refactor.recording_context import ( # Async Limiter Fixtures # ============================================================================= + class AsyncMockLimiter: """Mock asyncio semaphore that does not actually limit.""" @@ -77,6 +79,7 @@ def mock_limiter(): # Task Dictionary Fixtures # ============================================================================= + @pytest.fixture def standard_task_dict() -> Dict[str, Any]: """Provide a minimal but complete task dict for standard chunking.""" @@ -148,6 +151,7 @@ def memory_task_dict() -> Dict[str, Any]: # TaskContext Fixtures # ============================================================================= + @pytest.fixture def task_context(standard_task_dict, mock_limiter, recording_context): """Provide a real TaskContext instance with mocked limiters.""" @@ -194,6 +198,7 @@ def canceled_task_context(standard_task_dict, mock_limiter, recording_context): # RecordingContext Fixtures # ============================================================================= + @pytest.fixture(autouse=True) def recording_context(): """Provide a fresh RecordingContext for each test. @@ -246,6 +251,7 @@ def cleanup_resources(request): # External System Mocks (Boundary Mocks) # ============================================================================= + class MockEmbeddingModel: """Mock embedding model that returns deterministic vectors.""" @@ -300,6 +306,7 @@ def mock_chat_model(): # Patching Helpers # ============================================================================= + def create_patch_embedding_model(vectors=None, vector_size=128): """Create a patcher for the embedding model binding. @@ -315,15 +322,19 @@ def create_patch_embedding_model(vectors=None, vector_size=128): mock_model.__enter__ = MagicMock(return_value=mock_model) mock_model.__exit__ = MagicMock(return_value=False) - return patch( - "rag.svr.task_executor_refactor.task_handler.get_model_config_from_provider_instance", - return_value=MagicMock(), - ), patch( - "rag.svr.task_executor_refactor.task_handler.LLMBundle", - return_value=mock_model, - ), patch( - "rag.svr.task_executor_refactor.task_handler.get_tenant_default_model_by_type", - return_value=MagicMock(), + return ( + patch( + "rag.svr.task_executor_refactor.task_handler.get_model_config_from_provider_instance", + return_value=MagicMock(), + ), + patch( + "rag.svr.task_executor_refactor.task_handler.LLMBundle", + return_value=mock_model, + ), + patch( + "rag.svr.task_executor_refactor.task_handler.get_tenant_default_model_by_type", + return_value=MagicMock(), + ), ) @@ -356,12 +367,14 @@ def create_patch_parser_chunking(chunks=None): If None, returns a default single chunk. """ if chunks is None: - chunks = [{ - "content_with_weight": "This is a test chunk content.", - "page_num_int": [0], - "top_int": [0], - "position_int": [0, 0, 0, 0], - }] + chunks = [ + { + "content_with_weight": "This is a test chunk content.", + "page_num_int": [0], + "top_int": [0], + "position_int": [0, 0, 0, 0], + } + ] mock_async = AsyncMock(return_value=chunks) return patch( @@ -425,16 +438,18 @@ def create_default_chunks(count: int = 2) -> List[Dict[str, Any]]: """Create default chunk dictionaries for testing.""" chunks = [] for i in range(count): - chunks.append({ - "id": f"chunk_{i}_{uuid.uuid4().hex[:6]}", - "content_with_weight": f"This is test chunk content number {i}.", - "page_num_int": [i], - "top_int": [i * 100], - "position_int": [i, 0, i + 1, 0], - "doc_id": "doc_test", - "kb_id": "kb_test", - "docnm_kwd": "test_document.pdf", - }) + chunks.append( + { + "id": f"chunk_{i}_{uuid.uuid4().hex[:6]}", + "content_with_weight": f"This is test chunk content number {i}.", + "page_num_int": [i], + "top_int": [i * 100], + "position_int": [i, 0, i + 1, 0], + "doc_id": "doc_test", + "kb_id": "kb_test", + "docnm_kwd": "test_document.pdf", + } + ) return chunks @@ -476,6 +491,7 @@ def mock_chunk_service_factory(): # Unified Mock TaskContext Factory # ============================================================================= + def make_task_context(**overrides): """Build a MagicMock TaskContext with sensible defaults for all services. @@ -543,6 +559,7 @@ def make_task_context(**overrides): # RaptorService Fixtures (kept for backward compatibility) # ============================================================================= + def create_mock_raptor_context(): """Create a mock TaskContext suitable for RaptorService tests.""" return make_task_context() @@ -558,6 +575,7 @@ def mock_raptor_context(): # Embedding Binding Patch Helper # ============================================================================= + class patch_embedding_binding: """Context manager that patches embedding model binding at the external boundary. @@ -616,6 +634,7 @@ class patch_embedding_binding: # Common mock callbacks # ============================================================================= + async def mock_thread_return_binary(func, *args, **kwargs): """Reusable mock for thread_pool_exec — returns fake binary.""" return b"fake pdf binary" @@ -630,10 +649,10 @@ async def mock_thread_return_none(func, *args, **kwargs): # Patch helpers for integration tests # ============================================================================= + def patch_get_storage_binary(): """Patch TaskHandler._get_storage_binary to return fake binary.""" - return patch("rag.svr.task_executor_refactor.task_handler.TaskHandler._get_storage_binary", - new_callable=AsyncMock, return_value=b"fake pdf binary") + return patch("rag.svr.task_executor_refactor.task_handler.TaskHandler._get_storage_binary", new_callable=AsyncMock, return_value=b"fake pdf binary") def patch_task_handler_settings(mock_settings): @@ -645,6 +664,7 @@ def patch_task_handler_settings(mock_settings): # Shared Task Dictionary Factory # ============================================================================= + def make_task_dict(**overrides): """Build a task dict with sensible defaults for integration tests. @@ -681,6 +701,7 @@ def make_task_dict(**overrides): # Shared Pipeline Mock Block for Integration Tests # ============================================================================= + class patch_pipeline_mocks: """Context manager bundling common integration-test mock blocks. diff --git a/test/unit_test/rag/svr/task_executor_refactor/test_chunk_builder.py b/test/unit_test/rag/svr/task_executor_refactor/test_chunk_builder.py index bc2eed76de..414811f2ee 100644 --- a/test/unit_test/rag/svr/task_executor_refactor/test_chunk_builder.py +++ b/test/unit_test/rag/svr/task_executor_refactor/test_chunk_builder.py @@ -30,11 +30,26 @@ from test.unit_test.rag.svr.task_executor_refactor.conftest import make_task_con class TestGetParser: """Tests for get_parser function.""" - @pytest.mark.parametrize("parser_id", [ - "naive", "general", "table", "paper", "book", - "picture", "audio", "email", "presentation", "manual", - "laws", "qa", "resume", "one", "tag", - ]) + @pytest.mark.parametrize( + "parser_id", + [ + "naive", + "general", + "table", + "paper", + "book", + "picture", + "audio", + "email", + "presentation", + "manual", + "laws", + "qa", + "resume", + "one", + "tag", + ], + ) def test_get_parser_returns_non_none(self, parser_id): """Test that get_parser returns non-None for all parser types.""" parser = get_parser(parser_id) @@ -43,6 +58,7 @@ class TestGetParser: def test_get_parser_kg(self): """Test getting kg parser (maps to naive).""" from common.constants import ParserType + parser = get_parser(ParserType.KG.value) assert parser is not None @@ -73,8 +89,10 @@ class TestRunChunking: mock_chunker = MagicMock() mock_chunker.chunk = MagicMock(return_value=[]) - with patch("rag.svr.task_executor_refactor.chunk_builder.thread_pool_exec") as mock_thread, \ - patch("rag.svr.task_executor_refactor.chunk_builder.merge_table_parser_config_from_kb") as mock_merge: + with ( + patch("rag.svr.task_executor_refactor.chunk_builder.thread_pool_exec") as mock_thread, + patch("rag.svr.task_executor_refactor.chunk_builder.merge_table_parser_config_from_kb") as mock_merge, + ): mock_thread.return_value = [] mock_merge.return_value = {"chunk_token_num": 128} await run_chunking(mock_chunker, b"binary", ctx) @@ -143,9 +161,7 @@ class TestExtractOutline: outline_data = [{"title": "Chapter 1", "page": 1}] cks = [{"__outline__": outline_data}] await extract_outline(cks, ctx) - ctx.write_interceptor.intercept.assert_called_once_with( - "DocMetadataService.update_document_metadata" - ) + ctx.write_interceptor.intercept.assert_called_once_with("DocMetadataService.update_document_metadata") @pytest.mark.asyncio async def test_extract_outline_persistence_exception(self): diff --git a/test/unit_test/rag/svr/task_executor_refactor/test_comparator.py b/test/unit_test/rag/svr/task_executor_refactor/test_comparator.py index 307015e5d2..29a8ddc33e 100644 --- a/test/unit_test/rag/svr/task_executor_refactor/test_comparator.py +++ b/test/unit_test/rag/svr/task_executor_refactor/test_comparator.py @@ -41,13 +41,7 @@ class TestComparisonResult: def test_init_with_all_fields(self): """Test initialization with all fields.""" - result = ComparisonResult( - key="test_key", - match=False, - production_value=100, - dry_run_value=200, - diff_details="Values differ" - ) + result = ComparisonResult(key="test_key", match=False, production_value=100, dry_run_value=200, diff_details="Values differ") assert result.key == "test_key" assert result.match is False assert result.production_value == 100 @@ -62,11 +56,7 @@ class TestComparisonResult: def test_to_dict_mismatch(self): """Test to_dict for mismatching result.""" - result = ComparisonResult( - key="key", - match=False, - diff_details="Difference" - ) + result = ComparisonResult(key="key", match=False, diff_details="Difference") d = result.to_dict() assert d == {"key": "key", "match": False, "diff_details": "Difference"} @@ -92,24 +82,14 @@ class TestComparisonReport: def test_summary_with_keys(self): """Test summary with keys.""" - report = ComparisonReport( - task_id="task_123", - total_keys=10, - matched_keys=8, - mismatched_keys=2 - ) + report = ComparisonReport(task_id="task_123", total_keys=10, matched_keys=8, mismatched_keys=2) summary = report.summary() assert "8/10" in summary assert "80.0%" in summary def test_to_dict(self): """Test to_dict serialization.""" - report = ComparisonReport( - task_id="task_123", - total_keys=1, - matched_keys=1, - details=[ComparisonResult(key="k", match=True)] - ) + report = ComparisonReport(task_id="task_123", total_keys=1, matched_keys=1, details=[ComparisonResult(key="k", match=True)]) d = report.to_dict() assert d["task_id"] == "task_123" assert d["total_keys"] == 1 @@ -117,15 +97,7 @@ class TestComparisonReport: def test_to_markdown(self): """Test to_markdown serialization.""" - report = ComparisonReport( - task_id="task_123", - total_keys=1, - matched_keys=1, - mismatched_keys=0, - missing_in_production=[], - missing_in_dry_run=[], - details=[ComparisonResult(key="k", match=True)] - ) + report = ComparisonReport(task_id="task_123", total_keys=1, matched_keys=1, mismatched_keys=0, missing_in_production=[], missing_in_dry_run=[], details=[ComparisonResult(key="k", match=True)]) md = report.to_markdown() assert "# Comparison Report: task_123" in md assert "## Summary" in md @@ -222,11 +194,7 @@ class TestContextComparatorCompareValue: def test_compare_chunks_key_uses_chunk_comparison(self): """Test that chunk keys use chunk comparison strategy.""" - result = self.comparator.compare_value( - "raw_chunks", - [{"id": "1", "content_with_weight": "a"}], - [{"id": "1", "content_with_weight": "a"}] - ) + result = self.comparator.compare_value("raw_chunks", [{"id": "1", "content_with_weight": "a"}], [{"id": "1", "content_with_weight": "a"}]) assert result.match is True @@ -328,7 +296,7 @@ class TestContextComparatorCompareChunks: def test_all_chunks_compared_not_sampled(self): """Test that ALL chunks are compared, not just samples. - + This test creates 10 chunks where only the middle one (index 5) differs. With the old sampling strategy, this difference might be missed. With full comparison, the difference should always be detected. @@ -337,7 +305,7 @@ class TestContextComparatorCompareChunks: dry = [{"id": str(i), "content_with_weight": f"content_{i}"} for i in range(10)] # Only modify chunk at index 5 (which might not be sampled in old strategy) dry[5]["content_with_weight"] = "different_content" - + result = self.comparator._compare_chunks("raw_chunks", prod, dry) assert result.match is False assert "Content differs" in result.diff_details @@ -346,7 +314,7 @@ class TestContextComparatorCompareChunks: """Test that first chunk difference is detected.""" prod = [{"id": "1", "content_with_weight": "a"}, {"id": "2", "content_with_weight": "b"}] dry = [{"id": "1", "content_with_weight": "different"}, {"id": "2", "content_with_weight": "b"}] - + result = self.comparator._compare_chunks("raw_chunks", prod, dry) assert result.match is False @@ -354,7 +322,7 @@ class TestContextComparatorCompareChunks: """Test that last chunk difference is detected.""" prod = [{"id": "1", "content_with_weight": "a"}, {"id": "2", "content_with_weight": "b"}] dry = [{"id": "1", "content_with_weight": "a"}, {"id": "2", "content_with_weight": "different"}] - + result = self.comparator._compare_chunks("raw_chunks", prod, dry) assert result.match is False @@ -362,7 +330,7 @@ class TestContextComparatorCompareChunks: """Test that large list of chunks all match.""" prod = [{"id": str(i), "content_with_weight": f"content_{i}"} for i in range(100)] dry = [{"id": str(i), "content_with_weight": f"content_{i}"} for i in range(100)] - + result = self.comparator._compare_chunks("raw_chunks", prod, dry) assert result.match is True @@ -372,7 +340,7 @@ class TestContextComparatorCompareChunks: dry = [{"id": str(i), "content_with_weight": f"content_{i}"} for i in range(100)] # Modify only the last chunk dry[99]["content_with_weight"] = "different" - + result = self.comparator._compare_chunks("raw_chunks", prod, dry) assert result.match is False @@ -487,10 +455,7 @@ class TestContextComparatorStripNonDeterministicFields: def test_strip_seconds_from_dict_value(self): """Test that 'seconds' key is removed from dict values.""" - data = { - "graphrag_result": {"seconds": 45.48, "status": "done"}, - "other_key": "value" - } + data = {"graphrag_result": {"seconds": 45.48, "status": "done"}, "other_key": "value"} result = self.comparator._strip_non_deterministic_fields(data) assert "seconds" not in result["graphrag_result"] assert result["graphrag_result"] == {"status": "done"} @@ -498,11 +463,7 @@ class TestContextComparatorStripNonDeterministicFields: def test_strip_seconds_from_multiple_dict_values(self): """Test that 'seconds' is removed from multiple dict values.""" - data = { - "result1": {"seconds": 10.0, "count": 5}, - "result2": {"seconds": 20.0, "name": "test"}, - "simple_key": 123 - } + data = {"result1": {"seconds": 10.0, "count": 5}, "result2": {"seconds": 20.0, "name": "test"}, "simple_key": 123} result = self.comparator._strip_non_deterministic_fields(data) assert result["result1"] == {"count": 5} assert result["result2"] == {"name": "test"} @@ -510,9 +471,7 @@ class TestContextComparatorStripNonDeterministicFields: def test_strip_does_not_modify_original_dict(self): """Test that the original dict is not modified in place.""" - data = { - "result": {"seconds": 1.0, "value": "test"} - } + data = {"result": {"seconds": 1.0, "value": "test"}} _ = data["result"].copy() self.comparator._strip_non_deterministic_fields(data) # The original nested dict should still have seconds since we only do shallow copy @@ -520,22 +479,14 @@ class TestContextComparatorStripNonDeterministicFields: def test_strip_with_empty_dict_values(self): """Test handling of empty dict values.""" - data = { - "empty_dict": {}, - "normal_key": "value" - } + data = {"empty_dict": {}, "normal_key": "value"} result = self.comparator._strip_non_deterministic_fields(data) assert result["empty_dict"] == {} assert result["normal_key"] == "value" def test_strip_with_non_dict_values(self): """Test that non-dict values are not affected.""" - data = { - "string_val": "test", - "int_val": 42, - "list_val": [1, 2, 3], - "dict_val": {"seconds": 1.0, "name": "test"} - } + data = {"string_val": "test", "int_val": 42, "list_val": [1, 2, 3], "dict_val": {"seconds": 1.0, "name": "test"}} result = self.comparator._strip_non_deterministic_fields(data) assert result["string_val"] == "test" assert result["int_val"] == 42 @@ -544,23 +495,11 @@ class TestContextComparatorStripNonDeterministicFields: def test_strip_seconds_from_graphrag_result(self): """Test the specific case from the bug report: graphrag_result with seconds.""" - prod_data = { - "graphrag_result": { - "seconds": 45.48, - "status": "success", - "entity_count": 100 - } - } - dry_run_data = { - "graphrag_result": { - "seconds": 0.99, - "status": "success", - "entity_count": 100 - } - } + prod_data = {"graphrag_result": {"seconds": 45.48, "status": "success", "entity_count": 100}} + dry_run_data = {"graphrag_result": {"seconds": 0.99, "status": "success", "entity_count": 100}} prod_stripped = self.comparator._strip_non_deterministic_fields(prod_data) dry_run_stripped = self.comparator._strip_non_deterministic_fields(dry_run_data) - + # After stripping, both should be equal (except for seconds) assert prod_stripped["graphrag_result"] == {"status": "success", "entity_count": 100} assert dry_run_stripped["graphrag_result"] == {"status": "success", "entity_count": 100} @@ -572,7 +511,7 @@ class TestContextComparatorStripNonDeterministicFields: ctx2 = RecordingContext() ctx1.record("graphrag_result", {"seconds": 45.48, "status": "success"}) ctx2.record("graphrag_result", {"seconds": 0.99, "status": "success"}) - + report = self.comparator.compare("task_1", ctx1, ctx2) # Should match because seconds is stripped assert report.matched_keys == 1 @@ -584,7 +523,7 @@ class TestContextComparatorStripNonDeterministicFields: ctx2 = RecordingContext() ctx1.record("graphrag_result", {"seconds": 45.48, "status": "success", "count": 100}) ctx2.record("graphrag_result", {"seconds": 0.99, "status": "failed", "count": 50}) - + report = self.comparator.compare("task_1", ctx1, ctx2) # Should mismatch because status and count differ assert report.mismatched_keys == 1 diff --git a/test/unit_test/rag/svr/task_executor_refactor/test_dataflow_service.py b/test/unit_test/rag/svr/task_executor_refactor/test_dataflow_service.py index 6be9261356..3d4d01c7ca 100644 --- a/test/unit_test/rag/svr/task_executor_refactor/test_dataflow_service.py +++ b/test/unit_test/rag/svr/task_executor_refactor/test_dataflow_service.py @@ -61,7 +61,7 @@ class TestDataflowServiceRunDataflow: mock_pipeline.run = AsyncMock(return_value={}) mock_pipeline_class.return_value = mock_pipeline - with patch.object(DataflowService, '_record_pipeline_log'): + with patch.object(DataflowService, "_record_pipeline_log"): service = DataflowService(ctx=task_context) await service.run_dataflow() @@ -86,13 +86,13 @@ class TestDataflowServiceRunDataflow: mock_pipeline.run = AsyncMock(return_value=data) mock_pipeline_class.return_value = mock_pipeline - with patch.object(DataflowService, '_embed_chunks', new_callable=AsyncMock, - return_value=(data[output_key], 5)), \ - patch.object(DataflowService, '_insert_chunks', new_callable=AsyncMock, return_value=True), \ - patch.object(DataflowService, '_update_document_metadata'), \ - patch.object(DataflowService, '_record_pipeline_log'), \ - patch("api.db.services.document_service.DocumentService.increment_chunk_num"): - + with ( + patch.object(DataflowService, "_embed_chunks", new_callable=AsyncMock, return_value=(data[output_key], 5)), + patch.object(DataflowService, "_insert_chunks", new_callable=AsyncMock, return_value=True), + patch.object(DataflowService, "_update_document_metadata"), + patch.object(DataflowService, "_record_pipeline_log"), + patch("api.db.services.document_service.DocumentService.increment_chunk_num"), + ): service = DataflowService(ctx=task_context) await service.run_dataflow() DataflowService._insert_chunks.assert_called_once() @@ -118,8 +118,7 @@ class TestDataflowServiceRunDataflow: mock_pipeline.run = AsyncMock(return_value=chunks) mock_pipeline_class.return_value = mock_pipeline - with patch.object(DataflowService, '_embed_chunks', new_callable=AsyncMock, return_value=(None, 0)), \ - patch.object(DataflowService, '_record_pipeline_log'): + with patch.object(DataflowService, "_embed_chunks", new_callable=AsyncMock, return_value=(None, 0)), patch.object(DataflowService, "_record_pipeline_log"): service = DataflowService(ctx=task_context) await service.run_dataflow() service._record_pipeline_log.assert_called() @@ -152,12 +151,13 @@ class TestDataflowServiceRunDataflow: billing_hook.on_pipeline_success = AsyncMock() billing_hook.on_pipeline_error = AsyncMock() - with patch.object(DataflowService, '_embed_chunks', new_callable=AsyncMock, return_value=(chunks["chunks"], 1)), \ - patch.object(DataflowService, '_insert_chunks', new_callable=AsyncMock, return_value=True), \ - patch.object(DataflowService, '_update_document_metadata'), \ - patch.object(DataflowService, '_record_pipeline_log'), \ - patch("api.db.services.document_service.DocumentService.increment_chunk_num"): - + with ( + patch.object(DataflowService, "_embed_chunks", new_callable=AsyncMock, return_value=(chunks["chunks"], 1)), + patch.object(DataflowService, "_insert_chunks", new_callable=AsyncMock, return_value=True), + patch.object(DataflowService, "_update_document_metadata"), + patch.object(DataflowService, "_record_pipeline_log"), + patch("api.db.services.document_service.DocumentService.increment_chunk_num"), + ): service = DataflowService(ctx=task_context, billing_hook=billing_hook) await service.run_dataflow() billing_hook.on_pipeline_success.assert_called_once() @@ -382,4 +382,4 @@ class TestDataflowServiceLoadDsl: assert dsl == '{"id": "test_pipeline"}' assert corrected_id == "corrected_pipeline_id" - mock_log.get_by_id.assert_called_once_with(dataflow_id) \ No newline at end of file + mock_log.get_by_id.assert_called_once_with(dataflow_id) diff --git a/test/unit_test/rag/svr/task_executor_refactor/test_embedding_service.py b/test/unit_test/rag/svr/task_executor_refactor/test_embedding_service.py index 76ff5c9d19..f01be12a42 100644 --- a/test/unit_test/rag/svr/task_executor_refactor/test_embedding_service.py +++ b/test/unit_test/rag/svr/task_executor_refactor/test_embedding_service.py @@ -147,7 +147,6 @@ class TestEmbeddingServiceEmbedChunks: assert vector_size == 2 assert "q_2_vec" in docs[0] - @pytest.mark.asyncio @patch("rag.svr.task_executor_refactor.embedding_service.thread_pool_exec", new_callable=AsyncMock) async def test_embed_chunks_empty_docs(self, mock_thread_pool): @@ -196,8 +195,7 @@ class TestEmbeddingServiceEmbedChunks: model.max_length = 100 docs = [{"docnm_kwd": "Title1", "content_with_weight": "Content1"}] - _, vector_size = await service.embed_chunks(docs, model, - parser_config={"filename_embd_weight": 0.0}) + _, vector_size = await service.embed_chunks(docs, model, parser_config={"filename_embd_weight": 0.0}) assert vector_size == 2 @@ -214,8 +212,7 @@ class TestEmbeddingServiceEmbedChunks: model.max_length = 100 docs = [{"docnm_kwd": "Title1", "content_with_weight": "Content1"}] - _, vector_size = await service.embed_chunks(docs, model, - parser_config={"filename_embd_weight": 1.0}) + _, vector_size = await service.embed_chunks(docs, model, parser_config={"filename_embd_weight": 1.0}) assert vector_size == 2 @@ -239,6 +236,7 @@ class TestEmbeddingServiceEmbedChunks: @patch("rag.svr.task_executor_refactor.embedding_service.thread_pool_exec", new_callable=AsyncMock) async def test_embed_chunks_multiple_batches(self, mock_thread_pool): """Test embedding with more chunks than batch size — multiple encode calls.""" + # Each call returns vectors matching input count def side_effect(func, texts, *args, **kw): n = len(texts) if isinstance(texts, list) else 1 diff --git a/test/unit_test/rag/svr/task_executor_refactor/test_raptor_service.py b/test/unit_test/rag/svr/task_executor_refactor/test_raptor_service.py index 9f3aa60b5b..c44a3c50dc 100644 --- a/test/unit_test/rag/svr/task_executor_refactor/test_raptor_service.py +++ b/test/unit_test/rag/svr/task_executor_refactor/test_raptor_service.py @@ -218,9 +218,7 @@ class TestRaptorServiceRunRaptorForKb: # ---- Basic dispatch (file-level scope) ---- @pytest.mark.asyncio - async def test_run_raptor_for_kb_file_scope_delegates_to_file_level( - self, mock_raptor_context, sample_chunks, raptor_config_file_scope - ): + async def test_run_raptor_for_kb_file_scope_delegates_to_file_level(self, mock_raptor_context, sample_chunks, raptor_config_file_scope): """When scope='file', _run_file_level_raptor is called.""" svc = RaptorService(mock_raptor_context) doc_ids = ["doc_1", "doc_2"] @@ -228,12 +226,18 @@ class TestRaptorServiceRunRaptorForKb: embd_mdl = MagicMock() vector_size = 128 - with patch.object(svc, "_collect_doc_info", return_value={ - "doc_1": {"name": "a.pdf", "type": "pdf", "parser_id": "naive", "parser_config": {}}, - "doc_2": {"name": "b.pdf", "type": "pdf", "parser_id": "naive", "parser_config": {}}, - }), patch.object(svc, "_run_file_level_raptor", new_callable=AsyncMock) as mock_file, \ - patch.object(svc, "_run_dataset_level_raptor", new_callable=AsyncMock) as mock_dataset: - + with ( + patch.object( + svc, + "_collect_doc_info", + return_value={ + "doc_1": {"name": "a.pdf", "type": "pdf", "parser_id": "naive", "parser_config": {}}, + "doc_2": {"name": "b.pdf", "type": "pdf", "parser_id": "naive", "parser_config": {}}, + }, + ), + patch.object(svc, "_run_file_level_raptor", new_callable=AsyncMock) as mock_file, + patch.object(svc, "_run_dataset_level_raptor", new_callable=AsyncMock) as mock_dataset, + ): mock_file.return_value = (sample_chunks, 42) chunks, tk_count, cleanup = await svc.run_raptor_for_kb(raptor_config_file_scope, chat_mdl, embd_mdl, vector_size, doc_ids) @@ -245,9 +249,7 @@ class TestRaptorServiceRunRaptorForKb: # ---- Basic dispatch (dataset-level scope) ---- @pytest.mark.asyncio - async def test_run_raptor_for_kb_dataset_scope_delegates_to_dataset_level( - self, mock_raptor_context, sample_chunks, raptor_config_dataset_scope - ): + async def test_run_raptor_for_kb_dataset_scope_delegates_to_dataset_level(self, mock_raptor_context, sample_chunks, raptor_config_dataset_scope): """When scope='dataset', _run_dataset_level_raptor is called.""" svc = RaptorService(mock_raptor_context) doc_ids = ["doc_1"] @@ -255,11 +257,17 @@ class TestRaptorServiceRunRaptorForKb: embd_mdl = MagicMock() vector_size = 128 - with patch.object(svc, "_collect_doc_info", return_value={ - "doc_1": {"name": "a.pdf", "type": "pdf", "parser_id": "naive", "parser_config": {}}, - }), patch.object(svc, "_run_file_level_raptor", new_callable=AsyncMock) as mock_file, \ - patch.object(svc, "_run_dataset_level_raptor", new_callable=AsyncMock) as mock_dataset: - + with ( + patch.object( + svc, + "_collect_doc_info", + return_value={ + "doc_1": {"name": "a.pdf", "type": "pdf", "parser_id": "naive", "parser_config": {}}, + }, + ), + patch.object(svc, "_run_file_level_raptor", new_callable=AsyncMock) as mock_file, + patch.object(svc, "_run_dataset_level_raptor", new_callable=AsyncMock) as mock_dataset, + ): mock_dataset.return_value = (sample_chunks, 99) chunks, tk_count, cleanup = await svc.run_raptor_for_kb(raptor_config_dataset_scope, chat_mdl, embd_mdl, vector_size, doc_ids) @@ -277,10 +285,11 @@ class TestRaptorServiceRunRaptorForKb: chat_mdl = MagicMock() embd_mdl = MagicMock() - with patch.object(svc, "_collect_doc_info", return_value={}), \ - patch.object(svc, "_run_file_level_raptor", new_callable=AsyncMock) as mock_file, \ - patch.object(svc, "_run_dataset_level_raptor", new_callable=AsyncMock): - + with ( + patch.object(svc, "_collect_doc_info", return_value={}), + patch.object(svc, "_run_file_level_raptor", new_callable=AsyncMock) as mock_file, + patch.object(svc, "_run_dataset_level_raptor", new_callable=AsyncMock), + ): mock_file.return_value = ([], 0) chunks, tk_count, cleanup = await svc.run_raptor_for_kb(raptor_config_file_scope, chat_mdl, embd_mdl, 128, []) @@ -291,9 +300,7 @@ class TestRaptorServiceRunRaptorForKb: # ---- Cleanup scheduling through the public API ---- @pytest.mark.asyncio - async def test_run_raptor_for_kb_returns_cleanup_list( - self, mock_raptor_context, raptor_config_file_scope - ): + async def test_run_raptor_for_kb_returns_cleanup_list(self, mock_raptor_context, raptor_config_file_scope): """Cleanup list from internal method is propagated to caller.""" svc = RaptorService(mock_raptor_context) doc_ids = ["doc_1"] @@ -302,9 +309,16 @@ class TestRaptorServiceRunRaptorForKb: expected_cleanup = [("doc_1", "tree_builder_a")] - with patch.object(svc, "_collect_doc_info", return_value={ - "doc_1": {"name": "a.pdf", "type": "pdf", "parser_id": "naive", "parser_config": {}}, - }), patch.object(svc, "_run_file_level_raptor", new_callable=AsyncMock) as mock_file: + with ( + patch.object( + svc, + "_collect_doc_info", + return_value={ + "doc_1": {"name": "a.pdf", "type": "pdf", "parser_id": "naive", "parser_config": {}}, + }, + ), + patch.object(svc, "_run_file_level_raptor", new_callable=AsyncMock) as mock_file, + ): async def mock_run_file(*args, **kwargs): cleanup_list = args[11] @@ -319,9 +333,7 @@ class TestRaptorServiceRunRaptorForKb: # ---- Dispatch with missing raptor config key ---- @pytest.mark.asyncio - async def test_run_raptor_for_kb_defaults_to_file_scope_when_no_raptor_key( - self, mock_raptor_context - ): + async def test_run_raptor_for_kb_defaults_to_file_scope_when_no_raptor_key(self, mock_raptor_context): """When kb_parser_config has no 'raptor' key, defaults to file scope.""" svc = RaptorService(mock_raptor_context) doc_ids = ["doc_1"] @@ -329,11 +341,17 @@ class TestRaptorServiceRunRaptorForKb: embd_mdl = MagicMock() config = {} # No raptor key at all - with patch.object(svc, "_collect_doc_info", return_value={ - "doc_1": {"name": "a.pdf", "type": "pdf", "parser_id": "naive", "parser_config": {}}, - }), patch.object(svc, "_run_file_level_raptor", new_callable=AsyncMock) as mock_file, \ - patch.object(svc, "_run_dataset_level_raptor", new_callable=AsyncMock) as mock_dataset: - + with ( + patch.object( + svc, + "_collect_doc_info", + return_value={ + "doc_1": {"name": "a.pdf", "type": "pdf", "parser_id": "naive", "parser_config": {}}, + }, + ), + patch.object(svc, "_run_file_level_raptor", new_callable=AsyncMock) as mock_file, + patch.object(svc, "_run_dataset_level_raptor", new_callable=AsyncMock) as mock_dataset, + ): mock_file.return_value = ([], 0) await svc.run_raptor_for_kb(config, chat_mdl, embd_mdl, 128, doc_ids) @@ -343,9 +361,7 @@ class TestRaptorServiceRunRaptorForKb: # ---- Vector dimension name construction ---- @pytest.mark.asyncio - async def test_run_raptor_for_kb_passes_vector_size_to_file_level( - self, mock_raptor_context, sample_chunks, raptor_config_file_scope - ): + async def test_run_raptor_for_kb_passes_vector_size_to_file_level(self, mock_raptor_context, sample_chunks, raptor_config_file_scope): """Vector size is used to construct vctr_nm and passed to internal method.""" svc = RaptorService(mock_raptor_context) doc_ids = ["doc_1"] @@ -353,10 +369,16 @@ class TestRaptorServiceRunRaptorForKb: embd_mdl = MagicMock() vector_size = 256 - with patch.object(svc, "_collect_doc_info", return_value={ - "doc_1": {"name": "a.pdf", "type": "pdf", "parser_id": "naive", "parser_config": {}}, - }), patch.object(svc, "_run_file_level_raptor", new_callable=AsyncMock) as mock_file: - + with ( + patch.object( + svc, + "_collect_doc_info", + return_value={ + "doc_1": {"name": "a.pdf", "type": "pdf", "parser_id": "naive", "parser_config": {}}, + }, + ), + patch.object(svc, "_run_file_level_raptor", new_callable=AsyncMock) as mock_file, + ): mock_file.return_value = (sample_chunks, 10) await svc.run_raptor_for_kb(raptor_config_file_scope, chat_mdl, embd_mdl, vector_size, doc_ids) @@ -369,9 +391,7 @@ class TestRaptorServiceRunRaptorForKb: # ---- Document info collection through public API ---- @pytest.mark.asyncio - async def test_run_raptor_for_kb_collects_doc_info( - self, mock_raptor_context, raptor_config_file_scope - ): + async def test_run_raptor_for_kb_collects_doc_info(self, mock_raptor_context, raptor_config_file_scope): """Document info is collected before dispatching to internal methods.""" svc = RaptorService(mock_raptor_context) doc_ids = ["doc_a"] @@ -380,9 +400,7 @@ class TestRaptorServiceRunRaptorForKb: expected_info = {"doc_a": {"name": "file.pdf", "type": "pdf", "parser_id": "naive", "parser_config": {}}} - with patch.object(svc, "_collect_doc_info", return_value=expected_info) as mock_collect, \ - patch.object(svc, "_run_file_level_raptor", new_callable=AsyncMock) as mock_file: - + with patch.object(svc, "_collect_doc_info", return_value=expected_info) as mock_collect, patch.object(svc, "_run_file_level_raptor", new_callable=AsyncMock) as mock_file: mock_file.return_value = ([], 0) await svc.run_raptor_for_kb(raptor_config_file_scope, chat_mdl, embd_mdl, 128, doc_ids) @@ -406,44 +424,40 @@ class TestRaptorServiceFileLevelRaptorCheckpoint: svc = RaptorService(ctx) doc_ids = ["doc_1"] - doc_info_by_id = { - "doc_1": {"name": "a.pdf", "type": "pdf", "parser_id": "naive", "parser_config": {}} - } + doc_info_by_id = {"doc_1": {"name": "a.pdf", "type": "pdf", "parser_id": "naive", "parser_config": {}}} raptor_config = { "scope": "file", - "max_cluster": 64, "prompt": "test prompt", - "max_token": 256, "threshold": 0.1, "random_seed": 0, - "clustering_method": "gmm", "tree_builder": "raptor", + "max_cluster": 64, + "prompt": "test prompt", + "max_token": 256, + "threshold": 0.1, + "random_seed": 0, + "clustering_method": "gmm", + "tree_builder": "raptor", "ext": {}, } - with patch.object(svc, "_get_raptor_chunk_methods", new_callable=AsyncMock) as mock_methods, \ - patch.object(svc, "_should_skip_raptor", return_value=False): - + with patch.object(svc, "_get_raptor_chunk_methods", new_callable=AsyncMock) as mock_methods, patch.object(svc, "_should_skip_raptor", return_value=False): mock_methods.return_value = {"raptor"} result = await svc._run_file_level_raptor( - raptor_config=raptor_config, tree_builder="raptor", - clustering_method="gmm", chat_mdl=MagicMock(), - embd_mdl=MagicMock(), vctr_nm="q_128_vec", - doc_ids=doc_ids, doc_info_by_id=doc_info_by_id, - max_errors=3, res=[], tk_count=0, + raptor_config=raptor_config, + tree_builder="raptor", + clustering_method="gmm", + chat_mdl=MagicMock(), + embd_mdl=MagicMock(), + vctr_nm="q_128_vec", + doc_ids=doc_ids, + doc_info_by_id=doc_info_by_id, + max_errors=3, + res=[], + tk_count=0, cleanup_raptor_chunks=[], ) - msg_calls = [ - call.kwargs.get("msg", "") - for call in ctx.progress_cb.call_args_list - if call.kwargs.get("msg") is not None - ] - assert any("already has" in m for m in msg_calls), \ - f"Expected 'already has' progress message, got: {msg_calls}" + msg_calls = [call.kwargs.get("msg", "") for call in ctx.progress_cb.call_args_list if call.kwargs.get("msg") is not None] + assert any("already has" in m for m in msg_calls), f"Expected 'already has' progress message, got: {msg_calls}" - prog_calls = [ - call.kwargs.get("prog") - for call in ctx.progress_cb.call_args_list - if call.kwargs.get("prog") is not None - ] - assert len(prog_calls) > 0, \ - "Expected progress_cb to be called with prog update" + prog_calls = [call.kwargs.get("prog") for call in ctx.progress_cb.call_args_list if call.kwargs.get("prog") is not None] + assert len(prog_calls) > 0, "Expected progress_cb to be called with prog update" assert result[0] == [] diff --git a/test/unit_test/rag/svr/task_executor_refactor/test_recording_context.py b/test/unit_test/rag/svr/task_executor_refactor/test_recording_context.py index d5336393b1..48c02db466 100644 --- a/test/unit_test/rag/svr/task_executor_refactor/test_recording_context.py +++ b/test/unit_test/rag/svr/task_executor_refactor/test_recording_context.py @@ -306,11 +306,11 @@ class TestRecordingContextManager: """Test context manager restores previous context after exit.""" outer_ctx = RecordingContext() set_recording_context(outer_ctx) - + inner_ctx = RecordingContext() with recording_context_manager(inner_ctx): assert get_recording_context() is inner_ctx - + # After exiting, should restore outer_ctx assert get_recording_context() is outer_ctx @@ -322,24 +322,24 @@ class TestTimedWithRecordingDecorator: """Test decorator used without parentheses.""" ctx = RecordingContext() set_recording_context(ctx) - + @timed_with_recording def test_func(): time.sleep(0.01) return 42 - + result = test_func() assert result == 42 def test_decorator_with_parentheses_and_context(self): """Test decorator with explicit context.""" ctx = RecordingContext() - + @timed_with_recording(recording_context=ctx) def test_func(): time.sleep(0.01) return "hello" - + result = test_func() assert result == "hello" @@ -347,11 +347,11 @@ class TestTimedWithRecordingDecorator: """Test decorator raises RuntimeError when no context is available.""" # Ensure no context is set set_recording_context(None) - + @timed_with_recording def test_func(): return 123 - + # Should raise RuntimeError because no context is available with pytest.raises(RuntimeError, match="no context"): test_func() diff --git a/test/unit_test/rag/svr/task_executor_refactor/test_task_context.py b/test/unit_test/rag/svr/task_executor_refactor/test_task_context.py index 1b7017ae14..f470ab2562 100644 --- a/test/unit_test/rag/svr/task_executor_refactor/test_task_context.py +++ b/test/unit_test/rag/svr/task_executor_refactor/test_task_context.py @@ -413,5 +413,5 @@ class TestTaskContextProgressCallback: task = {"id": "task_1", "tenant_id": "tenant_1"} ctx = _make_ctx(task=task) # _progress_cb should be set in __init__ - assert hasattr(ctx, '_progress_cb') + assert hasattr(ctx, "_progress_cb") assert ctx._progress_cb is not None diff --git a/test/unit_test/rag/svr/task_executor_refactor/test_task_raptor_utils.py b/test/unit_test/rag/svr/task_executor_refactor/test_task_raptor_utils.py index 01bc3ecfa8..2eef7d260e 100644 --- a/test/unit_test/rag/svr/task_executor_refactor/test_task_raptor_utils.py +++ b/test/unit_test/rag/svr/task_executor_refactor/test_task_raptor_utils.py @@ -32,8 +32,9 @@ class TestGetRaptorChunkFieldMap: async def test_returns_primary_result_when_raptor_chunks_exist(self): """Test that primary result is returned when RAPTOR chunks exist.""" from common import settings + original_retriever = settings.docStoreConn - + mock_doc_store = MagicMock() mock_doc_store.search.return_value = {"chunk_1": {"raptor_kwd": "raptor", "extra": {"raptor_method": "raptor"}}} mock_doc_store.get_fields.return_value = {"chunk_1": {"raptor_kwd": "raptor", "extra": {"raptor_method": "raptor"}}} @@ -41,8 +42,10 @@ class TestGetRaptorChunkFieldMap: try: with patch("rag.svr.task_executor_refactor.raptor_utils.thread_pool_exec") as mock_thread: + async def mock_exec(*args, **kwargs): return {"chunk_1": {"raptor_kwd": "raptor", "extra": {"raptor_method": "raptor"}}} + mock_thread.side_effect = mock_exec with patch("rag.svr.task_executor_refactor.raptor_utils.collect_raptor_chunk_ids") as mock_collect: @@ -58,13 +61,15 @@ class TestGetRaptorChunkFieldMap: async def test_falls_back_to_secondary_search_when_no_raptor_chunks(self): """Test that fallback search is used when no RAPTOR chunks found.""" from common import settings + original_retriever = settings.docStoreConn - + mock_doc_store = MagicMock() settings.docStoreConn = mock_doc_store try: call_count = 0 + async def mock_exec(*args, **kwargs): nonlocal call_count call_count += 1 @@ -90,14 +95,16 @@ class TestGetRaptorChunkFieldMap: async def test_handles_fallback_search_exception(self): """Test that exception in fallback search is handled gracefully.""" from common import settings + original_retriever = settings.docStoreConn - + mock_doc_store = MagicMock() mock_doc_store.get_fields.return_value = {} settings.docStoreConn = mock_doc_store try: call_count = 0 + async def mock_exec(*args, **kwargs): nonlocal call_count call_count += 1 @@ -128,8 +135,9 @@ class TestDeleteRaptorChunks: async def test_deletes_all_chunks_when_keep_method_is_none(self): """Test that all RAPTOR chunks are deleted when keep_method is None.""" from common import settings + original_retriever = settings.docStoreConn - + mock_doc_store = MagicMock() settings.docStoreConn = mock_doc_store @@ -164,17 +172,15 @@ class TestDeleteRaptorChunks: async def test_deletes_stale_chunks_when_keep_method_specified(self): """Test that stale chunks are deleted when keep_method is specified.""" from common import settings + original_retriever = settings.docStoreConn - + mock_doc_store = MagicMock() settings.docStoreConn = mock_doc_store try: with patch("rag.svr.task_executor_refactor.raptor_utils.get_raptor_chunk_field_map") as mock_get_map: - mock_get_map.return_value = { - "chunk_1": {"raptor_kwd": "raptor", "extra": {"raptor_method": "psi"}}, - "chunk_2": {"raptor_kwd": "raptor", "extra": {"raptor_method": "raptor"}} - } + mock_get_map.return_value = {"chunk_1": {"raptor_kwd": "raptor", "extra": {"raptor_method": "psi"}}, "chunk_2": {"raptor_kwd": "raptor", "extra": {"raptor_method": "raptor"}}} with patch("rag.svr.task_executor_refactor.raptor_utils.collect_raptor_chunk_ids") as mock_collect: mock_collect.return_value = {"chunk_1"} # Only chunk_1 is stale (psi, not raptor) @@ -193,16 +199,15 @@ class TestDeleteRaptorChunks: async def test_logs_info_when_removing_stale_chunks(self): """Test that info is logged when removing stale chunks.""" from common import settings + original_retriever = settings.docStoreConn - + mock_doc_store = MagicMock() settings.docStoreConn = mock_doc_store try: with patch("rag.svr.task_executor_refactor.raptor_utils.get_raptor_chunk_field_map") as mock_get_map: - mock_get_map.return_value = { - "chunk_1": {"raptor_kwd": "raptor", "extra": {"raptor_method": "psi"}} - } + mock_get_map.return_value = {"chunk_1": {"raptor_kwd": "raptor", "extra": {"raptor_method": "psi"}}} with patch("rag.svr.task_executor_refactor.raptor_utils.collect_raptor_chunk_ids") as mock_collect: mock_collect.return_value = {"chunk_1"} diff --git a/test/unit_test/rag/svr/task_executor_refactor/test_write_operation_interceptor.py b/test/unit_test/rag/svr/task_executor_refactor/test_write_operation_interceptor.py index 59df761051..1a3ec4e068 100644 --- a/test/unit_test/rag/svr/task_executor_refactor/test_write_operation_interceptor.py +++ b/test/unit_test/rag/svr/task_executor_refactor/test_write_operation_interceptor.py @@ -54,7 +54,7 @@ class TestAllowedMethodNames: "PipelineOperationLogService.create", "delete_raptor_chunks", "docStoreConn.insert", - "docStoreConn.delete" + "docStoreConn.delete", } assert ALLOWED_METHOD_NAMES == expected_methods @@ -151,18 +151,21 @@ class TestWriteOperationInterceptorIntercept: result = interceptor.intercept("KnowledgebaseService.update_by_id", default_value=999) assert result == 100 - @pytest.mark.parametrize("default_value", [ - "default_string", - {"status": "success", "data": [1, 2, 3]}, - [1, 2, 3, 4, 5], - (1, "two", 3.0), - True, - False, - 0, - "", - [], - {}, - ]) + @pytest.mark.parametrize( + "default_value", + [ + "default_string", + {"status": "success", "data": [1, 2, 3]}, + [1, 2, 3, 4, 5], + (1, "two", 3.0), + True, + False, + 0, + "", + [], + {}, + ], + ) def test_intercept_with_various_default_values(self, valid_recorded_values, default_value): """Test that intercept returns various default_value types when list is empty.""" interceptor = WriteOperationInterceptor(valid_recorded_values) @@ -177,6 +180,7 @@ class TestWriteOperationInterceptorIntercept: result = interceptor.intercept("DocMetadataService.update_document_metadata") assert result == complex_value + class TestWriteOperationInterceptorRemainingCount: """Tests for WriteOperationInterceptor.remaining_count.""" diff --git a/test/unit_test/rag/svr/test_table_column_roles_helpers.py b/test/unit_test/rag/svr/test_table_column_roles_helpers.py index fe4eed27fe..63ecfd35a7 100644 --- a/test/unit_test/rag/svr/test_table_column_roles_helpers.py +++ b/test/unit_test/rag/svr/test_table_column_roles_helpers.py @@ -80,10 +80,7 @@ class TestEsFieldValueToDocMetadata: assert _es_field_value_to_doc_metadata("Brazil", from_tks_fallback=False) == "Brazil" def test_es_field_value_list_joined(self): - assert ( - _es_field_value_to_doc_metadata(["hello", "world"], from_tks_fallback=True) - == "hello world" - ) + assert _es_field_value_to_doc_metadata(["hello", "world"], from_tks_fallback=True) == "hello world" def test_es_field_value_empty(self): assert _es_field_value_to_doc_metadata(None, from_tks_fallback=True) is None diff --git a/test/unit_test/rag/test_naive_merge.py b/test/unit_test/rag/test_naive_merge.py index b26f8a7e5c..6ec4abb57f 100644 --- a/test/unit_test/rag/test_naive_merge.py +++ b/test/unit_test/rag/test_naive_merge.py @@ -112,9 +112,7 @@ def test_overlap_prefix_is_counted_in_token_budget(): # Pre-fix, the prefix tokens were not counted, so the per-chunk budget check # fired late and chunks systematically overshot chunk_token_num. sentences = [" ".join(["w"] * 10) for _ in range(30)] - chunks = _nonempty( - naive_merge(sentences, chunk_token_num=50, delimiter=DEFAULT_DELIMITER, overlapped_percent=20) - ) + chunks = _nonempty(naive_merge(sentences, chunk_token_num=50, delimiter=DEFAULT_DELIMITER, overlapped_percent=20)) assert len(chunks) > 1 # Each 10-token sentence divides chunk_token_num evenly, so a correct # accounting yields chunks of exactly the budget. The buggy version @@ -156,9 +154,7 @@ def test_images_oversized_section_is_split(): texts = [(section, "")] images = [None] - chunks, imgs = naive_merge_with_images( - texts, images, chunk_token_num=50, delimiter=DEFAULT_DELIMITER - ) + chunks, imgs = naive_merge_with_images(texts, images, chunk_token_num=50, delimiter=DEFAULT_DELIMITER) nonempty = _nonempty(chunks) assert len(nonempty) > 1 # Returned lists stay aligned. @@ -168,9 +164,7 @@ def test_images_oversized_section_is_split(): @pytest.mark.p2 def test_images_custom_delimiter_preserved(): - chunks, imgs = naive_merge_with_images( - [("x##y##z", "")], [None], chunk_token_num=1000, delimiter="`##`" - ) + chunks, imgs = naive_merge_with_images([("x##y##z", "")], [None], chunk_token_num=1000, delimiter="`##`") assert [c.strip() for c in chunks] == ["x", "y", "z"] assert len(chunks) == len(imgs) @@ -180,9 +174,7 @@ def test_images_plain_string_input(): # texts may be plain strings (not tuples). sentence = " ".join(["word"] * 10) section = "\n".join([sentence] * 20) - chunks, imgs = naive_merge_with_images( - [section], [None], chunk_token_num=50, delimiter=DEFAULT_DELIMITER - ) + chunks, imgs = naive_merge_with_images([section], [None], chunk_token_num=50, delimiter=DEFAULT_DELIMITER) assert len(_nonempty(chunks)) > 1 assert len(chunks) == len(imgs) @@ -202,9 +194,7 @@ def test_images_shared_lazyimage_not_stacked_across_split_sentences(): image = LazyImage([b"FAKEBLOB"]) section = "\n".join([" ".join(["word"] * 10)] * 20) - _, imgs = naive_merge_with_images( - [(section, "")], [image], chunk_token_num=50, delimiter=DEFAULT_DELIMITER - ) + _, imgs = naive_merge_with_images([(section, "")], [image], chunk_token_num=50, delimiter=DEFAULT_DELIMITER) for im in imgs: if isinstance(im, LazyImage): assert len(im._blobs) == 1 # never grows beyond the single source blob @@ -219,9 +209,7 @@ def test_images_distinct_lazyimages_are_concatenated(): a = LazyImage([b"BLOB_A"]) b = LazyImage([b"BLOB_B"]) texts = [("alpha beta gamma", ""), ("delta epsilon zeta", "")] - _, imgs = naive_merge_with_images( - texts, [a, b], chunk_token_num=100, delimiter=DEFAULT_DELIMITER - ) + _, imgs = naive_merge_with_images(texts, [a, b], chunk_token_num=100, delimiter=DEFAULT_DELIMITER) nonempty_imgs = [im for im in imgs if im is not None] assert len(nonempty_imgs) == 1 merged = nonempty_imgs[0] diff --git a/test/unit_test/rag/test_search_pagination.py b/test/unit_test/rag/test_search_pagination.py index 75a33d1adb..5427b05777 100644 --- a/test/unit_test/rag/test_search_pagination.py +++ b/test/unit_test/rag/test_search_pagination.py @@ -22,6 +22,7 @@ window is not an exact multiple of page_size, blocks and pages drift apart, so deep pages silently drop results and come back short. These tests pin that invariant and verify cross-block pagination loses nothing. """ + import math import sys import types @@ -48,11 +49,7 @@ _rerank_window = Dealer._rerank_window # (page_size, top) combinations, including the common page sizes (10, 30) that # do NOT divide 64 -- the exact case the old `min(..., 64)` clamp broke -- plus # tiny / large / page-aligned tops. -GRID = [ - (page_size, top) - for page_size in (1, 5, 7, 10, 30, 50, 64) - for top in (0, 5, 30, 50, 55, 64, 100, 1024) -] +GRID = [(page_size, top) for page_size in (1, 5, 7, 10, 30, 50, 64) for top in (0, 5, 30, 50, 55, 64, 100, 1024)] def _paginate(total, page_size, top, rerank): @@ -73,7 +70,7 @@ def _paginate(total, page_size, top, rerank): block_start = block_index * window block = list(range(block_start, min(block_start + window, cap))) begin = global_offset % window - surfaced.extend(block[begin:begin + page_size]) + surfaced.extend(block[begin : begin + page_size]) page += 1 return window, cap, surfaced @@ -96,9 +93,7 @@ def test_pagination_loses_nothing(page_size, top): for rerank in (False, True): window, cap, surfaced = _paginate(total, page_size, top, rerank) assert surfaced == list(range(cap)), ( - f"page_size={page_size} top={top} rerank={rerank} window={window} " - f"cap={cap}: missing={sorted(set(range(cap)) - set(surfaced))[:10]} " - f"dups={len(surfaced) != len(set(surfaced))}" + f"page_size={page_size} top={top} rerank={rerank} window={window} cap={cap}: missing={sorted(set(range(cap)) - set(surfaced))[:10]} dups={len(surfaced) != len(set(surfaced))}" ) @@ -123,6 +118,7 @@ def test_reported_regression_page7_not_short(): def test_matches_legacy_window_on_non_buggy_paths(): """Where the old formula already produced a page-aligned value, the new window is unchanged (no behavioral regression on the non-buggy paths).""" + def legacy(page_size, top, rerank): limit = math.ceil(64 / page_size) * page_size if page_size > 1 else 1 limit = max(30, limit) diff --git a/test/unit_test/rag/test_sync_data_source.py b/test/unit_test/rag/test_sync_data_source.py index 57fbfa043b..c8b6db1287 100644 --- a/test/unit_test/rag/test_sync_data_source.py +++ b/test/unit_test/rag/test_sync_data_source.py @@ -174,7 +174,14 @@ async def test_run_task_logic_skips_multiple_empty_sync_batches(monkeypatch): lambda *_args, **_kwargs: pytest.fail("duplicate_and_parse should not be called for empty batches"), ) - await _FakeSync(iter(([], [],)))._run_task_logic(_make_task()) + await _FakeSync( + iter( + ( + [], + [], + ) + ) + )._run_task_logic(_make_task()) @pytest.mark.asyncio @@ -197,9 +204,7 @@ async def test_run_prune_task_logic_cleans_up_for_empty_snapshot(monkeypatch): task = {**_make_task(), "task_type": sync_data_source.ConnectorTaskType.PRUNE} sync = _FakeSync(iter(())) sync.conf["sync_deleted_files"] = True - sync.connector = types.SimpleNamespace( - retrieve_all_slim_docs_perm_sync=lambda: iter(([],)) - ) + sync.connector = types.SimpleNamespace(retrieve_all_slim_docs_perm_sync=lambda: iter(([],))) await sync._run_task_logic(task) @@ -238,9 +243,7 @@ async def test_run_prune_task_logic_cleans_up_for_non_empty_snapshot(monkeypatch task = {**_make_task(), "task_type": sync_data_source.ConnectorTaskType.PRUNE} sync = _FakeSync(iter(())) sync.conf["sync_deleted_files"] = True - sync.connector = types.SimpleNamespace( - retrieve_all_slim_docs_perm_sync=lambda: iter((file_list,)) - ) + sync.connector = types.SimpleNamespace(retrieve_all_slim_docs_perm_sync=lambda: iter((file_list,))) await sync._run_task_logic(task) @@ -316,7 +319,7 @@ class _FakeRDBMSConnector: def load_from_cursor_range(self, start_value=None, start_id=None, end_value=None): self.load_from_cursor_range_called = True - return iter(([ _make_fake_doc("incremental-doc") ],)) + return iter(([_make_fake_doc("incremental-doc")],)) def persist_sync_state(self): self.persist_sync_state_called = True diff --git a/test/unit_test/rag/utils/test_base64_image.py b/test/unit_test/rag/utils/test_base64_image.py index 64ee86aeb4..fb2e62fc90 100644 --- a/test/unit_test/rag/utils/test_base64_image.py +++ b/test/unit_test/rag/utils/test_base64_image.py @@ -30,9 +30,7 @@ class TestParseStorageCompositeId: def test_hyphenated_object_key(self): """Object keys with hyphens split only on the first separator.""" - result = base64_image.parse_storage_composite_id( - "kb12345678901234567890123456789012-page-1.png" - ) + result = base64_image.parse_storage_composite_id("kb12345678901234567890123456789012-page-1.png") assert result == ("kb12345678901234567890123456789012", "page-1.png") def test_single_hyphen(self): diff --git a/test/unit_test/rag/utils/test_minio_conn_ssl.py b/test/unit_test/rag/utils/test_minio_conn_ssl.py index 5fc87d3304..e64a8a16ba 100644 --- a/test/unit_test/rag/utils/test_minio_conn_ssl.py +++ b/test/unit_test/rag/utils/test_minio_conn_ssl.py @@ -17,6 +17,7 @@ Unit tests for MinIO client SSL/secure configuration (_build_minio_http_client). Covers issue #13158. """ + import ssl from unittest.mock import patch @@ -28,6 +29,7 @@ class TestBuildMinioHttpClient: def test_returns_none_when_verify_true(self, mock_settings): mock_settings.MINIO = {"verify": True} from rag.utils.minio_conn import _build_minio_http_client + client = _build_minio_http_client() assert client is None @@ -35,6 +37,7 @@ class TestBuildMinioHttpClient: def test_returns_none_when_verify_missing(self, mock_settings): mock_settings.MINIO = {} from rag.utils.minio_conn import _build_minio_http_client + client = _build_minio_http_client() assert client is None @@ -42,6 +45,7 @@ class TestBuildMinioHttpClient: def test_returns_pool_manager_when_verify_false(self, mock_settings): mock_settings.MINIO = {"verify": False} from rag.utils.minio_conn import _build_minio_http_client + client = _build_minio_http_client() assert client is not None assert hasattr(client, "connection_pool_kw") @@ -51,6 +55,7 @@ class TestBuildMinioHttpClient: def test_returns_pool_manager_when_verify_string_false(self, mock_settings): mock_settings.MINIO = {"verify": "false"} from rag.utils.minio_conn import _build_minio_http_client + client = _build_minio_http_client() assert client is not None assert client.connection_pool_kw.get("cert_reqs") == ssl.CERT_NONE @@ -59,5 +64,6 @@ class TestBuildMinioHttpClient: def test_returns_none_when_verify_string_1(self, mock_settings): mock_settings.MINIO = {"verify": "1"} from rag.utils.minio_conn import _build_minio_http_client + client = _build_minio_http_client() assert client is None diff --git a/test/unit_test/rag/utils/test_ob_conn.py b/test/unit_test/rag/utils/test_ob_conn.py index c288ad4b81..af11901d27 100644 --- a/test/unit_test/rag/utils/test_ob_conn.py +++ b/test/unit_test/rag/utils/test_ob_conn.py @@ -123,24 +123,14 @@ class TestGetMetadataFilterExpression: def test_simple_is_condition(self): """Test simple 'is' comparison.""" - filter_dict = { - "conditions": [ - {"name": "author", "comparison_operator": "is", "value": "John"} - ], - "logical_operator": "and" - } + filter_dict = {"conditions": [{"name": "author", "comparison_operator": "is", "value": "John"}], "logical_operator": "and"} result = get_metadata_filter_expression(filter_dict) assert "JSON_EXTRACT(metadata, '$.author')" in result assert "= 'John'" in result def test_numeric_comparison_with_zero(self): """Test numeric comparison with zero value (regression test for bug).""" - filter_dict = { - "conditions": [ - {"name": "count", "comparison_operator": "=", "value": 0} - ], - "logical_operator": "and" - } + filter_dict = {"conditions": [{"name": "count", "comparison_operator": "=", "value": 0}], "logical_operator": "and"} result = get_metadata_filter_expression(filter_dict) assert "JSON_EXTRACT(metadata, '$.count')" in result assert "= 0" in result @@ -148,85 +138,49 @@ class TestGetMetadataFilterExpression: def test_numeric_comparison_with_float_zero(self): """Test numeric comparison with 0.0.""" - filter_dict = { - "conditions": [ - {"name": "rating", "comparison_operator": "=", "value": 0.0} - ], - "logical_operator": "and" - } + filter_dict = {"conditions": [{"name": "rating", "comparison_operator": "=", "value": 0.0}], "logical_operator": "and"} result = get_metadata_filter_expression(filter_dict) assert "JSON_EXTRACT(metadata, '$.rating')" in result assert "0.0" in result def test_empty_string_condition(self): """Test condition with empty string value.""" - filter_dict = { - "conditions": [ - {"name": "status", "comparison_operator": "is", "value": ""} - ], - "logical_operator": "and" - } + filter_dict = {"conditions": [{"name": "status", "comparison_operator": "is", "value": ""}], "logical_operator": "and"} result = get_metadata_filter_expression(filter_dict) assert "JSON_EXTRACT(metadata, '$.status')" in result assert "= ''" in result def test_boolean_false_condition(self): """Test condition with False value.""" - filter_dict = { - "conditions": [ - {"name": "active", "comparison_operator": "is", "value": False} - ], - "logical_operator": "and" - } + filter_dict = {"conditions": [{"name": "active", "comparison_operator": "is", "value": False}], "logical_operator": "and"} result = get_metadata_filter_expression(filter_dict) assert "JSON_EXTRACT(metadata, '$.active')" in result assert "false" in result def test_empty_list_condition(self): """Test condition with empty list.""" - filter_dict = { - "conditions": [ - {"name": "tags", "comparison_operator": "is", "value": []} - ], - "logical_operator": "and" - } + filter_dict = {"conditions": [{"name": "tags", "comparison_operator": "is", "value": []}], "logical_operator": "and"} result = get_metadata_filter_expression(filter_dict) assert "JSON_EXTRACT(metadata, '$.tags')" in result assert "'[]'" in result def test_empty_dict_condition(self): """Test condition with empty dict.""" - filter_dict = { - "conditions": [ - {"name": "metadata", "comparison_operator": "is", "value": {}} - ], - "logical_operator": "and" - } + filter_dict = {"conditions": [{"name": "metadata", "comparison_operator": "is", "value": {}}], "logical_operator": "and"} result = get_metadata_filter_expression(filter_dict) assert "JSON_EXTRACT(metadata, '$.metadata')" in result assert "'{}'" in result def test_none_value_condition(self): """Test condition with None value.""" - filter_dict = { - "conditions": [ - {"name": "optional", "comparison_operator": "is", "value": None} - ], - "logical_operator": "and" - } + filter_dict = {"conditions": [{"name": "optional", "comparison_operator": "is", "value": None}], "logical_operator": "and"} result = get_metadata_filter_expression(filter_dict) assert "JSON_EXTRACT(metadata, '$.optional')" in result assert "NULL" in result def test_multiple_conditions_with_and(self): """Test multiple conditions with AND operator.""" - filter_dict = { - "conditions": [ - {"name": "author", "comparison_operator": "is", "value": "John"}, - {"name": "year", "comparison_operator": ">", "value": 2020} - ], - "logical_operator": "and" - } + filter_dict = {"conditions": [{"name": "author", "comparison_operator": "is", "value": "John"}, {"name": "year", "comparison_operator": ">", "value": 2020}], "logical_operator": "and"} result = get_metadata_filter_expression(filter_dict) assert "JSON_EXTRACT(metadata, '$.author')" in result assert "JSON_EXTRACT(metadata, '$.year')" in result @@ -235,11 +189,8 @@ class TestGetMetadataFilterExpression: def test_multiple_conditions_with_or(self): """Test multiple conditions with OR operator.""" filter_dict = { - "conditions": [ - {"name": "status", "comparison_operator": "is", "value": "active"}, - {"name": "status", "comparison_operator": "is", "value": "pending"} - ], - "logical_operator": "or" + "conditions": [{"name": "status", "comparison_operator": "is", "value": "active"}, {"name": "status", "comparison_operator": "is", "value": "pending"}], + "logical_operator": "or", } result = get_metadata_filter_expression(filter_dict) assert "JSON_EXTRACT(metadata, '$.status')" in result @@ -247,71 +198,40 @@ class TestGetMetadataFilterExpression: def test_greater_than_operator(self): """Test greater than comparison.""" - filter_dict = { - "conditions": [ - {"name": "score", "comparison_operator": ">", "value": 90} - ], - "logical_operator": "and" - } + filter_dict = {"conditions": [{"name": "score", "comparison_operator": ">", "value": 90}], "logical_operator": "and"} result = get_metadata_filter_expression(filter_dict) assert ">" in result assert "90" in result def test_less_than_operator(self): """Test less than comparison.""" - filter_dict = { - "conditions": [ - {"name": "age", "comparison_operator": "<", "value": 18} - ], - "logical_operator": "and" - } + filter_dict = {"conditions": [{"name": "age", "comparison_operator": "<", "value": 18}], "logical_operator": "and"} result = get_metadata_filter_expression(filter_dict) assert "<" in result assert "18" in result def test_contains_operator(self): """Test contains operator.""" - filter_dict = { - "conditions": [ - {"name": "title", "comparison_operator": "contains", "value": "Python"} - ], - "logical_operator": "and" - } + filter_dict = {"conditions": [{"name": "title", "comparison_operator": "contains", "value": "Python"}], "logical_operator": "and"} result = get_metadata_filter_expression(filter_dict) assert "JSON_EXTRACT(metadata, '$.title')" in result def test_empty_operator(self): """Test empty operator.""" - filter_dict = { - "conditions": [ - {"name": "description", "comparison_operator": "empty", "value": None} - ], - "logical_operator": "and" - } + filter_dict = {"conditions": [{"name": "description", "comparison_operator": "empty", "value": None}], "logical_operator": "and"} result = get_metadata_filter_expression(filter_dict) assert "JSON_EXTRACT(metadata, '$.description')" in result assert "IS NULL" in result or "= ''" in result def test_not_empty_operator(self): """Test not empty operator.""" - filter_dict = { - "conditions": [ - {"name": "description", "comparison_operator": "not empty", "value": None} - ], - "logical_operator": "and" - } + filter_dict = {"conditions": [{"name": "description", "comparison_operator": "not empty", "value": None}], "logical_operator": "and"} result = get_metadata_filter_expression(filter_dict) assert "JSON_EXTRACT(metadata, '$.description')" in result def test_parentheses_wrapping(self): """Test that result is wrapped in parentheses.""" - filter_dict = { - "conditions": [ - {"name": "field", "comparison_operator": "is", "value": "value"} - ], - "logical_operator": "and" - } + filter_dict = {"conditions": [{"name": "field", "comparison_operator": "is", "value": "value"}], "logical_operator": "and"} result = get_metadata_filter_expression(filter_dict) assert result.startswith("(") assert result.endswith(")") - diff --git a/test/unit_test/rag/utils/test_opensearch_doc_meta.py b/test/unit_test/rag/utils/test_opensearch_doc_meta.py index ead97f6f8b..e111461ab6 100644 --- a/test/unit_test/rag/utils/test_opensearch_doc_meta.py +++ b/test/unit_test/rag/utils/test_opensearch_doc_meta.py @@ -28,6 +28,7 @@ The OpenSearch and Elasticsearch SDKs are imported at module load; mocking the underlying client lets us exercise OSConnection methods in isolation without a live cluster. """ + from __future__ import annotations import sys @@ -140,9 +141,7 @@ class TestOSConnectionMetaSurface: def test_create_doc_meta_idx_exists(self): cls = _resolve_os_connection_class() assert callable(getattr(cls, "create_doc_meta_idx", None)), ( - "OSConnection.create_doc_meta_idx is required so the metadata " - "PATCH path does not raise AttributeError on OpenSearch backends " - "(issue #14570)." + "OSConnection.create_doc_meta_idx is required so the metadata PATCH path does not raise AttributeError on OpenSearch backends (issue #14570)." ) def test_refresh_idx_exists(self): @@ -173,17 +172,21 @@ class TestCreateDocMetaIdx: fake_indices.create.return_value = {"acknowledged": True} cls = _resolve_os_connection_class() - with patch.object(cls, "index_exist", return_value=False), \ - patch("rag.utils.opensearch_conn.os.path.exists", return_value=True), \ - patch( - "rag.utils.opensearch_conn.open", - new=_open_returning_payload({ - "settings": {"index": {"number_of_shards": 2}}, - "mappings": {"properties": {"meta_fields": {"type": "object"}}}, - }), - create=True, - ), \ - patch("opensearchpy.client.IndicesClient", return_value=fake_indices): + with ( + patch.object(cls, "index_exist", return_value=False), + patch("rag.utils.opensearch_conn.os.path.exists", return_value=True), + patch( + "rag.utils.opensearch_conn.open", + new=_open_returning_payload( + { + "settings": {"index": {"number_of_shards": 2}}, + "mappings": {"properties": {"meta_fields": {"type": "object"}}}, + } + ), + create=True, + ), + patch("opensearchpy.client.IndicesClient", return_value=fake_indices), + ): result = conn.create_doc_meta_idx("ragflow_doc_meta_t1") assert result == {"acknowledged": True} @@ -197,8 +200,7 @@ class TestCreateDocMetaIdx: def test_returns_false_when_mapping_file_missing(self): conn = _make_os_connection() cls = _resolve_os_connection_class() - with patch.object(cls, "index_exist", return_value=False), \ - patch("rag.utils.opensearch_conn.os.path.exists", return_value=False): + with patch.object(cls, "index_exist", return_value=False), patch("rag.utils.opensearch_conn.os.path.exists", return_value=False): assert conn.create_doc_meta_idx("ragflow_doc_meta_t1") is False def test_returns_false_when_create_call_explodes(self): @@ -210,14 +212,16 @@ class TestCreateDocMetaIdx: fake_indices = MagicMock() fake_indices.create.side_effect = RuntimeError("opensearch unreachable") - with patch.object(cls, "index_exist", return_value=False), \ - patch("rag.utils.opensearch_conn.os.path.exists", return_value=True), \ - patch( - "rag.utils.opensearch_conn.open", - new=_open_returning_payload({"settings": {}, "mappings": {}}), - create=True, - ), \ - patch("opensearchpy.client.IndicesClient", return_value=fake_indices): + with ( + patch.object(cls, "index_exist", return_value=False), + patch("rag.utils.opensearch_conn.os.path.exists", return_value=True), + patch( + "rag.utils.opensearch_conn.open", + new=_open_returning_payload({"settings": {}, "mappings": {}}), + create=True, + ), + patch("opensearchpy.client.IndicesClient", return_value=fake_indices), + ): assert conn.create_doc_meta_idx("ragflow_doc_meta_t1") is False @@ -229,9 +233,7 @@ class TestRefreshIdx: def test_returns_false_on_not_found(self): conn = _make_os_connection() - conn.os.indices.refresh.side_effect = opensearchpy.NotFoundError( - 404, "index_not_found_exception", {} - ) + conn.os.indices.refresh.side_effect = opensearchpy.NotFoundError(404, "index_not_found_exception", {}) assert conn.refresh_idx("missing_idx") is False def test_swallows_other_errors_and_returns_false(self): @@ -249,9 +251,7 @@ class TestCountIdx: def test_missing_index_reads_as_zero(self): conn = _make_os_connection() - conn.os.count.side_effect = opensearchpy.NotFoundError( - 404, "index_not_found_exception", {} - ) + conn.os.count.side_effect = opensearchpy.NotFoundError(404, "index_not_found_exception", {}) assert conn.count_idx("ragflow_doc_meta_t1") == 0 def test_other_failure_returns_negative_one(self): @@ -282,7 +282,5 @@ class TestReplaceMetaFields: def test_returns_false_when_doc_missing(self): conn = _make_os_connection() - conn.os.update.side_effect = opensearchpy.NotFoundError( - 404, "document_missing_exception", {} - ) + conn.os.update.side_effect = opensearchpy.NotFoundError(404, "document_missing_exception", {}) assert conn.replace_meta_fields("ragflow_doc_meta_t1", "absent", {"a": 1}) is False diff --git a/test/unit_test/rag/utils/test_raptor_utils.py b/test/unit_test/rag/utils/test_raptor_utils.py index b0b8581e31..f9c5f0b060 100644 --- a/test/unit_test/rag/utils/test_raptor_utils.py +++ b/test/unit_test/rag/utils/test_raptor_utils.py @@ -124,7 +124,7 @@ class TestAsExtraDict: def test_returns_empty_dict_for_non_dict_json(self): """Test that non-dict JSON returns empty dict.""" - input_str = '[1, 2, 3]' + input_str = "[1, 2, 3]" result = _as_extra_dict(input_str) assert result == {} @@ -208,32 +208,19 @@ class TestCollectRaptorMethods: def test_collects_methods_from_raptor_chunks(self): """Test that methods are collected from RAPTOR chunks.""" - field_map = { - "chunk_1": { - "raptor_kwd": "raptor", - "extra": {"raptor_method": PSI_TREE_BUILDER} - } - } + field_map = {"chunk_1": {"raptor_kwd": "raptor", "extra": {"raptor_method": PSI_TREE_BUILDER}}} result = collect_raptor_methods(field_map) assert result == {PSI_TREE_BUILDER} def test_skips_non_raptor_chunks(self): """Test that non-RAPTOR chunks are skipped.""" - field_map = { - "chunk_1": { - "raptor_kwd": "other", - "extra": {"raptor_method": PSI_TREE_BUILDER} - } - } + field_map = {"chunk_1": {"raptor_kwd": "other", "extra": {"raptor_method": PSI_TREE_BUILDER}}} result = collect_raptor_methods(field_map) assert result == set() def test_collects_multiple_methods(self): """Test that multiple methods are collected.""" - field_map = { - "chunk_1": {"raptor_kwd": "raptor", "extra": {"raptor_method": "raptor"}}, - "chunk_2": {"raptor_kwd": "raptor", "extra": {"raptor_method": "psi"}} - } + field_map = {"chunk_1": {"raptor_kwd": "raptor", "extra": {"raptor_method": "raptor"}}, "chunk_2": {"raptor_kwd": "raptor", "extra": {"raptor_method": "psi"}}} result = collect_raptor_methods(field_map) assert result == {RAPTOR_TREE_BUILDER, PSI_TREE_BUILDER} @@ -248,28 +235,19 @@ class TestCollectRaptorChunkIds: def test_collects_ids_of_raptor_chunks(self): """Test that IDs of RAPTOR chunks are collected.""" - field_map = { - "chunk_1": {"raptor_kwd": "raptor"}, - "chunk_2": {"raptor_kwd": "raptor"} - } + field_map = {"chunk_1": {"raptor_kwd": "raptor"}, "chunk_2": {"raptor_kwd": "raptor"}} result = collect_raptor_chunk_ids(field_map) assert result == {"chunk_1", "chunk_2"} def test_excludes_specified_methods(self): """Test that specified methods are excluded.""" - field_map = { - "chunk_1": {"raptor_kwd": "raptor", "extra": {"raptor_method": "raptor"}}, - "chunk_2": {"raptor_kwd": "raptor", "extra": {"raptor_method": "psi"}} - } + field_map = {"chunk_1": {"raptor_kwd": "raptor", "extra": {"raptor_method": "raptor"}}, "chunk_2": {"raptor_kwd": "raptor", "extra": {"raptor_method": "psi"}}} result = collect_raptor_chunk_ids(field_map, exclude_methods={"raptor"}) assert result == {"chunk_2"} def test_skips_non_raptor_chunks(self): """Test that non-RAPTOR chunks are skipped.""" - field_map = { - "chunk_1": {"raptor_kwd": "raptor"}, - "chunk_2": {"raptor_kwd": "other"} - } + field_map = {"chunk_1": {"raptor_kwd": "raptor"}, "chunk_2": {"raptor_kwd": "other"}} result = collect_raptor_chunk_ids(field_map) assert result == {"chunk_1"} @@ -400,17 +378,11 @@ class TestShouldSkipRaptor: def test_respects_auto_disable_config_false(self): """Test that auto_disable_for_structured_data=False disables skipping.""" - assert should_skip_raptor( - file_type=".xlsx", - raptor_config={"auto_disable_for_structured_data": False} - ) is False + assert should_skip_raptor(file_type=".xlsx", raptor_config={"auto_disable_for_structured_data": False}) is False def test_respects_auto_disable_config_true(self): """Test that auto_disable_for_structured_data=True enables skipping.""" - assert should_skip_raptor( - file_type=".xlsx", - raptor_config={"auto_disable_for_structured_data": True} - ) is True + assert should_skip_raptor(file_type=".xlsx", raptor_config={"auto_disable_for_structured_data": True}) is True def test_default_auto_disable_is_true(self): """Test that default auto_disable is True.""" diff --git a/tools/chatgpt-on-wechat/plugins/__init__.py b/tools/chatgpt-on-wechat/plugins/__init__.py index 557f0d1f1b..29ad3f6e6f 100644 --- a/tools/chatgpt-on-wechat/plugins/__init__.py +++ b/tools/chatgpt-on-wechat/plugins/__init__.py @@ -15,10 +15,9 @@ # from beartype.claw import beartype_this_package + beartype_this_package() from .ragflow_chat import RAGFlowChat -__all__ = [ - "RAGFlowChat" -] +__all__ = ["RAGFlowChat"] diff --git a/tools/chatgpt-on-wechat/plugins/ragflow_chat.py b/tools/chatgpt-on-wechat/plugins/ragflow_chat.py index fe96f39a74..af3f4a33af 100644 --- a/tools/chatgpt-on-wechat/plugins/ragflow_chat.py +++ b/tools/chatgpt-on-wechat/plugins/ragflow_chat.py @@ -21,6 +21,7 @@ from bridge.reply import Reply, ReplyType # Import Reply, ReplyType from plugins import Plugin, register # Import Plugin and register from plugins.event import Event, EventContext, EventAction # Import event-related classes + @register(name="RAGFlowChat", desc="Use RAGFlow API to chat", version="1.0", author="Your Name") class RAGFlowChat(Plugin): def __init__(self): @@ -34,12 +35,12 @@ class RAGFlowChat(Plugin): logging.info("[RAGFlowChat] Plugin initialized") def on_handle_context(self, e_context: EventContext): - context = e_context['context'] + context = e_context["context"] if context.type != ContextType.TEXT: return # Only process text messages user_input = context.content.strip() - session_id = context['session_id'] + session_id = context["session_id"] # Call RAGFlow API to get a reply reply_text = self.get_ragflow_reply(user_input, session_id) @@ -47,7 +48,7 @@ class RAGFlowChat(Plugin): reply = Reply() reply.type = ReplyType.TEXT reply.content = reply_text - e_context['reply'] = reply + e_context["reply"] = reply e_context.action = EventAction.BREAK_PASS # Skip the default processing logic else: # If no reply is received, pass to the next plugin or default logic @@ -63,19 +64,14 @@ class RAGFlowChat(Plugin): logging.error("[RAGFlowChat] Missing configuration") return "The plugin configuration is incomplete. Please check the configuration." - headers = { - "Authorization": f"Bearer {api_key}", - "Content-Type": "application/json" - } + headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} # Step 1: Get or create conversation_id conversation_id = self.conversations.get(user_id) if not conversation_id: # Create a new conversation url_new_conversation = f"http://{host_address}/v1/api/new_conversation" - params_new_conversation = { - "user_id": user_id - } + params_new_conversation = {"user_id": user_id} try: response = requests.get(url_new_conversation, headers=headers, params=params_new_conversation) logging.debug(f"[RAGFlowChat] New conversation response: {response.text}") @@ -96,17 +92,7 @@ class RAGFlowChat(Plugin): # Step 2: Send the message and get a reply url_completion = f"http://{host_address}/v1/api/completion" - payload_completion = { - "conversation_id": conversation_id, - "messages": [ - { - "role": "user", - "content": user_input - } - ], - "quote": False, - "stream": False - } + payload_completion = {"conversation_id": conversation_id, "messages": [{"role": "user", "content": user_input}], "quote": False, "stream": False} try: response = requests.post(url_completion, headers=headers, json=payload_completion) diff --git a/tools/es-to-oceanbase-migration/src/es_ob_migration/cli.py b/tools/es-to-oceanbase-migration/src/es_ob_migration/cli.py index bb3ec2477f..5e0886752c 100644 --- a/tools/es-to-oceanbase-migration/src/es_ob_migration/cli.py +++ b/tools/es-to-oceanbase-migration/src/es_ob_migration/cli.py @@ -38,7 +38,7 @@ def main(ctx, verbose): Migrate RAGFlow data from Elasticsearch 8+ to OceanBase with schema conversion, vector data mapping, batch import, and resume capability. - + This tool is specifically designed for RAGFlow's data structure. """ ctx.ensure_object(dict) @@ -84,7 +84,7 @@ def migrate( progress_dir, ): """Run RAGFlow data migration from Elasticsearch to OceanBase. - + If --index is omitted, all indices starting with 'ragflow_' will be migrated. If --table is omitted, the same name as the source index will be used. """ @@ -116,14 +116,14 @@ def migrate( # Auto-discover all ragflow_* indices console.print("\n[cyan]Discovering RAGFlow indices...[/]") ragflow_indices = es_client.list_ragflow_indices() - + if not ragflow_indices: console.print("[yellow]No ragflow_* indices found in Elasticsearch[/]") sys.exit(0) - + # Each index maps to a table with the same name indices_to_migrate = [(idx, idx) for idx in ragflow_indices] - + console.print(f"[green]Found {len(indices_to_migrate)} RAGFlow indices:[/]") for idx, _ in indices_to_migrate: doc_count = es_client.count_documents(idx) @@ -144,9 +144,9 @@ def migrate( # Migrate each index for es_index, ob_table in indices_to_migrate: - console.print(f"\n[bold blue]{'='*60}[/]") + console.print(f"\n[bold blue]{'=' * 60}[/]") console.print(f"[bold]Migrating: {es_index} -> {ob_database}.{ob_table}[/]") - console.print(f"[bold blue]{'='*60}[/]") + console.print(f"[bold blue]{'=' * 60}[/]") result = migrator.migrate( es_index=es_index, @@ -155,7 +155,7 @@ def migrate( resume=resume, verify_after=verify, ) - + results.append(result) if result["success"]: total_success += 1 @@ -164,9 +164,9 @@ def migrate( # Summary for multiple indices if len(indices_to_migrate) > 1: - console.print(f"\n[bold]{'='*60}[/]") + console.print(f"\n[bold]{'=' * 60}[/]") console.print("[bold]Migration Summary[/]") - console.print(f"[bold]{'='*60}[/]") + console.print(f"[bold]{'=' * 60}[/]") console.print(f" Total indices: {len(indices_to_migrate)}") console.print(f" [green]Successful: {total_success}[/]") if total_failed > 0: @@ -217,33 +217,34 @@ def schema(ctx, es_host, es_port, es_user, es_password, index, output): migrator = ESToOceanBaseMigrator(es_client, ob_client if ob_client else OBClient.__new__(OBClient)) # Directly use schema converter from .schema import RAGFlowSchemaConverter + converter = RAGFlowSchemaConverter() - + es_mapping = es_client.get_index_mapping(index) analysis = converter.analyze_es_mapping(es_mapping) column_defs = converter.get_column_definitions() # Display analysis console.print(f"\n[bold]ES Index Analysis: {index}[/]\n") - + # Known RAGFlow fields console.print(f"[green]Known RAGFlow fields:[/] {len(analysis['known_fields'])}") - + # Vector fields - if analysis['vector_fields']: + if analysis["vector_fields"]: console.print("\n[cyan]Vector fields detected:[/]") - for vf in analysis['vector_fields']: + for vf in analysis["vector_fields"]: console.print(f" - {vf['name']} (dimension: {vf['dimension']})") - + # Unknown fields - if analysis['unknown_fields']: + if analysis["unknown_fields"]: console.print("\n[yellow]Unknown fields (will be stored in 'extra'):[/]") - for uf in analysis['unknown_fields']: + for uf in analysis["unknown_fields"]: console.print(f" - {uf}") # Display RAGFlow column schema console.print(f"\n[bold]RAGFlow OceanBase Schema ({len(column_defs)} columns):[/]\n") - + table = Table(title="Column Definitions") table.add_column("Column Name", style="cyan") table.add_column("OB Type", style="green") @@ -260,7 +261,7 @@ def schema(ctx, es_host, es_port, es_user, es_password, index, output): special.append("ARRAY") if col.get("is_vector"): special.append("VECTOR") - + table.add_row( col["name"], col["ob_type"], @@ -333,7 +334,8 @@ def verify( verifier = MigrationVerifier(es_client, ob_client) result = verifier.verify( - index, table, + index, + table, sample_size=sample_size, ) @@ -386,7 +388,7 @@ def list_indices(ctx, es_host, es_port, es_user, es_password): for idx in indices: doc_count = es_client.count_documents(idx) total_docs += doc_count - + # Determine index type if idx.startswith("ragflow_doc_meta_"): idx_type = "Metadata" @@ -394,7 +396,7 @@ def list_indices(ctx, es_host, es_port, es_user, es_password): idx_type = "Document Chunks" else: idx_type = "Unknown" - + table.add_row(idx, f"{doc_count:,}", idx_type) table.add_row("", "", "") @@ -488,11 +490,11 @@ def status(ctx, es_host, es_port, ob_host, ob_port, ob_user, ob_password): console.print(f" Cluster: {health.get('cluster_name')}") console.print(f" Status: {health.get('status')}") console.print(f" Version: {info.get('version', {}).get('number', 'unknown')}") - + # List indices indices = es_client.list_indices("*") console.print(f" Indices: {len(indices)}") - + es_client.close() except Exception as e: console.print(f"[red]Elasticsearch ({es_host}:{es_port}): Failed[/]") @@ -542,7 +544,7 @@ def sample(ctx, es_host, es_port, index, size): console.print(f" kb_id: {doc.get('kb_id')}") console.print(f" doc_id: {doc.get('doc_id')}") console.print(f" docnm_kwd: {doc.get('docnm_kwd')}") - + # Check for vector fields vector_fields = [k for k in doc.keys() if k.startswith("q_") and k.endswith("_vec")] if vector_fields: @@ -550,14 +552,14 @@ def sample(ctx, es_host, es_port, index, size): vec = doc.get(vf) if vec: console.print(f" {vf}: [{len(vec)} dimensions]") - + content = doc.get("content_with_weight", "") if content: if isinstance(content, dict): content = json.dumps(content, ensure_ascii=False) preview = content[:100] + "..." if len(str(content)) > 100 else content console.print(f" content: {preview}") - + console.print() es_client.close() diff --git a/tools/es-to-oceanbase-migration/src/es_ob_migration/es_client.py b/tools/es-to-oceanbase-migration/src/es_ob_migration/es_client.py index e04f18e9dd..0dd97b3952 100644 --- a/tools/es-to-oceanbase-migration/src/es_ob_migration/es_client.py +++ b/tools/es-to-oceanbase-migration/src/es_ob_migration/es_client.py @@ -72,11 +72,11 @@ class ESClient: def list_ragflow_indices(self) -> list[str]: """ List all RAGFlow-related indices. - + Returns indices matching patterns: - ragflow_* (document chunks) - ragflow_doc_meta_* (document metadata) - + Returns: List of RAGFlow index names """ @@ -111,18 +111,14 @@ class ESClient: response = self.client.count(index=index_name) return response["count"] - def count_documents_with_filter( - self, - index_name: str, - filters: dict[str, Any] - ) -> int: + def count_documents_with_filter(self, index_name: str, filters: dict[str, Any]) -> int: """ Count documents with filter conditions. - + Args: index_name: Index name filters: Filter conditions (e.g., {"kb_id": "xxx"}) - + Returns: Document count """ @@ -133,30 +129,26 @@ class ESClient: must_clauses.append({"terms": {field: value}}) else: must_clauses.append({"term": {field: value}}) - - query = { - "bool": { - "must": must_clauses - } - } if must_clauses else {"match_all": {}} - + + query = {"bool": {"must": must_clauses}} if must_clauses else {"match_all": {}} + response = self.client.count(index=index_name, query=query) return response["count"] def aggregate_field( - self, - index_name: str, + self, + index_name: str, field: str, size: int = 10000, ) -> dict[str, Any]: """ Aggregate field values (like getting all unique kb_ids). - + Args: index_name: Index name field: Field to aggregate size: Max number of buckets - + Returns: Aggregation result with buckets """ @@ -170,7 +162,7 @@ class ESClient: "size": size, } } - } + }, ) return response["aggregations"]["field_values"] @@ -244,24 +236,21 @@ class ESClient: return None def get_sample_documents( - self, - index_name: str, + self, + index_name: str, size: int = 10, query: dict[str, Any] | None = None, ) -> list[dict[str, Any]]: """ Get sample documents from an index. - + Args: index_name: Index name size: Number of samples query: Optional query filter """ - search_body = { - "query": query if query else {"match_all": {}}, - "size": size - } - + search_body = {"query": query if query else {"match_all": {}}, "size": size} + response = self.client.search(index=index_name, body=search_body) documents = [] for hit in response["hits"]["hits"]: @@ -271,8 +260,8 @@ class ESClient: return documents def get_document_ids( - self, - index_name: str, + self, + index_name: str, size: int = 1000, query: dict[str, Any] | None = None, ) -> list[str]: @@ -282,7 +271,7 @@ class ESClient: "size": size, "_source": False, } - + response = self.client.search(index=index_name, body=search_body) return [hit["_id"] for hit in response["hits"]["hits"]] diff --git a/tools/es-to-oceanbase-migration/src/es_ob_migration/migrator.py b/tools/es-to-oceanbase-migration/src/es_ob_migration/migrator.py index ba194dcd10..20325b93eb 100644 --- a/tools/es-to-oceanbase-migration/src/es_ob_migration/migrator.py +++ b/tools/es-to-oceanbase-migration/src/es_ob_migration/migrator.py @@ -29,7 +29,7 @@ console = Console() class ESToOceanBaseMigrator: """ RAGFlow-specific migration orchestrator. - + This migrator is designed specifically for RAGFlow's data structure, handling the fixed schema and vector embeddings correctly. """ @@ -99,7 +99,7 @@ class ESToOceanBaseMigrator: # Step 2: Analyze ES index console.print("\n[bold blue]Step 2: Analyzing ES index...[/]") analysis = self._analyze_es_index(es_index) - + # Auto-detect vector size from ES mapping vector_size = 768 # Default fallback if analysis["vector_fields"]: @@ -114,7 +114,7 @@ class ESToOceanBaseMigrator: # Step 3: Get total document count total_docs = self.es_client.count_documents(es_index) console.print(f" Total documents: {total_docs:,}") - + result["total_documents"] = total_docs if total_docs == 0: @@ -126,21 +126,13 @@ class ESToOceanBaseMigrator: if resume and self.progress_manager.can_resume(es_index, ob_table): console.print("\n[bold yellow]Resuming from previous progress...[/]") progress = self.progress_manager.load_progress(es_index, ob_table) - console.print( - f" Previously migrated: {progress.migrated_documents:,} documents" - ) + console.print(f" Previously migrated: {progress.migrated_documents:,} documents") else: # Fresh start - check if table already exists if self.ob_client.table_exists(ob_table): - raise RuntimeError( - f"Table '{ob_table}' already exists in OceanBase. " - f"Migration aborted to prevent data conflicts. " - f"Please drop the table manually or use a different table name." - ) + raise RuntimeError(f"Table '{ob_table}' already exists in OceanBase. Migration aborted to prevent data conflicts. Please drop the table manually or use a different table name.") - progress = self.progress_manager.create_progress( - es_index, ob_table, total_docs - ) + progress = self.progress_manager.create_progress(es_index, ob_table, total_docs) # Step 5: Create table if needed if not progress.table_created: @@ -157,7 +149,7 @@ class ESToOceanBaseMigrator: console.print(f" Table '{ob_table}' already exists") # Check and add vector column if needed self.ob_client.add_vector_column(ob_table, vector_size) - + progress.table_created = True progress.indexes_created = True progress.schema_converted = True @@ -186,10 +178,7 @@ class ESToOceanBaseMigrator: if verify_after: console.print("\n[bold blue]Step 5: Verifying migration...[/]") verifier = MigrationVerifier(self.es_client, self.ob_client) - verification = verifier.verify( - es_index, ob_table, - primary_key="id" - ) + verification = verifier.verify(es_index, ob_table, primary_key="id") result["verification"] = { "passed": verification.passed, "message": verification.message, @@ -236,7 +225,7 @@ class ESToOceanBaseMigrator: # Check OceanBase if not self.ob_client.health_check(): raise RuntimeError("OceanBase connection failed") - + ob_version = self.ob_client.get_version() console.print(f" OceanBase connection: OK (version: {ob_version})") @@ -275,7 +264,7 @@ class ESToOceanBaseMigrator: batch_count = 0 for batch in self.es_client.scroll_documents(es_index, batch_size): batch_count += 1 - + # Convert batch to OceanBase format ob_rows = data_converter.convert_batch(batch) @@ -341,7 +330,7 @@ class ESToOceanBaseMigrator: ) -> list[dict[str, Any]]: """ Get sample documents from ES for preview. - + Args: es_index: ES index name sample_size: Number of samples @@ -355,10 +344,10 @@ class ESToOceanBaseMigrator: def list_knowledge_bases(self, es_index: str) -> list[str]: """ List all knowledge base IDs in an ES index. - + Args: es_index: ES index name - + Returns: List of kb_id values """ diff --git a/tools/es-to-oceanbase-migration/src/es_ob_migration/ob_client.py b/tools/es-to-oceanbase-migration/src/es_ob_migration/ob_client.py index 50a6d92de9..12f7e031d3 100644 --- a/tools/es-to-oceanbase-migration/src/es_ob_migration/ob_client.py +++ b/tools/es-to-oceanbase-migration/src/es_ob_migration/ob_client.py @@ -24,7 +24,7 @@ VECTOR_INDEX_NAME_TEMPLATE = "%s_idx" # Columns that need regular indexes INDEX_COLUMNS = [ "kb_id", - "doc_id", + "doc_id", "available_int", "knowledge_graph_kwd", "entity_type_kwd", @@ -110,7 +110,7 @@ class OBClient: ): """ Create a RAGFlow-compatible table in OceanBase. - + This creates a table with the exact schema that RAGFlow expects, including all columns, indexes, and vector columns. @@ -122,21 +122,18 @@ class OBClient: """ # Build column definitions columns = self._build_ragflow_columns() - + # Add vector column vector_column_name = f"q_{vector_size}_vec" - columns.append( - Column(vector_column_name, VECTOR(vector_size), nullable=True, - comment=f"vector embedding ({vector_size} dimensions)") - ) - + columns.append(Column(vector_column_name, VECTOR(vector_size), nullable=True, comment=f"vector embedding ({vector_size} dimensions)")) + # Table options (from RAGFlow) table_options = { "mysql_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci", "mysql_organization": "heap", } - + # Create table self.client.create_table( table_name=table_name, @@ -144,49 +141,49 @@ class OBClient: **table_options, ) logger.info(f"Created table: {table_name}") - + # Create regular indexes if create_indexes: self._create_regular_indexes(table_name) - + # Create fulltext indexes if create_fts_indexes: self._create_fulltext_indexes(table_name) - + # Create vector index self._create_vector_index(table_name, vector_column_name) - + # Refresh metadata self.client.refresh_metadata([table_name]) def _build_ragflow_columns(self) -> list[Column]: """Build SQLAlchemy Column objects for RAGFlow schema.""" columns = [] - + for col_name, col_def in RAGFLOW_COLUMNS.items(): ob_type = col_def["ob_type"] nullable = col_def.get("nullable", True) default = col_def.get("default") is_primary = col_def.get("is_primary", False) is_array = col_def.get("is_array", False) - + # Parse type and create appropriate Column col = self._create_column(col_name, ob_type, nullable, default, is_primary, is_array) columns.append(col) - + return columns def _create_column( - self, - name: str, - ob_type: str, + self, + name: str, + ob_type: str, nullable: bool, default: Any, is_primary: bool, is_array: bool, ) -> Column: """Create a SQLAlchemy Column object based on type string.""" - + # Handle array types if is_array or ob_type.startswith("ARRAY"): # Extract inner type @@ -196,26 +193,22 @@ class OBClient: inner_type = Integer else: inner_type = String(256) - + # Nested array (e.g., ARRAY(ARRAY(Integer))) if ob_type.count("ARRAY") > 1: return Column(name, ARRAY(ARRAY(inner_type)), nullable=nullable) else: return Column(name, ARRAY(inner_type), nullable=nullable) - + # Handle String types with length if ob_type.startswith("String"): # Extract length: String(256) -> 256 import re - match = re.search(r'\((\d+)\)', ob_type) + + match = re.search(r"\((\d+)\)", ob_type) length = int(match.group(1)) if match else 256 - return Column( - name, String(length), - primary_key=is_primary, - nullable=nullable, - server_default=f"'{default}'" if default else None - ) - + return Column(name, String(length), primary_key=is_primary, nullable=nullable, server_default=f"'{default}'" if default else None) + # Map other types type_map = { "Integer": Integer, @@ -225,16 +218,11 @@ class OBClient: "LONGTEXT": LONGTEXT, "TEXT": MYSQL_TEXT, } - + for type_name, type_class in type_map.items(): if type_name in ob_type: - return Column( - name, type_class, - primary_key=is_primary, - nullable=nullable, - server_default=str(default) if default is not None else None - ) - + return Column(name, type_class, primary_key=is_primary, nullable=nullable, server_default=str(default) if default is not None else None) + # Default to String return Column(name, String(256), nullable=nullable) @@ -298,19 +286,19 @@ class OBClient: def add_vector_column(self, table_name: str, vector_size: int): """Add a vector column to an existing table.""" vector_column_name = f"q_{vector_size}_vec" - + # Check if column exists if self._column_exists(table_name, vector_column_name): logger.info(f"Vector column {vector_column_name} already exists") return - + try: self.client.add_columns( table_name=table_name, columns=[Column(vector_column_name, VECTOR(vector_size), nullable=True)], ) logger.info(f"Added vector column: {vector_column_name}") - + # Create index self._create_vector_index(table_name, vector_column_name) except Exception as e: @@ -321,10 +309,7 @@ class OBClient: """Check if a column exists in a table.""" try: result = self.client.perform_raw_text_sql( - f"SELECT COUNT(*) FROM INFORMATION_SCHEMA.COLUMNS " - f"WHERE TABLE_SCHEMA = '{self.database}' " - f"AND TABLE_NAME = '{table_name}' " - f"AND COLUMN_NAME = '{column_name}'" + f"SELECT COUNT(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA = '{self.database}' AND TABLE_NAME = '{table_name}' AND COLUMN_NAME = '{column_name}'" ) count = result.fetchone()[0] return count > 0 @@ -335,10 +320,7 @@ class OBClient: """Check if an index exists.""" try: result = self.client.perform_raw_text_sql( - f"SELECT COUNT(*) FROM INFORMATION_SCHEMA.STATISTICS " - f"WHERE TABLE_SCHEMA = '{self.database}' " - f"AND TABLE_NAME = '{table_name}' " - f"AND INDEX_NAME = '{index_name}'" + f"SELECT COUNT(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE TABLE_SCHEMA = '{self.database}' AND TABLE_NAME = '{table_name}' AND INDEX_NAME = '{index_name}'" ) count = result.fetchone()[0] return count > 0 @@ -373,7 +355,7 @@ class OBClient: def count_rows(self, table_name: str, kb_id: str | None = None) -> int: """ Count rows in a table. - + Args: table_name: Table name kb_id: Optional knowledge base ID filter @@ -388,8 +370,8 @@ class OBClient: return 0 def get_sample_rows( - self, - table_name: str, + self, + table_name: str, limit: int = 10, kb_id: str | None = None, ) -> list[dict[str, Any]]: @@ -399,7 +381,7 @@ class OBClient: if kb_id: sql += f" WHERE kb_id = '{kb_id}'" sql += f" LIMIT {limit}" - + result = self.client.perform_raw_text_sql(sql) columns = result.keys() rows = [] diff --git a/tools/es-to-oceanbase-migration/src/es_ob_migration/progress.py b/tools/es-to-oceanbase-migration/src/es_ob_migration/progress.py index 29c23bb4bc..6c44dbc33b 100644 --- a/tools/es-to-oceanbase-migration/src/es_ob_migration/progress.py +++ b/tools/es-to-oceanbase-migration/src/es_ob_migration/progress.py @@ -64,9 +64,7 @@ class ProgressManager: filename = f"{es_index}_to_{ob_table}.json" return self.progress_dir / filename - def load_progress( - self, es_index: str, ob_table: str - ) -> MigrationProgress | None: + def load_progress(self, es_index: str, ob_table: str) -> MigrationProgress | None: """ Load progress from file. @@ -86,9 +84,7 @@ class ProgressManager: with open(progress_file, "r") as f: data = json.load(f) progress = MigrationProgress(**data) - logger.info( - f"Loaded progress: {progress.migrated_documents}/{progress.total_documents} documents" - ) + logger.info(f"Loaded progress: {progress.migrated_documents}/{progress.total_documents} documents") return progress except Exception as e: logger.warning(f"Failed to load progress: {e}") @@ -174,9 +170,7 @@ class ProgressManager: progress.status = "completed" progress.updated_at = datetime.utcnow().isoformat() self.save_progress(progress) - logger.info( - f"Migration completed: {progress.migrated_documents} documents" - ) + logger.info(f"Migration completed: {progress.migrated_documents} documents") def mark_failed(self, progress: MigrationProgress, error: str): """Mark migration as failed.""" @@ -191,9 +185,7 @@ class ProgressManager: progress.status = "paused" progress.updated_at = datetime.utcnow().isoformat() self.save_progress(progress) - logger.info( - f"Migration paused at {progress.migrated_documents}/{progress.total_documents}" - ) + logger.info(f"Migration paused at {progress.migrated_documents}/{progress.total_documents}") def can_resume(self, es_index: str, ob_table: str) -> bool: """Check if migration can be resumed.""" diff --git a/tools/es-to-oceanbase-migration/src/es_ob_migration/schema.py b/tools/es-to-oceanbase-migration/src/es_ob_migration/schema.py index 468bde9574..4a6aedd1dd 100644 --- a/tools/es-to-oceanbase-migration/src/es_ob_migration/schema.py +++ b/tools/es-to-oceanbase-migration/src/es_ob_migration/schema.py @@ -20,23 +20,18 @@ RAGFLOW_COLUMNS = { "id": {"ob_type": "String(256)", "nullable": False, "is_primary": True}, "kb_id": {"ob_type": "String(256)", "nullable": False, "index": True}, "doc_id": {"ob_type": "String(256)", "nullable": True, "index": True}, - # Document metadata "docnm_kwd": {"ob_type": "String(256)", "nullable": True}, # document name "doc_type_kwd": {"ob_type": "String(256)", "nullable": True}, # document type - # Title fields "title_tks": {"ob_type": "String(256)", "nullable": True}, # title tokens "title_sm_tks": {"ob_type": "String(256)", "nullable": True}, # fine-grained title tokens - # Content fields "content_with_weight": {"ob_type": "LONGTEXT", "nullable": True}, # original content "content_ltks": {"ob_type": "LONGTEXT", "nullable": True}, # long text tokens "content_sm_ltks": {"ob_type": "LONGTEXT", "nullable": True}, # fine-grained tokens - # Feature fields "pagerank_fea": {"ob_type": "Integer", "nullable": True}, # page rank priority - # Array fields "important_kwd": {"ob_type": "ARRAY(String(256))", "nullable": True, "is_array": True}, # keywords "important_tks": {"ob_type": "TEXT", "nullable": True}, # keyword tokens @@ -44,22 +39,17 @@ RAGFLOW_COLUMNS = { "question_tks": {"ob_type": "TEXT", "nullable": True}, # question tokens "tag_kwd": {"ob_type": "ARRAY(String(256))", "nullable": True, "is_array": True}, # tags "tag_feas": {"ob_type": "JSON", "nullable": True, "is_json": True}, # tag features - # Status fields "available_int": {"ob_type": "Integer", "nullable": False, "default": 1}, - # Time fields "create_time": {"ob_type": "String(19)", "nullable": True}, "create_timestamp_flt": {"ob_type": "Double", "nullable": True}, - # Image field "img_id": {"ob_type": "String(128)", "nullable": True}, - # Position fields (arrays) "position_int": {"ob_type": "ARRAY(ARRAY(Integer))", "nullable": True, "is_array": True}, "page_num_int": {"ob_type": "ARRAY(Integer)", "nullable": True, "is_array": True}, "top_int": {"ob_type": "ARRAY(Integer)", "nullable": True, "is_array": True}, - # Knowledge graph fields "knowledge_graph_kwd": {"ob_type": "String(256)", "nullable": True, "index": True}, "source_id": {"ob_type": "ARRAY(String(256))", "nullable": True, "is_array": True}, @@ -71,14 +61,11 @@ RAGFLOW_COLUMNS = { "weight_flt": {"ob_type": "Double", "nullable": True}, "entities_kwd": {"ob_type": "ARRAY(String(256))", "nullable": True, "is_array": True}, "rank_flt": {"ob_type": "Double", "nullable": True}, - # Status "removed_kwd": {"ob_type": "String(256)", "nullable": True, "index": True, "default": "N"}, - # JSON fields "metadata": {"ob_type": "JSON", "nullable": True, "is_json": True}, "extra": {"ob_type": "JSON", "nullable": True, "is_json": True}, - # New columns "_order_id": {"ob_type": "Integer", "nullable": True}, "group_id": {"ob_type": "String(256)", "nullable": True}, @@ -86,10 +73,7 @@ RAGFLOW_COLUMNS = { } # Array column names for special handling -ARRAY_COLUMNS = [ - "important_kwd", "question_kwd", "tag_kwd", "source_id", - "entities_kwd", "position_int", "page_num_int", "top_int" -] +ARRAY_COLUMNS = ["important_kwd", "question_kwd", "tag_kwd", "source_id", "entities_kwd", "position_int", "page_num_int", "top_int"] # JSON column names JSON_COLUMNS = ["tag_feas", "metadata", "extra"] @@ -105,7 +89,7 @@ VECTOR_FIELD_PATTERN = re.compile(r"q_(?P\d+)_vec") class RAGFlowSchemaConverter: """ Convert RAGFlow Elasticsearch documents to OceanBase format. - + RAGFlow uses a fixed schema, so this converter knows exactly what fields to expect and how to map them. """ @@ -117,10 +101,10 @@ class RAGFlowSchemaConverter: def analyze_es_mapping(self, es_mapping: dict[str, Any]) -> dict[str, Any]: """ Analyze ES mapping to extract vector field dimensions. - + Args: es_mapping: Elasticsearch index mapping - + Returns: Analysis result with detected fields """ @@ -129,9 +113,9 @@ class RAGFlowSchemaConverter: "vector_fields": [], "unknown_fields": [], } - + properties = es_mapping.get("properties", {}) - + for field_name, field_def in properties.items(): # Check if it's a known RAGFlow field if field_name in RAGFLOW_COLUMNS: @@ -140,59 +124,63 @@ class RAGFlowSchemaConverter: elif VECTOR_FIELD_PATTERN.match(field_name): match = VECTOR_FIELD_PATTERN.match(field_name) vec_size = int(match.group("vector_size")) - result["vector_fields"].append({ - "name": field_name, - "dimension": vec_size, - }) - self.vector_fields.append({ - "name": field_name, - "dimension": vec_size, - }) + result["vector_fields"].append( + { + "name": field_name, + "dimension": vec_size, + } + ) + self.vector_fields.append( + { + "name": field_name, + "dimension": vec_size, + } + ) if self.detected_vector_size is None: self.detected_vector_size = vec_size else: # Unknown field - might be custom field stored in 'extra' result["unknown_fields"].append(field_name) - - logger.info( - f"Analyzed ES mapping: {len(result['known_fields'])} known fields, " - f"{len(result['vector_fields'])} vector fields, " - f"{len(result['unknown_fields'])} unknown fields" - ) - + + logger.info(f"Analyzed ES mapping: {len(result['known_fields'])} known fields, {len(result['vector_fields'])} vector fields, {len(result['unknown_fields'])} unknown fields") + return result def get_column_definitions(self) -> list[dict[str, Any]]: """ Get RAGFlow column definitions for OceanBase table creation. - + Returns: List of column definitions """ columns = [] - + for col_name, col_def in RAGFLOW_COLUMNS.items(): - columns.append({ - "name": col_name, - "ob_type": col_def["ob_type"], - "nullable": col_def.get("nullable", True), - "is_primary": col_def.get("is_primary", False), - "index": col_def.get("index", False), - "is_array": col_def.get("is_array", False), - "is_json": col_def.get("is_json", False), - "default": col_def.get("default"), - }) - + columns.append( + { + "name": col_name, + "ob_type": col_def["ob_type"], + "nullable": col_def.get("nullable", True), + "is_primary": col_def.get("is_primary", False), + "index": col_def.get("index", False), + "is_array": col_def.get("is_array", False), + "is_json": col_def.get("is_json", False), + "default": col_def.get("default"), + } + ) + # Add detected vector fields for vec_field in self.vector_fields: - columns.append({ - "name": vec_field["name"], - "ob_type": f"VECTOR({vec_field['dimension']})", - "nullable": True, - "is_vector": True, - "dimension": vec_field["dimension"], - }) - + columns.append( + { + "name": vec_field["name"], + "ob_type": f"VECTOR({vec_field['dimension']})", + "nullable": True, + "is_vector": True, + "dimension": vec_field["dimension"], + } + ) + return columns def get_vector_fields(self) -> list[dict[str, Any]]: @@ -203,7 +191,7 @@ class RAGFlowSchemaConverter: class RAGFlowDataConverter: """ Convert RAGFlow ES documents to OceanBase row format. - + This converter handles the specific data transformations needed for RAGFlow's data structure. """ @@ -221,111 +209,104 @@ class RAGFlowDataConverter: def convert_document(self, es_doc: dict[str, Any]) -> dict[str, Any]: """ Convert an ES document to OceanBase row format. - + Args: es_doc: Elasticsearch document (with _id and _source) - + Returns: Dictionary ready for OceanBase insertion """ # Extract _id and _source doc_id = es_doc.get("_id") source = es_doc.get("_source", es_doc) - + row = {} - + # Set document ID if doc_id: row["id"] = str(doc_id) elif "id" in source: row["id"] = str(source["id"]) - + # Process each field for field_name, field_def in RAGFLOW_COLUMNS.items(): if field_name == "id": continue # Already handled - + value = source.get(field_name) - + if value is None: # Use default if available default = field_def.get("default") if default is not None: row[field_name] = default continue - + # Convert based on field type - row[field_name] = self._convert_field_value( - field_name, value, field_def - ) - + row[field_name] = self._convert_field_value(field_name, value, field_def) + # Handle vector fields for key, value in source.items(): if VECTOR_FIELD_PATTERN.match(key): if isinstance(value, list): row[key] = value self.vector_fields.add(key) - + # Handle unknown fields -> store in 'extra' extra_fields = {} for key, value in source.items(): if key not in RAGFLOW_COLUMNS and not VECTOR_FIELD_PATTERN.match(key): extra_fields[key] = value - + if extra_fields: existing_extra = row.get("extra") if existing_extra and isinstance(existing_extra, dict): existing_extra.update(extra_fields) else: row["extra"] = json.dumps(extra_fields, ensure_ascii=False) - + return row - def _convert_field_value( - self, - field_name: str, - value: Any, - field_def: dict[str, Any] - ) -> Any: + def _convert_field_value(self, field_name: str, value: Any, field_def: dict[str, Any]) -> Any: """ Convert a field value to the appropriate format for OceanBase. - + Args: field_name: Field name value: Original value from ES field_def: Field definition from RAGFLOW_COLUMNS - + Returns: Converted value """ if value is None: return None - + ob_type = field_def.get("ob_type", "") is_array = field_def.get("is_array", False) is_json = field_def.get("is_json", False) - + # Handle array fields if is_array: return self._convert_array_value(value) - + # Handle JSON fields if is_json: return self._convert_json_value(value) - + # Handle specific types if "Integer" in ob_type: return self._convert_integer(value) - + if "Double" in ob_type or "Float" in ob_type: return self._convert_float(value) - + if "LONGTEXT" in ob_type or "TEXT" in ob_type: return self._convert_text(value) - + if "String" in ob_type: return self._convert_string(value, field_name) - + # Default: convert to string return str(value) if value is not None else None @@ -333,7 +314,7 @@ class RAGFlowDataConverter: """Convert array value to JSON string for OceanBase.""" if value is None: return None - + if isinstance(value, str): # Already a JSON string try: @@ -343,7 +324,7 @@ class RAGFlowDataConverter: except json.JSONDecodeError: # Not valid JSON, wrap in array return json.dumps([value], ensure_ascii=False) - + if isinstance(value, list): # Clean array values cleaned = [] @@ -351,15 +332,15 @@ class RAGFlowDataConverter: if isinstance(item, str): # Clean special characters cleaned_str = item.strip() - cleaned_str = cleaned_str.replace('\\', '\\\\') - cleaned_str = cleaned_str.replace('\n', '\\n') - cleaned_str = cleaned_str.replace('\r', '\\r') - cleaned_str = cleaned_str.replace('\t', '\\t') + cleaned_str = cleaned_str.replace("\\", "\\\\") + cleaned_str = cleaned_str.replace("\n", "\\n") + cleaned_str = cleaned_str.replace("\r", "\\r") + cleaned_str = cleaned_str.replace("\t", "\\t") cleaned.append(cleaned_str) else: cleaned.append(item) return json.dumps(cleaned, ensure_ascii=False) - + # Single value - wrap in array return json.dumps([value], ensure_ascii=False) @@ -367,7 +348,7 @@ class RAGFlowDataConverter: """Convert JSON value to string for OceanBase.""" if value is None: return None - + if isinstance(value, str): # Already a string, validate JSON try: @@ -376,20 +357,20 @@ class RAGFlowDataConverter: except json.JSONDecodeError: # Not valid JSON, return as-is return value - + if isinstance(value, (dict, list)): return json.dumps(value, ensure_ascii=False) - + return str(value) def _convert_integer(self, value: Any) -> int | None: """Convert to integer.""" if value is None: return None - + if isinstance(value, bool): return 1 if value else 0 - + try: return int(value) except (ValueError, TypeError): @@ -399,7 +380,7 @@ class RAGFlowDataConverter: """Convert to float.""" if value is None: return None - + try: return float(value) except (ValueError, TypeError): @@ -409,37 +390,37 @@ class RAGFlowDataConverter: """Convert to text/longtext.""" if value is None: return None - + if isinstance(value, dict): # content_with_weight might be stored as dict return json.dumps(value, ensure_ascii=False) - + if isinstance(value, list): return json.dumps(value, ensure_ascii=False) - + return str(value) def _convert_string(self, value: Any, field_name: str) -> str | None: """Convert to string with length considerations.""" if value is None: return None - + # Handle kb_id which might be a list in ES if field_name == "kb_id" and isinstance(value, list): return str(value[0]) if value else None - + if isinstance(value, (dict, list)): return json.dumps(value, ensure_ascii=False) - + return str(value) def convert_batch(self, es_docs: list[dict[str, Any]]) -> list[dict[str, Any]]: """ Convert a batch of ES documents. - + Args: es_docs: List of Elasticsearch documents - + Returns: List of dictionaries ready for OceanBase insertion """ diff --git a/tools/es-to-oceanbase-migration/src/es_ob_migration/verify.py b/tools/es-to-oceanbase-migration/src/es_ob_migration/verify.py index 0df94bbc92..a67f4ee599 100644 --- a/tools/es-to-oceanbase-migration/src/es_ob_migration/verify.py +++ b/tools/es-to-oceanbase-migration/src/es_ob_migration/verify.py @@ -47,8 +47,13 @@ class MigrationVerifier: # Fields to compare for verification VERIFY_FIELDS = [ - "id", "kb_id", "doc_id", "docnm_kwd", "content_with_weight", - "available_int", "create_time", + "id", + "kb_id", + "doc_id", + "docnm_kwd", + "content_with_weight", + "available_int", + "create_time", ] def __init__( @@ -88,7 +93,7 @@ class MigrationVerifier: VerificationResult with details """ result = VerificationResult( - es_index=es_index, + es_index=es_index, ob_table=ob_table, ) @@ -97,26 +102,21 @@ class MigrationVerifier: # Step 1: Verify document counts logger.info("Verifying document counts...") - + result.es_count = self.es_client.count_documents(es_index) result.ob_count = self.ob_client.count_rows(ob_table) - + result.count_diff = abs(result.es_count - result.ob_count) result.count_match = result.count_diff == 0 - logger.info( - f"Document counts - ES: {result.es_count}, OB: {result.ob_count}, " - f"Diff: {result.count_diff}" - ) + logger.info(f"Document counts - ES: {result.es_count}, OB: {result.ob_count}, Diff: {result.count_diff}") # Step 2: Sample verification result.sample_size = min(sample_size, result.es_count) - + if result.sample_size > 0: logger.info(f"Verifying {result.sample_size} sample documents...") - self._verify_samples( - es_index, ob_table, result, primary_key, verify_fields - ) + self._verify_samples(es_index, ob_table, result, primary_key, verify_fields) # Step 3: Determine overall result self._determine_result(result) @@ -134,9 +134,7 @@ class MigrationVerifier: ): """Verify sample documents.""" # Get sample documents from ES - es_samples = self.es_client.get_sample_documents( - es_index, result.sample_size - ) + es_samples = self.es_client.get_sample_documents(es_index, result.sample_size) for es_doc in es_samples: result.samples_verified += 1 @@ -154,17 +152,17 @@ class MigrationVerifier: continue # Compare documents - match, differences = self._compare_documents( - es_doc, ob_doc, verify_fields - ) - + match, differences = self._compare_documents(es_doc, ob_doc, verify_fields) + if match: result.samples_matched += 1 else: - result.data_mismatches.append({ - "id": doc_id, - "differences": differences, - }) + result.data_mismatches.append( + { + "id": doc_id, + "differences": differences, + } + ) # Calculate match rate if result.samples_verified > 0: @@ -178,7 +176,7 @@ class MigrationVerifier: ) -> tuple[bool, list[dict[str, Any]]]: """ Compare ES document with OceanBase row. - + Returns: Tuple of (match: bool, differences: list) """ @@ -194,20 +192,17 @@ class MigrationVerifier: # Handle special comparisons if not self._values_equal(field_name, es_value, ob_value): - differences.append({ - "field": field_name, - "es_value": es_value, - "ob_value": ob_value, - }) + differences.append( + { + "field": field_name, + "es_value": es_value, + "ob_value": ob_value, + } + ) return len(differences) == 0, differences - def _values_equal( - self, - field_name: str, - es_value: Any, - ob_value: Any - ) -> bool: + def _values_equal(self, field_name: str, es_value: Any, ob_value: Any) -> bool: """Compare two values with type-aware logic.""" if es_value is None and ob_value is None: return True @@ -261,32 +256,19 @@ class MigrationVerifier: """Determine overall verification result.""" # Allow small count differences (e.g., documents added during migration) count_tolerance = 0.01 # 1% tolerance - count_ok = ( - result.count_match or - (result.es_count > 0 and result.count_diff / result.es_count <= count_tolerance) - ) + count_ok = result.count_match or (result.es_count > 0 and result.count_diff / result.es_count <= count_tolerance) if count_ok and result.sample_match_rate >= 0.99: result.passed = True - result.message = ( - f"Verification PASSED. " - f"ES: {result.es_count:,}, OB: {result.ob_count:,}. " - f"Sample match rate: {result.sample_match_rate:.2%}" - ) + result.message = f"Verification PASSED. ES: {result.es_count:,}, OB: {result.ob_count:,}. Sample match rate: {result.sample_match_rate:.2%}" elif count_ok and result.sample_match_rate >= 0.95: result.passed = True - result.message = ( - f"Verification PASSED with warnings. " - f"ES: {result.es_count:,}, OB: {result.ob_count:,}. " - f"Sample match rate: {result.sample_match_rate:.2%}" - ) + result.message = f"Verification PASSED with warnings. ES: {result.es_count:,}, OB: {result.ob_count:,}. Sample match rate: {result.sample_match_rate:.2%}" else: result.passed = False issues = [] if not count_ok: - issues.append( - f"Count mismatch (ES: {result.es_count}, OB: {result.ob_count}, diff: {result.count_diff})" - ) + issues.append(f"Count mismatch (ES: {result.es_count}, OB: {result.ob_count}, diff: {result.count_diff})") if result.sample_match_rate < 0.95: issues.append(f"Low sample match rate: {result.sample_match_rate:.2%}") if result.missing_in_ob: @@ -303,22 +285,24 @@ class MigrationVerifier: f"ES Index: {result.es_index}", f"OB Table: {result.ob_table}", ] - - lines.extend([ - "", - "Document Counts:", - f" Elasticsearch: {result.es_count:,}", - f" OceanBase: {result.ob_count:,}", - f" Difference: {result.count_diff:,}", - f" Match: {'Yes' if result.count_match else 'No'}", - "", - "Sample Verification:", - f" Sample Size: {result.sample_size}", - f" Verified: {result.samples_verified}", - f" Matched: {result.samples_matched}", - f" Match Rate: {result.sample_match_rate:.2%}", - "", - ]) + + lines.extend( + [ + "", + "Document Counts:", + f" Elasticsearch: {result.es_count:,}", + f" OceanBase: {result.ob_count:,}", + f" Difference: {result.count_diff:,}", + f" Match: {'Yes' if result.count_match else 'No'}", + "", + "Sample Verification:", + f" Sample Size: {result.sample_size}", + f" Verified: {result.samples_verified}", + f" Matched: {result.samples_matched}", + f" Match Rate: {result.sample_match_rate:.2%}", + "", + ] + ) if result.missing_in_ob: lines.append(f"Missing in OceanBase ({len(result.missing_in_ob)}):") @@ -338,12 +322,14 @@ class MigrationVerifier: lines.append(f" ... and {len(result.data_mismatches) - 3} more") lines.append("") - lines.extend([ - "=" * 60, - f"Result: {'PASSED' if result.passed else 'FAILED'}", - result.message, - "=" * 60, - "", - ]) + lines.extend( + [ + "=" * 60, + f"Result: {'PASSED' if result.passed else 'FAILED'}", + result.message, + "=" * 60, + "", + ] + ) return "\n".join(lines) diff --git a/tools/es-to-oceanbase-migration/tests/test_progress.py b/tools/es-to-oceanbase-migration/tests/test_progress.py index 0e7368d548..cba107f9f5 100644 --- a/tools/es-to-oceanbase-migration/tests/test_progress.py +++ b/tools/es-to-oceanbase-migration/tests/test_progress.py @@ -20,7 +20,7 @@ class TestMigrationProgress: es_index="ragflow_test", ob_table="ragflow_test", ) - + assert progress.es_index == "ragflow_test" assert progress.ob_table == "ragflow_test" assert progress.total_documents == 0 @@ -37,7 +37,7 @@ class TestMigrationProgress: total_documents=1000, migrated_documents=500, ) - + assert progress.total_documents == 1000 assert progress.migrated_documents == 500 @@ -47,7 +47,7 @@ class TestMigrationProgress: es_index="test_index", ob_table="test_table", ) - + assert progress.failed_documents == 0 assert progress.last_sort_values == [] assert progress.last_batch_ids == [] @@ -99,7 +99,7 @@ class TestProgressManager: ob_table="ragflow_abc123", total_documents=1000, ) - + assert progress.es_index == "ragflow_abc123" assert progress.ob_table == "ragflow_abc123" assert progress.total_documents == 1000 @@ -116,10 +116,10 @@ class TestProgressManager: progress.migrated_documents = 250 progress.last_sort_values = ["doc_250", 1234567890] manager.save_progress(progress) - + # Load loaded = manager.load_progress("ragflow_test", "ragflow_test") - + assert loaded is not None assert loaded.es_index == "ragflow_test" assert loaded.total_documents == 500 @@ -139,13 +139,13 @@ class TestProgressManager: ob_table="ragflow_delete_test", total_documents=100, ) - + # Verify it exists assert manager.load_progress("ragflow_delete_test", "ragflow_delete_test") is not None - + # Delete manager.delete_progress("ragflow_delete_test", "ragflow_delete_test") - + # Verify it's gone assert manager.load_progress("ragflow_delete_test", "ragflow_delete_test") is None @@ -156,7 +156,7 @@ class TestProgressManager: ob_table="ragflow_update", total_documents=1000, ) - + # Update manager.update_progress( progress, @@ -164,7 +164,7 @@ class TestProgressManager: last_sort_values=["doc_100", 9999], last_batch_ids=["id1", "id2", "id3"], ) - + assert progress.migrated_documents == 100 assert progress.last_sort_values == ["doc_100", 9999] assert progress.last_batch_ids == ["id1", "id2", "id3"] @@ -176,11 +176,11 @@ class TestProgressManager: ob_table="ragflow_multi", total_documents=1000, ) - + # Update multiple times for i in range(1, 11): manager.update_progress(progress, migrated_count=100) - + assert progress.migrated_documents == 1000 def test_mark_completed(self, manager): @@ -191,9 +191,9 @@ class TestProgressManager: total_documents=100, ) progress.migrated_documents = 100 - + manager.mark_completed(progress) - + assert progress.status == "completed" def test_mark_failed(self, manager): @@ -203,9 +203,9 @@ class TestProgressManager: ob_table="ragflow_fail", total_documents=100, ) - + manager.mark_failed(progress, "Connection timeout") - + assert progress.status == "failed" assert progress.error_message == "Connection timeout" @@ -217,9 +217,9 @@ class TestProgressManager: total_documents=1000, ) progress.migrated_documents = 500 - + manager.mark_paused(progress) - + assert progress.status == "paused" def test_can_resume_running(self, manager): @@ -229,7 +229,7 @@ class TestProgressManager: ob_table="ragflow_resume_running", total_documents=1000, ) - + assert manager.can_resume("ragflow_resume_running", "ragflow_resume_running") is True def test_can_resume_paused(self, manager): @@ -240,7 +240,7 @@ class TestProgressManager: total_documents=1000, ) manager.mark_paused(progress) - + assert manager.can_resume("ragflow_resume_paused", "ragflow_resume_paused") is True def test_can_resume_completed(self, manager): @@ -252,7 +252,7 @@ class TestProgressManager: ) progress.migrated_documents = 100 manager.mark_completed(progress) - + # Completed migrations should not be resumed assert manager.can_resume("ragflow_resume_complete", "ragflow_resume_complete") is False @@ -272,9 +272,9 @@ class TestProgressManager: progress.schema_converted = True progress.table_created = True manager.save_progress(progress) - + info = manager.get_resume_info("ragflow_info", "ragflow_info") - + assert info is not None assert info["migrated_documents"] == 500 assert info["total_documents"] == 1000 @@ -295,7 +295,7 @@ class TestProgressManager: ob_table="ragflow_abc123", total_documents=100, ) - + expected_file = manager.progress_dir / "ragflow_abc123_to_ragflow_abc123.json" assert expected_file.exists() @@ -308,12 +308,12 @@ class TestProgressManager: ) progress.migrated_documents = 50 manager.save_progress(progress) - + # Read file directly progress_file = manager.progress_dir / "ragflow_json_to_ragflow_json.json" with open(progress_file) as f: data = json.load(f) - + assert data["es_index"] == "ragflow_json" assert data["ob_table"] == "ragflow_json" assert data["total_documents"] == 100 diff --git a/tools/es-to-oceanbase-migration/tests/test_schema.py b/tools/es-to-oceanbase-migration/tests/test_schema.py index cd55f98ceb..e856ae110f 100644 --- a/tools/es-to-oceanbase-migration/tests/test_schema.py +++ b/tools/es-to-oceanbase-migration/tests/test_schema.py @@ -27,7 +27,7 @@ class TestRAGFlowSchemaConverter: def test_analyze_ragflow_mapping(self): """Test analyzing a RAGFlow ES mapping.""" converter = RAGFlowSchemaConverter() - + # Simulate a RAGFlow ES mapping es_mapping = { "properties": { @@ -42,14 +42,14 @@ class TestRAGFlowSchemaConverter: "q_768_vec": {"type": "dense_vector", "dims": 768}, } } - + analysis = converter.analyze_es_mapping(es_mapping) - + # Check known fields assert "id" in analysis["known_fields"] assert "kb_id" in analysis["known_fields"] assert "content_with_weight" in analysis["known_fields"] - + # Check vector fields assert len(analysis["vector_fields"]) == 1 assert analysis["vector_fields"][0]["name"] == "q_768_vec" @@ -58,21 +58,21 @@ class TestRAGFlowSchemaConverter: def test_detect_vector_size(self): """Test automatic vector size detection.""" converter = RAGFlowSchemaConverter() - + es_mapping = { "properties": { "q_1536_vec": {"type": "dense_vector", "dims": 1536}, } } - + converter.analyze_es_mapping(es_mapping) - + assert converter.detected_vector_size == 1536 def test_unknown_fields(self): """Test that unknown fields are properly identified.""" converter = RAGFlowSchemaConverter() - + es_mapping = { "properties": { "id": {"type": "keyword"}, @@ -80,16 +80,16 @@ class TestRAGFlowSchemaConverter: "another_field": {"type": "integer"}, } } - + analysis = converter.analyze_es_mapping(es_mapping) - + assert "custom_field" in analysis["unknown_fields"] assert "another_field" in analysis["unknown_fields"] def test_get_column_definitions(self): """Test getting RAGFlow column definitions.""" converter = RAGFlowSchemaConverter() - + # First analyze to detect vector fields es_mapping = { "properties": { @@ -97,15 +97,15 @@ class TestRAGFlowSchemaConverter: } } converter.analyze_es_mapping(es_mapping) - + columns = converter.get_column_definitions() - + # Check that all RAGFlow columns are present column_names = [c["name"] for c in columns] - + for col_name in RAGFLOW_COLUMNS: assert col_name in column_names, f"Missing column: {col_name}" - + # Check vector column is added assert "q_768_vec" in column_names @@ -116,7 +116,7 @@ class TestRAGFlowDataConverter: def test_convert_basic_document(self): """Test converting a basic RAGFlow document.""" converter = RAGFlowDataConverter() - + es_doc = { "_id": "test-id-123", "_source": { @@ -126,11 +126,11 @@ class TestRAGFlowDataConverter: "docnm_kwd": "test_document.pdf", "content_with_weight": "This is test content", "available_int": 1, - } + }, } - + row = converter.convert_document(es_doc) - + assert row["id"] == "test-id-123" assert row["kb_id"] == "kb-001" assert row["doc_id"] == "doc-001" @@ -141,7 +141,7 @@ class TestRAGFlowDataConverter: def test_convert_with_vector(self): """Test converting document with vector embedding.""" converter = RAGFlowDataConverter() - + embedding = [0.1] * 768 es_doc = { "_id": "vec-doc-001", @@ -149,11 +149,11 @@ class TestRAGFlowDataConverter: "id": "vec-doc-001", "kb_id": "kb-001", "q_768_vec": embedding, - } + }, } - + row = converter.convert_document(es_doc) - + assert row["id"] == "vec-doc-001" assert row["q_768_vec"] == embedding assert "q_768_vec" in converter.vector_fields @@ -161,7 +161,7 @@ class TestRAGFlowDataConverter: def test_convert_array_fields(self): """Test converting array fields.""" converter = RAGFlowDataConverter() - + es_doc = { "_id": "array-doc", "_source": { @@ -170,11 +170,11 @@ class TestRAGFlowDataConverter: "important_kwd": ["keyword1", "keyword2", "keyword3"], "question_kwd": ["What is this?", "How does it work?"], "tag_kwd": ["tag1", "tag2"], - } + }, } - + row = converter.convert_document(es_doc) - + # Array fields should be JSON strings assert isinstance(row["important_kwd"], str) parsed = json.loads(row["important_kwd"]) @@ -183,7 +183,7 @@ class TestRAGFlowDataConverter: def test_convert_json_fields(self): """Test converting JSON fields.""" converter = RAGFlowDataConverter() - + es_doc = { "_id": "json-doc", "_source": { @@ -191,22 +191,22 @@ class TestRAGFlowDataConverter: "kb_id": "kb-001", "tag_feas": {"tag1": 0.8, "tag2": 0.5}, "metadata": {"author": "John", "date": "2024-01-01"}, - } + }, } - + row = converter.convert_document(es_doc) - + # JSON fields should be JSON strings assert isinstance(row["tag_feas"], str) assert isinstance(row["metadata"], str) - + tag_feas = json.loads(row["tag_feas"]) assert tag_feas == {"tag1": 0.8, "tag2": 0.5} def test_convert_unknown_fields_to_extra(self): """Test that unknown fields are stored in 'extra'.""" converter = RAGFlowDataConverter() - + es_doc = { "_id": "extra-doc", "_source": { @@ -214,11 +214,11 @@ class TestRAGFlowDataConverter: "kb_id": "kb-001", "custom_field": "custom_value", "another_custom": 123, - } + }, } - + row = converter.convert_document(es_doc) - + assert "extra" in row extra = json.loads(row["extra"]) assert extra["custom_field"] == "custom_value" @@ -227,24 +227,24 @@ class TestRAGFlowDataConverter: def test_convert_kb_id_list(self): """Test converting kb_id when it's a list (ES format).""" converter = RAGFlowDataConverter() - + es_doc = { "_id": "kb-list-doc", "_source": { "id": "kb-list-doc", "kb_id": ["kb-001", "kb-002"], # Some ES docs have list - } + }, } - + row = converter.convert_document(es_doc) - + # Should take first element assert row["kb_id"] == "kb-001" def test_convert_content_with_weight_dict(self): """Test converting content_with_weight when it's a dict.""" converter = RAGFlowDataConverter() - + es_doc = { "_id": "content-dict-doc", "_source": { @@ -254,11 +254,11 @@ class TestRAGFlowDataConverter: "text": "Some content", "weight": 1.0, }, - } + }, } - + row = converter.convert_document(es_doc) - + # Dict should be JSON serialized assert isinstance(row["content_with_weight"], str) parsed = json.loads(row["content_with_weight"]) @@ -267,14 +267,11 @@ class TestRAGFlowDataConverter: def test_convert_batch(self): """Test batch conversion.""" converter = RAGFlowDataConverter() - - es_docs = [ - {"_id": f"doc-{i}", "_source": {"id": f"doc-{i}", "kb_id": "kb-001"}} - for i in range(5) - ] - + + es_docs = [{"_id": f"doc-{i}", "_source": {"id": f"doc-{i}", "kb_id": "kb-001"}} for i in range(5)] + rows = converter.convert_batch(es_docs) - + assert len(rows) == 5 for i, row in enumerate(rows): assert row["id"] == f"doc-{i}" @@ -291,7 +288,7 @@ class TestVectorFieldPattern: "q_1536_vec", "q_3072_vec", ] - + for name in valid_names: match = VECTOR_FIELD_PATTERN.match(name) assert match is not None, f"Should match: {name}" @@ -305,7 +302,7 @@ class TestVectorFieldPattern: "vector_768", "content_with_weight", ] - + for name in invalid_names: match = VECTOR_FIELD_PATTERN.match(name) assert match is None, f"Should not match: {name}" @@ -322,28 +319,30 @@ class TestConstants: def test_array_columns(self): """Test ARRAY_COLUMNS list.""" - expected = [ - "important_kwd", "question_kwd", "tag_kwd", "source_id", - "entities_kwd", "position_int", "page_num_int", "top_int" - ] - + expected = ["important_kwd", "question_kwd", "tag_kwd", "source_id", "entities_kwd", "position_int", "page_num_int", "top_int"] + for col in expected: assert col in ARRAY_COLUMNS, f"Missing array column: {col}" def test_json_columns(self): """Test JSON_COLUMNS list.""" expected = ["tag_feas", "metadata", "extra"] - + for col in expected: assert col in JSON_COLUMNS, f"Missing JSON column: {col}" def test_ragflow_columns_completeness(self): """Test that RAGFLOW_COLUMNS has all required fields.""" required_fields = [ - "id", "kb_id", "doc_id", "content_with_weight", - "available_int", "metadata", "extra", + "id", + "kb_id", + "doc_id", + "content_with_weight", + "available_int", + "metadata", + "extra", ] - + for field in required_fields: assert field in RAGFLOW_COLUMNS, f"Missing required field: {field}" @@ -357,15 +356,15 @@ class TestConstants: # Primary key assert RAGFLOW_COLUMNS["id"]["is_primary"] is True assert RAGFLOW_COLUMNS["id"]["nullable"] is False - + # Indexed columns assert RAGFLOW_COLUMNS["kb_id"]["index"] is True assert RAGFLOW_COLUMNS["doc_id"]["index"] is True - + # Array columns assert RAGFLOW_COLUMNS["important_kwd"]["is_array"] is True assert RAGFLOW_COLUMNS["question_kwd"]["is_array"] is True - + # JSON columns assert RAGFLOW_COLUMNS["metadata"]["is_json"] is True assert RAGFLOW_COLUMNS["extra"]["is_json"] is True @@ -377,9 +376,9 @@ class TestRAGFlowSchemaConverterEdgeCases: def test_empty_mapping(self): """Test analyzing empty mapping.""" converter = RAGFlowSchemaConverter() - + analysis = converter.analyze_es_mapping({}) - + assert analysis["known_fields"] == [] assert analysis["vector_fields"] == [] assert analysis["unknown_fields"] == [] @@ -387,24 +386,24 @@ class TestRAGFlowSchemaConverterEdgeCases: def test_mapping_without_properties(self): """Test mapping without properties key.""" converter = RAGFlowSchemaConverter() - + analysis = converter.analyze_es_mapping({"some_other_key": {}}) - + assert analysis["known_fields"] == [] def test_multiple_vector_fields(self): """Test detecting multiple vector fields.""" converter = RAGFlowSchemaConverter() - + es_mapping = { "properties": { "q_768_vec": {"type": "dense_vector", "dims": 768}, "q_1024_vec": {"type": "dense_vector", "dims": 1024}, } } - + analysis = converter.analyze_es_mapping(es_mapping) - + assert len(analysis["vector_fields"]) == 2 # First detected should be set assert converter.detected_vector_size in [768, 1024] @@ -412,9 +411,9 @@ class TestRAGFlowSchemaConverterEdgeCases: def test_get_column_definitions_without_analysis(self): """Test getting columns without prior analysis.""" converter = RAGFlowSchemaConverter() - + columns = converter.get_column_definitions() - + # Should have all RAGFlow columns but no vector columns column_names = [c["name"] for c in columns] assert "id" in column_names @@ -423,16 +422,16 @@ class TestRAGFlowSchemaConverterEdgeCases: def test_get_vector_fields(self): """Test getting vector fields.""" converter = RAGFlowSchemaConverter() - + es_mapping = { "properties": { "q_1536_vec": {"type": "dense_vector", "dims": 1536}, } } converter.analyze_es_mapping(es_mapping) - + vec_fields = converter.get_vector_fields() - + assert len(vec_fields) == 1 assert vec_fields[0]["name"] == "q_1536_vec" assert vec_fields[0]["dimension"] == 1536 @@ -444,60 +443,60 @@ class TestRAGFlowDataConverterEdgeCases: def test_convert_empty_document(self): """Test converting empty document.""" converter = RAGFlowDataConverter() - + es_doc = {"_id": "empty_doc", "_source": {}} row = converter.convert_document(es_doc) - + assert row["id"] == "empty_doc" def test_convert_document_without_source(self): """Test converting document without _source.""" converter = RAGFlowDataConverter() - + es_doc = {"_id": "no_source", "id": "no_source", "kb_id": "kb_001"} row = converter.convert_document(es_doc) - + assert row["id"] == "no_source" assert row["kb_id"] == "kb_001" def test_convert_boolean_to_integer(self): """Test converting boolean to integer.""" converter = RAGFlowDataConverter() - + es_doc = { "_id": "bool_doc", "_source": { "id": "bool_doc", "kb_id": "kb_001", "available_int": True, - } + }, } - + row = converter.convert_document(es_doc) - + assert row["available_int"] == 1 def test_convert_invalid_integer(self): """Test converting invalid integer value.""" converter = RAGFlowDataConverter() - + es_doc = { "_id": "invalid_int", "_source": { "id": "invalid_int", "kb_id": "kb_001", "available_int": "not_a_number", - } + }, } - + row = converter.convert_document(es_doc) - + assert row["available_int"] is None def test_convert_float_field(self): """Test converting float fields.""" converter = RAGFlowDataConverter() - + es_doc = { "_id": "float_doc", "_source": { @@ -505,29 +504,29 @@ class TestRAGFlowDataConverterEdgeCases: "kb_id": "kb_001", "weight_flt": 0.85, "rank_flt": "0.95", # String that should become float - } + }, } - + row = converter.convert_document(es_doc) - + assert row["weight_flt"] == 0.85 assert row["rank_flt"] == 0.95 def test_convert_array_with_special_characters(self): """Test converting array with special characters.""" converter = RAGFlowDataConverter() - + es_doc = { "_id": "special_array", "_source": { "id": "special_array", "kb_id": "kb_001", "important_kwd": ["key\nwith\nnewlines", "key\twith\ttabs"], - } + }, } - + row = converter.convert_document(es_doc) - + # Should be JSON string with escaped characters assert isinstance(row["important_kwd"], str) parsed = json.loads(row["important_kwd"]) @@ -536,85 +535,85 @@ class TestRAGFlowDataConverterEdgeCases: def test_convert_already_json_array(self): """Test converting already JSON-encoded array.""" converter = RAGFlowDataConverter() - + es_doc = { "_id": "json_array", "_source": { "id": "json_array", "kb_id": "kb_001", "important_kwd": '["already", "json"]', - } + }, } - + row = converter.convert_document(es_doc) - + assert row["important_kwd"] == '["already", "json"]' def test_convert_single_value_to_array(self): """Test converting single value to array.""" converter = RAGFlowDataConverter() - + es_doc = { "_id": "single_to_array", "_source": { "id": "single_to_array", "kb_id": "kb_001", "important_kwd": "single_keyword", - } + }, } - + row = converter.convert_document(es_doc) - + parsed = json.loads(row["important_kwd"]) assert parsed == ["single_keyword"] def test_detect_vector_fields_from_document(self): """Test detecting vector fields from document.""" converter = RAGFlowDataConverter() - + doc = { "q_768_vec": [0.1] * 768, "q_1024_vec": [0.2] * 1024, } - + converter.detect_vector_fields(doc) - + assert "q_768_vec" in converter.vector_fields assert "q_1024_vec" in converter.vector_fields def test_convert_with_default_values(self): """Test conversion uses default values.""" converter = RAGFlowDataConverter() - + es_doc = { "_id": "default_test", "_source": { "id": "default_test", "kb_id": "kb_001", # available_int not provided, should get default - } + }, } - + row = converter.convert_document(es_doc) - + # available_int has default of 1 assert row.get("available_int") == 1 def test_convert_list_content(self): """Test converting list content to JSON.""" converter = RAGFlowDataConverter() - + es_doc = { "_id": "list_content", "_source": { "id": "list_content", "kb_id": "kb_001", "content_with_weight": ["part1", "part2", "part3"], - } + }, } - + row = converter.convert_document(es_doc) - + assert isinstance(row["content_with_weight"], str) parsed = json.loads(row["content_with_weight"]) assert parsed == ["part1", "part2", "part3"] @@ -622,15 +621,15 @@ class TestRAGFlowDataConverterEdgeCases: def test_convert_batch_empty(self): """Test batch conversion with empty list.""" converter = RAGFlowDataConverter() - + rows = converter.convert_batch([]) - + assert rows == [] def test_existing_extra_field_merged(self): """Test that existing extra field is merged with unknown fields.""" converter = RAGFlowDataConverter() - + es_doc = { "_id": "merge_extra", "_source": { @@ -638,11 +637,11 @@ class TestRAGFlowDataConverterEdgeCases: "kb_id": "kb_001", "extra": {"existing_key": "existing_value"}, "custom_field": "custom_value", - } + }, } - + row = converter.convert_document(es_doc) - + # extra should contain both existing and new fields extra = json.loads(row["extra"]) assert "custom_field" in extra diff --git a/tools/es-to-oceanbase-migration/tests/test_verify.py b/tools/es-to-oceanbase-migration/tests/test_verify.py index d0b9ee225e..dc42cb6fe7 100644 --- a/tools/es-to-oceanbase-migration/tests/test_verify.py +++ b/tools/es-to-oceanbase-migration/tests/test_verify.py @@ -17,7 +17,7 @@ class TestVerificationResult: es_index="ragflow_test", ob_table="ragflow_test", ) - + assert result.es_index == "ragflow_test" assert result.ob_table == "ragflow_test" assert result.es_count == 0 @@ -30,7 +30,7 @@ class TestVerificationResult: es_index="test", ob_table="test", ) - + assert result.count_match is False assert result.count_diff == 0 assert result.sample_size == 0 @@ -50,7 +50,7 @@ class TestVerificationResult: ob_count=1000, count_match=True, ) - + assert result.es_count == 1000 assert result.ob_count == 1000 assert result.count_match is True @@ -85,10 +85,10 @@ class TestMigrationVerifier: mock_es_client.count_documents.return_value = 1000 mock_ob_client.count_rows.return_value = 1000 mock_es_client.get_sample_documents.return_value = [] - + verifier = MigrationVerifier(mock_es_client, mock_ob_client) result = verifier.verify("ragflow_test", "ragflow_test", sample_size=0) - + assert result.es_count == 1000 assert result.ob_count == 1000 assert result.count_match is True @@ -99,10 +99,10 @@ class TestMigrationVerifier: mock_es_client.count_documents.return_value = 1000 mock_ob_client.count_rows.return_value = 950 mock_es_client.get_sample_documents.return_value = [] - + verifier = MigrationVerifier(mock_es_client, mock_ob_client) result = verifier.verify("ragflow_test", "ragflow_test", sample_size=0) - + assert result.es_count == 1000 assert result.ob_count == 950 assert result.count_match is False @@ -111,70 +111,64 @@ class TestMigrationVerifier: def test_verify_samples_all_match(self, mock_es_client, mock_ob_client): """Test sample verification when all samples match.""" # Setup ES samples - es_samples = [ - {"_id": f"doc_{i}", "id": f"doc_{i}", "kb_id": "kb_001", "content_with_weight": f"content_{i}"} - for i in range(10) - ] + es_samples = [{"_id": f"doc_{i}", "id": f"doc_{i}", "kb_id": "kb_001", "content_with_weight": f"content_{i}"} for i in range(10)] mock_es_client.count_documents.return_value = 100 mock_es_client.get_sample_documents.return_value = es_samples - + # Setup OB to return matching documents def get_row(table, doc_id): return {"id": doc_id, "kb_id": "kb_001", "content_with_weight": f"content_{doc_id.split('_')[1]}"} - + mock_ob_client.count_rows.return_value = 100 mock_ob_client.get_row_by_id.side_effect = get_row - + verifier = MigrationVerifier(mock_es_client, mock_ob_client) result = verifier.verify("ragflow_test", "ragflow_test", sample_size=10) - + assert result.samples_verified == 10 assert result.samples_matched == 10 assert result.sample_match_rate == 1.0 def test_verify_samples_some_missing(self, mock_es_client, mock_ob_client): """Test sample verification when some documents are missing.""" - es_samples = [ - {"_id": f"doc_{i}", "id": f"doc_{i}", "kb_id": "kb_001"} - for i in range(10) - ] + es_samples = [{"_id": f"doc_{i}", "id": f"doc_{i}", "kb_id": "kb_001"} for i in range(10)] mock_es_client.count_documents.return_value = 100 mock_es_client.get_sample_documents.return_value = es_samples - + # Only return some documents def get_row(table, doc_id): idx = int(doc_id.split("_")[1]) if idx < 7: # Only return first 7 return {"id": doc_id, "kb_id": "kb_001"} return None - + mock_ob_client.count_rows.return_value = 100 mock_ob_client.get_row_by_id.side_effect = get_row - + verifier = MigrationVerifier(mock_es_client, mock_ob_client) result = verifier.verify("ragflow_test", "ragflow_test", sample_size=10) - + assert result.samples_verified == 10 assert result.samples_matched == 7 assert len(result.missing_in_ob) == 3 def test_verify_samples_data_mismatch(self, mock_es_client, mock_ob_client): """Test sample verification when data doesn't match.""" - es_samples = [ - {"_id": "doc_1", "id": "doc_1", "kb_id": "kb_001", "available_int": 1} - ] + es_samples = [{"_id": "doc_1", "id": "doc_1", "kb_id": "kb_001", "available_int": 1}] mock_es_client.count_documents.return_value = 100 mock_es_client.get_sample_documents.return_value = es_samples - + # Return document with different data mock_ob_client.count_rows.return_value = 100 mock_ob_client.get_row_by_id.return_value = { - "id": "doc_1", "kb_id": "kb_002", "available_int": 0 # Different values + "id": "doc_1", + "kb_id": "kb_002", + "available_int": 0, # Different values } - + verifier = MigrationVerifier(mock_es_client, mock_ob_client) result = verifier.verify("ragflow_test", "ragflow_test", sample_size=1) - + assert result.samples_verified == 1 assert result.samples_matched == 0 assert len(result.data_mismatches) == 1 @@ -188,56 +182,36 @@ class TestMigrationVerifier: def test_values_equal_array_columns(self, verifier): """Test value comparison for array columns.""" # Array stored as JSON string in OB - assert verifier._values_equal( - "important_kwd", - ["key1", "key2"], - '["key1", "key2"]' - ) is True - + assert verifier._values_equal("important_kwd", ["key1", "key2"], '["key1", "key2"]') is True + # Order shouldn't matter for arrays - assert verifier._values_equal( - "important_kwd", - ["key2", "key1"], - '["key1", "key2"]' - ) is True + assert verifier._values_equal("important_kwd", ["key2", "key1"], '["key1", "key2"]') is True def test_values_equal_json_columns(self, verifier): """Test value comparison for JSON columns.""" - assert verifier._values_equal( - "metadata", - {"author": "John"}, - '{"author": "John"}' - ) is True + assert verifier._values_equal("metadata", {"author": "John"}, '{"author": "John"}') is True def test_values_equal_kb_id_list(self, verifier): """Test kb_id comparison when ES has list.""" # ES sometimes stores kb_id as list - assert verifier._values_equal( - "kb_id", - ["kb_001", "kb_002"], - "kb_001" - ) is True + assert verifier._values_equal("kb_id", ["kb_001", "kb_002"], "kb_001") is True def test_values_equal_content_with_weight_dict(self, verifier): """Test content_with_weight comparison when OB has JSON string.""" - assert verifier._values_equal( - "content_with_weight", - {"text": "content", "weight": 1.0}, - '{"text": "content", "weight": 1.0}' - ) is True + assert verifier._values_equal("content_with_weight", {"text": "content", "weight": 1.0}, '{"text": "content", "weight": 1.0}') is True def test_determine_result_passed(self, mock_es_client, mock_ob_client): """Test result determination for passed verification.""" mock_es_client.count_documents.return_value = 1000 mock_ob_client.count_rows.return_value = 1000 - + es_samples = [{"_id": f"doc_{i}", "id": f"doc_{i}", "kb_id": "kb_001"} for i in range(100)] mock_es_client.get_sample_documents.return_value = es_samples mock_ob_client.get_row_by_id.side_effect = lambda t, d: {"id": d, "kb_id": "kb_001"} - + verifier = MigrationVerifier(mock_es_client, mock_ob_client) result = verifier.verify("test", "test", sample_size=100) - + assert result.passed is True assert "PASSED" in result.message @@ -246,10 +220,10 @@ class TestMigrationVerifier: mock_es_client.count_documents.return_value = 1000 mock_ob_client.count_rows.return_value = 500 # Big difference mock_es_client.get_sample_documents.return_value = [] - + verifier = MigrationVerifier(mock_es_client, mock_ob_client) result = verifier.verify("test", "test", sample_size=0) - + assert result.passed is False assert "FAILED" in result.message @@ -257,14 +231,14 @@ class TestMigrationVerifier: """Test result determination when sample verification fails.""" mock_es_client.count_documents.return_value = 100 mock_ob_client.count_rows.return_value = 100 - + es_samples = [{"_id": f"doc_{i}", "id": f"doc_{i}"} for i in range(10)] mock_es_client.get_sample_documents.return_value = es_samples mock_ob_client.get_row_by_id.return_value = None # All missing - + verifier = MigrationVerifier(mock_es_client, mock_ob_client) result = verifier.verify("test", "test", sample_size=10) - + assert result.passed is False def test_generate_report(self, verifier): @@ -283,9 +257,9 @@ class TestMigrationVerifier: passed=True, message="Verification PASSED", ) - + report = verifier.generate_report(result) - + assert "ragflow_test" in report assert "1,000" in report assert "PASSED" in report @@ -308,9 +282,9 @@ class TestMigrationVerifier: passed=False, message="Verification FAILED", ) - + report = verifier.generate_report(result) - + assert "Missing in OceanBase" in report assert "doc_1" in report assert "FAILED" in report @@ -327,20 +301,13 @@ class TestMigrationVerifier: samples_verified=10, samples_matched=8, sample_match_rate=0.8, - data_mismatches=[ - { - "id": "doc_1", - "differences": [ - {"field": "kb_id", "es_value": "kb_001", "ob_value": "kb_002"} - ] - } - ], + data_mismatches=[{"id": "doc_1", "differences": [{"field": "kb_id", "es_value": "kb_001", "ob_value": "kb_002"}]}], passed=False, message="Verification FAILED", ) - + report = verifier.generate_report(result) - + assert "Data Mismatches" in report assert "doc_1" in report assert "kb_id" in report diff --git a/tools/firecrawl/example_usage.py b/tools/firecrawl/example_usage.py index fc8faeed56..647e1597e0 100644 --- a/tools/firecrawl/example_usage.py +++ b/tools/firecrawl/example_usage.py @@ -12,31 +12,31 @@ from .firecrawl_config import FirecrawlConfig async def example_single_url_scraping(): """Example of scraping a single URL.""" print("=== Single URL Scraping Example ===") - + # Configuration config = { "api_key": "fc-your-api-key-here", # Replace with your actual API key "api_url": "https://api.firecrawl.dev", "max_retries": 3, "timeout": 30, - "rate_limit_delay": 1.0 + "rate_limit_delay": 1.0, } - + # Create integration integration = create_firecrawl_integration(config) - + # Test connection connection_test = await integration.test_connection() print(f"Connection test: {connection_test}") - + if not connection_test["success"]: print("Connection failed, please check your API key") return - + # Scrape a single URL urls = ["https://httpbin.org/json"] documents = await integration.scrape_and_import(urls) - + for doc in documents: print(f"Title: {doc.title}") print(f"URL: {doc.source_url}") @@ -49,37 +49,31 @@ async def example_single_url_scraping(): async def example_website_crawling(): """Example of crawling an entire website.""" print("=== Website Crawling Example ===") - + # Configuration config = { "api_key": "fc-your-api-key-here", # Replace with your actual API key "api_url": "https://api.firecrawl.dev", "max_retries": 3, "timeout": 30, - "rate_limit_delay": 1.0 + "rate_limit_delay": 1.0, } - + # Create integration integration = create_firecrawl_integration(config) - + # Crawl a website start_url = "https://httpbin.org" documents = await integration.crawl_and_import( start_url=start_url, limit=5, # Limit to 5 pages for demo - scrape_options={ - "formats": ["markdown", "html"], - "extractOptions": { - "extractMainContent": True, - "excludeTags": ["nav", "footer", "header"] - } - } + scrape_options={"formats": ["markdown", "html"], "extractOptions": {"extractMainContent": True, "excludeTags": ["nav", "footer", "header"]}}, ) - + print(f"Crawled {len(documents)} pages from {start_url}") - + for i, doc in enumerate(documents): - print(f"Page {i+1}: {doc.title}") + print(f"Page {i + 1}: {doc.title}") print(f"URL: {doc.source_url}") print(f"Content length: {len(doc.content)}") print("-" * 30) @@ -88,42 +82,31 @@ async def example_website_crawling(): async def example_batch_processing(): """Example of batch processing multiple URLs.""" print("=== Batch Processing Example ===") - + # Configuration config = { "api_key": "fc-your-api-key-here", # Replace with your actual API key "api_url": "https://api.firecrawl.dev", "max_retries": 3, "timeout": 30, - "rate_limit_delay": 1.0 + "rate_limit_delay": 1.0, } - + # Create integration integration = create_firecrawl_integration(config) - + # Batch scrape multiple URLs - urls = [ - "https://httpbin.org/json", - "https://httpbin.org/html", - "https://httpbin.org/xml" - ] - - documents = await integration.scrape_and_import( - urls=urls, - formats=["markdown", "html"], - extract_options={ - "extractMainContent": True, - "excludeTags": ["nav", "footer", "header"] - } - ) - + urls = ["https://httpbin.org/json", "https://httpbin.org/html", "https://httpbin.org/xml"] + + documents = await integration.scrape_and_import(urls=urls, formats=["markdown", "html"], extract_options={"extractMainContent": True, "excludeTags": ["nav", "footer", "header"]}) + print(f"Processed {len(documents)} URLs") - + for doc in documents: print(f"Title: {doc.title}") print(f"URL: {doc.source_url}") print(f"Content length: {len(doc.content)}") - + # Example of chunking for RAG processing chunks = integration.processor.chunk_content(doc, chunk_size=500, chunk_overlap=100) print(f"Number of chunks: {len(chunks)}") @@ -133,38 +116,34 @@ async def example_batch_processing(): async def example_content_processing(): """Example of content processing and chunking.""" print("=== Content Processing Example ===") - + # Configuration config = { "api_key": "fc-your-api-key-here", # Replace with your actual API key "api_url": "https://api.firecrawl.dev", "max_retries": 3, "timeout": 30, - "rate_limit_delay": 1.0 + "rate_limit_delay": 1.0, } - + # Create integration integration = create_firecrawl_integration(config) - + # Scrape content urls = ["https://httpbin.org/html"] documents = await integration.scrape_and_import(urls) - + for doc in documents: print(f"Original document: {doc.title}") print(f"Content length: {len(doc.content)}") - + # Chunk the content - chunks = integration.processor.chunk_content( - doc, - chunk_size=1000, - chunk_overlap=200 - ) - + chunks = integration.processor.chunk_content(doc, chunk_size=1000, chunk_overlap=200) + print(f"Number of chunks: {len(chunks)}") - + for i, chunk in enumerate(chunks): - print(f"Chunk {i+1}:") + print(f"Chunk {i + 1}:") print(f" ID: {chunk['id']}") print(f" Content length: {len(chunk['content'])}") print(f" Metadata: {chunk['metadata']}") @@ -174,23 +153,17 @@ async def example_content_processing(): async def example_error_handling(): """Example of error handling.""" print("=== Error Handling Example ===") - + # Configuration with invalid API key - config = { - "api_key": "invalid-key", - "api_url": "https://api.firecrawl.dev", - "max_retries": 3, - "timeout": 30, - "rate_limit_delay": 1.0 - } - + config = {"api_key": "invalid-key", "api_url": "https://api.firecrawl.dev", "max_retries": 3, "timeout": 30, "rate_limit_delay": 1.0} + # Create integration integration = create_firecrawl_integration(config) - + # Test connection (should fail) connection_test = await integration.test_connection() print(f"Connection test with invalid key: {connection_test}") - + # Try to scrape (should fail gracefully) try: urls = ["https://httpbin.org/json"] @@ -203,33 +176,27 @@ async def example_error_handling(): async def example_configuration_validation(): """Example of configuration validation.""" print("=== Configuration Validation Example ===") - + # Test various configurations test_configs = [ - { - "api_key": "fc-valid-key", - "api_url": "https://api.firecrawl.dev", - "max_retries": 3, - "timeout": 30, - "rate_limit_delay": 1.0 - }, + {"api_key": "fc-valid-key", "api_url": "https://api.firecrawl.dev", "max_retries": 3, "timeout": 30, "rate_limit_delay": 1.0}, { "api_key": "invalid-key", # Invalid format - "api_url": "https://api.firecrawl.dev" + "api_url": "https://api.firecrawl.dev", }, { "api_key": "fc-valid-key", "api_url": "invalid-url", # Invalid URL "max_retries": 15, # Too high "timeout": 500, # Too high - "rate_limit_delay": 15.0 # Too high - } + "rate_limit_delay": 15.0, # Too high + }, ] - + for i, config in enumerate(test_configs): - print(f"Test configuration {i+1}:") + print(f"Test configuration {i + 1}:") errors = RAGFlowFirecrawlIntegration(FirecrawlConfig.from_dict(config)).validate_config(config) - + if errors: print(" Errors found:") for field, error in errors.items(): @@ -243,17 +210,17 @@ async def main(): """Run all examples.""" # Set up logging logging.basicConfig(level=logging.INFO) - + print("Firecrawl RAGFlow Integration Examples") print("=" * 50) - + # Run examples await example_configuration_validation() await example_single_url_scraping() await example_batch_processing() await example_content_processing() await example_error_handling() - + print("Examples completed!") diff --git a/tools/firecrawl/firecrawl_config.py b/tools/firecrawl/firecrawl_config.py index dc5f9cb38c..be85d67916 100644 --- a/tools/firecrawl/firecrawl_config.py +++ b/tools/firecrawl/firecrawl_config.py @@ -11,52 +11,52 @@ import json @dataclass class FirecrawlConfig: """Configuration class for Firecrawl integration.""" - + api_key: str api_url: str = "https://api.firecrawl.dev" max_retries: int = 3 timeout: int = 30 rate_limit_delay: float = 1.0 max_concurrent_requests: int = 5 - + def __post_init__(self): """Validate configuration after initialization.""" if not self.api_key: raise ValueError("Firecrawl API key is required") - + if not self.api_key.startswith("fc-"): raise ValueError("Invalid Firecrawl API key format. Must start with 'fc-'") - + if self.max_retries < 1 or self.max_retries > 10: raise ValueError("Max retries must be between 1 and 10") - + if self.timeout < 5 or self.timeout > 300: raise ValueError("Timeout must be between 5 and 300 seconds") - + if self.rate_limit_delay < 0.1 or self.rate_limit_delay > 10.0: raise ValueError("Rate limit delay must be between 0.1 and 10.0 seconds") - + @classmethod def from_env(cls) -> "FirecrawlConfig": """Create configuration from environment variables.""" api_key = os.getenv("FIRECRAWL_API_KEY") if not api_key: raise ValueError("FIRECRAWL_API_KEY environment variable not set") - + return cls( api_key=api_key, api_url=os.getenv("FIRECRAWL_API_URL", "https://api.firecrawl.dev"), max_retries=int(os.getenv("FIRECRAWL_MAX_RETRIES", "3")), timeout=int(os.getenv("FIRECRAWL_TIMEOUT", "30")), rate_limit_delay=float(os.getenv("FIRECRAWL_RATE_LIMIT_DELAY", "1.0")), - max_concurrent_requests=int(os.getenv("FIRECRAWL_MAX_CONCURRENT", "5")) + max_concurrent_requests=int(os.getenv("FIRECRAWL_MAX_CONCURRENT", "5")), ) - + @classmethod def from_dict(cls, config_dict: Dict[str, Any]) -> "FirecrawlConfig": """Create configuration from dictionary.""" return cls(**config_dict) - + def to_dict(self) -> Dict[str, Any]: """Convert configuration to dictionary.""" return { @@ -65,13 +65,13 @@ class FirecrawlConfig: "max_retries": self.max_retries, "timeout": self.timeout, "rate_limit_delay": self.rate_limit_delay, - "max_concurrent_requests": self.max_concurrent_requests + "max_concurrent_requests": self.max_concurrent_requests, } - + def to_json(self) -> str: """Convert configuration to JSON string.""" return json.dumps(self.to_dict(), indent=2) - + @classmethod def from_json(cls, json_str: str) -> "FirecrawlConfig": """Create configuration from JSON string.""" diff --git a/tools/firecrawl/firecrawl_connector.py b/tools/firecrawl/firecrawl_connector.py index d587e3a9d7..1ccdbf8b41 100644 --- a/tools/firecrawl/firecrawl_connector.py +++ b/tools/firecrawl/firecrawl_connector.py @@ -15,7 +15,7 @@ from firecrawl_config import FirecrawlConfig @dataclass class ScrapedContent: """Represents scraped content from Firecrawl.""" - + url: str markdown: Optional[str] = None html: Optional[str] = None @@ -29,7 +29,7 @@ class ScrapedContent: @dataclass class CrawlJob: """Represents a crawl job from Firecrawl.""" - + job_id: str status: str total: Optional[int] = None @@ -40,93 +40,82 @@ class CrawlJob: class FirecrawlConnector: """Main connector class for Firecrawl integration with RAGFlow.""" - + def __init__(self, config: FirecrawlConfig): """Initialize the Firecrawl connector.""" self.config = config self.logger = logging.getLogger(__name__) self.session: Optional[aiohttp.ClientSession] = None self._rate_limit_semaphore = asyncio.Semaphore(config.max_concurrent_requests) - + async def __aenter__(self): """Async context manager entry.""" await self._create_session() return self - + async def __aexit__(self, exc_type, exc_val, exc_tb): """Async context manager exit.""" await self._close_session() - + async def _create_session(self): """Create aiohttp session with proper headers.""" - headers = { - "Authorization": f"Bearer {self.config.api_key}", - "Content-Type": "application/json", - "User-Agent": "RAGFlow-Firecrawl-Plugin/1.0.0" - } - + headers = {"Authorization": f"Bearer {self.config.api_key}", "Content-Type": "application/json", "User-Agent": "RAGFlow-Firecrawl-Plugin/1.0.0"} + timeout = aiohttp.ClientTimeout(total=self.config.timeout) - self.session = aiohttp.ClientSession( - headers=headers, - timeout=timeout - ) - + self.session = aiohttp.ClientSession(headers=headers, timeout=timeout) + async def _close_session(self): """Close aiohttp session.""" if self.session: await self.session.close() - + async def _make_request(self, method: str, endpoint: str, **kwargs) -> Dict[str, Any]: """Make HTTP request with rate limiting and retry logic.""" async with self._rate_limit_semaphore: # Rate limiting await asyncio.sleep(self.config.rate_limit_delay) - + url = f"{self.config.api_url}{endpoint}" - + for attempt in range(self.config.max_retries): try: async with self.session.request(method, url, **kwargs) as response: if response.status == 429: # Rate limited - wait_time = 2 ** attempt + wait_time = 2**attempt self.logger.warning(f"Rate limited, waiting {wait_time}s") await asyncio.sleep(wait_time) continue - + response.raise_for_status() return await response.json() - + except aiohttp.ClientError as e: self.logger.error(f"Request failed (attempt {attempt + 1}): {e}") if attempt == self.config.max_retries - 1: raise - await asyncio.sleep(2 ** attempt) - + await asyncio.sleep(2**attempt) + raise Exception("Max retries exceeded") - - async def scrape_url(self, url: str, formats: List[str] = None, - extract_options: Dict[str, Any] = None) -> ScrapedContent: + + async def scrape_url(self, url: str, formats: List[str] = None, extract_options: Dict[str, Any] = None) -> ScrapedContent: """Scrape a single URL.""" if formats is None: formats = ["markdown", "html"] - - payload = { - "url": url, - "formats": formats - } - + + payload = {"url": url, "formats": formats} + if extract_options: payload["extractOptions"] = extract_options - + try: response = await self._make_request("POST", "/v2/scrape", json=payload) - + if not response.get("success"): return ScrapedContent(url=url, error=response.get("error", "Unknown error")) - + data = response.get("data", {}) metadata = data.get("metadata", {}) - + return ScrapedContent( url=url, markdown=data.get("markdown"), @@ -134,118 +123,96 @@ class FirecrawlConnector: metadata=metadata, title=metadata.get("title"), description=metadata.get("description"), - status_code=metadata.get("statusCode") + status_code=metadata.get("statusCode"), ) - + except Exception as e: self.logger.error(f"Failed to scrape {url}: {e}") return ScrapedContent(url=url, error=str(e)) - - async def start_crawl(self, url: str, limit: int = 100, - scrape_options: Dict[str, Any] = None) -> CrawlJob: + + async def start_crawl(self, url: str, limit: int = 100, scrape_options: Dict[str, Any] = None) -> CrawlJob: """Start a crawl job.""" if scrape_options is None: scrape_options = {"formats": ["markdown", "html"]} - - payload = { - "url": url, - "limit": limit, - "scrapeOptions": scrape_options - } - + + payload = {"url": url, "limit": limit, "scrapeOptions": scrape_options} + try: response = await self._make_request("POST", "/v2/crawl", json=payload) - + if not response.get("success"): - return CrawlJob( - job_id="", - status="failed", - error=response.get("error", "Unknown error") - ) - + return CrawlJob(job_id="", status="failed", error=response.get("error", "Unknown error")) + job_id = response.get("id") return CrawlJob(job_id=job_id, status="started") - + except Exception as e: self.logger.error(f"Failed to start crawl for {url}: {e}") return CrawlJob(job_id="", status="failed", error=str(e)) - + async def get_crawl_status(self, job_id: str) -> CrawlJob: """Get the status of a crawl job.""" try: response = await self._make_request("GET", f"/v2/crawl/{job_id}") - + if not response.get("success"): - return CrawlJob( - job_id=job_id, - status="failed", - error=response.get("error", "Unknown error") - ) - + return CrawlJob(job_id=job_id, status="failed", error=response.get("error", "Unknown error")) + status = response.get("status", "unknown") total = response.get("total") data = response.get("data", []) - + # Convert data to ScrapedContent objects scraped_content = [] for item in data: metadata = item.get("metadata", {}) - scraped_content.append(ScrapedContent( - url=metadata.get("sourceURL", ""), - markdown=item.get("markdown"), - html=item.get("html"), - metadata=metadata, - title=metadata.get("title"), - description=metadata.get("description"), - status_code=metadata.get("statusCode") - )) - - return CrawlJob( - job_id=job_id, - status=status, - total=total, - completed=len(scraped_content), - data=scraped_content - ) - + scraped_content.append( + ScrapedContent( + url=metadata.get("sourceURL", ""), + markdown=item.get("markdown"), + html=item.get("html"), + metadata=metadata, + title=metadata.get("title"), + description=metadata.get("description"), + status_code=metadata.get("statusCode"), + ) + ) + + return CrawlJob(job_id=job_id, status=status, total=total, completed=len(scraped_content), data=scraped_content) + except Exception as e: self.logger.error(f"Failed to get crawl status for {job_id}: {e}") return CrawlJob(job_id=job_id, status="failed", error=str(e)) - - async def wait_for_crawl_completion(self, job_id: str, - poll_interval: int = 30) -> CrawlJob: + + async def wait_for_crawl_completion(self, job_id: str, poll_interval: int = 30) -> CrawlJob: """Wait for a crawl job to complete.""" while True: job = await self.get_crawl_status(job_id) - + if job.status in ["completed", "failed", "cancelled"]: return job - + self.logger.info(f"Crawl {job_id} status: {job.status}") await asyncio.sleep(poll_interval) - - async def batch_scrape(self, urls: List[str], - formats: List[str] = None) -> List[ScrapedContent]: + + async def batch_scrape(self, urls: List[str], formats: List[str] = None) -> List[ScrapedContent]: """Scrape multiple URLs concurrently.""" if formats is None: formats = ["markdown", "html"] - + tasks = [self.scrape_url(url, formats) for url in urls] results = await asyncio.gather(*tasks, return_exceptions=True) - + # Handle exceptions processed_results = [] for i, result in enumerate(results): if isinstance(result, Exception): - processed_results.append(ScrapedContent( - url=urls[i], - error=str(result) - )) + processed_results.append(ScrapedContent(url=urls[i], error=str(result))) else: processed_results.append(result) - + return processed_results - + def validate_url(self, url: str) -> bool: """Validate if URL is properly formatted.""" try: @@ -253,7 +220,7 @@ class FirecrawlConnector: return all([result.scheme, result.netloc]) except Exception: return False - + def extract_domain(self, url: str) -> str: """Extract domain from URL.""" try: diff --git a/tools/firecrawl/firecrawl_processor.py b/tools/firecrawl/firecrawl_processor.py index c1cbb7ad54..a76a61cd51 100644 --- a/tools/firecrawl/firecrawl_processor.py +++ b/tools/firecrawl/firecrawl_processor.py @@ -15,7 +15,7 @@ from firecrawl_connector import ScrapedContent @dataclass class RAGFlowDocument: """Represents a document in RAGFlow format.""" - + id: str title: str content: str @@ -31,73 +31,73 @@ class RAGFlowDocument: class FirecrawlProcessor: """Processes Firecrawl content for RAGFlow integration.""" - + def __init__(self): """Initialize the processor.""" self.logger = logging.getLogger(__name__) - + def generate_document_id(self, url: str, content: str) -> str: """Generate a unique document ID.""" # Create a hash based on URL and content content_hash = hashlib.md5(f"{url}:{content[:100]}".encode()).hexdigest() return f"firecrawl_{content_hash}" - + def clean_content(self, content: str) -> str: """Clean and normalize content.""" if not content: return "" - + # Remove excessive whitespace - content = re.sub(r'\s+', ' ', content) - + content = re.sub(r"\s+", " ", content) + # Remove HTML tags if present - content = re.sub(r'<[^>]+>', '', content) - + content = re.sub(r"<[^>]+>", "", content) + # Remove special characters that might cause issues - content = re.sub(r'[^\w\s\.\,\!\?\;\:\-\(\)\[\]\"\']', '', content) - + content = re.sub(r"[^\w\s\.\,\!\?\;\:\-\(\)\[\]\"\']", "", content) + return content.strip() - + def extract_title(self, content: ScrapedContent) -> str: """Extract title from scraped content.""" if content.title: return content.title - + if content.metadata and content.metadata.get("title"): return content.metadata["title"] - + # Extract title from markdown if available if content.markdown: - title_match = re.search(r'^#\s+(.+)$', content.markdown, re.MULTILINE) + title_match = re.search(r"^#\s+(.+)$", content.markdown, re.MULTILINE) if title_match: return title_match.group(1).strip() - + # Fallback to URL - return content.url.split('/')[-1] or content.url - + return content.url.split("/")[-1] or content.url + def extract_description(self, content: ScrapedContent) -> str: """Extract description from scraped content.""" if content.description: return content.description - + if content.metadata and content.metadata.get("description"): return content.metadata["description"] - + # Extract first paragraph from markdown if content.markdown: # Remove headers and get first paragraph - text = re.sub(r'^#+\s+.*$', '', content.markdown, flags=re.MULTILINE) - paragraphs = [p.strip() for p in text.split('\n\n') if p.strip()] + text = re.sub(r"^#+\s+.*$", "", content.markdown, flags=re.MULTILINE) + paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()] if paragraphs: return paragraphs[0][:200] + "..." if len(paragraphs[0]) > 200 else paragraphs[0] - + return "" - + def extract_language(self, content: ScrapedContent) -> str: """Extract language from content metadata.""" if content.metadata and content.metadata.get("language"): return content.metadata["language"] - + # Simple language detection based on common words if content.markdown: text = content.markdown.lower() @@ -109,9 +109,9 @@ class FirecrawlProcessor: return "de" elif any(word in text for word in ["el", "la", "los", "las", "de", "del"]): return "es" - + return "en" # Default to English - + def create_metadata(self, content: ScrapedContent) -> Dict[str, Any]: """Create comprehensive metadata for RAGFlow document.""" metadata = { @@ -122,54 +122,57 @@ class FirecrawlProcessor: "status_code": content.status_code, "content_length": len(content.markdown or ""), "has_html": bool(content.html), - "has_markdown": bool(content.markdown) + "has_markdown": bool(content.markdown), } - + # Add original metadata if available if content.metadata: - metadata.update({ - "original_title": content.metadata.get("title"), - "original_description": content.metadata.get("description"), - "original_language": content.metadata.get("language"), - "original_keywords": content.metadata.get("keywords"), - "original_robots": content.metadata.get("robots"), - "og_title": content.metadata.get("ogTitle"), - "og_description": content.metadata.get("ogDescription"), - "og_image": content.metadata.get("ogImage"), - "og_url": content.metadata.get("ogUrl") - }) - + metadata.update( + { + "original_title": content.metadata.get("title"), + "original_description": content.metadata.get("description"), + "original_language": content.metadata.get("language"), + "original_keywords": content.metadata.get("keywords"), + "original_robots": content.metadata.get("robots"), + "og_title": content.metadata.get("ogTitle"), + "og_description": content.metadata.get("ogDescription"), + "og_image": content.metadata.get("ogImage"), + "og_url": content.metadata.get("ogUrl"), + } + ) + return metadata - + def extract_domain(self, url: str) -> str: """Extract domain from URL.""" try: from urllib.parse import urlparse + return urlparse(url).netloc except Exception: return "" - + def process_content(self, content: ScrapedContent) -> RAGFlowDocument: """Process scraped content into RAGFlow document format.""" if content.error: raise ValueError(f"Content has error: {content.error}") - + # Determine primary content primary_content = content.markdown or content.html or "" if not primary_content: raise ValueError("No content available to process") - + # Clean content cleaned_content = self.clean_content(primary_content) - + # Extract metadata title = self.extract_title(content) language = self.extract_language(content) metadata = self.create_metadata(content) - + # Generate document ID doc_id = self.generate_document_id(content.url, cleaned_content) - + # Create RAGFlow document document = RAGFlowDocument( id=doc_id, @@ -180,15 +183,15 @@ class FirecrawlProcessor: created_at=datetime.utcnow(), updated_at=datetime.utcnow(), content_type="text", - language=language + language=language, ) - + return document - + def process_batch(self, contents: List[ScrapedContent]) -> List[RAGFlowDocument]: """Process multiple scraped contents into RAGFlow documents.""" documents = [] - + for content in contents: try: document = self.process_content(content) @@ -196,80 +199,72 @@ class FirecrawlProcessor: except Exception as e: self.logger.error(f"Failed to process content from {content.url}: {e}") continue - + return documents - - def chunk_content(self, document: RAGFlowDocument, - chunk_size: int = 1000, - chunk_overlap: int = 200) -> List[Dict[str, Any]]: + + def chunk_content(self, document: RAGFlowDocument, chunk_size: int = 1000, chunk_overlap: int = 200) -> List[Dict[str, Any]]: """Chunk document content for RAG processing.""" content = document.content chunks = [] - + if len(content) <= chunk_size: - return [{ - "id": f"{document.id}_chunk_0", - "content": content, - "metadata": { - **document.metadata, - "chunk_index": 0, - "total_chunks": 1 - } - }] - + return [{"id": f"{document.id}_chunk_0", "content": content, "metadata": {**document.metadata, "chunk_index": 0, "total_chunks": 1}}] + # Split content into chunks start = 0 chunk_index = 0 - + while start < len(content): end = start + chunk_size - + # Try to break at sentence boundary if end < len(content): # Look for sentence endings - sentence_end = content.rfind('.', start, end) + sentence_end = content.rfind(".", start, end) if sentence_end > start + chunk_size // 2: end = sentence_end + 1 - + chunk_content = content[start:end].strip() - + if chunk_content: - chunks.append({ - "id": f"{document.id}_chunk_{chunk_index}", - "content": chunk_content, - "metadata": { - **document.metadata, - "chunk_index": chunk_index, - "total_chunks": len(chunks) + 1, # Will be updated - "chunk_start": start, - "chunk_end": end + chunks.append( + { + "id": f"{document.id}_chunk_{chunk_index}", + "content": chunk_content, + "metadata": { + **document.metadata, + "chunk_index": chunk_index, + "total_chunks": len(chunks) + 1, # Will be updated + "chunk_start": start, + "chunk_end": end, + }, } - }) + ) chunk_index += 1 - + # Move start position with overlap start = end - chunk_overlap if start >= len(content): break - + # Update total chunks count for chunk in chunks: chunk["metadata"]["total_chunks"] = len(chunks) - + return chunks - + def validate_document(self, document: RAGFlowDocument) -> bool: """Validate RAGFlow document.""" if not document.id: return False - + if not document.title: return False - + if not document.content: return False - + if not document.source_url: return False - + return True diff --git a/tools/firecrawl/firecrawl_ui.py b/tools/firecrawl/firecrawl_ui.py index 0660a1e4ff..3331c7bf2d 100644 --- a/tools/firecrawl/firecrawl_ui.py +++ b/tools/firecrawl/firecrawl_ui.py @@ -9,15 +9,15 @@ from dataclasses import dataclass @dataclass class FirecrawlUIComponent: """Represents a UI component for Firecrawl integration.""" - + component_type: str props: Dict[str, Any] - children: Optional[List['FirecrawlUIComponent']] = None + children: Optional[List["FirecrawlUIComponent"]] = None class FirecrawlUIBuilder: """Builder for Firecrawl UI components in RAGFlow.""" - + @staticmethod def create_data_source_config() -> Dict[str, Any]: """Create configuration for Firecrawl data source.""" @@ -32,49 +32,16 @@ class FirecrawlUIBuilder: "config_schema": { "type": "object", "properties": { - "api_key": { - "type": "string", - "title": "Firecrawl API Key", - "description": "Your Firecrawl API key (starts with 'fc-')", - "format": "password", - "required": True - }, - "api_url": { - "type": "string", - "title": "API URL", - "description": "Firecrawl API endpoint", - "default": "https://api.firecrawl.dev", - "required": False - }, - "max_retries": { - "type": "integer", - "title": "Max Retries", - "description": "Maximum number of retry attempts", - "default": 3, - "minimum": 1, - "maximum": 10 - }, - "timeout": { - "type": "integer", - "title": "Timeout (seconds)", - "description": "Request timeout in seconds", - "default": 30, - "minimum": 5, - "maximum": 300 - }, - "rate_limit_delay": { - "type": "number", - "title": "Rate Limit Delay", - "description": "Delay between requests in seconds", - "default": 1.0, - "minimum": 0.1, - "maximum": 10.0 - } + "api_key": {"type": "string", "title": "Firecrawl API Key", "description": "Your Firecrawl API key (starts with 'fc-')", "format": "password", "required": True}, + "api_url": {"type": "string", "title": "API URL", "description": "Firecrawl API endpoint", "default": "https://api.firecrawl.dev", "required": False}, + "max_retries": {"type": "integer", "title": "Max Retries", "description": "Maximum number of retry attempts", "default": 3, "minimum": 1, "maximum": 10}, + "timeout": {"type": "integer", "title": "Timeout (seconds)", "description": "Request timeout in seconds", "default": 30, "minimum": 5, "maximum": 300}, + "rate_limit_delay": {"type": "number", "title": "Rate Limit Delay", "description": "Delay between requests in seconds", "default": 1.0, "minimum": 0.1, "maximum": 10.0}, }, - "required": ["api_key"] - } + "required": ["api_key"], + }, } - + @staticmethod def create_scraping_form() -> Dict[str, Any]: """Create form for scraping configuration.""" @@ -88,12 +55,9 @@ class FirecrawlUIBuilder: "type": "array", "title": "URLs to Scrape", "description": "Enter URLs to scrape (one per line)", - "items": { - "type": "string", - "format": "uri" - }, + "items": {"type": "string", "format": "uri"}, "required": True, - "minItems": 1 + "minItems": 1, }, { "name": "scrape_type", @@ -103,19 +67,16 @@ class FirecrawlUIBuilder: "enum": ["single", "crawl", "batch"], "enumNames": ["Single URL", "Crawl Website", "Batch URLs"], "default": "single", - "required": True + "required": True, }, { "name": "formats", "type": "array", "title": "Output Formats", "description": "Select output formats", - "items": { - "type": "string", - "enum": ["markdown", "html", "links", "screenshot"] - }, + "items": {"type": "string", "enum": ["markdown", "html", "links", "screenshot"]}, "default": ["markdown", "html"], - "required": True + "required": True, }, { "name": "crawl_limit", @@ -125,10 +86,7 @@ class FirecrawlUIBuilder: "default": 100, "minimum": 1, "maximum": 1000, - "condition": { - "field": "scrape_type", - "equals": "crawl" - } + "condition": {"field": "scrape_type", "equals": "crawl"}, }, { "name": "extract_options", @@ -136,30 +94,20 @@ class FirecrawlUIBuilder: "title": "Extraction Options", "description": "Advanced extraction settings", "properties": { - "extractMainContent": { - "type": "boolean", - "title": "Extract Main Content Only", - "default": True - }, - "excludeTags": { - "type": "array", - "title": "Exclude Tags", - "description": "HTML tags to exclude", - "items": {"type": "string"}, - "default": ["nav", "footer", "header", "aside"] - }, + "extractMainContent": {"type": "boolean", "title": "Extract Main Content Only", "default": True}, + "excludeTags": {"type": "array", "title": "Exclude Tags", "description": "HTML tags to exclude", "items": {"type": "string"}, "default": ["nav", "footer", "header", "aside"]}, "includeTags": { "type": "array", "title": "Include Tags", "description": "HTML tags to include", "items": {"type": "string"}, - "default": ["main", "article", "section", "div", "p"] - } - } - } - ] + "default": ["main", "article", "section", "div", "p"], + }, + }, + }, + ], } - + @staticmethod def create_progress_component() -> Dict[str, Any]: """Create progress tracking component.""" @@ -167,13 +115,9 @@ class FirecrawlUIBuilder: "type": "progress", "title": "Scraping Progress", "description": "Track the progress of your web scraping job", - "properties": { - "show_percentage": True, - "show_eta": True, - "show_details": True - } + "properties": {"show_percentage": True, "show_eta": True, "show_details": True}, } - + @staticmethod def create_results_view() -> Dict[str, Any]: """Create results display component.""" @@ -181,14 +125,9 @@ class FirecrawlUIBuilder: "type": "results", "title": "Scraping Results", "description": "View and manage scraped content", - "properties": { - "show_preview": True, - "show_metadata": True, - "allow_editing": True, - "show_chunks": True - } + "properties": {"show_preview": True, "show_metadata": True, "allow_editing": True, "show_chunks": True}, } - + @staticmethod def create_error_handler() -> Dict[str, Any]: """Create error handling component.""" @@ -196,33 +135,18 @@ class FirecrawlUIBuilder: "type": "error_handler", "title": "Error Handling", "description": "Handle scraping errors and retries", - "properties": { - "show_retry_button": True, - "show_error_details": True, - "auto_retry": False, - "max_retries": 3 - } + "properties": {"show_retry_button": True, "show_error_details": True, "auto_retry": False, "max_retries": 3}, } - + @staticmethod def create_validation_rules() -> Dict[str, Any]: """Create validation rules for Firecrawl integration.""" return { - "url_validation": { - "pattern": r"^https?://.+", - "message": "URL must start with http:// or https://" - }, - "api_key_validation": { - "pattern": r"^fc-[a-zA-Z0-9]+$", - "message": "API key must start with 'fc-' followed by alphanumeric characters" - }, - "rate_limit_validation": { - "min": 0.1, - "max": 10.0, - "message": "Rate limit delay must be between 0.1 and 10.0 seconds" - } + "url_validation": {"pattern": r"^https?://.+", "message": "URL must start with http:// or https://"}, + "api_key_validation": {"pattern": r"^fc-[a-zA-Z0-9]+$", "message": "API key must start with 'fc-' followed by alphanumeric characters"}, + "rate_limit_validation": {"min": 0.1, "max": 10.0, "message": "Rate limit delay must be between 0.1 and 10.0 seconds"}, } - + @staticmethod def create_help_text() -> Dict[str, str]: """Create help text for users.""" @@ -231,9 +155,9 @@ class FirecrawlUIBuilder: "url_help": "Enter the URLs you want to scrape. You can add multiple URLs for batch processing.", "crawl_help": "Crawling will follow links from the starting URL and scrape all accessible pages within the limit.", "formats_help": "Choose the output formats you need. Markdown is recommended for RAG processing.", - "extract_help": "Extraction options help filter content to get only the main content without navigation and ads." + "extract_help": "Extraction options help filter content to get only the main content without navigation and ads.", } - + @staticmethod def create_ui_schema() -> Dict[str, Any]: """Create complete UI schema for Firecrawl integration.""" @@ -244,16 +168,9 @@ class FirecrawlUIBuilder: "scraping_form": FirecrawlUIBuilder.create_scraping_form(), "progress_component": FirecrawlUIBuilder.create_progress_component(), "results_view": FirecrawlUIBuilder.create_results_view(), - "error_handler": FirecrawlUIBuilder.create_error_handler() + "error_handler": FirecrawlUIBuilder.create_error_handler(), }, "validation_rules": FirecrawlUIBuilder.create_validation_rules(), "help_text": FirecrawlUIBuilder.create_help_text(), - "workflow": [ - "configure_data_source", - "setup_scraping_parameters", - "start_scraping_job", - "monitor_progress", - "review_results", - "import_to_ragflow" - ] + "workflow": ["configure_data_source", "setup_scraping_parameters", "start_scraping_job", "monitor_progress", "review_results", "import_to_ragflow"], } diff --git a/tools/firecrawl/integration.py b/tools/firecrawl/integration.py index b4fbf6cede..f85ba89f99 100644 --- a/tools/firecrawl/integration.py +++ b/tools/firecrawl/integration.py @@ -20,7 +20,7 @@ class FirecrawlRAGFlowPlugin: Main plugin class for Firecrawl integration with RAGFlow. This class provides the interface that RAGFlow expects from integrations. """ - + def __init__(self): """Initialize the Firecrawl plugin.""" self.name = "firecrawl" @@ -30,9 +30,9 @@ class FirecrawlRAGFlowPlugin: self.author = "Firecrawl Team" self.category = "web" self.icon = "🌐" - + logger.info(f"Initialized {self.display_name} plugin v{self.version}") - + def get_plugin_info(self) -> Dict[str, Any]: """Get plugin information for RAGFlow.""" return { @@ -44,17 +44,17 @@ class FirecrawlRAGFlowPlugin: "category": self.category, "icon": self.icon, "supported_formats": ["markdown", "html", "links", "screenshot"], - "supported_scrape_types": ["single", "crawl", "batch"] + "supported_scrape_types": ["single", "crawl", "batch"], } - + def get_config_schema(self) -> Dict[str, Any]: """Get configuration schema for RAGFlow.""" return FirecrawlUIBuilder.create_data_source_config()["config_schema"] - + def get_ui_schema(self) -> Dict[str, Any]: """Get UI schema for RAGFlow.""" return FirecrawlUIBuilder.create_ui_schema() - + def validate_config(self, config: Dict[str, Any]) -> Dict[str, Any]: """Validate configuration and return any errors.""" try: @@ -63,30 +63,27 @@ class FirecrawlRAGFlowPlugin: except Exception as e: logger.error(f"Configuration validation error: {e}") return {"general": str(e)} - + def test_connection(self, config: Dict[str, Any]) -> Dict[str, Any]: """Test connection to Firecrawl API.""" try: integration = create_firecrawl_integration(config) # Run the async test_connection method import asyncio + return asyncio.run(integration.test_connection()) except Exception as e: logger.error(f"Connection test error: {e}") - return { - "success": False, - "error": str(e), - "message": "Connection test failed" - } - + return {"success": False, "error": str(e), "message": "Connection test failed"} + def create_integration(self, config: Dict[str, Any]) -> RAGFlowFirecrawlIntegration: """Create and return a Firecrawl integration instance.""" return create_firecrawl_integration(config) - + def get_help_text(self) -> Dict[str, str]: """Get help text for users.""" return FirecrawlUIBuilder.create_help_text() - + def get_validation_rules(self) -> Dict[str, Any]: """Get validation rules for configuration.""" return FirecrawlUIBuilder.create_validation_rules() @@ -128,11 +125,7 @@ def test_connection(config: Dict[str, Any]) -> Dict[str, Any]: integration = create_firecrawl_integration(config) return integration.test_connection() except Exception as e: - return { - "success": False, - "error": str(e), - "message": "Connection test failed" - } + return {"success": False, "error": str(e), "message": "Connection test failed"} # Export main functions and classes @@ -145,5 +138,5 @@ __all__ = [ "validate_config", "test_connection", "RAGFlowFirecrawlIntegration", - "create_firecrawl_integration" + "create_firecrawl_integration", ] diff --git a/tools/firecrawl/ragflow_integration.py b/tools/firecrawl/ragflow_integration.py index 2d0bfe4b79..60c9d01fd5 100644 --- a/tools/firecrawl/ragflow_integration.py +++ b/tools/firecrawl/ragflow_integration.py @@ -14,75 +14,71 @@ from firecrawl_ui import FirecrawlUIBuilder class RAGFlowFirecrawlIntegration: """Main integration class for Firecrawl with RAGFlow.""" - + def __init__(self, config: FirecrawlConfig): """Initialize the integration.""" self.config = config self.connector = FirecrawlConnector(config) self.processor = FirecrawlProcessor() self.logger = logging.getLogger(__name__) - - async def scrape_and_import(self, urls: List[str], - formats: List[str] = None, - extract_options: Dict[str, Any] = None) -> List[RAGFlowDocument]: + + async def scrape_and_import(self, urls: List[str], formats: List[str] = None, extract_options: Dict[str, Any] = None) -> List[RAGFlowDocument]: """Scrape URLs and convert to RAGFlow documents.""" if formats is None: formats = ["markdown", "html"] - + async with self.connector: # Scrape URLs scraped_contents = await self.connector.batch_scrape(urls, formats) - + # Process into RAGFlow documents documents = self.processor.process_batch(scraped_contents) - + return documents - - async def crawl_and_import(self, start_url: str, - limit: int = 100, - scrape_options: Dict[str, Any] = None) -> List[RAGFlowDocument]: + + async def crawl_and_import(self, start_url: str, limit: int = 100, scrape_options: Dict[str, Any] = None) -> List[RAGFlowDocument]: """Crawl a website and convert to RAGFlow documents.""" if scrape_options is None: scrape_options = {"formats": ["markdown", "html"]} - + async with self.connector: # Start crawl job crawl_job = await self.connector.start_crawl(start_url, limit, scrape_options) - + if crawl_job.error: raise Exception(f"Failed to start crawl: {crawl_job.error}") - + # Wait for completion completed_job = await self.connector.wait_for_crawl_completion(crawl_job.job_id) - + if completed_job.error: raise Exception(f"Crawl failed: {completed_job.error}") - + # Process into RAGFlow documents documents = self.processor.process_batch(completed_job.data or []) - + return documents - + def get_ui_schema(self) -> Dict[str, Any]: """Get UI schema for RAGFlow integration.""" return FirecrawlUIBuilder.create_ui_schema() - + def validate_config(self, config_dict: Dict[str, Any]) -> Dict[str, Any]: """Validate configuration and return any errors.""" errors = {} - + # Validate API key api_key = config_dict.get("api_key", "") if not api_key: errors["api_key"] = "API key is required" elif not api_key.startswith("fc-"): errors["api_key"] = "API key must start with 'fc-'" - + # Validate API URL api_url = config_dict.get("api_url", "https://api.firecrawl.dev") if not api_url.startswith("http"): errors["api_url"] = "API URL must start with http:// or https://" - + # Validate numeric fields try: max_retries = int(config_dict.get("max_retries", 3)) @@ -90,27 +86,27 @@ class RAGFlowFirecrawlIntegration: errors["max_retries"] = "Max retries must be between 1 and 10" except (ValueError, TypeError): errors["max_retries"] = "Max retries must be a valid integer" - + try: timeout = int(config_dict.get("timeout", 30)) if timeout < 5 or timeout > 300: errors["timeout"] = "Timeout must be between 5 and 300 seconds" except (ValueError, TypeError): errors["timeout"] = "Timeout must be a valid integer" - + try: rate_limit_delay = float(config_dict.get("rate_limit_delay", 1.0)) if rate_limit_delay < 0.1 or rate_limit_delay > 10.0: errors["rate_limit_delay"] = "Rate limit delay must be between 0.1 and 10.0 seconds" except (ValueError, TypeError): errors["rate_limit_delay"] = "Rate limit delay must be a valid number" - + return errors - + def create_config(self, config_dict: Dict[str, Any]) -> FirecrawlConfig: """Create FirecrawlConfig from dictionary.""" return FirecrawlConfig.from_dict(config_dict) - + async def test_connection(self) -> Dict[str, Any]: """Test the connection to Firecrawl API.""" try: @@ -118,40 +114,32 @@ class RAGFlowFirecrawlIntegration: # Try to scrape a simple URL to test connection test_url = "https://httpbin.org/json" result = await self.connector.scrape_url(test_url, ["markdown"]) - + if result.error: - return { - "success": False, - "error": result.error, - "message": "Failed to connect to Firecrawl API" - } - + return {"success": False, "error": result.error, "message": "Failed to connect to Firecrawl API"} + return { "success": True, "message": "Successfully connected to Firecrawl API", "test_url": test_url, - "response_time": "N/A" # Could be enhanced to measure actual response time + "response_time": "N/A", # Could be enhanced to measure actual response time } - + except Exception as e: - return { - "success": False, - "error": str(e), - "message": "Connection test failed" - } - + return {"success": False, "error": str(e), "message": "Connection test failed"} + def get_supported_formats(self) -> List[str]: """Get list of supported output formats.""" return ["markdown", "html", "links", "screenshot"] - + def get_supported_scrape_types(self) -> List[str]: """Get list of supported scrape types.""" return ["single", "crawl", "batch"] - + def get_help_text(self) -> Dict[str, str]: """Get help text for users.""" return FirecrawlUIBuilder.create_help_text() - + def get_validation_rules(self) -> Dict[str, Any]: """Get validation rules for configuration.""" return FirecrawlUIBuilder.create_validation_rules() @@ -165,11 +153,4 @@ def create_firecrawl_integration(config_dict: Dict[str, Any]) -> RAGFlowFirecraw # Export main classes and functions -__all__ = [ - "RAGFlowFirecrawlIntegration", - "create_firecrawl_integration", - "FirecrawlConfig", - "FirecrawlConnector", - "FirecrawlProcessor", - "RAGFlowDocument" -] +__all__ = ["RAGFlowFirecrawlIntegration", "create_firecrawl_integration", "FirecrawlConfig", "FirecrawlConnector", "FirecrawlProcessor", "RAGFlowDocument"] diff --git a/tools/scripts/db_schema_sync.py b/tools/scripts/db_schema_sync.py index bfd9b8ba94..0dad101cda 100644 --- a/tools/scripts/db_schema_sync.py +++ b/tools/scripts/db_schema_sync.py @@ -41,103 +41,94 @@ PROJECT_BASE = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(_ sys.path.insert(0, PROJECT_BASE) # Configure logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(levelname)s - %(message)s' -) +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) def validate_version(version: str) -> bool: """Validate version format: vxx.xx.xx where xx are digits""" - pattern = r'^v\d+\.\d+\.\d+$' + pattern = r"^v\d+\.\d+\.\d+$" return bool(re.match(pattern, version)) def version_to_dirname(version: str) -> str: """Convert version string to valid directory name (e.g., 'v0.26.3' -> 'v0_26_3')""" - return version.replace('.', '_') + return version.replace(".", "_") def load_db_models(): """Load database models from api/db/db_models.py""" - models_path = os.path.join(PROJECT_BASE, 'api', 'db', 'db_models.py') - + models_path = os.path.join(PROJECT_BASE, "api", "db", "db_models.py") + if not os.path.exists(models_path): raise FileNotFoundError(f"db_models.py not found at {models_path}") - + # Import the module spec = importlib.util.spec_from_file_location("db_models", models_path) db_models = importlib.util.module_from_spec(spec) spec.loader.exec_module(db_models) - + # Get all Model subclasses models = [] for name, obj in inspect.getmembers(db_models): if inspect.isclass(obj) and issubclass(obj, Model) and obj is not Model: # Skip base model classes - if obj.__name__ in ['BaseModel', 'DataBaseModel']: + if obj.__name__ in ["BaseModel", "DataBaseModel"]: continue # Check if it has a database attribute (is a proper model) - if hasattr(obj._meta, 'database'): + if hasattr(obj._meta, "database"): models.append(obj) - + return models, db_models def create_database_connection(host: str, port: int, user: str, password: str, database: str): """Create MySQL database connection from command line arguments""" - db = MySQLDatabase( - database, - host=host, - port=port, - user=user, - password=password, - charset='utf8mb4' - ) + db = MySQLDatabase(database, host=host, port=port, user=user, password=password, charset="utf8mb4") return db # MySQL type to Peewee field type mapping MYSQL_TO_PEEWEE_TYPE = { - 'varchar': 'CharField', - 'char': 'CharField', - 'text': 'TextField', - 'longtext': 'TextField', - 'mediumtext': 'TextField', - 'int': 'IntegerField', - 'integer': 'IntegerField', - 'bigint': 'BigIntegerField', - 'float': 'FloatField', - 'double': 'FloatField', - 'decimal': 'FloatField', - 'datetime': 'DateTimeField', - 'timestamp': 'DateTimeField', - 'tinyint(1)': 'BooleanField', - 'tinyint': 'IntegerField', - 'smallint': 'IntegerField', - 'mediumint': 'IntegerField', + "varchar": "CharField", + "char": "CharField", + "text": "TextField", + "longtext": "TextField", + "mediumtext": "TextField", + "int": "IntegerField", + "integer": "IntegerField", + "bigint": "BigIntegerField", + "float": "FloatField", + "double": "FloatField", + "decimal": "FloatField", + "datetime": "DateTimeField", + "timestamp": "DateTimeField", + "tinyint(1)": "BooleanField", + "tinyint": "IntegerField", + "smallint": "IntegerField", + "mediumint": "IntegerField", } PEEWEE_TO_MYSQL_TYPE = { - 'CharField': 'varchar', - 'TextField': 'text', - 'IntegerField': 'int', - 'BigIntegerField': 'bigint', - 'FloatField': 'float', - 'BooleanField': 'tinyint', - 'DateTimeField': 'datetime', + "CharField": "varchar", + "TextField": "text", + "IntegerField": "int", + "BigIntegerField": "bigint", + "FloatField": "float", + "BooleanField": "tinyint", + "DateTimeField": "datetime", } def get_table_columns(db, table_name: str) -> dict: """Get column information from database table - + Returns: dict: {column_name: {type, nullable, default, ...}} """ - cursor = db.execute_sql(""" - SELECT + cursor = db.execute_sql( + """ + SELECT column_name, data_type, column_type, @@ -148,34 +139,36 @@ def get_table_columns(db, table_name: str) -> dict: FROM information_schema.columns WHERE table_schema = %s AND table_name = %s ORDER BY ordinal_position - """, (db.database, table_name)) - + """, + (db.database, table_name), + ) + columns = {} for row in cursor.fetchall(): col_name = row[0] data_type = row[1].lower() column_type = row[2].lower() - is_nullable = row[3] == 'YES' + is_nullable = row[3] == "YES" column_default = row[4] column_key = row[5] - extra = row[6] or '' - + extra = row[6] or "" + # Determine peewee type - if column_type.startswith('tinyint(1)'): - peewee_type = 'BooleanField' + if column_type.startswith("tinyint(1)"): + peewee_type = "BooleanField" else: - peewee_type = MYSQL_TO_PEEWEE_TYPE.get(data_type, 'TextField') - + peewee_type = MYSQL_TO_PEEWEE_TYPE.get(data_type, "TextField") + columns[col_name] = { - 'data_type': data_type, - 'column_type': column_type, - 'peewee_type': peewee_type, - 'nullable': is_nullable, - 'default': column_default, - 'is_primary': column_key == 'PRI', - 'extra': extra, + "data_type": data_type, + "column_type": column_type, + "peewee_type": peewee_type, + "nullable": is_nullable, + "default": column_default, + "is_primary": column_key == "PRI", + "extra": extra, } - + return columns @@ -187,26 +180,36 @@ def get_peewee_field_type(field: Field) -> str: def get_base_field_type(field: Field) -> str: """Get base peewee field type by walking the MRO chain. - + Custom field types (like DateTimeTzField, JSONField) inherit from standard types. This function returns the underlying standard type for comparison. """ # Standard peewee field types we consider as "base" types STANDARD_TYPES = { - 'CharField', 'TextField', 'IntegerField', 'BigIntegerField', - 'FloatField', 'BooleanField', 'DateTimeField', 'DateField', - 'TimeField', 'DecimalField', 'ForeignKeyField', 'ManyToManyField', - 'PrimaryKeyField', 'AutoField' + "CharField", + "TextField", + "IntegerField", + "BigIntegerField", + "FloatField", + "BooleanField", + "DateTimeField", + "DateField", + "TimeField", + "DecimalField", + "ForeignKeyField", + "ManyToManyField", + "PrimaryKeyField", + "AutoField", } - + # Walk through the MRO (Method Resolution Order) to find standard type for cls in field.__class__.__mro__: class_name = cls.__name__ if class_name in STANDARD_TYPES: return class_name - + # Fallback to TextField if no standard type found - return 'TextField' + return "TextField" def normalize_field_type(field: Field) -> str: @@ -216,7 +219,7 @@ def normalize_field_type(field: Field) -> str: def compare_fields(model_fields: dict, db_columns: dict) -> dict: """Compare model fields with database columns - + Returns: dict: { 'added': {field_name: field_obj}, # New fields not in DB @@ -225,72 +228,72 @@ def compare_fields(model_fields: dict, db_columns: dict) -> dict: } """ result = { - 'added': {}, - 'changed': {}, - 'removed': {}, + "added": {}, + "changed": {}, + "removed": {}, } - + # Skip auto-generated fields like id, create_time, etc. - skip_fields = {'id'} - + skip_fields = {"id"} + for field_name, field in model_fields.items(): if field_name in skip_fields: continue - + # Check if field exists in database if field_name not in db_columns: - result['added'][field_name] = field + result["added"][field_name] = field logger.info(f" New field detected: {field_name} ({field.__class__.__name__})") else: # Check if type changed db_col = db_columns[field_name] model_base_type = normalize_field_type(field) - db_type = db_col['peewee_type'] - + db_type = db_col["peewee_type"] + # Type mismatch if model_base_type != db_type: - result['changed'][field_name] = (db_col, field) + result["changed"][field_name] = (db_col, field) logger.info(f" Field type changed: {field_name} ({db_type} -> {model_base_type}, actual: {field.__class__.__name__})") - + # Detect removed fields: columns in DB but not in model for col_name, col_info in db_columns.items(): if col_name in skip_fields: continue if col_name not in model_fields: - result['removed'][col_name] = col_info + result["removed"][col_name] = col_info logger.info(f" Removed field detected: {col_name} ({col_info['column_type']})") - + return result def generate_field_code(field: Field, field_name: str) -> str: """Generate peewee field definition code""" field_class = field.__class__.__name__ - + # Map custom field types to standard peewee types for migration # These custom types will be stored as their underlying standard type custom_to_standard = { - 'LongTextField': 'TextField', - 'JSONField': 'TextField', - 'ListField': 'TextField', - 'SerializedField': 'TextField', - 'DateTimeTzField': 'CharField', + "LongTextField": "TextField", + "JSONField": "TextField", + "ListField": "TextField", + "SerializedField": "TextField", + "DateTimeTzField": "CharField", } - + # Use standard type for custom fields pw_field_class = custom_to_standard.get(field_class, field_class) - + # Build field arguments args = [] - + # max_length for CharField - if pw_field_class == 'CharField' and hasattr(field, 'max_length') and field.max_length is not None: + if pw_field_class == "CharField" and hasattr(field, "max_length") and field.max_length is not None: args.append(f"max_length={field.max_length}") - + # null if field.null: args.append("null=True") - + # default if field.default is not None: default_val = field.default @@ -306,54 +309,54 @@ def generate_field_code(field: Field, field_name: str) -> str: args.append(f"default={default_val}") elif isinstance(default_val, list): args.append(f"default={default_val}") - + # index - if getattr(field, 'index', False): + if getattr(field, "index", False): args.append("index=True") - + # unique - if getattr(field, 'unique', False): + if getattr(field, "unique", False): args.append("unique=True") - - args_str = ', '.join(args) + + args_str = ", ".join(args) return f"pw.{pw_field_class}({args_str})" def generate_add_field_sql(table_name: str, field: Field, field_name: str) -> str: """Generate raw SQL for adding a field to MySQL table. - + This is used for existing tables where migrator.add_fields doesn't work because the model is not registered in migrator.orm. """ field_class = field.__class__.__name__ - + # Determine MySQL column type mysql_type_map = { - 'CharField': f'VARCHAR({field.max_length})' if hasattr(field, 'max_length') and field.max_length else 'VARCHAR(255)', - 'TextField': 'LONGTEXT', - 'LongTextField': 'LONGTEXT', - 'JSONField': 'LONGTEXT', - 'ListField': 'LONGTEXT', - 'SerializedField': 'LONGTEXT', - 'IntegerField': 'INT', - 'BigIntegerField': 'BIGINT', - 'FloatField': 'DOUBLE', - 'BooleanField': 'TINYINT(1)', - 'DateTimeField': 'DATETIME', - 'DateTimeTzField': f'VARCHAR({field.max_length})' if hasattr(field, 'max_length') and field.max_length else 'VARCHAR(255)', + "CharField": f"VARCHAR({field.max_length})" if hasattr(field, "max_length") and field.max_length else "VARCHAR(255)", + "TextField": "LONGTEXT", + "LongTextField": "LONGTEXT", + "JSONField": "LONGTEXT", + "ListField": "LONGTEXT", + "SerializedField": "LONGTEXT", + "IntegerField": "INT", + "BigIntegerField": "BIGINT", + "FloatField": "DOUBLE", + "BooleanField": "TINYINT(1)", + "DateTimeField": "DATETIME", + "DateTimeTzField": f"VARCHAR({field.max_length})" if hasattr(field, "max_length") and field.max_length else "VARCHAR(255)", } - - mysql_type = mysql_type_map.get(field_class, 'LONGTEXT') - + + mysql_type = mysql_type_map.get(field_class, "LONGTEXT") + # Build column definition - parts = [f'`{field_name}`', mysql_type] - + parts = [f"`{field_name}`", mysql_type] + # NULL/NOT NULL if field.null: - parts.append('NULL') + parts.append("NULL") else: - parts.append('NOT NULL') - + parts.append("NOT NULL") + # DEFAULT if field.default is not None: default_val = field.default @@ -366,21 +369,22 @@ def generate_add_field_sql(table_name: str, field: Field, field_name: str) -> st parts.append(f"DEFAULT {default_val}") elif isinstance(default_val, dict) or isinstance(default_val, list): import json + escaped = json.dumps(default_val).replace("'", "''") parts.append(f"DEFAULT '{escaped}'") - + # COMMENT - if hasattr(field, 'help_text') and field.help_text: + if hasattr(field, "help_text") and field.help_text: escaped = field.help_text.replace("'", "''") parts.append(f"COMMENT '{escaped}'") - + sql = f"ALTER TABLE `{table_name}` ADD COLUMN {' '.join(parts)}" - + # Add index if needed index_sql = None - if getattr(field, 'index', False): + if getattr(field, "index", False): index_sql = f"CREATE INDEX `idx_{table_name}_{field_name}` ON `{table_name}` (`{field_name}`)" - + return sql, index_sql @@ -396,22 +400,22 @@ def generate_rollback_field_sql(table_name: str, field_name: str) -> str: def generate_rollback_add_field_sql(table_name: str, col_info: dict, field_name: str) -> str: """Generate SQL for rolling back a dropped field (re-adding it). - + This reconstructs the ADD COLUMN statement from the column info that was captured before the field was dropped. """ - mysql_type = col_info.get('column_type', 'LONGTEXT') - - parts = [f'`{field_name}`', mysql_type] - + mysql_type = col_info.get("column_type", "LONGTEXT") + + parts = [f"`{field_name}`", mysql_type] + # NULL/NOT NULL - if col_info.get('nullable', True): - parts.append('NULL') + if col_info.get("nullable", True): + parts.append("NULL") else: - parts.append('NOT NULL') - + parts.append("NOT NULL") + # DEFAULT - default_val = col_info.get('default') + default_val = col_info.get("default") if default_val is not None: if isinstance(default_val, str): escaped = default_val.replace("'", "''") @@ -420,38 +424,38 @@ def generate_rollback_add_field_sql(table_name: str, col_info: dict, field_name: parts.append(f"DEFAULT {1 if default_val else 0}") elif isinstance(default_val, (int, float)): parts.append(f"DEFAULT {default_val}") - + sql = f"ALTER TABLE `{table_name}` ADD COLUMN {' '.join(parts)}" - + # Re-add index if it was a non-primary key index_sql = None - if col_info.get('column_key') == 'MUL': + if col_info.get("column_key") == "MUL": index_sql = f"CREATE INDEX `idx_{table_name}_{field_name}` ON `{table_name}` (`{field_name}`)" - + return sql, index_sql def generate_rollback_modify_sql(table_name: str, old_info: dict, field_name: str) -> str: """Generate SQL for rolling back a field type change. - + Note: This restores the column type, but data values may need manual handling if the type conversion caused data loss or transformation. """ # Reconstruct MySQL type from old_info - mysql_type = old_info.get('column_type', 'LONGTEXT') - + mysql_type = old_info.get("column_type", "LONGTEXT") + # Build column definition - parts = [f'`{field_name}`', mysql_type] - + parts = [f"`{field_name}`", mysql_type] + # NULL/NOT NULL - if old_info.get('nullable', True): - parts.append('NULL') + if old_info.get("nullable", True): + parts.append("NULL") else: - parts.append('NOT NULL') - + parts.append("NOT NULL") + # DEFAULT (if available) - if old_info.get('default') is not None: - default_val = old_info['default'] + if old_info.get("default") is not None: + default_val = old_info["default"] if isinstance(default_val, str): escaped = default_val.replace("'", "''") parts.append(f"DEFAULT '{escaped}'") @@ -459,41 +463,41 @@ def generate_rollback_modify_sql(table_name: str, old_info: dict, field_name: st parts.append(f"DEFAULT {1 if default_val else 0}") elif isinstance(default_val, (int, float)): parts.append(f"DEFAULT {default_val}") - + return f"ALTER TABLE `{table_name}` MODIFY COLUMN {' '.join(parts)}" def generate_modify_field_sql(table_name: str, field: Field, field_name: str) -> str: """Generate SQL for modifying a field in MySQL table.""" field_class = field.__class__.__name__ - + # Determine MySQL column type mysql_type_map = { - 'CharField': f'VARCHAR({field.max_length})' if hasattr(field, 'max_length') and field.max_length else 'VARCHAR(255)', - 'TextField': 'LONGTEXT', - 'LongTextField': 'LONGTEXT', - 'JSONField': 'LONGTEXT', - 'ListField': 'LONGTEXT', - 'SerializedField': 'LONGTEXT', - 'IntegerField': 'INT', - 'BigIntegerField': 'BIGINT', - 'FloatField': 'DOUBLE', - 'BooleanField': 'TINYINT(1)', - 'DateTimeField': 'DATETIME', - 'DateTimeTzField': f'VARCHAR({field.max_length})' if hasattr(field, 'max_length') and field.max_length else 'VARCHAR(255)', + "CharField": f"VARCHAR({field.max_length})" if hasattr(field, "max_length") and field.max_length else "VARCHAR(255)", + "TextField": "LONGTEXT", + "LongTextField": "LONGTEXT", + "JSONField": "LONGTEXT", + "ListField": "LONGTEXT", + "SerializedField": "LONGTEXT", + "IntegerField": "INT", + "BigIntegerField": "BIGINT", + "FloatField": "DOUBLE", + "BooleanField": "TINYINT(1)", + "DateTimeField": "DATETIME", + "DateTimeTzField": f"VARCHAR({field.max_length})" if hasattr(field, "max_length") and field.max_length else "VARCHAR(255)", } - - mysql_type = mysql_type_map.get(field_class, 'LONGTEXT') - + + mysql_type = mysql_type_map.get(field_class, "LONGTEXT") + # Build column definition - parts = [f'`{field_name}`', mysql_type] - + parts = [f"`{field_name}`", mysql_type] + # NULL/NOT NULL if field.null: - parts.append('NULL') + parts.append("NULL") else: - parts.append('NOT NULL') - + parts.append("NOT NULL") + # DEFAULT if field.default is not None: default_val = field.default @@ -506,14 +510,15 @@ def generate_modify_field_sql(table_name: str, field: Field, field_name: str) -> parts.append(f"DEFAULT {default_val}") elif isinstance(default_val, dict) or isinstance(default_val, list): import json + escaped = json.dumps(default_val).replace("'", "''") parts.append(f"DEFAULT '{escaped}'") - + # COMMENT - if hasattr(field, 'help_text') and field.help_text: + if hasattr(field, "help_text") and field.help_text: escaped = field.help_text.replace("'", "''") parts.append(f"COMMENT '{escaped}'") - + return f"ALTER TABLE `{table_name}` MODIFY COLUMN {' '.join(parts)}" @@ -521,126 +526,126 @@ def generate_migration_content(new_tables: list, field_changes: dict, migrate_di """Generate migration file content""" lines = [ '"""Peewee migrations."""', - '', - 'from contextlib import suppress', - '', - 'import peewee as pw', - 'from peewee_migrate import Migrator', - '', - '', - 'with suppress(ImportError):', - ' import playhouse.postgres_ext as pw_pext', - '', - '', - 'def migrate(migrator: Migrator, database: pw.Database, *, fake=False):', + "", + "from contextlib import suppress", + "", + "import peewee as pw", + "from peewee_migrate import Migrator", + "", + "", + "with suppress(ImportError):", + " import playhouse.postgres_ext as pw_pext", + "", + "", + "def migrate(migrator: Migrator, database: pw.Database, *, fake=False):", ' """Write your migrations here."""', - '', + "", ] - + # Generate create_model for new tables for model in new_tables: table_name = model._meta.table_name model_name = model.__name__ - - lines.append(' @migrator.create_model') - lines.append(f' class {model_name}(pw.Model):') - + + lines.append(" @migrator.create_model") + lines.append(f" class {model_name}(pw.Model):") + # Get all fields fields = model._meta.fields for field_name, field in fields.items(): field_code = generate_field_code(field, field_name) - lines.append(f' {field_name} = {field_code}') - - lines.append('') - lines.append(' class Meta:') + lines.append(f" {field_name} = {field_code}") + + lines.append("") + lines.append(" class Meta:") lines.append(f' table_name = "{table_name}"') - + # Add indexes if defined - indexes = getattr(model._meta, 'indexes', None) + indexes = getattr(model._meta, "indexes", None) if indexes: - lines.append(f' indexes = {indexes}') - - lines.append('') - + lines.append(f" indexes = {indexes}") + + lines.append("") + # Generate SQL for adding new fields to existing tables for table_name, changes in field_changes.items(): - if changes.get('added'): - for field_name, field in changes['added'].items(): + if changes.get("added"): + for field_name, field in changes["added"].items(): sql, index_sql = generate_add_field_sql(table_name, field, field_name) lines.append(f' migrator.sql("{sql}")') if index_sql: lines.append(f' migrator.sql("{index_sql}")') - lines.append('') - + lines.append("") + # Generate SQL for modifying fields in existing tables for table_name, changes in field_changes.items(): - if changes.get('changed'): - for field_name, (old_info, field) in changes['changed'].items(): + if changes.get("changed"): + for field_name, (old_info, field) in changes["changed"].items(): modify_sql = generate_modify_field_sql(table_name, field, field_name) lines.append(f' migrator.sql("{modify_sql}")') - lines.append('') - + lines.append("") + # Generate SQL for dropping removed fields from existing tables if drop_fields: for table_name, changes in field_changes.items(): - if changes.get('removed'): - for field_name, col_info in changes['removed'].items(): + if changes.get("removed"): + for field_name, col_info in changes["removed"].items(): drop_sql = generate_drop_field_sql(table_name, field_name) - lines.append(f' # WARNING: Dropping column `{field_name}` from `{table_name}` - this will permanently delete data!') + lines.append(f" # WARNING: Dropping column `{field_name}` from `{table_name}` - this will permanently delete data!") lines.append(f' migrator.sql("{drop_sql}")') - lines.append('') - + lines.append("") + # Generate rollback - lines.append('') - lines.append('def rollback(migrator: Migrator, database: pw.Database, *, fake=False):') + lines.append("") + lines.append("def rollback(migrator: Migrator, database: pw.Database, *, fake=False):") lines.append(' """Write your rollback migrations here."""') - lines.append('') - + lines.append("") + # Rollback: re-add dropped fields (before other rollbacks, since they may depend on these fields) if drop_fields: for table_name, changes in field_changes.items(): - if changes.get('removed'): - for field_name, col_info in changes['removed'].items(): + if changes.get("removed"): + for field_name, col_info in changes["removed"].items(): add_sql, index_sql = generate_rollback_add_field_sql(table_name, col_info, field_name) - lines.append(f' # Re-add dropped column `{field_name}` to `{table_name}` (data is lost)') + lines.append(f" # Re-add dropped column `{field_name}` to `{table_name}` (data is lost)") lines.append(f' migrator.sql("{add_sql}")') if index_sql: lines.append(f' migrator.sql("{index_sql}")') - + # Rollback: reverse field type changes first (before removing added fields) for table_name, changes in field_changes.items(): - if changes.get('changed'): - for field_name, (old_info, field) in changes['changed'].items(): + if changes.get("changed"): + for field_name, (old_info, field) in changes["changed"].items(): rollback_modify_sql = generate_rollback_modify_sql(table_name, old_info, field_name) - lines.append(' # Note: Data values may need manual handling if type conversion caused data loss') + lines.append(" # Note: Data values may need manual handling if type conversion caused data loss") lines.append(f' migrator.sql("{rollback_modify_sql}")') - + # Rollback: remove added fields using SQL for table_name, changes in field_changes.items(): - if changes.get('added'): - for field_name in changes['added'].keys(): + if changes.get("added"): + for field_name in changes["added"].keys(): rollback_sql = generate_rollback_field_sql(table_name, field_name) lines.append(f' migrator.sql("{rollback_sql}")') - + # Rollback: remove tables (in reverse order) for model in reversed(new_tables): table_name = model._meta.table_name lines.append(f' migrator.remove_model("{table_name}")') - - lines.append('') - - return '\n'.join(lines) + + lines.append("") + + return "\n".join(lines) def create_migration(router: Router, models: list, db, name: str = "auto", drop_fields: bool = False): """Create a new migration by auto-detecting model changes - + Detects: 1. New tables -> generate create_model 2. New fields in existing tables -> generate add_fields 3. Field type changes -> generate change_fields 4. Removed fields (only when --drop is specified) -> generate drop_fields - + Args: router: peewee-migrate Router instance models: List of model classes to compare against database @@ -650,18 +655,15 @@ def create_migration(router: Router, models: list, db, name: str = "auto", drop_ """ try: # Get existing tables from database - cursor = db.execute_sql( - "SELECT table_name FROM information_schema.tables WHERE table_schema = %s", - (db.database,) - ) + cursor = db.execute_sql("SELECT table_name FROM information_schema.tables WHERE table_schema = %s", (db.database,)) existing_tables = {row[0] for row in cursor.fetchall()} - + new_tables = [] field_changes = {} - + for model in models: table_name = model._meta.table_name - + if table_name not in existing_tables: # New table new_tables.append(model) @@ -669,58 +671,58 @@ def create_migration(router: Router, models: list, db, name: str = "auto", drop_ else: # Existing table - check for field changes logger.info(f"Checking existing table: {table_name}") - + # Get model fields (exclude auto-generated) model_fields = {} for field_name, field in model._meta.fields.items(): # Skip id and base model fields - if field_name in ('id', 'create_time', 'create_date', 'update_time', 'update_date'): + if field_name in ("id", "create_time", "create_date", "update_time", "update_date"): continue - if hasattr(field, '_auto_created') and field._auto_created: + if hasattr(field, "_auto_created") and field._auto_created: continue model_fields[field_name] = field - + # Get database columns db_columns = get_table_columns(db, table_name) - + # Compare changes = compare_fields(model_fields, db_columns) - - if changes['added'] or changes['changed'] or changes['removed']: + + if changes["added"] or changes["changed"] or changes["removed"]: field_changes[table_name] = changes - + # Check if any changes detected - has_removed = any(changes.get('removed') for changes in field_changes.values()) + has_removed = any(changes.get("removed") for changes in field_changes.values()) if not drop_fields and has_removed: removed_details = [] for table_name, changes in field_changes.items(): - if changes.get('removed'): - for col_name in changes['removed']: + if changes.get("removed"): + for col_name in changes["removed"]: removed_details.append(f"{table_name}.{col_name}") logger.warning(f"Removed fields detected (not included in migration, use --drop to include): {', '.join(removed_details)}") # Remove 'removed' from changes since we're not acting on them for table_name in field_changes: - field_changes[table_name]['removed'] = {} - - if not new_tables and not any(changes['added'] or changes['changed'] for changes in field_changes.values()): + field_changes[table_name]["removed"] = {} + + if not new_tables and not any(changes["added"] or changes["changed"] for changes in field_changes.values()): if not (drop_fields and has_removed): logger.info("No schema changes detected, migration not created") return None - + # Generate migration file content migration_content = generate_migration_content(new_tables, field_changes, router.migrate_dir, name, drop_fields=drop_fields) - + # Get next migration number (count existing migration files) - existing_migrations = [f for f in os.listdir(router.migrate_dir) if f.endswith('.py') and not f.startswith('_')] + existing_migrations = [f for f in os.listdir(router.migrate_dir) if f.endswith(".py") and not f.startswith("_")] migration_num = len(existing_migrations) + 1 - migration_file = os.path.join(router.migrate_dir, f'{migration_num:03d}_{name}.py') - - with open(migration_file, 'w') as f: + migration_file = os.path.join(router.migrate_dir, f"{migration_num:03d}_{name}.py") + + with open(migration_file, "w") as f: f.write(migration_content) - + logger.info(f"Created migration: {migration_file}") return migration_file - + except Exception as e: logger.error(f"Failed to create migration: {e}") raise @@ -733,7 +735,7 @@ def run_migrations(router: Router): if not diff: logger.info("No pending migrations to run") return - + router.run() logger.info("Migrations completed successfully") except Exception as e: @@ -747,7 +749,7 @@ def list_migrations(router: Router): if not todo: logger.info("No migration files found") return - + logger.info("Available migrations:") done = set(router.done) for migration in todo: @@ -758,194 +760,178 @@ def list_migrations(router: Router): def diff_schema(models: list, db): """Show schema differences between models and database""" logger.info("Checking schema differences...") - + # Tables to ignore (managed by peewee-migrate) - IGNORE_TABLES = {'migratehistory'} - + IGNORE_TABLES = {"migratehistory"} + # Get all model table names model_tables = set() for model in models: table_name = model._meta.table_name model_tables.add(table_name) - + logger.info(f"Found {len(model_tables)} model tables") - + # Get existing tables from database - cursor = db.execute_sql( - "SELECT table_name FROM information_schema.tables WHERE table_schema = %s", - (db.database,) - ) + cursor = db.execute_sql("SELECT table_name FROM information_schema.tables WHERE table_schema = %s", (db.database,)) existing_tables = {row[0] for row in cursor.fetchall() if row[0] not in IGNORE_TABLES} - + # Find tables that exist in models but not in database missing_tables = model_tables - existing_tables if missing_tables: logger.warning(f"Tables not in database ({len(missing_tables)}): {', '.join(sorted(missing_tables))}") - + # Find tables that exist in database but not in models extra_tables = existing_tables - model_tables if extra_tables: logger.info(f"Tables in database but not in models: {', '.join(sorted(extra_tables))}") - + # Check field differences for existing tables common_tables = model_tables & existing_tables if common_tables: logger.info(f"\nChecking field differences for {len(common_tables)} existing tables...") - + total_added = 0 total_changed = 0 total_removed = 0 - + for model in models: table_name = model._meta.table_name if table_name not in common_tables: continue - + # Get model fields model_fields = {} for field_name, field in model._meta.fields.items(): - if field_name in ('id', 'create_time', 'create_date', 'update_time', 'update_date'): + if field_name in ("id", "create_time", "create_date", "update_time", "update_date"): continue model_fields[field_name] = field - + # Get database columns db_columns = get_table_columns(db, table_name) - + # Compare changes = compare_fields(model_fields, db_columns) - - if changes['added']: - total_added += len(changes['added']) - field_details = [f"{k}:{v.__class__.__name__}" for k, v in changes['added'].items()] + + if changes["added"]: + total_added += len(changes["added"]) + field_details = [f"{k}:{v.__class__.__name__}" for k, v in changes["added"].items()] logger.info(f" {table_name}: {len(changes['added'])} new field(s) - {field_details}") - - if changes['changed']: - total_changed += len(changes['changed']) - field_details = [f"{k}:{v[1].__class__.__name__}" for k, v in changes['changed'].items()] + + if changes["changed"]: + total_changed += len(changes["changed"]) + field_details = [f"{k}:{v[1].__class__.__name__}" for k, v in changes["changed"].items()] logger.info(f" {table_name}: {len(changes['changed'])} changed field(s) - {field_details}") - - if changes['removed']: - total_removed += len(changes['removed']) - field_details = [f"{k}:{v['column_type']}" for k, v in changes['removed'].items()] + + if changes["removed"]: + total_removed += len(changes["removed"]) + field_details = [f"{k}:{v['column_type']}" for k, v in changes["removed"].items()] logger.warning(f" {table_name}: {len(changes['removed'])} removed field(s) - {field_details}") - + logger.info(f"\nSummary: {total_added} new fields, {total_changed} changed fields, {total_removed} removed fields") def main(): parser = argparse.ArgumentParser( - description='Database Schema Synchronization Tool using peewee-migrate', + description="Database Schema Synchronization Tool using peewee-migrate", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: # List all migrations python db_schema_sync.py --list --host localhost --port 3306 --user root --password xxx --database rag_flow --version v0.26.3 - + # Create migration from model changes python db_schema_sync.py --create --host localhost --port 3306 --user root --password xxx --database rag_flow --version v0.26.3 - + # Create migration including dropped fields (destructive!) python db_schema_sync.py --create --drop --host localhost --port 3306 --user root --password xxx --database rag_flow --version v0.26.3 - + # Run all pending migrations python db_schema_sync.py --migrate --host localhost --port 3306 --user root --password xxx --database rag_flow --version v0.26.3 - + # Show schema differences python db_schema_sync.py --diff --host localhost --port 3306 --user root --password xxx --database rag_flow --version v0.26.3 -""" +""", ) - + # Database connection options - parser.add_argument('--host', type=str, required=True, help='MySQL host') - parser.add_argument('--port', type=int, default=3306, help='MySQL port (default: 3306)') - parser.add_argument('--user', type=str, required=True, help='MySQL user') - parser.add_argument('--password', type=str, required=True, help='MySQL password') - parser.add_argument('--database', type=str, required=True, help='MySQL database name') - + parser.add_argument("--host", type=str, required=True, help="MySQL host") + parser.add_argument("--port", type=int, default=3306, help="MySQL port (default: 3306)") + parser.add_argument("--user", type=str, required=True, help="MySQL user") + parser.add_argument("--password", type=str, required=True, help="MySQL password") + parser.add_argument("--database", type=str, required=True, help="MySQL database name") + # Version option - parser.add_argument('--version', '-v', type=str, required=True, - help='Version number in format vxx.xx.xx (e.g., v0.26.3)') - + parser.add_argument("--version", "-v", type=str, required=True, help="Version number in format vxx.xx.xx (e.g., v0.26.3)") + # Action options - parser.add_argument('--list', '-l', action='store_true', help='List all migrations') - parser.add_argument('--create', '-c', action='store_true', - help='Create migration from model changes (auto-detect)') - parser.add_argument('--migrate', '-m', action='store_true', help='Run pending migrations') - parser.add_argument('--diff', '-d', action='store_true', help='Show schema differences') - + parser.add_argument("--list", "-l", action="store_true", help="List all migrations") + parser.add_argument("--create", "-c", action="store_true", help="Create migration from model changes (auto-detect)") + parser.add_argument("--migrate", "-m", action="store_true", help="Run pending migrations") + parser.add_argument("--diff", "-d", action="store_true", help="Show schema differences") + # Migration options - parser.add_argument('--name', '-n', type=str, default='auto', help='Migration name') - parser.add_argument('--drop', action='store_true', - help='Include DROP COLUMN for fields removed from models (destructive - will permanently delete data!)') - + parser.add_argument("--name", "-n", type=str, default="auto", help="Migration name") + parser.add_argument("--drop", action="store_true", help="Include DROP COLUMN for fields removed from models (destructive - will permanently delete data!)") + args = parser.parse_args() - + # Validate version format if not validate_version(args.version): logger.error(f"Invalid version format: {args.version}. Expected format: vxx.xx.xx (e.g., v0.26.3)") sys.exit(1) - + # Validate at least one action is specified if not any([args.list, args.create, args.migrate, args.diff]): parser.print_help() logger.error("Please specify at least one action: --list, --create, --migrate, or --diff") sys.exit(1) - + # Convert version to directory name version_dir = version_to_dirname(args.version) - migrate_dir = os.path.join(PROJECT_BASE, 'tools', 'migrate', version_dir) - + migrate_dir = os.path.join(PROJECT_BASE, "tools", "migrate", version_dir) + logger.info(f"Version: {args.version}") logger.info(f"Migration directory: {migrate_dir}") - + # Create migration directory if it doesn't exist os.makedirs(migrate_dir, exist_ok=True) - + # Load database models logger.info("Loading database models from api/db/db_models.py...") models, _ = load_db_models() logger.info(f"Found {len(models)} model classes") - + # Create database connection - db = create_database_connection( - host=args.host, - port=args.port, - user=args.user, - password=args.password, - database=args.database - ) - + db = create_database_connection(host=args.host, port=args.port, user=args.user, password=args.password, database=args.database) + try: db.connect() logger.info(f"Connected to database: {args.database}") - + # Create router - router = Router( - db, - migrate_dir, - ignore=['basemodel', 'base_model', 'migratehistory'] - ) - + router = Router(db, migrate_dir, ignore=["basemodel", "base_model", "migratehistory"]) + # Execute requested actions if args.list: list_migrations(router) - + if args.create: create_migration(router, models, db, args.name, drop_fields=args.drop) - + if args.migrate: run_migrations(router) - + if args.diff: diff_schema(models, db) - + finally: if not db.is_closed(): db.close() logger.info("Database connection closed") - + logger.info("Done.") -if __name__ == '__main__': - main() \ No newline at end of file +if __name__ == "__main__": + main() diff --git a/tools/scripts/mysql_migration.py b/tools/scripts/mysql_migration.py index 7dbf0971af..536ede6447 100644 --- a/tools/scripts/mysql_migration.py +++ b/tools/scripts/mysql_migration.py @@ -49,10 +49,7 @@ PROJECT_BASE = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(_ sys.path.insert(0, PROJECT_BASE) # Configure logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(levelname)s - %(message)s' -) +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) @@ -61,34 +58,34 @@ MIGRATION_DB_VERSION_MARKER = "mysql_migration.database.version" class MigrationConfig: """Configuration for MySQL connection""" - - def __init__(self, host: str = 'localhost', port: int = 3306, - user: str = 'root', password: str = '', database: str = 'rag_flow'): + + def __init__(self, host: str = "localhost", port: int = 3306, user: str = "root", password: str = "", database: str = "rag_flow"): self.host = host self.port = port self.user = user self.password = password self.database = database - + @classmethod - def from_config_file(cls, config_path: str) -> 'MigrationConfig': + def from_config_file(cls, config_path: str) -> "MigrationConfig": """Load configuration from YAML config file""" try: from ruamel.yaml import YAML + yaml = YAML(typ="safe", pure=True) - - with open(config_path, 'r') as f: + + with open(config_path, "r") as f: config = yaml.load(f) - + # Try to get database config - db_config = config.get('database', config.get('mysql', {})) - + db_config = config.get("database", config.get("mysql", {})) + return cls( - host=db_config.get('host', 'localhost'), - port=db_config.get('port', 3306), - user=db_config.get('user', 'root'), - password=db_config.get('password', ''), - database=db_config.get('name', db_config.get('database', 'rag_flow')) + host=db_config.get("host", "localhost"), + port=db_config.get("port", 3306), + user=db_config.get("user", "root"), + password=db_config.get("password", ""), + database=db_config.get("name", db_config.get("database", "rag_flow")), ) except Exception as e: logger.warning(f"Failed to load config file: {e}, using defaults") @@ -97,30 +94,25 @@ class MigrationConfig: class MigrationStats: """Track migration statistics""" - + def __init__(self): self.tables_operated = [] self.rows_processed = 0 self.start_time = None self.end_time = None self.stage_stats = [] - + def start(self): self.start_time = time.time() - + def end(self): self.end_time = time.time() - + def add_stage_stats(self, stage_name: str, tables: list, rows: int, duration: float): - self.stage_stats.append({ - 'stage': stage_name, - 'tables': tables, - 'rows': rows, - 'duration': duration - }) + self.stage_stats.append({"stage": stage_name, "tables": tables, "rows": rows, "duration": duration}) self.tables_operated.extend(tables) self.rows_processed += rows - + def print_summary(self): duration = self.end_time - self.start_time if self.end_time and self.start_time else 0 logger.info("=" * 60) @@ -132,52 +124,36 @@ class MigrationStats: logger.info("-" * 60) logger.info("Stage Details:") for stat in self.stage_stats: - logger.info(f" [{stat['stage']}] Tables: {', '.join(stat['tables'])}, " - f"Rows: {stat['rows']}, Duration: {stat['duration']:.2f}s") + logger.info(f" [{stat['stage']}] Tables: {', '.join(stat['tables'])}, Rows: {stat['rows']}, Duration: {stat['duration']:.2f}s") logger.info("=" * 60) class MigrationDatabase: """Database wrapper for migrations""" - + def __init__(self, config: MigrationConfig): self.config = config - self.db = MySQLDatabase( - config.database, - host=config.host, - port=config.port, - user=config.user, - password=config.password, - charset='utf8mb4' - ) + self.db = MySQLDatabase(config.database, host=config.host, port=config.port, user=config.user, password=config.password, charset="utf8mb4") self.migrator = MySQLMigrator(self.db) - + def connect(self): self.db.connect() logger.info(f"Connected to MySQL database: {self.config.database}") - + def close(self): if not self.db.is_closed(): self.db.close() logger.info("Database connection closed") - + def execute_sql(self, sql: str, params=None): return self.db.execute_sql(sql, params) - + def table_exists(self, table_name: str) -> bool: - cursor = self.execute_sql( - "SELECT COUNT(*) FROM information_schema.tables " - "WHERE table_schema = %s AND table_name = %s", - (self.config.database, table_name) - ) + cursor = self.execute_sql("SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = %s AND table_name = %s", (self.config.database, table_name)) return cursor.fetchone()[0] > 0 def column_exists(self, table_name: str, column_name: str) -> bool: - cursor = self.execute_sql( - "SELECT COUNT(*) FROM information_schema.columns " - "WHERE table_schema = %s AND table_name = %s AND column_name = %s", - (self.config.database, table_name, column_name) - ) + cursor = self.execute_sql("SELECT COUNT(*) FROM information_schema.columns WHERE table_schema = %s AND table_name = %s AND column_name = %s", (self.config.database, table_name, column_name)) return cursor.fetchone()[0] > 0 def get_system_setting_value(self, name: str) -> str | None: @@ -252,17 +228,19 @@ def should_skip_migration(current_db_version: str | None, target_version: str) - # Define model classes for migration (not importing from api.db.db_models) class BaseModel(Model): """Base model for migration tables""" + create_time = BigIntegerField(null=True, index=True) create_date = DateTimeField(null=True, index=True) update_time = BigIntegerField(null=True, index=True) update_date = DateTimeField(null=True, index=True) - + class Meta: database = None # Will be set dynamically class TenantLLM(BaseModel): """Tenant LLM model (source table)""" + id = PrimaryKeyField() tenant_id = CharField(max_length=32, null=False, index=True) llm_factory = CharField(max_length=128, null=False, index=True) @@ -273,7 +251,7 @@ class TenantLLM(BaseModel): max_tokens = IntegerField(default=8192, index=True) used_tokens = IntegerField(default=0, index=True) status = CharField(max_length=1, null=False, default="1", index=True) - + class Meta: table_name = "tenant_llm" database = None @@ -281,10 +259,11 @@ class TenantLLM(BaseModel): class TenantModelProvider(BaseModel): """Tenant Model Provider model (target table)""" + id = CharField(max_length=32, primary_key=True) provider_name = CharField(max_length=128, null=False, index=True) tenant_id = CharField(max_length=32, null=False, index=True) - + class Meta: table_name = "tenant_model_provider" database = None @@ -292,25 +271,26 @@ class TenantModelProvider(BaseModel): class MigrationStage: """Base class for migration stages""" - + name = "base_stage" description = "Base migration stage" source_tables = [] target_tables = [] + def __init__(self, db: MigrationDatabase, dry_run: bool = True, create_table_only: bool = False): self.db = db self.dry_run = dry_run self.create_table_only = create_table_only self._noop_completes_migration = False - + def check(self) -> bool: """Check if migration is needed""" raise NotImplementedError - + def execute(self) -> tuple[int, list]: """Execute migration, returns (rows_affected, tables_operated)""" raise NotImplementedError - + def create_target_table(self): """Create target table (override in subclass if needed)""" pass @@ -324,72 +304,66 @@ class MigrationStage: class TenantModelProviderStage(MigrationStage): """Migrate tenant_llm to tenant_model_provider""" - + name = "tenant_model_provider" description = "Migrate tenant_llm.llm_factory to tenant_model_provider.provider_name" source_tables = ["tenant_llm"] target_tables = ["tenant_model_provider"] - + def current_timestamp(self) -> int: return int(time.time()) - + def generate_uuid(self) -> str: """Generate 32-character UUID1""" return uuid.uuid1().hex - + def check(self) -> bool: """Check if migration is needed""" # Check if source table exists if not self.db.table_exists("tenant_llm"): logger.warning("Source table 'tenant_llm' does not exist") return False - + # Check if target table exists if not self.db.table_exists("tenant_model_provider"): if self.dry_run: - logger.info("[DRY RUN] Target table 'tenant_model_provider' does not exist. " - "Use --execute to create and populate the table.") + logger.info("[DRY RUN] Target table 'tenant_model_provider' does not exist. Use --execute to create and populate the table.") return False logger.info("Target table 'tenant_model_provider' does not exist, will create") return True - + # Check if there's data to migrate cursor = self.db.execute_sql( - "SELECT COUNT(*) FROM tenant_llm t1 " - "WHERE NOT EXISTS (" - " SELECT 1 FROM tenant_model_provider t2 " - " WHERE t2.tenant_id = t1.tenant_id AND t2.provider_name = t1.llm_factory" - ")" + "SELECT COUNT(*) FROM tenant_llm t1 WHERE NOT EXISTS ( SELECT 1 FROM tenant_model_provider t2 WHERE t2.tenant_id = t1.tenant_id AND t2.provider_name = t1.llm_factory)" ) count = cursor.fetchone()[0] - + if count == 0: self.mark_noop_completes_migration() logger.info("No new data to migrate from tenant_llm to tenant_model_provider") return False - + logger.info(f"Found {count} rows to migrate from tenant_llm to tenant_model_provider") return True - + def execute(self) -> tuple[int, list]: """Execute migration""" current_ts = self.current_timestamp() rows_inserted = 0 - + # Check if target table exists if not self.db.table_exists("tenant_model_provider"): if self.dry_run: - logger.info("[DRY RUN] Target table 'tenant_model_provider' does not exist. " - "Use --execute to create and populate the table.") + logger.info("[DRY RUN] Target table 'tenant_model_provider' does not exist. Use --execute to create and populate the table.") return 0, [] logger.info("Target table 'tenant_model_provider' does not exist, will create") self.create_target_table() - + # If create_table_only mode, skip data migration if self.create_table_only: logger.info("[CREATE TABLE ONLY] Target table created/verified, skipping data migration") return 0, self.target_tables - + # Get distinct tenant_id, llm_factory pairs that don't exist in target cursor = self.db.execute_sql( "SELECT DISTINCT tenant_id, llm_factory FROM tenant_llm t1 " @@ -398,48 +372,50 @@ class TenantModelProviderStage(MigrationStage): " WHERE t2.tenant_id = t1.tenant_id AND t2.provider_name = t1.llm_factory" ")" ) - + records = cursor.fetchall() - + if not records: logger.info("No records to migrate") return 0, [] - + logger.info(f"Migrating {len(records)} unique tenant_id/llm_factory pairs...") - + if self.dry_run: logger.info(f"[DRY RUN] Would insert {len(records)} records") return len(records), self.target_tables - + # Insert records in batches with parameterized SQL to avoid quote breakage/injection batch_size = 100 for i in range(0, len(records), batch_size): - batch = records[i:i + batch_size] + batch = records[i : i + batch_size] placeholders = [] params = [] for tenant_id, llm_factory in batch: record_id = self.generate_uuid() placeholders.append("(%s, %s, %s, %s, FROM_UNIXTIME(%s), %s, FROM_UNIXTIME(%s))") - params.extend([ - record_id, - llm_factory, - tenant_id, - current_ts * 1000, - current_ts, - current_ts * 1000, - current_ts, - ]) + params.extend( + [ + record_id, + llm_factory, + tenant_id, + current_ts * 1000, + current_ts, + current_ts * 1000, + current_ts, + ] + ) insert_sql = f""" - INSERT INTO tenant_model_provider + INSERT INTO tenant_model_provider (id, provider_name, tenant_id, create_time, create_date, update_time, update_date) - VALUES {', '.join(placeholders)} + VALUES {", ".join(placeholders)} """ self.db.execute_sql(insert_sql, params) rows_inserted += len(batch) logger.info(f"Inserted batch {i // batch_size + 1}: {len(batch)} records") - + return rows_inserted, self.target_tables - + def create_target_table(self): """Create tenant_model_provider table""" create_sql = """ @@ -485,18 +461,15 @@ class TenantModelInstanceStage(MigrationStage): # Check if tenant_model_provider exists (dependency) if not self.db.table_exists("tenant_model_provider"): if self.dry_run: - logger.info("[DRY RUN] Dependency table 'tenant_model_provider' does not exist. " - "Run 'tenant_model_provider' stage first or use --execute.") + logger.info("[DRY RUN] Dependency table 'tenant_model_provider' does not exist. Run 'tenant_model_provider' stage first or use --execute.") return False - logger.warning("Dependency table 'tenant_model_provider' does not exist. " - "Please run 'tenant_model_provider' stage first.") + logger.warning("Dependency table 'tenant_model_provider' does not exist. Please run 'tenant_model_provider' stage first.") return False # Check if target table exists if not self.db.table_exists("tenant_model_instance"): if self.dry_run: - logger.info("[DRY RUN] Target table 'tenant_model_instance' does not exist. " - "Use --execute to create and populate the table.") + logger.info("[DRY RUN] Target table 'tenant_model_instance' does not exist. Use --execute to create and populate the table.") return False logger.info("Target table 'tenant_model_instance' does not exist, will create") return True @@ -531,15 +504,13 @@ class TenantModelInstanceStage(MigrationStage): # Check if tenant_model_provider exists (dependency) if not self.db.table_exists("tenant_model_provider"): - logger.error("Dependency table 'tenant_model_provider' does not exist. " - "Please run 'tenant_model_provider' stage first.") + logger.error("Dependency table 'tenant_model_provider' does not exist. Please run 'tenant_model_provider' stage first.") return 0, [] # Check if target table exists if not self.db.table_exists("tenant_model_instance"): if self.dry_run: - logger.info("[DRY RUN] Target table 'tenant_model_instance' does not exist. " - "Use --execute to create and populate the table.") + logger.info("[DRY RUN] Target table 'tenant_model_instance' does not exist. Use --execute to create and populate the table.") return 0, [] logger.info("Target table 'tenant_model_instance' does not exist, will create") self.create_target_table() @@ -588,22 +559,24 @@ class TenantModelInstanceStage(MigrationStage): # Insert records in batches batch_size = 100 for i in range(0, len(records), batch_size): - batch = records[i:i + batch_size] + batch = records[i : i + batch_size] values = [] for tenant_id, llm_factory, api_key, status, provider_id in batch: record_id = self.generate_uuid() instance_name = "default" api_key_escaped = api_key.replace("'", "''") if api_key else "" status_val = "active" if status in ["1", "active", "enable"] else "inactive" - values.append(f"('{record_id}', '{instance_name}', '{provider_id}', " - f"'{api_key_escaped}', '{status_val}', " - f"{current_ts * 1000}, FROM_UNIXTIME({current_ts}), " - f"{current_ts * 1000}, FROM_UNIXTIME({current_ts}))") + values.append( + f"('{record_id}', '{instance_name}', '{provider_id}', " + f"'{api_key_escaped}', '{status_val}', " + f"{current_ts * 1000}, FROM_UNIXTIME({current_ts}), " + f"{current_ts * 1000}, FROM_UNIXTIME({current_ts}))" + ) insert_sql = f""" - INSERT INTO tenant_model_instance + INSERT INTO tenant_model_instance (id, instance_name, provider_id, api_key, status, create_time, create_date, update_time, update_date) - VALUES {', '.join(values)} + VALUES {", ".join(values)} """ self.db.execute_sql(insert_sql) rows_inserted += len(batch) @@ -688,11 +661,7 @@ class TenantModelInstanceStage(MigrationStage): seen[canonical] = rec else: dup_count += 1 - logger.debug( - f"Dedup api_key for tenant={tenant_id}, factory={llm_factory}, " - f"provider={provider_id}: keeping '{api_key[:20]}...', " - f"dropping '{seen[canonical][2][:20]}...'" - ) + logger.debug(f"Dedup api_key for tenant={tenant_id}, factory={llm_factory}, provider={provider_id}: keeping '{api_key[:20]}...', dropping '{seen[canonical][2][:20]}...'") deduped.extend(seen.values()) if dup_count > 0: @@ -770,28 +739,23 @@ class TenantModelStage(MigrationStage): # Check if tenant_model_provider exists (dependency) if not self.db.table_exists("tenant_model_provider"): if self.dry_run: - logger.info("[DRY RUN] Dependency table 'tenant_model_provider' does not exist. " - "Run 'tenant_model_provider' stage first or use --execute.") + logger.info("[DRY RUN] Dependency table 'tenant_model_provider' does not exist. Run 'tenant_model_provider' stage first or use --execute.") return False - logger.warning("Dependency table 'tenant_model_provider' does not exist. " - "Please run 'tenant_model_provider' stage first.") + logger.warning("Dependency table 'tenant_model_provider' does not exist. Please run 'tenant_model_provider' stage first.") return False # Check if tenant_model_instance exists (dependency) if not self.db.table_exists("tenant_model_instance"): if self.dry_run: - logger.info("[DRY RUN] Dependency table 'tenant_model_instance' does not exist. " - "Run 'tenant_model_instance' stage first or use --execute.") + logger.info("[DRY RUN] Dependency table 'tenant_model_instance' does not exist. Run 'tenant_model_instance' stage first or use --execute.") return False - logger.warning("Dependency table 'tenant_model_instance' does not exist. " - "Please run 'tenant_model_instance' stage first.") + logger.warning("Dependency table 'tenant_model_instance' does not exist. Please run 'tenant_model_instance' stage first.") return False # Check if target table exists if not self.db.table_exists("tenant_model"): if self.dry_run: - logger.info("[DRY RUN] Target table 'tenant_model' does not exist. " - "Use --execute to create and populate the table.") + logger.info("[DRY RUN] Target table 'tenant_model' does not exist. Use --execute to create and populate the table.") return False logger.info("Target table 'tenant_model' does not exist, will create") return True @@ -827,21 +791,18 @@ class TenantModelStage(MigrationStage): # Check if tenant_model_provider exists (dependency) if not self.db.table_exists("tenant_model_provider"): - logger.error("Dependency table 'tenant_model_provider' does not exist. " - "Please run 'tenant_model_provider' stage first.") + logger.error("Dependency table 'tenant_model_provider' does not exist. Please run 'tenant_model_provider' stage first.") return 0, [] # Check if tenant_model_instance exists (dependency) if not self.db.table_exists("tenant_model_instance"): - logger.error("Dependency table 'tenant_model_instance' does not exist. " - "Please run 'tenant_model_instance' stage first.") + logger.error("Dependency table 'tenant_model_instance' does not exist. Please run 'tenant_model_instance' stage first.") return 0, [] # Check if target table exists if not self.db.table_exists("tenant_model"): if self.dry_run: - logger.info("[DRY RUN] Target table 'tenant_model' does not exist. " - "Use --execute to create and populate the table.") + logger.info("[DRY RUN] Target table 'tenant_model' does not exist. Use --execute to create and populate the table.") return 0, [] logger.info("Target table 'tenant_model' does not exist, will create") self.create_target_table() @@ -891,8 +852,7 @@ class TenantModelStage(MigrationStage): if self.dry_run: logger.info(f"[DRY RUN] Would insert {len(resolved_records)} records") for source_id, llm_name, provider_id, instance_id, model_type, status, api_key in resolved_records[:5]: - logger.info(f" model_name={llm_name}, provider_id={provider_id}, " - f"instance_id={instance_id}, model_type={model_type}") + logger.info(f" model_name={llm_name}, provider_id={provider_id}, instance_id={instance_id}, model_type={model_type}") if len(resolved_records) > 5: logger.info(f" ... and {len(resolved_records) - 5} more records") return len(resolved_records), self.target_tables @@ -900,7 +860,7 @@ class TenantModelStage(MigrationStage): # Insert records in batches batch_size = 100 for i in range(0, len(resolved_records), batch_size): - batch = resolved_records[i:i + batch_size] + batch = resolved_records[i : i + batch_size] values = [] for source_id, llm_name, provider_id, instance_id, model_type, status, api_key in batch: record_id = self.generate_uuid() @@ -910,17 +870,19 @@ class TenantModelStage(MigrationStage): # Extract is_tools from api_key JSON and put it in extra extra = self._extract_extra_from_api_key(api_key) extra_escaped = extra.replace("'", "''") if extra else "{}" - values.append(f"('{record_id}', '{model_name_escaped}', '{provider_id}', " - f"'{instance_id}', '{model_type_escaped}', '{status_val}', " - f"'{extra_escaped}', " - f"{current_ts * 1000}, FROM_UNIXTIME({current_ts}), " - f"{current_ts * 1000}, FROM_UNIXTIME({current_ts}))") + values.append( + f"('{record_id}', '{model_name_escaped}', '{provider_id}', " + f"'{instance_id}', '{model_type_escaped}', '{status_val}', " + f"'{extra_escaped}', " + f"{current_ts * 1000}, FROM_UNIXTIME({current_ts}), " + f"{current_ts * 1000}, FROM_UNIXTIME({current_ts}))" + ) insert_sql = f""" - INSERT INTO tenant_model + INSERT INTO tenant_model (id, model_name, provider_id, instance_id, model_type, status, extra, create_time, create_date, update_time, update_date) - VALUES {', '.join(values)} + VALUES {", ".join(values)} """ self.db.execute_sql(insert_sql) rows_inserted += len(batch) @@ -937,9 +899,7 @@ class TenantModelStage(MigrationStage): Returns: dict mapping (provider_id, canonical_api_key) -> instance_id """ - cursor = self.db.execute_sql( - "SELECT id, provider_id, api_key FROM tenant_model_instance" - ) + cursor = self.db.execute_sql("SELECT id, provider_id, api_key FROM tenant_model_instance") lookup = {} for instance_id, provider_id, api_key in cursor.fetchall(): canonical = TenantModelInstanceStage._strip_is_tools_from_api_key(api_key) @@ -974,7 +934,9 @@ class TenantModelStage(MigrationStage): # entropy to be useful to an attacker who reads the log. logger.warning( "No matching instance for tenant_llm id=%s provider_id=%s llm_name=%s", - source_id, provider_id, llm_name, + source_id, + provider_id, + llm_name, ) if skipped > 0: @@ -1102,9 +1064,7 @@ class ModelIdConfigStage(MigrationStage): normalized = {} for key, item in value.items(): key_path = path + (str(key),) - should_normalize = key in self.model_id_fields or ( - key in self.search_config_model_id_fields and "search_config" in path - ) + should_normalize = key in self.model_id_fields or (key in self.search_config_model_id_fields and "search_config" in path) if should_normalize: normalized_item, item_changed = self.normalize_model_id(item) else: @@ -1154,8 +1114,7 @@ class ModelIdConfigStage(MigrationStage): def iter_string_changes(self): for table_name, column_name in self.existing_columns(self.string_columns): cursor = self.db.execute_sql( - f"SELECT id, `{column_name}` FROM `{table_name}` " - f"WHERE `{column_name}` IS NOT NULL AND `{column_name}` != '' AND `{column_name}` LIKE %s", + f"SELECT id, `{column_name}` FROM `{table_name}` WHERE `{column_name}` IS NOT NULL AND `{column_name}` != '' AND `{column_name}` LIKE %s", ("%@%",), ) while True: @@ -1170,8 +1129,7 @@ class ModelIdConfigStage(MigrationStage): def iter_json_changes(self): for table_name, column_name in self.existing_columns(self.json_columns): cursor = self.db.execute_sql( - f"SELECT id, `{column_name}` FROM `{table_name}` " - f"WHERE `{column_name}` IS NOT NULL AND `{column_name}` != '' AND `{column_name}` LIKE %s", + f"SELECT id, `{column_name}` FROM `{table_name}` WHERE `{column_name}` IS NOT NULL AND `{column_name}` != '' AND `{column_name}` LIKE %s", ("%@%",), ) while True: @@ -1271,10 +1229,10 @@ class ModelIdConfigStage(MigrationStage): # Registry of available migration stages MIGRATION_STAGES = { - 'tenant_model_provider': TenantModelProviderStage, - 'tenant_model_instance': TenantModelInstanceStage, - 'tenant_model': TenantModelStage, - 'model_id_config': ModelIdConfigStage, + "tenant_model_provider": TenantModelProviderStage, + "tenant_model_instance": TenantModelInstanceStage, + "tenant_model": TenantModelStage, + "model_id_config": ModelIdConfigStage, } @@ -1298,9 +1256,9 @@ def run_migration( """Run migration with specified stages""" stats = MigrationStats() stats.start() - + db = MigrationDatabase(config) - + try: db.connect() @@ -1325,26 +1283,26 @@ def run_migration( "Database migration version marker is not set, target version is %s", database_version, ) - + total_stages = len(stages) all_stages_completed = True - + for idx, stage_name in enumerate(stages, 1): logger.info(f"{'=' * 60}") logger.info(f"Stage [{idx}/{total_stages}]: {stage_name}") logger.info(f"{'=' * 60}") - + if stage_name not in MIGRATION_STAGES: logger.error(f"Unknown stage: {stage_name}") stats.add_stage_stats(stage_name, [], 0, 0) all_stages_completed = False continue - + stage_cls = MIGRATION_STAGES[stage_name] stage = stage_cls(db, dry_run=dry_run, create_table_only=create_table_only) - + stage_start = time.time() - + # For create_table_only mode, skip check and directly execute if create_table_only: logger.info("[CREATE TABLE ONLY] Skipping check, will create/verify target table") @@ -1355,22 +1313,16 @@ def run_migration( logger.info(f"Stage '{stage_name}' check: no migration needed") stats.add_stage_stats(stage_name, [], 0, time.time() - stage_start) continue - + # Execute migration rows, tables = stage.execute() - + stage_duration = time.time() - stage_start - + stats.add_stage_stats(stage_name, tables, rows, stage_duration) logger.info(f"Stage '{stage_name}' completed: {rows} rows in {stage_duration:.2f}s") - if ( - mark_database_version_on_success - and not dry_run - and not create_table_only - and database_version - and all_stages_completed - ): + if mark_database_version_on_success and not dry_run and not create_table_only and database_version and all_stages_completed: db.set_database_version(database_version) logger.info("Marked database migration version as %s", database_version) @@ -1421,7 +1373,7 @@ def mark_database_version(config: MigrationConfig, version: str) -> None: def main(): parser = argparse.ArgumentParser( - description='MySQL Data Migration Tool', + description="MySQL Data Migration Tool", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: @@ -1433,16 +1385,16 @@ Examples: # Mark database version separately python mysql_migration.py --mark-database-version --database-version v0.26.3 --config /path/to/config.yaml - + # Dry run (default - check only, no write) with config file python mysql_migration.py --stages tenant_model_provider --config /path/to/config.yaml - + # Dry run with command line MySQL connection python mysql_migration.py --stages tenant_model_provider --host localhost --port 3306 --user root --password secret - + # Create target tables only (no data migration) python mysql_migration.py --stages tenant_model_provider --config /path/to/config.yaml --create-table-only - + # Execute full migration (create tables and migrate data) python mysql_migration.py --stages tenant_model_provider --config /path/to/config.yaml --execute @@ -1451,79 +1403,61 @@ Examples: # Execute migration and mark the database version when all stages succeed python mysql_migration.py --stages tenant_model_provider,tenant_model_instance,tenant_model,model_id_config --config /path/to/config.yaml --execute --database-version v0.26.3 --mark-database-version-on-success - + # Normalize legacy model IDs in stored configs python mysql_migration.py --stages model_id_config --config /path/to/config.yaml --execute # Run multiple stages python mysql_migration.py --stages stage1,stage2,stage3 --config /path/to/config.yaml --execute -""" +""", ) - + # MySQL connection options - parser.add_argument('--host', type=str, default='localhost', - help='MySQL host (default: localhost)') - parser.add_argument('--port', type=int, default=3306, - help='MySQL port (default: 3306)') - parser.add_argument('--user', type=str, default='root', - help='MySQL user (default: root)') - parser.add_argument('--password', type=str, default='', - help='MySQL password (default: empty)') - parser.add_argument('--database', type=str, default='rag_flow', - help='MySQL database name (default: rag_flow)') - + parser.add_argument("--host", type=str, default="localhost", help="MySQL host (default: localhost)") + parser.add_argument("--port", type=int, default=3306, help="MySQL port (default: 3306)") + parser.add_argument("--user", type=str, default="root", help="MySQL user (default: root)") + parser.add_argument("--password", type=str, default="", help="MySQL password (default: empty)") + parser.add_argument("--database", type=str, default="rag_flow", help="MySQL database name (default: rag_flow)") + # Configuration options - parser.add_argument('--config', '-c', type=str, help='Path to YAML config file') - + parser.add_argument("--config", "-c", type=str, help="Path to YAML config file") + # Migration options - parser.add_argument('--stages', '-s', type=str, help='Comma-separated list of stages to run') - parser.add_argument('--list-stages', '-l', action='store_true', help='List available stages') - parser.add_argument('--check-database-version', action='store_true', - help='Check whether migration is needed for the target database version') - parser.add_argument('--mark-database-version', action='store_true', - help='Write the database migration version marker and exit') - parser.add_argument('--database-version', type=str, metavar='VERSION', - help='Database migration version used by check/mark commands and as the migration threshold for --stages') - parser.add_argument('--mark-database-version-on-success', action='store_true', - help='When used with --stages and --execute, write --database-version after all stages succeed') - parser.add_argument('--execute', '-e', action='store_true', default=False, - help='Execute full migration: create tables and migrate data') - parser.add_argument('--create-table-only', action='store_true', default=False, - help='Only create target tables, skip data migration') - + parser.add_argument("--stages", "-s", type=str, help="Comma-separated list of stages to run") + parser.add_argument("--list-stages", "-l", action="store_true", help="List available stages") + parser.add_argument("--check-database-version", action="store_true", help="Check whether migration is needed for the target database version") + parser.add_argument("--mark-database-version", action="store_true", help="Write the database migration version marker and exit") + parser.add_argument("--database-version", type=str, metavar="VERSION", help="Database migration version used by check/mark commands and as the migration threshold for --stages") + parser.add_argument("--mark-database-version-on-success", action="store_true", help="When used with --stages and --execute, write --database-version after all stages succeed") + parser.add_argument("--execute", "-e", action="store_true", default=False, help="Execute full migration: create tables and migrate data") + parser.add_argument("--create-table-only", action="store_true", default=False, help="Only create target tables, skip data migration") + args = parser.parse_args() - + # List stages and exit if args.list_stages: list_available_stages() return - + # Load configuration: command line args take precedence over config file if args.config: config = MigrationConfig.from_config_file(args.config) # Override with command line args if provided - if args.host != 'localhost': + if args.host != "localhost": config.host = args.host if args.port != 3306: config.port = args.port - if args.user != 'root': + if args.user != "root": config.user = args.user - if args.password != '': + if args.password != "": config.password = args.password - if args.database != 'rag_flow': + if args.database != "rag_flow": config.database = args.database else: # Use command line args directly - config = MigrationConfig( - host=args.host, - port=args.port, - user=args.user, - password=args.password, - database=args.database - ) - - logger.info(f"MySQL Configuration: host={config.host}, port={config.port}, " - f"user={config.user}, database={config.database}") + config = MigrationConfig(host=args.host, port=args.port, user=args.user, password=args.password, database=args.database) + + logger.info(f"MySQL Configuration: host={config.host}, port={config.port}, user={config.user}, database={config.database}") if args.check_database_version and args.mark_database_version: logger.error("--check-database-version and --mark-database-version are mutually exclusive") @@ -1545,7 +1479,7 @@ Examples: if args.mark_database_version_on_success and not args.database_version: logger.error("--mark-database-version-on-success requires --database-version") sys.exit(1) - + # Three mutually exclusive modes: dry-run (default), create-table-only, execute if args.execute and args.create_table_only: logger.error("--execute and --create-table-only are mutually exclusive") @@ -1555,10 +1489,10 @@ Examples: logger.error("No stages specified. Use --stages to specify stages or --list-stages to see available stages.") sys.exit(1) - stages = [s.strip() for s in args.stages.split(',')] + stages = [s.strip() for s in args.stages.split(",")] dry_run = True create_table_only = False - + if args.create_table_only: logger.info("Running in CREATE TABLE ONLY mode (create tables, no data migration)") dry_run = False @@ -1567,9 +1501,8 @@ Examples: logger.info("Running in EXECUTE mode (create tables and migrate data)") dry_run = False else: - logger.info("Running in DRY-RUN mode (check only, no write). " - "Use --create-table-only to create tables, or --execute for full migration.") - + logger.info("Running in DRY-RUN mode (check only, no write). Use --create-table-only to create tables, or --execute for full migration.") + run_migration( config=config, stages=stages, @@ -1580,5 +1513,5 @@ Examples: ) -if __name__ == '__main__': +if __name__ == "__main__": main()