Fixed #27332 -- Added FilteredRelation API for conditional join (ON clause) support.

Thanks Anssi Kääriäinen for contributing to the patch.
This commit is contained in:
Nicolas Delaby 2017-09-22 17:53:17 +02:00 committed by Tim Graham
parent 3f9d85d95c
commit 01d440fa1e
17 changed files with 916 additions and 83 deletions

View File

@ -348,7 +348,7 @@ class GenericRelation(ForeignObject):
self.to_fields = [self.model._meta.pk.name] self.to_fields = [self.model._meta.pk.name]
return [(self.remote_field.model._meta.get_field(self.object_id_field_name), self.model._meta.pk)] return [(self.remote_field.model._meta.get_field(self.object_id_field_name), self.model._meta.pk)]
def _get_path_info_with_parent(self): def _get_path_info_with_parent(self, filtered_relation):
""" """
Return the path that joins the current model through any parent models. Return the path that joins the current model through any parent models.
The idea is that if you have a GFK defined on a parent model then we The idea is that if you have a GFK defined on a parent model then we
@ -365,7 +365,15 @@ class GenericRelation(ForeignObject):
opts = self.remote_field.model._meta.concrete_model._meta opts = self.remote_field.model._meta.concrete_model._meta
parent_opts = opts.get_field(self.object_id_field_name).model._meta parent_opts = opts.get_field(self.object_id_field_name).model._meta
target = parent_opts.pk target = parent_opts.pk
path.append(PathInfo(self.model._meta, parent_opts, (target,), self.remote_field, True, False)) path.append(PathInfo(
from_opts=self.model._meta,
to_opts=parent_opts,
target_fields=(target,),
join_field=self.remote_field,
m2m=True,
direct=False,
filtered_relation=filtered_relation,
))
# Collect joins needed for the parent -> child chain. This is easiest # Collect joins needed for the parent -> child chain. This is easiest
# to do if we collect joins for the child -> parent chain and then # to do if we collect joins for the child -> parent chain and then
# reverse the direction (call to reverse() and use of # reverse the direction (call to reverse() and use of
@ -380,19 +388,35 @@ class GenericRelation(ForeignObject):
path.extend(field.remote_field.get_path_info()) path.extend(field.remote_field.get_path_info())
return path return path
def get_path_info(self): def get_path_info(self, filtered_relation=None):
opts = self.remote_field.model._meta opts = self.remote_field.model._meta
object_id_field = opts.get_field(self.object_id_field_name) object_id_field = opts.get_field(self.object_id_field_name)
if object_id_field.model != opts.model: if object_id_field.model != opts.model:
return self._get_path_info_with_parent() return self._get_path_info_with_parent(filtered_relation)
else: else:
target = opts.pk target = opts.pk
return [PathInfo(self.model._meta, opts, (target,), self.remote_field, True, False)] return [PathInfo(
from_opts=self.model._meta,
to_opts=opts,
target_fields=(target,),
join_field=self.remote_field,
m2m=True,
direct=False,
filtered_relation=filtered_relation,
)]
def get_reverse_path_info(self): def get_reverse_path_info(self, filtered_relation=None):
opts = self.model._meta opts = self.model._meta
from_opts = self.remote_field.model._meta from_opts = self.remote_field.model._meta
return [PathInfo(from_opts, opts, (opts.pk,), self, not self.unique, False)] return [PathInfo(
from_opts=from_opts,
to_opts=opts,
target_fields=(opts.pk,),
join_field=self,
m2m=not self.unique,
direct=False,
filtered_relation=filtered_relation,
)]
def value_to_string(self, obj): def value_to_string(self, obj):
qs = getattr(obj, self.name).all() qs = getattr(obj, self.name).all()

View File

@ -20,6 +20,7 @@ from django.db.models.manager import Manager
from django.db.models.query import ( from django.db.models.query import (
Prefetch, Q, QuerySet, prefetch_related_objects, Prefetch, Q, QuerySet, prefetch_related_objects,
) )
from django.db.models.query_utils import FilteredRelation
# Imports that would create circular imports if sorted # Imports that would create circular imports if sorted
from django.db.models.base import DEFERRED, Model # isort:skip from django.db.models.base import DEFERRED, Model # isort:skip
@ -69,6 +70,7 @@ __all__ += [
'Window', 'WindowFrame', 'Window', 'WindowFrame',
'FileField', 'ImageField', 'OrderWrt', 'Lookup', 'Transform', 'Manager', 'FileField', 'ImageField', 'OrderWrt', 'Lookup', 'Transform', 'Manager',
'Prefetch', 'Q', 'QuerySet', 'prefetch_related_objects', 'DEFERRED', 'Model', 'Prefetch', 'Q', 'QuerySet', 'prefetch_related_objects', 'DEFERRED', 'Model',
'FilteredRelation',
'ForeignKey', 'ForeignObject', 'OneToOneField', 'ManyToManyField', 'ForeignKey', 'ForeignObject', 'OneToOneField', 'ManyToManyField',
'ManyToOneRel', 'ManyToManyRel', 'OneToOneRel', 'permalink', 'ManyToOneRel', 'ManyToManyRel', 'OneToOneRel', 'permalink',
] ]

View File

@ -697,18 +697,33 @@ class ForeignObject(RelatedField):
""" """
return None return None
def get_path_info(self): def get_path_info(self, filtered_relation=None):
"""Get path from this field to the related model.""" """Get path from this field to the related model."""
opts = self.remote_field.model._meta opts = self.remote_field.model._meta
from_opts = self.model._meta from_opts = self.model._meta
return [PathInfo(from_opts, opts, self.foreign_related_fields, self, False, True)] return [PathInfo(
from_opts=from_opts,
to_opts=opts,
target_fields=self.foreign_related_fields,
join_field=self,
m2m=False,
direct=True,
filtered_relation=filtered_relation,
)]
def get_reverse_path_info(self): def get_reverse_path_info(self, filtered_relation=None):
"""Get path from the related model to this field's model.""" """Get path from the related model to this field's model."""
opts = self.model._meta opts = self.model._meta
from_opts = self.remote_field.model._meta from_opts = self.remote_field.model._meta
pathinfos = [PathInfo(from_opts, opts, (opts.pk,), self.remote_field, not self.unique, False)] return [PathInfo(
return pathinfos from_opts=from_opts,
to_opts=opts,
target_fields=(opts.pk,),
join_field=self.remote_field,
m2m=not self.unique,
direct=False,
filtered_relation=filtered_relation,
)]
@classmethod @classmethod
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
@ -861,12 +876,19 @@ class ForeignKey(ForeignObject):
def target_field(self): def target_field(self):
return self.foreign_related_fields[0] return self.foreign_related_fields[0]
def get_reverse_path_info(self): def get_reverse_path_info(self, filtered_relation=None):
"""Get path from the related model to this field's model.""" """Get path from the related model to this field's model."""
opts = self.model._meta opts = self.model._meta
from_opts = self.remote_field.model._meta from_opts = self.remote_field.model._meta
pathinfos = [PathInfo(from_opts, opts, (opts.pk,), self.remote_field, not self.unique, False)] return [PathInfo(
return pathinfos from_opts=from_opts,
to_opts=opts,
target_fields=(opts.pk,),
join_field=self.remote_field,
m2m=not self.unique,
direct=False,
filtered_relation=filtered_relation,
)]
def validate(self, value, model_instance): def validate(self, value, model_instance):
if self.remote_field.parent_link: if self.remote_field.parent_link:
@ -1435,7 +1457,7 @@ class ManyToManyField(RelatedField):
) )
return name, path, args, kwargs return name, path, args, kwargs
def _get_path_info(self, direct=False): def _get_path_info(self, direct=False, filtered_relation=None):
"""Called by both direct and indirect m2m traversal.""" """Called by both direct and indirect m2m traversal."""
pathinfos = [] pathinfos = []
int_model = self.remote_field.through int_model = self.remote_field.through
@ -1443,10 +1465,10 @@ class ManyToManyField(RelatedField):
linkfield2 = int_model._meta.get_field(self.m2m_reverse_field_name()) linkfield2 = int_model._meta.get_field(self.m2m_reverse_field_name())
if direct: if direct:
join1infos = linkfield1.get_reverse_path_info() join1infos = linkfield1.get_reverse_path_info()
join2infos = linkfield2.get_path_info() join2infos = linkfield2.get_path_info(filtered_relation)
else: else:
join1infos = linkfield2.get_reverse_path_info() join1infos = linkfield2.get_reverse_path_info()
join2infos = linkfield1.get_path_info() join2infos = linkfield1.get_path_info(filtered_relation)
# Get join infos between the last model of join 1 and the first model # Get join infos between the last model of join 1 and the first model
# of join 2. Assume the only reason these may differ is due to model # of join 2. Assume the only reason these may differ is due to model
@ -1465,11 +1487,11 @@ class ManyToManyField(RelatedField):
pathinfos.extend(join2infos) pathinfos.extend(join2infos)
return pathinfos return pathinfos
def get_path_info(self): def get_path_info(self, filtered_relation=None):
return self._get_path_info(direct=True) return self._get_path_info(direct=True, filtered_relation=filtered_relation)
def get_reverse_path_info(self): def get_reverse_path_info(self, filtered_relation=None):
return self._get_path_info(direct=False) return self._get_path_info(direct=False, filtered_relation=filtered_relation)
def _get_m2m_db_table(self, opts): def _get_m2m_db_table(self, opts):
""" """

