Agent: Modify AutomatedMaster to handle propagation config options

This commit is contained in:
Mike Salvatore 2021-12-10 09:32:29 -05:00
parent 80707dac8e
commit 75cfa252c9
3 changed files with 52 additions and 69 deletions

View File

@ -121,7 +121,9 @@ class AutomatedMaster(IMaster):
# system_info_collector_thread.join() # system_info_collector_thread.join()
if self._can_propagate(): if self._can_propagate():
propagation_thread = create_daemon_thread(target=self._propagate, args=(config,)) propagation_thread = create_daemon_thread(
target=self._propagate, args=(config["propagation"],)
)
propagation_thread.start() propagation_thread.start()
propagation_thread.join() propagation_thread.join()
@ -160,14 +162,12 @@ class AutomatedMaster(IMaster):
return True return True
# TODO: Refactor propagation into its own class # TODO: Refactor propagation into its own class
def _propagate(self, config: Dict): def _propagate(self, propagation_config: Dict):
logger.info("Attempting to propagate") logger.info("Attempting to propagate")
self._hosts_to_exploit = Queue() self._hosts_to_exploit = Queue()
scan_thread = create_daemon_thread( scan_thread = create_daemon_thread(target=self._scan_network, args=(propagation_config,))
target=self._scan_network, args=(config["network_scan"],)
)
exploit_thread = create_daemon_thread(target=self._exploit_targets, args=(scan_thread,)) exploit_thread = create_daemon_thread(target=self._exploit_targets, args=(scan_thread,))
scan_thread.start() scan_thread.start()
@ -178,19 +178,14 @@ class AutomatedMaster(IMaster):
logger.info("Finished attempting to propagate") logger.info("Finished attempting to propagate")
def _scan_network(self, scan_config: Dict): def _scan_network(self, propagation_config: Dict):
logger.info("Starting network scan") logger.info("Starting network scan")
# TODO: Generate list of IPs to scan # TODO: Generate list of IPs to scan
ips_to_scan = [f"10.0.0.{i}" for i in range(1, 255)] ips_to_scan = [f"10.0.0.{i}" for i in range(1, 255)]
self._ip_scanner.scan( scan_config = propagation_config["network_scan"]
ips_to_scan, self._ip_scanner.scan(ips_to_scan, scan_config, self._handle_scanned_host, self._stop)
scan_config["icmp"],
scan_config["tcp"],
self._handle_scanned_host,
self._stop,
)
logger.info("Finished network scan") logger.info("Finished network scan")

View File

