handle workflow run cancel for child workflow runs / task v2 + observer cancel handling (#1776)
This commit is contained in:
@@ -1511,6 +1511,24 @@ class AgentDB:
|
|||||||
LOG.error("SQLAlchemyError", exc_info=True)
|
LOG.error("SQLAlchemyError", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
async def get_workflow_runs_by_parent_workflow_run_id(
|
||||||
|
self,
|
||||||
|
organization_id: str,
|
||||||
|
parent_workflow_run_id: str,
|
||||||
|
) -> list[WorkflowRun]:
|
||||||
|
try:
|
||||||
|
async with self.Session() as session:
|
||||||
|
query = (
|
||||||
|
select(WorkflowRunModel)
|
||||||
|
.filter(WorkflowRunModel.organization_id == organization_id)
|
||||||
|
.filter(WorkflowRunModel.parent_workflow_run_id == parent_workflow_run_id)
|
||||||
|
)
|
||||||
|
workflow_runs = (await session.scalars(query)).all()
|
||||||
|
return [convert_to_workflow_run(run) for run in workflow_runs]
|
||||||
|
except SQLAlchemyError:
|
||||||
|
LOG.error("SQLAlchemyError", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
async def create_workflow_parameter(
|
async def create_workflow_parameter(
|
||||||
self,
|
self,
|
||||||
workflow_id: str,
|
workflow_id: str,
|
||||||
|
|||||||
@@ -339,6 +339,19 @@ async def cancel_workflow_run(
|
|||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
detail=f"Workflow run not found {workflow_run_id}",
|
detail=f"Workflow run not found {workflow_run_id}",
|
||||||
)
|
)
|
||||||
|
# get all the child workflow runs and cancel them
|
||||||
|
child_workflow_runs = await app.DATABASE.get_workflow_runs_by_parent_workflow_run_id(
|
||||||
|
organization_id=current_org.organization_id,
|
||||||
|
parent_workflow_run_id=workflow_run_id,
|
||||||
|
)
|
||||||
|
for child_workflow_run in child_workflow_runs:
|
||||||
|
if child_workflow_run.status not in [
|
||||||
|
WorkflowRunStatus.running,
|
||||||
|
WorkflowRunStatus.created,
|
||||||
|
WorkflowRunStatus.queued,
|
||||||
|
]:
|
||||||
|
continue
|
||||||
|
await app.WORKFLOW_SERVICE.mark_workflow_run_as_canceled(child_workflow_run.workflow_run_id)
|
||||||
await app.WORKFLOW_SERVICE.mark_workflow_run_as_canceled(workflow_run_id)
|
await app.WORKFLOW_SERVICE.mark_workflow_run_as_canceled(workflow_run_id)
|
||||||
await app.WORKFLOW_SERVICE.execute_workflow_webhook(workflow_run, api_key=x_api_key)
|
await app.WORKFLOW_SERVICE.execute_workflow_webhook(workflow_run, api_key=x_api_key)
|
||||||
|
|
||||||
|
|||||||
@@ -163,9 +163,11 @@ async def initialize_observer_task(
|
|||||||
except Exception:
|
except Exception:
|
||||||
LOG.error("Failed to setup cruise workflow run", exc_info=True)
|
LOG.error("Failed to setup cruise workflow run", exc_info=True)
|
||||||
# fail the workflow run
|
# fail the workflow run
|
||||||
await app.WORKFLOW_SERVICE.mark_workflow_run_as_failed(
|
await mark_observer_task_as_failed(
|
||||||
|
observer_cruise_id=observer_task.observer_cruise_id,
|
||||||
workflow_run_id=workflow_run.workflow_run_id,
|
workflow_run_id=workflow_run.workflow_run_id,
|
||||||
failure_reason="Skyvern failed to setup the workflow run",
|
failure_reason="Skyvern failed to setup the workflow run",
|
||||||
|
organization_id=organization.organization_id,
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@@ -204,9 +206,11 @@ async def initialize_observer_task(
|
|||||||
except Exception:
|
except Exception:
|
||||||
LOG.warning("Failed to update task 2.0", exc_info=True)
|
LOG.warning("Failed to update task 2.0", exc_info=True)
|
||||||
# fail the workflow run
|
# fail the workflow run
|
||||||
await app.WORKFLOW_SERVICE.mark_workflow_run_as_failed(
|
await mark_observer_task_as_failed(
|
||||||
|
observer_cruise_id=observer_task.observer_cruise_id,
|
||||||
workflow_run_id=workflow_run.workflow_run_id,
|
workflow_run_id=workflow_run.workflow_run_id,
|
||||||
failure_reason="Skyvern failed to update the task 2.0 after initializing the workflow run",
|
failure_reason="Skyvern failed to update the task 2.0 after initializing the workflow run",
|
||||||
|
organization_id=organization.organization_id,
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@@ -225,14 +229,18 @@ async def run_observer_task(
|
|||||||
observer_task = await app.DATABASE.get_observer_cruise(observer_cruise_id, organization_id=organization_id)
|
observer_task = await app.DATABASE.get_observer_cruise(observer_cruise_id, organization_id=organization_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
LOG.error(
|
LOG.error(
|
||||||
"Failed to get observer cruise",
|
"Failed to get observer task",
|
||||||
observer_cruise_id=observer_cruise_id,
|
observer_cruise_id=observer_cruise_id,
|
||||||
organization_id=organization_id,
|
organization_id=organization_id,
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
return await mark_observer_task_as_failed(observer_cruise_id, organization_id=organization_id)
|
return await mark_observer_task_as_failed(
|
||||||
|
observer_cruise_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
failure_reason="Failed to get task v2",
|
||||||
|
)
|
||||||
if not observer_task:
|
if not observer_task:
|
||||||
LOG.error("Observer cruise not found", observer_cruise_id=observer_cruise_id, organization_id=organization_id)
|
LOG.error("Task v2 not found", observer_cruise_id=observer_cruise_id, organization_id=organization_id)
|
||||||
raise ObserverCruiseNotFound(observer_cruise_id=observer_cruise_id)
|
raise ObserverCruiseNotFound(observer_cruise_id=observer_cruise_id)
|
||||||
|
|
||||||
workflow, workflow_run = None, None
|
workflow, workflow_run = None, None
|
||||||
@@ -365,6 +373,25 @@ async def run_observer_task_helper(
|
|||||||
|
|
||||||
max_iterations = int_max_iterations_override or DEFAULT_MAX_ITERATIONS
|
max_iterations = int_max_iterations_override or DEFAULT_MAX_ITERATIONS
|
||||||
for i in range(max_iterations):
|
for i in range(max_iterations):
|
||||||
|
# check the status of the workflow run
|
||||||
|
workflow_run = await app.WORKFLOW_SERVICE.get_workflow_run(workflow_run_id, organization_id=organization_id)
|
||||||
|
if not workflow_run:
|
||||||
|
LOG.error("Workflow run not found", workflow_run_id=workflow_run_id)
|
||||||
|
break
|
||||||
|
|
||||||
|
if workflow_run.status == WorkflowRunStatus.canceled:
|
||||||
|
LOG.info(
|
||||||
|
"Task v2 is canceled. Stopping task v2",
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
observer_task_id=observer_cruise_id,
|
||||||
|
)
|
||||||
|
await mark_observer_task_as_canceled(
|
||||||
|
observer_cruise_id,
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
|
return workflow, workflow_run, observer_task
|
||||||
|
|
||||||
LOG.info(f"Observer iteration i={i}", workflow_run_id=workflow_run_id, url=url)
|
LOG.info(f"Observer iteration i={i}", workflow_run_id=workflow_run_id, url=url)
|
||||||
task_type = ""
|
task_type = ""
|
||||||
plan = ""
|
plan = ""
|
||||||
@@ -472,7 +499,8 @@ async def run_observer_task_helper(
|
|||||||
# parse observer repsonse and run the next task
|
# parse observer repsonse and run the next task
|
||||||
if not task_type:
|
if not task_type:
|
||||||
LOG.error("No task type found in observer response", observer_response=observer_response)
|
LOG.error("No task type found in observer response", observer_response=observer_response)
|
||||||
await app.WORKFLOW_SERVICE.mark_workflow_run_as_failed(
|
await mark_observer_task_as_failed(
|
||||||
|
observer_cruise_id=observer_cruise_id,
|
||||||
workflow_run_id=workflow_run_id,
|
workflow_run_id=workflow_run_id,
|
||||||
failure_reason="Skyvern failed to generate a task. Please try again later.",
|
failure_reason="Skyvern failed to generate a task. Please try again later.",
|
||||||
)
|
)
|
||||||
@@ -523,14 +551,16 @@ async def run_observer_task_helper(
|
|||||||
}
|
}
|
||||||
except Exception:
|
except Exception:
|
||||||
LOG.exception("Failed to generate loop task")
|
LOG.exception("Failed to generate loop task")
|
||||||
await app.WORKFLOW_SERVICE.mark_workflow_run_as_failed(
|
await mark_observer_task_as_failed(
|
||||||
|
observer_cruise_id=observer_cruise_id,
|
||||||
workflow_run_id=workflow_run_id,
|
workflow_run_id=workflow_run_id,
|
||||||
failure_reason="Failed to generate the loop.",
|
failure_reason="Failed to generate the loop.",
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
LOG.info("Unsupported task type", task_type=task_type)
|
LOG.info("Unsupported task type", task_type=task_type)
|
||||||
await app.WORKFLOW_SERVICE.mark_workflow_run_as_failed(
|
await mark_observer_task_as_failed(
|
||||||
|
observer_cruise_id=observer_cruise_id,
|
||||||
workflow_run_id=workflow_run_id,
|
workflow_run_id=workflow_run_id,
|
||||||
failure_reason=f"Unsupported task block type gets generated: {task_type}",
|
failure_reason=f"Unsupported task block type gets generated: {task_type}",
|
||||||
)
|
)
|
||||||
@@ -580,6 +610,7 @@ async def run_observer_task_helper(
|
|||||||
|
|
||||||
# execute the extraction task
|
# execute the extraction task
|
||||||
workflow_run = await handle_block_result(
|
workflow_run = await handle_block_result(
|
||||||
|
observer_cruise_id,
|
||||||
block,
|
block,
|
||||||
block_result,
|
block_result,
|
||||||
workflow,
|
workflow,
|
||||||
@@ -680,6 +711,7 @@ async def run_observer_task_helper(
|
|||||||
|
|
||||||
|
|
||||||
async def handle_block_result(
|
async def handle_block_result(
|
||||||
|
observer_cruise_id: str,
|
||||||
block: BlockTypeVar,
|
block: BlockTypeVar,
|
||||||
block_result: BlockResult,
|
block_result: BlockResult,
|
||||||
workflow: Workflow,
|
workflow: Workflow,
|
||||||
@@ -697,7 +729,11 @@ async def handle_block_result(
|
|||||||
block_type_var=block.block_type,
|
block_type_var=block.block_type,
|
||||||
block_label=block.label,
|
block_label=block.label,
|
||||||
)
|
)
|
||||||
await app.WORKFLOW_SERVICE.mark_workflow_run_as_canceled(workflow_run_id=workflow_run.workflow_run_id)
|
await mark_observer_task_as_canceled(
|
||||||
|
observer_cruise_id=observer_cruise_id,
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
organization_id=workflow_run.organization_id,
|
||||||
|
)
|
||||||
|
|
||||||
elif block_result.status == BlockStatus.failed:
|
elif block_result.status == BlockStatus.failed:
|
||||||
LOG.error(
|
LOG.error(
|
||||||
@@ -826,11 +862,7 @@ async def _generate_loop_task(
|
|||||||
"Failed to execute the extraction block for the loop task",
|
"Failed to execute the extraction block for the loop task",
|
||||||
extraction_block_result=extraction_block_result,
|
extraction_block_result=extraction_block_result,
|
||||||
)
|
)
|
||||||
# TODO: fail the workflow run
|
# wofklow run and observer task status update is handled in the upper caller layer
|
||||||
await app.WORKFLOW_SERVICE.mark_workflow_run_as_failed(
|
|
||||||
workflow_run_id=workflow_run_id,
|
|
||||||
failure_reason="Failed to extract loop values for the loop. Please try again later.",
|
|
||||||
)
|
|
||||||
raise Exception("extraction_block failed")
|
raise Exception("extraction_block failed")
|
||||||
# validate output parameter
|
# validate output parameter
|
||||||
try:
|
try:
|
||||||
@@ -848,10 +880,6 @@ async def _generate_loop_task(
|
|||||||
"Failed to validate the output parameter of the extraction block for the loop task",
|
"Failed to validate the output parameter of the extraction block for the loop task",
|
||||||
extraction_block_result=extraction_block_result,
|
extraction_block_result=extraction_block_result,
|
||||||
)
|
)
|
||||||
await app.WORKFLOW_SERVICE.mark_workflow_run_as_failed(
|
|
||||||
workflow_run_id=workflow_run_id,
|
|
||||||
failure_reason="Invalid output parameter of the extraction block for the loop. Please try again later.",
|
|
||||||
)
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# update the observer thought
|
# update the observer thought
|
||||||
@@ -1207,6 +1235,22 @@ async def mark_observer_task_as_completed(
|
|||||||
return observer_task
|
return observer_task
|
||||||
|
|
||||||
|
|
||||||
|
async def mark_observer_task_as_canceled(
|
||||||
|
observer_cruise_id: str,
|
||||||
|
workflow_run_id: str | None = None,
|
||||||
|
organization_id: str | None = None,
|
||||||
|
) -> ObserverTask:
|
||||||
|
observer_task = await app.DATABASE.update_observer_cruise(
|
||||||
|
observer_cruise_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
status=ObserverTaskStatus.canceled,
|
||||||
|
)
|
||||||
|
if workflow_run_id:
|
||||||
|
await app.WORKFLOW_SERVICE.mark_workflow_run_as_canceled(workflow_run_id)
|
||||||
|
await send_observer_task_webhook(observer_task)
|
||||||
|
return observer_task
|
||||||
|
|
||||||
|
|
||||||
def _get_extracted_data_from_block_result(
|
def _get_extracted_data_from_block_result(
|
||||||
block_result: BlockResult,
|
block_result: BlockResult,
|
||||||
task_type: str,
|
task_type: str,
|
||||||
|
|||||||
Reference in New Issue
Block a user