View File

@ -163,8 +163,8 @@ class ForeignObjectRel(FieldCacheMixin):
return self.related_name return self.related_name
return opts.model_name + ('_set' if self.multiple else '') return opts.model_name + ('_set' if self.multiple else '')
def get_path_info(self): def get_path_info(self, filtered_relation=None):
return self.field.get_reverse_path_info() return self.field.get_reverse_path_info(filtered_relation)
def get_cache_name(self): def get_cache_name(self):
""" """

View File

@ -632,7 +632,15 @@ class Options:
final_field = opts.parents[int_model] final_field = opts.parents[int_model]
targets = (final_field.remote_field.get_related_field(),) targets = (final_field.remote_field.get_related_field(),)
opts = int_model._meta opts = int_model._meta
path.append(PathInfo(final_field.model._meta, opts, targets, final_field, False, True)) path.append(PathInfo(
from_opts=final_field.model._meta,
to_opts=opts,
target_fields=targets,
join_field=final_field,
m2m=False,
direct=True,
filtered_relation=None,
))
return path return path
def get_path_from_parent(self, parent): def get_path_from_parent(self, parent):

View File

@ -22,7 +22,7 @@ from django.db.models.deletion import Collector
from django.db.models.expressions import F from django.db.models.expressions import F
from django.db.models.fields import AutoField from django.db.models.fields import AutoField
from django.db.models.functions import Trunc from django.db.models.functions import Trunc
from django.db.models.query_utils import InvalidQuery, Q from django.db.models.query_utils import FilteredRelation, InvalidQuery, Q
from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE
from django.utils import timezone from django.utils import timezone
from django.utils.deprecation import RemovedInDjango30Warning from django.utils.deprecation import RemovedInDjango30Warning
@ -953,6 +953,12 @@ class QuerySet:
if lookups == (None,): if lookups == (None,):
clone._prefetch_related_lookups = () clone._prefetch_related_lookups = ()
else: else:
for lookup in lookups:
if isinstance(lookup, Prefetch):
lookup = lookup.prefetch_to
lookup = lookup.split(LOOKUP_SEP, 1)[0]
if lookup in self.query._filtered_relations:
raise ValueError('prefetch_related() is not supported with FilteredRelation.')
clone._prefetch_related_lookups = clone._prefetch_related_lookups + lookups clone._prefetch_related_lookups = clone._prefetch_related_lookups + lookups
return clone return clone
@ -984,7 +990,10 @@ class QuerySet:
if alias in names: if alias in names:
raise ValueError("The annotation '%s' conflicts with a field on " raise ValueError("The annotation '%s' conflicts with a field on "
"the model." % alias) "the model." % alias)
clone.query.add_annotation(annotation, alias, is_summary=False) if isinstance(annotation, FilteredRelation):
clone.query.add_filtered_relation(annotation, alias)
else:
clone.query.add_annotation(annotation, alias, is_summary=False)
for alias, annotation in clone.query.annotations.items(): for alias, annotation in clone.query.annotations.items():
if alias in annotations and annotation.contains_aggregate: if alias in annotations and annotation.contains_aggregate:
@ -1060,6 +1069,10 @@ class QuerySet:
# Can only pass None to defer(), not only(), as the rest option. # Can only pass None to defer(), not only(), as the rest option.
# That won't stop people trying to do this, so let's be explicit. # That won't stop people trying to do this, so let's be explicit.
raise TypeError("Cannot pass None as an argument to only().") raise TypeError("Cannot pass None as an argument to only().")
for field in fields:
field = field.split(LOOKUP_SEP, 1)[0]
if field in self.query._filtered_relations:
raise ValueError('only() is not supported with FilteredRelation.')
clone = self._chain() clone = self._chain()
clone.query.add_immediate_loading(fields) clone.query.add_immediate_loading(fields)
return clone return clone
@ -1730,9 +1743,9 @@ class RelatedPopulator:
# model's fields. # model's fields.
# - related_populators: a list of RelatedPopulator instances if # - related_populators: a list of RelatedPopulator instances if
# select_related() descends to related models from this model. # select_related() descends to related models from this model.
# - field, remote_field: the fields to use for populating the # - local_setter, remote_setter: Methods to set cached values on
# internal fields cache. If remote_field is set then we also # the object being populated and on the remote object. Usually
# set the reverse link. # these are Field.set_cached_value() methods.
select_fields = klass_info['select_fields'] select_fields = klass_info['select_fields']
from_parent = klass_info['from_parent'] from_parent = klass_info['from_parent']
if not from_parent: if not from_parent:
@ -1751,16 +1764,8 @@ class RelatedPopulator:
self.model_cls = klass_info['model'] self.model_cls = klass_info['model']
self.pk_idx = self.init_list.index(self.model_cls._meta.pk.attname) self.pk_idx = self.init_list.index(self.model_cls._meta.pk.attname)
self.related_populators = get_related_populators(klass_info, select, self.db) self.related_populators = get_related_populators(klass_info, select, self.db)
reverse = klass_info['reverse'] self.local_setter = klass_info['local_setter']
field = klass_info['field'] self.remote_setter = klass_info['remote_setter']
self.remote_field = None
if reverse:
self.field = field.remote_field
self.remote_field = field
else:
self.field = field
if field.unique:
self.remote_field = field.remote_field
def populate(self, row, from_obj): def populate(self, row, from_obj):
if self.reorder_for_init: if self.reorder_for_init:
@ -1774,9 +1779,9 @@ class RelatedPopulator:
if self.related_populators: if self.related_populators:
for rel_iter in self.related_populators: for rel_iter in self.related_populators:
rel_iter.populate(row, obj) rel_iter.populate(row, obj)
if self.remote_field: self.local_setter(from_obj, obj)
self.remote_field.set_cached_value(obj, from_obj) if obj is not None:
self.field.set_cached_value(from_obj, obj) self.remote_setter(obj, from_obj)
def get_related_populators(klass_info, select, db): def get_related_populators(klass_info, select, db):

