Extract shared core from MCP tools, add CLI browser commands (#4768)

This commit is contained in:
Marc Kelechava
2026-02-17 11:24:56 -08:00
committed by GitHub
parent aacc612365
commit 7c5be8fefe
14 changed files with 1304 additions and 113 deletions

View File

@@ -1,23 +1,51 @@
import logging
import typer import typer
from dotenv import load_dotenv from dotenv import load_dotenv
from skyvern.forge.sdk.forge_log import setup_logger as _setup_logger
from skyvern.utils.env_paths import resolve_backend_env_path from skyvern.utils.env_paths import resolve_backend_env_path
from .credentials import credentials_app from ..credentials import credentials_app
from .docs import docs_app from ..docs import docs_app
from .init_command import init_browser, init_env from ..init_command import init_browser, init_env
from .quickstart import quickstart_app from ..quickstart import quickstart_app
from .run_commands import run_app from ..run_commands import run_app
from .status import status_app from ..status import status_app
from .stop_commands import stop_app from ..stop_commands import stop_app
from .tasks import tasks_app from ..tasks import tasks_app
from .workflow import workflow_app from ..workflow import workflow_app
from .browser import browser_app
_cli_logging_configured = False
def configure_cli_logging() -> None:
"""Configure CLI log levels once at runtime (not at import time)."""
global _cli_logging_configured
if _cli_logging_configured:
return
_cli_logging_configured = True
# Suppress noisy SDK/third-party logs for CLI execution only.
for logger_name in ("skyvern", "httpx", "litellm", "playwright", "httpcore"):
logging.getLogger(logger_name).setLevel(logging.WARNING)
_setup_logger()
cli_app = typer.Typer( cli_app = typer.Typer(
help=("""[bold]Skyvern CLI[/bold]\nManage and run your local Skyvern environment."""), help=("""[bold]Skyvern CLI[/bold]\nManage and run your local Skyvern environment."""),
no_args_is_help=True, no_args_is_help=True,
rich_markup_mode="rich", rich_markup_mode="rich",
) )
@cli_app.callback()
def cli_callback() -> None:
"""Configure CLI logging before command execution."""
configure_cli_logging()
cli_app.add_typer( cli_app.add_typer(
run_app, run_app,
name="run", name="run",
@@ -40,6 +68,9 @@ cli_app.add_typer(
quickstart_app, name="quickstart", help="One-command setup and start for Skyvern (combines init and run)." quickstart_app, name="quickstart", help="One-command setup and start for Skyvern (combines init and run)."
) )
# Browser automation commands
cli_app.add_typer(browser_app, name="browser", help="Browser automation commands.")
@init_app.callback() @init_app.callback()
def init_callback( def init_callback(

View File

@@ -0,0 +1,9 @@
from dotenv import load_dotenv
from skyvern.utils.env_paths import resolve_backend_env_path
from . import cli_app
if __name__ == "__main__": # pragma: no cover - manual CLI invocation
load_dotenv(resolve_backend_env_path())
cli_app()

View File

@@ -0,0 +1,52 @@
from __future__ import annotations
import json
import sys
from typing import Any
from rich.console import Console
from rich.table import Table
console = Console()
def output(
data: Any,
*,
action: str = "",
json_mode: bool = False,
) -> None:
if json_mode:
envelope: dict[str, Any] = {"ok": True, "action": action, "data": data, "error": None}
json.dump(envelope, sys.stdout, indent=2, default=str)
sys.stdout.write("\n")
return
if isinstance(data, list) and data and isinstance(data[0], dict):
table = Table()
for key in data[0]:
table.add_column(key.replace("_", " ").title())
for row in data:
table.add_row(*[str(v) for v in row.values()])
console.print(table)
elif isinstance(data, dict):
for key, value in data.items():
console.print(f"[bold]{key}:[/bold] {value}")
else:
console.print(str(data))
def output_error(message: str, *, hint: str = "", json_mode: bool = False, exit_code: int = 1) -> None:
if json_mode:
envelope: dict[str, Any] = {
"ok": False,
"action": "",
"data": None,
"error": {"message": message, "hint": hint},
}
json.dump(envelope, sys.stdout, indent=2, default=str)
sys.stdout.write("\n")
raise SystemExit(exit_code)
console.print(f"[red]Error: {message}[/red]")
if hint:
console.print(f"[yellow]Hint: {hint}[/yellow]")
raise SystemExit(exit_code)

View File

@@ -0,0 +1,48 @@
from __future__ import annotations
import json
from dataclasses import asdict, dataclass
from datetime import datetime, timezone
from pathlib import Path
STATE_DIR = Path.home() / ".skyvern"
STATE_FILE = STATE_DIR / "state.json"
_TTL_SECONDS = 86400 # 24 hours
@dataclass
class CLIState:
session_id: str | None = None
cdp_url: str | None = None
mode: str | None = None # "cloud", "local", or "cdp"
created_at: str | None = None
def save_state(state: CLIState) -> None:
STATE_DIR.mkdir(parents=True, exist_ok=True)
STATE_DIR.chmod(0o700)
data = asdict(state)
data["created_at"] = datetime.now(timezone.utc).isoformat()
STATE_FILE.write_text(json.dumps(data))
STATE_FILE.chmod(0o600)
def load_state() -> CLIState | None:
if not STATE_FILE.exists():
return None
try:
data = json.loads(STATE_FILE.read_text())
created_at = data.get("created_at")
if created_at:
age = (datetime.now(timezone.utc) - datetime.fromisoformat(created_at)).total_seconds()
if age > _TTL_SECONDS:
return None
return CLIState(**{k: v for k, v in data.items() if k in CLIState.__dataclass_fields__})
except Exception:
return None
def clear_state() -> None:
if STATE_FILE.exists():
STATE_FILE.unlink()

View File

@@ -0,0 +1,325 @@
from __future__ import annotations
import asyncio
from dataclasses import asdict, dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Literal
import typer
from skyvern.cli.commands._output import output, output_error
from skyvern.cli.commands._state import CLIState, clear_state, load_state, save_state
from skyvern.cli.core.artifacts import save_artifact
from skyvern.cli.core.browser_ops import do_act, do_extract, do_navigate, do_screenshot
from skyvern.cli.core.client import get_skyvern
from skyvern.cli.core.guards import GuardError, check_password_prompt, validate_wait_until
from skyvern.cli.core.session_ops import do_session_close, do_session_create, do_session_list
browser_app = typer.Typer(help="Browser automation commands.", no_args_is_help=True)
session_app = typer.Typer(help="Manage browser sessions.", no_args_is_help=True)
browser_app.add_typer(session_app, name="session")
@dataclass(frozen=True)
class ConnectionTarget:
mode: Literal["cloud", "cdp"]
session_id: str | None = None
cdp_url: str | None = None
def _resolve_connection(session: str | None, cdp: str | None) -> ConnectionTarget:
if session and cdp:
raise typer.BadParameter("Pass only one of --session or --cdp.")
if session:
return ConnectionTarget(mode="cloud", session_id=session)
if cdp:
return ConnectionTarget(mode="cdp", cdp_url=cdp)
state = load_state()
if state:
if state.mode == "cdp" and state.cdp_url:
return ConnectionTarget(mode="cdp", cdp_url=state.cdp_url)
if state.session_id:
return ConnectionTarget(mode="cloud", session_id=state.session_id)
if state.cdp_url:
return ConnectionTarget(mode="cdp", cdp_url=state.cdp_url)
raise typer.BadParameter(
"No active browser connection. Create one with: skyvern browser session create\n"
"Or connect with: skyvern browser session connect --cdp ws://...\n"
"Or specify: --session pbs_... / --cdp ws://..."
)
async def _connect_browser(connection: ConnectionTarget) -> Any:
skyvern = get_skyvern()
if connection.mode == "cloud":
if not connection.session_id:
raise typer.BadParameter("Cloud mode requires --session or an active cloud session in state.")
return await skyvern.connect_to_cloud_browser_session(connection.session_id)
if not connection.cdp_url:
raise typer.BadParameter("CDP mode requires --cdp or an active CDP URL in state.")
return await skyvern.connect_to_browser_over_cdp(connection.cdp_url)
# ---------------------------------------------------------------------------
# Session commands
# ---------------------------------------------------------------------------
@session_app.command("create")
def session_create(
timeout: int = typer.Option(60, help="Session timeout in minutes."),
proxy: str | None = typer.Option(None, help="Proxy location (e.g. RESIDENTIAL)."),
local: bool = typer.Option(False, "--local", help="Launch a local browser instead of cloud."),
headless: bool = typer.Option(False, "--headless", help="Run local browser headless."),
json_output: bool = typer.Option(False, "--json", help="Output as JSON."),
) -> None:
"""Create a new browser session."""
if local:
output_error(
"Local browser sessions are not yet supported in CLI mode.",
hint="Use MCP (skyvern run mcp) for local browser sessions, or omit --local for cloud sessions.",
json_mode=json_output,
)
async def _run() -> dict:
skyvern = get_skyvern()
_browser, result = await do_session_create(
skyvern,
timeout=timeout,
proxy_location=proxy,
)
save_state(CLIState(session_id=result.session_id, cdp_url=None, mode="cloud"))
return {
"session_id": result.session_id,
"mode": "cloud",
"timeout_minutes": result.timeout_minutes,
}
try:
data = asyncio.run(_run())
output(data, action="session_create", json_mode=json_output)
except GuardError as e:
output_error(str(e), hint=e.hint, json_mode=json_output)
except typer.BadParameter:
raise
except Exception as e:
output_error(str(e), hint="Check your API key and network connection.", json_mode=json_output)
@session_app.command("close")
def session_close(
session: str | None = typer.Option(None, help="Browser session ID to close."),
cdp: str | None = typer.Option(None, "--cdp", help="CDP WebSocket URL to detach from."),
json_output: bool = typer.Option(False, "--json", help="Output as JSON."),
) -> None:
"""Close a browser session."""
async def _run() -> dict:
connection = _resolve_connection(session, cdp)
if connection.mode == "cdp":
clear_state()
return {"cdp_url": connection.cdp_url, "closed": False, "detached": True}
if not connection.session_id:
raise typer.BadParameter("Cloud mode requires a browser session ID.")
skyvern = get_skyvern()
result = await do_session_close(skyvern, connection.session_id)
clear_state()
return {"session_id": result.session_id, "closed": result.closed}
try:
data = asyncio.run(_run())
output(data, action="session_close", json_mode=json_output)
except typer.BadParameter:
raise
except Exception as e:
output_error(str(e), hint="Verify the session ID or CDP URL is correct.", json_mode=json_output)
@session_app.command("connect")
def session_connect(
session: str | None = typer.Option(None, help="Cloud browser session ID."),
cdp: str | None = typer.Option(None, "--cdp", help="CDP WebSocket URL."),
json_output: bool = typer.Option(False, "--json", help="Output as JSON."),
) -> None:
"""Connect to an existing browser session (cloud or CDP) and persist it as active state."""
if not session and not cdp:
raise typer.BadParameter("Specify one of --session or --cdp.")
async def _run() -> dict:
connection = _resolve_connection(session, cdp)
browser = await _connect_browser(connection)
await browser.get_working_page()
if connection.mode == "cdp":
save_state(CLIState(session_id=None, cdp_url=connection.cdp_url, mode="cdp"))
return {"connected": True, "mode": "cdp", "cdp_url": connection.cdp_url}
save_state(CLIState(session_id=connection.session_id, cdp_url=None, mode="cloud"))
return {"connected": True, "mode": "cloud", "session_id": connection.session_id}
try:
data = asyncio.run(_run())
output(data, action="session_connect", json_mode=json_output)
except typer.BadParameter:
raise
except Exception as e:
output_error(str(e), hint="Verify the session ID or CDP URL is reachable.", json_mode=json_output)
@session_app.command("list")
def session_list(
json_output: bool = typer.Option(False, "--json", help="Output as JSON."),
) -> None:
"""List all browser sessions."""
async def _run() -> list[dict]:
skyvern = get_skyvern()
sessions = await do_session_list(skyvern)
return [asdict(s) for s in sessions]
try:
data = asyncio.run(_run())
output(data, action="session_list", json_mode=json_output)
except Exception as e:
output_error(str(e), hint="Check your API key and network connection.", json_mode=json_output)
# ---------------------------------------------------------------------------
# Browser commands
# ---------------------------------------------------------------------------
@browser_app.command("navigate")
def navigate(
url: str = typer.Option(..., help="URL to navigate to."),
session: str | None = typer.Option(None, help="Browser session ID."),
cdp: str | None = typer.Option(None, "--cdp", help="CDP WebSocket URL."),
timeout: int = typer.Option(30000, help="Navigation timeout in milliseconds."),
wait_until: str | None = typer.Option(None, help="Wait condition: load, domcontentloaded, networkidle, commit."),
json_output: bool = typer.Option(False, "--json", help="Output as JSON."),
) -> None:
"""Navigate to a URL in the browser session."""
async def _run() -> dict:
validate_wait_until(wait_until)
connection = _resolve_connection(session, cdp)
browser = await _connect_browser(connection)
page = await browser.get_working_page()
result = await do_navigate(page, url, timeout=timeout, wait_until=wait_until)
return {"url": result.url, "title": result.title}
try:
data = asyncio.run(_run())
output(data, action="navigate", json_mode=json_output)
except GuardError as e:
output_error(str(e), hint=e.hint, json_mode=json_output)
except typer.BadParameter:
raise
except Exception as e:
output_error(str(e), hint="Check the URL is valid and the session is active.", json_mode=json_output)
@browser_app.command("screenshot")
def screenshot(
session: str | None = typer.Option(None, help="Browser session ID."),
cdp: str | None = typer.Option(None, "--cdp", help="CDP WebSocket URL."),
full_page: bool = typer.Option(False, "--full-page", help="Capture the full scrollable page."),
selector: str | None = typer.Option(None, help="CSS selector to screenshot."),
output_path: str | None = typer.Option(None, "--output", help="Custom output file path."),
json_output: bool = typer.Option(False, "--json", help="Output as JSON."),
) -> None:
"""Take a screenshot of the current page."""
async def _run() -> dict:
connection = _resolve_connection(session, cdp)
browser = await _connect_browser(connection)
page = await browser.get_working_page()
result = await do_screenshot(page, full_page=full_page, selector=selector)
if output_path:
path = Path(output_path)
path.parent.mkdir(parents=True, exist_ok=True)
path.write_bytes(result.data)
return {"path": str(path), "bytes": len(result.data), "full_page": result.full_page}
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
artifact = save_artifact(
content=result.data,
kind="screenshot",
filename=f"screenshot_{timestamp}.png",
mime="image/png",
session_id=connection.session_id,
)
return {"path": artifact.path, "bytes": artifact.bytes, "full_page": result.full_page}
try:
data = asyncio.run(_run())
output(data, action="screenshot", json_mode=json_output)
except GuardError as e:
output_error(str(e), hint=e.hint, json_mode=json_output)
except typer.BadParameter:
raise
except Exception as e:
output_error(str(e), hint="Ensure the session is active and the page has loaded.", json_mode=json_output)
@browser_app.command("act")
def act(
prompt: str = typer.Option(..., help="Natural language action to perform."),
session: str | None = typer.Option(None, help="Browser session ID."),
cdp: str | None = typer.Option(None, "--cdp", help="CDP WebSocket URL."),
json_output: bool = typer.Option(False, "--json", help="Output as JSON."),
) -> None:
"""Perform a natural language action on the current page."""
async def _run() -> dict:
check_password_prompt(prompt)
connection = _resolve_connection(session, cdp)
browser = await _connect_browser(connection)
page = await browser.get_working_page()
result = await do_act(page, prompt)
return {"prompt": result.prompt, "completed": result.completed}
try:
data = asyncio.run(_run())
output(data, action="act", json_mode=json_output)
except GuardError as e:
output_error(str(e), hint=e.hint, json_mode=json_output)
except typer.BadParameter:
raise
except Exception as e:
output_error(str(e), hint="Simplify the prompt or break into steps.", json_mode=json_output)
@browser_app.command("extract")
def extract(
prompt: str = typer.Option(..., help="What data to extract from the page."),
session: str | None = typer.Option(None, help="Browser session ID."),
cdp: str | None = typer.Option(None, "--cdp", help="CDP WebSocket URL."),
schema: str | None = typer.Option(None, help="JSON schema for structured extraction."),
json_output: bool = typer.Option(False, "--json", help="Output as JSON."),
) -> None:
"""Extract data from the current page using natural language."""
async def _run() -> dict:
connection = _resolve_connection(session, cdp)
browser = await _connect_browser(connection)
page = await browser.get_working_page()
result = await do_extract(page, prompt, schema=schema)
return {"prompt": prompt, "extracted": result.extracted}
try:
data = asyncio.run(_run())
output(data, action="extract", json_mode=json_output)
except GuardError as e:
output_error(str(e), hint=e.hint, json_mode=json_output)
except typer.BadParameter:
raise
except Exception as e:
output_error(str(e), hint="Simplify the prompt or provide a JSON schema.", json_mode=json_output)

View File

@@ -0,0 +1,87 @@
"""Shared browser operations for MCP tools and CLI commands.
Each function: validate inputs -> call SDK -> return typed result.
Session resolution and output formatting are caller responsibilities.
"""
from __future__ import annotations
import json
from dataclasses import dataclass
from typing import Any
from .guards import GuardError
@dataclass
class NavigateResult:
url: str
title: str
@dataclass
class ScreenshotResult:
data: bytes
full_page: bool = False
@dataclass
class ActResult:
prompt: str
completed: bool = True
@dataclass
class ExtractResult:
extracted: Any = None
def parse_extract_schema(schema: str | dict[str, Any] | None) -> dict[str, Any] | None:
"""Parse and validate an extraction schema payload."""
if schema is None:
return None
if isinstance(schema, dict):
return schema
try:
return json.loads(schema)
except (json.JSONDecodeError, TypeError) as e:
raise GuardError(f"Invalid JSON schema: {e}", "Provide schema as a valid JSON string")
async def do_navigate(
page: Any,
url: str,
timeout: int = 30000,
wait_until: str | None = None,
) -> NavigateResult:
await page.goto(url, timeout=timeout, wait_until=wait_until)
return NavigateResult(url=page.url, title=await page.title())
async def do_screenshot(
page: Any,
full_page: bool = False,
selector: str | None = None,
) -> ScreenshotResult:
if selector:
element = page.locator(selector)
data = await element.screenshot()
else:
data = await page.screenshot(full_page=full_page)
return ScreenshotResult(data=data, full_page=full_page)
async def do_act(page: Any, prompt: str) -> ActResult:
await page.act(prompt)
return ActResult(prompt=prompt, completed=True)
async def do_extract(
page: Any,
prompt: str,
schema: str | dict[str, Any] | None = None,
) -> ExtractResult:
parsed_schema = parse_extract_schema(schema)
extracted = await page.extract(prompt=prompt, schema=parsed_schema)
return ExtractResult(extracted=extracted)

View File

@@ -0,0 +1,81 @@
"""Shared input validation guards for both MCP and CLI surfaces."""
from __future__ import annotations
import re
PASSWORD_PATTERN = re.compile(
r"\bpass(?:word|phrase|code)s?\b|\bsecret\b|\bcredential\b|\bpin\s*(?:code)?\b|\bpwd\b|\bpasswd\b",
re.IGNORECASE,
)
JS_PASSWORD_PATTERN = re.compile(
r"""(?:type\s*=\s*['"]?password|\.type\s*===?\s*['"]password|input\[type=password\]).*?\.value\s*=""",
re.IGNORECASE,
)
CREDENTIAL_HINT = (
"Use skyvern_login with a stored credential to authenticate. "
"Create credentials via CLI: skyvern credentials add. "
"Never pass passwords through tool calls."
)
VALID_WAIT_UNTIL = ("load", "domcontentloaded", "networkidle", "commit")
VALID_BUTTONS = ("left", "right", "middle")
VALID_ELEMENT_STATES = ("visible", "hidden", "attached", "detached")
class GuardError(Exception):
"""Raised when an input guard blocks an operation."""
def __init__(self, message: str, hint: str = "") -> None:
super().__init__(message)
self.hint = hint
def check_password_prompt(text: str) -> None:
"""Block prompts containing password/credential terms."""
if PASSWORD_PATTERN.search(text):
raise GuardError(
"Cannot perform password/credential actions — credentials must not be passed through tool calls",
CREDENTIAL_HINT,
)
def check_js_password(expression: str) -> None:
"""Block JS expressions that set password field values."""
if JS_PASSWORD_PATTERN.search(expression):
raise GuardError(
"Cannot set password field values via JavaScript — credentials must not be passed through tool calls",
CREDENTIAL_HINT,
)
def validate_wait_until(value: str | None) -> None:
if value is not None and value not in VALID_WAIT_UNTIL:
raise GuardError(
f"Invalid wait_until: {value}",
"Use load, domcontentloaded, networkidle, or commit",
)
def validate_button(value: str | None) -> None:
if value is not None and value not in VALID_BUTTONS:
raise GuardError(f"Invalid button: {value}", "Use left, right, or middle")
def resolve_ai_mode(
selector: str | None,
intent: str | None,
) -> tuple[str | None, str | None]:
"""Determine AI mode from selector/intent combination.
Returns (ai_mode, error_code) -- if error_code is set, the call should fail.
"""
if intent and not selector:
return "proactive", None
if intent and selector:
return "fallback", None
if selector and not intent:
return None, None
return None, "INVALID_INPUT"

View File

@@ -0,0 +1,74 @@
"""Shared session operations for MCP tools and CLI commands."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from skyvern.schemas.runs import ProxyLocation
@dataclass
class SessionCreateResult:
session_id: str | None
local: bool = False
headless: bool = False
timeout_minutes: int | None = None
@dataclass
class SessionCloseResult:
session_id: str | None
closed: bool = True
@dataclass
class SessionInfo:
session_id: str
status: str | None
started_at: str | None
timeout: int | None
runnable_id: str | None = None
available: bool = False
async def do_session_create(
skyvern: Any,
timeout: int = 60,
proxy_location: str | None = None,
local: bool = False,
headless: bool = False,
) -> tuple[Any, SessionCreateResult]:
"""Create browser session. Returns (browser, result)."""
if local:
browser = await skyvern.launch_local_browser(headless=headless)
return browser, SessionCreateResult(session_id=None, local=True, headless=headless)
proxy = ProxyLocation(proxy_location) if proxy_location else None
browser = await skyvern.launch_cloud_browser(timeout=timeout, proxy_location=proxy)
return browser, SessionCreateResult(
session_id=browser.browser_session_id,
timeout_minutes=timeout,
)
async def do_session_close(skyvern: Any, session_id: str) -> SessionCloseResult:
"""Close a browser session by ID."""
await skyvern.close_browser_session(session_id)
return SessionCloseResult(session_id=session_id)
async def do_session_list(skyvern: Any) -> list[SessionInfo]:
"""List all browser sessions."""
sessions = await skyvern.get_browser_sessions()
return [
SessionInfo(
session_id=s.browser_session_id,
status=s.status,
started_at=s.started_at.isoformat() if s.started_at else None,
timeout=s.timeout,
runnable_id=s.runnable_id,
available=s.runnable_id is None and s.browser_address is not None,
)
for s in sessions
]

View File

@@ -4,13 +4,24 @@ import asyncio
import base64 import base64
import json import json
import logging import logging
import re
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Annotated, Any from typing import Annotated, Any
from playwright.async_api import TimeoutError as PlaywrightTimeoutError from playwright.async_api import TimeoutError as PlaywrightTimeoutError
from pydantic import Field from pydantic import Field
from skyvern.cli.core.browser_ops import do_act, do_extract, do_navigate, do_screenshot, parse_extract_schema
from skyvern.cli.core.guards import (
CREDENTIAL_HINT,
JS_PASSWORD_PATTERN,
PASSWORD_PATTERN,
GuardError,
check_password_prompt,
)
from skyvern.cli.core.guards import resolve_ai_mode as _resolve_ai_mode
from skyvern.cli.core.guards import (
validate_wait_until,
)
from skyvern.schemas.run_blocks import CredentialType from skyvern.schemas.run_blocks import CredentialType
from ._common import ( from ._common import (
@@ -24,39 +35,6 @@ from ._session import BrowserNotAvailableError, get_page, no_browser_error
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
_PASSWORD_PATTERN = re.compile(
r"\bpass(?:word|phrase|code)s?\b|\bsecret\b|\bcredential\b|\bpin\s*(?:code)?\b|\bpwd\b|\bpasswd\b",
re.IGNORECASE,
)
_CREDENTIAL_ERROR_HINT = (
"Use skyvern_login with a stored credential to authenticate. "
"Create credentials via CLI: skyvern credentials add. "
"Never pass passwords through tool calls."
)
_JS_PASSWORD_PATTERN = re.compile(
r"""(?:type\s*=\s*['"]?password|\.type\s*===?\s*['"]password|input\[type=password\]).*?\.value\s*=""",
re.IGNORECASE,
)
def _resolve_ai_mode(
selector: str | None,
intent: str | None,
) -> tuple[str | None, str | None]:
"""Determine AI mode from selector/intent combination.
Returns (ai_mode, error_code) — if error_code is set, the call should fail.
"""
if intent and not selector:
return "proactive", None
if intent and selector:
return "fallback", None
if selector and not intent:
return None, None
return None, "INVALID_INPUT"
async def skyvern_navigate( async def skyvern_navigate(
url: Annotated[str, "The URL to navigate to"], url: Annotated[str, "The URL to navigate to"],
@@ -80,15 +58,13 @@ async def skyvern_navigate(
Returns the final URL (after redirects) and page title. Returns the final URL (after redirects) and page title.
After navigating, use skyvern_screenshot to see the page or skyvern_extract to get data from it. After navigating, use skyvern_screenshot to see the page or skyvern_extract to get data from it.
""" """
if wait_until is not None and wait_until not in ("load", "domcontentloaded", "networkidle", "commit"): try:
validate_wait_until(wait_until)
except GuardError as e:
return make_result( return make_result(
"skyvern_navigate", "skyvern_navigate",
ok=False, ok=False,
error=make_error( error=make_error(ErrorCode.INVALID_INPUT, str(e), e.hint),
ErrorCode.INVALID_INPUT,
f"Invalid wait_until: {wait_until}",
"Use load, domcontentloaded, networkidle, or commit",
),
) )
try: try:
@@ -98,10 +74,16 @@ async def skyvern_navigate(
with Timer() as timer: with Timer() as timer:
try: try:
await page.goto(url, timeout=timeout, wait_until=wait_until) result = await do_navigate(page, url, timeout=timeout, wait_until=wait_until)
timer.mark("sdk") timer.mark("sdk")
final_url = page.url except GuardError as e:
title = await page.title() return make_result(
"skyvern_navigate",
ok=False,
browser_context=ctx,
timing_ms=timer.timing_ms,
error=make_error(ErrorCode.INVALID_INPUT, str(e), e.hint),
)
except Exception as e: except Exception as e:
return make_result( return make_result(
"skyvern_navigate", "skyvern_navigate",
@@ -114,7 +96,7 @@ async def skyvern_navigate(
return make_result( return make_result(
"skyvern_navigate", "skyvern_navigate",
browser_context=ctx, browser_context=ctx,
data={"url": final_url, "title": title, "sdk_equivalent": f'await page.goto("{url}")'}, data={"url": result.url, "title": result.title, "sdk_equivalent": f'await page.goto("{url}")'},
timing_ms=timer.timing_ms, timing_ms=timer.timing_ms,
) )
@@ -355,14 +337,14 @@ async def skyvern_type(
""" """
# Block password entry — redirect to skyvern_login # Block password entry — redirect to skyvern_login
target_text = f"{intent or ''} {selector or ''}" target_text = f"{intent or ''} {selector or ''}"
if _PASSWORD_PATTERN.search(target_text): if PASSWORD_PATTERN.search(target_text):
return make_result( return make_result(
"skyvern_type", "skyvern_type",
ok=False, ok=False,
error=make_error( error=make_error(
ErrorCode.INVALID_INPUT, ErrorCode.INVALID_INPUT,
"Cannot type into password fields — credentials must not be passed through tool calls", "Cannot type into password fields — credentials must not be passed through tool calls",
_CREDENTIAL_ERROR_HINT, CREDENTIAL_HINT,
), ),
) )
@@ -402,7 +384,7 @@ async def skyvern_type(
error=make_error( error=make_error(
ErrorCode.INVALID_INPUT, ErrorCode.INVALID_INPUT,
"Cannot type into password fields — credentials must not be passed through tool calls", "Cannot type into password fields — credentials must not be passed through tool calls",
_CREDENTIAL_ERROR_HINT, CREDENTIAL_HINT,
), ),
) )
@@ -491,11 +473,7 @@ async def skyvern_screenshot(
with Timer() as timer: with Timer() as timer:
try: try:
if selector: result = await do_screenshot(page, full_page=full_page, selector=selector)
element = page.locator(selector)
screenshot_bytes = await element.screenshot()
else:
screenshot_bytes = await page.screenshot(full_page=full_page)
timer.mark("sdk") timer.mark("sdk")
except Exception as e: except Exception as e:
return make_result( return make_result(
@@ -507,7 +485,7 @@ async def skyvern_screenshot(
) )
if inline: if inline:
data_b64 = base64.b64encode(screenshot_bytes).decode("utf-8") data_b64 = base64.b64encode(result.data).decode("utf-8")
return make_result( return make_result(
"skyvern_screenshot", "skyvern_screenshot",
browser_context=ctx, browser_context=ctx,
@@ -515,7 +493,7 @@ async def skyvern_screenshot(
"inline": True, "inline": True,
"data": data_b64, "data": data_b64,
"mime": "image/png", "mime": "image/png",
"bytes": len(screenshot_bytes), "bytes": len(result.data),
"sdk_equivalent": "await page.screenshot()", "sdk_equivalent": "await page.screenshot()",
}, },
timing_ms=timer.timing_ms, timing_ms=timer.timing_ms,
@@ -525,7 +503,7 @@ async def skyvern_screenshot(
ts = datetime.now(timezone.utc).strftime("%H%M%S_%f") ts = datetime.now(timezone.utc).strftime("%H%M%S_%f")
filename = f"screenshot_{ts}.png" filename = f"screenshot_{ts}.png"
artifact = save_artifact( artifact = save_artifact(
screenshot_bytes, result.data,
kind="screenshot", kind="screenshot",
filename=filename, filename=filename,
mime="image/png", mime="image/png",
@@ -896,14 +874,14 @@ async def skyvern_evaluate(
Security: This executes arbitrary JS in the page context. Only use with trusted expressions. Security: This executes arbitrary JS in the page context. Only use with trusted expressions.
""" """
# Block JS that sets password field values # Block JS that sets password field values
if _JS_PASSWORD_PATTERN.search(expression): if JS_PASSWORD_PATTERN.search(expression):
return make_result( return make_result(
"skyvern_evaluate", "skyvern_evaluate",
ok=False, ok=False,
error=make_error( error=make_error(
ErrorCode.INVALID_INPUT, ErrorCode.INVALID_INPUT,
"Cannot set password field values via JavaScript — credentials must not be passed through tool calls", "Cannot set password field values via JavaScript — credentials must not be passed through tool calls",
_CREDENTIAL_ERROR_HINT, CREDENTIAL_HINT,
), ),
) )
@@ -947,20 +925,17 @@ async def skyvern_extract(
For visual inspection instead of structured data, use skyvern_screenshot. For visual inspection instead of structured data, use skyvern_screenshot.
Optionally provide a JSON `schema` to enforce the output structure (pass as a JSON string). Optionally provide a JSON `schema` to enforce the output structure (pass as a JSON string).
""" """
parsed_schema: dict[str, Any] | None = None
if schema is not None: if schema is not None:
try: try:
parsed_schema = json.loads(schema) parsed_schema = parse_extract_schema(schema)
except (json.JSONDecodeError, TypeError) as e: except GuardError as e:
return make_result( return make_result(
"skyvern_extract", "skyvern_extract",
ok=False, ok=False,
error=make_error( error=make_error(ErrorCode.INVALID_INPUT, str(e), e.hint),
ErrorCode.INVALID_INPUT,
f"Invalid JSON schema: {e}",
"Provide schema as a valid JSON string",
),
) )
else:
parsed_schema = None
try: try:
page, ctx = await get_page(session_id=session_id, cdp_url=cdp_url) page, ctx = await get_page(session_id=session_id, cdp_url=cdp_url)
@@ -969,8 +944,16 @@ async def skyvern_extract(
with Timer() as timer: with Timer() as timer:
try: try:
extracted = await page.extract(prompt=prompt, schema=parsed_schema) result = await do_extract(page, prompt, schema=parsed_schema)
timer.mark("sdk") timer.mark("sdk")
except GuardError as e:
return make_result(
"skyvern_extract",
ok=False,
browser_context=ctx,
timing_ms=timer.timing_ms,
error=make_error(ErrorCode.INVALID_INPUT, str(e), e.hint),
)
except Exception as e: except Exception as e:
return make_result( return make_result(
"skyvern_extract", "skyvern_extract",
@@ -983,7 +966,10 @@ async def skyvern_extract(
return make_result( return make_result(
"skyvern_extract", "skyvern_extract",
browser_context=ctx, browser_context=ctx,
data={"extracted": extracted, "sdk_equivalent": f'await page.extract(prompt="{prompt}")'}, data={
"extracted": result.extracted,
"sdk_equivalent": f'await page.extract(prompt="{prompt}")',
},
timing_ms=timer.timing_ms, timing_ms=timer.timing_ms,
) )
@@ -1037,16 +1023,13 @@ async def skyvern_act(
For multi-step automations (4+ pages), use skyvern_workflow_create with one block per step. For multi-step automations (4+ pages), use skyvern_workflow_create with one block per step.
For quick one-off multi-page tasks, use skyvern_run_task. For quick one-off multi-page tasks, use skyvern_run_task.
""" """
# Block login/password actions — redirect to skyvern_login try:
if _PASSWORD_PATTERN.search(prompt): check_password_prompt(prompt)
except GuardError as e:
return make_result( return make_result(
"skyvern_act", "skyvern_act",
ok=False, ok=False,
error=make_error( error=make_error(ErrorCode.INVALID_INPUT, str(e), e.hint),
ErrorCode.INVALID_INPUT,
"Cannot perform password/credential actions — credentials must not be passed through tool calls",
_CREDENTIAL_ERROR_HINT,
),
) )
try: try:
@@ -1056,8 +1039,16 @@ async def skyvern_act(
with Timer() as timer: with Timer() as timer:
try: try:
await page.act(prompt) result = await do_act(page, prompt)
timer.mark("sdk") timer.mark("sdk")
except GuardError as e:
return make_result(
"skyvern_act",
ok=False,
browser_context=ctx,
timing_ms=timer.timing_ms,
error=make_error(ErrorCode.INVALID_INPUT, str(e), e.hint),
)
except Exception as e: except Exception as e:
return make_result( return make_result(
"skyvern_act", "skyvern_act",
@@ -1070,7 +1061,11 @@ async def skyvern_act(
return make_result( return make_result(
"skyvern_act", "skyvern_act",
browser_context=ctx, browser_context=ctx,
data={"prompt": prompt, "completed": True, "sdk_equivalent": f'await page.act("{prompt}")'}, data={
"prompt": result.prompt,
"completed": result.completed,
"sdk_equivalent": f'await page.act("{prompt}")',
},
timing_ms=timer.timing_ms, timing_ms=timer.timing_ms,
) )
@@ -1099,14 +1094,14 @@ async def skyvern_run_task(
For simple single-step actions on the current page, use skyvern_act instead. For simple single-step actions on the current page, use skyvern_act instead.
""" """
# Block password/credential actions — redirect to skyvern_login # Block password/credential actions — redirect to skyvern_login
if _PASSWORD_PATTERN.search(prompt): if PASSWORD_PATTERN.search(prompt):
return make_result( return make_result(
"skyvern_run_task", "skyvern_run_task",
ok=False, ok=False,
error=make_error( error=make_error(
ErrorCode.INVALID_INPUT, ErrorCode.INVALID_INPUT,
"Cannot perform password/credential actions — credentials must not be passed through tool calls", "Cannot perform password/credential actions — credentials must not be passed through tool calls",
_CREDENTIAL_ERROR_HINT, CREDENTIAL_HINT,
), ),
) )

View File

@@ -4,7 +4,7 @@ from typing import Annotated, Any
from pydantic import Field from pydantic import Field
from skyvern.schemas.runs import ProxyLocation from skyvern.cli.core.session_ops import do_session_close, do_session_create, do_session_list
from ._common import BrowserContext, ErrorCode, Timer, make_error, make_result from ._common import BrowserContext, ErrorCode, Timer, make_error, make_result
from ._session import ( from ._session import (
@@ -30,25 +30,21 @@ async def skyvern_session_create(
with Timer() as timer: with Timer() as timer:
try: try:
skyvern = get_skyvern() skyvern = get_skyvern()
browser, result = await do_session_create(
if local: skyvern,
browser = await skyvern.launch_local_browser(headless=headless) timeout=timeout or 60,
ctx = BrowserContext(mode="local") proxy_location=proxy_location,
set_current_session(SessionState(browser=browser, context=ctx)) local=local,
timer.mark("sdk") headless=headless,
return make_result( )
"skyvern_session_create",
browser_context=ctx,
data={"local": True, "headless": headless},
timing_ms=timer.timing_ms,
)
proxy = ProxyLocation(proxy_location) if proxy_location else None
browser = await skyvern.launch_cloud_browser(timeout=timeout, proxy_location=proxy)
ctx = BrowserContext(mode="cloud_session", session_id=browser.browser_session_id)
set_current_session(SessionState(browser=browser, context=ctx))
timer.mark("sdk") timer.mark("sdk")
if result.local:
ctx = BrowserContext(mode="local")
else:
ctx = BrowserContext(mode="cloud_session", session_id=result.session_id)
set_current_session(SessionState(browser=browser, context=ctx))
except ValueError as e: except ValueError as e:
return make_result( return make_result(
"skyvern_session_create", "skyvern_session_create",
@@ -68,12 +64,20 @@ async def skyvern_session_create(
error=make_error(ErrorCode.SDK_ERROR, str(e), "Failed to create browser session"), error=make_error(ErrorCode.SDK_ERROR, str(e), "Failed to create browser session"),
) )
if result.local:
return make_result(
"skyvern_session_create",
browser_context=ctx,
data={"local": True, "headless": result.headless},
timing_ms=timer.timing_ms,
)
return make_result( return make_result(
"skyvern_session_create", "skyvern_session_create",
browser_context=ctx, browser_context=ctx,
data={ data={
"session_id": browser.browser_session_id, "session_id": result.session_id,
"timeout_minutes": timeout, "timeout_minutes": result.timeout_minutes,
}, },
timing_ms=timer.timing_ms, timing_ms=timer.timing_ms,
) )
@@ -92,13 +96,13 @@ async def skyvern_session_close(
try: try:
if session_id: if session_id:
skyvern = get_skyvern() skyvern = get_skyvern()
await skyvern.close_browser_session(session_id) result = await do_session_close(skyvern, session_id)
if current.context and current.context.session_id == session_id: if current.context and current.context.session_id == session_id:
set_current_session(SessionState()) set_current_session(SessionState())
timer.mark("sdk") timer.mark("sdk")
return make_result( return make_result(
"skyvern_session_close", "skyvern_session_close",
data={"session_id": session_id, "closed": True}, data={"session_id": result.session_id, "closed": result.closed},
timing_ms=timer.timing_ms, timing_ms=timer.timing_ms,
) )
@@ -138,17 +142,17 @@ async def skyvern_session_list() -> dict[str, Any]:
with Timer() as timer: with Timer() as timer:
try: try:
skyvern = get_skyvern() skyvern = get_skyvern()
sessions = await skyvern.get_browser_sessions() sessions = await do_session_list(skyvern)
timer.mark("sdk") timer.mark("sdk")
session_data = [ session_data = [
{ {
"session_id": s.browser_session_id, "session_id": s.session_id,
"status": s.status, "status": s.status,
"started_at": s.started_at.isoformat() if s.started_at else None, "started_at": s.started_at,
"timeout": s.timeout, "timeout": s.timeout,
"runnable_id": s.runnable_id, "runnable_id": s.runnable_id,
"available": s.runnable_id is None and s.browser_address is not None, "available": s.available,
} }
for s in sessions for s in sessions
] ]

View File

@@ -0,0 +1,149 @@
"""Tests for CLI commands infrastructure: _state.py and _output.py."""
from __future__ import annotations
import json
from pathlib import Path
import pytest
import typer
from skyvern.cli.commands._state import CLIState, clear_state, load_state, save_state
# ---------------------------------------------------------------------------
# _state.py
# ---------------------------------------------------------------------------
def _patch_state_dir(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
monkeypatch.setattr("skyvern.cli.commands._state.STATE_DIR", tmp_path)
monkeypatch.setattr("skyvern.cli.commands._state.STATE_FILE", tmp_path / "state.json")
class TestCLIState:
def test_save_load_roundtrip(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
_patch_state_dir(monkeypatch, tmp_path)
save_state(CLIState(session_id="pbs_123", cdp_url=None, mode="cloud"))
loaded = load_state()
assert loaded is not None
assert loaded.session_id == "pbs_123"
assert loaded.cdp_url is None
assert loaded.mode == "cloud"
def test_save_load_roundtrip_cdp(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
_patch_state_dir(monkeypatch, tmp_path)
save_state(CLIState(session_id=None, cdp_url="ws://localhost:9222/devtools/browser/abc", mode="cdp"))
loaded = load_state()
assert loaded is not None
assert loaded.session_id is None
assert loaded.cdp_url == "ws://localhost:9222/devtools/browser/abc"
assert loaded.mode == "cdp"
def test_load_returns_none_when_missing(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr("skyvern.cli.commands._state.STATE_FILE", tmp_path / "nonexistent.json")
assert load_state() is None
def test_24h_ttl_expires(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
_patch_state_dir(monkeypatch, tmp_path)
state_file = tmp_path / "state.json"
state_file.write_text(
json.dumps(
{
"session_id": "pbs_old",
"mode": "cloud",
"created_at": "2020-01-01T00:00:00+00:00",
}
)
)
assert load_state() is None
def test_clear_state(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
_patch_state_dir(monkeypatch, tmp_path)
save_state(CLIState(session_id="pbs_123"))
clear_state()
assert not (tmp_path / "state.json").exists()
def test_load_ignores_corrupt_file(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
state_file = tmp_path / "state.json"
monkeypatch.setattr("skyvern.cli.commands._state.STATE_FILE", state_file)
state_file.write_text("not-json")
assert load_state() is None
# ---------------------------------------------------------------------------
# _output.py
# ---------------------------------------------------------------------------
class TestOutput:
def test_json_envelope(self, capsys: pytest.CaptureFixture) -> None:
from skyvern.cli.commands._output import output
output({"key": "value"}, action="test", json_mode=True)
parsed = json.loads(capsys.readouterr().out)
assert parsed["ok"] is True
assert parsed["action"] == "test"
assert parsed["data"]["key"] == "value"
def test_json_error(self, capsys: pytest.CaptureFixture) -> None:
from skyvern.cli.commands._output import output_error
with pytest.raises(SystemExit, match="1"):
output_error("bad thing", hint="fix it", json_mode=True)
parsed = json.loads(capsys.readouterr().out)
assert parsed["ok"] is False
assert parsed["error"]["message"] == "bad thing"
# ---------------------------------------------------------------------------
# Connection resolution
# ---------------------------------------------------------------------------
class TestResolveConnection:
def test_explicit_session_wins(self) -> None:
from skyvern.cli.commands.browser import _resolve_connection
result = _resolve_connection("pbs_explicit", None)
assert result.mode == "cloud"
assert result.session_id == "pbs_explicit"
assert result.cdp_url is None
def test_explicit_cdp_wins(self) -> None:
from skyvern.cli.commands.browser import _resolve_connection
result = _resolve_connection(None, "ws://localhost:9222/devtools/browser/abc")
assert result.mode == "cdp"
assert result.session_id is None
assert result.cdp_url == "ws://localhost:9222/devtools/browser/abc"
def test_rejects_both_connection_flags(self) -> None:
from skyvern.cli.commands.browser import _resolve_connection
with pytest.raises(typer.BadParameter, match="Pass only one of --session or --cdp"):
_resolve_connection("pbs_explicit", "ws://localhost:9222/devtools/browser/abc")
def test_state_fallback(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
from skyvern.cli.commands.browser import _resolve_connection
_patch_state_dir(monkeypatch, tmp_path)
save_state(CLIState(session_id="pbs_from_state", mode="cloud"))
result = _resolve_connection(None, None)
assert result.mode == "cloud"
assert result.session_id == "pbs_from_state"
def test_state_fallback_cdp(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
from skyvern.cli.commands.browser import _resolve_connection
_patch_state_dir(monkeypatch, tmp_path)
save_state(CLIState(session_id=None, cdp_url="ws://localhost:9222/devtools/browser/abc", mode="cdp"))
result = _resolve_connection(None, None)
assert result.mode == "cdp"
assert result.cdp_url == "ws://localhost:9222/devtools/browser/abc"
def test_no_session_raises(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
from skyvern.cli.commands.browser import _resolve_connection
monkeypatch.setattr("skyvern.cli.commands._state.STATE_FILE", tmp_path / "nonexistent.json")
with pytest.raises(typer.BadParameter, match="No active browser connection"):
_resolve_connection(None, None)

View File

@@ -0,0 +1,37 @@
from __future__ import annotations
import logging
import skyvern.cli.commands as cli_commands
def test_configure_cli_logging_is_idempotent(monkeypatch) -> None:
setup_calls: list[str] = []
monkeypatch.setattr(cli_commands, "_setup_logger", lambda: setup_calls.append("called"))
monkeypatch.setattr(cli_commands, "_cli_logging_configured", False)
logger_names = ("skyvern", "httpx", "litellm", "playwright", "httpcore")
previous_levels = {name: logging.getLogger(name).level for name in logger_names}
try:
cli_commands.configure_cli_logging()
assert setup_calls == ["called"]
for name in logger_names:
assert logging.getLogger(name).level == logging.WARNING
cli_commands.configure_cli_logging()
assert setup_calls == ["called"]
finally:
for name, level in previous_levels.items():
logging.getLogger(name).setLevel(level)
def test_cli_callback_configures_logging(monkeypatch) -> None:
called = False
def _fake_configure() -> None:
nonlocal called
called = True
monkeypatch.setattr(cli_commands, "configure_cli_logging", _fake_configure)
cli_commands.cli_callback()
assert called

View File

@@ -0,0 +1,234 @@
"""Tests for skyvern.cli.core shared modules (guards, browser_ops, session_ops)."""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest
from skyvern.cli.core.browser_ops import do_act, do_extract, do_navigate, do_screenshot, parse_extract_schema
from skyvern.cli.core.guards import (
GuardError,
check_js_password,
check_password_prompt,
resolve_ai_mode,
validate_button,
validate_wait_until,
)
from skyvern.cli.core.session_ops import do_session_close, do_session_create, do_session_list
# ---------------------------------------------------------------------------
# guards.py
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"text",
[
"enter your password",
"use credential to login",
"type the secret",
"enter passphrase",
"enter passcode",
"enter your pin code",
"type pwd here",
"enter passwd",
],
)
def test_password_guard_blocks_sensitive_text(text: str) -> None:
with pytest.raises(GuardError) as exc_info:
check_password_prompt(text)
assert exc_info.value.hint # hint should always be populated
@pytest.mark.parametrize("text", ["click the submit button", "fill in the email field", ""])
def test_password_guard_allows_normal_text(text: str) -> None:
check_password_prompt(text) # should not raise
def test_js_password_guard() -> None:
with pytest.raises(GuardError):
check_js_password('input[type=password].value = "secret"')
with pytest.raises(GuardError):
check_js_password('.type === "password"; el.value = "x"')
check_js_password("document.title") # allowed
@pytest.mark.parametrize("value", ["load", "domcontentloaded", "networkidle", "commit", None])
def test_wait_until_accepts_valid(value: str | None) -> None:
validate_wait_until(value)
def test_wait_until_rejects_invalid() -> None:
with pytest.raises(GuardError, match="Invalid wait_until"):
validate_wait_until("badvalue")
@pytest.mark.parametrize("value", ["left", "right", "middle", None])
def test_button_accepts_valid(value: str | None) -> None:
validate_button(value)
def test_button_rejects_invalid() -> None:
with pytest.raises(GuardError, match="Invalid button"):
validate_button("double")
@pytest.mark.parametrize(
"selector,intent,expected",
[
(None, "click it", ("proactive", None)),
("#btn", "click it", ("fallback", None)),
("#btn", None, (None, None)),
(None, None, (None, "INVALID_INPUT")),
],
)
def test_resolve_ai_mode(selector: str | None, intent: str | None, expected: tuple) -> None:
assert resolve_ai_mode(selector, intent) == expected
# ---------------------------------------------------------------------------
# browser_ops.py
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_do_navigate_success() -> None:
page = MagicMock()
page.goto = AsyncMock()
page.url = "https://example.com/final"
page.title = AsyncMock(return_value="Example")
result = await do_navigate(page, "https://example.com")
assert result.url == "https://example.com/final"
assert result.title == "Example"
@pytest.mark.asyncio
async def test_do_navigate_passes_wait_until_through() -> None:
page = MagicMock()
page.goto = AsyncMock()
page.url = "https://example.com/final"
page.title = AsyncMock(return_value="Example")
result = await do_navigate(page, "https://example.com", wait_until="badvalue")
assert result.url == "https://example.com/final"
page.goto.assert_awaited_once_with("https://example.com", timeout=30000, wait_until="badvalue")
@pytest.mark.asyncio
async def test_do_screenshot_full_page() -> None:
page = MagicMock()
page.screenshot = AsyncMock(return_value=b"png-data")
result = await do_screenshot(page, full_page=True)
assert result.data == b"png-data"
assert result.full_page is True
@pytest.mark.asyncio
async def test_do_screenshot_with_selector() -> None:
page = MagicMock()
element = MagicMock()
element.screenshot = AsyncMock(return_value=b"element-data")
page.locator.return_value = element
result = await do_screenshot(page, selector="#header")
assert result.data == b"element-data"
@pytest.mark.asyncio
async def test_do_act_success() -> None:
page = MagicMock()
page.act = AsyncMock()
result = await do_act(page, "enter the password")
assert result.prompt == "enter the password"
assert result.completed is True
@pytest.mark.asyncio
async def test_do_extract_rejects_bad_schema() -> None:
with pytest.raises(GuardError, match="Invalid JSON schema"):
await do_extract(MagicMock(), "get data", schema="not-json")
@pytest.mark.asyncio
async def test_do_extract_success() -> None:
page = MagicMock()
page.extract = AsyncMock(return_value={"title": "Example"})
result = await do_extract(page, "get the title")
assert result.extracted == {"title": "Example"}
def test_parse_extract_schema_accepts_preparsed_dict() -> None:
schema = {"type": "object", "properties": {"title": {"type": "string"}}}
parsed = parse_extract_schema(schema)
assert parsed is schema
@pytest.mark.asyncio
async def test_do_extract_accepts_preparsed_dict() -> None:
page = MagicMock()
page.extract = AsyncMock(return_value={"title": "Example"})
schema = {"type": "object", "properties": {"title": {"type": "string"}}}
result = await do_extract(page, "get the title", schema=schema)
assert result.extracted == {"title": "Example"}
page.extract.assert_awaited_once_with(prompt="get the title", schema=schema)
# ---------------------------------------------------------------------------
# session_ops.py
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_do_session_create_local() -> None:
skyvern = MagicMock()
skyvern.launch_local_browser = AsyncMock(return_value=MagicMock())
browser, result = await do_session_create(skyvern, local=True, headless=True)
assert result.local is True
assert result.session_id is None
@pytest.mark.asyncio
async def test_do_session_create_cloud() -> None:
skyvern = MagicMock()
browser_mock = MagicMock()
browser_mock.browser_session_id = "pbs_123"
skyvern.launch_cloud_browser = AsyncMock(return_value=browser_mock)
browser, result = await do_session_create(skyvern, timeout=30)
assert result.session_id == "pbs_123"
assert result.timeout_minutes == 30
@pytest.mark.asyncio
async def test_do_session_close() -> None:
skyvern = MagicMock()
skyvern.close_browser_session = AsyncMock()
result = await do_session_close(skyvern, "pbs_123")
assert result.session_id == "pbs_123"
assert result.closed is True
@pytest.mark.asyncio
async def test_do_session_list() -> None:
session = MagicMock()
session.browser_session_id = "pbs_1"
session.status = "active"
session.started_at = None
session.timeout = 60
session.runnable_id = None
session.browser_address = "ws://localhost:1234"
skyvern = MagicMock()
skyvern.get_browser_sessions = AsyncMock(return_value=[session])
result = await do_session_list(skyvern)
assert len(result) == 1
assert result[0].session_id == "pbs_1"
assert result[0].available is True

View File

@@ -0,0 +1,65 @@
"""Tests for MCP browser tool preflight validation behavior."""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from skyvern.cli.core.result import BrowserContext
from skyvern.cli.mcp_tools import browser as mcp_browser
@pytest.mark.asyncio
async def test_skyvern_extract_invalid_schema_preflight_before_session(monkeypatch: pytest.MonkeyPatch) -> None:
get_page = AsyncMock(side_effect=AssertionError("get_page should not be called for invalid schema"))
monkeypatch.setattr(mcp_browser, "get_page", get_page)
result = await mcp_browser.skyvern_extract(prompt="extract data", schema="{invalid")
assert result["ok"] is False
assert result["error"]["code"] == mcp_browser.ErrorCode.INVALID_INPUT
assert "Invalid JSON schema" in result["error"]["message"]
get_page.assert_not_awaited()
@pytest.mark.asyncio
async def test_skyvern_extract_preparsed_schema_passed_to_core(monkeypatch: pytest.MonkeyPatch) -> None:
page = object()
context = BrowserContext(mode="cloud_session", session_id="pbs_test")
monkeypatch.setattr(mcp_browser, "get_page", AsyncMock(return_value=(page, context)))
do_extract = AsyncMock(return_value=SimpleNamespace(extracted={"ok": True}))
monkeypatch.setattr(mcp_browser, "do_extract", do_extract)
result = await mcp_browser.skyvern_extract(prompt="extract data", schema='{"type":"object"}')
assert result["ok"] is True
await_args = do_extract.await_args
assert await_args is not None
assert isinstance(await_args.kwargs["schema"], dict)
@pytest.mark.asyncio
async def test_skyvern_navigate_invalid_wait_until_preflight_before_session(monkeypatch: pytest.MonkeyPatch) -> None:
get_page = AsyncMock(side_effect=AssertionError("get_page should not be called for invalid wait_until"))
monkeypatch.setattr(mcp_browser, "get_page", get_page)
result = await mcp_browser.skyvern_navigate(url="https://example.com", wait_until="not-a-real-wait-until")
assert result["ok"] is False
assert result["error"]["code"] == mcp_browser.ErrorCode.INVALID_INPUT
get_page.assert_not_awaited()
@pytest.mark.asyncio
async def test_skyvern_act_password_prompt_preflight_before_session(monkeypatch: pytest.MonkeyPatch) -> None:
get_page = AsyncMock(side_effect=AssertionError("get_page should not be called for password prompt"))
monkeypatch.setattr(mcp_browser, "get_page", get_page)
result = await mcp_browser.skyvern_act(prompt="enter the password and submit")
assert result["ok"] is False
assert result["error"]["code"] == mcp_browser.ErrorCode.INVALID_INPUT
get_page.assert_not_awaited()