update organization API (#480)
This commit is contained in:
@@ -15,6 +15,7 @@ from skyvern.exceptions import SkyvernHTTPException
|
|||||||
from skyvern.forge import app as forge_app
|
from skyvern.forge import app as forge_app
|
||||||
from skyvern.forge.sdk.core import skyvern_context
|
from skyvern.forge.sdk.core import skyvern_context
|
||||||
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
|
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
|
||||||
|
from skyvern.forge.sdk.db.exceptions import NotFoundError
|
||||||
from skyvern.forge.sdk.routes.agent_protocol import base_router
|
from skyvern.forge.sdk.routes.agent_protocol import base_router
|
||||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||||
from skyvern.scheduler import SCHEDULER
|
from skyvern.scheduler import SCHEDULER
|
||||||
@@ -75,6 +76,10 @@ def get_agent_app(router: APIRouter = base_router) -> FastAPI:
|
|||||||
|
|
||||||
LOG.info("Server startup complete. Skyvern is now online")
|
LOG.info("Server startup complete. Skyvern is now online")
|
||||||
|
|
||||||
|
@app.exception_handler(NotFoundError)
|
||||||
|
async def handle_not_found_error(request: Request, exc: NotFoundError) -> Response:
|
||||||
|
return Response(status_code=status.HTTP_404_NOT_FOUND)
|
||||||
|
|
||||||
@app.exception_handler(SkyvernHTTPException)
|
@app.exception_handler(SkyvernHTTPException)
|
||||||
async def handle_skyvern_http_exception(request: Request, exc: SkyvernHTTPException) -> JSONResponse:
|
async def handle_skyvern_http_exception(request: Request, exc: SkyvernHTTPException) -> JSONResponse:
|
||||||
return JSONResponse(status_code=exc.status_code, content={"detail": exc.message})
|
return JSONResponse(status_code=exc.status_code, content={"detail": exc.message})
|
||||||
|
|||||||
@@ -485,6 +485,32 @@ class AgentDB:
|
|||||||
|
|
||||||
return convert_to_organization(org)
|
return convert_to_organization(org)
|
||||||
|
|
||||||
|
async def update_organization(
|
||||||
|
self,
|
||||||
|
organization_id: str,
|
||||||
|
organization_name: str | None = None,
|
||||||
|
webhook_callback_url: str | None = None,
|
||||||
|
max_steps_per_run: int | None = None,
|
||||||
|
max_retries_per_step: int | None = None,
|
||||||
|
) -> Organization:
|
||||||
|
async with self.Session() as session:
|
||||||
|
organization = (
|
||||||
|
await session.scalars(select(OrganizationModel).filter_by(organization_id=organization_id))
|
||||||
|
).first()
|
||||||
|
if not organization:
|
||||||
|
raise NotFoundError
|
||||||
|
if organization_name:
|
||||||
|
organization.organization_name = organization_name
|
||||||
|
if webhook_callback_url:
|
||||||
|
organization.webhook_callback_url = webhook_callback_url
|
||||||
|
if max_steps_per_run:
|
||||||
|
organization.max_steps_per_run = max_steps_per_run
|
||||||
|
if max_retries_per_step:
|
||||||
|
organization.max_retries_per_step = max_retries_per_step
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(organization)
|
||||||
|
return Organization.model_validate(organization)
|
||||||
|
|
||||||
async def get_valid_org_auth_token(
|
async def get_valid_org_auth_token(
|
||||||
self,
|
self,
|
||||||
organization_id: str,
|
organization_id: str,
|
||||||
|
|||||||
@@ -109,7 +109,7 @@ class OrganizationModel(Base):
|
|||||||
modified_at = Column(
|
modified_at = Column(
|
||||||
DateTime,
|
DateTime,
|
||||||
default=datetime.datetime.utcnow,
|
default=datetime.datetime.utcnow,
|
||||||
onupdate=datetime.datetime,
|
onupdate=datetime.datetime.utcnow,
|
||||||
nullable=False,
|
nullable=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -133,7 +133,7 @@ class OrganizationAuthTokenModel(Base):
|
|||||||
modified_at = Column(
|
modified_at = Column(
|
||||||
DateTime,
|
DateTime,
|
||||||
default=datetime.datetime.utcnow,
|
default=datetime.datetime.utcnow,
|
||||||
onupdate=datetime.datetime,
|
onupdate=datetime.datetime.utcnow,
|
||||||
nullable=False,
|
nullable=False,
|
||||||
)
|
)
|
||||||
deleted_at = Column(DateTime, nullable=True)
|
deleted_at = Column(DateTime, nullable=True)
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
||||||
from skyvern.webeye.actions.actions import ActionType
|
from skyvern.webeye.actions.actions import ActionType
|
||||||
@@ -118,6 +118,8 @@ class Step(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class Organization(BaseModel):
|
class Organization(BaseModel):
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
organization_id: str
|
organization_id: str
|
||||||
organization_name: str
|
organization_name: str
|
||||||
webhook_callback_url: str | None = None
|
webhook_callback_url: str | None = None
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from skyvern.forge.sdk.core.permissions.permission_checker_factory import Permis
|
|||||||
from skyvern.forge.sdk.core.security import generate_skyvern_signature
|
from skyvern.forge.sdk.core.security import generate_skyvern_signature
|
||||||
from skyvern.forge.sdk.executor.factory import AsyncExecutorFactory
|
from skyvern.forge.sdk.executor.factory import AsyncExecutorFactory
|
||||||
from skyvern.forge.sdk.models import Organization, Step
|
from skyvern.forge.sdk.models import Organization, Step
|
||||||
|
from skyvern.forge.sdk.schemas.organizations import OrganizationUpdate
|
||||||
from skyvern.forge.sdk.schemas.task_generations import GenerateTaskRequest, TaskGeneration, TaskGenerationBase
|
from skyvern.forge.sdk.schemas.task_generations import GenerateTaskRequest, TaskGeneration, TaskGenerationBase
|
||||||
from skyvern.forge.sdk.schemas.tasks import (
|
from skyvern.forge.sdk.schemas.tasks import (
|
||||||
CreateTaskResponse,
|
CreateTaskResponse,
|
||||||
@@ -693,3 +694,18 @@ async def generate_task(
|
|||||||
except LLMProviderError:
|
except LLMProviderError:
|
||||||
LOG.error("Failed to generate task", exc_info=True)
|
LOG.error("Failed to generate task", exc_info=True)
|
||||||
raise HTTPException(status_code=400, detail="Failed to generate task. Please try again later.")
|
raise HTTPException(status_code=400, detail="Failed to generate task. Please try again later.")
|
||||||
|
|
||||||
|
|
||||||
|
@base_router.put("/organizations/", include_in_schema=False)
|
||||||
|
@base_router.put("/organizations")
|
||||||
|
async def update_organization(
|
||||||
|
org_update: OrganizationUpdate,
|
||||||
|
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||||
|
) -> Organization:
|
||||||
|
return await app.DATABASE.update_organization(
|
||||||
|
current_org.organization_id,
|
||||||
|
organization_name=org_update.organization_name,
|
||||||
|
webhook_callback_url=org_update.webhook_callback_url,
|
||||||
|
max_steps_per_run=org_update.max_steps_per_run,
|
||||||
|
max_retries_per_step=org_update.max_retries_per_step,
|
||||||
|
)
|
||||||
|
|||||||
8
skyvern/forge/sdk/schemas/organizations.py
Normal file
8
skyvern/forge/sdk/schemas/organizations.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class OrganizationUpdate(BaseModel):
|
||||||
|
organization_name: str | None = None
|
||||||
|
webhook_callback_url: str | None = None
|
||||||
|
max_steps_per_run: int | None = None
|
||||||
|
max_retries_per_step: int | None = None
|
||||||
Reference in New Issue
Block a user