add api key expired message to the 403 when an api key is expired/invalid (#532)
This commit is contained in:
@@ -566,18 +566,19 @@ class AgentDB:
|
|||||||
organization_id: str,
|
organization_id: str,
|
||||||
token_type: OrganizationAuthTokenType,
|
token_type: OrganizationAuthTokenType,
|
||||||
token: str,
|
token: str,
|
||||||
|
valid: bool | None = True,
|
||||||
) -> OrganizationAuthToken | None:
|
) -> OrganizationAuthToken | None:
|
||||||
try:
|
try:
|
||||||
async with self.Session() as session:
|
async with self.Session() as session:
|
||||||
if token_obj := (
|
query = (
|
||||||
await session.scalars(
|
|
||||||
select(OrganizationAuthTokenModel)
|
select(OrganizationAuthTokenModel)
|
||||||
.filter_by(organization_id=organization_id)
|
.filter_by(organization_id=organization_id)
|
||||||
.filter_by(token_type=token_type)
|
.filter_by(token_type=token_type)
|
||||||
.filter_by(token=token)
|
.filter_by(token=token)
|
||||||
.filter_by(valid=True)
|
|
||||||
)
|
)
|
||||||
).first():
|
if valid is not None:
|
||||||
|
query = query.filter_by(valid=valid)
|
||||||
|
if token_obj := (await session.scalars(query)).first():
|
||||||
return convert_to_organization_auth_token(token_obj)
|
return convert_to_organization_auth_token(token_obj)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -109,6 +109,7 @@ async def _get_current_org_cached(x_api_key: str, db: AgentDB) -> Organization:
|
|||||||
organization_id=organization.organization_id,
|
organization_id=organization.organization_id,
|
||||||
token_type=OrganizationAuthTokenType.api,
|
token_type=OrganizationAuthTokenType.api,
|
||||||
token=x_api_key,
|
token=x_api_key,
|
||||||
|
valid=None,
|
||||||
)
|
)
|
||||||
if not api_key_db_obj:
|
if not api_key_db_obj:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -116,6 +117,12 @@ async def _get_current_org_cached(x_api_key: str, db: AgentDB) -> Organization:
|
|||||||
detail="Invalid credentials",
|
detail="Invalid credentials",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if api_key_db_obj.valid is False:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Your API key has expired. Please retrieve the latest one from https://app.skyvern.com/settings",
|
||||||
|
)
|
||||||
|
|
||||||
# set organization_id in skyvern context and log context
|
# set organization_id in skyvern context and log context
|
||||||
context = skyvern_context.current()
|
context = skyvern_context.current()
|
||||||
if context:
|
if context:
|
||||||
|
|||||||
Reference in New Issue
Block a user