parallelize goal check within task (#3997)

This commit is contained in:
pedrohsdb
2025-11-13 17:18:32 -08:00
committed by GitHub
parent a95837783a
commit b7e28b075c
5 changed files with 675 additions and 330 deletions

View File

@@ -82,6 +82,11 @@ class MissingElement(SkyvernException):
) )
class MissingExtractActionsResponse(SkyvernException):
def __init__(self) -> None:
super().__init__("extract-actions response missing")
class MultipleElementsFound(SkyvernException): class MultipleElementsFound(SkyvernException):
def __init__(self, num: int, selector: str | None = None, element_id: str | None = None): def __init__(self, num: int, selector: str | None = None, element_id: str | None = None):
super().__init__( super().__init__(

View File

@@ -6,6 +6,7 @@ import random
import re import re
import string import string
from asyncio.exceptions import CancelledError from asyncio.exceptions import CancelledError
from dataclasses import dataclass
from datetime import UTC, datetime from datetime import UTC, datetime
from pathlib import Path from pathlib import Path
from typing import Any, Tuple, cast from typing import Any, Tuple, cast
@@ -48,6 +49,7 @@ from skyvern.exceptions import (
InvalidTaskStatusTransition, InvalidTaskStatusTransition,
InvalidWorkflowTaskURLState, InvalidWorkflowTaskURLState,
MissingBrowserStatePage, MissingBrowserStatePage,
MissingExtractActionsResponse,
NoTOTPVerificationCodeFound, NoTOTPVerificationCodeFound,
ScrapingFailed, ScrapingFailed,
SkyvernException, SkyvernException,
@@ -81,7 +83,7 @@ from skyvern.forge.sdk.core.security import generate_skyvern_webhook_signature
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.db.enums import TaskType from skyvern.forge.sdk.db.enums import TaskType
from skyvern.forge.sdk.log_artifacts import save_step_logs, save_task_logs from skyvern.forge.sdk.log_artifacts import save_step_logs, save_task_logs
from skyvern.forge.sdk.models import Step, StepStatus from skyvern.forge.sdk.models import SpeculativeLLMMetadata, Step, StepStatus
from skyvern.forge.sdk.schemas.files import FileInfo from skyvern.forge.sdk.schemas.files import FileInfo
from skyvern.forge.sdk.schemas.organizations import Organization from skyvern.forge.sdk.schemas.organizations import Organization
from skyvern.forge.sdk.schemas.tasks import Task, TaskRequest, TaskResponse, TaskStatus from skyvern.forge.sdk.schemas.tasks import Task, TaskRequest, TaskResponse, TaskStatus
@@ -136,6 +138,15 @@ EXTRACT_ACTION_PROMPT_NAME = "extract-actions"
EXTRACT_ACTION_CACHE_KEY_PREFIX = f"{EXTRACT_ACTION_TEMPLATE}-static" EXTRACT_ACTION_CACHE_KEY_PREFIX = f"{EXTRACT_ACTION_TEMPLATE}-static"
@dataclass
class SpeculativePlan:
scraped_page: ScrapedPage
extract_action_prompt: str
use_caching: bool
llm_json_response: dict[str, Any] | None
llm_metadata: SpeculativeLLMMetadata | None = None
class ActionLinkedNode: class ActionLinkedNode:
def __init__(self, action: Action) -> None: def __init__(self, action: Action) -> None:
self.action = action self.action = action
@@ -915,6 +926,21 @@ class ForgeAgent:
organization=organization, task=task, step=step, browser_state=browser_state organization=organization, task=task, step=step, browser_state=browser_state
) )
speculative_plan: SpeculativePlan | None = None
reuse_speculative_llm_response = False
speculative_llm_metadata: SpeculativeLLMMetadata | None = None
if context:
speculative_plan = context.speculative_plans.pop(step.step_id, None)
if speculative_plan:
step.is_speculative = False
scraped_page = speculative_plan.scraped_page
extract_action_prompt = speculative_plan.extract_action_prompt
use_caching = speculative_plan.use_caching
json_response = speculative_plan.llm_json_response
reuse_speculative_llm_response = json_response is not None
speculative_llm_metadata = speculative_plan.llm_metadata
else:
( (
scraped_page, scraped_page,
extract_action_prompt, extract_action_prompt,
@@ -925,9 +951,10 @@ class ForgeAgent:
browser_state, browser_state,
engine, engine,
) )
json_response = None
detailed_agent_step_output.scraped_page = scraped_page detailed_agent_step_output.scraped_page = scraped_page
detailed_agent_step_output.extract_action_prompt = extract_action_prompt detailed_agent_step_output.extract_action_prompt = extract_action_prompt
json_response = None
actions: list[Action] actions: list[Action]
if engine == RunEngine.openai_cua: if engine == RunEngine.openai_cua:
@@ -986,12 +1013,20 @@ class ForgeAgent:
if context: if context:
context.use_prompt_caching = True context.use_prompt_caching = True
if not reuse_speculative_llm_response:
json_response = await llm_api_handler( json_response = await llm_api_handler(
prompt=extract_action_prompt, prompt=extract_action_prompt,
prompt_name="extract-actions", prompt_name="extract-actions",
step=step, step=step,
screenshots=scraped_page.screenshots, screenshots=scraped_page.screenshots,
) )
else:
LOG.debug(
"Using speculative extract-actions response",
step_id=step.step_id,
)
if json_response is None:
raise MissingExtractActionsResponse()
try: try:
otp_json_response, otp_actions = await self.handle_potential_OTP_actions( otp_json_response, otp_actions = await self.handle_potential_OTP_actions(
task, step, scraped_page, browser_state, json_response task, step, scraped_page, browser_state, json_response
@@ -1035,6 +1070,14 @@ class ForgeAgent:
) )
] ]
if reuse_speculative_llm_response and speculative_llm_metadata:
await self._persist_speculative_llm_metadata(
step,
speculative_llm_metadata,
screenshots=scraped_page.screenshots,
)
speculative_llm_metadata = None
detailed_agent_step_output.actions = actions detailed_agent_step_output.actions = actions
if len(actions) == 0: if len(actions) == 0:
LOG.info( LOG.info(
@@ -1308,6 +1351,7 @@ class ForgeAgent:
break break
task_completes_on_download = task_block and task_block.complete_on_download and task.workflow_run_id task_completes_on_download = task_block and task_block.complete_on_download and task.workflow_run_id
enable_parallel_verification = False
if ( if (
not has_decisive_action not has_decisive_action
and not task_completes_on_download and not task_completes_on_download
@@ -1385,6 +1429,8 @@ class ForgeAgent:
status=StepStatus.completed, status=StepStatus.completed,
output=detailed_agent_step_output.to_agent_step_output(), output=detailed_agent_step_output.to_agent_step_output(),
) )
if enable_parallel_verification:
completed_step.speculative_original_status = StepStatus.completed
return completed_step, detailed_agent_step_output.get_clean_detailed_output() return completed_step, detailed_agent_step_output.get_clean_detailed_output()
except CancelledError: except CancelledError:
LOG.exception( LOG.exception(
@@ -1748,51 +1794,229 @@ class ForgeAgent:
return draw_boxes return draw_boxes
async def _pre_scrape_for_next_step( async def _speculate_next_step_plan(
self, self,
task: Task, task: Task,
step: Step, current_step: Step,
next_step: Step,
browser_state: BrowserState, browser_state: BrowserState,
engine: RunEngine, engine: RunEngine,
) -> ScrapedPage | None: ) -> SpeculativePlan | None:
"""
Pre-scrape the page for the next step while verification is running.
This is the expensive operation (5-10 seconds) that we want to run in parallel.
"""
try:
max_screenshot_number = settings.MAX_NUM_SCREENSHOTS
draw_boxes = True
scroll = True
if engine in CUA_ENGINES: if engine in CUA_ENGINES:
max_screenshot_number = 1
draw_boxes = False
scroll = False
# Check PostHog feature flag to skip screenshot annotations
draw_boxes = await self._should_skip_screenshot_annotations(task, draw_boxes)
scraped_page = await scrape_website(
browser_state,
task.url,
app.AGENT_FUNCTION.cleanup_element_tree_factory(task=task, step=step),
scrape_exclude=app.scrape_exclude,
max_screenshot_number=max_screenshot_number,
draw_boxes=draw_boxes,
scroll=scroll,
)
LOG.info( LOG.info(
"Pre-scraped page for next step in parallel with verification", "Skipping speculative extract-actions for CUA engine",
step_id=step.step_id, step_id=current_step.step_id,
num_elements=len(scraped_page.elements) if scraped_page else 0, task_id=task.task_id,
)
return None
try:
next_step.is_speculative = True
scraped_page, extract_action_prompt, use_caching = await self.build_and_record_step_prompt(
task,
next_step,
browser_state,
engine,
persist_artifacts=False,
)
llm_api_handler = LLMAPIHandlerFactory.get_override_llm_api_handler(
task.llm_key,
default=app.LLM_API_HANDLER,
)
llm_json_response = await llm_api_handler(
prompt=extract_action_prompt,
prompt_name="extract-actions",
step=next_step,
screenshots=scraped_page.screenshots,
)
LOG.info(
"Speculative extract-actions completed",
current_step_id=current_step.step_id,
synthetic_step_id=next_step.step_id,
)
metadata_copy = None
if next_step.speculative_llm_metadata is not None:
metadata_copy = next_step.speculative_llm_metadata.model_copy()
next_step.speculative_llm_metadata = None
next_step.is_speculative = False
return SpeculativePlan(
scraped_page=scraped_page,
extract_action_prompt=extract_action_prompt,
use_caching=use_caching,
llm_json_response=llm_json_response,
llm_metadata=metadata_copy,
) )
return scraped_page
except Exception: except Exception:
LOG.warning( LOG.warning(
"Failed to pre-scrape for next step, will re-scrape on next step execution", "Failed to run speculative extract-actions",
step_id=current_step.step_id,
exc_info=True,
)
next_step.is_speculative = False
return None
async def _persist_speculative_llm_metadata(
self,
step: Step,
metadata: SpeculativeLLMMetadata,
*,
screenshots: list[bytes] | None = None,
) -> None:
if not metadata:
return
LOG.debug("Persisting speculative LLM metadata")
if metadata.prompt:
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=metadata.prompt.encode("utf-8"),
artifact_type=ArtifactType.LLM_PROMPT,
screenshots=screenshots,
step=step,
)
if metadata.llm_request_json:
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=metadata.llm_request_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_REQUEST,
step=step,
)
if metadata.llm_response_json:
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=metadata.llm_response_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE,
step=step,
)
if metadata.parsed_response_json:
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=metadata.parsed_response_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
step=step,
)
if metadata.rendered_response_json:
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=metadata.rendered_response_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE_RENDERED,
step=step,
)
incremental_cost = metadata.llm_cost if metadata.llm_cost and metadata.llm_cost > 0 else None
incremental_input_tokens = (
metadata.input_tokens if metadata.input_tokens and metadata.input_tokens > 0 else None
)
incremental_output_tokens = (
metadata.output_tokens if metadata.output_tokens and metadata.output_tokens > 0 else None
)
incremental_reasoning_tokens = (
metadata.reasoning_tokens if metadata.reasoning_tokens and metadata.reasoning_tokens > 0 else None
)
incremental_cached_tokens = (
metadata.cached_tokens if metadata.cached_tokens and metadata.cached_tokens > 0 else None
)
if (
incremental_cost is not None
or incremental_input_tokens is not None
or incremental_output_tokens is not None
or incremental_reasoning_tokens is not None
or incremental_cached_tokens is not None
):
await app.DATABASE.update_step(
task_id=step.task_id,
step_id=step.step_id,
organization_id=step.organization_id,
incremental_cost=incremental_cost,
incremental_input_tokens=incremental_input_tokens,
incremental_output_tokens=incremental_output_tokens,
incremental_reasoning_tokens=incremental_reasoning_tokens,
incremental_cached_tokens=incremental_cached_tokens,
)
if incremental_input_tokens:
step.input_token_count += incremental_input_tokens
if incremental_output_tokens:
step.output_token_count += incremental_output_tokens
if incremental_reasoning_tokens:
step.reasoning_token_count = (step.reasoning_token_count or 0) + incremental_reasoning_tokens
if incremental_cached_tokens:
step.cached_token_count = (step.cached_token_count or 0) + incremental_cached_tokens
if incremental_cost:
step.step_cost += incremental_cost
step.speculative_llm_metadata = None
async def _persist_speculative_metadata_for_discarded_plan(
self,
step: Step,
speculative_task: asyncio.Future[SpeculativePlan | None],
*,
cancel_step: bool = False,
) -> None:
try:
plan = await asyncio.shield(speculative_task)
except CancelledError:
LOG.debug(
"Speculative extract-actions cancelled before metadata persistence",
step_id=step.step_id,
)
step.is_speculative = False
if cancel_step:
await self._cancel_speculative_step(step)
return
except Exception:
LOG.debug(
"Speculative extract-actions failed before metadata persistence",
step_id=step.step_id,
exc_info=True,
)
step.is_speculative = False
if cancel_step:
await self._cancel_speculative_step(step)
return
if not plan or not plan.llm_metadata:
step.is_speculative = False
if cancel_step:
await self._cancel_speculative_step(step)
return
try:
await self._persist_speculative_llm_metadata(
step,
plan.llm_metadata,
)
step.is_speculative = False
if cancel_step:
await self._cancel_speculative_step(step)
except Exception:
LOG.warning(
"Failed to persist speculative llm metadata for discarded plan",
step_id=step.step_id,
exc_info=True,
)
async def _cancel_speculative_step(self, step: Step) -> None:
if step.status == StepStatus.canceled:
return
try:
updated_step = await self.update_step(step, status=StepStatus.canceled)
step.status = updated_step.status
step.is_speculative = False
except Exception:
LOG.warning(
"Failed to cancel speculative step",
step_id=step.step_id, step_id=step.step_id,
exc_info=True, exc_info=True,
) )
return None
async def complete_verify( async def complete_verify(
self, page: Page, scraped_page: ScrapedPage, task: Task, step: Step, task_block: BaseTaskBlock | None = None self, page: Page, scraped_page: ScrapedPage, task: Task, step: Step, task_block: BaseTaskBlock | None = None
@@ -2099,6 +2323,8 @@ class ForgeAgent:
step: Step, step: Step,
browser_state: BrowserState, browser_state: BrowserState,
engine: RunEngine, engine: RunEngine,
*,
persist_artifacts: bool = True,
) -> tuple[ScrapedPage, str, bool]: ) -> tuple[ScrapedPage, str, bool]:
# Check if we have pre-scraped data from parallel verification optimization # Check if we have pre-scraped data from parallel verification optimization
context = skyvern_context.current() context = skyvern_context.current()
@@ -2178,6 +2404,7 @@ class ForgeAgent:
extract_action_prompt = "" extract_action_prompt = ""
use_caching = False use_caching = False
if persist_artifacts:
await app.ARTIFACT_MANAGER.create_artifact( await app.ARTIFACT_MANAGER.create_artifact(
step=step, step=step,
artifact_type=ArtifactType.HTML_SCRAPE, artifact_type=ArtifactType.HTML_SCRAPE,
@@ -2191,6 +2418,7 @@ class ForgeAgent:
url=task.url, url=task.url,
) )
# TODO: we only use HTML element for now, introduce a way to switch in the future # TODO: we only use HTML element for now, introduce a way to switch in the future
enable_speed_optimizations = getattr(context, "enable_speed_optimizations", False)
element_tree_format = ElementTreeFormat.HTML element_tree_format = ElementTreeFormat.HTML
# OPTIMIZATION: Use economy tree (skip SVGs) when ENABLE_SPEED_OPTIMIZATIONS is enabled # OPTIMIZATION: Use economy tree (skip SVGs) when ENABLE_SPEED_OPTIMIZATIONS is enabled
@@ -2248,6 +2476,7 @@ class ForgeAgent:
expire_verification_code=True, expire_verification_code=True,
) )
if persist_artifacts:
await app.ARTIFACT_MANAGER.create_artifact( await app.ARTIFACT_MANAGER.create_artifact(
step=step, step=step,
artifact_type=ArtifactType.VISIBLE_ELEMENTS_ID_CSS_MAP, artifact_type=ArtifactType.VISIBLE_ELEMENTS_ID_CSS_MAP,
@@ -2480,6 +2709,16 @@ class ForgeAgent:
task_llm_key=task.llm_key, task_llm_key=task.llm_key,
effective_llm_key=effective_llm_key, effective_llm_key=effective_llm_key,
) )
enable_speed_optimizations = context.enable_speed_optimizations
element_tree_format = ElementTreeFormat.HTML
if enable_speed_optimizations:
if step.retry_index == 0:
elements_for_prompt = scraped_page.build_economy_elements_tree(element_tree_format)
else:
elements_for_prompt = scraped_page.build_element_tree(element_tree_format)
else:
elements_for_prompt = scraped_page.build_element_tree(element_tree_format)
if template == EXTRACT_ACTION_TEMPLATE and cache_enabled: if template == EXTRACT_ACTION_TEMPLATE and cache_enabled:
try: try:
# Try to load split templates for caching # Try to load split templates for caching
@@ -2501,7 +2740,11 @@ class ForgeAgent:
"has_magic_link_page": context.has_magic_link_page(task.task_id), "has_magic_link_page": context.has_magic_link_page(task.task_id),
} }
static_prompt = prompt_engine.load_prompt(f"{template}-static", **prompt_kwargs) static_prompt = prompt_engine.load_prompt(f"{template}-static", **prompt_kwargs)
dynamic_prompt = prompt_engine.load_prompt(f"{template}-dynamic", **prompt_kwargs) dynamic_prompt = prompt_engine.load_prompt(
f"{template}-dynamic",
elements=elements_for_prompt,
**prompt_kwargs,
)
# Store static prompt for caching and continue sending it alongside the dynamic section. # Store static prompt for caching and continue sending it alongside the dynamic section.
# Vertex explicit caching expects the static content to still be present in the request so the # Vertex explicit caching expects the static content to still be present in the request so the
@@ -3250,12 +3493,11 @@ class ForgeAgent:
the standard flow would have called check_user_goal_complete in agent_step). the standard flow would have called check_user_goal_complete in agent_step).
""" """
LOG.info( LOG.info(
"Starting parallel user goal verification optimization", "Starting parallel user goal verification with speculative extract-actions",
step_id=step.step_id, step_id=step.step_id,
task_id=task.task_id, task_id=task.task_id,
) )
# Task 1: Verify user goal (typically 2-5 seconds)
verification_task = asyncio.create_task( verification_task = asyncio.create_task(
self.check_user_goal_complete( self.check_user_goal_complete(
page=page, page=page,
@@ -3267,18 +3509,31 @@ class ForgeAgent:
name=f"verify_goal_{step.step_id}", name=f"verify_goal_{step.step_id}",
) )
# Task 2: Pre-scrape for next step (typically 5-10 seconds) next_step = await app.DATABASE.create_step(
pre_scrape_task = asyncio.create_task( task_id=task.task_id,
self._pre_scrape_for_next_step( order=step.order + 1,
retry_index=0,
organization_id=task.organization_id,
)
LOG.debug(
"Waiting before launching speculative plan",
step_id=step.step_id,
task_id=task.task_id,
)
await asyncio.sleep(1.0)
speculative_task = asyncio.create_task(
self._speculate_next_step_plan(
task=task, task=task,
step=step, current_step=step,
next_step=next_step,
browser_state=browser_state, browser_state=browser_state,
engine=engine, engine=engine,
), ),
name=f"pre_scrape_{step.step_id}", name=f"speculate_next_step_{step.step_id}",
) )
# Wait for verification to complete first (faster of the two)
try: try:
complete_action = await verification_task complete_action = await verification_task
except Exception: except Exception:
@@ -3290,25 +3545,15 @@ class ForgeAgent:
complete_action = None complete_action = None
if complete_action is not None: if complete_action is not None:
# Goal achieved or should terminate! Cancel the pre-scraping task asyncio.create_task(
is_terminate = isinstance(complete_action, TerminateAction) self._persist_speculative_metadata_for_discarded_plan(
LOG.info( next_step,
"Parallel verification: goal achieved or termination required, cancelling pre-scraping", speculative_task,
step_id=step.step_id, cancel_step=True,
task_id=task.task_id, )
is_terminate=is_terminate,
) )
pre_scrape_task.cancel()
try:
await pre_scrape_task # Clean up the cancelled task
except asyncio.CancelledError:
LOG.debug("Pre-scraping cancelled successfully", step_id=step.step_id)
except Exception:
LOG.debug("Pre-scraping task cleanup failed", step_id=step.step_id, exc_info=True)
working_page = page working_page = page or await browser_state.must_get_working_page()
if working_page is None:
working_page = await browser_state.must_get_working_page()
if step.output is None: if step.output is None:
step.output = AgentStepOutput(action_results=[], actions_and_results=[], errors=[]) step.output = AgentStepOutput(action_results=[], actions_and_results=[], errors=[])
@@ -3333,21 +3578,27 @@ class ForgeAgent:
if isinstance(persisted_action, DecisiveAction) and persisted_action.errors: if isinstance(persisted_action, DecisiveAction) and persisted_action.errors:
step.output.errors.extend(persisted_action.errors) step.output.errors.extend(persisted_action.errors)
if is_terminate: if isinstance(persisted_action, TerminateAction):
# Mark task as terminated/failed
# Note: This requires the USE_TERMINATION_AWARE_COMPLETE_VERIFICATION experiment to be enabled
LOG.warning( LOG.warning(
"Parallel verification: termination required, marking task as terminated (termination-aware experiment)", "Parallel verification: termination required, marking task as terminated",
step_id=step.step_id, step_id=step.step_id,
task_id=task.task_id, task_id=task.task_id,
reasoning=complete_action.reasoning, reasoning=complete_action.reasoning,
) )
last_step = await self.update_step(step, output=step.output, is_last=True) final_status = step.speculative_original_status or StepStatus.completed
step.speculative_original_status = None
step.status = final_status
last_step = await self.update_step(
step,
status=final_status,
output=step.output,
is_last=True,
)
task_errors = None task_errors = None
if isinstance(persisted_action, TerminateAction) and persisted_action.errors: if persisted_action.errors:
task_errors = [error.model_dump() for error in persisted_action.errors] task_errors = [error.model_dump() for error in persisted_action.errors]
failure_reason = persisted_action.reasoning failure_reason = persisted_action.reasoning
if isinstance(persisted_action, TerminateAction) and persisted_action.errors: if persisted_action.errors:
failure_reason = "; ".join(error.reasoning for error in persisted_action.errors) failure_reason = "; ".join(error.reasoning for error in persisted_action.errors)
await self.update_task( await self.update_task(
task, task,
@@ -3356,16 +3607,21 @@ class ForgeAgent:
errors=task_errors, errors=task_errors,
) )
return True, last_step, None return True, last_step, None
else:
# Mark task as complete
# Note: Step is already marked as completed by agent_step
# We don't add the complete action to the step output since the step is already finalized
LOG.info( LOG.info(
"Parallel verification: goal achieved, marking task as complete", "Parallel verification: goal achieved, marking task as complete",
step_id=step.step_id, step_id=step.step_id,
task_id=task.task_id, task_id=task.task_id,
) )
last_step = await self.update_step(step, output=step.output, is_last=True) final_status = step.speculative_original_status or StepStatus.completed
step.speculative_original_status = None
step.status = final_status
last_step = await self.update_step(
step,
status=final_status,
output=step.output,
is_last=True,
)
extracted_information = await self.get_extracted_information_for_task(task) extracted_information = await self.get_extracted_information_for_task(task)
await self.update_task( await self.update_task(
task, task,
@@ -3373,25 +3629,26 @@ class ForgeAgent:
extracted_information=extracted_information, extracted_information=extracted_information,
) )
return True, last_step, None return True, last_step, None
else:
# Goal not achieved - wait for pre-scraping to complete
LOG.info( LOG.info(
"Parallel verification: goal not achieved, using pre-scraped data for next step", "Parallel verification: goal not achieved, awaiting speculative extract-actions",
step_id=step.step_id, step_id=step.step_id,
task_id=task.task_id, task_id=task.task_id,
) )
try: try:
pre_scraped_page = await pre_scrape_task speculative_plan = await speculative_task
except CancelledError:
LOG.debug("Speculative extract-actions cancelled after verification finished", step_id=step.step_id)
speculative_plan = None
except Exception: except Exception:
LOG.warning( LOG.warning(
"Pre-scraping failed, next step will re-scrape", "Speculative extract-actions failed, next step will run sequentially",
step_id=step.step_id, step_id=step.step_id,
exc_info=True, exc_info=True,
) )
pre_scraped_page = None speculative_plan = None
# Check max steps before creating next step
context = skyvern_context.current() context = skyvern_context.current()
override_max_steps_per_run = context.max_steps_override if context else None override_max_steps_per_run = context.max_steps_override if context else None
max_steps_per_run = ( max_steps_per_run = (
@@ -3408,7 +3665,15 @@ class ForgeAgent:
step_retry=step.retry_index, step_retry=step.retry_index,
max_steps=max_steps_per_run, max_steps=max_steps_per_run,
) )
last_step = await self.update_step(step, is_last=True) final_status = step.speculative_original_status or StepStatus.completed
step.speculative_original_status = None
step.status = final_status
last_step = await self.update_step(
step,
status=final_status,
output=step.output,
is_last=True,
)
generated_failure_reason = await self.summary_failure_reason_for_max_steps( generated_failure_reason = await self.summary_failure_reason_for_max_steps(
organization=organization, organization=organization,
@@ -3421,6 +3686,8 @@ class ForgeAgent:
error.model_dump() for error in generated_failure_reason.errors error.model_dump() for error in generated_failure_reason.errors
] ]
await self._cancel_speculative_step(next_step)
await self.update_task( await self.update_task(
task, task,
status=TaskStatus.failed, status=TaskStatus.failed,
@@ -3429,27 +3696,17 @@ class ForgeAgent:
) )
return False, last_step, None return False, last_step, None
# Create next step if speculative_plan:
next_step = await app.DATABASE.create_step( context = skyvern_context.ensure_context()
task_id=task.task_id, context.speculative_plans[next_step.step_id] = speculative_plan
order=step.order + 1, LOG.info(
retry_index=0, "Stored speculative extract-actions plan for next step",
organization_id=task.organization_id, current_step_id=step.step_id,
next_step_id=next_step.step_id,
) )
# Store pre-scraped data in context for next step to use step.status = step.speculative_original_status or StepStatus.completed
if pre_scraped_page: step.speculative_original_status = None
context = skyvern_context.ensure_context()
context.next_step_pre_scraped_data = {
"step_id": next_step.step_id,
"scraped_page": pre_scraped_page,
"timestamp": datetime.now(UTC),
}
LOG.info(
"Stored pre-scraped data for next step",
step_id=next_step.step_id,
num_elements=len(pre_scraped_page.elements),
)
return None, None, next_step return None, None, next_step

