mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-06 03:18:36 +08:00
### What problem does this PR solve? Currently, RAGFlow's Search and Chat interfaces display only raw vectorized text chunks during retrieval, without contextual information about their source documents. Users cannot see document titles, page numbers, upload dates, or custom metadata fields that would help them understand and trust the retrieved results. This PR introduces an **optional metadata display feature** that enriches retrieved chunks with document-level metadata in both the Search tab and Chatbot interface. **Key improvements:** - **Search results**: Display document metadata as styled badges beneath chunk snippets - **Chat citations**: Show metadata in citation popovers and reference lists for better source context - **LLM context**: Metadata is injected into the LLM prompt to enable more accurate, citation-aware responses - **External API support**: Applications using RAGFlow's SDK retrieval endpoints (`/v1/retrieval`, `/v1/searchbots/retrieval_test`) can opt-in via request parameters - **User control**: Multi-select dropdown UI allows users to choose which metadata fields to display **Implementation approach:** - ✅ Reuses existing `DocMetadataService` infrastructure (no new database tables or indices) - ✅ Settings stored in existing JSON configuration fields (`search_config.reference_metadata`, `prompt_config.reference_metadata`) - ✅ No database migrations required - ✅ Disabled by default (fully opt-in and backward-compatible) - ✅ Dynamic metadata field selection populated from actual document metadata keys - ✅ Fixed critical bug where Python's builtin `set()` was shadowed by a route handler function **Modified endpoints (all backward-compatible):** - `POST /v1/retrieval` (Public SDK) - `POST /v1/searchbots/retrieval_test` (Searchbots) - `POST /v1/chunk/retrieval_test` (UI/Internal) - Chat completions endpoints (via `extra_body.reference_metadata` or `prompt_config`) ### Type of change - [x] New Feature (non-breaking change which adds functionality) ###Images - <img width="879" height="1275" alt="image" src="https://github.com/user-attachments/assets/95b2d731-31ae-45a1-b081-bf5893f52aeb" /> <br><br> <br><br> <img width="1532" height="362" alt="image" src="https://github.com/user-attachments/assets/9cebc65b-b7a7-459f-b25e-3b13fa9b638e" /> <br><br> <br><br> <img width="2586" height="1320" alt="image" src="https://github.com/user-attachments/assets/2153d493-d899-461f-a7a9-041391e07776" /> --------- Co-authored-by: Cursor Agent <cursoragent@cursor.com> Co-authored-by: Attili-sys <Attili-sys@users.noreply.github.com> Co-authored-by: Ahmad Intisar <ahmadintisar@Ahmads-MacBook-M4-Pro.local>
1619 lines
69 KiB
Python
1619 lines
69 KiB
Python
#
|
||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
#
|
||
import asyncio
|
||
import binascii
|
||
import logging
|
||
import re
|
||
import time
|
||
from copy import deepcopy
|
||
from datetime import datetime
|
||
from functools import partial
|
||
from timeit import default_timer as timer
|
||
from langfuse import Langfuse
|
||
from peewee import fn
|
||
from api.db.services.file_service import FileService
|
||
from common.constants import LLMType, ParserType, StatusEnum
|
||
from api.db.db_models import DB, Dialog
|
||
from api.db.services.common_service import CommonService
|
||
from api.db.services.doc_metadata_service import DocMetadataService
|
||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||
from api.db.services.langfuse_service import TenantLangfuseService
|
||
from api.db.services.llm_service import LLMBundle
|
||
from common.metadata_utils import apply_meta_data_filter
|
||
from api.utils.reference_metadata_utils import (
|
||
enrich_chunks_with_document_metadata,
|
||
resolve_reference_metadata_preferences,
|
||
)
|
||
from api.db.services.tenant_llm_service import TenantLLMService
|
||
from api.db.joint_services.tenant_model_service import get_model_config_by_id, get_model_config_by_type_and_name, get_tenant_default_model_by_type
|
||
from common.time_utils import current_timestamp, datetime_format
|
||
from common.text_utils import normalize_arabic_digits
|
||
from rag.graphrag.general.mind_map_extractor import MindMapExtractor
|
||
from rag.advanced_rag import DeepResearcher
|
||
from rag.app.tag import label_question
|
||
from rag.nlp.search import index_name
|
||
from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, \
|
||
PROMPT_JINJA_ENV, ASK_SUMMARY
|
||
from common.token_utils import num_tokens_from_string
|
||
from rag.utils.tavily_conn import Tavily
|
||
from common.string_utils import remove_redundant_spaces
|
||
from common import settings
|
||
|
||
def _resolve_reference_metadata(request_payload=None, config=None):
|
||
return resolve_reference_metadata_preferences(request_payload or {}, config)
|
||
|
||
def _enrich_chunks_with_document_metadata(chunks, metadata_fields=None):
|
||
enrich_chunks_with_document_metadata(chunks, metadata_fields)
|
||
|
||
def _chunk_kb_id_for_doc(row_dict, kb_ids, doc_id):
|
||
if len(kb_ids or []) == 1:
|
||
return kb_ids[0]
|
||
return row_dict.get("kb_id") or row_dict.get("kb_id_kwd")
|
||
|
||
def _normalize_internet_flag(value):
|
||
if isinstance(value, bool):
|
||
return value
|
||
if isinstance(value, (int, float)) and value in (0, 1):
|
||
return bool(value)
|
||
if isinstance(value, str):
|
||
normalized = value.strip().lower()
|
||
if normalized in {"true", "1", "yes", "on"}:
|
||
return True
|
||
if normalized in {"false", "0", "no", "off", ""}:
|
||
return False
|
||
return None
|
||
|
||
|
||
def _should_use_web_search(prompt_config, internet=None):
|
||
if not prompt_config.get("tavily_api_key"):
|
||
return False
|
||
normalized = _normalize_internet_flag(internet)
|
||
return normalized is True
|
||
|
||
|
||
def _resolve_reference_metadata(config, request_payload=None):
|
||
return resolve_reference_metadata_preferences(request_payload or {}, config)
|
||
|
||
|
||
def _enrich_chunks_with_document_metadata(chunks, metadata_fields=None):
|
||
enrich_chunks_with_document_metadata(chunks, metadata_fields)
|
||
|
||
|
||
|
||
class DialogService(CommonService):
|
||
model = Dialog
|
||
|
||
@classmethod
|
||
def save(cls, **kwargs):
|
||
"""Save a new record to database.
|
||
|
||
This method creates a new record in the database with the provided field values,
|
||
forcing an insert operation rather than an update.
|
||
|
||
Args:
|
||
**kwargs: Record field values as keyword arguments.
|
||
|
||
Returns:
|
||
Model instance: The created record object.
|
||
"""
|
||
sample_obj = cls.model(**kwargs).save(force_insert=True)
|
||
return sample_obj
|
||
|
||
@classmethod
|
||
def update_many_by_id(cls, data_list):
|
||
"""Update multiple records by their IDs.
|
||
|
||
This method updates multiple records in the database, identified by their IDs.
|
||
It automatically updates the update_time and update_date fields for each record.
|
||
|
||
Args:
|
||
data_list (list): List of dictionaries containing record data to update.
|
||
Each dictionary must include an 'id' field.
|
||
"""
|
||
with DB.atomic():
|
||
for data in data_list:
|
||
data["update_time"] = current_timestamp()
|
||
data["update_date"] = datetime_format(datetime.now())
|
||
cls.model.update(data).where(cls.model.id == data["id"]).execute()
|
||
|
||
@classmethod
|
||
@DB.connection_context()
|
||
def get_list(cls, tenant_id, page_number, items_per_page, orderby, desc, id, name):
|
||
chats = cls.model.select()
|
||
if id:
|
||
chats = chats.where(cls.model.id == id)
|
||
if name:
|
||
chats = chats.where(cls.model.name == name)
|
||
chats = chats.where((cls.model.tenant_id == tenant_id) & (cls.model.status == StatusEnum.VALID.value))
|
||
if desc:
|
||
chats = chats.order_by(cls.model.getter_by(orderby).desc())
|
||
else:
|
||
chats = chats.order_by(cls.model.getter_by(orderby).asc())
|
||
|
||
chats = chats.paginate(page_number, items_per_page)
|
||
|
||
return list(chats.dicts())
|
||
|
||
@classmethod
|
||
@DB.connection_context()
|
||
def get_by_tenant_ids(
|
||
cls,
|
||
joined_tenant_ids,
|
||
user_id,
|
||
page_number,
|
||
items_per_page,
|
||
orderby,
|
||
desc,
|
||
keywords,
|
||
id=None,
|
||
name=None,
|
||
):
|
||
from api.db.db_models import User
|
||
|
||
fields = [
|
||
cls.model.id,
|
||
cls.model.tenant_id,
|
||
cls.model.name,
|
||
cls.model.description,
|
||
cls.model.language,
|
||
cls.model.llm_id,
|
||
cls.model.llm_setting,
|
||
cls.model.prompt_type,
|
||
cls.model.prompt_config,
|
||
cls.model.similarity_threshold,
|
||
cls.model.vector_similarity_weight,
|
||
cls.model.top_n,
|
||
cls.model.top_k,
|
||
cls.model.do_refer,
|
||
cls.model.rerank_id,
|
||
cls.model.kb_ids,
|
||
cls.model.icon,
|
||
cls.model.status,
|
||
User.nickname,
|
||
User.avatar.alias("tenant_avatar"),
|
||
cls.model.update_time,
|
||
cls.model.create_time,
|
||
]
|
||
dialogs = (
|
||
cls.model.select(*fields)
|
||
.join(User, on=(cls.model.tenant_id == User.id))
|
||
.where(
|
||
(cls.model.tenant_id.in_(joined_tenant_ids) | (cls.model.tenant_id == user_id))
|
||
& (cls.model.status == StatusEnum.VALID.value),
|
||
)
|
||
)
|
||
if id:
|
||
dialogs = dialogs.where(cls.model.id == id)
|
||
if name:
|
||
dialogs = dialogs.where(cls.model.name == name)
|
||
if keywords:
|
||
dialogs = dialogs.where(fn.LOWER(cls.model.name).contains(keywords.lower()))
|
||
if desc:
|
||
dialogs = dialogs.order_by(cls.model.getter_by(orderby).desc())
|
||
else:
|
||
dialogs = dialogs.order_by(cls.model.getter_by(orderby).asc())
|
||
|
||
count = dialogs.count()
|
||
|
||
if page_number and items_per_page:
|
||
dialogs = dialogs.paginate(page_number, items_per_page)
|
||
|
||
return list(dialogs.dicts()), count
|
||
|
||
@classmethod
|
||
@DB.connection_context()
|
||
def get_all_dialogs_by_tenant_id(cls, tenant_id):
|
||
fields = [cls.model.id]
|
||
dialogs = cls.model.select(*fields).where(cls.model.tenant_id == tenant_id)
|
||
dialogs.order_by(cls.model.create_time.asc())
|
||
offset, limit = 0, 100
|
||
res = []
|
||
while True:
|
||
d_batch = dialogs.offset(offset).limit(limit)
|
||
_temp = list(d_batch.dicts())
|
||
if not _temp:
|
||
break
|
||
res.extend(_temp)
|
||
offset += limit
|
||
return res
|
||
|
||
@classmethod
|
||
@DB.connection_context()
|
||
def get_null_tenant_llm_id_row(cls):
|
||
fields = [
|
||
cls.model.id,
|
||
cls.model.tenant_id,
|
||
cls.model.llm_id
|
||
]
|
||
objs = cls.model.select(*fields).where(cls.model.tenant_llm_id.is_null())
|
||
return list(objs)
|
||
|
||
@classmethod
|
||
@DB.connection_context()
|
||
def get_null_tenant_rerank_id_row(cls):
|
||
fields = [
|
||
cls.model.id,
|
||
cls.model.tenant_id,
|
||
cls.model.rerank_id
|
||
]
|
||
objs = cls.model.select(*fields).where(cls.model.tenant_rerank_id.is_null())
|
||
return list(objs)
|
||
|
||
|
||
async def async_chat_solo(dialog, messages, stream=True):
|
||
llm_type = TenantLLMService.llm_id2llm_type(dialog.llm_id)
|
||
attachments = ""
|
||
image_attachments = []
|
||
image_files = []
|
||
if "files" in messages[-1]:
|
||
if llm_type == "chat":
|
||
text_attachments, image_attachments = split_file_attachments(messages[-1]["files"])
|
||
else:
|
||
text_attachments, image_files = split_file_attachments(messages[-1]["files"], raw=True)
|
||
attachments = "\n\n".join(text_attachments)
|
||
|
||
if dialog.llm_id:
|
||
model_config = get_model_config_by_type_and_name(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
||
elif dialog.tenant_llm_id:
|
||
model_config = get_model_config_by_id(dialog.tenant_llm_id)
|
||
else:
|
||
model_config = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.CHAT)
|
||
|
||
chat_mdl = LLMBundle(dialog.tenant_id, model_config)
|
||
factory = model_config.get("llm_factory", "") if model_config else ""
|
||
|
||
prompt_config = dialog.prompt_config
|
||
tts_mdl = None
|
||
if prompt_config.get("tts"):
|
||
default_tts_model = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.TTS)
|
||
tts_mdl = LLMBundle(dialog.tenant_id, default_tts_model)
|
||
msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"]
|
||
if attachments and msg:
|
||
msg[-1]["content"] += attachments
|
||
if llm_type == "chat" and image_attachments:
|
||
convert_last_user_msg_to_multimodal(msg, image_attachments, factory)
|
||
if stream:
|
||
if llm_type == "chat":
|
||
stream_iter = chat_mdl.async_chat_streamly_delta(prompt_config.get("system", ""), msg, dialog.llm_setting)
|
||
else:
|
||
stream_iter = chat_mdl.async_chat_streamly_delta(prompt_config.get("system", ""), msg, dialog.llm_setting, images=image_files)
|
||
async for kind, value, state in _stream_with_think_delta(stream_iter):
|
||
if kind == "marker":
|
||
flags = {"start_to_think": True} if value == "<think>" else {"end_to_think": True}
|
||
yield {"answer": "", "reference": {}, "audio_binary": None, "prompt": "", "created_at": time.time(), "final": False, **flags}
|
||
continue
|
||
yield {"answer": value, "reference": {}, "audio_binary": tts(tts_mdl, value), "prompt": "", "created_at": time.time(), "final": False}
|
||
else:
|
||
if llm_type == "chat":
|
||
answer = await chat_mdl.async_chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
|
||
else:
|
||
answer = await chat_mdl.async_chat(prompt_config.get("system", ""), msg, dialog.llm_setting, images=image_files)
|
||
user_content = msg[-1].get("content", "[content not available]")
|
||
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
|
||
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()}
|
||
|
||
|
||
def get_models(dialog):
|
||
embd_mdl, chat_mdl, rerank_mdl, tts_mdl = None, None, None, None
|
||
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
|
||
embedding_list = list(set([kb.embd_id for kb in kbs]))
|
||
if len(embedding_list) > 1:
|
||
raise Exception("**ERROR**: Knowledge bases use different embedding models.")
|
||
|
||
if embedding_list:
|
||
embd_owner_tenant_id = kbs[0].tenant_id
|
||
embd_model_config = get_model_config_by_type_and_name(embd_owner_tenant_id, LLMType.EMBEDDING, embedding_list[0])
|
||
embd_mdl = LLMBundle(embd_owner_tenant_id, embd_model_config)
|
||
if not embd_mdl:
|
||
raise LookupError("Embedding model(%s) not found" % embedding_list[0])
|
||
|
||
if dialog.llm_id:
|
||
chat_model_config = get_model_config_by_type_and_name(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
||
elif dialog.tenant_llm_id:
|
||
chat_model_config = get_model_config_by_id(dialog.tenant_llm_id)
|
||
else:
|
||
chat_model_config = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.CHAT)
|
||
|
||
chat_mdl = LLMBundle(dialog.tenant_id, chat_model_config)
|
||
|
||
if dialog.rerank_id:
|
||
rerank_model_config = get_model_config_by_type_and_name(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
|
||
rerank_mdl = LLMBundle(dialog.tenant_id, rerank_model_config)
|
||
|
||
if dialog.prompt_config.get("tts"):
|
||
default_tts_model_config = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.TTS)
|
||
tts_mdl = LLMBundle(dialog.tenant_id, default_tts_model_config)
|
||
return kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl
|
||
|
||
|
||
def split_file_attachments(files: list[dict] | None, raw: bool = False) -> tuple[list[str], list[str] | list[dict]]:
|
||
if not files:
|
||
return [], []
|
||
|
||
text_attachments = []
|
||
if raw:
|
||
file_contents, image_files = FileService.get_files(files, raw=True)
|
||
for content in file_contents:
|
||
if not isinstance(content, str):
|
||
content = str(content)
|
||
text_attachments.append(content)
|
||
return text_attachments, image_files
|
||
|
||
image_attachments = []
|
||
for content in FileService.get_files(files, raw=False):
|
||
if not isinstance(content, str):
|
||
content = str(content)
|
||
if content.strip().startswith("data:"):
|
||
image_attachments.append(content.strip())
|
||
continue
|
||
text_attachments.append(content)
|
||
return text_attachments, image_attachments
|
||
|
||
|
||
_DATA_URI_RE = re.compile(r"^data:(?P<mime>[^;]+);base64,(?P<b64>[A-Za-z0-9+/=\s]+)$")
|
||
|
||
|
||
def _parse_data_uri_or_b64(s: str, default_mime: str = "image/png") -> tuple[str, str]:
|
||
s = (s or "").strip()
|
||
match = _DATA_URI_RE.match(s)
|
||
if match:
|
||
mime = match.group("mime").strip()
|
||
b64 = match.group("b64").strip()
|
||
return mime, b64
|
||
return default_mime, s
|
||
|
||
|
||
def _normalize_text_from_content(content) -> str:
|
||
if content is None:
|
||
return ""
|
||
if isinstance(content, str):
|
||
return content
|
||
if isinstance(content, list):
|
||
texts = []
|
||
for blk in content:
|
||
if isinstance(blk, dict):
|
||
if blk.get("type") in {"text", "input_text"}:
|
||
txt = blk.get("text")
|
||
if txt:
|
||
texts.append(str(txt))
|
||
elif "text" in blk and isinstance(blk.get("text"), (str, int, float)):
|
||
texts.append(str(blk["text"]))
|
||
return "\n".join(texts).strip()
|
||
return str(content)
|
||
|
||
|
||
def convert_last_user_msg_to_multimodal(msg: list[dict], image_data_uris: list[str], factory: str) -> None:
|
||
if not msg or not image_data_uris:
|
||
return
|
||
|
||
factory_norm = (factory or "").strip().lower()
|
||
|
||
for idx in range(len(msg) - 1, -1, -1):
|
||
if msg[idx].get("role") != "user":
|
||
continue
|
||
|
||
original_content = msg[idx].get("content", "")
|
||
text = _normalize_text_from_content(original_content)
|
||
|
||
if factory_norm == "gemini":
|
||
parts = []
|
||
if text:
|
||
parts.append({"text": text})
|
||
for image in image_data_uris:
|
||
mime, b64 = _parse_data_uri_or_b64(str(image), default_mime="image/png")
|
||
parts.append({"inline_data": {"mime_type": mime, "data": b64}})
|
||
msg[idx]["content"] = parts
|
||
return
|
||
|
||
if factory_norm == "anthropic":
|
||
blocks = []
|
||
if text:
|
||
blocks.append({"type": "text", "text": text})
|
||
for image in image_data_uris:
|
||
mime, b64 = _parse_data_uri_or_b64(str(image), default_mime="image/png")
|
||
blocks.append(
|
||
{
|
||
"type": "image",
|
||
"source": {"type": "base64", "media_type": mime, "data": b64},
|
||
}
|
||
)
|
||
msg[idx]["content"] = blocks
|
||
return
|
||
|
||
multimodal_content = []
|
||
if isinstance(original_content, list):
|
||
multimodal_content = deepcopy(original_content)
|
||
else:
|
||
text_content = "" if original_content is None else str(original_content)
|
||
if text_content:
|
||
multimodal_content.append({"type": "text", "text": text_content})
|
||
|
||
for data_uri in image_data_uris:
|
||
image_url = data_uri
|
||
if not isinstance(image_url, str):
|
||
image_url = str(image_url)
|
||
if not image_url.startswith("data:"):
|
||
image_url = f"data:image/png;base64,{image_url}"
|
||
multimodal_content.append({"type": "image_url", "image_url": {"url": image_url}})
|
||
|
||
msg[idx]["content"] = multimodal_content
|
||
return
|
||
|
||
|
||
BAD_CITATION_PATTERNS = [
|
||
re.compile(r"\(\s*ID\s*[: ]*\s*(\d+)\s*\)"), # (ID: 12)
|
||
re.compile(r"\[\s*ID\s*[: ]*\s*(\d+)\s*\]"), # [ID: 12]
|
||
re.compile(r"【\s*ID\s*[: ]*\s*(\d+)\s*】"), # 【ID: 12】
|
||
re.compile(r"ref\s*(\d+)", flags=re.IGNORECASE), # ref12、REF 12
|
||
]
|
||
CITATION_MARKER_PATTERN = re.compile(r"\[(?:ID:)?([0-9\u0660-\u0669\u06F0-\u06F9]+)\]")
|
||
|
||
|
||
def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set):
|
||
max_index = len(kbinfos["chunks"])
|
||
normalized_answer = normalize_arabic_digits(answer) or ""
|
||
|
||
def safe_add(i):
|
||
if 0 <= i < max_index:
|
||
idx.add(i)
|
||
return True
|
||
return False
|
||
|
||
def find_and_replace(pattern, group_index=1, repl=lambda digits: f"ID:{digits}"):
|
||
nonlocal answer
|
||
nonlocal normalized_answer
|
||
|
||
matches = list(pattern.finditer(normalized_answer))
|
||
if not matches:
|
||
return
|
||
|
||
parts = []
|
||
last_idx = 0
|
||
for match in matches:
|
||
parts.append(answer[last_idx:match.start()])
|
||
try:
|
||
i = int(match.group(group_index))
|
||
except Exception:
|
||
parts.append(answer[match.start():match.end()])
|
||
last_idx = match.end()
|
||
continue
|
||
|
||
if safe_add(i):
|
||
digit_start, digit_end = match.span(group_index)
|
||
digits_original = answer[digit_start:digit_end]
|
||
parts.append(f"[{repl(digits_original)}]")
|
||
else:
|
||
parts.append(answer[match.start():match.end()])
|
||
last_idx = match.end()
|
||
|
||
parts.append(answer[last_idx:])
|
||
answer = "".join(parts)
|
||
normalized_answer = normalize_arabic_digits(answer) or ""
|
||
|
||
for pattern in BAD_CITATION_PATTERNS:
|
||
find_and_replace(pattern)
|
||
|
||
return answer, idx
|
||
|
||
|
||
async def async_chat(dialog, messages, stream=True, **kwargs):
|
||
logging.debug("Begin async_chat")
|
||
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
||
use_web_search = _should_use_web_search(dialog.prompt_config, kwargs.get("internet"))
|
||
logging.debug("web_search kb=%s tavily=%s internet=%r enabled=%s", bool(dialog.kb_ids), bool(dialog.prompt_config.get("tavily_api_key")), kwargs.get("internet"), use_web_search)
|
||
if not dialog.kb_ids and not use_web_search:
|
||
async for ans in async_chat_solo(dialog, messages, stream):
|
||
yield ans
|
||
return
|
||
|
||
chat_start_ts = timer()
|
||
llm_type = TenantLLMService.llm_id2llm_type(dialog.llm_id)
|
||
if llm_type == "image2text":
|
||
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
|
||
else:
|
||
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
||
|
||
factory = llm_model_config.get("llm_factory", "") if llm_model_config else ""
|
||
max_tokens = llm_model_config.get("max_tokens", 8192)
|
||
|
||
check_llm_ts = timer()
|
||
|
||
langfuse_tracer = None
|
||
trace_context = {}
|
||
langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=dialog.tenant_id)
|
||
if langfuse_keys:
|
||
langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host)
|
||
try:
|
||
if langfuse.auth_check():
|
||
langfuse_tracer = langfuse
|
||
trace_id = langfuse_tracer.create_trace_id()
|
||
trace_context = {"trace_id": trace_id}
|
||
except Exception:
|
||
# Skip langfuse tracing if connection fails
|
||
pass
|
||
|
||
check_langfuse_tracer_ts = timer()
|
||
kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl = get_models(dialog)
|
||
toolcall_session, tools = kwargs.get("toolcall_session"), kwargs.get("tools")
|
||
if toolcall_session and tools:
|
||
chat_mdl.bind_tools(toolcall_session, tools)
|
||
bind_models_ts = timer()
|
||
|
||
retriever = settings.retriever
|
||
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
|
||
attachments = None
|
||
if "doc_ids" in kwargs:
|
||
attachments = [doc_id for doc_id in kwargs["doc_ids"].split(",") if doc_id]
|
||
attachments_= ""
|
||
image_attachments = []
|
||
image_files = []
|
||
if "doc_ids" in messages[-1]:
|
||
attachments = [doc_id for doc_id in messages[-1]["doc_ids"] if doc_id]
|
||
if "files" in messages[-1]:
|
||
if llm_type == "chat":
|
||
text_attachments, image_attachments = split_file_attachments(messages[-1]["files"])
|
||
else:
|
||
text_attachments, image_files = split_file_attachments(messages[-1]["files"], raw=True)
|
||
attachments_ = "\n\n".join(text_attachments)
|
||
|
||
prompt_config = dialog.prompt_config
|
||
include_reference_metadata, metadata_fields = _resolve_reference_metadata(prompt_config, request_payload=kwargs)
|
||
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
||
logging.debug(f"field_map retrieved: {field_map}")
|
||
# try to use sql if field mapping is good to go
|
||
if field_map:
|
||
logging.debug("Use SQL to retrieval:{}".format(questions[-1]))
|
||
ans = await use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids)
|
||
# For aggregate queries (COUNT, SUM, etc.), chunks may be empty but answer is still valid
|
||
if ans and (ans.get("reference", {}).get("chunks") or ans.get("answer")):
|
||
if include_reference_metadata and ans.get("reference", {}).get("chunks"):
|
||
if len(dialog.kb_ids) != 1 and any(not c.get("kb_id") for c in ans["reference"]["chunks"]):
|
||
logging.warning(
|
||
"Skipping some _enrich_chunks_with_document_metadata results because "
|
||
"dialog.kb_ids has %d entries and use_sql returned chunks without kb_id.",
|
||
len(dialog.kb_ids),
|
||
)
|
||
_enrich_chunks_with_document_metadata(ans["reference"]["chunks"], metadata_fields)
|
||
yield ans
|
||
return
|
||
else:
|
||
logging.debug("SQL failed or returned no results, falling back to vector search")
|
||
|
||
param_keys = [p["key"] for p in prompt_config.get("parameters", [])]
|
||
if dialog.kb_ids and "knowledge" not in param_keys and "{knowledge}" in prompt_config.get("system", ""):
|
||
logging.warning("prompt_config['parameters'] is missing 'knowledge' entry despite kb_ids being set; auto-fixing.")
|
||
prompt_config.setdefault("parameters", []).append({"key": "knowledge", "optional": False})
|
||
param_keys.append("knowledge")
|
||
logging.debug(f"attachments={attachments}, param_keys={param_keys}, embd_mdl={embd_mdl}")
|
||
|
||
for p in prompt_config.get("parameters", []):
|
||
if p["key"] == "knowledge":
|
||
continue
|
||
if p["key"] not in kwargs and not p["optional"]:
|
||
raise KeyError("Miss parameter: " + p["key"])
|
||
if p["key"] not in kwargs:
|
||
prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ")
|
||
|
||
if len(questions) > 1 and prompt_config.get("refine_multiturn"):
|
||
questions = [await full_question(dialog.tenant_id, dialog.llm_id, messages)]
|
||
else:
|
||
questions = questions[-1:]
|
||
|
||
if prompt_config.get("cross_languages"):
|
||
questions = [await cross_languages(dialog.tenant_id, dialog.llm_id, questions[0], prompt_config["cross_languages"])]
|
||
|
||
if dialog.meta_data_filter:
|
||
metas = DocMetadataService.get_flatted_meta_by_kbs(dialog.kb_ids)
|
||
attachments = await apply_meta_data_filter(
|
||
dialog.meta_data_filter,
|
||
metas,
|
||
questions[-1],
|
||
chat_mdl,
|
||
attachments,
|
||
)
|
||
|
||
if prompt_config.get("keyword", False):
|
||
questions[-1] = questions[-1] + "," + await keyword_extraction(chat_mdl, questions[-1])
|
||
refine_question_ts = timer()
|
||
|
||
thought = ""
|
||
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
|
||
knowledges = []
|
||
|
||
if "knowledge" in param_keys:
|
||
logging.debug("Proceeding with retrieval")
|
||
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
||
knowledges = []
|
||
if prompt_config.get("reasoning", False) or kwargs.get("reasoning"):
|
||
reasoner = DeepResearcher(
|
||
chat_mdl,
|
||
prompt_config,
|
||
partial(
|
||
retriever.retrieval,
|
||
embd_mdl=embd_mdl,
|
||
tenant_ids=tenant_ids,
|
||
kb_ids=dialog.kb_ids,
|
||
page=1,
|
||
page_size=dialog.top_n,
|
||
similarity_threshold=0.2,
|
||
vector_similarity_weight=0.3,
|
||
doc_ids=attachments,
|
||
),
|
||
internet_enabled=use_web_search,
|
||
)
|
||
queue = asyncio.Queue()
|
||
async def callback(msg:str):
|
||
nonlocal queue
|
||
await queue.put(msg + "<br/>")
|
||
|
||
await callback("<START_DEEP_RESEARCH>")
|
||
task = asyncio.create_task(reasoner.research(kbinfos, questions[-1], questions[-1], callback=callback))
|
||
while True:
|
||
msg = await queue.get()
|
||
if msg.find("<START_DEEP_RESEARCH>") == 0:
|
||
yield {"answer": "", "reference": {}, "audio_binary": None, "final": False, "start_to_think": True}
|
||
elif msg.find("<END_DEEP_RESEARCH>") == 0:
|
||
yield {"answer": "", "reference": {}, "audio_binary": None, "final": False, "end_to_think": True}
|
||
break
|
||
else:
|
||
yield {"answer": msg, "reference": {}, "audio_binary": None, "final": False}
|
||
|
||
await task
|
||
|
||
else:
|
||
if embd_mdl:
|
||
kbinfos = await retriever.retrieval(
|
||
" ".join(questions),
|
||
embd_mdl,
|
||
tenant_ids,
|
||
dialog.kb_ids,
|
||
1,
|
||
dialog.top_n,
|
||
dialog.similarity_threshold,
|
||
dialog.vector_similarity_weight,
|
||
doc_ids=attachments,
|
||
top=dialog.top_k,
|
||
aggs=True,
|
||
rerank_mdl=rerank_mdl,
|
||
rank_feature=label_question(" ".join(questions), kbs),
|
||
)
|
||
if prompt_config.get("toc_enhance"):
|
||
cks = await retriever.retrieval_by_toc(" ".join(questions), kbinfos["chunks"], tenant_ids, chat_mdl, dialog.top_n)
|
||
if cks:
|
||
kbinfos["chunks"] = cks
|
||
kbinfos["chunks"] = retriever.retrieval_by_children(kbinfos["chunks"], tenant_ids)
|
||
if use_web_search:
|
||
tav = Tavily(prompt_config["tavily_api_key"])
|
||
tav_res = tav.retrieve_chunks(" ".join(questions))
|
||
kbinfos["chunks"].extend(tav_res["chunks"])
|
||
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
|
||
if prompt_config.get("use_kg"):
|
||
default_chat_model = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.CHAT)
|
||
ck = await settings.kg_retriever.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl,
|
||
LLMBundle(dialog.tenant_id, default_chat_model))
|
||
if ck["content_with_weight"]:
|
||
kbinfos["chunks"].insert(0, ck)
|
||
|
||
if include_reference_metadata:
|
||
logging.debug(
|
||
"reference_metadata enrichment enabled for async_chat: chunk_count=%d metadata_fields=%s",
|
||
len(kbinfos.get("chunks", [])),
|
||
metadata_fields,
|
||
)
|
||
_enrich_chunks_with_document_metadata(kbinfos.get("chunks", []), metadata_fields)
|
||
|
||
knowledges = kb_prompt(kbinfos, max_tokens)
|
||
logging.debug("{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
||
|
||
retrieval_ts = timer()
|
||
if not knowledges and prompt_config.get("empty_response"):
|
||
empty_res = prompt_config["empty_response"]
|
||
yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions),
|
||
"audio_binary": tts(tts_mdl, empty_res), "final": True}
|
||
return
|
||
|
||
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
|
||
gen_conf = dialog.llm_setting
|
||
|
||
msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)+attachments_}]
|
||
prompt4citation = ""
|
||
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
|
||
prompt4citation = citation_prompt()
|
||
msg.extend([{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"])
|
||
used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.95))
|
||
if llm_type == "chat" and image_attachments:
|
||
convert_last_user_msg_to_multimodal(msg, image_attachments, factory)
|
||
assert len(msg) >= 2, f"message_fit_in has bug: {msg}"
|
||
prompt = msg[0]["content"]
|
||
|
||
if "max_tokens" in gen_conf:
|
||
gen_conf["max_tokens"] = min(gen_conf["max_tokens"], max_tokens - used_token_count)
|
||
|
||
def decorate_answer(answer):
|
||
nonlocal embd_mdl, prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions, langfuse_tracer
|
||
|
||
refs = []
|
||
ans = answer.split("</think>")
|
||
think = ""
|
||
if len(ans) == 2:
|
||
think = ans[0] + "</think>"
|
||
answer = ans[1]
|
||
|
||
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
|
||
idx = set([])
|
||
normalized_answer = normalize_arabic_digits(answer) or ""
|
||
if embd_mdl and not CITATION_MARKER_PATTERN.search(normalized_answer):
|
||
answer, idx = retriever.insert_citations(
|
||
answer,
|
||
[ck["content_ltks"] for ck in kbinfos["chunks"]],
|
||
[ck["vector"] for ck in kbinfos["chunks"]],
|
||
embd_mdl,
|
||
tkweight=1 - dialog.vector_similarity_weight,
|
||
vtweight=dialog.vector_similarity_weight,
|
||
)
|
||
else:
|
||
for match in CITATION_MARKER_PATTERN.finditer(normalized_answer):
|
||
i = int(match.group(1))
|
||
if i < len(kbinfos["chunks"]):
|
||
idx.add(i)
|
||
|
||
answer, idx = repair_bad_citation_formats(answer, kbinfos, idx)
|
||
|
||
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
||
recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
||
if not recall_docs:
|
||
recall_docs = kbinfos["doc_aggs"]
|
||
kbinfos["doc_aggs"] = recall_docs
|
||
|
||
refs = deepcopy(kbinfos)
|
||
for c in refs["chunks"]:
|
||
if c.get("vector"):
|
||
del c["vector"]
|
||
|
||
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
|
||
answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'"
|
||
finish_chat_ts = timer()
|
||
|
||
total_time_cost = (finish_chat_ts - chat_start_ts) * 1000
|
||
check_llm_time_cost = (check_llm_ts - chat_start_ts) * 1000
|
||
check_langfuse_tracer_cost = (check_langfuse_tracer_ts - check_llm_ts) * 1000
|
||
bind_embedding_time_cost = (bind_models_ts - check_langfuse_tracer_ts) * 1000
|
||
refine_question_time_cost = (refine_question_ts - bind_models_ts) * 1000
|
||
retrieval_time_cost = (retrieval_ts - refine_question_ts) * 1000
|
||
generate_result_time_cost = (finish_chat_ts - retrieval_ts) * 1000
|
||
|
||
tk_num = num_tokens_from_string(think + answer)
|
||
prompt += "\n\n### Query:\n%s" % " ".join(questions)
|
||
prompt = (
|
||
f"{prompt}\n\n"
|
||
"## Time elapsed:\n"
|
||
f" - Total: {total_time_cost:.1f}ms\n"
|
||
f" - Check LLM: {check_llm_time_cost:.1f}ms\n"
|
||
f" - Check Langfuse tracer: {check_langfuse_tracer_cost:.1f}ms\n"
|
||
f" - Bind models: {bind_embedding_time_cost:.1f}ms\n"
|
||
f" - Query refinement(LLM): {refine_question_time_cost:.1f}ms\n"
|
||
f" - Retrieval: {retrieval_time_cost:.1f}ms\n"
|
||
f" - Generate answer: {generate_result_time_cost:.1f}ms\n\n"
|
||
"## Token usage:\n"
|
||
f" - Generated tokens(approximately): {tk_num}\n"
|
||
f" - Token speed: {int(tk_num / (generate_result_time_cost / 1000.0))}/s"
|
||
)
|
||
|
||
# Add a condition check to call the end method only if langfuse_tracer exists
|
||
if langfuse_tracer and "langfuse_generation" in locals():
|
||
langfuse_output = "\n" + re.sub(r"^.*?(### Query:.*)", r"\1", prompt, flags=re.DOTALL)
|
||
langfuse_output = {"time_elapsed:": re.sub(r"\n", " \n", langfuse_output), "created_at": time.time()}
|
||
langfuse_generation.update(output=langfuse_output)
|
||
langfuse_generation.end()
|
||
|
||
return {"answer": think + answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt), "created_at": time.time()}
|
||
|
||
if langfuse_tracer:
|
||
langfuse_generation = langfuse_tracer.start_observation(as_type="generation",
|
||
trace_context=trace_context, name="chat", model=llm_model_config["llm_name"],
|
||
input={"prompt": prompt, "prompt4citation": prompt4citation, "messages": msg}
|
||
)
|
||
|
||
if stream:
|
||
if llm_type == "chat":
|
||
stream_iter = chat_mdl.async_chat_streamly_delta(prompt + prompt4citation, msg[1:], gen_conf)
|
||
else:
|
||
stream_iter = chat_mdl.async_chat_streamly_delta(prompt + prompt4citation, msg[1:], gen_conf, images=image_files)
|
||
last_state = None
|
||
async for kind, value, state in _stream_with_think_delta(stream_iter):
|
||
last_state = state
|
||
if kind == "marker":
|
||
flags = {"start_to_think": True} if value == "<think>" else {"end_to_think": True}
|
||
yield {"answer": "", "reference": {}, "audio_binary": None, "final": False, **flags}
|
||
continue
|
||
yield {"answer": value, "reference": {}, "audio_binary": tts(tts_mdl, value), "final": False}
|
||
full_answer = last_state.full_text if last_state else ""
|
||
if full_answer:
|
||
final = decorate_answer(_extract_visible_answer(thought + full_answer))
|
||
final["final"] = True
|
||
final["audio_binary"] = None
|
||
yield final
|
||
else:
|
||
if llm_type == "chat":
|
||
answer = await chat_mdl.async_chat(prompt + prompt4citation, msg[1:], gen_conf)
|
||
else:
|
||
answer = await chat_mdl.async_chat(prompt + prompt4citation, msg[1:], gen_conf, images=image_files)
|
||
user_content = msg[-1].get("content", "[content not available]")
|
||
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
|
||
res = decorate_answer(answer)
|
||
res["audio_binary"] = tts(tts_mdl, answer)
|
||
yield res
|
||
|
||
return
|
||
|
||
|
||
async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
|
||
logging.debug(f"use_sql: Question: {question}")
|
||
|
||
# Determine which document engine we're using
|
||
if settings.DOC_ENGINE_INFINITY:
|
||
doc_engine = "infinity"
|
||
elif settings.DOC_ENGINE_OCEANBASE:
|
||
doc_engine = "oceanbase"
|
||
else:
|
||
doc_engine = "es"
|
||
|
||
# Construct the full table name
|
||
# For Elasticsearch: ragflow_{tenant_id} (kb_id is in WHERE clause)
|
||
# For Infinity: ragflow_{tenant_id}_{kb_id} (each KB has its own table)
|
||
base_table = index_name(tenant_id)
|
||
if doc_engine == "infinity" and kb_ids and len(kb_ids) == 1:
|
||
# Infinity: append kb_id to table name
|
||
table_name = f"{base_table}_{kb_ids[0]}"
|
||
logging.debug(f"use_sql: Using Infinity table name: {table_name}")
|
||
else:
|
||
# Elasticsearch/OpenSearch: use base index name
|
||
table_name = base_table
|
||
logging.debug(f"use_sql: Using ES/OS table name: {table_name}")
|
||
|
||
expected_doc_name_column = "docnm" if doc_engine == "infinity" else "docnm_kwd"
|
||
|
||
def has_source_columns(columns):
|
||
normalized_names = {str(col.get("name", "")).lower() for col in columns}
|
||
return "doc_id" in normalized_names and bool({"docnm_kwd", "docnm"} & normalized_names)
|
||
|
||
def is_aggregate_sql(sql_text):
|
||
return bool(re.search(r"(count|sum|avg|max|min|distinct)\s*\(", (sql_text or "").lower()))
|
||
|
||
def normalize_sql(sql):
|
||
logging.debug(f"use_sql: Raw SQL from LLM: {repr(sql[:500])}")
|
||
# Remove think blocks if present (format: </think>...)
|
||
sql = re.sub(r"</think>\n.*?\n\s*", "", sql, flags=re.DOTALL)
|
||
sql = re.sub(r"思考\n.*?\n", "", sql, flags=re.DOTALL)
|
||
# Remove markdown code blocks (```sql ... ```)
|
||
sql = re.sub(r"```(?:sql)?\s*", "", sql, flags=re.IGNORECASE)
|
||
sql = re.sub(r"```\s*$", "", sql, flags=re.IGNORECASE)
|
||
# Remove trailing semicolon that ES SQL parser doesn't like
|
||
return sql.rstrip().rstrip(';').strip()
|
||
|
||
def add_kb_filter(sql):
|
||
# Add kb_id filter for ES/OS only (Infinity already has it in table name)
|
||
if doc_engine == "infinity" or not kb_ids:
|
||
return sql
|
||
|
||
# Build kb_filter: single KB or multiple KBs with OR
|
||
if len(kb_ids) == 1:
|
||
kb_filter = f"kb_id = '{kb_ids[0]}'"
|
||
else:
|
||
kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")"
|
||
|
||
if "where " not in sql.lower():
|
||
o = sql.lower().split("order by")
|
||
if len(o) > 1:
|
||
sql = o[0] + f" WHERE {kb_filter} order by " + o[1]
|
||
else:
|
||
sql += f" WHERE {kb_filter}"
|
||
elif "kb_id =" not in sql.lower() and "kb_id=" not in sql.lower():
|
||
sql = re.sub(r"\bwhere\b ", f"where {kb_filter} and ", sql, flags=re.IGNORECASE)
|
||
return sql
|
||
|
||
def is_row_count_question(q: str) -> bool:
|
||
q = (q or "").lower()
|
||
if not re.search(r"\bhow many rows\b|\bnumber of rows\b|\brow count\b", q):
|
||
return False
|
||
return bool(re.search(r"\bdataset\b|\btable\b|\bspreadsheet\b|\bexcel\b", q))
|
||
|
||
# Generate engine-specific SQL prompts
|
||
if doc_engine == "infinity":
|
||
# Build Infinity prompts with JSON extraction context
|
||
json_field_names = list(field_map.keys())
|
||
row_count_override = (
|
||
f"SELECT COUNT(*) AS rows FROM {table_name}"
|
||
if is_row_count_question(question)
|
||
else None
|
||
)
|
||
sys_prompt = """You are a Database Administrator. Write SQL for a table with JSON 'chunk_data' column.
|
||
|
||
JSON Extraction: json_extract_string(chunk_data, '$.FieldName')
|
||
Numeric Cast: CAST(json_extract_string(chunk_data, '$.FieldName') AS INTEGER/FLOAT)
|
||
NULL Check: json_extract_isnull(chunk_data, '$.FieldName') == false
|
||
|
||
RULES:
|
||
1. Use EXACT field names (case-sensitive) from the list below
|
||
2. For SELECT: include doc_id, docnm, and json_extract_string() for requested fields
|
||
3. For COUNT: use COUNT(*) or COUNT(DISTINCT json_extract_string(...))
|
||
4. Add AS alias for extracted field names
|
||
5. DO NOT select 'content' field
|
||
6. Only add NULL check (json_extract_isnull() == false) in WHERE clause when:
|
||
- Question asks to "show me" or "display" specific columns
|
||
- Question mentions "not null" or "excluding null"
|
||
- Add NULL check for count specific column
|
||
- DO NOT add NULL check for COUNT(*) queries (COUNT(*) counts all rows including nulls)
|
||
7. Output ONLY the SQL, no explanations"""
|
||
user_prompt = """Table: {}
|
||
Fields (EXACT case): {}
|
||
{}
|
||
Question: {}
|
||
Write SQL using json_extract_string() with exact field names. Include doc_id, docnm for data queries. Only SQL.""".format(
|
||
table_name,
|
||
", ".join(json_field_names),
|
||
"\n".join([f" - {field}" for field in json_field_names]),
|
||
question
|
||
)
|
||
elif doc_engine == "oceanbase":
|
||
# Build OceanBase prompts with JSON extraction context
|
||
json_field_names = list(field_map.keys())
|
||
row_count_override = (
|
||
f"SELECT COUNT(*) AS rows FROM {table_name}"
|
||
if is_row_count_question(question)
|
||
else None
|
||
)
|
||
sys_prompt = """You are a Database Administrator. Write SQL for a table with JSON 'chunk_data' column.
|
||
|
||
JSON Extraction: json_extract_string(chunk_data, '$.FieldName')
|
||
Numeric Cast: CAST(json_extract_string(chunk_data, '$.FieldName') AS INTEGER/FLOAT)
|
||
NULL Check: json_extract_isnull(chunk_data, '$.FieldName') == false
|
||
|
||
RULES:
|
||
1. Use EXACT field names (case-sensitive) from the list below
|
||
2. For SELECT: include doc_id, docnm_kwd, and json_extract_string() for requested fields
|
||
3. For COUNT: use COUNT(*) or COUNT(DISTINCT json_extract_string(...))
|
||
4. Add AS alias for extracted field names
|
||
5. DO NOT select 'content' field
|
||
6. Only add NULL check (json_extract_isnull() == false) in WHERE clause when:
|
||
- Question asks to "show me" or "display" specific columns
|
||
- Question mentions "not null" or "excluding null"
|
||
- Add NULL check for count specific column
|
||
- DO NOT add NULL check for COUNT(*) queries (COUNT(*) counts all rows including nulls)
|
||
7. Output ONLY the SQL, no explanations"""
|
||
user_prompt = """Table: {}
|
||
Fields (EXACT case): {}
|
||
{}
|
||
Question: {}
|
||
Write SQL using json_extract_string() with exact field names. Include doc_id, docnm_kwd for data queries. Only SQL.""".format(
|
||
table_name,
|
||
", ".join(json_field_names),
|
||
"\n".join([f" - {field}" for field in json_field_names]),
|
||
question
|
||
)
|
||
else:
|
||
# Build ES/OS prompts with direct field access
|
||
row_count_override = None
|
||
sys_prompt = """You are a Database Administrator. Write SQL queries.
|
||
|
||
RULES:
|
||
1. Use EXACT field names from the schema below (e.g., product_tks, not product)
|
||
2. Quote field names starting with digit: "123_field"
|
||
3. Add IS NOT NULL in WHERE clause when:
|
||
- Question asks to "show me" or "display" specific columns
|
||
4. Include doc_id/docnm in non-aggregate statement
|
||
5. Output ONLY the SQL, no explanations"""
|
||
user_prompt = """Table: {}
|
||
Available fields:
|
||
{}
|
||
Question: {}
|
||
Write SQL using exact field names above. Include doc_id, docnm_kwd for data queries. Only SQL.""".format(
|
||
table_name,
|
||
"\n".join([f" - {k} ({v})" for k, v in field_map.items()]),
|
||
question
|
||
)
|
||
|
||
tried_times = 0
|
||
|
||
async def get_table(custom_user_prompt=None):
|
||
nonlocal sys_prompt, user_prompt, question, tried_times, row_count_override
|
||
if row_count_override and custom_user_prompt is None:
|
||
sql = row_count_override
|
||
else:
|
||
prompt = custom_user_prompt if custom_user_prompt is not None else user_prompt
|
||
sql = await chat_mdl.async_chat(sys_prompt, [{"role": "user", "content": prompt}], {"temperature": 0.06})
|
||
sql = normalize_sql(sql)
|
||
sql = add_kb_filter(sql)
|
||
|
||
logging.debug(f"{question} get SQL(refined): {sql}")
|
||
tried_times += 1
|
||
logging.debug(f"use_sql: Executing SQL retrieval (attempt {tried_times})")
|
||
tbl = settings.retriever.sql_retrieval(sql, format="json")
|
||
if tbl is None:
|
||
logging.debug("use_sql: SQL retrieval returned None")
|
||
return None, sql
|
||
logging.debug(f"use_sql: SQL retrieval completed, got {len(tbl.get('rows', []))} rows")
|
||
return tbl, sql
|
||
|
||
async def repair_table_for_missing_source_columns(previous_sql):
|
||
if doc_engine in ("infinity", "oceanbase"):
|
||
json_field_names = list(field_map.keys())
|
||
repair_prompt = """Table name: {};
|
||
JSON fields available in 'chunk_data' column (use exact names):
|
||
{}
|
||
|
||
Question: {}
|
||
Previous SQL:
|
||
{}
|
||
|
||
The previous SQL result is missing required source columns for citations.
|
||
Rewrite SQL to keep the same query intent and include doc_id and {} in the SELECT list.
|
||
For extracted JSON fields, use json_extract_string(chunk_data, '$.field_name').
|
||
Return ONLY SQL.""".format(
|
||
table_name,
|
||
"\n".join([f" - {field}" for field in json_field_names]),
|
||
question,
|
||
previous_sql,
|
||
expected_doc_name_column
|
||
)
|
||
else:
|
||
repair_prompt = """Table name: {}
|
||
Available fields:
|
||
{}
|
||
|
||
Question: {}
|
||
Previous SQL:
|
||
{}
|
||
|
||
The previous SQL result is missing required source columns for citations.
|
||
Rewrite SQL to keep the same query intent and include doc_id and docnm_kwd in the SELECT list.
|
||
Return ONLY SQL.""".format(
|
||
table_name,
|
||
"\n".join([f" - {k} ({v})" for k, v in field_map.items()]),
|
||
question,
|
||
previous_sql
|
||
)
|
||
return await get_table(custom_user_prompt=repair_prompt)
|
||
|
||
try:
|
||
tbl, sql = await get_table()
|
||
logging.debug(f"use_sql: Initial SQL execution SUCCESS. SQL: {sql}")
|
||
logging.debug(f"use_sql: Retrieved {len(tbl.get('rows', []))} rows, columns: {[c['name'] for c in tbl.get('columns', [])]}")
|
||
except Exception as e:
|
||
logging.warning(f"use_sql: Initial SQL execution FAILED with error: {e}")
|
||
# Build retry prompt with error information
|
||
if doc_engine in ("infinity", "oceanbase"):
|
||
# Build Infinity error retry prompt
|
||
json_field_names = list(field_map.keys())
|
||
user_prompt = """
|
||
Table name: {};
|
||
JSON fields available in 'chunk_data' column (use these exact names in json_extract_string):
|
||
{}
|
||
|
||
Question: {}
|
||
Please write the SQL using json_extract_string(chunk_data, '$.field_name') with the field names from the list above. Only SQL, no explanations.
|
||
|
||
|
||
The SQL error you provided last time is as follows:
|
||
{}
|
||
|
||
Please correct the error and write SQL again using json_extract_string(chunk_data, '$.field_name') syntax with the correct field names. Only SQL, no explanations.
|
||
""".format(table_name, "\n".join([f" - {field}" for field in json_field_names]), question, e)
|
||
else:
|
||
# Build ES/OS error retry prompt
|
||
user_prompt = """
|
||
Table name: {};
|
||
Table of database fields are as follows (use the field names directly in SQL):
|
||
{}
|
||
|
||
Question are as follows:
|
||
{}
|
||
Please write the SQL using the exact field names above, only SQL, without any other explanations or text.
|
||
|
||
|
||
The SQL error you provided last time is as follows:
|
||
{}
|
||
|
||
Please correct the error and write SQL again using the exact field names above, only SQL, without any other explanations or text.
|
||
""".format(table_name, "\n".join([f"{k} ({v})" for k, v in field_map.items()]), question, e)
|
||
try:
|
||
tbl, sql = await get_table()
|
||
logging.debug(f"use_sql: Retry SQL execution SUCCESS. SQL: {sql}")
|
||
logging.debug(f"use_sql: Retrieved {len(tbl.get('rows', []))} rows on retry")
|
||
except Exception:
|
||
logging.error("use_sql: Retry SQL execution also FAILED, returning None")
|
||
return
|
||
|
||
if len(tbl["rows"]) == 0:
|
||
logging.warning(f"use_sql: No rows returned from SQL query, returning None. SQL: {sql}")
|
||
return None
|
||
|
||
if not is_aggregate_sql(sql) and not has_source_columns(tbl.get("columns", [])):
|
||
logging.warning(f"use_sql: Non-aggregate SQL missing required source columns; retrying once. SQL: {sql}")
|
||
try:
|
||
repaired_tbl, repaired_sql = await repair_table_for_missing_source_columns(sql)
|
||
if (
|
||
repaired_tbl
|
||
and len(repaired_tbl.get("rows", [])) > 0
|
||
and has_source_columns(repaired_tbl.get("columns", []))
|
||
):
|
||
tbl, sql = repaired_tbl, repaired_sql
|
||
logging.info(f"use_sql: Source-column SQL repair succeeded. SQL: {sql}")
|
||
else:
|
||
logging.warning(f"use_sql: Source-column SQL repair did not provide required columns. Repaired SQL: {repaired_sql}")
|
||
except Exception as e:
|
||
logging.warning(f"use_sql: Source-column SQL repair failed, returning best-effort answer. Error: {e}")
|
||
|
||
logging.debug(f"use_sql: Proceeding with {len(tbl['rows'])} rows to build answer")
|
||
|
||
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"].lower() == "doc_id"])
|
||
doc_name_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"].lower() in ["docnm_kwd", "docnm"]])
|
||
kb_id_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"].lower() in ["kb_id", "kb_id_kwd"]])
|
||
|
||
logging.debug(f"use_sql: All columns: {[(i, c['name']) for i, c in enumerate(tbl['columns'])]}")
|
||
logging.debug(f"use_sql: docid_idx={docid_idx}, doc_name_idx={doc_name_idx}, kb_id_idx={kb_id_idx}")
|
||
|
||
column_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx | kb_id_idx)]
|
||
|
||
logging.debug(f"use_sql: column_idx={column_idx}")
|
||
logging.debug(f"use_sql: field_map={field_map}")
|
||
|
||
# Helper function to map column names to display names
|
||
def map_column_name(col_name):
|
||
if col_name.lower() == "count(star)":
|
||
return "COUNT(*)"
|
||
|
||
# First, try to extract AS alias from any expression (aggregate functions, json_extract_string, etc.)
|
||
# Pattern: anything AS alias_name
|
||
as_match = re.search(r'\s+AS\s+([^\s,)]+)', col_name, re.IGNORECASE)
|
||
if as_match:
|
||
alias = as_match.group(1).strip('"\'')
|
||
|
||
# Use the alias for display name lookup
|
||
if alias in field_map:
|
||
display = field_map[alias]
|
||
return re.sub(r"(/.*|([^()]+))", "", display)
|
||
# If alias not in field_map, try to match case-insensitively
|
||
for field_key, display_value in field_map.items():
|
||
if field_key.lower() == alias.lower():
|
||
return re.sub(r"(/.*|([^()]+))", "", display_value)
|
||
# Return alias as-is if no mapping found
|
||
return alias
|
||
|
||
# Try direct mapping first (for simple column names)
|
||
if col_name in field_map:
|
||
display = field_map[col_name]
|
||
# Clean up any suffix patterns
|
||
return re.sub(r"(/.*|([^()]+))", "", display)
|
||
|
||
# Try case-insensitive match for simple column names
|
||
col_lower = col_name.lower()
|
||
for field_key, display_value in field_map.items():
|
||
if field_key.lower() == col_lower:
|
||
return re.sub(r"(/.*|([^()]+))", "", display_value)
|
||
|
||
# For aggregate expressions or complex expressions without AS alias,
|
||
# try to replace field names with display names
|
||
result = col_name
|
||
for field_name, display_name in field_map.items():
|
||
# Replace field_name with display_name in the expression
|
||
result = result.replace(field_name, display_name)
|
||
|
||
# Clean up any suffix patterns
|
||
result = re.sub(r"(/.*|([^()]+))", "", result)
|
||
return result
|
||
|
||
# compose Markdown table
|
||
columns = (
|
||
"|" + "|".join(
|
||
[map_column_name(tbl["columns"][i]["name"]) for i in column_idx]) + (
|
||
"|Source|" if docid_idx and doc_name_idx else "|")
|
||
)
|
||
|
||
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "")
|
||
|
||
# Build rows ensuring column names match values - create a dict for each row
|
||
# keyed by column name to handle any SQL column order
|
||
rows = []
|
||
for row_idx, r in enumerate(tbl["rows"]):
|
||
row_dict = {tbl["columns"][i]["name"]: r[i] for i in range(len(tbl["columns"])) if i < len(r)}
|
||
if row_idx == 0:
|
||
logging.debug(f"use_sql: First row data: {row_dict}")
|
||
row_values = []
|
||
for col_idx in column_idx:
|
||
col_name = tbl["columns"][col_idx]["name"]
|
||
value = row_dict.get(col_name, " ")
|
||
row_values.append(remove_redundant_spaces(str(value)).replace("None", " "))
|
||
# Add Source column with citation marker if Source column exists
|
||
if docid_idx and doc_name_idx:
|
||
row_values.append(f" ##{row_idx}$$")
|
||
row_str = "|" + "|".join(row_values) + "|"
|
||
if re.sub(r"[ |]+", "", row_str):
|
||
rows.append(row_str)
|
||
if quota:
|
||
rows = "\n".join(rows)
|
||
else:
|
||
rows = "\n".join(rows)
|
||
rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
|
||
|
||
if not docid_idx or not doc_name_idx:
|
||
logging.warning(f"use_sql: SQL missing required doc_id or docnm_kwd field. docid_idx={docid_idx}, doc_name_idx={doc_name_idx}. SQL: {sql}")
|
||
# For aggregate queries (COUNT, SUM, AVG, MAX, MIN, DISTINCT), fetch doc_id, docnm_kwd separately
|
||
# to provide source chunks, but keep the original table format answer
|
||
if is_aggregate_sql(sql):
|
||
# Keep original table format as answer
|
||
answer = "\n".join([columns, line, rows])
|
||
|
||
# Now fetch doc_id, docnm_kwd to provide source chunks
|
||
# Extract WHERE clause from the original SQL
|
||
where_match = re.search(r"\bwhere\b(.+?)(?:\bgroup by\b|\border by\b|\blimit\b|$)", sql, re.IGNORECASE)
|
||
if where_match:
|
||
where_clause = where_match.group(1).strip()
|
||
# Build a query to get source fields with the same WHERE clause.
|
||
# Single-KB queries can derive kb_id from the dialog, while multi-KB
|
||
# ES/OS queries need the row value for metadata enrichment.
|
||
chunks_kb_column = ", kb_id" if not (kb_ids and len(kb_ids) == 1) else ""
|
||
chunks_sql = f"select doc_id, {expected_doc_name_column}{chunks_kb_column} from {table_name} where {where_clause}"
|
||
# Add LIMIT to avoid fetching too many chunks
|
||
if "limit" not in chunks_sql.lower():
|
||
chunks_sql += " limit 20"
|
||
logging.debug(f"use_sql: Fetching chunks with SQL: {chunks_sql}")
|
||
try:
|
||
chunks_tbl = settings.retriever.sql_retrieval(chunks_sql, format="json")
|
||
if chunks_tbl.get("rows") and len(chunks_tbl["rows"]) > 0:
|
||
# Build chunks reference - use case-insensitive matching
|
||
chunks_did_idx = next((i for i, c in enumerate(chunks_tbl["columns"]) if c["name"].lower() == "doc_id"), None)
|
||
chunks_dn_idx = next((i for i, c in enumerate(chunks_tbl["columns"]) if c["name"].lower() in ["docnm_kwd", "docnm"]), None)
|
||
chunks_kb_idx = next((i for i, c in enumerate(chunks_tbl["columns"]) if c["name"].lower() in ["kb_id", "kb_id_kwd"]), None)
|
||
if chunks_did_idx is not None and chunks_dn_idx is not None:
|
||
chunks = []
|
||
for r in chunks_tbl["rows"]:
|
||
chunk = {"doc_id": r[chunks_did_idx], "docnm_kwd": r[chunks_dn_idx]}
|
||
row_dict = {chunks_tbl["columns"][i]["name"]: r[i] for i in range(len(chunks_tbl["columns"])) if i < len(r)}
|
||
kb_id = _chunk_kb_id_for_doc(row_dict, kb_ids, chunk["doc_id"])
|
||
if kb_id:
|
||
chunk["kb_id"] = kb_id
|
||
elif chunks_kb_idx is not None:
|
||
chunk["kb_id"] = r[chunks_kb_idx]
|
||
chunks.append(chunk)
|
||
# Build doc_aggs
|
||
doc_aggs = {}
|
||
for r in chunks_tbl["rows"]:
|
||
doc_id = r[chunks_did_idx]
|
||
doc_name = r[chunks_dn_idx]
|
||
if doc_id not in doc_aggs:
|
||
doc_aggs[doc_id] = {"doc_name": doc_name, "count": 0}
|
||
doc_aggs[doc_id]["count"] += 1
|
||
doc_aggs_list = [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()]
|
||
logging.debug(f"use_sql: Returning aggregate answer with {len(chunks)} chunks from {len(doc_aggs)} documents")
|
||
return {"answer": answer, "reference": {"chunks": chunks, "doc_aggs": doc_aggs_list}, "prompt": sys_prompt}
|
||
except Exception as e:
|
||
logging.warning(f"use_sql: Failed to fetch chunks: {e}")
|
||
# Fallback: return answer without chunks
|
||
return {"answer": answer, "reference": {"chunks": [], "doc_aggs": []}, "prompt": sys_prompt}
|
||
# Fallback to table format for other cases
|
||
return {"answer": "\n".join([columns, line, rows]), "reference": {"chunks": [], "doc_aggs": []}, "prompt": sys_prompt}
|
||
|
||
docid_idx = list(docid_idx)[0]
|
||
doc_name_idx = list(doc_name_idx)[0]
|
||
doc_aggs = {}
|
||
for r in tbl["rows"]:
|
||
if r[docid_idx] not in doc_aggs:
|
||
doc_aggs[r[docid_idx]] = {"doc_name": r[doc_name_idx], "count": 0}
|
||
doc_aggs[r[docid_idx]]["count"] += 1
|
||
|
||
result = {
|
||
"answer": "\n".join([columns, line, rows]),
|
||
"reference": {
|
||
"chunks": [
|
||
{
|
||
key: value
|
||
for key, value in {
|
||
"doc_id": r[docid_idx],
|
||
"docnm_kwd": r[doc_name_idx],
|
||
"kb_id": _chunk_kb_id_for_doc(
|
||
{tbl["columns"][i]["name"]: r[i] for i in range(len(tbl["columns"])) if i < len(r)},
|
||
kb_ids,
|
||
r[docid_idx],
|
||
),
|
||
}.items()
|
||
if value
|
||
}
|
||
for r in tbl["rows"]
|
||
],
|
||
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()],
|
||
},
|
||
"prompt": sys_prompt,
|
||
}
|
||
logging.debug(f"use_sql: Returning answer with {len(result['reference']['chunks'])} chunks from {len(doc_aggs)} documents")
|
||
return result
|
||
|
||
def clean_tts_text(text: str) -> str:
|
||
if not text:
|
||
return ""
|
||
|
||
text = text.encode("utf-8", "ignore").decode("utf-8", "ignore")
|
||
|
||
text = re.sub(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]", "", text)
|
||
|
||
emoji_pattern = re.compile(
|
||
"[\U0001F600-\U0001F64F"
|
||
"\U0001F300-\U0001F5FF"
|
||
"\U0001F680-\U0001F6FF"
|
||
"\U0001F1E0-\U0001F1FF"
|
||
"\U00002700-\U000027BF"
|
||
"\U0001F900-\U0001F9FF"
|
||
"\U0001FA70-\U0001FAFF"
|
||
"\U0001FAD0-\U0001FAFF]+",
|
||
flags=re.UNICODE
|
||
)
|
||
text = emoji_pattern.sub("", text)
|
||
|
||
text = re.sub(r"\s+", " ", text).strip()
|
||
|
||
MAX_LEN = 500
|
||
if len(text) > MAX_LEN:
|
||
text = text[:MAX_LEN]
|
||
|
||
return text
|
||
|
||
def tts(tts_mdl, text):
|
||
if not tts_mdl or not text:
|
||
return None
|
||
text = clean_tts_text(text)
|
||
if not text:
|
||
return None
|
||
bin = b""
|
||
try:
|
||
for chunk in tts_mdl.tts(text):
|
||
bin += chunk
|
||
except Exception as e:
|
||
logging.error(f"TTS failed: {e}, text={text!r}")
|
||
return None
|
||
return binascii.hexlify(bin).decode("utf-8")
|
||
|
||
|
||
class _ThinkStreamState:
|
||
def __init__(self) -> None:
|
||
self.full_text = ""
|
||
self.last_idx = 0
|
||
self.endswith_think = False
|
||
self.last_full = ""
|
||
self.last_model_full = ""
|
||
self.in_think = False
|
||
self.buffer = ""
|
||
|
||
|
||
def _extract_visible_answer(text: str) -> str:
|
||
text = text or ""
|
||
if "</think>" not in text:
|
||
return re.sub(r"</?think>", "", text)
|
||
|
||
thought, answer = text.rsplit("</think>", 1)
|
||
thought = re.sub(r"</?think>", "", thought).strip()
|
||
answer = re.sub(r"</?think>", "", answer)
|
||
if not thought:
|
||
return answer
|
||
return f"<think>{thought}</think>{answer}"
|
||
|
||
|
||
def _next_think_delta(state: _ThinkStreamState) -> str:
|
||
full_text = state.full_text
|
||
if full_text == state.last_full:
|
||
return ""
|
||
state.last_full = full_text
|
||
delta_ans = full_text[state.last_idx:]
|
||
|
||
if delta_ans.find("<think>") == 0:
|
||
state.last_idx += len("<think>")
|
||
return "<think>"
|
||
if delta_ans.find("<think>") > 0:
|
||
delta_text = full_text[state.last_idx:state.last_idx + delta_ans.find("<think>")]
|
||
state.last_idx += delta_ans.find("<think>")
|
||
return delta_text
|
||
if delta_ans.endswith("</think>"):
|
||
state.endswith_think = True
|
||
elif state.endswith_think:
|
||
state.endswith_think = False
|
||
return "</think>"
|
||
|
||
state.last_idx = len(full_text)
|
||
if full_text.endswith("</think>"):
|
||
state.last_idx -= len("</think>")
|
||
return re.sub(r"(<think>|</think>)", "", delta_ans)
|
||
|
||
|
||
async def _stream_with_think_delta(stream_iter, min_tokens: int = 16):
|
||
state = _ThinkStreamState()
|
||
async for chunk in stream_iter:
|
||
if not chunk:
|
||
continue
|
||
if chunk.startswith(state.last_model_full):
|
||
new_part = chunk[len(state.last_model_full):]
|
||
state.last_model_full = chunk
|
||
else:
|
||
new_part = chunk
|
||
state.last_model_full += chunk
|
||
if not new_part:
|
||
continue
|
||
state.full_text += new_part
|
||
delta = _next_think_delta(state)
|
||
if not delta:
|
||
continue
|
||
if delta in ("<think>", "</think>"):
|
||
if delta == "<think>" and state.in_think:
|
||
continue
|
||
if delta == "</think>" and not state.in_think:
|
||
continue
|
||
if state.buffer:
|
||
yield ("text", state.buffer, state)
|
||
state.buffer = ""
|
||
state.in_think = delta == "<think>"
|
||
yield ("marker", delta, state)
|
||
continue
|
||
state.buffer += delta
|
||
if num_tokens_from_string(state.buffer) < min_tokens:
|
||
continue
|
||
yield ("text", state.buffer, state)
|
||
state.buffer = ""
|
||
|
||
if state.buffer:
|
||
yield ("text", state.buffer, state)
|
||
state.buffer = ""
|
||
if state.endswith_think:
|
||
yield ("marker", "</think>", state)
|
||
|
||
async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
|
||
doc_ids = search_config.get("doc_ids", [])
|
||
rerank_mdl = None
|
||
kb_ids = search_config.get("kb_ids", kb_ids)
|
||
chat_llm_name = search_config.get("chat_id", chat_llm_name)
|
||
rerank_id = search_config.get("rerank_id", "")
|
||
meta_data_filter = search_config.get("meta_data_filter")
|
||
include_reference_metadata, metadata_fields = _resolve_reference_metadata(search_config)
|
||
|
||
kbs = KnowledgebaseService.get_by_ids(kb_ids)
|
||
embedding_list = list(set([kb.embd_id for kb in kbs]))
|
||
|
||
is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs])
|
||
retriever = settings.retriever if not is_knowledge_graph else settings.kg_retriever
|
||
embd_owner_tenant_id = kbs[0].tenant_id
|
||
embd_model_config = get_model_config_by_type_and_name(embd_owner_tenant_id, LLMType.EMBEDDING, embedding_list[0])
|
||
embd_mdl = LLMBundle(embd_owner_tenant_id, embd_model_config)
|
||
chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, chat_llm_name)
|
||
chat_mdl = LLMBundle(tenant_id, chat_model_config)
|
||
if rerank_id:
|
||
rerank_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.RERANK, rerank_id)
|
||
rerank_mdl = LLMBundle(tenant_id, rerank_model_config)
|
||
max_tokens = chat_mdl.max_length
|
||
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
||
|
||
if meta_data_filter:
|
||
metas = DocMetadataService.get_flatted_meta_by_kbs(kb_ids)
|
||
doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, doc_ids)
|
||
|
||
kbinfos = await retriever.retrieval(
|
||
question=question,
|
||
embd_mdl=embd_mdl,
|
||
tenant_ids=tenant_ids,
|
||
kb_ids=kb_ids,
|
||
page=1,
|
||
page_size=12,
|
||
similarity_threshold=search_config.get("similarity_threshold", 0.1),
|
||
vector_similarity_weight=search_config.get("vector_similarity_weight", 0.3),
|
||
top=search_config.get("top_k", 1024),
|
||
doc_ids=doc_ids,
|
||
aggs=True,
|
||
rerank_mdl=rerank_mdl,
|
||
rank_feature=label_question(question, kbs)
|
||
)
|
||
if include_reference_metadata:
|
||
logging.debug(
|
||
"reference_metadata enrichment enabled for async_ask: chunk_count=%d metadata_fields=%s",
|
||
len(kbinfos.get("chunks", [])),
|
||
metadata_fields,
|
||
)
|
||
_enrich_chunks_with_document_metadata(kbinfos.get("chunks", []), metadata_fields)
|
||
|
||
knowledges = kb_prompt(kbinfos, max_tokens)
|
||
sys_prompt = PROMPT_JINJA_ENV.from_string(ASK_SUMMARY).render(knowledge="\n".join(knowledges))
|
||
|
||
msg = [{"role": "user", "content": question}]
|
||
|
||
def decorate_answer(answer):
|
||
nonlocal knowledges, kbinfos, sys_prompt
|
||
answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]],
|
||
embd_mdl, tkweight=0.7, vtweight=0.3)
|
||
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
||
recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
||
if not recall_docs:
|
||
recall_docs = kbinfos["doc_aggs"]
|
||
kbinfos["doc_aggs"] = recall_docs
|
||
refs = deepcopy(kbinfos)
|
||
for c in refs["chunks"]:
|
||
if c.get("vector"):
|
||
del c["vector"]
|
||
|
||
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
|
||
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
|
||
refs["chunks"] = chunks_format(refs)
|
||
return {"answer": answer, "reference": refs}
|
||
|
||
stream_iter = chat_mdl.async_chat_streamly_delta(sys_prompt, msg, {"temperature": 0.1})
|
||
last_state = None
|
||
async for kind, value, state in _stream_with_think_delta(stream_iter):
|
||
last_state = state
|
||
if kind == "marker":
|
||
flags = {"start_to_think": True} if value == "<think>" else {"end_to_think": True}
|
||
yield {"answer": "", "reference": {}, "final": False, **flags}
|
||
continue
|
||
yield {"answer": value, "reference": {}, "final": False}
|
||
full_answer = last_state.full_text if last_state else ""
|
||
final = decorate_answer(_extract_visible_answer(full_answer))
|
||
final["final"] = True
|
||
yield final
|
||
|
||
|
||
async def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
|
||
meta_data_filter = search_config.get("meta_data_filter", {})
|
||
doc_ids = search_config.get("doc_ids", [])
|
||
rerank_id = search_config.get("rerank_id", "")
|
||
rerank_mdl = None
|
||
kbs = KnowledgebaseService.get_by_ids(kb_ids)
|
||
if not kbs:
|
||
return {"error": "No KB selected"}
|
||
tenant_embedding_list = list(set([kb.tenant_embd_id for kb in kbs]))
|
||
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
||
if tenant_embedding_list[0]:
|
||
embd_model_config = get_model_config_by_id(tenant_embedding_list[0])
|
||
embd_owner_tenant_id = kbs[0].tenant_id
|
||
else:
|
||
embd_owner_tenant_id = kbs[0].tenant_id
|
||
embd_model_config = get_model_config_by_type_and_name(embd_owner_tenant_id, LLMType.EMBEDDING, kbs[0].embd_id)
|
||
embd_mdl = LLMBundle(embd_owner_tenant_id, embd_model_config)
|
||
chat_id = search_config.get("chat_id", "")
|
||
if chat_id:
|
||
chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, chat_id)
|
||
else:
|
||
chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT)
|
||
chat_mdl = LLMBundle(tenant_id, chat_model_config)
|
||
if rerank_id:
|
||
rerank_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.RERANK, rerank_id)
|
||
rerank_mdl = LLMBundle(tenant_id, rerank_model_config)
|
||
|
||
if meta_data_filter:
|
||
metas = DocMetadataService.get_flatted_meta_by_kbs(kb_ids)
|
||
doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, doc_ids)
|
||
|
||
ranks = await settings.retriever.retrieval(
|
||
question=question,
|
||
embd_mdl=embd_mdl,
|
||
tenant_ids=tenant_ids,
|
||
kb_ids=kb_ids,
|
||
page=1,
|
||
page_size=12,
|
||
similarity_threshold=search_config.get("similarity_threshold", 0.2),
|
||
vector_similarity_weight=search_config.get("vector_similarity_weight", 0.3),
|
||
top=search_config.get("top_k", 1024),
|
||
doc_ids=doc_ids,
|
||
aggs=False,
|
||
rerank_mdl=rerank_mdl,
|
||
rank_feature=label_question(question, kbs),
|
||
)
|
||
mindmap = MindMapExtractor(chat_mdl)
|
||
mind_map = await mindmap([c["content_with_weight"] for c in ranks["chunks"]])
|
||
return mind_map.output
|