diff --git a/test/testcases/test_sdk_api/common.py b/test/testcases/test_sdk_api/common.py index 4de02830d0..e3a0e3d030 100644 --- a/test/testcases/test_sdk_api/common.py +++ b/test/testcases/test_sdk_api/common.py @@ -20,6 +20,48 @@ from ragflow_sdk import Chat, Chunk, DataSet, Document, RAGFlow, Session from utils.file_utils import create_txt_file +REST_API_MAX_PAGE_SIZE = 100 + + +def list_all_documents(dataset: DataSet, *, limit: int | None = None, page_size: int = REST_API_MAX_PAGE_SIZE) -> list[Document]: + page_size = min(page_size, REST_API_MAX_PAGE_SIZE) + documents: list[Document] = [] + page = 1 + while True: + batch = dataset.list_documents(page=page, page_size=page_size) + documents.extend(batch) + if limit is not None and len(documents) >= limit: + return documents[:limit] + if len(batch) < page_size: + return documents + page += 1 + + +def list_all_sessions(chat_assistant: Chat, *, limit: int | None = None, page_size: int = REST_API_MAX_PAGE_SIZE) -> list[Session]: + page_size = min(page_size, REST_API_MAX_PAGE_SIZE) + sessions: list[Session] = [] + page = 1 + while True: + batch = chat_assistant.list_sessions(page=page, page_size=page_size) + sessions.extend(batch) + if limit is not None and len(sessions) >= limit: + return sessions[:limit] + if len(batch) < page_size: + return sessions + page += 1 + + +def valid_chat_llm_id(client: RAGFlow) -> str: + # SDK tests use the tenant's configured chat model; this helper discovers test fixture state, not SDK behavior. + res = client.get('/users/me/models') + data = res.json() + if data.get('code') == 0: + llm_id = (data.get('data') or {}).get('llm_id') + if llm_id: + return llm_id + raise Exception('No valid chat llm_id is configured for the current tenant') + + # DATASET MANAGEMENT def batch_create_datasets(client: RAGFlow, num: int) -> list[DataSet]: return [client.create_dataset(name=f"dataset_{i}") for i in range(num)] diff --git a/test/testcases/test_sdk_api/test_chat_assistant_management/test_create_chat_assistant.py b/test/testcases/test_sdk_api/test_chat_assistant_management/test_create_chat_assistant.py index f9470b2802..3b5dddd9c0 100644 --- a/test/testcases/test_sdk_api/test_chat_assistant_management/test_create_chat_assistant.py +++ b/test/testcases/test_sdk_api/test_chat_assistant_management/test_create_chat_assistant.py @@ -15,6 +15,7 @@ # import pytest +from common import valid_chat_llm_id from configs import CHAT_ASSISTANT_NAME_LIMIT from utils import encode_avatar from utils.file_utils import create_image_file @@ -127,12 +128,14 @@ class TestChatAssistantCreate: @pytest.mark.parametrize( "llm_id, expected_message", [ - ("glm-4", ""), + (valid_chat_llm_id, ""), ("unknown", "`llm_id` unknown doesn't exist"), ], ) def test_llm_id(self, client, add_chunks, llm_id, expected_message): dataset, _, _ = add_chunks + if callable(llm_id): + llm_id = llm_id(client) if expected_message: with pytest.raises(Exception) as exception_info: diff --git a/test/testcases/test_sdk_api/test_chat_assistant_management/test_update_chat_assistant.py b/test/testcases/test_sdk_api/test_chat_assistant_management/test_update_chat_assistant.py index 66b0044c39..6a3ed2d54c 100644 --- a/test/testcases/test_sdk_api/test_chat_assistant_management/test_update_chat_assistant.py +++ b/test/testcases/test_sdk_api/test_chat_assistant_management/test_update_chat_assistant.py @@ -15,6 +15,7 @@ # import pytest +from common import valid_chat_llm_id from configs import CHAT_ASSISTANT_NAME_LIMIT from utils import encode_avatar from utils.file_utils import create_image_file @@ -109,7 +110,7 @@ class TestChatAssistantUpdate: @pytest.mark.parametrize( "llm_setting, expected_message", [ - ({"model_name": "glm-4"}, ""), + ({"model_name": valid_chat_llm_id}, ""), ({"model_name": "unknown"}, "`llm_id` unknown doesn't exist"), ({"temperature": 0}, ""), ({"temperature": 1}, ""), @@ -142,7 +143,10 @@ class TestChatAssistantUpdate: def test_llm_setting(self, client, add_chat_assistants_func, llm_setting, expected_message): dataset, _, chat_assistants = add_chat_assistants_func chat_assistant = chat_assistants[0] + llm_setting = dict(llm_setting) llm_id = llm_setting.pop("model_name", None) + if callable(llm_id): + llm_id = llm_id(client) payload = {"name": "llm_test", "dataset_ids": [dataset.id], "llm_setting": llm_setting} if llm_id is not None: payload["llm_id"] = llm_id diff --git a/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_update_chunk.py b/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_update_chunk.py index 4f4debffab..d49368405d 100644 --- a/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_update_chunk.py +++ b/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_update_chunk.py @@ -166,4 +166,6 @@ class TestUpdatedChunk: with pytest.raises(Exception) as exception_info: chunks[0].update({}) - assert str(exception_info.value) in [f"You don't own the document {chunks[0].document_id}", f"Can't find this chunk {chunks[0].id}"], str(exception_info.value) + message = str(exception_info.value) + ownership_message = f"You don't own the document {chunks[0].document_id}" + assert message.rstrip(".") == ownership_message or message == f"Can't find this chunk {chunks[0].id}", message diff --git a/test/testcases/test_sdk_api/test_file_management_within_dataset/test_delete_documents.py b/test/testcases/test_sdk_api/test_file_management_within_dataset/test_delete_documents.py index 31627d6e88..ff2fe23129 100644 --- a/test/testcases/test_sdk_api/test_file_management_within_dataset/test_delete_documents.py +++ b/test/testcases/test_sdk_api/test_file_management_within_dataset/test_delete_documents.py @@ -16,7 +16,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed import pytest -from common import bulk_upload_documents +from common import bulk_upload_documents, list_all_documents class TestDocumentsDeletion: @@ -114,7 +114,7 @@ def test_delete_1k(add_dataset, tmp_path): count = 1_000 dataset = add_dataset documents = bulk_upload_documents(dataset, count, tmp_path) - assert len(dataset.list_documents(page_size=count * 2)) == count + assert len(list_all_documents(dataset, limit=count + 1)) == count dataset.delete_documents(ids=[doc.id for doc in documents]) assert len(dataset.list_documents()) == 0 diff --git a/test/testcases/test_sdk_api/test_file_management_within_dataset/test_parse_documents.py b/test/testcases/test_sdk_api/test_file_management_within_dataset/test_parse_documents.py index 7003dab6f2..e39c33cace 100644 --- a/test/testcases/test_sdk_api/test_file_management_within_dataset/test_parse_documents.py +++ b/test/testcases/test_sdk_api/test_file_management_within_dataset/test_parse_documents.py @@ -15,7 +15,7 @@ # from concurrent.futures import ThreadPoolExecutor, as_completed import pytest -from common import bulk_upload_documents +from common import bulk_upload_documents, list_all_documents from ragflow_sdk import DataSet from ragflow_sdk.modules.document import Document from utils import wait_for @@ -40,14 +40,17 @@ def condition(_dataset: DataSet, _document_ids: list[str] = None): def validate_document_details(dataset, document_ids): - documents = dataset.list_documents(page_size=100) - for document in documents: - if document.id in document_ids: + target_ids = set(document_ids) + found_ids = set() + for document in list_all_documents(dataset): + if document.id in target_ids: + found_ids.add(document.id) assert document.run == "DONE" assert len(document.process_begin_at) > 0 assert document.process_duration > 0 assert document.progress > 0 assert "Task done" in document.progress_msg + assert found_ids == target_ids class TestDocumentsParse: @@ -228,7 +231,9 @@ def test_async_cancel_parse_documents_raises_on_nonzero_code(add_dataset_func, m def test_parse_100_files(add_dataset_func, tmp_path): @wait_for(200, 1, "Document parsing timeout") def condition_inner(_dataset: DataSet, _count: int): - docs = _dataset.list_documents(page_size=_count * 2) + docs = list_all_documents(_dataset, limit=_count) + if len(docs) < _count: + return False for document in docs: if document.run != "DONE": return False @@ -248,7 +253,9 @@ def test_parse_100_files(add_dataset_func, tmp_path): def test_concurrent_parse(add_dataset_func, tmp_path): @wait_for(200, 1, "Document parsing timeout") def condition_inner(_dataset: DataSet, _count: int): - docs = _dataset.list_documents(page_size=_count * 2) + docs = list_all_documents(_dataset, limit=_count) + if len(docs) < _count: + return False for document in docs: if document.run != "DONE": return False diff --git a/test/testcases/test_sdk_api/test_session_management/test_create_session_with_chat_assistant.py b/test/testcases/test_sdk_api/test_session_management/test_create_session_with_chat_assistant.py index 4969589119..54300ff012 100644 --- a/test/testcases/test_sdk_api/test_session_management/test_create_session_with_chat_assistant.py +++ b/test/testcases/test_sdk_api/test_session_management/test_create_session_with_chat_assistant.py @@ -16,6 +16,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed import pytest +from common import list_all_sessions from configs import SESSION_WITH_CHAT_NAME_LIMIT from ragflow_sdk import RAGFlow from ragflow_sdk.modules.session import Session @@ -31,15 +32,6 @@ class _DummyStreamResponse: yield line -@pytest.fixture(scope="session") -def auth(): - return "unit-auth" - - -@pytest.fixture(scope="session", autouse=True) -def set_tenant_info(): - return None - @pytest.mark.usefixtures("clear_session_with_chat_assistants") class TestSessionWithChatAssistantCreate: @@ -84,7 +76,7 @@ class TestSessionWithChatAssistantCreate: responses = list(as_completed(futures)) assert len(responses) == count, responses - updated_sessions = chat_assistant.list_sessions(page_size=count * 2) + updated_sessions = list_all_sessions(chat_assistant, limit=count + 1) assert len(updated_sessions) == count @pytest.mark.p3