@ -20,26 +20,14 @@ class IPScanner:
self._puppet = puppet self._puppet = puppet
self._num_workers = num_workers self._num_workers = num_workers
def scan( def scan(self, ips_to_scan: List[str], options: Dict, results_callback: Callback, stop: Event):
self, # Pre-fill a Queue with all IPs to scan so that threads know they can safely exit when the
ips: List[str], # queue is empty.
icmp_config: Dict, ips = Queue()
tcp_config: Dict, for ip in ips_to_scan:
report_results_callback: Callback, ips.put(ip)
stop: Event,
):
# Pre-fill a Queue with all IPs so that threads can safely exit when the queue is empty.
ips_to_scan = Queue()
for ip in ips:
ips_to_scan.put(ip)
scan_ips_args = ( scan_ips_args = (ips, options, results_callback, stop)
ips_to_scan,
icmp_config,
tcp_config,
report_results_callback,
stop,
)
scan_threads = [] scan_threads = []
for i in range(0, self._num_workers): for i in range(0, self._num_workers):
t = create_daemon_thread(target=self._scan_ips, args=scan_ips_args) t = create_daemon_thread(target=self._scan_ips, args=scan_ips_args)
@ -49,27 +37,20 @@ class IPScanner:
for t in scan_threads: for t in scan_threads:
t.join() t.join()
def _scan_ips( def _scan_ips(self, ips: Queue, options: Dict, results_callback: Callback, stop: Event):
self,
ips_to_scan: Queue,
icmp_config: Dict,
tcp_config: Dict,
report_results_callback: Callback,
stop: Event,
):
logger.debug(f"Starting scan thread -- Thread ID: {threading.get_ident()}") logger.debug(f"Starting scan thread -- Thread ID: {threading.get_ident()}")
try: try:
while not stop.is_set(): while not stop.is_set():
ip = ips_to_scan.get_nowait() ip = ips.get_nowait()
logger.info(f"Scanning {ip}") logger.info(f"Scanning {ip}")
victim_host = VictimHost(ip) victim_host = VictimHost(ip)
self._ping_ip(ip, victim_host, icmp_config) self._ping_ip(ip, victim_host, options["icmp"])
self._scan_tcp_ports(ip, victim_host, tcp_config, stop) self._scan_tcp_ports(ip, victim_host, options["tcp"], stop)
report_results_callback(victim_host) results_callback(victim_host)
except queue.Empty: except queue.Empty:
logger.debug( logger.debug(

View File

@ -18,8 +18,9 @@ class MockPuppet(MockPuppet):
@pytest.fixture @pytest.fixture
def tcp_scan_config(): def scan_config():
return { return {
"tcp": {
"timeout_ms": 3000, "timeout_ms": 3000,
"ports": [ "ports": [
22, 22,
@ -29,13 +30,10 @@ def tcp_scan_config():
8008, 8008,
3306, 3306,
], ],
} },
"icmp": {
@pytest.fixture
def icmp_scan_config():
return {
"timeout_ms": 1000, "timeout_ms": 1000,
},
} }
@ -80,22 +78,22 @@ def assert_host_down(victim_host):
assert len(victim_host.services.keys()) == 0 assert len(victim_host.services.keys()) == 0
def test_scan_single_ip(callback, icmp_scan_config, tcp_scan_config, stop): def test_scan_single_ip(callback, scan_config, stop):
ips = ["10.0.0.1"] ips = ["10.0.0.1"]
ns = IPScanner(MockPuppet(), num_workers=1) ns = IPScanner(MockPuppet(), num_workers=1)
ns.scan(ips, icmp_scan_config, tcp_scan_config, callback, stop) ns.scan(ips, scan_config, callback, stop)
callback.assert_called_once() callback.assert_called_once()
assert_dot_1(callback.call_args_list[0][0][0]) assert_dot_1(callback.call_args_list[0][0][0])
def test_scan_multiple_ips(callback, icmp_scan_config, tcp_scan_config, stop): def test_scan_multiple_ips(callback, scan_config, stop):
ips = ["10.0.0.1", "10.0.0.2", "10.0.0.3", "10.0.0.4"] ips = ["10.0.0.1", "10.0.0.2", "10.0.0.3", "10.0.0.4"]
ns = IPScanner(MockPuppet(), num_workers=4) ns = IPScanner(MockPuppet(), num_workers=4)
ns.scan(ips, icmp_scan_config, tcp_scan_config, callback, stop) ns.scan(ips, scan_config, callback, stop)
assert callback.call_count == 4 assert callback.call_count == 4
@ -105,7 +103,16 @@ def test_scan_multiple_ips(callback, icmp_scan_config, tcp_scan_config, stop):
assert_host_down(callback.call_args_list[3][0][0]) assert_host_down(callback.call_args_list[3][0][0])
def test_stop_after_callback(icmp_scan_config, tcp_scan_config, stop): def test_scan_lots_of_ips(callback, scan_config, stop):
ips = [f"10.0.0.{i}" for i in range(0, 255)]
ns = IPScanner(MockPuppet(), num_workers=4)
ns.scan(ips, scan_config, callback, stop)
assert callback.call_count == 255
def test_stop_after_callback(scan_config, stop):
def _callback(_): def _callback(_):
# Block all threads here until 2 threads reach this barrier, then set stop # Block all threads here until 2 threads reach this barrier, then set stop
# and test that niether thread continues to scan. # and test that niether thread continues to scan.
@ -119,12 +126,12 @@ def test_stop_after_callback(icmp_scan_config, tcp_scan_config, stop):
ips = ["10.0.0.1", "10.0.0.2", "10.0.0.3", "10.0.0.4"] ips = ["10.0.0.1", "10.0.0.2", "10.0.0.3", "10.0.0.4"]
ns = IPScanner(MockPuppet(), num_workers=2) ns = IPScanner(MockPuppet(), num_workers=2)
ns.scan(ips, icmp_scan_config, tcp_scan_config, stopable_callback, stop) ns.scan(ips, scan_config, stopable_callback, stop)
assert stopable_callback.call_count == 2 assert stopable_callback.call_count == 2
def test_interrupt_port_scanning(callback, icmp_scan_config, tcp_scan_config, stop): def test_interrupt_port_scanning(callback, scan_config, stop):
def stopable_scan_tcp_port(port, _, __): def stopable_scan_tcp_port(port, _, __):
# Block all threads here until 2 threads reach this barrier, then set stop # Block all threads here until 2 threads reach this barrier, then set stop
# and test that niether thread scans any more ports # and test that niether thread scans any more ports
@ -141,6 +148,6 @@ def test_interrupt_port_scanning(callback, icmp_scan_config, tcp_scan_config, st
ips = ["10.0.0.1", "10.0.0.2", "10.0.0.3", "10.0.0.4"] ips = ["10.0.0.1", "10.0.0.2", "10.0.0.3", "10.0.0.4"]
ns = IPScanner(puppet, num_workers=2) ns = IPScanner(puppet, num_workers=2)
ns.scan(ips, icmp_scan_config, tcp_scan_config, callback, stop) ns.scan(ips, scan_config, callback, stop)
assert puppet.scan_tcp_port.call_count == 2 assert puppet.scan_tcp_port.call_count == 2