diff --git a/monkey/common/configuration/__init__.py b/monkey/common/configuration/__init__.py index 107e5a491..06ce30b50 100644 --- a/monkey/common/configuration/__init__.py +++ b/monkey/common/configuration/__init__.py @@ -1,7 +1,4 @@ -from .agent_configuration import ( - AgentConfiguration, - AgentConfigurationSchema, -) +from .agent_configuration import AgentConfiguration, InvalidConfigurationError from .agent_sub_configurations import ( CustomPBAConfiguration, PluginConfiguration, diff --git a/monkey/common/configuration/agent_configuration.py b/monkey/common/configuration/agent_configuration.py index a636e3d95..097b86382 100644 --- a/monkey/common/configuration/agent_configuration.py +++ b/monkey/common/configuration/agent_configuration.py @@ -1,7 +1,10 @@ -from dataclasses import dataclass -from typing import List +from __future__ import annotations -from marshmallow import Schema, fields, post_load +from dataclasses import dataclass +from typing import Any, List, Mapping + +from marshmallow import Schema, fields +from marshmallow.exceptions import MarshmallowError from .agent_sub_configuration_schemas import ( CustomPBAConfigurationSchema, @@ -15,6 +18,15 @@ from .agent_sub_configurations import ( ) +class InvalidConfigurationError(Exception): + pass + + +INVALID_CONFIGURATION_ERROR_MESSAGE = ( + "Cannot construct an AgentConfiguration object with the supplied, invalid data:" +) + + @dataclass(frozen=True) class AgentConfiguration: keep_tunnel_open_time: float @@ -24,9 +36,56 @@ class AgentConfiguration: payloads: List[PluginConfiguration] propagation: PropagationConfiguration + def __post_init__(self): + # This will raise an exception if the object is invalid. Calling this in __post__init() + # makes it impossible to construct an invalid object + try: + AgentConfigurationSchema().dump(self) + except Exception as err: + raise InvalidConfigurationError(f"{INVALID_CONFIGURATION_ERROR_MESSAGE}: {err}") + @staticmethod - def from_dict(_dict: dict): - return AgentConfigurationSchema().load(_dict) + def from_mapping(config_mapping: Mapping[str, Any]) -> AgentConfiguration: + """ + Construct an AgentConfiguration from a Mapping + + :param config_mapping: A Mapping that represents an AgentConfiguration + :return: An AgentConfiguration + :raises: InvalidConfigurationError if the provided Mapping does not represent a valid + AgentConfiguration + """ + + try: + config_dict = AgentConfigurationSchema().load(config_mapping) + return AgentConfiguration(**config_dict) + except MarshmallowError as err: + raise InvalidConfigurationError(f"{INVALID_CONFIGURATION_ERROR_MESSAGE}: {err}") + + @staticmethod + def from_json(config_json: str) -> AgentConfiguration: + """ + Construct an AgentConfiguration from a JSON string + + :param config_json: A JSON string that represents an AgentConfiguration + :return: An AgentConfiguration + :raises: InvalidConfigurationError if the provided JSON does not represent a valid + AgentConfiguration + """ + try: + config_dict = AgentConfigurationSchema().loads(config_json) + return AgentConfiguration(**config_dict) + except MarshmallowError as err: + raise InvalidConfigurationError(f"{INVALID_CONFIGURATION_ERROR_MESSAGE}: {err}") + + @staticmethod + def to_json(config: AgentConfiguration) -> str: + """ + Serialize an AgentConfiguration to JSON + + :param config: An AgentConfiguration + :return: A JSON string representing the AgentConfiguration + """ + return AgentConfigurationSchema().dumps(config) class AgentConfigurationSchema(Schema): @@ -36,7 +95,3 @@ class AgentConfigurationSchema(Schema): credential_collectors = fields.List(fields.Nested(PluginConfigurationSchema)) payloads = fields.List(fields.Nested(PluginConfigurationSchema)) propagation = fields.Nested(PropagationConfigurationSchema) - - @post_load - def _make_agent_configuration(self, data, **kwargs): - return AgentConfiguration(**data) diff --git a/monkey/common/configuration/default_agent_configuration.py b/monkey/common/configuration/default_agent_configuration.py index c83169566..f3295a4b6 100644 --- a/monkey/common/configuration/default_agent_configuration.py +++ b/monkey/common/configuration/default_agent_configuration.py @@ -1,4 +1,4 @@ -from . import AgentConfiguration, AgentConfigurationSchema +from . import AgentConfiguration DEFAULT_AGENT_CONFIGURATION_JSON = """{ "keep_tunnel_open_time": 30, @@ -204,5 +204,4 @@ DEFAULT_AGENT_CONFIGURATION_JSON = """{ def build_default_agent_configuration() -> AgentConfiguration: - schema = AgentConfigurationSchema() - return schema.loads(DEFAULT_AGENT_CONFIGURATION_JSON) + return AgentConfiguration.from_json(DEFAULT_AGENT_CONFIGURATION_JSON) diff --git a/monkey/common/utils/exceptions.py b/monkey/common/utils/exceptions.py index 5935145e7..31cebca32 100644 --- a/monkey/common/utils/exceptions.py +++ b/monkey/common/utils/exceptions.py @@ -38,5 +38,7 @@ class DomainControllerNameFetchError(FailedExploitationError): """Raise on failed attempt to extract domain controller's name""" +# TODO: This has been replaced by common.configuration.InvalidConfigurationError. Use that error +# instead and remove this one. class InvalidConfigurationError(Exception): """Raise when configuration is invalid""" diff --git a/monkey/infection_monkey/master/control_channel.py b/monkey/infection_monkey/master/control_channel.py index 8567a301f..c83942f5d 100644 --- a/monkey/infection_monkey/master/control_channel.py +++ b/monkey/infection_monkey/master/control_channel.py @@ -58,7 +58,7 @@ class ControlChannel(IControlChannel): ) response.raise_for_status() - return AgentConfiguration.from_dict(json.loads(response.text)["config"]) + return AgentConfiguration.from_mapping(json.loads(response.text)["config"]) except ( json.JSONDecodeError, requests.exceptions.ConnectionError, diff --git a/monkey/monkey_island/cc/repository/file_agent_configuration_repository.py b/monkey/monkey_island/cc/repository/file_agent_configuration_repository.py index b63ee817c..312e3921e 100644 --- a/monkey/monkey_island/cc/repository/file_agent_configuration_repository.py +++ b/monkey/monkey_island/cc/repository/file_agent_configuration_repository.py @@ -1,6 +1,6 @@ import io -from common.configuration import AgentConfiguration, AgentConfigurationSchema +from common.configuration import AgentConfiguration from monkey_island.cc import repository from monkey_island.cc.repository import ( IAgentConfigurationRepository, @@ -17,21 +17,20 @@ class FileAgentConfigurationRepository(IAgentConfigurationRepository): ): self._default_agent_configuration = default_agent_configuration self._file_repository = file_repository - self._schema = AgentConfigurationSchema() def get_configuration(self) -> AgentConfiguration: try: with self._file_repository.open_file(AGENT_CONFIGURATION_FILE_NAME) as f: configuration_json = f.read().decode() - return self._schema.loads(configuration_json) + return AgentConfiguration.from_json(configuration_json) except repository.FileNotFoundError: return self._default_agent_configuration except Exception as err: raise RetrievalError(f"Error retrieving the agent configuration: {err}") def store_configuration(self, agent_configuration: AgentConfiguration): - configuration_json = self._schema.dumps(agent_configuration) + configuration_json = AgentConfiguration.to_json(agent_configuration) self._file_repository.save_file( AGENT_CONFIGURATION_FILE_NAME, io.BytesIO(configuration_json.encode()) diff --git a/monkey/monkey_island/cc/resources/agent_configuration.py b/monkey/monkey_island/cc/resources/agent_configuration.py index f0ad73cb8..1bf311564 100644 --- a/monkey/monkey_island/cc/resources/agent_configuration.py +++ b/monkey/monkey_island/cc/resources/agent_configuration.py @@ -1,9 +1,9 @@ import json -import marshmallow from flask import make_response, request -from common.configuration.agent_configuration import AgentConfigurationSchema +from common.configuration.agent_configuration import AgentConfiguration as AgentConfigurationObject +from common.configuration.agent_configuration import InvalidConfigurationError from monkey_island.cc.repository import IAgentConfigurationRepository from monkey_island.cc.resources.AbstractResource import AbstractResource from monkey_island.cc.resources.request_authentication import jwt_required @@ -14,22 +14,21 @@ class AgentConfiguration(AbstractResource): def __init__(self, agent_configuration_repository: IAgentConfigurationRepository): self._agent_configuration_repository = agent_configuration_repository - self._schema = AgentConfigurationSchema() @jwt_required def get(self): configuration = self._agent_configuration_repository.get_configuration() - configuration_json = self._schema.dumps(configuration) + configuration_json = AgentConfigurationObject.to_json(configuration) return make_response(configuration_json, 200) @jwt_required def post(self): try: - configuration_object = self._schema.loads(request.data) + configuration_object = AgentConfigurationObject.from_json(request.data) self._agent_configuration_repository.store_configuration(configuration_object) return make_response({}, 200) - except (marshmallow.exceptions.ValidationError, json.JSONDecodeError) as err: + except (InvalidConfigurationError, json.JSONDecodeError) as err: return make_response( {"message": f"Invalid configuration supplied: {err}"}, 400, diff --git a/monkey/tests/data_for_tests/monkey_configs/default_config.py b/monkey/tests/data_for_tests/monkey_configs/default_config.py index e232d122d..c0eca7c43 100644 --- a/monkey/tests/data_for_tests/monkey_configs/default_config.py +++ b/monkey/tests/data_for_tests/monkey_configs/default_config.py @@ -1,4 +1,4 @@ -from common.configuration import AgentConfigurationSchema +from common.configuration import AgentConfiguration flat_config = { "keep_tunnel_open_time": 30, @@ -116,4 +116,4 @@ flat_config = { }, } -DEFAULT_CONFIG = AgentConfigurationSchema().load(flat_config) +DEFAULT_CONFIG = AgentConfiguration.from_mapping(flat_config) diff --git a/monkey/tests/monkey_island/in_memory_agent_configuration_repository.py b/monkey/tests/monkey_island/in_memory_agent_configuration_repository.py index e737d645c..e9bcbae62 100644 --- a/monkey/tests/monkey_island/in_memory_agent_configuration_repository.py +++ b/monkey/tests/monkey_island/in_memory_agent_configuration_repository.py @@ -1,12 +1,12 @@ from tests.common.example_agent_configuration import AGENT_CONFIGURATION -from common.configuration.agent_configuration import AgentConfigurationSchema +from common.configuration.agent_configuration import AgentConfiguration from monkey_island.cc.repository import IAgentConfigurationRepository class InMemoryAgentConfigurationRepository(IAgentConfigurationRepository): def __init__(self): - self._configuration = AgentConfigurationSchema().load(AGENT_CONFIGURATION) + self._configuration = AgentConfiguration.from_mapping(AGENT_CONFIGURATION) def get_configuration(self): return self._configuration diff --git a/monkey/tests/unit_tests/common/test_agent_configuration.py b/monkey/tests/unit_tests/common/configuration/test_agent_configuration.py similarity index 74% rename from monkey/tests/unit_tests/common/test_agent_configuration.py rename to monkey/tests/unit_tests/common/configuration/test_agent_configuration.py index 7ea80cfc5..4b264c8cb 100644 --- a/monkey/tests/unit_tests/common/test_agent_configuration.py +++ b/monkey/tests/unit_tests/common/configuration/test_agent_configuration.py @@ -1,3 +1,6 @@ +import json + +import pytest from tests.common.example_agent_configuration import ( AGENT_CONFIGURATION, BLOCKED_IPS, @@ -26,8 +29,9 @@ from tests.common.example_agent_configuration import ( from common.configuration import ( DEFAULT_AGENT_CONFIGURATION_JSON, AgentConfiguration, - AgentConfigurationSchema, + InvalidConfigurationError, ) +from common.configuration.agent_configuration import AgentConfigurationSchema from common.configuration.agent_sub_configuration_schemas import ( CustomPBAConfigurationSchema, ExploitationConfigurationSchema, @@ -157,10 +161,8 @@ def test_propagation_configuration(): def test_agent_configuration(): - schema = AgentConfigurationSchema() - - config = schema.load(AGENT_CONFIGURATION) - config_dict = schema.dump(config) + config = AgentConfiguration.from_mapping(AGENT_CONFIGURATION) + config_json = AgentConfiguration.to_json(config) assert isinstance(config, AgentConfiguration) assert config.keep_tunnel_open_time == 30 @@ -169,12 +171,60 @@ def test_agent_configuration(): assert isinstance(config.credential_collectors[0], PluginConfiguration) assert isinstance(config.payloads[0], PluginConfiguration) assert isinstance(config.propagation, PropagationConfiguration) - assert config_dict == AGENT_CONFIGURATION + assert json.loads(config_json) == AGENT_CONFIGURATION + + +def test_incorrect_type(): + valid_config = AgentConfiguration.from_mapping(AGENT_CONFIGURATION) + with pytest.raises(InvalidConfigurationError): + valid_config_dict = valid_config.__dict__ + valid_config_dict["keep_tunnel_open_time"] = "not_a_float" + AgentConfiguration(**valid_config_dict) def test_default_agent_configuration(): - schema = AgentConfigurationSchema() - - config = schema.loads(DEFAULT_AGENT_CONFIGURATION_JSON) + config = AgentConfiguration.from_json(DEFAULT_AGENT_CONFIGURATION_JSON) assert isinstance(config, AgentConfiguration) + + +def test_from_dict(): + schema = AgentConfigurationSchema() + dict_ = json.loads(DEFAULT_AGENT_CONFIGURATION_JSON) + + config = AgentConfiguration.from_mapping(dict_) + + assert schema.dump(config) == dict_ + + +def test_from_dict__invalid_data(): + dict_ = json.loads(DEFAULT_AGENT_CONFIGURATION_JSON) + dict_["payloads"] = "payloads" + + with pytest.raises(InvalidConfigurationError): + AgentConfiguration.from_mapping(dict_) + + +def test_from_json(): + schema = AgentConfigurationSchema() + dict_ = json.loads(DEFAULT_AGENT_CONFIGURATION_JSON) + + config = AgentConfiguration.from_json(DEFAULT_AGENT_CONFIGURATION_JSON) + + assert schema.dump(config) == dict_ + + +def test_from_json__invalid_data(): + invalid_dict = json.loads(DEFAULT_AGENT_CONFIGURATION_JSON) + invalid_dict["payloads"] = "payloads" + + with pytest.raises(InvalidConfigurationError): + AgentConfiguration.from_json(json.dumps(invalid_dict)) + + +def test_to_json(): + config = AgentConfiguration.from_json(DEFAULT_AGENT_CONFIGURATION_JSON) + + assert json.loads(AgentConfiguration.to_json(config)) == json.loads( + DEFAULT_AGENT_CONFIGURATION_JSON + ) diff --git a/monkey/tests/unit_tests/monkey_island/cc/repository/test_file_agent_configuration_repository.py b/monkey/tests/unit_tests/monkey_island/cc/repository/test_file_agent_configuration_repository.py index 4ab111606..fb7863dc3 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/repository/test_file_agent_configuration_repository.py +++ b/monkey/tests/unit_tests/monkey_island/cc/repository/test_file_agent_configuration_repository.py @@ -2,7 +2,7 @@ import pytest from tests.common.example_agent_configuration import AGENT_CONFIGURATION from tests.monkey_island import OpenErrorFileRepository, SingleFileRepository -from common.configuration import AgentConfigurationSchema +from common.configuration import AgentConfiguration from monkey_island.cc.repository import FileAgentConfigurationRepository, RetrievalError @@ -12,8 +12,7 @@ def repository(default_agent_configuration): def test_store_agent_config(repository): - schema = AgentConfigurationSchema() - agent_configuration = schema.load(AGENT_CONFIGURATION) + agent_configuration = AgentConfiguration.from_mapping(AGENT_CONFIGURATION) repository.store_configuration(agent_configuration) retrieved_agent_configuration = repository.get_configuration()