From 006c177abdb7ffbe006ceb0ad1aa7d2346406318 Mon Sep 17 00:00:00 2001
From: Shay Nehmad <shay.nehmad@guardicore.com>
Date: Wed, 2 Oct 2019 13:04:58 +0300
Subject: [PATCH] Added lock on report generation and improved the
 get_completed_steps method

---
 monkey/monkey_island/cc/resources/root.py | 47 ++++++++++++++++-------
 1 file changed, 33 insertions(+), 14 deletions(-)

diff --git a/monkey/monkey_island/cc/resources/root.py b/monkey/monkey_island/cc/resources/root.py
index e3b3e9854..f1914b3a2 100644
--- a/monkey/monkey_island/cc/resources/root.py
+++ b/monkey/monkey_island/cc/resources/root.py
@@ -1,5 +1,6 @@
 from datetime import datetime
 import logging
+import threading
 
 import flask_restful
 from flask import request, make_response, jsonify
@@ -18,13 +19,15 @@ logger = logging.getLogger(__name__)
 
 
 class Root(flask_restful.Resource):
+    def __init__(self):
+        self.report_generating_lock = threading.Event()
 
     def get(self, action=None):
         if not action:
             action = request.args.get('action')
 
         if not action:
-            return Root.get_server_info()
+            return self.get_server_info()
         elif action == "reset":
             return jwt_required()(Database.reset_db)()
         elif action == "killall":
@@ -34,11 +37,10 @@ class Root(flask_restful.Resource):
         else:
             return make_response(400, {'error': 'unknown action'})
 
-    @staticmethod
     @jwt_required()
-    def get_server_info():
+    def get_server_info(self):
         return jsonify(ip_addresses=local_ip_addresses(), mongo=str(mongo.db),
-                       completed_steps=Root.get_completed_steps())
+                       completed_steps=self.get_completed_steps())
 
     @staticmethod
     @jwt_required()
@@ -49,17 +51,34 @@ class Root(flask_restful.Resource):
         logger.info('Kill all monkeys was called')
         return jsonify(status='OK')
 
-    @staticmethod
     @jwt_required()
-    def get_completed_steps():
+    def get_completed_steps(self):
         is_any_exists = NodeService.is_any_monkey_exists()
         infection_done = NodeService.is_monkey_finished_running()
-        if not infection_done:
-            report_done = False
-        else:
-            if is_any_exists:
-                ReportService.get_report()
-                AttackReportService.get_latest_report()
+
+        if infection_done:
+            if self.should_generate_report():
+                self.generate_report()
             report_done = ReportService.is_report_generated()
-        return dict(run_server=True, run_monkey=is_any_exists, infection_done=infection_done,
-                    report_done=report_done)
+        else:  # Infection is not done
+            report_done = False
+
+        return dict(
+            run_server=True,
+            run_monkey=is_any_exists,
+            infection_done=infection_done,
+            report_done=report_done)
+
+    def generate_report(self):
+        # Set the event - enter the critical section
+        self.report_generating_lock.set()
+        # Not using the return value, as the get_report function also saves the report in the DB for later.
+        _ = ReportService.get_report()
+        _ = AttackReportService.get_latest_report()
+        # Clear the event - exit the critical section
+        self.report_generating_lock.clear()
+
+    def should_generate_report(self):
+        # If the lock is not set, that means no one is generating a report right now.
+        is_any_thread_generating_a_report_right_now = not self.report_generating_lock.is_set()
+        return is_any_thread_generating_a_report_right_now and not ReportService.is_latest_report_exists()