replace flask-jwt with flask-jwt-extended

This commit is contained in:
Shay Nehmad 2020-07-21 17:30:21 +03:00
parent 1827cfab93
commit 1f26d7ffb9
26 changed files with 93 additions and 61 deletions

View File

@ -10,7 +10,7 @@ from monkey_island.cc.consts import MONKEY_ISLAND_ABS_PATH
from monkey_island.cc.database import database, mongo from monkey_island.cc.database import database, mongo
from monkey_island.cc.resources.attack.attack_config import AttackConfiguration from monkey_island.cc.resources.attack.attack_config import AttackConfiguration
from monkey_island.cc.resources.attack.attack_report import AttackReport from monkey_island.cc.resources.attack.attack_report import AttackReport
from monkey_island.cc.resources.auth.auth import init_jwt from monkey_island.cc.resources.auth.auth import init_jwt, Authenticate
from monkey_island.cc.resources.bootloader import Bootloader from monkey_island.cc.resources.bootloader import Bootloader
from monkey_island.cc.resources.client_run import ClientRun from monkey_island.cc.resources.client_run import ClientRun
from monkey_island.cc.resources.edge import Edge from monkey_island.cc.resources.edge import Edge
@ -31,7 +31,7 @@ from monkey_island.cc.resources.node import Node
from monkey_island.cc.resources.node_states import NodeStates from monkey_island.cc.resources.node_states import NodeStates
from monkey_island.cc.resources.pba_file_download import PBAFileDownload from monkey_island.cc.resources.pba_file_download import PBAFileDownload
from monkey_island.cc.resources.pba_file_upload import FileUpload from monkey_island.cc.resources.pba_file_upload import FileUpload
from monkey_island.cc.resources.registration import Registration from monkey_island.cc.resources.auth.registration import Registration
from monkey_island.cc.resources.remote_run import RemoteRun from monkey_island.cc.resources.remote_run import RemoteRun
from monkey_island.cc.resources.reporting.report import Report from monkey_island.cc.resources.reporting.report import Report
from monkey_island.cc.resources.root import Root from monkey_island.cc.resources.root import Root
@ -71,9 +71,12 @@ def serve_home():
def init_app_config(app, mongo_url): def init_app_config(app, mongo_url):
app.config['MONGO_URI'] = mongo_url app.config['MONGO_URI'] = mongo_url
app.config['SECRET_KEY'] = str(uuid.getnode())
app.config['JWT_AUTH_URL_RULE'] = '/api/auth' # See https://flask-jwt-extended.readthedocs.io/en/stable/options
app.config['JWT_EXPIRATION_DELTA'] = env_singleton.env.get_auth_expiration_time() app.config['JWT_TOKEN_LOCATION'] = ['headers']
app.config['JWT_ACCESS_TOKEN_EXPIRES'] = env_singleton.env.get_auth_expiration_time()
# Invalidate the signature of JWTs between server resets.
app.config['JWT_SECRET_KEY'] = str(uuid.uuid4())
def init_app_services(app): def init_app_services(app):
@ -96,6 +99,7 @@ def init_app_url_rules(app):
def init_api_resources(api): def init_api_resources(api):
api.add_resource(Root, '/api') api.add_resource(Root, '/api')
api.add_resource(Registration, '/api/registration') api.add_resource(Registration, '/api/registration')
api.add_resource(Authenticate, '/api/auth')
api.add_resource(Environment, '/api/environment') api.add_resource(Environment, '/api/environment')
api.add_resource(Monkey, '/api/monkey', '/api/monkey/', '/api/monkey/<string:guid>') api.add_resource(Monkey, '/api/monkey', '/api/monkey/', '/api/monkey/<string:guid>')
api.add_resource(Bootloader, '/api/bootloader/<string:os>') api.add_resource(Bootloader, '/api/bootloader/<string:os>')

View File

@ -23,7 +23,7 @@ class Environment(object, metaclass=ABCMeta):
_MONGO_URL = os.environ.get("MONKEY_MONGO_URL", _MONGO_URL = os.environ.get("MONKEY_MONGO_URL",
"mongodb://{0}:{1}/{2}".format(_MONGO_DB_HOST, _MONGO_DB_PORT, str(_MONGO_DB_NAME))) "mongodb://{0}:{1}/{2}".format(_MONGO_DB_HOST, _MONGO_DB_PORT, str(_MONGO_DB_NAME)))
_DEBUG_SERVER = False _DEBUG_SERVER = False
_AUTH_EXPIRATION_TIME = timedelta(hours=1) _AUTH_EXPIRATION_TIME = timedelta(minutes=30)
_testing = False _testing = False

