Deployment: Import ATT&CK data into mongo

This commit is contained in:
Mike Salvatore 2021-09-28 14:18:58 -04:00
parent 82c8385863
commit 6de33bfd57
2 changed files with 167 additions and 23 deletions

View File

@ -0,0 +1,65 @@
from typing import Dict
from mongoengine import Document, EmbeddedDocument, EmbeddedDocumentField, ListField, StringField
from stix2 import AttackPattern, CourseOfAction
class Mitigation(EmbeddedDocument):
name = StringField(required=True)
description = StringField(required=True)
url = StringField()
@staticmethod
def get_from_stix2_data(mitigation: CourseOfAction):
name = mitigation["name"]
description = mitigation["description"]
url = get_stix2_external_reference_url(mitigation)
return Mitigation(name=name, description=description, url=url)
class AttackMitigations(Document):
technique_id = StringField(required=True, primary_key=True)
mitigations = ListField(EmbeddedDocumentField("Mitigation"))
def add_mitigation(self, mitigation: CourseOfAction):
mitigation_external_ref_id = get_stix2_external_reference_id(mitigation)
if mitigation_external_ref_id.startswith("M"):
self.mitigations.append(Mitigation.get_from_stix2_data(mitigation))
def add_no_mitigations_info(self, mitigation: CourseOfAction):
mitigation_external_ref_id = get_stix2_external_reference_id(mitigation)
if mitigation_external_ref_id.startswith("T") and len(self.mitigations) == 0:
mitigation_mongo_object = Mitigation.get_from_stix2_data(mitigation)
mitigation_mongo_object["description"] = mitigation_mongo_object[
"description"
].splitlines()[0]
mitigation_mongo_object["url"] = ""
self.mitigations.append(mitigation_mongo_object)
@staticmethod
def dict_from_stix2_attack_patterns(stix2_dict: Dict[str, AttackPattern]):
return {
key: AttackMitigations.mitigations_from_attack_pattern(attack_pattern)
for key, attack_pattern in stix2_dict.items()
}
@staticmethod
def mitigations_from_attack_pattern(attack_pattern: AttackPattern):
return AttackMitigations(
technique_id=get_stix2_external_reference_id(attack_pattern),
mitigations=[],
)
def get_stix2_external_reference_url(stix2_data) -> str:
for reference in stix2_data["external_references"]:
if "url" in reference:
return reference["url"]
return ""
def get_stix2_external_reference_id(stix2_data) -> str:
for reference in stix2_data["external_references"]:
if reference["source_name"] == "mitre-attack" and "external_id" in reference:
return reference["external_id"]
return ""

View File

