Refactored unittests to pytest on island code. Cleaned up test infrasctructure: moved common test files to /test_common

This commit is contained in:
VakarisZ 2021-01-20 15:31:42 +02:00
parent d31e9064c8
commit 2df889ee31
21 changed files with 173 additions and 198 deletions

View File

@ -1,30 +1,28 @@
from common.network.network_range import CidrRange from common.network.network_range import CidrRange
from common.network.segmentation_utils import get_ip_in_src_and_not_in_dst from common.network.segmentation_utils import get_ip_in_src_and_not_in_dst
from monkey_island.cc.testing.IslandTestCase import IslandTestCase
class TestSegmentationUtils(IslandTestCase): class TestSegmentationUtils:
def test_get_ip_in_src_and_not_in_dst(self): def test_get_ip_in_src_and_not_in_dst(self):
self.fail_if_not_testing_env()
source = CidrRange("1.1.1.0/24") source = CidrRange("1.1.1.0/24")
target = CidrRange("2.2.2.0/24") target = CidrRange("2.2.2.0/24")
# IP not in both # IP not in both
self.assertIsNone(get_ip_in_src_and_not_in_dst( assert get_ip_in_src_and_not_in_dst(
["3.3.3.3", "4.4.4.4"], source, target ["3.3.3.3", "4.4.4.4"], source, target
)) ) is None
# IP not in source, in target # IP not in source, in target
self.assertIsNone(get_ip_in_src_and_not_in_dst( assert (get_ip_in_src_and_not_in_dst(
["2.2.2.2"], source, target ["2.2.2.2"], source, target
)) )) is None
# IP in source, not in target # IP in source, not in target
self.assertIsNotNone(get_ip_in_src_and_not_in_dst( assert (get_ip_in_src_and_not_in_dst(
["8.8.8.8", "1.1.1.1"], source, target ["8.8.8.8", "1.1.1.1"], source, target
)) ))
# IP in both subnets # IP in both subnets
self.assertIsNone(get_ip_in_src_and_not_in_dst( assert (get_ip_in_src_and_not_in_dst(
["8.8.8.8", "1.1.1.1"], source, source ["8.8.8.8", "1.1.1.1"], source, source
)) )) is None

View File

@ -0,0 +1 @@
from .test_common.mongomock_fixtures import *

View File

@ -4,7 +4,7 @@ from typing import Dict
from unittest import TestCase from unittest import TestCase
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import monkey_island.cc.testing.environment.server_config_mocks as config_mocks import monkey_island.cc.test_common.environment.server_config_mocks as config_mocks
from common.utils.exceptions import (AlreadyRegisteredError, CredentialsNotRequiredError, from common.utils.exceptions import (AlreadyRegisteredError, CredentialsNotRequiredError,
InvalidRegistrationCredentialsError, RegistrationNotNeededError) InvalidRegistrationCredentialsError, RegistrationNotNeededError)
from monkey_island.cc.environment import Environment, EnvironmentConfig, UserCreds from monkey_island.cc.environment import Environment, EnvironmentConfig, UserCreds

View File

@ -5,7 +5,7 @@ from typing import Dict
from unittest import TestCase from unittest import TestCase
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import monkey_island.cc.testing.environment.server_config_mocks as config_mocks import monkey_island.cc.test_common.environment.server_config_mocks as config_mocks
from monkey_island.cc.server_utils.consts import MONKEY_ISLAND_ABS_PATH from monkey_island.cc.server_utils.consts import MONKEY_ISLAND_ABS_PATH
from monkey_island.cc.environment.environment_config import EnvironmentConfig from monkey_island.cc.environment.environment_config import EnvironmentConfig
from monkey_island.cc.environment.user_creds import UserCreds from monkey_island.cc.environment.user_creds import UserCreds

View File

