Agent, UT: Implement VictimHostFactory

Implements and unit tests the VictimHostFactory. The factory allows creation of victims based on current network situation of the agent
This commit is contained in:
vakarisz 2021-12-16 18:09:00 +02:00 committed by Mike Salvatore
parent ddd8a0e53a
commit 29d3cc2aaf
5 changed files with 152 additions and 33 deletions

View File

@ -89,7 +89,7 @@ class Propagator:
)
def _process_scan_results(self, address: NetworkAddress, scan_results: IPScanResults):
victim_host = self._victim_host_factory.build_victim_host(address.ip, address.domain)
victim_host = self._victim_host_factory.build_victim_host(address)
Propagator._process_ping_scan_results(victim_host, scan_results.ping_scan_data)
Propagator._process_tcp_scan_results(victim_host, scan_results.port_scan_data)

View File

@ -1,28 +1,45 @@
import logging
from typing import Optional
from infection_monkey.model import VictimHost
from infection_monkey.network import NetworkAddress
from infection_monkey.network.tools import get_interface_to_target
from infection_monkey.tunnel import MonkeyTunnel
logger = logging.getLogger(__name__)
class VictimHostFactory:
def __init__(self):
pass
def __init__(
self,
tunnel: Optional[MonkeyTunnel],
default_server: Optional[str],
default_port: Optional[str],
on_island: bool,
):
self.tunnel = tunnel
self.default_server = default_server
self.default_port = default_port
self.on_island = on_island
def build_victim_host(self, ip: str, domain: str):
victim_host = VictimHost(ip, domain)
def build_victim_host(self, network_address: NetworkAddress) -> VictimHost:
victim_host = VictimHost(network_address.ip, network_address.domain)
# TODO: Reimplement the below logic from the old monkey.py
"""
if self._monkey_tunnel:
self._monkey_tunnel.set_tunnel_for_host(machine)
if self._default_server:
if self._network.on_island(self._default_server):
machine.set_default_server(
get_interface_to_target(machine.ip_addr)
+ (":" + self._default_server_port if self._default_server_port else "")
)
else:
machine.set_default_server(self._default_server)
logger.debug(
f"Default server for machine: {machine} set to {machine.default_server}"
)
"""
if self.tunnel:
victim_host.default_tunnel = self.tunnel.get_tunnel_for_ip(victim_host.ip_addr)
if self.default_server:
if self.on_island:
victim_host.set_default_server(
get_interface_to_target(victim_host.ip_addr)
+ (":" + self.default_port if self.default_port else "")
)
else:
victim_host.set_default_server(self.default_server)
logger.debug(
f"Default server for machine: {victim_host} set to {victim_host.default_server}"
)
logger.debug(
f"Default tunnel for machine: {victim_host} set to {victim_host.default_tunnel}"
)
return victim_host

View File

@ -4,7 +4,6 @@ import struct
import time
from threading import Thread
from infection_monkey.model import VictimHost
from infection_monkey.network.firewall import app as firewall
from infection_monkey.network.info import get_free_tcp_port, local_ips
from infection_monkey.network.tools import check_tcp_port, get_interface_to_target
@ -188,14 +187,13 @@ class MonkeyTunnel(Thread):
proxy.stop()
proxy.join()
def set_tunnel_for_host(self, host):
assert isinstance(host, VictimHost)
def get_tunnel_for_ip(self, ip: str):
if not self.local_port:
return
ip_match = get_interface_to_target(host.ip_addr)
host.default_tunnel = "%s:%d" % (ip_match, self.local_port)
ip_match = get_interface_to_target(ip)
return "%s:%d" % (ip_match, self.local_port)
def stop(self):
self._stopped = True

View File

