Merge pull request #2041 from guardicore/agent-configuration-construction

Agent configuration construction
This commit is contained in:
Mike Salvatore 2022-06-24 14:17:13 -04:00 committed by GitHub
commit 93ed7cf428
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 143 additions and 43 deletions

View File

@ -1,7 +1,4 @@
from .agent_configuration import (
AgentConfiguration,
AgentConfigurationSchema,
)
from .agent_configuration import AgentConfiguration, InvalidConfigurationError
from .agent_sub_configurations import (
CustomPBAConfiguration,
PluginConfiguration,

View File

@ -1,7 +1,10 @@
from dataclasses import dataclass
from typing import List
from __future__ import annotations
from marshmallow import Schema, fields, post_load
from dataclasses import dataclass
from typing import Any, List, Mapping
from marshmallow import Schema, fields
from marshmallow.exceptions import MarshmallowError
from .agent_sub_configuration_schemas import (
CustomPBAConfigurationSchema,
@ -15,6 +18,15 @@ from .agent_sub_configurations import (
)
class InvalidConfigurationError(Exception):
pass
INVALID_CONFIGURATION_ERROR_MESSAGE = (
"Cannot construct an AgentConfiguration object with the supplied, invalid data:"
)
@dataclass(frozen=True)
class AgentConfiguration:
keep_tunnel_open_time: float
@ -24,9 +36,56 @@ class AgentConfiguration:
payloads: List[PluginConfiguration]
propagation: PropagationConfiguration
def __post_init__(self):
# This will raise an exception if the object is invalid. Calling this in __post__init()
# makes it impossible to construct an invalid object
try:
AgentConfigurationSchema().dump(self)
except Exception as err:
raise InvalidConfigurationError(f"{INVALID_CONFIGURATION_ERROR_MESSAGE}: {err}")
@staticmethod
def from_dict(_dict: dict):
return AgentConfigurationSchema().load(_dict)
def from_mapping(config_mapping: Mapping[str, Any]) -> AgentConfiguration:
"""
Construct an AgentConfiguration from a Mapping
:param config_mapping: A Mapping that represents an AgentConfiguration
:return: An AgentConfiguration
:raises: InvalidConfigurationError if the provided Mapping does not represent a valid
AgentConfiguration
"""
try:
config_dict = AgentConfigurationSchema().load(config_mapping)
return AgentConfiguration(**config_dict)
except MarshmallowError as err:
raise InvalidConfigurationError(f"{INVALID_CONFIGURATION_ERROR_MESSAGE}: {err}")
@staticmethod
def from_json(config_json: str) -> AgentConfiguration:
"""
Construct an AgentConfiguration from a JSON string
:param config_json: A JSON string that represents an AgentConfiguration
:return: An AgentConfiguration
:raises: InvalidConfigurationError if the provided JSON does not represent a valid
AgentConfiguration
"""
try:
config_dict = AgentConfigurationSchema().loads(config_json)
return AgentConfiguration(**config_dict)
except MarshmallowError as err:
raise InvalidConfigurationError(f"{INVALID_CONFIGURATION_ERROR_MESSAGE}: {err}")
@staticmethod
def to_json(config: AgentConfiguration) -> str:
"""
Serialize an AgentConfiguration to JSON
:param config: An AgentConfiguration
:return: A JSON string representing the AgentConfiguration
"""
return AgentConfigurationSchema().dumps(config)
class AgentConfigurationSchema(Schema):
@ -36,7 +95,3 @@ class AgentConfigurationSchema(Schema):
credential_collectors = fields.List(fields.Nested(PluginConfigurationSchema))
payloads = fields.List(fields.Nested(PluginConfigurationSchema))
propagation = fields.Nested(PropagationConfigurationSchema)
@post_load
def _make_agent_configuration(self, data, **kwargs):
return AgentConfiguration(**data)

View File

@ -1,4 +1,4 @@
from . import AgentConfiguration, AgentConfigurationSchema
from . import AgentConfiguration
DEFAULT_AGENT_CONFIGURATION_JSON = """{
"keep_tunnel_open_time": 30,
@ -204,5 +204,4 @@ DEFAULT_AGENT_CONFIGURATION_JSON = """{
def build_default_agent_configuration() -> AgentConfiguration:
schema = AgentConfigurationSchema()
return schema.loads(DEFAULT_AGENT_CONFIGURATION_JSON)
return AgentConfiguration.from_json(DEFAULT_AGENT_CONFIGURATION_JSON)

View File

@ -38,5 +38,7 @@ class DomainControllerNameFetchError(FailedExploitationError):
"""Raise on failed attempt to extract domain controller's name"""
# TODO: This has been replaced by common.configuration.InvalidConfigurationError. Use that error
# instead and remove this one.
class InvalidConfigurationError(Exception):
"""Raise when configuration is invalid"""

View File

@ -58,7 +58,7 @@ class ControlChannel(IControlChannel):
)
response.raise_for_status()
return AgentConfiguration.from_dict(json.loads(response.text)["config"])
return AgentConfiguration.from_mapping(json.loads(response.text)["config"])
except (
json.JSONDecodeError,
requests.exceptions.ConnectionError,

View File

@ -1,6 +1,6 @@
import io
from common.configuration import AgentConfiguration, AgentConfigurationSchema
from common.configuration import AgentConfiguration
from monkey_island.cc import repository
from monkey_island.cc.repository import (
IAgentConfigurationRepository,
@ -17,21 +17,20 @@ class FileAgentConfigurationRepository(IAgentConfigurationRepository):
):
self._default_agent_configuration = default_agent_configuration
self._file_repository = file_repository
self._schema = AgentConfigurationSchema()
def get_configuration(self) -> AgentConfiguration:
try:
with self._file_repository.open_file(AGENT_CONFIGURATION_FILE_NAME) as f:
configuration_json = f.read().decode()
return self._schema.loads(configuration_json)
return AgentConfiguration.from_json(configuration_json)
except repository.FileNotFoundError:
return self._default_agent_configuration
except Exception as err:
raise RetrievalError(f"Error retrieving the agent configuration: {err}")
def store_configuration(self, agent_configuration: AgentConfiguration):
configuration_json = self._schema.dumps(agent_configuration)
configuration_json = AgentConfiguration.to_json(agent_configuration)
self._file_repository.save_file(
AGENT_CONFIGURATION_FILE_NAME, io.BytesIO(configuration_json.encode())

View File

@ -1,9 +1,9 @@
import json
import marshmallow
from flask import make_response, request
from common.configuration.agent_configuration import AgentConfigurationSchema
from common.configuration.agent_configuration import AgentConfiguration as AgentConfigurationObject
from common.configuration.agent_configuration import InvalidConfigurationError
from monkey_island.cc.repository import IAgentConfigurationRepository
from monkey_island.cc.resources.AbstractResource import AbstractResource
from monkey_island.cc.resources.request_authentication import jwt_required
@ -14,22 +14,21 @@ class AgentConfiguration(AbstractResource):
def __init__(self, agent_configuration_repository: IAgentConfigurationRepository):
self._agent_configuration_repository = agent_configuration_repository
self._schema = AgentConfigurationSchema()
@jwt_required
def get(self):
configuration = self._agent_configuration_repository.get_configuration()
configuration_json = self._schema.dumps(configuration)
configuration_json = AgentConfigurationObject.to_json(configuration)
return make_response(configuration_json, 200)
@jwt_required
def post(self):
try:
configuration_object = self._schema.loads(request.data)
configuration_object = AgentConfigurationObject.from_json(request.data)
self._agent_configuration_repository.store_configuration(configuration_object)
return make_response({}, 200)
except (marshmallow.exceptions.ValidationError, json.JSONDecodeError) as err:
except (InvalidConfigurationError, json.JSONDecodeError) as err:
return make_response(
{"message": f"Invalid configuration supplied: {err}"},
400,

View File

@ -1,4 +1,4 @@
from common.configuration import AgentConfigurationSchema
from common.configuration import AgentConfiguration
flat_config = {
"keep_tunnel_open_time": 30,
@ -116,4 +116,4 @@ flat_config = {
},
}
DEFAULT_CONFIG = AgentConfigurationSchema().load(flat_config)
DEFAULT_CONFIG = AgentConfiguration.from_mapping(flat_config)

View File

@ -1,12 +1,12 @@
from tests.common.example_agent_configuration import AGENT_CONFIGURATION
from common.configuration.agent_configuration import AgentConfigurationSchema
from common.configuration.agent_configuration import AgentConfiguration
from monkey_island.cc.repository import IAgentConfigurationRepository
class InMemoryAgentConfigurationRepository(IAgentConfigurationRepository):
def __init__(self):
self._configuration = AgentConfigurationSchema().load(AGENT_CONFIGURATION)
self._configuration = AgentConfiguration.from_mapping(AGENT_CONFIGURATION)
def get_configuration(self):
return self._configuration

View File

@ -1,3 +1,6 @@
import json
import pytest
from tests.common.example_agent_configuration import (
AGENT_CONFIGURATION,
BLOCKED_IPS,
@ -26,8 +29,9 @@ from tests.common.example_agent_configuration import (
from common.configuration import (
DEFAULT_AGENT_CONFIGURATION_JSON,
AgentConfiguration,
AgentConfigurationSchema,
InvalidConfigurationError,
)
from common.configuration.agent_configuration import AgentConfigurationSchema
from common.configuration.agent_sub_configuration_schemas import (
CustomPBAConfigurationSchema,
ExploitationConfigurationSchema,
@ -157,10 +161,8 @@ def test_propagation_configuration():
def test_agent_configuration():
schema = AgentConfigurationSchema()
config = schema.load(AGENT_CONFIGURATION)
config_dict = schema.dump(config)
config = AgentConfiguration.from_mapping(AGENT_CONFIGURATION)
config_json = AgentConfiguration.to_json(config)
assert isinstance(config, AgentConfiguration)
assert config.keep_tunnel_open_time == 30
@ -169,12 +171,60 @@ def test_agent_configuration():
assert isinstance(config.credential_collectors[0], PluginConfiguration)
assert isinstance(config.payloads[0], PluginConfiguration)
assert isinstance(config.propagation, PropagationConfiguration)
assert config_dict == AGENT_CONFIGURATION
assert json.loads(config_json) == AGENT_CONFIGURATION
def test_incorrect_type():
valid_config = AgentConfiguration.from_mapping(AGENT_CONFIGURATION)
with pytest.raises(InvalidConfigurationError):
valid_config_dict = valid_config.__dict__
valid_config_dict["keep_tunnel_open_time"] = "not_a_float"
AgentConfiguration(**valid_config_dict)
def test_default_agent_configuration():
schema = AgentConfigurationSchema()
config = schema.loads(DEFAULT_AGENT_CONFIGURATION_JSON)
config = AgentConfiguration.from_json(DEFAULT_AGENT_CONFIGURATION_JSON)
assert isinstance(config, AgentConfiguration)
def test_from_dict():
schema = AgentConfigurationSchema()
dict_ = json.loads(DEFAULT_AGENT_CONFIGURATION_JSON)
config = AgentConfiguration.from_mapping(dict_)
assert schema.dump(config) == dict_
def test_from_dict__invalid_data():
dict_ = json.loads(DEFAULT_AGENT_CONFIGURATION_JSON)
dict_["payloads"] = "payloads"
with pytest.raises(InvalidConfigurationError):
AgentConfiguration.from_mapping(dict_)
def test_from_json():
schema = AgentConfigurationSchema()
dict_ = json.loads(DEFAULT_AGENT_CONFIGURATION_JSON)
config = AgentConfiguration.from_json(DEFAULT_AGENT_CONFIGURATION_JSON)
assert schema.dump(config) == dict_
def test_from_json__invalid_data():
invalid_dict = json.loads(DEFAULT_AGENT_CONFIGURATION_JSON)
invalid_dict["payloads"] = "payloads"
with pytest.raises(InvalidConfigurationError):
AgentConfiguration.from_json(json.dumps(invalid_dict))
def test_to_json():
config = AgentConfiguration.from_json(DEFAULT_AGENT_CONFIGURATION_JSON)
assert json.loads(AgentConfiguration.to_json(config)) == json.loads(
DEFAULT_AGENT_CONFIGURATION_JSON
)

View File

@ -2,7 +2,7 @@ import pytest
from tests.common.example_agent_configuration import AGENT_CONFIGURATION
from tests.monkey_island import OpenErrorFileRepository, SingleFileRepository
from common.configuration import AgentConfigurationSchema
from common.configuration import AgentConfiguration
from monkey_island.cc.repository import FileAgentConfigurationRepository, RetrievalError
@ -12,8 +12,7 @@ def repository(default_agent_configuration):
def test_store_agent_config(repository):
schema = AgentConfigurationSchema()
agent_configuration = schema.load(AGENT_CONFIGURATION)
agent_configuration = AgentConfiguration.from_mapping(AGENT_CONFIGURATION)
repository.store_configuration(agent_configuration)
retrieved_agent_configuration = repository.get_configuration()