mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
feat(raptor): add Psi tree builder with original-space ranking and safe migration (#14679)
### What problem does this PR solve? Closes #14674. This PR improves RAPTOR configuration and tree construction while preserving the existing RAPTOR behavior as the default. RAPTOR currently builds summary layers with the original UMAP + GMM clustering path. This PR keeps that default path, and adds: - A hidden backend tree-builder option: - `tree_builder="raptor"`: default, existing RAPTOR behavior. - `tree_builder="psi"`: rank-aware Psi-style tree builder using original embedding-space cosine ranking. - A user-facing clustering method option for the default RAPTOR builder: - `clustering_method="gmm"`: existing default. - `clustering_method="ahc"`: agglomerative hierarchical clustering path. - A RAPTOR UI setting for `Clustering method` and `Max cluster`. ### What changed #### Backend - Added `tree_builder` support for RAPTOR/Psi. - Added `clustering_method` support for GMM/AHC. - Kept existing RAPTOR + GMM as the default. - Added Psi tree building from original-space cosine similarity. - Added bucketed Psi building controls for large inputs: - `raptor.ext.psi_exact_max_leaves` - `raptor.ext.psi_bucket_size` - Added method-aware RAPTOR summary metadata using existing `extra.raptor_method`. - Avoided adding a dedicated DB schema field for experimental method tracking. - Added cleanup/migration logic to avoid mixing stale RAPTOR summary trees. - Added defensive checks for Psi tree construction and summary failures. #### Frontend/UI - Added `Clustering method` in RAPTOR settings with `GMM` and `AHC`. - Added/kept `Max cluster` in RAPTOR settings. - Enlarged max cluster UI limit to `1024`, matching backend validation. - Kept AHC editable even when a RAPTOR task has already finished. - Fixed the UI save payload so `clustering_method` and `tree_builder` are serialized through `parser_config.raptor.ext`, avoiding backend validation errors for extra top-level RAPTOR fields. Example saved RAPTOR config: ```json { "raptor": { "max_cluster": 317, "ext": { "clustering_method": "ahc", "tree_builder": "raptor" } } } Co-authored-by: CaptainTimon <CaptainTimon@users.noreply.github.com>
This commit is contained in:
@@ -17,7 +17,7 @@ import { DocumentParserType, ParseType } from '@/constants/knowledge';
|
||||
import { useFetchKnowledgeBaseConfiguration } from '@/hooks/use-knowledge-request';
|
||||
import { IModalProps } from '@/interfaces/common';
|
||||
import { IParserConfig } from '@/interfaces/database/document';
|
||||
import { IChangeParserConfigRequestBody } from '@/interfaces/request/document';
|
||||
import { IChangeParserRequestBody } from '@/interfaces/request/document';
|
||||
import { MetadataType } from '@/pages/dataset/components/metedata/constant';
|
||||
import {
|
||||
AutoMetadata,
|
||||
@@ -28,7 +28,6 @@ import {
|
||||
} from '@/pages/dataset/dataset-setting/configuration/common-item';
|
||||
import { zodResolver } from '@hookform/resolvers/zod';
|
||||
import omit from 'lodash/omit';
|
||||
import {} from 'module';
|
||||
import { useEffect, useMemo } from 'react';
|
||||
import { useForm, useWatch } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -56,10 +55,7 @@ import {
|
||||
|
||||
const FormId = 'ChunkMethodDialogForm';
|
||||
|
||||
interface IProps extends IModalProps<{
|
||||
parserId: string;
|
||||
parserConfig: IChangeParserConfigRequestBody;
|
||||
}> {
|
||||
interface IProps extends IModalProps<IChangeParserRequestBody> {
|
||||
loading: boolean;
|
||||
parserId: string;
|
||||
pipelineId?: string;
|
||||
@@ -126,16 +122,19 @@ export function ChunkMethodDialog({
|
||||
mineru_formula_enable: z.boolean().optional(),
|
||||
mineru_table_enable: z.boolean().optional(),
|
||||
mineru_lang: z.string().optional(),
|
||||
// raptor: z
|
||||
// .object({
|
||||
// use_raptor: z.boolean().optional(),
|
||||
// prompt: z.string().optional().optional(),
|
||||
// max_token: z.coerce.number().optional(),
|
||||
// threshold: z.coerce.number().optional(),
|
||||
// max_cluster: z.coerce.number().optional(),
|
||||
// random_seed: z.coerce.number().optional(),
|
||||
// })
|
||||
// .optional(),
|
||||
raptor: z
|
||||
.object({
|
||||
use_raptor: z.boolean().optional(),
|
||||
prompt: z.string().optional(),
|
||||
max_token: z.coerce.number().optional(),
|
||||
threshold: z.coerce.number().optional(),
|
||||
max_cluster: z.coerce.number().optional(),
|
||||
random_seed: z.coerce.number().optional(),
|
||||
scope: z.string().optional(),
|
||||
clustering_method: z.enum(['gmm', 'ahc']).optional(),
|
||||
tree_builder: z.enum(['raptor', 'psi']).optional(),
|
||||
})
|
||||
.optional(),
|
||||
// graphrag: z.object({
|
||||
// use_graphrag: z.boolean().optional(),
|
||||
// }),
|
||||
|
||||
@@ -23,14 +23,17 @@ export function useDefaultParserValues() {
|
||||
mineru_formula_enable: true,
|
||||
mineru_table_enable: true,
|
||||
mineru_lang: 'English',
|
||||
// raptor: {
|
||||
// use_raptor: false,
|
||||
// prompt: t('knowledgeConfiguration.promptText'),
|
||||
// max_token: 256,
|
||||
// threshold: 0.1,
|
||||
// max_cluster: 64,
|
||||
// random_seed: 0,
|
||||
// },
|
||||
raptor: {
|
||||
use_raptor: false,
|
||||
prompt: t('knowledgeConfiguration.promptText'),
|
||||
max_token: 256,
|
||||
threshold: 0.1,
|
||||
max_cluster: 64,
|
||||
random_seed: 0,
|
||||
scope: 'file',
|
||||
clustering_method: 'gmm',
|
||||
tree_builder: 'raptor',
|
||||
},
|
||||
// graphrag: {
|
||||
// use_graphrag: false,
|
||||
// },
|
||||
|
||||
@@ -8,7 +8,7 @@ import {
|
||||
} from '@/pages/dataset/dataset/generate-button/generate';
|
||||
import random from 'lodash/random';
|
||||
import { Shuffle } from 'lucide-react';
|
||||
import { useCallback } from 'react';
|
||||
import { useCallback, useEffect, useMemo } from 'react';
|
||||
import { useFormContext, useWatch } from 'react-hook-form';
|
||||
import { SliderInputFormField } from '../slider-input-form-field';
|
||||
import {
|
||||
@@ -50,10 +50,10 @@ export const showTagItems = (parserId: DocumentParserType) => {
|
||||
|
||||
const UseRaptorField = 'parser_config.raptor.use_raptor';
|
||||
const RandomSeedField = 'parser_config.raptor.random_seed';
|
||||
const MaxTokenField = 'parser_config.raptor.max_token';
|
||||
const ThresholdField = 'parser_config.raptor.threshold';
|
||||
const MaxCluster = 'parser_config.raptor.max_cluster';
|
||||
const Prompt = 'parser_config.raptor.prompt';
|
||||
const ClusteringMethodField = 'parser_config.raptor.clustering_method';
|
||||
const ClusteringMethodExtField = 'parser_config.raptor.ext.clustering_method';
|
||||
const TreeBuilderField = 'parser_config.raptor.tree_builder';
|
||||
const MaxClusterMax = 1024;
|
||||
|
||||
// The three types "table", "resume" and "one" do not display this configuration.
|
||||
|
||||
@@ -67,17 +67,48 @@ const RaptorFormFields = ({
|
||||
const form = useFormContext();
|
||||
const { t } = useTranslate('knowledgeConfiguration');
|
||||
const useRaptor = useWatch({ name: UseRaptorField });
|
||||
const clusteringMethod = useWatch({ name: ClusteringMethodField });
|
||||
const extClusteringMethod = useWatch({ name: ClusteringMethodExtField });
|
||||
const selectedClusteringMethod = useMemo(
|
||||
() =>
|
||||
(clusteringMethod ??
|
||||
extClusteringMethod ??
|
||||
form.getValues(ClusteringMethodField) ??
|
||||
form.getValues(ClusteringMethodExtField) ??
|
||||
'gmm') as 'gmm' | 'ahc',
|
||||
[clusteringMethod, extClusteringMethod, form],
|
||||
);
|
||||
|
||||
const handleGenerate = useCallback(() => {
|
||||
form.setValue(RandomSeedField, random(10000));
|
||||
}, [form]);
|
||||
|
||||
const handleClusteringMethodChange = useCallback(
|
||||
(method: 'gmm' | 'ahc') => {
|
||||
form.setValue(ClusteringMethodField, method, {
|
||||
shouldDirty: true,
|
||||
shouldValidate: true,
|
||||
});
|
||||
form.setValue(TreeBuilderField, 'raptor', {
|
||||
shouldDirty: true,
|
||||
shouldValidate: true,
|
||||
});
|
||||
},
|
||||
[form],
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (!clusteringMethod && !extClusteringMethod) {
|
||||
handleClusteringMethodChange('gmm');
|
||||
}
|
||||
}, [clusteringMethod, extClusteringMethod, handleClusteringMethodChange]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<FormField
|
||||
control={form.control}
|
||||
name={UseRaptorField}
|
||||
render={({ field }) => {
|
||||
render={() => {
|
||||
return (
|
||||
<FormItem
|
||||
defaultChecked={false}
|
||||
@@ -209,11 +240,61 @@ const RaptorFormFields = ({
|
||||
sliderTestId="ds-settings-raptor-threshold-slider"
|
||||
numberInputTestId="ds-settings-raptor-threshold-input"
|
||||
></SliderInputFormField>
|
||||
<FormField
|
||||
control={form.control}
|
||||
name={ClusteringMethodField}
|
||||
render={({ field }) => {
|
||||
return (
|
||||
<FormItem className=" items-center space-y-0 ">
|
||||
<div className="flex items-start">
|
||||
<FormLabel
|
||||
tooltip={t('clusteringMethodTip')}
|
||||
className="text-sm whitespace-nowrap w-1/4"
|
||||
>
|
||||
{t('clusteringMethod')}
|
||||
</FormLabel>
|
||||
<div className="w-3/4">
|
||||
<FormControl>
|
||||
<Radio.Group
|
||||
{...field}
|
||||
value={selectedClusteringMethod}
|
||||
onChange={(value) =>
|
||||
handleClusteringMethodChange(value as 'gmm' | 'ahc')
|
||||
}
|
||||
>
|
||||
<div
|
||||
className={'flex gap-4 w-full text-text-secondary '}
|
||||
>
|
||||
<Radio
|
||||
value="gmm"
|
||||
testId="ds-settings-raptor-clustering-method-option-gmm"
|
||||
>
|
||||
{t('clusteringMethodGmm')}
|
||||
</Radio>
|
||||
<Radio
|
||||
value="ahc"
|
||||
testId="ds-settings-raptor-clustering-method-option-ahc"
|
||||
>
|
||||
{t('clusteringMethodAhc')}
|
||||
</Radio>
|
||||
</div>
|
||||
</Radio.Group>
|
||||
</FormControl>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex pt-1">
|
||||
<div className="w-1/4"></div>
|
||||
<FormMessage />
|
||||
</div>
|
||||
</FormItem>
|
||||
);
|
||||
}}
|
||||
/>
|
||||
<SliderInputFormField
|
||||
name={'parser_config.raptor.max_cluster'}
|
||||
label={t('maxCluster')}
|
||||
tooltip={t('maxClusterTip')}
|
||||
max={1024}
|
||||
max={MaxClusterMax}
|
||||
min={1}
|
||||
layout={FormLayout.Horizontal}
|
||||
sliderTestId="ds-settings-raptor-max-cluster-slider"
|
||||
|
||||
@@ -13,6 +13,7 @@ type RadioProps = {
|
||||
checked?: boolean;
|
||||
disabled?: boolean;
|
||||
onChange?: (checked: boolean) => void;
|
||||
testId?: string;
|
||||
children?: React.ReactNode;
|
||||
} & Omit<
|
||||
React.InputHTMLAttributes<HTMLInputElement>,
|
||||
@@ -25,6 +26,7 @@ function Radio({
|
||||
checked,
|
||||
disabled,
|
||||
onChange,
|
||||
testId,
|
||||
children,
|
||||
...props
|
||||
}: RadioProps) {
|
||||
@@ -65,6 +67,7 @@ function Radio({
|
||||
onChange={handleChange}
|
||||
disabled={mergedDisabled}
|
||||
className={cn('peer absolute size-[1px] opacity-0', className)}
|
||||
data-testid={testId}
|
||||
{...props}
|
||||
name={groupContext?.name}
|
||||
/>
|
||||
@@ -151,9 +154,11 @@ const Group = React.forwardRef<HTMLDivElement, RadioGroupProps>(
|
||||
)}
|
||||
>
|
||||
{React.Children.map(children, (child) => {
|
||||
if (!React.isValidElement<RadioProps>(child)) return child;
|
||||
if (!React.isValidElement<RadioProps>(child)) {
|
||||
return child;
|
||||
}
|
||||
return React.cloneElement(child, {
|
||||
disabled: disabled || child.props?.disabled,
|
||||
disabled: disabled || child.props.disabled,
|
||||
});
|
||||
})}
|
||||
</div>
|
||||
|
||||
@@ -21,10 +21,17 @@ export const extractRaptorConfigExt = (
|
||||
max_cluster,
|
||||
random_seed,
|
||||
scope,
|
||||
clustering_method,
|
||||
tree_builder,
|
||||
auto_disable_for_structured_data,
|
||||
ext,
|
||||
...raptorExt
|
||||
} = raptorConfig;
|
||||
const extClusteringMethod = ext?.clustering_method;
|
||||
const normalizedClusteringMethod =
|
||||
clustering_method ?? extClusteringMethod ?? 'gmm';
|
||||
const normalizedTreeBuilder = tree_builder ?? ext?.tree_builder ?? 'raptor';
|
||||
|
||||
return {
|
||||
use_raptor,
|
||||
prompt,
|
||||
@@ -34,7 +41,12 @@ export const extractRaptorConfigExt = (
|
||||
random_seed,
|
||||
scope,
|
||||
auto_disable_for_structured_data,
|
||||
ext: { ...ext, ...raptorExt },
|
||||
ext: {
|
||||
...ext,
|
||||
...raptorExt,
|
||||
clustering_method: normalizedClusteringMethod,
|
||||
tree_builder: normalizedTreeBuilder,
|
||||
},
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
45
web/src/hooks/tests/parser-config-utils.test.ts
Normal file
45
web/src/hooks/tests/parser-config-utils.test.ts
Normal file
@@ -0,0 +1,45 @@
|
||||
import { extractParserConfigExt } from '../parser-config-utils';
|
||||
|
||||
describe('extractParserConfigExt', () => {
|
||||
it('serializes RAPTOR clustering fields through ext for API compatibility', () => {
|
||||
const result = extractParserConfigExt({
|
||||
raptor: {
|
||||
use_raptor: true,
|
||||
prompt: 'Summarize {cluster_content}',
|
||||
max_token: 256,
|
||||
threshold: 0.1,
|
||||
max_cluster: 317,
|
||||
random_seed: 0,
|
||||
scope: 'file',
|
||||
clustering_method: 'ahc',
|
||||
tree_builder: 'raptor',
|
||||
},
|
||||
});
|
||||
|
||||
expect(result?.raptor).not.toHaveProperty('clustering_method');
|
||||
expect(result?.raptor).not.toHaveProperty('tree_builder');
|
||||
expect(result?.raptor?.ext).toMatchObject({
|
||||
clustering_method: 'ahc',
|
||||
tree_builder: 'raptor',
|
||||
});
|
||||
});
|
||||
|
||||
it('preserves existing RAPTOR ext clustering values when the top-level field is absent', () => {
|
||||
const result = extractParserConfigExt({
|
||||
raptor: {
|
||||
max_cluster: 512,
|
||||
ext: {
|
||||
clustering_method: 'ahc',
|
||||
tree_builder: 'raptor',
|
||||
psi_bucket_size: 1024,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
expect(result?.raptor?.ext).toMatchObject({
|
||||
clustering_method: 'ahc',
|
||||
tree_builder: 'raptor',
|
||||
psi_bucket_size: 1024,
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -73,11 +73,13 @@ interface Parserconfig {
|
||||
}
|
||||
|
||||
interface Raptor {
|
||||
clustering_method?: 'gmm' | 'ahc';
|
||||
max_cluster: number;
|
||||
max_token: number;
|
||||
prompt: string;
|
||||
random_seed: number;
|
||||
threshold: number;
|
||||
tree_builder?: 'raptor' | 'psi';
|
||||
use_raptor: boolean;
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,17 @@ export interface IChangeParserConfigRequestBody {
|
||||
image_table_context_window?: number;
|
||||
image_context_size?: number;
|
||||
table_context_size?: number;
|
||||
raptor?: {
|
||||
use_raptor?: boolean;
|
||||
prompt?: string;
|
||||
max_token?: number;
|
||||
threshold?: number;
|
||||
max_cluster?: number;
|
||||
random_seed?: number;
|
||||
scope?: string;
|
||||
clustering_method?: 'gmm' | 'ahc';
|
||||
tree_builder?: 'raptor' | 'psi';
|
||||
};
|
||||
// Metadata fields
|
||||
metadata?: Array<{
|
||||
key?: string;
|
||||
@@ -27,8 +38,8 @@ export interface IChangeParserConfigRequestBody {
|
||||
|
||||
export interface IChangeParserRequestBody {
|
||||
parser_id: string;
|
||||
pipeline_id: string;
|
||||
doc_id: string;
|
||||
pipeline_id?: string;
|
||||
doc_id?: string;
|
||||
parser_config: IChangeParserConfigRequestBody;
|
||||
}
|
||||
|
||||
|
||||
@@ -861,6 +861,11 @@ The above is the content you need to summarize.`,
|
||||
thresholdTip:
|
||||
'In RAPTOR, chunks are clustered by their semantic similarity. The Threshold parameter sets the minimum similarity required for chunks to be grouped together. A higher Threshold means fewer chunks in each cluster, while a lower one means more.',
|
||||
thresholdMessage: 'Threshold is required',
|
||||
clusteringMethod: 'Clustering method',
|
||||
clusteringMethodTip:
|
||||
'Select the RAPTOR clustering method. AHC can use a larger max cluster value, but may require more memory on large inputs.',
|
||||
clusteringMethodGmm: 'GMM',
|
||||
clusteringMethodAhc: 'AHC',
|
||||
maxCluster: 'Max cluster',
|
||||
maxClusterTip: 'The maximum number of clusters to create.',
|
||||
maxClusterMessage: 'Max cluster is required',
|
||||
|
||||
@@ -772,6 +772,11 @@ export default {
|
||||
maxTokenMessage: '最大token数是必填项',
|
||||
threshold: '阈值',
|
||||
thresholdMessage: '阈值是必填项',
|
||||
clusteringMethod: '聚类方法',
|
||||
clusteringMethodTip:
|
||||
'选择 RAPTOR 聚类方法。AHC 可以使用更大的最大聚类数,但在大规模输入时可能占用更多内存。',
|
||||
clusteringMethodGmm: 'GMM',
|
||||
clusteringMethodAhc: 'AHC',
|
||||
maxCluster: '最大聚类数',
|
||||
maxClusterMessage: '最大聚类数是必填项',
|
||||
randomSeed: '随机种子',
|
||||
|
||||
@@ -42,11 +42,14 @@ export const formSchema = z
|
||||
.object({
|
||||
use_raptor: z.boolean().optional(),
|
||||
prompt: z.string().optional(),
|
||||
max_token: z.number().optional(),
|
||||
threshold: z.number().optional(),
|
||||
max_cluster: z.number().optional(),
|
||||
random_seed: z.number().optional(),
|
||||
max_token: z.coerce.number().optional(),
|
||||
threshold: z.coerce.number().optional(),
|
||||
max_cluster: z.coerce.number().optional(),
|
||||
random_seed: z.coerce.number().optional(),
|
||||
scope: z.string().optional(),
|
||||
clustering_method: z.enum(['gmm', 'ahc']).optional(),
|
||||
tree_builder: z.enum(['raptor', 'psi']).optional(),
|
||||
ext: z.record(z.string(), z.any()).optional(),
|
||||
})
|
||||
.refine(
|
||||
(data) => {
|
||||
|
||||
@@ -95,6 +95,8 @@ export default function DatasetSettings() {
|
||||
max_cluster: 64,
|
||||
random_seed: 0,
|
||||
scope: 'file',
|
||||
clustering_method: 'gmm',
|
||||
tree_builder: 'raptor',
|
||||
prompt: t('knowledgeConfiguration.promptText'),
|
||||
},
|
||||
graphrag: {
|
||||
|
||||
@@ -19,7 +19,7 @@ export const useChangeDocumentParser = () => {
|
||||
if (record?.id && record?.dataset_id) {
|
||||
const ret = await setDocumentParser({
|
||||
parserId: parserConfigInfo.parser_id,
|
||||
pipelineId: parserConfigInfo.pipeline_id,
|
||||
pipelineId: parserConfigInfo.pipeline_id || '',
|
||||
documentId: record?.id,
|
||||
datasetId: record?.dataset_id,
|
||||
parserConfig: parserConfigInfo.parser_config,
|
||||
|
||||
Reference in New Issue
Block a user