Refactored try_add_user and needs_registration to avoid code duplication

This commit is contained in:
VakarisZ 2020-06-22 11:59:02 +03:00
parent 8c428aa44d
commit 78cf0b5791
6 changed files with 149 additions and 129 deletions

View File

@ -6,5 +6,17 @@ class FailedExploitationError(Exception):
""" Raise when exploiter fails instead of returning False """ """ Raise when exploiter fails instead of returning False """
class InvalidRegistrationCredentials(Exception): class InvalidRegistrationCredentialsError(Exception):
""" Raise when server config file changed and island needs to restart """ """ Raise when server config file changed and island needs to restart """
class RegistrationNotNeededError(Exception):
""" Raise to indicate the reason why registration is not required """
class CredentialsNotRequiredError(RegistrationNotNeededError):
""" Raise to indicate the reason why registration is not required """
class AlreadyRegisteredError(RegistrationNotNeededError):
""" Raise to indicate the reason why registration is not required """

View File

@ -1,15 +1,18 @@
import hashlib import hashlib
import logging
import os import os
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from datetime import timedelta from datetime import timedelta
__author__ = 'itay.mizeretz' __author__ = 'itay.mizeretz'
from common.utils.exceptions import InvalidRegistrationCredentialsError, RegistrationNotNeededError, \
from common.utils.exceptions import InvalidRegistrationCredentials CredentialsNotRequiredError, AlreadyRegisteredError
from monkey_island.cc.environment.environment_config import EnvironmentConfig from monkey_island.cc.environment.environment_config import EnvironmentConfig
from monkey_island.cc.environment.user_creds import UserCreds from monkey_island.cc.environment.user_creds import UserCreds
logger = logging.getLogger(__name__)
class Environment(object, metaclass=ABCMeta): class Environment(object, metaclass=ABCMeta):
_ISLAND_PORT = 5000 _ISLAND_PORT = 5000
@ -36,24 +39,29 @@ class Environment(object, metaclass=ABCMeta):
def get_auth_users(self): def get_auth_users(self):
pass pass
def try_add_user(self, credentials: UserCreds):
if self._credentials_required:
if credentials:
if self._is_registered():
raise InvalidRegistrationCredentials("User has already been registered. "
"Reset credentials or login.")
self._config.add_user(credentials)
else:
raise InvalidRegistrationCredentials("Missing part of credentials.")
else:
raise InvalidRegistrationCredentials("Can't add user because credentials are not required "
"for current environment.")
def needs_registration(self) -> bool: def needs_registration(self) -> bool:
if not self._credentials_required: try:
needs_registration = self._try_needs_registration()
return needs_registration
except (CredentialsNotRequiredError, AlreadyRegisteredError) as e:
logger.info(e)
return False return False
def try_add_user(self, credentials: UserCreds):
if not credentials:
raise InvalidRegistrationCredentialsError("Missing part of credentials.")
if self._try_needs_registration():
self._config.add_user(credentials)
def _try_needs_registration(self) -> bool:
if not self._credentials_required:
raise CredentialsNotRequiredError("Credentials are not required "
"for current environment.")
else: else:
return not self._is_registered() if self._is_registered():
raise AlreadyRegisteredError("User has already been registered. "
"Reset credentials or login.")
return True
def _is_registered(self) -> bool: def _is_registered(self) -> bool:
return self._credentials_required and self._is_credentials_set_up() return self._credentials_required and self._is_credentials_set_up()

View File

