Refactored the empty/full result logic in WhereNode.as_sql()

Made sure the WhereNode.as_sql() handles various EmptyResultSet and
FullResultSet conditions correctly. Also, got rid of the FullResultSet
exception class. It is now represented by '', [] return value in the
as_sql() methods.
This commit is contained in:
Anssi Kääriäinen 2012-05-26 05:55:33 +03:00
parent 2b9fb2e644
commit bd283aa844
3 changed files with 135 additions and 39 deletions

View File

@ -6,9 +6,6 @@ the SQL domain.
class EmptyResultSet(Exception): class EmptyResultSet(Exception):
pass pass
class FullResultSet(Exception):
pass
class MultiJoin(Exception): class MultiJoin(Exception):
""" """
Used by join construction code to indicate the point at which a Used by join construction code to indicate the point at which a

View File

@ -10,7 +10,7 @@ from itertools import repeat
from django.utils import tree from django.utils import tree
from django.db.models.fields import Field from django.db.models.fields import Field
from django.db.models.sql.datastructures import EmptyResultSet, FullResultSet from django.db.models.sql.datastructures import EmptyResultSet
from django.db.models.sql.aggregates import Aggregate from django.db.models.sql.aggregates import Aggregate
# Connection types # Connection types
@ -75,17 +75,21 @@ class WhereNode(tree.Node):
def as_sql(self, qn, connection): def as_sql(self, qn, connection):
""" """
Returns the SQL version of the where clause and the value to be Returns the SQL version of the where clause and the value to be
substituted in. Returns None, None if this node is empty. substituted in. Returns '', [] if this node matches everything,
None, [] if this node is empty, and raises EmptyResultSet if this
If 'node' is provided, that is the root of the SQL generation node can't match anything.
(generally not needed except by the internal implementation for
recursion).
""" """
if not self.children: # Note that the logic here is made slightly more complex than
return None, [] # necessary because there are two kind of empty nodes: Nodes
# containing 0 children, and nodes that are known to match everything.
# A match-everything node is different than empty node (which also
# technically matches everything) for backwards compatibility reasons.
# Refs #5261.
result = [] result = []
result_params = [] result_params = []
empty = True everything_childs, nothing_childs = 0, 0
non_empty_childs = len(self.children)
for child in self.children: for child in self.children:
try: try:
if hasattr(child, 'as_sql'): if hasattr(child, 'as_sql'):
@ -93,39 +97,48 @@ class WhereNode(tree.Node):
else: else:
# A leaf node in the tree. # A leaf node in the tree.
sql, params = self.make_atom(child, qn, connection) sql, params = self.make_atom(child, qn, connection)
except EmptyResultSet: except EmptyResultSet:
if self.connector == AND and not self.negated: nothing_childs += 1
# We can bail out early in this particular case (only). else:
raise
elif self.negated:
empty = False
continue
except FullResultSet:
if self.connector == OR:
if self.negated:
empty = True
break
# We match everything. No need for any constraints.
return '', []
if self.negated:
empty = True
continue
empty = False
if sql: if sql:
result.append(sql) result.append(sql)
result_params.extend(params) result_params.extend(params)
if empty: else:
if sql is None:
# Skip empty childs totally.
non_empty_childs -= 1
continue
everything_childs += 1
# Check if this node matches nothing or everything.
# First check the amount of full nodes and empty nodes
# to make this node empty/full.
if self.connector == AND:
full_needed, empty_needed = non_empty_childs, 1
else:
full_needed, empty_needed = 1, non_empty_childs
# Now, check if this node is full/empty using the
# counts.
if empty_needed - nothing_childs <= 0:
if self.negated:
return '', []
else:
raise EmptyResultSet raise EmptyResultSet
if full_needed - everything_childs <= 0:
if self.negated:
raise EmptyResultSet
else:
return '', []
if non_empty_childs == 0:
# All the child nodes were empty, so this one is empty, too.
return None, []
conn = ' %s ' % self.connector conn = ' %s ' % self.connector
sql_string = conn.join(result) sql_string = conn.join(result)
if sql_string: if sql_string:
if self.negated: if len(result) > 1:
sql_string = 'NOT (%s)' % sql_string
elif len(self.children) != 1:
sql_string = '(%s)' % sql_string sql_string = '(%s)' % sql_string
if self.negated:
sql_string = 'NOT %s' % sql_string
return sql_string, result_params return sql_string, result_params
def make_atom(self, child, qn, connection): def make_atom(self, child, qn, connection):
@ -261,7 +274,7 @@ class EverythingNode(object):
""" """
def as_sql(self, qn=None, connection=None): def as_sql(self, qn=None, connection=None):
raise FullResultSet return '', []
def relabel_aliases(self, change_map, node=None): def relabel_aliases(self, change_map, node=None):
return return

View File

