add actions db model and caching V0 (#980)

This commit is contained in:
Shuchang Zheng
2024-10-15 12:06:50 -07:00
committed by GitHub
parent e7583ac878
commit 9048cdfa73
19 changed files with 731 additions and 90 deletions

View File

@@ -490,3 +490,8 @@ class IllegitComplete(SkyvernException):
def __init__(self, data: dict | None = None) -> None:
data_str = f", data={data}" if data else ""
super().__init__(f"Illegit complete{data_str}")
class CachedActionPlanError(SkyvernException):
def __init__(self, message: str) -> None:
super().__init__(message)

View File

@@ -51,9 +51,14 @@ from skyvern.webeye.actions.actions import (
WebAction,
parse_actions,
)
from skyvern.webeye.actions.handler import ActionHandler, handle_complete_action, poll_verification_code
from skyvern.webeye.actions.caching import retrieve_action_plan
from skyvern.webeye.actions.handler import (
ActionHandler,
extract_information_for_navigation_goal,
poll_verification_code,
)
from skyvern.webeye.actions.models import AgentStepOutput, DetailedAgentStepOutput
from skyvern.webeye.actions.responses import ActionResult
from skyvern.webeye.actions.responses import ActionResult, ActionSuccess
from skyvern.webeye.browser_factory import BrowserState
from skyvern.webeye.scraper.scraper import ElementTreeFormat, ScrapedPage, scrape_website
from skyvern.webeye.utils.page import SkyvernFrame
@@ -553,7 +558,22 @@ class ForgeAgent:
detailed_agent_step_output.extract_action_prompt = extract_action_prompt
json_response = None
actions: list[Action]
if task.navigation_goal:
using_cached_action_plan = False
if not task.navigation_goal:
actions = [
CompleteAction(
reasoning="Task has no navigation goal.",
data_extraction_goal=task.data_extraction_goal,
)
]
elif (
task_block
and task_block.cache_actions
and (actions := await retrieve_action_plan(task, step, scraped_page))
):
using_cached_action_plan = True
else:
self.async_operation_pool.run_operation(task.task_id, AgentPhase.llm)
json_response = await app.LLM_API_HANDLER(
prompt=extract_action_prompt,
@@ -569,14 +589,8 @@ class ForgeAgent:
)
detailed_agent_step_output.llm_response = json_response
actions = parse_actions(task, json_response["actions"])
else:
actions = [
CompleteAction(
reasoning="Task has no navigation goal.",
data_extraction_goal=task.data_extraction_goal,
)
]
actions = parse_actions(task, step.step_id, step.order, scraped_page, json_response["actions"])
detailed_agent_step_output.actions = actions
if len(actions) == 0:
LOG.info(
@@ -621,7 +635,8 @@ class ForgeAgent:
wait_actions_to_skip = [action for action in actions if action.action_type == ActionType.WAIT]
wait_actions_len = len(wait_actions_to_skip)
# if there are wait actions and there are other actions in the list, skip wait actions
if wait_actions_len > 0 and wait_actions_len < len(actions):
# if we are using cached action plan, we don't skip wait actions
if wait_actions_len > 0 and wait_actions_len < len(actions) and not using_cached_action_plan:
actions = [action for action in actions if action.action_type != ActionType.WAIT]
LOG.info(
"Skipping wait actions",
@@ -871,12 +886,10 @@ class ForgeAgent:
navigation_payload=task.navigation_payload,
elements=scraped_page.build_element_tree(ElementTreeFormat.HTML),
)
screenshots = await SkyvernFrame.take_split_screenshots(page=page, url=page.url)
verification_llm_api_handler = app.SECONDARY_LLM_API_HANDLER
verification_response = await verification_llm_api_handler(
prompt=verification_prompt, step=step, screenshots=screenshots
prompt=verification_prompt, step=step, screenshots=None
)
if "user_goal_achieved" not in verification_response or "reasoning" not in verification_response:
LOG.error(
@@ -895,9 +908,16 @@ class ForgeAgent:
return None
LOG.info("User goal achieved, executing complete action")
action_results = await handle_complete_action(complete_action, page, scraped_page, task, step)
extracted_data = None
if complete_action.data_extraction_goal:
scrape_action_result = await extract_information_for_navigation_goal(
scraped_page=scraped_page,
task=task,
step=step,
)
extracted_data = scrape_action_result.scraped_data
return complete_action, action_results
return complete_action, [ActionSuccess(data=extracted_data)]
except Exception:
LOG.error("LLM verification failed for complete action, skipping LLM verification", exc_info=True)

View File

@@ -0,0 +1,25 @@
You will be given information about a user's goal and details.
Your job is to answer the user's questions based on the information provided.
The user's questions will be provided in JSON format.
Your answers should be direct and to the point. No need to explain the answer.
Your response should be in JSON format. Basically fill in the answer part and return the JSON.
User's goal: {{ navigation_goal }}
User's details: {{ navigation_payload }}
User's questions: {{ queries_and_answers }}
YOUR RESPONSE HAS TO BE IN JSON FORMAT. DO NOT RETURN ANYTHING ELSE.
THESE ANSWERS WILL BE USED TO FILL OUT INFORMATION ON A WEBPAGE. DO NOT INCLUDE ANY UNRELATED INFORMATION OR UNNECESSARY DETAILS IN YOUR ANSWERS.
EXAMPLE RESPONSE FORMAT:
{
"question_1": "answer_1",
"question_2": "answer_2",
"question_3": "answer_3"
}

View File

@@ -1,4 +1,4 @@
Based on the content of the screenshot and the elements on the page, determine whether the user goal has been successfully completed or not.
Based on the content of the elements on the page, determine whether the user goal has been successfully completed or not.
The JSON object should be in this format:
```json
@@ -7,15 +7,15 @@ The JSON object should be in this format:
"user_goal_achieved": bool // True if the user goal has been completed, False otherwise.
}
Make sure to ONLY return the JSON object, with no additional text before or after it. Do not make any assumptions based on the screenshot, return a response solely based on what you observe in the screenshot and nothing else.
Make sure to ONLY return the JSON object, with no additional text before or after it. Do not make any assumptions, return a response solely based on the elements on the page.
Examples:
{
"reasoning": "The screenshot shows a success message for a file upload field. Since the user's goal is to upload a file, it has been successfully completed.",
"reasoning": "There is a success message for a file upload field. Since the user's goal is to upload a file, it has been successfully completed.",
"user_goal_achieved": true
}
{
"reasoning": "The screenshot shows a job application form with fields. Since the user's goal is to submit a job application, it has not been successfully completed.",
"reasoning": "This is a job application form with fields. Since the user's goal is to submit a job application, it has not been successfully completed.",
"user_goal_achieved": false
}

View File

@@ -14,7 +14,9 @@ Reply in JSON format with the following keys:
"action_plan": str, // A string that describes the plan of actions you're going to take. Be specific and to the point. Use this as a quick summary of the actions you're going to take, and what order you're going to take them in, and how that moves you towards your overall goal. Output "COMPLETE" action in the "actions" if user_goal_achieved is True.
"actions": array // An array of actions. Here's the format of each action:
[{
"reasoning": str, // The reasoning behind the action. Be specific, referencing any user information and their fields and element ids in your reasoning. Mention why you chose the action type, and why you chose the element id. Keep the reasoning short and to the point.
"reasoning": str, // The reasoning behind the action. This reasoning must be user information agnostic. Mention why you chose the action type, and why you chose the element id. Keep the reasoning short and to the point.
"user_detail_query": str, // Think of this value as a Jeopardy question. Ask the user for the details you need for executing this action. Ask the question even if the details are disclosed in user goal or user details. If it's a text field, ask for the text. If it's a file upload, ask for the file. If it's a dropdown, ask for the relevant information. If you are clicking on something specific, ask about what to click on. If you're downloading a file and you have multiple options, ask the user which one to download. Otherwise, use null. Examples are: "What product ID should I input into the search bar?", "What file should I upload?", "What is the previous insurance provider of the user?", "Which invoice should I download?", "Does the user have any pets?". If the action doesn't require any user details, use null.
"user_detail_answer": str, // The answer to the `user_detail_query`. The source of this answer can be user goal or user details.
"confidence_float": float, // The confidence of the action. Pick a number between 0.0 and 1.0. 0.0 means no confidence, 1.0 means full confidence
"action_type": str, // It's a string enum: "CLICK", "INPUT_TEXT", "UPLOAD_FILE", "SELECT_OPTION", "WAIT", "SOLVE_CAPTCHA", "COMPLETE", "TERMINATE". "CLICK" is an element you'd like to click. "INPUT_TEXT" is an element you'd like to input text into. "UPLOAD_FILE" is an element you'd like to upload a file into. "SELECT_OPTION" is an element you'd like to select an option from. "WAIT" action should be used if there are no actions to take and there is some indication on screen that waiting could yield more actions. "WAIT" should not be used if there are actions to take. "SOLVE_CAPTCHA" should be used if there's a captcha to solve on the screen. "COMPLETE" is used when the user goal has been achieved AND if there's any data extraction goal, you should be able to get data from the page. Never return a COMPLETE action unless the user goal is achieved. "TERMINATE" is used to terminate the whole task with a failure when it doesn't seem like the user goal can be achieved. Do not use "TERMINATE" if waiting could lead the user towards the goal. Only return "TERMINATE" if you are on a page where the user goal cannot be achieved. All other actions are ignored when "TERMINATE" is returned.
"id": str, // The id of the element to take action on. The id has to be one from the elements list

View File

@@ -0,0 +1,8 @@
import hashlib
def calculate_sha256(data: str) -> str:
"""Helper function to calculate SHA256 hash of a string."""
sha256_hash = hashlib.sha256()
sha256_hash.update(data.encode())
return sha256_hash.hexdigest()

View File

@@ -113,7 +113,7 @@ def rename_file(file_path: str, new_file_name: str) -> str:
return file_path
def calculate_sha256(file_path: str) -> str:
def calculate_sha256_for_file(file_path: str) -> str:
"""Helper function to calculate SHA256 hash of a file."""
sha256_hash = hashlib.sha256()
with open(file_path, "rb") as f:

View File

@@ -13,6 +13,7 @@ from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
from skyvern.forge.sdk.db.exceptions import NotFoundError
from skyvern.forge.sdk.db.models import (
ActionModel,
ArtifactModel,
AWSSecretParameterModel,
BitwardenCreditCardDataParameterModel,
@@ -68,6 +69,7 @@ from skyvern.forge.sdk.workflow.models.workflow import (
WorkflowRunParameter,
WorkflowRunStatus,
)
from skyvern.webeye.actions.actions import Action
from skyvern.webeye.actions.models import AgentStepOutput
LOG = structlog.get_logger()
@@ -1571,3 +1573,59 @@ class AgentDB:
)
totp_code = (await session.scalars(query)).all()
return [TOTPCode.model_validate(totp_code) for totp_code in totp_code]
async def create_action(self, action: Action) -> Action:
async with self.Session() as session:
new_action = ActionModel(
action_type=action.action_type,
source_action_id=action.source_action_id,
organization_id=action.organization_id,
workflow_run_id=action.workflow_run_id,
task_id=action.task_id,
step_id=action.step_id,
step_order=action.step_order,
action_order=action.action_order,
status=action.status,
reasoning=action.reasoning,
intention=action.intention,
response=action.response,
element_id=action.element_id,
skyvern_element_hash=action.skyvern_element_hash,
skyvern_element_data=action.skyvern_element_data,
action_json=action.model_dump(),
)
session.add(new_action)
await session.commit()
await session.refresh(new_action)
return Action.model_validate(new_action)
async def retrieve_action_plan(self, task: Task) -> list[Action]:
async with self.Session() as session:
subquery = (
select(TaskModel.task_id)
.filter(TaskModel.url == task.url)
.filter(TaskModel.navigation_goal == task.navigation_goal)
.filter(TaskModel.status == TaskStatus.completed)
.order_by(TaskModel.created_at.desc())
.limit(1)
.subquery()
)
query = (
select(ActionModel)
.filter(ActionModel.task_id == subquery.c.task_id)
.order_by(ActionModel.step_order, ActionModel.action_order, ActionModel.created_at)
)
actions = (await session.scalars(query)).all()
return [Action.model_validate(action) for action in actions]
async def get_previous_actions_for_task(self, task_id: str) -> list[Action]:
async with self.Session() as session:
query = (
select(ActionModel)
.filter_by(task_id=task_id)
.order_by(ActionModel.step_order, ActionModel.action_order, ActionModel.created_at)
)
actions = (await session.scalars(query)).all()
return [Action.model_validate(action) for action in actions]

View File

@@ -130,6 +130,11 @@ def generate_totp_code_id() -> str:
return f"totp_{int_id}"
def generate_action_id() -> str:
int_id = generate_id()
return f"a_{int_id}"
def generate_id() -> int:
"""
generate a 64-bit int ID

View File

@@ -19,6 +19,7 @@ from sqlalchemy.orm import DeclarativeBase
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
from skyvern.forge.sdk.db.id import (
generate_action_id,
generate_artifact_id,
generate_aws_secret_parameter_id,
generate_bitwarden_credit_card_data_parameter_id,
@@ -437,3 +438,29 @@ class TOTPCodeModel(Base):
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False, index=True)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
expired_at = Column(DateTime, index=True)
class ActionModel(Base):
__tablename__ = "actions"
__table_args__ = (Index("action_org_task_step_index", "organization_id", "task_id", "step_id"),)
action_id = Column(String, primary_key=True, index=True, default=generate_action_id)
action_type = Column(String, nullable=False)
source_action_id = Column(String, ForeignKey("actions.action_id"), nullable=True, index=True)
organization_id = Column(String, ForeignKey("organizations.organization_id"), nullable=True)
workflow_run_id = Column(String, ForeignKey("workflow_runs.workflow_run_id"), nullable=True)
task_id = Column(String, ForeignKey("tasks.task_id"), nullable=False, index=True)
step_id = Column(String, ForeignKey("steps.step_id"), nullable=False)
step_order = Column(Integer, nullable=False)
action_order = Column(Integer, nullable=False)
status = Column(String, nullable=False)
reasoning = Column(String, nullable=True)
intention = Column(String, nullable=True)
response = Column(String, nullable=True)
element_id = Column(String, nullable=True)
skyvern_element_hash = Column(String, nullable=True)
skyvern_element_data = Column(JSON, nullable=True)
action_json = Column(JSON, nullable=True)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)

View File

@@ -32,7 +32,7 @@ from skyvern.forge import app
from skyvern.forge.prompts import prompt_engine
from skyvern.forge.sdk.api.aws import AsyncAWSClient
from skyvern.forge.sdk.api.files import (
calculate_sha256,
calculate_sha256_for_file,
download_file,
download_from_s3,
get_path_for_workflow_download_directory,
@@ -181,6 +181,7 @@ class TaskBlock(Block):
download_suffix: str | None = None
totp_verification_url: str | None = None
totp_identifier: str | None = None
cache_actions: bool = False
def get_all_parameters(
self,
@@ -1057,7 +1058,7 @@ class SendEmailBlock(Block):
subtype=subtype,
filename=attachment_filename,
)
file_hash = calculate_sha256(path)
file_hash = calculate_sha256_for_file(path)
file_names_by_hash[file_hash].append(path)
finally:
if path:

View File

@@ -129,6 +129,7 @@ class TaskBlockYAML(BlockYAML):
download_suffix: str | None = None
totp_verification_url: str | None = None
totp_identifier: str | None = None
cache_actions: bool = False
class ForLoopBlockYAML(BlockYAML):

View File

@@ -985,7 +985,8 @@ class WorkflowService:
bitwarden_client_id_aws_secret_key=parameter.bitwarden_client_id_aws_secret_key,
bitwarden_client_secret_aws_secret_key=parameter.bitwarden_client_secret_aws_secret_key,
bitwarden_master_password_aws_secret_key=parameter.bitwarden_master_password_aws_secret_key,
bitwarden_collection_id=parameter.bitwarden_collection_id,
# TODO: remove "# type: ignore" after ensuring bitwarden_collection_id is always set
bitwarden_collection_id=parameter.bitwarden_collection_id, # type: ignore
bitwarden_item_id=parameter.bitwarden_item_id,
key=parameter.key,
description=parameter.description,
@@ -1128,6 +1129,7 @@ class WorkflowService:
continue_on_failure=block_yaml.continue_on_failure,
totp_verification_url=block_yaml.totp_verification_url,
totp_identifier=block_yaml.totp_identifier,
cache_actions=block_yaml.cache_actions,
)
elif block_yaml.block_type == BlockType.FOR_LOOP:
loop_blocks = [

View File

@@ -1,14 +1,17 @@
from enum import StrEnum
from typing import Annotated, Any, Dict
from typing import Annotated, Any, Dict, Type, TypeVar
import structlog
from deprecation import deprecated
from litellm import ConfigDict
from pydantic import BaseModel, Field, ValidationError
from skyvern.exceptions import UnsupportedActionType
from skyvern.forge.sdk.schemas.tasks import Task
from skyvern.webeye.scraper.scraper import ScrapedPage
LOG = structlog.get_logger()
T = TypeVar("T", bound="Action")
class ActionType(StrEnum):
@@ -27,6 +30,23 @@ class ActionType(StrEnum):
TERMINATE = "terminate"
COMPLETE = "complete"
def is_web_action(self) -> bool:
return self in [
ActionType.CLICK,
ActionType.INPUT_TEXT,
ActionType.UPLOAD_FILE,
ActionType.DOWNLOAD_FILE,
ActionType.SELECT_OPTION,
ActionType.CHECKBOX,
]
class ActionStatus(StrEnum):
pending = "pending"
skipped = "skipped"
failed = "failed"
completed = "completed"
class UserDefinedError(BaseModel):
error_code: str
@@ -53,11 +73,26 @@ class InputOrSelectContext(BaseModel):
class Action(BaseModel):
model_config = ConfigDict(from_attributes=True)
action_type: ActionType
status: ActionStatus = ActionStatus.pending
action_id: str | None = None
source_action_id: str | None = None
organization_id: str | None = None
workflow_run_id: str | None = None
task_id: str | None = None
step_id: str | None = None
step_order: int | None = None
action_order: int | None = None
confidence_float: float | None = None
description: str | None = None
reasoning: str | None = None
intention: str | None = None
response: str | None = None
element_id: Annotated[str, Field(coerce_numbers_to_str=True)] | None = None
skyvern_element_hash: str | None = None
skyvern_element_data: dict[str, Any] | None = None
# DecisiveAction (CompleteAction, TerminateAction) fields
errors: list[UserDefinedError] | None = None
@@ -72,6 +107,38 @@ class Action(BaseModel):
option: SelectOption | None = None
is_checked: bool | None = None
@classmethod
def validate(cls: Type[T], value: Any) -> T:
if isinstance(value, dict):
action_type = value["action_type"]
if action_type is ActionType.CLICK:
return ClickAction.model_validate(value)
elif action_type is ActionType.INPUT_TEXT:
return InputTextAction.model_validate(value)
elif action_type is ActionType.UPLOAD_FILE:
return UploadFileAction.model_validate(value)
elif action_type is ActionType.DOWNLOAD_FILE:
return DownloadFileAction.model_validate(value)
elif action_type is ActionType.NULL_ACTION:
return NullAction.model_validate(value)
elif action_type is ActionType.TERMINATE:
return TerminateAction.model_validate(value)
elif action_type is ActionType.COMPLETE:
return CompleteAction.model_validate(value)
elif action_type is ActionType.SELECT_OPTION:
return SelectOptionAction.model_validate(value)
elif action_type is ActionType.CHECKBOX:
return CheckboxAction.model_validate(value)
elif action_type is ActionType.WAIT:
return WaitAction.model_validate(value)
elif action_type is ActionType.SOLVE_CAPTCHA:
return SolveCaptchaAction.model_validate(value)
else:
raise ValueError(f"Unsupported action type: {action_type}")
else:
raise ValueError("Invalid action data")
class WebAction(Action):
element_id: Annotated[str, Field(coerce_numbers_to_str=True)]
@@ -159,7 +226,7 @@ class CompleteAction(DecisiveAction):
data_extraction_goal: str | None = None
def parse_action(action: Dict[str, Any], data_extraction_goal: str | None = None) -> Action:
def parse_action(action: Dict[str, Any], scraped_page: ScrapedPage, data_extraction_goal: str | None = None) -> Action:
if "id" in action:
element_id = action["id"]
elif "element_id" in action:
@@ -167,57 +234,58 @@ def parse_action(action: Dict[str, Any], data_extraction_goal: str | None = None
else:
element_id = None
skyvern_element_hash = scraped_page.id_to_element_hash.get(element_id) if element_id else None
skyvern_element_data = scraped_page.id_to_element_dict.get(element_id) if element_id else None
reasoning = action["reasoning"] if "reasoning" in action else None
confidence_float = action["confidence_float"] if "confidence_float" in action else None
# TODO: currently action intention and response are only used for Q&A actions, like input_text
# When we start supporting click action, intention will be the reasoning for the click action (why take the action)
intention = action["user_detail_query"] if "user_detail_query" in action else None
response = action["user_detail_answer"] if "user_detail_answer" in action else None
base_action_dict = {
"element_id": element_id,
"skyvern_element_hash": skyvern_element_hash,
"skyvern_element_data": skyvern_element_data,
"reasoning": reasoning,
"confidence_float": confidence_float,
"intention": intention,
"response": response,
}
if "action_type" not in action or action["action_type"] is None:
return NullAction(reasoning=reasoning, confidence_float=confidence_float)
return NullAction(**base_action_dict)
# `.upper()` handles the case where the LLM returns a lowercase action type (e.g. "click" instead of "CLICK")
action_type = ActionType[action["action_type"].upper()]
if not action_type.is_web_action():
# LLM sometimes hallucinates and returns element id for non-web actions such as WAIT, TERMINATE, COMPLETE etc.
# That can sometimes cause cached action plan to be invalidated. This way we're making sure the element id is not
# set for non-web actions.
base_action_dict["element_id"] = None
if action_type == ActionType.TERMINATE:
return TerminateAction(
reasoning=reasoning,
confidence_float=confidence_float,
errors=action["errors"] if "errors" in action else [],
)
return TerminateAction(**base_action_dict, errors=action["errors"] if "errors" in action else [])
if action_type == ActionType.CLICK:
file_url = action["file_url"] if "file_url" in action else None
return ClickAction(
element_id=element_id,
reasoning=reasoning,
confidence_float=confidence_float,
file_url=file_url,
download=action.get("download", False),
)
return ClickAction(**base_action_dict, file_url=file_url, download=action.get("download", False))
if action_type == ActionType.INPUT_TEXT:
return InputTextAction(
element_id=element_id,
text=action["text"],
reasoning=reasoning,
confidence_float=confidence_float,
)
return InputTextAction(**base_action_dict, text=action["text"])
if action_type == ActionType.UPLOAD_FILE:
# TODO: see if the element is a file input element. if it's not, convert this action into a click action
return UploadFileAction(
element_id=element_id,
confidence_float=confidence_float,
**base_action_dict,
file_url=action["file_url"],
reasoning=reasoning,
)
# This action is not used in the current implementation. Click actions are used instead.
if action_type == ActionType.DOWNLOAD_FILE:
return DownloadFileAction(
element_id=element_id,
file_name=action["file_name"],
reasoning=reasoning,
confidence_float=confidence_float,
)
return DownloadFileAction(**base_action_dict, file_name=action["file_name"])
if action_type == ActionType.SELECT_OPTION:
option = action["option"]
@@ -229,49 +297,54 @@ def parse_action(action: Dict[str, Any], data_extraction_goal: str | None = None
if label is None and value is None and index is None:
raise ValueError("At least one of 'label', 'value', or 'index' must be provided for a SelectOption")
return SelectOptionAction(
element_id=element_id,
**base_action_dict,
option=SelectOption(
label=label,
value=value,
index=index,
),
reasoning=reasoning,
confidence_float=confidence_float,
)
if action_type == ActionType.CHECKBOX:
return CheckboxAction(
element_id=element_id,
**base_action_dict,
is_checked=action["is_checked"],
reasoning=reasoning,
confidence_float=confidence_float,
)
if action_type == ActionType.WAIT:
return WaitAction(reasoning=reasoning, confidence_float=confidence_float)
return WaitAction(**base_action_dict)
if action_type == ActionType.COMPLETE:
return CompleteAction(
reasoning=reasoning,
confidence_float=confidence_float,
**base_action_dict,
data_extraction_goal=data_extraction_goal,
errors=action["errors"] if "errors" in action else [],
)
if action_type == "null":
return NullAction(reasoning=reasoning, confidence_float=confidence_float)
return NullAction(**base_action_dict)
if action_type == ActionType.SOLVE_CAPTCHA:
return SolveCaptchaAction(reasoning=reasoning, confidence_float=confidence_float)
return SolveCaptchaAction(**base_action_dict)
raise UnsupportedActionType(action_type=action_type)
def parse_actions(task: Task, json_response: list[Dict[str, Any]]) -> list[Action]:
def parse_actions(
task: Task, step_id: str, step_order: int, scraped_page: ScrapedPage, json_response: list[Dict[str, Any]]
) -> list[Action]:
actions: list[Action] = []
for action in json_response:
for idx, action in enumerate(json_response):
try:
action_instance = parse_action(action=action, data_extraction_goal=task.data_extraction_goal)
action_instance = parse_action(
action=action, scraped_page=scraped_page, data_extraction_goal=task.data_extraction_goal
)
action_instance.organization_id = task.organization_id
action_instance.workflow_run_id = task.workflow_run_id
action_instance.task_id = task.task_id
action_instance.step_id = step_id
action_instance.step_order = step_order
action_instance.action_order = idx
if isinstance(action_instance, TerminateAction):
LOG.warning(
"Agent decided to terminate",
@@ -303,6 +376,23 @@ def parse_actions(task: Task, json_response: list[Dict[str, Any]]) -> list[Actio
raw_action=action,
exc_info=True,
)
############################ This part of code might not be needed ############################
# Reason #1. validation can be done in action handler but not in parser
# Reason #2. no need to validate whether the element_id has a hash.
# If there's no hash, we can fall back to normal operation
all_element_ids = [action.element_id for action in actions if action.element_id]
missing_element_ids = [
element_id for element_id in all_element_ids if element_id not in scraped_page.id_to_element_hash
]
if missing_element_ids:
LOG.warning(
"Missing elements in scraped page",
task_id=task.task_id,
missing_element_ids=missing_element_ids,
all_element_ids=all_element_ids,
)
############################ This part of code might not be needed ############################
return actions

View File

@@ -0,0 +1,226 @@
import structlog
from skyvern.exceptions import CachedActionPlanError
from skyvern.forge import app
from skyvern.forge.prompts import prompt_engine
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.tasks import Task
from skyvern.webeye.actions.actions import Action, ActionStatus, ActionType
from skyvern.webeye.scraper.scraper import ScrapedPage
LOG = structlog.get_logger()
async def retrieve_action_plan(task: Task, step: Step, scraped_page: ScrapedPage) -> list[Action]:
try:
return await _retrieve_action_plan(task, step, scraped_page)
except Exception as e:
LOG.exception("Failed to retrieve action plan", exception=e)
return []
async def _retrieve_action_plan(task: Task, step: Step, scraped_page: ScrapedPage) -> list[Action]:
# V0: use the previous action plan if there is a completed task with the same url and navigation goal
# get completed task with the same url and navigation goal
# TODO(kerem): don't use step_order, get all the previous actions instead
cached_actions = await app.DATABASE.retrieve_action_plan(task=task)
if not cached_actions:
LOG.info("No cached actions found for the task, fallback to no-cache mode")
return []
# Get the existing actions for this task from the database. Then find the actions that are already executed by looking at
# the source_action_id field for this task's actions.
previous_actions = await app.DATABASE.get_previous_actions_for_task(task_id=task.task_id)
executed_cached_actions = []
remaining_cached_actions = []
action_matching_complete = False
if previous_actions:
for idx, cached_action in enumerate(cached_actions):
if not action_matching_complete:
should_be_matching_action = previous_actions[idx]
if not should_be_matching_action.source_action_id:
# If there is an action without a source_action_id, it means we already went back to no-cache mode
# and we should not try to reuse the previous action plan since it's not possible to determine which
# action we should execute next
return []
action_id_to_match = (
cached_action.source_action_id if cached_action.source_action_id else cached_action.action_id
)
if should_be_matching_action.source_action_id == action_id_to_match:
executed_cached_actions.append(cached_action)
if idx == len(previous_actions) - 1:
# If we've reached the end of the previous actions, we've completed matching.
action_matching_complete = True
else:
# If we've reached an action that doesn't match the source_action_id of the previous actions,
# we've completed matching.
action_matching_complete = True
remaining_cached_actions.append(cached_action)
else:
remaining_cached_actions.append(cached_action)
else:
remaining_cached_actions = cached_actions
action_matching_complete = True
# For any remaining cached action,
# check if the element hash exists in the current scraped page. Add them to a list until we can't find a match. Always keep the
# actions without an element hash.
cached_actions_to_execute: list[Action] = []
found_element_with_no_hash = False
for cached_action in remaining_cached_actions:
# The actions without an element hash: TerminateAction CompleteAction NullAction SolveCaptchaAction WaitAction
# For these actions, we can't check if the element hash exists in the current scraped page.
# For that reason, we're going to make sure they're executed always as the first action in each step.
if not cached_action.skyvern_element_hash:
if not found_element_with_no_hash and len(cached_actions_to_execute) > 0:
# If we've already added actions with element hashes to the list before we encounter an action without an element hash,
# we need to execute the actions we already added first. We want the actions without an element hash
# to be executed as the first actions in each step. We're ok with executing multiple actions without an element hash
# in a row, but we want them to be executed in a new step after we wait & scrape the page again.
break
cached_actions_to_execute.append(cached_action)
found_element_with_no_hash = True
continue
matching_element_ids = scraped_page.hash_to_element_ids.get(cached_action.skyvern_element_hash)
if matching_element_ids and len(matching_element_ids) == 1:
cached_actions_to_execute.append(cached_action)
continue
# After this point, we can't continue adding actions to the plan, so we break and continue with what we have.
# Because this action has either no hash-match or multiple hash-matches, we can't continue.
elif matching_element_ids and len(matching_element_ids) > 1:
LOG.warning(
"Found multiple elements with the same hash, stop matching",
element_hash=cached_action.skyvern_element_hash,
element_ids=matching_element_ids,
)
break
else:
LOG.warning("No element found with the hash", element_hash=cached_action.skyvern_element_hash)
break
# If there are no items in the list we just built, we need to revert back to no-cache mode. Return empty list.
if not cached_actions_to_execute:
return []
LOG.info("Found cached actions to execute", actions=cached_actions_to_execute)
actions_queries: list[tuple[Action, str | None]] = []
for idx, cached_action in enumerate(cached_actions_to_execute):
updated_action = cached_action.model_copy()
updated_action.status = ActionStatus.pending
updated_action.source_action_id = (
cached_action.source_action_id if cached_action.source_action_id else cached_action.action_id
)
updated_action.workflow_run_id = task.workflow_run_id
updated_action.task_id = task.task_id
updated_action.step_id = step.step_id
updated_action.step_order = step.order
updated_action.action_order = idx
# Reset the action response to None so we don't use the previous answers
updated_action.response = None
# Update the element id with the element id from the current scraped page, matched by element hash
if cached_action.skyvern_element_hash:
matching_element_ids = scraped_page.hash_to_element_ids.get(cached_action.skyvern_element_hash)
if matching_element_ids and len(matching_element_ids) == 1:
matching_element_id = matching_element_ids[0]
updated_action.element_id = matching_element_id
updated_action.skyvern_element_data = scraped_page.id_to_element_dict.get(matching_element_id)
else:
raise CachedActionPlanError(
"All elements with either no hash or multiple hashes should have been already filtered out"
)
actions_queries.append((updated_action, updated_action.intention))
# Check for unsupported actions before personalizing the actions
# Classify the supported actions into two groups:
# 1. Actions that can be cached with a query
# 2. Actions that can be cached without a query
# We'll use this classification to determine if we should continue with caching or fallback to no-cache mode
check_for_unsupported_actions(actions_queries)
personalized_actions = await personalize_actions(
task=task, step=step, scraped_page=scraped_page, actions_queries=actions_queries
)
LOG.info("Personalized cached actions are ready", actions=personalized_actions)
return personalized_actions
async def personalize_actions(
task: Task,
step: Step,
actions_queries: list[tuple[Action, str | None]],
scraped_page: ScrapedPage,
) -> list[Action]:
queries_and_answers: dict[str, str | None] = {query: None for _, query in actions_queries if query}
answered_queries: dict[str, str] = {}
if queries_and_answers:
# Call LLM to get answers for the queries only if there are queries to answer
answered_queries = await get_user_detail_answers(
task=task, step=step, scraped_page=scraped_page, queries_and_answers=queries_and_answers
)
personalized_actions = []
for action, query in actions_queries:
if query and (personalized_answer := answered_queries.get(query)):
personalized_actions.append(personalize_action(action, query, personalized_answer))
else:
personalized_actions.append(action)
return personalized_actions
async def get_user_detail_answers(
task: Task, step: Step, scraped_page: ScrapedPage, queries_and_answers: dict[str, str | None]
) -> dict[str, str]:
try:
question_answering_prompt = prompt_engine.load_prompt(
"answer-user-detail-questions",
navigation_goal=task.navigation_goal,
navigation_payload=task.navigation_payload,
queries_and_answers=queries_and_answers,
)
llm_response = await app.SECONDARY_LLM_API_HANDLER(
prompt=question_answering_prompt, step=step, screenshots=None
)
return llm_response
except Exception as e:
LOG.exception("Failed to get user detail answers", exception=e)
# TODO: custom exception so we can fallback to no-cache mode by catching it
raise e
def personalize_action(action: Action, query: str, answer: str) -> Action:
action.intention = query
action.response = answer
if action.action_type == ActionType.INPUT_TEXT:
action.text = answer
else:
raise CachedActionPlanError(
f"Unsupported action type for personalization, fallback to no-cache mode: {action.action_type}"
)
return action
def check_for_unsupported_actions(actions_queries: list[tuple[Action, str | None]]) -> None:
supported_actions = [ActionType.INPUT_TEXT, ActionType.WAIT, ActionType.CLICK, ActionType.COMPLETE]
supported_actions_with_query = [ActionType.INPUT_TEXT]
for action, query in actions_queries:
if action.action_type not in supported_actions:
raise CachedActionPlanError(
f"This action type does not support caching: {action.action_type}, fallback to no-cache mode"
)
if query and action.action_type not in supported_actions_with_query:
raise CachedActionPlanError(
f"This action type does not support caching with a query: {action.action_type}, fallback to no-cache mode"
)

View File

@@ -23,6 +23,7 @@ from skyvern.exceptions import (
FailToSelectByIndex,
FailToSelectByLabel,
FailToSelectByValue,
IllegitComplete,
ImaginaryFileUrl,
InvalidElementForTextInput,
MissingElement,
@@ -54,6 +55,7 @@ from skyvern.forge.sdk.settings_manager import SettingsManager
from skyvern.webeye.actions import actions
from skyvern.webeye.actions.actions import (
Action,
ActionStatus,
ActionType,
CheckboxAction,
ClickAction,
@@ -64,7 +66,7 @@ from skyvern.webeye.actions.actions import (
UploadFileAction,
WebAction,
)
from skyvern.webeye.actions.responses import ActionFailure, ActionResult, ActionSuccess
from skyvern.webeye.actions.responses import ActionAbort, ActionFailure, ActionResult, ActionSuccess
from skyvern.webeye.browser_factory import BrowserState, get_download_dir
from skyvern.webeye.scraper.scraper import (
CleanupElementTreeFunc,
@@ -227,12 +229,13 @@ class ActionHandler:
) -> list[ActionResult]:
LOG.info("Handling action", action=action)
page = await browser_state.get_or_create_page()
actions_result: list[ActionResult] = []
try:
if action.action_type in ActionHandler._handled_action_types:
actions_result: list[ActionResult] = []
if invalid_web_action_check := check_for_invalid_web_action(action, page, scraped_page, task, step):
return invalid_web_action_check
invalid_web_action_check = check_for_invalid_web_action(action, page, scraped_page, task, step)
if invalid_web_action_check:
actions_result.extend(invalid_web_action_check)
return actions_result
# do setup before action handler
if setup := ActionHandler._setup_action_types.get(action.action_type):
@@ -250,11 +253,10 @@ class ActionHandler:
# do the teardown
teardown = ActionHandler._teardown_action_types.get(action.action_type)
if not teardown:
return actions_result
if teardown:
results = await teardown(action, page, scraped_page, task, step)
actions_result.extend(results)
results = await teardown(action, page, scraped_page, task, step)
actions_result.extend(results)
return actions_result
else:
@@ -263,7 +265,8 @@ class ActionHandler:
action=action,
type=type(action),
)
return [ActionFailure(Exception(f"Unsupported action type: {type(action)}"))]
actions_result.append(ActionFailure(Exception(f"Unsupported action type: {type(action)}")))
return actions_result
except MissingElement as e:
LOG.info(
"Known exceptions",
@@ -271,16 +274,29 @@ class ActionHandler:
exception_type=type(e),
exception_message=str(e),
)
return [ActionFailure(e)]
actions_result.append(ActionFailure(e))
except MultipleElementsFound as e:
LOG.exception(
"Cannot handle multiple elements with the same selector in one action.",
action=action,
)
return [ActionFailure(e)]
actions_result.append(ActionFailure(e))
except Exception as e:
LOG.exception("Unhandled exception in action handler", action=action)
return [ActionFailure(e)]
actions_result.append(ActionFailure(e))
finally:
if actions_result and isinstance(actions_result[-1], ActionSuccess):
action.status = ActionStatus.completed
elif actions_result and isinstance(actions_result[-1], ActionAbort):
action.status = ActionStatus.skipped
else:
# either actions_result is empty or the last action is a failure
if not actions_result:
LOG.warning("Action failed to execute, setting status to failed", action=action)
action.status = ActionStatus.failed
await app.DATABASE.create_action(action=action)
return actions_result
def check_for_invalid_web_action(
@@ -874,7 +890,7 @@ async def handle_wait_action(
task: Task,
step: Step,
) -> list[ActionResult]:
await asyncio.sleep(10)
await asyncio.sleep(20)
return [ActionFailure(exception=Exception("Wait action is treated as a failure"))]
@@ -895,6 +911,25 @@ async def handle_complete_action(
task: Task,
step: Step,
) -> list[ActionResult]:
# If this action has a source_action_id, then we need to make sure if the goal is actually completed.
if action.source_action_id:
LOG.info("CompleteAction has source_action_id, checking if goal is completed")
complete_action_and_results = await app.agent.check_user_goal_success(page, scraped_page, task, step)
if complete_action_and_results is None:
return [
ActionFailure(
exception=IllegitComplete(
data={
"error": "Cached complete action wasn't verified by LLM, fallback to default execution mode"
}
)
)
]
_, action_results = complete_action_and_results
return action_results
# If there's no source_action_id, then we just handle it as a normal complete action
extracted_data = None
if action.data_extraction_goal:
scrape_action_result = await extract_information_for_navigation_goal(
@@ -951,6 +986,15 @@ async def chain_click(
# File choosers are impossible to close if you don't expect one. Instead of dealing with it, close it!
locator = skyvern_element.locator
try:
await locator.hover(timeout=timeout)
except Exception:
LOG.warning(
"Failed to hover over element in chain_click",
action=action,
locator=locator,
exc_info=True,
)
# TODO (suchintan): This should likely result in an ActionFailure -- we can figure out how to do this later!
LOG.info("Chain click starts", action=action, locator=locator)
file: list[str] | str = []
@@ -1015,6 +1059,7 @@ async def chain_click(
parent_javascript_triggered = await is_javascript_triggered(scraped_page, page, parent_locator)
javascript_triggered = javascript_triggered or parent_javascript_triggered
await parent_locator.hover(timeout=timeout)
await parent_locator.click(timeout=timeout)
LOG.info(
@@ -2101,6 +2146,10 @@ async def click_sibling_of_input(
input_id = await input_element.get_attribute("id")
sibling_label_css = f'label[for="{input_id}"]'
label_locator = parent_locator.locator(sibling_label_css)
try:
await locator.hover(timeout=timeout)
except Exception:
LOG.warning("Failed to hover over input element in click_sibling_of_input", exc_info=True)
await label_locator.click(timeout=timeout)
LOG.info(
"Successfully clicked sibling label of input element",

View File

@@ -11,6 +11,7 @@ from pydantic import BaseModel
from skyvern.constants import SKYVERN_DIR, SKYVERN_ID_ATTR
from skyvern.exceptions import FailedToTakeScreenshot, UnknownElementTreeFormat
from skyvern.forge.sdk.api.crypto import calculate_sha256
from skyvern.forge.sdk.settings_manager import SettingsManager
from skyvern.webeye.browser_factory import BrowserState
from skyvern.webeye.utils.page import SkyvernFrame
@@ -127,10 +128,34 @@ def json_to_html(element: dict, need_skyvern_attrs: bool = True) -> str:
return f'<{tag}{attributes_html if not attributes_html else " "+attributes_html}>{text}{children_html+option_html}</{tag}>'
def build_element_dict(elements: list[dict]) -> tuple[dict[str, str], dict[str, dict], dict[str, str]]:
def clean_element_before_hashing(element: dict) -> dict:
element_copy = copy.deepcopy(element)
element_copy.pop("id", None)
element_copy.pop("rect", None)
if "attributes" in element_copy:
element_copy["attributes"].pop(SKYVERN_ID_ATTR, None)
if "children" in element_copy:
for idx, child in enumerate(element_copy["children"]):
element_copy["children"][idx] = clean_element_before_hashing(child)
return element_copy
def hash_element(element: dict) -> str:
hash_ready_element = clean_element_before_hashing(element)
# Sort the keys to ensure consistent ordering
element_string = json.dumps(hash_ready_element, sort_keys=True)
return calculate_sha256(element_string)
def build_element_dict(
elements: list[dict],
) -> tuple[dict[str, str], dict[str, dict], dict[str, str], dict[str, str], dict[str, list[str]]]:
id_to_css_dict: dict[str, str] = {}
id_to_element_dict: dict[str, dict] = {}
id_to_frame_dict: dict[str, str] = {}
id_to_element_hash: dict[str, str] = {}
hash_to_element_ids: dict[str, list[str]] = {}
for element in elements:
element_id: str = element.get("id", "")
@@ -138,8 +163,11 @@ def build_element_dict(elements: list[dict]) -> tuple[dict[str, str], dict[str,
id_to_css_dict[element_id] = f"[{SKYVERN_ID_ATTR}='{element_id}']"
id_to_element_dict[element_id] = element
id_to_frame_dict[element_id] = element["frame"]
element_hash = hash_element(element)
id_to_element_hash[element_id] = element_hash
hash_to_element_ids[element_hash] = hash_to_element_ids.get(element_hash, []) + [element_id]
return id_to_css_dict, id_to_element_dict, id_to_frame_dict
return id_to_css_dict, id_to_element_dict, id_to_frame_dict, id_to_element_hash, hash_to_element_ids
class ElementTreeFormat(StrEnum):
@@ -163,6 +191,8 @@ class ScrapedPage(BaseModel):
id_to_element_dict: dict[str, dict] = {}
id_to_frame_dict: dict[str, str] = {}
id_to_css_dict: dict[str, str]
id_to_element_hash: dict[str, str]
hash_to_element_ids: dict[str, list[str]]
element_tree: list[dict]
element_tree_trimmed: list[dict]
screenshots: list[bytes]
@@ -309,7 +339,13 @@ async def scrape_web_unsafe(
elements, element_tree = await get_interactable_element_tree(page, scrape_exclude)
element_tree = await cleanup_element_tree(url, copy.deepcopy(element_tree))
id_to_css_dict, id_to_element_dict, id_to_frame_dict = build_element_dict(elements)
id_to_css_dict, id_to_element_dict, id_to_frame_dict, id_to_element_hash, hash_to_element_ids = build_element_dict(
elements
)
# if there are no elements, fail the scraping
if not elements:
raise Exception("No elements found on the page")
text_content = await get_frame_text(page.main_frame)
@@ -329,6 +365,8 @@ async def scrape_web_unsafe(
id_to_css_dict=id_to_css_dict,
id_to_element_dict=id_to_element_dict,
id_to_frame_dict=id_to_frame_dict,
id_to_element_hash=id_to_element_hash,
hash_to_element_ids=hash_to_element_ids,
element_tree=element_tree,
element_tree_trimmed=trim_element_tree(copy.deepcopy(element_tree)),
screenshots=screenshots,
@@ -434,7 +472,7 @@ class IncrementalScrapePage:
js_script = "() => getIncrementElements()"
incremental_elements, incremental_tree = await frame.evaluate(js_script)
# we listen the incremental elements seperated by frames, so all elements will be in the same SkyvernFrame
self.id_to_css_dict, self.id_to_element_dict, _ = build_element_dict(incremental_elements)
self.id_to_css_dict, self.id_to_element_dict, _, _, _ = build_element_dict(incremental_elements)
self.elements = incremental_elements

View File

@@ -120,10 +120,11 @@ class SkyvernElement:
return cls(locator, frame, element_dict)
def __init__(self, locator: Locator, frame: Page | Frame, static_element: dict) -> None:
def __init__(self, locator: Locator, frame: Page | Frame, static_element: dict, hash_value: str = "") -> None:
self.__static_element = static_element
self.__frame = frame
self.locator = locator
self.hash_value = hash_value
def build_HTML(self, need_trim_element: bool = True, need_skyvern_attrs: bool = True) -> str:
element_dict = self.get_element_dict()
@@ -486,4 +487,6 @@ class DomUtil:
)
raise MultipleElementsFound(num=num_elements, selector=css, element_id=element_id)
return SkyvernElement(locator, frame_content, element)
hash_value = self.scraped_page.id_to_element_hash.get(element_id, "")
return SkyvernElement(locator, frame_content, element, hash_value)