mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
## Summary Fixes #14985 — clicking the **Thinking** button in a shared/embedded chat returns 401 and bounces the user to the login page, even though the same share page can chat with the agent just fine. ## Root cause In shared chat, `useGetSharedChatSearchParams` binds `conversationId` to the URL's `shared_id` query param — which is the **beta APIToken**, not the real agent id. That `conversationId` propagates through the component tree: ```tsx <WorkFlowTimeline canvasId={conversationId}> → useFetchMessageTrace(canvasId) → GET /api/v1/agents/<sharedId>/logs/<messageId> ``` But `/agents/<agent_id>/logs/<message_id>` is decorated with `@login_required` (`api/apps/restful_apis/agent_api.py:842-846`). The share page only holds the beta token — there is no session JWT — so the request 401s and quart-auth redirects to the login page. The reporter's server log matches exactly: ``` load_user from jwt got exception No b'.' found in value load_user: No APIToken found for token=ULG10SWG3E... Unauthorized request (quart_auth) GET /api/v1/agents/394013f8d42211f0bad6123fa55e8ed9/logs/96fd72e2-... 1.1 401 ``` The `394013f8...` segment in the URL is the `shared_id` (beta token), not an actual agent id. `_load_user` already accepts the regular `APIToken.token` field, but not `APIToken.beta`, by design — beta is a much weaker share-link credential than a personal API key. The sibling endpoints `/agentbots/<id>/completions` and `/agentbots/<id>/inputs` already use the right auth pattern for this scope (beta-token via `_get_sdk_authorization_token` → `APIToken.query(beta=token)`). Trace just didn't have a parallel. ## Fix ### Backend (`api/apps/restful_apis/bot_api.py`) Added a beta-token sibling endpoint: ``` GET /api/v1/agentbots/<shared_id>/logs/<message_id> ``` - Same auth shape as the existing `agentbots` endpoints. - The `<shared_id>` path segment is a client-supplied label only. The real `agent_id` used to build the Redis key (`<agent_id>-<message_id>-logs`) is taken from `APIToken.dialog_id` on the looked-up token, so the endpoint never trusts client-supplied identifiers for the data lookup. - Returns the same `{data: ...}` shape as the existing `/agents/<id>/logs/<message_id>` endpoint, so the frontend doesn't need to reshape the response. ### Frontend - `web/src/utils/api.ts`: added `sharedTrace(sharedId, messageId)` URL builder. - `web/src/services/agent-service.ts`: added `fetchSharedTrace({ shared_id, message_id })`. - `web/src/hooks/use-agent-request.ts`: `useFetchMessageTrace` takes an optional `isShare` argument. When set, it calls `fetchSharedTrace`; `isShare` is also folded into the `queryKey` so the two modes never share cached results. - `web/src/pages/agent/log-sheet/workflow-timeline.tsx`: forwards the already-existing `isShare` prop into the hook. All other existing call sites of `useFetchMessageTrace` (webhook timeline, pipeline log, dataflow result) pass no `isShare` argument → undefined → falsy → unchanged behavior. ## Test plan - [ ] In the regular Agent UI (logged-in user): open the trace / log sheet for any message and click into "Thinking" — the timeline should still load via `/agents/<id>/logs/<msg>`, same as before. - [ ] From the Agent page, click **Chat in new tab** to open `/chat/share?shared_id=<token>&from=agent`. Send a message, wait for a response, then click **Thinking** on the assistant turn. The trace panel should load instead of redirecting to the login page. - [ ] Same flow but with the agent embedded in an iframe ("Embed into webpage") — confirm there is no login redirect. - [ ] In DevTools → Network, confirm the share-chat trace request goes to `/api/v1/agentbots/<sharedId>/logs/<msgId>` and returns 200 with the same JSON shape as the logged-in path. - [ ] Confirm the chat completions, inputs, and upload flows in the share page still work — they were not touched. - [ ] Send a bogus / expired beta token to the new endpoint and confirm it returns the standard "Authentication error: API key is invalid!" response (no traceback, no 500). - [ ] Run `uv run pytest` to make sure no existing tests regress. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [ ] New Feature (non-breaking change which adds functionality) - [ ] Documentation Update - [ ] Refactoring - [ ] Performance Improvement - [ ] Other (please describe): --------- Co-authored-by: Zhichang Yu <yuzhichang@gmail.com>
560 lines
24 KiB
Python
560 lines
24 KiB
Python
#
|
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
import copy
|
|
import hashlib
|
|
import json
|
|
import re
|
|
|
|
import logging
|
|
|
|
from quart import Response, request
|
|
|
|
from agent.canvas import Canvas
|
|
from api.apps import AUTH_BETA, login_required
|
|
from api.db.services.api_service import API4ConversationService
|
|
from api.db.services.canvas_service import UserCanvasService
|
|
from api.db.services.canvas_service import completion as agent_completion
|
|
from api.db.services.conversation_service import async_iframe_completion as iframe_completion
|
|
from api.db.services.dialog_service import DialogService, async_ask, gen_mindmap
|
|
from api.db.services.doc_metadata_service import DocMetadataService
|
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
|
from api.db.services.llm_service import LLMBundle
|
|
from api.db.services.user_service import TenantService
|
|
from common.metadata_utils import apply_meta_data_filter
|
|
from api.db.services.search_service import SearchService
|
|
from api.db.services.user_service import UserTenantService
|
|
from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type, get_model_config_from_provider_instance
|
|
from common.misc_utils import thread_pool_exec
|
|
from api.utils.api_utils import get_error_data_result, get_json_result, \
|
|
add_tenant_id_to_kwargs, get_result, get_request_json, server_error_response, validate_request
|
|
from rag.app.tag import label_question
|
|
from rag.prompts.template import load_prompt
|
|
from rag.prompts.generator import cross_languages, keyword_extraction
|
|
from common.constants import RetCode, LLMType, StatusEnum
|
|
from common import settings
|
|
from api.utils.reference_metadata_utils import (
|
|
enrich_chunks_with_document_metadata,
|
|
resolve_reference_metadata_preferences,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@manager.route("/chatbots/<dialog_id>/completions", methods=["POST"]) # noqa: F821
|
|
@login_required(auth_types=AUTH_BETA)
|
|
@add_tenant_id_to_kwargs
|
|
async def chatbot_completions(dialog_id, tenant_id=None):
|
|
req = await get_request_json()
|
|
|
|
exists, dialog = DialogService.get_by_id(dialog_id)
|
|
if (not exists
|
|
or getattr(dialog, "tenant_id", None) != tenant_id
|
|
or str(getattr(dialog, "status", "")) != StatusEnum.VALID.value):
|
|
logger.warning(
|
|
"Denied chatbot access: reason=%s tenant_id=%s dialog_id=%s user_id=%s session_id=%s",
|
|
"no access to this chatbot",
|
|
tenant_id,
|
|
dialog_id,
|
|
req.get("user_id"),
|
|
req.get("session_id"),
|
|
)
|
|
return get_error_data_result(message="Authentication error: no access to this chatbot!")
|
|
|
|
if "quote" not in req:
|
|
req["quote"] = False
|
|
|
|
def _validate_iframe_access():
|
|
if req.get("session_id"):
|
|
exists, conv = API4ConversationService.get_by_id(req.get("session_id"))
|
|
if not exists:
|
|
raise AssertionError("Session not found!")
|
|
if conv.dialog_id != dialog_id:
|
|
raise AssertionError("Session does not belong to this dialog")
|
|
if tenant_id and conv.user_id and conv.user_id != tenant_id:
|
|
raise AssertionError("Session does not belong to this tenant")
|
|
|
|
if req.get("stream", True):
|
|
try:
|
|
_validate_iframe_access()
|
|
except AssertionError:
|
|
logger.warning(
|
|
"Denied chatbot completion stream: reason=%s tenant_id=%s dialog_id=%s user_id=%s session_id=%s",
|
|
"no access to this chatbot",
|
|
tenant_id,
|
|
dialog_id,
|
|
req.get("user_id"),
|
|
req.get("session_id"),
|
|
)
|
|
return get_error_data_result(message="Authentication error: no access to this chatbot!")
|
|
|
|
resp = Response(iframe_completion(dialog_id, tenant_id=tenant_id, **req), mimetype="text/event-stream")
|
|
resp.headers.add_header("Cache-control", "no-cache")
|
|
resp.headers.add_header("Connection", "keep-alive")
|
|
resp.headers.add_header("X-Accel-Buffering", "no")
|
|
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
|
return resp
|
|
|
|
try:
|
|
_validate_iframe_access()
|
|
async for answer in iframe_completion(dialog_id, tenant_id=tenant_id, **req):
|
|
return get_result(data=answer)
|
|
except AssertionError:
|
|
logger.warning(
|
|
"Denied chatbot completion: reason=%s tenant_id=%s dialog_id=%s user_id=%s session_id=%s",
|
|
"no access to this chatbot",
|
|
tenant_id,
|
|
dialog_id,
|
|
req.get("user_id"),
|
|
req.get("session_id"),
|
|
)
|
|
return get_error_data_result(message="Authentication error: no access to this chatbot!")
|
|
|
|
return None
|
|
|
|
@manager.route("/chatbots/<dialog_id>/info", methods=["GET"]) # noqa: F821
|
|
@login_required(auth_types=AUTH_BETA)
|
|
@add_tenant_id_to_kwargs
|
|
async def chatbots_inputs(dialog_id, tenant_id=None):
|
|
exists, dialog = await thread_pool_exec(DialogService.get_by_id, dialog_id)
|
|
if (not exists
|
|
or getattr(dialog, "tenant_id", None) != tenant_id
|
|
or str(getattr(dialog, "status", "")) != StatusEnum.VALID.value):
|
|
request_args = getattr(request, "args", {}) or {}
|
|
request_user_id = request_args.get("user_id") if hasattr(request_args, "get") else None
|
|
request_session_id = request_args.get("session_id") if hasattr(request_args, "get") else None
|
|
logger.warning(
|
|
"Denied chatbot access: reason=%s tenant_id=%s dialog_id=%s user_id=%s session_id=%s",
|
|
"no access to this chatbot",
|
|
tenant_id,
|
|
dialog_id,
|
|
request_user_id,
|
|
request_session_id,
|
|
)
|
|
return get_error_data_result(message="Authentication error: no access to this chatbot!")
|
|
return get_result(
|
|
data={
|
|
"title": dialog.name,
|
|
"avatar": dialog.icon,
|
|
"prologue": dialog.prompt_config.get("prologue", ""),
|
|
"has_tavily_key": bool(dialog.prompt_config.get("tavily_api_key", "").strip()),
|
|
"llm_id": dialog.llm_id or "",
|
|
}
|
|
)
|
|
|
|
|
|
@manager.route("/agentbots/<agent_id>/completions", methods=["POST"]) # noqa: F821
|
|
@login_required(auth_types=AUTH_BETA)
|
|
@add_tenant_id_to_kwargs
|
|
async def agent_bot_completions(agent_id, tenant_id=None):
|
|
req = await get_request_json()
|
|
|
|
if req.get("stream", True):
|
|
async def stream():
|
|
try:
|
|
async for answer in agent_completion(tenant_id, agent_id, **req):
|
|
yield answer
|
|
except Exception as e:
|
|
logging.exception(e)
|
|
error_result = get_error_data_result(message=str(e) or "Unknown error")
|
|
yield "data:" + json.dumps(
|
|
{
|
|
"event": "message",
|
|
"data": {"content": f"Error {error_result['code']}: {error_result['message']}\n\n"},
|
|
**error_result,
|
|
},
|
|
ensure_ascii=False,
|
|
) + "\n\n"
|
|
|
|
resp = Response(stream(), mimetype="text/event-stream")
|
|
resp.headers.add_header("Cache-control", "no-cache")
|
|
resp.headers.add_header("Connection", "keep-alive")
|
|
resp.headers.add_header("X-Accel-Buffering", "no")
|
|
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
|
return resp
|
|
|
|
try:
|
|
full_content = ""
|
|
reference = {}
|
|
structured_output = {}
|
|
final_ans = {}
|
|
async for answer in agent_completion(tenant_id, agent_id, **req):
|
|
# agent_completion yields SSE-formatted strings. A single yielded
|
|
# chunk can contain multiple "data:..." frames separated by "\n\n"
|
|
# plus blank or comment lines, so parse line-by-line rather than
|
|
# assuming one frame per chunk.
|
|
if not isinstance(answer, str):
|
|
continue
|
|
for line in answer.splitlines():
|
|
line = line.strip()
|
|
if not line.startswith("data:"):
|
|
continue
|
|
payload = line[len("data:"):].strip()
|
|
if not payload:
|
|
continue
|
|
try:
|
|
ans = json.loads(payload)
|
|
except Exception as e:
|
|
logging.debug("agent_bot_completions: skipping malformed SSE frame: %s", e)
|
|
continue
|
|
event = ans.get("event")
|
|
if event == "message":
|
|
full_content += ans.get("data", {}).get("content", "") or ""
|
|
if ans.get("data", {}).get("reference"):
|
|
reference.update(ans["data"]["reference"])
|
|
if event == "node_finished":
|
|
data = ans.get("data", {})
|
|
node_out = data.get("outputs") or {}
|
|
component_id = data.get("component_id")
|
|
if component_id is not None and "structured" in node_out:
|
|
structured_output[component_id] = copy.deepcopy(node_out["structured"])
|
|
final_ans = ans
|
|
|
|
if not final_ans:
|
|
return get_result(data={})
|
|
|
|
if "data" not in final_ans or not isinstance(final_ans["data"], dict):
|
|
final_ans["data"] = {}
|
|
final_ans["data"]["content"] = full_content
|
|
final_ans["data"]["reference"] = reference
|
|
if structured_output:
|
|
final_ans["data"]["structured"] = structured_output
|
|
return get_result(data=final_ans)
|
|
except Exception as e:
|
|
logging.exception(e)
|
|
return get_error_data_result(message=str(e) or "Unknown error")
|
|
|
|
|
|
@manager.route("/agentbots/<agent_id>/inputs", methods=["GET"]) # noqa: F821
|
|
@login_required(auth_types=AUTH_BETA)
|
|
@add_tenant_id_to_kwargs
|
|
async def begin_inputs(agent_id, tenant_id=None):
|
|
e, cvs = await thread_pool_exec(UserCanvasService.get_by_id, agent_id)
|
|
if not e:
|
|
return get_error_data_result(f"Can't find agent by ID: {agent_id}")
|
|
|
|
canvas = Canvas(json.dumps(cvs.dsl), tenant_id, canvas_id=cvs.id)
|
|
return get_result(
|
|
data={"title": cvs.title, "avatar": cvs.avatar, "inputs": canvas.get_component_input_form("begin"),
|
|
"prologue": canvas.get_prologue(), "mode": canvas.get_mode()})
|
|
|
|
|
|
@manager.route("/agentbots/<shared_id>/logs/<message_id>", methods=["GET"]) # noqa: F821
|
|
async def agent_bot_logs(shared_id, message_id):
|
|
# Beta-token sibling of /agents/<agent_id>/logs/<message_id>.
|
|
# Used by the shared/embedded chat page's "Thinking" button (fixes #14985).
|
|
# The <shared_id> path segment is just the value the client passed in the
|
|
# URL (it equals the beta token in the share flow); authentication comes
|
|
# from the Authorization header and the real agent_id is read from the
|
|
# looked-up APIToken so we never trust client-supplied identifiers.
|
|
from rag.utils.redis_conn import REDIS_CONN
|
|
|
|
token = _get_sdk_authorization_token()
|
|
if not token:
|
|
logger.warning(
|
|
"agent_bot_logs: missing Authorization header (shared_id=%s message_id=%s)",
|
|
shared_id, message_id,
|
|
)
|
|
return get_error_data_result(message='Authorization is not valid!')
|
|
# Non-reversible fingerprint of the share token: lets operators correlate
|
|
# auth-failure log lines for the same token without leaking a guessable
|
|
# substring of the secret itself.
|
|
token_fp = hashlib.sha256(token.encode("utf-8")).hexdigest()[:16]
|
|
objs = await thread_pool_exec(APIToken.query, beta=token)
|
|
if not objs:
|
|
logger.warning(
|
|
"agent_bot_logs: invalid beta token (fingerprint=%s shared_id=%s)",
|
|
token_fp, shared_id,
|
|
)
|
|
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
|
|
|
agent_id = objs[0].dialog_id
|
|
if not agent_id:
|
|
logger.warning(
|
|
"agent_bot_logs: APIToken has no dialog_id (tenant_id=%s fingerprint=%s)",
|
|
objs[0].tenant_id, token_fp,
|
|
)
|
|
return get_error_data_result(message='API token is not bound to an agent.')
|
|
|
|
try:
|
|
binary = await thread_pool_exec(REDIS_CONN.get, f"{agent_id}-{message_id}-logs")
|
|
if not binary:
|
|
return get_json_result(data={})
|
|
payload = binary.decode("utf-8") if isinstance(binary, bytes) else binary
|
|
return get_json_result(data=json.loads(payload))
|
|
except Exception as exc:
|
|
logging.exception(exc)
|
|
return server_error_response(exc)
|
|
|
|
|
|
@manager.route("/searchbots/ask", methods=["POST"]) # noqa: F821
|
|
@login_required(auth_types=AUTH_BETA)
|
|
@add_tenant_id_to_kwargs
|
|
@validate_request("question", "kb_ids")
|
|
async def ask_about_embedded(tenant_id=None):
|
|
req = await get_request_json()
|
|
uid = tenant_id
|
|
|
|
search_id = req.get("search_id", "")
|
|
search_config = {}
|
|
if search_id:
|
|
if search_app := await thread_pool_exec(SearchService.get_detail, search_id):
|
|
search_config = search_app.get("search_config", {})
|
|
|
|
chat_llm_name = ""
|
|
if not search_config or not search_config.get("chat_id"):
|
|
_, tenant_info = TenantService.get_by_id(uid)
|
|
chat_llm_name = tenant_info.llm_id
|
|
|
|
async def stream():
|
|
nonlocal req, uid
|
|
try:
|
|
async for ans in async_ask(req["question"], req["kb_ids"], uid, chat_llm_name=chat_llm_name, search_config=search_config):
|
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
|
except Exception as e:
|
|
yield "data:" + json.dumps(
|
|
{"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}},
|
|
ensure_ascii=False) + "\n\n"
|
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
|
|
|
resp = Response(stream(), mimetype="text/event-stream")
|
|
resp.headers.add_header("Cache-control", "no-cache")
|
|
resp.headers.add_header("Connection", "keep-alive")
|
|
resp.headers.add_header("X-Accel-Buffering", "no")
|
|
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
|
return resp
|
|
|
|
|
|
@manager.route("/searchbots/retrieval_test", methods=["POST"]) # noqa: F821
|
|
@login_required(auth_types=AUTH_BETA)
|
|
@add_tenant_id_to_kwargs
|
|
@validate_request("kb_id", "question")
|
|
async def retrieval_test_embedded(tenant_id=None):
|
|
req = await get_request_json()
|
|
page = int(req.get("page", 1))
|
|
size = int(req.get("size", 30))
|
|
question = req["question"]
|
|
kb_ids = req["kb_id"]
|
|
if isinstance(kb_ids, str):
|
|
kb_ids = [kb_ids]
|
|
if not kb_ids:
|
|
return get_json_result(data=False, message='Please specify dataset firstly.',
|
|
code=RetCode.DATA_ERROR)
|
|
doc_ids = req.get("doc_ids", [])
|
|
similarity_threshold = float(req.get("similarity_threshold", 0.0))
|
|
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
|
|
use_kg = req.get("use_kg", False)
|
|
top = int(req.get("top_k", 1024))
|
|
if top <= 0:
|
|
return get_error_data_result("`top_k` must be greater than 0")
|
|
langs = req.get("cross_languages", [])
|
|
rerank_id = req.get("rerank_id", "")
|
|
if not tenant_id:
|
|
return get_error_data_result(message="permission denined.")
|
|
search_config = {}
|
|
|
|
async def _retrieval():
|
|
nonlocal similarity_threshold, vector_similarity_weight, top, rerank_id
|
|
local_doc_ids = list(doc_ids) if doc_ids else []
|
|
tenant_ids = []
|
|
_question = question
|
|
|
|
meta_data_filter = {}
|
|
chat_mdl = None
|
|
if req.get("search_id", ""):
|
|
nonlocal search_config
|
|
detail = await thread_pool_exec(SearchService.get_detail, req.get("search_id", ""))
|
|
if detail:
|
|
search_config = detail.get("search_config", {})
|
|
meta_data_filter = search_config.get("meta_data_filter", {})
|
|
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
|
|
chat_id = search_config.get("chat_id", "")
|
|
if chat_id:
|
|
chat_model_config = await thread_pool_exec(get_model_config_from_provider_instance, tenant_id, LLMType.CHAT, chat_id)
|
|
else:
|
|
chat_model_config = await thread_pool_exec(get_tenant_default_model_by_type, tenant_id, LLMType.CHAT)
|
|
chat_mdl = LLMBundle(tenant_id, chat_model_config)
|
|
# Apply search_config settings if not explicitly provided in request
|
|
if not req.get("similarity_threshold"):
|
|
similarity_threshold = float(search_config.get("similarity_threshold", similarity_threshold))
|
|
if not req.get("vector_similarity_weight"):
|
|
vector_similarity_weight = float(search_config.get("vector_similarity_weight", vector_similarity_weight))
|
|
if not req.get("top_k"):
|
|
top = int(search_config.get("top_k", top))
|
|
if not req.get("rerank_id"):
|
|
rerank_id = search_config.get("rerank_id", "")
|
|
else:
|
|
meta_data_filter = req.get("meta_data_filter") or {}
|
|
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
|
|
chat_model_config = await thread_pool_exec(get_tenant_default_model_by_type, tenant_id, LLMType.CHAT)
|
|
chat_mdl = LLMBundle(tenant_id, chat_model_config)
|
|
|
|
if meta_data_filter:
|
|
local_doc_ids = await apply_meta_data_filter(
|
|
meta_data_filter,
|
|
None,
|
|
_question,
|
|
chat_mdl,
|
|
local_doc_ids,
|
|
kb_ids=kb_ids,
|
|
metas_loader=lambda: DocMetadataService.get_flatted_meta_by_kbs(kb_ids),
|
|
)
|
|
|
|
tenants = await thread_pool_exec(UserTenantService.query, user_id=tenant_id)
|
|
for kb_id in kb_ids:
|
|
for tenant in tenants:
|
|
if await thread_pool_exec(KnowledgebaseService.query, tenant_id=tenant.tenant_id, id=kb_id):
|
|
tenant_ids.append(tenant.tenant_id)
|
|
break
|
|
else:
|
|
return get_json_result(data=False, message="Only owner of dataset authorized for this operation.",
|
|
code=RetCode.OPERATING_ERROR)
|
|
|
|
e, kb = await thread_pool_exec(KnowledgebaseService.get_by_id, kb_ids[0])
|
|
if not e:
|
|
return get_error_data_result(message="Knowledgebase not found!")
|
|
|
|
if langs:
|
|
_question = await cross_languages(kb.tenant_id, None, _question, langs)
|
|
embd_model_config = await thread_pool_exec(get_model_config_from_provider_instance, kb.tenant_id, LLMType.EMBEDDING, kb.embd_id)
|
|
embd_mdl = LLMBundle(kb.tenant_id, embd_model_config)
|
|
|
|
rerank_mdl = None
|
|
if rerank_id:
|
|
rerank_model_config = await thread_pool_exec(get_model_config_from_provider_instance, tenant_id, LLMType.RERANK, rerank_id)
|
|
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
|
|
|
|
if req.get("keyword", False):
|
|
default_chat_model = await thread_pool_exec(get_tenant_default_model_by_type, kb.tenant_id, LLMType.CHAT)
|
|
chat_mdl = LLMBundle(kb.tenant_id, default_chat_model)
|
|
_question += await keyword_extraction(chat_mdl, _question)
|
|
|
|
labels = label_question(_question, [kb])
|
|
ranks = await settings.retriever.retrieval(
|
|
_question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top,
|
|
local_doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
|
|
)
|
|
if use_kg:
|
|
default_chat_model = await thread_pool_exec(get_tenant_default_model_by_type, kb.tenant_id, LLMType.CHAT)
|
|
ck = await settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl,
|
|
LLMBundle(kb.tenant_id, default_chat_model))
|
|
if ck["content_with_weight"]:
|
|
ranks["chunks"].insert(0, ck)
|
|
|
|
for c in ranks["chunks"]:
|
|
c.pop("vector", None)
|
|
|
|
include_metadata, metadata_fields = _resolve_reference_metadata(req, search_config)
|
|
if include_metadata:
|
|
enrich_chunks_with_document_metadata(ranks["chunks"], metadata_fields)
|
|
|
|
ranks["labels"] = labels
|
|
|
|
return get_json_result(data=ranks)
|
|
|
|
try:
|
|
return await _retrieval()
|
|
except Exception as e:
|
|
if "not_found" in str(e):
|
|
return get_json_result(data=False, message="No chunk found! Check the chunk status please!",
|
|
code=RetCode.DATA_ERROR)
|
|
return server_error_response(e)
|
|
|
|
|
|
@manager.route("/searchbots/related_questions", methods=["POST"]) # noqa: F821
|
|
@login_required(auth_types=AUTH_BETA)
|
|
@add_tenant_id_to_kwargs
|
|
@validate_request("question")
|
|
async def related_questions_embedded(tenant_id=None):
|
|
req = await get_request_json()
|
|
if not tenant_id:
|
|
return get_error_data_result(message="permission denined.")
|
|
|
|
search_id = req.get("search_id", "")
|
|
search_config = {}
|
|
if search_id:
|
|
if search_app := await thread_pool_exec(SearchService.get_detail, search_id):
|
|
search_config = search_app.get("search_config", {})
|
|
|
|
question = req["question"]
|
|
|
|
chat_id = search_config.get("chat_id", "")
|
|
if chat_id:
|
|
chat_model_config = await thread_pool_exec(get_model_config_from_provider_instance, tenant_id, LLMType.CHAT, chat_id)
|
|
else:
|
|
chat_model_config = await thread_pool_exec(get_tenant_default_model_by_type, tenant_id, LLMType.CHAT)
|
|
chat_mdl = LLMBundle(tenant_id, chat_model_config)
|
|
|
|
gen_conf = search_config.get("llm_setting", {"temperature": 0.9})
|
|
prompt = load_prompt("related_question")
|
|
ans = await chat_mdl.async_chat(
|
|
prompt,
|
|
[
|
|
{
|
|
"role": "user",
|
|
"content": f"""
|
|
Keywords: {question}
|
|
Related search terms:
|
|
""",
|
|
}
|
|
],
|
|
gen_conf,
|
|
)
|
|
return get_json_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)])
|
|
|
|
|
|
@manager.route("/searchbots/detail", methods=["GET"]) # noqa: F821
|
|
@login_required(auth_types=AUTH_BETA)
|
|
@add_tenant_id_to_kwargs
|
|
async def detail_share_embedded(tenant_id=None):
|
|
search_id = request.args["search_id"]
|
|
if not tenant_id:
|
|
return get_error_data_result(message="permission denined.")
|
|
try:
|
|
tenants = await thread_pool_exec(UserTenantService.query, user_id=tenant_id)
|
|
for tenant in tenants:
|
|
if await thread_pool_exec(SearchService.query, tenant_id=tenant.tenant_id, id=search_id):
|
|
break
|
|
else:
|
|
return get_json_result(data=False, message="Has no permission for this operation.",
|
|
code=RetCode.OPERATING_ERROR)
|
|
|
|
search = await thread_pool_exec(SearchService.get_detail, search_id)
|
|
if not search:
|
|
return get_error_data_result(message="Can't find this Search App!")
|
|
return get_json_result(data=search)
|
|
except Exception as e:
|
|
return server_error_response(e)
|
|
|
|
|
|
@manager.route("/searchbots/mindmap", methods=["POST"]) # noqa: F821
|
|
@login_required(auth_types=AUTH_BETA)
|
|
@add_tenant_id_to_kwargs
|
|
@validate_request("question", "kb_ids")
|
|
async def mindmap(tenant_id=None):
|
|
req = await get_request_json()
|
|
|
|
search_id = req.get("search_id", "")
|
|
search_app = await thread_pool_exec(SearchService.get_detail, search_id) if search_id else {}
|
|
|
|
mind_map =await gen_mindmap(req["question"], req["kb_ids"], tenant_id, search_app.get("search_config", {}))
|
|
if "error" in mind_map:
|
|
return server_error_response(Exception(mind_map["error"]))
|
|
return get_json_result(data=mind_map)
|
|
|
|
|
|
def _resolve_reference_metadata(req, search_config=None):
|
|
return resolve_reference_metadata_preferences(req, search_config)
|