Fixed #27332 -- Added FilteredRelation API for conditional join (ON clause) support.
Thanks Anssi Kääriäinen for contributing to the patch.
This commit is contained in:
parent
3f9d85d95c
commit
01d440fa1e
|
@ -348,7 +348,7 @@ class GenericRelation(ForeignObject):
|
|||
self.to_fields = [self.model._meta.pk.name]
|
||||
return [(self.remote_field.model._meta.get_field(self.object_id_field_name), self.model._meta.pk)]
|
||||
|
||||
def _get_path_info_with_parent(self):
|
||||
def _get_path_info_with_parent(self, filtered_relation):
|
||||
"""
|
||||
Return the path that joins the current model through any parent models.
|
||||
The idea is that if you have a GFK defined on a parent model then we
|
||||
|
@ -365,7 +365,15 @@ class GenericRelation(ForeignObject):
|
|||
opts = self.remote_field.model._meta.concrete_model._meta
|
||||
parent_opts = opts.get_field(self.object_id_field_name).model._meta
|
||||
target = parent_opts.pk
|
||||
path.append(PathInfo(self.model._meta, parent_opts, (target,), self.remote_field, True, False))
|
||||
path.append(PathInfo(
|
||||
from_opts=self.model._meta,
|
||||
to_opts=parent_opts,
|
||||
target_fields=(target,),
|
||||
join_field=self.remote_field,
|
||||
m2m=True,
|
||||
direct=False,
|
||||
filtered_relation=filtered_relation,
|
||||
))
|
||||
# Collect joins needed for the parent -> child chain. This is easiest
|
||||
# to do if we collect joins for the child -> parent chain and then
|
||||
# reverse the direction (call to reverse() and use of
|
||||
|
@ -380,19 +388,35 @@ class GenericRelation(ForeignObject):
|
|||
path.extend(field.remote_field.get_path_info())
|
||||
return path
|
||||
|
||||
def get_path_info(self):
|
||||
def get_path_info(self, filtered_relation=None):
|
||||
opts = self.remote_field.model._meta
|
||||
object_id_field = opts.get_field(self.object_id_field_name)
|
||||
if object_id_field.model != opts.model:
|
||||
return self._get_path_info_with_parent()
|
||||
return self._get_path_info_with_parent(filtered_relation)
|
||||
else:
|
||||
target = opts.pk
|
||||
return [PathInfo(self.model._meta, opts, (target,), self.remote_field, True, False)]
|
||||
return [PathInfo(
|
||||
from_opts=self.model._meta,
|
||||
to_opts=opts,
|
||||
target_fields=(target,),
|
||||
join_field=self.remote_field,
|
||||
m2m=True,
|
||||
direct=False,
|
||||
filtered_relation=filtered_relation,
|
||||
)]
|
||||
|
||||
def get_reverse_path_info(self):
|
||||
def get_reverse_path_info(self, filtered_relation=None):
|
||||
opts = self.model._meta
|
||||
from_opts = self.remote_field.model._meta
|
||||
return [PathInfo(from_opts, opts, (opts.pk,), self, not self.unique, False)]
|
||||
return [PathInfo(
|
||||
from_opts=from_opts,
|
||||
to_opts=opts,
|
||||
target_fields=(opts.pk,),
|
||||
join_field=self,
|
||||
m2m=not self.unique,
|
||||
direct=False,
|
||||
filtered_relation=filtered_relation,
|
||||
)]
|
||||
|
||||
def value_to_string(self, obj):
|
||||
qs = getattr(obj, self.name).all()
|
||||
|
|
|
@ -20,6 +20,7 @@ from django.db.models.manager import Manager
|
|||
from django.db.models.query import (
|
||||
Prefetch, Q, QuerySet, prefetch_related_objects,
|
||||
)
|
||||
from django.db.models.query_utils import FilteredRelation
|
||||
|
||||
# Imports that would create circular imports if sorted
|
||||
from django.db.models.base import DEFERRED, Model # isort:skip
|
||||
|
@ -69,6 +70,7 @@ __all__ += [
|
|||
'Window', 'WindowFrame',
|
||||
'FileField', 'ImageField', 'OrderWrt', 'Lookup', 'Transform', 'Manager',
|
||||
'Prefetch', 'Q', 'QuerySet', 'prefetch_related_objects', 'DEFERRED', 'Model',
|
||||
'FilteredRelation',
|
||||
'ForeignKey', 'ForeignObject', 'OneToOneField', 'ManyToManyField',
|
||||
'ManyToOneRel', 'ManyToManyRel', 'OneToOneRel', 'permalink',
|
||||
]
|
||||
|
|
|
@ -697,18 +697,33 @@ class ForeignObject(RelatedField):
|
|||
"""
|
||||
return None
|
||||
|
||||
def get_path_info(self):
|
||||
def get_path_info(self, filtered_relation=None):
|
||||
"""Get path from this field to the related model."""
|
||||
opts = self.remote_field.model._meta
|
||||
from_opts = self.model._meta
|
||||
return [PathInfo(from_opts, opts, self.foreign_related_fields, self, False, True)]
|
||||
return [PathInfo(
|
||||
from_opts=from_opts,
|
||||
to_opts=opts,
|
||||
target_fields=self.foreign_related_fields,
|
||||
join_field=self,
|
||||
m2m=False,
|
||||
direct=True,
|
||||
filtered_relation=filtered_relation,
|
||||
)]
|
||||
|
||||
def get_reverse_path_info(self):
|
||||
def get_reverse_path_info(self, filtered_relation=None):
|
||||
"""Get path from the related model to this field's model."""
|
||||
opts = self.model._meta
|
||||
from_opts = self.remote_field.model._meta
|
||||
pathinfos = [PathInfo(from_opts, opts, (opts.pk,), self.remote_field, not self.unique, False)]
|
||||
return pathinfos
|
||||
return [PathInfo(
|
||||
from_opts=from_opts,
|
||||
to_opts=opts,
|
||||
target_fields=(opts.pk,),
|
||||
join_field=self.remote_field,
|
||||
m2m=not self.unique,
|
||||
direct=False,
|
||||
filtered_relation=filtered_relation,
|
||||
)]
|
||||
|
||||
@classmethod
|
||||
@functools.lru_cache(maxsize=None)
|
||||
|
@ -861,12 +876,19 @@ class ForeignKey(ForeignObject):
|
|||
def target_field(self):
|
||||
return self.foreign_related_fields[0]
|
||||
|
||||
def get_reverse_path_info(self):
|
||||
def get_reverse_path_info(self, filtered_relation=None):
|
||||
"""Get path from the related model to this field's model."""
|
||||
opts = self.model._meta
|
||||
from_opts = self.remote_field.model._meta
|
||||
pathinfos = [PathInfo(from_opts, opts, (opts.pk,), self.remote_field, not self.unique, False)]
|
||||
return pathinfos
|
||||
return [PathInfo(
|
||||
from_opts=from_opts,
|
||||
to_opts=opts,
|
||||
target_fields=(opts.pk,),
|
||||
join_field=self.remote_field,
|
||||
m2m=not self.unique,
|
||||
direct=False,
|
||||
filtered_relation=filtered_relation,
|
||||
)]
|
||||
|
||||
def validate(self, value, model_instance):
|
||||
if self.remote_field.parent_link:
|
||||
|
@ -1435,7 +1457,7 @@ class ManyToManyField(RelatedField):
|
|||
)
|
||||
return name, path, args, kwargs
|
||||
|
||||
def _get_path_info(self, direct=False):
|
||||
def _get_path_info(self, direct=False, filtered_relation=None):
|
||||
"""Called by both direct and indirect m2m traversal."""
|
||||
pathinfos = []
|
||||
int_model = self.remote_field.through
|
||||
|
@ -1443,10 +1465,10 @@ class ManyToManyField(RelatedField):
|
|||
linkfield2 = int_model._meta.get_field(self.m2m_reverse_field_name())
|
||||
if direct:
|
||||
join1infos = linkfield1.get_reverse_path_info()
|
||||
join2infos = linkfield2.get_path_info()
|
||||
join2infos = linkfield2.get_path_info(filtered_relation)
|
||||
else:
|
||||
join1infos = linkfield2.get_reverse_path_info()
|
||||
join2infos = linkfield1.get_path_info()
|
||||
join2infos = linkfield1.get_path_info(filtered_relation)
|
||||
|
||||
# Get join infos between the last model of join 1 and the first model
|
||||
# of join 2. Assume the only reason these may differ is due to model
|
||||
|
@ -1465,11 +1487,11 @@ class ManyToManyField(RelatedField):
|
|||
pathinfos.extend(join2infos)
|
||||
return pathinfos
|
||||
|
||||
def get_path_info(self):
|
||||
return self._get_path_info(direct=True)
|
||||
def get_path_info(self, filtered_relation=None):
|
||||
return self._get_path_info(direct=True, filtered_relation=filtered_relation)
|
||||
|
||||
def get_reverse_path_info(self):
|
||||
return self._get_path_info(direct=False)
|
||||
def get_reverse_path_info(self, filtered_relation=None):
|
||||
return self._get_path_info(direct=False, filtered_relation=filtered_relation)
|
||||
|
||||
def _get_m2m_db_table(self, opts):
|
||||
"""
|
||||
|
|
|
@ -163,8 +163,8 @@ class ForeignObjectRel(FieldCacheMixin):
|
|||
return self.related_name
|
||||
return opts.model_name + ('_set' if self.multiple else '')
|
||||
|
||||
def get_path_info(self):
|
||||
return self.field.get_reverse_path_info()
|
||||
def get_path_info(self, filtered_relation=None):
|
||||
return self.field.get_reverse_path_info(filtered_relation)
|
||||
|
||||
def get_cache_name(self):
|
||||
"""
|
||||
|
|
|
@ -632,7 +632,15 @@ class Options:
|
|||
final_field = opts.parents[int_model]
|
||||
targets = (final_field.remote_field.get_related_field(),)
|
||||
opts = int_model._meta
|
||||
path.append(PathInfo(final_field.model._meta, opts, targets, final_field, False, True))
|
||||
path.append(PathInfo(
|
||||
from_opts=final_field.model._meta,
|
||||
to_opts=opts,
|
||||
target_fields=targets,
|
||||
join_field=final_field,
|
||||
m2m=False,
|
||||
direct=True,
|
||||
filtered_relation=None,
|
||||
))
|
||||
return path
|
||||
|
||||
def get_path_from_parent(self, parent):
|
||||
|
|
|
@ -22,7 +22,7 @@ from django.db.models.deletion import Collector
|
|||
from django.db.models.expressions import F
|
||||
from django.db.models.fields import AutoField
|
||||
from django.db.models.functions import Trunc
|
||||
from django.db.models.query_utils import InvalidQuery, Q
|
||||
from django.db.models.query_utils import FilteredRelation, InvalidQuery, Q
|
||||
from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE
|
||||
from django.utils import timezone
|
||||
from django.utils.deprecation import RemovedInDjango30Warning
|
||||
|
@ -953,6 +953,12 @@ class QuerySet:
|
|||
if lookups == (None,):
|
||||
clone._prefetch_related_lookups = ()
|
||||
else:
|
||||
for lookup in lookups:
|
||||
if isinstance(lookup, Prefetch):
|
||||
lookup = lookup.prefetch_to
|
||||
lookup = lookup.split(LOOKUP_SEP, 1)[0]
|
||||
if lookup in self.query._filtered_relations:
|
||||
raise ValueError('prefetch_related() is not supported with FilteredRelation.')
|
||||
clone._prefetch_related_lookups = clone._prefetch_related_lookups + lookups
|
||||
return clone
|
||||
|
||||
|
@ -984,7 +990,10 @@ class QuerySet:
|
|||
if alias in names:
|
||||
raise ValueError("The annotation '%s' conflicts with a field on "
|
||||
"the model." % alias)
|
||||
clone.query.add_annotation(annotation, alias, is_summary=False)
|
||||
if isinstance(annotation, FilteredRelation):
|
||||
clone.query.add_filtered_relation(annotation, alias)
|
||||
else:
|
||||
clone.query.add_annotation(annotation, alias, is_summary=False)
|
||||
|
||||
for alias, annotation in clone.query.annotations.items():
|
||||
if alias in annotations and annotation.contains_aggregate:
|
||||
|
@ -1060,6 +1069,10 @@ class QuerySet:
|
|||
# Can only pass None to defer(), not only(), as the rest option.
|
||||
# That won't stop people trying to do this, so let's be explicit.
|
||||
raise TypeError("Cannot pass None as an argument to only().")
|
||||
for field in fields:
|
||||
field = field.split(LOOKUP_SEP, 1)[0]
|
||||
if field in self.query._filtered_relations:
|
||||
raise ValueError('only() is not supported with FilteredRelation.')
|
||||
clone = self._chain()
|
||||
clone.query.add_immediate_loading(fields)
|
||||
return clone
|
||||
|
@ -1730,9 +1743,9 @@ class RelatedPopulator:
|
|||
# model's fields.
|
||||
# - related_populators: a list of RelatedPopulator instances if
|
||||
# select_related() descends to related models from this model.
|
||||
# - field, remote_field: the fields to use for populating the
|
||||
# internal fields cache. If remote_field is set then we also
|
||||
# set the reverse link.
|
||||
# - local_setter, remote_setter: Methods to set cached values on
|
||||
# the object being populated and on the remote object. Usually
|
||||
# these are Field.set_cached_value() methods.
|
||||
select_fields = klass_info['select_fields']
|
||||
from_parent = klass_info['from_parent']
|
||||
if not from_parent:
|
||||
|
@ -1751,16 +1764,8 @@ class RelatedPopulator:
|
|||
self.model_cls = klass_info['model']
|
||||
self.pk_idx = self.init_list.index(self.model_cls._meta.pk.attname)
|
||||
self.related_populators = get_related_populators(klass_info, select, self.db)
|
||||
reverse = klass_info['reverse']
|
||||
field = klass_info['field']
|
||||
self.remote_field = None
|
||||
if reverse:
|
||||
self.field = field.remote_field
|
||||
self.remote_field = field
|
||||
else:
|
||||
self.field = field
|
||||
if field.unique:
|
||||
self.remote_field = field.remote_field
|
||||
self.local_setter = klass_info['local_setter']
|
||||
self.remote_setter = klass_info['remote_setter']
|
||||
|
||||
def populate(self, row, from_obj):
|
||||
if self.reorder_for_init:
|
||||
|
@ -1774,9 +1779,9 @@ class RelatedPopulator:
|
|||
if self.related_populators:
|
||||
for rel_iter in self.related_populators:
|
||||
rel_iter.populate(row, obj)
|
||||
if self.remote_field:
|
||||
self.remote_field.set_cached_value(obj, from_obj)
|
||||
self.field.set_cached_value(from_obj, obj)
|
||||
self.local_setter(from_obj, obj)
|
||||
if obj is not None:
|
||||
self.remote_setter(obj, from_obj)
|
||||
|
||||
|
||||
def get_related_populators(klass_info, select, db):
|
||||
|
|
|
@ -16,7 +16,7 @@ from django.utils import tree
|
|||
# PathInfo is used when converting lookups (fk__somecol). The contents
|
||||
# describe the relation in Model terms (model Options and Fields for both
|
||||
# sides of the relation. The join_field is the field backing the relation.
|
||||
PathInfo = namedtuple('PathInfo', 'from_opts to_opts target_fields join_field m2m direct')
|
||||
PathInfo = namedtuple('PathInfo', 'from_opts to_opts target_fields join_field m2m direct filtered_relation')
|
||||
|
||||
|
||||
class InvalidQuery(Exception):
|
||||
|
@ -291,3 +291,44 @@ def check_rel_lookup_compatibility(model, target_opts, field):
|
|||
check(target_opts) or
|
||||
(getattr(field, 'primary_key', False) and check(field.model._meta))
|
||||
)
|
||||
|
||||
|
||||
class FilteredRelation:
|
||||
"""Specify custom filtering in the ON clause of SQL joins."""
|
||||
|
||||
def __init__(self, relation_name, *, condition=Q()):
|
||||
if not relation_name:
|
||||
raise ValueError('relation_name cannot be empty.')
|
||||
self.relation_name = relation_name
|
||||
self.alias = None
|
||||
if not isinstance(condition, Q):
|
||||
raise ValueError('condition argument must be a Q() instance.')
|
||||
self.condition = condition
|
||||
self.path = []
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
isinstance(other, self.__class__) and
|
||||
self.relation_name == other.relation_name and
|
||||
self.alias == other.alias and
|
||||
self.condition == other.condition
|
||||
)
|
||||
|
||||
def clone(self):
|
||||
clone = FilteredRelation(self.relation_name, condition=self.condition)
|
||||
clone.alias = self.alias
|
||||
clone.path = self.path[:]
|
||||
return clone
|
||||
|
||||
def resolve_expression(self, *args, **kwargs):
|
||||
"""
|
||||
QuerySet.annotate() only accepts expression-like arguments
|
||||
(with a resolve_expression() method).
|
||||
"""
|
||||
raise NotImplementedError('FilteredRelation.resolve_expression() is unused.')
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
# Resolve the condition in Join.filtered_relation.
|
||||
query = compiler.query
|
||||
where = query.build_filtered_relation_q(self.condition, reuse=set(self.path))
|
||||
return compiler.compile(where)
|
||||
|
|
|
@ -702,7 +702,7 @@ class SQLCompiler:
|
|||
"""
|
||||
result = []
|
||||
params = []
|
||||
for alias in self.query.alias_map:
|
||||
for alias in tuple(self.query.alias_map):
|
||||
if not self.query.alias_refcount[alias]:
|
||||
continue
|
||||
try:
|
||||
|
@ -737,7 +737,7 @@ class SQLCompiler:
|
|||
f.field.related_query_name()
|
||||
for f in opts.related_objects if f.field.unique
|
||||
)
|
||||
return chain(direct_choices, reverse_choices)
|
||||
return chain(direct_choices, reverse_choices, self.query._filtered_relations)
|
||||
|
||||
related_klass_infos = []
|
||||
if not restricted and cur_depth > self.query.max_depth:
|
||||
|
@ -788,7 +788,8 @@ class SQLCompiler:
|
|||
klass_info = {
|
||||
'model': f.remote_field.model,
|
||||
'field': f,
|
||||
'reverse': False,
|
||||
'local_setter': f.set_cached_value,
|
||||
'remote_setter': f.remote_field.set_cached_value if f.unique else lambda x, y: None,
|
||||
'from_parent': False,
|
||||
}
|
||||
related_klass_infos.append(klass_info)
|
||||
|
@ -825,7 +826,8 @@ class SQLCompiler:
|
|||
klass_info = {
|
||||
'model': model,
|
||||
'field': f,
|
||||
'reverse': True,
|
||||
'local_setter': f.remote_field.set_cached_value,
|
||||
'remote_setter': f.set_cached_value,
|
||||
'from_parent': from_parent,
|
||||
}
|
||||
related_klass_infos.append(klass_info)
|
||||
|
@ -842,6 +844,47 @@ class SQLCompiler:
|
|||
next, restricted)
|
||||
get_related_klass_infos(klass_info, next_klass_infos)
|
||||
fields_not_found = set(requested).difference(fields_found)
|
||||
for name in list(requested):
|
||||
# Filtered relations work only on the topmost level.
|
||||
if cur_depth > 1:
|
||||
break
|
||||
if name in self.query._filtered_relations:
|
||||
fields_found.add(name)
|
||||
f, _, join_opts, joins, _ = self.query.setup_joins([name], opts, root_alias)
|
||||
model = join_opts.model
|
||||
alias = joins[-1]
|
||||
from_parent = issubclass(model, opts.model) and model is not opts.model
|
||||
|
||||
def local_setter(obj, from_obj):
|
||||
f.remote_field.set_cached_value(from_obj, obj)
|
||||
|
||||
def remote_setter(obj, from_obj):
|
||||
setattr(from_obj, name, obj)
|
||||
klass_info = {
|
||||
'model': model,
|
||||
'field': f,
|
||||
'local_setter': local_setter,
|
||||
'remote_setter': remote_setter,
|
||||
'from_parent': from_parent,
|
||||
}
|
||||
related_klass_infos.append(klass_info)
|
||||
select_fields = []
|
||||
columns = self.get_default_columns(
|
||||
start_alias=alias, opts=model._meta,
|
||||
from_parent=opts.model,
|
||||
)
|
||||
for col in columns:
|
||||
select_fields.append(len(select))
|
||||
select.append((col, None))
|
||||
klass_info['select_fields'] = select_fields
|
||||
next_requested = requested.get(name, {})
|
||||
next_klass_infos = self.get_related_selections(
|
||||
select, opts=model._meta, root_alias=alias,
|
||||
cur_depth=cur_depth + 1, requested=next_requested,
|
||||
restricted=restricted,
|
||||
)
|
||||
get_related_klass_infos(klass_info, next_klass_infos)
|
||||
fields_not_found = set(requested).difference(fields_found)
|
||||
if fields_not_found:
|
||||
invalid_fields = ("'%s'" % s for s in fields_not_found)
|
||||
raise FieldError(
|
||||
|
|
|
@ -41,7 +41,7 @@ class Join:
|
|||
- relabeled_clone()
|
||||
"""
|
||||
def __init__(self, table_name, parent_alias, table_alias, join_type,
|
||||
join_field, nullable):
|
||||
join_field, nullable, filtered_relation=None):
|
||||
# Join table
|
||||
self.table_name = table_name
|
||||
self.parent_alias = parent_alias
|
||||
|
@ -56,6 +56,7 @@ class Join:
|
|||
self.join_field = join_field
|
||||
# Is this join nullabled?
|
||||
self.nullable = nullable
|
||||
self.filtered_relation = filtered_relation
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
"""
|
||||
|
@ -85,7 +86,11 @@ class Join:
|
|||
extra_sql, extra_params = compiler.compile(extra_cond)
|
||||
join_conditions.append('(%s)' % extra_sql)
|
||||
params.extend(extra_params)
|
||||
|
||||
if self.filtered_relation:
|
||||
extra_sql, extra_params = compiler.compile(self.filtered_relation)
|
||||
if extra_sql:
|
||||
join_conditions.append('(%s)' % extra_sql)
|
||||
params.extend(extra_params)
|
||||
if not join_conditions:
|
||||
# This might be a rel on the other end of an actual declared field.
|
||||
declared_field = getattr(self.join_field, 'field', self.join_field)
|
||||
|
@ -101,18 +106,27 @@ class Join:
|
|||
def relabeled_clone(self, change_map):
|
||||
new_parent_alias = change_map.get(self.parent_alias, self.parent_alias)
|
||||
new_table_alias = change_map.get(self.table_alias, self.table_alias)
|
||||
if self.filtered_relation is not None:
|
||||
filtered_relation = self.filtered_relation.clone()
|
||||
filtered_relation.path = [change_map.get(p, p) for p in self.filtered_relation.path]
|
||||
else:
|
||||
filtered_relation = None
|
||||
return self.__class__(
|
||||
self.table_name, new_parent_alias, new_table_alias, self.join_type,
|
||||
self.join_field, self.nullable)
|
||||
self.join_field, self.nullable, filtered_relation=filtered_relation,
|
||||
)
|
||||
|
||||
def equals(self, other, with_filtered_relation):
|
||||
return (
|
||||
isinstance(other, self.__class__) and
|
||||
self.table_name == other.table_name and
|
||||
self.parent_alias == other.parent_alias and
|
||||
self.join_field == other.join_field and
|
||||
(not with_filtered_relation or self.filtered_relation == other.filtered_relation)
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, self.__class__):
|
||||
return (
|
||||
self.table_name == other.table_name and
|
||||
self.parent_alias == other.parent_alias and
|
||||
self.join_field == other.join_field
|
||||
)
|
||||
return False
|
||||
return self.equals(other, with_filtered_relation=True)
|
||||
|
||||
def demote(self):
|
||||
new = self.relabeled_clone({})
|
||||
|
@ -134,6 +148,7 @@ class BaseTable:
|
|||
"""
|
||||
join_type = None
|
||||
parent_alias = None
|
||||
filtered_relation = None
|
||||
|
||||
def __init__(self, table_name, alias):
|
||||
self.table_name = table_name
|
||||
|
@ -146,3 +161,10 @@ class BaseTable:
|
|||
|
||||
def relabeled_clone(self, change_map):
|
||||
return self.__class__(self.table_name, change_map.get(self.table_alias, self.table_alias))
|
||||
|
||||
def equals(self, other, with_filtered_relation):
|
||||
return (
|
||||
isinstance(self, other.__class__) and
|
||||
self.table_name == other.table_name and
|
||||
self.table_alias == other.table_alias
|
||||
)
|
||||
|
|
|
@ -45,6 +45,14 @@ def get_field_names_from_opts(opts):
|
|||
))
|
||||
|
||||
|
||||
def get_children_from_q(q):
|
||||
for child in q.children:
|
||||
if isinstance(child, Node):
|
||||
yield from get_children_from_q(child)
|
||||
else:
|
||||
yield child
|
||||
|
||||
|
||||
JoinInfo = namedtuple(
|
||||
'JoinInfo',
|
||||
('final_field', 'targets', 'opts', 'joins', 'path')
|
||||
|
@ -210,6 +218,8 @@ class Query:
|
|||
# load.
|
||||
self.deferred_loading = (frozenset(), True)
|
||||
|
||||
self._filtered_relations = {}
|
||||
|
||||
@property
|
||||
def extra(self):
|
||||
if self._extra is None:
|
||||
|
@ -311,6 +321,7 @@ class Query:
|
|||
if 'subq_aliases' in self.__dict__:
|
||||
obj.subq_aliases = self.subq_aliases.copy()
|
||||
obj.used_aliases = self.used_aliases.copy()
|
||||
obj._filtered_relations = self._filtered_relations.copy()
|
||||
# Clear the cached_property
|
||||
try:
|
||||
del obj.base_table
|
||||
|
@ -624,6 +635,8 @@ class Query:
|
|||
opts = orig_opts
|
||||
for name in parts[:-1]:
|
||||
old_model = cur_model
|
||||
if name in self._filtered_relations:
|
||||
name = self._filtered_relations[name].relation_name
|
||||
source = opts.get_field(name)
|
||||
if is_reverse_o2o(source):
|
||||
cur_model = source.related_model
|
||||
|
@ -684,7 +697,7 @@ class Query:
|
|||
for model, values in seen.items():
|
||||
callback(target, model, values)
|
||||
|
||||
def table_alias(self, table_name, create=False):
|
||||
def table_alias(self, table_name, create=False, filtered_relation=None):
|
||||
"""
|
||||
Return a table alias for the given table_name and whether this is a
|
||||
new alias or not.
|
||||
|
@ -704,8 +717,8 @@ class Query:
|
|||
alias_list.append(alias)
|
||||
else:
|
||||
# The first occurrence of a table uses the table name directly.
|
||||
alias = table_name
|
||||
self.table_map[alias] = [alias]
|
||||
alias = filtered_relation.alias if filtered_relation is not None else table_name
|
||||
self.table_map[table_name] = [alias]
|
||||
self.alias_refcount[alias] = 1
|
||||
return alias, True
|
||||
|
||||
|
@ -881,7 +894,7 @@ class Query:
|
|||
"""
|
||||
return len([1 for count in self.alias_refcount.values() if count])
|
||||
|
||||
def join(self, join, reuse=None):
|
||||
def join(self, join, reuse=None, reuse_with_filtered_relation=False):
|
||||
"""
|
||||
Return an alias for the 'join', either reusing an existing alias for
|
||||
that join or creating a new one. 'join' is either a
|
||||
|
@ -890,18 +903,29 @@ class Query:
|
|||
The 'reuse' parameter can be either None which means all joins are
|
||||
reusable, or it can be a set containing the aliases that can be reused.
|
||||
|
||||
The 'reuse_with_filtered_relation' parameter is used when computing
|
||||
FilteredRelation instances.
|
||||
|
||||
A join is always created as LOUTER if the lhs alias is LOUTER to make
|
||||
sure chains like t1 LOUTER t2 INNER t3 aren't generated. All new
|
||||
joins are created as LOUTER if the join is nullable.
|
||||
"""
|
||||
reuse = [a for a, j in self.alias_map.items()
|
||||
if (reuse is None or a in reuse) and j == join]
|
||||
if reuse:
|
||||
self.ref_alias(reuse[0])
|
||||
return reuse[0]
|
||||
if reuse_with_filtered_relation and reuse:
|
||||
reuse_aliases = [
|
||||
a for a, j in self.alias_map.items()
|
||||
if a in reuse and j.equals(join, with_filtered_relation=False)
|
||||
]
|
||||
else:
|
||||
reuse_aliases = [
|
||||
a for a, j in self.alias_map.items()
|
||||
if (reuse is None or a in reuse) and j == join
|
||||
]
|
||||
if reuse_aliases:
|
||||
self.ref_alias(reuse_aliases[0])
|
||||
return reuse_aliases[0]
|
||||
|
||||
# No reuse is possible, so we need a new alias.
|
||||
alias, _ = self.table_alias(join.table_name, create=True)
|
||||
alias, _ = self.table_alias(join.table_name, create=True, filtered_relation=join.filtered_relation)
|
||||
if join.join_type:
|
||||
if self.alias_map[join.parent_alias].join_type == LOUTER or join.nullable:
|
||||
join_type = LOUTER
|
||||
|
@ -1090,7 +1114,8 @@ class Query:
|
|||
(name, lhs.output_field.__class__.__name__))
|
||||
|
||||
def build_filter(self, filter_expr, branch_negated=False, current_negated=False,
|
||||
can_reuse=None, allow_joins=True, split_subq=True):
|
||||
can_reuse=None, allow_joins=True, split_subq=True,
|
||||
reuse_with_filtered_relation=False):
|
||||
"""
|
||||
Build a WhereNode for a single filter clause but don't add it
|
||||
to this Query. Query.add_q() will then add this filter to the where
|
||||
|
@ -1112,6 +1137,9 @@ class Query:
|
|||
|
||||
The 'can_reuse' is a set of reusable joins for multijoins.
|
||||
|
||||
If 'reuse_with_filtered_relation' is True, then only joins in can_reuse
|
||||
will be reused.
|
||||
|
||||
The method will create a filter clause that can be added to the current
|
||||
query. However, if the filter isn't added to the query then the caller
|
||||
is responsible for unreffing the joins used.
|
||||
|
@ -1147,7 +1175,10 @@ class Query:
|
|||
allow_many = not branch_negated or not split_subq
|
||||
|
||||
try:
|
||||
join_info = self.setup_joins(parts, opts, alias, can_reuse=can_reuse, allow_many=allow_many)
|
||||
join_info = self.setup_joins(
|
||||
parts, opts, alias, can_reuse=can_reuse, allow_many=allow_many,
|
||||
reuse_with_filtered_relation=reuse_with_filtered_relation,
|
||||
)
|
||||
|
||||
# Prevent iterator from being consumed by check_related_objects()
|
||||
if isinstance(value, Iterator):
|
||||
|
@ -1250,6 +1281,41 @@ class Query:
|
|||
needed_inner = joinpromoter.update_join_types(self)
|
||||
return target_clause, needed_inner
|
||||
|
||||
def build_filtered_relation_q(self, q_object, reuse, branch_negated=False, current_negated=False):
|
||||
"""Add a FilteredRelation object to the current filter."""
|
||||
connector = q_object.connector
|
||||
current_negated ^= q_object.negated
|
||||
branch_negated = branch_negated or q_object.negated
|
||||
target_clause = self.where_class(connector=connector, negated=q_object.negated)
|
||||
for child in q_object.children:
|
||||
if isinstance(child, Node):
|
||||
child_clause = self.build_filtered_relation_q(
|
||||
child, reuse=reuse, branch_negated=branch_negated,
|
||||
current_negated=current_negated,
|
||||
)
|
||||
else:
|
||||
child_clause, _ = self.build_filter(
|
||||
child, can_reuse=reuse, branch_negated=branch_negated,
|
||||
current_negated=current_negated,
|
||||
allow_joins=True, split_subq=False,
|
||||
reuse_with_filtered_relation=True,
|
||||
)
|
||||
target_clause.add(child_clause, connector)
|
||||
return target_clause
|
||||
|
||||
def add_filtered_relation(self, filtered_relation, alias):
|
||||
filtered_relation.alias = alias
|
||||
lookups = dict(get_children_from_q(filtered_relation.condition))
|
||||
for lookup in chain((filtered_relation.relation_name,), lookups):
|
||||
lookup_parts, field_parts, _ = self.solve_lookup_type(lookup)
|
||||
shift = 2 if not lookup_parts else 1
|
||||
if len(field_parts) > (shift + len(lookup_parts)):
|
||||
raise ValueError(
|
||||
"FilteredRelation's condition doesn't support nested "
|
||||
"relations (got %r)." % lookup
|
||||
)
|
||||
self._filtered_relations[filtered_relation.alias] = filtered_relation
|
||||
|
||||
def names_to_path(self, names, opts, allow_many=True, fail_on_missing=False):
|
||||
"""
|
||||
Walk the list of names and turns them into PathInfo tuples. A single
|
||||
|
@ -1272,12 +1338,15 @@ class Query:
|
|||
name = opts.pk.name
|
||||
|
||||
field = None
|
||||
filtered_relation = None
|
||||
try:
|
||||
field = opts.get_field(name)
|
||||
except FieldDoesNotExist:
|
||||
if name in self.annotation_select:
|
||||
field = self.annotation_select[name].output_field
|
||||
|
||||
elif name in self._filtered_relations and pos == 0:
|
||||
filtered_relation = self._filtered_relations[name]
|
||||
field = opts.get_field(filtered_relation.relation_name)
|
||||
if field is not None:
|
||||
# Fields that contain one-to-many relations with a generic
|
||||
# model (like a GenericForeignKey) cannot generate reverse
|
||||
|
@ -1301,7 +1370,10 @@ class Query:
|
|||
pos -= 1
|
||||
if pos == -1 or fail_on_missing:
|
||||
field_names = list(get_field_names_from_opts(opts))
|
||||
available = sorted(field_names + list(self.annotation_select))
|
||||
available = sorted(
|
||||
field_names + list(self.annotation_select) +
|
||||
list(self._filtered_relations)
|
||||
)
|
||||
raise FieldError("Cannot resolve keyword '%s' into field. "
|
||||
"Choices are: %s" % (name, ", ".join(available)))
|
||||
break
|
||||
|
@ -1315,7 +1387,7 @@ class Query:
|
|||
cur_names_with_path[1].extend(path_to_parent)
|
||||
opts = path_to_parent[-1].to_opts
|
||||
if hasattr(field, 'get_path_info'):
|
||||
pathinfos = field.get_path_info()
|
||||
pathinfos = field.get_path_info(filtered_relation)
|
||||
if not allow_many:
|
||||
for inner_pos, p in enumerate(pathinfos):
|
||||
if p.m2m:
|
||||
|
@ -1340,7 +1412,8 @@ class Query:
|
|||
break
|
||||
return path, final_field, targets, names[pos + 1:]
|
||||
|
||||
def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True):
|
||||
def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True,
|
||||
reuse_with_filtered_relation=False):
|
||||
"""
|
||||
Compute the necessary table joins for the passage through the fields
|
||||
given in 'names'. 'opts' is the Options class for the current model
|
||||
|
@ -1352,6 +1425,9 @@ class Query:
|
|||
that can be reused. Note that non-reverse foreign keys are always
|
||||
reusable when using setup_joins().
|
||||
|
||||
The 'reuse_with_filtered_relation' can be used to force 'can_reuse'
|
||||
parameter and force the relation on the given connections.
|
||||
|
||||
If 'allow_many' is False, then any reverse foreign key seen will
|
||||
generate a MultiJoin exception.
|
||||
|
||||
|
@ -1374,15 +1450,29 @@ class Query:
|
|||
# joins at this stage - we will need the information about join type
|
||||
# of the trimmed joins.
|
||||
for join in path:
|
||||
if join.filtered_relation:
|
||||
filtered_relation = join.filtered_relation.clone()
|
||||
table_alias = filtered_relation.alias
|
||||
else:
|
||||
filtered_relation = None
|
||||
table_alias = None
|
||||
opts = join.to_opts
|
||||
if join.direct:
|
||||
nullable = self.is_nullable(join.join_field)
|
||||
else:
|
||||
nullable = True
|
||||
connection = Join(opts.db_table, alias, None, INNER, join.join_field, nullable)
|
||||
reuse = can_reuse if join.m2m else None
|
||||
alias = self.join(connection, reuse=reuse)
|
||||
connection = Join(
|
||||
opts.db_table, alias, table_alias, INNER, join.join_field,
|
||||
nullable, filtered_relation=filtered_relation,
|
||||
)
|
||||
reuse = can_reuse if join.m2m or reuse_with_filtered_relation else None
|
||||
alias = self.join(
|
||||
connection, reuse=reuse,
|
||||
reuse_with_filtered_relation=reuse_with_filtered_relation,
|
||||
)
|
||||
joins.append(alias)
|
||||
if filtered_relation:
|
||||
filtered_relation.path = joins[:]
|
||||
return JoinInfo(final_field, targets, opts, joins, path)
|
||||
|
||||
def trim_joins(self, targets, joins, path):
|
||||
|
@ -1402,6 +1492,8 @@ class Query:
|
|||
for pos, info in enumerate(reversed(path)):
|
||||
if len(joins) == 1 or not info.direct:
|
||||
break
|
||||
if info.filtered_relation:
|
||||
break
|
||||
join_targets = {t.column for t in info.join_field.foreign_related_fields}
|
||||
cur_targets = {t.column for t in targets}
|
||||
if not cur_targets.issubset(join_targets):
|
||||
|
@ -1425,7 +1517,7 @@ class Query:
|
|||
return self.annotation_select[name]
|
||||
else:
|
||||
field_list = name.split(LOOKUP_SEP)
|
||||
join_info = self.setup_joins(field_list, self.get_meta(), self.get_initial_alias(), reuse)
|
||||
join_info = self.setup_joins(field_list, self.get_meta(), self.get_initial_alias(), can_reuse=reuse)
|
||||
targets, _, join_list = self.trim_joins(join_info.targets, join_info.joins, join_info.path)
|
||||
if len(targets) > 1:
|
||||
raise FieldError("Referencing multicolumn fields with F() objects "
|
||||
|
@ -1602,7 +1694,10 @@ class Query:
|
|||
# from the model on which the lookup failed.
|
||||
raise
|
||||
else:
|
||||
names = sorted(list(get_field_names_from_opts(opts)) + list(self.extra) + list(self.annotation_select))
|
||||
names = sorted(
|
||||
list(get_field_names_from_opts(opts)) + list(self.extra) +
|
||||
list(self.annotation_select) + list(self._filtered_relations)
|
||||
)
|
||||
raise FieldError("Cannot resolve keyword %r into field. "
|
||||
"Choices are: %s" % (name, ", ".join(names)))
|
||||
|
||||
|
|
|
@ -3318,3 +3318,60 @@ lookups or :class:`Prefetch` objects you want to prefetch for. For example::
|
|||
>>> from django.db.models import prefetch_related_objects
|
||||
>>> restaurants = fetch_top_restaurants_from_cache() # A list of Restaurants
|
||||
>>> prefetch_related_objects(restaurants, 'pizzas__toppings')
|
||||
|
||||
``FilteredRelation()`` objects
|
||||
------------------------------
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
.. class:: FilteredRelation(relation_name, *, condition=Q())
|
||||
|
||||
.. attribute:: FilteredRelation.relation_name
|
||||
|
||||
The name of the field on which you'd like to filter the relation.
|
||||
|
||||
.. attribute:: FilteredRelation.condition
|
||||
|
||||
A :class:`~django.db.models.Q` object to control the filtering.
|
||||
|
||||
``FilteredRelation`` is used with :meth:`~.QuerySet.annotate()` to create an
|
||||
``ON`` clause when a ``JOIN`` is performed. It doesn't act on the default
|
||||
relationship but on the annotation name (``pizzas_vegetarian`` in example
|
||||
below).
|
||||
|
||||
For example, to find restaurants that have vegetarian pizzas with
|
||||
``'mozzarella'`` in the name::
|
||||
|
||||
>>> from django.db.models import FilteredRelation, Q
|
||||
>>> Restaurant.objects.annotate(
|
||||
... pizzas_vegetarian=FilteredRelation(
|
||||
... 'pizzas', condition=Q(pizzas__vegetarian=True),
|
||||
... ),
|
||||
... ).filter(pizzas_vegetarian__name__icontains='mozzarella')
|
||||
|
||||
If there are a large number of pizzas, this queryset performs better than::
|
||||
|
||||
>>> Restaurant.objects.filter(
|
||||
... pizzas__vegetarian=True,
|
||||
... pizzas__name__icontains='mozzarella',
|
||||
... )
|
||||
|
||||
because the filtering in the ``WHERE`` clause of the first queryset will only
|
||||
operate on vegetarian pizzas.
|
||||
|
||||
``FilteredRelation`` doesn't support:
|
||||
|
||||
* Conditions that span relational fields. For example::
|
||||
|
||||
>>> Restaurant.objects.annotate(
|
||||
... pizzas_with_toppings_startswith_n=FilteredRelation(
|
||||
... 'pizzas__toppings',
|
||||
... condition=Q(pizzas__toppings__name__startswith='n'),
|
||||
... ),
|
||||
... )
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ValueError: FilteredRelation's condition doesn't support nested relations (got 'pizzas__toppings__name__startswith').
|
||||
* :meth:`.QuerySet.only` and :meth:`~.QuerySet.prefetch_related`.
|
||||
* A :class:`~django.contrib.contenttypes.fields.GenericForeignKey`
|
||||
inherited from a parent model.
|
||||
|
|
|
@ -354,6 +354,9 @@ Models
|
|||
* The new ``named`` parameter of :meth:`.QuerySet.values_list` allows fetching
|
||||
results as named tuples.
|
||||
|
||||
* The new :class:`.FilteredRelation` class allows adding an ``ON`` clause to
|
||||
querysets.
|
||||
|
||||
Pagination
|
||||
~~~~~~~~~~
|
||||
|
||||
|
|
|
@ -0,0 +1,108 @@
|
|||
from django.contrib.contenttypes.fields import (
|
||||
GenericForeignKey, GenericRelation,
|
||||
)
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.db import models
|
||||
|
||||
|
||||
class Author(models.Model):
|
||||
name = models.CharField(max_length=50, unique=True)
|
||||
favorite_books = models.ManyToManyField(
|
||||
'Book',
|
||||
related_name='preferred_by_authors',
|
||||
related_query_name='preferred_by_authors',
|
||||
)
|
||||
content_type = models.ForeignKey(ContentType, models.CASCADE, null=True)
|
||||
object_id = models.PositiveIntegerField(null=True)
|
||||
content_object = GenericForeignKey()
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
class Editor(models.Model):
|
||||
name = models.CharField(max_length=255)
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
class Book(models.Model):
|
||||
AVAILABLE = 'available'
|
||||
RESERVED = 'reserved'
|
||||
RENTED = 'rented'
|
||||
STATES = (
|
||||
(AVAILABLE, 'Available'),
|
||||
(RESERVED, 'reserved'),
|
||||
(RENTED, 'Rented'),
|
||||
)
|
||||
title = models.CharField(max_length=255)
|
||||
author = models.ForeignKey(
|
||||
Author,
|
||||
models.CASCADE,
|
||||
related_name='books',
|
||||
related_query_name='book',
|
||||
)
|
||||
editor = models.ForeignKey(Editor, models.CASCADE)
|
||||
generic_author = GenericRelation(Author)
|
||||
state = models.CharField(max_length=9, choices=STATES, default=AVAILABLE)
|
||||
|
||||
def __str__(self):
|
||||
return self.title
|
||||
|
||||
|
||||
class Borrower(models.Model):
|
||||
name = models.CharField(max_length=50, unique=True)
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
class Reservation(models.Model):
|
||||
NEW = 'new'
|
||||
STOPPED = 'stopped'
|
||||
STATES = (
|
||||
(NEW, 'New'),
|
||||
(STOPPED, 'Stopped'),
|
||||
)
|
||||
borrower = models.ForeignKey(
|
||||
Borrower,
|
||||
models.CASCADE,
|
||||
related_name='reservations',
|
||||
related_query_name='reservation',
|
||||
)
|
||||
book = models.ForeignKey(
|
||||
Book,
|
||||
models.CASCADE,
|
||||
related_name='reservations',
|
||||
related_query_name='reservation',
|
||||
)
|
||||
state = models.CharField(max_length=7, choices=STATES, default=NEW)
|
||||
|
||||
def __str__(self):
|
||||
return '-'.join((self.book.name, self.borrower.name, self.state))
|
||||
|
||||
|
||||
class RentalSession(models.Model):
|
||||
NEW = 'new'
|
||||
STOPPED = 'stopped'
|
||||
STATES = (
|
||||
(NEW, 'New'),
|
||||
(STOPPED, 'Stopped'),
|
||||
)
|
||||
borrower = models.ForeignKey(
|
||||
Borrower,
|
||||
models.CASCADE,
|
||||
related_name='rental_sessions',
|
||||
related_query_name='rental_session',
|
||||
)
|
||||
book = models.ForeignKey(
|
||||
Book,
|
||||
models.CASCADE,
|
||||
related_name='rental_sessions',
|
||||
related_query_name='rental_session',
|
||||
)
|
||||
state = models.CharField(max_length=7, choices=STATES, default=NEW)
|
||||
|
||||
def __str__(self):
|
||||
return '-'.join((self.book.name, self.borrower.name, self.state))
|
|
@ -0,0 +1,381 @@
|
|||
from django.db import connection
|
||||
from django.db.models import Case, Count, F, FilteredRelation, Q, When
|
||||
from django.test import TestCase
|
||||
from django.test.testcases import skipUnlessDBFeature
|
||||
|
||||
from .models import Author, Book, Borrower, Editor, RentalSession, Reservation
|
||||
|
||||
|
||||
class FilteredRelationTests(TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpTestData(cls):
|
||||
cls.author1 = Author.objects.create(name='Alice')
|
||||
cls.author2 = Author.objects.create(name='Jane')
|
||||
cls.editor_a = Editor.objects.create(name='a')
|
||||
cls.editor_b = Editor.objects.create(name='b')
|
||||
cls.book1 = Book.objects.create(
|
||||
title='Poem by Alice',
|
||||
editor=cls.editor_a,
|
||||
author=cls.author1,
|
||||
)
|
||||
cls.book1.generic_author.set([cls.author2])
|
||||
cls.book2 = Book.objects.create(
|
||||
title='The book by Jane A',
|
||||
editor=cls.editor_b,
|
||||
author=cls.author2,
|
||||
)
|
||||
cls.book3 = Book.objects.create(
|
||||
title='The book by Jane B',
|
||||
editor=cls.editor_b,
|
||||
author=cls.author2,
|
||||
)
|
||||
cls.book4 = Book.objects.create(
|
||||
title='The book by Alice',
|
||||
editor=cls.editor_a,
|
||||
author=cls.author1,
|
||||
)
|
||||
cls.author1.favorite_books.add(cls.book2)
|
||||
cls.author1.favorite_books.add(cls.book3)
|
||||
|
||||
def test_select_related(self):
|
||||
qs = Author.objects.annotate(
|
||||
book_join=FilteredRelation('book'),
|
||||
).select_related('book_join__editor').order_by('pk', 'book_join__pk')
|
||||
with self.assertNumQueries(1):
|
||||
self.assertQuerysetEqual(qs, [
|
||||
(self.author1, self.book1, self.editor_a, self.author1),
|
||||
(self.author1, self.book4, self.editor_a, self.author1),
|
||||
(self.author2, self.book2, self.editor_b, self.author2),
|
||||
(self.author2, self.book3, self.editor_b, self.author2),
|
||||
], lambda x: (x, x.book_join, x.book_join.editor, x.book_join.author))
|
||||
|
||||
def test_select_related_foreign_key(self):
|
||||
qs = Book.objects.annotate(
|
||||
author_join=FilteredRelation('author'),
|
||||
).select_related('author_join').order_by('pk')
|
||||
with self.assertNumQueries(1):
|
||||
self.assertQuerysetEqual(qs, [
|
||||
(self.book1, self.author1),
|
||||
(self.book2, self.author2),
|
||||
(self.book3, self.author2),
|
||||
(self.book4, self.author1),
|
||||
], lambda x: (x, x.author_join))
|
||||
|
||||
def test_without_join(self):
|
||||
self.assertSequenceEqual(
|
||||
Author.objects.annotate(
|
||||
book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')),
|
||||
),
|
||||
[self.author1, self.author2]
|
||||
)
|
||||
|
||||
def test_with_join(self):
|
||||
self.assertSequenceEqual(
|
||||
Author.objects.annotate(
|
||||
book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')),
|
||||
).filter(book_alice__isnull=False),
|
||||
[self.author1]
|
||||
)
|
||||
|
||||
def test_with_join_and_complex_condition(self):
|
||||
self.assertSequenceEqual(
|
||||
Author.objects.annotate(
|
||||
book_alice=FilteredRelation(
|
||||
'book', condition=Q(
|
||||
Q(book__title__iexact='poem by alice') |
|
||||
Q(book__state=Book.RENTED)
|
||||
),
|
||||
),
|
||||
).filter(book_alice__isnull=False),
|
||||
[self.author1]
|
||||
)
|
||||
|
||||
def test_internal_queryset_alias_mapping(self):
|
||||
queryset = Author.objects.annotate(
|
||||
book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')),
|
||||
).filter(book_alice__isnull=False)
|
||||
self.assertIn(
|
||||
'INNER JOIN {} book_alice ON'.format(connection.ops.quote_name('filtered_relation_book')),
|
||||
str(queryset.query)
|
||||
)
|
||||
|
||||
def test_with_multiple_filter(self):
|
||||
self.assertSequenceEqual(
|
||||
Author.objects.annotate(
|
||||
book_editor_a=FilteredRelation(
|
||||
'book',
|
||||
condition=Q(book__title__icontains='book', book__editor_id=self.editor_a.pk),
|
||||
),
|
||||
).filter(book_editor_a__isnull=False),
|
||||
[self.author1]
|
||||
)
|
||||
|
||||
def test_multiple_times(self):
|
||||
self.assertSequenceEqual(
|
||||
Author.objects.annotate(
|
||||
book_title_alice=FilteredRelation('book', condition=Q(book__title__icontains='alice')),
|
||||
).filter(book_title_alice__isnull=False).filter(book_title_alice__isnull=False).distinct(),
|
||||
[self.author1]
|
||||
)
|
||||
|
||||
def test_exclude_relation_with_join(self):
|
||||
self.assertSequenceEqual(
|
||||
Author.objects.annotate(
|
||||
book_alice=FilteredRelation('book', condition=~Q(book__title__icontains='alice')),
|
||||
).filter(book_alice__isnull=False).distinct(),
|
||||
[self.author2]
|
||||
)
|
||||
|
||||
def test_with_m2m(self):
|
||||
qs = Author.objects.annotate(
|
||||
favorite_books_written_by_jane=FilteredRelation(
|
||||
'favorite_books', condition=Q(favorite_books__in=[self.book2]),
|
||||
),
|
||||
).filter(favorite_books_written_by_jane__isnull=False)
|
||||
self.assertSequenceEqual(qs, [self.author1])
|
||||
|
||||
def test_with_m2m_deep(self):
|
||||
qs = Author.objects.annotate(
|
||||
favorite_books_written_by_jane=FilteredRelation(
|
||||
'favorite_books', condition=Q(favorite_books__author=self.author2),
|
||||
),
|
||||
).filter(favorite_books_written_by_jane__title='The book by Jane B')
|
||||
self.assertSequenceEqual(qs, [self.author1])
|
||||
|
||||
def test_with_m2m_multijoin(self):
|
||||
qs = Author.objects.annotate(
|
||||
favorite_books_written_by_jane=FilteredRelation(
|
||||
'favorite_books', condition=Q(favorite_books__author=self.author2),
|
||||
)
|
||||
).filter(favorite_books_written_by_jane__editor__name='b').distinct()
|
||||
self.assertSequenceEqual(qs, [self.author1])
|
||||
|
||||
def test_values_list(self):
|
||||
self.assertSequenceEqual(
|
||||
Author.objects.annotate(
|
||||
book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')),
|
||||
).filter(book_alice__isnull=False).values_list('book_alice__title', flat=True),
|
||||
['Poem by Alice']
|
||||
)
|
||||
|
||||
def test_values(self):
|
||||
self.assertSequenceEqual(
|
||||
Author.objects.annotate(
|
||||
book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')),
|
||||
).filter(book_alice__isnull=False).values(),
|
||||
[{'id': self.author1.pk, 'name': 'Alice', 'content_type_id': None, 'object_id': None}]
|
||||
)
|
||||
|
||||
def test_extra(self):
|
||||
self.assertSequenceEqual(
|
||||
Author.objects.annotate(
|
||||
book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')),
|
||||
).filter(book_alice__isnull=False).extra(where=['1 = 1']),
|
||||
[self.author1]
|
||||
)
|
||||
|
||||
@skipUnlessDBFeature('supports_select_union')
|
||||
def test_union(self):
|
||||
qs1 = Author.objects.annotate(
|
||||
book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')),
|
||||
).filter(book_alice__isnull=False)
|
||||
qs2 = Author.objects.annotate(
|
||||
book_jane=FilteredRelation('book', condition=Q(book__title__iexact='the book by jane a')),
|
||||
).filter(book_jane__isnull=False)
|
||||
self.assertSequenceEqual(qs1.union(qs2), [self.author1, self.author2])
|
||||
|
||||
@skipUnlessDBFeature('supports_select_intersection')
|
||||
def test_intersection(self):
|
||||
qs1 = Author.objects.annotate(
|
||||
book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')),
|
||||
).filter(book_alice__isnull=False)
|
||||
qs2 = Author.objects.annotate(
|
||||
book_jane=FilteredRelation('book', condition=Q(book__title__iexact='the book by jane a')),
|
||||
).filter(book_jane__isnull=False)
|
||||
self.assertSequenceEqual(qs1.intersection(qs2), [])
|
||||
|
||||
@skipUnlessDBFeature('supports_select_difference')
|
||||
def test_difference(self):
|
||||
qs1 = Author.objects.annotate(
|
||||
book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')),
|
||||
).filter(book_alice__isnull=False)
|
||||
qs2 = Author.objects.annotate(
|
||||
book_jane=FilteredRelation('book', condition=Q(book__title__iexact='the book by jane a')),
|
||||
).filter(book_jane__isnull=False)
|
||||
self.assertSequenceEqual(qs1.difference(qs2), [self.author1])
|
||||
|
||||
def test_select_for_update(self):
|
||||
self.assertSequenceEqual(
|
||||
Author.objects.annotate(
|
||||
book_jane=FilteredRelation('book', condition=Q(book__title__iexact='the book by jane a')),
|
||||
).filter(book_jane__isnull=False).select_for_update(),
|
||||
[self.author2]
|
||||
)
|
||||
|
||||
def test_defer(self):
|
||||
# One query for the list and one query for the deferred title.
|
||||
with self.assertNumQueries(2):
|
||||
self.assertQuerysetEqual(
|
||||
Author.objects.annotate(
|
||||
book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')),
|
||||
).filter(book_alice__isnull=False).select_related('book_alice').defer('book_alice__title'),
|
||||
['Poem by Alice'], lambda author: author.book_alice.title
|
||||
)
|
||||
|
||||
def test_only_not_supported(self):
|
||||
msg = 'only() is not supported with FilteredRelation.'
|
||||
with self.assertRaisesMessage(ValueError, msg):
|
||||
Author.objects.annotate(
|
||||
book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')),
|
||||
).filter(book_alice__isnull=False).select_related('book_alice').only('book_alice__state')
|
||||
|
||||
def test_as_subquery(self):
|
||||
inner_qs = Author.objects.annotate(
|
||||
book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')),
|
||||
).filter(book_alice__isnull=False)
|
||||
qs = Author.objects.filter(id__in=inner_qs)
|
||||
self.assertSequenceEqual(qs, [self.author1])
|
||||
|
||||
def test_with_foreign_key_error(self):
|
||||
msg = (
|
||||
"FilteredRelation's condition doesn't support nested relations "
|
||||
"(got 'author__favorite_books__author')."
|
||||
)
|
||||
with self.assertRaisesMessage(ValueError, msg):
|
||||
list(Book.objects.annotate(
|
||||
alice_favorite_books=FilteredRelation(
|
||||
'author__favorite_books',
|
||||
condition=Q(author__favorite_books__author=self.author1),
|
||||
)
|
||||
))
|
||||
|
||||
def test_with_foreign_key_on_condition_error(self):
|
||||
msg = (
|
||||
"FilteredRelation's condition doesn't support nested relations "
|
||||
"(got 'book__editor__name__icontains')."
|
||||
)
|
||||
with self.assertRaisesMessage(ValueError, msg):
|
||||
list(Author.objects.annotate(
|
||||
book_edited_by_b=FilteredRelation('book', condition=Q(book__editor__name__icontains='b')),
|
||||
))
|
||||
|
||||
def test_with_empty_relation_name_error(self):
|
||||
with self.assertRaisesMessage(ValueError, 'relation_name cannot be empty.'):
|
||||
FilteredRelation('', condition=Q(blank=''))
|
||||
|
||||
def test_with_condition_as_expression_error(self):
|
||||
msg = 'condition argument must be a Q() instance.'
|
||||
expression = Case(
|
||||
When(book__title__iexact='poem by alice', then=True), default=False,
|
||||
)
|
||||
with self.assertRaisesMessage(ValueError, msg):
|
||||
FilteredRelation('book', condition=expression)
|
||||
|
||||
def test_with_prefetch_related(self):
|
||||
msg = 'prefetch_related() is not supported with FilteredRelation.'
|
||||
qs = Author.objects.annotate(
|
||||
book_title_contains_b=FilteredRelation('book', condition=Q(book__title__icontains='b')),
|
||||
).filter(
|
||||
book_title_contains_b__isnull=False,
|
||||
)
|
||||
with self.assertRaisesMessage(ValueError, msg):
|
||||
qs.prefetch_related('book_title_contains_b')
|
||||
with self.assertRaisesMessage(ValueError, msg):
|
||||
qs.prefetch_related('book_title_contains_b__editor')
|
||||
|
||||
def test_with_generic_foreign_key(self):
|
||||
self.assertSequenceEqual(
|
||||
Book.objects.annotate(
|
||||
generic_authored_book=FilteredRelation(
|
||||
'generic_author',
|
||||
condition=Q(generic_author__isnull=False)
|
||||
),
|
||||
).filter(generic_authored_book__isnull=False),
|
||||
[self.book1]
|
||||
)
|
||||
|
||||
|
||||
class FilteredRelationAggregationTests(TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpTestData(cls):
|
||||
cls.author1 = Author.objects.create(name='Alice')
|
||||
cls.editor_a = Editor.objects.create(name='a')
|
||||
cls.book1 = Book.objects.create(
|
||||
title='Poem by Alice',
|
||||
editor=cls.editor_a,
|
||||
author=cls.author1,
|
||||
)
|
||||
cls.borrower1 = Borrower.objects.create(name='Jenny')
|
||||
cls.borrower2 = Borrower.objects.create(name='Kevin')
|
||||
# borrower 1 reserves, rents, and returns book1.
|
||||
Reservation.objects.create(
|
||||
borrower=cls.borrower1,
|
||||
book=cls.book1,
|
||||
state=Reservation.STOPPED,
|
||||
)
|
||||
RentalSession.objects.create(
|
||||
borrower=cls.borrower1,
|
||||
book=cls.book1,
|
||||
state=RentalSession.STOPPED,
|
||||
)
|
||||
# borrower2 reserves, rents, and returns book1.
|
||||
Reservation.objects.create(
|
||||
borrower=cls.borrower2,
|
||||
book=cls.book1,
|
||||
state=Reservation.STOPPED,
|
||||
)
|
||||
RentalSession.objects.create(
|
||||
borrower=cls.borrower2,
|
||||
book=cls.book1,
|
||||
state=RentalSession.STOPPED,
|
||||
)
|
||||
|
||||
def test_aggregate(self):
|
||||
"""
|
||||
filtered_relation() not only improves performance but also creates
|
||||
correct results when aggregating with multiple LEFT JOINs.
|
||||
|
||||
Books can be reserved then rented by a borrower. Each reservation and
|
||||
rental session are recorded with Reservation and RentalSession models.
|
||||
Every time a reservation or a rental session is over, their state is
|
||||
changed to 'stopped'.
|
||||
|
||||
Goal: Count number of books that are either currently reserved or
|
||||
rented by borrower1 or available.
|
||||
"""
|
||||
qs = Book.objects.annotate(
|
||||
is_reserved_or_rented_by=Case(
|
||||
When(reservation__state=Reservation.NEW, then=F('reservation__borrower__pk')),
|
||||
When(rental_session__state=RentalSession.NEW, then=F('rental_session__borrower__pk')),
|
||||
default=None,
|
||||
)
|
||||
).filter(
|
||||
Q(is_reserved_or_rented_by=self.borrower1.pk) | Q(state=Book.AVAILABLE)
|
||||
).distinct()
|
||||
self.assertEqual(qs.count(), 1)
|
||||
# If count is equal to 1, the same aggregation should return in the
|
||||
# same result but it returns 4.
|
||||
self.assertSequenceEqual(qs.annotate(total=Count('pk')).values('total'), [{'total': 4}])
|
||||
# With FilteredRelation, the result is as expected (1).
|
||||
qs = Book.objects.annotate(
|
||||
active_reservations=FilteredRelation(
|
||||
'reservation', condition=Q(
|
||||
reservation__state=Reservation.NEW,
|
||||
reservation__borrower=self.borrower1,
|
||||
)
|
||||
),
|
||||
).annotate(
|
||||
active_rental_sessions=FilteredRelation(
|
||||
'rental_session', condition=Q(
|
||||
rental_session__state=RentalSession.NEW,
|
||||
rental_session__borrower=self.borrower1,
|
||||
)
|
||||
),
|
||||
).filter(
|
||||
(Q(active_reservations__isnull=False) | Q(active_rental_sessions__isnull=False)) |
|
||||
Q(state=Book.AVAILABLE)
|
||||
).distinct()
|
||||
self.assertEqual(qs.count(), 1)
|
||||
self.assertSequenceEqual(qs.annotate(total=Count('pk')).values('total'), [{'total': 1}])
|
|
@ -53,15 +53,31 @@ class StartsWithRelation(models.ForeignObject):
|
|||
def get_joining_columns(self, reverse_join=False):
|
||||
return ()
|
||||
|
||||
def get_path_info(self):
|
||||
def get_path_info(self, filtered_relation=None):
|
||||
to_opts = self.remote_field.model._meta
|
||||
from_opts = self.model._meta
|
||||
return [PathInfo(from_opts, to_opts, (to_opts.pk,), self, False, False)]
|
||||
return [PathInfo(
|
||||
from_opts=from_opts,
|
||||
to_opts=to_opts,
|
||||
target_fields=(to_opts.pk,),
|
||||
join_field=self,
|
||||
m2m=False,
|
||||
direct=False,
|
||||
filtered_relation=filtered_relation,
|
||||
)]
|
||||
|
||||
def get_reverse_path_info(self):
|
||||
def get_reverse_path_info(self, filtered_relation=None):
|
||||
to_opts = self.model._meta
|
||||
from_opts = self.remote_field.model._meta
|
||||
return [PathInfo(from_opts, to_opts, (to_opts.pk,), self.remote_field, False, False)]
|
||||
return [PathInfo(
|
||||
from_opts=from_opts,
|
||||
to_opts=to_opts,
|
||||
target_fields=(to_opts.pk,),
|
||||
join_field=self.remote_field,
|
||||
m2m=False,
|
||||
direct=False,
|
||||
filtered_relation=filtered_relation,
|
||||
)]
|
||||
|
||||
def contribute_to_class(self, cls, name, private_only=False):
|
||||
super().contribute_to_class(cls, name, private_only)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from django.core.exceptions import FieldError
|
||||
from django.db.models import FilteredRelation
|
||||
from django.test import SimpleTestCase, TestCase
|
||||
|
||||
from .models import (
|
||||
|
@ -230,3 +231,8 @@ class ReverseSelectRelatedValidationTests(SimpleTestCase):
|
|||
|
||||
with self.assertRaisesMessage(FieldError, self.non_relational_error % ('username', fields)):
|
||||
list(User.objects.select_related('username'))
|
||||
|
||||
def test_reverse_related_validation_with_filtered_relation(self):
|
||||
fields = 'userprofile, userstat, relation'
|
||||
with self.assertRaisesMessage(FieldError, self.invalid_error % ('foobar', fields)):
|
||||
list(User.objects.annotate(relation=FilteredRelation('userprofile')).select_related('foobar'))
|
||||
|
|
Loading…
Reference in New Issue