Merge pull request #1255 from renatahodovan/pythonic-python

Make Python targets more pythonic.
This commit is contained in:
Terence Parr 2016-09-25 11:24:07 -07:00 committed by GitHub
commit 5ea09b30f2
22 changed files with 88 additions and 249 deletions

View File

@ -97,16 +97,10 @@ class IntervalSet(object):
if self.intervals is None: if self.intervals is None:
return False return False
else: else:
for i in self.intervals: return any(item in i for i in self.intervals)
if item in i:
return True
return False
def __len__(self): def __len__(self):
xlen = 0 return sum(len(i) for i in self.intervals)
for i in self.intervals:
xlen += len(i)
return xlen
def removeRange(self, v): def removeRange(self, v):
if v.start==v.stop-1: if v.start==v.stop-1:
@ -126,7 +120,7 @@ class IntervalSet(object):
# check for included range, remove it # check for included range, remove it
elif v.start<=i.start and v.stop>=i.stop: elif v.start<=i.start and v.stop>=i.stop:
self.intervals.pop(k) self.intervals.pop(k)
k = k - 1 # need another pass k -= 1 # need another pass
# check for lower boundary # check for lower boundary
elif v.start<i.stop: elif v.start<i.stop:
self.intervals[k] = Interval(i.start, v.start) self.intervals[k] = Interval(i.start, v.start)

View File

@ -101,9 +101,7 @@ class ListTokenSource(TokenSource):
line = lastToken.line line = lastToken.line
tokenText = lastToken.text tokenText = lastToken.text
if tokenText is not None: if tokenText is not None:
for c in tokenText: line += tokenText.count('\n')
if c == '\n':
line += 1
# if no text is available, assume the token did not contain any newline characters. # if no text is available, assume the token did not contain any newline characters.
return line return line

View File

@ -585,15 +585,14 @@ def getCachedPredictionContext(context, contextCache, visited):
parent = getCachedPredictionContext(context.getParent(i), contextCache, visited) parent = getCachedPredictionContext(context.getParent(i), contextCache, visited)
if changed or parent is not context.getParent(i): if changed or parent is not context.getParent(i):
if not changed: if not changed:
parents = [None] * len(context) parents = [context.getParent(j) for j in range(len(context))]
for j in range(0, len(context)):
parents[j] = context.getParent(j)
changed = True changed = True
parents[i] = parent parents[i] = parent
if not changed: if not changed:
contextCache.add(context) contextCache.add(context)
visited[context] = context visited[context] = context
return context return context
updated = None updated = None
if len(parents) == 0: if len(parents) == 0:
updated = PredictionContext.EMPTY updated = PredictionContext.EMPTY

View File

@ -33,6 +33,7 @@
# info about the set, with support for combining similar configurations using a # info about the set, with support for combining similar configurations using a
# graph-structured stack. # graph-structured stack.
#/ #/
from functools import reduce
from io import StringIO from io import StringIO
from antlr4.PredictionContext import merge from antlr4.PredictionContext import merge
from antlr4.Utils import str_list from antlr4.Utils import str_list
@ -118,9 +119,9 @@ class ATNConfigSet(object):
h = config.hashCodeForConfigSet() h = config.hashCodeForConfigSet()
l = self.configLookup.get(h, None) l = self.configLookup.get(h, None)
if l is not None: if l is not None:
for c in l: r = next((c for c in l if config.equalsForConfigSet(c)), None)
if config.equalsForConfigSet(c): if r is not None:
return c return r
if l is None: if l is None:
l = [config] l = [config]
self.configLookup[h] = l self.configLookup[h] = l
@ -129,17 +130,10 @@ class ATNConfigSet(object):
return config return config
def getStates(self): def getStates(self):
states = set() return set(cfg.state for cfg in self.configs)
for c in self.configs:
states.add(c.state)
return states
def getPredicates(self): def getPredicates(self):
preds = list() return [cfg.semanticContext for cfg in self.configs if cfg.semanticContext!=SemanticContext.NONE]
for c in self.configs:
if c.semanticContext!=SemanticContext.NONE:
preds.append(c.semanticContext)
return preds
def get(self, i): def get(self, i):
return self.configs[i] return self.configs[i]
@ -181,10 +175,7 @@ class ATNConfigSet(object):
return self.hashConfigs() return self.hashConfigs()
def hashConfigs(self): def hashConfigs(self):
h = 0 return reduce(lambda h, cfg: hash((h, cfg)), self.configs, 0)
for cfg in self.configs:
h = hash((h, cfg))
return h
def __len__(self): def __len__(self):
return len(self.configs) return len(self.configs)

