Fix: validate kb_ids as UUIDs before SQL interpolation in use_sql (#14087)

### What problem does this PR solve?

The use_sql() function in dialog_service.py constructed SQL WHERE
clauses and Infinity table names by directly interpolating kb_id values
using Python f-strings, with no validation of the input values. A
malformed or maliciously crafted kb_id (introduced via a compromised
admin account or a separate injection vector) could alter the structure
of the generated SQL query, potentially leading to unauthorized data
access or data manipulation.

This PR adds strict UUID format validation for all kb_id values before
they are interpolated into any SQL string, causing requests with invalid
IDs to fail fast with a ValueError rather than executing a tampered
query.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
This commit is contained in:
Xing Hong
2026-05-09 11:52:06 +09:00
committed by GitHub
parent c44dc85143
commit c428187350
2 changed files with 88 additions and 94 deletions

View File

@@ -18,7 +18,10 @@ import binascii
import logging
import re
import time
import uuid
from copy import deepcopy
logger = logging.getLogger(__name__)
from datetime import datetime
from functools import partial
from timeit import default_timer as timer
@@ -45,8 +48,7 @@ from rag.graphrag.general.mind_map_extractor import MindMapExtractor
from rag.advanced_rag import DeepResearcher
from rag.app.tag import label_question
from rag.nlp.search import index_name
from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, \
PROMPT_JINJA_ENV, ASK_SUMMARY
from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, PROMPT_JINJA_ENV, ASK_SUMMARY
from common.token_utils import num_tokens_from_string
from rag.utils.tavily_conn import Tavily
from common.string_utils import remove_redundant_spaces
@@ -191,8 +193,7 @@ class DialogService(CommonService):
cls.model.select(*fields)
.join(User, on=(cls.model.tenant_id == User.id))
.where(
(cls.model.tenant_id.in_(joined_tenant_ids) | (cls.model.tenant_id == user_id))
& (cls.model.status == StatusEnum.VALID.value),
(cls.model.tenant_id.in_(joined_tenant_ids) | (cls.model.tenant_id == user_id)) & (cls.model.status == StatusEnum.VALID.value),
)
)
if id:
@@ -233,22 +234,14 @@ class DialogService(CommonService):
@classmethod
@DB.connection_context()
def get_null_tenant_llm_id_row(cls):
fields = [
cls.model.id,
cls.model.tenant_id,
cls.model.llm_id
]
fields = [cls.model.id, cls.model.tenant_id, cls.model.llm_id]
objs = cls.model.select(*fields).where(cls.model.tenant_llm_id.is_null())
return list(objs)
@classmethod
@DB.connection_context()
def get_null_tenant_rerank_id_row(cls):
fields = [
cls.model.id,
cls.model.tenant_id,
cls.model.rerank_id
]
fields = [cls.model.id, cls.model.tenant_id, cls.model.rerank_id]
objs = cls.model.select(*fields).where(cls.model.tenant_rerank_id.is_null())
return list(objs)
@@ -264,7 +257,7 @@ async def async_chat_solo(dialog, messages, stream=True):
else:
text_attachments, image_files = split_file_attachments(messages[-1]["files"], raw=True)
attachments = "\n\n".join(text_attachments)
if dialog.llm_id:
model_config = get_model_config_by_type_and_name(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
elif dialog.tenant_llm_id:
@@ -483,11 +476,11 @@ def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set):
parts = []
last_idx = 0
for match in matches:
parts.append(answer[last_idx:match.start()])
parts.append(answer[last_idx : match.start()])
try:
i = int(match.group(group_index))
except Exception:
parts.append(answer[match.start():match.end()])
parts.append(answer[match.start() : match.end()])
last_idx = match.end()
continue
@@ -496,7 +489,7 @@ def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set):
digits_original = answer[digit_start:digit_end]
parts.append(f"[{repl(digits_original)}]")
else:
parts.append(answer[match.start():match.end()])
parts.append(answer[match.start() : match.end()])
last_idx = match.end()
parts.append(answer[last_idx:])
@@ -557,7 +550,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
attachments = None
if "doc_ids" in kwargs:
attachments = [doc_id for doc_id in kwargs["doc_ids"].split(",") if doc_id]
attachments_= ""
attachments_ = ""
image_attachments = []
image_files = []
if "doc_ids" in messages[-1]:
@@ -656,7 +649,8 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
internet_enabled=use_web_search,
)
queue = asyncio.Queue()
async def callback(msg:str):
async def callback(msg: str):
nonlocal queue
await queue.put(msg + "<br/>")
@@ -703,8 +697,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
if prompt_config.get("use_kg"):
default_chat_model = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.CHAT)
ck = await settings.kg_retriever.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl,
LLMBundle(dialog.tenant_id, default_chat_model))
ck = await settings.kg_retriever.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl, LLMBundle(dialog.tenant_id, default_chat_model))
if ck["content_with_weight"]:
kbinfos["chunks"].insert(0, ck)
@@ -722,14 +715,13 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
retrieval_ts = timer()
if not knowledges and prompt_config.get("empty_response"):
empty_res = prompt_config["empty_response"]
yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions),
"audio_binary": tts(tts_mdl, empty_res), "final": True}
yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions), "audio_binary": tts(tts_mdl, empty_res), "final": True}
return
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
gen_conf = dialog.llm_setting
msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)+attachments_}]
msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs) + attachments_}]
prompt4citation = ""
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
prompt4citation = citation_prompt()
@@ -823,9 +815,8 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
return {"answer": think + answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt), "created_at": time.time()}
if langfuse_tracer:
langfuse_generation = langfuse_tracer.start_observation(as_type="generation",
trace_context=trace_context, name="chat", model=llm_model_config["llm_name"],
input={"prompt": prompt, "prompt4citation": prompt4citation, "messages": msg}
langfuse_generation = langfuse_tracer.start_generation(
trace_context=trace_context, name="chat", model=llm_model_config["llm_name"], input={"prompt": prompt, "prompt4citation": prompt4citation, "messages": msg}
)
if stream:
@@ -862,6 +853,25 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
"""Answer a natural-language question by generating and executing SQL against the document index.
Detects the active document engine (Infinity, OceanBase, or Elasticsearch), asks the
chat model to produce the appropriate SQL, injects a validated kb_id filter, executes
the query, and returns formatted results with optional source citations.
Args:
question: Natural-language question from the user.
field_map: Mapping of field names to types describing the indexed document schema.
tenant_id: Tenant identifier used to derive the target index/table name.
chat_mdl: LLM bundle used to generate SQL from the question.
quota: Whether to enforce token-quota checks (default True).
kb_ids: Optional list of knowledge-base UUIDs to restrict the query scope.
Returns:
A dict with keys ``answer`` (formatted response string), ``reference``
(dict of supporting document chunks and doc_aggs), and ``prompt``
(the system prompt used), or ``None`` if SQL generation or execution fails.
"""
logging.debug(f"use_sql: Question: {question}")
# Determine which document engine we're using
@@ -872,12 +882,20 @@ async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=N
else:
doc_engine = "es"
def _assert_valid_uuid(value: str, label: str = "id") -> None:
try:
uuid.UUID(str(value))
except (ValueError, AttributeError, TypeError):
logger.warning("SQL injection guard rejected invalid %s value (length=%d)", label, len(str(value)))
raise ValueError(f"Invalid {label} format: {value!r}")
# Construct the full table name
# For Elasticsearch: ragflow_{tenant_id} (kb_id is in WHERE clause)
# For Infinity: ragflow_{tenant_id}_{kb_id} (each KB has its own table)
base_table = index_name(tenant_id)
if doc_engine == "infinity" and kb_ids and len(kb_ids) == 1:
# Infinity: append kb_id to table name
# Infinity: append kb_id to table name — validate before interpolating
_assert_valid_uuid(kb_ids[0], "kb_id")
table_name = f"{base_table}_{kb_ids[0]}"
logging.debug(f"use_sql: Using Infinity table name: {table_name}")
else:
@@ -888,13 +906,20 @@ async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=N
expected_doc_name_column = "docnm" if doc_engine == "infinity" else "docnm_kwd"
def has_source_columns(columns):
"""Return True if the result set contains the columns needed to build source citations."""
normalized_names = {str(col.get("name", "")).lower() for col in columns}
return "doc_id" in normalized_names and bool({"docnm_kwd", "docnm"} & normalized_names)
def is_aggregate_sql(sql_text):
"""Return True if *sql_text* contains an aggregate function (COUNT, SUM, AVG, MAX, MIN, DISTINCT)."""
return bool(re.search(r"(count|sum|avg|max|min|distinct)\s*\(", (sql_text or "").lower()))
def normalize_sql(sql):
"""Strip LLM artefacts from *sql* and return a clean, executable SQL string.
Removes ``<think>`` reasoning blocks, Chinese reasoning markers, markdown
code fences, and trailing semicolons that some engines reject.
"""
logging.debug(f"use_sql: Raw SQL from LLM: {repr(sql[:500])}")
# Remove think blocks if present (format: </think>...)
sql = re.sub(r"</think>\n.*?\n\s*", "", sql, flags=re.DOTALL)
@@ -903,18 +928,28 @@ async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=N
sql = re.sub(r"```(?:sql)?\s*", "", sql, flags=re.IGNORECASE)
sql = re.sub(r"```\s*$", "", sql, flags=re.IGNORECASE)
# Remove trailing semicolon that ES SQL parser doesn't like
return sql.rstrip().rstrip(';').strip()
return sql.rstrip().rstrip(";").strip()
def add_kb_filter(sql):
"""Inject a validated kb_id WHERE filter into *sql* for ES/OceanBase engines.
Infinity encodes the knowledge-base scope in the table name, so this
function is a no-op for that engine. All kb_id values are validated as
canonical UUIDs before interpolation to prevent SQL injection.
"""
# Add kb_id filter for ES/OS only (Infinity already has it in table name)
if doc_engine == "infinity" or not kb_ids:
return sql
# Validate all kb_ids are UUIDs before interpolating into SQL
for kid in kb_ids:
_assert_valid_uuid(kid, "kb_id")
# Build kb_filter: single KB or multiple KBs with OR
if len(kb_ids) == 1:
kb_filter = f"kb_id = '{kb_ids[0]}'"
else:
kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")"
kb_filter = "(" + " OR ".join([f"kb_id = '{kid}'" for kid in kb_ids]) + ")"
if "where " not in sql.lower():
o = sql.lower().split("order by")
@@ -927,6 +962,7 @@ async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=N
return sql
def is_row_count_question(q: str) -> bool:
"""Return True if *q* is asking for a total row count of a dataset or table."""
q = (q or "").lower()
if not re.search(r"\bhow many rows\b|\bnumber of rows\b|\brow count\b", q):
return False
@@ -936,11 +972,7 @@ async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=N
if doc_engine == "infinity":
# Build Infinity prompts with JSON extraction context
json_field_names = list(field_map.keys())
row_count_override = (
f"SELECT COUNT(*) AS rows FROM {table_name}"
if is_row_count_question(question)
else None
)
row_count_override = f"SELECT COUNT(*) AS rows FROM {table_name}" if is_row_count_question(question) else None
sys_prompt = """You are a Database Administrator. Write SQL for a table with JSON 'chunk_data' column.
JSON Extraction: json_extract_string(chunk_data, '$.FieldName')
@@ -964,19 +996,12 @@ Fields (EXACT case): {}
{}
Question: {}
Write SQL using json_extract_string() with exact field names. Include doc_id, docnm for data queries. Only SQL.""".format(
table_name,
", ".join(json_field_names),
"\n".join([f" - {field}" for field in json_field_names]),
question
table_name, ", ".join(json_field_names), "\n".join([f" - {field}" for field in json_field_names]), question
)
elif doc_engine == "oceanbase":
# Build OceanBase prompts with JSON extraction context
json_field_names = list(field_map.keys())
row_count_override = (
f"SELECT COUNT(*) AS rows FROM {table_name}"
if is_row_count_question(question)
else None
)
row_count_override = f"SELECT COUNT(*) AS rows FROM {table_name}" if is_row_count_question(question) else None
sys_prompt = """You are a Database Administrator. Write SQL for a table with JSON 'chunk_data' column.
JSON Extraction: json_extract_string(chunk_data, '$.FieldName')
@@ -1000,10 +1025,7 @@ Fields (EXACT case): {}
{}
Question: {}
Write SQL using json_extract_string() with exact field names. Include doc_id, docnm_kwd for data queries. Only SQL.""".format(
table_name,
", ".join(json_field_names),
"\n".join([f" - {field}" for field in json_field_names]),
question
table_name, ", ".join(json_field_names), "\n".join([f" - {field}" for field in json_field_names]), question
)
else:
# Build ES/OS prompts with direct field access
@@ -1021,11 +1043,7 @@ RULES:
Available fields:
{}
Question: {}
Write SQL using exact field names above. Include doc_id, docnm_kwd for data queries. Only SQL.""".format(
table_name,
"\n".join([f" - {k} ({v})" for k, v in field_map.items()]),
question
)
Write SQL using exact field names above. Include doc_id, docnm_kwd for data queries. Only SQL.""".format(table_name, "\n".join([f" - {k} ({v})" for k, v in field_map.items()]), question)
tried_times = 0
@@ -1063,13 +1081,7 @@ Previous SQL:
The previous SQL result is missing required source columns for citations.
Rewrite SQL to keep the same query intent and include doc_id and {} in the SELECT list.
For extracted JSON fields, use json_extract_string(chunk_data, '$.field_name').
Return ONLY SQL.""".format(
table_name,
"\n".join([f" - {field}" for field in json_field_names]),
question,
previous_sql,
expected_doc_name_column
)
Return ONLY SQL.""".format(table_name, "\n".join([f" - {field}" for field in json_field_names]), question, previous_sql, expected_doc_name_column)
else:
repair_prompt = """Table name: {}
Available fields:
@@ -1081,12 +1093,7 @@ Previous SQL:
The previous SQL result is missing required source columns for citations.
Rewrite SQL to keep the same query intent and include doc_id and docnm_kwd in the SELECT list.
Return ONLY SQL.""".format(
table_name,
"\n".join([f" - {k} ({v})" for k, v in field_map.items()]),
question,
previous_sql
)
Return ONLY SQL.""".format(table_name, "\n".join([f" - {k} ({v})" for k, v in field_map.items()]), question, previous_sql)
return await get_table(custom_user_prompt=repair_prompt)
try:
@@ -1146,11 +1153,7 @@ Please correct the error and write SQL again using json_extract_string(chunk_dat
logging.warning(f"use_sql: Non-aggregate SQL missing required source columns; retrying once. SQL: {sql}")
try:
repaired_tbl, repaired_sql = await repair_table_for_missing_source_columns(sql)
if (
repaired_tbl
and len(repaired_tbl.get("rows", [])) > 0
and has_source_columns(repaired_tbl.get("columns", []))
):
if repaired_tbl and len(repaired_tbl.get("rows", [])) > 0 and has_source_columns(repaired_tbl.get("columns", [])):
tbl, sql = repaired_tbl, repaired_sql
logging.info(f"use_sql: Source-column SQL repair succeeded. SQL: {sql}")
else:
@@ -1179,9 +1182,9 @@ Please correct the error and write SQL again using json_extract_string(chunk_dat
# First, try to extract AS alias from any expression (aggregate functions, json_extract_string, etc.)
# Pattern: anything AS alias_name
as_match = re.search(r'\s+AS\s+([^\s,)]+)', col_name, re.IGNORECASE)
as_match = re.search(r"\s+AS\s+([^\s,)]+)", col_name, re.IGNORECASE)
if as_match:
alias = as_match.group(1).strip('"\'')
alias = as_match.group(1).strip("\"'")
# Use the alias for display name lookup
if alias in field_map:
@@ -1218,11 +1221,7 @@ Please correct the error and write SQL again using json_extract_string(chunk_dat
return result
# compose Markdown table
columns = (
"|" + "|".join(
[map_column_name(tbl["columns"][i]["name"]) for i in column_idx]) + (
"|Source|" if docid_idx and doc_name_idx else "|")
)
columns = "|" + "|".join([map_column_name(tbl["columns"][i]["name"]) for i in column_idx]) + ("|Source|" if docid_idx and doc_name_idx else "|")
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "")
@@ -1342,6 +1341,7 @@ Please correct the error and write SQL again using json_extract_string(chunk_dat
logging.debug(f"use_sql: Returning answer with {len(result['reference']['chunks'])} chunks from {len(doc_aggs)} documents")
return result
def clean_tts_text(text: str) -> str:
if not text:
return ""
@@ -1351,15 +1351,7 @@ def clean_tts_text(text: str) -> str:
text = re.sub(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]", "", text)
emoji_pattern = re.compile(
"[\U0001F600-\U0001F64F"
"\U0001F300-\U0001F5FF"
"\U0001F680-\U0001F6FF"
"\U0001F1E0-\U0001F1FF"
"\U00002700-\U000027BF"
"\U0001F900-\U0001F9FF"
"\U0001FA70-\U0001FAFF"
"\U0001FAD0-\U0001FAFF]+",
flags=re.UNICODE
"[\U0001f600-\U0001f64f\U0001f300-\U0001f5ff\U0001f680-\U0001f6ff\U0001f1e0-\U0001f1ff\U00002700-\U000027bf\U0001f900-\U0001f9ff\U0001fa70-\U0001faff\U0001fad0-\U0001faff]+", flags=re.UNICODE
)
text = emoji_pattern.sub("", text)
@@ -1371,6 +1363,7 @@ def clean_tts_text(text: str) -> str:
return text
def tts(tts_mdl, text):
if not tts_mdl or not text:
return None
@@ -1416,13 +1409,13 @@ def _next_think_delta(state: _ThinkStreamState) -> str:
if full_text == state.last_full:
return ""
state.last_full = full_text
delta_ans = full_text[state.last_idx:]
delta_ans = full_text[state.last_idx :]
if delta_ans.find("<think>") == 0:
state.last_idx += len("<think>")
return "<think>"
if delta_ans.find("<think>") > 0:
delta_text = full_text[state.last_idx:state.last_idx + delta_ans.find("<think>")]
delta_text = full_text[state.last_idx : state.last_idx + delta_ans.find("<think>")]
state.last_idx += delta_ans.find("<think>")
return delta_text
if delta_ans.endswith("</think>"):
@@ -1443,7 +1436,7 @@ async def _stream_with_think_delta(stream_iter, min_tokens: int = 16):
if not chunk:
continue
if chunk.startswith(state.last_model_full):
new_part = chunk[len(state.last_model_full):]
new_part = chunk[len(state.last_model_full) :]
state.last_model_full = chunk
else:
new_part = chunk
@@ -1477,6 +1470,7 @@ async def _stream_with_think_delta(stream_iter, min_tokens: int = 16):
if state.endswith_think:
yield ("marker", "</think>", state)
async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
doc_ids = search_config.get("doc_ids", [])
rerank_mdl = None
@@ -1526,7 +1520,7 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf
doc_ids=doc_ids,
aggs=True,
rerank_mdl=rerank_mdl,
rank_feature=label_question(question, kbs)
rank_feature=label_question(question, kbs),
)
if include_reference_metadata:
logging.debug(
@@ -1543,8 +1537,7 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf
def decorate_answer(answer):
nonlocal knowledges, kbinfos, sys_prompt
answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]],
embd_mdl, tkweight=0.7, vtweight=0.3)
answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]], embd_mdl, tkweight=0.7, vtweight=0.3)
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
if not recall_docs:

View File

@@ -33,6 +33,7 @@ warnings.filterwarnings(
def _install_cv2_stub_if_unavailable():
try:
import cv2 # noqa: F401
return
except Exception:
pass