From 09e1fd290a23184706ddcb7d726bbaa93d3059ac Mon Sep 17 00:00:00 2001
From: Idriss Sbaaoui <112825897+6ba3i@users.noreply.github.com>
Date: Wed, 13 May 2026 15:07:23 +0800
Subject: [PATCH] Chore: migrate tests to restful api (#14871)
### What problem does this PR solve?
add new testing suite for the new restful api endpoints meant to replace
http and web api tests
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
- [x] Other (please describe): test
---
.github/workflows/tests.yml | 28 +-
rag/utils/infinity_conn.py | 5 +-
test/testcases/restful_api/conftest.py | 163 ++
.../testcases/restful_api/helpers/__init__.py | 1 +
test/testcases/restful_api/helpers/client.py | 85 +
test/testcases/restful_api/test_agents.py | 333 ++++
test/testcases/restful_api/test_chats.py | 123 ++
test/testcases/restful_api/test_chunks.py | 124 ++
.../restful_api/test_connector_routes_unit.py | 718 +++++++++
test/testcases/restful_api/test_datasets.py | 335 ++++
.../restful_api/test_document_raw_routes.py | 43 +
test/testcases/restful_api/test_documents.py | 122 ++
.../restful_api/test_file_routes_unit.py | 632 ++++++++
.../restful_api/test_langfuse_routes.py | 37 +
.../restful_api/test_mcp_routes_unit.py | 745 +++++++++
.../restful_api/test_memories_messages.py | 210 +++
.../restful_api/test_memory_messages.py | 165 ++
.../restful_api/test_openai_compatible.py | 212 +++
.../restful_api/test_plugin_tools.py | 92 ++
test/testcases/restful_api/test_retrieval.py | 109 ++
.../restful_api/test_router_contracts.py | 28 +
test/testcases/restful_api/test_searches.py | 155 ++
test/testcases/restful_api/test_sessions.py | 219 +++
test/testcases/restful_api/test_system.py | 159 ++
.../testcases/restful_api/test_task_routes.py | 48 +
.../test_user_tenant_routes_unit.py | 1382 +++++++++++++++++
26 files changed, 6249 insertions(+), 24 deletions(-)
create mode 100644 test/testcases/restful_api/conftest.py
create mode 100644 test/testcases/restful_api/helpers/__init__.py
create mode 100644 test/testcases/restful_api/helpers/client.py
create mode 100644 test/testcases/restful_api/test_agents.py
create mode 100644 test/testcases/restful_api/test_chats.py
create mode 100644 test/testcases/restful_api/test_chunks.py
create mode 100644 test/testcases/restful_api/test_connector_routes_unit.py
create mode 100644 test/testcases/restful_api/test_datasets.py
create mode 100644 test/testcases/restful_api/test_document_raw_routes.py
create mode 100644 test/testcases/restful_api/test_documents.py
create mode 100644 test/testcases/restful_api/test_file_routes_unit.py
create mode 100644 test/testcases/restful_api/test_langfuse_routes.py
create mode 100644 test/testcases/restful_api/test_mcp_routes_unit.py
create mode 100644 test/testcases/restful_api/test_memories_messages.py
create mode 100644 test/testcases/restful_api/test_memory_messages.py
create mode 100644 test/testcases/restful_api/test_openai_compatible.py
create mode 100644 test/testcases/restful_api/test_plugin_tools.py
create mode 100644 test/testcases/restful_api/test_retrieval.py
create mode 100644 test/testcases/restful_api/test_router_contracts.py
create mode 100644 test/testcases/restful_api/test_searches.py
create mode 100644 test/testcases/restful_api/test_sessions.py
create mode 100644 test/testcases/restful_api/test_system.py
create mode 100644 test/testcases/restful_api/test_task_routes.py
create mode 100644 test/testcases/restful_api/test_user_tenant_routes_unit.py
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index fc4233504b..2ff30e628f 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -240,23 +240,14 @@ jobs:
echo "Start to run test sdk on Infinity"
source .venv/bin/activate && set -o pipefail; DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} --junitxml=pytest-infinity-sdk.xml --cov=sdk/python/ragflow_sdk --cov-branch --cov-report=xml:coverage-infinity-sdk.xml test/testcases/test_sdk_api 2>&1 | tee infinity_sdk_test.log
- - name: Run web api tests against Infinity
+ - name: Run New RESTFUL api tests against Infinity
run: |
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null 2>&1; do
echo "Waiting for service to be available... (last exit code: $?)"
sleep 5
done
- source .venv/bin/activate && set -o pipefail; DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_web_api/test_chunk_feedback 2>&1 | tee infinity_web_api_test.log
-
- - name: Run http api tests against Infinity
- run: |
- export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
- until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null 2>&1; do
- echo "Waiting for service to be available... (last exit code: $?)"
- sleep 5
- done
- source .venv/bin/activate && set -o pipefail; DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api 2>&1 | tee infinity_http_api_test.log
+ source .venv/bin/activate && set -o pipefail; DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/restful_api 2>&1 | tee infinity_restful_api_test.log
- name: RAGFlow CLI retrieval test Infinity
env:
@@ -432,24 +423,15 @@ jobs:
done
echo "Start to run test sdk on Elasticsearch"
source .venv/bin/activate && set -o pipefail; pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} --junitxml=pytest-infinity-sdk.xml --cov=sdk/python/ragflow_sdk --cov-branch --cov-report=xml:coverage-es-sdk.xml test/testcases/test_sdk_api 2>&1 | tee es_sdk_test.log
-
- - name: Run web api tests against Elasticsearch
+
+ - name: Run New RESTFUL api tests against Elasticsearch
run: |
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null 2>&1; do
echo "Waiting for service to be available... (last exit code: $?)"
sleep 5
done
- source .venv/bin/activate && set -o pipefail; pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_web_api 2>&1 | tee es_web_api_test.log
-
- - name: Run http api tests against Elasticsearch
- run: |
- export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
- until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null 2>&1; do
- echo "Waiting for service to be available... (last exit code: $?)"
- sleep 5
- done
- source .venv/bin/activate && set -o pipefail; pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api 2>&1 | tee es_http_api_test.log
+ source .venv/bin/activate && set -o pipefail; pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/restful_api 2>&1 | tee es_restful_api_test.log
- name: RAGFlow CLI retrieval test Elasticsearch
env:
diff --git a/rag/utils/infinity_conn.py b/rag/utils/infinity_conn.py
index 45290c520d..7ffd9f13d4 100644
--- a/rag/utils/infinity_conn.py
+++ b/rag/utils/infinity_conn.py
@@ -318,7 +318,10 @@ class InfinityConnection(InfinityConnectionBase):
"authors_sm_tks"]:
fields.add(field)
res_fields = self.get_fields(res, list(fields))
- return res_fields.get(chunk_id, None)
+ chunk = res_fields.get(chunk_id, None)
+ if chunk is not None:
+ chunk["id"] = chunk_id
+ return chunk
def insert(self, documents: list[dict], index_name: str, knowledgebase_id: str = None) -> list[str]:
'''
diff --git a/test/testcases/restful_api/conftest.py b/test/testcases/restful_api/conftest.py
new file mode 100644
index 0000000000..b24f0bda45
--- /dev/null
+++ b/test/testcases/restful_api/conftest.py
@@ -0,0 +1,163 @@
+#
+# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import pytest
+
+from libs.auth import RAGFlowHttpApiAuth
+from test.testcases.restful_api.helpers.client import RestClient
+from utils.file_utils import create_txt_file
+from utils import wait_for
+
+
+@pytest.fixture(scope="session")
+def RestApiAuth(token):
+ return RAGFlowHttpApiAuth(token)
+
+
+@pytest.fixture(scope="session")
+def rest_client(token):
+ return RestClient(token=token)
+
+
+@pytest.fixture(scope="session")
+def rest_client_noauth():
+ return RestClient(token=None)
+
+
+@pytest.fixture
+def clear_datasets(rest_client):
+ def _cleanup():
+ res = rest_client.delete("/datasets", json={"ids": None, "delete_all": True})
+ assert res.status_code == 200, res.text
+ payload = res.json()
+ assert payload["code"] in (0, 102), payload
+
+ yield
+ _cleanup()
+
+
+@pytest.fixture
+def clear_chats(rest_client):
+ def _cleanup():
+ res = rest_client.delete("/chats", json={"ids": None, "delete_all": True})
+ assert res.status_code == 200, res.text
+ payload = res.json()
+ assert payload["code"] in (0, 102), payload
+
+ yield
+ _cleanup()
+
+
+@pytest.fixture
+def create_dataset(rest_client, clear_datasets):
+ created_ids: list[str] = []
+
+ def _create(name: str = "restful_dataset") -> str:
+ res = rest_client.post("/datasets", json={"name": name})
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] == 0, payload
+ dataset_id = payload["data"]["id"]
+ created_ids.append(dataset_id)
+ return dataset_id
+
+ yield _create
+
+ if created_ids:
+ res = rest_client.delete("/datasets", json={"ids": created_ids})
+ assert res.status_code == 200
+ payload = res.json()
+ # Dataset may already be removed by test logic/cleanup.
+ assert payload["code"] in (0, 102), payload
+
+
+@pytest.fixture
+def create_chat(rest_client, clear_chats):
+ created_ids: list[str] = []
+
+ def _create(name: str = "restful_chat") -> str:
+ res = rest_client.post("/chats", json={"name": name, "dataset_ids": []})
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] == 0, payload
+ chat_id = payload["data"]["id"]
+ created_ids.append(chat_id)
+ return chat_id
+
+ yield _create
+
+ if created_ids:
+ res = rest_client.delete("/chats", json={"ids": created_ids})
+ assert res.status_code == 200, res.text
+ payload = res.json()
+ assert payload["code"] in (0, 102), payload
+
+
+@pytest.fixture
+def create_document(rest_client, create_dataset, tmp_path):
+ created_docs: list[tuple[str, str]] = []
+
+ def _create(name: str = "restful_doc.txt") -> tuple[str, str]:
+ dataset_id = create_dataset("dataset_for_doc")
+ fp = create_txt_file(tmp_path / name)
+ with fp.open("rb") as file_obj:
+ files = [("file", (fp.name, file_obj))]
+ res = rest_client.post(f"/datasets/{dataset_id}/documents", files=files)
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] == 0, payload
+ document_id = payload["data"][0]["id"]
+ created_docs.append((dataset_id, document_id))
+ return dataset_id, document_id
+
+ yield _create
+
+ for dataset_id, document_id in created_docs:
+ res = rest_client.delete(f"/datasets/{dataset_id}/documents", json={"ids": [document_id]})
+ assert res.status_code == 200, res.text
+ payload = res.json()
+ assert payload["code"] in (0, 102), payload
+
+
+@wait_for(60, 1, "Document parsing timeout in RESTful batch2 tests")
+def _parsed(rest_client: RestClient, dataset_id: str, document_id: str):
+ res = rest_client.get(f"/datasets/{dataset_id}/documents", params={"id": document_id})
+ if res.status_code != 200:
+ return False
+ payload = res.json()
+ if payload["code"] != 0:
+ return False
+ docs = payload["data"]["docs"]
+ if not docs:
+ return False
+ return docs[0].get("run") == "DONE"
+
+
+@pytest.fixture
+def ensure_parsed_document(rest_client, create_document):
+ def _ensure() -> tuple[str, str]:
+ dataset_id, document_id = create_document()
+ res = rest_client.post(
+ f"/datasets/{dataset_id}/documents/parse",
+ json={"document_ids": [document_id]},
+ )
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] == 0, payload
+ _parsed(rest_client, dataset_id, document_id)
+ return dataset_id, document_id
+
+ return _ensure
diff --git a/test/testcases/restful_api/helpers/__init__.py b/test/testcases/restful_api/helpers/__init__.py
new file mode 100644
index 0000000000..117dea3cf0
--- /dev/null
+++ b/test/testcases/restful_api/helpers/__init__.py
@@ -0,0 +1 @@
+# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
diff --git a/test/testcases/restful_api/helpers/client.py b/test/testcases/restful_api/helpers/client.py
new file mode 100644
index 0000000000..8c0a198fc2
--- /dev/null
+++ b/test/testcases/restful_api/helpers/client.py
@@ -0,0 +1,85 @@
+#
+# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from dataclasses import dataclass
+from typing import Any
+
+import requests
+from configs import HOST_ADDRESS, VERSION
+
+
+@dataclass
+class RestClient:
+ token: str | None = None
+ timeout: int = 30
+
+ @property
+ def api_root(self) -> str:
+ return f"{HOST_ADDRESS}/api/{VERSION}"
+
+ def _headers(self, headers: dict[str, str] | None = None) -> dict[str, str]:
+ merged: dict[str, str] = {"Content-Type": "application/json"}
+ if headers:
+ merged.update(headers)
+ if self.token and "Authorization" not in merged:
+ merged["Authorization"] = f"Bearer {self.token}"
+ return merged
+
+ def request(
+ self,
+ method: str,
+ path: str,
+ *,
+ headers: dict[str, str] | None = None,
+ params: dict[str, Any] | None = None,
+ json: dict[str, Any] | None = None,
+ data: Any = None,
+ files: Any = None,
+ **request_kwargs: Any,
+ ) -> requests.Response:
+ req_headers = self._headers(headers)
+ if files is not None:
+ # requests sets multipart boundary automatically.
+ req_headers.pop("Content-Type", None)
+
+ timeout = request_kwargs.pop("timeout", self.timeout)
+ normalized_path = f"/{path.lstrip('/')}" if path else "/"
+ return requests.request(
+ method=method,
+ url=f"{self.api_root}{normalized_path}",
+ headers=req_headers,
+ params=params,
+ json=json,
+ data=data,
+ files=files,
+ timeout=timeout,
+ **request_kwargs,
+ )
+
+ def get(self, path: str, **kwargs) -> requests.Response:
+ return self.request("GET", path, **kwargs)
+
+ def post(self, path: str, **kwargs) -> requests.Response:
+ return self.request("POST", path, **kwargs)
+
+ def delete(self, path: str, **kwargs) -> requests.Response:
+ return self.request("DELETE", path, **kwargs)
+
+ def put(self, path: str, **kwargs) -> requests.Response:
+ return self.request("PUT", path, **kwargs)
+
+ def patch(self, path: str, **kwargs) -> requests.Response:
+ return self.request("PATCH", path, **kwargs)
diff --git a/test/testcases/restful_api/test_agents.py b/test/testcases/restful_api/test_agents.py
new file mode 100644
index 0000000000..9748b9fb96
--- /dev/null
+++ b/test/testcases/restful_api/test_agents.py
@@ -0,0 +1,333 @@
+#
+# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import json
+
+import pytest
+
+
+MINIMAL_DSL = {
+ "components": {
+ "begin": {
+ "obj": {"component_name": "Begin", "params": {}},
+ "downstream": ["message"],
+ "upstream": [],
+ },
+ "message": {
+ "obj": {"component_name": "Message", "params": {"content": ["{sys.query}"]}},
+ "downstream": [],
+ "upstream": ["begin"],
+ },
+ },
+ "history": [],
+ "retrieval": [],
+ "path": [],
+ "globals": {
+ "sys.query": "",
+ "sys.user_id": "",
+ "sys.conversation_turns": 0,
+ "sys.files": [],
+ },
+ "variables": {},
+}
+
+
+def _sse_events(response_text: str) -> list[str]:
+ return [line[5:] for line in response_text.splitlines() if line.startswith("data:")]
+
+
+@pytest.fixture
+def create_agent_resource(rest_client):
+ created_agent_ids: list[str] = []
+
+ def _create(title: str = "restful_agent") -> str:
+ res = rest_client.post("/agents", json={"title": title, "dsl": MINIMAL_DSL})
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] == 0, payload
+ agent_id = payload["data"]["id"]
+ created_agent_ids.append(agent_id)
+ return agent_id
+
+ yield _create
+
+ cleanup_errors = []
+ for agent_id in created_agent_ids:
+ res = rest_client.delete(f"/agents/{agent_id}")
+ if res.status_code != 200:
+ cleanup_errors.append((agent_id, res.status_code, res.text))
+ continue
+ payload = res.json()
+ if payload["code"] not in (0, 103):
+ cleanup_errors.append((agent_id, res.status_code, payload))
+ assert not cleanup_errors, f"Agent cleanup failed: {cleanup_errors}"
+
+
+@pytest.mark.p2
+def test_agents_crud_validation_contract(rest_client, create_agent_resource):
+ list_empty = rest_client.get("/agents", params={"title": "missing_restful_agent"})
+ assert list_empty.status_code == 200
+ list_empty_payload = list_empty.json()
+ assert list_empty_payload["code"] == 0, list_empty_payload
+ assert "canvas" in list_empty_payload["data"], list_empty_payload
+ assert "total" in list_empty_payload["data"], list_empty_payload
+
+ missing_dsl = rest_client.post("/agents", json={"title": "missing_dsl_agent"})
+ assert missing_dsl.status_code == 200
+ missing_dsl_payload = missing_dsl.json()
+ assert missing_dsl_payload["code"] == 101, missing_dsl_payload
+ assert "No DSL data in request" in missing_dsl_payload["message"], missing_dsl_payload
+
+ missing_title = rest_client.post("/agents", json={"dsl": MINIMAL_DSL})
+ assert missing_title.status_code == 200
+ missing_title_payload = missing_title.json()
+ assert missing_title_payload["code"] == 101, missing_title_payload
+ assert "No title in request" in missing_title_payload["message"], missing_title_payload
+
+ agent_id = create_agent_resource("restful_agent_crud")
+
+ duplicate = rest_client.post("/agents", json={"title": "restful_agent_crud", "dsl": MINIMAL_DSL})
+ assert duplicate.status_code == 200
+ duplicate_payload = duplicate.json()
+ assert duplicate_payload["code"] == 102, duplicate_payload
+ assert "already exists" in duplicate_payload["message"], duplicate_payload
+
+ get_res = rest_client.get(f"/agents/{agent_id}")
+ assert get_res.status_code == 200
+ get_payload = get_res.json()
+ assert get_payload["code"] == 0, get_payload
+ assert get_payload["data"]["id"] == agent_id, get_payload
+
+ update_res = rest_client.put(f"/agents/{agent_id}", json={"title": "restful_agent_crud_updated", "dsl": MINIMAL_DSL})
+ assert update_res.status_code == 200
+ update_payload = update_res.json()
+ assert update_payload["code"] == 0, update_payload
+
+ list_after_update = rest_client.get("/agents", params={"title": "restful_agent_crud_updated"})
+ assert list_after_update.status_code == 200
+ list_after_update_payload = list_after_update.json()
+ assert list_after_update_payload["code"] == 0, list_after_update_payload
+ assert list_after_update_payload["data"]["total"] >= 1, list_after_update_payload
+
+ delete_res = rest_client.delete(f"/agents/{agent_id}")
+ assert delete_res.status_code == 200
+ delete_payload = delete_res.json()
+ assert delete_payload["code"] == 0, delete_payload
+ assert delete_payload["data"] is True, delete_payload
+
+
+@pytest.mark.p2
+def test_agent_sessions_crud(rest_client, create_agent_resource):
+ agent_id = create_agent_resource("restful_agent_sessions")
+
+ create_session = rest_client.post(f"/agents/{agent_id}/sessions", json={"name": "agent_session_1"})
+ assert create_session.status_code == 200
+ create_session_payload = create_session.json()
+ assert create_session_payload["code"] == 0, create_session_payload
+ session_id = create_session_payload["data"]["id"]
+
+ list_sessions = rest_client.get(f"/agents/{agent_id}/sessions")
+ assert list_sessions.status_code == 200
+ list_sessions_payload = list_sessions.json()
+ assert list_sessions_payload["code"] == 0, list_sessions_payload
+ assert isinstance(list_sessions_payload["data"], list), list_sessions_payload
+ assert any(item["id"] == session_id for item in list_sessions_payload["data"]), list_sessions_payload
+
+ get_session = rest_client.get(f"/agents/{agent_id}/sessions/{session_id}")
+ assert get_session.status_code == 200
+ get_session_payload = get_session.json()
+ assert get_session_payload["code"] == 0, get_session_payload
+ assert get_session_payload["data"]["id"] == session_id, get_session_payload
+
+ delete_session = rest_client.delete(f"/agents/{agent_id}/sessions/{session_id}")
+ assert delete_session.status_code == 200
+ delete_session_payload = delete_session.json()
+ assert delete_session_payload["code"] == 0, delete_session_payload
+
+
+@pytest.mark.p2
+def test_agent_chat_completion_validation(rest_client):
+ missing_agent_id = rest_client.post(
+ "/agents/chat/completions",
+ json={"query": "hello", "stream": False},
+ )
+ assert missing_agent_id.status_code == 200
+ missing_agent_id_payload = missing_agent_id.json()
+ assert missing_agent_id_payload["code"] == 101, missing_agent_id_payload
+ assert "`agent_id` is required." in missing_agent_id_payload["message"], missing_agent_id_payload
+
+
+@pytest.mark.p2
+def test_agent_chat_completion_nonstream(rest_client, create_agent_resource):
+ agent_id = create_agent_resource("restful_agent_nonstream")
+ create_session = rest_client.post(f"/agents/{agent_id}/sessions", json={"name": "agent_completion_session"})
+ assert create_session.status_code == 200
+ create_session_payload = create_session.json()
+ assert create_session_payload["code"] == 0, create_session_payload
+ session_id = create_session_payload["data"]["id"]
+
+ res = rest_client.post(
+ "/agents/chat/completions",
+ json={"agent_id": agent_id, "query": "hello", "stream": False, "session_id": session_id},
+ timeout=60,
+ )
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] == 0, payload
+ assert isinstance(payload["data"], dict), payload
+ assert isinstance(payload["data"].get("data"), dict), payload
+ assert "content" in payload["data"]["data"], payload
+
+
+@pytest.mark.p2
+def test_agent_chat_completion_stream_structure_and_done(rest_client, create_agent_resource):
+ agent_id = create_agent_resource("restful_agent_stream")
+ create_session = rest_client.post(f"/agents/{agent_id}/sessions", json={"name": "agent_stream_session"})
+ assert create_session.status_code == 200
+ create_session_payload = create_session.json()
+ assert create_session_payload["code"] == 0, create_session_payload
+ session_id = create_session_payload["data"]["id"]
+
+ res = rest_client.post(
+ "/agents/chat/completions",
+ json={
+ "agent_id": agent_id,
+ "query": "hello",
+ "stream": True,
+ "session_id": session_id,
+ "return_trace": True,
+ },
+ timeout=60,
+ )
+ assert res.status_code == 200
+ content_type = res.headers.get("Content-Type", "")
+ assert "text/event-stream" in content_type, content_type
+
+ events = _sse_events(res.text)
+ assert events, res.text
+ assert events[-1].strip() == "[DONE]", events[-1]
+
+ json_events = [json.loads(evt) for evt in events if evt.strip() != "[DONE]"]
+ assert json_events, events
+ assert any(isinstance(evt, dict) for evt in json_events), json_events
+
+
+@pytest.mark.p2
+def test_agent_openai_compatible_mode(rest_client, create_agent_resource):
+ agent_id = create_agent_resource("restful_agent_openai_compat")
+
+ missing_messages = rest_client.post(
+ "/agents/chat/completions",
+ json={"agent_id": agent_id, "openai-compatible": True, "model": "model", "messages": []},
+ )
+ assert missing_messages.status_code == 200
+ missing_messages_payload = missing_messages.json()
+ assert missing_messages_payload["code"] == 102, missing_messages_payload
+ assert "at least one message" in missing_messages_payload["message"], missing_messages_payload
+
+ nonstream = rest_client.post(
+ "/agents/chat/completions",
+ json={
+ "agent_id": agent_id,
+ "openai-compatible": True,
+ "model": "model",
+ "messages": [{"role": "user", "content": "hello"}],
+ "stream": False,
+ },
+ timeout=60,
+ )
+ assert nonstream.status_code == 200
+ nonstream_payload = nonstream.json()
+ assert isinstance(nonstream_payload, dict), nonstream_payload
+ assert "choices" in nonstream_payload, nonstream_payload
+
+ stream = rest_client.post(
+ "/agents/chat/completions",
+ json={
+ "agent_id": agent_id,
+ "openai-compatible": True,
+ "model": "model",
+ "messages": [{"role": "user", "content": "hello"}],
+ "stream": True,
+ },
+ timeout=60,
+ )
+ assert stream.status_code == 200
+ stream_content_type = stream.headers.get("Content-Type", "")
+ assert "text/event-stream" in stream_content_type, stream_content_type
+
+
+@pytest.mark.p2
+def test_agent_support_routes_auth_and_contracts(rest_client, rest_client_noauth, create_agent_resource):
+ prompts_unauth = rest_client_noauth.get("/agents/prompts")
+ assert prompts_unauth.status_code == 401
+ assert prompts_unauth.json()["code"] == 401
+
+ prompts = rest_client.get("/agents/prompts")
+ assert prompts.status_code == 200
+ prompts_payload = prompts.json()
+ assert prompts_payload["code"] == 0, prompts_payload
+ assert "task_analysis" in prompts_payload["data"], prompts_payload
+ assert "citation_guidelines" in prompts_payload["data"], prompts_payload
+
+ templates = rest_client.get("/agents/templates")
+ assert templates.status_code == 200
+ templates_payload = templates.json()
+ assert templates_payload["code"] == 0, templates_payload
+ assert isinstance(templates_payload["data"], list), templates_payload
+
+ agent_id = create_agent_resource("restful_agent_support")
+ versions = rest_client.get(f"/agents/{agent_id}/versions")
+ assert versions.status_code == 200
+ versions_payload = versions.json()
+ assert versions_payload["code"] == 0, versions_payload
+ assert isinstance(versions_payload["data"], list), versions_payload
+
+ logs = rest_client.get(f"/agents/{agent_id}/logs/missing_message")
+ assert logs.status_code == 200
+ logs_payload = logs.json()
+ assert logs_payload["code"] == 0, logs_payload
+ assert isinstance(logs_payload["data"], dict), logs_payload
+
+
+@pytest.mark.p2
+def test_agent_webhook_logs_empty_poll_contract(rest_client, create_agent_resource):
+ agent_id = create_agent_resource("restful_agent_webhook_logs")
+ res = rest_client.get(f"/agents/{agent_id}/webhook/logs", params={"since_ts": 0})
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] == 0, payload
+ assert payload["data"]["events"] == [], payload
+ assert payload["data"]["finished"] is False, payload
+ assert "next_since_ts" in payload["data"], payload
+
+
+@pytest.mark.p2
+def test_agent_db_connection_validates_required_fields(rest_client):
+ res = rest_client.post("/agents/test_db_connection", json={"db_type": "mysql"})
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] == 101, payload
+ assert "required argument are missing" in payload["message"], payload
+
+
+@pytest.mark.p2
+def test_agent_rerun_requires_required_fields(rest_client):
+ res = rest_client.post("/agents/rerun", json={"id": "flow-1"})
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] == 101, payload
+ assert "required argument are missing" in payload["message"], payload
diff --git a/test/testcases/restful_api/test_chats.py b/test/testcases/restful_api/test_chats.py
new file mode 100644
index 0000000000..45fe5a8d2a
--- /dev/null
+++ b/test/testcases/restful_api/test_chats.py
@@ -0,0 +1,123 @@
+#
+# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import pytest
+
+
+@pytest.mark.p1
+class TestChatsAuthorization:
+ def test_create_requires_auth(self, rest_client_noauth):
+ res = rest_client_noauth.post("/chats", json={"name": "chat_auth", "dataset_ids": []})
+ assert res.status_code == 401
+
+
+@pytest.mark.p1
+def test_chat_crud_cycle(rest_client, clear_chats):
+ create_res = rest_client.post(
+ "/chats",
+ json={"name": "restful_chat_crud", "dataset_ids": []},
+ )
+ assert create_res.status_code == 200
+ create_payload = create_res.json()
+ assert create_payload["code"] == 0, create_payload
+ chat_id = create_payload["data"]["id"]
+
+ list_res = rest_client.get("/chats", params={"id": chat_id})
+ assert list_res.status_code == 200
+ list_payload = list_res.json()
+ assert list_payload["code"] == 0, list_payload
+ chats = list_payload["data"]["chats"]
+ assert len(chats) == 1, list_payload
+ assert chats[0]["id"] == chat_id, list_payload
+
+ get_res = rest_client.get(f"/chats/{chat_id}")
+ assert get_res.status_code == 200
+ get_payload = get_res.json()
+ assert get_payload["code"] == 0, get_payload
+ assert get_payload["data"]["id"] == chat_id, get_payload
+
+ update_res = rest_client.put(
+ f"/chats/{chat_id}",
+ json={"name": "restful_chat_crud_updated", "dataset_ids": []},
+ )
+ assert update_res.status_code == 200
+ update_payload = update_res.json()
+ assert update_payload["code"] == 0, update_payload
+ assert update_payload["data"]["name"] == "restful_chat_crud_updated", update_payload
+
+ patch_res = rest_client.patch(f"/chats/{chat_id}", json={"name": "restful_chat_crud_patched"})
+ assert patch_res.status_code == 200
+ patch_payload = patch_res.json()
+ assert patch_payload["code"] == 0, patch_payload
+ assert patch_payload["data"]["name"] == "restful_chat_crud_patched", patch_payload
+
+ delete_res = rest_client.delete("/chats", json={"ids": [chat_id]})
+ assert delete_res.status_code == 200
+ delete_payload = delete_res.json()
+ assert delete_payload["code"] == 0, delete_payload
+ assert delete_payload["data"]["success_count"] == 1, delete_payload
+
+ list_after_delete = rest_client.get("/chats", params={"id": chat_id})
+ assert list_after_delete.status_code == 200
+ list_after_delete_payload = list_after_delete.json()
+ assert list_after_delete_payload["code"] == 0, list_after_delete_payload
+ assert list_after_delete_payload["data"]["chats"] == [], list_after_delete_payload
+
+
+@pytest.mark.p2
+@pytest.mark.parametrize(
+ "name, expected_fragment",
+ [
+ ("", "`name` is required."),
+ (" ", "`name` is required."),
+ ],
+)
+def test_chat_create_name_validation(rest_client, clear_chats, name, expected_fragment):
+ res = rest_client.post("/chats", json={"name": name, "dataset_ids": []})
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] == 102, payload
+ assert expected_fragment in payload["message"], payload
+
+
+@pytest.mark.p2
+def test_chat_duplicate_name_validation(rest_client, clear_chats):
+ first = rest_client.post("/chats", json={"name": "duplicate_chat_name", "dataset_ids": []})
+ assert first.status_code == 200
+ first_payload = first.json()
+ assert first_payload["code"] == 0, first_payload
+
+ second = rest_client.post("/chats", json={"name": "duplicate_chat_name", "dataset_ids": []})
+ assert second.status_code == 200
+ second_payload = second.json()
+ assert second_payload["code"] == 102, second_payload
+ assert "Duplicated chat name" in second_payload["message"], second_payload
+
+
+@pytest.mark.p2
+def test_chat_list_pagination(rest_client, clear_chats):
+ for i in range(3):
+ res = rest_client.post("/chats", json={"name": f"chat_page_{i}", "dataset_ids": []})
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] == 0, payload
+
+ page_res = rest_client.get("/chats", params={"page": 1, "page_size": 2, "orderby": "create_time", "desc": "true"})
+ assert page_res.status_code == 200
+ page_payload = page_res.json()
+ assert page_payload["code"] == 0, page_payload
+ assert len(page_payload["data"]["chats"]) == 2, page_payload
+ assert page_payload["data"]["total"] >= 3, page_payload
diff --git a/test/testcases/restful_api/test_chunks.py b/test/testcases/restful_api/test_chunks.py
new file mode 100644
index 0000000000..42009a2af5
--- /dev/null
+++ b/test/testcases/restful_api/test_chunks.py
@@ -0,0 +1,124 @@
+#
+# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import pytest
+
+
+def _assert_created_chunk_id(payload):
+ chunk_id = payload["data"]["chunk"].get("id")
+ assert chunk_id, payload
+ assert isinstance(chunk_id, str), payload
+ assert chunk_id.strip(), payload
+ return chunk_id
+
+
+@pytest.mark.p1
+def test_chunks_add_list_get_update_delete_cycle(rest_client, create_document):
+ dataset_id, document_id = create_document("chunk_cycle.txt")
+ base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks"
+
+ add_res = rest_client.post(
+ base_path,
+ json={"content": "batch2 chunk content", "important_keywords": ["batch2"], "questions": ["what is batch2?"]},
+ )
+ assert add_res.status_code == 200
+ add_payload = add_res.json()
+ assert add_payload["code"] == 0, add_payload
+ chunk_id = _assert_created_chunk_id(add_payload)
+
+ list_res = rest_client.get(base_path, params={"id": chunk_id})
+ assert list_res.status_code == 200
+ list_payload = list_res.json()
+ assert list_payload["code"] == 0, list_payload
+ assert list_payload["data"]["total"] == 1, list_payload
+ assert list_payload["data"]["chunks"][0]["id"] == chunk_id, list_payload
+
+ get_res = rest_client.get(f"{base_path}/{chunk_id}")
+ assert get_res.status_code == 200
+ get_payload = get_res.json()
+ assert get_payload["code"] == 0, get_payload
+ assert get_payload["data"]["id"] == chunk_id, get_payload
+
+ update_res = rest_client.patch(
+ f"{base_path}/{chunk_id}",
+ json={"content": "batch2 chunk content updated"},
+ )
+ assert update_res.status_code == 200
+ update_payload = update_res.json()
+ assert update_payload["code"] == 0, update_payload
+
+ get_updated_res = rest_client.get(f"{base_path}/{chunk_id}")
+ assert get_updated_res.status_code == 200
+ get_updated_payload = get_updated_res.json()
+ assert get_updated_payload["code"] == 0, get_updated_payload
+ assert get_updated_payload["data"]["content_with_weight"] == "batch2 chunk content updated", get_updated_payload
+
+ delete_candidate_res = rest_client.post(base_path, json={"content": "batch2 chunk content to delete"})
+ assert delete_candidate_res.status_code == 200
+ delete_candidate_payload = delete_candidate_res.json()
+ assert delete_candidate_payload["code"] == 0, delete_candidate_payload
+ delete_candidate_id = _assert_created_chunk_id(delete_candidate_payload)
+
+ delete_res = rest_client.delete(base_path, json={"chunk_ids": [delete_candidate_id]})
+ assert delete_res.status_code == 200
+ delete_payload = delete_res.json()
+ assert delete_payload["code"] == 0, delete_payload
+
+ deleted_list_res = rest_client.get(base_path, params={"id": delete_candidate_id})
+ assert deleted_list_res.status_code == 200
+ deleted_list_payload = deleted_list_res.json()
+ assert deleted_list_payload["code"] != 0, deleted_list_payload
+
+ deleted_get_res = rest_client.get(f"{base_path}/{delete_candidate_id}")
+ assert deleted_get_res.status_code == 200
+ deleted_get_payload = deleted_get_res.json()
+ assert deleted_get_payload["code"] != 0, deleted_get_payload
+
+
+@pytest.mark.p2
+def test_chunks_add_requires_content(rest_client, create_document):
+ dataset_id, document_id = create_document("chunk_requires_content.txt")
+ res = rest_client.post(
+ f"/datasets/{dataset_id}/documents/{document_id}/chunks",
+ json={"content": " "},
+ )
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] == 102, payload
+ assert payload["message"] == "`content` is required", payload
+
+
+@pytest.mark.p2
+def test_chunks_list_empty_document(rest_client, create_document):
+ dataset_id, document_id = create_document("chunk_list_empty.txt")
+ base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks"
+ list_res = rest_client.get(base_path)
+ assert list_res.status_code == 200
+ list_payload = list_res.json()
+ assert list_payload["code"] == 0, list_payload
+ assert "chunks" in list_payload["data"], list_payload
+ assert "doc" in list_payload["data"], list_payload
+
+
+@pytest.mark.p2
+def test_chunks_delete_partial_invalid(rest_client, create_document):
+ dataset_id, document_id = create_document("chunk_delete_partial.txt")
+ base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks"
+ delete_res = rest_client.delete(base_path, json={"chunk_ids": ["invalid_chunk_id"]})
+ assert delete_res.status_code == 200
+ delete_payload = delete_res.json()
+ assert delete_payload["code"] == 102, delete_payload
+ assert "expect 1" in delete_payload["message"], delete_payload
diff --git a/test/testcases/restful_api/test_connector_routes_unit.py b/test/testcases/restful_api/test_connector_routes_unit.py
new file mode 100644
index 0000000000..ad47aef379
--- /dev/null
+++ b/test/testcases/restful_api/test_connector_routes_unit.py
@@ -0,0 +1,718 @@
+#
+# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import asyncio
+import importlib.util
+import json
+import sys
+from pathlib import Path
+from types import ModuleType, SimpleNamespace
+
+import pytest
+
+
+class _DummyManager:
+ def route(self, *_args, **_kwargs):
+ def decorator(func):
+ return func
+
+ return decorator
+
+
+class _AwaitableValue:
+ def __init__(self, value):
+ self._value = value
+
+ def __await__(self):
+ async def _co():
+ return self._value
+
+ return _co().__await__()
+
+
+class _Args(dict):
+ def get(self, key, default=None, type=None):
+ value = super().get(key, default)
+ if type is None:
+ return value
+ try:
+ return type(value)
+ except (TypeError, ValueError):
+ return default
+
+ def to_dict(self, flat=True):
+ return dict(self)
+
+
+class _FakeResponse:
+ def __init__(self, body, status_code):
+ self.body = body
+ self.status_code = status_code
+ self.headers = {}
+
+
+class _FakeConnectorRecord:
+ def __init__(self, payload):
+ self._payload = payload
+
+ def to_dict(self):
+ return dict(self._payload)
+
+
+class _FakeCredentials:
+ def __init__(self, raw='{"refresh_token":"rt","access_token":"at"}'):
+ self._raw = raw
+
+ def to_json(self):
+ return self._raw
+
+
+class _FakeFlow:
+ def __init__(self, client_config, scopes):
+ self.client_config = client_config
+ self.scopes = scopes
+ self.redirect_uri = None
+ self.credentials = _FakeCredentials()
+ self.auth_kwargs = None
+ self.token_code = None
+ self.token_code_verifier = None
+ self.code_verifier = "fake-code-verifier"
+
+ def authorization_url(self, **kwargs):
+ self.auth_kwargs = dict(kwargs)
+ return f"https://oauth.example/{kwargs['state']}", kwargs["state"]
+
+ def fetch_token(self, code, code_verifier=None):
+ self.token_code = code
+ self.token_code_verifier = code_verifier
+
+
+class _FakeBoxToken:
+ def __init__(self, access_token, refresh_token):
+ self.access_token = access_token
+ self.refresh_token = refresh_token
+
+
+class _FakeBoxOAuth:
+ def __init__(self, config):
+ self.config = config
+ self.exchange_code = None
+
+ def get_authorize_url(self, options):
+ return f"https://box.example/auth?state={options.state}&redirect={options.redirect_uri}"
+
+ def get_tokens_authorization_code_grant(self, code):
+ self.exchange_code = code
+
+ def retrieve_token(self):
+ return _FakeBoxToken("box-access", "box-refresh")
+
+
+class _FakeRedis:
+ def __init__(self):
+ self.store = {}
+ self.set_calls = []
+ self.deleted = []
+
+ def get(self, key):
+ return self.store.get(key)
+
+ def set_obj(self, key, obj, ttl):
+ self.set_calls.append((key, obj, ttl))
+ self.store[key] = json.dumps(obj)
+
+ def delete(self, key):
+ self.deleted.append(key)
+ self.store.pop(key, None)
+
+
+def _run(coro):
+ return asyncio.run(coro)
+
+
+def _set_request(module, *, args=None, json_body=None):
+ module.request = SimpleNamespace(
+ args=_Args(args or {}),
+ json=_AwaitableValue({} if json_body is None else json_body),
+ )
+
+
+@pytest.fixture(scope="session")
+def auth():
+ return "unit-auth"
+
+
+@pytest.fixture(scope="session", autouse=True)
+def set_tenant_info():
+ return None
+
+
+def _load_connector_app(monkeypatch):
+ repo_root = Path(__file__).resolve().parents[3]
+
+ api_pkg = ModuleType("api")
+ api_pkg.__path__ = [str(repo_root / "api")]
+ monkeypatch.setitem(sys.modules, "api", api_pkg)
+
+ apps_mod = ModuleType("api.apps")
+ apps_mod.__path__ = [str(repo_root / "api" / "apps")]
+ apps_mod.current_user = SimpleNamespace(id="tenant-1")
+ apps_mod.login_required = lambda fn: fn
+ monkeypatch.setitem(sys.modules, "api.apps", apps_mod)
+
+ db_mod = ModuleType("api.db")
+ db_mod.InputType = SimpleNamespace(POLL="POLL")
+ monkeypatch.setitem(sys.modules, "api.db", db_mod)
+
+ services_pkg = ModuleType("api.db.services")
+ services_pkg.__path__ = []
+ monkeypatch.setitem(sys.modules, "api.db.services", services_pkg)
+
+ connector_service_mod = ModuleType("api.db.services.connector_service")
+
+ class _StubConnectorService:
+ @staticmethod
+ def update_by_id(*_args, **_kwargs):
+ return True
+
+ @staticmethod
+ def save(**_kwargs):
+ return True
+
+ @staticmethod
+ def get_by_id(_connector_id):
+ return True, _FakeConnectorRecord({"id": _connector_id})
+
+ @staticmethod
+ def list(_tenant_id):
+ return []
+
+ @staticmethod
+ def resume(*_args, **_kwargs):
+ return True
+
+ @staticmethod
+ def rebuild(*_args, **_kwargs):
+ return None
+
+ @staticmethod
+ def delete_by_id(*_args, **_kwargs):
+ return True
+
+ class _StubSyncLogsService:
+ @staticmethod
+ def list_sync_tasks(*_args, **_kwargs):
+ return [], 0
+
+ connector_service_mod.ConnectorService = _StubConnectorService
+ connector_service_mod.SyncLogsService = _StubSyncLogsService
+ monkeypatch.setitem(sys.modules, "api.db.services.connector_service", connector_service_mod)
+
+ api_utils_mod = ModuleType("api.utils.api_utils")
+
+ async def _get_request_json():
+ return {}
+
+ api_utils_mod.get_request_json = _get_request_json
+ api_utils_mod.get_json_result = lambda data=None, message="", code=0: {
+ "code": code,
+ "message": message,
+ "data": data,
+ }
+ api_utils_mod.get_data_error_result = lambda message="", code=400, data=None: {
+ "code": code,
+ "message": message,
+ "data": data,
+ }
+ api_utils_mod.validate_request = lambda *_args, **_kwargs: (lambda fn: fn)
+ monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
+
+ constants_mod = ModuleType("common.constants")
+ constants_mod.RetCode = SimpleNamespace(
+ ARGUMENT_ERROR=101,
+ SERVER_ERROR=500,
+ RUNNING=102,
+ PERMISSION_ERROR=403,
+ )
+ constants_mod.TaskStatus = SimpleNamespace(SCHEDULE="schedule", CANCEL="cancel")
+ monkeypatch.setitem(sys.modules, "common.constants", constants_mod)
+
+ config_mod = ModuleType("common.data_source.config")
+ config_mod.GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI = "https://example.com/drive"
+ config_mod.GMAIL_WEB_OAUTH_REDIRECT_URI = "https://example.com/gmail"
+ config_mod.BOX_WEB_OAUTH_REDIRECT_URI = "https://example.com/box"
+ config_mod.DocumentSource = SimpleNamespace(GMAIL="gmail", GOOGLE_DRIVE="google-drive")
+ monkeypatch.setitem(sys.modules, "common.data_source.config", config_mod)
+
+ google_constants_mod = ModuleType("common.data_source.google_util.constant")
+ google_constants_mod.WEB_OAUTH_POPUP_TEMPLATE = (
+ "
{title}"
+ "{heading}
{message}
"
+ )
+ google_constants_mod.GOOGLE_SCOPES = {
+ config_mod.DocumentSource.GMAIL: ["scope-gmail"],
+ config_mod.DocumentSource.GOOGLE_DRIVE: ["scope-drive"],
+ }
+ monkeypatch.setitem(sys.modules, "common.data_source.google_util.constant", google_constants_mod)
+
+ misc_mod = ModuleType("common.misc_utils")
+ misc_mod.get_uuid = lambda: "uuid-from-helper"
+ monkeypatch.setitem(sys.modules, "common.misc_utils", misc_mod)
+
+ rag_pkg = ModuleType("rag")
+ rag_pkg.__path__ = [str(repo_root / "rag")]
+ monkeypatch.setitem(sys.modules, "rag", rag_pkg)
+
+ rag_utils_pkg = ModuleType("rag.utils")
+ rag_utils_pkg.__path__ = [str(repo_root / "rag" / "utils")]
+ monkeypatch.setitem(sys.modules, "rag.utils", rag_utils_pkg)
+
+ redis_mod = ModuleType("rag.utils.redis_conn")
+ redis_mod.REDIS_CONN = _FakeRedis()
+ monkeypatch.setitem(sys.modules, "rag.utils.redis_conn", redis_mod)
+
+ quart_mod = ModuleType("quart")
+ quart_mod.request = SimpleNamespace(args=_Args(), json=_AwaitableValue({}))
+
+ async def _make_response(body, status_code):
+ return _FakeResponse(body, status_code)
+
+ quart_mod.make_response = _make_response
+ monkeypatch.setitem(sys.modules, "quart", quart_mod)
+
+ google_pkg = ModuleType("google_auth_oauthlib")
+ google_pkg.__path__ = []
+ monkeypatch.setitem(sys.modules, "google_auth_oauthlib", google_pkg)
+
+ google_flow_mod = ModuleType("google_auth_oauthlib.flow")
+
+ class _StubFlow:
+ @classmethod
+ def from_client_config(cls, client_config, scopes):
+ return _FakeFlow(client_config, scopes)
+
+ google_flow_mod.Flow = _StubFlow
+ monkeypatch.setitem(sys.modules, "google_auth_oauthlib.flow", google_flow_mod)
+
+ box_mod = ModuleType("box_sdk_gen")
+
+ class _OAuthConfig:
+ def __init__(self, client_id, client_secret):
+ self.client_id = client_id
+ self.client_secret = client_secret
+
+ class _GetAuthorizeUrlOptions:
+ def __init__(self, redirect_uri, state):
+ self.redirect_uri = redirect_uri
+ self.state = state
+
+ box_mod.BoxOAuth = _FakeBoxOAuth
+ box_mod.OAuthConfig = _OAuthConfig
+ box_mod.GetAuthorizeUrlOptions = _GetAuthorizeUrlOptions
+ monkeypatch.setitem(sys.modules, "box_sdk_gen", box_mod)
+
+ module_path = repo_root / "api" / "apps" / "restful_apis" / "connector_api.py"
+ spec = importlib.util.spec_from_file_location("test_connector_routes_unit", module_path)
+ module = importlib.util.module_from_spec(spec)
+ module.manager = _DummyManager()
+ spec.loader.exec_module(module)
+ return module
+
+
+@pytest.mark.p2
+def test_connector_basic_routes_and_task_controls(monkeypatch):
+ module = _load_connector_app(monkeypatch)
+
+ async def _no_sleep(_secs):
+ return None
+
+ monkeypatch.setattr(module.asyncio, "sleep", _no_sleep)
+
+ records = {"conn-1": _FakeConnectorRecord({"id": "conn-1", "source": "drive"})}
+ update_calls = []
+ save_calls = []
+ resume_calls = []
+ delete_calls = []
+
+ monkeypatch.setattr(module.ConnectorService, "update_by_id", lambda cid, payload: update_calls.append((cid, payload)))
+
+ def _save(**payload):
+ save_calls.append(payload)
+ records[payload["id"]] = _FakeConnectorRecord(payload)
+
+ monkeypatch.setattr(module.ConnectorService, "save", _save)
+ monkeypatch.setattr(module.ConnectorService, "get_by_id", lambda cid: (True, records[cid]))
+ monkeypatch.setattr(module.ConnectorService, "list", lambda tenant_id: [{"id": "listed", "tenant": tenant_id}])
+ monkeypatch.setattr(module.SyncLogsService, "list_sync_tasks", lambda cid, page, page_size: ([{"id": "log-1"}], 9))
+ monkeypatch.setattr(module.ConnectorService, "resume", lambda cid, status: resume_calls.append((cid, status)))
+ monkeypatch.setattr(module.ConnectorService, "delete_by_id", lambda cid: delete_calls.append(cid))
+ monkeypatch.setattr(module, "get_uuid", lambda: "generated-id")
+
+ monkeypatch.setattr(
+ module,
+ "get_request_json",
+ lambda: _AwaitableValue({"id": "conn-1", "refresh_freq": 7, "config": {"x": 1}}),
+ )
+ res = _run(module.update_connector("conn-1"))
+ assert update_calls == [("conn-1", {'id': 'conn-1', "refresh_freq": 7, "config": {"x": 1}})]
+ assert res["data"]["id"] == "conn-1"
+
+ monkeypatch.setattr(
+ module,
+ "get_request_json",
+ lambda: _AwaitableValue({"name": "new", "source": "gmail", "config": {"y": 2}}),
+ )
+ res = _run(module.create_connector())
+ assert save_calls[-1]["id"] == "generated-id"
+ assert save_calls[-1]["tenant_id"] == "tenant-1"
+ assert save_calls[-1]["input_type"] == module.InputType.POLL
+ assert res["data"]["id"] == "generated-id"
+
+ list_res = module.list_connector()
+ assert list_res["data"] == [{"id": "listed", "tenant": "tenant-1"}]
+
+ monkeypatch.setattr(module.ConnectorService, "get_by_id", lambda _cid: (False, None))
+ missing_res = module.get_connector("missing")
+ assert missing_res["message"] == "Can't find this Connector!"
+
+ monkeypatch.setattr(module.ConnectorService, "get_by_id", lambda cid: (True, _FakeConnectorRecord({"id": cid})))
+ found_res = module.get_connector("conn-2")
+ assert found_res["data"]["id"] == "conn-2"
+
+ _set_request(module, args={"page": "2", "page_size": "7"})
+ logs_res = module.list_logs("conn-log")
+ assert logs_res["data"] == {"total": 9, "logs": [{"id": "log-1"}]}
+
+ monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"resume": True}))
+ assert _run(module.resume("conn-r1"))["data"] is True
+
+ monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"resume": False}))
+ assert _run(module.resume("conn-r2"))["data"] is True
+ assert ("conn-r1", module.TaskStatus.SCHEDULE) in resume_calls
+ assert ("conn-r2", module.TaskStatus.CANCEL) in resume_calls
+
+ monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"kb_id": "kb-1"}))
+ monkeypatch.setattr(module.ConnectorService, "rebuild", lambda *_args: "rebuild-failed")
+ failed_rebuild = _run(module.rebuild("conn-rb"))
+ assert failed_rebuild["code"] == module.RetCode.SERVER_ERROR
+ assert failed_rebuild["data"] is False
+
+ monkeypatch.setattr(module.ConnectorService, "rebuild", lambda *_args: None)
+ ok_rebuild = _run(module.rebuild("conn-rb"))
+ assert ok_rebuild["data"] is True
+
+ rm_res = module.rm_connector("conn-rm")
+ assert rm_res["data"] is True
+ assert ("conn-rm", module.TaskStatus.CANCEL) in resume_calls
+ assert delete_calls == ["conn-rm"]
+
+
+@pytest.mark.p2
+def test_connector_oauth_helper_functions(monkeypatch):
+ module = _load_connector_app(monkeypatch)
+
+ assert module._web_state_cache_key("flow-a", "gmail") == "gmail_web_flow_state:flow-a"
+ assert module._web_result_cache_key("flow-b", "google-drive") == "google-drive_web_flow_result:flow-b"
+
+ creds_dict = {"web": {"client_id": "id"}}
+ assert module._load_credentials(creds_dict) == creds_dict
+ assert module._load_credentials(json.dumps(creds_dict)) == creds_dict
+
+ with pytest.raises(ValueError, match="Invalid Google credentials JSON"):
+ module._load_credentials("{not-json")
+
+ assert module._get_web_client_config(creds_dict) == {"web": {"client_id": "id"}}
+ with pytest.raises(ValueError, match="must include a 'web'"):
+ module._get_web_client_config({"installed": {"client_id": "id"}})
+
+ popup_ok = _run(module._render_web_oauth_popup("flow-1", True, "done", "gmail"))
+ assert popup_ok.status_code == 200
+ assert popup_ok.headers["Content-Type"] == "text/html; charset=utf-8"
+ assert "Authorization complete" in popup_ok.body
+ assert "ragflow-gmail-oauth" in popup_ok.body
+
+ popup_error = _run(module._render_web_oauth_popup("flow-2", False, "", "google-drive"))
+ assert popup_error.status_code == 200
+ assert "Authorization failed" in popup_error.body
+ assert "<denied>" in popup_error.body
+
+
+@pytest.mark.p2
+def test_start_google_web_oauth_matrix(monkeypatch):
+ module = _load_connector_app(monkeypatch)
+
+ redis = _FakeRedis()
+ monkeypatch.setattr(module, "REDIS_CONN", redis)
+ monkeypatch.setattr(module.time, "time", lambda: 1700000000)
+
+ flow_calls = []
+
+ def _from_client_config(client_config, scopes):
+ flow = _FakeFlow(client_config, scopes)
+ flow_calls.append(flow)
+ return flow
+
+ monkeypatch.setattr(module.Flow, "from_client_config", staticmethod(_from_client_config))
+
+ _set_request(module, args={"type": "invalid"})
+ monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"credentials": "{}"}))
+ invalid_type = _run(module.start_google_web_oauth())
+ assert invalid_type["code"] == module.RetCode.ARGUMENT_ERROR
+
+ monkeypatch.setattr(module, "GMAIL_WEB_OAUTH_REDIRECT_URI", "")
+ _set_request(module, args={"type": "gmail"})
+ missing_redirect = _run(module.start_google_web_oauth())
+ assert missing_redirect["code"] == module.RetCode.SERVER_ERROR
+
+ monkeypatch.setattr(module, "GMAIL_WEB_OAUTH_REDIRECT_URI", "https://example.com/gmail")
+ monkeypatch.setattr(module, "GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI", "https://example.com/drive")
+
+ _set_request(module, args={"type": "google-drive"})
+ monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"credentials": "{invalid-json"}))
+ invalid_credentials = _run(module.start_google_web_oauth())
+ assert invalid_credentials["code"] == module.RetCode.ARGUMENT_ERROR
+
+ monkeypatch.setattr(
+ module,
+ "get_request_json",
+ lambda: _AwaitableValue({"credentials": json.dumps({"web": {"client_id": "id"}, "refresh_token": "rt"})}),
+ )
+ has_refresh_token = _run(module.start_google_web_oauth())
+ assert has_refresh_token["code"] == module.RetCode.ARGUMENT_ERROR
+
+ monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"credentials": json.dumps({"installed": {"x": 1}})}))
+ missing_web = _run(module.start_google_web_oauth())
+ assert missing_web["code"] == module.RetCode.ARGUMENT_ERROR
+
+ ids = iter(["flow-gmail", "flow-drive"])
+ monkeypatch.setattr(module.uuid, "uuid4", lambda: next(ids))
+
+ monkeypatch.setattr(
+ module,
+ "get_request_json",
+ lambda: _AwaitableValue({"credentials": json.dumps({"web": {"client_id": "id", "client_secret": "secret"}})}),
+ )
+
+ _set_request(module, args={"type": "gmail"})
+ gmail_ok = _run(module.start_google_web_oauth())
+ assert gmail_ok["code"] == 0
+ assert gmail_ok["data"]["flow_id"] == "flow-gmail"
+ assert gmail_ok["data"]["authorization_url"].endswith("flow-gmail")
+
+ _set_request(module, args={})
+ drive_ok = _run(module.start_google_web_oauth())
+ assert drive_ok["code"] == 0
+ assert drive_ok["data"]["flow_id"] == "flow-drive"
+ assert drive_ok["data"]["authorization_url"].endswith("flow-drive")
+
+ assert any(call.scopes == module.GOOGLE_SCOPES[module.DocumentSource.GMAIL] for call in flow_calls)
+ assert any(call.scopes == module.GOOGLE_SCOPES[module.DocumentSource.GOOGLE_DRIVE] for call in flow_calls)
+ assert "gmail_web_flow_state:flow-gmail" in redis.store
+ assert "google-drive_web_flow_state:flow-drive" in redis.store
+ assert json.loads(redis.store["gmail_web_flow_state:flow-gmail"])["code_verifier"] == "fake-code-verifier"
+ assert json.loads(redis.store["google-drive_web_flow_state:flow-drive"])["code_verifier"] == "fake-code-verifier"
+
+
+@pytest.mark.p2
+def test_google_web_oauth_callbacks_matrix(monkeypatch):
+ module = _load_connector_app(monkeypatch)
+
+ flow_calls = []
+
+ def _from_client_config(client_config, scopes):
+ flow = _FakeFlow(client_config, scopes)
+ flow_calls.append(flow)
+ return flow
+
+ monkeypatch.setattr(module.Flow, "from_client_config", staticmethod(_from_client_config))
+
+ callback_specs = [
+ (
+ module.google_gmail_web_oauth_callback,
+ "gmail",
+ module.GMAIL_WEB_OAUTH_REDIRECT_URI,
+ module.GOOGLE_SCOPES[module.DocumentSource.GMAIL],
+ ),
+ (
+ module.google_drive_web_oauth_callback,
+ "google-drive",
+ module.GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI,
+ module.GOOGLE_SCOPES[module.DocumentSource.GOOGLE_DRIVE],
+ ),
+ ]
+
+ for callback, source, expected_redirect, expected_scopes in callback_specs:
+ redis = _FakeRedis()
+ monkeypatch.setattr(module, "REDIS_CONN", redis)
+
+ _set_request(module, args={})
+ missing_state = _run(callback())
+ assert "Missing OAuth state parameter." in missing_state.body
+
+ _set_request(module, args={"state": "sid"})
+ expired_state = _run(callback())
+ assert "Authorization session expired" in expired_state.body
+
+ redis.store[module._web_state_cache_key("sid", source)] = json.dumps({"user_id": "tenant-1"})
+ _set_request(module, args={"state": "sid"})
+ invalid_state = _run(callback())
+ assert "Authorization session was invalid" in invalid_state.body
+ assert module._web_state_cache_key("sid", source) in redis.deleted
+
+ redis.store[module._web_state_cache_key("sid", source)] = json.dumps({
+ "user_id": "tenant-1",
+ "client_config": {"web": {"client_id": "cid"}},
+ })
+ _set_request(module, args={"state": "sid", "error": "denied", "error_description": "permission denied"})
+ oauth_error = _run(callback())
+ assert "permission denied" in oauth_error.body
+
+ redis.store[module._web_state_cache_key("sid", source)] = json.dumps({
+ "user_id": "tenant-1",
+ "client_config": {"web": {"client_id": "cid"}},
+ })
+ _set_request(module, args={"state": "sid"})
+ missing_code = _run(callback())
+ assert "Missing authorization code" in missing_code.body
+
+ redis.store[module._web_state_cache_key("sid", source)] = json.dumps({
+ "user_id": "tenant-1",
+ "client_config": {"web": {"client_id": "cid"}},
+ "code_verifier": "state-code-verifier",
+ })
+ _set_request(module, args={"state": "sid", "code": "code-123"})
+ success = _run(callback())
+ assert "Authorization completed successfully." in success.body
+
+ result_key = module._web_result_cache_key("sid", source)
+ assert result_key in redis.store
+ assert module._web_state_cache_key("sid", source) in redis.deleted
+
+ assert flow_calls[-1].redirect_uri == expected_redirect
+ assert flow_calls[-1].scopes == expected_scopes
+ assert flow_calls[-1].token_code == "code-123"
+ assert flow_calls[-1].token_code_verifier == "state-code-verifier"
+
+
+@pytest.mark.p2
+def test_poll_google_web_result_matrix(monkeypatch):
+ module = _load_connector_app(monkeypatch)
+ redis = _FakeRedis()
+ monkeypatch.setattr(module, "REDIS_CONN", redis)
+
+ _set_request(module, args={"type": "invalid"}, json_body={"flow_id": "flow-1"})
+ invalid_type = _run(module.poll_google_web_result())
+ assert invalid_type["code"] == module.RetCode.ARGUMENT_ERROR
+
+ _set_request(module, args={"type": "gmail"}, json_body={"flow_id": "flow-1"})
+ pending = _run(module.poll_google_web_result())
+ assert pending["code"] == module.RetCode.RUNNING
+
+ redis.store[module._web_result_cache_key("flow-1", "gmail")] = json.dumps(
+ {"user_id": "another-user", "credentials": "token-x"}
+ )
+ _set_request(module, args={"type": "gmail"}, json_body={"flow_id": "flow-1"})
+ permission_error = _run(module.poll_google_web_result())
+ assert permission_error["code"] == module.RetCode.PERMISSION_ERROR
+
+ redis.store[module._web_result_cache_key("flow-1", "gmail")] = json.dumps(
+ {"user_id": "tenant-1", "credentials": "token-ok"}
+ )
+ _set_request(module, args={"type": "gmail"}, json_body={"flow_id": "flow-1"})
+ success = _run(module.poll_google_web_result())
+ assert success["code"] == 0
+ assert success["data"] == {"credentials": "token-ok"}
+ assert module._web_result_cache_key("flow-1", "gmail") in redis.deleted
+
+
+@pytest.mark.p2
+def test_box_oauth_start_callback_and_poll_matrix(monkeypatch):
+ module = _load_connector_app(monkeypatch)
+ redis = _FakeRedis()
+ monkeypatch.setattr(module, "REDIS_CONN", redis)
+
+ created_auth = []
+
+ class _TrackingBoxOAuth(_FakeBoxOAuth):
+ def __init__(self, config):
+ super().__init__(config)
+ created_auth.append(self)
+
+ monkeypatch.setattr(module, "BoxOAuth", _TrackingBoxOAuth)
+ monkeypatch.setattr(module.uuid, "uuid4", lambda: "flow-box")
+ monkeypatch.setattr(module.time, "time", lambda: 1800000000)
+
+ monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({}))
+ missing_params = _run(module.start_box_web_oauth())
+ assert missing_params["code"] == module.RetCode.ARGUMENT_ERROR
+
+ monkeypatch.setattr(
+ module,
+ "get_request_json",
+ lambda: _AwaitableValue({"client_id": "cid", "client_secret": "sec", "redirect_uri": "https://box.local/callback"}),
+ )
+ start_ok = _run(module.start_box_web_oauth())
+ assert start_ok["code"] == 0
+ assert start_ok["data"]["flow_id"] == "flow-box"
+ assert "authorization_url" in start_ok["data"]
+ assert module._web_state_cache_key("flow-box", "box") in redis.store
+
+ _set_request(module, args={})
+ missing_state = _run(module.box_web_oauth_callback())
+ assert "Missing OAuth parameters." in missing_state.body
+
+ _set_request(module, args={"state": "flow-box"})
+ missing_code = _run(module.box_web_oauth_callback())
+ assert "Missing authorization code from Box." in missing_code.body
+
+ redis.store[module._web_state_cache_key("flow-null", "box")] = "null"
+ _set_request(module, args={"state": "flow-null", "code": "abc"})
+ invalid_session = _run(module.box_web_oauth_callback())
+ assert invalid_session["code"] == module.RetCode.ARGUMENT_ERROR
+
+ redis.store[module._web_state_cache_key("flow-box", "box")] = json.dumps(
+ {"user_id": "tenant-1", "client_id": "cid", "client_secret": "sec"}
+ )
+ _set_request(module, args={"state": "flow-box", "code": "abc", "error": "access_denied", "error_description": "denied"})
+ callback_error = _run(module.box_web_oauth_callback())
+ assert "denied" in callback_error.body
+
+ redis.store[module._web_state_cache_key("flow-ok", "box")] = json.dumps(
+ {"user_id": "tenant-1", "client_id": "cid", "client_secret": "sec"}
+ )
+ _set_request(module, args={"state": "flow-ok", "code": "code-ok"})
+ callback_success = _run(module.box_web_oauth_callback())
+ assert "Authorization completed successfully." in callback_success.body
+ assert created_auth[-1].exchange_code == "code-ok"
+ assert module._web_result_cache_key("flow-ok", "box") in redis.store
+ assert module._web_state_cache_key("flow-ok", "box") in redis.deleted
+
+ monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"flow_id": "flow-ok"}))
+ redis.store.pop(module._web_result_cache_key("flow-ok", "box"), None)
+ pending = _run(module.poll_box_web_result())
+ assert pending["code"] == module.RetCode.RUNNING
+
+ redis.store[module._web_result_cache_key("flow-ok", "box")] = json.dumps({"user_id": "another-user"})
+ permission_error = _run(module.poll_box_web_result())
+ assert permission_error["code"] == module.RetCode.PERMISSION_ERROR
+
+ redis.store[module._web_result_cache_key("flow-ok", "box")] = json.dumps(
+ {"user_id": "tenant-1", "access_token": "at", "refresh_token": "rt"}
+ )
+ poll_success = _run(module.poll_box_web_result())
+ assert poll_success["code"] == 0
+ assert poll_success["data"]["credentials"]["access_token"] == "at"
+ assert module._web_result_cache_key("flow-ok", "box") in redis.deleted
diff --git a/test/testcases/restful_api/test_datasets.py b/test/testcases/restful_api/test_datasets.py
new file mode 100644
index 0000000000..a0f1726102
--- /dev/null
+++ b/test/testcases/restful_api/test_datasets.py
@@ -0,0 +1,335 @@
+#
+# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import pytest
+from configs import DATASET_NAME_LIMIT
+
+
+@pytest.mark.p1
+class TestDatasetsAuthorization:
+ def test_create_requires_auth(self, rest_client_noauth):
+ res = rest_client_noauth.post("/datasets", json={"name": "auth_test"})
+ assert res.status_code == 401
+ payload = res.json()
+ assert payload["code"] == 401, payload
+
+
+@pytest.mark.p1
+def test_dataset_crud_cycle(rest_client, clear_datasets):
+ create_res = rest_client.post("/datasets", json={"name": "restful_dataset_crud"})
+ assert create_res.status_code == 200
+ create_payload = create_res.json()
+ assert create_payload["code"] == 0, create_payload
+ dataset_id = create_payload["data"]["id"]
+
+ get_res = rest_client.get(f"/datasets/{dataset_id}")
+ assert get_res.status_code == 200
+ get_payload = get_res.json()
+ assert get_payload["code"] == 0, get_payload
+ assert get_payload["data"]["id"] == dataset_id, get_payload
+
+ update_res = rest_client.put(
+ f"/datasets/{dataset_id}",
+ json={"name": "restful_dataset_crud_updated"},
+ )
+ assert update_res.status_code == 200
+ update_payload = update_res.json()
+ assert update_payload["code"] == 0, update_payload
+ assert update_payload["data"]["name"] == "restful_dataset_crud_updated", update_payload
+
+ list_res = rest_client.get("/datasets", params={"id": dataset_id})
+ assert list_res.status_code == 200
+ list_payload = list_res.json()
+ assert list_payload["code"] == 0, list_payload
+ assert len(list_payload["data"]) == 1, list_payload
+ assert list_payload["data"][0]["id"] == dataset_id, list_payload
+ assert list_payload.get("total_datasets", 0) >= 1, list_payload
+
+ delete_res = rest_client.delete("/datasets", json={"ids": [dataset_id]})
+ assert delete_res.status_code == 200
+ delete_payload = delete_res.json()
+ assert delete_payload["code"] == 0, delete_payload
+
+ list_after_delete = rest_client.get("/datasets")
+ assert list_after_delete.status_code == 200
+ list_after_delete_payload = list_after_delete.json()
+ assert list_after_delete_payload["code"] == 0, list_after_delete_payload
+ assert all(dataset["id"] != dataset_id for dataset in list_after_delete_payload["data"]), list_after_delete_payload
+
+
+@pytest.mark.p2
+@pytest.mark.parametrize(
+ "name, expected_fragment",
+ [
+ ("", "String should have at least 1 character"),
+ (" ", "String should have at least 1 character"),
+ ("a" * (DATASET_NAME_LIMIT + 1), f"String should have at most {DATASET_NAME_LIMIT} characters"),
+ ],
+ ids=["empty", "spaces", "too_long"],
+)
+def test_dataset_create_name_validation(rest_client, clear_datasets, name, expected_fragment):
+ res = rest_client.post("/datasets", json={"name": name})
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] == 101, payload
+ assert expected_fragment in payload["message"], payload
+
+
+@pytest.mark.p2
+def test_dataset_list_ordering_and_pagination(rest_client, clear_datasets):
+ for i in range(3):
+ res = rest_client.post("/datasets", json={"name": f"dataset_page_{i}"})
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] == 0, payload
+
+ list_res = rest_client.get(
+ "/datasets",
+ params={"page": 1, "page_size": 2, "orderby": "create_time", "desc": "true"},
+ )
+ assert list_res.status_code == 200
+ list_payload = list_res.json()
+ assert list_payload["code"] == 0, list_payload
+ assert len(list_payload["data"]) == 2, list_payload
+ assert list_payload.get("total_datasets", 0) >= 3, list_payload
+
+
+@pytest.mark.p2
+def test_dataset_search_endpoint(rest_client, ensure_parsed_document):
+ dataset_id, _ = ensure_parsed_document()
+ res = rest_client.post(
+ f"/datasets/{dataset_id}/search",
+ json={"question": "test TXT file", "page": 1, "size": 10},
+ )
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] == 0, payload
+ assert "chunks" in payload["data"], payload
+
+
+@pytest.mark.p2
+def test_dataset_search_requires_question(rest_client, create_dataset):
+ dataset_id = create_dataset("dataset_search_missing_question")
+ res = rest_client.post(f"/datasets/{dataset_id}/search", json={})
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] == 101, payload
+ assert "question" in payload["message"], payload
+
+
+@pytest.mark.p2
+def test_dataset_tags_and_aggregation(rest_client, create_dataset):
+ dataset_id = create_dataset("dataset_tags")
+ second_dataset_id = create_dataset("dataset_tags_second")
+
+ list_tags_res = rest_client.get(f"/datasets/{dataset_id}/tags")
+ assert list_tags_res.status_code == 200
+ list_tags_payload = list_tags_res.json()
+ # Known env/runtime behavior: this route can return 102 when retriever tag
+ # backend is unavailable for an empty dataset. Keep route-contract coverage.
+ assert list_tags_payload["code"] in (0, 102), list_tags_payload
+
+ aggregate_res = rest_client.get(
+ "/datasets/tags/aggregation",
+ params={"dataset_ids": f"{dataset_id},{second_dataset_id}"},
+ )
+ assert aggregate_res.status_code == 200
+ aggregate_payload = aggregate_res.json()
+ assert aggregate_payload["code"] in (0, 102), aggregate_payload
+
+ empty_aggregate_res = rest_client.get("/datasets/tags/aggregation")
+ assert empty_aggregate_res.status_code == 200
+ empty_aggregate_payload = empty_aggregate_res.json()
+ assert empty_aggregate_payload["code"] != 0, empty_aggregate_payload
+
+
+@pytest.mark.p2
+def test_dataset_tags_delete_and_rename_validation(rest_client, create_dataset):
+ dataset_id = create_dataset("dataset_tag_mutation")
+
+ delete_missing_tags = rest_client.delete(f"/datasets/{dataset_id}/tags", json={})
+ assert delete_missing_tags.status_code == 200
+ delete_missing_tags_payload = delete_missing_tags.json()
+ assert delete_missing_tags_payload["code"] != 0, delete_missing_tags_payload
+
+ delete_invalid_tags_type = rest_client.delete(f"/datasets/{dataset_id}/tags", json={"tags": "wrong"})
+ assert delete_invalid_tags_type.status_code == 200
+ delete_invalid_tags_type_payload = delete_invalid_tags_type.json()
+ assert delete_invalid_tags_type_payload["code"] != 0, delete_invalid_tags_type_payload
+
+ rename_empty = rest_client.put(
+ f"/datasets/{dataset_id}/tags",
+ json={"from_tag": "", "to_tag": ""},
+ )
+ assert rename_empty.status_code == 200
+ rename_empty_payload = rename_empty.json()
+ assert rename_empty_payload["code"] != 0, rename_empty_payload
+
+ rename_invalid_dataset = rest_client.put(
+ "/datasets/invalid_id/tags",
+ json={"from_tag": "old", "to_tag": "new"},
+ )
+ assert rename_invalid_dataset.status_code == 200
+ rename_invalid_dataset_payload = rename_invalid_dataset.json()
+ assert rename_invalid_dataset_payload["code"] != 0, rename_invalid_dataset_payload
+
+
+@pytest.mark.p2
+def test_dataset_flattened_metadata(rest_client, create_dataset):
+ first_dataset_id = create_dataset("flattened_meta_1")
+ second_dataset_id = create_dataset("flattened_meta_2")
+
+ flattened_res = rest_client.get(
+ "/datasets/metadata/flattened",
+ params={"dataset_ids": f"{first_dataset_id},{second_dataset_id}"},
+ )
+ assert flattened_res.status_code == 200
+ flattened_payload = flattened_res.json()
+ assert flattened_payload["code"] == 0, flattened_payload
+
+ empty_ids_res = rest_client.get("/datasets/metadata/flattened")
+ assert empty_ids_res.status_code == 200
+ empty_ids_payload = empty_ids_res.json()
+ assert empty_ids_payload["code"] != 0, empty_ids_payload
+
+ invalid_dataset_res = rest_client.get(
+ "/datasets/metadata/flattened",
+ params={"dataset_ids": "invalid_id"},
+ )
+ assert invalid_dataset_res.status_code == 200
+ invalid_dataset_payload = invalid_dataset_res.json()
+ assert invalid_dataset_payload["code"] != 0, invalid_dataset_payload
+
+
+@pytest.mark.p2
+def test_dataset_ingestion_summary_and_logs(rest_client, create_dataset):
+ dataset_id = create_dataset("dataset_ingestions")
+
+ summary_res = rest_client.get(f"/datasets/{dataset_id}/ingestions/summary")
+ assert summary_res.status_code == 200
+ summary_payload = summary_res.json()
+ assert summary_payload["code"] == 0, summary_payload
+ assert "doc_num" in summary_payload["data"], summary_payload
+ assert "chunk_num" in summary_payload["data"], summary_payload
+ assert "token_num" in summary_payload["data"], summary_payload
+ assert "status" in summary_payload["data"], summary_payload
+
+ logs_res = rest_client.get(
+ f"/datasets/{dataset_id}/ingestions",
+ params={"page": 1, "page_size": 10},
+ )
+ assert logs_res.status_code == 200
+ logs_payload = logs_res.json()
+ assert logs_payload["code"] == 0, logs_payload
+ assert "total" in logs_payload["data"], logs_payload
+ assert "logs" in logs_payload["data"], logs_payload
+
+ not_found_log_res = rest_client.get(f"/datasets/{dataset_id}/ingestions/nonexistent_log")
+ assert not_found_log_res.status_code == 200
+ not_found_log_payload = not_found_log_res.json()
+ assert not_found_log_payload["code"] != 0, not_found_log_payload
+
+
+@pytest.mark.p2
+def test_dataset_ingestion_invalid_dataset(rest_client):
+ summary_res = rest_client.get("/datasets/invalid_id/ingestions/summary")
+ assert summary_res.status_code == 200
+ summary_payload = summary_res.json()
+ assert summary_payload["code"] != 0, summary_payload
+
+ logs_res = rest_client.get("/datasets/invalid_id/ingestions")
+ assert logs_res.status_code == 200
+ logs_payload = logs_res.json()
+ assert logs_payload["code"] != 0, logs_payload
+
+ log_res = rest_client.get("/datasets/invalid_id/ingestions/some_log_id")
+ assert log_res.status_code == 200
+ log_payload = log_res.json()
+ assert log_payload["code"] != 0, log_payload
+
+
+@pytest.mark.p2
+def test_dataset_index_endpoints(rest_client, create_dataset):
+ dataset_id = create_dataset("dataset_index_endpoints")
+
+ run_invalid_type = rest_client.post(
+ f"/datasets/{dataset_id}/index",
+ params={"type": "invalid_type"},
+ )
+ assert run_invalid_type.status_code == 200
+ run_invalid_type_payload = run_invalid_type.json()
+ assert run_invalid_type_payload["code"] != 0, run_invalid_type_payload
+
+ run_no_docs = rest_client.post(
+ f"/datasets/{dataset_id}/index",
+ params={"type": "graph"},
+ )
+ assert run_no_docs.status_code == 200
+ run_no_docs_payload = run_no_docs.json()
+ assert run_no_docs_payload["code"] == 102, run_no_docs_payload
+
+ trace_no_task = rest_client.get(
+ f"/datasets/{dataset_id}/index",
+ params={"type": "graph"},
+ )
+ assert trace_no_task.status_code == 200
+ trace_no_task_payload = trace_no_task.json()
+ assert trace_no_task_payload["code"] == 0, trace_no_task_payload
+ assert trace_no_task_payload["data"] == {}, trace_no_task_payload
+
+ delete_graph = rest_client.delete(f"/datasets/{dataset_id}/graph")
+ assert delete_graph.status_code == 200
+ delete_graph_payload = delete_graph.json()
+ assert delete_graph_payload["code"] == 0, delete_graph_payload
+
+ delete_invalid_type = rest_client.delete(f"/datasets/{dataset_id}/invalid_type")
+ assert delete_invalid_type.status_code == 200
+ delete_invalid_type_payload = delete_invalid_type.json()
+ assert delete_invalid_type_payload["code"] != 0, delete_invalid_type_payload
+
+
+@pytest.mark.p2
+@pytest.mark.parametrize("index_type", ["graph", "raptor", "mindmap"])
+def test_dataset_index_run_with_document_creates_task(rest_client, create_document, index_type):
+ dataset_id, _ = create_document("dataset_index_graph_source.txt")
+ run_graph = rest_client.post(
+ f"/datasets/{dataset_id}/index",
+ params={"type": index_type},
+ )
+ assert run_graph.status_code == 200
+ run_graph_payload = run_graph.json()
+ assert run_graph_payload["code"] == 0, run_graph_payload
+ assert run_graph_payload["data"].get("task_id"), run_graph_payload
+
+
+@pytest.mark.p2
+def test_dataset_embedding_endpoints(rest_client, create_dataset):
+ dataset_id = create_dataset("dataset_embedding_endpoints")
+
+ run_no_docs_res = rest_client.post(f"/datasets/{dataset_id}/embedding")
+ assert run_no_docs_res.status_code == 200
+ run_no_docs_payload = run_no_docs_res.json()
+ assert run_no_docs_payload["code"] == 102, run_no_docs_payload
+
+ missing_embd_id_res = rest_client.post(f"/datasets/{dataset_id}/embedding/check", json={})
+ assert missing_embd_id_res.status_code == 200
+ missing_embd_id_payload = missing_embd_id_res.json()
+ assert missing_embd_id_payload["code"] != 0, missing_embd_id_payload
+
+ invalid_dataset_res = rest_client.post("/datasets/invalid_id/embedding")
+ assert invalid_dataset_res.status_code == 200
+ invalid_dataset_payload = invalid_dataset_res.json()
+ assert invalid_dataset_payload["code"] != 0, invalid_dataset_payload
diff --git a/test/testcases/restful_api/test_document_raw_routes.py b/test/testcases/restful_api/test_document_raw_routes.py
new file mode 100644
index 0000000000..07f65230df
--- /dev/null
+++ b/test/testcases/restful_api/test_document_raw_routes.py
@@ -0,0 +1,43 @@
+#
+# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import pytest
+
+
+@pytest.mark.p2
+def test_document_image_invalid_id_contract(rest_client_noauth):
+ res = rest_client_noauth.get("/documents/images/not-a-valid-image-id")
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] == 102, payload
+ assert payload["message"] == "Image not found.", payload
+
+
+@pytest.mark.p2
+def test_document_artifact_requires_auth(rest_client_noauth):
+ res = rest_client_noauth.get("/documents/artifact/not-an-artifact.txt")
+ assert res.status_code == 401
+ payload = res.json()
+ assert payload["code"] == 401, payload
+
+
+@pytest.mark.p2
+def test_document_artifact_rejects_unsafe_filename(rest_client):
+ res = rest_client.get("/documents/artifact/not-an-artifact.exe")
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] == 102, payload
+ assert payload["message"] == "Invalid file type.", payload
diff --git a/test/testcases/restful_api/test_documents.py b/test/testcases/restful_api/test_documents.py
new file mode 100644
index 0000000000..59575fc0fc
--- /dev/null
+++ b/test/testcases/restful_api/test_documents.py
@@ -0,0 +1,122 @@
+#
+# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import pytest
+from utils.file_utils import create_txt_file
+
+
+@pytest.mark.p1
+def test_documents_upload_and_list(rest_client, create_dataset, tmp_path):
+ dataset_id = create_dataset("dataset_upload_list")
+ fp = create_txt_file(tmp_path / "upload_and_list.txt")
+ with fp.open("rb") as file_obj:
+ res = rest_client.post(
+ f"/datasets/{dataset_id}/documents",
+ files=[("file", (fp.name, file_obj))],
+ )
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] == 0, payload
+ assert payload["data"][0]["dataset_id"] == dataset_id, payload
+
+ list_res = rest_client.get(f"/datasets/{dataset_id}/documents")
+ assert list_res.status_code == 200
+ list_payload = list_res.json()
+ assert list_payload["code"] == 0, list_payload
+ assert list_payload["data"]["total"] >= 1, list_payload
+ assert any(doc["name"] == fp.name for doc in list_payload["data"]["docs"]), list_payload
+
+
+@pytest.mark.p2
+def test_documents_upload_missing_file(rest_client, create_dataset):
+ dataset_id = create_dataset("dataset_upload_missing")
+ res = rest_client.post(f"/datasets/{dataset_id}/documents")
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] == 101, payload
+ assert payload["message"] == "No file part!", payload
+
+
+@pytest.mark.p2
+def test_documents_update_patch_and_delete(rest_client, create_document):
+ dataset_id, document_id = create_document("update_target.txt")
+
+ patch_res = rest_client.patch(
+ f"/datasets/{dataset_id}/documents/{document_id}",
+ json={"name": "updated_target.txt"},
+ )
+ assert patch_res.status_code == 200
+ patch_payload = patch_res.json()
+ assert patch_payload["code"] == 0, patch_payload
+ assert patch_payload["data"]["name"] == "updated_target.txt", patch_payload
+
+ delete_res = rest_client.delete(
+ f"/datasets/{dataset_id}/documents",
+ json={"ids": [document_id]},
+ )
+ assert delete_res.status_code == 200
+ delete_payload = delete_res.json()
+ assert delete_payload["code"] == 0, delete_payload
+ assert delete_payload["data"]["deleted"] == 1, delete_payload
+
+ list_res = rest_client.get(f"/datasets/{dataset_id}/documents")
+ assert list_res.status_code == 200
+ list_payload = list_res.json()
+ assert list_payload["code"] == 0, list_payload
+ assert all(doc["id"] != document_id for doc in list_payload["data"]["docs"]), list_payload
+
+
+@pytest.mark.p2
+def test_documents_parse_and_stop(rest_client, create_document):
+ dataset_id, document_id = create_document("parse_target.txt")
+
+ parse_res = rest_client.post(
+ f"/datasets/{dataset_id}/documents/parse",
+ json={"document_ids": [document_id]},
+ )
+ assert parse_res.status_code == 200
+ parse_payload = parse_res.json()
+ assert parse_payload["code"] == 0, parse_payload
+
+ stop_res = rest_client.post(
+ f"/datasets/{dataset_id}/documents/stop",
+ json={"document_ids": [document_id]},
+ )
+ assert stop_res.status_code == 200
+ stop_payload = stop_res.json()
+ # Depending on timing this can be immediate stop success or "already completed".
+ assert stop_payload["code"] in (0, 102), stop_payload
+ if stop_payload["code"] == 102:
+ assert "already completed" in stop_payload["message"], stop_payload
+
+
+@pytest.mark.p2
+def test_documents_metadata_update_path(rest_client, create_document):
+ dataset_id, document_id = create_document("metadata_target.txt")
+
+ res = rest_client.patch(
+ f"/datasets/{dataset_id}/documents/metadatas",
+ json={
+ "selector": {"document_ids": [document_id]},
+ "updates": [{"key": "author", "value": "qa"}],
+ "deletes": [],
+ },
+ )
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] == 0, payload
+ assert payload["data"]["matched_docs"] == 1, payload
+ assert payload["data"]["updated"] >= 1, payload
diff --git a/test/testcases/restful_api/test_file_routes_unit.py b/test/testcases/restful_api/test_file_routes_unit.py
new file mode 100644
index 0000000000..39246e97a0
--- /dev/null
+++ b/test/testcases/restful_api/test_file_routes_unit.py
@@ -0,0 +1,632 @@
+#
+# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import asyncio
+import importlib.util
+import sys
+from enum import Enum
+from pathlib import Path
+from types import ModuleType, SimpleNamespace
+
+import pytest
+
+
+class _DummyManager:
+ def route(self, *_args, **_kwargs):
+ def decorator(func):
+ return func
+
+ return decorator
+
+
+class _AwaitableValue:
+ def __init__(self, value):
+ self._value = value
+
+ def __await__(self):
+ async def _co():
+ return self._value
+
+ return _co().__await__()
+
+
+class _DummyFiles(dict):
+ def __init__(self, file_objs=None):
+ super().__init__()
+ self._file_objs = list(file_objs or [])
+ if file_objs is not None:
+ self["file"] = self._file_objs
+
+ def getlist(self, key):
+ if key == "file":
+ return list(self._file_objs)
+ return []
+
+
+class _DummyUploadFile:
+ def __init__(self, filename, blob=b"blob"):
+ self.filename = filename
+ self._blob = blob
+
+ def read(self):
+ return self._blob
+
+
+class _DummyRequest:
+ def __init__(self, *, content_type="", form=None, files=None, args=None):
+ self.content_type = content_type
+ self.form = _AwaitableValue(form or {})
+ self.files = _AwaitableValue(files if files is not None else _DummyFiles())
+ self.args = args or {}
+
+
+class _DummyResponse:
+ def __init__(self, data):
+ self.data = data
+ self.headers = {}
+
+
+def _run(coro):
+ return asyncio.run(coro)
+
+
+def _load_file_api_module(monkeypatch):
+ repo_root = Path(__file__).resolve().parents[3]
+
+ quart_mod = ModuleType("quart")
+ quart_mod.request = _DummyRequest()
+
+ async def _make_response(data):
+ return _DummyResponse(data)
+
+ quart_mod.make_response = _make_response
+ monkeypatch.setitem(sys.modules, "quart", quart_mod)
+
+ api_pkg = ModuleType("api")
+ api_pkg.__path__ = [str(repo_root / "api")]
+ monkeypatch.setitem(sys.modules, "api", api_pkg)
+
+ apps_pkg = ModuleType("api.apps")
+ apps_pkg.__path__ = [str(repo_root / "api" / "apps")]
+ apps_pkg.login_required = lambda func: func
+ monkeypatch.setitem(sys.modules, "api.apps", apps_pkg)
+ api_pkg.apps = apps_pkg
+
+ services_pkg = ModuleType("api.apps.services")
+ services_pkg.__path__ = [str(repo_root / "api" / "apps" / "services")]
+ monkeypatch.setitem(sys.modules, "api.apps.services", services_pkg)
+ apps_pkg.services = services_pkg
+
+ file_api_service_mod = ModuleType("api.apps.services.file_api_service")
+
+ async def _upload_file(_tenant_id, _pf_id, _file_objs):
+ return True, [{"id": "f1"}]
+
+ async def _create_folder(_tenant_id, _name, _parent_id=None, _file_type=None):
+ return True, {"id": "folder1"}
+
+ async def _delete_files(_tenant_id, _ids):
+ return True, True
+
+ async def _move_files(_tenant_id, _src_file_ids, _dest_file_id=None, _new_name=None):
+ return True, True
+
+ file_api_service_mod.upload_file = _upload_file
+ file_api_service_mod.create_folder = _create_folder
+ file_api_service_mod.list_files = lambda _tenant_id, _args: (True, {"files": [], "total": 0})
+ file_api_service_mod.delete_files = _delete_files
+ file_api_service_mod.move_files = _move_files
+ file_api_service_mod.get_file_content = lambda _tenant_id, _file_id: (
+ True,
+ SimpleNamespace(parent_id="bucket1", location="path1", name="doc.txt", type="doc"),
+ )
+ file_api_service_mod.get_parent_folder = lambda _file_id, user_id=None: (True, {"parent_folder": {"id": "parent1"}})
+ file_api_service_mod.get_all_parent_folders = lambda _file_id, user_id=None: (True, {"parent_folders": [{"id": "root"}]})
+ monkeypatch.setitem(sys.modules, "api.apps.services.file_api_service", file_api_service_mod)
+ services_pkg.file_api_service = file_api_service_mod
+
+ db_pkg = ModuleType("api.db")
+ db_pkg.__path__ = []
+
+ class _FileType(Enum):
+ DOC = "doc"
+ VISUAL = "visual"
+
+ db_pkg.FileType = _FileType
+ monkeypatch.setitem(sys.modules, "api.db", db_pkg)
+ api_pkg.db = db_pkg
+
+ file2doc_mod = ModuleType("api.db.services.file2document_service")
+ file2doc_mod.File2DocumentService = SimpleNamespace(get_storage_address=lambda **_kwargs: ("bucket2", "path2"))
+ monkeypatch.setitem(sys.modules, "api.db.services.file2document_service", file2doc_mod)
+
+ api_utils_mod = ModuleType("api.utils.api_utils")
+ api_utils_mod.add_tenant_id_to_kwargs = lambda func: func
+ api_utils_mod.get_error_argument_result = lambda message: {"code": 400, "data": None, "message": message}
+ api_utils_mod.get_error_data_result = lambda message: {"code": 500, "data": None, "message": message}
+ api_utils_mod.get_result = lambda data=None: {"code": 0, "data": data, "message": ""}
+ api_utils_mod.get_json_result = lambda code=0, message="success", data=None: {"code": code, "data": data, "message": message}
+ monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
+
+ validation_mod = ModuleType("api.utils.validation_utils")
+ validation_mod.CreateFolderReq = object
+ validation_mod.DeleteFileReq = object
+ validation_mod.ListFileReq = object
+ validation_mod.MoveFileReq = object
+
+ async def _validate_json_request(_request, _schema):
+ return {}, None
+
+ validation_mod.validate_and_parse_json_request = _validate_json_request
+ validation_mod.validate_and_parse_request_args = lambda _request, _schema: ({}, None)
+ monkeypatch.setitem(sys.modules, "api.utils.validation_utils", validation_mod)
+
+ web_utils_mod = ModuleType("api.utils.web_utils")
+ web_utils_mod.CONTENT_TYPE_MAP = {"txt": "text/plain"}
+ web_utils_mod.apply_safe_file_response_headers = lambda response, content_type, ext: response.headers.update({"content_type": content_type, "ext": ext})
+ monkeypatch.setitem(sys.modules, "api.utils.web_utils", web_utils_mod)
+
+ common_pkg = ModuleType("common")
+ common_pkg.__path__ = [str(repo_root / "common")]
+ common_pkg.settings = SimpleNamespace(
+ STORAGE_IMPL=SimpleNamespace(
+ get=lambda *_args, **_kwargs: b"blob",
+ )
+ )
+ monkeypatch.setitem(sys.modules, "common", common_pkg)
+
+ misc_utils_mod = ModuleType("common.misc_utils")
+
+ async def thread_pool_exec(func, *args, **kwargs):
+ return func(*args, **kwargs)
+
+ misc_utils_mod.thread_pool_exec = thread_pool_exec
+ monkeypatch.setitem(sys.modules, "common.misc_utils", misc_utils_mod)
+
+ module_path = repo_root / "api" / "apps" / "restful_apis" / "file_api.py"
+ spec = importlib.util.spec_from_file_location("api.apps.restful_apis.file_api", module_path)
+ module = importlib.util.module_from_spec(spec)
+ module.manager = _DummyManager()
+ monkeypatch.setitem(sys.modules, "api.apps.restful_apis.file_api", module)
+ spec.loader.exec_module(module)
+ return module
+
+
+@pytest.mark.p2
+def test_create_or_upload_multipart_requires_file(monkeypatch):
+ module = _load_file_api_module(monkeypatch)
+ monkeypatch.setattr(module, "request", _DummyRequest(content_type="multipart/form-data", form={}, files=_DummyFiles()))
+
+ res = _run(module.create_or_upload("tenant1"))
+ assert res["code"] == 400
+ assert res["message"] == "No file part!"
+
+
+@pytest.mark.p2
+def test_create_or_upload_uploads_via_new_service(monkeypatch):
+ module = _load_file_api_module(monkeypatch)
+ files = _DummyFiles([_DummyUploadFile("a.txt")])
+ monkeypatch.setattr(module, "request", _DummyRequest(content_type="multipart/form-data", form={"parent_id": "pf1"}, files=files))
+
+ seen = {}
+
+ async def _upload_file(tenant_id, pf_id, file_objs):
+ seen["args"] = (tenant_id, pf_id, [f.filename for f in file_objs])
+ return True, [{"id": "f1"}]
+
+ monkeypatch.setattr(module.file_api_service, "upload_file", _upload_file)
+ res = _run(module.create_or_upload("tenant1"))
+
+ assert seen["args"] == ("tenant1", "pf1", ["a.txt"])
+ assert res["code"] == 0
+ assert res["data"] == [{"id": "f1"}]
+
+
+@pytest.mark.p2
+def test_create_or_upload_creates_folder_from_json(monkeypatch):
+ module = _load_file_api_module(monkeypatch)
+ monkeypatch.setattr(module, "request", _DummyRequest(content_type="application/json"))
+
+ async def _validate(_request, _schema):
+ return {"name": "folder-a", "parent_id": "pf1", "type": "folder"}, None
+
+ async def _create_folder(tenant_id, name, parent_id=None, file_type=None):
+ return True, {"tenant_id": tenant_id, "name": name, "parent_id": parent_id, "type": file_type}
+
+ monkeypatch.setattr(module, "validate_and_parse_json_request", _validate)
+ monkeypatch.setattr(module.file_api_service, "create_folder", _create_folder)
+
+ res = _run(module.create_or_upload("tenant1"))
+ assert res["code"] == 0
+ assert res["data"]["tenant_id"] == "tenant1"
+ assert res["data"]["name"] == "folder-a"
+
+
+@pytest.mark.p2
+def test_list_files_validation_error(monkeypatch):
+ module = _load_file_api_module(monkeypatch)
+ monkeypatch.setattr(module, "validate_and_parse_request_args", lambda _request, _schema: (None, "bad args"))
+
+ res = _run(module.list_files("tenant1"))
+ assert res["code"] == 400
+ assert res["message"] == "bad args"
+
+
+@pytest.mark.p2
+def test_move_uses_new_payload_shape(monkeypatch):
+ module = _load_file_api_module(monkeypatch)
+
+ async def _validate(_request, _schema):
+ return {"src_file_ids": ["f1"], "dest_file_id": "pf2"}, None
+
+ seen = {}
+
+ async def _move_files(tenant_id, src_file_ids, dest_file_id=None, new_name=None):
+ seen["args"] = (tenant_id, src_file_ids, dest_file_id, new_name)
+ return True, True
+
+ monkeypatch.setattr(module, "validate_and_parse_json_request", _validate)
+ monkeypatch.setattr(module.file_api_service, "move_files", _move_files)
+
+ res = _run(module.move("tenant1"))
+ assert seen["args"] == ("tenant1", ["f1"], "pf2", None)
+ assert res["code"] == 0
+ assert res["data"] is True
+
+
+@pytest.mark.p2
+def test_rename_via_move_route(monkeypatch):
+ module = _load_file_api_module(monkeypatch)
+
+ async def _validate(_request, _schema):
+ return {"src_file_ids": ["file1"], "new_name": "renamed.txt"}, None
+
+ seen = {}
+
+ async def _move_files(tenant_id, src_file_ids, dest_file_id=None, new_name=None):
+ seen["args"] = (tenant_id, src_file_ids, dest_file_id, new_name)
+ return True, True
+
+ monkeypatch.setattr(module, "validate_and_parse_json_request", _validate)
+ monkeypatch.setattr(module.file_api_service, "move_files", _move_files)
+
+ res = _run(module.move("tenant1"))
+ assert seen["args"] == ("tenant1", ["file1"], None, "renamed.txt")
+ assert res["code"] == 0
+ assert res["data"] is True
+
+
+@pytest.mark.p2
+def test_download_falls_back_to_document_storage(monkeypatch):
+ module = _load_file_api_module(monkeypatch)
+ storage_calls = []
+
+ def _get(bucket, location):
+ storage_calls.append((bucket, location))
+ return b"" if len(storage_calls) == 1 else b"fallback-blob"
+
+ monkeypatch.setattr(module.settings, "STORAGE_IMPL", SimpleNamespace(get=_get))
+ res = _run(module.download("tenant1", "file1"))
+
+ assert storage_calls == [("bucket1", "path1"), ("bucket2", "path2")]
+ assert res.data == b"fallback-blob"
+ assert res.headers["content_type"] == "text/plain"
+ assert res.headers["ext"] == "txt"
+
+
+@pytest.mark.p2
+def test_parent_and_ancestors_use_new_routes(monkeypatch):
+ module = _load_file_api_module(monkeypatch)
+
+ parent_res = _run(module.parent_folder("tenant1", "file1"))
+ ancestors_res = _run(module.ancestors("tenant1", "file1"))
+
+ assert parent_res["code"] == 0
+ assert parent_res["data"]["parent_folder"]["id"] == "parent1"
+ assert ancestors_res["code"] == 0
+ assert ancestors_res["data"]["parent_folders"][0]["id"] == "root"
+
+#
+# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import functools
+from copy import deepcopy
+
+import pytest
+
+
+class _DummyManager:
+ def route(self, *_args, **_kwargs):
+ def decorator(func):
+ return func
+
+ return decorator
+
+
+class _DummyFile:
+ def __init__(self, file_id, file_type, *, name="file.txt", location="loc", size=1):
+ self.id = file_id
+ self.type = file_type
+ self.name = name
+ self.location = location
+ self.size = size
+
+
+class _FalsyFile(_DummyFile):
+ def __bool__(self):
+ return False
+
+
+def _run(coro):
+ return asyncio.run(coro)
+
+
+def _set_request_json(monkeypatch, module, payload_state):
+ async def _req_json():
+ return deepcopy(payload_state)
+
+ monkeypatch.setattr(module, "get_request_json", _req_json)
+
+
+@pytest.fixture(scope="session")
+def auth():
+ return "unit-auth"
+
+
+@pytest.fixture(scope="session", autouse=True)
+def set_tenant_info():
+ return None
+
+
+def _load_file2document_module(monkeypatch):
+ repo_root = Path(__file__).resolve().parents[3]
+
+ api_pkg = ModuleType("api")
+ api_pkg.__path__ = [str(repo_root / "api")]
+ monkeypatch.setitem(sys.modules, "api", api_pkg)
+
+ apps_mod = ModuleType("api.apps")
+ apps_mod.__path__ = [str(repo_root / "api" / "apps")]
+ apps_mod.current_user = SimpleNamespace(id="user-1")
+ apps_mod.login_required = lambda func: func
+ monkeypatch.setitem(sys.modules, "api.apps", apps_mod)
+ api_pkg.apps = apps_mod
+
+ db_pkg = ModuleType("api.db")
+ db_pkg.__path__ = []
+
+ class _FileType(Enum):
+ FOLDER = "folder"
+ DOC = "doc"
+
+ db_pkg.FileType = _FileType
+ monkeypatch.setitem(sys.modules, "api.db", db_pkg)
+ api_pkg.db = db_pkg
+
+ services_pkg = ModuleType("api.db.services")
+ services_pkg.__path__ = []
+ monkeypatch.setitem(sys.modules, "api.db.services", services_pkg)
+
+ common_pkg = ModuleType("api.common")
+ common_pkg.__path__ = []
+ monkeypatch.setitem(sys.modules, "api.common", common_pkg)
+
+ permission_mod = ModuleType("api.common.check_team_permission")
+ permission_mod.check_file_team_permission = lambda *_args, **_kwargs: True
+ permission_mod.check_kb_team_permission = lambda *_args, **_kwargs: True
+ monkeypatch.setitem(sys.modules, "api.common.check_team_permission", permission_mod)
+ common_pkg.check_team_permission = permission_mod
+
+ file2document_mod = ModuleType("api.db.services.file2document_service")
+
+ class _StubFile2DocumentService:
+ @staticmethod
+ def get_by_file_id(_file_id):
+ return []
+
+ @staticmethod
+ def delete_by_file_id(*_args, **_kwargs):
+ return None
+
+ @staticmethod
+ def insert(_payload):
+ return SimpleNamespace(to_json=lambda: {})
+
+ file2document_mod.File2DocumentService = _StubFile2DocumentService
+ monkeypatch.setitem(sys.modules, "api.db.services.file2document_service", file2document_mod)
+ services_pkg.file2document_service = file2document_mod
+
+ file_service_mod = ModuleType("api.db.services.file_service")
+
+ class _StubFileService:
+ @staticmethod
+ def get_by_ids(_file_ids):
+ return []
+
+ @staticmethod
+ def get_all_innermost_file_ids(_file_id, _acc):
+ return []
+
+ @staticmethod
+ def get_by_id(_file_id):
+ return True, _DummyFile(_file_id, _FileType.DOC.value)
+
+ @staticmethod
+ def get_parser(_file_type, _file_name, parser_id):
+ return parser_id
+
+ file_service_mod.FileService = _StubFileService
+ monkeypatch.setitem(sys.modules, "api.db.services.file_service", file_service_mod)
+ services_pkg.file_service = file_service_mod
+
+ kb_service_mod = ModuleType("api.db.services.knowledgebase_service")
+
+ class _StubKnowledgebaseService:
+ @staticmethod
+ def get_by_id(_kb_id):
+ return False, None
+
+ kb_service_mod.KnowledgebaseService = _StubKnowledgebaseService
+ monkeypatch.setitem(sys.modules, "api.db.services.knowledgebase_service", kb_service_mod)
+ services_pkg.knowledgebase_service = kb_service_mod
+
+ document_service_mod = ModuleType("api.db.services.document_service")
+
+ class _StubDocumentService:
+ @staticmethod
+ def get_by_id(doc_id):
+ return True, SimpleNamespace(id=doc_id)
+
+ @staticmethod
+ def get_tenant_id(_doc_id):
+ return "tenant-1"
+
+ @staticmethod
+ def remove_document(*_args, **_kwargs):
+ return True
+
+ @staticmethod
+ def insert(_payload):
+ return SimpleNamespace(id="doc-1")
+
+ document_service_mod.DocumentService = _StubDocumentService
+ monkeypatch.setitem(sys.modules, "api.db.services.document_service", document_service_mod)
+ services_pkg.document_service = document_service_mod
+
+ api_utils_mod = ModuleType("api.utils.api_utils")
+
+ def get_json_result(data=None, message="", code=0):
+ return {"code": code, "data": data, "message": message}
+
+ def get_data_error_result(message=""):
+ return {"code": 102, "data": None, "message": message}
+
+ async def get_request_json():
+ return {}
+
+ def server_error_response(err):
+ return {"code": 500, "data": None, "message": str(err)}
+
+ def validate_request(*_keys):
+ def _decorator(func):
+ @functools.wraps(func)
+ async def _wrapper(*args, **kwargs):
+ return await func(*args, **kwargs)
+
+ return _wrapper
+
+ return _decorator
+
+ api_utils_mod.get_json_result = get_json_result
+ api_utils_mod.get_data_error_result = get_data_error_result
+ api_utils_mod.get_request_json = get_request_json
+ api_utils_mod.server_error_response = server_error_response
+ api_utils_mod.validate_request = validate_request
+ monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
+
+ misc_utils_mod = ModuleType("common.misc_utils")
+ misc_utils_mod.get_uuid = lambda: "uuid"
+ monkeypatch.setitem(sys.modules, "common.misc_utils", misc_utils_mod)
+
+ constants_mod = ModuleType("common.constants")
+
+ class _RetCode:
+ ARGUMENT_ERROR = 101
+
+ constants_mod.RetCode = _RetCode
+ monkeypatch.setitem(sys.modules, "common.constants", constants_mod)
+
+ module_name = "test_file2document_routes_unit_module"
+ module_path = repo_root / "api" / "apps" / "restful_apis" / "file2document_api.py"
+ spec = importlib.util.spec_from_file_location(module_name, module_path)
+ module = importlib.util.module_from_spec(spec)
+ module.manager = _DummyManager()
+ monkeypatch.setitem(sys.modules, module_name, module)
+ spec.loader.exec_module(module)
+ return module
+
+
+@pytest.mark.p2
+def test_convert_branch_matrix_unit(monkeypatch):
+ module = _load_file2document_module(monkeypatch)
+ req_state = {"kb_ids": ["kb-1"], "file_ids": ["f1"]}
+ _set_request_json(monkeypatch, module, req_state)
+
+ # Falsy file returns "File not found!" during synchronous validation.
+ monkeypatch.setattr(module.FileService, "get_by_ids", lambda _ids: [_FalsyFile("f1", module.FileType.DOC.value)])
+ res = _run(module.convert())
+ assert res["code"] == 102
+ assert res["message"] == "File not found!"
+
+ # Valid file but invalid kb returns "Can't find this dataset!" during synchronous validation.
+ monkeypatch.setattr(module.FileService, "get_by_ids", lambda _ids: [_DummyFile("f1", module.FileType.DOC.value)])
+ res = _run(module.convert())
+ assert res["code"] == 102
+ assert res["message"] == "Can't find this dataset!"
+
+ kb = SimpleNamespace(id="kb-1", parser_id="naive", pipeline_id="p1", parser_config={})
+ monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, kb))
+
+ # Unauthorized file access is rejected before scheduling background work.
+ monkeypatch.setattr(module, "check_file_team_permission", lambda *_args, **_kwargs: False)
+ res = _run(module.convert())
+ assert res["code"] == 102
+ assert res["message"] == "No authorization."
+
+ # Unauthorized dataset access is rejected before scheduling background work.
+ monkeypatch.setattr(module, "check_file_team_permission", lambda *_args, **_kwargs: True)
+ monkeypatch.setattr(module, "check_kb_team_permission", lambda *_args, **_kwargs: False)
+ res = _run(module.convert())
+ assert res["code"] == 102
+ assert res["message"] == "No authorization."
+
+ # Valid file and kb schedule background work and return data=True immediately.
+ monkeypatch.setattr(module, "check_kb_team_permission", lambda *_args, **_kwargs: True)
+ res = _run(module.convert())
+ assert res["code"] == 0
+ assert res["data"] is True
+
+ # Folder expansion schedules background work and returns data=True immediately.
+ req_state["file_ids"] = ["folder-1"]
+ monkeypatch.setattr(module.FileService, "get_by_ids", lambda _ids: [_DummyFile("folder-1", module.FileType.FOLDER.value, name="folder")])
+ monkeypatch.setattr(module.FileService, "get_all_innermost_file_ids", lambda _file_id, _acc: ["inner-1"])
+ res = _run(module.convert())
+ assert res["code"] == 0
+ assert res["data"] is True
+
+ # Exception in file lookup returns 500.
+ req_state["file_ids"] = ["f1"]
+ monkeypatch.setattr(
+ module.FileService,
+ "get_by_ids",
+ lambda _ids: (_ for _ in ()).throw(RuntimeError("convert boom")),
+ )
+ res = _run(module.convert())
+ assert res["code"] == 500
+ assert "convert boom" in res["message"]
diff --git a/test/testcases/restful_api/test_langfuse_routes.py b/test/testcases/restful_api/test_langfuse_routes.py
new file mode 100644
index 0000000000..deda7fbe3e
--- /dev/null
+++ b/test/testcases/restful_api/test_langfuse_routes.py
@@ -0,0 +1,37 @@
+#
+# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import pytest
+
+
+@pytest.mark.p2
+def test_langfuse_api_key_routes_require_auth(rest_client_noauth):
+ for method in ("get", "post", "put", "delete"):
+ requester = getattr(rest_client_noauth, method)
+ kwargs = {"json": {"secret_key": "s", "public_key": "p", "host": "http://example.com"}} if method in {"post", "put"} else {}
+ res = requester("/langfuse/api-key", **kwargs)
+ assert res.status_code == 401
+ payload = res.json()
+ assert payload["code"] == 401, (method, payload)
+
+
+@pytest.mark.p2
+def test_langfuse_api_key_missing_required_fields(rest_client):
+ res = rest_client.post("/langfuse/api-key", json={"secret_key": "", "public_key": "pub", "host": "http://host"})
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] in (101, 102), payload
+ assert "required" in payload["message"].lower() or "missing" in payload["message"].lower(), payload
diff --git a/test/testcases/restful_api/test_mcp_routes_unit.py b/test/testcases/restful_api/test_mcp_routes_unit.py
new file mode 100644
index 0000000000..ccd628f0fd
--- /dev/null
+++ b/test/testcases/restful_api/test_mcp_routes_unit.py
@@ -0,0 +1,745 @@
+#
+# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import asyncio
+import importlib.util
+import inspect
+import json
+import sys
+from functools import wraps
+from pathlib import Path
+from types import ModuleType, SimpleNamespace
+
+import pytest
+
+
+class _DummyManager:
+ def route(self, *_args, **_kwargs):
+ def decorator(func):
+ return func
+
+ return decorator
+
+
+class _Args(dict):
+ def getlist(self, key):
+ value = self.get(key, [])
+ if isinstance(value, list):
+ return value
+ return [value]
+
+
+class _Field:
+ def __init__(self, name):
+ self.name = name
+
+ def __eq__(self, other):
+ return (self.name, other)
+
+
+class _DummyMCPServer:
+ id = _Field("id")
+ tenant_id = _Field("tenant_id")
+
+ def __init__(self, **kwargs):
+ self.id = kwargs.get("id", "")
+ self.name = kwargs.get("name", "")
+ self.url = kwargs.get("url", "")
+ self.server_type = kwargs.get("server_type", "sse")
+ self.tenant_id = kwargs.get("tenant_id", "tenant_1")
+ self.variables = kwargs.get("variables", {})
+ self.headers = kwargs.get("headers", {})
+
+ def to_dict(self):
+ return {
+ "id": self.id,
+ "name": self.name,
+ "url": self.url,
+ "server_type": self.server_type,
+ "tenant_id": self.tenant_id,
+ "variables": self.variables,
+ "headers": self.headers,
+ }
+
+
+class _DummyMCPServerService:
+ @staticmethod
+ def get_servers(*_args, **_kwargs):
+ return []
+
+ @staticmethod
+ def get_or_none(*_args, **_kwargs):
+ return None
+
+ @staticmethod
+ def get_by_id(*_args, **_kwargs):
+ return False, None
+
+ @staticmethod
+ def get_by_name_and_tenant(*_args, **_kwargs):
+ return False, None
+
+ @staticmethod
+ def insert(**_kwargs):
+ return True
+
+ @staticmethod
+ def filter_update(*_args, **_kwargs):
+ return True
+
+ @staticmethod
+ def delete_by_ids(*_args, **_kwargs):
+ return True
+
+
+class _DummyTenantService:
+ @staticmethod
+ def get_by_id(*_args, **_kwargs):
+ return True, SimpleNamespace(id="tenant_1")
+
+
+class _DummyTool:
+ def __init__(self, name):
+ self._name = name
+
+ def model_dump(self):
+ return {"name": self._name}
+
+
+class _DummyMCPToolCallSession:
+ def __init__(self, _mcp_server, _variables):
+ self._tools = [_DummyTool("tool_a"), _DummyTool("tool_b")]
+
+ def get_tools(self, _timeout):
+ return self._tools
+
+ def tool_call(self, _name, _arguments, _timeout):
+ return "ok"
+
+
+def _run(coro):
+ return asyncio.run(coro)
+
+
+def _set_request_json(monkeypatch, module, payload):
+ async def _request_json():
+ return payload
+
+ monkeypatch.setattr(module, "get_request_json", _request_json)
+
+
+@pytest.fixture(scope="session")
+def auth():
+ return "unit-auth"
+
+
+@pytest.fixture(scope="session", autouse=True)
+def set_tenant_info():
+ return None
+
+
+def _load_mcp_api(monkeypatch):
+ repo_root = Path(__file__).resolve().parents[3]
+
+ quart_mod = ModuleType("quart")
+ quart_mod.Response = object
+ quart_mod.request = SimpleNamespace(args=_Args({}))
+ monkeypatch.setitem(sys.modules, "quart", quart_mod)
+
+ common_pkg = ModuleType("common")
+ common_pkg.__path__ = [str(repo_root / "common")]
+ monkeypatch.setitem(sys.modules, "common", common_pkg)
+
+ constants_mod = ModuleType("common.constants")
+ constants_mod.VALID_MCP_SERVER_TYPES = {"sse", "streamable-http"}
+ monkeypatch.setitem(sys.modules, "common.constants", constants_mod)
+
+ apps_mod = ModuleType("api.apps")
+ apps_mod.current_user = SimpleNamespace(id="tenant_1")
+ apps_mod.login_required = lambda func: func
+ monkeypatch.setitem(sys.modules, "api.apps", apps_mod)
+
+ db_models_mod = ModuleType("api.db.db_models")
+ db_models_mod.MCPServer = _DummyMCPServer
+ monkeypatch.setitem(sys.modules, "api.db.db_models", db_models_mod)
+
+ mcp_service_mod = ModuleType("api.db.services.mcp_server_service")
+ mcp_service_mod.MCPServerService = _DummyMCPServerService
+ monkeypatch.setitem(sys.modules, "api.db.services.mcp_server_service", mcp_service_mod)
+
+ user_service_mod = ModuleType("api.db.services.user_service")
+ user_service_mod.TenantService = _DummyTenantService
+ monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod)
+
+ mcp_conn_mod = ModuleType("common.mcp_tool_call_conn")
+ mcp_conn_mod.MCPToolCallSession = _DummyMCPToolCallSession
+ mcp_conn_mod.close_multiple_mcp_toolcall_sessions = lambda _sessions: None
+ monkeypatch.setitem(sys.modules, "common.mcp_tool_call_conn", mcp_conn_mod)
+
+ api_utils_mod = ModuleType("api.utils.api_utils")
+
+ async def _default_request_json():
+ return {}
+
+ def _get_json_result(code=0, message="success", data=None):
+ return {"code": code, "message": message, "data": data}
+
+ def _get_data_error_result(code=102, message="Sorry! Data missing!"):
+ return {"code": code, "message": message}
+
+ def _server_error_response(error):
+ return {"code": 100, "message": repr(error)}
+
+ async def _get_mcp_tools(*_args, **_kwargs):
+ return {}
+
+ def _validate_request(*_args, **_kwargs):
+ def _decorator(func):
+ @wraps(func)
+ async def _wrapped(*func_args, **func_kwargs):
+ if inspect.iscoroutinefunction(func):
+ return await func(*func_args, **func_kwargs)
+ return func(*func_args, **func_kwargs)
+
+ return _wrapped
+
+ return _decorator
+
+ api_utils_mod.get_request_json = _default_request_json
+ api_utils_mod.get_json_result = _get_json_result
+ api_utils_mod.get_data_error_result = _get_data_error_result
+ api_utils_mod.server_error_response = _server_error_response
+ api_utils_mod.validate_request = _validate_request
+ api_utils_mod.get_mcp_tools = _get_mcp_tools
+ monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
+
+ web_utils_mod = ModuleType("api.utils.web_utils")
+
+ def _get_float(data, key, default):
+ try:
+ return float(data.get(key, default))
+ except (TypeError, ValueError):
+ return default
+
+ def _safe_json_parse(value):
+ if isinstance(value, (dict, list)):
+ return value
+ if value in (None, ""):
+ return {}
+ try:
+ return json.loads(value)
+ except (TypeError, ValueError):
+ return {}
+
+ web_utils_mod.get_float = _get_float
+ web_utils_mod.safe_json_parse = _safe_json_parse
+ monkeypatch.setitem(sys.modules, "api.utils.web_utils", web_utils_mod)
+
+ module_name = "test_mcp_api_unit_module"
+ module_path = repo_root / "api" / "apps" / "restful_apis" / "mcp_api.py"
+ spec = importlib.util.spec_from_file_location(module_name, module_path)
+ module = importlib.util.module_from_spec(spec)
+ module.manager = _DummyManager()
+ monkeypatch.setitem(sys.modules, module_name, module)
+ spec.loader.exec_module(module)
+ return module
+
+
+@pytest.mark.p2
+def test_list_mcp_desc_pagination_and_exception(monkeypatch):
+ module = _load_mcp_api(monkeypatch)
+
+ monkeypatch.setattr(
+ module,
+ "request",
+ SimpleNamespace(args=_Args({"keywords": "k", "page": "2", "page_size": "1", "orderby": "create_time", "desc": "false"})),
+ )
+ _set_request_json(monkeypatch, module, {"mcp_ids": []})
+ monkeypatch.setattr(module.MCPServerService, "get_servers", lambda *_args, **_kwargs: [{"id": "a"}, {"id": "b"}])
+
+ res = _run(module.list_mcp())
+ assert res["code"] == 0
+ assert res["data"]["total"] == 2
+ assert res["data"]["mcp_servers"] == [{"id": "b"}]
+
+ monkeypatch.setattr(module, "request", SimpleNamespace(args=_Args({})))
+ _set_request_json(monkeypatch, module, {"mcp_ids": []})
+
+ def _raise_list(*_args, **_kwargs):
+ raise RuntimeError("list explode")
+
+ monkeypatch.setattr(module.MCPServerService, "get_servers", _raise_list)
+ res = _run(module.list_mcp())
+ assert res["code"] == 100
+ assert "list explode" in res["message"]
+
+
+@pytest.mark.p2
+def test_detail_not_found_success_and_exception(monkeypatch):
+ module = _load_mcp_api(monkeypatch)
+ monkeypatch.setattr(module, "request", SimpleNamespace(args=_Args({})))
+
+ monkeypatch.setattr(module.MCPServerService, "get_or_none", lambda **_kwargs: None)
+ res = module.detail("mcp-1")
+ assert res["code"] == 102
+ assert "Cannot find MCP server mcp-1 for user tenant_1" in res["message"]
+
+ monkeypatch.setattr(
+ module.MCPServerService,
+ "get_or_none",
+ lambda **_kwargs: _DummyMCPServer(id="mcp-1", name="srv", url="http://a", server_type="sse", tenant_id="tenant_1"),
+ )
+ res = module.detail("mcp-1")
+ assert res["code"] == 0
+ assert res["data"]["id"] == "mcp-1"
+
+ def _raise_detail(**_kwargs):
+ raise RuntimeError("detail explode")
+
+ monkeypatch.setattr(module.MCPServerService, "get_or_none", _raise_detail)
+ res = module.detail("mcp-1")
+ assert res["code"] == 100
+ assert "detail explode" in res["message"]
+
+
+@pytest.mark.p2
+def test_create_validation_guards(monkeypatch):
+ module = _load_mcp_api(monkeypatch)
+
+ monkeypatch.setattr(module.MCPServerService, "get_by_name_and_tenant", lambda **_kwargs: (False, None))
+
+ _set_request_json(monkeypatch, module, {"name": "srv", "url": "http://a", "server_type": "invalid"})
+ res = _run(module.create.__wrapped__())
+ assert "Unsupported MCP server type" in res["message"]
+
+ _set_request_json(monkeypatch, module, {"name": "", "url": "http://a", "server_type": "sse"})
+ res = _run(module.create.__wrapped__())
+ assert "Invalid MCP name" in res["message"]
+
+ monkeypatch.setattr(module.MCPServerService, "get_by_name_and_tenant", lambda **_kwargs: (True, object()))
+ _set_request_json(monkeypatch, module, {"name": "srv", "url": "http://a", "server_type": "sse"})
+ res = _run(module.create.__wrapped__())
+ assert "Duplicated MCP server name" in res["message"]
+
+ monkeypatch.setattr(module.MCPServerService, "get_by_name_and_tenant", lambda **_kwargs: (False, None))
+ _set_request_json(monkeypatch, module, {"name": "srv", "url": "", "server_type": "sse"})
+ res = _run(module.create.__wrapped__())
+ assert "Invalid url" in res["message"]
+
+
+@pytest.mark.p2
+def test_create_service_paths(monkeypatch):
+ module = _load_mcp_api(monkeypatch)
+
+ base_payload = {
+ "name": "srv",
+ "url": "http://server",
+ "server_type": "sse",
+ "headers": '{"Authorization": "x"}',
+ "variables": '{"tools": {"old": 1}, "token": "abc"}',
+ "timeout": "2.5",
+ }
+
+ monkeypatch.setattr(module, "get_uuid", lambda: "uuid-create")
+ monkeypatch.setattr(module.MCPServerService, "get_by_name_and_tenant", lambda **_kwargs: (False, None))
+
+ _set_request_json(monkeypatch, module, dict(base_payload))
+ monkeypatch.setattr(module.TenantService, "get_by_id", lambda *_args, **_kwargs: (False, None))
+ res = _run(module.create.__wrapped__())
+ assert "Tenant not found" in res["message"]
+
+ _set_request_json(monkeypatch, module, dict(base_payload))
+ monkeypatch.setattr(module.TenantService, "get_by_id", lambda *_args, **_kwargs: (True, object()))
+
+ async def _thread_pool_tools_error(_func, _servers, _timeout):
+ return None, "tools error"
+
+ monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_tools_error)
+ res = _run(module.create.__wrapped__())
+ assert res["code"] == 102
+ assert "tools error" in res["message"]
+
+ _set_request_json(monkeypatch, module, dict(base_payload))
+
+ async def _thread_pool_ok(_func, servers, _timeout):
+ return {servers[0].name: [{"name": "tool_a"}, {"invalid": True}]}, None
+
+ monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_ok)
+ monkeypatch.setattr(module.MCPServerService, "insert", lambda **_kwargs: False)
+ res = _run(module.create.__wrapped__())
+ assert res["code"] == 102
+ assert "Failed to create MCP server" in res["message"]
+
+ _set_request_json(monkeypatch, module, dict(base_payload))
+ monkeypatch.setattr(module.MCPServerService, "insert", lambda **_kwargs: True)
+ res = _run(module.create.__wrapped__())
+ assert res["code"] == 0
+ assert res["data"]["id"] == "uuid-create"
+ assert res["data"]["tenant_id"] == "tenant_1"
+ assert res["data"]["variables"]["tools"] == {"tool_a": {"name": "tool_a"}}
+
+ _set_request_json(monkeypatch, module, dict(base_payload))
+
+ async def _thread_pool_raises(_func, _servers, _timeout):
+ raise RuntimeError("create explode")
+
+ monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_raises)
+ res = _run(module.create.__wrapped__())
+ assert res["code"] == 100
+ assert "create explode" in res["message"]
+
+
+@pytest.mark.p2
+def test_update_validation_guards(monkeypatch):
+ module = _load_mcp_api(monkeypatch)
+
+ existing = _DummyMCPServer(id="mcp-1", name="srv", url="http://server", server_type="sse", tenant_id="tenant_1", variables={}, headers={})
+
+ _set_request_json(monkeypatch, module, {"mcp_id": "mcp-1"})
+ monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (False, None))
+ res = _run(module.update("mcp-1"))
+ assert "Cannot find MCP server" in res["message"]
+
+ _set_request_json(monkeypatch, module, {"mcp_id": "mcp-1"})
+ monkeypatch.setattr(
+ module.MCPServerService,
+ "get_by_id",
+ lambda _mcp_id: (True, _DummyMCPServer(id="mcp-1", name="srv", url="http://server", server_type="sse", tenant_id="other", variables={}, headers={})),
+ )
+ res = _run(module.update("mcp-1"))
+ assert "Cannot find MCP server" in res["message"]
+
+ _set_request_json(monkeypatch, module, {"mcp_id": "mcp-1", "server_type": "invalid"})
+ monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, existing))
+ res = _run(module.update("mcp-1"))
+ assert "Unsupported MCP server type" in res["message"]
+
+ _set_request_json(monkeypatch, module, {"mcp_id": "mcp-1", "name": "a" * 256})
+ res = _run(module.update("mcp-1"))
+ assert "Invalid MCP name" in res["message"]
+
+ _set_request_json(monkeypatch, module, {"mcp_id": "mcp-1", "url": ""})
+ res = _run(module.update("mcp-1"))
+ assert "Invalid url" in res["message"]
+
+
+@pytest.mark.p2
+def test_update_service_paths(monkeypatch):
+ module = _load_mcp_api(monkeypatch)
+
+ existing = _DummyMCPServer(
+ id="mcp-1",
+ name="srv",
+ url="http://server",
+ server_type="sse",
+ tenant_id="tenant_1",
+ variables={"tools": {"old": {"enabled": True}}, "token": "abc"},
+ headers={"Authorization": "old"},
+ )
+ updated = _DummyMCPServer(
+ id="mcp-1",
+ name="srv-new",
+ url="http://server-new",
+ server_type="sse",
+ tenant_id="tenant_1",
+ variables={"tools": {"tool_a": {"name": "tool_a"}}},
+ headers={"Authorization": "new"},
+ )
+
+ base_payload = {
+ "mcp_id": "mcp-1",
+ "name": "srv-new",
+ "url": "http://server-new",
+ "server_type": "sse",
+ "headers": '{"Authorization": "new"}',
+ "variables": '{"tools": {"ignore": 1}, "token": "new"}',
+ "timeout": "3.0",
+ }
+
+ _set_request_json(monkeypatch, module, dict(base_payload))
+ monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, existing))
+
+ async def _thread_pool_tools_error(_func, _servers, _timeout):
+ return None, "update tools error"
+
+ monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_tools_error)
+ res = _run(module.update("mcp-1"))
+ assert res["code"] == 102
+ assert "update tools error" in res["message"]
+
+ _set_request_json(monkeypatch, module, dict(base_payload))
+
+ async def _thread_pool_ok(_func, servers, _timeout):
+ return {servers[0].name: [{"name": "tool_a"}, {"bad": True}]}, None
+
+ monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_ok)
+ monkeypatch.setattr(module.MCPServerService, "filter_update", lambda *_args, **_kwargs: False)
+ res = _run(module.update("mcp-1"))
+ assert "Failed to updated MCP server" in res["message"]
+
+ _set_request_json(monkeypatch, module, dict(base_payload))
+ monkeypatch.setattr(module.MCPServerService, "filter_update", lambda *_args, **_kwargs: True)
+
+ def _get_by_id_fetch_fail(_mcp_id):
+ if _get_by_id_fetch_fail.calls == 0:
+ _get_by_id_fetch_fail.calls += 1
+ return True, existing
+ return False, None
+
+ _get_by_id_fetch_fail.calls = 0
+ monkeypatch.setattr(module.MCPServerService, "get_by_id", _get_by_id_fetch_fail)
+ res = _run(module.update("mcp-1"))
+ assert "Failed to fetch updated MCP server" in res["message"]
+
+ _set_request_json(monkeypatch, module, dict(base_payload))
+
+ def _get_by_id_success(_mcp_id):
+ if _get_by_id_success.calls == 0:
+ _get_by_id_success.calls += 1
+ return True, existing
+ return True, updated
+
+ _get_by_id_success.calls = 0
+ monkeypatch.setattr(module.MCPServerService, "get_by_id", _get_by_id_success)
+ res = _run(module.update("mcp-1"))
+ assert res["code"] == 0
+ assert res["data"]["id"] == "mcp-1"
+
+ _set_request_json(monkeypatch, module, dict(base_payload))
+ monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, existing))
+
+ async def _thread_pool_raises(_func, _servers, _timeout):
+ raise RuntimeError("update explode")
+
+ monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_raises)
+ res = _run(module.update("mcp-1"))
+ assert res["code"] == 100
+ assert "update explode" in res["message"]
+
+
+@pytest.mark.p2
+def test_rm_failure_success_and_exception(monkeypatch):
+ module = _load_mcp_api(monkeypatch)
+ server = _DummyMCPServer(id="id1", name="srv", url="http://a", server_type="sse", tenant_id="tenant_1", variables={})
+ monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, server))
+
+ _set_request_json(monkeypatch, module, {"mcp_ids": ["a", "b"]})
+ monkeypatch.setattr(module.MCPServerService, "delete_by_ids", lambda _ids: False)
+ res = _run(module.rm("id1"))
+ assert "Failed to delete MCP servers" in res["message"]
+
+ _set_request_json(monkeypatch, module, {"mcp_ids": ["a", "b"]})
+ monkeypatch.setattr(module.MCPServerService, "delete_by_ids", lambda _ids: True)
+ res = _run(module.rm("id1"))
+ assert res["code"] == 0
+ assert res["data"] is True
+
+ _set_request_json(monkeypatch, module, {"mcp_ids": ["a", "b"]})
+
+ def _raise_rm(_ids):
+ raise RuntimeError("rm explode")
+
+ monkeypatch.setattr(module.MCPServerService, "delete_by_ids", _raise_rm)
+ res = _run(module.rm("id1"))
+ assert res["code"] == 100
+ assert "rm explode" in res["message"]
+
+
+@pytest.mark.p2
+def test_import_multiple_missing_servers_and_exception(monkeypatch):
+ module = _load_mcp_api(monkeypatch)
+
+ _set_request_json(monkeypatch, module, {"mcpServers": {}})
+ res = _run(module.import_multiple.__wrapped__())
+ assert "No MCP servers provided" in res["message"]
+
+ _set_request_json(monkeypatch, module, {"mcpServers": {"srv": {"type": "sse", "url": "http://x"}}, "timeout": "1"})
+
+ def _raise_import(**_kwargs):
+ raise RuntimeError("import explode")
+
+ monkeypatch.setattr(module.MCPServerService, "get_by_name_and_tenant", _raise_import)
+ res = _run(module.import_multiple.__wrapped__())
+ assert res["code"] == 100
+ assert "import explode" in res["message"]
+
+
+@pytest.mark.p2
+def test_import_multiple_mixed_results(monkeypatch):
+ module = _load_mcp_api(monkeypatch)
+
+ payload = {
+ "mcpServers": {
+ "missing_fields": {"type": "sse"},
+ "": {"type": "sse", "url": "http://empty"},
+ "dup": {"type": "sse", "url": "http://dup", "authorization_token": "dup-token"},
+ "tool_err": {"type": "sse", "url": "http://err"},
+ "insert_fail": {"type": "sse", "url": "http://fail"},
+ },
+ "timeout": "3",
+ }
+ _set_request_json(monkeypatch, module, payload)
+
+ monkeypatch.setattr(module, "get_uuid", lambda: "uuid-import")
+
+ def _get_by_name_and_tenant(name, tenant_id):
+ if name == "dup" and not _get_by_name_and_tenant.first_dup_seen:
+ _get_by_name_and_tenant.first_dup_seen = True
+ return True, object()
+ return False, None
+
+ _get_by_name_and_tenant.first_dup_seen = False
+ monkeypatch.setattr(module.MCPServerService, "get_by_name_and_tenant", _get_by_name_and_tenant)
+
+ async def _thread_pool_exec(func, servers, _timeout):
+ mcp_server = servers[0]
+ if mcp_server.name == "tool_err":
+ return None, "tool call failed"
+ return {mcp_server.name: [{"name": "tool_a"}, {"invalid": True}]}, None
+
+ monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_exec)
+
+ def _insert(**kwargs):
+ return kwargs["name"] != "insert_fail"
+
+ monkeypatch.setattr(module.MCPServerService, "insert", _insert)
+
+ res = _run(module.import_multiple.__wrapped__())
+ assert res["code"] == 0
+
+ results = {item["server"]: item for item in res["data"]["results"]}
+ assert results["missing_fields"]["success"] is False
+ assert "Missing required fields" in results["missing_fields"]["message"]
+ assert results[""]["success"] is False
+ assert "Invalid MCP name" in results[""]["message"]
+ assert results["tool_err"]["success"] is False
+ assert "tool call failed" in results["tool_err"]["message"]
+ assert results["insert_fail"]["success"] is False
+ assert "Failed to create MCP server" in results["insert_fail"]["message"]
+ assert results["dup"]["success"] is True
+ assert results["dup"]["new_name"] == "dup_0"
+ assert "Renamed from 'dup' to 'dup_0' avoid duplication" == results["dup"]["message"]
+
+
+@pytest.mark.p2
+def test_detail_download_success_and_exception(monkeypatch):
+ module = _load_mcp_api(monkeypatch)
+ monkeypatch.setattr(module, "request", SimpleNamespace(args=_Args({"mode": "download"})))
+
+ monkeypatch.setattr(
+ module.MCPServerService,
+ "get_by_id",
+ lambda _mcp_id: (
+ True,
+ _DummyMCPServer(
+ id="id1",
+ name="srv-one",
+ url="http://one",
+ server_type="sse",
+ tenant_id="tenant_1",
+ variables={"authorization_token": "tok", "tools": {"tool_a": {"enabled": True}}},
+ ),
+ ),
+ )
+ res = module.detail("id1")
+ assert res["code"] == 0
+ assert list(res["data"]["mcpServers"].keys()) == ["srv-one"]
+
+ monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (False, None))
+ res = module.detail("missing")
+ assert res["code"] == 102
+ assert "Cannot find MCP server missing for user tenant_1" in res["message"]
+
+ monkeypatch.setattr(
+ module.MCPServerService,
+ "get_by_id",
+ lambda _mcp_id: (
+ True,
+ _DummyMCPServer(
+ id="id2",
+ name="srv-two",
+ url="http://two",
+ server_type="sse",
+ tenant_id="other",
+ variables={},
+ ),
+ ),
+ )
+ res = module.detail("id2")
+ assert res["code"] == 102
+ assert "Cannot find MCP server id2 for user tenant_1" in res["message"]
+
+ def _raise_export(_mcp_id):
+ raise RuntimeError("export explode")
+
+ monkeypatch.setattr(module.MCPServerService, "get_by_id", _raise_export)
+ res = module.detail("id1")
+ assert res["code"] == 100
+ assert "export explode" in res["message"]
+
+
+@pytest.mark.p2
+def test_test_mcp_route_matrix_unit(monkeypatch):
+ module = _load_mcp_api(monkeypatch)
+
+ _set_request_json(monkeypatch, module, {"url": "", "server_type": "sse"})
+ res = _run(module.test_mcp("mcp-1"))
+ assert "Invalid MCP url" in res["message"]
+
+ _set_request_json(monkeypatch, module, {"url": "http://a", "server_type": "invalid"})
+ res = _run(module.test_mcp("mcp-1"))
+ assert "Unsupported MCP server type" in res["message"]
+
+ close_calls = []
+
+ async def _thread_pool_exec_inner_error(func, *args):
+ if func is module.close_multiple_mcp_toolcall_sessions:
+ close_calls.append(args[0])
+ return None
+ if getattr(func, "__name__", "") == "get_tools":
+ raise RuntimeError("get tools explode")
+ return func(*args)
+
+ monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_exec_inner_error)
+ _set_request_json(monkeypatch, module, {"url": "http://a", "server_type": "sse"})
+ res = _run(module.test_mcp("mcp-1"))
+ assert res["code"] == 102
+ assert "Test MCP error: get tools explode" in res["message"]
+ assert close_calls and len(close_calls[-1]) == 1
+
+ close_calls_success = []
+
+ async def _thread_pool_exec_success(func, *args):
+ if func is module.close_multiple_mcp_toolcall_sessions:
+ close_calls_success.append(args[0])
+ return None
+ return func(*args)
+
+ monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_exec_success)
+ _set_request_json(monkeypatch, module, {"url": "http://a", "server_type": "sse"})
+ res = _run(module.test_mcp("mcp-1"))
+ assert res["code"] == 0
+ assert res["data"][0]["name"] == "tool_a"
+ assert all(tool["enabled"] is True for tool in res["data"])
+ assert close_calls_success and len(close_calls_success[-1]) == 1
+
+ def _raise_session(*_args, **_kwargs):
+ raise RuntimeError("session explode")
+
+ monkeypatch.setattr(module, "MCPToolCallSession", _raise_session)
+ _set_request_json(monkeypatch, module, {"url": "http://a", "server_type": "sse"})
+ res = _run(module.test_mcp("mcp-1"))
+ assert res["code"] == 100
+ assert "session explode" in res["message"]
diff --git a/test/testcases/restful_api/test_memories_messages.py b/test/testcases/restful_api/test_memories_messages.py
new file mode 100644
index 0000000000..12fcdc0df2
--- /dev/null
+++ b/test/testcases/restful_api/test_memories_messages.py
@@ -0,0 +1,210 @@
+#
+# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import time
+import uuid
+
+import pytest
+
+
+@pytest.fixture
+def memory_cleanup(rest_client):
+ created_ids: list[str] = []
+
+ def _cleanup():
+ cleanup_errors = []
+ for memory_id in created_ids:
+ delete_res = rest_client.delete(f"/memories/{memory_id}")
+ if delete_res.status_code != 200:
+ cleanup_errors.append((memory_id, delete_res.status_code, delete_res.text))
+ continue
+ delete_payload = delete_res.json()
+ if delete_payload["code"] not in (0, 404):
+ cleanup_errors.append((memory_id, delete_res.status_code, delete_payload))
+ assert not cleanup_errors, f"Memory cleanup failed: {cleanup_errors}"
+
+ yield created_ids
+ _cleanup()
+
+
+@pytest.fixture
+def create_memory_resource(rest_client, memory_cleanup):
+ def _create(name_prefix: str = "restful_memory") -> str:
+ payload = {
+ "name": f"{name_prefix}_{uuid.uuid4().hex[:8]}",
+ "memory_type": ["raw"],
+ "embd_id": "BAAI/bge-small-en-v1.5@Builtin",
+ "llm_id": "glm-4-flash@ZHIPU-AI",
+ }
+ res = rest_client.post("/memories", json=payload)
+ assert res.status_code == 200
+ res_payload = res.json()
+ assert res_payload["code"] == 0, res_payload
+ memory_id = res_payload["data"]["id"]
+ memory_cleanup.append(memory_id)
+ return memory_id
+
+ yield _create
+
+
+def _add_message(rest_client, memory_id: str, user_input: str, agent_response: str) -> None:
+ add_res = rest_client.post(
+ "/messages",
+ json={
+ "memory_id": [memory_id],
+ "agent_id": uuid.uuid4().hex,
+ "session_id": uuid.uuid4().hex,
+ "user_id": uuid.uuid4().hex,
+ "user_input": user_input,
+ "agent_response": agent_response,
+ },
+ )
+ assert add_res.status_code == 200
+ add_payload = add_res.json()
+ assert add_payload["code"] == 0, add_payload
+
+
+def _wait_for_memory_messages(rest_client, memory_id: str, timeout: float = 10, interval: float = 0.2) -> list[dict]:
+ deadline = time.time() + timeout
+ last_payload = None
+ while time.time() < deadline:
+ res = rest_client.get(f"/memories/{memory_id}")
+ if res.status_code == 200:
+ payload = res.json()
+ last_payload = payload
+ if payload.get("code") == 0:
+ message_list = payload.get("data", {}).get("messages", {}).get("message_list", [])
+ if message_list:
+ return message_list
+ time.sleep(interval)
+ pytest.fail(f"Timed out waiting for memory messages: {last_payload}")
+
+
+@pytest.mark.p1
+def test_memory_crud_cycle(rest_client, create_memory_resource):
+ memory_id = create_memory_resource("restful_memory_crud")
+
+ list_res = rest_client.get("/memories")
+ assert list_res.status_code == 200
+ list_payload = list_res.json()
+ assert list_payload["code"] == 0, list_payload
+ assert any(item["id"] == memory_id for item in list_payload["data"]["memory_list"]), list_payload
+
+ config_res = rest_client.get(f"/memories/{memory_id}/config")
+ assert config_res.status_code == 200
+ config_payload = config_res.json()
+ assert config_payload["code"] == 0, config_payload
+ assert config_payload["data"]["id"] == memory_id, config_payload
+
+ update_res = rest_client.put(
+ f"/memories/{memory_id}",
+ json={"name": f"updated_{uuid.uuid4().hex[:6]}", "permissions": "me"},
+ )
+ assert update_res.status_code == 200
+ update_payload = update_res.json()
+ assert update_payload["code"] == 0, update_payload
+
+ delete_res = rest_client.delete(f"/memories/{memory_id}")
+ assert delete_res.status_code == 200
+ delete_payload = delete_res.json()
+ assert delete_payload["code"] == 0, delete_payload
+
+
+@pytest.mark.p2
+def test_memory_create_missing_required_fields(rest_client):
+ res = rest_client.post("/memories", json={"name": "missing_models", "memory_type": ["raw"]})
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] == 101, payload
+
+
+@pytest.mark.p1
+def test_messages_add_list_recent_content_update_forget(rest_client, create_memory_resource):
+ memory_id = create_memory_resource("restful_message_memory")
+ _add_message(
+ rest_client,
+ memory_id,
+ user_input="what is coriander?",
+ agent_response="coriander can refer to leaves or seeds",
+ )
+
+ message_list = _wait_for_memory_messages(rest_client, memory_id)
+
+ message_id = message_list[0]["message_id"]
+
+ recent_res = rest_client.get("/messages", params={"memory_id": memory_id, "limit": 10})
+ assert recent_res.status_code == 200
+ recent_payload = recent_res.json()
+ assert recent_payload["code"] == 0, recent_payload
+ assert any(item["message_id"] == message_id for item in recent_payload["data"]), recent_payload
+
+ content_res = rest_client.get(f"/messages/{memory_id}:{message_id}/content")
+ assert content_res.status_code == 200
+ content_payload = content_res.json()
+ assert content_payload["code"] == 0, content_payload
+ assert content_payload["data"]["content"], content_payload
+
+ update_res = rest_client.put(f"/messages/{memory_id}:{message_id}", json={"status": False})
+ assert update_res.status_code == 200
+ update_payload = update_res.json()
+ assert update_payload["code"] == 0, update_payload
+
+ forget_res = rest_client.delete(f"/messages/{memory_id}:{message_id}")
+ assert forget_res.status_code == 200
+ forget_payload = forget_res.json()
+ assert forget_payload["code"] == 0, forget_payload
+
+
+@pytest.mark.p2
+def test_message_status_validation_requires_boolean(rest_client, create_memory_resource):
+ memory_id = create_memory_resource("restful_message_status_validation")
+ _add_message(rest_client, memory_id, user_input="hello", agent_response="hello")
+
+ message_id = _wait_for_memory_messages(rest_client, memory_id)[0]["message_id"]
+
+ invalid_update = rest_client.put(f"/messages/{memory_id}:{message_id}", json={"status": "false"})
+ assert invalid_update.status_code == 200
+ invalid_payload = invalid_update.json()
+ assert invalid_payload["code"] == 101, invalid_payload
+ assert "Status must be a boolean." in invalid_payload["message"], invalid_payload
+
+
+@pytest.mark.p2
+def test_messages_recent_requires_memory_ids(rest_client):
+ res = rest_client.get("/messages")
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] == 101, payload
+ assert "memory_ids is required" in payload["message"], payload
+
+
+@pytest.mark.p2
+def test_message_search_route_contract(rest_client, create_memory_resource):
+ memory_id = create_memory_resource("restful_message_search")
+ _add_message(
+ rest_client,
+ memory_id,
+ user_input="what is pineapple?",
+ agent_response="pineapple is a tropical fruit",
+ )
+
+ _wait_for_memory_messages(rest_client, memory_id)
+
+ res = rest_client.get("/messages/search", params={"memory_id": memory_id, "query": "pineapple", "top_n": 3})
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] == 0, payload
+ assert isinstance(payload["data"], list), payload
diff --git a/test/testcases/restful_api/test_memory_messages.py b/test/testcases/restful_api/test_memory_messages.py
new file mode 100644
index 0000000000..dcf5a3704f
--- /dev/null
+++ b/test/testcases/restful_api/test_memory_messages.py
@@ -0,0 +1,165 @@
+#
+# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import uuid
+
+import pytest
+
+
+def _memory_payload(name: str) -> dict:
+ return {
+ "name": name,
+ "memory_type": ["raw"],
+ "embd_id": "BAAI/bge-small-en-v1.5@Builtin",
+ "llm_id": "glm-4-flash@ZHIPU-AI",
+ }
+
+
+def _create_memory(rest_client, name: str) -> dict:
+ res = rest_client.post("/memories", json=_memory_payload(name))
+ assert res.status_code == 200
+ payload = res.json()
+ if payload["code"] == 0:
+ return payload["data"]
+
+ pytest.fail(f"Failed to create memory: {payload}")
+
+
+@pytest.fixture
+def memory_resource(rest_client):
+ memory = _create_memory(rest_client, f"restful_memory_{uuid.uuid4().hex[:8]}")
+ memory_id = memory["id"]
+ try:
+ yield memory
+ finally:
+ delete_res = rest_client.delete(f"/memories/{memory_id}")
+ assert delete_res.status_code == 200, delete_res.text
+ delete_payload = delete_res.json()
+ assert delete_payload["code"] in (0, 404), delete_payload
+
+
+@pytest.mark.p2
+def test_memory_and_message_routes_require_auth(rest_client_noauth):
+ memory_res = rest_client_noauth.get("/memories")
+ assert memory_res.status_code == 401
+ memory_payload = memory_res.json()
+ assert memory_payload["code"] == 401, memory_payload
+
+ msg_list_res = rest_client_noauth.get("/messages")
+ assert msg_list_res.status_code == 401
+ msg_list_payload = msg_list_res.json()
+ assert msg_list_payload["code"] == 401, msg_list_payload
+
+ msg_search_res = rest_client_noauth.get("/messages/search")
+ assert msg_search_res.status_code == 401
+ msg_search_payload = msg_search_res.json()
+ assert msg_search_payload["code"] == 401, msg_search_payload
+
+
+@pytest.mark.p2
+def test_memory_crud_and_config(rest_client):
+ memory = _create_memory(rest_client, f"restful_memory_crud_{uuid.uuid4().hex[:8]}")
+ memory_id = memory["id"]
+ try:
+ config_res = rest_client.get(f"/memories/{memory_id}/config")
+ assert config_res.status_code == 200
+ config_payload = config_res.json()
+ assert config_payload["code"] == 0, config_payload
+ assert config_payload["data"]["id"] == memory_id, config_payload
+
+ list_res = rest_client.get("/memories", params={"keywords": memory["name"]})
+ assert list_res.status_code == 200
+ list_payload = list_res.json()
+ assert list_payload["code"] == 0, list_payload
+ assert any(item["id"] == memory_id for item in list_payload["data"]["memory_list"]), list_payload
+
+ update_res = rest_client.put(f"/memories/{memory_id}", json={"name": "restful_memory_updated"})
+ assert update_res.status_code == 200
+ update_payload = update_res.json()
+ assert update_payload["code"] == 0, update_payload
+ finally:
+ delete_res = rest_client.delete(f"/memories/{memory_id}")
+ assert delete_res.status_code == 200, delete_res.text
+ delete_payload = delete_res.json()
+ assert delete_payload["code"] in (0, 404), delete_payload
+
+
+@pytest.mark.p2
+def test_memory_update_invalid_name(rest_client, memory_resource):
+ memory_id = memory_resource["id"]
+ res = rest_client.put(f"/memories/{memory_id}", json={"name": " "})
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] == 101, payload
+ assert "cannot be empty" in payload["message"], payload
+
+
+@pytest.mark.p2
+def test_messages_list_and_search_validation_contracts(rest_client, memory_resource):
+ memory_id = memory_resource["id"]
+
+ list_res = rest_client.get("/messages", params={"memory_id": memory_id, "limit": 10})
+ assert list_res.status_code == 200
+ list_payload = list_res.json()
+ assert list_payload["code"] == 0, list_payload
+ assert isinstance(list_payload["data"], list), list_payload
+
+ missing_memory_res = rest_client.get("/messages")
+ assert missing_memory_res.status_code == 200
+ missing_memory_payload = missing_memory_res.json()
+ assert missing_memory_payload["code"] == 101, missing_memory_payload
+ assert "memory_ids is required" in missing_memory_payload["message"], missing_memory_payload
+
+ search_res = rest_client.get("/messages/search", params={"memory_id": memory_id, "query": "coriander"})
+ assert search_res.status_code == 200
+ search_payload = search_res.json()
+ assert search_payload["code"] == 0, search_payload
+ assert isinstance(search_payload["data"], list), search_payload
+
+ search_no_memory = rest_client.get("/messages/search", params={"query": "coriander"})
+ assert search_no_memory.status_code == 200
+ search_no_memory_payload = search_no_memory.json()
+ assert search_no_memory_payload["code"] == 0, search_no_memory_payload
+ assert isinstance(search_no_memory_payload["data"], list), search_no_memory_payload
+
+
+@pytest.mark.p2
+def test_message_update_forget_and_content_error_contracts(rest_client, memory_resource):
+ memory_id = memory_resource["id"]
+
+ invalid_status_res = rest_client.put(
+ f"/messages/{memory_id}:1",
+ json={"status": "false"},
+ )
+ assert invalid_status_res.status_code == 200
+ invalid_status_payload = invalid_status_res.json()
+ assert invalid_status_payload["code"] == 101, invalid_status_payload
+ assert "Status must be a boolean" in invalid_status_payload["message"], invalid_status_payload
+
+ missing_content_res = rest_client.get(f"/messages/{memory_id}:999999/content")
+ assert missing_content_res.status_code == 200
+ missing_content_payload = missing_content_res.json()
+ assert missing_content_payload["code"] == 404, missing_content_payload
+
+ invalid_memory_forget = rest_client.delete("/messages/missing_memory_id:1")
+ assert invalid_memory_forget.status_code == 200
+ invalid_memory_forget_payload = invalid_memory_forget.json()
+ assert invalid_memory_forget_payload["code"] == 404, invalid_memory_forget_payload
+
+ invalid_memory_update = rest_client.put("/messages/missing_memory_id:1", json={"status": False})
+ assert invalid_memory_update.status_code == 200
+ invalid_memory_update_payload = invalid_memory_update.json()
+ assert invalid_memory_update_payload["code"] == 404, invalid_memory_update_payload
diff --git a/test/testcases/restful_api/test_openai_compatible.py b/test/testcases/restful_api/test_openai_compatible.py
new file mode 100644
index 0000000000..d858576725
--- /dev/null
+++ b/test/testcases/restful_api/test_openai_compatible.py
@@ -0,0 +1,212 @@
+#
+# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import json
+
+import pytest
+
+
+def _sse_events(response_text: str) -> list[str]:
+ return [line[5:] for line in response_text.splitlines() if line.startswith("data:")]
+
+
+@pytest.mark.p2
+@pytest.mark.parametrize(
+ "payload, expected_message",
+ [
+ (
+ {
+ "model": "model",
+ "messages": [{"role": "user", "content": "hello"}],
+ "extra_body": "invalid_extra_body",
+ },
+ "extra_body must be an object.",
+ ),
+ (
+ {
+ "model": "model",
+ "messages": [{"role": "user", "content": "hello"}],
+ "extra_body": {"reference_metadata": "invalid_reference_metadata"},
+ },
+ "reference_metadata must be an object.",
+ ),
+ (
+ {
+ "model": "model",
+ "messages": [{"role": "user", "content": "hello"}],
+ "extra_body": {"reference_metadata": {"fields": "author"}},
+ },
+ "reference_metadata.fields must be an array.",
+ ),
+ (
+ {
+ "model": "model",
+ "messages": [],
+ },
+ "You have to provide messages.",
+ ),
+ (
+ {
+ "model": "model",
+ "messages": [{"role": "assistant", "content": "hello"}],
+ },
+ "The last content of this conversation is not from user.",
+ ),
+ ],
+)
+def test_openai_compatible_validation_payloads(rest_client, create_chat, payload, expected_message):
+ chat_id = create_chat("restful_openai_validation_chat")
+ res = rest_client.post(f"/openai/{chat_id}/chat/completions", json=payload)
+ assert res.status_code == 200
+ data = res.json()
+ assert data["code"] != 0, data
+ assert expected_message in data.get("message", ""), data
+
+
+@pytest.mark.p2
+def test_openai_compatible_metadata_condition_requires_object(rest_client, create_chat):
+ chat_id = create_chat("restful_openai_metadata_condition_chat")
+ res = rest_client.post(
+ f"/openai/{chat_id}/chat/completions",
+ json={
+ "model": "model",
+ "messages": [{"role": "user", "content": "hello"}],
+ "extra_body": {"metadata_condition": "invalid"},
+ },
+ )
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] == 102, payload
+ assert "metadata_condition must be an object." in payload["message"], payload
+
+
+@pytest.mark.p2
+def test_openai_compatible_invalid_chat(rest_client):
+ res = rest_client.post(
+ "/openai/invalid_chat_id/chat/completions",
+ json={
+ "model": "model",
+ "messages": [{"role": "user", "content": "hello"}],
+ "stream": False,
+ },
+ )
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] != 0, payload
+ assert "don't own the chat" in payload["message"], payload
+
+
+@pytest.mark.p2
+def test_openai_compatible_nonstream_shape(rest_client, create_chat):
+ chat_id = create_chat("restful_openai_nonstream_chat")
+ res = rest_client.post(
+ f"/openai/{chat_id}/chat/completions",
+ json={
+ "model": "model",
+ "messages": [{"role": "user", "content": "hello"}],
+ "stream": False,
+ },
+ timeout=60,
+ )
+ assert res.status_code == 200
+ payload = res.json()
+
+ assert payload["object"] == "chat.completion", payload
+ assert isinstance(payload["choices"], list) and payload["choices"], payload
+ first_choice = payload["choices"][0]
+ assert first_choice.get("finish_reason") == "stop", payload
+ assert first_choice.get("message", {}).get("role") == "assistant", payload
+ assert "content" in first_choice.get("message", {}), payload
+
+ usage = payload.get("usage", {})
+ assert "prompt_tokens" in usage, usage
+ assert "completion_tokens" in usage, usage
+ assert "total_tokens" in usage, usage
+ assert usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"], usage
+
+
+@pytest.mark.p2
+def test_openai_compatible_nonstream_with_reference_output_shape(rest_client, create_chat):
+ chat_id = create_chat("restful_openai_reference_chat")
+ res = rest_client.post(
+ f"/openai/{chat_id}/chat/completions",
+ json={
+ "model": "model",
+ "messages": [{"role": "user", "content": "hello"}],
+ "stream": False,
+ "extra_body": {
+ "reference": True,
+ "reference_metadata": {"include": True, "fields": ["author"]},
+ },
+ },
+ timeout=60,
+ )
+ assert res.status_code == 200
+ payload = res.json()
+ choice_msg = payload["choices"][0]["message"]
+ assert "reference" in choice_msg, payload
+ assert isinstance(choice_msg["reference"], list), payload
+
+
+@pytest.mark.p2
+def test_openai_compatible_stream_shape_and_done_semantics(rest_client, create_chat):
+ chat_id = create_chat("restful_openai_stream_chat")
+ res = rest_client.post(
+ f"/openai/{chat_id}/chat/completions",
+ json={
+ "model": "model",
+ "messages": [{"role": "user", "content": "hello"}],
+ "stream": True,
+ "extra_body": {"reference": True},
+ },
+ timeout=60,
+ )
+ assert res.status_code == 200
+ content_type = res.headers.get("Content-Type", "")
+ assert "text/event-stream" in content_type, content_type
+
+ events = _sse_events(res.text)
+ assert events, res.text
+ assert events[-1].strip() == "[DONE]", events[-1]
+
+ json_events = [json.loads(evt) for evt in events if evt.strip() != "[DONE]"]
+ assert json_events, events
+ assert any(evt.get("object") == "chat.completion.chunk" for evt in json_events), json_events
+ assert any(evt.get("choices", [{}])[0].get("finish_reason") == "stop" for evt in json_events), json_events
+
+
+@pytest.mark.p2
+def test_openai_compatible_reference_metadata_fields_filter_accepts_array(rest_client, create_chat):
+ chat_id = create_chat("restful_openai_reference_fields_array_chat")
+ res = rest_client.post(
+ f"/openai/{chat_id}/chat/completions",
+ json={
+ "model": "model",
+ "messages": [{"role": "user", "content": "hello"}],
+ "stream": False,
+ "extra_body": {
+ "reference": True,
+ "reference_metadata": {"include": True, "fields": ["author", "year"]},
+ },
+ },
+ timeout=60,
+ )
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload.get("choices"), payload
+ choice_msg = payload["choices"][0]["message"]
+ assert "reference" in choice_msg, payload
+ assert isinstance(choice_msg["reference"], list), payload
diff --git a/test/testcases/restful_api/test_plugin_tools.py b/test/testcases/restful_api/test_plugin_tools.py
new file mode 100644
index 0000000000..c151394c29
--- /dev/null
+++ b/test/testcases/restful_api/test_plugin_tools.py
@@ -0,0 +1,92 @@
+#
+# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+import pytest
+
+
+@pytest.mark.p2
+def test_plugin_tools_requires_auth(rest_client_noauth):
+ res = rest_client_noauth.get("/plugin/tools")
+ assert res.status_code == 401
+ payload = res.json()
+ assert payload["code"] == 401, payload
+
+
+@pytest.mark.p2
+def test_plugin_tools_contract(rest_client):
+ res = rest_client.get("/plugin/tools")
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] == 0, payload
+ assert isinstance(payload["data"], list), payload
+
+
+class _DummyManager:
+ def route(self, *_args, **_kwargs):
+ def decorator(func):
+ return func
+
+ return decorator
+
+
+def _load_plugin_module(monkeypatch):
+ repo_root = Path(__file__).resolve().parents[3]
+
+ common_pkg = ModuleType("common")
+ common_pkg.__path__ = [str(repo_root / "common")]
+ monkeypatch.setitem(sys.modules, "common", common_pkg)
+
+ stub_apps = ModuleType("api.apps")
+ stub_apps.login_required = lambda func: func
+ monkeypatch.setitem(sys.modules, "api.apps", stub_apps)
+
+ stub_plugin = ModuleType("agent.plugin")
+
+ class _StubGlobalPluginManager:
+ @staticmethod
+ def get_llm_tools():
+ return []
+
+ stub_plugin.GlobalPluginManager = _StubGlobalPluginManager
+ monkeypatch.setitem(sys.modules, "agent.plugin", stub_plugin)
+
+ module_path = repo_root / "api" / "apps" / "restful_apis" / "plugin_api.py"
+ spec = importlib.util.spec_from_file_location("restful_plugin_api_unit", module_path)
+ module = importlib.util.module_from_spec(spec)
+ module.manager = _DummyManager()
+ spec.loader.exec_module(module)
+ return module
+
+
+@pytest.mark.p2
+def test_plugin_tools_metadata_shape_unit(monkeypatch):
+ module = _load_plugin_module(monkeypatch)
+
+ class _DummyTool:
+ def get_metadata(self):
+ return {"name": "dummy", "description": "test"}
+
+ monkeypatch.setattr(module.GlobalPluginManager, "get_llm_tools", staticmethod(lambda: [_DummyTool()]))
+ res = module.llm_tools()
+ assert res["code"] == 0
+ assert isinstance(res["data"], list)
+ assert res["data"][0]["name"] == "dummy"
+ assert res["data"][0]["description"] == "test"
diff --git a/test/testcases/restful_api/test_retrieval.py b/test/testcases/restful_api/test_retrieval.py
new file mode 100644
index 0000000000..bce37c4cdb
--- /dev/null
+++ b/test/testcases/restful_api/test_retrieval.py
@@ -0,0 +1,109 @@
+#
+# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import pytest
+
+
+@pytest.mark.p1
+def test_dataset_search_rest_endpoint(rest_client, ensure_parsed_document):
+ dataset_id, _ = ensure_parsed_document()
+ res = rest_client.post(
+ f"/datasets/{dataset_id}/search",
+ json={"question": "test TXT file", "top_k": 5},
+ )
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] == 0, payload
+ assert "chunks" in payload["data"], payload
+
+
+@pytest.mark.p2
+def test_multi_dataset_search_rest_endpoint(rest_client, ensure_parsed_document):
+ dataset_id, _ = ensure_parsed_document()
+ res = rest_client.post(
+ "/datasets/search",
+ json={"dataset_ids": [dataset_id], "question": "test TXT file", "top_k": 5},
+ )
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] == 0, payload
+ assert "chunks" in payload["data"], payload
+
+
+@pytest.mark.p2
+def test_multi_dataset_search_with_metadata_filter(rest_client, ensure_parsed_document):
+ dataset_id, document_id = ensure_parsed_document()
+ meta_res = rest_client.patch(
+ f"/datasets/{dataset_id}/documents/metadatas",
+ json={
+ "selector": {"document_ids": [document_id]},
+ "updates": [{"key": "author", "value": "qa_batch2"}],
+ "deletes": [],
+ },
+ )
+ assert meta_res.status_code == 200
+ meta_payload = meta_res.json()
+ assert meta_payload["code"] == 0, meta_payload
+
+ res = rest_client.post(
+ "/datasets/search",
+ json={
+ "dataset_ids": [dataset_id],
+ "question": "test TXT file",
+ "meta_data_filter": {
+ "method": "manual",
+ "logic": "and",
+ "manual": [{"key": "author", "op": "=", "value": "qa_batch2"}],
+ },
+ },
+ )
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] == 0, payload
+ assert "chunks" in payload["data"], payload
+
+
+@pytest.mark.p2
+def test_retrieval_compatibility_endpoint(rest_client, ensure_parsed_document):
+ dataset_id, _ = ensure_parsed_document()
+ # /api/v1/retrieval is SDK compatibility endpoint from api/apps/sdk/doc.py.
+ res = rest_client.post(
+ "/retrieval",
+ json={"dataset_ids": [dataset_id], "question": "test TXT file", "top_k": 5},
+ )
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] == 0, payload
+ assert "chunks" in payload["data"], payload
+
+
+@pytest.mark.p2
+def test_retrieval_compatibility_requires_dataset_ids(rest_client):
+ res = rest_client.post("/retrieval", json={"question": "test"})
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] == 102, payload
+ assert payload["message"] == "`dataset_ids` is required.", payload
+
+
+@pytest.mark.p2
+def test_retrieval_compatibility_requires_auth(rest_client_noauth):
+ res = rest_client_noauth.post("/retrieval", json={"question": "test", "dataset_ids": ["x"]})
+ assert res.status_code == 401
+ payload = res.json()
+ # token_required preserves legacy payload code/message while returning HTTP 401.
+ assert payload["code"] == 0, payload
+ assert payload["message"] == "`Authorization` can't be empty", payload
diff --git a/test/testcases/restful_api/test_router_contracts.py b/test/testcases/restful_api/test_router_contracts.py
new file mode 100644
index 0000000000..72683bad8f
--- /dev/null
+++ b/test/testcases/restful_api/test_router_contracts.py
@@ -0,0 +1,28 @@
+#
+# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import pytest
+from configs import VERSION
+
+
+@pytest.mark.p1
+def test_route_not_found_returns_json(rest_client_noauth):
+ res = rest_client_noauth.get("/__missing_route__")
+ assert res.status_code == 404
+ payload = res.json()
+ assert payload["code"] == 404, payload
+ assert payload["error"] == "Not Found", payload
+ assert payload["message"] == f"Not Found: /api/{VERSION}/__missing_route__", payload
diff --git a/test/testcases/restful_api/test_searches.py b/test/testcases/restful_api/test_searches.py
new file mode 100644
index 0000000000..1a6923fb50
--- /dev/null
+++ b/test/testcases/restful_api/test_searches.py
@@ -0,0 +1,155 @@
+#
+# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import json
+import uuid
+
+import pytest
+
+
+@pytest.fixture
+def search_resource(rest_client):
+ name = f"restful_search_{uuid.uuid4().hex[:8]}"
+ create_res = rest_client.post("/searches", json={"name": name, "description": "restful search"})
+ assert create_res.status_code == 200
+ create_payload = create_res.json()
+ assert create_payload["code"] == 0, create_payload
+ search_id = create_payload["data"]["search_id"]
+
+ try:
+ yield search_id
+ finally:
+ delete_res = rest_client.delete(f"/searches/{search_id}")
+ assert delete_res.status_code == 200, delete_res.text
+ delete_payload = delete_res.json()
+ assert delete_payload["code"] in (0, 109), delete_payload
+
+
+def _sse_events(response_text: str) -> list[str]:
+ return [line[5:] for line in response_text.splitlines() if line.startswith("data:")]
+
+
+@pytest.mark.p2
+def test_search_routes_require_auth(rest_client_noauth):
+ create_res = rest_client_noauth.post("/searches", json={"name": "search_noauth"})
+ assert create_res.status_code == 401
+ create_payload = create_res.json()
+ assert create_payload["code"] == 401, create_payload
+
+ list_res = rest_client_noauth.get("/searches")
+ assert list_res.status_code == 401
+ list_payload = list_res.json()
+ assert list_payload["code"] == 401, list_payload
+
+
+@pytest.mark.p2
+def test_search_crud_contract(rest_client, search_resource):
+ search_id = search_resource
+
+ list_res = rest_client.get("/searches")
+ assert list_res.status_code == 200
+ list_payload = list_res.json()
+ assert list_payload["code"] == 0, list_payload
+ assert any(item.get("id") == search_id for item in list_payload["data"]["search_apps"]), list_payload
+
+ detail_res = rest_client.get(f"/searches/{search_id}")
+ assert detail_res.status_code == 200
+ detail_payload = detail_res.json()
+ assert detail_payload["code"] == 0, detail_payload
+ assert detail_payload["data"]["id"] == search_id, detail_payload
+
+ new_name = f"search_updated_{uuid.uuid4().hex[:6]}"
+ update_res = rest_client.put(
+ f"/searches/{search_id}",
+ json={"name": new_name, "search_config": {"top_k": 3}},
+ )
+ assert update_res.status_code == 200
+ update_payload = update_res.json()
+ assert update_payload["code"] == 0, update_payload
+ assert update_payload["data"]["name"] == new_name, update_payload
+
+
+@pytest.mark.p2
+def test_search_create_invalid_name(rest_client):
+ res = rest_client.post("/searches", json={"name": ""})
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] == 102, payload
+ assert "empty" in payload["message"], payload
+
+
+@pytest.mark.p2
+def test_search_update_invalid_search_id(rest_client):
+ res = rest_client.put(
+ "/searches/invalid_search_id",
+ json={"name": "invalid", "search_config": {}},
+ )
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] == 109, payload
+ assert "No authorization" in payload["message"], payload
+
+
+@pytest.mark.p2
+def test_search_completion_requires_question(rest_client, search_resource):
+ search_id = search_resource
+
+ completion_res = rest_client.post(f"/searches/{search_id}/completion", json={})
+ assert completion_res.status_code == 200
+ completion_payload = completion_res.json()
+ assert completion_payload["code"] == 101, completion_payload
+ assert "required argument are missing: question" in completion_payload["message"], completion_payload
+
+ completions_res = rest_client.post(f"/searches/{search_id}/completions", json={})
+ assert completions_res.status_code == 200
+ completions_payload = completions_res.json()
+ assert completions_payload["code"] == 101, completions_payload
+ assert "required argument are missing: question" in completions_payload["message"], completions_payload
+
+
+@pytest.mark.p2
+def test_search_completion_requires_kb_ids(rest_client, search_resource):
+ search_id = search_resource
+ for path in ("completion", "completions"):
+ res = rest_client.post(
+ f"/searches/{search_id}/{path}",
+ json={"question": "what is coriander?"},
+ )
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] == 102, payload
+ assert "`kb_ids` is required" in payload["message"], payload
+
+
+@pytest.mark.p2
+def test_search_completion_sse_shape_when_kb_ids_provided(rest_client, search_resource):
+ search_id = search_resource
+ # Even with kb_ids provided, runtime may return an error event in-stream, but
+ # contract remains SSE with JSON data lines and terminal boolean event.
+ res = rest_client.post(
+ f"/searches/{search_id}/completion",
+ json={"question": "what is coriander?", "kb_ids": ["nonexistent_dataset"]},
+ timeout=60,
+ )
+ assert res.status_code == 200
+ content_type = res.headers.get("Content-Type", "")
+ assert "text/event-stream" in content_type, content_type
+
+ events = _sse_events(res.text)
+ assert events, res.text
+ parsed = [json.loads(evt) for evt in events]
+ assert isinstance(parsed[0], dict), parsed
+ assert parsed[-1].get("data") is True, parsed[-1]
diff --git a/test/testcases/restful_api/test_sessions.py b/test/testcases/restful_api/test_sessions.py
new file mode 100644
index 0000000000..6a7a9de82c
--- /dev/null
+++ b/test/testcases/restful_api/test_sessions.py
@@ -0,0 +1,219 @@
+#
+# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import json
+
+import pytest
+
+
+def _sse_events(response_text: str) -> list[str]:
+ return [line[5:] for line in response_text.splitlines() if line.startswith("data:")]
+
+
+@pytest.mark.p1
+def test_session_crud_cycle(rest_client, create_chat):
+ chat_id = create_chat("restful_session_crud_chat")
+
+ create_res = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": "session_a"})
+ assert create_res.status_code == 200
+ create_payload = create_res.json()
+ assert create_payload["code"] == 0, create_payload
+ session_id = create_payload["data"]["id"]
+ assert create_payload["data"]["chat_id"] == chat_id, create_payload
+
+ list_res = rest_client.get(f"/chats/{chat_id}/sessions")
+ assert list_res.status_code == 200
+ list_payload = list_res.json()
+ assert list_payload["code"] == 0, list_payload
+ assert any(item["id"] == session_id for item in list_payload["data"]), list_payload
+
+ get_res = rest_client.get(f"/chats/{chat_id}/sessions/{session_id}")
+ assert get_res.status_code == 200
+ get_payload = get_res.json()
+ assert get_payload["code"] == 0, get_payload
+ assert get_payload["data"]["id"] == session_id, get_payload
+
+ patch_res = rest_client.patch(
+ f"/chats/{chat_id}/sessions/{session_id}",
+ json={"name": "session_a_updated"},
+ )
+ assert patch_res.status_code == 200
+ patch_payload = patch_res.json()
+ assert patch_payload["code"] == 0, patch_payload
+ assert patch_payload["data"]["name"] == "session_a_updated", patch_payload
+
+ delete_res = rest_client.delete(f"/chats/{chat_id}/sessions", json={"ids": [session_id]})
+ assert delete_res.status_code == 200
+ delete_payload = delete_res.json()
+ assert delete_payload["code"] == 0, delete_payload
+
+ list_after_delete = rest_client.get(f"/chats/{chat_id}/sessions")
+ assert list_after_delete.status_code == 200
+ list_after_delete_payload = list_after_delete.json()
+ assert list_after_delete_payload["code"] == 0, list_after_delete_payload
+ assert all(item["id"] != session_id for item in list_after_delete_payload["data"]), list_after_delete_payload
+
+
+@pytest.mark.p2
+def test_session_create_name_validation(rest_client, create_chat):
+ chat_id = create_chat("restful_session_name_validation_chat")
+
+ res = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": " "})
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] == 102, payload
+ assert "`name` can not be empty." in payload["message"], payload
+
+
+@pytest.mark.p2
+def test_session_update_blocks_messages_and_reference(rest_client, create_chat):
+ chat_id = create_chat("restful_session_guard_chat")
+ create_res = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": "session_guard"})
+ assert create_res.status_code == 200
+ create_payload = create_res.json()
+ assert create_payload["code"] == 0, create_payload
+ session_id = create_payload["data"]["id"]
+
+ msg_res = rest_client.patch(f"/chats/{chat_id}/sessions/{session_id}", json={"messages": []})
+ assert msg_res.status_code == 200
+ msg_payload = msg_res.json()
+ assert msg_payload["code"] == 102, msg_payload
+ assert "`messages` cannot be changed." in msg_payload["message"], msg_payload
+
+ ref_res = rest_client.patch(f"/chats/{chat_id}/sessions/{session_id}", json={"reference": []})
+ assert ref_res.status_code == 200
+ ref_payload = ref_res.json()
+ assert ref_payload["code"] == 102, ref_payload
+ assert "`reference` cannot be changed." in ref_payload["message"], ref_payload
+
+
+@pytest.mark.p2
+def test_chat_recommendation_requires_question(rest_client):
+ res = rest_client.post("/chat/recommendation", json={})
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] == 101, payload
+ assert "required argument are missing: question" in payload["message"], payload
+
+
+@pytest.mark.p2
+def test_related_questions_compatibility_requires_auth(rest_client_noauth):
+ # /api/v1/searchbots/related_questions is an SDK compatibility endpoint.
+ res = rest_client_noauth.post(
+ "/searchbots/related_questions",
+ json={"question": "ragflow"},
+ headers={"Authorization": "invalid"},
+ )
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] == 102, payload
+ assert "Authorization is not valid!" in payload["message"], payload
+
+
+@pytest.mark.p2
+def test_chat_completion_nonstream_with_session(rest_client, create_chat):
+ chat_id = create_chat("restful_completion_nonstream_chat")
+ create_session_res = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": "session_for_completion"})
+ assert create_session_res.status_code == 200
+ create_session_payload = create_session_res.json()
+ assert create_session_payload["code"] == 0, create_session_payload
+ session_id = create_session_payload["data"]["id"]
+
+ completion_res = rest_client.post(
+ "/chat/completions",
+ json={
+ "chat_id": chat_id,
+ "session_id": session_id,
+ "messages": [{"role": "user", "content": "hello"}],
+ "stream": False,
+ },
+ timeout=60,
+ )
+ assert completion_res.status_code == 200
+ completion_payload = completion_res.json()
+ assert completion_payload["code"] == 0, completion_payload
+ assert isinstance(completion_payload["data"], dict), completion_payload
+ assert completion_payload["data"]["session_id"] == session_id, completion_payload
+ assert "answer" in completion_payload["data"], completion_payload
+ assert "reference" in completion_payload["data"], completion_payload
+
+
+@pytest.mark.p2
+def test_chat_completion_stream_events(rest_client, create_chat):
+ chat_id = create_chat("restful_completion_stream_chat")
+ stream_res = rest_client.post(
+ "/chat/completions",
+ json={
+ "chat_id": chat_id,
+ "messages": [{"role": "user", "content": "hello"}],
+ "stream": True,
+ },
+ timeout=60,
+ )
+ assert stream_res.status_code == 200
+ content_type = stream_res.headers.get("Content-Type", "")
+ assert "text/event-stream" in content_type, content_type
+
+ events = _sse_events(stream_res.text)
+ assert events, stream_res.text
+ parsed_events = []
+ for event in events:
+ parsed = json.loads(event)
+ parsed_events.append(parsed)
+
+ assert any(evt.get("code") == 0 and isinstance(evt.get("data"), dict) for evt in parsed_events), parsed_events
+ assert parsed_events[-1].get("data") is True, parsed_events[-1]
+
+
+@pytest.mark.p2
+def test_chat_completion_validation_errors(rest_client, create_chat):
+ chat_id = create_chat("restful_completion_validation_chat")
+
+ missing_messages = rest_client.post(
+ "/chat/completions",
+ json={"chat_id": chat_id, "stream": False},
+ )
+ assert missing_messages.status_code == 200
+ missing_messages_payload = missing_messages.json()
+ assert missing_messages_payload["code"] == 101, missing_messages_payload
+ assert "required argument are missing: messages" in missing_messages_payload["message"], missing_messages_payload
+
+ missing_chat_for_session = rest_client.post(
+ "/chat/completions",
+ json={
+ "session_id": "some_session",
+ "messages": [{"role": "user", "content": "hello"}],
+ "stream": False,
+ },
+ )
+ assert missing_chat_for_session.status_code == 200
+ missing_chat_for_session_payload = missing_chat_for_session.json()
+ assert missing_chat_for_session_payload["code"] == 102, missing_chat_for_session_payload
+ assert "`chat_id` is required when `session_id` is provided." in missing_chat_for_session_payload["message"], missing_chat_for_session_payload
+
+ invalid_chat = rest_client.post(
+ "/chat/completions",
+ json={
+ "chat_id": "invalid_chat_id",
+ "session_id": "invalid_session",
+ "messages": [{"role": "user", "content": "hello"}],
+ "stream": False,
+ },
+ )
+ assert invalid_chat.status_code == 200
+ invalid_chat_payload = invalid_chat.json()
+ assert invalid_chat_payload["code"] == 109, invalid_chat_payload
+ assert "No authorization." in invalid_chat_payload["message"], invalid_chat_payload
diff --git a/test/testcases/restful_api/test_system.py b/test/testcases/restful_api/test_system.py
new file mode 100644
index 0000000000..e5022f4861
--- /dev/null
+++ b/test/testcases/restful_api/test_system.py
@@ -0,0 +1,159 @@
+#
+# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import pytest
+
+
+@pytest.mark.p1
+def test_system_ping(rest_client):
+ res = rest_client.get("/system/ping")
+ assert res.status_code == 200
+ assert res.text == "pong"
+
+
+@pytest.mark.p1
+def test_system_version(rest_client):
+ res = rest_client.get("/system/version")
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] == 0, payload
+ assert payload["data"], payload
+
+
+@pytest.mark.p2
+def test_system_status_requires_auth(rest_client_noauth):
+ res = rest_client_noauth.get("/system/status")
+ assert res.status_code == 401
+ payload = res.json()
+ assert payload["code"] == 401, payload
+ assert "Unauthorized" in payload["message"], payload
+
+
+@pytest.mark.p2
+def test_system_status_contract(rest_client):
+ res = rest_client.get("/system/status")
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] == 0, payload
+ for key in ("doc_engine", "storage", "database", "redis"):
+ assert key in payload["data"], payload
+
+
+@pytest.mark.p2
+def test_system_config_no_auth_required(rest_client_noauth):
+ res = rest_client_noauth.get("/system/config")
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] == 0, payload
+ assert "registerEnabled" in payload["data"], payload
+ assert "disablePasswordLogin" in payload["data"], payload
+
+
+@pytest.mark.p2
+def test_system_healthz_contract(rest_client_noauth):
+ res = rest_client_noauth.get("/system/healthz")
+ assert res.status_code in (200, 500)
+ payload = res.json()
+ assert isinstance(payload, dict), payload
+ assert payload, payload
+
+
+@pytest.mark.p2
+def test_system_tokens_auth_and_crud(rest_client, rest_client_noauth):
+ unauth_list = rest_client_noauth.get("/system/tokens")
+ assert unauth_list.status_code == 401
+ unauth_list_payload = unauth_list.json()
+ assert unauth_list_payload["code"] == 401, unauth_list_payload
+
+ create_res = rest_client.post("/system/tokens")
+ assert create_res.status_code == 200
+ create_payload = create_res.json()
+ assert create_payload["code"] == 0, create_payload
+ token = create_payload["data"]["token"]
+
+ list_res = rest_client.get("/system/tokens")
+ assert list_res.status_code == 200
+ list_payload = list_res.json()
+ assert list_payload["code"] == 0, list_payload
+ assert isinstance(list_payload["data"], list), list_payload
+ assert any(item.get("token") == token for item in list_payload["data"]), list_payload
+
+ delete_res = rest_client.delete(f"/system/tokens/{token}")
+ assert delete_res.status_code == 200
+ delete_payload = delete_res.json()
+ assert delete_payload["code"] == 0, delete_payload
+ assert delete_payload["data"] is True, delete_payload
+
+ delete_missing = rest_client.delete("/system/tokens/missing_token")
+ assert delete_missing.status_code == 200
+ delete_missing_payload = delete_missing.json()
+ assert delete_missing_payload["code"] == 0, delete_missing_payload
+ assert delete_missing_payload["data"] is True, delete_missing_payload
+
+
+@pytest.mark.p2
+def test_system_stats_auth_and_shape(rest_client, rest_client_noauth):
+ unauth_res = rest_client_noauth.get("/system/stats")
+ assert unauth_res.status_code == 401
+ unauth_payload = unauth_res.json()
+ assert unauth_payload["code"] == 401, unauth_payload
+
+ res = rest_client.get("/system/stats")
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] == 0, payload
+ data = payload["data"]
+ for key in ("pv", "uv", "speed", "tokens", "round", "thumb_up"):
+ assert key in data, payload
+ assert isinstance(data[key], list), payload
+
+
+@pytest.mark.p2
+def test_system_oceanbase_status_auth_contract(rest_client, rest_client_noauth):
+ unauth = rest_client_noauth.get("/system/oceanbase/status")
+ assert unauth.status_code == 401
+ assert unauth.json()["code"] == 401
+
+ res = rest_client.get("/system/oceanbase/status")
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] in (0, 500), payload
+ assert "data" in payload, payload
+
+
+@pytest.mark.p2
+def test_system_log_config_routes_auth_and_validation(rest_client, rest_client_noauth):
+ unauth = rest_client_noauth.get("/system/config/log")
+ assert unauth.status_code == 401
+ assert unauth.json()["code"] == 401
+
+ levels = rest_client.get("/system/config/log")
+ assert levels.status_code == 200
+ levels_payload = levels.json()
+ assert levels_payload["code"] == 0, levels_payload
+ assert isinstance(levels_payload["data"], dict), levels_payload
+
+ missing_body = rest_client.put("/system/config/log", json={})
+ assert missing_body.status_code == 200
+ missing_payload = missing_body.json()
+ assert missing_payload["code"] == 102, missing_payload
+ assert "pkg_name and level are required" in missing_payload["message"], missing_payload
+
+ invalid_level = rest_client.put("/system/config/log", json={"pkg_name": "rag", "level": "NOT_A_LEVEL"})
+ assert invalid_level.status_code == 200
+ invalid_payload = invalid_level.json()
+ assert invalid_payload["code"] == 102, invalid_payload
+ assert "Invalid log level" in invalid_payload["message"], invalid_payload
diff --git a/test/testcases/restful_api/test_task_routes.py b/test/testcases/restful_api/test_task_routes.py
new file mode 100644
index 0000000000..13a0fc8a9d
--- /dev/null
+++ b/test/testcases/restful_api/test_task_routes.py
@@ -0,0 +1,48 @@
+#
+# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import pytest
+
+
+@pytest.mark.p2
+def test_task_routes_require_auth(rest_client_noauth):
+ cancel_res = rest_client_noauth.post("/tasks/missing_task/cancel")
+ assert cancel_res.status_code == 401
+ cancel_payload = cancel_res.json()
+ assert cancel_payload["code"] == 401, cancel_payload
+
+ patch_res = rest_client_noauth.patch("/tasks/missing_task", json={"action": "stop"})
+ assert patch_res.status_code == 401
+ patch_payload = patch_res.json()
+ assert patch_payload["code"] == 401, patch_payload
+
+
+@pytest.mark.p2
+def test_patch_task_rejects_unsupported_action(rest_client):
+ res = rest_client.patch("/tasks/missing_task", json={"action": "pause"})
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] == 101, payload
+ assert "Only 'stop' is supported" in payload["message"], payload
+
+
+@pytest.mark.p2
+def test_cancel_missing_task_sets_cancel_contract(rest_client):
+ res = rest_client.post("/tasks/missing_task/cancel")
+ assert res.status_code == 200
+ payload = res.json()
+ assert payload["code"] == 0, payload
+ assert payload["data"] is True, payload
diff --git a/test/testcases/restful_api/test_user_tenant_routes_unit.py b/test/testcases/restful_api/test_user_tenant_routes_unit.py
new file mode 100644
index 0000000000..811a40654c
--- /dev/null
+++ b/test/testcases/restful_api/test_user_tenant_routes_unit.py
@@ -0,0 +1,1382 @@
+#
+# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import asyncio
+import base64
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType, SimpleNamespace
+
+import pytest
+
+
+class _DummyManager:
+ def route(self, *_args, **_kwargs):
+ def decorator(func):
+ return func
+
+ return decorator
+
+
+class _AwaitableValue:
+ def __init__(self, value):
+ self._value = value
+
+ def __await__(self):
+ async def _co():
+ return self._value
+
+ return _co().__await__()
+
+
+class _Field:
+ def __init__(self, name):
+ self.name = name
+
+ def __eq__(self, other):
+ return (self.name, other)
+
+
+class _Invitee:
+ def __init__(self, user_id="invitee-1", email="invitee@example.com"):
+ self.id = user_id
+ self.email = email
+
+ def to_dict(self):
+ return {
+ "id": self.id,
+ "avatar": "avatar-url",
+ "email": self.email,
+ "nickname": "Invitee",
+ "password": "ignored",
+ }
+
+
+def _run(coro):
+ return asyncio.run(coro)
+
+
+def _set_request_json(monkeypatch, module, payload):
+ async def _request_json():
+ return payload
+
+ monkeypatch.setattr(module, "get_request_json", _request_json)
+
+
+def _load_tenant_module(monkeypatch):
+ repo_root = Path(__file__).resolve().parents[3]
+
+ api_pkg = ModuleType("api")
+ api_pkg.__path__ = [str(repo_root / "api")]
+ monkeypatch.setitem(sys.modules, "api", api_pkg)
+
+ apps_mod = ModuleType("api.apps")
+ apps_mod.__path__ = [str(repo_root / "api" / "apps")]
+ apps_mod.current_user = SimpleNamespace(id="tenant-1", email="owner@example.com")
+ apps_mod.login_required = lambda fn: fn
+ monkeypatch.setitem(sys.modules, "api.apps", apps_mod)
+
+ db_mod = ModuleType("api.db")
+ db_mod.UserTenantRole = SimpleNamespace(NORMAL="normal", OWNER="owner", INVITE="invite")
+ monkeypatch.setitem(sys.modules, "api.db", db_mod)
+
+ db_models_mod = ModuleType("api.db.db_models")
+ db_models_mod.UserTenant = type(
+ "UserTenant",
+ (),
+ {
+ "tenant_id": _Field("tenant_id"),
+ "user_id": _Field("user_id"),
+ },
+ )
+ monkeypatch.setitem(sys.modules, "api.db.db_models", db_models_mod)
+
+ services_pkg = ModuleType("api.db.services")
+ services_pkg.__path__ = []
+ monkeypatch.setitem(sys.modules, "api.db.services", services_pkg)
+
+ user_service_mod = ModuleType("api.db.services.user_service")
+
+ class _UserTenantService:
+ @staticmethod
+ def get_by_tenant_id(_tenant_id):
+ return []
+
+ @staticmethod
+ def query(**_kwargs):
+ return []
+
+ @staticmethod
+ def save(**_kwargs):
+ return True
+
+ @staticmethod
+ def filter_delete(_conditions):
+ return True
+
+ @staticmethod
+ def get_tenants_by_user_id(_user_id):
+ return []
+
+ @staticmethod
+ def filter_update(_conditions, _payload):
+ return True
+
+ class _UserService:
+ @staticmethod
+ def query(**_kwargs):
+ return []
+
+ @staticmethod
+ def get_by_id(_user_id):
+ return False, None
+
+ user_service_mod.UserTenantService = _UserTenantService
+ user_service_mod.UserService = _UserService
+ monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod)
+
+ api_utils_mod = ModuleType("api.utils.api_utils")
+ api_utils_mod.get_json_result = lambda data=None, message="", code=0: {"code": code, "message": message, "data": data}
+ api_utils_mod.get_data_error_result = lambda message="": {"code": 102, "message": message, "data": False}
+ api_utils_mod.server_error_response = lambda exc: {"code": 100, "message": repr(exc), "data": False}
+ api_utils_mod.validate_request = lambda *_args, **_kwargs: (lambda fn: fn)
+ api_utils_mod.get_request_json = lambda: _AwaitableValue({})
+ monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
+
+ web_utils_mod = ModuleType("api.utils.web_utils")
+ web_utils_mod.send_invite_email = lambda **_kwargs: {"ok": True}
+ monkeypatch.setitem(sys.modules, "api.utils.web_utils", web_utils_mod)
+
+ common_pkg = ModuleType("common")
+ common_pkg.__path__ = [str(repo_root / "common")]
+ monkeypatch.setitem(sys.modules, "common", common_pkg)
+
+ constants_mod = ModuleType("common.constants")
+ constants_mod.RetCode = SimpleNamespace(AUTHENTICATION_ERROR=401, SERVER_ERROR=500, DATA_ERROR=102)
+ constants_mod.StatusEnum = SimpleNamespace(VALID=SimpleNamespace(value=1))
+ monkeypatch.setitem(sys.modules, "common.constants", constants_mod)
+
+ misc_utils_mod = ModuleType("common.misc_utils")
+ misc_utils_mod.get_uuid = lambda: "uuid-1"
+ monkeypatch.setitem(sys.modules, "common.misc_utils", misc_utils_mod)
+
+ time_utils_mod = ModuleType("common.time_utils")
+ time_utils_mod.delta_seconds = lambda _value: 0
+ monkeypatch.setitem(sys.modules, "common.time_utils", time_utils_mod)
+
+ settings_mod = ModuleType("common.settings")
+ settings_mod.MAIL_FRONTEND_URL = "https://frontend.example/invite"
+ monkeypatch.setitem(sys.modules, "common.settings", settings_mod)
+ common_pkg.settings = settings_mod
+
+ sys.modules.pop("test_tenant_app_unit_module", None)
+ module_path = repo_root / "api" / "apps" / "restful_apis" / "tenant_api.py"
+ spec = importlib.util.spec_from_file_location("test_tenant_app_unit_module", module_path)
+ module = importlib.util.module_from_spec(spec)
+ module.manager = _DummyManager()
+ monkeypatch.setitem(sys.modules, "test_tenant_app_unit_module", module)
+ spec.loader.exec_module(module)
+ return module
+
+
+@pytest.mark.p2
+def test_user_list_auth_success_exception_matrix_unit(monkeypatch):
+ module = _load_tenant_module(monkeypatch)
+
+ module.current_user.id = "other-user"
+ res = module.user_list("tenant-1")
+ assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res
+ assert res["message"] == "No authorization.", res
+
+ module.current_user.id = "tenant-1"
+ monkeypatch.setattr(
+ module.UserTenantService,
+ "get_by_tenant_id",
+ lambda _tenant_id: [{"id": "u1", "update_date": "2024-01-01 00:00:00"}],
+ )
+ monkeypatch.setattr(module, "delta_seconds", lambda _value: 42)
+ res = module.user_list("tenant-1")
+ assert res["code"] == 0, res
+ assert res["data"][0]["delta_seconds"] == 42, res
+
+ monkeypatch.setattr(module.UserTenantService, "get_by_tenant_id", lambda _tenant_id: (_ for _ in ()).throw(RuntimeError("list boom")))
+ res = module.user_list("tenant-1")
+ assert res["code"] == 100, res
+ assert "list boom" in res["message"], res
+
+
+@pytest.mark.p2
+def test_create_invite_role_and_email_failure_matrix_unit(monkeypatch):
+ module = _load_tenant_module(monkeypatch)
+
+ module.current_user.id = "other-user"
+ _set_request_json(monkeypatch, module, {"email": "invitee@example.com"})
+ res = _run(module.create("tenant-1"))
+ assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res
+ assert res["message"] == "No authorization.", res
+
+ module.current_user.id = "tenant-1"
+ monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: [])
+ res = _run(module.create("tenant-1"))
+ assert res["message"] == "User not found.", res
+
+ invitee = _Invitee()
+ monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: [invitee])
+ monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(role=module.UserTenantRole.NORMAL)])
+ res = _run(module.create("tenant-1"))
+ assert "already in the team." in res["message"], res
+
+ monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(role=module.UserTenantRole.OWNER)])
+ res = _run(module.create("tenant-1"))
+ assert "owner of the team." in res["message"], res
+
+ monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(role="strange-role")])
+ res = _run(module.create("tenant-1"))
+ assert "role: strange-role is invalid." in res["message"], res
+
+ saved = []
+ scheduled = []
+ monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [])
+ monkeypatch.setattr(module.UserTenantService, "save", lambda **kwargs: saved.append(kwargs) or True)
+ monkeypatch.setattr(module.UserService, "get_by_id", lambda _user_id: (True, SimpleNamespace(nickname="Inviter Nick")))
+ monkeypatch.setattr(module, "send_invite_email", lambda **kwargs: kwargs)
+ monkeypatch.setattr(module.asyncio, "create_task", lambda payload: scheduled.append(payload) or SimpleNamespace())
+ res = _run(module.create("tenant-1"))
+ assert res["code"] == 0, res
+ assert saved and saved[-1]["role"] == module.UserTenantRole.INVITE, saved
+ assert scheduled and scheduled[-1]["inviter"] == "Inviter Nick", scheduled
+ assert sorted(res["data"].keys()) == ["avatar", "email", "id", "nickname"], res
+
+ monkeypatch.setattr(module.asyncio, "create_task", lambda _payload: (_ for _ in ()).throw(RuntimeError("send boom")))
+ res = _run(module.create("tenant-1"))
+ assert res["code"] == module.RetCode.SERVER_ERROR, res
+ assert "Failed to send invite email." in res["message"], res
+
+
+@pytest.mark.p2
+def test_rm_and_tenant_list_matrix_unit(monkeypatch):
+ module = _load_tenant_module(monkeypatch)
+
+ module.current_user.id = "outsider"
+ _set_request_json(monkeypatch, module, {"user_id": "user-2"})
+ res = _run(module.rm("tenant-1"))
+ assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res
+ assert res["message"] == "No authorization.", res
+
+ module.current_user.id = "tenant-1"
+ deleted = []
+ monkeypatch.setattr(module.UserTenantService, "filter_delete", lambda conditions: deleted.append(conditions) or True)
+ res = _run(module.rm("tenant-1"))
+ assert res["code"] == 0, res
+ assert res["data"] is True, res
+ assert deleted, "filter_delete should be called"
+
+ monkeypatch.setattr(module.UserTenantService, "filter_delete", lambda _conditions: (_ for _ in ()).throw(RuntimeError("rm boom")))
+ res = _run(module.rm("tenant-1"))
+ assert res["code"] == 100, res
+ assert "rm boom" in res["message"], res
+
+ monkeypatch.setattr(
+ module.UserTenantService,
+ "get_tenants_by_user_id",
+ lambda _user_id: [{"id": "tenant-1", "update_date": "2024-01-01 00:00:00"}],
+ )
+ monkeypatch.setattr(module, "delta_seconds", lambda _value: 9)
+ res = module.tenant_list()
+ assert res["code"] == 0, res
+ assert res["data"][0]["delta_seconds"] == 9, res
+
+ monkeypatch.setattr(module.UserTenantService, "get_tenants_by_user_id", lambda _user_id: (_ for _ in ()).throw(RuntimeError("tenant boom")))
+ res = module.tenant_list()
+ assert res["code"] == 100, res
+ assert "tenant boom" in res["message"], res
+
+
+@pytest.mark.p2
+def test_agree_success_and_exception_unit(monkeypatch):
+ module = _load_tenant_module(monkeypatch)
+
+ calls = []
+ monkeypatch.setattr(module.UserTenantService, "filter_update", lambda conditions, payload: calls.append((conditions, payload)) or True)
+ res = module.agree("tenant-1")
+ assert res["code"] == 0, res
+ assert res["data"] is True, res
+ assert calls and calls[-1][1]["role"] == module.UserTenantRole.NORMAL
+
+ monkeypatch.setattr(module.UserTenantService, "filter_update", lambda _conditions, _payload: (_ for _ in ()).throw(RuntimeError("agree boom")))
+ res = module.agree("tenant-1")
+ assert res["code"] == 100, res
+ assert "agree boom" in res["message"], res
+
+
+class _Args(dict):
+ def get(self, key, default=None, type=None):
+ value = super().get(key, default)
+ if type is None:
+ return value
+ try:
+ return type(value)
+ except (TypeError, ValueError):
+ return default
+
+
+class _DummyResponse:
+ def __init__(self, data):
+ self.data = data
+ self.headers = {}
+
+
+class _DummyHTTPResponse:
+ def __init__(self, payload):
+ self._payload = payload
+
+ def json(self):
+ return self._payload
+
+
+class _DummyRedis:
+ def __init__(self):
+ self.store = {}
+
+ def get(self, key):
+ return self.store.get(key)
+
+ def set(self, key, value, _ttl=None):
+ self.store[key] = value
+
+ def delete(self, key):
+ self.store.pop(key, None)
+
+
+class _DummyUser:
+ def __init__(self, user_id, email, *, password="stored-password", is_active="1", nickname="nick"):
+ self.id = user_id
+ self.email = email
+ self.password = password
+ self.is_active = is_active
+ self.nickname = nickname
+ self.access_token = ""
+ self.save_calls = 0
+
+ def save(self):
+ self.save_calls += 1
+
+ def get_id(self):
+ return self.id
+
+ def to_json(self):
+ return {"id": self.id, "email": self.email, "nickname": self.nickname}
+
+ def to_dict(self):
+ return {"id": self.id, "email": self.email}
+
+
+def _set_request_args(monkeypatch, module, args=None):
+ monkeypatch.setattr(module, "request", SimpleNamespace(args=_Args(args or {})))
+
+
+@pytest.fixture(scope="session")
+def auth():
+ return "unit-auth"
+
+
+@pytest.fixture(scope="session", autouse=True)
+def set_tenant_info():
+ return None
+
+
+def _load_user_app(monkeypatch):
+ repo_root = Path(__file__).resolve().parents[3]
+
+ quart_mod = ModuleType("quart")
+ quart_mod.session = {}
+ quart_mod.request = SimpleNamespace(args=_Args({}))
+
+ async def _make_response(data):
+ return _DummyResponse(data)
+
+ quart_mod.make_response = _make_response
+ quart_mod.redirect = lambda url: {"redirect": url}
+ monkeypatch.setitem(sys.modules, "quart", quart_mod)
+
+ api_pkg = ModuleType("api")
+ api_pkg.__path__ = [str(repo_root / "api")]
+ monkeypatch.setitem(sys.modules, "api", api_pkg)
+
+ apps_mod = ModuleType("api.apps")
+ apps_mod.__path__ = [str(repo_root / "api" / "apps")]
+ apps_mod.current_user = _DummyUser("current-user", "current@example.com")
+ apps_mod.login_required = lambda fn: fn
+ apps_mod.login_user = lambda _user: True
+ apps_mod.logout_user = lambda: True
+ monkeypatch.setitem(sys.modules, "api.apps", apps_mod)
+ api_pkg.apps = apps_mod
+
+ apps_auth_mod = ModuleType("api.apps.auth")
+ apps_auth_mod.get_auth_client = lambda _config: SimpleNamespace(
+ get_authorization_url=lambda state: f"https://oauth.example/{state}"
+ )
+ monkeypatch.setitem(sys.modules, "api.apps.auth", apps_auth_mod)
+
+ db_mod = ModuleType("api.db")
+ db_mod.FileType = SimpleNamespace(FOLDER=SimpleNamespace(value="folder"))
+ db_mod.UserTenantRole = SimpleNamespace(OWNER="owner")
+ monkeypatch.setitem(sys.modules, "api.db", db_mod)
+ api_pkg.db = db_mod
+
+ db_models_mod = ModuleType("api.db.db_models")
+
+ class _DummyTenantLLMModel:
+ tenant_id = _Field("tenant_id")
+
+ @staticmethod
+ def delete():
+ class _DeleteQuery:
+ def where(self, *_args, **_kwargs):
+ return self
+
+ def execute(self):
+ return 1
+
+ return _DeleteQuery()
+
+ db_models_mod.TenantLLM = _DummyTenantLLMModel
+ monkeypatch.setitem(sys.modules, "api.db.db_models", db_models_mod)
+
+ services_pkg = ModuleType("api.db.services")
+ services_pkg.__path__ = []
+ monkeypatch.setitem(sys.modules, "api.db.services", services_pkg)
+
+ file_service_mod = ModuleType("api.db.services.file_service")
+
+ class _StubFileService:
+ @staticmethod
+ def insert(_data):
+ return True
+
+ file_service_mod.FileService = _StubFileService
+ monkeypatch.setitem(sys.modules, "api.db.services.file_service", file_service_mod)
+
+ llm_service_mod = ModuleType("api.db.services.llm_service")
+ llm_service_mod.get_init_tenant_llm = lambda _user_id: []
+ monkeypatch.setitem(sys.modules, "api.db.services.llm_service", llm_service_mod)
+
+ tenant_llm_service_mod = ModuleType("api.db.services.tenant_llm_service")
+
+ class _MockTableObject:
+ def __init__(self, **kwargs):
+ for key, value in kwargs.items():
+ setattr(self, key, value)
+
+ def to_dict(self):
+ return {k: v for k, v in self.__dict__.items()}
+
+ class _StubTenantLLMService:
+ @staticmethod
+ def insert_many(_payload):
+ return True
+
+ @staticmethod
+ def get_api_key(tenant_id, model_name, model_type=None):
+ return _MockTableObject(
+ id=1,
+ tenant_id=tenant_id,
+ llm_factory="",
+ model_type="chat",
+ llm_name=model_name,
+ api_key="fake-api-key",
+ api_base="https://api.example.com",
+ max_tokens=8192,
+ used_tokens=0,
+ status=1
+ )
+
+ tenant_llm_service_mod.TenantLLMService = _StubTenantLLMService
+ monkeypatch.setitem(sys.modules, "api.db.services.tenant_llm_service", tenant_llm_service_mod)
+
+ user_service_mod = ModuleType("api.db.services.user_service")
+
+ class _StubTenantService:
+ @staticmethod
+ def insert(**_kwargs):
+ return True
+
+ @staticmethod
+ def delete_by_id(_tenant_id):
+ return True
+
+ @staticmethod
+ def get_by_id(_tenant_id):
+ return True, SimpleNamespace(id=_tenant_id)
+
+ @staticmethod
+ def get_info_by(_user_id):
+ return []
+
+ @staticmethod
+ def update_by_id(_tenant_id, _payload):
+ return True
+
+ class _StubUserService:
+ @staticmethod
+ def query(**_kwargs):
+ return []
+
+ @staticmethod
+ def query_user(_email, _password):
+ return None
+
+ @staticmethod
+ def query_user_by_email(**_kwargs):
+ return []
+
+ @staticmethod
+ def save(**_kwargs):
+ return True
+
+ @staticmethod
+ def delete_by_id(_user_id):
+ return True
+
+ @staticmethod
+ def update_by_id(_user_id, _payload):
+ return True
+
+ @staticmethod
+ def update_user_password(_user_id, _new_password):
+ return True
+
+ class _StubUserTenantService:
+ @staticmethod
+ def insert(**_kwargs):
+ return True
+
+ @staticmethod
+ def query(**_kwargs):
+ return []
+
+ @staticmethod
+ def delete_by_id(_user_tenant_id):
+ return True
+
+ user_service_mod.TenantService = _StubTenantService
+ user_service_mod.UserService = _StubUserService
+ user_service_mod.UserTenantService = _StubUserTenantService
+ monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod)
+
+ api_utils_mod = ModuleType("api.utils.api_utils")
+
+ async def _default_request_json():
+ return {}
+
+ def _get_json_result(code=0, message="success", data=None):
+ return {"code": code, "message": message, "data": data}
+
+ def _get_data_error_result(code=102, message="Sorry! Data missing!", data=None):
+ return {"code": code, "message": message, "data": data}
+
+ def _server_error_response(error):
+ return {"code": 100, "message": repr(error)}
+
+ def _validate_request(*_args, **_kwargs):
+ def _decorator(func):
+ return func
+
+ return _decorator
+
+ api_utils_mod.get_request_json = _default_request_json
+ api_utils_mod.get_json_result = _get_json_result
+ api_utils_mod.get_data_error_result = _get_data_error_result
+ api_utils_mod.server_error_response = _server_error_response
+ api_utils_mod.validate_request = _validate_request
+ monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
+
+ tenant_utils_mod = ModuleType("api.utils.tenant_utils")
+ tenant_utils_mod.ensure_tenant_model_id_for_params = lambda _tenant_id, params: params
+ monkeypatch.setitem(sys.modules, "api.utils.tenant_utils", tenant_utils_mod)
+
+ crypt_mod = ModuleType("api.utils.crypt")
+ crypt_mod.decrypt = lambda value: value
+ monkeypatch.setitem(sys.modules, "api.utils.crypt", crypt_mod)
+
+ web_utils_mod = ModuleType("api.utils.web_utils")
+ web_utils_mod.send_email_html = lambda *_args, **_kwargs: _AwaitableValue(True)
+ web_utils_mod.OTP_LENGTH = 6
+ web_utils_mod.OTP_TTL_SECONDS = 600
+ web_utils_mod.ATTEMPT_LIMIT = 5
+ web_utils_mod.ATTEMPT_LOCK_SECONDS = 600
+ web_utils_mod.RESEND_COOLDOWN_SECONDS = 60
+ web_utils_mod.otp_keys = lambda email: (
+ f"otp:{email}:code",
+ f"otp:{email}:attempts",
+ f"otp:{email}:last",
+ f"otp:{email}:lock",
+ )
+ web_utils_mod.hash_code = lambda code, _salt: f"hash:{code}"
+ web_utils_mod.captcha_key = lambda email: f"captcha:{email}"
+ monkeypatch.setitem(sys.modules, "api.utils.web_utils", web_utils_mod)
+
+ common_pkg = ModuleType("common")
+ common_pkg.__path__ = [str(repo_root / "common")]
+ monkeypatch.setitem(sys.modules, "common", common_pkg)
+
+ settings_mod = ModuleType("common.settings")
+ settings_mod.OAUTH_CONFIG = {
+ "github": {"display_name": "GitHub", "icon": "gh"},
+ "feishu": {"display_name": "Feishu", "icon": "fs"},
+ }
+ settings_mod.GITHUB_OAUTH = {"url": "https://github.example/oauth", "client_id": "cid", "secret_key": "sk"}
+ settings_mod.FEISHU_OAUTH = {
+ "app_access_token_url": "https://feishu.example/app_token",
+ "user_access_token_url": "https://feishu.example/user_token",
+ "app_id": "app-id",
+ "app_secret": "app-secret",
+ "grant_type": "authorization_code",
+ }
+ settings_mod.CHAT_MDL = "chat-mdl"
+ settings_mod.EMBEDDING_MDL = "embd-mdl"
+ settings_mod.ASR_MDL = "asr-mdl"
+ settings_mod.PARSERS = []
+ settings_mod.IMAGE2TEXT_MDL = "img-mdl"
+ settings_mod.RERANK_MDL = "rerank-mdl"
+ settings_mod.REGISTER_ENABLED = True
+ monkeypatch.setitem(sys.modules, "common.settings", settings_mod)
+ common_pkg.settings = settings_mod
+
+ constants_mod = ModuleType("common.constants")
+ constants_mod.RetCode = SimpleNamespace(
+ AUTHENTICATION_ERROR=401,
+ SERVER_ERROR=500,
+ FORBIDDEN=403,
+ EXCEPTION_ERROR=100,
+ OPERATING_ERROR=300,
+ ARGUMENT_ERROR=101,
+ DATA_ERROR=102,
+ NOT_EFFECTIVE=103,
+ SUCCESS=0,
+ )
+ monkeypatch.setitem(sys.modules, "common.constants", constants_mod)
+
+ connection_utils_mod = ModuleType("common.connection_utils")
+
+ async def _construct_response(data=None, auth=None, message=""):
+ return {"code": 0, "message": message, "data": data, "auth": auth}
+
+ connection_utils_mod.construct_response = _construct_response
+ monkeypatch.setitem(sys.modules, "common.connection_utils", connection_utils_mod)
+
+ time_utils_mod = ModuleType("common.time_utils")
+ time_utils_mod.current_timestamp = lambda: 111
+ time_utils_mod.datetime_format = lambda _dt: "2024-01-01 00:00:00"
+ time_utils_mod.get_format_time = lambda: "2024-01-01 00:00:00"
+ monkeypatch.setitem(sys.modules, "common.time_utils", time_utils_mod)
+
+ misc_utils_mod = ModuleType("common.misc_utils")
+ misc_utils_mod.download_img = lambda _url: "avatar"
+ misc_utils_mod.get_uuid = lambda: "uuid-default"
+ monkeypatch.setitem(sys.modules, "common.misc_utils", misc_utils_mod)
+
+ http_client_mod = ModuleType("common.http_client")
+
+ async def _async_request(_method, _url, **_kwargs):
+ return _DummyHTTPResponse({})
+
+ http_client_mod.async_request = _async_request
+ monkeypatch.setitem(sys.modules, "common.http_client", http_client_mod)
+
+ rag_pkg = ModuleType("rag")
+ rag_pkg.__path__ = [str(repo_root / "rag")]
+ monkeypatch.setitem(sys.modules, "rag", rag_pkg)
+
+ rag_utils_pkg = ModuleType("rag.utils")
+ rag_utils_pkg.__path__ = [str(repo_root / "rag" / "utils")]
+ monkeypatch.setitem(sys.modules, "rag.utils", rag_utils_pkg)
+
+ redis_mod = ModuleType("rag.utils.redis_conn")
+ redis_mod.REDIS_CONN = _DummyRedis()
+ monkeypatch.setitem(sys.modules, "rag.utils.redis_conn", redis_mod)
+
+ module_name = "test_user_app_unit_module"
+ module_path = repo_root / "api" / "apps" / "restful_apis" / "user_api.py"
+ spec = importlib.util.spec_from_file_location(module_name, module_path)
+ module = importlib.util.module_from_spec(spec)
+ module.manager = _DummyManager()
+ monkeypatch.setitem(sys.modules, module_name, module)
+ spec.loader.exec_module(module)
+ return module
+
+
+@pytest.mark.p2
+def test_login_route_branch_matrix_unit(monkeypatch):
+ module = _load_user_app(monkeypatch)
+
+ _set_request_json(monkeypatch, module, {})
+ res = _run(module.login())
+ assert res["code"] == module.RetCode.AUTHENTICATION_ERROR
+ assert "Unauthorized" in res["message"]
+
+ _set_request_json(monkeypatch, module, {"email": "unknown@example.com", "password": "enc"})
+ monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: [])
+ res = _run(module.login())
+ assert res["code"] == module.RetCode.AUTHENTICATION_ERROR
+ assert "not registered" in res["message"]
+
+ _set_request_json(monkeypatch, module, {"email": "known@example.com", "password": "enc"})
+ monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: [SimpleNamespace(email="known@example.com")])
+
+ def _raise_decrypt(_value):
+ raise RuntimeError("decrypt explode")
+
+ monkeypatch.setattr(module, "decrypt", _raise_decrypt)
+ res = _run(module.login())
+ assert res["code"] == module.RetCode.SERVER_ERROR
+ assert "Fail to crypt password" in res["message"]
+
+ user_inactive = _DummyUser("u-inactive", "known@example.com", is_active="0")
+ monkeypatch.setattr(module, "decrypt", lambda value: value)
+ monkeypatch.setattr(module.UserService, "query_user", lambda _email, _password: user_inactive)
+ res = _run(module.login())
+ assert res["code"] == module.RetCode.FORBIDDEN
+ assert "disabled" in res["message"]
+
+ monkeypatch.setattr(module.UserService, "query_user", lambda _email, _password: None)
+ res = _run(module.login())
+ assert res["code"] == module.RetCode.AUTHENTICATION_ERROR
+ assert "do not match" in res["message"]
+
+
+@pytest.mark.p2
+def test_login_channels_and_oauth_login_matrix_unit(monkeypatch):
+ module = _load_user_app(monkeypatch)
+
+ module.settings.OAUTH_CONFIG = {"github": {"display_name": "GitHub", "icon": "gh"}}
+ res = _run(module.get_login_channels())
+ assert res["code"] == 0
+ assert res["data"][0]["channel"] == "github"
+
+ class _BrokenOAuthConfig:
+ @staticmethod
+ def items():
+ raise RuntimeError("broken oauth config")
+
+ module.settings.OAUTH_CONFIG = _BrokenOAuthConfig()
+ res = _run(module.get_login_channels())
+ assert res["code"] == module.RetCode.EXCEPTION_ERROR
+ assert "Load channels failure" in res["message"]
+
+ module.settings.OAUTH_CONFIG = {"github": {"display_name": "GitHub", "icon": "gh"}}
+ with pytest.raises(ValueError, match="Invalid channel name: missing"):
+ _run(module.oauth_login("missing"))
+
+ module.session.clear()
+ monkeypatch.setattr(module, "get_uuid", lambda: "state-123")
+
+ class _AuthClient:
+ @staticmethod
+ def get_authorization_url(state):
+ return f"https://oauth.example/{state}"
+
+ monkeypatch.setattr(module, "get_auth_client", lambda _config: _AuthClient())
+ res = _run(module.oauth_login("github"))
+ assert res["redirect"] == "https://oauth.example/state-123"
+ assert module.session["oauth_state"] == "state-123"
+
+
+@pytest.mark.p2
+def test_oauth_callback_matrix_unit(monkeypatch):
+ module = _load_user_app(monkeypatch)
+ module.settings.OAUTH_CONFIG = {"github": {"display_name": "GitHub", "icon": "gh"}}
+
+ class _SyncAuthClient:
+ def __init__(self, token_info, user_info):
+ self._token_info = token_info
+ self._user_info = user_info
+
+ def exchange_code_for_token(self, _code):
+ return self._token_info
+
+ def fetch_user_info(self, _token, id_token=None):
+ _ = id_token
+ return self._user_info
+
+ class _AsyncAuthClient:
+ def __init__(self, token_info, user_info):
+ self._token_info = token_info
+ self._user_info = user_info
+
+ async def async_exchange_code_for_token(self, _code):
+ return self._token_info
+
+ async def async_fetch_user_info(self, _token, id_token=None):
+ _ = id_token
+ return self._user_info
+
+ _set_request_args(monkeypatch, module, {"state": "x", "code": "c"})
+ module.session.clear()
+ res = _run(module.oauth_callback("missing"))
+ assert "Invalid channel name: missing" in res["redirect"]
+
+ sync_ok = _SyncAuthClient(
+ token_info={"access_token": "token-sync", "id_token": "id-sync"},
+ user_info=SimpleNamespace(email="sync@example.com", avatar_url="http://img", nickname="sync"),
+ )
+ monkeypatch.setattr(module, "get_auth_client", lambda _config: sync_ok)
+
+ module.session.clear()
+ module.session["oauth_state"] = "expected"
+ _set_request_args(monkeypatch, module, {"state": "wrong", "code": "code"})
+ res = _run(module.oauth_callback("github"))
+ assert res["redirect"] == "/?error=invalid_state"
+
+ module.session.clear()
+ module.session["oauth_state"] = "ok-state"
+ _set_request_args(monkeypatch, module, {"state": "ok-state"})
+ res = _run(module.oauth_callback("github"))
+ assert res["redirect"] == "/?error=missing_code"
+
+ sync_missing_token = _SyncAuthClient(
+ token_info={"id_token": "id-only"},
+ user_info=SimpleNamespace(email="sync@example.com", avatar_url="http://img", nickname="sync"),
+ )
+ monkeypatch.setattr(module, "get_auth_client", lambda _config: sync_missing_token)
+ module.session.clear()
+ module.session["oauth_state"] = "token-state"
+ _set_request_args(monkeypatch, module, {"state": "token-state", "code": "code"})
+ res = _run(module.oauth_callback("github"))
+ assert res["redirect"] == "/?error=token_failed"
+
+ sync_missing_email = _SyncAuthClient(
+ token_info={"access_token": "token-sync", "id_token": "id-sync"},
+ user_info=SimpleNamespace(email=None, avatar_url="http://img", nickname="sync"),
+ )
+ monkeypatch.setattr(module, "get_auth_client", lambda _config: sync_missing_email)
+ module.session.clear()
+ module.session["oauth_state"] = "email-state"
+ _set_request_args(monkeypatch, module, {"state": "email-state", "code": "code"})
+ res = _run(module.oauth_callback("github"))
+ assert res["redirect"] == "/?error=email_missing"
+
+ async_new_user = _AsyncAuthClient(
+ token_info={"access_token": "token-async", "id_token": "id-async"},
+ user_info=SimpleNamespace(email="new@example.com", avatar_url="http://img", nickname="new-user"),
+ )
+ monkeypatch.setattr(module, "get_auth_client", lambda _config: async_new_user)
+ monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: [])
+
+ def _raise_download(_url):
+ raise RuntimeError("download explode")
+
+ monkeypatch.setattr(module, "download_img", _raise_download)
+ monkeypatch.setattr(module, "user_register", lambda _user_id, _user: None)
+ rollback_calls = []
+ monkeypatch.setattr(module, "rollback_user_registration", lambda user_id: rollback_calls.append(user_id))
+ monkeypatch.setattr(module, "get_uuid", lambda: "new-user-id")
+ module.session.clear()
+ module.session["oauth_state"] = "new-user-state"
+ _set_request_args(monkeypatch, module, {"state": "new-user-state", "code": "code"})
+ res = _run(module.oauth_callback("github"))
+ assert "Failed to register new@example.com" in res["redirect"]
+ assert rollback_calls == ["new-user-id"]
+
+ monkeypatch.setattr(module, "download_img", lambda _url: "avatar")
+ monkeypatch.setattr(
+ module,
+ "user_register",
+ lambda _user_id, _user: [_DummyUser("dup-1", "new@example.com"), _DummyUser("dup-2", "new@example.com")],
+ )
+ rollback_calls.clear()
+ module.session.clear()
+ module.session["oauth_state"] = "dup-user-state"
+ _set_request_args(monkeypatch, module, {"state": "dup-user-state", "code": "code"})
+ res = _run(module.oauth_callback("github"))
+ assert "Same email: new@example.com exists!" in res["redirect"]
+ assert rollback_calls == ["new-user-id"]
+
+ new_user = _DummyUser("new-user", "new@example.com")
+ login_calls = []
+ monkeypatch.setattr(module, "login_user", lambda user: login_calls.append(user))
+ monkeypatch.setattr(module, "user_register", lambda _user_id, _user: [new_user])
+ module.session.clear()
+ module.session["oauth_state"] = "create-user-state"
+ _set_request_args(monkeypatch, module, {"state": "create-user-state", "code": "code"})
+ res = _run(module.oauth_callback("github"))
+ assert res["redirect"] == "/?auth=new-user"
+ assert login_calls and login_calls[-1] is new_user
+
+ async_existing_inactive = _AsyncAuthClient(
+ token_info={"access_token": "token-existing", "id_token": "id-existing"},
+ user_info=SimpleNamespace(email="existing@example.com", avatar_url="http://img", nickname="existing"),
+ )
+ monkeypatch.setattr(module, "get_auth_client", lambda _config: async_existing_inactive)
+ inactive_user = _DummyUser("existing-user", "existing@example.com", is_active="0")
+ monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: [inactive_user])
+ module.session.clear()
+ module.session["oauth_state"] = "inactive-state"
+ _set_request_args(monkeypatch, module, {"state": "inactive-state", "code": "code"})
+ res = _run(module.oauth_callback("github"))
+ assert res["redirect"] == "/?error=user_inactive"
+
+ async_existing_ok = _AsyncAuthClient(
+ token_info={"access_token": "token-existing", "id_token": "id-existing"},
+ user_info=SimpleNamespace(email="existing@example.com", avatar_url="http://img", nickname="existing"),
+ )
+ monkeypatch.setattr(module, "get_auth_client", lambda _config: async_existing_ok)
+ existing_user = _DummyUser("existing-user", "existing@example.com")
+ monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: [existing_user])
+ login_calls.clear()
+ monkeypatch.setattr(module, "login_user", lambda user: login_calls.append(user))
+ monkeypatch.setattr(module, "get_uuid", lambda: "existing-token")
+ module.session.clear()
+ module.session["oauth_state"] = "existing-state"
+ _set_request_args(monkeypatch, module, {"state": "existing-state", "code": "code"})
+ res = _run(module.oauth_callback("github"))
+ assert res["redirect"] == "/?auth=existing-user"
+ assert existing_user.access_token == "existing-token"
+ assert existing_user.save_calls == 1
+ assert login_calls and login_calls[-1] is existing_user
+
+
+@pytest.mark.p2
+def test_logout_setting_profile_matrix_unit(monkeypatch):
+ module = _load_user_app(monkeypatch)
+
+ current_user = _DummyUser("current-user", "current@example.com", password="stored-password")
+ monkeypatch.setattr(module, "current_user", current_user)
+ monkeypatch.setattr(module.secrets, "token_hex", lambda _n: "abcdef")
+ logout_calls = []
+ monkeypatch.setattr(module, "logout_user", lambda: logout_calls.append(True))
+
+ res = _run(module.log_out())
+ assert res["code"] == 0
+ assert current_user.access_token == "INVALID_abcdef"
+ assert current_user.save_calls == 1
+ assert logout_calls == [True]
+
+ _set_request_json(monkeypatch, module, {"password": "old-password", "new_password": "new-password"})
+ monkeypatch.setattr(module, "decrypt", lambda value: value)
+ monkeypatch.setattr(module, "check_password_hash", lambda _hashed, _plain: False)
+ res = _run(module.setting_user())
+ assert res["code"] == module.RetCode.AUTHENTICATION_ERROR
+ assert "Password error" in res["message"]
+
+ _set_request_json(
+ monkeypatch,
+ module,
+ {
+ "password": "old-password",
+ "new_password": "new-password",
+ "nickname": "neo",
+ "email": "blocked@example.com",
+ "status": "disabled",
+ "theme": "dark",
+ },
+ )
+ monkeypatch.setattr(module, "check_password_hash", lambda _hashed, _plain: True)
+ monkeypatch.setattr(module, "decrypt", lambda value: f"dec:{value}")
+ monkeypatch.setattr(module, "generate_password_hash", lambda value: f"hash:{value}")
+ update_calls = {}
+
+ def _update_by_id(user_id, payload):
+ update_calls["user_id"] = user_id
+ update_calls["payload"] = payload
+ return True
+
+ monkeypatch.setattr(module.UserService, "update_by_id", _update_by_id)
+ res = _run(module.setting_user())
+ assert res["code"] == 0
+ assert res["data"] is True
+ assert update_calls["user_id"] == "current-user"
+ assert update_calls["payload"]["password"] == "hash:dec:new-password"
+ assert update_calls["payload"]["nickname"] == "neo"
+ assert update_calls["payload"]["theme"] == "dark"
+ assert "email" not in update_calls["payload"]
+ assert "status" not in update_calls["payload"]
+
+ _set_request_json(monkeypatch, module, {"nickname": "neo"})
+
+ def _raise_update(_user_id, _payload):
+ raise RuntimeError("update explode")
+
+ monkeypatch.setattr(module.UserService, "update_by_id", _raise_update)
+ res = _run(module.setting_user())
+ assert res["code"] == module.RetCode.EXCEPTION_ERROR
+ assert "Update failure" in res["message"]
+
+ res = _run(module.user_profile())
+ assert res["code"] == 0
+ assert res["data"] == current_user.to_dict()
+
+
+@pytest.mark.p2
+def test_registration_helpers_and_register_route_matrix_unit(monkeypatch):
+ module = _load_user_app(monkeypatch)
+
+ deleted = {"user": 0, "tenant": 0, "user_tenant": 0, "tenant_llm": 0}
+ monkeypatch.setattr(module.UserService, "delete_by_id", lambda _user_id: deleted.__setitem__("user", deleted["user"] + 1))
+ monkeypatch.setattr(module.TenantService, "delete_by_id", lambda _tenant_id: deleted.__setitem__("tenant", deleted["tenant"] + 1))
+ monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(id="ut-1")])
+ monkeypatch.setattr(module.UserTenantService, "delete_by_id", lambda _ut_id: deleted.__setitem__("user_tenant", deleted["user_tenant"] + 1))
+
+ class _DeleteQuery:
+ def where(self, *_args, **_kwargs):
+ return self
+
+ def execute(self):
+ deleted["tenant_llm"] += 1
+ return 1
+
+ monkeypatch.setattr(module.TenantLLM, "delete", lambda: _DeleteQuery())
+ module.rollback_user_registration("user-1")
+ assert deleted == {"user": 1, "tenant": 1, "user_tenant": 1, "tenant_llm": 1}, deleted
+
+ monkeypatch.setattr(module.UserService, "delete_by_id", lambda _user_id: (_ for _ in ()).throw(RuntimeError("u boom")))
+ monkeypatch.setattr(module.TenantService, "delete_by_id", lambda _tenant_id: (_ for _ in ()).throw(RuntimeError("t boom")))
+ monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: (_ for _ in ()).throw(RuntimeError("ut boom")))
+
+ class _RaisingDeleteQuery:
+ def where(self, *_args, **_kwargs):
+ raise RuntimeError("llm boom")
+
+ monkeypatch.setattr(module.TenantLLM, "delete", lambda: _RaisingDeleteQuery())
+ module.rollback_user_registration("user-2")
+
+ monkeypatch.setattr(module.UserService, "save", lambda **_kwargs: False)
+ res = module.user_register(
+ "new-user",
+ {
+ "nickname": "new",
+ "email": "new@example.com",
+ "password": "pw",
+ "access_token": "tk",
+ "login_channel": "password",
+ "last_login_time": "2024-01-01 00:00:00",
+ "is_superuser": False,
+ },
+ )
+ assert res is None
+
+ monkeypatch.setattr(module.settings, "REGISTER_ENABLED", False)
+ _set_request_json(monkeypatch, module, {"nickname": "neo", "email": "neo@example.com", "password": "enc"})
+ res = _run(module.user_add())
+ assert res["code"] == module.RetCode.OPERATING_ERROR, res
+ assert "disabled" in res["message"], res
+
+ monkeypatch.setattr(module.settings, "REGISTER_ENABLED", True)
+ _set_request_json(monkeypatch, module, {"nickname": "neo", "email": "bad-email", "password": "enc"})
+ res = _run(module.user_add())
+ assert res["code"] == module.RetCode.OPERATING_ERROR, res
+ assert "Invalid email address" in res["message"], res
+
+ monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: [])
+ monkeypatch.setattr(module, "decrypt", lambda value: value)
+ monkeypatch.setattr(module, "get_uuid", lambda: "new-user-id")
+ rollback_calls = []
+ monkeypatch.setattr(module, "rollback_user_registration", lambda user_id: rollback_calls.append(user_id))
+
+ _set_request_json(monkeypatch, module, {"nickname": "neo", "email": "neo@example.com", "password": "enc"})
+ monkeypatch.setattr(module, "user_register", lambda _user_id, _payload: None)
+ res = _run(module.user_add())
+ assert res["code"] == module.RetCode.EXCEPTION_ERROR, res
+ assert "Fail to register neo@example.com." in res["message"], res
+ assert rollback_calls == ["new-user-id"], rollback_calls
+
+ rollback_calls.clear()
+ monkeypatch.setattr(
+ module,
+ "user_register",
+ lambda _user_id, _payload: [_DummyUser("dup-1", "neo@example.com"), _DummyUser("dup-2", "neo@example.com")],
+ )
+ _set_request_json(monkeypatch, module, {"nickname": "neo", "email": "neo@example.com", "password": "enc"})
+ res = _run(module.user_add())
+ assert res["code"] == module.RetCode.EXCEPTION_ERROR, res
+ assert "Same email: neo@example.com exists!" in res["message"], res
+ assert rollback_calls == ["new-user-id"], rollback_calls
+
+
+@pytest.mark.p2
+def test_tenant_info_and_set_tenant_info_exception_matrix_unit(monkeypatch):
+ module = _load_user_app(monkeypatch)
+
+ monkeypatch.setattr(module.TenantService, "get_info_by", lambda _uid: [])
+ res = _run(module.tenant_info())
+ assert res["code"] == module.RetCode.DATA_ERROR, res
+ assert "Tenant not found" in res["message"], res
+
+ def _raise_tenant_info(_uid):
+ raise RuntimeError("tenant info boom")
+
+ monkeypatch.setattr(module.TenantService, "get_info_by", _raise_tenant_info)
+ res = _run(module.tenant_info())
+ assert res["code"] == module.RetCode.EXCEPTION_ERROR, res
+ assert "tenant info boom" in res["message"], res
+
+ _set_request_json(
+ monkeypatch,
+ module,
+ {"tenant_id": "tenant-1", "llm_id": "l", "embd_id": "e", "asr_id": "a", "img2txt_id": "i"},
+ )
+
+ def _raise_update(_tenant_id, _payload):
+ raise RuntimeError("tenant update boom")
+
+ monkeypatch.setattr(module.TenantService, "update_by_id", _raise_update)
+ res = _run(module.set_tenant_info())
+ assert res["code"] == module.RetCode.EXCEPTION_ERROR, res
+ assert "tenant update boom" in res["message"], res
+
+
+@pytest.mark.p2
+def test_forget_captcha_and_send_otp_matrix_unit(monkeypatch):
+ module = _load_user_app(monkeypatch)
+
+ class _Headers(dict):
+ def set(self, key, value):
+ self[key] = value
+
+ async def _make_response(data):
+ return SimpleNamespace(data=data, headers=_Headers())
+
+ monkeypatch.setattr(module, "make_response", _make_response)
+
+ captcha_pkg = ModuleType("captcha")
+ captcha_image_mod = ModuleType("captcha.image")
+
+ class _ImageCaptcha:
+ def __init__(self, **_kwargs):
+ pass
+
+ def generate(self, text):
+ return SimpleNamespace(read=lambda: f"img:{text}".encode())
+
+ captcha_image_mod.ImageCaptcha = _ImageCaptcha
+ monkeypatch.setitem(sys.modules, "captcha", captcha_pkg)
+ monkeypatch.setitem(sys.modules, "captcha.image", captcha_image_mod)
+
+ _set_request_args(monkeypatch, module, {"email": ""})
+ res = _run(module.forget_get_captcha())
+ assert res["code"] == module.RetCode.ARGUMENT_ERROR, res
+
+ monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: [])
+ _set_request_args(monkeypatch, module, {"email": "nobody@example.com"})
+ res = _run(module.forget_get_captcha())
+ assert res["code"] == module.RetCode.DATA_ERROR, res
+
+ monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: [_DummyUser("u1", "ok@example.com")])
+ monkeypatch.setattr(module.secrets, "choice", lambda _allowed: "A")
+ _set_request_args(monkeypatch, module, {"email": "ok@example.com"})
+ res = _run(module.forget_get_captcha())
+ assert res.data.startswith(b"img:"), res
+ assert res.headers["Content-Type"] == "image/JPEG", res.headers
+ assert module.REDIS_CONN.get(module.captcha_key("ok@example.com")), module.REDIS_CONN.store
+
+ _set_request_json(monkeypatch, module, {"email": "", "captcha": ""})
+ res = _run(module.forget_send_otp())
+ assert res["code"] == module.RetCode.ARGUMENT_ERROR, res
+
+ monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: [])
+ _set_request_json(monkeypatch, module, {"email": "none@example.com", "captcha": "AAAA"})
+ res = _run(module.forget_send_otp())
+ assert res["code"] == module.RetCode.DATA_ERROR, res
+
+ monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: [_DummyUser("u1", "ok@example.com")])
+ _set_request_json(monkeypatch, module, {"email": "ok@example.com", "captcha": "AAAA"})
+ module.REDIS_CONN.store.pop(module.captcha_key("ok@example.com"), None)
+ res = _run(module.forget_send_otp())
+ assert res["code"] == module.RetCode.NOT_EFFECTIVE, res
+
+ module.REDIS_CONN.store[module.captcha_key("ok@example.com")] = "ABCD"
+ _set_request_json(monkeypatch, module, {"email": "ok@example.com", "captcha": "ZZZZ"})
+ res = _run(module.forget_send_otp())
+ assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res
+
+ monkeypatch.setattr(module.time, "time", lambda: 1000)
+ k_code, k_attempts, k_last, k_lock = module.otp_keys("ok@example.com")
+ module.REDIS_CONN.store[module.captcha_key("ok@example.com")] = "ABCD"
+ module.REDIS_CONN.store[k_last] = "990"
+ _set_request_json(monkeypatch, module, {"email": "ok@example.com", "captcha": "ABCD"})
+ res = _run(module.forget_send_otp())
+ assert res["code"] == module.RetCode.NOT_EFFECTIVE, res
+ assert "wait" in res["message"], res
+
+ module.REDIS_CONN.store[module.captcha_key("ok@example.com")] = "ABCD"
+ module.REDIS_CONN.store[k_last] = "bad-timestamp"
+ monkeypatch.setattr(module.secrets, "choice", lambda _allowed: "B")
+ monkeypatch.setattr(module.os, "urandom", lambda _n: b"\x00" * 16)
+ monkeypatch.setattr(module, "hash_code", lambda code, _salt: f"HASH_{code}")
+
+ async def _raise_send_email(*_args, **_kwargs):
+ raise RuntimeError("send email boom")
+
+ monkeypatch.setattr(module, "send_email_html", _raise_send_email)
+ _set_request_json(monkeypatch, module, {"email": "ok@example.com", "captcha": "ABCD"})
+ res = _run(module.forget_send_otp())
+ assert res["code"] == module.RetCode.SERVER_ERROR, res
+ assert "failed to send email" in res["message"], res
+
+ async def _ok_send_email(*_args, **_kwargs):
+ return True
+
+ module.REDIS_CONN.store[module.captcha_key("ok@example.com")] = "ABCD"
+ module.REDIS_CONN.store.pop(k_last, None)
+ monkeypatch.setattr(module, "send_email_html", _ok_send_email)
+ _set_request_json(monkeypatch, module, {"email": "ok@example.com", "captcha": "ABCD"})
+ res = _run(module.forget_send_otp())
+ assert res["code"] == module.RetCode.SUCCESS, res
+ assert res["data"] is True, res
+ assert module.REDIS_CONN.get(k_code), module.REDIS_CONN.store
+ assert module.REDIS_CONN.get(k_attempts) == 0, module.REDIS_CONN.store
+ assert module.REDIS_CONN.get(k_lock) is None, module.REDIS_CONN.store
+
+
+@pytest.mark.p2
+def test_forget_verify_otp_matrix_unit(monkeypatch):
+ module = _load_user_app(monkeypatch)
+ email = "ok@example.com"
+ k_code, k_attempts, k_last, k_lock = module.otp_keys(email)
+ salt = b"\x01" * 16
+ monkeypatch.setattr(module, "hash_code", lambda code, _salt: f"HASH_{code}")
+
+ _set_request_json(monkeypatch, module, {})
+ res = _run(module.forget_verify_otp())
+ assert res["code"] == module.RetCode.ARGUMENT_ERROR, res
+
+ monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: [])
+ _set_request_json(monkeypatch, module, {"email": email, "otp": "ABCDEF"})
+ res = _run(module.forget_verify_otp())
+ assert res["code"] == module.RetCode.DATA_ERROR, res
+
+ monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: [_DummyUser("u1", email)])
+ module.REDIS_CONN.store[k_lock] = "1"
+ _set_request_json(monkeypatch, module, {"email": email, "otp": "ABCDEF"})
+ res = _run(module.forget_verify_otp())
+ assert res["code"] == module.RetCode.NOT_EFFECTIVE, res
+ module.REDIS_CONN.store.pop(k_lock, None)
+
+ module.REDIS_CONN.store.pop(k_code, None)
+ _set_request_json(monkeypatch, module, {"email": email, "otp": "ABCDEF"})
+ res = _run(module.forget_verify_otp())
+ assert res["code"] == module.RetCode.NOT_EFFECTIVE, res
+
+ module.REDIS_CONN.store[k_code] = "broken"
+ _set_request_json(monkeypatch, module, {"email": email, "otp": "ABCDEF"})
+ res = _run(module.forget_verify_otp())
+ assert res["code"] == module.RetCode.EXCEPTION_ERROR, res
+
+ module.REDIS_CONN.store[k_code] = f"HASH_CORRECT:{salt.hex()}"
+ module.REDIS_CONN.store[k_attempts] = "bad-int"
+ _set_request_json(monkeypatch, module, {"email": email, "otp": "wrong"})
+ res = _run(module.forget_verify_otp())
+ assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res
+ assert module.REDIS_CONN.get(k_attempts) == 1, module.REDIS_CONN.store
+
+ module.REDIS_CONN.store[k_code] = f"HASH_CORRECT:{salt.hex()}"
+ module.REDIS_CONN.store[k_attempts] = str(module.ATTEMPT_LIMIT - 1)
+ _set_request_json(monkeypatch, module, {"email": email, "otp": "wrong"})
+ res = _run(module.forget_verify_otp())
+ assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res
+ assert module.REDIS_CONN.get(k_lock) is not None, module.REDIS_CONN.store
+ module.REDIS_CONN.store.pop(k_lock, None)
+
+ module.REDIS_CONN.store[k_code] = f"HASH_ABCDEF:{salt.hex()}"
+ module.REDIS_CONN.store[k_attempts] = "0"
+ module.REDIS_CONN.store[k_last] = "1000"
+
+ def _set_with_verified_fail(key, value, _ttl=None):
+ if key == module._verified_key(email):
+ raise RuntimeError("verified set boom")
+ module.REDIS_CONN.store[key] = value
+
+ monkeypatch.setattr(module.REDIS_CONN, "set", _set_with_verified_fail)
+ _set_request_json(monkeypatch, module, {"email": email, "otp": "abcdef"})
+ res = _run(module.forget_verify_otp())
+ assert res["code"] == module.RetCode.SERVER_ERROR, res
+
+ monkeypatch.setattr(module.REDIS_CONN, "set", lambda key, value, _ttl=None: module.REDIS_CONN.store.__setitem__(key, value))
+ module.REDIS_CONN.store[k_code] = f"HASH_ABCDEF:{salt.hex()}"
+ module.REDIS_CONN.store[k_attempts] = "0"
+ module.REDIS_CONN.store[k_last] = "1000"
+ _set_request_json(monkeypatch, module, {"email": email, "otp": "abcdef"})
+ res = _run(module.forget_verify_otp())
+ assert res["code"] == module.RetCode.SUCCESS, res
+ assert module.REDIS_CONN.get(k_code) is None, module.REDIS_CONN.store
+ assert module.REDIS_CONN.get(k_attempts) is None, module.REDIS_CONN.store
+ assert module.REDIS_CONN.get(k_last) is None, module.REDIS_CONN.store
+ assert module.REDIS_CONN.get(k_lock) is None, module.REDIS_CONN.store
+ assert module.REDIS_CONN.get(module._verified_key(email)) == "1", module.REDIS_CONN.store
+
+
+@pytest.mark.p2
+def test_forget_reset_password_matrix_unit(monkeypatch):
+ module = _load_user_app(monkeypatch)
+ email = "reset@example.com"
+ v_key = module._verified_key(email)
+ user = _DummyUser("u-reset", email, nickname="reset-user")
+ pwd_a = base64.b64encode(b"new-password").decode()
+ pwd_b = base64.b64encode(b"confirm-password").decode()
+ pwd_same = base64.b64encode(b"same-password").decode()
+ monkeypatch.setattr(module, "decrypt", lambda value: value)
+
+ _set_request_json(monkeypatch, module, {"email": email, "new_password": pwd_same, "confirm_new_password": pwd_same})
+ module.REDIS_CONN.store.pop(v_key, None)
+ res = _run(module.forget_reset_password())
+ assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res
+
+ module.REDIS_CONN.store[v_key] = "1"
+ monkeypatch.setattr(module, "decrypt", lambda _value: "")
+ _set_request_json(monkeypatch, module, {"email": email, "new_password": "", "confirm_new_password": ""})
+ res = _run(module.forget_reset_password())
+ assert res["code"] == module.RetCode.ARGUMENT_ERROR, res
+
+ monkeypatch.setattr(module, "decrypt", lambda value: value)
+ module.REDIS_CONN.store[v_key] = "1"
+ _set_request_json(monkeypatch, module, {"email": email, "new_password": pwd_a, "confirm_new_password": pwd_b})
+ res = _run(module.forget_reset_password())
+ assert res["code"] == module.RetCode.ARGUMENT_ERROR, res
+ assert "do not match" in res["message"], res
+
+ module.REDIS_CONN.store[v_key] = "1"
+ monkeypatch.setattr(module.UserService, "query_user_by_email", lambda **_kwargs: [])
+ _set_request_json(monkeypatch, module, {"email": email, "new_password": pwd_same, "confirm_new_password": pwd_same})
+ res = _run(module.forget_reset_password())
+ assert res["code"] == module.RetCode.DATA_ERROR, res
+
+ module.REDIS_CONN.store[v_key] = "1"
+ monkeypatch.setattr(module.UserService, "query_user_by_email", lambda **_kwargs: [user])
+
+ def _raise_update_password(_user_id, _new_pwd):
+ raise RuntimeError("reset boom")
+
+ monkeypatch.setattr(module.UserService, "update_user_password", _raise_update_password)
+ _set_request_json(monkeypatch, module, {"email": email, "new_password": pwd_same, "confirm_new_password": pwd_same})
+ res = _run(module.forget_reset_password())
+ assert res["code"] == module.RetCode.EXCEPTION_ERROR, res
+
+ module.REDIS_CONN.store[v_key] = "1"
+ monkeypatch.setattr(module.UserService, "update_user_password", lambda _user_id, _new_pwd: True)
+ monkeypatch.setattr(module.REDIS_CONN, "delete", lambda _key: (_ for _ in ()).throw(RuntimeError("delete boom")))
+ _set_request_json(monkeypatch, module, {"email": email, "new_password": pwd_same, "confirm_new_password": pwd_same})
+ res = _run(module.forget_reset_password())
+ assert res["code"] == module.RetCode.SUCCESS, res
+ assert res["auth"] == user.get_id(), res
+
+ monkeypatch.setattr(module.REDIS_CONN, "delete", lambda key: module.REDIS_CONN.store.pop(key, None))
+ module.REDIS_CONN.store[v_key] = "1"
+ _set_request_json(monkeypatch, module, {"email": email, "new_password": pwd_same, "confirm_new_password": pwd_same})
+ res = _run(module.forget_reset_password())
+ assert res["code"] == module.RetCode.SUCCESS, res
+ assert res["auth"] == user.get_id(), res
+ assert module.REDIS_CONN.get(v_key) is None, module.REDIS_CONN.store