Fixed #31211 -- Added SearchConfig expression.

Thanks Simon Charette for the review.
This commit is contained in:
Hannes Ljungberg 2020-01-10 22:17:09 +01:00 committed by Mariusz Felisiak
parent 958977f662
commit a69b6e006b
1 changed files with 33 additions and 14 deletions

View File

@ -1,5 +1,7 @@
from django.db.models import CharField, Field, FloatField, TextField from django.db.models import CharField, Field, FloatField, TextField
from django.db.models.expressions import CombinedExpression, Func, Value from django.db.models.expressions import (
CombinedExpression, Expression, Func, Value,
)
from django.db.models.functions import Cast, Coalesce from django.db.models.functions import Cast, Coalesce
from django.db.models.lookups import Lookup from django.db.models.lookups import Lookup
@ -33,6 +35,24 @@ class SearchQueryField(Field):
return 'tsquery' return 'tsquery'
class SearchConfig(Expression):
def __init__(self, config):
super().__init__()
if not hasattr(config, 'resolve_expression'):
config = Value(config)
self.config = config
def get_source_expressions(self):
return [self.config]
def set_source_expressions(self, exprs):
self.config, = exprs
def as_sql(self, compiler, connection):
sql, params = compiler.compile(self.config)
return '%s::regconfig' % sql, params
class SearchVectorCombinable: class SearchVectorCombinable:
ADD = '||' ADD = '||'
@ -55,7 +75,8 @@ class SearchVector(SearchVectorCombinable, Func):
def __init__(self, *expressions, **extra): def __init__(self, *expressions, **extra):
super().__init__(*expressions, **extra) super().__init__(*expressions, **extra)
self.config = self.extra.get('config', self.config) config = self.extra.get('config', self.config)
self.config = SearchConfig(config) if config else None
weight = self.extra.get('weight') weight = self.extra.get('weight')
if weight is not None and not hasattr(weight, 'resolve_expression'): if weight is not None and not hasattr(weight, 'resolve_expression'):
weight = Value(weight) weight = Value(weight)
@ -64,10 +85,7 @@ class SearchVector(SearchVectorCombinable, Func):
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
resolved = super().resolve_expression(query, allow_joins, reuse, summarize, for_save) resolved = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
if self.config: if self.config:
if not hasattr(self.config, 'resolve_expression'): resolved.config = self.config.resolve_expression(query, allow_joins, reuse, summarize, for_save)
resolved.config = Value(self.config).resolve_expression(query, allow_joins, reuse, summarize, for_save)
else:
resolved.config = self.config.resolve_expression(query, allow_joins, reuse, summarize, for_save)
return resolved return resolved
def as_sql(self, compiler, connection, function=None, template=None): def as_sql(self, compiler, connection, function=None, template=None):
@ -80,14 +98,18 @@ class SearchVector(SearchVectorCombinable, Func):
Value('') Value('')
) for expression in clone.get_source_expressions() ) for expression in clone.get_source_expressions()
]) ])
config_sql = None
config_params = [] config_params = []
if template is None: if template is None:
if clone.config: if clone.config:
config_sql, config_params = compiler.compile(clone.config) config_sql, config_params = compiler.compile(clone.config)
template = '%(function)s({}::regconfig, %(expressions)s)'.format(config_sql.replace('%', '%%')) template = '%(function)s(%(config)s, %(expressions)s)'
else: else:
template = clone.template template = clone.template
sql, params = super(SearchVector, clone).as_sql(compiler, connection, function=function, template=template) sql, params = super(SearchVector, clone).as_sql(
compiler, connection, function=function, template=template,
config=config_sql,
)
extra_params = [] extra_params = []
if clone.weight: if clone.weight:
weight_sql, extra_params = compiler.compile(clone.weight) weight_sql, extra_params = compiler.compile(clone.weight)
@ -141,7 +163,7 @@ class SearchQuery(SearchQueryCombinable, Value):
} }
def __init__(self, value, output_field=None, *, config=None, invert=False, search_type='plain'): def __init__(self, value, output_field=None, *, config=None, invert=False, search_type='plain'):
self.config = config self.config = SearchConfig(config) if config else None
self.invert = invert self.invert = invert
if search_type not in self.SEARCH_TYPES: if search_type not in self.SEARCH_TYPES:
raise ValueError("Unknown search_type argument '%s'." % search_type) raise ValueError("Unknown search_type argument '%s'." % search_type)
@ -151,10 +173,7 @@ class SearchQuery(SearchQueryCombinable, Value):
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
resolved = super().resolve_expression(query, allow_joins, reuse, summarize, for_save) resolved = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
if self.config: if self.config:
if not hasattr(self.config, 'resolve_expression'): resolved.config = self.config.resolve_expression(query, allow_joins, reuse, summarize, for_save)
resolved.config = Value(self.config).resolve_expression(query, allow_joins, reuse, summarize, for_save)
else:
resolved.config = self.config.resolve_expression(query, allow_joins, reuse, summarize, for_save)
return resolved return resolved
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
@ -162,7 +181,7 @@ class SearchQuery(SearchQueryCombinable, Value):
function = self.SEARCH_TYPES[self.search_type] function = self.SEARCH_TYPES[self.search_type]
if self.config: if self.config:
config_sql, config_params = compiler.compile(self.config) config_sql, config_params = compiler.compile(self.config)
template = '{}({}::regconfig, %s)'.format(function, config_sql) template = '{}({}, %s)'.format(function, config_sql)
params = config_params + [self.value] params = config_params + [self.value]
else: else:
template = '{}(%s)'.format(function) template = '{}(%s)'.format(function)