mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
@@ -51,7 +51,9 @@ from api.db.services.user_service import TenantService, UserService
|
||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||
from api.utils.api_utils import (
|
||||
add_tenant_id_to_kwargs,
|
||||
check_duplicate_ids,
|
||||
get_data_error_result,
|
||||
get_error_data_result,
|
||||
get_json_result,
|
||||
get_result,
|
||||
get_request_json,
|
||||
@@ -441,6 +443,61 @@ def delete_agent_session_item(agent_id, session_id, tenant_id):
|
||||
return get_json_result(data=API4ConversationService.delete_by_id(session_id))
|
||||
|
||||
|
||||
@manager.route("/agents/<agent_id>/sessions", methods=["DELETE"]) # noqa: F821
|
||||
@login_required
|
||||
@add_tenant_id_to_kwargs
|
||||
@_require_canvas_access_async
|
||||
async def delete_agent_session(tenant_id, agent_id):
|
||||
errors = []
|
||||
success_count = 0
|
||||
req = await get_request_json()
|
||||
cvs = await thread_pool_exec(UserCanvasService.query, user_id=tenant_id, id=agent_id)
|
||||
if not cvs:
|
||||
return get_error_data_result(f"You don't own the agent {agent_id}")
|
||||
|
||||
if not req:
|
||||
return get_result()
|
||||
|
||||
ids = req.get("ids")
|
||||
if not ids:
|
||||
if req.get("delete_all") is True:
|
||||
ids = [conv.id for conv in await thread_pool_exec(API4ConversationService.query, dialog_id=agent_id)]
|
||||
if not ids:
|
||||
return get_result()
|
||||
else:
|
||||
return get_result()
|
||||
|
||||
conv_list = ids
|
||||
|
||||
unique_conv_ids, duplicate_messages = check_duplicate_ids(conv_list, "session")
|
||||
conv_list = unique_conv_ids
|
||||
|
||||
for session_id in conv_list:
|
||||
conv = await thread_pool_exec(API4ConversationService.query, id=session_id, dialog_id=agent_id)
|
||||
if not conv:
|
||||
errors.append(f"The agent doesn't own the session {session_id}")
|
||||
continue
|
||||
await thread_pool_exec(API4ConversationService.delete_by_id, session_id)
|
||||
success_count += 1
|
||||
|
||||
if errors:
|
||||
if success_count > 0:
|
||||
return get_result(data={"success_count": success_count, "errors": errors},
|
||||
message=f"Partially deleted {success_count} sessions with {len(errors)} errors")
|
||||
else:
|
||||
return get_error_data_result(message="; ".join(errors))
|
||||
|
||||
if duplicate_messages:
|
||||
if success_count > 0:
|
||||
return get_result(
|
||||
message=f"Partially deleted {success_count} sessions with {len(duplicate_messages)} errors",
|
||||
data={"success_count": success_count, "errors": duplicate_messages})
|
||||
else:
|
||||
return get_error_data_result(message=";".join(duplicate_messages))
|
||||
|
||||
return get_result()
|
||||
|
||||
|
||||
@manager.route("/agents/download", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
@add_tenant_id_to_kwargs
|
||||
|
||||
@@ -26,7 +26,6 @@ from api.db.db_models import APIToken
|
||||
from api.db.services.api_service import API4ConversationService
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from api.db.services.canvas_service import completion as agent_completion
|
||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||
from api.db.services.conversation_service import async_iframe_completion as iframe_completion
|
||||
from api.db.services.dialog_service import DialogService, async_ask, gen_mindmap
|
||||
from api.db.services.doc_metadata_service import DocMetadataService
|
||||
@@ -36,9 +35,9 @@ from common.metadata_utils import apply_meta_data_filter
|
||||
from api.db.services.search_service import SearchService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type, get_model_config_from_provider_instance
|
||||
from common.misc_utils import get_uuid, thread_pool_exec
|
||||
from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_json_result, \
|
||||
get_result, get_request_json, server_error_response, token_required, validate_request
|
||||
from common.misc_utils import thread_pool_exec
|
||||
from api.utils.api_utils import get_error_data_result, get_json_result, \
|
||||
get_result, get_request_json, server_error_response, validate_request
|
||||
from rag.app.tag import label_question
|
||||
from rag.prompts.template import load_prompt
|
||||
from rag.prompts.generator import cross_languages, keyword_extraction
|
||||
@@ -59,97 +58,6 @@ def _get_sdk_authorization_token():
|
||||
return token[1]
|
||||
|
||||
|
||||
@token_required
|
||||
async def create_agent_session(tenant_id, agent_id):
|
||||
req = await get_request_json()
|
||||
user_id = req.get("user_id") or request.args.get("user_id", tenant_id)
|
||||
release_mode = bool(req.get("release", request.args.get("release", False)))
|
||||
|
||||
if not await thread_pool_exec(UserCanvasService.query, user_id=tenant_id, id=agent_id):
|
||||
return get_error_data_result("You cannot access the agent.")
|
||||
|
||||
try:
|
||||
cvs, dsl = await thread_pool_exec(UserCanvasService.get_agent_dsl_with_release, agent_id, release_mode, tenant_id)
|
||||
except LookupError:
|
||||
return get_error_data_result("Agent not found.")
|
||||
except PermissionError as e:
|
||||
return get_error_data_result(str(e))
|
||||
|
||||
session_id = get_uuid()
|
||||
canvas = Canvas(dsl, tenant_id, agent_id, canvas_id=cvs.id)
|
||||
canvas.reset()
|
||||
|
||||
cvs.dsl = json.loads(str(canvas))
|
||||
# Get the version title based on release_mode
|
||||
version_title = await thread_pool_exec(UserCanvasVersionService.get_latest_version_title, cvs.id, release_mode=release_mode)
|
||||
conv = {
|
||||
"id": session_id,
|
||||
"dialog_id": cvs.id,
|
||||
"user_id": user_id,
|
||||
"message": [{"role": "assistant", "content": canvas.get_prologue()}],
|
||||
"source": "agent",
|
||||
"dsl": cvs.dsl,
|
||||
"version_title": version_title
|
||||
}
|
||||
await thread_pool_exec(API4ConversationService.save, **conv)
|
||||
conv["agent_id"] = conv.pop("dialog_id")
|
||||
return get_result(data=conv)
|
||||
|
||||
|
||||
@manager.route("/agents/<agent_id>/sessions", methods=["DELETE"]) # noqa: F821
|
||||
@token_required
|
||||
async def delete_agent_session(tenant_id, agent_id):
|
||||
errors = []
|
||||
success_count = 0
|
||||
req = await get_request_json()
|
||||
cvs = await thread_pool_exec(UserCanvasService.query, user_id=tenant_id, id=agent_id)
|
||||
if not cvs:
|
||||
return get_error_data_result(f"You don't own the agent {agent_id}")
|
||||
|
||||
if not req:
|
||||
return get_result()
|
||||
|
||||
ids = req.get("ids")
|
||||
if not ids:
|
||||
if req.get("delete_all") is True:
|
||||
ids = [conv.id for conv in await thread_pool_exec(API4ConversationService.query, dialog_id=agent_id)]
|
||||
if not ids:
|
||||
return get_result()
|
||||
else:
|
||||
return get_result()
|
||||
|
||||
conv_list = ids
|
||||
|
||||
unique_conv_ids, duplicate_messages = check_duplicate_ids(conv_list, "session")
|
||||
conv_list = unique_conv_ids
|
||||
|
||||
for session_id in conv_list:
|
||||
conv = await thread_pool_exec(API4ConversationService.query, id=session_id, dialog_id=agent_id)
|
||||
if not conv:
|
||||
errors.append(f"The agent doesn't own the session {session_id}")
|
||||
continue
|
||||
await thread_pool_exec(API4ConversationService.delete_by_id, session_id)
|
||||
success_count += 1
|
||||
|
||||
if errors:
|
||||
if success_count > 0:
|
||||
return get_result(data={"success_count": success_count, "errors": errors},
|
||||
message=f"Partially deleted {success_count} sessions with {len(errors)} errors")
|
||||
else:
|
||||
return get_error_data_result(message="; ".join(errors))
|
||||
|
||||
if duplicate_messages:
|
||||
if success_count > 0:
|
||||
return get_result(
|
||||
message=f"Partially deleted {success_count} sessions with {len(duplicate_messages)} errors",
|
||||
data={"success_count": success_count, "errors": duplicate_messages})
|
||||
else:
|
||||
return get_error_data_result(message=";".join(duplicate_messages))
|
||||
|
||||
return get_result()
|
||||
|
||||
|
||||
|
||||
@manager.route("/chatbots/<dialog_id>/completions", methods=["POST"]) # noqa: F821
|
||||
async def chatbot_completions(dialog_id):
|
||||
req = await get_request_json()
|
||||
|
||||
@@ -44,7 +44,6 @@ from api.utils.api_utils import (
|
||||
get_request_json,
|
||||
get_result,
|
||||
server_error_response,
|
||||
token_required,
|
||||
)
|
||||
from api.utils.pagination_utils import validate_rest_api_page_size
|
||||
from api.utils.image_utils import store_chunk_image
|
||||
@@ -157,7 +156,8 @@ def _enrich_chunks_with_document_metadata(chunks: list[dict], metadata_fields=No
|
||||
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/chunks", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
@login_required
|
||||
@add_tenant_id_to_kwargs
|
||||
async def parse(tenant_id, dataset_id):
|
||||
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
||||
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
|
||||
@@ -212,7 +212,8 @@ async def parse(tenant_id, dataset_id):
|
||||
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/chunks", methods=["DELETE"]) # noqa: F821
|
||||
@token_required
|
||||
@login_required
|
||||
@add_tenant_id_to_kwargs
|
||||
async def stop_parsing(tenant_id, dataset_id):
|
||||
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
||||
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
|
||||
@@ -252,7 +253,8 @@ async def stop_parsing(tenant_id, dataset_id):
|
||||
|
||||
|
||||
@manager.route("/retrieval", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
@login_required
|
||||
@add_tenant_id_to_kwargs
|
||||
async def retrieval_test(tenant_id):
|
||||
req = await get_request_json()
|
||||
if not req.get("dataset_ids"):
|
||||
|
||||
@@ -108,9 +108,8 @@ def test_retrieval_compatibility_requires_auth(rest_client_noauth):
|
||||
res = rest_client_noauth.post("/retrieval", json={"question": "test", "dataset_ids": ["x"]})
|
||||
assert res.status_code == 401
|
||||
payload = res.json()
|
||||
# token_required preserves legacy payload code/message while returning HTTP 401.
|
||||
assert payload["code"] == 0, payload
|
||||
assert payload["message"] == "`Authorization` can't be empty", payload
|
||||
assert payload["code"] == 401, payload
|
||||
assert payload["message"] == "<Unauthorized '401: Unauthorized'>", payload
|
||||
|
||||
|
||||
@wait_for(20, 1, "Retrieval indexing timeout in RESTful batch 10 tests")
|
||||
@@ -151,14 +150,13 @@ def _retrieval_lacks_chunks(rest_client, dataset_id, question, chunk_ids):
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_retrieval_requires_auth_contract(ensure_parsed_document):
|
||||
dataset_id, _ = ensure_parsed_document()
|
||||
def test_retrieval_requires_auth_contract():
|
||||
for scenario_name, token, expected_code, expected_message in (
|
||||
("missing token", None, 0, "`Authorization` can't be empty"),
|
||||
("invalid token", INVALID_API_TOKEN, 109, "Authentication error: API key is invalid!"),
|
||||
("missing token", None, 401, "<Unauthorized '401: Unauthorized'>"),
|
||||
("invalid token", INVALID_API_TOKEN, 401, "<Unauthorized '401: Unauthorized'>"),
|
||||
):
|
||||
client = RestClient(token=token)
|
||||
res = client.post("/retrieval", json={"question": "chunk", "dataset_ids": [dataset_id]})
|
||||
res = client.post("/retrieval", json={"question": "chunk", "dataset_ids": ["x"]})
|
||||
assert res.status_code == 401, (scenario_name, res.text)
|
||||
payload = res.json()
|
||||
assert payload["code"] == expected_code, (scenario_name, payload)
|
||||
|
||||
@@ -70,7 +70,9 @@ def _load_agent_api(monkeypatch, get_by_id_result):
|
||||
monkeypatch,
|
||||
"api.utils.api_utils",
|
||||
add_tenant_id_to_kwargs=lambda func: func,
|
||||
check_duplicate_ids=lambda ids, _kind="item": (ids, []),
|
||||
get_data_error_result=lambda message="Sorry": {"code": 102, "message": message, "data": None},
|
||||
get_error_data_result=lambda message="Sorry": {"code": 102, "message": message, "data": None},
|
||||
get_json_result=lambda code=0, message="", data=None: {"code": code, "message": message, "data": data},
|
||||
get_result=lambda **kwargs: kwargs,
|
||||
get_request_json=lambda: {},
|
||||
|
||||
Reference in New Issue
Block a user