From fb595319699d16fc70fa089302ee8f9afacc4592 Mon Sep 17 00:00:00 2001
From: VakarisZ <vakarisz@yahoo.com>
Date: Mon, 8 Jun 2020 10:29:04 +0300
Subject: [PATCH] Refactored EdgeService into a boundary object.

---
 monkey/monkey_island/cc/models/edge.py        |   3 +
 monkey/monkey_island/cc/resources/monkey.py   |   7 +-
 .../cc/services/edge/displayed_edge.py        |  28 +++--
 monkey/monkey_island/cc/services/edge/edge.py | 103 +++++++++++-------
 .../cc/services/edge/test_displayed_edge.py   |  20 ++--
 .../cc/services/edge/test_edge.py             |   8 +-
 .../cc/services/netmap/net_edge.py            |   6 +-
 monkey/monkey_island/cc/services/node.py      |   9 +-
 .../services/telemetry/processing/exploit.py  |   6 +-
 .../cc/services/telemetry/processing/scan.py  |   4 +-
 10 files changed, 111 insertions(+), 83 deletions(-)

diff --git a/monkey/monkey_island/cc/models/edge.py b/monkey/monkey_island/cc/models/edge.py
index 31a51598e..09af04680 100644
--- a/monkey/monkey_island/cc/models/edge.py
+++ b/monkey/monkey_island/cc/models/edge.py
@@ -2,6 +2,9 @@ from mongoengine import Document, ObjectIdField, ListField, DynamicField, Boolea
 
 
 class Edge(Document):
+
+    meta = {'allow_inheritance': True}
+
     # SCHEMA
     src_node_id = ObjectIdField(required=True)
     dst_node_id = ObjectIdField(required=True)
diff --git a/monkey/monkey_island/cc/resources/monkey.py b/monkey/monkey_island/cc/resources/monkey.py
index ae34c624d..a39aaf199 100644
--- a/monkey/monkey_island/cc/resources/monkey.py
+++ b/monkey/monkey_island/cc/resources/monkey.py
@@ -4,7 +4,6 @@ from datetime import datetime
 import dateutil.parser
 import flask_restful
 
-from monkey_island.cc.models.edge import Edge
 from monkey_island.cc.resources.test.utils.telem_store import TestTelemStore
 from flask import request
 
@@ -12,6 +11,7 @@ from monkey_island.cc.consts import DEFAULT_MONKEY_TTL_EXPIRY_DURATION_IN_SECOND
 from monkey_island.cc.database import mongo
 from monkey_island.cc.models.monkey_ttl import create_monkey_ttl_document
 from monkey_island.cc.services.config import ConfigService
+from monkey_island.cc.services.edge.edge import EdgeService
 from monkey_island.cc.services.node import NodeService
 
 __author__ = 'Barak'
@@ -133,9 +133,8 @@ class Monkey(flask_restful.Resource):
 
         if existing_node:
             node_id = existing_node["_id"]
-            for edge in Edge.objects(dst_node_id=node_id):
-                edge.dst_node_id = new_monkey_id
-                edge.save()
+            EdgeService.update_all_dst_nodes(old_dst_node_id=node_id,
+                                             new_dst_node_id=new_monkey_id)
             for creds in existing_node['creds']:
                 NodeService.add_credentials_to_monkey(new_monkey_id, creds)
             mongo.db.node.remove({"_id": node_id})
diff --git a/monkey/monkey_island/cc/services/edge/displayed_edge.py b/monkey/monkey_island/cc/services/edge/displayed_edge.py
index 8f94a6ffa..f7a0664bf 100644
--- a/monkey/monkey_island/cc/services/edge/displayed_edge.py
+++ b/monkey/monkey_island/cc/services/edge/displayed_edge.py
@@ -1,8 +1,8 @@
 from copy import deepcopy
+from typing import Dict
 
 from bson import ObjectId
 
-from monkey_island.cc.models.edge import Edge
 from monkey_island.cc.services.edge.edge import EdgeService
 
 __author__ = "itay.mizeretz"
@@ -11,18 +11,18 @@ __author__ = "itay.mizeretz"
 class DisplayedEdgeService:
 
     @staticmethod
