Refactored utility method and added unit test

This commit is contained in:
VakarisZ 2021-01-04 17:04:31 +02:00
parent eaf9b6a8d1
commit 28601d97ed
7 changed files with 23 additions and 23 deletions

View File

@ -1,9 +1,6 @@
import re import re
from urllib.parse import urlparse from urllib.parse import urlparse
from infection_monkey.config import WormConfiguration
from infection_monkey.network.tools import is_running_on_server
def get_host_from_network_location(network_location: str) -> str: def get_host_from_network_location(network_location: str) -> str:
""" """
@ -21,9 +18,3 @@ def remove_port(url):
with_port = f'{parsed.scheme}://{parsed.netloc}' with_port = f'{parsed.scheme}://{parsed.netloc}'
without_port = re.sub(':[0-9]+(?=$|\/)', '', with_port) without_port = re.sub(':[0-9]+(?=$|\/)', '', with_port)
return without_port return without_port
def is_running_on_island():
current_server_without_port = get_host_from_network_location(WormConfiguration.current_server)
running_on_island = is_running_on_server(current_server_without_port)
return running_on_island and WormConfiguration.depth == WormConfiguration.max_depth

View File

@ -2,7 +2,7 @@
# noinspection PyPep8Naming # noinspection PyPep8Naming
import operator import operator
from functools import reduce from functools import reduce
from typing import List from typing import List, Union, Any
class abstractstatic(staticmethod): class abstractstatic(staticmethod):
@ -15,13 +15,5 @@ class abstractstatic(staticmethod):
__isabstractmethod__ = True __isabstractmethod__ = True
def _get_value_by_path(data, path: List[str]): def get_dict_value_by_path(data: dict, path: List[str]) -> Any:
return reduce(operator.getitem, path, data) return reduce(operator.getitem, path, data)
def get_object_value_by_path(data_object: object, path: List[str]):
return _get_value_by_path(data_object, path)
def get_dict_value_by_path(data_dict: dict, path: List[str]):
return _get_value_by_path(data_dict, path)

View File

@ -0,0 +1,9 @@
import unittest
from common.utils.code_utils import get_dict_value_by_path
class TestCodeUtils(unittest.TestCase):
def test_get_dict_value_by_path(self):
dict_for_test = {'a': {'b': {'c': 'result'}}}
self.assertEqual(get_dict_value_by_path(dict_for_test, ['a', 'b', 'c']), 'result')

View File

@ -7,7 +7,7 @@ import time
from threading import Thread from threading import Thread
import infection_monkey.tunnel as tunnel import infection_monkey.tunnel as tunnel
from common.network.network_utils import is_running_on_island from infection_monkey.network.tools import is_running_on_island
from common.utils.attack_utils import ScanStatus, UsageEnum from common.utils.attack_utils import ScanStatus, UsageEnum
from common.utils.exceptions import ExploitingVulnerableMachineError, FailedExploitationError from common.utils.exceptions import ExploitingVulnerableMachineError, FailedExploitationError
from common.version import get_version from common.version import get_version

View File

@ -7,6 +7,8 @@ import subprocess
import sys import sys
import time import time
from common.network.network_utils import get_host_from_network_location
from infection_monkey.config import WormConfiguration
from infection_monkey.network.info import get_routes, local_ips from infection_monkey.network.info import get_routes, local_ips
from infection_monkey.pyinstaller_utils import get_binary_file_path from infection_monkey.pyinstaller_utils import get_binary_file_path
from infection_monkey.utils.environment import is_64bit_python from infection_monkey.utils.environment import is_64bit_python
@ -311,5 +313,11 @@ def get_interface_to_target(dst):
return ret[1] return ret[1]
def is_running_on_island():
current_server_without_port = get_host_from_network_location(WormConfiguration.current_server)
running_on_island = is_running_on_server(current_server_without_port)
return running_on_island and WormConfiguration.depth == WormConfiguration.max_depth
def is_running_on_server(ip: str) -> bool: def is_running_on_server(ip: str) -> bool:
return ip in local_ips() return ip in local_ips()

View File

@ -3,7 +3,7 @@ import logging
from common.cloud.aws.aws_instance import AwsInstance from common.cloud.aws.aws_instance import AwsInstance
from common.cloud.scoutsuite_consts import CloudProviders from common.cloud.scoutsuite_consts import CloudProviders
from common.common_consts.system_info_collectors_names import AWS_COLLECTOR from common.common_consts.system_info_collectors_names import AWS_COLLECTOR
from common.network.network_utils import is_running_on_island from infection_monkey.network.tools import is_running_on_island
from infection_monkey.system_info.collectors.scoutsuite_collector.scoutsuite_collector import scan_cloud_security from infection_monkey.system_info.collectors.scoutsuite_collector.scoutsuite_collector import scan_cloud_security
from infection_monkey.system_info.system_info_collector import SystemInfoCollector from infection_monkey.system_info.system_info_collector import SystemInfoCollector

View File

@ -1,4 +1,4 @@
from common.utils.code_utils import get_object_value_by_path from common.utils.code_utils import get_dict_value_by_path
from common.utils.exceptions import RulePathCreatorNotFound from common.utils.exceptions import RulePathCreatorNotFound
from monkey_island.cc.services.zero_trust.scoutsuite.data_parsing.rule_path_building.rule_path_creators_list import \ from monkey_island.cc.services.zero_trust.scoutsuite.data_parsing.rule_path_building.rule_path_creators_list import \
RULE_PATH_CREATORS_LIST RULE_PATH_CREATORS_LIST
@ -9,7 +9,7 @@ class RuleParser:
@staticmethod @staticmethod
def get_rule_data(scoutsuite_data, rule_name): def get_rule_data(scoutsuite_data, rule_name):
rule_path = RuleParser.get_rule_path(rule_name) rule_path = RuleParser.get_rule_path(rule_name)
return get_object_value_by_path(scoutsuite_data, rule_path) return get_dict_value_by_path(data=scoutsuite_data, path=rule_path)
@staticmethod @staticmethod
def get_rule_path(rule_name): def get_rule_path(rule_name):