Fixed #31211 -- Added SearchConfig expression.
Thanks Simon Charette for the review.
This commit is contained in:
parent
958977f662
commit
a69b6e006b
|
@ -1,5 +1,7 @@
|
|||
from django.db.models import CharField, Field, FloatField, TextField
|
||||
from django.db.models.expressions import CombinedExpression, Func, Value
|
||||
from django.db.models.expressions import (
|
||||
CombinedExpression, Expression, Func, Value,
|
||||
)
|
||||
from django.db.models.functions import Cast, Coalesce
|
||||
from django.db.models.lookups import Lookup
|
||||
|
||||
|
@ -33,6 +35,24 @@ class SearchQueryField(Field):
|
|||
return 'tsquery'
|
||||
|
||||
|
||||
class SearchConfig(Expression):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
if not hasattr(config, 'resolve_expression'):
|
||||
config = Value(config)
|
||||
self.config = config
|
||||
|
||||
def get_source_expressions(self):
|
||||
return [self.config]
|
||||
|
||||
def set_source_expressions(self, exprs):
|
||||
self.config, = exprs
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
sql, params = compiler.compile(self.config)
|
||||
return '%s::regconfig' % sql, params
|
||||
|
||||
|
||||
class SearchVectorCombinable:
|
||||
ADD = '||'
|
||||
|
||||
|
@ -55,7 +75,8 @@ class SearchVector(SearchVectorCombinable, Func):
|
|||
|
||||
def __init__(self, *expressions, **extra):
|
||||
super().__init__(*expressions, **extra)
|
||||
self.config = self.extra.get('config', self.config)
|
||||
config = self.extra.get('config', self.config)
|
||||
self.config = SearchConfig(config) if config else None
|
||||
weight = self.extra.get('weight')
|
||||
if weight is not None and not hasattr(weight, 'resolve_expression'):
|
||||
weight = Value(weight)
|
||||
|
@ -64,10 +85,7 @@ class SearchVector(SearchVectorCombinable, Func):
|
|||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
||||
resolved = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
||||
if self.config:
|
||||
if not hasattr(self.config, 'resolve_expression'):
|
||||
resolved.config = Value(self.config).resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
||||
else:
|
||||
resolved.config = self.config.resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
||||
resolved.config = self.config.resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
||||
return resolved
|
||||
|
||||
def as_sql(self, compiler, connection, function=None, template=None):
|
||||
|
@ -80,14 +98,18 @@ class SearchVector(SearchVectorCombinable, Func):
|
|||
Value('')
|
||||
) for expression in clone.get_source_expressions()
|
||||
])
|
||||
config_sql = None
|
||||
config_params = []
|
||||
if template is None:
|
||||
if clone.config:
|
||||
config_sql, config_params = compiler.compile(clone.config)
|
||||
template = '%(function)s({}::regconfig, %(expressions)s)'.format(config_sql.replace('%', '%%'))
|
||||
template = '%(function)s(%(config)s, %(expressions)s)'
|
||||
else:
|
||||
template = clone.template
|
||||
sql, params = super(SearchVector, clone).as_sql(compiler, connection, function=function, template=template)
|
||||
sql, params = super(SearchVector, clone).as_sql(
|
||||
compiler, connection, function=function, template=template,
|
||||
config=config_sql,
|
||||
)
|
||||
extra_params = []
|
||||
if clone.weight:
|
||||
weight_sql, extra_params = compiler.compile(clone.weight)
|
||||
|
@ -141,7 +163,7 @@ class SearchQuery(SearchQueryCombinable, Value):
|
|||
}
|
||||
|
||||
def __init__(self, value, output_field=None, *, config=None, invert=False, search_type='plain'):
|
||||
self.config = config
|
||||
self.config = SearchConfig(config) if config else None
|
||||
self.invert = invert
|
||||
if search_type not in self.SEARCH_TYPES:
|
||||
raise ValueError("Unknown search_type argument '%s'." % search_type)
|
||||
|
@ -151,10 +173,7 @@ class SearchQuery(SearchQueryCombinable, Value):
|
|||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
||||
resolved = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
||||
if self.config:
|
||||
if not hasattr(self.config, 'resolve_expression'):
|
||||
resolved.config = Value(self.config).resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
||||
else:
|
||||
resolved.config = self.config.resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
||||
resolved.config = self.config.resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
||||
return resolved
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
|
@ -162,7 +181,7 @@ class SearchQuery(SearchQueryCombinable, Value):
|
|||
function = self.SEARCH_TYPES[self.search_type]
|
||||
if self.config:
|
||||
config_sql, config_params = compiler.compile(self.config)
|
||||
template = '{}({}::regconfig, %s)'.format(function, config_sql)
|
||||
template = '{}({}, %s)'.format(function, config_sql)
|
||||
params = config_params + [self.value]
|
||||
else:
|
||||
template = '{}(%s)'.format(function)
|
||||
|
|
Loading…
Reference in New Issue