Agent: Fix return types for run_pba in puppets and master

This commit is contained in:
Shreya Malviya 2022-03-29 18:38:25 +05:30
parent 314bc49d1c
commit 8d4c29fc06
3 changed files with 8 additions and 8 deletions

View File

@ -59,12 +59,12 @@ class IPuppet(metaclass=abc.ABCMeta):
"""
@abc.abstractmethod
def run_pba(self, name: str, options: Dict) -> PostBreachData:
def run_pba(self, name: str, options: Dict) -> Iterable[PostBreachData]:
"""
Runs a post-breach action (PBA)
:param str name: The name of the post-breach action to run
:param Dict options: A dictionary containing options that modify the behavior of the PBA
:rtype: PostBreachData
:rtype: Iterable[PostBreachData]
"""
@abc.abstractmethod

View File

@ -1,6 +1,6 @@
import logging
import threading
from typing import Dict, List, Sequence
from typing import Dict, Iterable, List, Sequence
from infection_monkey.credential_collectors import LMHash, Password, SSHKeypair, Username
from infection_monkey.i_puppet import (
@ -49,13 +49,13 @@ class MockPuppet(IPuppet):
return []
def run_pba(self, name: str, options: Dict) -> PostBreachData:
def run_pba(self, name: str, options: Dict) -> Iterable[PostBreachData]:
logger.debug(f"run_pba({name}, {options})")
if name == "AccountDiscovery":
yield PostBreachData(name, "pba command 1", ["pba result 1", True])
return [PostBreachData(name, "pba command 1", ["pba result 1", True])]
else:
yield PostBreachData(name, "pba command 2", ["pba result 2", False])
return [PostBreachData(name, "pba command 2", ["pba result 2", False])]
def ping(self, host: str, timeout: float = 1) -> PingScanData:
logger.debug(f"run_ping({host}, {timeout})")

View File

@ -1,6 +1,6 @@
import logging
import threading
from typing import Dict, List, Sequence
from typing import Dict, Iterable, List, Sequence
from common.common_consts.timeouts import CONNECTION_TIMEOUT
from infection_monkey import network_scanning
@ -36,7 +36,7 @@ class Puppet(IPuppet):
)
return credential_collector.collect_credentials(options)
def run_pba(self, name: str, options: Dict) -> PostBreachData:
def run_pba(self, name: str, options: Dict) -> Iterable[PostBreachData]:
return self._mock_puppet.run_pba(name, options)
def ping(self, host: str, timeout: float = CONNECTION_TIMEOUT) -> PingScanData: