Agent: Fix return types for run_pba in puppets and master

This commit is contained in:
Shreya Malviya 2022-03-29 18:38:25 +05:30
parent 314bc49d1c
commit 8d4c29fc06
3 changed files with 8 additions and 8 deletions

View File

@ -59,12 +59,12 @@ class IPuppet(metaclass=abc.ABCMeta):
""" """
@abc.abstractmethod @abc.abstractmethod
def run_pba(self, name: str, options: Dict) -> PostBreachData: def run_pba(self, name: str, options: Dict) -> Iterable[PostBreachData]:
""" """
Runs a post-breach action (PBA) Runs a post-breach action (PBA)
:param str name: The name of the post-breach action to run :param str name: The name of the post-breach action to run
:param Dict options: A dictionary containing options that modify the behavior of the PBA :param Dict options: A dictionary containing options that modify the behavior of the PBA
:rtype: PostBreachData :rtype: Iterable[PostBreachData]
""" """
@abc.abstractmethod @abc.abstractmethod

View File

@ -1,6 +1,6 @@
import logging import logging
import threading import threading
from typing import Dict, List, Sequence from typing import Dict, Iterable, List, Sequence
from infection_monkey.credential_collectors import LMHash, Password, SSHKeypair, Username from infection_monkey.credential_collectors import LMHash, Password, SSHKeypair, Username
from infection_monkey.i_puppet import ( from infection_monkey.i_puppet import (
@ -49,13 +49,13 @@ class MockPuppet(IPuppet):
return [] return []
def run_pba(self, name: str, options: Dict) -> PostBreachData: def run_pba(self, name: str, options: Dict) -> Iterable[PostBreachData]:
logger.debug(f"run_pba({name}, {options})") logger.debug(f"run_pba({name}, {options})")
if name == "AccountDiscovery": if name == "AccountDiscovery":
yield PostBreachData(name, "pba command 1", ["pba result 1", True]) return [PostBreachData(name, "pba command 1", ["pba result 1", True])]
else: else:
yield PostBreachData(name, "pba command 2", ["pba result 2", False]) return [PostBreachData(name, "pba command 2", ["pba result 2", False])]
def ping(self, host: str, timeout: float = 1) -> PingScanData: def ping(self, host: str, timeout: float = 1) -> PingScanData:
logger.debug(f"run_ping({host}, {timeout})") logger.debug(f"run_ping({host}, {timeout})")

View File

@ -1,6 +1,6 @@
import logging import logging
import threading import threading
from typing import Dict, List, Sequence from typing import Dict, Iterable, List, Sequence
from common.common_consts.timeouts import CONNECTION_TIMEOUT from common.common_consts.timeouts import CONNECTION_TIMEOUT
from infection_monkey import network_scanning from infection_monkey import network_scanning
@ -36,7 +36,7 @@ class Puppet(IPuppet):
) )
return credential_collector.collect_credentials(options) return credential_collector.collect_credentials(options)
def run_pba(self, name: str, options: Dict) -> PostBreachData: def run_pba(self, name: str, options: Dict) -> Iterable[PostBreachData]:
return self._mock_puppet.run_pba(name, options) return self._mock_puppet.run_pba(name, options)
def ping(self, host: str, timeout: float = CONNECTION_TIMEOUT) -> PingScanData: def ping(self, host: str, timeout: float = CONNECTION_TIMEOUT) -> PingScanData: