Chore: migrate tests to restful api (#14871)

### What problem does this PR solve?

add new testing suite for the new restful api endpoints meant to replace
http and web api tests

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
- [x] Other (please describe): test
This commit is contained in:
Idriss Sbaaoui
2026-05-13 15:07:23 +08:00
committed by GitHub
parent d63d3bb7d2
commit 09e1fd290a
26 changed files with 6249 additions and 24 deletions

View File

@@ -0,0 +1,163 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pytest
from libs.auth import RAGFlowHttpApiAuth
from test.testcases.restful_api.helpers.client import RestClient
from utils.file_utils import create_txt_file
from utils import wait_for
@pytest.fixture(scope="session")
def RestApiAuth(token):
return RAGFlowHttpApiAuth(token)
@pytest.fixture(scope="session")
def rest_client(token):
return RestClient(token=token)
@pytest.fixture(scope="session")
def rest_client_noauth():
return RestClient(token=None)
@pytest.fixture
def clear_datasets(rest_client):
def _cleanup():
res = rest_client.delete("/datasets", json={"ids": None, "delete_all": True})
assert res.status_code == 200, res.text
payload = res.json()
assert payload["code"] in (0, 102), payload
yield
_cleanup()
@pytest.fixture
def clear_chats(rest_client):
def _cleanup():
res = rest_client.delete("/chats", json={"ids": None, "delete_all": True})
assert res.status_code == 200, res.text
payload = res.json()
assert payload["code"] in (0, 102), payload
yield
_cleanup()
@pytest.fixture
def create_dataset(rest_client, clear_datasets):
created_ids: list[str] = []
def _create(name: str = "restful_dataset") -> str:
res = rest_client.post("/datasets", json={"name": name})
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 0, payload
dataset_id = payload["data"]["id"]
created_ids.append(dataset_id)
return dataset_id
yield _create
if created_ids:
res = rest_client.delete("/datasets", json={"ids": created_ids})
assert res.status_code == 200
payload = res.json()
# Dataset may already be removed by test logic/cleanup.
assert payload["code"] in (0, 102), payload
@pytest.fixture
def create_chat(rest_client, clear_chats):
created_ids: list[str] = []
def _create(name: str = "restful_chat") -> str:
res = rest_client.post("/chats", json={"name": name, "dataset_ids": []})
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 0, payload
chat_id = payload["data"]["id"]
created_ids.append(chat_id)
return chat_id
yield _create
if created_ids:
res = rest_client.delete("/chats", json={"ids": created_ids})
assert res.status_code == 200, res.text
payload = res.json()
assert payload["code"] in (0, 102), payload
@pytest.fixture
def create_document(rest_client, create_dataset, tmp_path):
created_docs: list[tuple[str, str]] = []
def _create(name: str = "restful_doc.txt") -> tuple[str, str]:
dataset_id = create_dataset("dataset_for_doc")
fp = create_txt_file(tmp_path / name)
with fp.open("rb") as file_obj:
files = [("file", (fp.name, file_obj))]
res = rest_client.post(f"/datasets/{dataset_id}/documents", files=files)
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 0, payload
document_id = payload["data"][0]["id"]
created_docs.append((dataset_id, document_id))
return dataset_id, document_id
yield _create
for dataset_id, document_id in created_docs:
res = rest_client.delete(f"/datasets/{dataset_id}/documents", json={"ids": [document_id]})
assert res.status_code == 200, res.text
payload = res.json()
assert payload["code"] in (0, 102), payload
@wait_for(60, 1, "Document parsing timeout in RESTful batch2 tests")
def _parsed(rest_client: RestClient, dataset_id: str, document_id: str):
res = rest_client.get(f"/datasets/{dataset_id}/documents", params={"id": document_id})
if res.status_code != 200:
return False
payload = res.json()
if payload["code"] != 0:
return False
docs = payload["data"]["docs"]
if not docs:
return False
return docs[0].get("run") == "DONE"
@pytest.fixture
def ensure_parsed_document(rest_client, create_document):
def _ensure() -> tuple[str, str]:
dataset_id, document_id = create_document()
res = rest_client.post(
f"/datasets/{dataset_id}/documents/parse",
json={"document_ids": [document_id]},
)
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 0, payload
_parsed(rest_client, dataset_id, document_id)
return dataset_id, document_id
return _ensure

View File

@@ -0,0 +1 @@
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.

View File

@@ -0,0 +1,85 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from dataclasses import dataclass
from typing import Any
import requests
from configs import HOST_ADDRESS, VERSION
@dataclass
class RestClient:
token: str | None = None
timeout: int = 30
@property
def api_root(self) -> str:
return f"{HOST_ADDRESS}/api/{VERSION}"
def _headers(self, headers: dict[str, str] | None = None) -> dict[str, str]:
merged: dict[str, str] = {"Content-Type": "application/json"}
if headers:
merged.update(headers)
if self.token and "Authorization" not in merged:
merged["Authorization"] = f"Bearer {self.token}"
return merged
def request(
self,
method: str,
path: str,
*,
headers: dict[str, str] | None = None,
params: dict[str, Any] | None = None,
json: dict[str, Any] | None = None,
data: Any = None,
files: Any = None,
**request_kwargs: Any,
) -> requests.Response:
req_headers = self._headers(headers)
if files is not None:
# requests sets multipart boundary automatically.
req_headers.pop("Content-Type", None)
timeout = request_kwargs.pop("timeout", self.timeout)
normalized_path = f"/{path.lstrip('/')}" if path else "/"
return requests.request(
method=method,
url=f"{self.api_root}{normalized_path}",
headers=req_headers,
params=params,
json=json,
data=data,
files=files,
timeout=timeout,
**request_kwargs,
)
def get(self, path: str, **kwargs) -> requests.Response:
return self.request("GET", path, **kwargs)
def post(self, path: str, **kwargs) -> requests.Response:
return self.request("POST", path, **kwargs)
def delete(self, path: str, **kwargs) -> requests.Response:
return self.request("DELETE", path, **kwargs)
def put(self, path: str, **kwargs) -> requests.Response:
return self.request("PUT", path, **kwargs)
def patch(self, path: str, **kwargs) -> requests.Response:
return self.request("PATCH", path, **kwargs)

View File

@@ -0,0 +1,333 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import pytest
MINIMAL_DSL = {
"components": {
"begin": {
"obj": {"component_name": "Begin", "params": {}},
"downstream": ["message"],
"upstream": [],
},
"message": {
"obj": {"component_name": "Message", "params": {"content": ["{sys.query}"]}},
"downstream": [],
"upstream": ["begin"],
},
},
"history": [],
"retrieval": [],
"path": [],
"globals": {
"sys.query": "",
"sys.user_id": "",
"sys.conversation_turns": 0,
"sys.files": [],
},
"variables": {},
}
def _sse_events(response_text: str) -> list[str]:
return [line[5:] for line in response_text.splitlines() if line.startswith("data:")]
@pytest.fixture
def create_agent_resource(rest_client):
created_agent_ids: list[str] = []
def _create(title: str = "restful_agent") -> str:
res = rest_client.post("/agents", json={"title": title, "dsl": MINIMAL_DSL})
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 0, payload
agent_id = payload["data"]["id"]
created_agent_ids.append(agent_id)
return agent_id
yield _create
cleanup_errors = []
for agent_id in created_agent_ids:
res = rest_client.delete(f"/agents/{agent_id}")
if res.status_code != 200:
cleanup_errors.append((agent_id, res.status_code, res.text))
continue
payload = res.json()
if payload["code"] not in (0, 103):
cleanup_errors.append((agent_id, res.status_code, payload))
assert not cleanup_errors, f"Agent cleanup failed: {cleanup_errors}"
@pytest.mark.p2
def test_agents_crud_validation_contract(rest_client, create_agent_resource):
list_empty = rest_client.get("/agents", params={"title": "missing_restful_agent"})
assert list_empty.status_code == 200
list_empty_payload = list_empty.json()
assert list_empty_payload["code"] == 0, list_empty_payload
assert "canvas" in list_empty_payload["data"], list_empty_payload
assert "total" in list_empty_payload["data"], list_empty_payload
missing_dsl = rest_client.post("/agents", json={"title": "missing_dsl_agent"})
assert missing_dsl.status_code == 200
missing_dsl_payload = missing_dsl.json()
assert missing_dsl_payload["code"] == 101, missing_dsl_payload
assert "No DSL data in request" in missing_dsl_payload["message"], missing_dsl_payload
missing_title = rest_client.post("/agents", json={"dsl": MINIMAL_DSL})
assert missing_title.status_code == 200
missing_title_payload = missing_title.json()
assert missing_title_payload["code"] == 101, missing_title_payload
assert "No title in request" in missing_title_payload["message"], missing_title_payload
agent_id = create_agent_resource("restful_agent_crud")
duplicate = rest_client.post("/agents", json={"title": "restful_agent_crud", "dsl": MINIMAL_DSL})
assert duplicate.status_code == 200
duplicate_payload = duplicate.json()
assert duplicate_payload["code"] == 102, duplicate_payload
assert "already exists" in duplicate_payload["message"], duplicate_payload
get_res = rest_client.get(f"/agents/{agent_id}")
assert get_res.status_code == 200
get_payload = get_res.json()
assert get_payload["code"] == 0, get_payload
assert get_payload["data"]["id"] == agent_id, get_payload
update_res = rest_client.put(f"/agents/{agent_id}", json={"title": "restful_agent_crud_updated", "dsl": MINIMAL_DSL})
assert update_res.status_code == 200
update_payload = update_res.json()
assert update_payload["code"] == 0, update_payload
list_after_update = rest_client.get("/agents", params={"title": "restful_agent_crud_updated"})
assert list_after_update.status_code == 200
list_after_update_payload = list_after_update.json()
assert list_after_update_payload["code"] == 0, list_after_update_payload
assert list_after_update_payload["data"]["total"] >= 1, list_after_update_payload
delete_res = rest_client.delete(f"/agents/{agent_id}")
assert delete_res.status_code == 200
delete_payload = delete_res.json()
assert delete_payload["code"] == 0, delete_payload
assert delete_payload["data"] is True, delete_payload
@pytest.mark.p2
def test_agent_sessions_crud(rest_client, create_agent_resource):
agent_id = create_agent_resource("restful_agent_sessions")
create_session = rest_client.post(f"/agents/{agent_id}/sessions", json={"name": "agent_session_1"})
assert create_session.status_code == 200
create_session_payload = create_session.json()
assert create_session_payload["code"] == 0, create_session_payload
session_id = create_session_payload["data"]["id"]
list_sessions = rest_client.get(f"/agents/{agent_id}/sessions")
assert list_sessions.status_code == 200
list_sessions_payload = list_sessions.json()
assert list_sessions_payload["code"] == 0, list_sessions_payload
assert isinstance(list_sessions_payload["data"], list), list_sessions_payload
assert any(item["id"] == session_id for item in list_sessions_payload["data"]), list_sessions_payload
get_session = rest_client.get(f"/agents/{agent_id}/sessions/{session_id}")
assert get_session.status_code == 200
get_session_payload = get_session.json()
assert get_session_payload["code"] == 0, get_session_payload
assert get_session_payload["data"]["id"] == session_id, get_session_payload
delete_session = rest_client.delete(f"/agents/{agent_id}/sessions/{session_id}")
assert delete_session.status_code == 200
delete_session_payload = delete_session.json()
assert delete_session_payload["code"] == 0, delete_session_payload
@pytest.mark.p2
def test_agent_chat_completion_validation(rest_client):
missing_agent_id = rest_client.post(
"/agents/chat/completions",
json={"query": "hello", "stream": False},
)
assert missing_agent_id.status_code == 200
missing_agent_id_payload = missing_agent_id.json()
assert missing_agent_id_payload["code"] == 101, missing_agent_id_payload
assert "`agent_id` is required." in missing_agent_id_payload["message"], missing_agent_id_payload
@pytest.mark.p2
def test_agent_chat_completion_nonstream(rest_client, create_agent_resource):
agent_id = create_agent_resource("restful_agent_nonstream")
create_session = rest_client.post(f"/agents/{agent_id}/sessions", json={"name": "agent_completion_session"})
assert create_session.status_code == 200
create_session_payload = create_session.json()
assert create_session_payload["code"] == 0, create_session_payload
session_id = create_session_payload["data"]["id"]
res = rest_client.post(
"/agents/chat/completions",
json={"agent_id": agent_id, "query": "hello", "stream": False, "session_id": session_id},
timeout=60,
)
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 0, payload
assert isinstance(payload["data"], dict), payload
assert isinstance(payload["data"].get("data"), dict), payload
assert "content" in payload["data"]["data"], payload
@pytest.mark.p2
def test_agent_chat_completion_stream_structure_and_done(rest_client, create_agent_resource):
agent_id = create_agent_resource("restful_agent_stream")
create_session = rest_client.post(f"/agents/{agent_id}/sessions", json={"name": "agent_stream_session"})
assert create_session.status_code == 200
create_session_payload = create_session.json()
assert create_session_payload["code"] == 0, create_session_payload
session_id = create_session_payload["data"]["id"]
res = rest_client.post(
"/agents/chat/completions",
json={
"agent_id": agent_id,
"query": "hello",
"stream": True,
"session_id": session_id,
"return_trace": True,
},
timeout=60,
)
assert res.status_code == 200
content_type = res.headers.get("Content-Type", "")
assert "text/event-stream" in content_type, content_type
events = _sse_events(res.text)
assert events, res.text
assert events[-1].strip() == "[DONE]", events[-1]
json_events = [json.loads(evt) for evt in events if evt.strip() != "[DONE]"]
assert json_events, events
assert any(isinstance(evt, dict) for evt in json_events), json_events
@pytest.mark.p2
def test_agent_openai_compatible_mode(rest_client, create_agent_resource):
agent_id = create_agent_resource("restful_agent_openai_compat")
missing_messages = rest_client.post(
"/agents/chat/completions",
json={"agent_id": agent_id, "openai-compatible": True, "model": "model", "messages": []},
)
assert missing_messages.status_code == 200
missing_messages_payload = missing_messages.json()
assert missing_messages_payload["code"] == 102, missing_messages_payload
assert "at least one message" in missing_messages_payload["message"], missing_messages_payload
nonstream = rest_client.post(
"/agents/chat/completions",
json={
"agent_id": agent_id,
"openai-compatible": True,
"model": "model",
"messages": [{"role": "user", "content": "hello"}],
"stream": False,
},
timeout=60,
)
assert nonstream.status_code == 200
nonstream_payload = nonstream.json()
assert isinstance(nonstream_payload, dict), nonstream_payload
assert "choices" in nonstream_payload, nonstream_payload
stream = rest_client.post(
"/agents/chat/completions",
json={
"agent_id": agent_id,
"openai-compatible": True,
"model": "model",
"messages": [{"role": "user", "content": "hello"}],
"stream": True,
},
timeout=60,
)
assert stream.status_code == 200
stream_content_type = stream.headers.get("Content-Type", "")
assert "text/event-stream" in stream_content_type, stream_content_type
@pytest.mark.p2
def test_agent_support_routes_auth_and_contracts(rest_client, rest_client_noauth, create_agent_resource):
prompts_unauth = rest_client_noauth.get("/agents/prompts")
assert prompts_unauth.status_code == 401
assert prompts_unauth.json()["code"] == 401
prompts = rest_client.get("/agents/prompts")
assert prompts.status_code == 200
prompts_payload = prompts.json()
assert prompts_payload["code"] == 0, prompts_payload
assert "task_analysis" in prompts_payload["data"], prompts_payload
assert "citation_guidelines" in prompts_payload["data"], prompts_payload
templates = rest_client.get("/agents/templates")
assert templates.status_code == 200
templates_payload = templates.json()
assert templates_payload["code"] == 0, templates_payload
assert isinstance(templates_payload["data"], list), templates_payload
agent_id = create_agent_resource("restful_agent_support")
versions = rest_client.get(f"/agents/{agent_id}/versions")
assert versions.status_code == 200
versions_payload = versions.json()
assert versions_payload["code"] == 0, versions_payload
assert isinstance(versions_payload["data"], list), versions_payload
logs = rest_client.get(f"/agents/{agent_id}/logs/missing_message")
assert logs.status_code == 200
logs_payload = logs.json()
assert logs_payload["code"] == 0, logs_payload
assert isinstance(logs_payload["data"], dict), logs_payload
@pytest.mark.p2
def test_agent_webhook_logs_empty_poll_contract(rest_client, create_agent_resource):
agent_id = create_agent_resource("restful_agent_webhook_logs")
res = rest_client.get(f"/agents/{agent_id}/webhook/logs", params={"since_ts": 0})
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 0, payload
assert payload["data"]["events"] == [], payload
assert payload["data"]["finished"] is False, payload
assert "next_since_ts" in payload["data"], payload
@pytest.mark.p2
def test_agent_db_connection_validates_required_fields(rest_client):
res = rest_client.post("/agents/test_db_connection", json={"db_type": "mysql"})
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 101, payload
assert "required argument are missing" in payload["message"], payload
@pytest.mark.p2
def test_agent_rerun_requires_required_fields(rest_client):
res = rest_client.post("/agents/rerun", json={"id": "flow-1"})
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 101, payload
assert "required argument are missing" in payload["message"], payload

View File

@@ -0,0 +1,123 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pytest
@pytest.mark.p1
class TestChatsAuthorization:
def test_create_requires_auth(self, rest_client_noauth):
res = rest_client_noauth.post("/chats", json={"name": "chat_auth", "dataset_ids": []})
assert res.status_code == 401
@pytest.mark.p1
def test_chat_crud_cycle(rest_client, clear_chats):
create_res = rest_client.post(
"/chats",
json={"name": "restful_chat_crud", "dataset_ids": []},
)
assert create_res.status_code == 200
create_payload = create_res.json()
assert create_payload["code"] == 0, create_payload
chat_id = create_payload["data"]["id"]
list_res = rest_client.get("/chats", params={"id": chat_id})
assert list_res.status_code == 200
list_payload = list_res.json()
assert list_payload["code"] == 0, list_payload
chats = list_payload["data"]["chats"]
assert len(chats) == 1, list_payload
assert chats[0]["id"] == chat_id, list_payload
get_res = rest_client.get(f"/chats/{chat_id}")
assert get_res.status_code == 200
get_payload = get_res.json()
assert get_payload["code"] == 0, get_payload
assert get_payload["data"]["id"] == chat_id, get_payload
update_res = rest_client.put(
f"/chats/{chat_id}",
json={"name": "restful_chat_crud_updated", "dataset_ids": []},
)
assert update_res.status_code == 200
update_payload = update_res.json()
assert update_payload["code"] == 0, update_payload
assert update_payload["data"]["name"] == "restful_chat_crud_updated", update_payload
patch_res = rest_client.patch(f"/chats/{chat_id}", json={"name": "restful_chat_crud_patched"})
assert patch_res.status_code == 200
patch_payload = patch_res.json()
assert patch_payload["code"] == 0, patch_payload
assert patch_payload["data"]["name"] == "restful_chat_crud_patched", patch_payload
delete_res = rest_client.delete("/chats", json={"ids": [chat_id]})
assert delete_res.status_code == 200
delete_payload = delete_res.json()
assert delete_payload["code"] == 0, delete_payload
assert delete_payload["data"]["success_count"] == 1, delete_payload
list_after_delete = rest_client.get("/chats", params={"id": chat_id})
assert list_after_delete.status_code == 200
list_after_delete_payload = list_after_delete.json()
assert list_after_delete_payload["code"] == 0, list_after_delete_payload
assert list_after_delete_payload["data"]["chats"] == [], list_after_delete_payload
@pytest.mark.p2
@pytest.mark.parametrize(
"name, expected_fragment",
[
("", "`name` is required."),
(" ", "`name` is required."),
],
)
def test_chat_create_name_validation(rest_client, clear_chats, name, expected_fragment):
res = rest_client.post("/chats", json={"name": name, "dataset_ids": []})
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 102, payload
assert expected_fragment in payload["message"], payload
@pytest.mark.p2
def test_chat_duplicate_name_validation(rest_client, clear_chats):
first = rest_client.post("/chats", json={"name": "duplicate_chat_name", "dataset_ids": []})
assert first.status_code == 200
first_payload = first.json()
assert first_payload["code"] == 0, first_payload
second = rest_client.post("/chats", json={"name": "duplicate_chat_name", "dataset_ids": []})
assert second.status_code == 200
second_payload = second.json()
assert second_payload["code"] == 102, second_payload
assert "Duplicated chat name" in second_payload["message"], second_payload
@pytest.mark.p2
def test_chat_list_pagination(rest_client, clear_chats):
for i in range(3):
res = rest_client.post("/chats", json={"name": f"chat_page_{i}", "dataset_ids": []})
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 0, payload
page_res = rest_client.get("/chats", params={"page": 1, "page_size": 2, "orderby": "create_time", "desc": "true"})
assert page_res.status_code == 200
page_payload = page_res.json()
assert page_payload["code"] == 0, page_payload
assert len(page_payload["data"]["chats"]) == 2, page_payload
assert page_payload["data"]["total"] >= 3, page_payload

View File

@@ -0,0 +1,124 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pytest
def _assert_created_chunk_id(payload):
chunk_id = payload["data"]["chunk"].get("id")
assert chunk_id, payload
assert isinstance(chunk_id, str), payload
assert chunk_id.strip(), payload
return chunk_id
@pytest.mark.p1
def test_chunks_add_list_get_update_delete_cycle(rest_client, create_document):
dataset_id, document_id = create_document("chunk_cycle.txt")
base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks"
add_res = rest_client.post(
base_path,
json={"content": "batch2 chunk content", "important_keywords": ["batch2"], "questions": ["what is batch2?"]},
)
assert add_res.status_code == 200
add_payload = add_res.json()
assert add_payload["code"] == 0, add_payload
chunk_id = _assert_created_chunk_id(add_payload)
list_res = rest_client.get(base_path, params={"id": chunk_id})
assert list_res.status_code == 200
list_payload = list_res.json()
assert list_payload["code"] == 0, list_payload
assert list_payload["data"]["total"] == 1, list_payload
assert list_payload["data"]["chunks"][0]["id"] == chunk_id, list_payload
get_res = rest_client.get(f"{base_path}/{chunk_id}")
assert get_res.status_code == 200
get_payload = get_res.json()
assert get_payload["code"] == 0, get_payload
assert get_payload["data"]["id"] == chunk_id, get_payload
update_res = rest_client.patch(
f"{base_path}/{chunk_id}",
json={"content": "batch2 chunk content updated"},
)
assert update_res.status_code == 200
update_payload = update_res.json()
assert update_payload["code"] == 0, update_payload
get_updated_res = rest_client.get(f"{base_path}/{chunk_id}")
assert get_updated_res.status_code == 200
get_updated_payload = get_updated_res.json()
assert get_updated_payload["code"] == 0, get_updated_payload
assert get_updated_payload["data"]["content_with_weight"] == "batch2 chunk content updated", get_updated_payload
delete_candidate_res = rest_client.post(base_path, json={"content": "batch2 chunk content to delete"})
assert delete_candidate_res.status_code == 200
delete_candidate_payload = delete_candidate_res.json()
assert delete_candidate_payload["code"] == 0, delete_candidate_payload
delete_candidate_id = _assert_created_chunk_id(delete_candidate_payload)
delete_res = rest_client.delete(base_path, json={"chunk_ids": [delete_candidate_id]})
assert delete_res.status_code == 200
delete_payload = delete_res.json()
assert delete_payload["code"] == 0, delete_payload
deleted_list_res = rest_client.get(base_path, params={"id": delete_candidate_id})
assert deleted_list_res.status_code == 200
deleted_list_payload = deleted_list_res.json()
assert deleted_list_payload["code"] != 0, deleted_list_payload
deleted_get_res = rest_client.get(f"{base_path}/{delete_candidate_id}")
assert deleted_get_res.status_code == 200
deleted_get_payload = deleted_get_res.json()
assert deleted_get_payload["code"] != 0, deleted_get_payload
@pytest.mark.p2
def test_chunks_add_requires_content(rest_client, create_document):
dataset_id, document_id = create_document("chunk_requires_content.txt")
res = rest_client.post(
f"/datasets/{dataset_id}/documents/{document_id}/chunks",
json={"content": " "},
)
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 102, payload
assert payload["message"] == "`content` is required", payload
@pytest.mark.p2
def test_chunks_list_empty_document(rest_client, create_document):
dataset_id, document_id = create_document("chunk_list_empty.txt")
base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks"
list_res = rest_client.get(base_path)
assert list_res.status_code == 200
list_payload = list_res.json()
assert list_payload["code"] == 0, list_payload
assert "chunks" in list_payload["data"], list_payload
assert "doc" in list_payload["data"], list_payload
@pytest.mark.p2
def test_chunks_delete_partial_invalid(rest_client, create_document):
dataset_id, document_id = create_document("chunk_delete_partial.txt")
base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks"
delete_res = rest_client.delete(base_path, json={"chunk_ids": ["invalid_chunk_id"]})
assert delete_res.status_code == 200
delete_payload = delete_res.json()
assert delete_payload["code"] == 102, delete_payload
assert "expect 1" in delete_payload["message"], delete_payload

View File

@@ -0,0 +1,718 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import importlib.util
import json
import sys
from pathlib import Path
from types import ModuleType, SimpleNamespace
import pytest
class _DummyManager:
def route(self, *_args, **_kwargs):
def decorator(func):
return func
return decorator
class _AwaitableValue:
def __init__(self, value):
self._value = value
def __await__(self):
async def _co():
return self._value
return _co().__await__()
class _Args(dict):
def get(self, key, default=None, type=None):
value = super().get(key, default)
if type is None:
return value
try:
return type(value)
except (TypeError, ValueError):
return default
def to_dict(self, flat=True):
return dict(self)
class _FakeResponse:
def __init__(self, body, status_code):
self.body = body
self.status_code = status_code
self.headers = {}
class _FakeConnectorRecord:
def __init__(self, payload):
self._payload = payload
def to_dict(self):
return dict(self._payload)
class _FakeCredentials:
def __init__(self, raw='{"refresh_token":"rt","access_token":"at"}'):
self._raw = raw
def to_json(self):
return self._raw
class _FakeFlow:
def __init__(self, client_config, scopes):
self.client_config = client_config
self.scopes = scopes
self.redirect_uri = None
self.credentials = _FakeCredentials()
self.auth_kwargs = None
self.token_code = None
self.token_code_verifier = None
self.code_verifier = "fake-code-verifier"
def authorization_url(self, **kwargs):
self.auth_kwargs = dict(kwargs)
return f"https://oauth.example/{kwargs['state']}", kwargs["state"]
def fetch_token(self, code, code_verifier=None):
self.token_code = code
self.token_code_verifier = code_verifier
class _FakeBoxToken:
def __init__(self, access_token, refresh_token):
self.access_token = access_token
self.refresh_token = refresh_token
class _FakeBoxOAuth:
def __init__(self, config):
self.config = config
self.exchange_code = None
def get_authorize_url(self, options):
return f"https://box.example/auth?state={options.state}&redirect={options.redirect_uri}"
def get_tokens_authorization_code_grant(self, code):
self.exchange_code = code
def retrieve_token(self):
return _FakeBoxToken("box-access", "box-refresh")
class _FakeRedis:
def __init__(self):
self.store = {}
self.set_calls = []
self.deleted = []
def get(self, key):
return self.store.get(key)
def set_obj(self, key, obj, ttl):
self.set_calls.append((key, obj, ttl))
self.store[key] = json.dumps(obj)
def delete(self, key):
self.deleted.append(key)
self.store.pop(key, None)
def _run(coro):
return asyncio.run(coro)
def _set_request(module, *, args=None, json_body=None):
module.request = SimpleNamespace(
args=_Args(args or {}),
json=_AwaitableValue({} if json_body is None else json_body),
)
@pytest.fixture(scope="session")
def auth():
return "unit-auth"
@pytest.fixture(scope="session", autouse=True)
def set_tenant_info():
return None
def _load_connector_app(monkeypatch):
repo_root = Path(__file__).resolve().parents[3]
api_pkg = ModuleType("api")
api_pkg.__path__ = [str(repo_root / "api")]
monkeypatch.setitem(sys.modules, "api", api_pkg)
apps_mod = ModuleType("api.apps")
apps_mod.__path__ = [str(repo_root / "api" / "apps")]
apps_mod.current_user = SimpleNamespace(id="tenant-1")
apps_mod.login_required = lambda fn: fn
monkeypatch.setitem(sys.modules, "api.apps", apps_mod)
db_mod = ModuleType("api.db")
db_mod.InputType = SimpleNamespace(POLL="POLL")
monkeypatch.setitem(sys.modules, "api.db", db_mod)
services_pkg = ModuleType("api.db.services")
services_pkg.__path__ = []
monkeypatch.setitem(sys.modules, "api.db.services", services_pkg)
connector_service_mod = ModuleType("api.db.services.connector_service")
class _StubConnectorService:
@staticmethod
def update_by_id(*_args, **_kwargs):
return True
@staticmethod
def save(**_kwargs):
return True
@staticmethod
def get_by_id(_connector_id):
return True, _FakeConnectorRecord({"id": _connector_id})
@staticmethod
def list(_tenant_id):
return []
@staticmethod
def resume(*_args, **_kwargs):
return True
@staticmethod
def rebuild(*_args, **_kwargs):
return None
@staticmethod
def delete_by_id(*_args, **_kwargs):
return True
class _StubSyncLogsService:
@staticmethod
def list_sync_tasks(*_args, **_kwargs):
return [], 0
connector_service_mod.ConnectorService = _StubConnectorService
connector_service_mod.SyncLogsService = _StubSyncLogsService
monkeypatch.setitem(sys.modules, "api.db.services.connector_service", connector_service_mod)
api_utils_mod = ModuleType("api.utils.api_utils")
async def _get_request_json():
return {}
api_utils_mod.get_request_json = _get_request_json
api_utils_mod.get_json_result = lambda data=None, message="", code=0: {
"code": code,
"message": message,
"data": data,
}
api_utils_mod.get_data_error_result = lambda message="", code=400, data=None: {
"code": code,
"message": message,
"data": data,
}
api_utils_mod.validate_request = lambda *_args, **_kwargs: (lambda fn: fn)
monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
constants_mod = ModuleType("common.constants")
constants_mod.RetCode = SimpleNamespace(
ARGUMENT_ERROR=101,
SERVER_ERROR=500,
RUNNING=102,
PERMISSION_ERROR=403,
)
constants_mod.TaskStatus = SimpleNamespace(SCHEDULE="schedule", CANCEL="cancel")
monkeypatch.setitem(sys.modules, "common.constants", constants_mod)
config_mod = ModuleType("common.data_source.config")
config_mod.GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI = "https://example.com/drive"
config_mod.GMAIL_WEB_OAUTH_REDIRECT_URI = "https://example.com/gmail"
config_mod.BOX_WEB_OAUTH_REDIRECT_URI = "https://example.com/box"
config_mod.DocumentSource = SimpleNamespace(GMAIL="gmail", GOOGLE_DRIVE="google-drive")
monkeypatch.setitem(sys.modules, "common.data_source.config", config_mod)
google_constants_mod = ModuleType("common.data_source.google_util.constant")
google_constants_mod.WEB_OAUTH_POPUP_TEMPLATE = (
"<html><head><title>{title}</title></head>"
"<body><h1>{heading}</h1><p>{message}</p><script>{payload_json}</script><script>{auto_close}</script></body></html>"
)
google_constants_mod.GOOGLE_SCOPES = {
config_mod.DocumentSource.GMAIL: ["scope-gmail"],
config_mod.DocumentSource.GOOGLE_DRIVE: ["scope-drive"],
}
monkeypatch.setitem(sys.modules, "common.data_source.google_util.constant", google_constants_mod)
misc_mod = ModuleType("common.misc_utils")
misc_mod.get_uuid = lambda: "uuid-from-helper"
monkeypatch.setitem(sys.modules, "common.misc_utils", misc_mod)
rag_pkg = ModuleType("rag")
rag_pkg.__path__ = [str(repo_root / "rag")]
monkeypatch.setitem(sys.modules, "rag", rag_pkg)
rag_utils_pkg = ModuleType("rag.utils")
rag_utils_pkg.__path__ = [str(repo_root / "rag" / "utils")]
monkeypatch.setitem(sys.modules, "rag.utils", rag_utils_pkg)
redis_mod = ModuleType("rag.utils.redis_conn")
redis_mod.REDIS_CONN = _FakeRedis()
monkeypatch.setitem(sys.modules, "rag.utils.redis_conn", redis_mod)
quart_mod = ModuleType("quart")
quart_mod.request = SimpleNamespace(args=_Args(), json=_AwaitableValue({}))
async def _make_response(body, status_code):
return _FakeResponse(body, status_code)
quart_mod.make_response = _make_response
monkeypatch.setitem(sys.modules, "quart", quart_mod)
google_pkg = ModuleType("google_auth_oauthlib")
google_pkg.__path__ = []
monkeypatch.setitem(sys.modules, "google_auth_oauthlib", google_pkg)
google_flow_mod = ModuleType("google_auth_oauthlib.flow")
class _StubFlow:
@classmethod
def from_client_config(cls, client_config, scopes):
return _FakeFlow(client_config, scopes)
google_flow_mod.Flow = _StubFlow
monkeypatch.setitem(sys.modules, "google_auth_oauthlib.flow", google_flow_mod)
box_mod = ModuleType("box_sdk_gen")
class _OAuthConfig:
def __init__(self, client_id, client_secret):
self.client_id = client_id
self.client_secret = client_secret
class _GetAuthorizeUrlOptions:
def __init__(self, redirect_uri, state):
self.redirect_uri = redirect_uri
self.state = state
box_mod.BoxOAuth = _FakeBoxOAuth
box_mod.OAuthConfig = _OAuthConfig
box_mod.GetAuthorizeUrlOptions = _GetAuthorizeUrlOptions
monkeypatch.setitem(sys.modules, "box_sdk_gen", box_mod)
module_path = repo_root / "api" / "apps" / "restful_apis" / "connector_api.py"
spec = importlib.util.spec_from_file_location("test_connector_routes_unit", module_path)
module = importlib.util.module_from_spec(spec)
module.manager = _DummyManager()
spec.loader.exec_module(module)
return module
@pytest.mark.p2
def test_connector_basic_routes_and_task_controls(monkeypatch):
module = _load_connector_app(monkeypatch)
async def _no_sleep(_secs):
return None
monkeypatch.setattr(module.asyncio, "sleep", _no_sleep)
records = {"conn-1": _FakeConnectorRecord({"id": "conn-1", "source": "drive"})}
update_calls = []
save_calls = []
resume_calls = []
delete_calls = []
monkeypatch.setattr(module.ConnectorService, "update_by_id", lambda cid, payload: update_calls.append((cid, payload)))
def _save(**payload):
save_calls.append(payload)
records[payload["id"]] = _FakeConnectorRecord(payload)
monkeypatch.setattr(module.ConnectorService, "save", _save)
monkeypatch.setattr(module.ConnectorService, "get_by_id", lambda cid: (True, records[cid]))
monkeypatch.setattr(module.ConnectorService, "list", lambda tenant_id: [{"id": "listed", "tenant": tenant_id}])
monkeypatch.setattr(module.SyncLogsService, "list_sync_tasks", lambda cid, page, page_size: ([{"id": "log-1"}], 9))
monkeypatch.setattr(module.ConnectorService, "resume", lambda cid, status: resume_calls.append((cid, status)))
monkeypatch.setattr(module.ConnectorService, "delete_by_id", lambda cid: delete_calls.append(cid))
monkeypatch.setattr(module, "get_uuid", lambda: "generated-id")
monkeypatch.setattr(
module,
"get_request_json",
lambda: _AwaitableValue({"id": "conn-1", "refresh_freq": 7, "config": {"x": 1}}),
)
res = _run(module.update_connector("conn-1"))
assert update_calls == [("conn-1", {'id': 'conn-1', "refresh_freq": 7, "config": {"x": 1}})]
assert res["data"]["id"] == "conn-1"
monkeypatch.setattr(
module,
"get_request_json",
lambda: _AwaitableValue({"name": "new", "source": "gmail", "config": {"y": 2}}),
)
res = _run(module.create_connector())
assert save_calls[-1]["id"] == "generated-id"
assert save_calls[-1]["tenant_id"] == "tenant-1"
assert save_calls[-1]["input_type"] == module.InputType.POLL
assert res["data"]["id"] == "generated-id"
list_res = module.list_connector()
assert list_res["data"] == [{"id": "listed", "tenant": "tenant-1"}]
monkeypatch.setattr(module.ConnectorService, "get_by_id", lambda _cid: (False, None))
missing_res = module.get_connector("missing")
assert missing_res["message"] == "Can't find this Connector!"
monkeypatch.setattr(module.ConnectorService, "get_by_id", lambda cid: (True, _FakeConnectorRecord({"id": cid})))
found_res = module.get_connector("conn-2")
assert found_res["data"]["id"] == "conn-2"
_set_request(module, args={"page": "2", "page_size": "7"})
logs_res = module.list_logs("conn-log")
assert logs_res["data"] == {"total": 9, "logs": [{"id": "log-1"}]}
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"resume": True}))
assert _run(module.resume("conn-r1"))["data"] is True
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"resume": False}))
assert _run(module.resume("conn-r2"))["data"] is True
assert ("conn-r1", module.TaskStatus.SCHEDULE) in resume_calls
assert ("conn-r2", module.TaskStatus.CANCEL) in resume_calls
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"kb_id": "kb-1"}))
monkeypatch.setattr(module.ConnectorService, "rebuild", lambda *_args: "rebuild-failed")
failed_rebuild = _run(module.rebuild("conn-rb"))
assert failed_rebuild["code"] == module.RetCode.SERVER_ERROR
assert failed_rebuild["data"] is False
monkeypatch.setattr(module.ConnectorService, "rebuild", lambda *_args: None)
ok_rebuild = _run(module.rebuild("conn-rb"))
assert ok_rebuild["data"] is True
rm_res = module.rm_connector("conn-rm")
assert rm_res["data"] is True
assert ("conn-rm", module.TaskStatus.CANCEL) in resume_calls
assert delete_calls == ["conn-rm"]
@pytest.mark.p2
def test_connector_oauth_helper_functions(monkeypatch):
module = _load_connector_app(monkeypatch)
assert module._web_state_cache_key("flow-a", "gmail") == "gmail_web_flow_state:flow-a"
assert module._web_result_cache_key("flow-b", "google-drive") == "google-drive_web_flow_result:flow-b"
creds_dict = {"web": {"client_id": "id"}}
assert module._load_credentials(creds_dict) == creds_dict
assert module._load_credentials(json.dumps(creds_dict)) == creds_dict
with pytest.raises(ValueError, match="Invalid Google credentials JSON"):
module._load_credentials("{not-json")
assert module._get_web_client_config(creds_dict) == {"web": {"client_id": "id"}}
with pytest.raises(ValueError, match="must include a 'web'"):
module._get_web_client_config({"installed": {"client_id": "id"}})
popup_ok = _run(module._render_web_oauth_popup("flow-1", True, "done", "gmail"))
assert popup_ok.status_code == 200
assert popup_ok.headers["Content-Type"] == "text/html; charset=utf-8"
assert "Authorization complete" in popup_ok.body
assert "ragflow-gmail-oauth" in popup_ok.body
popup_error = _run(module._render_web_oauth_popup("flow-2", False, "<denied>", "google-drive"))
assert popup_error.status_code == 200
assert "Authorization failed" in popup_error.body
assert "&lt;denied&gt;" in popup_error.body
@pytest.mark.p2
def test_start_google_web_oauth_matrix(monkeypatch):
module = _load_connector_app(monkeypatch)
redis = _FakeRedis()
monkeypatch.setattr(module, "REDIS_CONN", redis)
monkeypatch.setattr(module.time, "time", lambda: 1700000000)
flow_calls = []
def _from_client_config(client_config, scopes):
flow = _FakeFlow(client_config, scopes)
flow_calls.append(flow)
return flow
monkeypatch.setattr(module.Flow, "from_client_config", staticmethod(_from_client_config))
_set_request(module, args={"type": "invalid"})
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"credentials": "{}"}))
invalid_type = _run(module.start_google_web_oauth())
assert invalid_type["code"] == module.RetCode.ARGUMENT_ERROR
monkeypatch.setattr(module, "GMAIL_WEB_OAUTH_REDIRECT_URI", "")
_set_request(module, args={"type": "gmail"})
missing_redirect = _run(module.start_google_web_oauth())
assert missing_redirect["code"] == module.RetCode.SERVER_ERROR
monkeypatch.setattr(module, "GMAIL_WEB_OAUTH_REDIRECT_URI", "https://example.com/gmail")
monkeypatch.setattr(module, "GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI", "https://example.com/drive")
_set_request(module, args={"type": "google-drive"})
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"credentials": "{invalid-json"}))
invalid_credentials = _run(module.start_google_web_oauth())
assert invalid_credentials["code"] == module.RetCode.ARGUMENT_ERROR
monkeypatch.setattr(
module,
"get_request_json",
lambda: _AwaitableValue({"credentials": json.dumps({"web": {"client_id": "id"}, "refresh_token": "rt"})}),
)
has_refresh_token = _run(module.start_google_web_oauth())
assert has_refresh_token["code"] == module.RetCode.ARGUMENT_ERROR
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"credentials": json.dumps({"installed": {"x": 1}})}))
missing_web = _run(module.start_google_web_oauth())
assert missing_web["code"] == module.RetCode.ARGUMENT_ERROR
ids = iter(["flow-gmail", "flow-drive"])
monkeypatch.setattr(module.uuid, "uuid4", lambda: next(ids))
monkeypatch.setattr(
module,
"get_request_json",
lambda: _AwaitableValue({"credentials": json.dumps({"web": {"client_id": "id", "client_secret": "secret"}})}),
)
_set_request(module, args={"type": "gmail"})
gmail_ok = _run(module.start_google_web_oauth())
assert gmail_ok["code"] == 0
assert gmail_ok["data"]["flow_id"] == "flow-gmail"
assert gmail_ok["data"]["authorization_url"].endswith("flow-gmail")
_set_request(module, args={})
drive_ok = _run(module.start_google_web_oauth())
assert drive_ok["code"] == 0
assert drive_ok["data"]["flow_id"] == "flow-drive"
assert drive_ok["data"]["authorization_url"].endswith("flow-drive")
assert any(call.scopes == module.GOOGLE_SCOPES[module.DocumentSource.GMAIL] for call in flow_calls)
assert any(call.scopes == module.GOOGLE_SCOPES[module.DocumentSource.GOOGLE_DRIVE] for call in flow_calls)
assert "gmail_web_flow_state:flow-gmail" in redis.store
assert "google-drive_web_flow_state:flow-drive" in redis.store
assert json.loads(redis.store["gmail_web_flow_state:flow-gmail"])["code_verifier"] == "fake-code-verifier"
assert json.loads(redis.store["google-drive_web_flow_state:flow-drive"])["code_verifier"] == "fake-code-verifier"
@pytest.mark.p2
def test_google_web_oauth_callbacks_matrix(monkeypatch):
module = _load_connector_app(monkeypatch)
flow_calls = []
def _from_client_config(client_config, scopes):
flow = _FakeFlow(client_config, scopes)
flow_calls.append(flow)
return flow
monkeypatch.setattr(module.Flow, "from_client_config", staticmethod(_from_client_config))
callback_specs = [
(
module.google_gmail_web_oauth_callback,
"gmail",
module.GMAIL_WEB_OAUTH_REDIRECT_URI,
module.GOOGLE_SCOPES[module.DocumentSource.GMAIL],
),
(
module.google_drive_web_oauth_callback,
"google-drive",
module.GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI,
module.GOOGLE_SCOPES[module.DocumentSource.GOOGLE_DRIVE],
),
]
for callback, source, expected_redirect, expected_scopes in callback_specs:
redis = _FakeRedis()
monkeypatch.setattr(module, "REDIS_CONN", redis)
_set_request(module, args={})
missing_state = _run(callback())
assert "Missing OAuth state parameter." in missing_state.body
_set_request(module, args={"state": "sid"})
expired_state = _run(callback())
assert "Authorization session expired" in expired_state.body
redis.store[module._web_state_cache_key("sid", source)] = json.dumps({"user_id": "tenant-1"})
_set_request(module, args={"state": "sid"})
invalid_state = _run(callback())
assert "Authorization session was invalid" in invalid_state.body
assert module._web_state_cache_key("sid", source) in redis.deleted
redis.store[module._web_state_cache_key("sid", source)] = json.dumps({
"user_id": "tenant-1",
"client_config": {"web": {"client_id": "cid"}},
})
_set_request(module, args={"state": "sid", "error": "denied", "error_description": "permission denied"})
oauth_error = _run(callback())
assert "permission denied" in oauth_error.body
redis.store[module._web_state_cache_key("sid", source)] = json.dumps({
"user_id": "tenant-1",
"client_config": {"web": {"client_id": "cid"}},
})
_set_request(module, args={"state": "sid"})
missing_code = _run(callback())
assert "Missing authorization code" in missing_code.body
redis.store[module._web_state_cache_key("sid", source)] = json.dumps({
"user_id": "tenant-1",
"client_config": {"web": {"client_id": "cid"}},
"code_verifier": "state-code-verifier",
})
_set_request(module, args={"state": "sid", "code": "code-123"})
success = _run(callback())
assert "Authorization completed successfully." in success.body
result_key = module._web_result_cache_key("sid", source)
assert result_key in redis.store
assert module._web_state_cache_key("sid", source) in redis.deleted
assert flow_calls[-1].redirect_uri == expected_redirect
assert flow_calls[-1].scopes == expected_scopes
assert flow_calls[-1].token_code == "code-123"
assert flow_calls[-1].token_code_verifier == "state-code-verifier"
@pytest.mark.p2
def test_poll_google_web_result_matrix(monkeypatch):
module = _load_connector_app(monkeypatch)
redis = _FakeRedis()
monkeypatch.setattr(module, "REDIS_CONN", redis)
_set_request(module, args={"type": "invalid"}, json_body={"flow_id": "flow-1"})
invalid_type = _run(module.poll_google_web_result())
assert invalid_type["code"] == module.RetCode.ARGUMENT_ERROR
_set_request(module, args={"type": "gmail"}, json_body={"flow_id": "flow-1"})
pending = _run(module.poll_google_web_result())
assert pending["code"] == module.RetCode.RUNNING
redis.store[module._web_result_cache_key("flow-1", "gmail")] = json.dumps(
{"user_id": "another-user", "credentials": "token-x"}
)
_set_request(module, args={"type": "gmail"}, json_body={"flow_id": "flow-1"})
permission_error = _run(module.poll_google_web_result())
assert permission_error["code"] == module.RetCode.PERMISSION_ERROR
redis.store[module._web_result_cache_key("flow-1", "gmail")] = json.dumps(
{"user_id": "tenant-1", "credentials": "token-ok"}
)
_set_request(module, args={"type": "gmail"}, json_body={"flow_id": "flow-1"})
success = _run(module.poll_google_web_result())
assert success["code"] == 0
assert success["data"] == {"credentials": "token-ok"}
assert module._web_result_cache_key("flow-1", "gmail") in redis.deleted
@pytest.mark.p2
def test_box_oauth_start_callback_and_poll_matrix(monkeypatch):
module = _load_connector_app(monkeypatch)
redis = _FakeRedis()
monkeypatch.setattr(module, "REDIS_CONN", redis)
created_auth = []
class _TrackingBoxOAuth(_FakeBoxOAuth):
def __init__(self, config):
super().__init__(config)
created_auth.append(self)
monkeypatch.setattr(module, "BoxOAuth", _TrackingBoxOAuth)
monkeypatch.setattr(module.uuid, "uuid4", lambda: "flow-box")
monkeypatch.setattr(module.time, "time", lambda: 1800000000)
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({}))
missing_params = _run(module.start_box_web_oauth())
assert missing_params["code"] == module.RetCode.ARGUMENT_ERROR
monkeypatch.setattr(
module,
"get_request_json",
lambda: _AwaitableValue({"client_id": "cid", "client_secret": "sec", "redirect_uri": "https://box.local/callback"}),
)
start_ok = _run(module.start_box_web_oauth())
assert start_ok["code"] == 0
assert start_ok["data"]["flow_id"] == "flow-box"
assert "authorization_url" in start_ok["data"]
assert module._web_state_cache_key("flow-box", "box") in redis.store
_set_request(module, args={})
missing_state = _run(module.box_web_oauth_callback())
assert "Missing OAuth parameters." in missing_state.body
_set_request(module, args={"state": "flow-box"})
missing_code = _run(module.box_web_oauth_callback())
assert "Missing authorization code from Box." in missing_code.body
redis.store[module._web_state_cache_key("flow-null", "box")] = "null"
_set_request(module, args={"state": "flow-null", "code": "abc"})
invalid_session = _run(module.box_web_oauth_callback())
assert invalid_session["code"] == module.RetCode.ARGUMENT_ERROR
redis.store[module._web_state_cache_key("flow-box", "box")] = json.dumps(
{"user_id": "tenant-1", "client_id": "cid", "client_secret": "sec"}
)
_set_request(module, args={"state": "flow-box", "code": "abc", "error": "access_denied", "error_description": "denied"})
callback_error = _run(module.box_web_oauth_callback())
assert "denied" in callback_error.body
redis.store[module._web_state_cache_key("flow-ok", "box")] = json.dumps(
{"user_id": "tenant-1", "client_id": "cid", "client_secret": "sec"}
)
_set_request(module, args={"state": "flow-ok", "code": "code-ok"})
callback_success = _run(module.box_web_oauth_callback())
assert "Authorization completed successfully." in callback_success.body
assert created_auth[-1].exchange_code == "code-ok"
assert module._web_result_cache_key("flow-ok", "box") in redis.store
assert module._web_state_cache_key("flow-ok", "box") in redis.deleted
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"flow_id": "flow-ok"}))
redis.store.pop(module._web_result_cache_key("flow-ok", "box"), None)
pending = _run(module.poll_box_web_result())
assert pending["code"] == module.RetCode.RUNNING
redis.store[module._web_result_cache_key("flow-ok", "box")] = json.dumps({"user_id": "another-user"})
permission_error = _run(module.poll_box_web_result())
assert permission_error["code"] == module.RetCode.PERMISSION_ERROR
redis.store[module._web_result_cache_key("flow-ok", "box")] = json.dumps(
{"user_id": "tenant-1", "access_token": "at", "refresh_token": "rt"}
)
poll_success = _run(module.poll_box_web_result())
assert poll_success["code"] == 0
assert poll_success["data"]["credentials"]["access_token"] == "at"
assert module._web_result_cache_key("flow-ok", "box") in redis.deleted

