diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py
index c1d90ebe4c..6f981efb5e 100644
--- a/api/db/services/dialog_service.py
+++ b/api/db/services/dialog_service.py
@@ -18,7 +18,10 @@ import binascii
import logging
import re
import time
+import uuid
from copy import deepcopy
+
+logger = logging.getLogger(__name__)
from datetime import datetime
from functools import partial
from timeit import default_timer as timer
@@ -45,8 +48,7 @@ from rag.graphrag.general.mind_map_extractor import MindMapExtractor
from rag.advanced_rag import DeepResearcher
from rag.app.tag import label_question
from rag.nlp.search import index_name
-from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, \
- PROMPT_JINJA_ENV, ASK_SUMMARY
+from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, PROMPT_JINJA_ENV, ASK_SUMMARY
from common.token_utils import num_tokens_from_string
from rag.utils.tavily_conn import Tavily
from common.string_utils import remove_redundant_spaces
@@ -191,8 +193,7 @@ class DialogService(CommonService):
cls.model.select(*fields)
.join(User, on=(cls.model.tenant_id == User.id))
.where(
- (cls.model.tenant_id.in_(joined_tenant_ids) | (cls.model.tenant_id == user_id))
- & (cls.model.status == StatusEnum.VALID.value),
+ (cls.model.tenant_id.in_(joined_tenant_ids) | (cls.model.tenant_id == user_id)) & (cls.model.status == StatusEnum.VALID.value),
)
)
if id:
@@ -233,22 +234,14 @@ class DialogService(CommonService):
@classmethod
@DB.connection_context()
def get_null_tenant_llm_id_row(cls):
- fields = [
- cls.model.id,
- cls.model.tenant_id,
- cls.model.llm_id
- ]
+ fields = [cls.model.id, cls.model.tenant_id, cls.model.llm_id]
objs = cls.model.select(*fields).where(cls.model.tenant_llm_id.is_null())
return list(objs)
@classmethod
@DB.connection_context()
def get_null_tenant_rerank_id_row(cls):
- fields = [
- cls.model.id,
- cls.model.tenant_id,
- cls.model.rerank_id
- ]
+ fields = [cls.model.id, cls.model.tenant_id, cls.model.rerank_id]
objs = cls.model.select(*fields).where(cls.model.tenant_rerank_id.is_null())
return list(objs)
@@ -264,7 +257,7 @@ async def async_chat_solo(dialog, messages, stream=True):
else:
text_attachments, image_files = split_file_attachments(messages[-1]["files"], raw=True)
attachments = "\n\n".join(text_attachments)
-
+
if dialog.llm_id:
model_config = get_model_config_by_type_and_name(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
elif dialog.tenant_llm_id:
@@ -483,11 +476,11 @@ def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set):
parts = []
last_idx = 0
for match in matches:
- parts.append(answer[last_idx:match.start()])
+ parts.append(answer[last_idx : match.start()])
try:
i = int(match.group(group_index))
except Exception:
- parts.append(answer[match.start():match.end()])
+ parts.append(answer[match.start() : match.end()])
last_idx = match.end()
continue
@@ -496,7 +489,7 @@ def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set):
digits_original = answer[digit_start:digit_end]
parts.append(f"[{repl(digits_original)}]")
else:
- parts.append(answer[match.start():match.end()])
+ parts.append(answer[match.start() : match.end()])
last_idx = match.end()
parts.append(answer[last_idx:])
@@ -557,7 +550,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
attachments = None
if "doc_ids" in kwargs:
attachments = [doc_id for doc_id in kwargs["doc_ids"].split(",") if doc_id]
- attachments_= ""
+ attachments_ = ""
image_attachments = []
image_files = []
if "doc_ids" in messages[-1]:
@@ -656,7 +649,8 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
internet_enabled=use_web_search,
)
queue = asyncio.Queue()
- async def callback(msg:str):
+
+ async def callback(msg: str):
nonlocal queue
await queue.put(msg + "
")
@@ -703,8 +697,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
if prompt_config.get("use_kg"):
default_chat_model = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.CHAT)
- ck = await settings.kg_retriever.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl,
- LLMBundle(dialog.tenant_id, default_chat_model))
+ ck = await settings.kg_retriever.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl, LLMBundle(dialog.tenant_id, default_chat_model))
if ck["content_with_weight"]:
kbinfos["chunks"].insert(0, ck)
@@ -722,14 +715,13 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
retrieval_ts = timer()
if not knowledges and prompt_config.get("empty_response"):
empty_res = prompt_config["empty_response"]
- yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions),
- "audio_binary": tts(tts_mdl, empty_res), "final": True}
+ yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions), "audio_binary": tts(tts_mdl, empty_res), "final": True}
return
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
gen_conf = dialog.llm_setting
- msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)+attachments_}]
+ msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs) + attachments_}]
prompt4citation = ""
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
prompt4citation = citation_prompt()
@@ -823,9 +815,8 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
return {"answer": think + answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt), "created_at": time.time()}
if langfuse_tracer:
- langfuse_generation = langfuse_tracer.start_observation(as_type="generation",
- trace_context=trace_context, name="chat", model=llm_model_config["llm_name"],
- input={"prompt": prompt, "prompt4citation": prompt4citation, "messages": msg}
+ langfuse_generation = langfuse_tracer.start_generation(
+ trace_context=trace_context, name="chat", model=llm_model_config["llm_name"], input={"prompt": prompt, "prompt4citation": prompt4citation, "messages": msg}
)
if stream:
@@ -862,6 +853,25 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
+ """Answer a natural-language question by generating and executing SQL against the document index.
+
+ Detects the active document engine (Infinity, OceanBase, or Elasticsearch), asks the
+ chat model to produce the appropriate SQL, injects a validated kb_id filter, executes
+ the query, and returns formatted results with optional source citations.
+
+ Args:
+ question: Natural-language question from the user.
+ field_map: Mapping of field names to types describing the indexed document schema.
+ tenant_id: Tenant identifier used to derive the target index/table name.
+ chat_mdl: LLM bundle used to generate SQL from the question.
+ quota: Whether to enforce token-quota checks (default True).
+ kb_ids: Optional list of knowledge-base UUIDs to restrict the query scope.
+
+ Returns:
+ A dict with keys ``answer`` (formatted response string), ``reference``
+ (dict of supporting document chunks and doc_aggs), and ``prompt``
+ (the system prompt used), or ``None`` if SQL generation or execution fails.
+ """
logging.debug(f"use_sql: Question: {question}")
# Determine which document engine we're using
@@ -872,12 +882,20 @@ async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=N
else:
doc_engine = "es"
+ def _assert_valid_uuid(value: str, label: str = "id") -> None:
+ try:
+ uuid.UUID(str(value))
+ except (ValueError, AttributeError, TypeError):
+ logger.warning("SQL injection guard rejected invalid %s value (length=%d)", label, len(str(value)))
+ raise ValueError(f"Invalid {label} format: {value!r}")
+
# Construct the full table name
# For Elasticsearch: ragflow_{tenant_id} (kb_id is in WHERE clause)
# For Infinity: ragflow_{tenant_id}_{kb_id} (each KB has its own table)
base_table = index_name(tenant_id)
if doc_engine == "infinity" and kb_ids and len(kb_ids) == 1:
- # Infinity: append kb_id to table name
+ # Infinity: append kb_id to table name — validate before interpolating
+ _assert_valid_uuid(kb_ids[0], "kb_id")
table_name = f"{base_table}_{kb_ids[0]}"
logging.debug(f"use_sql: Using Infinity table name: {table_name}")
else:
@@ -888,13 +906,20 @@ async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=N
expected_doc_name_column = "docnm" if doc_engine == "infinity" else "docnm_kwd"
def has_source_columns(columns):
+ """Return True if the result set contains the columns needed to build source citations."""
normalized_names = {str(col.get("name", "")).lower() for col in columns}
return "doc_id" in normalized_names and bool({"docnm_kwd", "docnm"} & normalized_names)
def is_aggregate_sql(sql_text):
+ """Return True if *sql_text* contains an aggregate function (COUNT, SUM, AVG, MAX, MIN, DISTINCT)."""
return bool(re.search(r"(count|sum|avg|max|min|distinct)\s*\(", (sql_text or "").lower()))
def normalize_sql(sql):
+ """Strip LLM artefacts from *sql* and return a clean, executable SQL string.
+
+ Removes ```` reasoning blocks, Chinese reasoning markers, markdown
+ code fences, and trailing semicolons that some engines reject.
+ """
logging.debug(f"use_sql: Raw SQL from LLM: {repr(sql[:500])}")
# Remove think blocks if present (format: ...)
sql = re.sub(r"\n.*?\n\s*", "", sql, flags=re.DOTALL)
@@ -903,18 +928,28 @@ async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=N
sql = re.sub(r"```(?:sql)?\s*", "", sql, flags=re.IGNORECASE)
sql = re.sub(r"```\s*$", "", sql, flags=re.IGNORECASE)
# Remove trailing semicolon that ES SQL parser doesn't like
- return sql.rstrip().rstrip(';').strip()
+ return sql.rstrip().rstrip(";").strip()
def add_kb_filter(sql):
+ """Inject a validated kb_id WHERE filter into *sql* for ES/OceanBase engines.
+
+ Infinity encodes the knowledge-base scope in the table name, so this
+ function is a no-op for that engine. All kb_id values are validated as
+ canonical UUIDs before interpolation to prevent SQL injection.
+ """
# Add kb_id filter for ES/OS only (Infinity already has it in table name)
if doc_engine == "infinity" or not kb_ids:
return sql
+ # Validate all kb_ids are UUIDs before interpolating into SQL
+ for kid in kb_ids:
+ _assert_valid_uuid(kid, "kb_id")
+
# Build kb_filter: single KB or multiple KBs with OR
if len(kb_ids) == 1:
kb_filter = f"kb_id = '{kb_ids[0]}'"
else:
- kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")"
+ kb_filter = "(" + " OR ".join([f"kb_id = '{kid}'" for kid in kb_ids]) + ")"
if "where " not in sql.lower():
o = sql.lower().split("order by")
@@ -927,6 +962,7 @@ async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=N
return sql
def is_row_count_question(q: str) -> bool:
+ """Return True if *q* is asking for a total row count of a dataset or table."""
q = (q or "").lower()
if not re.search(r"\bhow many rows\b|\bnumber of rows\b|\brow count\b", q):
return False
@@ -936,11 +972,7 @@ async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=N
if doc_engine == "infinity":
# Build Infinity prompts with JSON extraction context
json_field_names = list(field_map.keys())
- row_count_override = (
- f"SELECT COUNT(*) AS rows FROM {table_name}"
- if is_row_count_question(question)
- else None
- )
+ row_count_override = f"SELECT COUNT(*) AS rows FROM {table_name}" if is_row_count_question(question) else None
sys_prompt = """You are a Database Administrator. Write SQL for a table with JSON 'chunk_data' column.
JSON Extraction: json_extract_string(chunk_data, '$.FieldName')
@@ -964,19 +996,12 @@ Fields (EXACT case): {}
{}
Question: {}
Write SQL using json_extract_string() with exact field names. Include doc_id, docnm for data queries. Only SQL.""".format(
- table_name,
- ", ".join(json_field_names),
- "\n".join([f" - {field}" for field in json_field_names]),
- question
+ table_name, ", ".join(json_field_names), "\n".join([f" - {field}" for field in json_field_names]), question
)
elif doc_engine == "oceanbase":
# Build OceanBase prompts with JSON extraction context
json_field_names = list(field_map.keys())
- row_count_override = (
- f"SELECT COUNT(*) AS rows FROM {table_name}"
- if is_row_count_question(question)
- else None
- )
+ row_count_override = f"SELECT COUNT(*) AS rows FROM {table_name}" if is_row_count_question(question) else None
sys_prompt = """You are a Database Administrator. Write SQL for a table with JSON 'chunk_data' column.
JSON Extraction: json_extract_string(chunk_data, '$.FieldName')
@@ -1000,10 +1025,7 @@ Fields (EXACT case): {}
{}
Question: {}
Write SQL using json_extract_string() with exact field names. Include doc_id, docnm_kwd for data queries. Only SQL.""".format(
- table_name,
- ", ".join(json_field_names),
- "\n".join([f" - {field}" for field in json_field_names]),
- question
+ table_name, ", ".join(json_field_names), "\n".join([f" - {field}" for field in json_field_names]), question
)
else:
# Build ES/OS prompts with direct field access
@@ -1021,11 +1043,7 @@ RULES:
Available fields:
{}
Question: {}
-Write SQL using exact field names above. Include doc_id, docnm_kwd for data queries. Only SQL.""".format(
- table_name,
- "\n".join([f" - {k} ({v})" for k, v in field_map.items()]),
- question
- )
+Write SQL using exact field names above. Include doc_id, docnm_kwd for data queries. Only SQL.""".format(table_name, "\n".join([f" - {k} ({v})" for k, v in field_map.items()]), question)
tried_times = 0
@@ -1063,13 +1081,7 @@ Previous SQL:
The previous SQL result is missing required source columns for citations.
Rewrite SQL to keep the same query intent and include doc_id and {} in the SELECT list.
For extracted JSON fields, use json_extract_string(chunk_data, '$.field_name').
-Return ONLY SQL.""".format(
- table_name,
- "\n".join([f" - {field}" for field in json_field_names]),
- question,
- previous_sql,
- expected_doc_name_column
- )
+Return ONLY SQL.""".format(table_name, "\n".join([f" - {field}" for field in json_field_names]), question, previous_sql, expected_doc_name_column)
else:
repair_prompt = """Table name: {}
Available fields:
@@ -1081,12 +1093,7 @@ Previous SQL:
The previous SQL result is missing required source columns for citations.
Rewrite SQL to keep the same query intent and include doc_id and docnm_kwd in the SELECT list.
-Return ONLY SQL.""".format(
- table_name,
- "\n".join([f" - {k} ({v})" for k, v in field_map.items()]),
- question,
- previous_sql
- )
+Return ONLY SQL.""".format(table_name, "\n".join([f" - {k} ({v})" for k, v in field_map.items()]), question, previous_sql)
return await get_table(custom_user_prompt=repair_prompt)
try:
@@ -1146,11 +1153,7 @@ Please correct the error and write SQL again using json_extract_string(chunk_dat
logging.warning(f"use_sql: Non-aggregate SQL missing required source columns; retrying once. SQL: {sql}")
try:
repaired_tbl, repaired_sql = await repair_table_for_missing_source_columns(sql)
- if (
- repaired_tbl
- and len(repaired_tbl.get("rows", [])) > 0
- and has_source_columns(repaired_tbl.get("columns", []))
- ):
+ if repaired_tbl and len(repaired_tbl.get("rows", [])) > 0 and has_source_columns(repaired_tbl.get("columns", [])):
tbl, sql = repaired_tbl, repaired_sql
logging.info(f"use_sql: Source-column SQL repair succeeded. SQL: {sql}")
else:
@@ -1179,9 +1182,9 @@ Please correct the error and write SQL again using json_extract_string(chunk_dat
# First, try to extract AS alias from any expression (aggregate functions, json_extract_string, etc.)
# Pattern: anything AS alias_name
- as_match = re.search(r'\s+AS\s+([^\s,)]+)', col_name, re.IGNORECASE)
+ as_match = re.search(r"\s+AS\s+([^\s,)]+)", col_name, re.IGNORECASE)
if as_match:
- alias = as_match.group(1).strip('"\'')
+ alias = as_match.group(1).strip("\"'")
# Use the alias for display name lookup
if alias in field_map:
@@ -1218,11 +1221,7 @@ Please correct the error and write SQL again using json_extract_string(chunk_dat
return result
# compose Markdown table
- columns = (
- "|" + "|".join(
- [map_column_name(tbl["columns"][i]["name"]) for i in column_idx]) + (
- "|Source|" if docid_idx and doc_name_idx else "|")
- )
+ columns = "|" + "|".join([map_column_name(tbl["columns"][i]["name"]) for i in column_idx]) + ("|Source|" if docid_idx and doc_name_idx else "|")
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "")
@@ -1342,6 +1341,7 @@ Please correct the error and write SQL again using json_extract_string(chunk_dat
logging.debug(f"use_sql: Returning answer with {len(result['reference']['chunks'])} chunks from {len(doc_aggs)} documents")
return result
+
def clean_tts_text(text: str) -> str:
if not text:
return ""
@@ -1351,15 +1351,7 @@ def clean_tts_text(text: str) -> str:
text = re.sub(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]", "", text)
emoji_pattern = re.compile(
- "[\U0001F600-\U0001F64F"
- "\U0001F300-\U0001F5FF"
- "\U0001F680-\U0001F6FF"
- "\U0001F1E0-\U0001F1FF"
- "\U00002700-\U000027BF"
- "\U0001F900-\U0001F9FF"
- "\U0001FA70-\U0001FAFF"
- "\U0001FAD0-\U0001FAFF]+",
- flags=re.UNICODE
+ "[\U0001f600-\U0001f64f\U0001f300-\U0001f5ff\U0001f680-\U0001f6ff\U0001f1e0-\U0001f1ff\U00002700-\U000027bf\U0001f900-\U0001f9ff\U0001fa70-\U0001faff\U0001fad0-\U0001faff]+", flags=re.UNICODE
)
text = emoji_pattern.sub("", text)
@@ -1371,6 +1363,7 @@ def clean_tts_text(text: str) -> str:
return text
+
def tts(tts_mdl, text):
if not tts_mdl or not text:
return None
@@ -1416,13 +1409,13 @@ def _next_think_delta(state: _ThinkStreamState) -> str:
if full_text == state.last_full:
return ""
state.last_full = full_text
- delta_ans = full_text[state.last_idx:]
+ delta_ans = full_text[state.last_idx :]
if delta_ans.find("") == 0:
state.last_idx += len("")
return ""
if delta_ans.find("") > 0:
- delta_text = full_text[state.last_idx:state.last_idx + delta_ans.find("")]
+ delta_text = full_text[state.last_idx : state.last_idx + delta_ans.find("")]
state.last_idx += delta_ans.find("")
return delta_text
if delta_ans.endswith(""):
@@ -1443,7 +1436,7 @@ async def _stream_with_think_delta(stream_iter, min_tokens: int = 16):
if not chunk:
continue
if chunk.startswith(state.last_model_full):
- new_part = chunk[len(state.last_model_full):]
+ new_part = chunk[len(state.last_model_full) :]
state.last_model_full = chunk
else:
new_part = chunk
@@ -1477,6 +1470,7 @@ async def _stream_with_think_delta(stream_iter, min_tokens: int = 16):
if state.endswith_think:
yield ("marker", "", state)
+
async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
doc_ids = search_config.get("doc_ids", [])
rerank_mdl = None
@@ -1526,7 +1520,7 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf
doc_ids=doc_ids,
aggs=True,
rerank_mdl=rerank_mdl,
- rank_feature=label_question(question, kbs)
+ rank_feature=label_question(question, kbs),
)
if include_reference_metadata:
logging.debug(
@@ -1543,8 +1537,7 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf
def decorate_answer(answer):
nonlocal knowledges, kbinfos, sys_prompt
- answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]],
- embd_mdl, tkweight=0.7, vtweight=0.3)
+ answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]], embd_mdl, tkweight=0.7, vtweight=0.3)
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
if not recall_docs:
diff --git a/test/unit_test/api/db/services/test_dialog_service_use_sql_source_columns.py b/test/unit_test/api/db/services/test_dialog_service_use_sql_source_columns.py
index 71941e3874..5910781be4 100644
--- a/test/unit_test/api/db/services/test_dialog_service_use_sql_source_columns.py
+++ b/test/unit_test/api/db/services/test_dialog_service_use_sql_source_columns.py
@@ -33,6 +33,7 @@ warnings.filterwarnings(
def _install_cv2_stub_if_unavailable():
try:
import cv2 # noqa: F401
+
return
except Exception:
pass