From a2dd618e3b4a7472fab852da450ca5eef92a922f Mon Sep 17 00:00:00 2001 From: Marc Tamlyn Date: Fri, 16 May 2014 19:56:44 +0200 Subject: [PATCH] Fixed #22648 -- Transform.output_type should respect overridden custom_lookup and custom_transform. Previously, class lookups from the output_type would be used, but any changes to custom_lookup or custom_transform would be ignored. --- django/db/models/lookups.py | 6 ++-- tests/custom_lookups/tests.py | 60 +++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 2 deletions(-) diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py index 86ec9c2222..b94090ea1a 100644 --- a/django/db/models/lookups.py +++ b/django/db/models/lookups.py @@ -22,18 +22,20 @@ class RegisterLookupMixin(object): except AttributeError: # This class didn't have any class_lookups pass - if hasattr(self, 'output_type'): - return self.output_type.get_lookup(lookup_name) return None def get_lookup(self, lookup_name): found = self._get_lookup(lookup_name) + if found is None and hasattr(self, 'output_type'): + return self.output_type.get_lookup(lookup_name) if found is not None and not issubclass(found, Lookup): return None return found def get_transform(self, lookup_name): found = self._get_lookup(lookup_name) + if found is None and hasattr(self, 'output_type'): + return self.output_type.get_transform(lookup_name) if found is not None and not issubclass(found, Transform): return None return found diff --git a/tests/custom_lookups/tests.py b/tests/custom_lookups/tests.py index 396974b4b1..c7af60b54d 100644 --- a/tests/custom_lookups/tests.py +++ b/tests/custom_lookups/tests.py @@ -89,6 +89,47 @@ class YearLte(models.lookups.LessThanOrEqual): YearTransform.register_lookup(YearLte) +class SQLFunc(models.Lookup): + def __init__(self, name, *args, **kwargs): + super(SQLFunc, self).__init__(*args, **kwargs) + self.name = name + + def as_sql(self, qn, connection): + return '%s()', [self.name] + + @property + def output_type(self): + return CustomField() + + +class SQLFuncFactory(object): + + def __init__(self, name): + self.name = name + + def __call__(self, *args, **kwargs): + return SQLFunc(self.name, *args, **kwargs) + + +class CustomField(models.Field): + + def get_lookup(self, lookup_name): + if lookup_name.startswith('lookupfunc_'): + key, name = lookup_name.split('_', 1) + return SQLFuncFactory(name) + return super(CustomField, self).get_lookup(lookup_name) + + def get_transform(self, lookup_name): + if lookup_name.startswith('transformfunc_'): + key, name = lookup_name.split('_', 1) + return SQLFuncFactory(name) + return super(CustomField, self).get_transform(lookup_name) + + +class CustomModel(models.Model): + field = CustomField() + + # We will register this class temporarily in the test method. @@ -341,3 +382,22 @@ class LookupTransformCallOrderTests(TestCase): finally: models.DateField._unregister_lookup(TrackCallsYearTransform) + + +class CustomisedMethodsTests(TestCase): + + def test_overridden_get_lookup(self): + q = CustomModel.objects.filter(field__lookupfunc_monkeys=3) + self.assertIn('monkeys()', str(q.query)) + + def test_overridden_get_transform(self): + q = CustomModel.objects.filter(field__transformfunc_banana=3) + self.assertIn('banana()', str(q.query)) + + def test_overridden_get_lookup_chain(self): + q = CustomModel.objects.filter(field__transformfunc_banana__lookupfunc_elephants=3) + self.assertIn('elephants()', str(q.query)) + + def test_overridden_get_transform_chain(self): + q = CustomModel.objects.filter(field__transformfunc_banana__transformfunc_pear=3) + self.assertIn('pear()', str(q.query))