workflow runtime API (#1421)

This commit is contained in:
Shuchang Zheng
2024-12-22 20:54:53 -08:00
committed by GitHub
parent 2e37542218
commit 94a3779bd7
5 changed files with 137 additions and 79 deletions

View File

@@ -250,6 +250,28 @@ class AgentDB:
LOG.error("UnexpectedError", exc_info=True) LOG.error("UnexpectedError", exc_info=True)
raise raise
async def get_tasks_by_ids(
self,
task_ids: list[str],
organization_id: str | None = None,
) -> list[Task]:
try:
async with self.Session() as session:
tasks = (
await session.scalars(
select(TaskModel)
.filter(TaskModel.task_id.in_(task_ids))
.filter_by(organization_id=organization_id)
)
).all()
return [convert_to_task(task, debug_enabled=self.debug_enabled) for task in tasks]
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def get_step(self, task_id: str, step_id: str, organization_id: str | None = None) -> Step | None: async def get_step(self, task_id: str, step_id: str, organization_id: str | None = None) -> Step | None:
try: try:
async with self.Session() as session: async with self.Session() as session:
@@ -1883,7 +1905,7 @@ class AgentDB:
return ObserverThought.model_validate(observer_thought) return ObserverThought.model_validate(observer_thought)
return None return None
async def get_observer_cruise_thoughts( async def get_observer_thoughts(
self, self,
observer_cruise_id: str, observer_cruise_id: str,
organization_id: str | None = None, organization_id: str | None = None,
@@ -2079,3 +2101,24 @@ class AgentDB:
task = await self.get_task(task_id, organization_id=organization_id) task = await self.get_task(task_id, organization_id=organization_id)
return convert_to_workflow_run_block(workflow_run_block, task=task) return convert_to_workflow_run_block(workflow_run_block, task=task)
raise NotFoundError(f"WorkflowRunBlock {workflow_run_block_id} not found") raise NotFoundError(f"WorkflowRunBlock {workflow_run_block_id} not found")
async def get_workflow_run_blocks(
self,
workflow_run_id: str,
organization_id: str | None = None,
) -> list[WorkflowRunBlock]:
async with self.Session() as session:
workflow_run_blocks = (
await session.scalars(
select(WorkflowRunBlockModel)
.filter_by(workflow_run_id=workflow_run_id)
.filter_by(organization_id=organization_id)
.order_by(WorkflowRunBlockModel.created_at)
)
).all()
tasks = await self.get_tasks_by_workflow_run_id(workflow_run_id)
tasks_dict = {task.task_id: task for task in tasks}
return [
convert_to_workflow_run_block(workflow_run_block, task=tasks_dict.get(workflow_run_block.task_id))
for workflow_run_block in workflow_run_blocks
]

View File

@@ -33,7 +33,7 @@ from skyvern.forge.sdk.artifact.models import Artifact
from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.permissions.permission_checker_factory import PermissionCheckerFactory from skyvern.forge.sdk.core.permissions.permission_checker_factory import PermissionCheckerFactory
from skyvern.forge.sdk.core.security import generate_skyvern_signature from skyvern.forge.sdk.core.security import generate_skyvern_signature
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType, TaskType from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
from skyvern.forge.sdk.executor.factory import AsyncExecutorFactory from skyvern.forge.sdk.executor.factory import AsyncExecutorFactory
from skyvern.forge.sdk.models import Step from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.observers import CruiseRequest, ObserverCruise from skyvern.forge.sdk.schemas.observers import CruiseRequest, ObserverCruise
@@ -53,14 +53,13 @@ from skyvern.forge.sdk.schemas.tasks import (
TaskResponse, TaskResponse,
TaskStatus, TaskStatus,
) )
from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunBlock, WorkflowRunEvent, WorkflowRunEventType from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunTimeline
from skyvern.forge.sdk.services import observer_service, org_auth_service from skyvern.forge.sdk.services import observer_service, org_auth_service
from skyvern.forge.sdk.workflow.exceptions import ( from skyvern.forge.sdk.workflow.exceptions import (
FailedToCreateWorkflow, FailedToCreateWorkflow,
FailedToUpdateWorkflow, FailedToUpdateWorkflow,
WorkflowParameterMissingRequiredValue, WorkflowParameterMissingRequiredValue,
) )
from skyvern.forge.sdk.workflow.models.block import BlockType
from skyvern.forge.sdk.workflow.models.workflow import ( from skyvern.forge.sdk.workflow.models.workflow import (
RunWorkflowResponse, RunWorkflowResponse,
Workflow, Workflow,
@@ -727,88 +726,32 @@ async def get_workflow_run(
@base_router.get( @base_router.get(
"/workflows/{workflow_id}/runs/{workflow_run_id}/events", "/workflows/{workflow_id}/runs/{workflow_run_id}/timeline",
) )
@base_router.get( @base_router.get(
"/workflows/{workflow_id}/runs/{workflow_run_id}/events/", "/workflows/{workflow_id}/runs/{workflow_run_id}/timeline/",
) )
async def get_workflow_run_events( async def get_workflow_run_timeline(
workflow_id: str, workflow_id: str,
workflow_run_id: str, workflow_run_id: str,
observer_cruise_id: str | None = None, observer_cruise_id: str | None = None,
page: int = Query(1, ge=1), page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1), page_size: int = Query(20, ge=1),
current_org: Organization = Depends(org_auth_service.get_current_org), current_org: Organization = Depends(org_auth_service.get_current_org),
) -> list[WorkflowRunEvent]: ) -> list[WorkflowRunTimeline]:
# get all the tasks for the workflow run # get all the workflow run blocks
tasks = await app.DATABASE.get_tasks( workflow_run_block_timeline = await app.WORKFLOW_SERVICE.get_workflow_run_timeline(
page,
page_size,
workflow_run_id=workflow_run_id, workflow_run_id=workflow_run_id,
organization_id=current_org.organization_id, organization_id=current_org.organization_id,
) )
workflow_run_events: list[WorkflowRunEvent] = []
for task in tasks:
block_type = BlockType.TASK
if task.task_type == TaskType.general:
if not task.navigation_goal and task.data_extraction_goal:
block_type = BlockType.EXTRACTION
elif task.navigation_goal and not task.data_extraction_goal:
block_type = BlockType.NAVIGATION
elif task.task_type == TaskType.validation:
block_type = BlockType.VALIDATION
elif task.task_type == TaskType.action:
block_type = BlockType.ACTION
event = WorkflowRunEvent(
type=WorkflowRunEventType.block,
block=WorkflowRunBlock(
workflow_run_id=workflow_run_id,
block_type=block_type,
label=task.title,
title=task.title,
url=task.url,
status=task.status,
navigation_goal=task.navigation_goal,
data_extraction_goal=task.data_extraction_goal,
data_schema=task.extracted_information_schema,
terminate_criterion=task.terminate_criterion,
complete_criterion=task.complete_criterion,
created_at=task.created_at,
modified_at=task.modified_at,
),
created_at=task.created_at,
modified_at=task.modified_at,
)
workflow_run_events.append(event)
# get all the actions for all the tasks
actions = await app.DATABASE.get_tasks_actions(
[task.task_id for task in tasks], organization_id=current_org.organization_id
)
for action in actions:
workflow_run_events.append(
WorkflowRunEvent(
type=WorkflowRunEventType.action,
action=action,
created_at=action.created_at or datetime.datetime.utcnow(),
modified_at=action.modified_at or datetime.datetime.utcnow(),
)
)
# get all the thoughts for the cruise
if observer_cruise_id: if observer_cruise_id:
thoughts = await app.DATABASE.get_observer_cruise_thoughts( observer_thought_timeline = await observer_service.get_observer_thought_timelines(
observer_cruise_id, organization_id=current_org.organization_id observer_cruise_id=observer_cruise_id,
organization_id=current_org.organization_id,
) )
for thought in thoughts: workflow_run_block_timeline.extend(observer_thought_timeline)
workflow_run_events.append( workflow_run_block_timeline.sort(key=lambda x: x.created_at)
WorkflowRunEvent( return workflow_run_block_timeline
type=WorkflowRunEventType.thought,
thought=thought,
created_at=thought.created_at,
modified_at=thought.modified_at,
)
)
workflow_run_events.sort(key=lambda x: x.created_at)
return workflow_run_events
@base_router.get( @base_router.get(

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
from datetime import datetime from datetime import datetime
from enum import StrEnum from enum import StrEnum
from typing import Any from typing import Any
@@ -10,7 +12,7 @@ from skyvern.webeye.actions.actions import Action
class WorkflowRunBlock(BaseModel): class WorkflowRunBlock(BaseModel):
workflow_run_block_id: str = "placeholder" workflow_run_block_id: str
workflow_run_id: str workflow_run_id: str
parent_workflow_run_block_id: str | None = None parent_workflow_run_block_id: str | None = None
block_type: BlockType block_type: BlockType
@@ -26,20 +28,27 @@ class WorkflowRunBlock(BaseModel):
data_schema: dict[str, Any] | list | str | None = None data_schema: dict[str, Any] | list | str | None = None
terminate_criterion: str | None = None terminate_criterion: str | None = None
complete_criterion: str | None = None complete_criterion: str | None = None
actions: list[Action] = []
created_at: datetime created_at: datetime
modified_at: datetime modified_at: datetime
# for loop block
loop_values: list[Any] | None = None
class WorkflowRunEventType(StrEnum): # block inside a loop block
action = "action" current_item: Any | None = None
current_index: int | None = None
class WorkflowRunTimelineType(StrEnum):
thought = "thought" thought = "thought"
block = "block" block = "block"
class WorkflowRunEvent(BaseModel): class WorkflowRunTimeline(BaseModel):
type: WorkflowRunEventType type: WorkflowRunTimelineType
action: Action | None = None
thought: ObserverThought | None = None
block: WorkflowRunBlock | None = None block: WorkflowRunBlock | None = None
thought: ObserverThought | None = None
children: list[WorkflowRunTimeline] = []
created_at: datetime created_at: datetime
modified_at: datetime modified_at: datetime

View File

@@ -15,6 +15,7 @@ from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.schemas.observers import ObserverCruise, ObserverCruiseStatus, ObserverMetadata from skyvern.forge.sdk.schemas.observers import ObserverCruise, ObserverCruiseStatus, ObserverMetadata
from skyvern.forge.sdk.schemas.organizations import Organization from skyvern.forge.sdk.schemas.organizations import Organization
from skyvern.forge.sdk.schemas.tasks import ProxyLocation from skyvern.forge.sdk.schemas.tasks import ProxyLocation
from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunTimeline, WorkflowRunTimelineType
from skyvern.forge.sdk.workflow.models.block import ( from skyvern.forge.sdk.workflow.models.block import (
BlockResult, BlockResult,
BlockStatus, BlockStatus,
@@ -776,3 +777,19 @@ def _generate_random_string(length: int = 5) -> str:
# Use the current timestamp as the seed # Use the current timestamp as the seed
random.seed(os.urandom(16)) random.seed(os.urandom(16))
return "".join(random.choices(RANDOM_STRING_POOL, k=length)) return "".join(random.choices(RANDOM_STRING_POOL, k=length))
async def get_observer_thought_timelines(
observer_cruise_id: str,
organization_id: str | None = None,
) -> list[WorkflowRunTimeline]:
observer_thoughts = await app.DATABASE.get_observer_thoughts(observer_cruise_id, organization_id=organization_id)
return [
WorkflowRunTimeline(
type=WorkflowRunTimelineType.thought,
thought=thought,
created_at=thought.created_at,
modified_at=thought.modified_at,
)
for thought in observer_thoughts
]

View File

@@ -25,6 +25,7 @@ from skyvern.forge.sdk.db.enums import TaskType
from skyvern.forge.sdk.models import Step from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.organizations import Organization from skyvern.forge.sdk.schemas.organizations import Organization
from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task
from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunBlock, WorkflowRunTimeline, WorkflowRunTimelineType
from skyvern.forge.sdk.workflow.exceptions import ( from skyvern.forge.sdk.workflow.exceptions import (
ContextParameterSourceNotDefined, ContextParameterSourceNotDefined,
InvalidWaitBlockTime, InvalidWaitBlockTime,
@@ -1602,3 +1603,48 @@ class WorkflowService:
organization=organization, organization=organization,
request=workflow_create_request, request=workflow_create_request,
) )
async def get_workflow_run_timeline(
self,
workflow_run_id: str,
organization_id: str | None = None,
) -> list[WorkflowRunTimeline]:
"""
build the tree structure of the workflow run timeline
"""
workflow_run_blocks = await app.DATABASE.get_workflow_run_blocks(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
# get all the actions for all workflow run blocks
task_ids = [block.task_id for block in workflow_run_blocks if block.task_id]
task_id_to_block: dict[str, WorkflowRunBlock] = {
block.task_id: block for block in workflow_run_blocks if block.task_id
}
actions = await app.DATABASE.get_tasks_actions(task_ids=task_ids, organization_id=organization_id)
for action in actions:
if not action.task_id:
continue
task_block = task_id_to_block[action.task_id]
task_block.actions.append(action)
result = []
block_map: dict[str, WorkflowRunTimeline] = {}
while workflow_run_blocks:
block = workflow_run_blocks.pop(0)
workflow_run_timeline = WorkflowRunTimeline(
type=WorkflowRunTimelineType.block,
block=block,
created_at=block.created_at,
modified_at=block.modified_at,
)
if block.parent_workflow_run_block_id:
if block.parent_workflow_run_block_id in block_map:
block_map[block.parent_workflow_run_block_id].children.append(workflow_run_timeline)
else:
# put the block back to the queue
workflow_run_blocks.append(block)
else:
result.append(workflow_run_timeline)
return result