Refs #24267 -- Implemented lookups for related fields

Previously related fields didn't implement get_lookup, instead
related fields were treated specially. This commit removed some of
the special handling. In particular, related fields return Lookup
instances now, too.

Other notable changes in this commit is removal of support for
annotations in names_to_path().
This commit is contained in:
Anssi Kääriäinen 2015-02-02 13:48:30 +02:00 committed by Tim Graham
parent 8654c6a732
commit b68212f539
6 changed files with 219 additions and 94 deletions

View File

@ -15,7 +15,10 @@ from django.db.models.fields import (
BLANK_CHOICE_DASH, AutoField, Field, IntegerField, PositiveIntegerField, BLANK_CHOICE_DASH, AutoField, Field, IntegerField, PositiveIntegerField,
PositiveSmallIntegerField, PositiveSmallIntegerField,
) )
from django.db.models.lookups import IsNull from django.db.models.fields.related_lookups import (
RelatedExact, RelatedGreaterThan, RelatedGreaterThanOrEqual, RelatedIn,
RelatedLessThan, RelatedLessThanOrEqual,
)
from django.db.models.query import QuerySet from django.db.models.query import QuerySet
from django.db.models.query_utils import PathInfo from django.db.models.query_utils import PathInfo
from django.utils import six from django.utils import six
@ -1336,6 +1339,16 @@ class ForeignObjectRel(object):
def one_to_one(self): def one_to_one(self):
return self.field.one_to_one return self.field.one_to_one
def get_prep_lookup(self, lookup_name, value):
return self.field.get_prep_lookup(lookup_name, value)
def get_internal_type(self):
return self.field.get_internal_type()
@property
def db_type(self):
return self.field.db_type
def __repr__(self): def __repr__(self):
return '<%s: %s.%s>' % ( return '<%s: %s.%s>' % (
type(self).__name__, type(self).__name__,
@ -1760,67 +1773,25 @@ class ForeignObject(RelatedField):
pathinfos = [PathInfo(from_opts, opts, (opts.pk,), self.rel, not self.unique, False)] pathinfos = [PathInfo(from_opts, opts, (opts.pk,), self.rel, not self.unique, False)]
return pathinfos return pathinfos
def get_lookup_constraint(self, constraint_class, alias, targets, sources, lookups, def get_lookup(self, lookup_name):
raw_value): if lookup_name == 'in':
from django.db.models.sql.where import SubqueryConstraint, AND, OR return RelatedIn
root_constraint = constraint_class() elif lookup_name == 'exact':
assert len(targets) == len(sources) return RelatedExact
if len(lookups) > 1: elif lookup_name == 'gt':
raise exceptions.FieldError( return RelatedGreaterThan
"Cannot resolve keyword %r into field. Choices are: %s" % ( elif lookup_name == 'gte':
lookups[0], return RelatedGreaterThanOrEqual
", ".join(f.name for f in self.model._meta.get_fields()), elif lookup_name == 'lt':
) return RelatedLessThan
) elif lookup_name == 'lte':
lookup_type = lookups[0] return RelatedLessThanOrEqual
elif lookup_name != 'isnull':
raise TypeError('Related Field got invalid lookup: %s' % lookup_name)
return super(ForeignObject, self).get_lookup(lookup_name)
def get_normalized_value(value): def get_transform(self, *args, **kwargs):
from django.db.models import Model raise NotImplementedError('Relational fields do not support transforms.')
if isinstance(value, Model):
value_list = []
for source in sources:
# Account for one-to-one relations when sent a different model
while not isinstance(value, source.model) and source.rel:
source = source.rel.to._meta.get_field(source.rel.field_name)
value_list.append(getattr(value, source.attname))
return tuple(value_list)
elif not isinstance(value, tuple):
return (value,)
return value
is_multicolumn = len(self.related_fields) > 1
if (hasattr(raw_value, '_as_sql') or
hasattr(raw_value, 'get_compiler')):
root_constraint.add(SubqueryConstraint(alias, [target.column for target in targets],
[source.name for source in sources], raw_value),
AND)
elif lookup_type == 'isnull':
root_constraint.add(IsNull(targets[0].get_col(alias, sources[0]), raw_value), AND)
elif (lookup_type == 'exact' or (lookup_type in ['gt', 'lt', 'gte', 'lte']
and not is_multicolumn)):
value = get_normalized_value(raw_value)
for target, source, val in zip(targets, sources, value):
lookup_class = target.get_lookup(lookup_type)
root_constraint.add(
lookup_class(target.get_col(alias, source), val), AND)
elif lookup_type in ['range', 'in'] and not is_multicolumn:
values = [get_normalized_value(value) for value in raw_value]
value = [val[0] for val in values]
lookup_class = targets[0].get_lookup(lookup_type)
root_constraint.add(lookup_class(targets[0].get_col(alias, sources[0]), value), AND)
elif lookup_type == 'in':
values = [get_normalized_value(value) for value in raw_value]
root_constraint.connector = OR
for value in values:
value_constraint = constraint_class()
for source, target, val in zip(sources, targets, value):
lookup_class = target.get_lookup('exact')
lookup = lookup_class(target.get_col(alias, source), val)
value_constraint.add(lookup, AND)
root_constraint.add(value_constraint, OR)
else:
raise TypeError('Related Field got invalid lookup: %s' % lookup_type)
return root_constraint
@property @property
def attnames(self): def attnames(self):
@ -2017,6 +1988,9 @@ class ForeignKey(ForeignObject):
else: else:
return self.related_field.get_db_prep_save(value, connection=connection) return self.related_field.get_db_prep_save(value, connection=connection)
def get_db_prep_value(self, value, connection, prepared=False):
return self.related_field.get_db_prep_value(value, connection, prepared)
def value_to_string(self, obj): def value_to_string(self, obj):
if not obj: if not obj:
# In required many-to-one fields with only one available choice, # In required many-to-one fields with only one available choice,

View File

@ -0,0 +1,130 @@
from django.db.models.lookups import (
Exact, GreaterThan, GreaterThanOrEqual, In, LessThan, LessThanOrEqual,
)
class MultiColSource(object):
contains_aggregate = False
def __init__(self, alias, targets, sources, field):
self.targets, self.sources, self.field, self.alias = targets, sources, field, alias
self.output_field = self.field
def __repr__(self):
return "{}({}, {})".format(
self.__class__.__name__, self.alias, self.field)
def relabeled_clone(self, relabels):
return self.__class__(relabels.get(self.alias, self.alias),
self.targets, self.sources, self.field)
def get_normalized_value(value, lhs):
from django.db.models import Model
if isinstance(value, Model):
value_list = []
# Account for one-to-one relations when sent a different model
sources = lhs.output_field.get_path_info()[-1].target_fields
for source in sources:
while not isinstance(value, source.model) and source.rel:
source = source.rel.to._meta.get_field(source.rel.field_name)
value_list.append(getattr(value, source.attname))
return tuple(value_list)
if not isinstance(value, tuple):
return (value,)
return value
class RelatedIn(In):
def get_prep_lookup(self):
if not isinstance(self.lhs, MultiColSource) and self.rhs_is_direct_value():
# If we get here, we are dealing with single-column relations.
self.rhs = [get_normalized_value(val, self.lhs)[0] for val in self.rhs]
# We need to run the related field's get_prep_lookup(). Consider case
# ForeignKey to IntegerField given value 'abc'. The ForeignKey itself
# doesn't have validation for non-integers, so we must run validation
# using the target field.
if hasattr(self.lhs.output_field, 'get_path_info'):
# Run the target field's get_prep_lookup. We can safely assume there is
# only one as we don't get to the direct value branch otherwise.
self.rhs = self.lhs.output_field.get_path_info()[-1].target_fields[-1].get_prep_lookup(
self.lookup_name, self.rhs)
return super(RelatedIn, self).get_prep_lookup()
def as_sql(self, compiler, connection):
if isinstance(self.lhs, MultiColSource):
# For multicolumn lookups we need to build a multicolumn where clause.
# This clause is either a SubqueryConstraint (for values that need to be compiled to
# SQL) or a OR-combined list of (col1 = val1 AND col2 = val2 AND ...) clauses.
from django.db.models.sql.where import WhereNode, SubqueryConstraint, AND, OR
root_constraint = WhereNode(connector=OR)
if self.rhs_is_direct_value():
values = [get_normalized_value(value, self.lhs) for value in self.rhs]
for value in values:
value_constraint = WhereNode()
for source, target, val in zip(self.lhs.sources, self.lhs.targets, value):
lookup_class = target.get_lookup('exact')
lookup = lookup_class(target.get_col(self.lhs.alias, source), val)
value_constraint.add(lookup, AND)
root_constraint.add(value_constraint, OR)
else:
root_constraint.add(
SubqueryConstraint(
self.lhs.alias, [target.column for target in self.lhs.targets],
[source.name for source in self.lhs.sources], self.rhs),
AND)
return root_constraint.as_sql(compiler, connection)
else:
return super(RelatedIn, self).as_sql(compiler, connection)
class RelatedLookupMixin(object):
def get_prep_lookup(self):
if not isinstance(self.lhs, MultiColSource) and self.rhs_is_direct_value():
# If we get here, we are dealing with single-column relations.
self.rhs = get_normalized_value(self.rhs, self.lhs)[0]
# We need to run the related field's get_prep_lookup(). Consider case
# ForeignKey to IntegerField given value 'abc'. The ForeignKey itself
# doesn't have validation for non-integers, so we must run validation
# using the target field.
if hasattr(self.lhs.output_field, 'get_path_info'):
# Get the target field. We can safely assume there is only one
# as we don't get to the direct value branch otherwise.
self.rhs = self.lhs.output_field.get_path_info()[-1].target_fields[-1].get_prep_lookup(
self.lookup_name, self.rhs)
return super(RelatedLookupMixin, self).get_prep_lookup()
def as_sql(self, compiler, connection):
if isinstance(self.lhs, MultiColSource):
assert self.rhs_is_direct_value()
self.rhs = get_normalized_value(self.rhs, self.lhs)
from django.db.models.sql.where import WhereNode, AND
root_constraint = WhereNode()
for target, source, val in zip(self.lhs.targets, self.lhs.sources, self.rhs):
lookup_class = target.get_lookup(self.lookup_name)
root_constraint.add(
lookup_class(target.get_col(self.lhs.alias, source), val), AND)
return root_constraint.as_sql(compiler, connection)
return super(RelatedLookupMixin, self).as_sql(compiler, connection)
class RelatedExact(RelatedLookupMixin, Exact):
pass
class RelatedLessThan(RelatedLookupMixin, LessThan):
pass
class RelatedGreaterThan(RelatedLookupMixin, GreaterThan):
pass
class RelatedGreaterThanOrEqual(RelatedLookupMixin, GreaterThanOrEqual):
pass
class RelatedLessThanOrEqual(RelatedLookupMixin, LessThanOrEqual):
pass

View File

@ -250,7 +250,7 @@ deferred_class_factory.__safe_for_unpickling__ = True
def refs_aggregate(lookup_parts, aggregates): def refs_aggregate(lookup_parts, aggregates):
""" """
A little helper method to check if the lookup_parts contains references A helper method to check if the lookup_parts contains references
to the given aggregates set. Because the LOOKUP_SEP is contained in the to the given aggregates set. Because the LOOKUP_SEP is contained in the
default annotation names we must check each prefix of the lookup_parts default annotation names we must check each prefix of the lookup_parts
for a match. for a match.
@ -260,3 +260,17 @@ def refs_aggregate(lookup_parts, aggregates):
if level_n_lookup in aggregates and aggregates[level_n_lookup].contains_aggregate: if level_n_lookup in aggregates and aggregates[level_n_lookup].contains_aggregate:
return aggregates[level_n_lookup], lookup_parts[n:] return aggregates[level_n_lookup], lookup_parts[n:]
return False, () return False, ()
def refs_expression(lookup_parts, annotations):
"""
A helper method to check if the lookup_parts contains references
to the given annotations set. Because the LOOKUP_SEP is contained in the
default annotation names we must check each prefix of the lookup_parts
for a match.
"""
for n in range(len(lookup_parts) + 1):
level_n_lookup = LOOKUP_SEP.join(lookup_parts[0:n])
if level_n_lookup in annotations and annotations[level_n_lookup]:
return annotations[level_n_lookup], lookup_parts[n:]
return False, ()

View File

@ -17,7 +17,8 @@ from django.db import DEFAULT_DB_ALIAS, connections
from django.db.models.aggregates import Count from django.db.models.aggregates import Count
from django.db.models.constants import LOOKUP_SEP from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import Col, Ref from django.db.models.expressions import Col, Ref
from django.db.models.query_utils import Q, PathInfo, refs_aggregate from django.db.models.fields.related_lookups import MultiColSource
from django.db.models.query_utils import Q, PathInfo, refs_expression
from django.db.models.sql.constants import ( from django.db.models.sql.constants import (
INNER, LOUTER, ORDER_DIR, ORDER_PATTERN, QUERY_TERMS, SINGLE, INNER, LOUTER, ORDER_DIR, ORDER_PATTERN, QUERY_TERMS, SINGLE,
) )
@ -1006,7 +1007,7 @@ class Query(object):
""" """
lookup_splitted = lookup.split(LOOKUP_SEP) lookup_splitted = lookup.split(LOOKUP_SEP)
if self._annotations: if self._annotations:
aggregate, aggregate_lookups = refs_aggregate(lookup_splitted, self.annotations) aggregate, aggregate_lookups = refs_expression(lookup_splitted, self.annotations)
if aggregate: if aggregate:
return aggregate_lookups, (), aggregate return aggregate_lookups, (), aggregate
_, field, _, lookup_parts = self.names_to_path(lookup_splitted, self.get_meta()) _, field, _, lookup_parts = self.names_to_path(lookup_splitted, self.get_meta())
@ -1157,22 +1158,24 @@ class Query(object):
if can_reuse is not None: if can_reuse is not None:
can_reuse.update(join_list) can_reuse.update(join_list)
used_joins = set(used_joins).union(set(join_list)) used_joins = set(used_joins).union(set(join_list))
# Process the join list to see if we can remove any non-needed joins from
# the far end (fewer tables in a query is better).
targets, alias, join_list = self.trim_joins(sources, join_list, path) targets, alias, join_list = self.trim_joins(sources, join_list, path)
if hasattr(field, 'get_lookup_constraint'): if field.is_relation:
# For now foreign keys get special treatment. This should be # No support for transforms for relational fields
# refactored when composite fields lands. assert len(lookups) == 1
condition = field.get_lookup_constraint(self.where_class, alias, targets, sources, lookup_class = field.get_lookup(lookups[0])
lookups, value) # Undo the changes done in setup_joins() if hasattr(final_field, 'field') branch
lookup_type = lookups[-1] # This hack is needed as long as the field.rel isn't like a real field.
if field.get_path_info()[-1].target_fields != sources:
target_field = field.rel
else: else:
assert(len(targets) == 1) target_field = field
if hasattr(targets[0], 'as_sql'): if len(targets) == 1:
# handle Expressions as annotations lhs = targets[0].get_col(alias, target_field)
col = targets[0] else:
lhs = MultiColSource(alias, targets, sources, target_field)
condition = lookup_class(lhs, value)
lookup_type = lookup_class.lookup_name
else: else:
col = targets[0].get_col(alias, field) col = targets[0].get_col(alias, field)
condition = self.build_lookup(lookups, col, value) condition = self.build_lookup(lookups, col, value)
@ -1284,14 +1287,6 @@ class Query(object):
) )
model = field.model._meta.concrete_model model = field.model._meta.concrete_model
except FieldDoesNotExist: except FieldDoesNotExist:
# is it an annotation?
if self._annotations and name in self._annotations:
field, model = self._annotations[name], None
if not field.contains_aggregate:
# Local non-relational field.
final_field = field
targets = (field,)
break
# We didn't find the current field, so move position back # We didn't find the current field, so move position back
# one step. # one step.
pos -= 1 pos -= 1
@ -1985,7 +1980,7 @@ def is_reverse_o2o(field):
A little helper to check if the given field is reverse-o2o. The field is A little helper to check if the given field is reverse-o2o. The field is
expected to be some sort of relation field or related object. expected to be some sort of relation field or related object.
""" """
return not hasattr(field, 'rel') and field.field.unique return field.is_relation and field.one_to_one and not field.concrete
class JoinPromoter(object): class JoinPromoter(object):

View File

@ -144,22 +144,26 @@ class GenericRelationTests(TestCase):
tag.save() tag.save()
def test_ticket_20378(self): def test_ticket_20378(self):
# Create a couple of extra HasLinkThing so that the autopk value
# isn't the same for Link and HasLinkThing.
hs1 = HasLinkThing.objects.create() hs1 = HasLinkThing.objects.create()
hs2 = HasLinkThing.objects.create() hs2 = HasLinkThing.objects.create()
l1 = Link.objects.create(content_object=hs1) hs3 = HasLinkThing.objects.create()
l2 = Link.objects.create(content_object=hs2) hs4 = HasLinkThing.objects.create()
l1 = Link.objects.create(content_object=hs3)
l2 = Link.objects.create(content_object=hs4)
self.assertQuerysetEqual( self.assertQuerysetEqual(
HasLinkThing.objects.filter(links=l1), HasLinkThing.objects.filter(links=l1),
[hs1], lambda x: x) [hs3], lambda x: x)
self.assertQuerysetEqual( self.assertQuerysetEqual(
HasLinkThing.objects.filter(links=l2), HasLinkThing.objects.filter(links=l2),
[hs2], lambda x: x) [hs4], lambda x: x)
self.assertQuerysetEqual( self.assertQuerysetEqual(
HasLinkThing.objects.exclude(links=l2), HasLinkThing.objects.exclude(links=l2),
[hs1], lambda x: x) [hs1, hs2, hs3], lambda x: x, ordered=False)
self.assertQuerysetEqual( self.assertQuerysetEqual(
HasLinkThing.objects.exclude(links=l1), HasLinkThing.objects.exclude(links=l1),
[hs2], lambda x: x) [hs1, hs2, hs4], lambda x: x, ordered=False)
def test_ticket_20564(self): def test_ticket_20564(self):
b1 = B.objects.create() b1 = B.objects.create()

View File

@ -3678,3 +3678,11 @@ class TestTicket24279(TestCase):
School.objects.create() School.objects.create()
qs = School.objects.filter(Q(pk__in=()) | Q()) qs = School.objects.filter(Q(pk__in=()) | Q())
self.assertQuerysetEqual(qs, []) self.assertQuerysetEqual(qs, [])
class TestInvalidValuesRelation(TestCase):
def test_invalid_values(self):
with self.assertRaises(ValueError):
Annotation.objects.filter(tag='abc')
with self.assertRaises(ValueError):
Annotation.objects.filter(tag__in=[123, 'abc'])