diff --git a/monkey/monkey_island/cc/models/monkey.py b/monkey/monkey_island/cc/models/monkey.py index 1a0e872f6..324903809 100644 --- a/monkey/monkey_island/cc/models/monkey.py +++ b/monkey/monkey_island/cc/models/monkey.py @@ -97,6 +97,25 @@ class Monkey(Document): os = "windows" return os + @ring.lru() + @staticmethod + def get_label_by_id(object_id): + current_monkey = Monkey.get_single_monkey_by_id(object_id) + label = Monkey.get_hostname_by_id(object_id) + " : " + current_monkey.ip_addresses[0] + if len(set(current_monkey.ip_addresses).intersection(local_ip_addresses())) > 0: + label = "MonkeyIsland - " + label + return label + + @ring.lru() + @staticmethod + def get_hostname_by_id(object_id): + """ + :param object_id: the object ID of a Monkey in the database. + :return: The hostname of that machine. + :note: Use this and not monkey.hostname for performance - this is lru-cached. + """ + return Monkey.get_single_monkey_by_id(object_id).hostname + def set_hostname(self, hostname): """ Sets a new hostname for a machine and clears the cache for getting it. @@ -104,8 +123,8 @@ class Monkey(Document): """ self.hostname = hostname self.save() - get_monkey_hostname_by_id.delete(self.id) - get_monkey_label_by_id.delete(self.id) + Monkey.get_hostname_by_id.delete(self.id) + Monkey.get_label_by_id.delete(self.id) def get_network_info(self): """ @@ -114,6 +133,17 @@ class Monkey(Document): """ return {'ips': self.ip_addresses, 'hostname': self.hostname} + @ring.lru( + expire=1 # data has TTL of 1 second. This is useful for rapid calls for report generation. + ) + @staticmethod + def is_monkey(object_id): + try: + _ = Monkey.get_single_monkey_by_id(object_id) + return True + except: + return False + @staticmethod def get_tunneled_monkeys(): return Monkey.objects(tunnel__exists=True) @@ -123,37 +153,5 @@ class Monkey(Document): self.save() -# TODO Can't make following methods static under Monkey class due to ring bug. When ring will support static methods, we -# should move to static methods in the Monkey class. -@ring.lru( - expire=1 # data has TTL of 1 second. This is useful for rapid calls for report generation. -) -def is_monkey(object_id): - try: - _ = Monkey.get_single_monkey_by_id(object_id) - return True - except: - return False - - -@ring.lru() -def get_monkey_label_by_id(object_id): - current_monkey = Monkey.get_single_monkey_by_id(object_id) - label = get_monkey_hostname_by_id(object_id) + " : " + current_monkey.ip_addresses[0] - if len(set(current_monkey.ip_addresses).intersection(local_ip_addresses())) > 0: - label = "MonkeyIsland - " + label - return label - - -@ring.lru() -def get_monkey_hostname_by_id(object_id): - """ - :param object_id: the object ID of a Monkey in the database. - :return: The hostname of that machine. - :note: Use this and not monkey.hostname for performance - this is lru-cached. - """ - return Monkey.get_single_monkey_by_id(object_id).hostname - - class MonkeyNotFoundError(Exception): pass diff --git a/monkey/monkey_island/cc/models/test_monkey.py b/monkey/monkey_island/cc/models/test_monkey.py index 472c5770b..d399355a3 100644 --- a/monkey/monkey_island/cc/models/test_monkey.py +++ b/monkey/monkey_island/cc/models/test_monkey.py @@ -4,7 +4,7 @@ from time import sleep import pytest -from monkey_island.cc.models.monkey import Monkey, MonkeyNotFoundError, is_monkey, get_monkey_label_by_id +from monkey_island.cc.models.monkey import Monkey, MonkeyNotFoundError from monkey_island.cc.testing.IslandTestCase import IslandTestCase from .monkey_ttl import MonkeyTtl @@ -131,13 +131,12 @@ class TestMonkey(IslandTestCase): ip_addresses=[ip_example]) linux_monkey.save() - logger.debug(id(get_monkey_label_by_id)) - - cache_info_before_query = get_monkey_label_by_id.storage.backend.cache_info() + cache_info_before_query = Monkey.get_label_by_id.storage.backend.cache_info() self.assertEqual(cache_info_before_query.hits, 0) self.assertEqual(cache_info_before_query.misses, 0) # not cached + label = Monkey.get_label_by_id(linux_monkey.id) label = get_monkey_label_by_id(linux_monkey.id) cache_info_after_query_1 = get_monkey_label_by_id.storage.backend.cache_info() self.assertEqual(cache_info_after_query_1.hits, 0) @@ -149,23 +148,23 @@ class TestMonkey(IslandTestCase): self.assertIn(ip_example, label) # should be cached + _ = Monkey.get_label_by_id(linux_monkey.id) + cache_info_after_query = Monkey.get_label_by_id.storage.backend.cache_info() + self.assertEqual(cache_info_after_query.hits, 1) label = get_monkey_label_by_id(linux_monkey.id) logger.info("2) ID: {} label: {}".format(linux_monkey.id, label)) cache_info_after_query_2 = get_monkey_label_by_id.storage.backend.cache_info() self.assertEqual(cache_info_after_query_2.hits, 1) self.assertEqual(cache_info_after_query_2.misses, 1) - # set hostname deletes the id from the cache. linux_monkey.set_hostname("Another hostname") # should be a miss - label = get_monkey_label_by_id(linux_monkey.id) - logger.info("3) ID: {} label: {}".format(linux_monkey.id, label)) - cache_info_after_query_3 = get_monkey_label_by_id.storage.backend.cache_info() - logger.debug("Cache info: {}".format(str(cache_info_after_query_3))) + label = Monkey.get_label_by_id(linux_monkey.id) + cache_info_after_second_query = Monkey.get_label_by_id.storage.backend.cache_info() # still 1 hit only - self.assertEqual(cache_info_after_query_3.hits, 1) - self.assertEqual(cache_info_after_query_3.misses, 2) + self.assertEqual(cache_info_after_second_query.hits, 1) + self.assertEqual(cache_info_after_second_query.misses, 2) def test_is_monkey(self): self.fail_if_not_testing_env() @@ -174,18 +173,18 @@ class TestMonkey(IslandTestCase): a_monkey = Monkey(guid=str(uuid.uuid4())) a_monkey.save() - cache_info_before_query = is_monkey.storage.backend.cache_info() + cache_info_before_query = Monkey.is_monkey.storage.backend.cache_info() self.assertEqual(cache_info_before_query.hits, 0) # not cached - self.assertTrue(is_monkey(a_monkey.id)) + self.assertTrue(Monkey.is_monkey(a_monkey.id)) fake_id = "123456789012" - self.assertFalse(is_monkey(fake_id)) + self.assertFalse(Monkey.is_monkey(fake_id)) # should be cached - self.assertTrue(is_monkey(a_monkey.id)) - self.assertFalse(is_monkey(fake_id)) + self.assertTrue(Monkey.is_monkey(a_monkey.id)) + self.assertFalse(Monkey.is_monkey(fake_id)) - cache_info_after_query = is_monkey.storage.backend.cache_info() + cache_info_after_query = Monkey.is_monkey.storage.backend.cache_info() self.assertEqual(cache_info_after_query.hits, 2) diff --git a/monkey/monkey_island/cc/services/edge.py b/monkey/monkey_island/cc/services/edge.py index bf9417309..ae3d2a2de 100644 --- a/monkey/monkey_island/cc/services/edge.py +++ b/monkey/monkey_island/cc/services/edge.py @@ -2,7 +2,7 @@ from bson import ObjectId from monkey_island.cc.database import mongo import monkey_island.cc.services.node -from monkey_island.cc.models.monkey import get_monkey_label_by_id, is_monkey +from monkey_island.cc.models import Monkey __author__ = "itay.mizeretz" @@ -145,13 +145,13 @@ class EdgeService: from_id = edge["from"] to_id = edge["to"] - from_label = get_monkey_label_by_id(from_id) + from_label = Monkey.get_label_by_id(from_id) if to_id == ObjectId("000000000000000000000000"): to_label = 'MonkeyIsland' else: - if is_monkey(to_id): - to_label = get_monkey_label_by_id(to_id) + if Monkey.is_monkey(to_id): + to_label = Monkey.get_label_by_id(to_id) else: to_label = NodeService.get_node_label(NodeService.get_node_by_id(to_id)) diff --git a/monkey/monkey_island/cc/services/node.py b/monkey/monkey_island/cc/services/node.py index 0c0a873e8..27d2d299a 100644 --- a/monkey/monkey_island/cc/services/node.py +++ b/monkey/monkey_island/cc/services/node.py @@ -4,7 +4,7 @@ from bson import ObjectId import monkey_island.cc.services.log from monkey_island.cc.database import mongo -from monkey_island.cc.models.monkey import Monkey, get_monkey_hostname_by_id, get_monkey_label_by_id +from monkey_island.cc.models import Monkey from monkey_island.cc.services.edge import EdgeService from monkey_island.cc.utils import local_ip_addresses import socket @@ -50,8 +50,8 @@ class NodeService: for edge in edges: from_node_id = edge["from"] - from_node_label = get_monkey_label_by_id(from_node_id) - from_node_hostname = get_monkey_hostname_by_id(from_node_id) + from_node_label = Monkey.get_label_by_id(from_node_id) + from_node_hostname = Monkey.get_hostname_by_id(from_node_id) accessible_from_nodes.append(from_node_label) accessible_from_nodes_hostnames.append(from_node_hostname) @@ -140,7 +140,7 @@ class NodeService: @staticmethod def monkey_to_net_node(monkey, for_report=False): monkey_id = monkey["_id"] - label = get_monkey_hostname_by_id(monkey_id) if for_report else get_monkey_label_by_id(monkey_id) + label = Monkey.get_hostname_by_id(monkey_id) if for_report else Monkey.get_label_by_id(monkey_id) monkey_group = NodeService.get_monkey_group(monkey) return \ {