mirror of https://github.com/django/django.git
Fixed #16211 -- Added logical NOT support to F expressions.
This commit is contained in:
parent
c01e76c95c
commit
a320aab512
|
@ -162,6 +162,9 @@ class Combinable:
|
||||||
"Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
|
"Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __invert__(self):
|
||||||
|
return NegatedExpression(self)
|
||||||
|
|
||||||
|
|
||||||
class BaseExpression:
|
class BaseExpression:
|
||||||
"""Base class for all query expressions."""
|
"""Base class for all query expressions."""
|
||||||
|
@ -827,6 +830,9 @@ class F(Combinable):
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
return hash(self.name)
|
return hash(self.name)
|
||||||
|
|
||||||
|
def copy(self):
|
||||||
|
return copy.copy(self)
|
||||||
|
|
||||||
|
|
||||||
class ResolvedOuterRef(F):
|
class ResolvedOuterRef(F):
|
||||||
"""
|
"""
|
||||||
|
@ -1252,6 +1258,57 @@ class ExpressionWrapper(SQLiteNumericMixin, Expression):
|
||||||
return "{}({})".format(self.__class__.__name__, self.expression)
|
return "{}({})".format(self.__class__.__name__, self.expression)
|
||||||
|
|
||||||
|
|
||||||
|
class NegatedExpression(ExpressionWrapper):
|
||||||
|
"""The logical negation of a conditional expression."""
|
||||||
|
|
||||||
|
def __init__(self, expression):
|
||||||
|
super().__init__(expression, output_field=fields.BooleanField())
|
||||||
|
|
||||||
|
def __invert__(self):
|
||||||
|
return self.expression.copy()
|
||||||
|
|
||||||
|
def as_sql(self, compiler, connection):
|
||||||
|
try:
|
||||||
|
sql, params = super().as_sql(compiler, connection)
|
||||||
|
except EmptyResultSet:
|
||||||
|
features = compiler.connection.features
|
||||||
|
if not features.supports_boolean_expr_in_select_clause:
|
||||||
|
return "1=1", ()
|
||||||
|
return compiler.compile(Value(True))
|
||||||
|
ops = compiler.connection.ops
|
||||||
|
# Some database backends (e.g. Oracle) don't allow EXISTS() and filters
|
||||||
|
# to be compared to another expression unless they're wrapped in a CASE
|
||||||
|
# WHEN.
|
||||||
|
if not ops.conditional_expression_supported_in_where_clause(self.expression):
|
||||||
|
return f"CASE WHEN {sql} = 0 THEN 1 ELSE 0 END", params
|
||||||
|
return f"NOT {sql}", params
|
||||||
|
|
||||||
|
def resolve_expression(
|
||||||
|
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
|
||||||
|
):
|
||||||
|
resolved = super().resolve_expression(
|
||||||
|
query, allow_joins, reuse, summarize, for_save
|
||||||
|
)
|
||||||
|
if not getattr(resolved.expression, "conditional", False):
|
||||||
|
raise TypeError("Cannot negate non-conditional expressions.")
|
||||||
|
return resolved
|
||||||
|
|
||||||
|
def select_format(self, compiler, sql, params):
|
||||||
|
# Wrap boolean expressions with a CASE WHEN expression if a database
|
||||||
|
# backend (e.g. Oracle) doesn't support boolean expression in SELECT or
|
||||||
|
# GROUP BY list.
|
||||||
|
expression_supported_in_where_clause = (
|
||||||
|
compiler.connection.ops.conditional_expression_supported_in_where_clause
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
not compiler.connection.features.supports_boolean_expr_in_select_clause
|
||||||
|
# Avoid double wrapping.
|
||||||
|
and expression_supported_in_where_clause(self.expression)
|
||||||
|
):
|
||||||
|
sql = "CASE WHEN {} THEN 1 ELSE 0 END".format(sql)
|
||||||
|
return sql, params
|
||||||
|
|
||||||
|
|
||||||
@deconstructible(path="django.db.models.When")
|
@deconstructible(path="django.db.models.When")
|
||||||
class When(Expression):
|
class When(Expression):
|
||||||
template = "WHEN %(condition)s THEN %(result)s"
|
template = "WHEN %(condition)s THEN %(result)s"
|
||||||
|
@ -1486,34 +1543,10 @@ class Exists(Subquery):
|
||||||
template = "EXISTS(%(subquery)s)"
|
template = "EXISTS(%(subquery)s)"
|
||||||
output_field = fields.BooleanField()
|
output_field = fields.BooleanField()
|
||||||
|
|
||||||
def __init__(self, queryset, negated=False, **kwargs):
|
def __init__(self, queryset, **kwargs):
|
||||||
self.negated = negated
|
|
||||||
super().__init__(queryset, **kwargs)
|
super().__init__(queryset, **kwargs)
|
||||||
self.query = self.query.exists()
|
self.query = self.query.exists()
|
||||||
|
|
||||||
def __invert__(self):
|
|
||||||
clone = self.copy()
|
|
||||||
clone.negated = not self.negated
|
|
||||||
return clone
|
|
||||||
|
|
||||||
def as_sql(self, compiler, connection, **extra_context):
|
|
||||||
try:
|
|
||||||
sql, params = super().as_sql(
|
|
||||||
compiler,
|
|
||||||
connection,
|
|
||||||
**extra_context,
|
|
||||||
)
|
|
||||||
except EmptyResultSet:
|
|
||||||
if self.negated:
|
|
||||||
features = compiler.connection.features
|
|
||||||
if not features.supports_boolean_expr_in_select_clause:
|
|
||||||
return "1=1", ()
|
|
||||||
return compiler.compile(Value(True))
|
|
||||||
raise
|
|
||||||
if self.negated:
|
|
||||||
sql = "NOT {}".format(sql)
|
|
||||||
return sql, params
|
|
||||||
|
|
||||||
def select_format(self, compiler, sql, params):
|
def select_format(self, compiler, sql, params):
|
||||||
# Wrap EXISTS() with a CASE WHEN expression if a database backend
|
# Wrap EXISTS() with a CASE WHEN expression if a database backend
|
||||||
# (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
|
# (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
|
||||||
|
|
|
@ -255,6 +255,19 @@ is null) after companies that have been contacted::
|
||||||
from django.db.models import F
|
from django.db.models import F
|
||||||
Company.objects.order_by(F('last_contacted').desc(nulls_last=True))
|
Company.objects.order_by(F('last_contacted').desc(nulls_last=True))
|
||||||
|
|
||||||
|
Using ``F()`` with logical operations
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. versionadded:: 4.2
|
||||||
|
|
||||||
|
``F()`` expressions that output ``BooleanField`` can be logically negated with
|
||||||
|
the inversion operator ``~F()``. For example, to swap the activation status of
|
||||||
|
companies::
|
||||||
|
|
||||||
|
from django.db.models import F
|
||||||
|
|
||||||
|
Company.objects.update(is_active=~F('is_active'))
|
||||||
|
|
||||||
.. _func-expressions:
|
.. _func-expressions:
|
||||||
|
|
||||||
``Func()`` expressions
|
``Func()`` expressions
|
||||||
|
|
|
@ -236,6 +236,9 @@ Models
|
||||||
* :class:`~django.db.models.functions.Now` now supports microsecond precision
|
* :class:`~django.db.models.functions.Now` now supports microsecond precision
|
||||||
on MySQL and millisecond precision on SQLite.
|
on MySQL and millisecond precision on SQLite.
|
||||||
|
|
||||||
|
* :class:`F() <django.db.models.F>` expressions that output ``BooleanField``
|
||||||
|
can now be negated using ``~F()`` (inversion operator).
|
||||||
|
|
||||||
Requests and Responses
|
Requests and Responses
|
||||||
~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
@ -345,6 +348,9 @@ Miscellaneous
|
||||||
* The minimum supported version of ``sqlparse`` is increased from 0.2.2 to
|
* The minimum supported version of ``sqlparse`` is increased from 0.2.2 to
|
||||||
0.2.3.
|
0.2.3.
|
||||||
|
|
||||||
|
* The undocumented ``negated`` parameter of the
|
||||||
|
:class:`~django.db.models.Exists` expression is removed.
|
||||||
|
|
||||||
.. _deprecated-features-4.2:
|
.. _deprecated-features-4.2:
|
||||||
|
|
||||||
Features deprecated in 4.2
|
Features deprecated in 4.2
|
||||||
|
|
|
@ -48,6 +48,7 @@ from django.db.models.expressions import (
|
||||||
Col,
|
Col,
|
||||||
Combinable,
|
Combinable,
|
||||||
CombinedExpression,
|
CombinedExpression,
|
||||||
|
NegatedExpression,
|
||||||
RawSQL,
|
RawSQL,
|
||||||
Ref,
|
Ref,
|
||||||
)
|
)
|
||||||
|
@ -2536,6 +2537,61 @@ class ExpressionWrapperTests(SimpleTestCase):
|
||||||
self.assertEqual(group_by_cols[0].output_field, expr.output_field)
|
self.assertEqual(group_by_cols[0].output_field, expr.output_field)
|
||||||
|
|
||||||
|
|
||||||
|
class NegatedExpressionTests(TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpTestData(cls):
|
||||||
|
ceo = Employee.objects.create(firstname="Joe", lastname="Smith", salary=10)
|
||||||
|
cls.eu_company = Company.objects.create(
|
||||||
|
name="Example Inc.",
|
||||||
|
num_employees=2300,
|
||||||
|
num_chairs=5,
|
||||||
|
ceo=ceo,
|
||||||
|
based_in_eu=True,
|
||||||
|
)
|
||||||
|
cls.non_eu_company = Company.objects.create(
|
||||||
|
name="Foobar Ltd.",
|
||||||
|
num_employees=3,
|
||||||
|
num_chairs=4,
|
||||||
|
ceo=ceo,
|
||||||
|
based_in_eu=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_invert(self):
|
||||||
|
f = F("field")
|
||||||
|
self.assertEqual(~f, NegatedExpression(f))
|
||||||
|
self.assertIsNot(~~f, f)
|
||||||
|
self.assertEqual(~~f, f)
|
||||||
|
|
||||||
|
def test_filter(self):
|
||||||
|
self.assertSequenceEqual(
|
||||||
|
Company.objects.filter(~F("based_in_eu")),
|
||||||
|
[self.non_eu_company],
|
||||||
|
)
|
||||||
|
|
||||||
|
qs = Company.objects.annotate(eu_required=~Value(False))
|
||||||
|
self.assertSequenceEqual(
|
||||||
|
qs.filter(based_in_eu=F("eu_required")).order_by("eu_required"),
|
||||||
|
[self.eu_company],
|
||||||
|
)
|
||||||
|
self.assertSequenceEqual(
|
||||||
|
qs.filter(based_in_eu=~~F("eu_required")),
|
||||||
|
[self.eu_company],
|
||||||
|
)
|
||||||
|
self.assertSequenceEqual(
|
||||||
|
qs.filter(based_in_eu=~F("eu_required")),
|
||||||
|
[self.non_eu_company],
|
||||||
|
)
|
||||||
|
self.assertSequenceEqual(qs.filter(based_in_eu=~F("based_in_eu")), [])
|
||||||
|
|
||||||
|
def test_values(self):
|
||||||
|
self.assertSequenceEqual(
|
||||||
|
Company.objects.annotate(negated=~F("based_in_eu"))
|
||||||
|
.values_list("name", "negated")
|
||||||
|
.order_by("name"),
|
||||||
|
[("Example Inc.", False), ("Foobar Ltd.", True)],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class OrderByTests(SimpleTestCase):
|
class OrderByTests(SimpleTestCase):
|
||||||
def test_equal(self):
|
def test_equal(self):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
|
|
|
@ -8,7 +8,7 @@ from django.db.models import (
|
||||||
Q,
|
Q,
|
||||||
Value,
|
Value,
|
||||||
)
|
)
|
||||||
from django.db.models.expressions import RawSQL
|
from django.db.models.expressions import NegatedExpression, RawSQL
|
||||||
from django.db.models.functions import Lower
|
from django.db.models.functions import Lower
|
||||||
from django.db.models.sql.where import NothingNode
|
from django.db.models.sql.where import NothingNode
|
||||||
from django.test import SimpleTestCase, TestCase
|
from django.test import SimpleTestCase, TestCase
|
||||||
|
@ -87,7 +87,7 @@ class QTests(SimpleTestCase):
|
||||||
]
|
]
|
||||||
for q in tests:
|
for q in tests:
|
||||||
with self.subTest(q=q):
|
with self.subTest(q=q):
|
||||||
self.assertIs(q.negated, True)
|
self.assertIsInstance(q, NegatedExpression)
|
||||||
|
|
||||||
def test_deconstruct(self):
|
def test_deconstruct(self):
|
||||||
q = Q(price__gt=F("discounted_price"))
|
q = Q(price__gt=F("discounted_price"))
|
||||||
|
|
|
@ -10,6 +10,7 @@ class DataPoint(models.Model):
|
||||||
name = models.CharField(max_length=20)
|
name = models.CharField(max_length=20)
|
||||||
value = models.CharField(max_length=20)
|
value = models.CharField(max_length=20)
|
||||||
another_value = models.CharField(max_length=20, blank=True)
|
another_value = models.CharField(max_length=20, blank=True)
|
||||||
|
is_active = models.BooleanField(default=True)
|
||||||
|
|
||||||
|
|
||||||
class RelatedPoint(models.Model):
|
class RelatedPoint(models.Model):
|
||||||
|
|
|
@ -2,7 +2,7 @@ import unittest
|
||||||
|
|
||||||
from django.core.exceptions import FieldError
|
from django.core.exceptions import FieldError
|
||||||
from django.db import IntegrityError, connection, transaction
|
from django.db import IntegrityError, connection, transaction
|
||||||
from django.db.models import CharField, Count, F, IntegerField, Max
|
from django.db.models import Case, CharField, Count, F, IntegerField, Max, When
|
||||||
from django.db.models.functions import Abs, Concat, Lower
|
from django.db.models.functions import Abs, Concat, Lower
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
from django.test.utils import register_lookup
|
from django.test.utils import register_lookup
|
||||||
|
@ -81,7 +81,7 @@ class AdvancedTests(TestCase):
|
||||||
def setUpTestData(cls):
|
def setUpTestData(cls):
|
||||||
cls.d0 = DataPoint.objects.create(name="d0", value="apple")
|
cls.d0 = DataPoint.objects.create(name="d0", value="apple")
|
||||||
cls.d2 = DataPoint.objects.create(name="d2", value="banana")
|
cls.d2 = DataPoint.objects.create(name="d2", value="banana")
|
||||||
cls.d3 = DataPoint.objects.create(name="d3", value="banana")
|
cls.d3 = DataPoint.objects.create(name="d3", value="banana", is_active=False)
|
||||||
cls.r1 = RelatedPoint.objects.create(name="r1", data=cls.d3)
|
cls.r1 = RelatedPoint.objects.create(name="r1", data=cls.d3)
|
||||||
|
|
||||||
def test_update(self):
|
def test_update(self):
|
||||||
|
@ -249,6 +249,32 @@ class AdvancedTests(TestCase):
|
||||||
Bar.objects.annotate(abs_id=Abs("m2m_foo")).order_by("abs_id").update(x=3)
|
Bar.objects.annotate(abs_id=Abs("m2m_foo")).order_by("abs_id").update(x=3)
|
||||||
self.assertEqual(Bar.objects.get().x, 3)
|
self.assertEqual(Bar.objects.get().x, 3)
|
||||||
|
|
||||||
|
def test_update_negated_f(self):
|
||||||
|
DataPoint.objects.update(is_active=~F("is_active"))
|
||||||
|
self.assertCountEqual(
|
||||||
|
DataPoint.objects.values_list("name", "is_active"),
|
||||||
|
[("d0", False), ("d2", False), ("d3", True)],
|
||||||
|
)
|
||||||
|
DataPoint.objects.update(is_active=~F("is_active"))
|
||||||
|
self.assertCountEqual(
|
||||||
|
DataPoint.objects.values_list("name", "is_active"),
|
||||||
|
[("d0", True), ("d2", True), ("d3", False)],
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_update_negated_f_conditional_annotation(self):
|
||||||
|
DataPoint.objects.annotate(
|
||||||
|
is_d2=Case(When(name="d2", then=True), default=False)
|
||||||
|
).update(is_active=~F("is_d2"))
|
||||||
|
self.assertCountEqual(
|
||||||
|
DataPoint.objects.values_list("name", "is_active"),
|
||||||
|
[("d0", True), ("d2", False), ("d3", True)],
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_updating_non_conditional_field(self):
|
||||||
|
msg = "Cannot negate non-conditional expressions."
|
||||||
|
with self.assertRaisesMessage(TypeError, msg):
|
||||||
|
DataPoint.objects.update(is_active=~F("name"))
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipUnless(
|
@unittest.skipUnless(
|
||||||
connection.vendor == "mysql",
|
connection.vendor == "mysql",
|
||||||
|
|
Loading…
Reference in New Issue