diff --git a/monkey/common/cmd/aws_cmd_runner.py b/monkey/common/cmd/aws_cmd_runner.py index 1b2d01e42..10927cb81 100644 --- a/monkey/common/cmd/aws_cmd_runner.py +++ b/monkey/common/cmd/aws_cmd_runner.py @@ -1,11 +1,11 @@ import time import logging -import json from common.cloud.aws_service import AwsService from common.cmd.aws_cmd_result import AwsCmdResult from common.cmd.cmd_result import CmdResult from common.cmd.cmd_runner import CmdRunner +from common.cmd.cmd_status import CmdStatus __author__ = 'itay.mizeretz' @@ -23,33 +23,21 @@ class AwsCmdRunner(CmdRunner): self.region = region self.ssm = AwsService.get_client('ssm', region) - def run_command(self, command, timeout=CmdRunner.DEFAULT_TIMEOUT): - # TODO: document - command_id = self._send_command(command) - init_time = time.time() - curr_time = init_time - command_info = None + def query_command(self, command_id): + return self.ssm.get_command_invocation(CommandId=command_id, InstanceId=self.instance_id) - try: - while curr_time - init_time < timeout: - command_info = self.ssm.get_command_invocation(CommandId=command_id, InstanceId=self.instance_id) - if AwsCmdResult.is_successful(command_info): - break - else: - time.sleep(0.5) - curr_time = time.time() + def get_command_result(self, command_info): + return AwsCmdResult(command_info) - cmd_res = AwsCmdResult(command_info) + def get_command_status(self, command_info): + if command_info[u'Status'] == u'InProgress': + return CmdStatus.IN_PROGRESS + elif command_info[u'Status'] == u'Success': + return CmdStatus.SUCCESS + else: + return CmdStatus.FAILURE - if not cmd_res.is_success: - logger.error('Failed running AWS command: `%s`. status code: %s', command, str(cmd_res.status_code)) - - return cmd_res - except Exception: - logger.exception('Exception while running AWS command: `%s`', command) - return CmdResult(False) - - def _send_command(self, command): + def run_command_async(self, command): doc_name = "AWS-RunShellScript" if self.is_linux else "AWS-RunPowerShellScript" command_res = self.ssm.send_command(DocumentName=doc_name, Parameters={'commands': [command]}, InstanceIds=[self.instance_id]) diff --git a/monkey/common/cmd/cmd_runner.py b/monkey/common/cmd/cmd_runner.py index 1875b7d4e..c0541cc0b 100644 --- a/monkey/common/cmd/cmd_runner.py +++ b/monkey/common/cmd/cmd_runner.py @@ -1,7 +1,14 @@ +import time +import logging from abc import abstractmethod +from common.cmd.cmd_result import CmdResult +from common.cmd.cmd_status import CmdStatus + __author__ = 'itay.mizeretz' +logger = logging.getLogger(__name__) + class CmdRunner(object): """ @@ -10,11 +17,12 @@ class CmdRunner(object): # Default command timeout in seconds DEFAULT_TIMEOUT = 5 + # Time to sleep when waiting on commands. + WAIT_SLEEP_TIME = 1 def __init__(self, is_linux): self.is_linux = is_linux - @abstractmethod def run_command(self, command, timeout=DEFAULT_TIMEOUT): """ Runs the given command on the remote machine @@ -22,20 +30,94 @@ class CmdRunner(object): :param timeout: Timeout in seconds for command. :return: Command result """ + c_id = self.run_command_async(command) + self.wait_commands([(self, c_id)], timeout) + + @abstractmethod + def run_command_async(self, command): + """ + Runs the given command on the remote machine asynchronously. + :param command: The command to run + :return: Command ID (in any format) + """ raise NotImplementedError() - def is_64bit(self): + @staticmethod + def wait_commands(commands, timeout=DEFAULT_TIMEOUT): """ - Runs a command to determine whether OS is 32 or 64 bit. - :return: True if 64bit, False if 32bit, None if failed. + Waits on all commands up to given timeout + :param commands: list of tuples of command IDs and command runners + :param timeout: Timeout in seconds for command. + :return: commands' results (tuple of """ - if self.is_linux: - cmd_result = self.run_command('uname -m') - if not cmd_result.is_success: - return None - return cmd_result.stdout.find('i686') == -1 # i686 means 32bit - else: - cmd_result = self.run_command('Get-ChildItem Env:') - if not cmd_result.is_success: - return None - return cmd_result.stdout.lower().find('programfiles(x86)') != -1 # if not found it means 32bit + init_time = time.time() + curr_time = init_time + + results = [] + + while (curr_time - init_time < timeout) and (len(commands) != 0): + for command in list(commands): + CmdRunner._process_command(command, commands, results, True) + + time.sleep(CmdRunner.WAIT_SLEEP_TIME) + curr_time = time.time() + + for command in list(commands): + CmdRunner._process_command(command, commands, results, False) + + for command, result in results: + if not result.is_success: + logger.error('The following command failed: `%s`. status code: %s', + str(command[1]), str(result.status_code)) + + return results + + @abstractmethod + def query_command(self, command_id): + """ + Queries the already run command for more info + :param command_id: The command ID to query + :return: Command info (in any format) + """ + raise NotImplementedError() + + @abstractmethod + def get_command_result(self, command_info): + """ + Gets the result of the already run command + :param command_info: The command info of the command to get the result of + :return: CmdResult + """ + raise NotImplementedError() + + @abstractmethod + def get_command_status(self, command_info): + """ + Gets the status of the already run command + :param command_info: The command info of the command to get the result of + :return: CmdStatus + """ + raise NotImplementedError() + + @staticmethod + def _process_command(command, commands, results, should_process_only_finished): + """ + Removes the command from the list, processes its result and appends to results + :param command: Command to process. Must be in commands. + :param commands: List of unprocessed commands. + :param results: List of command results. + :param should_process_only_finished: If True, processes only if command finished. + :return: None + """ + c_runner = command[0] + c_id = command[1] + try: + command_info = c_runner.query_command(c_id) + if (not should_process_only_finished) or c_runner.get_command_status(command_info) != CmdStatus.IN_PROGRESS: + commands.remove(command) + results.append((command, c_runner.get_command_result(command_info))) + except Exception: + logger.exception('Exception while querying command: `%s`', str(c_id)) + if not should_process_only_finished: + commands.remove(command) + results.append((command, CmdResult(False))) diff --git a/monkey/common/cmd/cmd_status.py b/monkey/common/cmd/cmd_status.py new file mode 100644 index 000000000..2fc9cc168 --- /dev/null +++ b/monkey/common/cmd/cmd_status.py @@ -0,0 +1,9 @@ +from enum import Enum + +__author__ = 'itay.mizeretz' + + +class CmdStatus(Enum): + IN_PROGRESS = 0 + SUCCESS = 1 + FAILURE = 2 diff --git a/monkey/monkey_island/cc/resources/remote_run.py b/monkey/monkey_island/cc/resources/remote_run.py index d4ebbed0b..b4da8eaf2 100644 --- a/monkey/monkey_island/cc/resources/remote_run.py +++ b/monkey/monkey_island/cc/resources/remote_run.py @@ -7,6 +7,7 @@ from cc.services.config import ConfigService from common.cloud.aws_instance import AwsInstance from common.cloud.aws_service import AwsService from common.cmd.aws_cmd_runner import AwsCmdRunner +from common.cmd.cmd_runner import CmdRunner class RemoteRun(flask_restful.Resource): @@ -18,50 +19,84 @@ class RemoteRun(flask_restful.Resource): self.init_aws_auth_params() instances = request_body.get('instances') island_ip = request_body.get('island_ip') + instances_bitness = self.get_bitness(instances) + return self.run_multiple_commands( + instances, + lambda instance: self.run_aws_monkey_cmd_async(instance['instance_id'], + instance['os'], island_ip, instances_bitness[instance['instance_id']]), + lambda _, result: result.is_success) - results = {} + def run_multiple_commands(self, instances, inst_to_cmd, inst_n_cmd_res_to_res): + command_instance_dict = {} for instance in instances: - is_success = self.run_aws_monkey_cmd(instance['instance_id'], instance['os'], island_ip) - results[instance['instance_id']] = is_success + command = inst_to_cmd(instance) + command_instance_dict[command] = instance - return results + instance_results = {} + results = CmdRunner.wait_commands(command_instance_dict.keys()) + for command, result in results: + instance = command_instance_dict[command] + instance_results[instance['instance_id']] = inst_n_cmd_res_to_res(instance, result) - def run_aws_monkey_cmd(self, instance_id, os, island_ip): + return instance_results + + def get_bitness(self, instances): + return self.run_multiple_commands( + instances, + lambda instance: RemoteRun.run_aws_bitness_cmd_async(instance['instance_id'], instance['os']), + lambda instance, result: self.get_bitness_by_result('linux' == instance['os'], result)) + + def get_bitness_by_result(self, is_linux, result): + if not result.is_success: + return None + elif is_linux: + return result.stdout.find('i686') == -1 # i686 means 32bit + else: + return result.stdout.lower().find('programfiles(x86)') != -1 # if not found it means 32bit + + @staticmethod + def run_aws_bitness_cmd_async(instance_id, os): + """ + Runs an AWS command to check bitness + :param instance_id: Instance ID of target + :param os: OS of target ('linux' or 'windows') + :return: Tuple of CmdRunner and command id + """ + is_linux = ('linux' == os) + cmd = AwsCmdRunner(instance_id, None, is_linux) + cmd_text = 'uname -m' if is_linux else 'Get-ChildItem Env:' + return cmd, cmd.run_command_async(cmd_text) + + def run_aws_monkey_cmd_async(self, instance_id, os, island_ip, is_64bit): """ Runs a monkey remotely using AWS :param instance_id: Instance ID of target :param os: OS of target ('linux' or 'windows') :param island_ip: IP of the island which the instance will try to connect to - :return: True if successfully ran monkey, False otherwise. + :param is_64bit: Whether the instance is 64bit + :return: Tuple of CmdRunner and command id """ is_linux = ('linux' == os) cmd = AwsCmdRunner(instance_id, None, is_linux) - is_64bit = cmd.is_64bit() - cmd_text = self._get_run_monkey_cmd(is_linux, is_64bit, island_ip) - return cmd.run_command(cmd_text).is_success + cmd_text = self._get_run_monkey_cmd_line(is_linux, is_64bit, island_ip) + return cmd, cmd.run_command_async(cmd_text) - def _get_run_monkey_cmd_linux(self, bit_text, island_ip): + def _get_run_monkey_cmd_linux_line(self, bit_text, island_ip): return r'wget --no-check-certificate https://' + island_ip + r':5000/api/monkey/download/monkey-linux-' + \ bit_text + r'; chmod +x monkey-linux-' + bit_text + r'; ./monkey-linux-' + bit_text + r' m0nk3y -s ' + \ island_ip + r':5000' - """ - return r'curl -O -k https://' + island_ip + r':5000/api/monkey/download/monkey-linux-' + bit_text + \ - r'; chmod +x monkey-linux-' + bit_text + \ - r'; ./monkey-linux-' + bit_text + r' m0nk3y -s ' + \ - island_ip + r':5000' - """ - def _get_run_monkey_cmd_windows(self, bit_text, island_ip): + def _get_run_monkey_cmd_windows_line(self, bit_text, island_ip): return r"[System.Net.ServicePointManager]::ServerCertificateValidationCallback = {" \ r"$true}; (New-Object System.Net.WebClient).DownloadFile('https://" + island_ip + \ r":5000/api/monkey/download/monkey-windows-" + bit_text + r".exe','.\\monkey.exe'); " \ r";Start-Process -FilePath '.\\monkey.exe' -ArgumentList 'm0nk3y -s " + island_ip + r":5000'; " - def _get_run_monkey_cmd(self, is_linux, is_64bit, island_ip): + def _get_run_monkey_cmd_line(self, is_linux, is_64bit, island_ip): bit_text = '64' if is_64bit else '32' - return self._get_run_monkey_cmd_linux(bit_text, island_ip) if is_linux \ - else self._get_run_monkey_cmd_windows(bit_text, island_ip) + return self._get_run_monkey_cmd_linux_line(bit_text, island_ip) if is_linux \ + else self._get_run_monkey_cmd_windows_line(bit_text, island_ip) def init_aws_auth_params(self): access_key_id = ConfigService.get_config_value(['cnc', 'aws_config', 'aws_access_key_id'], False, True)