From 6de33bfd572045ffaef822b4c9eee4d35dfcec9c Mon Sep 17 00:00:00 2001 From: Mike Salvatore Date: Tue, 28 Sep 2021 14:18:58 -0400 Subject: [PATCH] Deployment: Import ATT&CK data into mongo --- deployment_scripts/attack_mitigations.py | 65 +++++++++ deployment_scripts/dump_attack_mitigations.py | 125 ++++++++++++++---- 2 files changed, 167 insertions(+), 23 deletions(-) create mode 100644 deployment_scripts/attack_mitigations.py diff --git a/deployment_scripts/attack_mitigations.py b/deployment_scripts/attack_mitigations.py new file mode 100644 index 000000000..95e3a09e6 --- /dev/null +++ b/deployment_scripts/attack_mitigations.py @@ -0,0 +1,65 @@ +from typing import Dict + +from mongoengine import Document, EmbeddedDocument, EmbeddedDocumentField, ListField, StringField +from stix2 import AttackPattern, CourseOfAction + + +class Mitigation(EmbeddedDocument): + name = StringField(required=True) + description = StringField(required=True) + url = StringField() + + @staticmethod + def get_from_stix2_data(mitigation: CourseOfAction): + name = mitigation["name"] + description = mitigation["description"] + url = get_stix2_external_reference_url(mitigation) + return Mitigation(name=name, description=description, url=url) + + +class AttackMitigations(Document): + technique_id = StringField(required=True, primary_key=True) + mitigations = ListField(EmbeddedDocumentField("Mitigation")) + + def add_mitigation(self, mitigation: CourseOfAction): + mitigation_external_ref_id = get_stix2_external_reference_id(mitigation) + if mitigation_external_ref_id.startswith("M"): + self.mitigations.append(Mitigation.get_from_stix2_data(mitigation)) + + def add_no_mitigations_info(self, mitigation: CourseOfAction): + mitigation_external_ref_id = get_stix2_external_reference_id(mitigation) + if mitigation_external_ref_id.startswith("T") and len(self.mitigations) == 0: + mitigation_mongo_object = Mitigation.get_from_stix2_data(mitigation) + mitigation_mongo_object["description"] = mitigation_mongo_object[ + "description" + ].splitlines()[0] + mitigation_mongo_object["url"] = "" + self.mitigations.append(mitigation_mongo_object) + + @staticmethod + def dict_from_stix2_attack_patterns(stix2_dict: Dict[str, AttackPattern]): + return { + key: AttackMitigations.mitigations_from_attack_pattern(attack_pattern) + for key, attack_pattern in stix2_dict.items() + } + + @staticmethod + def mitigations_from_attack_pattern(attack_pattern: AttackPattern): + return AttackMitigations( + technique_id=get_stix2_external_reference_id(attack_pattern), + mitigations=[], + ) + + +def get_stix2_external_reference_url(stix2_data) -> str: + for reference in stix2_data["external_references"]: + if "url" in reference: + return reference["url"] + return "" + + +def get_stix2_external_reference_id(stix2_data) -> str: + for reference in stix2_data["external_references"]: + if reference["source_name"] == "mitre-attack" and "external_id" in reference: + return reference["external_id"] + return "" diff --git a/deployment_scripts/dump_attack_mitigations.py b/deployment_scripts/dump_attack_mitigations.py index 573d54d9d..a8c164ca5 100755 --- a/deployment_scripts/dump_attack_mitigations.py +++ b/deployment_scripts/dump_attack_mitigations.py @@ -1,62 +1,141 @@ import argparse +from pathlib import Path +from typing import Dict, List +import mongoengine import pymongo +from attack_mitigations import AttackMitigations +from bson import json_util +from stix2 import AttackPattern, CourseOfAction, FileSystemSource, Filter + +COLLECTION_NAME = "attack_mitigations" def main(): args = parse_args() - mongodb = connect_to_mongo(args.mongo_host, args.mongo_port, args.database_name) - clean_collection(mongodb, args.collection_name) + set_default_mongo_connection(args.database_name, args.mongo_host, args.mongo_port) + + mongo_client = pymongo.MongoClient(host=args.mongo_host, port=args.mongo_port) + database = mongo_client.get_database(args.database_name) + + clean_collection(database) + populate_attack_mitigations(database, Path(args.cti_repo)) + dump_attack_mitigations(database, Path(args.dump_file_path)) def parse_args(): - parser = argparse.ArgumentParser(description="Export attack mitigations from a database") - parser.add_argument( - "-host", "--mongo_host", default="localhost", help="URL for mongo database.", required=False + parser = argparse.ArgumentParser( + description="Export attack mitigations from a database", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( - "-port", - "--mongo_port", + "--mongo_host", default="localhost", help="URL for mongo database.", required=False + ) + parser.add_argument( + "--mongo-port", action="store", default=27017, type=int, - help="Port for mongo database. Default 27017", + help="Port for mongo database.", required=False, ) parser.add_argument( - "-db", - "--database_name", + "--database-name", action="store", default="monkeyisland", help="Database name inside of mongo.", required=False, ) parser.add_argument( - "-cn", - "--collection_name", + "--cti-repo", action="store", default="attack_mitigations", - help="Which collection are we going to export", + help="The path to the Cyber Threat Intelligence Repository.", + required=True, + ) + parser.add_argument( + "--dump-file-path", + action="store", + default="./attack_mitigations.json", + help="A file path where the database dump will be saved.", required=False, ) + return parser.parse_args() -def connect_to_mongo(mongo_host: str, mongo_port: int, database_name: str) -> pymongo.MongoClient: - client = pymongo.MongoClient(host=mongo_host, port=mongo_port) - database = client.get_database(database_name) - return database +def set_default_mongo_connection(database_name: str, host: str, port: int): + mongoengine.connect(db=database_name, host=host, port=port) -def clean_collection(mongodb: pymongo.MongoClient, collection_name: str): - if collection_exists(mongodb, collection_name): - mongodb.drop_collection(collection_name) +def clean_collection(database: pymongo.database.Database): + if collection_exists(database, COLLECTION_NAME): + database.drop_collection(COLLECTION_NAME) -def collection_exists(mongodb: pymongo.MongoClient, collection_name: str) -> bool: - collections = mongodb.list_collection_names() - return collection_name in collections +def collection_exists(database: pymongo.database.Database, collection_name: str) -> bool: + return collection_name in database.list_collection_names() + + +def populate_attack_mitigations(database: pymongo.database.Database, cti_repo: Path): + database.create_collection(COLLECTION_NAME) + attack_data_path = cti_repo / "enterprise-attack" + + stix2_mitigations = get_all_mitigations(attack_data_path) + mongo_mitigations = AttackMitigations.dict_from_stix2_attack_patterns( + get_all_attack_techniques(attack_data_path) + ) + mitigation_technique_relationships = get_technique_and_mitigation_relationships( + attack_data_path + ) + for relationship in mitigation_technique_relationships: + mongo_mitigations[relationship["target_ref"]].add_mitigation( + stix2_mitigations[relationship["source_ref"]] + ) + for relationship in mitigation_technique_relationships: + mongo_mitigations[relationship["target_ref"]].add_no_mitigations_info( + stix2_mitigations[relationship["source_ref"]] + ) + for key, mongo_object in mongo_mitigations.items(): + mongo_object.save() + + +def get_all_mitigations(attack_data_path: Path) -> Dict[str, CourseOfAction]: + file_system = FileSystemSource(attack_data_path) + mitigation_filter = [Filter("type", "=", "course-of-action")] + all_mitigations = file_system.query(mitigation_filter) + all_mitigations = {mitigation["id"]: mitigation for mitigation in all_mitigations} + return all_mitigations + + +def get_all_attack_techniques(attack_data_path: Path) -> Dict[str, AttackPattern]: + file_system = FileSystemSource(attack_data_path) + technique_filter = [Filter("type", "=", "attack-pattern")] + all_techniques = file_system.query(technique_filter) + all_techniques = {technique["id"]: technique for technique in all_techniques} + return all_techniques + + +def get_technique_and_mitigation_relationships(attack_data_path: Path) -> List[CourseOfAction]: + file_system = FileSystemSource(attack_data_path) + technique_filter = [ + Filter("type", "=", "relationship"), + Filter("relationship_type", "=", "mitigates"), + ] + all_techniques = file_system.query(technique_filter) + return all_techniques + + +def dump_attack_mitigations(database: pymongo.database.Database, dump_file_path: Path): + if not collection_exists(database, COLLECTION_NAME): + raise Exception(f"Could not find collection: {COLLECTION_NAME}") + + collection = database.get_collection(COLLECTION_NAME) + collection_contents = collection.find() + + with open(dump_file_path, "wb") as jsonfile: + jsonfile.write(json_util.dumps(collection_contents).encode()) if __name__ == "__main__":