Refs #33374 -- Adjusted full match condition handling.

Adjusting WhereNode.as_sql() to raise an exception when encoutering a
full match just like with empty matches ensures that all case are
explicitly handled.
This commit is contained in:
Simon Charette 2022-11-06 11:19:33 -05:00 committed by Mariusz Felisiak
parent 4b702c832c
commit 76e37513e2
11 changed files with 114 additions and 61 deletions

View File

@ -233,6 +233,12 @@ class EmptyResultSet(Exception):
pass
class FullResultSet(Exception):
"""A database query predicate is matches everything."""
pass
class SynchronousOnlyOperation(Exception):
"""The user tried to call a sync-only function from an async context."""

View File

@ -1,4 +1,4 @@
from django.core.exceptions import FieldError
from django.core.exceptions import FieldError, FullResultSet
from django.db.models.expressions import Col
from django.db.models.sql import compiler
@ -40,12 +40,16 @@ class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler):
"DELETE %s FROM"
% self.quote_name_unless_alias(self.query.get_initial_alias())
]
from_sql, from_params = self.get_from_clause()
from_sql, params = self.get_from_clause()
result.extend(from_sql)
try:
where_sql, where_params = self.compile(where)
if where_sql:
except FullResultSet:
pass
else:
result.append("WHERE %s" % where_sql)
return " ".join(result), tuple(from_params) + tuple(where_params)
params.extend(where_params)
return " ".join(result), tuple(params)
class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler):

View File

@ -1,7 +1,7 @@
"""
Classes to represent the definitions of aggregate functions.
"""
from django.core.exceptions import FieldError
from django.core.exceptions import FieldError, FullResultSet
from django.db.models.expressions import Case, Func, Star, When
from django.db.models.fields import IntegerField
from django.db.models.functions.comparison import Coalesce
@ -104,8 +104,11 @@ class Aggregate(Func):
extra_context["distinct"] = "DISTINCT " if self.distinct else ""
if self.filter:
if connection.features.supports_aggregate_filter_clause:
try:
filter_sql, filter_params = self.filter.as_sql(compiler, connection)
if filter_sql:
except FullResultSet:
pass
else:
template = self.filter_template % extra_context.get(
"template", self.template
)

View File

@ -7,7 +7,7 @@ from collections import defaultdict
from decimal import Decimal
from uuid import UUID
from django.core.exceptions import EmptyResultSet, FieldError
from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet
from django.db import DatabaseError, NotSupportedError, connection
from django.db.models import fields
from django.db.models.constants import LOOKUP_SEP
@ -955,6 +955,8 @@ class Func(SQLiteNumericMixin, Expression):
if empty_result_set_value is NotImplemented:
raise
arg_sql, arg_params = compiler.compile(Value(empty_result_set_value))
except FullResultSet:
arg_sql, arg_params = compiler.compile(Value(True))
sql_parts.append(arg_sql)
params.extend(arg_params)
data = {**self.extra, **extra_context}
@ -1367,14 +1369,6 @@ class When(Expression):
template_params = extra_context
sql_params = []
condition_sql, condition_params = compiler.compile(self.condition)
# Filters that match everything are handled as empty strings in the
# WHERE clause, but in a CASE WHEN expression they must use a predicate
# that's always True.
if condition_sql == "":
if connection.features.supports_boolean_expr_in_select_clause:
condition_sql, condition_params = compiler.compile(Value(True))
else:
condition_sql, condition_params = "1=1", ()
template_params["condition"] = condition_sql
result_sql, result_params = compiler.compile(self.result)
template_params["result"] = result_sql
@ -1461,14 +1455,17 @@ class Case(SQLiteNumericMixin, Expression):
template_params = {**self.extra, **extra_context}
case_parts = []
sql_params = []
default_sql, default_params = compiler.compile(self.default)
for case in self.cases:
try:
case_sql, case_params = compiler.compile(case)
except EmptyResultSet:
continue
except FullResultSet:
default_sql, default_params = compiler.compile(case.result)
break
case_parts.append(case_sql)
sql_params.extend(case_params)
default_sql, default_params = compiler.compile(self.default)
if not case_parts:
return default_sql, default_params
case_joiner = case_joiner or self.case_joiner

View File

@ -1103,15 +1103,6 @@ class BooleanField(Field):
defaults = {"form_class": form_class, "required": False}
return super().formfield(**{**defaults, **kwargs})
def select_format(self, compiler, sql, params):
sql, params = super().select_format(compiler, sql, params)
# Filters that match everything are handled as empty strings in the
# WHERE clause, but in SELECT or GROUP BY list they must use a
# predicate that's always True.
if sql == "":
sql = "1"
return sql, params
class CharField(Field):
description = _("String (up to %(max_length)s)")

View File

