diff --git a/monkey/common/utils/exceptions.py b/monkey/common/utils/exceptions.py index df40f3007..50dcb2d6b 100644 --- a/monkey/common/utils/exceptions.py +++ b/monkey/common/utils/exceptions.py @@ -14,14 +14,14 @@ class RegistrationNotNeededError(Exception): """ Raise to indicate the reason why registration is not required """ -class CredentialsNotRequiredError(RegistrationNotNeededError): - """ Raise to indicate the reason why registration is not required """ - - class AlreadyRegisteredError(RegistrationNotNeededError): """ Raise to indicate the reason why registration is not required """ +class IncorrectCredentialsError(Exception): + """ Raise to indicate that authentication failed """ + + class RulePathCreatorNotFound(Exception): """ Raise to indicate that ScoutSuite rule doesn't have a path creator""" diff --git a/monkey/monkey_island/cc/environment/__init__.py b/monkey/monkey_island/cc/environment/__init__.py index 281b08a3a..70fb775c5 100644 --- a/monkey/monkey_island/cc/environment/__init__.py +++ b/monkey/monkey_island/cc/environment/__init__.py @@ -4,7 +4,6 @@ from datetime import timedelta from common.utils.exceptions import ( AlreadyRegisteredError, - CredentialsNotRequiredError, InvalidRegistrationCredentialsError, ) from monkey_island.cc.environment.environment_config import EnvironmentConfig @@ -24,20 +23,14 @@ class Environment(object, metaclass=ABCMeta): self._config = config self._testing = False # Assume env is not for unit testing. - @property - @abstractmethod - def _credentials_required(self) -> bool: - pass - - @abstractmethod - def get_auth_users(self): - pass + def get_user(self): + return self._config.user_creds def needs_registration(self) -> bool: try: needs_registration = self._try_needs_registration() return needs_registration - except (CredentialsNotRequiredError, AlreadyRegisteredError) as e: + except (AlreadyRegisteredError) as e: logger.info(e) return False @@ -49,25 +42,14 @@ class Environment(object, metaclass=ABCMeta): logger.info(f"New user {credentials.username} registered!") def _try_needs_registration(self) -> bool: - if not self._credentials_required: - raise CredentialsNotRequiredError( - "Credentials are not required " "for current environment." + if self._is_registered(): + raise AlreadyRegisteredError( + "User has already been registered. " "Reset credentials or login." ) - else: - if self._is_registered(): - raise AlreadyRegisteredError( - "User has already been registered. " "Reset credentials or login." - ) - return True + return True def _is_registered(self) -> bool: - return self._credentials_required and self._is_credentials_set_up() - - def _is_credentials_set_up(self) -> bool: - if self._config and self._config.user_creds: - return True - else: - return False + return self._config and self._config.user_creds @property def testing(self): diff --git a/monkey/monkey_island/cc/environment/aws.py b/monkey/monkey_island/cc/environment/aws.py index fda359133..c367d3a98 100644 --- a/monkey/monkey_island/cc/environment/aws.py +++ b/monkey/monkey_island/cc/environment/aws.py @@ -3,15 +3,7 @@ from monkey_island.cc.environment import Environment class AwsEnvironment(Environment): - _credentials_required = True - def __init__(self, config): super(AwsEnvironment, self).__init__(config) # Not suppressing error here on purpose. This is critical if we're on AWS env. self.aws_info = AwsInstance() - - def get_auth_users(self): - if self._is_registered(): - return self._config.get_users() - else: - return [] diff --git a/monkey/monkey_island/cc/environment/environment_config.py b/monkey/monkey_island/cc/environment/environment_config.py index 804b7896f..42623369e 100644 --- a/monkey/monkey_island/cc/environment/environment_config.py +++ b/monkey/monkey_island/cc/environment/environment_config.py @@ -2,11 +2,9 @@ from __future__ import annotations import json import os -from typing import Dict, List +from typing import Dict from monkey_island.cc.environment.user_creds import UserCreds -from monkey_island.cc.resources.auth.auth_user import User -from monkey_island.cc.resources.auth.user_store import UserStore class EnvironmentConfig: @@ -58,11 +56,6 @@ class EnvironmentConfig: def add_user(self, credentials: UserCreds): self.user_creds = credentials self.save_to_file() - UserStore.set_users(self.get_users()) - - def get_users(self) -> List[User]: - auth_user = self.user_creds.to_auth_user() - return [auth_user] if auth_user else [] def _get_user_credentials_from_config(dict_data: Dict): diff --git a/monkey/monkey_island/cc/environment/password.py b/monkey/monkey_island/cc/environment/password.py index c79f2caba..6dc9eea09 100644 --- a/monkey/monkey_island/cc/environment/password.py +++ b/monkey/monkey_island/cc/environment/password.py @@ -1,11 +1,7 @@ from monkey_island.cc.environment import Environment +# TODO: We can probably remove these Environment subclasses, but the +# AwsEnvironment class still does something unique in its constructor. class PasswordEnvironment(Environment): - _credentials_required = True - - def get_auth_users(self): - if self._is_registered(): - return self._config.get_users() - else: - return [] + pass diff --git a/monkey/monkey_island/cc/environment/testing.py b/monkey/monkey_island/cc/environment/testing.py deleted file mode 100644 index efa323fe8..000000000 --- a/monkey/monkey_island/cc/environment/testing.py +++ /dev/null @@ -1,18 +0,0 @@ -from monkey_island.cc.environment import Environment, EnvironmentConfig - - -class TestingEnvironment(Environment): - """ - Use this environment for running Unit Tests. - This will cause all mongo connections to happen via `mongomock` instead of using an actual - mongodb instance. - """ - - _credentials_required = True - - def __init__(self, config: EnvironmentConfig): - super(TestingEnvironment, self).__init__(config) - self.testing = True - - def get_auth_users(self): - return [] diff --git a/monkey/monkey_island/cc/environment/user_creds.py b/monkey/monkey_island/cc/environment/user_creds.py index aba349f2d..a30edae5f 100644 --- a/monkey/monkey_island/cc/environment/user_creds.py +++ b/monkey/monkey_island/cc/environment/user_creds.py @@ -2,8 +2,6 @@ from __future__ import annotations from typing import Dict -from monkey_island.cc.resources.auth.auth_user import User - class UserCreds: def __init__(self, username, password_hash): @@ -20,6 +18,3 @@ class UserCreds: if self.password_hash: cred_dict.update({"password_hash": self.password_hash}) return cred_dict - - def to_auth_user(self) -> User: - return User(1, self.username, self.password_hash) diff --git a/monkey/monkey_island/cc/resources/auth/auth.py b/monkey/monkey_island/cc/resources/auth/auth.py index 4e31778bf..7baac22c9 100644 --- a/monkey/monkey_island/cc/resources/auth/auth.py +++ b/monkey/monkey_island/cc/resources/auth/auth.py @@ -7,19 +7,14 @@ from flask import make_response, request from flask_jwt_extended.exceptions import JWTExtendedException from jwt import PyJWTError -import monkey_island.cc.environment.environment_singleton as env_singleton -import monkey_island.cc.resources.auth.user_store as user_store -from monkey_island.cc.resources.auth.credential_utils import ( - get_username_password_from_request, - password_matches_hash, -) +from common.utils.exceptions import IncorrectCredentialsError +from monkey_island.cc.resources.auth.credential_utils import get_username_password_from_request from monkey_island.cc.services.authentication import AuthenticationService logger = logging.getLogger(__name__) def init_jwt(app): - user_store.UserStore.set_users(env_singleton.env.get_auth_users()) _ = flask_jwt_extended.JWTManager(app) logger.debug( "Initialized JWT with secret key that started with " + app.config["JWT_SECRET_KEY"][:4] @@ -43,27 +38,17 @@ class Authenticate(flask_restful.Resource): """ username, password = get_username_password_from_request(request) - if _credentials_match_registered_user(username, password): - AuthenticationService.unlock_datastore_encryptor(username, password) + try: + AuthenticationService.authenticate(username, password) access_token = _create_access_token(username) - return make_response({"access_token": access_token, "error": ""}, 200) - else: + except IncorrectCredentialsError: return make_response({"error": "Invalid credentials"}, 401) - -def _credentials_match_registered_user(username: str, password: str) -> bool: - user = user_store.UserStore.username_table.get(username, None) - - if user and password_matches_hash(password, user.secret): - return True - - return False + return make_response({"access_token": access_token, "error": ""}, 200) def _create_access_token(username): - access_token = flask_jwt_extended.create_access_token( - identity=user_store.UserStore.username_table[username].id - ) + access_token = flask_jwt_extended.create_access_token(identity=username) logger.debug(f"Created access token for user {username} that begins with {access_token[:4]}") return access_token diff --git a/monkey/monkey_island/cc/resources/auth/auth_user.py b/monkey/monkey_island/cc/resources/auth/auth_user.py deleted file mode 100644 index 547b6e5bc..000000000 --- a/monkey/monkey_island/cc/resources/auth/auth_user.py +++ /dev/null @@ -1,8 +0,0 @@ -class User(object): - def __init__(self, user_id, username, secret): - self.id = user_id - self.username = username - self.secret = secret - - def __str__(self): - return "User(id='%s')" % self.id diff --git a/monkey/monkey_island/cc/resources/auth/credential_utils.py b/monkey/monkey_island/cc/resources/auth/credential_utils.py index a0823d42b..57d5ebc70 100644 --- a/monkey/monkey_island/cc/resources/auth/credential_utils.py +++ b/monkey/monkey_island/cc/resources/auth/credential_utils.py @@ -1,29 +1,8 @@ import json from typing import Tuple -import bcrypt from flask import Request, request -from monkey_island.cc.environment.user_creds import UserCreds - - -def hash_password(plaintext_password): - salt = bcrypt.gensalt() - password_hash = bcrypt.hashpw(plaintext_password.encode("utf-8"), salt) - - return password_hash.decode() - - -def password_matches_hash(plaintext_password, password_hash): - return bcrypt.checkpw(plaintext_password.encode("utf-8"), password_hash.encode("utf-8")) - - -def get_user_credentials_from_request(_request) -> UserCreds: - username, password = get_username_password_from_request(_request) - password_hash = hash_password(password) - - return UserCreds(username, password_hash) - def get_username_password_from_request(_request: Request) -> Tuple[str, str]: cred_dict = json.loads(request.data) diff --git a/monkey/monkey_island/cc/resources/auth/registration.py b/monkey/monkey_island/cc/resources/auth/registration.py index 670fa4d19..fd9532456 100644 --- a/monkey/monkey_island/cc/resources/auth/registration.py +++ b/monkey/monkey_island/cc/resources/auth/registration.py @@ -3,31 +3,22 @@ import logging import flask_restful from flask import make_response, request -import monkey_island.cc.environment.environment_singleton as env_singleton from common.utils.exceptions import InvalidRegistrationCredentialsError, RegistrationNotNeededError -from monkey_island.cc.resources.auth.credential_utils import ( - get_user_credentials_from_request, - get_username_password_from_request, -) +from monkey_island.cc.resources.auth.credential_utils import get_username_password_from_request from monkey_island.cc.services.authentication import AuthenticationService -from monkey_island.cc.setup.mongo.database_initializer import reset_database logger = logging.getLogger(__name__) class Registration(flask_restful.Resource): def get(self): - is_registration_needed = env_singleton.env.needs_registration() - return {"needs_registration": is_registration_needed} + return {"needs_registration": AuthenticationService.needs_registration()} def post(self): - credentials = get_user_credentials_from_request(request) + username, password = get_username_password_from_request(request) try: - env_singleton.env.try_add_user(credentials) - username, password = get_username_password_from_request(request) - AuthenticationService.reset_datastore_encryptor(username, password) - reset_database() + AuthenticationService.register_new_user(username, password) return make_response({"error": ""}, 200) except (InvalidRegistrationCredentialsError, RegistrationNotNeededError) as e: return make_response({"error": str(e)}, 400) diff --git a/monkey/monkey_island/cc/resources/auth/user_store.py b/monkey/monkey_island/cc/resources/auth/user_store.py deleted file mode 100644 index 3c5217f57..000000000 --- a/monkey/monkey_island/cc/resources/auth/user_store.py +++ /dev/null @@ -1,13 +0,0 @@ -from typing import List - -from monkey_island.cc.resources.auth.auth_user import User - - -class UserStore: - users = [] - username_table = {} - - @staticmethod - def set_users(users: List[User]): - UserStore.users = users - UserStore.username_table = {u.username: u for u in UserStore.users} diff --git a/monkey/monkey_island/cc/services/authentication.py b/monkey/monkey_island/cc/services/authentication.py index 88b5f3fb0..52d6cfe44 100644 --- a/monkey/monkey_island/cc/services/authentication.py +++ b/monkey/monkey_island/cc/services/authentication.py @@ -1,29 +1,75 @@ +import bcrypt + +import monkey_island.cc.environment.environment_singleton as env_singleton +from common.utils.exceptions import IncorrectCredentialsError +from monkey_island.cc.environment.user_creds import UserCreds from monkey_island.cc.server_utils.encryption import ( reset_datastore_encryptor, unlock_datastore_encryptor, ) +from monkey_island.cc.setup.mongo.database_initializer import reset_database class AuthenticationService: - KEY_FILE_DIRECTORY = None + DATA_DIR = None # TODO: A number of these services should be instance objects instead of # static/singleton hybrids. At the moment, this requires invasive refactoring that's # not a priority. @classmethod - def initialize(cls, key_file_directory): - cls.KEY_FILE_DIRECTORY = key_file_directory + def initialize(cls, data_dir: str): + cls.DATA_DIR = data_dir @staticmethod - def unlock_datastore_encryptor(username: str, password: str): - secret = AuthenticationService._get_secret_from_credentials(username, password) - unlock_datastore_encryptor(AuthenticationService.KEY_FILE_DIRECTORY, secret) + def needs_registration() -> bool: + return env_singleton.env.needs_registration() - @staticmethod - def reset_datastore_encryptor(username: str, password: str): - secret = AuthenticationService._get_secret_from_credentials(username, password) - reset_datastore_encryptor(AuthenticationService.KEY_FILE_DIRECTORY, secret) + @classmethod + def register_new_user(cls, username: str, password: str): + credentials = UserCreds(username, _hash_password(password)) + env_singleton.env.try_add_user(credentials) + cls._reset_datastore_encryptor(username, password) + reset_database() - @staticmethod - def _get_secret_from_credentials(username: str, password: str) -> str: - return f"{username}:{password}" + @classmethod + def authenticate(cls, username: str, password: str): + if not _credentials_match_registered_user(username, password): + raise IncorrectCredentialsError() + + cls._unlock_datastore_encryptor(username, password) + + @classmethod + def _unlock_datastore_encryptor(cls, username: str, password: str): + secret = _get_secret_from_credentials(username, password) + unlock_datastore_encryptor(cls.DATA_DIR, secret) + + @classmethod + def _reset_datastore_encryptor(cls, username: str, password: str): + secret = _get_secret_from_credentials(username, password) + reset_datastore_encryptor(cls.DATA_DIR, secret) + + +def _hash_password(plaintext_password: str) -> str: + salt = bcrypt.gensalt() + password_hash = bcrypt.hashpw(plaintext_password.encode("utf-8"), salt) + + return password_hash.decode() + + +def _credentials_match_registered_user(username: str, password: str) -> bool: + registered_user = env_singleton.env.get_user() + + if not registered_user: + return False + + return (registered_user.username == username) and _password_matches_hash( + password, registered_user.password_hash + ) + + +def _password_matches_hash(plaintext_password: str, password_hash: str) -> bool: + return bcrypt.checkpw(plaintext_password.encode("utf-8"), password_hash.encode("utf-8")) + + +def _get_secret_from_credentials(username: str, password: str) -> str: + return f"{username}:{password}" diff --git a/monkey/monkey_island/cc/services/initialize.py b/monkey/monkey_island/cc/services/initialize.py index b6e37bbc7..caa599e00 100644 --- a/monkey/monkey_island/cc/services/initialize.py +++ b/monkey/monkey_island/cc/services/initialize.py @@ -6,4 +6,4 @@ from monkey_island.cc.services.run_local_monkey import LocalMonkeyRunService def initialize_services(data_dir): PostBreachFilesService.initialize(data_dir) LocalMonkeyRunService.initialize(data_dir) - AuthenticationService.initialize(key_file_directory=data_dir) + AuthenticationService.initialize(data_dir) diff --git a/monkey/tests/unit_tests/monkey_island/cc/environment/test_environment.py b/monkey/tests/unit_tests/monkey_island/cc/environment/test_environment.py index 10adea8b7..67f7db115 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/environment/test_environment.py +++ b/monkey/tests/unit_tests/monkey_island/cc/environment/test_environment.py @@ -6,20 +6,14 @@ from unittest.mock import MagicMock, patch import pytest -from common.utils.exceptions import ( - AlreadyRegisteredError, - CredentialsNotRequiredError, - InvalidRegistrationCredentialsError, - RegistrationNotNeededError, -) +from common.utils.exceptions import AlreadyRegisteredError, InvalidRegistrationCredentialsError from monkey_island.cc.environment import Environment, EnvironmentConfig, UserCreds WITH_CREDENTIALS = None NO_CREDENTIALS = None PARTIAL_CREDENTIALS = None -EMPTY_USER_CREDENTIALS = UserCreds("", "") -FULL_USER_CREDENTIALS = UserCreds(username="test", password_hash="1231234") +USER_CREDENTIALS = UserCreds(username="test", password_hash="1231234") # This fixture is a dirty hack that can be removed once these tests are converted from @@ -52,60 +46,31 @@ class StubEnvironmentConfig(EnvironmentConfig): class TestEnvironment(TestCase): - class EnvironmentCredentialsNotRequired(Environment): - def __init__(self): - config = StubEnvironmentConfig("test", "test", EMPTY_USER_CREDENTIALS) - super().__init__(config) - - _credentials_required = False - - def get_auth_users(self): - return [] - class EnvironmentCredentialsRequired(Environment): def __init__(self): - config = StubEnvironmentConfig("test", "test", EMPTY_USER_CREDENTIALS) + config = StubEnvironmentConfig("test", "test", None) super().__init__(config) - _credentials_required = True - - def get_auth_users(self): - return [] - class EnvironmentAlreadyRegistered(Environment): def __init__(self): config = StubEnvironmentConfig("test", "test", UserCreds("test_user", "test_secret")) super().__init__(config) - _credentials_required = True - - def get_auth_users(self): - return [1, "Test_username", "Test_secret"] - @patch.object(target=EnvironmentConfig, attribute="save_to_file", new=MagicMock()) def test_try_add_user(self): env = TestEnvironment.EnvironmentCredentialsRequired() - credentials = FULL_USER_CREDENTIALS + credentials = USER_CREDENTIALS env.try_add_user(credentials) credentials = UserCreds(username="test", password_hash="") with self.assertRaises(InvalidRegistrationCredentialsError): env.try_add_user(credentials) - env = TestEnvironment.EnvironmentCredentialsNotRequired() - credentials = FULL_USER_CREDENTIALS - with self.assertRaises(RegistrationNotNeededError): - env.try_add_user(credentials) - def test_try_needs_registration(self): env = TestEnvironment.EnvironmentAlreadyRegistered() with self.assertRaises(AlreadyRegisteredError): env._try_needs_registration() - env = TestEnvironment.EnvironmentCredentialsNotRequired() - with self.assertRaises(CredentialsNotRequiredError): - env._try_needs_registration() - env = TestEnvironment.EnvironmentCredentialsRequired() self.assertTrue(env._try_needs_registration()) @@ -121,12 +86,6 @@ class TestEnvironment(TestCase): self._test_bool_env_method("_is_registered", env, NO_CREDENTIALS, False) self._test_bool_env_method("_is_registered", env, PARTIAL_CREDENTIALS, False) - def test_is_credentials_set_up(self): - env = TestEnvironment.EnvironmentCredentialsRequired() - self._test_bool_env_method("_is_credentials_set_up", env, NO_CREDENTIALS, False) - self._test_bool_env_method("_is_credentials_set_up", env, WITH_CREDENTIALS, True) - self._test_bool_env_method("_is_credentials_set_up", env, PARTIAL_CREDENTIALS, False) - def _test_bool_env_method( self, method_name: str, env: Environment, config: Dict, expected_result: bool ): diff --git a/monkey/tests/unit_tests/monkey_island/cc/environment/test_environment_config.py b/monkey/tests/unit_tests/monkey_island/cc/environment/test_environment_config.py index 63ae123bf..13a50c62e 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/environment/test_environment_config.py +++ b/monkey/tests/unit_tests/monkey_island/cc/environment/test_environment_config.py @@ -82,11 +82,9 @@ def test_add_user(config_file, with_credentials): assert from_file["environment"]["password_hash"] == new_password_hash -def test_get_users(with_credentials): +def test_user(with_credentials): environment_config = EnvironmentConfig(with_credentials) - users = environment_config.get_users() + user = environment_config.user_creds - assert len(users) == 1 - assert users[0].id == 1 - assert users[0].username == "test" - assert users[0].secret == "abcdef" + assert user.username == "test" + assert user.password_hash == "abcdef" diff --git a/monkey/tests/unit_tests/monkey_island/cc/environment/test_user_creds.py b/monkey/tests/unit_tests/monkey_island/cc/environment/test_user_creds.py index 7d83ba59f..d629687d6 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/environment/test_user_creds.py +++ b/monkey/tests/unit_tests/monkey_island/cc/environment/test_user_creds.py @@ -30,14 +30,6 @@ def test_to_dict_full_creds(): assert user_creds.to_dict() == {"user": TEST_USER, "password_hash": TEST_HASH} -def test_to_auth_user_full_credentials(): - user_creds = UserCreds(TEST_USER, TEST_HASH) - auth_user = user_creds.to_auth_user() - assert auth_user.id == 1 - assert auth_user.username == TEST_USER - assert auth_user.secret == TEST_HASH - - def test_member_values(monkeypatch): creds = UserCreds(TEST_USER, TEST_HASH) assert creds.username == TEST_USER diff --git a/monkey/tests/unit_tests/monkey_island/cc/resources/auth/test_auth.py b/monkey/tests/unit_tests/monkey_island/cc/resources/auth/test_auth.py new file mode 100644 index 000000000..8bcc80690 --- /dev/null +++ b/monkey/tests/unit_tests/monkey_island/cc/resources/auth/test_auth.py @@ -0,0 +1,71 @@ +import re +from unittest.mock import MagicMock + +import pytest + +from common.utils.exceptions import IncorrectCredentialsError + +USERNAME = "test_user" +PASSWORD = "test_password" +TEST_REQUEST = f'{{"username": "{USERNAME}", "password": "{PASSWORD}"}}' + + +@pytest.fixture +def mock_authentication_service(monkeypatch): + mock_service = MagicMock() + mock_service.authenticate = MagicMock() + + monkeypatch.setattr("monkey_island.cc.resources.auth.auth.AuthenticationService", mock_service) + + return mock_service + + +@pytest.fixture +def make_auth_request(flask_client): + url = "/api/auth" + + def inner(request_body): + return flask_client.post(url, data=request_body, follow_redirects=True) + + return inner + + +def test_credential_parsing(make_auth_request, mock_authentication_service): + make_auth_request(TEST_REQUEST) + mock_authentication_service.authenticate.assert_called_with(USERNAME, PASSWORD) + + +def test_empty_credentials(make_auth_request, mock_authentication_service): + make_auth_request("{}") + mock_authentication_service.authenticate.assert_called_with("", "") + + +def test_authentication_successful(make_auth_request, mock_authentication_service): + mock_authentication_service.authenticate = MagicMock(return_value=True) + + response = make_auth_request(TEST_REQUEST) + + assert response.status_code == 200 + assert response.json["error"] == "" + assert re.match( + r"^[a-zA-Z0-9+/=]+\.[a-zA-Z0-9+/=]+\.[a-zA-Z0-9+/=\-_]+$", response.json["access_token"] + ) + + +def test_authentication_failure(make_auth_request, mock_authentication_service): + mock_authentication_service.authenticate = MagicMock(side_effect=IncorrectCredentialsError()) + + response = make_auth_request(TEST_REQUEST) + + assert "access_token" not in response.json + assert response.status_code == 401 + assert response.json["error"] == "Invalid credentials" + + +def test_authentication_error(make_auth_request, mock_authentication_service): + mock_authentication_service.authenticate = MagicMock(side_effect=Exception()) + + response = make_auth_request(TEST_REQUEST) + + assert "access_token" not in response.json + assert response.status_code == 500 diff --git a/monkey/tests/unit_tests/monkey_island/cc/resources/auth/test_registration.py b/monkey/tests/unit_tests/monkey_island/cc/resources/auth/test_registration.py new file mode 100644 index 000000000..0b1b60251 --- /dev/null +++ b/monkey/tests/unit_tests/monkey_island/cc/resources/auth/test_registration.py @@ -0,0 +1,87 @@ +import json +from unittest.mock import MagicMock + +import pytest + +from common.utils.exceptions import InvalidRegistrationCredentialsError, RegistrationNotNeededError + +REGISTRATION_URL = "/api/registration" + +USERNAME = "test_user" +PASSWORD = "test_password" + + +@pytest.fixture +def mock_authentication_service(monkeypatch): + mock_service = MagicMock() + mock_service.register_new_user = MagicMock() + mock_service.needs_registration = MagicMock() + + monkeypatch.setattr( + "monkey_island.cc.resources.auth.registration.AuthenticationService", mock_service + ) + + return mock_service + + +@pytest.fixture +def make_registration_request(flask_client): + def inner(request_body): + return flask_client.post(REGISTRATION_URL, data=request_body, follow_redirects=True) + + return inner + + +def test_registration(make_registration_request, mock_authentication_service): + registration_request_body = f'{{"username": "{USERNAME}", "password": "{PASSWORD}"}}' + response = make_registration_request(registration_request_body) + + assert response.status_code == 200 + mock_authentication_service.register_new_user.assert_called_with(USERNAME, PASSWORD) + + +def test_empty_credentials(make_registration_request, mock_authentication_service): + registration_request_body = "{}" + make_registration_request(registration_request_body) + + mock_authentication_service.register_new_user.assert_called_with("", "") + + +def test_invalid_credentials(make_registration_request, mock_authentication_service): + mock_authentication_service.register_new_user = MagicMock( + side_effect=InvalidRegistrationCredentialsError() + ) + + registration_request_body = "{}" + response = make_registration_request(registration_request_body) + + assert response.status_code == 400 + + +def test_registration_not_needed(make_registration_request, mock_authentication_service): + mock_authentication_service.register_new_user = MagicMock( + side_effect=RegistrationNotNeededError() + ) + + registration_request_body = "{}" + response = make_registration_request(registration_request_body) + + assert response.status_code == 400 + + +def test_internal_error(make_registration_request, mock_authentication_service): + mock_authentication_service.register_new_user = MagicMock(side_effect=Exception()) + + registration_request_body = json.dumps({}) + response = make_registration_request(registration_request_body) + + assert response.status_code == 500 + + +@pytest.mark.parametrize("needs_registration", [True, False]) +def test_needs_registration(flask_client, mock_authentication_service, needs_registration): + mock_authentication_service.needs_registration = MagicMock(return_value=needs_registration) + response = flask_client.get(REGISTRATION_URL, follow_redirects=True) + + assert response.status_code == 200 + assert response.json["needs_registration"] is needs_registration diff --git a/monkey/tests/unit_tests/monkey_island/cc/resources/conftest.py b/monkey/tests/unit_tests/monkey_island/cc/resources/conftest.py index 3ca40a11a..eeef5b383 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/resources/conftest.py +++ b/monkey/tests/unit_tests/monkey_island/cc/resources/conftest.py @@ -19,6 +19,7 @@ def flask_client(monkeypatch_session): def mock_init_app(): app = Flask(__name__) + app.config["SECRET_KEY"] = "test_key" api = flask_restful.Api(app) api.representations = {"application/json": output_json} @@ -26,4 +27,6 @@ def mock_init_app(): monkey_island.cc.app.init_app_url_rules(app) monkey_island.cc.app.init_api_resources(api) + flask_jwt_extended.JWTManager(app) + return app