Fix API inconsistency, where getText would expect a tuple named interval, leading to lots of confusion

This commit is contained in:
Eric Vergnaud 2019-01-01 13:20:23 +01:00
parent 108854f986
commit cccf6e87da
5 changed files with 41 additions and 52 deletions

View File

@ -269,36 +269,32 @@ class BufferedTokenStream(TokenStream):
def getSourceName(self): def getSourceName(self):
return self.tokenSource.getSourceName() return self.tokenSource.getSourceName()
def getText(self, interval=None): def getText(self, start=None, stop=None):
""" """
Get the text of all tokens in this buffer. Get the text of all tokens in this buffer.
:param interval:
:type interval: antlr4.IntervalSet.Interval
:return: string :return: string
""" """
self.lazyInit() self.lazyInit()
self.fill() self.fill()
if interval is None: if start is None:
interval = (0, len(self.tokens)-1) start = 0
start = interval[0] elif isinstance(start, Token):
if isinstance(start, Token):
start = start.tokenIndex start = start.tokenIndex
stop = interval[1] if stop is None or stop >= len(self.tokens):
if isinstance(stop, Token):
stop = stop.tokenIndex
if start is None or stop is None or start<0 or stop<0:
return ""
if stop >= len(self.tokens):
stop = len(self.tokens) - 1 stop = len(self.tokens) - 1
elif isinstance(stop, Token):
stop = stop.tokenIndex
if start < 0 or stop < 0 or stop<start:
return ""
with StringIO() as buf: with StringIO() as buf:
for i in range(start, stop+1): for i in xrange(start, stop+1):
t = self.tokens[i] t = self.tokens[i]
if t.type==Token.EOF: if t.type==Token.EOF:
break break
buf.write(t.text) buf.write(t.text)
return buf.getvalue() return buf.getvalue()
def fill(self): def fill(self):
""" """
Get all tokens from lexer until EOF Get all tokens from lexer until EOF

View File

@ -100,9 +100,9 @@ class TokenStreamRewriter(object):
return self.programs.setdefault(program_name, []) return self.programs.setdefault(program_name, [])
def getDefaultText(self): def getDefaultText(self):
return self.getText(self.DEFAULT_PROGRAM_NAME, Interval(0, len(self.tokens.tokens))) return self.getText(self.DEFAULT_PROGRAM_NAME, 0, len(self.tokens.tokens))
def getText(self, program_name, interval): def getText(self, program_name, start, stop):
""" """
:type interval: Interval.Interval :type interval: Interval.Interval
:param program_name: :param program_name:
@ -110,15 +110,15 @@ class TokenStreamRewriter(object):
:return: :return:
""" """
rewrites = self.programs.get(program_name) rewrites = self.programs.get(program_name)
start = interval.start
stop = interval.stop
# ensure start/end are in range # ensure start/end are in range
if stop > len(self.tokens.tokens) - 1: stop = len(self.tokens.tokens) - 1 if stop > len(self.tokens.tokens) - 1:
if start < 0: start = 0 stop = len(self.tokens.tokens) - 1
if start < 0:
start = 0
# if no instructions to execute # if no instructions to execute
if not rewrites: return self.tokens.getText(interval) if not rewrites: return self.tokens.getText(start, stop)
buf = StringIO() buf = StringIO()
indexToOp = self._reduceToSingleOperationPerIndex(rewrites) indexToOp = self._reduceToSingleOperationPerIndex(rewrites)
i = start i = start

View File

@ -4,8 +4,6 @@
import unittest import unittest
from antlr4.IntervalSet import Interval
from mocks.TestLexer import TestLexer, TestLexer2 from mocks.TestLexer import TestLexer, TestLexer2
from antlr4.TokenStreamRewriter import TokenStreamRewriter from antlr4.TokenStreamRewriter import TokenStreamRewriter
from antlr4.InputStream import InputStream from antlr4.InputStream import InputStream
@ -88,8 +86,8 @@ class TestTokenStreamRewriter(unittest.TestCase):
rewriter.replaceRange(4, 8, '0') rewriter.replaceRange(4, 8, '0')
self.assertEquals(rewriter.getDefaultText(), 'x = 0;') self.assertEquals(rewriter.getDefaultText(), 'x = 0;')
self.assertEquals(rewriter.getText('default', Interval(0, 9)), 'x = 0;') self.assertEquals(rewriter.getText('default', 0, 9), 'x = 0;')
self.assertEquals(rewriter.getText('default', Interval(4, 8)), '0') self.assertEquals(rewriter.getText('default', 4, 8), '0')
def testToStringStartStop2(self): def testToStringStartStop2(self):
input = InputStream('x = 3 * 0 + 2 * 0;') input = InputStream('x = 3 * 0 + 2 * 0;')
@ -103,15 +101,15 @@ class TestTokenStreamRewriter(unittest.TestCase):
# replace 3 * 0 with 0 # replace 3 * 0 with 0
rewriter.replaceRange(4, 8, '0') rewriter.replaceRange(4, 8, '0')
self.assertEquals('x = 0 + 2 * 0;', rewriter.getDefaultText()) self.assertEquals('x = 0 + 2 * 0;', rewriter.getDefaultText())
self.assertEquals('x = 0 + 2 * 0;', rewriter.getText('default', Interval(0, 17))) self.assertEquals('x = 0 + 2 * 0;', rewriter.getText('default', 0, 17))
self.assertEquals('0', rewriter.getText('default', Interval(4, 8))) self.assertEquals('0', rewriter.getText('default', 4, 8))
self.assertEquals('x = 0', rewriter.getText('default', Interval(0, 8))) self.assertEquals('x = 0', rewriter.getText('default', 0, 8))
self.assertEquals('2 * 0', rewriter.getText('default', Interval(12, 16))) self.assertEquals('2 * 0', rewriter.getText('default', 12, 16))
rewriter.insertAfter(17, "// comment") rewriter.insertAfter(17, "// comment")
self.assertEquals('2 * 0;// comment', rewriter.getText('default', Interval(12, 18))) self.assertEquals('2 * 0;// comment', rewriter.getText('default', 12, 18))
self.assertEquals('x = 0', rewriter.getText('default', Interval(0, 8))) self.assertEquals('x = 0', rewriter.getText('default', 0, 8))
def test2ReplaceMiddleIndex(self): def test2ReplaceMiddleIndex(self):
input = InputStream('abc') input = InputStream('abc')

