Fix API inconsistency, where getText would expect a tuple named interval, leading to lots of confusion
This commit is contained in:
parent
108854f986
commit
cccf6e87da
|
@ -269,36 +269,32 @@ class BufferedTokenStream(TokenStream):
|
|||
def getSourceName(self):
|
||||
return self.tokenSource.getSourceName()
|
||||
|
||||
def getText(self, interval=None):
|
||||
def getText(self, start=None, stop=None):
|
||||
"""
|
||||
Get the text of all tokens in this buffer.
|
||||
|
||||
:param interval:
|
||||
:type interval: antlr4.IntervalSet.Interval
|
||||
:return: string
|
||||
"""
|
||||
self.lazyInit()
|
||||
self.fill()
|
||||
if interval is None:
|
||||
interval = (0, len(self.tokens)-1)
|
||||
start = interval[0]
|
||||
if isinstance(start, Token):
|
||||
if start is None:
|
||||
start = 0
|
||||
elif isinstance(start, Token):
|
||||
start = start.tokenIndex
|
||||
stop = interval[1]
|
||||
if isinstance(stop, Token):
|
||||
if stop is None or stop >= len(self.tokens):
|
||||
stop = len(self.tokens) - 1
|
||||
elif isinstance(stop, Token):
|
||||
stop = stop.tokenIndex
|
||||
if start is None or stop is None or start<0 or stop<0:
|
||||
if start < 0 or stop < 0 or stop<start:
|
||||
return ""
|
||||
if stop >= len(self.tokens):
|
||||
stop = len(self.tokens)-1
|
||||
with StringIO() as buf:
|
||||
for i in range(start, stop+1):
|
||||
for i in xrange(start, stop+1):
|
||||
t = self.tokens[i]
|
||||
if t.type==Token.EOF:
|
||||
break
|
||||
buf.write(t.text)
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
def fill(self):
|
||||
"""
|
||||
Get all tokens from lexer until EOF
|
||||
|
|
|
@ -100,9 +100,9 @@ class TokenStreamRewriter(object):
|
|||
return self.programs.setdefault(program_name, [])
|
||||
|
||||
def getDefaultText(self):
|
||||
return self.getText(self.DEFAULT_PROGRAM_NAME, Interval(0, len(self.tokens.tokens)))
|
||||
return self.getText(self.DEFAULT_PROGRAM_NAME, 0, len(self.tokens.tokens))
|
||||
|
||||
def getText(self, program_name, interval):
|
||||
def getText(self, program_name, start, stop):
|
||||
"""
|
||||
:type interval: Interval.Interval
|
||||
:param program_name:
|
||||
|
@ -110,15 +110,15 @@ class TokenStreamRewriter(object):
|
|||
:return:
|
||||
"""
|
||||
rewrites = self.programs.get(program_name)
|
||||
start = interval.start
|
||||
stop = interval.stop
|
||||
|
||||
# ensure start/end are in range
|
||||
if stop > len(self.tokens.tokens) - 1: stop = len(self.tokens.tokens) - 1
|
||||
if start < 0: start = 0
|
||||
if stop > len(self.tokens.tokens) - 1:
|
||||
stop = len(self.tokens.tokens) - 1
|
||||
if start < 0:
|
||||
start = 0
|
||||
|
||||
# if no instructions to execute
|
||||
if not rewrites: return self.tokens.getText(interval)
|
||||
if not rewrites: return self.tokens.getText(start, stop)
|
||||
buf = StringIO()
|
||||
indexToOp = self._reduceToSingleOperationPerIndex(rewrites)
|
||||
i = start
|
||||
|
|
|
@ -4,8 +4,6 @@
|
|||
|
||||
import unittest
|
||||
|
||||
from antlr4.IntervalSet import Interval
|
||||
|
||||
from mocks.TestLexer import TestLexer, TestLexer2
|
||||
from antlr4.TokenStreamRewriter import TokenStreamRewriter
|
||||
from antlr4.InputStream import InputStream
|
||||
|
@ -88,8 +86,8 @@ class TestTokenStreamRewriter(unittest.TestCase):
|
|||
rewriter.replaceRange(4, 8, '0')
|
||||
|
||||
self.assertEquals(rewriter.getDefaultText(), 'x = 0;')
|
||||
self.assertEquals(rewriter.getText('default', Interval(0, 9)), 'x = 0;')
|
||||
self.assertEquals(rewriter.getText('default', Interval(4, 8)), '0')
|
||||
self.assertEquals(rewriter.getText('default', 0, 9), 'x = 0;')
|
||||
self.assertEquals(rewriter.getText('default', 4, 8), '0')
|
||||
|
||||
def testToStringStartStop2(self):
|
||||
input = InputStream('x = 3 * 0 + 2 * 0;')
|
||||
|
@ -103,15 +101,15 @@ class TestTokenStreamRewriter(unittest.TestCase):
|
|||
# replace 3 * 0 with 0
|
||||
rewriter.replaceRange(4, 8, '0')
|
||||
self.assertEquals('x = 0 + 2 * 0;', rewriter.getDefaultText())
|
||||
self.assertEquals('x = 0 + 2 * 0;', rewriter.getText('default', Interval(0, 17)))
|
||||
self.assertEquals('0', rewriter.getText('default', Interval(4, 8)))
|
||||
self.assertEquals('x = 0', rewriter.getText('default', Interval(0, 8)))
|
||||
self.assertEquals('2 * 0', rewriter.getText('default', Interval(12, 16)))
|
||||
self.assertEquals('x = 0 + 2 * 0;', rewriter.getText('default', 0, 17))
|
||||
self.assertEquals('0', rewriter.getText('default', 4, 8))
|
||||
self.assertEquals('x = 0', rewriter.getText('default', 0, 8))
|
||||
self.assertEquals('2 * 0', rewriter.getText('default', 12, 16))
|
||||
|
||||
rewriter.insertAfter(17, "// comment")
|
||||
self.assertEquals('2 * 0;// comment', rewriter.getText('default', Interval(12, 18)))
|
||||
self.assertEquals('2 * 0;// comment', rewriter.getText('default', 12, 18))
|
||||
|
||||
self.assertEquals('x = 0', rewriter.getText('default', Interval(0, 8)))
|
||||
self.assertEquals('x = 0', rewriter.getText('default', 0, 8))
|
||||
|
||||
def test2ReplaceMiddleIndex(self):
|
||||
input = InputStream('abc')
|
||||
|
|
|
@ -272,21 +272,19 @@ class BufferedTokenStream(TokenStream):
|
|||
return self.tokenSource.getSourceName()
|
||||
|
||||
# Get the text of all tokens in this buffer.#/
|
||||
def getText(self, interval:tuple=None):
|
||||
def getText(self, start:int=None, stop:int=None):
|
||||
self.lazyInit()
|
||||
self.fill()
|
||||
if interval is None:
|
||||
interval = (0, len(self.tokens)-1)
|
||||
start = interval[0]
|
||||
if isinstance(start, Token):
|
||||
if start is None:
|
||||
start = 0
|
||||
elif isinstance(start, Token):
|
||||
start = start.tokenIndex
|
||||
stop = interval[1]
|
||||
if isinstance(stop, Token):
|
||||
if stop is None or stop >= len(self.tokens):
|
||||
stop = len(self.tokens) - 1
|
||||
elif isinstance(stop, Token):
|
||||
stop = stop.tokenIndex
|
||||
if start is None or stop is None or start<0 or stop<0:
|
||||
if start < 0 or stop < 0 or stop < start:
|
||||
return ""
|
||||
if stop >= len(self.tokens):
|
||||
stop = len(self.tokens)-1
|
||||
with StringIO() as buf:
|
||||
for i in range(start, stop+1):
|
||||
t = self.tokens[i]
|
||||
|
|
|
@ -96,23 +96,20 @@ class TokenStreamRewriter(object):
|
|||
def getProgram(self, program_name):
|
||||
return self.programs.setdefault(program_name, [])
|
||||
|
||||
def getText(self, program_name, interval):
|
||||
def getText(self, program_name, start:int, stop:int):
|
||||
"""
|
||||
:type interval: Interval.Interval
|
||||
:param program_name:
|
||||
:param interval:
|
||||
:return:
|
||||
"""
|
||||
rewrites = self.programs.get(program_name)
|
||||
start = interval.start
|
||||
stop = interval.stop
|
||||
|
||||
# ensure start/end are in range
|
||||
if stop > len(self.tokens.tokens) - 1: stop = len(self.tokens.tokens) - 1
|
||||
if start < 0: start = 0
|
||||
if stop > len(self.tokens.tokens) - 1:
|
||||
stop = len(self.tokens.tokens) - 1
|
||||
if start < 0:
|
||||
start = 0
|
||||
|
||||
# if no instructions to execute
|
||||
if not rewrites: return self.tokens.getText(interval)
|
||||
if not rewrites: return self.tokens.getText(start, stop)
|
||||
buf = StringIO()
|
||||
indexToOp = self._reduceToSingleOperationPerIndex(rewrites)
|
||||
i = start
|
||||
|
@ -149,12 +146,12 @@ class TokenStreamRewriter(object):
|
|||
prevReplaces = [op for op in rewrites[:i] if isinstance(op, TokenStreamRewriter.ReplaceOp)]
|
||||
for prevRop in prevReplaces:
|
||||
if all((prevRop.index >= rop.index, prevRop.last_index <= rop.last_index)):
|
||||
rewrites[prevRop.instructioIndex] = None
|
||||
rewrites[prevRop.instructionIndex] = None
|
||||
continue
|
||||
isDisjoint = any((prevRop.last_index<rop.index, prevRop.index>rop))
|
||||
isSame = all((prevRop.index == rop.index, prevRop.last_index == rop.last_index))
|
||||
if all((prevRop.text is None, rop.text is None, not isDisjoint)):
|
||||
rewrites[prevRop.instructioIndex] = None
|
||||
rewrites[prevRop.instructionIndex] = None
|
||||
rop.index = min(prevRop.index, rop.index)
|
||||
rop.last_index = min(prevRop.last_index, rop.last_index)
|
||||
print('New rop {}'.format(rop))
|
||||
|
|
Loading…
Reference in New Issue