Sync cloud skyvern to oss skyvern (#55)

This commit is contained in:
Kerem Yilmaz
2024-03-12 22:28:16 -07:00
committed by GitHub
parent 647ea2ac0f
commit 15d78d7b08
25 changed files with 554 additions and 163 deletions

View File

@@ -0,0 +1,38 @@
"""Add title, error_code_mapping, and errors to tasks
Revision ID: 82a0c686152d
Revises: 99423c1dec60
Create Date: 2024-03-13 05:18:52.674264+00:00
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "82a0c686152d"
down_revision: Union[str, None] = "99423c1dec60"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("tasks", sa.Column("title", sa.String(), nullable=True))
op.add_column("tasks", sa.Column("error_code_mapping", sa.JSON(), nullable=True))
# In order to add a column with a default value, we need to add the column
# as nullable, then set the default value, then set the column to not
op.add_column("tasks", sa.Column("errors", sa.JSON(), nullable=True))
op.execute("UPDATE tasks SET errors = '[]'::jsonb")
op.alter_column("tasks", "errors", nullable=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("tasks", "errors")
op.drop_column("tasks", "error_code_mapping")
op.drop_column("tasks", "title")
# ### end Alembic commands ###

View File

@@ -13,6 +13,7 @@ class Settings(BaseSettings):
VIDEO_PATH: str | None = None VIDEO_PATH: str | None = None
HAR_PATH: str | None = "./har" HAR_PATH: str | None = "./har"
BROWSER_ACTION_TIMEOUT_MS: int = 5000 BROWSER_ACTION_TIMEOUT_MS: int = 5000
BROWSER_SCREENSHOT_TIMEOUT_MS: int = 10000
MAX_STEPS_PER_RUN: int = 75 MAX_STEPS_PER_RUN: int = 75
MAX_NUM_SCREENSHOTS: int = 6 MAX_NUM_SCREENSHOTS: int = 6
# Ratio should be between 0 and 1. # Ratio should be between 0 and 1.

View File

@@ -179,3 +179,13 @@ class OrganizationNotFound(SkyvernException):
class StepNotFound(SkyvernException): class StepNotFound(SkyvernException):
def __init__(self, organization_id: str, task_id: str, step_id: str | None = None) -> None: def __init__(self, organization_id: str, task_id: str, step_id: str | None = None) -> None:
super().__init__(f"Step {step_id or 'latest'} not found. organization_id={organization_id} task_id={task_id}") super().__init__(f"Step {step_id or 'latest'} not found. organization_id={organization_id} task_id={task_id}")
class FailedToTakeScreenshot(SkyvernException):
def __init__(self, error_message: str) -> None:
super().__init__(f"Failed to take screenshot. Error message: {error_message}")
class WorkflowRunContextNotInitialized(SkyvernException):
def __init__(self, workflow_run_id: str) -> None:
super().__init__("WorkflowRunContext not initialized for workflow run {workflow_run_id}")

View File

@@ -25,10 +25,17 @@ from skyvern.forge.sdk.core.security import generate_skyvern_signature
from skyvern.forge.sdk.models import Organization, Step, StepStatus from skyvern.forge.sdk.models import Organization, Step, StepStatus
from skyvern.forge.sdk.schemas.tasks import Task, TaskRequest, TaskStatus from skyvern.forge.sdk.schemas.tasks import Task, TaskRequest, TaskStatus
from skyvern.forge.sdk.settings_manager import SettingsManager from skyvern.forge.sdk.settings_manager import SettingsManager
from skyvern.forge.sdk.workflow.context_manager import ContextManager from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext
from skyvern.forge.sdk.workflow.models.block import TaskBlock from skyvern.forge.sdk.workflow.models.block import TaskBlock
from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowRun from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowRun
from skyvern.webeye.actions.actions import Action, ActionType, CompleteAction, WebAction, parse_actions from skyvern.webeye.actions.actions import (
Action,
ActionType,
CompleteAction,
UserDefinedError,
WebAction,
parse_actions,
)
from skyvern.webeye.actions.handler import ActionHandler from skyvern.webeye.actions.handler import ActionHandler
from skyvern.webeye.actions.models import AgentStepOutput, DetailedAgentStepOutput from skyvern.webeye.actions.models import AgentStepOutput, DetailedAgentStepOutput
from skyvern.webeye.actions.responses import ActionResult from skyvern.webeye.actions.responses import ActionResult
@@ -40,6 +47,11 @@ LOG = structlog.get_logger()
class ForgeAgent(Agent): class ForgeAgent(Agent):
def __init__(self) -> None: def __init__(self) -> None:
if SettingsManager.get_settings().ADDITIONAL_MODULES:
for module in SettingsManager.get_settings().ADDITIONAL_MODULES:
LOG.info("Loading additional module", module=module)
__import__(module)
LOG.info("Additional modules loaded", modules=SettingsManager.get_settings().ADDITIONAL_MODULES)
LOG.info( LOG.info(
"Initializing ForgeAgent", "Initializing ForgeAgent",
env=SettingsManager.get_settings().ENV, env=SettingsManager.get_settings().ENV,
@@ -52,11 +64,6 @@ class ForgeAgent(Agent):
long_running_task_warning_ratio=SettingsManager.get_settings().LONG_RUNNING_TASK_WARNING_RATIO, long_running_task_warning_ratio=SettingsManager.get_settings().LONG_RUNNING_TASK_WARNING_RATIO,
debug_mode=SettingsManager.get_settings().DEBUG_MODE, debug_mode=SettingsManager.get_settings().DEBUG_MODE,
) )
if SettingsManager.get_settings().ADDITIONAL_MODULES:
for module in SettingsManager.get_settings().ADDITIONAL_MODULES:
LOG.info("Loading additional module", module=module)
__import__(module)
LOG.info("Additional modules loaded", modules=SettingsManager.get_settings().ADDITIONAL_MODULES)
async def validate_step_execution( async def validate_step_execution(
self, self,
@@ -91,14 +98,14 @@ class ForgeAgent(Agent):
task_block: TaskBlock, task_block: TaskBlock,
workflow: Workflow, workflow: Workflow,
workflow_run: WorkflowRun, workflow_run: WorkflowRun,
context_manager: ContextManager, workflow_run_context: WorkflowRunContext,
task_order: int, task_order: int,
task_retry: int, task_retry: int,
) -> tuple[Task, Step]: ) -> tuple[Task, Step]:
task_block_parameters = task_block.parameters task_block_parameters = task_block.parameters
navigation_payload = {} navigation_payload = {}
for parameter in task_block_parameters: for parameter in task_block_parameters:
navigation_payload[parameter.key] = context_manager.get_value(parameter.key) navigation_payload[parameter.key] = workflow_run_context.get_value(parameter.key)
task_url = task_block.url task_url = task_block.url
if task_url is None: if task_url is None:
@@ -114,6 +121,7 @@ class ForgeAgent(Agent):
task = await app.DATABASE.create_task( task = await app.DATABASE.create_task(
url=task_url, url=task_url,
title=task_block.title,
webhook_callback_url=None, webhook_callback_url=None,
navigation_goal=task_block.navigation_goal, navigation_goal=task_block.navigation_goal,
data_extraction_goal=task_block.data_extraction_goal, data_extraction_goal=task_block.data_extraction_goal,
@@ -124,6 +132,7 @@ class ForgeAgent(Agent):
workflow_run_id=workflow_run.workflow_run_id, workflow_run_id=workflow_run.workflow_run_id,
order=task_order, order=task_order,
retry=task_retry, retry=task_retry,
error_code_mapping=task_block.error_code_mapping,
) )
LOG.info( LOG.info(
"Created new task for workflow run", "Created new task for workflow run",
@@ -131,8 +140,10 @@ class ForgeAgent(Agent):
workflow_run_id=workflow_run.workflow_run_id, workflow_run_id=workflow_run.workflow_run_id,
task_id=task.task_id, task_id=task.task_id,
url=task.url, url=task.url,
title=task.title,
nav_goal=task.navigation_goal, nav_goal=task.navigation_goal,
data_goal=task.data_extraction_goal, data_goal=task.data_extraction_goal,
error_code_mapping=task.error_code_mapping,
proxy_location=task.proxy_location, proxy_location=task.proxy_location,
task_order=task_order, task_order=task_order,
task_retry=task_retry, task_retry=task_retry,
@@ -161,6 +172,7 @@ class ForgeAgent(Agent):
async def create_task(self, task_request: TaskRequest, organization_id: str | None = None) -> Task: async def create_task(self, task_request: TaskRequest, organization_id: str | None = None) -> Task:
task = await app.DATABASE.create_task( task = await app.DATABASE.create_task(
url=task_request.url, url=task_request.url,
title=task_request.title,
webhook_callback_url=task_request.webhook_callback_url, webhook_callback_url=task_request.webhook_callback_url,
navigation_goal=task_request.navigation_goal, navigation_goal=task_request.navigation_goal,
data_extraction_goal=task_request.data_extraction_goal, data_extraction_goal=task_request.data_extraction_goal,
@@ -168,10 +180,12 @@ class ForgeAgent(Agent):
organization_id=organization_id, organization_id=organization_id,
proxy_location=task_request.proxy_location, proxy_location=task_request.proxy_location,
extracted_information_schema=task_request.extracted_information_schema, extracted_information_schema=task_request.extracted_information_schema,
error_code_mapping=task_request.error_code_mapping,
) )
LOG.info( LOG.info(
"Created new task", "Created new task",
task_id=task.task_id, task_id=task.task_id,
title=task.title,
url=task.url, url=task.url,
nav_goal=task.navigation_goal, nav_goal=task.navigation_goal,
data_goal=task.data_extraction_goal, data_goal=task.data_extraction_goal,
@@ -195,7 +209,7 @@ class ForgeAgent(Agent):
await self.validate_step_execution(task, step) await self.validate_step_execution(task, step)
step, browser_state, detailed_output = await self._initialize_execution_state(task, step, workflow_run) step, browser_state, detailed_output = await self._initialize_execution_state(task, step, workflow_run)
step, detailed_output = await self.agent_step(task, step, browser_state, organization=organization) step, detailed_output = await self.agent_step(task, step, browser_state, organization=organization)
analytics.capture("skyvern-oss-agent-step-status", {"status": step.status}) task = await self.update_task_errors_from_detailed_output(task, detailed_output)
retry = False retry = False
# If the step failed, mark the step as failed and retry # If the step failed, mark the step as failed and retry
@@ -466,7 +480,7 @@ class ForgeAgent(Agent):
if not browser_state.page: if not browser_state.page:
raise BrowserStateMissingPage() raise BrowserStateMissingPage()
try: try:
screenshot = await browser_state.page.screenshot(full_page=True) screenshot = await browser_state.take_screenshot(full_page=True)
await app.ARTIFACT_MANAGER.create_artifact( await app.ARTIFACT_MANAGER.create_artifact(
step=step, step=step,
artifact_type=ArtifactType.SCREENSHOT_ACTION, artifact_type=ArtifactType.SCREENSHOT_ACTION,
@@ -582,6 +596,7 @@ class ForgeAgent(Agent):
elements=scraped_page.element_tree_trimmed, # scraped_page.element_tree, elements=scraped_page.element_tree_trimmed, # scraped_page.element_tree,
data_extraction_goal=task.data_extraction_goal, data_extraction_goal=task.data_extraction_goal,
action_history=action_results_str, action_history=action_results_str,
error_code_mapping_str=json.dumps(task.error_code_mapping) if task.error_code_mapping else None,
utc_datetime=datetime.utcnow(), utc_datetime=datetime.utcnow(),
) )
@@ -686,9 +701,9 @@ class ForgeAgent(Agent):
analytics.capture("skyvern-oss-agent-task-status", {"status": task.status}) analytics.capture("skyvern-oss-agent-task-status", {"status": task.status})
# Take one last screenshot and create an artifact before closing the browser to see the final state # Take one last screenshot and create an artifact before closing the browser to see the final state
browser_state: BrowserState = await app.BROWSER_MANAGER.get_or_create_for_task(task) browser_state: BrowserState = await app.BROWSER_MANAGER.get_or_create_for_task(task)
page = await browser_state.get_or_create_page() await browser_state.get_or_create_page()
try: try:
screenshot = await page.screenshot(full_page=True) screenshot = await browser_state.take_screenshot(full_page=True)
await app.ARTIFACT_MANAGER.create_artifact( await app.ARTIFACT_MANAGER.create_artifact(
step=last_step, step=last_step,
artifact_type=ArtifactType.SCREENSHOT_FINAL, artifact_type=ArtifactType.SCREENSHOT_FINAL,
@@ -829,6 +844,14 @@ class ForgeAgent(Agent):
artifact_type=ArtifactType.HAR, artifact_type=ArtifactType.HAR,
data=har_data, data=har_data,
) )
if browser_state.browser_context and browser_state.browser_artifacts.traces_dir:
trace_path = f"{browser_state.browser_artifacts.traces_dir}/{task.task_id}.zip"
await app.ARTIFACT_MANAGER.create_artifact(
step=last_step,
artifact_type=ArtifactType.TRACE,
path=trace_path,
)
else: else:
LOG.warning( LOG.warning(
"BrowserState is missing before sending response to webhook_callback_url", "BrowserState is missing before sending response to webhook_callback_url",
@@ -1010,3 +1033,25 @@ class ForgeAgent(Agent):
warning_ratio=SettingsManager.get_settings().LONG_RUNNING_TASK_WARNING_RATIO, warning_ratio=SettingsManager.get_settings().LONG_RUNNING_TASK_WARNING_RATIO,
) )
return None, None, next_step return None, None, next_step
@staticmethod
async def get_task_errors(task: Task) -> list[UserDefinedError]:
steps = await app.DATABASE.get_task_steps(task_id=task.task_id, organization_id=task.organization_id)
errors = []
for step in steps:
if step.output and step.output.errors:
errors.extend(step.output.errors)
return errors
@staticmethod
async def update_task_errors_from_detailed_output(
task: Task, detailed_step_output: DetailedAgentStepOutput
) -> Task:
task_errors = task.errors
step_errors = detailed_step_output.extract_errors() or []
task_errors.extend([error.model_dump() for error in step_errors])
return await app.DATABASE.update_task(
task_id=task.task_id, organization_id=task.organization_id, errors=task_errors
)

View File

@@ -6,9 +6,9 @@ from skyvern.forge.sdk.api.open_ai import OpenAIClientManager
from skyvern.forge.sdk.artifact.manager import ArtifactManager from skyvern.forge.sdk.artifact.manager import ArtifactManager
from skyvern.forge.sdk.artifact.storage.factory import StorageFactory from skyvern.forge.sdk.artifact.storage.factory import StorageFactory
from skyvern.forge.sdk.db.client import AgentDB from skyvern.forge.sdk.db.client import AgentDB
from skyvern.forge.sdk.executor.factory import AsyncExecutorFactory
from skyvern.forge.sdk.forge_log import setup_logger from skyvern.forge.sdk.forge_log import setup_logger
from skyvern.forge.sdk.settings_manager import SettingsManager from skyvern.forge.sdk.settings_manager import SettingsManager
from skyvern.forge.sdk.workflow.context_manager import WorkflowContextManager
from skyvern.forge.sdk.workflow.service import WorkflowService from skyvern.forge.sdk.workflow.service import WorkflowService
from skyvern.webeye.browser_manager import BrowserManager from skyvern.webeye.browser_manager import BrowserManager
@@ -26,10 +26,10 @@ DATABASE = AgentDB(
SettingsManager.get_settings().DATABASE_STRING, debug_enabled=SettingsManager.get_settings().DEBUG_MODE SettingsManager.get_settings().DATABASE_STRING, debug_enabled=SettingsManager.get_settings().DEBUG_MODE
) )
STORAGE = StorageFactory.get_storage() STORAGE = StorageFactory.get_storage()
ASYNC_EXECUTOR = AsyncExecutorFactory.get_executor()
ARTIFACT_MANAGER = ArtifactManager() ARTIFACT_MANAGER = ArtifactManager()
BROWSER_MANAGER = BrowserManager() BROWSER_MANAGER = BrowserManager()
OPENAI_CLIENT = OpenAIClientManager() OPENAI_CLIENT = OpenAIClientManager()
WORKFLOW_CONTEXT_MANAGER = WorkflowContextManager()
WORKFLOW_SERVICE = WorkflowService() WORKFLOW_SERVICE = WorkflowService()
agent = ForgeAgent() agent = ForgeAgent()

View File

@@ -24,7 +24,15 @@ Reply in JSON format with the following keys:
"label": str, // the label of the option if any. MAKE SURE YOU USE THIS LABEL TO SELECT THE OPTION. DO NOT PUT ANYTHING OTHER THAN A VALID OPTION LABEL HERE "label": str, // the label of the option if any. MAKE SURE YOU USE THIS LABEL TO SELECT THE OPTION. DO NOT PUT ANYTHING OTHER THAN A VALID OPTION LABEL HERE
"index": int, // the id corresponding to the optionIndex under the the select element. "index": int, // the id corresponding to the optionIndex under the the select element.
"value": str // the value of the option. MAKE SURE YOU USE THIS VALUE TO SELECT THE OPTION. DO NOT PUT ANYTHING OTHER THAN A VALID OPTION VALUE HERE "value": str // the value of the option. MAKE SURE YOU USE THIS VALUE TO SELECT THE OPTION. DO NOT PUT ANYTHING OTHER THAN A VALID OPTION VALUE HERE
} },
{% if error_code_mapping_str %}
"errors": array // A list of errors. This is used to surface any errors that matches the current situation for COMPLETE and TERMINATE actions. For other actions or if no error description suits the current situation on the screenshots, return an empty list. You are allowed to return multiple errors if there are multiple errors on the page.
[{
"error_code": str, // The error code from the user's error code list
"reasoning": str, // The reasoning behind the error. Be specific, referencing any user information and their fields in your reasoning. Keep the reasoning short and to the point.
"confidence_float": float // The confidence of the error. Pick a number between 0.0 and 1.0. 0.0 means no confidence, 1.0 means full confidence
}]
{% endif %}
}], }],
} }
@@ -41,6 +49,12 @@ User goal:
``` ```
{{ navigation_goal }} {{ navigation_goal }}
``` ```
{% if error_code_mapping_str %}
Use the error codes and their descriptions to surface user-defined errors. Do not return any error that's not defined by the user. User defined errors:
{{ error_code_mapping_str }}
{% endif %}
{% if data_extraction_goal %} {% if data_extraction_goal %}
User Data Extraction Goal: User Data Extraction Goal:

View File

@@ -11,6 +11,12 @@ If you are unable to extract the requested information for a specific field in t
User Data Extraction Goal: {{ data_extraction_goal }} User Data Extraction Goal: {{ data_extraction_goal }}
{% if error_code_mapping_str %}
Use the error codes and their descriptions to return errors in the output, do not return any error that's not defined by the user. Don't return any outputs if the schema doesn't specify an error related field. Here are the descriptions defined by the user: {{ error_code_mapping_str }}
{% endif %}
Current URL: {{ current_url }} Current URL: {{ current_url }}
Text extracted from the webpage: {{ extracted_text }} Text extracted from the webpage: {{ extracted_text }}
User Navigation Payload: {{ navigation_payload }}

View File

@@ -13,6 +13,7 @@ from starlette_context.plugins.base import Plugin
from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.routes.agent_protocol import base_router from skyvern.forge.sdk.routes.agent_protocol import base_router
from skyvern.scheduler import SCHEDULER
LOG = structlog.get_logger() LOG = structlog.get_logger()
@@ -58,6 +59,12 @@ class Agent:
), ),
) )
# Register the scheduler on startup so that we can schedule jobs dynamically
@app.on_event("startup")
def start_scheduler() -> None:
LOG.info("Starting the skyvern scheduler.")
SCHEDULER.start()
@app.exception_handler(Exception) @app.exception_handler(Exception)
async def unexpected_exception(request: Request, exc: Exception) -> JSONResponse: async def unexpected_exception(request: Request, exc: Exception) -> JSONResponse:
LOG.exception("Unexpected error in agent server.", exc_info=exc) LOG.exception("Unexpected error in agent server.", exc_info=exc)

View File

@@ -55,6 +55,7 @@ class AgentDB:
async def create_task( async def create_task(
self, self,
url: str, url: str,
title: str | None,
navigation_goal: str | None, navigation_goal: str | None,
data_extraction_goal: str | None, data_extraction_goal: str | None,
navigation_payload: dict[str, Any] | list | str | None, navigation_payload: dict[str, Any] | list | str | None,
@@ -65,12 +66,14 @@ class AgentDB:
workflow_run_id: str | None = None, workflow_run_id: str | None = None,
order: int | None = None, order: int | None = None,
retry: int | None = None, retry: int | None = None,
error_code_mapping: dict[str, str] | None = None,
) -> Task: ) -> Task:
try: try:
with self.Session() as session: with self.Session() as session:
new_task = TaskModel( new_task = TaskModel(
status="created", status="created",
url=url, url=url,
title=title,
webhook_callback_url=webhook_callback_url, webhook_callback_url=webhook_callback_url,
navigation_goal=navigation_goal, navigation_goal=navigation_goal,
data_extraction_goal=data_extraction_goal, data_extraction_goal=data_extraction_goal,
@@ -81,6 +84,7 @@ class AgentDB:
workflow_run_id=workflow_run_id, workflow_run_id=workflow_run_id,
order=order, order=order,
retry=retry, retry=retry,
error_code_mapping=error_code_mapping,
) )
session.add(new_task) session.add(new_task)
session.commit() session.commit()
@@ -312,11 +316,16 @@ class AgentDB:
async def update_task( async def update_task(
self, self,
task_id: str, task_id: str,
status: TaskStatus, status: TaskStatus | None = None,
extracted_information: dict[str, Any] | list | str | None = None, extracted_information: dict[str, Any] | list | str | None = None,
failure_reason: str | None = None, failure_reason: str | None = None,
errors: list[dict[str, Any]] | None = None,
organization_id: str | None = None, organization_id: str | None = None,
) -> Task: ) -> Task:
if status is None and extracted_information is None and failure_reason is None and errors is None:
raise ValueError(
"At least one of status, extracted_information, or failure_reason must be provided to update the task"
)
try: try:
with self.Session() as session: with self.Session() as session:
if ( if (
@@ -325,11 +334,14 @@ class AgentDB:
.filter_by(organization_id=organization_id) .filter_by(organization_id=organization_id)
.first() .first()
): ):
task.status = status if status is not None:
task.status = status
if extracted_information is not None: if extracted_information is not None:
task.extracted_information = extracted_information task.extracted_information = extracted_information
if failure_reason is not None: if failure_reason is not None:
task.failure_reason = failure_reason task.failure_reason = failure_reason
if errors is not None:
task.errors = errors
session.commit() session.commit()
updated_task = await self.get_task(task_id, organization_id=organization_id) updated_task = await self.get_task(task_id, organization_id=organization_id)
if not updated_task: if not updated_task:

View File

@@ -29,6 +29,7 @@ class TaskModel(Base):
organization_id = Column(String, ForeignKey("organizations.organization_id")) organization_id = Column(String, ForeignKey("organizations.organization_id"))
status = Column(String) status = Column(String)
webhook_callback_url = Column(String) webhook_callback_url = Column(String)
title = Column(String)
url = Column(String) url = Column(String)
navigation_goal = Column(String) navigation_goal = Column(String)
data_extraction_goal = Column(String) data_extraction_goal = Column(String)
@@ -40,6 +41,8 @@ class TaskModel(Base):
workflow_run_id = Column(String, ForeignKey("workflow_runs.workflow_run_id")) workflow_run_id = Column(String, ForeignKey("workflow_runs.workflow_run_id"))
order = Column(Integer, nullable=True) order = Column(Integer, nullable=True)
retry = Column(Integer, nullable=True) retry = Column(Integer, nullable=True)
error_code_mapping = Column(JSON, nullable=True)
errors = Column(JSON, default=[], nullable=False)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False) modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)

View File

@@ -48,6 +48,7 @@ def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False) -> Task:
status=TaskStatus(task_obj.status), status=TaskStatus(task_obj.status),
created_at=task_obj.created_at, created_at=task_obj.created_at,
modified_at=task_obj.modified_at, modified_at=task_obj.modified_at,
title=task_obj.title,
url=task_obj.url, url=task_obj.url,
webhook_callback_url=task_obj.webhook_callback_url, webhook_callback_url=task_obj.webhook_callback_url,
navigation_goal=task_obj.navigation_goal, navigation_goal=task_obj.navigation_goal,
@@ -61,6 +62,8 @@ def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False) -> Task:
workflow_run_id=task_obj.workflow_run_id, workflow_run_id=task_obj.workflow_run_id,
order=task_obj.order, order=task_obj.order,
retry=task_obj.retry, retry=task_obj.retry,
error_code_mapping=task_obj.error_code_mapping,
errors=task_obj.errors,
) )
return task return task

View File

@@ -1,5 +1,6 @@
import abc import abc
import structlog
from fastapi import BackgroundTasks from fastapi import BackgroundTasks
from skyvern.forge import app from skyvern.forge import app
@@ -8,6 +9,8 @@ from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.models import Organization from skyvern.forge.sdk.models import Organization
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
LOG = structlog.get_logger()
class AsyncExecutor(abc.ABC): class AsyncExecutor(abc.ABC):
@abc.abstractmethod @abc.abstractmethod
@@ -43,6 +46,7 @@ class BackgroundTaskExecutor(AsyncExecutor):
max_steps_override: int | None, max_steps_override: int | None,
api_key: str | None, api_key: str | None,
) -> None: ) -> None:
LOG.info("Executing task using background task executor", task_id=task.task_id)
step = await app.DATABASE.create_step( step = await app.DATABASE.create_step(
task.task_id, task.task_id,
order=0, order=0,
@@ -52,7 +56,7 @@ class BackgroundTaskExecutor(AsyncExecutor):
task = await app.DATABASE.update_task( task = await app.DATABASE.update_task(
task.task_id, task.task_id,
TaskStatus.running, status=TaskStatus.running,
organization_id=organization.organization_id, organization_id=organization.organization_id,
) )
@@ -78,6 +82,7 @@ class BackgroundTaskExecutor(AsyncExecutor):
max_steps_override: int | None, max_steps_override: int | None,
api_key: str | None, api_key: str | None,
) -> None: ) -> None:
LOG.info("Executing workflow using background task executor", workflow_run_id=workflow_run_id)
background_tasks.add_task( background_tasks.add_task(
app.WORKFLOW_SERVICE.execute_workflow, app.WORKFLOW_SERVICE.execute_workflow,
workflow_run_id=workflow_run_id, workflow_run_id=workflow_run_id,

View File

@@ -11,8 +11,16 @@ from skyvern.forge import app
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.security import generate_skyvern_signature from skyvern.forge.sdk.core.security import generate_skyvern_signature
from skyvern.forge.sdk.executor.factory import AsyncExecutorFactory
from skyvern.forge.sdk.models import Organization, Step from skyvern.forge.sdk.models import Organization, Step
from skyvern.forge.sdk.schemas.tasks import CreateTaskResponse, Task, TaskRequest, TaskResponse, TaskStatus from skyvern.forge.sdk.schemas.tasks import (
CreateTaskResponse,
ProxyLocation,
Task,
TaskRequest,
TaskResponse,
TaskStatus,
)
from skyvern.forge.sdk.services import org_auth_service from skyvern.forge.sdk.services import org_auth_service
from skyvern.forge.sdk.settings_manager import SettingsManager from skyvern.forge.sdk.settings_manager import SettingsManager
from skyvern.forge.sdk.workflow.models.workflow import ( from skyvern.forge.sdk.workflow.models.workflow import (
@@ -80,10 +88,13 @@ async def create_agent_task(
analytics.capture("skyvern-oss-agent-task-create", data={"url": task.url}) analytics.capture("skyvern-oss-agent-task-create", data={"url": task.url})
agent = request["agent"] agent = request["agent"]
if current_org and current_org.organization_name == "CoverageCat":
task.proxy_location = ProxyLocation.RESIDENTIAL
created_task = await agent.create_task(task, current_org.organization_id) created_task = await agent.create_task(task, current_org.organization_id)
if x_max_steps_override: if x_max_steps_override:
LOG.info("Overriding max steps per run", max_steps_override=x_max_steps_override) LOG.info("Overriding max steps per run", max_steps_override=x_max_steps_override)
await app.ASYNC_EXECUTOR.execute_task( await AsyncExecutorFactory.get_executor().execute_task(
background_tasks=background_tasks, background_tasks=background_tasks,
task=created_task, task=created_task,
organization=current_org, organization=current_org,
@@ -398,10 +409,6 @@ async def execute_workflow(
x_max_steps_override: Annotated[int | None, Header()] = None, x_max_steps_override: Annotated[int | None, Header()] = None,
) -> RunWorkflowResponse: ) -> RunWorkflowResponse:
analytics.capture("skyvern-oss-agent-workflow-execute") analytics.capture("skyvern-oss-agent-workflow-execute")
LOG.info(
f"Running workflow {workflow_id}",
workflow_id=workflow_id,
)
context = skyvern_context.ensure_context() context = skyvern_context.ensure_context()
request_id = context.request_id request_id = context.request_id
workflow_run = await app.WORKFLOW_SERVICE.setup_workflow_run( workflow_run = await app.WORKFLOW_SERVICE.setup_workflow_run(
@@ -413,7 +420,7 @@ async def execute_workflow(
) )
if x_max_steps_override: if x_max_steps_override:
LOG.info("Overriding max steps per run", max_steps_override=x_max_steps_override) LOG.info("Overriding max steps per run", max_steps_override=x_max_steps_override)
await app.ASYNC_EXECUTOR.execute_workflow( await AsyncExecutorFactory.get_executor().execute_workflow(
background_tasks=background_tasks, background_tasks=background_tasks,
organization=current_org, organization=current_org,
workflow_id=workflow_id, workflow_id=workflow_id,

View File

@@ -18,6 +18,11 @@ class ProxyLocation(StrEnum):
class TaskRequest(BaseModel): class TaskRequest(BaseModel):
title: str | None = Field(
default=None,
description="The title of the task.",
examples=["Get a quote for car insurance"],
)
url: str = Field( url: str = Field(
..., ...,
min_length=1, min_length=1,
@@ -41,17 +46,27 @@ class TaskRequest(BaseModel):
examples=["Extract the quote price"], examples=["Extract the quote price"],
) )
navigation_payload: dict[str, Any] | list | str | None = Field( navigation_payload: dict[str, Any] | list | str | None = Field(
None, default=None,
description="The user's details needed to achieve the task.", description="The user's details needed to achieve the task.",
examples=[{"name": "John Doe", "email": "john@doe.com"}], examples=[{"name": "John Doe", "email": "john@doe.com"}],
) )
error_code_mapping: dict[str, str] | None = Field(
default=None,
description="The mapping of error codes and their descriptions.",
examples=[
{
"out_of_stock": "Return this error when the product is out of stock",
"not_found": "Return this error when the product is not found",
}
],
)
proxy_location: ProxyLocation | None = Field( proxy_location: ProxyLocation | None = Field(
None, default=None,
description="The location of the proxy to use for the task.", description="The location of the proxy to use for the task.",
examples=["US-WA", "US-CA", "US-FL", "US-NY", "US-TX"], examples=["US-WA", "US-CA", "US-FL", "US-NY", "US-TX"],
) )
extracted_information_schema: dict[str, Any] | list | str | None = Field( extracted_information_schema: dict[str, Any] | list | str | None = Field(
None, default=None,
description="The requested schema of the extracted information.", description="The requested schema of the extracted information.",
) )
@@ -122,6 +137,7 @@ class Task(TaskRequest):
workflow_run_id: str | None = None workflow_run_id: str | None = None
order: int | None = None order: int | None = None
retry: int | None = None retry: int | None = None
errors: list[dict[str, Any]] = []
def validate_update( def validate_update(
self, self,
@@ -162,6 +178,7 @@ class Task(TaskRequest):
failure_reason=failure_reason or self.failure_reason, failure_reason=failure_reason or self.failure_reason,
screenshot_url=screenshot_url, screenshot_url=screenshot_url,
recording_url=recording_url, recording_url=recording_url,
errors=self.errors,
) )
@@ -175,6 +192,7 @@ class TaskResponse(BaseModel):
screenshot_url: str | None = None screenshot_url: str | None = None
recording_url: str | None = None recording_url: str | None = None
failure_reason: str | None = None failure_reason: str | None = None
errors: list[dict[str, Any]] = []
class CreateTaskResponse(BaseModel): class CreateTaskResponse(BaseModel):

View File

@@ -1,7 +1,9 @@
import uuid
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
import structlog import structlog
from skyvern.exceptions import WorkflowRunContextNotInitialized
from skyvern.forge.sdk.api.aws import AsyncAWSClient from skyvern.forge.sdk.api.aws import AsyncAWSClient
from skyvern.forge.sdk.workflow.models.parameter import PARAMETER_TYPE, Parameter, ParameterType, WorkflowParameter from skyvern.forge.sdk.workflow.models.parameter import PARAMETER_TYPE, Parameter, ParameterType, WorkflowParameter
@@ -12,15 +14,15 @@ if TYPE_CHECKING:
LOG = structlog.get_logger() LOG = structlog.get_logger()
class ContextManager: class WorkflowRunContext:
aws_client: AsyncAWSClient
parameters: dict[str, PARAMETER_TYPE] parameters: dict[str, PARAMETER_TYPE]
values: dict[str, Any] values: dict[str, Any]
secrets: dict[str, Any]
def __init__(self, workflow_parameter_tuples: list[tuple[WorkflowParameter, "WorkflowRunParameter"]]) -> None: def __init__(self, workflow_parameter_tuples: list[tuple[WorkflowParameter, "WorkflowRunParameter"]]) -> None:
self.aws_client = AsyncAWSClient()
self.parameters = {} self.parameters = {}
self.values = {} self.values = {}
self.secrets = {}
for parameter, run_parameter in workflow_parameter_tuples: for parameter, run_parameter in workflow_parameter_tuples:
if parameter.key in self.parameters: if parameter.key in self.parameters:
prev_value = self.parameters[parameter.key] prev_value = self.parameters[parameter.key]
@@ -32,8 +34,33 @@ class ContextManager:
self.parameters[parameter.key] = parameter self.parameters[parameter.key] = parameter
self.values[parameter.key] = run_parameter.value self.values[parameter.key] = run_parameter.value
def get_parameter(self, key: str) -> Parameter:
return self.parameters[key]
def get_value(self, key: str) -> Any:
"""
Get the value of a parameter. If the parameter is an AWS secret, the value will be the random secret id, not
the actual secret value. This will be used when building the navigation payload since we don't want to expose
the actual secret value in the payload.
"""
return self.values[key]
def set_value(self, key: str, value: Any) -> None:
self.values[key] = value
def get_original_secret_value_or_none(self, secret_id: str) -> Any:
"""
Get the original secret value from the secrets dict. If the secret id is not found, return None.
"""
return self.secrets.get(secret_id)
@staticmethod
def generate_random_secret_id() -> str:
return f"secret_{uuid.uuid4()}"
async def register_parameter_value( async def register_parameter_value(
self, self,
aws_client: AsyncAWSClient,
parameter: PARAMETER_TYPE, parameter: PARAMETER_TYPE,
) -> None: ) -> None:
if parameter.parameter_type == ParameterType.WORKFLOW: if parameter.parameter_type == ParameterType.WORKFLOW:
@@ -42,15 +69,21 @@ class ContextManager:
f"Workflow parameters are set while initializing context manager. Parameter key: {parameter.key}" f"Workflow parameters are set while initializing context manager. Parameter key: {parameter.key}"
) )
elif parameter.parameter_type == ParameterType.AWS_SECRET: elif parameter.parameter_type == ParameterType.AWS_SECRET:
secret_value = await self.aws_client.get_secret(parameter.aws_key) # If the parameter is an AWS secret, fetch the secret value and store it in the secrets dict
# The value of the parameter will be the random secret id with format `secret_<uuid>`.
# We'll replace the random secret id with the actual secret value when we need to use it.
secret_value = await aws_client.get_secret(parameter.aws_key)
if secret_value is not None: if secret_value is not None:
self.values[parameter.key] = secret_value random_secret_id = self.generate_random_secret_id()
self.secrets[random_secret_id] = secret_value
self.values[parameter.key] = random_secret_id
else: else:
# ContextParameter values will be set within the blocks # ContextParameter values will be set within the blocks
return None return None
async def register_block_parameters( async def register_block_parameters(
self, self,
aws_client: AsyncAWSClient,
parameters: list[PARAMETER_TYPE], parameters: list[PARAMETER_TYPE],
) -> None: ) -> None:
for parameter in parameters: for parameter in parameters:
@@ -67,13 +100,41 @@ class ContextManager:
) )
self.parameters[parameter.key] = parameter self.parameters[parameter.key] = parameter
await self.register_parameter_value(parameter) await self.register_parameter_value(aws_client, parameter)
def get_parameter(self, key: str) -> Parameter:
return self.parameters[key]
def get_value(self, key: str) -> Any: class WorkflowContextManager:
return self.values[key] aws_client: AsyncAWSClient
workflow_run_contexts: dict[str, WorkflowRunContext]
def set_value(self, key: str, value: Any) -> None: parameters: dict[str, PARAMETER_TYPE]
self.values[key] = value values: dict[str, Any]
secrets: dict[str, Any]
def __init__(self) -> None:
self.aws_client = AsyncAWSClient()
self.workflow_run_contexts = {}
def _validate_workflow_run_context(self, workflow_run_id: str) -> None:
if workflow_run_id not in self.workflow_run_contexts:
LOG.error(f"WorkflowRunContext not initialized for workflow run {workflow_run_id}")
raise WorkflowRunContextNotInitialized(workflow_run_id=workflow_run_id)
def initialize_workflow_run_context(
self, workflow_run_id: str, workflow_parameter_tuples: list[tuple[WorkflowParameter, "WorkflowRunParameter"]]
) -> WorkflowRunContext:
workflow_run_context = WorkflowRunContext(workflow_parameter_tuples)
self.workflow_run_contexts[workflow_run_id] = workflow_run_context
return workflow_run_context
def get_workflow_run_context(self, workflow_run_id: str) -> WorkflowRunContext:
self._validate_workflow_run_context(workflow_run_id)
return self.workflow_run_contexts[workflow_run_id]
async def register_block_parameters_for_workflow_run(
self,
workflow_run_id: str,
parameters: list[PARAMETER_TYPE],
) -> None:
self._validate_workflow_run_context(workflow_run_id)
await self.workflow_run_contexts[workflow_run_id].register_block_parameters(self.aws_client, parameters)

View File

@@ -13,7 +13,7 @@ from skyvern.exceptions import (
) )
from skyvern.forge import app from skyvern.forge import app
from skyvern.forge.sdk.schemas.tasks import TaskStatus from skyvern.forge.sdk.schemas.tasks import TaskStatus
from skyvern.forge.sdk.workflow.context_manager import ContextManager from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext
from skyvern.forge.sdk.workflow.models.parameter import PARAMETER_TYPE, ContextParameter, WorkflowParameter from skyvern.forge.sdk.workflow.models.parameter import PARAMETER_TYPE, ContextParameter, WorkflowParameter
LOG = structlog.get_logger() LOG = structlog.get_logger()
@@ -33,8 +33,12 @@ class Block(BaseModel, abc.ABC):
def get_subclasses(cls) -> tuple[type["Block"], ...]: def get_subclasses(cls) -> tuple[type["Block"], ...]:
return tuple(cls.__subclasses__()) return tuple(cls.__subclasses__())
@staticmethod
def get_workflow_run_context(workflow_run_id: str) -> WorkflowRunContext:
return app.WORKFLOW_CONTEXT_MANAGER.get_workflow_run_context(workflow_run_id)
@abc.abstractmethod @abc.abstractmethod
async def execute(self, workflow_run_id: str, context_manager: ContextManager, **kwargs: dict) -> Any: async def execute(self, workflow_run_id: str, **kwargs: dict) -> Any:
pass pass
@abc.abstractmethod @abc.abstractmethod
@@ -48,9 +52,12 @@ class TaskBlock(Block):
block_type: Literal[BlockType.TASK] = BlockType.TASK block_type: Literal[BlockType.TASK] = BlockType.TASK
url: str | None = None url: str | None = None
title: str = "Untitled Task"
navigation_goal: str | None = None navigation_goal: str | None = None
data_extraction_goal: str | None = None data_extraction_goal: str | None = None
data_schema: dict[str, Any] | None = None data_schema: dict[str, Any] | None = None
# error code to error description for the LLM
error_code_mapping: dict[str, str] | None = None
max_retries: int = 0 max_retries: int = 0
parameters: list[PARAMETER_TYPE] = [] parameters: list[PARAMETER_TYPE] = []
@@ -89,8 +96,8 @@ class TaskBlock(Block):
return order, retry + 1 return order, retry + 1
async def execute(self, workflow_run_id: str, context_manager: ContextManager, **kwargs: dict) -> Any: async def execute(self, workflow_run_id: str, **kwargs: dict) -> Any:
task = None workflow_run_context = self.get_workflow_run_context(workflow_run_id)
current_retry = 0 current_retry = 0
# initial value for will_retry is True, so that the loop runs at least once # initial value for will_retry is True, so that the loop runs at least once
will_retry = True will_retry = True
@@ -104,7 +111,7 @@ class TaskBlock(Block):
task_block=self, task_block=self,
workflow=workflow, workflow=workflow,
workflow_run=workflow_run, workflow_run=workflow_run,
context_manager=context_manager, workflow_run_context=workflow_run_context,
task_order=task_order, task_order=task_order,
task_retry=task_retry, task_retry=task_retry,
) )
@@ -131,7 +138,18 @@ class TaskBlock(Block):
if self.url: if self.url:
await browser_state.page.goto(self.url) await browser_state.page.goto(self.url)
await app.agent.execute_step(organization=organization, task=task, step=step, workflow_run=workflow_run) try:
await app.agent.execute_step(organization=organization, task=task, step=step, workflow_run=workflow_run)
except Exception as e:
# Make sure the task is marked as failed in the database before raising the exception
await app.DATABASE.update_task(
task.task_id,
status=TaskStatus.failed,
organization_id=workflow.organization_id,
failure_reason=str(e),
)
raise e
# Check task status # Check task status
updated_task = await app.DATABASE.get_task(task_id=task.task_id, organization_id=workflow.organization_id) updated_task = await app.DATABASE.get_task(task_id=task.task_id, organization_id=workflow.organization_id)
if not updated_task: if not updated_task:
@@ -188,9 +206,9 @@ class ForLoopBlock(Block):
return context_parameters return context_parameters
def get_loop_over_parameter_values(self, context_manager: ContextManager) -> list[Any]: def get_loop_over_parameter_values(self, workflow_run_context: WorkflowRunContext) -> list[Any]:
if isinstance(self.loop_over, WorkflowParameter): if isinstance(self.loop_over, WorkflowParameter):
parameter_value = context_manager.get_value(self.loop_over.key) parameter_value = workflow_run_context.get_value(self.loop_over.key)
if isinstance(parameter_value, list): if isinstance(parameter_value, list):
return parameter_value return parameter_value
else: else:
@@ -200,8 +218,9 @@ class ForLoopBlock(Block):
# TODO (kerem): Implement this for context parameters # TODO (kerem): Implement this for context parameters
raise NotImplementedError raise NotImplementedError
async def execute(self, workflow_run_id: str, context_manager: ContextManager, **kwargs: dict) -> Any: async def execute(self, workflow_run_id: str, **kwargs: dict) -> Any:
loop_over_values = self.get_loop_over_parameter_values(context_manager) workflow_run_context = self.get_workflow_run_context(workflow_run_id)
loop_over_values = self.get_loop_over_parameter_values(workflow_run_context)
LOG.info( LOG.info(
f"Number of loop_over values: {len(loop_over_values)}", f"Number of loop_over values: {len(loop_over_values)}",
block_type=self.block_type, block_type=self.block_type,
@@ -211,8 +230,8 @@ class ForLoopBlock(Block):
for loop_over_value in loop_over_values: for loop_over_value in loop_over_values:
context_parameters_with_value = self.get_loop_block_context_parameters(workflow_run_id, loop_over_value) context_parameters_with_value = self.get_loop_block_context_parameters(workflow_run_id, loop_over_value)
for context_parameter in context_parameters_with_value: for context_parameter in context_parameters_with_value:
context_manager.set_value(context_parameter.key, context_parameter.value) workflow_run_context.set_value(context_parameter.key, context_parameter.value)
await self.loop_block.execute(workflow_run_id=workflow_run_id, context_manager=context_manager) await self.loop_block.execute(workflow_run_id=workflow_run_id)
return None return None

View File

@@ -72,3 +72,4 @@ class WorkflowRunStatusResponse(BaseModel):
parameters: dict[str, Any] parameters: dict[str, Any]
screenshot_urls: list[str] | None = None screenshot_urls: list[str] | None = None
recording_url: str | None = None recording_url: str | None = None
payload: dict[str, Any] | None = None

View File

@@ -1,6 +1,5 @@
import asyncio
import json import json
import time from collections import Counter
from datetime import datetime from datetime import datetime
import requests import requests
@@ -19,8 +18,8 @@ from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.security import generate_skyvern_signature from skyvern.forge.sdk.core.security import generate_skyvern_signature
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.schemas.tasks import Task from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.workflow.context_manager import ContextManager from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
from skyvern.forge.sdk.workflow.models.parameter import AWSSecretParameter, WorkflowParameter, WorkflowParameterType from skyvern.forge.sdk.workflow.models.parameter import AWSSecretParameter, WorkflowParameter, WorkflowParameterType
from skyvern.forge.sdk.workflow.models.workflow import ( from skyvern.forge.sdk.workflow.models.workflow import (
Workflow, Workflow,
@@ -55,7 +54,6 @@ class WorkflowService:
:param max_steps_override: The max steps override for the workflow run, if any. :param max_steps_override: The max steps override for the workflow run, if any.
:return: The created workflow run. :return: The created workflow run.
""" """
LOG.info(f"Setting up workflow run for workflow {workflow_id}", workflow_id=workflow_id)
# Validate the workflow and the organization # Validate the workflow and the organization
workflow = await self.get_workflow(workflow_id=workflow_id) workflow = await self.get_workflow(workflow_id=workflow_id)
if workflow is None: if workflow is None:
@@ -83,9 +81,6 @@ class WorkflowService:
) )
) )
# Set workflow run status to running, create workflow run parameters
await self.mark_workflow_run_as_running(workflow_run_id=workflow_run.workflow_run_id)
# Create all the workflow run parameters, AWSSecretParameter won't have workflow run parameters created. # Create all the workflow run parameters, AWSSecretParameter won't have workflow run parameters created.
all_workflow_parameters = await self.get_workflow_parameters(workflow_id=workflow.workflow_id) all_workflow_parameters = await self.get_workflow_parameters(workflow_id=workflow.workflow_id)
workflow_run_parameters = [] workflow_run_parameters = []
@@ -113,11 +108,6 @@ class WorkflowService:
workflow_run_parameters.append(workflow_run_parameter) workflow_run_parameters.append(workflow_run_parameter)
LOG.info(
f"Created workflow run parameters for workflow run {workflow_run.workflow_run_id}",
workflow_run_id=workflow_run.workflow_run_id,
)
return workflow_run return workflow_run
async def execute_workflow( async def execute_workflow(
@@ -129,59 +119,92 @@ class WorkflowService:
workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id) workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id)
workflow = await self.get_workflow(workflow_id=workflow_run.workflow_id) workflow = await self.get_workflow(workflow_id=workflow_run.workflow_id)
await app.BROWSER_MANAGER.get_or_create_for_workflow_run(workflow_run=workflow_run) # Set workflow run status to running, create workflow run parameters
await self.mark_workflow_run_as_running(workflow_run_id=workflow_run.workflow_run_id)
# Get all <workflow parameter, workflow run parameter> tuples # Get all <workflow parameter, workflow run parameter> tuples
wp_wps_tuples = await self.get_workflow_run_parameter_tuples(workflow_run_id=workflow_run.workflow_run_id) wp_wps_tuples = await self.get_workflow_run_parameter_tuples(workflow_run_id=workflow_run.workflow_run_id)
# todo(kerem): do this in a better way (a shared context manager? (not really shared because we use batch job)) app.WORKFLOW_CONTEXT_MANAGER.initialize_workflow_run_context(workflow_run_id, wp_wps_tuples)
context_manager = ContextManager(wp_wps_tuples)
# Execute workflow blocks # Execute workflow blocks
blocks = workflow.workflow_definition.blocks blocks = workflow.workflow_definition.blocks
for block_idx, block in enumerate(blocks): try:
parameters = block.get_all_parameters() for block_idx, block in enumerate(blocks):
await context_manager.register_block_parameters(parameters) parameters = block.get_all_parameters()
LOG.info( await app.WORKFLOW_CONTEXT_MANAGER.register_block_parameters_for_workflow_run(
f"Executing root block {block.block_type} at index {block_idx} for workflow run {workflow_run.workflow_run_id}", workflow_run_id, parameters
block_type=block.block_type, )
LOG.info(
f"Executing root block {block.block_type} at index {block_idx} for workflow run {workflow_run_id}",
block_type=block.block_type,
workflow_run_id=workflow_run.workflow_run_id,
block_idx=block_idx,
)
await block.execute(workflow_run_id=workflow_run_id)
except Exception:
LOG.exception(
f"Error while executing workflow run {workflow_run.workflow_run_id}",
workflow_run_id=workflow_run.workflow_run_id, workflow_run_id=workflow_run.workflow_run_id,
block_idx=block_idx, exc_info=True,
) )
await block.execute(workflow_run_id=workflow_run.workflow_run_id, context_manager=context_manager)
# Get last task for workflow run tasks = await self.get_tasks_by_workflow_run_id(workflow_run.workflow_run_id)
task = await self.get_last_task_for_workflow_run(workflow_run_id=workflow_run.workflow_run_id) if not tasks:
if not task:
LOG.warning( LOG.warning(
f"No tasks found for workflow run {workflow_run.workflow_run_id}, not sending webhook", f"No tasks found for workflow run {workflow_run.workflow_run_id}, not sending webhook, marking as failed",
workflow_run_id=workflow_run.workflow_run_id, workflow_run_id=workflow_run.workflow_run_id,
) )
await self.mark_workflow_run_as_failed(workflow_run_id=workflow_run.workflow_run_id)
return workflow_run return workflow_run
# Update workflow status workflow_run = await self.handle_workflow_status(workflow_run=workflow_run, tasks=tasks)
if task.status == "completed":
await self.mark_workflow_run_as_completed(workflow_run_id=workflow_run.workflow_run_id)
elif task.status == "failed":
await self.mark_workflow_run_as_failed(workflow_run_id=workflow_run.workflow_run_id)
elif task.status == "terminated":
await self.mark_workflow_run_as_terminated(workflow_run_id=workflow_run.workflow_run_id)
else:
LOG.warning(
f"Task {task.task_id} has an incomplete status {task.status}, not updating workflow run status",
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
task_id=task.task_id,
status=task.status,
workflow_run_status=workflow_run.status,
)
await self.send_workflow_response( await self.send_workflow_response(
workflow=workflow, workflow=workflow,
workflow_run=workflow_run, workflow_run=workflow_run,
tasks=tasks,
api_key=api_key, api_key=api_key,
last_task=task,
) )
return workflow_run return workflow_run
async def handle_workflow_status(self, workflow_run: WorkflowRun, tasks: list[Task]) -> WorkflowRun:
task_counts_by_status = Counter(task.status for task in tasks)
# Create a mapping of status to (action, log_func, log_message)
status_action_mapping = {
TaskStatus.running: (None, LOG.error, "has running tasks, this should not happen"),
TaskStatus.terminated: (
self.mark_workflow_run_as_terminated,
LOG.warning,
"has terminated tasks, marking as terminated",
),
TaskStatus.failed: (self.mark_workflow_run_as_failed, LOG.warning, "has failed tasks, marking as failed"),
TaskStatus.completed: (
self.mark_workflow_run_as_completed,
LOG.info,
"tasks are completed, marking as completed",
),
}
for status, (action, log_func, log_message) in status_action_mapping.items():
if task_counts_by_status.get(status, 0) > 0:
if action is not None:
await action(workflow_run_id=workflow_run.workflow_run_id)
if log_func and log_message:
log_func(
f"Workflow run {workflow_run.workflow_run_id} {log_message}",
workflow_run_id=workflow_run.workflow_run_id,
task_counts_by_status=task_counts_by_status,
)
return workflow_run
# Handle unexpected state
LOG.error(
f"Workflow run {workflow_run.workflow_run_id} has tasks in an unexpected state, marking as failed",
workflow_run_id=workflow_run.workflow_run_id,
task_counts_by_status=task_counts_by_status,
)
await self.mark_workflow_run_as_failed(workflow_run_id=workflow_run.workflow_run_id)
return workflow_run
async def create_workflow( async def create_workflow(
self, self,
organization_id: str, organization_id: str,
@@ -354,6 +377,15 @@ class WorkflowService:
workflow_parameter_tuples = await app.DATABASE.get_workflow_run_parameters(workflow_run_id=workflow_run_id) workflow_parameter_tuples = await app.DATABASE.get_workflow_run_parameters(workflow_run_id=workflow_run_id)
parameters_with_value = {wfp.key: wfrp.value for wfp, wfrp in workflow_parameter_tuples} parameters_with_value = {wfp.key: wfrp.value for wfp, wfrp in workflow_parameter_tuples}
payload = {
task.task_id: {
"title": task.title,
"extracted_information": task.extracted_information,
"navigation_payload": task.navigation_payload,
"errors": await app.agent.get_task_errors(task=task),
}
for task in workflow_run_tasks
}
return WorkflowRunStatusResponse( return WorkflowRunStatusResponse(
workflow_id=workflow_id, workflow_id=workflow_id,
workflow_run_id=workflow_run_id, workflow_run_id=workflow_run_id,
@@ -365,50 +397,28 @@ class WorkflowService:
parameters=parameters_with_value, parameters=parameters_with_value,
screenshot_urls=screenshot_urls, screenshot_urls=screenshot_urls,
recording_url=recording_url, recording_url=recording_url,
payload=payload,
) )
async def send_workflow_response( async def send_workflow_response(
self, self,
workflow: Workflow, workflow: Workflow,
workflow_run: WorkflowRun, workflow_run: WorkflowRun,
last_task: Task, tasks: list[Task],
api_key: str | None = None, api_key: str | None = None,
close_browser_on_completion: bool = True, close_browser_on_completion: bool = True,
) -> None: ) -> None:
analytics.capture("skyvern-oss-agent-workflow-status", {"status": workflow_run.status}) analytics.capture("skyvern-oss-agent-workflow-status", {"status": workflow_run.status})
all_workflow_task_ids = [task.task_id for task in tasks]
browser_state = await app.BROWSER_MANAGER.cleanup_for_workflow_run( browser_state = await app.BROWSER_MANAGER.cleanup_for_workflow_run(
workflow_run.workflow_run_id, close_browser_on_completion workflow_run.workflow_run_id, all_workflow_task_ids, close_browser_on_completion
) )
if browser_state: if browser_state:
await self.persist_video_data(browser_state, workflow, workflow_run) await self.persist_video_data(browser_state, workflow, workflow_run)
await self.persist_har_data(browser_state, last_task, workflow, workflow_run) await self.persist_debug_artifacts(browser_state, tasks[-1], workflow, workflow_run)
# Wait for all tasks to complete before generating the links for the artifacts
all_workflow_tasks = await app.DATABASE.get_tasks_by_workflow_run_id(
workflow_run_id=workflow_run.workflow_run_id
)
all_workflow_task_ids = [task.task_id for task in all_workflow_tasks]
await app.ARTIFACT_MANAGER.wait_for_upload_aiotasks_for_tasks(all_workflow_task_ids) await app.ARTIFACT_MANAGER.wait_for_upload_aiotasks_for_tasks(all_workflow_task_ids)
try:
# Wait for all tasks to complete. Currently we're using asyncio.create_task() only for uploading artifacts to S3.
# We're excluding the current task from the list of tasks to wait for to prevent a deadlock.
st = time.time()
async with asyncio.timeout(30):
await asyncio.gather(
*[aio_task for aio_task in (asyncio.all_tasks() - {asyncio.current_task()}) if not aio_task.done()]
)
LOG.info(
f"Waiting for all S3 uploads to complete took {time.time() - st} seconds",
duration=time.time() - st,
)
except asyncio.TimeoutError:
LOG.warning(
"Timed out waiting for all S3 uploads to complete, not all artifacts may be uploaded. Waited 30 seconds.",
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
)
if not workflow_run.webhook_callback_url: if not workflow_run.webhook_callback_url:
LOG.warning( LOG.warning(
"Workflow has no webhook callback url. Not sending workflow response", "Workflow has no webhook callback url. Not sending workflow response",
@@ -493,19 +503,35 @@ class WorkflowService:
) )
async def persist_har_data( async def persist_har_data(
self, browser_state: BrowserState, last_task: Task, workflow: Workflow, workflow_run: WorkflowRun self, browser_state: BrowserState, last_step: Step, workflow: Workflow, workflow_run: WorkflowRun
) -> None: ) -> None:
har_data = await app.BROWSER_MANAGER.get_har_data( har_data = await app.BROWSER_MANAGER.get_har_data(
workflow_id=workflow.workflow_id, workflow_run_id=workflow_run.workflow_run_id, browser_state=browser_state workflow_id=workflow.workflow_id, workflow_run_id=workflow_run.workflow_run_id, browser_state=browser_state
) )
if har_data: if har_data:
last_step = await app.DATABASE.get_latest_step( await app.ARTIFACT_MANAGER.create_artifact(
task_id=last_task.task_id, organization_id=last_task.organization_id step=last_step,
artifact_type=ArtifactType.HAR,
data=har_data,
) )
if last_step: async def persist_tracing_data(
await app.ARTIFACT_MANAGER.create_artifact( self, browser_state: BrowserState, last_step: Step, workflow_run: WorkflowRun
step=last_step, ) -> None:
artifact_type=ArtifactType.HAR, if browser_state.browser_context is None or browser_state.browser_artifacts.traces_dir is None:
data=har_data, return
)
trace_path = f"{browser_state.browser_artifacts.traces_dir}/{workflow_run.workflow_run_id}.zip"
await app.ARTIFACT_MANAGER.create_artifact(step=last_step, artifact_type=ArtifactType.TRACE, path=trace_path)
async def persist_debug_artifacts(
self, browser_state: BrowserState, last_task: Task, workflow: Workflow, workflow_run: WorkflowRun
) -> None:
last_step = await app.DATABASE.get_latest_step(
task_id=last_task.task_id, organization_id=last_task.organization_id
)
if not last_step:
return
await self.persist_har_data(browser_state, last_step, workflow, workflow_run)
await self.persist_tracing_data(browser_state, last_step, workflow_run)

3
skyvern/scheduler.py Normal file
View File

@@ -0,0 +1,3 @@
from apscheduler.schedulers.asyncio import AsyncIOScheduler
SCHEDULER = AsyncIOScheduler()

View File

@@ -3,7 +3,7 @@ from enum import StrEnum
from typing import Any, Dict, List from typing import Any, Dict, List
import structlog import structlog
from pydantic import BaseModel from pydantic import BaseModel, Field
from skyvern.forge.sdk.schemas.tasks import Task from skyvern.forge.sdk.schemas.tasks import Task
@@ -34,6 +34,16 @@ class WebAction(Action, abc.ABC):
element_id: int element_id: int
class UserDefinedError(BaseModel):
error_code: str
reasoning: str
confidence_float: float = Field(..., ge=0, le=1)
class DecisiveAction(Action, abc.ABC):
errors: List[UserDefinedError] = []
class ClickAction(WebAction): class ClickAction(WebAction):
action_type: ActionType = ActionType.CLICK action_type: ActionType = ActionType.CLICK
file_url: str | None = None file_url: str | None = None
@@ -102,11 +112,11 @@ class WaitAction(Action):
action_type: ActionType = ActionType.WAIT action_type: ActionType = ActionType.WAIT
class TerminateAction(Action): class TerminateAction(DecisiveAction):
action_type: ActionType = ActionType.TERMINATE action_type: ActionType = ActionType.TERMINATE
class CompleteAction(Action): class CompleteAction(DecisiveAction):
action_type: ActionType = ActionType.COMPLETE action_type: ActionType = ActionType.COMPLETE
data_extraction_goal: str | None = None data_extraction_goal: str | None = None
@@ -129,7 +139,7 @@ def parse_actions(task: Task, json_response: List[Dict[str, Any]]) -> List[Actio
reasoning=reasoning, reasoning=reasoning,
actions=actions, actions=actions,
) )
actions.append(TerminateAction(reasoning=reasoning)) actions.append(TerminateAction(reasoning=reasoning, errors=action["errors"] if "errors" in action else []))
elif action_type == ActionType.CLICK: elif action_type == ActionType.CLICK:
file_url = action["file_url"] if "file_url" in action else None file_url = action["file_url"] if "file_url" in action else None
actions.append(ClickAction(element_id=element_id, reasoning=reasoning, file_url=file_url)) actions.append(ClickAction(element_id=element_id, reasoning=reasoning, file_url=file_url))
@@ -165,7 +175,13 @@ def parse_actions(task: Task, json_response: List[Dict[str, Any]]) -> List[Actio
actions=actions, actions=actions,
llm_response=json_response, llm_response=json_response,
) )
return [CompleteAction(reasoning=reasoning, data_extraction_goal=task.data_extraction_goal)] return [
CompleteAction(
reasoning=reasoning,
data_extraction_goal=task.data_extraction_goal,
errors=action["errors"] if "errors" in action else [],
)
]
elif action_type == "null": elif action_type == "null":
actions.append(NullAction(reasoning=reasoning)) actions.append(NullAction(reasoning=reasoning))
elif action_type == ActionType.SOLVE_CAPTCHA: elif action_type == ActionType.SOLVE_CAPTCHA:

View File

@@ -1,6 +1,7 @@
import asyncio import asyncio
import json
import re import re
from typing import Awaitable, Callable, List from typing import Any, Awaitable, Callable, List
import structlog import structlog
from playwright.async_api import Locator, Page from playwright.async_api import Locator, Page
@@ -82,7 +83,9 @@ async def handle_click_action(
) -> list[ActionResult]: ) -> list[ActionResult]:
xpath = await validate_actions_in_dom(action, page, scraped_page) xpath = await validate_actions_in_dom(action, page, scraped_page)
await asyncio.sleep(0.3) await asyncio.sleep(0.3)
return await chain_click(page, action, xpath, timeout=SettingsManager.get_settings().BROWSER_ACTION_TIMEOUT_MS) return await chain_click(
task, page, action, xpath, timeout=SettingsManager.get_settings().BROWSER_ACTION_TIMEOUT_MS
)
async def handle_input_text_action( async def handle_input_text_action(
@@ -91,7 +94,8 @@ async def handle_input_text_action(
xpath = await validate_actions_in_dom(action, page, scraped_page) xpath = await validate_actions_in_dom(action, page, scraped_page)
locator = page.locator(f"xpath={xpath}") locator = page.locator(f"xpath={xpath}")
await locator.clear() await locator.clear()
await locator.fill(action.text, timeout=SettingsManager.get_settings().BROWSER_ACTION_TIMEOUT_MS) text = get_actual_value_of_parameter_if_secret(task, action.text)
await locator.fill(text, timeout=SettingsManager.get_settings().BROWSER_ACTION_TIMEOUT_MS)
# This is a hack that gets dropdowns to select the "best" option based on what's typed # This is a hack that gets dropdowns to select the "best" option based on what's typed
# Fixes situations like tsk_228671423990405776 where the location isn't being autocompleted # Fixes situations like tsk_228671423990405776 where the location isn't being autocompleted
@@ -100,7 +104,7 @@ async def handle_input_text_action(
if not input_value: if not input_value:
LOG.info("Failed to input the text, trying to press sequentially with an enter click", action=action) LOG.info("Failed to input the text, trying to press sequentially with an enter click", action=action)
await locator.clear() await locator.clear()
await locator.press_sequentially(action.text, timeout=SettingsManager.get_settings().BROWSER_ACTION_TIMEOUT_MS) await locator.press_sequentially(text, timeout=SettingsManager.get_settings().BROWSER_ACTION_TIMEOUT_MS)
await locator.press("Enter", timeout=SettingsManager.get_settings().BROWSER_ACTION_TIMEOUT_MS) await locator.press("Enter", timeout=SettingsManager.get_settings().BROWSER_ACTION_TIMEOUT_MS)
input_value = await locator.input_value(timeout=SettingsManager.get_settings().BROWSER_ACTION_TIMEOUT_MS) input_value = await locator.input_value(timeout=SettingsManager.get_settings().BROWSER_ACTION_TIMEOUT_MS)
LOG.info("Input value", input_value=input_value, action=action) LOG.info("Input value", input_value=input_value, action=action)
@@ -114,7 +118,12 @@ async def handle_upload_file_action(
if not action.file_url: if not action.file_url:
LOG.warning("InputFileAction has no file_url", action=action) LOG.warning("InputFileAction has no file_url", action=action)
return [ActionFailure(MissingFileUrl())] return [ActionFailure(MissingFileUrl())]
if action.file_url not in str(task.navigation_payload): # ************************************************************************************************************** #
# After this point if the file_url is a secret, it will be replaced with the actual value
# In order to make sure we don't log the secret value, we log the action with the original value action.file_url
# ************************************************************************************************************** #
file_url = get_actual_value_of_parameter_if_secret(task, action.file_url)
if file_url not in str(task.navigation_payload):
LOG.warning( LOG.warning(
"LLM might be imagining the file url, which is not in navigation payload", "LLM might be imagining the file url, which is not in navigation payload",
action=action, action=action,
@@ -122,7 +131,7 @@ async def handle_upload_file_action(
) )
return [ActionFailure(ImaginaryFileUrl(action.file_url))] return [ActionFailure(ImaginaryFileUrl(action.file_url))]
xpath = await validate_actions_in_dom(action, page, scraped_page) xpath = await validate_actions_in_dom(action, page, scraped_page)
file_path = download_file(action.file_url) file_path = download_file(file_url)
locator = page.locator(f"xpath={xpath}") locator = page.locator(f"xpath={xpath}")
is_file_input = await is_file_input_element(locator) is_file_input = await is_file_input_element(locator)
if is_file_input: if is_file_input:
@@ -141,7 +150,9 @@ async def handle_upload_file_action(
LOG.info("Taking UploadFileAction. Found non file input tag", action=action) LOG.info("Taking UploadFileAction. Found non file input tag", action=action)
# treat it as a click action # treat it as a click action
action.is_upload_file_tag = False action.is_upload_file_tag = False
return await chain_click(page, action, xpath, timeout=SettingsManager.get_settings().BROWSER_ACTION_TIMEOUT_MS) return await chain_click(
task, page, action, xpath, timeout=SettingsManager.get_settings().BROWSER_ACTION_TIMEOUT_MS
)
async def handle_null_action( async def handle_null_action(
@@ -189,7 +200,7 @@ async def handle_select_option_action(
child_anchor_xpath=child_anchor_xpath, child_anchor_xpath=child_anchor_xpath,
) )
click_action = ClickAction(element_id=action.element_id) click_action = ClickAction(element_id=action.element_id)
return await chain_click(page, click_action, child_anchor_xpath) return await chain_click(task, page, click_action, child_anchor_xpath)
return [ActionFailure(Exception("No anchor tag found for the label for SelectOptionAction"))] return [ActionFailure(Exception("No anchor tag found for the label for SelectOptionAction"))]
elif tag_name == "a": elif tag_name == "a":
# turn the SelectOptionAction into a ClickAction # turn the SelectOptionAction into a ClickAction
@@ -198,7 +209,7 @@ async def handle_select_option_action(
action=action, action=action,
) )
click_action = ClickAction(element_id=action.element_id) click_action = ClickAction(element_id=action.element_id)
action_result = await chain_click(page, click_action, xpath) action_result = await chain_click(task, page, click_action, xpath)
return action_result return action_result
elif tag_name == "ul" or tag_name == "div" or tag_name == "li": elif tag_name == "ul" or tag_name == "div" or tag_name == "li":
# if the role is listbox, find the option with the "label" or "value" and click that option element # if the role is listbox, find the option with the "label" or "value" and click that option element
@@ -234,7 +245,7 @@ async def handle_select_option_action(
) )
# click the option element # click the option element
click_action = ClickAction(element_id=action.element_id) click_action = ClickAction(element_id=action.element_id)
return await chain_click(page, click_action, xpath) return await chain_click(task, page, click_action, xpath)
else: else:
LOG.error( LOG.error(
"SelectOptionAction on a non-listbox element. Cannot handle this action", "SelectOptionAction on a non-listbox element. Cannot handle this action",
@@ -349,6 +360,22 @@ ActionHandler.register_action_type(ActionType.TERMINATE, handle_terminate_action
ActionHandler.register_action_type(ActionType.COMPLETE, handle_complete_action) ActionHandler.register_action_type(ActionType.COMPLETE, handle_complete_action)
def get_actual_value_of_parameter_if_secret(task: Task, parameter: str) -> Any:
"""
Get the actual value of a parameter if it's a secret. If it's not a secret, return the parameter value as is.
Just return the parameter value if the task isn't a workflow's task.
This is only used for InputTextAction, UploadFileAction, and ClickAction (if it has a file_url).
"""
if task.workflow_run_id is None:
return parameter
workflow_run_context = app.WORKFLOW_CONTEXT_MANAGER.get_workflow_run_context(task.workflow_run_id)
secret_value = workflow_run_context.get_original_secret_value_or_none(parameter)
return secret_value if secret_value is not None else parameter
async def validate_actions_in_dom(action: WebAction, page: Page, scraped_page: ScrapedPage) -> str: async def validate_actions_in_dom(action: WebAction, page: Page, scraped_page: ScrapedPage) -> str:
xpath = scraped_page.id_to_xpath_dict[action.element_id] xpath = scraped_page.id_to_xpath_dict[action.element_id]
locator = page.locator(xpath) locator = page.locator(xpath)
@@ -371,6 +398,7 @@ async def validate_actions_in_dom(action: WebAction, page: Page, scraped_page: S
async def chain_click( async def chain_click(
task: Task,
page: Page, page: Page,
action: ClickAction | UploadFileAction, action: ClickAction | UploadFileAction,
xpath: str, xpath: str,
@@ -384,7 +412,8 @@ async def chain_click(
LOG.info("Chain click starts", action=action, xpath=xpath) LOG.info("Chain click starts", action=action, xpath=xpath)
file: list[str] | str = [] file: list[str] | str = []
if action.file_url: if action.file_url:
file = download_file(action.file_url) or [] file_url = get_actual_value_of_parameter_if_secret(task, action.file_url)
file = download_file(file_url) or []
fc_func = lambda fc: fc.set_files(files=file) fc_func = lambda fc: fc.set_files(files=file)
page.on("filechooser", fc_func) page.on("filechooser", fc_func)
@@ -535,11 +564,13 @@ async def extract_information_for_navigation_goal(
extract_information_prompt = prompt_engine.load_prompt( extract_information_prompt = prompt_engine.load_prompt(
prompt_template, prompt_template,
navigation_goal=task.navigation_goal, navigation_goal=task.navigation_goal,
navigation_payload=task.navigation_payload,
elements=scraped_page.element_tree, elements=scraped_page.element_tree,
data_extraction_goal=task.data_extraction_goal, data_extraction_goal=task.data_extraction_goal,
extracted_information_schema=task.extracted_information_schema, extracted_information_schema=task.extracted_information_schema,
current_url=scraped_page.url, current_url=scraped_page.url,
extracted_text=scraped_page.extracted_text, extracted_text=scraped_page.extracted_text,
error_code_mapping_str=json.dumps(task.error_code_mapping) if task.error_code_mapping else None,
) )
json_response = await app.OPENAI_CLIENT.chat_completion( json_response = await app.OPENAI_CLIENT.chat_completion(

View File

@@ -5,7 +5,7 @@ from typing import Any
from pydantic import BaseModel from pydantic import BaseModel
from skyvern.forge.sdk.settings_manager import SettingsManager from skyvern.forge.sdk.settings_manager import SettingsManager
from skyvern.webeye.actions.actions import Action, ActionTypeUnion from skyvern.webeye.actions.actions import Action, ActionTypeUnion, DecisiveAction, UserDefinedError
from skyvern.webeye.actions.responses import ActionResult from skyvern.webeye.actions.responses import ActionResult
from skyvern.webeye.scraper.scraper import ScrapedPage from skyvern.webeye.scraper.scraper import ScrapedPage
@@ -19,6 +19,7 @@ class AgentStepOutput(BaseModel):
action_results: list[ActionResult] | None = None action_results: list[ActionResult] | None = None
# Nullable for backwards compatibility, once backfill is done, this won't be nullable anymore # Nullable for backwards compatibility, once backfill is done, this won't be nullable anymore
actions_and_results: list[tuple[ActionTypeUnion, list[ActionResult]]] | None = None actions_and_results: list[tuple[ActionTypeUnion, list[ActionResult]]] | None = None
errors: list[UserDefinedError] = []
def __repr__(self) -> str: def __repr__(self) -> str:
return f"AgentStepOutput({self.model_dump()})" return f"AgentStepOutput({self.model_dump()})"
@@ -51,8 +52,17 @@ class DetailedAgentStepOutput(BaseModel):
def __str__(self) -> str: def __str__(self) -> str:
return self.__repr__() return self.__repr__()
def extract_errors(self) -> list[UserDefinedError]:
errors = []
if self.actions_and_results:
for action, action_results in self.actions_and_results:
if isinstance(action, DecisiveAction):
errors.extend(action.errors)
return errors
def to_agent_step_output(self) -> AgentStepOutput: def to_agent_step_output(self) -> AgentStepOutput:
return AgentStepOutput( return AgentStepOutput(
action_results=self.action_results if self.action_results else [], action_results=self.action_results if self.action_results else [],
actions_and_results=self.actions_and_results if self.actions_and_results else [], actions_and_results=self.actions_and_results if self.actions_and_results else [],
errors=self.extract_errors(),
) )

View File

@@ -6,10 +6,17 @@ from datetime import datetime
from typing import Any, Awaitable, Protocol from typing import Any, Awaitable, Protocol
import structlog import structlog
from playwright._impl._errors import TimeoutError
from playwright.async_api import BrowserContext, Error, Page, Playwright, async_playwright from playwright.async_api import BrowserContext, Error, Page, Playwright, async_playwright
from pydantic import BaseModel from pydantic import BaseModel
from skyvern.exceptions import FailedToNavigateToUrl, UnknownBrowserType, UnknownErrorWhileCreatingBrowserContext from skyvern.exceptions import (
FailedToNavigateToUrl,
FailedToTakeScreenshot,
MissingBrowserStatePage,
UnknownBrowserType,
UnknownErrorWhileCreatingBrowserContext,
)
from skyvern.forge.sdk.core.skyvern_context import current from skyvern.forge.sdk.core.skyvern_context import current
from skyvern.forge.sdk.settings_manager import SettingsManager from skyvern.forge.sdk.settings_manager import SettingsManager
@@ -58,9 +65,14 @@ class BrowserContextFactory:
@staticmethod @staticmethod
def build_browser_artifacts( def build_browser_artifacts(
video_path: str | None = None, har_path: str | None = None, video_artifact_id: str | None = None video_path: str | None = None,
har_path: str | None = None,
video_artifact_id: str | None = None,
traces_dir: str | None = None,
) -> BrowserArtifacts: ) -> BrowserArtifacts:
return BrowserArtifacts(video_path=video_path, har_path=har_path, video_artifact_id=video_artifact_id) return BrowserArtifacts(
video_path=video_path, har_path=har_path, video_artifact_id=video_artifact_id, traces_dir=traces_dir
)
@classmethod @classmethod
def register_type(cls, browser_type: str, creator: BrowserContextCreator) -> None: def register_type(cls, browser_type: str, creator: BrowserContextCreator) -> None:
@@ -86,6 +98,7 @@ class BrowserArtifacts(BaseModel):
video_path: str | None = None video_path: str | None = None
video_artifact_id: str | None = None video_artifact_id: str | None = None
har_path: str | None = None har_path: str | None = None
traces_dir: str | None = None
async def _create_headless_chromium(playwright: Playwright, **kwargs: dict) -> tuple[BrowserContext, BrowserArtifacts]: async def _create_headless_chromium(playwright: Playwright, **kwargs: dict) -> tuple[BrowserContext, BrowserArtifacts]:
@@ -180,3 +193,26 @@ class BrowserState:
LOG.info("Stopping playwright") LOG.info("Stopping playwright")
await self.pw.stop() await self.pw.stop()
LOG.info("Playwright is stopped") LOG.info("Playwright is stopped")
async def take_screenshot(self, full_page: bool = False, file_path: str | None = None) -> bytes:
if not self.page:
LOG.error("BrowserState has no page")
raise MissingBrowserStatePage()
try:
if file_path:
return await self.page.screenshot(
path=file_path,
full_page=full_page,
timeout=SettingsManager.get_settings().BROWSER_SCREENSHOT_TIMEOUT_MS,
)
return await self.page.screenshot(
full_page=full_page,
timeout=SettingsManager.get_settings().BROWSER_SCREENSHOT_TIMEOUT_MS,
animations="disabled",
)
except TimeoutError as e:
LOG.exception(f"Timeout error while taking screenshot: {str(e)}", exc_info=True)
raise FailedToTakeScreenshot(error_message=str(e)) from e
except Exception as e:
LOG.exception(f"Unknown error while taking screenshot: {str(e)}", exc_info=True)
raise FailedToTakeScreenshot(error_message=str(e)) from e

View File

@@ -50,6 +50,8 @@ class BrowserManager:
await browser_state.get_or_create_page(task.url) await browser_state.get_or_create_page(task.url)
self.pages[task.task_id] = browser_state self.pages[task.task_id] = browser_state
if task.workflow_run_id:
self.pages[task.workflow_run_id] = browser_state
return browser_state return browser_state
async def get_or_create_for_workflow_run(self, workflow_run: WorkflowRun, url: str | None = None) -> BrowserState: async def get_or_create_for_workflow_run(self, workflow_run: WorkflowRun, url: str | None = None) -> BrowserState:
@@ -95,8 +97,11 @@ class BrowserManager:
if browser_state: if browser_state:
path = browser_state.browser_artifacts.video_path path = browser_state.browser_artifacts.video_path
if path: if path:
with open(path, "rb") as f: try:
return f.read() with open(path, "rb") as f:
return f.read()
except FileNotFoundError:
pass
LOG.warning( LOG.warning(
"Video data not found for task", task_id=task_id, workflow_id=workflow_id, workflow_run_id=workflow_run_id "Video data not found for task", task_id=task_id, workflow_id=workflow_id, workflow_run_id=workflow_run_id
) )
@@ -135,18 +140,32 @@ class BrowserManager:
LOG.info("Cleaning up for task") LOG.info("Cleaning up for task")
browser_state_to_close = self.pages.pop(task_id, None) browser_state_to_close = self.pages.pop(task_id, None)
if browser_state_to_close: if browser_state_to_close:
# Stop tracing before closing the browser if tracing is enabled
if browser_state_to_close.browser_context and browser_state_to_close.browser_artifacts.traces_dir:
trace_path = f"{browser_state_to_close.browser_artifacts.traces_dir}/{task_id}.zip"
await browser_state_to_close.browser_context.tracing.stop(path=trace_path)
LOG.info("Stopped tracing", trace_path=trace_path)
await browser_state_to_close.close(close_browser_on_completion=close_browser_on_completion) await browser_state_to_close.close(close_browser_on_completion=close_browser_on_completion)
LOG.info("Task is cleaned up") LOG.info("Task is cleaned up")
return browser_state_to_close return browser_state_to_close
async def cleanup_for_workflow_run( async def cleanup_for_workflow_run(
self, workflow_run_id: str, close_browser_on_completion: bool = True self, workflow_run_id: str, task_ids: list[str], close_browser_on_completion: bool = True
) -> BrowserState | None: ) -> BrowserState | None:
LOG.info("Cleaning up for workflow run") LOG.info("Cleaning up for workflow run")
browser_state_to_close = self.pages.pop(workflow_run_id, None) browser_state_to_close = self.pages.pop(workflow_run_id, None)
if browser_state_to_close: if browser_state_to_close:
# Stop tracing before closing the browser if tracing is enabled
if browser_state_to_close.browser_context and browser_state_to_close.browser_artifacts.traces_dir:
trace_path = f"{browser_state_to_close.browser_artifacts.traces_dir}/{workflow_run_id}.zip"
await browser_state_to_close.browser_context.tracing.stop(path=trace_path)
LOG.info("Stopped tracing", trace_path=trace_path)
await browser_state_to_close.close(close_browser_on_completion=close_browser_on_completion) await browser_state_to_close.close(close_browser_on_completion=close_browser_on_completion)
for task_id in task_ids:
self.pages.pop(task_id, None)
LOG.info("Workflow run is cleaned up") LOG.info("Workflow run is cleaned up")
return browser_state_to_close return browser_state_to_close

View File

@@ -170,7 +170,7 @@ async def scrape_web_unsafe(
scroll_y_px = await scroll_to_top(page, drow_boxes=True) scroll_y_px = await scroll_to_top(page, drow_boxes=True)
# Checking max number of screenshots to prevent infinite loop # Checking max number of screenshots to prevent infinite loop
while scroll_y_px_old != scroll_y_px and len(screenshots) < SettingsManager.get_settings().MAX_NUM_SCREENSHOTS: while scroll_y_px_old != scroll_y_px and len(screenshots) < SettingsManager.get_settings().MAX_NUM_SCREENSHOTS:
screenshot = await page.screenshot(full_page=False) screenshot = await browser_state.take_screenshot(full_page=False)
screenshots.append(screenshot) screenshots.append(screenshot)
scroll_y_px_old = scroll_y_px scroll_y_px_old = scroll_y_px
LOG.info("Scrolling to next page", url=url, num_screenshots=len(screenshots)) LOG.info("Scrolling to next page", url=url, num_screenshots=len(screenshots))
@@ -348,9 +348,10 @@ def _build_element_links(elements: list[dict]) -> None:
listbox_text = element["text"] if "text" in element else "" listbox_text = element["text"] if "text" in element else ""
# WARNING: If a listbox has really little commont content (yes/no, etc.), # WARNING: If a listbox has really little commont content (yes/no, etc.),
# it might have conflict and will connect to wrong element. If so, code should be added to prevent that: # it might have conflict and will connect to wrong element
# if len(listbox_text) < 10: # if len(listbox_text) < 10:
# # do not support small listbox text as it's error proning. larger text match is more reliable # # do not support small listbox text for now as it's error proning. larger text match is more reliable
# LOG.info("Skip because too short listbox text", listbox_text=listbox_text)
# continue # continue
for text, linked_elements in text_to_elements_map.items(): for text, linked_elements in text_to_elements_map.items():
@@ -369,7 +370,6 @@ def _build_element_links(elements: list[dict]) -> None:
for context, linked_elements in context_to_elements_map.items(): for context, linked_elements in context_to_elements_map.items():
if listbox_text in context: if listbox_text in context:
for linked_element in linked_elements: for linked_element in linked_elements:
# if _ensure_nearby_rects(element["rect"], linked_element["rect"]):
if linked_element["id"] != element["id"]: if linked_element["id"] != element["id"]:
LOG.info( LOG.info(
"Match listbox to target element context", "Match listbox to target element context",