fix loopblock continue on failure issue (#1283)
This commit is contained in:
@@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
import asyncio
|
import asyncio
|
||||||
import csv
|
import csv
|
||||||
@@ -508,11 +510,51 @@ class TaskBlock(BaseTaskBlock):
|
|||||||
block_type: Literal[BlockType.TASK] = BlockType.TASK
|
block_type: Literal[BlockType.TASK] = BlockType.TASK
|
||||||
|
|
||||||
|
|
||||||
|
class LoopBlockExecutedResult(BaseModel):
|
||||||
|
outputs_with_loop_values: list[list[dict[str, Any]]]
|
||||||
|
block_outputs: list[BlockResult]
|
||||||
|
last_block: BlockTypeVar | None
|
||||||
|
|
||||||
|
def is_canceled(self) -> bool:
|
||||||
|
return len(self.block_outputs) > 0 and self.block_outputs[-1].status == BlockStatus.canceled
|
||||||
|
|
||||||
|
def is_completed(self) -> bool:
|
||||||
|
if len(self.block_outputs) == 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if self.last_block is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if self.is_canceled():
|
||||||
|
return False
|
||||||
|
|
||||||
|
last_ouput = self.block_outputs[-1]
|
||||||
|
if last_ouput.success:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if self.last_block.continue_on_failure:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def is_terminated(self) -> bool:
|
||||||
|
return len(self.block_outputs) > 0 and self.block_outputs[-1].status == BlockStatus.terminated
|
||||||
|
|
||||||
|
def get_failure_reason(self) -> str | None:
|
||||||
|
if self.is_completed():
|
||||||
|
return None
|
||||||
|
|
||||||
|
if self.is_canceled():
|
||||||
|
return f"Block({self.last_block.label if self.last_block else ''}) with type {self.last_block.block_type if self.last_block else ''} was canceled, canceling for loop"
|
||||||
|
|
||||||
|
return self.block_outputs[-1].failure_reason if len(self.block_outputs) > 0 else "No block has been executed"
|
||||||
|
|
||||||
|
|
||||||
class ForLoopBlock(Block):
|
class ForLoopBlock(Block):
|
||||||
block_type: Literal[BlockType.FOR_LOOP] = BlockType.FOR_LOOP
|
block_type: Literal[BlockType.FOR_LOOP] = BlockType.FOR_LOOP
|
||||||
|
|
||||||
loop_over: PARAMETER_TYPE
|
loop_over: PARAMETER_TYPE
|
||||||
loop_blocks: list["BlockTypeVar"]
|
loop_blocks: list[BlockTypeVar]
|
||||||
|
|
||||||
def get_all_parameters(
|
def get_all_parameters(
|
||||||
self,
|
self,
|
||||||
@@ -588,9 +630,73 @@ class ForLoopBlock(Block):
|
|||||||
# TODO (kerem): Should we raise an error here?
|
# TODO (kerem): Should we raise an error here?
|
||||||
return [parameter_value]
|
return [parameter_value]
|
||||||
|
|
||||||
async def execute(self, workflow_run_id: str, **kwargs: dict) -> BlockResult:
|
async def execute_loop_helper(
|
||||||
|
self, workflow_run_id: str, workflow_run_context: WorkflowRunContext, loop_over_values: list[Any]
|
||||||
|
) -> LoopBlockExecutedResult:
|
||||||
outputs_with_loop_values: list[list[dict[str, Any]]] = []
|
outputs_with_loop_values: list[list[dict[str, Any]]] = []
|
||||||
success = False
|
block_outputs: list[BlockResult] = []
|
||||||
|
current_block: BlockTypeVar | None = None
|
||||||
|
|
||||||
|
for loop_idx, loop_over_value in enumerate(loop_over_values):
|
||||||
|
context_parameters_with_value = self.get_loop_block_context_parameters(workflow_run_id, loop_over_value)
|
||||||
|
for context_parameter in context_parameters_with_value:
|
||||||
|
workflow_run_context.set_value(context_parameter.key, context_parameter.value)
|
||||||
|
each_loop_output_values: list[dict[str, Any]] = []
|
||||||
|
for block_idx, loop_block in enumerate(self.loop_blocks):
|
||||||
|
original_loop_block = loop_block
|
||||||
|
loop_block = loop_block.copy()
|
||||||
|
current_block = loop_block
|
||||||
|
|
||||||
|
block_output = await loop_block.execute_safe(workflow_run_id=workflow_run_id)
|
||||||
|
each_loop_output_values.append(
|
||||||
|
{
|
||||||
|
"loop_value": loop_over_value,
|
||||||
|
"output_parameter": block_output.output_parameter,
|
||||||
|
"output_value": workflow_run_context.get_value(block_output.output_parameter.key),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
loop_block = original_loop_block
|
||||||
|
block_outputs.append(block_output)
|
||||||
|
if block_output.status == BlockStatus.canceled:
|
||||||
|
LOG.info(
|
||||||
|
f"ForLoopBlock: Block with type {loop_block.block_type} at index {block_idx} during loop {loop_idx} was canceled for workflow run {workflow_run_id}, canceling for loop",
|
||||||
|
block_type=loop_block.block_type,
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
block_idx=block_idx,
|
||||||
|
block_result=block_outputs,
|
||||||
|
)
|
||||||
|
outputs_with_loop_values.append(each_loop_output_values)
|
||||||
|
return LoopBlockExecutedResult(
|
||||||
|
outputs_with_loop_values=outputs_with_loop_values,
|
||||||
|
block_outputs=block_outputs,
|
||||||
|
last_block=current_block,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not block_output.success and not loop_block.continue_on_failure:
|
||||||
|
LOG.info(
|
||||||
|
f"ForLoopBlock: Encountered an failure processing block {block_idx} during loop {loop_idx}, terminating early",
|
||||||
|
block_outputs=block_outputs,
|
||||||
|
loop_idx=loop_idx,
|
||||||
|
block_idx=block_idx,
|
||||||
|
loop_over_value=loop_over_value,
|
||||||
|
loop_block_continue_on_failure=loop_block.continue_on_failure,
|
||||||
|
)
|
||||||
|
outputs_with_loop_values.append(each_loop_output_values)
|
||||||
|
return LoopBlockExecutedResult(
|
||||||
|
outputs_with_loop_values=outputs_with_loop_values,
|
||||||
|
block_outputs=block_outputs,
|
||||||
|
last_block=current_block,
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs_with_loop_values.append(each_loop_output_values)
|
||||||
|
|
||||||
|
return LoopBlockExecutedResult(
|
||||||
|
outputs_with_loop_values=outputs_with_loop_values,
|
||||||
|
block_outputs=block_outputs,
|
||||||
|
last_block=current_block,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def execute(self, workflow_run_id: str, **kwargs: dict) -> BlockResult:
|
||||||
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
|
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
|
||||||
loop_over_values = self.get_loop_over_parameter_values(workflow_run_context)
|
loop_over_values = self.get_loop_over_parameter_values(workflow_run_context)
|
||||||
LOG.info(
|
LOG.info(
|
||||||
@@ -625,84 +731,32 @@ class ForLoopBlock(Block):
|
|||||||
success=False, failure_reason="No defined blocks to loop", status=BlockStatus.terminated
|
success=False, failure_reason="No defined blocks to loop", status=BlockStatus.terminated
|
||||||
)
|
)
|
||||||
|
|
||||||
block_outputs: list[BlockResult] = []
|
loop_executed_result = await self.execute_loop_helper(
|
||||||
for loop_idx, loop_over_value in enumerate(loop_over_values):
|
workflow_run_id=workflow_run_id,
|
||||||
context_parameters_with_value = self.get_loop_block_context_parameters(workflow_run_id, loop_over_value)
|
workflow_run_context=workflow_run_context,
|
||||||
for context_parameter in context_parameters_with_value:
|
loop_over_values=loop_over_values,
|
||||||
workflow_run_context.set_value(context_parameter.key, context_parameter.value)
|
)
|
||||||
for block_idx, loop_block in enumerate(self.loop_blocks):
|
await self.record_output_parameter_value(
|
||||||
original_loop_block = loop_block
|
workflow_run_context, workflow_run_id, loop_executed_result.outputs_with_loop_values
|
||||||
loop_block = loop_block.copy()
|
)
|
||||||
block_output = await loop_block.execute_safe(workflow_run_id=workflow_run_id)
|
block_status = BlockStatus.failed
|
||||||
if block_output.status == BlockStatus.canceled:
|
success = False
|
||||||
failure_message = f"ForLoopBlock: Block with type {loop_block.block_type} at index {block_idx} during loop {loop_idx} was canceled for workflow run {workflow_run_id}, canceling for loop"
|
|
||||||
LOG.info(
|
|
||||||
failure_message,
|
|
||||||
block_type=loop_block.block_type,
|
|
||||||
workflow_run_id=workflow_run_id,
|
|
||||||
block_idx=block_idx,
|
|
||||||
block_result=block_output,
|
|
||||||
)
|
|
||||||
await self.record_output_parameter_value(
|
|
||||||
workflow_run_context, workflow_run_id, outputs_with_loop_values
|
|
||||||
)
|
|
||||||
return self.build_block_result(
|
|
||||||
success=False,
|
|
||||||
failure_reason=failure_message,
|
|
||||||
output_parameter_value=outputs_with_loop_values,
|
|
||||||
status=BlockStatus.canceled,
|
|
||||||
)
|
|
||||||
|
|
||||||
loop_block = original_loop_block
|
if loop_executed_result.is_canceled():
|
||||||
block_outputs.append(block_output)
|
block_status = BlockStatus.canceled
|
||||||
if not block_output.success and not loop_block.continue_on_failure:
|
elif loop_executed_result.is_completed():
|
||||||
LOG.info(
|
block_status = BlockStatus.completed
|
||||||
f"ForLoopBlock: Encountered an failure processing block {block_idx} during loop {loop_idx}, terminating early",
|
success = True
|
||||||
block_outputs=block_outputs,
|
elif loop_executed_result.is_terminated():
|
||||||
loop_idx=loop_idx,
|
block_status = BlockStatus.terminated
|
||||||
block_idx=block_idx,
|
else:
|
||||||
loop_over_value=loop_over_value,
|
block_status = BlockStatus.failed
|
||||||
loop_block_continue_on_failure=loop_block.continue_on_failure,
|
|
||||||
)
|
|
||||||
break
|
|
||||||
|
|
||||||
outputs_with_loop_values.append(
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"loop_value": loop_over_value,
|
|
||||||
"output_parameter": block_output.output_parameter,
|
|
||||||
"output_value": workflow_run_context.get_value(block_output.output_parameter.key),
|
|
||||||
}
|
|
||||||
for block_output in block_outputs
|
|
||||||
if block_output.output_parameter
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# If all block outputs are successful, the loop is successful. If self.continue_on_failure is True, we will
|
|
||||||
# continue to the next loop iteration even if there are failures.
|
|
||||||
success = all([block_output.success for block_output in block_outputs])
|
|
||||||
if not success and not self.continue_on_failure:
|
|
||||||
LOG.info(
|
|
||||||
f"ForLoopBlock: Encountered an failure processing loop {loop_idx}, won't continue to the next loop. Total number of loops: {len(loop_over_values)}",
|
|
||||||
for_loop_continue_on_failure=self.continue_on_failure,
|
|
||||||
)
|
|
||||||
break
|
|
||||||
|
|
||||||
# at least one block must be executed in the loop
|
|
||||||
assert len(block_outputs) != 0
|
|
||||||
|
|
||||||
is_any_block_terminated = any([block_output.status == BlockStatus.terminated for block_output in block_outputs])
|
|
||||||
for_loop_block_status = BlockStatus.completed
|
|
||||||
if is_any_block_terminated:
|
|
||||||
for_loop_block_status = BlockStatus.terminated
|
|
||||||
elif not success:
|
|
||||||
for_loop_block_status = BlockStatus.failed
|
|
||||||
await self.record_output_parameter_value(workflow_run_context, workflow_run_id, outputs_with_loop_values)
|
|
||||||
return self.build_block_result(
|
return self.build_block_result(
|
||||||
success=success,
|
success=success,
|
||||||
failure_reason=block_outputs[-1].failure_reason,
|
failure_reason=loop_executed_result.get_failure_reason(),
|
||||||
output_parameter_value=outputs_with_loop_values,
|
output_parameter_value=loop_executed_result.outputs_with_loop_values,
|
||||||
status=for_loop_block_status,
|
status=block_status,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user