View File

@@ -0,0 +1,335 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pytest
from configs import DATASET_NAME_LIMIT
@pytest.mark.p1
class TestDatasetsAuthorization:
def test_create_requires_auth(self, rest_client_noauth):
res = rest_client_noauth.post("/datasets", json={"name": "auth_test"})
assert res.status_code == 401
payload = res.json()
assert payload["code"] == 401, payload
@pytest.mark.p1
def test_dataset_crud_cycle(rest_client, clear_datasets):
create_res = rest_client.post("/datasets", json={"name": "restful_dataset_crud"})
assert create_res.status_code == 200
create_payload = create_res.json()
assert create_payload["code"] == 0, create_payload
dataset_id = create_payload["data"]["id"]
get_res = rest_client.get(f"/datasets/{dataset_id}")
assert get_res.status_code == 200
get_payload = get_res.json()
assert get_payload["code"] == 0, get_payload
assert get_payload["data"]["id"] == dataset_id, get_payload
update_res = rest_client.put(
f"/datasets/{dataset_id}",
json={"name": "restful_dataset_crud_updated"},
)
assert update_res.status_code == 200
update_payload = update_res.json()
assert update_payload["code"] == 0, update_payload
assert update_payload["data"]["name"] == "restful_dataset_crud_updated", update_payload
list_res = rest_client.get("/datasets", params={"id": dataset_id})
assert list_res.status_code == 200
list_payload = list_res.json()
assert list_payload["code"] == 0, list_payload
assert len(list_payload["data"]) == 1, list_payload
assert list_payload["data"][0]["id"] == dataset_id, list_payload
assert list_payload.get("total_datasets", 0) >= 1, list_payload
delete_res = rest_client.delete("/datasets", json={"ids": [dataset_id]})
assert delete_res.status_code == 200
delete_payload = delete_res.json()
assert delete_payload["code"] == 0, delete_payload
list_after_delete = rest_client.get("/datasets")
assert list_after_delete.status_code == 200
list_after_delete_payload = list_after_delete.json()
assert list_after_delete_payload["code"] == 0, list_after_delete_payload
assert all(dataset["id"] != dataset_id for dataset in list_after_delete_payload["data"]), list_after_delete_payload
@pytest.mark.p2
@pytest.mark.parametrize(
"name, expected_fragment",
[
("", "String should have at least 1 character"),
(" ", "String should have at least 1 character"),
("a" * (DATASET_NAME_LIMIT + 1), f"String should have at most {DATASET_NAME_LIMIT} characters"),
],
ids=["empty", "spaces", "too_long"],
)
def test_dataset_create_name_validation(rest_client, clear_datasets, name, expected_fragment):
res = rest_client.post("/datasets", json={"name": name})
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 101, payload
assert expected_fragment in payload["message"], payload
@pytest.mark.p2
def test_dataset_list_ordering_and_pagination(rest_client, clear_datasets):
for i in range(3):
res = rest_client.post("/datasets", json={"name": f"dataset_page_{i}"})
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 0, payload
list_res = rest_client.get(
"/datasets",
params={"page": 1, "page_size": 2, "orderby": "create_time", "desc": "true"},
)
assert list_res.status_code == 200
list_payload = list_res.json()
assert list_payload["code"] == 0, list_payload
assert len(list_payload["data"]) == 2, list_payload
assert list_payload.get("total_datasets", 0) >= 3, list_payload
@pytest.mark.p2
def test_dataset_search_endpoint(rest_client, ensure_parsed_document):
dataset_id, _ = ensure_parsed_document()
res = rest_client.post(
f"/datasets/{dataset_id}/search",
json={"question": "test TXT file", "page": 1, "size": 10},
)
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 0, payload
assert "chunks" in payload["data"], payload
@pytest.mark.p2
def test_dataset_search_requires_question(rest_client, create_dataset):
dataset_id = create_dataset("dataset_search_missing_question")
res = rest_client.post(f"/datasets/{dataset_id}/search", json={})
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 101, payload
assert "question" in payload["message"], payload
@pytest.mark.p2
def test_dataset_tags_and_aggregation(rest_client, create_dataset):
dataset_id = create_dataset("dataset_tags")
second_dataset_id = create_dataset("dataset_tags_second")
list_tags_res = rest_client.get(f"/datasets/{dataset_id}/tags")
assert list_tags_res.status_code == 200
list_tags_payload = list_tags_res.json()
# Known env/runtime behavior: this route can return 102 when retriever tag
# backend is unavailable for an empty dataset. Keep route-contract coverage.
assert list_tags_payload["code"] in (0, 102), list_tags_payload
aggregate_res = rest_client.get(
"/datasets/tags/aggregation",
params={"dataset_ids": f"{dataset_id},{second_dataset_id}"},
)
assert aggregate_res.status_code == 200
aggregate_payload = aggregate_res.json()
assert aggregate_payload["code"] in (0, 102), aggregate_payload
empty_aggregate_res = rest_client.get("/datasets/tags/aggregation")
assert empty_aggregate_res.status_code == 200
empty_aggregate_payload = empty_aggregate_res.json()
assert empty_aggregate_payload["code"] != 0, empty_aggregate_payload
@pytest.mark.p2
def test_dataset_tags_delete_and_rename_validation(rest_client, create_dataset):
dataset_id = create_dataset("dataset_tag_mutation")
delete_missing_tags = rest_client.delete(f"/datasets/{dataset_id}/tags", json={})
assert delete_missing_tags.status_code == 200
delete_missing_tags_payload = delete_missing_tags.json()
assert delete_missing_tags_payload["code"] != 0, delete_missing_tags_payload
delete_invalid_tags_type = rest_client.delete(f"/datasets/{dataset_id}/tags", json={"tags": "wrong"})
assert delete_invalid_tags_type.status_code == 200
delete_invalid_tags_type_payload = delete_invalid_tags_type.json()
assert delete_invalid_tags_type_payload["code"] != 0, delete_invalid_tags_type_payload
rename_empty = rest_client.put(
f"/datasets/{dataset_id}/tags",
json={"from_tag": "", "to_tag": ""},
)
assert rename_empty.status_code == 200
rename_empty_payload = rename_empty.json()
assert rename_empty_payload["code"] != 0, rename_empty_payload
rename_invalid_dataset = rest_client.put(
"/datasets/invalid_id/tags",
json={"from_tag": "old", "to_tag": "new"},
)
assert rename_invalid_dataset.status_code == 200
rename_invalid_dataset_payload = rename_invalid_dataset.json()
assert rename_invalid_dataset_payload["code"] != 0, rename_invalid_dataset_payload
@pytest.mark.p2
def test_dataset_flattened_metadata(rest_client, create_dataset):
first_dataset_id = create_dataset("flattened_meta_1")
second_dataset_id = create_dataset("flattened_meta_2")
flattened_res = rest_client.get(
"/datasets/metadata/flattened",
params={"dataset_ids": f"{first_dataset_id},{second_dataset_id}"},
)
assert flattened_res.status_code == 200
flattened_payload = flattened_res.json()
assert flattened_payload["code"] == 0, flattened_payload
empty_ids_res = rest_client.get("/datasets/metadata/flattened")
assert empty_ids_res.status_code == 200
empty_ids_payload = empty_ids_res.json()
assert empty_ids_payload["code"] != 0, empty_ids_payload
invalid_dataset_res = rest_client.get(
"/datasets/metadata/flattened",
params={"dataset_ids": "invalid_id"},
)
assert invalid_dataset_res.status_code == 200
invalid_dataset_payload = invalid_dataset_res.json()
assert invalid_dataset_payload["code"] != 0, invalid_dataset_payload
@pytest.mark.p2
def test_dataset_ingestion_summary_and_logs(rest_client, create_dataset):
dataset_id = create_dataset("dataset_ingestions")
summary_res = rest_client.get(f"/datasets/{dataset_id}/ingestions/summary")
assert summary_res.status_code == 200
summary_payload = summary_res.json()
assert summary_payload["code"] == 0, summary_payload
assert "doc_num" in summary_payload["data"], summary_payload
assert "chunk_num" in summary_payload["data"], summary_payload
assert "token_num" in summary_payload["data"], summary_payload
assert "status" in summary_payload["data"], summary_payload
logs_res = rest_client.get(
f"/datasets/{dataset_id}/ingestions",
params={"page": 1, "page_size": 10},
)
assert logs_res.status_code == 200
logs_payload = logs_res.json()
assert logs_payload["code"] == 0, logs_payload
assert "total" in logs_payload["data"], logs_payload
assert "logs" in logs_payload["data"], logs_payload
not_found_log_res = rest_client.get(f"/datasets/{dataset_id}/ingestions/nonexistent_log")
assert not_found_log_res.status_code == 200
not_found_log_payload = not_found_log_res.json()
assert not_found_log_payload["code"] != 0, not_found_log_payload
@pytest.mark.p2
def test_dataset_ingestion_invalid_dataset(rest_client):
summary_res = rest_client.get("/datasets/invalid_id/ingestions/summary")
assert summary_res.status_code == 200
summary_payload = summary_res.json()
assert summary_payload["code"] != 0, summary_payload
logs_res = rest_client.get("/datasets/invalid_id/ingestions")
assert logs_res.status_code == 200
logs_payload = logs_res.json()
assert logs_payload["code"] != 0, logs_payload
log_res = rest_client.get("/datasets/invalid_id/ingestions/some_log_id")
assert log_res.status_code == 200
log_payload = log_res.json()
assert log_payload["code"] != 0, log_payload
@pytest.mark.p2
def test_dataset_index_endpoints(rest_client, create_dataset):
dataset_id = create_dataset("dataset_index_endpoints")
run_invalid_type = rest_client.post(
f"/datasets/{dataset_id}/index",
params={"type": "invalid_type"},
)
assert run_invalid_type.status_code == 200
run_invalid_type_payload = run_invalid_type.json()
assert run_invalid_type_payload["code"] != 0, run_invalid_type_payload
run_no_docs = rest_client.post(
f"/datasets/{dataset_id}/index",
params={"type": "graph"},
)
assert run_no_docs.status_code == 200
run_no_docs_payload = run_no_docs.json()
assert run_no_docs_payload["code"] == 102, run_no_docs_payload
trace_no_task = rest_client.get(
f"/datasets/{dataset_id}/index",
params={"type": "graph"},
)
assert trace_no_task.status_code == 200
trace_no_task_payload = trace_no_task.json()
assert trace_no_task_payload["code"] == 0, trace_no_task_payload
assert trace_no_task_payload["data"] == {}, trace_no_task_payload
delete_graph = rest_client.delete(f"/datasets/{dataset_id}/graph")
assert delete_graph.status_code == 200
delete_graph_payload = delete_graph.json()
assert delete_graph_payload["code"] == 0, delete_graph_payload
delete_invalid_type = rest_client.delete(f"/datasets/{dataset_id}/invalid_type")
assert delete_invalid_type.status_code == 200
delete_invalid_type_payload = delete_invalid_type.json()
assert delete_invalid_type_payload["code"] != 0, delete_invalid_type_payload
@pytest.mark.p2
@pytest.mark.parametrize("index_type", ["graph", "raptor", "mindmap"])
def test_dataset_index_run_with_document_creates_task(rest_client, create_document, index_type):
dataset_id, _ = create_document("dataset_index_graph_source.txt")
run_graph = rest_client.post(
f"/datasets/{dataset_id}/index",
params={"type": index_type},
)
assert run_graph.status_code == 200
run_graph_payload = run_graph.json()
assert run_graph_payload["code"] == 0, run_graph_payload
assert run_graph_payload["data"].get("task_id"), run_graph_payload
@pytest.mark.p2
def test_dataset_embedding_endpoints(rest_client, create_dataset):
dataset_id = create_dataset("dataset_embedding_endpoints")
run_no_docs_res = rest_client.post(f"/datasets/{dataset_id}/embedding")
assert run_no_docs_res.status_code == 200
run_no_docs_payload = run_no_docs_res.json()
assert run_no_docs_payload["code"] == 102, run_no_docs_payload
missing_embd_id_res = rest_client.post(f"/datasets/{dataset_id}/embedding/check", json={})
assert missing_embd_id_res.status_code == 200
missing_embd_id_payload = missing_embd_id_res.json()
assert missing_embd_id_payload["code"] != 0, missing_embd_id_payload
invalid_dataset_res = rest_client.post("/datasets/invalid_id/embedding")
assert invalid_dataset_res.status_code == 200
invalid_dataset_payload = invalid_dataset_res.json()
assert invalid_dataset_payload["code"] != 0, invalid_dataset_payload