View File

@ -16,7 +16,7 @@ from django.utils import tree
# PathInfo is used when converting lookups (fk__somecol). The contents # PathInfo is used when converting lookups (fk__somecol). The contents
# describe the relation in Model terms (model Options and Fields for both # describe the relation in Model terms (model Options and Fields for both
# sides of the relation. The join_field is the field backing the relation. # sides of the relation. The join_field is the field backing the relation.
PathInfo = namedtuple('PathInfo', 'from_opts to_opts target_fields join_field m2m direct') PathInfo = namedtuple('PathInfo', 'from_opts to_opts target_fields join_field m2m direct filtered_relation')
class InvalidQuery(Exception): class InvalidQuery(Exception):
@ -291,3 +291,44 @@ def check_rel_lookup_compatibility(model, target_opts, field):
check(target_opts) or check(target_opts) or
(getattr(field, 'primary_key', False) and check(field.model._meta)) (getattr(field, 'primary_key', False) and check(field.model._meta))
) )
class FilteredRelation:
"""Specify custom filtering in the ON clause of SQL joins."""
def __init__(self, relation_name, *, condition=Q()):
if not relation_name:
raise ValueError('relation_name cannot be empty.')
self.relation_name = relation_name
self.alias = None
if not isinstance(condition, Q):
raise ValueError('condition argument must be a Q() instance.')
self.condition = condition
self.path = []
def __eq__(self, other):
return (
isinstance(other, self.__class__) and
self.relation_name == other.relation_name and
self.alias == other.alias and
self.condition == other.condition
)
def clone(self):
clone = FilteredRelation(self.relation_name, condition=self.condition)
clone.alias = self.alias
clone.path = self.path[:]
return clone
def resolve_expression(self, *args, **kwargs):
"""
QuerySet.annotate() only accepts expression-like arguments
(with a resolve_expression() method).
"""
raise NotImplementedError('FilteredRelation.resolve_expression() is unused.')
def as_sql(self, compiler, connection):
# Resolve the condition in Join.filtered_relation.
query = compiler.query
where = query.build_filtered_relation_q(self.condition, reuse=set(self.path))
return compiler.compile(where)

View File