@ -1,62 +1,141 @@
import argparse import argparse
from pathlib import Path
from typing import Dict, List
import mongoengine
import pymongo import pymongo
from attack_mitigations import AttackMitigations
from bson import json_util
from stix2 import AttackPattern, CourseOfAction, FileSystemSource, Filter
COLLECTION_NAME = "attack_mitigations"
def main(): def main():
args = parse_args() args = parse_args()
mongodb = connect_to_mongo(args.mongo_host, args.mongo_port, args.database_name)
clean_collection(mongodb, args.collection_name) set_default_mongo_connection(args.database_name, args.mongo_host, args.mongo_port)
mongo_client = pymongo.MongoClient(host=args.mongo_host, port=args.mongo_port)
database = mongo_client.get_database(args.database_name)
clean_collection(database)
populate_attack_mitigations(database, Path(args.cti_repo))
dump_attack_mitigations(database, Path(args.dump_file_path))
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description="Export attack mitigations from a database") parser = argparse.ArgumentParser(
parser.add_argument( description="Export attack mitigations from a database",
"-host", "--mongo_host", default="localhost", help="URL for mongo database.", required=False formatter_class=argparse.ArgumentDefaultsHelpFormatter,
) )
parser.add_argument( parser.add_argument(
"-port", "--mongo_host", default="localhost", help="URL for mongo database.", required=False
"--mongo_port", )
parser.add_argument(
"--mongo-port",
action="store", action="store",
default=27017, default=27017,
type=int, type=int,
help="Port for mongo database. Default 27017", help="Port for mongo database.",
required=False, required=False,
) )
parser.add_argument( parser.add_argument(
"-db", "--database-name",
"--database_name",
action="store", action="store",
default="monkeyisland", default="monkeyisland",
help="Database name inside of mongo.", help="Database name inside of mongo.",
required=False, required=False,
) )
parser.add_argument( parser.add_argument(
"-cn", "--cti-repo",
"--collection_name",
action="store", action="store",
default="attack_mitigations", default="attack_mitigations",
help="Which collection are we going to export", help="The path to the Cyber Threat Intelligence Repository.",
required=True,
)
parser.add_argument(
"--dump-file-path",
action="store",
default="./attack_mitigations.json",
help="A file path where the database dump will be saved.",
required=False, required=False,
) )
return parser.parse_args() return parser.parse_args()
def connect_to_mongo(mongo_host: str, mongo_port: int, database_name: str) -> pymongo.MongoClient: def set_default_mongo_connection(database_name: str, host: str, port: int):
client = pymongo.MongoClient(host=mongo_host, port=mongo_port) mongoengine.connect(db=database_name, host=host, port=port)
database = client.get_database(database_name)
return database
def clean_collection(mongodb: pymongo.MongoClient, collection_name: str): def clean_collection(database: pymongo.database.Database):
if collection_exists(mongodb, collection_name): if collection_exists(database, COLLECTION_NAME):
mongodb.drop_collection(collection_name) database.drop_collection(COLLECTION_NAME)
def collection_exists(mongodb: pymongo.MongoClient, collection_name: str) -> bool: def collection_exists(database: pymongo.database.Database, collection_name: str) -> bool:
collections = mongodb.list_collection_names() return collection_name in database.list_collection_names()
return collection_name in collections
def populate_attack_mitigations(database: pymongo.database.Database, cti_repo: Path):
database.create_collection(COLLECTION_NAME)
attack_data_path = cti_repo / "enterprise-attack"
stix2_mitigations = get_all_mitigations(attack_data_path)
mongo_mitigations = AttackMitigations.dict_from_stix2_attack_patterns(
get_all_attack_techniques(attack_data_path)
)
mitigation_technique_relationships = get_technique_and_mitigation_relationships(
attack_data_path
)
for relationship in mitigation_technique_relationships:
mongo_mitigations[relationship["target_ref"]].add_mitigation(
stix2_mitigations[relationship["source_ref"]]
)
for relationship in mitigation_technique_relationships:
mongo_mitigations[relationship["target_ref"]].add_no_mitigations_info(
stix2_mitigations[relationship["source_ref"]]
)
for key, mongo_object in mongo_mitigations.items():
mongo_object.save()
def get_all_mitigations(attack_data_path: Path) -> Dict[str, CourseOfAction]:
file_system = FileSystemSource(attack_data_path)
mitigation_filter = [Filter("type", "=", "course-of-action")]
all_mitigations = file_system.query(mitigation_filter)
all_mitigations = {mitigation["id"]: mitigation for mitigation in all_mitigations}
return all_mitigations
def get_all_attack_techniques(attack_data_path: Path) -> Dict[str, AttackPattern]:
file_system = FileSystemSource(attack_data_path)
technique_filter = [Filter("type", "=", "attack-pattern")]
all_techniques = file_system.query(technique_filter)
all_techniques = {technique["id"]: technique for technique in all_techniques}
return all_techniques
def get_technique_and_mitigation_relationships(attack_data_path: Path) -> List[CourseOfAction]:
file_system = FileSystemSource(attack_data_path)
technique_filter = [
Filter("type", "=", "relationship"),
Filter("relationship_type", "=", "mitigates"),
]
all_techniques = file_system.query(technique_filter)
return all_techniques
def dump_attack_mitigations(database: pymongo.database.Database, dump_file_path: Path):
if not collection_exists(database, COLLECTION_NAME):
raise Exception(f"Could not find collection: {COLLECTION_NAME}")
collection = database.get_collection(COLLECTION_NAME)
collection_contents = collection.find()
with open(dump_file_path, "wb") as jsonfile:
jsonfile.write(json_util.dumps(collection_contents).encode())
if __name__ == "__main__": if __name__ == "__main__":