Add cache to boost performance and a few more fixes

This commit is contained in:
Oran Nadler 2018-04-03 10:47:10 +03:00
parent 0383830719
commit 990e68fc4d
1 changed files with 185 additions and 43 deletions

View File

@ -43,6 +43,57 @@ def myntlm(x):
hash = hashlib.new('md4', x.encode('utf-16le')).digest() hash = hashlib.new('md4', x.encode('utf-16le')).digest()
return str(binascii.hexlify(hash)) return str(binascii.hexlify(hash))
def cache(foo):
def hash(o):
if type(o) in (int, float, str, unicode):
return o
elif type(o) in (list, tuple):
hashed = tuple([hash(x) for x in o])
if "NotHashable" in hashed:
return "NotHashable"
return hashed
elif type(o) == dict:
hashed_keys = tuple([hash(k) for k, v in o.iteritems()])
hashed_vals = tuple([hash(v) for k, v in o.iteritems()])
if "NotHashable" in hashed_keys or "NotHashable" in hashed_vals:
return "NotHashable"
return tuple(zip(hashed_keys, hashed_vals))
elif type(o) == Machine:
return o.monkey_guid
elif type(o) == PthMap:
return "PthMapSingleton"
elif type(o) == PassTheHashMap:
return "PassTheHashMapSingleton"
else:
return "NotHashable"
def wrapper(*args, **kwargs):
hashed = (hash(args), hash(kwargs))
if "NotHashable" in hashed:
print foo
return foo(*args, **kwargs)
if not hasattr(foo, "_mycache_"):
foo._mycache_ = dict()
if hashed not in foo._mycache_.keys():
foo._mycache_[hashed] = foo(*args, **kwargs)
return foo._mycache_[hashed]
return wrapper
class Machine(object): class Machine(object):
def __init__(self, monkey_guid): def __init__(self, monkey_guid):
self.monkey_guid = str(monkey_guid) self.monkey_guid = str(monkey_guid)
@ -52,6 +103,7 @@ class Machine(object):
if self.latest_system_info.count() > 0: if self.latest_system_info.count() > 0:
self.latest_system_info = self.latest_system_info[0] self.latest_system_info = self.latest_system_info[0]
@cache
def GetMimikatzOutput(self): def GetMimikatzOutput(self):
doc = self.latest_system_info doc = self.latest_system_info
@ -60,6 +112,7 @@ class Machine(object):
return doc["data"]["mimikatz"] return doc["data"]["mimikatz"]
@cache
def GetHostName(self): def GetHostName(self):
doc = self.latest_system_info doc = self.latest_system_info
@ -68,6 +121,7 @@ class Machine(object):
return None return None
@cache
def GetIp(self): def GetIp(self):
doc = self.latest_system_info doc = self.latest_system_info
@ -76,6 +130,7 @@ class Machine(object):
return None return None
@cache
def GetDomainName(self): def GetDomainName(self):
doc = self.latest_system_info doc = self.latest_system_info
@ -84,6 +139,7 @@ class Machine(object):
return None return None
@cache
def GetDomainRole(self): def GetDomainRole(self):
doc = self.latest_system_info doc = self.latest_system_info
@ -92,9 +148,11 @@ class Machine(object):
return None return None
@cache
def IsDomainController(self): def IsDomainController(self):
return self.GetDomainRole() in (DsRole_RolePrimaryDomainController, DsRole_RoleBackupDomainController) return self.GetDomainRole() in (DsRole_RolePrimaryDomainController, DsRole_RoleBackupDomainController)
@cache
def GetSidByUsername(self, username): def GetSidByUsername(self, username):
doc = self.latest_system_info doc = self.latest_system_info
@ -113,6 +171,7 @@ class Machine(object):
return None return None
@cache
def GetUsernameBySid(self, sid): def GetUsernameBySid(self, sid):
doc = self.latest_system_info doc = self.latest_system_info
@ -131,6 +190,7 @@ class Machine(object):
return None return None
@cache
def GetUsernamesBySecret(self, secret): def GetUsernamesBySecret(self, secret):
sam = self.GetLocalSecrets() sam = self.GetLocalSecrets()
@ -142,10 +202,12 @@ class Machine(object):
return names return names
@cache
def GetSidsBySecret(self, secret): def GetSidsBySecret(self, secret):
usernames = self.GetUsernamesBySecret(secret) usernames = self.GetUsernamesBySecret(secret)
return set(map(self.GetSidByUsername, usernames)) return set(map(self.GetSidByUsername, usernames))
@cache
def GetGroupSidByGroupName(self, group_name): def GetGroupSidByGroupName(self, group_name):
doc = self.latest_system_info doc = self.latest_system_info
@ -157,6 +219,7 @@ class Machine(object):
return None return None
@cache
def GetUsersByGroupSid(self, sid): def GetUsersByGroupSid(self, sid):
doc = self.latest_system_info doc = self.latest_system_info
@ -173,6 +236,7 @@ class Machine(object):
return users return users
@cache
def GetDomainControllersMonkeyGuidByDomainName(self, domain_name): def GetDomainControllersMonkeyGuidByDomainName(self, domain_name):
cur = mongo.db.telemetry.find({"telem_type":"system_info_collection", "data.Win32_ComputerSystem.Domain":"u'%s'" % (domain_name,)}) cur = mongo.db.telemetry.find({"telem_type":"system_info_collection", "data.Win32_ComputerSystem.Domain":"u'%s'" % (domain_name,)})
@ -186,9 +250,11 @@ class Machine(object):
return GUIDs return GUIDs
@cache
def GetLocalAdmins(self): def GetLocalAdmins(self):
return set(self.GetUsersByGroupSid(self.GetGroupSidByGroupName("Administrators")).keys()) return set(self.GetUsersByGroupSid(self.GetGroupSidByGroupName("Administrators")).keys())
@cache
def GetLocalSids(self): def GetLocalSids(self):
doc = self.latest_system_info doc = self.latest_system_info
@ -199,9 +265,11 @@ class Machine(object):
return SIDs return SIDs
@cache
def GetLocalAdminNames(self): def GetLocalAdminNames(self):
return set(self.GetUsersByGroupSid(self.GetGroupSidByGroupName("Administrators")).values()) return set(self.GetUsersByGroupSid(self.GetGroupSidByGroupName("Administrators")).values())
@cache
def GetSam(self): def GetSam(self):
if not self.GetMimikatzOutput(): if not self.GetMimikatzOutput():
return {} return {}
@ -211,16 +279,26 @@ class Machine(object):
if mimikatz.count("\n42.") != 2: if mimikatz.count("\n42.") != 2:
return {} return {}
try:
sam_users = mimikatz.split("\n42.")[1].split("\nSAMKey :")[1].split("\n\n")[1:] sam_users = mimikatz.split("\n42.")[1].split("\nSAMKey :")[1].split("\n\n")[1:]
sam = {} sam = {}
for sam_user_txt in sam_users: for sam_user_txt in sam_users:
sam_user = dict([map(unicode.strip, line.split(":")) for line in filter(lambda l: l.count(":") == 1, sam_user_txt.splitlines())]) sam_user = dict([map(unicode.strip, line.split(":")) for line in filter(lambda l: l.count(":") == 1, sam_user_txt.splitlines())])
sam[sam_user["User"]] = sam_user["NTLM"].replace("[hashed secret]", "").strip()
ntlm = sam_user["NTLM"]
if "[hashed secret]" not in ntlm:
continue
sam[sam_user["User"]] = ntlm.replace("[hashed secret]", "").strip()
return sam return sam
except:
return {}
@cache
def GetNtds(self): def GetNtds(self):
if not self.GetMimikatzOutput(): if not self.GetMimikatzOutput():
return {} return {}
@ -243,6 +321,7 @@ class Machine(object):
return ntds return ntds
@cache
def GetLocalSecrets(self): def GetLocalSecrets(self):
sam = self.GetSam() sam = self.GetSam()
ntds = self.GetNtds() ntds = self.GetNtds()
@ -252,9 +331,11 @@ class Machine(object):
return secrets return secrets
@cache
def GetLocalAdminSecrets(self): def GetLocalAdminSecrets(self):
return set(self.GetLocalAdminCreds().values()) return set(self.GetLocalAdminCreds().values())
@cache
def GetLocalAdminCreds(self): def GetLocalAdminCreds(self):
admin_names = self.GetLocalAdminNames() admin_names = self.GetLocalAdminNames()
sam = self.GetLocalSecrets() sam = self.GetLocalSecrets()
@ -269,35 +350,44 @@ class Machine(object):
return admin_creds return admin_creds
@cache
def GetCachedSecrets(self): def GetCachedSecrets(self):
return set(self.GetCachedCreds().values()) return set(self.GetCachedCreds().values())
@cache
def GetCachedCreds(self): def GetCachedCreds(self):
doc = self.latest_system_info doc = self.latest_system_info
creds = dict() creds = dict()
for username in doc["data"]["credentials"]: if not self.GetMimikatzOutput():
user = doc["data"]["credentials"][username] return {}
if "password" in user.keys(): mimikatz = self.GetMimikatzOutput()
ntlm = myntlm(str(user["password"]))
elif "ntlm_hash" in user.keys():
ntlm = str(user["ntlm_hash"])
else:
continue
secret = hashlib.md5(ntlm.decode("hex")).hexdigest() for user in mimikatz.split("\n42.")[0].split("Authentication Id")[1:]:
username = None
secret = None
for line in user.splitlines():
if "User Name" in line:
username = line.split(":")[1].strip()
if ("NTLM" in line or "Password" in line) and "[hashed secret]" in line:
secret = line.split(":")[1].replace("[hashed secret]", "").strip()
if username and secret:
creds[username] = secret creds[username] = secret
return creds return creds
@cache
def GetDomainControllers(self): def GetDomainControllers(self):
domain_name = self.GetDomainName() domain_name = self.GetDomainName()
DCs = self.GetDomainControllersMonkeyGuidByDomainName(domain_name) DCs = self.GetDomainControllersMonkeyGuidByDomainName(domain_name)
return map(Machine, DCs) return map(Machine, DCs)
@cache
def GetDomainAdminsOfMachine(self): def GetDomainAdminsOfMachine(self):
DCs = self.GetDomainControllers() DCs = self.GetDomainControllers()
@ -308,12 +398,15 @@ class Machine(object):
return domain_admins return domain_admins
@cache
def GetAdmins(self): def GetAdmins(self):
return self.GetLocalAdmins() | self.GetDomainAdminsOfMachine() return self.GetLocalAdmins() | self.GetDomainAdminsOfMachine()
@cache
def GetAdminNames(self): def GetAdminNames(self):
return set(map(lambda x: self.GetUsernameBySid(x), self.GetAdmins())) return set(map(lambda x: self.GetUsernameBySid(x), self.GetAdmins()))
@cache
def GetCachedSids(self): def GetCachedSids(self):
doc = self.latest_system_info doc = self.latest_system_info
@ -329,6 +422,7 @@ class Machine(object):
return SIDs return SIDs
@cache
def GetCachedUsernames(self): def GetCachedUsernames(self):
doc = self.latest_system_info doc = self.latest_system_info
@ -342,12 +436,14 @@ class Machine(object):
class PassTheHashMap(object): class PassTheHashMap(object):
def __init__(self): def __init__(self):
self.vertices = self.GetAllMachines() self.vertices = self.GetAllMachines()
self.edges = set() self.edges = set()
self.machines = map(Machine, self.vertices) self.machines = map(Machine, self.vertices)
self.GenerateEdgesBySid() # Useful for non-cached domain users self.GenerateEdgesBySid() # Useful for non-cached domain users
self.GenerateEdgesBySamHash() # This will add edges based only on password hash without caring about username self.GenerateEdgesBySamHash() # This will add edges based only on password hash without caring about username
@cache
def GetAllMachines(self): def GetAllMachines(self):
cur = mongo.db.telemetry.find({"telem_type":"system_info_collection"}) cur = mongo.db.telemetry.find({"telem_type":"system_info_collection"})
@ -358,6 +454,7 @@ class PassTheHashMap(object):
return GUIDs return GUIDs
@cache
def ReprSidList(self, sid_list, attacker, victim): def ReprSidList(self, sid_list, attacker, victim):
label = set() label = set()
@ -372,6 +469,7 @@ class PassTheHashMap(object):
return ",\n".join(label) return ",\n".join(label)
@cache
def ReprSecretList(self, secret_list, victim): def ReprSecretList(self, secret_list, victim):
label = set() label = set()
@ -380,6 +478,7 @@ class PassTheHashMap(object):
return ",\n".join(label) return ",\n".join(label)
@cache
def GenerateEdgesBySid(self): def GenerateEdgesBySid(self):
for attacker in self.vertices: for attacker in self.vertices:
cached = Machine(attacker).GetCachedSids() cached = Machine(attacker).GetCachedSids()
@ -394,6 +493,7 @@ class PassTheHashMap(object):
label = self.ReprSidList(cached & admins, attacker, victim) label = self.ReprSidList(cached & admins, attacker, victim)
self.edges.add((attacker, victim, label)) self.edges.add((attacker, victim, label))
@cache
def GenerateEdgesBySamHash(self): def GenerateEdgesBySamHash(self):
for attacker in self.vertices: for attacker in self.vertices:
cached_creds = set(Machine(attacker).GetCachedCreds().items()) cached_creds = set(Machine(attacker).GetCachedCreds().items())
@ -408,6 +508,7 @@ class PassTheHashMap(object):
label = self.ReprSecretList(set(dict(cached_creds & admin_creds).values()), victim) label = self.ReprSecretList(set(dict(cached_creds & admin_creds).values()), victim)
self.edges.add((attacker, victim, label)) self.edges.add((attacker, victim, label))
@cache
def GenerateEdgesByUsername(self): def GenerateEdgesByUsername(self):
for attacker in self.vertices: for attacker in self.vertices:
cached = Machine(attacker).GetCachedUsernames() cached = Machine(attacker).GetCachedUsernames()
@ -421,17 +522,19 @@ class PassTheHashMap(object):
if len(cached & admins) > 0: if len(cached & admins) > 0:
self.edges.add((attacker, victim)) self.edges.add((attacker, victim))
@cache
def Print(self): def Print(self):
print map(lambda x: Machine(x).GetIp(), self.vertices) print map(lambda x: Machine(x).GetIp(), self.vertices)
print map(lambda x: (Machine(x[0]).GetIp(), Machine(x[1]).GetIp()), self.edges) print map(lambda x: (Machine(x[0]).GetIp(), Machine(x[1]).GetIp()), self.edges)
@cache
def GetPossibleAttackCountBySid(self, sid): def GetPossibleAttackCountBySid(self, sid):
return len(self.GetPossibleAttacksBySid(sid)) return len(self.GetPossibleAttacksBySid(sid))
def GetPossibleAttacksBySid(self, sid): @cache
def GetPossibleAttacksByAttacker(self, attacker):
attacks = set() attacks = set()
for attacker in self.vertices:
cached_creds = set(Machine(attacker).GetCachedCreds().items()) cached_creds = set(Machine(attacker).GetCachedCreds().items())
for victim in self.vertices: for victim in self.vertices:
@ -442,13 +545,25 @@ class PassTheHashMap(object):
if len(cached_creds & admin_creds) > 0: if len(cached_creds & admin_creds) > 0:
curr_attacks = dict(cached_creds & admin_creds) curr_attacks = dict(cached_creds & admin_creds)
attacks.add((attacker, victim, curr_attacks))
return attacks
@cache
def GetPossibleAttacksBySid(self, sid):
attacks = set()
for attacker in self.vertices:
tmp = self.GetPossibleAttacksByAttacker(attacker)
for _, victim, curr_attacks in tmp:
for username, secret in curr_attacks.iteritems(): for username, secret in curr_attacks.iteritems():
if Machine(victim).GetSidByUsername(username) == sid: if Machine(victim).GetSidByUsername(username) == sid:
attacks.add((attacker, victim)) attacks.add((attacker, victim))
return attacks return attacks
@cache
def GetSecretBySid(self, sid): def GetSecretBySid(self, sid):
for m in self.machines: for m in self.machines:
for user, user_secret in m.GetLocalSecrets().iteritems(): for user, user_secret in m.GetLocalSecrets().iteritems():
@ -457,15 +572,19 @@ class PassTheHashMap(object):
return None return None
@cache
def GetVictimCountBySid(self, sid): def GetVictimCountBySid(self, sid):
return len(self.GetVictimsBySid(sid)) return len(self.GetVictimsBySid(sid))
@cache
def GetVictimCountByMachine(self, attacker): def GetVictimCountByMachine(self, attacker):
return len(self.GetVictimsByAttacker(attacker)) return len(self.GetVictimsByAttacker(attacker))
@cache
def GetAttackCountBySecret(self, secret): def GetAttackCountBySecret(self, secret):
return len(self.GetAttackersBySecret(secret)) return len(self.GetAttackersBySecret(secret))
@cache
def GetAllUsernames(self): def GetAllUsernames(self):
names = set() names = set()
@ -474,6 +593,7 @@ class PassTheHashMap(object):
return names return names
@cache
def GetAllSids(self): def GetAllSids(self):
SIDs = set() SIDs = set()
@ -482,6 +602,7 @@ class PassTheHashMap(object):
return SIDs return SIDs
@cache
def GetAllSecrets(self): def GetAllSecrets(self):
secrets = set() secrets = set()
@ -491,6 +612,7 @@ class PassTheHashMap(object):
return secrets return secrets
@cache
def GetUsernameBySid(self, sid): def GetUsernameBySid(self, sid):
for m in self.machines: for m in self.machines:
username = m.GetUsernameBySid(sid) username = m.GetUsernameBySid(sid)
@ -500,6 +622,7 @@ class PassTheHashMap(object):
return None return None
@cache
def GetSidsBySecret(self, secret): def GetSidsBySecret(self, secret):
SIDs = set() SIDs = set()
@ -508,6 +631,7 @@ class PassTheHashMap(object):
return SIDs return SIDs
@cache
def GetAllDomainControllers(self): def GetAllDomainControllers(self):
DCs = set() DCs = set()
@ -517,6 +641,7 @@ class PassTheHashMap(object):
return DCs return DCs
@cache
def GetSidsByUsername(self, username): def GetSidsByUsername(self, username):
SIDs = set() SIDs = set()
@ -527,6 +652,7 @@ class PassTheHashMap(object):
return SIDs return SIDs
@cache
def GetVictimsBySid(self, sid): def GetVictimsBySid(self, sid):
machines = set() machines = set()
@ -536,6 +662,7 @@ class PassTheHashMap(object):
return machines return machines
@cache
def GetVictimsBySecret(self, secret): def GetVictimsBySecret(self, secret):
machines = set() machines = set()
@ -547,6 +674,7 @@ class PassTheHashMap(object):
return machines return machines
@cache
def GetAttackersBySecret(self, secret): def GetAttackersBySecret(self, secret):
machines = set() machines = set()
@ -556,6 +684,7 @@ class PassTheHashMap(object):
return machines return machines
@cache
def GetAttackersByVictim(self, victim): def GetAttackersByVictim(self, victim):
attackers = set() attackers = set()
@ -565,6 +694,7 @@ class PassTheHashMap(object):
return attackers return attackers
@cache
def GetVictimsByAttacker(self, attacker): def GetVictimsByAttacker(self, attacker):
victims = set() victims = set()
@ -574,6 +704,7 @@ class PassTheHashMap(object):
return victims return victims
@cache
def GetInPathCountByVictim(self, victim, already_processed=None): def GetInPathCountByVictim(self, victim, already_processed=None):
if type(victim) != unicode: if type(victim) != unicode:
victim = victim.monkey_guid victim = victim.monkey_guid
@ -607,11 +738,14 @@ def main():
print "<h2>Duplicated Passwords</h2>" print "<h2>Duplicated Passwords</h2>"
print "<h3>How many users share each secret?</h3>" print "<h3>How many users share each secret?</h3>"
dups = dict(map(lambda x: (x, len(pth.GetSidsBySecret(x))), pth.GetAllSecrets())) dups = dict(map(lambda x: (x, len(pth.GetSidsBySecret(x))), pth.GetAllSecrets()))
print """<table>""" print """<table>"""
print """<tr><th>Secret</th><th>User Count</th></tr>""" print """<tr><th>Secret</th><th>User Count</th></tr>"""
for secret, count in sorted(dups.iteritems(), key=lambda (k,v): (v,k), reverse=True): for secret, count in sorted(dups.iteritems(), key=lambda (k,v): (v,k), reverse=True):
if count <= 1:
continue
print """<tr><td><a href="#{secret}">{secret}</a></td><td>{count}</td>""".format(secret=secret, count=count) print """<tr><td><a href="#{secret}">{secret}</a></td><td>{count}</td>""".format(secret=secret, count=count)
print """</table>""" print """</table>"""
@ -622,6 +756,8 @@ def main():
print """<table>""" print """<table>"""
print """<tr><th>Secret</th><th>Machine Count</th></tr>""" print """<tr><th>Secret</th><th>Machine Count</th></tr>"""
for secret, count in sorted(cache_counts.iteritems(), key=lambda (k,v): (v,k), reverse=True): for secret, count in sorted(cache_counts.iteritems(), key=lambda (k,v): (v,k), reverse=True):
if count <= 0:
continue
print """<tr><td><a href="#{secret}">{secret}</a></td><td>{count}</td>""".format(secret=secret, count=count) print """<tr><td><a href="#{secret}">{secret}</a></td><td>{count}</td>""".format(secret=secret, count=count)
print """</table>""" print """</table>"""
@ -632,6 +768,8 @@ def main():
print """<table>""" print """<table>"""
print """<tr><th>SID</th><th>Username</th><th>Machine Count</th></tr>""" print """<tr><th>SID</th><th>Username</th><th>Machine Count</th></tr>"""
for sid, count in sorted(attackable_counts.iteritems(), key=lambda (k,v): (v,k), reverse=True): for sid, count in sorted(attackable_counts.iteritems(), key=lambda (k,v): (v,k), reverse=True):
if count <= 1:
continue
print """<tr><td><a href="#{sid}">{sid}</a></td><td>{username}</td><td>{count}</td>""".format(sid=sid, username=pth.GetUsernameBySid(sid), count=count) print """<tr><td><a href="#{sid}">{sid}</a></td><td>{username}</td><td>{count}</td>""".format(sid=sid, username=pth.GetUsernameBySid(sid), count=count)
print """</table>""" print """</table>"""
@ -642,6 +780,8 @@ def main():
print """<table>""" print """<table>"""
print """<tr><th>SID</th><th>Username</th><th>Machine Count</th></tr>""" print """<tr><th>SID</th><th>Username</th><th>Machine Count</th></tr>"""
for sid, count in sorted(possible_attacks_by_sid.iteritems(), key=lambda (k,v): (v,k), reverse=True): for sid, count in sorted(possible_attacks_by_sid.iteritems(), key=lambda (k,v): (v,k), reverse=True):
if count <= 1:
continue
print """<tr><td><a href="#{sid}">{sid}</a></td><td>{username}</td><td>{count}</td>""".format(sid=sid, username=pth.GetUsernameBySid(sid), count=count) print """<tr><td><a href="#{sid}">{sid}</a></td><td>{username}</td><td>{count}</td>""".format(sid=sid, username=pth.GetUsernameBySid(sid), count=count)
print """</table>""" print """</table>"""
@ -652,6 +792,8 @@ def main():
print """<table>""" print """<table>"""
print """<tr><th>Attacker Ip</th><th>Attacker Hostname</th><th>Domain Name</th><th>Victim Machine Count</th></tr>""" print """<tr><th>Attacker Ip</th><th>Attacker Hostname</th><th>Domain Name</th><th>Victim Machine Count</th></tr>"""
for m, count in sorted(attackable_counts.iteritems(), key=lambda (k,v): (v,k), reverse=True): for m, count in sorted(attackable_counts.iteritems(), key=lambda (k,v): (v,k), reverse=True):
if count <= 1:
continue
print """<tr><td><a href="#{ip}">{ip}</a></td><td>{hostname}</td><td>{domain}</td><td>{count}</td>""".format(ip=m.GetIp(), hostname=m.GetHostName(), domain=m.GetDomainName(), count=count) print """<tr><td><a href="#{ip}">{ip}</a></td><td>{hostname}</td><td>{domain}</td><td>{count}</td>""".format(ip=m.GetIp(), hostname=m.GetHostName(), domain=m.GetDomainName(), count=count)
print """</table>""" print """</table>"""