Files
ragflow/rag/flow/chunker/token_chunker.py
Magicbook1108 87a87a7122 Feat: pipeline support ONE chunking method (#14024)
### What problem does this PR solve?

Feat: pipeline support ONE chunking method

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Yingfeng <yingfeng.zhang@gmail.com>
2026-04-10 13:11:22 +08:00

369 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import re
from copy import deepcopy
from common.float_utils import normalize_overlapped_percent
from common.token_utils import num_tokens_from_string
from rag.flow.base import ProcessBase, ProcessParamBase
from rag.flow.chunker.schema import TokenChunkerFromUpstream
from rag.flow.parser.pdf_chunk_metadata import (
PDF_POSITIONS_KEY,
extract_pdf_positions,
finalize_pdf_chunk,
restore_pdf_text_previews,
)
from rag.nlp import naive_merge
class TokenChunkerParam(ProcessParamBase):
def __init__(self):
super().__init__()
self.delimiter_mode = "token_size"
self.chunk_token_size = 512
self.delimiters = ["\n"]
self.overlapped_percent = 0
self.children_delimiters = []
self.table_context_size = 0
self.image_context_size = 0
def check(self):
self.check_valid_value(self.delimiter_mode, "Delimiter mode abnormal.", ["token_size", "delimiter", "one"])
if self.delimiters is None:
self.delimiters = []
elif isinstance(self.delimiters, str):
self.delimiters = [self.delimiters]
else:
self.delimiters = [d for d in self.delimiters if isinstance(d, str)]
self.delimiters = [d for d in self.delimiters if d]
if self.children_delimiters is None:
self.children_delimiters = []
elif isinstance(self.children_delimiters, str):
self.children_delimiters = [self.children_delimiters]
else:
self.children_delimiters = [d for d in self.children_delimiters if isinstance(d, str)]
self.children_delimiters = [d for d in self.children_delimiters if d]
self.check_positive_integer(self.chunk_token_size, "Chunk token size.")
self.check_decimal_float(self.overlapped_percent, "Overlapped percentage: [0, 1)")
self.check_nonnegative_number(self.table_context_size, "Table context size.")
self.check_nonnegative_number(self.image_context_size, "Image context size.")
def get_input_form(self) -> dict[str, dict]:
return {}
def _compile_delimiter_pattern(delimiters):
# Build the primary delimiter regex from active delimiters wrapped by backticks.
raw_delimiters = "".join(delimiter for delimiter in (delimiters or []) if delimiter)
custom_delimiters = [m.group(1) for m in re.finditer(r"`([^`]+)`", raw_delimiters)]
if not custom_delimiters:
return ""
return "|".join(re.escape(text) for text in sorted(set(custom_delimiters), key=len, reverse=True))
def _split_text_by_pattern(text, pattern):
# Split text by the compiled delimiter pattern and keep delimiter text in each chunk.
if not pattern:
return [text or ""]
split_texts = re.split(r"(%s)" % pattern, text or "", flags=re.DOTALL)
chunks = []
for i in range(0, len(split_texts), 2):
chunk = split_texts[i]
if not chunk:
continue
if i + 1 < len(split_texts):
chunk += split_texts[i + 1]
if chunk.strip():
chunks.append(chunk)
return chunks
def _build_json_chunks(json_result, delimiter_pattern):
# Convert upstream JSON items into internal working chunks.
chunks = []
for item in json_result:
doc_type = str(item.get("doc_type_kwd") or "").strip().lower()
if doc_type == "table":
ck_type = "table"
elif doc_type == "image":
ck_type = "image"
else:
ck_type = "text"
text = item.get("text")
if not isinstance(text, str):
text = item.get("content_with_weight")
if not isinstance(text, str):
text = ""
# Keep PDF coordinates as an internal preview field until the final
# output is assembled. This avoids leaking two public coordinate
# formats downstream.
preview_positions = extract_pdf_positions(item)
img_id = item.get("img_id")
if ck_type == "text":
text_segments = _split_text_by_pattern(text, delimiter_pattern) if delimiter_pattern else [text]
for segment in text_segments:
if not segment or not segment.strip():
continue
chunks.append(
{
"text": segment,
"doc_type_kwd": "text",
"ck_type": "text",
PDF_POSITIONS_KEY: deepcopy(preview_positions),
"tk_nums": num_tokens_from_string(segment),
}
)
continue
chunks.append(
{
"text": text or "",
"doc_type_kwd": ck_type,
"ck_type": ck_type,
"img_id": img_id,
PDF_POSITIONS_KEY: deepcopy(preview_positions),
"tk_nums": num_tokens_from_string(text or ""),
"context_above": "",
"context_below": "",
}
)
return chunks
def _take_sentences(text, need_tokens, from_end=False):
# Take text from one side until the target token budget is reached.
split_pat = r"([。!?\n]|\. )"
texts = re.split(split_pat, text or "", flags=re.DOTALL)
sentences = []
for i in range(0, len(texts), 2):
sentences.append(texts[i] + (texts[i + 1] if i + 1 < len(texts) else ""))
iterator = reversed(sentences) if from_end else sentences
collected = ""
for sentence in iterator:
collected = sentence + collected if from_end else collected + sentence
if num_tokens_from_string(collected) >= need_tokens:
break
return collected
def _attach_context_to_media_chunks(chunks, table_context_size, image_context_size):
# Add surrounding text to table/image chunks when context windows are enabled.
for i, chunk in enumerate(chunks):
if chunk["ck_type"] not in {"table", "image"}:
continue
context_size = image_context_size if chunk["ck_type"] == "image" else table_context_size
if context_size <= 0:
continue
remain_above = context_size
remain_below = context_size
parts_above = []
parts_below = []
prev = i - 1
while prev >= 0 and remain_above > 0:
prev_chunk = chunks[prev]
if prev_chunk["ck_type"] == "text":
if prev_chunk["tk_nums"] >= remain_above:
parts_above.insert(0, _take_sentences(prev_chunk["text"], remain_above, from_end=True))
remain_above = 0
break
parts_above.insert(0, prev_chunk["text"])
remain_above -= prev_chunk["tk_nums"]
prev -= 1
after = i + 1
while after < len(chunks) and remain_below > 0:
after_chunk = chunks[after]
if after_chunk["ck_type"] == "text":
if after_chunk["tk_nums"] >= remain_below:
parts_below.append(_take_sentences(after_chunk["text"], remain_below))
remain_below = 0
break
parts_below.append(after_chunk["text"])
remain_below -= after_chunk["tk_nums"]
after += 1
chunk["context_above"] = "".join(parts_above)
chunk["context_below"] = "".join(parts_below)
def _merge_text_chunks_by_token_size(chunks, chunk_token_size, overlapped_percent):
# Merge adjacent text chunks when delimiter-based splitting is not active.
merged = []
prev_text_idx = -1
threshold = chunk_token_size * (100 - overlapped_percent) / 100.0
for chunk in chunks:
if chunk["ck_type"] != "text":
merged.append(deepcopy(chunk))
prev_text_idx = -1
continue
current = deepcopy(chunk)
should_start_new = prev_text_idx < 0 or merged[prev_text_idx]["tk_nums"] > threshold
if should_start_new:
if prev_text_idx >= 0 and overlapped_percent > 0 and merged[prev_text_idx]["text"]:
overlapped = merged[prev_text_idx]["text"]
overlap_start = int(len(overlapped) * (100 - overlapped_percent) / 100.0)
current["text"] = overlapped[overlap_start:] + current["text"]
current["tk_nums"] = num_tokens_from_string(current["text"])
merged.append(current)
prev_text_idx = len(merged) - 1
continue
if merged[prev_text_idx]["text"] and current["text"]:
merged[prev_text_idx]["text"] += "\n" + current["text"]
else:
merged[prev_text_idx]["text"] += current["text"]
merged[prev_text_idx][PDF_POSITIONS_KEY].extend(current.get(PDF_POSITIONS_KEY) or [])
merged[prev_text_idx]["tk_nums"] += current["tk_nums"]
return merged
def _finalize_json_chunks(chunks):
# Convert internal chunks into the final token chunker output format.
docs = []
for chunk in chunks:
text = (chunk.get("context_above") or "") + (chunk.get("text") or "") + (chunk.get("context_below") or "")
if not text.strip():
continue
# The internal preview coordinates are converted exactly once into the
# indexed fields consumed downstream.
doc = {
"text": text,
"doc_type_kwd": chunk.get("doc_type_kwd", "text"),
}
if chunk.get(PDF_POSITIONS_KEY):
doc[PDF_POSITIONS_KEY] = deepcopy(chunk[PDF_POSITIONS_KEY])
if chunk.get("mom"):
doc["mom"] = chunk["mom"]
if chunk.get("img_id"):
doc["img_id"] = chunk["img_id"]
docs.append(finalize_pdf_chunk(doc))
return docs
def _split_chunk_docs_by_children(chunks, pattern):
# Apply the secondary children_delimiters split to text chunks only.
if not pattern:
return chunks
docs = []
for chunk in chunks:
if chunk.get("doc_type_kwd", "text") != "text":
docs.append(chunk)
continue
split_texts = _split_text_by_pattern(chunk.get("text", ""), pattern)
mom = chunk.get("text", "")
for text in split_texts:
if not text.strip():
continue
child = deepcopy(chunk)
child["mom"] = mom
child["text"] = text
docs.append(child)
return docs
class TokenChunker(ProcessBase):
component_name = "TokenChunker"
async def _invoke(self, **kwargs):
try:
from_upstream = TokenChunkerFromUpstream.model_validate(kwargs)
except Exception as e:
self.set_output("_ERROR", f"Input error: {str(e)}")
return
# Build the primary delimiter regex. If no active custom delimiter exists,
# the token chunker falls back to token-size based merging.
delimiter_pattern = _compile_delimiter_pattern(self._param.delimiters)
custom_pattern = "|".join(re.escape(t) for t in sorted(set(self._param.children_delimiters), key=len, reverse=True))
self.set_output("output_format", "chunks")
self.callback(random.randint(1, 5) / 100.0, "Start to split into chunks.")
overlapped_percent = normalize_overlapped_percent(self._param.overlapped_percent)
if from_upstream.output_format in ["markdown", "text", "html"]:
payload = getattr(from_upstream, f"{from_upstream.output_format}_result") or ""
if self._param.delimiter_mode == "one":
self.set_output("chunks", [{"text": payload}] if payload.strip() else [])
self.callback(1, "Done.")
return
cks = _split_text_by_pattern(payload, delimiter_pattern) if delimiter_pattern else naive_merge(
payload,
self._param.chunk_token_size,
"",
overlapped_percent,
)
if custom_pattern:
docs = []
for c in cks:
if not c.strip():
continue
for text in _split_text_by_pattern(c, custom_pattern):
if not text.strip():
continue
docs.append({"text": text, "mom": c})
self.set_output("chunks", docs)
else:
self.set_output("chunks", [{"text": c.strip()} for c in cks if c.strip()])
self.callback(1, "Done.")
return
# json
json_result = from_upstream.json_result or []
if self._param.delimiter_mode == "one":
sections = []
for item in json_result:
text = item.get("text")
if not isinstance(text, str):
text = item.get("content_with_weight")
if isinstance(text, str) and text.strip():
sections.append(text)
merged_text = "\n".join(sections)
self.set_output("chunks", [{"text": merged_text}] if merged_text.strip() else [])
self.callback(1, "Done.")
return
# Structured JSON input is normalized first, then optionally enriched with
# media context, and finally merged only when delimiter splitting is inactive.
chunks = _build_json_chunks(json_result, delimiter_pattern)
_attach_context_to_media_chunks(chunks, self._param.table_context_size, self._param.image_context_size)
if not delimiter_pattern:
chunks = _merge_text_chunks_by_token_size(chunks, self._param.chunk_token_size, overlapped_percent)
if custom_pattern:
chunks = _split_chunk_docs_by_children(chunks, custom_pattern)
await restore_pdf_text_previews(chunks, from_upstream, self._canvas)
cks = _finalize_json_chunks(chunks)
self.set_output("chunks", cks)
self.callback(1, "Done.")