From ff5971448b8ea2bead108732f55e20c3e9a42679 Mon Sep 17 00:00:00 2001 From: VictorECDSA Date: Wed, 3 Jun 2026 10:49:28 +0800 Subject: [PATCH] [Fix] naive: force-merge short markdown headers to prevent separate chunks (#15488) ## Problem When uploading `.md` files with `parser=naive` and `delimiter="\n"`, markdown headers (e.g., `## Quick Travel`) become separate chunks with very short content (16-18 characters). This causes retrieval issues: when the header is matched, the corresponding body text is not included in the chunk. ## Related Issues Closes #15487 ## Checklist - [x] Code changes are minimal and focused - [x] Unit tests added (12/12 passed) - [x] No breaking changes --- rag/app/naive.py | 26 +++++++- tests/test_naive_markdown_merge.py | 102 +++++++++++++++++++++++++++++ 2 files changed, 127 insertions(+), 1 deletion(-) create mode 100644 tests/test_naive_markdown_merge.py diff --git a/rag/app/naive.py b/rag/app/naive.py index 1dac71e107..18f790003a 100644 --- a/rag/app/naive.py +++ b/rag/app/naive.py @@ -56,6 +56,28 @@ from rag.nlp import ( ) # noqa: F401 +def _is_short_header(text, max_tokens=50): + """ + Check if text is a short markdown header. + + Args: + text: The text to check + max_tokens: Maximum tokens for a header to be considered "short" + + Returns: + bool: True if text is a short markdown header, False otherwise + """ + if not text or not text.strip(): + return False + + # Check if it matches markdown header pattern: 1-6 # followed by space + if not re.match(r"^#{1,6}\s+", text.strip()): + return False + + # Check if token count is below threshold + return num_tokens_from_string(text) < max_tokens + + def _normalize_section_text_for_rtl_presentation_forms(sections): if not sections: return sections @@ -1067,6 +1089,7 @@ def chunk(filename, binary=None, from_page=0, to_page=MAXIMUM_PAGE_NUMBER, lang= st = timer() overlapped_percent = normalize_overlapped_percent(parser_config.get("overlapped_percent", 0)) + if is_markdown: merged_chunks = [] merged_images = [] @@ -1081,7 +1104,8 @@ def chunk(filename, binary=None, from_page=0, to_page=MAXIMUM_PAGE_NUMBER, lang= sec_tokens = num_tokens_from_string(text) sec_image = section_images[idx] if section_images and idx < len(section_images) else None - if current_text and current_tokens + sec_tokens > chunk_limit: + # Don't finalize chunk if current_text is a short header (force merge with next section) + if current_text and not _is_short_header(current_text) and current_tokens + sec_tokens > chunk_limit: merged_chunks.append(current_text) merged_images.append(current_image) overlap_part = "" diff --git a/tests/test_naive_markdown_merge.py b/tests/test_naive_markdown_merge.py new file mode 100644 index 0000000000..10aa3c3edb --- /dev/null +++ b/tests/test_naive_markdown_merge.py @@ -0,0 +1,102 @@ +""" +Unit tests for markdown chunk merging logic in rag/app/naive.py. + +Tests the _is_short_header() helper function to ensure short markdown headers +are correctly identified and will be force-merged with the next section. +""" + +import sys +import os + +# Add project root to path for imports +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) + +from rag.app.naive import _is_short_header + + +class TestIsShortHeader: + """Test cases for _is_short_header() function.""" + + def test_short_header_h1(self): + """Short level-1 header should return True.""" + text = "# Quick Start" + result = _is_short_header(text) + assert result is True + + def test_short_header_h2(self): + """Short level-2 header should return True.""" + text = "## Quick Travel" + result = _is_short_header(text) + assert result is True + + def test_short_header_h3(self): + """Short level-3 header should return True.""" + text = "### Setup" + result = _is_short_header(text) + assert result is True + + def test_long_header(self): + """Long header (> 50 tokens) should return False.""" + text = "# " + "Very long header " * 20 # ~100 tokens + result = _is_short_header(text) + assert result is False + + def test_non_header_short_text(self): + """Short text without header pattern should return False.""" + text = "This is short" + result = _is_short_header(text) + assert result is False + + def test_empty_text(self): + """Empty text should return False.""" + text = "" + result = _is_short_header(text) + assert result is False + + def test_whitespace_only(self): + """Whitespace-only text should return False.""" + text = " " + result = _is_short_header(text) + assert result is False + + def test_header_exactly_50_tokens(self): + """Header with exactly 50 tokens should return False (strict <).""" + # Construct a header with exactly 50 tokens + words = ["word"] * 49 # 49 words = 49 tokens, plus "# " = 1 token + text = "# " + " ".join(words) + result = _is_short_header(text, max_tokens=50) + # 50 tokens = not < 50, so should return False + assert result is False + + def test_header_49_tokens(self): + """Header with 49 tokens should return True (< 50).""" + words = ["word"] * 48 # 48 words = 48 tokens, plus "# " = 1 token = 49 tokens + text = "# " + " ".join(words) + result = _is_short_header(text, max_tokens=50) + assert result is True + + def test_custom_max_tokens(self): + """Should respect custom max_tokens parameter.""" + text = "# Short" + result = _is_short_header(text, max_tokens=5) + assert result is False # "# Short" is ~2 tokens, but wait... + + result = _is_short_header(text, max_tokens=10) + assert result is True + + def test_header_with_special_chars(self): + """Header with special characters should still be recognized.""" + text = "## API Endpoint: /api/v1/users" + result = _is_short_header(text) + assert result is True + + def test_header_with_cjk_chars(self): + """Header with CJK characters should be recognized.""" + text = "## 快速旅行" + result = _is_short_header(text) + assert result is True + + +if __name__ == "__main__": + import pytest + pytest.main([__file__, "-v"])