update organization API (#480)

This commit is contained in:
Kerem Yilmaz
2024-06-16 19:42:20 -07:00
committed by GitHub
parent af81fb7206
commit 10612f02fd
6 changed files with 60 additions and 3 deletions

View File

@@ -15,6 +15,7 @@ from skyvern.exceptions import SkyvernHTTPException
from skyvern.forge import app as forge_app from skyvern.forge import app as forge_app
from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.db.exceptions import NotFoundError
from skyvern.forge.sdk.routes.agent_protocol import base_router from skyvern.forge.sdk.routes.agent_protocol import base_router
from skyvern.forge.sdk.settings_manager import SettingsManager from skyvern.forge.sdk.settings_manager import SettingsManager
from skyvern.scheduler import SCHEDULER from skyvern.scheduler import SCHEDULER
@@ -75,6 +76,10 @@ def get_agent_app(router: APIRouter = base_router) -> FastAPI:
LOG.info("Server startup complete. Skyvern is now online") LOG.info("Server startup complete. Skyvern is now online")
@app.exception_handler(NotFoundError)
async def handle_not_found_error(request: Request, exc: NotFoundError) -> Response:
return Response(status_code=status.HTTP_404_NOT_FOUND)
@app.exception_handler(SkyvernHTTPException) @app.exception_handler(SkyvernHTTPException)
async def handle_skyvern_http_exception(request: Request, exc: SkyvernHTTPException) -> JSONResponse: async def handle_skyvern_http_exception(request: Request, exc: SkyvernHTTPException) -> JSONResponse:
return JSONResponse(status_code=exc.status_code, content={"detail": exc.message}) return JSONResponse(status_code=exc.status_code, content={"detail": exc.message})

View File

@@ -485,6 +485,32 @@ class AgentDB:
return convert_to_organization(org) return convert_to_organization(org)
async def update_organization(
self,
organization_id: str,
organization_name: str | None = None,
webhook_callback_url: str | None = None,
max_steps_per_run: int | None = None,
max_retries_per_step: int | None = None,
) -> Organization:
async with self.Session() as session:
organization = (
await session.scalars(select(OrganizationModel).filter_by(organization_id=organization_id))
).first()
if not organization:
raise NotFoundError
if organization_name:
organization.organization_name = organization_name
if webhook_callback_url:
organization.webhook_callback_url = webhook_callback_url
if max_steps_per_run:
organization.max_steps_per_run = max_steps_per_run
if max_retries_per_step:
organization.max_retries_per_step = max_retries_per_step
await session.commit()
await session.refresh(organization)
return Organization.model_validate(organization)
async def get_valid_org_auth_token( async def get_valid_org_auth_token(
self, self,
organization_id: str, organization_id: str,

View File

@@ -109,7 +109,7 @@ class OrganizationModel(Base):
modified_at = Column( modified_at = Column(
DateTime, DateTime,
default=datetime.datetime.utcnow, default=datetime.datetime.utcnow,
onupdate=datetime.datetime, onupdate=datetime.datetime.utcnow,
nullable=False, nullable=False,
) )
@@ -133,7 +133,7 @@ class OrganizationAuthTokenModel(Base):
modified_at = Column( modified_at = Column(
DateTime, DateTime,
default=datetime.datetime.utcnow, default=datetime.datetime.utcnow,
onupdate=datetime.datetime, onupdate=datetime.datetime.utcnow,
nullable=False, nullable=False,
) )
deleted_at = Column(DateTime, nullable=True) deleted_at = Column(DateTime, nullable=True)

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from datetime import datetime from datetime import datetime
from enum import StrEnum from enum import StrEnum
from pydantic import BaseModel from pydantic import BaseModel, ConfigDict
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
from skyvern.webeye.actions.actions import ActionType from skyvern.webeye.actions.actions import ActionType
@@ -118,6 +118,8 @@ class Step(BaseModel):
class Organization(BaseModel): class Organization(BaseModel):
model_config = ConfigDict(from_attributes=True)
organization_id: str organization_id: str
organization_name: str organization_name: str
webhook_callback_url: str | None = None webhook_callback_url: str | None = None

View File

@@ -17,6 +17,7 @@ from skyvern.forge.sdk.core.permissions.permission_checker_factory import Permis
from skyvern.forge.sdk.core.security import generate_skyvern_signature from skyvern.forge.sdk.core.security import generate_skyvern_signature
from skyvern.forge.sdk.executor.factory import AsyncExecutorFactory from skyvern.forge.sdk.executor.factory import AsyncExecutorFactory
from skyvern.forge.sdk.models import Organization, Step from skyvern.forge.sdk.models import Organization, Step
from skyvern.forge.sdk.schemas.organizations import OrganizationUpdate
from skyvern.forge.sdk.schemas.task_generations import GenerateTaskRequest, TaskGeneration, TaskGenerationBase from skyvern.forge.sdk.schemas.task_generations import GenerateTaskRequest, TaskGeneration, TaskGenerationBase
from skyvern.forge.sdk.schemas.tasks import ( from skyvern.forge.sdk.schemas.tasks import (
CreateTaskResponse, CreateTaskResponse,
@@ -693,3 +694,18 @@ async def generate_task(
except LLMProviderError: except LLMProviderError:
LOG.error("Failed to generate task", exc_info=True) LOG.error("Failed to generate task", exc_info=True)
raise HTTPException(status_code=400, detail="Failed to generate task. Please try again later.") raise HTTPException(status_code=400, detail="Failed to generate task. Please try again later.")
@base_router.put("/organizations/", include_in_schema=False)
@base_router.put("/organizations")
async def update_organization(
org_update: OrganizationUpdate,
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> Organization:
return await app.DATABASE.update_organization(
current_org.organization_id,
organization_name=org_update.organization_name,
webhook_callback_url=org_update.webhook_callback_url,
max_steps_per_run=org_update.max_steps_per_run,
max_retries_per_step=org_update.max_retries_per_step,
)

View File

@@ -0,0 +1,8 @@
from pydantic import BaseModel
class OrganizationUpdate(BaseModel):
organization_name: str | None = None
webhook_callback_url: str | None = None
max_steps_per_run: int | None = None
max_retries_per_step: int | None = None