View File

@@ -0,0 +1,43 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pytest
@pytest.mark.p2
def test_document_image_invalid_id_contract(rest_client_noauth):
res = rest_client_noauth.get("/documents/images/not-a-valid-image-id")
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 102, payload
assert payload["message"] == "Image not found.", payload
@pytest.mark.p2
def test_document_artifact_requires_auth(rest_client_noauth):
res = rest_client_noauth.get("/documents/artifact/not-an-artifact.txt")
assert res.status_code == 401
payload = res.json()
assert payload["code"] == 401, payload
@pytest.mark.p2
def test_document_artifact_rejects_unsafe_filename(rest_client):
res = rest_client.get("/documents/artifact/not-an-artifact.exe")
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 102, payload
assert payload["message"] == "Invalid file type.", payload

View File

@@ -0,0 +1,122 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pytest
from utils.file_utils import create_txt_file
@pytest.mark.p1
def test_documents_upload_and_list(rest_client, create_dataset, tmp_path):
dataset_id = create_dataset("dataset_upload_list")
fp = create_txt_file(tmp_path / "upload_and_list.txt")
with fp.open("rb") as file_obj:
res = rest_client.post(
f"/datasets/{dataset_id}/documents",
files=[("file", (fp.name, file_obj))],
)
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 0, payload
assert payload["data"][0]["dataset_id"] == dataset_id, payload
list_res = rest_client.get(f"/datasets/{dataset_id}/documents")
assert list_res.status_code == 200
list_payload = list_res.json()
assert list_payload["code"] == 0, list_payload
assert list_payload["data"]["total"] >= 1, list_payload
assert any(doc["name"] == fp.name for doc in list_payload["data"]["docs"]), list_payload
@pytest.mark.p2
def test_documents_upload_missing_file(rest_client, create_dataset):
dataset_id = create_dataset("dataset_upload_missing")
res = rest_client.post(f"/datasets/{dataset_id}/documents")
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 101, payload
assert payload["message"] == "No file part!", payload
@pytest.mark.p2
def test_documents_update_patch_and_delete(rest_client, create_document):
dataset_id, document_id = create_document("update_target.txt")
patch_res = rest_client.patch(
f"/datasets/{dataset_id}/documents/{document_id}",
json={"name": "updated_target.txt"},
)
assert patch_res.status_code == 200
patch_payload = patch_res.json()
assert patch_payload["code"] == 0, patch_payload
assert patch_payload["data"]["name"] == "updated_target.txt", patch_payload
delete_res = rest_client.delete(
f"/datasets/{dataset_id}/documents",
json={"ids": [document_id]},
)
assert delete_res.status_code == 200
delete_payload = delete_res.json()
assert delete_payload["code"] == 0, delete_payload
assert delete_payload["data"]["deleted"] == 1, delete_payload
list_res = rest_client.get(f"/datasets/{dataset_id}/documents")
assert list_res.status_code == 200
list_payload = list_res.json()
assert list_payload["code"] == 0, list_payload
assert all(doc["id"] != document_id for doc in list_payload["data"]["docs"]), list_payload
@pytest.mark.p2
def test_documents_parse_and_stop(rest_client, create_document):
dataset_id, document_id = create_document("parse_target.txt")
parse_res = rest_client.post(
f"/datasets/{dataset_id}/documents/parse",
json={"document_ids": [document_id]},
)
assert parse_res.status_code == 200
parse_payload = parse_res.json()
assert parse_payload["code"] == 0, parse_payload
stop_res = rest_client.post(
f"/datasets/{dataset_id}/documents/stop",
json={"document_ids": [document_id]},
)
assert stop_res.status_code == 200
stop_payload = stop_res.json()
# Depending on timing this can be immediate stop success or "already completed".
assert stop_payload["code"] in (0, 102), stop_payload
if stop_payload["code"] == 102:
assert "already completed" in stop_payload["message"], stop_payload
@pytest.mark.p2
def test_documents_metadata_update_path(rest_client, create_document):
dataset_id, document_id = create_document("metadata_target.txt")
res = rest_client.patch(
f"/datasets/{dataset_id}/documents/metadatas",
json={
"selector": {"document_ids": [document_id]},
"updates": [{"key": "author", "value": "qa"}],
"deletes": [],
},
)
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 0, payload
assert payload["data"]["matched_docs"] == 1, payload
assert payload["data"]["updated"] >= 1, payload