View File

@@ -29,7 +29,7 @@ from skyvern.forge.sdk.api.llm.ui_tars_response import UITarsResponse
from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, llm_messages_builder_with_history, parse_api_response from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, llm_messages_builder_with_history, parse_api_response
from skyvern.forge.sdk.artifact.models import ArtifactType from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.models import Step from skyvern.forge.sdk.models import SpeculativeLLMMetadata, Step
from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion
from skyvern.forge.sdk.schemas.task_v2 import TaskV2, Thought from skyvern.forge.sdk.schemas.task_v2 import TaskV2, Thought
from skyvern.forge.sdk.trace import TraceManager from skyvern.forge.sdk.trace import TraceManager
@@ -260,7 +260,8 @@ class LLMAPIHandlerFactory:
) )
context = skyvern_context.current() context = skyvern_context.current()
if context and len(context.hashed_href_map) > 0: is_speculative_step = step.is_speculative if step else False
if context and len(context.hashed_href_map) > 0 and step and not is_speculative_step:
await app.ARTIFACT_MANAGER.create_llm_artifact( await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(context.hashed_href_map, indent=2).encode("utf-8"), data=json.dumps(context.hashed_href_map, indent=2).encode("utf-8"),
artifact_type=ArtifactType.HASHED_HREF_MAP, artifact_type=ArtifactType.HASHED_HREF_MAP,
@@ -270,8 +271,10 @@ class LLMAPIHandlerFactory:
ai_suggestion=ai_suggestion, ai_suggestion=ai_suggestion,
) )
llm_prompt_value = prompt
if step and not is_speculative_step:
await app.ARTIFACT_MANAGER.create_llm_artifact( await app.ARTIFACT_MANAGER.create_llm_artifact(
data=prompt.encode("utf-8"), data=llm_prompt_value.encode("utf-8"),
artifact_type=ArtifactType.LLM_PROMPT, artifact_type=ArtifactType.LLM_PROMPT,
screenshots=screenshots, screenshots=screenshots,
step=step, step=step,
@@ -330,15 +333,16 @@ class LLMAPIHandlerFactory:
cache_attached=True, cache_attached=True,
) )
await app.ARTIFACT_MANAGER.create_llm_artifact( llm_request_payload = {
data=json.dumps(
{
"model": llm_key, "model": llm_key,
"messages": messages, "messages": messages,
**parameters, **parameters,
"vertex_cache_attached": vertex_cache_attached, "vertex_cache_attached": vertex_cache_attached,
} }
).encode("utf-8"), llm_request_json = json.dumps(llm_request_payload)
if step and not is_speculative_step:
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=llm_request_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_REQUEST, artifact_type=ArtifactType.LLM_REQUEST,
step=step, step=step,
task_v2=task_v2, task_v2=task_v2,
@@ -382,8 +386,10 @@ class LLMAPIHandlerFactory:
) )
raise LLMProviderError(llm_key) from e raise LLMProviderError(llm_key) from e
llm_response_json = response.model_dump_json(indent=2)
if step and not is_speculative_step:
await app.ARTIFACT_MANAGER.create_llm_artifact( await app.ARTIFACT_MANAGER.create_llm_artifact(
data=response.model_dump_json(indent=2).encode("utf-8"), data=llm_response_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE, artifact_type=ArtifactType.LLM_RESPONSE,
step=step, step=step,
task_v2=task_v2, task_v2=task_v2,
@@ -424,7 +430,7 @@ class LLMAPIHandlerFactory:
# Fallback for Vertex/Gemini: LiteLLM exposes cache_read_input_tokens on usage # Fallback for Vertex/Gemini: LiteLLM exposes cache_read_input_tokens on usage
if cached_tokens == 0: if cached_tokens == 0:
cached_tokens = getattr(response.usage, "cache_read_input_tokens", 0) or 0 cached_tokens = getattr(response.usage, "cache_read_input_tokens", 0) or 0
if step: if step and not is_speculative_step:
await app.DATABASE.update_step( await app.DATABASE.update_step(
task_id=step.task_id, task_id=step.task_id,
step_id=step.step_id, step_id=step.step_id,
@@ -446,8 +452,10 @@ class LLMAPIHandlerFactory:
cached_token_count=cached_tokens if cached_tokens > 0 else None, cached_token_count=cached_tokens if cached_tokens > 0 else None,
) )
parsed_response = parse_api_response(response, llm_config.add_assistant_prefix) parsed_response = parse_api_response(response, llm_config.add_assistant_prefix)
parsed_response_json = json.dumps(parsed_response, indent=2)
if step and not is_speculative_step:
await app.ARTIFACT_MANAGER.create_llm_artifact( await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(parsed_response, indent=2).encode("utf-8"), data=parsed_response_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE_PARSED, artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
step=step, step=step,
task_v2=task_v2, task_v2=task_v2,
@@ -455,12 +463,15 @@ class LLMAPIHandlerFactory:
ai_suggestion=ai_suggestion, ai_suggestion=ai_suggestion,
) )
rendered_response_json = None
if context and len(context.hashed_href_map) > 0: if context and len(context.hashed_href_map) > 0:
llm_content = json.dumps(parsed_response) llm_content = json.dumps(parsed_response)
rendered_content = Template(llm_content).render(context.hashed_href_map) rendered_content = Template(llm_content).render(context.hashed_href_map)
parsed_response = json.loads(rendered_content) parsed_response = json.loads(rendered_content)
rendered_response_json = json.dumps(parsed_response, indent=2)
if step and not is_speculative_step:
await app.ARTIFACT_MANAGER.create_llm_artifact( await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(parsed_response, indent=2).encode("utf-8"), data=rendered_response_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE_RENDERED, artifact_type=ArtifactType.LLM_RESPONSE_RENDERED,
step=step, step=step,
task_v2=task_v2, task_v2=task_v2,
@@ -489,6 +500,23 @@ class LLMAPIHandlerFactory:
llm_cost=llm_cost if llm_cost > 0 else None, llm_cost=llm_cost if llm_cost > 0 else None,
) )
if step and is_speculative_step:
step.speculative_llm_metadata = SpeculativeLLMMetadata(
prompt=llm_prompt_value,
llm_request_json=llm_request_json,
llm_response_json=llm_response_json,
parsed_response_json=parsed_response_json,
rendered_response_json=rendered_response_json,
llm_key=llm_key,
model=main_model_group,
duration_seconds=duration_seconds,
input_tokens=prompt_tokens if prompt_tokens > 0 else None,
output_tokens=completion_tokens if completion_tokens > 0 else None,
reasoning_tokens=reasoning_tokens if reasoning_tokens > 0 else None,
cached_tokens=cached_tokens if cached_tokens > 0 else None,
llm_cost=llm_cost if llm_cost > 0 else None,
)
return parsed_response return parsed_response
llm_api_handler_with_router_and_fallback.llm_key = llm_key # type: ignore[attr-defined] llm_api_handler_with_router_and_fallback.llm_key = llm_key # type: ignore[attr-defined]
@@ -547,7 +575,8 @@ class LLMAPIHandlerFactory:
) )
context = skyvern_context.current() context = skyvern_context.current()
if context and len(context.hashed_href_map) > 0: is_speculative_step = step.is_speculative if step else False
if context and len(context.hashed_href_map) > 0 and step and not is_speculative_step:
await app.ARTIFACT_MANAGER.create_llm_artifact( await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(context.hashed_href_map, indent=2).encode("utf-8"), data=json.dumps(context.hashed_href_map, indent=2).encode("utf-8"),
artifact_type=ArtifactType.HASHED_HREF_MAP, artifact_type=ArtifactType.HASHED_HREF_MAP,
@@ -557,8 +586,10 @@ class LLMAPIHandlerFactory:
ai_suggestion=ai_suggestion, ai_suggestion=ai_suggestion,
) )
llm_prompt_value = prompt
if step and not is_speculative_step:
await app.ARTIFACT_MANAGER.create_llm_artifact( await app.ARTIFACT_MANAGER.create_llm_artifact(
data=prompt.encode("utf-8"), data=llm_prompt_value.encode("utf-8"),
artifact_type=ArtifactType.LLM_PROMPT, artifact_type=ArtifactType.LLM_PROMPT,
screenshots=screenshots, screenshots=screenshots,
step=step, step=step,
@@ -630,16 +661,17 @@ class LLMAPIHandlerFactory:
cache_attached=True, cache_attached=True,
) )
await app.ARTIFACT_MANAGER.create_llm_artifact( llm_request_payload = {
data=json.dumps(
{
"model": model_name, "model": model_name,
"messages": messages, "messages": messages,
# we're not using active_parameters here because it may contain sensitive information # we're not using active_parameters here because it may contain sensitive information
**parameters, **parameters,
"vertex_cache_attached": vertex_cache_attached, "vertex_cache_attached": vertex_cache_attached,
} }
).encode("utf-8"), llm_request_json = json.dumps(llm_request_payload)
if step and not is_speculative_step:
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=llm_request_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_REQUEST, artifact_type=ArtifactType.LLM_REQUEST,
step=step, step=step,
task_v2=task_v2, task_v2=task_v2,
@@ -692,8 +724,10 @@ class LLMAPIHandlerFactory:
) )
raise LLMProviderError(llm_key) from e raise LLMProviderError(llm_key) from e
llm_response_json = response.model_dump_json(indent=2)
if step and not is_speculative_step:
await app.ARTIFACT_MANAGER.create_llm_artifact( await app.ARTIFACT_MANAGER.create_llm_artifact(
data=response.model_dump_json(indent=2).encode("utf-8"), data=llm_response_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE, artifact_type=ArtifactType.LLM_RESPONSE,
step=step, step=step,
task_v2=task_v2, task_v2=task_v2,
@@ -912,7 +946,8 @@ class LLMCaller:
active_parameters.update(self.llm_config.litellm_params) # type: ignore active_parameters.update(self.llm_config.litellm_params) # type: ignore
context = skyvern_context.current() context = skyvern_context.current()
if context and len(context.hashed_href_map) > 0: is_speculative_step = step.is_speculative if step else False
if context and len(context.hashed_href_map) > 0 and step and not is_speculative_step:
await app.ARTIFACT_MANAGER.create_llm_artifact( await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(context.hashed_href_map, indent=2).encode("utf-8"), data=json.dumps(context.hashed_href_map, indent=2).encode("utf-8"),
artifact_type=ArtifactType.HASHED_HREF_MAP, artifact_type=ArtifactType.HASHED_HREF_MAP,
@@ -939,7 +974,8 @@ class LLMCaller:
tool["display_width_px"] = target_dimension["width"] tool["display_width_px"] = target_dimension["width"]
screenshots = resize_screenshots(screenshots, target_dimension) screenshots = resize_screenshots(screenshots, target_dimension)
if prompt: llm_prompt_value = prompt or ""
if prompt and step and not is_speculative_step:
await app.ARTIFACT_MANAGER.create_llm_artifact( await app.ARTIFACT_MANAGER.create_llm_artifact(
data=prompt.encode("utf-8"), data=prompt.encode("utf-8"),
artifact_type=ArtifactType.LLM_PROMPT, artifact_type=ArtifactType.LLM_PROMPT,
@@ -971,15 +1007,16 @@ class LLMCaller:
screenshots, screenshots,
message_pattern=message_pattern, message_pattern=message_pattern,
) )
await app.ARTIFACT_MANAGER.create_llm_artifact( llm_request_payload = {
data=json.dumps(
{
"model": self.llm_config.model_name, "model": self.llm_config.model_name,
"messages": messages, "messages": messages,
# we're not using active_parameters here because it may contain sensitive information # we're not using active_parameters here because it may contain sensitive information
**parameters, **parameters,
} }
).encode("utf-8"), llm_request_json = json.dumps(llm_request_payload)
if step and not is_speculative_step:
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=llm_request_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_REQUEST, artifact_type=ArtifactType.LLM_REQUEST,
step=step, step=step,
task_v2=task_v2, task_v2=task_v2,
@@ -1019,8 +1056,10 @@ class LLMCaller:
LOG.exception("LLM request failed unexpectedly", llm_key=self.llm_key) LOG.exception("LLM request failed unexpectedly", llm_key=self.llm_key)
raise LLMProviderError(self.llm_key) from e raise LLMProviderError(self.llm_key) from e
llm_response_json = response.model_dump_json(indent=2)
if step and not is_speculative_step:
await app.ARTIFACT_MANAGER.create_llm_artifact( await app.ARTIFACT_MANAGER.create_llm_artifact(
data=response.model_dump_json(indent=2).encode("utf-8"), data=llm_response_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE, artifact_type=ArtifactType.LLM_RESPONSE,
step=step, step=step,
task_v2=task_v2, task_v2=task_v2,
@@ -1029,7 +1068,7 @@ class LLMCaller:
) )
call_stats = await self.get_call_stats(response) call_stats = await self.get_call_stats(response)
if step: if step and not is_speculative_step:
await app.DATABASE.update_step( await app.DATABASE.update_step(
task_id=step.task_id, task_id=step.task_id,
step_id=step.step_id, step_id=step.step_id,
@@ -1051,6 +1090,34 @@ class LLMCaller:
thought_cost=call_stats.llm_cost, thought_cost=call_stats.llm_cost,
) )
parsed_response = parse_api_response(response, self.llm_config.add_assistant_prefix)
parsed_response_json = json.dumps(parsed_response, indent=2)
if step and not is_speculative_step:
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=parsed_response_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
step=step,
task_v2=task_v2,
thought=thought,
ai_suggestion=ai_suggestion,
)
rendered_response_json = None
if context and len(context.hashed_href_map) > 0:
llm_content = json.dumps(parsed_response)
rendered_content = Template(llm_content).render(context.hashed_href_map)
parsed_response = json.loads(rendered_content)
rendered_response_json = json.dumps(parsed_response, indent=2)
if step and not is_speculative_step:
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=rendered_response_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE_RENDERED,
step=step,
task_v2=task_v2,
thought=thought,
ai_suggestion=ai_suggestion,
)
organization_id = organization_id or ( organization_id = organization_id or (
step.organization_id if step else (thought.organization_id if thought else None) step.organization_id if step else (thought.organization_id if thought else None)
) )
@@ -1071,32 +1138,27 @@ class LLMCaller:
cached_tokens=call_stats.cached_tokens if call_stats and call_stats.cached_tokens else None, cached_tokens=call_stats.cached_tokens if call_stats and call_stats.cached_tokens else None,
llm_cost=call_stats.llm_cost if call_stats and call_stats.llm_cost else None, llm_cost=call_stats.llm_cost if call_stats and call_stats.llm_cost else None,
) )
if step and is_speculative_step:
step.speculative_llm_metadata = SpeculativeLLMMetadata(
prompt=llm_prompt_value,
llm_request_json=llm_request_json,
llm_response_json=llm_response_json,
parsed_response_json=parsed_response_json,
rendered_response_json=rendered_response_json,
llm_key=self.llm_key,
model=self.llm_config.model_name,
duration_seconds=duration_seconds,
input_tokens=call_stats.input_tokens,
output_tokens=call_stats.output_tokens,
reasoning_tokens=call_stats.reasoning_tokens,
cached_tokens=call_stats.cached_tokens,
llm_cost=call_stats.llm_cost,
)
if raw_response: if raw_response:
return response.model_dump(exclude_none=True) return response.model_dump(exclude_none=True)
parsed_response = parse_api_response(response, self.llm_config.add_assistant_prefix)
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
step=step,
task_v2=task_v2,
thought=thought,
ai_suggestion=ai_suggestion,
)
if context and len(context.hashed_href_map) > 0:
llm_content = json.dumps(parsed_response)
rendered_content = Template(llm_content).render(context.hashed_href_map)
parsed_response = json.loads(rendered_content)
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE_RENDERED,
step=step,
task_v2=task_v2,
thought=thought,
ai_suggestion=ai_suggestion,
)
return parsed_response return parsed_response
def get_screenshot_resize_target_dimension(self, window_dimension: Resolution | None) -> Resolution: def get_screenshot_resize_target_dimension(self, window_dimension: Resolution | None) -> Resolution:

