diff --git a/monkey/tests/unit_tests/infection_monkey/master/test_exploiter.py b/monkey/tests/unit_tests/infection_monkey/master/test_exploiter.py index b2c42f1ec..26067ab22 100644 --- a/monkey/tests/unit_tests/infection_monkey/master/test_exploiter.py +++ b/monkey/tests/unit_tests/infection_monkey/master/test_exploiter.py @@ -66,12 +66,20 @@ def get_credentials_for_propagation(): return CREDENTIALS_FOR_PROPAGATION -def test_exploiter(exploiter_config, callback, scan_completed, stop, hosts, hosts_to_exploit): - # Set this so that Exploiter() exits once it has processed all victims - scan_completed.set() +@pytest.fixture +def run_exploiters(exploiter_config, hosts_to_exploit, callback, scan_completed, stop): + def inner(puppet, num_workers): + # Set this so that Exploiter() exits once it has processed all victims + scan_completed.set() - e = Exploiter(MockPuppet(), 2, get_credentials_for_propagation) - e.exploit_hosts(exploiter_config, hosts_to_exploit, callback, scan_completed, stop) + e = Exploiter(puppet, num_workers, get_credentials_for_propagation) + e.exploit_hosts(exploiter_config, hosts_to_exploit, callback, scan_completed, stop) + + return inner + + +def test_exploiter(callback, hosts, hosts_to_exploit, run_exploiters): + run_exploiters(MockPuppet(), 2) assert callback.call_count == 5 host_exploit_combos = set() @@ -88,15 +96,9 @@ def test_exploiter(exploiter_config, callback, scan_completed, stop, hosts, host assert ("SSHExploiter", hosts[1]) in host_exploit_combos -def test_credentials_passed_to_exploiter( - exploiter_config, callback, scan_completed, stop, hosts, hosts_to_exploit -): +def test_credentials_passed_to_exploiter(run_exploiters): mock_puppet = MagicMock() - # Set this so that Exploiter() exits once it has processed all victims - scan_completed.set() - - e = Exploiter(mock_puppet, 2, get_credentials_for_propagation) - e.exploit_hosts(exploiter_config, hosts_to_exploit, callback, scan_completed, stop) + run_exploiters(mock_puppet, 1) for call_args in mock_puppet.exploit_host.call_args_list: assert call_args[0][2].get("credentials") == CREDENTIALS_FOR_PROPAGATION