Fixed #16211 -- Added logical NOT support to F expressions.

This commit is contained in:
David Wobrock 2022-09-26 22:59:25 +02:00 committed by Mariusz Felisiak
parent c01e76c95c
commit a320aab512
7 changed files with 164 additions and 29 deletions

View File

@ -162,6 +162,9 @@ class Combinable:
"Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
)
def __invert__(self):
return NegatedExpression(self)
class BaseExpression:
"""Base class for all query expressions."""
@ -827,6 +830,9 @@ class F(Combinable):
def __hash__(self):
return hash(self.name)
def copy(self):
return copy.copy(self)
class ResolvedOuterRef(F):
"""
@ -1252,6 +1258,57 @@ class ExpressionWrapper(SQLiteNumericMixin, Expression):
return "{}({})".format(self.__class__.__name__, self.expression)
class NegatedExpression(ExpressionWrapper):
"""The logical negation of a conditional expression."""
def __init__(self, expression):
super().__init__(expression, output_field=fields.BooleanField())
def __invert__(self):
return self.expression.copy()
def as_sql(self, compiler, connection):
try:
sql, params = super().as_sql(compiler, connection)
except EmptyResultSet:
features = compiler.connection.features
if not features.supports_boolean_expr_in_select_clause:
return "1=1", ()
return compiler.compile(Value(True))
ops = compiler.connection.ops
# Some database backends (e.g. Oracle) don't allow EXISTS() and filters
# to be compared to another expression unless they're wrapped in a CASE
# WHEN.
if not ops.conditional_expression_supported_in_where_clause(self.expression):
return f"CASE WHEN {sql} = 0 THEN 1 ELSE 0 END", params
return f"NOT {sql}", params
def resolve_expression(
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
):
resolved = super().resolve_expression(
query, allow_joins, reuse, summarize, for_save
)
if not getattr(resolved.expression, "conditional", False):
raise TypeError("Cannot negate non-conditional expressions.")
return resolved
def select_format(self, compiler, sql, params):
# Wrap boolean expressions with a CASE WHEN expression if a database
# backend (e.g. Oracle) doesn't support boolean expression in SELECT or
# GROUP BY list.
expression_supported_in_where_clause = (
compiler.connection.ops.conditional_expression_supported_in_where_clause
)
if (
not compiler.connection.features.supports_boolean_expr_in_select_clause
# Avoid double wrapping.
and expression_supported_in_where_clause(self.expression)
):
sql = "CASE WHEN {} THEN 1 ELSE 0 END".format(sql)
return sql, params
@deconstructible(path="django.db.models.When")
class When(Expression):
template = "WHEN %(condition)s THEN %(result)s"
@ -1486,34 +1543,10 @@ class Exists(Subquery):
template = "EXISTS(%(subquery)s)"
output_field = fields.BooleanField()
def __init__(self, queryset, negated=False, **kwargs):
self.negated = negated
def __init__(self, queryset, **kwargs):
super().__init__(queryset, **kwargs)
self.query = self.query.exists()
def __invert__(self):
clone = self.copy()
clone.negated = not self.negated
return clone
def as_sql(self, compiler, connection, **extra_context):
try:
sql, params = super().as_sql(
compiler,
connection,
**extra_context,
)
except EmptyResultSet:
if self.negated:
features = compiler.connection.features
if not features.supports_boolean_expr_in_select_clause:
return "1=1", ()
return compiler.compile(Value(True))
raise
if self.negated:
sql = "NOT {}".format(sql)
return sql, params
def select_format(self, compiler, sql, params):
# Wrap EXISTS() with a CASE WHEN expression if a database backend
# (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP

View File

@ -255,6 +255,19 @@ is null) after companies that have been contacted::
from django.db.models import F
Company.objects.order_by(F('last_contacted').desc(nulls_last=True))
Using ``F()`` with logical operations
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. versionadded:: 4.2
``F()`` expressions that output ``BooleanField`` can be logically negated with
the inversion operator ``~F()``. For example, to swap the activation status of
companies::
from django.db.models import F
Company.objects.update(is_active=~F('is_active'))
.. _func-expressions:
``Func()`` expressions

View File

@ -236,6 +236,9 @@ Models
* :class:`~django.db.models.functions.Now` now supports microsecond precision
on MySQL and millisecond precision on SQLite.
* :class:`F() <django.db.models.F>` expressions that output ``BooleanField``
can now be negated using ``~F()`` (inversion operator).
Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~
@ -345,6 +348,9 @@ Miscellaneous
* The minimum supported version of ``sqlparse`` is increased from 0.2.2 to
0.2.3.
* The undocumented ``negated`` parameter of the
:class:`~django.db.models.Exists` expression is removed.
.. _deprecated-features-4.2:
Features deprecated in 4.2

View File

@ -48,6 +48,7 @@ from django.db.models.expressions import (
Col,
Combinable,
CombinedExpression,
NegatedExpression,
RawSQL,
Ref,
)
@ -2536,6 +2537,61 @@ class ExpressionWrapperTests(SimpleTestCase):
self.assertEqual(group_by_cols[0].output_field, expr.output_field)
class NegatedExpressionTests(TestCase):
@classmethod
def setUpTestData(cls):
ceo = Employee.objects.create(firstname="Joe", lastname="Smith", salary=10)
cls.eu_company = Company.objects.create(
name="Example Inc.",
num_employees=2300,
num_chairs=5,
ceo=ceo,
based_in_eu=True,
)
cls.non_eu_company = Company.objects.create(
name="Foobar Ltd.",
num_employees=3,
num_chairs=4,
ceo=ceo,
based_in_eu=False,
)
def test_invert(self):
f = F("field")
self.assertEqual(~f, NegatedExpression(f))
self.assertIsNot(~~f, f)
self.assertEqual(~~f, f)
def test_filter(self):
self.assertSequenceEqual(
Company.objects.filter(~F("based_in_eu")),
[self.non_eu_company],
)
qs = Company.objects.annotate(eu_required=~Value(False))
self.assertSequenceEqual(
qs.filter(based_in_eu=F("eu_required")).order_by("eu_required"),
[self.eu_company],
)
self.assertSequenceEqual(
qs.filter(based_in_eu=~~F("eu_required")),
[self.eu_company],
)
self.assertSequenceEqual(
qs.filter(based_in_eu=~F("eu_required")),
[self.non_eu_company],
)
self.assertSequenceEqual(qs.filter(based_in_eu=~F("based_in_eu")), [])
def test_values(self):
self.assertSequenceEqual(
Company.objects.annotate(negated=~F("based_in_eu"))
.values_list("name", "negated")
.order_by("name"),
[("Example Inc.", False), ("Foobar Ltd.", True)],
)
class OrderByTests(SimpleTestCase):
def test_equal(self):
self.assertEqual(

View File

@ -8,7 +8,7 @@ from django.db.models import (
Q,
Value,
)
from django.db.models.expressions import RawSQL
from django.db.models.expressions import NegatedExpression, RawSQL
from django.db.models.functions import Lower
from django.db.models.sql.where import NothingNode
from django.test import SimpleTestCase, TestCase
@ -87,7 +87,7 @@ class QTests(SimpleTestCase):
]
for q in tests:
with self.subTest(q=q):
self.assertIs(q.negated, True)
self.assertIsInstance(q, NegatedExpression)
def test_deconstruct(self):
q = Q(price__gt=F("discounted_price"))

View File

@ -10,6 +10,7 @@ class DataPoint(models.Model):
name = models.CharField(max_length=20)
value = models.CharField(max_length=20)
another_value = models.CharField(max_length=20, blank=True)
is_active = models.BooleanField(default=True)
class RelatedPoint(models.Model):

View File

@ -2,7 +2,7 @@ import unittest
from django.core.exceptions import FieldError
from django.db import IntegrityError, connection, transaction
from django.db.models import CharField, Count, F, IntegerField, Max
from django.db.models import Case, CharField, Count, F, IntegerField, Max, When
from django.db.models.functions import Abs, Concat, Lower
from django.test import TestCase
from django.test.utils import register_lookup
@ -81,7 +81,7 @@ class AdvancedTests(TestCase):
def setUpTestData(cls):
cls.d0 = DataPoint.objects.create(name="d0", value="apple")
cls.d2 = DataPoint.objects.create(name="d2", value="banana")
cls.d3 = DataPoint.objects.create(name="d3", value="banana")
cls.d3 = DataPoint.objects.create(name="d3", value="banana", is_active=False)
cls.r1 = RelatedPoint.objects.create(name="r1", data=cls.d3)
def test_update(self):
@ -249,6 +249,32 @@ class AdvancedTests(TestCase):
Bar.objects.annotate(abs_id=Abs("m2m_foo")).order_by("abs_id").update(x=3)
self.assertEqual(Bar.objects.get().x, 3)
def test_update_negated_f(self):
DataPoint.objects.update(is_active=~F("is_active"))
self.assertCountEqual(
DataPoint.objects.values_list("name", "is_active"),
[("d0", False), ("d2", False), ("d3", True)],
)
DataPoint.objects.update(is_active=~F("is_active"))
self.assertCountEqual(
DataPoint.objects.values_list("name", "is_active"),
[("d0", True), ("d2", True), ("d3", False)],
)
def test_update_negated_f_conditional_annotation(self):
DataPoint.objects.annotate(
is_d2=Case(When(name="d2", then=True), default=False)
).update(is_active=~F("is_d2"))
self.assertCountEqual(
DataPoint.objects.values_list("name", "is_active"),
[("d0", True), ("d2", False), ("d3", True)],
)
def test_updating_non_conditional_field(self):
msg = "Cannot negate non-conditional expressions."
with self.assertRaisesMessage(TypeError, msg):
DataPoint.objects.update(is_active=~F("name"))
@unittest.skipUnless(
connection.vendor == "mysql",