View File

@ -272,21 +272,19 @@ class BufferedTokenStream(TokenStream):
return self.tokenSource.getSourceName() return self.tokenSource.getSourceName()
# Get the text of all tokens in this buffer.#/ # Get the text of all tokens in this buffer.#/
def getText(self, interval:tuple=None): def getText(self, start:int=None, stop:int=None):
self.lazyInit() self.lazyInit()
self.fill() self.fill()
if interval is None: if start is None:
interval = (0, len(self.tokens)-1) start = 0
start = interval[0] elif isinstance(start, Token):
if isinstance(start, Token):
start = start.tokenIndex start = start.tokenIndex
stop = interval[1] if stop is None or stop >= len(self.tokens):
if isinstance(stop, Token):
stop = stop.tokenIndex
if start is None or stop is None or start<0 or stop<0:
return ""
if stop >= len(self.tokens):
stop = len(self.tokens) - 1 stop = len(self.tokens) - 1
elif isinstance(stop, Token):
stop = stop.tokenIndex
if start < 0 or stop < 0 or stop < start:
return ""
with StringIO() as buf: with StringIO() as buf:
for i in range(start, stop+1): for i in range(start, stop+1):
t = self.tokens[i] t = self.tokens[i]

View File

@ -96,23 +96,20 @@ class TokenStreamRewriter(object):
def getProgram(self, program_name): def getProgram(self, program_name):
return self.programs.setdefault(program_name, []) return self.programs.setdefault(program_name, [])
def getText(self, program_name, interval): def getText(self, program_name, start:int, stop:int):
""" """
:type interval: Interval.Interval
:param program_name:
:param interval:
:return: :return:
""" """
rewrites = self.programs.get(program_name) rewrites = self.programs.get(program_name)
start = interval.start
stop = interval.stop
# ensure start/end are in range # ensure start/end are in range
if stop > len(self.tokens.tokens) - 1: stop = len(self.tokens.tokens) - 1 if stop > len(self.tokens.tokens) - 1:
if start < 0: start = 0 stop = len(self.tokens.tokens) - 1
if start < 0:
start = 0
# if no instructions to execute # if no instructions to execute
if not rewrites: return self.tokens.getText(interval) if not rewrites: return self.tokens.getText(start, stop)
buf = StringIO() buf = StringIO()
indexToOp = self._reduceToSingleOperationPerIndex(rewrites) indexToOp = self._reduceToSingleOperationPerIndex(rewrites)
i = start i = start
@ -149,12 +146,12 @@ class TokenStreamRewriter(object):
prevReplaces = [op for op in rewrites[:i] if isinstance(op, TokenStreamRewriter.ReplaceOp)] prevReplaces = [op for op in rewrites[:i] if isinstance(op, TokenStreamRewriter.ReplaceOp)]
for prevRop in prevReplaces: for prevRop in prevReplaces:
if all((prevRop.index >= rop.index, prevRop.last_index <= rop.last_index)): if all((prevRop.index >= rop.index, prevRop.last_index <= rop.last_index)):
rewrites[prevRop.instructioIndex] = None rewrites[prevRop.instructionIndex] = None
continue continue
isDisjoint = any((prevRop.last_index<rop.index, prevRop.index>rop)) isDisjoint = any((prevRop.last_index<rop.index, prevRop.index>rop))
isSame = all((prevRop.index == rop.index, prevRop.last_index == rop.last_index)) isSame = all((prevRop.index == rop.index, prevRop.last_index == rop.last_index))
if all((prevRop.text is None, rop.text is None, not isDisjoint)): if all((prevRop.text is None, rop.text is None, not isDisjoint)):
rewrites[prevRop.instructioIndex] = None rewrites[prevRop.instructionIndex] = None
rop.index = min(prevRop.index, rop.index) rop.index = min(prevRop.index, rop.index)
rop.last_index = min(prevRop.last_index, rop.last_index) rop.last_index = min(prevRop.last_index, rop.last_index)
print('New rop {}'.format(rop)) print('New rop {}'.format(rop))