View File

@@ -0,0 +1,632 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import importlib.util
import sys
from enum import Enum
from pathlib import Path
from types import ModuleType, SimpleNamespace
import pytest
class _DummyManager:
def route(self, *_args, **_kwargs):
def decorator(func):
return func
return decorator
class _AwaitableValue:
def __init__(self, value):
self._value = value
def __await__(self):
async def _co():
return self._value
return _co().__await__()
class _DummyFiles(dict):
def __init__(self, file_objs=None):
super().__init__()
self._file_objs = list(file_objs or [])
if file_objs is not None:
self["file"] = self._file_objs
def getlist(self, key):
if key == "file":
return list(self._file_objs)
return []
class _DummyUploadFile:
def __init__(self, filename, blob=b"blob"):
self.filename = filename
self._blob = blob
def read(self):
return self._blob
class _DummyRequest:
def __init__(self, *, content_type="", form=None, files=None, args=None):
self.content_type = content_type
self.form = _AwaitableValue(form or {})
self.files = _AwaitableValue(files if files is not None else _DummyFiles())
self.args = args or {}
class _DummyResponse:
def __init__(self, data):
self.data = data
self.headers = {}
def _run(coro):
return asyncio.run(coro)
def _load_file_api_module(monkeypatch):
repo_root = Path(__file__).resolve().parents[3]
quart_mod = ModuleType("quart")
quart_mod.request = _DummyRequest()
async def _make_response(data):
return _DummyResponse(data)
quart_mod.make_response = _make_response
monkeypatch.setitem(sys.modules, "quart", quart_mod)
api_pkg = ModuleType("api")
api_pkg.__path__ = [str(repo_root / "api")]
monkeypatch.setitem(sys.modules, "api", api_pkg)
apps_pkg = ModuleType("api.apps")
apps_pkg.__path__ = [str(repo_root / "api" / "apps")]
apps_pkg.login_required = lambda func: func
monkeypatch.setitem(sys.modules, "api.apps", apps_pkg)
api_pkg.apps = apps_pkg
services_pkg = ModuleType("api.apps.services")
services_pkg.__path__ = [str(repo_root / "api" / "apps" / "services")]
monkeypatch.setitem(sys.modules, "api.apps.services", services_pkg)
apps_pkg.services = services_pkg
file_api_service_mod = ModuleType("api.apps.services.file_api_service")
async def _upload_file(_tenant_id, _pf_id, _file_objs):
return True, [{"id": "f1"}]
async def _create_folder(_tenant_id, _name, _parent_id=None, _file_type=None):
return True, {"id": "folder1"}
async def _delete_files(_tenant_id, _ids):
return True, True
async def _move_files(_tenant_id, _src_file_ids, _dest_file_id=None, _new_name=None):
return True, True
file_api_service_mod.upload_file = _upload_file
file_api_service_mod.create_folder = _create_folder
file_api_service_mod.list_files = lambda _tenant_id, _args: (True, {"files": [], "total": 0})
file_api_service_mod.delete_files = _delete_files
file_api_service_mod.move_files = _move_files
file_api_service_mod.get_file_content = lambda _tenant_id, _file_id: (
True,
SimpleNamespace(parent_id="bucket1", location="path1", name="doc.txt", type="doc"),
)
file_api_service_mod.get_parent_folder = lambda _file_id, user_id=None: (True, {"parent_folder": {"id": "parent1"}})
file_api_service_mod.get_all_parent_folders = lambda _file_id, user_id=None: (True, {"parent_folders": [{"id": "root"}]})
monkeypatch.setitem(sys.modules, "api.apps.services.file_api_service", file_api_service_mod)
services_pkg.file_api_service = file_api_service_mod
db_pkg = ModuleType("api.db")
db_pkg.__path__ = []
class _FileType(Enum):
DOC = "doc"
VISUAL = "visual"
db_pkg.FileType = _FileType
monkeypatch.setitem(sys.modules, "api.db", db_pkg)
api_pkg.db = db_pkg
file2doc_mod = ModuleType("api.db.services.file2document_service")
file2doc_mod.File2DocumentService = SimpleNamespace(get_storage_address=lambda **_kwargs: ("bucket2", "path2"))
monkeypatch.setitem(sys.modules, "api.db.services.file2document_service", file2doc_mod)
api_utils_mod = ModuleType("api.utils.api_utils")
api_utils_mod.add_tenant_id_to_kwargs = lambda func: func
api_utils_mod.get_error_argument_result = lambda message: {"code": 400, "data": None, "message": message}
api_utils_mod.get_error_data_result = lambda message: {"code": 500, "data": None, "message": message}
api_utils_mod.get_result = lambda data=None: {"code": 0, "data": data, "message": ""}
api_utils_mod.get_json_result = lambda code=0, message="success", data=None: {"code": code, "data": data, "message": message}
monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
validation_mod = ModuleType("api.utils.validation_utils")
validation_mod.CreateFolderReq = object
validation_mod.DeleteFileReq = object
validation_mod.ListFileReq = object
validation_mod.MoveFileReq = object
async def _validate_json_request(_request, _schema):
return {}, None
validation_mod.validate_and_parse_json_request = _validate_json_request
validation_mod.validate_and_parse_request_args = lambda _request, _schema: ({}, None)
monkeypatch.setitem(sys.modules, "api.utils.validation_utils", validation_mod)
web_utils_mod = ModuleType("api.utils.web_utils")
web_utils_mod.CONTENT_TYPE_MAP = {"txt": "text/plain"}
web_utils_mod.apply_safe_file_response_headers = lambda response, content_type, ext: response.headers.update({"content_type": content_type, "ext": ext})
monkeypatch.setitem(sys.modules, "api.utils.web_utils", web_utils_mod)
common_pkg = ModuleType("common")
common_pkg.__path__ = [str(repo_root / "common")]
common_pkg.settings = SimpleNamespace(
STORAGE_IMPL=SimpleNamespace(
get=lambda *_args, **_kwargs: b"blob",
)
)
monkeypatch.setitem(sys.modules, "common", common_pkg)
misc_utils_mod = ModuleType("common.misc_utils")
async def thread_pool_exec(func, *args, **kwargs):
return func(*args, **kwargs)
misc_utils_mod.thread_pool_exec = thread_pool_exec
monkeypatch.setitem(sys.modules, "common.misc_utils", misc_utils_mod)
module_path = repo_root / "api" / "apps" / "restful_apis" / "file_api.py"
spec = importlib.util.spec_from_file_location("api.apps.restful_apis.file_api", module_path)
module = importlib.util.module_from_spec(spec)
module.manager = _DummyManager()
monkeypatch.setitem(sys.modules, "api.apps.restful_apis.file_api", module)
spec.loader.exec_module(module)
return module
@pytest.mark.p2
def test_create_or_upload_multipart_requires_file(monkeypatch):
module = _load_file_api_module(monkeypatch)
monkeypatch.setattr(module, "request", _DummyRequest(content_type="multipart/form-data", form={}, files=_DummyFiles()))
res = _run(module.create_or_upload("tenant1"))
assert res["code"] == 400
assert res["message"] == "No file part!"
@pytest.mark.p2
def test_create_or_upload_uploads_via_new_service(monkeypatch):
module = _load_file_api_module(monkeypatch)
files = _DummyFiles([_DummyUploadFile("a.txt")])
monkeypatch.setattr(module, "request", _DummyRequest(content_type="multipart/form-data", form={"parent_id": "pf1"}, files=files))
seen = {}
async def _upload_file(tenant_id, pf_id, file_objs):
seen["args"] = (tenant_id, pf_id, [f.filename for f in file_objs])
return True, [{"id": "f1"}]
monkeypatch.setattr(module.file_api_service, "upload_file", _upload_file)
res = _run(module.create_or_upload("tenant1"))
assert seen["args"] == ("tenant1", "pf1", ["a.txt"])
assert res["code"] == 0
assert res["data"] == [{"id": "f1"}]
@pytest.mark.p2
def test_create_or_upload_creates_folder_from_json(monkeypatch):
module = _load_file_api_module(monkeypatch)
monkeypatch.setattr(module, "request", _DummyRequest(content_type="application/json"))
async def _validate(_request, _schema):
return {"name": "folder-a", "parent_id": "pf1", "type": "folder"}, None
async def _create_folder(tenant_id, name, parent_id=None, file_type=None):
return True, {"tenant_id": tenant_id, "name": name, "parent_id": parent_id, "type": file_type}
monkeypatch.setattr(module, "validate_and_parse_json_request", _validate)
monkeypatch.setattr(module.file_api_service, "create_folder", _create_folder)
res = _run(module.create_or_upload("tenant1"))
assert res["code"] == 0
assert res["data"]["tenant_id"] == "tenant1"
assert res["data"]["name"] == "folder-a"
@pytest.mark.p2
def test_list_files_validation_error(monkeypatch):
module = _load_file_api_module(monkeypatch)
monkeypatch.setattr(module, "validate_and_parse_request_args", lambda _request, _schema: (None, "bad args"))
res = _run(module.list_files("tenant1"))
assert res["code"] == 400
assert res["message"] == "bad args"
@pytest.mark.p2
def test_move_uses_new_payload_shape(monkeypatch):
module = _load_file_api_module(monkeypatch)
async def _validate(_request, _schema):
return {"src_file_ids": ["f1"], "dest_file_id": "pf2"}, None
seen = {}
async def _move_files(tenant_id, src_file_ids, dest_file_id=None, new_name=None):
seen["args"] = (tenant_id, src_file_ids, dest_file_id, new_name)
return True, True
monkeypatch.setattr(module, "validate_and_parse_json_request", _validate)
monkeypatch.setattr(module.file_api_service, "move_files", _move_files)
res = _run(module.move("tenant1"))
assert seen["args"] == ("tenant1", ["f1"], "pf2", None)
assert res["code"] == 0
assert res["data"] is True
@pytest.mark.p2
def test_rename_via_move_route(monkeypatch):
module = _load_file_api_module(monkeypatch)
async def _validate(_request, _schema):
return {"src_file_ids": ["file1"], "new_name": "renamed.txt"}, None
seen = {}
async def _move_files(tenant_id, src_file_ids, dest_file_id=None, new_name=None):
seen["args"] = (tenant_id, src_file_ids, dest_file_id, new_name)
return True, True
monkeypatch.setattr(module, "validate_and_parse_json_request", _validate)
monkeypatch.setattr(module.file_api_service, "move_files", _move_files)
res = _run(module.move("tenant1"))
assert seen["args"] == ("tenant1", ["file1"], None, "renamed.txt")
assert res["code"] == 0
assert res["data"] is True
@pytest.mark.p2
def test_download_falls_back_to_document_storage(monkeypatch):
module = _load_file_api_module(monkeypatch)
storage_calls = []
def _get(bucket, location):
storage_calls.append((bucket, location))
return b"" if len(storage_calls) == 1 else b"fallback-blob"
monkeypatch.setattr(module.settings, "STORAGE_IMPL", SimpleNamespace(get=_get))
res = _run(module.download("tenant1", "file1"))
assert storage_calls == [("bucket1", "path1"), ("bucket2", "path2")]
assert res.data == b"fallback-blob"
assert res.headers["content_type"] == "text/plain"
assert res.headers["ext"] == "txt"
@pytest.mark.p2
def test_parent_and_ancestors_use_new_routes(monkeypatch):
module = _load_file_api_module(monkeypatch)
parent_res = _run(module.parent_folder("tenant1", "file1"))
ancestors_res = _run(module.ancestors("tenant1", "file1"))
assert parent_res["code"] == 0
assert parent_res["data"]["parent_folder"]["id"] == "parent1"
assert ancestors_res["code"] == 0
assert ancestors_res["data"]["parent_folders"][0]["id"] == "root"
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import functools
from copy import deepcopy
import pytest
class _DummyManager:
def route(self, *_args, **_kwargs):
def decorator(func):
return func
return decorator
class _DummyFile:
def __init__(self, file_id, file_type, *, name="file.txt", location="loc", size=1):
self.id = file_id
self.type = file_type
self.name = name
self.location = location
self.size = size
class _FalsyFile(_DummyFile):
def __bool__(self):
return False
def _run(coro):
return asyncio.run(coro)
def _set_request_json(monkeypatch, module, payload_state):
async def _req_json():
return deepcopy(payload_state)
monkeypatch.setattr(module, "get_request_json", _req_json)
@pytest.fixture(scope="session")
def auth():
return "unit-auth"
@pytest.fixture(scope="session", autouse=True)
def set_tenant_info():
return None
def _load_file2document_module(monkeypatch):
repo_root = Path(__file__).resolve().parents[3]
api_pkg = ModuleType("api")
api_pkg.__path__ = [str(repo_root / "api")]
monkeypatch.setitem(sys.modules, "api", api_pkg)
apps_mod = ModuleType("api.apps")
apps_mod.__path__ = [str(repo_root / "api" / "apps")]
apps_mod.current_user = SimpleNamespace(id="user-1")
apps_mod.login_required = lambda func: func
monkeypatch.setitem(sys.modules, "api.apps", apps_mod)
api_pkg.apps = apps_mod
db_pkg = ModuleType("api.db")
db_pkg.__path__ = []
class _FileType(Enum):
FOLDER = "folder"
DOC = "doc"
db_pkg.FileType = _FileType
monkeypatch.setitem(sys.modules, "api.db", db_pkg)
api_pkg.db = db_pkg
services_pkg = ModuleType("api.db.services")
services_pkg.__path__ = []
monkeypatch.setitem(sys.modules, "api.db.services", services_pkg)
common_pkg = ModuleType("api.common")
common_pkg.__path__ = []
monkeypatch.setitem(sys.modules, "api.common", common_pkg)
permission_mod = ModuleType("api.common.check_team_permission")
permission_mod.check_file_team_permission = lambda *_args, **_kwargs: True
permission_mod.check_kb_team_permission = lambda *_args, **_kwargs: True
monkeypatch.setitem(sys.modules, "api.common.check_team_permission", permission_mod)
common_pkg.check_team_permission = permission_mod
file2document_mod = ModuleType("api.db.services.file2document_service")
class _StubFile2DocumentService:
@staticmethod
def get_by_file_id(_file_id):
return []
@staticmethod
def delete_by_file_id(*_args, **_kwargs):
return None
@staticmethod
def insert(_payload):
return SimpleNamespace(to_json=lambda: {})
file2document_mod.File2DocumentService = _StubFile2DocumentService
monkeypatch.setitem(sys.modules, "api.db.services.file2document_service", file2document_mod)
services_pkg.file2document_service = file2document_mod
file_service_mod = ModuleType("api.db.services.file_service")
class _StubFileService:
@staticmethod
def get_by_ids(_file_ids):
return []
@staticmethod
def get_all_innermost_file_ids(_file_id, _acc):
return []
@staticmethod
def get_by_id(_file_id):
return True, _DummyFile(_file_id, _FileType.DOC.value)
@staticmethod
def get_parser(_file_type, _file_name, parser_id):
return parser_id
file_service_mod.FileService = _StubFileService
monkeypatch.setitem(sys.modules, "api.db.services.file_service", file_service_mod)
services_pkg.file_service = file_service_mod
kb_service_mod = ModuleType("api.db.services.knowledgebase_service")
class _StubKnowledgebaseService:
@staticmethod
def get_by_id(_kb_id):
return False, None
kb_service_mod.KnowledgebaseService = _StubKnowledgebaseService
monkeypatch.setitem(sys.modules, "api.db.services.knowledgebase_service", kb_service_mod)
services_pkg.knowledgebase_service = kb_service_mod
document_service_mod = ModuleType("api.db.services.document_service")
class _StubDocumentService:
@staticmethod
def get_by_id(doc_id):
return True, SimpleNamespace(id=doc_id)
@staticmethod
def get_tenant_id(_doc_id):
return "tenant-1"
@staticmethod
def remove_document(*_args, **_kwargs):
return True
@staticmethod
def insert(_payload):
return SimpleNamespace(id="doc-1")
document_service_mod.DocumentService = _StubDocumentService
monkeypatch.setitem(sys.modules, "api.db.services.document_service", document_service_mod)
services_pkg.document_service = document_service_mod
api_utils_mod = ModuleType("api.utils.api_utils")
def get_json_result(data=None, message="", code=0):
return {"code": code, "data": data, "message": message}
def get_data_error_result(message=""):
return {"code": 102, "data": None, "message": message}
async def get_request_json():
return {}
def server_error_response(err):
return {"code": 500, "data": None, "message": str(err)}
def validate_request(*_keys):
def _decorator(func):
@functools.wraps(func)
async def _wrapper(*args, **kwargs):
return await func(*args, **kwargs)
return _wrapper
return _decorator
api_utils_mod.get_json_result = get_json_result
api_utils_mod.get_data_error_result = get_data_error_result
api_utils_mod.get_request_json = get_request_json
api_utils_mod.server_error_response = server_error_response
api_utils_mod.validate_request = validate_request
monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
misc_utils_mod = ModuleType("common.misc_utils")
misc_utils_mod.get_uuid = lambda: "uuid"
monkeypatch.setitem(sys.modules, "common.misc_utils", misc_utils_mod)
constants_mod = ModuleType("common.constants")
class _RetCode:
ARGUMENT_ERROR = 101
constants_mod.RetCode = _RetCode
monkeypatch.setitem(sys.modules, "common.constants", constants_mod)
module_name = "test_file2document_routes_unit_module"
module_path = repo_root / "api" / "apps" / "restful_apis" / "file2document_api.py"
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
module.manager = _DummyManager()
monkeypatch.setitem(sys.modules, module_name, module)
spec.loader.exec_module(module)
return module
@pytest.mark.p2
def test_convert_branch_matrix_unit(monkeypatch):
module = _load_file2document_module(monkeypatch)
req_state = {"kb_ids": ["kb-1"], "file_ids": ["f1"]}
_set_request_json(monkeypatch, module, req_state)
# Falsy file returns "File not found!" during synchronous validation.
monkeypatch.setattr(module.FileService, "get_by_ids", lambda _ids: [_FalsyFile("f1", module.FileType.DOC.value)])
res = _run(module.convert())
assert res["code"] == 102
assert res["message"] == "File not found!"
# Valid file but invalid kb returns "Can't find this dataset!" during synchronous validation.
monkeypatch.setattr(module.FileService, "get_by_ids", lambda _ids: [_DummyFile("f1", module.FileType.DOC.value)])
res = _run(module.convert())
assert res["code"] == 102
assert res["message"] == "Can't find this dataset!"
kb = SimpleNamespace(id="kb-1", parser_id="naive", pipeline_id="p1", parser_config={})
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, kb))
# Unauthorized file access is rejected before scheduling background work.
monkeypatch.setattr(module, "check_file_team_permission", lambda *_args, **_kwargs: False)
res = _run(module.convert())
assert res["code"] == 102
assert res["message"] == "No authorization."
# Unauthorized dataset access is rejected before scheduling background work.
monkeypatch.setattr(module, "check_file_team_permission", lambda *_args, **_kwargs: True)
monkeypatch.setattr(module, "check_kb_team_permission", lambda *_args, **_kwargs: False)
res = _run(module.convert())
assert res["code"] == 102
assert res["message"] == "No authorization."
# Valid file and kb schedule background work and return data=True immediately.
monkeypatch.setattr(module, "check_kb_team_permission", lambda *_args, **_kwargs: True)
res = _run(module.convert())
assert res["code"] == 0
assert res["data"] is True
# Folder expansion schedules background work and returns data=True immediately.
req_state["file_ids"] = ["folder-1"]
monkeypatch.setattr(module.FileService, "get_by_ids", lambda _ids: [_DummyFile("folder-1", module.FileType.FOLDER.value, name="folder")])
monkeypatch.setattr(module.FileService, "get_all_innermost_file_ids", lambda _file_id, _acc: ["inner-1"])
res = _run(module.convert())
assert res["code"] == 0
assert res["data"] is True
# Exception in file lookup returns 500.
req_state["file_ids"] = ["f1"]
monkeypatch.setattr(
module.FileService,
"get_by_ids",
lambda _ids: (_ for _ in ()).throw(RuntimeError("convert boom")),
)
res = _run(module.convert())
assert res["code"] == 500
assert "convert boom" in res["message"]

