Bitwarden Security Upgrade (#900)

This commit is contained in:
Kerem Yilmaz
2024-10-02 15:16:08 -07:00
committed by GitHub
parent 4f6feae03e
commit 36135a613b
9 changed files with 124 additions and 6 deletions

View File

@@ -0,0 +1,33 @@
"""Add bitwarden details to organizations
Revision ID: 6c90d565076b
Revises: c5848cc524b1
Create Date: 2024-10-02 22:12:34.959165+00:00
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "6c90d565076b"
down_revision: Union[str, None] = "c5848cc524b1"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("organizations", sa.Column("bw_organization_id", sa.String(), nullable=True))
op.add_column("organizations", sa.Column("bw_collection_ids", sa.JSON(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("organizations", "bw_collection_ids")
op.drop_column("organizations", "bw_organization_id")
# ### end Alembic commands ###

View File

@@ -278,6 +278,15 @@ class BitwardenSyncError(BitwardenBaseError):
super().__init__(f"Error syncing Bitwarden: {message}") super().__init__(f"Error syncing Bitwarden: {message}")
class BitwardenAccessDeniedError(BitwardenBaseError):
def __init__(self) -> None:
super().__init__(
"Current organization does not have access to the specified Bitwarden collection. \
Contact Skyvern support to enable access. This is a security layer on top of Bitwarden, \
Skyvern team needs to let your Skyvern account access the Bitwarden collection."
)
class UnknownElementTreeFormat(SkyvernException): class UnknownElementTreeFormat(SkyvernException):
def __init__(self, fmt: str) -> None: def __init__(self, fmt: str) -> None:
super().__init__(f"Unknown element tree format {fmt}") super().__init__(f"Unknown element tree format {fmt}")

View File

@@ -109,6 +109,8 @@ class OrganizationModel(Base):
max_steps_per_run = Column(Integer, nullable=True) max_steps_per_run = Column(Integer, nullable=True)
max_retries_per_step = Column(Integer, nullable=True) max_retries_per_step = Column(Integer, nullable=True)
domain = Column(String, nullable=True, index=True) domain = Column(String, nullable=True, index=True)
bw_organization_id = Column(String, nullable=True, default=None)
bw_collection_ids = Column(JSON, nullable=True, default=None)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column( modified_at = Column(
DateTime, DateTime,

View File

@@ -111,6 +111,8 @@ def convert_to_organization(org_model: OrganizationModel) -> Organization:
max_steps_per_run=org_model.max_steps_per_run, max_steps_per_run=org_model.max_steps_per_run,
max_retries_per_step=org_model.max_retries_per_step, max_retries_per_step=org_model.max_retries_per_step,
domain=org_model.domain, domain=org_model.domain,
bw_organization_id=org_model.bw_organization_id,
bw_collection_ids=org_model.bw_collection_ids,
created_at=org_model.created_at, created_at=org_model.created_at,
modified_at=org_model.modified_at, modified_at=org_model.modified_at,
) )

View File

@@ -99,8 +99,14 @@ class BackgroundTaskExecutor(AsyncExecutor):
"Executing workflow using background task executor", "Executing workflow using background task executor",
workflow_run_id=workflow_run_id, workflow_run_id=workflow_run_id,
) )
organization = await app.DATABASE.get_organization(organization_id)
if organization is None:
raise OrganizationNotFound(organization_id)
background_tasks.add_task( background_tasks.add_task(
app.WORKFLOW_SERVICE.execute_workflow, app.WORKFLOW_SERVICE.execute_workflow,
workflow_run_id=workflow_run_id, workflow_run_id=workflow_run_id,
api_key=api_key, api_key=api_key,
organization=organization,
) )

View File

@@ -131,6 +131,8 @@ class Organization(BaseModel):
max_steps_per_run: int | None = None max_steps_per_run: int | None = None
max_retries_per_step: int | None = None max_retries_per_step: int | None = None
domain: str | None = None domain: str | None = None
bw_organization_id: str | None = None
bw_collection_ids: list[str] | None = None
created_at: datetime created_at: datetime
modified_at: datetime modified_at: datetime

View File

@@ -10,6 +10,7 @@ import tldextract
from skyvern.config import settings from skyvern.config import settings
from skyvern.exceptions import ( from skyvern.exceptions import (
BitwardenAccessDeniedError,
BitwardenListItemsError, BitwardenListItemsError,
BitwardenLoginError, BitwardenLoginError,
BitwardenLogoutError, BitwardenLogoutError,
@@ -29,6 +30,9 @@ def is_valid_email(email: str | None) -> bool:
class BitwardenConstants(StrEnum): class BitwardenConstants(StrEnum):
BW_ORGANIZATION_ID = "BW_ORGANIZATION_ID"
BW_COLLECTION_IDS = "BW_COLLECTION_IDS"
CLIENT_ID = "BW_CLIENT_ID" CLIENT_ID = "BW_CLIENT_ID"
CLIENT_SECRET = "BW_CLIENT_SECRET" CLIENT_SECRET = "BW_CLIENT_SECRET"
MASTER_PASSWORD = "BW_MASTER_PASSWORD" MASTER_PASSWORD = "BW_MASTER_PASSWORD"
@@ -79,6 +83,8 @@ class BitwardenService:
client_id: str, client_id: str,
client_secret: str, client_secret: str,
master_password: str, master_password: str,
bw_organization_id: str | None,
bw_collection_ids: list[str] | None,
url: str, url: str,
collection_id: str | None = None, collection_id: str | None = None,
remaining_retries: int = settings.BITWARDEN_MAX_RETRIES, remaining_retries: int = settings.BITWARDEN_MAX_RETRIES,
@@ -94,6 +100,8 @@ class BitwardenService:
client_id=client_id, client_id=client_id,
client_secret=client_secret, client_secret=client_secret,
master_password=master_password, master_password=master_password,
bw_organization_id=bw_organization_id,
bw_collection_ids=bw_collection_ids,
url=url, url=url,
collection_id=collection_id, collection_id=collection_id,
) )
@@ -109,6 +117,8 @@ class BitwardenService:
client_id=client_id, client_id=client_id,
client_secret=client_secret, client_secret=client_secret,
master_password=master_password, master_password=master_password,
bw_organization_id=bw_organization_id,
bw_collection_ids=bw_collection_ids,
url=url, url=url,
collection_id=collection_id, collection_id=collection_id,
remaining_retries=remaining_retries, remaining_retries=remaining_retries,
@@ -122,12 +132,16 @@ class BitwardenService:
client_id: str, client_id: str,
client_secret: str, client_secret: str,
master_password: str, master_password: str,
bw_organization_id: str | None,
bw_collection_ids: list[str] | None,
url: str, url: str,
collection_id: str | None = None, collection_id: str | None = None,
) -> dict[str, str]: ) -> dict[str, str]:
""" """
Get the secret value from the Bitwarden CLI. Get the secret value from the Bitwarden CLI.
""" """
if not bw_organization_id and bw_collection_ids and collection_id not in bw_collection_ids:
raise BitwardenAccessDeniedError()
try: try:
BitwardenService.login(client_id, client_secret) BitwardenService.login(client_id, client_secret)
BitwardenService.sync() BitwardenService.sync()
@@ -144,7 +158,13 @@ class BitwardenService:
"--session", "--session",
session_key, session_key,
] ]
if collection_id: if bw_organization_id:
LOG.info(
"Organization ID is provided, filtering items by organization ID",
bw_organization_id=bw_organization_id,
)
list_command.extend(["--organizationid", bw_organization_id])
elif collection_id:
LOG.info("Collection ID is provided, filtering items by collection ID", collection_id=collection_id) LOG.info("Collection ID is provided, filtering items by collection ID", collection_id=collection_id)
list_command.extend(["--collectionid", collection_id]) list_command.extend(["--collectionid", collection_id])
items_result = BitwardenService.run_command(list_command) items_result = BitwardenService.run_command(list_command)
@@ -158,11 +178,26 @@ class BitwardenService:
except json.JSONDecodeError: except json.JSONDecodeError:
raise BitwardenListItemsError("Failed to parse items JSON. Output: " + items_result.stdout) raise BitwardenListItemsError("Failed to parse items JSON. Output: " + items_result.stdout)
# Since Bitwarden can't AND multiple filters, we only use organization id in the list command
# but we still need to filter the items by collection id here
if bw_organization_id and collection_id:
filtered_items = []
for item in items:
if "collectionIds" in item and collection_id in item["collectionIds"]:
filtered_items.append(item)
items = filtered_items
if not items: if not items:
collection_id_str = f" in collection with ID: {collection_id}" if collection_id else "" collection_id_str = f" in collection with ID: {collection_id}" if collection_id else ""
raise BitwardenListItemsError(f"No items found in Bitwarden for URL: {url}{collection_id_str}") raise BitwardenListItemsError(f"No items found in Bitwarden for URL: {url}{collection_id_str}")
# TODO (kerem): To make this more robust, we need to store the item id of the totp login item
# and use it here to get the TOTP code for that specific item
totp_command = ["bw", "get", "totp", url, "--session", session_key] totp_command = ["bw", "get", "totp", url, "--session", session_key]
if bw_organization_id:
# We need to add this filter because the TOTP command fails if there are multiple results
# For now, we require that the bitwarden organization id has only one totp login item for the domain
totp_command.extend(["--organizationid", bw_organization_id])
totp_result = BitwardenService.run_command(totp_command) totp_result = BitwardenService.run_command(totp_command)
if totp_result.stderr and "Event post failed" not in totp_result.stderr: if totp_result.stderr and "Event post failed" not in totp_result.stderr:
@@ -208,6 +243,8 @@ class BitwardenService:
client_id: str, client_id: str,
client_secret: str, client_secret: str,
master_password: str, master_password: str,
bw_organization_id: str | None,
bw_collection_ids: list[str] | None,
collection_id: str, collection_id: str,
identity_key: str, identity_key: str,
identity_fields: list[str], identity_fields: list[str],
@@ -224,6 +261,8 @@ class BitwardenService:
client_id=client_id, client_id=client_id,
client_secret=client_secret, client_secret=client_secret,
master_password=master_password, master_password=master_password,
bw_organization_id=bw_organization_id,
bw_collection_ids=bw_collection_ids,
collection_id=collection_id, collection_id=collection_id,
identity_key=identity_key, identity_key=identity_key,
identity_fields=identity_fields, identity_fields=identity_fields,
@@ -240,6 +279,8 @@ class BitwardenService:
client_id=client_id, client_id=client_id,
client_secret=client_secret, client_secret=client_secret,
master_password=master_password, master_password=master_password,
bw_organization_id=bw_organization_id,
bw_collection_ids=bw_collection_ids,
collection_id=collection_id, collection_id=collection_id,
identity_key=identity_key, identity_key=identity_key,
identity_fields=identity_fields, identity_fields=identity_fields,
@@ -257,10 +298,14 @@ class BitwardenService:
collection_id: str, collection_id: str,
identity_key: str, identity_key: str,
identity_fields: list[str], identity_fields: list[str],
bw_organization_id: str | None,
bw_collection_ids: list[str] | None,
) -> dict[str, str]: ) -> dict[str, str]:
""" """
Get the sensitive information from the Bitwarden CLI. Get the sensitive information from the Bitwarden CLI.
""" """
if not bw_organization_id and bw_collection_ids and collection_id not in bw_collection_ids:
raise BitwardenAccessDeniedError()
try: try:
BitwardenService.login(client_id, client_secret) BitwardenService.login(client_id, client_secret)
BitwardenService.sync() BitwardenService.sync()
@@ -278,6 +323,8 @@ class BitwardenService:
"--collectionid", "--collectionid",
collection_id, collection_id,
] ]
if bw_organization_id:
list_command.extend(["--organizationid", bw_organization_id])
items_result = BitwardenService.run_command(list_command) items_result = BitwardenService.run_command(list_command)
# Parse the items and extract sensitive information # Parse the items and extract sensitive information

View File

@@ -5,6 +5,7 @@ import structlog
from skyvern.exceptions import BitwardenBaseError, WorkflowRunContextNotInitialized from skyvern.exceptions import BitwardenBaseError, WorkflowRunContextNotInitialized
from skyvern.forge.sdk.api.aws import AsyncAWSClient from skyvern.forge.sdk.api.aws import AsyncAWSClient
from skyvern.forge.sdk.models import Organization
from skyvern.forge.sdk.services.bitwarden import BitwardenConstants, BitwardenService from skyvern.forge.sdk.services.bitwarden import BitwardenConstants, BitwardenService
from skyvern.forge.sdk.workflow.exceptions import OutputParameterKeyCollisionError from skyvern.forge.sdk.workflow.exceptions import OutputParameterKeyCollisionError
from skyvern.forge.sdk.workflow.models.parameter import ( from skyvern.forge.sdk.workflow.models.parameter import (
@@ -106,6 +107,8 @@ class WorkflowRunContext:
client_secret=self.secrets[BitwardenConstants.CLIENT_SECRET], client_secret=self.secrets[BitwardenConstants.CLIENT_SECRET],
client_id=self.secrets[BitwardenConstants.CLIENT_ID], client_id=self.secrets[BitwardenConstants.CLIENT_ID],
master_password=self.secrets[BitwardenConstants.MASTER_PASSWORD], master_password=self.secrets[BitwardenConstants.MASTER_PASSWORD],
bw_organization_id=self.secrets[BitwardenConstants.BW_ORGANIZATION_ID],
bw_collection_ids=self.secrets[BitwardenConstants.BW_COLLECTION_IDS],
) )
return secret_credentials return secret_credentials
@@ -117,6 +120,7 @@ class WorkflowRunContext:
self, self,
aws_client: AsyncAWSClient, aws_client: AsyncAWSClient,
parameter: PARAMETER_TYPE, parameter: PARAMETER_TYPE,
organization: Organization,
) -> None: ) -> None:
if parameter.parameter_type == ParameterType.WORKFLOW: if parameter.parameter_type == ParameterType.WORKFLOW:
LOG.error(f"Workflow parameters are set while initializing context manager. Parameter key: {parameter.key}") LOG.error(f"Workflow parameters are set while initializing context manager. Parameter key: {parameter.key}")
@@ -165,10 +169,14 @@ class WorkflowRunContext:
client_id, client_id,
client_secret, client_secret,
master_password, master_password,
organization.bw_organization_id,
organization.bw_collection_ids,
url, url,
collection_id=collection_id, collection_id=collection_id,
) )
if secret_credentials: if secret_credentials:
self.secrets[BitwardenConstants.BW_ORGANIZATION_ID] = organization.bw_organization_id
self.secrets[BitwardenConstants.BW_COLLECTION_IDS] = organization.bw_collection_ids
self.secrets[BitwardenConstants.URL] = url self.secrets[BitwardenConstants.URL] = url
self.secrets[BitwardenConstants.CLIENT_SECRET] = client_secret self.secrets[BitwardenConstants.CLIENT_SECRET] = client_secret
self.secrets[BitwardenConstants.CLIENT_ID] = client_id self.secrets[BitwardenConstants.CLIENT_ID] = client_id
@@ -223,11 +231,15 @@ class WorkflowRunContext:
client_id, client_id,
client_secret, client_secret,
master_password, master_password,
organization.bw_organization_id,
organization.bw_collection_ids,
collection_id, collection_id,
bitwarden_identity_key, bitwarden_identity_key,
parameter.bitwarden_identity_fields, parameter.bitwarden_identity_fields,
) )
if sensitive_values: if sensitive_values:
self.secrets[BitwardenConstants.BW_ORGANIZATION_ID] = organization.bw_organization_id
self.secrets[BitwardenConstants.BW_COLLECTION_IDS] = organization.bw_collection_ids
self.secrets[BitwardenConstants.IDENTITY_KEY] = bitwarden_identity_key self.secrets[BitwardenConstants.IDENTITY_KEY] = bitwarden_identity_key
self.secrets[BitwardenConstants.CLIENT_SECRET] = client_secret self.secrets[BitwardenConstants.CLIENT_SECRET] = client_secret
self.secrets[BitwardenConstants.CLIENT_ID] = client_id self.secrets[BitwardenConstants.CLIENT_ID] = client_id
@@ -333,6 +345,7 @@ class WorkflowRunContext:
self, self,
aws_client: AsyncAWSClient, aws_client: AsyncAWSClient,
parameters: list[PARAMETER_TYPE], parameters: list[PARAMETER_TYPE],
organization: Organization,
) -> None: ) -> None:
# Sort the parameters so that ContextParameter and BitwardenLoginCredentialParameter are processed last # Sort the parameters so that ContextParameter and BitwardenLoginCredentialParameter are processed last
# ContextParameter should be processed at the end since it requires the source parameter to be set # ContextParameter should be processed at the end since it requires the source parameter to be set
@@ -369,7 +382,7 @@ class WorkflowRunContext:
) )
self.parameters[parameter.key] = parameter self.parameters[parameter.key] = parameter
await self.register_parameter_value(aws_client, parameter) await self.register_parameter_value(aws_client, parameter, organization)
class WorkflowContextManager: class WorkflowContextManager:
@@ -410,6 +423,9 @@ class WorkflowContextManager:
self, self,
workflow_run_id: str, workflow_run_id: str,
parameters: list[PARAMETER_TYPE], parameters: list[PARAMETER_TYPE],
organization: Organization,
) -> None: ) -> None:
self._validate_workflow_run_context(workflow_run_id) self._validate_workflow_run_context(workflow_run_id)
await self.workflow_run_contexts[workflow_run_id].register_block_parameters(self.aws_client, parameters) await self.workflow_run_contexts[workflow_run_id].register_block_parameters(
self.aws_client, parameters, organization
)

View File

@@ -12,7 +12,7 @@ from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.security import generate_skyvern_signature from skyvern.forge.sdk.core.security import generate_skyvern_signature
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.models import Step from skyvern.forge.sdk.models import Organization, Step
from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task, TaskStatus from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task, TaskStatus
from skyvern.forge.sdk.workflow.exceptions import ( from skyvern.forge.sdk.workflow.exceptions import (
ContextParameterSourceNotDefined, ContextParameterSourceNotDefined,
@@ -150,9 +150,10 @@ class WorkflowService:
self, self,
workflow_run_id: str, workflow_run_id: str,
api_key: str, api_key: str,
organization_id: str | None = None, organization: Organization,
) -> WorkflowRun: ) -> WorkflowRun:
"""Execute a workflow.""" """Execute a workflow."""
organization_id = organization.organization_id
workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id) workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id)
workflow = await self.get_workflow(workflow_id=workflow_run.workflow_id, organization_id=organization_id) workflow = await self.get_workflow(workflow_id=workflow_run.workflow_id, organization_id=organization_id)
@@ -181,7 +182,7 @@ class WorkflowService:
try: try:
parameters = block.get_all_parameters(workflow_run_id) parameters = block.get_all_parameters(workflow_run_id)
await app.WORKFLOW_CONTEXT_MANAGER.register_block_parameters_for_workflow_run( await app.WORKFLOW_CONTEXT_MANAGER.register_block_parameters_for_workflow_run(
workflow_run_id, parameters workflow_run_id, parameters, organization
) )
LOG.info( LOG.info(
f"Executing root block {block.block_type} at index {block_idx} for workflow run {workflow_run_id}", f"Executing root block {block.block_type} at index {block_idx} for workflow run {workflow_run_id}",