monkey/deployment_scripts/dump_attack_mitigations/dump_attack_mitigations.py

143 lines
4.8 KiB
Python
Raw Normal View History

import argparse
from pathlib import Path
from typing import Dict, List
import mongoengine
import pymongo
from attack_mitigations import AttackMitigations
from bson import json_util
from stix2 import AttackPattern, CourseOfAction, FileSystemSource, Filter
COLLECTION_NAME = "attack_mitigations"
def main():
args = parse_args()
set_default_mongo_connection(args.database_name, args.mongo_host, args.mongo_port)
mongo_client = pymongo.MongoClient(host=args.mongo_host, port=args.mongo_port)
database = mongo_client.get_database(args.database_name)
clean_collection(database)
populate_attack_mitigations(database, Path(args.cti_repo))
dump_attack_mitigations(database, Path(args.dump_file_path))
def parse_args():
parser = argparse.ArgumentParser(
description="Export attack mitigations from a database",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--mongo_host", default="localhost", help="URL for mongo database.", required=False
)
parser.add_argument(
"--mongo-port",
action="store",
default=27017,
type=int,
help="Port for mongo database.",
required=False,
)
parser.add_argument(
"--database-name",
action="store",
default="monkeyisland",
help="Database name inside of mongo.",
required=False,
)
parser.add_argument(
"--cti-repo",
action="store",
default="attack_mitigations",
help="The path to the Cyber Threat Intelligence Repository.",
required=True,
)
parser.add_argument(
"--dump-file-path",
action="store",
default="./attack_mitigations.json",
help="A file path where the database dump will be saved.",
required=False,
)
return parser.parse_args()
def set_default_mongo_connection(database_name: str, host: str, port: int):
mongoengine.connect(db=database_name, host=host, port=port)
def clean_collection(database: pymongo.database.Database):
if collection_exists(database, COLLECTION_NAME):
database.drop_collection(COLLECTION_NAME)
def collection_exists(database: pymongo.database.Database, collection_name: str) -> bool:
return collection_name in database.list_collection_names()
def populate_attack_mitigations(database: pymongo.database.Database, cti_repo: Path):
database.create_collection(COLLECTION_NAME)
attack_data_path = cti_repo / "enterprise-attack"
stix2_mitigations = get_all_mitigations(attack_data_path)
mongo_mitigations = AttackMitigations.dict_from_stix2_attack_patterns(
get_all_attack_techniques(attack_data_path)
)
mitigation_technique_relationships = get_technique_and_mitigation_relationships(
attack_data_path
)
for relationship in mitigation_technique_relationships:
mongo_mitigations[relationship["target_ref"]].add_mitigation(
stix2_mitigations[relationship["source_ref"]]
)
for relationship in mitigation_technique_relationships:
mongo_mitigations[relationship["target_ref"]].add_no_mitigations_info(
stix2_mitigations[relationship["source_ref"]]
)
for key, mongo_object in mongo_mitigations.items():
mongo_object.save()
def get_all_mitigations(attack_data_path: Path) -> Dict[str, CourseOfAction]:
file_system = FileSystemSource(attack_data_path)
mitigation_filter = [Filter("type", "=", "course-of-action")]
all_mitigations = file_system.query(mitigation_filter)
all_mitigations = {mitigation["id"]: mitigation for mitigation in all_mitigations}
return all_mitigations
def get_all_attack_techniques(attack_data_path: Path) -> Dict[str, AttackPattern]:
file_system = FileSystemSource(attack_data_path)
technique_filter = [Filter("type", "=", "attack-pattern")]
all_techniques = file_system.query(technique_filter)
all_techniques = {technique["id"]: technique for technique in all_techniques}
return all_techniques
def get_technique_and_mitigation_relationships(attack_data_path: Path) -> List[CourseOfAction]:
file_system = FileSystemSource(attack_data_path)
technique_filter = [
Filter("type", "=", "relationship"),
Filter("relationship_type", "=", "mitigates"),
]
all_techniques = file_system.query(technique_filter)
return all_techniques
def dump_attack_mitigations(database: pymongo.database.Database, dump_file_path: Path):
if not collection_exists(database, COLLECTION_NAME):
raise Exception(f"Could not find collection: {COLLECTION_NAME}")
collection = database.get_collection(COLLECTION_NAME)
collection_contents = collection.find()
with open(dump_file_path, "wb") as jsonfile:
jsonfile.write(json_util.dumps(collection_contents).encode())
if __name__ == "__main__":
main()