diff --git a/web/src/pages/agent/hooks/use-before-delete.test.tsx b/web/src/pages/agent/hooks/use-before-delete.test.tsx new file mode 100644 index 0000000000..114d5f85f7 --- /dev/null +++ b/web/src/pages/agent/hooks/use-before-delete.test.tsx @@ -0,0 +1,205 @@ +import { act, renderHook } from '@testing-library/react'; +import { Edge } from '@xyflow/react'; +import { NodeHandleId, Operator } from '../constant'; +import useGraphStore from '../store'; +import { useBeforeDelete } from './use-before-delete'; + +const createNode = ( + id: string, + label: Operator, + options: Record = {}, +) => ({ + id, + type: 'ragNode', + position: { x: 0, y: 0 }, + data: { + label, + name: id, + form: {}, + }, + ...options, +}); + +const createEdge = ( + id: string, + source: string, + target: string, + options: Partial = {}, +): Edge => ({ + id, + source, + target, + ...options, +}); + +describe('useBeforeDelete', () => { + beforeEach(() => { + useGraphStore.setState({ + nodes: [], + edges: [], + selectedNodeIds: [], + selectedEdgeIds: [], + clickedNodeId: '', + clickedToolId: '', + }); + }); + + it('expands iteration deletion to descendants and all touching edges', async () => { + const nodes = [ + createNode('iteration:0', Operator.Iteration, { type: 'group' }), + createNode('iterationStart:0', Operator.IterationStart, { + parentId: 'iteration:0', + type: 'iterationStartNode', + }), + createNode('message:0', Operator.Message, { parentId: 'iteration:0' }), + createNode('message:1', Operator.Message, { parentId: 'message:0' }), + createNode('generate:0', Operator.Generate), + ]; + + const edges = [ + createEdge('e1', 'iterationStart:0', 'message:0'), + createEdge('e2', 'message:0', 'message:1'), + createEdge('e3', 'message:0', 'generate:0'), + createEdge('e4', 'generate:0', 'message:1'), + ]; + + useGraphStore.setState({ nodes, edges }); + + const { result } = renderHook(() => useBeforeDelete()); + let deletion; + await act(async () => { + deletion = await result.current.handleBeforeDelete({ + nodes: [nodes[0] as any], + edges: [], + }); + }); + + expect(deletion?.nodes.map((node) => node.id).sort()).toEqual( + ['iteration:0', 'iterationStart:0', 'message:0', 'message:1'].sort(), + ); + expect(deletion?.edges.map((edge) => edge.id).sort()).toEqual( + ['e1', 'e2', 'e3', 'e4'].sort(), + ); + }); + + it('keeps begin and detached iteration-start protected', async () => { + const beginNode = createNode('begin', Operator.Begin); + const iterationNode = createNode('iteration:0', Operator.Iteration, { + type: 'group', + }); + const iterationStartNode = createNode( + 'iterationStart:0', + Operator.IterationStart, + { + parentId: 'iteration:0', + type: 'iterationStartNode', + }, + ); + + useGraphStore.setState({ + nodes: [beginNode, iterationNode, iterationStartNode], + edges: [], + }); + + const { result } = renderHook(() => useBeforeDelete()); + let beginDeletion; + let startDeletion; + await act(async () => { + beginDeletion = await result.current.handleBeforeDelete({ + nodes: [beginNode as any], + edges: [], + }); + startDeletion = await result.current.handleBeforeDelete({ + nodes: [iterationStartNode as any], + edges: [], + }); + }); + + expect(beginDeletion?.nodes).toEqual([]); + expect(startDeletion?.nodes).toEqual([]); + }); + + it('preserves agent downstream cleanup', async () => { + const nodes = [ + createNode('agent:0', Operator.Agent), + createNode('tool:0', Operator.Tool), + createNode('message:0', Operator.Message), + ]; + + const edges = [ + createEdge('e1', 'agent:0', 'tool:0', { + sourceHandle: NodeHandleId.AgentBottom, + }), + createEdge('e2', 'tool:0', 'message:0', { + sourceHandle: NodeHandleId.Tool, + }), + ]; + + useGraphStore.setState({ nodes, edges }); + + const { result } = renderHook(() => useBeforeDelete()); + let deletion; + await act(async () => { + deletion = await result.current.handleBeforeDelete({ + nodes: [nodes[0] as any], + edges, + }); + }); + + expect(deletion?.nodes.map((node) => node.id).sort()).toEqual( + ['agent:0', 'tool:0', 'message:0'].sort(), + ); + expect(deletion?.edges.map((edge) => edge.id).sort()).toEqual( + ['e1', 'e2'].sort(), + ); + }); + + it('expands iteration deletion to nested agent tool chains', async () => { + const nodes = [ + createNode('iteration:0', Operator.Iteration, { type: 'group' }), + createNode('iterationStart:0', Operator.IterationStart, { + parentId: 'iteration:0', + type: 'iterationStartNode', + }), + createNode('agent:0', Operator.Agent, { parentId: 'iteration:0' }), + createNode('tool:0', Operator.Tool), + createNode('message:0', Operator.Message), + createNode('generate:0', Operator.Generate), + ]; + + const edges = [ + createEdge('e1', 'iterationStart:0', 'agent:0'), + createEdge('e2', 'agent:0', 'tool:0', { + sourceHandle: NodeHandleId.AgentBottom, + }), + createEdge('e3', 'tool:0', 'message:0', { + sourceHandle: NodeHandleId.Tool, + }), + createEdge('e4', 'generate:0', 'message:0'), + ]; + + useGraphStore.setState({ nodes, edges }); + + const { result } = renderHook(() => useBeforeDelete()); + let deletion; + await act(async () => { + deletion = await result.current.handleBeforeDelete({ + nodes: [nodes[0] as any], + edges: [], + }); + }); + + expect(deletion?.nodes.map((node) => node.id).sort()).toEqual( + [ + 'iteration:0', + 'iterationStart:0', + 'agent:0', + 'tool:0', + 'message:0', + ].sort(), + ); + expect(deletion?.edges.map((edge) => edge.id).sort()).toEqual( + ['e1', 'e2', 'e3', 'e4'].sort(), + ); + }); +}); diff --git a/web/src/pages/agent/hooks/use-before-delete.tsx b/web/src/pages/agent/hooks/use-before-delete.tsx index ef60b758b1..15f76675bf 100644 --- a/web/src/pages/agent/hooks/use-before-delete.tsx +++ b/web/src/pages/agent/hooks/use-before-delete.tsx @@ -1,13 +1,16 @@ import { RAGFlowNodeType } from '@/interfaces/database/agent'; import { Node, OnBeforeDelete } from '@xyflow/react'; import { Operator } from '../constant'; -import useGraphStore from '../store'; +import useGraphStore, { collectDeletionNodeIds } from '../store'; import { deleteAllDownstreamAgentsAndTool } from '../utils/delete-node'; -const UndeletableNodes = [Operator.Begin, Operator.IterationStart]; - export function useBeforeDelete() { - const { getOperatorTypeFromId, getNode } = useGraphStore((state) => state); + const { + getOperatorTypeFromId, + getNode, + nodes: graphNodes, + edges: graphEdges, + } = useGraphStore((state) => state); const agentPredicate = (node: Node) => { return getOperatorTypeFromId(node.id) === Operator.Agent; @@ -33,28 +36,23 @@ export function useBeforeDelete() { return true; }); - const toBeDeletedEdges = edges.filter((edge) => { - const sourceType = getOperatorTypeFromId(edge.source) as Operator; - const downStreamNodes = nodes.filter((x) => x.id === edge.target); - - // This edge does not need to be deleted, the range of edges that do not need to be deleted is smaller, so consider the case where it does not need to be deleted - if ( - UndeletableNodes.includes(sourceType) && // Upstream node is Begin or IterationStart - downStreamNodes.length === 0 // Downstream node does not exist in the nodes to be deleted - ) { - if (!nodes.some((x) => x.id === edge.source)) { - return true; // Can be deleted - } - return false; // Cannot be deleted - } - - return true; - }); + toBeDeletedNodes + .filter((node) => node.data?.label === Operator.Iteration) + .forEach((node) => { + collectDeletionNodeIds(graphNodes, graphEdges, node.id) + .filter((nodeId) => nodeId !== node.id) + .forEach((nodeId) => { + const currentNode = getNode(nodeId); + if (currentNode && toBeDeletedNodes.every((x) => x.id !== nodeId)) { + toBeDeletedNodes.push(currentNode); + } + }); + }); // Delete the agent and tool nodes downstream of the agent node if (nodes.some(agentPredicate)) { nodes.filter(agentPredicate).forEach((node) => { - const { downstreamAgentAndToolEdges, downstreamAgentAndToolNodeIds } = + const { downstreamAgentAndToolNodeIds } = deleteAllDownstreamAgentsAndTool(node.id, edges); downstreamAgentAndToolNodeIds.forEach((nodeId) => { @@ -63,15 +61,18 @@ export function useBeforeDelete() { toBeDeletedNodes.push(currentNode); } }); - - downstreamAgentAndToolEdges.forEach((edge) => { - if (toBeDeletedEdges.every((x) => x.id !== edge.id)) { - toBeDeletedEdges.push(edge); - } - }); }, []); } + const toBeDeletedNodeIdSet = new Set( + toBeDeletedNodes.map((node) => node.id), + ); + const toBeDeletedEdges = graphEdges.filter( + (edge) => + toBeDeletedNodeIdSet.has(edge.source) || + toBeDeletedNodeIdSet.has(edge.target), + ); + return { nodes: toBeDeletedNodes, edges: toBeDeletedEdges, diff --git a/web/src/pages/agent/store.test.ts b/web/src/pages/agent/store.test.ts new file mode 100644 index 0000000000..91496dd2ee --- /dev/null +++ b/web/src/pages/agent/store.test.ts @@ -0,0 +1,161 @@ +import { Edge } from '@xyflow/react'; +import { NodeHandleId, Operator } from './constant'; +import useGraphStore from './store'; + +function baseNode(id: string, label: Operator) { + return { + id, + type: 'ragNode', + position: { x: 0, y: 0 }, + data: { + label, + name: id, + form: {}, + }, + }; +} + +const createNode = ( + id: string, + label: Operator, + options: Partial> = {}, +) => ({ + ...baseNode(id, label), + ...options, +}); + +const createEdge = ( + id: string, + source: string, + target: string, + options: Partial = {}, +): Edge => ({ + id, + source, + target, + ...options, +}); + +describe('useGraphStore.deleteIterationNodeById', () => { + beforeEach(() => { + useGraphStore.setState({ + nodes: [], + edges: [], + selectedNodeIds: [], + selectedEdgeIds: [], + clickedNodeId: '', + clickedToolId: '', + }); + }); + + it('removes the iteration node, its descendants, and all incident edges', () => { + const nodes = [ + createNode('begin', Operator.Begin), + createNode('iteration:0', Operator.Iteration, { type: 'group' }), + createNode('iterationStart:0', Operator.IterationStart, { + parentId: 'iteration:0', + type: 'iterationStartNode', + }), + createNode('message:0', Operator.Message, { parentId: 'iteration:0' }), + createNode('message:1', Operator.Message, { parentId: 'message:0' }), + createNode('generate:0', Operator.Generate), + ]; + + const edges = [ + createEdge('e1', 'begin', 'iteration:0'), + createEdge('e2', 'iterationStart:0', 'message:0'), + createEdge('e3', 'message:0', 'message:1'), + createEdge('e4', 'message:0', 'generate:0'), + createEdge('e5', 'generate:0', 'message:1'), + ]; + + useGraphStore.setState({ + nodes, + edges, + selectedNodeIds: ['iteration:0', 'message:0'], + selectedEdgeIds: ['e2', 'e4'], + clickedNodeId: 'message:0', + }); + + useGraphStore.getState().deleteIterationNodeById('iteration:0'); + + const state = useGraphStore.getState(); + + expect(state.nodes.map((node) => node.id)).toEqual(['begin', 'generate:0']); + expect(state.edges.map((edge) => edge.id)).toEqual([]); + expect(state.selectedNodeIds).toEqual([]); + expect(state.selectedEdgeIds).toEqual([]); + expect(state.clickedNodeId).toBe(''); + }); + + it('preserves unrelated graph branches', () => { + const nodes = [ + createNode('iteration:0', Operator.Iteration, { type: 'group' }), + createNode('iterationStart:0', Operator.IterationStart, { + parentId: 'iteration:0', + type: 'iterationStartNode', + }), + createNode('message:0', Operator.Message, { parentId: 'iteration:0' }), + createNode('begin', Operator.Begin), + createNode('generate:0', Operator.Generate), + createNode('message:2', Operator.Message), + ]; + + const edges = [ + createEdge('iteration-edge', 'iterationStart:0', 'message:0'), + createEdge('branch-edge-a', 'begin', 'generate:0'), + createEdge('branch-edge-b', 'generate:0', 'message:2'), + ]; + + useGraphStore.setState({ nodes, edges }); + + useGraphStore.getState().deleteIterationNodeById('iteration:0'); + + const state = useGraphStore.getState(); + + expect(state.nodes.map((node) => node.id)).toEqual([ + 'begin', + 'generate:0', + 'message:2', + ]); + expect(state.edges.map((edge) => edge.id)).toEqual([ + 'branch-edge-a', + 'branch-edge-b', + ]); + }); + + it('removes agent tool chains nested inside an iteration subtree', () => { + const nodes = [ + createNode('iteration:0', Operator.Iteration, { type: 'group' }), + createNode('iterationStart:0', Operator.IterationStart, { + parentId: 'iteration:0', + type: 'iterationStartNode', + }), + createNode('agent:0', Operator.Agent, { parentId: 'iteration:0' }), + createNode('tool:0', Operator.Tool), + createNode('message:0', Operator.Message), + createNode('begin', Operator.Begin), + createNode('generate:0', Operator.Generate), + ]; + + const edges = [ + createEdge('iteration-edge', 'iterationStart:0', 'agent:0'), + createEdge('tool-edge', 'agent:0', 'tool:0', { + sourceHandle: NodeHandleId.AgentBottom, + }), + createEdge('tool-output-edge', 'tool:0', 'message:0', { + sourceHandle: NodeHandleId.Tool, + }), + createEdge('branch-edge', 'begin', 'generate:0'), + ]; + + useGraphStore.setState({ nodes, edges }); + + useGraphStore.getState().deleteIterationNodeById('iteration:0'); + + const state = useGraphStore.getState(); + + expect(state.nodes.map((node) => node.id)).toEqual(['begin', 'generate:0']); + expect(state.edges.map((edge) => edge.id)).toEqual(['branch-edge']); + }); +}); diff --git a/web/src/pages/agent/store.ts b/web/src/pages/agent/store.ts index 7afc9c7941..1030a00652 100644 --- a/web/src/pages/agent/store.ts +++ b/web/src/pages/agent/store.ts @@ -41,6 +41,91 @@ import { deleteAllDownstreamAgentsAndTool } from './utils/delete-node'; type IAgentTool = IAgentForm['tools'][number]; +const collectDescendantNodeIds = ( + nodes: RAGFlowNodeType[], + rootId: string, +): string[] => { + const descendantNodeIds: string[] = []; + const queue = [rootId]; + + while (queue.length > 0) { + const currentNodeId = queue.shift(); + if (!currentNodeId) { + continue; + } + + const childNodeIds = nodes + .filter((node) => node.parentId === currentNodeId) + .map((node) => node.id); + + childNodeIds.forEach((nodeId) => { + if (!descendantNodeIds.includes(nodeId)) { + descendantNodeIds.push(nodeId); + queue.push(nodeId); + } + }); + } + + return descendantNodeIds; +}; + +const collectAgentAttachmentNodeIds = ( + nodes: RAGFlowNodeType[], + edges: Edge[], + rootNodeIds: string[], +) => { + const attachedNodeIds: string[] = []; + + rootNodeIds.forEach((nodeId) => { + const node = nodes.find((item) => item.id === nodeId); + if (node?.data?.label !== Operator.Agent) { + return; + } + + const { downstreamAgentAndToolNodeIds } = deleteAllDownstreamAgentsAndTool( + nodeId, + edges, + ); + + downstreamAgentAndToolNodeIds.forEach((attachedNodeId) => { + if (!attachedNodeIds.includes(attachedNodeId)) { + attachedNodeIds.push(attachedNodeId); + } + }); + }); + + return attachedNodeIds; +}; + +export const collectDeletionNodeIds = ( + nodes: RAGFlowNodeType[], + edges: Edge[], + rootId: string, +): string[] => { + const deletedNodeIds = [rootId, ...collectDescendantNodeIds(nodes, rootId)]; + const attachedNodeIds = collectAgentAttachmentNodeIds( + nodes, + edges, + deletedNodeIds, + ); + + attachedNodeIds.forEach((nodeId) => { + if (!deletedNodeIds.includes(nodeId)) { + deletedNodeIds.push(nodeId); + } + }); + + return deletedNodeIds; +}; + +export const removeEdgesForNodeIds = (edges: Edge[], nodeIds: string[]) => { + const nodeIdSet = new Set(nodeIds); + + return edges.filter( + (edge) => !nodeIdSet.has(edge.source) && !nodeIdSet.has(edge.target), + ); +}; + interface GetAgentToolByIdFunc { (id: string): IAgentTool | undefined; (id: string, agentNode: RAGFlowNodeType): IAgentTool | undefined; @@ -406,18 +491,32 @@ const useGraphStore = create()( } }, deleteIterationNodeById: (id: string) => { - const { nodes, edges } = get(); - const children = nodes.filter((node) => node.parentId === id); + const { + nodes, + edges, + selectedNodeIds, + selectedEdgeIds, + clickedNodeId, + } = get(); + const deletedNodeIds = collectDeletionNodeIds(nodes, edges, id); + const deletedNodeIdSet = new Set(deletedNodeIds); + const remainingEdges = removeEdgesForNodeIds(edges, deletedNodeIds); + const remainingEdgeIdSet = new Set( + remainingEdges.map((edge) => edge.id), + ); + set({ - nodes: nodes.filter((node) => node.id !== id && node.parentId !== id), - edges: edges.filter( - (edge) => - edge.source !== id && - edge.target !== id && - !children.some( - (child) => edge.source === child.id && edge.target === child.id, - ), + nodes: nodes.filter((node) => !deletedNodeIdSet.has(node.id)), + edges: remainingEdges, + selectedNodeIds: selectedNodeIds.filter( + (nodeId) => !deletedNodeIdSet.has(nodeId), ), + selectedEdgeIds: selectedEdgeIds.filter((edgeId) => + remainingEdgeIdSet.has(edgeId), + ), + clickedNodeId: deletedNodeIdSet.has(clickedNodeId) + ? '' + : clickedNodeId, }); }, findNodeByName: (name: Operator) => {