forked from jasder/antlr
Fix API inconsistency, where getText would expect a tuple named interval, leading to lots of confusion
This commit is contained in:
parent
108854f986
commit
cccf6e87da
|
@ -269,36 +269,32 @@ class BufferedTokenStream(TokenStream):
|
||||||
def getSourceName(self):
|
def getSourceName(self):
|
||||||
return self.tokenSource.getSourceName()
|
return self.tokenSource.getSourceName()
|
||||||
|
|
||||||
def getText(self, interval=None):
|
def getText(self, start=None, stop=None):
|
||||||
"""
|
"""
|
||||||
Get the text of all tokens in this buffer.
|
Get the text of all tokens in this buffer.
|
||||||
|
|
||||||
:param interval:
|
|
||||||
:type interval: antlr4.IntervalSet.Interval
|
|
||||||
:return: string
|
:return: string
|
||||||
"""
|
"""
|
||||||
self.lazyInit()
|
self.lazyInit()
|
||||||
self.fill()
|
self.fill()
|
||||||
if interval is None:
|
if start is None:
|
||||||
interval = (0, len(self.tokens)-1)
|
start = 0
|
||||||
start = interval[0]
|
elif isinstance(start, Token):
|
||||||
if isinstance(start, Token):
|
|
||||||
start = start.tokenIndex
|
start = start.tokenIndex
|
||||||
stop = interval[1]
|
if stop is None or stop >= len(self.tokens):
|
||||||
if isinstance(stop, Token):
|
|
||||||
stop = stop.tokenIndex
|
|
||||||
if start is None or stop is None or start<0 or stop<0:
|
|
||||||
return ""
|
|
||||||
if stop >= len(self.tokens):
|
|
||||||
stop = len(self.tokens) - 1
|
stop = len(self.tokens) - 1
|
||||||
|
elif isinstance(stop, Token):
|
||||||
|
stop = stop.tokenIndex
|
||||||
|
if start < 0 or stop < 0 or stop<start:
|
||||||
|
return ""
|
||||||
with StringIO() as buf:
|
with StringIO() as buf:
|
||||||
for i in range(start, stop+1):
|
for i in xrange(start, stop+1):
|
||||||
t = self.tokens[i]
|
t = self.tokens[i]
|
||||||
if t.type==Token.EOF:
|
if t.type==Token.EOF:
|
||||||
break
|
break
|
||||||
buf.write(t.text)
|
buf.write(t.text)
|
||||||
return buf.getvalue()
|
return buf.getvalue()
|
||||||
|
|
||||||
|
|
||||||
def fill(self):
|
def fill(self):
|
||||||
"""
|
"""
|
||||||
Get all tokens from lexer until EOF
|
Get all tokens from lexer until EOF
|
||||||
|
|
|
@ -100,9 +100,9 @@ class TokenStreamRewriter(object):
|
||||||
return self.programs.setdefault(program_name, [])
|
return self.programs.setdefault(program_name, [])
|
||||||
|
|
||||||
def getDefaultText(self):
|
def getDefaultText(self):
|
||||||
return self.getText(self.DEFAULT_PROGRAM_NAME, Interval(0, len(self.tokens.tokens)))
|
return self.getText(self.DEFAULT_PROGRAM_NAME, 0, len(self.tokens.tokens))
|
||||||
|
|
||||||
def getText(self, program_name, interval):
|
def getText(self, program_name, start, stop):
|
||||||
"""
|
"""
|
||||||
:type interval: Interval.Interval
|
:type interval: Interval.Interval
|
||||||
:param program_name:
|
:param program_name:
|
||||||
|
@ -110,15 +110,15 @@ class TokenStreamRewriter(object):
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
rewrites = self.programs.get(program_name)
|
rewrites = self.programs.get(program_name)
|
||||||
start = interval.start
|
|
||||||
stop = interval.stop
|
|
||||||
|
|
||||||
# ensure start/end are in range
|
# ensure start/end are in range
|
||||||
if stop > len(self.tokens.tokens) - 1: stop = len(self.tokens.tokens) - 1
|
if stop > len(self.tokens.tokens) - 1:
|
||||||
if start < 0: start = 0
|
stop = len(self.tokens.tokens) - 1
|
||||||
|
if start < 0:
|
||||||
|
start = 0
|
||||||
|
|
||||||
# if no instructions to execute
|
# if no instructions to execute
|
||||||
if not rewrites: return self.tokens.getText(interval)
|
if not rewrites: return self.tokens.getText(start, stop)
|
||||||
buf = StringIO()
|
buf = StringIO()
|
||||||
indexToOp = self._reduceToSingleOperationPerIndex(rewrites)
|
indexToOp = self._reduceToSingleOperationPerIndex(rewrites)
|
||||||
i = start
|
i = start
|
||||||
|
|
|
@ -4,8 +4,6 @@
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from antlr4.IntervalSet import Interval
|
|
||||||
|
|
||||||
from mocks.TestLexer import TestLexer, TestLexer2
|
from mocks.TestLexer import TestLexer, TestLexer2
|
||||||
from antlr4.TokenStreamRewriter import TokenStreamRewriter
|
from antlr4.TokenStreamRewriter import TokenStreamRewriter
|
||||||
from antlr4.InputStream import InputStream
|
from antlr4.InputStream import InputStream
|
||||||
|
@ -88,8 +86,8 @@ class TestTokenStreamRewriter(unittest.TestCase):
|
||||||
rewriter.replaceRange(4, 8, '0')
|
rewriter.replaceRange(4, 8, '0')
|
||||||
|
|
||||||
self.assertEquals(rewriter.getDefaultText(), 'x = 0;')
|
self.assertEquals(rewriter.getDefaultText(), 'x = 0;')
|
||||||
self.assertEquals(rewriter.getText('default', Interval(0, 9)), 'x = 0;')
|
self.assertEquals(rewriter.getText('default', 0, 9), 'x = 0;')
|
||||||
self.assertEquals(rewriter.getText('default', Interval(4, 8)), '0')
|
self.assertEquals(rewriter.getText('default', 4, 8), '0')
|
||||||
|
|
||||||
def testToStringStartStop2(self):
|
def testToStringStartStop2(self):
|
||||||
input = InputStream('x = 3 * 0 + 2 * 0;')
|
input = InputStream('x = 3 * 0 + 2 * 0;')
|
||||||
|
@ -103,15 +101,15 @@ class TestTokenStreamRewriter(unittest.TestCase):
|
||||||
# replace 3 * 0 with 0
|
# replace 3 * 0 with 0
|
||||||
rewriter.replaceRange(4, 8, '0')
|
rewriter.replaceRange(4, 8, '0')
|
||||||
self.assertEquals('x = 0 + 2 * 0;', rewriter.getDefaultText())
|
self.assertEquals('x = 0 + 2 * 0;', rewriter.getDefaultText())
|
||||||
self.assertEquals('x = 0 + 2 * 0;', rewriter.getText('default', Interval(0, 17)))
|
self.assertEquals('x = 0 + 2 * 0;', rewriter.getText('default', 0, 17))
|
||||||
self.assertEquals('0', rewriter.getText('default', Interval(4, 8)))
|
self.assertEquals('0', rewriter.getText('default', 4, 8))
|
||||||
self.assertEquals('x = 0', rewriter.getText('default', Interval(0, 8)))
|
self.assertEquals('x = 0', rewriter.getText('default', 0, 8))
|
||||||
self.assertEquals('2 * 0', rewriter.getText('default', Interval(12, 16)))
|
self.assertEquals('2 * 0', rewriter.getText('default', 12, 16))
|
||||||
|
|
||||||
rewriter.insertAfter(17, "// comment")
|
rewriter.insertAfter(17, "// comment")
|
||||||
self.assertEquals('2 * 0;// comment', rewriter.getText('default', Interval(12, 18)))
|
self.assertEquals('2 * 0;// comment', rewriter.getText('default', 12, 18))
|
||||||
|
|
||||||
self.assertEquals('x = 0', rewriter.getText('default', Interval(0, 8)))
|
self.assertEquals('x = 0', rewriter.getText('default', 0, 8))
|
||||||
|
|
||||||
def test2ReplaceMiddleIndex(self):
|
def test2ReplaceMiddleIndex(self):
|
||||||
input = InputStream('abc')
|
input = InputStream('abc')
|
||||||
|
|
|
@ -272,21 +272,19 @@ class BufferedTokenStream(TokenStream):
|
||||||
return self.tokenSource.getSourceName()
|
return self.tokenSource.getSourceName()
|
||||||
|
|
||||||
# Get the text of all tokens in this buffer.#/
|
# Get the text of all tokens in this buffer.#/
|
||||||
def getText(self, interval:tuple=None):
|
def getText(self, start:int=None, stop:int=None):
|
||||||
self.lazyInit()
|
self.lazyInit()
|
||||||
self.fill()
|
self.fill()
|
||||||
if interval is None:
|
if start is None:
|
||||||
interval = (0, len(self.tokens)-1)
|
start = 0
|
||||||
start = interval[0]
|
elif isinstance(start, Token):
|
||||||
if isinstance(start, Token):
|
|
||||||
start = start.tokenIndex
|
start = start.tokenIndex
|
||||||
stop = interval[1]
|
if stop is None or stop >= len(self.tokens):
|
||||||
if isinstance(stop, Token):
|
|
||||||
stop = stop.tokenIndex
|
|
||||||
if start is None or stop is None or start<0 or stop<0:
|
|
||||||
return ""
|
|
||||||
if stop >= len(self.tokens):
|
|
||||||
stop = len(self.tokens) - 1
|
stop = len(self.tokens) - 1
|
||||||
|
elif isinstance(stop, Token):
|
||||||
|
stop = stop.tokenIndex
|
||||||
|
if start < 0 or stop < 0 or stop < start:
|
||||||
|
return ""
|
||||||
with StringIO() as buf:
|
with StringIO() as buf:
|
||||||
for i in range(start, stop+1):
|
for i in range(start, stop+1):
|
||||||
t = self.tokens[i]
|
t = self.tokens[i]
|
||||||
|
|
|
@ -96,23 +96,20 @@ class TokenStreamRewriter(object):
|
||||||
def getProgram(self, program_name):
|
def getProgram(self, program_name):
|
||||||
return self.programs.setdefault(program_name, [])
|
return self.programs.setdefault(program_name, [])
|
||||||
|
|
||||||
def getText(self, program_name, interval):
|
def getText(self, program_name, start:int, stop:int):
|
||||||
"""
|
"""
|
||||||
:type interval: Interval.Interval
|
|
||||||
:param program_name:
|
|
||||||
:param interval:
|
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
rewrites = self.programs.get(program_name)
|
rewrites = self.programs.get(program_name)
|
||||||
start = interval.start
|
|
||||||
stop = interval.stop
|
|
||||||
|
|
||||||
# ensure start/end are in range
|
# ensure start/end are in range
|
||||||
if stop > len(self.tokens.tokens) - 1: stop = len(self.tokens.tokens) - 1
|
if stop > len(self.tokens.tokens) - 1:
|
||||||
if start < 0: start = 0
|
stop = len(self.tokens.tokens) - 1
|
||||||
|
if start < 0:
|
||||||
|
start = 0
|
||||||
|
|
||||||
# if no instructions to execute
|
# if no instructions to execute
|
||||||
if not rewrites: return self.tokens.getText(interval)
|
if not rewrites: return self.tokens.getText(start, stop)
|
||||||
buf = StringIO()
|
buf = StringIO()
|
||||||
indexToOp = self._reduceToSingleOperationPerIndex(rewrites)
|
indexToOp = self._reduceToSingleOperationPerIndex(rewrites)
|
||||||
i = start
|
i = start
|
||||||
|
@ -149,12 +146,12 @@ class TokenStreamRewriter(object):
|
||||||
prevReplaces = [op for op in rewrites[:i] if isinstance(op, TokenStreamRewriter.ReplaceOp)]
|
prevReplaces = [op for op in rewrites[:i] if isinstance(op, TokenStreamRewriter.ReplaceOp)]
|
||||||
for prevRop in prevReplaces:
|
for prevRop in prevReplaces:
|
||||||
if all((prevRop.index >= rop.index, prevRop.last_index <= rop.last_index)):
|
if all((prevRop.index >= rop.index, prevRop.last_index <= rop.last_index)):
|
||||||
rewrites[prevRop.instructioIndex] = None
|
rewrites[prevRop.instructionIndex] = None
|
||||||
continue
|
continue
|
||||||
isDisjoint = any((prevRop.last_index<rop.index, prevRop.index>rop))
|
isDisjoint = any((prevRop.last_index<rop.index, prevRop.index>rop))
|
||||||
isSame = all((prevRop.index == rop.index, prevRop.last_index == rop.last_index))
|
isSame = all((prevRop.index == rop.index, prevRop.last_index == rop.last_index))
|
||||||
if all((prevRop.text is None, rop.text is None, not isDisjoint)):
|
if all((prevRop.text is None, rop.text is None, not isDisjoint)):
|
||||||
rewrites[prevRop.instructioIndex] = None
|
rewrites[prevRop.instructionIndex] = None
|
||||||
rop.index = min(prevRop.index, rop.index)
|
rop.index = min(prevRop.index, rop.index)
|
||||||
rop.last_index = min(prevRop.last_index, rop.last_index)
|
rop.last_index = min(prevRop.last_index, rop.last_index)
|
||||||
print('New rop {}'.format(rop))
|
print('New rop {}'.format(rop))
|
||||||
|
|
Loading…
Reference in New Issue