Fixed #31211 -- Added SearchConfig expression.

Thanks Simon Charette for the review.
This commit is contained in:
Hannes Ljungberg 2020-01-10 22:17:09 +01:00 committed by Mariusz Felisiak
parent 958977f662
commit a69b6e006b
1 changed files with 33 additions and 14 deletions

View File

@ -1,5 +1,7 @@
from django.db.models import CharField, Field, FloatField, TextField
from django.db.models.expressions import CombinedExpression, Func, Value
from django.db.models.expressions import (
CombinedExpression, Expression, Func, Value,
)
from django.db.models.functions import Cast, Coalesce
from django.db.models.lookups import Lookup
@ -33,6 +35,24 @@ class SearchQueryField(Field):
return 'tsquery'
class SearchConfig(Expression):
def __init__(self, config):
super().__init__()
if not hasattr(config, 'resolve_expression'):
config = Value(config)
self.config = config
def get_source_expressions(self):
return [self.config]
def set_source_expressions(self, exprs):
self.config, = exprs
def as_sql(self, compiler, connection):
sql, params = compiler.compile(self.config)
return '%s::regconfig' % sql, params
class SearchVectorCombinable:
ADD = '||'
@ -55,7 +75,8 @@ class SearchVector(SearchVectorCombinable, Func):
def __init__(self, *expressions, **extra):
super().__init__(*expressions, **extra)
self.config = self.extra.get('config', self.config)
config = self.extra.get('config', self.config)
self.config = SearchConfig(config) if config else None
weight = self.extra.get('weight')
if weight is not None and not hasattr(weight, 'resolve_expression'):
weight = Value(weight)
@ -64,9 +85,6 @@ class SearchVector(SearchVectorCombinable, Func):
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
resolved = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
if self.config:
if not hasattr(self.config, 'resolve_expression'):
resolved.config = Value(self.config).resolve_expression(query, allow_joins, reuse, summarize, for_save)
else:
resolved.config = self.config.resolve_expression(query, allow_joins, reuse, summarize, for_save)
return resolved
@ -80,14 +98,18 @@ class SearchVector(SearchVectorCombinable, Func):
Value('')
) for expression in clone.get_source_expressions()
])
config_sql = None
config_params = []
if template is None:
if clone.config:
config_sql, config_params = compiler.compile(clone.config)
template = '%(function)s({}::regconfig, %(expressions)s)'.format(config_sql.replace('%', '%%'))
template = '%(function)s(%(config)s, %(expressions)s)'
else:
template = clone.template
sql, params = super(SearchVector, clone).as_sql(compiler, connection, function=function, template=template)
sql, params = super(SearchVector, clone).as_sql(
compiler, connection, function=function, template=template,
config=config_sql,
)
extra_params = []
if clone.weight:
weight_sql, extra_params = compiler.compile(clone.weight)
@ -141,7 +163,7 @@ class SearchQuery(SearchQueryCombinable, Value):
}
def __init__(self, value, output_field=None, *, config=None, invert=False, search_type='plain'):
self.config = config
self.config = SearchConfig(config) if config else None
self.invert = invert
if search_type not in self.SEARCH_TYPES:
raise ValueError("Unknown search_type argument '%s'." % search_type)
@ -151,9 +173,6 @@ class SearchQuery(SearchQueryCombinable, Value):
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
resolved = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
if self.config:
if not hasattr(self.config, 'resolve_expression'):
resolved.config = Value(self.config).resolve_expression(query, allow_joins, reuse, summarize, for_save)
else:
resolved.config = self.config.resolve_expression(query, allow_joins, reuse, summarize, for_save)
return resolved
@ -162,7 +181,7 @@ class SearchQuery(SearchQueryCombinable, Value):
function = self.SEARCH_TYPES[self.search_type]
if self.config:
config_sql, config_params = compiler.compile(self.config)
template = '{}({}::regconfig, %s)'.format(function, config_sql)
template = '{}({}, %s)'.format(function, config_sql)
params = config_params + [self.value]
else:
template = '{}(%s)'.format(function)