Add loop-scoped DAG execution for conditionals inside for-loops - backend (#4302)

This commit is contained in:
Celal Zamanoglu
2025-12-15 23:20:21 +03:00
committed by GitHub
parent 0572746608
commit 781bd13c5a

View File

@@ -10,7 +10,7 @@ import re
import smtplib import smtplib
import textwrap import textwrap
import uuid import uuid
from collections import defaultdict from collections import defaultdict, deque
from datetime import datetime from datetime import datetime
from email.message import EmailMessage from email.message import EmailMessage
from pathlib import Path from pathlib import Path
@@ -77,6 +77,7 @@ from skyvern.forge.sdk.workflow.exceptions import (
InsecureCodeDetected, InsecureCodeDetected,
InvalidEmailClientConfiguration, InvalidEmailClientConfiguration,
InvalidFileType, InvalidFileType,
InvalidWorkflowDefinition,
MissingJinjaVariables, MissingJinjaVariables,
NoIterableValueFound, NoIterableValueFound,
NoValidEmailRecipient, NoValidEmailRecipient,
@@ -1336,6 +1337,71 @@ class ForLoopBlock(Block):
output_parameter=output_param, output_parameter=output_param,
) )
def _build_loop_graph(
self, blocks: list[BlockTypeVar]
) -> tuple[str, dict[str, BlockTypeVar], dict[str, str | None]]:
label_to_block: dict[str, BlockTypeVar] = {}
default_next_map: dict[str, str | None] = {}
for block in blocks:
if block.label in label_to_block:
raise InvalidWorkflowDefinition(f"Duplicate block label detected in loop: {block.label}")
label_to_block[block.label] = block
default_next_map[block.label] = block.next_block_label
has_conditional_blocks = any(block.block_type == BlockType.CONDITIONAL for block in blocks)
if not has_conditional_blocks:
for idx, block in enumerate(blocks[:-1]):
if default_next_map.get(block.label) is None:
default_next_map[block.label] = blocks[idx + 1].label
adjacency: dict[str, set[str]] = {label: set() for label in label_to_block}
incoming: dict[str, int] = {label: 0 for label in label_to_block}
def _add_edge(source: str, target: str | None) -> None:
if not target:
return
if target not in label_to_block:
raise InvalidWorkflowDefinition(
f"Block {source} references unknown next_block_label {target} inside loop {self.label}"
)
# Allow multiple branches of a conditional to point to the same target
# without double-counting the incoming edge.
if target not in adjacency[source]:
adjacency[source].add(target)
incoming[target] += 1
for label, block in label_to_block.items():
if block.block_type == BlockType.CONDITIONAL:
for branch in block.ordered_branches:
_add_edge(label, branch.next_block_label)
else:
_add_edge(label, default_next_map.get(label))
roots = [label for label, count in incoming.items() if count == 0]
if not roots:
raise InvalidWorkflowDefinition(f"No entry block found for loop {self.label}")
if len(roots) > 1:
raise InvalidWorkflowDefinition(
f"Multiple entry blocks detected in loop {self.label} ({', '.join(sorted(roots))}); only one entry block is supported."
)
queue: deque[str] = deque([roots[0]])
visited_count = 0
in_degree = dict(incoming)
while queue:
node = queue.popleft()
visited_count += 1
for neighbor in adjacency[node]:
in_degree[neighbor] -= 1
if in_degree[neighbor] == 0:
queue.append(neighbor)
if visited_count != len(label_to_block):
raise InvalidWorkflowDefinition(f"Loop {self.label} contains a cycle; DAG traversal is required.")
return roots[0], label_to_block, default_next_map
async def execute_loop_helper( async def execute_loop_helper(
self, self,
workflow_run_id: str, workflow_run_id: str,
@@ -1349,6 +1415,8 @@ class ForLoopBlock(Block):
block_outputs: list[BlockResult] = [] block_outputs: list[BlockResult] = []
current_block: BlockTypeVar | None = None current_block: BlockTypeVar | None = None
start_label, label_to_block, default_next_map = self._build_loop_graph(self.loop_blocks)
for loop_idx, loop_over_value in enumerate(loop_over_values): for loop_idx, loop_over_value in enumerate(loop_over_values):
# Check max_iterations limit # Check max_iterations limit
if loop_idx >= DEFAULT_MAX_LOOP_ITERATIONS: if loop_idx >= DEFAULT_MAX_LOOP_ITERATIONS:
@@ -1379,7 +1447,6 @@ class ForLoopBlock(Block):
each_loop_output_values: list[dict[str, Any]] = [] each_loop_output_values: list[dict[str, Any]] = []
# Track steps for current iteration
iteration_step_count = 0 iteration_step_count = 0
LOG.info( LOG.info(
f"ForLoopBlock: Starting iteration {loop_idx} with max_steps_per_iteration={DEFAULT_MAX_STEPS_PER_ITERATION}", f"ForLoopBlock: Starting iteration {loop_idx} with max_steps_per_iteration={DEFAULT_MAX_STEPS_PER_ITERATION}",
@@ -1388,7 +1455,32 @@ class ForLoopBlock(Block):
max_steps_per_iteration=DEFAULT_MAX_STEPS_PER_ITERATION, max_steps_per_iteration=DEFAULT_MAX_STEPS_PER_ITERATION,
) )
for block_idx, loop_block in enumerate(self.loop_blocks): block_idx = 0
current_label: str | None = start_label
while current_label:
loop_block = label_to_block.get(current_label)
if not loop_block:
LOG.error(
"Unable to find loop block with label in loop graph",
workflow_run_id=workflow_run_id,
loop_label=self.label,
current_label=current_label,
)
failure_block_result = await self.build_block_result(
success=False,
status=BlockStatus.failed,
failure_reason=f"Unable to find block with label {current_label} inside loop {self.label}",
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
)
block_outputs.append(failure_block_result)
outputs_with_loop_values.append(each_loop_output_values)
return LoopBlockExecutedResult(
outputs_with_loop_values=outputs_with_loop_values,
block_outputs=block_outputs,
last_block=current_block,
)
metadata: BlockMetadata = { metadata: BlockMetadata = {
"current_index": loop_idx, "current_index": loop_idx,
"current_value": loop_over_value, "current_value": loop_over_value,
@@ -1515,6 +1607,38 @@ class ForLoopBlock(Block):
) )
if block_output.success or loop_block.continue_on_failure: if block_output.success or loop_block.continue_on_failure:
next_label: str | None = None
if loop_block.block_type == BlockType.CONDITIONAL:
branch_metadata = (
block_output.output_parameter_value
if isinstance(block_output.output_parameter_value, dict)
else None
)
next_label = (branch_metadata or {}).get("next_block_label")
else:
next_label = default_next_map.get(loop_block.label)
if not next_label:
break
if next_label not in label_to_block:
failure_block_result = await self.build_block_result(
success=False,
status=BlockStatus.failed,
failure_reason=f"Next block label {next_label} not found inside loop {self.label}",
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
)
block_outputs.append(failure_block_result)
outputs_with_loop_values.append(each_loop_output_values)
return LoopBlockExecutedResult(
outputs_with_loop_values=outputs_with_loop_values,
block_outputs=block_outputs,
last_block=current_block,
)
current_label = next_label
block_idx += 1
continue continue
if loop_block.next_loop_on_failure or self.next_loop_on_failure: if loop_block.next_loop_on_failure or self.next_loop_on_failure:
@@ -1528,6 +1652,8 @@ class ForLoopBlock(Block):
) )
break break
break
outputs_with_loop_values.append(each_loop_output_values) outputs_with_loop_values.append(each_loop_output_values)
return LoopBlockExecutedResult( return LoopBlockExecutedResult(
@@ -1616,14 +1742,29 @@ class ForLoopBlock(Block):
organization_id=organization_id, organization_id=organization_id,
) )
loop_executed_result = await self.execute_loop_helper( try:
workflow_run_id=workflow_run_id, loop_executed_result = await self.execute_loop_helper(
workflow_run_block_id=workflow_run_block_id, workflow_run_id=workflow_run_id,
workflow_run_context=workflow_run_context, workflow_run_block_id=workflow_run_block_id,
loop_over_values=loop_over_values, workflow_run_context=workflow_run_context,
organization_id=organization_id, loop_over_values=loop_over_values,
browser_session_id=browser_session_id, organization_id=organization_id,
) browser_session_id=browser_session_id,
)
except InvalidWorkflowDefinition as exc:
LOG.error(
"Loop graph validation failed",
error=str(exc),
workflow_run_id=workflow_run_id,
loop_label=self.label,
)
return await self.build_block_result(
success=False,
failure_reason=str(exc),
status=BlockStatus.failed,
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
)
await self.record_output_parameter_value( await self.record_output_parameter_value(
workflow_run_context, workflow_run_id, loop_executed_result.outputs_with_loop_values workflow_run_context, workflow_run_id, loop_executed_result.outputs_with_loop_values
) )