overhual llm key override (#2677)
This commit is contained in:
@@ -58,7 +58,7 @@ from skyvern.forge.sdk.api.files import (
|
|||||||
rename_file,
|
rename_file,
|
||||||
wait_for_download_finished,
|
wait_for_download_finished,
|
||||||
)
|
)
|
||||||
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMCaller, LLMCallerManager
|
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory, LLMCaller, LLMCallerManager
|
||||||
from skyvern.forge.sdk.artifact.models import ArtifactType
|
from skyvern.forge.sdk.artifact.models import ArtifactType
|
||||||
from skyvern.forge.sdk.core import skyvern_context
|
from skyvern.forge.sdk.core import skyvern_context
|
||||||
from skyvern.forge.sdk.core.security import generate_skyvern_webhook_headers
|
from skyvern.forge.sdk.core.security import generate_skyvern_webhook_headers
|
||||||
@@ -865,15 +865,20 @@ class ForgeAgent:
|
|||||||
):
|
):
|
||||||
using_cached_action_plan = True
|
using_cached_action_plan = True
|
||||||
else:
|
else:
|
||||||
|
llm_key_override = task.llm_key
|
||||||
|
# FIXME: Redundant engine check?
|
||||||
if engine in CUA_ENGINES:
|
if engine in CUA_ENGINES:
|
||||||
self.async_operation_pool.run_operation(task.task_id, AgentPhase.llm)
|
self.async_operation_pool.run_operation(task.task_id, AgentPhase.llm)
|
||||||
|
llm_key_override = None
|
||||||
|
|
||||||
json_response = await app.LLM_API_HANDLER(
|
llm_api_handler = LLMAPIHandlerFactory.get_override_llm_api_handler(
|
||||||
|
llm_key_override, default=app.LLM_API_HANDLER
|
||||||
|
)
|
||||||
|
json_response = await llm_api_handler(
|
||||||
prompt=extract_action_prompt,
|
prompt=extract_action_prompt,
|
||||||
prompt_name="extract-actions",
|
prompt_name="extract-actions",
|
||||||
step=step,
|
step=step,
|
||||||
screenshots=scraped_page.screenshots,
|
screenshots=scraped_page.screenshots,
|
||||||
llm_key_override=task.llm_key,
|
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
json_response = await self.handle_potential_verification_code(
|
json_response = await self.handle_potential_verification_code(
|
||||||
@@ -1513,12 +1518,14 @@ class ForgeAgent:
|
|||||||
|
|
||||||
# this prompt is critical to our agent so let's use the primary LLM API handler
|
# this prompt is critical to our agent so let's use the primary LLM API handler
|
||||||
|
|
||||||
verification_result = await app.LLM_API_HANDLER(
|
llm_api_handler = LLMAPIHandlerFactory.get_override_llm_api_handler(
|
||||||
|
llm_key_override, default=app.LLM_API_HANDLER
|
||||||
|
)
|
||||||
|
verification_result = await llm_api_handler(
|
||||||
prompt=verification_prompt,
|
prompt=verification_prompt,
|
||||||
step=step,
|
step=step,
|
||||||
screenshots=scraped_page_refreshed.screenshots,
|
screenshots=scraped_page_refreshed.screenshots,
|
||||||
prompt_name="check-user-goal",
|
prompt_name="check-user-goal",
|
||||||
llm_key_override=llm_key_override,
|
|
||||||
)
|
)
|
||||||
return CompleteVerifyResult.model_validate(verification_result)
|
return CompleteVerifyResult.model_validate(verification_result)
|
||||||
|
|
||||||
@@ -1833,7 +1840,10 @@ class ForgeAgent:
|
|||||||
prompt = prompt_engine.load_prompt(
|
prompt = prompt_engine.load_prompt(
|
||||||
"infer-action-type", navigation_goal=navigation_goal, prompt_name="infer-action-type"
|
"infer-action-type", navigation_goal=navigation_goal, prompt_name="infer-action-type"
|
||||||
)
|
)
|
||||||
json_response = await app.LLM_API_HANDLER(prompt=prompt, step=step, prompt_name="infer-action-type")
|
llm_api_handler = LLMAPIHandlerFactory.get_override_llm_api_handler(
|
||||||
|
task.llm_key, default=app.LLM_API_HANDLER
|
||||||
|
)
|
||||||
|
json_response = await llm_api_handler(prompt=prompt, step=step, prompt_name="infer-action-type")
|
||||||
if json_response.get("error"):
|
if json_response.get("error"):
|
||||||
raise FailedToParseActionInstruction(
|
raise FailedToParseActionInstruction(
|
||||||
reason=json_response.get("thought"), error_type=json_response.get("error")
|
reason=json_response.get("thought"), error_type=json_response.get("error")
|
||||||
@@ -2772,12 +2782,14 @@ class ForgeAgent:
|
|||||||
run_obj = await app.DATABASE.get_run(run_id=task.task_id, organization_id=task.organization_id)
|
run_obj = await app.DATABASE.get_run(run_id=task.task_id, organization_id=task.organization_id)
|
||||||
if run_obj and run_obj.task_run_type in CUA_RUN_TYPES:
|
if run_obj and run_obj.task_run_type in CUA_RUN_TYPES:
|
||||||
llm_key_override = None
|
llm_key_override = None
|
||||||
return await app.LLM_API_HANDLER(
|
llm_api_handler = LLMAPIHandlerFactory.get_override_llm_api_handler(
|
||||||
|
llm_key_override, default=app.LLM_API_HANDLER
|
||||||
|
)
|
||||||
|
return await llm_api_handler(
|
||||||
prompt=extract_action_prompt,
|
prompt=extract_action_prompt,
|
||||||
step=step,
|
step=step,
|
||||||
screenshots=scraped_page.screenshots,
|
screenshots=scraped_page.screenshots,
|
||||||
prompt_name="extract-actions",
|
prompt_name="extract-actions",
|
||||||
llm_key_override=llm_key_override,
|
|
||||||
)
|
)
|
||||||
return json_response
|
return json_response
|
||||||
|
|
||||||
|
|||||||
@@ -45,6 +45,20 @@ class LLMCallStats(BaseModel):
|
|||||||
class LLMAPIHandlerFactory:
|
class LLMAPIHandlerFactory:
|
||||||
_custom_handlers: dict[str, LLMAPIHandler] = {}
|
_custom_handlers: dict[str, LLMAPIHandler] = {}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_override_llm_api_handler(override_llm_key: str | None, *, default: LLMAPIHandler) -> LLMAPIHandler:
|
||||||
|
if not override_llm_key:
|
||||||
|
return default
|
||||||
|
try:
|
||||||
|
return LLMAPIHandlerFactory.get_llm_api_handler(override_llm_key)
|
||||||
|
except Exception:
|
||||||
|
LOG.warning(
|
||||||
|
"Failed to get override LLM API handler, going to use the default.",
|
||||||
|
override_llm_key=override_llm_key,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return default
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_llm_api_handler_with_router(llm_key: str) -> LLMAPIHandler:
|
def get_llm_api_handler_with_router(llm_key: str) -> LLMAPIHandler:
|
||||||
llm_config = LLMConfigRegistry.get_config(llm_key)
|
llm_config = LLMConfigRegistry.get_config(llm_key)
|
||||||
@@ -82,7 +96,6 @@ class LLMAPIHandlerFactory:
|
|||||||
ai_suggestion: AISuggestion | None = None,
|
ai_suggestion: AISuggestion | None = None,
|
||||||
screenshots: list[bytes] | None = None,
|
screenshots: list[bytes] | None = None,
|
||||||
parameters: dict[str, Any] | None = None,
|
parameters: dict[str, Any] | None = None,
|
||||||
llm_key_override: str | None = None,
|
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Custom LLM API handler that utilizes the LiteLLM router and fallbacks to OpenAI GPT-4 Vision.
|
Custom LLM API handler that utilizes the LiteLLM router and fallbacks to OpenAI GPT-4 Vision.
|
||||||
@@ -96,18 +109,10 @@ class LLMAPIHandlerFactory:
|
|||||||
Returns:
|
Returns:
|
||||||
The response from the LLM router.
|
The response from the LLM router.
|
||||||
"""
|
"""
|
||||||
nonlocal llm_config
|
|
||||||
nonlocal llm_key
|
|
||||||
|
|
||||||
local_llm_config: LLMConfig | LLMRouterConfig = llm_config
|
|
||||||
if llm_key_override:
|
|
||||||
local_llm_config = LLMConfigRegistry.get_config(llm_key_override)
|
|
||||||
|
|
||||||
local_llm_key = llm_key_override or llm_key
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
if parameters is None:
|
if parameters is None:
|
||||||
parameters = LLMAPIHandlerFactory.get_api_parameters(local_llm_config)
|
parameters = LLMAPIHandlerFactory.get_api_parameters(llm_config)
|
||||||
|
|
||||||
context = skyvern_context.current()
|
context = skyvern_context.current()
|
||||||
if context and len(context.hashed_href_map) > 0:
|
if context and len(context.hashed_href_map) > 0:
|
||||||
@@ -128,12 +133,12 @@ class LLMAPIHandlerFactory:
|
|||||||
task_v2=task_v2,
|
task_v2=task_v2,
|
||||||
thought=thought,
|
thought=thought,
|
||||||
)
|
)
|
||||||
messages = await llm_messages_builder(prompt, screenshots, local_llm_config.add_assistant_prefix)
|
messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix)
|
||||||
|
|
||||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||||
data=json.dumps(
|
data=json.dumps(
|
||||||
{
|
{
|
||||||
"model": local_llm_key,
|
"model": llm_key,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
**parameters,
|
**parameters,
|
||||||
}
|
}
|
||||||
@@ -149,12 +154,12 @@ class LLMAPIHandlerFactory:
|
|||||||
model=main_model_group, messages=messages, timeout=settings.LLM_CONFIG_TIMEOUT, **parameters
|
model=main_model_group, messages=messages, timeout=settings.LLM_CONFIG_TIMEOUT, **parameters
|
||||||
)
|
)
|
||||||
except litellm.exceptions.APIError as e:
|
except litellm.exceptions.APIError as e:
|
||||||
raise LLMProviderErrorRetryableTask(local_llm_key) from e
|
raise LLMProviderErrorRetryableTask(llm_key) from e
|
||||||
except litellm.exceptions.ContextWindowExceededError as e:
|
except litellm.exceptions.ContextWindowExceededError as e:
|
||||||
duration_seconds = time.time() - start_time
|
duration_seconds = time.time() - start_time
|
||||||
LOG.exception(
|
LOG.exception(
|
||||||
"Context window exceeded",
|
"Context window exceeded",
|
||||||
llm_key=local_llm_key,
|
llm_key=llm_key,
|
||||||
model=main_model_group,
|
model=main_model_group,
|
||||||
prompt_name=prompt_name,
|
prompt_name=prompt_name,
|
||||||
duration_seconds=duration_seconds,
|
duration_seconds=duration_seconds,
|
||||||
@@ -164,22 +169,22 @@ class LLMAPIHandlerFactory:
|
|||||||
duration_seconds = time.time() - start_time
|
duration_seconds = time.time() - start_time
|
||||||
LOG.exception(
|
LOG.exception(
|
||||||
"LLM token limit exceeded",
|
"LLM token limit exceeded",
|
||||||
llm_key=local_llm_key,
|
llm_key=llm_key,
|
||||||
model=main_model_group,
|
model=main_model_group,
|
||||||
prompt_name=prompt_name,
|
prompt_name=prompt_name,
|
||||||
duration_seconds=duration_seconds,
|
duration_seconds=duration_seconds,
|
||||||
)
|
)
|
||||||
raise LLMProviderErrorRetryableTask(local_llm_key) from e
|
raise LLMProviderErrorRetryableTask(llm_key) from e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
duration_seconds = time.time() - start_time
|
duration_seconds = time.time() - start_time
|
||||||
LOG.exception(
|
LOG.exception(
|
||||||
"LLM request failed unexpectedly",
|
"LLM request failed unexpectedly",
|
||||||
llm_key=local_llm_key,
|
llm_key=llm_key,
|
||||||
model=main_model_group,
|
model=main_model_group,
|
||||||
prompt_name=prompt_name,
|
prompt_name=prompt_name,
|
||||||
duration_seconds=duration_seconds,
|
duration_seconds=duration_seconds,
|
||||||
)
|
)
|
||||||
raise LLMProviderError(local_llm_key) from e
|
raise LLMProviderError(llm_key) from e
|
||||||
|
|
||||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||||
data=response.model_dump_json(indent=2).encode("utf-8"),
|
data=response.model_dump_json(indent=2).encode("utf-8"),
|
||||||
@@ -226,7 +231,7 @@ class LLMAPIHandlerFactory:
|
|||||||
reasoning_token_count=reasoning_tokens if reasoning_tokens > 0 else None,
|
reasoning_token_count=reasoning_tokens if reasoning_tokens > 0 else None,
|
||||||
cached_token_count=cached_tokens if cached_tokens > 0 else None,
|
cached_token_count=cached_tokens if cached_tokens > 0 else None,
|
||||||
)
|
)
|
||||||
parsed_response = parse_api_response(response, local_llm_config.add_assistant_prefix)
|
parsed_response = parse_api_response(response, llm_config.add_assistant_prefix)
|
||||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||||
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
|
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
|
||||||
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
|
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
|
||||||
@@ -253,7 +258,7 @@ class LLMAPIHandlerFactory:
|
|||||||
duration_seconds = time.time() - start_time
|
duration_seconds = time.time() - start_time
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"LLM API handler duration metrics",
|
"LLM API handler duration metrics",
|
||||||
llm_key=local_llm_key,
|
llm_key=llm_key,
|
||||||
model=main_model_group,
|
model=main_model_group,
|
||||||
prompt_name=prompt_name,
|
prompt_name=prompt_name,
|
||||||
duration_seconds=duration_seconds,
|
duration_seconds=duration_seconds,
|
||||||
@@ -287,25 +292,15 @@ class LLMAPIHandlerFactory:
|
|||||||
ai_suggestion: AISuggestion | None = None,
|
ai_suggestion: AISuggestion | None = None,
|
||||||
screenshots: list[bytes] | None = None,
|
screenshots: list[bytes] | None = None,
|
||||||
parameters: dict[str, Any] | None = None,
|
parameters: dict[str, Any] | None = None,
|
||||||
llm_key_override: str | None = None,
|
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
nonlocal llm_config
|
|
||||||
nonlocal llm_key
|
|
||||||
|
|
||||||
local_llm_config: LLMConfig | LLMRouterConfig = llm_config
|
|
||||||
if llm_key_override:
|
|
||||||
local_llm_config = LLMConfigRegistry.get_config(llm_key_override)
|
|
||||||
|
|
||||||
local_llm_key = llm_key_override or llm_key
|
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
active_parameters = base_parameters or {}
|
active_parameters = base_parameters or {}
|
||||||
if parameters is None:
|
if parameters is None:
|
||||||
parameters = LLMAPIHandlerFactory.get_api_parameters(local_llm_config)
|
parameters = LLMAPIHandlerFactory.get_api_parameters(llm_config)
|
||||||
|
|
||||||
active_parameters.update(parameters)
|
active_parameters.update(parameters)
|
||||||
if local_llm_config.litellm_params: # type: ignore
|
if llm_config.litellm_params: # type: ignore
|
||||||
active_parameters.update(local_llm_config.litellm_params) # type: ignore
|
active_parameters.update(llm_config.litellm_params) # type: ignore
|
||||||
|
|
||||||
context = skyvern_context.current()
|
context = skyvern_context.current()
|
||||||
if context and len(context.hashed_href_map) > 0:
|
if context and len(context.hashed_href_map) > 0:
|
||||||
@@ -328,12 +323,12 @@ class LLMAPIHandlerFactory:
|
|||||||
ai_suggestion=ai_suggestion,
|
ai_suggestion=ai_suggestion,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not local_llm_config.supports_vision:
|
if not llm_config.supports_vision:
|
||||||
screenshots = None
|
screenshots = None
|
||||||
|
|
||||||
model_name = local_llm_config.model_name
|
model_name = llm_config.model_name
|
||||||
|
|
||||||
messages = await llm_messages_builder(prompt, screenshots, local_llm_config.add_assistant_prefix)
|
messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix)
|
||||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||||
data=json.dumps(
|
data=json.dumps(
|
||||||
{
|
{
|
||||||
@@ -361,12 +356,12 @@ class LLMAPIHandlerFactory:
|
|||||||
**active_parameters,
|
**active_parameters,
|
||||||
)
|
)
|
||||||
except litellm.exceptions.APIError as e:
|
except litellm.exceptions.APIError as e:
|
||||||
raise LLMProviderErrorRetryableTask(local_llm_key) from e
|
raise LLMProviderErrorRetryableTask(llm_key) from e
|
||||||
except litellm.exceptions.ContextWindowExceededError as e:
|
except litellm.exceptions.ContextWindowExceededError as e:
|
||||||
duration_seconds = time.time() - start_time
|
duration_seconds = time.time() - start_time
|
||||||
LOG.exception(
|
LOG.exception(
|
||||||
"Context window exceeded",
|
"Context window exceeded",
|
||||||
llm_key=local_llm_key,
|
llm_key=llm_key,
|
||||||
model=model_name,
|
model=model_name,
|
||||||
prompt_name=prompt_name,
|
prompt_name=prompt_name,
|
||||||
duration_seconds=duration_seconds,
|
duration_seconds=duration_seconds,
|
||||||
@@ -376,22 +371,22 @@ class LLMAPIHandlerFactory:
|
|||||||
t_llm_cancelled = time.perf_counter()
|
t_llm_cancelled = time.perf_counter()
|
||||||
LOG.error(
|
LOG.error(
|
||||||
"LLM request got cancelled",
|
"LLM request got cancelled",
|
||||||
llm_key=local_llm_key,
|
llm_key=llm_key,
|
||||||
model=model_name,
|
model=model_name,
|
||||||
prompt_name=prompt_name,
|
prompt_name=prompt_name,
|
||||||
duration=t_llm_cancelled - t_llm_request,
|
duration=t_llm_cancelled - t_llm_request,
|
||||||
)
|
)
|
||||||
raise LLMProviderError(local_llm_key)
|
raise LLMProviderError(llm_key)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
duration_seconds = time.time() - start_time
|
duration_seconds = time.time() - start_time
|
||||||
LOG.exception(
|
LOG.exception(
|
||||||
"LLM request failed unexpectedly",
|
"LLM request failed unexpectedly",
|
||||||
llm_key=local_llm_key,
|
llm_key=llm_key,
|
||||||
model=model_name,
|
model=model_name,
|
||||||
prompt_name=prompt_name,
|
prompt_name=prompt_name,
|
||||||
duration_seconds=duration_seconds,
|
duration_seconds=duration_seconds,
|
||||||
)
|
)
|
||||||
raise LLMProviderError(local_llm_key) from e
|
raise LLMProviderError(llm_key) from e
|
||||||
|
|
||||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||||
data=response.model_dump_json(indent=2).encode("utf-8"),
|
data=response.model_dump_json(indent=2).encode("utf-8"),
|
||||||
@@ -439,7 +434,7 @@ class LLMAPIHandlerFactory:
|
|||||||
cached_token_count=cached_tokens if cached_tokens > 0 else None,
|
cached_token_count=cached_tokens if cached_tokens > 0 else None,
|
||||||
thought_cost=llm_cost,
|
thought_cost=llm_cost,
|
||||||
)
|
)
|
||||||
parsed_response = parse_api_response(response, local_llm_config.add_assistant_prefix)
|
parsed_response = parse_api_response(response, llm_config.add_assistant_prefix)
|
||||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||||
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
|
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
|
||||||
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
|
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
|
||||||
@@ -466,9 +461,9 @@ class LLMAPIHandlerFactory:
|
|||||||
duration_seconds = time.time() - start_time
|
duration_seconds = time.time() - start_time
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"LLM API handler duration metrics",
|
"LLM API handler duration metrics",
|
||||||
llm_key=local_llm_key,
|
llm_key=llm_key,
|
||||||
prompt_name=prompt_name,
|
prompt_name=prompt_name,
|
||||||
model=local_llm_config.model_name,
|
model=llm_config.model_name,
|
||||||
duration_seconds=duration_seconds,
|
duration_seconds=duration_seconds,
|
||||||
step_id=step.step_id if step else None,
|
step_id=step.step_id if step else None,
|
||||||
thought_id=thought.observer_thought_id if thought else None,
|
thought_id=thought.observer_thought_id if thought else None,
|
||||||
|
|||||||
@@ -94,7 +94,6 @@ class LLMAPIHandler(Protocol):
|
|||||||
ai_suggestion: AISuggestion | None = None,
|
ai_suggestion: AISuggestion | None = None,
|
||||||
screenshots: list[bytes] | None = None,
|
screenshots: list[bytes] | None = None,
|
||||||
parameters: dict[str, Any] | None = None,
|
parameters: dict[str, Any] | None = None,
|
||||||
llm_key_override: str | None = None,
|
|
||||||
) -> Awaitable[dict[str, Any]]: ...
|
) -> Awaitable[dict[str, Any]]: ...
|
||||||
|
|
||||||
|
|
||||||
@@ -107,6 +106,5 @@ async def dummy_llm_api_handler(
|
|||||||
ai_suggestion: AISuggestion | None = None,
|
ai_suggestion: AISuggestion | None = None,
|
||||||
screenshots: list[bytes] | None = None,
|
screenshots: list[bytes] | None = None,
|
||||||
parameters: dict[str, Any] | None = None,
|
parameters: dict[str, Any] | None = None,
|
||||||
llm_key_override: str | None = None,
|
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
raise NotImplementedError("Your LLM provider is not configured. Please configure it in the .env file.")
|
raise NotImplementedError("Your LLM provider is not configured. Please configure it in the .env file.")
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from skyvern.exceptions import (
|
|||||||
)
|
)
|
||||||
from skyvern.forge import app
|
from skyvern.forge import app
|
||||||
from skyvern.forge.prompts import prompt_engine
|
from skyvern.forge.prompts import prompt_engine
|
||||||
|
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory
|
||||||
from skyvern.forge.sdk.artifact.models import ArtifactType
|
from skyvern.forge.sdk.artifact.models import ArtifactType
|
||||||
from skyvern.forge.sdk.core import skyvern_context
|
from skyvern.forge.sdk.core import skyvern_context
|
||||||
from skyvern.forge.sdk.core.hashing import generate_url_hash
|
from skyvern.forge.sdk.core.hashing import generate_url_hash
|
||||||
@@ -617,12 +618,14 @@ async def run_task_v2_helper(
|
|||||||
thought_type=ThoughtType.plan,
|
thought_type=ThoughtType.plan,
|
||||||
thought_scenario=ThoughtScenario.generate_plan,
|
thought_scenario=ThoughtScenario.generate_plan,
|
||||||
)
|
)
|
||||||
task_v2_response = await app.LLM_API_HANDLER(
|
llm_api_handler = LLMAPIHandlerFactory.get_override_llm_api_handler(
|
||||||
|
task_v2.llm_key, default=app.LLM_API_HANDLER
|
||||||
|
)
|
||||||
|
task_v2_response = await llm_api_handler(
|
||||||
prompt=task_v2_prompt,
|
prompt=task_v2_prompt,
|
||||||
screenshots=scraped_page.screenshots,
|
screenshots=scraped_page.screenshots,
|
||||||
thought=thought,
|
thought=thought,
|
||||||
prompt_name="task_v2",
|
prompt_name="task_v2",
|
||||||
llm_key_override=task_v2.llm_key,
|
|
||||||
)
|
)
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"Task v2 response",
|
"Task v2 response",
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ from skyvern.forge.sdk.api.files import (
|
|||||||
list_files_in_directory,
|
list_files_in_directory,
|
||||||
wait_for_download_finished,
|
wait_for_download_finished,
|
||||||
)
|
)
|
||||||
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMCallerManager
|
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory, LLMCallerManager
|
||||||
from skyvern.forge.sdk.api.llm.exceptions import LLMProviderError
|
from skyvern.forge.sdk.api.llm.exceptions import LLMProviderError
|
||||||
from skyvern.forge.sdk.core import skyvern_context
|
from skyvern.forge.sdk.core import skyvern_context
|
||||||
from skyvern.forge.sdk.core.aiohttp_helper import aiohttp_post
|
from skyvern.forge.sdk.core.aiohttp_helper import aiohttp_post
|
||||||
@@ -2557,7 +2557,8 @@ async def sequentially_select_from_dropdown(
|
|||||||
select_history=json.dumps(build_sequential_select_history(select_history)),
|
select_history=json.dumps(build_sequential_select_history(select_history)),
|
||||||
local_datetime=datetime.now(ensure_context().tz_info).isoformat(),
|
local_datetime=datetime.now(ensure_context().tz_info).isoformat(),
|
||||||
)
|
)
|
||||||
json_response = await app.LLM_API_HANDLER(
|
llm_api_handler = LLMAPIHandlerFactory.get_override_llm_api_handler(task.llm_key, default=app.LLM_API_HANDLER)
|
||||||
|
json_response = await llm_api_handler(
|
||||||
prompt=prompt, screenshots=[screenshot], step=step, prompt_name="confirm-multi-selection-finish"
|
prompt=prompt, screenshots=[screenshot], step=step, prompt_name="confirm-multi-selection-finish"
|
||||||
)
|
)
|
||||||
if json_response.get("is_mini_goal_finished", False):
|
if json_response.get("is_mini_goal_finished", False):
|
||||||
@@ -2641,7 +2642,8 @@ async def select_from_emerging_elements(
|
|||||||
task_id=task.task_id,
|
task_id=task.task_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
json_response = await app.LLM_API_HANDLER(prompt=prompt, step=step, prompt_name="custom-select")
|
llm_api_handler = LLMAPIHandlerFactory.get_override_llm_api_handler(task.llm_key, default=app.LLM_API_HANDLER)
|
||||||
|
json_response = await llm_api_handler(prompt=prompt, step=step, prompt_name="custom-select")
|
||||||
value: str | None = json_response.get("value", None)
|
value: str | None = json_response.get("value", None)
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"LLM response for the matched element",
|
"LLM response for the matched element",
|
||||||
@@ -3385,12 +3387,12 @@ async def extract_information_for_navigation_goal(
|
|||||||
# CUA tasks should use the default data extraction llm key
|
# CUA tasks should use the default data extraction llm key
|
||||||
llm_key_override = None
|
llm_key_override = None
|
||||||
|
|
||||||
json_response = await app.LLM_API_HANDLER(
|
llm_api_handler = LLMAPIHandlerFactory.get_override_llm_api_handler(llm_key_override, default=app.LLM_API_HANDLER)
|
||||||
|
json_response = await llm_api_handler(
|
||||||
prompt=extract_information_prompt,
|
prompt=extract_information_prompt,
|
||||||
step=step,
|
step=step,
|
||||||
screenshots=scraped_page.screenshots,
|
screenshots=scraped_page.screenshots,
|
||||||
prompt_name="extract-information",
|
prompt_name="extract-information",
|
||||||
llm_key_override=llm_key_override,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return ScrapeResult(
|
return ScrapeResult(
|
||||||
|
|||||||
Reference in New Issue
Block a user