@ -4,7 +4,7 @@ import re
from functools import partial
from itertools import chain
from django.core.exceptions import EmptyResultSet, FieldError
from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet
from django.db import DatabaseError, NotSupportedError
from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import F, OrderBy, RawSQL, Ref, Value
@ -169,7 +169,7 @@ class SQLCompiler:
expr = Ref(alias, expr)
try:
sql, params = self.compile(expr)
except EmptyResultSet:
except (EmptyResultSet, FullResultSet):
continue
sql, params = expr.select_format(self, sql, params)
params_hash = make_hashable(params)
@ -287,6 +287,8 @@ class SQLCompiler:
sql, params = "0", ()
else:
sql, params = self.compile(Value(empty_result_set_value))
except FullResultSet:
sql, params = self.compile(Value(True))
else:
sql, params = col.select_format(self, sql, params)
if alias is None and with_col_aliases:
@ -721,9 +723,16 @@ class SQLCompiler:
raise
# Use a predicate that's always False.
where, w_params = "0 = 1", []
except FullResultSet:
where, w_params = "", []
try:
having, h_params = (
self.compile(self.having) if self.having is not None else ("", [])
self.compile(self.having)
if self.having is not None
else ("", [])
)
except FullResultSet:
having, h_params = "", []
result = ["SELECT"]
params = []
@ -1817,11 +1826,12 @@ class SQLDeleteCompiler(SQLCompiler):
)
def _as_sql(self, query):
result = ["DELETE FROM %s" % self.quote_name_unless_alias(query.base_table)]
delete = "DELETE FROM %s" % self.quote_name_unless_alias(query.base_table)
try:
where, params = self.compile(query.where)
if where:
result.append("WHERE %s" % where)
return " ".join(result), tuple(params)
except FullResultSet:
return delete, ()
return f"{delete} WHERE {where}", tuple(params)
def as_sql(self):
"""
@ -1906,8 +1916,11 @@ class SQLUpdateCompiler(SQLCompiler):
"UPDATE %s SET" % qn(table),
", ".join(values),
]
try:
where, params = self.compile(self.query.where)
if where:
except FullResultSet:
params = []
else:
result.append("WHERE %s" % where)
return " ".join(result), tuple(update_params + params)

View File

@ -2,6 +2,7 @@
Useful auxiliary data structures for query construction. Not useful outside
the SQL domain.
"""
from django.core.exceptions import FullResultSet
from django.db.models.sql.constants import INNER, LOUTER
@ -100,8 +101,11 @@ class Join:
join_conditions.append("(%s)" % extra_sql)
params.extend(extra_params)
if self.filtered_relation:
try:
extra_sql, extra_params = compiler.compile(self.filtered_relation)
if extra_sql:
except FullResultSet:
pass
else:
join_conditions.append("(%s)" % extra_sql)
params.extend(extra_params)
if not join_conditions:

View File

@ -4,7 +4,7 @@ Code to manage the creation and SQL rendering of 'where' constraints.
import operator
from functools import reduce
from django.core.exceptions import EmptyResultSet
from django.core.exceptions import EmptyResultSet, FullResultSet
from django.db.models.expressions import Case, When
from django.db.models.lookups import Exact
from django.utils import tree
@ -145,6 +145,8 @@ class WhereNode(tree.Node):
sql, params = compiler.compile(child)
except EmptyResultSet:
empty_needed -= 1
except FullResultSet:
full_needed -= 1
else:
if sql:
result.append(sql)
@ -158,21 +160,22 @@ class WhereNode(tree.Node):
# counts.
if empty_needed == 0:
if self.negated:
return "", []
raise FullResultSet
else:
raise EmptyResultSet
if full_needed == 0:
if self.negated:
raise EmptyResultSet
else:
return "", []
raise FullResultSet
conn = " %s " % self.connector
sql_string = conn.join(result)
if sql_string:
if not sql_string:
raise FullResultSet
if self.negated:
# Some backends (Oracle at least) need parentheses
# around the inner SQL in the negated case, even if the
# inner SQL contains just a single expression.
# Some backends (Oracle at least) need parentheses around the inner
# SQL in the negated case, even if the inner SQL contains just a
# single expression.
sql_string = "NOT (%s)" % sql_string
elif len(result) > 1 or self.resolved:
sql_string = "(%s)" % sql_string

View File

@ -42,6 +42,17 @@ Django core exception classes are defined in ``django.core.exceptions``.
return any results. Most Django projects won't encounter this exception,
but it might be useful for implementing custom lookups and expressions.
``FullResultSet``
-----------------
.. exception:: FullResultSet
.. versionadded:: 4.2
``FullResultSet`` may be raised during query generation if a query will
match everything. Most Django projects won't encounter this exception, but
it might be useful for implementing custom lookups and expressions.
``FieldDoesNotExist``
---------------------

View File

@ -24,7 +24,15 @@ from django.db.models import (
When,
)
from django.db.models.expressions import RawSQL
from django.db.models.functions import Coalesce, ExtractYear, Floor, Length, Lower, Trim
from django.db.models.functions import (
Cast,
Coalesce,
ExtractYear,
Floor,
Length,
Lower,
Trim,
)
from django.test import TestCase, skipUnlessDBFeature
from django.test.utils import register_lookup
@ -282,6 +290,13 @@ class NonAggregateAnnotationTestCase(TestCase):
self.assertEqual(len(books), Book.objects.count())
self.assertTrue(all(book.selected for book in books))
def test_full_expression_wrapped_annotation(self):
books = Book.objects.annotate(
selected=Coalesce(~Q(pk__in=[]), True),
)
self.assertEqual(len(books), Book.objects.count())
self.assertTrue(all(book.selected for book in books))
def test_full_expression_annotation_with_aggregation(self):
qs = Book.objects.filter(isbn="159059725").annotate(
selected=ExpressionWrapper(~Q(pk__in=[]), output_field=BooleanField()),
@ -292,7 +307,7 @@ class NonAggregateAnnotationTestCase(TestCase):
def test_aggregate_over_full_expression_annotation(self):
qs = Book.objects.annotate(
selected=ExpressionWrapper(~Q(pk__in=[]), output_field=BooleanField()),
).aggregate(Sum("selected"))
).aggregate(selected__sum=Sum(Cast("selected", IntegerField())))
self.assertEqual(qs["selected__sum"], Book.objects.count())
def test_empty_queryset_annotation(self):

View File

@ -5,7 +5,7 @@ import unittest
from operator import attrgetter
from threading import Lock
from django.core.exceptions import EmptyResultSet, FieldError
from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet
from django.db import DEFAULT_DB_ALIAS, connection
from django.db.models import CharField, Count, Exists, F, Max, OuterRef, Q
from django.db.models.expressions import RawSQL
@ -3588,7 +3588,8 @@ class WhereNodeTest(SimpleTestCase):
with self.assertRaises(EmptyResultSet):
w.as_sql(compiler, connection)
w.negate()
self.assertEqual(w.as_sql(compiler, connection), ("", []))
with self.assertRaises(FullResultSet):
w.as_sql(compiler, connection)
w = WhereNode(children=[self.DummyNode(), self.DummyNode()])
self.assertEqual(w.as_sql(compiler, connection), ("(dummy AND dummy)", []))
w.negate()
@ -3597,7 +3598,8 @@ class WhereNodeTest(SimpleTestCase):
with self.assertRaises(EmptyResultSet):
w.as_sql(compiler, connection)
w.negate()
self.assertEqual(w.as_sql(compiler, connection), ("", []))
with self.assertRaises(FullResultSet):
w.as_sql(compiler, connection)
def test_empty_full_handling_disjunction(self):
compiler = WhereNodeTest.MockCompiler()
@ -3605,7 +3607,8 @@ class WhereNodeTest(SimpleTestCase):
with self.assertRaises(EmptyResultSet):
w.as_sql(compiler, connection)
w.negate()
self.assertEqual(w.as_sql(compiler, connection), ("", []))
with self.assertRaises(FullResultSet):
w.as_sql(compiler, connection)
w = WhereNode(children=[self.DummyNode(), self.DummyNode()], connector=OR)
self.assertEqual(w.as_sql(compiler, connection), ("(dummy OR dummy)", []))
w.negate()
@ -3619,7 +3622,8 @@ class WhereNodeTest(SimpleTestCase):
compiler = WhereNodeTest.MockCompiler()
empty_w = WhereNode()
w = WhereNode(children=[empty_w, empty_w])
self.assertEqual(w.as_sql(compiler, connection), ("", []))
with self.assertRaises(FullResultSet):
w.as_sql(compiler, connection)
w.negate()
with self.assertRaises(EmptyResultSet):
w.as_sql(compiler, connection)
@ -3627,9 +3631,11 @@ class WhereNodeTest(SimpleTestCase):
with self.assertRaises(EmptyResultSet):
w.as_sql(compiler, connection)
w.negate()
self.assertEqual(w.as_sql(compiler, connection), ("", []))
with self.assertRaises(FullResultSet):
w.as_sql(compiler, connection)
w = WhereNode(children=[empty_w, NothingNode()], connector=OR)
self.assertEqual(w.as_sql(compiler, connection), ("", []))
with self.assertRaises(FullResultSet):
w.as_sql(compiler, connection)
w = WhereNode(children=[empty_w, NothingNode()], connector=AND)
with self.assertRaises(EmptyResultSet):
w.as_sql(compiler, connection)