View File

@@ -0,0 +1,37 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pytest
@pytest.mark.p2
def test_langfuse_api_key_routes_require_auth(rest_client_noauth):
for method in ("get", "post", "put", "delete"):
requester = getattr(rest_client_noauth, method)
kwargs = {"json": {"secret_key": "s", "public_key": "p", "host": "http://example.com"}} if method in {"post", "put"} else {}
res = requester("/langfuse/api-key", **kwargs)
assert res.status_code == 401
payload = res.json()
assert payload["code"] == 401, (method, payload)
@pytest.mark.p2
def test_langfuse_api_key_missing_required_fields(rest_client):
res = rest_client.post("/langfuse/api-key", json={"secret_key": "", "public_key": "pub", "host": "http://host"})
assert res.status_code == 200
payload = res.json()
assert payload["code"] in (101, 102), payload
assert "required" in payload["message"].lower() or "missing" in payload["message"].lower(), payload

View File

@@ -0,0 +1,745 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import importlib.util
import inspect
import json
import sys
from functools import wraps
from pathlib import Path
from types import ModuleType, SimpleNamespace
import pytest
class _DummyManager:
def route(self, *_args, **_kwargs):
def decorator(func):
return func
return decorator
class _Args(dict):
def getlist(self, key):
value = self.get(key, [])
if isinstance(value, list):
return value
return [value]
class _Field:
def __init__(self, name):
self.name = name
def __eq__(self, other):
return (self.name, other)
class _DummyMCPServer:
id = _Field("id")
tenant_id = _Field("tenant_id")
def __init__(self, **kwargs):
self.id = kwargs.get("id", "")
self.name = kwargs.get("name", "")
self.url = kwargs.get("url", "")
self.server_type = kwargs.get("server_type", "sse")
self.tenant_id = kwargs.get("tenant_id", "tenant_1")
self.variables = kwargs.get("variables", {})
self.headers = kwargs.get("headers", {})
def to_dict(self):
return {
"id": self.id,
"name": self.name,
"url": self.url,
"server_type": self.server_type,
"tenant_id": self.tenant_id,
"variables": self.variables,
"headers": self.headers,
}
class _DummyMCPServerService:
@staticmethod
def get_servers(*_args, **_kwargs):
return []
@staticmethod
def get_or_none(*_args, **_kwargs):
return None
@staticmethod
def get_by_id(*_args, **_kwargs):
return False, None
@staticmethod
def get_by_name_and_tenant(*_args, **_kwargs):
return False, None
@staticmethod
def insert(**_kwargs):
return True
@staticmethod
def filter_update(*_args, **_kwargs):
return True
@staticmethod
def delete_by_ids(*_args, **_kwargs):
return True
class _DummyTenantService:
@staticmethod
def get_by_id(*_args, **_kwargs):
return True, SimpleNamespace(id="tenant_1")
class _DummyTool:
def __init__(self, name):
self._name = name
def model_dump(self):
return {"name": self._name}
class _DummyMCPToolCallSession:
def __init__(self, _mcp_server, _variables):
self._tools = [_DummyTool("tool_a"), _DummyTool("tool_b")]
def get_tools(self, _timeout):
return self._tools
def tool_call(self, _name, _arguments, _timeout):
return "ok"
def _run(coro):
return asyncio.run(coro)
def _set_request_json(monkeypatch, module, payload):
async def _request_json():
return payload
monkeypatch.setattr(module, "get_request_json", _request_json)
@pytest.fixture(scope="session")
def auth():
return "unit-auth"
@pytest.fixture(scope="session", autouse=True)
def set_tenant_info():
return None
def _load_mcp_api(monkeypatch):
repo_root = Path(__file__).resolve().parents[3]
quart_mod = ModuleType("quart")
quart_mod.Response = object
quart_mod.request = SimpleNamespace(args=_Args({}))
monkeypatch.setitem(sys.modules, "quart", quart_mod)
common_pkg = ModuleType("common")
common_pkg.__path__ = [str(repo_root / "common")]
monkeypatch.setitem(sys.modules, "common", common_pkg)
constants_mod = ModuleType("common.constants")
constants_mod.VALID_MCP_SERVER_TYPES = {"sse", "streamable-http"}
monkeypatch.setitem(sys.modules, "common.constants", constants_mod)
apps_mod = ModuleType("api.apps")
apps_mod.current_user = SimpleNamespace(id="tenant_1")
apps_mod.login_required = lambda func: func
monkeypatch.setitem(sys.modules, "api.apps", apps_mod)
db_models_mod = ModuleType("api.db.db_models")
db_models_mod.MCPServer = _DummyMCPServer
monkeypatch.setitem(sys.modules, "api.db.db_models", db_models_mod)
mcp_service_mod = ModuleType("api.db.services.mcp_server_service")
mcp_service_mod.MCPServerService = _DummyMCPServerService
monkeypatch.setitem(sys.modules, "api.db.services.mcp_server_service", mcp_service_mod)
user_service_mod = ModuleType("api.db.services.user_service")
user_service_mod.TenantService = _DummyTenantService
monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod)
mcp_conn_mod = ModuleType("common.mcp_tool_call_conn")
mcp_conn_mod.MCPToolCallSession = _DummyMCPToolCallSession
mcp_conn_mod.close_multiple_mcp_toolcall_sessions = lambda _sessions: None
monkeypatch.setitem(sys.modules, "common.mcp_tool_call_conn", mcp_conn_mod)
api_utils_mod = ModuleType("api.utils.api_utils")
async def _default_request_json():
return {}
def _get_json_result(code=0, message="success", data=None):
return {"code": code, "message": message, "data": data}
def _get_data_error_result(code=102, message="Sorry! Data missing!"):
return {"code": code, "message": message}
def _server_error_response(error):
return {"code": 100, "message": repr(error)}
async def _get_mcp_tools(*_args, **_kwargs):
return {}
def _validate_request(*_args, **_kwargs):
def _decorator(func):
@wraps(func)
async def _wrapped(*func_args, **func_kwargs):
if inspect.iscoroutinefunction(func):
return await func(*func_args, **func_kwargs)
return func(*func_args, **func_kwargs)
return _wrapped
return _decorator
api_utils_mod.get_request_json = _default_request_json
api_utils_mod.get_json_result = _get_json_result
api_utils_mod.get_data_error_result = _get_data_error_result
api_utils_mod.server_error_response = _server_error_response
api_utils_mod.validate_request = _validate_request
api_utils_mod.get_mcp_tools = _get_mcp_tools
monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
web_utils_mod = ModuleType("api.utils.web_utils")
def _get_float(data, key, default):
try:
return float(data.get(key, default))
except (TypeError, ValueError):
return default
def _safe_json_parse(value):
if isinstance(value, (dict, list)):
return value
if value in (None, ""):
return {}
try:
return json.loads(value)
except (TypeError, ValueError):
return {}
web_utils_mod.get_float = _get_float
web_utils_mod.safe_json_parse = _safe_json_parse
monkeypatch.setitem(sys.modules, "api.utils.web_utils", web_utils_mod)
module_name = "test_mcp_api_unit_module"
module_path = repo_root / "api" / "apps" / "restful_apis" / "mcp_api.py"
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
module.manager = _DummyManager()
monkeypatch.setitem(sys.modules, module_name, module)
spec.loader.exec_module(module)
return module
@pytest.mark.p2
def test_list_mcp_desc_pagination_and_exception(monkeypatch):
module = _load_mcp_api(monkeypatch)
monkeypatch.setattr(
module,
"request",
SimpleNamespace(args=_Args({"keywords": "k", "page": "2", "page_size": "1", "orderby": "create_time", "desc": "false"})),
)
_set_request_json(monkeypatch, module, {"mcp_ids": []})
monkeypatch.setattr(module.MCPServerService, "get_servers", lambda *_args, **_kwargs: [{"id": "a"}, {"id": "b"}])
res = _run(module.list_mcp())
assert res["code"] == 0
assert res["data"]["total"] == 2
assert res["data"]["mcp_servers"] == [{"id": "b"}]
monkeypatch.setattr(module, "request", SimpleNamespace(args=_Args({})))
_set_request_json(monkeypatch, module, {"mcp_ids": []})
def _raise_list(*_args, **_kwargs):
raise RuntimeError("list explode")
monkeypatch.setattr(module.MCPServerService, "get_servers", _raise_list)
res = _run(module.list_mcp())
assert res["code"] == 100
assert "list explode" in res["message"]
@pytest.mark.p2
def test_detail_not_found_success_and_exception(monkeypatch):
module = _load_mcp_api(monkeypatch)
monkeypatch.setattr(module, "request", SimpleNamespace(args=_Args({})))
monkeypatch.setattr(module.MCPServerService, "get_or_none", lambda **_kwargs: None)
res = module.detail("mcp-1")
assert res["code"] == 102
assert "Cannot find MCP server mcp-1 for user tenant_1" in res["message"]
monkeypatch.setattr(
module.MCPServerService,
"get_or_none",
lambda **_kwargs: _DummyMCPServer(id="mcp-1", name="srv", url="http://a", server_type="sse", tenant_id="tenant_1"),
)
res = module.detail("mcp-1")
assert res["code"] == 0
assert res["data"]["id"] == "mcp-1"
def _raise_detail(**_kwargs):
raise RuntimeError("detail explode")
monkeypatch.setattr(module.MCPServerService, "get_or_none", _raise_detail)
res = module.detail("mcp-1")
assert res["code"] == 100
assert "detail explode" in res["message"]
@pytest.mark.p2
def test_create_validation_guards(monkeypatch):
module = _load_mcp_api(monkeypatch)
monkeypatch.setattr(module.MCPServerService, "get_by_name_and_tenant", lambda **_kwargs: (False, None))
_set_request_json(monkeypatch, module, {"name": "srv", "url": "http://a", "server_type": "invalid"})
res = _run(module.create.__wrapped__())
assert "Unsupported MCP server type" in res["message"]
_set_request_json(monkeypatch, module, {"name": "", "url": "http://a", "server_type": "sse"})
res = _run(module.create.__wrapped__())
assert "Invalid MCP name" in res["message"]
monkeypatch.setattr(module.MCPServerService, "get_by_name_and_tenant", lambda **_kwargs: (True, object()))
_set_request_json(monkeypatch, module, {"name": "srv", "url": "http://a", "server_type": "sse"})
res = _run(module.create.__wrapped__())
assert "Duplicated MCP server name" in res["message"]
monkeypatch.setattr(module.MCPServerService, "get_by_name_and_tenant", lambda **_kwargs: (False, None))
_set_request_json(monkeypatch, module, {"name": "srv", "url": "", "server_type": "sse"})
res = _run(module.create.__wrapped__())
assert "Invalid url" in res["message"]
@pytest.mark.p2
def test_create_service_paths(monkeypatch):
module = _load_mcp_api(monkeypatch)
base_payload = {
"name": "srv",
"url": "http://server",
"server_type": "sse",
"headers": '{"Authorization": "x"}',
"variables": '{"tools": {"old": 1}, "token": "abc"}',
"timeout": "2.5",
}
monkeypatch.setattr(module, "get_uuid", lambda: "uuid-create")
monkeypatch.setattr(module.MCPServerService, "get_by_name_and_tenant", lambda **_kwargs: (False, None))
_set_request_json(monkeypatch, module, dict(base_payload))
monkeypatch.setattr(module.TenantService, "get_by_id", lambda *_args, **_kwargs: (False, None))
res = _run(module.create.__wrapped__())
assert "Tenant not found" in res["message"]
_set_request_json(monkeypatch, module, dict(base_payload))
monkeypatch.setattr(module.TenantService, "get_by_id", lambda *_args, **_kwargs: (True, object()))
async def _thread_pool_tools_error(_func, _servers, _timeout):
return None, "tools error"
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_tools_error)
res = _run(module.create.__wrapped__())
assert res["code"] == 102
assert "tools error" in res["message"]
_set_request_json(monkeypatch, module, dict(base_payload))
async def _thread_pool_ok(_func, servers, _timeout):
return {servers[0].name: [{"name": "tool_a"}, {"invalid": True}]}, None
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_ok)
monkeypatch.setattr(module.MCPServerService, "insert", lambda **_kwargs: False)
res = _run(module.create.__wrapped__())
assert res["code"] == 102
assert "Failed to create MCP server" in res["message"]
_set_request_json(monkeypatch, module, dict(base_payload))
monkeypatch.setattr(module.MCPServerService, "insert", lambda **_kwargs: True)
res = _run(module.create.__wrapped__())
assert res["code"] == 0
assert res["data"]["id"] == "uuid-create"
assert res["data"]["tenant_id"] == "tenant_1"
assert res["data"]["variables"]["tools"] == {"tool_a": {"name": "tool_a"}}
_set_request_json(monkeypatch, module, dict(base_payload))
async def _thread_pool_raises(_func, _servers, _timeout):
raise RuntimeError("create explode")
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_raises)
res = _run(module.create.__wrapped__())
assert res["code"] == 100
assert "create explode" in res["message"]
@pytest.mark.p2
def test_update_validation_guards(monkeypatch):
module = _load_mcp_api(monkeypatch)
existing = _DummyMCPServer(id="mcp-1", name="srv", url="http://server", server_type="sse", tenant_id="tenant_1", variables={}, headers={})
_set_request_json(monkeypatch, module, {"mcp_id": "mcp-1"})
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (False, None))
res = _run(module.update("mcp-1"))
assert "Cannot find MCP server" in res["message"]
_set_request_json(monkeypatch, module, {"mcp_id": "mcp-1"})
monkeypatch.setattr(
module.MCPServerService,
"get_by_id",
lambda _mcp_id: (True, _DummyMCPServer(id="mcp-1", name="srv", url="http://server", server_type="sse", tenant_id="other", variables={}, headers={})),
)
res = _run(module.update("mcp-1"))
assert "Cannot find MCP server" in res["message"]
_set_request_json(monkeypatch, module, {"mcp_id": "mcp-1", "server_type": "invalid"})
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, existing))
res = _run(module.update("mcp-1"))
assert "Unsupported MCP server type" in res["message"]
_set_request_json(monkeypatch, module, {"mcp_id": "mcp-1", "name": "a" * 256})
res = _run(module.update("mcp-1"))
assert "Invalid MCP name" in res["message"]
_set_request_json(monkeypatch, module, {"mcp_id": "mcp-1", "url": ""})
res = _run(module.update("mcp-1"))
assert "Invalid url" in res["message"]
@pytest.mark.p2
def test_update_service_paths(monkeypatch):
module = _load_mcp_api(monkeypatch)
existing = _DummyMCPServer(
id="mcp-1",
name="srv",
url="http://server",
server_type="sse",
tenant_id="tenant_1",
variables={"tools": {"old": {"enabled": True}}, "token": "abc"},
headers={"Authorization": "old"},
)
updated = _DummyMCPServer(
id="mcp-1",
name="srv-new",
url="http://server-new",
server_type="sse",
tenant_id="tenant_1",
variables={"tools": {"tool_a": {"name": "tool_a"}}},
headers={"Authorization": "new"},
)
base_payload = {
"mcp_id": "mcp-1",
"name": "srv-new",
"url": "http://server-new",
"server_type": "sse",
"headers": '{"Authorization": "new"}',
"variables": '{"tools": {"ignore": 1}, "token": "new"}',
"timeout": "3.0",
}
_set_request_json(monkeypatch, module, dict(base_payload))
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, existing))
async def _thread_pool_tools_error(_func, _servers, _timeout):
return None, "update tools error"
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_tools_error)
res = _run(module.update("mcp-1"))
assert res["code"] == 102
assert "update tools error" in res["message"]
_set_request_json(monkeypatch, module, dict(base_payload))
async def _thread_pool_ok(_func, servers, _timeout):
return {servers[0].name: [{"name": "tool_a"}, {"bad": True}]}, None
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_ok)
monkeypatch.setattr(module.MCPServerService, "filter_update", lambda *_args, **_kwargs: False)
res = _run(module.update("mcp-1"))
assert "Failed to updated MCP server" in res["message"]
_set_request_json(monkeypatch, module, dict(base_payload))
monkeypatch.setattr(module.MCPServerService, "filter_update", lambda *_args, **_kwargs: True)
def _get_by_id_fetch_fail(_mcp_id):
if _get_by_id_fetch_fail.calls == 0:
_get_by_id_fetch_fail.calls += 1
return True, existing
return False, None
_get_by_id_fetch_fail.calls = 0
monkeypatch.setattr(module.MCPServerService, "get_by_id", _get_by_id_fetch_fail)
res = _run(module.update("mcp-1"))
assert "Failed to fetch updated MCP server" in res["message"]
_set_request_json(monkeypatch, module, dict(base_payload))
def _get_by_id_success(_mcp_id):
if _get_by_id_success.calls == 0:
_get_by_id_success.calls += 1
return True, existing
return True, updated
_get_by_id_success.calls = 0
monkeypatch.setattr(module.MCPServerService, "get_by_id", _get_by_id_success)
res = _run(module.update("mcp-1"))
assert res["code"] == 0
assert res["data"]["id"] == "mcp-1"
_set_request_json(monkeypatch, module, dict(base_payload))
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, existing))
async def _thread_pool_raises(_func, _servers, _timeout):
raise RuntimeError("update explode")
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_raises)
res = _run(module.update("mcp-1"))
assert res["code"] == 100
assert "update explode" in res["message"]
@pytest.mark.p2
def test_rm_failure_success_and_exception(monkeypatch):
module = _load_mcp_api(monkeypatch)
server = _DummyMCPServer(id="id1", name="srv", url="http://a", server_type="sse", tenant_id="tenant_1", variables={})
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, server))
_set_request_json(monkeypatch, module, {"mcp_ids": ["a", "b"]})
monkeypatch.setattr(module.MCPServerService, "delete_by_ids", lambda _ids: False)
res = _run(module.rm("id1"))
assert "Failed to delete MCP servers" in res["message"]
_set_request_json(monkeypatch, module, {"mcp_ids": ["a", "b"]})
monkeypatch.setattr(module.MCPServerService, "delete_by_ids", lambda _ids: True)
res = _run(module.rm("id1"))
assert res["code"] == 0
assert res["data"] is True
_set_request_json(monkeypatch, module, {"mcp_ids": ["a", "b"]})
def _raise_rm(_ids):
raise RuntimeError("rm explode")
monkeypatch.setattr(module.MCPServerService, "delete_by_ids", _raise_rm)
res = _run(module.rm("id1"))
assert res["code"] == 100
assert "rm explode" in res["message"]
@pytest.mark.p2
def test_import_multiple_missing_servers_and_exception(monkeypatch):
module = _load_mcp_api(monkeypatch)
_set_request_json(monkeypatch, module, {"mcpServers": {}})
res = _run(module.import_multiple.__wrapped__())
assert "No MCP servers provided" in res["message"]
_set_request_json(monkeypatch, module, {"mcpServers": {"srv": {"type": "sse", "url": "http://x"}}, "timeout": "1"})
def _raise_import(**_kwargs):
raise RuntimeError("import explode")
monkeypatch.setattr(module.MCPServerService, "get_by_name_and_tenant", _raise_import)
res = _run(module.import_multiple.__wrapped__())
assert res["code"] == 100
assert "import explode" in res["message"]
@pytest.mark.p2
def test_import_multiple_mixed_results(monkeypatch):
module = _load_mcp_api(monkeypatch)
payload = {
"mcpServers": {
"missing_fields": {"type": "sse"},
"": {"type": "sse", "url": "http://empty"},
"dup": {"type": "sse", "url": "http://dup", "authorization_token": "dup-token"},
"tool_err": {"type": "sse", "url": "http://err"},
"insert_fail": {"type": "sse", "url": "http://fail"},
},
"timeout": "3",
}
_set_request_json(monkeypatch, module, payload)
monkeypatch.setattr(module, "get_uuid", lambda: "uuid-import")
def _get_by_name_and_tenant(name, tenant_id):
if name == "dup" and not _get_by_name_and_tenant.first_dup_seen:
_get_by_name_and_tenant.first_dup_seen = True
return True, object()
return False, None
_get_by_name_and_tenant.first_dup_seen = False
monkeypatch.setattr(module.MCPServerService, "get_by_name_and_tenant", _get_by_name_and_tenant)
async def _thread_pool_exec(func, servers, _timeout):
mcp_server = servers[0]
if mcp_server.name == "tool_err":
return None, "tool call failed"
return {mcp_server.name: [{"name": "tool_a"}, {"invalid": True}]}, None
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_exec)
def _insert(**kwargs):
return kwargs["name"] != "insert_fail"
monkeypatch.setattr(module.MCPServerService, "insert", _insert)
res = _run(module.import_multiple.__wrapped__())
assert res["code"] == 0
results = {item["server"]: item for item in res["data"]["results"]}
assert results["missing_fields"]["success"] is False
assert "Missing required fields" in results["missing_fields"]["message"]
assert results[""]["success"] is False
assert "Invalid MCP name" in results[""]["message"]
assert results["tool_err"]["success"] is False
assert "tool call failed" in results["tool_err"]["message"]
assert results["insert_fail"]["success"] is False
assert "Failed to create MCP server" in results["insert_fail"]["message"]
assert results["dup"]["success"] is True
assert results["dup"]["new_name"] == "dup_0"
assert "Renamed from 'dup' to 'dup_0' avoid duplication" == results["dup"]["message"]
@pytest.mark.p2
def test_detail_download_success_and_exception(monkeypatch):
module = _load_mcp_api(monkeypatch)
monkeypatch.setattr(module, "request", SimpleNamespace(args=_Args({"mode": "download"})))
monkeypatch.setattr(
module.MCPServerService,
"get_by_id",
lambda _mcp_id: (
True,
_DummyMCPServer(
id="id1",
name="srv-one",
url="http://one",
server_type="sse",
tenant_id="tenant_1",
variables={"authorization_token": "tok", "tools": {"tool_a": {"enabled": True}}},
),
),
)
res = module.detail("id1")
assert res["code"] == 0
assert list(res["data"]["mcpServers"].keys()) == ["srv-one"]
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (False, None))
res = module.detail("missing")
assert res["code"] == 102
assert "Cannot find MCP server missing for user tenant_1" in res["message"]
monkeypatch.setattr(
module.MCPServerService,
"get_by_id",
lambda _mcp_id: (
True,
_DummyMCPServer(
id="id2",
name="srv-two",
url="http://two",
server_type="sse",
tenant_id="other",
variables={},
),
),
)
res = module.detail("id2")
assert res["code"] == 102
assert "Cannot find MCP server id2 for user tenant_1" in res["message"]
def _raise_export(_mcp_id):
raise RuntimeError("export explode")
monkeypatch.setattr(module.MCPServerService, "get_by_id", _raise_export)
res = module.detail("id1")
assert res["code"] == 100
assert "export explode" in res["message"]
@pytest.mark.p2
def test_test_mcp_route_matrix_unit(monkeypatch):
module = _load_mcp_api(monkeypatch)
_set_request_json(monkeypatch, module, {"url": "", "server_type": "sse"})
res = _run(module.test_mcp("mcp-1"))
assert "Invalid MCP url" in res["message"]
_set_request_json(monkeypatch, module, {"url": "http://a", "server_type": "invalid"})
res = _run(module.test_mcp("mcp-1"))
assert "Unsupported MCP server type" in res["message"]
close_calls = []
async def _thread_pool_exec_inner_error(func, *args):
if func is module.close_multiple_mcp_toolcall_sessions:
close_calls.append(args[0])
return None
if getattr(func, "__name__", "") == "get_tools":
raise RuntimeError("get tools explode")
return func(*args)
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_exec_inner_error)
_set_request_json(monkeypatch, module, {"url": "http://a", "server_type": "sse"})
res = _run(module.test_mcp("mcp-1"))
assert res["code"] == 102
assert "Test MCP error: get tools explode" in res["message"]
assert close_calls and len(close_calls[-1]) == 1
close_calls_success = []
async def _thread_pool_exec_success(func, *args):
if func is module.close_multiple_mcp_toolcall_sessions:
close_calls_success.append(args[0])
return None
return func(*args)
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_exec_success)
_set_request_json(monkeypatch, module, {"url": "http://a", "server_type": "sse"})
res = _run(module.test_mcp("mcp-1"))
assert res["code"] == 0
assert res["data"][0]["name"] == "tool_a"
assert all(tool["enabled"] is True for tool in res["data"])
assert close_calls_success and len(close_calls_success[-1]) == 1
def _raise_session(*_args, **_kwargs):
raise RuntimeError("session explode")
monkeypatch.setattr(module, "MCPToolCallSession", _raise_session)
_set_request_json(monkeypatch, module, {"url": "http://a", "server_type": "sse"})
res = _run(module.test_mcp("mcp-1"))
assert res["code"] == 100
assert "session explode" in res["message"]

