Formalized SearchVector and SearchRank signatures.

This commit is contained in:
Simon Charette 2020-02-25 00:08:55 -05:00 committed by Mariusz Felisiak
parent d0f1c03331
commit 1138ca4c57
1 changed files with 8 additions and 13 deletions

View File

@ -76,13 +76,10 @@ class SearchVector(SearchVectorCombinable, Func):
function = 'to_tsvector' function = 'to_tsvector'
arg_joiner = " || ' ' || " arg_joiner = " || ' ' || "
output_field = SearchVectorField() output_field = SearchVectorField()
config = None
def __init__(self, *expressions, **extra): def __init__(self, *expressions, config=None, weight=None):
super().__init__(*expressions, **extra) super().__init__(*expressions)
config = self.extra.get('config', self.config)
self.config = SearchConfig.from_parameter(config) self.config = SearchConfig.from_parameter(config)
weight = self.extra.get('weight')
if weight is not None and not hasattr(weight, 'resolve_expression'): if weight is not None and not hasattr(weight, 'resolve_expression'):
weight = Value(weight) weight = Value(weight)
self.weight = weight self.weight = weight
@ -220,25 +217,23 @@ class SearchRank(Func):
function = 'ts_rank' function = 'ts_rank'
output_field = FloatField() output_field = FloatField()
def __init__(self, vector, query, **extra): def __init__(self, vector, query, weights=None):
if not hasattr(vector, 'resolve_expression'): if not hasattr(vector, 'resolve_expression'):
vector = SearchVector(vector) vector = SearchVector(vector)
if not hasattr(query, 'resolve_expression'): if not hasattr(query, 'resolve_expression'):
query = SearchQuery(query) query = SearchQuery(query)
weights = extra.get('weights')
if weights is not None and not hasattr(weights, 'resolve_expression'): if weights is not None and not hasattr(weights, 'resolve_expression'):
weights = Value(weights) weights = Value(weights)
self.weights = weights self.weights = weights
super().__init__(vector, query, **extra) super().__init__(vector, query)
def as_sql(self, compiler, connection, function=None, template=None): def as_sql(self, compiler, connection, function=None, template=None):
extra_params = [] extra_params = []
extra_context = {} extra_context = {}
if template is None and self.extra.get('weights'): if template is None and self.weights:
if self.weights: template = '%(function)s(%(weights)s, %(expressions)s)'
template = '%(function)s(%(weights)s, %(expressions)s)' weight_sql, extra_params = compiler.compile(self.weights)
weight_sql, extra_params = compiler.compile(self.weights) extra_context['weights'] = weight_sql
extra_context['weights'] = weight_sql
sql, params = super().as_sql( sql, params = super().as_sql(
compiler, connection, compiler, connection,
function=function, template=template, **extra_context function=function, template=template, **extra_context