View File

@ -141,10 +141,7 @@ class ATNState(object):
return self.stateNumber return self.stateNumber
def __eq__(self, other): def __eq__(self, other):
if isinstance(other, ATNState): return isinstance(other, ATNState) and self.stateNumber==other.stateNumber
return self.stateNumber==other.stateNumber
else:
return False
def onlyHasEpsilonTransitions(self): def onlyHasEpsilonTransitions(self):
return self.epsilonOnlyTransitions return self.epsilonOnlyTransitions

View File

@ -535,11 +535,7 @@ class LexerATNSimulator(ATNSimulator):
def addDFAState(self, configs): def addDFAState(self, configs):
proposed = DFAState(configs=configs) proposed = DFAState(configs=configs)
firstConfigWithRuleStopState = None firstConfigWithRuleStopState = next((cfg for cfg in configs if isinstance(cfg.state, RuleStopState)), None)
for c in configs:
if isinstance(c.state, RuleStopState):
firstConfigWithRuleStopState = c
break
if firstConfigWithRuleStopState is not None: if firstConfigWithRuleStopState is not None:
proposed.isAcceptState = True proposed.isAcceptState = True

View File

@ -232,10 +232,7 @@ class PredictionMode(object):
# {@link RuleStopState}, otherwise {@code false} # {@link RuleStopState}, otherwise {@code false}
@classmethod @classmethod
def hasConfigInRuleStopState(cls, configs): def hasConfigInRuleStopState(cls, configs):
for c in configs: return any(isinstance(cfg.state, RuleStopState) for cfg in configs)
if isinstance(c.state, RuleStopState):
return True
return False
# Checks if all configurations in {@code configs} are in a # Checks if all configurations in {@code configs} are in a
# {@link RuleStopState}. Configurations meeting this condition have reached # {@link RuleStopState}. Configurations meeting this condition have reached
@ -247,10 +244,7 @@ class PredictionMode(object):
# {@link RuleStopState}, otherwise {@code false} # {@link RuleStopState}, otherwise {@code false}
@classmethod @classmethod
def allConfigsInRuleStopStates(cls, configs): def allConfigsInRuleStopStates(cls, configs):
for config in configs: return all(isinstance(cfg.state, RuleStopState) for cfg in configs)
if not isinstance(config.state, RuleStopState):
return False
return True
# #
# Full LL prediction termination. # Full LL prediction termination.
@ -419,10 +413,7 @@ class PredictionMode(object):
# #
@classmethod @classmethod
def hasNonConflictingAltSet(cls, altsets): def hasNonConflictingAltSet(cls, altsets):
for alts in altsets: return any(len(alts) == 1 for alts in altsets)
if len(alts)==1:
return True
return False
# #
# Determines if any single alternative subset in {@code altsets} contains # Determines if any single alternative subset in {@code altsets} contains
@ -434,10 +425,7 @@ class PredictionMode(object):
# #
@classmethod @classmethod
def hasConflictingAltSet(cls, altsets): def hasConflictingAltSet(cls, altsets):
for alts in altsets: return any(len(alts) > 1 for alts in altsets)
if len(alts)>1:
return True
return False
# #
# Determines if every alternative subset in {@code altsets} is equivalent. # Determines if every alternative subset in {@code altsets} is equivalent.
@ -448,13 +436,9 @@ class PredictionMode(object):
# #
@classmethod @classmethod
def allSubsetsEqual(cls, altsets): def allSubsetsEqual(cls, altsets):
first = None if not altsets:
for alts in altsets: return True
if first is None: return all(alts == altsets[0] for alts in altsets[1:])
first = alts
elif not alts==first:
return False
return True
# #
# Returns the unique alternative predicted by all alternative subsets in # Returns the unique alternative predicted by all alternative subsets in
@ -467,10 +451,8 @@ class PredictionMode(object):
def getUniqueAlt(cls, altsets): def getUniqueAlt(cls, altsets):
all = cls.getAlts(altsets) all = cls.getAlts(altsets)
if len(all)==1: if len(all)==1:
for one in all: return all.pop()
return one return ATN.INVALID_ALT_NUMBER
else:
return ATN.INVALID_ALT_NUMBER
# Gets the complete set of represented alternatives for a collection of # Gets the complete set of represented alternatives for a collection of
# alternative subsets. This method returns the union of each {@link BitSet} # alternative subsets. This method returns the union of each {@link BitSet}
@ -481,10 +463,7 @@ class PredictionMode(object):
# #
@classmethod @classmethod
def getAlts(cls, altsets): def getAlts(cls, altsets):
all = set() return set.union(*altsets)
for alts in altsets:
all = all | alts
return all
# #
# This function gets the conflicting alt subsets from a configuration set. # This function gets the conflicting alt subsets from a configuration set.
@ -528,11 +507,7 @@ class PredictionMode(object):
@classmethod @classmethod
def hasStateAssociatedWithOneAlt(cls, configs): def hasStateAssociatedWithOneAlt(cls, configs):
x = cls.getStateToAltMap(configs) return any(len(alts) == 1 for alts in cls.getStateToAltMap(configs).values())
for alts in x.values():
if len(alts)==1:
return True
return False
@classmethod @classmethod
def getSingleViableAlt(cls, altsets): def getSingleViableAlt(cls, altsets):

