diff --git a/deployment_scripts/dump_attack_mitigations.py b/deployment_scripts/dump_attack_mitigations.py index 6ca92f8b7..573d54d9d 100755 --- a/deployment_scripts/dump_attack_mitigations.py +++ b/deployment_scripts/dump_attack_mitigations.py @@ -3,6 +3,13 @@ import argparse import pymongo +def main(): + args = parse_args() + mongodb = connect_to_mongo(args.mongo_host, args.mongo_port, args.database_name) + + clean_collection(mongodb, args.collection_name) + + def parse_args(): parser = argparse.ArgumentParser(description="Export attack mitigations from a database") parser.add_argument( @@ -42,21 +49,14 @@ def connect_to_mongo(mongo_host: str, mongo_port: int, database_name: str) -> py return database -def collection_exists(mongodb: pymongo.MongoClient, collection_name: str) -> bool: - collections = mongodb.list_collection_names() - return collection_name in collections - - def clean_collection(mongodb: pymongo.MongoClient, collection_name: str): if collection_exists(mongodb, collection_name): mongodb.drop_collection(collection_name) -def main(): - args = parse_args() - mongodb = connect_to_mongo(args.mongo_host, args.mongo_port, args.database_name) - - clean_collection(mongodb, args.collection_name) +def collection_exists(mongodb: pymongo.MongoClient, collection_name: str) -> bool: + collections = mongodb.list_collection_names() + return collection_name in collections if __name__ == "__main__":