diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index c1d90ebe4c..6f981efb5e 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -18,7 +18,10 @@ import binascii import logging import re import time +import uuid from copy import deepcopy + +logger = logging.getLogger(__name__) from datetime import datetime from functools import partial from timeit import default_timer as timer @@ -45,8 +48,7 @@ from rag.graphrag.general.mind_map_extractor import MindMapExtractor from rag.advanced_rag import DeepResearcher from rag.app.tag import label_question from rag.nlp.search import index_name -from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, \ - PROMPT_JINJA_ENV, ASK_SUMMARY +from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, PROMPT_JINJA_ENV, ASK_SUMMARY from common.token_utils import num_tokens_from_string from rag.utils.tavily_conn import Tavily from common.string_utils import remove_redundant_spaces @@ -191,8 +193,7 @@ class DialogService(CommonService): cls.model.select(*fields) .join(User, on=(cls.model.tenant_id == User.id)) .where( - (cls.model.tenant_id.in_(joined_tenant_ids) | (cls.model.tenant_id == user_id)) - & (cls.model.status == StatusEnum.VALID.value), + (cls.model.tenant_id.in_(joined_tenant_ids) | (cls.model.tenant_id == user_id)) & (cls.model.status == StatusEnum.VALID.value), ) ) if id: @@ -233,22 +234,14 @@ class DialogService(CommonService): @classmethod @DB.connection_context() def get_null_tenant_llm_id_row(cls): - fields = [ - cls.model.id, - cls.model.tenant_id, - cls.model.llm_id - ] + fields = [cls.model.id, cls.model.tenant_id, cls.model.llm_id] objs = cls.model.select(*fields).where(cls.model.tenant_llm_id.is_null()) return list(objs) @classmethod @DB.connection_context() def get_null_tenant_rerank_id_row(cls): - fields = [ - cls.model.id, - cls.model.tenant_id, - cls.model.rerank_id - ] + fields = [cls.model.id, cls.model.tenant_id, cls.model.rerank_id] objs = cls.model.select(*fields).where(cls.model.tenant_rerank_id.is_null()) return list(objs) @@ -264,7 +257,7 @@ async def async_chat_solo(dialog, messages, stream=True): else: text_attachments, image_files = split_file_attachments(messages[-1]["files"], raw=True) attachments = "\n\n".join(text_attachments) - + if dialog.llm_id: model_config = get_model_config_by_type_and_name(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) elif dialog.tenant_llm_id: @@ -483,11 +476,11 @@ def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set): parts = [] last_idx = 0 for match in matches: - parts.append(answer[last_idx:match.start()]) + parts.append(answer[last_idx : match.start()]) try: i = int(match.group(group_index)) except Exception: - parts.append(answer[match.start():match.end()]) + parts.append(answer[match.start() : match.end()]) last_idx = match.end() continue @@ -496,7 +489,7 @@ def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set): digits_original = answer[digit_start:digit_end] parts.append(f"[{repl(digits_original)}]") else: - parts.append(answer[match.start():match.end()]) + parts.append(answer[match.start() : match.end()]) last_idx = match.end() parts.append(answer[last_idx:]) @@ -557,7 +550,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs): attachments = None if "doc_ids" in kwargs: attachments = [doc_id for doc_id in kwargs["doc_ids"].split(",") if doc_id] - attachments_= "" + attachments_ = "" image_attachments = [] image_files = [] if "doc_ids" in messages[-1]: @@ -656,7 +649,8 @@ async def async_chat(dialog, messages, stream=True, **kwargs): internet_enabled=use_web_search, ) queue = asyncio.Queue() - async def callback(msg:str): + + async def callback(msg: str): nonlocal queue await queue.put(msg + "
") @@ -703,8 +697,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs): kbinfos["doc_aggs"].extend(tav_res["doc_aggs"]) if prompt_config.get("use_kg"): default_chat_model = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.CHAT) - ck = await settings.kg_retriever.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl, - LLMBundle(dialog.tenant_id, default_chat_model)) + ck = await settings.kg_retriever.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl, LLMBundle(dialog.tenant_id, default_chat_model)) if ck["content_with_weight"]: kbinfos["chunks"].insert(0, ck) @@ -722,14 +715,13 @@ async def async_chat(dialog, messages, stream=True, **kwargs): retrieval_ts = timer() if not knowledges and prompt_config.get("empty_response"): empty_res = prompt_config["empty_response"] - yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions), - "audio_binary": tts(tts_mdl, empty_res), "final": True} + yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions), "audio_binary": tts(tts_mdl, empty_res), "final": True} return kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges) gen_conf = dialog.llm_setting - msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)+attachments_}] + msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs) + attachments_}] prompt4citation = "" if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)): prompt4citation = citation_prompt() @@ -823,9 +815,8 @@ async def async_chat(dialog, messages, stream=True, **kwargs): return {"answer": think + answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt), "created_at": time.time()} if langfuse_tracer: - langfuse_generation = langfuse_tracer.start_observation(as_type="generation", - trace_context=trace_context, name="chat", model=llm_model_config["llm_name"], - input={"prompt": prompt, "prompt4citation": prompt4citation, "messages": msg} + langfuse_generation = langfuse_tracer.start_generation( + trace_context=trace_context, name="chat", model=llm_model_config["llm_name"], input={"prompt": prompt, "prompt4citation": prompt4citation, "messages": msg} ) if stream: @@ -862,6 +853,25 @@ async def async_chat(dialog, messages, stream=True, **kwargs): async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None): + """Answer a natural-language question by generating and executing SQL against the document index. + + Detects the active document engine (Infinity, OceanBase, or Elasticsearch), asks the + chat model to produce the appropriate SQL, injects a validated kb_id filter, executes + the query, and returns formatted results with optional source citations. + + Args: + question: Natural-language question from the user. + field_map: Mapping of field names to types describing the indexed document schema. + tenant_id: Tenant identifier used to derive the target index/table name. + chat_mdl: LLM bundle used to generate SQL from the question. + quota: Whether to enforce token-quota checks (default True). + kb_ids: Optional list of knowledge-base UUIDs to restrict the query scope. + + Returns: + A dict with keys ``answer`` (formatted response string), ``reference`` + (dict of supporting document chunks and doc_aggs), and ``prompt`` + (the system prompt used), or ``None`` if SQL generation or execution fails. + """ logging.debug(f"use_sql: Question: {question}") # Determine which document engine we're using @@ -872,12 +882,20 @@ async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=N else: doc_engine = "es" + def _assert_valid_uuid(value: str, label: str = "id") -> None: + try: + uuid.UUID(str(value)) + except (ValueError, AttributeError, TypeError): + logger.warning("SQL injection guard rejected invalid %s value (length=%d)", label, len(str(value))) + raise ValueError(f"Invalid {label} format: {value!r}") + # Construct the full table name # For Elasticsearch: ragflow_{tenant_id} (kb_id is in WHERE clause) # For Infinity: ragflow_{tenant_id}_{kb_id} (each KB has its own table) base_table = index_name(tenant_id) if doc_engine == "infinity" and kb_ids and len(kb_ids) == 1: - # Infinity: append kb_id to table name + # Infinity: append kb_id to table name — validate before interpolating + _assert_valid_uuid(kb_ids[0], "kb_id") table_name = f"{base_table}_{kb_ids[0]}" logging.debug(f"use_sql: Using Infinity table name: {table_name}") else: @@ -888,13 +906,20 @@ async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=N expected_doc_name_column = "docnm" if doc_engine == "infinity" else "docnm_kwd" def has_source_columns(columns): + """Return True if the result set contains the columns needed to build source citations.""" normalized_names = {str(col.get("name", "")).lower() for col in columns} return "doc_id" in normalized_names and bool({"docnm_kwd", "docnm"} & normalized_names) def is_aggregate_sql(sql_text): + """Return True if *sql_text* contains an aggregate function (COUNT, SUM, AVG, MAX, MIN, DISTINCT).""" return bool(re.search(r"(count|sum|avg|max|min|distinct)\s*\(", (sql_text or "").lower())) def normalize_sql(sql): + """Strip LLM artefacts from *sql* and return a clean, executable SQL string. + + Removes ```` reasoning blocks, Chinese reasoning markers, markdown + code fences, and trailing semicolons that some engines reject. + """ logging.debug(f"use_sql: Raw SQL from LLM: {repr(sql[:500])}") # Remove think blocks if present (format: ...) sql = re.sub(r"\n.*?\n\s*", "", sql, flags=re.DOTALL) @@ -903,18 +928,28 @@ async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=N sql = re.sub(r"```(?:sql)?\s*", "", sql, flags=re.IGNORECASE) sql = re.sub(r"```\s*$", "", sql, flags=re.IGNORECASE) # Remove trailing semicolon that ES SQL parser doesn't like - return sql.rstrip().rstrip(';').strip() + return sql.rstrip().rstrip(";").strip() def add_kb_filter(sql): + """Inject a validated kb_id WHERE filter into *sql* for ES/OceanBase engines. + + Infinity encodes the knowledge-base scope in the table name, so this + function is a no-op for that engine. All kb_id values are validated as + canonical UUIDs before interpolation to prevent SQL injection. + """ # Add kb_id filter for ES/OS only (Infinity already has it in table name) if doc_engine == "infinity" or not kb_ids: return sql + # Validate all kb_ids are UUIDs before interpolating into SQL + for kid in kb_ids: + _assert_valid_uuid(kid, "kb_id") + # Build kb_filter: single KB or multiple KBs with OR if len(kb_ids) == 1: kb_filter = f"kb_id = '{kb_ids[0]}'" else: - kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")" + kb_filter = "(" + " OR ".join([f"kb_id = '{kid}'" for kid in kb_ids]) + ")" if "where " not in sql.lower(): o = sql.lower().split("order by") @@ -927,6 +962,7 @@ async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=N return sql def is_row_count_question(q: str) -> bool: + """Return True if *q* is asking for a total row count of a dataset or table.""" q = (q or "").lower() if not re.search(r"\bhow many rows\b|\bnumber of rows\b|\brow count\b", q): return False @@ -936,11 +972,7 @@ async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=N if doc_engine == "infinity": # Build Infinity prompts with JSON extraction context json_field_names = list(field_map.keys()) - row_count_override = ( - f"SELECT COUNT(*) AS rows FROM {table_name}" - if is_row_count_question(question) - else None - ) + row_count_override = f"SELECT COUNT(*) AS rows FROM {table_name}" if is_row_count_question(question) else None sys_prompt = """You are a Database Administrator. Write SQL for a table with JSON 'chunk_data' column. JSON Extraction: json_extract_string(chunk_data, '$.FieldName') @@ -964,19 +996,12 @@ Fields (EXACT case): {} {} Question: {} Write SQL using json_extract_string() with exact field names. Include doc_id, docnm for data queries. Only SQL.""".format( - table_name, - ", ".join(json_field_names), - "\n".join([f" - {field}" for field in json_field_names]), - question + table_name, ", ".join(json_field_names), "\n".join([f" - {field}" for field in json_field_names]), question ) elif doc_engine == "oceanbase": # Build OceanBase prompts with JSON extraction context json_field_names = list(field_map.keys()) - row_count_override = ( - f"SELECT COUNT(*) AS rows FROM {table_name}" - if is_row_count_question(question) - else None - ) + row_count_override = f"SELECT COUNT(*) AS rows FROM {table_name}" if is_row_count_question(question) else None sys_prompt = """You are a Database Administrator. Write SQL for a table with JSON 'chunk_data' column. JSON Extraction: json_extract_string(chunk_data, '$.FieldName') @@ -1000,10 +1025,7 @@ Fields (EXACT case): {} {} Question: {} Write SQL using json_extract_string() with exact field names. Include doc_id, docnm_kwd for data queries. Only SQL.""".format( - table_name, - ", ".join(json_field_names), - "\n".join([f" - {field}" for field in json_field_names]), - question + table_name, ", ".join(json_field_names), "\n".join([f" - {field}" for field in json_field_names]), question ) else: # Build ES/OS prompts with direct field access @@ -1021,11 +1043,7 @@ RULES: Available fields: {} Question: {} -Write SQL using exact field names above. Include doc_id, docnm_kwd for data queries. Only SQL.""".format( - table_name, - "\n".join([f" - {k} ({v})" for k, v in field_map.items()]), - question - ) +Write SQL using exact field names above. Include doc_id, docnm_kwd for data queries. Only SQL.""".format(table_name, "\n".join([f" - {k} ({v})" for k, v in field_map.items()]), question) tried_times = 0 @@ -1063,13 +1081,7 @@ Previous SQL: The previous SQL result is missing required source columns for citations. Rewrite SQL to keep the same query intent and include doc_id and {} in the SELECT list. For extracted JSON fields, use json_extract_string(chunk_data, '$.field_name'). -Return ONLY SQL.""".format( - table_name, - "\n".join([f" - {field}" for field in json_field_names]), - question, - previous_sql, - expected_doc_name_column - ) +Return ONLY SQL.""".format(table_name, "\n".join([f" - {field}" for field in json_field_names]), question, previous_sql, expected_doc_name_column) else: repair_prompt = """Table name: {} Available fields: @@ -1081,12 +1093,7 @@ Previous SQL: The previous SQL result is missing required source columns for citations. Rewrite SQL to keep the same query intent and include doc_id and docnm_kwd in the SELECT list. -Return ONLY SQL.""".format( - table_name, - "\n".join([f" - {k} ({v})" for k, v in field_map.items()]), - question, - previous_sql - ) +Return ONLY SQL.""".format(table_name, "\n".join([f" - {k} ({v})" for k, v in field_map.items()]), question, previous_sql) return await get_table(custom_user_prompt=repair_prompt) try: @@ -1146,11 +1153,7 @@ Please correct the error and write SQL again using json_extract_string(chunk_dat logging.warning(f"use_sql: Non-aggregate SQL missing required source columns; retrying once. SQL: {sql}") try: repaired_tbl, repaired_sql = await repair_table_for_missing_source_columns(sql) - if ( - repaired_tbl - and len(repaired_tbl.get("rows", [])) > 0 - and has_source_columns(repaired_tbl.get("columns", [])) - ): + if repaired_tbl and len(repaired_tbl.get("rows", [])) > 0 and has_source_columns(repaired_tbl.get("columns", [])): tbl, sql = repaired_tbl, repaired_sql logging.info(f"use_sql: Source-column SQL repair succeeded. SQL: {sql}") else: @@ -1179,9 +1182,9 @@ Please correct the error and write SQL again using json_extract_string(chunk_dat # First, try to extract AS alias from any expression (aggregate functions, json_extract_string, etc.) # Pattern: anything AS alias_name - as_match = re.search(r'\s+AS\s+([^\s,)]+)', col_name, re.IGNORECASE) + as_match = re.search(r"\s+AS\s+([^\s,)]+)", col_name, re.IGNORECASE) if as_match: - alias = as_match.group(1).strip('"\'') + alias = as_match.group(1).strip("\"'") # Use the alias for display name lookup if alias in field_map: @@ -1218,11 +1221,7 @@ Please correct the error and write SQL again using json_extract_string(chunk_dat return result # compose Markdown table - columns = ( - "|" + "|".join( - [map_column_name(tbl["columns"][i]["name"]) for i in column_idx]) + ( - "|Source|" if docid_idx and doc_name_idx else "|") - ) + columns = "|" + "|".join([map_column_name(tbl["columns"][i]["name"]) for i in column_idx]) + ("|Source|" if docid_idx and doc_name_idx else "|") line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "") @@ -1342,6 +1341,7 @@ Please correct the error and write SQL again using json_extract_string(chunk_dat logging.debug(f"use_sql: Returning answer with {len(result['reference']['chunks'])} chunks from {len(doc_aggs)} documents") return result + def clean_tts_text(text: str) -> str: if not text: return "" @@ -1351,15 +1351,7 @@ def clean_tts_text(text: str) -> str: text = re.sub(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]", "", text) emoji_pattern = re.compile( - "[\U0001F600-\U0001F64F" - "\U0001F300-\U0001F5FF" - "\U0001F680-\U0001F6FF" - "\U0001F1E0-\U0001F1FF" - "\U00002700-\U000027BF" - "\U0001F900-\U0001F9FF" - "\U0001FA70-\U0001FAFF" - "\U0001FAD0-\U0001FAFF]+", - flags=re.UNICODE + "[\U0001f600-\U0001f64f\U0001f300-\U0001f5ff\U0001f680-\U0001f6ff\U0001f1e0-\U0001f1ff\U00002700-\U000027bf\U0001f900-\U0001f9ff\U0001fa70-\U0001faff\U0001fad0-\U0001faff]+", flags=re.UNICODE ) text = emoji_pattern.sub("", text) @@ -1371,6 +1363,7 @@ def clean_tts_text(text: str) -> str: return text + def tts(tts_mdl, text): if not tts_mdl or not text: return None @@ -1416,13 +1409,13 @@ def _next_think_delta(state: _ThinkStreamState) -> str: if full_text == state.last_full: return "" state.last_full = full_text - delta_ans = full_text[state.last_idx:] + delta_ans = full_text[state.last_idx :] if delta_ans.find("") == 0: state.last_idx += len("") return "" if delta_ans.find("") > 0: - delta_text = full_text[state.last_idx:state.last_idx + delta_ans.find("")] + delta_text = full_text[state.last_idx : state.last_idx + delta_ans.find("")] state.last_idx += delta_ans.find("") return delta_text if delta_ans.endswith(""): @@ -1443,7 +1436,7 @@ async def _stream_with_think_delta(stream_iter, min_tokens: int = 16): if not chunk: continue if chunk.startswith(state.last_model_full): - new_part = chunk[len(state.last_model_full):] + new_part = chunk[len(state.last_model_full) :] state.last_model_full = chunk else: new_part = chunk @@ -1477,6 +1470,7 @@ async def _stream_with_think_delta(stream_iter, min_tokens: int = 16): if state.endswith_think: yield ("marker", "", state) + async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}): doc_ids = search_config.get("doc_ids", []) rerank_mdl = None @@ -1526,7 +1520,7 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf doc_ids=doc_ids, aggs=True, rerank_mdl=rerank_mdl, - rank_feature=label_question(question, kbs) + rank_feature=label_question(question, kbs), ) if include_reference_metadata: logging.debug( @@ -1543,8 +1537,7 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf def decorate_answer(answer): nonlocal knowledges, kbinfos, sys_prompt - answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]], - embd_mdl, tkweight=0.7, vtweight=0.3) + answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]], embd_mdl, tkweight=0.7, vtweight=0.3) idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] if not recall_docs: diff --git a/test/unit_test/api/db/services/test_dialog_service_use_sql_source_columns.py b/test/unit_test/api/db/services/test_dialog_service_use_sql_source_columns.py index 71941e3874..5910781be4 100644 --- a/test/unit_test/api/db/services/test_dialog_service_use_sql_source_columns.py +++ b/test/unit_test/api/db/services/test_dialog_service_use_sql_source_columns.py @@ -33,6 +33,7 @@ warnings.filterwarnings( def _install_cv2_stub_if_unavailable(): try: import cv2 # noqa: F401 + return except Exception: pass