View File

@@ -0,0 +1,210 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import time
import uuid
import pytest
@pytest.fixture
def memory_cleanup(rest_client):
created_ids: list[str] = []
def _cleanup():
cleanup_errors = []
for memory_id in created_ids:
delete_res = rest_client.delete(f"/memories/{memory_id}")
if delete_res.status_code != 200:
cleanup_errors.append((memory_id, delete_res.status_code, delete_res.text))
continue
delete_payload = delete_res.json()
if delete_payload["code"] not in (0, 404):
cleanup_errors.append((memory_id, delete_res.status_code, delete_payload))
assert not cleanup_errors, f"Memory cleanup failed: {cleanup_errors}"
yield created_ids
_cleanup()
@pytest.fixture
def create_memory_resource(rest_client, memory_cleanup):
def _create(name_prefix: str = "restful_memory") -> str:
payload = {
"name": f"{name_prefix}_{uuid.uuid4().hex[:8]}",
"memory_type": ["raw"],
"embd_id": "BAAI/bge-small-en-v1.5@Builtin",
"llm_id": "glm-4-flash@ZHIPU-AI",
}
res = rest_client.post("/memories", json=payload)
assert res.status_code == 200
res_payload = res.json()
assert res_payload["code"] == 0, res_payload
memory_id = res_payload["data"]["id"]
memory_cleanup.append(memory_id)
return memory_id
yield _create
def _add_message(rest_client, memory_id: str, user_input: str, agent_response: str) -> None:
add_res = rest_client.post(
"/messages",
json={
"memory_id": [memory_id],
"agent_id": uuid.uuid4().hex,
"session_id": uuid.uuid4().hex,
"user_id": uuid.uuid4().hex,
"user_input": user_input,
"agent_response": agent_response,
},
)
assert add_res.status_code == 200
add_payload = add_res.json()
assert add_payload["code"] == 0, add_payload
def _wait_for_memory_messages(rest_client, memory_id: str, timeout: float = 10, interval: float = 0.2) -> list[dict]:
deadline = time.time() + timeout
last_payload = None
while time.time() < deadline:
res = rest_client.get(f"/memories/{memory_id}")
if res.status_code == 200:
payload = res.json()
last_payload = payload
if payload.get("code") == 0:
message_list = payload.get("data", {}).get("messages", {}).get("message_list", [])
if message_list:
return message_list
time.sleep(interval)
pytest.fail(f"Timed out waiting for memory messages: {last_payload}")
@pytest.mark.p1
def test_memory_crud_cycle(rest_client, create_memory_resource):
memory_id = create_memory_resource("restful_memory_crud")
list_res = rest_client.get("/memories")
assert list_res.status_code == 200
list_payload = list_res.json()
assert list_payload["code"] == 0, list_payload
assert any(item["id"] == memory_id for item in list_payload["data"]["memory_list"]), list_payload
config_res = rest_client.get(f"/memories/{memory_id}/config")
assert config_res.status_code == 200
config_payload = config_res.json()
assert config_payload["code"] == 0, config_payload
assert config_payload["data"]["id"] == memory_id, config_payload
update_res = rest_client.put(
f"/memories/{memory_id}",
json={"name": f"updated_{uuid.uuid4().hex[:6]}", "permissions": "me"},
)
assert update_res.status_code == 200
update_payload = update_res.json()
assert update_payload["code"] == 0, update_payload
delete_res = rest_client.delete(f"/memories/{memory_id}")
assert delete_res.status_code == 200
delete_payload = delete_res.json()
assert delete_payload["code"] == 0, delete_payload
@pytest.mark.p2
def test_memory_create_missing_required_fields(rest_client):
res = rest_client.post("/memories", json={"name": "missing_models", "memory_type": ["raw"]})
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 101, payload
@pytest.mark.p1
def test_messages_add_list_recent_content_update_forget(rest_client, create_memory_resource):
memory_id = create_memory_resource("restful_message_memory")
_add_message(
rest_client,
memory_id,
user_input="what is coriander?",
agent_response="coriander can refer to leaves or seeds",
)
message_list = _wait_for_memory_messages(rest_client, memory_id)
message_id = message_list[0]["message_id"]
recent_res = rest_client.get("/messages", params={"memory_id": memory_id, "limit": 10})
assert recent_res.status_code == 200
recent_payload = recent_res.json()
assert recent_payload["code"] == 0, recent_payload
assert any(item["message_id"] == message_id for item in recent_payload["data"]), recent_payload
content_res = rest_client.get(f"/messages/{memory_id}:{message_id}/content")
assert content_res.status_code == 200
content_payload = content_res.json()
assert content_payload["code"] == 0, content_payload
assert content_payload["data"]["content"], content_payload
update_res = rest_client.put(f"/messages/{memory_id}:{message_id}", json={"status": False})
assert update_res.status_code == 200
update_payload = update_res.json()
assert update_payload["code"] == 0, update_payload
forget_res = rest_client.delete(f"/messages/{memory_id}:{message_id}")
assert forget_res.status_code == 200
forget_payload = forget_res.json()
assert forget_payload["code"] == 0, forget_payload
@pytest.mark.p2
def test_message_status_validation_requires_boolean(rest_client, create_memory_resource):
memory_id = create_memory_resource("restful_message_status_validation")
_add_message(rest_client, memory_id, user_input="hello", agent_response="hello")
message_id = _wait_for_memory_messages(rest_client, memory_id)[0]["message_id"]
invalid_update = rest_client.put(f"/messages/{memory_id}:{message_id}", json={"status": "false"})
assert invalid_update.status_code == 200
invalid_payload = invalid_update.json()
assert invalid_payload["code"] == 101, invalid_payload
assert "Status must be a boolean." in invalid_payload["message"], invalid_payload
@pytest.mark.p2
def test_messages_recent_requires_memory_ids(rest_client):
res = rest_client.get("/messages")
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 101, payload
assert "memory_ids is required" in payload["message"], payload
@pytest.mark.p2
def test_message_search_route_contract(rest_client, create_memory_resource):
memory_id = create_memory_resource("restful_message_search")
_add_message(
rest_client,
memory_id,
user_input="what is pineapple?",
agent_response="pineapple is a tropical fruit",
)
_wait_for_memory_messages(rest_client, memory_id)
res = rest_client.get("/messages/search", params={"memory_id": memory_id, "query": "pineapple", "top_n": 3})
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 0, payload
assert isinstance(payload["data"], list), payload

