Refs #31340 -- Simplified SearchQuery by making it subclass Func.

This commit is contained in:
Baptiste Mispelon 2020-03-13 10:20:34 +01:00 committed by Mariusz Felisiak
parent b62c58d5fc
commit dd704c6705
2 changed files with 26 additions and 29 deletions

View File

@ -157,7 +157,7 @@ class SearchQueryCombinable:
return self._combine(other, self.BITAND, True) return self._combine(other, self.BITAND, True)
class SearchQuery(SearchQueryCombinable, Value): class SearchQuery(SearchQueryCombinable, Func):
output_field = SearchQueryField() output_field = SearchQueryField()
SEARCH_TYPES = { SEARCH_TYPES = {
'plain': 'plainto_tsquery', 'plain': 'plainto_tsquery',
@ -167,34 +167,27 @@ class SearchQuery(SearchQueryCombinable, Value):
} }
def __init__(self, value, output_field=None, *, config=None, invert=False, search_type='plain'): def __init__(self, value, output_field=None, *, config=None, invert=False, search_type='plain'):
self.config = SearchConfig.from_parameter(config) self.function = self.SEARCH_TYPES.get(search_type)
self.invert = invert if self.function is None:
if search_type not in self.SEARCH_TYPES:
raise ValueError("Unknown search_type argument '%s'." % search_type) raise ValueError("Unknown search_type argument '%s'." % search_type)
self.search_type = search_type value = Value(value)
super().__init__(value, output_field=output_field) expressions = (value,)
self.config = SearchConfig.from_parameter(config)
if self.config is not None:
expressions = (self.config,) + expressions
self.invert = invert
super().__init__(*expressions, output_field=output_field)
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): def as_sql(self, compiler, connection, function=None, template=None):
resolved = super().resolve_expression(query, allow_joins, reuse, summarize, for_save) sql, params = super().as_sql(compiler, connection, function, template)
if self.config:
resolved.config = self.config.resolve_expression(query, allow_joins, reuse, summarize, for_save)
return resolved
def as_sql(self, compiler, connection):
params = [self.value]
function = self.SEARCH_TYPES[self.search_type]
if self.config:
config_sql, config_params = compiler.compile(self.config)
template = '{}({}, %s)'.format(function, config_sql)
params = config_params + [self.value]
else:
template = '{}(%s)'.format(function)
if self.invert: if self.invert:
template = '!!({})'.format(template) sql = '!!(%s)' % sql
return template, params return sql, params
def __invert__(self): def __invert__(self):
return type(self)(self.value, config=self.config, invert=not self.invert) clone = self.copy()
clone.invert = not self.invert
return clone
def __str__(self): def __str__(self):
result = super().__str__() result = super().__str__()

View File

@ -449,22 +449,26 @@ class SearchVectorIndexTests(PostgreSQLTestCase):
class SearchQueryTests(PostgreSQLSimpleTestCase): class SearchQueryTests(PostgreSQLSimpleTestCase):
def test_str(self): def test_str(self):
tests = ( tests = (
(~SearchQuery('a'), '~SearchQuery(a)'), (~SearchQuery('a'), '~SearchQuery(Value(a))'),
( (
(SearchQuery('a') | SearchQuery('b')) & (SearchQuery('c') | SearchQuery('d')), (SearchQuery('a') | SearchQuery('b')) & (SearchQuery('c') | SearchQuery('d')),
'((SearchQuery(a) || SearchQuery(b)) && (SearchQuery(c) || SearchQuery(d)))', '((SearchQuery(Value(a)) || SearchQuery(Value(b))) && '
'(SearchQuery(Value(c)) || SearchQuery(Value(d))))',
), ),
( (
SearchQuery('a') & (SearchQuery('b') | SearchQuery('c')), SearchQuery('a') & (SearchQuery('b') | SearchQuery('c')),
'(SearchQuery(a) && (SearchQuery(b) || SearchQuery(c)))', '(SearchQuery(Value(a)) && (SearchQuery(Value(b)) || '
'SearchQuery(Value(c))))',
), ),
( (
(SearchQuery('a') | SearchQuery('b')) & SearchQuery('c'), (SearchQuery('a') | SearchQuery('b')) & SearchQuery('c'),
'((SearchQuery(a) || SearchQuery(b)) && SearchQuery(c))' '((SearchQuery(Value(a)) || SearchQuery(Value(b))) && '
'SearchQuery(Value(c)))'
), ),
( (
SearchQuery('a') & (SearchQuery('b') & (SearchQuery('c') | SearchQuery('d'))), SearchQuery('a') & (SearchQuery('b') & (SearchQuery('c') | SearchQuery('d'))),
'(SearchQuery(a) && (SearchQuery(b) && (SearchQuery(c) || SearchQuery(d))))', '(SearchQuery(Value(a)) && (SearchQuery(Value(b)) && '
'(SearchQuery(Value(c)) || SearchQuery(Value(d)))))',
), ),
) )
for query, expected_str in tests: for query, expected_str in tests: