Refs #29396, #30494 -- Reduced code duplication in year lookups.

This commit is contained in:
Simon Charette 2019-05-20 19:58:11 -04:00 committed by Mariusz Felisiak
parent 2b582a7b84
commit 514104cf23
2 changed files with 27 additions and 38 deletions

View File

@ -483,8 +483,6 @@ class YearLookup(Lookup):
bounds = connection.ops.year_lookup_bounds_for_date_field(year) bounds = connection.ops.year_lookup_bounds_for_date_field(year)
return bounds return bounds
class YearComparisonLookup(YearLookup):
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
# Avoid the extract operation if the rhs is a direct value to allow # Avoid the extract operation if the rhs is a direct value to allow
# indexes to be used. # indexes to be used.
@ -493,53 +491,44 @@ class YearComparisonLookup(YearLookup):
# that is self.lhs.lhs. # that is self.lhs.lhs.
lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs) lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs)
rhs_sql, _ = self.process_rhs(compiler, connection) rhs_sql, _ = self.process_rhs(compiler, connection)
rhs_sql = self.get_rhs_op(connection, rhs_sql) rhs_sql = self.get_direct_rhs_sql(connection, rhs_sql)
start, finish = self.year_lookup_bounds(connection, self.rhs) start, finish = self.year_lookup_bounds(connection, self.rhs)
params.append(self.get_bound(start, finish)) params.extend(self.get_bound_params(start, finish))
return '%s %s' % (lhs_sql, rhs_sql), params return '%s %s' % (lhs_sql, rhs_sql), params
return super().as_sql(compiler, connection) return super().as_sql(compiler, connection)
def get_rhs_op(self, connection, rhs): def get_direct_rhs_sql(self, connection, rhs):
return connection.operators[self.lookup_name] % rhs return connection.operators[self.lookup_name] % rhs
def get_bound(self, start, finish): def get_bound_params(self, start, finish):
raise NotImplementedError( raise NotImplementedError(
'subclasses of YearComparisonLookup must provide a get_bound() method' 'subclasses of YearLookup must provide a get_bound_params() method'
) )
class YearExact(YearLookup, Exact): class YearExact(YearLookup, Exact):
lookup_name = 'exact' def get_direct_rhs_sql(self, connection, rhs):
return 'BETWEEN %s AND %s'
def as_sql(self, compiler, connection): def get_bound_params(self, start, finish):
# Avoid the extract operation if the rhs is a direct value to allow return (start, finish)
# indexes to be used.
if self.rhs_is_direct_value():
# Skip the extract part by directly using the originating field,
# that is self.lhs.lhs.
lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs)
bounds = self.year_lookup_bounds(connection, self.rhs)
params.extend(bounds)
return '%s BETWEEN %%s AND %%s' % lhs_sql, params
return super().as_sql(compiler, connection)
class YearGt(YearLookup, GreaterThan):
class YearGt(YearComparisonLookup, GreaterThan): def get_bound_params(self, start, finish):
def get_bound(self, start, finish): return (finish,)
return finish
class YearGte(YearComparisonLookup, GreaterThanOrEqual): class YearGte(YearLookup, GreaterThanOrEqual):
def get_bound(self, start, finish): def get_bound_params(self, start, finish):
return start return (start,)
class YearLt(YearComparisonLookup, LessThan): class YearLt(YearLookup, LessThan):
def get_bound(self, start, finish): def get_bound_params(self, start, finish):
return start return (start,)
class YearLte(YearComparisonLookup, LessThanOrEqual): class YearLte(YearLookup, LessThanOrEqual):
def get_bound(self, start, finish): def get_bound_params(self, start, finish):
return finish return (finish,)

View File

@ -2,16 +2,16 @@ from datetime import datetime
from django.db.models import Value from django.db.models import Value
from django.db.models.fields import DateTimeField from django.db.models.fields import DateTimeField
from django.db.models.lookups import YearComparisonLookup from django.db.models.lookups import YearLookup
from django.test import SimpleTestCase from django.test import SimpleTestCase
class YearComparisonLookupTests(SimpleTestCase): class YearLookupTests(SimpleTestCase):
def test_get_bound(self): def test_get_bound_params(self):
look_up = YearComparisonLookup( look_up = YearLookup(
lhs=Value(datetime(2010, 1, 1, 0, 0, 0), output_field=DateTimeField()), lhs=Value(datetime(2010, 1, 1, 0, 0, 0), output_field=DateTimeField()),
rhs=Value(datetime(2010, 1, 1, 23, 59, 59), output_field=DateTimeField()), rhs=Value(datetime(2010, 1, 1, 23, 59, 59), output_field=DateTimeField()),
) )
msg = 'subclasses of YearComparisonLookup must provide a get_bound() method' msg = 'subclasses of YearLookup must provide a get_bound_params() method'
with self.assertRaisesMessage(NotImplementedError, msg): with self.assertRaisesMessage(NotImplementedError, msg):
look_up.get_bound(datetime(2010, 1, 1, 0, 0, 0), datetime(2010, 1, 1, 23, 59, 59)) look_up.get_bound_params(datetime(2010, 1, 1, 0, 0, 0), datetime(2010, 1, 1, 23, 59, 59))