@ -702,7 +702,7 @@ class SQLCompiler:
""" """
result = [] result = []
params = [] params = []
for alias in self.query.alias_map: for alias in tuple(self.query.alias_map):
if not self.query.alias_refcount[alias]: if not self.query.alias_refcount[alias]:
continue continue
try: try:
@ -737,7 +737,7 @@ class SQLCompiler:
f.field.related_query_name() f.field.related_query_name()
for f in opts.related_objects if f.field.unique for f in opts.related_objects if f.field.unique
) )
return chain(direct_choices, reverse_choices) return chain(direct_choices, reverse_choices, self.query._filtered_relations)
related_klass_infos = [] related_klass_infos = []
if not restricted and cur_depth > self.query.max_depth: if not restricted and cur_depth > self.query.max_depth:
@ -788,7 +788,8 @@ class SQLCompiler:
klass_info = { klass_info = {
'model': f.remote_field.model, 'model': f.remote_field.model,
'field': f, 'field': f,
'reverse': False, 'local_setter': f.set_cached_value,
'remote_setter': f.remote_field.set_cached_value if f.unique else lambda x, y: None,
'from_parent': False, 'from_parent': False,
} }
related_klass_infos.append(klass_info) related_klass_infos.append(klass_info)
@ -825,7 +826,8 @@ class SQLCompiler:
klass_info = { klass_info = {
'model': model, 'model': model,
'field': f, 'field': f,
'reverse': True, 'local_setter': f.remote_field.set_cached_value,
'remote_setter': f.set_cached_value,
'from_parent': from_parent, 'from_parent': from_parent,
} }
related_klass_infos.append(klass_info) related_klass_infos.append(klass_info)
@ -842,6 +844,47 @@ class SQLCompiler:
next, restricted) next, restricted)
get_related_klass_infos(klass_info, next_klass_infos) get_related_klass_infos(klass_info, next_klass_infos)
fields_not_found = set(requested).difference(fields_found) fields_not_found = set(requested).difference(fields_found)
for name in list(requested):
# Filtered relations work only on the topmost level.
if cur_depth > 1:
break
if name in self.query._filtered_relations:
fields_found.add(name)
f, _, join_opts, joins, _ = self.query.setup_joins([name], opts, root_alias)
model = join_opts.model
alias = joins[-1]
from_parent = issubclass(model, opts.model) and model is not opts.model
def local_setter(obj, from_obj):
f.remote_field.set_cached_value(from_obj, obj)
def remote_setter(obj, from_obj):
setattr(from_obj, name, obj)
klass_info = {
'model': model,
'field': f,
'local_setter': local_setter,
'remote_setter': remote_setter,
'from_parent': from_parent,
}
related_klass_infos.append(klass_info)
select_fields = []
columns = self.get_default_columns(
start_alias=alias, opts=model._meta,
from_parent=opts.model,
)
for col in columns:
select_fields.append(len(select))
select.append((col, None))
klass_info['select_fields'] = select_fields
next_requested = requested.get(name, {})
next_klass_infos = self.get_related_selections(
select, opts=model._meta, root_alias=alias,
cur_depth=cur_depth + 1, requested=next_requested,
restricted=restricted,
)
get_related_klass_infos(klass_info, next_klass_infos)
fields_not_found = set(requested).difference(fields_found)
if fields_not_found: if fields_not_found:
invalid_fields = ("'%s'" % s for s in fields_not_found) invalid_fields = ("'%s'" % s for s in fields_not_found)
raise FieldError( raise FieldError(

View File

@ -41,7 +41,7 @@ class Join:
- relabeled_clone() - relabeled_clone()
""" """
def __init__(self, table_name, parent_alias, table_alias, join_type, def __init__(self, table_name, parent_alias, table_alias, join_type,
join_field, nullable): join_field, nullable, filtered_relation=None):
# Join table # Join table
self.table_name = table_name self.table_name = table_name
self.parent_alias = parent_alias self.parent_alias = parent_alias
@ -56,6 +56,7 @@ class Join:
self.join_field = join_field self.join_field = join_field
# Is this join nullabled? # Is this join nullabled?
self.nullable = nullable self.nullable = nullable
self.filtered_relation = filtered_relation
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
""" """
@ -85,7 +86,11 @@ class Join:
extra_sql, extra_params = compiler.compile(extra_cond) extra_sql, extra_params = compiler.compile(extra_cond)
join_conditions.append('(%s)' % extra_sql) join_conditions.append('(%s)' % extra_sql)
params.extend(extra_params) params.extend(extra_params)
if self.filtered_relation:
extra_sql, extra_params = compiler.compile(self.filtered_relation)
if extra_sql:
join_conditions.append('(%s)' % extra_sql)
params.extend(extra_params)
if not join_conditions: if not join_conditions:
# This might be a rel on the other end of an actual declared field. # This might be a rel on the other end of an actual declared field.
declared_field = getattr(self.join_field, 'field', self.join_field) declared_field = getattr(self.join_field, 'field', self.join_field)
@ -101,18 +106,27 @@ class Join:
def relabeled_clone(self, change_map): def relabeled_clone(self, change_map):
new_parent_alias = change_map.get(self.parent_alias, self.parent_alias) new_parent_alias = change_map.get(self.parent_alias, self.parent_alias)
new_table_alias = change_map.get(self.table_alias, self.table_alias) new_table_alias = change_map.get(self.table_alias, self.table_alias)
if self.filtered_relation is not None:
filtered_relation = self.filtered_relation.clone()
filtered_relation.path = [change_map.get(p, p) for p in self.filtered_relation.path]
else:
filtered_relation = None
return self.__class__( return self.__class__(
self.table_name, new_parent_alias, new_table_alias, self.join_type, self.table_name, new_parent_alias, new_table_alias, self.join_type,
self.join_field, self.nullable) self.join_field, self.nullable, filtered_relation=filtered_relation,
)
def equals(self, other, with_filtered_relation):
return (
isinstance(other, self.__class__) and
self.table_name == other.table_name and
self.parent_alias == other.parent_alias and
self.join_field == other.join_field and
(not with_filtered_relation or self.filtered_relation == other.filtered_relation)
)
def __eq__(self, other): def __eq__(self, other):
if isinstance(other, self.__class__): return self.equals(other, with_filtered_relation=True)
return (
self.table_name == other.table_name and
self.parent_alias == other.parent_alias and
self.join_field == other.join_field
)
return False
def demote(self): def demote(self):
new = self.relabeled_clone({}) new = self.relabeled_clone({})
@ -134,6 +148,7 @@ class BaseTable:
""" """
join_type = None join_type = None
parent_alias = None parent_alias = None
filtered_relation = None
def __init__(self, table_name, alias): def __init__(self, table_name, alias):
self.table_name = table_name self.table_name = table_name
@ -146,3 +161,10 @@ class BaseTable:
def relabeled_clone(self, change_map): def relabeled_clone(self, change_map):
return self.__class__(self.table_name, change_map.get(self.table_alias, self.table_alias)) return self.__class__(self.table_name, change_map.get(self.table_alias, self.table_alias))
def equals(self, other, with_filtered_relation):
return (
isinstance(self, other.__class__) and
self.table_name == other.table_name and
self.table_alias == other.table_alias
)

View File

@ -45,6 +45,14 @@ def get_field_names_from_opts(opts):
)) ))
def get_children_from_q(q):
for child in q.children:
if isinstance(child, Node):
yield from get_children_from_q(child)
else:
yield child
JoinInfo = namedtuple( JoinInfo = namedtuple(
'JoinInfo', 'JoinInfo',
('final_field', 'targets', 'opts', 'joins', 'path') ('final_field', 'targets', 'opts', 'joins', 'path')
@ -210,6 +218,8 @@ class Query:
# load. # load.
self.deferred_loading = (frozenset(), True) self.deferred_loading = (frozenset(), True)
self._filtered_relations = {}
@property @property
def extra(self): def extra(self):
if self._extra is None: if self._extra is None:
@ -311,6 +321,7 @@ class Query:
if 'subq_aliases' in self.__dict__: if 'subq_aliases' in self.__dict__:
obj.subq_aliases = self.subq_aliases.copy() obj.subq_aliases = self.subq_aliases.copy()
obj.used_aliases = self.used_aliases.copy() obj.used_aliases = self.used_aliases.copy()
obj._filtered_relations = self._filtered_relations.copy()
# Clear the cached_property # Clear the cached_property
try: try:
del obj.base_table del obj.base_table
@ -624,6 +635,8 @@ class Query:
opts = orig_opts opts = orig_opts
for name in parts[:-1]: for name in parts[:-1]:
old_model = cur_model old_model = cur_model
if name in self._filtered_relations:
name = self._filtered_relations[name].relation_name
source = opts.get_field(name) source = opts.get_field(name)
if is_reverse_o2o(source): if is_reverse_o2o(source):
cur_model = source.related_model cur_model = source.related_model
@ -684,7 +697,7 @@ class Query:
for model, values in seen.items(): for model, values in seen.items():
callback(target, model, values) callback(target, model, values)
def table_alias(self, table_name, create=False): def table_alias(self, table_name, create=False, filtered_relation=None):
""" """
Return a table alias for the given table_name and whether this is a Return a table alias for the given table_name and whether this is a
new alias or not. new alias or not.
@ -704,8 +717,8 @@ class Query:
alias_list.append(alias) alias_list.append(alias)
else: else:
# The first occurrence of a table uses the table name directly. # The first occurrence of a table uses the table name directly.
alias = table_name alias = filtered_relation.alias if filtered_relation is not None else table_name
self.table_map[alias] = [alias] self.table_map[table_name] = [alias]
self.alias_refcount[alias] = 1 self.alias_refcount[alias] = 1
return alias, True return alias, True
@ -881,7 +894,7 @@ class Query:
""" """
return len([1 for count in self.alias_refcount.values() if count]) return len([1 for count in self.alias_refcount.values() if count])
def join(self, join, reuse=None): def join(self, join, reuse=None, reuse_with_filtered_relation=False):
""" """
Return an alias for the 'join', either reusing an existing alias for Return an alias for the 'join', either reusing an existing alias for
that join or creating a new one. 'join' is either a that join or creating a new one. 'join' is either a
@ -890,18 +903,29 @@ class Query:
The 'reuse' parameter can be either None which means all joins are The 'reuse' parameter can be either None which means all joins are
reusable, or it can be a set containing the aliases that can be reused. reusable, or it can be a set containing the aliases that can be reused.
The 'reuse_with_filtered_relation' parameter is used when computing
FilteredRelation instances.
A join is always created as LOUTER if the lhs alias is LOUTER to make A join is always created as LOUTER if the lhs alias is LOUTER to make
sure chains like t1 LOUTER t2 INNER t3 aren't generated. All new sure chains like t1 LOUTER t2 INNER t3 aren't generated. All new
joins are created as LOUTER if the join is nullable. joins are created as LOUTER if the join is nullable.
""" """
reuse = [a for a, j in self.alias_map.items() if reuse_with_filtered_relation and reuse:
if (reuse is None or a in reuse) and j == join] reuse_aliases = [
if reuse: a for a, j in self.alias_map.items()
self.ref_alias(reuse[0]) if a in reuse and j.equals(join, with_filtered_relation=False)
return reuse[0] ]
else:
reuse_aliases = [
a for a, j in self.alias_map.items()
if (reuse is None or a in reuse) and j == join
]
if reuse_aliases:
self.ref_alias(reuse_aliases[0])
return reuse_aliases[0]
# No reuse is possible, so we need a new alias. # No reuse is possible, so we need a new alias.
alias, _ = self.table_alias(join.table_name, create=True) alias, _ = self.table_alias(join.table_name, create=True, filtered_relation=join.filtered_relation)
if join.join_type: if join.join_type:
if self.alias_map[join.parent_alias].join_type == LOUTER or join.nullable: if self.alias_map[join.parent_alias].join_type == LOUTER or join.nullable:
join_type = LOUTER join_type = LOUTER
@ -1090,7 +1114,8 @@ class Query:
(name, lhs.output_field.__class__.__name__)) (name, lhs.output_field.__class__.__name__))
def build_filter(self, filter_expr, branch_negated=False, current_negated=False, def build_filter(self, filter_expr, branch_negated=False, current_negated=False,
can_reuse=None, allow_joins=True, split_subq=True): can_reuse=None, allow_joins=True, split_subq=True,
reuse_with_filtered_relation=False):
""" """
Build a WhereNode for a single filter clause but don't add it Build a WhereNode for a single filter clause but don't add it
to this Query. Query.add_q() will then add this filter to the where to this Query. Query.add_q() will then add this filter to the where
@ -1112,6 +1137,9 @@ class Query:
The 'can_reuse' is a set of reusable joins for multijoins. The 'can_reuse' is a set of reusable joins for multijoins.
If 'reuse_with_filtered_relation' is True, then only joins in can_reuse
will be reused.
The method will create a filter clause that can be added to the current The method will create a filter clause that can be added to the current
query. However, if the filter isn't added to the query then the caller query. However, if the filter isn't added to the query then the caller
is responsible for unreffing the joins used. is responsible for unreffing the joins used.
@ -1147,7 +1175,10 @@ class Query:
allow_many = not branch_negated or not split_subq allow_many = not branch_negated or not split_subq
try: try:
join_info = self.setup_joins(parts, opts, alias, can_reuse=can_reuse, allow_many=allow_many) join_info = self.setup_joins(
parts, opts, alias, can_reuse=can_reuse, allow_many=allow_many,
reuse_with_filtered_relation=reuse_with_filtered_relation,
)
# Prevent iterator from being consumed by check_related_objects() # Prevent iterator from being consumed by check_related_objects()
if isinstance(value, Iterator): if isinstance(value, Iterator):
@ -1250,6 +1281,41 @@ class Query:
needed_inner = joinpromoter.update_join_types(self) needed_inner = joinpromoter.update_join_types(self)
return target_clause, needed_inner return target_clause, needed_inner
def build_filtered_relation_q(self, q_object, reuse, branch_negated=False, current_negated=False):
"""Add a FilteredRelation object to the current filter."""
connector = q_object.connector
current_negated ^= q_object.negated
branch_negated = branch_negated or q_object.negated
target_clause = self.where_class(connector=connector, negated=q_object.negated)
for child in q_object.children:
if isinstance(child, Node):
child_clause = self.build_filtered_relation_q(
child, reuse=reuse, branch_negated=branch_negated,
current_negated=current_negated,
)
else:
child_clause, _ = self.build_filter(
child, can_reuse=reuse, branch_negated=branch_negated,
current_negated=current_negated,
allow_joins=True, split_subq=False,
reuse_with_filtered_relation=True,
)
target_clause.add(child_clause, connector)
return target_clause
def add_filtered_relation(self, filtered_relation, alias):
filtered_relation.alias = alias
lookups = dict(get_children_from_q(filtered_relation.condition))
for lookup in chain((filtered_relation.relation_name,), lookups):
lookup_parts, field_parts, _ = self.solve_lookup_type(lookup)
shift = 2 if not lookup_parts else 1
if len(field_parts) > (shift + len(lookup_parts)):
raise ValueError(
"FilteredRelation's condition doesn't support nested "
"relations (got %r)." % lookup
)
self._filtered_relations[filtered_relation.alias] = filtered_relation
def names_to_path(self, names, opts, allow_many=True, fail_on_missing=False): def names_to_path(self, names, opts, allow_many=True, fail_on_missing=False):
""" """
Walk the list of names and turns them into PathInfo tuples. A single Walk the list of names and turns them into PathInfo tuples. A single
@ -1272,12 +1338,15 @@ class Query:
name = opts.pk.name name = opts.pk.name
field = None field = None
filtered_relation = None
try: try:
field = opts.get_field(name) field = opts.get_field(name)
except FieldDoesNotExist: except FieldDoesNotExist:
if name in self.annotation_select: if name in self.annotation_select:
field = self.annotation_select[name].output_field field = self.annotation_select[name].output_field
elif name in self._filtered_relations and pos == 0:
filtered_relation = self._filtered_relations[name]
field = opts.get_field(filtered_relation.relation_name)
if field is not None: if field is not None:
# Fields that contain one-to-many relations with a generic # Fields that contain one-to-many relations with a generic
# model (like a GenericForeignKey) cannot generate reverse # model (like a GenericForeignKey) cannot generate reverse
@ -1301,7 +1370,10 @@ class Query:
pos -= 1 pos -= 1
if pos == -1 or fail_on_missing: if pos == -1 or fail_on_missing:
field_names = list(get_field_names_from_opts(opts)) field_names = list(get_field_names_from_opts(opts))
available = sorted(field_names + list(self.annotation_select)) available = sorted(
field_names + list(self.annotation_select) +
list(self._filtered_relations)
)
raise FieldError("Cannot resolve keyword '%s' into field. " raise FieldError("Cannot resolve keyword '%s' into field. "
"Choices are: %s" % (name, ", ".join(available))) "Choices are: %s" % (name, ", ".join(available)))
break break
@ -1315,7 +1387,7 @@ class Query:
cur_names_with_path[1].extend(path_to_parent) cur_names_with_path[1].extend(path_to_parent)
opts = path_to_parent[-1].to_opts opts = path_to_parent[-1].to_opts
if hasattr(field, 'get_path_info'): if hasattr(field, 'get_path_info'):
pathinfos = field.get_path_info() pathinfos = field.get_path_info(filtered_relation)
if not allow_many: if not allow_many:
for inner_pos, p in enumerate(pathinfos): for inner_pos, p in enumerate(pathinfos):
if p.m2m: if p.m2m:
@ -1340,7 +1412,8 @@ class Query:
break break
return path, final_field, targets, names[pos + 1:] return path, final_field, targets, names[pos + 1:]
def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True): def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True,
reuse_with_filtered_relation=False):
""" """
Compute the necessary table joins for the passage through the fields Compute the necessary table joins for the passage through the fields
given in 'names'. 'opts' is the Options class for the current model given in 'names'. 'opts' is the Options class for the current model
@ -1352,6 +1425,9 @@ class Query:
that can be reused. Note that non-reverse foreign keys are always that can be reused. Note that non-reverse foreign keys are always
reusable when using setup_joins(). reusable when using setup_joins().
The 'reuse_with_filtered_relation' can be used to force 'can_reuse'
parameter and force the relation on the given connections.
If 'allow_many' is False, then any reverse foreign key seen will If 'allow_many' is False, then any reverse foreign key seen will
generate a MultiJoin exception. generate a MultiJoin exception.
@ -1374,15 +1450,29 @@ class Query:
# joins at this stage - we will need the information about join type # joins at this stage - we will need the information about join type
# of the trimmed joins. # of the trimmed joins.
for join in path: for join in path:
if join.filtered_relation:
filtered_relation = join.filtered_relation.clone()
table_alias = filtered_relation.alias
else:
filtered_relation = None
table_alias = None
opts = join.to_opts opts = join.to_opts
if join.direct: if join.direct:
nullable = self.is_nullable(join.join_field) nullable = self.is_nullable(join.join_field)
else: else:
nullable = True nullable = True
connection = Join(opts.db_table, alias, None, INNER, join.join_field, nullable) connection = Join(
reuse = can_reuse if join.m2m else None opts.db_table, alias, table_alias, INNER, join.join_field,
alias = self.join(connection, reuse=reuse) nullable, filtered_relation=filtered_relation,
)
reuse = can_reuse if join.m2m or reuse_with_filtered_relation else None
alias = self.join(
connection, reuse=reuse,
reuse_with_filtered_relation=reuse_with_filtered_relation,
)
joins.append(alias) joins.append(alias)
if filtered_relation:
filtered_relation.path = joins[:]
return JoinInfo(final_field, targets, opts, joins, path) return JoinInfo(final_field, targets, opts, joins, path)
def trim_joins(self, targets, joins, path): def trim_joins(self, targets, joins, path):
@ -1402,6 +1492,8 @@ class Query:
for pos, info in enumerate(reversed(path)): for pos, info in enumerate(reversed(path)):
if len(joins) == 1 or not info.direct: if len(joins) == 1 or not info.direct:
break break
if info.filtered_relation:
break
join_targets = {t.column for t in info.join_field.foreign_related_fields} join_targets = {t.column for t in info.join_field.foreign_related_fields}
cur_targets = {t.column for t in targets} cur_targets = {t.column for t in targets}
if not cur_targets.issubset(join_targets): if not cur_targets.issubset(join_targets):
@ -1425,7 +1517,7 @@ class Query:
return self.annotation_select[name] return self.annotation_select[name]
else: else:
field_list = name.split(LOOKUP_SEP) field_list = name.split(LOOKUP_SEP)
join_info = self.setup_joins(field_list, self.get_meta(), self.get_initial_alias(), reuse) join_info = self.setup_joins(field_list, self.get_meta(), self.get_initial_alias(), can_reuse=reuse)
targets, _, join_list = self.trim_joins(join_info.targets, join_info.joins, join_info.path) targets, _, join_list = self.trim_joins(join_info.targets, join_info.joins, join_info.path)
if len(targets) > 1: if len(targets) > 1:
raise FieldError("Referencing multicolumn fields with F() objects " raise FieldError("Referencing multicolumn fields with F() objects "
@ -1602,7 +1694,10 @@ class Query:
# from the model on which the lookup failed. # from the model on which the lookup failed.
raise raise
else: else:
names = sorted(list(get_field_names_from_opts(opts)) + list(self.extra) + list(self.annotation_select)) names = sorted(
list(get_field_names_from_opts(opts)) + list(self.extra) +
list(self.annotation_select) + list(self._filtered_relations)
)
raise FieldError("Cannot resolve keyword %r into field. " raise FieldError("Cannot resolve keyword %r into field. "
"Choices are: %s" % (name, ", ".join(names))) "Choices are: %s" % (name, ", ".join(names)))

View File

@ -3318,3 +3318,60 @@ lookups or :class:`Prefetch` objects you want to prefetch for. For example::
>>> from django.db.models import prefetch_related_objects >>> from django.db.models import prefetch_related_objects
>>> restaurants = fetch_top_restaurants_from_cache() # A list of Restaurants >>> restaurants = fetch_top_restaurants_from_cache() # A list of Restaurants
>>> prefetch_related_objects(restaurants, 'pizzas__toppings') >>> prefetch_related_objects(restaurants, 'pizzas__toppings')
``FilteredRelation()`` objects
------------------------------
.. versionadded:: 2.0
.. class:: FilteredRelation(relation_name, *, condition=Q())
.. attribute:: FilteredRelation.relation_name
The name of the field on which you'd like to filter the relation.
.. attribute:: FilteredRelation.condition
A :class:`~django.db.models.Q` object to control the filtering.
``FilteredRelation`` is used with :meth:`~.QuerySet.annotate()` to create an
``ON`` clause when a ``JOIN`` is performed. It doesn't act on the default
relationship but on the annotation name (``pizzas_vegetarian`` in example
below).
For example, to find restaurants that have vegetarian pizzas with
``'mozzarella'`` in the name::
>>> from django.db.models import FilteredRelation, Q
>>> Restaurant.objects.annotate(
... pizzas_vegetarian=FilteredRelation(
... 'pizzas', condition=Q(pizzas__vegetarian=True),
... ),
... ).filter(pizzas_vegetarian__name__icontains='mozzarella')
If there are a large number of pizzas, this queryset performs better than::
>>> Restaurant.objects.filter(
... pizzas__vegetarian=True,
... pizzas__name__icontains='mozzarella',
... )
because the filtering in the ``WHERE`` clause of the first queryset will only
operate on vegetarian pizzas.
``FilteredRelation`` doesn't support:
* Conditions that span relational fields. For example::
>>> Restaurant.objects.annotate(
... pizzas_with_toppings_startswith_n=FilteredRelation(
... 'pizzas__toppings',
... condition=Q(pizzas__toppings__name__startswith='n'),
... ),
... )
Traceback (most recent call last):
...
ValueError: FilteredRelation's condition doesn't support nested relations (got 'pizzas__toppings__name__startswith').
* :meth:`.QuerySet.only` and :meth:`~.QuerySet.prefetch_related`.
* A :class:`~django.contrib.contenttypes.fields.GenericForeignKey`
inherited from a parent model.

View File

@ -354,6 +354,9 @@ Models
* The new ``named`` parameter of :meth:`.QuerySet.values_list` allows fetching * The new ``named`` parameter of :meth:`.QuerySet.values_list` allows fetching
results as named tuples. results as named tuples.
* The new :class:`.FilteredRelation` class allows adding an ``ON`` clause to
querysets.
Pagination Pagination
~~~~~~~~~~ ~~~~~~~~~~

View File

View File

@ -0,0 +1,108 @@
from django.contrib.contenttypes.fields import (
GenericForeignKey, GenericRelation,
)
from django.contrib.contenttypes.models import ContentType
from django.db import models
class Author(models.Model):
name = models.CharField(max_length=50, unique=True)
favorite_books = models.ManyToManyField(
'Book',
related_name='preferred_by_authors',
related_query_name='preferred_by_authors',
)
content_type = models.ForeignKey(ContentType, models.CASCADE, null=True)
object_id = models.PositiveIntegerField(null=True)
content_object = GenericForeignKey()
def __str__(self):
return self.name
class Editor(models.Model):
name = models.CharField(max_length=255)
def __str__(self):
return self.name
class Book(models.Model):
AVAILABLE = 'available'
RESERVED = 'reserved'
RENTED = 'rented'
STATES = (
(AVAILABLE, 'Available'),
(RESERVED, 'reserved'),
(RENTED, 'Rented'),
)
title = models.CharField(max_length=255)
author = models.ForeignKey(
Author,
models.CASCADE,
related_name='books',
related_query_name='book',
)
editor = models.ForeignKey(Editor, models.CASCADE)
generic_author = GenericRelation(Author)
state = models.CharField(max_length=9, choices=STATES, default=AVAILABLE)
def __str__(self):
return self.title
class Borrower(models.Model):
name = models.CharField(max_length=50, unique=True)
def __str__(self):
return self.name
class Reservation(models.Model):
NEW = 'new'
STOPPED = 'stopped'
STATES = (
(NEW, 'New'),
(STOPPED, 'Stopped'),
)
borrower = models.ForeignKey(
Borrower,
models.CASCADE,
related_name='reservations',
related_query_name='reservation',
)
book = models.ForeignKey(
Book,
models.CASCADE,
related_name='reservations',
related_query_name='reservation',
)
state = models.CharField(max_length=7, choices=STATES, default=NEW)
def __str__(self):
return '-'.join((self.book.name, self.borrower.name, self.state))
class RentalSession(models.Model):
NEW = 'new'
STOPPED = 'stopped'
STATES = (
(NEW, 'New'),
(STOPPED, 'Stopped'),
)
borrower = models.ForeignKey(
Borrower,
models.CASCADE,
related_name='rental_sessions',
related_query_name='rental_session',
)
book = models.ForeignKey(
Book,
models.CASCADE,
related_name='rental_sessions',
related_query_name='rental_session',
)
state = models.CharField(max_length=7, choices=STATES, default=NEW)
def __str__(self):
return '-'.join((self.book.name, self.borrower.name, self.state))

View File

@ -0,0 +1,381 @@
from django.db import connection
from django.db.models import Case, Count, F, FilteredRelation, Q, When
from django.test import TestCase
from django.test.testcases import skipUnlessDBFeature
from .models import Author, Book, Borrower, Editor, RentalSession, Reservation
class FilteredRelationTests(TestCase):
@classmethod
def setUpTestData(cls):
cls.author1 = Author.objects.create(name='Alice')
cls.author2 = Author.objects.create(name='Jane')
cls.editor_a = Editor.objects.create(name='a')
cls.editor_b = Editor.objects.create(name='b')
cls.book1 = Book.objects.create(
title='Poem by Alice',
editor=cls.editor_a,
author=cls.author1,
)
cls.book1.generic_author.set([cls.author2])
cls.book2 = Book.objects.create(
title='The book by Jane A',
editor=cls.editor_b,
author=cls.author2,
)
cls.book3 = Book.objects.create(
title='The book by Jane B',
editor=cls.editor_b,
author=cls.author2,
)
cls.book4 = Book.objects.create(
title='The book by Alice',
editor=cls.editor_a,
author=cls.author1,
)
cls.author1.favorite_books.add(cls.book2)
cls.author1.favorite_books.add(cls.book3)
def test_select_related(self):
qs = Author.objects.annotate(
book_join=FilteredRelation('book'),
).select_related('book_join__editor').order_by('pk', 'book_join__pk')
with self.assertNumQueries(1):
self.assertQuerysetEqual(qs, [
(self.author1, self.book1, self.editor_a, self.author1),
(self.author1, self.book4, self.editor_a, self.author1),
(self.author2, self.book2, self.editor_b, self.author2),
(self.author2, self.book3, self.editor_b, self.author2),
], lambda x: (x, x.book_join, x.book_join.editor, x.book_join.author))
def test_select_related_foreign_key(self):
qs = Book.objects.annotate(
author_join=FilteredRelation('author'),
).select_related('author_join').order_by('pk')
with self.assertNumQueries(1):
self.assertQuerysetEqual(qs, [
(self.book1, self.author1),
(self.book2, self.author2),
(self.book3, self.author2),
(self.book4, self.author1),
], lambda x: (x, x.author_join))
def test_without_join(self):
self.assertSequenceEqual(
Author.objects.annotate(
book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')),
),
[self.author1, self.author2]
)
def test_with_join(self):
self.assertSequenceEqual(
Author.objects.annotate(
book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')),
).filter(book_alice__isnull=False),
[self.author1]
)
def test_with_join_and_complex_condition(self):
self.assertSequenceEqual(
Author.objects.annotate(
book_alice=FilteredRelation(
'book', condition=Q(
Q(book__title__iexact='poem by alice') |
Q(book__state=Book.RENTED)
),
),
).filter(book_alice__isnull=False),
[self.author1]
)
def test_internal_queryset_alias_mapping(self):
queryset = Author.objects.annotate(
book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')),
).filter(book_alice__isnull=False)
self.assertIn(
'INNER JOIN {} book_alice ON'.format(connection.ops.quote_name('filtered_relation_book')),
str(queryset.query)
)
def test_with_multiple_filter(self):
self.assertSequenceEqual(
Author.objects.annotate(
book_editor_a=FilteredRelation(
'book',
condition=Q(book__title__icontains='book', book__editor_id=self.editor_a.pk),
),
).filter(book_editor_a__isnull=False),
[self.author1]
)
def test_multiple_times(self):
self.assertSequenceEqual(
Author.objects.annotate(
book_title_alice=FilteredRelation('book', condition=Q(book__title__icontains='alice')),
).filter(book_title_alice__isnull=False).filter(book_title_alice__isnull=False).distinct(),
[self.author1]
)
def test_exclude_relation_with_join(self):
self.assertSequenceEqual(
Author.objects.annotate(
book_alice=FilteredRelation('book', condition=~Q(book__title__icontains='alice')),
).filter(book_alice__isnull=False).distinct(),
[self.author2]
)
def test_with_m2m(self):
qs = Author.objects.annotate(
favorite_books_written_by_jane=FilteredRelation(
'favorite_books', condition=Q(favorite_books__in=[self.book2]),
),
).filter(favorite_books_written_by_jane__isnull=False)
self.assertSequenceEqual(qs, [self.author1])
def test_with_m2m_deep(self):
qs = Author.objects.annotate(
favorite_books_written_by_jane=FilteredRelation(
'favorite_books', condition=Q(favorite_books__author=self.author2),
),
).filter(favorite_books_written_by_jane__title='The book by Jane B')
self.assertSequenceEqual(qs, [self.author1])
def test_with_m2m_multijoin(self):
qs = Author.objects.annotate(
favorite_books_written_by_jane=FilteredRelation(
'favorite_books', condition=Q(favorite_books__author=self.author2),
)
).filter(favorite_books_written_by_jane__editor__name='b').distinct()
self.assertSequenceEqual(qs, [self.author1])
def test_values_list(self):
self.assertSequenceEqual(
Author.objects.annotate(
book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')),
).filter(book_alice__isnull=False).values_list('book_alice__title', flat=True),
['Poem by Alice']
)
def test_values(self):
self.assertSequenceEqual(
Author.objects.annotate(
book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')),
).filter(book_alice__isnull=False).values(),
[{'id': self.author1.pk, 'name': 'Alice', 'content_type_id': None, 'object_id': None}]
)
def test_extra(self):
self.assertSequenceEqual(
Author.objects.annotate(
book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')),
).filter(book_alice__isnull=False).extra(where=['1 = 1']),
[self.author1]
)
@skipUnlessDBFeature('supports_select_union')
def test_union(self):
qs1 = Author.objects.annotate(
book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')),
).filter(book_alice__isnull=False)
qs2 = Author.objects.annotate(
book_jane=FilteredRelation('book', condition=Q(book__title__iexact='the book by jane a')),
).filter(book_jane__isnull=False)
self.assertSequenceEqual(qs1.union(qs2), [self.author1, self.author2])
@skipUnlessDBFeature('supports_select_intersection')
def test_intersection(self):
qs1 = Author.objects.annotate(
book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')),
).filter(book_alice__isnull=False)
qs2 = Author.objects.annotate(
book_jane=FilteredRelation('book', condition=Q(book__title__iexact='the book by jane a')),
).filter(book_jane__isnull=False)
self.assertSequenceEqual(qs1.intersection(qs2), [])
@skipUnlessDBFeature('supports_select_difference')
def test_difference(self):
qs1 = Author.objects.annotate(
book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')),
).filter(book_alice__isnull=False)
qs2 = Author.objects.annotate(
book_jane=FilteredRelation('book', condition=Q(book__title__iexact='the book by jane a')),
).filter(book_jane__isnull=False)
self.assertSequenceEqual(qs1.difference(qs2), [self.author1])
def test_select_for_update(self):
self.assertSequenceEqual(
Author.objects.annotate(
book_jane=FilteredRelation('book', condition=Q(book__title__iexact='the book by jane a')),
).filter(book_jane__isnull=False).select_for_update(),
[self.author2]
)
def test_defer(self):
# One query for the list and one query for the deferred title.
with self.assertNumQueries(2):
self.assertQuerysetEqual(
Author.objects.annotate(
book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')),
).filter(book_alice__isnull=False).select_related('book_alice').defer('book_alice__title'),
['Poem by Alice'], lambda author: author.book_alice.title
)
def test_only_not_supported(self):
msg = 'only() is not supported with FilteredRelation.'
with self.assertRaisesMessage(ValueError, msg):
Author.objects.annotate(
book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')),
).filter(book_alice__isnull=False).select_related('book_alice').only('book_alice__state')
def test_as_subquery(self):
inner_qs = Author.objects.annotate(
book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')),
).filter(book_alice__isnull=False)
qs = Author.objects.filter(id__in=inner_qs)
self.assertSequenceEqual(qs, [self.author1])
def test_with_foreign_key_error(self):
msg = (
"FilteredRelation's condition doesn't support nested relations "
"(got 'author__favorite_books__author')."
)
with self.assertRaisesMessage(ValueError, msg):
list(Book.objects.annotate(
alice_favorite_books=FilteredRelation(
'author__favorite_books',
condition=Q(author__favorite_books__author=self.author1),
)
))
def test_with_foreign_key_on_condition_error(self):
msg = (
"FilteredRelation's condition doesn't support nested relations "
"(got 'book__editor__name__icontains')."
)
with self.assertRaisesMessage(ValueError, msg):
list(Author.objects.annotate(
book_edited_by_b=FilteredRelation('book', condition=Q(book__editor__name__icontains='b')),
))
def test_with_empty_relation_name_error(self):
with self.assertRaisesMessage(ValueError, 'relation_name cannot be empty.'):
FilteredRelation('', condition=Q(blank=''))
def test_with_condition_as_expression_error(self):
msg = 'condition argument must be a Q() instance.'
expression = Case(
When(book__title__iexact='poem by alice', then=True), default=False,
)
with self.assertRaisesMessage(ValueError, msg):
FilteredRelation('book', condition=expression)
def test_with_prefetch_related(self):
msg = 'prefetch_related() is not supported with FilteredRelation.'
qs = Author.objects.annotate(
book_title_contains_b=FilteredRelation('book', condition=Q(book__title__icontains='b')),
).filter(
book_title_contains_b__isnull=False,
)
with self.assertRaisesMessage(ValueError, msg):
qs.prefetch_related('book_title_contains_b')
with self.assertRaisesMessage(ValueError, msg):
qs.prefetch_related('book_title_contains_b__editor')
def test_with_generic_foreign_key(self):
self.assertSequenceEqual(
Book.objects.annotate(
generic_authored_book=FilteredRelation(
'generic_author',
condition=Q(generic_author__isnull=False)
),
).filter(generic_authored_book__isnull=False),
[self.book1]
)
class FilteredRelationAggregationTests(TestCase):
@classmethod
def setUpTestData(cls):
cls.author1 = Author.objects.create(name='Alice')
cls.editor_a = Editor.objects.create(name='a')
cls.book1 = Book.objects.create(
title='Poem by Alice',
editor=cls.editor_a,
author=cls.author1,
)
cls.borrower1 = Borrower.objects.create(name='Jenny')
cls.borrower2 = Borrower.objects.create(name='Kevin')
# borrower 1 reserves, rents, and returns book1.
Reservation.objects.create(
borrower=cls.borrower1,
book=cls.book1,
state=Reservation.STOPPED,
)
RentalSession.objects.create(
borrower=cls.borrower1,
book=cls.book1,
state=RentalSession.STOPPED,
)
# borrower2 reserves, rents, and returns book1.
Reservation.objects.create(
borrower=cls.borrower2,
book=cls.book1,
state=Reservation.STOPPED,
)
RentalSession.objects.create(
borrower=cls.borrower2,
book=cls.book1,
state=RentalSession.STOPPED,
)
def test_aggregate(self):
"""
filtered_relation() not only improves performance but also creates
correct results when aggregating with multiple LEFT JOINs.
Books can be reserved then rented by a borrower. Each reservation and
rental session are recorded with Reservation and RentalSession models.
Every time a reservation or a rental session is over, their state is
changed to 'stopped'.
Goal: Count number of books that are either currently reserved or
rented by borrower1 or available.
"""
qs = Book.objects.annotate(
is_reserved_or_rented_by=Case(
When(reservation__state=Reservation.NEW, then=F('reservation__borrower__pk')),
When(rental_session__state=RentalSession.NEW, then=F('rental_session__borrower__pk')),
default=None,
)
).filter(
Q(is_reserved_or_rented_by=self.borrower1.pk) | Q(state=Book.AVAILABLE)
).distinct()
self.assertEqual(qs.count(), 1)
# If count is equal to 1, the same aggregation should return in the
# same result but it returns 4.
self.assertSequenceEqual(qs.annotate(total=Count('pk')).values('total'), [{'total': 4}])
# With FilteredRelation, the result is as expected (1).
qs = Book.objects.annotate(
active_reservations=FilteredRelation(
'reservation', condition=Q(
reservation__state=Reservation.NEW,
reservation__borrower=self.borrower1,
)
),
).annotate(
active_rental_sessions=FilteredRelation(
'rental_session', condition=Q(
rental_session__state=RentalSession.NEW,
rental_session__borrower=self.borrower1,
)
),
).filter(
(Q(active_reservations__isnull=False) | Q(active_rental_sessions__isnull=False)) |
Q(state=Book.AVAILABLE)
).distinct()
self.assertEqual(qs.count(), 1)
self.assertSequenceEqual(qs.annotate(total=Count('pk')).values('total'), [{'total': 1}])

View File

@ -53,15 +53,31 @@ class StartsWithRelation(models.ForeignObject):
def get_joining_columns(self, reverse_join=False): def get_joining_columns(self, reverse_join=False):
return () return ()
def get_path_info(self): def get_path_info(self, filtered_relation=None):
to_opts = self.remote_field.model._meta to_opts = self.remote_field.model._meta
from_opts = self.model._meta from_opts = self.model._meta
return [PathInfo(from_opts, to_opts, (to_opts.pk,), self, False, False)] return [PathInfo(
from_opts=from_opts,
to_opts=to_opts,
target_fields=(to_opts.pk,),
join_field=self,
m2m=False,
direct=False,
filtered_relation=filtered_relation,
)]
def get_reverse_path_info(self): def get_reverse_path_info(self, filtered_relation=None):
to_opts = self.model._meta to_opts = self.model._meta
from_opts = self.remote_field.model._meta from_opts = self.remote_field.model._meta
return [PathInfo(from_opts, to_opts, (to_opts.pk,), self.remote_field, False, False)] return [PathInfo(
from_opts=from_opts,
to_opts=to_opts,
target_fields=(to_opts.pk,),
join_field=self.remote_field,
m2m=False,
direct=False,
filtered_relation=filtered_relation,
)]
def contribute_to_class(self, cls, name, private_only=False): def contribute_to_class(self, cls, name, private_only=False):
super().contribute_to_class(cls, name, private_only) super().contribute_to_class(cls, name, private_only)

View File

@ -1,4 +1,5 @@
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db.models import FilteredRelation
from django.test import SimpleTestCase, TestCase from django.test import SimpleTestCase, TestCase
from .models import ( from .models import (
@ -230,3 +231,8 @@ class ReverseSelectRelatedValidationTests(SimpleTestCase):
with self.assertRaisesMessage(FieldError, self.non_relational_error % ('username', fields)): with self.assertRaisesMessage(FieldError, self.non_relational_error % ('username', fields)):
list(User.objects.select_related('username')) list(User.objects.select_related('username'))
def test_reverse_related_validation_with_filtered_relation(self):
fields = 'userprofile, userstat, relation'
with self.assertRaisesMessage(FieldError, self.invalid_error % ('foobar', fields)):
list(User.objects.annotate(relation=FilteredRelation('userprofile')).select_related('foobar'))