Merge pull request #2041 from guardicore/agent-configuration-construction

Agent configuration construction
This commit is contained in:
Mike Salvatore 2022-06-24 14:17:13 -04:00 committed by GitHub
commit 93ed7cf428
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 143 additions and 43 deletions

View File

@ -1,7 +1,4 @@
from .agent_configuration import ( from .agent_configuration import AgentConfiguration, InvalidConfigurationError
AgentConfiguration,
AgentConfigurationSchema,
)
from .agent_sub_configurations import ( from .agent_sub_configurations import (
CustomPBAConfiguration, CustomPBAConfiguration,
PluginConfiguration, PluginConfiguration,

View File

@ -1,7 +1,10 @@
from dataclasses import dataclass from __future__ import annotations
from typing import List
from marshmallow import Schema, fields, post_load from dataclasses import dataclass
from typing import Any, List, Mapping
from marshmallow import Schema, fields
from marshmallow.exceptions import MarshmallowError
from .agent_sub_configuration_schemas import ( from .agent_sub_configuration_schemas import (
CustomPBAConfigurationSchema, CustomPBAConfigurationSchema,
@ -15,6 +18,15 @@ from .agent_sub_configurations import (
) )
class InvalidConfigurationError(Exception):
pass
INVALID_CONFIGURATION_ERROR_MESSAGE = (
"Cannot construct an AgentConfiguration object with the supplied, invalid data:"
)
@dataclass(frozen=True) @dataclass(frozen=True)
class AgentConfiguration: class AgentConfiguration:
keep_tunnel_open_time: float keep_tunnel_open_time: float
@ -24,9 +36,56 @@ class AgentConfiguration:
payloads: List[PluginConfiguration] payloads: List[PluginConfiguration]
propagation: PropagationConfiguration propagation: PropagationConfiguration
def __post_init__(self):
# This will raise an exception if the object is invalid. Calling this in __post__init()
# makes it impossible to construct an invalid object
try:
AgentConfigurationSchema().dump(self)
except Exception as err:
raise InvalidConfigurationError(f"{INVALID_CONFIGURATION_ERROR_MESSAGE}: {err}")
@staticmethod @staticmethod
def from_dict(_dict: dict): def from_mapping(config_mapping: Mapping[str, Any]) -> AgentConfiguration:
return AgentConfigurationSchema().load(_dict) """
Construct an AgentConfiguration from a Mapping
:param config_mapping: A Mapping that represents an AgentConfiguration
:return: An AgentConfiguration
:raises: InvalidConfigurationError if the provided Mapping does not represent a valid
AgentConfiguration
"""
try:
config_dict = AgentConfigurationSchema().load(config_mapping)
return AgentConfiguration(**config_dict)
except MarshmallowError as err:
raise InvalidConfigurationError(f"{INVALID_CONFIGURATION_ERROR_MESSAGE}: {err}")
@staticmethod
def from_json(config_json: str) -> AgentConfiguration:
"""
Construct an AgentConfiguration from a JSON string
:param config_json: A JSON string that represents an AgentConfiguration
:return: An AgentConfiguration
:raises: InvalidConfigurationError if the provided JSON does not represent a valid
AgentConfiguration
"""
try:
config_dict = AgentConfigurationSchema().loads(config_json)
return AgentConfiguration(**config_dict)
except MarshmallowError as err:
raise InvalidConfigurationError(f"{INVALID_CONFIGURATION_ERROR_MESSAGE}: {err}")
@staticmethod
def to_json(config: AgentConfiguration) -> str:
"""
Serialize an AgentConfiguration to JSON
:param config: An AgentConfiguration
:return: A JSON string representing the AgentConfiguration
"""
return AgentConfigurationSchema().dumps(config)
class AgentConfigurationSchema(Schema): class AgentConfigurationSchema(Schema):
@ -36,7 +95,3 @@ class AgentConfigurationSchema(Schema):
credential_collectors = fields.List(fields.Nested(PluginConfigurationSchema)) credential_collectors = fields.List(fields.Nested(PluginConfigurationSchema))
payloads = fields.List(fields.Nested(PluginConfigurationSchema)) payloads = fields.List(fields.Nested(PluginConfigurationSchema))
propagation = fields.Nested(PropagationConfigurationSchema) propagation = fields.Nested(PropagationConfigurationSchema)
@post_load
def _make_agent_configuration(self, data, **kwargs):
return AgentConfiguration(**data)

View File

@ -1,4 +1,4 @@
from . import AgentConfiguration, AgentConfigurationSchema from . import AgentConfiguration
DEFAULT_AGENT_CONFIGURATION_JSON = """{ DEFAULT_AGENT_CONFIGURATION_JSON = """{
"keep_tunnel_open_time": 30, "keep_tunnel_open_time": 30,
@ -204,5 +204,4 @@ DEFAULT_AGENT_CONFIGURATION_JSON = """{
def build_default_agent_configuration() -> AgentConfiguration: def build_default_agent_configuration() -> AgentConfiguration:
schema = AgentConfigurationSchema() return AgentConfiguration.from_json(DEFAULT_AGENT_CONFIGURATION_JSON)
return schema.loads(DEFAULT_AGENT_CONFIGURATION_JSON)

View File

@ -38,5 +38,7 @@ class DomainControllerNameFetchError(FailedExploitationError):
"""Raise on failed attempt to extract domain controller's name""" """Raise on failed attempt to extract domain controller's name"""
# TODO: This has been replaced by common.configuration.InvalidConfigurationError. Use that error
# instead and remove this one.
class InvalidConfigurationError(Exception): class InvalidConfigurationError(Exception):
"""Raise when configuration is invalid""" """Raise when configuration is invalid"""

View File

@ -58,7 +58,7 @@ class ControlChannel(IControlChannel):
) )
response.raise_for_status() response.raise_for_status()
return AgentConfiguration.from_dict(json.loads(response.text)["config"]) return AgentConfiguration.from_mapping(json.loads(response.text)["config"])
except ( except (
json.JSONDecodeError, json.JSONDecodeError,
requests.exceptions.ConnectionError, requests.exceptions.ConnectionError,

View File

@ -1,6 +1,6 @@
import io import io
from common.configuration import AgentConfiguration, AgentConfigurationSchema from common.configuration import AgentConfiguration
from monkey_island.cc import repository from monkey_island.cc import repository
from monkey_island.cc.repository import ( from monkey_island.cc.repository import (
IAgentConfigurationRepository, IAgentConfigurationRepository,
@ -17,21 +17,20 @@ class FileAgentConfigurationRepository(IAgentConfigurationRepository):
): ):
self._default_agent_configuration = default_agent_configuration self._default_agent_configuration = default_agent_configuration
self._file_repository = file_repository self._file_repository = file_repository
self._schema = AgentConfigurationSchema()
def get_configuration(self) -> AgentConfiguration: def get_configuration(self) -> AgentConfiguration:
try: try:
with self._file_repository.open_file(AGENT_CONFIGURATION_FILE_NAME) as f: with self._file_repository.open_file(AGENT_CONFIGURATION_FILE_NAME) as f:
configuration_json = f.read().decode() configuration_json = f.read().decode()
return self._schema.loads(configuration_json) return AgentConfiguration.from_json(configuration_json)
except repository.FileNotFoundError: except repository.FileNotFoundError:
return self._default_agent_configuration return self._default_agent_configuration
except Exception as err: except Exception as err:
raise RetrievalError(f"Error retrieving the agent configuration: {err}") raise RetrievalError(f"Error retrieving the agent configuration: {err}")
def store_configuration(self, agent_configuration: AgentConfiguration): def store_configuration(self, agent_configuration: AgentConfiguration):
configuration_json = self._schema.dumps(agent_configuration) configuration_json = AgentConfiguration.to_json(agent_configuration)
self._file_repository.save_file( self._file_repository.save_file(
AGENT_CONFIGURATION_FILE_NAME, io.BytesIO(configuration_json.encode()) AGENT_CONFIGURATION_FILE_NAME, io.BytesIO(configuration_json.encode())

View File

@ -1,9 +1,9 @@
import json import json
import marshmallow
from flask import make_response, request from flask import make_response, request
from common.configuration.agent_configuration import AgentConfigurationSchema from common.configuration.agent_configuration import AgentConfiguration as AgentConfigurationObject
from common.configuration.agent_configuration import InvalidConfigurationError
from monkey_island.cc.repository import IAgentConfigurationRepository from monkey_island.cc.repository import IAgentConfigurationRepository
from monkey_island.cc.resources.AbstractResource import AbstractResource from monkey_island.cc.resources.AbstractResource import AbstractResource
from monkey_island.cc.resources.request_authentication import jwt_required from monkey_island.cc.resources.request_authentication import jwt_required
@ -14,22 +14,21 @@ class AgentConfiguration(AbstractResource):
def __init__(self, agent_configuration_repository: IAgentConfigurationRepository): def __init__(self, agent_configuration_repository: IAgentConfigurationRepository):
self._agent_configuration_repository = agent_configuration_repository self._agent_configuration_repository = agent_configuration_repository
self._schema = AgentConfigurationSchema()
@jwt_required @jwt_required
def get(self): def get(self):
configuration = self._agent_configuration_repository.get_configuration() configuration = self._agent_configuration_repository.get_configuration()
configuration_json = self._schema.dumps(configuration) configuration_json = AgentConfigurationObject.to_json(configuration)
return make_response(configuration_json, 200) return make_response(configuration_json, 200)
@jwt_required @jwt_required
def post(self): def post(self):
try: try:
configuration_object = self._schema.loads(request.data) configuration_object = AgentConfigurationObject.from_json(request.data)
self._agent_configuration_repository.store_configuration(configuration_object) self._agent_configuration_repository.store_configuration(configuration_object)
return make_response({}, 200) return make_response({}, 200)
except (marshmallow.exceptions.ValidationError, json.JSONDecodeError) as err: except (InvalidConfigurationError, json.JSONDecodeError) as err:
return make_response( return make_response(
{"message": f"Invalid configuration supplied: {err}"}, {"message": f"Invalid configuration supplied: {err}"},
400, 400,

View File

@ -1,4 +1,4 @@
from common.configuration import AgentConfigurationSchema from common.configuration import AgentConfiguration
flat_config = { flat_config = {
"keep_tunnel_open_time": 30, "keep_tunnel_open_time": 30,
@ -116,4 +116,4 @@ flat_config = {
}, },
} }
DEFAULT_CONFIG = AgentConfigurationSchema().load(flat_config) DEFAULT_CONFIG = AgentConfiguration.from_mapping(flat_config)

View File

@ -1,12 +1,12 @@
from tests.common.example_agent_configuration import AGENT_CONFIGURATION from tests.common.example_agent_configuration import AGENT_CONFIGURATION
from common.configuration.agent_configuration import AgentConfigurationSchema from common.configuration.agent_configuration import AgentConfiguration
from monkey_island.cc.repository import IAgentConfigurationRepository from monkey_island.cc.repository import IAgentConfigurationRepository
class InMemoryAgentConfigurationRepository(IAgentConfigurationRepository): class InMemoryAgentConfigurationRepository(IAgentConfigurationRepository):
def __init__(self): def __init__(self):
self._configuration = AgentConfigurationSchema().load(AGENT_CONFIGURATION) self._configuration = AgentConfiguration.from_mapping(AGENT_CONFIGURATION)
def get_configuration(self): def get_configuration(self):
return self._configuration return self._configuration

View File

@ -1,3 +1,6 @@
import json
import pytest
from tests.common.example_agent_configuration import ( from tests.common.example_agent_configuration import (
AGENT_CONFIGURATION, AGENT_CONFIGURATION,
BLOCKED_IPS, BLOCKED_IPS,
@ -26,8 +29,9 @@ from tests.common.example_agent_configuration import (
from common.configuration import ( from common.configuration import (
DEFAULT_AGENT_CONFIGURATION_JSON, DEFAULT_AGENT_CONFIGURATION_JSON,
AgentConfiguration, AgentConfiguration,
AgentConfigurationSchema, InvalidConfigurationError,
) )
from common.configuration.agent_configuration import AgentConfigurationSchema
from common.configuration.agent_sub_configuration_schemas import ( from common.configuration.agent_sub_configuration_schemas import (
CustomPBAConfigurationSchema, CustomPBAConfigurationSchema,
ExploitationConfigurationSchema, ExploitationConfigurationSchema,
@ -157,10 +161,8 @@ def test_propagation_configuration():
def test_agent_configuration(): def test_agent_configuration():
schema = AgentConfigurationSchema() config = AgentConfiguration.from_mapping(AGENT_CONFIGURATION)
config_json = AgentConfiguration.to_json(config)
config = schema.load(AGENT_CONFIGURATION)
config_dict = schema.dump(config)
assert isinstance(config, AgentConfiguration) assert isinstance(config, AgentConfiguration)
assert config.keep_tunnel_open_time == 30 assert config.keep_tunnel_open_time == 30
@ -169,12 +171,60 @@ def test_agent_configuration():
assert isinstance(config.credential_collectors[0], PluginConfiguration) assert isinstance(config.credential_collectors[0], PluginConfiguration)
assert isinstance(config.payloads[0], PluginConfiguration) assert isinstance(config.payloads[0], PluginConfiguration)
assert isinstance(config.propagation, PropagationConfiguration) assert isinstance(config.propagation, PropagationConfiguration)
assert config_dict == AGENT_CONFIGURATION assert json.loads(config_json) == AGENT_CONFIGURATION
def test_incorrect_type():
valid_config = AgentConfiguration.from_mapping(AGENT_CONFIGURATION)
with pytest.raises(InvalidConfigurationError):
valid_config_dict = valid_config.__dict__
valid_config_dict["keep_tunnel_open_time"] = "not_a_float"
AgentConfiguration(**valid_config_dict)
def test_default_agent_configuration(): def test_default_agent_configuration():
schema = AgentConfigurationSchema() config = AgentConfiguration.from_json(DEFAULT_AGENT_CONFIGURATION_JSON)
config = schema.loads(DEFAULT_AGENT_CONFIGURATION_JSON)
assert isinstance(config, AgentConfiguration) assert isinstance(config, AgentConfiguration)
def test_from_dict():
schema = AgentConfigurationSchema()
dict_ = json.loads(DEFAULT_AGENT_CONFIGURATION_JSON)
config = AgentConfiguration.from_mapping(dict_)
assert schema.dump(config) == dict_
def test_from_dict__invalid_data():
dict_ = json.loads(DEFAULT_AGENT_CONFIGURATION_JSON)
dict_["payloads"] = "payloads"
with pytest.raises(InvalidConfigurationError):
AgentConfiguration.from_mapping(dict_)
def test_from_json():
schema = AgentConfigurationSchema()
dict_ = json.loads(DEFAULT_AGENT_CONFIGURATION_JSON)
config = AgentConfiguration.from_json(DEFAULT_AGENT_CONFIGURATION_JSON)
assert schema.dump(config) == dict_
def test_from_json__invalid_data():
invalid_dict = json.loads(DEFAULT_AGENT_CONFIGURATION_JSON)
invalid_dict["payloads"] = "payloads"
with pytest.raises(InvalidConfigurationError):
AgentConfiguration.from_json(json.dumps(invalid_dict))
def test_to_json():
config = AgentConfiguration.from_json(DEFAULT_AGENT_CONFIGURATION_JSON)
assert json.loads(AgentConfiguration.to_json(config)) == json.loads(
DEFAULT_AGENT_CONFIGURATION_JSON
)

View File

@ -2,7 +2,7 @@ import pytest
from tests.common.example_agent_configuration import AGENT_CONFIGURATION from tests.common.example_agent_configuration import AGENT_CONFIGURATION
from tests.monkey_island import OpenErrorFileRepository, SingleFileRepository from tests.monkey_island import OpenErrorFileRepository, SingleFileRepository
from common.configuration import AgentConfigurationSchema from common.configuration import AgentConfiguration
from monkey_island.cc.repository import FileAgentConfigurationRepository, RetrievalError from monkey_island.cc.repository import FileAgentConfigurationRepository, RetrievalError
@ -12,8 +12,7 @@ def repository(default_agent_configuration):
def test_store_agent_config(repository): def test_store_agent_config(repository):
schema = AgentConfigurationSchema() agent_configuration = AgentConfiguration.from_mapping(AGENT_CONFIGURATION)
agent_configuration = schema.load(AGENT_CONFIGURATION)
repository.store_configuration(agent_configuration) repository.store_configuration(agent_configuration)
retrieved_agent_configuration = repository.get_configuration() retrieved_agent_configuration = repository.get_configuration()