Merge pull request #758 from shreyamalviya/pba-threading

Run post-breach phase in separate thread
This commit is contained in:
VakarisZ 2020-08-11 17:05:39 +03:00 committed by GitHub
commit 62c4eeb3fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 9 deletions

View File

@ -4,6 +4,7 @@ import os
import subprocess
import sys
import time
from threading import Thread
import infection_monkey.tunnel as tunnel
from common.network.network_utils import get_host_from_network_location
@ -138,9 +139,9 @@ class InfectionMonkey(object):
StateTelem(is_done=False, version=get_version()).send()
TunnelTelem().send()
LOG.debug("Starting the post-breach phase.")
self.collect_system_info_if_configured()
PostBreach().execute_all_configured()
LOG.debug("Starting the post-breach phase asynchronously.")
post_breach_phase = Thread(target=self.start_post_breach_phase)
post_breach_phase.start()
LOG.debug("Starting the propagation phase.")
self.shutdown_by_max_depth_reached()
@ -230,10 +231,17 @@ class InfectionMonkey(object):
if monkey_tunnel:
monkey_tunnel.stop()
monkey_tunnel.join()
post_breach_phase.join()
except PlannedShutdownException:
LOG.info("A planned shutdown of the Monkey occurred. Logging the reason and finishing execution.")
LOG.exception("Planned shutdown, reason:")
def start_post_breach_phase(self):
self.collect_system_info_if_configured()
PostBreach().execute_all_configured()
def shutdown_by_max_depth_reached(self):
if 0 == WormConfiguration.depth:
TraceTelem(MAX_DEPTH_REACHED_MESSAGE).send()

View File

@ -1,4 +1,5 @@
import logging
from multiprocessing.dummy import Pool
from typing import Sequence
from infection_monkey.post_breach.pba import PBA
@ -24,12 +25,8 @@ class PostBreach(object):
"""
Executes all post breach actions.
"""
for pba in self.pba_list:
try:
LOG.debug("Executing PBA: '{}'".format(pba.name))
pba.run()
except Exception as e:
LOG.error("PBA {} failed. Error info: {}".format(pba.name, e))
pool = Pool(5)
pool.map(self.run_pba, self.pba_list)
LOG.info("All PBAs executed. Total {} executed.".format(len(self.pba_list)))
@staticmethod
@ -38,3 +35,10 @@ class PostBreach(object):
:return: A list of PBA objects.
"""
return PBA.get_instances()
def run_pba(self, pba):
try:
LOG.debug("Executing PBA: '{}'".format(pba.name))
pba.run()
except Exception as e:
LOG.error("PBA {} failed. Error info: {}".format(pba.name, e))