Merge pull request #1557 from mike-lischke/LR-loop-fix

Lr loop fix
This commit is contained in:
Terence Parr 2016-12-27 12:57:26 -08:00 committed by GitHub
commit f79a34aa3a
18 changed files with 390 additions and 145 deletions

View File

@ -113,7 +113,7 @@ public class PerformanceDescriptors {
@Override
public boolean ignore(String targetName) {
return !Arrays.asList("Java", "CSharp", "Python2", "Python3", "Node").contains(targetName);
return !Arrays.asList("Java", "CSharp", "Python2", "Python3", "Node", "Cpp").contains(targetName);
}
}
@ -199,7 +199,7 @@ public class PerformanceDescriptors {
@Override
public boolean ignore(String targetName) {
// passes, but still too slow in Python and JavaScript
return !Arrays.asList("Java", "CSharp").contains(targetName);
return !Arrays.asList("Java", "CSharp", "Cpp").contains(targetName);
}
}

View File

@ -77,8 +77,8 @@ void ATNConfig::setPrecedenceFilterSuppressed(bool value) {
bool ATNConfig::operator == (const ATNConfig &other) const {
return state->stateNumber == other.state->stateNumber && alt == other.alt &&
(context == other.context || (context != nullptr && *context == *(other.context))) &&
*semanticContext == *(other.semanticContext) &&
((context == other.context) || (*context == *other.context)) &&
*semanticContext == *other.semanticContext &&
isPrecedenceFilterSuppressed() == other.isPrecedenceFilterSuppressed();
}

View File

@ -27,7 +27,7 @@ namespace atn {
struct Comparer {
bool operator()(ATNConfig const& lhs, ATNConfig const& rhs) const {
return lhs == rhs;
return (&lhs == &rhs) || (lhs == rhs);
}
};

View File

@ -34,11 +34,9 @@ namespace atn {
bool hasSemanticContext;
bool dipsIntoOuterContext;
/// <summary>
/// Indicates that this configuration set is part of a full context
/// LL prediction. It will be used to determine how to merge $. With SLL
/// it's a wildcard whereas it is not for LL context merge.
/// </summary>
const bool fullCtx;
ATNConfigSet(bool fullCtx = true);
@ -89,12 +87,6 @@ namespace atn {
virtual std::string toString();
protected:
/// All configs but hashed by (s, i, _, pi) not including context. Wiped out
/// when we go readonly as this set becomes a DFA state.
// ml: no need for a comparer here as by definition there can be no hash clashes.
// (same hashes always mean same object).
std::unordered_map<size_t, ATNConfig *> _configLookup;
/// Indicates that the set of configurations is read-only. Do not
/// allow any code to manipulate the set; DFA states will point at
/// the sets and they must not change. This does not protect the other
@ -107,6 +99,10 @@ namespace atn {
private:
size_t _cachedHashCode;
/// All configs but hashed by (s, i, _, pi) not including context. Wiped out
/// when we go readonly as this set becomes a DFA state.
std::unordered_map<size_t, ATNConfig *> _configLookup;
void InitializeInstanceFields();
};

View File

@ -14,9 +14,9 @@ ArrayPredictionContext::ArrayPredictionContext(Ref<SingletonPredictionContext> c
: ArrayPredictionContext({ a->parent }, { a->returnState }) {
}
ArrayPredictionContext::ArrayPredictionContext(std::vector<std::weak_ptr<PredictionContext>> parents_,
ArrayPredictionContext::ArrayPredictionContext(std::vector<Ref<PredictionContext>> const& parents_,
std::vector<size_t> const& returnStates)
: PredictionContext(calculateHashCode(parents_, returnStates)), parents(makeRef(parents_)), returnStates(returnStates) {
: PredictionContext(calculateHashCode(parents_, returnStates)), parents(parents_), returnStates(returnStates) {
assert(parents.size() > 0);
assert(returnStates.size() > 0);
}
@ -30,7 +30,7 @@ size_t ArrayPredictionContext::size() const {
return returnStates.size();
}
std::weak_ptr<PredictionContext> ArrayPredictionContext::getParent(size_t index) const {
Ref<PredictionContext> ArrayPredictionContext::getParent(size_t index) const {
return parents[index];
}
@ -77,11 +77,3 @@ std::string ArrayPredictionContext::toString() const {
ss << "]";
return ss.str();
}
std::vector<Ref<PredictionContext>> ArrayPredictionContext::makeRef(const std::vector<std::weak_ptr<PredictionContext> > &input) {
std::vector<Ref<PredictionContext>> result;
for (auto element : input) {
result.push_back(element.lock());
}
return result;
}

View File

@ -26,19 +26,16 @@ namespace atn {
const std::vector<size_t> returnStates;
ArrayPredictionContext(Ref<SingletonPredictionContext> const& a);
ArrayPredictionContext(std::vector<std::weak_ptr<PredictionContext>> parents_,
std::vector<size_t> const& returnStates);
ArrayPredictionContext(std::vector<Ref<PredictionContext>> const& parents_, std::vector<size_t> const& returnStates);
virtual ~ArrayPredictionContext() {};
virtual bool isEmpty() const override;
virtual size_t size() const override;
virtual std::weak_ptr<PredictionContext> getParent(size_t index) const override;
virtual Ref<PredictionContext> getParent(size_t index) const override;
virtual size_t getReturnState(size_t index) const override;
bool operator == (const PredictionContext &o) const override;
virtual std::string toString() const override;
private:
std::vector<Ref<PredictionContext>> makeRef(const std::vector<std::weak_ptr<PredictionContext>> &input);
};
} // namespace atn

View File

@ -7,7 +7,7 @@
using namespace antlr4::atn;
EmptyPredictionContext::EmptyPredictionContext() : SingletonPredictionContext(std::weak_ptr<PredictionContext>(), EMPTY_RETURN_STATE) {
EmptyPredictionContext::EmptyPredictionContext() : SingletonPredictionContext(nullptr, EMPTY_RETURN_STATE) {
}
bool EmptyPredictionContext::isEmpty() const {
@ -18,8 +18,8 @@ size_t EmptyPredictionContext::size() const {
return 1;
}
std::weak_ptr<PredictionContext> EmptyPredictionContext::getParent(size_t /*index*/) const {
return std::weak_ptr<PredictionContext>();
Ref<PredictionContext> EmptyPredictionContext::getParent(size_t /*index*/) const {
return nullptr;
}
size_t EmptyPredictionContext::getReturnState(size_t /*index*/) const {

View File

@ -16,7 +16,7 @@ namespace atn {
virtual bool isEmpty() const override;
virtual size_t size() const override;
virtual std::weak_ptr<PredictionContext> getParent(size_t index) const override;
virtual Ref<PredictionContext> getParent(size_t index) const override;
virtual size_t getReturnState(size_t index) const override;
virtual std::string toString() const override;

View File

@ -109,7 +109,7 @@ void LL1Analyzer::_LOOK(ATNState *s, ATNState *stopState, Ref<PredictionContext>
});
calledRuleStack[returnState->ruleIndex] = false;
_LOOK(returnState, stopState, ctx->getParent(i).lock(), look, lookBusy, calledRuleStack, seeThruPreds, addEOF);
_LOOK(returnState, stopState, ctx->getParent(i), look, lookBusy, calledRuleStack, seeThruPreds, addEOF);
}
return;
}

View File

@ -55,7 +55,7 @@ bool LexerIndexedCustomAction::operator == (const LexerAction &obj) const {
return false;
}
return _offset == action->_offset && _action == action->_action;
return _offset == action->_offset && *_action == *action->_action;
}
std::string LexerIndexedCustomAction::toString() const {

View File

@ -21,6 +21,11 @@
#include "atn/RuleStopState.h"
#include "atn/ATNConfigSet.h"
#include "atn/ATNConfig.h"
#include "atn/StarLoopEntryState.h"
#include "atn/BlockStartState.h"
#include "atn/BlockEndState.h"
#include "misc/Interval.h"
#include "ANTLRErrorListener.h"
@ -39,6 +44,8 @@ using namespace antlr4::atn;
using namespace antlrcpp;
const bool ParserATNSimulator::TURN_OFF_LR_LOOP_ENTRY_BRANCH_OPT = ParserATNSimulator::getLrLoopSetting();
ParserATNSimulator::ParserATNSimulator(const ATN &atn, std::vector<dfa::DFA> &decisionToDFA,
PredictionContextCache &sharedContextCache)
: ParserATNSimulator(nullptr, atn, decisionToDFA, sharedContextCache) {
@ -91,8 +98,7 @@ size_t ParserATNSimulator::adaptivePredict(TokenStream *input, size_t decision,
// the start state for a precedence DFA depends on the current
// parser precedence, and is provided by a DFA method.
s0 = dfa.getPrecedenceStartState(parser->getPrecedence());
}
else {
} else {
// the start state for a "regular" DFA is just s0
s0 = dfa.s0;
}
@ -875,6 +881,9 @@ void ParserATNSimulator::closure_(Ref<ATNConfig> const& config, ATNConfigSet *co
}
for (size_t i = 0; i < p->transitions.size(); i++) {
if (i == 0 && canDropLoopEntryEdgeInLeftRecursiveRule(config.get()))
continue;
Transition *t = p->transitions[i];
bool continueCollecting = !is<ActionTransition*>(t) && collectPredicates;
Ref<ATNConfig> c = getEpsilonTarget(config, t, continueCollecting, depth == 0, fullCtx, treatEofAsEpsilon);
@ -932,6 +941,84 @@ void ParserATNSimulator::closure_(Ref<ATNConfig> const& config, ATNConfigSet *co
}
}
bool ParserATNSimulator::canDropLoopEntryEdgeInLeftRecursiveRule(ATNConfig *config) const {
if (TURN_OFF_LR_LOOP_ENTRY_BRANCH_OPT)
return false;
ATNState *p = config->state;
// First check to see if we are in StarLoopEntryState generated during
// left-recursion elimination. For efficiency, also check if
// the context has an empty stack case. If so, it would mean
// global FOLLOW so we can't perform optimization
if ( p->getStateType() != ATNState::STAR_LOOP_ENTRY ||
!((StarLoopEntryState *)p)->isPrecedenceDecision || // Are we the special loop entry/exit state?
config->context->isEmpty() || // If SLL wildcard
config->context->hasEmptyPath())
{
return false;
}
// Require all return states to return back to the same rule
// that p is in.
size_t numCtxs = config->context->size();
for (size_t i = 0; i < numCtxs; i++) { // for each stack context
ATNState *returnState = atn.states[config->context->getReturnState(i)];
if (returnState->ruleIndex != p->ruleIndex)
return false;
}
BlockStartState *decisionStartState = (BlockStartState *)p->transitions[0]->target;
size_t blockEndStateNum = decisionStartState->endState->stateNumber;
BlockEndState *blockEndState = (BlockEndState *)atn.states[blockEndStateNum];
// Verify that the top of each stack context leads to loop entry/exit
// state through epsilon edges and w/o leaving rule.
for (size_t i = 0; i < numCtxs; i++) { // for each stack context
size_t returnStateNumber = config->context->getReturnState(i);
ATNState *returnState = atn.states[returnStateNumber];
// All states must have single outgoing epsilon edge.
if (returnState->transitions.size() != 1 || !returnState->transitions[0]->isEpsilon())
{
return false;
}
// Look for prefix op case like 'not expr', (' type ')' expr
ATNState *returnStateTarget = returnState->transitions[0]->target;
if (returnState->getStateType() == ATNState::BLOCK_END && returnStateTarget == p) {
continue;
}
// Look for 'expr op expr' or case where expr's return state is block end
// of (...)* internal block; the block end points to loop back
// which points to p but we don't need to check that
if (returnState == blockEndState) {
continue;
}
// Look for ternary expr ? expr : expr. The return state points at block end,
// which points at loop entry state
if (returnStateTarget == blockEndState) {
continue;
}
// Look for complex prefix 'between expr and expr' case where 2nd expr's
// return state points at block end state of (...)* internal block
if (returnStateTarget->getStateType() == ATNState::BLOCK_END &&
returnStateTarget->transitions.size() == 1 &&
returnStateTarget->transitions[0]->isEpsilon() &&
returnStateTarget->transitions[0]->target == p)
{
continue;
}
// Anything else ain't conforming.
return false;
}
return true;
}
std::string ParserATNSimulator::getRuleName(size_t index) {
if (parser != nullptr) {
return parser->getRuleNames()[index];
@ -1253,6 +1340,14 @@ Parser* ParserATNSimulator::getParser() {
return parser;
}
bool ParserATNSimulator::getLrLoopSetting() {
char *var = std::getenv("TURN_OFF_LR_LOOP_ENTRY_BRANCH_OPT");
if (var == nullptr)
return false;
std::string value(var);
return value == "true" || value == "1";
}
void ParserATNSimulator::InitializeInstanceFields() {
mode = PredictionMode::LL;
_startIndex = 0;

View File

@ -247,6 +247,8 @@ namespace atn {
Parser *const parser;
public:
static const bool TURN_OFF_LR_LOOP_ENTRY_BRANCH_OPT;
std::vector<dfa::DFA> &decisionToDFA;
/// <summary>
@ -676,6 +678,93 @@ namespace atn {
bool collectPredicates, bool fullCtx, int depth, bool treatEofAsEpsilon);
public:
/** Implements first-edge (loop entry) elimination as an optimization
* during closure operations. See antlr/antlr4#1398.
*
* The optimization is to avoid adding the loop entry config when
* the exit path can only lead back to the same
* StarLoopEntryState after popping context at the rule end state
* (traversing only epsilon edges, so we're still in closure, in
* this same rule).
*
* We need to detect any state that can reach loop entry on
* epsilon w/o exiting rule. We don't have to look at FOLLOW
* links, just ensure that all stack tops for config refer to key
* states in LR rule.
*
* To verify we are in the right situation we must first check
* closure is at a StarLoopEntryState generated during LR removal.
* Then we check that each stack top of context is a return state
* from one of these cases:
*
* 1. 'not' expr, '(' type ')' expr. The return state points at loop entry state
* 2. expr op expr. The return state is the block end of internal block of (...)*
* 3. 'between' expr 'and' expr. The return state of 2nd expr reference.
* That state points at block end of internal block of (...)*.
* 4. expr '?' expr ':' expr. The return state points at block end,
* which points at loop entry state.
*
* If any is true for each stack top, then closure does not add a
* config to the current config set for edge[0], the loop entry branch.
*
* Conditions fail if any context for the current config is:
*
* a. empty (we'd fall out of expr to do a global FOLLOW which could
* even be to some weird spot in expr) or,
* b. lies outside of expr or,
* c. lies within expr but at a state not the BlockEndState
* generated during LR removal
*
* Do we need to evaluate predicates ever in closure for this case?
*
* No. Predicates, including precedence predicates, are only
* evaluated when computing a DFA start state. I.e., only before
* the lookahead (but not parser) consumes a token.
*
* There are no epsilon edges allowed in LR rule alt blocks or in
* the "primary" part (ID here). If closure is in
* StarLoopEntryState any lookahead operation will have consumed a
* token as there are no epsilon-paths that lead to
* StarLoopEntryState. We do not have to evaluate predicates
* therefore if we are in the generated StarLoopEntryState of a LR
* rule. Note that when making a prediction starting at that
* decision point, decision d=2, compute-start-state performs
* closure starting at edges[0], edges[1] emanating from
* StarLoopEntryState. That means it is not performing closure on
* StarLoopEntryState during compute-start-state.
*
* How do we know this always gives same prediction answer?
*
* Without predicates, loop entry and exit paths are ambiguous
* upon remaining input +b (in, say, a+b). Either paths lead to
* valid parses. Closure can lead to consuming + immediately or by
* falling out of this call to expr back into expr and loop back
* again to StarLoopEntryState to match +b. In this special case,
* we choose the more efficient path, which is to take the bypass
* path.
*
* The lookahead language has not changed because closure chooses
* one path over the other. Both paths lead to consuming the same
* remaining input during a lookahead operation. If the next token
* is an operator, lookahead will enter the choice block with
* operators. If it is not, lookahead will exit expr. Same as if
* closure had chosen to enter the choice block immediately.
*
* Closure is examining one config (some loopentrystate, some alt,
* context) which means it is considering exactly one alt. Closure
* always copies the same alt to any derived configs.
*
* How do we know this optimization doesn't mess up precedence in
* our parse trees?
*
* Looking through expr from left edge of stat only has to confirm
* that an input, say, a+b+c; begins with any valid interpretation
* of an expression. The precedence actually doesn't matter when
* making a decision in stat seeing through expr. It is only when
* parsing rule expr that we must use the precedence to get the
* right interpretation and, hence, parse tree.
*/
bool canDropLoopEntryEdgeInLeftRecursiveRule(ATNConfig *config) const;
virtual std::string getRuleName(size_t index);
protected:
@ -819,6 +908,7 @@ namespace atn {
Parser* getParser();
private:
static bool getLrLoopSetting();
void InitializeInstanceFields();
};

View File

@ -23,6 +23,8 @@ using namespace antlrcpp;
size_t PredictionContext::globalNodeCount = 0;
const Ref<PredictionContext> PredictionContext::EMPTY = std::make_shared<EmptyPredictionContext>();
//----------------- PredictionContext ----------------------------------------------------------------------------------
PredictionContext::PredictionContext(size_t cachedHashCode) : id(globalNodeCount++), cachedHashCode(cachedHashCode) {
}
@ -48,15 +50,12 @@ Ref<PredictionContext> PredictionContext::fromRuleContext(const ATN &atn, RuleCo
return SingletonPredictionContext::create(parent, transition->followState->stateNumber);
}
bool PredictionContext::operator != (const PredictionContext &o) const {
return !(*this == o);
}
bool PredictionContext::isEmpty() const {
return this == EMPTY.get();
}
bool PredictionContext::hasEmptyPath() const {
// since EMPTY_RETURN_STATE can only appear in the last position, we check last one
return getReturnState(size() - 1) == EMPTY_RETURN_STATE;
}
@ -70,23 +69,20 @@ size_t PredictionContext::calculateEmptyHashCode() {
return hash;
}
size_t PredictionContext::calculateHashCode(std::weak_ptr<PredictionContext> parent, size_t returnState) {
size_t PredictionContext::calculateHashCode(Ref<PredictionContext> parent, size_t returnState) {
size_t hash = MurmurHash::initialize(INITIAL_HASH);
hash = MurmurHash::update(hash, parent.lock());
hash = MurmurHash::update(hash, parent);
hash = MurmurHash::update(hash, returnState);
hash = MurmurHash::finish(hash, 2);
return hash;
}
size_t PredictionContext::calculateHashCode(const std::vector<std::weak_ptr<PredictionContext>> &parents,
size_t PredictionContext::calculateHashCode(const std::vector<Ref<PredictionContext>> &parents,
const std::vector<size_t> &returnStates) {
size_t hash = MurmurHash::initialize(INITIAL_HASH);
for (auto parent : parents) {
if (parent.expired())
hash = MurmurHash::update(hash, 0);
else
hash = MurmurHash::update(hash, parent.lock());
hash = MurmurHash::update(hash, parent);
}
for (auto returnState : returnStates) {
@ -98,11 +94,10 @@ size_t PredictionContext::calculateHashCode(const std::vector<std::weak_ptr<Pred
Ref<PredictionContext> PredictionContext::merge(const Ref<PredictionContext> &a,
const Ref<PredictionContext> &b, bool rootIsWildcard, PredictionContextMergeCache *mergeCache) {
assert(a && b);
// share same graph if both same
if (a == b) {
if (a == b || *a == *b) {
return a;
}
@ -111,8 +106,8 @@ Ref<PredictionContext> PredictionContext::merge(const Ref<PredictionContext> &a,
std::dynamic_pointer_cast<SingletonPredictionContext>(b), rootIsWildcard, mergeCache);
}
// At least one of a or b is array
// If one is $ and rootIsWildcard, return $ as * wildcard
// At least one of a or b is array.
// If one is $ and rootIsWildcard, return $ as * wildcard.
if (rootIsWildcard) {
if (is<EmptyPredictionContext>(a)) {
return a;
@ -142,20 +137,20 @@ Ref<PredictionContext> PredictionContext::mergeSingletons(const Ref<SingletonPre
const Ref<SingletonPredictionContext> &b, bool rootIsWildcard, PredictionContextMergeCache *mergeCache) {
if (mergeCache != nullptr) { // Can be null if not given to the ATNState from which this call originates.
auto iterator = mergeCache->find({ a, b });
if (iterator != mergeCache->end()) {
return iterator->second;
auto existing = mergeCache->get(a, b);
if (existing) {
return existing;
}
iterator = mergeCache->find({ b, a });
if (iterator != mergeCache->end()) {
return iterator->second;
existing = mergeCache->get(b, a);
if (existing) {
return existing;
}
}
Ref<PredictionContext> rootMerge = mergeRoot(a, b, rootIsWildcard);
if (rootMerge) {
if (mergeCache != nullptr) {
(*mergeCache)[{ a, b }] = rootMerge;
mergeCache->put(a, b, rootMerge);
}
return rootMerge;
}
@ -179,26 +174,26 @@ Ref<PredictionContext> PredictionContext::mergeSingletons(const Ref<SingletonPre
// new joined parent so create new singleton pointing to it, a'
Ref<PredictionContext> a_ = SingletonPredictionContext::create(parent, a->returnState);
if (mergeCache != nullptr) {
(*mergeCache)[{ a, b }] = a_;
mergeCache->put(a, b, a_);
}
return a_;
} else {
// a != b payloads differ
// see if we can collapse parents due to $+x parents if local ctx
std::weak_ptr<PredictionContext> singleParent;
if (a == b || (parentA && parentA == parentB)) { // ax + bx = [a,b]x
singleParent = a->parent;
Ref<PredictionContext> singleParent;
if (a == b || (*parentA == *parentB)) { // ax + bx = [a,b]x
singleParent = parentA;
}
if (!singleParent.expired()) { // parents are same, sort payloads and use same parent
if (singleParent) { // parents are same, sort payloads and use same parent
std::vector<size_t> payloads = { a->returnState, b->returnState };
if (a->returnState > b->returnState) {
payloads[0] = b->returnState;
payloads[1] = a->returnState;
}
std::vector<std::weak_ptr<PredictionContext>> parents = { singleParent, singleParent };
std::vector<Ref<PredictionContext>> parents = { singleParent, singleParent };
Ref<PredictionContext> a_ = std::make_shared<ArrayPredictionContext>(parents, payloads);
if (mergeCache != nullptr) {
(*mergeCache)[{ a, b }] = a_;
mergeCache->put(a, b, a_);
}
return a_;
}
@ -209,16 +204,16 @@ Ref<PredictionContext> PredictionContext::mergeSingletons(const Ref<SingletonPre
Ref<PredictionContext> a_;
if (a->returnState > b->returnState) { // sort by payload
std::vector<size_t> payloads = { b->returnState, a->returnState };
std::vector<std::weak_ptr<PredictionContext>> parents = { b->parent, a->parent };
std::vector<Ref<PredictionContext>> parents = { b->parent, a->parent };
a_ = std::make_shared<ArrayPredictionContext>(parents, payloads);
} else {
std::vector<size_t> payloads = {a->returnState, b->returnState};
std::vector<std::weak_ptr<PredictionContext>> parents = { a->parent, b->parent };
std::vector<Ref<PredictionContext>> parents = { a->parent, b->parent };
a_ = std::make_shared<ArrayPredictionContext>(parents, payloads);
}
if (mergeCache != nullptr) {
(*mergeCache)[{ a, b }] = a_;
mergeCache->put(a, b, a_);
}
return a_;
}
@ -239,13 +234,13 @@ Ref<PredictionContext> PredictionContext::mergeRoot(const Ref<SingletonPredictio
}
if (a == EMPTY) { // $ + x = [$,x]
std::vector<size_t> payloads = { b->returnState, EMPTY_RETURN_STATE };
std::vector<std::weak_ptr<PredictionContext>> parents = { b->parent, EMPTY };
std::vector<Ref<PredictionContext>> parents = { b->parent, nullptr };
Ref<PredictionContext> joined = std::make_shared<ArrayPredictionContext>(parents, payloads);
return joined;
}
if (b == EMPTY) { // x + $ = [$,x] ($ is always first if present)
std::vector<size_t> payloads = { a->returnState, EMPTY_RETURN_STATE };
std::vector<std::weak_ptr<PredictionContext>> parents = { a->parent, EMPTY };
std::vector<Ref<PredictionContext>> parents = { a->parent, nullptr };
Ref<PredictionContext> joined = std::make_shared<ArrayPredictionContext>(parents, payloads);
return joined;
}
@ -257,13 +252,13 @@ Ref<PredictionContext> PredictionContext::mergeArrays(const Ref<ArrayPredictionC
const Ref<ArrayPredictionContext> &b, bool rootIsWildcard, PredictionContextMergeCache *mergeCache) {
if (mergeCache != nullptr) {
auto iterator = mergeCache->find({ a, b });
if (iterator != mergeCache->end()) {
return iterator->second;
auto existing = mergeCache->get(a, b);
if (existing) {
return existing;
}
iterator = mergeCache->find({ b, a });
if (iterator != mergeCache->end()) {
return iterator->second;
existing = mergeCache->get(b, a);
if (existing) {
return existing;
}
}
@ -273,7 +268,7 @@ Ref<PredictionContext> PredictionContext::mergeArrays(const Ref<ArrayPredictionC
size_t k = 0; // walks target M array
std::vector<size_t> mergedReturnStates(a->returnStates.size() + b->returnStates.size());
std::vector<std::weak_ptr<PredictionContext>> mergedParents(a->returnStates.size() + b->returnStates.size());
std::vector<Ref<PredictionContext>> mergedParents(a->returnStates.size() + b->returnStates.size());
// walk and merge to yield mergedParents, mergedReturnStates
while (i < a->returnStates.size() && j < b->returnStates.size()) {
@ -284,7 +279,7 @@ Ref<PredictionContext> PredictionContext::mergeArrays(const Ref<ArrayPredictionC
size_t payload = a->returnStates[i];
// $+$ = $
bool both$ = payload == EMPTY_RETURN_STATE && a_parent && b_parent;
bool ax_ax = (a_parent && b_parent) && a_parent == b_parent; // ax+ax -> ax
bool ax_ax = (a_parent && b_parent) && *a_parent == *b_parent; // ax+ax -> ax
if (both$ || ax_ax) {
mergedParents[k] = a_parent; // choose left
mergedReturnStates[k] = payload;
@ -327,15 +322,13 @@ Ref<PredictionContext> PredictionContext::mergeArrays(const Ref<ArrayPredictionC
// trim merged if we combined a few that had same stack tops
if (k < mergedParents.size()) { // write index < last position; trim
if (k == 1) { // for just one merged element, return singleton top
Ref<PredictionContext> a_ = SingletonPredictionContext::create(mergedParents[0].lock(), mergedReturnStates[0]);
Ref<PredictionContext> a_ = SingletonPredictionContext::create(mergedParents[0], mergedReturnStates[0]);
if (mergeCache != nullptr) {
(*mergeCache)[{ a, b }] = a_;
mergeCache->put(a, b, a_);
}
return a_;
}
//mergedParents = Arrays::copyOf(mergedParents, k);
mergedParents.resize(k);
//mergedReturnStates = Arrays::copyOf(mergedReturnStates, k);
mergedReturnStates.resize(k);
}
@ -343,36 +336,38 @@ Ref<PredictionContext> PredictionContext::mergeArrays(const Ref<ArrayPredictionC
// if we created same array as a or b, return that instead
// TO_DO: track whether this is possible above during merge sort for speed
if (M == a) {
if (*M == *a) {
if (mergeCache != nullptr) {
(*mergeCache)[{ a, b }] = a;
mergeCache->put(a, b, a);
}
return a;
}
if (M == b) {
if (*M == *b) {
if (mergeCache != nullptr) {
(*mergeCache)[{ a, b }] = b;
mergeCache->put(a, b, b);
}
return b;
}
// This part differs from Java code. We have to recreate the context as the parents array is copied on creation.
if (combineCommonParents(mergedParents))
// ml: this part differs from Java code. We have to recreate the context as the parents array is copied on creation.
if (combineCommonParents(mergedParents)) {
mergedReturnStates.resize(mergedParents.size());
M = std::make_shared<ArrayPredictionContext>(mergedParents, mergedReturnStates);
}
if (mergeCache != nullptr) {
(*mergeCache)[{ a, b }] = M;
mergeCache->put(a, b, M);
}
return M;
}
bool PredictionContext::combineCommonParents(std::vector<std::weak_ptr<PredictionContext>> &parents) {
bool PredictionContext::combineCommonParents(std::vector<Ref<PredictionContext>> &parents) {
std::unordered_set<Ref<PredictionContext>, PredictionContextHasher, PredictionContextComparer> uniqueParents;
for (size_t p = 0; p < parents.size(); ++p) {
if (parents[p].expired())
if (!parents[p])
continue;
Ref<PredictionContext> parent = parents[p].lock();
Ref<PredictionContext> parent = parents[p];
if (uniqueParents.find(parent) == uniqueParents.end()) { // don't replace
uniqueParents.insert(parent);
}
@ -381,12 +376,13 @@ bool PredictionContext::combineCommonParents(std::vector<std::weak_ptr<Predictio
if (uniqueParents.size() == parents.size())
return false;
// Don't resize the parents array, just update the content.
for (size_t p = 0; p < parents.size(); ++p) {
if (parents[p].expired())
for (size_t p = 0; p < uniqueParents.size(); ++p) {
if (!parents[p])
continue;
parents[p] = *uniqueParents.find(parents[p].lock());
parents[p] = *uniqueParents.find(parents[p]);
}
parents.resize(uniqueParents.size());
return true;
}
@ -437,10 +433,10 @@ std::string PredictionContext::toDOTString(const Ref<PredictionContext> &context
continue;
}
for (size_t i = 0; i < current->size(); i++) {
if (current->getParent(i).expired()) {
if (!current->getParent(i)) {
continue;
}
ss << " s" << current->id << "->" << "s" << current->getParent(i).lock()->id;
ss << " s" << current->id << "->" << "s" << current->getParent(i)->id;
if (current->size() > 1) {
ss << " [label=\"parent[" << i << "]\"];\n";
} else {
@ -475,10 +471,10 @@ Ref<PredictionContext> PredictionContext::getCachedContext(const Ref<PredictionC
bool changed = false;
std::vector<std::weak_ptr<PredictionContext>> parents(context->size());
std::vector<Ref<PredictionContext>> parents(context->size());
for (size_t i = 0; i < parents.size(); i++) {
std::weak_ptr<PredictionContext> parent = getCachedContext(context->getParent(i).lock(), contextCache, visited);
if (changed || parent.lock() != context->getParent(i).lock()) {
Ref<PredictionContext> parent = getCachedContext(context->getParent(i), contextCache, visited);
if (changed || parent != context->getParent(i)) {
if (!changed) {
parents.clear();
for (size_t j = 0; j < context->size(); j++) {
@ -504,11 +500,12 @@ Ref<PredictionContext> PredictionContext::getCachedContext(const Ref<PredictionC
updated = EMPTY;
} else if (parents.size() == 1) {
updated = SingletonPredictionContext::create(parents[0], context->getReturnState(0));
contextCache.insert(updated);
} else {
updated = std::make_shared<ArrayPredictionContext>(parents, std::dynamic_pointer_cast<ArrayPredictionContext>(context)->returnStates);
contextCache.insert(updated);
}
contextCache.insert(updated);
visited[updated] = updated;
visited[context] = updated;
@ -534,7 +531,7 @@ void PredictionContext::getAllContextNodes_(const Ref<PredictionContext> &contex
nodes.push_back(context);
for (size_t i = 0; i < context->size(); i++) {
getAllContextNodes_(context->getParent(i).lock(), nodes, visited);
getAllContextNodes_(context->getParent(i), nodes, visited);
}
}
@ -603,7 +600,7 @@ std::vector<std::string> PredictionContext::toStrings(Recognizer *recognizer, co
}
}
stateNumber = p->getReturnState(index);
p = p->getParent(index).lock().get();
p = p->getParent(index).get();
}
if (outerContinue)
@ -619,3 +616,55 @@ std::vector<std::string> PredictionContext::toStrings(Recognizer *recognizer, co
return result;
}
//----------------- PredictionContextMergeCache ------------------------------------------------------------------------
Ref<PredictionContext> PredictionContextMergeCache::put(Ref<PredictionContext> const& key1, Ref<PredictionContext> const& key2,
Ref<PredictionContext> const& value) {
Ref<PredictionContext> previous;
auto iterator = _data.find(key1);
if (iterator == _data.end())
_data[key1][key2] = value;
else {
auto iterator2 = iterator->second.find(key2);
if (iterator2 != iterator->second.end())
previous = iterator2->second;
iterator->second[key2] = value;
}
return previous;
}
Ref<PredictionContext> PredictionContextMergeCache::get(Ref<PredictionContext> const& key1, Ref<PredictionContext> const& key2) {
auto iterator = _data.find(key1);
if (iterator == _data.end())
return nullptr;
auto iterator2 = iterator->second.find(key2);
if (iterator2 == iterator->second.end())
return nullptr;
return iterator2->second;
}
void PredictionContextMergeCache::clear() {
_data.clear();
}
std::string PredictionContextMergeCache::toString() const {
std::string result;
for (auto pair : _data)
for (auto pair2 : pair.second)
result += pair2.second->toString() + "\n";
return result;
}
size_t PredictionContextMergeCache::count() const {
size_t result = 0;
for (auto entry : _data)
result += entry.second.size();
return result;
}

View File

@ -14,9 +14,10 @@ namespace atn {
struct PredictionContextHasher;
struct PredictionContextComparer;
class PredictionContextMergeCache;
typedef std::unordered_set<Ref<PredictionContext>, PredictionContextHasher, PredictionContextComparer> PredictionContextCache;
typedef std::map<std::pair<Ref<PredictionContext>, Ref<PredictionContext>>, Ref<PredictionContext>> PredictionContextMergeCache;
//typedef std::map<std::pair<Ref<PredictionContext>, Ref<PredictionContext>>, Ref<PredictionContext>> PredictionContextMergeCache;
class ANTLR4CPP_PUBLIC PredictionContext {
public:
@ -72,21 +73,20 @@ namespace atn {
static Ref<PredictionContext> fromRuleContext(const ATN &atn, RuleContext *outerContext);
virtual size_t size() const = 0;
virtual std::weak_ptr<PredictionContext> getParent(size_t index) const = 0;
virtual Ref<PredictionContext> getParent(size_t index) const = 0;
virtual size_t getReturnState(size_t index) const = 0;
virtual bool operator == (const PredictionContext &o) const = 0;
virtual bool operator != (const PredictionContext &o) const;
/// This means only the EMPTY context is in set.
/// This means only the EMPTY (wildcard? not sure) context is in set.
virtual bool isEmpty() const;
virtual bool hasEmptyPath() const;
virtual size_t hashCode() const;
protected:
static size_t calculateEmptyHashCode();
static size_t calculateHashCode(std::weak_ptr<PredictionContext> parent, size_t returnState);
static size_t calculateHashCode(const std::vector<std::weak_ptr<PredictionContext>> &parents,
static size_t calculateHashCode(Ref<PredictionContext> parent, size_t returnState);
static size_t calculateHashCode(const std::vector<Ref<PredictionContext>> &parents,
const std::vector<size_t> &returnStates);
public:
@ -197,7 +197,7 @@ namespace atn {
protected:
/// Make pass over all M parents; merge any equal() ones.
/// @returns true if the list has been changed (i.e. duplicates where found).
static bool combineCommonParents(std::vector<std::weak_ptr<PredictionContext>> &parents);
static bool combineCommonParents(std::vector<Ref<PredictionContext>> &parents);
public:
static std::string toDOTString(const Ref<PredictionContext> &context);
@ -227,10 +227,29 @@ namespace atn {
struct PredictionContextComparer {
bool operator () (const Ref<PredictionContext> &lhs, const Ref<PredictionContext> &rhs) const
{
return *lhs == *rhs;
if (lhs == rhs) // Object identity.
return true;
return (lhs->hashCode() == rhs->hashCode()) && (*lhs == *rhs);
}
};
class PredictionContextMergeCache {
public:
Ref<PredictionContext> put(Ref<PredictionContext> const& key1, Ref<PredictionContext> const& key2,
Ref<PredictionContext> const& value);
Ref<PredictionContext> get(Ref<PredictionContext> const& key1, Ref<PredictionContext> const& key2);
void clear();
std::string toString() const;
size_t count() const;
private:
std::unordered_map<Ref<PredictionContext>,
std::unordered_map<Ref<PredictionContext>, Ref<PredictionContext>, PredictionContextHasher, PredictionContextComparer>,
PredictionContextHasher, PredictionContextComparer> _data;
};
} // namespace atn
} // namespace antlr4

View File

@ -28,7 +28,9 @@ namespace atn {
struct Comparer {
bool operator()(Ref<SemanticContext> const& lhs, Ref<SemanticContext> const& rhs) const {
return *lhs == *rhs;
if (lhs == rhs)
return true;
return (lhs->hashCode() == rhs->hashCode()) && (*lhs == *rhs);
}
};

View File

@ -9,16 +9,15 @@
using namespace antlr4::atn;
SingletonPredictionContext::SingletonPredictionContext(std::weak_ptr<PredictionContext> parent, size_t returnState)
: PredictionContext(!parent.expired() ? calculateHashCode(parent, returnState) : calculateEmptyHashCode()),
parent(parent.lock()), returnState(returnState) {
SingletonPredictionContext::SingletonPredictionContext(Ref<PredictionContext> const& parent, size_t returnState)
: PredictionContext(parent ? calculateHashCode(parent, returnState) : calculateEmptyHashCode()),
parent(parent), returnState(returnState) {
assert(returnState != ATNState::INVALID_STATE_NUMBER);
}
Ref<SingletonPredictionContext> SingletonPredictionContext::create(std::weak_ptr<PredictionContext> parent,
size_t returnState) {
Ref<SingletonPredictionContext> SingletonPredictionContext::create(Ref<PredictionContext> const& parent, size_t returnState) {
if (returnState == EMPTY_RETURN_STATE && parent.expired()) {
if (returnState == EMPTY_RETURN_STATE && parent) {
// someone can pass in the bits of an array ctx that mean $
return std::dynamic_pointer_cast<SingletonPredictionContext>(EMPTY);
}
@ -29,7 +28,7 @@ size_t SingletonPredictionContext::size() const {
return 1;
}
std::weak_ptr<PredictionContext> SingletonPredictionContext::getParent(size_t index) const {
Ref<PredictionContext> SingletonPredictionContext::getParent(size_t index) const {
assert(index == 0);
((void)(index)); // Make Release build happy.
return parent;
@ -55,8 +54,15 @@ bool SingletonPredictionContext::operator == (const PredictionContext &o) const
return false; // can't be same if hash is different
}
//return returnState == other->returnState && (!parent.expired() && parent.lock() == other->parent.lock());
return returnState == other->returnState && (parent != nullptr && *parent == *other->parent);
if (returnState != other->returnState)
return false;
if (!parent && !other->parent)
return true;
if (!parent || !other->parent)
return false;
return *parent == *other->parent;
}
std::string SingletonPredictionContext::toString() const {

View File

@ -20,13 +20,13 @@ namespace atn {
const Ref<PredictionContext> parent;
const size_t returnState;
SingletonPredictionContext(std::weak_ptr<PredictionContext> parent, size_t returnState);
SingletonPredictionContext(Ref<PredictionContext> const& parent, size_t returnState);
virtual ~SingletonPredictionContext() {};
static Ref<SingletonPredictionContext> create(std::weak_ptr<PredictionContext> parent, size_t returnState);
static Ref<SingletonPredictionContext> create(Ref<PredictionContext> const& parent, size_t returnState);
virtual size_t size() const override;
virtual std::weak_ptr<PredictionContext> getParent(size_t index) const override;
virtual Ref<PredictionContext> getParent(size_t index) const override;
virtual size_t getReturnState(size_t index) const override;
virtual bool operator == (const PredictionContext &o) const override;
virtual std::string toString() const override;

View File

@ -20,7 +20,7 @@ namespace antlrcpp {
return false;
for (size_t i = 0; i < a.size(); ++i)
if (a[i] != b[i]) // Requires that the != operator is supported by the template type.
if (!(a[i] == b[i]))
return false;
return true;
@ -31,22 +31,13 @@ namespace antlrcpp {
if (a.size() != b.size())
return false;
for (size_t i = 0; i < a.size(); ++i)
if (*a[i] != *b[i])
for (size_t i = 0; i < a.size(); ++i) {
if (a[i] == b[i])
continue;
if (!(*a[i] == *b[i]))
return false;
return true;
}
template <typename T>
static bool equals(const std::vector<std::weak_ptr<T>> &a, const std::vector<std::weak_ptr<T>> &b) {
if (a.size() != b.size())
return false;
for (size_t i = 0; i < a.size(); ++i)
if (*a[i].lock() != *b[i].lock())
return false;
return true;
}
@ -55,9 +46,17 @@ namespace antlrcpp {
if (a.size() != b.size())
return false;
for (size_t i = 0; i < a.size(); ++i)
if (*a[i] != *b[i])
for (size_t i = 0; i < a.size(); ++i) {
if (!a[i] && !b[i])
continue;
if (!a[i] || !b[i])
return false;
if (a[i] == b[i])
continue;
if (!(*a[i] == *b[i]))
return false;
}
return true;
}