mirror of https://github.com/django/django.git
Refs #33374 -- Adjusted full match condition handling.
Adjusting WhereNode.as_sql() to raise an exception when encoutering a full match just like with empty matches ensures that all case are explicitly handled.
This commit is contained in:
parent
4b702c832c
commit
76e37513e2
|
@ -233,6 +233,12 @@ class EmptyResultSet(Exception):
|
|||
pass
|
||||
|
||||
|
||||
class FullResultSet(Exception):
|
||||
"""A database query predicate is matches everything."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SynchronousOnlyOperation(Exception):
|
||||
"""The user tried to call a sync-only function from an async context."""
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from django.core.exceptions import FieldError
|
||||
from django.core.exceptions import FieldError, FullResultSet
|
||||
from django.db.models.expressions import Col
|
||||
from django.db.models.sql import compiler
|
||||
|
||||
|
@ -40,12 +40,16 @@ class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler):
|
|||
"DELETE %s FROM"
|
||||
% self.quote_name_unless_alias(self.query.get_initial_alias())
|
||||
]
|
||||
from_sql, from_params = self.get_from_clause()
|
||||
from_sql, params = self.get_from_clause()
|
||||
result.extend(from_sql)
|
||||
where_sql, where_params = self.compile(where)
|
||||
if where_sql:
|
||||
try:
|
||||
where_sql, where_params = self.compile(where)
|
||||
except FullResultSet:
|
||||
pass
|
||||
else:
|
||||
result.append("WHERE %s" % where_sql)
|
||||
return " ".join(result), tuple(from_params) + tuple(where_params)
|
||||
params.extend(where_params)
|
||||
return " ".join(result), tuple(params)
|
||||
|
||||
|
||||
class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler):
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
"""
|
||||
Classes to represent the definitions of aggregate functions.
|
||||
"""
|
||||
from django.core.exceptions import FieldError
|
||||
from django.core.exceptions import FieldError, FullResultSet
|
||||
from django.db.models.expressions import Case, Func, Star, When
|
||||
from django.db.models.fields import IntegerField
|
||||
from django.db.models.functions.comparison import Coalesce
|
||||
|
@ -104,8 +104,11 @@ class Aggregate(Func):
|
|||
extra_context["distinct"] = "DISTINCT " if self.distinct else ""
|
||||
if self.filter:
|
||||
if connection.features.supports_aggregate_filter_clause:
|
||||
filter_sql, filter_params = self.filter.as_sql(compiler, connection)
|
||||
if filter_sql:
|
||||
try:
|
||||
filter_sql, filter_params = self.filter.as_sql(compiler, connection)
|
||||
except FullResultSet:
|
||||
pass
|
||||
else:
|
||||
template = self.filter_template % extra_context.get(
|
||||
"template", self.template
|
||||
)
|
||||
|
|
|
@ -7,7 +7,7 @@ from collections import defaultdict
|
|||
from decimal import Decimal
|
||||
from uuid import UUID
|
||||
|
||||
from django.core.exceptions import EmptyResultSet, FieldError
|
||||
from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet
|
||||
from django.db import DatabaseError, NotSupportedError, connection
|
||||
from django.db.models import fields
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
|
@ -955,6 +955,8 @@ class Func(SQLiteNumericMixin, Expression):
|
|||
if empty_result_set_value is NotImplemented:
|
||||
raise
|
||||
arg_sql, arg_params = compiler.compile(Value(empty_result_set_value))
|
||||
except FullResultSet:
|
||||
arg_sql, arg_params = compiler.compile(Value(True))
|
||||
sql_parts.append(arg_sql)
|
||||
params.extend(arg_params)
|
||||
data = {**self.extra, **extra_context}
|
||||
|
@ -1367,14 +1369,6 @@ class When(Expression):
|
|||
template_params = extra_context
|
||||
sql_params = []
|
||||
condition_sql, condition_params = compiler.compile(self.condition)
|
||||
# Filters that match everything are handled as empty strings in the
|
||||
# WHERE clause, but in a CASE WHEN expression they must use a predicate
|
||||
# that's always True.
|
||||
if condition_sql == "":
|
||||
if connection.features.supports_boolean_expr_in_select_clause:
|
||||
condition_sql, condition_params = compiler.compile(Value(True))
|
||||
else:
|
||||
condition_sql, condition_params = "1=1", ()
|
||||
template_params["condition"] = condition_sql
|
||||
result_sql, result_params = compiler.compile(self.result)
|
||||
template_params["result"] = result_sql
|
||||
|
@ -1461,14 +1455,17 @@ class Case(SQLiteNumericMixin, Expression):
|
|||
template_params = {**self.extra, **extra_context}
|
||||
case_parts = []
|
||||
sql_params = []
|
||||
default_sql, default_params = compiler.compile(self.default)
|
||||
for case in self.cases:
|
||||
try:
|
||||
case_sql, case_params = compiler.compile(case)
|
||||
except EmptyResultSet:
|
||||
continue
|
||||
except FullResultSet:
|
||||
default_sql, default_params = compiler.compile(case.result)
|
||||
break
|
||||
case_parts.append(case_sql)
|
||||
sql_params.extend(case_params)
|
||||
default_sql, default_params = compiler.compile(self.default)
|
||||
if not case_parts:
|
||||
return default_sql, default_params
|
||||
case_joiner = case_joiner or self.case_joiner
|
||||
|
|
|
@ -1103,15 +1103,6 @@ class BooleanField(Field):
|
|||
defaults = {"form_class": form_class, "required": False}
|
||||
return super().formfield(**{**defaults, **kwargs})
|
||||
|
||||
def select_format(self, compiler, sql, params):
|
||||
sql, params = super().select_format(compiler, sql, params)
|
||||
# Filters that match everything are handled as empty strings in the
|
||||
# WHERE clause, but in SELECT or GROUP BY list they must use a
|
||||
# predicate that's always True.
|
||||
if sql == "":
|
||||
sql = "1"
|
||||
return sql, params
|
||||
|
||||
|
||||
class CharField(Field):
|
||||
description = _("String (up to %(max_length)s)")
|
||||
|
|
|
@ -4,7 +4,7 @@ import re
|
|||
from functools import partial
|
||||
from itertools import chain
|
||||
|
||||
from django.core.exceptions import EmptyResultSet, FieldError
|
||||
from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet
|
||||
from django.db import DatabaseError, NotSupportedError
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
from django.db.models.expressions import F, OrderBy, RawSQL, Ref, Value
|
||||
|
@ -169,7 +169,7 @@ class SQLCompiler:
|
|||
expr = Ref(alias, expr)
|
||||
try:
|
||||
sql, params = self.compile(expr)
|
||||
except EmptyResultSet:
|
||||
except (EmptyResultSet, FullResultSet):
|
||||
continue
|
||||
sql, params = expr.select_format(self, sql, params)
|
||||
params_hash = make_hashable(params)
|
||||
|
@ -287,6 +287,8 @@ class SQLCompiler:
|
|||
sql, params = "0", ()
|
||||
else:
|
||||
sql, params = self.compile(Value(empty_result_set_value))
|
||||
except FullResultSet:
|
||||
sql, params = self.compile(Value(True))
|
||||
else:
|
||||
sql, params = col.select_format(self, sql, params)
|
||||
if alias is None and with_col_aliases:
|
||||
|
@ -721,9 +723,16 @@ class SQLCompiler:
|
|||
raise
|
||||
# Use a predicate that's always False.
|
||||
where, w_params = "0 = 1", []
|
||||
having, h_params = (
|
||||
self.compile(self.having) if self.having is not None else ("", [])
|
||||
)
|
||||
except FullResultSet:
|
||||
where, w_params = "", []
|
||||
try:
|
||||
having, h_params = (
|
||||
self.compile(self.having)
|
||||
if self.having is not None
|
||||
else ("", [])
|
||||
)
|
||||
except FullResultSet:
|
||||
having, h_params = "", []
|
||||
result = ["SELECT"]
|
||||
params = []
|
||||
|
||||
|
@ -1817,11 +1826,12 @@ class SQLDeleteCompiler(SQLCompiler):
|
|||
)
|
||||
|
||||
def _as_sql(self, query):
|
||||
result = ["DELETE FROM %s" % self.quote_name_unless_alias(query.base_table)]
|
||||
where, params = self.compile(query.where)
|
||||
if where:
|
||||
result.append("WHERE %s" % where)
|
||||
return " ".join(result), tuple(params)
|
||||
delete = "DELETE FROM %s" % self.quote_name_unless_alias(query.base_table)
|
||||
try:
|
||||
where, params = self.compile(query.where)
|
||||
except FullResultSet:
|
||||
return delete, ()
|
||||
return f"{delete} WHERE {where}", tuple(params)
|
||||
|
||||
def as_sql(self):
|
||||
"""
|
||||
|
@ -1906,8 +1916,11 @@ class SQLUpdateCompiler(SQLCompiler):
|
|||
"UPDATE %s SET" % qn(table),
|
||||
", ".join(values),
|
||||
]
|
||||
where, params = self.compile(self.query.where)
|
||||
if where:
|
||||
try:
|
||||
where, params = self.compile(self.query.where)
|
||||
except FullResultSet:
|
||||
params = []
|
||||
else:
|
||||
result.append("WHERE %s" % where)
|
||||
return " ".join(result), tuple(update_params + params)
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
Useful auxiliary data structures for query construction. Not useful outside
|
||||
the SQL domain.
|
||||
"""
|
||||
from django.core.exceptions import FullResultSet
|
||||
from django.db.models.sql.constants import INNER, LOUTER
|
||||
|
||||
|
||||
|
@ -100,8 +101,11 @@ class Join:
|
|||
join_conditions.append("(%s)" % extra_sql)
|
||||
params.extend(extra_params)
|
||||
if self.filtered_relation:
|
||||
extra_sql, extra_params = compiler.compile(self.filtered_relation)
|
||||
if extra_sql:
|
||||
try:
|
||||
extra_sql, extra_params = compiler.compile(self.filtered_relation)
|
||||
except FullResultSet:
|
||||
pass
|
||||
else:
|
||||
join_conditions.append("(%s)" % extra_sql)
|
||||
params.extend(extra_params)
|
||||
if not join_conditions:
|
||||
|
|
|
@ -4,7 +4,7 @@ Code to manage the creation and SQL rendering of 'where' constraints.
|
|||
import operator
|
||||
from functools import reduce
|
||||
|
||||
from django.core.exceptions import EmptyResultSet
|
||||
from django.core.exceptions import EmptyResultSet, FullResultSet
|
||||
from django.db.models.expressions import Case, When
|
||||
from django.db.models.lookups import Exact
|
||||
from django.utils import tree
|
||||
|
@ -145,6 +145,8 @@ class WhereNode(tree.Node):
|
|||
sql, params = compiler.compile(child)
|
||||
except EmptyResultSet:
|
||||
empty_needed -= 1
|
||||
except FullResultSet:
|
||||
full_needed -= 1
|
||||
else:
|
||||
if sql:
|
||||
result.append(sql)
|
||||
|
@ -158,24 +160,25 @@ class WhereNode(tree.Node):
|
|||
# counts.
|
||||
if empty_needed == 0:
|
||||
if self.negated:
|
||||
return "", []
|
||||
raise FullResultSet
|
||||
else:
|
||||
raise EmptyResultSet
|
||||
if full_needed == 0:
|
||||
if self.negated:
|
||||
raise EmptyResultSet
|
||||
else:
|
||||
return "", []
|
||||
raise FullResultSet
|
||||
conn = " %s " % self.connector
|
||||
sql_string = conn.join(result)
|
||||
if sql_string:
|
||||
if self.negated:
|
||||
# Some backends (Oracle at least) need parentheses
|
||||
# around the inner SQL in the negated case, even if the
|
||||
# inner SQL contains just a single expression.
|
||||
sql_string = "NOT (%s)" % sql_string
|
||||
elif len(result) > 1 or self.resolved:
|
||||
sql_string = "(%s)" % sql_string
|
||||
if not sql_string:
|
||||
raise FullResultSet
|
||||
if self.negated:
|
||||
# Some backends (Oracle at least) need parentheses around the inner
|
||||
# SQL in the negated case, even if the inner SQL contains just a
|
||||
# single expression.
|
||||
sql_string = "NOT (%s)" % sql_string
|
||||
elif len(result) > 1 or self.resolved:
|
||||
sql_string = "(%s)" % sql_string
|
||||
return sql_string, result_params
|
||||
|
||||
def get_group_by_cols(self):
|
||||
|
|
|
@ -42,6 +42,17 @@ Django core exception classes are defined in ``django.core.exceptions``.
|
|||
return any results. Most Django projects won't encounter this exception,
|
||||
but it might be useful for implementing custom lookups and expressions.
|
||||
|
||||
``FullResultSet``
|
||||
-----------------
|
||||
|
||||
.. exception:: FullResultSet
|
||||
|
||||
.. versionadded:: 4.2
|
||||
|
||||
``FullResultSet`` may be raised during query generation if a query will
|
||||
match everything. Most Django projects won't encounter this exception, but
|
||||
it might be useful for implementing custom lookups and expressions.
|
||||
|
||||
``FieldDoesNotExist``
|
||||
---------------------
|
||||
|
||||
|
|
|
@ -24,7 +24,15 @@ from django.db.models import (
|
|||
When,
|
||||
)
|
||||
from django.db.models.expressions import RawSQL
|
||||
from django.db.models.functions import Coalesce, ExtractYear, Floor, Length, Lower, Trim
|
||||
from django.db.models.functions import (
|
||||
Cast,
|
||||
Coalesce,
|
||||
ExtractYear,
|
||||
Floor,
|
||||
Length,
|
||||
Lower,
|
||||
Trim,
|
||||
)
|
||||
from django.test import TestCase, skipUnlessDBFeature
|
||||
from django.test.utils import register_lookup
|
||||
|
||||
|
@ -282,6 +290,13 @@ class NonAggregateAnnotationTestCase(TestCase):
|
|||
self.assertEqual(len(books), Book.objects.count())
|
||||
self.assertTrue(all(book.selected for book in books))
|
||||
|
||||
def test_full_expression_wrapped_annotation(self):
|
||||
books = Book.objects.annotate(
|
||||
selected=Coalesce(~Q(pk__in=[]), True),
|
||||
)
|
||||
self.assertEqual(len(books), Book.objects.count())
|
||||
self.assertTrue(all(book.selected for book in books))
|
||||
|
||||
def test_full_expression_annotation_with_aggregation(self):
|
||||
qs = Book.objects.filter(isbn="159059725").annotate(
|
||||
selected=ExpressionWrapper(~Q(pk__in=[]), output_field=BooleanField()),
|
||||
|
@ -292,7 +307,7 @@ class NonAggregateAnnotationTestCase(TestCase):
|
|||
def test_aggregate_over_full_expression_annotation(self):
|
||||
qs = Book.objects.annotate(
|
||||
selected=ExpressionWrapper(~Q(pk__in=[]), output_field=BooleanField()),
|
||||
).aggregate(Sum("selected"))
|
||||
).aggregate(selected__sum=Sum(Cast("selected", IntegerField())))
|
||||
self.assertEqual(qs["selected__sum"], Book.objects.count())
|
||||
|
||||
def test_empty_queryset_annotation(self):
|
||||
|
|
|
@ -5,7 +5,7 @@ import unittest
|
|||
from operator import attrgetter
|
||||
from threading import Lock
|
||||
|
||||
from django.core.exceptions import EmptyResultSet, FieldError
|
||||
from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet
|
||||
from django.db import DEFAULT_DB_ALIAS, connection
|
||||
from django.db.models import CharField, Count, Exists, F, Max, OuterRef, Q
|
||||
from django.db.models.expressions import RawSQL
|
||||
|
@ -3588,7 +3588,8 @@ class WhereNodeTest(SimpleTestCase):
|
|||
with self.assertRaises(EmptyResultSet):
|
||||
w.as_sql(compiler, connection)
|
||||
w.negate()
|
||||
self.assertEqual(w.as_sql(compiler, connection), ("", []))
|
||||
with self.assertRaises(FullResultSet):
|
||||
w.as_sql(compiler, connection)
|
||||
w = WhereNode(children=[self.DummyNode(), self.DummyNode()])
|
||||
self.assertEqual(w.as_sql(compiler, connection), ("(dummy AND dummy)", []))
|
||||
w.negate()
|
||||
|
@ -3597,7 +3598,8 @@ class WhereNodeTest(SimpleTestCase):
|
|||
with self.assertRaises(EmptyResultSet):
|
||||
w.as_sql(compiler, connection)
|
||||
w.negate()
|
||||
self.assertEqual(w.as_sql(compiler, connection), ("", []))
|
||||
with self.assertRaises(FullResultSet):
|
||||
w.as_sql(compiler, connection)
|
||||
|
||||
def test_empty_full_handling_disjunction(self):
|
||||
compiler = WhereNodeTest.MockCompiler()
|
||||
|
@ -3605,7 +3607,8 @@ class WhereNodeTest(SimpleTestCase):
|
|||
with self.assertRaises(EmptyResultSet):
|
||||
w.as_sql(compiler, connection)
|
||||
w.negate()
|
||||
self.assertEqual(w.as_sql(compiler, connection), ("", []))
|
||||
with self.assertRaises(FullResultSet):
|
||||
w.as_sql(compiler, connection)
|
||||
w = WhereNode(children=[self.DummyNode(), self.DummyNode()], connector=OR)
|
||||
self.assertEqual(w.as_sql(compiler, connection), ("(dummy OR dummy)", []))
|
||||
w.negate()
|
||||
|
@ -3619,7 +3622,8 @@ class WhereNodeTest(SimpleTestCase):
|
|||
compiler = WhereNodeTest.MockCompiler()
|
||||
empty_w = WhereNode()
|
||||
w = WhereNode(children=[empty_w, empty_w])
|
||||
self.assertEqual(w.as_sql(compiler, connection), ("", []))
|
||||
with self.assertRaises(FullResultSet):
|
||||
w.as_sql(compiler, connection)
|
||||
w.negate()
|
||||
with self.assertRaises(EmptyResultSet):
|
||||
w.as_sql(compiler, connection)
|
||||
|
@ -3627,9 +3631,11 @@ class WhereNodeTest(SimpleTestCase):
|
|||
with self.assertRaises(EmptyResultSet):
|
||||
w.as_sql(compiler, connection)
|
||||
w.negate()
|
||||
self.assertEqual(w.as_sql(compiler, connection), ("", []))
|
||||
with self.assertRaises(FullResultSet):
|
||||
w.as_sql(compiler, connection)
|
||||
w = WhereNode(children=[empty_w, NothingNode()], connector=OR)
|
||||
self.assertEqual(w.as_sql(compiler, connection), ("", []))
|
||||
with self.assertRaises(FullResultSet):
|
||||
w.as_sql(compiler, connection)
|
||||
w = WhereNode(children=[empty_w, NothingNode()], connector=AND)
|
||||
with self.assertRaises(EmptyResultSet):
|
||||
w.as_sql(compiler, connection)
|
||||
|
|
Loading…
Reference in New Issue