@ -4,8 +4,10 @@ from typing import Dict
from unittest import TestCase from unittest import TestCase
from unittest.mock import patch, MagicMock from unittest.mock import patch, MagicMock
from common.utils.exceptions import InvalidRegistrationCredentials from common.utils.exceptions import InvalidRegistrationCredentialsError, AlreadyRegisteredError, \
CredentialsNotRequiredError, RegistrationNotNeededError
from monkey_island.cc.environment import Environment, EnvironmentConfig, UserCreds from monkey_island.cc.environment import Environment, EnvironmentConfig, UserCreds
from monkey_island.cc.testing.environment.server_config_mocks import *
def get_server_config_file_path_test_version(): def get_server_config_file_path_test_version():
@ -14,94 +16,91 @@ def get_server_config_file_path_test_version():
class TestEnvironment(TestCase): class TestEnvironment(TestCase):
class EnvironmentNoCredentials(Environment): class EnvironmentCredentialsNotRequired(Environment):
def __init__(self):
config = EnvironmentConfig('test', 'test', UserCreds())
super().__init__(config)
_credentials_required = False _credentials_required = False
def get_auth_users(self): def get_auth_users(self):
return [] return []
class EnvironmentWithCredentials(Environment): class EnvironmentCredentialsRequired(Environment):
def __init__(self):
config = EnvironmentConfig('test', 'test', UserCreds())
super().__init__(config)
_credentials_required = True _credentials_required = True
def get_auth_users(self): def get_auth_users(self):
return [] return []
# Username:test Password:test class EnvironmentAlreadyRegistered(Environment):
CONFIG_WITH_CREDENTIALS = { def __init__(self):
"server_config": "password", config = EnvironmentConfig('test', 'test', UserCreds('test_user', 'test_secret'))
"deployment": "develop", super().__init__(config)
"user": "test",
"password_hash": "9ece086e9bac491fac5c1d1046ca11d737b92a2b2ebd93f005d7b710110c0a678288166e7fbe796883a"
"4f2e9b3ca9f484f521d0ce464345cc1aec96779149c14"
}
CONFIG_NO_CREDENTIALS = { _credentials_required = True
"server_config": "password",
"deployment": "develop"
}
CONFIG_PARTIAL_CREDENTIALS = { def get_auth_users(self):
"server_config": "password", return [1, "Test_username", "Test_secret"]
"deployment": "develop",
"user": "test"
}
CONFIG_STANDARD_ENV = {
"server_config": "standard",
"deployment": "develop"
}
CONFIG_STANDARD_WITH_CREDENTIALS = {
"server_config": "standard",
"deployment": "develop",
"user": "test",
"password_hash": "9ece086e9bac491fac5c1d1046ca11d737b92a2b2ebd93f005d7b710110c0a678288166e7fbe796883a"
"4f2e9b3ca9f484f521d0ce464345cc1aec96779149c14"
}
@patch.object(target=EnvironmentConfig, attribute="save_to_file", new=MagicMock()) @patch.object(target=EnvironmentConfig, attribute="save_to_file", new=MagicMock())
def test_try_add_user(self): def test_try_add_user(self):
env = TestEnvironment.EnvironmentWithCredentials() env = TestEnvironment.EnvironmentCredentialsRequired()
credentials = UserCreds(username="test", password_hash="1231234") credentials = UserCreds(username="test", password_hash="1231234")
env.try_add_user(credentials) env.try_add_user(credentials)
credentials = UserCreds(username="test") credentials = UserCreds(username="test")
with self.assertRaises(InvalidRegistrationCredentials): with self.assertRaises(InvalidRegistrationCredentialsError):
env.try_add_user(credentials) env.try_add_user(credentials)
env = TestEnvironment.EnvironmentNoCredentials() env = TestEnvironment.EnvironmentCredentialsNotRequired()
credentials = UserCreds(username="test", password_hash="1231234") credentials = UserCreds(username="test", password_hash="1231234")
with self.assertRaises(InvalidRegistrationCredentials): with self.assertRaises(RegistrationNotNeededError):
env.try_add_user(credentials) env.try_add_user(credentials)
def test_try_needs_registration(self):
env = TestEnvironment.EnvironmentAlreadyRegistered()
with self.assertRaises(AlreadyRegisteredError):
env._try_needs_registration()
env = TestEnvironment.EnvironmentCredentialsNotRequired()
with self.assertRaises(CredentialsNotRequiredError):
env._try_needs_registration()
env = TestEnvironment.EnvironmentCredentialsRequired()
self.assertTrue(env._try_needs_registration())
def test_needs_registration(self): def test_needs_registration(self):
env = TestEnvironment.EnvironmentWithCredentials() env = TestEnvironment.EnvironmentCredentialsRequired()
self._test_bool_env_method("needs_registration", env, TestEnvironment.CONFIG_WITH_CREDENTIALS, False) self._test_bool_env_method("needs_registration", env, CONFIG_WITH_CREDENTIALS, False)
self._test_bool_env_method("needs_registration", env, TestEnvironment.CONFIG_NO_CREDENTIALS, True) self._test_bool_env_method("needs_registration", env, CONFIG_NO_CREDENTIALS, True)
self._test_bool_env_method("needs_registration", env, TestEnvironment.CONFIG_PARTIAL_CREDENTIALS, True) self._test_bool_env_method("needs_registration", env, CONFIG_PARTIAL_CREDENTIALS, True)
env = TestEnvironment.EnvironmentNoCredentials() env = TestEnvironment.EnvironmentCredentialsNotRequired()
self._test_bool_env_method("needs_registration", env, TestEnvironment.CONFIG_STANDARD_ENV, False) self._test_bool_env_method("needs_registration", env, CONFIG_STANDARD_ENV, False)
self._test_bool_env_method("needs_registration", env, TestEnvironment.CONFIG_STANDARD_WITH_CREDENTIALS, False) self._test_bool_env_method("needs_registration", env, CONFIG_STANDARD_WITH_CREDENTIALS, False)
def test_is_registered(self): def test_is_registered(self):
env = TestEnvironment.EnvironmentWithCredentials() env = TestEnvironment.EnvironmentCredentialsRequired()
self._test_bool_env_method("_is_registered", env, TestEnvironment.CONFIG_WITH_CREDENTIALS, True) self._test_bool_env_method("_is_registered", env, CONFIG_WITH_CREDENTIALS, True)
self._test_bool_env_method("_is_registered", env, TestEnvironment.CONFIG_NO_CREDENTIALS, False) self._test_bool_env_method("_is_registered", env, CONFIG_NO_CREDENTIALS, False)
self._test_bool_env_method("_is_registered", env, TestEnvironment.CONFIG_PARTIAL_CREDENTIALS, False) self._test_bool_env_method("_is_registered", env, CONFIG_PARTIAL_CREDENTIALS, False)
env = TestEnvironment.EnvironmentNoCredentials() env = TestEnvironment.EnvironmentCredentialsNotRequired()
self._test_bool_env_method("_is_registered", env, TestEnvironment.CONFIG_STANDARD_ENV, False) self._test_bool_env_method("_is_registered", env, CONFIG_STANDARD_ENV, False)
self._test_bool_env_method("_is_registered", env, TestEnvironment.CONFIG_STANDARD_WITH_CREDENTIALS, False) self._test_bool_env_method("_is_registered", env, CONFIG_STANDARD_WITH_CREDENTIALS, False)
def test_is_credentials_set_up(self): def test_is_credentials_set_up(self):
env = TestEnvironment.EnvironmentWithCredentials() env = TestEnvironment.EnvironmentCredentialsRequired()
self._test_bool_env_method("_is_credentials_set_up", env, TestEnvironment.CONFIG_NO_CREDENTIALS, False) self._test_bool_env_method("_is_credentials_set_up", env, CONFIG_NO_CREDENTIALS, False)
self._test_bool_env_method("_is_credentials_set_up", env, TestEnvironment.CONFIG_WITH_CREDENTIALS, True) self._test_bool_env_method("_is_credentials_set_up", env, CONFIG_WITH_CREDENTIALS, True)
self._test_bool_env_method("_is_credentials_set_up", env, TestEnvironment.CONFIG_PARTIAL_CREDENTIALS, False) self._test_bool_env_method("_is_credentials_set_up", env, CONFIG_PARTIAL_CREDENTIALS, False)
env = TestEnvironment.EnvironmentNoCredentials() env = TestEnvironment.EnvironmentCredentialsNotRequired()
self._test_bool_env_method("_is_credentials_set_up", env, TestEnvironment.CONFIG_STANDARD_ENV, False) self._test_bool_env_method("_is_credentials_set_up", env, CONFIG_STANDARD_ENV, False)
def _test_bool_env_method(self, method_name: str, env: Environment, config: Dict, expected_result: bool): def _test_bool_env_method(self, method_name: str, env: Environment, config: Dict, expected_result: bool):
env._config = EnvironmentConfig.get_from_json(json.dumps(config)) env._config = EnvironmentConfig.get_from_json(json.dumps(config))

