replace flask-jwt with flask-jwt-extended

This commit is contained in:
Shay Nehmad 2020-07-21 17:30:21 +03:00
parent 1827cfab93
commit 1f26d7ffb9
26 changed files with 93 additions and 61 deletions

View File

@ -10,7 +10,7 @@ from monkey_island.cc.consts import MONKEY_ISLAND_ABS_PATH
from monkey_island.cc.database import database, mongo
from monkey_island.cc.resources.attack.attack_config import AttackConfiguration
from monkey_island.cc.resources.attack.attack_report import AttackReport
from monkey_island.cc.resources.auth.auth import init_jwt
from monkey_island.cc.resources.auth.auth import init_jwt, Authenticate
from monkey_island.cc.resources.bootloader import Bootloader
from monkey_island.cc.resources.client_run import ClientRun
from monkey_island.cc.resources.edge import Edge
@ -31,7 +31,7 @@ from monkey_island.cc.resources.node import Node
from monkey_island.cc.resources.node_states import NodeStates
from monkey_island.cc.resources.pba_file_download import PBAFileDownload
from monkey_island.cc.resources.pba_file_upload import FileUpload
from monkey_island.cc.resources.registration import Registration
from monkey_island.cc.resources.auth.registration import Registration
from monkey_island.cc.resources.remote_run import RemoteRun
from monkey_island.cc.resources.reporting.report import Report
from monkey_island.cc.resources.root import Root
@ -71,9 +71,12 @@ def serve_home():
def init_app_config(app, mongo_url):
app.config['MONGO_URI'] = mongo_url
app.config['SECRET_KEY'] = str(uuid.getnode())
app.config['JWT_AUTH_URL_RULE'] = '/api/auth'
app.config['JWT_EXPIRATION_DELTA'] = env_singleton.env.get_auth_expiration_time()
# See https://flask-jwt-extended.readthedocs.io/en/stable/options
app.config['JWT_TOKEN_LOCATION'] = ['headers']
app.config['JWT_ACCESS_TOKEN_EXPIRES'] = env_singleton.env.get_auth_expiration_time()
# Invalidate the signature of JWTs between server resets.
app.config['JWT_SECRET_KEY'] = str(uuid.uuid4())
def init_app_services(app):
@ -96,6 +99,7 @@ def init_app_url_rules(app):
def init_api_resources(api):
api.add_resource(Root, '/api')
api.add_resource(Registration, '/api/registration')
api.add_resource(Authenticate, '/api/auth')
api.add_resource(Environment, '/api/environment')
api.add_resource(Monkey, '/api/monkey', '/api/monkey/', '/api/monkey/<string:guid>')
api.add_resource(Bootloader, '/api/bootloader/<string:os>')

View File

@ -23,7 +23,7 @@ class Environment(object, metaclass=ABCMeta):
_MONGO_URL = os.environ.get("MONKEY_MONGO_URL",
"mongodb://{0}:{1}/{2}".format(_MONGO_DB_HOST, _MONGO_DB_PORT, str(_MONGO_DB_NAME)))
_DEBUG_SERVER = False
_AUTH_EXPIRATION_TIME = timedelta(hours=1)
_AUTH_EXPIRATION_TIME = timedelta(minutes=30)
_testing = False

View File

@ -8,7 +8,7 @@ __author__ = "VakarisZ"
class AttackConfiguration(flask_restful.Resource):
@jwt_required()
@jwt_required
def get(self):
return current_app.response_class(json.dumps({"configuration": AttackConfig.get_config()},
indent=None,
@ -16,7 +16,7 @@ class AttackConfiguration(flask_restful.Resource):
sort_keys=False) + "\n",
mimetype=current_app.config['JSONIFY_MIMETYPE'])
@jwt_required()
@jwt_required
def post(self):
"""
Based on request content this endpoint either resets ATT&CK configuration or updates it.

View File

@ -10,7 +10,7 @@ __author__ = "VakarisZ"
class AttackReport(flask_restful.Resource):
@jwt_required()
@jwt_required
def get(self):
response_content = {'techniques': AttackReportService.get_latest_report()['techniques'], 'schema': SCHEMA}
return current_app.response_class(json.dumps(response_content,

View File

@ -1,40 +1,67 @@
import json
import logging
from functools import wraps
from flask import abort, current_app
from flask_jwt import JWT, JWTError, _jwt_required
import flask_restful
import flask_jwt_extended
from flask import make_response, request
from flask_jwt_extended.exceptions import JWTExtendedException
from jwt import PyJWTError
from werkzeug.security import safe_str_cmp
import monkey_island.cc.environment.environment_singleton as env_singleton
import monkey_island.cc.resources.auth.user_store as user_store
__author__ = 'itay.mizeretz'
logger = logging.getLogger(__name__)
def init_jwt(app):
user_store.UserStore.set_users(env_singleton.env.get_auth_users())
_ = flask_jwt_extended.JWTManager(app)
logger.debug("Initialized JWT with secret key that started with " + app.config["JWT_SECRET_KEY"][:4])
def authenticate(username, secret):
class Authenticate(flask_restful.Resource):
"""
Resource for user authentication. The user provides the username and hashed password and we give them a JWT.
See `AuthService.js` file for the frontend counterpart for this code.
"""
@staticmethod
def _authenticate(username, secret):
user = user_store.UserStore.username_table.get(username, None)
if user and safe_str_cmp(user.secret.encode('utf-8'), secret.encode('utf-8')):
return user
def identity(payload):
user_id = payload['identity']
return user_store.UserStore.user_id_table.get(user_id, None)
JWT(app, authenticate, identity)
def post(self):
"""
Example request:
{
"username": "my_user",
"password": "343bb87e553b05430e5c44baf99569d4b66..."
}
"""
credentials = json.loads(request.data)
# Unpack auth info from request
username = credentials["username"]
secret = credentials["password"]
# If the user and password have been previously registered
if self._authenticate(username, secret):
access_token = flask_jwt_extended.create_access_token(identity=user_store.UserStore.username_table[username].id)
logger.debug(f"Created access token for user {username}: {access_token}")
return make_response({"access_token": access_token, "error": ""}, 200)
else:
return make_response({"error": "Invalid credentials"}, 401)
def jwt_required(realm=None):
def wrapper(fn):
# See https://flask-jwt-extended.readthedocs.io/en/stable/custom_decorators/
def jwt_required(fn):
@wraps(fn)
def decorator(*args, **kwargs):
def wrapper(*args, **kwargs):
try:
_jwt_required(realm or current_app.config['JWT_DEFAULT_REALM'])
flask_jwt_extended.verify_jwt_in_request()
return fn(*args, **kwargs)
except JWTError:
abort(401)
return decorator
# Catch authentication related errors in the verification or inside the called function. All other exceptions propagate
except (JWTExtendedException, PyJWTError) as e:
return make_response({"error": f"Authentication error: {str(e)}"}, 401)
return wrapper

View File

@ -8,12 +8,12 @@ from monkey_island.cc.services.config import ConfigService
class IslandConfiguration(flask_restful.Resource):
@jwt_required()
@jwt_required
def get(self):
return jsonify(schema=ConfigService.get_config_schema(),
configuration=ConfigService.get_config(False, True, True))
@jwt_required()
@jwt_required
def post(self):
config_json = json.loads(request.data)
if 'reset' in config_json:

View File

@ -11,7 +11,7 @@ logger = logging.getLogger(__name__)
class IslandLog(flask_restful.Resource):
@jwt_required()
@jwt_required
def get(self):
try:
return IslandLogService.get_log_file()

View File

@ -14,7 +14,7 @@ __author__ = "itay.mizeretz"
class Log(flask_restful.Resource):
@jwt_required()
@jwt_required
def get(self):
monkey_id = request.args.get('id')
exists_monkey_id = request.args.get('exists')

View File

@ -10,11 +10,11 @@ __author__ = 'Barak'
class MonkeyConfiguration(flask_restful.Resource):
@jwt_required()
@jwt_required
def get(self):
return jsonify(schema=ConfigService.get_config_schema(), configuration=ConfigService.get_config(False, True))
@jwt_required()
@jwt_required
def post(self):
config_json = json.loads(request.data)
if 'reset' in config_json:

View File

@ -8,7 +8,7 @@ __author__ = 'Barak'
class NetMap(flask_restful.Resource):
@jwt_required()
@jwt_required
def get(self, **kw):
net_nodes = NetNodeService.get_all_net_nodes()
net_edges = NetEdgeService.get_all_net_edges()

View File

@ -8,7 +8,7 @@ __author__ = 'Barak'
class Node(flask_restful.Resource):
@jwt_required()
@jwt_required
def get(self):
node_id = request.args.get('id')
if node_id:

View File

@ -6,6 +6,6 @@ from monkey_island.cc.services.utils.node_states import \
class NodeStates(flask_restful.Resource):
@jwt_required()
@jwt_required
def get(self):
return {'node_states': [state.value for state in NodeStateList]}

View File

@ -27,7 +27,7 @@ class FileUpload(flask_restful.Resource):
# Create all directories on the way if they don't exist
UPLOADS_DIR.mkdir(parents=True, exist_ok=True)
@jwt_required()
@jwt_required
def get(self, file_type):
"""
Sends file to filepond
@ -41,7 +41,7 @@ class FileUpload(flask_restful.Resource):
filename = ConfigService.get_config_value(copy.deepcopy(PBA_WINDOWS_FILENAME_PATH))
return send_from_directory(UPLOADS_DIR, filename)
@jwt_required()
@jwt_required
def post(self, file_type):
"""
Receives user's uploaded file from filepond
@ -55,7 +55,7 @@ class FileUpload(flask_restful.Resource):
status=200, mimetype='text/plain')
return response
@jwt_required()
@jwt_required
def delete(self, file_type):
"""
Deletes file that has been deleted on the front end

View File

@ -24,7 +24,7 @@ class RemoteRun(flask_restful.Resource):
island_ip = request_body.get('island_ip')
return RemoteRunAwsService.run_aws_monkeys(instances, island_ip)
@jwt_required()
@jwt_required
def get(self):
action = request.args.get('action')
if action == 'list_aws':
@ -43,7 +43,7 @@ class RemoteRun(flask_restful.Resource):
return {}
@jwt_required()
@jwt_required
def post(self):
body = json.loads(request.data)
resp = {}

View File

@ -21,7 +21,7 @@ __author__ = ["itay.mizeretz", "shay.nehmad"]
class Report(flask_restful.Resource):
@jwt_required()
@jwt_required
def get(self, report_type=SECURITY_REPORT_TYPE, report_data=None):
if report_type == SECURITY_REPORT_TYPE:
return ReportService.get_report()

View File

@ -26,15 +26,15 @@ class Root(flask_restful.Resource):
if not action:
return self.get_server_info()
elif action == "reset":
return jwt_required()(Database.reset_db)()
return jwt_required(Database.reset_db)()
elif action == "killall":
return jwt_required()(InfectionLifecycle.kill_all)()
return jwt_required(InfectionLifecycle.kill_all)()
elif action == "is-up":
return {'is-up': True}
else:
return make_response(400, {'error': 'unknown action'})
@jwt_required()
@jwt_required
def get_server_info(self):
return jsonify(
ip_addresses=local_ip_addresses(),

View File

@ -20,7 +20,7 @@ logger = logging.getLogger(__name__)
class Telemetry(flask_restful.Resource):
@jwt_required()
@jwt_required
def get(self, **kw):
monkey_guid = request.args.get('monkey_guid')
telem_category = request.args.get('telem_category')

View File

@ -5,9 +5,9 @@ import dateutil
import flask_pymongo
import flask_restful
from flask import request
from monkey_island.cc.resources.auth.auth import jwt_required
from monkey_island.cc.database import mongo
from monkey_island.cc.resources.auth.auth import jwt_required
from monkey_island.cc.services.node import NodeService
logger = logging.getLogger(__name__)
@ -16,7 +16,7 @@ __author__ = 'itay.mizeretz'
class TelemetryFeed(flask_restful.Resource):
@jwt_required()
@jwt_required
def get(self, **kw):
timestamp = request.args.get('timestamp')
if "null" == timestamp or timestamp is None: # special case to avoid ugly JS code...

View File

@ -17,7 +17,7 @@ class ClearCaches(flask_restful.Resource):
so we use this to clear the caches.
:note: DO NOT CALL THIS IN PRODUCTION CODE as this will slow down the user experience.
"""
@jwt_required()
@jwt_required
def get(self, **kw):
try:
logger.warning("Trying to clear caches! Make sure this is not production")

View File

@ -7,7 +7,7 @@ from monkey_island.cc.resources.auth.auth import jwt_required
class LogTest(flask_restful.Resource):
@jwt_required()
@jwt_required
def get(self):
find_query = json_util.loads(request.args.get('find_query'))
log = mongo.db.log.find_one(find_query)

View File

@ -7,7 +7,7 @@ from monkey_island.cc.resources.auth.auth import jwt_required
class MonkeyTest(flask_restful.Resource):
@jwt_required()
@jwt_required
def get(self, **kw):
find_query = json_util.loads(request.args.get('find_query'))
return {'results': list(mongo.db.monkey.find(find_query))}

View File

@ -9,6 +9,6 @@ from monkey_island.cc.services.reporting.zero_trust_service import \
class ZeroTrustFindingEvent(flask_restful.Resource):
@jwt_required()
@jwt_required
def get(self, finding_id: str):
return {'events_json': json.dumps(ZeroTrustService.get_events_by_finding(finding_id), default=str)}

View File

@ -26,13 +26,13 @@ class RegisterPageComponent extends React.Component {
};
setNoAuth = () => {
let options = {}
let options = {};
options['headers'] = {
'Accept': 'application/json',
'Content-Type': 'application/json'
};
options['method'] = 'PATCH'
options['body'] = JSON.stringify({'server_config': 'standard'})
options['method'] = 'PATCH';
options['body'] = JSON.stringify({'server_config': 'standard'});
return fetch(this.NO_AUTH_API_ENDPOINT, options)
.then(res => {

View File

@ -83,7 +83,7 @@ export default class AuthService {
};
if (this._loggedIn()) {
headers['Authorization'] = 'JWT ' + this._getToken();
headers['Authorization'] = 'Bearer ' + this._getToken();
}
if (options.hasOwnProperty('headers')) {
@ -97,6 +97,9 @@ export default class AuthService {
return fetch(url, options)
.then(res => {
if (res.status === 401) {
res.clone().json().then(res_json => {
console.log('Got 401 from server while trying to authFetch: ' + JSON.stringify(res_json));
});
this._removeToken();
}
return res;
@ -156,6 +159,4 @@ export default class AuthService {
_toHexStr(byteArr) {
return byteArr.reduce((acc, x) => (acc + ('0' + x.toString(0x10)).slice(-2)), '');
}
}

View File

@ -1,4 +1,4 @@
Flask-JWT>=0.3.2
Flask-JWT-Extended==3.24.1
Flask-Pymongo>=2.3.0
Flask-Restful>=0.3.8
PyInstaller==3.6