View File

@@ -62,6 +62,7 @@ class SkyvernContext:
# parallel verification optimization # parallel verification optimization
# stores pre-scraped data for next step to avoid re-scraping # stores pre-scraped data for next step to avoid re-scraping
next_step_pre_scraped_data: dict[str, Any] | None = None next_step_pre_scraped_data: dict[str, Any] | None = None
speculative_plans: dict[str, Any] = field(default_factory=dict)
""" """
Example output value: Example output value:

View File

@@ -39,6 +39,22 @@ class StepStatus(StrEnum):
return self in status_is_terminal return self in status_is_terminal
class SpeculativeLLMMetadata(BaseModel):
prompt: str
llm_request_json: str
llm_response_json: str | None = None
parsed_response_json: str | None = None
rendered_response_json: str | None = None
llm_key: str | None = None
model: str | None = None
duration_seconds: float | None = None
input_tokens: int | None = None
output_tokens: int | None = None
reasoning_tokens: int | None = None
cached_tokens: int | None = None
llm_cost: float | None = None
class Step(BaseModel): class Step(BaseModel):
created_at: datetime created_at: datetime
modified_at: datetime modified_at: datetime
@@ -55,6 +71,9 @@ class Step(BaseModel):
reasoning_token_count: int | None = None reasoning_token_count: int | None = None
cached_token_count: int | None = None cached_token_count: int | None = None
step_cost: float = 0 step_cost: float = 0
is_speculative: bool = False
speculative_original_status: StepStatus | None = None
speculative_llm_metadata: SpeculativeLLMMetadata | None = None
def validate_update( def validate_update(
self, self,
@@ -64,7 +83,7 @@ class Step(BaseModel):
) -> None: ) -> None:
old_status = self.status old_status = self.status
if status and not old_status.can_update_to(status): if status and status != old_status and not old_status.can_update_to(status):
raise ValueError(f"invalid_status_transition({old_status},{status},{self.step_id})") raise ValueError(f"invalid_status_transition({old_status},{status},{self.step_id})")
if status == StepStatus.canceled: if status == StepStatus.canceled:
@@ -83,6 +102,7 @@ class Step(BaseModel):
old_status not in [StepStatus.running, StepStatus.created] old_status not in [StepStatus.running, StepStatus.created]
and self.output is not None and self.output is not None
and output is not None and output is not None
and not (status == old_status == StepStatus.completed)
): ):
raise ValueError(f"cant_override_output({self.step_id})") raise ValueError(f"cant_override_output({self.step_id})")