-    def get_displayed_edges_by_dst(dst_id, for_report=False):
-        edges = Edge.objects(dst_node_id=ObjectId(dst_id))
+    def get_displayed_edges_by_dst(dst_id: str, for_report=False):
+        edges = EdgeService.get_by_dst_node(dst_node_id=ObjectId(dst_id))
         return [DisplayedEdgeService.edge_to_displayed_edge(edge, for_report) for edge in edges]
 
     @staticmethod
-    def get_displayed_edge_by_id(edge_id, for_report=False):
-        edge = Edge.objects.get(id=edge_id)
+    def get_displayed_edge_by_id(edge_id: str, for_report=False):
+        edge = EdgeService.get_edge_by_id(ObjectId(edge_id))
         displayed_edge = DisplayedEdgeService.edge_to_displayed_edge(edge, for_report)
         return displayed_edge
 
     @staticmethod
-    def edge_to_displayed_edge(edge: Edge, for_report=False):
+    def edge_to_displayed_edge(edge: EdgeService, for_report=False):
         services = []
         os = {}
 
@@ -33,13 +33,13 @@ class DisplayedEdgeService:
 
         displayed_edge = DisplayedEdgeService.edge_to_net_edge(edge)
 
-        displayed_edge["ip_address"] = edge['ip_address']
+        displayed_edge["ip_address"] = edge.ip_address
         displayed_edge["services"] = services
         displayed_edge["os"] = os
         # we need to deepcopy all mutable edge properties, because weak-reference link is made otherwise,
         # which is destroyed after method is exited and causes an error later.
-        displayed_edge["exploits"] = deepcopy(edge['exploits'])
-        displayed_edge["_label"] = EdgeService.get_edge_label(displayed_edge)
+        displayed_edge["exploits"] = deepcopy(edge.exploits)
+        displayed_edge["_label"] = edge.get_label()
         return displayed_edge
 
     @staticmethod
@@ -53,9 +53,13 @@ class DisplayedEdgeService:
                 "src_label": src_label,
                 "dst_label": dst_label
             }
-        edge["_label"] = EdgeService.get_edge_label(edge)
+        edge["_label"] = DisplayedEdgeService.get_pseudo_label(edge)
         return edge
 
+    @staticmethod
+    def get_pseudo_label(edge: Dict):
+        return f"{edge['src_label']} {RIGHT_ARROW} {edge['dst_label']}"
+
     @staticmethod
     def services_to_displayed_services(services, for_report=False):
         if for_report:
@@ -64,13 +68,13 @@ class DisplayedEdgeService:
             return [x + ": " + (services[x]['name'] if 'name' in services[x] else 'unknown') for x in services]
 
     @staticmethod
-    def edge_to_net_edge(edge: Edge):
+    def edge_to_net_edge(edge: EdgeService):
         return \
             {
                 "id": edge.id,
                 "from": edge.src_node_id,
                 "to": edge.dst_node_id,
-                "group": EdgeService.get_edge_group(edge),
+                "group": edge.get_group(),
                 "src_label": edge.src_label,
                 "dst_label": edge.dst_label
             }
diff --git a/monkey/monkey_island/cc/services/edge/edge.py b/monkey/monkey_island/cc/services/edge/edge.py
index ab4c3c114..4c9ef57d7 100644
--- a/monkey/monkey_island/cc/services/edge/edge.py
+++ b/monkey/monkey_island/cc/services/edge/edge.py
@@ -1,5 +1,7 @@
+from __future__ import annotations
+
 import copy
-from typing import Dict
+from typing import Dict, List
 
 from bson import ObjectId
 from mongoengine import DoesNotExist
@@ -9,35 +11,60 @@ from monkey_island.cc.models.edge import Edge
 RIGHT_ARROW = "\u2192"
 
 
-class EdgeService:
+class EdgeService(Edge):
 
     @staticmethod
-    def get_or_create_edge(src_node_id, dst_node_id, src_label, dst_label):
-        edge = False
+    def get_all_edges() -> List[EdgeService]:
+        return EdgeService.objects()
+
+    @staticmethod
+    def get_or_create_edge(src_node_id, dst_node_id, src_label, dst_label) -> EdgeService:
+        edge = None
         try:
-            edge = Edge.objects.get(src_node_id=src_node_id, dst_node_id=dst_node_id)
+            edge = EdgeService.objects.get(src_node_id=src_node_id, dst_node_id=dst_node_id)
         except DoesNotExist:
-            edge = Edge(src_node_id=src_node_id, dst_node_id=dst_node_id)
+            edge = EdgeService(src_node_id=src_node_id, dst_node_id=dst_node_id)
         finally:
             if edge:
-                edge.src_label = src_label
-                edge.dst_label = dst_label
-                edge.save()
+                edge.update_label(node_id=src_node_id, label=src_label)
+                edge.update_label(node_id=dst_node_id, label=dst_label)
         return edge
 
     @staticmethod
-    def update_label(edge: Edge, node_id: ObjectId, label: str):
-        if edge.src_node_id == node_id:
-            edge.src_label = label
-        elif edge.dst_node_id == node_id:
-            edge.dst_label = label
-        else:
-            raise DoesNotExist("Node id provided does not match with any endpoint of an edge provided.")
-        edge.save()
-        pass
+    def get_by_dst_node(dst_node_id: ObjectId) -> List[EdgeService]:
+        return EdgeService.objects(dst_node_id=dst_node_id)
 
     @staticmethod
-    def update_based_on_scan_telemetry(edge: Edge, telemetry: Dict):
+    def get_edge_by_id(edge_id: ObjectId) -> EdgeService:
+        return EdgeService.objects.get(id=edge_id)
+
+    def update_label(self, node_id: ObjectId, label: str):
+        if self.src_node_id == node_id:
+            self.src_label = label
+        elif self.dst_node_id == node_id:
+            self.dst_label = label
+        else:
+            raise DoesNotExist("Node id provided does not match with any endpoint of an self provided.")
+        self.save()
+
+    @staticmethod
+    def update_all_dst_nodes(old_dst_node_id: ObjectId, new_dst_node_id: ObjectId):
+        for edge in EdgeService.objects(dst_node_id=old_dst_node_id):
+            edge.dst_node_id = new_dst_node_id
+            edge.save()
+
+    @staticmethod
+    def get_tunnel_edges_by_src(src_node_id) -> List[EdgeService]:
+        try:
+            return EdgeService.objects(src_node_id=src_node_id, tunnel=True)
+        except DoesNotExist:
+            return []
+
+    def disable_tunnel(self):
+        self.tunnel = False
+        self.save()
+
+    def update_based_on_scan_telemetry(self, telemetry: Dict):
         machine_info = copy.deepcopy(telemetry['data']['machine'])
         new_scan = \
             {
@@ -46,33 +73,29 @@ class EdgeService:
             }
         ip_address = machine_info.pop("ip_addr")
         domain_name = machine_info.pop("domain_name")
-        edge.scans.append(new_scan)
-        edge.ip_address = ip_address
-        edge.domain_name = domain_name
-        edge.save()
+        self.scans.append(new_scan)
+        self.ip_address = ip_address
+        self.domain_name = domain_name
+        self.save()
 
-    @staticmethod
-    def update_based_on_exploit(edge: Edge, exploit: Dict):
-        edge.exploits.append(exploit)
-        edge.save()
+    def update_based_on_exploit(self, exploit: Dict):
+        self.exploits.append(exploit)
+        self.save()
         if exploit['result']:
-            EdgeService.set_edge_exploited(edge)
+            self.set_exploited()
 
-    @staticmethod
-    def set_edge_exploited(edge: Edge):
-        edge.exploited = True
-        edge.save()
+    def set_exploited(self):
+        self.exploited = True
+        self.save()
 
-    @staticmethod
-    def get_edge_group(edge: Edge):
-        if edge.exploited:
+    def get_group(self) -> str:
+        if self.exploited:
             return "exploited"
-        if edge.tunnel:
+        if self.tunnel:
             return "tunnel"
-        if edge.scans or edge.exploits:
+        if self.scans or self.exploits:
             return "scan"
         return "empty"
 
-    @staticmethod
-    def get_edge_label(edge):
-        return "%s %s %s" % (edge['src_label'], RIGHT_ARROW, edge['dst_label'])
+    def get_label(self) -> str:
+        return f"{self.src_label} {RIGHT_ARROW} {self.dst_label}"
diff --git a/monkey/monkey_island/cc/services/edge/test_displayed_edge.py b/monkey/monkey_island/cc/services/edge/test_displayed_edge.py
index 340efe771..c134bce33 100644
--- a/monkey/monkey_island/cc/services/edge/test_displayed_edge.py
+++ b/monkey/monkey_island/cc/services/edge/test_displayed_edge.py
@@ -58,21 +58,21 @@ class TestDisplayedEdgeService(IslandTestCase):
         src_id2 = ObjectId()
         EdgeService.get_or_create_edge(src_id2, dst_id, "Ubuntu-4ubuntu3.2", "Ubuntu-4ubuntu2.8")
 
