From bd283aa844b04651b7c8b4e85f48c6dced1678f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Anssi=20K=C3=A4=C3=A4ri=C3=A4inen?= Date: Sat, 26 May 2012 05:55:33 +0300 Subject: [PATCH] Refactored the empty/full result logic in WhereNode.as_sql() Made sure the WhereNode.as_sql() handles various EmptyResultSet and FullResultSet conditions correctly. Also, got rid of the FullResultSet exception class. It is now represented by '', [] return value in the as_sql() methods. --- django/db/models/sql/datastructures.py | 3 - django/db/models/sql/where.py | 85 ++++++++++++++----------- tests/regressiontests/queries/tests.py | 86 ++++++++++++++++++++++++++ 3 files changed, 135 insertions(+), 39 deletions(-) diff --git a/django/db/models/sql/datastructures.py b/django/db/models/sql/datastructures.py index 92d64e15ddd..b8e06daf01a 100644 --- a/django/db/models/sql/datastructures.py +++ b/django/db/models/sql/datastructures.py @@ -6,9 +6,6 @@ the SQL domain. class EmptyResultSet(Exception): pass -class FullResultSet(Exception): - pass - class MultiJoin(Exception): """ Used by join construction code to indicate the point at which a diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index 5515bc4f839..70ff5310f7b 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -10,7 +10,7 @@ from itertools import repeat from django.utils import tree from django.db.models.fields import Field -from django.db.models.sql.datastructures import EmptyResultSet, FullResultSet +from django.db.models.sql.datastructures import EmptyResultSet from django.db.models.sql.aggregates import Aggregate # Connection types @@ -75,17 +75,21 @@ class WhereNode(tree.Node): def as_sql(self, qn, connection): """ Returns the SQL version of the where clause and the value to be - substituted in. Returns None, None if this node is empty. - - If 'node' is provided, that is the root of the SQL generation - (generally not needed except by the internal implementation for - recursion). + substituted in. Returns '', [] if this node matches everything, + None, [] if this node is empty, and raises EmptyResultSet if this + node can't match anything. """ - if not self.children: - return None, [] + # Note that the logic here is made slightly more complex than + # necessary because there are two kind of empty nodes: Nodes + # containing 0 children, and nodes that are known to match everything. + # A match-everything node is different than empty node (which also + # technically matches everything) for backwards compatibility reasons. + # Refs #5261. result = [] result_params = [] - empty = True + everything_childs, nothing_childs = 0, 0 + non_empty_childs = len(self.children) + for child in self.children: try: if hasattr(child, 'as_sql'): @@ -93,39 +97,48 @@ class WhereNode(tree.Node): else: # A leaf node in the tree. sql, params = self.make_atom(child, qn, connection) - except EmptyResultSet: - if self.connector == AND and not self.negated: - # We can bail out early in this particular case (only). - raise - elif self.negated: - empty = False - continue - except FullResultSet: - if self.connector == OR: - if self.negated: - empty = True - break - # We match everything. No need for any constraints. - return '', [] + nothing_childs += 1 + else: + if sql: + result.append(sql) + result_params.extend(params) + else: + if sql is None: + # Skip empty childs totally. + non_empty_childs -= 1 + continue + everything_childs += 1 + # Check if this node matches nothing or everything. + # First check the amount of full nodes and empty nodes + # to make this node empty/full. + if self.connector == AND: + full_needed, empty_needed = non_empty_childs, 1 + else: + full_needed, empty_needed = 1, non_empty_childs + # Now, check if this node is full/empty using the + # counts. + if empty_needed - nothing_childs <= 0: if self.negated: - empty = True - continue - - empty = False - if sql: - result.append(sql) - result_params.extend(params) - if empty: - raise EmptyResultSet + return '', [] + else: + raise EmptyResultSet + if full_needed - everything_childs <= 0: + if self.negated: + raise EmptyResultSet + else: + return '', [] + if non_empty_childs == 0: + # All the child nodes were empty, so this one is empty, too. + return None, [] conn = ' %s ' % self.connector sql_string = conn.join(result) if sql_string: - if self.negated: - sql_string = 'NOT (%s)' % sql_string - elif len(self.children) != 1: + if len(result) > 1: sql_string = '(%s)' % sql_string + if self.negated: + sql_string = 'NOT %s' % sql_string return sql_string, result_params def make_atom(self, child, qn, connection): @@ -261,7 +274,7 @@ class EverythingNode(object): """ def as_sql(self, qn=None, connection=None): - raise FullResultSet + return '', [] def relabel_aliases(self, change_map, node=None): return diff --git a/tests/regressiontests/queries/tests.py b/tests/regressiontests/queries/tests.py index 4cc7208a96a..3d2aafd2d96 100644 --- a/tests/regressiontests/queries/tests.py +++ b/tests/regressiontests/queries/tests.py @@ -10,6 +10,8 @@ from django.core.exceptions import FieldError from django.db import DatabaseError, connection, connections, DEFAULT_DB_ALIAS from django.db.models import Count from django.db.models.query import Q, ITER_CHUNK_SIZE, EmptyQuerySet +from django.db.models.sql.where import WhereNode, EverythingNode, NothingNode +from django.db.models.sql.datastructures import EmptyResultSet from django.test import TestCase, skipUnlessDBFeature from django.test.utils import str_prefix from django.utils import unittest @@ -1316,10 +1318,23 @@ class Queries5Tests(TestCase): ) def test_ticket5261(self): + # Test different empty excludes. self.assertQuerysetEqual( Note.objects.exclude(Q()), ['', ''] ) + self.assertQuerysetEqual( + Note.objects.filter(~Q()), + ['', ''] + ) + self.assertQuerysetEqual( + Note.objects.filter(~Q()|~Q()), + ['', ''] + ) + self.assertQuerysetEqual( + Note.objects.exclude(~Q()&~Q()), + ['', ''] + ) class SelectRelatedTests(TestCase): @@ -2020,3 +2035,74 @@ class ProxyQueryCleanupTest(TestCase): self.assertEqual(qs.count(), 1) str(qs.query) self.assertEqual(qs.count(), 1) + +class WhereNodeTest(TestCase): + class DummyNode(object): + def as_sql(self, qn, connection): + return 'dummy', [] + + def test_empty_full_handling_conjunction(self): + qn = connection.ops.quote_name + w = WhereNode(children=[EverythingNode()]) + self.assertEquals(w.as_sql(qn, connection), ('', [])) + w.negate() + self.assertRaises(EmptyResultSet, w.as_sql, qn, connection) + w = WhereNode(children=[NothingNode()]) + self.assertRaises(EmptyResultSet, w.as_sql, qn, connection) + w.negate() + self.assertEquals(w.as_sql(qn, connection), ('', [])) + w = WhereNode(children=[EverythingNode(), EverythingNode()]) + self.assertEquals(w.as_sql(qn, connection), ('', [])) + w.negate() + self.assertRaises(EmptyResultSet, w.as_sql, qn, connection) + w = WhereNode(children=[EverythingNode(), self.DummyNode()]) + self.assertEquals(w.as_sql(qn, connection), ('dummy', [])) + w = WhereNode(children=[self.DummyNode(), self.DummyNode()]) + self.assertEquals(w.as_sql(qn, connection), ('(dummy AND dummy)', [])) + w.negate() + self.assertEquals(w.as_sql(qn, connection), ('NOT (dummy AND dummy)', [])) + w = WhereNode(children=[NothingNode(), self.DummyNode()]) + self.assertRaises(EmptyResultSet, w.as_sql, qn, connection) + w.negate() + self.assertEquals(w.as_sql(qn, connection), ('', [])) + + def test_empty_full_handling_disjunction(self): + qn = connection.ops.quote_name + w = WhereNode(children=[EverythingNode()], connector='OR') + self.assertEquals(w.as_sql(qn, connection), ('', [])) + w.negate() + self.assertRaises(EmptyResultSet, w.as_sql, qn, connection) + w = WhereNode(children=[NothingNode()], connector='OR') + self.assertRaises(EmptyResultSet, w.as_sql, qn, connection) + w.negate() + self.assertEquals(w.as_sql(qn, connection), ('', [])) + w = WhereNode(children=[EverythingNode(), EverythingNode()], connector='OR') + self.assertEquals(w.as_sql(qn, connection), ('', [])) + w.negate() + self.assertRaises(EmptyResultSet, w.as_sql, qn, connection) + w = WhereNode(children=[EverythingNode(), self.DummyNode()], connector='OR') + self.assertEquals(w.as_sql(qn, connection), ('', [])) + w.negate() + self.assertRaises(EmptyResultSet, w.as_sql, qn, connection) + w = WhereNode(children=[self.DummyNode(), self.DummyNode()], connector='OR') + self.assertEquals(w.as_sql(qn, connection), ('(dummy OR dummy)', [])) + w.negate() + self.assertEquals(w.as_sql(qn, connection), ('NOT (dummy OR dummy)', [])) + w = WhereNode(children=[NothingNode(), self.DummyNode()], connector='OR') + self.assertEquals(w.as_sql(qn, connection), ('dummy', [])) + w.negate() + self.assertEquals(w.as_sql(qn, connection), ('NOT dummy', [])) + + def test_empty_nodes(self): + qn = connection.ops.quote_name + empty_w = WhereNode() + w = WhereNode(children=[empty_w, empty_w]) + self.assertEquals(w.as_sql(qn, connection), (None, [])) + w.negate() + self.assertEquals(w.as_sql(qn, connection), (None, [])) + w.connector = 'OR' + self.assertEquals(w.as_sql(qn, connection), (None, [])) + w.negate() + self.assertEquals(w.as_sql(qn, connection), (None, [])) + w = WhereNode(children=[empty_w, NothingNode()], connector='OR') + self.assertRaises(EmptyResultSet, w.as_sql, qn, connection)