From 1aa0d1c5322613ab97d71b72aa9b7949c8007ded Mon Sep 17 00:00:00 2001 From: arkunzz <4873204@qq.com> Date: Tue, 4 Mar 2025 11:52:17 +0800 Subject: [PATCH] fix: loop node metadata --- api/core/workflow/nodes/loop/loop_node.py | 61 ++++++++++++----------- 1 file changed, 31 insertions(+), 30 deletions(-) diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 8d97ab73b1..b0736c95f4 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -108,6 +108,10 @@ class LoopNode(BaseNode[LoopNodeData]): for i in range(loop_count): # Run workflow rst = graph_engine.run() + current_index_variable = variable_pool.get([self.node_id, "index"]) + if not isinstance(current_index_variable, IntegerSegment): + raise ValueError(f"loop {self.node_id} current index not found") + current_index = current_index_variable.value check_break_result = False @@ -123,30 +127,7 @@ class LoopNode(BaseNode[LoopNodeData]): continue if isinstance(event, NodeRunSucceededEvent): - if event.route_node_state.node_run_result: - metadata = event.route_node_state.node_run_result.metadata - if not metadata: - metadata = {} - if NodeRunMetadataKey.LOOP_ID not in metadata: - index_variable = variable_pool.get([self.node_id, "index"]) - if not isinstance(index_variable, IntegerSegment): - total_tokens = graph_engine.graph_runtime_state.total_tokens - yield RunCompletedEvent( - run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=f"Invalid index variable type: {type(index_variable)}", - metadata={NodeRunMetadataKey.TOTAL_TOKENS: total_tokens}, - ) - ) - return - metadata = { - **metadata, - NodeRunMetadataKey.LOOP_ID: self.node_id, - NodeRunMetadataKey.LOOP_INDEX: index_variable.value, - } - event.route_node_state.node_run_result.metadata = metadata - - yield event + yield self._handle_event_metadata(event=event, iter_run_index=current_index) # Check if all variables in break conditions exist exists_variable = False @@ -220,7 +201,7 @@ class LoopNode(BaseNode[LoopNodeData]): ) return else: - yield cast(InNodeEvent, event) + yield self._handle_event_metadata(event=cast(InNodeEvent, event), iter_run_index=current_index) # Remove all nodes outputs from variable pool for node_id in loop_graph.node_ids: @@ -230,11 +211,7 @@ class LoopNode(BaseNode[LoopNodeData]): break # Move to next loop - current_index_variable = variable_pool.get([self.node_id, "index"]) - if not isinstance(current_index_variable, IntegerSegment): - raise ValueError(f"loop {self.node_id} current index not found") - - next_index = current_index_variable.value + 1 + next_index = current_index + 1 variable_pool.add([self.node_id, "index"], next_index) yield LoopRunNextEvent( @@ -298,6 +275,30 @@ class LoopNode(BaseNode[LoopNodeData]): # Clean up variable_pool.remove([self.node_id, "index"]) + def _handle_event_metadata( + self, + *, + event: BaseNodeEvent | InNodeEvent, + iter_run_index: int, + ) -> NodeRunStartedEvent | BaseNodeEvent | InNodeEvent: + """ + add iteration metadata to event. + """ + if not isinstance(event, BaseNodeEvent): + return event + if event.route_node_state.node_run_result: + metadata = event.route_node_state.node_run_result.metadata + if not metadata: + metadata = {} + if NodeRunMetadataKey.LOOP_ID not in metadata: + metadata = { + **metadata, + NodeRunMetadataKey.LOOP_ID: self.node_id, + NodeRunMetadataKey.LOOP_INDEX: iter_run_index + } + event.route_node_state.node_run_result.metadata = metadata + return event + @classmethod def _extract_variable_selector_to_variable_mapping( cls,