diff --git a/django/db/models/sql/aggregates.py b/django/db/models/sql/aggregates.py index b41314a686e..75a330f22aa 100644 --- a/django/db/models/sql/aggregates.py +++ b/django/db/models/sql/aggregates.py @@ -1,6 +1,7 @@ """ Classes to represent the default SQL aggregate functions """ +import copy from django.db.models.fields import IntegerField, FloatField @@ -62,6 +63,11 @@ class Aggregate(object): self.field = tmp + def clone(self): + # Different aggregates have different init methods, so use copy here + # deepcopy is not needed, as self.col is only changing variable. + return copy.copy(self) + def relabel_aliases(self, change_map): if isinstance(self.col, (list, tuple)): self.col = (change_map.get(self.col[0], self.col[0]), self.col[1]) diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index f021d571e9b..613f4c4cfc7 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -279,13 +279,13 @@ class Query(object): obj.select = self.select[:] obj.related_select_cols = [] obj.tables = self.tables[:] - obj.where = copy.deepcopy(self.where, memo=memo) + obj.where = self.where.clone() obj.where_class = self.where_class if self.group_by is None: obj.group_by = None else: obj.group_by = self.group_by[:] - obj.having = copy.deepcopy(self.having, memo=memo) + obj.having = self.having.clone() obj.order_by = self.order_by[:] obj.low_mark, obj.high_mark = self.low_mark, self.high_mark obj.distinct = self.distinct @@ -293,7 +293,9 @@ class Query(object): obj.select_for_update = self.select_for_update obj.select_for_update_nowait = self.select_for_update_nowait obj.select_related = self.select_related - obj.aggregates = copy.deepcopy(self.aggregates, memo=memo) + obj.related_select_cols = [] + obj.aggregates = SortedDict((k, v.clone()) + for k, v in self.aggregates.items()) if self.aggregate_select_mask is None: obj.aggregate_select_mask = None else: @@ -316,7 +318,7 @@ class Query(object): obj._extra_select_cache = self._extra_select_cache.copy() obj.extra_tables = self.extra_tables obj.extra_order_by = self.extra_order_by - obj.deferred_loading = copy.deepcopy(self.deferred_loading, memo=memo) + obj.deferred_loading = copy.copy(self.deferred_loading[0]), self.deferred_loading[1] if self.filter_is_sticky and self.used_aliases: obj.used_aliases = self.used_aliases.copy() else: @@ -549,7 +551,7 @@ class Query(object): # Now relabel a copy of the rhs where-clause and add it to the current # one. if rhs.where: - w = copy.deepcopy(rhs.where) + w = rhs.where.clone() w.relabel_aliases(change_map) if not self.where: # Since 'self' matches everything, add an explicit "include @@ -571,7 +573,7 @@ class Query(object): new_col = change_map.get(col[0], col[0]), col[1] self.select.append(SelectInfo(new_col, field)) else: - item = copy.deepcopy(col) + item = col.clone() item.relabel_aliases(change_map) self.select.append(SelectInfo(item, field)) diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index 02847b1f547..3e4b352f100 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -10,7 +10,7 @@ from itertools import repeat from django.utils import tree from django.db.models.fields import Field -from django.db.models.sql.datastructures import EmptyResultSet +from django.db.models.sql.datastructures import EmptyResultSet, Empty from django.db.models.sql.aggregates import Aggregate from django.utils.six.moves import xrange @@ -272,6 +272,23 @@ class WhereNode(tree.Node): if hasattr(child[3], 'relabel_aliases'): child[3].relabel_aliases(change_map) + def clone(self): + """ + Creates a clone of the tree. Must only be called on root nodes (nodes + with empty subtree_parents). Childs must be either (Contraint, lookup, + value) tuples, or objects supporting .clone(). + """ + assert not self.subtree_parents + clone = self.__class__._new_instance( + children=[], connector=self.connector, negated=self.negated) + for child in self.children: + if isinstance(child, tuple): + clone.children.append( + (child[0].clone(), child[1], child[2], child[3])) + else: + clone.children.append(child.clone()) + return clone + class EmptyWhere(WhereNode): def add(self, data, connector): @@ -291,6 +308,9 @@ class EverythingNode(object): def relabel_aliases(self, change_map, node=None): return + def clone(self): + return self + class NothingNode(object): """ A node that matches nothing. @@ -301,6 +321,9 @@ class NothingNode(object): def relabel_aliases(self, change_map, node=None): return + def clone(self): + return self + class ExtraWhere(object): def __init__(self, sqls, params): self.sqls = sqls @@ -310,6 +333,9 @@ class ExtraWhere(object): sqls = ["(%s)" % sql for sql in self.sqls] return " AND ".join(sqls), tuple(self.params or ()) + def clone(self): + return self + class Constraint(object): """ An object that can be passed to WhereNode.add() and knows how to @@ -374,3 +400,9 @@ class Constraint(object): def relabel_aliases(self, change_map): if self.alias in change_map: self.alias = change_map[self.alias] + + def clone(self): + new = Empty() + new.__class__ = self.__class__ + new.alias, new.col, new.field = self.alias, self.col, self.field + return new diff --git a/tests/regressiontests/queries/tests.py b/tests/regressiontests/queries/tests.py index 7d01c16255d..780af5a8b7b 100644 --- a/tests/regressiontests/queries/tests.py +++ b/tests/regressiontests/queries/tests.py @@ -1919,6 +1919,7 @@ class SubqueryTests(TestCase): class CloneTests(TestCase): + def test_evaluated_queryset_as_argument(self): "#13227 -- If a queryset is already evaluated, it can still be used as a query arg" n = Note(note='Test1', misc='misc') @@ -1933,6 +1934,39 @@ class CloneTests(TestCase): # that query in a way that involves cloning. self.assertEqual(ExtraInfo.objects.filter(note__in=n_list)[0].info, 'good') + def test_no_model_options_cloning(self): + """ + Test that cloning a queryset does not get out of hand. While complete + testing is impossible, this is a sanity check against invalid use of + deepcopy. refs #16759. + """ + opts_class = type(Note._meta) + note_deepcopy = getattr(opts_class, "__deepcopy__", None) + opts_class.__deepcopy__ = lambda obj, memo: self.fail("Model options shouldn't be cloned.") + try: + Note.objects.filter(pk__lte=F('pk') + 1).all() + finally: + if note_deepcopy is None: + delattr(opts_class, "__deepcopy__") + else: + opts_class.__deepcopy__ = note_deepcopy + + def test_no_fields_cloning(self): + """ + Test that cloning a queryset does not get out of hand. While complete + testing is impossible, this is a sanity check against invalid use of + deepcopy. refs #16759. + """ + opts_class = type(Note._meta.get_field_by_name("misc")[0]) + note_deepcopy = getattr(opts_class, "__deepcopy__", None) + opts_class.__deepcopy__ = lambda obj, memo: self.fail("Model fields shouldn't be cloned") + try: + Note.objects.filter(note=F('misc')).all() + finally: + if note_deepcopy is None: + delattr(opts_class, "__deepcopy__") + else: + opts_class.__deepcopy__ = note_deepcopy class EmptyQuerySetTests(TestCase): def test_emptyqueryset_values(self):