diff --git a/django/contrib/contenttypes/fields.py b/django/contrib/contenttypes/fields.py index 21367c4195..2227707203 100644 --- a/django/contrib/contenttypes/fields.py +++ b/django/contrib/contenttypes/fields.py @@ -348,7 +348,7 @@ class GenericRelation(ForeignObject): self.to_fields = [self.model._meta.pk.name] return [(self.remote_field.model._meta.get_field(self.object_id_field_name), self.model._meta.pk)] - def _get_path_info_with_parent(self): + def _get_path_info_with_parent(self, filtered_relation): """ Return the path that joins the current model through any parent models. The idea is that if you have a GFK defined on a parent model then we @@ -365,7 +365,15 @@ class GenericRelation(ForeignObject): opts = self.remote_field.model._meta.concrete_model._meta parent_opts = opts.get_field(self.object_id_field_name).model._meta target = parent_opts.pk - path.append(PathInfo(self.model._meta, parent_opts, (target,), self.remote_field, True, False)) + path.append(PathInfo( + from_opts=self.model._meta, + to_opts=parent_opts, + target_fields=(target,), + join_field=self.remote_field, + m2m=True, + direct=False, + filtered_relation=filtered_relation, + )) # Collect joins needed for the parent -> child chain. This is easiest # to do if we collect joins for the child -> parent chain and then # reverse the direction (call to reverse() and use of @@ -380,19 +388,35 @@ class GenericRelation(ForeignObject): path.extend(field.remote_field.get_path_info()) return path - def get_path_info(self): + def get_path_info(self, filtered_relation=None): opts = self.remote_field.model._meta object_id_field = opts.get_field(self.object_id_field_name) if object_id_field.model != opts.model: - return self._get_path_info_with_parent() + return self._get_path_info_with_parent(filtered_relation) else: target = opts.pk - return [PathInfo(self.model._meta, opts, (target,), self.remote_field, True, False)] + return [PathInfo( + from_opts=self.model._meta, + to_opts=opts, + target_fields=(target,), + join_field=self.remote_field, + m2m=True, + direct=False, + filtered_relation=filtered_relation, + )] - def get_reverse_path_info(self): + def get_reverse_path_info(self, filtered_relation=None): opts = self.model._meta from_opts = self.remote_field.model._meta - return [PathInfo(from_opts, opts, (opts.pk,), self, not self.unique, False)] + return [PathInfo( + from_opts=from_opts, + to_opts=opts, + target_fields=(opts.pk,), + join_field=self, + m2m=not self.unique, + direct=False, + filtered_relation=filtered_relation, + )] def value_to_string(self, obj): qs = getattr(obj, self.name).all() diff --git a/django/db/models/__init__.py b/django/db/models/__init__.py index d29addd1f7..628f92db3c 100644 --- a/django/db/models/__init__.py +++ b/django/db/models/__init__.py @@ -20,6 +20,7 @@ from django.db.models.manager import Manager from django.db.models.query import ( Prefetch, Q, QuerySet, prefetch_related_objects, ) +from django.db.models.query_utils import FilteredRelation # Imports that would create circular imports if sorted from django.db.models.base import DEFERRED, Model # isort:skip @@ -69,6 +70,7 @@ __all__ += [ 'Window', 'WindowFrame', 'FileField', 'ImageField', 'OrderWrt', 'Lookup', 'Transform', 'Manager', 'Prefetch', 'Q', 'QuerySet', 'prefetch_related_objects', 'DEFERRED', 'Model', + 'FilteredRelation', 'ForeignKey', 'ForeignObject', 'OneToOneField', 'ManyToManyField', 'ManyToOneRel', 'ManyToManyRel', 'OneToOneRel', 'permalink', ] diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index 5cf540d385..34123fd4de 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -697,18 +697,33 @@ class ForeignObject(RelatedField): """ return None - def get_path_info(self): + def get_path_info(self, filtered_relation=None): """Get path from this field to the related model.""" opts = self.remote_field.model._meta from_opts = self.model._meta - return [PathInfo(from_opts, opts, self.foreign_related_fields, self, False, True)] + return [PathInfo( + from_opts=from_opts, + to_opts=opts, + target_fields=self.foreign_related_fields, + join_field=self, + m2m=False, + direct=True, + filtered_relation=filtered_relation, + )] - def get_reverse_path_info(self): + def get_reverse_path_info(self, filtered_relation=None): """Get path from the related model to this field's model.""" opts = self.model._meta from_opts = self.remote_field.model._meta - pathinfos = [PathInfo(from_opts, opts, (opts.pk,), self.remote_field, not self.unique, False)] - return pathinfos + return [PathInfo( + from_opts=from_opts, + to_opts=opts, + target_fields=(opts.pk,), + join_field=self.remote_field, + m2m=not self.unique, + direct=False, + filtered_relation=filtered_relation, + )] @classmethod @functools.lru_cache(maxsize=None) @@ -861,12 +876,19 @@ class ForeignKey(ForeignObject): def target_field(self): return self.foreign_related_fields[0] - def get_reverse_path_info(self): + def get_reverse_path_info(self, filtered_relation=None): """Get path from the related model to this field's model.""" opts = self.model._meta from_opts = self.remote_field.model._meta - pathinfos = [PathInfo(from_opts, opts, (opts.pk,), self.remote_field, not self.unique, False)] - return pathinfos + return [PathInfo( + from_opts=from_opts, + to_opts=opts, + target_fields=(opts.pk,), + join_field=self.remote_field, + m2m=not self.unique, + direct=False, + filtered_relation=filtered_relation, + )] def validate(self, value, model_instance): if self.remote_field.parent_link: @@ -1435,7 +1457,7 @@ class ManyToManyField(RelatedField): ) return name, path, args, kwargs - def _get_path_info(self, direct=False): + def _get_path_info(self, direct=False, filtered_relation=None): """Called by both direct and indirect m2m traversal.""" pathinfos = [] int_model = self.remote_field.through @@ -1443,10 +1465,10 @@ class ManyToManyField(RelatedField): linkfield2 = int_model._meta.get_field(self.m2m_reverse_field_name()) if direct: join1infos = linkfield1.get_reverse_path_info() - join2infos = linkfield2.get_path_info() + join2infos = linkfield2.get_path_info(filtered_relation) else: join1infos = linkfield2.get_reverse_path_info() - join2infos = linkfield1.get_path_info() + join2infos = linkfield1.get_path_info(filtered_relation) # Get join infos between the last model of join 1 and the first model # of join 2. Assume the only reason these may differ is due to model @@ -1465,11 +1487,11 @@ class ManyToManyField(RelatedField): pathinfos.extend(join2infos) return pathinfos - def get_path_info(self): - return self._get_path_info(direct=True) + def get_path_info(self, filtered_relation=None): + return self._get_path_info(direct=True, filtered_relation=filtered_relation) - def get_reverse_path_info(self): - return self._get_path_info(direct=False) + def get_reverse_path_info(self, filtered_relation=None): + return self._get_path_info(direct=False, filtered_relation=filtered_relation) def _get_m2m_db_table(self, opts): """ diff --git a/django/db/models/fields/reverse_related.py b/django/db/models/fields/reverse_related.py index 1f42375566..dddb869513 100644 --- a/django/db/models/fields/reverse_related.py +++ b/django/db/models/fields/reverse_related.py @@ -163,8 +163,8 @@ class ForeignObjectRel(FieldCacheMixin): return self.related_name return opts.model_name + ('_set' if self.multiple else '') - def get_path_info(self): - return self.field.get_reverse_path_info() + def get_path_info(self, filtered_relation=None): + return self.field.get_reverse_path_info(filtered_relation) def get_cache_name(self): """ diff --git a/django/db/models/options.py b/django/db/models/options.py index 9f0746bd58..0786e525b3 100644 --- a/django/db/models/options.py +++ b/django/db/models/options.py @@ -632,7 +632,15 @@ class Options: final_field = opts.parents[int_model] targets = (final_field.remote_field.get_related_field(),) opts = int_model._meta - path.append(PathInfo(final_field.model._meta, opts, targets, final_field, False, True)) + path.append(PathInfo( + from_opts=final_field.model._meta, + to_opts=opts, + target_fields=targets, + join_field=final_field, + m2m=False, + direct=True, + filtered_relation=None, + )) return path def get_path_from_parent(self, parent): diff --git a/django/db/models/query.py b/django/db/models/query.py index 42fb728190..3bfe0a6fb4 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -22,7 +22,7 @@ from django.db.models.deletion import Collector from django.db.models.expressions import F from django.db.models.fields import AutoField from django.db.models.functions import Trunc -from django.db.models.query_utils import InvalidQuery, Q +from django.db.models.query_utils import FilteredRelation, InvalidQuery, Q from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE from django.utils import timezone from django.utils.deprecation import RemovedInDjango30Warning @@ -953,6 +953,12 @@ class QuerySet: if lookups == (None,): clone._prefetch_related_lookups = () else: + for lookup in lookups: + if isinstance(lookup, Prefetch): + lookup = lookup.prefetch_to + lookup = lookup.split(LOOKUP_SEP, 1)[0] + if lookup in self.query._filtered_relations: + raise ValueError('prefetch_related() is not supported with FilteredRelation.') clone._prefetch_related_lookups = clone._prefetch_related_lookups + lookups return clone @@ -984,7 +990,10 @@ class QuerySet: if alias in names: raise ValueError("The annotation '%s' conflicts with a field on " "the model." % alias) - clone.query.add_annotation(annotation, alias, is_summary=False) + if isinstance(annotation, FilteredRelation): + clone.query.add_filtered_relation(annotation, alias) + else: + clone.query.add_annotation(annotation, alias, is_summary=False) for alias, annotation in clone.query.annotations.items(): if alias in annotations and annotation.contains_aggregate: @@ -1060,6 +1069,10 @@ class QuerySet: # Can only pass None to defer(), not only(), as the rest option. # That won't stop people trying to do this, so let's be explicit. raise TypeError("Cannot pass None as an argument to only().") + for field in fields: + field = field.split(LOOKUP_SEP, 1)[0] + if field in self.query._filtered_relations: + raise ValueError('only() is not supported with FilteredRelation.') clone = self._chain() clone.query.add_immediate_loading(fields) return clone @@ -1730,9 +1743,9 @@ class RelatedPopulator: # model's fields. # - related_populators: a list of RelatedPopulator instances if # select_related() descends to related models from this model. - # - field, remote_field: the fields to use for populating the - # internal fields cache. If remote_field is set then we also - # set the reverse link. + # - local_setter, remote_setter: Methods to set cached values on + # the object being populated and on the remote object. Usually + # these are Field.set_cached_value() methods. select_fields = klass_info['select_fields'] from_parent = klass_info['from_parent'] if not from_parent: @@ -1751,16 +1764,8 @@ class RelatedPopulator: self.model_cls = klass_info['model'] self.pk_idx = self.init_list.index(self.model_cls._meta.pk.attname) self.related_populators = get_related_populators(klass_info, select, self.db) - reverse = klass_info['reverse'] - field = klass_info['field'] - self.remote_field = None - if reverse: - self.field = field.remote_field - self.remote_field = field - else: - self.field = field - if field.unique: - self.remote_field = field.remote_field + self.local_setter = klass_info['local_setter'] + self.remote_setter = klass_info['remote_setter'] def populate(self, row, from_obj): if self.reorder_for_init: @@ -1774,9 +1779,9 @@ class RelatedPopulator: if self.related_populators: for rel_iter in self.related_populators: rel_iter.populate(row, obj) - if self.remote_field: - self.remote_field.set_cached_value(obj, from_obj) - self.field.set_cached_value(from_obj, obj) + self.local_setter(from_obj, obj) + if obj is not None: + self.remote_setter(obj, from_obj) def get_related_populators(klass_info, select, db): diff --git a/django/db/models/query_utils.py b/django/db/models/query_utils.py index e3f6a730d5..8a889264e5 100644 --- a/django/db/models/query_utils.py +++ b/django/db/models/query_utils.py @@ -16,7 +16,7 @@ from django.utils import tree # PathInfo is used when converting lookups (fk__somecol). The contents # describe the relation in Model terms (model Options and Fields for both # sides of the relation. The join_field is the field backing the relation. -PathInfo = namedtuple('PathInfo', 'from_opts to_opts target_fields join_field m2m direct') +PathInfo = namedtuple('PathInfo', 'from_opts to_opts target_fields join_field m2m direct filtered_relation') class InvalidQuery(Exception): @@ -291,3 +291,44 @@ def check_rel_lookup_compatibility(model, target_opts, field): check(target_opts) or (getattr(field, 'primary_key', False) and check(field.model._meta)) ) + + +class FilteredRelation: + """Specify custom filtering in the ON clause of SQL joins.""" + + def __init__(self, relation_name, *, condition=Q()): + if not relation_name: + raise ValueError('relation_name cannot be empty.') + self.relation_name = relation_name + self.alias = None + if not isinstance(condition, Q): + raise ValueError('condition argument must be a Q() instance.') + self.condition = condition + self.path = [] + + def __eq__(self, other): + return ( + isinstance(other, self.__class__) and + self.relation_name == other.relation_name and + self.alias == other.alias and + self.condition == other.condition + ) + + def clone(self): + clone = FilteredRelation(self.relation_name, condition=self.condition) + clone.alias = self.alias + clone.path = self.path[:] + return clone + + def resolve_expression(self, *args, **kwargs): + """ + QuerySet.annotate() only accepts expression-like arguments + (with a resolve_expression() method). + """ + raise NotImplementedError('FilteredRelation.resolve_expression() is unused.') + + def as_sql(self, compiler, connection): + # Resolve the condition in Join.filtered_relation. + query = compiler.query + where = query.build_filtered_relation_q(self.condition, reuse=set(self.path)) + return compiler.compile(where) diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 11ff51f60f..14d44d3eef 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -702,7 +702,7 @@ class SQLCompiler: """ result = [] params = [] - for alias in self.query.alias_map: + for alias in tuple(self.query.alias_map): if not self.query.alias_refcount[alias]: continue try: @@ -737,7 +737,7 @@ class SQLCompiler: f.field.related_query_name() for f in opts.related_objects if f.field.unique ) - return chain(direct_choices, reverse_choices) + return chain(direct_choices, reverse_choices, self.query._filtered_relations) related_klass_infos = [] if not restricted and cur_depth > self.query.max_depth: @@ -788,7 +788,8 @@ class SQLCompiler: klass_info = { 'model': f.remote_field.model, 'field': f, - 'reverse': False, + 'local_setter': f.set_cached_value, + 'remote_setter': f.remote_field.set_cached_value if f.unique else lambda x, y: None, 'from_parent': False, } related_klass_infos.append(klass_info) @@ -825,7 +826,8 @@ class SQLCompiler: klass_info = { 'model': model, 'field': f, - 'reverse': True, + 'local_setter': f.remote_field.set_cached_value, + 'remote_setter': f.set_cached_value, 'from_parent': from_parent, } related_klass_infos.append(klass_info) @@ -842,6 +844,47 @@ class SQLCompiler: next, restricted) get_related_klass_infos(klass_info, next_klass_infos) fields_not_found = set(requested).difference(fields_found) + for name in list(requested): + # Filtered relations work only on the topmost level. + if cur_depth > 1: + break + if name in self.query._filtered_relations: + fields_found.add(name) + f, _, join_opts, joins, _ = self.query.setup_joins([name], opts, root_alias) + model = join_opts.model + alias = joins[-1] + from_parent = issubclass(model, opts.model) and model is not opts.model + + def local_setter(obj, from_obj): + f.remote_field.set_cached_value(from_obj, obj) + + def remote_setter(obj, from_obj): + setattr(from_obj, name, obj) + klass_info = { + 'model': model, + 'field': f, + 'local_setter': local_setter, + 'remote_setter': remote_setter, + 'from_parent': from_parent, + } + related_klass_infos.append(klass_info) + select_fields = [] + columns = self.get_default_columns( + start_alias=alias, opts=model._meta, + from_parent=opts.model, + ) + for col in columns: + select_fields.append(len(select)) + select.append((col, None)) + klass_info['select_fields'] = select_fields + next_requested = requested.get(name, {}) + next_klass_infos = self.get_related_selections( + select, opts=model._meta, root_alias=alias, + cur_depth=cur_depth + 1, requested=next_requested, + restricted=restricted, + ) + get_related_klass_infos(klass_info, next_klass_infos) + fields_not_found = set(requested).difference(fields_found) if fields_not_found: invalid_fields = ("'%s'" % s for s in fields_not_found) raise FieldError( diff --git a/django/db/models/sql/datastructures.py b/django/db/models/sql/datastructures.py index 788c2dd669..ab02f65042 100644 --- a/django/db/models/sql/datastructures.py +++ b/django/db/models/sql/datastructures.py @@ -41,7 +41,7 @@ class Join: - relabeled_clone() """ def __init__(self, table_name, parent_alias, table_alias, join_type, - join_field, nullable): + join_field, nullable, filtered_relation=None): # Join table self.table_name = table_name self.parent_alias = parent_alias @@ -56,6 +56,7 @@ class Join: self.join_field = join_field # Is this join nullabled? self.nullable = nullable + self.filtered_relation = filtered_relation def as_sql(self, compiler, connection): """ @@ -85,7 +86,11 @@ class Join: extra_sql, extra_params = compiler.compile(extra_cond) join_conditions.append('(%s)' % extra_sql) params.extend(extra_params) - + if self.filtered_relation: + extra_sql, extra_params = compiler.compile(self.filtered_relation) + if extra_sql: + join_conditions.append('(%s)' % extra_sql) + params.extend(extra_params) if not join_conditions: # This might be a rel on the other end of an actual declared field. declared_field = getattr(self.join_field, 'field', self.join_field) @@ -101,18 +106,27 @@ class Join: def relabeled_clone(self, change_map): new_parent_alias = change_map.get(self.parent_alias, self.parent_alias) new_table_alias = change_map.get(self.table_alias, self.table_alias) + if self.filtered_relation is not None: + filtered_relation = self.filtered_relation.clone() + filtered_relation.path = [change_map.get(p, p) for p in self.filtered_relation.path] + else: + filtered_relation = None return self.__class__( self.table_name, new_parent_alias, new_table_alias, self.join_type, - self.join_field, self.nullable) + self.join_field, self.nullable, filtered_relation=filtered_relation, + ) + + def equals(self, other, with_filtered_relation): + return ( + isinstance(other, self.__class__) and + self.table_name == other.table_name and + self.parent_alias == other.parent_alias and + self.join_field == other.join_field and + (not with_filtered_relation or self.filtered_relation == other.filtered_relation) + ) def __eq__(self, other): - if isinstance(other, self.__class__): - return ( - self.table_name == other.table_name and - self.parent_alias == other.parent_alias and - self.join_field == other.join_field - ) - return False + return self.equals(other, with_filtered_relation=True) def demote(self): new = self.relabeled_clone({}) @@ -134,6 +148,7 @@ class BaseTable: """ join_type = None parent_alias = None + filtered_relation = None def __init__(self, table_name, alias): self.table_name = table_name @@ -146,3 +161,10 @@ class BaseTable: def relabeled_clone(self, change_map): return self.__class__(self.table_name, change_map.get(self.table_alias, self.table_alias)) + + def equals(self, other, with_filtered_relation): + return ( + isinstance(self, other.__class__) and + self.table_name == other.table_name and + self.table_alias == other.table_alias + ) diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index dfa369513b..a962aabdf1 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -45,6 +45,14 @@ def get_field_names_from_opts(opts): )) +def get_children_from_q(q): + for child in q.children: + if isinstance(child, Node): + yield from get_children_from_q(child) + else: + yield child + + JoinInfo = namedtuple( 'JoinInfo', ('final_field', 'targets', 'opts', 'joins', 'path') @@ -210,6 +218,8 @@ class Query: # load. self.deferred_loading = (frozenset(), True) + self._filtered_relations = {} + @property def extra(self): if self._extra is None: @@ -311,6 +321,7 @@ class Query: if 'subq_aliases' in self.__dict__: obj.subq_aliases = self.subq_aliases.copy() obj.used_aliases = self.used_aliases.copy() + obj._filtered_relations = self._filtered_relations.copy() # Clear the cached_property try: del obj.base_table @@ -624,6 +635,8 @@ class Query: opts = orig_opts for name in parts[:-1]: old_model = cur_model + if name in self._filtered_relations: + name = self._filtered_relations[name].relation_name source = opts.get_field(name) if is_reverse_o2o(source): cur_model = source.related_model @@ -684,7 +697,7 @@ class Query: for model, values in seen.items(): callback(target, model, values) - def table_alias(self, table_name, create=False): + def table_alias(self, table_name, create=False, filtered_relation=None): """ Return a table alias for the given table_name and whether this is a new alias or not. @@ -704,8 +717,8 @@ class Query: alias_list.append(alias) else: # The first occurrence of a table uses the table name directly. - alias = table_name - self.table_map[alias] = [alias] + alias = filtered_relation.alias if filtered_relation is not None else table_name + self.table_map[table_name] = [alias] self.alias_refcount[alias] = 1 return alias, True @@ -881,7 +894,7 @@ class Query: """ return len([1 for count in self.alias_refcount.values() if count]) - def join(self, join, reuse=None): + def join(self, join, reuse=None, reuse_with_filtered_relation=False): """ Return an alias for the 'join', either reusing an existing alias for that join or creating a new one. 'join' is either a @@ -890,18 +903,29 @@ class Query: The 'reuse' parameter can be either None which means all joins are reusable, or it can be a set containing the aliases that can be reused. + The 'reuse_with_filtered_relation' parameter is used when computing + FilteredRelation instances. + A join is always created as LOUTER if the lhs alias is LOUTER to make sure chains like t1 LOUTER t2 INNER t3 aren't generated. All new joins are created as LOUTER if the join is nullable. """ - reuse = [a for a, j in self.alias_map.items() - if (reuse is None or a in reuse) and j == join] - if reuse: - self.ref_alias(reuse[0]) - return reuse[0] + if reuse_with_filtered_relation and reuse: + reuse_aliases = [ + a for a, j in self.alias_map.items() + if a in reuse and j.equals(join, with_filtered_relation=False) + ] + else: + reuse_aliases = [ + a for a, j in self.alias_map.items() + if (reuse is None or a in reuse) and j == join + ] + if reuse_aliases: + self.ref_alias(reuse_aliases[0]) + return reuse_aliases[0] # No reuse is possible, so we need a new alias. - alias, _ = self.table_alias(join.table_name, create=True) + alias, _ = self.table_alias(join.table_name, create=True, filtered_relation=join.filtered_relation) if join.join_type: if self.alias_map[join.parent_alias].join_type == LOUTER or join.nullable: join_type = LOUTER @@ -1090,7 +1114,8 @@ class Query: (name, lhs.output_field.__class__.__name__)) def build_filter(self, filter_expr, branch_negated=False, current_negated=False, - can_reuse=None, allow_joins=True, split_subq=True): + can_reuse=None, allow_joins=True, split_subq=True, + reuse_with_filtered_relation=False): """ Build a WhereNode for a single filter clause but don't add it to this Query. Query.add_q() will then add this filter to the where @@ -1112,6 +1137,9 @@ class Query: The 'can_reuse' is a set of reusable joins for multijoins. + If 'reuse_with_filtered_relation' is True, then only joins in can_reuse + will be reused. + The method will create a filter clause that can be added to the current query. However, if the filter isn't added to the query then the caller is responsible for unreffing the joins used. @@ -1147,7 +1175,10 @@ class Query: allow_many = not branch_negated or not split_subq try: - join_info = self.setup_joins(parts, opts, alias, can_reuse=can_reuse, allow_many=allow_many) + join_info = self.setup_joins( + parts, opts, alias, can_reuse=can_reuse, allow_many=allow_many, + reuse_with_filtered_relation=reuse_with_filtered_relation, + ) # Prevent iterator from being consumed by check_related_objects() if isinstance(value, Iterator): @@ -1250,6 +1281,41 @@ class Query: needed_inner = joinpromoter.update_join_types(self) return target_clause, needed_inner + def build_filtered_relation_q(self, q_object, reuse, branch_negated=False, current_negated=False): + """Add a FilteredRelation object to the current filter.""" + connector = q_object.connector + current_negated ^= q_object.negated + branch_negated = branch_negated or q_object.negated + target_clause = self.where_class(connector=connector, negated=q_object.negated) + for child in q_object.children: + if isinstance(child, Node): + child_clause = self.build_filtered_relation_q( + child, reuse=reuse, branch_negated=branch_negated, + current_negated=current_negated, + ) + else: + child_clause, _ = self.build_filter( + child, can_reuse=reuse, branch_negated=branch_negated, + current_negated=current_negated, + allow_joins=True, split_subq=False, + reuse_with_filtered_relation=True, + ) + target_clause.add(child_clause, connector) + return target_clause + + def add_filtered_relation(self, filtered_relation, alias): + filtered_relation.alias = alias + lookups = dict(get_children_from_q(filtered_relation.condition)) + for lookup in chain((filtered_relation.relation_name,), lookups): + lookup_parts, field_parts, _ = self.solve_lookup_type(lookup) + shift = 2 if not lookup_parts else 1 + if len(field_parts) > (shift + len(lookup_parts)): + raise ValueError( + "FilteredRelation's condition doesn't support nested " + "relations (got %r)." % lookup + ) + self._filtered_relations[filtered_relation.alias] = filtered_relation + def names_to_path(self, names, opts, allow_many=True, fail_on_missing=False): """ Walk the list of names and turns them into PathInfo tuples. A single @@ -1272,12 +1338,15 @@ class Query: name = opts.pk.name field = None + filtered_relation = None try: field = opts.get_field(name) except FieldDoesNotExist: if name in self.annotation_select: field = self.annotation_select[name].output_field - + elif name in self._filtered_relations and pos == 0: + filtered_relation = self._filtered_relations[name] + field = opts.get_field(filtered_relation.relation_name) if field is not None: # Fields that contain one-to-many relations with a generic # model (like a GenericForeignKey) cannot generate reverse @@ -1301,7 +1370,10 @@ class Query: pos -= 1 if pos == -1 or fail_on_missing: field_names = list(get_field_names_from_opts(opts)) - available = sorted(field_names + list(self.annotation_select)) + available = sorted( + field_names + list(self.annotation_select) + + list(self._filtered_relations) + ) raise FieldError("Cannot resolve keyword '%s' into field. " "Choices are: %s" % (name, ", ".join(available))) break @@ -1315,7 +1387,7 @@ class Query: cur_names_with_path[1].extend(path_to_parent) opts = path_to_parent[-1].to_opts if hasattr(field, 'get_path_info'): - pathinfos = field.get_path_info() + pathinfos = field.get_path_info(filtered_relation) if not allow_many: for inner_pos, p in enumerate(pathinfos): if p.m2m: @@ -1340,7 +1412,8 @@ class Query: break return path, final_field, targets, names[pos + 1:] - def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True): + def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True, + reuse_with_filtered_relation=False): """ Compute the necessary table joins for the passage through the fields given in 'names'. 'opts' is the Options class for the current model @@ -1352,6 +1425,9 @@ class Query: that can be reused. Note that non-reverse foreign keys are always reusable when using setup_joins(). + The 'reuse_with_filtered_relation' can be used to force 'can_reuse' + parameter and force the relation on the given connections. + If 'allow_many' is False, then any reverse foreign key seen will generate a MultiJoin exception. @@ -1374,15 +1450,29 @@ class Query: # joins at this stage - we will need the information about join type # of the trimmed joins. for join in path: + if join.filtered_relation: + filtered_relation = join.filtered_relation.clone() + table_alias = filtered_relation.alias + else: + filtered_relation = None + table_alias = None opts = join.to_opts if join.direct: nullable = self.is_nullable(join.join_field) else: nullable = True - connection = Join(opts.db_table, alias, None, INNER, join.join_field, nullable) - reuse = can_reuse if join.m2m else None - alias = self.join(connection, reuse=reuse) + connection = Join( + opts.db_table, alias, table_alias, INNER, join.join_field, + nullable, filtered_relation=filtered_relation, + ) + reuse = can_reuse if join.m2m or reuse_with_filtered_relation else None + alias = self.join( + connection, reuse=reuse, + reuse_with_filtered_relation=reuse_with_filtered_relation, + ) joins.append(alias) + if filtered_relation: + filtered_relation.path = joins[:] return JoinInfo(final_field, targets, opts, joins, path) def trim_joins(self, targets, joins, path): @@ -1402,6 +1492,8 @@ class Query: for pos, info in enumerate(reversed(path)): if len(joins) == 1 or not info.direct: break + if info.filtered_relation: + break join_targets = {t.column for t in info.join_field.foreign_related_fields} cur_targets = {t.column for t in targets} if not cur_targets.issubset(join_targets): @@ -1425,7 +1517,7 @@ class Query: return self.annotation_select[name] else: field_list = name.split(LOOKUP_SEP) - join_info = self.setup_joins(field_list, self.get_meta(), self.get_initial_alias(), reuse) + join_info = self.setup_joins(field_list, self.get_meta(), self.get_initial_alias(), can_reuse=reuse) targets, _, join_list = self.trim_joins(join_info.targets, join_info.joins, join_info.path) if len(targets) > 1: raise FieldError("Referencing multicolumn fields with F() objects " @@ -1602,7 +1694,10 @@ class Query: # from the model on which the lookup failed. raise else: - names = sorted(list(get_field_names_from_opts(opts)) + list(self.extra) + list(self.annotation_select)) + names = sorted( + list(get_field_names_from_opts(opts)) + list(self.extra) + + list(self.annotation_select) + list(self._filtered_relations) + ) raise FieldError("Cannot resolve keyword %r into field. " "Choices are: %s" % (name, ", ".join(names))) diff --git a/docs/ref/models/querysets.txt b/docs/ref/models/querysets.txt index f62050d818..f85ce1e441 100644 --- a/docs/ref/models/querysets.txt +++ b/docs/ref/models/querysets.txt @@ -3318,3 +3318,60 @@ lookups or :class:`Prefetch` objects you want to prefetch for. For example:: >>> from django.db.models import prefetch_related_objects >>> restaurants = fetch_top_restaurants_from_cache() # A list of Restaurants >>> prefetch_related_objects(restaurants, 'pizzas__toppings') + +``FilteredRelation()`` objects +------------------------------ + +.. versionadded:: 2.0 + +.. class:: FilteredRelation(relation_name, *, condition=Q()) + + .. attribute:: FilteredRelation.relation_name + + The name of the field on which you'd like to filter the relation. + + .. attribute:: FilteredRelation.condition + + A :class:`~django.db.models.Q` object to control the filtering. + +``FilteredRelation`` is used with :meth:`~.QuerySet.annotate()` to create an +``ON`` clause when a ``JOIN`` is performed. It doesn't act on the default +relationship but on the annotation name (``pizzas_vegetarian`` in example +below). + +For example, to find restaurants that have vegetarian pizzas with +``'mozzarella'`` in the name:: + + >>> from django.db.models import FilteredRelation, Q + >>> Restaurant.objects.annotate( + ... pizzas_vegetarian=FilteredRelation( + ... 'pizzas', condition=Q(pizzas__vegetarian=True), + ... ), + ... ).filter(pizzas_vegetarian__name__icontains='mozzarella') + +If there are a large number of pizzas, this queryset performs better than:: + + >>> Restaurant.objects.filter( + ... pizzas__vegetarian=True, + ... pizzas__name__icontains='mozzarella', + ... ) + +because the filtering in the ``WHERE`` clause of the first queryset will only +operate on vegetarian pizzas. + +``FilteredRelation`` doesn't support: + +* Conditions that span relational fields. For example:: + + >>> Restaurant.objects.annotate( + ... pizzas_with_toppings_startswith_n=FilteredRelation( + ... 'pizzas__toppings', + ... condition=Q(pizzas__toppings__name__startswith='n'), + ... ), + ... ) + Traceback (most recent call last): + ... + ValueError: FilteredRelation's condition doesn't support nested relations (got 'pizzas__toppings__name__startswith'). +* :meth:`.QuerySet.only` and :meth:`~.QuerySet.prefetch_related`. +* A :class:`~django.contrib.contenttypes.fields.GenericForeignKey` + inherited from a parent model. diff --git a/docs/releases/2.0.txt b/docs/releases/2.0.txt index 77eee0c206..bff7063df8 100644 --- a/docs/releases/2.0.txt +++ b/docs/releases/2.0.txt @@ -354,6 +354,9 @@ Models * The new ``named`` parameter of :meth:`.QuerySet.values_list` allows fetching results as named tuples. +* The new :class:`.FilteredRelation` class allows adding an ``ON`` clause to + querysets. + Pagination ~~~~~~~~~~ diff --git a/tests/filtered_relation/__init__.py b/tests/filtered_relation/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/filtered_relation/models.py b/tests/filtered_relation/models.py new file mode 100644 index 0000000000..501e731de7 --- /dev/null +++ b/tests/filtered_relation/models.py @@ -0,0 +1,108 @@ +from django.contrib.contenttypes.fields import ( + GenericForeignKey, GenericRelation, +) +from django.contrib.contenttypes.models import ContentType +from django.db import models + + +class Author(models.Model): + name = models.CharField(max_length=50, unique=True) + favorite_books = models.ManyToManyField( + 'Book', + related_name='preferred_by_authors', + related_query_name='preferred_by_authors', + ) + content_type = models.ForeignKey(ContentType, models.CASCADE, null=True) + object_id = models.PositiveIntegerField(null=True) + content_object = GenericForeignKey() + + def __str__(self): + return self.name + + +class Editor(models.Model): + name = models.CharField(max_length=255) + + def __str__(self): + return self.name + + +class Book(models.Model): + AVAILABLE = 'available' + RESERVED = 'reserved' + RENTED = 'rented' + STATES = ( + (AVAILABLE, 'Available'), + (RESERVED, 'reserved'), + (RENTED, 'Rented'), + ) + title = models.CharField(max_length=255) + author = models.ForeignKey( + Author, + models.CASCADE, + related_name='books', + related_query_name='book', + ) + editor = models.ForeignKey(Editor, models.CASCADE) + generic_author = GenericRelation(Author) + state = models.CharField(max_length=9, choices=STATES, default=AVAILABLE) + + def __str__(self): + return self.title + + +class Borrower(models.Model): + name = models.CharField(max_length=50, unique=True) + + def __str__(self): + return self.name + + +class Reservation(models.Model): + NEW = 'new' + STOPPED = 'stopped' + STATES = ( + (NEW, 'New'), + (STOPPED, 'Stopped'), + ) + borrower = models.ForeignKey( + Borrower, + models.CASCADE, + related_name='reservations', + related_query_name='reservation', + ) + book = models.ForeignKey( + Book, + models.CASCADE, + related_name='reservations', + related_query_name='reservation', + ) + state = models.CharField(max_length=7, choices=STATES, default=NEW) + + def __str__(self): + return '-'.join((self.book.name, self.borrower.name, self.state)) + + +class RentalSession(models.Model): + NEW = 'new' + STOPPED = 'stopped' + STATES = ( + (NEW, 'New'), + (STOPPED, 'Stopped'), + ) + borrower = models.ForeignKey( + Borrower, + models.CASCADE, + related_name='rental_sessions', + related_query_name='rental_session', + ) + book = models.ForeignKey( + Book, + models.CASCADE, + related_name='rental_sessions', + related_query_name='rental_session', + ) + state = models.CharField(max_length=7, choices=STATES, default=NEW) + + def __str__(self): + return '-'.join((self.book.name, self.borrower.name, self.state)) diff --git a/tests/filtered_relation/tests.py b/tests/filtered_relation/tests.py new file mode 100644 index 0000000000..4bae2216bf --- /dev/null +++ b/tests/filtered_relation/tests.py @@ -0,0 +1,381 @@ +from django.db import connection +from django.db.models import Case, Count, F, FilteredRelation, Q, When +from django.test import TestCase +from django.test.testcases import skipUnlessDBFeature + +from .models import Author, Book, Borrower, Editor, RentalSession, Reservation + + +class FilteredRelationTests(TestCase): + + @classmethod + def setUpTestData(cls): + cls.author1 = Author.objects.create(name='Alice') + cls.author2 = Author.objects.create(name='Jane') + cls.editor_a = Editor.objects.create(name='a') + cls.editor_b = Editor.objects.create(name='b') + cls.book1 = Book.objects.create( + title='Poem by Alice', + editor=cls.editor_a, + author=cls.author1, + ) + cls.book1.generic_author.set([cls.author2]) + cls.book2 = Book.objects.create( + title='The book by Jane A', + editor=cls.editor_b, + author=cls.author2, + ) + cls.book3 = Book.objects.create( + title='The book by Jane B', + editor=cls.editor_b, + author=cls.author2, + ) + cls.book4 = Book.objects.create( + title='The book by Alice', + editor=cls.editor_a, + author=cls.author1, + ) + cls.author1.favorite_books.add(cls.book2) + cls.author1.favorite_books.add(cls.book3) + + def test_select_related(self): + qs = Author.objects.annotate( + book_join=FilteredRelation('book'), + ).select_related('book_join__editor').order_by('pk', 'book_join__pk') + with self.assertNumQueries(1): + self.assertQuerysetEqual(qs, [ + (self.author1, self.book1, self.editor_a, self.author1), + (self.author1, self.book4, self.editor_a, self.author1), + (self.author2, self.book2, self.editor_b, self.author2), + (self.author2, self.book3, self.editor_b, self.author2), + ], lambda x: (x, x.book_join, x.book_join.editor, x.book_join.author)) + + def test_select_related_foreign_key(self): + qs = Book.objects.annotate( + author_join=FilteredRelation('author'), + ).select_related('author_join').order_by('pk') + with self.assertNumQueries(1): + self.assertQuerysetEqual(qs, [ + (self.book1, self.author1), + (self.book2, self.author2), + (self.book3, self.author2), + (self.book4, self.author1), + ], lambda x: (x, x.author_join)) + + def test_without_join(self): + self.assertSequenceEqual( + Author.objects.annotate( + book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')), + ), + [self.author1, self.author2] + ) + + def test_with_join(self): + self.assertSequenceEqual( + Author.objects.annotate( + book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')), + ).filter(book_alice__isnull=False), + [self.author1] + ) + + def test_with_join_and_complex_condition(self): + self.assertSequenceEqual( + Author.objects.annotate( + book_alice=FilteredRelation( + 'book', condition=Q( + Q(book__title__iexact='poem by alice') | + Q(book__state=Book.RENTED) + ), + ), + ).filter(book_alice__isnull=False), + [self.author1] + ) + + def test_internal_queryset_alias_mapping(self): + queryset = Author.objects.annotate( + book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')), + ).filter(book_alice__isnull=False) + self.assertIn( + 'INNER JOIN {} book_alice ON'.format(connection.ops.quote_name('filtered_relation_book')), + str(queryset.query) + ) + + def test_with_multiple_filter(self): + self.assertSequenceEqual( + Author.objects.annotate( + book_editor_a=FilteredRelation( + 'book', + condition=Q(book__title__icontains='book', book__editor_id=self.editor_a.pk), + ), + ).filter(book_editor_a__isnull=False), + [self.author1] + ) + + def test_multiple_times(self): + self.assertSequenceEqual( + Author.objects.annotate( + book_title_alice=FilteredRelation('book', condition=Q(book__title__icontains='alice')), + ).filter(book_title_alice__isnull=False).filter(book_title_alice__isnull=False).distinct(), + [self.author1] + ) + + def test_exclude_relation_with_join(self): + self.assertSequenceEqual( + Author.objects.annotate( + book_alice=FilteredRelation('book', condition=~Q(book__title__icontains='alice')), + ).filter(book_alice__isnull=False).distinct(), + [self.author2] + ) + + def test_with_m2m(self): + qs = Author.objects.annotate( + favorite_books_written_by_jane=FilteredRelation( + 'favorite_books', condition=Q(favorite_books__in=[self.book2]), + ), + ).filter(favorite_books_written_by_jane__isnull=False) + self.assertSequenceEqual(qs, [self.author1]) + + def test_with_m2m_deep(self): + qs = Author.objects.annotate( + favorite_books_written_by_jane=FilteredRelation( + 'favorite_books', condition=Q(favorite_books__author=self.author2), + ), + ).filter(favorite_books_written_by_jane__title='The book by Jane B') + self.assertSequenceEqual(qs, [self.author1]) + + def test_with_m2m_multijoin(self): + qs = Author.objects.annotate( + favorite_books_written_by_jane=FilteredRelation( + 'favorite_books', condition=Q(favorite_books__author=self.author2), + ) + ).filter(favorite_books_written_by_jane__editor__name='b').distinct() + self.assertSequenceEqual(qs, [self.author1]) + + def test_values_list(self): + self.assertSequenceEqual( + Author.objects.annotate( + book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')), + ).filter(book_alice__isnull=False).values_list('book_alice__title', flat=True), + ['Poem by Alice'] + ) + + def test_values(self): + self.assertSequenceEqual( + Author.objects.annotate( + book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')), + ).filter(book_alice__isnull=False).values(), + [{'id': self.author1.pk, 'name': 'Alice', 'content_type_id': None, 'object_id': None}] + ) + + def test_extra(self): + self.assertSequenceEqual( + Author.objects.annotate( + book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')), + ).filter(book_alice__isnull=False).extra(where=['1 = 1']), + [self.author1] + ) + + @skipUnlessDBFeature('supports_select_union') + def test_union(self): + qs1 = Author.objects.annotate( + book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')), + ).filter(book_alice__isnull=False) + qs2 = Author.objects.annotate( + book_jane=FilteredRelation('book', condition=Q(book__title__iexact='the book by jane a')), + ).filter(book_jane__isnull=False) + self.assertSequenceEqual(qs1.union(qs2), [self.author1, self.author2]) + + @skipUnlessDBFeature('supports_select_intersection') + def test_intersection(self): + qs1 = Author.objects.annotate( + book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')), + ).filter(book_alice__isnull=False) + qs2 = Author.objects.annotate( + book_jane=FilteredRelation('book', condition=Q(book__title__iexact='the book by jane a')), + ).filter(book_jane__isnull=False) + self.assertSequenceEqual(qs1.intersection(qs2), []) + + @skipUnlessDBFeature('supports_select_difference') + def test_difference(self): + qs1 = Author.objects.annotate( + book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')), + ).filter(book_alice__isnull=False) + qs2 = Author.objects.annotate( + book_jane=FilteredRelation('book', condition=Q(book__title__iexact='the book by jane a')), + ).filter(book_jane__isnull=False) + self.assertSequenceEqual(qs1.difference(qs2), [self.author1]) + + def test_select_for_update(self): + self.assertSequenceEqual( + Author.objects.annotate( + book_jane=FilteredRelation('book', condition=Q(book__title__iexact='the book by jane a')), + ).filter(book_jane__isnull=False).select_for_update(), + [self.author2] + ) + + def test_defer(self): + # One query for the list and one query for the deferred title. + with self.assertNumQueries(2): + self.assertQuerysetEqual( + Author.objects.annotate( + book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')), + ).filter(book_alice__isnull=False).select_related('book_alice').defer('book_alice__title'), + ['Poem by Alice'], lambda author: author.book_alice.title + ) + + def test_only_not_supported(self): + msg = 'only() is not supported with FilteredRelation.' + with self.assertRaisesMessage(ValueError, msg): + Author.objects.annotate( + book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')), + ).filter(book_alice__isnull=False).select_related('book_alice').only('book_alice__state') + + def test_as_subquery(self): + inner_qs = Author.objects.annotate( + book_alice=FilteredRelation('book', condition=Q(book__title__iexact='poem by alice')), + ).filter(book_alice__isnull=False) + qs = Author.objects.filter(id__in=inner_qs) + self.assertSequenceEqual(qs, [self.author1]) + + def test_with_foreign_key_error(self): + msg = ( + "FilteredRelation's condition doesn't support nested relations " + "(got 'author__favorite_books__author')." + ) + with self.assertRaisesMessage(ValueError, msg): + list(Book.objects.annotate( + alice_favorite_books=FilteredRelation( + 'author__favorite_books', + condition=Q(author__favorite_books__author=self.author1), + ) + )) + + def test_with_foreign_key_on_condition_error(self): + msg = ( + "FilteredRelation's condition doesn't support nested relations " + "(got 'book__editor__name__icontains')." + ) + with self.assertRaisesMessage(ValueError, msg): + list(Author.objects.annotate( + book_edited_by_b=FilteredRelation('book', condition=Q(book__editor__name__icontains='b')), + )) + + def test_with_empty_relation_name_error(self): + with self.assertRaisesMessage(ValueError, 'relation_name cannot be empty.'): + FilteredRelation('', condition=Q(blank='')) + + def test_with_condition_as_expression_error(self): + msg = 'condition argument must be a Q() instance.' + expression = Case( + When(book__title__iexact='poem by alice', then=True), default=False, + ) + with self.assertRaisesMessage(ValueError, msg): + FilteredRelation('book', condition=expression) + + def test_with_prefetch_related(self): + msg = 'prefetch_related() is not supported with FilteredRelation.' + qs = Author.objects.annotate( + book_title_contains_b=FilteredRelation('book', condition=Q(book__title__icontains='b')), + ).filter( + book_title_contains_b__isnull=False, + ) + with self.assertRaisesMessage(ValueError, msg): + qs.prefetch_related('book_title_contains_b') + with self.assertRaisesMessage(ValueError, msg): + qs.prefetch_related('book_title_contains_b__editor') + + def test_with_generic_foreign_key(self): + self.assertSequenceEqual( + Book.objects.annotate( + generic_authored_book=FilteredRelation( + 'generic_author', + condition=Q(generic_author__isnull=False) + ), + ).filter(generic_authored_book__isnull=False), + [self.book1] + ) + + +class FilteredRelationAggregationTests(TestCase): + + @classmethod + def setUpTestData(cls): + cls.author1 = Author.objects.create(name='Alice') + cls.editor_a = Editor.objects.create(name='a') + cls.book1 = Book.objects.create( + title='Poem by Alice', + editor=cls.editor_a, + author=cls.author1, + ) + cls.borrower1 = Borrower.objects.create(name='Jenny') + cls.borrower2 = Borrower.objects.create(name='Kevin') + # borrower 1 reserves, rents, and returns book1. + Reservation.objects.create( + borrower=cls.borrower1, + book=cls.book1, + state=Reservation.STOPPED, + ) + RentalSession.objects.create( + borrower=cls.borrower1, + book=cls.book1, + state=RentalSession.STOPPED, + ) + # borrower2 reserves, rents, and returns book1. + Reservation.objects.create( + borrower=cls.borrower2, + book=cls.book1, + state=Reservation.STOPPED, + ) + RentalSession.objects.create( + borrower=cls.borrower2, + book=cls.book1, + state=RentalSession.STOPPED, + ) + + def test_aggregate(self): + """ + filtered_relation() not only improves performance but also creates + correct results when aggregating with multiple LEFT JOINs. + + Books can be reserved then rented by a borrower. Each reservation and + rental session are recorded with Reservation and RentalSession models. + Every time a reservation or a rental session is over, their state is + changed to 'stopped'. + + Goal: Count number of books that are either currently reserved or + rented by borrower1 or available. + """ + qs = Book.objects.annotate( + is_reserved_or_rented_by=Case( + When(reservation__state=Reservation.NEW, then=F('reservation__borrower__pk')), + When(rental_session__state=RentalSession.NEW, then=F('rental_session__borrower__pk')), + default=None, + ) + ).filter( + Q(is_reserved_or_rented_by=self.borrower1.pk) | Q(state=Book.AVAILABLE) + ).distinct() + self.assertEqual(qs.count(), 1) + # If count is equal to 1, the same aggregation should return in the + # same result but it returns 4. + self.assertSequenceEqual(qs.annotate(total=Count('pk')).values('total'), [{'total': 4}]) + # With FilteredRelation, the result is as expected (1). + qs = Book.objects.annotate( + active_reservations=FilteredRelation( + 'reservation', condition=Q( + reservation__state=Reservation.NEW, + reservation__borrower=self.borrower1, + ) + ), + ).annotate( + active_rental_sessions=FilteredRelation( + 'rental_session', condition=Q( + rental_session__state=RentalSession.NEW, + rental_session__borrower=self.borrower1, + ) + ), + ).filter( + (Q(active_reservations__isnull=False) | Q(active_rental_sessions__isnull=False)) | + Q(state=Book.AVAILABLE) + ).distinct() + self.assertEqual(qs.count(), 1) + self.assertSequenceEqual(qs.annotate(total=Count('pk')).values('total'), [{'total': 1}]) diff --git a/tests/foreign_object/models/empty_join.py b/tests/foreign_object/models/empty_join.py index 202676b075..08d1edb18a 100644 --- a/tests/foreign_object/models/empty_join.py +++ b/tests/foreign_object/models/empty_join.py @@ -53,15 +53,31 @@ class StartsWithRelation(models.ForeignObject): def get_joining_columns(self, reverse_join=False): return () - def get_path_info(self): + def get_path_info(self, filtered_relation=None): to_opts = self.remote_field.model._meta from_opts = self.model._meta - return [PathInfo(from_opts, to_opts, (to_opts.pk,), self, False, False)] + return [PathInfo( + from_opts=from_opts, + to_opts=to_opts, + target_fields=(to_opts.pk,), + join_field=self, + m2m=False, + direct=False, + filtered_relation=filtered_relation, + )] - def get_reverse_path_info(self): + def get_reverse_path_info(self, filtered_relation=None): to_opts = self.model._meta from_opts = self.remote_field.model._meta - return [PathInfo(from_opts, to_opts, (to_opts.pk,), self.remote_field, False, False)] + return [PathInfo( + from_opts=from_opts, + to_opts=to_opts, + target_fields=(to_opts.pk,), + join_field=self.remote_field, + m2m=False, + direct=False, + filtered_relation=filtered_relation, + )] def contribute_to_class(self, cls, name, private_only=False): super().contribute_to_class(cls, name, private_only) diff --git a/tests/select_related_onetoone/tests.py b/tests/select_related_onetoone/tests.py index 5868dc807c..089a1062d9 100644 --- a/tests/select_related_onetoone/tests.py +++ b/tests/select_related_onetoone/tests.py @@ -1,4 +1,5 @@ from django.core.exceptions import FieldError +from django.db.models import FilteredRelation from django.test import SimpleTestCase, TestCase from .models import ( @@ -230,3 +231,8 @@ class ReverseSelectRelatedValidationTests(SimpleTestCase): with self.assertRaisesMessage(FieldError, self.non_relational_error % ('username', fields)): list(User.objects.select_related('username')) + + def test_reverse_related_validation_with_filtered_relation(self): + fields = 'userprofile, userstat, relation' + with self.assertRaisesMessage(FieldError, self.invalid_error % ('foobar', fields)): + list(User.objects.annotate(relation=FilteredRelation('userprofile')).select_related('foobar'))