Merge pull request #2079 from guardicore/1965-credential-serialization

1965 credential serialization
This commit is contained in:
Mike Salvatore 2022-07-12 09:40:24 -04:00 committed by GitHub
commit 5fe232aaa1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 98 additions and 64 deletions

View File

@ -7,6 +7,7 @@ from typing import Any, Mapping, MutableMapping, Sequence, Tuple
from marshmallow import Schema, fields, post_load, pre_dump from marshmallow import Schema, fields, post_load, pre_dump
from marshmallow.exceptions import MarshmallowError from marshmallow.exceptions import MarshmallowError
from ..utils import IJSONSerializable
from . import ( from . import (
CredentialComponentType, CredentialComponentType,
InvalidCredentialComponentError, InvalidCredentialComponentError,
@ -116,7 +117,7 @@ class CredentialsSchema(Schema):
@dataclass(frozen=True) @dataclass(frozen=True)
class Credentials: class Credentials(IJSONSerializable):
identities: Tuple[ICredentialComponent] identities: Tuple[ICredentialComponent]
secrets: Tuple[ICredentialComponent] secrets: Tuple[ICredentialComponent]
@ -141,8 +142,8 @@ class Credentials:
except MarshmallowError as err: except MarshmallowError as err:
raise InvalidCredentialsError(str(err)) raise InvalidCredentialsError(str(err))
@staticmethod @classmethod
def from_json(credentials: str) -> Credentials: def from_json(cls, credentials: str) -> Credentials:
""" """
Construct a Credentials object from a JSON string Construct a Credentials object from a JSON string
@ -180,8 +181,8 @@ class Credentials:
credentials_list = json.loads(credentials_array_json) credentials_list = json.loads(credentials_array_json)
return [Credentials.from_mapping(c) for c in credentials_list] return [Credentials.from_mapping(c) for c in credentials_list]
@staticmethod @classmethod
def to_json(credentials: Credentials) -> str: def to_json(cls, credentials: Credentials) -> str:
""" """
Serialize a Credentials object to JSON Serialize a Credentials object to JSON

View File

@ -0,0 +1,15 @@
from __future__ import annotations
from abc import ABC, abstractmethod
class IJSONSerializable(ABC):
@classmethod
@abstractmethod
def from_json(cls, json_string: str) -> IJSONSerializable:
pass
@classmethod
@abstractmethod
def to_json(cls, class_object: IJSONSerializable) -> str:
pass

View File

@ -1 +1,2 @@
from .timer import Timer from .timer import Timer
from .IJSONSerializable import IJSONSerializable

View File

@ -1,6 +1,6 @@
from http import HTTPStatus from http import HTTPStatus
from flask import make_response, request from flask import request
from common.credentials import Credentials from common.credentials import Credentials
from monkey_island.cc.repository import ICredentialsRepository from monkey_island.cc.repository import ICredentialsRepository
@ -26,7 +26,7 @@ class PropagationCredentials(AbstractResource):
else: else:
return {}, HTTPStatus.NOT_FOUND return {}, HTTPStatus.NOT_FOUND
return make_response(Credentials.to_json_array(propagation_credentials), HTTPStatus.OK) return propagation_credentials, HTTPStatus.OK
def post(self, collection=None): def post(self, collection=None):
credentials = [Credentials.from_json(c) for c in request.json] credentials = [Credentials.from_json(c) for c in request.json]

View File

@ -1,47 +1,34 @@
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from json import loads
from typing import Any
import bson import bson
from bson.json_util import dumps
from flask import make_response from flask import make_response
from flask.json import JSONEncoder, dumps
from common.utils import IJSONSerializable
def _normalize_obj(obj): class APIEncoder(JSONEncoder):
if ("_id" in obj) and ("id" not in obj): def default(self, value: Any) -> Any:
obj["id"] = obj["_id"]
del obj["_id"]
for key, value in list(obj.items()):
obj[key] = _normalize_value(value)
return obj
def _normalize_value(value):
# ObjectId is serializible by default, but returns a dict # ObjectId is serializible by default, but returns a dict
# So serialize it first into a plain string # So serialize it first into a plain string
if isinstance(value, bson.objectid.ObjectId): if isinstance(value, bson.objectid.ObjectId):
return str(value) return str(value)
if isinstance(value, list):
return [_normalize_value(_value) for _value in value]
if isinstance(value, tuple):
return tuple((_normalize_value(_value) for _value in value))
if type(value) == dict:
return _normalize_obj(value)
if isinstance(value, datetime): if isinstance(value, datetime):
return str(value) return str(value)
if issubclass(type(value), Enum): if issubclass(type(value), Enum):
return value.name return value.name
if issubclass(type(value), IJSONSerializable):
return loads(value.__class__.to_json(value))
try: try:
dumps(value) return JSONEncoder.default(self, value)
return value
except TypeError: except TypeError:
return value.__dict__ return value.__dict__
def output_json(value, code, headers=None): def output_json(value, code, headers=None):
value = _normalize_value(value) resp = make_response(dumps(value, cls=APIEncoder), code)
resp = make_response(dumps(value), code)
resp.headers.extend(headers or {}) resp.headers.extend(headers or {})
return resp return resp

View File

@ -50,7 +50,7 @@ def test_propagation_credentials_endpoint_get(flask_client, credentials_reposito
) )
resp = flask_client.get(ALL_CREDENTIALS_URL) resp = flask_client.get(ALL_CREDENTIALS_URL)
actual_propagation_credentials = Credentials.from_json_array(resp.text) actual_propagation_credentials = [Credentials.from_mapping(creds) for creds in resp.json]
assert resp.status_code == HTTPStatus.OK assert resp.status_code == HTTPStatus.OK
assert len(actual_propagation_credentials) == 4 assert len(actual_propagation_credentials) == 4
@ -76,7 +76,7 @@ def test_propagation_credentials_endpoint__get_stolen(flask_client, credentials_
) )
resp = flask_client.get(url) resp = flask_client.get(url)
actual_propagation_credentials = Credentials.from_json_array(resp.text) actual_propagation_credentials = [Credentials.from_mapping(creds) for creds in resp.json]
assert resp.status_code == HTTPStatus.OK assert resp.status_code == HTTPStatus.OK
assert len(actual_propagation_credentials) == 2 assert len(actual_propagation_credentials) == 2
@ -98,7 +98,7 @@ def test_propagation_credentials_endpoint__post_stolen(flask_client, credentials
assert resp.status_code == HTTPStatus.NO_CONTENT assert resp.status_code == HTTPStatus.NO_CONTENT
resp = flask_client.get(url) resp = flask_client.get(url)
retrieved_propagation_credentials = Credentials.from_json_array(resp.text) retrieved_propagation_credentials = [Credentials.from_mapping(creds) for creds in resp.json]
assert resp.status_code == HTTPStatus.OK assert resp.status_code == HTTPStatus.OK
assert len(retrieved_propagation_credentials) == 3 assert len(retrieved_propagation_credentials) == 3

View File

@ -1,10 +1,12 @@
import json
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
import bson import bson
from monkey_island.cc.services.representations import _normalize_value from common.utils import IJSONSerializable
from monkey_island.cc.services.representations import APIEncoder
@dataclass @dataclass
@ -17,22 +19,24 @@ bogus_object1 = MockClass(1)
bogus_object2 = MockClass(2) bogus_object2 = MockClass(2)
def test_normalize_dicts(): def test_api_encoder_dicts():
assert {} == _normalize_value({}) assert json.dumps({}) == json.dumps({}, cls=APIEncoder)
assert {"a": "a"} == _normalize_value({"a": "a"}) assert json.dumps({"a": "a"}) == json.dumps({"a": "a"}, cls=APIEncoder)
assert {"id": 12345} == _normalize_value({"id": 12345}) assert json.dumps({"id": 12345}) == json.dumps({"id": 12345}, cls=APIEncoder)
assert {"id": obj_id_str} == _normalize_value({"id": bson.objectid.ObjectId(obj_id_str)}) assert json.dumps({"id": obj_id_str}) == json.dumps(
{"id": bson.objectid.ObjectId(obj_id_str)}, cls=APIEncoder
)
dt = datetime.now() dt = datetime.now()
expected = {"a": str(dt)} expected = {"a": str(dt)}
result = _normalize_value({"a": dt}) result = json.dumps({"a": dt}, cls=APIEncoder)
assert expected == result assert json.dumps(expected) == result
def test_normalize_complex(): def test_api_encoder_complex():
bogus_dict = { bogus_dict = {
"a": [ "a": [
{ {
@ -44,26 +48,47 @@ def test_normalize_complex():
} }
expected_dict = {"a": [{"ba": obj_id_str, "bb": obj_id_str}], "b": {"id": obj_id_str}} expected_dict = {"a": [{"ba": obj_id_str, "bb": obj_id_str}], "b": {"id": obj_id_str}}
assert expected_dict == _normalize_value(bogus_dict) assert json.dumps(expected_dict) == json.dumps(bogus_dict, cls=APIEncoder)
def test_normalize_list(): def test_api_encoder_list():
bogus_list = [bson.objectid.ObjectId(obj_id_str), {"a": "b"}, {"object": [bogus_object1]}] bogus_list = [bson.objectid.ObjectId(obj_id_str), {"a": "b"}, {"object": [bogus_object1]}]
expected_list = [obj_id_str, {"a": "b"}, {"object": [{"a": 1}]}] expected_list = [obj_id_str, {"a": "b"}, {"object": [{"a": 1}]}]
assert expected_list == _normalize_value(bogus_list) assert json.dumps(expected_list) == json.dumps(bogus_list, cls=APIEncoder)
def test_normalize_enum(): def test_api_encoder_enum():
class BogusEnum(Enum): class BogusEnum(Enum):
bogus_val = "Bogus" bogus_val = "Bogus"
my_obj = {"something": "something", "my_enum": BogusEnum.bogus_val} my_obj = {"something": "something", "my_enum": BogusEnum.bogus_val}
assert {"something": "something", "my_enum": "bogus_val"} == _normalize_value(my_obj) assert json.dumps({"something": "something", "my_enum": "bogus_val"}) == json.dumps(
my_obj, cls=APIEncoder
)
def test_normalize_tuple(): def test_api_encoder_tuple():
bogus_tuple = [{"my_tuple": (bogus_object1, bogus_object2, b"one_two")}] bogus_tuple = [{"my_tuple": (bogus_object1, bogus_object2, "string")}]
expected_tuple = [{"my_tuple": ({"a": 1}, {"a": 2}, b"one_two")}] expected_tuple = [{"my_tuple": ({"a": 1}, {"a": 2}, "string")}]
assert expected_tuple == _normalize_value(bogus_tuple) assert json.dumps(expected_tuple) == json.dumps(bogus_tuple, cls=APIEncoder)
class BogusSerializableClass(IJSONSerializable):
def __init__(self, a):
self.a = a
@classmethod
def to_json(cls, class_object: IJSONSerializable) -> str:
return json.dumps({"wacky": class_object.a})
@classmethod
def from_json(cls, json_string: str) -> IJSONSerializable:
pass
def test_api_encoder_json_serializable():
bogus_data = {"target": [BogusSerializableClass("macky")]}
expected_result = {"target": [{"wacky": "macky"}]}
assert json.dumps(expected_result) == json.dumps(bogus_data, cls=APIEncoder)

View File

@ -4,7 +4,6 @@ dead or is kept deliberately. Referencing these in a file like this makes sure t
Vulture doesn't mark these as dead again. Vulture doesn't mark these as dead again.
""" """
from infection_monkey.exploit.log4shell_utils.ldap_server import LDAPServerFactory from infection_monkey.exploit.log4shell_utils.ldap_server import LDAPServerFactory
from monkey_island.cc import app
from monkey_island.cc.models import Report from monkey_island.cc.models import Report
from monkey_island.cc.models.networkmap import Arc, NetworkMap from monkey_island.cc.models.networkmap import Arc, NetworkMap
from monkey_island.cc.repository.attack.IMitigationsRepository import IMitigationsRepository from monkey_island.cc.repository.attack.IMitigationsRepository import IMitigationsRepository
@ -207,6 +206,12 @@ _make_simulation # unused method (monkey/monkey_island/cc/models/simulation.py:
# TODO DELETE AFTER RESOURCE REFACTORING # TODO DELETE AFTER RESOURCE REFACTORING
# https://github.com/jendrikseipp/vulture/issues/287
# Both happen in common\utils\IJSONSerializable.py
json_string
class_object
NetworkMap NetworkMap
Arc.dst_machine Arc.dst_machine
IMitigationsRepository.get_mitigations IMitigationsRepository.get_mitigations