update SkyvernClient using generated client code (#2044)
This commit is contained in:
@@ -1,10 +1,6 @@
|
|||||||
from typing import Any
|
from skyvern.client.client import AsyncSkyvern
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
from skyvern.config import settings
|
from skyvern.config import settings
|
||||||
from skyvern.exceptions import SkyvernClientException
|
from skyvern.forge.sdk.workflow.models.workflow import RunWorkflowResponse
|
||||||
from skyvern.forge.sdk.workflow.models.workflow import RunWorkflowResponse, WorkflowRunResponse
|
|
||||||
from skyvern.schemas.runs import ProxyLocation, RunEngine, RunResponse
|
from skyvern.schemas.runs import ProxyLocation, RunEngine, RunResponse
|
||||||
|
|
||||||
|
|
||||||
@@ -16,25 +12,38 @@ class SkyvernClient:
|
|||||||
) -> None:
|
) -> None:
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
|
self.client = AsyncSkyvern(base_url=base_url, api_key=api_key)
|
||||||
|
|
||||||
async def run_task(
|
async def run_task(
|
||||||
self,
|
self,
|
||||||
goal: str,
|
prompt: str,
|
||||||
engine: RunEngine = RunEngine.skyvern_v1,
|
|
||||||
url: str | None = None,
|
url: str | None = None,
|
||||||
|
title: str | None = None,
|
||||||
|
engine: RunEngine = RunEngine.skyvern_v2,
|
||||||
webhook_url: str | None = None,
|
webhook_url: str | None = None,
|
||||||
totp_identifier: str | None = None,
|
totp_identifier: str | None = None,
|
||||||
totp_url: str | None = None,
|
totp_url: str | None = None,
|
||||||
title: str | None = None,
|
|
||||||
error_code_mapping: dict[str, str] | None = None,
|
error_code_mapping: dict[str, str] | None = None,
|
||||||
proxy_location: ProxyLocation | None = None,
|
proxy_location: ProxyLocation | None = None,
|
||||||
max_steps: int | None = None,
|
max_steps: int | None = None,
|
||||||
|
browser_session_id: str | None = None,
|
||||||
|
publish_workflow: bool = False,
|
||||||
) -> RunResponse:
|
) -> RunResponse:
|
||||||
if engine == RunEngine.skyvern_v1:
|
task_run_obj = await self.client.agent.run_task(
|
||||||
return RunResponse()
|
goal=prompt,
|
||||||
elif engine == RunEngine.skyvern_v2:
|
url=url,
|
||||||
return RunResponse()
|
title=title,
|
||||||
raise ValueError(f"Invalid engine: {engine}")
|
engine=engine,
|
||||||
|
webhook_url=webhook_url,
|
||||||
|
totp_identifier=totp_identifier,
|
||||||
|
totp_url=totp_url,
|
||||||
|
error_code_mapping=error_code_mapping,
|
||||||
|
proxy_location=proxy_location,
|
||||||
|
max_steps=max_steps,
|
||||||
|
browser_session_id=browser_session_id,
|
||||||
|
publish_workflow=publish_workflow,
|
||||||
|
)
|
||||||
|
return RunResponse.model_validate(task_run_obj)
|
||||||
|
|
||||||
async def run_workflow(
|
async def run_workflow(
|
||||||
self,
|
self,
|
||||||
@@ -44,47 +53,24 @@ class SkyvernClient:
|
|||||||
proxy_location: ProxyLocation | None = None,
|
proxy_location: ProxyLocation | None = None,
|
||||||
totp_identifier: str | None = None,
|
totp_identifier: str | None = None,
|
||||||
totp_url: str | None = None,
|
totp_url: str | None = None,
|
||||||
|
browser_session_id: str | None = None,
|
||||||
|
template: bool = False,
|
||||||
) -> RunWorkflowResponse:
|
) -> RunWorkflowResponse:
|
||||||
data: dict[str, Any] = {
|
workflow_run_obj = await self.client.agent.run_workflow(
|
||||||
"webhook_callback_url": webhook_url,
|
workflow_id=workflow_id,
|
||||||
"proxy_location": proxy_location,
|
data=workflow_input,
|
||||||
"totp_identifier": totp_identifier,
|
webhook_callback_url=webhook_url,
|
||||||
"totp_url": totp_url,
|
proxy_location=proxy_location,
|
||||||
}
|
totp_identifier=totp_identifier,
|
||||||
if workflow_input:
|
totp_url=totp_url,
|
||||||
data["data"] = workflow_input
|
browser_session_id=browser_session_id,
|
||||||
async with httpx.AsyncClient() as client:
|
template=template,
|
||||||
response = await client.post(
|
)
|
||||||
f"{self.base_url}/api/v1/workflows/{workflow_id}/run",
|
return RunWorkflowResponse.model_validate(workflow_run_obj)
|
||||||
headers={"x-api-key": self.api_key},
|
|
||||||
json=data,
|
|
||||||
)
|
|
||||||
if response.status_code != 200:
|
|
||||||
raise SkyvernClientException(
|
|
||||||
f"Failed to run workflow: {response.text}",
|
|
||||||
status_code=response.status_code,
|
|
||||||
)
|
|
||||||
return RunWorkflowResponse.model_validate(response.json())
|
|
||||||
|
|
||||||
async def get_run(
|
async def get_run(
|
||||||
self,
|
self,
|
||||||
run_id: str,
|
run_id: str,
|
||||||
) -> RunResponse:
|
) -> RunResponse:
|
||||||
return RunResponse()
|
run_obj = await self.client.agent.get_run(run_id=run_id)
|
||||||
|
return RunResponse.model_validate(run_obj)
|
||||||
async def get_workflow_run(
|
|
||||||
self,
|
|
||||||
workflow_run_id: str,
|
|
||||||
) -> WorkflowRunResponse:
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
response = await client.get(
|
|
||||||
f"{self.base_url}/api/v1/workflows/runs/{workflow_run_id}",
|
|
||||||
headers={"x-api-key": self.api_key},
|
|
||||||
timeout=60,
|
|
||||||
)
|
|
||||||
if response.status_code != 200:
|
|
||||||
raise SkyvernClientException(
|
|
||||||
f"Failed to get workflow run: {response.text}",
|
|
||||||
status_code=response.status_code,
|
|
||||||
)
|
|
||||||
return WorkflowRunResponse.model_validate(response.json())
|
|
||||||
|
|||||||
Reference in New Issue
Block a user