diff --git a/monkey/infection_monkey/master/propagator.py b/monkey/infection_monkey/master/propagator.py index ca6922b37..b3eb7faf9 100644 --- a/monkey/infection_monkey/master/propagator.py +++ b/monkey/infection_monkey/master/propagator.py @@ -89,7 +89,7 @@ class Propagator: ) def _process_scan_results(self, address: NetworkAddress, scan_results: IPScanResults): - victim_host = self._victim_host_factory.build_victim_host(address.ip, address.domain) + victim_host = self._victim_host_factory.build_victim_host(address) Propagator._process_ping_scan_results(victim_host, scan_results.ping_scan_data) Propagator._process_tcp_scan_results(victim_host, scan_results.port_scan_data) diff --git a/monkey/infection_monkey/model/victim_host_factory.py b/monkey/infection_monkey/model/victim_host_factory.py index 775bb8baf..09ef8e98e 100644 --- a/monkey/infection_monkey/model/victim_host_factory.py +++ b/monkey/infection_monkey/model/victim_host_factory.py @@ -1,28 +1,45 @@ +import logging +from typing import Optional + from infection_monkey.model import VictimHost +from infection_monkey.network import NetworkAddress +from infection_monkey.network.tools import get_interface_to_target +from infection_monkey.tunnel import MonkeyTunnel + +logger = logging.getLogger(__name__) class VictimHostFactory: - def __init__(self): - pass + def __init__( + self, + tunnel: Optional[MonkeyTunnel], + default_server: Optional[str], + default_port: Optional[str], + on_island: bool, + ): + self.tunnel = tunnel + self.default_server = default_server + self.default_port = default_port + self.on_island = on_island - def build_victim_host(self, ip: str, domain: str): - victim_host = VictimHost(ip, domain) + def build_victim_host(self, network_address: NetworkAddress) -> VictimHost: + victim_host = VictimHost(network_address.ip, network_address.domain) - # TODO: Reimplement the below logic from the old monkey.py - """ - if self._monkey_tunnel: - self._monkey_tunnel.set_tunnel_for_host(machine) - if self._default_server: - if self._network.on_island(self._default_server): - machine.set_default_server( - get_interface_to_target(machine.ip_addr) - + (":" + self._default_server_port if self._default_server_port else "") - ) - else: - machine.set_default_server(self._default_server) - logger.debug( - f"Default server for machine: {machine} set to {machine.default_server}" - ) - """ + if self.tunnel: + victim_host.default_tunnel = self.tunnel.get_tunnel_for_ip(victim_host.ip_addr) + if self.default_server: + if self.on_island: + victim_host.set_default_server( + get_interface_to_target(victim_host.ip_addr) + + (":" + self.default_port if self.default_port else "") + ) + else: + victim_host.set_default_server(self.default_server) + logger.debug( + f"Default server for machine: {victim_host} set to {victim_host.default_server}" + ) + logger.debug( + f"Default tunnel for machine: {victim_host} set to {victim_host.default_tunnel}" + ) return victim_host diff --git a/monkey/infection_monkey/tunnel.py b/monkey/infection_monkey/tunnel.py index f39069daf..4aa90e80f 100644 --- a/monkey/infection_monkey/tunnel.py +++ b/monkey/infection_monkey/tunnel.py @@ -4,7 +4,6 @@ import struct import time from threading import Thread -from infection_monkey.model import VictimHost from infection_monkey.network.firewall import app as firewall from infection_monkey.network.info import get_free_tcp_port, local_ips from infection_monkey.network.tools import check_tcp_port, get_interface_to_target @@ -188,14 +187,13 @@ class MonkeyTunnel(Thread): proxy.stop() proxy.join() - def set_tunnel_for_host(self, host): - assert isinstance(host, VictimHost) + def get_tunnel_for_ip(self, ip: str): if not self.local_port: return - ip_match = get_interface_to_target(host.ip_addr) - host.default_tunnel = "%s:%d" % (ip_match, self.local_port) + ip_match = get_interface_to_target(ip) + return "%s:%d" % (ip_match, self.local_port) def stop(self): self._stopped = True diff --git a/monkey/tests/unit_tests/infection_monkey/master/test_propagator.py b/monkey/tests/unit_tests/infection_monkey/master/test_propagator.py index 8fa0204c2..745e075fa 100644 --- a/monkey/tests/unit_tests/infection_monkey/master/test_propagator.py +++ b/monkey/tests/unit_tests/infection_monkey/master/test_propagator.py @@ -11,9 +11,26 @@ from infection_monkey.i_puppet import ( PortStatus, ) from infection_monkey.master import IPScanResults, Propagator -from infection_monkey.model import VictimHostFactory from infection_monkey.network import NetworkInterface from infection_monkey.telemetry.exploit_telem import ExploitTelem +from infection_monkey.model import VictimHost, VictimHostFactory +from infection_monkey.network import NetworkAddress + + + + +@pytest.fixture +def mock_victim_host_factory(): + class MockVictimHostFactory(VictimHostFactory): + def __init__(self): + pass + + def build_victim_host(self, network_address: NetworkAddress) -> VictimHost: + domain = network_address.domain or "" + return VictimHost(network_address.ip, domain) + + return MockVictimHostFactory() + empty_fingerprint_data = FingerprintData(None, None, {}) @@ -111,9 +128,9 @@ class StubExploiter: pass -def test_scan_result_processing(telemetry_messenger_spy, mock_ip_scanner): +def test_scan_result_processing(telemetry_messenger_spy, mock_ip_scanner, mock_victim_host_factory): p = Propagator( - telemetry_messenger_spy, mock_ip_scanner, StubExploiter(), VictimHostFactory(), [] + telemetry_messenger_spy, mock_ip_scanner, StubExploiter(), mock_victim_host_factory, [] ) p.propagate( { @@ -201,9 +218,11 @@ class MockExploiter: ) -def test_exploiter_result_processing(telemetry_messenger_spy, mock_ip_scanner): +def test_exploiter_result_processing( + telemetry_messenger_spy, mock_ip_scanner, mock_victim_host_factory +): p = Propagator( - telemetry_messenger_spy, mock_ip_scanner, MockExploiter(), VictimHostFactory(), [] + telemetry_messenger_spy, mock_ip_scanner, MockExploiter(), mock_victim_host_factory, [] ) p.propagate( { @@ -240,13 +259,13 @@ def test_exploiter_result_processing(telemetry_messenger_spy, mock_ip_scanner): assert data["result"] -def test_scan_target_generation(telemetry_messenger_spy, mock_ip_scanner): +def test_scan_target_generation(telemetry_messenger_spy, mock_ip_scanner, mock_victim_host_factory): local_network_interfaces = [NetworkInterface("10.0.0.9", "/29")] p = Propagator( telemetry_messenger_spy, mock_ip_scanner, StubExploiter(), - VictimHostFactory(), + mock_victim_host_factory, local_network_interfaces, ) p.propagate( diff --git a/monkey/tests/unit_tests/infection_monkey/model/test_victim_host_factory.py b/monkey/tests/unit_tests/infection_monkey/model/test_victim_host_factory.py new file mode 100644 index 000000000..2b5250c8c --- /dev/null +++ b/monkey/tests/unit_tests/infection_monkey/model/test_victim_host_factory.py @@ -0,0 +1,85 @@ +from unittest.mock import MagicMock + +import pytest + +from infection_monkey.model import VictimHostFactory +from infection_monkey.network.scan_target_generator import NetworkAddress + + +@pytest.fixture +def mock_tunnel(): + tunnel = MagicMock() + tunnel.get_tunnel_for_ip = lambda _: "1.2.3.4:1234" + return tunnel + + +@pytest.fixture(autouse=True) +def mock_get_interface_to_target(monkeypatch): + monkeypatch.setattr( + "infection_monkey.model.victim_host_factory.get_interface_to_target", lambda _: "1.1.1.1" + ) + + +def test_factory_no_tunnel(): + factory = VictimHostFactory( + tunnel=None, default_server="192.168.56.1", default_port="5000", on_island=False + ) + network_address = NetworkAddress("192.168.56.2", None) + + victim = factory.build_victim_host(network_address) + + assert victim.default_server == "192.168.56.1" + assert victim.ip_addr == "192.168.56.2" + assert victim.default_tunnel is None + assert victim.domain_name == "" + + +def test_factory_with_tunnel(mock_tunnel): + factory = VictimHostFactory( + tunnel=mock_tunnel, default_server="192.168.56.1", default_port="5000", on_island=False + ) + network_address = NetworkAddress("192.168.56.2", None) + + victim = factory.build_victim_host(network_address) + + assert victim.default_server == "192.168.56.1" + assert victim.ip_addr == "192.168.56.2" + assert victim.default_tunnel == "1.2.3.4:1234" + assert victim.domain_name == "" + + +def test_factory_on_island(mock_tunnel): + factory = VictimHostFactory( + tunnel=mock_tunnel, default_server="192.168.56.1", default_port="99", on_island=True + ) + network_address = NetworkAddress("192.168.56.2", "www.bogus.monkey") + + victim = factory.build_victim_host(network_address) + + assert victim.default_server == "1.1.1.1:99" + assert victim.domain_name == "www.bogus.monkey" + assert victim.ip_addr == "192.168.56.2" + assert victim.default_tunnel == "1.2.3.4:1234" + + +@pytest.mark.parametrize("default_port", ["", None]) +def test_factory_no_port(mock_tunnel, default_port): + factory = VictimHostFactory( + tunnel=mock_tunnel, default_server="192.168.56.1", default_port=default_port, on_island=True + ) + network_address = NetworkAddress("192.168.56.2", "www.bogus.monkey") + + victim = factory.build_victim_host(network_address) + + assert victim.default_server == "1.1.1.1" + + +def test_factory_no_default_server(mock_tunnel): + factory = VictimHostFactory( + tunnel=mock_tunnel, default_server=None, default_port="", on_island=True + ) + network_address = NetworkAddress("192.168.56.2", "www.bogus.monkey") + + victim = factory.build_victim_host(network_address) + + assert victim.default_server is None