From c42818735096d08e6dece7b272eb235de5ae2d62 Mon Sep 17 00:00:00 2001 From: Xing Hong <39619359+xingxing21@users.noreply.github.com> Date: Sat, 9 May 2026 11:52:06 +0900 Subject: [PATCH] Fix: validate kb_ids as UUIDs before SQL interpolation in use_sql (#14087) ### What problem does this PR solve? The use_sql() function in dialog_service.py constructed SQL WHERE clauses and Infinity table names by directly interpolating kb_id values using Python f-strings, with no validation of the input values. A malformed or maliciously crafted kb_id (introduced via a compromised admin account or a separate injection vector) could alter the structure of the generated SQL query, potentially leading to unauthorized data access or data manipulation. This PR adds strict UUID format validation for all kb_id values before they are interpolated into any SQL string, causing requests with invalid IDs to fail fast with a ValueError rather than executing a tampered query. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --------- Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- api/db/services/dialog_service.py | 181 +++++++++--------- ...t_dialog_service_use_sql_source_columns.py | 1 + 2 files changed, 88 insertions(+), 94 deletions(-) diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index c1d90ebe4c..6f981efb5e 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -18,7 +18,10 @@ import binascii import logging import re import time +import uuid from copy import deepcopy + +logger = logging.getLogger(__name__) from datetime import datetime from functools import partial from timeit import default_timer as timer @@ -45,8 +48,7 @@ from rag.graphrag.general.mind_map_extractor import MindMapExtractor from rag.advanced_rag import DeepResearcher from rag.app.tag import label_question from rag.nlp.search import index_name -from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, \ - PROMPT_JINJA_ENV, ASK_SUMMARY +from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, PROMPT_JINJA_ENV, ASK_SUMMARY from common.token_utils import num_tokens_from_string from rag.utils.tavily_conn import Tavily from common.string_utils import remove_redundant_spaces @@ -191,8 +193,7 @@ class DialogService(CommonService): cls.model.select(*fields) .join(User, on=(cls.model.tenant_id == User.id)) .where( - (cls.model.tenant_id.in_(joined_tenant_ids) | (cls.model.tenant_id == user_id)) - & (cls.model.status == StatusEnum.VALID.value), + (cls.model.tenant_id.in_(joined_tenant_ids) | (cls.model.tenant_id == user_id)) & (cls.model.status == StatusEnum.VALID.value), ) ) if id: @@ -233,22 +234,14 @@ class DialogService(CommonService): @classmethod @DB.connection_context() def get_null_tenant_llm_id_row(cls): - fields = [ - cls.model.id, - cls.model.tenant_id, - cls.model.llm_id - ] + fields = [cls.model.id, cls.model.tenant_id, cls.model.llm_id] objs = cls.model.select(*fields).where(cls.model.tenant_llm_id.is_null()) return list(objs) @classmethod @DB.connection_context() def get_null_tenant_rerank_id_row(cls): - fields = [ - cls.model.id, - cls.model.tenant_id, - cls.model.rerank_id - ] + fields = [cls.model.id, cls.model.tenant_id, cls.model.rerank_id] objs = cls.model.select(*fields).where(cls.model.tenant_rerank_id.is_null()) return list(objs) @@ -264,7 +257,7 @@ async def async_chat_solo(dialog, messages, stream=True): else: text_attachments, image_files = split_file_attachments(messages[-1]["files"], raw=True) attachments = "\n\n".join(text_attachments) - + if dialog.llm_id: model_config = get_model_config_by_type_and_name(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) elif dialog.tenant_llm_id: @@ -483,11 +476,11 @@ def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set): parts = [] last_idx = 0 for match in matches: - parts.append(answer[last_idx:match.start()]) + parts.append(answer[last_idx : match.start()]) try: i = int(match.group(group_index)) except Exception: - parts.append(answer[match.start():match.end()]) + parts.append(answer[match.start() : match.end()]) last_idx = match.end() continue @@ -496,7 +489,7 @@ def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set): digits_original = answer[digit_start:digit_end] parts.append(f"[{repl(digits_original)}]") else: - parts.append(answer[match.start():match.end()]) + parts.append(answer[match.start() : match.end()]) last_idx = match.end() parts.append(answer[last_idx:]) @@ -557,7 +550,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs): attachments = None if "doc_ids" in kwargs: attachments = [doc_id for doc_id in kwargs["doc_ids"].split(",") if doc_id] - attachments_= "" + attachments_ = "" image_attachments = [] image_files = [] if "doc_ids" in messages[-1]: @@ -656,7 +649,8 @@ async def async_chat(dialog, messages, stream=True, **kwargs): internet_enabled=use_web_search, ) queue = asyncio.Queue() - async def callback(msg:str): + + async def callback(msg: str): nonlocal queue await queue.put(msg + "
") @@ -703,8 +697,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs): kbinfos["doc_aggs"].extend(tav_res["doc_aggs"]) if prompt_config.get("use_kg"): default_chat_model = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.CHAT) - ck = await settings.kg_retriever.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl, - LLMBundle(dialog.tenant_id, default_chat_model)) + ck = await settings.kg_retriever.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl, LLMBundle(dialog.tenant_id, default_chat_model)) if ck["content_with_weight"]: kbinfos["chunks"].insert(0, ck) @@ -722,14 +715,13 @@ async def async_chat(dialog, messages, stream=True, **kwargs): retrieval_ts = timer() if not knowledges and prompt_config.get("empty_response"): empty_res = prompt_config["empty_response"] - yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions), - "audio_binary": tts(tts_mdl, empty_res), "final": True} + yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions), "audio_binary": tts(tts_mdl, empty_res), "final": True} return kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges) gen_conf = dialog.llm_setting - msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)+attachments_}] + msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs) + attachments_}] prompt4citation = "" if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)): prompt4citation = citation_prompt() @@ -823,9 +815,8 @@ async def async_chat(dialog, messages, stream=True, **kwargs): return {"answer": think + answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt), "created_at": time.time()} if langfuse_tracer: - langfuse_generation = langfuse_tracer.start_observation(as_type="generation", - trace_context=trace_context, name="chat", model=llm_model_config["llm_name"], - input={"prompt": prompt, "prompt4citation": prompt4citation, "messages": msg} + langfuse_generation = langfuse_tracer.start_generation( + trace_context=trace_context, name="chat", model=llm_model_config["llm_name"], input={"prompt": prompt, "prompt4citation": prompt4citation, "messages": msg} ) if stream: @@ -862,6 +853,25 @@ async def async_chat(dialog, messages, stream=True, **kwargs): async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None): + """Answer a natural-language question by generating and executing SQL against the document index. + + Detects the active document engine (Infinity, OceanBase, or Elasticsearch), asks the + chat model to produce the appropriate SQL, injects a validated kb_id filter, executes + the query, and returns formatted results with optional source citations. + + Args: + question: Natural-language question from the user. + field_map: Mapping of field names to types describing the indexed document schema. + tenant_id: Tenant identifier used to derive the target index/table name. + chat_mdl: LLM bundle used to generate SQL from the question. + quota: Whether to enforce token-quota checks (default True). + kb_ids: Optional list of knowledge-base UUIDs to restrict the query scope. + + Returns: + A dict with keys ``answer`` (formatted response string), ``reference`` + (dict of supporting document chunks and doc_aggs), and ``prompt`` + (the system prompt used), or ``None`` if SQL generation or execution fails. + """ logging.debug(f"use_sql: Question: {question}") # Determine which document engine we're using @@ -872,12 +882,20 @@ async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=N else: doc_engine = "es" + def _assert_valid_uuid(value: str, label: str = "id") -> None: + try: + uuid.UUID(str(value)) + except (ValueError, AttributeError, TypeError): + logger.warning("SQL injection guard rejected invalid %s value (length=%d)", label, len(str(value))) + raise ValueError(f"Invalid {label} format: {value!r}") + # Construct the full table name # For Elasticsearch: ragflow_{tenant_id} (kb_id is in WHERE clause) # For Infinity: ragflow_{tenant_id}_{kb_id} (each KB has its own table) base_table = index_name(tenant_id) if doc_engine == "infinity" and kb_ids and len(kb_ids) == 1: - # Infinity: append kb_id to table name + # Infinity: append kb_id to table name — validate before interpolating + _assert_valid_uuid(kb_ids[0], "kb_id") table_name = f"{base_table}_{kb_ids[0]}" logging.debug(f"use_sql: Using Infinity table name: {table_name}") else: @@ -888,13 +906,20 @@ async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=N expected_doc_name_column = "docnm" if doc_engine == "infinity" else "docnm_kwd" def has_source_columns(columns): + """Return True if the result set contains the columns needed to build source citations.""" normalized_names = {str(col.get("name", "")).lower() for col in columns} return "doc_id" in normalized_names and bool({"docnm_kwd", "docnm"} & normalized_names) def is_aggregate_sql(sql_text): + """Return True if *sql_text* contains an aggregate function (COUNT, SUM, AVG, MAX, MIN, DISTINCT).""" return bool(re.search(r"(count|sum|avg|max|min|distinct)\s*\(", (sql_text or "").lower())) def normalize_sql(sql): + """Strip LLM artefacts from *sql* and return a clean, executable SQL string. + + Removes ```` reasoning blocks, Chinese reasoning markers, markdown + code fences, and trailing semicolons that some engines reject. + """ logging.debug(f"use_sql: Raw SQL from LLM: {repr(sql[:500])}") # Remove think blocks if present (format: ...) sql = re.sub(r"\n.*?\n\s*", "", sql, flags=re.DOTALL) @@ -903,18 +928,28 @@ async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=N sql = re.sub(r"```(?:sql)?\s*", "", sql, flags=re.IGNORECASE) sql = re.sub(r"```\s*$", "", sql, flags=re.IGNORECASE) # Remove trailing semicolon that ES SQL parser doesn't like - return sql.rstrip().rstrip(';').strip() + return sql.rstrip().rstrip(";").strip() def add_kb_filter(sql): + """Inject a validated kb_id WHERE filter into *sql* for ES/OceanBase engines. + + Infinity encodes the knowledge-base scope in the table name, so this + function is a no-op for that engine. All kb_id values are validated as + canonical UUIDs before interpolation to prevent SQL injection. + """ # Add kb_id filter for ES/OS only (Infinity already has it in table name) if doc_engine == "infinity" or not kb_ids: return sql + # Validate all kb_ids are UUIDs before interpolating into SQL + for kid in kb_ids: + _assert_valid_uuid(kid, "kb_id") + # Build kb_filter: single KB or multiple KBs with OR if len(kb_ids) == 1: kb_filter = f"kb_id = '{kb_ids[0]}'" else: - kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")" + kb_filter = "(" + " OR ".join([f"kb_id = '{kid}'" for kid in kb_ids]) + ")" if "where " not in sql.lower(): o = sql.lower().split("order by") @@ -927,6 +962,7 @@ async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=N return sql def is_row_count_question(q: str) -> bool: + """Return True if *q* is asking for a total row count of a dataset or table.""" q = (q or "").lower() if not re.search(r"\bhow many rows\b|\bnumber of rows\b|\brow count\b", q): return False @@ -936,11 +972,7 @@ async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=N if doc_engine == "infinity": # Build Infinity prompts with JSON extraction context json_field_names = list(field_map.keys()) - row_count_override = ( - f"SELECT COUNT(*) AS rows FROM {table_name}" - if is_row_count_question(question) - else None - ) + row_count_override = f"SELECT COUNT(*) AS rows FROM {table_name}" if is_row_count_question(question) else None sys_prompt = """You are a Database Administrator. Write SQL for a table with JSON 'chunk_data' column. JSON Extraction: json_extract_string(chunk_data, '$.FieldName') @@ -964,19 +996,12 @@ Fields (EXACT case): {} {} Question: {} Write SQL using json_extract_string() with exact field names. Include doc_id, docnm for data queries. Only SQL.""".format( - table_name, - ", ".join(json_field_names), - "\n".join([f" - {field}" for field in json_field_names]), - question + table_name, ", ".join(json_field_names), "\n".join([f" - {field}" for field in json_field_names]), question ) elif doc_engine == "oceanbase": # Build OceanBase prompts with JSON extraction context json_field_names = list(field_map.keys()) - row_count_override = ( - f"SELECT COUNT(*) AS rows FROM {table_name}" - if is_row_count_question(question) - else None - ) + row_count_override = f"SELECT COUNT(*) AS rows FROM {table_name}" if is_row_count_question(question) else None sys_prompt = """You are a Database Administrator. Write SQL for a table with JSON 'chunk_data' column. JSON Extraction: json_extract_string(chunk_data, '$.FieldName') @@ -1000,10 +1025,7 @@ Fields (EXACT case): {} {} Question: {} Write SQL using json_extract_string() with exact field names. Include doc_id, docnm_kwd for data queries. Only SQL.""".format( - table_name, - ", ".join(json_field_names), - "\n".join([f" - {field}" for field in json_field_names]), - question + table_name, ", ".join(json_field_names), "\n".join([f" - {field}" for field in json_field_names]), question ) else: # Build ES/OS prompts with direct field access @@ -1021,11 +1043,7 @@ RULES: Available fields: {} Question: {} -Write SQL using exact field names above. Include doc_id, docnm_kwd for data queries. Only SQL.""".format( - table_name, - "\n".join([f" - {k} ({v})" for k, v in field_map.items()]), - question - ) +Write SQL using exact field names above. Include doc_id, docnm_kwd for data queries. Only SQL.""".format(table_name, "\n".join([f" - {k} ({v})" for k, v in field_map.items()]), question) tried_times = 0 @@ -1063,13 +1081,7 @@ Previous SQL: The previous SQL result is missing required source columns for citations. Rewrite SQL to keep the same query intent and include doc_id and {} in the SELECT list. For extracted JSON fields, use json_extract_string(chunk_data, '$.field_name'). -Return ONLY SQL.""".format( - table_name, - "\n".join([f" - {field}" for field in json_field_names]), - question, - previous_sql, - expected_doc_name_column - ) +Return ONLY SQL.""".format(table_name, "\n".join([f" - {field}" for field in json_field_names]), question, previous_sql, expected_doc_name_column) else: repair_prompt = """Table name: {} Available fields: @@ -1081,12 +1093,7 @@ Previous SQL: The previous SQL result is missing required source columns for citations. Rewrite SQL to keep the same query intent and include doc_id and docnm_kwd in the SELECT list. -Return ONLY SQL.""".format( - table_name, - "\n".join([f" - {k} ({v})" for k, v in field_map.items()]), - question, - previous_sql - ) +Return ONLY SQL.""".format(table_name, "\n".join([f" - {k} ({v})" for k, v in field_map.items()]), question, previous_sql) return await get_table(custom_user_prompt=repair_prompt) try: @@ -1146,11 +1153,7 @@ Please correct the error and write SQL again using json_extract_string(chunk_dat logging.warning(f"use_sql: Non-aggregate SQL missing required source columns; retrying once. SQL: {sql}") try: repaired_tbl, repaired_sql = await repair_table_for_missing_source_columns(sql) - if ( - repaired_tbl - and len(repaired_tbl.get("rows", [])) > 0 - and has_source_columns(repaired_tbl.get("columns", [])) - ): + if repaired_tbl and len(repaired_tbl.get("rows", [])) > 0 and has_source_columns(repaired_tbl.get("columns", [])): tbl, sql = repaired_tbl, repaired_sql logging.info(f"use_sql: Source-column SQL repair succeeded. SQL: {sql}") else: @@ -1179,9 +1182,9 @@ Please correct the error and write SQL again using json_extract_string(chunk_dat # First, try to extract AS alias from any expression (aggregate functions, json_extract_string, etc.) # Pattern: anything AS alias_name - as_match = re.search(r'\s+AS\s+([^\s,)]+)', col_name, re.IGNORECASE) + as_match = re.search(r"\s+AS\s+([^\s,)]+)", col_name, re.IGNORECASE) if as_match: - alias = as_match.group(1).strip('"\'') + alias = as_match.group(1).strip("\"'") # Use the alias for display name lookup if alias in field_map: @@ -1218,11 +1221,7 @@ Please correct the error and write SQL again using json_extract_string(chunk_dat return result # compose Markdown table - columns = ( - "|" + "|".join( - [map_column_name(tbl["columns"][i]["name"]) for i in column_idx]) + ( - "|Source|" if docid_idx and doc_name_idx else "|") - ) + columns = "|" + "|".join([map_column_name(tbl["columns"][i]["name"]) for i in column_idx]) + ("|Source|" if docid_idx and doc_name_idx else "|") line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "") @@ -1342,6 +1341,7 @@ Please correct the error and write SQL again using json_extract_string(chunk_dat logging.debug(f"use_sql: Returning answer with {len(result['reference']['chunks'])} chunks from {len(doc_aggs)} documents") return result + def clean_tts_text(text: str) -> str: if not text: return "" @@ -1351,15 +1351,7 @@ def clean_tts_text(text: str) -> str: text = re.sub(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]", "", text) emoji_pattern = re.compile( - "[\U0001F600-\U0001F64F" - "\U0001F300-\U0001F5FF" - "\U0001F680-\U0001F6FF" - "\U0001F1E0-\U0001F1FF" - "\U00002700-\U000027BF" - "\U0001F900-\U0001F9FF" - "\U0001FA70-\U0001FAFF" - "\U0001FAD0-\U0001FAFF]+", - flags=re.UNICODE + "[\U0001f600-\U0001f64f\U0001f300-\U0001f5ff\U0001f680-\U0001f6ff\U0001f1e0-\U0001f1ff\U00002700-\U000027bf\U0001f900-\U0001f9ff\U0001fa70-\U0001faff\U0001fad0-\U0001faff]+", flags=re.UNICODE ) text = emoji_pattern.sub("", text) @@ -1371,6 +1363,7 @@ def clean_tts_text(text: str) -> str: return text + def tts(tts_mdl, text): if not tts_mdl or not text: return None @@ -1416,13 +1409,13 @@ def _next_think_delta(state: _ThinkStreamState) -> str: if full_text == state.last_full: return "" state.last_full = full_text - delta_ans = full_text[state.last_idx:] + delta_ans = full_text[state.last_idx :] if delta_ans.find("") == 0: state.last_idx += len("") return "" if delta_ans.find("") > 0: - delta_text = full_text[state.last_idx:state.last_idx + delta_ans.find("")] + delta_text = full_text[state.last_idx : state.last_idx + delta_ans.find("")] state.last_idx += delta_ans.find("") return delta_text if delta_ans.endswith(""): @@ -1443,7 +1436,7 @@ async def _stream_with_think_delta(stream_iter, min_tokens: int = 16): if not chunk: continue if chunk.startswith(state.last_model_full): - new_part = chunk[len(state.last_model_full):] + new_part = chunk[len(state.last_model_full) :] state.last_model_full = chunk else: new_part = chunk @@ -1477,6 +1470,7 @@ async def _stream_with_think_delta(stream_iter, min_tokens: int = 16): if state.endswith_think: yield ("marker", "", state) + async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}): doc_ids = search_config.get("doc_ids", []) rerank_mdl = None @@ -1526,7 +1520,7 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf doc_ids=doc_ids, aggs=True, rerank_mdl=rerank_mdl, - rank_feature=label_question(question, kbs) + rank_feature=label_question(question, kbs), ) if include_reference_metadata: logging.debug( @@ -1543,8 +1537,7 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf def decorate_answer(answer): nonlocal knowledges, kbinfos, sys_prompt - answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]], - embd_mdl, tkweight=0.7, vtweight=0.3) + answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]], embd_mdl, tkweight=0.7, vtweight=0.3) idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] if not recall_docs: diff --git a/test/unit_test/api/db/services/test_dialog_service_use_sql_source_columns.py b/test/unit_test/api/db/services/test_dialog_service_use_sql_source_columns.py index 71941e3874..5910781be4 100644 --- a/test/unit_test/api/db/services/test_dialog_service_use_sql_source_columns.py +++ b/test/unit_test/api/db/services/test_dialog_service_use_sql_source_columns.py @@ -33,6 +33,7 @@ warnings.filterwarnings( def _install_cv2_stub_if_unavailable(): try: import cv2 # noqa: F401 + return except Exception: pass