forked from p34709852/monkey
Merge pull request #1561 from guardicore/1554-authentication-service-refactor
Authentication service refactor
This commit is contained in:
commit
8ee918b5a2
|
@ -14,14 +14,14 @@ class RegistrationNotNeededError(Exception):
|
||||||
""" Raise to indicate the reason why registration is not required """
|
""" Raise to indicate the reason why registration is not required """
|
||||||
|
|
||||||
|
|
||||||
class CredentialsNotRequiredError(RegistrationNotNeededError):
|
|
||||||
""" Raise to indicate the reason why registration is not required """
|
|
||||||
|
|
||||||
|
|
||||||
class AlreadyRegisteredError(RegistrationNotNeededError):
|
class AlreadyRegisteredError(RegistrationNotNeededError):
|
||||||
""" Raise to indicate the reason why registration is not required """
|
""" Raise to indicate the reason why registration is not required """
|
||||||
|
|
||||||
|
|
||||||
|
class IncorrectCredentialsError(Exception):
|
||||||
|
""" Raise to indicate that authentication failed """
|
||||||
|
|
||||||
|
|
||||||
class RulePathCreatorNotFound(Exception):
|
class RulePathCreatorNotFound(Exception):
|
||||||
""" Raise to indicate that ScoutSuite rule doesn't have a path creator"""
|
""" Raise to indicate that ScoutSuite rule doesn't have a path creator"""
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,6 @@ from datetime import timedelta
|
||||||
|
|
||||||
from common.utils.exceptions import (
|
from common.utils.exceptions import (
|
||||||
AlreadyRegisteredError,
|
AlreadyRegisteredError,
|
||||||
CredentialsNotRequiredError,
|
|
||||||
InvalidRegistrationCredentialsError,
|
InvalidRegistrationCredentialsError,
|
||||||
)
|
)
|
||||||
from monkey_island.cc.environment.environment_config import EnvironmentConfig
|
from monkey_island.cc.environment.environment_config import EnvironmentConfig
|
||||||
|
@ -24,20 +23,14 @@ class Environment(object, metaclass=ABCMeta):
|
||||||
self._config = config
|
self._config = config
|
||||||
self._testing = False # Assume env is not for unit testing.
|
self._testing = False # Assume env is not for unit testing.
|
||||||
|
|
||||||
@property
|
def get_user(self):
|
||||||
@abstractmethod
|
return self._config.user_creds
|
||||||
def _credentials_required(self) -> bool:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_auth_users(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def needs_registration(self) -> bool:
|
def needs_registration(self) -> bool:
|
||||||
try:
|
try:
|
||||||
needs_registration = self._try_needs_registration()
|
needs_registration = self._try_needs_registration()
|
||||||
return needs_registration
|
return needs_registration
|
||||||
except (CredentialsNotRequiredError, AlreadyRegisteredError) as e:
|
except (AlreadyRegisteredError) as e:
|
||||||
logger.info(e)
|
logger.info(e)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -49,11 +42,6 @@ class Environment(object, metaclass=ABCMeta):
|
||||||
logger.info(f"New user {credentials.username} registered!")
|
logger.info(f"New user {credentials.username} registered!")
|
||||||
|
|
||||||
def _try_needs_registration(self) -> bool:
|
def _try_needs_registration(self) -> bool:
|
||||||
if not self._credentials_required:
|
|
||||||
raise CredentialsNotRequiredError(
|
|
||||||
"Credentials are not required " "for current environment."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if self._is_registered():
|
if self._is_registered():
|
||||||
raise AlreadyRegisteredError(
|
raise AlreadyRegisteredError(
|
||||||
"User has already been registered. " "Reset credentials or login."
|
"User has already been registered. " "Reset credentials or login."
|
||||||
|
@ -61,13 +49,7 @@ class Environment(object, metaclass=ABCMeta):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _is_registered(self) -> bool:
|
def _is_registered(self) -> bool:
|
||||||
return self._credentials_required and self._is_credentials_set_up()
|
return self._config and self._config.user_creds
|
||||||
|
|
||||||
def _is_credentials_set_up(self) -> bool:
|
|
||||||
if self._config and self._config.user_creds:
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def testing(self):
|
def testing(self):
|
||||||
|
|
|
@ -3,15 +3,7 @@ from monkey_island.cc.environment import Environment
|
||||||
|
|
||||||
|
|
||||||
class AwsEnvironment(Environment):
|
class AwsEnvironment(Environment):
|
||||||
_credentials_required = True
|
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(AwsEnvironment, self).__init__(config)
|
super(AwsEnvironment, self).__init__(config)
|
||||||
# Not suppressing error here on purpose. This is critical if we're on AWS env.
|
# Not suppressing error here on purpose. This is critical if we're on AWS env.
|
||||||
self.aws_info = AwsInstance()
|
self.aws_info = AwsInstance()
|
||||||
|
|
||||||
def get_auth_users(self):
|
|
||||||
if self._is_registered():
|
|
||||||
return self._config.get_users()
|
|
||||||
else:
|
|
||||||
return []
|
|
||||||
|
|
|
@ -2,11 +2,9 @@ from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Dict, List
|
from typing import Dict
|
||||||
|
|
||||||
from monkey_island.cc.environment.user_creds import UserCreds
|
from monkey_island.cc.environment.user_creds import UserCreds
|
||||||
from monkey_island.cc.resources.auth.auth_user import User
|
|
||||||
from monkey_island.cc.resources.auth.user_store import UserStore
|
|
||||||
|
|
||||||
|
|
||||||
class EnvironmentConfig:
|
class EnvironmentConfig:
|
||||||
|
@ -58,11 +56,6 @@ class EnvironmentConfig:
|
||||||
def add_user(self, credentials: UserCreds):
|
def add_user(self, credentials: UserCreds):
|
||||||
self.user_creds = credentials
|
self.user_creds = credentials
|
||||||
self.save_to_file()
|
self.save_to_file()
|
||||||
UserStore.set_users(self.get_users())
|
|
||||||
|
|
||||||
def get_users(self) -> List[User]:
|
|
||||||
auth_user = self.user_creds.to_auth_user()
|
|
||||||
return [auth_user] if auth_user else []
|
|
||||||
|
|
||||||
|
|
||||||
def _get_user_credentials_from_config(dict_data: Dict):
|
def _get_user_credentials_from_config(dict_data: Dict):
|
||||||
|
|
|
@ -1,11 +1,7 @@
|
||||||
from monkey_island.cc.environment import Environment
|
from monkey_island.cc.environment import Environment
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: We can probably remove these Environment subclasses, but the
|
||||||
|
# AwsEnvironment class still does something unique in its constructor.
|
||||||
class PasswordEnvironment(Environment):
|
class PasswordEnvironment(Environment):
|
||||||
_credentials_required = True
|
pass
|
||||||
|
|
||||||
def get_auth_users(self):
|
|
||||||
if self._is_registered():
|
|
||||||
return self._config.get_users()
|
|
||||||
else:
|
|
||||||
return []
|
|
||||||
|
|
|
@ -1,18 +0,0 @@
|
||||||
from monkey_island.cc.environment import Environment, EnvironmentConfig
|
|
||||||
|
|
||||||
|
|
||||||
class TestingEnvironment(Environment):
|
|
||||||
"""
|
|
||||||
Use this environment for running Unit Tests.
|
|
||||||
This will cause all mongo connections to happen via `mongomock` instead of using an actual
|
|
||||||
mongodb instance.
|
|
||||||
"""
|
|
||||||
|
|
||||||
_credentials_required = True
|
|
||||||
|
|
||||||
def __init__(self, config: EnvironmentConfig):
|
|
||||||
super(TestingEnvironment, self).__init__(config)
|
|
||||||
self.testing = True
|
|
||||||
|
|
||||||
def get_auth_users(self):
|
|
||||||
return []
|
|
|
@ -2,8 +2,6 @@ from __future__ import annotations
|
||||||
|
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from monkey_island.cc.resources.auth.auth_user import User
|
|
||||||
|
|
||||||
|
|
||||||
class UserCreds:
|
class UserCreds:
|
||||||
def __init__(self, username, password_hash):
|
def __init__(self, username, password_hash):
|
||||||
|
@ -20,6 +18,3 @@ class UserCreds:
|
||||||
if self.password_hash:
|
if self.password_hash:
|
||||||
cred_dict.update({"password_hash": self.password_hash})
|
cred_dict.update({"password_hash": self.password_hash})
|
||||||
return cred_dict
|
return cred_dict
|
||||||
|
|
||||||
def to_auth_user(self) -> User:
|
|
||||||
return User(1, self.username, self.password_hash)
|
|
||||||
|
|
|
@ -7,19 +7,14 @@ from flask import make_response, request
|
||||||
from flask_jwt_extended.exceptions import JWTExtendedException
|
from flask_jwt_extended.exceptions import JWTExtendedException
|
||||||
from jwt import PyJWTError
|
from jwt import PyJWTError
|
||||||
|
|
||||||
import monkey_island.cc.environment.environment_singleton as env_singleton
|
from common.utils.exceptions import IncorrectCredentialsError
|
||||||
import monkey_island.cc.resources.auth.user_store as user_store
|
from monkey_island.cc.resources.auth.credential_utils import get_username_password_from_request
|
||||||
from monkey_island.cc.resources.auth.credential_utils import (
|
|
||||||
get_username_password_from_request,
|
|
||||||
password_matches_hash,
|
|
||||||
)
|
|
||||||
from monkey_island.cc.services.authentication import AuthenticationService
|
from monkey_island.cc.services.authentication import AuthenticationService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def init_jwt(app):
|
def init_jwt(app):
|
||||||
user_store.UserStore.set_users(env_singleton.env.get_auth_users())
|
|
||||||
_ = flask_jwt_extended.JWTManager(app)
|
_ = flask_jwt_extended.JWTManager(app)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Initialized JWT with secret key that started with " + app.config["JWT_SECRET_KEY"][:4]
|
"Initialized JWT with secret key that started with " + app.config["JWT_SECRET_KEY"][:4]
|
||||||
|
@ -43,27 +38,17 @@ class Authenticate(flask_restful.Resource):
|
||||||
"""
|
"""
|
||||||
username, password = get_username_password_from_request(request)
|
username, password = get_username_password_from_request(request)
|
||||||
|
|
||||||
if _credentials_match_registered_user(username, password):
|
try:
|
||||||
AuthenticationService.unlock_datastore_encryptor(username, password)
|
AuthenticationService.authenticate(username, password)
|
||||||
access_token = _create_access_token(username)
|
access_token = _create_access_token(username)
|
||||||
return make_response({"access_token": access_token, "error": ""}, 200)
|
except IncorrectCredentialsError:
|
||||||
else:
|
|
||||||
return make_response({"error": "Invalid credentials"}, 401)
|
return make_response({"error": "Invalid credentials"}, 401)
|
||||||
|
|
||||||
|
return make_response({"access_token": access_token, "error": ""}, 200)
|
||||||
def _credentials_match_registered_user(username: str, password: str) -> bool:
|
|
||||||
user = user_store.UserStore.username_table.get(username, None)
|
|
||||||
|
|
||||||
if user and password_matches_hash(password, user.secret):
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _create_access_token(username):
|
def _create_access_token(username):
|
||||||
access_token = flask_jwt_extended.create_access_token(
|
access_token = flask_jwt_extended.create_access_token(identity=username)
|
||||||
identity=user_store.UserStore.username_table[username].id
|
|
||||||
)
|
|
||||||
logger.debug(f"Created access token for user {username} that begins with {access_token[:4]}")
|
logger.debug(f"Created access token for user {username} that begins with {access_token[:4]}")
|
||||||
|
|
||||||
return access_token
|
return access_token
|
||||||
|
|
|
@ -1,8 +0,0 @@
|
||||||
class User(object):
|
|
||||||
def __init__(self, user_id, username, secret):
|
|
||||||
self.id = user_id
|
|
||||||
self.username = username
|
|
||||||
self.secret = secret
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return "User(id='%s')" % self.id
|
|
|
@ -1,29 +1,8 @@
|
||||||
import json
|
import json
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
import bcrypt
|
|
||||||
from flask import Request, request
|
from flask import Request, request
|
||||||
|
|
||||||
from monkey_island.cc.environment.user_creds import UserCreds
|
|
||||||
|
|
||||||
|
|
||||||
def hash_password(plaintext_password):
|
|
||||||
salt = bcrypt.gensalt()
|
|
||||||
password_hash = bcrypt.hashpw(plaintext_password.encode("utf-8"), salt)
|
|
||||||
|
|
||||||
return password_hash.decode()
|
|
||||||
|
|
||||||
|
|
||||||
def password_matches_hash(plaintext_password, password_hash):
|
|
||||||
return bcrypt.checkpw(plaintext_password.encode("utf-8"), password_hash.encode("utf-8"))
|
|
||||||
|
|
||||||
|
|
||||||
def get_user_credentials_from_request(_request) -> UserCreds:
|
|
||||||
username, password = get_username_password_from_request(_request)
|
|
||||||
password_hash = hash_password(password)
|
|
||||||
|
|
||||||
return UserCreds(username, password_hash)
|
|
||||||
|
|
||||||
|
|
||||||
def get_username_password_from_request(_request: Request) -> Tuple[str, str]:
|
def get_username_password_from_request(_request: Request) -> Tuple[str, str]:
|
||||||
cred_dict = json.loads(request.data)
|
cred_dict = json.loads(request.data)
|
||||||
|
|
|
@ -3,31 +3,22 @@ import logging
|
||||||
import flask_restful
|
import flask_restful
|
||||||
from flask import make_response, request
|
from flask import make_response, request
|
||||||
|
|
||||||
import monkey_island.cc.environment.environment_singleton as env_singleton
|
|
||||||
from common.utils.exceptions import InvalidRegistrationCredentialsError, RegistrationNotNeededError
|
from common.utils.exceptions import InvalidRegistrationCredentialsError, RegistrationNotNeededError
|
||||||
from monkey_island.cc.resources.auth.credential_utils import (
|
from monkey_island.cc.resources.auth.credential_utils import get_username_password_from_request
|
||||||
get_user_credentials_from_request,
|
|
||||||
get_username_password_from_request,
|
|
||||||
)
|
|
||||||
from monkey_island.cc.services.authentication import AuthenticationService
|
from monkey_island.cc.services.authentication import AuthenticationService
|
||||||
from monkey_island.cc.setup.mongo.database_initializer import reset_database
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Registration(flask_restful.Resource):
|
class Registration(flask_restful.Resource):
|
||||||
def get(self):
|
def get(self):
|
||||||
is_registration_needed = env_singleton.env.needs_registration()
|
return {"needs_registration": AuthenticationService.needs_registration()}
|
||||||
return {"needs_registration": is_registration_needed}
|
|
||||||
|
|
||||||
def post(self):
|
def post(self):
|
||||||
credentials = get_user_credentials_from_request(request)
|
username, password = get_username_password_from_request(request)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
env_singleton.env.try_add_user(credentials)
|
AuthenticationService.register_new_user(username, password)
|
||||||
username, password = get_username_password_from_request(request)
|
|
||||||
AuthenticationService.reset_datastore_encryptor(username, password)
|
|
||||||
reset_database()
|
|
||||||
return make_response({"error": ""}, 200)
|
return make_response({"error": ""}, 200)
|
||||||
except (InvalidRegistrationCredentialsError, RegistrationNotNeededError) as e:
|
except (InvalidRegistrationCredentialsError, RegistrationNotNeededError) as e:
|
||||||
return make_response({"error": str(e)}, 400)
|
return make_response({"error": str(e)}, 400)
|
||||||
|
|
|
@ -1,13 +0,0 @@
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from monkey_island.cc.resources.auth.auth_user import User
|
|
||||||
|
|
||||||
|
|
||||||
class UserStore:
|
|
||||||
users = []
|
|
||||||
username_table = {}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def set_users(users: List[User]):
|
|
||||||
UserStore.users = users
|
|
||||||
UserStore.username_table = {u.username: u for u in UserStore.users}
|
|
|
@ -1,29 +1,75 @@
|
||||||
|
import bcrypt
|
||||||
|
|
||||||
|
import monkey_island.cc.environment.environment_singleton as env_singleton
|
||||||
|
from common.utils.exceptions import IncorrectCredentialsError
|
||||||
|
from monkey_island.cc.environment.user_creds import UserCreds
|
||||||
from monkey_island.cc.server_utils.encryption import (
|
from monkey_island.cc.server_utils.encryption import (
|
||||||
reset_datastore_encryptor,
|
reset_datastore_encryptor,
|
||||||
unlock_datastore_encryptor,
|
unlock_datastore_encryptor,
|
||||||
)
|
)
|
||||||
|
from monkey_island.cc.setup.mongo.database_initializer import reset_database
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationService:
|
class AuthenticationService:
|
||||||
KEY_FILE_DIRECTORY = None
|
DATA_DIR = None
|
||||||
|
|
||||||
# TODO: A number of these services should be instance objects instead of
|
# TODO: A number of these services should be instance objects instead of
|
||||||
# static/singleton hybrids. At the moment, this requires invasive refactoring that's
|
# static/singleton hybrids. At the moment, this requires invasive refactoring that's
|
||||||
# not a priority.
|
# not a priority.
|
||||||
@classmethod
|
@classmethod
|
||||||
def initialize(cls, key_file_directory):
|
def initialize(cls, data_dir: str):
|
||||||
cls.KEY_FILE_DIRECTORY = key_file_directory
|
cls.DATA_DIR = data_dir
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def unlock_datastore_encryptor(username: str, password: str):
|
def needs_registration() -> bool:
|
||||||
secret = AuthenticationService._get_secret_from_credentials(username, password)
|
return env_singleton.env.needs_registration()
|
||||||
unlock_datastore_encryptor(AuthenticationService.KEY_FILE_DIRECTORY, secret)
|
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def reset_datastore_encryptor(username: str, password: str):
|
def register_new_user(cls, username: str, password: str):
|
||||||
secret = AuthenticationService._get_secret_from_credentials(username, password)
|
credentials = UserCreds(username, _hash_password(password))
|
||||||
reset_datastore_encryptor(AuthenticationService.KEY_FILE_DIRECTORY, secret)
|
env_singleton.env.try_add_user(credentials)
|
||||||
|
cls._reset_datastore_encryptor(username, password)
|
||||||
|
reset_database()
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def _get_secret_from_credentials(username: str, password: str) -> str:
|
def authenticate(cls, username: str, password: str):
|
||||||
|
if not _credentials_match_registered_user(username, password):
|
||||||
|
raise IncorrectCredentialsError()
|
||||||
|
|
||||||
|
cls._unlock_datastore_encryptor(username, password)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _unlock_datastore_encryptor(cls, username: str, password: str):
|
||||||
|
secret = _get_secret_from_credentials(username, password)
|
||||||
|
unlock_datastore_encryptor(cls.DATA_DIR, secret)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _reset_datastore_encryptor(cls, username: str, password: str):
|
||||||
|
secret = _get_secret_from_credentials(username, password)
|
||||||
|
reset_datastore_encryptor(cls.DATA_DIR, secret)
|
||||||
|
|
||||||
|
|
||||||
|
def _hash_password(plaintext_password: str) -> str:
|
||||||
|
salt = bcrypt.gensalt()
|
||||||
|
password_hash = bcrypt.hashpw(plaintext_password.encode("utf-8"), salt)
|
||||||
|
|
||||||
|
return password_hash.decode()
|
||||||
|
|
||||||
|
|
||||||
|
def _credentials_match_registered_user(username: str, password: str) -> bool:
|
||||||
|
registered_user = env_singleton.env.get_user()
|
||||||
|
|
||||||
|
if not registered_user:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return (registered_user.username == username) and _password_matches_hash(
|
||||||
|
password, registered_user.password_hash
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _password_matches_hash(plaintext_password: str, password_hash: str) -> bool:
|
||||||
|
return bcrypt.checkpw(plaintext_password.encode("utf-8"), password_hash.encode("utf-8"))
|
||||||
|
|
||||||
|
|
||||||
|
def _get_secret_from_credentials(username: str, password: str) -> str:
|
||||||
return f"{username}:{password}"
|
return f"{username}:{password}"
|
||||||
|
|
|
@ -6,4 +6,4 @@ from monkey_island.cc.services.run_local_monkey import LocalMonkeyRunService
|
||||||
def initialize_services(data_dir):
|
def initialize_services(data_dir):
|
||||||
PostBreachFilesService.initialize(data_dir)
|
PostBreachFilesService.initialize(data_dir)
|
||||||
LocalMonkeyRunService.initialize(data_dir)
|
LocalMonkeyRunService.initialize(data_dir)
|
||||||
AuthenticationService.initialize(key_file_directory=data_dir)
|
AuthenticationService.initialize(data_dir)
|
||||||
|
|
|
@ -6,20 +6,14 @@ from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from common.utils.exceptions import (
|
from common.utils.exceptions import AlreadyRegisteredError, InvalidRegistrationCredentialsError
|
||||||
AlreadyRegisteredError,
|
|
||||||
CredentialsNotRequiredError,
|
|
||||||
InvalidRegistrationCredentialsError,
|
|
||||||
RegistrationNotNeededError,
|
|
||||||
)
|
|
||||||
from monkey_island.cc.environment import Environment, EnvironmentConfig, UserCreds
|
from monkey_island.cc.environment import Environment, EnvironmentConfig, UserCreds
|
||||||
|
|
||||||
WITH_CREDENTIALS = None
|
WITH_CREDENTIALS = None
|
||||||
NO_CREDENTIALS = None
|
NO_CREDENTIALS = None
|
||||||
PARTIAL_CREDENTIALS = None
|
PARTIAL_CREDENTIALS = None
|
||||||
|
|
||||||
EMPTY_USER_CREDENTIALS = UserCreds("", "")
|
USER_CREDENTIALS = UserCreds(username="test", password_hash="1231234")
|
||||||
FULL_USER_CREDENTIALS = UserCreds(username="test", password_hash="1231234")
|
|
||||||
|
|
||||||
|
|
||||||
# This fixture is a dirty hack that can be removed once these tests are converted from
|
# This fixture is a dirty hack that can be removed once these tests are converted from
|
||||||
|
@ -52,60 +46,31 @@ class StubEnvironmentConfig(EnvironmentConfig):
|
||||||
|
|
||||||
|
|
||||||
class TestEnvironment(TestCase):
|
class TestEnvironment(TestCase):
|
||||||
class EnvironmentCredentialsNotRequired(Environment):
|
|
||||||
def __init__(self):
|
|
||||||
config = StubEnvironmentConfig("test", "test", EMPTY_USER_CREDENTIALS)
|
|
||||||
super().__init__(config)
|
|
||||||
|
|
||||||
_credentials_required = False
|
|
||||||
|
|
||||||
def get_auth_users(self):
|
|
||||||
return []
|
|
||||||
|
|
||||||
class EnvironmentCredentialsRequired(Environment):
|
class EnvironmentCredentialsRequired(Environment):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
config = StubEnvironmentConfig("test", "test", EMPTY_USER_CREDENTIALS)
|
config = StubEnvironmentConfig("test", "test", None)
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
_credentials_required = True
|
|
||||||
|
|
||||||
def get_auth_users(self):
|
|
||||||
return []
|
|
||||||
|
|
||||||
class EnvironmentAlreadyRegistered(Environment):
|
class EnvironmentAlreadyRegistered(Environment):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
config = StubEnvironmentConfig("test", "test", UserCreds("test_user", "test_secret"))
|
config = StubEnvironmentConfig("test", "test", UserCreds("test_user", "test_secret"))
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
_credentials_required = True
|
|
||||||
|
|
||||||
def get_auth_users(self):
|
|
||||||
return [1, "Test_username", "Test_secret"]
|
|
||||||
|
|
||||||
@patch.object(target=EnvironmentConfig, attribute="save_to_file", new=MagicMock())
|
@patch.object(target=EnvironmentConfig, attribute="save_to_file", new=MagicMock())
|
||||||
def test_try_add_user(self):
|
def test_try_add_user(self):
|
||||||
env = TestEnvironment.EnvironmentCredentialsRequired()
|
env = TestEnvironment.EnvironmentCredentialsRequired()
|
||||||
credentials = FULL_USER_CREDENTIALS
|
credentials = USER_CREDENTIALS
|
||||||
env.try_add_user(credentials)
|
env.try_add_user(credentials)
|
||||||
|
|
||||||
credentials = UserCreds(username="test", password_hash="")
|
credentials = UserCreds(username="test", password_hash="")
|
||||||
with self.assertRaises(InvalidRegistrationCredentialsError):
|
with self.assertRaises(InvalidRegistrationCredentialsError):
|
||||||
env.try_add_user(credentials)
|
env.try_add_user(credentials)
|
||||||
|
|
||||||
env = TestEnvironment.EnvironmentCredentialsNotRequired()
|
|
||||||
credentials = FULL_USER_CREDENTIALS
|
|
||||||
with self.assertRaises(RegistrationNotNeededError):
|
|
||||||
env.try_add_user(credentials)
|
|
||||||
|
|
||||||
def test_try_needs_registration(self):
|
def test_try_needs_registration(self):
|
||||||
env = TestEnvironment.EnvironmentAlreadyRegistered()
|
env = TestEnvironment.EnvironmentAlreadyRegistered()
|
||||||
with self.assertRaises(AlreadyRegisteredError):
|
with self.assertRaises(AlreadyRegisteredError):
|
||||||
env._try_needs_registration()
|
env._try_needs_registration()
|
||||||
|
|
||||||
env = TestEnvironment.EnvironmentCredentialsNotRequired()
|
|
||||||
with self.assertRaises(CredentialsNotRequiredError):
|
|
||||||
env._try_needs_registration()
|
|
||||||
|
|
||||||
env = TestEnvironment.EnvironmentCredentialsRequired()
|
env = TestEnvironment.EnvironmentCredentialsRequired()
|
||||||
self.assertTrue(env._try_needs_registration())
|
self.assertTrue(env._try_needs_registration())
|
||||||
|
|
||||||
|
@ -121,12 +86,6 @@ class TestEnvironment(TestCase):
|
||||||
self._test_bool_env_method("_is_registered", env, NO_CREDENTIALS, False)
|
self._test_bool_env_method("_is_registered", env, NO_CREDENTIALS, False)
|
||||||
self._test_bool_env_method("_is_registered", env, PARTIAL_CREDENTIALS, False)
|
self._test_bool_env_method("_is_registered", env, PARTIAL_CREDENTIALS, False)
|
||||||
|
|
||||||
def test_is_credentials_set_up(self):
|
|
||||||
env = TestEnvironment.EnvironmentCredentialsRequired()
|
|
||||||
self._test_bool_env_method("_is_credentials_set_up", env, NO_CREDENTIALS, False)
|
|
||||||
self._test_bool_env_method("_is_credentials_set_up", env, WITH_CREDENTIALS, True)
|
|
||||||
self._test_bool_env_method("_is_credentials_set_up", env, PARTIAL_CREDENTIALS, False)
|
|
||||||
|
|
||||||
def _test_bool_env_method(
|
def _test_bool_env_method(
|
||||||
self, method_name: str, env: Environment, config: Dict, expected_result: bool
|
self, method_name: str, env: Environment, config: Dict, expected_result: bool
|
||||||
):
|
):
|
||||||
|
|
|
@ -82,11 +82,9 @@ def test_add_user(config_file, with_credentials):
|
||||||
assert from_file["environment"]["password_hash"] == new_password_hash
|
assert from_file["environment"]["password_hash"] == new_password_hash
|
||||||
|
|
||||||
|
|
||||||
def test_get_users(with_credentials):
|
def test_user(with_credentials):
|
||||||
environment_config = EnvironmentConfig(with_credentials)
|
environment_config = EnvironmentConfig(with_credentials)
|
||||||
users = environment_config.get_users()
|
user = environment_config.user_creds
|
||||||
|
|
||||||
assert len(users) == 1
|
assert user.username == "test"
|
||||||
assert users[0].id == 1
|
assert user.password_hash == "abcdef"
|
||||||
assert users[0].username == "test"
|
|
||||||
assert users[0].secret == "abcdef"
|
|
||||||
|
|
|
@ -30,14 +30,6 @@ def test_to_dict_full_creds():
|
||||||
assert user_creds.to_dict() == {"user": TEST_USER, "password_hash": TEST_HASH}
|
assert user_creds.to_dict() == {"user": TEST_USER, "password_hash": TEST_HASH}
|
||||||
|
|
||||||
|
|
||||||
def test_to_auth_user_full_credentials():
|
|
||||||
user_creds = UserCreds(TEST_USER, TEST_HASH)
|
|
||||||
auth_user = user_creds.to_auth_user()
|
|
||||||
assert auth_user.id == 1
|
|
||||||
assert auth_user.username == TEST_USER
|
|
||||||
assert auth_user.secret == TEST_HASH
|
|
||||||
|
|
||||||
|
|
||||||
def test_member_values(monkeypatch):
|
def test_member_values(monkeypatch):
|
||||||
creds = UserCreds(TEST_USER, TEST_HASH)
|
creds = UserCreds(TEST_USER, TEST_HASH)
|
||||||
assert creds.username == TEST_USER
|
assert creds.username == TEST_USER
|
||||||
|
|
|
@ -0,0 +1,71 @@
|
||||||
|
import re
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from common.utils.exceptions import IncorrectCredentialsError
|
||||||
|
|
||||||
|
USERNAME = "test_user"
|
||||||
|
PASSWORD = "test_password"
|
||||||
|
TEST_REQUEST = f'{{"username": "{USERNAME}", "password": "{PASSWORD}"}}'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_authentication_service(monkeypatch):
|
||||||
|
mock_service = MagicMock()
|
||||||
|
mock_service.authenticate = MagicMock()
|
||||||
|
|
||||||
|
monkeypatch.setattr("monkey_island.cc.resources.auth.auth.AuthenticationService", mock_service)
|
||||||
|
|
||||||
|
return mock_service
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def make_auth_request(flask_client):
|
||||||
|
url = "/api/auth"
|
||||||
|
|
||||||
|
def inner(request_body):
|
||||||
|
return flask_client.post(url, data=request_body, follow_redirects=True)
|
||||||
|
|
||||||
|
return inner
|
||||||
|
|
||||||
|
|
||||||
|
def test_credential_parsing(make_auth_request, mock_authentication_service):
|
||||||
|
make_auth_request(TEST_REQUEST)
|
||||||
|
mock_authentication_service.authenticate.assert_called_with(USERNAME, PASSWORD)
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_credentials(make_auth_request, mock_authentication_service):
|
||||||
|
make_auth_request("{}")
|
||||||
|
mock_authentication_service.authenticate.assert_called_with("", "")
|
||||||
|
|
||||||
|
|
||||||
|
def test_authentication_successful(make_auth_request, mock_authentication_service):
|
||||||
|
mock_authentication_service.authenticate = MagicMock(return_value=True)
|
||||||
|
|
||||||
|
response = make_auth_request(TEST_REQUEST)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json["error"] == ""
|
||||||
|
assert re.match(
|
||||||
|
r"^[a-zA-Z0-9+/=]+\.[a-zA-Z0-9+/=]+\.[a-zA-Z0-9+/=\-_]+$", response.json["access_token"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_authentication_failure(make_auth_request, mock_authentication_service):
|
||||||
|
mock_authentication_service.authenticate = MagicMock(side_effect=IncorrectCredentialsError())
|
||||||
|
|
||||||
|
response = make_auth_request(TEST_REQUEST)
|
||||||
|
|
||||||
|
assert "access_token" not in response.json
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert response.json["error"] == "Invalid credentials"
|
||||||
|
|
||||||
|
|
||||||
|
def test_authentication_error(make_auth_request, mock_authentication_service):
|
||||||
|
mock_authentication_service.authenticate = MagicMock(side_effect=Exception())
|
||||||
|
|
||||||
|
response = make_auth_request(TEST_REQUEST)
|
||||||
|
|
||||||
|
assert "access_token" not in response.json
|
||||||
|
assert response.status_code == 500
|
|
@ -0,0 +1,87 @@
|
||||||
|
import json
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from common.utils.exceptions import InvalidRegistrationCredentialsError, RegistrationNotNeededError
|
||||||
|
|
||||||
|
REGISTRATION_URL = "/api/registration"
|
||||||
|
|
||||||
|
USERNAME = "test_user"
|
||||||
|
PASSWORD = "test_password"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_authentication_service(monkeypatch):
|
||||||
|
mock_service = MagicMock()
|
||||||
|
mock_service.register_new_user = MagicMock()
|
||||||
|
mock_service.needs_registration = MagicMock()
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"monkey_island.cc.resources.auth.registration.AuthenticationService", mock_service
|
||||||
|
)
|
||||||
|
|
||||||
|
return mock_service
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def make_registration_request(flask_client):
|
||||||
|
def inner(request_body):
|
||||||
|
return flask_client.post(REGISTRATION_URL, data=request_body, follow_redirects=True)
|
||||||
|
|
||||||
|
return inner
|
||||||
|
|
||||||
|
|
||||||
|
def test_registration(make_registration_request, mock_authentication_service):
|
||||||
|
registration_request_body = f'{{"username": "{USERNAME}", "password": "{PASSWORD}"}}'
|
||||||
|
response = make_registration_request(registration_request_body)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
mock_authentication_service.register_new_user.assert_called_with(USERNAME, PASSWORD)
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_credentials(make_registration_request, mock_authentication_service):
|
||||||
|
registration_request_body = "{}"
|
||||||
|
make_registration_request(registration_request_body)
|
||||||
|
|
||||||
|
mock_authentication_service.register_new_user.assert_called_with("", "")
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_credentials(make_registration_request, mock_authentication_service):
|
||||||
|
mock_authentication_service.register_new_user = MagicMock(
|
||||||
|
side_effect=InvalidRegistrationCredentialsError()
|
||||||
|
)
|
||||||
|
|
||||||
|
registration_request_body = "{}"
|
||||||
|
response = make_registration_request(registration_request_body)
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
def test_registration_not_needed(make_registration_request, mock_authentication_service):
|
||||||
|
mock_authentication_service.register_new_user = MagicMock(
|
||||||
|
side_effect=RegistrationNotNeededError()
|
||||||
|
)
|
||||||
|
|
||||||
|
registration_request_body = "{}"
|
||||||
|
response = make_registration_request(registration_request_body)
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
def test_internal_error(make_registration_request, mock_authentication_service):
|
||||||
|
mock_authentication_service.register_new_user = MagicMock(side_effect=Exception())
|
||||||
|
|
||||||
|
registration_request_body = json.dumps({})
|
||||||
|
response = make_registration_request(registration_request_body)
|
||||||
|
|
||||||
|
assert response.status_code == 500
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("needs_registration", [True, False])
|
||||||
|
def test_needs_registration(flask_client, mock_authentication_service, needs_registration):
|
||||||
|
mock_authentication_service.needs_registration = MagicMock(return_value=needs_registration)
|
||||||
|
response = flask_client.get(REGISTRATION_URL, follow_redirects=True)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json["needs_registration"] is needs_registration
|
|
@ -19,6 +19,7 @@ def flask_client(monkeypatch_session):
|
||||||
|
|
||||||
def mock_init_app():
|
def mock_init_app():
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
|
app.config["SECRET_KEY"] = "test_key"
|
||||||
|
|
||||||
api = flask_restful.Api(app)
|
api = flask_restful.Api(app)
|
||||||
api.representations = {"application/json": output_json}
|
api.representations = {"application/json": output_json}
|
||||||
|
@ -26,4 +27,6 @@ def mock_init_app():
|
||||||
monkey_island.cc.app.init_app_url_rules(app)
|
monkey_island.cc.app.init_app_url_rules(app)
|
||||||
monkey_island.cc.app.init_api_resources(api)
|
monkey_island.cc.app.init_api_resources(api)
|
||||||
|
|
||||||
|
flask_jwt_extended.JWTManager(app)
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
Loading…
Reference in New Issue