diff --git a/django/contrib/postgres/search.py b/django/contrib/postgres/search.py index ff0e9cf3b5..1b96b26513 100644 --- a/django/contrib/postgres/search.py +++ b/django/contrib/postgres/search.py @@ -1,5 +1,7 @@ from django.db.models import CharField, Field, FloatField, TextField -from django.db.models.expressions import CombinedExpression, Func, Value +from django.db.models.expressions import ( + CombinedExpression, Expression, Func, Value, +) from django.db.models.functions import Cast, Coalesce from django.db.models.lookups import Lookup @@ -33,6 +35,24 @@ class SearchQueryField(Field): return 'tsquery' +class SearchConfig(Expression): + def __init__(self, config): + super().__init__() + if not hasattr(config, 'resolve_expression'): + config = Value(config) + self.config = config + + def get_source_expressions(self): + return [self.config] + + def set_source_expressions(self, exprs): + self.config, = exprs + + def as_sql(self, compiler, connection): + sql, params = compiler.compile(self.config) + return '%s::regconfig' % sql, params + + class SearchVectorCombinable: ADD = '||' @@ -55,7 +75,8 @@ class SearchVector(SearchVectorCombinable, Func): def __init__(self, *expressions, **extra): super().__init__(*expressions, **extra) - self.config = self.extra.get('config', self.config) + config = self.extra.get('config', self.config) + self.config = SearchConfig(config) if config else None weight = self.extra.get('weight') if weight is not None and not hasattr(weight, 'resolve_expression'): weight = Value(weight) @@ -64,10 +85,7 @@ class SearchVector(SearchVectorCombinable, Func): def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): resolved = super().resolve_expression(query, allow_joins, reuse, summarize, for_save) if self.config: - if not hasattr(self.config, 'resolve_expression'): - resolved.config = Value(self.config).resolve_expression(query, allow_joins, reuse, summarize, for_save) - else: - resolved.config = self.config.resolve_expression(query, allow_joins, reuse, summarize, for_save) + resolved.config = self.config.resolve_expression(query, allow_joins, reuse, summarize, for_save) return resolved def as_sql(self, compiler, connection, function=None, template=None): @@ -80,14 +98,18 @@ class SearchVector(SearchVectorCombinable, Func): Value('') ) for expression in clone.get_source_expressions() ]) + config_sql = None config_params = [] if template is None: if clone.config: config_sql, config_params = compiler.compile(clone.config) - template = '%(function)s({}::regconfig, %(expressions)s)'.format(config_sql.replace('%', '%%')) + template = '%(function)s(%(config)s, %(expressions)s)' else: template = clone.template - sql, params = super(SearchVector, clone).as_sql(compiler, connection, function=function, template=template) + sql, params = super(SearchVector, clone).as_sql( + compiler, connection, function=function, template=template, + config=config_sql, + ) extra_params = [] if clone.weight: weight_sql, extra_params = compiler.compile(clone.weight) @@ -141,7 +163,7 @@ class SearchQuery(SearchQueryCombinable, Value): } def __init__(self, value, output_field=None, *, config=None, invert=False, search_type='plain'): - self.config = config + self.config = SearchConfig(config) if config else None self.invert = invert if search_type not in self.SEARCH_TYPES: raise ValueError("Unknown search_type argument '%s'." % search_type) @@ -151,10 +173,7 @@ class SearchQuery(SearchQueryCombinable, Value): def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): resolved = super().resolve_expression(query, allow_joins, reuse, summarize, for_save) if self.config: - if not hasattr(self.config, 'resolve_expression'): - resolved.config = Value(self.config).resolve_expression(query, allow_joins, reuse, summarize, for_save) - else: - resolved.config = self.config.resolve_expression(query, allow_joins, reuse, summarize, for_save) + resolved.config = self.config.resolve_expression(query, allow_joins, reuse, summarize, for_save) return resolved def as_sql(self, compiler, connection): @@ -162,7 +181,7 @@ class SearchQuery(SearchQueryCombinable, Value): function = self.SEARCH_TYPES[self.search_type] if self.config: config_sql, config_params = compiler.compile(self.config) - template = '{}({}::regconfig, %s)'.format(function, config_sql) + template = '{}({}, %s)'.format(function, config_sql) params = config_params + [self.value] else: template = '{}(%s)'.format(function)