Altered query string customization for backends vendors

The new way is trying to call first method 'as_' + connection.vendor.
If that doesn't exist, then call as_sql().

Also altered how lookup registration is done. There is now
RegisterLookupMixin class that is used by Field, Extract and
sql.Aggregate. This allows one to register lookups for extracts and
aggregates in the same way lookup registration is done for fields.
This commit is contained in:
Anssi Kääriäinen 2014-01-11 14:45:53 +02:00
parent 90e7004ec1
commit c7d5f8661b
7 changed files with 126 additions and 176 deletions

View File

@ -67,9 +67,6 @@ class BaseDatabaseWrapper(object):
self.allow_thread_sharing = allow_thread_sharing self.allow_thread_sharing = allow_thread_sharing
self._thread_ident = thread.get_ident() self._thread_ident = thread.get_ident()
# Compile implementations, used by compiler.compile(someelem)
self.compile_implementations = utils.get_implementations(self.vendor)
def __eq__(self, other): def __eq__(self, other):
if isinstance(other, BaseDatabaseWrapper): if isinstance(other, BaseDatabaseWrapper):
return self.alias == other.alias return self.alias == other.alias

View File

@ -195,31 +195,3 @@ def format_number(value, max_digits, decimal_places):
return "{0:f}".format(value.quantize(decimal.Decimal(".1") ** decimal_places, context=context)) return "{0:f}".format(value.quantize(decimal.Decimal(".1") ** decimal_places, context=context))
else: else:
return "%.*f" % (decimal_places, value) return "%.*f" % (decimal_places, value)
# Map of vendor name -> map of query element class -> implementation function
compile_implementations = defaultdict(dict)
def get_implementations(vendor):
return compile_implementations[vendor]
class add_implementation(object):
"""
A decorator to allow customised implementations for query expressions.
For example:
@add_implementation(Exact, 'mysql')
def mysql_exact(node, qn, connection):
# Play with the node here.
return somesql, list_of_params
Now Exact nodes are compiled to SQL using mysql_exact instead of
Exact.as_sql() when using MySQL backend.
"""
def __init__(self, klass, vendor):
self.klass = klass
self.vendor = vendor
def __call__(self, func):
implementations = get_implementations(self.vendor)
implementations[self.klass] = func
return func

View File

@ -4,7 +4,6 @@ import collections
import copy import copy
import datetime import datetime
import decimal import decimal
import inspect
import math import math
import warnings import warnings
from base64 import b64decode, b64encode from base64 import b64decode, b64encode
@ -12,7 +11,7 @@ from itertools import tee
from django.apps import apps from django.apps import apps
from django.db import connection from django.db import connection
from django.db.models.lookups import default_lookups from django.db.models.lookups import default_lookups, RegisterLookupMixin
from django.db.models.query_utils import QueryWrapper from django.db.models.query_utils import QueryWrapper
from django.conf import settings from django.conf import settings
from django import forms from django import forms
@ -82,7 +81,7 @@ def _empty(of_cls):
@total_ordering @total_ordering
class Field(object): class Field(RegisterLookupMixin):
"""Base class for all field types""" """Base class for all field types"""
# Designates whether empty strings fundamentally are allowed at the # Designates whether empty strings fundamentally are allowed at the
@ -459,30 +458,6 @@ class Field(object):
def get_internal_type(self): def get_internal_type(self):
return self.__class__.__name__ return self.__class__.__name__
def get_lookup(self, lookup_name):
try:
return self.class_lookups[lookup_name]
except KeyError:
for parent in inspect.getmro(self.__class__):
if not 'class_lookups' in parent.__dict__:
continue
if lookup_name in parent.class_lookups:
return parent.class_lookups[lookup_name]
@classmethod
def register_lookup(cls, lookup):
if not 'class_lookups' in cls.__dict__:
cls.class_lookups = {}
cls.class_lookups[lookup.lookup_name] = lookup
@classmethod
def _unregister_lookup(cls, lookup):
"""
Removes given lookup from cls lookups. Meant to be used in
tests only.
"""
del cls.class_lookups[lookup.lookup_name]
def pre_save(self, model_instance, add): def pre_save(self, model_instance, add):
""" """
Returns field's value just before saving. Returns field's value just before saving.

View File

@ -1,18 +1,49 @@
from copy import copy from copy import copy
import inspect
from django.conf import settings from django.conf import settings
from django.utils import timezone from django.utils import timezone
from django.utils.functional import cached_property from django.utils.functional import cached_property
class Extract(object): class RegisterLookupMixin(object):
def get_lookup(self, lookup_name):
try:
return self.class_lookups[lookup_name]
except KeyError:
# To allow for inheritance, check parent class class lookups.
for parent in inspect.getmro(self.__class__):
if not 'class_lookups' in parent.__dict__:
continue
if lookup_name in parent.class_lookups:
return parent.class_lookups[lookup_name]
except AttributeError:
# This class didn't have any class_lookups
pass
if hasattr(self, 'output_type'):
return self.output_type.get_lookup(lookup_name)
return None
@classmethod
def register_lookup(cls, lookup):
if not 'class_lookups' in cls.__dict__:
cls.class_lookups = {}
cls.class_lookups[lookup.lookup_name] = lookup
@classmethod
def _unregister_lookup(cls, lookup):
"""
Removes given lookup from cls lookups. Meant to be used in
tests only.
"""
del cls.class_lookups[lookup.lookup_name]
class Extract(RegisterLookupMixin):
def __init__(self, lhs, lookups): def __init__(self, lhs, lookups):
self.lhs = lhs self.lhs = lhs
self.init_lookups = lookups[:] self.init_lookups = lookups[:]
def get_lookup(self, lookup):
return self.output_type.get_lookup(lookup)
def as_sql(self, qn, connection): def as_sql(self, qn, connection):
raise NotImplementedError raise NotImplementedError
@ -27,7 +58,7 @@ class Extract(object):
return self.lhs.get_cols() return self.lhs.get_cols()
class Lookup(object): class Lookup(RegisterLookupMixin):
lookup_name = None lookup_name = None
def __init__(self, lhs, rhs): def __init__(self, lhs, rhs):

View File

@ -4,6 +4,7 @@ Classes to represent the default SQL aggregate functions
import copy import copy
from django.db.models.fields import IntegerField, FloatField from django.db.models.fields import IntegerField, FloatField
from django.db.models.lookups import RegisterLookupMixin
__all__ = ['Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance'] __all__ = ['Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance']
@ -14,7 +15,7 @@ ordinal_aggregate_field = IntegerField()
computed_aggregate_field = FloatField() computed_aggregate_field = FloatField()
class Aggregate(object): class Aggregate(RegisterLookupMixin):
""" """
Default SQL Aggregate. Default SQL Aggregate.
""" """
@ -100,9 +101,6 @@ class Aggregate(object):
def output_type(self): def output_type(self):
return self.field return self.field
def get_lookup(self, lookup):
return self.output_type.get_lookup(lookup)
class Avg(Aggregate): class Avg(Aggregate):
is_computed = True is_computed = True

View File

@ -70,9 +70,10 @@ class SQLCompiler(object):
return self(name) return self(name)
def compile(self, node): def compile(self, node):
if node.__class__ in self.connection.compile_implementations: vendor_impl = getattr(
return self.connection.compile_implementations[node.__class__]( node, 'as_' + self.connection.vendor, None)
node, self, self.connection) if vendor_impl:
return vendor_impl(self, self.connection)
else: else:
return node.as_sql(self, self.connection) return node.as_sql(self, self.connection)

View File

@ -1,4 +1,3 @@
from copy import copy
from datetime import date from datetime import date
import unittest import unittest
@ -6,7 +5,6 @@ from django.test import TestCase
from .models import Author from .models import Author
from django.db import models from django.db import models
from django.db import connection from django.db import connection
from django.db.backends.utils import add_implementation
class Div3Lookup(models.lookups.Lookup): class Div3Lookup(models.lookups.Lookup):
@ -27,6 +25,37 @@ class Div3Extract(models.lookups.Extract):
return '%s %%%% 3' % (lhs,), lhs_params return '%s %%%% 3' % (lhs,), lhs_params
class YearExtract(models.lookups.Extract):
lookup_name = 'year'
def as_sql(self, qn, connection):
lhs_sql, params = qn.compile(self.lhs)
return connection.ops.date_extract_sql('year', lhs_sql), params
@property
def output_type(self):
return models.IntegerField()
class YearExact(models.lookups.Lookup):
lookup_name = 'exact'
def as_sql(self, qn, connection):
# We will need to skip the extract part, and instead go
# directly with the originating field, that is self.lhs.lhs
lhs_sql, lhs_params = self.process_lhs(qn, connection, self.lhs.lhs)
rhs_sql, rhs_params = self.process_rhs(qn, connection)
# Note that we must be careful so that we have params in the
# same order as we have the parts in the SQL.
params = lhs_params + rhs_params + lhs_params + rhs_params
# We use PostgreSQL specific SQL here. Note that we must do the
# conversions in SQL instead of in Python to support F() references.
return ("%(lhs)s >= (%(rhs)s || '-01-01')::date "
"AND %(lhs)s <= (%(rhs)s || '-12-31')::date" %
{'lhs': lhs_sql, 'rhs': rhs_sql}, params)
YearExtract.register_lookup(YearExact)
class YearLte(models.lookups.LessThanOrEqual): class YearLte(models.lookups.LessThanOrEqual):
""" """
The purpose of this lookup is to efficiently compare the year of the field. The purpose of this lookup is to efficiently compare the year of the field.
@ -44,80 +73,27 @@ class YearLte(models.lookups.LessThanOrEqual):
# WHERE somecol <= '2013-12-31') # WHERE somecol <= '2013-12-31')
# but also make it work if the rhs_sql is field reference. # but also make it work if the rhs_sql is field reference.
return "%s <= (%s || '-12-31')::date" % (lhs_sql, rhs_sql), params return "%s <= (%s || '-12-31')::date" % (lhs_sql, rhs_sql), params
YearExtract.register_lookup(YearLte)
class YearExtract(models.lookups.Extract): # We will register this class temporarily in the test method.
lookup_name = 'year'
def as_sql(self, qn, connection):
lhs_sql, params = qn.compile(self.lhs)
return connection.ops.date_extract_sql('year', lhs_sql), params
@property
def output_type(self):
return models.IntegerField()
def get_lookup(self, lookup):
if lookup == 'lte':
return YearLte
elif lookup == 'exact':
return YearExact
else:
return super(YearExtract, self).get_lookup(lookup)
class YearExact(models.lookups.Lookup):
def as_sql(self, qn, connection):
# We will need to skip the extract part, and instead go
# directly with the originating field, that is self.lhs.lhs
lhs_sql, lhs_params = self.process_lhs(qn, connection, self.lhs.lhs)
rhs_sql, rhs_params = self.process_rhs(qn, connection)
# Note that we must be careful so that we have params in the
# same order as we have the parts in the SQL.
params = []
params.extend(lhs_params)
params.extend(rhs_params)
params.extend(lhs_params)
params.extend(rhs_params)
# We use PostgreSQL specific SQL here. Note that we must do the
# conversions in SQL instead of in Python to support F() references.
return ("%(lhs)s >= (%(rhs)s || '-01-01')::date "
"AND %(lhs)s <= (%(rhs)s || '-12-31')::date" %
{'lhs': lhs_sql, 'rhs': rhs_sql}, params)
@add_implementation(YearExact, 'mysql')
def mysql_year_exact(node, qn, connection):
lhs_sql, lhs_params = node.process_lhs(qn, connection, node.lhs.lhs)
rhs_sql, rhs_params = node.process_rhs(qn, connection)
params = []
params.extend(lhs_params)
params.extend(rhs_params)
params.extend(lhs_params)
params.extend(rhs_params)
return ("%(lhs)s >= str_to_date(concat(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') "
"AND %(lhs)s <= str_to_date(concat(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')" %
{'lhs': lhs_sql, 'rhs': rhs_sql}, params)
class InMonth(models.lookups.Lookup): class InMonth(models.lookups.Lookup):
""" """
InMonth matches if the column's month is contained in the value's month. InMonth matches if the column's month is the same as value's month.
""" """
lookup_name = 'inmonth' lookup_name = 'inmonth'
def as_sql(self, qn, connection): def as_sql(self, qn, connection):
lhs, params = self.process_lhs(qn, connection) lhs, lhs_params = self.process_lhs(qn, connection)
rhs, rhs_params = self.process_rhs(qn, connection) rhs, rhs_params = self.process_rhs(qn, connection)
# We need to be careful so that we get the params in right # We need to be careful so that we get the params in right
# places. # places.
full_params = params[:] params = lhs_params + rhs_params + lhs_params + rhs_params
full_params.extend(rhs_params)
full_params.extend(params)
full_params.extend(rhs_params)
return ("%s >= date_trunc('month', %s) and " return ("%s >= date_trunc('month', %s) and "
"%s < date_trunc('month', %s) + interval '1 months'" % "%s < date_trunc('month', %s) + interval '1 months'" %
(lhs, rhs, lhs, rhs), full_params) (lhs, rhs, lhs, rhs), params)
class LookupTests(TestCase): class LookupTests(TestCase):
@ -178,44 +154,6 @@ class LookupTests(TestCase):
finally: finally:
models.DateField._unregister_lookup(InMonth) models.DateField._unregister_lookup(InMonth)
def test_custom_compiles(self):
a1 = Author.objects.create(name='a1', age=1)
a2 = Author.objects.create(name='a2', age=2)
a3 = Author.objects.create(name='a3', age=3)
a4 = Author.objects.create(name='a4', age=4)
class AnotherEqual(models.lookups.Exact):
lookup_name = 'anotherequal'
models.Field.register_lookup(AnotherEqual)
try:
@add_implementation(AnotherEqual, connection.vendor)
def custom_eq_sql(node, qn, connection):
return '1 = 1', []
self.assertIn('1 = 1', str(Author.objects.filter(name__anotherequal='asdf').query))
self.assertQuerysetEqual(
Author.objects.filter(name__anotherequal='asdf').order_by('name'),
[a1, a2, a3, a4], lambda x: x)
@add_implementation(AnotherEqual, connection.vendor)
def another_custom_eq_sql(node, qn, connection):
# If you need to override one method, it seems this is the best
# option.
node = copy(node)
class OverriddenAnotherEqual(AnotherEqual):
def get_rhs_op(self, connection, rhs):
return ' <> %s'
node.__class__ = OverriddenAnotherEqual
return node.as_sql(qn, connection)
self.assertIn(' <> ', str(Author.objects.filter(name__anotherequal='a1').query))
self.assertQuerysetEqual(
Author.objects.filter(name__anotherequal='a1').order_by('name'),
[a2, a3, a4], lambda x: x
)
finally:
models.Field._unregister_lookup(AnotherEqual)
def test_div3_extract(self): def test_div3_extract(self):
models.IntegerField.register_lookup(Div3Extract) models.IntegerField.register_lookup(Div3Extract)
try: try:
@ -293,11 +231,49 @@ class YearLteTests(TestCase):
self.assertIn( self.assertIn(
'-12-31', str(baseqs.filter(birthdate__year__lte=2011).query)) '-12-31', str(baseqs.filter(birthdate__year__lte=2011).query))
@unittest.skipUnless(connection.vendor == 'mysql', 'MySQL specific SQL used') def test_postgres_year_exact(self):
def test_mysql_year_exact(self): baseqs = Author.objects.order_by('name')
self.assertQuerysetEqual(
Author.objects.filter(birthdate__year=2012).order_by('name'),
[self.a2, self.a3, self.a4], lambda x: x)
self.assertIn( self.assertIn(
'concat(', '= (2011 || ', str(baseqs.filter(birthdate__year=2011).query))
str(Author.objects.filter(birthdate__year=2012).query)) self.assertIn(
'-12-31', str(baseqs.filter(birthdate__year=2011).query))
def test_custom_implementation_year_exact(self):
try:
# Two ways to add a customized implementation for different backends:
# First is MonkeyPatch of the class.
def as_custom_sql(self, qn, connection):
lhs_sql, lhs_params = self.process_lhs(qn, connection, self.lhs.lhs)
rhs_sql, rhs_params = self.process_rhs(qn, connection)
params = lhs_params + rhs_params + lhs_params + rhs_params
return ("%(lhs)s >= str_to_date(concat(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') "
"AND %(lhs)s <= str_to_date(concat(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')" %
{'lhs': lhs_sql, 'rhs': rhs_sql}, params)
setattr(YearExact, 'as_' + connection.vendor, as_custom_sql)
self.assertIn(
'concat(',
str(Author.objects.filter(birthdate__year=2012).query))
finally:
delattr(YearExact, 'as_' + connection.vendor)
try:
# The other way is to subclass the original lookup and register the subclassed
# lookup instead of the original.
class CustomYearExact(YearExact):
# This method should be named "as_mysql" for MySQL, "as_postgresql" for postgres
# and so on, but as we don't know which DB we are running on, we need to use
# setattr.
def as_custom_sql(self, qn, connection):
lhs_sql, lhs_params = self.process_lhs(qn, connection, self.lhs.lhs)
rhs_sql, rhs_params = self.process_rhs(qn, connection)
params = lhs_params + rhs_params + lhs_params + rhs_params
return ("%(lhs)s >= str_to_date(CONCAT(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') "
"AND %(lhs)s <= str_to_date(CONCAT(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')" %
{'lhs': lhs_sql, 'rhs': rhs_sql}, params)
setattr(CustomYearExact, 'as_' + connection.vendor, CustomYearExact.as_custom_sql)
YearExtract.register_lookup(CustomYearExact)
self.assertIn(
'CONCAT(',
str(Author.objects.filter(birthdate__year=2012).query))
finally:
YearExtract._unregister_lookup(CustomYearExact)
YearExtract.register_lookup(YearExact)