View File

@@ -0,0 +1,165 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import uuid
import pytest
def _memory_payload(name: str) -> dict:
return {
"name": name,
"memory_type": ["raw"],
"embd_id": "BAAI/bge-small-en-v1.5@Builtin",
"llm_id": "glm-4-flash@ZHIPU-AI",
}
def _create_memory(rest_client, name: str) -> dict:
res = rest_client.post("/memories", json=_memory_payload(name))
assert res.status_code == 200
payload = res.json()
if payload["code"] == 0:
return payload["data"]
pytest.fail(f"Failed to create memory: {payload}")
@pytest.fixture
def memory_resource(rest_client):
memory = _create_memory(rest_client, f"restful_memory_{uuid.uuid4().hex[:8]}")
memory_id = memory["id"]
try:
yield memory
finally:
delete_res = rest_client.delete(f"/memories/{memory_id}")
assert delete_res.status_code == 200, delete_res.text
delete_payload = delete_res.json()
assert delete_payload["code"] in (0, 404), delete_payload
@pytest.mark.p2
def test_memory_and_message_routes_require_auth(rest_client_noauth):
memory_res = rest_client_noauth.get("/memories")
assert memory_res.status_code == 401
memory_payload = memory_res.json()
assert memory_payload["code"] == 401, memory_payload
msg_list_res = rest_client_noauth.get("/messages")
assert msg_list_res.status_code == 401
msg_list_payload = msg_list_res.json()
assert msg_list_payload["code"] == 401, msg_list_payload
msg_search_res = rest_client_noauth.get("/messages/search")
assert msg_search_res.status_code == 401
msg_search_payload = msg_search_res.json()
assert msg_search_payload["code"] == 401, msg_search_payload
@pytest.mark.p2
def test_memory_crud_and_config(rest_client):
memory = _create_memory(rest_client, f"restful_memory_crud_{uuid.uuid4().hex[:8]}")
memory_id = memory["id"]
try:
config_res = rest_client.get(f"/memories/{memory_id}/config")
assert config_res.status_code == 200
config_payload = config_res.json()
assert config_payload["code"] == 0, config_payload
assert config_payload["data"]["id"] == memory_id, config_payload
list_res = rest_client.get("/memories", params={"keywords": memory["name"]})
assert list_res.status_code == 200
list_payload = list_res.json()
assert list_payload["code"] == 0, list_payload
assert any(item["id"] == memory_id for item in list_payload["data"]["memory_list"]), list_payload
update_res = rest_client.put(f"/memories/{memory_id}", json={"name": "restful_memory_updated"})
assert update_res.status_code == 200
update_payload = update_res.json()
assert update_payload["code"] == 0, update_payload
finally:
delete_res = rest_client.delete(f"/memories/{memory_id}")
assert delete_res.status_code == 200, delete_res.text
delete_payload = delete_res.json()
assert delete_payload["code"] in (0, 404), delete_payload
@pytest.mark.p2
def test_memory_update_invalid_name(rest_client, memory_resource):
memory_id = memory_resource["id"]
res = rest_client.put(f"/memories/{memory_id}", json={"name": " "})
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 101, payload
assert "cannot be empty" in payload["message"], payload
@pytest.mark.p2
def test_messages_list_and_search_validation_contracts(rest_client, memory_resource):
memory_id = memory_resource["id"]
list_res = rest_client.get("/messages", params={"memory_id": memory_id, "limit": 10})
assert list_res.status_code == 200
list_payload = list_res.json()
assert list_payload["code"] == 0, list_payload
assert isinstance(list_payload["data"], list), list_payload
missing_memory_res = rest_client.get("/messages")
assert missing_memory_res.status_code == 200
missing_memory_payload = missing_memory_res.json()
assert missing_memory_payload["code"] == 101, missing_memory_payload
assert "memory_ids is required" in missing_memory_payload["message"], missing_memory_payload
search_res = rest_client.get("/messages/search", params={"memory_id": memory_id, "query": "coriander"})
assert search_res.status_code == 200
search_payload = search_res.json()
assert search_payload["code"] == 0, search_payload
assert isinstance(search_payload["data"], list), search_payload
search_no_memory = rest_client.get("/messages/search", params={"query": "coriander"})
assert search_no_memory.status_code == 200
search_no_memory_payload = search_no_memory.json()
assert search_no_memory_payload["code"] == 0, search_no_memory_payload
assert isinstance(search_no_memory_payload["data"], list), search_no_memory_payload
@pytest.mark.p2
def test_message_update_forget_and_content_error_contracts(rest_client, memory_resource):
memory_id = memory_resource["id"]
invalid_status_res = rest_client.put(
f"/messages/{memory_id}:1",
json={"status": "false"},
)
assert invalid_status_res.status_code == 200
invalid_status_payload = invalid_status_res.json()
assert invalid_status_payload["code"] == 101, invalid_status_payload
assert "Status must be a boolean" in invalid_status_payload["message"], invalid_status_payload
missing_content_res = rest_client.get(f"/messages/{memory_id}:999999/content")
assert missing_content_res.status_code == 200
missing_content_payload = missing_content_res.json()
assert missing_content_payload["code"] == 404, missing_content_payload
invalid_memory_forget = rest_client.delete("/messages/missing_memory_id:1")
assert invalid_memory_forget.status_code == 200
invalid_memory_forget_payload = invalid_memory_forget.json()
assert invalid_memory_forget_payload["code"] == 404, invalid_memory_forget_payload
invalid_memory_update = rest_client.put("/messages/missing_memory_id:1", json={"status": False})
assert invalid_memory_update.status_code == 200
invalid_memory_update_payload = invalid_memory_update.json()
assert invalid_memory_update_payload["code"] == 404, invalid_memory_update_payload

View File

@@ -0,0 +1,212 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import pytest
def _sse_events(response_text: str) -> list[str]:
return [line[5:] for line in response_text.splitlines() if line.startswith("data:")]
@pytest.mark.p2
@pytest.mark.parametrize(
"payload, expected_message",
[
(
{
"model": "model",
"messages": [{"role": "user", "content": "hello"}],
"extra_body": "invalid_extra_body",
},
"extra_body must be an object.",
),
(
{
"model": "model",
"messages": [{"role": "user", "content": "hello"}],
"extra_body": {"reference_metadata": "invalid_reference_metadata"},
},
"reference_metadata must be an object.",
),
(
{
"model": "model",
"messages": [{"role": "user", "content": "hello"}],
"extra_body": {"reference_metadata": {"fields": "author"}},
},
"reference_metadata.fields must be an array.",
),
(
{
"model": "model",
"messages": [],
},
"You have to provide messages.",
),
(
{
"model": "model",
"messages": [{"role": "assistant", "content": "hello"}],
},
"The last content of this conversation is not from user.",
),
],
)
def test_openai_compatible_validation_payloads(rest_client, create_chat, payload, expected_message):
chat_id = create_chat("restful_openai_validation_chat")
res = rest_client.post(f"/openai/{chat_id}/chat/completions", json=payload)
assert res.status_code == 200
data = res.json()
assert data["code"] != 0, data
assert expected_message in data.get("message", ""), data
@pytest.mark.p2
def test_openai_compatible_metadata_condition_requires_object(rest_client, create_chat):
chat_id = create_chat("restful_openai_metadata_condition_chat")
res = rest_client.post(
f"/openai/{chat_id}/chat/completions",
json={
"model": "model",
"messages": [{"role": "user", "content": "hello"}],
"extra_body": {"metadata_condition": "invalid"},
},
)
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 102, payload
assert "metadata_condition must be an object." in payload["message"], payload
@pytest.mark.p2
def test_openai_compatible_invalid_chat(rest_client):
res = rest_client.post(
"/openai/invalid_chat_id/chat/completions",
json={
"model": "model",
"messages": [{"role": "user", "content": "hello"}],
"stream": False,
},
)
assert res.status_code == 200
payload = res.json()
assert payload["code"] != 0, payload
assert "don't own the chat" in payload["message"], payload
@pytest.mark.p2
def test_openai_compatible_nonstream_shape(rest_client, create_chat):
chat_id = create_chat("restful_openai_nonstream_chat")
res = rest_client.post(
f"/openai/{chat_id}/chat/completions",
json={
"model": "model",
"messages": [{"role": "user", "content": "hello"}],
"stream": False,
},
timeout=60,
)
assert res.status_code == 200
payload = res.json()
assert payload["object"] == "chat.completion", payload
assert isinstance(payload["choices"], list) and payload["choices"], payload
first_choice = payload["choices"][0]
assert first_choice.get("finish_reason") == "stop", payload
assert first_choice.get("message", {}).get("role") == "assistant", payload
assert "content" in first_choice.get("message", {}), payload
usage = payload.get("usage", {})
assert "prompt_tokens" in usage, usage
assert "completion_tokens" in usage, usage
assert "total_tokens" in usage, usage
assert usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"], usage
@pytest.mark.p2
def test_openai_compatible_nonstream_with_reference_output_shape(rest_client, create_chat):
chat_id = create_chat("restful_openai_reference_chat")
res = rest_client.post(
f"/openai/{chat_id}/chat/completions",
json={
"model": "model",
"messages": [{"role": "user", "content": "hello"}],
"stream": False,
"extra_body": {
"reference": True,
"reference_metadata": {"include": True, "fields": ["author"]},
},
},
timeout=60,
)
assert res.status_code == 200
payload = res.json()
choice_msg = payload["choices"][0]["message"]
assert "reference" in choice_msg, payload
assert isinstance(choice_msg["reference"], list), payload
@pytest.mark.p2
def test_openai_compatible_stream_shape_and_done_semantics(rest_client, create_chat):
chat_id = create_chat("restful_openai_stream_chat")
res = rest_client.post(
f"/openai/{chat_id}/chat/completions",
json={
"model": "model",
"messages": [{"role": "user", "content": "hello"}],
"stream": True,
"extra_body": {"reference": True},
},
timeout=60,
)
assert res.status_code == 200
content_type = res.headers.get("Content-Type", "")
assert "text/event-stream" in content_type, content_type
events = _sse_events(res.text)
assert events, res.text
assert events[-1].strip() == "[DONE]", events[-1]
json_events = [json.loads(evt) for evt in events if evt.strip() != "[DONE]"]
assert json_events, events
assert any(evt.get("object") == "chat.completion.chunk" for evt in json_events), json_events
assert any(evt.get("choices", [{}])[0].get("finish_reason") == "stop" for evt in json_events), json_events
@pytest.mark.p2
def test_openai_compatible_reference_metadata_fields_filter_accepts_array(rest_client, create_chat):
chat_id = create_chat("restful_openai_reference_fields_array_chat")
res = rest_client.post(
f"/openai/{chat_id}/chat/completions",
json={
"model": "model",
"messages": [{"role": "user", "content": "hello"}],
"stream": False,
"extra_body": {
"reference": True,
"reference_metadata": {"include": True, "fields": ["author", "year"]},
},
},
timeout=60,
)
assert res.status_code == 200
payload = res.json()
assert payload.get("choices"), payload
choice_msg = payload["choices"][0]["message"]
assert "reference" in choice_msg, payload
assert isinstance(choice_msg["reference"], list), payload

View File

@@ -0,0 +1,92 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import importlib.util
import sys
from pathlib import Path
from types import ModuleType
import pytest
@pytest.mark.p2
def test_plugin_tools_requires_auth(rest_client_noauth):
res = rest_client_noauth.get("/plugin/tools")
assert res.status_code == 401
payload = res.json()
assert payload["code"] == 401, payload
@pytest.mark.p2
def test_plugin_tools_contract(rest_client):
res = rest_client.get("/plugin/tools")
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 0, payload
assert isinstance(payload["data"], list), payload
class _DummyManager:
def route(self, *_args, **_kwargs):
def decorator(func):
return func
return decorator
def _load_plugin_module(monkeypatch):
repo_root = Path(__file__).resolve().parents[3]
common_pkg = ModuleType("common")
common_pkg.__path__ = [str(repo_root / "common")]
monkeypatch.setitem(sys.modules, "common", common_pkg)
stub_apps = ModuleType("api.apps")
stub_apps.login_required = lambda func: func
monkeypatch.setitem(sys.modules, "api.apps", stub_apps)
stub_plugin = ModuleType("agent.plugin")
class _StubGlobalPluginManager:
@staticmethod
def get_llm_tools():
return []
stub_plugin.GlobalPluginManager = _StubGlobalPluginManager
monkeypatch.setitem(sys.modules, "agent.plugin", stub_plugin)
module_path = repo_root / "api" / "apps" / "restful_apis" / "plugin_api.py"
spec = importlib.util.spec_from_file_location("restful_plugin_api_unit", module_path)
module = importlib.util.module_from_spec(spec)
module.manager = _DummyManager()
spec.loader.exec_module(module)
return module
@pytest.mark.p2
def test_plugin_tools_metadata_shape_unit(monkeypatch):
module = _load_plugin_module(monkeypatch)
class _DummyTool:
def get_metadata(self):
return {"name": "dummy", "description": "test"}
monkeypatch.setattr(module.GlobalPluginManager, "get_llm_tools", staticmethod(lambda: [_DummyTool()]))
res = module.llm_tools()
assert res["code"] == 0
assert isinstance(res["data"], list)
assert res["data"][0]["name"] == "dummy"
assert res["data"][0]["description"] == "test"

View File