@ -11,9 +11,26 @@ from infection_monkey.i_puppet import (
PortStatus,
)
from infection_monkey.master import IPScanResults, Propagator
from infection_monkey.model import VictimHostFactory
from infection_monkey.network import NetworkInterface
from infection_monkey.telemetry.exploit_telem import ExploitTelem
from infection_monkey.model import VictimHost, VictimHostFactory
from infection_monkey.network import NetworkAddress
@pytest.fixture
def mock_victim_host_factory():
class MockVictimHostFactory(VictimHostFactory):
def __init__(self):
pass
def build_victim_host(self, network_address: NetworkAddress) -> VictimHost:
domain = network_address.domain or ""
return VictimHost(network_address.ip, domain)
return MockVictimHostFactory()
empty_fingerprint_data = FingerprintData(None, None, {})
@ -111,9 +128,9 @@ class StubExploiter:
pass
def test_scan_result_processing(telemetry_messenger_spy, mock_ip_scanner):
def test_scan_result_processing(telemetry_messenger_spy, mock_ip_scanner, mock_victim_host_factory):
p = Propagator(
telemetry_messenger_spy, mock_ip_scanner, StubExploiter(), VictimHostFactory(), []
telemetry_messenger_spy, mock_ip_scanner, StubExploiter(), mock_victim_host_factory, []
)
p.propagate(
{
@ -201,9 +218,11 @@ class MockExploiter:
)
def test_exploiter_result_processing(telemetry_messenger_spy, mock_ip_scanner):
def test_exploiter_result_processing(
telemetry_messenger_spy, mock_ip_scanner, mock_victim_host_factory
):
p = Propagator(
telemetry_messenger_spy, mock_ip_scanner, MockExploiter(), VictimHostFactory(), []
telemetry_messenger_spy, mock_ip_scanner, MockExploiter(), mock_victim_host_factory, []
)
p.propagate(
{
@ -240,13 +259,13 @@ def test_exploiter_result_processing(telemetry_messenger_spy, mock_ip_scanner):
assert data["result"]
def test_scan_target_generation(telemetry_messenger_spy, mock_ip_scanner):
def test_scan_target_generation(telemetry_messenger_spy, mock_ip_scanner, mock_victim_host_factory):
local_network_interfaces = [NetworkInterface("10.0.0.9", "/29")]
p = Propagator(
telemetry_messenger_spy,
mock_ip_scanner,
StubExploiter(),
VictimHostFactory(),
mock_victim_host_factory,
local_network_interfaces,
)
p.propagate(

View File

@ -0,0 +1,85 @@
from unittest.mock import MagicMock
import pytest
from infection_monkey.model import VictimHostFactory
from infection_monkey.network.scan_target_generator import NetworkAddress
@pytest.fixture
def mock_tunnel():
tunnel = MagicMock()
tunnel.get_tunnel_for_ip = lambda _: "1.2.3.4:1234"
return tunnel
@pytest.fixture(autouse=True)
def mock_get_interface_to_target(monkeypatch):
monkeypatch.setattr(
"infection_monkey.model.victim_host_factory.get_interface_to_target", lambda _: "1.1.1.1"
)
def test_factory_no_tunnel():
factory = VictimHostFactory(
tunnel=None, default_server="192.168.56.1", default_port="5000", on_island=False
)
network_address = NetworkAddress("192.168.56.2", None)
victim = factory.build_victim_host(network_address)
assert victim.default_server == "192.168.56.1"
assert victim.ip_addr == "192.168.56.2"
assert victim.default_tunnel is None
assert victim.domain_name == ""
def test_factory_with_tunnel(mock_tunnel):
factory = VictimHostFactory(
tunnel=mock_tunnel, default_server="192.168.56.1", default_port="5000", on_island=False
)
network_address = NetworkAddress("192.168.56.2", None)
victim = factory.build_victim_host(network_address)
assert victim.default_server == "192.168.56.1"
assert victim.ip_addr == "192.168.56.2"
assert victim.default_tunnel == "1.2.3.4:1234"
assert victim.domain_name == ""
def test_factory_on_island(mock_tunnel):
factory = VictimHostFactory(
tunnel=mock_tunnel, default_server="192.168.56.1", default_port="99", on_island=True
)
network_address = NetworkAddress("192.168.56.2", "www.bogus.monkey")
victim = factory.build_victim_host(network_address)
assert victim.default_server == "1.1.1.1:99"
assert victim.domain_name == "www.bogus.monkey"
assert victim.ip_addr == "192.168.56.2"
assert victim.default_tunnel == "1.2.3.4:1234"
@pytest.mark.parametrize("default_port", ["", None])
def test_factory_no_port(mock_tunnel, default_port):
factory = VictimHostFactory(
tunnel=mock_tunnel, default_server="192.168.56.1", default_port=default_port, on_island=True
)
network_address = NetworkAddress("192.168.56.2", "www.bogus.monkey")
victim = factory.build_victim_host(network_address)
assert victim.default_server == "1.1.1.1"
def test_factory_no_default_server(mock_tunnel):
factory = VictimHostFactory(
tunnel=mock_tunnel, default_server=None, default_port="", on_island=True
)
network_address = NetworkAddress("192.168.56.2", "www.bogus.monkey")
victim = factory.build_victim_host(network_address)
assert victim.default_server is None