Fix API inconsistency, where getText would expect a tuple named interval, leading to lots of confusion

This commit is contained in:
Eric Vergnaud 2019-01-01 13:20:23 +01:00
parent 108854f986
commit cccf6e87da
5 changed files with 41 additions and 52 deletions

View File

@ -269,36 +269,32 @@ class BufferedTokenStream(TokenStream):
def getSourceName(self):
return self.tokenSource.getSourceName()
def getText(self, interval=None):
def getText(self, start=None, stop=None):
"""
Get the text of all tokens in this buffer.
:param interval:
:type interval: antlr4.IntervalSet.Interval
:return: string
"""
self.lazyInit()
self.fill()
if interval is None:
interval = (0, len(self.tokens)-1)
start = interval[0]
if isinstance(start, Token):
if start is None:
start = 0
elif isinstance(start, Token):
start = start.tokenIndex
stop = interval[1]
if isinstance(stop, Token):
if stop is None or stop >= len(self.tokens):
stop = len(self.tokens) - 1
elif isinstance(stop, Token):
stop = stop.tokenIndex
if start is None or stop is None or start<0 or stop<0:
if start < 0 or stop < 0 or stop<start:
return ""
if stop >= len(self.tokens):
stop = len(self.tokens)-1
with StringIO() as buf:
for i in range(start, stop+1):
for i in xrange(start, stop+1):
t = self.tokens[i]
if t.type==Token.EOF:
break
buf.write(t.text)
return buf.getvalue()
def fill(self):
"""
Get all tokens from lexer until EOF

View File

@ -100,9 +100,9 @@ class TokenStreamRewriter(object):
return self.programs.setdefault(program_name, [])
def getDefaultText(self):
return self.getText(self.DEFAULT_PROGRAM_NAME, Interval(0, len(self.tokens.tokens)))
return self.getText(self.DEFAULT_PROGRAM_NAME, 0, len(self.tokens.tokens))
def getText(self, program_name, interval):
def getText(self, program_name, start, stop):
"""
:type interval: Interval.Interval
:param program_name:
@ -110,15 +110,15 @@ class TokenStreamRewriter(object):
:return:
"""
rewrites = self.programs.get(program_name)
start = interval.start
stop = interval.stop
# ensure start/end are in range
if stop > len(self.tokens.tokens) - 1: stop = len(self.tokens.tokens) - 1
if start < 0: start = 0
if stop > len(self.tokens.tokens) - 1:
stop = len(self.tokens.tokens) - 1
if start < 0:
start = 0
# if no instructions to execute
if not rewrites: return self.tokens.getText(interval)
if not rewrites: return self.tokens.getText(start, stop)
buf = StringIO()
indexToOp = self._reduceToSingleOperationPerIndex(rewrites)
i = start

View File

@ -4,8 +4,6 @@
import unittest
from antlr4.IntervalSet import Interval
from mocks.TestLexer import TestLexer, TestLexer2
from antlr4.TokenStreamRewriter import TokenStreamRewriter
from antlr4.InputStream import InputStream
@ -88,8 +86,8 @@ class TestTokenStreamRewriter(unittest.TestCase):
rewriter.replaceRange(4, 8, '0')
self.assertEquals(rewriter.getDefaultText(), 'x = 0;')
self.assertEquals(rewriter.getText('default', Interval(0, 9)), 'x = 0;')
self.assertEquals(rewriter.getText('default', Interval(4, 8)), '0')
self.assertEquals(rewriter.getText('default', 0, 9), 'x = 0;')
self.assertEquals(rewriter.getText('default', 4, 8), '0')
def testToStringStartStop2(self):
input = InputStream('x = 3 * 0 + 2 * 0;')
@ -103,15 +101,15 @@ class TestTokenStreamRewriter(unittest.TestCase):
# replace 3 * 0 with 0
rewriter.replaceRange(4, 8, '0')
self.assertEquals('x = 0 + 2 * 0;', rewriter.getDefaultText())
self.assertEquals('x = 0 + 2 * 0;', rewriter.getText('default', Interval(0, 17)))
self.assertEquals('0', rewriter.getText('default', Interval(4, 8)))
self.assertEquals('x = 0', rewriter.getText('default', Interval(0, 8)))
self.assertEquals('2 * 0', rewriter.getText('default', Interval(12, 16)))
self.assertEquals('x = 0 + 2 * 0;', rewriter.getText('default', 0, 17))
self.assertEquals('0', rewriter.getText('default', 4, 8))
self.assertEquals('x = 0', rewriter.getText('default', 0, 8))
self.assertEquals('2 * 0', rewriter.getText('default', 12, 16))
rewriter.insertAfter(17, "// comment")
self.assertEquals('2 * 0;// comment', rewriter.getText('default', Interval(12, 18)))
self.assertEquals('2 * 0;// comment', rewriter.getText('default', 12, 18))
self.assertEquals('x = 0', rewriter.getText('default', Interval(0, 8)))
self.assertEquals('x = 0', rewriter.getText('default', 0, 8))
def test2ReplaceMiddleIndex(self):
input = InputStream('abc')

View File

@ -272,21 +272,19 @@ class BufferedTokenStream(TokenStream):
return self.tokenSource.getSourceName()
# Get the text of all tokens in this buffer.#/
def getText(self, interval:tuple=None):
def getText(self, start:int=None, stop:int=None):
self.lazyInit()
self.fill()
if interval is None:
interval = (0, len(self.tokens)-1)
start = interval[0]
if isinstance(start, Token):
if start is None:
start = 0
elif isinstance(start, Token):
start = start.tokenIndex
stop = interval[1]
if isinstance(stop, Token):
if stop is None or stop >= len(self.tokens):
stop = len(self.tokens) - 1
elif isinstance(stop, Token):
stop = stop.tokenIndex
if start is None or stop is None or start<0 or stop<0:
if start < 0 or stop < 0 or stop < start:
return ""
if stop >= len(self.tokens):
stop = len(self.tokens)-1
with StringIO() as buf:
for i in range(start, stop+1):
t = self.tokens[i]

View File

@ -96,23 +96,20 @@ class TokenStreamRewriter(object):
def getProgram(self, program_name):
return self.programs.setdefault(program_name, [])
def getText(self, program_name, interval):
def getText(self, program_name, start:int, stop:int):
"""
:type interval: Interval.Interval
:param program_name:
:param interval:
:return:
"""
rewrites = self.programs.get(program_name)
start = interval.start
stop = interval.stop
# ensure start/end are in range
if stop > len(self.tokens.tokens) - 1: stop = len(self.tokens.tokens) - 1
if start < 0: start = 0
if stop > len(self.tokens.tokens) - 1:
stop = len(self.tokens.tokens) - 1
if start < 0:
start = 0
# if no instructions to execute
if not rewrites: return self.tokens.getText(interval)
if not rewrites: return self.tokens.getText(start, stop)
buf = StringIO()
indexToOp = self._reduceToSingleOperationPerIndex(rewrites)
i = start
@ -149,12 +146,12 @@ class TokenStreamRewriter(object):
prevReplaces = [op for op in rewrites[:i] if isinstance(op, TokenStreamRewriter.ReplaceOp)]
for prevRop in prevReplaces:
if all((prevRop.index >= rop.index, prevRop.last_index <= rop.last_index)):
rewrites[prevRop.instructioIndex] = None
rewrites[prevRop.instructionIndex] = None
continue
isDisjoint = any((prevRop.last_index<rop.index, prevRop.index>rop))
isSame = all((prevRop.index == rop.index, prevRop.last_index == rop.last_index))
if all((prevRop.text is None, rop.text is None, not isDisjoint)):
rewrites[prevRop.instructioIndex] = None
rewrites[prevRop.instructionIndex] = None
rop.index = min(prevRop.index, rop.index)
rop.last_index = min(prevRop.last_index, rop.last_index)
print('New rop {}'.format(rop))