@@ -0,0 +1,109 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pytest
@pytest.mark.p1
def test_dataset_search_rest_endpoint(rest_client, ensure_parsed_document):
dataset_id, _ = ensure_parsed_document()
res = rest_client.post(
f"/datasets/{dataset_id}/search",
json={"question": "test TXT file", "top_k": 5},
)
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 0, payload
assert "chunks" in payload["data"], payload
@pytest.mark.p2
def test_multi_dataset_search_rest_endpoint(rest_client, ensure_parsed_document):
dataset_id, _ = ensure_parsed_document()
res = rest_client.post(
"/datasets/search",
json={"dataset_ids": [dataset_id], "question": "test TXT file", "top_k": 5},
)
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 0, payload
assert "chunks" in payload["data"], payload
@pytest.mark.p2
def test_multi_dataset_search_with_metadata_filter(rest_client, ensure_parsed_document):
dataset_id, document_id = ensure_parsed_document()
meta_res = rest_client.patch(
f"/datasets/{dataset_id}/documents/metadatas",
json={
"selector": {"document_ids": [document_id]},
"updates": [{"key": "author", "value": "qa_batch2"}],
"deletes": [],
},
)
assert meta_res.status_code == 200
meta_payload = meta_res.json()
assert meta_payload["code"] == 0, meta_payload
res = rest_client.post(
"/datasets/search",
json={
"dataset_ids": [dataset_id],
"question": "test TXT file",
"meta_data_filter": {
"method": "manual",
"logic": "and",
"manual": [{"key": "author", "op": "=", "value": "qa_batch2"}],
},
},
)
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 0, payload
assert "chunks" in payload["data"], payload
@pytest.mark.p2
def test_retrieval_compatibility_endpoint(rest_client, ensure_parsed_document):
dataset_id, _ = ensure_parsed_document()
# /api/v1/retrieval is SDK compatibility endpoint from api/apps/sdk/doc.py.
res = rest_client.post(
"/retrieval",
json={"dataset_ids": [dataset_id], "question": "test TXT file", "top_k": 5},
)
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 0, payload
assert "chunks" in payload["data"], payload
@pytest.mark.p2
def test_retrieval_compatibility_requires_dataset_ids(rest_client):
res = rest_client.post("/retrieval", json={"question": "test"})
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 102, payload
assert payload["message"] == "`dataset_ids` is required.", payload
@pytest.mark.p2
def test_retrieval_compatibility_requires_auth(rest_client_noauth):
res = rest_client_noauth.post("/retrieval", json={"question": "test", "dataset_ids": ["x"]})
assert res.status_code == 401
payload = res.json()
# token_required preserves legacy payload code/message while returning HTTP 401.
assert payload["code"] == 0, payload
assert payload["message"] == "`Authorization` can't be empty", payload

View File

@@ -0,0 +1,28 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pytest
from configs import VERSION
@pytest.mark.p1
def test_route_not_found_returns_json(rest_client_noauth):
res = rest_client_noauth.get("/__missing_route__")
assert res.status_code == 404
payload = res.json()
assert payload["code"] == 404, payload
assert payload["error"] == "Not Found", payload
assert payload["message"] == f"Not Found: /api/{VERSION}/__missing_route__", payload

View File

@@ -0,0 +1,155 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import uuid
import pytest
@pytest.fixture
def search_resource(rest_client):
name = f"restful_search_{uuid.uuid4().hex[:8]}"
create_res = rest_client.post("/searches", json={"name": name, "description": "restful search"})
assert create_res.status_code == 200
create_payload = create_res.json()
assert create_payload["code"] == 0, create_payload
search_id = create_payload["data"]["search_id"]
try:
yield search_id
finally:
delete_res = rest_client.delete(f"/searches/{search_id}")
assert delete_res.status_code == 200, delete_res.text
delete_payload = delete_res.json()
assert delete_payload["code"] in (0, 109), delete_payload
def _sse_events(response_text: str) -> list[str]:
return [line[5:] for line in response_text.splitlines() if line.startswith("data:")]
@pytest.mark.p2
def test_search_routes_require_auth(rest_client_noauth):
create_res = rest_client_noauth.post("/searches", json={"name": "search_noauth"})
assert create_res.status_code == 401
create_payload = create_res.json()
assert create_payload["code"] == 401, create_payload
list_res = rest_client_noauth.get("/searches")
assert list_res.status_code == 401
list_payload = list_res.json()
assert list_payload["code"] == 401, list_payload
@pytest.mark.p2
def test_search_crud_contract(rest_client, search_resource):
search_id = search_resource
list_res = rest_client.get("/searches")
assert list_res.status_code == 200
list_payload = list_res.json()
assert list_payload["code"] == 0, list_payload
assert any(item.get("id") == search_id for item in list_payload["data"]["search_apps"]), list_payload
detail_res = rest_client.get(f"/searches/{search_id}")
assert detail_res.status_code == 200
detail_payload = detail_res.json()
assert detail_payload["code"] == 0, detail_payload
assert detail_payload["data"]["id"] == search_id, detail_payload
new_name = f"search_updated_{uuid.uuid4().hex[:6]}"
update_res = rest_client.put(
f"/searches/{search_id}",
json={"name": new_name, "search_config": {"top_k": 3}},
)
assert update_res.status_code == 200
update_payload = update_res.json()
assert update_payload["code"] == 0, update_payload
assert update_payload["data"]["name"] == new_name, update_payload
@pytest.mark.p2
def test_search_create_invalid_name(rest_client):
res = rest_client.post("/searches", json={"name": ""})
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 102, payload
assert "empty" in payload["message"], payload
@pytest.mark.p2
def test_search_update_invalid_search_id(rest_client):
res = rest_client.put(
"/searches/invalid_search_id",
json={"name": "invalid", "search_config": {}},
)
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 109, payload
assert "No authorization" in payload["message"], payload
@pytest.mark.p2
def test_search_completion_requires_question(rest_client, search_resource):
search_id = search_resource
completion_res = rest_client.post(f"/searches/{search_id}/completion", json={})
assert completion_res.status_code == 200
completion_payload = completion_res.json()
assert completion_payload["code"] == 101, completion_payload
assert "required argument are missing: question" in completion_payload["message"], completion_payload
completions_res = rest_client.post(f"/searches/{search_id}/completions", json={})
assert completions_res.status_code == 200
completions_payload = completions_res.json()
assert completions_payload["code"] == 101, completions_payload
assert "required argument are missing: question" in completions_payload["message"], completions_payload
@pytest.mark.p2
def test_search_completion_requires_kb_ids(rest_client, search_resource):
search_id = search_resource
for path in ("completion", "completions"):
res = rest_client.post(
f"/searches/{search_id}/{path}",
json={"question": "what is coriander?"},
)
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 102, payload
assert "`kb_ids` is required" in payload["message"], payload
@pytest.mark.p2
def test_search_completion_sse_shape_when_kb_ids_provided(rest_client, search_resource):
search_id = search_resource
# Even with kb_ids provided, runtime may return an error event in-stream, but
# contract remains SSE with JSON data lines and terminal boolean event.
res = rest_client.post(
f"/searches/{search_id}/completion",
json={"question": "what is coriander?", "kb_ids": ["nonexistent_dataset"]},
timeout=60,
)
assert res.status_code == 200
content_type = res.headers.get("Content-Type", "")
assert "text/event-stream" in content_type, content_type
events = _sse_events(res.text)
assert events, res.text
parsed = [json.loads(evt) for evt in events]
assert isinstance(parsed[0], dict), parsed
assert parsed[-1].get("data") is True, parsed[-1]

View File

@@ -0,0 +1,219 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import pytest
def _sse_events(response_text: str) -> list[str]:
return [line[5:] for line in response_text.splitlines() if line.startswith("data:")]
@pytest.mark.p1
def test_session_crud_cycle(rest_client, create_chat):
chat_id = create_chat("restful_session_crud_chat")
create_res = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": "session_a"})
assert create_res.status_code == 200
create_payload = create_res.json()
assert create_payload["code"] == 0, create_payload
session_id = create_payload["data"]["id"]
assert create_payload["data"]["chat_id"] == chat_id, create_payload
list_res = rest_client.get(f"/chats/{chat_id}/sessions")
assert list_res.status_code == 200
list_payload = list_res.json()
assert list_payload["code"] == 0, list_payload
assert any(item["id"] == session_id for item in list_payload["data"]), list_payload
get_res = rest_client.get(f"/chats/{chat_id}/sessions/{session_id}")
assert get_res.status_code == 200
get_payload = get_res.json()
assert get_payload["code"] == 0, get_payload
assert get_payload["data"]["id"] == session_id, get_payload
patch_res = rest_client.patch(
f"/chats/{chat_id}/sessions/{session_id}",
json={"name": "session_a_updated"},
)
assert patch_res.status_code == 200
patch_payload = patch_res.json()
assert patch_payload["code"] == 0, patch_payload
assert patch_payload["data"]["name"] == "session_a_updated", patch_payload
delete_res = rest_client.delete(f"/chats/{chat_id}/sessions", json={"ids": [session_id]})
assert delete_res.status_code == 200
delete_payload = delete_res.json()
assert delete_payload["code"] == 0, delete_payload
list_after_delete = rest_client.get(f"/chats/{chat_id}/sessions")
assert list_after_delete.status_code == 200
list_after_delete_payload = list_after_delete.json()
assert list_after_delete_payload["code"] == 0, list_after_delete_payload
assert all(item["id"] != session_id for item in list_after_delete_payload["data"]), list_after_delete_payload
@pytest.mark.p2
def test_session_create_name_validation(rest_client, create_chat):
chat_id = create_chat("restful_session_name_validation_chat")
res = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": " "})
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 102, payload
assert "`name` can not be empty." in payload["message"], payload
@pytest.mark.p2
def test_session_update_blocks_messages_and_reference(rest_client, create_chat):
chat_id = create_chat("restful_session_guard_chat")
create_res = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": "session_guard"})
assert create_res.status_code == 200
create_payload = create_res.json()
assert create_payload["code"] == 0, create_payload
session_id = create_payload["data"]["id"]
msg_res = rest_client.patch(f"/chats/{chat_id}/sessions/{session_id}", json={"messages": []})
assert msg_res.status_code == 200
msg_payload = msg_res.json()
assert msg_payload["code"] == 102, msg_payload
assert "`messages` cannot be changed." in msg_payload["message"], msg_payload
ref_res = rest_client.patch(f"/chats/{chat_id}/sessions/{session_id}", json={"reference": []})
assert ref_res.status_code == 200
ref_payload = ref_res.json()
assert ref_payload["code"] == 102, ref_payload
assert "`reference` cannot be changed." in ref_payload["message"], ref_payload
@pytest.mark.p2
def test_chat_recommendation_requires_question(rest_client):
res = rest_client.post("/chat/recommendation", json={})
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 101, payload
assert "required argument are missing: question" in payload["message"], payload
@pytest.mark.p2
def test_related_questions_compatibility_requires_auth(rest_client_noauth):
# /api/v1/searchbots/related_questions is an SDK compatibility endpoint.
res = rest_client_noauth.post(
"/searchbots/related_questions",
json={"question": "ragflow"},
headers={"Authorization": "invalid"},
)
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 102, payload
assert "Authorization is not valid!" in payload["message"], payload
@pytest.mark.p2
def test_chat_completion_nonstream_with_session(rest_client, create_chat):
chat_id = create_chat("restful_completion_nonstream_chat")
create_session_res = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": "session_for_completion"})
assert create_session_res.status_code == 200
create_session_payload = create_session_res.json()
assert create_session_payload["code"] == 0, create_session_payload
session_id = create_session_payload["data"]["id"]
completion_res = rest_client.post(
"/chat/completions",
json={
"chat_id": chat_id,
"session_id": session_id,
"messages": [{"role": "user", "content": "hello"}],
"stream": False,
},
timeout=60,
)
assert completion_res.status_code == 200
completion_payload = completion_res.json()
assert completion_payload["code"] == 0, completion_payload
assert isinstance(completion_payload["data"], dict), completion_payload
assert completion_payload["data"]["session_id"] == session_id, completion_payload
assert "answer" in completion_payload["data"], completion_payload
assert "reference" in completion_payload["data"], completion_payload
@pytest.mark.p2
def test_chat_completion_stream_events(rest_client, create_chat):
chat_id = create_chat("restful_completion_stream_chat")
stream_res = rest_client.post(
"/chat/completions",
json={
"chat_id": chat_id,
"messages": [{"role": "user", "content": "hello"}],
"stream": True,
},
timeout=60,
)
assert stream_res.status_code == 200
content_type = stream_res.headers.get("Content-Type", "")
assert "text/event-stream" in content_type, content_type
events = _sse_events(stream_res.text)
assert events, stream_res.text
parsed_events = []
for event in events:
parsed = json.loads(event)
parsed_events.append(parsed)
assert any(evt.get("code") == 0 and isinstance(evt.get("data"), dict) for evt in parsed_events), parsed_events
assert parsed_events[-1].get("data") is True, parsed_events[-1]
@pytest.mark.p2
def test_chat_completion_validation_errors(rest_client, create_chat):
chat_id = create_chat("restful_completion_validation_chat")
missing_messages = rest_client.post(
"/chat/completions",
json={"chat_id": chat_id, "stream": False},
)
assert missing_messages.status_code == 200
missing_messages_payload = missing_messages.json()
assert missing_messages_payload["code"] == 101, missing_messages_payload
assert "required argument are missing: messages" in missing_messages_payload["message"], missing_messages_payload
missing_chat_for_session = rest_client.post(
"/chat/completions",
json={
"session_id": "some_session",
"messages": [{"role": "user", "content": "hello"}],
"stream": False,
},
)
assert missing_chat_for_session.status_code == 200
missing_chat_for_session_payload = missing_chat_for_session.json()
assert missing_chat_for_session_payload["code"] == 102, missing_chat_for_session_payload
assert "`chat_id` is required when `session_id` is provided." in missing_chat_for_session_payload["message"], missing_chat_for_session_payload
invalid_chat = rest_client.post(
"/chat/completions",
json={
"chat_id": "invalid_chat_id",
"session_id": "invalid_session",
"messages": [{"role": "user", "content": "hello"}],
"stream": False,
},
)
assert invalid_chat.status_code == 200
invalid_chat_payload = invalid_chat.json()
assert invalid_chat_payload["code"] == 109, invalid_chat_payload
assert "No authorization." in invalid_chat_payload["message"], invalid_chat_payload

View File

@@ -0,0 +1,159 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pytest
@pytest.mark.p1
def test_system_ping(rest_client):
res = rest_client.get("/system/ping")
assert res.status_code == 200
assert res.text == "pong"
@pytest.mark.p1
def test_system_version(rest_client):
res = rest_client.get("/system/version")
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 0, payload
assert payload["data"], payload
@pytest.mark.p2
def test_system_status_requires_auth(rest_client_noauth):
res = rest_client_noauth.get("/system/status")
assert res.status_code == 401
payload = res.json()
assert payload["code"] == 401, payload
assert "Unauthorized" in payload["message"], payload
@pytest.mark.p2
def test_system_status_contract(rest_client):
res = rest_client.get("/system/status")
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 0, payload
for key in ("doc_engine", "storage", "database", "redis"):
assert key in payload["data"], payload
@pytest.mark.p2
def test_system_config_no_auth_required(rest_client_noauth):
res = rest_client_noauth.get("/system/config")
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 0, payload
assert "registerEnabled" in payload["data"], payload
assert "disablePasswordLogin" in payload["data"], payload
@pytest.mark.p2
def test_system_healthz_contract(rest_client_noauth):
res = rest_client_noauth.get("/system/healthz")
assert res.status_code in (200, 500)
payload = res.json()
assert isinstance(payload, dict), payload
assert payload, payload
@pytest.mark.p2
def test_system_tokens_auth_and_crud(rest_client, rest_client_noauth):
unauth_list = rest_client_noauth.get("/system/tokens")
assert unauth_list.status_code == 401
unauth_list_payload = unauth_list.json()
assert unauth_list_payload["code"] == 401, unauth_list_payload
create_res = rest_client.post("/system/tokens")
assert create_res.status_code == 200
create_payload = create_res.json()
assert create_payload["code"] == 0, create_payload
token = create_payload["data"]["token"]
list_res = rest_client.get("/system/tokens")
assert list_res.status_code == 200
list_payload = list_res.json()
assert list_payload["code"] == 0, list_payload
assert isinstance(list_payload["data"], list), list_payload
assert any(item.get("token") == token for item in list_payload["data"]), list_payload
delete_res = rest_client.delete(f"/system/tokens/{token}")
assert delete_res.status_code == 200
delete_payload = delete_res.json()
assert delete_payload["code"] == 0, delete_payload
assert delete_payload["data"] is True, delete_payload
delete_missing = rest_client.delete("/system/tokens/missing_token")
assert delete_missing.status_code == 200
delete_missing_payload = delete_missing.json()
assert delete_missing_payload["code"] == 0, delete_missing_payload
assert delete_missing_payload["data"] is True, delete_missing_payload
@pytest.mark.p2
def test_system_stats_auth_and_shape(rest_client, rest_client_noauth):
unauth_res = rest_client_noauth.get("/system/stats")
assert unauth_res.status_code == 401
unauth_payload = unauth_res.json()
assert unauth_payload["code"] == 401, unauth_payload
res = rest_client.get("/system/stats")
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 0, payload
data = payload["data"]
for key in ("pv", "uv", "speed", "tokens", "round", "thumb_up"):
assert key in data, payload
assert isinstance(data[key], list), payload
@pytest.mark.p2
def test_system_oceanbase_status_auth_contract(rest_client, rest_client_noauth):
unauth = rest_client_noauth.get("/system/oceanbase/status")
assert unauth.status_code == 401
assert unauth.json()["code"] == 401
res = rest_client.get("/system/oceanbase/status")
assert res.status_code == 200
payload = res.json()
assert payload["code"] in (0, 500), payload
assert "data" in payload, payload
@pytest.mark.p2
def test_system_log_config_routes_auth_and_validation(rest_client, rest_client_noauth):
unauth = rest_client_noauth.get("/system/config/log")
assert unauth.status_code == 401
assert unauth.json()["code"] == 401
levels = rest_client.get("/system/config/log")
assert levels.status_code == 200
levels_payload = levels.json()
assert levels_payload["code"] == 0, levels_payload
assert isinstance(levels_payload["data"], dict), levels_payload
missing_body = rest_client.put("/system/config/log", json={})
assert missing_body.status_code == 200
missing_payload = missing_body.json()
assert missing_payload["code"] == 102, missing_payload
assert "pkg_name and level are required" in missing_payload["message"], missing_payload
invalid_level = rest_client.put("/system/config/log", json={"pkg_name": "rag", "level": "NOT_A_LEVEL"})
assert invalid_level.status_code == 200
invalid_payload = invalid_level.json()
assert invalid_payload["code"] == 102, invalid_payload
assert "Invalid log level" in invalid_payload["message"], invalid_payload

View File

@@ -0,0 +1,48 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pytest
@pytest.mark.p2
def test_task_routes_require_auth(rest_client_noauth):
cancel_res = rest_client_noauth.post("/tasks/missing_task/cancel")
assert cancel_res.status_code == 401
cancel_payload = cancel_res.json()
assert cancel_payload["code"] == 401, cancel_payload
patch_res = rest_client_noauth.patch("/tasks/missing_task", json={"action": "stop"})
assert patch_res.status_code == 401
patch_payload = patch_res.json()
assert patch_payload["code"] == 401, patch_payload
@pytest.mark.p2
def test_patch_task_rejects_unsupported_action(rest_client):
res = rest_client.patch("/tasks/missing_task", json={"action": "pause"})
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 101, payload
assert "Only 'stop' is supported" in payload["message"], payload
@pytest.mark.p2
def test_cancel_missing_task_sets_cancel_contract(rest_client):
res = rest_client.post("/tasks/missing_task/cancel")
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 0, payload
assert payload["data"] is True, payload

File diff suppressed because it is too large Load Diff