Files
ragflow/test/unit_test/rag/utils/test_raptor_utils.py
Jack f0cb7a544b Refactor: Task Executor (#15154)
### What problem does this PR solve?

1. Break huge function into smaller pieces
2. Add unit test for the smaller pieces function
3. Layer-ed design
a. infra layer - task_context.py, recording_context.py,
write_operation_interceptor.py, ...
    b. service layer - *_service.py
    c. business layer - task_handler.py
4. Default behavior: use "refactor-ed version" - can switch to original
version by change env variable

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Refactoring
- [x] Performance Improvement

---------

Co-authored-by: Liu An <asiro@qq.com>
Co-authored-by: Zhichang Yu <yuzhichang@gmail.com>
2026-05-27 21:54:17 +08:00

453 lines
17 KiB
Python

#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Unit tests for rag/utils/raptor_utils.py module.
"""
import pytest
from rag.utils.raptor_utils import (
RAPTOR_TREE_BUILDER,
PSI_TREE_BUILDER,
GMM_CLUSTERING_METHOD,
AHC_CLUSTERING_METHOD,
get_raptor_tree_builder,
get_raptor_clustering_method,
_as_extra_dict,
_has_raptor_marker,
_raptor_methods_from_fields,
collect_raptor_methods,
collect_raptor_chunk_ids,
make_raptor_summary_chunk_id,
is_structured_file_type,
is_tabular_pdf,
should_skip_raptor,
get_skip_reason,
)
class TestGetRaptorTreeBuilder:
"""Tests for get_raptor_tree_builder function."""
def test_returns_default_raptor_tree_builder(self):
"""Test that default tree builder is 'raptor'."""
result = get_raptor_tree_builder(None)
assert result == RAPTOR_TREE_BUILDER
def test_returns_default_with_empty_config(self):
"""Test that empty config returns default."""
result = get_raptor_tree_builder({})
assert result == RAPTOR_TREE_BUILDER
def test_returns_configured_tree_builder(self):
"""Test that configured tree builder is returned."""
config = {"tree_builder": PSI_TREE_BUILDER}
result = get_raptor_tree_builder(config)
assert result == PSI_TREE_BUILDER
def test_returns_ext_tree_builder(self):
"""Test that ext.tree_builder takes precedence."""
config = {"tree_builder": "old", "ext": {"tree_builder": PSI_TREE_BUILDER}}
result = get_raptor_tree_builder(config)
assert result == PSI_TREE_BUILDER
def test_raises_error_for_unsupported_tree_builder(self):
"""Test that unsupported tree builder raises ValueError."""
config = {"tree_builder": "unknown"}
with pytest.raises(ValueError, match="Unsupported RAPTOR tree builder"):
get_raptor_tree_builder(config)
class TestGetRaptorClusteringMethod:
"""Tests for get_raptor_clustering_method function."""
def test_returns_default_gmm(self):
"""Test that default clustering method is 'gmm'."""
result = get_raptor_clustering_method(None)
assert result == GMM_CLUSTERING_METHOD
def test_returns_configured_clustering_method(self):
"""Test that configured clustering method is returned."""
config = {"clustering_method": AHC_CLUSTERING_METHOD}
result = get_raptor_clustering_method(config)
assert result == AHC_CLUSTERING_METHOD
def test_returns_ext_clustering_method(self):
"""Test that ext.clustering_method takes precedence."""
config = {"clustering_method": "old", "ext": {"clustering_method": AHC_CLUSTERING_METHOD}}
result = get_raptor_clustering_method(config)
assert result == AHC_CLUSTERING_METHOD
def test_raises_error_for_unsupported_clustering_method(self):
"""Test that unsupported clustering method raises ValueError."""
config = {"clustering_method": "unknown"}
with pytest.raises(ValueError, match="Unsupported RAPTOR clustering method"):
get_raptor_clustering_method(config)
class TestAsExtraDict:
"""Tests for _as_extra_dict function."""
def test_returns_dict_as_is(self):
"""Test that dict input is returned as-is."""
input_dict = {"key": "value"}
result = _as_extra_dict(input_dict)
assert result == input_dict
def test_returns_empty_dict_for_none(self):
"""Test that None input returns empty dict."""
result = _as_extra_dict(None)
assert result == {}
def test_returns_empty_dict_for_empty_string(self):
"""Test that empty string input returns empty dict."""
result = _as_extra_dict("")
assert result == {}
def test_parses_valid_json_string(self):
"""Test that valid JSON string is parsed correctly."""
input_str = '{"key": "value"}'
result = _as_extra_dict(input_str)
assert result == {"key": "value"}
def test_returns_empty_dict_for_non_dict_json(self):
"""Test that non-dict JSON returns empty dict."""
input_str = '[1, 2, 3]'
result = _as_extra_dict(input_str)
assert result == {}
def test_parses_python_dict_literal(self):
"""Test that Python dict literal is parsed."""
input_str = "{'key': 'value'}"
result = _as_extra_dict(input_str)
assert result == {"key": "value"}
def test_returns_empty_dict_for_malformed_string(self):
"""Test that malformed string returns empty dict."""
input_str = "{invalid json}"
result = _as_extra_dict(input_str)
assert result == {}
class TestHasRaptorMarker:
"""Tests for _has_raptor_marker function."""
def test_returns_true_for_raptor_string(self):
"""Test that 'raptor' string returns True."""
assert _has_raptor_marker("raptor") is True
def test_returns_true_for_raptor_in_list(self):
"""Test that 'raptor' in list returns True."""
assert _has_raptor_marker(["raptor", "other"]) is True
def test_returns_false_for_other_string(self):
"""Test that other string returns False."""
assert _has_raptor_marker("other") is False
def test_returns_false_for_empty_list(self):
"""Test that empty list returns False."""
assert _has_raptor_marker([]) is False
def test_returns_false_for_list_without_raptor(self):
"""Test that list without 'raptor' returns False."""
assert _has_raptor_marker(["psi", "other"]) is False
class TestRaptorMethodsFromFields:
"""Tests for _raptor_methods_from_fields function."""
def test_returns_default_raptor_method(self):
"""Test that default method is 'raptor'."""
result = _raptor_methods_from_fields({})
assert result == {RAPTOR_TREE_BUILDER}
def test_returns_method_from_extra_dict(self):
"""Test that method is extracted from extra dict."""
fields = {"extra": {"raptor_method": PSI_TREE_BUILDER}}
result = _raptor_methods_from_fields(fields)
assert result == {PSI_TREE_BUILDER}
def test_returns_method_from_extra_field(self):
"""Test that method is extracted from extra field directly."""
fields = {"extra": "{'raptor_method': 'psi'}"}
result = _raptor_methods_from_fields(fields)
assert result == {PSI_TREE_BUILDER}
def test_handles_list_method(self):
"""Test that list method is converted to set."""
fields = {"extra": {"raptor_method": ["raptor", "psi"]}}
result = _raptor_methods_from_fields(fields)
assert result == {RAPTOR_TREE_BUILDER, PSI_TREE_BUILDER}
def test_handles_empty_method(self):
"""Test that empty method returns default."""
fields = {"extra": {"raptor_method": ""}}
result = _raptor_methods_from_fields(fields)
assert result == {RAPTOR_TREE_BUILDER}
class TestCollectRaptorMethods:
"""Tests for collect_raptor_methods function."""
def test_returns_empty_set_for_empty_map(self):
"""Test that empty field map returns empty set."""
result = collect_raptor_methods({})
assert result == set()
def test_collects_methods_from_raptor_chunks(self):
"""Test that methods are collected from RAPTOR chunks."""
field_map = {
"chunk_1": {
"raptor_kwd": "raptor",
"extra": {"raptor_method": PSI_TREE_BUILDER}
}
}
result = collect_raptor_methods(field_map)
assert result == {PSI_TREE_BUILDER}
def test_skips_non_raptor_chunks(self):
"""Test that non-RAPTOR chunks are skipped."""
field_map = {
"chunk_1": {
"raptor_kwd": "other",
"extra": {"raptor_method": PSI_TREE_BUILDER}
}
}
result = collect_raptor_methods(field_map)
assert result == set()
def test_collects_multiple_methods(self):
"""Test that multiple methods are collected."""
field_map = {
"chunk_1": {"raptor_kwd": "raptor", "extra": {"raptor_method": "raptor"}},
"chunk_2": {"raptor_kwd": "raptor", "extra": {"raptor_method": "psi"}}
}
result = collect_raptor_methods(field_map)
assert result == {RAPTOR_TREE_BUILDER, PSI_TREE_BUILDER}
class TestCollectRaptorChunkIds:
"""Tests for collect_raptor_chunk_ids function."""
def test_returns_empty_set_for_empty_map(self):
"""Test that empty field map returns empty set."""
result = collect_raptor_chunk_ids({})
assert result == set()
def test_collects_ids_of_raptor_chunks(self):
"""Test that IDs of RAPTOR chunks are collected."""
field_map = {
"chunk_1": {"raptor_kwd": "raptor"},
"chunk_2": {"raptor_kwd": "raptor"}
}
result = collect_raptor_chunk_ids(field_map)
assert result == {"chunk_1", "chunk_2"}
def test_excludes_specified_methods(self):
"""Test that specified methods are excluded."""
field_map = {
"chunk_1": {"raptor_kwd": "raptor", "extra": {"raptor_method": "raptor"}},
"chunk_2": {"raptor_kwd": "raptor", "extra": {"raptor_method": "psi"}}
}
result = collect_raptor_chunk_ids(field_map, exclude_methods={"raptor"})
assert result == {"chunk_2"}
def test_skips_non_raptor_chunks(self):
"""Test that non-RAPTOR chunks are skipped."""
field_map = {
"chunk_1": {"raptor_kwd": "raptor"},
"chunk_2": {"raptor_kwd": "other"}
}
result = collect_raptor_chunk_ids(field_map)
assert result == {"chunk_1"}
class TestMakeRaptorSummaryChunkId:
"""Tests for make_raptor_summary_chunk_id function."""
def test_generates_consistent_id(self):
"""Test that same input generates same ID."""
id1 = make_raptor_summary_chunk_id("content", "doc_1")
id2 = make_raptor_summary_chunk_id("content", "doc_1")
assert id1 == id2
def test_generates_different_ids_for_different_content(self):
"""Test that different content generates different ID."""
id1 = make_raptor_summary_chunk_id("content1", "doc_1")
id2 = make_raptor_summary_chunk_id("content2", "doc_1")
assert id1 != id2
def test_generates_different_ids_for_different_doc(self):
"""Test that different doc_id generates different ID."""
id1 = make_raptor_summary_chunk_id("content", "doc_1")
id2 = make_raptor_summary_chunk_id("content", "doc_2")
assert id1 != id2
def test_returns_string(self):
"""Test that result is a string."""
result = make_raptor_summary_chunk_id("content", "doc_1")
assert isinstance(result, str)
class TestIsStructuredFileType:
"""Tests for is_structured_file_type function."""
def test_returns_true_for_xlsx(self):
"""Test that .xlsx is recognized as structured."""
assert is_structured_file_type(".xlsx") is True
def test_returns_true_for_xls(self):
"""Test that .xls is recognized as structured."""
assert is_structured_file_type(".xls") is True
def test_returns_true_for_csv(self):
"""Test that .csv is recognized as structured."""
assert is_structured_file_type(".csv") is True
def test_returns_true_for_tsv(self):
"""Test that .tsv is recognized as structured."""
assert is_structured_file_type(".tsv") is True
def test_returns_false_for_pdf(self):
"""Test that .pdf is not structured."""
assert is_structured_file_type(".pdf") is False
def test_returns_false_for_txt(self):
"""Test that .txt is not structured."""
assert is_structured_file_type(".txt") is False
def test_returns_false_for_none(self):
"""Test that None is not structured."""
assert is_structured_file_type(None) is False
def test_returns_false_for_empty_string(self):
"""Test that empty string is not structured."""
assert is_structured_file_type("") is False
def test_handles_case_insensitive(self):
"""Test that case is handled insensitively."""
assert is_structured_file_type(".XLSX") is True
assert is_structured_file_type("xlsx") is True
def test_handles_missing_dot(self):
"""Test that missing dot is handled."""
assert is_structured_file_type("xlsx") is True
class TestIsTabularPdf:
"""Tests for is_tabular_pdf function."""
def test_returns_true_for_table_parser(self):
"""Test that table parser returns True."""
assert is_tabular_pdf("table", {}) is True
def test_returns_true_for_html4excel(self):
"""Test that html4excel enabled returns True."""
assert is_tabular_pdf("naive", {"html4excel": True}) is True
def test_returns_false_for_naive_parser(self):
"""Test that naive parser returns False."""
assert is_tabular_pdf("naive", {}) is False
def test_returns_false_for_empty_parser_id(self):
"""Test that empty parser_id returns False."""
assert is_tabular_pdf("", {}) is False
def test_returns_false_for_html4excel_false(self):
"""Test that html4excel=False returns False."""
assert is_tabular_pdf("naive", {"html4excel": False}) is False
def test_handles_case_insensitive_parser_id(self):
"""Test that parser_id case is handled."""
assert is_tabular_pdf("TABLE", {}) is True
assert is_tabular_pdf("Table", {}) is True
class TestShouldSkipRaptor:
"""Tests for should_skip_raptor function."""
def test_skips_for_xlsx_file(self):
"""Test that .xlsx file skips Raptor."""
assert should_skip_raptor(file_type=".xlsx") is True
def test_skips_for_csv_file(self):
"""Test that .csv file skips Raptor."""
assert should_skip_raptor(file_type=".csv") is True
def test_skips_for_tabular_pdf(self):
"""Test that tabular PDF skips Raptor."""
assert should_skip_raptor(file_type=".pdf", parser_id="table") is True
def test_does_not_skip_for_normal_pdf(self):
"""Test that normal PDF does not skip Raptor."""
assert should_skip_raptor(file_type=".pdf", parser_id="naive") is False
def test_does_not_skip_for_txt_file(self):
"""Test that .txt file does not skip Raptor."""
assert should_skip_raptor(file_type=".txt") is False
def test_respects_auto_disable_config_false(self):
"""Test that auto_disable_for_structured_data=False disables skipping."""
assert should_skip_raptor(
file_type=".xlsx",
raptor_config={"auto_disable_for_structured_data": False}
) is False
def test_respects_auto_disable_config_true(self):
"""Test that auto_disable_for_structured_data=True enables skipping."""
assert should_skip_raptor(
file_type=".xlsx",
raptor_config={"auto_disable_for_structured_data": True}
) is True
def test_default_auto_disable_is_true(self):
"""Test that default auto_disable is True."""
assert should_skip_raptor(file_type=".xlsx") is True
def test_returns_false_for_none_file_type(self):
"""Test that None file_type does not skip."""
assert should_skip_raptor(file_type=None) is False
class TestGetSkipReason:
"""Tests for get_skip_reason function."""
def test_returns_reason_for_structured_file(self):
"""Test that reason is returned for structured file."""
reason = get_skip_reason(file_type=".xlsx")
assert "Structured data file" in reason
assert ".xlsx" in reason
def test_returns_reason_for_tabular_pdf(self):
"""Test that reason is returned for tabular PDF."""
reason = get_skip_reason(file_type=".pdf", parser_id="table")
assert "Tabular PDF" in reason
assert "table" in reason
def test_returns_empty_for_normal_pdf(self):
"""Test that empty reason is returned for normal PDF."""
reason = get_skip_reason(file_type=".pdf", parser_id="naive")
assert reason == ""
def test_returns_empty_for_txt_file(self):
"""Test that empty reason is returned for .txt file."""
reason = get_skip_reason(file_type=".txt")
assert reason == ""
def test_returns_empty_for_none_file_type(self):
"""Test that empty reason is returned for None file_type."""
reason = get_skip_reason(file_type=None)
assert reason == ""