diff --git a/deployment_scripts/dump_attack_mitigations/dump_attack_mitigations.py b/deployment_scripts/dump_attack_mitigations/dump_attack_mitigations.py index a8c164ca5..c8e2b064a 100755 --- a/deployment_scripts/dump_attack_mitigations/dump_attack_mitigations.py +++ b/deployment_scripts/dump_attack_mitigations/dump_attack_mitigations.py @@ -1,4 +1,7 @@ import argparse +import json +import subprocess +import time from pathlib import Path from typing import Dict, List @@ -21,7 +24,7 @@ def main(): clean_collection(database) populate_attack_mitigations(database, Path(args.cti_repo)) - dump_attack_mitigations(database, Path(args.dump_file_path)) + dump_attack_mitigations(database, Path(args.cti_repo), Path(args.dump_file_path)) def parse_args(): @@ -127,15 +130,54 @@ def get_technique_and_mitigation_relationships(attack_data_path: Path) -> List[C return all_techniques -def dump_attack_mitigations(database: pymongo.database.Database, dump_file_path: Path): +def dump_attack_mitigations( + database: pymongo.database.Database, cti_repo: Path, dump_file_path: Path +): if not collection_exists(database, COLLECTION_NAME): raise Exception(f"Could not find collection: {COLLECTION_NAME}") + metadata = get_metadata(cti_repo) + data = get_data_from_database(database) + + json_output = f'{{"metadata":{json.dumps(metadata)},"data":{json_util.dumps(data)}}}' + + with open(dump_file_path, "wb") as jsonfile: + jsonfile.write(json_output.encode()) + + +def get_metadata(cti_repo: Path) -> dict: + timestamp = str(time.time()) + commit_hash = get_commit_hash(cti_repo) + origin_url = get_origin_url(cti_repo) + + return {"timestamp": timestamp, "commit_hash": commit_hash, "origin_url": origin_url} + + +def get_commit_hash(cti_repo: Path) -> str: + return run_command(["git", "rev-parse", "--short", "HEAD"], cti_repo).strip() + + +def get_origin_url(cti_repo: Path) -> str: + return run_command(["git", "remote", "get-url", "origin"], cti_repo).strip() + + +def run_command(cmd: List, cwd: Path = None) -> str: + cp = subprocess.run(cmd, capture_output=True, cwd=cwd, encoding="utf-8") + + if cp.returncode != 0: + raise Exception( + f"Error running command -- Command: {cmd} -- Return Code: {cp.returncode} -- stderr: " + f"{cp.stderr}" + ) + + return cp.stdout + + +def get_data_from_database(database: pymongo.database.Database) -> pymongo.cursor.Cursor: collection = database.get_collection(COLLECTION_NAME) collection_contents = collection.find() - with open(dump_file_path, "wb") as jsonfile: - jsonfile.write(json_util.dumps(collection_contents).encode()) + return collection_contents if __name__ == "__main__":