script gen post action (#3480)

This commit is contained in:
Shuchang Zheng
2025-09-19 08:50:21 -07:00
committed by GitHub
parent b4669f7477
commit c5280782b0
17 changed files with 536 additions and 264 deletions

View File

@@ -273,6 +273,7 @@ function RunWorkflowForm({
cacheKey,
cacheKeyValue,
workflowPermanentId,
status: "published",
});
const [runWithCodeIsEnabled, setRunWithCodeIsEnabled] = useState(false);

View File

@@ -214,6 +214,7 @@ function Workspace({
cacheKey,
cacheKeyValue,
workflowPermanentId,
status: "published",
});
const { data: cacheKeyValues, isLoading: cacheKeyValuesLoading } =

View File

@@ -8,6 +8,8 @@ type Props = {
cacheKeyValue?: string;
workflowPermanentId?: string;
pollIntervalMs?: number;
status?: string;
workflowRunId?: string;
};
function useBlockScriptsQuery({
@@ -15,6 +17,8 @@ function useBlockScriptsQuery({
cacheKeyValue,
workflowPermanentId,
pollIntervalMs,
status,
workflowRunId,
}: Props) {
const credentialGetter = useCredentialGetter();
@@ -25,6 +29,8 @@ function useBlockScriptsQuery({
cacheKey,
cacheKeyValue,
pollIntervalMs,
status,
workflowRunId,
],
queryFn: async () => {
const client = await getClient(credentialGetter, "sans-api-v1");
@@ -33,6 +39,8 @@ function useBlockScriptsQuery({
.post<ScriptBlocksResponse>(`/scripts/${workflowPermanentId}/blocks`, {
cache_key: cacheKey ?? "",
cache_key_value: cacheKeyValue ?? "",
status: status ?? "published",
workflow_run_id: workflowRunId ?? null,
})
.then((response) => response.data);

View File

@@ -52,6 +52,8 @@ function WorkflowRunCode(props?: Props) {
cacheKeyValue,
workflowPermanentId,
pollIntervalMs: !isFinalized ? 3000 : undefined,
status: "draft",
workflowRunId: workflowRun?.workflow_run_id,
});
const orderedBlockLabels = getOrderedBlockLabels(workflow);
const code = getCode(orderedBlockLabels, blockScripts).join("").trim();
@@ -79,7 +81,15 @@ function WorkflowRunCode(props?: Props) {
useEffect(() => {
queryClient.invalidateQueries({
queryKey: ["block-scripts", workflowPermanentId, cacheKey, cacheKeyValue],
queryKey: [
"block-scripts",
workflowPermanentId,
cacheKey,
cacheKeyValue,
undefined,
"draft",
workflowRun?.workflow_run_id,
],
});
}, [queryClient, workflowRun, workflowPermanentId, cacheKey, cacheKeyValue]);

View File

@@ -2,20 +2,11 @@
"""
Generate a runnable Skyvern workflow script.
Example
-------
generated_code = generate_workflow_script(
file_name="workflow.py",
workflow_run_request=workflow_run_request,
workflow=workflow,
tasks=tasks,
actions_by_task=actions_by_task,
)
Path("workflow.py").write_text(src)
"""
from __future__ import annotations
import asyncio
import hashlib
import keyword
import re
@@ -1602,7 +1593,7 @@ def _build_run_fn(blocks: list[dict[str, Any]], wf_req: dict[str, Any]) -> Funct
# --------------------------------------------------------------------- #
async def generate_workflow_script(
async def generate_workflow_script_python_code(
*,
file_name: str,
workflow_run_request: dict[str, Any],
@@ -1614,6 +1605,7 @@ async def generate_workflow_script(
run_id: str | None = None,
script_id: str | None = None,
script_revision_id: str | None = None,
draft: bool = False,
) -> str:
"""
Build a LibCST Module and emit .code (PEP-8-formatted source).
@@ -1685,16 +1677,15 @@ async def generate_workflow_script(
if script_id and script_revision_id and organization_id:
try:
block_name = task.get("label") or task.get("title") or task.get("task_id") or f"task_{idx}"
block_description = f"Generated block for task: {block_name}"
temp_module = cst.Module(body=[block_fn_def])
block_code = temp_module.code
await create_script_block(
await create_or_update_script_block(
block_code=block_code,
script_revision_id=script_revision_id,
script_id=script_id,
organization_id=organization_id,
block_name=block_name,
block_description=block_description,
block_label=block_name,
update=draft,
)
except Exception as e:
LOG.error("Failed to create script block", error=str(e), exc_info=True)
@@ -1737,15 +1728,13 @@ async def generate_workflow_script(
task_v2_block_code = temp_module.code
block_name = task_v2.get("label") or task_v2.get("title") or f"task_v2_{idx}"
block_description = f"Generated task_v2 block with child functions: {block_name}"
await create_script_block(
await create_or_update_script_block(
block_code=task_v2_block_code,
script_revision_id=script_revision_id,
script_id=script_id,
organization_id=organization_id,
block_name=block_name,
block_description=block_description,
block_label=block_name,
)
except Exception as e:
LOG.error("Failed to create task_v2 script block", error=str(e), exc_info=True)
@@ -1805,13 +1794,13 @@ async def generate_workflow_script(
start_block_module = cst.Module(body=start_block_body)
start_block_code = start_block_module.code
await create_script_block(
await create_or_update_script_block(
block_code=start_block_code,
script_revision_id=script_revision_id,
script_id=script_id,
organization_id=organization_id,
block_name=settings.WORKFLOW_START_BLOCK_LABEL,
block_description="Start block containing imports, model classes, and run function",
block_label=settings.WORKFLOW_START_BLOCK_LABEL,
update=draft,
)
except Exception as e:
LOG.error("Failed to create __start_block__", error=str(e), exc_info=True)
@@ -1830,69 +1819,92 @@ async def generate_workflow_script(
return module.code
async def create_script_block(
async def create_or_update_script_block(
block_code: str | bytes,
script_revision_id: str,
script_id: str,
organization_id: str,
block_name: str,
block_description: str | None = None,
block_label: str,
update: bool = False,
) -> None:
"""
Create a script block in the database and save the block code to a script file.
If update is True, the script block will be updated instead of created.
Args:
block_code: The code to save
script_revision_id: The script revision ID
script_id: The script ID
organization_id: The organization ID
block_name: Optional custom name for the block (defaults to function name)
block_description: Optional description for the block
block_label: Optional custom name for the block (defaults to function name)
update: Whether to update the script block instead of creating a new one
"""
block_code_bytes = block_code if isinstance(block_code, bytes) else block_code.encode("utf-8")
try:
# Step 3: Create script block in database
script_block = await app.DATABASE.create_script_block(
script_revision_id=script_revision_id,
script_id=script_id,
organization_id=organization_id,
script_block_label=block_name,
)
script_block = None
if update:
script_block = await app.DATABASE.get_script_block_by_label(
organization_id=organization_id,
script_revision_id=script_revision_id,
script_block_label=block_label,
)
if not script_block:
script_block = await app.DATABASE.create_script_block(
script_revision_id=script_revision_id,
script_id=script_id,
organization_id=organization_id,
script_block_label=block_label,
)
# Step 4: Create script file for the block
# Generate a unique filename for the block
file_name = f"{block_name}.skyvern"
file_name = f"{block_label}.skyvern"
file_path = f"blocks/{file_name}"
# Create artifact and upload to S3
artifact_id = await app.ARTIFACT_MANAGER.create_script_file_artifact(
organization_id=organization_id,
script_id=script_id,
script_version=1, # Assuming version 1 for now
file_path=file_path,
data=block_code_bytes,
)
artifact_id = None
if update and script_block.script_file_id:
script_file = await app.DATABASE.get_script_file_by_id(
script_revision_id,
script_block.script_file_id,
organization_id,
)
if script_file and script_file.artifact_id:
artifact = await app.DATABASE.get_artifact_by_id(script_file.artifact_id, organization_id)
if artifact:
asyncio.create_task(app.STORAGE.store_artifact(artifact, block_code_bytes))
else:
LOG.error("Script file or artifact not found", script_file_id=script_block.script_file_id)
else:
artifact_id = await app.ARTIFACT_MANAGER.create_script_file_artifact(
organization_id=organization_id,
script_id=script_id,
script_version=1, # Assuming version 1 for now
file_path=file_path,
data=block_code_bytes,
)
# Create script file record
script_file = await app.DATABASE.create_script_file(
script_revision_id=script_revision_id,
script_id=script_id,
organization_id=organization_id,
file_path=file_path,
file_name=file_name,
file_type="file",
content_hash=f"sha256:{hashlib.sha256(block_code_bytes).hexdigest()}",
file_size=len(block_code_bytes),
mime_type="text/x-python",
artifact_id=artifact_id,
)
# Create script file record
script_file = await app.DATABASE.create_script_file(
script_revision_id=script_revision_id,
script_id=script_id,
organization_id=organization_id,
file_path=file_path,
file_name=file_name,
file_type="file",
content_hash=f"sha256:{hashlib.sha256(block_code_bytes).hexdigest()}",
file_size=len(block_code_bytes),
mime_type="text/x-python",
artifact_id=artifact_id,
)
# update script block with script file id
await app.DATABASE.update_script_block(
script_block_id=script_block.script_block_id,
organization_id=organization_id,
script_file_id=script_file.file_id,
)
# update script block with script file id
await app.DATABASE.update_script_block(
script_block_id=script_block.script_block_id,
organization_id=organization_id,
script_file_id=script_file.file_id,
)
except Exception as e:
# Log error but don't fail the entire generation process

View File

@@ -199,7 +199,7 @@ class SkyvernPage:
finally:
skyvern_page._record(call)
# Auto-create action after execution
await skyvern_page._create_action_before_execution(
await skyvern_page._create_action_after_execution(
action_type=action,
intention=intention,
status=action_status,
@@ -222,7 +222,7 @@ class SkyvernPage:
timeout=timeout,
)
async def _create_action_before_execution(
async def _create_action_after_execution(
self,
action_type: ActionType,
intention: str = "",
@@ -295,6 +295,7 @@ class SkyvernPage:
created_action = await app.DATABASE.create_action(action)
context.action_order += 1
return created_action
except Exception:

View File

@@ -1176,6 +1176,7 @@ class ForgeAgent:
# set verified to True will skip the completion verification
action.verified = True
results = await ActionHandler.handle_action(scraped_page, task, step, current_page, action)
await app.AGENT_FUNCTION.post_action_execution()
detailed_agent_step_output.actions_and_results[action_idx] = (
action,
results,
@@ -1318,6 +1319,7 @@ class ForgeAgent:
complete_results = await ActionHandler.handle_action(
scraped_page, task, step, working_page, complete_action
)
await app.AGENT_FUNCTION.post_action_execution()
detailed_agent_step_output.actions_and_results.append((complete_action, complete_results))
await self.record_artifacts_after_action(task, step, browser_state, engine)
@@ -1337,6 +1339,7 @@ class ForgeAgent:
extract_results = await ActionHandler.handle_action(
scraped_page, task, step, working_page, extract_action
)
await app.AGENT_FUNCTION.post_action_execution()
detailed_agent_step_output.actions_and_results.append((extract_action, extract_results))
# If no action errors return the agent state and output

View File

@@ -20,6 +20,7 @@ from skyvern.forge.sdk.schemas.organizations import Organization
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
from skyvern.forge.sdk.trace import TraceManager
from skyvern.forge.sdk.workflow.models.block import BlockTypeVar
from skyvern.services import workflow_script_service
from skyvern.webeye.browser_factory import BrowserState
from skyvern.webeye.scraper.scraper import ELEMENT_NODE_ATTRIBUTES, CleanupElementTreeFunc, json_to_html
from skyvern.webeye.utils.dom import SkyvernElement
@@ -615,3 +616,35 @@ class AgentFunction:
async def validate_code_block(self, organization_id: str | None = None) -> None:
if not settings.ENABLE_CODE_BLOCK:
raise DisabledBlockExecutionError("CodeBlock is disabled")
async def _post_action_execution(self) -> None:
"""
If this is a workflow running environment, generate the
"""
context = skyvern_context.current()
if not context or not context.root_workflow_run_id or not context.organization_id:
return
root_workflow_run_id = context.root_workflow_run_id
organization_id = context.organization_id
workflow_run = await app.DATABASE.get_workflow_run(
workflow_run_id=root_workflow_run_id, organization_id=organization_id
)
if not workflow_run:
return
workflow = await app.DATABASE.get_workflow(
workflow_id=workflow_run.workflow_id, organization_id=organization_id
)
if not workflow:
return
LOG.info(
"Post action execution",
root_workflow_run_id=context.root_workflow_run_id,
organization_id=context.organization_id,
)
await workflow_script_service.generate_or_update_draft_workflow_script(
workflow_run=workflow_run,
workflow=workflow,
)
async def post_action_execution(self) -> None:
asyncio.create_task(self._post_action_execution())

View File

@@ -15,6 +15,7 @@ class SkyvernContext:
workflow_id: str | None = None
workflow_permanent_id: str | None = None
workflow_run_id: str | None = None
root_workflow_run_id: str | None = None
task_v2_id: str | None = None
max_steps_override: int | None = None
browser_session_id: str | None = None

View File

@@ -107,7 +107,7 @@ from skyvern.forge.sdk.workflow.models.workflow import (
WorkflowRunStatus,
)
from skyvern.schemas.runs import ProxyLocation, RunEngine, RunType
from skyvern.schemas.scripts import Script, ScriptBlock, ScriptFile, ScriptStatus
from skyvern.schemas.scripts import Script, ScriptBlock, ScriptFile, ScriptStatus, WorkflowScript
from skyvern.schemas.steps import AgentStepOutput
from skyvern.schemas.workflows import BlockStatus, BlockType, WorkflowStatus
from skyvern.webeye.actions.actions import Action
@@ -4039,6 +4039,44 @@ class AgentDB:
return convert_to_script_file(script_file) if script_file else None
async def get_script_file_by_path(
self,
script_revision_id: str,
file_path: str,
organization_id: str,
) -> ScriptFile | None:
async with self.Session() as session:
script_file = (
await session.scalars(
select(ScriptFileModel)
.filter_by(script_revision_id=script_revision_id)
.filter_by(file_path=file_path)
.filter_by(organization_id=organization_id)
)
).first()
return convert_to_script_file(script_file) if script_file else None
async def update_script_file(
self,
script_file_id: str,
organization_id: str,
artifact_id: str | None = None,
) -> ScriptFile:
async with self.Session() as session:
script_file = (
await session.scalars(
select(ScriptFileModel).filter_by(file_id=script_file_id).filter_by(organization_id=organization_id)
)
).first()
if script_file:
if artifact_id:
script_file.artifact_id = artifact_id
await session.commit()
await session.refresh(script_file)
return convert_to_script_file(script_file)
else:
raise NotFoundError("Script file not found")
async def get_script_block(
self,
script_block_id: str,
@@ -4054,6 +4092,23 @@ class AgentDB:
).first()
return convert_to_script_block(record) if record else None
async def get_script_block_by_label(
self,
organization_id: str,
script_revision_id: str,
script_block_label: str,
) -> ScriptBlock | None:
async with self.Session() as session:
record = (
await session.scalars(
select(ScriptBlockModel)
.filter_by(script_revision_id=script_revision_id)
.filter_by(script_block_label=script_block_label)
.filter_by(organization_id=organization_id)
)
).first()
return convert_to_script_block(record) if record else None
async def get_script_blocks_by_script_revision_id(
self,
script_revision_id: str,
@@ -4080,6 +4135,7 @@ class AgentDB:
cache_key_value: str,
workflow_id: str | None = None,
workflow_run_id: str | None = None,
status: ScriptStatus = ScriptStatus.published,
) -> None:
"""Create a workflow->script cache mapping entry."""
try:
@@ -4092,6 +4148,7 @@ class AgentDB:
workflow_run_id=workflow_run_id,
cache_key=cache_key,
cache_key_value=cache_key_value,
status=status,
)
session.add(record)
await session.commit()
@@ -4102,12 +4159,32 @@ class AgentDB:
LOG.error("UnexpectedError", exc_info=True)
raise
async def get_workflow_script(
self,
organization_id: str,
workflow_permanent_id: str,
workflow_run_id: str,
statuses: list[ScriptStatus] | None = None,
) -> WorkflowScript | None:
async with self.Session() as session:
query = (
select(WorkflowScriptModel)
.filter_by(organization_id=organization_id)
.filter_by(workflow_permanent_id=workflow_permanent_id)
.filter_by(workflow_run_id=workflow_run_id)
)
if statuses:
query = query.filter(WorkflowScriptModel.status.in_(statuses))
workflow_script_model = (await session.scalars(query)).first()
return WorkflowScript.model_validate(workflow_script_model) if workflow_script_model else None
async def get_workflow_scripts_by_cache_key_value(
self,
*,
organization_id: str,
workflow_permanent_id: str,
cache_key_value: str,
workflow_run_id: str | None = None,
cache_key: str | None = None,
statuses: list[ScriptStatus] | None = None,
) -> list[Script]:
@@ -4122,6 +4199,10 @@ class AgentDB:
.where(WorkflowScriptModel.cache_key_value == cache_key_value)
.where(WorkflowScriptModel.deleted_at.is_(None))
)
if workflow_run_id:
ws_script_ids_subquery = ws_script_ids_subquery.where(
WorkflowScriptModel.workflow_run_id == workflow_run_id
)
if cache_key is not None:
ws_script_ids_subquery = ws_script_ids_subquery.where(WorkflowScriptModel.cache_key == cache_key)
@@ -4174,6 +4255,7 @@ class AgentDB:
.filter_by(workflow_permanent_id=workflow_permanent_id)
.filter_by(cache_key=cache_key)
.filter_by(deleted_at=None)
.filter_by(status="published")
)
if filter:
@@ -4205,6 +4287,7 @@ class AgentDB:
.filter_by(workflow_permanent_id=workflow_permanent_id)
.filter_by(cache_key=cache_key)
.filter_by(deleted_at=None)
.filter_by(status="published")
.offset((page - 1) * page_size)
.limit(page_size)
)
@@ -4220,45 +4303,6 @@ class AgentDB:
LOG.error("UnexpectedError", exc_info=True)
raise
async def create_workflow_cache_key_value(
self,
organization_id: str,
workflow_permanent_id: str,
cache_key: str,
cache_key_value: str,
script_id: str,
workflow_id: str | None = None,
workflow_run_id: str | None = None,
) -> str:
"""
Insert a new cache key value for a workflow.
Returns the workflow_script_id of the created record.
"""
try:
async with self.Session() as session:
workflow_script = WorkflowScriptModel(
script_id=script_id,
organization_id=organization_id,
workflow_permanent_id=workflow_permanent_id,
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
cache_key=cache_key,
cache_key_value=cache_key_value,
)
session.add(workflow_script)
await session.commit()
await session.refresh(workflow_script)
return workflow_script.workflow_script_id
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def delete_workflow_cache_key_value(
self,
organization_id: str,

View File

@@ -300,6 +300,7 @@ async def get_workflow_script_blocks(
scripts = await app.DATABASE.get_workflow_scripts_by_cache_key_value(
organization_id=current_org.organization_id,
workflow_permanent_id=workflow_permanent_id,
workflow_run_id=block_script_request.workflow_run_id,
cache_key_value=cache_key_value,
cache_key=cache_key,
statuses=[status] if status else None,

View File

@@ -3060,6 +3060,9 @@ class TaskV2Block(Block):
finally:
context: skyvern_context.SkyvernContext | None = skyvern_context.current()
current_run_id = context.run_id if context and context.run_id else workflow_run_id
root_workflow_run_id = (
context.root_workflow_run_id if context and context.root_workflow_run_id else workflow_run_id
)
skyvern_context.set(
skyvern_context.SkyvernContext(
organization_id=organization_id,
@@ -3067,6 +3070,7 @@ class TaskV2Block(Block):
workflow_id=workflow_run.workflow_id,
workflow_permanent_id=workflow_run.workflow_permanent_id,
workflow_run_id=workflow_run_id,
root_workflow_run_id=root_workflow_run_id,
run_id=current_run_id,
browser_session_id=browser_session_id,
max_screenshot_scrolls=workflow_run.max_screenshot_scrolls,

View File

@@ -1,5 +1,4 @@
import asyncio
import base64
import json
import uuid
from datetime import UTC, datetime
@@ -7,14 +6,11 @@ from typing import Any
import httpx
import structlog
from jinja2.sandbox import SandboxedEnvironment
from skyvern import analytics
from skyvern.client.types.output_parameter import OutputParameter as BlockOutputParameter
from skyvern.config import settings
from skyvern.constants import GET_DOWNLOADED_FILES_TIMEOUT, SAVE_DOWNLOADED_FILES_TIMEOUT
from skyvern.core.script_generations.generate_script import generate_workflow_script as generate_python_workflow_script
from skyvern.core.script_generations.transform_workflow_run import transform_workflow_run_to_code_gen_input
from skyvern.exceptions import (
BlockNotFound,
BrowserSessionNotFound,
@@ -99,7 +95,6 @@ from skyvern.forge.sdk.workflow.models.workflow import (
WorkflowRunStatus,
)
from skyvern.schemas.runs import ProxyLocation, RunStatus, RunType, WorkflowRunRequest, WorkflowRunResponse
from skyvern.schemas.scripts import FileEncoding, Script, ScriptFileCreate
from skyvern.schemas.workflows import (
BLOCK_YAML_TYPES,
BlockStatus,
@@ -109,7 +104,7 @@ from skyvern.schemas.workflows import (
WorkflowDefinitionYAML,
WorkflowStatus,
)
from skyvern.services import script_service
from skyvern.services import script_service, workflow_script_service
from skyvern.webeye.browser_factory import BrowserState
LOG = structlog.get_logger()
@@ -205,6 +200,7 @@ class WorkflowService:
request_id=request_id,
workflow_id=workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
root_workflow_run_id=workflow_run.workflow_run_id,
run_id=current_run_id,
workflow_permanent_id=workflow_run.workflow_permanent_id,
max_steps_override=max_steps_override,
@@ -353,7 +349,7 @@ class WorkflowService:
return workflow_run
# Check if there's a related workflow script that should be used instead
workflow_script, _ = await self._get_workflow_script(workflow, workflow_run, block_labels)
workflow_script, _ = await workflow_script_service.get_workflow_script(workflow, workflow_run, block_labels)
is_script = workflow_script is not None
if workflow_script is not None:
LOG.info(
@@ -365,9 +361,7 @@ class WorkflowService:
)
workflow_run = await self._execute_workflow_script(
script_id=workflow_script.script_id,
workflow=workflow,
workflow_run=workflow_run,
api_key=api_key,
organization=organization,
browser_session_id=browser_session_id,
)
@@ -375,9 +369,7 @@ class WorkflowService:
workflow_run = await self._execute_workflow_blocks(
workflow=workflow,
workflow_run=workflow_run,
api_key=api_key,
organization=organization,
close_browser_on_completion=close_browser_on_completion,
browser_session_id=browser_session_id,
block_labels=block_labels,
block_outputs=block_outputs,
@@ -422,9 +414,7 @@ class WorkflowService:
self,
workflow: Workflow,
workflow_run: WorkflowRun,
api_key: str,
organization: Organization,
close_browser_on_completion: bool,
browser_session_id: str | None = None,
block_labels: list[str] | None = None,
block_outputs: dict[str, Any] | None = None,
@@ -2457,66 +2447,10 @@ class WorkflowService:
return result
async def _get_workflow_script(
self, workflow: Workflow, workflow_run: WorkflowRun, block_labels: list[str] | None = None
) -> tuple[Script | None, str]:
"""
Check if there's a related workflow script that should be used instead of running the workflow.
Returns the tuple of (script, rendered_cache_key_value).
"""
cache_key = workflow.cache_key or ""
rendered_cache_key_value = ""
if not workflow.generate_script:
return None, rendered_cache_key_value
if block_labels:
# Do not generate script or run script if block_labels is provided
return None, rendered_cache_key_value
try:
parameter_tuples = await app.DATABASE.get_workflow_run_parameters(
workflow_run_id=workflow_run.workflow_run_id,
)
parameters = {wf_param.key: run_param.value for wf_param, run_param in parameter_tuples}
jinja_sandbox_env = SandboxedEnvironment()
rendered_cache_key_value = jinja_sandbox_env.from_string(cache_key).render(parameters)
# Check if there are existing cached scripts for this workflow + cache_key_value
existing_scripts = await app.DATABASE.get_workflow_scripts_by_cache_key_value(
organization_id=workflow.organization_id,
workflow_permanent_id=workflow.workflow_permanent_id,
cache_key_value=rendered_cache_key_value,
)
if existing_scripts:
LOG.info(
"Found cached script for workflow",
workflow_id=workflow.workflow_id,
cache_key_value=rendered_cache_key_value,
workflow_run_id=workflow_run.workflow_run_id,
script_count=len(existing_scripts),
)
return existing_scripts[0], rendered_cache_key_value
return None, rendered_cache_key_value
except Exception as e:
LOG.warning(
"Failed to check for workflow script, proceeding with normal workflow execution",
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
error=str(e),
exc_info=True,
)
return None, rendered_cache_key_value
async def _execute_workflow_script(
self,
script_id: str,
workflow: Workflow,
workflow_run: WorkflowRun,
api_key: str,
organization: Organization,
browser_session_id: str | None = None,
) -> WorkflowRun:
@@ -2584,7 +2518,7 @@ class WorkflowService:
# Do not generate script if block_labels is provided
return None
existing_script, rendered_cache_key_value = await self._get_workflow_script(
existing_script, rendered_cache_key_value = await workflow_script_service.get_workflow_script(
workflow,
workflow_run,
block_labels,
@@ -2605,62 +2539,9 @@ class WorkflowService:
run_id=workflow_run.workflow_run_id,
)
# 3) Generate script code from workflow run
try:
LOG.info(
"Generating script for workflow",
workflow_run_id=workflow_run.workflow_run_id,
workflow_id=workflow.workflow_id,
workflow_name=workflow.title,
cache_key_value=rendered_cache_key_value,
)
codegen_input = await transform_workflow_run_to_code_gen_input(
workflow_run_id=workflow_run.workflow_run_id,
organization_id=workflow.organization_id,
)
python_src = await generate_python_workflow_script(
file_name=codegen_input.file_name,
workflow_run_request=codegen_input.workflow_run,
workflow=codegen_input.workflow,
blocks=codegen_input.workflow_blocks,
actions_by_task=codegen_input.actions_by_task,
task_v2_child_blocks=codegen_input.task_v2_child_blocks,
organization_id=workflow.organization_id,
script_id=created_script.script_id,
script_revision_id=created_script.script_revision_id,
)
except Exception:
LOG.error("Failed to generate workflow script source", exc_info=True)
return
# 4) Persist script and files, then record mapping
content_bytes = python_src.encode("utf-8")
content_b64 = base64.b64encode(content_bytes).decode("utf-8")
files = [
ScriptFileCreate(
path="main.py",
content=content_b64,
encoding=FileEncoding.BASE64,
mime_type="text/x-python",
)
]
# Upload script file(s) as artifacts and create rows
await script_service.build_file_tree(
files=files,
organization_id=workflow.organization_id,
script_id=created_script.script_id,
script_version=created_script.version,
script_revision_id=created_script.script_revision_id,
)
# Record the workflow->script mapping for cache lookup
await app.DATABASE.create_workflow_script(
organization_id=workflow.organization_id,
script_id=created_script.script_id,
workflow_permanent_id=workflow.workflow_permanent_id,
cache_key=workflow.cache_key or "",
cache_key_value=rendered_cache_key_value,
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
await workflow_script_service.generate_workflow_script(
workflow_run=workflow_run,
workflow=workflow,
script=created_script,
rendered_cache_key_value=rendered_cache_key_value,
)

View File

@@ -155,8 +155,26 @@ class ScriptBlocksRequest(BaseModel):
cache_key_value: str
cache_key: str | None = None
status: ScriptStatus | None = None
workflow_run_id: str | None = None
class ScriptStatus(StrEnum):
published = "published"
draft = "draft"
pending = "pending"
class WorkflowScript(BaseModel):
model_config = ConfigDict(from_attributes=True)
workflow_script_id: str
organization_id: str
script_id: str
workflow_permanent_id: str
workflow_id: str | None = None
workflow_run_id: str | None = None
cache_key: str
cache_key_value: str
status: ScriptStatus
created_at: datetime
modified_at: datetime
deleted_at: datetime | None = None

View File

@@ -17,7 +17,7 @@ from jinja2.sandbox import SandboxedEnvironment
from skyvern.config import settings
from skyvern.constants import GET_DOWNLOADED_FILES_TIMEOUT
from skyvern.core.script_generations.constants import SCRIPT_TASK_BLOCKS
from skyvern.core.script_generations.generate_script import _build_block_fn, create_script_block
from skyvern.core.script_generations.generate_script import _build_block_fn, create_or_update_script_block
from skyvern.core.script_generations.skyvern_page import script_run_context_manager
from skyvern.exceptions import ScriptNotFound, WorkflowRunNotFound
from skyvern.forge import app
@@ -45,10 +45,10 @@ from skyvern.forge.sdk.workflow.models.block import (
from skyvern.forge.sdk.workflow.models.parameter import PARAMETER_TYPE, OutputParameter, ParameterType
from skyvern.forge.sdk.workflow.models.workflow import Workflow
from skyvern.schemas.runs import RunEngine
from skyvern.schemas.scripts import CreateScriptResponse, FileEncoding, FileNode, ScriptFileCreate
from skyvern.schemas.scripts import CreateScriptResponse, FileEncoding, FileNode, ScriptFileCreate, ScriptStatus
from skyvern.schemas.workflows import BlockStatus, BlockType, FileStorageType, FileType
LOG = structlog.get_logger(__name__)
LOG = structlog.get_logger()
jinja_sandbox_env = SandboxedEnvironment()
@@ -58,6 +58,7 @@ async def build_file_tree(
script_id: str,
script_version: int,
script_revision_id: str,
draft: bool = False,
) -> dict[str, FileNode]:
"""Build a hierarchical file tree from a list of files and upload the files to s3 with the same tree structure."""
file_tree: dict[str, FileNode] = {}
@@ -70,33 +71,94 @@ async def build_file_tree(
# Create artifact and upload to S3
try:
artifact_id = await app.ARTIFACT_MANAGER.create_script_file_artifact(
organization_id=organization_id,
script_id=script_id,
script_version=script_version,
file_path=file.path,
data=content_bytes,
)
LOG.debug(
"Created script file artifact",
artifact_id=artifact_id,
file_path=file.path,
script_id=script_id,
script_version=script_version,
)
# create a script file record
await app.DATABASE.create_script_file(
script_revision_id=script_revision_id,
script_id=script_id,
organization_id=organization_id,
file_path=file.path,
file_name=file.path.split("/")[-1],
file_type="file",
content_hash=f"sha256:{content_hash}",
file_size=file_size,
mime_type=file.mime_type,
artifact_id=artifact_id,
)
if draft:
# get the script file object
script_file = await app.DATABASE.get_script_file_by_path(
script_revision_id=script_revision_id,
file_path=file.path,
organization_id=organization_id,
)
if script_file:
if not script_file.artifact_id:
LOG.error(
"Failed to update file. An existing script file has no artifact id",
script_file_id=script_file.file_id,
)
continue
artifact = await app.DATABASE.get_artifact_by_id(script_file.artifact_id, organization_id)
if artifact:
# override the actual file in the storage
asyncio.create_task(app.STORAGE.store_artifact(artifact, content_bytes))
else:
artifact_id = await app.ARTIFACT_MANAGER.create_script_file_artifact(
organization_id=organization_id,
script_id=script_id,
script_version=script_version,
file_path=file.path,
data=content_bytes,
)
# update the artifact_id in the script file
await app.DATABASE.update_script_file(
script_file_id=script_file.file_id,
organization_id=organization_id,
artifact_id=artifact_id,
)
else:
artifact_id = await app.ARTIFACT_MANAGER.create_script_file_artifact(
organization_id=organization_id,
script_id=script_id,
script_version=script_version,
file_path=file.path,
data=content_bytes,
)
LOG.debug(
"Created script file artifact",
artifact_id=artifact_id,
file_path=file.path,
script_id=script_id,
script_version=script_version,
)
# create a script file record
await app.DATABASE.create_script_file(
script_revision_id=script_revision_id,
script_id=script_id,
organization_id=organization_id,
file_path=file.path,
file_name=file.path.split("/")[-1],
file_type="file",
content_hash=f"sha256:{content_hash}",
file_size=file_size,
mime_type=file.mime_type,
artifact_id=artifact_id,
)
else:
artifact_id = await app.ARTIFACT_MANAGER.create_script_file_artifact(
organization_id=organization_id,
script_id=script_id,
script_version=script_version,
file_path=file.path,
data=content_bytes,
)
LOG.debug(
"Created script file artifact",
artifact_id=artifact_id,
file_path=file.path,
script_id=script_id,
script_version=script_version,
)
# create a script file record
await app.DATABASE.create_script_file(
script_revision_id=script_revision_id,
script_id=script_id,
organization_id=organization_id,
file_path=file.path,
file_name=file.path.split("/")[-1],
file_type="file",
content_hash=f"sha256:{content_hash}",
file_size=file_size,
mime_type=file.mime_type,
artifact_id=artifact_id,
)
except Exception:
LOG.exception(
"Failed to create script file artifact",
@@ -794,6 +856,7 @@ async def _regenerate_script_block_after_ai_fallback(
workflow_permanent_id=workflow.workflow_permanent_id,
cache_key_value=cache_key_value,
cache_key=workflow.cache_key,
statuses=[ScriptStatus.published],
)
if not existing_scripts:
@@ -898,12 +961,12 @@ async def _regenerate_script_block_after_ai_fallback(
)
continue
await create_script_block(
await create_or_update_script_block(
block_code=block_file_content,
script_revision_id=new_script.script_revision_id,
script_id=new_script.script_id,
organization_id=organization_id,
block_name=existing_block.script_block_label,
block_label=existing_block.script_block_label,
)
block_file_content_bytes = (
block_file_content if isinstance(block_file_content, bytes) else block_file_content.encode("utf-8")

View File

@@ -466,6 +466,8 @@ async def run_task_v2_helper(
context: skyvern_context.SkyvernContext | None = skyvern_context.current()
current_run_id = context.run_id if context and context.run_id else task_v2_id
# task v2 can be nested inside a workflow run, so we need to use the root workflow run id
root_workflow_run_id = context.root_workflow_run_id if context and context.root_workflow_run_id else workflow_run_id
enable_parse_select_in_extract = app.EXPERIMENTATION_PROVIDER.is_feature_enabled_cached(
"ENABLE_PARSE_SELECT_IN_EXTRACT",
current_run_id,
@@ -476,6 +478,7 @@ async def run_task_v2_helper(
organization_id=organization_id,
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
root_workflow_run_id=root_workflow_run_id,
request_id=request_id,
task_v2_id=task_v2_id,
run_id=current_run_id,

View File

@@ -0,0 +1,188 @@
import base64
import structlog
from jinja2.sandbox import SandboxedEnvironment
from skyvern.core.script_generations.generate_script import generate_workflow_script_python_code
from skyvern.core.script_generations.transform_workflow_run import transform_workflow_run_to_code_gen_input
from skyvern.forge import app
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowRun
from skyvern.schemas.scripts import FileEncoding, Script, ScriptFileCreate, ScriptStatus
from skyvern.services import script_service
LOG = structlog.get_logger()
jinja_sandbox_env = SandboxedEnvironment()
async def generate_or_update_draft_workflow_script(
workflow_run: WorkflowRun,
workflow: Workflow,
) -> None:
organization_id = workflow.organization_id
context = skyvern_context.current()
if not context:
return
script_id = context.script_id
script = None
if script_id:
script = await app.DATABASE.get_script(script_id=script_id, organization_id=organization_id)
if not script:
script = await app.DATABASE.create_script(organization_id=organization_id, run_id=workflow_run.workflow_run_id)
if context:
context.script_id = script.script_id
context.script_revision_id = script.script_revision_id
_, rendered_cache_key_value = await get_workflow_script(
workflow=workflow,
workflow_run=workflow_run,
status=ScriptStatus.pending,
)
await generate_workflow_script(
workflow_run=workflow_run,
workflow=workflow,
script=script,
rendered_cache_key_value=rendered_cache_key_value,
draft=True,
)
async def get_workflow_script(
workflow: Workflow,
workflow_run: WorkflowRun,
block_labels: list[str] | None = None,
status: ScriptStatus = ScriptStatus.published,
) -> tuple[Script | None, str]:
"""
Check if there's a related workflow script that should be used instead of running the workflow.
Returns the tuple of (script, rendered_cache_key_value).
"""
cache_key = workflow.cache_key or ""
rendered_cache_key_value = ""
if not workflow.generate_script:
return None, rendered_cache_key_value
if block_labels:
# Do not generate script or run script if block_labels is provided
return None, rendered_cache_key_value
try:
parameter_tuples = await app.DATABASE.get_workflow_run_parameters(
workflow_run_id=workflow_run.workflow_run_id,
)
parameters = {wf_param.key: run_param.value for wf_param, run_param in parameter_tuples}
rendered_cache_key_value = jinja_sandbox_env.from_string(cache_key).render(parameters)
# Check if there are existing cached scripts for this workflow + cache_key_value
existing_scripts = await app.DATABASE.get_workflow_scripts_by_cache_key_value(
organization_id=workflow.organization_id,
workflow_permanent_id=workflow.workflow_permanent_id,
cache_key_value=rendered_cache_key_value,
statuses=[status],
)
if existing_scripts:
LOG.info(
"Found cached script for workflow",
workflow_id=workflow.workflow_id,
cache_key_value=rendered_cache_key_value,
workflow_run_id=workflow_run.workflow_run_id,
script_count=len(existing_scripts),
)
return existing_scripts[0], rendered_cache_key_value
return None, rendered_cache_key_value
except Exception as e:
LOG.warning(
"Failed to check for workflow script, proceeding with normal workflow execution",
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
error=str(e),
exc_info=True,
)
return None, rendered_cache_key_value
async def generate_workflow_script(
workflow_run: WorkflowRun,
workflow: Workflow,
script: Script,
rendered_cache_key_value: str,
draft: bool = False,
) -> None:
try:
LOG.info(
"Generating script for workflow",
workflow_run_id=workflow_run.workflow_run_id,
workflow_id=workflow.workflow_id,
workflow_name=workflow.title,
cache_key_value=rendered_cache_key_value,
)
codegen_input = await transform_workflow_run_to_code_gen_input(
workflow_run_id=workflow_run.workflow_run_id,
organization_id=workflow.organization_id,
)
python_src = await generate_workflow_script_python_code(
file_name=codegen_input.file_name,
workflow_run_request=codegen_input.workflow_run,
workflow=codegen_input.workflow,
blocks=codegen_input.workflow_blocks,
actions_by_task=codegen_input.actions_by_task,
task_v2_child_blocks=codegen_input.task_v2_child_blocks,
organization_id=workflow.organization_id,
script_id=script.script_id,
script_revision_id=script.script_revision_id,
draft=draft,
)
except Exception:
LOG.error("Failed to generate workflow script source", exc_info=True)
return
# 4) Persist script and files, then record mapping
content_bytes = python_src.encode("utf-8")
content_b64 = base64.b64encode(content_bytes).decode("utf-8")
files = [
ScriptFileCreate(
path="main.py",
content=content_b64,
encoding=FileEncoding.BASE64,
mime_type="text/x-python",
)
]
# Upload script file(s) as artifacts and create rows
await script_service.build_file_tree(
files=files,
organization_id=workflow.organization_id,
script_id=script.script_id,
script_version=script.version,
script_revision_id=script.script_revision_id,
draft=draft,
)
# check if an existing drfat workflow script exists for this workflow run
existing_draft_workflow_script = None
status = ScriptStatus.published
if draft:
status = ScriptStatus.pending
existing_draft_workflow_script = await app.DATABASE.get_workflow_script(
organization_id=workflow.organization_id,
workflow_permanent_id=workflow.workflow_permanent_id,
workflow_run_id=workflow_run.workflow_run_id,
statuses=[status],
)
if not existing_draft_workflow_script:
# Record the workflow->script mapping for cache lookup
await app.DATABASE.create_workflow_script(
organization_id=workflow.organization_id,
script_id=script.script_id,
workflow_permanent_id=workflow.workflow_permanent_id,
cache_key=workflow.cache_key or "",
cache_key_value=rendered_cache_key_value,
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
status=status,
)