View File

@ -115,14 +115,7 @@ def orContext(a, b):
return result return result
def filterPrecedencePredicates(collection): def filterPrecedencePredicates(collection):
result = [] return [context for context in collection if isinstance(context, PrecedencePredicate)]
for context in collection:
if isinstance(context, PrecedencePredicate):
if result is None:
result = []
result.append(context)
return result
class Predicate(SemanticContext): class Predicate(SemanticContext):
@ -187,13 +180,11 @@ class AND(SemanticContext):
def __init__(self, a, b): def __init__(self, a, b):
operands = set() operands = set()
if isinstance( a, AND): if isinstance( a, AND):
for o in a.opnds: operands.update(a.opnds)
operands.add(o)
else: else:
operands.add(a) operands.add(a)
if isinstance( b, AND): if isinstance( b, AND):
for o in b.opnds: operands.update(b.opnds)
operands.add(o)
else: else:
operands.add(b) operands.add(b)
@ -203,7 +194,7 @@ class AND(SemanticContext):
reduced = min(precedencePredicates) reduced = min(precedencePredicates)
operands.add(reduced) operands.add(reduced)
self.opnds = [ o for o in operands ] self.opnds = list(operands)
def __eq__(self, other): def __eq__(self, other):
if self is other: if self is other:
@ -227,10 +218,7 @@ class AND(SemanticContext):
# unordered.</p> # unordered.</p>
# #
def eval(self, parser, outerContext): def eval(self, parser, outerContext):
for opnd in self.opnds: return all(opnd.eval(parser, outerContext) for opnd in self.opnds)
if not opnd.eval(parser, outerContext):
return False
return True
def evalPrecedence(self, parser, outerContext): def evalPrecedence(self, parser, outerContext):
differs = False differs = False
@ -277,13 +265,11 @@ class OR (SemanticContext):
def __init__(self, a, b): def __init__(self, a, b):
operands = set() operands = set()
if isinstance( a, OR): if isinstance( a, OR):
for o in a.opnds: operands.update(a.opnds)
operands.add(o)
else: else:
operands.add(a) operands.add(a)
if isinstance( b, OR): if isinstance( b, OR):
for o in b.opnds: operands.update(b.opnds)
operands.add(o)
else: else:
operands.add(b) operands.add(b)
@ -291,10 +277,10 @@ class OR (SemanticContext):
if len(precedencePredicates)>0: if len(precedencePredicates)>0:
# interested in the transition with the highest precedence # interested in the transition with the highest precedence
s = sorted(precedencePredicates) s = sorted(precedencePredicates)
reduced = s[len(s)-1] reduced = s[-1]
operands.add(reduced) operands.add(reduced)
self.opnds = [ o for o in operands ] self.opnds = list(operands)
def __eq__(self, other): def __eq__(self, other):
if self is other: if self is other:
@ -315,10 +301,7 @@ class OR (SemanticContext):
# unordered.</p> # unordered.</p>
# #
def eval(self, parser, outerContext): def eval(self, parser, outerContext):
for opnd in self.opnds: return any(opnd.eval(parser, outerContext) for opnd in self.opnds)
if opnd.eval(parser, outerContext):
return True
return False
def evalPrecedence(self, parser, outerContext): def evalPrecedence(self, parser, outerContext):
differs = False differs = False