View File

@ -8,6 +8,7 @@ from unittest.mock import patch, MagicMock
from monkey_island.cc.consts import MONKEY_ISLAND_ABS_PATH from monkey_island.cc.consts import MONKEY_ISLAND_ABS_PATH
from monkey_island.cc.environment.environment_config import EnvironmentConfig from monkey_island.cc.environment.environment_config import EnvironmentConfig
from monkey_island.cc.environment.user_creds import UserCreds from monkey_island.cc.environment.user_creds import UserCreds
from monkey_island.cc.testing.environment.server_config_mocks import *
def get_server_config_file_path_test_version(): def get_server_config_file_path_test_version():
@ -15,52 +16,11 @@ def get_server_config_file_path_test_version():
class TestEnvironmentConfig(TestCase): class TestEnvironmentConfig(TestCase):
# Username:test Password:test
CONFIG_WITH_CREDENTIALS = {
"server_config": "password",
"deployment": "develop",
"user": "test",
"password_hash": "9ece086e9bac491fac5c1d1046ca11d737b92a2b2ebd93f005d7b710110c0a678288166e7fbe796883a"
"4f2e9b3ca9f484f521d0ce464345cc1aec96779149c14"
}
CONFIG_NO_CREDENTIALS = {
"server_config": "password",
"deployment": "develop"
}
CONFIG_PARTIAL_CREDENTIALS = {
"server_config": "password",
"deployment": "develop",
"user": "test"
}
CONFIG_BOGUS_VALUES = {
"server_config": "password",
"deployment": "develop",
"user": "test",
"aws": "test",
"test": "test",
"test2": "test2"
}
CONFIG_STANDARD_ENV = {
"server_config": "standard",
"deployment": "develop"
}
CONFIG_STANDARD_WITH_CREDENTIALS = {
"server_config": "standard",
"deployment": "develop",
"user": "test",
"password_hash": "9ece086e9bac491fac5c1d1046ca11d737b92a2b2ebd93f005d7b710110c0a678288166e7fbe796883a"
"4f2e9b3ca9f484f521d0ce464345cc1aec96779149c14"
}
def test_get_from_json(self): def test_get_from_json(self):
self._test_get_from_json(TestEnvironmentConfig.CONFIG_WITH_CREDENTIALS) self._test_get_from_json(CONFIG_WITH_CREDENTIALS)
self._test_get_from_json(TestEnvironmentConfig.CONFIG_NO_CREDENTIALS) self._test_get_from_json(CONFIG_NO_CREDENTIALS)
self._test_get_from_json(TestEnvironmentConfig.CONFIG_PARTIAL_CREDENTIALS) self._test_get_from_json(CONFIG_PARTIAL_CREDENTIALS)
def _test_get_from_json(self, config: Dict): def _test_get_from_json(self, config: Dict):
config_json = json.dumps(config) config_json = json.dumps(config)
@ -75,9 +35,9 @@ class TestEnvironmentConfig(TestCase):
self.assertEqual(config['aws'], env_config_object.aws) self.assertEqual(config['aws'], env_config_object.aws)
def test_save_to_file(self): def test_save_to_file(self):
self._test_save_to_file(TestEnvironmentConfig.CONFIG_WITH_CREDENTIALS) self._test_save_to_file(CONFIG_WITH_CREDENTIALS)
self._test_save_to_file(TestEnvironmentConfig.CONFIG_NO_CREDENTIALS) self._test_save_to_file(CONFIG_NO_CREDENTIALS)
self._test_save_to_file(TestEnvironmentConfig.CONFIG_PARTIAL_CREDENTIALS) self._test_save_to_file(CONFIG_PARTIAL_CREDENTIALS)
@patch.object(target=EnvironmentConfig, attribute="get_config_file_path", @patch.object(target=EnvironmentConfig, attribute="get_config_file_path",
new=MagicMock(return_value=get_server_config_file_path_test_version())) new=MagicMock(return_value=get_server_config_file_path_test_version()))
@ -103,14 +63,14 @@ class TestEnvironmentConfig(TestCase):
self.assertEqual(EnvironmentConfig.get_config_file_path(), server_file_path) self.assertEqual(EnvironmentConfig.get_config_file_path(), server_file_path)
def test_get_from_dict(self): def test_get_from_dict(self):
config_dict = TestEnvironmentConfig.CONFIG_WITH_CREDENTIALS config_dict = CONFIG_WITH_CREDENTIALS
env_conf = EnvironmentConfig.get_from_dict(config_dict) env_conf = EnvironmentConfig.get_from_dict(config_dict)
self.assertEqual(env_conf.server_config, config_dict['server_config']) self.assertEqual(env_conf.server_config, config_dict['server_config'])
self.assertEqual(env_conf.deployment, config_dict['deployment']) self.assertEqual(env_conf.deployment, config_dict['deployment'])
self.assertEqual(env_conf.user_creds.username, config_dict['user']) self.assertEqual(env_conf.user_creds.username, config_dict['user'])
self.assertEqual(env_conf.aws, None) self.assertEqual(env_conf.aws, None)
config_dict = TestEnvironmentConfig.CONFIG_BOGUS_VALUES config_dict = CONFIG_BOGUS_VALUES
env_conf = EnvironmentConfig.get_from_dict(config_dict) env_conf = EnvironmentConfig.get_from_dict(config_dict)
self.assertEqual(env_conf.server_config, config_dict['server_config']) self.assertEqual(env_conf.server_config, config_dict['server_config'])
self.assertEqual(env_conf.deployment, config_dict['deployment']) self.assertEqual(env_conf.deployment, config_dict['deployment'])
@ -118,13 +78,13 @@ class TestEnvironmentConfig(TestCase):
self.assertEqual(env_conf.aws, config_dict['aws']) self.assertEqual(env_conf.aws, config_dict['aws'])
def test_to_dict(self): def test_to_dict(self):
conf_json1 = json.dumps(TestEnvironmentConfig.CONFIG_WITH_CREDENTIALS) conf_json1 = json.dumps(CONFIG_WITH_CREDENTIALS)
self._test_to_dict(EnvironmentConfig.get_from_json(conf_json1)) self._test_to_dict(EnvironmentConfig.get_from_json(conf_json1))
conf_json2 = json.dumps(TestEnvironmentConfig.CONFIG_NO_CREDENTIALS) conf_json2 = json.dumps(CONFIG_NO_CREDENTIALS)
self._test_to_dict(EnvironmentConfig.get_from_json(conf_json2)) self._test_to_dict(EnvironmentConfig.get_from_json(conf_json2))
conf_json3 = json.dumps(TestEnvironmentConfig.CONFIG_PARTIAL_CREDENTIALS) conf_json3 = json.dumps(CONFIG_PARTIAL_CREDENTIALS)
self._test_to_dict(EnvironmentConfig.get_from_json(conf_json3)) self._test_to_dict(EnvironmentConfig.get_from_json(conf_json3))
def _test_to_dict(self, env_config_object: EnvironmentConfig): def _test_to_dict(self, env_config_object: EnvironmentConfig):

View File

@ -1,7 +1,7 @@
import flask_restful import flask_restful
from flask import request, make_response from flask import request, make_response
from common.utils.exceptions import InvalidRegistrationCredentials from common.utils.exceptions import InvalidRegistrationCredentialsError, RegistrationNotNeededError
import monkey_island.cc.environment.environment_singleton as env_singleton import monkey_island.cc.environment.environment_singleton as env_singleton
from monkey_island.cc.environment.user_creds import UserCreds from monkey_island.cc.environment.user_creds import UserCreds
@ -15,6 +15,6 @@ class Registration(flask_restful.Resource):
try: try:
env_singleton.env.try_add_user(credentials) env_singleton.env.try_add_user(credentials)
return make_response({"error": ""}, 200) return make_response({"error": ""}, 200)
except InvalidRegistrationCredentials as e: except (InvalidRegistrationCredentialsError, RegistrationNotNeededError) as e:
return make_response({"error": str(e)}, 400) return make_response({"error": str(e)}, 400)

View File

@ -0,0 +1,41 @@
# Username:test Password:test
CONFIG_WITH_CREDENTIALS = {
"server_config": "password",
"deployment": "develop",
"user": "test",
"password_hash": "9ece086e9bac491fac5c1d1046ca11d737b92a2b2ebd93f005d7b710110c0a678288166e7fbe796883a"
"4f2e9b3ca9f484f521d0ce464345cc1aec96779149c14"
}
CONFIG_NO_CREDENTIALS = {
"server_config": "password",
"deployment": "develop"
}
CONFIG_PARTIAL_CREDENTIALS = {
"server_config": "password",
"deployment": "develop",
"user": "test"
}
CONFIG_BOGUS_VALUES = {
"server_config": "password",
"deployment": "develop",
"user": "test",
"aws": "test",
"test": "test",
"test2": "test2"
}
CONFIG_STANDARD_ENV = {
"server_config": "standard",
"deployment": "develop"
}
CONFIG_STANDARD_WITH_CREDENTIALS = {
"server_config": "standard",
"deployment": "develop",
"user": "test",
"password_hash": "9ece086e9bac491fac5c1d1046ca11d737b92a2b2ebd93f005d7b710110c0a678288166e7fbe796883a"
"4f2e9b3ca9f484f521d0ce464345cc1aec96779149c14"
}