diff --git a/monkey/common/credentials/credentials.py b/monkey/common/credentials/credentials.py index 9fdf78eab..eed392b6c 100644 --- a/monkey/common/credentials/credentials.py +++ b/monkey/common/credentials/credentials.py @@ -7,6 +7,7 @@ from typing import Any, Mapping, MutableMapping, Sequence, Tuple from marshmallow import Schema, fields, post_load, pre_dump from marshmallow.exceptions import MarshmallowError +from ..utils import IJSONSerializable from . import ( CredentialComponentType, InvalidCredentialComponentError, @@ -116,7 +117,7 @@ class CredentialsSchema(Schema): @dataclass(frozen=True) -class Credentials: +class Credentials(IJSONSerializable): identities: Tuple[ICredentialComponent] secrets: Tuple[ICredentialComponent] @@ -141,8 +142,8 @@ class Credentials: except MarshmallowError as err: raise InvalidCredentialsError(str(err)) - @staticmethod - def from_json(credentials: str) -> Credentials: + @classmethod + def from_json(cls, credentials: str) -> Credentials: """ Construct a Credentials object from a JSON string @@ -180,8 +181,8 @@ class Credentials: credentials_list = json.loads(credentials_array_json) return [Credentials.from_mapping(c) for c in credentials_list] - @staticmethod - def to_json(credentials: Credentials) -> str: + @classmethod + def to_json(cls, credentials: Credentials) -> str: """ Serialize a Credentials object to JSON diff --git a/monkey/common/utils/IJSONSerializable.py b/monkey/common/utils/IJSONSerializable.py new file mode 100644 index 000000000..39eefbf90 --- /dev/null +++ b/monkey/common/utils/IJSONSerializable.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod + + +class IJSONSerializable(ABC): + @classmethod + @abstractmethod + def from_json(cls, json_string: str) -> IJSONSerializable: + pass + + @classmethod + @abstractmethod + def to_json(cls, class_object: IJSONSerializable) -> str: + pass diff --git a/monkey/common/utils/__init__.py b/monkey/common/utils/__init__.py index 57725fa62..1b11c54e3 100644 --- a/monkey/common/utils/__init__.py +++ b/monkey/common/utils/__init__.py @@ -1 +1,2 @@ from .timer import Timer +from .IJSONSerializable import IJSONSerializable diff --git a/monkey/monkey_island/cc/resources/propagation_credentials.py b/monkey/monkey_island/cc/resources/propagation_credentials.py index 9fb59ad0e..c5dfc42fe 100644 --- a/monkey/monkey_island/cc/resources/propagation_credentials.py +++ b/monkey/monkey_island/cc/resources/propagation_credentials.py @@ -1,6 +1,6 @@ from http import HTTPStatus -from flask import make_response, request +from flask import request from common.credentials import Credentials from monkey_island.cc.repository import ICredentialsRepository @@ -26,7 +26,7 @@ class PropagationCredentials(AbstractResource): else: return {}, HTTPStatus.NOT_FOUND - return make_response(Credentials.to_json_array(propagation_credentials), HTTPStatus.OK) + return propagation_credentials, HTTPStatus.OK def post(self, collection=None): credentials = [Credentials.from_json(c) for c in request.json] diff --git a/monkey/monkey_island/cc/services/representations.py b/monkey/monkey_island/cc/services/representations.py index e066764aa..e21fcdb25 100644 --- a/monkey/monkey_island/cc/services/representations.py +++ b/monkey/monkey_island/cc/services/representations.py @@ -1,47 +1,34 @@ from datetime import datetime from enum import Enum +from json import loads +from typing import Any import bson -from bson.json_util import dumps from flask import make_response +from flask.json import JSONEncoder, dumps + +from common.utils import IJSONSerializable -def _normalize_obj(obj): - if ("_id" in obj) and ("id" not in obj): - obj["id"] = obj["_id"] - del obj["_id"] - - for key, value in list(obj.items()): - obj[key] = _normalize_value(value) - return obj - - -def _normalize_value(value): - # ObjectId is serializible by default, but returns a dict - # So serialize it first into a plain string - if isinstance(value, bson.objectid.ObjectId): - return str(value) - - if isinstance(value, list): - return [_normalize_value(_value) for _value in value] - if isinstance(value, tuple): - return tuple((_normalize_value(_value) for _value in value)) - if type(value) == dict: - return _normalize_obj(value) - if isinstance(value, datetime): - return str(value) - if issubclass(type(value), Enum): - return value.name - - try: - dumps(value) - return value - except TypeError: - return value.__dict__ +class APIEncoder(JSONEncoder): + def default(self, value: Any) -> Any: + # ObjectId is serializible by default, but returns a dict + # So serialize it first into a plain string + if isinstance(value, bson.objectid.ObjectId): + return str(value) + if isinstance(value, datetime): + return str(value) + if issubclass(type(value), Enum): + return value.name + if issubclass(type(value), IJSONSerializable): + return loads(value.__class__.to_json(value)) + try: + return JSONEncoder.default(self, value) + except TypeError: + return value.__dict__ def output_json(value, code, headers=None): - value = _normalize_value(value) - resp = make_response(dumps(value), code) + resp = make_response(dumps(value, cls=APIEncoder), code) resp.headers.extend(headers or {}) return resp diff --git a/monkey/tests/unit_tests/monkey_island/cc/resources/test_propagation_credentials.py b/monkey/tests/unit_tests/monkey_island/cc/resources/test_propagation_credentials.py index 111ad7e30..93e82a223 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/resources/test_propagation_credentials.py +++ b/monkey/tests/unit_tests/monkey_island/cc/resources/test_propagation_credentials.py @@ -50,7 +50,7 @@ def test_propagation_credentials_endpoint_get(flask_client, credentials_reposito ) resp = flask_client.get(ALL_CREDENTIALS_URL) - actual_propagation_credentials = Credentials.from_json_array(resp.text) + actual_propagation_credentials = [Credentials.from_mapping(creds) for creds in resp.json] assert resp.status_code == HTTPStatus.OK assert len(actual_propagation_credentials) == 4 @@ -76,7 +76,7 @@ def test_propagation_credentials_endpoint__get_stolen(flask_client, credentials_ ) resp = flask_client.get(url) - actual_propagation_credentials = Credentials.from_json_array(resp.text) + actual_propagation_credentials = [Credentials.from_mapping(creds) for creds in resp.json] assert resp.status_code == HTTPStatus.OK assert len(actual_propagation_credentials) == 2 @@ -98,7 +98,7 @@ def test_propagation_credentials_endpoint__post_stolen(flask_client, credentials assert resp.status_code == HTTPStatus.NO_CONTENT resp = flask_client.get(url) - retrieved_propagation_credentials = Credentials.from_json_array(resp.text) + retrieved_propagation_credentials = [Credentials.from_mapping(creds) for creds in resp.json] assert resp.status_code == HTTPStatus.OK assert len(retrieved_propagation_credentials) == 3 diff --git a/monkey/tests/unit_tests/monkey_island/cc/services/test_representations.py b/monkey/tests/unit_tests/monkey_island/cc/services/test_representations.py index d2eef03c5..541a64c18 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/services/test_representations.py +++ b/monkey/tests/unit_tests/monkey_island/cc/services/test_representations.py @@ -1,10 +1,12 @@ +import json from dataclasses import dataclass from datetime import datetime from enum import Enum import bson -from monkey_island.cc.services.representations import _normalize_value +from common.utils import IJSONSerializable +from monkey_island.cc.services.representations import APIEncoder @dataclass @@ -17,22 +19,24 @@ bogus_object1 = MockClass(1) bogus_object2 = MockClass(2) -def test_normalize_dicts(): - assert {} == _normalize_value({}) +def test_api_encoder_dicts(): + assert json.dumps({}) == json.dumps({}, cls=APIEncoder) - assert {"a": "a"} == _normalize_value({"a": "a"}) + assert json.dumps({"a": "a"}) == json.dumps({"a": "a"}, cls=APIEncoder) - assert {"id": 12345} == _normalize_value({"id": 12345}) + assert json.dumps({"id": 12345}) == json.dumps({"id": 12345}, cls=APIEncoder) - assert {"id": obj_id_str} == _normalize_value({"id": bson.objectid.ObjectId(obj_id_str)}) + assert json.dumps({"id": obj_id_str}) == json.dumps( + {"id": bson.objectid.ObjectId(obj_id_str)}, cls=APIEncoder + ) dt = datetime.now() expected = {"a": str(dt)} - result = _normalize_value({"a": dt}) - assert expected == result + result = json.dumps({"a": dt}, cls=APIEncoder) + assert json.dumps(expected) == result -def test_normalize_complex(): +def test_api_encoder_complex(): bogus_dict = { "a": [ { @@ -44,26 +48,47 @@ def test_normalize_complex(): } expected_dict = {"a": [{"ba": obj_id_str, "bb": obj_id_str}], "b": {"id": obj_id_str}} - assert expected_dict == _normalize_value(bogus_dict) + assert json.dumps(expected_dict) == json.dumps(bogus_dict, cls=APIEncoder) -def test_normalize_list(): +def test_api_encoder_list(): bogus_list = [bson.objectid.ObjectId(obj_id_str), {"a": "b"}, {"object": [bogus_object1]}] expected_list = [obj_id_str, {"a": "b"}, {"object": [{"a": 1}]}] - assert expected_list == _normalize_value(bogus_list) + assert json.dumps(expected_list) == json.dumps(bogus_list, cls=APIEncoder) -def test_normalize_enum(): +def test_api_encoder_enum(): class BogusEnum(Enum): bogus_val = "Bogus" my_obj = {"something": "something", "my_enum": BogusEnum.bogus_val} - assert {"something": "something", "my_enum": "bogus_val"} == _normalize_value(my_obj) + assert json.dumps({"something": "something", "my_enum": "bogus_val"}) == json.dumps( + my_obj, cls=APIEncoder + ) -def test_normalize_tuple(): - bogus_tuple = [{"my_tuple": (bogus_object1, bogus_object2, b"one_two")}] - expected_tuple = [{"my_tuple": ({"a": 1}, {"a": 2}, b"one_two")}] - assert expected_tuple == _normalize_value(bogus_tuple) +def test_api_encoder_tuple(): + bogus_tuple = [{"my_tuple": (bogus_object1, bogus_object2, "string")}] + expected_tuple = [{"my_tuple": ({"a": 1}, {"a": 2}, "string")}] + assert json.dumps(expected_tuple) == json.dumps(bogus_tuple, cls=APIEncoder) + + +class BogusSerializableClass(IJSONSerializable): + def __init__(self, a): + self.a = a + + @classmethod + def to_json(cls, class_object: IJSONSerializable) -> str: + return json.dumps({"wacky": class_object.a}) + + @classmethod + def from_json(cls, json_string: str) -> IJSONSerializable: + pass + + +def test_api_encoder_json_serializable(): + bogus_data = {"target": [BogusSerializableClass("macky")]} + expected_result = {"target": [{"wacky": "macky"}]} + assert json.dumps(expected_result) == json.dumps(bogus_data, cls=APIEncoder) diff --git a/vulture_allowlist.py b/vulture_allowlist.py index 049e487fc..4b4262b47 100644 --- a/vulture_allowlist.py +++ b/vulture_allowlist.py @@ -4,7 +4,6 @@ dead or is kept deliberately. Referencing these in a file like this makes sure t Vulture doesn't mark these as dead again. """ from infection_monkey.exploit.log4shell_utils.ldap_server import LDAPServerFactory -from monkey_island.cc import app from monkey_island.cc.models import Report from monkey_island.cc.models.networkmap import Arc, NetworkMap from monkey_island.cc.repository.attack.IMitigationsRepository import IMitigationsRepository @@ -207,6 +206,12 @@ _make_simulation # unused method (monkey/monkey_island/cc/models/simulation.py: # TODO DELETE AFTER RESOURCE REFACTORING + +# https://github.com/jendrikseipp/vulture/issues/287 +# Both happen in common\utils\IJSONSerializable.py +json_string +class_object + NetworkMap Arc.dst_machine IMitigationsRepository.get_mitigations