@ -10,6 +10,8 @@ from django.core.exceptions import FieldError
from django.db import DatabaseError, connection, connections, DEFAULT_DB_ALIAS from django.db import DatabaseError, connection, connections, DEFAULT_DB_ALIAS
from django.db.models import Count from django.db.models import Count
from django.db.models.query import Q, ITER_CHUNK_SIZE, EmptyQuerySet from django.db.models.query import Q, ITER_CHUNK_SIZE, EmptyQuerySet
from django.db.models.sql.where import WhereNode, EverythingNode, NothingNode
from django.db.models.sql.datastructures import EmptyResultSet
from django.test import TestCase, skipUnlessDBFeature from django.test import TestCase, skipUnlessDBFeature
from django.test.utils import str_prefix from django.test.utils import str_prefix
from django.utils import unittest from django.utils import unittest
@ -1316,10 +1318,23 @@ class Queries5Tests(TestCase):
) )
def test_ticket5261(self): def test_ticket5261(self):
# Test different empty excludes.
self.assertQuerysetEqual( self.assertQuerysetEqual(
Note.objects.exclude(Q()), Note.objects.exclude(Q()),
['<Note: n1>', '<Note: n2>'] ['<Note: n1>', '<Note: n2>']
) )
self.assertQuerysetEqual(
Note.objects.filter(~Q()),
['<Note: n1>', '<Note: n2>']
)
self.assertQuerysetEqual(
Note.objects.filter(~Q()|~Q()),
['<Note: n1>', '<Note: n2>']
)
self.assertQuerysetEqual(
Note.objects.exclude(~Q()&~Q()),
['<Note: n1>', '<Note: n2>']
)
class SelectRelatedTests(TestCase): class SelectRelatedTests(TestCase):
@ -2020,3 +2035,74 @@ class ProxyQueryCleanupTest(TestCase):
self.assertEqual(qs.count(), 1) self.assertEqual(qs.count(), 1)
str(qs.query) str(qs.query)
self.assertEqual(qs.count(), 1) self.assertEqual(qs.count(), 1)
class WhereNodeTest(TestCase):
class DummyNode(object):
def as_sql(self, qn, connection):
return 'dummy', []
def test_empty_full_handling_conjunction(self):
qn = connection.ops.quote_name
w = WhereNode(children=[EverythingNode()])
self.assertEquals(w.as_sql(qn, connection), ('', []))
w.negate()
self.assertRaises(EmptyResultSet, w.as_sql, qn, connection)
w = WhereNode(children=[NothingNode()])
self.assertRaises(EmptyResultSet, w.as_sql, qn, connection)
w.negate()
self.assertEquals(w.as_sql(qn, connection), ('', []))
w = WhereNode(children=[EverythingNode(), EverythingNode()])
self.assertEquals(w.as_sql(qn, connection), ('', []))
w.negate()
self.assertRaises(EmptyResultSet, w.as_sql, qn, connection)
w = WhereNode(children=[EverythingNode(), self.DummyNode()])
self.assertEquals(w.as_sql(qn, connection), ('dummy', []))
w = WhereNode(children=[self.DummyNode(), self.DummyNode()])
self.assertEquals(w.as_sql(qn, connection), ('(dummy AND dummy)', []))
w.negate()
self.assertEquals(w.as_sql(qn, connection), ('NOT (dummy AND dummy)', []))
w = WhereNode(children=[NothingNode(), self.DummyNode()])
self.assertRaises(EmptyResultSet, w.as_sql, qn, connection)
w.negate()
self.assertEquals(w.as_sql(qn, connection), ('', []))
def test_empty_full_handling_disjunction(self):
qn = connection.ops.quote_name
w = WhereNode(children=[EverythingNode()], connector='OR')
self.assertEquals(w.as_sql(qn, connection), ('', []))
w.negate()
self.assertRaises(EmptyResultSet, w.as_sql, qn, connection)
w = WhereNode(children=[NothingNode()], connector='OR')
self.assertRaises(EmptyResultSet, w.as_sql, qn, connection)
w.negate()
self.assertEquals(w.as_sql(qn, connection), ('', []))
w = WhereNode(children=[EverythingNode(), EverythingNode()], connector='OR')
self.assertEquals(w.as_sql(qn, connection), ('', []))
w.negate()
self.assertRaises(EmptyResultSet, w.as_sql, qn, connection)
w = WhereNode(children=[EverythingNode(), self.DummyNode()], connector='OR')
self.assertEquals(w.as_sql(qn, connection), ('', []))
w.negate()
self.assertRaises(EmptyResultSet, w.as_sql, qn, connection)
w = WhereNode(children=[self.DummyNode(), self.DummyNode()], connector='OR')
self.assertEquals(w.as_sql(qn, connection), ('(dummy OR dummy)', []))
w.negate()
self.assertEquals(w.as_sql(qn, connection), ('NOT (dummy OR dummy)', []))
w = WhereNode(children=[NothingNode(), self.DummyNode()], connector='OR')
self.assertEquals(w.as_sql(qn, connection), ('dummy', []))
w.negate()
self.assertEquals(w.as_sql(qn, connection), ('NOT dummy', []))
def test_empty_nodes(self):
qn = connection.ops.quote_name
empty_w = WhereNode()
w = WhereNode(children=[empty_w, empty_w])
self.assertEquals(w.as_sql(qn, connection), (None, []))
w.negate()
self.assertEquals(w.as_sql(qn, connection), (None, []))
w.connector = 'OR'
self.assertEquals(w.as_sql(qn, connection), (None, []))
w.negate()
self.assertEquals(w.as_sql(qn, connection), (None, []))
w = WhereNode(children=[empty_w, NothingNode()], connector='OR')
self.assertRaises(EmptyResultSet, w.as_sql, qn, connection)