View File

@ -105,14 +105,9 @@ class DFAState(object):
# Get the set of all alts mentioned by all ATN configurations in this # Get the set of all alts mentioned by all ATN configurations in this
# DFA state. # DFA state.
def getAltSet(self): def getAltSet(self):
alts = set()
if self.configs is not None: if self.configs is not None:
for c in self.configs: return set(cfg.alt for cfg in self.configs) or None
alts.add(c.alt) return None
if len(alts)==0:
return None
else:
return alts
def __hash__(self): def __hash__(self):
return hash(self.configs) return hash(self.configs)

View File

@ -129,8 +129,7 @@ class Trees(object):
@classmethod @classmethod
def descendants(cls, t): def descendants(cls, t):
nodes = [] nodes = [t]
nodes.append(t)
for i in range(0, t.getChildCount()): for i in range(0, t.getChildCount()):
nodes.extend(cls.descendants(t.getChild(i))) nodes.extend(cls.descendants(t.getChild(i)))
return nodes return nodes

View File

@ -289,12 +289,7 @@ class XPathRuleElement(XPathElement):
def evaluate(self, t): def evaluate(self, t):
# return all children of t that match nodeName # return all children of t that match nodeName
nodes = [] return [c for c in Trees.getChildren(t) if isinstance(c, ParserRuleContext) and (c.ruleIndex == self.ruleIndex) == (not self.invert)]
for c in Trees.getChildren(t):
if isinstance(c, ParserRuleContext ):
if (c.ruleIndex == self.ruleIndex ) == (not self.invert):
nodes.append(c)
return nodes
class XPathTokenAnywhereElement(XPathElement): class XPathTokenAnywhereElement(XPathElement):
@ -314,12 +309,8 @@ class XPathTokenElement(XPathElement):
def evaluate(self, t): def evaluate(self, t):
# return all children of t that match nodeName # return all children of t that match nodeName
nodes = [] return [c for c in Trees.getChildren(t) if isinstance(c, TerminalNode) and (c.symbol.type == self.tokenType) == (not self.invert)]
for c in Trees.getChildren(t):
if isinstance(c, TerminalNode):
if (c.symbol.type == self.tokenType ) == (not self.invert):
nodes.append(c)
return nodes
class XPathWildcardAnywhereElement(XPathElement): class XPathWildcardAnywhereElement(XPathElement):

View File

@ -84,16 +84,10 @@ class IntervalSet(object):
if self.intervals is None: if self.intervals is None:
return False return False
else: else:
for i in self.intervals: return any(item in i for i in self.intervals)
if item in i:
return True
return False
def __len__(self): def __len__(self):
xlen = 0 return sum(len(i) for i in self.intervals)
for i in self.intervals:
xlen += len(i)
return xlen
def removeRange(self, v): def removeRange(self, v):
if v.start==v.stop-1: if v.start==v.stop-1:
@ -113,7 +107,7 @@ class IntervalSet(object):
# check for included range, remove it # check for included range, remove it
elif v.start<=i.start and v.stop>=i.stop: elif v.start<=i.start and v.stop>=i.stop:
self.intervals.pop(k) self.intervals.pop(k)
k = k - 1 # need another pass k -= 1 # need another pass
# check for lower boundary # check for lower boundary
elif v.start<i.stop: elif v.start<i.stop:
self.intervals[k] = range(i.start, v.start) self.intervals[k] = range(i.start, v.start)

View File

@ -101,9 +101,7 @@ class ListTokenSource(TokenSource):
line = lastToken.line line = lastToken.line
tokenText = lastToken.text tokenText = lastToken.text
if tokenText is not None: if tokenText is not None:
for c in tokenText: line += tokenText.count('\n')
if c == '\n':
line += 1
# if no text is available, assume the token did not contain any newline characters. # if no text is available, assume the token did not contain any newline characters.
return line return line

View File

