Refactored utility method and added unit test
This commit is contained in:
parent
eaf9b6a8d1
commit
28601d97ed
|
@ -1,9 +1,6 @@
|
|||
import re
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from infection_monkey.config import WormConfiguration
|
||||
from infection_monkey.network.tools import is_running_on_server
|
||||
|
||||
|
||||
def get_host_from_network_location(network_location: str) -> str:
|
||||
"""
|
||||
|
@ -21,9 +18,3 @@ def remove_port(url):
|
|||
with_port = f'{parsed.scheme}://{parsed.netloc}'
|
||||
without_port = re.sub(':[0-9]+(?=$|\/)', '', with_port)
|
||||
return without_port
|
||||
|
||||
|
||||
def is_running_on_island():
|
||||
current_server_without_port = get_host_from_network_location(WormConfiguration.current_server)
|
||||
running_on_island = is_running_on_server(current_server_without_port)
|
||||
return running_on_island and WormConfiguration.depth == WormConfiguration.max_depth
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
# noinspection PyPep8Naming
|
||||
import operator
|
||||
from functools import reduce
|
||||
from typing import List
|
||||
from typing import List, Union, Any
|
||||
|
||||
|
||||
class abstractstatic(staticmethod):
|
||||
|
@ -15,13 +15,5 @@ class abstractstatic(staticmethod):
|
|||
__isabstractmethod__ = True
|
||||
|
||||
|
||||
def _get_value_by_path(data, path: List[str]):
|
||||
def get_dict_value_by_path(data: dict, path: List[str]) -> Any:
|
||||
return reduce(operator.getitem, path, data)
|
||||
|
||||
|
||||
def get_object_value_by_path(data_object: object, path: List[str]):
|
||||
return _get_value_by_path(data_object, path)
|
||||
|
||||
|
||||
def get_dict_value_by_path(data_dict: dict, path: List[str]):
|
||||
return _get_value_by_path(data_dict, path)
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
import unittest
|
||||
|
||||
from common.utils.code_utils import get_dict_value_by_path
|
||||
|
||||
|
||||
class TestCodeUtils(unittest.TestCase):
|
||||
def test_get_dict_value_by_path(self):
|
||||
dict_for_test = {'a': {'b': {'c': 'result'}}}
|
||||
self.assertEqual(get_dict_value_by_path(dict_for_test, ['a', 'b', 'c']), 'result')
|
|
@ -7,7 +7,7 @@ import time
|
|||
from threading import Thread
|
||||
|
||||
import infection_monkey.tunnel as tunnel
|
||||
from common.network.network_utils import is_running_on_island
|
||||
from infection_monkey.network.tools import is_running_on_island
|
||||
from common.utils.attack_utils import ScanStatus, UsageEnum
|
||||
from common.utils.exceptions import ExploitingVulnerableMachineError, FailedExploitationError
|
||||
from common.version import get_version
|
||||
|
|
|
@ -7,6 +7,8 @@ import subprocess
|
|||
import sys
|
||||
import time
|
||||
|
||||
from common.network.network_utils import get_host_from_network_location
|
||||
from infection_monkey.config import WormConfiguration
|
||||
from infection_monkey.network.info import get_routes, local_ips
|
||||
from infection_monkey.pyinstaller_utils import get_binary_file_path
|
||||
from infection_monkey.utils.environment import is_64bit_python
|
||||
|
@ -311,5 +313,11 @@ def get_interface_to_target(dst):
|
|||
return ret[1]
|
||||
|
||||
|
||||
def is_running_on_island():
|
||||
current_server_without_port = get_host_from_network_location(WormConfiguration.current_server)
|
||||
running_on_island = is_running_on_server(current_server_without_port)
|
||||
return running_on_island and WormConfiguration.depth == WormConfiguration.max_depth
|
||||
|
||||
|
||||
def is_running_on_server(ip: str) -> bool:
|
||||
return ip in local_ips()
|
||||
|
|
|
@ -3,7 +3,7 @@ import logging
|
|||
from common.cloud.aws.aws_instance import AwsInstance
|
||||
from common.cloud.scoutsuite_consts import CloudProviders
|
||||
from common.common_consts.system_info_collectors_names import AWS_COLLECTOR
|
||||
from common.network.network_utils import is_running_on_island
|
||||
from infection_monkey.network.tools import is_running_on_island
|
||||
from infection_monkey.system_info.collectors.scoutsuite_collector.scoutsuite_collector import scan_cloud_security
|
||||
from infection_monkey.system_info.system_info_collector import SystemInfoCollector
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from common.utils.code_utils import get_object_value_by_path
|
||||
from common.utils.code_utils import get_dict_value_by_path
|
||||
from common.utils.exceptions import RulePathCreatorNotFound
|
||||
from monkey_island.cc.services.zero_trust.scoutsuite.data_parsing.rule_path_building.rule_path_creators_list import \
|
||||
RULE_PATH_CREATORS_LIST
|
||||
|
@ -9,7 +9,7 @@ class RuleParser:
|
|||
@staticmethod
|
||||
def get_rule_data(scoutsuite_data, rule_name):
|
||||
rule_path = RuleParser.get_rule_path(rule_name)
|
||||
return get_object_value_by_path(scoutsuite_data, rule_path)
|
||||
return get_dict_value_by_path(data=scoutsuite_data, path=rule_path)
|
||||
|
||||
@staticmethod
|
||||
def get_rule_path(rule_name):
|
||||
|
|
Loading…
Reference in New Issue