Extracted api representations hooks to separate file, added UT, and fixed linter issue

use `x not in y` instead of `not x in y`.
This commit is contained in:
Shay Nehmad 2019-10-28 10:28:40 +02:00
parent 59a779822b
commit 348a743619
3 changed files with 86 additions and 30 deletions

View File

@ -1,11 +1,8 @@
import os
import uuid
from datetime import datetime
import bson
import flask_restful
from bson.json_util import dumps
from flask import Flask, send_from_directory, make_response, Response
from flask import Flask, send_from_directory, Response
from werkzeug.exceptions import NotFound
from monkey_island.cc.auth import init_jwt
@ -24,6 +21,7 @@ from monkey_island.cc.resources.netmap import NetMap
from monkey_island.cc.resources.node import Node
from monkey_island.cc.resources.remote_run import RemoteRun
from monkey_island.cc.resources.reporting.report import Report
from monkey_island.cc.resources.representations import output_json
from monkey_island.cc.resources.root import Root
from monkey_island.cc.resources.telemetry import Telemetry
from monkey_island.cc.resources.telemetry_feed import TelemetryFeed
@ -62,32 +60,6 @@ def serve_home():
return serve_static_file(HOME_FILE)
def normalize_obj(obj):
if '_id' in obj and not 'id' in obj:
obj['id'] = obj['_id']
del obj['_id']
for key, value in list(obj.items()):
if isinstance(value, bson.objectid.ObjectId):
obj[key] = str(value)
if isinstance(value, datetime):
obj[key] = str(value)
if isinstance(value, dict):
obj[key] = normalize_obj(value)
if isinstance(value, list):
for i in range(0, len(value)):
if isinstance(value[i], dict):
value[i] = normalize_obj(value[i])
return obj
def output_json(obj, code, headers=None):
obj = normalize_obj(obj)
resp = make_response(dumps(obj), code)
resp.headers.extend(headers or {})
return resp
def init_app_config(app, mongo_url):
app.config['MONGO_URI'] = mongo_url
app.config['SECRET_KEY'] = str(uuid.getnode())

View File

@ -0,0 +1,31 @@
from datetime import datetime
import bson
from bson.json_util import dumps
from flask import make_response
def normalize_obj(obj):
if ('_id' in obj) and ('id' not in obj):
obj['id'] = obj['_id']
del obj['_id']
for key, value in list(obj.items()):
if isinstance(value, bson.objectid.ObjectId):
obj[key] = str(value)
if isinstance(value, datetime):
obj[key] = str(value)
if isinstance(value, dict):
obj[key] = normalize_obj(value)
if isinstance(value, list):
for i in range(0, len(value)):
if isinstance(value[i], dict):
value[i] = normalize_obj(value[i])
return obj
def output_json(obj, code, headers=None):
obj = normalize_obj(obj)
resp = make_response(dumps(obj), code)
resp.headers.extend(headers or {})
return resp

View File

@ -0,0 +1,53 @@
from unittest import TestCase
from datetime import datetime
from .representations import normalize_obj
import bson
class TestJsonRepresentations(TestCase):
def test_normalize_obj(self):
# empty
self.assertEqual({}, normalize_obj({}))
# no special content
self.assertEqual(
{"a": "a"},
normalize_obj({"a": "a"})
)
# _id field -> id field
self.assertEqual(
{"id": 12345},
normalize_obj({"_id": 12345})
)
# obj id field -> str
obj_id_str = "123456789012345678901234"
self.assertEqual(
{"id": obj_id_str},
normalize_obj({"_id": bson.objectid.ObjectId(obj_id_str)})
)
# datetime -> str
dt = datetime.now()
expected = {"a": str(dt)}
result = normalize_obj({"a": dt})
self.assertEqual(expected, result)
# dicts and lists
self.assertEqual({
"a": [
{"ba": obj_id_str,
"bb": obj_id_str}
],
"b": {"id": obj_id_str}
},
normalize_obj({
"a": [
{"ba": bson.objectid.ObjectId(obj_id_str),
"bb": bson.objectid.ObjectId(obj_id_str)}
],
"b": {"_id": bson.objectid.ObjectId(obj_id_str)}
})
)