Agent: Fix mypy errors related to puppet

This commit is contained in:
vakarisz 2022-09-23 10:45:23 +03:00
parent 0d08ce467e
commit 978daf973b
3 changed files with 11 additions and 9 deletions

View File

@ -3,7 +3,7 @@ import threading
from collections import namedtuple from collections import namedtuple
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Dict, Iterable, List, Mapping, Sequence from typing import Dict, Iterable, Mapping, Optional, Sequence
from common.credentials import Credentials from common.credentials import Credentials
from infection_monkey.model import VictimHost from infection_monkey.model import VictimHost
@ -26,8 +26,8 @@ class ExploiterResultData:
propagation_success: bool = False propagation_success: bool = False
interrupted: bool = False interrupted: bool = False
os: str = "" os: str = ""
info: Mapping = None info: Optional[Mapping] = None
attempts: Iterable = None attempts: Optional[Iterable] = None
error_message: str = "" error_message: str = ""
@ -83,7 +83,7 @@ class IPuppet(metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
def scan_tcp_ports( def scan_tcp_ports(
self, host: str, ports: List[int], timeout: float = 3 self, host: str, ports: Sequence[int], timeout: float = 3
) -> Dict[int, PortScanData]: ) -> Dict[int, PortScanData]:
""" """
Scans a list of TCP ports on a remote host Scans a list of TCP ports on a remote host
@ -125,6 +125,7 @@ class IPuppet(metaclass=abc.ABCMeta):
name: str, name: str,
host: VictimHost, host: VictimHost,
current_depth: int, current_depth: int,
servers: Sequence[str],
options: Dict, options: Dict,
interrupt: threading.Event, interrupt: threading.Event,
) -> ExploiterResultData: ) -> ExploiterResultData:
@ -134,6 +135,7 @@ class IPuppet(metaclass=abc.ABCMeta):
:param str name: The name of the exploiter to run :param str name: The name of the exploiter to run
:param VictimHost host: A VictimHost object representing the target to exploit :param VictimHost host: A VictimHost object representing the target to exploit
:param int current_depth: The current propagation depth :param int current_depth: The current propagation depth
:param servers: List of socket addresses for victim to connect back to
:param Dict options: A dictionary containing options that modify the behavior of the :param Dict options: A dictionary containing options that modify the behavior of the
exploiter exploiter
:param threading.Event interrupt: A threading.Event object that signals the exploit to stop :param threading.Event interrupt: A threading.Event object that signals the exploit to stop

View File

@ -3,7 +3,7 @@ import select
import socket import socket
import time import time
from pprint import pformat from pprint import pformat
from typing import Collection, Iterable, Mapping, Tuple from typing import Collection, Dict, Iterable, Mapping, Tuple
from common.utils import Timer from common.utils import Timer
from infection_monkey.i_puppet import PortScanData, PortStatus from infection_monkey.i_puppet import PortScanData, PortStatus
@ -17,7 +17,7 @@ EMPTY_PORT_SCAN = {-1: PortScanData(-1, PortStatus.CLOSED, None, None)}
def scan_tcp_ports( def scan_tcp_ports(
host: str, ports_to_scan: Collection[int], timeout: float host: str, ports_to_scan: Collection[int], timeout: float
) -> Mapping[int, PortScanData]: ) -> Dict[int, PortScanData]:
try: try:
return _scan_tcp_ports(host, ports_to_scan, timeout) return _scan_tcp_ports(host, ports_to_scan, timeout)
except Exception: except Exception:

View File

@ -1,6 +1,6 @@
import logging import logging
import threading import threading
from typing import Dict, Iterable, List, Sequence from typing import Dict, Iterable, Sequence
from common.common_consts.timeouts import CONNECTION_TIMEOUT from common.common_consts.timeouts import CONNECTION_TIMEOUT
from common.credentials import Credentials from common.credentials import Credentials
@ -18,7 +18,7 @@ from infection_monkey.model import VictimHost
from .plugin_registry import PluginRegistry from .plugin_registry import PluginRegistry
EMPTY_FINGERPRINT = PingScanData(False, None) EMPTY_FINGERPRINT = FingerprintData(None, None, [])
logger = logging.getLogger() logger = logging.getLogger()
@ -44,7 +44,7 @@ class Puppet(IPuppet):
return network_scanning.ping(host, timeout) return network_scanning.ping(host, timeout)
def scan_tcp_ports( def scan_tcp_ports(
self, host: str, ports: List[int], timeout: float = CONNECTION_TIMEOUT self, host: str, ports: Sequence[int], timeout: float = CONNECTION_TIMEOUT
) -> Dict[int, PortScanData]: ) -> Dict[int, PortScanData]:
return network_scanning.scan_tcp_ports(host, ports, timeout) return network_scanning.scan_tcp_ports(host, ports, timeout)