diff --git a/monkey/monkey_island/cc/resources/remote_run.py b/monkey/monkey_island/cc/resources/remote_run.py index d1c1149b5..f872fc07b 100644 --- a/monkey/monkey_island/cc/resources/remote_run.py +++ b/monkey/monkey_island/cc/resources/remote_run.py @@ -1,4 +1,5 @@ import json +from typing import Sequence import flask_restful from botocore.exceptions import ClientError, NoCredentialsError @@ -6,6 +7,7 @@ from flask import jsonify, make_response, request from monkey_island.cc.resources.auth.auth import jwt_required from monkey_island.cc.services import AWSService +from monkey_island.cc.services.aws import AWSCommandResults CLIENT_ERROR_FORMAT = ( "ClientError, error message: '{}'. Probably, the IAM role that has been associated with the " @@ -21,7 +23,7 @@ class RemoteRun(flask_restful.Resource): def __init__(self, aws_service: AWSService): self._aws_service = aws_service - def run_aws_monkeys(self, request_body): + def run_aws_monkeys(self, request_body) -> Sequence[AWSCommandResults]: instances = request_body.get("instances") island_ip = request_body.get("island_ip") diff --git a/monkey/monkey_island/cc/services/aws/aws_service.py b/monkey/monkey_island/cc/services/aws/aws_service.py index cf4acd89d..8518f864d 100644 --- a/monkey/monkey_island/cc/services/aws/aws_service.py +++ b/monkey/monkey_island/cc/services/aws/aws_service.py @@ -1,10 +1,13 @@ import logging +from queue import Queue +from threading import Thread from typing import Any, Iterable, Mapping, Sequence import boto3 import botocore from common.aws.aws_instance import AWSInstance +from common.utils.code_utils import queue_to_list from .aws_command_runner import AWSCommandResults, start_infection_monkey_agent @@ -78,20 +81,28 @@ class AWSService: :return: A sequence of AWSCommandResults """ - results = [] - # TODO: Use threadpool or similar to run these in parallel (daemon threads) + results_queue = Queue() + command_threads = [] for i in instances: - results.append( - self._run_agent_on_managed_instance(i["instance_id"], i["os"], island_ip) + command_threads.append( + Thread( + target=self._run_agent_on_managed_instance, + args=(results_queue, i["instance_id"], i["os"], island_ip), + daemon=True, + ) ) - return results + for thread in command_threads: + thread.join() + + return queue_to_list(results_queue) def _run_agent_on_managed_instance( - self, instance_id: str, os: str, island_ip: str - ) -> AWSCommandResults: + self, results_queue: Queue, instance_id: str, os: str, island_ip: str + ): ssm_client = boto3.client("ssm", self.island_aws_instance.region) - return start_infection_monkey_agent(ssm_client, instance_id, os, island_ip) + command_results = start_infection_monkey_agent(ssm_client, instance_id, os, island_ip) + results_queue.put(command_results) def _filter_relevant_instance_info(raw_managed_instances_info: Sequence[Mapping[str, Any]]):