From a320aab5129f4019b3c1d28b7a3b509582bc56f9 Mon Sep 17 00:00:00 2001 From: David Wobrock Date: Mon, 26 Sep 2022 22:59:25 +0200 Subject: [PATCH] Fixed #16211 -- Added logical NOT support to F expressions. --- django/db/models/expressions.py | 83 +++++++++++++++++++++++---------- docs/ref/models/expressions.txt | 13 ++++++ docs/releases/4.2.txt | 6 +++ tests/expressions/tests.py | 56 ++++++++++++++++++++++ tests/queries/test_q.py | 4 +- tests/update/models.py | 1 + tests/update/tests.py | 30 +++++++++++- 7 files changed, 164 insertions(+), 29 deletions(-) diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 6c3bc5c4ded..8b04e1f11bc 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -162,6 +162,9 @@ class Combinable: "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations." ) + def __invert__(self): + return NegatedExpression(self) + class BaseExpression: """Base class for all query expressions.""" @@ -827,6 +830,9 @@ class F(Combinable): def __hash__(self): return hash(self.name) + def copy(self): + return copy.copy(self) + class ResolvedOuterRef(F): """ @@ -1252,6 +1258,57 @@ class ExpressionWrapper(SQLiteNumericMixin, Expression): return "{}({})".format(self.__class__.__name__, self.expression) +class NegatedExpression(ExpressionWrapper): + """The logical negation of a conditional expression.""" + + def __init__(self, expression): + super().__init__(expression, output_field=fields.BooleanField()) + + def __invert__(self): + return self.expression.copy() + + def as_sql(self, compiler, connection): + try: + sql, params = super().as_sql(compiler, connection) + except EmptyResultSet: + features = compiler.connection.features + if not features.supports_boolean_expr_in_select_clause: + return "1=1", () + return compiler.compile(Value(True)) + ops = compiler.connection.ops + # Some database backends (e.g. Oracle) don't allow EXISTS() and filters + # to be compared to another expression unless they're wrapped in a CASE + # WHEN. + if not ops.conditional_expression_supported_in_where_clause(self.expression): + return f"CASE WHEN {sql} = 0 THEN 1 ELSE 0 END", params + return f"NOT {sql}", params + + def resolve_expression( + self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False + ): + resolved = super().resolve_expression( + query, allow_joins, reuse, summarize, for_save + ) + if not getattr(resolved.expression, "conditional", False): + raise TypeError("Cannot negate non-conditional expressions.") + return resolved + + def select_format(self, compiler, sql, params): + # Wrap boolean expressions with a CASE WHEN expression if a database + # backend (e.g. Oracle) doesn't support boolean expression in SELECT or + # GROUP BY list. + expression_supported_in_where_clause = ( + compiler.connection.ops.conditional_expression_supported_in_where_clause + ) + if ( + not compiler.connection.features.supports_boolean_expr_in_select_clause + # Avoid double wrapping. + and expression_supported_in_where_clause(self.expression) + ): + sql = "CASE WHEN {} THEN 1 ELSE 0 END".format(sql) + return sql, params + + @deconstructible(path="django.db.models.When") class When(Expression): template = "WHEN %(condition)s THEN %(result)s" @@ -1486,34 +1543,10 @@ class Exists(Subquery): template = "EXISTS(%(subquery)s)" output_field = fields.BooleanField() - def __init__(self, queryset, negated=False, **kwargs): - self.negated = negated + def __init__(self, queryset, **kwargs): super().__init__(queryset, **kwargs) self.query = self.query.exists() - def __invert__(self): - clone = self.copy() - clone.negated = not self.negated - return clone - - def as_sql(self, compiler, connection, **extra_context): - try: - sql, params = super().as_sql( - compiler, - connection, - **extra_context, - ) - except EmptyResultSet: - if self.negated: - features = compiler.connection.features - if not features.supports_boolean_expr_in_select_clause: - return "1=1", () - return compiler.compile(Value(True)) - raise - if self.negated: - sql = "NOT {}".format(sql) - return sql, params - def select_format(self, compiler, sql, params): # Wrap EXISTS() with a CASE WHEN expression if a database backend # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP diff --git a/docs/ref/models/expressions.txt b/docs/ref/models/expressions.txt index ccd18670b1a..e1edad2e4dd 100644 --- a/docs/ref/models/expressions.txt +++ b/docs/ref/models/expressions.txt @@ -255,6 +255,19 @@ is null) after companies that have been contacted:: from django.db.models import F Company.objects.order_by(F('last_contacted').desc(nulls_last=True)) +Using ``F()`` with logical operations +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. versionadded:: 4.2 + +``F()`` expressions that output ``BooleanField`` can be logically negated with +the inversion operator ``~F()``. For example, to swap the activation status of +companies:: + + from django.db.models import F + + Company.objects.update(is_active=~F('is_active')) + .. _func-expressions: ``Func()`` expressions diff --git a/docs/releases/4.2.txt b/docs/releases/4.2.txt index 58e4bd12ac1..4f7bda70807 100644 --- a/docs/releases/4.2.txt +++ b/docs/releases/4.2.txt @@ -236,6 +236,9 @@ Models * :class:`~django.db.models.functions.Now` now supports microsecond precision on MySQL and millisecond precision on SQLite. +* :class:`F() ` expressions that output ``BooleanField`` + can now be negated using ``~F()`` (inversion operator). + Requests and Responses ~~~~~~~~~~~~~~~~~~~~~~ @@ -345,6 +348,9 @@ Miscellaneous * The minimum supported version of ``sqlparse`` is increased from 0.2.2 to 0.2.3. +* The undocumented ``negated`` parameter of the + :class:`~django.db.models.Exists` expression is removed. + .. _deprecated-features-4.2: Features deprecated in 4.2 diff --git a/tests/expressions/tests.py b/tests/expressions/tests.py index fd0094db63c..465edc54b56 100644 --- a/tests/expressions/tests.py +++ b/tests/expressions/tests.py @@ -48,6 +48,7 @@ from django.db.models.expressions import ( Col, Combinable, CombinedExpression, + NegatedExpression, RawSQL, Ref, ) @@ -2536,6 +2537,61 @@ class ExpressionWrapperTests(SimpleTestCase): self.assertEqual(group_by_cols[0].output_field, expr.output_field) +class NegatedExpressionTests(TestCase): + @classmethod + def setUpTestData(cls): + ceo = Employee.objects.create(firstname="Joe", lastname="Smith", salary=10) + cls.eu_company = Company.objects.create( + name="Example Inc.", + num_employees=2300, + num_chairs=5, + ceo=ceo, + based_in_eu=True, + ) + cls.non_eu_company = Company.objects.create( + name="Foobar Ltd.", + num_employees=3, + num_chairs=4, + ceo=ceo, + based_in_eu=False, + ) + + def test_invert(self): + f = F("field") + self.assertEqual(~f, NegatedExpression(f)) + self.assertIsNot(~~f, f) + self.assertEqual(~~f, f) + + def test_filter(self): + self.assertSequenceEqual( + Company.objects.filter(~F("based_in_eu")), + [self.non_eu_company], + ) + + qs = Company.objects.annotate(eu_required=~Value(False)) + self.assertSequenceEqual( + qs.filter(based_in_eu=F("eu_required")).order_by("eu_required"), + [self.eu_company], + ) + self.assertSequenceEqual( + qs.filter(based_in_eu=~~F("eu_required")), + [self.eu_company], + ) + self.assertSequenceEqual( + qs.filter(based_in_eu=~F("eu_required")), + [self.non_eu_company], + ) + self.assertSequenceEqual(qs.filter(based_in_eu=~F("based_in_eu")), []) + + def test_values(self): + self.assertSequenceEqual( + Company.objects.annotate(negated=~F("based_in_eu")) + .values_list("name", "negated") + .order_by("name"), + [("Example Inc.", False), ("Foobar Ltd.", True)], + ) + + class OrderByTests(SimpleTestCase): def test_equal(self): self.assertEqual( diff --git a/tests/queries/test_q.py b/tests/queries/test_q.py index 923846b5a3d..cdf40292b06 100644 --- a/tests/queries/test_q.py +++ b/tests/queries/test_q.py @@ -8,7 +8,7 @@ from django.db.models import ( Q, Value, ) -from django.db.models.expressions import RawSQL +from django.db.models.expressions import NegatedExpression, RawSQL from django.db.models.functions import Lower from django.db.models.sql.where import NothingNode from django.test import SimpleTestCase, TestCase @@ -87,7 +87,7 @@ class QTests(SimpleTestCase): ] for q in tests: with self.subTest(q=q): - self.assertIs(q.negated, True) + self.assertIsInstance(q, NegatedExpression) def test_deconstruct(self): q = Q(price__gt=F("discounted_price")) diff --git a/tests/update/models.py b/tests/update/models.py index d7452dc3021..d71fc887c7d 100644 --- a/tests/update/models.py +++ b/tests/update/models.py @@ -10,6 +10,7 @@ class DataPoint(models.Model): name = models.CharField(max_length=20) value = models.CharField(max_length=20) another_value = models.CharField(max_length=20, blank=True) + is_active = models.BooleanField(default=True) class RelatedPoint(models.Model): diff --git a/tests/update/tests.py b/tests/update/tests.py index 2162f5164d4..e88eeda96dd 100644 --- a/tests/update/tests.py +++ b/tests/update/tests.py @@ -2,7 +2,7 @@ import unittest from django.core.exceptions import FieldError from django.db import IntegrityError, connection, transaction -from django.db.models import CharField, Count, F, IntegerField, Max +from django.db.models import Case, CharField, Count, F, IntegerField, Max, When from django.db.models.functions import Abs, Concat, Lower from django.test import TestCase from django.test.utils import register_lookup @@ -81,7 +81,7 @@ class AdvancedTests(TestCase): def setUpTestData(cls): cls.d0 = DataPoint.objects.create(name="d0", value="apple") cls.d2 = DataPoint.objects.create(name="d2", value="banana") - cls.d3 = DataPoint.objects.create(name="d3", value="banana") + cls.d3 = DataPoint.objects.create(name="d3", value="banana", is_active=False) cls.r1 = RelatedPoint.objects.create(name="r1", data=cls.d3) def test_update(self): @@ -249,6 +249,32 @@ class AdvancedTests(TestCase): Bar.objects.annotate(abs_id=Abs("m2m_foo")).order_by("abs_id").update(x=3) self.assertEqual(Bar.objects.get().x, 3) + def test_update_negated_f(self): + DataPoint.objects.update(is_active=~F("is_active")) + self.assertCountEqual( + DataPoint.objects.values_list("name", "is_active"), + [("d0", False), ("d2", False), ("d3", True)], + ) + DataPoint.objects.update(is_active=~F("is_active")) + self.assertCountEqual( + DataPoint.objects.values_list("name", "is_active"), + [("d0", True), ("d2", True), ("d3", False)], + ) + + def test_update_negated_f_conditional_annotation(self): + DataPoint.objects.annotate( + is_d2=Case(When(name="d2", then=True), default=False) + ).update(is_active=~F("is_d2")) + self.assertCountEqual( + DataPoint.objects.values_list("name", "is_active"), + [("d0", True), ("d2", False), ("d3", True)], + ) + + def test_updating_non_conditional_field(self): + msg = "Cannot negate non-conditional expressions." + with self.assertRaisesMessage(TypeError, msg): + DataPoint.objects.update(is_active=~F("name")) + @unittest.skipUnless( connection.vendor == "mysql",