@ -5,22 +5,15 @@ from time import sleep
import pytest import pytest
from monkey_island.cc.models.monkey import Monkey, MonkeyNotFoundError from monkey_island.cc.models.monkey import Monkey, MonkeyNotFoundError
from monkey_island.cc.testing.IslandTestCase import IslandTestCase
from .monkey_ttl import MonkeyTtl from .monkey_ttl import MonkeyTtl
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TestMonkey(IslandTestCase): class TestMonkey:
"""
Make sure to set server environment to `testing` in server_config.json!
Otherwise this will mess up your mongo instance and won't work.
Also, the working directory needs to be the working directory from which you usually run the island so the
server_config.json file is found and loaded.
"""
@pytest.mark.usefixtures('uses_database')
def test_is_dead(self): def test_is_dead(self):
# Arrange # Arrange
alive_monkey_ttl = MonkeyTtl.create_ttl_expire_in(30) alive_monkey_ttl = MonkeyTtl.create_ttl_expire_in(30)
@ -44,20 +37,22 @@ class TestMonkey(IslandTestCase):
dead_monkey.save() dead_monkey.save()
# act + assert # act + assert
self.assertTrue(dead_monkey.is_dead()) assert dead_monkey.is_dead()
self.assertTrue(mia_monkey.is_dead()) assert mia_monkey.is_dead()
self.assertFalse(alive_monkey.is_dead()) assert not alive_monkey.is_dead()
@pytest.mark.usefixtures('uses_database')
def test_ttl_renewal(self): def test_ttl_renewal(self):
# Arrange # Arrange
monkey = Monkey(guid=str(uuid.uuid4())) monkey = Monkey(guid=str(uuid.uuid4()))
monkey.save() monkey.save()
self.assertIsNone(monkey.ttl_ref) assert monkey.ttl_ref is None
# act + assert # act + assert
monkey.renew_ttl() monkey.renew_ttl()
self.assertIsNotNone(monkey.ttl_ref) assert monkey.ttl_ref
@pytest.mark.usefixtures('uses_database')
def test_get_single_monkey_by_id(self): def test_get_single_monkey_by_id(self):
# Arrange # Arrange
a_monkey = Monkey(guid=str(uuid.uuid4())) a_monkey = Monkey(guid=str(uuid.uuid4()))
@ -65,12 +60,13 @@ class TestMonkey(IslandTestCase):
# Act + assert # Act + assert
# Find the existing one # Find the existing one
self.assertIsNotNone(Monkey.get_single_monkey_by_id(a_monkey.id)) assert Monkey.get_single_monkey_by_id(a_monkey.id) is not None
# Raise on non-existent monkey # Raise on non-existent monkey
with pytest.raises(MonkeyNotFoundError) as _: with pytest.raises(MonkeyNotFoundError) as _:
_ = Monkey.get_single_monkey_by_id("abcdefabcdefabcdefabcdef") _ = Monkey.get_single_monkey_by_id("abcdefabcdefabcdefabcdef")
@pytest.mark.usefixtures('uses_database')
def test_get_os(self): def test_get_os(self):
linux_monkey = Monkey(guid=str(uuid.uuid4()), linux_monkey = Monkey(guid=str(uuid.uuid4()),
description="Linux shay-Virtual-Machine 4.15.0-50-generic #54-Ubuntu") description="Linux shay-Virtual-Machine 4.15.0-50-generic #54-Ubuntu")
@ -82,10 +78,11 @@ class TestMonkey(IslandTestCase):
windows_monkey.save() windows_monkey.save()
unknown_monkey.save() unknown_monkey.save()
self.assertEqual(1, len([m for m in Monkey.objects() if m.get_os() == "windows"])) assert 1 == len([m for m in Monkey.objects() if m.get_os() == "windows"])
self.assertEqual(1, len([m for m in Monkey.objects() if m.get_os() == "linux"])) assert 1 == len([m for m in Monkey.objects() if m.get_os() == "linux"])
self.assertEqual(1, len([m for m in Monkey.objects() if m.get_os() == "unknown"])) assert 1 == len([m for m in Monkey.objects() if m.get_os() == "unknown"])
@pytest.mark.usefixtures('uses_database')
def test_get_tunneled_monkeys(self): def test_get_tunneled_monkeys(self):
linux_monkey = Monkey(guid=str(uuid.uuid4()), linux_monkey = Monkey(guid=str(uuid.uuid4()),
description="Linux shay-Virtual-Machine") description="Linux shay-Virtual-Machine")
@ -103,8 +100,9 @@ class TestMonkey(IslandTestCase):
and unknown_monkey in tunneled_monkeys and unknown_monkey in tunneled_monkeys
and linux_monkey not in tunneled_monkeys and linux_monkey not in tunneled_monkeys
and len(tunneled_monkeys) == 2) and len(tunneled_monkeys) == 2)
self.assertTrue(test, "Tunneling test") assert test == "Tunneling test"
@pytest.mark.usefixtures('uses_database')
def test_get_label_by_id(self): def test_get_label_by_id(self):
hostname_example = "a_hostname" hostname_example = "a_hostname"
ip_example = "1.1.1.1" ip_example = "1.1.1.1"
@ -117,26 +115,26 @@ class TestMonkey(IslandTestCase):
logger.debug(id(Monkey.get_label_by_id)) logger.debug(id(Monkey.get_label_by_id))
cache_info_before_query = Monkey.get_label_by_id.storage.backend.cache_info() cache_info_before_query = Monkey.get_label_by_id.storage.backend.cache_info()
self.assertEqual(cache_info_before_query.hits, 0) assert cache_info_before_query.hits == 0
self.assertEqual(cache_info_before_query.misses, 0) assert cache_info_before_query.misses == 0
# not cached # not cached
label = Monkey.get_label_by_id(linux_monkey.id) label = Monkey.get_label_by_id(linux_monkey.id)
cache_info_after_query_1 = Monkey.get_label_by_id.storage.backend.cache_info() cache_info_after_query_1 = Monkey.get_label_by_id.storage.backend.cache_info()
self.assertEqual(cache_info_after_query_1.hits, 0) assert cache_info_after_query_1.hits == 0
self.assertEqual(cache_info_after_query_1.misses, 1) assert cache_info_after_query_1.misses == 1
logger.debug("1) ID: {} label: {}".format(linux_monkey.id, label)) logger.debug("1) ID: {} label: {}".format(linux_monkey.id, label))
self.assertIsNotNone(label) assert label is not None
self.assertIn(hostname_example, label) assert hostname_example in label
self.assertIn(ip_example, label) assert ip_example in label
# should be cached # should be cached
label = Monkey.get_label_by_id(linux_monkey.id) label = Monkey.get_label_by_id(linux_monkey.id)
logger.debug("2) ID: {} label: {}".format(linux_monkey.id, label)) logger.debug("2) ID: {} label: {}".format(linux_monkey.id, label))
cache_info_after_query_2 = Monkey.get_label_by_id.storage.backend.cache_info() cache_info_after_query_2 = Monkey.get_label_by_id.storage.backend.cache_info()
self.assertEqual(cache_info_after_query_2.hits, 1) assert cache_info_after_query_2.hits == 1
self.assertEqual(cache_info_after_query_2.misses, 1) assert cache_info_after_query_2.misses == 1
# set hostname deletes the id from the cache. # set hostname deletes the id from the cache.
linux_monkey.set_hostname("Another hostname") linux_monkey.set_hostname("Another hostname")
@ -147,24 +145,25 @@ class TestMonkey(IslandTestCase):
cache_info_after_query_3 = Monkey.get_label_by_id.storage.backend.cache_info() cache_info_after_query_3 = Monkey.get_label_by_id.storage.backend.cache_info()
logger.debug("Cache info: {}".format(str(cache_info_after_query_3))) logger.debug("Cache info: {}".format(str(cache_info_after_query_3)))
# still 1 hit only # still 1 hit only
self.assertEqual(cache_info_after_query_3.hits, 1) assert cache_info_after_query_3.hits == 1
self.assertEqual(cache_info_after_query_3.misses, 2) assert cache_info_after_query_3.misses == 2
@pytest.mark.usefixtures('uses_database')
def test_is_monkey(self): def test_is_monkey(self):
a_monkey = Monkey(guid=str(uuid.uuid4())) a_monkey = Monkey(guid=str(uuid.uuid4()))
a_monkey.save() a_monkey.save()
cache_info_before_query = Monkey.is_monkey.storage.backend.cache_info() cache_info_before_query = Monkey.is_monkey.storage.backend.cache_info()
self.assertEqual(cache_info_before_query.hits, 0) assert cache_info_before_query.hits == 0
# not cached # not cached
self.assertTrue(Monkey.is_monkey(a_monkey.id)) assert Monkey.is_monkey(a_monkey.id)
fake_id = "123456789012" fake_id = "123456789012"
self.assertFalse(Monkey.is_monkey(fake_id)) assert not Monkey.is_monkey(fake_id)
# should be cached # should be cached
self.assertTrue(Monkey.is_monkey(a_monkey.id)) assert Monkey.is_monkey(a_monkey.id)
self.assertFalse(Monkey.is_monkey(fake_id)) assert not Monkey.is_monkey(fake_id)
cache_info_after_query = Monkey.is_monkey.storage.backend.cache_info() cache_info_after_query = Monkey.is_monkey.storage.backend.cache_info()
self.assertEqual(cache_info_after_query.hits, 2) assert cache_info_after_query.hits == 2

View File

@ -1,20 +1,20 @@
import pytest
from mongoengine import ValidationError from mongoengine import ValidationError
import common.common_consts.zero_trust_consts as zero_trust_consts import common.common_consts.zero_trust_consts as zero_trust_consts
from monkey_island.cc.models.zero_trust.event import Event from monkey_island.cc.models.zero_trust.event import Event
from monkey_island.cc.testing.IslandTestCase import IslandTestCase
class TestEvent(IslandTestCase): class TestEvent:
def test_create_event(self): def test_create_event(self):
with self.assertRaises(ValidationError): with pytest.raises(ValidationError):
_ = Event.create_event( _ = Event.create_event(
title=None, # title required title=None, # title required
message="bla bla", message="bla bla",
event_type=zero_trust_consts.EVENT_TYPE_MONKEY_NETWORK event_type=zero_trust_consts.EVENT_TYPE_MONKEY_NETWORK
) )
with self.assertRaises(ValidationError): with pytest.raises(ValidationError):
_ = Event.create_event( _ = Event.create_event(
title="skjs", title="skjs",
message="bla bla", message="bla bla",

View File

@ -6,7 +6,6 @@ from monkey_island.cc.models.zero_trust.event import Event
from monkey_island.cc.models.zero_trust.finding import Finding from monkey_island.cc.models.zero_trust.finding import Finding
from monkey_island.cc.models.zero_trust.monkey_finding_details import MonkeyFindingDetails from monkey_island.cc.models.zero_trust.monkey_finding_details import MonkeyFindingDetails
from monkey_island.cc.models.zero_trust.scoutsuite_finding_details import ScoutSuiteFindingDetails from monkey_island.cc.models.zero_trust.scoutsuite_finding_details import ScoutSuiteFindingDetails
from monkey_island.cc.testing.IslandTestCase import IslandTestCase
MONKEY_FINDING_DETAIL_MOCK = MonkeyFindingDetails() MONKEY_FINDING_DETAIL_MOCK = MonkeyFindingDetails()
@ -15,21 +14,21 @@ SCOUTSUITE_FINDING_DETAIL_MOCK = ScoutSuiteFindingDetails()
SCOUTSUITE_FINDING_DETAIL_MOCK.scoutsuite_rules = [] SCOUTSUITE_FINDING_DETAIL_MOCK.scoutsuite_rules = []
class TestFinding(IslandTestCase): class TestFinding:
def test_save_finding_validation(self): def test_save_finding_validation(self):
with self.assertRaises(ValidationError): with pytest.raises(ValidationError):
_ = Finding.save_finding(test="bla bla", _ = Finding.save_finding(test="bla bla",
status=zero_trust_consts.STATUS_FAILED, status=zero_trust_consts.STATUS_FAILED,
detail_ref=MONKEY_FINDING_DETAIL_MOCK) detail_ref=MONKEY_FINDING_DETAIL_MOCK)
with self.assertRaises(ValidationError): with pytest.raises(ValidationError):
_ = Finding.save_finding(test=zero_trust_consts.TEST_SEGMENTATION, _ = Finding.save_finding(test=zero_trust_consts.TEST_SEGMENTATION,
status="bla bla", status="bla bla",
detail_ref=SCOUTSUITE_FINDING_DETAIL_MOCK) detail_ref=SCOUTSUITE_FINDING_DETAIL_MOCK)
def test_save_finding_sanity(self): def test_save_finding_sanity(self):
self.assertEqual(len(Finding.objects(test=zero_trust_consts.TEST_SEGMENTATION)), 0) assert len(Finding.objects(test=zero_trust_consts.TEST_SEGMENTATION)) == 0
event_example = Event.create_event( event_example = Event.create_event(
title="Event Title", message="event message", event_type=zero_trust_consts.EVENT_TYPE_MONKEY_NETWORK) title="Event Title", message="event message", event_type=zero_trust_consts.EVENT_TYPE_MONKEY_NETWORK)
@ -38,5 +37,5 @@ class TestFinding(IslandTestCase):
Finding.save_finding(test=zero_trust_consts.TEST_SEGMENTATION, Finding.save_finding(test=zero_trust_consts.TEST_SEGMENTATION,
status=zero_trust_consts.STATUS_FAILED, detail_ref=monkey_details_example) status=zero_trust_consts.STATUS_FAILED, detail_ref=monkey_details_example)
self.assertEqual(len(Finding.objects(test=zero_trust_consts.TEST_SEGMENTATION)), 1) assert len(Finding.objects(test=zero_trust_consts.TEST_SEGMENTATION)) == 1
self.assertEqual(len(Finding.objects(status=zero_trust_consts.STATUS_FAILED)), 1) assert len(Finding.objects(status=zero_trust_consts.STATUS_FAILED)) == 1

View File

@ -2,7 +2,6 @@ from bson import ObjectId
from monkey_island.cc.services.edge.displayed_edge import DisplayedEdgeService from monkey_island.cc.services.edge.displayed_edge import DisplayedEdgeService
from monkey_island.cc.services.edge.edge import RIGHT_ARROW, EdgeService from monkey_island.cc.services.edge.edge import RIGHT_ARROW, EdgeService
from monkey_island.cc.testing.IslandTestCase import IslandTestCase
SCAN_DATA_MOCK = [{ SCAN_DATA_MOCK = [{
"timestamp": "2020-05-27T14:59:28.944Z", "timestamp": "2020-05-27T14:59:28.944Z",
@ -45,7 +44,7 @@ EXPLOIT_DATA_MOCK = [{
}] }]
class TestDisplayedEdgeService(IslandTestCase): class TestDisplayedEdgeService:
def test_get_displayed_edges_by_to(self): def test_get_displayed_edges_by_to(self):
dst_id = ObjectId() dst_id = ObjectId()
@ -57,7 +56,7 @@ class TestDisplayedEdgeService(IslandTestCase):
EdgeService.get_or_create_edge(src_id2, dst_id, "Ubuntu-4ubuntu3.2", "Ubuntu-4ubuntu2.8") EdgeService.get_or_create_edge(src_id2, dst_id, "Ubuntu-4ubuntu3.2", "Ubuntu-4ubuntu2.8")
displayed_edges = DisplayedEdgeService.get_displayed_edges_by_dst(str(dst_id)) displayed_edges = DisplayedEdgeService.get_displayed_edges_by_dst(str(dst_id))
self.assertEqual(len(displayed_edges), 2) assert len(displayed_edges) == 2
def test_edge_to_displayed_edge(self): def test_edge_to_displayed_edge(self):
src_node_id = ObjectId() src_node_id = ObjectId()
@ -74,22 +73,21 @@ class TestDisplayedEdgeService(IslandTestCase):
displayed_edge = DisplayedEdgeService.edge_to_displayed_edge(edge) displayed_edge = DisplayedEdgeService.edge_to_displayed_edge(edge)
self.assertEqual(displayed_edge['to'], dst_node_id) assert displayed_edge['to'] == dst_node_id
self.assertEqual(displayed_edge['from'], src_node_id) assert displayed_edge['from'] == src_node_id
self.assertEqual(displayed_edge['ip_address'], "10.2.2.2") assert displayed_edge['ip_address'] == "10.2.2.2"
self.assertListEqual(displayed_edge['services'], ["tcp-8088: unknown", "tcp-22: ssh"]) assert displayed_edge['services'] == ["tcp-8088: unknown", "tcp-22: ssh"]
self.assertEqual(displayed_edge['os'], {"type": "linux", assert displayed_edge['os'] == {"type": "linux", "version": "Ubuntu-4ubuntu2.8"}
"version": "Ubuntu-4ubuntu2.8"}) assert displayed_edge['exploits'] == EXPLOIT_DATA_MOCK
self.assertEqual(displayed_edge['exploits'], EXPLOIT_DATA_MOCK) assert displayed_edge['_label'] == "Ubuntu-4ubuntu3.2 " + RIGHT_ARROW + " Ubuntu-4ubuntu2.8"
self.assertEqual(displayed_edge['_label'], "Ubuntu-4ubuntu3.2 " + RIGHT_ARROW + " Ubuntu-4ubuntu2.8") assert displayed_edge['group'] == "exploited"
self.assertEqual(displayed_edge['group'], "exploited")
return displayed_edge return displayed_edge
def test_services_to_displayed_services(self): def test_services_to_displayed_services(self):
services1 = DisplayedEdgeService.services_to_displayed_services(SCAN_DATA_MOCK[-1]["data"]["services"], services1 = DisplayedEdgeService.services_to_displayed_services(SCAN_DATA_MOCK[-1]["data"]["services"],
True) True)
self.assertEqual(services1, ["tcp-8088", "tcp-22"]) assert services1 == ["tcp-8088", "tcp-22"]
services2 = DisplayedEdgeService.services_to_displayed_services(SCAN_DATA_MOCK[-1]["data"]["services"], services2 = DisplayedEdgeService.services_to_displayed_services(SCAN_DATA_MOCK[-1]["data"]["services"],
False) False)
self.assertEqual(services2, ["tcp-8088: unknown", "tcp-22: ssh"]) assert services2 == ["tcp-8088: unknown", "tcp-22: ssh"]

View File

@ -1,57 +1,51 @@
import logging import logging
import pytest
from mongomock import ObjectId from mongomock import ObjectId
from monkey_island.cc.models.edge import Edge from monkey_island.cc.models.edge import Edge
from monkey_island.cc.services.edge.edge import EdgeService from monkey_island.cc.services.edge.edge import EdgeService
from monkey_island.cc.testing.IslandTestCase import IslandTestCase
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TestEdgeService(IslandTestCase): class TestEdgeService:
"""
Make sure to set server environment to `testing` in server_config.json!
Otherwise this will mess up your mongo instance and won't work.
Also, the working directory needs to be the working directory from which you usually run the island so the
server_config.json file is found and loaded.
"""
@pytest.mark.usefixtures('uses_database')
def test_get_or_create_edge(self): def test_get_or_create_edge(self):
src_id = ObjectId() src_id = ObjectId()
dst_id = ObjectId() dst_id = ObjectId()
test_edge1 = EdgeService.get_or_create_edge(src_id, dst_id, "Mock label 1", "Mock label 2") test_edge1 = EdgeService.get_or_create_edge(src_id, dst_id, "Mock label 1", "Mock label 2")
self.assertEqual(test_edge1.src_node_id, src_id) assert test_edge1.src_node_id == src_id
self.assertEqual(test_edge1.dst_node_id, dst_id) assert test_edge1.dst_node_id == dst_id
self.assertFalse(test_edge1.exploited) assert not test_edge1.exploited
self.assertFalse(test_edge1.tunnel) assert not test_edge1.tunnel
self.assertListEqual(test_edge1.scans, []) assert test_edge1.scans == []
self.assertListEqual(test_edge1.exploits, []) assert test_edge1.exploits == []
self.assertEqual(test_edge1.src_label, "Mock label 1") assert test_edge1.src_label == "Mock label 1"
self.assertEqual(test_edge1.dst_label, "Mock label 2") assert test_edge1.dst_label == "Mock label 2"
self.assertIsNone(test_edge1.group) assert test_edge1.group is None
self.assertIsNone(test_edge1.domain_name) assert test_edge1.domain_name is None
self.assertIsNone(test_edge1.ip_address) assert test_edge1.ip_address is None
EdgeService.get_or_create_edge(src_id, dst_id, "Mock label 1", "Mock label 2") EdgeService.get_or_create_edge(src_id, dst_id, "Mock label 1", "Mock label 2")
self.assertEqual(len(Edge.objects()), 1) assert len(Edge.objects()) == 1
def test_get_edge_group(self): def test_get_edge_group(self):
edge = Edge(src_node_id=ObjectId(), edge = Edge(src_node_id=ObjectId(),
dst_node_id=ObjectId(), dst_node_id=ObjectId(),
exploited=True) exploited=True)
self.assertEqual("exploited", EdgeService.get_group(edge)) assert "exploited" == EdgeService.get_group(edge)
edge.exploited = False edge.exploited = False
edge.tunnel = True edge.tunnel = True
self.assertEqual("tunnel", EdgeService.get_group(edge)) assert "tunnel" == EdgeService.get_group(edge)
edge.tunnel = False edge.tunnel = False
edge.exploits.append(["mock_exploit_data"]) edge.exploits.append(["mock_exploit_data"])
self.assertEqual("scan", EdgeService.get_group(edge)) assert "scan" == EdgeService.get_group(edge)
edge.exploits = [] edge.exploits = []
edge.scans = [] edge.scans = []
self.assertEqual("empty", EdgeService.get_group(edge)) assert "empty" == EdgeService.get_group(edge)

View File

@ -2,12 +2,11 @@ import uuid
from monkey_island.cc.models import Monkey from monkey_island.cc.models import Monkey
from monkey_island.cc.services.reporting.pth_report import PTHReportService from monkey_island.cc.services.reporting.pth_report import PTHReportService
from monkey_island.cc.testing.IslandTestCase import IslandTestCase
class TestPTHReportServiceGenerateMapNodes(IslandTestCase): class TestPTHReportServiceGenerateMapNodes():
def test_generate_map_nodes(self): def test_generate_map_nodes(self):
self.assertEqual(PTHReportService.generate_map_nodes(), []) assert PTHReportService.generate_map_nodes() == []
windows_monkey_with_services = Monkey( windows_monkey_with_services = Monkey(
guid=str(uuid.uuid4()), guid=str(uuid.uuid4()),
@ -37,7 +36,7 @@ class TestPTHReportServiceGenerateMapNodes(IslandTestCase):
map_nodes = PTHReportService.generate_map_nodes() map_nodes = PTHReportService.generate_map_nodes()
self.assertEqual(2, len(map_nodes)) assert 2 == len(map_nodes)
def test_generate_map_nodes_parsing(self): def test_generate_map_nodes_parsing(self):
monkey_id = str(uuid.uuid4()) monkey_id = str(uuid.uuid4())
@ -53,8 +52,8 @@ class TestPTHReportServiceGenerateMapNodes(IslandTestCase):
map_nodes = PTHReportService.generate_map_nodes() map_nodes = PTHReportService.generate_map_nodes()
self.assertEqual(map_nodes[0]["id"], monkey_id) assert map_nodes[0]["id"] == monkey_id
self.assertEqual(map_nodes[0]["label"], "A_Windows_PC_1 : 1.1.1.1") assert map_nodes[0]["label"] == "A_Windows_PC_1 : 1.1.1.1"
self.assertEqual(map_nodes[0]["group"], "critical") assert map_nodes[0]["group"] == "critical"
self.assertEqual(len(map_nodes[0]["services"]), 2) assert len(map_nodes[0]["services"]) == 2
self.assertEqual(map_nodes[0]["hostname"], hostname) assert map_nodes[0]["hostname"] == hostname

View File

@ -3,10 +3,9 @@ import uuid
from monkey_island.cc.models import Monkey from monkey_island.cc.models import Monkey
from monkey_island.cc.services.telemetry.processing.system_info_collectors.system_info_telemetry_dispatcher import \ from monkey_island.cc.services.telemetry.processing.system_info_collectors.system_info_telemetry_dispatcher import \
SystemInfoTelemetryDispatcher SystemInfoTelemetryDispatcher
from monkey_island.cc.testing.IslandTestCase import IslandTestCase
class TestEnvironmentTelemetryProcessing(IslandTestCase): class TestEnvironmentTelemetryProcessing:
def test_process_environment_telemetry(self): def test_process_environment_telemetry(self):
# Arrange # Arrange
monkey_guid = str(uuid.uuid4()) monkey_guid = str(uuid.uuid4())
@ -25,4 +24,4 @@ class TestEnvironmentTelemetryProcessing(IslandTestCase):
} }
dispatcher.dispatch_collector_results_to_relevant_processors(telem_json) dispatcher.dispatch_collector_results_to_relevant_processors(telem_json)
self.assertEqual(Monkey.get_single_monkey_by_guid(monkey_guid).environment, on_premise) assert Monkey.get_single_monkey_by_guid(monkey_guid).environment == on_premise

View File

@ -1,32 +1,33 @@
import uuid import uuid
import pytest
from monkey_island.cc.models import Monkey from monkey_island.cc.models import Monkey
from monkey_island.cc.services.telemetry.processing.system_info_collectors.system_info_telemetry_dispatcher import ( from monkey_island.cc.services.telemetry.processing.system_info_collectors.system_info_telemetry_dispatcher import (
SystemInfoTelemetryDispatcher, process_aws_telemetry) SystemInfoTelemetryDispatcher, process_aws_telemetry)
from monkey_island.cc.testing.IslandTestCase import IslandTestCase
TEST_SYS_INFO_TO_PROCESSING = { TEST_SYS_INFO_TO_PROCESSING = {
"AwsCollector": [process_aws_telemetry], "AwsCollector": [process_aws_telemetry],
} }
class SystemInfoTelemetryDispatcherTest(IslandTestCase): class TestSystemInfoTelemetryDispatcher:
def test_dispatch_to_relevant_collector_bad_inputs(self): def test_dispatch_to_relevant_collector_bad_inputs(self):
self.fail_if_not_testing_env()
dispatcher = SystemInfoTelemetryDispatcher(TEST_SYS_INFO_TO_PROCESSING) dispatcher = SystemInfoTelemetryDispatcher(TEST_SYS_INFO_TO_PROCESSING)
# Bad format telem JSONs - throws # Bad format telem JSONs - throws
bad_empty_telem_json = {} bad_empty_telem_json = {}
self.assertRaises(KeyError, dispatcher.dispatch_collector_results_to_relevant_processors, bad_empty_telem_json) with pytest.raises(KeyError):
dispatcher.dispatch_collector_results_to_relevant_processors(bad_empty_telem_json)
bad_no_data_telem_json = {"monkey_guid": "bla"} bad_no_data_telem_json = {"monkey_guid": "bla"}
self.assertRaises(KeyError, with pytest.raises(KeyError):
dispatcher.dispatch_collector_results_to_relevant_processors, dispatcher.dispatch_collector_results_to_relevant_processors(bad_no_data_telem_json)
bad_no_data_telem_json)
bad_no_monkey_telem_json = {"data": {"collectors": {"AwsCollector": "Bla"}}} bad_no_monkey_telem_json = {"data": {"collectors": {"AwsCollector": "Bla"}}}
self.assertRaises(KeyError, with pytest.raises(KeyError):
dispatcher.dispatch_collector_results_to_relevant_processors, dispatcher.dispatch_collector_results_to_relevant_processors(bad_no_monkey_telem_json)
bad_no_monkey_telem_json)
# Telem JSON with no collectors - nothing gets dispatched # Telem JSON with no collectors - nothing gets dispatched
good_telem_no_collectors = {"monkey_guid": "bla", "data": {"bla": "bla"}} good_telem_no_collectors = {"monkey_guid": "bla", "data": {"bla": "bla"}}
@ -53,4 +54,4 @@ class SystemInfoTelemetryDispatcherTest(IslandTestCase):
} }
dispatcher.dispatch_collector_results_to_relevant_processors(telem_json) dispatcher.dispatch_collector_results_to_relevant_processors(telem_json)
self.assertEquals(Monkey.get_single_monkey_by_guid(a_monkey.guid).aws_instance_id, instance_id) assert Monkey.get_single_monkey_by_guid(a_monkey.guid).aws_instance_id == instance_id

View File

@ -6,14 +6,13 @@ from monkey_island.cc.models.zero_trust.event import Event
from monkey_island.cc.models.zero_trust.finding import Finding from monkey_island.cc.models.zero_trust.finding import Finding
from monkey_island.cc.services.telemetry.zero_trust_checks.segmentation import create_or_add_findings_for_all_pairs from monkey_island.cc.services.telemetry.zero_trust_checks.segmentation import create_or_add_findings_for_all_pairs
from monkey_island.cc.services.zero_trust.monkey_findings.monkey_zt_finding_service import MonkeyZTFindingService from monkey_island.cc.services.zero_trust.monkey_findings.monkey_zt_finding_service import MonkeyZTFindingService
from monkey_island.cc.testing.IslandTestCase import IslandTestCase
FIRST_SUBNET = "1.1.1.1" FIRST_SUBNET = "1.1.1.1"
SECOND_SUBNET = "2.2.2.0/24" SECOND_SUBNET = "2.2.2.0/24"
THIRD_SUBNET = "3.3.3.3-3.3.3.200" THIRD_SUBNET = "3.3.3.3-3.3.3.200"
class TestSegmentationChecks(IslandTestCase): class TestSegmentationChecks:
def test_create_findings_for_all_done_pairs(self): def test_create_findings_for_all_done_pairs(self):
all_subnets = [FIRST_SUBNET, SECOND_SUBNET, THIRD_SUBNET] all_subnets = [FIRST_SUBNET, SECOND_SUBNET, THIRD_SUBNET]
@ -23,15 +22,15 @@ class TestSegmentationChecks(IslandTestCase):
ip_addresses=[FIRST_SUBNET]) ip_addresses=[FIRST_SUBNET])
# no findings # no findings
self.assertEqual(len(Finding.objects(test=zero_trust_consts.TEST_SEGMENTATION)), 0) assert len(Finding.objects(test=zero_trust_consts.TEST_SEGMENTATION)) == 0
# This is like the monkey is done and sent done telem # This is like the monkey is done and sent done telem
create_or_add_findings_for_all_pairs(all_subnets, monkey) create_or_add_findings_for_all_pairs(all_subnets, monkey)
# There are 2 subnets in which the monkey is NOT # There are 2 subnets in which the monkey is NOT
self.assertEqual( zt_seg_findings = Finding.objects(test=zero_trust_consts.TEST_SEGMENTATION,
len(Finding.objects(test=zero_trust_consts.TEST_SEGMENTATION, status=zero_trust_consts.STATUS_PASSED)), status=zero_trust_consts.STATUS_PASSED)
2) assert len(zt_seg_findings) == 2
# This is a monkey from 2nd subnet communicated with 1st subnet. # This is a monkey from 2nd subnet communicated with 1st subnet.
MonkeyZTFindingService.create_or_add_to_existing( MonkeyZTFindingService.create_or_add_to_existing(
@ -43,12 +42,13 @@ class TestSegmentationChecks(IslandTestCase):
) )
self.assertEqual( zt_seg_findings = Finding.objects(test=zero_trust_consts.TEST_SEGMENTATION,
len(Finding.objects(test=zero_trust_consts.TEST_SEGMENTATION, status=zero_trust_consts.STATUS_PASSED)), status=zero_trust_consts.STATUS_PASSED)
1) assert len(zt_seg_findings) == 1
self.assertEqual(
len(Finding.objects(test=zero_trust_consts.TEST_SEGMENTATION, status=zero_trust_consts.STATUS_FAILED)), zt_seg_findings = Finding.objects(test=zero_trust_consts.TEST_SEGMENTATION,
1) status=zero_trust_consts.STATUS_FAILED)
self.assertEqual( assert len(zt_seg_findings) == 1
len(Finding.objects(test=zero_trust_consts.TEST_SEGMENTATION)),
2) zt_seg_findings = Finding.objects(test=zero_trust_consts.TEST_SEGMENTATION)
assert len(zt_seg_findings) == 2

View File

@ -1,14 +1,8 @@
from datetime import datetime from datetime import datetime
from typing import List
from bson import ObjectId
from common.common_consts import zero_trust_consts from common.common_consts import zero_trust_consts
from monkey_island.cc.models.zero_trust.event import Event from monkey_island.cc.models.zero_trust.event import Event
from monkey_island.cc.models.zero_trust.finding import Finding
from monkey_island.cc.models.zero_trust.monkey_finding_details import MonkeyFindingDetails
from monkey_island.cc.services.zero_trust.monkey_findings.monkey_zt_finding_service import MonkeyZTFindingService from monkey_island.cc.services.zero_trust.monkey_findings.monkey_zt_finding_service import MonkeyZTFindingService
from monkey_island.cc.testing.IslandTestCase import IslandTestCase
EVENTS = [ EVENTS = [
Event.create_event( Event.create_event(
@ -38,7 +32,7 @@ STATUS = [
] ]
class TestMonkeyZTFindingService(IslandTestCase): class TestMonkeyZTFindingService:
def test_create_or_add_to_existing(self): def test_create_or_add_to_existing(self):

View File

@ -1,10 +1,11 @@
import pytest
from common.common_consts import zero_trust_consts from common.common_consts import zero_trust_consts
from monkey_island.cc.models.zero_trust.finding import Finding from monkey_island.cc.models.zero_trust.finding import Finding
from monkey_island.cc.models.zero_trust.scoutsuite_rule import ScoutSuiteRule from monkey_island.cc.models.zero_trust.scoutsuite_rule import ScoutSuiteRule
from monkey_island.cc.services.zero_trust.scoutsuite.consts.findings import PermissiveFirewallRules, \ from monkey_island.cc.services.zero_trust.scoutsuite.consts.findings import PermissiveFirewallRules, \
UnencryptedData UnencryptedData
from monkey_island.cc.services.zero_trust.scoutsuite.scoutsuite_zt_finding_service import ScoutSuiteZTFindingService from monkey_island.cc.services.zero_trust.scoutsuite.scoutsuite_zt_finding_service import ScoutSuiteZTFindingService
from monkey_island.cc.testing.IslandTestCase import IslandTestCase
RULES = [ RULES = [
ScoutSuiteRule( ScoutSuiteRule(
@ -71,35 +72,36 @@ FINDINGS = [
] ]
class TestScoutSuiteZTFindingService(IslandTestCase): class TestScoutSuiteZTFindingService:
@pytest.mark.usefixtures('uses_database')
def test_process_rule(self): def test_process_rule(self):
# Creates new PermissiveFirewallRules finding with a rule # Creates new PermissiveFirewallRules finding with a rule
ScoutSuiteZTFindingService.process_rule(FINDINGS[0], RULES[0]) ScoutSuiteZTFindingService.process_rule(FINDINGS[0], RULES[0])
findings = list(Finding.objects()) findings = list(Finding.objects())
self.assertEqual(len(findings), 1) assert len(findings) == 1
self.assertEqual(findings[0].finding_type, zero_trust_consts.SCOUTSUITE_FINDING) assert findings[0].finding_type == zero_trust_consts.SCOUTSUITE_FINDING
# Assert that details were created properly # Assert that details were created properly
details = findings[0].details.fetch() details = findings[0].details.fetch()
self.assertEqual(len(details.scoutsuite_rules), 1) assert len(details.scoutsuite_rules) == 1
self.assertEqual(details.scoutsuite_rules[0], RULES[0]) assert details.scoutsuite_rules[0] == RULES[0]
# Rule processing should add rule to an already existing finding # Rule processing should add rule to an already existing finding
ScoutSuiteZTFindingService.process_rule(FINDINGS[0], RULES[1]) ScoutSuiteZTFindingService.process_rule(FINDINGS[0], RULES[1])
findings = list(Finding.objects()) findings = list(Finding.objects())
self.assertEqual(len(findings), 1) assert len(findings) == 1
self.assertEqual(findings[0].finding_type, zero_trust_consts.SCOUTSUITE_FINDING) assert findings[0].finding_type == zero_trust_consts.SCOUTSUITE_FINDING
# Assert that details were created properly # Assert that details were created properly
details = findings[0].details.fetch() details = findings[0].details.fetch()
self.assertEqual(len(details.scoutsuite_rules), 2) assert len(details.scoutsuite_rules) == 2
self.assertEqual(details.scoutsuite_rules[1], RULES[1]) assert details.scoutsuite_rules[1] == RULES[1]
# New finding created # New finding created
ScoutSuiteZTFindingService.process_rule(FINDINGS[1], RULES[1]) ScoutSuiteZTFindingService.process_rule(FINDINGS[1], RULES[1])
findings = list(Finding.objects()) findings = list(Finding.objects())
self.assertEqual(len(findings), 2) assert len(findings) == 2
self.assertEqual(findings[1].finding_type, zero_trust_consts.SCOUTSUITE_FINDING) assert findings[1].finding_type == zero_trust_consts.SCOUTSUITE_FINDING
# Assert that details were created properly # Assert that details were created properly
details = findings[1].details.fetch() details = findings[1].details.fetch()
self.assertEqual(len(details.scoutsuite_rules), 1) assert len(details.scoutsuite_rules) == 1
self.assertEqual(details.scoutsuite_rules[0], RULES[1]) assert details.scoutsuite_rules[0] == RULES[1]

View File

@ -0,0 +1,32 @@
import mongoengine
import pytest
from monkey_island.cc.models import Monkey
from monkey_island.cc.models.edge import Edge
from monkey_island.cc.models.zero_trust.finding import Finding
@pytest.fixture(scope='session', autouse=True)
def change_to_mongo_mock():
# Make sure tests are working with mongomock
mongoengine.disconnect()
mongoengine.connect('mongoenginetest', host='mongomock://localhost')
@pytest.fixture(scope='function')
def uses_database():
_clean_edge_db()
_clean_monkey_db()
_clean_finding_db()
def _clean_monkey_db():
Monkey.objects().delete()
def _clean_edge_db():
Edge.objects().delete()
def _clean_finding_db():
Finding.objects().delete()

View File

@ -1,41 +0,0 @@
import unittest
import mongoengine
import monkey_island.cc.environment.environment_singleton as env_singleton
from monkey_island.cc.models import Monkey
from monkey_island.cc.models.edge import Edge
from monkey_island.cc.models.zero_trust.finding import Finding
class IslandTestCase(unittest.TestCase):
def __init__(self, methodName):
# Make sure test is working with mongomock
if mongoengine.connection.get_connection().server_info()['sysInfo'] != 'Mock':
mongoengine.disconnect()
mongoengine.connect('mongoenginetest', host='mongomock://localhost')
else:
IslandTestCase.clean_db()
super().__init__(methodName)
def fail_if_not_testing_env(self):
self.assertFalse(not env_singleton.env.testing, "Change server_config.json to testing environment.")
@staticmethod
def clean_db():
IslandTestCase._clean_edge_db()
IslandTestCase._clean_monkey_db()
IslandTestCase._clean_finding_db()
@staticmethod
def _clean_monkey_db():
Monkey.objects().delete()
@staticmethod
def _clean_edge_db():
Edge.objects().delete()
@staticmethod
def _clean_finding_db():
Finding.objects().delete()