Fixed #31211 -- Added SearchConfig expression.
Thanks Simon Charette for the review.
This commit is contained in:
parent
958977f662
commit
a69b6e006b
|
@ -1,5 +1,7 @@
|
||||||
from django.db.models import CharField, Field, FloatField, TextField
|
from django.db.models import CharField, Field, FloatField, TextField
|
||||||
from django.db.models.expressions import CombinedExpression, Func, Value
|
from django.db.models.expressions import (
|
||||||
|
CombinedExpression, Expression, Func, Value,
|
||||||
|
)
|
||||||
from django.db.models.functions import Cast, Coalesce
|
from django.db.models.functions import Cast, Coalesce
|
||||||
from django.db.models.lookups import Lookup
|
from django.db.models.lookups import Lookup
|
||||||
|
|
||||||
|
@ -33,6 +35,24 @@ class SearchQueryField(Field):
|
||||||
return 'tsquery'
|
return 'tsquery'
|
||||||
|
|
||||||
|
|
||||||
|
class SearchConfig(Expression):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
if not hasattr(config, 'resolve_expression'):
|
||||||
|
config = Value(config)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def get_source_expressions(self):
|
||||||
|
return [self.config]
|
||||||
|
|
||||||
|
def set_source_expressions(self, exprs):
|
||||||
|
self.config, = exprs
|
||||||
|
|
||||||
|
def as_sql(self, compiler, connection):
|
||||||
|
sql, params = compiler.compile(self.config)
|
||||||
|
return '%s::regconfig' % sql, params
|
||||||
|
|
||||||
|
|
||||||
class SearchVectorCombinable:
|
class SearchVectorCombinable:
|
||||||
ADD = '||'
|
ADD = '||'
|
||||||
|
|
||||||
|
@ -55,7 +75,8 @@ class SearchVector(SearchVectorCombinable, Func):
|
||||||
|
|
||||||
def __init__(self, *expressions, **extra):
|
def __init__(self, *expressions, **extra):
|
||||||
super().__init__(*expressions, **extra)
|
super().__init__(*expressions, **extra)
|
||||||
self.config = self.extra.get('config', self.config)
|
config = self.extra.get('config', self.config)
|
||||||
|
self.config = SearchConfig(config) if config else None
|
||||||
weight = self.extra.get('weight')
|
weight = self.extra.get('weight')
|
||||||
if weight is not None and not hasattr(weight, 'resolve_expression'):
|
if weight is not None and not hasattr(weight, 'resolve_expression'):
|
||||||
weight = Value(weight)
|
weight = Value(weight)
|
||||||
|
@ -64,10 +85,7 @@ class SearchVector(SearchVectorCombinable, Func):
|
||||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
||||||
resolved = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
resolved = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
||||||
if self.config:
|
if self.config:
|
||||||
if not hasattr(self.config, 'resolve_expression'):
|
resolved.config = self.config.resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
||||||
resolved.config = Value(self.config).resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
|
||||||
else:
|
|
||||||
resolved.config = self.config.resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
|
||||||
return resolved
|
return resolved
|
||||||
|
|
||||||
def as_sql(self, compiler, connection, function=None, template=None):
|
def as_sql(self, compiler, connection, function=None, template=None):
|
||||||
|
@ -80,14 +98,18 @@ class SearchVector(SearchVectorCombinable, Func):
|
||||||
Value('')
|
Value('')
|
||||||
) for expression in clone.get_source_expressions()
|
) for expression in clone.get_source_expressions()
|
||||||
])
|
])
|
||||||
|
config_sql = None
|
||||||
config_params = []
|
config_params = []
|
||||||
if template is None:
|
if template is None:
|
||||||
if clone.config:
|
if clone.config:
|
||||||
config_sql, config_params = compiler.compile(clone.config)
|
config_sql, config_params = compiler.compile(clone.config)
|
||||||
template = '%(function)s({}::regconfig, %(expressions)s)'.format(config_sql.replace('%', '%%'))
|
template = '%(function)s(%(config)s, %(expressions)s)'
|
||||||
else:
|
else:
|
||||||
template = clone.template
|
template = clone.template
|
||||||
sql, params = super(SearchVector, clone).as_sql(compiler, connection, function=function, template=template)
|
sql, params = super(SearchVector, clone).as_sql(
|
||||||
|
compiler, connection, function=function, template=template,
|
||||||
|
config=config_sql,
|
||||||
|
)
|
||||||
extra_params = []
|
extra_params = []
|
||||||
if clone.weight:
|
if clone.weight:
|
||||||
weight_sql, extra_params = compiler.compile(clone.weight)
|
weight_sql, extra_params = compiler.compile(clone.weight)
|
||||||
|
@ -141,7 +163,7 @@ class SearchQuery(SearchQueryCombinable, Value):
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, value, output_field=None, *, config=None, invert=False, search_type='plain'):
|
def __init__(self, value, output_field=None, *, config=None, invert=False, search_type='plain'):
|
||||||
self.config = config
|
self.config = SearchConfig(config) if config else None
|
||||||
self.invert = invert
|
self.invert = invert
|
||||||
if search_type not in self.SEARCH_TYPES:
|
if search_type not in self.SEARCH_TYPES:
|
||||||
raise ValueError("Unknown search_type argument '%s'." % search_type)
|
raise ValueError("Unknown search_type argument '%s'." % search_type)
|
||||||
|
@ -151,10 +173,7 @@ class SearchQuery(SearchQueryCombinable, Value):
|
||||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
||||||
resolved = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
resolved = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
||||||
if self.config:
|
if self.config:
|
||||||
if not hasattr(self.config, 'resolve_expression'):
|
resolved.config = self.config.resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
||||||
resolved.config = Value(self.config).resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
|
||||||
else:
|
|
||||||
resolved.config = self.config.resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
|
||||||
return resolved
|
return resolved
|
||||||
|
|
||||||
def as_sql(self, compiler, connection):
|
def as_sql(self, compiler, connection):
|
||||||
|
@ -162,7 +181,7 @@ class SearchQuery(SearchQueryCombinable, Value):
|
||||||
function = self.SEARCH_TYPES[self.search_type]
|
function = self.SEARCH_TYPES[self.search_type]
|
||||||
if self.config:
|
if self.config:
|
||||||
config_sql, config_params = compiler.compile(self.config)
|
config_sql, config_params = compiler.compile(self.config)
|
||||||
template = '{}({}::regconfig, %s)'.format(function, config_sql)
|
template = '{}({}, %s)'.format(function, config_sql)
|
||||||
params = config_params + [self.value]
|
params = config_params + [self.value]
|
||||||
else:
|
else:
|
||||||
template = '{}(%s)'.format(function)
|
template = '{}(%s)'.format(function)
|
||||||
|
|
Loading…
Reference in New Issue