Batch LLM artifacts creation (#4322)
This commit is contained in:
committed by
GitHub
parent
f594474b9e
commit
1eca20b78a
@@ -1880,41 +1880,55 @@ class ForgeAgent:
|
||||
|
||||
LOG.debug("Persisting speculative LLM metadata")
|
||||
|
||||
artifacts = []
|
||||
if metadata.prompt:
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
artifacts.append(
|
||||
await app.ARTIFACT_MANAGER.prepare_llm_artifact(
|
||||
data=metadata.prompt.encode("utf-8"),
|
||||
artifact_type=ArtifactType.LLM_PROMPT,
|
||||
screenshots=screenshots,
|
||||
step=step,
|
||||
)
|
||||
)
|
||||
|
||||
if metadata.llm_request_json:
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
artifacts.append(
|
||||
await app.ARTIFACT_MANAGER.prepare_llm_artifact(
|
||||
data=metadata.llm_request_json.encode("utf-8"),
|
||||
artifact_type=ArtifactType.LLM_REQUEST,
|
||||
step=step,
|
||||
)
|
||||
)
|
||||
|
||||
if metadata.llm_response_json:
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
artifacts.append(
|
||||
await app.ARTIFACT_MANAGER.prepare_llm_artifact(
|
||||
data=metadata.llm_response_json.encode("utf-8"),
|
||||
artifact_type=ArtifactType.LLM_RESPONSE,
|
||||
step=step,
|
||||
)
|
||||
)
|
||||
|
||||
if metadata.parsed_response_json:
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
artifacts.append(
|
||||
await app.ARTIFACT_MANAGER.prepare_llm_artifact(
|
||||
data=metadata.parsed_response_json.encode("utf-8"),
|
||||
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
|
||||
step=step,
|
||||
)
|
||||
)
|
||||
|
||||
if metadata.rendered_response_json:
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
artifacts.append(
|
||||
await app.ARTIFACT_MANAGER.prepare_llm_artifact(
|
||||
data=metadata.rendered_response_json.encode("utf-8"),
|
||||
artifact_type=ArtifactType.LLM_RESPONSE_RENDERED,
|
||||
step=step,
|
||||
)
|
||||
)
|
||||
|
||||
if artifacts:
|
||||
await app.ARTIFACT_MANAGER.bulk_create_artifacts(artifacts)
|
||||
|
||||
incremental_cost = metadata.llm_cost if metadata.llm_cost and metadata.llm_cost > 0 else None
|
||||
incremental_input_tokens = (
|
||||
|
||||
@@ -35,6 +35,7 @@ from skyvern.forge.sdk.api.llm.models import (
|
||||
)
|
||||
from skyvern.forge.sdk.api.llm.ui_tars_response import UITarsResponse
|
||||
from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, llm_messages_builder_with_history, parse_api_response
|
||||
from skyvern.forge.sdk.artifact.manager import BulkArtifactCreationRequest
|
||||
from skyvern.forge.sdk.artifact.models import ArtifactType
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
|
||||
@@ -93,6 +94,7 @@ def _get_artifact_targets_and_persist_flag(
|
||||
|
||||
|
||||
async def _log_hashed_href_map_artifacts_if_needed(
|
||||
artifacts: list[BulkArtifactCreationRequest | None],
|
||||
context: SkyvernContext | None,
|
||||
step: Step | None,
|
||||
task_v2: TaskV2 | None,
|
||||
@@ -105,11 +107,13 @@ async def _log_hashed_href_map_artifacts_if_needed(
|
||||
step, is_speculative_step, task_v2, thought, ai_suggestion
|
||||
)
|
||||
if context and context.hashed_href_map and should_persist_llm_artifacts:
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
artifacts.append(
|
||||
await app.ARTIFACT_MANAGER.prepare_llm_artifact(
|
||||
data=json.dumps(context.hashed_href_map, indent=2).encode("utf-8"),
|
||||
artifact_type=ArtifactType.HASHED_HREF_MAP,
|
||||
**artifact_targets,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _log_vertex_cache_hit_if_needed(
|
||||
@@ -446,7 +450,11 @@ class LLMAPIHandlerFactory:
|
||||
should_persist_llm_artifacts, artifact_targets = _get_artifact_targets_and_persist_flag(
|
||||
step, is_speculative_step, task_v2, thought, ai_suggestion
|
||||
)
|
||||
|
||||
artifacts: list[BulkArtifactCreationRequest | None] = []
|
||||
try:
|
||||
await _log_hashed_href_map_artifacts_if_needed(
|
||||
artifacts,
|
||||
context,
|
||||
step,
|
||||
task_v2,
|
||||
@@ -457,12 +465,14 @@ class LLMAPIHandlerFactory:
|
||||
|
||||
llm_prompt_value = prompt
|
||||
if should_persist_llm_artifacts:
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
artifacts.append(
|
||||
await app.ARTIFACT_MANAGER.prepare_llm_artifact(
|
||||
data=llm_prompt_value.encode("utf-8"),
|
||||
artifact_type=ArtifactType.LLM_PROMPT,
|
||||
screenshots=screenshots,
|
||||
**artifact_targets,
|
||||
)
|
||||
)
|
||||
# Build messages and apply caching in one step
|
||||
messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix)
|
||||
|
||||
@@ -475,11 +485,13 @@ class LLMAPIHandlerFactory:
|
||||
}
|
||||
llm_request_json = json.dumps(llm_request_payload)
|
||||
if should_persist_llm_artifacts:
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
artifacts.append(
|
||||
await app.ARTIFACT_MANAGER.prepare_llm_artifact(
|
||||
data=llm_request_json.encode("utf-8"),
|
||||
artifact_type=ArtifactType.LLM_REQUEST,
|
||||
**artifact_targets,
|
||||
)
|
||||
)
|
||||
return llm_request_json
|
||||
|
||||
# Inject context caching system message when available
|
||||
@@ -661,11 +673,13 @@ class LLMAPIHandlerFactory:
|
||||
|
||||
llm_response_json = response.model_dump_json(indent=2)
|
||||
if should_persist_llm_artifacts:
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
artifacts.append(
|
||||
await app.ARTIFACT_MANAGER.prepare_llm_artifact(
|
||||
data=llm_response_json.encode("utf-8"),
|
||||
artifact_type=ArtifactType.LLM_RESPONSE,
|
||||
**artifact_targets,
|
||||
)
|
||||
)
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
reasoning_tokens = 0
|
||||
@@ -724,11 +738,13 @@ class LLMAPIHandlerFactory:
|
||||
parsed_response = parse_api_response(response, llm_config.add_assistant_prefix, force_dict)
|
||||
parsed_response_json = json.dumps(parsed_response, indent=2)
|
||||
if should_persist_llm_artifacts:
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
artifacts.append(
|
||||
await app.ARTIFACT_MANAGER.prepare_llm_artifact(
|
||||
data=parsed_response_json.encode("utf-8"),
|
||||
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
|
||||
**artifact_targets,
|
||||
)
|
||||
)
|
||||
|
||||
rendered_response_json = None
|
||||
if context and len(context.hashed_href_map) > 0:
|
||||
@@ -737,11 +753,13 @@ class LLMAPIHandlerFactory:
|
||||
parsed_response = json.loads(rendered_content)
|
||||
rendered_response_json = json.dumps(parsed_response, indent=2)
|
||||
if should_persist_llm_artifacts:
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
artifacts.append(
|
||||
await app.ARTIFACT_MANAGER.prepare_llm_artifact(
|
||||
data=rendered_response_json.encode("utf-8"),
|
||||
artifact_type=ArtifactType.LLM_RESPONSE_RENDERED,
|
||||
**artifact_targets,
|
||||
)
|
||||
)
|
||||
|
||||
# Track LLM API handler duration, token counts, and cost
|
||||
organization_id = organization_id or (
|
||||
@@ -782,6 +800,11 @@ class LLMAPIHandlerFactory:
|
||||
)
|
||||
|
||||
return parsed_response
|
||||
finally:
|
||||
try:
|
||||
await app.ARTIFACT_MANAGER.bulk_create_artifacts(artifacts)
|
||||
except Exception:
|
||||
LOG.error("Failed to persist artifacts", exc_info=True)
|
||||
|
||||
llm_api_handler_with_router_and_fallback.llm_key = llm_key # type: ignore[attr-defined]
|
||||
return llm_api_handler_with_router_and_fallback
|
||||
@@ -855,7 +878,11 @@ class LLMAPIHandlerFactory:
|
||||
should_persist_llm_artifacts, artifact_targets = _get_artifact_targets_and_persist_flag(
|
||||
step, is_speculative_step, task_v2, thought, ai_suggestion
|
||||
)
|
||||
|
||||
artifacts: list[BulkArtifactCreationRequest | None] = []
|
||||
try:
|
||||
await _log_hashed_href_map_artifacts_if_needed(
|
||||
artifacts,
|
||||
context,
|
||||
step,
|
||||
task_v2,
|
||||
@@ -866,12 +893,14 @@ class LLMAPIHandlerFactory:
|
||||
|
||||
llm_prompt_value = prompt
|
||||
if should_persist_llm_artifacts:
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
artifacts.append(
|
||||
await app.ARTIFACT_MANAGER.prepare_llm_artifact(
|
||||
data=llm_prompt_value.encode("utf-8"),
|
||||
artifact_type=ArtifactType.LLM_PROMPT,
|
||||
screenshots=screenshots,
|
||||
**artifact_targets,
|
||||
)
|
||||
)
|
||||
|
||||
if not llm_config.supports_vision:
|
||||
screenshots = None
|
||||
@@ -957,11 +986,13 @@ class LLMAPIHandlerFactory:
|
||||
}
|
||||
llm_request_json = json.dumps(llm_request_payload)
|
||||
if should_persist_llm_artifacts:
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
artifacts.append(
|
||||
await app.ARTIFACT_MANAGER.prepare_llm_artifact(
|
||||
data=llm_request_json.encode("utf-8"),
|
||||
artifact_type=ArtifactType.LLM_REQUEST,
|
||||
**artifact_targets,
|
||||
)
|
||||
)
|
||||
|
||||
# Strip static prompt from the request messages because it's already in the cache
|
||||
# Sending it again causes double-billing (once cached, once uncached)
|
||||
@@ -1034,11 +1065,13 @@ class LLMAPIHandlerFactory:
|
||||
|
||||
llm_response_json = response.model_dump_json(indent=2)
|
||||
if should_persist_llm_artifacts:
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
artifacts.append(
|
||||
await app.ARTIFACT_MANAGER.prepare_llm_artifact(
|
||||
data=llm_response_json.encode("utf-8"),
|
||||
artifact_type=ArtifactType.LLM_RESPONSE,
|
||||
**artifact_targets,
|
||||
)
|
||||
)
|
||||
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
@@ -1101,11 +1134,13 @@ class LLMAPIHandlerFactory:
|
||||
parsed_response = parse_api_response(response, llm_config.add_assistant_prefix, force_dict)
|
||||
parsed_response_json = json.dumps(parsed_response, indent=2)
|
||||
if should_persist_llm_artifacts:
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
artifacts.append(
|
||||
await app.ARTIFACT_MANAGER.prepare_llm_artifact(
|
||||
data=parsed_response_json.encode("utf-8"),
|
||||
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
|
||||
**artifact_targets,
|
||||
)
|
||||
)
|
||||
|
||||
rendered_response_json = None
|
||||
if context and len(context.hashed_href_map) > 0:
|
||||
@@ -1114,11 +1149,13 @@ class LLMAPIHandlerFactory:
|
||||
parsed_response = json.loads(rendered_content)
|
||||
rendered_response_json = json.dumps(parsed_response, indent=2)
|
||||
if should_persist_llm_artifacts:
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
artifacts.append(
|
||||
await app.ARTIFACT_MANAGER.prepare_llm_artifact(
|
||||
data=rendered_response_json.encode("utf-8"),
|
||||
artifact_type=ArtifactType.LLM_RESPONSE_RENDERED,
|
||||
**artifact_targets,
|
||||
)
|
||||
)
|
||||
|
||||
# Track LLM API handler duration, token counts, and cost
|
||||
organization_id = organization_id or (
|
||||
@@ -1159,6 +1196,11 @@ class LLMAPIHandlerFactory:
|
||||
)
|
||||
|
||||
return parsed_response
|
||||
finally:
|
||||
try:
|
||||
await app.ARTIFACT_MANAGER.bulk_create_artifacts(artifacts)
|
||||
except Exception:
|
||||
LOG.error("Failed to persist artifacts", exc_info=True)
|
||||
|
||||
llm_api_handler.llm_key = llm_key # type: ignore[attr-defined]
|
||||
return llm_api_handler
|
||||
@@ -1278,7 +1320,11 @@ class LLMCaller:
|
||||
should_persist_llm_artifacts, artifact_targets = _get_artifact_targets_and_persist_flag(
|
||||
step, is_speculative_step, task_v2, thought, ai_suggestion
|
||||
)
|
||||
|
||||
artifacts: list[BulkArtifactCreationRequest | None] = []
|
||||
try:
|
||||
await _log_hashed_href_map_artifacts_if_needed(
|
||||
artifacts,
|
||||
context,
|
||||
step,
|
||||
task_v2,
|
||||
@@ -1306,12 +1352,14 @@ class LLMCaller:
|
||||
|
||||
llm_prompt_value = prompt or ""
|
||||
if prompt and should_persist_llm_artifacts:
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
artifacts.append(
|
||||
await app.ARTIFACT_MANAGER.prepare_llm_artifact(
|
||||
data=prompt.encode("utf-8"),
|
||||
artifact_type=ArtifactType.LLM_PROMPT,
|
||||
screenshots=screenshots,
|
||||
**artifact_targets,
|
||||
)
|
||||
)
|
||||
|
||||
if not self.llm_config.supports_vision:
|
||||
screenshots = None
|
||||
@@ -1342,11 +1390,14 @@ class LLMCaller:
|
||||
}
|
||||
llm_request_json = json.dumps(llm_request_payload)
|
||||
if should_persist_llm_artifacts:
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
artifacts.append(
|
||||
await app.ARTIFACT_MANAGER.prepare_llm_artifact(
|
||||
data=llm_request_json.encode("utf-8"),
|
||||
artifact_type=ArtifactType.LLM_REQUEST,
|
||||
**artifact_targets,
|
||||
)
|
||||
)
|
||||
|
||||
t_llm_request = time.perf_counter()
|
||||
try:
|
||||
response = await self._dispatch_llm_call(
|
||||
@@ -1393,11 +1444,13 @@ class LLMCaller:
|
||||
|
||||
llm_response_json = response.model_dump_json(indent=2)
|
||||
if should_persist_llm_artifacts:
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
artifacts.append(
|
||||
await app.ARTIFACT_MANAGER.prepare_llm_artifact(
|
||||
data=llm_response_json.encode("utf-8"),
|
||||
artifact_type=ArtifactType.LLM_RESPONSE,
|
||||
**artifact_targets,
|
||||
)
|
||||
)
|
||||
|
||||
call_stats = await self.get_call_stats(response)
|
||||
if step and not is_speculative_step:
|
||||
@@ -1450,11 +1503,13 @@ class LLMCaller:
|
||||
parsed_response = parse_api_response(response, self.llm_config.add_assistant_prefix, force_dict)
|
||||
parsed_response_json = json.dumps(parsed_response, indent=2)
|
||||
if should_persist_llm_artifacts:
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
artifacts.append(
|
||||
await app.ARTIFACT_MANAGER.prepare_llm_artifact(
|
||||
data=parsed_response_json.encode("utf-8"),
|
||||
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
|
||||
**artifact_targets,
|
||||
)
|
||||
)
|
||||
|
||||
rendered_response_json = None
|
||||
if context and len(context.hashed_href_map) > 0:
|
||||
@@ -1463,11 +1518,13 @@ class LLMCaller:
|
||||
parsed_response = json.loads(rendered_content)
|
||||
rendered_response_json = json.dumps(parsed_response, indent=2)
|
||||
if should_persist_llm_artifacts:
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
artifacts.append(
|
||||
await app.ARTIFACT_MANAGER.prepare_llm_artifact(
|
||||
data=rendered_response_json.encode("utf-8"),
|
||||
artifact_type=ArtifactType.LLM_RESPONSE_RENDERED,
|
||||
**artifact_targets,
|
||||
)
|
||||
)
|
||||
|
||||
if step and is_speculative_step:
|
||||
step.speculative_llm_metadata = SpeculativeLLMMetadata(
|
||||
@@ -1487,6 +1544,11 @@ class LLMCaller:
|
||||
)
|
||||
|
||||
return parsed_response
|
||||
finally:
|
||||
try:
|
||||
await app.ARTIFACT_MANAGER.bulk_create_artifacts(artifacts)
|
||||
except Exception:
|
||||
LOG.error("Failed to persist artifacts", exc_info=True)
|
||||
|
||||
def get_screenshot_resize_target_dimension(self, window_dimension: Resolution | None) -> Resolution:
|
||||
if window_dimension and window_dimension != self.browser_window_dimension:
|
||||
|
||||
@@ -367,6 +367,24 @@ class ArtifactManager:
|
||||
data=data,
|
||||
)
|
||||
|
||||
async def bulk_create_artifacts(
|
||||
self,
|
||||
requests: list[BulkArtifactCreationRequest | None],
|
||||
) -> list[str]:
|
||||
artifacts: list[ArtifactBatchData] = []
|
||||
primary_key: str | None = None
|
||||
for request in requests:
|
||||
if request:
|
||||
artifacts.extend(request.artifacts)
|
||||
primary_key = request.primary_key
|
||||
|
||||
if primary_key is None or not artifacts:
|
||||
return []
|
||||
|
||||
return await self._bulk_create_artifacts(
|
||||
BulkArtifactCreationRequest(artifacts=artifacts, primary_key=primary_key)
|
||||
)
|
||||
|
||||
async def _bulk_create_artifacts(
|
||||
self,
|
||||
request: BulkArtifactCreationRequest,
|
||||
@@ -636,7 +654,7 @@ class ArtifactManager:
|
||||
|
||||
return BulkArtifactCreationRequest(artifacts=artifacts, primary_key=ai_suggestion.ai_suggestion_id)
|
||||
|
||||
async def create_llm_artifact(
|
||||
async def prepare_llm_artifact(
|
||||
self,
|
||||
data: bytes,
|
||||
artifact_type: ArtifactType,
|
||||
@@ -645,54 +663,40 @@ class ArtifactManager:
|
||||
thought: Thought | None = None,
|
||||
task_v2: TaskV2 | None = None,
|
||||
ai_suggestion: AISuggestion | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Create LLM artifact with optional screenshots using bulk insert.
|
||||
|
||||
Args:
|
||||
data: Main artifact data
|
||||
artifact_type: Type of the main artifact
|
||||
screenshots: Optional list of screenshot data
|
||||
step: Optional Step entity
|
||||
thought: Optional Thought entity
|
||||
task_v2: Optional TaskV2 entity
|
||||
ai_suggestion: Optional AISuggestion entity
|
||||
"""
|
||||
) -> BulkArtifactCreationRequest | None:
|
||||
if step:
|
||||
request = self._prepare_step_artifacts(
|
||||
return self._prepare_step_artifacts(
|
||||
step=step,
|
||||
artifact_type=artifact_type,
|
||||
data=data,
|
||||
screenshots=screenshots,
|
||||
)
|
||||
await self._bulk_create_artifacts(request)
|
||||
|
||||
elif task_v2:
|
||||
request = self._prepare_task_v2_artifacts(
|
||||
return self._prepare_task_v2_artifacts(
|
||||
task_v2=task_v2,
|
||||
artifact_type=artifact_type,
|
||||
data=data,
|
||||
screenshots=screenshots,
|
||||
)
|
||||
await self._bulk_create_artifacts(request)
|
||||
|
||||
elif thought:
|
||||
request = self._prepare_thought_artifacts(
|
||||
return self._prepare_thought_artifacts(
|
||||
thought=thought,
|
||||
artifact_type=artifact_type,
|
||||
data=data,
|
||||
screenshots=screenshots,
|
||||
)
|
||||
await self._bulk_create_artifacts(request)
|
||||
|
||||
elif ai_suggestion:
|
||||
request = self._prepare_ai_suggestion_artifacts(
|
||||
return self._prepare_ai_suggestion_artifacts(
|
||||
ai_suggestion=ai_suggestion,
|
||||
artifact_type=artifact_type,
|
||||
data=data,
|
||||
screenshots=screenshots,
|
||||
)
|
||||
await self._bulk_create_artifacts(request)
|
||||
else:
|
||||
return None
|
||||
|
||||
async def update_artifact_data(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user