Common: Encapsulate MarshmallowError

This commit is contained in:
Mike Salvatore 2022-06-24 13:21:39 -04:00
parent 94524d124c
commit dbd0d3e0dd
3 changed files with 44 additions and 9 deletions

View File

@ -4,6 +4,7 @@ from dataclasses import dataclass
from typing import List from typing import List
from marshmallow import Schema, fields from marshmallow import Schema, fields
from marshmallow.exceptions import MarshmallowError
from .agent_sub_configuration_schemas import ( from .agent_sub_configuration_schemas import (
CustomPBAConfigurationSchema, CustomPBAConfigurationSchema,
@ -21,6 +22,11 @@ class InvalidConfigurationError(Exception):
pass pass
INVALID_CONFIGURATION_ERROR_MESSAGE = (
"Cannot construct an AgentConfiguration object with the supplied, invalid data:"
)
@dataclass(frozen=True) @dataclass(frozen=True)
class AgentConfiguration: class AgentConfiguration:
keep_tunnel_open_time: float keep_tunnel_open_time: float
@ -33,17 +39,26 @@ class AgentConfiguration:
def __post_init__(self): def __post_init__(self):
# This will raise an exception if the object is invalid. Calling this in __post__init() # This will raise an exception if the object is invalid. Calling this in __post__init()
# makes it impossible to construct an invalid object # makes it impossible to construct an invalid object
AgentConfigurationSchema().dump(self) try:
AgentConfigurationSchema().dump(self)
except Exception as err:
raise InvalidConfigurationError(f"{INVALID_CONFIGURATION_ERROR_MESSAGE}: {err}")
@staticmethod @staticmethod
def from_dict(dict_: dict): def from_dict(dict_: dict):
config_dict = AgentConfigurationSchema().load(dict_) try:
return AgentConfiguration(**config_dict) config_dict = AgentConfigurationSchema().load(dict_)
return AgentConfiguration(**config_dict)
except MarshmallowError as err:
raise InvalidConfigurationError(f"{INVALID_CONFIGURATION_ERROR_MESSAGE}: {err}")
@staticmethod @staticmethod
def from_json(config_json: dict): def from_json(config_json: dict):
config_dict = AgentConfigurationSchema().loads(config_json) try:
return AgentConfiguration(**config_dict) config_dict = AgentConfigurationSchema().loads(config_json)
return AgentConfiguration(**config_dict)
except MarshmallowError as err:
raise InvalidConfigurationError(f"{INVALID_CONFIGURATION_ERROR_MESSAGE}: {err}")
@staticmethod @staticmethod
def to_json(config: AgentConfiguration) -> str: def to_json(config: AgentConfiguration) -> str:

View File

@ -1,9 +1,9 @@
import json import json
import marshmallow
from flask import make_response, request from flask import make_response, request
from common.configuration.agent_configuration import AgentConfiguration as AgentConfigurationObject from common.configuration.agent_configuration import AgentConfiguration as AgentConfigurationObject
from common.configuration.agent_configuration import InvalidConfigurationError
from monkey_island.cc.repository import IAgentConfigurationRepository from monkey_island.cc.repository import IAgentConfigurationRepository
from monkey_island.cc.resources.AbstractResource import AbstractResource from monkey_island.cc.resources.AbstractResource import AbstractResource
from monkey_island.cc.resources.request_authentication import jwt_required from monkey_island.cc.resources.request_authentication import jwt_required
@ -28,7 +28,7 @@ class AgentConfiguration(AbstractResource):
configuration_object = AgentConfigurationObject.from_json(request.data) configuration_object = AgentConfigurationObject.from_json(request.data)
self._agent_configuration_repository.store_configuration(configuration_object) self._agent_configuration_repository.store_configuration(configuration_object)
return make_response({}, 200) return make_response({}, 200)
except (marshmallow.exceptions.ValidationError, json.JSONDecodeError) as err: except (InvalidConfigurationError, json.JSONDecodeError) as err:
return make_response( return make_response(
{"message": f"Invalid configuration supplied: {err}"}, {"message": f"Invalid configuration supplied: {err}"},
400, 400,

View File

@ -26,7 +26,11 @@ from tests.common.example_agent_configuration import (
WINDOWS_FILENAME, WINDOWS_FILENAME,
) )
from common.configuration import DEFAULT_AGENT_CONFIGURATION_JSON, AgentConfiguration from common.configuration import (
DEFAULT_AGENT_CONFIGURATION_JSON,
AgentConfiguration,
InvalidConfigurationError,
)
from common.configuration.agent_configuration import AgentConfigurationSchema from common.configuration.agent_configuration import AgentConfigurationSchema
from common.configuration.agent_sub_configuration_schemas import ( from common.configuration.agent_sub_configuration_schemas import (
CustomPBAConfigurationSchema, CustomPBAConfigurationSchema,
@ -172,7 +176,7 @@ def test_agent_configuration():
def test_incorrect_type(): def test_incorrect_type():
valid_config = AgentConfiguration.from_dict(AGENT_CONFIGURATION) valid_config = AgentConfiguration.from_dict(AGENT_CONFIGURATION)
with pytest.raises(Exception): with pytest.raises(InvalidConfigurationError):
valid_config_dict = valid_config.__dict__ valid_config_dict = valid_config.__dict__
valid_config_dict["keep_tunnel_open_time"] = "not_a_float" valid_config_dict["keep_tunnel_open_time"] = "not_a_float"
AgentConfiguration(**valid_config_dict) AgentConfiguration(**valid_config_dict)
@ -193,6 +197,14 @@ def test_from_dict():
assert schema.dump(config) == dict_ assert schema.dump(config) == dict_
def test_from_dict__invalid_data():
dict_ = json.loads(DEFAULT_AGENT_CONFIGURATION_JSON)
dict_["payloads"] = "payloads"
with pytest.raises(InvalidConfigurationError):
AgentConfiguration.from_dict(dict_)
def test_from_json(): def test_from_json():
schema = AgentConfigurationSchema() schema = AgentConfigurationSchema()
dict_ = json.loads(DEFAULT_AGENT_CONFIGURATION_JSON) dict_ = json.loads(DEFAULT_AGENT_CONFIGURATION_JSON)
@ -202,6 +214,14 @@ def test_from_json():
assert schema.dump(config) == dict_ assert schema.dump(config) == dict_
def test_from_json__invalid_data():
invalid_dict = json.loads(DEFAULT_AGENT_CONFIGURATION_JSON)
invalid_dict["payloads"] = "payloads"
with pytest.raises(InvalidConfigurationError):
AgentConfiguration.from_json(json.dumps(invalid_dict))
def test_to_json(): def test_to_json():
config = AgentConfiguration.from_json(DEFAULT_AGENT_CONFIGURATION_JSON) config = AgentConfiguration.from_json(DEFAULT_AGENT_CONFIGURATION_JSON)