@ -581,15 +581,14 @@ def getCachedPredictionContext(context:PredictionContext, contextCache:Predictio
parent = getCachedPredictionContext(context.getParent(i), contextCache, visited) parent = getCachedPredictionContext(context.getParent(i), contextCache, visited)
if changed or parent is not context.getParent(i): if changed or parent is not context.getParent(i):
if not changed: if not changed:
parents = [None] * len(context) parents = [context.getParent(j) for j in range(len(context))]
for j in range(0, len(context)):
parents[j] = context.getParent(j)
changed = True changed = True
parents[i] = parent parents[i] = parent
if not changed: if not changed:
contextCache.add(context) contextCache.add(context)
visited[context] = context visited[context] = context
return context return context
updated = None updated = None
if len(parents) == 0: if len(parents) == 0:
updated = PredictionContext.EMPTY updated = PredictionContext.EMPTY

View File

@ -34,6 +34,7 @@
# graph-structured stack. # graph-structured stack.
#/ #/
from io import StringIO from io import StringIO
from functools import reduce
from antlr4.PredictionContext import PredictionContext, merge from antlr4.PredictionContext import PredictionContext, merge
from antlr4.Utils import str_list from antlr4.Utils import str_list
from antlr4.atn.ATN import ATN from antlr4.atn.ATN import ATN
@ -121,9 +122,9 @@ class ATNConfigSet(object):
h = config.hashCodeForConfigSet() h = config.hashCodeForConfigSet()
l = self.configLookup.get(h, None) l = self.configLookup.get(h, None)
if l is not None: if l is not None:
for c in l: r = next((cfg for cfg in l if config.equalsForConfigSet(cfg)), None)
if config.equalsForConfigSet(c): if r is not None:
return c return r
if l is None: if l is None:
l = [config] l = [config]
self.configLookup[h] = l self.configLookup[h] = l
@ -132,17 +133,10 @@ class ATNConfigSet(object):
return config return config
def getStates(self): def getStates(self):
states = set() return set(c.state for c in self.configs)
for c in self.configs:
states.add(c.state)
return states
def getPredicates(self): def getPredicates(self):
preds = list() return list(cfg.semanticContext for cfg in self.configs if cfg.semanticContext!=SemanticContext.NONE)
for c in self.configs:
if c.semanticContext!=SemanticContext.NONE:
preds.append(c.semanticContext)
return preds
def get(self, i:int): def get(self, i:int):
return self.configs[i] return self.configs[i]
@ -184,10 +178,7 @@ class ATNConfigSet(object):
return self.hashConfigs() return self.hashConfigs()
def hashConfigs(self): def hashConfigs(self):
h = 0 return reduce(lambda h, cfg: hash((h, cfg)), self.configs, 0)
for cfg in self.configs:
h = hash((h, cfg))
return h
def __len__(self): def __len__(self):
return len(self.configs) return len(self.configs)

View File

@ -143,10 +143,7 @@ class ATNState(object):
return self.stateNumber return self.stateNumber
def __eq__(self, other): def __eq__(self, other):
if isinstance(other, ATNState): return isinstance(other, ATNState) and self.stateNumber==other.stateNumber
return self.stateNumber==other.stateNumber
else:
return False
def onlyHasEpsilonTransitions(self): def onlyHasEpsilonTransitions(self):
return self.epsilonOnlyTransitions return self.epsilonOnlyTransitions

View File

@ -542,11 +542,7 @@ class LexerATNSimulator(ATNSimulator):
def addDFAState(self, configs:ATNConfigSet) -> DFAState: def addDFAState(self, configs:ATNConfigSet) -> DFAState:
proposed = DFAState(configs=configs) proposed = DFAState(configs=configs)
firstConfigWithRuleStopState = None firstConfigWithRuleStopState = next((cfg for cfg in configs if isinstance(cfg.state, RuleStopState)), None)
for c in configs:
if isinstance(c.state, RuleStopState):
firstConfigWithRuleStopState = c
break
if firstConfigWithRuleStopState is not None: if firstConfigWithRuleStopState is not None:
proposed.isAcceptState = True proposed.isAcceptState = True

View File

@ -235,10 +235,7 @@ class PredictionMode(Enum):
# {@link RuleStopState}, otherwise {@code false} # {@link RuleStopState}, otherwise {@code false}
@classmethod @classmethod
def hasConfigInRuleStopState(cls, configs:ATNConfigSet): def hasConfigInRuleStopState(cls, configs:ATNConfigSet):
for c in configs: return any(isinstance(cfg.state, RuleStopState) for cfg in configs)
if isinstance(c.state, RuleStopState):
return True
return False
# Checks if all configurations in {@code configs} are in a # Checks if all configurations in {@code configs} are in a
# {@link RuleStopState}. Configurations meeting this condition have reached # {@link RuleStopState}. Configurations meeting this condition have reached
@ -250,10 +247,7 @@ class PredictionMode(Enum):
# {@link RuleStopState}, otherwise {@code false} # {@link RuleStopState}, otherwise {@code false}
@classmethod @classmethod
def allConfigsInRuleStopStates(cls, configs:ATNConfigSet): def allConfigsInRuleStopStates(cls, configs:ATNConfigSet):
for config in configs: return all(isinstance(cfg.state, RuleStopState) for cfg in configs)
if not isinstance(config.state, RuleStopState):
return False
return True
# #
# Full LL prediction termination. # Full LL prediction termination.
@ -422,10 +416,7 @@ class PredictionMode(Enum):
# #
@classmethod @classmethod
def hasNonConflictingAltSet(cls, altsets:list): def hasNonConflictingAltSet(cls, altsets:list):
for alts in altsets: return any(len(alts) == 1 for alts in altsets)
if len(alts)==1:
return True
return False
# #
# Determines if any single alternative subset in {@code altsets} contains # Determines if any single alternative subset in {@code altsets} contains
@ -437,10 +428,7 @@ class PredictionMode(Enum):
# #
@classmethod @classmethod
def hasConflictingAltSet(cls, altsets:list): def hasConflictingAltSet(cls, altsets:list):
for alts in altsets: return any(len(alts) > 1 for alts in altsets)
if len(alts)>1:
return True
return False
# #
# Determines if every alternative subset in {@code altsets} is equivalent. # Determines if every alternative subset in {@code altsets} is equivalent.
@ -451,13 +439,10 @@ class PredictionMode(Enum):
# #
@classmethod @classmethod
def allSubsetsEqual(cls, altsets:list): def allSubsetsEqual(cls, altsets:list):
first = None if not altsets:
for alts in altsets: return True
if first is None: first = next(iter(altsets))
first = alts return all(alts == first for alts in iter(altsets))
elif not alts==first:
return False
return True
# #
# Returns the unique alternative predicted by all alternative subsets in # Returns the unique alternative predicted by all alternative subsets in
@ -470,10 +455,8 @@ class PredictionMode(Enum):
def getUniqueAlt(cls, altsets:list): def getUniqueAlt(cls, altsets:list):
all = cls.getAlts(altsets) all = cls.getAlts(altsets)
if len(all)==1: if len(all)==1:
for one in all: return next(iter(all))
return one return ATN.INVALID_ALT_NUMBER
else:
return ATN.INVALID_ALT_NUMBER
# Gets the complete set of represented alternatives for a collection of # Gets the complete set of represented alternatives for a collection of
# alternative subsets. This method returns the union of each {@link BitSet} # alternative subsets. This method returns the union of each {@link BitSet}
@ -484,10 +467,7 @@ class PredictionMode(Enum):
# #
@classmethod @classmethod
def getAlts(cls, altsets:list): def getAlts(cls, altsets:list):
all = set() return set.union(*altsets)
for alts in altsets:
all = all | alts
return all
# #
# This function gets the conflicting alt subsets from a configuration set. # This function gets the conflicting alt subsets from a configuration set.
@ -531,11 +511,7 @@ class PredictionMode(Enum):
@classmethod @classmethod
def hasStateAssociatedWithOneAlt(cls, configs:ATNConfigSet): def hasStateAssociatedWithOneAlt(cls, configs:ATNConfigSet):
x = cls.getStateToAltMap(configs) return any(len(alts) == 1 for alts in cls.getStateToAltMap(configs).values())
for alts in x.values():
if len(alts)==1:
return True
return False
@classmethod @classmethod
def getSingleViableAlt(cls, altsets:list): def getSingleViableAlt(cls, altsets:list):

View File

@ -116,13 +116,7 @@ def orContext(a:SemanticContext, b:SemanticContext):
return result return result
def filterPrecedencePredicates(collection:list): def filterPrecedencePredicates(collection:list):
result = [] return [context for context in collection if isinstance(context, PrecedencePredicate)]
for context in collection:
if isinstance(context, PrecedencePredicate):
if result is None:
result = []
result.append(context)
return result
class Predicate(SemanticContext): class Predicate(SemanticContext):
@ -188,13 +182,11 @@ class AND(SemanticContext):
def __init__(self, a:SemanticContext, b:SemanticContext): def __init__(self, a:SemanticContext, b:SemanticContext):
operands = set() operands = set()
if isinstance( a, AND ): if isinstance( a, AND ):
for o in a.opnds: operands.update(a.opnds)
operands.add(o)
else: else:
operands.add(a) operands.add(a)
if isinstance( b, AND ): if isinstance( b, AND ):
for o in b.opnds: operands.update(b.opnds)
operands.add(o)
else: else:
operands.add(b) operands.add(b)
@ -204,7 +196,7 @@ class AND(SemanticContext):
reduced = min(precedencePredicates) reduced = min(precedencePredicates)
operands.add(reduced) operands.add(reduced)
self.opnds = [ o for o in operands ] self.opnds = list(operands)
def __eq__(self, other): def __eq__(self, other):
if self is other: if self is other:
@ -227,11 +219,8 @@ class AND(SemanticContext):
# The evaluation of predicates by this context is short-circuiting, but # The evaluation of predicates by this context is short-circuiting, but
# unordered.</p> # unordered.</p>
# #
def eval(self, parser:Recognizer , outerContext:RuleContext ): def eval(self, parser:Recognizer, outerContext:RuleContext):
for opnd in self.opnds: return all(opnd.eval(parser, outerContext) for opnd in self.opnds)
if not opnd.eval(parser, outerContext):
return False
return True
def evalPrecedence(self, parser:Recognizer, outerContext:RuleContext): def evalPrecedence(self, parser:Recognizer, outerContext:RuleContext):
differs = False differs = False
@ -278,13 +267,11 @@ class OR (SemanticContext):
def __init__(self, a:SemanticContext, b:SemanticContext): def __init__(self, a:SemanticContext, b:SemanticContext):
operands = set() operands = set()
if isinstance( a, OR ): if isinstance( a, OR ):
for o in a.opnds: operands.update(a.opnds)
operands.add(o)
else: else:
operands.add(a) operands.add(a)
if isinstance( b, OR ): if isinstance( b, OR ):
for o in b.opnds: operands.update(b.opnds)
operands.add(o)
else: else:
operands.add(b) operands.add(b)
@ -292,10 +279,10 @@ class OR (SemanticContext):
if len(precedencePredicates)>0: if len(precedencePredicates)>0:
# interested in the transition with the highest precedence # interested in the transition with the highest precedence
s = sorted(precedencePredicates) s = sorted(precedencePredicates)
reduced = s[len(s)-1] reduced = s[-1]
operands.add(reduced) operands.add(reduced)
self.opnds = [ o for o in operands ] self.opnds = list(operands)
def __eq__(self, other): def __eq__(self, other):
if self is other: if self is other:
@ -316,10 +303,7 @@ class OR (SemanticContext):
# unordered.</p> # unordered.</p>
# #
def eval(self, parser:Recognizer, outerContext:RuleContext): def eval(self, parser:Recognizer, outerContext:RuleContext):
for opnd in self.opnds: return any(opnd.eval(parser, outerContext) for opnd in self.opnds)
if opnd.eval(parser, outerContext):
return True
return False
def evalPrecedence(self, parser:Recognizer, outerContext:RuleContext): def evalPrecedence(self, parser:Recognizer, outerContext:RuleContext):
differs = False differs = False

View File

@ -104,14 +104,9 @@ class DFAState(object):
# Get the set of all alts mentioned by all ATN configurations in this # Get the set of all alts mentioned by all ATN configurations in this
# DFA state. # DFA state.
def getAltSet(self): def getAltSet(self):
alts = set()
if self.configs is not None: if self.configs is not None:
for c in self.configs: return set(cfg.alt for cfg in self.configs) or None
alts.add(c.alt) return None
if len(alts)==0:
return None
else:
return alts
def __hash__(self): def __hash__(self):
return hash(self.configs) return hash(self.configs)

View File

@ -130,8 +130,7 @@ class Trees(object):
@classmethod @classmethod
def descendants(cls, t:ParseTree): def descendants(cls, t:ParseTree):
nodes = [] nodes = [t]
nodes.append(t)
for i in range(0, t.getChildCount()): for i in range(0, t.getChildCount()):
nodes.extend(cls.descendants(t.getChild(i))) nodes.extend(cls.descendants(t.getChild(i)))
return nodes return nodes

View File

@ -290,12 +290,8 @@ class XPathRuleElement(XPathElement):
def evaluate(self, t:ParseTree): def evaluate(self, t:ParseTree):
# return all children of t that match nodeName # return all children of t that match nodeName
nodes = [] return [c for c in Trees.getChildren(t) if isinstance(c, ParserRuleContext) and (c.ruleIndex == self.ruleIndex) == (not self.invert)]
for c in Trees.getChildren(t):
if isinstance(c, ParserRuleContext ):
if (c.ruleIndex == self.ruleIndex ) == (not self.invert):
nodes.append(c)
return nodes
class XPathTokenAnywhereElement(XPathElement): class XPathTokenAnywhereElement(XPathElement):
@ -315,12 +311,8 @@ class XPathTokenElement(XPathElement):
def evaluate(self, t:ParseTree): def evaluate(self, t:ParseTree):
# return all children of t that match nodeName # return all children of t that match nodeName
nodes = [] return [c for c in Trees.getChildren(t) if isinstance(c, TerminalNode) and (c.symbol.type == self.tokenType) == (not self.invert)]
for c in Trees.getChildren(t):
if isinstance(c, TerminalNode):
if (c.symbol.type == self.tokenType ) == (not self.invert):
nodes.append(c)
return nodes
class XPathWildcardAnywhereElement(XPathElement): class XPathWildcardAnywhereElement(XPathElement):