mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
Chore: migrate tests to restful api (#14871)
### What problem does this PR solve? add new testing suite for the new restful api endpoints meant to replace http and web api tests ### Type of change - [x] New Feature (non-breaking change which adds functionality) - [x] Other (please describe): test
This commit is contained in:
163
test/testcases/restful_api/conftest.py
Normal file
163
test/testcases/restful_api/conftest.py
Normal file
@@ -0,0 +1,163 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import pytest
|
||||
|
||||
from libs.auth import RAGFlowHttpApiAuth
|
||||
from test.testcases.restful_api.helpers.client import RestClient
|
||||
from utils.file_utils import create_txt_file
|
||||
from utils import wait_for
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def RestApiAuth(token):
|
||||
return RAGFlowHttpApiAuth(token)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def rest_client(token):
|
||||
return RestClient(token=token)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def rest_client_noauth():
|
||||
return RestClient(token=None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def clear_datasets(rest_client):
|
||||
def _cleanup():
|
||||
res = rest_client.delete("/datasets", json={"ids": None, "delete_all": True})
|
||||
assert res.status_code == 200, res.text
|
||||
payload = res.json()
|
||||
assert payload["code"] in (0, 102), payload
|
||||
|
||||
yield
|
||||
_cleanup()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def clear_chats(rest_client):
|
||||
def _cleanup():
|
||||
res = rest_client.delete("/chats", json={"ids": None, "delete_all": True})
|
||||
assert res.status_code == 200, res.text
|
||||
payload = res.json()
|
||||
assert payload["code"] in (0, 102), payload
|
||||
|
||||
yield
|
||||
_cleanup()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_dataset(rest_client, clear_datasets):
|
||||
created_ids: list[str] = []
|
||||
|
||||
def _create(name: str = "restful_dataset") -> str:
|
||||
res = rest_client.post("/datasets", json={"name": name})
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 0, payload
|
||||
dataset_id = payload["data"]["id"]
|
||||
created_ids.append(dataset_id)
|
||||
return dataset_id
|
||||
|
||||
yield _create
|
||||
|
||||
if created_ids:
|
||||
res = rest_client.delete("/datasets", json={"ids": created_ids})
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
# Dataset may already be removed by test logic/cleanup.
|
||||
assert payload["code"] in (0, 102), payload
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_chat(rest_client, clear_chats):
|
||||
created_ids: list[str] = []
|
||||
|
||||
def _create(name: str = "restful_chat") -> str:
|
||||
res = rest_client.post("/chats", json={"name": name, "dataset_ids": []})
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 0, payload
|
||||
chat_id = payload["data"]["id"]
|
||||
created_ids.append(chat_id)
|
||||
return chat_id
|
||||
|
||||
yield _create
|
||||
|
||||
if created_ids:
|
||||
res = rest_client.delete("/chats", json={"ids": created_ids})
|
||||
assert res.status_code == 200, res.text
|
||||
payload = res.json()
|
||||
assert payload["code"] in (0, 102), payload
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_document(rest_client, create_dataset, tmp_path):
|
||||
created_docs: list[tuple[str, str]] = []
|
||||
|
||||
def _create(name: str = "restful_doc.txt") -> tuple[str, str]:
|
||||
dataset_id = create_dataset("dataset_for_doc")
|
||||
fp = create_txt_file(tmp_path / name)
|
||||
with fp.open("rb") as file_obj:
|
||||
files = [("file", (fp.name, file_obj))]
|
||||
res = rest_client.post(f"/datasets/{dataset_id}/documents", files=files)
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 0, payload
|
||||
document_id = payload["data"][0]["id"]
|
||||
created_docs.append((dataset_id, document_id))
|
||||
return dataset_id, document_id
|
||||
|
||||
yield _create
|
||||
|
||||
for dataset_id, document_id in created_docs:
|
||||
res = rest_client.delete(f"/datasets/{dataset_id}/documents", json={"ids": [document_id]})
|
||||
assert res.status_code == 200, res.text
|
||||
payload = res.json()
|
||||
assert payload["code"] in (0, 102), payload
|
||||
|
||||
|
||||
@wait_for(60, 1, "Document parsing timeout in RESTful batch2 tests")
|
||||
def _parsed(rest_client: RestClient, dataset_id: str, document_id: str):
|
||||
res = rest_client.get(f"/datasets/{dataset_id}/documents", params={"id": document_id})
|
||||
if res.status_code != 200:
|
||||
return False
|
||||
payload = res.json()
|
||||
if payload["code"] != 0:
|
||||
return False
|
||||
docs = payload["data"]["docs"]
|
||||
if not docs:
|
||||
return False
|
||||
return docs[0].get("run") == "DONE"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ensure_parsed_document(rest_client, create_document):
|
||||
def _ensure() -> tuple[str, str]:
|
||||
dataset_id, document_id = create_document()
|
||||
res = rest_client.post(
|
||||
f"/datasets/{dataset_id}/documents/parse",
|
||||
json={"document_ids": [document_id]},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 0, payload
|
||||
_parsed(rest_client, dataset_id, document_id)
|
||||
return dataset_id, document_id
|
||||
|
||||
return _ensure
|
||||
1
test/testcases/restful_api/helpers/__init__.py
Normal file
1
test/testcases/restful_api/helpers/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
85
test/testcases/restful_api/helpers/client.py
Normal file
85
test/testcases/restful_api/helpers/client.py
Normal file
@@ -0,0 +1,85 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from configs import HOST_ADDRESS, VERSION
|
||||
|
||||
|
||||
@dataclass
|
||||
class RestClient:
|
||||
token: str | None = None
|
||||
timeout: int = 30
|
||||
|
||||
@property
|
||||
def api_root(self) -> str:
|
||||
return f"{HOST_ADDRESS}/api/{VERSION}"
|
||||
|
||||
def _headers(self, headers: dict[str, str] | None = None) -> dict[str, str]:
|
||||
merged: dict[str, str] = {"Content-Type": "application/json"}
|
||||
if headers:
|
||||
merged.update(headers)
|
||||
if self.token and "Authorization" not in merged:
|
||||
merged["Authorization"] = f"Bearer {self.token}"
|
||||
return merged
|
||||
|
||||
def request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
*,
|
||||
headers: dict[str, str] | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
json: dict[str, Any] | None = None,
|
||||
data: Any = None,
|
||||
files: Any = None,
|
||||
**request_kwargs: Any,
|
||||
) -> requests.Response:
|
||||
req_headers = self._headers(headers)
|
||||
if files is not None:
|
||||
# requests sets multipart boundary automatically.
|
||||
req_headers.pop("Content-Type", None)
|
||||
|
||||
timeout = request_kwargs.pop("timeout", self.timeout)
|
||||
normalized_path = f"/{path.lstrip('/')}" if path else "/"
|
||||
return requests.request(
|
||||
method=method,
|
||||
url=f"{self.api_root}{normalized_path}",
|
||||
headers=req_headers,
|
||||
params=params,
|
||||
json=json,
|
||||
data=data,
|
||||
files=files,
|
||||
timeout=timeout,
|
||||
**request_kwargs,
|
||||
)
|
||||
|
||||
def get(self, path: str, **kwargs) -> requests.Response:
|
||||
return self.request("GET", path, **kwargs)
|
||||
|
||||
def post(self, path: str, **kwargs) -> requests.Response:
|
||||
return self.request("POST", path, **kwargs)
|
||||
|
||||
def delete(self, path: str, **kwargs) -> requests.Response:
|
||||
return self.request("DELETE", path, **kwargs)
|
||||
|
||||
def put(self, path: str, **kwargs) -> requests.Response:
|
||||
return self.request("PUT", path, **kwargs)
|
||||
|
||||
def patch(self, path: str, **kwargs) -> requests.Response:
|
||||
return self.request("PATCH", path, **kwargs)
|
||||
333
test/testcases/restful_api/test_agents.py
Normal file
333
test/testcases/restful_api/test_agents.py
Normal file
@@ -0,0 +1,333 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
MINIMAL_DSL = {
|
||||
"components": {
|
||||
"begin": {
|
||||
"obj": {"component_name": "Begin", "params": {}},
|
||||
"downstream": ["message"],
|
||||
"upstream": [],
|
||||
},
|
||||
"message": {
|
||||
"obj": {"component_name": "Message", "params": {"content": ["{sys.query}"]}},
|
||||
"downstream": [],
|
||||
"upstream": ["begin"],
|
||||
},
|
||||
},
|
||||
"history": [],
|
||||
"retrieval": [],
|
||||
"path": [],
|
||||
"globals": {
|
||||
"sys.query": "",
|
||||
"sys.user_id": "",
|
||||
"sys.conversation_turns": 0,
|
||||
"sys.files": [],
|
||||
},
|
||||
"variables": {},
|
||||
}
|
||||
|
||||
|
||||
def _sse_events(response_text: str) -> list[str]:
|
||||
return [line[5:] for line in response_text.splitlines() if line.startswith("data:")]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_agent_resource(rest_client):
|
||||
created_agent_ids: list[str] = []
|
||||
|
||||
def _create(title: str = "restful_agent") -> str:
|
||||
res = rest_client.post("/agents", json={"title": title, "dsl": MINIMAL_DSL})
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 0, payload
|
||||
agent_id = payload["data"]["id"]
|
||||
created_agent_ids.append(agent_id)
|
||||
return agent_id
|
||||
|
||||
yield _create
|
||||
|
||||
cleanup_errors = []
|
||||
for agent_id in created_agent_ids:
|
||||
res = rest_client.delete(f"/agents/{agent_id}")
|
||||
if res.status_code != 200:
|
||||
cleanup_errors.append((agent_id, res.status_code, res.text))
|
||||
continue
|
||||
payload = res.json()
|
||||
if payload["code"] not in (0, 103):
|
||||
cleanup_errors.append((agent_id, res.status_code, payload))
|
||||
assert not cleanup_errors, f"Agent cleanup failed: {cleanup_errors}"
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_agents_crud_validation_contract(rest_client, create_agent_resource):
|
||||
list_empty = rest_client.get("/agents", params={"title": "missing_restful_agent"})
|
||||
assert list_empty.status_code == 200
|
||||
list_empty_payload = list_empty.json()
|
||||
assert list_empty_payload["code"] == 0, list_empty_payload
|
||||
assert "canvas" in list_empty_payload["data"], list_empty_payload
|
||||
assert "total" in list_empty_payload["data"], list_empty_payload
|
||||
|
||||
missing_dsl = rest_client.post("/agents", json={"title": "missing_dsl_agent"})
|
||||
assert missing_dsl.status_code == 200
|
||||
missing_dsl_payload = missing_dsl.json()
|
||||
assert missing_dsl_payload["code"] == 101, missing_dsl_payload
|
||||
assert "No DSL data in request" in missing_dsl_payload["message"], missing_dsl_payload
|
||||
|
||||
missing_title = rest_client.post("/agents", json={"dsl": MINIMAL_DSL})
|
||||
assert missing_title.status_code == 200
|
||||
missing_title_payload = missing_title.json()
|
||||
assert missing_title_payload["code"] == 101, missing_title_payload
|
||||
assert "No title in request" in missing_title_payload["message"], missing_title_payload
|
||||
|
||||
agent_id = create_agent_resource("restful_agent_crud")
|
||||
|
||||
duplicate = rest_client.post("/agents", json={"title": "restful_agent_crud", "dsl": MINIMAL_DSL})
|
||||
assert duplicate.status_code == 200
|
||||
duplicate_payload = duplicate.json()
|
||||
assert duplicate_payload["code"] == 102, duplicate_payload
|
||||
assert "already exists" in duplicate_payload["message"], duplicate_payload
|
||||
|
||||
get_res = rest_client.get(f"/agents/{agent_id}")
|
||||
assert get_res.status_code == 200
|
||||
get_payload = get_res.json()
|
||||
assert get_payload["code"] == 0, get_payload
|
||||
assert get_payload["data"]["id"] == agent_id, get_payload
|
||||
|
||||
update_res = rest_client.put(f"/agents/{agent_id}", json={"title": "restful_agent_crud_updated", "dsl": MINIMAL_DSL})
|
||||
assert update_res.status_code == 200
|
||||
update_payload = update_res.json()
|
||||
assert update_payload["code"] == 0, update_payload
|
||||
|
||||
list_after_update = rest_client.get("/agents", params={"title": "restful_agent_crud_updated"})
|
||||
assert list_after_update.status_code == 200
|
||||
list_after_update_payload = list_after_update.json()
|
||||
assert list_after_update_payload["code"] == 0, list_after_update_payload
|
||||
assert list_after_update_payload["data"]["total"] >= 1, list_after_update_payload
|
||||
|
||||
delete_res = rest_client.delete(f"/agents/{agent_id}")
|
||||
assert delete_res.status_code == 200
|
||||
delete_payload = delete_res.json()
|
||||
assert delete_payload["code"] == 0, delete_payload
|
||||
assert delete_payload["data"] is True, delete_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_agent_sessions_crud(rest_client, create_agent_resource):
|
||||
agent_id = create_agent_resource("restful_agent_sessions")
|
||||
|
||||
create_session = rest_client.post(f"/agents/{agent_id}/sessions", json={"name": "agent_session_1"})
|
||||
assert create_session.status_code == 200
|
||||
create_session_payload = create_session.json()
|
||||
assert create_session_payload["code"] == 0, create_session_payload
|
||||
session_id = create_session_payload["data"]["id"]
|
||||
|
||||
list_sessions = rest_client.get(f"/agents/{agent_id}/sessions")
|
||||
assert list_sessions.status_code == 200
|
||||
list_sessions_payload = list_sessions.json()
|
||||
assert list_sessions_payload["code"] == 0, list_sessions_payload
|
||||
assert isinstance(list_sessions_payload["data"], list), list_sessions_payload
|
||||
assert any(item["id"] == session_id for item in list_sessions_payload["data"]), list_sessions_payload
|
||||
|
||||
get_session = rest_client.get(f"/agents/{agent_id}/sessions/{session_id}")
|
||||
assert get_session.status_code == 200
|
||||
get_session_payload = get_session.json()
|
||||
assert get_session_payload["code"] == 0, get_session_payload
|
||||
assert get_session_payload["data"]["id"] == session_id, get_session_payload
|
||||
|
||||
delete_session = rest_client.delete(f"/agents/{agent_id}/sessions/{session_id}")
|
||||
assert delete_session.status_code == 200
|
||||
delete_session_payload = delete_session.json()
|
||||
assert delete_session_payload["code"] == 0, delete_session_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_agent_chat_completion_validation(rest_client):
|
||||
missing_agent_id = rest_client.post(
|
||||
"/agents/chat/completions",
|
||||
json={"query": "hello", "stream": False},
|
||||
)
|
||||
assert missing_agent_id.status_code == 200
|
||||
missing_agent_id_payload = missing_agent_id.json()
|
||||
assert missing_agent_id_payload["code"] == 101, missing_agent_id_payload
|
||||
assert "`agent_id` is required." in missing_agent_id_payload["message"], missing_agent_id_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_agent_chat_completion_nonstream(rest_client, create_agent_resource):
|
||||
agent_id = create_agent_resource("restful_agent_nonstream")
|
||||
create_session = rest_client.post(f"/agents/{agent_id}/sessions", json={"name": "agent_completion_session"})
|
||||
assert create_session.status_code == 200
|
||||
create_session_payload = create_session.json()
|
||||
assert create_session_payload["code"] == 0, create_session_payload
|
||||
session_id = create_session_payload["data"]["id"]
|
||||
|
||||
res = rest_client.post(
|
||||
"/agents/chat/completions",
|
||||
json={"agent_id": agent_id, "query": "hello", "stream": False, "session_id": session_id},
|
||||
timeout=60,
|
||||
)
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 0, payload
|
||||
assert isinstance(payload["data"], dict), payload
|
||||
assert isinstance(payload["data"].get("data"), dict), payload
|
||||
assert "content" in payload["data"]["data"], payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_agent_chat_completion_stream_structure_and_done(rest_client, create_agent_resource):
|
||||
agent_id = create_agent_resource("restful_agent_stream")
|
||||
create_session = rest_client.post(f"/agents/{agent_id}/sessions", json={"name": "agent_stream_session"})
|
||||
assert create_session.status_code == 200
|
||||
create_session_payload = create_session.json()
|
||||
assert create_session_payload["code"] == 0, create_session_payload
|
||||
session_id = create_session_payload["data"]["id"]
|
||||
|
||||
res = rest_client.post(
|
||||
"/agents/chat/completions",
|
||||
json={
|
||||
"agent_id": agent_id,
|
||||
"query": "hello",
|
||||
"stream": True,
|
||||
"session_id": session_id,
|
||||
"return_trace": True,
|
||||
},
|
||||
timeout=60,
|
||||
)
|
||||
assert res.status_code == 200
|
||||
content_type = res.headers.get("Content-Type", "")
|
||||
assert "text/event-stream" in content_type, content_type
|
||||
|
||||
events = _sse_events(res.text)
|
||||
assert events, res.text
|
||||
assert events[-1].strip() == "[DONE]", events[-1]
|
||||
|
||||
json_events = [json.loads(evt) for evt in events if evt.strip() != "[DONE]"]
|
||||
assert json_events, events
|
||||
assert any(isinstance(evt, dict) for evt in json_events), json_events
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_agent_openai_compatible_mode(rest_client, create_agent_resource):
|
||||
agent_id = create_agent_resource("restful_agent_openai_compat")
|
||||
|
||||
missing_messages = rest_client.post(
|
||||
"/agents/chat/completions",
|
||||
json={"agent_id": agent_id, "openai-compatible": True, "model": "model", "messages": []},
|
||||
)
|
||||
assert missing_messages.status_code == 200
|
||||
missing_messages_payload = missing_messages.json()
|
||||
assert missing_messages_payload["code"] == 102, missing_messages_payload
|
||||
assert "at least one message" in missing_messages_payload["message"], missing_messages_payload
|
||||
|
||||
nonstream = rest_client.post(
|
||||
"/agents/chat/completions",
|
||||
json={
|
||||
"agent_id": agent_id,
|
||||
"openai-compatible": True,
|
||||
"model": "model",
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"stream": False,
|
||||
},
|
||||
timeout=60,
|
||||
)
|
||||
assert nonstream.status_code == 200
|
||||
nonstream_payload = nonstream.json()
|
||||
assert isinstance(nonstream_payload, dict), nonstream_payload
|
||||
assert "choices" in nonstream_payload, nonstream_payload
|
||||
|
||||
stream = rest_client.post(
|
||||
"/agents/chat/completions",
|
||||
json={
|
||||
"agent_id": agent_id,
|
||||
"openai-compatible": True,
|
||||
"model": "model",
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"stream": True,
|
||||
},
|
||||
timeout=60,
|
||||
)
|
||||
assert stream.status_code == 200
|
||||
stream_content_type = stream.headers.get("Content-Type", "")
|
||||
assert "text/event-stream" in stream_content_type, stream_content_type
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_agent_support_routes_auth_and_contracts(rest_client, rest_client_noauth, create_agent_resource):
|
||||
prompts_unauth = rest_client_noauth.get("/agents/prompts")
|
||||
assert prompts_unauth.status_code == 401
|
||||
assert prompts_unauth.json()["code"] == 401
|
||||
|
||||
prompts = rest_client.get("/agents/prompts")
|
||||
assert prompts.status_code == 200
|
||||
prompts_payload = prompts.json()
|
||||
assert prompts_payload["code"] == 0, prompts_payload
|
||||
assert "task_analysis" in prompts_payload["data"], prompts_payload
|
||||
assert "citation_guidelines" in prompts_payload["data"], prompts_payload
|
||||
|
||||
templates = rest_client.get("/agents/templates")
|
||||
assert templates.status_code == 200
|
||||
templates_payload = templates.json()
|
||||
assert templates_payload["code"] == 0, templates_payload
|
||||
assert isinstance(templates_payload["data"], list), templates_payload
|
||||
|
||||
agent_id = create_agent_resource("restful_agent_support")
|
||||
versions = rest_client.get(f"/agents/{agent_id}/versions")
|
||||
assert versions.status_code == 200
|
||||
versions_payload = versions.json()
|
||||
assert versions_payload["code"] == 0, versions_payload
|
||||
assert isinstance(versions_payload["data"], list), versions_payload
|
||||
|
||||
logs = rest_client.get(f"/agents/{agent_id}/logs/missing_message")
|
||||
assert logs.status_code == 200
|
||||
logs_payload = logs.json()
|
||||
assert logs_payload["code"] == 0, logs_payload
|
||||
assert isinstance(logs_payload["data"], dict), logs_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_agent_webhook_logs_empty_poll_contract(rest_client, create_agent_resource):
|
||||
agent_id = create_agent_resource("restful_agent_webhook_logs")
|
||||
res = rest_client.get(f"/agents/{agent_id}/webhook/logs", params={"since_ts": 0})
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 0, payload
|
||||
assert payload["data"]["events"] == [], payload
|
||||
assert payload["data"]["finished"] is False, payload
|
||||
assert "next_since_ts" in payload["data"], payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_agent_db_connection_validates_required_fields(rest_client):
|
||||
res = rest_client.post("/agents/test_db_connection", json={"db_type": "mysql"})
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 101, payload
|
||||
assert "required argument are missing" in payload["message"], payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_agent_rerun_requires_required_fields(rest_client):
|
||||
res = rest_client.post("/agents/rerun", json={"id": "flow-1"})
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 101, payload
|
||||
assert "required argument are missing" in payload["message"], payload
|
||||
123
test/testcases/restful_api/test_chats.py
Normal file
123
test/testcases/restful_api/test_chats.py
Normal file
@@ -0,0 +1,123 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
class TestChatsAuthorization:
|
||||
def test_create_requires_auth(self, rest_client_noauth):
|
||||
res = rest_client_noauth.post("/chats", json={"name": "chat_auth", "dataset_ids": []})
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_chat_crud_cycle(rest_client, clear_chats):
|
||||
create_res = rest_client.post(
|
||||
"/chats",
|
||||
json={"name": "restful_chat_crud", "dataset_ids": []},
|
||||
)
|
||||
assert create_res.status_code == 200
|
||||
create_payload = create_res.json()
|
||||
assert create_payload["code"] == 0, create_payload
|
||||
chat_id = create_payload["data"]["id"]
|
||||
|
||||
list_res = rest_client.get("/chats", params={"id": chat_id})
|
||||
assert list_res.status_code == 200
|
||||
list_payload = list_res.json()
|
||||
assert list_payload["code"] == 0, list_payload
|
||||
chats = list_payload["data"]["chats"]
|
||||
assert len(chats) == 1, list_payload
|
||||
assert chats[0]["id"] == chat_id, list_payload
|
||||
|
||||
get_res = rest_client.get(f"/chats/{chat_id}")
|
||||
assert get_res.status_code == 200
|
||||
get_payload = get_res.json()
|
||||
assert get_payload["code"] == 0, get_payload
|
||||
assert get_payload["data"]["id"] == chat_id, get_payload
|
||||
|
||||
update_res = rest_client.put(
|
||||
f"/chats/{chat_id}",
|
||||
json={"name": "restful_chat_crud_updated", "dataset_ids": []},
|
||||
)
|
||||
assert update_res.status_code == 200
|
||||
update_payload = update_res.json()
|
||||
assert update_payload["code"] == 0, update_payload
|
||||
assert update_payload["data"]["name"] == "restful_chat_crud_updated", update_payload
|
||||
|
||||
patch_res = rest_client.patch(f"/chats/{chat_id}", json={"name": "restful_chat_crud_patched"})
|
||||
assert patch_res.status_code == 200
|
||||
patch_payload = patch_res.json()
|
||||
assert patch_payload["code"] == 0, patch_payload
|
||||
assert patch_payload["data"]["name"] == "restful_chat_crud_patched", patch_payload
|
||||
|
||||
delete_res = rest_client.delete("/chats", json={"ids": [chat_id]})
|
||||
assert delete_res.status_code == 200
|
||||
delete_payload = delete_res.json()
|
||||
assert delete_payload["code"] == 0, delete_payload
|
||||
assert delete_payload["data"]["success_count"] == 1, delete_payload
|
||||
|
||||
list_after_delete = rest_client.get("/chats", params={"id": chat_id})
|
||||
assert list_after_delete.status_code == 200
|
||||
list_after_delete_payload = list_after_delete.json()
|
||||
assert list_after_delete_payload["code"] == 0, list_after_delete_payload
|
||||
assert list_after_delete_payload["data"]["chats"] == [], list_after_delete_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
"name, expected_fragment",
|
||||
[
|
||||
("", "`name` is required."),
|
||||
(" ", "`name` is required."),
|
||||
],
|
||||
)
|
||||
def test_chat_create_name_validation(rest_client, clear_chats, name, expected_fragment):
|
||||
res = rest_client.post("/chats", json={"name": name, "dataset_ids": []})
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 102, payload
|
||||
assert expected_fragment in payload["message"], payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_chat_duplicate_name_validation(rest_client, clear_chats):
|
||||
first = rest_client.post("/chats", json={"name": "duplicate_chat_name", "dataset_ids": []})
|
||||
assert first.status_code == 200
|
||||
first_payload = first.json()
|
||||
assert first_payload["code"] == 0, first_payload
|
||||
|
||||
second = rest_client.post("/chats", json={"name": "duplicate_chat_name", "dataset_ids": []})
|
||||
assert second.status_code == 200
|
||||
second_payload = second.json()
|
||||
assert second_payload["code"] == 102, second_payload
|
||||
assert "Duplicated chat name" in second_payload["message"], second_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_chat_list_pagination(rest_client, clear_chats):
|
||||
for i in range(3):
|
||||
res = rest_client.post("/chats", json={"name": f"chat_page_{i}", "dataset_ids": []})
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 0, payload
|
||||
|
||||
page_res = rest_client.get("/chats", params={"page": 1, "page_size": 2, "orderby": "create_time", "desc": "true"})
|
||||
assert page_res.status_code == 200
|
||||
page_payload = page_res.json()
|
||||
assert page_payload["code"] == 0, page_payload
|
||||
assert len(page_payload["data"]["chats"]) == 2, page_payload
|
||||
assert page_payload["data"]["total"] >= 3, page_payload
|
||||
124
test/testcases/restful_api/test_chunks.py
Normal file
124
test/testcases/restful_api/test_chunks.py
Normal file
@@ -0,0 +1,124 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _assert_created_chunk_id(payload):
|
||||
chunk_id = payload["data"]["chunk"].get("id")
|
||||
assert chunk_id, payload
|
||||
assert isinstance(chunk_id, str), payload
|
||||
assert chunk_id.strip(), payload
|
||||
return chunk_id
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_chunks_add_list_get_update_delete_cycle(rest_client, create_document):
|
||||
dataset_id, document_id = create_document("chunk_cycle.txt")
|
||||
base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks"
|
||||
|
||||
add_res = rest_client.post(
|
||||
base_path,
|
||||
json={"content": "batch2 chunk content", "important_keywords": ["batch2"], "questions": ["what is batch2?"]},
|
||||
)
|
||||
assert add_res.status_code == 200
|
||||
add_payload = add_res.json()
|
||||
assert add_payload["code"] == 0, add_payload
|
||||
chunk_id = _assert_created_chunk_id(add_payload)
|
||||
|
||||
list_res = rest_client.get(base_path, params={"id": chunk_id})
|
||||
assert list_res.status_code == 200
|
||||
list_payload = list_res.json()
|
||||
assert list_payload["code"] == 0, list_payload
|
||||
assert list_payload["data"]["total"] == 1, list_payload
|
||||
assert list_payload["data"]["chunks"][0]["id"] == chunk_id, list_payload
|
||||
|
||||
get_res = rest_client.get(f"{base_path}/{chunk_id}")
|
||||
assert get_res.status_code == 200
|
||||
get_payload = get_res.json()
|
||||
assert get_payload["code"] == 0, get_payload
|
||||
assert get_payload["data"]["id"] == chunk_id, get_payload
|
||||
|
||||
update_res = rest_client.patch(
|
||||
f"{base_path}/{chunk_id}",
|
||||
json={"content": "batch2 chunk content updated"},
|
||||
)
|
||||
assert update_res.status_code == 200
|
||||
update_payload = update_res.json()
|
||||
assert update_payload["code"] == 0, update_payload
|
||||
|
||||
get_updated_res = rest_client.get(f"{base_path}/{chunk_id}")
|
||||
assert get_updated_res.status_code == 200
|
||||
get_updated_payload = get_updated_res.json()
|
||||
assert get_updated_payload["code"] == 0, get_updated_payload
|
||||
assert get_updated_payload["data"]["content_with_weight"] == "batch2 chunk content updated", get_updated_payload
|
||||
|
||||
delete_candidate_res = rest_client.post(base_path, json={"content": "batch2 chunk content to delete"})
|
||||
assert delete_candidate_res.status_code == 200
|
||||
delete_candidate_payload = delete_candidate_res.json()
|
||||
assert delete_candidate_payload["code"] == 0, delete_candidate_payload
|
||||
delete_candidate_id = _assert_created_chunk_id(delete_candidate_payload)
|
||||
|
||||
delete_res = rest_client.delete(base_path, json={"chunk_ids": [delete_candidate_id]})
|
||||
assert delete_res.status_code == 200
|
||||
delete_payload = delete_res.json()
|
||||
assert delete_payload["code"] == 0, delete_payload
|
||||
|
||||
deleted_list_res = rest_client.get(base_path, params={"id": delete_candidate_id})
|
||||
assert deleted_list_res.status_code == 200
|
||||
deleted_list_payload = deleted_list_res.json()
|
||||
assert deleted_list_payload["code"] != 0, deleted_list_payload
|
||||
|
||||
deleted_get_res = rest_client.get(f"{base_path}/{delete_candidate_id}")
|
||||
assert deleted_get_res.status_code == 200
|
||||
deleted_get_payload = deleted_get_res.json()
|
||||
assert deleted_get_payload["code"] != 0, deleted_get_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_chunks_add_requires_content(rest_client, create_document):
|
||||
dataset_id, document_id = create_document("chunk_requires_content.txt")
|
||||
res = rest_client.post(
|
||||
f"/datasets/{dataset_id}/documents/{document_id}/chunks",
|
||||
json={"content": " "},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 102, payload
|
||||
assert payload["message"] == "`content` is required", payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_chunks_list_empty_document(rest_client, create_document):
|
||||
dataset_id, document_id = create_document("chunk_list_empty.txt")
|
||||
base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks"
|
||||
list_res = rest_client.get(base_path)
|
||||
assert list_res.status_code == 200
|
||||
list_payload = list_res.json()
|
||||
assert list_payload["code"] == 0, list_payload
|
||||
assert "chunks" in list_payload["data"], list_payload
|
||||
assert "doc" in list_payload["data"], list_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_chunks_delete_partial_invalid(rest_client, create_document):
|
||||
dataset_id, document_id = create_document("chunk_delete_partial.txt")
|
||||
base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks"
|
||||
delete_res = rest_client.delete(base_path, json={"chunk_ids": ["invalid_chunk_id"]})
|
||||
assert delete_res.status_code == 200
|
||||
delete_payload = delete_res.json()
|
||||
assert delete_payload["code"] == 102, delete_payload
|
||||
assert "expect 1" in delete_payload["message"], delete_payload
|
||||
718
test/testcases/restful_api/test_connector_routes_unit.py
Normal file
718
test/testcases/restful_api/test_connector_routes_unit.py
Normal file
@@ -0,0 +1,718 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class _DummyManager:
|
||||
def route(self, *_args, **_kwargs):
|
||||
def decorator(func):
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class _AwaitableValue:
|
||||
def __init__(self, value):
|
||||
self._value = value
|
||||
|
||||
def __await__(self):
|
||||
async def _co():
|
||||
return self._value
|
||||
|
||||
return _co().__await__()
|
||||
|
||||
|
||||
class _Args(dict):
|
||||
def get(self, key, default=None, type=None):
|
||||
value = super().get(key, default)
|
||||
if type is None:
|
||||
return value
|
||||
try:
|
||||
return type(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
def to_dict(self, flat=True):
|
||||
return dict(self)
|
||||
|
||||
|
||||
class _FakeResponse:
|
||||
def __init__(self, body, status_code):
|
||||
self.body = body
|
||||
self.status_code = status_code
|
||||
self.headers = {}
|
||||
|
||||
|
||||
class _FakeConnectorRecord:
|
||||
def __init__(self, payload):
|
||||
self._payload = payload
|
||||
|
||||
def to_dict(self):
|
||||
return dict(self._payload)
|
||||
|
||||
|
||||
class _FakeCredentials:
|
||||
def __init__(self, raw='{"refresh_token":"rt","access_token":"at"}'):
|
||||
self._raw = raw
|
||||
|
||||
def to_json(self):
|
||||
return self._raw
|
||||
|
||||
|
||||
class _FakeFlow:
|
||||
def __init__(self, client_config, scopes):
|
||||
self.client_config = client_config
|
||||
self.scopes = scopes
|
||||
self.redirect_uri = None
|
||||
self.credentials = _FakeCredentials()
|
||||
self.auth_kwargs = None
|
||||
self.token_code = None
|
||||
self.token_code_verifier = None
|
||||
self.code_verifier = "fake-code-verifier"
|
||||
|
||||
def authorization_url(self, **kwargs):
|
||||
self.auth_kwargs = dict(kwargs)
|
||||
return f"https://oauth.example/{kwargs['state']}", kwargs["state"]
|
||||
|
||||
def fetch_token(self, code, code_verifier=None):
|
||||
self.token_code = code
|
||||
self.token_code_verifier = code_verifier
|
||||
|
||||
|
||||
class _FakeBoxToken:
|
||||
def __init__(self, access_token, refresh_token):
|
||||
self.access_token = access_token
|
||||
self.refresh_token = refresh_token
|
||||
|
||||
|
||||
class _FakeBoxOAuth:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.exchange_code = None
|
||||
|
||||
def get_authorize_url(self, options):
|
||||
return f"https://box.example/auth?state={options.state}&redirect={options.redirect_uri}"
|
||||
|
||||
def get_tokens_authorization_code_grant(self, code):
|
||||
self.exchange_code = code
|
||||
|
||||
def retrieve_token(self):
|
||||
return _FakeBoxToken("box-access", "box-refresh")
|
||||
|
||||
|
||||
class _FakeRedis:
|
||||
def __init__(self):
|
||||
self.store = {}
|
||||
self.set_calls = []
|
||||
self.deleted = []
|
||||
|
||||
def get(self, key):
|
||||
return self.store.get(key)
|
||||
|
||||
def set_obj(self, key, obj, ttl):
|
||||
self.set_calls.append((key, obj, ttl))
|
||||
self.store[key] = json.dumps(obj)
|
||||
|
||||
def delete(self, key):
|
||||
self.deleted.append(key)
|
||||
self.store.pop(key, None)
|
||||
|
||||
|
||||
def _run(coro):
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
def _set_request(module, *, args=None, json_body=None):
|
||||
module.request = SimpleNamespace(
|
||||
args=_Args(args or {}),
|
||||
json=_AwaitableValue({} if json_body is None else json_body),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def auth():
|
||||
return "unit-auth"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def set_tenant_info():
|
||||
return None
|
||||
|
||||
|
||||
def _load_connector_app(monkeypatch):
|
||||
repo_root = Path(__file__).resolve().parents[3]
|
||||
|
||||
api_pkg = ModuleType("api")
|
||||
api_pkg.__path__ = [str(repo_root / "api")]
|
||||
monkeypatch.setitem(sys.modules, "api", api_pkg)
|
||||
|
||||
apps_mod = ModuleType("api.apps")
|
||||
apps_mod.__path__ = [str(repo_root / "api" / "apps")]
|
||||
apps_mod.current_user = SimpleNamespace(id="tenant-1")
|
||||
apps_mod.login_required = lambda fn: fn
|
||||
monkeypatch.setitem(sys.modules, "api.apps", apps_mod)
|
||||
|
||||
db_mod = ModuleType("api.db")
|
||||
db_mod.InputType = SimpleNamespace(POLL="POLL")
|
||||
monkeypatch.setitem(sys.modules, "api.db", db_mod)
|
||||
|
||||
services_pkg = ModuleType("api.db.services")
|
||||
services_pkg.__path__ = []
|
||||
monkeypatch.setitem(sys.modules, "api.db.services", services_pkg)
|
||||
|
||||
connector_service_mod = ModuleType("api.db.services.connector_service")
|
||||
|
||||
class _StubConnectorService:
|
||||
@staticmethod
|
||||
def update_by_id(*_args, **_kwargs):
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def save(**_kwargs):
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_by_id(_connector_id):
|
||||
return True, _FakeConnectorRecord({"id": _connector_id})
|
||||
|
||||
@staticmethod
|
||||
def list(_tenant_id):
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def resume(*_args, **_kwargs):
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def rebuild(*_args, **_kwargs):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def delete_by_id(*_args, **_kwargs):
|
||||
return True
|
||||
|
||||
class _StubSyncLogsService:
|
||||
@staticmethod
|
||||
def list_sync_tasks(*_args, **_kwargs):
|
||||
return [], 0
|
||||
|
||||
connector_service_mod.ConnectorService = _StubConnectorService
|
||||
connector_service_mod.SyncLogsService = _StubSyncLogsService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.connector_service", connector_service_mod)
|
||||
|
||||
api_utils_mod = ModuleType("api.utils.api_utils")
|
||||
|
||||
async def _get_request_json():
|
||||
return {}
|
||||
|
||||
api_utils_mod.get_request_json = _get_request_json
|
||||
api_utils_mod.get_json_result = lambda data=None, message="", code=0: {
|
||||
"code": code,
|
||||
"message": message,
|
||||
"data": data,
|
||||
}
|
||||
api_utils_mod.get_data_error_result = lambda message="", code=400, data=None: {
|
||||
"code": code,
|
||||
"message": message,
|
||||
"data": data,
|
||||
}
|
||||
api_utils_mod.validate_request = lambda *_args, **_kwargs: (lambda fn: fn)
|
||||
monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
|
||||
|
||||
constants_mod = ModuleType("common.constants")
|
||||
constants_mod.RetCode = SimpleNamespace(
|
||||
ARGUMENT_ERROR=101,
|
||||
SERVER_ERROR=500,
|
||||
RUNNING=102,
|
||||
PERMISSION_ERROR=403,
|
||||
)
|
||||
constants_mod.TaskStatus = SimpleNamespace(SCHEDULE="schedule", CANCEL="cancel")
|
||||
monkeypatch.setitem(sys.modules, "common.constants", constants_mod)
|
||||
|
||||
config_mod = ModuleType("common.data_source.config")
|
||||
config_mod.GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI = "https://example.com/drive"
|
||||
config_mod.GMAIL_WEB_OAUTH_REDIRECT_URI = "https://example.com/gmail"
|
||||
config_mod.BOX_WEB_OAUTH_REDIRECT_URI = "https://example.com/box"
|
||||
config_mod.DocumentSource = SimpleNamespace(GMAIL="gmail", GOOGLE_DRIVE="google-drive")
|
||||
monkeypatch.setitem(sys.modules, "common.data_source.config", config_mod)
|
||||
|
||||
google_constants_mod = ModuleType("common.data_source.google_util.constant")
|
||||
google_constants_mod.WEB_OAUTH_POPUP_TEMPLATE = (
|
||||
"<html><head><title>{title}</title></head>"
|
||||
"<body><h1>{heading}</h1><p>{message}</p><script>{payload_json}</script><script>{auto_close}</script></body></html>"
|
||||
)
|
||||
google_constants_mod.GOOGLE_SCOPES = {
|
||||
config_mod.DocumentSource.GMAIL: ["scope-gmail"],
|
||||
config_mod.DocumentSource.GOOGLE_DRIVE: ["scope-drive"],
|
||||
}
|
||||
monkeypatch.setitem(sys.modules, "common.data_source.google_util.constant", google_constants_mod)
|
||||
|
||||
misc_mod = ModuleType("common.misc_utils")
|
||||
misc_mod.get_uuid = lambda: "uuid-from-helper"
|
||||
monkeypatch.setitem(sys.modules, "common.misc_utils", misc_mod)
|
||||
|
||||
rag_pkg = ModuleType("rag")
|
||||
rag_pkg.__path__ = [str(repo_root / "rag")]
|
||||
monkeypatch.setitem(sys.modules, "rag", rag_pkg)
|
||||
|
||||
rag_utils_pkg = ModuleType("rag.utils")
|
||||
rag_utils_pkg.__path__ = [str(repo_root / "rag" / "utils")]
|
||||
monkeypatch.setitem(sys.modules, "rag.utils", rag_utils_pkg)
|
||||
|
||||
redis_mod = ModuleType("rag.utils.redis_conn")
|
||||
redis_mod.REDIS_CONN = _FakeRedis()
|
||||
monkeypatch.setitem(sys.modules, "rag.utils.redis_conn", redis_mod)
|
||||
|
||||
quart_mod = ModuleType("quart")
|
||||
quart_mod.request = SimpleNamespace(args=_Args(), json=_AwaitableValue({}))
|
||||
|
||||
async def _make_response(body, status_code):
|
||||
return _FakeResponse(body, status_code)
|
||||
|
||||
quart_mod.make_response = _make_response
|
||||
monkeypatch.setitem(sys.modules, "quart", quart_mod)
|
||||
|
||||
google_pkg = ModuleType("google_auth_oauthlib")
|
||||
google_pkg.__path__ = []
|
||||
monkeypatch.setitem(sys.modules, "google_auth_oauthlib", google_pkg)
|
||||
|
||||
google_flow_mod = ModuleType("google_auth_oauthlib.flow")
|
||||
|
||||
class _StubFlow:
|
||||
@classmethod
|
||||
def from_client_config(cls, client_config, scopes):
|
||||
return _FakeFlow(client_config, scopes)
|
||||
|
||||
google_flow_mod.Flow = _StubFlow
|
||||
monkeypatch.setitem(sys.modules, "google_auth_oauthlib.flow", google_flow_mod)
|
||||
|
||||
box_mod = ModuleType("box_sdk_gen")
|
||||
|
||||
class _OAuthConfig:
|
||||
def __init__(self, client_id, client_secret):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
|
||||
class _GetAuthorizeUrlOptions:
|
||||
def __init__(self, redirect_uri, state):
|
||||
self.redirect_uri = redirect_uri
|
||||
self.state = state
|
||||
|
||||
box_mod.BoxOAuth = _FakeBoxOAuth
|
||||
box_mod.OAuthConfig = _OAuthConfig
|
||||
box_mod.GetAuthorizeUrlOptions = _GetAuthorizeUrlOptions
|
||||
monkeypatch.setitem(sys.modules, "box_sdk_gen", box_mod)
|
||||
|
||||
module_path = repo_root / "api" / "apps" / "restful_apis" / "connector_api.py"
|
||||
spec = importlib.util.spec_from_file_location("test_connector_routes_unit", module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
module.manager = _DummyManager()
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_connector_basic_routes_and_task_controls(monkeypatch):
|
||||
module = _load_connector_app(monkeypatch)
|
||||
|
||||
async def _no_sleep(_secs):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(module.asyncio, "sleep", _no_sleep)
|
||||
|
||||
records = {"conn-1": _FakeConnectorRecord({"id": "conn-1", "source": "drive"})}
|
||||
update_calls = []
|
||||
save_calls = []
|
||||
resume_calls = []
|
||||
delete_calls = []
|
||||
|
||||
monkeypatch.setattr(module.ConnectorService, "update_by_id", lambda cid, payload: update_calls.append((cid, payload)))
|
||||
|
||||
def _save(**payload):
|
||||
save_calls.append(payload)
|
||||
records[payload["id"]] = _FakeConnectorRecord(payload)
|
||||
|
||||
monkeypatch.setattr(module.ConnectorService, "save", _save)
|
||||
monkeypatch.setattr(module.ConnectorService, "get_by_id", lambda cid: (True, records[cid]))
|
||||
monkeypatch.setattr(module.ConnectorService, "list", lambda tenant_id: [{"id": "listed", "tenant": tenant_id}])
|
||||
monkeypatch.setattr(module.SyncLogsService, "list_sync_tasks", lambda cid, page, page_size: ([{"id": "log-1"}], 9))
|
||||
monkeypatch.setattr(module.ConnectorService, "resume", lambda cid, status: resume_calls.append((cid, status)))
|
||||
monkeypatch.setattr(module.ConnectorService, "delete_by_id", lambda cid: delete_calls.append(cid))
|
||||
monkeypatch.setattr(module, "get_uuid", lambda: "generated-id")
|
||||
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"get_request_json",
|
||||
lambda: _AwaitableValue({"id": "conn-1", "refresh_freq": 7, "config": {"x": 1}}),
|
||||
)
|
||||
res = _run(module.update_connector("conn-1"))
|
||||
assert update_calls == [("conn-1", {'id': 'conn-1', "refresh_freq": 7, "config": {"x": 1}})]
|
||||
assert res["data"]["id"] == "conn-1"
|
||||
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"get_request_json",
|
||||
lambda: _AwaitableValue({"name": "new", "source": "gmail", "config": {"y": 2}}),
|
||||
)
|
||||
res = _run(module.create_connector())
|
||||
assert save_calls[-1]["id"] == "generated-id"
|
||||
assert save_calls[-1]["tenant_id"] == "tenant-1"
|
||||
assert save_calls[-1]["input_type"] == module.InputType.POLL
|
||||
assert res["data"]["id"] == "generated-id"
|
||||
|
||||
list_res = module.list_connector()
|
||||
assert list_res["data"] == [{"id": "listed", "tenant": "tenant-1"}]
|
||||
|
||||
monkeypatch.setattr(module.ConnectorService, "get_by_id", lambda _cid: (False, None))
|
||||
missing_res = module.get_connector("missing")
|
||||
assert missing_res["message"] == "Can't find this Connector!"
|
||||
|
||||
monkeypatch.setattr(module.ConnectorService, "get_by_id", lambda cid: (True, _FakeConnectorRecord({"id": cid})))
|
||||
found_res = module.get_connector("conn-2")
|
||||
assert found_res["data"]["id"] == "conn-2"
|
||||
|
||||
_set_request(module, args={"page": "2", "page_size": "7"})
|
||||
logs_res = module.list_logs("conn-log")
|
||||
assert logs_res["data"] == {"total": 9, "logs": [{"id": "log-1"}]}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"resume": True}))
|
||||
assert _run(module.resume("conn-r1"))["data"] is True
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"resume": False}))
|
||||
assert _run(module.resume("conn-r2"))["data"] is True
|
||||
assert ("conn-r1", module.TaskStatus.SCHEDULE) in resume_calls
|
||||
assert ("conn-r2", module.TaskStatus.CANCEL) in resume_calls
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"kb_id": "kb-1"}))
|
||||
monkeypatch.setattr(module.ConnectorService, "rebuild", lambda *_args: "rebuild-failed")
|
||||
failed_rebuild = _run(module.rebuild("conn-rb"))
|
||||
assert failed_rebuild["code"] == module.RetCode.SERVER_ERROR
|
||||
assert failed_rebuild["data"] is False
|
||||
|
||||
monkeypatch.setattr(module.ConnectorService, "rebuild", lambda *_args: None)
|
||||
ok_rebuild = _run(module.rebuild("conn-rb"))
|
||||
assert ok_rebuild["data"] is True
|
||||
|
||||
rm_res = module.rm_connector("conn-rm")
|
||||
assert rm_res["data"] is True
|
||||
assert ("conn-rm", module.TaskStatus.CANCEL) in resume_calls
|
||||
assert delete_calls == ["conn-rm"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_connector_oauth_helper_functions(monkeypatch):
|
||||
module = _load_connector_app(monkeypatch)
|
||||
|
||||
assert module._web_state_cache_key("flow-a", "gmail") == "gmail_web_flow_state:flow-a"
|
||||
assert module._web_result_cache_key("flow-b", "google-drive") == "google-drive_web_flow_result:flow-b"
|
||||
|
||||
creds_dict = {"web": {"client_id": "id"}}
|
||||
assert module._load_credentials(creds_dict) == creds_dict
|
||||
assert module._load_credentials(json.dumps(creds_dict)) == creds_dict
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid Google credentials JSON"):
|
||||
module._load_credentials("{not-json")
|
||||
|
||||
assert module._get_web_client_config(creds_dict) == {"web": {"client_id": "id"}}
|
||||
with pytest.raises(ValueError, match="must include a 'web'"):
|
||||
module._get_web_client_config({"installed": {"client_id": "id"}})
|
||||
|
||||
popup_ok = _run(module._render_web_oauth_popup("flow-1", True, "done", "gmail"))
|
||||
assert popup_ok.status_code == 200
|
||||
assert popup_ok.headers["Content-Type"] == "text/html; charset=utf-8"
|
||||
assert "Authorization complete" in popup_ok.body
|
||||
assert "ragflow-gmail-oauth" in popup_ok.body
|
||||
|
||||
popup_error = _run(module._render_web_oauth_popup("flow-2", False, "<denied>", "google-drive"))
|
||||
assert popup_error.status_code == 200
|
||||
assert "Authorization failed" in popup_error.body
|
||||
assert "<denied>" in popup_error.body
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_start_google_web_oauth_matrix(monkeypatch):
|
||||
module = _load_connector_app(monkeypatch)
|
||||
|
||||
redis = _FakeRedis()
|
||||
monkeypatch.setattr(module, "REDIS_CONN", redis)
|
||||
monkeypatch.setattr(module.time, "time", lambda: 1700000000)
|
||||
|
||||
flow_calls = []
|
||||
|
||||
def _from_client_config(client_config, scopes):
|
||||
flow = _FakeFlow(client_config, scopes)
|
||||
flow_calls.append(flow)
|
||||
return flow
|
||||
|
||||
monkeypatch.setattr(module.Flow, "from_client_config", staticmethod(_from_client_config))
|
||||
|
||||
_set_request(module, args={"type": "invalid"})
|
||||
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"credentials": "{}"}))
|
||||
invalid_type = _run(module.start_google_web_oauth())
|
||||
assert invalid_type["code"] == module.RetCode.ARGUMENT_ERROR
|
||||
|
||||
monkeypatch.setattr(module, "GMAIL_WEB_OAUTH_REDIRECT_URI", "")
|
||||
_set_request(module, args={"type": "gmail"})
|
||||
missing_redirect = _run(module.start_google_web_oauth())
|
||||
assert missing_redirect["code"] == module.RetCode.SERVER_ERROR
|
||||
|
||||
monkeypatch.setattr(module, "GMAIL_WEB_OAUTH_REDIRECT_URI", "https://example.com/gmail")
|
||||
monkeypatch.setattr(module, "GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI", "https://example.com/drive")
|
||||
|
||||
_set_request(module, args={"type": "google-drive"})
|
||||
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"credentials": "{invalid-json"}))
|
||||
invalid_credentials = _run(module.start_google_web_oauth())
|
||||
assert invalid_credentials["code"] == module.RetCode.ARGUMENT_ERROR
|
||||
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"get_request_json",
|
||||
lambda: _AwaitableValue({"credentials": json.dumps({"web": {"client_id": "id"}, "refresh_token": "rt"})}),
|
||||
)
|
||||
has_refresh_token = _run(module.start_google_web_oauth())
|
||||
assert has_refresh_token["code"] == module.RetCode.ARGUMENT_ERROR
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"credentials": json.dumps({"installed": {"x": 1}})}))
|
||||
missing_web = _run(module.start_google_web_oauth())
|
||||
assert missing_web["code"] == module.RetCode.ARGUMENT_ERROR
|
||||
|
||||
ids = iter(["flow-gmail", "flow-drive"])
|
||||
monkeypatch.setattr(module.uuid, "uuid4", lambda: next(ids))
|
||||
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"get_request_json",
|
||||
lambda: _AwaitableValue({"credentials": json.dumps({"web": {"client_id": "id", "client_secret": "secret"}})}),
|
||||
)
|
||||
|
||||
_set_request(module, args={"type": "gmail"})
|
||||
gmail_ok = _run(module.start_google_web_oauth())
|
||||
assert gmail_ok["code"] == 0
|
||||
assert gmail_ok["data"]["flow_id"] == "flow-gmail"
|
||||
assert gmail_ok["data"]["authorization_url"].endswith("flow-gmail")
|
||||
|
||||
_set_request(module, args={})
|
||||
drive_ok = _run(module.start_google_web_oauth())
|
||||
assert drive_ok["code"] == 0
|
||||
assert drive_ok["data"]["flow_id"] == "flow-drive"
|
||||
assert drive_ok["data"]["authorization_url"].endswith("flow-drive")
|
||||
|
||||
assert any(call.scopes == module.GOOGLE_SCOPES[module.DocumentSource.GMAIL] for call in flow_calls)
|
||||
assert any(call.scopes == module.GOOGLE_SCOPES[module.DocumentSource.GOOGLE_DRIVE] for call in flow_calls)
|
||||
assert "gmail_web_flow_state:flow-gmail" in redis.store
|
||||
assert "google-drive_web_flow_state:flow-drive" in redis.store
|
||||
assert json.loads(redis.store["gmail_web_flow_state:flow-gmail"])["code_verifier"] == "fake-code-verifier"
|
||||
assert json.loads(redis.store["google-drive_web_flow_state:flow-drive"])["code_verifier"] == "fake-code-verifier"
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_google_web_oauth_callbacks_matrix(monkeypatch):
|
||||
module = _load_connector_app(monkeypatch)
|
||||
|
||||
flow_calls = []
|
||||
|
||||
def _from_client_config(client_config, scopes):
|
||||
flow = _FakeFlow(client_config, scopes)
|
||||
flow_calls.append(flow)
|
||||
return flow
|
||||
|
||||
monkeypatch.setattr(module.Flow, "from_client_config", staticmethod(_from_client_config))
|
||||
|
||||
callback_specs = [
|
||||
(
|
||||
module.google_gmail_web_oauth_callback,
|
||||
"gmail",
|
||||
module.GMAIL_WEB_OAUTH_REDIRECT_URI,
|
||||
module.GOOGLE_SCOPES[module.DocumentSource.GMAIL],
|
||||
),
|
||||
(
|
||||
module.google_drive_web_oauth_callback,
|
||||
"google-drive",
|
||||
module.GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI,
|
||||
module.GOOGLE_SCOPES[module.DocumentSource.GOOGLE_DRIVE],
|
||||
),
|
||||
]
|
||||
|
||||
for callback, source, expected_redirect, expected_scopes in callback_specs:
|
||||
redis = _FakeRedis()
|
||||
monkeypatch.setattr(module, "REDIS_CONN", redis)
|
||||
|
||||
_set_request(module, args={})
|
||||
missing_state = _run(callback())
|
||||
assert "Missing OAuth state parameter." in missing_state.body
|
||||
|
||||
_set_request(module, args={"state": "sid"})
|
||||
expired_state = _run(callback())
|
||||
assert "Authorization session expired" in expired_state.body
|
||||
|
||||
redis.store[module._web_state_cache_key("sid", source)] = json.dumps({"user_id": "tenant-1"})
|
||||
_set_request(module, args={"state": "sid"})
|
||||
invalid_state = _run(callback())
|
||||
assert "Authorization session was invalid" in invalid_state.body
|
||||
assert module._web_state_cache_key("sid", source) in redis.deleted
|
||||
|
||||
redis.store[module._web_state_cache_key("sid", source)] = json.dumps({
|
||||
"user_id": "tenant-1",
|
||||
"client_config": {"web": {"client_id": "cid"}},
|
||||
})
|
||||
_set_request(module, args={"state": "sid", "error": "denied", "error_description": "permission denied"})
|
||||
oauth_error = _run(callback())
|
||||
assert "permission denied" in oauth_error.body
|
||||
|
||||
redis.store[module._web_state_cache_key("sid", source)] = json.dumps({
|
||||
"user_id": "tenant-1",
|
||||
"client_config": {"web": {"client_id": "cid"}},
|
||||
})
|
||||
_set_request(module, args={"state": "sid"})
|
||||
missing_code = _run(callback())
|
||||
assert "Missing authorization code" in missing_code.body
|
||||
|
||||
redis.store[module._web_state_cache_key("sid", source)] = json.dumps({
|
||||
"user_id": "tenant-1",
|
||||
"client_config": {"web": {"client_id": "cid"}},
|
||||
"code_verifier": "state-code-verifier",
|
||||
})
|
||||
_set_request(module, args={"state": "sid", "code": "code-123"})
|
||||
success = _run(callback())
|
||||
assert "Authorization completed successfully." in success.body
|
||||
|
||||
result_key = module._web_result_cache_key("sid", source)
|
||||
assert result_key in redis.store
|
||||
assert module._web_state_cache_key("sid", source) in redis.deleted
|
||||
|
||||
assert flow_calls[-1].redirect_uri == expected_redirect
|
||||
assert flow_calls[-1].scopes == expected_scopes
|
||||
assert flow_calls[-1].token_code == "code-123"
|
||||
assert flow_calls[-1].token_code_verifier == "state-code-verifier"
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_poll_google_web_result_matrix(monkeypatch):
|
||||
module = _load_connector_app(monkeypatch)
|
||||
redis = _FakeRedis()
|
||||
monkeypatch.setattr(module, "REDIS_CONN", redis)
|
||||
|
||||
_set_request(module, args={"type": "invalid"}, json_body={"flow_id": "flow-1"})
|
||||
invalid_type = _run(module.poll_google_web_result())
|
||||
assert invalid_type["code"] == module.RetCode.ARGUMENT_ERROR
|
||||
|
||||
_set_request(module, args={"type": "gmail"}, json_body={"flow_id": "flow-1"})
|
||||
pending = _run(module.poll_google_web_result())
|
||||
assert pending["code"] == module.RetCode.RUNNING
|
||||
|
||||
redis.store[module._web_result_cache_key("flow-1", "gmail")] = json.dumps(
|
||||
{"user_id": "another-user", "credentials": "token-x"}
|
||||
)
|
||||
_set_request(module, args={"type": "gmail"}, json_body={"flow_id": "flow-1"})
|
||||
permission_error = _run(module.poll_google_web_result())
|
||||
assert permission_error["code"] == module.RetCode.PERMISSION_ERROR
|
||||
|
||||
redis.store[module._web_result_cache_key("flow-1", "gmail")] = json.dumps(
|
||||
{"user_id": "tenant-1", "credentials": "token-ok"}
|
||||
)
|
||||
_set_request(module, args={"type": "gmail"}, json_body={"flow_id": "flow-1"})
|
||||
success = _run(module.poll_google_web_result())
|
||||
assert success["code"] == 0
|
||||
assert success["data"] == {"credentials": "token-ok"}
|
||||
assert module._web_result_cache_key("flow-1", "gmail") in redis.deleted
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_box_oauth_start_callback_and_poll_matrix(monkeypatch):
|
||||
module = _load_connector_app(monkeypatch)
|
||||
redis = _FakeRedis()
|
||||
monkeypatch.setattr(module, "REDIS_CONN", redis)
|
||||
|
||||
created_auth = []
|
||||
|
||||
class _TrackingBoxOAuth(_FakeBoxOAuth):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
created_auth.append(self)
|
||||
|
||||
monkeypatch.setattr(module, "BoxOAuth", _TrackingBoxOAuth)
|
||||
monkeypatch.setattr(module.uuid, "uuid4", lambda: "flow-box")
|
||||
monkeypatch.setattr(module.time, "time", lambda: 1800000000)
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({}))
|
||||
missing_params = _run(module.start_box_web_oauth())
|
||||
assert missing_params["code"] == module.RetCode.ARGUMENT_ERROR
|
||||
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"get_request_json",
|
||||
lambda: _AwaitableValue({"client_id": "cid", "client_secret": "sec", "redirect_uri": "https://box.local/callback"}),
|
||||
)
|
||||
start_ok = _run(module.start_box_web_oauth())
|
||||
assert start_ok["code"] == 0
|
||||
assert start_ok["data"]["flow_id"] == "flow-box"
|
||||
assert "authorization_url" in start_ok["data"]
|
||||
assert module._web_state_cache_key("flow-box", "box") in redis.store
|
||||
|
||||
_set_request(module, args={})
|
||||
missing_state = _run(module.box_web_oauth_callback())
|
||||
assert "Missing OAuth parameters." in missing_state.body
|
||||
|
||||
_set_request(module, args={"state": "flow-box"})
|
||||
missing_code = _run(module.box_web_oauth_callback())
|
||||
assert "Missing authorization code from Box." in missing_code.body
|
||||
|
||||
redis.store[module._web_state_cache_key("flow-null", "box")] = "null"
|
||||
_set_request(module, args={"state": "flow-null", "code": "abc"})
|
||||
invalid_session = _run(module.box_web_oauth_callback())
|
||||
assert invalid_session["code"] == module.RetCode.ARGUMENT_ERROR
|
||||
|
||||
redis.store[module._web_state_cache_key("flow-box", "box")] = json.dumps(
|
||||
{"user_id": "tenant-1", "client_id": "cid", "client_secret": "sec"}
|
||||
)
|
||||
_set_request(module, args={"state": "flow-box", "code": "abc", "error": "access_denied", "error_description": "denied"})
|
||||
callback_error = _run(module.box_web_oauth_callback())
|
||||
assert "denied" in callback_error.body
|
||||
|
||||
redis.store[module._web_state_cache_key("flow-ok", "box")] = json.dumps(
|
||||
{"user_id": "tenant-1", "client_id": "cid", "client_secret": "sec"}
|
||||
)
|
||||
_set_request(module, args={"state": "flow-ok", "code": "code-ok"})
|
||||
callback_success = _run(module.box_web_oauth_callback())
|
||||
assert "Authorization completed successfully." in callback_success.body
|
||||
assert created_auth[-1].exchange_code == "code-ok"
|
||||
assert module._web_result_cache_key("flow-ok", "box") in redis.store
|
||||
assert module._web_state_cache_key("flow-ok", "box") in redis.deleted
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"flow_id": "flow-ok"}))
|
||||
redis.store.pop(module._web_result_cache_key("flow-ok", "box"), None)
|
||||
pending = _run(module.poll_box_web_result())
|
||||
assert pending["code"] == module.RetCode.RUNNING
|
||||
|
||||
redis.store[module._web_result_cache_key("flow-ok", "box")] = json.dumps({"user_id": "another-user"})
|
||||
permission_error = _run(module.poll_box_web_result())
|
||||
assert permission_error["code"] == module.RetCode.PERMISSION_ERROR
|
||||
|
||||
redis.store[module._web_result_cache_key("flow-ok", "box")] = json.dumps(
|
||||
{"user_id": "tenant-1", "access_token": "at", "refresh_token": "rt"}
|
||||
)
|
||||
poll_success = _run(module.poll_box_web_result())
|
||||
assert poll_success["code"] == 0
|
||||
assert poll_success["data"]["credentials"]["access_token"] == "at"
|
||||
assert module._web_result_cache_key("flow-ok", "box") in redis.deleted
|
||||
335
test/testcases/restful_api/test_datasets.py
Normal file
335
test/testcases/restful_api/test_datasets.py
Normal file
@@ -0,0 +1,335 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import pytest
|
||||
from configs import DATASET_NAME_LIMIT
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
class TestDatasetsAuthorization:
|
||||
def test_create_requires_auth(self, rest_client_noauth):
|
||||
res = rest_client_noauth.post("/datasets", json={"name": "auth_test"})
|
||||
assert res.status_code == 401
|
||||
payload = res.json()
|
||||
assert payload["code"] == 401, payload
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_dataset_crud_cycle(rest_client, clear_datasets):
|
||||
create_res = rest_client.post("/datasets", json={"name": "restful_dataset_crud"})
|
||||
assert create_res.status_code == 200
|
||||
create_payload = create_res.json()
|
||||
assert create_payload["code"] == 0, create_payload
|
||||
dataset_id = create_payload["data"]["id"]
|
||||
|
||||
get_res = rest_client.get(f"/datasets/{dataset_id}")
|
||||
assert get_res.status_code == 200
|
||||
get_payload = get_res.json()
|
||||
assert get_payload["code"] == 0, get_payload
|
||||
assert get_payload["data"]["id"] == dataset_id, get_payload
|
||||
|
||||
update_res = rest_client.put(
|
||||
f"/datasets/{dataset_id}",
|
||||
json={"name": "restful_dataset_crud_updated"},
|
||||
)
|
||||
assert update_res.status_code == 200
|
||||
update_payload = update_res.json()
|
||||
assert update_payload["code"] == 0, update_payload
|
||||
assert update_payload["data"]["name"] == "restful_dataset_crud_updated", update_payload
|
||||
|
||||
list_res = rest_client.get("/datasets", params={"id": dataset_id})
|
||||
assert list_res.status_code == 200
|
||||
list_payload = list_res.json()
|
||||
assert list_payload["code"] == 0, list_payload
|
||||
assert len(list_payload["data"]) == 1, list_payload
|
||||
assert list_payload["data"][0]["id"] == dataset_id, list_payload
|
||||
assert list_payload.get("total_datasets", 0) >= 1, list_payload
|
||||
|
||||
delete_res = rest_client.delete("/datasets", json={"ids": [dataset_id]})
|
||||
assert delete_res.status_code == 200
|
||||
delete_payload = delete_res.json()
|
||||
assert delete_payload["code"] == 0, delete_payload
|
||||
|
||||
list_after_delete = rest_client.get("/datasets")
|
||||
assert list_after_delete.status_code == 200
|
||||
list_after_delete_payload = list_after_delete.json()
|
||||
assert list_after_delete_payload["code"] == 0, list_after_delete_payload
|
||||
assert all(dataset["id"] != dataset_id for dataset in list_after_delete_payload["data"]), list_after_delete_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
"name, expected_fragment",
|
||||
[
|
||||
("", "String should have at least 1 character"),
|
||||
(" ", "String should have at least 1 character"),
|
||||
("a" * (DATASET_NAME_LIMIT + 1), f"String should have at most {DATASET_NAME_LIMIT} characters"),
|
||||
],
|
||||
ids=["empty", "spaces", "too_long"],
|
||||
)
|
||||
def test_dataset_create_name_validation(rest_client, clear_datasets, name, expected_fragment):
|
||||
res = rest_client.post("/datasets", json={"name": name})
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 101, payload
|
||||
assert expected_fragment in payload["message"], payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_dataset_list_ordering_and_pagination(rest_client, clear_datasets):
|
||||
for i in range(3):
|
||||
res = rest_client.post("/datasets", json={"name": f"dataset_page_{i}"})
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 0, payload
|
||||
|
||||
list_res = rest_client.get(
|
||||
"/datasets",
|
||||
params={"page": 1, "page_size": 2, "orderby": "create_time", "desc": "true"},
|
||||
)
|
||||
assert list_res.status_code == 200
|
||||
list_payload = list_res.json()
|
||||
assert list_payload["code"] == 0, list_payload
|
||||
assert len(list_payload["data"]) == 2, list_payload
|
||||
assert list_payload.get("total_datasets", 0) >= 3, list_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_dataset_search_endpoint(rest_client, ensure_parsed_document):
|
||||
dataset_id, _ = ensure_parsed_document()
|
||||
res = rest_client.post(
|
||||
f"/datasets/{dataset_id}/search",
|
||||
json={"question": "test TXT file", "page": 1, "size": 10},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 0, payload
|
||||
assert "chunks" in payload["data"], payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_dataset_search_requires_question(rest_client, create_dataset):
|
||||
dataset_id = create_dataset("dataset_search_missing_question")
|
||||
res = rest_client.post(f"/datasets/{dataset_id}/search", json={})
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 101, payload
|
||||
assert "question" in payload["message"], payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_dataset_tags_and_aggregation(rest_client, create_dataset):
|
||||
dataset_id = create_dataset("dataset_tags")
|
||||
second_dataset_id = create_dataset("dataset_tags_second")
|
||||
|
||||
list_tags_res = rest_client.get(f"/datasets/{dataset_id}/tags")
|
||||
assert list_tags_res.status_code == 200
|
||||
list_tags_payload = list_tags_res.json()
|
||||
# Known env/runtime behavior: this route can return 102 when retriever tag
|
||||
# backend is unavailable for an empty dataset. Keep route-contract coverage.
|
||||
assert list_tags_payload["code"] in (0, 102), list_tags_payload
|
||||
|
||||
aggregate_res = rest_client.get(
|
||||
"/datasets/tags/aggregation",
|
||||
params={"dataset_ids": f"{dataset_id},{second_dataset_id}"},
|
||||
)
|
||||
assert aggregate_res.status_code == 200
|
||||
aggregate_payload = aggregate_res.json()
|
||||
assert aggregate_payload["code"] in (0, 102), aggregate_payload
|
||||
|
||||
empty_aggregate_res = rest_client.get("/datasets/tags/aggregation")
|
||||
assert empty_aggregate_res.status_code == 200
|
||||
empty_aggregate_payload = empty_aggregate_res.json()
|
||||
assert empty_aggregate_payload["code"] != 0, empty_aggregate_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_dataset_tags_delete_and_rename_validation(rest_client, create_dataset):
|
||||
dataset_id = create_dataset("dataset_tag_mutation")
|
||||
|
||||
delete_missing_tags = rest_client.delete(f"/datasets/{dataset_id}/tags", json={})
|
||||
assert delete_missing_tags.status_code == 200
|
||||
delete_missing_tags_payload = delete_missing_tags.json()
|
||||
assert delete_missing_tags_payload["code"] != 0, delete_missing_tags_payload
|
||||
|
||||
delete_invalid_tags_type = rest_client.delete(f"/datasets/{dataset_id}/tags", json={"tags": "wrong"})
|
||||
assert delete_invalid_tags_type.status_code == 200
|
||||
delete_invalid_tags_type_payload = delete_invalid_tags_type.json()
|
||||
assert delete_invalid_tags_type_payload["code"] != 0, delete_invalid_tags_type_payload
|
||||
|
||||
rename_empty = rest_client.put(
|
||||
f"/datasets/{dataset_id}/tags",
|
||||
json={"from_tag": "", "to_tag": ""},
|
||||
)
|
||||
assert rename_empty.status_code == 200
|
||||
rename_empty_payload = rename_empty.json()
|
||||
assert rename_empty_payload["code"] != 0, rename_empty_payload
|
||||
|
||||
rename_invalid_dataset = rest_client.put(
|
||||
"/datasets/invalid_id/tags",
|
||||
json={"from_tag": "old", "to_tag": "new"},
|
||||
)
|
||||
assert rename_invalid_dataset.status_code == 200
|
||||
rename_invalid_dataset_payload = rename_invalid_dataset.json()
|
||||
assert rename_invalid_dataset_payload["code"] != 0, rename_invalid_dataset_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_dataset_flattened_metadata(rest_client, create_dataset):
|
||||
first_dataset_id = create_dataset("flattened_meta_1")
|
||||
second_dataset_id = create_dataset("flattened_meta_2")
|
||||
|
||||
flattened_res = rest_client.get(
|
||||
"/datasets/metadata/flattened",
|
||||
params={"dataset_ids": f"{first_dataset_id},{second_dataset_id}"},
|
||||
)
|
||||
assert flattened_res.status_code == 200
|
||||
flattened_payload = flattened_res.json()
|
||||
assert flattened_payload["code"] == 0, flattened_payload
|
||||
|
||||
empty_ids_res = rest_client.get("/datasets/metadata/flattened")
|
||||
assert empty_ids_res.status_code == 200
|
||||
empty_ids_payload = empty_ids_res.json()
|
||||
assert empty_ids_payload["code"] != 0, empty_ids_payload
|
||||
|
||||
invalid_dataset_res = rest_client.get(
|
||||
"/datasets/metadata/flattened",
|
||||
params={"dataset_ids": "invalid_id"},
|
||||
)
|
||||
assert invalid_dataset_res.status_code == 200
|
||||
invalid_dataset_payload = invalid_dataset_res.json()
|
||||
assert invalid_dataset_payload["code"] != 0, invalid_dataset_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_dataset_ingestion_summary_and_logs(rest_client, create_dataset):
|
||||
dataset_id = create_dataset("dataset_ingestions")
|
||||
|
||||
summary_res = rest_client.get(f"/datasets/{dataset_id}/ingestions/summary")
|
||||
assert summary_res.status_code == 200
|
||||
summary_payload = summary_res.json()
|
||||
assert summary_payload["code"] == 0, summary_payload
|
||||
assert "doc_num" in summary_payload["data"], summary_payload
|
||||
assert "chunk_num" in summary_payload["data"], summary_payload
|
||||
assert "token_num" in summary_payload["data"], summary_payload
|
||||
assert "status" in summary_payload["data"], summary_payload
|
||||
|
||||
logs_res = rest_client.get(
|
||||
f"/datasets/{dataset_id}/ingestions",
|
||||
params={"page": 1, "page_size": 10},
|
||||
)
|
||||
assert logs_res.status_code == 200
|
||||
logs_payload = logs_res.json()
|
||||
assert logs_payload["code"] == 0, logs_payload
|
||||
assert "total" in logs_payload["data"], logs_payload
|
||||
assert "logs" in logs_payload["data"], logs_payload
|
||||
|
||||
not_found_log_res = rest_client.get(f"/datasets/{dataset_id}/ingestions/nonexistent_log")
|
||||
assert not_found_log_res.status_code == 200
|
||||
not_found_log_payload = not_found_log_res.json()
|
||||
assert not_found_log_payload["code"] != 0, not_found_log_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_dataset_ingestion_invalid_dataset(rest_client):
|
||||
summary_res = rest_client.get("/datasets/invalid_id/ingestions/summary")
|
||||
assert summary_res.status_code == 200
|
||||
summary_payload = summary_res.json()
|
||||
assert summary_payload["code"] != 0, summary_payload
|
||||
|
||||
logs_res = rest_client.get("/datasets/invalid_id/ingestions")
|
||||
assert logs_res.status_code == 200
|
||||
logs_payload = logs_res.json()
|
||||
assert logs_payload["code"] != 0, logs_payload
|
||||
|
||||
log_res = rest_client.get("/datasets/invalid_id/ingestions/some_log_id")
|
||||
assert log_res.status_code == 200
|
||||
log_payload = log_res.json()
|
||||
assert log_payload["code"] != 0, log_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_dataset_index_endpoints(rest_client, create_dataset):
|
||||
dataset_id = create_dataset("dataset_index_endpoints")
|
||||
|
||||
run_invalid_type = rest_client.post(
|
||||
f"/datasets/{dataset_id}/index",
|
||||
params={"type": "invalid_type"},
|
||||
)
|
||||
assert run_invalid_type.status_code == 200
|
||||
run_invalid_type_payload = run_invalid_type.json()
|
||||
assert run_invalid_type_payload["code"] != 0, run_invalid_type_payload
|
||||
|
||||
run_no_docs = rest_client.post(
|
||||
f"/datasets/{dataset_id}/index",
|
||||
params={"type": "graph"},
|
||||
)
|
||||
assert run_no_docs.status_code == 200
|
||||
run_no_docs_payload = run_no_docs.json()
|
||||
assert run_no_docs_payload["code"] == 102, run_no_docs_payload
|
||||
|
||||
trace_no_task = rest_client.get(
|
||||
f"/datasets/{dataset_id}/index",
|
||||
params={"type": "graph"},
|
||||
)
|
||||
assert trace_no_task.status_code == 200
|
||||
trace_no_task_payload = trace_no_task.json()
|
||||
assert trace_no_task_payload["code"] == 0, trace_no_task_payload
|
||||
assert trace_no_task_payload["data"] == {}, trace_no_task_payload
|
||||
|
||||
delete_graph = rest_client.delete(f"/datasets/{dataset_id}/graph")
|
||||
assert delete_graph.status_code == 200
|
||||
delete_graph_payload = delete_graph.json()
|
||||
assert delete_graph_payload["code"] == 0, delete_graph_payload
|
||||
|
||||
delete_invalid_type = rest_client.delete(f"/datasets/{dataset_id}/invalid_type")
|
||||
assert delete_invalid_type.status_code == 200
|
||||
delete_invalid_type_payload = delete_invalid_type.json()
|
||||
assert delete_invalid_type_payload["code"] != 0, delete_invalid_type_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize("index_type", ["graph", "raptor", "mindmap"])
|
||||
def test_dataset_index_run_with_document_creates_task(rest_client, create_document, index_type):
|
||||
dataset_id, _ = create_document("dataset_index_graph_source.txt")
|
||||
run_graph = rest_client.post(
|
||||
f"/datasets/{dataset_id}/index",
|
||||
params={"type": index_type},
|
||||
)
|
||||
assert run_graph.status_code == 200
|
||||
run_graph_payload = run_graph.json()
|
||||
assert run_graph_payload["code"] == 0, run_graph_payload
|
||||
assert run_graph_payload["data"].get("task_id"), run_graph_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_dataset_embedding_endpoints(rest_client, create_dataset):
|
||||
dataset_id = create_dataset("dataset_embedding_endpoints")
|
||||
|
||||
run_no_docs_res = rest_client.post(f"/datasets/{dataset_id}/embedding")
|
||||
assert run_no_docs_res.status_code == 200
|
||||
run_no_docs_payload = run_no_docs_res.json()
|
||||
assert run_no_docs_payload["code"] == 102, run_no_docs_payload
|
||||
|
||||
missing_embd_id_res = rest_client.post(f"/datasets/{dataset_id}/embedding/check", json={})
|
||||
assert missing_embd_id_res.status_code == 200
|
||||
missing_embd_id_payload = missing_embd_id_res.json()
|
||||
assert missing_embd_id_payload["code"] != 0, missing_embd_id_payload
|
||||
|
||||
invalid_dataset_res = rest_client.post("/datasets/invalid_id/embedding")
|
||||
assert invalid_dataset_res.status_code == 200
|
||||
invalid_dataset_payload = invalid_dataset_res.json()
|
||||
assert invalid_dataset_payload["code"] != 0, invalid_dataset_payload
|
||||
43
test/testcases/restful_api/test_document_raw_routes.py
Normal file
43
test/testcases/restful_api/test_document_raw_routes.py
Normal file
@@ -0,0 +1,43 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_document_image_invalid_id_contract(rest_client_noauth):
|
||||
res = rest_client_noauth.get("/documents/images/not-a-valid-image-id")
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 102, payload
|
||||
assert payload["message"] == "Image not found.", payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_document_artifact_requires_auth(rest_client_noauth):
|
||||
res = rest_client_noauth.get("/documents/artifact/not-an-artifact.txt")
|
||||
assert res.status_code == 401
|
||||
payload = res.json()
|
||||
assert payload["code"] == 401, payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_document_artifact_rejects_unsafe_filename(rest_client):
|
||||
res = rest_client.get("/documents/artifact/not-an-artifact.exe")
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 102, payload
|
||||
assert payload["message"] == "Invalid file type.", payload
|
||||
122
test/testcases/restful_api/test_documents.py
Normal file
122
test/testcases/restful_api/test_documents.py
Normal file
@@ -0,0 +1,122 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import pytest
|
||||
from utils.file_utils import create_txt_file
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_documents_upload_and_list(rest_client, create_dataset, tmp_path):
|
||||
dataset_id = create_dataset("dataset_upload_list")
|
||||
fp = create_txt_file(tmp_path / "upload_and_list.txt")
|
||||
with fp.open("rb") as file_obj:
|
||||
res = rest_client.post(
|
||||
f"/datasets/{dataset_id}/documents",
|
||||
files=[("file", (fp.name, file_obj))],
|
||||
)
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 0, payload
|
||||
assert payload["data"][0]["dataset_id"] == dataset_id, payload
|
||||
|
||||
list_res = rest_client.get(f"/datasets/{dataset_id}/documents")
|
||||
assert list_res.status_code == 200
|
||||
list_payload = list_res.json()
|
||||
assert list_payload["code"] == 0, list_payload
|
||||
assert list_payload["data"]["total"] >= 1, list_payload
|
||||
assert any(doc["name"] == fp.name for doc in list_payload["data"]["docs"]), list_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_documents_upload_missing_file(rest_client, create_dataset):
|
||||
dataset_id = create_dataset("dataset_upload_missing")
|
||||
res = rest_client.post(f"/datasets/{dataset_id}/documents")
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 101, payload
|
||||
assert payload["message"] == "No file part!", payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_documents_update_patch_and_delete(rest_client, create_document):
|
||||
dataset_id, document_id = create_document("update_target.txt")
|
||||
|
||||
patch_res = rest_client.patch(
|
||||
f"/datasets/{dataset_id}/documents/{document_id}",
|
||||
json={"name": "updated_target.txt"},
|
||||
)
|
||||
assert patch_res.status_code == 200
|
||||
patch_payload = patch_res.json()
|
||||
assert patch_payload["code"] == 0, patch_payload
|
||||
assert patch_payload["data"]["name"] == "updated_target.txt", patch_payload
|
||||
|
||||
delete_res = rest_client.delete(
|
||||
f"/datasets/{dataset_id}/documents",
|
||||
json={"ids": [document_id]},
|
||||
)
|
||||
assert delete_res.status_code == 200
|
||||
delete_payload = delete_res.json()
|
||||
assert delete_payload["code"] == 0, delete_payload
|
||||
assert delete_payload["data"]["deleted"] == 1, delete_payload
|
||||
|
||||
list_res = rest_client.get(f"/datasets/{dataset_id}/documents")
|
||||
assert list_res.status_code == 200
|
||||
list_payload = list_res.json()
|
||||
assert list_payload["code"] == 0, list_payload
|
||||
assert all(doc["id"] != document_id for doc in list_payload["data"]["docs"]), list_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_documents_parse_and_stop(rest_client, create_document):
|
||||
dataset_id, document_id = create_document("parse_target.txt")
|
||||
|
||||
parse_res = rest_client.post(
|
||||
f"/datasets/{dataset_id}/documents/parse",
|
||||
json={"document_ids": [document_id]},
|
||||
)
|
||||
assert parse_res.status_code == 200
|
||||
parse_payload = parse_res.json()
|
||||
assert parse_payload["code"] == 0, parse_payload
|
||||
|
||||
stop_res = rest_client.post(
|
||||
f"/datasets/{dataset_id}/documents/stop",
|
||||
json={"document_ids": [document_id]},
|
||||
)
|
||||
assert stop_res.status_code == 200
|
||||
stop_payload = stop_res.json()
|
||||
# Depending on timing this can be immediate stop success or "already completed".
|
||||
assert stop_payload["code"] in (0, 102), stop_payload
|
||||
if stop_payload["code"] == 102:
|
||||
assert "already completed" in stop_payload["message"], stop_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_documents_metadata_update_path(rest_client, create_document):
|
||||
dataset_id, document_id = create_document("metadata_target.txt")
|
||||
|
||||
res = rest_client.patch(
|
||||
f"/datasets/{dataset_id}/documents/metadatas",
|
||||
json={
|
||||
"selector": {"document_ids": [document_id]},
|
||||
"updates": [{"key": "author", "value": "qa"}],
|
||||
"deletes": [],
|
||||
},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 0, payload
|
||||
assert payload["data"]["matched_docs"] == 1, payload
|
||||
assert payload["data"]["updated"] >= 1, payload
|
||||
632
test/testcases/restful_api/test_file_routes_unit.py
Normal file
632
test/testcases/restful_api/test_file_routes_unit.py
Normal file
@@ -0,0 +1,632 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import sys
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class _DummyManager:
|
||||
def route(self, *_args, **_kwargs):
|
||||
def decorator(func):
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class _AwaitableValue:
|
||||
def __init__(self, value):
|
||||
self._value = value
|
||||
|
||||
def __await__(self):
|
||||
async def _co():
|
||||
return self._value
|
||||
|
||||
return _co().__await__()
|
||||
|
||||
|
||||
class _DummyFiles(dict):
|
||||
def __init__(self, file_objs=None):
|
||||
super().__init__()
|
||||
self._file_objs = list(file_objs or [])
|
||||
if file_objs is not None:
|
||||
self["file"] = self._file_objs
|
||||
|
||||
def getlist(self, key):
|
||||
if key == "file":
|
||||
return list(self._file_objs)
|
||||
return []
|
||||
|
||||
|
||||
class _DummyUploadFile:
|
||||
def __init__(self, filename, blob=b"blob"):
|
||||
self.filename = filename
|
||||
self._blob = blob
|
||||
|
||||
def read(self):
|
||||
return self._blob
|
||||
|
||||
|
||||
class _DummyRequest:
|
||||
def __init__(self, *, content_type="", form=None, files=None, args=None):
|
||||
self.content_type = content_type
|
||||
self.form = _AwaitableValue(form or {})
|
||||
self.files = _AwaitableValue(files if files is not None else _DummyFiles())
|
||||
self.args = args or {}
|
||||
|
||||
|
||||
class _DummyResponse:
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
self.headers = {}
|
||||
|
||||
|
||||
def _run(coro):
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
def _load_file_api_module(monkeypatch):
|
||||
repo_root = Path(__file__).resolve().parents[3]
|
||||
|
||||
quart_mod = ModuleType("quart")
|
||||
quart_mod.request = _DummyRequest()
|
||||
|
||||
async def _make_response(data):
|
||||
return _DummyResponse(data)
|
||||
|
||||
quart_mod.make_response = _make_response
|
||||
monkeypatch.setitem(sys.modules, "quart", quart_mod)
|
||||
|
||||
api_pkg = ModuleType("api")
|
||||
api_pkg.__path__ = [str(repo_root / "api")]
|
||||
monkeypatch.setitem(sys.modules, "api", api_pkg)
|
||||
|
||||
apps_pkg = ModuleType("api.apps")
|
||||
apps_pkg.__path__ = [str(repo_root / "api" / "apps")]
|
||||
apps_pkg.login_required = lambda func: func
|
||||
monkeypatch.setitem(sys.modules, "api.apps", apps_pkg)
|
||||
api_pkg.apps = apps_pkg
|
||||
|
||||
services_pkg = ModuleType("api.apps.services")
|
||||
services_pkg.__path__ = [str(repo_root / "api" / "apps" / "services")]
|
||||
monkeypatch.setitem(sys.modules, "api.apps.services", services_pkg)
|
||||
apps_pkg.services = services_pkg
|
||||
|
||||
file_api_service_mod = ModuleType("api.apps.services.file_api_service")
|
||||
|
||||
async def _upload_file(_tenant_id, _pf_id, _file_objs):
|
||||
return True, [{"id": "f1"}]
|
||||
|
||||
async def _create_folder(_tenant_id, _name, _parent_id=None, _file_type=None):
|
||||
return True, {"id": "folder1"}
|
||||
|
||||
async def _delete_files(_tenant_id, _ids):
|
||||
return True, True
|
||||
|
||||
async def _move_files(_tenant_id, _src_file_ids, _dest_file_id=None, _new_name=None):
|
||||
return True, True
|
||||
|
||||
file_api_service_mod.upload_file = _upload_file
|
||||
file_api_service_mod.create_folder = _create_folder
|
||||
file_api_service_mod.list_files = lambda _tenant_id, _args: (True, {"files": [], "total": 0})
|
||||
file_api_service_mod.delete_files = _delete_files
|
||||
file_api_service_mod.move_files = _move_files
|
||||
file_api_service_mod.get_file_content = lambda _tenant_id, _file_id: (
|
||||
True,
|
||||
SimpleNamespace(parent_id="bucket1", location="path1", name="doc.txt", type="doc"),
|
||||
)
|
||||
file_api_service_mod.get_parent_folder = lambda _file_id, user_id=None: (True, {"parent_folder": {"id": "parent1"}})
|
||||
file_api_service_mod.get_all_parent_folders = lambda _file_id, user_id=None: (True, {"parent_folders": [{"id": "root"}]})
|
||||
monkeypatch.setitem(sys.modules, "api.apps.services.file_api_service", file_api_service_mod)
|
||||
services_pkg.file_api_service = file_api_service_mod
|
||||
|
||||
db_pkg = ModuleType("api.db")
|
||||
db_pkg.__path__ = []
|
||||
|
||||
class _FileType(Enum):
|
||||
DOC = "doc"
|
||||
VISUAL = "visual"
|
||||
|
||||
db_pkg.FileType = _FileType
|
||||
monkeypatch.setitem(sys.modules, "api.db", db_pkg)
|
||||
api_pkg.db = db_pkg
|
||||
|
||||
file2doc_mod = ModuleType("api.db.services.file2document_service")
|
||||
file2doc_mod.File2DocumentService = SimpleNamespace(get_storage_address=lambda **_kwargs: ("bucket2", "path2"))
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.file2document_service", file2doc_mod)
|
||||
|
||||
api_utils_mod = ModuleType("api.utils.api_utils")
|
||||
api_utils_mod.add_tenant_id_to_kwargs = lambda func: func
|
||||
api_utils_mod.get_error_argument_result = lambda message: {"code": 400, "data": None, "message": message}
|
||||
api_utils_mod.get_error_data_result = lambda message: {"code": 500, "data": None, "message": message}
|
||||
api_utils_mod.get_result = lambda data=None: {"code": 0, "data": data, "message": ""}
|
||||
api_utils_mod.get_json_result = lambda code=0, message="success", data=None: {"code": code, "data": data, "message": message}
|
||||
monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
|
||||
|
||||
validation_mod = ModuleType("api.utils.validation_utils")
|
||||
validation_mod.CreateFolderReq = object
|
||||
validation_mod.DeleteFileReq = object
|
||||
validation_mod.ListFileReq = object
|
||||
validation_mod.MoveFileReq = object
|
||||
|
||||
async def _validate_json_request(_request, _schema):
|
||||
return {}, None
|
||||
|
||||
validation_mod.validate_and_parse_json_request = _validate_json_request
|
||||
validation_mod.validate_and_parse_request_args = lambda _request, _schema: ({}, None)
|
||||
monkeypatch.setitem(sys.modules, "api.utils.validation_utils", validation_mod)
|
||||
|
||||
web_utils_mod = ModuleType("api.utils.web_utils")
|
||||
web_utils_mod.CONTENT_TYPE_MAP = {"txt": "text/plain"}
|
||||
web_utils_mod.apply_safe_file_response_headers = lambda response, content_type, ext: response.headers.update({"content_type": content_type, "ext": ext})
|
||||
monkeypatch.setitem(sys.modules, "api.utils.web_utils", web_utils_mod)
|
||||
|
||||
common_pkg = ModuleType("common")
|
||||
common_pkg.__path__ = [str(repo_root / "common")]
|
||||
common_pkg.settings = SimpleNamespace(
|
||||
STORAGE_IMPL=SimpleNamespace(
|
||||
get=lambda *_args, **_kwargs: b"blob",
|
||||
)
|
||||
)
|
||||
monkeypatch.setitem(sys.modules, "common", common_pkg)
|
||||
|
||||
misc_utils_mod = ModuleType("common.misc_utils")
|
||||
|
||||
async def thread_pool_exec(func, *args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
misc_utils_mod.thread_pool_exec = thread_pool_exec
|
||||
monkeypatch.setitem(sys.modules, "common.misc_utils", misc_utils_mod)
|
||||
|
||||
module_path = repo_root / "api" / "apps" / "restful_apis" / "file_api.py"
|
||||
spec = importlib.util.spec_from_file_location("api.apps.restful_apis.file_api", module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
module.manager = _DummyManager()
|
||||
monkeypatch.setitem(sys.modules, "api.apps.restful_apis.file_api", module)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_create_or_upload_multipart_requires_file(monkeypatch):
|
||||
module = _load_file_api_module(monkeypatch)
|
||||
monkeypatch.setattr(module, "request", _DummyRequest(content_type="multipart/form-data", form={}, files=_DummyFiles()))
|
||||
|
||||
res = _run(module.create_or_upload("tenant1"))
|
||||
assert res["code"] == 400
|
||||
assert res["message"] == "No file part!"
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_create_or_upload_uploads_via_new_service(monkeypatch):
|
||||
module = _load_file_api_module(monkeypatch)
|
||||
files = _DummyFiles([_DummyUploadFile("a.txt")])
|
||||
monkeypatch.setattr(module, "request", _DummyRequest(content_type="multipart/form-data", form={"parent_id": "pf1"}, files=files))
|
||||
|
||||
seen = {}
|
||||
|
||||
async def _upload_file(tenant_id, pf_id, file_objs):
|
||||
seen["args"] = (tenant_id, pf_id, [f.filename for f in file_objs])
|
||||
return True, [{"id": "f1"}]
|
||||
|
||||
monkeypatch.setattr(module.file_api_service, "upload_file", _upload_file)
|
||||
res = _run(module.create_or_upload("tenant1"))
|
||||
|
||||
assert seen["args"] == ("tenant1", "pf1", ["a.txt"])
|
||||
assert res["code"] == 0
|
||||
assert res["data"] == [{"id": "f1"}]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_create_or_upload_creates_folder_from_json(monkeypatch):
|
||||
module = _load_file_api_module(monkeypatch)
|
||||
monkeypatch.setattr(module, "request", _DummyRequest(content_type="application/json"))
|
||||
|
||||
async def _validate(_request, _schema):
|
||||
return {"name": "folder-a", "parent_id": "pf1", "type": "folder"}, None
|
||||
|
||||
async def _create_folder(tenant_id, name, parent_id=None, file_type=None):
|
||||
return True, {"tenant_id": tenant_id, "name": name, "parent_id": parent_id, "type": file_type}
|
||||
|
||||
monkeypatch.setattr(module, "validate_and_parse_json_request", _validate)
|
||||
monkeypatch.setattr(module.file_api_service, "create_folder", _create_folder)
|
||||
|
||||
res = _run(module.create_or_upload("tenant1"))
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["tenant_id"] == "tenant1"
|
||||
assert res["data"]["name"] == "folder-a"
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_list_files_validation_error(monkeypatch):
|
||||
module = _load_file_api_module(monkeypatch)
|
||||
monkeypatch.setattr(module, "validate_and_parse_request_args", lambda _request, _schema: (None, "bad args"))
|
||||
|
||||
res = _run(module.list_files("tenant1"))
|
||||
assert res["code"] == 400
|
||||
assert res["message"] == "bad args"
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_move_uses_new_payload_shape(monkeypatch):
|
||||
module = _load_file_api_module(monkeypatch)
|
||||
|
||||
async def _validate(_request, _schema):
|
||||
return {"src_file_ids": ["f1"], "dest_file_id": "pf2"}, None
|
||||
|
||||
seen = {}
|
||||
|
||||
async def _move_files(tenant_id, src_file_ids, dest_file_id=None, new_name=None):
|
||||
seen["args"] = (tenant_id, src_file_ids, dest_file_id, new_name)
|
||||
return True, True
|
||||
|
||||
monkeypatch.setattr(module, "validate_and_parse_json_request", _validate)
|
||||
monkeypatch.setattr(module.file_api_service, "move_files", _move_files)
|
||||
|
||||
res = _run(module.move("tenant1"))
|
||||
assert seen["args"] == ("tenant1", ["f1"], "pf2", None)
|
||||
assert res["code"] == 0
|
||||
assert res["data"] is True
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_rename_via_move_route(monkeypatch):
|
||||
module = _load_file_api_module(monkeypatch)
|
||||
|
||||
async def _validate(_request, _schema):
|
||||
return {"src_file_ids": ["file1"], "new_name": "renamed.txt"}, None
|
||||
|
||||
seen = {}
|
||||
|
||||
async def _move_files(tenant_id, src_file_ids, dest_file_id=None, new_name=None):
|
||||
seen["args"] = (tenant_id, src_file_ids, dest_file_id, new_name)
|
||||
return True, True
|
||||
|
||||
monkeypatch.setattr(module, "validate_and_parse_json_request", _validate)
|
||||
monkeypatch.setattr(module.file_api_service, "move_files", _move_files)
|
||||
|
||||
res = _run(module.move("tenant1"))
|
||||
assert seen["args"] == ("tenant1", ["file1"], None, "renamed.txt")
|
||||
assert res["code"] == 0
|
||||
assert res["data"] is True
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_download_falls_back_to_document_storage(monkeypatch):
|
||||
module = _load_file_api_module(monkeypatch)
|
||||
storage_calls = []
|
||||
|
||||
def _get(bucket, location):
|
||||
storage_calls.append((bucket, location))
|
||||
return b"" if len(storage_calls) == 1 else b"fallback-blob"
|
||||
|
||||
monkeypatch.setattr(module.settings, "STORAGE_IMPL", SimpleNamespace(get=_get))
|
||||
res = _run(module.download("tenant1", "file1"))
|
||||
|
||||
assert storage_calls == [("bucket1", "path1"), ("bucket2", "path2")]
|
||||
assert res.data == b"fallback-blob"
|
||||
assert res.headers["content_type"] == "text/plain"
|
||||
assert res.headers["ext"] == "txt"
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_parent_and_ancestors_use_new_routes(monkeypatch):
|
||||
module = _load_file_api_module(monkeypatch)
|
||||
|
||||
parent_res = _run(module.parent_folder("tenant1", "file1"))
|
||||
ancestors_res = _run(module.ancestors("tenant1", "file1"))
|
||||
|
||||
assert parent_res["code"] == 0
|
||||
assert parent_res["data"]["parent_folder"]["id"] == "parent1"
|
||||
assert ancestors_res["code"] == 0
|
||||
assert ancestors_res["data"]["parent_folders"][0]["id"] == "root"
|
||||
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import functools
|
||||
from copy import deepcopy
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class _DummyManager:
|
||||
def route(self, *_args, **_kwargs):
|
||||
def decorator(func):
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class _DummyFile:
|
||||
def __init__(self, file_id, file_type, *, name="file.txt", location="loc", size=1):
|
||||
self.id = file_id
|
||||
self.type = file_type
|
||||
self.name = name
|
||||
self.location = location
|
||||
self.size = size
|
||||
|
||||
|
||||
class _FalsyFile(_DummyFile):
|
||||
def __bool__(self):
|
||||
return False
|
||||
|
||||
|
||||
def _run(coro):
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
def _set_request_json(monkeypatch, module, payload_state):
|
||||
async def _req_json():
|
||||
return deepcopy(payload_state)
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", _req_json)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def auth():
|
||||
return "unit-auth"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def set_tenant_info():
|
||||
return None
|
||||
|
||||
|
||||
def _load_file2document_module(monkeypatch):
|
||||
repo_root = Path(__file__).resolve().parents[3]
|
||||
|
||||
api_pkg = ModuleType("api")
|
||||
api_pkg.__path__ = [str(repo_root / "api")]
|
||||
monkeypatch.setitem(sys.modules, "api", api_pkg)
|
||||
|
||||
apps_mod = ModuleType("api.apps")
|
||||
apps_mod.__path__ = [str(repo_root / "api" / "apps")]
|
||||
apps_mod.current_user = SimpleNamespace(id="user-1")
|
||||
apps_mod.login_required = lambda func: func
|
||||
monkeypatch.setitem(sys.modules, "api.apps", apps_mod)
|
||||
api_pkg.apps = apps_mod
|
||||
|
||||
db_pkg = ModuleType("api.db")
|
||||
db_pkg.__path__ = []
|
||||
|
||||
class _FileType(Enum):
|
||||
FOLDER = "folder"
|
||||
DOC = "doc"
|
||||
|
||||
db_pkg.FileType = _FileType
|
||||
monkeypatch.setitem(sys.modules, "api.db", db_pkg)
|
||||
api_pkg.db = db_pkg
|
||||
|
||||
services_pkg = ModuleType("api.db.services")
|
||||
services_pkg.__path__ = []
|
||||
monkeypatch.setitem(sys.modules, "api.db.services", services_pkg)
|
||||
|
||||
common_pkg = ModuleType("api.common")
|
||||
common_pkg.__path__ = []
|
||||
monkeypatch.setitem(sys.modules, "api.common", common_pkg)
|
||||
|
||||
permission_mod = ModuleType("api.common.check_team_permission")
|
||||
permission_mod.check_file_team_permission = lambda *_args, **_kwargs: True
|
||||
permission_mod.check_kb_team_permission = lambda *_args, **_kwargs: True
|
||||
monkeypatch.setitem(sys.modules, "api.common.check_team_permission", permission_mod)
|
||||
common_pkg.check_team_permission = permission_mod
|
||||
|
||||
file2document_mod = ModuleType("api.db.services.file2document_service")
|
||||
|
||||
class _StubFile2DocumentService:
|
||||
@staticmethod
|
||||
def get_by_file_id(_file_id):
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def delete_by_file_id(*_args, **_kwargs):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def insert(_payload):
|
||||
return SimpleNamespace(to_json=lambda: {})
|
||||
|
||||
file2document_mod.File2DocumentService = _StubFile2DocumentService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.file2document_service", file2document_mod)
|
||||
services_pkg.file2document_service = file2document_mod
|
||||
|
||||
file_service_mod = ModuleType("api.db.services.file_service")
|
||||
|
||||
class _StubFileService:
|
||||
@staticmethod
|
||||
def get_by_ids(_file_ids):
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def get_all_innermost_file_ids(_file_id, _acc):
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def get_by_id(_file_id):
|
||||
return True, _DummyFile(_file_id, _FileType.DOC.value)
|
||||
|
||||
@staticmethod
|
||||
def get_parser(_file_type, _file_name, parser_id):
|
||||
return parser_id
|
||||
|
||||
file_service_mod.FileService = _StubFileService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.file_service", file_service_mod)
|
||||
services_pkg.file_service = file_service_mod
|
||||
|
||||
kb_service_mod = ModuleType("api.db.services.knowledgebase_service")
|
||||
|
||||
class _StubKnowledgebaseService:
|
||||
@staticmethod
|
||||
def get_by_id(_kb_id):
|
||||
return False, None
|
||||
|
||||
kb_service_mod.KnowledgebaseService = _StubKnowledgebaseService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.knowledgebase_service", kb_service_mod)
|
||||
services_pkg.knowledgebase_service = kb_service_mod
|
||||
|
||||
document_service_mod = ModuleType("api.db.services.document_service")
|
||||
|
||||
class _StubDocumentService:
|
||||
@staticmethod
|
||||
def get_by_id(doc_id):
|
||||
return True, SimpleNamespace(id=doc_id)
|
||||
|
||||
@staticmethod
|
||||
def get_tenant_id(_doc_id):
|
||||
return "tenant-1"
|
||||
|
||||
@staticmethod
|
||||
def remove_document(*_args, **_kwargs):
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def insert(_payload):
|
||||
return SimpleNamespace(id="doc-1")
|
||||
|
||||
document_service_mod.DocumentService = _StubDocumentService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.document_service", document_service_mod)
|
||||
services_pkg.document_service = document_service_mod
|
||||
|
||||
api_utils_mod = ModuleType("api.utils.api_utils")
|
||||
|
||||
def get_json_result(data=None, message="", code=0):
|
||||
return {"code": code, "data": data, "message": message}
|
||||
|
||||
def get_data_error_result(message=""):
|
||||
return {"code": 102, "data": None, "message": message}
|
||||
|
||||
async def get_request_json():
|
||||
return {}
|
||||
|
||||
def server_error_response(err):
|
||||
return {"code": 500, "data": None, "message": str(err)}
|
||||
|
||||
def validate_request(*_keys):
|
||||
def _decorator(func):
|
||||
@functools.wraps(func)
|
||||
async def _wrapper(*args, **kwargs):
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return _wrapper
|
||||
|
||||
return _decorator
|
||||
|
||||
api_utils_mod.get_json_result = get_json_result
|
||||
api_utils_mod.get_data_error_result = get_data_error_result
|
||||
api_utils_mod.get_request_json = get_request_json
|
||||
api_utils_mod.server_error_response = server_error_response
|
||||
api_utils_mod.validate_request = validate_request
|
||||
monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
|
||||
|
||||
misc_utils_mod = ModuleType("common.misc_utils")
|
||||
misc_utils_mod.get_uuid = lambda: "uuid"
|
||||
monkeypatch.setitem(sys.modules, "common.misc_utils", misc_utils_mod)
|
||||
|
||||
constants_mod = ModuleType("common.constants")
|
||||
|
||||
class _RetCode:
|
||||
ARGUMENT_ERROR = 101
|
||||
|
||||
constants_mod.RetCode = _RetCode
|
||||
monkeypatch.setitem(sys.modules, "common.constants", constants_mod)
|
||||
|
||||
module_name = "test_file2document_routes_unit_module"
|
||||
module_path = repo_root / "api" / "apps" / "restful_apis" / "file2document_api.py"
|
||||
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
module.manager = _DummyManager()
|
||||
monkeypatch.setitem(sys.modules, module_name, module)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_convert_branch_matrix_unit(monkeypatch):
|
||||
module = _load_file2document_module(monkeypatch)
|
||||
req_state = {"kb_ids": ["kb-1"], "file_ids": ["f1"]}
|
||||
_set_request_json(monkeypatch, module, req_state)
|
||||
|
||||
# Falsy file returns "File not found!" during synchronous validation.
|
||||
monkeypatch.setattr(module.FileService, "get_by_ids", lambda _ids: [_FalsyFile("f1", module.FileType.DOC.value)])
|
||||
res = _run(module.convert())
|
||||
assert res["code"] == 102
|
||||
assert res["message"] == "File not found!"
|
||||
|
||||
# Valid file but invalid kb returns "Can't find this dataset!" during synchronous validation.
|
||||
monkeypatch.setattr(module.FileService, "get_by_ids", lambda _ids: [_DummyFile("f1", module.FileType.DOC.value)])
|
||||
res = _run(module.convert())
|
||||
assert res["code"] == 102
|
||||
assert res["message"] == "Can't find this dataset!"
|
||||
|
||||
kb = SimpleNamespace(id="kb-1", parser_id="naive", pipeline_id="p1", parser_config={})
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, kb))
|
||||
|
||||
# Unauthorized file access is rejected before scheduling background work.
|
||||
monkeypatch.setattr(module, "check_file_team_permission", lambda *_args, **_kwargs: False)
|
||||
res = _run(module.convert())
|
||||
assert res["code"] == 102
|
||||
assert res["message"] == "No authorization."
|
||||
|
||||
# Unauthorized dataset access is rejected before scheduling background work.
|
||||
monkeypatch.setattr(module, "check_file_team_permission", lambda *_args, **_kwargs: True)
|
||||
monkeypatch.setattr(module, "check_kb_team_permission", lambda *_args, **_kwargs: False)
|
||||
res = _run(module.convert())
|
||||
assert res["code"] == 102
|
||||
assert res["message"] == "No authorization."
|
||||
|
||||
# Valid file and kb schedule background work and return data=True immediately.
|
||||
monkeypatch.setattr(module, "check_kb_team_permission", lambda *_args, **_kwargs: True)
|
||||
res = _run(module.convert())
|
||||
assert res["code"] == 0
|
||||
assert res["data"] is True
|
||||
|
||||
# Folder expansion schedules background work and returns data=True immediately.
|
||||
req_state["file_ids"] = ["folder-1"]
|
||||
monkeypatch.setattr(module.FileService, "get_by_ids", lambda _ids: [_DummyFile("folder-1", module.FileType.FOLDER.value, name="folder")])
|
||||
monkeypatch.setattr(module.FileService, "get_all_innermost_file_ids", lambda _file_id, _acc: ["inner-1"])
|
||||
res = _run(module.convert())
|
||||
assert res["code"] == 0
|
||||
assert res["data"] is True
|
||||
|
||||
# Exception in file lookup returns 500.
|
||||
req_state["file_ids"] = ["f1"]
|
||||
monkeypatch.setattr(
|
||||
module.FileService,
|
||||
"get_by_ids",
|
||||
lambda _ids: (_ for _ in ()).throw(RuntimeError("convert boom")),
|
||||
)
|
||||
res = _run(module.convert())
|
||||
assert res["code"] == 500
|
||||
assert "convert boom" in res["message"]
|
||||
37
test/testcases/restful_api/test_langfuse_routes.py
Normal file
37
test/testcases/restful_api/test_langfuse_routes.py
Normal file
@@ -0,0 +1,37 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_langfuse_api_key_routes_require_auth(rest_client_noauth):
|
||||
for method in ("get", "post", "put", "delete"):
|
||||
requester = getattr(rest_client_noauth, method)
|
||||
kwargs = {"json": {"secret_key": "s", "public_key": "p", "host": "http://example.com"}} if method in {"post", "put"} else {}
|
||||
res = requester("/langfuse/api-key", **kwargs)
|
||||
assert res.status_code == 401
|
||||
payload = res.json()
|
||||
assert payload["code"] == 401, (method, payload)
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_langfuse_api_key_missing_required_fields(rest_client):
|
||||
res = rest_client.post("/langfuse/api-key", json={"secret_key": "", "public_key": "pub", "host": "http://host"})
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] in (101, 102), payload
|
||||
assert "required" in payload["message"].lower() or "missing" in payload["message"].lower(), payload
|
||||
745
test/testcases/restful_api/test_mcp_routes_unit.py
Normal file
745
test/testcases/restful_api/test_mcp_routes_unit.py
Normal file
@@ -0,0 +1,745 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import inspect
|
||||
import json
|
||||
import sys
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class _DummyManager:
|
||||
def route(self, *_args, **_kwargs):
|
||||
def decorator(func):
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class _Args(dict):
|
||||
def getlist(self, key):
|
||||
value = self.get(key, [])
|
||||
if isinstance(value, list):
|
||||
return value
|
||||
return [value]
|
||||
|
||||
|
||||
class _Field:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def __eq__(self, other):
|
||||
return (self.name, other)
|
||||
|
||||
|
||||
class _DummyMCPServer:
|
||||
id = _Field("id")
|
||||
tenant_id = _Field("tenant_id")
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.id = kwargs.get("id", "")
|
||||
self.name = kwargs.get("name", "")
|
||||
self.url = kwargs.get("url", "")
|
||||
self.server_type = kwargs.get("server_type", "sse")
|
||||
self.tenant_id = kwargs.get("tenant_id", "tenant_1")
|
||||
self.variables = kwargs.get("variables", {})
|
||||
self.headers = kwargs.get("headers", {})
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"url": self.url,
|
||||
"server_type": self.server_type,
|
||||
"tenant_id": self.tenant_id,
|
||||
"variables": self.variables,
|
||||
"headers": self.headers,
|
||||
}
|
||||
|
||||
|
||||
class _DummyMCPServerService:
|
||||
@staticmethod
|
||||
def get_servers(*_args, **_kwargs):
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def get_or_none(*_args, **_kwargs):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_by_id(*_args, **_kwargs):
|
||||
return False, None
|
||||
|
||||
@staticmethod
|
||||
def get_by_name_and_tenant(*_args, **_kwargs):
|
||||
return False, None
|
||||
|
||||
@staticmethod
|
||||
def insert(**_kwargs):
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def filter_update(*_args, **_kwargs):
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def delete_by_ids(*_args, **_kwargs):
|
||||
return True
|
||||
|
||||
|
||||
class _DummyTenantService:
|
||||
@staticmethod
|
||||
def get_by_id(*_args, **_kwargs):
|
||||
return True, SimpleNamespace(id="tenant_1")
|
||||
|
||||
|
||||
class _DummyTool:
|
||||
def __init__(self, name):
|
||||
self._name = name
|
||||
|
||||
def model_dump(self):
|
||||
return {"name": self._name}
|
||||
|
||||
|
||||
class _DummyMCPToolCallSession:
|
||||
def __init__(self, _mcp_server, _variables):
|
||||
self._tools = [_DummyTool("tool_a"), _DummyTool("tool_b")]
|
||||
|
||||
def get_tools(self, _timeout):
|
||||
return self._tools
|
||||
|
||||
def tool_call(self, _name, _arguments, _timeout):
|
||||
return "ok"
|
||||
|
||||
|
||||
def _run(coro):
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
def _set_request_json(monkeypatch, module, payload):
|
||||
async def _request_json():
|
||||
return payload
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", _request_json)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def auth():
|
||||
return "unit-auth"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def set_tenant_info():
|
||||
return None
|
||||
|
||||
|
||||
def _load_mcp_api(monkeypatch):
|
||||
repo_root = Path(__file__).resolve().parents[3]
|
||||
|
||||
quart_mod = ModuleType("quart")
|
||||
quart_mod.Response = object
|
||||
quart_mod.request = SimpleNamespace(args=_Args({}))
|
||||
monkeypatch.setitem(sys.modules, "quart", quart_mod)
|
||||
|
||||
common_pkg = ModuleType("common")
|
||||
common_pkg.__path__ = [str(repo_root / "common")]
|
||||
monkeypatch.setitem(sys.modules, "common", common_pkg)
|
||||
|
||||
constants_mod = ModuleType("common.constants")
|
||||
constants_mod.VALID_MCP_SERVER_TYPES = {"sse", "streamable-http"}
|
||||
monkeypatch.setitem(sys.modules, "common.constants", constants_mod)
|
||||
|
||||
apps_mod = ModuleType("api.apps")
|
||||
apps_mod.current_user = SimpleNamespace(id="tenant_1")
|
||||
apps_mod.login_required = lambda func: func
|
||||
monkeypatch.setitem(sys.modules, "api.apps", apps_mod)
|
||||
|
||||
db_models_mod = ModuleType("api.db.db_models")
|
||||
db_models_mod.MCPServer = _DummyMCPServer
|
||||
monkeypatch.setitem(sys.modules, "api.db.db_models", db_models_mod)
|
||||
|
||||
mcp_service_mod = ModuleType("api.db.services.mcp_server_service")
|
||||
mcp_service_mod.MCPServerService = _DummyMCPServerService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.mcp_server_service", mcp_service_mod)
|
||||
|
||||
user_service_mod = ModuleType("api.db.services.user_service")
|
||||
user_service_mod.TenantService = _DummyTenantService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod)
|
||||
|
||||
mcp_conn_mod = ModuleType("common.mcp_tool_call_conn")
|
||||
mcp_conn_mod.MCPToolCallSession = _DummyMCPToolCallSession
|
||||
mcp_conn_mod.close_multiple_mcp_toolcall_sessions = lambda _sessions: None
|
||||
monkeypatch.setitem(sys.modules, "common.mcp_tool_call_conn", mcp_conn_mod)
|
||||
|
||||
api_utils_mod = ModuleType("api.utils.api_utils")
|
||||
|
||||
async def _default_request_json():
|
||||
return {}
|
||||
|
||||
def _get_json_result(code=0, message="success", data=None):
|
||||
return {"code": code, "message": message, "data": data}
|
||||
|
||||
def _get_data_error_result(code=102, message="Sorry! Data missing!"):
|
||||
return {"code": code, "message": message}
|
||||
|
||||
def _server_error_response(error):
|
||||
return {"code": 100, "message": repr(error)}
|
||||
|
||||
async def _get_mcp_tools(*_args, **_kwargs):
|
||||
return {}
|
||||
|
||||
def _validate_request(*_args, **_kwargs):
|
||||
def _decorator(func):
|
||||
@wraps(func)
|
||||
async def _wrapped(*func_args, **func_kwargs):
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return await func(*func_args, **func_kwargs)
|
||||
return func(*func_args, **func_kwargs)
|
||||
|
||||
return _wrapped
|
||||
|
||||
return _decorator
|
||||
|
||||
api_utils_mod.get_request_json = _default_request_json
|
||||
api_utils_mod.get_json_result = _get_json_result
|
||||
api_utils_mod.get_data_error_result = _get_data_error_result
|
||||
api_utils_mod.server_error_response = _server_error_response
|
||||
api_utils_mod.validate_request = _validate_request
|
||||
api_utils_mod.get_mcp_tools = _get_mcp_tools
|
||||
monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
|
||||
|
||||
web_utils_mod = ModuleType("api.utils.web_utils")
|
||||
|
||||
def _get_float(data, key, default):
|
||||
try:
|
||||
return float(data.get(key, default))
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
def _safe_json_parse(value):
|
||||
if isinstance(value, (dict, list)):
|
||||
return value
|
||||
if value in (None, ""):
|
||||
return {}
|
||||
try:
|
||||
return json.loads(value)
|
||||
except (TypeError, ValueError):
|
||||
return {}
|
||||
|
||||
web_utils_mod.get_float = _get_float
|
||||
web_utils_mod.safe_json_parse = _safe_json_parse
|
||||
monkeypatch.setitem(sys.modules, "api.utils.web_utils", web_utils_mod)
|
||||
|
||||
module_name = "test_mcp_api_unit_module"
|
||||
module_path = repo_root / "api" / "apps" / "restful_apis" / "mcp_api.py"
|
||||
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
module.manager = _DummyManager()
|
||||
monkeypatch.setitem(sys.modules, module_name, module)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_list_mcp_desc_pagination_and_exception(monkeypatch):
|
||||
module = _load_mcp_api(monkeypatch)
|
||||
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"request",
|
||||
SimpleNamespace(args=_Args({"keywords": "k", "page": "2", "page_size": "1", "orderby": "create_time", "desc": "false"})),
|
||||
)
|
||||
_set_request_json(monkeypatch, module, {"mcp_ids": []})
|
||||
monkeypatch.setattr(module.MCPServerService, "get_servers", lambda *_args, **_kwargs: [{"id": "a"}, {"id": "b"}])
|
||||
|
||||
res = _run(module.list_mcp())
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["total"] == 2
|
||||
assert res["data"]["mcp_servers"] == [{"id": "b"}]
|
||||
|
||||
monkeypatch.setattr(module, "request", SimpleNamespace(args=_Args({})))
|
||||
_set_request_json(monkeypatch, module, {"mcp_ids": []})
|
||||
|
||||
def _raise_list(*_args, **_kwargs):
|
||||
raise RuntimeError("list explode")
|
||||
|
||||
monkeypatch.setattr(module.MCPServerService, "get_servers", _raise_list)
|
||||
res = _run(module.list_mcp())
|
||||
assert res["code"] == 100
|
||||
assert "list explode" in res["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_detail_not_found_success_and_exception(monkeypatch):
|
||||
module = _load_mcp_api(monkeypatch)
|
||||
monkeypatch.setattr(module, "request", SimpleNamespace(args=_Args({})))
|
||||
|
||||
monkeypatch.setattr(module.MCPServerService, "get_or_none", lambda **_kwargs: None)
|
||||
res = module.detail("mcp-1")
|
||||
assert res["code"] == 102
|
||||
assert "Cannot find MCP server mcp-1 for user tenant_1" in res["message"]
|
||||
|
||||
monkeypatch.setattr(
|
||||
module.MCPServerService,
|
||||
"get_or_none",
|
||||
lambda **_kwargs: _DummyMCPServer(id="mcp-1", name="srv", url="http://a", server_type="sse", tenant_id="tenant_1"),
|
||||
)
|
||||
res = module.detail("mcp-1")
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["id"] == "mcp-1"
|
||||
|
||||
def _raise_detail(**_kwargs):
|
||||
raise RuntimeError("detail explode")
|
||||
|
||||
monkeypatch.setattr(module.MCPServerService, "get_or_none", _raise_detail)
|
||||
res = module.detail("mcp-1")
|
||||
assert res["code"] == 100
|
||||
assert "detail explode" in res["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_create_validation_guards(monkeypatch):
|
||||
module = _load_mcp_api(monkeypatch)
|
||||
|
||||
monkeypatch.setattr(module.MCPServerService, "get_by_name_and_tenant", lambda **_kwargs: (False, None))
|
||||
|
||||
_set_request_json(monkeypatch, module, {"name": "srv", "url": "http://a", "server_type": "invalid"})
|
||||
res = _run(module.create.__wrapped__())
|
||||
assert "Unsupported MCP server type" in res["message"]
|
||||
|
||||
_set_request_json(monkeypatch, module, {"name": "", "url": "http://a", "server_type": "sse"})
|
||||
res = _run(module.create.__wrapped__())
|
||||
assert "Invalid MCP name" in res["message"]
|
||||
|
||||
monkeypatch.setattr(module.MCPServerService, "get_by_name_and_tenant", lambda **_kwargs: (True, object()))
|
||||
_set_request_json(monkeypatch, module, {"name": "srv", "url": "http://a", "server_type": "sse"})
|
||||
res = _run(module.create.__wrapped__())
|
||||
assert "Duplicated MCP server name" in res["message"]
|
||||
|
||||
monkeypatch.setattr(module.MCPServerService, "get_by_name_and_tenant", lambda **_kwargs: (False, None))
|
||||
_set_request_json(monkeypatch, module, {"name": "srv", "url": "", "server_type": "sse"})
|
||||
res = _run(module.create.__wrapped__())
|
||||
assert "Invalid url" in res["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_create_service_paths(monkeypatch):
|
||||
module = _load_mcp_api(monkeypatch)
|
||||
|
||||
base_payload = {
|
||||
"name": "srv",
|
||||
"url": "http://server",
|
||||
"server_type": "sse",
|
||||
"headers": '{"Authorization": "x"}',
|
||||
"variables": '{"tools": {"old": 1}, "token": "abc"}',
|
||||
"timeout": "2.5",
|
||||
}
|
||||
|
||||
monkeypatch.setattr(module, "get_uuid", lambda: "uuid-create")
|
||||
monkeypatch.setattr(module.MCPServerService, "get_by_name_and_tenant", lambda **_kwargs: (False, None))
|
||||
|
||||
_set_request_json(monkeypatch, module, dict(base_payload))
|
||||
monkeypatch.setattr(module.TenantService, "get_by_id", lambda *_args, **_kwargs: (False, None))
|
||||
res = _run(module.create.__wrapped__())
|
||||
assert "Tenant not found" in res["message"]
|
||||
|
||||
_set_request_json(monkeypatch, module, dict(base_payload))
|
||||
monkeypatch.setattr(module.TenantService, "get_by_id", lambda *_args, **_kwargs: (True, object()))
|
||||
|
||||
async def _thread_pool_tools_error(_func, _servers, _timeout):
|
||||
return None, "tools error"
|
||||
|
||||
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_tools_error)
|
||||
res = _run(module.create.__wrapped__())
|
||||
assert res["code"] == 102
|
||||
assert "tools error" in res["message"]
|
||||
|
||||
_set_request_json(monkeypatch, module, dict(base_payload))
|
||||
|
||||
async def _thread_pool_ok(_func, servers, _timeout):
|
||||
return {servers[0].name: [{"name": "tool_a"}, {"invalid": True}]}, None
|
||||
|
||||
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_ok)
|
||||
monkeypatch.setattr(module.MCPServerService, "insert", lambda **_kwargs: False)
|
||||
res = _run(module.create.__wrapped__())
|
||||
assert res["code"] == 102
|
||||
assert "Failed to create MCP server" in res["message"]
|
||||
|
||||
_set_request_json(monkeypatch, module, dict(base_payload))
|
||||
monkeypatch.setattr(module.MCPServerService, "insert", lambda **_kwargs: True)
|
||||
res = _run(module.create.__wrapped__())
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["id"] == "uuid-create"
|
||||
assert res["data"]["tenant_id"] == "tenant_1"
|
||||
assert res["data"]["variables"]["tools"] == {"tool_a": {"name": "tool_a"}}
|
||||
|
||||
_set_request_json(monkeypatch, module, dict(base_payload))
|
||||
|
||||
async def _thread_pool_raises(_func, _servers, _timeout):
|
||||
raise RuntimeError("create explode")
|
||||
|
||||
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_raises)
|
||||
res = _run(module.create.__wrapped__())
|
||||
assert res["code"] == 100
|
||||
assert "create explode" in res["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_update_validation_guards(monkeypatch):
|
||||
module = _load_mcp_api(monkeypatch)
|
||||
|
||||
existing = _DummyMCPServer(id="mcp-1", name="srv", url="http://server", server_type="sse", tenant_id="tenant_1", variables={}, headers={})
|
||||
|
||||
_set_request_json(monkeypatch, module, {"mcp_id": "mcp-1"})
|
||||
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (False, None))
|
||||
res = _run(module.update("mcp-1"))
|
||||
assert "Cannot find MCP server" in res["message"]
|
||||
|
||||
_set_request_json(monkeypatch, module, {"mcp_id": "mcp-1"})
|
||||
monkeypatch.setattr(
|
||||
module.MCPServerService,
|
||||
"get_by_id",
|
||||
lambda _mcp_id: (True, _DummyMCPServer(id="mcp-1", name="srv", url="http://server", server_type="sse", tenant_id="other", variables={}, headers={})),
|
||||
)
|
||||
res = _run(module.update("mcp-1"))
|
||||
assert "Cannot find MCP server" in res["message"]
|
||||
|
||||
_set_request_json(monkeypatch, module, {"mcp_id": "mcp-1", "server_type": "invalid"})
|
||||
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, existing))
|
||||
res = _run(module.update("mcp-1"))
|
||||
assert "Unsupported MCP server type" in res["message"]
|
||||
|
||||
_set_request_json(monkeypatch, module, {"mcp_id": "mcp-1", "name": "a" * 256})
|
||||
res = _run(module.update("mcp-1"))
|
||||
assert "Invalid MCP name" in res["message"]
|
||||
|
||||
_set_request_json(monkeypatch, module, {"mcp_id": "mcp-1", "url": ""})
|
||||
res = _run(module.update("mcp-1"))
|
||||
assert "Invalid url" in res["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_update_service_paths(monkeypatch):
|
||||
module = _load_mcp_api(monkeypatch)
|
||||
|
||||
existing = _DummyMCPServer(
|
||||
id="mcp-1",
|
||||
name="srv",
|
||||
url="http://server",
|
||||
server_type="sse",
|
||||
tenant_id="tenant_1",
|
||||
variables={"tools": {"old": {"enabled": True}}, "token": "abc"},
|
||||
headers={"Authorization": "old"},
|
||||
)
|
||||
updated = _DummyMCPServer(
|
||||
id="mcp-1",
|
||||
name="srv-new",
|
||||
url="http://server-new",
|
||||
server_type="sse",
|
||||
tenant_id="tenant_1",
|
||||
variables={"tools": {"tool_a": {"name": "tool_a"}}},
|
||||
headers={"Authorization": "new"},
|
||||
)
|
||||
|
||||
base_payload = {
|
||||
"mcp_id": "mcp-1",
|
||||
"name": "srv-new",
|
||||
"url": "http://server-new",
|
||||
"server_type": "sse",
|
||||
"headers": '{"Authorization": "new"}',
|
||||
"variables": '{"tools": {"ignore": 1}, "token": "new"}',
|
||||
"timeout": "3.0",
|
||||
}
|
||||
|
||||
_set_request_json(monkeypatch, module, dict(base_payload))
|
||||
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, existing))
|
||||
|
||||
async def _thread_pool_tools_error(_func, _servers, _timeout):
|
||||
return None, "update tools error"
|
||||
|
||||
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_tools_error)
|
||||
res = _run(module.update("mcp-1"))
|
||||
assert res["code"] == 102
|
||||
assert "update tools error" in res["message"]
|
||||
|
||||
_set_request_json(monkeypatch, module, dict(base_payload))
|
||||
|
||||
async def _thread_pool_ok(_func, servers, _timeout):
|
||||
return {servers[0].name: [{"name": "tool_a"}, {"bad": True}]}, None
|
||||
|
||||
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_ok)
|
||||
monkeypatch.setattr(module.MCPServerService, "filter_update", lambda *_args, **_kwargs: False)
|
||||
res = _run(module.update("mcp-1"))
|
||||
assert "Failed to updated MCP server" in res["message"]
|
||||
|
||||
_set_request_json(monkeypatch, module, dict(base_payload))
|
||||
monkeypatch.setattr(module.MCPServerService, "filter_update", lambda *_args, **_kwargs: True)
|
||||
|
||||
def _get_by_id_fetch_fail(_mcp_id):
|
||||
if _get_by_id_fetch_fail.calls == 0:
|
||||
_get_by_id_fetch_fail.calls += 1
|
||||
return True, existing
|
||||
return False, None
|
||||
|
||||
_get_by_id_fetch_fail.calls = 0
|
||||
monkeypatch.setattr(module.MCPServerService, "get_by_id", _get_by_id_fetch_fail)
|
||||
res = _run(module.update("mcp-1"))
|
||||
assert "Failed to fetch updated MCP server" in res["message"]
|
||||
|
||||
_set_request_json(monkeypatch, module, dict(base_payload))
|
||||
|
||||
def _get_by_id_success(_mcp_id):
|
||||
if _get_by_id_success.calls == 0:
|
||||
_get_by_id_success.calls += 1
|
||||
return True, existing
|
||||
return True, updated
|
||||
|
||||
_get_by_id_success.calls = 0
|
||||
monkeypatch.setattr(module.MCPServerService, "get_by_id", _get_by_id_success)
|
||||
res = _run(module.update("mcp-1"))
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["id"] == "mcp-1"
|
||||
|
||||
_set_request_json(monkeypatch, module, dict(base_payload))
|
||||
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, existing))
|
||||
|
||||
async def _thread_pool_raises(_func, _servers, _timeout):
|
||||
raise RuntimeError("update explode")
|
||||
|
||||
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_raises)
|
||||
res = _run(module.update("mcp-1"))
|
||||
assert res["code"] == 100
|
||||
assert "update explode" in res["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_rm_failure_success_and_exception(monkeypatch):
|
||||
module = _load_mcp_api(monkeypatch)
|
||||
server = _DummyMCPServer(id="id1", name="srv", url="http://a", server_type="sse", tenant_id="tenant_1", variables={})
|
||||
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, server))
|
||||
|
||||
_set_request_json(monkeypatch, module, {"mcp_ids": ["a", "b"]})
|
||||
monkeypatch.setattr(module.MCPServerService, "delete_by_ids", lambda _ids: False)
|
||||
res = _run(module.rm("id1"))
|
||||
assert "Failed to delete MCP servers" in res["message"]
|
||||
|
||||
_set_request_json(monkeypatch, module, {"mcp_ids": ["a", "b"]})
|
||||
monkeypatch.setattr(module.MCPServerService, "delete_by_ids", lambda _ids: True)
|
||||
res = _run(module.rm("id1"))
|
||||
assert res["code"] == 0
|
||||
assert res["data"] is True
|
||||
|
||||
_set_request_json(monkeypatch, module, {"mcp_ids": ["a", "b"]})
|
||||
|
||||
def _raise_rm(_ids):
|
||||
raise RuntimeError("rm explode")
|
||||
|
||||
monkeypatch.setattr(module.MCPServerService, "delete_by_ids", _raise_rm)
|
||||
res = _run(module.rm("id1"))
|
||||
assert res["code"] == 100
|
||||
assert "rm explode" in res["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_import_multiple_missing_servers_and_exception(monkeypatch):
|
||||
module = _load_mcp_api(monkeypatch)
|
||||
|
||||
_set_request_json(monkeypatch, module, {"mcpServers": {}})
|
||||
res = _run(module.import_multiple.__wrapped__())
|
||||
assert "No MCP servers provided" in res["message"]
|
||||
|
||||
_set_request_json(monkeypatch, module, {"mcpServers": {"srv": {"type": "sse", "url": "http://x"}}, "timeout": "1"})
|
||||
|
||||
def _raise_import(**_kwargs):
|
||||
raise RuntimeError("import explode")
|
||||
|
||||
monkeypatch.setattr(module.MCPServerService, "get_by_name_and_tenant", _raise_import)
|
||||
res = _run(module.import_multiple.__wrapped__())
|
||||
assert res["code"] == 100
|
||||
assert "import explode" in res["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_import_multiple_mixed_results(monkeypatch):
|
||||
module = _load_mcp_api(monkeypatch)
|
||||
|
||||
payload = {
|
||||
"mcpServers": {
|
||||
"missing_fields": {"type": "sse"},
|
||||
"": {"type": "sse", "url": "http://empty"},
|
||||
"dup": {"type": "sse", "url": "http://dup", "authorization_token": "dup-token"},
|
||||
"tool_err": {"type": "sse", "url": "http://err"},
|
||||
"insert_fail": {"type": "sse", "url": "http://fail"},
|
||||
},
|
||||
"timeout": "3",
|
||||
}
|
||||
_set_request_json(monkeypatch, module, payload)
|
||||
|
||||
monkeypatch.setattr(module, "get_uuid", lambda: "uuid-import")
|
||||
|
||||
def _get_by_name_and_tenant(name, tenant_id):
|
||||
if name == "dup" and not _get_by_name_and_tenant.first_dup_seen:
|
||||
_get_by_name_and_tenant.first_dup_seen = True
|
||||
return True, object()
|
||||
return False, None
|
||||
|
||||
_get_by_name_and_tenant.first_dup_seen = False
|
||||
monkeypatch.setattr(module.MCPServerService, "get_by_name_and_tenant", _get_by_name_and_tenant)
|
||||
|
||||
async def _thread_pool_exec(func, servers, _timeout):
|
||||
mcp_server = servers[0]
|
||||
if mcp_server.name == "tool_err":
|
||||
return None, "tool call failed"
|
||||
return {mcp_server.name: [{"name": "tool_a"}, {"invalid": True}]}, None
|
||||
|
||||
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_exec)
|
||||
|
||||
def _insert(**kwargs):
|
||||
return kwargs["name"] != "insert_fail"
|
||||
|
||||
monkeypatch.setattr(module.MCPServerService, "insert", _insert)
|
||||
|
||||
res = _run(module.import_multiple.__wrapped__())
|
||||
assert res["code"] == 0
|
||||
|
||||
results = {item["server"]: item for item in res["data"]["results"]}
|
||||
assert results["missing_fields"]["success"] is False
|
||||
assert "Missing required fields" in results["missing_fields"]["message"]
|
||||
assert results[""]["success"] is False
|
||||
assert "Invalid MCP name" in results[""]["message"]
|
||||
assert results["tool_err"]["success"] is False
|
||||
assert "tool call failed" in results["tool_err"]["message"]
|
||||
assert results["insert_fail"]["success"] is False
|
||||
assert "Failed to create MCP server" in results["insert_fail"]["message"]
|
||||
assert results["dup"]["success"] is True
|
||||
assert results["dup"]["new_name"] == "dup_0"
|
||||
assert "Renamed from 'dup' to 'dup_0' avoid duplication" == results["dup"]["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_detail_download_success_and_exception(monkeypatch):
|
||||
module = _load_mcp_api(monkeypatch)
|
||||
monkeypatch.setattr(module, "request", SimpleNamespace(args=_Args({"mode": "download"})))
|
||||
|
||||
monkeypatch.setattr(
|
||||
module.MCPServerService,
|
||||
"get_by_id",
|
||||
lambda _mcp_id: (
|
||||
True,
|
||||
_DummyMCPServer(
|
||||
id="id1",
|
||||
name="srv-one",
|
||||
url="http://one",
|
||||
server_type="sse",
|
||||
tenant_id="tenant_1",
|
||||
variables={"authorization_token": "tok", "tools": {"tool_a": {"enabled": True}}},
|
||||
),
|
||||
),
|
||||
)
|
||||
res = module.detail("id1")
|
||||
assert res["code"] == 0
|
||||
assert list(res["data"]["mcpServers"].keys()) == ["srv-one"]
|
||||
|
||||
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (False, None))
|
||||
res = module.detail("missing")
|
||||
assert res["code"] == 102
|
||||
assert "Cannot find MCP server missing for user tenant_1" in res["message"]
|
||||
|
||||
monkeypatch.setattr(
|
||||
module.MCPServerService,
|
||||
"get_by_id",
|
||||
lambda _mcp_id: (
|
||||
True,
|
||||
_DummyMCPServer(
|
||||
id="id2",
|
||||
name="srv-two",
|
||||
url="http://two",
|
||||
server_type="sse",
|
||||
tenant_id="other",
|
||||
variables={},
|
||||
),
|
||||
),
|
||||
)
|
||||
res = module.detail("id2")
|
||||
assert res["code"] == 102
|
||||
assert "Cannot find MCP server id2 for user tenant_1" in res["message"]
|
||||
|
||||
def _raise_export(_mcp_id):
|
||||
raise RuntimeError("export explode")
|
||||
|
||||
monkeypatch.setattr(module.MCPServerService, "get_by_id", _raise_export)
|
||||
res = module.detail("id1")
|
||||
assert res["code"] == 100
|
||||
assert "export explode" in res["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_test_mcp_route_matrix_unit(monkeypatch):
|
||||
module = _load_mcp_api(monkeypatch)
|
||||
|
||||
_set_request_json(monkeypatch, module, {"url": "", "server_type": "sse"})
|
||||
res = _run(module.test_mcp("mcp-1"))
|
||||
assert "Invalid MCP url" in res["message"]
|
||||
|
||||
_set_request_json(monkeypatch, module, {"url": "http://a", "server_type": "invalid"})
|
||||
res = _run(module.test_mcp("mcp-1"))
|
||||
assert "Unsupported MCP server type" in res["message"]
|
||||
|
||||
close_calls = []
|
||||
|
||||
async def _thread_pool_exec_inner_error(func, *args):
|
||||
if func is module.close_multiple_mcp_toolcall_sessions:
|
||||
close_calls.append(args[0])
|
||||
return None
|
||||
if getattr(func, "__name__", "") == "get_tools":
|
||||
raise RuntimeError("get tools explode")
|
||||
return func(*args)
|
||||
|
||||
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_exec_inner_error)
|
||||
_set_request_json(monkeypatch, module, {"url": "http://a", "server_type": "sse"})
|
||||
res = _run(module.test_mcp("mcp-1"))
|
||||
assert res["code"] == 102
|
||||
assert "Test MCP error: get tools explode" in res["message"]
|
||||
assert close_calls and len(close_calls[-1]) == 1
|
||||
|
||||
close_calls_success = []
|
||||
|
||||
async def _thread_pool_exec_success(func, *args):
|
||||
if func is module.close_multiple_mcp_toolcall_sessions:
|
||||
close_calls_success.append(args[0])
|
||||
return None
|
||||
return func(*args)
|
||||
|
||||
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_exec_success)
|
||||
_set_request_json(monkeypatch, module, {"url": "http://a", "server_type": "sse"})
|
||||
res = _run(module.test_mcp("mcp-1"))
|
||||
assert res["code"] == 0
|
||||
assert res["data"][0]["name"] == "tool_a"
|
||||
assert all(tool["enabled"] is True for tool in res["data"])
|
||||
assert close_calls_success and len(close_calls_success[-1]) == 1
|
||||
|
||||
def _raise_session(*_args, **_kwargs):
|
||||
raise RuntimeError("session explode")
|
||||
|
||||
monkeypatch.setattr(module, "MCPToolCallSession", _raise_session)
|
||||
_set_request_json(monkeypatch, module, {"url": "http://a", "server_type": "sse"})
|
||||
res = _run(module.test_mcp("mcp-1"))
|
||||
assert res["code"] == 100
|
||||
assert "session explode" in res["message"]
|
||||
210
test/testcases/restful_api/test_memories_messages.py
Normal file
210
test/testcases/restful_api/test_memories_messages.py
Normal file
@@ -0,0 +1,210 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def memory_cleanup(rest_client):
|
||||
created_ids: list[str] = []
|
||||
|
||||
def _cleanup():
|
||||
cleanup_errors = []
|
||||
for memory_id in created_ids:
|
||||
delete_res = rest_client.delete(f"/memories/{memory_id}")
|
||||
if delete_res.status_code != 200:
|
||||
cleanup_errors.append((memory_id, delete_res.status_code, delete_res.text))
|
||||
continue
|
||||
delete_payload = delete_res.json()
|
||||
if delete_payload["code"] not in (0, 404):
|
||||
cleanup_errors.append((memory_id, delete_res.status_code, delete_payload))
|
||||
assert not cleanup_errors, f"Memory cleanup failed: {cleanup_errors}"
|
||||
|
||||
yield created_ids
|
||||
_cleanup()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_memory_resource(rest_client, memory_cleanup):
|
||||
def _create(name_prefix: str = "restful_memory") -> str:
|
||||
payload = {
|
||||
"name": f"{name_prefix}_{uuid.uuid4().hex[:8]}",
|
||||
"memory_type": ["raw"],
|
||||
"embd_id": "BAAI/bge-small-en-v1.5@Builtin",
|
||||
"llm_id": "glm-4-flash@ZHIPU-AI",
|
||||
}
|
||||
res = rest_client.post("/memories", json=payload)
|
||||
assert res.status_code == 200
|
||||
res_payload = res.json()
|
||||
assert res_payload["code"] == 0, res_payload
|
||||
memory_id = res_payload["data"]["id"]
|
||||
memory_cleanup.append(memory_id)
|
||||
return memory_id
|
||||
|
||||
yield _create
|
||||
|
||||
|
||||
def _add_message(rest_client, memory_id: str, user_input: str, agent_response: str) -> None:
|
||||
add_res = rest_client.post(
|
||||
"/messages",
|
||||
json={
|
||||
"memory_id": [memory_id],
|
||||
"agent_id": uuid.uuid4().hex,
|
||||
"session_id": uuid.uuid4().hex,
|
||||
"user_id": uuid.uuid4().hex,
|
||||
"user_input": user_input,
|
||||
"agent_response": agent_response,
|
||||
},
|
||||
)
|
||||
assert add_res.status_code == 200
|
||||
add_payload = add_res.json()
|
||||
assert add_payload["code"] == 0, add_payload
|
||||
|
||||
|
||||
def _wait_for_memory_messages(rest_client, memory_id: str, timeout: float = 10, interval: float = 0.2) -> list[dict]:
|
||||
deadline = time.time() + timeout
|
||||
last_payload = None
|
||||
while time.time() < deadline:
|
||||
res = rest_client.get(f"/memories/{memory_id}")
|
||||
if res.status_code == 200:
|
||||
payload = res.json()
|
||||
last_payload = payload
|
||||
if payload.get("code") == 0:
|
||||
message_list = payload.get("data", {}).get("messages", {}).get("message_list", [])
|
||||
if message_list:
|
||||
return message_list
|
||||
time.sleep(interval)
|
||||
pytest.fail(f"Timed out waiting for memory messages: {last_payload}")
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_memory_crud_cycle(rest_client, create_memory_resource):
|
||||
memory_id = create_memory_resource("restful_memory_crud")
|
||||
|
||||
list_res = rest_client.get("/memories")
|
||||
assert list_res.status_code == 200
|
||||
list_payload = list_res.json()
|
||||
assert list_payload["code"] == 0, list_payload
|
||||
assert any(item["id"] == memory_id for item in list_payload["data"]["memory_list"]), list_payload
|
||||
|
||||
config_res = rest_client.get(f"/memories/{memory_id}/config")
|
||||
assert config_res.status_code == 200
|
||||
config_payload = config_res.json()
|
||||
assert config_payload["code"] == 0, config_payload
|
||||
assert config_payload["data"]["id"] == memory_id, config_payload
|
||||
|
||||
update_res = rest_client.put(
|
||||
f"/memories/{memory_id}",
|
||||
json={"name": f"updated_{uuid.uuid4().hex[:6]}", "permissions": "me"},
|
||||
)
|
||||
assert update_res.status_code == 200
|
||||
update_payload = update_res.json()
|
||||
assert update_payload["code"] == 0, update_payload
|
||||
|
||||
delete_res = rest_client.delete(f"/memories/{memory_id}")
|
||||
assert delete_res.status_code == 200
|
||||
delete_payload = delete_res.json()
|
||||
assert delete_payload["code"] == 0, delete_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_memory_create_missing_required_fields(rest_client):
|
||||
res = rest_client.post("/memories", json={"name": "missing_models", "memory_type": ["raw"]})
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 101, payload
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_messages_add_list_recent_content_update_forget(rest_client, create_memory_resource):
|
||||
memory_id = create_memory_resource("restful_message_memory")
|
||||
_add_message(
|
||||
rest_client,
|
||||
memory_id,
|
||||
user_input="what is coriander?",
|
||||
agent_response="coriander can refer to leaves or seeds",
|
||||
)
|
||||
|
||||
message_list = _wait_for_memory_messages(rest_client, memory_id)
|
||||
|
||||
message_id = message_list[0]["message_id"]
|
||||
|
||||
recent_res = rest_client.get("/messages", params={"memory_id": memory_id, "limit": 10})
|
||||
assert recent_res.status_code == 200
|
||||
recent_payload = recent_res.json()
|
||||
assert recent_payload["code"] == 0, recent_payload
|
||||
assert any(item["message_id"] == message_id for item in recent_payload["data"]), recent_payload
|
||||
|
||||
content_res = rest_client.get(f"/messages/{memory_id}:{message_id}/content")
|
||||
assert content_res.status_code == 200
|
||||
content_payload = content_res.json()
|
||||
assert content_payload["code"] == 0, content_payload
|
||||
assert content_payload["data"]["content"], content_payload
|
||||
|
||||
update_res = rest_client.put(f"/messages/{memory_id}:{message_id}", json={"status": False})
|
||||
assert update_res.status_code == 200
|
||||
update_payload = update_res.json()
|
||||
assert update_payload["code"] == 0, update_payload
|
||||
|
||||
forget_res = rest_client.delete(f"/messages/{memory_id}:{message_id}")
|
||||
assert forget_res.status_code == 200
|
||||
forget_payload = forget_res.json()
|
||||
assert forget_payload["code"] == 0, forget_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_message_status_validation_requires_boolean(rest_client, create_memory_resource):
|
||||
memory_id = create_memory_resource("restful_message_status_validation")
|
||||
_add_message(rest_client, memory_id, user_input="hello", agent_response="hello")
|
||||
|
||||
message_id = _wait_for_memory_messages(rest_client, memory_id)[0]["message_id"]
|
||||
|
||||
invalid_update = rest_client.put(f"/messages/{memory_id}:{message_id}", json={"status": "false"})
|
||||
assert invalid_update.status_code == 200
|
||||
invalid_payload = invalid_update.json()
|
||||
assert invalid_payload["code"] == 101, invalid_payload
|
||||
assert "Status must be a boolean." in invalid_payload["message"], invalid_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_messages_recent_requires_memory_ids(rest_client):
|
||||
res = rest_client.get("/messages")
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 101, payload
|
||||
assert "memory_ids is required" in payload["message"], payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_message_search_route_contract(rest_client, create_memory_resource):
|
||||
memory_id = create_memory_resource("restful_message_search")
|
||||
_add_message(
|
||||
rest_client,
|
||||
memory_id,
|
||||
user_input="what is pineapple?",
|
||||
agent_response="pineapple is a tropical fruit",
|
||||
)
|
||||
|
||||
_wait_for_memory_messages(rest_client, memory_id)
|
||||
|
||||
res = rest_client.get("/messages/search", params={"memory_id": memory_id, "query": "pineapple", "top_n": 3})
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 0, payload
|
||||
assert isinstance(payload["data"], list), payload
|
||||
165
test/testcases/restful_api/test_memory_messages.py
Normal file
165
test/testcases/restful_api/test_memory_messages.py
Normal file
@@ -0,0 +1,165 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _memory_payload(name: str) -> dict:
|
||||
return {
|
||||
"name": name,
|
||||
"memory_type": ["raw"],
|
||||
"embd_id": "BAAI/bge-small-en-v1.5@Builtin",
|
||||
"llm_id": "glm-4-flash@ZHIPU-AI",
|
||||
}
|
||||
|
||||
|
||||
def _create_memory(rest_client, name: str) -> dict:
|
||||
res = rest_client.post("/memories", json=_memory_payload(name))
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
if payload["code"] == 0:
|
||||
return payload["data"]
|
||||
|
||||
pytest.fail(f"Failed to create memory: {payload}")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def memory_resource(rest_client):
|
||||
memory = _create_memory(rest_client, f"restful_memory_{uuid.uuid4().hex[:8]}")
|
||||
memory_id = memory["id"]
|
||||
try:
|
||||
yield memory
|
||||
finally:
|
||||
delete_res = rest_client.delete(f"/memories/{memory_id}")
|
||||
assert delete_res.status_code == 200, delete_res.text
|
||||
delete_payload = delete_res.json()
|
||||
assert delete_payload["code"] in (0, 404), delete_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_memory_and_message_routes_require_auth(rest_client_noauth):
|
||||
memory_res = rest_client_noauth.get("/memories")
|
||||
assert memory_res.status_code == 401
|
||||
memory_payload = memory_res.json()
|
||||
assert memory_payload["code"] == 401, memory_payload
|
||||
|
||||
msg_list_res = rest_client_noauth.get("/messages")
|
||||
assert msg_list_res.status_code == 401
|
||||
msg_list_payload = msg_list_res.json()
|
||||
assert msg_list_payload["code"] == 401, msg_list_payload
|
||||
|
||||
msg_search_res = rest_client_noauth.get("/messages/search")
|
||||
assert msg_search_res.status_code == 401
|
||||
msg_search_payload = msg_search_res.json()
|
||||
assert msg_search_payload["code"] == 401, msg_search_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_memory_crud_and_config(rest_client):
|
||||
memory = _create_memory(rest_client, f"restful_memory_crud_{uuid.uuid4().hex[:8]}")
|
||||
memory_id = memory["id"]
|
||||
try:
|
||||
config_res = rest_client.get(f"/memories/{memory_id}/config")
|
||||
assert config_res.status_code == 200
|
||||
config_payload = config_res.json()
|
||||
assert config_payload["code"] == 0, config_payload
|
||||
assert config_payload["data"]["id"] == memory_id, config_payload
|
||||
|
||||
list_res = rest_client.get("/memories", params={"keywords": memory["name"]})
|
||||
assert list_res.status_code == 200
|
||||
list_payload = list_res.json()
|
||||
assert list_payload["code"] == 0, list_payload
|
||||
assert any(item["id"] == memory_id for item in list_payload["data"]["memory_list"]), list_payload
|
||||
|
||||
update_res = rest_client.put(f"/memories/{memory_id}", json={"name": "restful_memory_updated"})
|
||||
assert update_res.status_code == 200
|
||||
update_payload = update_res.json()
|
||||
assert update_payload["code"] == 0, update_payload
|
||||
finally:
|
||||
delete_res = rest_client.delete(f"/memories/{memory_id}")
|
||||
assert delete_res.status_code == 200, delete_res.text
|
||||
delete_payload = delete_res.json()
|
||||
assert delete_payload["code"] in (0, 404), delete_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_memory_update_invalid_name(rest_client, memory_resource):
|
||||
memory_id = memory_resource["id"]
|
||||
res = rest_client.put(f"/memories/{memory_id}", json={"name": " "})
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 101, payload
|
||||
assert "cannot be empty" in payload["message"], payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_messages_list_and_search_validation_contracts(rest_client, memory_resource):
|
||||
memory_id = memory_resource["id"]
|
||||
|
||||
list_res = rest_client.get("/messages", params={"memory_id": memory_id, "limit": 10})
|
||||
assert list_res.status_code == 200
|
||||
list_payload = list_res.json()
|
||||
assert list_payload["code"] == 0, list_payload
|
||||
assert isinstance(list_payload["data"], list), list_payload
|
||||
|
||||
missing_memory_res = rest_client.get("/messages")
|
||||
assert missing_memory_res.status_code == 200
|
||||
missing_memory_payload = missing_memory_res.json()
|
||||
assert missing_memory_payload["code"] == 101, missing_memory_payload
|
||||
assert "memory_ids is required" in missing_memory_payload["message"], missing_memory_payload
|
||||
|
||||
search_res = rest_client.get("/messages/search", params={"memory_id": memory_id, "query": "coriander"})
|
||||
assert search_res.status_code == 200
|
||||
search_payload = search_res.json()
|
||||
assert search_payload["code"] == 0, search_payload
|
||||
assert isinstance(search_payload["data"], list), search_payload
|
||||
|
||||
search_no_memory = rest_client.get("/messages/search", params={"query": "coriander"})
|
||||
assert search_no_memory.status_code == 200
|
||||
search_no_memory_payload = search_no_memory.json()
|
||||
assert search_no_memory_payload["code"] == 0, search_no_memory_payload
|
||||
assert isinstance(search_no_memory_payload["data"], list), search_no_memory_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_message_update_forget_and_content_error_contracts(rest_client, memory_resource):
|
||||
memory_id = memory_resource["id"]
|
||||
|
||||
invalid_status_res = rest_client.put(
|
||||
f"/messages/{memory_id}:1",
|
||||
json={"status": "false"},
|
||||
)
|
||||
assert invalid_status_res.status_code == 200
|
||||
invalid_status_payload = invalid_status_res.json()
|
||||
assert invalid_status_payload["code"] == 101, invalid_status_payload
|
||||
assert "Status must be a boolean" in invalid_status_payload["message"], invalid_status_payload
|
||||
|
||||
missing_content_res = rest_client.get(f"/messages/{memory_id}:999999/content")
|
||||
assert missing_content_res.status_code == 200
|
||||
missing_content_payload = missing_content_res.json()
|
||||
assert missing_content_payload["code"] == 404, missing_content_payload
|
||||
|
||||
invalid_memory_forget = rest_client.delete("/messages/missing_memory_id:1")
|
||||
assert invalid_memory_forget.status_code == 200
|
||||
invalid_memory_forget_payload = invalid_memory_forget.json()
|
||||
assert invalid_memory_forget_payload["code"] == 404, invalid_memory_forget_payload
|
||||
|
||||
invalid_memory_update = rest_client.put("/messages/missing_memory_id:1", json={"status": False})
|
||||
assert invalid_memory_update.status_code == 200
|
||||
invalid_memory_update_payload = invalid_memory_update.json()
|
||||
assert invalid_memory_update_payload["code"] == 404, invalid_memory_update_payload
|
||||
212
test/testcases/restful_api/test_openai_compatible.py
Normal file
212
test/testcases/restful_api/test_openai_compatible.py
Normal file
@@ -0,0 +1,212 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _sse_events(response_text: str) -> list[str]:
|
||||
return [line[5:] for line in response_text.splitlines() if line.startswith("data:")]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_message",
|
||||
[
|
||||
(
|
||||
{
|
||||
"model": "model",
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"extra_body": "invalid_extra_body",
|
||||
},
|
||||
"extra_body must be an object.",
|
||||
),
|
||||
(
|
||||
{
|
||||
"model": "model",
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"extra_body": {"reference_metadata": "invalid_reference_metadata"},
|
||||
},
|
||||
"reference_metadata must be an object.",
|
||||
),
|
||||
(
|
||||
{
|
||||
"model": "model",
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"extra_body": {"reference_metadata": {"fields": "author"}},
|
||||
},
|
||||
"reference_metadata.fields must be an array.",
|
||||
),
|
||||
(
|
||||
{
|
||||
"model": "model",
|
||||
"messages": [],
|
||||
},
|
||||
"You have to provide messages.",
|
||||
),
|
||||
(
|
||||
{
|
||||
"model": "model",
|
||||
"messages": [{"role": "assistant", "content": "hello"}],
|
||||
},
|
||||
"The last content of this conversation is not from user.",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_openai_compatible_validation_payloads(rest_client, create_chat, payload, expected_message):
|
||||
chat_id = create_chat("restful_openai_validation_chat")
|
||||
res = rest_client.post(f"/openai/{chat_id}/chat/completions", json=payload)
|
||||
assert res.status_code == 200
|
||||
data = res.json()
|
||||
assert data["code"] != 0, data
|
||||
assert expected_message in data.get("message", ""), data
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_openai_compatible_metadata_condition_requires_object(rest_client, create_chat):
|
||||
chat_id = create_chat("restful_openai_metadata_condition_chat")
|
||||
res = rest_client.post(
|
||||
f"/openai/{chat_id}/chat/completions",
|
||||
json={
|
||||
"model": "model",
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"extra_body": {"metadata_condition": "invalid"},
|
||||
},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 102, payload
|
||||
assert "metadata_condition must be an object." in payload["message"], payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_openai_compatible_invalid_chat(rest_client):
|
||||
res = rest_client.post(
|
||||
"/openai/invalid_chat_id/chat/completions",
|
||||
json={
|
||||
"model": "model",
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] != 0, payload
|
||||
assert "don't own the chat" in payload["message"], payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_openai_compatible_nonstream_shape(rest_client, create_chat):
|
||||
chat_id = create_chat("restful_openai_nonstream_chat")
|
||||
res = rest_client.post(
|
||||
f"/openai/{chat_id}/chat/completions",
|
||||
json={
|
||||
"model": "model",
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"stream": False,
|
||||
},
|
||||
timeout=60,
|
||||
)
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
|
||||
assert payload["object"] == "chat.completion", payload
|
||||
assert isinstance(payload["choices"], list) and payload["choices"], payload
|
||||
first_choice = payload["choices"][0]
|
||||
assert first_choice.get("finish_reason") == "stop", payload
|
||||
assert first_choice.get("message", {}).get("role") == "assistant", payload
|
||||
assert "content" in first_choice.get("message", {}), payload
|
||||
|
||||
usage = payload.get("usage", {})
|
||||
assert "prompt_tokens" in usage, usage
|
||||
assert "completion_tokens" in usage, usage
|
||||
assert "total_tokens" in usage, usage
|
||||
assert usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"], usage
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_openai_compatible_nonstream_with_reference_output_shape(rest_client, create_chat):
|
||||
chat_id = create_chat("restful_openai_reference_chat")
|
||||
res = rest_client.post(
|
||||
f"/openai/{chat_id}/chat/completions",
|
||||
json={
|
||||
"model": "model",
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"stream": False,
|
||||
"extra_body": {
|
||||
"reference": True,
|
||||
"reference_metadata": {"include": True, "fields": ["author"]},
|
||||
},
|
||||
},
|
||||
timeout=60,
|
||||
)
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
choice_msg = payload["choices"][0]["message"]
|
||||
assert "reference" in choice_msg, payload
|
||||
assert isinstance(choice_msg["reference"], list), payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_openai_compatible_stream_shape_and_done_semantics(rest_client, create_chat):
|
||||
chat_id = create_chat("restful_openai_stream_chat")
|
||||
res = rest_client.post(
|
||||
f"/openai/{chat_id}/chat/completions",
|
||||
json={
|
||||
"model": "model",
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"stream": True,
|
||||
"extra_body": {"reference": True},
|
||||
},
|
||||
timeout=60,
|
||||
)
|
||||
assert res.status_code == 200
|
||||
content_type = res.headers.get("Content-Type", "")
|
||||
assert "text/event-stream" in content_type, content_type
|
||||
|
||||
events = _sse_events(res.text)
|
||||
assert events, res.text
|
||||
assert events[-1].strip() == "[DONE]", events[-1]
|
||||
|
||||
json_events = [json.loads(evt) for evt in events if evt.strip() != "[DONE]"]
|
||||
assert json_events, events
|
||||
assert any(evt.get("object") == "chat.completion.chunk" for evt in json_events), json_events
|
||||
assert any(evt.get("choices", [{}])[0].get("finish_reason") == "stop" for evt in json_events), json_events
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_openai_compatible_reference_metadata_fields_filter_accepts_array(rest_client, create_chat):
|
||||
chat_id = create_chat("restful_openai_reference_fields_array_chat")
|
||||
res = rest_client.post(
|
||||
f"/openai/{chat_id}/chat/completions",
|
||||
json={
|
||||
"model": "model",
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"stream": False,
|
||||
"extra_body": {
|
||||
"reference": True,
|
||||
"reference_metadata": {"include": True, "fields": ["author", "year"]},
|
||||
},
|
||||
},
|
||||
timeout=60,
|
||||
)
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload.get("choices"), payload
|
||||
choice_msg = payload["choices"][0]["message"]
|
||||
assert "reference" in choice_msg, payload
|
||||
assert isinstance(choice_msg["reference"], list), payload
|
||||
92
test/testcases/restful_api/test_plugin_tools.py
Normal file
92
test/testcases/restful_api/test_plugin_tools.py
Normal file
@@ -0,0 +1,92 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_plugin_tools_requires_auth(rest_client_noauth):
|
||||
res = rest_client_noauth.get("/plugin/tools")
|
||||
assert res.status_code == 401
|
||||
payload = res.json()
|
||||
assert payload["code"] == 401, payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_plugin_tools_contract(rest_client):
|
||||
res = rest_client.get("/plugin/tools")
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 0, payload
|
||||
assert isinstance(payload["data"], list), payload
|
||||
|
||||
|
||||
class _DummyManager:
|
||||
def route(self, *_args, **_kwargs):
|
||||
def decorator(func):
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _load_plugin_module(monkeypatch):
|
||||
repo_root = Path(__file__).resolve().parents[3]
|
||||
|
||||
common_pkg = ModuleType("common")
|
||||
common_pkg.__path__ = [str(repo_root / "common")]
|
||||
monkeypatch.setitem(sys.modules, "common", common_pkg)
|
||||
|
||||
stub_apps = ModuleType("api.apps")
|
||||
stub_apps.login_required = lambda func: func
|
||||
monkeypatch.setitem(sys.modules, "api.apps", stub_apps)
|
||||
|
||||
stub_plugin = ModuleType("agent.plugin")
|
||||
|
||||
class _StubGlobalPluginManager:
|
||||
@staticmethod
|
||||
def get_llm_tools():
|
||||
return []
|
||||
|
||||
stub_plugin.GlobalPluginManager = _StubGlobalPluginManager
|
||||
monkeypatch.setitem(sys.modules, "agent.plugin", stub_plugin)
|
||||
|
||||
module_path = repo_root / "api" / "apps" / "restful_apis" / "plugin_api.py"
|
||||
spec = importlib.util.spec_from_file_location("restful_plugin_api_unit", module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
module.manager = _DummyManager()
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_plugin_tools_metadata_shape_unit(monkeypatch):
|
||||
module = _load_plugin_module(monkeypatch)
|
||||
|
||||
class _DummyTool:
|
||||
def get_metadata(self):
|
||||
return {"name": "dummy", "description": "test"}
|
||||
|
||||
monkeypatch.setattr(module.GlobalPluginManager, "get_llm_tools", staticmethod(lambda: [_DummyTool()]))
|
||||
res = module.llm_tools()
|
||||
assert res["code"] == 0
|
||||
assert isinstance(res["data"], list)
|
||||
assert res["data"][0]["name"] == "dummy"
|
||||
assert res["data"][0]["description"] == "test"
|
||||
109
test/testcases/restful_api/test_retrieval.py
Normal file
109
test/testcases/restful_api/test_retrieval.py
Normal file
@@ -0,0 +1,109 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_dataset_search_rest_endpoint(rest_client, ensure_parsed_document):
|
||||
dataset_id, _ = ensure_parsed_document()
|
||||
res = rest_client.post(
|
||||
f"/datasets/{dataset_id}/search",
|
||||
json={"question": "test TXT file", "top_k": 5},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 0, payload
|
||||
assert "chunks" in payload["data"], payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_multi_dataset_search_rest_endpoint(rest_client, ensure_parsed_document):
|
||||
dataset_id, _ = ensure_parsed_document()
|
||||
res = rest_client.post(
|
||||
"/datasets/search",
|
||||
json={"dataset_ids": [dataset_id], "question": "test TXT file", "top_k": 5},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 0, payload
|
||||
assert "chunks" in payload["data"], payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_multi_dataset_search_with_metadata_filter(rest_client, ensure_parsed_document):
|
||||
dataset_id, document_id = ensure_parsed_document()
|
||||
meta_res = rest_client.patch(
|
||||
f"/datasets/{dataset_id}/documents/metadatas",
|
||||
json={
|
||||
"selector": {"document_ids": [document_id]},
|
||||
"updates": [{"key": "author", "value": "qa_batch2"}],
|
||||
"deletes": [],
|
||||
},
|
||||
)
|
||||
assert meta_res.status_code == 200
|
||||
meta_payload = meta_res.json()
|
||||
assert meta_payload["code"] == 0, meta_payload
|
||||
|
||||
res = rest_client.post(
|
||||
"/datasets/search",
|
||||
json={
|
||||
"dataset_ids": [dataset_id],
|
||||
"question": "test TXT file",
|
||||
"meta_data_filter": {
|
||||
"method": "manual",
|
||||
"logic": "and",
|
||||
"manual": [{"key": "author", "op": "=", "value": "qa_batch2"}],
|
||||
},
|
||||
},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 0, payload
|
||||
assert "chunks" in payload["data"], payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_retrieval_compatibility_endpoint(rest_client, ensure_parsed_document):
|
||||
dataset_id, _ = ensure_parsed_document()
|
||||
# /api/v1/retrieval is SDK compatibility endpoint from api/apps/sdk/doc.py.
|
||||
res = rest_client.post(
|
||||
"/retrieval",
|
||||
json={"dataset_ids": [dataset_id], "question": "test TXT file", "top_k": 5},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 0, payload
|
||||
assert "chunks" in payload["data"], payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_retrieval_compatibility_requires_dataset_ids(rest_client):
|
||||
res = rest_client.post("/retrieval", json={"question": "test"})
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 102, payload
|
||||
assert payload["message"] == "`dataset_ids` is required.", payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_retrieval_compatibility_requires_auth(rest_client_noauth):
|
||||
res = rest_client_noauth.post("/retrieval", json={"question": "test", "dataset_ids": ["x"]})
|
||||
assert res.status_code == 401
|
||||
payload = res.json()
|
||||
# token_required preserves legacy payload code/message while returning HTTP 401.
|
||||
assert payload["code"] == 0, payload
|
||||
assert payload["message"] == "`Authorization` can't be empty", payload
|
||||
28
test/testcases/restful_api/test_router_contracts.py
Normal file
28
test/testcases/restful_api/test_router_contracts.py
Normal file
@@ -0,0 +1,28 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import pytest
|
||||
from configs import VERSION
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_route_not_found_returns_json(rest_client_noauth):
|
||||
res = rest_client_noauth.get("/__missing_route__")
|
||||
assert res.status_code == 404
|
||||
payload = res.json()
|
||||
assert payload["code"] == 404, payload
|
||||
assert payload["error"] == "Not Found", payload
|
||||
assert payload["message"] == f"Not Found: /api/{VERSION}/__missing_route__", payload
|
||||
155
test/testcases/restful_api/test_searches.py
Normal file
155
test/testcases/restful_api/test_searches.py
Normal file
@@ -0,0 +1,155 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import json
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def search_resource(rest_client):
|
||||
name = f"restful_search_{uuid.uuid4().hex[:8]}"
|
||||
create_res = rest_client.post("/searches", json={"name": name, "description": "restful search"})
|
||||
assert create_res.status_code == 200
|
||||
create_payload = create_res.json()
|
||||
assert create_payload["code"] == 0, create_payload
|
||||
search_id = create_payload["data"]["search_id"]
|
||||
|
||||
try:
|
||||
yield search_id
|
||||
finally:
|
||||
delete_res = rest_client.delete(f"/searches/{search_id}")
|
||||
assert delete_res.status_code == 200, delete_res.text
|
||||
delete_payload = delete_res.json()
|
||||
assert delete_payload["code"] in (0, 109), delete_payload
|
||||
|
||||
|
||||
def _sse_events(response_text: str) -> list[str]:
|
||||
return [line[5:] for line in response_text.splitlines() if line.startswith("data:")]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_search_routes_require_auth(rest_client_noauth):
|
||||
create_res = rest_client_noauth.post("/searches", json={"name": "search_noauth"})
|
||||
assert create_res.status_code == 401
|
||||
create_payload = create_res.json()
|
||||
assert create_payload["code"] == 401, create_payload
|
||||
|
||||
list_res = rest_client_noauth.get("/searches")
|
||||
assert list_res.status_code == 401
|
||||
list_payload = list_res.json()
|
||||
assert list_payload["code"] == 401, list_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_search_crud_contract(rest_client, search_resource):
|
||||
search_id = search_resource
|
||||
|
||||
list_res = rest_client.get("/searches")
|
||||
assert list_res.status_code == 200
|
||||
list_payload = list_res.json()
|
||||
assert list_payload["code"] == 0, list_payload
|
||||
assert any(item.get("id") == search_id for item in list_payload["data"]["search_apps"]), list_payload
|
||||
|
||||
detail_res = rest_client.get(f"/searches/{search_id}")
|
||||
assert detail_res.status_code == 200
|
||||
detail_payload = detail_res.json()
|
||||
assert detail_payload["code"] == 0, detail_payload
|
||||
assert detail_payload["data"]["id"] == search_id, detail_payload
|
||||
|
||||
new_name = f"search_updated_{uuid.uuid4().hex[:6]}"
|
||||
update_res = rest_client.put(
|
||||
f"/searches/{search_id}",
|
||||
json={"name": new_name, "search_config": {"top_k": 3}},
|
||||
)
|
||||
assert update_res.status_code == 200
|
||||
update_payload = update_res.json()
|
||||
assert update_payload["code"] == 0, update_payload
|
||||
assert update_payload["data"]["name"] == new_name, update_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_search_create_invalid_name(rest_client):
|
||||
res = rest_client.post("/searches", json={"name": ""})
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 102, payload
|
||||
assert "empty" in payload["message"], payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_search_update_invalid_search_id(rest_client):
|
||||
res = rest_client.put(
|
||||
"/searches/invalid_search_id",
|
||||
json={"name": "invalid", "search_config": {}},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 109, payload
|
||||
assert "No authorization" in payload["message"], payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_search_completion_requires_question(rest_client, search_resource):
|
||||
search_id = search_resource
|
||||
|
||||
completion_res = rest_client.post(f"/searches/{search_id}/completion", json={})
|
||||
assert completion_res.status_code == 200
|
||||
completion_payload = completion_res.json()
|
||||
assert completion_payload["code"] == 101, completion_payload
|
||||
assert "required argument are missing: question" in completion_payload["message"], completion_payload
|
||||
|
||||
completions_res = rest_client.post(f"/searches/{search_id}/completions", json={})
|
||||
assert completions_res.status_code == 200
|
||||
completions_payload = completions_res.json()
|
||||
assert completions_payload["code"] == 101, completions_payload
|
||||
assert "required argument are missing: question" in completions_payload["message"], completions_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_search_completion_requires_kb_ids(rest_client, search_resource):
|
||||
search_id = search_resource
|
||||
for path in ("completion", "completions"):
|
||||
res = rest_client.post(
|
||||
f"/searches/{search_id}/{path}",
|
||||
json={"question": "what is coriander?"},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 102, payload
|
||||
assert "`kb_ids` is required" in payload["message"], payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_search_completion_sse_shape_when_kb_ids_provided(rest_client, search_resource):
|
||||
search_id = search_resource
|
||||
# Even with kb_ids provided, runtime may return an error event in-stream, but
|
||||
# contract remains SSE with JSON data lines and terminal boolean event.
|
||||
res = rest_client.post(
|
||||
f"/searches/{search_id}/completion",
|
||||
json={"question": "what is coriander?", "kb_ids": ["nonexistent_dataset"]},
|
||||
timeout=60,
|
||||
)
|
||||
assert res.status_code == 200
|
||||
content_type = res.headers.get("Content-Type", "")
|
||||
assert "text/event-stream" in content_type, content_type
|
||||
|
||||
events = _sse_events(res.text)
|
||||
assert events, res.text
|
||||
parsed = [json.loads(evt) for evt in events]
|
||||
assert isinstance(parsed[0], dict), parsed
|
||||
assert parsed[-1].get("data") is True, parsed[-1]
|
||||
219
test/testcases/restful_api/test_sessions.py
Normal file
219
test/testcases/restful_api/test_sessions.py
Normal file
@@ -0,0 +1,219 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _sse_events(response_text: str) -> list[str]:
|
||||
return [line[5:] for line in response_text.splitlines() if line.startswith("data:")]
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_session_crud_cycle(rest_client, create_chat):
|
||||
chat_id = create_chat("restful_session_crud_chat")
|
||||
|
||||
create_res = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": "session_a"})
|
||||
assert create_res.status_code == 200
|
||||
create_payload = create_res.json()
|
||||
assert create_payload["code"] == 0, create_payload
|
||||
session_id = create_payload["data"]["id"]
|
||||
assert create_payload["data"]["chat_id"] == chat_id, create_payload
|
||||
|
||||
list_res = rest_client.get(f"/chats/{chat_id}/sessions")
|
||||
assert list_res.status_code == 200
|
||||
list_payload = list_res.json()
|
||||
assert list_payload["code"] == 0, list_payload
|
||||
assert any(item["id"] == session_id for item in list_payload["data"]), list_payload
|
||||
|
||||
get_res = rest_client.get(f"/chats/{chat_id}/sessions/{session_id}")
|
||||
assert get_res.status_code == 200
|
||||
get_payload = get_res.json()
|
||||
assert get_payload["code"] == 0, get_payload
|
||||
assert get_payload["data"]["id"] == session_id, get_payload
|
||||
|
||||
patch_res = rest_client.patch(
|
||||
f"/chats/{chat_id}/sessions/{session_id}",
|
||||
json={"name": "session_a_updated"},
|
||||
)
|
||||
assert patch_res.status_code == 200
|
||||
patch_payload = patch_res.json()
|
||||
assert patch_payload["code"] == 0, patch_payload
|
||||
assert patch_payload["data"]["name"] == "session_a_updated", patch_payload
|
||||
|
||||
delete_res = rest_client.delete(f"/chats/{chat_id}/sessions", json={"ids": [session_id]})
|
||||
assert delete_res.status_code == 200
|
||||
delete_payload = delete_res.json()
|
||||
assert delete_payload["code"] == 0, delete_payload
|
||||
|
||||
list_after_delete = rest_client.get(f"/chats/{chat_id}/sessions")
|
||||
assert list_after_delete.status_code == 200
|
||||
list_after_delete_payload = list_after_delete.json()
|
||||
assert list_after_delete_payload["code"] == 0, list_after_delete_payload
|
||||
assert all(item["id"] != session_id for item in list_after_delete_payload["data"]), list_after_delete_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_session_create_name_validation(rest_client, create_chat):
|
||||
chat_id = create_chat("restful_session_name_validation_chat")
|
||||
|
||||
res = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": " "})
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 102, payload
|
||||
assert "`name` can not be empty." in payload["message"], payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_session_update_blocks_messages_and_reference(rest_client, create_chat):
|
||||
chat_id = create_chat("restful_session_guard_chat")
|
||||
create_res = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": "session_guard"})
|
||||
assert create_res.status_code == 200
|
||||
create_payload = create_res.json()
|
||||
assert create_payload["code"] == 0, create_payload
|
||||
session_id = create_payload["data"]["id"]
|
||||
|
||||
msg_res = rest_client.patch(f"/chats/{chat_id}/sessions/{session_id}", json={"messages": []})
|
||||
assert msg_res.status_code == 200
|
||||
msg_payload = msg_res.json()
|
||||
assert msg_payload["code"] == 102, msg_payload
|
||||
assert "`messages` cannot be changed." in msg_payload["message"], msg_payload
|
||||
|
||||
ref_res = rest_client.patch(f"/chats/{chat_id}/sessions/{session_id}", json={"reference": []})
|
||||
assert ref_res.status_code == 200
|
||||
ref_payload = ref_res.json()
|
||||
assert ref_payload["code"] == 102, ref_payload
|
||||
assert "`reference` cannot be changed." in ref_payload["message"], ref_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_chat_recommendation_requires_question(rest_client):
|
||||
res = rest_client.post("/chat/recommendation", json={})
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 101, payload
|
||||
assert "required argument are missing: question" in payload["message"], payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_related_questions_compatibility_requires_auth(rest_client_noauth):
|
||||
# /api/v1/searchbots/related_questions is an SDK compatibility endpoint.
|
||||
res = rest_client_noauth.post(
|
||||
"/searchbots/related_questions",
|
||||
json={"question": "ragflow"},
|
||||
headers={"Authorization": "invalid"},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 102, payload
|
||||
assert "Authorization is not valid!" in payload["message"], payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_chat_completion_nonstream_with_session(rest_client, create_chat):
|
||||
chat_id = create_chat("restful_completion_nonstream_chat")
|
||||
create_session_res = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": "session_for_completion"})
|
||||
assert create_session_res.status_code == 200
|
||||
create_session_payload = create_session_res.json()
|
||||
assert create_session_payload["code"] == 0, create_session_payload
|
||||
session_id = create_session_payload["data"]["id"]
|
||||
|
||||
completion_res = rest_client.post(
|
||||
"/chat/completions",
|
||||
json={
|
||||
"chat_id": chat_id,
|
||||
"session_id": session_id,
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"stream": False,
|
||||
},
|
||||
timeout=60,
|
||||
)
|
||||
assert completion_res.status_code == 200
|
||||
completion_payload = completion_res.json()
|
||||
assert completion_payload["code"] == 0, completion_payload
|
||||
assert isinstance(completion_payload["data"], dict), completion_payload
|
||||
assert completion_payload["data"]["session_id"] == session_id, completion_payload
|
||||
assert "answer" in completion_payload["data"], completion_payload
|
||||
assert "reference" in completion_payload["data"], completion_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_chat_completion_stream_events(rest_client, create_chat):
|
||||
chat_id = create_chat("restful_completion_stream_chat")
|
||||
stream_res = rest_client.post(
|
||||
"/chat/completions",
|
||||
json={
|
||||
"chat_id": chat_id,
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"stream": True,
|
||||
},
|
||||
timeout=60,
|
||||
)
|
||||
assert stream_res.status_code == 200
|
||||
content_type = stream_res.headers.get("Content-Type", "")
|
||||
assert "text/event-stream" in content_type, content_type
|
||||
|
||||
events = _sse_events(stream_res.text)
|
||||
assert events, stream_res.text
|
||||
parsed_events = []
|
||||
for event in events:
|
||||
parsed = json.loads(event)
|
||||
parsed_events.append(parsed)
|
||||
|
||||
assert any(evt.get("code") == 0 and isinstance(evt.get("data"), dict) for evt in parsed_events), parsed_events
|
||||
assert parsed_events[-1].get("data") is True, parsed_events[-1]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_chat_completion_validation_errors(rest_client, create_chat):
|
||||
chat_id = create_chat("restful_completion_validation_chat")
|
||||
|
||||
missing_messages = rest_client.post(
|
||||
"/chat/completions",
|
||||
json={"chat_id": chat_id, "stream": False},
|
||||
)
|
||||
assert missing_messages.status_code == 200
|
||||
missing_messages_payload = missing_messages.json()
|
||||
assert missing_messages_payload["code"] == 101, missing_messages_payload
|
||||
assert "required argument are missing: messages" in missing_messages_payload["message"], missing_messages_payload
|
||||
|
||||
missing_chat_for_session = rest_client.post(
|
||||
"/chat/completions",
|
||||
json={
|
||||
"session_id": "some_session",
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
assert missing_chat_for_session.status_code == 200
|
||||
missing_chat_for_session_payload = missing_chat_for_session.json()
|
||||
assert missing_chat_for_session_payload["code"] == 102, missing_chat_for_session_payload
|
||||
assert "`chat_id` is required when `session_id` is provided." in missing_chat_for_session_payload["message"], missing_chat_for_session_payload
|
||||
|
||||
invalid_chat = rest_client.post(
|
||||
"/chat/completions",
|
||||
json={
|
||||
"chat_id": "invalid_chat_id",
|
||||
"session_id": "invalid_session",
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
assert invalid_chat.status_code == 200
|
||||
invalid_chat_payload = invalid_chat.json()
|
||||
assert invalid_chat_payload["code"] == 109, invalid_chat_payload
|
||||
assert "No authorization." in invalid_chat_payload["message"], invalid_chat_payload
|
||||
159
test/testcases/restful_api/test_system.py
Normal file
159
test/testcases/restful_api/test_system.py
Normal file
@@ -0,0 +1,159 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_system_ping(rest_client):
|
||||
res = rest_client.get("/system/ping")
|
||||
assert res.status_code == 200
|
||||
assert res.text == "pong"
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_system_version(rest_client):
|
||||
res = rest_client.get("/system/version")
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 0, payload
|
||||
assert payload["data"], payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_system_status_requires_auth(rest_client_noauth):
|
||||
res = rest_client_noauth.get("/system/status")
|
||||
assert res.status_code == 401
|
||||
payload = res.json()
|
||||
assert payload["code"] == 401, payload
|
||||
assert "Unauthorized" in payload["message"], payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_system_status_contract(rest_client):
|
||||
res = rest_client.get("/system/status")
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 0, payload
|
||||
for key in ("doc_engine", "storage", "database", "redis"):
|
||||
assert key in payload["data"], payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_system_config_no_auth_required(rest_client_noauth):
|
||||
res = rest_client_noauth.get("/system/config")
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 0, payload
|
||||
assert "registerEnabled" in payload["data"], payload
|
||||
assert "disablePasswordLogin" in payload["data"], payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_system_healthz_contract(rest_client_noauth):
|
||||
res = rest_client_noauth.get("/system/healthz")
|
||||
assert res.status_code in (200, 500)
|
||||
payload = res.json()
|
||||
assert isinstance(payload, dict), payload
|
||||
assert payload, payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_system_tokens_auth_and_crud(rest_client, rest_client_noauth):
|
||||
unauth_list = rest_client_noauth.get("/system/tokens")
|
||||
assert unauth_list.status_code == 401
|
||||
unauth_list_payload = unauth_list.json()
|
||||
assert unauth_list_payload["code"] == 401, unauth_list_payload
|
||||
|
||||
create_res = rest_client.post("/system/tokens")
|
||||
assert create_res.status_code == 200
|
||||
create_payload = create_res.json()
|
||||
assert create_payload["code"] == 0, create_payload
|
||||
token = create_payload["data"]["token"]
|
||||
|
||||
list_res = rest_client.get("/system/tokens")
|
||||
assert list_res.status_code == 200
|
||||
list_payload = list_res.json()
|
||||
assert list_payload["code"] == 0, list_payload
|
||||
assert isinstance(list_payload["data"], list), list_payload
|
||||
assert any(item.get("token") == token for item in list_payload["data"]), list_payload
|
||||
|
||||
delete_res = rest_client.delete(f"/system/tokens/{token}")
|
||||
assert delete_res.status_code == 200
|
||||
delete_payload = delete_res.json()
|
||||
assert delete_payload["code"] == 0, delete_payload
|
||||
assert delete_payload["data"] is True, delete_payload
|
||||
|
||||
delete_missing = rest_client.delete("/system/tokens/missing_token")
|
||||
assert delete_missing.status_code == 200
|
||||
delete_missing_payload = delete_missing.json()
|
||||
assert delete_missing_payload["code"] == 0, delete_missing_payload
|
||||
assert delete_missing_payload["data"] is True, delete_missing_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_system_stats_auth_and_shape(rest_client, rest_client_noauth):
|
||||
unauth_res = rest_client_noauth.get("/system/stats")
|
||||
assert unauth_res.status_code == 401
|
||||
unauth_payload = unauth_res.json()
|
||||
assert unauth_payload["code"] == 401, unauth_payload
|
||||
|
||||
res = rest_client.get("/system/stats")
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 0, payload
|
||||
data = payload["data"]
|
||||
for key in ("pv", "uv", "speed", "tokens", "round", "thumb_up"):
|
||||
assert key in data, payload
|
||||
assert isinstance(data[key], list), payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_system_oceanbase_status_auth_contract(rest_client, rest_client_noauth):
|
||||
unauth = rest_client_noauth.get("/system/oceanbase/status")
|
||||
assert unauth.status_code == 401
|
||||
assert unauth.json()["code"] == 401
|
||||
|
||||
res = rest_client.get("/system/oceanbase/status")
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] in (0, 500), payload
|
||||
assert "data" in payload, payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_system_log_config_routes_auth_and_validation(rest_client, rest_client_noauth):
|
||||
unauth = rest_client_noauth.get("/system/config/log")
|
||||
assert unauth.status_code == 401
|
||||
assert unauth.json()["code"] == 401
|
||||
|
||||
levels = rest_client.get("/system/config/log")
|
||||
assert levels.status_code == 200
|
||||
levels_payload = levels.json()
|
||||
assert levels_payload["code"] == 0, levels_payload
|
||||
assert isinstance(levels_payload["data"], dict), levels_payload
|
||||
|
||||
missing_body = rest_client.put("/system/config/log", json={})
|
||||
assert missing_body.status_code == 200
|
||||
missing_payload = missing_body.json()
|
||||
assert missing_payload["code"] == 102, missing_payload
|
||||
assert "pkg_name and level are required" in missing_payload["message"], missing_payload
|
||||
|
||||
invalid_level = rest_client.put("/system/config/log", json={"pkg_name": "rag", "level": "NOT_A_LEVEL"})
|
||||
assert invalid_level.status_code == 200
|
||||
invalid_payload = invalid_level.json()
|
||||
assert invalid_payload["code"] == 102, invalid_payload
|
||||
assert "Invalid log level" in invalid_payload["message"], invalid_payload
|
||||
48
test/testcases/restful_api/test_task_routes.py
Normal file
48
test/testcases/restful_api/test_task_routes.py
Normal file
@@ -0,0 +1,48 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_task_routes_require_auth(rest_client_noauth):
|
||||
cancel_res = rest_client_noauth.post("/tasks/missing_task/cancel")
|
||||
assert cancel_res.status_code == 401
|
||||
cancel_payload = cancel_res.json()
|
||||
assert cancel_payload["code"] == 401, cancel_payload
|
||||
|
||||
patch_res = rest_client_noauth.patch("/tasks/missing_task", json={"action": "stop"})
|
||||
assert patch_res.status_code == 401
|
||||
patch_payload = patch_res.json()
|
||||
assert patch_payload["code"] == 401, patch_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_patch_task_rejects_unsupported_action(rest_client):
|
||||
res = rest_client.patch("/tasks/missing_task", json={"action": "pause"})
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 101, payload
|
||||
assert "Only 'stop' is supported" in payload["message"], payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_cancel_missing_task_sets_cancel_contract(rest_client):
|
||||
res = rest_client.post("/tasks/missing_task/cancel")
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 0, payload
|
||||
assert payload["data"] is True, payload
|
||||
1382
test/testcases/restful_api/test_user_tenant_routes_unit.py
Normal file
1382
test/testcases/restful_api/test_user_tenant_routes_unit.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user