Added lock on report generation and improved the get_completed_steps method

This commit is contained in:
Shay Nehmad 2019-10-02 13:04:58 +03:00
parent 32e98fa418
commit 006c177abd
1 changed files with 33 additions and 14 deletions

View File

@ -1,5 +1,6 @@
from datetime import datetime from datetime import datetime
import logging import logging
import threading
import flask_restful import flask_restful
from flask import request, make_response, jsonify from flask import request, make_response, jsonify
@ -18,13 +19,15 @@ logger = logging.getLogger(__name__)
class Root(flask_restful.Resource): class Root(flask_restful.Resource):
def __init__(self):
self.report_generating_lock = threading.Event()
def get(self, action=None): def get(self, action=None):
if not action: if not action:
action = request.args.get('action') action = request.args.get('action')
if not action: if not action:
return Root.get_server_info() return self.get_server_info()
elif action == "reset": elif action == "reset":
return jwt_required()(Database.reset_db)() return jwt_required()(Database.reset_db)()
elif action == "killall": elif action == "killall":
@ -34,11 +37,10 @@ class Root(flask_restful.Resource):
else: else:
return make_response(400, {'error': 'unknown action'}) return make_response(400, {'error': 'unknown action'})
@staticmethod
@jwt_required() @jwt_required()
def get_server_info(): def get_server_info(self):
return jsonify(ip_addresses=local_ip_addresses(), mongo=str(mongo.db), return jsonify(ip_addresses=local_ip_addresses(), mongo=str(mongo.db),
completed_steps=Root.get_completed_steps()) completed_steps=self.get_completed_steps())
@staticmethod @staticmethod
@jwt_required() @jwt_required()
@ -49,17 +51,34 @@ class Root(flask_restful.Resource):
logger.info('Kill all monkeys was called') logger.info('Kill all monkeys was called')
return jsonify(status='OK') return jsonify(status='OK')
@staticmethod
@jwt_required() @jwt_required()
def get_completed_steps(): def get_completed_steps(self):
is_any_exists = NodeService.is_any_monkey_exists() is_any_exists = NodeService.is_any_monkey_exists()
infection_done = NodeService.is_monkey_finished_running() infection_done = NodeService.is_monkey_finished_running()
if not infection_done:
report_done = False if infection_done:
else: if self.should_generate_report():
if is_any_exists: self.generate_report()
ReportService.get_report()
AttackReportService.get_latest_report()
report_done = ReportService.is_report_generated() report_done = ReportService.is_report_generated()
return dict(run_server=True, run_monkey=is_any_exists, infection_done=infection_done, else: # Infection is not done
report_done = False
return dict(
run_server=True,
run_monkey=is_any_exists,
infection_done=infection_done,
report_done=report_done) report_done=report_done)
def generate_report(self):
# Set the event - enter the critical section
self.report_generating_lock.set()
# Not using the return value, as the get_report function also saves the report in the DB for later.
_ = ReportService.get_report()
_ = AttackReportService.get_latest_report()
# Clear the event - exit the critical section
self.report_generating_lock.clear()
def should_generate_report(self):
# If the lock is not set, that means no one is generating a report right now.
is_any_thread_generating_a_report_right_now = not self.report_generating_lock.is_set()
return is_any_thread_generating_a_report_right_now and not ReportService.is_latest_report_exists()