switch out recursive dfs for stack based approach, to avoid possibly hitting the recursion limit

This commit is contained in:
Ben Reilly 2014-09-05 15:26:05 -07:00
parent 45768e6b72
commit b878c73fc3
2 changed files with 34 additions and 24 deletions

View File

@ -94,31 +94,26 @@ class MigrationGraph(object):
""" """
Dynamic programming based depth first search, for finding dependencies. Dynamic programming based depth first search, for finding dependencies.
""" """
cache = {} visited = []
visited.append(start)
path = [start]
stack = sorted(get_children(start))
while stack:
node = stack.pop(0)
def _dfs(start, get_children, path): if node in path:
# If we already computed this, use that (dynamic programming) raise CircularDependencyError()
if (start, get_children) in cache: path.append(node)
return cache[(start, get_children)]
# If we've traversed here before, that's a circular dep visited.insert(0, node)
if start in path: children = sorted(get_children(node))
raise CircularDependencyError(path[path.index(start):] + [start])
# Build our own results list, starting with us if not children:
results = [] path = []
results.append(start)
# We need to add to results all the migrations this one depends on stack = children + stack
children = sorted(get_children(start))
path.append(start) return list(OrderedSet(visited))
for n in children:
results = _dfs(n, get_children, path) + results
path.pop()
# Use OrderedSet to ensure only one instance of each result
results = list(OrderedSet(results))
# Populate DP cache
cache[(start, get_children)] = results
# Done!
return results
return _dfs(start, get_children, [])
def __str__(self): def __str__(self):
return "Graph: %s nodes, %s edges" % (len(self.nodes), sum(len(x) for x in self.dependencies.values())) return "Graph: %s nodes, %s edges" % (len(self.nodes), sum(len(x) for x in self.dependencies.values()))

View File

@ -134,6 +134,21 @@ class GraphTests(TestCase):
graph.forwards_plan, ("app_a", "0003"), graph.forwards_plan, ("app_a", "0003"),
) )
def test_dfs(self):
graph = MigrationGraph()
root = ("app_a", "1")
graph.add_node(root, None)
expected = [root]
for i in xrange(2, 1000):
parent = ("app_a", str(i - 1))
child = ("app_a", str(i))
graph.add_node(child, None)
graph.add_dependency(str(i), child, parent)
expected.append(child)
actual = graph.dfs(root, lambda x: graph.dependents.get(x, set()))
self.assertEqual(expected[::-1], actual)
def test_plan_invalid_node(self): def test_plan_invalid_node(self):
""" """
Tests for forwards/backwards_plan of nonexistent node. Tests for forwards/backwards_plan of nonexistent node.