Keep track of token counts in steps table (#412)
This commit is contained in:
@@ -115,11 +115,15 @@ class LLMAPIHandlerFactory:
|
|||||||
data=response.model_dump_json(indent=2).encode("utf-8"),
|
data=response.model_dump_json(indent=2).encode("utf-8"),
|
||||||
)
|
)
|
||||||
llm_cost = litellm.completion_cost(completion_response=response)
|
llm_cost = litellm.completion_cost(completion_response=response)
|
||||||
|
prompt_tokens = response.get("usage", {}).get("prompt_tokens", 0)
|
||||||
|
completion_tokens = response.get("usage", {}).get("completion_tokens", 0)
|
||||||
await app.DATABASE.update_step(
|
await app.DATABASE.update_step(
|
||||||
task_id=step.task_id,
|
task_id=step.task_id,
|
||||||
step_id=step.step_id,
|
step_id=step.step_id,
|
||||||
organization_id=step.organization_id,
|
organization_id=step.organization_id,
|
||||||
incremental_cost=llm_cost,
|
incremental_cost=llm_cost,
|
||||||
|
incremental_input_tokens=prompt_tokens if prompt_tokens > 0 else None,
|
||||||
|
incremental_output_tokens=completion_tokens if completion_tokens > 0 else None,
|
||||||
)
|
)
|
||||||
parsed_response = parse_api_response(response, llm_config.add_assistant_prefix)
|
parsed_response = parse_api_response(response, llm_config.add_assistant_prefix)
|
||||||
if step:
|
if step:
|
||||||
@@ -206,11 +210,15 @@ class LLMAPIHandlerFactory:
|
|||||||
data=response.model_dump_json(indent=2).encode("utf-8"),
|
data=response.model_dump_json(indent=2).encode("utf-8"),
|
||||||
)
|
)
|
||||||
llm_cost = litellm.completion_cost(completion_response=response)
|
llm_cost = litellm.completion_cost(completion_response=response)
|
||||||
|
prompt_tokens = response.get("usage", {}).get("prompt_tokens", 0)
|
||||||
|
completion_tokens = response.get("usage", {}).get("completion_tokens", 0)
|
||||||
await app.DATABASE.update_step(
|
await app.DATABASE.update_step(
|
||||||
task_id=step.task_id,
|
task_id=step.task_id,
|
||||||
step_id=step.step_id,
|
step_id=step.step_id,
|
||||||
organization_id=step.organization_id,
|
organization_id=step.organization_id,
|
||||||
incremental_cost=llm_cost,
|
incremental_cost=llm_cost,
|
||||||
|
incremental_input_tokens=prompt_tokens if prompt_tokens > 0 else None,
|
||||||
|
incremental_output_tokens=completion_tokens if completion_tokens > 0 else None,
|
||||||
)
|
)
|
||||||
parsed_response = parse_api_response(response, llm_config.add_assistant_prefix)
|
parsed_response = parse_api_response(response, llm_config.add_assistant_prefix)
|
||||||
if step:
|
if step:
|
||||||
|
|||||||
@@ -293,6 +293,8 @@ class AgentDB:
|
|||||||
retry_index: int | None = None,
|
retry_index: int | None = None,
|
||||||
organization_id: str | None = None,
|
organization_id: str | None = None,
|
||||||
incremental_cost: float | None = None,
|
incremental_cost: float | None = None,
|
||||||
|
incremental_input_tokens: int | None = None,
|
||||||
|
incremental_output_tokens: int | None = None,
|
||||||
) -> Step:
|
) -> Step:
|
||||||
try:
|
try:
|
||||||
async with self.Session() as session:
|
async with self.Session() as session:
|
||||||
@@ -314,6 +316,10 @@ class AgentDB:
|
|||||||
step.retry_index = retry_index
|
step.retry_index = retry_index
|
||||||
if incremental_cost is not None:
|
if incremental_cost is not None:
|
||||||
step.step_cost = incremental_cost + float(step.step_cost or 0)
|
step.step_cost = incremental_cost + float(step.step_cost or 0)
|
||||||
|
if incremental_input_tokens is not None:
|
||||||
|
step.input_token_count = incremental_input_tokens + (step.input_token_count or 0)
|
||||||
|
if incremental_output_tokens is not None:
|
||||||
|
step.output_token_count = incremental_output_tokens + (step.output_token_count or 0)
|
||||||
|
|
||||||
await session.commit()
|
await session.commit()
|
||||||
updated_step = await self.get_step(task_id, step_id, organization_id)
|
updated_step = await self.get_step(task_id, step_id, organization_id)
|
||||||
|
|||||||
Reference in New Issue
Block a user