trim svg elements when prompt exceeds context window (#2106)
This commit is contained in:
2
poetry.lock
generated
2
poetry.lock
generated
@@ -6521,4 +6521,4 @@ type = ["pytest-mypy"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.11,<3.12"
|
python-versions = "^3.11,<3.12"
|
||||||
content-hash = "b43cb55e0c18ac83f0e32444132fd7618ef5b8355b0a90dbed55599d068c2892"
|
content-hash = "84b211a2b313b852996823fc4105d809b990e34cecd400c61d541561c010afdf"
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ json-repair = "^0.34.0"
|
|||||||
pypdf = "^5.1.0"
|
pypdf = "^5.1.0"
|
||||||
fastmcp = "^0.4.1"
|
fastmcp = "^0.4.1"
|
||||||
psutil = ">=7.0.0"
|
psutil = ">=7.0.0"
|
||||||
|
tiktoken = ">=0.9.0"
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
isort = "^5.13.2"
|
isort = "^5.13.2"
|
||||||
|
|||||||
@@ -68,6 +68,7 @@ from skyvern.forge.sdk.schemas.tasks import Task, TaskRequest, TaskResponse, Tas
|
|||||||
from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext
|
from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext
|
||||||
from skyvern.forge.sdk.workflow.models.block import ActionBlock, BaseTaskBlock, ValidationBlock
|
from skyvern.forge.sdk.workflow.models.block import ActionBlock, BaseTaskBlock, ValidationBlock
|
||||||
from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowRun, WorkflowRunStatus
|
from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowRun, WorkflowRunStatus
|
||||||
|
from skyvern.utils.prompt_engine import load_prompt_with_elements
|
||||||
from skyvern.webeye.actions.actions import (
|
from skyvern.webeye.actions.actions import (
|
||||||
Action,
|
Action,
|
||||||
ActionStatus,
|
ActionStatus,
|
||||||
@@ -1196,11 +1197,12 @@ class ForgeAgent:
|
|||||||
)
|
)
|
||||||
scraped_page_refreshed = await scraped_page.refresh(draw_boxes=False)
|
scraped_page_refreshed = await scraped_page.refresh(draw_boxes=False)
|
||||||
|
|
||||||
verification_prompt = prompt_engine.load_prompt(
|
verification_prompt = load_prompt_with_elements(
|
||||||
"check-user-goal",
|
scraped_page=scraped_page_refreshed,
|
||||||
|
prompt_engine=prompt_engine,
|
||||||
|
template_name="check-user-goal",
|
||||||
navigation_goal=task.navigation_goal,
|
navigation_goal=task.navigation_goal,
|
||||||
navigation_payload=task.navigation_payload,
|
navigation_payload=task.navigation_payload,
|
||||||
elements=scraped_page_refreshed.build_element_tree(ElementTreeFormat.HTML),
|
|
||||||
complete_criterion=task.complete_criterion,
|
complete_criterion=task.complete_criterion,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1432,7 +1434,7 @@ class ForgeAgent:
|
|||||||
task,
|
task,
|
||||||
step,
|
step,
|
||||||
browser_state,
|
browser_state,
|
||||||
element_tree_in_prompt,
|
scraped_page,
|
||||||
verification_code_check=bool(task.totp_verification_url or task.totp_identifier),
|
verification_code_check=bool(task.totp_verification_url or task.totp_identifier),
|
||||||
expire_verification_code=True,
|
expire_verification_code=True,
|
||||||
)
|
)
|
||||||
@@ -1470,7 +1472,7 @@ class ForgeAgent:
|
|||||||
task: Task,
|
task: Task,
|
||||||
step: Step,
|
step: Step,
|
||||||
browser_state: BrowserState,
|
browser_state: BrowserState,
|
||||||
element_tree_in_prompt: str,
|
scraped_page: ScrapedPage,
|
||||||
verification_code_check: bool = False,
|
verification_code_check: bool = False,
|
||||||
expire_verification_code: bool = False,
|
expire_verification_code: bool = False,
|
||||||
) -> str:
|
) -> str:
|
||||||
@@ -1525,13 +1527,14 @@ class ForgeAgent:
|
|||||||
raise UnsupportedTaskType(task_type=task_type)
|
raise UnsupportedTaskType(task_type=task_type)
|
||||||
|
|
||||||
context = skyvern_context.ensure_context()
|
context = skyvern_context.ensure_context()
|
||||||
return prompt_engine.load_prompt(
|
return load_prompt_with_elements(
|
||||||
template=template,
|
scraped_page=scraped_page,
|
||||||
|
prompt_engine=prompt_engine,
|
||||||
|
template_name=template,
|
||||||
navigation_goal=navigation_goal,
|
navigation_goal=navigation_goal,
|
||||||
navigation_payload_str=json.dumps(final_navigation_payload),
|
navigation_payload_str=json.dumps(final_navigation_payload),
|
||||||
starting_url=starting_url,
|
starting_url=starting_url,
|
||||||
current_url=current_url,
|
current_url=current_url,
|
||||||
elements=element_tree_in_prompt,
|
|
||||||
data_extraction_goal=task.data_extraction_goal,
|
data_extraction_goal=task.data_extraction_goal,
|
||||||
action_history=actions_and_results_str,
|
action_history=actions_and_results_str,
|
||||||
error_code_mapping_str=(json.dumps(task.error_code_mapping) if task.error_code_mapping else None),
|
error_code_mapping_str=(json.dumps(task.error_code_mapping) if task.error_code_mapping else None),
|
||||||
@@ -2300,12 +2303,11 @@ class ForgeAgent:
|
|||||||
current_context = skyvern_context.ensure_context()
|
current_context = skyvern_context.ensure_context()
|
||||||
current_context.totp_codes[task.task_id] = verification_code
|
current_context.totp_codes[task.task_id] = verification_code
|
||||||
|
|
||||||
element_tree_in_prompt: str = scraped_page.build_element_tree(ElementTreeFormat.HTML)
|
|
||||||
extract_action_prompt = await self._build_extract_action_prompt(
|
extract_action_prompt = await self._build_extract_action_prompt(
|
||||||
task,
|
task,
|
||||||
step,
|
step,
|
||||||
browser_state,
|
browser_state,
|
||||||
element_tree_in_prompt,
|
scraped_page,
|
||||||
verification_code_check=False,
|
verification_code_check=False,
|
||||||
expire_verification_code=True,
|
expire_verification_code=True,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -139,7 +139,9 @@ async def _convert_svg_to_string(
|
|||||||
|
|
||||||
skyvern_element = SkyvernElement(locator=locater, frame=skyvern_frame.get_frame(), static_element=element)
|
skyvern_element = SkyvernElement(locator=locater, frame=skyvern_frame.get_frame(), static_element=element)
|
||||||
|
|
||||||
_, blocked = await skyvern_frame.get_blocking_element_id(await skyvern_element.get_element_handler())
|
_, blocked = await skyvern_frame.get_blocking_element_id(
|
||||||
|
await skyvern_element.get_element_handler(timeout=1000)
|
||||||
|
)
|
||||||
if not skyvern_element.is_interactable() and blocked:
|
if not skyvern_element.is_interactable() and blocked:
|
||||||
_mark_element_as_dropped(element)
|
_mark_element_as_dropped(element)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ from skyvern.forge.sdk.workflow.models.yaml import (
|
|||||||
WorkflowDefinitionYAML,
|
WorkflowDefinitionYAML,
|
||||||
)
|
)
|
||||||
from skyvern.schemas.runs import ProxyLocation, RunType
|
from skyvern.schemas.runs import ProxyLocation, RunType
|
||||||
|
from skyvern.utils.prompt_engine import load_prompt_with_elements
|
||||||
from skyvern.webeye.browser_factory import BrowserState
|
from skyvern.webeye.browser_factory import BrowserState
|
||||||
from skyvern.webeye.scraper.scraper import ElementTreeFormat, ScrapedPage, scrape_website
|
from skyvern.webeye.scraper.scraper import ElementTreeFormat, ScrapedPage, scrape_website
|
||||||
from skyvern.webeye.utils.page import SkyvernFrame
|
from skyvern.webeye.utils.page import SkyvernFrame
|
||||||
@@ -462,10 +463,11 @@ async def run_task_v2_helper(
|
|||||||
continue
|
continue
|
||||||
current_url = current_url if current_url else str(await SkyvernFrame.get_url(frame=page) if page else url)
|
current_url = current_url if current_url else str(await SkyvernFrame.get_url(frame=page) if page else url)
|
||||||
|
|
||||||
task_v2_prompt = prompt_engine.load_prompt(
|
task_v2_prompt = load_prompt_with_elements(
|
||||||
|
scraped_page,
|
||||||
|
prompt_engine,
|
||||||
"task_v2",
|
"task_v2",
|
||||||
current_url=current_url,
|
current_url=current_url,
|
||||||
elements=element_tree_in_prompt,
|
|
||||||
user_goal=user_prompt,
|
user_goal=user_prompt,
|
||||||
task_history=task_history,
|
task_history=task_history,
|
||||||
local_datetime=datetime.now(context.tz_info).isoformat(),
|
local_datetime=datetime.now(context.tz_info).isoformat(),
|
||||||
|
|||||||
47
skyvern/utils/prompt_engine.py
Normal file
47
skyvern/utils/prompt_engine.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
|
||||||
|
from skyvern.forge.sdk.prompting import PromptEngine
|
||||||
|
from skyvern.utils.token_counter import count_tokens
|
||||||
|
from skyvern.webeye.scraper.scraper import ScrapedPage
|
||||||
|
|
||||||
|
DEFAULT_MAX_TOKENS = 100000
|
||||||
|
LOG = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def load_prompt_with_elements(
|
||||||
|
scraped_page: ScrapedPage,
|
||||||
|
prompt_engine: PromptEngine,
|
||||||
|
template_name: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
prompt = prompt_engine.load_prompt(template_name, elements=scraped_page.build_element_tree(), **kwargs)
|
||||||
|
token_count = count_tokens(prompt)
|
||||||
|
if token_count > DEFAULT_MAX_TOKENS:
|
||||||
|
# get rid of all the secondary elements like SVG, etc
|
||||||
|
economy_elements_tree = scraped_page.build_economy_elements_tree()
|
||||||
|
prompt = prompt_engine.load_prompt(template_name, elements=economy_elements_tree, **kwargs)
|
||||||
|
economy_token_count = count_tokens(prompt)
|
||||||
|
LOG.warning(
|
||||||
|
"Prompt is longer than the max tokens. Going to use the economy elements tree.",
|
||||||
|
template_name=template_name,
|
||||||
|
token_count=token_count,
|
||||||
|
economy_token_count=economy_token_count,
|
||||||
|
max_tokens=DEFAULT_MAX_TOKENS,
|
||||||
|
)
|
||||||
|
if economy_token_count > DEFAULT_MAX_TOKENS:
|
||||||
|
# !!! HACK alert
|
||||||
|
# dump the last 1/3 of the html context and keep the first 2/3 of the html context
|
||||||
|
economy_elements_tree_dumped = scraped_page.build_economy_elements_tree(percent_to_keep=2 / 3)
|
||||||
|
prompt = prompt_engine.load_prompt(template_name, elements=economy_elements_tree_dumped, **kwargs)
|
||||||
|
token_count_after_dump = count_tokens(prompt)
|
||||||
|
LOG.warning(
|
||||||
|
"Prompt is still longer than the max tokens. Will only keep the first 2/3 of the html context.",
|
||||||
|
template_name=template_name,
|
||||||
|
token_count=token_count,
|
||||||
|
economy_token_count=economy_token_count,
|
||||||
|
token_count_after_dump=token_count_after_dump,
|
||||||
|
max_tokens=DEFAULT_MAX_TOKENS,
|
||||||
|
)
|
||||||
|
return prompt
|
||||||
5
skyvern/utils/token_counter.py
Normal file
5
skyvern/utils/token_counter.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
import tiktoken
|
||||||
|
|
||||||
|
|
||||||
|
def count_tokens(text: str) -> int:
|
||||||
|
return len(tiktoken.encoding_for_model("gpt-4o").encode(text))
|
||||||
@@ -229,6 +229,7 @@ class ScrapedPage(BaseModel):
|
|||||||
hash_to_element_ids: dict[str, list[str]]
|
hash_to_element_ids: dict[str, list[str]]
|
||||||
element_tree: list[dict]
|
element_tree: list[dict]
|
||||||
element_tree_trimmed: list[dict]
|
element_tree_trimmed: list[dict]
|
||||||
|
economy_element_tree: list[dict] | None = None
|
||||||
screenshots: list[bytes]
|
screenshots: list[bytes]
|
||||||
url: str
|
url: str
|
||||||
html: str
|
html: str
|
||||||
@@ -268,6 +269,58 @@ class ScrapedPage(BaseModel):
|
|||||||
|
|
||||||
raise UnknownElementTreeFormat(fmt=fmt)
|
raise UnknownElementTreeFormat(fmt=fmt)
|
||||||
|
|
||||||
|
def build_economy_elements_tree(
|
||||||
|
self,
|
||||||
|
fmt: ElementTreeFormat = ElementTreeFormat.HTML,
|
||||||
|
html_need_skyvern_attrs: bool = True,
|
||||||
|
percent_to_keep: float = 1,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Economy elements tree doesn't include secondary elements like SVG, etc
|
||||||
|
"""
|
||||||
|
if not self.economy_element_tree:
|
||||||
|
economy_elements = []
|
||||||
|
copied_element_tree_trimmed = copy.deepcopy(self.element_tree_trimmed)
|
||||||
|
|
||||||
|
# Process each root element
|
||||||
|
for root_element in copied_element_tree_trimmed:
|
||||||
|
processed_element = self._process_element_for_economy_tree(root_element)
|
||||||
|
if processed_element:
|
||||||
|
economy_elements.append(processed_element)
|
||||||
|
|
||||||
|
self.economy_element_tree = economy_elements
|
||||||
|
|
||||||
|
final_element_tree = self.economy_element_tree[: int(len(self.economy_element_tree) * percent_to_keep)]
|
||||||
|
|
||||||
|
if fmt == ElementTreeFormat.JSON:
|
||||||
|
return json.dumps(final_element_tree)
|
||||||
|
|
||||||
|
if fmt == ElementTreeFormat.HTML:
|
||||||
|
return "".join(
|
||||||
|
json_to_html(element, need_skyvern_attrs=html_need_skyvern_attrs) for element in final_element_tree
|
||||||
|
)
|
||||||
|
|
||||||
|
raise UnknownElementTreeFormat(fmt=fmt)
|
||||||
|
|
||||||
|
def _process_element_for_economy_tree(self, element: dict) -> dict | None:
|
||||||
|
"""
|
||||||
|
Helper method to process an element for the economy tree using BFS.
|
||||||
|
Removes SVG elements and their children.
|
||||||
|
"""
|
||||||
|
# Skip SVG elements entirely
|
||||||
|
if element.get("tagName", "").lower() == "svg":
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Process children using BFS
|
||||||
|
if "children" in element:
|
||||||
|
new_children = []
|
||||||
|
for child in element["children"]:
|
||||||
|
processed_child = self._process_element_for_economy_tree(child)
|
||||||
|
if processed_child:
|
||||||
|
new_children.append(processed_child)
|
||||||
|
element["children"] = new_children
|
||||||
|
return element
|
||||||
|
|
||||||
async def refresh(self, draw_boxes: bool = True) -> Self:
|
async def refresh(self, draw_boxes: bool = True) -> Self:
|
||||||
refreshed_page = await scrape_website(
|
refreshed_page = await scrape_website(
|
||||||
browser_state=self._browser_state,
|
browser_state=self._browser_state,
|
||||||
|
|||||||
Reference in New Issue
Block a user