From 348a74361977a67ebb8eca62fb1a90923302ca82 Mon Sep 17 00:00:00 2001 From: Shay Nehmad Date: Mon, 28 Oct 2019 10:28:40 +0200 Subject: [PATCH] Extracted api representations hooks to separate file, added UT, and fixed linter issue use `x not in y` instead of `not x in y`. --- monkey/monkey_island/cc/app.py | 32 +---------- .../cc/resources/representations.py | 31 +++++++++++ .../cc/resources/representations_test.py | 53 +++++++++++++++++++ 3 files changed, 86 insertions(+), 30 deletions(-) create mode 100644 monkey/monkey_island/cc/resources/representations.py create mode 100644 monkey/monkey_island/cc/resources/representations_test.py diff --git a/monkey/monkey_island/cc/app.py b/monkey/monkey_island/cc/app.py index 38af31bde..8ab61c895 100644 --- a/monkey/monkey_island/cc/app.py +++ b/monkey/monkey_island/cc/app.py @@ -1,11 +1,8 @@ import os import uuid -from datetime import datetime -import bson import flask_restful -from bson.json_util import dumps -from flask import Flask, send_from_directory, make_response, Response +from flask import Flask, send_from_directory, Response from werkzeug.exceptions import NotFound from monkey_island.cc.auth import init_jwt @@ -24,6 +21,7 @@ from monkey_island.cc.resources.netmap import NetMap from monkey_island.cc.resources.node import Node from monkey_island.cc.resources.remote_run import RemoteRun from monkey_island.cc.resources.reporting.report import Report +from monkey_island.cc.resources.representations import output_json from monkey_island.cc.resources.root import Root from monkey_island.cc.resources.telemetry import Telemetry from monkey_island.cc.resources.telemetry_feed import TelemetryFeed @@ -62,32 +60,6 @@ def serve_home(): return serve_static_file(HOME_FILE) -def normalize_obj(obj): - if '_id' in obj and not 'id' in obj: - obj['id'] = obj['_id'] - del obj['_id'] - - for key, value in list(obj.items()): - if isinstance(value, bson.objectid.ObjectId): - obj[key] = str(value) - if isinstance(value, datetime): - obj[key] = str(value) - if isinstance(value, dict): - obj[key] = normalize_obj(value) - if isinstance(value, list): - for i in range(0, len(value)): - if isinstance(value[i], dict): - value[i] = normalize_obj(value[i]) - return obj - - -def output_json(obj, code, headers=None): - obj = normalize_obj(obj) - resp = make_response(dumps(obj), code) - resp.headers.extend(headers or {}) - return resp - - def init_app_config(app, mongo_url): app.config['MONGO_URI'] = mongo_url app.config['SECRET_KEY'] = str(uuid.getnode()) diff --git a/monkey/monkey_island/cc/resources/representations.py b/monkey/monkey_island/cc/resources/representations.py new file mode 100644 index 000000000..cd804db50 --- /dev/null +++ b/monkey/monkey_island/cc/resources/representations.py @@ -0,0 +1,31 @@ +from datetime import datetime + +import bson +from bson.json_util import dumps +from flask import make_response + + +def normalize_obj(obj): + if ('_id' in obj) and ('id' not in obj): + obj['id'] = obj['_id'] + del obj['_id'] + + for key, value in list(obj.items()): + if isinstance(value, bson.objectid.ObjectId): + obj[key] = str(value) + if isinstance(value, datetime): + obj[key] = str(value) + if isinstance(value, dict): + obj[key] = normalize_obj(value) + if isinstance(value, list): + for i in range(0, len(value)): + if isinstance(value[i], dict): + value[i] = normalize_obj(value[i]) + return obj + + +def output_json(obj, code, headers=None): + obj = normalize_obj(obj) + resp = make_response(dumps(obj), code) + resp.headers.extend(headers or {}) + return resp diff --git a/monkey/monkey_island/cc/resources/representations_test.py b/monkey/monkey_island/cc/resources/representations_test.py new file mode 100644 index 000000000..714c70ed2 --- /dev/null +++ b/monkey/monkey_island/cc/resources/representations_test.py @@ -0,0 +1,53 @@ +from unittest import TestCase +from datetime import datetime +from .representations import normalize_obj + +import bson + + +class TestJsonRepresentations(TestCase): + def test_normalize_obj(self): + # empty + self.assertEqual({}, normalize_obj({})) + + # no special content + self.assertEqual( + {"a": "a"}, + normalize_obj({"a": "a"}) + ) + + # _id field -> id field + self.assertEqual( + {"id": 12345}, + normalize_obj({"_id": 12345}) + ) + + # obj id field -> str + obj_id_str = "123456789012345678901234" + self.assertEqual( + {"id": obj_id_str}, + normalize_obj({"_id": bson.objectid.ObjectId(obj_id_str)}) + ) + + # datetime -> str + dt = datetime.now() + expected = {"a": str(dt)} + result = normalize_obj({"a": dt}) + self.assertEqual(expected, result) + + # dicts and lists + self.assertEqual({ + "a": [ + {"ba": obj_id_str, + "bb": obj_id_str} + ], + "b": {"id": obj_id_str} + }, + normalize_obj({ + "a": [ + {"ba": bson.objectid.ObjectId(obj_id_str), + "bb": bson.objectid.ObjectId(obj_id_str)} + ], + "b": {"_id": bson.objectid.ObjectId(obj_id_str)} + }) + )