Fixed #22648 -- Transform.output_type should respect overridden custom_lookup and custom_transform.
Previously, class lookups from the output_type would be used, but any changes to custom_lookup or custom_transform would be ignored.
This commit is contained in:
parent
11932e978f
commit
a2dd618e3b
|
@ -22,18 +22,20 @@ class RegisterLookupMixin(object):
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
# This class didn't have any class_lookups
|
# This class didn't have any class_lookups
|
||||||
pass
|
pass
|
||||||
if hasattr(self, 'output_type'):
|
|
||||||
return self.output_type.get_lookup(lookup_name)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_lookup(self, lookup_name):
|
def get_lookup(self, lookup_name):
|
||||||
found = self._get_lookup(lookup_name)
|
found = self._get_lookup(lookup_name)
|
||||||
|
if found is None and hasattr(self, 'output_type'):
|
||||||
|
return self.output_type.get_lookup(lookup_name)
|
||||||
if found is not None and not issubclass(found, Lookup):
|
if found is not None and not issubclass(found, Lookup):
|
||||||
return None
|
return None
|
||||||
return found
|
return found
|
||||||
|
|
||||||
def get_transform(self, lookup_name):
|
def get_transform(self, lookup_name):
|
||||||
found = self._get_lookup(lookup_name)
|
found = self._get_lookup(lookup_name)
|
||||||
|
if found is None and hasattr(self, 'output_type'):
|
||||||
|
return self.output_type.get_transform(lookup_name)
|
||||||
if found is not None and not issubclass(found, Transform):
|
if found is not None and not issubclass(found, Transform):
|
||||||
return None
|
return None
|
||||||
return found
|
return found
|
||||||
|
|
|
@ -89,6 +89,47 @@ class YearLte(models.lookups.LessThanOrEqual):
|
||||||
YearTransform.register_lookup(YearLte)
|
YearTransform.register_lookup(YearLte)
|
||||||
|
|
||||||
|
|
||||||
|
class SQLFunc(models.Lookup):
|
||||||
|
def __init__(self, name, *args, **kwargs):
|
||||||
|
super(SQLFunc, self).__init__(*args, **kwargs)
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def as_sql(self, qn, connection):
|
||||||
|
return '%s()', [self.name]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_type(self):
|
||||||
|
return CustomField()
|
||||||
|
|
||||||
|
|
||||||
|
class SQLFuncFactory(object):
|
||||||
|
|
||||||
|
def __init__(self, name):
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
return SQLFunc(self.name, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomField(models.Field):
|
||||||
|
|
||||||
|
def get_lookup(self, lookup_name):
|
||||||
|
if lookup_name.startswith('lookupfunc_'):
|
||||||
|
key, name = lookup_name.split('_', 1)
|
||||||
|
return SQLFuncFactory(name)
|
||||||
|
return super(CustomField, self).get_lookup(lookup_name)
|
||||||
|
|
||||||
|
def get_transform(self, lookup_name):
|
||||||
|
if lookup_name.startswith('transformfunc_'):
|
||||||
|
key, name = lookup_name.split('_', 1)
|
||||||
|
return SQLFuncFactory(name)
|
||||||
|
return super(CustomField, self).get_transform(lookup_name)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomModel(models.Model):
|
||||||
|
field = CustomField()
|
||||||
|
|
||||||
|
|
||||||
# We will register this class temporarily in the test method.
|
# We will register this class temporarily in the test method.
|
||||||
|
|
||||||
|
|
||||||
|
@ -341,3 +382,22 @@ class LookupTransformCallOrderTests(TestCase):
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
models.DateField._unregister_lookup(TrackCallsYearTransform)
|
models.DateField._unregister_lookup(TrackCallsYearTransform)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomisedMethodsTests(TestCase):
|
||||||
|
|
||||||
|
def test_overridden_get_lookup(self):
|
||||||
|
q = CustomModel.objects.filter(field__lookupfunc_monkeys=3)
|
||||||
|
self.assertIn('monkeys()', str(q.query))
|
||||||
|
|
||||||
|
def test_overridden_get_transform(self):
|
||||||
|
q = CustomModel.objects.filter(field__transformfunc_banana=3)
|
||||||
|
self.assertIn('banana()', str(q.query))
|
||||||
|
|
||||||
|
def test_overridden_get_lookup_chain(self):
|
||||||
|
q = CustomModel.objects.filter(field__transformfunc_banana__lookupfunc_elephants=3)
|
||||||
|
self.assertIn('elephants()', str(q.query))
|
||||||
|
|
||||||
|
def test_overridden_get_transform_chain(self):
|
||||||
|
q = CustomModel.objects.filter(field__transformfunc_banana__transformfunc_pear=3)
|
||||||
|
self.assertIn('pear()', str(q.query))
|
||||||
|
|
Loading…
Reference in New Issue