Refs #33374 -- Adjusted full match condition handling.

Adjusting WhereNode.as_sql() to raise an exception when encoutering a
full match just like with empty matches ensures that all case are
explicitly handled.
This commit is contained in:
Simon Charette 2022-11-06 11:19:33 -05:00 committed by Mariusz Felisiak
parent 4b702c832c
commit 76e37513e2
11 changed files with 114 additions and 61 deletions

View File

@ -233,6 +233,12 @@ class EmptyResultSet(Exception):
pass pass
class FullResultSet(Exception):
"""A database query predicate is matches everything."""
pass
class SynchronousOnlyOperation(Exception): class SynchronousOnlyOperation(Exception):
"""The user tried to call a sync-only function from an async context.""" """The user tried to call a sync-only function from an async context."""

View File

@ -1,4 +1,4 @@
from django.core.exceptions import FieldError from django.core.exceptions import FieldError, FullResultSet
from django.db.models.expressions import Col from django.db.models.expressions import Col
from django.db.models.sql import compiler from django.db.models.sql import compiler
@ -40,12 +40,16 @@ class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler):
"DELETE %s FROM" "DELETE %s FROM"
% self.quote_name_unless_alias(self.query.get_initial_alias()) % self.quote_name_unless_alias(self.query.get_initial_alias())
] ]
from_sql, from_params = self.get_from_clause() from_sql, params = self.get_from_clause()
result.extend(from_sql) result.extend(from_sql)
try:
where_sql, where_params = self.compile(where) where_sql, where_params = self.compile(where)
if where_sql: except FullResultSet:
pass
else:
result.append("WHERE %s" % where_sql) result.append("WHERE %s" % where_sql)
return " ".join(result), tuple(from_params) + tuple(where_params) params.extend(where_params)
return " ".join(result), tuple(params)
class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler): class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler):

View File

@ -1,7 +1,7 @@
""" """
Classes to represent the definitions of aggregate functions. Classes to represent the definitions of aggregate functions.
""" """
from django.core.exceptions import FieldError from django.core.exceptions import FieldError, FullResultSet
from django.db.models.expressions import Case, Func, Star, When from django.db.models.expressions import Case, Func, Star, When
from django.db.models.fields import IntegerField from django.db.models.fields import IntegerField
from django.db.models.functions.comparison import Coalesce from django.db.models.functions.comparison import Coalesce
@ -104,8 +104,11 @@ class Aggregate(Func):
extra_context["distinct"] = "DISTINCT " if self.distinct else "" extra_context["distinct"] = "DISTINCT " if self.distinct else ""
if self.filter: if self.filter:
if connection.features.supports_aggregate_filter_clause: if connection.features.supports_aggregate_filter_clause:
try:
filter_sql, filter_params = self.filter.as_sql(compiler, connection) filter_sql, filter_params = self.filter.as_sql(compiler, connection)
if filter_sql: except FullResultSet:
pass
else:
template = self.filter_template % extra_context.get( template = self.filter_template % extra_context.get(
"template", self.template "template", self.template
) )

View File

@ -7,7 +7,7 @@ from collections import defaultdict
from decimal import Decimal from decimal import Decimal
from uuid import UUID from uuid import UUID
from django.core.exceptions import EmptyResultSet, FieldError from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet
from django.db import DatabaseError, NotSupportedError, connection from django.db import DatabaseError, NotSupportedError, connection
from django.db.models import fields from django.db.models import fields
from django.db.models.constants import LOOKUP_SEP from django.db.models.constants import LOOKUP_SEP
@ -955,6 +955,8 @@ class Func(SQLiteNumericMixin, Expression):
if empty_result_set_value is NotImplemented: if empty_result_set_value is NotImplemented:
raise raise
arg_sql, arg_params = compiler.compile(Value(empty_result_set_value)) arg_sql, arg_params = compiler.compile(Value(empty_result_set_value))
except FullResultSet:
arg_sql, arg_params = compiler.compile(Value(True))
sql_parts.append(arg_sql) sql_parts.append(arg_sql)
params.extend(arg_params) params.extend(arg_params)
data = {**self.extra, **extra_context} data = {**self.extra, **extra_context}
@ -1367,14 +1369,6 @@ class When(Expression):
template_params = extra_context template_params = extra_context
sql_params = [] sql_params = []
condition_sql, condition_params = compiler.compile(self.condition) condition_sql, condition_params = compiler.compile(self.condition)
# Filters that match everything are handled as empty strings in the
# WHERE clause, but in a CASE WHEN expression they must use a predicate
# that's always True.
if condition_sql == "":
if connection.features.supports_boolean_expr_in_select_clause:
condition_sql, condition_params = compiler.compile(Value(True))
else:
condition_sql, condition_params = "1=1", ()
template_params["condition"] = condition_sql template_params["condition"] = condition_sql
result_sql, result_params = compiler.compile(self.result) result_sql, result_params = compiler.compile(self.result)
template_params["result"] = result_sql template_params["result"] = result_sql
@ -1461,14 +1455,17 @@ class Case(SQLiteNumericMixin, Expression):
template_params = {**self.extra, **extra_context} template_params = {**self.extra, **extra_context}
case_parts = [] case_parts = []
sql_params = [] sql_params = []
default_sql, default_params = compiler.compile(self.default)
for case in self.cases: for case in self.cases:
try: try:
case_sql, case_params = compiler.compile(case) case_sql, case_params = compiler.compile(case)
except EmptyResultSet: except EmptyResultSet:
continue continue
except FullResultSet:
default_sql, default_params = compiler.compile(case.result)
break
case_parts.append(case_sql) case_parts.append(case_sql)
sql_params.extend(case_params) sql_params.extend(case_params)
default_sql, default_params = compiler.compile(self.default)
if not case_parts: if not case_parts:
return default_sql, default_params return default_sql, default_params
case_joiner = case_joiner or self.case_joiner case_joiner = case_joiner or self.case_joiner

View File

@ -1103,15 +1103,6 @@ class BooleanField(Field):
defaults = {"form_class": form_class, "required": False} defaults = {"form_class": form_class, "required": False}
return super().formfield(**{**defaults, **kwargs}) return super().formfield(**{**defaults, **kwargs})
def select_format(self, compiler, sql, params):
sql, params = super().select_format(compiler, sql, params)
# Filters that match everything are handled as empty strings in the
# WHERE clause, but in SELECT or GROUP BY list they must use a
# predicate that's always True.
if sql == "":
sql = "1"
return sql, params
class CharField(Field): class CharField(Field):
description = _("String (up to %(max_length)s)") description = _("String (up to %(max_length)s)")

View File

@ -4,7 +4,7 @@ import re
from functools import partial from functools import partial
from itertools import chain from itertools import chain
from django.core.exceptions import EmptyResultSet, FieldError from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet
from django.db import DatabaseError, NotSupportedError from django.db import DatabaseError, NotSupportedError
from django.db.models.constants import LOOKUP_SEP from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import F, OrderBy, RawSQL, Ref, Value from django.db.models.expressions import F, OrderBy, RawSQL, Ref, Value
@ -169,7 +169,7 @@ class SQLCompiler:
expr = Ref(alias, expr) expr = Ref(alias, expr)
try: try:
sql, params = self.compile(expr) sql, params = self.compile(expr)
except EmptyResultSet: except (EmptyResultSet, FullResultSet):
continue continue
sql, params = expr.select_format(self, sql, params) sql, params = expr.select_format(self, sql, params)
params_hash = make_hashable(params) params_hash = make_hashable(params)
@ -287,6 +287,8 @@ class SQLCompiler:
sql, params = "0", () sql, params = "0", ()
else: else:
sql, params = self.compile(Value(empty_result_set_value)) sql, params = self.compile(Value(empty_result_set_value))
except FullResultSet:
sql, params = self.compile(Value(True))
else: else:
sql, params = col.select_format(self, sql, params) sql, params = col.select_format(self, sql, params)
if alias is None and with_col_aliases: if alias is None and with_col_aliases:
@ -721,9 +723,16 @@ class SQLCompiler:
raise raise
# Use a predicate that's always False. # Use a predicate that's always False.
where, w_params = "0 = 1", [] where, w_params = "0 = 1", []
except FullResultSet:
where, w_params = "", []
try:
having, h_params = ( having, h_params = (
self.compile(self.having) if self.having is not None else ("", []) self.compile(self.having)
if self.having is not None
else ("", [])
) )
except FullResultSet:
having, h_params = "", []
result = ["SELECT"] result = ["SELECT"]
params = [] params = []
@ -1817,11 +1826,12 @@ class SQLDeleteCompiler(SQLCompiler):
) )
def _as_sql(self, query): def _as_sql(self, query):
result = ["DELETE FROM %s" % self.quote_name_unless_alias(query.base_table)] delete = "DELETE FROM %s" % self.quote_name_unless_alias(query.base_table)
try:
where, params = self.compile(query.where) where, params = self.compile(query.where)
if where: except FullResultSet:
result.append("WHERE %s" % where) return delete, ()
return " ".join(result), tuple(params) return f"{delete} WHERE {where}", tuple(params)
def as_sql(self): def as_sql(self):
""" """
@ -1906,8 +1916,11 @@ class SQLUpdateCompiler(SQLCompiler):
"UPDATE %s SET" % qn(table), "UPDATE %s SET" % qn(table),
", ".join(values), ", ".join(values),
] ]
try:
where, params = self.compile(self.query.where) where, params = self.compile(self.query.where)
if where: except FullResultSet:
params = []
else:
result.append("WHERE %s" % where) result.append("WHERE %s" % where)
return " ".join(result), tuple(update_params + params) return " ".join(result), tuple(update_params + params)

View File

@ -2,6 +2,7 @@
Useful auxiliary data structures for query construction. Not useful outside Useful auxiliary data structures for query construction. Not useful outside
the SQL domain. the SQL domain.
""" """
from django.core.exceptions import FullResultSet
from django.db.models.sql.constants import INNER, LOUTER from django.db.models.sql.constants import INNER, LOUTER
@ -100,8 +101,11 @@ class Join:
join_conditions.append("(%s)" % extra_sql) join_conditions.append("(%s)" % extra_sql)
params.extend(extra_params) params.extend(extra_params)
if self.filtered_relation: if self.filtered_relation:
try:
extra_sql, extra_params = compiler.compile(self.filtered_relation) extra_sql, extra_params = compiler.compile(self.filtered_relation)
if extra_sql: except FullResultSet:
pass
else:
join_conditions.append("(%s)" % extra_sql) join_conditions.append("(%s)" % extra_sql)
params.extend(extra_params) params.extend(extra_params)
if not join_conditions: if not join_conditions:

View File

@ -4,7 +4,7 @@ Code to manage the creation and SQL rendering of 'where' constraints.
import operator import operator
from functools import reduce from functools import reduce
from django.core.exceptions import EmptyResultSet from django.core.exceptions import EmptyResultSet, FullResultSet
from django.db.models.expressions import Case, When from django.db.models.expressions import Case, When
from django.db.models.lookups import Exact from django.db.models.lookups import Exact
from django.utils import tree from django.utils import tree
@ -145,6 +145,8 @@ class WhereNode(tree.Node):
sql, params = compiler.compile(child) sql, params = compiler.compile(child)
except EmptyResultSet: except EmptyResultSet:
empty_needed -= 1 empty_needed -= 1
except FullResultSet:
full_needed -= 1
else: else:
if sql: if sql:
result.append(sql) result.append(sql)
@ -158,21 +160,22 @@ class WhereNode(tree.Node):
# counts. # counts.
if empty_needed == 0: if empty_needed == 0:
if self.negated: if self.negated:
return "", [] raise FullResultSet
else: else:
raise EmptyResultSet raise EmptyResultSet
if full_needed == 0: if full_needed == 0:
if self.negated: if self.negated:
raise EmptyResultSet raise EmptyResultSet
else: else:
return "", [] raise FullResultSet
conn = " %s " % self.connector conn = " %s " % self.connector
sql_string = conn.join(result) sql_string = conn.join(result)
if sql_string: if not sql_string:
raise FullResultSet
if self.negated: if self.negated:
# Some backends (Oracle at least) need parentheses # Some backends (Oracle at least) need parentheses around the inner
# around the inner SQL in the negated case, even if the # SQL in the negated case, even if the inner SQL contains just a
# inner SQL contains just a single expression. # single expression.
sql_string = "NOT (%s)" % sql_string sql_string = "NOT (%s)" % sql_string
elif len(result) > 1 or self.resolved: elif len(result) > 1 or self.resolved:
sql_string = "(%s)" % sql_string sql_string = "(%s)" % sql_string

View File

@ -42,6 +42,17 @@ Django core exception classes are defined in ``django.core.exceptions``.
return any results. Most Django projects won't encounter this exception, return any results. Most Django projects won't encounter this exception,
but it might be useful for implementing custom lookups and expressions. but it might be useful for implementing custom lookups and expressions.
``FullResultSet``
-----------------
.. exception:: FullResultSet
.. versionadded:: 4.2
``FullResultSet`` may be raised during query generation if a query will
match everything. Most Django projects won't encounter this exception, but
it might be useful for implementing custom lookups and expressions.
``FieldDoesNotExist`` ``FieldDoesNotExist``
--------------------- ---------------------

View File

@ -24,7 +24,15 @@ from django.db.models import (
When, When,
) )
from django.db.models.expressions import RawSQL from django.db.models.expressions import RawSQL
from django.db.models.functions import Coalesce, ExtractYear, Floor, Length, Lower, Trim from django.db.models.functions import (
Cast,
Coalesce,
ExtractYear,
Floor,
Length,
Lower,
Trim,
)
from django.test import TestCase, skipUnlessDBFeature from django.test import TestCase, skipUnlessDBFeature
from django.test.utils import register_lookup from django.test.utils import register_lookup
@ -282,6 +290,13 @@ class NonAggregateAnnotationTestCase(TestCase):
self.assertEqual(len(books), Book.objects.count()) self.assertEqual(len(books), Book.objects.count())
self.assertTrue(all(book.selected for book in books)) self.assertTrue(all(book.selected for book in books))
def test_full_expression_wrapped_annotation(self):
books = Book.objects.annotate(
selected=Coalesce(~Q(pk__in=[]), True),
)
self.assertEqual(len(books), Book.objects.count())
self.assertTrue(all(book.selected for book in books))
def test_full_expression_annotation_with_aggregation(self): def test_full_expression_annotation_with_aggregation(self):
qs = Book.objects.filter(isbn="159059725").annotate( qs = Book.objects.filter(isbn="159059725").annotate(
selected=ExpressionWrapper(~Q(pk__in=[]), output_field=BooleanField()), selected=ExpressionWrapper(~Q(pk__in=[]), output_field=BooleanField()),
@ -292,7 +307,7 @@ class NonAggregateAnnotationTestCase(TestCase):
def test_aggregate_over_full_expression_annotation(self): def test_aggregate_over_full_expression_annotation(self):
qs = Book.objects.annotate( qs = Book.objects.annotate(
selected=ExpressionWrapper(~Q(pk__in=[]), output_field=BooleanField()), selected=ExpressionWrapper(~Q(pk__in=[]), output_field=BooleanField()),
).aggregate(Sum("selected")) ).aggregate(selected__sum=Sum(Cast("selected", IntegerField())))
self.assertEqual(qs["selected__sum"], Book.objects.count()) self.assertEqual(qs["selected__sum"], Book.objects.count())
def test_empty_queryset_annotation(self): def test_empty_queryset_annotation(self):

View File

@ -5,7 +5,7 @@ import unittest
from operator import attrgetter from operator import attrgetter
from threading import Lock from threading import Lock
from django.core.exceptions import EmptyResultSet, FieldError from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet
from django.db import DEFAULT_DB_ALIAS, connection from django.db import DEFAULT_DB_ALIAS, connection
from django.db.models import CharField, Count, Exists, F, Max, OuterRef, Q from django.db.models import CharField, Count, Exists, F, Max, OuterRef, Q
from django.db.models.expressions import RawSQL from django.db.models.expressions import RawSQL
@ -3588,7 +3588,8 @@ class WhereNodeTest(SimpleTestCase):
with self.assertRaises(EmptyResultSet): with self.assertRaises(EmptyResultSet):
w.as_sql(compiler, connection) w.as_sql(compiler, connection)
w.negate() w.negate()
self.assertEqual(w.as_sql(compiler, connection), ("", [])) with self.assertRaises(FullResultSet):
w.as_sql(compiler, connection)
w = WhereNode(children=[self.DummyNode(), self.DummyNode()]) w = WhereNode(children=[self.DummyNode(), self.DummyNode()])
self.assertEqual(w.as_sql(compiler, connection), ("(dummy AND dummy)", [])) self.assertEqual(w.as_sql(compiler, connection), ("(dummy AND dummy)", []))
w.negate() w.negate()
@ -3597,7 +3598,8 @@ class WhereNodeTest(SimpleTestCase):
with self.assertRaises(EmptyResultSet): with self.assertRaises(EmptyResultSet):
w.as_sql(compiler, connection) w.as_sql(compiler, connection)
w.negate() w.negate()
self.assertEqual(w.as_sql(compiler, connection), ("", [])) with self.assertRaises(FullResultSet):
w.as_sql(compiler, connection)
def test_empty_full_handling_disjunction(self): def test_empty_full_handling_disjunction(self):
compiler = WhereNodeTest.MockCompiler() compiler = WhereNodeTest.MockCompiler()
@ -3605,7 +3607,8 @@ class WhereNodeTest(SimpleTestCase):
with self.assertRaises(EmptyResultSet): with self.assertRaises(EmptyResultSet):
w.as_sql(compiler, connection) w.as_sql(compiler, connection)
w.negate() w.negate()
self.assertEqual(w.as_sql(compiler, connection), ("", [])) with self.assertRaises(FullResultSet):
w.as_sql(compiler, connection)
w = WhereNode(children=[self.DummyNode(), self.DummyNode()], connector=OR) w = WhereNode(children=[self.DummyNode(), self.DummyNode()], connector=OR)
self.assertEqual(w.as_sql(compiler, connection), ("(dummy OR dummy)", [])) self.assertEqual(w.as_sql(compiler, connection), ("(dummy OR dummy)", []))
w.negate() w.negate()
@ -3619,7 +3622,8 @@ class WhereNodeTest(SimpleTestCase):
compiler = WhereNodeTest.MockCompiler() compiler = WhereNodeTest.MockCompiler()
empty_w = WhereNode() empty_w = WhereNode()
w = WhereNode(children=[empty_w, empty_w]) w = WhereNode(children=[empty_w, empty_w])
self.assertEqual(w.as_sql(compiler, connection), ("", [])) with self.assertRaises(FullResultSet):
w.as_sql(compiler, connection)
w.negate() w.negate()
with self.assertRaises(EmptyResultSet): with self.assertRaises(EmptyResultSet):
w.as_sql(compiler, connection) w.as_sql(compiler, connection)
@ -3627,9 +3631,11 @@ class WhereNodeTest(SimpleTestCase):
with self.assertRaises(EmptyResultSet): with self.assertRaises(EmptyResultSet):
w.as_sql(compiler, connection) w.as_sql(compiler, connection)
w.negate() w.negate()
self.assertEqual(w.as_sql(compiler, connection), ("", [])) with self.assertRaises(FullResultSet):
w.as_sql(compiler, connection)
w = WhereNode(children=[empty_w, NothingNode()], connector=OR) w = WhereNode(children=[empty_w, NothingNode()], connector=OR)
self.assertEqual(w.as_sql(compiler, connection), ("", [])) with self.assertRaises(FullResultSet):
w.as_sql(compiler, connection)
w = WhereNode(children=[empty_w, NothingNode()], connector=AND) w = WhereNode(children=[empty_w, NothingNode()], connector=AND)
with self.assertRaises(EmptyResultSet): with self.assertRaises(EmptyResultSet):
w.as_sql(compiler, connection) w.as_sql(compiler, connection)