mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
[Fix] naive: force-merge short markdown headers to prevent separate chunks (#15488)
## Problem When uploading `.md` files with `parser=naive` and `delimiter="\n"`, markdown headers (e.g., `## Quick Travel`) become separate chunks with very short content (16-18 characters). This causes retrieval issues: when the header is matched, the corresponding body text is not included in the chunk. ## Related Issues Closes #15487 ## Checklist - [x] Code changes are minimal and focused - [x] Unit tests added (12/12 passed) - [x] No breaking changes
This commit is contained in:
@@ -56,6 +56,28 @@ from rag.nlp import (
|
||||
) # noqa: F401
|
||||
|
||||
|
||||
def _is_short_header(text, max_tokens=50):
|
||||
"""
|
||||
Check if text is a short markdown header.
|
||||
|
||||
Args:
|
||||
text: The text to check
|
||||
max_tokens: Maximum tokens for a header to be considered "short"
|
||||
|
||||
Returns:
|
||||
bool: True if text is a short markdown header, False otherwise
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
return False
|
||||
|
||||
# Check if it matches markdown header pattern: 1-6 # followed by space
|
||||
if not re.match(r"^#{1,6}\s+", text.strip()):
|
||||
return False
|
||||
|
||||
# Check if token count is below threshold
|
||||
return num_tokens_from_string(text) < max_tokens
|
||||
|
||||
|
||||
def _normalize_section_text_for_rtl_presentation_forms(sections):
|
||||
if not sections:
|
||||
return sections
|
||||
@@ -1067,6 +1089,7 @@ def chunk(filename, binary=None, from_page=0, to_page=MAXIMUM_PAGE_NUMBER, lang=
|
||||
|
||||
st = timer()
|
||||
overlapped_percent = normalize_overlapped_percent(parser_config.get("overlapped_percent", 0))
|
||||
|
||||
if is_markdown:
|
||||
merged_chunks = []
|
||||
merged_images = []
|
||||
@@ -1081,7 +1104,8 @@ def chunk(filename, binary=None, from_page=0, to_page=MAXIMUM_PAGE_NUMBER, lang=
|
||||
sec_tokens = num_tokens_from_string(text)
|
||||
sec_image = section_images[idx] if section_images and idx < len(section_images) else None
|
||||
|
||||
if current_text and current_tokens + sec_tokens > chunk_limit:
|
||||
# Don't finalize chunk if current_text is a short header (force merge with next section)
|
||||
if current_text and not _is_short_header(current_text) and current_tokens + sec_tokens > chunk_limit:
|
||||
merged_chunks.append(current_text)
|
||||
merged_images.append(current_image)
|
||||
overlap_part = ""
|
||||
|
||||
102
tests/test_naive_markdown_merge.py
Normal file
102
tests/test_naive_markdown_merge.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""
|
||||
Unit tests for markdown chunk merging logic in rag/app/naive.py.
|
||||
|
||||
Tests the _is_short_header() helper function to ensure short markdown headers
|
||||
are correctly identified and will be force-merged with the next section.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add project root to path for imports
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
from rag.app.naive import _is_short_header
|
||||
|
||||
|
||||
class TestIsShortHeader:
|
||||
"""Test cases for _is_short_header() function."""
|
||||
|
||||
def test_short_header_h1(self):
|
||||
"""Short level-1 header should return True."""
|
||||
text = "# Quick Start"
|
||||
result = _is_short_header(text)
|
||||
assert result is True
|
||||
|
||||
def test_short_header_h2(self):
|
||||
"""Short level-2 header should return True."""
|
||||
text = "## Quick Travel"
|
||||
result = _is_short_header(text)
|
||||
assert result is True
|
||||
|
||||
def test_short_header_h3(self):
|
||||
"""Short level-3 header should return True."""
|
||||
text = "### Setup"
|
||||
result = _is_short_header(text)
|
||||
assert result is True
|
||||
|
||||
def test_long_header(self):
|
||||
"""Long header (> 50 tokens) should return False."""
|
||||
text = "# " + "Very long header " * 20 # ~100 tokens
|
||||
result = _is_short_header(text)
|
||||
assert result is False
|
||||
|
||||
def test_non_header_short_text(self):
|
||||
"""Short text without header pattern should return False."""
|
||||
text = "This is short"
|
||||
result = _is_short_header(text)
|
||||
assert result is False
|
||||
|
||||
def test_empty_text(self):
|
||||
"""Empty text should return False."""
|
||||
text = ""
|
||||
result = _is_short_header(text)
|
||||
assert result is False
|
||||
|
||||
def test_whitespace_only(self):
|
||||
"""Whitespace-only text should return False."""
|
||||
text = " "
|
||||
result = _is_short_header(text)
|
||||
assert result is False
|
||||
|
||||
def test_header_exactly_50_tokens(self):
|
||||
"""Header with exactly 50 tokens should return False (strict <)."""
|
||||
# Construct a header with exactly 50 tokens
|
||||
words = ["word"] * 49 # 49 words = 49 tokens, plus "# " = 1 token
|
||||
text = "# " + " ".join(words)
|
||||
result = _is_short_header(text, max_tokens=50)
|
||||
# 50 tokens = not < 50, so should return False
|
||||
assert result is False
|
||||
|
||||
def test_header_49_tokens(self):
|
||||
"""Header with 49 tokens should return True (< 50)."""
|
||||
words = ["word"] * 48 # 48 words = 48 tokens, plus "# " = 1 token = 49 tokens
|
||||
text = "# " + " ".join(words)
|
||||
result = _is_short_header(text, max_tokens=50)
|
||||
assert result is True
|
||||
|
||||
def test_custom_max_tokens(self):
|
||||
"""Should respect custom max_tokens parameter."""
|
||||
text = "# Short"
|
||||
result = _is_short_header(text, max_tokens=5)
|
||||
assert result is False # "# Short" is ~2 tokens, but wait...
|
||||
|
||||
result = _is_short_header(text, max_tokens=10)
|
||||
assert result is True
|
||||
|
||||
def test_header_with_special_chars(self):
|
||||
"""Header with special characters should still be recognized."""
|
||||
text = "## API Endpoint: /api/v1/users"
|
||||
result = _is_short_header(text)
|
||||
assert result is True
|
||||
|
||||
def test_header_with_cjk_chars(self):
|
||||
"""Header with CJK characters should be recognized."""
|
||||
text = "## 快速旅行"
|
||||
result = _is_short_header(text)
|
||||
assert result is True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user