Addressed most CR comments, refactored mitigations to include names and urls

This commit is contained in:
VakarisZ 2020-03-31 17:40:36 +03:00
parent 5d827d1f44
commit 2b0820f74a
11 changed files with 160 additions and 49 deletions

View File

@ -0,0 +1,35 @@
from typing import Dict
from mongoengine import Document, StringField, DoesNotExist, EmbeddedDocumentField, ListField
from monkey_island.cc.models.attack.mitigation import Mitigation
from stix2 import AttackPattern, CourseOfAction
from monkey_island.cc.services.attack.test_mitre_api_interface import MitreApiInterface
class AttackMitigations(Document):
technique_id = StringField(required=True, primary_key=True)
mitigations = ListField(EmbeddedDocumentField('Mitigation'))
@staticmethod
def get_mitigation_by_technique_id(technique_id: str) -> Document:
try:
return AttackMitigations.objects.get(technique_id=technique_id)
except DoesNotExist:
raise Exception("Attack technique with id {} does not exist!".format(technique_id))
def add_mitigation(self, mitigation: CourseOfAction):
mitigation_external_ref_id = MitreApiInterface.get_stix2_external_reference_id(mitigation)
if mitigation_external_ref_id.startswith('M'):
self.mitigations.append(Mitigation.get_from_stix2_data(mitigation))
@staticmethod
def mitigations_from_attack_pattern(attack_pattern: AttackPattern):
return AttackMitigations(technique_id=MitreApiInterface.get_stix2_external_reference_id(attack_pattern),
mitigations=[])
@staticmethod
def dict_from_stix2_attack_patterns(stix2_dict: Dict[str, AttackPattern]):
return {key: AttackMitigations.mitigations_from_attack_pattern(attack_pattern)
for key, attack_pattern in stix2_dict.items()}

View File

@ -0,0 +1,19 @@
from mongoengine import StringField, EmbeddedDocument
from stix2 import CourseOfAction
from monkey_island.cc.services.attack.test_mitre_api_interface import MitreApiInterface
class Mitigation(EmbeddedDocument):
name = StringField(required=True)
description = StringField(required=True)
url = StringField()
@staticmethod
def get_from_stix2_data(mitigation: CourseOfAction):
name = mitigation['name']
description = mitigation['description']
url = MitreApiInterface.get_stix2_external_reference_url(mitigation)
return Mitigation(name=name, description=description, url=url)

View File

@ -1,23 +0,0 @@
from mongoengine import Document, StringField, DoesNotExist
class AttackMitigation(Document):
technique_id = StringField(required=True, primary_key=True)
name = StringField(required=True)
description = StringField(required=True)
@staticmethod
def get_mitigation_by_technique_id(technique_id: str) -> Document:
try:
return AttackMitigation.objects.get(technique_id=technique_id)
except DoesNotExist:
raise Exception("Attack technique with id {} does not exist!".format(technique_id))
@staticmethod
def add_mitigation_from_stix2(mitigation_stix2_data):
mitigation_model = AttackMitigation(technique_id=mitigation_stix2_data['external_references'][0]['external_id'],
name=mitigation_stix2_data['name'],
description=mitigation_stix2_data['description'])
if mitigation_model.technique_id.startswith('T'):
mitigation_model.save()

View File

@ -1,4 +1,7 @@
from stix2 import FileSystemSource, Filter
from typing import List, Dict
from stix2 import FileSystemSource, Filter, CourseOfAction, AttackPattern
from stix2.core import STIXDomainObject
class MitreApiInterface:
@ -6,8 +9,39 @@ class MitreApiInterface:
ATTACK_DATA_PATH = 'monkey_island/cc/services/attack/attack_data/enterprise-attack'
@staticmethod
def get_all_mitigations() -> list:
def get_all_mitigations() -> Dict[str, CourseOfAction]:
file_system = FileSystemSource(MitreApiInterface.ATTACK_DATA_PATH)
mitigation_filter = [Filter('type', '=', 'course-of-action')]
all_mitigations = file_system.query(mitigation_filter)
all_mitigations = {mitigation['id']: mitigation for mitigation in all_mitigations}
return all_mitigations
@staticmethod
def get_all_attack_techniques() -> Dict[str, AttackPattern]:
file_system = FileSystemSource(MitreApiInterface.ATTACK_DATA_PATH)
technique_filter = [Filter('type', '=', 'attack-pattern')]
all_techniques = file_system.query(technique_filter)
all_techniques = {technique['id']: technique for technique in all_techniques}
return all_techniques
@staticmethod
def get_technique_and_mitigation_relationships() -> List[CourseOfAction]:
file_system = FileSystemSource(MitreApiInterface.ATTACK_DATA_PATH)
technique_filter = [Filter('type', '=', 'relationship'),
Filter('relationship_type', '=', 'mitigates')]
all_techniques = file_system.query(technique_filter)
return all_techniques
@staticmethod
def get_stix2_external_reference_id(stix2_data: STIXDomainObject) -> str:
for reference in stix2_data['external_references']:
if reference['source_name'] == "mitre-attack" and 'external_id' in reference:
return reference['external_id']
return ''
@staticmethod
def get_stix2_external_reference_url(stix2_data: STIXDomainObject) -> str:
for reference in stix2_data['external_references']:
if 'url' in reference:
return reference['url']
return ''

View File

@ -5,7 +5,7 @@ from monkey_island.cc.database import mongo
from common.utils.attack_utils import ScanStatus
from monkey_island.cc.services.attack.attack_config import AttackConfig
from common.utils.code_utils import abstractstatic
from monkey_island.cc.models.attack_mitigation import AttackMitigation
from cc.models.attack.attack_mitigations import AttackMitigations
logger = logging.getLogger(__name__)
@ -125,9 +125,8 @@ class AttackTechnique(object, metaclass=abc.ABCMeta):
@classmethod
def get_mitigation_by_status(cls, status: ScanStatus) -> dict:
if status == ScanStatus.USED.value:
mitigation_document = AttackMitigation.get_mitigation_by_technique_id(str(cls.tech_id))
return {'mitigations': {'name': mitigation_document['name'],
'description': mitigation_document['description']}}
mitigation_document = AttackMitigations.get_mitigation_by_technique_id(str(cls.tech_id))
return {'mitigations': mitigation_document.to_mongo().to_dict()['mitigations']}
else:
return {}

View File

@ -0,0 +1,15 @@
from unittest import TestCase
from monkey_island.cc.services.attack.mitre_api_interface import MitreApiInterface
class TestMitreApiInterface(TestCase):
def test_get_all_mitigations(self):
mitigations = MitreApiInterface.get_all_mitigations()
self.assertTrue((len(mitigations) >= 282))
mitigation = mitigations[0]
self.assertEqual(mitigation['type'], "course-of-action")
self.assertTrue(mitigation['name'])
self.assertTrue(mitigation['description'])
self.assertTrue(mitigation['external_references'])

View File

@ -1,5 +1,5 @@
from monkey_island.cc.services.attack.mitre_api_interface import MitreApiInterface
from monkey_island.cc.models.attack_mitigation import AttackMitigation
from cc.models.attack.attack_mitigations import AttackMitigations
from monkey_island.cc.database import mongo
from pymongo import errors
@ -9,16 +9,23 @@ def setup():
def try_store_mitigations_on_mongo():
# import the 'errors' module from PyMongo
mitigation_collection_name = 'attack_mitigation'
mitigation_collection_name = 'attack_mitigations'
try:
mongo.db.validate_collection(mitigation_collection_name)
if mongo.db.attack_mitigations.count() == 0:
raise errors.OperationFailure("Mitigation collection empty")
except errors.OperationFailure:
mongo.db.create_collection(mitigation_collection_name)
store_mitigations_on_mongo()
try:
mongo.db.create_collection(mitigation_collection_name)
finally:
store_mitigations_on_mongo()
def store_mitigations_on_mongo():
all_mitigations = MitreApiInterface.get_all_mitigations()
for mitigation in all_mitigations:
AttackMitigation.add_mitigation_from_stix2(mitigation)
stix2_mitigations = MitreApiInterface.get_all_mitigations()
mongo_mitigations = AttackMitigations.dict_from_stix2_attack_patterns(MitreApiInterface.get_all_attack_techniques())
mitigation_technique_relationships = MitreApiInterface.get_technique_and_mitigation_relationships()
for relationship in mitigation_technique_relationships:
mongo_mitigations[relationship['target_ref']].add_mitigation(stix2_mitigations[relationship['source_ref']])
for key, mongo_object in mongo_mitigations.items():
mongo_object.save()

View File

@ -10567,6 +10567,11 @@
"object-visit": "1.0.1"
}
},
"marked": {
"version": "0.8.2",
"resolved": "https://registry.npmjs.org/marked/-/marked-0.8.2.tgz",
"integrity": "sha512-EGwzEeCcLniFX51DhTpmTom+dSA/MG/OBUDjnWtHbEnjAH180VzUeAw+oE4+Zv+CoYBWyRlYOTR0N8SO9R1PVw=="
},
"md5.js": {
"version": "1.3.5",
"resolved": "https://registry.npmjs.org/md5.js/-/md5.js-1.3.5.tgz",

View File

@ -81,6 +81,7 @@
"filepond": "^4.7.3",
"json-loader": "^0.5.7",
"jwt-decode": "^2.2.0",
"marked": "^0.8.2",
"moment": "^2.24.0",
"node-sass": "^4.13.0",
"normalize.css": "^8.0.0",

View File

@ -8,15 +8,22 @@ class MitigationsComponent extends React.Component {
constructor(props) {
super(props);
if (typeof this.props.mitigations !== 'undefined'){
let descriptions = MitigationsComponent.parseDescription(this.props.mitigations.description);
this.state = {name: this.props.mitigations.name, descriptions: descriptions};
if (typeof this.props.mitigations !== 'undefined' && this.props.mitigations.length > 0){
this.state = {mitigations: this.props.mitigations};
} else {
this.state = {name: '', descriptions: []}
this.state = {mitigations: null}
}
}
static parseDescription(description){
static createRows(descriptions, references) {
let rows = [];
for(let i = 0; i < descriptions.length; i++){
rows[i] = {'description': descriptions[i], 'reference': references[i]};
}
return rows;
}
static parseDescription(description) {
const citationRegex = /\(Citation:.*\)/gi;
const emptyLineRegex = /^\s*[\r\n]/gm;
description = description.replace(citationRegex, '');
@ -26,28 +33,40 @@ class MitigationsComponent extends React.Component {
return descriptions;
}
static getMitigationDescriptions(name) {
static getMitigations() {
return ([{
Header: name,
Header: 'Mitigations',
style: {'text-align': 'left'},
columns: [
{ id: 'name',
accessor: x => this.getMitigationName(x.name, x.url),
width: 200},
{ id: 'description',
accessor: x => (<div dangerouslySetInnerHTML={{__html: x}} />),
accessor: x => (<div dangerouslySetInnerHTML={{__html: x.description}} />),
style: {'whiteSpace': 'unset'}}
]
}])
}
static getMitigationName(name, url) {
if(url){
return (<a href={url} target={'_blank'}>{name}</a>)
} else {
return (<p>{name}</p>)
}
}
render() {
return (
<div>
<br/>
{this.state.descriptions.length !== 0 ?
{this.state.mitigations ?
<ReactTable
columns={MitigationsComponent.getMitigationDescriptions(this.state.name)}
data={this.state.descriptions}
columns={MitigationsComponent.getMitigations()}
data={this.state.mitigations}
showPagination={false}
defaultPageSize={this.state.descriptions.length}
defaultPageSize={this.state.mitigations.length}
className={'attack-mitigation'}
/> : ''}
</div>