diff --git a/chaos_monkey/network/firewall.py b/chaos_monkey/network/firewall.py index 950fdea7c..003da8613 100644 --- a/chaos_monkey/network/firewall.py +++ b/chaos_monkey/network/firewall.py @@ -15,16 +15,20 @@ class FirewallApp(object): def listen_allowed(self, **kwargs): return True - def __exit__(self): - self.close() + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + self.close() def close(self): return def _run_netsh_cmd(command, args): - cmd = subprocess.Popen("netsh %s %s" % (command, " ".join(['%s="%s"'%(key,value) for key,value in args.items()])), stdout=subprocess.PIPE) + cmd = subprocess.Popen("netsh %s %s" % (command, " ".join(['%s="%s"'%(key,value) for key,value in args.items() if value])), stdout=subprocess.PIPE) return cmd.stdout.read().strip().lower().endswith('ok.') + class WinAdvFirewall(FirewallApp): def __init__(self): self._rules = {} @@ -67,7 +71,7 @@ class WinAdvFirewall(FirewallApp): del self._rules[name] return True else: - return False + return False except: return None @@ -117,7 +121,8 @@ class WinFirewall(FirewallApp): netsh_args.update(kwargs) try: - if _run_netsh_cmd('firewall add', netsh_args): + if _run_netsh_cmd('firewall add %s' % rule, netsh_args): + netsh_args['rule'] = rule self._rules[name] = netsh_args return True else: @@ -125,13 +130,11 @@ class WinFirewall(FirewallApp): except: return None - def remove_firewall_rule(self, rule='allowedprogram', name="Firewall", **kwargs): - netsh_args = {'name': name, - 'mode' : mode, - 'program' : program} + def remove_firewall_rule(self, rule='allowedprogram', name="Firewall", mode="ENABLE", program=sys.executable, **kwargs): + netsh_args = {'program' : program} netsh_args.update(kwargs) try: - if _run_netsh_cmd('firewall delete', netsh_args): + if _run_netsh_cmd('firewall delete %s' % rule, netsh_args): if self._rules.has_key(name): del self._rules[name] return True @@ -146,16 +149,14 @@ class WinFirewall(FirewallApp): for rule in self._rules.values(): if rule.get('program') == sys.executable and \ - 'allowedprogram' == rule.get('rule') and \ - 'ENABLE' == rule.get('mode') and \ - 4 == len(rule.keys()): + 'ENABLE' == rule.get('mode'): return True return False def close(self): try: - for rule in self._rules.keys(): - _run_netsh_cmd('firewall delete', {'name' : rule}) + for rule in self._rules.values(): + self.remove_firewall_rule(**rule) except: pass