Batch LLM artifacts creation (#4322)

This commit is contained in:
Stanislav Novosad
2025-12-17 20:15:26 -07:00
committed by GitHub
parent f594474b9e
commit 1eca20b78a
3 changed files with 908 additions and 828 deletions

View File

@@ -1880,41 +1880,55 @@ class ForgeAgent:
LOG.debug("Persisting speculative LLM metadata")
artifacts = []
if metadata.prompt:
await app.ARTIFACT_MANAGER.create_llm_artifact(
artifacts.append(
await app.ARTIFACT_MANAGER.prepare_llm_artifact(
data=metadata.prompt.encode("utf-8"),
artifact_type=ArtifactType.LLM_PROMPT,
screenshots=screenshots,
step=step,
)
)
if metadata.llm_request_json:
await app.ARTIFACT_MANAGER.create_llm_artifact(
artifacts.append(
await app.ARTIFACT_MANAGER.prepare_llm_artifact(
data=metadata.llm_request_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_REQUEST,
step=step,
)
)
if metadata.llm_response_json:
await app.ARTIFACT_MANAGER.create_llm_artifact(
artifacts.append(
await app.ARTIFACT_MANAGER.prepare_llm_artifact(
data=metadata.llm_response_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE,
step=step,
)
)
if metadata.parsed_response_json:
await app.ARTIFACT_MANAGER.create_llm_artifact(
artifacts.append(
await app.ARTIFACT_MANAGER.prepare_llm_artifact(
data=metadata.parsed_response_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
step=step,
)
)
if metadata.rendered_response_json:
await app.ARTIFACT_MANAGER.create_llm_artifact(
artifacts.append(
await app.ARTIFACT_MANAGER.prepare_llm_artifact(
data=metadata.rendered_response_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE_RENDERED,
step=step,
)
)
if artifacts:
await app.ARTIFACT_MANAGER.bulk_create_artifacts(artifacts)
incremental_cost = metadata.llm_cost if metadata.llm_cost and metadata.llm_cost > 0 else None
incremental_input_tokens = (

View File

@@ -35,6 +35,7 @@ from skyvern.forge.sdk.api.llm.models import (
)
from skyvern.forge.sdk.api.llm.ui_tars_response import UITarsResponse
from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, llm_messages_builder_with_history, parse_api_response
from skyvern.forge.sdk.artifact.manager import BulkArtifactCreationRequest
from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
@@ -93,6 +94,7 @@ def _get_artifact_targets_and_persist_flag(
async def _log_hashed_href_map_artifacts_if_needed(
artifacts: list[BulkArtifactCreationRequest | None],
context: SkyvernContext | None,
step: Step | None,
task_v2: TaskV2 | None,
@@ -105,11 +107,13 @@ async def _log_hashed_href_map_artifacts_if_needed(
step, is_speculative_step, task_v2, thought, ai_suggestion
)
if context and context.hashed_href_map and should_persist_llm_artifacts:
await app.ARTIFACT_MANAGER.create_llm_artifact(
artifacts.append(
await app.ARTIFACT_MANAGER.prepare_llm_artifact(
data=json.dumps(context.hashed_href_map, indent=2).encode("utf-8"),
artifact_type=ArtifactType.HASHED_HREF_MAP,
**artifact_targets,
)
)
def _log_vertex_cache_hit_if_needed(
@@ -446,7 +450,11 @@ class LLMAPIHandlerFactory:
should_persist_llm_artifacts, artifact_targets = _get_artifact_targets_and_persist_flag(
step, is_speculative_step, task_v2, thought, ai_suggestion
)
artifacts: list[BulkArtifactCreationRequest | None] = []
try:
await _log_hashed_href_map_artifacts_if_needed(
artifacts,
context,
step,
task_v2,
@@ -457,12 +465,14 @@ class LLMAPIHandlerFactory:
llm_prompt_value = prompt
if should_persist_llm_artifacts:
await app.ARTIFACT_MANAGER.create_llm_artifact(
artifacts.append(
await app.ARTIFACT_MANAGER.prepare_llm_artifact(
data=llm_prompt_value.encode("utf-8"),
artifact_type=ArtifactType.LLM_PROMPT,
screenshots=screenshots,
**artifact_targets,
)
)
# Build messages and apply caching in one step
messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix)
@@ -475,11 +485,13 @@ class LLMAPIHandlerFactory:
}
llm_request_json = json.dumps(llm_request_payload)
if should_persist_llm_artifacts:
await app.ARTIFACT_MANAGER.create_llm_artifact(
artifacts.append(
await app.ARTIFACT_MANAGER.prepare_llm_artifact(
data=llm_request_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_REQUEST,
**artifact_targets,
)
)
return llm_request_json
# Inject context caching system message when available
@@ -661,11 +673,13 @@ class LLMAPIHandlerFactory:
llm_response_json = response.model_dump_json(indent=2)
if should_persist_llm_artifacts:
await app.ARTIFACT_MANAGER.create_llm_artifact(
artifacts.append(
await app.ARTIFACT_MANAGER.prepare_llm_artifact(
data=llm_response_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE,
**artifact_targets,
)
)
prompt_tokens = 0
completion_tokens = 0
reasoning_tokens = 0
@@ -724,11 +738,13 @@ class LLMAPIHandlerFactory:
parsed_response = parse_api_response(response, llm_config.add_assistant_prefix, force_dict)
parsed_response_json = json.dumps(parsed_response, indent=2)
if should_persist_llm_artifacts:
await app.ARTIFACT_MANAGER.create_llm_artifact(
artifacts.append(
await app.ARTIFACT_MANAGER.prepare_llm_artifact(
data=parsed_response_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
**artifact_targets,
)
)
rendered_response_json = None
if context and len(context.hashed_href_map) > 0:
@@ -737,11 +753,13 @@ class LLMAPIHandlerFactory:
parsed_response = json.loads(rendered_content)
rendered_response_json = json.dumps(parsed_response, indent=2)
if should_persist_llm_artifacts:
await app.ARTIFACT_MANAGER.create_llm_artifact(
artifacts.append(
await app.ARTIFACT_MANAGER.prepare_llm_artifact(
data=rendered_response_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE_RENDERED,
**artifact_targets,
)
)
# Track LLM API handler duration, token counts, and cost
organization_id = organization_id or (
@@ -782,6 +800,11 @@ class LLMAPIHandlerFactory:
)
return parsed_response
finally:
try:
await app.ARTIFACT_MANAGER.bulk_create_artifacts(artifacts)
except Exception:
LOG.error("Failed to persist artifacts", exc_info=True)
llm_api_handler_with_router_and_fallback.llm_key = llm_key # type: ignore[attr-defined]
return llm_api_handler_with_router_and_fallback
@@ -855,7 +878,11 @@ class LLMAPIHandlerFactory:
should_persist_llm_artifacts, artifact_targets = _get_artifact_targets_and_persist_flag(
step, is_speculative_step, task_v2, thought, ai_suggestion
)
artifacts: list[BulkArtifactCreationRequest | None] = []
try:
await _log_hashed_href_map_artifacts_if_needed(
artifacts,
context,
step,
task_v2,
@@ -866,12 +893,14 @@ class LLMAPIHandlerFactory:
llm_prompt_value = prompt
if should_persist_llm_artifacts:
await app.ARTIFACT_MANAGER.create_llm_artifact(
artifacts.append(
await app.ARTIFACT_MANAGER.prepare_llm_artifact(
data=llm_prompt_value.encode("utf-8"),
artifact_type=ArtifactType.LLM_PROMPT,
screenshots=screenshots,
**artifact_targets,
)
)
if not llm_config.supports_vision:
screenshots = None
@@ -957,11 +986,13 @@ class LLMAPIHandlerFactory:
}
llm_request_json = json.dumps(llm_request_payload)
if should_persist_llm_artifacts:
await app.ARTIFACT_MANAGER.create_llm_artifact(
artifacts.append(
await app.ARTIFACT_MANAGER.prepare_llm_artifact(
data=llm_request_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_REQUEST,
**artifact_targets,
)
)
# Strip static prompt from the request messages because it's already in the cache
# Sending it again causes double-billing (once cached, once uncached)
@@ -1034,11 +1065,13 @@ class LLMAPIHandlerFactory:
llm_response_json = response.model_dump_json(indent=2)
if should_persist_llm_artifacts:
await app.ARTIFACT_MANAGER.create_llm_artifact(
artifacts.append(
await app.ARTIFACT_MANAGER.prepare_llm_artifact(
data=llm_response_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE,
**artifact_targets,
)
)
prompt_tokens = 0
completion_tokens = 0
@@ -1101,11 +1134,13 @@ class LLMAPIHandlerFactory:
parsed_response = parse_api_response(response, llm_config.add_assistant_prefix, force_dict)
parsed_response_json = json.dumps(parsed_response, indent=2)
if should_persist_llm_artifacts:
await app.ARTIFACT_MANAGER.create_llm_artifact(
artifacts.append(
await app.ARTIFACT_MANAGER.prepare_llm_artifact(
data=parsed_response_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
**artifact_targets,
)
)
rendered_response_json = None
if context and len(context.hashed_href_map) > 0:
@@ -1114,11 +1149,13 @@ class LLMAPIHandlerFactory:
parsed_response = json.loads(rendered_content)
rendered_response_json = json.dumps(parsed_response, indent=2)
if should_persist_llm_artifacts:
await app.ARTIFACT_MANAGER.create_llm_artifact(
artifacts.append(
await app.ARTIFACT_MANAGER.prepare_llm_artifact(
data=rendered_response_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE_RENDERED,
**artifact_targets,
)
)
# Track LLM API handler duration, token counts, and cost
organization_id = organization_id or (
@@ -1159,6 +1196,11 @@ class LLMAPIHandlerFactory:
)
return parsed_response
finally:
try:
await app.ARTIFACT_MANAGER.bulk_create_artifacts(artifacts)
except Exception:
LOG.error("Failed to persist artifacts", exc_info=True)
llm_api_handler.llm_key = llm_key # type: ignore[attr-defined]
return llm_api_handler
@@ -1278,7 +1320,11 @@ class LLMCaller:
should_persist_llm_artifacts, artifact_targets = _get_artifact_targets_and_persist_flag(
step, is_speculative_step, task_v2, thought, ai_suggestion
)
artifacts: list[BulkArtifactCreationRequest | None] = []
try:
await _log_hashed_href_map_artifacts_if_needed(
artifacts,
context,
step,
task_v2,
@@ -1306,12 +1352,14 @@ class LLMCaller:
llm_prompt_value = prompt or ""
if prompt and should_persist_llm_artifacts:
await app.ARTIFACT_MANAGER.create_llm_artifact(
artifacts.append(
await app.ARTIFACT_MANAGER.prepare_llm_artifact(
data=prompt.encode("utf-8"),
artifact_type=ArtifactType.LLM_PROMPT,
screenshots=screenshots,
**artifact_targets,
)
)
if not self.llm_config.supports_vision:
screenshots = None
@@ -1342,11 +1390,14 @@ class LLMCaller:
}
llm_request_json = json.dumps(llm_request_payload)
if should_persist_llm_artifacts:
await app.ARTIFACT_MANAGER.create_llm_artifact(
artifacts.append(
await app.ARTIFACT_MANAGER.prepare_llm_artifact(
data=llm_request_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_REQUEST,
**artifact_targets,
)
)
t_llm_request = time.perf_counter()
try:
response = await self._dispatch_llm_call(
@@ -1393,11 +1444,13 @@ class LLMCaller:
llm_response_json = response.model_dump_json(indent=2)
if should_persist_llm_artifacts:
await app.ARTIFACT_MANAGER.create_llm_artifact(
artifacts.append(
await app.ARTIFACT_MANAGER.prepare_llm_artifact(
data=llm_response_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE,
**artifact_targets,
)
)
call_stats = await self.get_call_stats(response)
if step and not is_speculative_step:
@@ -1450,11 +1503,13 @@ class LLMCaller:
parsed_response = parse_api_response(response, self.llm_config.add_assistant_prefix, force_dict)
parsed_response_json = json.dumps(parsed_response, indent=2)
if should_persist_llm_artifacts:
await app.ARTIFACT_MANAGER.create_llm_artifact(
artifacts.append(
await app.ARTIFACT_MANAGER.prepare_llm_artifact(
data=parsed_response_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
**artifact_targets,
)
)
rendered_response_json = None
if context and len(context.hashed_href_map) > 0:
@@ -1463,11 +1518,13 @@ class LLMCaller:
parsed_response = json.loads(rendered_content)
rendered_response_json = json.dumps(parsed_response, indent=2)
if should_persist_llm_artifacts:
await app.ARTIFACT_MANAGER.create_llm_artifact(
artifacts.append(
await app.ARTIFACT_MANAGER.prepare_llm_artifact(
data=rendered_response_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE_RENDERED,
**artifact_targets,
)
)
if step and is_speculative_step:
step.speculative_llm_metadata = SpeculativeLLMMetadata(
@@ -1487,6 +1544,11 @@ class LLMCaller:
)
return parsed_response
finally:
try:
await app.ARTIFACT_MANAGER.bulk_create_artifacts(artifacts)
except Exception:
LOG.error("Failed to persist artifacts", exc_info=True)
def get_screenshot_resize_target_dimension(self, window_dimension: Resolution | None) -> Resolution:
if window_dimension and window_dimension != self.browser_window_dimension:

View File

@@ -367,6 +367,24 @@ class ArtifactManager:
data=data,
)
async def bulk_create_artifacts(
self,
requests: list[BulkArtifactCreationRequest | None],
) -> list[str]:
artifacts: list[ArtifactBatchData] = []
primary_key: str | None = None
for request in requests:
if request:
artifacts.extend(request.artifacts)
primary_key = request.primary_key
if primary_key is None or not artifacts:
return []
return await self._bulk_create_artifacts(
BulkArtifactCreationRequest(artifacts=artifacts, primary_key=primary_key)
)
async def _bulk_create_artifacts(
self,
request: BulkArtifactCreationRequest,
@@ -636,7 +654,7 @@ class ArtifactManager:
return BulkArtifactCreationRequest(artifacts=artifacts, primary_key=ai_suggestion.ai_suggestion_id)
async def create_llm_artifact(
async def prepare_llm_artifact(
self,
data: bytes,
artifact_type: ArtifactType,
@@ -645,54 +663,40 @@ class ArtifactManager:
thought: Thought | None = None,
task_v2: TaskV2 | None = None,
ai_suggestion: AISuggestion | None = None,
) -> None:
"""
Create LLM artifact with optional screenshots using bulk insert.
Args:
data: Main artifact data
artifact_type: Type of the main artifact
screenshots: Optional list of screenshot data
step: Optional Step entity
thought: Optional Thought entity
task_v2: Optional TaskV2 entity
ai_suggestion: Optional AISuggestion entity
"""
) -> BulkArtifactCreationRequest | None:
if step:
request = self._prepare_step_artifacts(
return self._prepare_step_artifacts(
step=step,
artifact_type=artifact_type,
data=data,
screenshots=screenshots,
)
await self._bulk_create_artifacts(request)
elif task_v2:
request = self._prepare_task_v2_artifacts(
return self._prepare_task_v2_artifacts(
task_v2=task_v2,
artifact_type=artifact_type,
data=data,
screenshots=screenshots,
)
await self._bulk_create_artifacts(request)
elif thought:
request = self._prepare_thought_artifacts(
return self._prepare_thought_artifacts(
thought=thought,
artifact_type=artifact_type,
data=data,
screenshots=screenshots,
)
await self._bulk_create_artifacts(request)
elif ai_suggestion:
request = self._prepare_ai_suggestion_artifacts(
return self._prepare_ai_suggestion_artifacts(
ai_suggestion=ai_suggestion,
artifact_type=artifact_type,
data=data,
screenshots=screenshots,
)
await self._bulk_create_artifacts(request)
else:
return None
async def update_artifact_data(
self,