ai_adapt_value for text input (#3354)

This commit is contained in:
Shuchang Zheng
2025-09-03 16:44:52 -07:00
committed by GitHub
parent 32771bdd19
commit 55d847461e
7 changed files with 120 additions and 82 deletions

View File

@@ -126,12 +126,12 @@ def _generate_text_call(text_value: str, intention: str, parameter_key: str) ->
last_line=cst.SimpleWhitespace(DOUBLE_INDENT), last_line=cst.SimpleWhitespace(DOUBLE_INDENT),
), ),
args=[ args=[
# First positional argument: context.generated_parameters['parameter_key'] # First positional argument: context.parameters['parameter_key']
cst.Arg( cst.Arg(
value=cst.Subscript( value=cst.Subscript(
value=cst.Attribute( value=cst.Attribute(
value=cst.Name("context"), value=cst.Name("context"),
attr=cst.Name("generated_parameters"), attr=cst.Name("parameters"),
), ),
slice=[cst.SubscriptElement(slice=cst.Index(value=_value(parameter_key)))], slice=[cst.SubscriptElement(slice=cst.Index(value=_value(parameter_key)))],
), ),
@@ -247,20 +247,21 @@ def _action_to_stmt(act: dict[str, Any], assign_to_output: bool = False) -> cst.
) )
if method in ["type", "fill"]: if method in ["type", "fill"]:
# Get intention from action # Use context.parameters if field_name is available, otherwise fallback to direct value
intention = act.get("intention") or act.get("reasoning") or ""
# Use generate_text call if field_name is available, otherwise fallback to direct value
if act.get("field_name"): if act.get("field_name"):
text_value = _generate_text_call( text_value = cst.Subscript(
text_value=act["text"], intention=intention, parameter_key=act["field_name"] value=cst.Attribute(
value=cst.Name("context"),
attr=cst.Name("parameters"),
),
slice=[cst.SubscriptElement(slice=cst.Index(value=_value(act["field_name"])))],
) )
else: else:
text_value = _value(act["text"]) text_value = _value(act["text"])
args.append( args.append(
cst.Arg( cst.Arg(
keyword=cst.Name("text"), keyword=cst.Name("value"),
value=text_value, value=text_value,
whitespace_after_arg=cst.ParenthesizedWhitespace( whitespace_after_arg=cst.ParenthesizedWhitespace(
indent=True, indent=True,
@@ -268,6 +269,16 @@ def _action_to_stmt(act: dict[str, Any], assign_to_output: bool = False) -> cst.
), ),
) )
) )
args.append(
cst.Arg(
keyword=cst.Name("ai_adapt_value"),
value=cst.Name("True"),
whitespace_after_arg=cst.ParenthesizedWhitespace(
indent=True,
last_line=cst.SimpleWhitespace(INDENT),
),
)
)
elif method == "select_option": elif method == "select_option":
args.append( args.append(
cst.Arg( cst.Arg(

View File

@@ -2,8 +2,7 @@ from typing import Any
from pydantic import BaseModel from pydantic import BaseModel
from skyvern.core.script_generations.script_run_context_manager import script_run_context_manager from skyvern.core.script_generations.skyvern_page import RunContext, SkyvernPage, script_run_context_manager
from skyvern.core.script_generations.skyvern_page import RunContext, SkyvernPage
from skyvern.forge import app from skyvern.forge import app
from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.workflow.models.parameter import WorkflowParameterType from skyvern.forge.sdk.workflow.models.parameter import WorkflowParameterType

View File

@@ -1,34 +0,0 @@
from typing import Callable
from skyvern.core.script_generations.skyvern_page import RunContext
class ScriptRunContextManager:
"""
Manages the run context for code runs.
"""
def __init__(self) -> None:
# self.run_contexts: dict[str, RunContext] = {}
self.run_context: RunContext | None = None
self.cached_fns: dict[str, Callable] = {}
def get_run_context(self) -> RunContext | None:
return self.run_context
def set_run_context(self, run_context: RunContext) -> None:
self.run_context = run_context
def ensure_run_context(self) -> RunContext:
if not self.run_context:
raise Exception("Run context not found")
return self.run_context
def set_cached_fn(self, cache_key: str, fn: Callable) -> None:
self.cached_fns[cache_key] = fn
def get_cached_fn(self, cache_key: str) -> Callable | None:
return self.cached_fns.get(cache_key)
script_run_context_manager = ScriptRunContextManager()

View File

@@ -8,6 +8,7 @@ from datetime import datetime, timezone
from enum import StrEnum from enum import StrEnum
from typing import Any, Callable, Literal from typing import Any, Callable, Literal
import structlog
from playwright.async_api import Page from playwright.async_api import Page
from skyvern.config import settings from skyvern.config import settings
@@ -24,6 +25,8 @@ from skyvern.webeye.actions.actions import Action, ActionStatus, ExtractAction,
from skyvern.webeye.browser_factory import BrowserState from skyvern.webeye.browser_factory import BrowserState
from skyvern.webeye.scraper.scraper import ScrapedPage, scrape_website from skyvern.webeye.scraper.scraper import ScrapedPage, scrape_website
LOG = structlog.get_logger()
class Driver(StrEnum): class Driver(StrEnum):
PLAYWRIGHT = "playwright" PLAYWRIGHT = "playwright"
@@ -196,7 +199,8 @@ class SkyvernPage:
# Create action record. TODO: store more action fields # Create action record. TODO: store more action fields
kwargs = kwargs or {} kwargs = kwargs or {}
text = kwargs.get("text") # we're using "value" instead of "text" for input text actions interface
text = kwargs.get("value", "")
option_value = kwargs.get("option") option_value = kwargs.get("option")
select_option = SelectOption(value=option_value) if option_value else None select_option = SelectOption(value=option_value) if option_value else None
response: str | None = kwargs.get("response") response: str | None = kwargs.get("response")
@@ -314,7 +318,7 @@ class SkyvernPage:
current_url=self.page.url, current_url=self.page.url,
elements=element_tree, elements=element_tree,
local_datetime=datetime.now(context.tz_info or datetime.now().astimezone().tzinfo).isoformat(), local_datetime=datetime.now(context.tz_info or datetime.now().astimezone().tzinfo).isoformat(),
user_context=getattr(context, "prompt", None), # user_context=getattr(context, "prompt", None),
) )
json_response = await app.SINGLE_CLICK_AGENT_LLM_API_HANDLER( json_response = await app.SINGLE_CLICK_AGENT_LLM_API_HANDLER(
prompt=single_click_prompt, prompt=single_click_prompt,
@@ -334,28 +338,31 @@ class SkyvernPage:
async def fill( async def fill(
self, self,
xpath: str, xpath: str,
text: str, value: str,
ai_adapt_value: bool = False,
intention: str | None = None, intention: str | None = None,
data: str | dict[str, Any] | None = None, data: str | dict[str, Any] | None = None,
timeout: float = settings.BROWSER_ACTION_TIMEOUT_MS, timeout: float = settings.BROWSER_ACTION_TIMEOUT_MS,
) -> None: ) -> None:
await self._input_text(xpath, text, intention, data, timeout) await self._input_text(xpath, value, ai_adapt_value, intention, data, timeout)
@action_wrap(ActionType.INPUT_TEXT) @action_wrap(ActionType.INPUT_TEXT)
async def type( async def type(
self, self,
xpath: str, xpath: str,
text: str, value: str,
ai_adapt_value: bool = False,
intention: str | None = None, intention: str | None = None,
data: str | dict[str, Any] | None = None, data: str | dict[str, Any] | None = None,
timeout: float = settings.BROWSER_ACTION_TIMEOUT_MS, timeout: float = settings.BROWSER_ACTION_TIMEOUT_MS,
) -> None: ) -> None:
await self._input_text(xpath, text, intention, data, timeout) await self._input_text(xpath, value, ai_adapt_value, intention, data, timeout)
async def _input_text( async def _input_text(
self, self,
xpath: str, xpath: str,
text: str, value: str,
ai_adapt_value: bool = False,
intention: str | None = None, intention: str | None = None,
data: str | dict[str, Any] | None = None, data: str | dict[str, Any] | None = None,
timeout: float = settings.BROWSER_ACTION_TIMEOUT_MS, timeout: float = settings.BROWSER_ACTION_TIMEOUT_MS,
@@ -372,11 +379,33 @@ class SkyvernPage:
""" """
# format the text with the actual value of the parameter if it's a secret when running a workflow # format the text with the actual value of the parameter if it's a secret when running a workflow
context = skyvern_context.current() context = skyvern_context.current()
value = value or ""
if context and context.workflow_run_id: if context and context.workflow_run_id:
text = await _get_actual_value_of_parameter_if_secret(context.workflow_run_id, text) value = await _get_actual_value_of_parameter_if_secret(context.workflow_run_id, value)
if ai_adapt_value and intention:
try:
prompt = context.prompt if context else None
# Build the element tree of the current page for the prompt
# clean up empty data values
data = {k: v for k, v in data.items() if v} if isinstance(data, dict) else (data or "")
payload_str = json.dumps(data) if isinstance(data, (dict, list)) else (data or "")
script_generation_input_text_prompt = prompt_engine.load_prompt(
template="script-generation-input-text-generatiion",
intention=intention,
data=payload_str,
goal=prompt,
)
json_response = await app.SINGLE_INPUT_AGENT_LLM_API_HANDLER(
prompt=script_generation_input_text_prompt,
prompt_name="script-generation-input-text-generatiion",
)
value = json_response.get("answer", value)
except Exception:
LOG.exception(f"Failed to adapt value for input text action on xpath={xpath}, value={value}")
locator = self.page.locator(f"xpath={xpath}") locator = self.page.locator(f"xpath={xpath}")
await handler_utils.input_sequentially(locator, text, timeout=timeout) await handler_utils.input_sequentially(locator, value, timeout=timeout)
@action_wrap(ActionType.UPLOAD_FILE) @action_wrap(ActionType.UPLOAD_FILE)
async def upload_file( async def upload_file(
@@ -542,11 +571,13 @@ class RunContext:
self.original_parameters = parameters self.original_parameters = parameters
self.generated_parameters = generated_parameters self.generated_parameters = generated_parameters
self.parameters = copy.deepcopy(parameters) self.parameters = copy.deepcopy(parameters)
# if generated_parameters: if generated_parameters:
# self.parameters.update(generated_parameters) # hydrate the generated parameter fields in the run context parameters
for key, value in generated_parameters.items():
if key not in self.parameters:
self.parameters[key] = value
self.page = page self.page = page
self.trace: list[ActionCall] = [] self.trace: list[ActionCall] = []
self.prompt: str | None = None
async def _get_actual_value_of_parameter_if_secret(workflow_run_id: str, parameter: str) -> Any: async def _get_actual_value_of_parameter_if_secret(workflow_run_id: str, parameter: str) -> Any:
@@ -560,3 +591,34 @@ async def _get_actual_value_of_parameter_if_secret(workflow_run_id: str, paramet
workflow_run_context = app.WORKFLOW_CONTEXT_MANAGER.get_workflow_run_context(workflow_run_id) workflow_run_context = app.WORKFLOW_CONTEXT_MANAGER.get_workflow_run_context(workflow_run_id)
secret_value = workflow_run_context.get_original_secret_value_or_none(parameter) secret_value = workflow_run_context.get_original_secret_value_or_none(parameter)
return secret_value if secret_value is not None else parameter return secret_value if secret_value is not None else parameter
class ScriptRunContextManager:
"""
Manages the run context for code runs.
"""
def __init__(self) -> None:
# self.run_contexts: dict[str, RunContext] = {}
self.run_context: RunContext | None = None
self.cached_fns: dict[str, Callable] = {}
def get_run_context(self) -> RunContext | None:
return self.run_context
def set_run_context(self, run_context: RunContext) -> None:
self.run_context = run_context
def ensure_run_context(self) -> RunContext:
if not self.run_context:
raise Exception("Run context not found")
return self.run_context
def set_cached_fn(self, cache_key: str, fn: Callable) -> None:
self.cached_fns[cache_key] = fn
def get_cached_fn(self, cache_key: str) -> Callable | None:
return self.cached_fns.get(cache_key)
script_run_context_manager = ScriptRunContextManager()

View File

@@ -1,7 +1,6 @@
from typing import Any, Callable from typing import Any, Callable
from skyvern import RunContext, SkyvernPage from skyvern.core.script_generations.skyvern_page import RunContext, SkyvernPage, script_run_context_manager
from skyvern.core.script_generations.script_run_context_manager import script_run_context_manager
# Build a dummy workflow decorator # Build a dummy workflow decorator

View File

@@ -30,6 +30,7 @@ class SkyvernContext:
script_id: str | None = None script_id: str | None = None
script_revision_id: str | None = None script_revision_id: str | None = None
action_order: int = 0 action_order: int = 0
prompt: str | None = None
def __repr__(self) -> str: def __repr__(self) -> str:
return f"SkyvernContext(request_id={self.request_id}, organization_id={self.organization_id}, task_id={self.task_id}, step_id={self.step_id}, workflow_id={self.workflow_id}, workflow_run_id={self.workflow_run_id}, task_v2_id={self.task_v2_id}, max_steps_override={self.max_steps_override}, run_id={self.run_id})" return f"SkyvernContext(request_id={self.request_id}, organization_id={self.organization_id}, task_id={self.task_id}, step_id={self.step_id}, workflow_id={self.workflow_id}, workflow_run_id={self.workflow_run_id}, task_v2_id={self.task_v2_id}, max_steps_override={self.max_steps_override}, run_id={self.run_id})"

View File

@@ -16,7 +16,7 @@ from skyvern.config import settings
from skyvern.constants import GET_DOWNLOADED_FILES_TIMEOUT from skyvern.constants import GET_DOWNLOADED_FILES_TIMEOUT
from skyvern.core.script_generations.constants import SCRIPT_TASK_BLOCKS from skyvern.core.script_generations.constants import SCRIPT_TASK_BLOCKS
from skyvern.core.script_generations.generate_script import _build_block_fn, create_script_block from skyvern.core.script_generations.generate_script import _build_block_fn, create_script_block
from skyvern.core.script_generations.script_run_context_manager import script_run_context_manager from skyvern.core.script_generations.skyvern_page import script_run_context_manager
from skyvern.exceptions import ScriptNotFound, WorkflowRunNotFound from skyvern.exceptions import ScriptNotFound, WorkflowRunNotFound
from skyvern.forge import app from skyvern.forge import app
from skyvern.forge.prompts import prompt_engine from skyvern.forge.prompts import prompt_engine
@@ -942,8 +942,8 @@ async def run_task(
url=url, url=url,
) )
# set the prompt in the RunContext # set the prompt in the RunContext
run_context = script_run_context_manager.ensure_run_context() context = skyvern_context.ensure_context()
run_context.prompt = prompt context.prompt = prompt
if cache_key: if cache_key:
try: try:
@@ -972,7 +972,7 @@ async def run_task(
) )
finally: finally:
# clear the prompt in the RunContext # clear the prompt in the RunContext
run_context.prompt = None context.prompt = None
else: else:
if workflow_run_block_id: if workflow_run_block_id:
await _update_workflow_block( await _update_workflow_block(
@@ -984,7 +984,7 @@ async def run_task(
step_status=StepStatus.failed, step_status=StepStatus.failed,
failure_reason="Cache key is required", failure_reason="Cache key is required",
) )
run_context.prompt = None context.prompt = None
raise Exception("Cache key is required to run task block in a script") raise Exception("Cache key is required to run task block in a script")
@@ -1001,8 +1001,8 @@ async def download(
url=url, url=url,
) )
# set the prompt in the RunContext # set the prompt in the RunContext
run_context = script_run_context_manager.ensure_run_context() context = skyvern_context.ensure_context()
run_context.prompt = prompt context.prompt = prompt
if cache_key: if cache_key:
try: try:
@@ -1031,7 +1031,7 @@ async def download(
workflow_run_block_id=workflow_run_block_id, workflow_run_block_id=workflow_run_block_id,
) )
finally: finally:
run_context.prompt = None context.prompt = None
else: else:
if workflow_run_block_id: if workflow_run_block_id:
await _update_workflow_block( await _update_workflow_block(
@@ -1043,7 +1043,7 @@ async def download(
step_status=StepStatus.failed, step_status=StepStatus.failed,
failure_reason="Cache key is required", failure_reason="Cache key is required",
) )
run_context.prompt = None context.prompt = None
raise Exception("Cache key is required to run task block in a script") raise Exception("Cache key is required to run task block in a script")
@@ -1060,8 +1060,8 @@ async def action(
url=url, url=url,
) )
# set the prompt in the RunContext # set the prompt in the RunContext
run_context = script_run_context_manager.ensure_run_context() context = skyvern_context.ensure_context()
run_context.prompt = prompt context.prompt = prompt
if cache_key: if cache_key:
try: try:
@@ -1089,7 +1089,7 @@ async def action(
workflow_run_block_id=workflow_run_block_id, workflow_run_block_id=workflow_run_block_id,
) )
finally: finally:
run_context.prompt = None context.prompt = None
else: else:
if workflow_run_block_id: if workflow_run_block_id:
await _update_workflow_block( await _update_workflow_block(
@@ -1101,7 +1101,7 @@ async def action(
step_status=StepStatus.failed, step_status=StepStatus.failed,
failure_reason="Cache key is required", failure_reason="Cache key is required",
) )
run_context.prompt = None context.prompt = None
raise Exception("Cache key is required to run task block in a script") raise Exception("Cache key is required to run task block in a script")
@@ -1118,8 +1118,8 @@ async def login(
url=url, url=url,
) )
# set the prompt in the RunContext # set the prompt in the RunContext
run_context = script_run_context_manager.ensure_run_context() context = skyvern_context.ensure_context()
run_context.prompt = prompt context.prompt = prompt
if cache_key: if cache_key:
try: try:
@@ -1147,7 +1147,7 @@ async def login(
workflow_run_block_id=workflow_run_block_id, workflow_run_block_id=workflow_run_block_id,
) )
finally: finally:
run_context.prompt = None context.prompt = None
else: else:
if workflow_run_block_id: if workflow_run_block_id:
await _update_workflow_block( await _update_workflow_block(
@@ -1159,7 +1159,7 @@ async def login(
step_status=StepStatus.failed, step_status=StepStatus.failed,
failure_reason="Cache key is required", failure_reason="Cache key is required",
) )
run_context.prompt = None context.prompt = None
raise Exception("Cache key is required to run task block in a script") raise Exception("Cache key is required to run task block in a script")
@@ -1178,8 +1178,8 @@ async def extract(
url=url, url=url,
) )
# set the prompt in the RunContext # set the prompt in the RunContext
run_context = script_run_context_manager.ensure_run_context() context = skyvern_context.ensure_context()
run_context.prompt = prompt context.prompt = prompt
output: dict[str, Any] | list | str | None = None output: dict[str, Any] | list | str | None = None
if cache_key: if cache_key:
@@ -1213,7 +1213,7 @@ async def extract(
) )
raise raise
finally: finally:
run_context.prompt = None context.prompt = None
else: else:
if workflow_run_block_id: if workflow_run_block_id:
await _update_workflow_block( await _update_workflow_block(
@@ -1225,7 +1225,7 @@ async def extract(
step_status=StepStatus.failed, step_status=StepStatus.failed,
failure_reason="Cache key is required", failure_reason="Cache key is required",
) )
run_context.prompt = None context.prompt = None
raise Exception("Cache key is required to run task block in a script") raise Exception("Cache key is required to run task block in a script")
@@ -1296,8 +1296,8 @@ async def generate_text(
new_text = text or "" new_text = text or ""
if intention and data: if intention and data:
try: try:
run_context = script_run_context_manager.ensure_run_context() context = skyvern_context.ensure_context()
prompt = run_context.prompt prompt = context.prompt
# Build the element tree of the current page for the prompt # Build the element tree of the current page for the prompt
payload_str = json.dumps(data) if isinstance(data, (dict, list)) else (data or "") payload_str = json.dumps(data) if isinstance(data, (dict, list)) else (data or "")
script_generation_input_text_prompt = prompt_engine.load_prompt( script_generation_input_text_prompt = prompt_engine.load_prompt(