-        displayed_edges = DisplayedEdgeService.get_displayed_edges_by_dst(dst_id)
+        displayed_edges = DisplayedEdgeService.get_displayed_edges_by_dst(str(dst_id))
         self.assertEqual(len(displayed_edges), 2)
 
     def test_edge_to_displayed_edge(self):
         src_node_id = ObjectId()
         dst_node_id = ObjectId()
-        edge = Edge(src_node_id=src_node_id,
-                    dst_node_id=dst_node_id,
-                    scans=SCAN_DATA_MOCK,
-                    exploits=EXPLOIT_DATA_MOCK,
-                    exploited=True,
-                    domain_name=None,
-                    ip_address="10.2.2.2",
-                    dst_label="Ubuntu-4ubuntu2.8",
-                    src_label="Ubuntu-4ubuntu3.2")
+        edge = EdgeService(src_node_id=src_node_id,
+                           dst_node_id=dst_node_id,
+                           scans=SCAN_DATA_MOCK,
+                           exploits=EXPLOIT_DATA_MOCK,
+                           exploited=True,
+                           domain_name=None,
+                           ip_address="10.2.2.2",
+                           dst_label="Ubuntu-4ubuntu2.8",
+                           src_label="Ubuntu-4ubuntu3.2")
 
         displayed_edge = DisplayedEdgeService.edge_to_displayed_edge(edge)
 
diff --git a/monkey/monkey_island/cc/services/edge/test_edge.py b/monkey/monkey_island/cc/services/edge/test_edge.py
index 8053f45e3..8ebf45343 100644
--- a/monkey/monkey_island/cc/services/edge/test_edge.py
+++ b/monkey/monkey_island/cc/services/edge/test_edge.py
@@ -45,16 +45,16 @@ class TestEdgeService(IslandTestCase):
         edge = Edge(src_node_id=ObjectId(),
                     dst_node_id=ObjectId(),
                     exploited=True)
-        self.assertEqual("exploited", EdgeService.get_edge_group(edge))
+        self.assertEqual("exploited", EdgeService.get_group(edge))
 
         edge.exploited = False
         edge.tunnel = True
-        self.assertEqual("tunnel", EdgeService.get_edge_group(edge))
+        self.assertEqual("tunnel", EdgeService.get_group(edge))
 
         edge.tunnel = False
         edge.exploits.append(["mock_exploit_data"])
-        self.assertEqual("scan", EdgeService.get_edge_group(edge))
+        self.assertEqual("scan", EdgeService.get_group(edge))
 
         edge.exploits = []
         edge.scans = []
-        self.assertEqual("empty", EdgeService.get_edge_group(edge))
+        self.assertEqual("empty", EdgeService.get_group(edge))
diff --git a/monkey/monkey_island/cc/services/netmap/net_edge.py b/monkey/monkey_island/cc/services/netmap/net_edge.py
index f9d5e1932..44e097630 100644
--- a/monkey/monkey_island/cc/services/netmap/net_edge.py
+++ b/monkey/monkey_island/cc/services/netmap/net_edge.py
@@ -3,6 +3,7 @@ from bson import ObjectId
 from monkey_island.cc.models import Monkey
 from monkey_island.cc.models.edge import Edge
 from monkey_island.cc.services.edge.displayed_edge import DisplayedEdgeService
+from monkey_island.cc.services.edge.edge import EdgeService
 from monkey_island.cc.services.node import NodeService
 
 
@@ -20,7 +21,7 @@ class NetEdgeService:
 
     @staticmethod
     def _get_standard_net_edges():
-        return [DisplayedEdgeService.edge_to_net_edge(x) for x in Edge.objects()]
+        return [DisplayedEdgeService.edge_to_net_edge(x) for x in EdgeService.get_all_edges()]
 
     @staticmethod
     def _get_uninfected_island_net_edges():
@@ -44,7 +45,8 @@ class NetEdgeService:
 
     @staticmethod
     def _get_infected_island_net_edges(monkey_island_monkey):
-        existing_ids = [x.src_node_id for x in Edge.objects(dst_node_id=monkey_island_monkey["_id"])]
+        existing_ids = [x.src_node_id for x
+                        in EdgeService.get_by_dst_node(dst_node_id=monkey_island_monkey["_id"])]
         monkey_ids = [x.id for x in Monkey.objects()
                       if ("tunnel" not in x) and
                       (x.id not in existing_ids) and
diff --git a/monkey/monkey_island/cc/services/node.py b/monkey/monkey_island/cc/services/node.py
index ecce06c85..f6f5362c3 100644
--- a/monkey/monkey_island/cc/services/node.py
+++ b/monkey/monkey_island/cc/services/node.py
@@ -189,12 +189,9 @@ class NodeService:
             {'$unset': {'tunnel': ''}},
             upsert=False)
 
-        try:
-            edge = Edge.objects.get(src_node_id=monkey_id, tunnel=True)
-            edge.tunnel = False
-            edge.save()
-        except DoesNotExist:
-            pass
+        edges = EdgeService.get_tunnel_edges_by_src(monkey_id)
+        for edge in edges:
+            edge.disable_tunnel()
 
     @staticmethod
     def set_monkey_tunnel(monkey_id, tunnel_host_ip):
diff --git a/monkey/monkey_island/cc/services/telemetry/processing/exploit.py b/monkey/monkey_island/cc/services/telemetry/processing/exploit.py
index 8cb5db0d6..f8ea52b6e 100644
--- a/monkey/monkey_island/cc/services/telemetry/processing/exploit.py
+++ b/monkey/monkey_island/cc/services/telemetry/processing/exploit.py
@@ -25,7 +25,7 @@ def process_exploit_telemetry(telemetry_json):
         timestamp=telemetry_json['timestamp'])
 
 
-def update_node_credentials_from_successful_attempts(edge: Edge, telemetry_json):
+def update_node_credentials_from_successful_attempts(edge: EdgeService, telemetry_json):
     for attempt in telemetry_json['data']['attempts']:
         if attempt['result']:
             found_creds = {'user': attempt['user']}
@@ -35,13 +35,13 @@ def update_node_credentials_from_successful_attempts(edge: Edge, telemetry_json)
             NodeService.add_credentials_to_node(edge.dst_node_id, found_creds)
 
 
-def update_network_with_exploit(edge: Edge, telemetry_json):
+def update_network_with_exploit(edge: EdgeService, telemetry_json):
     telemetry_json['data']['info']['started'] = dateutil.parser.parse(telemetry_json['data']['info']['started'])
     telemetry_json['data']['info']['finished'] = dateutil.parser.parse(telemetry_json['data']['info']['finished'])
     new_exploit = copy.deepcopy(telemetry_json['data'])
     new_exploit.pop('machine')
     new_exploit['timestamp'] = telemetry_json['timestamp']
-    EdgeService.update_based_on_exploit(edge, new_exploit)
+    edge.update_based_on_exploit(new_exploit)
     if new_exploit['result']:
         NodeService.set_node_exploited(edge.dst_node_id)
 
diff --git a/monkey/monkey_island/cc/services/telemetry/processing/scan.py b/monkey/monkey_island/cc/services/telemetry/processing/scan.py
index 48c1f11c3..7a7f0b19c 100644
--- a/monkey/monkey_island/cc/services/telemetry/processing/scan.py
+++ b/monkey/monkey_island/cc/services/telemetry/processing/scan.py
@@ -18,7 +18,7 @@ def process_scan_telemetry(telemetry_json):
 
 def update_edges_and_nodes_based_on_scan_telemetry(telemetry_json):
     edge = get_edge_by_scan_or_exploit_telemetry(telemetry_json)
-    EdgeService.update_based_on_scan_telemetry(edge, telemetry_json)
+    edge.update_based_on_scan_telemetry(telemetry_json)
 
     node = mongo.db.node.find_one({"_id": edge.dst_node_id})
     if node is not None:
@@ -32,4 +32,4 @@ def update_edges_and_nodes_based_on_scan_telemetry(telemetry_json):
                                  {"$set": {"os.version": scan_os["version"]}},
                                  upsert=False)
         label = NodeService.get_label_for_endpoint(node["_id"])
-        EdgeService.update_label(edge, node["_id"], label)
+        edge.update_label(node["_id"], label)