View File

@ -8,7 +8,7 @@ __author__ = "VakarisZ"
class AttackConfiguration(flask_restful.Resource): class AttackConfiguration(flask_restful.Resource):
@jwt_required() @jwt_required
def get(self): def get(self):
return current_app.response_class(json.dumps({"configuration": AttackConfig.get_config()}, return current_app.response_class(json.dumps({"configuration": AttackConfig.get_config()},
indent=None, indent=None,
@ -16,7 +16,7 @@ class AttackConfiguration(flask_restful.Resource):
sort_keys=False) + "\n", sort_keys=False) + "\n",
mimetype=current_app.config['JSONIFY_MIMETYPE']) mimetype=current_app.config['JSONIFY_MIMETYPE'])
@jwt_required() @jwt_required
def post(self): def post(self):
""" """
Based on request content this endpoint either resets ATT&CK configuration or updates it. Based on request content this endpoint either resets ATT&CK configuration or updates it.

View File

@ -10,7 +10,7 @@ __author__ = "VakarisZ"
class AttackReport(flask_restful.Resource): class AttackReport(flask_restful.Resource):
@jwt_required() @jwt_required
def get(self): def get(self):
response_content = {'techniques': AttackReportService.get_latest_report()['techniques'], 'schema': SCHEMA} response_content = {'techniques': AttackReportService.get_latest_report()['techniques'], 'schema': SCHEMA}
return current_app.response_class(json.dumps(response_content, return current_app.response_class(json.dumps(response_content,

View File

@ -1,40 +1,67 @@
import json
import logging
from functools import wraps from functools import wraps
from flask import abort, current_app import flask_restful
from flask_jwt import JWT, JWTError, _jwt_required import flask_jwt_extended
from flask import make_response, request
from flask_jwt_extended.exceptions import JWTExtendedException
from jwt import PyJWTError
from werkzeug.security import safe_str_cmp from werkzeug.security import safe_str_cmp
import monkey_island.cc.environment.environment_singleton as env_singleton import monkey_island.cc.environment.environment_singleton as env_singleton
import monkey_island.cc.resources.auth.user_store as user_store import monkey_island.cc.resources.auth.user_store as user_store
__author__ = 'itay.mizeretz' logger = logging.getLogger(__name__)
def init_jwt(app): def init_jwt(app):
user_store.UserStore.set_users(env_singleton.env.get_auth_users()) user_store.UserStore.set_users(env_singleton.env.get_auth_users())
_ = flask_jwt_extended.JWTManager(app)
logger.debug("Initialized JWT with secret key that started with " + app.config["JWT_SECRET_KEY"][:4])
def authenticate(username, secret):
class Authenticate(flask_restful.Resource):
"""
Resource for user authentication. The user provides the username and hashed password and we give them a JWT.
See `AuthService.js` file for the frontend counterpart for this code.
"""
@staticmethod
def _authenticate(username, secret):
user = user_store.UserStore.username_table.get(username, None) user = user_store.UserStore.username_table.get(username, None)
if user and safe_str_cmp(user.secret.encode('utf-8'), secret.encode('utf-8')): if user and safe_str_cmp(user.secret.encode('utf-8'), secret.encode('utf-8')):
return user return user
def identity(payload): def post(self):
user_id = payload['identity'] """
return user_store.UserStore.user_id_table.get(user_id, None) Example request:
{
JWT(app, authenticate, identity) "username": "my_user",
"password": "343bb87e553b05430e5c44baf99569d4b66..."
}
"""
credentials = json.loads(request.data)
# Unpack auth info from request
username = credentials["username"]
secret = credentials["password"]
# If the user and password have been previously registered
if self._authenticate(username, secret):
access_token = flask_jwt_extended.create_access_token(identity=user_store.UserStore.username_table[username].id)
logger.debug(f"Created access token for user {username}: {access_token}")
return make_response({"access_token": access_token, "error": ""}, 200)
else:
return make_response({"error": "Invalid credentials"}, 401)
def jwt_required(realm=None): # See https://flask-jwt-extended.readthedocs.io/en/stable/custom_decorators/
def wrapper(fn): def jwt_required(fn):
@wraps(fn) @wraps(fn)
def decorator(*args, **kwargs): def wrapper(*args, **kwargs):
try: try:
_jwt_required(realm or current_app.config['JWT_DEFAULT_REALM']) flask_jwt_extended.verify_jwt_in_request()
return fn(*args, **kwargs) return fn(*args, **kwargs)
except JWTError: # Catch authentication related errors in the verification or inside the called function. All other exceptions propagate
abort(401) except (JWTExtendedException, PyJWTError) as e:
return make_response({"error": f"Authentication error: {str(e)}"}, 401)
return decorator
return wrapper return wrapper

View File

@ -8,12 +8,12 @@ from monkey_island.cc.services.config import ConfigService
class IslandConfiguration(flask_restful.Resource): class IslandConfiguration(flask_restful.Resource):
@jwt_required() @jwt_required
def get(self): def get(self):
return jsonify(schema=ConfigService.get_config_schema(), return jsonify(schema=ConfigService.get_config_schema(),
configuration=ConfigService.get_config(False, True, True)) configuration=ConfigService.get_config(False, True, True))
@jwt_required() @jwt_required
def post(self): def post(self):
config_json = json.loads(request.data) config_json = json.loads(request.data)
if 'reset' in config_json: if 'reset' in config_json:

View File

@ -11,7 +11,7 @@ logger = logging.getLogger(__name__)
class IslandLog(flask_restful.Resource): class IslandLog(flask_restful.Resource):
@jwt_required() @jwt_required
def get(self): def get(self):
try: try:
return IslandLogService.get_log_file() return IslandLogService.get_log_file()

View File

@ -14,7 +14,7 @@ __author__ = "itay.mizeretz"
class Log(flask_restful.Resource): class Log(flask_restful.Resource):
@jwt_required() @jwt_required
def get(self): def get(self):
monkey_id = request.args.get('id') monkey_id = request.args.get('id')
exists_monkey_id = request.args.get('exists') exists_monkey_id = request.args.get('exists')

View File

@ -10,11 +10,11 @@ __author__ = 'Barak'
class MonkeyConfiguration(flask_restful.Resource): class MonkeyConfiguration(flask_restful.Resource):
@jwt_required() @jwt_required
def get(self): def get(self):
return jsonify(schema=ConfigService.get_config_schema(), configuration=ConfigService.get_config(False, True)) return jsonify(schema=ConfigService.get_config_schema(), configuration=ConfigService.get_config(False, True))
@jwt_required() @jwt_required
def post(self): def post(self):
config_json = json.loads(request.data) config_json = json.loads(request.data)
if 'reset' in config_json: if 'reset' in config_json:

View File

@ -8,7 +8,7 @@ __author__ = 'Barak'
class NetMap(flask_restful.Resource): class NetMap(flask_restful.Resource):
@jwt_required() @jwt_required
def get(self, **kw): def get(self, **kw):
net_nodes = NetNodeService.get_all_net_nodes() net_nodes = NetNodeService.get_all_net_nodes()
net_edges = NetEdgeService.get_all_net_edges() net_edges = NetEdgeService.get_all_net_edges()

View File

@ -8,7 +8,7 @@ __author__ = 'Barak'
class Node(flask_restful.Resource): class Node(flask_restful.Resource):
@jwt_required() @jwt_required
def get(self): def get(self):
node_id = request.args.get('id') node_id = request.args.get('id')
if node_id: if node_id:

View File

@ -6,6 +6,6 @@ from monkey_island.cc.services.utils.node_states import \
class NodeStates(flask_restful.Resource): class NodeStates(flask_restful.Resource):
@jwt_required() @jwt_required
def get(self): def get(self):
return {'node_states': [state.value for state in NodeStateList]} return {'node_states': [state.value for state in NodeStateList]}

View File

@ -27,7 +27,7 @@ class FileUpload(flask_restful.Resource):
# Create all directories on the way if they don't exist # Create all directories on the way if they don't exist
UPLOADS_DIR.mkdir(parents=True, exist_ok=True) UPLOADS_DIR.mkdir(parents=True, exist_ok=True)
@jwt_required() @jwt_required
def get(self, file_type): def get(self, file_type):
""" """
Sends file to filepond Sends file to filepond
@ -41,7 +41,7 @@ class FileUpload(flask_restful.Resource):
filename = ConfigService.get_config_value(copy.deepcopy(PBA_WINDOWS_FILENAME_PATH)) filename = ConfigService.get_config_value(copy.deepcopy(PBA_WINDOWS_FILENAME_PATH))
return send_from_directory(UPLOADS_DIR, filename) return send_from_directory(UPLOADS_DIR, filename)
@jwt_required() @jwt_required
def post(self, file_type): def post(self, file_type):
""" """
Receives user's uploaded file from filepond Receives user's uploaded file from filepond
@ -55,7 +55,7 @@ class FileUpload(flask_restful.Resource):
status=200, mimetype='text/plain') status=200, mimetype='text/plain')
return response return response
@jwt_required() @jwt_required
def delete(self, file_type): def delete(self, file_type):
""" """
Deletes file that has been deleted on the front end Deletes file that has been deleted on the front end

View File

@ -24,7 +24,7 @@ class RemoteRun(flask_restful.Resource):
island_ip = request_body.get('island_ip') island_ip = request_body.get('island_ip')
return RemoteRunAwsService.run_aws_monkeys(instances, island_ip) return RemoteRunAwsService.run_aws_monkeys(instances, island_ip)
@jwt_required() @jwt_required
def get(self): def get(self):
action = request.args.get('action') action = request.args.get('action')
if action == 'list_aws': if action == 'list_aws':
@ -43,7 +43,7 @@ class RemoteRun(flask_restful.Resource):
return {} return {}
@jwt_required() @jwt_required
def post(self): def post(self):
body = json.loads(request.data) body = json.loads(request.data)
resp = {} resp = {}

View File

@ -21,7 +21,7 @@ __author__ = ["itay.mizeretz", "shay.nehmad"]
class Report(flask_restful.Resource): class Report(flask_restful.Resource):
@jwt_required() @jwt_required
def get(self, report_type=SECURITY_REPORT_TYPE, report_data=None): def get(self, report_type=SECURITY_REPORT_TYPE, report_data=None):
if report_type == SECURITY_REPORT_TYPE: if report_type == SECURITY_REPORT_TYPE:
return ReportService.get_report() return ReportService.get_report()

View File

@ -26,15 +26,15 @@ class Root(flask_restful.Resource):
if not action: if not action:
return self.get_server_info() return self.get_server_info()
elif action == "reset": elif action == "reset":
return jwt_required()(Database.reset_db)() return jwt_required(Database.reset_db)()
elif action == "killall": elif action == "killall":
return jwt_required()(InfectionLifecycle.kill_all)() return jwt_required(InfectionLifecycle.kill_all)()
elif action == "is-up": elif action == "is-up":
return {'is-up': True} return {'is-up': True}
else: else:
return make_response(400, {'error': 'unknown action'}) return make_response(400, {'error': 'unknown action'})
@jwt_required() @jwt_required
def get_server_info(self): def get_server_info(self):
return jsonify( return jsonify(
ip_addresses=local_ip_addresses(), ip_addresses=local_ip_addresses(),

View File

@ -20,7 +20,7 @@ logger = logging.getLogger(__name__)
class Telemetry(flask_restful.Resource): class Telemetry(flask_restful.Resource):
@jwt_required() @jwt_required
def get(self, **kw): def get(self, **kw):
monkey_guid = request.args.get('monkey_guid') monkey_guid = request.args.get('monkey_guid')
telem_category = request.args.get('telem_category') telem_category = request.args.get('telem_category')

View File

@ -5,9 +5,9 @@ import dateutil
import flask_pymongo import flask_pymongo
import flask_restful import flask_restful
from flask import request from flask import request
from monkey_island.cc.resources.auth.auth import jwt_required
from monkey_island.cc.database import mongo from monkey_island.cc.database import mongo
from monkey_island.cc.resources.auth.auth import jwt_required
from monkey_island.cc.services.node import NodeService from monkey_island.cc.services.node import NodeService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -16,7 +16,7 @@ __author__ = 'itay.mizeretz'
class TelemetryFeed(flask_restful.Resource): class TelemetryFeed(flask_restful.Resource):
@jwt_required() @jwt_required
def get(self, **kw): def get(self, **kw):
timestamp = request.args.get('timestamp') timestamp = request.args.get('timestamp')
if "null" == timestamp or timestamp is None: # special case to avoid ugly JS code... if "null" == timestamp or timestamp is None: # special case to avoid ugly JS code...

View File

@ -17,7 +17,7 @@ class ClearCaches(flask_restful.Resource):
so we use this to clear the caches. so we use this to clear the caches.
:note: DO NOT CALL THIS IN PRODUCTION CODE as this will slow down the user experience. :note: DO NOT CALL THIS IN PRODUCTION CODE as this will slow down the user experience.
""" """
@jwt_required() @jwt_required
def get(self, **kw): def get(self, **kw):
try: try:
logger.warning("Trying to clear caches! Make sure this is not production") logger.warning("Trying to clear caches! Make sure this is not production")

View File

@ -7,7 +7,7 @@ from monkey_island.cc.resources.auth.auth import jwt_required
class LogTest(flask_restful.Resource): class LogTest(flask_restful.Resource):
@jwt_required() @jwt_required
def get(self): def get(self):
find_query = json_util.loads(request.args.get('find_query')) find_query = json_util.loads(request.args.get('find_query'))
log = mongo.db.log.find_one(find_query) log = mongo.db.log.find_one(find_query)

View File

@ -7,7 +7,7 @@ from monkey_island.cc.resources.auth.auth import jwt_required
class MonkeyTest(flask_restful.Resource): class MonkeyTest(flask_restful.Resource):
@jwt_required() @jwt_required
def get(self, **kw): def get(self, **kw):
find_query = json_util.loads(request.args.get('find_query')) find_query = json_util.loads(request.args.get('find_query'))
return {'results': list(mongo.db.monkey.find(find_query))} return {'results': list(mongo.db.monkey.find(find_query))}

View File

@ -9,6 +9,6 @@ from monkey_island.cc.services.reporting.zero_trust_service import \
class ZeroTrustFindingEvent(flask_restful.Resource): class ZeroTrustFindingEvent(flask_restful.Resource):
@jwt_required() @jwt_required
def get(self, finding_id: str): def get(self, finding_id: str):
return {'events_json': json.dumps(ZeroTrustService.get_events_by_finding(finding_id), default=str)} return {'events_json': json.dumps(ZeroTrustService.get_events_by_finding(finding_id), default=str)}

View File

@ -26,13 +26,13 @@ class RegisterPageComponent extends React.Component {
}; };
setNoAuth = () => { setNoAuth = () => {
let options = {} let options = {};
options['headers'] = { options['headers'] = {
'Accept': 'application/json', 'Accept': 'application/json',
'Content-Type': 'application/json' 'Content-Type': 'application/json'
}; };
options['method'] = 'PATCH' options['method'] = 'PATCH';
options['body'] = JSON.stringify({'server_config': 'standard'}) options['body'] = JSON.stringify({'server_config': 'standard'});
return fetch(this.NO_AUTH_API_ENDPOINT, options) return fetch(this.NO_AUTH_API_ENDPOINT, options)
.then(res => { .then(res => {

View File

@ -83,7 +83,7 @@ export default class AuthService {
}; };
if (this._loggedIn()) { if (this._loggedIn()) {
headers['Authorization'] = 'JWT ' + this._getToken(); headers['Authorization'] = 'Bearer ' + this._getToken();
} }
if (options.hasOwnProperty('headers')) { if (options.hasOwnProperty('headers')) {
@ -97,6 +97,9 @@ export default class AuthService {
return fetch(url, options) return fetch(url, options)
.then(res => { .then(res => {
if (res.status === 401) { if (res.status === 401) {
res.clone().json().then(res_json => {
console.log('Got 401 from server while trying to authFetch: ' + JSON.stringify(res_json));
});
this._removeToken(); this._removeToken();
} }
return res; return res;
@ -156,6 +159,4 @@ export default class AuthService {
_toHexStr(byteArr) { _toHexStr(byteArr) {
return byteArr.reduce((acc, x) => (acc + ('0' + x.toString(0x10)).slice(-2)), ''); return byteArr.reduce((acc, x) => (acc + ('0' + x.toString(0x10)).slice(-2)), '');
} }
} }

View File

@ -1,4 +1,4 @@
Flask-JWT>=0.3.2 Flask-JWT-Extended==3.24.1
Flask-Pymongo>=2.3.0 Flask-Pymongo>=2.3.0
Flask-Restful>=0.3.8 Flask-Restful>=0.3.8
PyInstaller==3.6 PyInstaller==3.6