diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index fc4233504b..2ff30e628f 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -240,23 +240,14 @@ jobs: echo "Start to run test sdk on Infinity" source .venv/bin/activate && set -o pipefail; DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} --junitxml=pytest-infinity-sdk.xml --cov=sdk/python/ragflow_sdk --cov-branch --cov-report=xml:coverage-infinity-sdk.xml test/testcases/test_sdk_api 2>&1 | tee infinity_sdk_test.log - - name: Run web api tests against Infinity + - name: Run New RESTFUL api tests against Infinity run: | export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY="" until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null 2>&1; do echo "Waiting for service to be available... (last exit code: $?)" sleep 5 done - source .venv/bin/activate && set -o pipefail; DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_web_api/test_chunk_feedback 2>&1 | tee infinity_web_api_test.log - - - name: Run http api tests against Infinity - run: | - export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY="" - until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null 2>&1; do - echo "Waiting for service to be available... (last exit code: $?)" - sleep 5 - done - source .venv/bin/activate && set -o pipefail; DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api 2>&1 | tee infinity_http_api_test.log + source .venv/bin/activate && set -o pipefail; DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/restful_api 2>&1 | tee infinity_restful_api_test.log - name: RAGFlow CLI retrieval test Infinity env: @@ -432,24 +423,15 @@ jobs: done echo "Start to run test sdk on Elasticsearch" source .venv/bin/activate && set -o pipefail; pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} --junitxml=pytest-infinity-sdk.xml --cov=sdk/python/ragflow_sdk --cov-branch --cov-report=xml:coverage-es-sdk.xml test/testcases/test_sdk_api 2>&1 | tee es_sdk_test.log - - - name: Run web api tests against Elasticsearch + + - name: Run New RESTFUL api tests against Elasticsearch run: | export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY="" until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null 2>&1; do echo "Waiting for service to be available... (last exit code: $?)" sleep 5 done - source .venv/bin/activate && set -o pipefail; pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_web_api 2>&1 | tee es_web_api_test.log - - - name: Run http api tests against Elasticsearch - run: | - export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY="" - until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null 2>&1; do - echo "Waiting for service to be available... (last exit code: $?)" - sleep 5 - done - source .venv/bin/activate && set -o pipefail; pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api 2>&1 | tee es_http_api_test.log + source .venv/bin/activate && set -o pipefail; pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/restful_api 2>&1 | tee es_restful_api_test.log - name: RAGFlow CLI retrieval test Elasticsearch env: diff --git a/rag/utils/infinity_conn.py b/rag/utils/infinity_conn.py index 45290c520d..7ffd9f13d4 100644 --- a/rag/utils/infinity_conn.py +++ b/rag/utils/infinity_conn.py @@ -318,7 +318,10 @@ class InfinityConnection(InfinityConnectionBase): "authors_sm_tks"]: fields.add(field) res_fields = self.get_fields(res, list(fields)) - return res_fields.get(chunk_id, None) + chunk = res_fields.get(chunk_id, None) + if chunk is not None: + chunk["id"] = chunk_id + return chunk def insert(self, documents: list[dict], index_name: str, knowledgebase_id: str = None) -> list[str]: ''' diff --git a/test/testcases/restful_api/conftest.py b/test/testcases/restful_api/conftest.py new file mode 100644 index 0000000000..b24f0bda45 --- /dev/null +++ b/test/testcases/restful_api/conftest.py @@ -0,0 +1,163 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest + +from libs.auth import RAGFlowHttpApiAuth +from test.testcases.restful_api.helpers.client import RestClient +from utils.file_utils import create_txt_file +from utils import wait_for + + +@pytest.fixture(scope="session") +def RestApiAuth(token): + return RAGFlowHttpApiAuth(token) + + +@pytest.fixture(scope="session") +def rest_client(token): + return RestClient(token=token) + + +@pytest.fixture(scope="session") +def rest_client_noauth(): + return RestClient(token=None) + + +@pytest.fixture +def clear_datasets(rest_client): + def _cleanup(): + res = rest_client.delete("/datasets", json={"ids": None, "delete_all": True}) + assert res.status_code == 200, res.text + payload = res.json() + assert payload["code"] in (0, 102), payload + + yield + _cleanup() + + +@pytest.fixture +def clear_chats(rest_client): + def _cleanup(): + res = rest_client.delete("/chats", json={"ids": None, "delete_all": True}) + assert res.status_code == 200, res.text + payload = res.json() + assert payload["code"] in (0, 102), payload + + yield + _cleanup() + + +@pytest.fixture +def create_dataset(rest_client, clear_datasets): + created_ids: list[str] = [] + + def _create(name: str = "restful_dataset") -> str: + res = rest_client.post("/datasets", json={"name": name}) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + dataset_id = payload["data"]["id"] + created_ids.append(dataset_id) + return dataset_id + + yield _create + + if created_ids: + res = rest_client.delete("/datasets", json={"ids": created_ids}) + assert res.status_code == 200 + payload = res.json() + # Dataset may already be removed by test logic/cleanup. + assert payload["code"] in (0, 102), payload + + +@pytest.fixture +def create_chat(rest_client, clear_chats): + created_ids: list[str] = [] + + def _create(name: str = "restful_chat") -> str: + res = rest_client.post("/chats", json={"name": name, "dataset_ids": []}) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + chat_id = payload["data"]["id"] + created_ids.append(chat_id) + return chat_id + + yield _create + + if created_ids: + res = rest_client.delete("/chats", json={"ids": created_ids}) + assert res.status_code == 200, res.text + payload = res.json() + assert payload["code"] in (0, 102), payload + + +@pytest.fixture +def create_document(rest_client, create_dataset, tmp_path): + created_docs: list[tuple[str, str]] = [] + + def _create(name: str = "restful_doc.txt") -> tuple[str, str]: + dataset_id = create_dataset("dataset_for_doc") + fp = create_txt_file(tmp_path / name) + with fp.open("rb") as file_obj: + files = [("file", (fp.name, file_obj))] + res = rest_client.post(f"/datasets/{dataset_id}/documents", files=files) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + document_id = payload["data"][0]["id"] + created_docs.append((dataset_id, document_id)) + return dataset_id, document_id + + yield _create + + for dataset_id, document_id in created_docs: + res = rest_client.delete(f"/datasets/{dataset_id}/documents", json={"ids": [document_id]}) + assert res.status_code == 200, res.text + payload = res.json() + assert payload["code"] in (0, 102), payload + + +@wait_for(60, 1, "Document parsing timeout in RESTful batch2 tests") +def _parsed(rest_client: RestClient, dataset_id: str, document_id: str): + res = rest_client.get(f"/datasets/{dataset_id}/documents", params={"id": document_id}) + if res.status_code != 200: + return False + payload = res.json() + if payload["code"] != 0: + return False + docs = payload["data"]["docs"] + if not docs: + return False + return docs[0].get("run") == "DONE" + + +@pytest.fixture +def ensure_parsed_document(rest_client, create_document): + def _ensure() -> tuple[str, str]: + dataset_id, document_id = create_document() + res = rest_client.post( + f"/datasets/{dataset_id}/documents/parse", + json={"document_ids": [document_id]}, + ) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + _parsed(rest_client, dataset_id, document_id) + return dataset_id, document_id + + return _ensure diff --git a/test/testcases/restful_api/helpers/__init__.py b/test/testcases/restful_api/helpers/__init__.py new file mode 100644 index 0000000000..117dea3cf0 --- /dev/null +++ b/test/testcases/restful_api/helpers/__init__.py @@ -0,0 +1 @@ +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. diff --git a/test/testcases/restful_api/helpers/client.py b/test/testcases/restful_api/helpers/client.py new file mode 100644 index 0000000000..8c0a198fc2 --- /dev/null +++ b/test/testcases/restful_api/helpers/client.py @@ -0,0 +1,85 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from dataclasses import dataclass +from typing import Any + +import requests +from configs import HOST_ADDRESS, VERSION + + +@dataclass +class RestClient: + token: str | None = None + timeout: int = 30 + + @property + def api_root(self) -> str: + return f"{HOST_ADDRESS}/api/{VERSION}" + + def _headers(self, headers: dict[str, str] | None = None) -> dict[str, str]: + merged: dict[str, str] = {"Content-Type": "application/json"} + if headers: + merged.update(headers) + if self.token and "Authorization" not in merged: + merged["Authorization"] = f"Bearer {self.token}" + return merged + + def request( + self, + method: str, + path: str, + *, + headers: dict[str, str] | None = None, + params: dict[str, Any] | None = None, + json: dict[str, Any] | None = None, + data: Any = None, + files: Any = None, + **request_kwargs: Any, + ) -> requests.Response: + req_headers = self._headers(headers) + if files is not None: + # requests sets multipart boundary automatically. + req_headers.pop("Content-Type", None) + + timeout = request_kwargs.pop("timeout", self.timeout) + normalized_path = f"/{path.lstrip('/')}" if path else "/" + return requests.request( + method=method, + url=f"{self.api_root}{normalized_path}", + headers=req_headers, + params=params, + json=json, + data=data, + files=files, + timeout=timeout, + **request_kwargs, + ) + + def get(self, path: str, **kwargs) -> requests.Response: + return self.request("GET", path, **kwargs) + + def post(self, path: str, **kwargs) -> requests.Response: + return self.request("POST", path, **kwargs) + + def delete(self, path: str, **kwargs) -> requests.Response: + return self.request("DELETE", path, **kwargs) + + def put(self, path: str, **kwargs) -> requests.Response: + return self.request("PUT", path, **kwargs) + + def patch(self, path: str, **kwargs) -> requests.Response: + return self.request("PATCH", path, **kwargs) diff --git a/test/testcases/restful_api/test_agents.py b/test/testcases/restful_api/test_agents.py new file mode 100644 index 0000000000..9748b9fb96 --- /dev/null +++ b/test/testcases/restful_api/test_agents.py @@ -0,0 +1,333 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json + +import pytest + + +MINIMAL_DSL = { + "components": { + "begin": { + "obj": {"component_name": "Begin", "params": {}}, + "downstream": ["message"], + "upstream": [], + }, + "message": { + "obj": {"component_name": "Message", "params": {"content": ["{sys.query}"]}}, + "downstream": [], + "upstream": ["begin"], + }, + }, + "history": [], + "retrieval": [], + "path": [], + "globals": { + "sys.query": "", + "sys.user_id": "", + "sys.conversation_turns": 0, + "sys.files": [], + }, + "variables": {}, +} + + +def _sse_events(response_text: str) -> list[str]: + return [line[5:] for line in response_text.splitlines() if line.startswith("data:")] + + +@pytest.fixture +def create_agent_resource(rest_client): + created_agent_ids: list[str] = [] + + def _create(title: str = "restful_agent") -> str: + res = rest_client.post("/agents", json={"title": title, "dsl": MINIMAL_DSL}) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + agent_id = payload["data"]["id"] + created_agent_ids.append(agent_id) + return agent_id + + yield _create + + cleanup_errors = [] + for agent_id in created_agent_ids: + res = rest_client.delete(f"/agents/{agent_id}") + if res.status_code != 200: + cleanup_errors.append((agent_id, res.status_code, res.text)) + continue + payload = res.json() + if payload["code"] not in (0, 103): + cleanup_errors.append((agent_id, res.status_code, payload)) + assert not cleanup_errors, f"Agent cleanup failed: {cleanup_errors}" + + +@pytest.mark.p2 +def test_agents_crud_validation_contract(rest_client, create_agent_resource): + list_empty = rest_client.get("/agents", params={"title": "missing_restful_agent"}) + assert list_empty.status_code == 200 + list_empty_payload = list_empty.json() + assert list_empty_payload["code"] == 0, list_empty_payload + assert "canvas" in list_empty_payload["data"], list_empty_payload + assert "total" in list_empty_payload["data"], list_empty_payload + + missing_dsl = rest_client.post("/agents", json={"title": "missing_dsl_agent"}) + assert missing_dsl.status_code == 200 + missing_dsl_payload = missing_dsl.json() + assert missing_dsl_payload["code"] == 101, missing_dsl_payload + assert "No DSL data in request" in missing_dsl_payload["message"], missing_dsl_payload + + missing_title = rest_client.post("/agents", json={"dsl": MINIMAL_DSL}) + assert missing_title.status_code == 200 + missing_title_payload = missing_title.json() + assert missing_title_payload["code"] == 101, missing_title_payload + assert "No title in request" in missing_title_payload["message"], missing_title_payload + + agent_id = create_agent_resource("restful_agent_crud") + + duplicate = rest_client.post("/agents", json={"title": "restful_agent_crud", "dsl": MINIMAL_DSL}) + assert duplicate.status_code == 200 + duplicate_payload = duplicate.json() + assert duplicate_payload["code"] == 102, duplicate_payload + assert "already exists" in duplicate_payload["message"], duplicate_payload + + get_res = rest_client.get(f"/agents/{agent_id}") + assert get_res.status_code == 200 + get_payload = get_res.json() + assert get_payload["code"] == 0, get_payload + assert get_payload["data"]["id"] == agent_id, get_payload + + update_res = rest_client.put(f"/agents/{agent_id}", json={"title": "restful_agent_crud_updated", "dsl": MINIMAL_DSL}) + assert update_res.status_code == 200 + update_payload = update_res.json() + assert update_payload["code"] == 0, update_payload + + list_after_update = rest_client.get("/agents", params={"title": "restful_agent_crud_updated"}) + assert list_after_update.status_code == 200 + list_after_update_payload = list_after_update.json() + assert list_after_update_payload["code"] == 0, list_after_update_payload + assert list_after_update_payload["data"]["total"] >= 1, list_after_update_payload + + delete_res = rest_client.delete(f"/agents/{agent_id}") + assert delete_res.status_code == 200 + delete_payload = delete_res.json() + assert delete_payload["code"] == 0, delete_payload + assert delete_payload["data"] is True, delete_payload + + +@pytest.mark.p2 +def test_agent_sessions_crud(rest_client, create_agent_resource): + agent_id = create_agent_resource("restful_agent_sessions") + + create_session = rest_client.post(f"/agents/{agent_id}/sessions", json={"name": "agent_session_1"}) + assert create_session.status_code == 200 + create_session_payload = create_session.json() + assert create_session_payload["code"] == 0, create_session_payload + session_id = create_session_payload["data"]["id"] + + list_sessions = rest_client.get(f"/agents/{agent_id}/sessions") + assert list_sessions.status_code == 200 + list_sessions_payload = list_sessions.json() + assert list_sessions_payload["code"] == 0, list_sessions_payload + assert isinstance(list_sessions_payload["data"], list), list_sessions_payload + assert any(item["id"] == session_id for item in list_sessions_payload["data"]), list_sessions_payload + + get_session = rest_client.get(f"/agents/{agent_id}/sessions/{session_id}") + assert get_session.status_code == 200 + get_session_payload = get_session.json() + assert get_session_payload["code"] == 0, get_session_payload + assert get_session_payload["data"]["id"] == session_id, get_session_payload + + delete_session = rest_client.delete(f"/agents/{agent_id}/sessions/{session_id}") + assert delete_session.status_code == 200 + delete_session_payload = delete_session.json() + assert delete_session_payload["code"] == 0, delete_session_payload + + +@pytest.mark.p2 +def test_agent_chat_completion_validation(rest_client): + missing_agent_id = rest_client.post( + "/agents/chat/completions", + json={"query": "hello", "stream": False}, + ) + assert missing_agent_id.status_code == 200 + missing_agent_id_payload = missing_agent_id.json() + assert missing_agent_id_payload["code"] == 101, missing_agent_id_payload + assert "`agent_id` is required." in missing_agent_id_payload["message"], missing_agent_id_payload + + +@pytest.mark.p2 +def test_agent_chat_completion_nonstream(rest_client, create_agent_resource): + agent_id = create_agent_resource("restful_agent_nonstream") + create_session = rest_client.post(f"/agents/{agent_id}/sessions", json={"name": "agent_completion_session"}) + assert create_session.status_code == 200 + create_session_payload = create_session.json() + assert create_session_payload["code"] == 0, create_session_payload + session_id = create_session_payload["data"]["id"] + + res = rest_client.post( + "/agents/chat/completions", + json={"agent_id": agent_id, "query": "hello", "stream": False, "session_id": session_id}, + timeout=60, + ) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + assert isinstance(payload["data"], dict), payload + assert isinstance(payload["data"].get("data"), dict), payload + assert "content" in payload["data"]["data"], payload + + +@pytest.mark.p2 +def test_agent_chat_completion_stream_structure_and_done(rest_client, create_agent_resource): + agent_id = create_agent_resource("restful_agent_stream") + create_session = rest_client.post(f"/agents/{agent_id}/sessions", json={"name": "agent_stream_session"}) + assert create_session.status_code == 200 + create_session_payload = create_session.json() + assert create_session_payload["code"] == 0, create_session_payload + session_id = create_session_payload["data"]["id"] + + res = rest_client.post( + "/agents/chat/completions", + json={ + "agent_id": agent_id, + "query": "hello", + "stream": True, + "session_id": session_id, + "return_trace": True, + }, + timeout=60, + ) + assert res.status_code == 200 + content_type = res.headers.get("Content-Type", "") + assert "text/event-stream" in content_type, content_type + + events = _sse_events(res.text) + assert events, res.text + assert events[-1].strip() == "[DONE]", events[-1] + + json_events = [json.loads(evt) for evt in events if evt.strip() != "[DONE]"] + assert json_events, events + assert any(isinstance(evt, dict) for evt in json_events), json_events + + +@pytest.mark.p2 +def test_agent_openai_compatible_mode(rest_client, create_agent_resource): + agent_id = create_agent_resource("restful_agent_openai_compat") + + missing_messages = rest_client.post( + "/agents/chat/completions", + json={"agent_id": agent_id, "openai-compatible": True, "model": "model", "messages": []}, + ) + assert missing_messages.status_code == 200 + missing_messages_payload = missing_messages.json() + assert missing_messages_payload["code"] == 102, missing_messages_payload + assert "at least one message" in missing_messages_payload["message"], missing_messages_payload + + nonstream = rest_client.post( + "/agents/chat/completions", + json={ + "agent_id": agent_id, + "openai-compatible": True, + "model": "model", + "messages": [{"role": "user", "content": "hello"}], + "stream": False, + }, + timeout=60, + ) + assert nonstream.status_code == 200 + nonstream_payload = nonstream.json() + assert isinstance(nonstream_payload, dict), nonstream_payload + assert "choices" in nonstream_payload, nonstream_payload + + stream = rest_client.post( + "/agents/chat/completions", + json={ + "agent_id": agent_id, + "openai-compatible": True, + "model": "model", + "messages": [{"role": "user", "content": "hello"}], + "stream": True, + }, + timeout=60, + ) + assert stream.status_code == 200 + stream_content_type = stream.headers.get("Content-Type", "") + assert "text/event-stream" in stream_content_type, stream_content_type + + +@pytest.mark.p2 +def test_agent_support_routes_auth_and_contracts(rest_client, rest_client_noauth, create_agent_resource): + prompts_unauth = rest_client_noauth.get("/agents/prompts") + assert prompts_unauth.status_code == 401 + assert prompts_unauth.json()["code"] == 401 + + prompts = rest_client.get("/agents/prompts") + assert prompts.status_code == 200 + prompts_payload = prompts.json() + assert prompts_payload["code"] == 0, prompts_payload + assert "task_analysis" in prompts_payload["data"], prompts_payload + assert "citation_guidelines" in prompts_payload["data"], prompts_payload + + templates = rest_client.get("/agents/templates") + assert templates.status_code == 200 + templates_payload = templates.json() + assert templates_payload["code"] == 0, templates_payload + assert isinstance(templates_payload["data"], list), templates_payload + + agent_id = create_agent_resource("restful_agent_support") + versions = rest_client.get(f"/agents/{agent_id}/versions") + assert versions.status_code == 200 + versions_payload = versions.json() + assert versions_payload["code"] == 0, versions_payload + assert isinstance(versions_payload["data"], list), versions_payload + + logs = rest_client.get(f"/agents/{agent_id}/logs/missing_message") + assert logs.status_code == 200 + logs_payload = logs.json() + assert logs_payload["code"] == 0, logs_payload + assert isinstance(logs_payload["data"], dict), logs_payload + + +@pytest.mark.p2 +def test_agent_webhook_logs_empty_poll_contract(rest_client, create_agent_resource): + agent_id = create_agent_resource("restful_agent_webhook_logs") + res = rest_client.get(f"/agents/{agent_id}/webhook/logs", params={"since_ts": 0}) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + assert payload["data"]["events"] == [], payload + assert payload["data"]["finished"] is False, payload + assert "next_since_ts" in payload["data"], payload + + +@pytest.mark.p2 +def test_agent_db_connection_validates_required_fields(rest_client): + res = rest_client.post("/agents/test_db_connection", json={"db_type": "mysql"}) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 101, payload + assert "required argument are missing" in payload["message"], payload + + +@pytest.mark.p2 +def test_agent_rerun_requires_required_fields(rest_client): + res = rest_client.post("/agents/rerun", json={"id": "flow-1"}) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 101, payload + assert "required argument are missing" in payload["message"], payload diff --git a/test/testcases/restful_api/test_chats.py b/test/testcases/restful_api/test_chats.py new file mode 100644 index 0000000000..45fe5a8d2a --- /dev/null +++ b/test/testcases/restful_api/test_chats.py @@ -0,0 +1,123 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest + + +@pytest.mark.p1 +class TestChatsAuthorization: + def test_create_requires_auth(self, rest_client_noauth): + res = rest_client_noauth.post("/chats", json={"name": "chat_auth", "dataset_ids": []}) + assert res.status_code == 401 + + +@pytest.mark.p1 +def test_chat_crud_cycle(rest_client, clear_chats): + create_res = rest_client.post( + "/chats", + json={"name": "restful_chat_crud", "dataset_ids": []}, + ) + assert create_res.status_code == 200 + create_payload = create_res.json() + assert create_payload["code"] == 0, create_payload + chat_id = create_payload["data"]["id"] + + list_res = rest_client.get("/chats", params={"id": chat_id}) + assert list_res.status_code == 200 + list_payload = list_res.json() + assert list_payload["code"] == 0, list_payload + chats = list_payload["data"]["chats"] + assert len(chats) == 1, list_payload + assert chats[0]["id"] == chat_id, list_payload + + get_res = rest_client.get(f"/chats/{chat_id}") + assert get_res.status_code == 200 + get_payload = get_res.json() + assert get_payload["code"] == 0, get_payload + assert get_payload["data"]["id"] == chat_id, get_payload + + update_res = rest_client.put( + f"/chats/{chat_id}", + json={"name": "restful_chat_crud_updated", "dataset_ids": []}, + ) + assert update_res.status_code == 200 + update_payload = update_res.json() + assert update_payload["code"] == 0, update_payload + assert update_payload["data"]["name"] == "restful_chat_crud_updated", update_payload + + patch_res = rest_client.patch(f"/chats/{chat_id}", json={"name": "restful_chat_crud_patched"}) + assert patch_res.status_code == 200 + patch_payload = patch_res.json() + assert patch_payload["code"] == 0, patch_payload + assert patch_payload["data"]["name"] == "restful_chat_crud_patched", patch_payload + + delete_res = rest_client.delete("/chats", json={"ids": [chat_id]}) + assert delete_res.status_code == 200 + delete_payload = delete_res.json() + assert delete_payload["code"] == 0, delete_payload + assert delete_payload["data"]["success_count"] == 1, delete_payload + + list_after_delete = rest_client.get("/chats", params={"id": chat_id}) + assert list_after_delete.status_code == 200 + list_after_delete_payload = list_after_delete.json() + assert list_after_delete_payload["code"] == 0, list_after_delete_payload + assert list_after_delete_payload["data"]["chats"] == [], list_after_delete_payload + + +@pytest.mark.p2 +@pytest.mark.parametrize( + "name, expected_fragment", + [ + ("", "`name` is required."), + (" ", "`name` is required."), + ], +) +def test_chat_create_name_validation(rest_client, clear_chats, name, expected_fragment): + res = rest_client.post("/chats", json={"name": name, "dataset_ids": []}) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 102, payload + assert expected_fragment in payload["message"], payload + + +@pytest.mark.p2 +def test_chat_duplicate_name_validation(rest_client, clear_chats): + first = rest_client.post("/chats", json={"name": "duplicate_chat_name", "dataset_ids": []}) + assert first.status_code == 200 + first_payload = first.json() + assert first_payload["code"] == 0, first_payload + + second = rest_client.post("/chats", json={"name": "duplicate_chat_name", "dataset_ids": []}) + assert second.status_code == 200 + second_payload = second.json() + assert second_payload["code"] == 102, second_payload + assert "Duplicated chat name" in second_payload["message"], second_payload + + +@pytest.mark.p2 +def test_chat_list_pagination(rest_client, clear_chats): + for i in range(3): + res = rest_client.post("/chats", json={"name": f"chat_page_{i}", "dataset_ids": []}) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + + page_res = rest_client.get("/chats", params={"page": 1, "page_size": 2, "orderby": "create_time", "desc": "true"}) + assert page_res.status_code == 200 + page_payload = page_res.json() + assert page_payload["code"] == 0, page_payload + assert len(page_payload["data"]["chats"]) == 2, page_payload + assert page_payload["data"]["total"] >= 3, page_payload diff --git a/test/testcases/restful_api/test_chunks.py b/test/testcases/restful_api/test_chunks.py new file mode 100644 index 0000000000..42009a2af5 --- /dev/null +++ b/test/testcases/restful_api/test_chunks.py @@ -0,0 +1,124 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest + + +def _assert_created_chunk_id(payload): + chunk_id = payload["data"]["chunk"].get("id") + assert chunk_id, payload + assert isinstance(chunk_id, str), payload + assert chunk_id.strip(), payload + return chunk_id + + +@pytest.mark.p1 +def test_chunks_add_list_get_update_delete_cycle(rest_client, create_document): + dataset_id, document_id = create_document("chunk_cycle.txt") + base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks" + + add_res = rest_client.post( + base_path, + json={"content": "batch2 chunk content", "important_keywords": ["batch2"], "questions": ["what is batch2?"]}, + ) + assert add_res.status_code == 200 + add_payload = add_res.json() + assert add_payload["code"] == 0, add_payload + chunk_id = _assert_created_chunk_id(add_payload) + + list_res = rest_client.get(base_path, params={"id": chunk_id}) + assert list_res.status_code == 200 + list_payload = list_res.json() + assert list_payload["code"] == 0, list_payload + assert list_payload["data"]["total"] == 1, list_payload + assert list_payload["data"]["chunks"][0]["id"] == chunk_id, list_payload + + get_res = rest_client.get(f"{base_path}/{chunk_id}") + assert get_res.status_code == 200 + get_payload = get_res.json() + assert get_payload["code"] == 0, get_payload + assert get_payload["data"]["id"] == chunk_id, get_payload + + update_res = rest_client.patch( + f"{base_path}/{chunk_id}", + json={"content": "batch2 chunk content updated"}, + ) + assert update_res.status_code == 200 + update_payload = update_res.json() + assert update_payload["code"] == 0, update_payload + + get_updated_res = rest_client.get(f"{base_path}/{chunk_id}") + assert get_updated_res.status_code == 200 + get_updated_payload = get_updated_res.json() + assert get_updated_payload["code"] == 0, get_updated_payload + assert get_updated_payload["data"]["content_with_weight"] == "batch2 chunk content updated", get_updated_payload + + delete_candidate_res = rest_client.post(base_path, json={"content": "batch2 chunk content to delete"}) + assert delete_candidate_res.status_code == 200 + delete_candidate_payload = delete_candidate_res.json() + assert delete_candidate_payload["code"] == 0, delete_candidate_payload + delete_candidate_id = _assert_created_chunk_id(delete_candidate_payload) + + delete_res = rest_client.delete(base_path, json={"chunk_ids": [delete_candidate_id]}) + assert delete_res.status_code == 200 + delete_payload = delete_res.json() + assert delete_payload["code"] == 0, delete_payload + + deleted_list_res = rest_client.get(base_path, params={"id": delete_candidate_id}) + assert deleted_list_res.status_code == 200 + deleted_list_payload = deleted_list_res.json() + assert deleted_list_payload["code"] != 0, deleted_list_payload + + deleted_get_res = rest_client.get(f"{base_path}/{delete_candidate_id}") + assert deleted_get_res.status_code == 200 + deleted_get_payload = deleted_get_res.json() + assert deleted_get_payload["code"] != 0, deleted_get_payload + + +@pytest.mark.p2 +def test_chunks_add_requires_content(rest_client, create_document): + dataset_id, document_id = create_document("chunk_requires_content.txt") + res = rest_client.post( + f"/datasets/{dataset_id}/documents/{document_id}/chunks", + json={"content": " "}, + ) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 102, payload + assert payload["message"] == "`content` is required", payload + + +@pytest.mark.p2 +def test_chunks_list_empty_document(rest_client, create_document): + dataset_id, document_id = create_document("chunk_list_empty.txt") + base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks" + list_res = rest_client.get(base_path) + assert list_res.status_code == 200 + list_payload = list_res.json() + assert list_payload["code"] == 0, list_payload + assert "chunks" in list_payload["data"], list_payload + assert "doc" in list_payload["data"], list_payload + + +@pytest.mark.p2 +def test_chunks_delete_partial_invalid(rest_client, create_document): + dataset_id, document_id = create_document("chunk_delete_partial.txt") + base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks" + delete_res = rest_client.delete(base_path, json={"chunk_ids": ["invalid_chunk_id"]}) + assert delete_res.status_code == 200 + delete_payload = delete_res.json() + assert delete_payload["code"] == 102, delete_payload + assert "expect 1" in delete_payload["message"], delete_payload diff --git a/test/testcases/restful_api/test_connector_routes_unit.py b/test/testcases/restful_api/test_connector_routes_unit.py new file mode 100644 index 0000000000..ad47aef379 --- /dev/null +++ b/test/testcases/restful_api/test_connector_routes_unit.py @@ -0,0 +1,718 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import asyncio +import importlib.util +import json +import sys +from pathlib import Path +from types import ModuleType, SimpleNamespace + +import pytest + + +class _DummyManager: + def route(self, *_args, **_kwargs): + def decorator(func): + return func + + return decorator + + +class _AwaitableValue: + def __init__(self, value): + self._value = value + + def __await__(self): + async def _co(): + return self._value + + return _co().__await__() + + +class _Args(dict): + def get(self, key, default=None, type=None): + value = super().get(key, default) + if type is None: + return value + try: + return type(value) + except (TypeError, ValueError): + return default + + def to_dict(self, flat=True): + return dict(self) + + +class _FakeResponse: + def __init__(self, body, status_code): + self.body = body + self.status_code = status_code + self.headers = {} + + +class _FakeConnectorRecord: + def __init__(self, payload): + self._payload = payload + + def to_dict(self): + return dict(self._payload) + + +class _FakeCredentials: + def __init__(self, raw='{"refresh_token":"rt","access_token":"at"}'): + self._raw = raw + + def to_json(self): + return self._raw + + +class _FakeFlow: + def __init__(self, client_config, scopes): + self.client_config = client_config + self.scopes = scopes + self.redirect_uri = None + self.credentials = _FakeCredentials() + self.auth_kwargs = None + self.token_code = None + self.token_code_verifier = None + self.code_verifier = "fake-code-verifier" + + def authorization_url(self, **kwargs): + self.auth_kwargs = dict(kwargs) + return f"https://oauth.example/{kwargs['state']}", kwargs["state"] + + def fetch_token(self, code, code_verifier=None): + self.token_code = code + self.token_code_verifier = code_verifier + + +class _FakeBoxToken: + def __init__(self, access_token, refresh_token): + self.access_token = access_token + self.refresh_token = refresh_token + + +class _FakeBoxOAuth: + def __init__(self, config): + self.config = config + self.exchange_code = None + + def get_authorize_url(self, options): + return f"https://box.example/auth?state={options.state}&redirect={options.redirect_uri}" + + def get_tokens_authorization_code_grant(self, code): + self.exchange_code = code + + def retrieve_token(self): + return _FakeBoxToken("box-access", "box-refresh") + + +class _FakeRedis: + def __init__(self): + self.store = {} + self.set_calls = [] + self.deleted = [] + + def get(self, key): + return self.store.get(key) + + def set_obj(self, key, obj, ttl): + self.set_calls.append((key, obj, ttl)) + self.store[key] = json.dumps(obj) + + def delete(self, key): + self.deleted.append(key) + self.store.pop(key, None) + + +def _run(coro): + return asyncio.run(coro) + + +def _set_request(module, *, args=None, json_body=None): + module.request = SimpleNamespace( + args=_Args(args or {}), + json=_AwaitableValue({} if json_body is None else json_body), + ) + + +@pytest.fixture(scope="session") +def auth(): + return "unit-auth" + + +@pytest.fixture(scope="session", autouse=True) +def set_tenant_info(): + return None + + +def _load_connector_app(monkeypatch): + repo_root = Path(__file__).resolve().parents[3] + + api_pkg = ModuleType("api") + api_pkg.__path__ = [str(repo_root / "api")] + monkeypatch.setitem(sys.modules, "api", api_pkg) + + apps_mod = ModuleType("api.apps") + apps_mod.__path__ = [str(repo_root / "api" / "apps")] + apps_mod.current_user = SimpleNamespace(id="tenant-1") + apps_mod.login_required = lambda fn: fn + monkeypatch.setitem(sys.modules, "api.apps", apps_mod) + + db_mod = ModuleType("api.db") + db_mod.InputType = SimpleNamespace(POLL="POLL") + monkeypatch.setitem(sys.modules, "api.db", db_mod) + + services_pkg = ModuleType("api.db.services") + services_pkg.__path__ = [] + monkeypatch.setitem(sys.modules, "api.db.services", services_pkg) + + connector_service_mod = ModuleType("api.db.services.connector_service") + + class _StubConnectorService: + @staticmethod + def update_by_id(*_args, **_kwargs): + return True + + @staticmethod + def save(**_kwargs): + return True + + @staticmethod + def get_by_id(_connector_id): + return True, _FakeConnectorRecord({"id": _connector_id}) + + @staticmethod + def list(_tenant_id): + return [] + + @staticmethod + def resume(*_args, **_kwargs): + return True + + @staticmethod + def rebuild(*_args, **_kwargs): + return None + + @staticmethod + def delete_by_id(*_args, **_kwargs): + return True + + class _StubSyncLogsService: + @staticmethod + def list_sync_tasks(*_args, **_kwargs): + return [], 0 + + connector_service_mod.ConnectorService = _StubConnectorService + connector_service_mod.SyncLogsService = _StubSyncLogsService + monkeypatch.setitem(sys.modules, "api.db.services.connector_service", connector_service_mod) + + api_utils_mod = ModuleType("api.utils.api_utils") + + async def _get_request_json(): + return {} + + api_utils_mod.get_request_json = _get_request_json + api_utils_mod.get_json_result = lambda data=None, message="", code=0: { + "code": code, + "message": message, + "data": data, + } + api_utils_mod.get_data_error_result = lambda message="", code=400, data=None: { + "code": code, + "message": message, + "data": data, + } + api_utils_mod.validate_request = lambda *_args, **_kwargs: (lambda fn: fn) + monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) + + constants_mod = ModuleType("common.constants") + constants_mod.RetCode = SimpleNamespace( + ARGUMENT_ERROR=101, + SERVER_ERROR=500, + RUNNING=102, + PERMISSION_ERROR=403, + ) + constants_mod.TaskStatus = SimpleNamespace(SCHEDULE="schedule", CANCEL="cancel") + monkeypatch.setitem(sys.modules, "common.constants", constants_mod) + + config_mod = ModuleType("common.data_source.config") + config_mod.GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI = "https://example.com/drive" + config_mod.GMAIL_WEB_OAUTH_REDIRECT_URI = "https://example.com/gmail" + config_mod.BOX_WEB_OAUTH_REDIRECT_URI = "https://example.com/box" + config_mod.DocumentSource = SimpleNamespace(GMAIL="gmail", GOOGLE_DRIVE="google-drive") + monkeypatch.setitem(sys.modules, "common.data_source.config", config_mod) + + google_constants_mod = ModuleType("common.data_source.google_util.constant") + google_constants_mod.WEB_OAUTH_POPUP_TEMPLATE = ( + "{title}" + "

{heading}

{message}

" + ) + google_constants_mod.GOOGLE_SCOPES = { + config_mod.DocumentSource.GMAIL: ["scope-gmail"], + config_mod.DocumentSource.GOOGLE_DRIVE: ["scope-drive"], + } + monkeypatch.setitem(sys.modules, "common.data_source.google_util.constant", google_constants_mod) + + misc_mod = ModuleType("common.misc_utils") + misc_mod.get_uuid = lambda: "uuid-from-helper" + monkeypatch.setitem(sys.modules, "common.misc_utils", misc_mod) + + rag_pkg = ModuleType("rag") + rag_pkg.__path__ = [str(repo_root / "rag")] + monkeypatch.setitem(sys.modules, "rag", rag_pkg) + + rag_utils_pkg = ModuleType("rag.utils") + rag_utils_pkg.__path__ = [str(repo_root / "rag" / "utils")] + monkeypatch.setitem(sys.modules, "rag.utils", rag_utils_pkg) + + redis_mod = ModuleType("rag.utils.redis_conn") + redis_mod.REDIS_CONN = _FakeRedis() + monkeypatch.setitem(sys.modules, "rag.utils.redis_conn", redis_mod) + + quart_mod = ModuleType("quart") + quart_mod.request = SimpleNamespace(args=_Args(), json=_AwaitableValue({})) + + async def _make_response(body, status_code): + return _FakeResponse(body, status_code) + + quart_mod.make_response = _make_response + monkeypatch.setitem(sys.modules, "quart", quart_mod) + + google_pkg = ModuleType("google_auth_oauthlib") + google_pkg.__path__ = [] + monkeypatch.setitem(sys.modules, "google_auth_oauthlib", google_pkg) + + google_flow_mod = ModuleType("google_auth_oauthlib.flow") + + class _StubFlow: + @classmethod + def from_client_config(cls, client_config, scopes): + return _FakeFlow(client_config, scopes) + + google_flow_mod.Flow = _StubFlow + monkeypatch.setitem(sys.modules, "google_auth_oauthlib.flow", google_flow_mod) + + box_mod = ModuleType("box_sdk_gen") + + class _OAuthConfig: + def __init__(self, client_id, client_secret): + self.client_id = client_id + self.client_secret = client_secret + + class _GetAuthorizeUrlOptions: + def __init__(self, redirect_uri, state): + self.redirect_uri = redirect_uri + self.state = state + + box_mod.BoxOAuth = _FakeBoxOAuth + box_mod.OAuthConfig = _OAuthConfig + box_mod.GetAuthorizeUrlOptions = _GetAuthorizeUrlOptions + monkeypatch.setitem(sys.modules, "box_sdk_gen", box_mod) + + module_path = repo_root / "api" / "apps" / "restful_apis" / "connector_api.py" + spec = importlib.util.spec_from_file_location("test_connector_routes_unit", module_path) + module = importlib.util.module_from_spec(spec) + module.manager = _DummyManager() + spec.loader.exec_module(module) + return module + + +@pytest.mark.p2 +def test_connector_basic_routes_and_task_controls(monkeypatch): + module = _load_connector_app(monkeypatch) + + async def _no_sleep(_secs): + return None + + monkeypatch.setattr(module.asyncio, "sleep", _no_sleep) + + records = {"conn-1": _FakeConnectorRecord({"id": "conn-1", "source": "drive"})} + update_calls = [] + save_calls = [] + resume_calls = [] + delete_calls = [] + + monkeypatch.setattr(module.ConnectorService, "update_by_id", lambda cid, payload: update_calls.append((cid, payload))) + + def _save(**payload): + save_calls.append(payload) + records[payload["id"]] = _FakeConnectorRecord(payload) + + monkeypatch.setattr(module.ConnectorService, "save", _save) + monkeypatch.setattr(module.ConnectorService, "get_by_id", lambda cid: (True, records[cid])) + monkeypatch.setattr(module.ConnectorService, "list", lambda tenant_id: [{"id": "listed", "tenant": tenant_id}]) + monkeypatch.setattr(module.SyncLogsService, "list_sync_tasks", lambda cid, page, page_size: ([{"id": "log-1"}], 9)) + monkeypatch.setattr(module.ConnectorService, "resume", lambda cid, status: resume_calls.append((cid, status))) + monkeypatch.setattr(module.ConnectorService, "delete_by_id", lambda cid: delete_calls.append(cid)) + monkeypatch.setattr(module, "get_uuid", lambda: "generated-id") + + monkeypatch.setattr( + module, + "get_request_json", + lambda: _AwaitableValue({"id": "conn-1", "refresh_freq": 7, "config": {"x": 1}}), + ) + res = _run(module.update_connector("conn-1")) + assert update_calls == [("conn-1", {'id': 'conn-1', "refresh_freq": 7, "config": {"x": 1}})] + assert res["data"]["id"] == "conn-1" + + monkeypatch.setattr( + module, + "get_request_json", + lambda: _AwaitableValue({"name": "new", "source": "gmail", "config": {"y": 2}}), + ) + res = _run(module.create_connector()) + assert save_calls[-1]["id"] == "generated-id" + assert save_calls[-1]["tenant_id"] == "tenant-1" + assert save_calls[-1]["input_type"] == module.InputType.POLL + assert res["data"]["id"] == "generated-id" + + list_res = module.list_connector() + assert list_res["data"] == [{"id": "listed", "tenant": "tenant-1"}] + + monkeypatch.setattr(module.ConnectorService, "get_by_id", lambda _cid: (False, None)) + missing_res = module.get_connector("missing") + assert missing_res["message"] == "Can't find this Connector!" + + monkeypatch.setattr(module.ConnectorService, "get_by_id", lambda cid: (True, _FakeConnectorRecord({"id": cid}))) + found_res = module.get_connector("conn-2") + assert found_res["data"]["id"] == "conn-2" + + _set_request(module, args={"page": "2", "page_size": "7"}) + logs_res = module.list_logs("conn-log") + assert logs_res["data"] == {"total": 9, "logs": [{"id": "log-1"}]} + + monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"resume": True})) + assert _run(module.resume("conn-r1"))["data"] is True + + monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"resume": False})) + assert _run(module.resume("conn-r2"))["data"] is True + assert ("conn-r1", module.TaskStatus.SCHEDULE) in resume_calls + assert ("conn-r2", module.TaskStatus.CANCEL) in resume_calls + + monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"kb_id": "kb-1"})) + monkeypatch.setattr(module.ConnectorService, "rebuild", lambda *_args: "rebuild-failed") + failed_rebuild = _run(module.rebuild("conn-rb")) + assert failed_rebuild["code"] == module.RetCode.SERVER_ERROR + assert failed_rebuild["data"] is False + + monkeypatch.setattr(module.ConnectorService, "rebuild", lambda *_args: None) + ok_rebuild = _run(module.rebuild("conn-rb")) + assert ok_rebuild["data"] is True + + rm_res = module.rm_connector("conn-rm") + assert rm_res["data"] is True + assert ("conn-rm", module.TaskStatus.CANCEL) in resume_calls + assert delete_calls == ["conn-rm"] + + +@pytest.mark.p2 +def test_connector_oauth_helper_functions(monkeypatch): + module = _load_connector_app(monkeypatch) + + assert module._web_state_cache_key("flow-a", "gmail") == "gmail_web_flow_state:flow-a" + assert module._web_result_cache_key("flow-b", "google-drive") == "google-drive_web_flow_result:flow-b" + + creds_dict = {"web": {"client_id": "id"}} + assert module._load_credentials(creds_dict) == creds_dict + assert module._load_credentials(json.dumps(creds_dict)) == creds_dict + + with pytest.raises(ValueError, match="Invalid Google credentials JSON"): + module._load_credentials("{not-json") + + assert module._get_web_client_config(creds_dict) == {"web": {"client_id": "id"}} + with pytest.raises(ValueError, match="must include a 'web'"): + module._get_web_client_config({"installed": {"client_id": "id"}}) + + popup_ok = _run(module._render_web_oauth_popup("flow-1", True, "done", "gmail")) + assert popup_ok.status_code == 200 + assert popup_ok.headers["Content-Type"] == "text/html; charset=utf-8" + assert "Authorization complete" in popup_ok.body + assert "ragflow-gmail-oauth" in popup_ok.body + + popup_error = _run(module._render_web_oauth_popup("flow-2", False, "", "google-drive")) + assert popup_error.status_code == 200 + assert "Authorization failed" in popup_error.body + assert "<denied>" in popup_error.body + + +@pytest.mark.p2 +def test_start_google_web_oauth_matrix(monkeypatch): + module = _load_connector_app(monkeypatch) + + redis = _FakeRedis() + monkeypatch.setattr(module, "REDIS_CONN", redis) + monkeypatch.setattr(module.time, "time", lambda: 1700000000) + + flow_calls = [] + + def _from_client_config(client_config, scopes): + flow = _FakeFlow(client_config, scopes) + flow_calls.append(flow) + return flow + + monkeypatch.setattr(module.Flow, "from_client_config", staticmethod(_from_client_config)) + + _set_request(module, args={"type": "invalid"}) + monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"credentials": "{}"})) + invalid_type = _run(module.start_google_web_oauth()) + assert invalid_type["code"] == module.RetCode.ARGUMENT_ERROR + + monkeypatch.setattr(module, "GMAIL_WEB_OAUTH_REDIRECT_URI", "") + _set_request(module, args={"type": "gmail"}) + missing_redirect = _run(module.start_google_web_oauth()) + assert missing_redirect["code"] == module.RetCode.SERVER_ERROR + + monkeypatch.setattr(module, "GMAIL_WEB_OAUTH_REDIRECT_URI", "https://example.com/gmail") + monkeypatch.setattr(module, "GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI", "https://example.com/drive") + + _set_request(module, args={"type": "google-drive"}) + monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"credentials": "{invalid-json"})) + invalid_credentials = _run(module.start_google_web_oauth()) + assert invalid_credentials["code"] == module.RetCode.ARGUMENT_ERROR + + monkeypatch.setattr( + module, + "get_request_json", + lambda: _AwaitableValue({"credentials": json.dumps({"web": {"client_id": "id"}, "refresh_token": "rt"})}), + ) + has_refresh_token = _run(module.start_google_web_oauth()) + assert has_refresh_token["code"] == module.RetCode.ARGUMENT_ERROR + + monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"credentials": json.dumps({"installed": {"x": 1}})})) + missing_web = _run(module.start_google_web_oauth()) + assert missing_web["code"] == module.RetCode.ARGUMENT_ERROR + + ids = iter(["flow-gmail", "flow-drive"]) + monkeypatch.setattr(module.uuid, "uuid4", lambda: next(ids)) + + monkeypatch.setattr( + module, + "get_request_json", + lambda: _AwaitableValue({"credentials": json.dumps({"web": {"client_id": "id", "client_secret": "secret"}})}), + ) + + _set_request(module, args={"type": "gmail"}) + gmail_ok = _run(module.start_google_web_oauth()) + assert gmail_ok["code"] == 0 + assert gmail_ok["data"]["flow_id"] == "flow-gmail" + assert gmail_ok["data"]["authorization_url"].endswith("flow-gmail") + + _set_request(module, args={}) + drive_ok = _run(module.start_google_web_oauth()) + assert drive_ok["code"] == 0 + assert drive_ok["data"]["flow_id"] == "flow-drive" + assert drive_ok["data"]["authorization_url"].endswith("flow-drive") + + assert any(call.scopes == module.GOOGLE_SCOPES[module.DocumentSource.GMAIL] for call in flow_calls) + assert any(call.scopes == module.GOOGLE_SCOPES[module.DocumentSource.GOOGLE_DRIVE] for call in flow_calls) + assert "gmail_web_flow_state:flow-gmail" in redis.store + assert "google-drive_web_flow_state:flow-drive" in redis.store + assert json.loads(redis.store["gmail_web_flow_state:flow-gmail"])["code_verifier"] == "fake-code-verifier" + assert json.loads(redis.store["google-drive_web_flow_state:flow-drive"])["code_verifier"] == "fake-code-verifier" + + +@pytest.mark.p2 +def test_google_web_oauth_callbacks_matrix(monkeypatch): + module = _load_connector_app(monkeypatch) + + flow_calls = [] + + def _from_client_config(client_config, scopes): + flow = _FakeFlow(client_config, scopes) + flow_calls.append(flow) + return flow + + monkeypatch.setattr(module.Flow, "from_client_config", staticmethod(_from_client_config)) + + callback_specs = [ + ( + module.google_gmail_web_oauth_callback, + "gmail", + module.GMAIL_WEB_OAUTH_REDIRECT_URI, + module.GOOGLE_SCOPES[module.DocumentSource.GMAIL], + ), + ( + module.google_drive_web_oauth_callback, + "google-drive", + module.GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI, + module.GOOGLE_SCOPES[module.DocumentSource.GOOGLE_DRIVE], + ), + ] + + for callback, source, expected_redirect, expected_scopes in callback_specs: + redis = _FakeRedis() + monkeypatch.setattr(module, "REDIS_CONN", redis) + + _set_request(module, args={}) + missing_state = _run(callback()) + assert "Missing OAuth state parameter." in missing_state.body + + _set_request(module, args={"state": "sid"}) + expired_state = _run(callback()) + assert "Authorization session expired" in expired_state.body + + redis.store[module._web_state_cache_key("sid", source)] = json.dumps({"user_id": "tenant-1"}) + _set_request(module, args={"state": "sid"}) + invalid_state = _run(callback()) + assert "Authorization session was invalid" in invalid_state.body + assert module._web_state_cache_key("sid", source) in redis.deleted + + redis.store[module._web_state_cache_key("sid", source)] = json.dumps({ + "user_id": "tenant-1", + "client_config": {"web": {"client_id": "cid"}}, + }) + _set_request(module, args={"state": "sid", "error": "denied", "error_description": "permission denied"}) + oauth_error = _run(callback()) + assert "permission denied" in oauth_error.body + + redis.store[module._web_state_cache_key("sid", source)] = json.dumps({ + "user_id": "tenant-1", + "client_config": {"web": {"client_id": "cid"}}, + }) + _set_request(module, args={"state": "sid"}) + missing_code = _run(callback()) + assert "Missing authorization code" in missing_code.body + + redis.store[module._web_state_cache_key("sid", source)] = json.dumps({ + "user_id": "tenant-1", + "client_config": {"web": {"client_id": "cid"}}, + "code_verifier": "state-code-verifier", + }) + _set_request(module, args={"state": "sid", "code": "code-123"}) + success = _run(callback()) + assert "Authorization completed successfully." in success.body + + result_key = module._web_result_cache_key("sid", source) + assert result_key in redis.store + assert module._web_state_cache_key("sid", source) in redis.deleted + + assert flow_calls[-1].redirect_uri == expected_redirect + assert flow_calls[-1].scopes == expected_scopes + assert flow_calls[-1].token_code == "code-123" + assert flow_calls[-1].token_code_verifier == "state-code-verifier" + + +@pytest.mark.p2 +def test_poll_google_web_result_matrix(monkeypatch): + module = _load_connector_app(monkeypatch) + redis = _FakeRedis() + monkeypatch.setattr(module, "REDIS_CONN", redis) + + _set_request(module, args={"type": "invalid"}, json_body={"flow_id": "flow-1"}) + invalid_type = _run(module.poll_google_web_result()) + assert invalid_type["code"] == module.RetCode.ARGUMENT_ERROR + + _set_request(module, args={"type": "gmail"}, json_body={"flow_id": "flow-1"}) + pending = _run(module.poll_google_web_result()) + assert pending["code"] == module.RetCode.RUNNING + + redis.store[module._web_result_cache_key("flow-1", "gmail")] = json.dumps( + {"user_id": "another-user", "credentials": "token-x"} + ) + _set_request(module, args={"type": "gmail"}, json_body={"flow_id": "flow-1"}) + permission_error = _run(module.poll_google_web_result()) + assert permission_error["code"] == module.RetCode.PERMISSION_ERROR + + redis.store[module._web_result_cache_key("flow-1", "gmail")] = json.dumps( + {"user_id": "tenant-1", "credentials": "token-ok"} + ) + _set_request(module, args={"type": "gmail"}, json_body={"flow_id": "flow-1"}) + success = _run(module.poll_google_web_result()) + assert success["code"] == 0 + assert success["data"] == {"credentials": "token-ok"} + assert module._web_result_cache_key("flow-1", "gmail") in redis.deleted + + +@pytest.mark.p2 +def test_box_oauth_start_callback_and_poll_matrix(monkeypatch): + module = _load_connector_app(monkeypatch) + redis = _FakeRedis() + monkeypatch.setattr(module, "REDIS_CONN", redis) + + created_auth = [] + + class _TrackingBoxOAuth(_FakeBoxOAuth): + def __init__(self, config): + super().__init__(config) + created_auth.append(self) + + monkeypatch.setattr(module, "BoxOAuth", _TrackingBoxOAuth) + monkeypatch.setattr(module.uuid, "uuid4", lambda: "flow-box") + monkeypatch.setattr(module.time, "time", lambda: 1800000000) + + monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({})) + missing_params = _run(module.start_box_web_oauth()) + assert missing_params["code"] == module.RetCode.ARGUMENT_ERROR + + monkeypatch.setattr( + module, + "get_request_json", + lambda: _AwaitableValue({"client_id": "cid", "client_secret": "sec", "redirect_uri": "https://box.local/callback"}), + ) + start_ok = _run(module.start_box_web_oauth()) + assert start_ok["code"] == 0 + assert start_ok["data"]["flow_id"] == "flow-box" + assert "authorization_url" in start_ok["data"] + assert module._web_state_cache_key("flow-box", "box") in redis.store + + _set_request(module, args={}) + missing_state = _run(module.box_web_oauth_callback()) + assert "Missing OAuth parameters." in missing_state.body + + _set_request(module, args={"state": "flow-box"}) + missing_code = _run(module.box_web_oauth_callback()) + assert "Missing authorization code from Box." in missing_code.body + + redis.store[module._web_state_cache_key("flow-null", "box")] = "null" + _set_request(module, args={"state": "flow-null", "code": "abc"}) + invalid_session = _run(module.box_web_oauth_callback()) + assert invalid_session["code"] == module.RetCode.ARGUMENT_ERROR + + redis.store[module._web_state_cache_key("flow-box", "box")] = json.dumps( + {"user_id": "tenant-1", "client_id": "cid", "client_secret": "sec"} + ) + _set_request(module, args={"state": "flow-box", "code": "abc", "error": "access_denied", "error_description": "denied"}) + callback_error = _run(module.box_web_oauth_callback()) + assert "denied" in callback_error.body + + redis.store[module._web_state_cache_key("flow-ok", "box")] = json.dumps( + {"user_id": "tenant-1", "client_id": "cid", "client_secret": "sec"} + ) + _set_request(module, args={"state": "flow-ok", "code": "code-ok"}) + callback_success = _run(module.box_web_oauth_callback()) + assert "Authorization completed successfully." in callback_success.body + assert created_auth[-1].exchange_code == "code-ok" + assert module._web_result_cache_key("flow-ok", "box") in redis.store + assert module._web_state_cache_key("flow-ok", "box") in redis.deleted + + monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"flow_id": "flow-ok"})) + redis.store.pop(module._web_result_cache_key("flow-ok", "box"), None) + pending = _run(module.poll_box_web_result()) + assert pending["code"] == module.RetCode.RUNNING + + redis.store[module._web_result_cache_key("flow-ok", "box")] = json.dumps({"user_id": "another-user"}) + permission_error = _run(module.poll_box_web_result()) + assert permission_error["code"] == module.RetCode.PERMISSION_ERROR + + redis.store[module._web_result_cache_key("flow-ok", "box")] = json.dumps( + {"user_id": "tenant-1", "access_token": "at", "refresh_token": "rt"} + ) + poll_success = _run(module.poll_box_web_result()) + assert poll_success["code"] == 0 + assert poll_success["data"]["credentials"]["access_token"] == "at" + assert module._web_result_cache_key("flow-ok", "box") in redis.deleted diff --git a/test/testcases/restful_api/test_datasets.py b/test/testcases/restful_api/test_datasets.py new file mode 100644 index 0000000000..a0f1726102 --- /dev/null +++ b/test/testcases/restful_api/test_datasets.py @@ -0,0 +1,335 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest +from configs import DATASET_NAME_LIMIT + + +@pytest.mark.p1 +class TestDatasetsAuthorization: + def test_create_requires_auth(self, rest_client_noauth): + res = rest_client_noauth.post("/datasets", json={"name": "auth_test"}) + assert res.status_code == 401 + payload = res.json() + assert payload["code"] == 401, payload + + +@pytest.mark.p1 +def test_dataset_crud_cycle(rest_client, clear_datasets): + create_res = rest_client.post("/datasets", json={"name": "restful_dataset_crud"}) + assert create_res.status_code == 200 + create_payload = create_res.json() + assert create_payload["code"] == 0, create_payload + dataset_id = create_payload["data"]["id"] + + get_res = rest_client.get(f"/datasets/{dataset_id}") + assert get_res.status_code == 200 + get_payload = get_res.json() + assert get_payload["code"] == 0, get_payload + assert get_payload["data"]["id"] == dataset_id, get_payload + + update_res = rest_client.put( + f"/datasets/{dataset_id}", + json={"name": "restful_dataset_crud_updated"}, + ) + assert update_res.status_code == 200 + update_payload = update_res.json() + assert update_payload["code"] == 0, update_payload + assert update_payload["data"]["name"] == "restful_dataset_crud_updated", update_payload + + list_res = rest_client.get("/datasets", params={"id": dataset_id}) + assert list_res.status_code == 200 + list_payload = list_res.json() + assert list_payload["code"] == 0, list_payload + assert len(list_payload["data"]) == 1, list_payload + assert list_payload["data"][0]["id"] == dataset_id, list_payload + assert list_payload.get("total_datasets", 0) >= 1, list_payload + + delete_res = rest_client.delete("/datasets", json={"ids": [dataset_id]}) + assert delete_res.status_code == 200 + delete_payload = delete_res.json() + assert delete_payload["code"] == 0, delete_payload + + list_after_delete = rest_client.get("/datasets") + assert list_after_delete.status_code == 200 + list_after_delete_payload = list_after_delete.json() + assert list_after_delete_payload["code"] == 0, list_after_delete_payload + assert all(dataset["id"] != dataset_id for dataset in list_after_delete_payload["data"]), list_after_delete_payload + + +@pytest.mark.p2 +@pytest.mark.parametrize( + "name, expected_fragment", + [ + ("", "String should have at least 1 character"), + (" ", "String should have at least 1 character"), + ("a" * (DATASET_NAME_LIMIT + 1), f"String should have at most {DATASET_NAME_LIMIT} characters"), + ], + ids=["empty", "spaces", "too_long"], +) +def test_dataset_create_name_validation(rest_client, clear_datasets, name, expected_fragment): + res = rest_client.post("/datasets", json={"name": name}) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 101, payload + assert expected_fragment in payload["message"], payload + + +@pytest.mark.p2 +def test_dataset_list_ordering_and_pagination(rest_client, clear_datasets): + for i in range(3): + res = rest_client.post("/datasets", json={"name": f"dataset_page_{i}"}) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + + list_res = rest_client.get( + "/datasets", + params={"page": 1, "page_size": 2, "orderby": "create_time", "desc": "true"}, + ) + assert list_res.status_code == 200 + list_payload = list_res.json() + assert list_payload["code"] == 0, list_payload + assert len(list_payload["data"]) == 2, list_payload + assert list_payload.get("total_datasets", 0) >= 3, list_payload + + +@pytest.mark.p2 +def test_dataset_search_endpoint(rest_client, ensure_parsed_document): + dataset_id, _ = ensure_parsed_document() + res = rest_client.post( + f"/datasets/{dataset_id}/search", + json={"question": "test TXT file", "page": 1, "size": 10}, + ) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + assert "chunks" in payload["data"], payload + + +@pytest.mark.p2 +def test_dataset_search_requires_question(rest_client, create_dataset): + dataset_id = create_dataset("dataset_search_missing_question") + res = rest_client.post(f"/datasets/{dataset_id}/search", json={}) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 101, payload + assert "question" in payload["message"], payload + + +@pytest.mark.p2 +def test_dataset_tags_and_aggregation(rest_client, create_dataset): + dataset_id = create_dataset("dataset_tags") + second_dataset_id = create_dataset("dataset_tags_second") + + list_tags_res = rest_client.get(f"/datasets/{dataset_id}/tags") + assert list_tags_res.status_code == 200 + list_tags_payload = list_tags_res.json() + # Known env/runtime behavior: this route can return 102 when retriever tag + # backend is unavailable for an empty dataset. Keep route-contract coverage. + assert list_tags_payload["code"] in (0, 102), list_tags_payload + + aggregate_res = rest_client.get( + "/datasets/tags/aggregation", + params={"dataset_ids": f"{dataset_id},{second_dataset_id}"}, + ) + assert aggregate_res.status_code == 200 + aggregate_payload = aggregate_res.json() + assert aggregate_payload["code"] in (0, 102), aggregate_payload + + empty_aggregate_res = rest_client.get("/datasets/tags/aggregation") + assert empty_aggregate_res.status_code == 200 + empty_aggregate_payload = empty_aggregate_res.json() + assert empty_aggregate_payload["code"] != 0, empty_aggregate_payload + + +@pytest.mark.p2 +def test_dataset_tags_delete_and_rename_validation(rest_client, create_dataset): + dataset_id = create_dataset("dataset_tag_mutation") + + delete_missing_tags = rest_client.delete(f"/datasets/{dataset_id}/tags", json={}) + assert delete_missing_tags.status_code == 200 + delete_missing_tags_payload = delete_missing_tags.json() + assert delete_missing_tags_payload["code"] != 0, delete_missing_tags_payload + + delete_invalid_tags_type = rest_client.delete(f"/datasets/{dataset_id}/tags", json={"tags": "wrong"}) + assert delete_invalid_tags_type.status_code == 200 + delete_invalid_tags_type_payload = delete_invalid_tags_type.json() + assert delete_invalid_tags_type_payload["code"] != 0, delete_invalid_tags_type_payload + + rename_empty = rest_client.put( + f"/datasets/{dataset_id}/tags", + json={"from_tag": "", "to_tag": ""}, + ) + assert rename_empty.status_code == 200 + rename_empty_payload = rename_empty.json() + assert rename_empty_payload["code"] != 0, rename_empty_payload + + rename_invalid_dataset = rest_client.put( + "/datasets/invalid_id/tags", + json={"from_tag": "old", "to_tag": "new"}, + ) + assert rename_invalid_dataset.status_code == 200 + rename_invalid_dataset_payload = rename_invalid_dataset.json() + assert rename_invalid_dataset_payload["code"] != 0, rename_invalid_dataset_payload + + +@pytest.mark.p2 +def test_dataset_flattened_metadata(rest_client, create_dataset): + first_dataset_id = create_dataset("flattened_meta_1") + second_dataset_id = create_dataset("flattened_meta_2") + + flattened_res = rest_client.get( + "/datasets/metadata/flattened", + params={"dataset_ids": f"{first_dataset_id},{second_dataset_id}"}, + ) + assert flattened_res.status_code == 200 + flattened_payload = flattened_res.json() + assert flattened_payload["code"] == 0, flattened_payload + + empty_ids_res = rest_client.get("/datasets/metadata/flattened") + assert empty_ids_res.status_code == 200 + empty_ids_payload = empty_ids_res.json() + assert empty_ids_payload["code"] != 0, empty_ids_payload + + invalid_dataset_res = rest_client.get( + "/datasets/metadata/flattened", + params={"dataset_ids": "invalid_id"}, + ) + assert invalid_dataset_res.status_code == 200 + invalid_dataset_payload = invalid_dataset_res.json() + assert invalid_dataset_payload["code"] != 0, invalid_dataset_payload + + +@pytest.mark.p2 +def test_dataset_ingestion_summary_and_logs(rest_client, create_dataset): + dataset_id = create_dataset("dataset_ingestions") + + summary_res = rest_client.get(f"/datasets/{dataset_id}/ingestions/summary") + assert summary_res.status_code == 200 + summary_payload = summary_res.json() + assert summary_payload["code"] == 0, summary_payload + assert "doc_num" in summary_payload["data"], summary_payload + assert "chunk_num" in summary_payload["data"], summary_payload + assert "token_num" in summary_payload["data"], summary_payload + assert "status" in summary_payload["data"], summary_payload + + logs_res = rest_client.get( + f"/datasets/{dataset_id}/ingestions", + params={"page": 1, "page_size": 10}, + ) + assert logs_res.status_code == 200 + logs_payload = logs_res.json() + assert logs_payload["code"] == 0, logs_payload + assert "total" in logs_payload["data"], logs_payload + assert "logs" in logs_payload["data"], logs_payload + + not_found_log_res = rest_client.get(f"/datasets/{dataset_id}/ingestions/nonexistent_log") + assert not_found_log_res.status_code == 200 + not_found_log_payload = not_found_log_res.json() + assert not_found_log_payload["code"] != 0, not_found_log_payload + + +@pytest.mark.p2 +def test_dataset_ingestion_invalid_dataset(rest_client): + summary_res = rest_client.get("/datasets/invalid_id/ingestions/summary") + assert summary_res.status_code == 200 + summary_payload = summary_res.json() + assert summary_payload["code"] != 0, summary_payload + + logs_res = rest_client.get("/datasets/invalid_id/ingestions") + assert logs_res.status_code == 200 + logs_payload = logs_res.json() + assert logs_payload["code"] != 0, logs_payload + + log_res = rest_client.get("/datasets/invalid_id/ingestions/some_log_id") + assert log_res.status_code == 200 + log_payload = log_res.json() + assert log_payload["code"] != 0, log_payload + + +@pytest.mark.p2 +def test_dataset_index_endpoints(rest_client, create_dataset): + dataset_id = create_dataset("dataset_index_endpoints") + + run_invalid_type = rest_client.post( + f"/datasets/{dataset_id}/index", + params={"type": "invalid_type"}, + ) + assert run_invalid_type.status_code == 200 + run_invalid_type_payload = run_invalid_type.json() + assert run_invalid_type_payload["code"] != 0, run_invalid_type_payload + + run_no_docs = rest_client.post( + f"/datasets/{dataset_id}/index", + params={"type": "graph"}, + ) + assert run_no_docs.status_code == 200 + run_no_docs_payload = run_no_docs.json() + assert run_no_docs_payload["code"] == 102, run_no_docs_payload + + trace_no_task = rest_client.get( + f"/datasets/{dataset_id}/index", + params={"type": "graph"}, + ) + assert trace_no_task.status_code == 200 + trace_no_task_payload = trace_no_task.json() + assert trace_no_task_payload["code"] == 0, trace_no_task_payload + assert trace_no_task_payload["data"] == {}, trace_no_task_payload + + delete_graph = rest_client.delete(f"/datasets/{dataset_id}/graph") + assert delete_graph.status_code == 200 + delete_graph_payload = delete_graph.json() + assert delete_graph_payload["code"] == 0, delete_graph_payload + + delete_invalid_type = rest_client.delete(f"/datasets/{dataset_id}/invalid_type") + assert delete_invalid_type.status_code == 200 + delete_invalid_type_payload = delete_invalid_type.json() + assert delete_invalid_type_payload["code"] != 0, delete_invalid_type_payload + + +@pytest.mark.p2 +@pytest.mark.parametrize("index_type", ["graph", "raptor", "mindmap"]) +def test_dataset_index_run_with_document_creates_task(rest_client, create_document, index_type): + dataset_id, _ = create_document("dataset_index_graph_source.txt") + run_graph = rest_client.post( + f"/datasets/{dataset_id}/index", + params={"type": index_type}, + ) + assert run_graph.status_code == 200 + run_graph_payload = run_graph.json() + assert run_graph_payload["code"] == 0, run_graph_payload + assert run_graph_payload["data"].get("task_id"), run_graph_payload + + +@pytest.mark.p2 +def test_dataset_embedding_endpoints(rest_client, create_dataset): + dataset_id = create_dataset("dataset_embedding_endpoints") + + run_no_docs_res = rest_client.post(f"/datasets/{dataset_id}/embedding") + assert run_no_docs_res.status_code == 200 + run_no_docs_payload = run_no_docs_res.json() + assert run_no_docs_payload["code"] == 102, run_no_docs_payload + + missing_embd_id_res = rest_client.post(f"/datasets/{dataset_id}/embedding/check", json={}) + assert missing_embd_id_res.status_code == 200 + missing_embd_id_payload = missing_embd_id_res.json() + assert missing_embd_id_payload["code"] != 0, missing_embd_id_payload + + invalid_dataset_res = rest_client.post("/datasets/invalid_id/embedding") + assert invalid_dataset_res.status_code == 200 + invalid_dataset_payload = invalid_dataset_res.json() + assert invalid_dataset_payload["code"] != 0, invalid_dataset_payload diff --git a/test/testcases/restful_api/test_document_raw_routes.py b/test/testcases/restful_api/test_document_raw_routes.py new file mode 100644 index 0000000000..07f65230df --- /dev/null +++ b/test/testcases/restful_api/test_document_raw_routes.py @@ -0,0 +1,43 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest + + +@pytest.mark.p2 +def test_document_image_invalid_id_contract(rest_client_noauth): + res = rest_client_noauth.get("/documents/images/not-a-valid-image-id") + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 102, payload + assert payload["message"] == "Image not found.", payload + + +@pytest.mark.p2 +def test_document_artifact_requires_auth(rest_client_noauth): + res = rest_client_noauth.get("/documents/artifact/not-an-artifact.txt") + assert res.status_code == 401 + payload = res.json() + assert payload["code"] == 401, payload + + +@pytest.mark.p2 +def test_document_artifact_rejects_unsafe_filename(rest_client): + res = rest_client.get("/documents/artifact/not-an-artifact.exe") + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 102, payload + assert payload["message"] == "Invalid file type.", payload diff --git a/test/testcases/restful_api/test_documents.py b/test/testcases/restful_api/test_documents.py new file mode 100644 index 0000000000..59575fc0fc --- /dev/null +++ b/test/testcases/restful_api/test_documents.py @@ -0,0 +1,122 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest +from utils.file_utils import create_txt_file + + +@pytest.mark.p1 +def test_documents_upload_and_list(rest_client, create_dataset, tmp_path): + dataset_id = create_dataset("dataset_upload_list") + fp = create_txt_file(tmp_path / "upload_and_list.txt") + with fp.open("rb") as file_obj: + res = rest_client.post( + f"/datasets/{dataset_id}/documents", + files=[("file", (fp.name, file_obj))], + ) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + assert payload["data"][0]["dataset_id"] == dataset_id, payload + + list_res = rest_client.get(f"/datasets/{dataset_id}/documents") + assert list_res.status_code == 200 + list_payload = list_res.json() + assert list_payload["code"] == 0, list_payload + assert list_payload["data"]["total"] >= 1, list_payload + assert any(doc["name"] == fp.name for doc in list_payload["data"]["docs"]), list_payload + + +@pytest.mark.p2 +def test_documents_upload_missing_file(rest_client, create_dataset): + dataset_id = create_dataset("dataset_upload_missing") + res = rest_client.post(f"/datasets/{dataset_id}/documents") + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 101, payload + assert payload["message"] == "No file part!", payload + + +@pytest.mark.p2 +def test_documents_update_patch_and_delete(rest_client, create_document): + dataset_id, document_id = create_document("update_target.txt") + + patch_res = rest_client.patch( + f"/datasets/{dataset_id}/documents/{document_id}", + json={"name": "updated_target.txt"}, + ) + assert patch_res.status_code == 200 + patch_payload = patch_res.json() + assert patch_payload["code"] == 0, patch_payload + assert patch_payload["data"]["name"] == "updated_target.txt", patch_payload + + delete_res = rest_client.delete( + f"/datasets/{dataset_id}/documents", + json={"ids": [document_id]}, + ) + assert delete_res.status_code == 200 + delete_payload = delete_res.json() + assert delete_payload["code"] == 0, delete_payload + assert delete_payload["data"]["deleted"] == 1, delete_payload + + list_res = rest_client.get(f"/datasets/{dataset_id}/documents") + assert list_res.status_code == 200 + list_payload = list_res.json() + assert list_payload["code"] == 0, list_payload + assert all(doc["id"] != document_id for doc in list_payload["data"]["docs"]), list_payload + + +@pytest.mark.p2 +def test_documents_parse_and_stop(rest_client, create_document): + dataset_id, document_id = create_document("parse_target.txt") + + parse_res = rest_client.post( + f"/datasets/{dataset_id}/documents/parse", + json={"document_ids": [document_id]}, + ) + assert parse_res.status_code == 200 + parse_payload = parse_res.json() + assert parse_payload["code"] == 0, parse_payload + + stop_res = rest_client.post( + f"/datasets/{dataset_id}/documents/stop", + json={"document_ids": [document_id]}, + ) + assert stop_res.status_code == 200 + stop_payload = stop_res.json() + # Depending on timing this can be immediate stop success or "already completed". + assert stop_payload["code"] in (0, 102), stop_payload + if stop_payload["code"] == 102: + assert "already completed" in stop_payload["message"], stop_payload + + +@pytest.mark.p2 +def test_documents_metadata_update_path(rest_client, create_document): + dataset_id, document_id = create_document("metadata_target.txt") + + res = rest_client.patch( + f"/datasets/{dataset_id}/documents/metadatas", + json={ + "selector": {"document_ids": [document_id]}, + "updates": [{"key": "author", "value": "qa"}], + "deletes": [], + }, + ) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + assert payload["data"]["matched_docs"] == 1, payload + assert payload["data"]["updated"] >= 1, payload diff --git a/test/testcases/restful_api/test_file_routes_unit.py b/test/testcases/restful_api/test_file_routes_unit.py new file mode 100644 index 0000000000..39246e97a0 --- /dev/null +++ b/test/testcases/restful_api/test_file_routes_unit.py @@ -0,0 +1,632 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import asyncio +import importlib.util +import sys +from enum import Enum +from pathlib import Path +from types import ModuleType, SimpleNamespace + +import pytest + + +class _DummyManager: + def route(self, *_args, **_kwargs): + def decorator(func): + return func + + return decorator + + +class _AwaitableValue: + def __init__(self, value): + self._value = value + + def __await__(self): + async def _co(): + return self._value + + return _co().__await__() + + +class _DummyFiles(dict): + def __init__(self, file_objs=None): + super().__init__() + self._file_objs = list(file_objs or []) + if file_objs is not None: + self["file"] = self._file_objs + + def getlist(self, key): + if key == "file": + return list(self._file_objs) + return [] + + +class _DummyUploadFile: + def __init__(self, filename, blob=b"blob"): + self.filename = filename + self._blob = blob + + def read(self): + return self._blob + + +class _DummyRequest: + def __init__(self, *, content_type="", form=None, files=None, args=None): + self.content_type = content_type + self.form = _AwaitableValue(form or {}) + self.files = _AwaitableValue(files if files is not None else _DummyFiles()) + self.args = args or {} + + +class _DummyResponse: + def __init__(self, data): + self.data = data + self.headers = {} + + +def _run(coro): + return asyncio.run(coro) + + +def _load_file_api_module(monkeypatch): + repo_root = Path(__file__).resolve().parents[3] + + quart_mod = ModuleType("quart") + quart_mod.request = _DummyRequest() + + async def _make_response(data): + return _DummyResponse(data) + + quart_mod.make_response = _make_response + monkeypatch.setitem(sys.modules, "quart", quart_mod) + + api_pkg = ModuleType("api") + api_pkg.__path__ = [str(repo_root / "api")] + monkeypatch.setitem(sys.modules, "api", api_pkg) + + apps_pkg = ModuleType("api.apps") + apps_pkg.__path__ = [str(repo_root / "api" / "apps")] + apps_pkg.login_required = lambda func: func + monkeypatch.setitem(sys.modules, "api.apps", apps_pkg) + api_pkg.apps = apps_pkg + + services_pkg = ModuleType("api.apps.services") + services_pkg.__path__ = [str(repo_root / "api" / "apps" / "services")] + monkeypatch.setitem(sys.modules, "api.apps.services", services_pkg) + apps_pkg.services = services_pkg + + file_api_service_mod = ModuleType("api.apps.services.file_api_service") + + async def _upload_file(_tenant_id, _pf_id, _file_objs): + return True, [{"id": "f1"}] + + async def _create_folder(_tenant_id, _name, _parent_id=None, _file_type=None): + return True, {"id": "folder1"} + + async def _delete_files(_tenant_id, _ids): + return True, True + + async def _move_files(_tenant_id, _src_file_ids, _dest_file_id=None, _new_name=None): + return True, True + + file_api_service_mod.upload_file = _upload_file + file_api_service_mod.create_folder = _create_folder + file_api_service_mod.list_files = lambda _tenant_id, _args: (True, {"files": [], "total": 0}) + file_api_service_mod.delete_files = _delete_files + file_api_service_mod.move_files = _move_files + file_api_service_mod.get_file_content = lambda _tenant_id, _file_id: ( + True, + SimpleNamespace(parent_id="bucket1", location="path1", name="doc.txt", type="doc"), + ) + file_api_service_mod.get_parent_folder = lambda _file_id, user_id=None: (True, {"parent_folder": {"id": "parent1"}}) + file_api_service_mod.get_all_parent_folders = lambda _file_id, user_id=None: (True, {"parent_folders": [{"id": "root"}]}) + monkeypatch.setitem(sys.modules, "api.apps.services.file_api_service", file_api_service_mod) + services_pkg.file_api_service = file_api_service_mod + + db_pkg = ModuleType("api.db") + db_pkg.__path__ = [] + + class _FileType(Enum): + DOC = "doc" + VISUAL = "visual" + + db_pkg.FileType = _FileType + monkeypatch.setitem(sys.modules, "api.db", db_pkg) + api_pkg.db = db_pkg + + file2doc_mod = ModuleType("api.db.services.file2document_service") + file2doc_mod.File2DocumentService = SimpleNamespace(get_storage_address=lambda **_kwargs: ("bucket2", "path2")) + monkeypatch.setitem(sys.modules, "api.db.services.file2document_service", file2doc_mod) + + api_utils_mod = ModuleType("api.utils.api_utils") + api_utils_mod.add_tenant_id_to_kwargs = lambda func: func + api_utils_mod.get_error_argument_result = lambda message: {"code": 400, "data": None, "message": message} + api_utils_mod.get_error_data_result = lambda message: {"code": 500, "data": None, "message": message} + api_utils_mod.get_result = lambda data=None: {"code": 0, "data": data, "message": ""} + api_utils_mod.get_json_result = lambda code=0, message="success", data=None: {"code": code, "data": data, "message": message} + monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) + + validation_mod = ModuleType("api.utils.validation_utils") + validation_mod.CreateFolderReq = object + validation_mod.DeleteFileReq = object + validation_mod.ListFileReq = object + validation_mod.MoveFileReq = object + + async def _validate_json_request(_request, _schema): + return {}, None + + validation_mod.validate_and_parse_json_request = _validate_json_request + validation_mod.validate_and_parse_request_args = lambda _request, _schema: ({}, None) + monkeypatch.setitem(sys.modules, "api.utils.validation_utils", validation_mod) + + web_utils_mod = ModuleType("api.utils.web_utils") + web_utils_mod.CONTENT_TYPE_MAP = {"txt": "text/plain"} + web_utils_mod.apply_safe_file_response_headers = lambda response, content_type, ext: response.headers.update({"content_type": content_type, "ext": ext}) + monkeypatch.setitem(sys.modules, "api.utils.web_utils", web_utils_mod) + + common_pkg = ModuleType("common") + common_pkg.__path__ = [str(repo_root / "common")] + common_pkg.settings = SimpleNamespace( + STORAGE_IMPL=SimpleNamespace( + get=lambda *_args, **_kwargs: b"blob", + ) + ) + monkeypatch.setitem(sys.modules, "common", common_pkg) + + misc_utils_mod = ModuleType("common.misc_utils") + + async def thread_pool_exec(func, *args, **kwargs): + return func(*args, **kwargs) + + misc_utils_mod.thread_pool_exec = thread_pool_exec + monkeypatch.setitem(sys.modules, "common.misc_utils", misc_utils_mod) + + module_path = repo_root / "api" / "apps" / "restful_apis" / "file_api.py" + spec = importlib.util.spec_from_file_location("api.apps.restful_apis.file_api", module_path) + module = importlib.util.module_from_spec(spec) + module.manager = _DummyManager() + monkeypatch.setitem(sys.modules, "api.apps.restful_apis.file_api", module) + spec.loader.exec_module(module) + return module + + +@pytest.mark.p2 +def test_create_or_upload_multipart_requires_file(monkeypatch): + module = _load_file_api_module(monkeypatch) + monkeypatch.setattr(module, "request", _DummyRequest(content_type="multipart/form-data", form={}, files=_DummyFiles())) + + res = _run(module.create_or_upload("tenant1")) + assert res["code"] == 400 + assert res["message"] == "No file part!" + + +@pytest.mark.p2 +def test_create_or_upload_uploads_via_new_service(monkeypatch): + module = _load_file_api_module(monkeypatch) + files = _DummyFiles([_DummyUploadFile("a.txt")]) + monkeypatch.setattr(module, "request", _DummyRequest(content_type="multipart/form-data", form={"parent_id": "pf1"}, files=files)) + + seen = {} + + async def _upload_file(tenant_id, pf_id, file_objs): + seen["args"] = (tenant_id, pf_id, [f.filename for f in file_objs]) + return True, [{"id": "f1"}] + + monkeypatch.setattr(module.file_api_service, "upload_file", _upload_file) + res = _run(module.create_or_upload("tenant1")) + + assert seen["args"] == ("tenant1", "pf1", ["a.txt"]) + assert res["code"] == 0 + assert res["data"] == [{"id": "f1"}] + + +@pytest.mark.p2 +def test_create_or_upload_creates_folder_from_json(monkeypatch): + module = _load_file_api_module(monkeypatch) + monkeypatch.setattr(module, "request", _DummyRequest(content_type="application/json")) + + async def _validate(_request, _schema): + return {"name": "folder-a", "parent_id": "pf1", "type": "folder"}, None + + async def _create_folder(tenant_id, name, parent_id=None, file_type=None): + return True, {"tenant_id": tenant_id, "name": name, "parent_id": parent_id, "type": file_type} + + monkeypatch.setattr(module, "validate_and_parse_json_request", _validate) + monkeypatch.setattr(module.file_api_service, "create_folder", _create_folder) + + res = _run(module.create_or_upload("tenant1")) + assert res["code"] == 0 + assert res["data"]["tenant_id"] == "tenant1" + assert res["data"]["name"] == "folder-a" + + +@pytest.mark.p2 +def test_list_files_validation_error(monkeypatch): + module = _load_file_api_module(monkeypatch) + monkeypatch.setattr(module, "validate_and_parse_request_args", lambda _request, _schema: (None, "bad args")) + + res = _run(module.list_files("tenant1")) + assert res["code"] == 400 + assert res["message"] == "bad args" + + +@pytest.mark.p2 +def test_move_uses_new_payload_shape(monkeypatch): + module = _load_file_api_module(monkeypatch) + + async def _validate(_request, _schema): + return {"src_file_ids": ["f1"], "dest_file_id": "pf2"}, None + + seen = {} + + async def _move_files(tenant_id, src_file_ids, dest_file_id=None, new_name=None): + seen["args"] = (tenant_id, src_file_ids, dest_file_id, new_name) + return True, True + + monkeypatch.setattr(module, "validate_and_parse_json_request", _validate) + monkeypatch.setattr(module.file_api_service, "move_files", _move_files) + + res = _run(module.move("tenant1")) + assert seen["args"] == ("tenant1", ["f1"], "pf2", None) + assert res["code"] == 0 + assert res["data"] is True + + +@pytest.mark.p2 +def test_rename_via_move_route(monkeypatch): + module = _load_file_api_module(monkeypatch) + + async def _validate(_request, _schema): + return {"src_file_ids": ["file1"], "new_name": "renamed.txt"}, None + + seen = {} + + async def _move_files(tenant_id, src_file_ids, dest_file_id=None, new_name=None): + seen["args"] = (tenant_id, src_file_ids, dest_file_id, new_name) + return True, True + + monkeypatch.setattr(module, "validate_and_parse_json_request", _validate) + monkeypatch.setattr(module.file_api_service, "move_files", _move_files) + + res = _run(module.move("tenant1")) + assert seen["args"] == ("tenant1", ["file1"], None, "renamed.txt") + assert res["code"] == 0 + assert res["data"] is True + + +@pytest.mark.p2 +def test_download_falls_back_to_document_storage(monkeypatch): + module = _load_file_api_module(monkeypatch) + storage_calls = [] + + def _get(bucket, location): + storage_calls.append((bucket, location)) + return b"" if len(storage_calls) == 1 else b"fallback-blob" + + monkeypatch.setattr(module.settings, "STORAGE_IMPL", SimpleNamespace(get=_get)) + res = _run(module.download("tenant1", "file1")) + + assert storage_calls == [("bucket1", "path1"), ("bucket2", "path2")] + assert res.data == b"fallback-blob" + assert res.headers["content_type"] == "text/plain" + assert res.headers["ext"] == "txt" + + +@pytest.mark.p2 +def test_parent_and_ancestors_use_new_routes(monkeypatch): + module = _load_file_api_module(monkeypatch) + + parent_res = _run(module.parent_folder("tenant1", "file1")) + ancestors_res = _run(module.ancestors("tenant1", "file1")) + + assert parent_res["code"] == 0 + assert parent_res["data"]["parent_folder"]["id"] == "parent1" + assert ancestors_res["code"] == 0 + assert ancestors_res["data"]["parent_folders"][0]["id"] == "root" + +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import functools +from copy import deepcopy + +import pytest + + +class _DummyManager: + def route(self, *_args, **_kwargs): + def decorator(func): + return func + + return decorator + + +class _DummyFile: + def __init__(self, file_id, file_type, *, name="file.txt", location="loc", size=1): + self.id = file_id + self.type = file_type + self.name = name + self.location = location + self.size = size + + +class _FalsyFile(_DummyFile): + def __bool__(self): + return False + + +def _run(coro): + return asyncio.run(coro) + + +def _set_request_json(monkeypatch, module, payload_state): + async def _req_json(): + return deepcopy(payload_state) + + monkeypatch.setattr(module, "get_request_json", _req_json) + + +@pytest.fixture(scope="session") +def auth(): + return "unit-auth" + + +@pytest.fixture(scope="session", autouse=True) +def set_tenant_info(): + return None + + +def _load_file2document_module(monkeypatch): + repo_root = Path(__file__).resolve().parents[3] + + api_pkg = ModuleType("api") + api_pkg.__path__ = [str(repo_root / "api")] + monkeypatch.setitem(sys.modules, "api", api_pkg) + + apps_mod = ModuleType("api.apps") + apps_mod.__path__ = [str(repo_root / "api" / "apps")] + apps_mod.current_user = SimpleNamespace(id="user-1") + apps_mod.login_required = lambda func: func + monkeypatch.setitem(sys.modules, "api.apps", apps_mod) + api_pkg.apps = apps_mod + + db_pkg = ModuleType("api.db") + db_pkg.__path__ = [] + + class _FileType(Enum): + FOLDER = "folder" + DOC = "doc" + + db_pkg.FileType = _FileType + monkeypatch.setitem(sys.modules, "api.db", db_pkg) + api_pkg.db = db_pkg + + services_pkg = ModuleType("api.db.services") + services_pkg.__path__ = [] + monkeypatch.setitem(sys.modules, "api.db.services", services_pkg) + + common_pkg = ModuleType("api.common") + common_pkg.__path__ = [] + monkeypatch.setitem(sys.modules, "api.common", common_pkg) + + permission_mod = ModuleType("api.common.check_team_permission") + permission_mod.check_file_team_permission = lambda *_args, **_kwargs: True + permission_mod.check_kb_team_permission = lambda *_args, **_kwargs: True + monkeypatch.setitem(sys.modules, "api.common.check_team_permission", permission_mod) + common_pkg.check_team_permission = permission_mod + + file2document_mod = ModuleType("api.db.services.file2document_service") + + class _StubFile2DocumentService: + @staticmethod + def get_by_file_id(_file_id): + return [] + + @staticmethod + def delete_by_file_id(*_args, **_kwargs): + return None + + @staticmethod + def insert(_payload): + return SimpleNamespace(to_json=lambda: {}) + + file2document_mod.File2DocumentService = _StubFile2DocumentService + monkeypatch.setitem(sys.modules, "api.db.services.file2document_service", file2document_mod) + services_pkg.file2document_service = file2document_mod + + file_service_mod = ModuleType("api.db.services.file_service") + + class _StubFileService: + @staticmethod + def get_by_ids(_file_ids): + return [] + + @staticmethod + def get_all_innermost_file_ids(_file_id, _acc): + return [] + + @staticmethod + def get_by_id(_file_id): + return True, _DummyFile(_file_id, _FileType.DOC.value) + + @staticmethod + def get_parser(_file_type, _file_name, parser_id): + return parser_id + + file_service_mod.FileService = _StubFileService + monkeypatch.setitem(sys.modules, "api.db.services.file_service", file_service_mod) + services_pkg.file_service = file_service_mod + + kb_service_mod = ModuleType("api.db.services.knowledgebase_service") + + class _StubKnowledgebaseService: + @staticmethod + def get_by_id(_kb_id): + return False, None + + kb_service_mod.KnowledgebaseService = _StubKnowledgebaseService + monkeypatch.setitem(sys.modules, "api.db.services.knowledgebase_service", kb_service_mod) + services_pkg.knowledgebase_service = kb_service_mod + + document_service_mod = ModuleType("api.db.services.document_service") + + class _StubDocumentService: + @staticmethod + def get_by_id(doc_id): + return True, SimpleNamespace(id=doc_id) + + @staticmethod + def get_tenant_id(_doc_id): + return "tenant-1" + + @staticmethod + def remove_document(*_args, **_kwargs): + return True + + @staticmethod + def insert(_payload): + return SimpleNamespace(id="doc-1") + + document_service_mod.DocumentService = _StubDocumentService + monkeypatch.setitem(sys.modules, "api.db.services.document_service", document_service_mod) + services_pkg.document_service = document_service_mod + + api_utils_mod = ModuleType("api.utils.api_utils") + + def get_json_result(data=None, message="", code=0): + return {"code": code, "data": data, "message": message} + + def get_data_error_result(message=""): + return {"code": 102, "data": None, "message": message} + + async def get_request_json(): + return {} + + def server_error_response(err): + return {"code": 500, "data": None, "message": str(err)} + + def validate_request(*_keys): + def _decorator(func): + @functools.wraps(func) + async def _wrapper(*args, **kwargs): + return await func(*args, **kwargs) + + return _wrapper + + return _decorator + + api_utils_mod.get_json_result = get_json_result + api_utils_mod.get_data_error_result = get_data_error_result + api_utils_mod.get_request_json = get_request_json + api_utils_mod.server_error_response = server_error_response + api_utils_mod.validate_request = validate_request + monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) + + misc_utils_mod = ModuleType("common.misc_utils") + misc_utils_mod.get_uuid = lambda: "uuid" + monkeypatch.setitem(sys.modules, "common.misc_utils", misc_utils_mod) + + constants_mod = ModuleType("common.constants") + + class _RetCode: + ARGUMENT_ERROR = 101 + + constants_mod.RetCode = _RetCode + monkeypatch.setitem(sys.modules, "common.constants", constants_mod) + + module_name = "test_file2document_routes_unit_module" + module_path = repo_root / "api" / "apps" / "restful_apis" / "file2document_api.py" + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + module.manager = _DummyManager() + monkeypatch.setitem(sys.modules, module_name, module) + spec.loader.exec_module(module) + return module + + +@pytest.mark.p2 +def test_convert_branch_matrix_unit(monkeypatch): + module = _load_file2document_module(monkeypatch) + req_state = {"kb_ids": ["kb-1"], "file_ids": ["f1"]} + _set_request_json(monkeypatch, module, req_state) + + # Falsy file returns "File not found!" during synchronous validation. + monkeypatch.setattr(module.FileService, "get_by_ids", lambda _ids: [_FalsyFile("f1", module.FileType.DOC.value)]) + res = _run(module.convert()) + assert res["code"] == 102 + assert res["message"] == "File not found!" + + # Valid file but invalid kb returns "Can't find this dataset!" during synchronous validation. + monkeypatch.setattr(module.FileService, "get_by_ids", lambda _ids: [_DummyFile("f1", module.FileType.DOC.value)]) + res = _run(module.convert()) + assert res["code"] == 102 + assert res["message"] == "Can't find this dataset!" + + kb = SimpleNamespace(id="kb-1", parser_id="naive", pipeline_id="p1", parser_config={}) + monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, kb)) + + # Unauthorized file access is rejected before scheduling background work. + monkeypatch.setattr(module, "check_file_team_permission", lambda *_args, **_kwargs: False) + res = _run(module.convert()) + assert res["code"] == 102 + assert res["message"] == "No authorization." + + # Unauthorized dataset access is rejected before scheduling background work. + monkeypatch.setattr(module, "check_file_team_permission", lambda *_args, **_kwargs: True) + monkeypatch.setattr(module, "check_kb_team_permission", lambda *_args, **_kwargs: False) + res = _run(module.convert()) + assert res["code"] == 102 + assert res["message"] == "No authorization." + + # Valid file and kb schedule background work and return data=True immediately. + monkeypatch.setattr(module, "check_kb_team_permission", lambda *_args, **_kwargs: True) + res = _run(module.convert()) + assert res["code"] == 0 + assert res["data"] is True + + # Folder expansion schedules background work and returns data=True immediately. + req_state["file_ids"] = ["folder-1"] + monkeypatch.setattr(module.FileService, "get_by_ids", lambda _ids: [_DummyFile("folder-1", module.FileType.FOLDER.value, name="folder")]) + monkeypatch.setattr(module.FileService, "get_all_innermost_file_ids", lambda _file_id, _acc: ["inner-1"]) + res = _run(module.convert()) + assert res["code"] == 0 + assert res["data"] is True + + # Exception in file lookup returns 500. + req_state["file_ids"] = ["f1"] + monkeypatch.setattr( + module.FileService, + "get_by_ids", + lambda _ids: (_ for _ in ()).throw(RuntimeError("convert boom")), + ) + res = _run(module.convert()) + assert res["code"] == 500 + assert "convert boom" in res["message"] diff --git a/test/testcases/restful_api/test_langfuse_routes.py b/test/testcases/restful_api/test_langfuse_routes.py new file mode 100644 index 0000000000..deda7fbe3e --- /dev/null +++ b/test/testcases/restful_api/test_langfuse_routes.py @@ -0,0 +1,37 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest + + +@pytest.mark.p2 +def test_langfuse_api_key_routes_require_auth(rest_client_noauth): + for method in ("get", "post", "put", "delete"): + requester = getattr(rest_client_noauth, method) + kwargs = {"json": {"secret_key": "s", "public_key": "p", "host": "http://example.com"}} if method in {"post", "put"} else {} + res = requester("/langfuse/api-key", **kwargs) + assert res.status_code == 401 + payload = res.json() + assert payload["code"] == 401, (method, payload) + + +@pytest.mark.p2 +def test_langfuse_api_key_missing_required_fields(rest_client): + res = rest_client.post("/langfuse/api-key", json={"secret_key": "", "public_key": "pub", "host": "http://host"}) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] in (101, 102), payload + assert "required" in payload["message"].lower() or "missing" in payload["message"].lower(), payload diff --git a/test/testcases/restful_api/test_mcp_routes_unit.py b/test/testcases/restful_api/test_mcp_routes_unit.py new file mode 100644 index 0000000000..ccd628f0fd --- /dev/null +++ b/test/testcases/restful_api/test_mcp_routes_unit.py @@ -0,0 +1,745 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import asyncio +import importlib.util +import inspect +import json +import sys +from functools import wraps +from pathlib import Path +from types import ModuleType, SimpleNamespace + +import pytest + + +class _DummyManager: + def route(self, *_args, **_kwargs): + def decorator(func): + return func + + return decorator + + +class _Args(dict): + def getlist(self, key): + value = self.get(key, []) + if isinstance(value, list): + return value + return [value] + + +class _Field: + def __init__(self, name): + self.name = name + + def __eq__(self, other): + return (self.name, other) + + +class _DummyMCPServer: + id = _Field("id") + tenant_id = _Field("tenant_id") + + def __init__(self, **kwargs): + self.id = kwargs.get("id", "") + self.name = kwargs.get("name", "") + self.url = kwargs.get("url", "") + self.server_type = kwargs.get("server_type", "sse") + self.tenant_id = kwargs.get("tenant_id", "tenant_1") + self.variables = kwargs.get("variables", {}) + self.headers = kwargs.get("headers", {}) + + def to_dict(self): + return { + "id": self.id, + "name": self.name, + "url": self.url, + "server_type": self.server_type, + "tenant_id": self.tenant_id, + "variables": self.variables, + "headers": self.headers, + } + + +class _DummyMCPServerService: + @staticmethod + def get_servers(*_args, **_kwargs): + return [] + + @staticmethod + def get_or_none(*_args, **_kwargs): + return None + + @staticmethod + def get_by_id(*_args, **_kwargs): + return False, None + + @staticmethod + def get_by_name_and_tenant(*_args, **_kwargs): + return False, None + + @staticmethod + def insert(**_kwargs): + return True + + @staticmethod + def filter_update(*_args, **_kwargs): + return True + + @staticmethod + def delete_by_ids(*_args, **_kwargs): + return True + + +class _DummyTenantService: + @staticmethod + def get_by_id(*_args, **_kwargs): + return True, SimpleNamespace(id="tenant_1") + + +class _DummyTool: + def __init__(self, name): + self._name = name + + def model_dump(self): + return {"name": self._name} + + +class _DummyMCPToolCallSession: + def __init__(self, _mcp_server, _variables): + self._tools = [_DummyTool("tool_a"), _DummyTool("tool_b")] + + def get_tools(self, _timeout): + return self._tools + + def tool_call(self, _name, _arguments, _timeout): + return "ok" + + +def _run(coro): + return asyncio.run(coro) + + +def _set_request_json(monkeypatch, module, payload): + async def _request_json(): + return payload + + monkeypatch.setattr(module, "get_request_json", _request_json) + + +@pytest.fixture(scope="session") +def auth(): + return "unit-auth" + + +@pytest.fixture(scope="session", autouse=True) +def set_tenant_info(): + return None + + +def _load_mcp_api(monkeypatch): + repo_root = Path(__file__).resolve().parents[3] + + quart_mod = ModuleType("quart") + quart_mod.Response = object + quart_mod.request = SimpleNamespace(args=_Args({})) + monkeypatch.setitem(sys.modules, "quart", quart_mod) + + common_pkg = ModuleType("common") + common_pkg.__path__ = [str(repo_root / "common")] + monkeypatch.setitem(sys.modules, "common", common_pkg) + + constants_mod = ModuleType("common.constants") + constants_mod.VALID_MCP_SERVER_TYPES = {"sse", "streamable-http"} + monkeypatch.setitem(sys.modules, "common.constants", constants_mod) + + apps_mod = ModuleType("api.apps") + apps_mod.current_user = SimpleNamespace(id="tenant_1") + apps_mod.login_required = lambda func: func + monkeypatch.setitem(sys.modules, "api.apps", apps_mod) + + db_models_mod = ModuleType("api.db.db_models") + db_models_mod.MCPServer = _DummyMCPServer + monkeypatch.setitem(sys.modules, "api.db.db_models", db_models_mod) + + mcp_service_mod = ModuleType("api.db.services.mcp_server_service") + mcp_service_mod.MCPServerService = _DummyMCPServerService + monkeypatch.setitem(sys.modules, "api.db.services.mcp_server_service", mcp_service_mod) + + user_service_mod = ModuleType("api.db.services.user_service") + user_service_mod.TenantService = _DummyTenantService + monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod) + + mcp_conn_mod = ModuleType("common.mcp_tool_call_conn") + mcp_conn_mod.MCPToolCallSession = _DummyMCPToolCallSession + mcp_conn_mod.close_multiple_mcp_toolcall_sessions = lambda _sessions: None + monkeypatch.setitem(sys.modules, "common.mcp_tool_call_conn", mcp_conn_mod) + + api_utils_mod = ModuleType("api.utils.api_utils") + + async def _default_request_json(): + return {} + + def _get_json_result(code=0, message="success", data=None): + return {"code": code, "message": message, "data": data} + + def _get_data_error_result(code=102, message="Sorry! Data missing!"): + return {"code": code, "message": message} + + def _server_error_response(error): + return {"code": 100, "message": repr(error)} + + async def _get_mcp_tools(*_args, **_kwargs): + return {} + + def _validate_request(*_args, **_kwargs): + def _decorator(func): + @wraps(func) + async def _wrapped(*func_args, **func_kwargs): + if inspect.iscoroutinefunction(func): + return await func(*func_args, **func_kwargs) + return func(*func_args, **func_kwargs) + + return _wrapped + + return _decorator + + api_utils_mod.get_request_json = _default_request_json + api_utils_mod.get_json_result = _get_json_result + api_utils_mod.get_data_error_result = _get_data_error_result + api_utils_mod.server_error_response = _server_error_response + api_utils_mod.validate_request = _validate_request + api_utils_mod.get_mcp_tools = _get_mcp_tools + monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) + + web_utils_mod = ModuleType("api.utils.web_utils") + + def _get_float(data, key, default): + try: + return float(data.get(key, default)) + except (TypeError, ValueError): + return default + + def _safe_json_parse(value): + if isinstance(value, (dict, list)): + return value + if value in (None, ""): + return {} + try: + return json.loads(value) + except (TypeError, ValueError): + return {} + + web_utils_mod.get_float = _get_float + web_utils_mod.safe_json_parse = _safe_json_parse + monkeypatch.setitem(sys.modules, "api.utils.web_utils", web_utils_mod) + + module_name = "test_mcp_api_unit_module" + module_path = repo_root / "api" / "apps" / "restful_apis" / "mcp_api.py" + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + module.manager = _DummyManager() + monkeypatch.setitem(sys.modules, module_name, module) + spec.loader.exec_module(module) + return module + + +@pytest.mark.p2 +def test_list_mcp_desc_pagination_and_exception(monkeypatch): + module = _load_mcp_api(monkeypatch) + + monkeypatch.setattr( + module, + "request", + SimpleNamespace(args=_Args({"keywords": "k", "page": "2", "page_size": "1", "orderby": "create_time", "desc": "false"})), + ) + _set_request_json(monkeypatch, module, {"mcp_ids": []}) + monkeypatch.setattr(module.MCPServerService, "get_servers", lambda *_args, **_kwargs: [{"id": "a"}, {"id": "b"}]) + + res = _run(module.list_mcp()) + assert res["code"] == 0 + assert res["data"]["total"] == 2 + assert res["data"]["mcp_servers"] == [{"id": "b"}] + + monkeypatch.setattr(module, "request", SimpleNamespace(args=_Args({}))) + _set_request_json(monkeypatch, module, {"mcp_ids": []}) + + def _raise_list(*_args, **_kwargs): + raise RuntimeError("list explode") + + monkeypatch.setattr(module.MCPServerService, "get_servers", _raise_list) + res = _run(module.list_mcp()) + assert res["code"] == 100 + assert "list explode" in res["message"] + + +@pytest.mark.p2 +def test_detail_not_found_success_and_exception(monkeypatch): + module = _load_mcp_api(monkeypatch) + monkeypatch.setattr(module, "request", SimpleNamespace(args=_Args({}))) + + monkeypatch.setattr(module.MCPServerService, "get_or_none", lambda **_kwargs: None) + res = module.detail("mcp-1") + assert res["code"] == 102 + assert "Cannot find MCP server mcp-1 for user tenant_1" in res["message"] + + monkeypatch.setattr( + module.MCPServerService, + "get_or_none", + lambda **_kwargs: _DummyMCPServer(id="mcp-1", name="srv", url="http://a", server_type="sse", tenant_id="tenant_1"), + ) + res = module.detail("mcp-1") + assert res["code"] == 0 + assert res["data"]["id"] == "mcp-1" + + def _raise_detail(**_kwargs): + raise RuntimeError("detail explode") + + monkeypatch.setattr(module.MCPServerService, "get_or_none", _raise_detail) + res = module.detail("mcp-1") + assert res["code"] == 100 + assert "detail explode" in res["message"] + + +@pytest.mark.p2 +def test_create_validation_guards(monkeypatch): + module = _load_mcp_api(monkeypatch) + + monkeypatch.setattr(module.MCPServerService, "get_by_name_and_tenant", lambda **_kwargs: (False, None)) + + _set_request_json(monkeypatch, module, {"name": "srv", "url": "http://a", "server_type": "invalid"}) + res = _run(module.create.__wrapped__()) + assert "Unsupported MCP server type" in res["message"] + + _set_request_json(monkeypatch, module, {"name": "", "url": "http://a", "server_type": "sse"}) + res = _run(module.create.__wrapped__()) + assert "Invalid MCP name" in res["message"] + + monkeypatch.setattr(module.MCPServerService, "get_by_name_and_tenant", lambda **_kwargs: (True, object())) + _set_request_json(monkeypatch, module, {"name": "srv", "url": "http://a", "server_type": "sse"}) + res = _run(module.create.__wrapped__()) + assert "Duplicated MCP server name" in res["message"] + + monkeypatch.setattr(module.MCPServerService, "get_by_name_and_tenant", lambda **_kwargs: (False, None)) + _set_request_json(monkeypatch, module, {"name": "srv", "url": "", "server_type": "sse"}) + res = _run(module.create.__wrapped__()) + assert "Invalid url" in res["message"] + + +@pytest.mark.p2 +def test_create_service_paths(monkeypatch): + module = _load_mcp_api(monkeypatch) + + base_payload = { + "name": "srv", + "url": "http://server", + "server_type": "sse", + "headers": '{"Authorization": "x"}', + "variables": '{"tools": {"old": 1}, "token": "abc"}', + "timeout": "2.5", + } + + monkeypatch.setattr(module, "get_uuid", lambda: "uuid-create") + monkeypatch.setattr(module.MCPServerService, "get_by_name_and_tenant", lambda **_kwargs: (False, None)) + + _set_request_json(monkeypatch, module, dict(base_payload)) + monkeypatch.setattr(module.TenantService, "get_by_id", lambda *_args, **_kwargs: (False, None)) + res = _run(module.create.__wrapped__()) + assert "Tenant not found" in res["message"] + + _set_request_json(monkeypatch, module, dict(base_payload)) + monkeypatch.setattr(module.TenantService, "get_by_id", lambda *_args, **_kwargs: (True, object())) + + async def _thread_pool_tools_error(_func, _servers, _timeout): + return None, "tools error" + + monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_tools_error) + res = _run(module.create.__wrapped__()) + assert res["code"] == 102 + assert "tools error" in res["message"] + + _set_request_json(monkeypatch, module, dict(base_payload)) + + async def _thread_pool_ok(_func, servers, _timeout): + return {servers[0].name: [{"name": "tool_a"}, {"invalid": True}]}, None + + monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_ok) + monkeypatch.setattr(module.MCPServerService, "insert", lambda **_kwargs: False) + res = _run(module.create.__wrapped__()) + assert res["code"] == 102 + assert "Failed to create MCP server" in res["message"] + + _set_request_json(monkeypatch, module, dict(base_payload)) + monkeypatch.setattr(module.MCPServerService, "insert", lambda **_kwargs: True) + res = _run(module.create.__wrapped__()) + assert res["code"] == 0 + assert res["data"]["id"] == "uuid-create" + assert res["data"]["tenant_id"] == "tenant_1" + assert res["data"]["variables"]["tools"] == {"tool_a": {"name": "tool_a"}} + + _set_request_json(monkeypatch, module, dict(base_payload)) + + async def _thread_pool_raises(_func, _servers, _timeout): + raise RuntimeError("create explode") + + monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_raises) + res = _run(module.create.__wrapped__()) + assert res["code"] == 100 + assert "create explode" in res["message"] + + +@pytest.mark.p2 +def test_update_validation_guards(monkeypatch): + module = _load_mcp_api(monkeypatch) + + existing = _DummyMCPServer(id="mcp-1", name="srv", url="http://server", server_type="sse", tenant_id="tenant_1", variables={}, headers={}) + + _set_request_json(monkeypatch, module, {"mcp_id": "mcp-1"}) + monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (False, None)) + res = _run(module.update("mcp-1")) + assert "Cannot find MCP server" in res["message"] + + _set_request_json(monkeypatch, module, {"mcp_id": "mcp-1"}) + monkeypatch.setattr( + module.MCPServerService, + "get_by_id", + lambda _mcp_id: (True, _DummyMCPServer(id="mcp-1", name="srv", url="http://server", server_type="sse", tenant_id="other", variables={}, headers={})), + ) + res = _run(module.update("mcp-1")) + assert "Cannot find MCP server" in res["message"] + + _set_request_json(monkeypatch, module, {"mcp_id": "mcp-1", "server_type": "invalid"}) + monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, existing)) + res = _run(module.update("mcp-1")) + assert "Unsupported MCP server type" in res["message"] + + _set_request_json(monkeypatch, module, {"mcp_id": "mcp-1", "name": "a" * 256}) + res = _run(module.update("mcp-1")) + assert "Invalid MCP name" in res["message"] + + _set_request_json(monkeypatch, module, {"mcp_id": "mcp-1", "url": ""}) + res = _run(module.update("mcp-1")) + assert "Invalid url" in res["message"] + + +@pytest.mark.p2 +def test_update_service_paths(monkeypatch): + module = _load_mcp_api(monkeypatch) + + existing = _DummyMCPServer( + id="mcp-1", + name="srv", + url="http://server", + server_type="sse", + tenant_id="tenant_1", + variables={"tools": {"old": {"enabled": True}}, "token": "abc"}, + headers={"Authorization": "old"}, + ) + updated = _DummyMCPServer( + id="mcp-1", + name="srv-new", + url="http://server-new", + server_type="sse", + tenant_id="tenant_1", + variables={"tools": {"tool_a": {"name": "tool_a"}}}, + headers={"Authorization": "new"}, + ) + + base_payload = { + "mcp_id": "mcp-1", + "name": "srv-new", + "url": "http://server-new", + "server_type": "sse", + "headers": '{"Authorization": "new"}', + "variables": '{"tools": {"ignore": 1}, "token": "new"}', + "timeout": "3.0", + } + + _set_request_json(monkeypatch, module, dict(base_payload)) + monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, existing)) + + async def _thread_pool_tools_error(_func, _servers, _timeout): + return None, "update tools error" + + monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_tools_error) + res = _run(module.update("mcp-1")) + assert res["code"] == 102 + assert "update tools error" in res["message"] + + _set_request_json(monkeypatch, module, dict(base_payload)) + + async def _thread_pool_ok(_func, servers, _timeout): + return {servers[0].name: [{"name": "tool_a"}, {"bad": True}]}, None + + monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_ok) + monkeypatch.setattr(module.MCPServerService, "filter_update", lambda *_args, **_kwargs: False) + res = _run(module.update("mcp-1")) + assert "Failed to updated MCP server" in res["message"] + + _set_request_json(monkeypatch, module, dict(base_payload)) + monkeypatch.setattr(module.MCPServerService, "filter_update", lambda *_args, **_kwargs: True) + + def _get_by_id_fetch_fail(_mcp_id): + if _get_by_id_fetch_fail.calls == 0: + _get_by_id_fetch_fail.calls += 1 + return True, existing + return False, None + + _get_by_id_fetch_fail.calls = 0 + monkeypatch.setattr(module.MCPServerService, "get_by_id", _get_by_id_fetch_fail) + res = _run(module.update("mcp-1")) + assert "Failed to fetch updated MCP server" in res["message"] + + _set_request_json(monkeypatch, module, dict(base_payload)) + + def _get_by_id_success(_mcp_id): + if _get_by_id_success.calls == 0: + _get_by_id_success.calls += 1 + return True, existing + return True, updated + + _get_by_id_success.calls = 0 + monkeypatch.setattr(module.MCPServerService, "get_by_id", _get_by_id_success) + res = _run(module.update("mcp-1")) + assert res["code"] == 0 + assert res["data"]["id"] == "mcp-1" + + _set_request_json(monkeypatch, module, dict(base_payload)) + monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, existing)) + + async def _thread_pool_raises(_func, _servers, _timeout): + raise RuntimeError("update explode") + + monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_raises) + res = _run(module.update("mcp-1")) + assert res["code"] == 100 + assert "update explode" in res["message"] + + +@pytest.mark.p2 +def test_rm_failure_success_and_exception(monkeypatch): + module = _load_mcp_api(monkeypatch) + server = _DummyMCPServer(id="id1", name="srv", url="http://a", server_type="sse", tenant_id="tenant_1", variables={}) + monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, server)) + + _set_request_json(monkeypatch, module, {"mcp_ids": ["a", "b"]}) + monkeypatch.setattr(module.MCPServerService, "delete_by_ids", lambda _ids: False) + res = _run(module.rm("id1")) + assert "Failed to delete MCP servers" in res["message"] + + _set_request_json(monkeypatch, module, {"mcp_ids": ["a", "b"]}) + monkeypatch.setattr(module.MCPServerService, "delete_by_ids", lambda _ids: True) + res = _run(module.rm("id1")) + assert res["code"] == 0 + assert res["data"] is True + + _set_request_json(monkeypatch, module, {"mcp_ids": ["a", "b"]}) + + def _raise_rm(_ids): + raise RuntimeError("rm explode") + + monkeypatch.setattr(module.MCPServerService, "delete_by_ids", _raise_rm) + res = _run(module.rm("id1")) + assert res["code"] == 100 + assert "rm explode" in res["message"] + + +@pytest.mark.p2 +def test_import_multiple_missing_servers_and_exception(monkeypatch): + module = _load_mcp_api(monkeypatch) + + _set_request_json(monkeypatch, module, {"mcpServers": {}}) + res = _run(module.import_multiple.__wrapped__()) + assert "No MCP servers provided" in res["message"] + + _set_request_json(monkeypatch, module, {"mcpServers": {"srv": {"type": "sse", "url": "http://x"}}, "timeout": "1"}) + + def _raise_import(**_kwargs): + raise RuntimeError("import explode") + + monkeypatch.setattr(module.MCPServerService, "get_by_name_and_tenant", _raise_import) + res = _run(module.import_multiple.__wrapped__()) + assert res["code"] == 100 + assert "import explode" in res["message"] + + +@pytest.mark.p2 +def test_import_multiple_mixed_results(monkeypatch): + module = _load_mcp_api(monkeypatch) + + payload = { + "mcpServers": { + "missing_fields": {"type": "sse"}, + "": {"type": "sse", "url": "http://empty"}, + "dup": {"type": "sse", "url": "http://dup", "authorization_token": "dup-token"}, + "tool_err": {"type": "sse", "url": "http://err"}, + "insert_fail": {"type": "sse", "url": "http://fail"}, + }, + "timeout": "3", + } + _set_request_json(monkeypatch, module, payload) + + monkeypatch.setattr(module, "get_uuid", lambda: "uuid-import") + + def _get_by_name_and_tenant(name, tenant_id): + if name == "dup" and not _get_by_name_and_tenant.first_dup_seen: + _get_by_name_and_tenant.first_dup_seen = True + return True, object() + return False, None + + _get_by_name_and_tenant.first_dup_seen = False + monkeypatch.setattr(module.MCPServerService, "get_by_name_and_tenant", _get_by_name_and_tenant) + + async def _thread_pool_exec(func, servers, _timeout): + mcp_server = servers[0] + if mcp_server.name == "tool_err": + return None, "tool call failed" + return {mcp_server.name: [{"name": "tool_a"}, {"invalid": True}]}, None + + monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_exec) + + def _insert(**kwargs): + return kwargs["name"] != "insert_fail" + + monkeypatch.setattr(module.MCPServerService, "insert", _insert) + + res = _run(module.import_multiple.__wrapped__()) + assert res["code"] == 0 + + results = {item["server"]: item for item in res["data"]["results"]} + assert results["missing_fields"]["success"] is False + assert "Missing required fields" in results["missing_fields"]["message"] + assert results[""]["success"] is False + assert "Invalid MCP name" in results[""]["message"] + assert results["tool_err"]["success"] is False + assert "tool call failed" in results["tool_err"]["message"] + assert results["insert_fail"]["success"] is False + assert "Failed to create MCP server" in results["insert_fail"]["message"] + assert results["dup"]["success"] is True + assert results["dup"]["new_name"] == "dup_0" + assert "Renamed from 'dup' to 'dup_0' avoid duplication" == results["dup"]["message"] + + +@pytest.mark.p2 +def test_detail_download_success_and_exception(monkeypatch): + module = _load_mcp_api(monkeypatch) + monkeypatch.setattr(module, "request", SimpleNamespace(args=_Args({"mode": "download"}))) + + monkeypatch.setattr( + module.MCPServerService, + "get_by_id", + lambda _mcp_id: ( + True, + _DummyMCPServer( + id="id1", + name="srv-one", + url="http://one", + server_type="sse", + tenant_id="tenant_1", + variables={"authorization_token": "tok", "tools": {"tool_a": {"enabled": True}}}, + ), + ), + ) + res = module.detail("id1") + assert res["code"] == 0 + assert list(res["data"]["mcpServers"].keys()) == ["srv-one"] + + monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (False, None)) + res = module.detail("missing") + assert res["code"] == 102 + assert "Cannot find MCP server missing for user tenant_1" in res["message"] + + monkeypatch.setattr( + module.MCPServerService, + "get_by_id", + lambda _mcp_id: ( + True, + _DummyMCPServer( + id="id2", + name="srv-two", + url="http://two", + server_type="sse", + tenant_id="other", + variables={}, + ), + ), + ) + res = module.detail("id2") + assert res["code"] == 102 + assert "Cannot find MCP server id2 for user tenant_1" in res["message"] + + def _raise_export(_mcp_id): + raise RuntimeError("export explode") + + monkeypatch.setattr(module.MCPServerService, "get_by_id", _raise_export) + res = module.detail("id1") + assert res["code"] == 100 + assert "export explode" in res["message"] + + +@pytest.mark.p2 +def test_test_mcp_route_matrix_unit(monkeypatch): + module = _load_mcp_api(monkeypatch) + + _set_request_json(monkeypatch, module, {"url": "", "server_type": "sse"}) + res = _run(module.test_mcp("mcp-1")) + assert "Invalid MCP url" in res["message"] + + _set_request_json(monkeypatch, module, {"url": "http://a", "server_type": "invalid"}) + res = _run(module.test_mcp("mcp-1")) + assert "Unsupported MCP server type" in res["message"] + + close_calls = [] + + async def _thread_pool_exec_inner_error(func, *args): + if func is module.close_multiple_mcp_toolcall_sessions: + close_calls.append(args[0]) + return None + if getattr(func, "__name__", "") == "get_tools": + raise RuntimeError("get tools explode") + return func(*args) + + monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_exec_inner_error) + _set_request_json(monkeypatch, module, {"url": "http://a", "server_type": "sse"}) + res = _run(module.test_mcp("mcp-1")) + assert res["code"] == 102 + assert "Test MCP error: get tools explode" in res["message"] + assert close_calls and len(close_calls[-1]) == 1 + + close_calls_success = [] + + async def _thread_pool_exec_success(func, *args): + if func is module.close_multiple_mcp_toolcall_sessions: + close_calls_success.append(args[0]) + return None + return func(*args) + + monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_exec_success) + _set_request_json(monkeypatch, module, {"url": "http://a", "server_type": "sse"}) + res = _run(module.test_mcp("mcp-1")) + assert res["code"] == 0 + assert res["data"][0]["name"] == "tool_a" + assert all(tool["enabled"] is True for tool in res["data"]) + assert close_calls_success and len(close_calls_success[-1]) == 1 + + def _raise_session(*_args, **_kwargs): + raise RuntimeError("session explode") + + monkeypatch.setattr(module, "MCPToolCallSession", _raise_session) + _set_request_json(monkeypatch, module, {"url": "http://a", "server_type": "sse"}) + res = _run(module.test_mcp("mcp-1")) + assert res["code"] == 100 + assert "session explode" in res["message"] diff --git a/test/testcases/restful_api/test_memories_messages.py b/test/testcases/restful_api/test_memories_messages.py new file mode 100644 index 0000000000..12fcdc0df2 --- /dev/null +++ b/test/testcases/restful_api/test_memories_messages.py @@ -0,0 +1,210 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import time +import uuid + +import pytest + + +@pytest.fixture +def memory_cleanup(rest_client): + created_ids: list[str] = [] + + def _cleanup(): + cleanup_errors = [] + for memory_id in created_ids: + delete_res = rest_client.delete(f"/memories/{memory_id}") + if delete_res.status_code != 200: + cleanup_errors.append((memory_id, delete_res.status_code, delete_res.text)) + continue + delete_payload = delete_res.json() + if delete_payload["code"] not in (0, 404): + cleanup_errors.append((memory_id, delete_res.status_code, delete_payload)) + assert not cleanup_errors, f"Memory cleanup failed: {cleanup_errors}" + + yield created_ids + _cleanup() + + +@pytest.fixture +def create_memory_resource(rest_client, memory_cleanup): + def _create(name_prefix: str = "restful_memory") -> str: + payload = { + "name": f"{name_prefix}_{uuid.uuid4().hex[:8]}", + "memory_type": ["raw"], + "embd_id": "BAAI/bge-small-en-v1.5@Builtin", + "llm_id": "glm-4-flash@ZHIPU-AI", + } + res = rest_client.post("/memories", json=payload) + assert res.status_code == 200 + res_payload = res.json() + assert res_payload["code"] == 0, res_payload + memory_id = res_payload["data"]["id"] + memory_cleanup.append(memory_id) + return memory_id + + yield _create + + +def _add_message(rest_client, memory_id: str, user_input: str, agent_response: str) -> None: + add_res = rest_client.post( + "/messages", + json={ + "memory_id": [memory_id], + "agent_id": uuid.uuid4().hex, + "session_id": uuid.uuid4().hex, + "user_id": uuid.uuid4().hex, + "user_input": user_input, + "agent_response": agent_response, + }, + ) + assert add_res.status_code == 200 + add_payload = add_res.json() + assert add_payload["code"] == 0, add_payload + + +def _wait_for_memory_messages(rest_client, memory_id: str, timeout: float = 10, interval: float = 0.2) -> list[dict]: + deadline = time.time() + timeout + last_payload = None + while time.time() < deadline: + res = rest_client.get(f"/memories/{memory_id}") + if res.status_code == 200: + payload = res.json() + last_payload = payload + if payload.get("code") == 0: + message_list = payload.get("data", {}).get("messages", {}).get("message_list", []) + if message_list: + return message_list + time.sleep(interval) + pytest.fail(f"Timed out waiting for memory messages: {last_payload}") + + +@pytest.mark.p1 +def test_memory_crud_cycle(rest_client, create_memory_resource): + memory_id = create_memory_resource("restful_memory_crud") + + list_res = rest_client.get("/memories") + assert list_res.status_code == 200 + list_payload = list_res.json() + assert list_payload["code"] == 0, list_payload + assert any(item["id"] == memory_id for item in list_payload["data"]["memory_list"]), list_payload + + config_res = rest_client.get(f"/memories/{memory_id}/config") + assert config_res.status_code == 200 + config_payload = config_res.json() + assert config_payload["code"] == 0, config_payload + assert config_payload["data"]["id"] == memory_id, config_payload + + update_res = rest_client.put( + f"/memories/{memory_id}", + json={"name": f"updated_{uuid.uuid4().hex[:6]}", "permissions": "me"}, + ) + assert update_res.status_code == 200 + update_payload = update_res.json() + assert update_payload["code"] == 0, update_payload + + delete_res = rest_client.delete(f"/memories/{memory_id}") + assert delete_res.status_code == 200 + delete_payload = delete_res.json() + assert delete_payload["code"] == 0, delete_payload + + +@pytest.mark.p2 +def test_memory_create_missing_required_fields(rest_client): + res = rest_client.post("/memories", json={"name": "missing_models", "memory_type": ["raw"]}) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 101, payload + + +@pytest.mark.p1 +def test_messages_add_list_recent_content_update_forget(rest_client, create_memory_resource): + memory_id = create_memory_resource("restful_message_memory") + _add_message( + rest_client, + memory_id, + user_input="what is coriander?", + agent_response="coriander can refer to leaves or seeds", + ) + + message_list = _wait_for_memory_messages(rest_client, memory_id) + + message_id = message_list[0]["message_id"] + + recent_res = rest_client.get("/messages", params={"memory_id": memory_id, "limit": 10}) + assert recent_res.status_code == 200 + recent_payload = recent_res.json() + assert recent_payload["code"] == 0, recent_payload + assert any(item["message_id"] == message_id for item in recent_payload["data"]), recent_payload + + content_res = rest_client.get(f"/messages/{memory_id}:{message_id}/content") + assert content_res.status_code == 200 + content_payload = content_res.json() + assert content_payload["code"] == 0, content_payload + assert content_payload["data"]["content"], content_payload + + update_res = rest_client.put(f"/messages/{memory_id}:{message_id}", json={"status": False}) + assert update_res.status_code == 200 + update_payload = update_res.json() + assert update_payload["code"] == 0, update_payload + + forget_res = rest_client.delete(f"/messages/{memory_id}:{message_id}") + assert forget_res.status_code == 200 + forget_payload = forget_res.json() + assert forget_payload["code"] == 0, forget_payload + + +@pytest.mark.p2 +def test_message_status_validation_requires_boolean(rest_client, create_memory_resource): + memory_id = create_memory_resource("restful_message_status_validation") + _add_message(rest_client, memory_id, user_input="hello", agent_response="hello") + + message_id = _wait_for_memory_messages(rest_client, memory_id)[0]["message_id"] + + invalid_update = rest_client.put(f"/messages/{memory_id}:{message_id}", json={"status": "false"}) + assert invalid_update.status_code == 200 + invalid_payload = invalid_update.json() + assert invalid_payload["code"] == 101, invalid_payload + assert "Status must be a boolean." in invalid_payload["message"], invalid_payload + + +@pytest.mark.p2 +def test_messages_recent_requires_memory_ids(rest_client): + res = rest_client.get("/messages") + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 101, payload + assert "memory_ids is required" in payload["message"], payload + + +@pytest.mark.p2 +def test_message_search_route_contract(rest_client, create_memory_resource): + memory_id = create_memory_resource("restful_message_search") + _add_message( + rest_client, + memory_id, + user_input="what is pineapple?", + agent_response="pineapple is a tropical fruit", + ) + + _wait_for_memory_messages(rest_client, memory_id) + + res = rest_client.get("/messages/search", params={"memory_id": memory_id, "query": "pineapple", "top_n": 3}) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + assert isinstance(payload["data"], list), payload diff --git a/test/testcases/restful_api/test_memory_messages.py b/test/testcases/restful_api/test_memory_messages.py new file mode 100644 index 0000000000..dcf5a3704f --- /dev/null +++ b/test/testcases/restful_api/test_memory_messages.py @@ -0,0 +1,165 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import uuid + +import pytest + + +def _memory_payload(name: str) -> dict: + return { + "name": name, + "memory_type": ["raw"], + "embd_id": "BAAI/bge-small-en-v1.5@Builtin", + "llm_id": "glm-4-flash@ZHIPU-AI", + } + + +def _create_memory(rest_client, name: str) -> dict: + res = rest_client.post("/memories", json=_memory_payload(name)) + assert res.status_code == 200 + payload = res.json() + if payload["code"] == 0: + return payload["data"] + + pytest.fail(f"Failed to create memory: {payload}") + + +@pytest.fixture +def memory_resource(rest_client): + memory = _create_memory(rest_client, f"restful_memory_{uuid.uuid4().hex[:8]}") + memory_id = memory["id"] + try: + yield memory + finally: + delete_res = rest_client.delete(f"/memories/{memory_id}") + assert delete_res.status_code == 200, delete_res.text + delete_payload = delete_res.json() + assert delete_payload["code"] in (0, 404), delete_payload + + +@pytest.mark.p2 +def test_memory_and_message_routes_require_auth(rest_client_noauth): + memory_res = rest_client_noauth.get("/memories") + assert memory_res.status_code == 401 + memory_payload = memory_res.json() + assert memory_payload["code"] == 401, memory_payload + + msg_list_res = rest_client_noauth.get("/messages") + assert msg_list_res.status_code == 401 + msg_list_payload = msg_list_res.json() + assert msg_list_payload["code"] == 401, msg_list_payload + + msg_search_res = rest_client_noauth.get("/messages/search") + assert msg_search_res.status_code == 401 + msg_search_payload = msg_search_res.json() + assert msg_search_payload["code"] == 401, msg_search_payload + + +@pytest.mark.p2 +def test_memory_crud_and_config(rest_client): + memory = _create_memory(rest_client, f"restful_memory_crud_{uuid.uuid4().hex[:8]}") + memory_id = memory["id"] + try: + config_res = rest_client.get(f"/memories/{memory_id}/config") + assert config_res.status_code == 200 + config_payload = config_res.json() + assert config_payload["code"] == 0, config_payload + assert config_payload["data"]["id"] == memory_id, config_payload + + list_res = rest_client.get("/memories", params={"keywords": memory["name"]}) + assert list_res.status_code == 200 + list_payload = list_res.json() + assert list_payload["code"] == 0, list_payload + assert any(item["id"] == memory_id for item in list_payload["data"]["memory_list"]), list_payload + + update_res = rest_client.put(f"/memories/{memory_id}", json={"name": "restful_memory_updated"}) + assert update_res.status_code == 200 + update_payload = update_res.json() + assert update_payload["code"] == 0, update_payload + finally: + delete_res = rest_client.delete(f"/memories/{memory_id}") + assert delete_res.status_code == 200, delete_res.text + delete_payload = delete_res.json() + assert delete_payload["code"] in (0, 404), delete_payload + + +@pytest.mark.p2 +def test_memory_update_invalid_name(rest_client, memory_resource): + memory_id = memory_resource["id"] + res = rest_client.put(f"/memories/{memory_id}", json={"name": " "}) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 101, payload + assert "cannot be empty" in payload["message"], payload + + +@pytest.mark.p2 +def test_messages_list_and_search_validation_contracts(rest_client, memory_resource): + memory_id = memory_resource["id"] + + list_res = rest_client.get("/messages", params={"memory_id": memory_id, "limit": 10}) + assert list_res.status_code == 200 + list_payload = list_res.json() + assert list_payload["code"] == 0, list_payload + assert isinstance(list_payload["data"], list), list_payload + + missing_memory_res = rest_client.get("/messages") + assert missing_memory_res.status_code == 200 + missing_memory_payload = missing_memory_res.json() + assert missing_memory_payload["code"] == 101, missing_memory_payload + assert "memory_ids is required" in missing_memory_payload["message"], missing_memory_payload + + search_res = rest_client.get("/messages/search", params={"memory_id": memory_id, "query": "coriander"}) + assert search_res.status_code == 200 + search_payload = search_res.json() + assert search_payload["code"] == 0, search_payload + assert isinstance(search_payload["data"], list), search_payload + + search_no_memory = rest_client.get("/messages/search", params={"query": "coriander"}) + assert search_no_memory.status_code == 200 + search_no_memory_payload = search_no_memory.json() + assert search_no_memory_payload["code"] == 0, search_no_memory_payload + assert isinstance(search_no_memory_payload["data"], list), search_no_memory_payload + + +@pytest.mark.p2 +def test_message_update_forget_and_content_error_contracts(rest_client, memory_resource): + memory_id = memory_resource["id"] + + invalid_status_res = rest_client.put( + f"/messages/{memory_id}:1", + json={"status": "false"}, + ) + assert invalid_status_res.status_code == 200 + invalid_status_payload = invalid_status_res.json() + assert invalid_status_payload["code"] == 101, invalid_status_payload + assert "Status must be a boolean" in invalid_status_payload["message"], invalid_status_payload + + missing_content_res = rest_client.get(f"/messages/{memory_id}:999999/content") + assert missing_content_res.status_code == 200 + missing_content_payload = missing_content_res.json() + assert missing_content_payload["code"] == 404, missing_content_payload + + invalid_memory_forget = rest_client.delete("/messages/missing_memory_id:1") + assert invalid_memory_forget.status_code == 200 + invalid_memory_forget_payload = invalid_memory_forget.json() + assert invalid_memory_forget_payload["code"] == 404, invalid_memory_forget_payload + + invalid_memory_update = rest_client.put("/messages/missing_memory_id:1", json={"status": False}) + assert invalid_memory_update.status_code == 200 + invalid_memory_update_payload = invalid_memory_update.json() + assert invalid_memory_update_payload["code"] == 404, invalid_memory_update_payload diff --git a/test/testcases/restful_api/test_openai_compatible.py b/test/testcases/restful_api/test_openai_compatible.py new file mode 100644 index 0000000000..d858576725 --- /dev/null +++ b/test/testcases/restful_api/test_openai_compatible.py @@ -0,0 +1,212 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json + +import pytest + + +def _sse_events(response_text: str) -> list[str]: + return [line[5:] for line in response_text.splitlines() if line.startswith("data:")] + + +@pytest.mark.p2 +@pytest.mark.parametrize( + "payload, expected_message", + [ + ( + { + "model": "model", + "messages": [{"role": "user", "content": "hello"}], + "extra_body": "invalid_extra_body", + }, + "extra_body must be an object.", + ), + ( + { + "model": "model", + "messages": [{"role": "user", "content": "hello"}], + "extra_body": {"reference_metadata": "invalid_reference_metadata"}, + }, + "reference_metadata must be an object.", + ), + ( + { + "model": "model", + "messages": [{"role": "user", "content": "hello"}], + "extra_body": {"reference_metadata": {"fields": "author"}}, + }, + "reference_metadata.fields must be an array.", + ), + ( + { + "model": "model", + "messages": [], + }, + "You have to provide messages.", + ), + ( + { + "model": "model", + "messages": [{"role": "assistant", "content": "hello"}], + }, + "The last content of this conversation is not from user.", + ), + ], +) +def test_openai_compatible_validation_payloads(rest_client, create_chat, payload, expected_message): + chat_id = create_chat("restful_openai_validation_chat") + res = rest_client.post(f"/openai/{chat_id}/chat/completions", json=payload) + assert res.status_code == 200 + data = res.json() + assert data["code"] != 0, data + assert expected_message in data.get("message", ""), data + + +@pytest.mark.p2 +def test_openai_compatible_metadata_condition_requires_object(rest_client, create_chat): + chat_id = create_chat("restful_openai_metadata_condition_chat") + res = rest_client.post( + f"/openai/{chat_id}/chat/completions", + json={ + "model": "model", + "messages": [{"role": "user", "content": "hello"}], + "extra_body": {"metadata_condition": "invalid"}, + }, + ) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 102, payload + assert "metadata_condition must be an object." in payload["message"], payload + + +@pytest.mark.p2 +def test_openai_compatible_invalid_chat(rest_client): + res = rest_client.post( + "/openai/invalid_chat_id/chat/completions", + json={ + "model": "model", + "messages": [{"role": "user", "content": "hello"}], + "stream": False, + }, + ) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] != 0, payload + assert "don't own the chat" in payload["message"], payload + + +@pytest.mark.p2 +def test_openai_compatible_nonstream_shape(rest_client, create_chat): + chat_id = create_chat("restful_openai_nonstream_chat") + res = rest_client.post( + f"/openai/{chat_id}/chat/completions", + json={ + "model": "model", + "messages": [{"role": "user", "content": "hello"}], + "stream": False, + }, + timeout=60, + ) + assert res.status_code == 200 + payload = res.json() + + assert payload["object"] == "chat.completion", payload + assert isinstance(payload["choices"], list) and payload["choices"], payload + first_choice = payload["choices"][0] + assert first_choice.get("finish_reason") == "stop", payload + assert first_choice.get("message", {}).get("role") == "assistant", payload + assert "content" in first_choice.get("message", {}), payload + + usage = payload.get("usage", {}) + assert "prompt_tokens" in usage, usage + assert "completion_tokens" in usage, usage + assert "total_tokens" in usage, usage + assert usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"], usage + + +@pytest.mark.p2 +def test_openai_compatible_nonstream_with_reference_output_shape(rest_client, create_chat): + chat_id = create_chat("restful_openai_reference_chat") + res = rest_client.post( + f"/openai/{chat_id}/chat/completions", + json={ + "model": "model", + "messages": [{"role": "user", "content": "hello"}], + "stream": False, + "extra_body": { + "reference": True, + "reference_metadata": {"include": True, "fields": ["author"]}, + }, + }, + timeout=60, + ) + assert res.status_code == 200 + payload = res.json() + choice_msg = payload["choices"][0]["message"] + assert "reference" in choice_msg, payload + assert isinstance(choice_msg["reference"], list), payload + + +@pytest.mark.p2 +def test_openai_compatible_stream_shape_and_done_semantics(rest_client, create_chat): + chat_id = create_chat("restful_openai_stream_chat") + res = rest_client.post( + f"/openai/{chat_id}/chat/completions", + json={ + "model": "model", + "messages": [{"role": "user", "content": "hello"}], + "stream": True, + "extra_body": {"reference": True}, + }, + timeout=60, + ) + assert res.status_code == 200 + content_type = res.headers.get("Content-Type", "") + assert "text/event-stream" in content_type, content_type + + events = _sse_events(res.text) + assert events, res.text + assert events[-1].strip() == "[DONE]", events[-1] + + json_events = [json.loads(evt) for evt in events if evt.strip() != "[DONE]"] + assert json_events, events + assert any(evt.get("object") == "chat.completion.chunk" for evt in json_events), json_events + assert any(evt.get("choices", [{}])[0].get("finish_reason") == "stop" for evt in json_events), json_events + + +@pytest.mark.p2 +def test_openai_compatible_reference_metadata_fields_filter_accepts_array(rest_client, create_chat): + chat_id = create_chat("restful_openai_reference_fields_array_chat") + res = rest_client.post( + f"/openai/{chat_id}/chat/completions", + json={ + "model": "model", + "messages": [{"role": "user", "content": "hello"}], + "stream": False, + "extra_body": { + "reference": True, + "reference_metadata": {"include": True, "fields": ["author", "year"]}, + }, + }, + timeout=60, + ) + assert res.status_code == 200 + payload = res.json() + assert payload.get("choices"), payload + choice_msg = payload["choices"][0]["message"] + assert "reference" in choice_msg, payload + assert isinstance(choice_msg["reference"], list), payload diff --git a/test/testcases/restful_api/test_plugin_tools.py b/test/testcases/restful_api/test_plugin_tools.py new file mode 100644 index 0000000000..c151394c29 --- /dev/null +++ b/test/testcases/restful_api/test_plugin_tools.py @@ -0,0 +1,92 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + +import pytest + + +@pytest.mark.p2 +def test_plugin_tools_requires_auth(rest_client_noauth): + res = rest_client_noauth.get("/plugin/tools") + assert res.status_code == 401 + payload = res.json() + assert payload["code"] == 401, payload + + +@pytest.mark.p2 +def test_plugin_tools_contract(rest_client): + res = rest_client.get("/plugin/tools") + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + assert isinstance(payload["data"], list), payload + + +class _DummyManager: + def route(self, *_args, **_kwargs): + def decorator(func): + return func + + return decorator + + +def _load_plugin_module(monkeypatch): + repo_root = Path(__file__).resolve().parents[3] + + common_pkg = ModuleType("common") + common_pkg.__path__ = [str(repo_root / "common")] + monkeypatch.setitem(sys.modules, "common", common_pkg) + + stub_apps = ModuleType("api.apps") + stub_apps.login_required = lambda func: func + monkeypatch.setitem(sys.modules, "api.apps", stub_apps) + + stub_plugin = ModuleType("agent.plugin") + + class _StubGlobalPluginManager: + @staticmethod + def get_llm_tools(): + return [] + + stub_plugin.GlobalPluginManager = _StubGlobalPluginManager + monkeypatch.setitem(sys.modules, "agent.plugin", stub_plugin) + + module_path = repo_root / "api" / "apps" / "restful_apis" / "plugin_api.py" + spec = importlib.util.spec_from_file_location("restful_plugin_api_unit", module_path) + module = importlib.util.module_from_spec(spec) + module.manager = _DummyManager() + spec.loader.exec_module(module) + return module + + +@pytest.mark.p2 +def test_plugin_tools_metadata_shape_unit(monkeypatch): + module = _load_plugin_module(monkeypatch) + + class _DummyTool: + def get_metadata(self): + return {"name": "dummy", "description": "test"} + + monkeypatch.setattr(module.GlobalPluginManager, "get_llm_tools", staticmethod(lambda: [_DummyTool()])) + res = module.llm_tools() + assert res["code"] == 0 + assert isinstance(res["data"], list) + assert res["data"][0]["name"] == "dummy" + assert res["data"][0]["description"] == "test" diff --git a/test/testcases/restful_api/test_retrieval.py b/test/testcases/restful_api/test_retrieval.py new file mode 100644 index 0000000000..bce37c4cdb --- /dev/null +++ b/test/testcases/restful_api/test_retrieval.py @@ -0,0 +1,109 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest + + +@pytest.mark.p1 +def test_dataset_search_rest_endpoint(rest_client, ensure_parsed_document): + dataset_id, _ = ensure_parsed_document() + res = rest_client.post( + f"/datasets/{dataset_id}/search", + json={"question": "test TXT file", "top_k": 5}, + ) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + assert "chunks" in payload["data"], payload + + +@pytest.mark.p2 +def test_multi_dataset_search_rest_endpoint(rest_client, ensure_parsed_document): + dataset_id, _ = ensure_parsed_document() + res = rest_client.post( + "/datasets/search", + json={"dataset_ids": [dataset_id], "question": "test TXT file", "top_k": 5}, + ) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + assert "chunks" in payload["data"], payload + + +@pytest.mark.p2 +def test_multi_dataset_search_with_metadata_filter(rest_client, ensure_parsed_document): + dataset_id, document_id = ensure_parsed_document() + meta_res = rest_client.patch( + f"/datasets/{dataset_id}/documents/metadatas", + json={ + "selector": {"document_ids": [document_id]}, + "updates": [{"key": "author", "value": "qa_batch2"}], + "deletes": [], + }, + ) + assert meta_res.status_code == 200 + meta_payload = meta_res.json() + assert meta_payload["code"] == 0, meta_payload + + res = rest_client.post( + "/datasets/search", + json={ + "dataset_ids": [dataset_id], + "question": "test TXT file", + "meta_data_filter": { + "method": "manual", + "logic": "and", + "manual": [{"key": "author", "op": "=", "value": "qa_batch2"}], + }, + }, + ) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + assert "chunks" in payload["data"], payload + + +@pytest.mark.p2 +def test_retrieval_compatibility_endpoint(rest_client, ensure_parsed_document): + dataset_id, _ = ensure_parsed_document() + # /api/v1/retrieval is SDK compatibility endpoint from api/apps/sdk/doc.py. + res = rest_client.post( + "/retrieval", + json={"dataset_ids": [dataset_id], "question": "test TXT file", "top_k": 5}, + ) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + assert "chunks" in payload["data"], payload + + +@pytest.mark.p2 +def test_retrieval_compatibility_requires_dataset_ids(rest_client): + res = rest_client.post("/retrieval", json={"question": "test"}) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 102, payload + assert payload["message"] == "`dataset_ids` is required.", payload + + +@pytest.mark.p2 +def test_retrieval_compatibility_requires_auth(rest_client_noauth): + res = rest_client_noauth.post("/retrieval", json={"question": "test", "dataset_ids": ["x"]}) + assert res.status_code == 401 + payload = res.json() + # token_required preserves legacy payload code/message while returning HTTP 401. + assert payload["code"] == 0, payload + assert payload["message"] == "`Authorization` can't be empty", payload diff --git a/test/testcases/restful_api/test_router_contracts.py b/test/testcases/restful_api/test_router_contracts.py new file mode 100644 index 0000000000..72683bad8f --- /dev/null +++ b/test/testcases/restful_api/test_router_contracts.py @@ -0,0 +1,28 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest +from configs import VERSION + + +@pytest.mark.p1 +def test_route_not_found_returns_json(rest_client_noauth): + res = rest_client_noauth.get("/__missing_route__") + assert res.status_code == 404 + payload = res.json() + assert payload["code"] == 404, payload + assert payload["error"] == "Not Found", payload + assert payload["message"] == f"Not Found: /api/{VERSION}/__missing_route__", payload diff --git a/test/testcases/restful_api/test_searches.py b/test/testcases/restful_api/test_searches.py new file mode 100644 index 0000000000..1a6923fb50 --- /dev/null +++ b/test/testcases/restful_api/test_searches.py @@ -0,0 +1,155 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +import uuid + +import pytest + + +@pytest.fixture +def search_resource(rest_client): + name = f"restful_search_{uuid.uuid4().hex[:8]}" + create_res = rest_client.post("/searches", json={"name": name, "description": "restful search"}) + assert create_res.status_code == 200 + create_payload = create_res.json() + assert create_payload["code"] == 0, create_payload + search_id = create_payload["data"]["search_id"] + + try: + yield search_id + finally: + delete_res = rest_client.delete(f"/searches/{search_id}") + assert delete_res.status_code == 200, delete_res.text + delete_payload = delete_res.json() + assert delete_payload["code"] in (0, 109), delete_payload + + +def _sse_events(response_text: str) -> list[str]: + return [line[5:] for line in response_text.splitlines() if line.startswith("data:")] + + +@pytest.mark.p2 +def test_search_routes_require_auth(rest_client_noauth): + create_res = rest_client_noauth.post("/searches", json={"name": "search_noauth"}) + assert create_res.status_code == 401 + create_payload = create_res.json() + assert create_payload["code"] == 401, create_payload + + list_res = rest_client_noauth.get("/searches") + assert list_res.status_code == 401 + list_payload = list_res.json() + assert list_payload["code"] == 401, list_payload + + +@pytest.mark.p2 +def test_search_crud_contract(rest_client, search_resource): + search_id = search_resource + + list_res = rest_client.get("/searches") + assert list_res.status_code == 200 + list_payload = list_res.json() + assert list_payload["code"] == 0, list_payload + assert any(item.get("id") == search_id for item in list_payload["data"]["search_apps"]), list_payload + + detail_res = rest_client.get(f"/searches/{search_id}") + assert detail_res.status_code == 200 + detail_payload = detail_res.json() + assert detail_payload["code"] == 0, detail_payload + assert detail_payload["data"]["id"] == search_id, detail_payload + + new_name = f"search_updated_{uuid.uuid4().hex[:6]}" + update_res = rest_client.put( + f"/searches/{search_id}", + json={"name": new_name, "search_config": {"top_k": 3}}, + ) + assert update_res.status_code == 200 + update_payload = update_res.json() + assert update_payload["code"] == 0, update_payload + assert update_payload["data"]["name"] == new_name, update_payload + + +@pytest.mark.p2 +def test_search_create_invalid_name(rest_client): + res = rest_client.post("/searches", json={"name": ""}) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 102, payload + assert "empty" in payload["message"], payload + + +@pytest.mark.p2 +def test_search_update_invalid_search_id(rest_client): + res = rest_client.put( + "/searches/invalid_search_id", + json={"name": "invalid", "search_config": {}}, + ) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 109, payload + assert "No authorization" in payload["message"], payload + + +@pytest.mark.p2 +def test_search_completion_requires_question(rest_client, search_resource): + search_id = search_resource + + completion_res = rest_client.post(f"/searches/{search_id}/completion", json={}) + assert completion_res.status_code == 200 + completion_payload = completion_res.json() + assert completion_payload["code"] == 101, completion_payload + assert "required argument are missing: question" in completion_payload["message"], completion_payload + + completions_res = rest_client.post(f"/searches/{search_id}/completions", json={}) + assert completions_res.status_code == 200 + completions_payload = completions_res.json() + assert completions_payload["code"] == 101, completions_payload + assert "required argument are missing: question" in completions_payload["message"], completions_payload + + +@pytest.mark.p2 +def test_search_completion_requires_kb_ids(rest_client, search_resource): + search_id = search_resource + for path in ("completion", "completions"): + res = rest_client.post( + f"/searches/{search_id}/{path}", + json={"question": "what is coriander?"}, + ) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 102, payload + assert "`kb_ids` is required" in payload["message"], payload + + +@pytest.mark.p2 +def test_search_completion_sse_shape_when_kb_ids_provided(rest_client, search_resource): + search_id = search_resource + # Even with kb_ids provided, runtime may return an error event in-stream, but + # contract remains SSE with JSON data lines and terminal boolean event. + res = rest_client.post( + f"/searches/{search_id}/completion", + json={"question": "what is coriander?", "kb_ids": ["nonexistent_dataset"]}, + timeout=60, + ) + assert res.status_code == 200 + content_type = res.headers.get("Content-Type", "") + assert "text/event-stream" in content_type, content_type + + events = _sse_events(res.text) + assert events, res.text + parsed = [json.loads(evt) for evt in events] + assert isinstance(parsed[0], dict), parsed + assert parsed[-1].get("data") is True, parsed[-1] diff --git a/test/testcases/restful_api/test_sessions.py b/test/testcases/restful_api/test_sessions.py new file mode 100644 index 0000000000..6a7a9de82c --- /dev/null +++ b/test/testcases/restful_api/test_sessions.py @@ -0,0 +1,219 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json + +import pytest + + +def _sse_events(response_text: str) -> list[str]: + return [line[5:] for line in response_text.splitlines() if line.startswith("data:")] + + +@pytest.mark.p1 +def test_session_crud_cycle(rest_client, create_chat): + chat_id = create_chat("restful_session_crud_chat") + + create_res = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": "session_a"}) + assert create_res.status_code == 200 + create_payload = create_res.json() + assert create_payload["code"] == 0, create_payload + session_id = create_payload["data"]["id"] + assert create_payload["data"]["chat_id"] == chat_id, create_payload + + list_res = rest_client.get(f"/chats/{chat_id}/sessions") + assert list_res.status_code == 200 + list_payload = list_res.json() + assert list_payload["code"] == 0, list_payload + assert any(item["id"] == session_id for item in list_payload["data"]), list_payload + + get_res = rest_client.get(f"/chats/{chat_id}/sessions/{session_id}") + assert get_res.status_code == 200 + get_payload = get_res.json() + assert get_payload["code"] == 0, get_payload + assert get_payload["data"]["id"] == session_id, get_payload + + patch_res = rest_client.patch( + f"/chats/{chat_id}/sessions/{session_id}", + json={"name": "session_a_updated"}, + ) + assert patch_res.status_code == 200 + patch_payload = patch_res.json() + assert patch_payload["code"] == 0, patch_payload + assert patch_payload["data"]["name"] == "session_a_updated", patch_payload + + delete_res = rest_client.delete(f"/chats/{chat_id}/sessions", json={"ids": [session_id]}) + assert delete_res.status_code == 200 + delete_payload = delete_res.json() + assert delete_payload["code"] == 0, delete_payload + + list_after_delete = rest_client.get(f"/chats/{chat_id}/sessions") + assert list_after_delete.status_code == 200 + list_after_delete_payload = list_after_delete.json() + assert list_after_delete_payload["code"] == 0, list_after_delete_payload + assert all(item["id"] != session_id for item in list_after_delete_payload["data"]), list_after_delete_payload + + +@pytest.mark.p2 +def test_session_create_name_validation(rest_client, create_chat): + chat_id = create_chat("restful_session_name_validation_chat") + + res = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": " "}) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 102, payload + assert "`name` can not be empty." in payload["message"], payload + + +@pytest.mark.p2 +def test_session_update_blocks_messages_and_reference(rest_client, create_chat): + chat_id = create_chat("restful_session_guard_chat") + create_res = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": "session_guard"}) + assert create_res.status_code == 200 + create_payload = create_res.json() + assert create_payload["code"] == 0, create_payload + session_id = create_payload["data"]["id"] + + msg_res = rest_client.patch(f"/chats/{chat_id}/sessions/{session_id}", json={"messages": []}) + assert msg_res.status_code == 200 + msg_payload = msg_res.json() + assert msg_payload["code"] == 102, msg_payload + assert "`messages` cannot be changed." in msg_payload["message"], msg_payload + + ref_res = rest_client.patch(f"/chats/{chat_id}/sessions/{session_id}", json={"reference": []}) + assert ref_res.status_code == 200 + ref_payload = ref_res.json() + assert ref_payload["code"] == 102, ref_payload + assert "`reference` cannot be changed." in ref_payload["message"], ref_payload + + +@pytest.mark.p2 +def test_chat_recommendation_requires_question(rest_client): + res = rest_client.post("/chat/recommendation", json={}) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 101, payload + assert "required argument are missing: question" in payload["message"], payload + + +@pytest.mark.p2 +def test_related_questions_compatibility_requires_auth(rest_client_noauth): + # /api/v1/searchbots/related_questions is an SDK compatibility endpoint. + res = rest_client_noauth.post( + "/searchbots/related_questions", + json={"question": "ragflow"}, + headers={"Authorization": "invalid"}, + ) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 102, payload + assert "Authorization is not valid!" in payload["message"], payload + + +@pytest.mark.p2 +def test_chat_completion_nonstream_with_session(rest_client, create_chat): + chat_id = create_chat("restful_completion_nonstream_chat") + create_session_res = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": "session_for_completion"}) + assert create_session_res.status_code == 200 + create_session_payload = create_session_res.json() + assert create_session_payload["code"] == 0, create_session_payload + session_id = create_session_payload["data"]["id"] + + completion_res = rest_client.post( + "/chat/completions", + json={ + "chat_id": chat_id, + "session_id": session_id, + "messages": [{"role": "user", "content": "hello"}], + "stream": False, + }, + timeout=60, + ) + assert completion_res.status_code == 200 + completion_payload = completion_res.json() + assert completion_payload["code"] == 0, completion_payload + assert isinstance(completion_payload["data"], dict), completion_payload + assert completion_payload["data"]["session_id"] == session_id, completion_payload + assert "answer" in completion_payload["data"], completion_payload + assert "reference" in completion_payload["data"], completion_payload + + +@pytest.mark.p2 +def test_chat_completion_stream_events(rest_client, create_chat): + chat_id = create_chat("restful_completion_stream_chat") + stream_res = rest_client.post( + "/chat/completions", + json={ + "chat_id": chat_id, + "messages": [{"role": "user", "content": "hello"}], + "stream": True, + }, + timeout=60, + ) + assert stream_res.status_code == 200 + content_type = stream_res.headers.get("Content-Type", "") + assert "text/event-stream" in content_type, content_type + + events = _sse_events(stream_res.text) + assert events, stream_res.text + parsed_events = [] + for event in events: + parsed = json.loads(event) + parsed_events.append(parsed) + + assert any(evt.get("code") == 0 and isinstance(evt.get("data"), dict) for evt in parsed_events), parsed_events + assert parsed_events[-1].get("data") is True, parsed_events[-1] + + +@pytest.mark.p2 +def test_chat_completion_validation_errors(rest_client, create_chat): + chat_id = create_chat("restful_completion_validation_chat") + + missing_messages = rest_client.post( + "/chat/completions", + json={"chat_id": chat_id, "stream": False}, + ) + assert missing_messages.status_code == 200 + missing_messages_payload = missing_messages.json() + assert missing_messages_payload["code"] == 101, missing_messages_payload + assert "required argument are missing: messages" in missing_messages_payload["message"], missing_messages_payload + + missing_chat_for_session = rest_client.post( + "/chat/completions", + json={ + "session_id": "some_session", + "messages": [{"role": "user", "content": "hello"}], + "stream": False, + }, + ) + assert missing_chat_for_session.status_code == 200 + missing_chat_for_session_payload = missing_chat_for_session.json() + assert missing_chat_for_session_payload["code"] == 102, missing_chat_for_session_payload + assert "`chat_id` is required when `session_id` is provided." in missing_chat_for_session_payload["message"], missing_chat_for_session_payload + + invalid_chat = rest_client.post( + "/chat/completions", + json={ + "chat_id": "invalid_chat_id", + "session_id": "invalid_session", + "messages": [{"role": "user", "content": "hello"}], + "stream": False, + }, + ) + assert invalid_chat.status_code == 200 + invalid_chat_payload = invalid_chat.json() + assert invalid_chat_payload["code"] == 109, invalid_chat_payload + assert "No authorization." in invalid_chat_payload["message"], invalid_chat_payload diff --git a/test/testcases/restful_api/test_system.py b/test/testcases/restful_api/test_system.py new file mode 100644 index 0000000000..e5022f4861 --- /dev/null +++ b/test/testcases/restful_api/test_system.py @@ -0,0 +1,159 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest + + +@pytest.mark.p1 +def test_system_ping(rest_client): + res = rest_client.get("/system/ping") + assert res.status_code == 200 + assert res.text == "pong" + + +@pytest.mark.p1 +def test_system_version(rest_client): + res = rest_client.get("/system/version") + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + assert payload["data"], payload + + +@pytest.mark.p2 +def test_system_status_requires_auth(rest_client_noauth): + res = rest_client_noauth.get("/system/status") + assert res.status_code == 401 + payload = res.json() + assert payload["code"] == 401, payload + assert "Unauthorized" in payload["message"], payload + + +@pytest.mark.p2 +def test_system_status_contract(rest_client): + res = rest_client.get("/system/status") + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + for key in ("doc_engine", "storage", "database", "redis"): + assert key in payload["data"], payload + + +@pytest.mark.p2 +def test_system_config_no_auth_required(rest_client_noauth): + res = rest_client_noauth.get("/system/config") + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + assert "registerEnabled" in payload["data"], payload + assert "disablePasswordLogin" in payload["data"], payload + + +@pytest.mark.p2 +def test_system_healthz_contract(rest_client_noauth): + res = rest_client_noauth.get("/system/healthz") + assert res.status_code in (200, 500) + payload = res.json() + assert isinstance(payload, dict), payload + assert payload, payload + + +@pytest.mark.p2 +def test_system_tokens_auth_and_crud(rest_client, rest_client_noauth): + unauth_list = rest_client_noauth.get("/system/tokens") + assert unauth_list.status_code == 401 + unauth_list_payload = unauth_list.json() + assert unauth_list_payload["code"] == 401, unauth_list_payload + + create_res = rest_client.post("/system/tokens") + assert create_res.status_code == 200 + create_payload = create_res.json() + assert create_payload["code"] == 0, create_payload + token = create_payload["data"]["token"] + + list_res = rest_client.get("/system/tokens") + assert list_res.status_code == 200 + list_payload = list_res.json() + assert list_payload["code"] == 0, list_payload + assert isinstance(list_payload["data"], list), list_payload + assert any(item.get("token") == token for item in list_payload["data"]), list_payload + + delete_res = rest_client.delete(f"/system/tokens/{token}") + assert delete_res.status_code == 200 + delete_payload = delete_res.json() + assert delete_payload["code"] == 0, delete_payload + assert delete_payload["data"] is True, delete_payload + + delete_missing = rest_client.delete("/system/tokens/missing_token") + assert delete_missing.status_code == 200 + delete_missing_payload = delete_missing.json() + assert delete_missing_payload["code"] == 0, delete_missing_payload + assert delete_missing_payload["data"] is True, delete_missing_payload + + +@pytest.mark.p2 +def test_system_stats_auth_and_shape(rest_client, rest_client_noauth): + unauth_res = rest_client_noauth.get("/system/stats") + assert unauth_res.status_code == 401 + unauth_payload = unauth_res.json() + assert unauth_payload["code"] == 401, unauth_payload + + res = rest_client.get("/system/stats") + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + data = payload["data"] + for key in ("pv", "uv", "speed", "tokens", "round", "thumb_up"): + assert key in data, payload + assert isinstance(data[key], list), payload + + +@pytest.mark.p2 +def test_system_oceanbase_status_auth_contract(rest_client, rest_client_noauth): + unauth = rest_client_noauth.get("/system/oceanbase/status") + assert unauth.status_code == 401 + assert unauth.json()["code"] == 401 + + res = rest_client.get("/system/oceanbase/status") + assert res.status_code == 200 + payload = res.json() + assert payload["code"] in (0, 500), payload + assert "data" in payload, payload + + +@pytest.mark.p2 +def test_system_log_config_routes_auth_and_validation(rest_client, rest_client_noauth): + unauth = rest_client_noauth.get("/system/config/log") + assert unauth.status_code == 401 + assert unauth.json()["code"] == 401 + + levels = rest_client.get("/system/config/log") + assert levels.status_code == 200 + levels_payload = levels.json() + assert levels_payload["code"] == 0, levels_payload + assert isinstance(levels_payload["data"], dict), levels_payload + + missing_body = rest_client.put("/system/config/log", json={}) + assert missing_body.status_code == 200 + missing_payload = missing_body.json() + assert missing_payload["code"] == 102, missing_payload + assert "pkg_name and level are required" in missing_payload["message"], missing_payload + + invalid_level = rest_client.put("/system/config/log", json={"pkg_name": "rag", "level": "NOT_A_LEVEL"}) + assert invalid_level.status_code == 200 + invalid_payload = invalid_level.json() + assert invalid_payload["code"] == 102, invalid_payload + assert "Invalid log level" in invalid_payload["message"], invalid_payload diff --git a/test/testcases/restful_api/test_task_routes.py b/test/testcases/restful_api/test_task_routes.py new file mode 100644 index 0000000000..13a0fc8a9d --- /dev/null +++ b/test/testcases/restful_api/test_task_routes.py @@ -0,0 +1,48 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest + + +@pytest.mark.p2 +def test_task_routes_require_auth(rest_client_noauth): + cancel_res = rest_client_noauth.post("/tasks/missing_task/cancel") + assert cancel_res.status_code == 401 + cancel_payload = cancel_res.json() + assert cancel_payload["code"] == 401, cancel_payload + + patch_res = rest_client_noauth.patch("/tasks/missing_task", json={"action": "stop"}) + assert patch_res.status_code == 401 + patch_payload = patch_res.json() + assert patch_payload["code"] == 401, patch_payload + + +@pytest.mark.p2 +def test_patch_task_rejects_unsupported_action(rest_client): + res = rest_client.patch("/tasks/missing_task", json={"action": "pause"}) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 101, payload + assert "Only 'stop' is supported" in payload["message"], payload + + +@pytest.mark.p2 +def test_cancel_missing_task_sets_cancel_contract(rest_client): + res = rest_client.post("/tasks/missing_task/cancel") + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + assert payload["data"] is True, payload diff --git a/test/testcases/restful_api/test_user_tenant_routes_unit.py b/test/testcases/restful_api/test_user_tenant_routes_unit.py new file mode 100644 index 0000000000..811a40654c --- /dev/null +++ b/test/testcases/restful_api/test_user_tenant_routes_unit.py @@ -0,0 +1,1382 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import asyncio +import base64 +import importlib.util +import sys +from pathlib import Path +from types import ModuleType, SimpleNamespace + +import pytest + + +class _DummyManager: + def route(self, *_args, **_kwargs): + def decorator(func): + return func + + return decorator + + +class _AwaitableValue: + def __init__(self, value): + self._value = value + + def __await__(self): + async def _co(): + return self._value + + return _co().__await__() + + +class _Field: + def __init__(self, name): + self.name = name + + def __eq__(self, other): + return (self.name, other) + + +class _Invitee: + def __init__(self, user_id="invitee-1", email="invitee@example.com"): + self.id = user_id + self.email = email + + def to_dict(self): + return { + "id": self.id, + "avatar": "avatar-url", + "email": self.email, + "nickname": "Invitee", + "password": "ignored", + } + + +def _run(coro): + return asyncio.run(coro) + + +def _set_request_json(monkeypatch, module, payload): + async def _request_json(): + return payload + + monkeypatch.setattr(module, "get_request_json", _request_json) + + +def _load_tenant_module(monkeypatch): + repo_root = Path(__file__).resolve().parents[3] + + api_pkg = ModuleType("api") + api_pkg.__path__ = [str(repo_root / "api")] + monkeypatch.setitem(sys.modules, "api", api_pkg) + + apps_mod = ModuleType("api.apps") + apps_mod.__path__ = [str(repo_root / "api" / "apps")] + apps_mod.current_user = SimpleNamespace(id="tenant-1", email="owner@example.com") + apps_mod.login_required = lambda fn: fn + monkeypatch.setitem(sys.modules, "api.apps", apps_mod) + + db_mod = ModuleType("api.db") + db_mod.UserTenantRole = SimpleNamespace(NORMAL="normal", OWNER="owner", INVITE="invite") + monkeypatch.setitem(sys.modules, "api.db", db_mod) + + db_models_mod = ModuleType("api.db.db_models") + db_models_mod.UserTenant = type( + "UserTenant", + (), + { + "tenant_id": _Field("tenant_id"), + "user_id": _Field("user_id"), + }, + ) + monkeypatch.setitem(sys.modules, "api.db.db_models", db_models_mod) + + services_pkg = ModuleType("api.db.services") + services_pkg.__path__ = [] + monkeypatch.setitem(sys.modules, "api.db.services", services_pkg) + + user_service_mod = ModuleType("api.db.services.user_service") + + class _UserTenantService: + @staticmethod + def get_by_tenant_id(_tenant_id): + return [] + + @staticmethod + def query(**_kwargs): + return [] + + @staticmethod + def save(**_kwargs): + return True + + @staticmethod + def filter_delete(_conditions): + return True + + @staticmethod + def get_tenants_by_user_id(_user_id): + return [] + + @staticmethod + def filter_update(_conditions, _payload): + return True + + class _UserService: + @staticmethod + def query(**_kwargs): + return [] + + @staticmethod + def get_by_id(_user_id): + return False, None + + user_service_mod.UserTenantService = _UserTenantService + user_service_mod.UserService = _UserService + monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod) + + api_utils_mod = ModuleType("api.utils.api_utils") + api_utils_mod.get_json_result = lambda data=None, message="", code=0: {"code": code, "message": message, "data": data} + api_utils_mod.get_data_error_result = lambda message="": {"code": 102, "message": message, "data": False} + api_utils_mod.server_error_response = lambda exc: {"code": 100, "message": repr(exc), "data": False} + api_utils_mod.validate_request = lambda *_args, **_kwargs: (lambda fn: fn) + api_utils_mod.get_request_json = lambda: _AwaitableValue({}) + monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) + + web_utils_mod = ModuleType("api.utils.web_utils") + web_utils_mod.send_invite_email = lambda **_kwargs: {"ok": True} + monkeypatch.setitem(sys.modules, "api.utils.web_utils", web_utils_mod) + + common_pkg = ModuleType("common") + common_pkg.__path__ = [str(repo_root / "common")] + monkeypatch.setitem(sys.modules, "common", common_pkg) + + constants_mod = ModuleType("common.constants") + constants_mod.RetCode = SimpleNamespace(AUTHENTICATION_ERROR=401, SERVER_ERROR=500, DATA_ERROR=102) + constants_mod.StatusEnum = SimpleNamespace(VALID=SimpleNamespace(value=1)) + monkeypatch.setitem(sys.modules, "common.constants", constants_mod) + + misc_utils_mod = ModuleType("common.misc_utils") + misc_utils_mod.get_uuid = lambda: "uuid-1" + monkeypatch.setitem(sys.modules, "common.misc_utils", misc_utils_mod) + + time_utils_mod = ModuleType("common.time_utils") + time_utils_mod.delta_seconds = lambda _value: 0 + monkeypatch.setitem(sys.modules, "common.time_utils", time_utils_mod) + + settings_mod = ModuleType("common.settings") + settings_mod.MAIL_FRONTEND_URL = "https://frontend.example/invite" + monkeypatch.setitem(sys.modules, "common.settings", settings_mod) + common_pkg.settings = settings_mod + + sys.modules.pop("test_tenant_app_unit_module", None) + module_path = repo_root / "api" / "apps" / "restful_apis" / "tenant_api.py" + spec = importlib.util.spec_from_file_location("test_tenant_app_unit_module", module_path) + module = importlib.util.module_from_spec(spec) + module.manager = _DummyManager() + monkeypatch.setitem(sys.modules, "test_tenant_app_unit_module", module) + spec.loader.exec_module(module) + return module + + +@pytest.mark.p2 +def test_user_list_auth_success_exception_matrix_unit(monkeypatch): + module = _load_tenant_module(monkeypatch) + + module.current_user.id = "other-user" + res = module.user_list("tenant-1") + assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res + assert res["message"] == "No authorization.", res + + module.current_user.id = "tenant-1" + monkeypatch.setattr( + module.UserTenantService, + "get_by_tenant_id", + lambda _tenant_id: [{"id": "u1", "update_date": "2024-01-01 00:00:00"}], + ) + monkeypatch.setattr(module, "delta_seconds", lambda _value: 42) + res = module.user_list("tenant-1") + assert res["code"] == 0, res + assert res["data"][0]["delta_seconds"] == 42, res + + monkeypatch.setattr(module.UserTenantService, "get_by_tenant_id", lambda _tenant_id: (_ for _ in ()).throw(RuntimeError("list boom"))) + res = module.user_list("tenant-1") + assert res["code"] == 100, res + assert "list boom" in res["message"], res + + +@pytest.mark.p2 +def test_create_invite_role_and_email_failure_matrix_unit(monkeypatch): + module = _load_tenant_module(monkeypatch) + + module.current_user.id = "other-user" + _set_request_json(monkeypatch, module, {"email": "invitee@example.com"}) + res = _run(module.create("tenant-1")) + assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res + assert res["message"] == "No authorization.", res + + module.current_user.id = "tenant-1" + monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: []) + res = _run(module.create("tenant-1")) + assert res["message"] == "User not found.", res + + invitee = _Invitee() + monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: [invitee]) + monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(role=module.UserTenantRole.NORMAL)]) + res = _run(module.create("tenant-1")) + assert "already in the team." in res["message"], res + + monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(role=module.UserTenantRole.OWNER)]) + res = _run(module.create("tenant-1")) + assert "owner of the team." in res["message"], res + + monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(role="strange-role")]) + res = _run(module.create("tenant-1")) + assert "role: strange-role is invalid." in res["message"], res + + saved = [] + scheduled = [] + monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: []) + monkeypatch.setattr(module.UserTenantService, "save", lambda **kwargs: saved.append(kwargs) or True) + monkeypatch.setattr(module.UserService, "get_by_id", lambda _user_id: (True, SimpleNamespace(nickname="Inviter Nick"))) + monkeypatch.setattr(module, "send_invite_email", lambda **kwargs: kwargs) + monkeypatch.setattr(module.asyncio, "create_task", lambda payload: scheduled.append(payload) or SimpleNamespace()) + res = _run(module.create("tenant-1")) + assert res["code"] == 0, res + assert saved and saved[-1]["role"] == module.UserTenantRole.INVITE, saved + assert scheduled and scheduled[-1]["inviter"] == "Inviter Nick", scheduled + assert sorted(res["data"].keys()) == ["avatar", "email", "id", "nickname"], res + + monkeypatch.setattr(module.asyncio, "create_task", lambda _payload: (_ for _ in ()).throw(RuntimeError("send boom"))) + res = _run(module.create("tenant-1")) + assert res["code"] == module.RetCode.SERVER_ERROR, res + assert "Failed to send invite email." in res["message"], res + + +@pytest.mark.p2 +def test_rm_and_tenant_list_matrix_unit(monkeypatch): + module = _load_tenant_module(monkeypatch) + + module.current_user.id = "outsider" + _set_request_json(monkeypatch, module, {"user_id": "user-2"}) + res = _run(module.rm("tenant-1")) + assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res + assert res["message"] == "No authorization.", res + + module.current_user.id = "tenant-1" + deleted = [] + monkeypatch.setattr(module.UserTenantService, "filter_delete", lambda conditions: deleted.append(conditions) or True) + res = _run(module.rm("tenant-1")) + assert res["code"] == 0, res + assert res["data"] is True, res + assert deleted, "filter_delete should be called" + + monkeypatch.setattr(module.UserTenantService, "filter_delete", lambda _conditions: (_ for _ in ()).throw(RuntimeError("rm boom"))) + res = _run(module.rm("tenant-1")) + assert res["code"] == 100, res + assert "rm boom" in res["message"], res + + monkeypatch.setattr( + module.UserTenantService, + "get_tenants_by_user_id", + lambda _user_id: [{"id": "tenant-1", "update_date": "2024-01-01 00:00:00"}], + ) + monkeypatch.setattr(module, "delta_seconds", lambda _value: 9) + res = module.tenant_list() + assert res["code"] == 0, res + assert res["data"][0]["delta_seconds"] == 9, res + + monkeypatch.setattr(module.UserTenantService, "get_tenants_by_user_id", lambda _user_id: (_ for _ in ()).throw(RuntimeError("tenant boom"))) + res = module.tenant_list() + assert res["code"] == 100, res + assert "tenant boom" in res["message"], res + + +@pytest.mark.p2 +def test_agree_success_and_exception_unit(monkeypatch): + module = _load_tenant_module(monkeypatch) + + calls = [] + monkeypatch.setattr(module.UserTenantService, "filter_update", lambda conditions, payload: calls.append((conditions, payload)) or True) + res = module.agree("tenant-1") + assert res["code"] == 0, res + assert res["data"] is True, res + assert calls and calls[-1][1]["role"] == module.UserTenantRole.NORMAL + + monkeypatch.setattr(module.UserTenantService, "filter_update", lambda _conditions, _payload: (_ for _ in ()).throw(RuntimeError("agree boom"))) + res = module.agree("tenant-1") + assert res["code"] == 100, res + assert "agree boom" in res["message"], res + + +class _Args(dict): + def get(self, key, default=None, type=None): + value = super().get(key, default) + if type is None: + return value + try: + return type(value) + except (TypeError, ValueError): + return default + + +class _DummyResponse: + def __init__(self, data): + self.data = data + self.headers = {} + + +class _DummyHTTPResponse: + def __init__(self, payload): + self._payload = payload + + def json(self): + return self._payload + + +class _DummyRedis: + def __init__(self): + self.store = {} + + def get(self, key): + return self.store.get(key) + + def set(self, key, value, _ttl=None): + self.store[key] = value + + def delete(self, key): + self.store.pop(key, None) + + +class _DummyUser: + def __init__(self, user_id, email, *, password="stored-password", is_active="1", nickname="nick"): + self.id = user_id + self.email = email + self.password = password + self.is_active = is_active + self.nickname = nickname + self.access_token = "" + self.save_calls = 0 + + def save(self): + self.save_calls += 1 + + def get_id(self): + return self.id + + def to_json(self): + return {"id": self.id, "email": self.email, "nickname": self.nickname} + + def to_dict(self): + return {"id": self.id, "email": self.email} + + +def _set_request_args(monkeypatch, module, args=None): + monkeypatch.setattr(module, "request", SimpleNamespace(args=_Args(args or {}))) + + +@pytest.fixture(scope="session") +def auth(): + return "unit-auth" + + +@pytest.fixture(scope="session", autouse=True) +def set_tenant_info(): + return None + + +def _load_user_app(monkeypatch): + repo_root = Path(__file__).resolve().parents[3] + + quart_mod = ModuleType("quart") + quart_mod.session = {} + quart_mod.request = SimpleNamespace(args=_Args({})) + + async def _make_response(data): + return _DummyResponse(data) + + quart_mod.make_response = _make_response + quart_mod.redirect = lambda url: {"redirect": url} + monkeypatch.setitem(sys.modules, "quart", quart_mod) + + api_pkg = ModuleType("api") + api_pkg.__path__ = [str(repo_root / "api")] + monkeypatch.setitem(sys.modules, "api", api_pkg) + + apps_mod = ModuleType("api.apps") + apps_mod.__path__ = [str(repo_root / "api" / "apps")] + apps_mod.current_user = _DummyUser("current-user", "current@example.com") + apps_mod.login_required = lambda fn: fn + apps_mod.login_user = lambda _user: True + apps_mod.logout_user = lambda: True + monkeypatch.setitem(sys.modules, "api.apps", apps_mod) + api_pkg.apps = apps_mod + + apps_auth_mod = ModuleType("api.apps.auth") + apps_auth_mod.get_auth_client = lambda _config: SimpleNamespace( + get_authorization_url=lambda state: f"https://oauth.example/{state}" + ) + monkeypatch.setitem(sys.modules, "api.apps.auth", apps_auth_mod) + + db_mod = ModuleType("api.db") + db_mod.FileType = SimpleNamespace(FOLDER=SimpleNamespace(value="folder")) + db_mod.UserTenantRole = SimpleNamespace(OWNER="owner") + monkeypatch.setitem(sys.modules, "api.db", db_mod) + api_pkg.db = db_mod + + db_models_mod = ModuleType("api.db.db_models") + + class _DummyTenantLLMModel: + tenant_id = _Field("tenant_id") + + @staticmethod + def delete(): + class _DeleteQuery: + def where(self, *_args, **_kwargs): + return self + + def execute(self): + return 1 + + return _DeleteQuery() + + db_models_mod.TenantLLM = _DummyTenantLLMModel + monkeypatch.setitem(sys.modules, "api.db.db_models", db_models_mod) + + services_pkg = ModuleType("api.db.services") + services_pkg.__path__ = [] + monkeypatch.setitem(sys.modules, "api.db.services", services_pkg) + + file_service_mod = ModuleType("api.db.services.file_service") + + class _StubFileService: + @staticmethod + def insert(_data): + return True + + file_service_mod.FileService = _StubFileService + monkeypatch.setitem(sys.modules, "api.db.services.file_service", file_service_mod) + + llm_service_mod = ModuleType("api.db.services.llm_service") + llm_service_mod.get_init_tenant_llm = lambda _user_id: [] + monkeypatch.setitem(sys.modules, "api.db.services.llm_service", llm_service_mod) + + tenant_llm_service_mod = ModuleType("api.db.services.tenant_llm_service") + + class _MockTableObject: + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + def to_dict(self): + return {k: v for k, v in self.__dict__.items()} + + class _StubTenantLLMService: + @staticmethod + def insert_many(_payload): + return True + + @staticmethod + def get_api_key(tenant_id, model_name, model_type=None): + return _MockTableObject( + id=1, + tenant_id=tenant_id, + llm_factory="", + model_type="chat", + llm_name=model_name, + api_key="fake-api-key", + api_base="https://api.example.com", + max_tokens=8192, + used_tokens=0, + status=1 + ) + + tenant_llm_service_mod.TenantLLMService = _StubTenantLLMService + monkeypatch.setitem(sys.modules, "api.db.services.tenant_llm_service", tenant_llm_service_mod) + + user_service_mod = ModuleType("api.db.services.user_service") + + class _StubTenantService: + @staticmethod + def insert(**_kwargs): + return True + + @staticmethod + def delete_by_id(_tenant_id): + return True + + @staticmethod + def get_by_id(_tenant_id): + return True, SimpleNamespace(id=_tenant_id) + + @staticmethod + def get_info_by(_user_id): + return [] + + @staticmethod + def update_by_id(_tenant_id, _payload): + return True + + class _StubUserService: + @staticmethod + def query(**_kwargs): + return [] + + @staticmethod + def query_user(_email, _password): + return None + + @staticmethod + def query_user_by_email(**_kwargs): + return [] + + @staticmethod + def save(**_kwargs): + return True + + @staticmethod + def delete_by_id(_user_id): + return True + + @staticmethod + def update_by_id(_user_id, _payload): + return True + + @staticmethod + def update_user_password(_user_id, _new_password): + return True + + class _StubUserTenantService: + @staticmethod + def insert(**_kwargs): + return True + + @staticmethod + def query(**_kwargs): + return [] + + @staticmethod + def delete_by_id(_user_tenant_id): + return True + + user_service_mod.TenantService = _StubTenantService + user_service_mod.UserService = _StubUserService + user_service_mod.UserTenantService = _StubUserTenantService + monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod) + + api_utils_mod = ModuleType("api.utils.api_utils") + + async def _default_request_json(): + return {} + + def _get_json_result(code=0, message="success", data=None): + return {"code": code, "message": message, "data": data} + + def _get_data_error_result(code=102, message="Sorry! Data missing!", data=None): + return {"code": code, "message": message, "data": data} + + def _server_error_response(error): + return {"code": 100, "message": repr(error)} + + def _validate_request(*_args, **_kwargs): + def _decorator(func): + return func + + return _decorator + + api_utils_mod.get_request_json = _default_request_json + api_utils_mod.get_json_result = _get_json_result + api_utils_mod.get_data_error_result = _get_data_error_result + api_utils_mod.server_error_response = _server_error_response + api_utils_mod.validate_request = _validate_request + monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) + + tenant_utils_mod = ModuleType("api.utils.tenant_utils") + tenant_utils_mod.ensure_tenant_model_id_for_params = lambda _tenant_id, params: params + monkeypatch.setitem(sys.modules, "api.utils.tenant_utils", tenant_utils_mod) + + crypt_mod = ModuleType("api.utils.crypt") + crypt_mod.decrypt = lambda value: value + monkeypatch.setitem(sys.modules, "api.utils.crypt", crypt_mod) + + web_utils_mod = ModuleType("api.utils.web_utils") + web_utils_mod.send_email_html = lambda *_args, **_kwargs: _AwaitableValue(True) + web_utils_mod.OTP_LENGTH = 6 + web_utils_mod.OTP_TTL_SECONDS = 600 + web_utils_mod.ATTEMPT_LIMIT = 5 + web_utils_mod.ATTEMPT_LOCK_SECONDS = 600 + web_utils_mod.RESEND_COOLDOWN_SECONDS = 60 + web_utils_mod.otp_keys = lambda email: ( + f"otp:{email}:code", + f"otp:{email}:attempts", + f"otp:{email}:last", + f"otp:{email}:lock", + ) + web_utils_mod.hash_code = lambda code, _salt: f"hash:{code}" + web_utils_mod.captcha_key = lambda email: f"captcha:{email}" + monkeypatch.setitem(sys.modules, "api.utils.web_utils", web_utils_mod) + + common_pkg = ModuleType("common") + common_pkg.__path__ = [str(repo_root / "common")] + monkeypatch.setitem(sys.modules, "common", common_pkg) + + settings_mod = ModuleType("common.settings") + settings_mod.OAUTH_CONFIG = { + "github": {"display_name": "GitHub", "icon": "gh"}, + "feishu": {"display_name": "Feishu", "icon": "fs"}, + } + settings_mod.GITHUB_OAUTH = {"url": "https://github.example/oauth", "client_id": "cid", "secret_key": "sk"} + settings_mod.FEISHU_OAUTH = { + "app_access_token_url": "https://feishu.example/app_token", + "user_access_token_url": "https://feishu.example/user_token", + "app_id": "app-id", + "app_secret": "app-secret", + "grant_type": "authorization_code", + } + settings_mod.CHAT_MDL = "chat-mdl" + settings_mod.EMBEDDING_MDL = "embd-mdl" + settings_mod.ASR_MDL = "asr-mdl" + settings_mod.PARSERS = [] + settings_mod.IMAGE2TEXT_MDL = "img-mdl" + settings_mod.RERANK_MDL = "rerank-mdl" + settings_mod.REGISTER_ENABLED = True + monkeypatch.setitem(sys.modules, "common.settings", settings_mod) + common_pkg.settings = settings_mod + + constants_mod = ModuleType("common.constants") + constants_mod.RetCode = SimpleNamespace( + AUTHENTICATION_ERROR=401, + SERVER_ERROR=500, + FORBIDDEN=403, + EXCEPTION_ERROR=100, + OPERATING_ERROR=300, + ARGUMENT_ERROR=101, + DATA_ERROR=102, + NOT_EFFECTIVE=103, + SUCCESS=0, + ) + monkeypatch.setitem(sys.modules, "common.constants", constants_mod) + + connection_utils_mod = ModuleType("common.connection_utils") + + async def _construct_response(data=None, auth=None, message=""): + return {"code": 0, "message": message, "data": data, "auth": auth} + + connection_utils_mod.construct_response = _construct_response + monkeypatch.setitem(sys.modules, "common.connection_utils", connection_utils_mod) + + time_utils_mod = ModuleType("common.time_utils") + time_utils_mod.current_timestamp = lambda: 111 + time_utils_mod.datetime_format = lambda _dt: "2024-01-01 00:00:00" + time_utils_mod.get_format_time = lambda: "2024-01-01 00:00:00" + monkeypatch.setitem(sys.modules, "common.time_utils", time_utils_mod) + + misc_utils_mod = ModuleType("common.misc_utils") + misc_utils_mod.download_img = lambda _url: "avatar" + misc_utils_mod.get_uuid = lambda: "uuid-default" + monkeypatch.setitem(sys.modules, "common.misc_utils", misc_utils_mod) + + http_client_mod = ModuleType("common.http_client") + + async def _async_request(_method, _url, **_kwargs): + return _DummyHTTPResponse({}) + + http_client_mod.async_request = _async_request + monkeypatch.setitem(sys.modules, "common.http_client", http_client_mod) + + rag_pkg = ModuleType("rag") + rag_pkg.__path__ = [str(repo_root / "rag")] + monkeypatch.setitem(sys.modules, "rag", rag_pkg) + + rag_utils_pkg = ModuleType("rag.utils") + rag_utils_pkg.__path__ = [str(repo_root / "rag" / "utils")] + monkeypatch.setitem(sys.modules, "rag.utils", rag_utils_pkg) + + redis_mod = ModuleType("rag.utils.redis_conn") + redis_mod.REDIS_CONN = _DummyRedis() + monkeypatch.setitem(sys.modules, "rag.utils.redis_conn", redis_mod) + + module_name = "test_user_app_unit_module" + module_path = repo_root / "api" / "apps" / "restful_apis" / "user_api.py" + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + module.manager = _DummyManager() + monkeypatch.setitem(sys.modules, module_name, module) + spec.loader.exec_module(module) + return module + + +@pytest.mark.p2 +def test_login_route_branch_matrix_unit(monkeypatch): + module = _load_user_app(monkeypatch) + + _set_request_json(monkeypatch, module, {}) + res = _run(module.login()) + assert res["code"] == module.RetCode.AUTHENTICATION_ERROR + assert "Unauthorized" in res["message"] + + _set_request_json(monkeypatch, module, {"email": "unknown@example.com", "password": "enc"}) + monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: []) + res = _run(module.login()) + assert res["code"] == module.RetCode.AUTHENTICATION_ERROR + assert "not registered" in res["message"] + + _set_request_json(monkeypatch, module, {"email": "known@example.com", "password": "enc"}) + monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: [SimpleNamespace(email="known@example.com")]) + + def _raise_decrypt(_value): + raise RuntimeError("decrypt explode") + + monkeypatch.setattr(module, "decrypt", _raise_decrypt) + res = _run(module.login()) + assert res["code"] == module.RetCode.SERVER_ERROR + assert "Fail to crypt password" in res["message"] + + user_inactive = _DummyUser("u-inactive", "known@example.com", is_active="0") + monkeypatch.setattr(module, "decrypt", lambda value: value) + monkeypatch.setattr(module.UserService, "query_user", lambda _email, _password: user_inactive) + res = _run(module.login()) + assert res["code"] == module.RetCode.FORBIDDEN + assert "disabled" in res["message"] + + monkeypatch.setattr(module.UserService, "query_user", lambda _email, _password: None) + res = _run(module.login()) + assert res["code"] == module.RetCode.AUTHENTICATION_ERROR + assert "do not match" in res["message"] + + +@pytest.mark.p2 +def test_login_channels_and_oauth_login_matrix_unit(monkeypatch): + module = _load_user_app(monkeypatch) + + module.settings.OAUTH_CONFIG = {"github": {"display_name": "GitHub", "icon": "gh"}} + res = _run(module.get_login_channels()) + assert res["code"] == 0 + assert res["data"][0]["channel"] == "github" + + class _BrokenOAuthConfig: + @staticmethod + def items(): + raise RuntimeError("broken oauth config") + + module.settings.OAUTH_CONFIG = _BrokenOAuthConfig() + res = _run(module.get_login_channels()) + assert res["code"] == module.RetCode.EXCEPTION_ERROR + assert "Load channels failure" in res["message"] + + module.settings.OAUTH_CONFIG = {"github": {"display_name": "GitHub", "icon": "gh"}} + with pytest.raises(ValueError, match="Invalid channel name: missing"): + _run(module.oauth_login("missing")) + + module.session.clear() + monkeypatch.setattr(module, "get_uuid", lambda: "state-123") + + class _AuthClient: + @staticmethod + def get_authorization_url(state): + return f"https://oauth.example/{state}" + + monkeypatch.setattr(module, "get_auth_client", lambda _config: _AuthClient()) + res = _run(module.oauth_login("github")) + assert res["redirect"] == "https://oauth.example/state-123" + assert module.session["oauth_state"] == "state-123" + + +@pytest.mark.p2 +def test_oauth_callback_matrix_unit(monkeypatch): + module = _load_user_app(monkeypatch) + module.settings.OAUTH_CONFIG = {"github": {"display_name": "GitHub", "icon": "gh"}} + + class _SyncAuthClient: + def __init__(self, token_info, user_info): + self._token_info = token_info + self._user_info = user_info + + def exchange_code_for_token(self, _code): + return self._token_info + + def fetch_user_info(self, _token, id_token=None): + _ = id_token + return self._user_info + + class _AsyncAuthClient: + def __init__(self, token_info, user_info): + self._token_info = token_info + self._user_info = user_info + + async def async_exchange_code_for_token(self, _code): + return self._token_info + + async def async_fetch_user_info(self, _token, id_token=None): + _ = id_token + return self._user_info + + _set_request_args(monkeypatch, module, {"state": "x", "code": "c"}) + module.session.clear() + res = _run(module.oauth_callback("missing")) + assert "Invalid channel name: missing" in res["redirect"] + + sync_ok = _SyncAuthClient( + token_info={"access_token": "token-sync", "id_token": "id-sync"}, + user_info=SimpleNamespace(email="sync@example.com", avatar_url="http://img", nickname="sync"), + ) + monkeypatch.setattr(module, "get_auth_client", lambda _config: sync_ok) + + module.session.clear() + module.session["oauth_state"] = "expected" + _set_request_args(monkeypatch, module, {"state": "wrong", "code": "code"}) + res = _run(module.oauth_callback("github")) + assert res["redirect"] == "/?error=invalid_state" + + module.session.clear() + module.session["oauth_state"] = "ok-state" + _set_request_args(monkeypatch, module, {"state": "ok-state"}) + res = _run(module.oauth_callback("github")) + assert res["redirect"] == "/?error=missing_code" + + sync_missing_token = _SyncAuthClient( + token_info={"id_token": "id-only"}, + user_info=SimpleNamespace(email="sync@example.com", avatar_url="http://img", nickname="sync"), + ) + monkeypatch.setattr(module, "get_auth_client", lambda _config: sync_missing_token) + module.session.clear() + module.session["oauth_state"] = "token-state" + _set_request_args(monkeypatch, module, {"state": "token-state", "code": "code"}) + res = _run(module.oauth_callback("github")) + assert res["redirect"] == "/?error=token_failed" + + sync_missing_email = _SyncAuthClient( + token_info={"access_token": "token-sync", "id_token": "id-sync"}, + user_info=SimpleNamespace(email=None, avatar_url="http://img", nickname="sync"), + ) + monkeypatch.setattr(module, "get_auth_client", lambda _config: sync_missing_email) + module.session.clear() + module.session["oauth_state"] = "email-state" + _set_request_args(monkeypatch, module, {"state": "email-state", "code": "code"}) + res = _run(module.oauth_callback("github")) + assert res["redirect"] == "/?error=email_missing" + + async_new_user = _AsyncAuthClient( + token_info={"access_token": "token-async", "id_token": "id-async"}, + user_info=SimpleNamespace(email="new@example.com", avatar_url="http://img", nickname="new-user"), + ) + monkeypatch.setattr(module, "get_auth_client", lambda _config: async_new_user) + monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: []) + + def _raise_download(_url): + raise RuntimeError("download explode") + + monkeypatch.setattr(module, "download_img", _raise_download) + monkeypatch.setattr(module, "user_register", lambda _user_id, _user: None) + rollback_calls = [] + monkeypatch.setattr(module, "rollback_user_registration", lambda user_id: rollback_calls.append(user_id)) + monkeypatch.setattr(module, "get_uuid", lambda: "new-user-id") + module.session.clear() + module.session["oauth_state"] = "new-user-state" + _set_request_args(monkeypatch, module, {"state": "new-user-state", "code": "code"}) + res = _run(module.oauth_callback("github")) + assert "Failed to register new@example.com" in res["redirect"] + assert rollback_calls == ["new-user-id"] + + monkeypatch.setattr(module, "download_img", lambda _url: "avatar") + monkeypatch.setattr( + module, + "user_register", + lambda _user_id, _user: [_DummyUser("dup-1", "new@example.com"), _DummyUser("dup-2", "new@example.com")], + ) + rollback_calls.clear() + module.session.clear() + module.session["oauth_state"] = "dup-user-state" + _set_request_args(monkeypatch, module, {"state": "dup-user-state", "code": "code"}) + res = _run(module.oauth_callback("github")) + assert "Same email: new@example.com exists!" in res["redirect"] + assert rollback_calls == ["new-user-id"] + + new_user = _DummyUser("new-user", "new@example.com") + login_calls = [] + monkeypatch.setattr(module, "login_user", lambda user: login_calls.append(user)) + monkeypatch.setattr(module, "user_register", lambda _user_id, _user: [new_user]) + module.session.clear() + module.session["oauth_state"] = "create-user-state" + _set_request_args(monkeypatch, module, {"state": "create-user-state", "code": "code"}) + res = _run(module.oauth_callback("github")) + assert res["redirect"] == "/?auth=new-user" + assert login_calls and login_calls[-1] is new_user + + async_existing_inactive = _AsyncAuthClient( + token_info={"access_token": "token-existing", "id_token": "id-existing"}, + user_info=SimpleNamespace(email="existing@example.com", avatar_url="http://img", nickname="existing"), + ) + monkeypatch.setattr(module, "get_auth_client", lambda _config: async_existing_inactive) + inactive_user = _DummyUser("existing-user", "existing@example.com", is_active="0") + monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: [inactive_user]) + module.session.clear() + module.session["oauth_state"] = "inactive-state" + _set_request_args(monkeypatch, module, {"state": "inactive-state", "code": "code"}) + res = _run(module.oauth_callback("github")) + assert res["redirect"] == "/?error=user_inactive" + + async_existing_ok = _AsyncAuthClient( + token_info={"access_token": "token-existing", "id_token": "id-existing"}, + user_info=SimpleNamespace(email="existing@example.com", avatar_url="http://img", nickname="existing"), + ) + monkeypatch.setattr(module, "get_auth_client", lambda _config: async_existing_ok) + existing_user = _DummyUser("existing-user", "existing@example.com") + monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: [existing_user]) + login_calls.clear() + monkeypatch.setattr(module, "login_user", lambda user: login_calls.append(user)) + monkeypatch.setattr(module, "get_uuid", lambda: "existing-token") + module.session.clear() + module.session["oauth_state"] = "existing-state" + _set_request_args(monkeypatch, module, {"state": "existing-state", "code": "code"}) + res = _run(module.oauth_callback("github")) + assert res["redirect"] == "/?auth=existing-user" + assert existing_user.access_token == "existing-token" + assert existing_user.save_calls == 1 + assert login_calls and login_calls[-1] is existing_user + + +@pytest.mark.p2 +def test_logout_setting_profile_matrix_unit(monkeypatch): + module = _load_user_app(monkeypatch) + + current_user = _DummyUser("current-user", "current@example.com", password="stored-password") + monkeypatch.setattr(module, "current_user", current_user) + monkeypatch.setattr(module.secrets, "token_hex", lambda _n: "abcdef") + logout_calls = [] + monkeypatch.setattr(module, "logout_user", lambda: logout_calls.append(True)) + + res = _run(module.log_out()) + assert res["code"] == 0 + assert current_user.access_token == "INVALID_abcdef" + assert current_user.save_calls == 1 + assert logout_calls == [True] + + _set_request_json(monkeypatch, module, {"password": "old-password", "new_password": "new-password"}) + monkeypatch.setattr(module, "decrypt", lambda value: value) + monkeypatch.setattr(module, "check_password_hash", lambda _hashed, _plain: False) + res = _run(module.setting_user()) + assert res["code"] == module.RetCode.AUTHENTICATION_ERROR + assert "Password error" in res["message"] + + _set_request_json( + monkeypatch, + module, + { + "password": "old-password", + "new_password": "new-password", + "nickname": "neo", + "email": "blocked@example.com", + "status": "disabled", + "theme": "dark", + }, + ) + monkeypatch.setattr(module, "check_password_hash", lambda _hashed, _plain: True) + monkeypatch.setattr(module, "decrypt", lambda value: f"dec:{value}") + monkeypatch.setattr(module, "generate_password_hash", lambda value: f"hash:{value}") + update_calls = {} + + def _update_by_id(user_id, payload): + update_calls["user_id"] = user_id + update_calls["payload"] = payload + return True + + monkeypatch.setattr(module.UserService, "update_by_id", _update_by_id) + res = _run(module.setting_user()) + assert res["code"] == 0 + assert res["data"] is True + assert update_calls["user_id"] == "current-user" + assert update_calls["payload"]["password"] == "hash:dec:new-password" + assert update_calls["payload"]["nickname"] == "neo" + assert update_calls["payload"]["theme"] == "dark" + assert "email" not in update_calls["payload"] + assert "status" not in update_calls["payload"] + + _set_request_json(monkeypatch, module, {"nickname": "neo"}) + + def _raise_update(_user_id, _payload): + raise RuntimeError("update explode") + + monkeypatch.setattr(module.UserService, "update_by_id", _raise_update) + res = _run(module.setting_user()) + assert res["code"] == module.RetCode.EXCEPTION_ERROR + assert "Update failure" in res["message"] + + res = _run(module.user_profile()) + assert res["code"] == 0 + assert res["data"] == current_user.to_dict() + + +@pytest.mark.p2 +def test_registration_helpers_and_register_route_matrix_unit(monkeypatch): + module = _load_user_app(monkeypatch) + + deleted = {"user": 0, "tenant": 0, "user_tenant": 0, "tenant_llm": 0} + monkeypatch.setattr(module.UserService, "delete_by_id", lambda _user_id: deleted.__setitem__("user", deleted["user"] + 1)) + monkeypatch.setattr(module.TenantService, "delete_by_id", lambda _tenant_id: deleted.__setitem__("tenant", deleted["tenant"] + 1)) + monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(id="ut-1")]) + monkeypatch.setattr(module.UserTenantService, "delete_by_id", lambda _ut_id: deleted.__setitem__("user_tenant", deleted["user_tenant"] + 1)) + + class _DeleteQuery: + def where(self, *_args, **_kwargs): + return self + + def execute(self): + deleted["tenant_llm"] += 1 + return 1 + + monkeypatch.setattr(module.TenantLLM, "delete", lambda: _DeleteQuery()) + module.rollback_user_registration("user-1") + assert deleted == {"user": 1, "tenant": 1, "user_tenant": 1, "tenant_llm": 1}, deleted + + monkeypatch.setattr(module.UserService, "delete_by_id", lambda _user_id: (_ for _ in ()).throw(RuntimeError("u boom"))) + monkeypatch.setattr(module.TenantService, "delete_by_id", lambda _tenant_id: (_ for _ in ()).throw(RuntimeError("t boom"))) + monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: (_ for _ in ()).throw(RuntimeError("ut boom"))) + + class _RaisingDeleteQuery: + def where(self, *_args, **_kwargs): + raise RuntimeError("llm boom") + + monkeypatch.setattr(module.TenantLLM, "delete", lambda: _RaisingDeleteQuery()) + module.rollback_user_registration("user-2") + + monkeypatch.setattr(module.UserService, "save", lambda **_kwargs: False) + res = module.user_register( + "new-user", + { + "nickname": "new", + "email": "new@example.com", + "password": "pw", + "access_token": "tk", + "login_channel": "password", + "last_login_time": "2024-01-01 00:00:00", + "is_superuser": False, + }, + ) + assert res is None + + monkeypatch.setattr(module.settings, "REGISTER_ENABLED", False) + _set_request_json(monkeypatch, module, {"nickname": "neo", "email": "neo@example.com", "password": "enc"}) + res = _run(module.user_add()) + assert res["code"] == module.RetCode.OPERATING_ERROR, res + assert "disabled" in res["message"], res + + monkeypatch.setattr(module.settings, "REGISTER_ENABLED", True) + _set_request_json(monkeypatch, module, {"nickname": "neo", "email": "bad-email", "password": "enc"}) + res = _run(module.user_add()) + assert res["code"] == module.RetCode.OPERATING_ERROR, res + assert "Invalid email address" in res["message"], res + + monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: []) + monkeypatch.setattr(module, "decrypt", lambda value: value) + monkeypatch.setattr(module, "get_uuid", lambda: "new-user-id") + rollback_calls = [] + monkeypatch.setattr(module, "rollback_user_registration", lambda user_id: rollback_calls.append(user_id)) + + _set_request_json(monkeypatch, module, {"nickname": "neo", "email": "neo@example.com", "password": "enc"}) + monkeypatch.setattr(module, "user_register", lambda _user_id, _payload: None) + res = _run(module.user_add()) + assert res["code"] == module.RetCode.EXCEPTION_ERROR, res + assert "Fail to register neo@example.com." in res["message"], res + assert rollback_calls == ["new-user-id"], rollback_calls + + rollback_calls.clear() + monkeypatch.setattr( + module, + "user_register", + lambda _user_id, _payload: [_DummyUser("dup-1", "neo@example.com"), _DummyUser("dup-2", "neo@example.com")], + ) + _set_request_json(monkeypatch, module, {"nickname": "neo", "email": "neo@example.com", "password": "enc"}) + res = _run(module.user_add()) + assert res["code"] == module.RetCode.EXCEPTION_ERROR, res + assert "Same email: neo@example.com exists!" in res["message"], res + assert rollback_calls == ["new-user-id"], rollback_calls + + +@pytest.mark.p2 +def test_tenant_info_and_set_tenant_info_exception_matrix_unit(monkeypatch): + module = _load_user_app(monkeypatch) + + monkeypatch.setattr(module.TenantService, "get_info_by", lambda _uid: []) + res = _run(module.tenant_info()) + assert res["code"] == module.RetCode.DATA_ERROR, res + assert "Tenant not found" in res["message"], res + + def _raise_tenant_info(_uid): + raise RuntimeError("tenant info boom") + + monkeypatch.setattr(module.TenantService, "get_info_by", _raise_tenant_info) + res = _run(module.tenant_info()) + assert res["code"] == module.RetCode.EXCEPTION_ERROR, res + assert "tenant info boom" in res["message"], res + + _set_request_json( + monkeypatch, + module, + {"tenant_id": "tenant-1", "llm_id": "l", "embd_id": "e", "asr_id": "a", "img2txt_id": "i"}, + ) + + def _raise_update(_tenant_id, _payload): + raise RuntimeError("tenant update boom") + + monkeypatch.setattr(module.TenantService, "update_by_id", _raise_update) + res = _run(module.set_tenant_info()) + assert res["code"] == module.RetCode.EXCEPTION_ERROR, res + assert "tenant update boom" in res["message"], res + + +@pytest.mark.p2 +def test_forget_captcha_and_send_otp_matrix_unit(monkeypatch): + module = _load_user_app(monkeypatch) + + class _Headers(dict): + def set(self, key, value): + self[key] = value + + async def _make_response(data): + return SimpleNamespace(data=data, headers=_Headers()) + + monkeypatch.setattr(module, "make_response", _make_response) + + captcha_pkg = ModuleType("captcha") + captcha_image_mod = ModuleType("captcha.image") + + class _ImageCaptcha: + def __init__(self, **_kwargs): + pass + + def generate(self, text): + return SimpleNamespace(read=lambda: f"img:{text}".encode()) + + captcha_image_mod.ImageCaptcha = _ImageCaptcha + monkeypatch.setitem(sys.modules, "captcha", captcha_pkg) + monkeypatch.setitem(sys.modules, "captcha.image", captcha_image_mod) + + _set_request_args(monkeypatch, module, {"email": ""}) + res = _run(module.forget_get_captcha()) + assert res["code"] == module.RetCode.ARGUMENT_ERROR, res + + monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: []) + _set_request_args(monkeypatch, module, {"email": "nobody@example.com"}) + res = _run(module.forget_get_captcha()) + assert res["code"] == module.RetCode.DATA_ERROR, res + + monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: [_DummyUser("u1", "ok@example.com")]) + monkeypatch.setattr(module.secrets, "choice", lambda _allowed: "A") + _set_request_args(monkeypatch, module, {"email": "ok@example.com"}) + res = _run(module.forget_get_captcha()) + assert res.data.startswith(b"img:"), res + assert res.headers["Content-Type"] == "image/JPEG", res.headers + assert module.REDIS_CONN.get(module.captcha_key("ok@example.com")), module.REDIS_CONN.store + + _set_request_json(monkeypatch, module, {"email": "", "captcha": ""}) + res = _run(module.forget_send_otp()) + assert res["code"] == module.RetCode.ARGUMENT_ERROR, res + + monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: []) + _set_request_json(monkeypatch, module, {"email": "none@example.com", "captcha": "AAAA"}) + res = _run(module.forget_send_otp()) + assert res["code"] == module.RetCode.DATA_ERROR, res + + monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: [_DummyUser("u1", "ok@example.com")]) + _set_request_json(monkeypatch, module, {"email": "ok@example.com", "captcha": "AAAA"}) + module.REDIS_CONN.store.pop(module.captcha_key("ok@example.com"), None) + res = _run(module.forget_send_otp()) + assert res["code"] == module.RetCode.NOT_EFFECTIVE, res + + module.REDIS_CONN.store[module.captcha_key("ok@example.com")] = "ABCD" + _set_request_json(monkeypatch, module, {"email": "ok@example.com", "captcha": "ZZZZ"}) + res = _run(module.forget_send_otp()) + assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res + + monkeypatch.setattr(module.time, "time", lambda: 1000) + k_code, k_attempts, k_last, k_lock = module.otp_keys("ok@example.com") + module.REDIS_CONN.store[module.captcha_key("ok@example.com")] = "ABCD" + module.REDIS_CONN.store[k_last] = "990" + _set_request_json(monkeypatch, module, {"email": "ok@example.com", "captcha": "ABCD"}) + res = _run(module.forget_send_otp()) + assert res["code"] == module.RetCode.NOT_EFFECTIVE, res + assert "wait" in res["message"], res + + module.REDIS_CONN.store[module.captcha_key("ok@example.com")] = "ABCD" + module.REDIS_CONN.store[k_last] = "bad-timestamp" + monkeypatch.setattr(module.secrets, "choice", lambda _allowed: "B") + monkeypatch.setattr(module.os, "urandom", lambda _n: b"\x00" * 16) + monkeypatch.setattr(module, "hash_code", lambda code, _salt: f"HASH_{code}") + + async def _raise_send_email(*_args, **_kwargs): + raise RuntimeError("send email boom") + + monkeypatch.setattr(module, "send_email_html", _raise_send_email) + _set_request_json(monkeypatch, module, {"email": "ok@example.com", "captcha": "ABCD"}) + res = _run(module.forget_send_otp()) + assert res["code"] == module.RetCode.SERVER_ERROR, res + assert "failed to send email" in res["message"], res + + async def _ok_send_email(*_args, **_kwargs): + return True + + module.REDIS_CONN.store[module.captcha_key("ok@example.com")] = "ABCD" + module.REDIS_CONN.store.pop(k_last, None) + monkeypatch.setattr(module, "send_email_html", _ok_send_email) + _set_request_json(monkeypatch, module, {"email": "ok@example.com", "captcha": "ABCD"}) + res = _run(module.forget_send_otp()) + assert res["code"] == module.RetCode.SUCCESS, res + assert res["data"] is True, res + assert module.REDIS_CONN.get(k_code), module.REDIS_CONN.store + assert module.REDIS_CONN.get(k_attempts) == 0, module.REDIS_CONN.store + assert module.REDIS_CONN.get(k_lock) is None, module.REDIS_CONN.store + + +@pytest.mark.p2 +def test_forget_verify_otp_matrix_unit(monkeypatch): + module = _load_user_app(monkeypatch) + email = "ok@example.com" + k_code, k_attempts, k_last, k_lock = module.otp_keys(email) + salt = b"\x01" * 16 + monkeypatch.setattr(module, "hash_code", lambda code, _salt: f"HASH_{code}") + + _set_request_json(monkeypatch, module, {}) + res = _run(module.forget_verify_otp()) + assert res["code"] == module.RetCode.ARGUMENT_ERROR, res + + monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: []) + _set_request_json(monkeypatch, module, {"email": email, "otp": "ABCDEF"}) + res = _run(module.forget_verify_otp()) + assert res["code"] == module.RetCode.DATA_ERROR, res + + monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: [_DummyUser("u1", email)]) + module.REDIS_CONN.store[k_lock] = "1" + _set_request_json(monkeypatch, module, {"email": email, "otp": "ABCDEF"}) + res = _run(module.forget_verify_otp()) + assert res["code"] == module.RetCode.NOT_EFFECTIVE, res + module.REDIS_CONN.store.pop(k_lock, None) + + module.REDIS_CONN.store.pop(k_code, None) + _set_request_json(monkeypatch, module, {"email": email, "otp": "ABCDEF"}) + res = _run(module.forget_verify_otp()) + assert res["code"] == module.RetCode.NOT_EFFECTIVE, res + + module.REDIS_CONN.store[k_code] = "broken" + _set_request_json(monkeypatch, module, {"email": email, "otp": "ABCDEF"}) + res = _run(module.forget_verify_otp()) + assert res["code"] == module.RetCode.EXCEPTION_ERROR, res + + module.REDIS_CONN.store[k_code] = f"HASH_CORRECT:{salt.hex()}" + module.REDIS_CONN.store[k_attempts] = "bad-int" + _set_request_json(monkeypatch, module, {"email": email, "otp": "wrong"}) + res = _run(module.forget_verify_otp()) + assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res + assert module.REDIS_CONN.get(k_attempts) == 1, module.REDIS_CONN.store + + module.REDIS_CONN.store[k_code] = f"HASH_CORRECT:{salt.hex()}" + module.REDIS_CONN.store[k_attempts] = str(module.ATTEMPT_LIMIT - 1) + _set_request_json(monkeypatch, module, {"email": email, "otp": "wrong"}) + res = _run(module.forget_verify_otp()) + assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res + assert module.REDIS_CONN.get(k_lock) is not None, module.REDIS_CONN.store + module.REDIS_CONN.store.pop(k_lock, None) + + module.REDIS_CONN.store[k_code] = f"HASH_ABCDEF:{salt.hex()}" + module.REDIS_CONN.store[k_attempts] = "0" + module.REDIS_CONN.store[k_last] = "1000" + + def _set_with_verified_fail(key, value, _ttl=None): + if key == module._verified_key(email): + raise RuntimeError("verified set boom") + module.REDIS_CONN.store[key] = value + + monkeypatch.setattr(module.REDIS_CONN, "set", _set_with_verified_fail) + _set_request_json(monkeypatch, module, {"email": email, "otp": "abcdef"}) + res = _run(module.forget_verify_otp()) + assert res["code"] == module.RetCode.SERVER_ERROR, res + + monkeypatch.setattr(module.REDIS_CONN, "set", lambda key, value, _ttl=None: module.REDIS_CONN.store.__setitem__(key, value)) + module.REDIS_CONN.store[k_code] = f"HASH_ABCDEF:{salt.hex()}" + module.REDIS_CONN.store[k_attempts] = "0" + module.REDIS_CONN.store[k_last] = "1000" + _set_request_json(monkeypatch, module, {"email": email, "otp": "abcdef"}) + res = _run(module.forget_verify_otp()) + assert res["code"] == module.RetCode.SUCCESS, res + assert module.REDIS_CONN.get(k_code) is None, module.REDIS_CONN.store + assert module.REDIS_CONN.get(k_attempts) is None, module.REDIS_CONN.store + assert module.REDIS_CONN.get(k_last) is None, module.REDIS_CONN.store + assert module.REDIS_CONN.get(k_lock) is None, module.REDIS_CONN.store + assert module.REDIS_CONN.get(module._verified_key(email)) == "1", module.REDIS_CONN.store + + +@pytest.mark.p2 +def test_forget_reset_password_matrix_unit(monkeypatch): + module = _load_user_app(monkeypatch) + email = "reset@example.com" + v_key = module._verified_key(email) + user = _DummyUser("u-reset", email, nickname="reset-user") + pwd_a = base64.b64encode(b"new-password").decode() + pwd_b = base64.b64encode(b"confirm-password").decode() + pwd_same = base64.b64encode(b"same-password").decode() + monkeypatch.setattr(module, "decrypt", lambda value: value) + + _set_request_json(monkeypatch, module, {"email": email, "new_password": pwd_same, "confirm_new_password": pwd_same}) + module.REDIS_CONN.store.pop(v_key, None) + res = _run(module.forget_reset_password()) + assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res + + module.REDIS_CONN.store[v_key] = "1" + monkeypatch.setattr(module, "decrypt", lambda _value: "") + _set_request_json(monkeypatch, module, {"email": email, "new_password": "", "confirm_new_password": ""}) + res = _run(module.forget_reset_password()) + assert res["code"] == module.RetCode.ARGUMENT_ERROR, res + + monkeypatch.setattr(module, "decrypt", lambda value: value) + module.REDIS_CONN.store[v_key] = "1" + _set_request_json(monkeypatch, module, {"email": email, "new_password": pwd_a, "confirm_new_password": pwd_b}) + res = _run(module.forget_reset_password()) + assert res["code"] == module.RetCode.ARGUMENT_ERROR, res + assert "do not match" in res["message"], res + + module.REDIS_CONN.store[v_key] = "1" + monkeypatch.setattr(module.UserService, "query_user_by_email", lambda **_kwargs: []) + _set_request_json(monkeypatch, module, {"email": email, "new_password": pwd_same, "confirm_new_password": pwd_same}) + res = _run(module.forget_reset_password()) + assert res["code"] == module.RetCode.DATA_ERROR, res + + module.REDIS_CONN.store[v_key] = "1" + monkeypatch.setattr(module.UserService, "query_user_by_email", lambda **_kwargs: [user]) + + def _raise_update_password(_user_id, _new_pwd): + raise RuntimeError("reset boom") + + monkeypatch.setattr(module.UserService, "update_user_password", _raise_update_password) + _set_request_json(monkeypatch, module, {"email": email, "new_password": pwd_same, "confirm_new_password": pwd_same}) + res = _run(module.forget_reset_password()) + assert res["code"] == module.RetCode.EXCEPTION_ERROR, res + + module.REDIS_CONN.store[v_key] = "1" + monkeypatch.setattr(module.UserService, "update_user_password", lambda _user_id, _new_pwd: True) + monkeypatch.setattr(module.REDIS_CONN, "delete", lambda _key: (_ for _ in ()).throw(RuntimeError("delete boom"))) + _set_request_json(monkeypatch, module, {"email": email, "new_password": pwd_same, "confirm_new_password": pwd_same}) + res = _run(module.forget_reset_password()) + assert res["code"] == module.RetCode.SUCCESS, res + assert res["auth"] == user.get_id(), res + + monkeypatch.setattr(module.REDIS_CONN, "delete", lambda key: module.REDIS_CONN.store.pop(key, None)) + module.REDIS_CONN.store[v_key] = "1" + _set_request_json(monkeypatch, module, {"email": email, "new_password": pwd_same, "confirm_new_password": pwd_same}) + res = _run(module.forget_reset_password()) + assert res["code"] == module.RetCode.SUCCESS, res + assert res["auth"] == user.get_id(), res + assert module.REDIS_CONN.get(v_key) is None, module.REDIS_CONN.store