From dd704c670543edc4672cc6114fe3d8da0dfbed7a Mon Sep 17 00:00:00 2001 From: Baptiste Mispelon Date: Fri, 13 Mar 2020 10:20:34 +0100 Subject: [PATCH] Refs #31340 -- Simplified SearchQuery by making it subclass Func. --- django/contrib/postgres/search.py | 41 ++++++++++++----------------- tests/postgres_tests/test_search.py | 14 ++++++---- 2 files changed, 26 insertions(+), 29 deletions(-) diff --git a/django/contrib/postgres/search.py b/django/contrib/postgres/search.py index c404797a60..90b6823575 100644 --- a/django/contrib/postgres/search.py +++ b/django/contrib/postgres/search.py @@ -157,7 +157,7 @@ class SearchQueryCombinable: return self._combine(other, self.BITAND, True) -class SearchQuery(SearchQueryCombinable, Value): +class SearchQuery(SearchQueryCombinable, Func): output_field = SearchQueryField() SEARCH_TYPES = { 'plain': 'plainto_tsquery', @@ -167,34 +167,27 @@ class SearchQuery(SearchQueryCombinable, Value): } def __init__(self, value, output_field=None, *, config=None, invert=False, search_type='plain'): - self.config = SearchConfig.from_parameter(config) - self.invert = invert - if search_type not in self.SEARCH_TYPES: + self.function = self.SEARCH_TYPES.get(search_type) + if self.function is None: raise ValueError("Unknown search_type argument '%s'." % search_type) - self.search_type = search_type - super().__init__(value, output_field=output_field) + value = Value(value) + expressions = (value,) + self.config = SearchConfig.from_parameter(config) + if self.config is not None: + expressions = (self.config,) + expressions + self.invert = invert + super().__init__(*expressions, output_field=output_field) - def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): - resolved = super().resolve_expression(query, allow_joins, reuse, summarize, for_save) - if self.config: - resolved.config = self.config.resolve_expression(query, allow_joins, reuse, summarize, for_save) - return resolved - - def as_sql(self, compiler, connection): - params = [self.value] - function = self.SEARCH_TYPES[self.search_type] - if self.config: - config_sql, config_params = compiler.compile(self.config) - template = '{}({}, %s)'.format(function, config_sql) - params = config_params + [self.value] - else: - template = '{}(%s)'.format(function) + def as_sql(self, compiler, connection, function=None, template=None): + sql, params = super().as_sql(compiler, connection, function, template) if self.invert: - template = '!!({})'.format(template) - return template, params + sql = '!!(%s)' % sql + return sql, params def __invert__(self): - return type(self)(self.value, config=self.config, invert=not self.invert) + clone = self.copy() + clone.invert = not self.invert + return clone def __str__(self): result = super().__str__() diff --git a/tests/postgres_tests/test_search.py b/tests/postgres_tests/test_search.py index 04adef4f17..86a409afc2 100644 --- a/tests/postgres_tests/test_search.py +++ b/tests/postgres_tests/test_search.py @@ -449,22 +449,26 @@ class SearchVectorIndexTests(PostgreSQLTestCase): class SearchQueryTests(PostgreSQLSimpleTestCase): def test_str(self): tests = ( - (~SearchQuery('a'), '~SearchQuery(a)'), + (~SearchQuery('a'), '~SearchQuery(Value(a))'), ( (SearchQuery('a') | SearchQuery('b')) & (SearchQuery('c') | SearchQuery('d')), - '((SearchQuery(a) || SearchQuery(b)) && (SearchQuery(c) || SearchQuery(d)))', + '((SearchQuery(Value(a)) || SearchQuery(Value(b))) && ' + '(SearchQuery(Value(c)) || SearchQuery(Value(d))))', ), ( SearchQuery('a') & (SearchQuery('b') | SearchQuery('c')), - '(SearchQuery(a) && (SearchQuery(b) || SearchQuery(c)))', + '(SearchQuery(Value(a)) && (SearchQuery(Value(b)) || ' + 'SearchQuery(Value(c))))', ), ( (SearchQuery('a') | SearchQuery('b')) & SearchQuery('c'), - '((SearchQuery(a) || SearchQuery(b)) && SearchQuery(c))' + '((SearchQuery(Value(a)) || SearchQuery(Value(b))) && ' + 'SearchQuery(Value(c)))' ), ( SearchQuery('a') & (SearchQuery('b') & (SearchQuery('c') | SearchQuery('d'))), - '(SearchQuery(a) && (SearchQuery(b) && (SearchQuery(c) || SearchQuery(d))))', + '(SearchQuery(Value(a)) && (SearchQuery(Value(b)) && ' + '(SearchQuery(Value(c)) || SearchQuery(Value(d)))))', ), ) for query, expected_str in tests: