Fixed #34604 -- Corrected fallback SQL for n-ary logical XOR.

An n-ary logical XOR Q(…) ^ Q(…) ^ … ^ Q(…) should evaluate to true
when an odd number of its operands evaluate to true, not when exactly
one operand evaluates to true.
This commit is contained in:
Anders Kaseorg 2023-05-29 21:59:22 -07:00 committed by Mariusz Felisiak
parent ee36e101e8
commit b81e974e9e
4 changed files with 44 additions and 4 deletions

View File

@ -6,6 +6,7 @@ from functools import reduce
from django.core.exceptions import EmptyResultSet, FullResultSet from django.core.exceptions import EmptyResultSet, FullResultSet
from django.db.models.expressions import Case, When from django.db.models.expressions import Case, When
from django.db.models.functions import Mod
from django.db.models.lookups import Exact from django.db.models.lookups import Exact
from django.utils import tree from django.utils import tree
from django.utils.functional import cached_property from django.utils.functional import cached_property
@ -129,12 +130,16 @@ class WhereNode(tree.Node):
# Convert if the database doesn't support XOR: # Convert if the database doesn't support XOR:
# a XOR b XOR c XOR ... # a XOR b XOR c XOR ...
# to: # to:
# (a OR b OR c OR ...) AND (a + b + c + ...) == 1 # (a OR b OR c OR ...) AND MOD(a + b + c + ..., 2) == 1
# The result of an n-ary XOR is true when an odd number of operands
# are true.
lhs = self.__class__(self.children, OR) lhs = self.__class__(self.children, OR)
rhs_sum = reduce( rhs_sum = reduce(
operator.add, operator.add,
(Case(When(c, then=1), default=0) for c in self.children), (Case(When(c, then=1), default=0) for c in self.children),
) )
if len(self.children) > 2:
rhs_sum = Mod(rhs_sum, 2)
rhs = Exact(1, rhs_sum) rhs = Exact(1, rhs_sum)
return self.__class__([lhs, rhs], AND, self.negated).as_sql( return self.__class__([lhs, rhs], AND, self.negated).as_sql(
compiler, connection compiler, connection

View File

@ -2021,7 +2021,8 @@ may be generated.
XOR (``^``) XOR (``^``)
~~~~~~~~~~~ ~~~~~~~~~~~
Combines two ``QuerySet``\s using the SQL ``XOR`` operator. Combines two ``QuerySet``\s using the SQL ``XOR`` operator. A ``XOR``
expression matches rows that are matched by an odd number of operands.
The following are equivalent:: The following are equivalent::
@ -2044,13 +2045,21 @@ SQL equivalent:
.. code-block:: sql .. code-block:: sql
(x OR y OR ... OR z) AND (x OR y OR ... OR z) AND
1=( 1=MOD(
(CASE WHEN x THEN 1 ELSE 0 END) + (CASE WHEN x THEN 1 ELSE 0 END) +
(CASE WHEN y THEN 1 ELSE 0 END) + (CASE WHEN y THEN 1 ELSE 0 END) +
... ...
(CASE WHEN z THEN 1 ELSE 0 END) + (CASE WHEN z THEN 1 ELSE 0 END),
2
) )
.. versionchanged:: 5.0
In older versions, on databases without native support for the SQL
``XOR`` operator, ``XOR`` returned rows that were matched by exactly
one operand. The previous behavior was not consistent with MySQL,
MariaDB, and Python behavior.
Methods that do not return ``QuerySet``\s Methods that do not return ``QuerySet``\s
----------------------------------------- -----------------------------------------

View File

@ -424,6 +424,11 @@ Miscellaneous
a page. Having two ``<h1>`` elements was confusing and the site header wasn't a page. Having two ``<h1>`` elements was confusing and the site header wasn't
helpful as it is repeated on all pages. helpful as it is repeated on all pages.
* On databases without native support for the SQL ``XOR`` operator, ``^`` as
the exclusive or (``XOR``) operator now returns rows that are matched by an
odd number of operands rather than exactly one operand. This is consistent
with the behavior of MySQL, MariaDB, and Python.
.. _deprecated-features-5.0: .. _deprecated-features-5.0:
Features deprecated in 5.0 Features deprecated in 5.0

View File

@ -19,6 +19,27 @@ class XorLookupsTests(TestCase):
self.numbers[:3] + self.numbers[8:], self.numbers[:3] + self.numbers[8:],
) )
def test_filter_multiple(self):
qs = Number.objects.filter(
Q(num__gte=1)
^ Q(num__gte=3)
^ Q(num__gte=5)
^ Q(num__gte=7)
^ Q(num__gte=9)
)
self.assertCountEqual(
qs,
self.numbers[1:3] + self.numbers[5:7] + self.numbers[9:],
)
self.assertCountEqual(
qs.values_list("num", flat=True),
[
i
for i in range(10)
if (i >= 1) ^ (i >= 3) ^ (i >= 5) ^ (i >= 7) ^ (i >= 9)
],
)
def test_filter_negated(self): def test_filter_negated(self):
self.assertCountEqual( self.assertCountEqual(
Number.objects.filter(Q(num__lte=7) ^ ~Q(num__lt=3)), Number.objects.filter(Q(num__lte=7) ^ ~Q(num__lt=3)),