Fixed #16211 -- Implemented negating an F()-expression.

This commit is contained in:
David Wobrock 2022-09-26 22:59:25 +02:00
parent d795259ea9
commit fb3fccd196
No known key found for this signature in database
GPG Key ID: 4885899CFD92B563
7 changed files with 92 additions and 17 deletions

View File

@ -162,6 +162,9 @@ class Combinable:
"Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
)
def __invert__(self):
return NegatedExpression(self)
class BaseExpression:
"""Base class for all query expressions."""
@ -1252,6 +1255,38 @@ class ExpressionWrapper(SQLiteNumericMixin, Expression):
return "{}({})".format(self.__class__.__name__, self.expression)
class NegatedExpression(ExpressionWrapper):
"""
Wrapping an expression to negate its output.
"""
def __init__(self, expression):
super().__init__(expression, output_field=fields.BooleanField())
def __invert__(self):
return copy.copy(self.expression)
def as_sql(self, compiler, connection):
try:
sql, params = super().as_sql(compiler, connection)
except EmptyResultSet:
features = compiler.connection.features
if not features.supports_boolean_expr_in_select_clause:
return "1=1", ()
return compiler.compile(Value(True))
return f"NOT {sql}", params
def resolve_expression(
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
):
resolved_expr = super().resolve_expression(
query, allow_joins, reuse, summarize, for_save
)
if not resolved_expr.expression.conditional:
raise TypeError("Cannot negate non-conditional expression.")
return resolved_expr
@deconstructible(path="django.db.models.When")
class When(Expression):
template = "WHEN %(condition)s THEN %(result)s"
@ -1486,16 +1521,10 @@ class Exists(Subquery):
template = "EXISTS(%(subquery)s)"
output_field = fields.BooleanField()
def __init__(self, queryset, negated=False, **kwargs):
self.negated = negated
def __init__(self, queryset, **kwargs):
super().__init__(queryset, **kwargs)
self.query = self.query.exists()
def __invert__(self):
clone = self.copy()
clone.negated = not self.negated
return clone
def as_sql(self, compiler, connection, **extra_context):
try:
sql, params = super().as_sql(
@ -1504,14 +1533,7 @@ class Exists(Subquery):
**extra_context,
)
except EmptyResultSet:
if self.negated:
features = compiler.connection.features
if not features.supports_boolean_expr_in_select_clause:
return "1=1", ()
return compiler.compile(Value(True))
raise
if self.negated:
sql = "NOT {}".format(sql)
return sql, params
def select_format(self, compiler, sql, params):

View File

@ -255,6 +255,15 @@ is null) after companies that have been contacted::
from django.db.models import F
Company.objects.order_by(F('last_contacted').desc(nulls_last=True))
.. versionchanged:: 4.2
Support for inverting expression ``~F()`` was added.
For example, to deactivate active models::
from django.db.models import F
Company.objects.update(is_active=~F('is_active')
.. _func-expressions:
``Func()`` expressions

View File

@ -228,6 +228,8 @@ Models
* :class:`~django.db.models.functions.Now` now supports microsecond precision
on MySQL and millisecond precision on SQLite.
* :class:`~django.db.models.F` expressions can now be inverted using ``~`` in
order to negate the expression.
Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~

View File

@ -48,6 +48,7 @@ from django.db.models.expressions import (
Col,
Combinable,
CombinedExpression,
NegatedExpression,
RawSQL,
Ref,
)
@ -2381,6 +2382,11 @@ class CombinableTests(SimpleTestCase):
c = Combinable()
self.assertEqual(-c, c * -1)
def test_invert(self):
c = Combinable()
self.assertEqual(~c, NegatedExpression(c))
self.assertIsNot(~~c, c)
def test_and(self):
with self.assertRaisesMessage(NotImplementedError, self.bitwise_msg):
Combinable() & Combinable()
@ -2536,6 +2542,12 @@ class ExpressionWrapperTests(SimpleTestCase):
self.assertEqual(group_by_cols[0].output_field, expr.output_field)
class NegatedExpressionTests(SimpleTestCase):
def test_invert_negated_expression(self):
expr = NegatedExpression(F("field"))
self.assertEqual(~expr, F("field"))
class OrderByTests(SimpleTestCase):
def test_equal(self):
self.assertEqual(

View File

@ -8,7 +8,7 @@ from django.db.models import (
Q,
Value,
)
from django.db.models.expressions import RawSQL
from django.db.models.expressions import NegatedExpression, RawSQL
from django.db.models.functions import Lower
from django.db.models.sql.where import NothingNode
from django.test import SimpleTestCase, TestCase
@ -87,7 +87,7 @@ class QTests(SimpleTestCase):
]
for q in tests:
with self.subTest(q=q):
self.assertIs(q.negated, True)
self.assertIsInstance(q, NegatedExpression)
def test_deconstruct(self):
q = Q(price__gt=F("discounted_price"))

View File

@ -50,3 +50,8 @@ class UniqueNumber(models.Model):
class UniqueNumberChild(UniqueNumber):
pass
class CanBeActivatedModel(models.Model):
is_active = models.BooleanField()
x = models.IntegerField(default=0)

View File

@ -2,7 +2,7 @@ import unittest
from django.core.exceptions import FieldError
from django.db import IntegrityError, connection, transaction
from django.db.models import CharField, Count, F, IntegerField, Max
from django.db.models import Case, CharField, Count, F, IntegerField, Max, When
from django.db.models.functions import Abs, Concat, Lower
from django.test import TestCase
from django.test.utils import register_lookup
@ -11,6 +11,7 @@ from .models import (
A,
B,
Bar,
CanBeActivatedModel,
D,
DataPoint,
Foo,
@ -249,6 +250,30 @@ class AdvancedTests(TestCase):
Bar.objects.annotate(abs_id=Abs("m2m_foo")).order_by("abs_id").update(x=3)
self.assertEqual(Bar.objects.get().x, 3)
def test_update_negated_f(self):
CanBeActivatedModel.objects.create(is_active=True)
CanBeActivatedModel.objects.update(is_active=~F("is_active"))
self.assertIs(CanBeActivatedModel.objects.get().is_active, False)
CanBeActivatedModel.objects.update(is_active=~F("is_active"))
self.assertIs(CanBeActivatedModel.objects.get().is_active, True)
def test_update_negated_f_conditional_annotation(self):
CanBeActivatedModel.objects.create(x=2, is_active=True)
CanBeActivatedModel.objects.annotate(
is_x_positive=Case(When(x__gt=0, then=True), default=False)
).update(is_active=~F("is_x_positive"))
self.assertIs(CanBeActivatedModel.objects.get().is_active, False)
def test_updating_non_conditional_field(self):
CanBeActivatedModel.objects.create(is_active=True)
with self.assertRaisesMessage(
TypeError, "Cannot negate non-conditional expression."
):
CanBeActivatedModel.objects.update(is_active=~F("x"))
@unittest.skipUnless(
connection.vendor == "mysql",