diff --git a/runtime/Java/src/org/antlr/v4/runtime/Parser.java b/runtime/Java/src/org/antlr/v4/runtime/Parser.java index 0eb3a4df5..e478385f9 100644 --- a/runtime/Java/src/org/antlr/v4/runtime/Parser.java +++ b/runtime/Java/src/org/antlr/v4/runtime/Parser.java @@ -550,32 +550,16 @@ public abstract class Parser extends Recognizer { return false; } - /** Compute the set of valid tokens reachable from the current - * position in the parse. + /** + * Computes the set of input symbols which could follow the current parser + * state and context, as given by {@link #getState} and {@link #getContext}, + * respectively. + * + * @see ATN#getExpectedTokens(int, RuleContext) */ - public IntervalSet getExpectedTokens() { - ATN atn = getInterpreter().atn; - ParserRuleContext ctx = _ctx; - ATNState s = atn.states.get(getState()); - IntervalSet following = atn.nextTokens(s); -// System.out.println("following "+s+"="+following); - if ( !following.contains(Token.EPSILON) ) return following; - IntervalSet expected = new IntervalSet(); - expected.addAll(following); - expected.remove(Token.EPSILON); - while ( ctx!=null && ctx.invokingState>=0 && following.contains(Token.EPSILON) ) { - ATNState invokingState = atn.states.get(ctx.invokingState); - RuleTransition rt = (RuleTransition)invokingState.transition(0); - following = atn.nextTokens(rt.followState); - expected.addAll(following); - expected.remove(Token.EPSILON); - ctx = (ParserRuleContext)ctx.parent; - } - if ( following.contains(Token.EPSILON) ) { - expected.add(Token.EOF); - } - return expected; - } + public IntervalSet getExpectedTokens() { + return getATN().getExpectedTokens(getState(), getContext()); + } public IntervalSet getExpectedTokensWithinCurrentRule() { ATN atn = getInterpreter().atn; diff --git a/runtime/Java/src/org/antlr/v4/runtime/RecognitionException.java b/runtime/Java/src/org/antlr/v4/runtime/RecognitionException.java index 0268461b9..ca62f73e4 100644 --- a/runtime/Java/src/org/antlr/v4/runtime/RecognitionException.java +++ b/runtime/Java/src/org/antlr/v4/runtime/RecognitionException.java @@ -92,10 +92,10 @@ public class RecognitionException extends RuntimeException { } public IntervalSet getExpectedTokens() { - // TODO: do we really need this type check? - if ( recognizer!=null && recognizer instanceof Parser) { - return ((Parser) recognizer).getExpectedTokens(); + if (recognizer != null) { + return recognizer.getATN().getExpectedTokens(offendingState, ctx); } + return null; } diff --git a/runtime/Java/src/org/antlr/v4/runtime/atn/ATN.java b/runtime/Java/src/org/antlr/v4/runtime/atn/ATN.java index 6593f7646..27245dacf 100644 --- a/runtime/Java/src/org/antlr/v4/runtime/atn/ATN.java +++ b/runtime/Java/src/org/antlr/v4/runtime/atn/ATN.java @@ -155,4 +155,55 @@ public class ATN { public int getNumberOfDecisions() { return decisionToState.size(); } + + /** + * Computes the set of input symbols which could follow ATN state number + * {@code stateNumber} in the specified full {@code context}. This method + * considers the complete parser context, but does not evaluate semantic + * predicates (i.e. all predicates encountered during the calculation are + * assumed true). If a path in the ATN exists from the starting state to the + * {@link RuleStopState} of the outermost context without matching any + * symbols, {@link Token#EOF} is added to the returned set. + *

+ * If {@code context} is {@code null}, it is treated as + * {@link ParserRuleContext#EMPTY}. + * + * @param stateNumber the ATN state number + * @param context the full parse context + * @return The set of potentially valid input symbols which could follow the + * specified state in the specified context. + * @throws IllegalArgumentException if the ATN does not contain a state with + * number {@code stateNumber} + */ + @NotNull + public IntervalSet getExpectedTokens(int stateNumber, @Nullable RuleContext context) { + if (stateNumber < 0 || stateNumber >= states.size()) { + throw new IllegalArgumentException("Invalid state number."); + } + + RuleContext ctx = context; + ATNState s = states.get(stateNumber); + IntervalSet following = nextTokens(s); + if (!following.contains(Token.EPSILON)) { + return following; + } + + IntervalSet expected = new IntervalSet(); + expected.addAll(following); + expected.remove(Token.EPSILON); + while (ctx != null && ctx.invokingState >= 0 && following.contains(Token.EPSILON)) { + ATNState invokingState = states.get(ctx.invokingState); + RuleTransition rt = (RuleTransition)invokingState.transition(0); + following = nextTokens(rt.followState); + expected.addAll(following); + expected.remove(Token.EPSILON); + ctx = ctx.parent; + } + + if (following.contains(Token.EPSILON)) { + expected.add(Token.EOF); + } + + return expected; + } }