Refs #31340 -- Simplified SearchQuery by making it subclass Func.

This commit is contained in:
Baptiste Mispelon 2020-03-13 10:20:34 +01:00 committed by Mariusz Felisiak
parent b62c58d5fc
commit dd704c6705
2 changed files with 26 additions and 29 deletions

View File

@ -157,7 +157,7 @@ class SearchQueryCombinable:
return self._combine(other, self.BITAND, True)
class SearchQuery(SearchQueryCombinable, Value):
class SearchQuery(SearchQueryCombinable, Func):
output_field = SearchQueryField()
SEARCH_TYPES = {
'plain': 'plainto_tsquery',
@ -167,34 +167,27 @@ class SearchQuery(SearchQueryCombinable, Value):
}
def __init__(self, value, output_field=None, *, config=None, invert=False, search_type='plain'):
self.config = SearchConfig.from_parameter(config)
self.invert = invert
if search_type not in self.SEARCH_TYPES:
self.function = self.SEARCH_TYPES.get(search_type)
if self.function is None:
raise ValueError("Unknown search_type argument '%s'." % search_type)
self.search_type = search_type
super().__init__(value, output_field=output_field)
value = Value(value)
expressions = (value,)
self.config = SearchConfig.from_parameter(config)
if self.config is not None:
expressions = (self.config,) + expressions
self.invert = invert
super().__init__(*expressions, output_field=output_field)
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
resolved = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
if self.config:
resolved.config = self.config.resolve_expression(query, allow_joins, reuse, summarize, for_save)
return resolved
def as_sql(self, compiler, connection):
params = [self.value]
function = self.SEARCH_TYPES[self.search_type]
if self.config:
config_sql, config_params = compiler.compile(self.config)
template = '{}({}, %s)'.format(function, config_sql)
params = config_params + [self.value]
else:
template = '{}(%s)'.format(function)
def as_sql(self, compiler, connection, function=None, template=None):
sql, params = super().as_sql(compiler, connection, function, template)
if self.invert:
template = '!!({})'.format(template)
return template, params
sql = '!!(%s)' % sql
return sql, params
def __invert__(self):
return type(self)(self.value, config=self.config, invert=not self.invert)
clone = self.copy()
clone.invert = not self.invert
return clone
def __str__(self):
result = super().__str__()

View File

@ -449,22 +449,26 @@ class SearchVectorIndexTests(PostgreSQLTestCase):
class SearchQueryTests(PostgreSQLSimpleTestCase):
def test_str(self):
tests = (
(~SearchQuery('a'), '~SearchQuery(a)'),
(~SearchQuery('a'), '~SearchQuery(Value(a))'),
(
(SearchQuery('a') | SearchQuery('b')) & (SearchQuery('c') | SearchQuery('d')),
'((SearchQuery(a) || SearchQuery(b)) && (SearchQuery(c) || SearchQuery(d)))',
'((SearchQuery(Value(a)) || SearchQuery(Value(b))) && '
'(SearchQuery(Value(c)) || SearchQuery(Value(d))))',
),
(
SearchQuery('a') & (SearchQuery('b') | SearchQuery('c')),
'(SearchQuery(a) && (SearchQuery(b) || SearchQuery(c)))',
'(SearchQuery(Value(a)) && (SearchQuery(Value(b)) || '
'SearchQuery(Value(c))))',
),
(
(SearchQuery('a') | SearchQuery('b')) & SearchQuery('c'),
'((SearchQuery(a) || SearchQuery(b)) && SearchQuery(c))'
'((SearchQuery(Value(a)) || SearchQuery(Value(b))) && '
'SearchQuery(Value(c)))'
),
(
SearchQuery('a') & (SearchQuery('b') & (SearchQuery('c') | SearchQuery('d'))),
'(SearchQuery(a) && (SearchQuery(b) && (SearchQuery(c) || SearchQuery(d))))',
'(SearchQuery(Value(a)) && (SearchQuery(Value(b)) && '
'(SearchQuery(Value(c)) || SearchQuery(Value(d)))))',
),
)
for query, expected_str in tests: