disable complete verification when CUA engine (#2728)

Co-authored-by: lawyzheng <lawyzheng1106@gmail.com>
This commit is contained in:
Shuchang Zheng
2025-06-17 00:25:58 -07:00
committed by GitHub
parent b241185aae
commit f1bc1a03db
9 changed files with 107 additions and 13 deletions

View File

@@ -0,0 +1,33 @@
"""add_engine_to_workflow_run_block
Revision ID: 2be3e0ba85ff
Revises: 2c6b27e8e961
Create Date: 2025-06-17 07:23:13.753617+00:00
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "2be3e0ba85ff"
down_revision: Union[str, None] = "2c6b27e8e961"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("workflow_run_blocks", sa.Column("engine", sa.String(), nullable=True))
op.create_index(op.f("ix_workflow_run_blocks_task_id"), "workflow_run_blocks", ["task_id"], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f("ix_workflow_run_blocks_task_id"), table_name="workflow_run_blocks")
op.drop_column("workflow_run_blocks", "engine")
# ### end Alembic commands ###

View File

@@ -73,8 +73,9 @@ from skyvern.forge.sdk.schemas.tasks import Task, TaskRequest, TaskResponse, Tas
from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext
from skyvern.forge.sdk.workflow.models.block import ActionBlock, BaseTaskBlock, ValidationBlock from skyvern.forge.sdk.workflow.models.block import ActionBlock, BaseTaskBlock, ValidationBlock
from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowRun, WorkflowRunStatus from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowRun, WorkflowRunStatus
from skyvern.schemas.runs import CUA_ENGINES, CUA_RUN_TYPES, RunEngine from skyvern.schemas.runs import CUA_ENGINES, RunEngine
from skyvern.services import run_service from skyvern.services import run_service
from skyvern.services.task_v1_service import is_cua_task
from skyvern.utils.image_resizer import Resolution from skyvern.utils.image_resizer import Resolution
from skyvern.utils.prompt_engine import load_prompt_with_elements from skyvern.utils.prompt_engine import load_prompt_with_elements
from skyvern.webeye.actions.action_types import ActionType from skyvern.webeye.actions.action_types import ActionType
@@ -268,6 +269,12 @@ class ForgeAgent:
cua_response: OpenAIResponse | None = None, cua_response: OpenAIResponse | None = None,
llm_caller: LLMCaller | None = None, llm_caller: LLMCaller | None = None,
) -> Tuple[Step, DetailedAgentStepOutput | None, Step | None]: ) -> Tuple[Step, DetailedAgentStepOutput | None, Step | None]:
# do not need to do complete verification when it's a CUA task
# 1. CUA executes only one action step by step -- it's pretty less likely to have a hallucination for completion or forget to return a complete
# 2. It will significantly slow down CUA tasks
if engine in CUA_ENGINES:
complete_verification = False
workflow_run: WorkflowRun | None = None workflow_run: WorkflowRun | None = None
if task.workflow_run_id: if task.workflow_run_id:
workflow_run = await app.DATABASE.get_workflow_run( workflow_run = await app.DATABASE.get_workflow_run(
@@ -1575,10 +1582,9 @@ class ForgeAgent:
step_id=step.step_id, step_id=step.step_id,
workflow_run_id=task.workflow_run_id, workflow_run_id=task.workflow_run_id,
) )
run_obj = await app.DATABASE.get_run(run_id=task.task_id, organization_id=task.organization_id)
scroll = True scroll = True
llm_key_override = task.llm_key llm_key_override = task.llm_key
if run_obj and run_obj.task_run_type in CUA_RUN_TYPES: if await is_cua_task(task=task):
scroll = False scroll = False
llm_key_override = None llm_key_override = None
@@ -2628,9 +2634,8 @@ class ForgeAgent:
step_result["actions_result"] = action_result_summary step_result["actions_result"] = action_result_summary
steps_results.append(step_result) steps_results.append(step_result)
run_obj = await app.DATABASE.get_run(run_id=task.task_id, organization_id=task.organization_id)
scroll = True scroll = True
if run_obj and run_obj.task_run_type in CUA_RUN_TYPES: if await is_cua_task(task=task):
scroll = False scroll = False
screenshots: list[bytes] = [] screenshots: list[bytes] = []
@@ -2880,8 +2885,7 @@ class ForgeAgent:
expire_verification_code=True, expire_verification_code=True,
) )
llm_key_override = task.llm_key llm_key_override = task.llm_key
run_obj = await app.DATABASE.get_run(run_id=task.task_id, organization_id=task.organization_id) if await is_cua_task(task=task):
if run_obj and run_obj.task_run_type in CUA_RUN_TYPES:
llm_key_override = None llm_key_override = None
llm_api_handler = LLMAPIHandlerFactory.get_override_llm_api_handler( llm_api_handler = LLMAPIHandlerFactory.get_override_llm_api_handler(
llm_key_override, default=app.LLM_API_HANDLER llm_key_override, default=app.LLM_API_HANDLER

View File

@@ -95,7 +95,7 @@ from skyvern.forge.sdk.workflow.models.workflow import (
WorkflowRunStatus, WorkflowRunStatus,
WorkflowStatus, WorkflowStatus,
) )
from skyvern.schemas.runs import ProxyLocation, RunType from skyvern.schemas.runs import ProxyLocation, RunEngine, RunType
from skyvern.webeye.actions.actions import Action from skyvern.webeye.actions.actions import Action
from skyvern.webeye.actions.models import AgentStepOutput from skyvern.webeye.actions.models import AgentStepOutput
@@ -2707,6 +2707,7 @@ class AgentDB:
status: BlockStatus = BlockStatus.running, status: BlockStatus = BlockStatus.running,
output: dict | list | str | None = None, output: dict | list | str | None = None,
continue_on_failure: bool = False, continue_on_failure: bool = False,
engine: RunEngine | None = None,
) -> WorkflowRunBlock: ) -> WorkflowRunBlock:
async with self.Session() as session: async with self.Session() as session:
new_workflow_run_block = WorkflowRunBlockModel( new_workflow_run_block = WorkflowRunBlockModel(
@@ -2719,6 +2720,7 @@ class AgentDB:
status=status, status=status,
output=output, output=output,
continue_on_failure=continue_on_failure, continue_on_failure=continue_on_failure,
engine=engine,
) )
session.add(new_workflow_run_block) session.add(new_workflow_run_block)
await session.commit() await session.commit()
@@ -2759,6 +2761,7 @@ class AgentDB:
wait_sec: int | None = None, wait_sec: int | None = None,
description: str | None = None, description: str | None = None,
block_workflow_run_id: str | None = None, block_workflow_run_id: str | None = None,
engine: str | None = None,
) -> WorkflowRunBlock: ) -> WorkflowRunBlock:
async with self.Session() as session: async with self.Session() as session:
workflow_run_block = ( workflow_run_block = (
@@ -2799,6 +2802,8 @@ class AgentDB:
workflow_run_block.description = description workflow_run_block.description = description
if block_workflow_run_id: if block_workflow_run_id:
workflow_run_block.block_workflow_run_id = block_workflow_run_id workflow_run_block.block_workflow_run_id = block_workflow_run_id
if engine:
workflow_run_block.engine = engine
await session.commit() await session.commit()
await session.refresh(workflow_run_block) await session.refresh(workflow_run_block)
else: else:
@@ -2830,6 +2835,25 @@ class AgentDB:
return convert_to_workflow_run_block(workflow_run_block, task=task) return convert_to_workflow_run_block(workflow_run_block, task=task)
raise NotFoundError(f"WorkflowRunBlock {workflow_run_block_id} not found") raise NotFoundError(f"WorkflowRunBlock {workflow_run_block_id} not found")
async def get_workflow_run_block_by_task_id(
self,
task_id: str,
organization_id: str | None = None,
) -> WorkflowRunBlock:
async with self.Session() as session:
workflow_run_block = (
await session.scalars(
select(WorkflowRunBlockModel).filter_by(task_id=task_id).filter_by(organization_id=organization_id)
)
).first()
if workflow_run_block:
task = None
task_id = workflow_run_block.task_id
if task_id:
task = await self.get_task(task_id, organization_id=organization_id)
return convert_to_workflow_run_block(workflow_run_block, task=task)
raise NotFoundError(f"WorkflowRunBlock not found by {task_id}")
async def get_workflow_run_blocks( async def get_workflow_run_blocks(
self, self,
workflow_run_id: str, workflow_run_id: str,

View File

@@ -573,13 +573,14 @@ class WorkflowRunBlockModel(Base):
parent_workflow_run_block_id = Column(String, nullable=True) parent_workflow_run_block_id = Column(String, nullable=True)
organization_id = Column(String, nullable=True) organization_id = Column(String, nullable=True)
description = Column(String, nullable=True) description = Column(String, nullable=True)
task_id = Column(String, nullable=True) task_id = Column(String, index=True, nullable=True)
label = Column(String, nullable=True) label = Column(String, nullable=True)
block_type = Column(String, nullable=False) block_type = Column(String, nullable=False)
status = Column(String, nullable=False) status = Column(String, nullable=False)
output = Column(JSON, nullable=True) output = Column(JSON, nullable=True)
continue_on_failure = Column(Boolean, nullable=False, default=False) continue_on_failure = Column(Boolean, nullable=False, default=False)
failure_reason = Column(String, nullable=True) failure_reason = Column(String, nullable=True)
engine = Column(String, nullable=True)
# for loop block # for loop block
loop_values = Column(JSON, nullable=True) loop_values = Column(JSON, nullable=True)

View File

@@ -460,6 +460,7 @@ def convert_to_workflow_run_block(
output=workflow_run_block_model.output, output=workflow_run_block_model.output,
continue_on_failure=workflow_run_block_model.continue_on_failure, continue_on_failure=workflow_run_block_model.continue_on_failure,
failure_reason=workflow_run_block_model.failure_reason, failure_reason=workflow_run_block_model.failure_reason,
engine=workflow_run_block_model.engine,
task_id=workflow_run_block_model.task_id, task_id=workflow_run_block_model.task_id,
loop_values=workflow_run_block_model.loop_values, loop_values=workflow_run_block_model.loop_values,
current_value=workflow_run_block_model.current_value, current_value=workflow_run_block_model.current_value,

View File

@@ -8,6 +8,7 @@ from pydantic import BaseModel
from skyvern.forge.sdk.schemas.task_v2 import Thought from skyvern.forge.sdk.schemas.task_v2 import Thought
from skyvern.forge.sdk.workflow.models.block import BlockType from skyvern.forge.sdk.workflow.models.block import BlockType
from skyvern.schemas.runs import RunEngine
from skyvern.webeye.actions.actions import Action from skyvern.webeye.actions.actions import Action
@@ -24,6 +25,7 @@ class WorkflowRunBlock(BaseModel):
output: dict | list | str | None = None output: dict | list | str | None = None
continue_on_failure: bool = False continue_on_failure: bool = False
failure_reason: str | None = None failure_reason: str | None = None
engine: RunEngine | None = None
task_id: str | None = None task_id: str | None = None
url: str | None = None url: str | None = None
navigation_goal: str | None = None navigation_goal: str | None = None

View File

@@ -288,7 +288,11 @@ class Block(BaseModel, abc.ABC):
**kwargs: dict, **kwargs: dict,
) -> BlockResult: ) -> BlockResult:
workflow_run_block_id = None workflow_run_block_id = None
engine: RunEngine | None = None
try: try:
if isinstance(self, BaseTaskBlock):
engine = self.engine
workflow_run_block = await app.DATABASE.create_workflow_run_block( workflow_run_block = await app.DATABASE.create_workflow_run_block(
workflow_run_id=workflow_run_id, workflow_run_id=workflow_run_id,
organization_id=organization_id, organization_id=organization_id,
@@ -296,6 +300,7 @@ class Block(BaseModel, abc.ABC):
label=self.label, label=self.label,
block_type=self.block_type, block_type=self.block_type,
continue_on_failure=self.continue_on_failure, continue_on_failure=self.continue_on_failure,
engine=engine,
) )
workflow_run_block_id = workflow_run_block.workflow_run_block_id workflow_run_block_id = workflow_run_block.workflow_run_block_id

View File

@@ -14,7 +14,7 @@ from skyvern.forge.sdk.executor.factory import AsyncExecutorFactory
from skyvern.forge.sdk.schemas.organizations import Organization from skyvern.forge.sdk.schemas.organizations import Organization
from skyvern.forge.sdk.schemas.task_generations import TaskGeneration, TaskGenerationBase from skyvern.forge.sdk.schemas.task_generations import TaskGeneration, TaskGenerationBase
from skyvern.forge.sdk.schemas.tasks import Task, TaskRequest, TaskResponse, TaskStatus from skyvern.forge.sdk.schemas.tasks import Task, TaskRequest, TaskResponse, TaskStatus
from skyvern.schemas.runs import RunEngine, RunType from skyvern.schemas.runs import CUA_ENGINES, CUA_RUN_TYPES, RunEngine, RunType
LOG = structlog.get_logger() LOG = structlog.get_logger()
@@ -148,3 +148,28 @@ async def get_task_v1_response(task_id: str, organization_id: str | None = None)
return await app.agent.build_task_response( return await app.agent.build_task_response(
task=task_obj, last_step=latest_step, failure_reason=failure_reason, need_browser_log=True task=task_obj, last_step=latest_step, failure_reason=failure_reason, need_browser_log=True
) )
async def is_cua_task(
*,
task: Task,
) -> bool:
"""Return True if the run, engine, or task indicates a CUA task."""
if task.workflow_run_id:
# it's a task based block, should look up the block run to see if it's a CUA task
block = await app.DATABASE.get_workflow_run_block_by_task_id(
task_id=task.task_id,
organization_id=task.organization_id,
)
if block.engine is not None and block.engine in CUA_ENGINES:
return True
run = await app.DATABASE.get_run(
run_id=task.task_id,
organization_id=task.organization_id,
)
if run and run.task_run_type in CUA_RUN_TYPES:
return True
return False

View File

@@ -71,7 +71,7 @@ from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.tasks import Task from skyvern.forge.sdk.schemas.tasks import Task
from skyvern.forge.sdk.services.bitwarden import BitwardenConstants from skyvern.forge.sdk.services.bitwarden import BitwardenConstants
from skyvern.forge.sdk.services.credentials import OnePasswordConstants from skyvern.forge.sdk.services.credentials import OnePasswordConstants
from skyvern.schemas.runs import CUA_RUN_TYPES from skyvern.services.task_v1_service import is_cua_task
from skyvern.utils.prompt_engine import CheckPhoneNumberFormatResponse, load_prompt_with_elements from skyvern.utils.prompt_engine import CheckPhoneNumberFormatResponse, load_prompt_with_elements
from skyvern.webeye.actions import actions, handler_utils from skyvern.webeye.actions import actions, handler_utils
from skyvern.webeye.actions.action_types import ActionType from skyvern.webeye.actions.action_types import ActionType
@@ -3377,9 +3377,8 @@ async def extract_information_for_navigation_goal(
local_datetime=datetime.now(context.tz_info).isoformat(), local_datetime=datetime.now(context.tz_info).isoformat(),
) )
task_run = await app.DATABASE.get_run(run_id=task.task_id, organization_id=task.organization_id)
llm_key_override = task.llm_key llm_key_override = task.llm_key
if task_run and task_run.task_run_type in CUA_RUN_TYPES: if await is_cua_task(task=task):
# CUA tasks should use the default data extraction llm key # CUA tasks should use the default data extraction llm key
llm_key_override = None llm_key_override = None