Made envs an enum

This commit is contained in:
Shay Nehmad 2020-01-21 16:19:10 +02:00
parent db5c0f4786
commit 6f289915fc
7 changed files with 30 additions and 24 deletions

View File

@ -6,7 +6,7 @@ import logging
__author__ = 'itay.mizeretz' __author__ = 'itay.mizeretz'
from common.cloud.environment_names import AWS from common.cloud.environment_names import Environment
from common.cloud.instance import CloudInstance from common.cloud.instance import CloudInstance
AWS_INSTANCE_METADATA_LOCAL_IP_ADDRESS = "169.254.169.254" AWS_INSTANCE_METADATA_LOCAL_IP_ADDRESS = "169.254.169.254"
@ -23,8 +23,8 @@ class AwsInstance(CloudInstance):
def is_instance(self): def is_instance(self):
return self.instance_id is not None return self.instance_id is not None
def get_cloud_provider_name(self) -> str: def get_cloud_provider_name(self) -> Environment:
return AWS return Environment.AWS
def __init__(self): def __init__(self):
self.instance_id = None self.instance_id = None

View File

@ -1,7 +1,7 @@
import logging import logging
import requests import requests
from common.cloud.environment_names import AZURE from common.cloud.environment_names import Environment
from common.cloud.instance import CloudInstance from common.cloud.instance import CloudInstance
LATEST_AZURE_METADATA_API_VERSION = "2019-04-30" LATEST_AZURE_METADATA_API_VERSION = "2019-04-30"
@ -18,8 +18,8 @@ class AzureInstance(CloudInstance):
def is_instance(self): def is_instance(self):
return self.on_azure return self.on_azure
def get_cloud_provider_name(self) -> str: def get_cloud_provider_name(self) -> Environment:
return AZURE return Environment.AZURE
def __init__(self): def __init__(self):
""" """

View File

@ -1,5 +1,7 @@
# When adding a new environment to this file, make sure to add it to ALL_ENV_NAMES as well! from enum import Enum
class Environment(Enum):
UNKNOWN = "Unknown" UNKNOWN = "Unknown"
ON_PREMISE = "On Premise" ON_PREMISE = "On Premise"
AZURE = "Azure" AZURE = "Azure"
@ -9,4 +11,5 @@ ALIBABA = "Alibaba Cloud"
IBM = "IBM Cloud" IBM = "IBM Cloud"
DigitalOcean = "Digital Ocean" DigitalOcean = "Digital Ocean"
ALL_ENV_NAMES = [UNKNOWN, ON_PREMISE, AZURE, AWS, GCP, ALIBABA, IBM, DigitalOcean]
ALL_ENVIRONMENTS_NAMES = [x.value for x in Environment]

View File

@ -1,7 +1,7 @@
import logging import logging
import requests import requests
from common.cloud.environment_names import GCP from common.cloud.environment_names import Environment
from common.cloud.instance import CloudInstance from common.cloud.instance import CloudInstance
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -17,8 +17,8 @@ class GcpInstance(CloudInstance):
def is_instance(self): def is_instance(self):
return self.on_gcp return self.on_gcp
def get_cloud_provider_name(self) -> str: def get_cloud_provider_name(self) -> Environment:
return GCP return Environment.GCP
def __init__(self): def __init__(self):
self.on_gcp = False self.on_gcp = False

View File

@ -1,3 +1,6 @@
from common.cloud.environment_names import Environment
class CloudInstance(object): class CloudInstance(object):
""" """
This is an abstract class which represents a cloud instance. This is an abstract class which represents a cloud instance.
@ -7,5 +10,5 @@ class CloudInstance(object):
def is_instance(self) -> bool: def is_instance(self) -> bool:
raise NotImplementedError() raise NotImplementedError()
def get_cloud_provider_name(self) -> str: def get_cloud_provider_name(self) -> Environment:
raise NotImplementedError() raise NotImplementedError()

View File

@ -1,10 +1,10 @@
from common.cloud.all_instances import get_all_cloud_instances from common.cloud.all_instances import get_all_cloud_instances
from common.cloud.environment_names import ON_PREMISE from common.cloud.environment_names import Environment
from common.data.system_info_collectors_names import ENVIRONMENT_COLLECTOR from common.data.system_info_collectors_names import ENVIRONMENT_COLLECTOR
from infection_monkey.system_info.system_info_collector import SystemInfoCollector from infection_monkey.system_info.system_info_collector import SystemInfoCollector
def get_monkey_environment() -> str: def get_monkey_environment() -> Environment:
""" """
Get the Monkey's running environment. Get the Monkey's running environment.
:return: One of the cloud providers if on cloud; otherwise, assumes "on premise". :return: One of the cloud providers if on cloud; otherwise, assumes "on premise".
@ -13,7 +13,7 @@ def get_monkey_environment() -> str:
if instance.is_instance(): if instance.is_instance():
return instance.get_cloud_provider_name() return instance.get_cloud_provider_name()
return ON_PREMISE return Environment.ON_PREMISE
class EnvironmentCollector(SystemInfoCollector): class EnvironmentCollector(SystemInfoCollector):

View File

@ -45,7 +45,7 @@ class Monkey(Document):
command_control_channel = EmbeddedDocumentField(CommandControlChannel) command_control_channel = EmbeddedDocumentField(CommandControlChannel)
# Environment related fields # Environment related fields
environment = StringField(default=environment_names.UNKNOWN, choices=environment_names.ALL_ENV_NAMES) environment = StringField(default=environment_names.Environment.UNKNOWN, choices=environment_names.ALL_ENVIRONMENTS_NAMES)
aws_instance_id = StringField(required=False) # This field only exists when the monkey is running on an AWS aws_instance_id = StringField(required=False) # This field only exists when the monkey is running on an AWS
# instance. See https://github.com/guardicore/monkey/issues/426. # instance. See https://github.com/guardicore/monkey/issues/426.