diff --git a/monkey/monkey_island/cc/services/edge/edge.py b/monkey/monkey_island/cc/services/edge/edge.py index ca0c34731..dae19ccd8 100644 --- a/monkey/monkey_island/cc/services/edge/edge.py +++ b/monkey/monkey_island/cc/services/edge/edge.py @@ -1,6 +1,7 @@ from __future__ import annotations import copy +import threading from typing import Dict, List from bson import ObjectId @@ -10,6 +11,8 @@ from monkey_island.cc.models.edge import Edge RIGHT_ARROW = "\u2192" +lock = threading.Lock() + class EdgeService(Edge): @staticmethod @@ -18,16 +21,17 @@ class EdgeService(Edge): @staticmethod def get_or_create_edge(src_node_id, dst_node_id, src_label, dst_label) -> EdgeService: - edge = None - try: - edge = EdgeService.objects.get(src_node_id=src_node_id, dst_node_id=dst_node_id) - except DoesNotExist: - edge = EdgeService(src_node_id=src_node_id, dst_node_id=dst_node_id) - finally: - if edge: - edge.update_label(node_id=src_node_id, label=src_label) - edge.update_label(node_id=dst_node_id, label=dst_label) - return edge + with lock: + edge = None + try: + edge = EdgeService.objects.get(src_node_id=src_node_id, dst_node_id=dst_node_id) + except DoesNotExist: + edge = EdgeService(src_node_id=src_node_id, dst_node_id=dst_node_id) + finally: + if edge: + edge.update_label(node_id=src_node_id, label=src_label) + edge.update_label(node_id=dst_node_id, label=dst_label) + return edge @staticmethod def get_by_dst_node(dst_node_id: ObjectId) -> List[EdgeService]: