Refs #373 -- Added tuple lookups.

This commit is contained in:
Bendeguz Csirmaz 2024-07-23 21:17:34 +08:00 committed by Sarah Boyce
parent 3dac3271d2
commit 1eac690d25
6 changed files with 555 additions and 69 deletions

View File

@ -152,6 +152,7 @@ answer newbie questions, and generally made Django that much better:
Ben Lomax <lomax.on.the.run@gmail.com>
Ben Slavin <benjamin.slavin@gmail.com>
Ben Sturmfels <ben@sturm.com.au>
Bendegúz Csirmaz <csirmazbendeguz@gmail.com>
Berker Peksag <berker.peksag@gmail.com>
Bernd Schlapsi
Bernhard Essl <me@bernhardessl.com>

View File

@ -1295,6 +1295,52 @@ class Col(Expression):
) + self.target.get_db_converters(connection)
class ColPairs(Expression):
def __init__(self, alias, targets, sources, output_field):
super().__init__(output_field=output_field)
self.alias, self.targets, self.sources = alias, targets, sources
def __len__(self):
return len(self.targets)
def __iter__(self):
return iter(self.get_cols())
def get_cols(self):
return [
Col(self.alias, target, source)
for target, source in zip(self.targets, self.sources)
]
def get_source_expressions(self):
return self.get_cols()
def set_source_expressions(self, exprs):
assert all(isinstance(expr, Col) and expr.alias == self.alias for expr in exprs)
self.targets = [col.target for col in exprs]
self.sources = [col.field for col in exprs]
def as_sql(self, compiler, connection):
cols_sql = []
cols_params = []
cols = self.get_cols()
for col in cols:
sql, params = col.as_sql(compiler, connection)
cols_sql.append(sql)
cols_params.extend(params)
return ", ".join(cols_sql), cols_params
def relabeled_clone(self, relabels):
return self.__class__(
relabels.get(self.alias, self.alias), self.targets, self.sources, self.field
)
def resolve_expression(self, *args, **kwargs):
return self
class Ref(Expression):
"""
Reference to column alias of the query. For example, Ref('sum_cost') in

View File

@ -1,3 +1,6 @@
from django.db import NotSupportedError
from django.db.models.expressions import ColPairs
from django.db.models.fields.tuple_lookups import TupleIn, tuple_lookups
from django.db.models.lookups import (
Exact,
GreaterThan,
@ -9,34 +12,6 @@ from django.db.models.lookups import (
)
class MultiColSource:
contains_aggregate = False
contains_over_clause = False
def __init__(self, alias, targets, sources, field):
self.targets, self.sources, self.field, self.alias = (
targets,
sources,
field,
alias,
)
self.output_field = self.field
def __repr__(self):
return "{}({}, {})".format(self.__class__.__name__, self.alias, self.field)
def relabeled_clone(self, relabels):
return self.__class__(
relabels.get(self.alias, self.alias), self.targets, self.sources, self.field
)
def get_lookup(self, lookup):
return self.output_field.get_lookup(lookup)
def resolve_expression(self, *args, **kwargs):
return self
def get_normalized_value(value, lhs):
from django.db.models import Model
@ -64,7 +39,7 @@ def get_normalized_value(value, lhs):
class RelatedIn(In):
def get_prep_lookup(self):
if not isinstance(self.lhs, MultiColSource):
if not isinstance(self.lhs, ColPairs):
if self.rhs_is_direct_value():
# If we get here, we are dealing with single-column relations.
self.rhs = [get_normalized_value(val, self.lhs)[0] for val in self.rhs]
@ -98,49 +73,33 @@ class RelatedIn(In):
return super().get_prep_lookup()
def as_sql(self, compiler, connection):
if isinstance(self.lhs, MultiColSource):
if isinstance(self.lhs, ColPairs):
# For multicolumn lookups we need to build a multicolumn where clause.
# This clause is either a SubqueryConstraint (for values that need
# to be compiled to SQL) or an OR-combined list of
# (col1 = val1 AND col2 = val2 AND ...) clauses.
from django.db.models.sql.where import (
AND,
OR,
SubqueryConstraint,
WhereNode,
)
from django.db.models.sql.where import SubqueryConstraint
root_constraint = WhereNode(connector=OR)
if self.rhs_is_direct_value():
values = [get_normalized_value(value, self.lhs) for value in self.rhs]
for value in values:
value_constraint = WhereNode()
for source, target, val in zip(
self.lhs.sources, self.lhs.targets, value
):
lookup_class = target.get_lookup("exact")
lookup = lookup_class(
target.get_col(self.lhs.alias, source), val
)
value_constraint.add(lookup, AND)
root_constraint.add(value_constraint, OR)
lookup = TupleIn(self.lhs, values)
return compiler.compile(lookup)
else:
root_constraint.add(
return compiler.compile(
SubqueryConstraint(
self.lhs.alias,
[target.column for target in self.lhs.targets],
[source.name for source in self.lhs.sources],
self.rhs,
),
AND,
)
return root_constraint.as_sql(compiler, connection)
return super().as_sql(compiler, connection)
class RelatedLookupMixin:
def get_prep_lookup(self):
if not isinstance(self.lhs, MultiColSource) and not hasattr(
if not isinstance(self.lhs, ColPairs) and not hasattr(
self.rhs, "resolve_expression"
):
# If we get here, we are dealing with single-column relations.
@ -158,20 +117,16 @@ class RelatedLookupMixin:
return super().get_prep_lookup()
def as_sql(self, compiler, connection):
if isinstance(self.lhs, MultiColSource):
assert self.rhs_is_direct_value()
self.rhs = get_normalized_value(self.rhs, self.lhs)
from django.db.models.sql.where import AND, WhereNode
root_constraint = WhereNode()
for target, source, val in zip(
self.lhs.targets, self.lhs.sources, self.rhs
):
lookup_class = target.get_lookup(self.lookup_name)
root_constraint.add(
lookup_class(target.get_col(self.lhs.alias, source), val), AND
if isinstance(self.lhs, ColPairs):
if not self.rhs_is_direct_value():
raise NotSupportedError(
f"'{self.lookup_name}' doesn't support multi-column subqueries."
)
return root_constraint.as_sql(compiler, connection)
self.rhs = get_normalized_value(self.rhs, self.lhs)
lookup_class = tuple_lookups[self.lookup_name]
lookup = lookup_class(self.lhs, self.rhs)
return compiler.compile(lookup)
return super().as_sql(compiler, connection)

View File

@ -0,0 +1,244 @@
import itertools
from django.core.exceptions import EmptyResultSet
from django.db.models.expressions import ColPairs, Func, Value
from django.db.models.lookups import (
Exact,
GreaterThan,
GreaterThanOrEqual,
In,
IsNull,
LessThan,
LessThanOrEqual,
)
from django.db.models.sql.where import AND, OR, WhereNode
class Tuple(Func):
function = ""
class TupleLookupMixin:
def get_prep_lookup(self):
self.check_tuple_lookup()
return super().get_prep_lookup()
def check_tuple_lookup(self):
assert isinstance(self.lhs, ColPairs)
self.check_rhs_is_tuple_or_list()
self.check_rhs_length_equals_lhs_length()
def check_rhs_is_tuple_or_list(self):
if not isinstance(self.rhs, (tuple, list)):
raise ValueError(
f"'{self.lookup_name}' lookup of '{self.lhs.field.name}' field "
"must be a tuple or a list"
)
def check_rhs_length_equals_lhs_length(self):
if len(self.lhs) != len(self.rhs):
raise ValueError(
f"'{self.lookup_name}' lookup of '{self.lhs.field.name}' field "
f"must have {len(self.lhs)} elements"
)
def check_rhs_is_collection_of_tuples_or_lists(self):
if not all(isinstance(vals, (tuple, list)) for vals in self.rhs):
raise ValueError(
f"'{self.lookup_name}' lookup of '{self.lhs.field.name}' field "
f"must be a collection of tuples or lists"
)
def check_rhs_elements_length_equals_lhs_length(self):
if not all(len(self.lhs) == len(vals) for vals in self.rhs):
raise ValueError(
f"'{self.lookup_name}' lookup of '{self.lhs.field.name}' field "
f"must have {len(self.lhs)} elements each"
)
def as_sql(self, compiler, connection):
# e.g.: (a, b, c) == (x, y, z) as SQL:
# WHERE (a, b, c) = (x, y, z)
vals = [
Value(val, output_field=col.output_field)
for col, val in zip(self.lhs, self.rhs)
]
lookup_class = self.__class__.__bases__[-1]
lookup = lookup_class(Tuple(self.lhs), Tuple(*vals))
return lookup.as_sql(compiler, connection)
class TupleExact(TupleLookupMixin, Exact):
def as_oracle(self, compiler, connection):
# e.g.: (a, b, c) == (x, y, z) as SQL:
# WHERE a = x AND b = y AND c = z
cols = self.lhs.get_cols()
lookups = [Exact(col, val) for col, val in zip(cols, self.rhs)]
root = WhereNode(lookups, connector=AND)
return root.as_sql(compiler, connection)
class TupleIsNull(IsNull):
def as_sql(self, compiler, connection):
# e.g.: (a, b, c) is None as SQL:
# WHERE a IS NULL AND b IS NULL AND c IS NULL
vals = self.rhs
if isinstance(vals, bool):
vals = [vals] * len(self.lhs)
cols = self.lhs.get_cols()
lookups = [IsNull(col, val) for col, val in zip(cols, vals)]
root = WhereNode(lookups, connector=AND)
return root.as_sql(compiler, connection)
class TupleGreaterThan(TupleLookupMixin, GreaterThan):
def as_oracle(self, compiler, connection):
# e.g.: (a, b, c) > (x, y, z) as SQL:
# WHERE a > x OR (a = x AND (b > y OR (b = y AND c > z)))
cols = self.lhs.get_cols()
lookups = itertools.cycle([GreaterThan, Exact])
connectors = itertools.cycle([OR, AND])
cols_list = [col for col in cols for _ in range(2)]
vals_list = [val for val in self.rhs for _ in range(2)]
cols_iter = iter(cols_list[:-1])
vals_iter = iter(vals_list[:-1])
col, val = next(cols_iter), next(vals_iter)
lookup, connector = next(lookups), next(connectors)
root = node = WhereNode([lookup(col, val)], connector=connector)
for col, val in zip(cols_iter, vals_iter):
lookup, connector = next(lookups), next(connectors)
child = WhereNode([lookup(col, val)], connector=connector)
node.children.append(child)
node = child
return root.as_sql(compiler, connection)
class TupleGreaterThanOrEqual(TupleLookupMixin, GreaterThanOrEqual):
def as_oracle(self, compiler, connection):
# e.g.: (a, b, c) >= (x, y, z) as SQL:
# WHERE a > x OR (a = x AND (b > y OR (b = y AND (c > z OR c = z))))
cols = self.lhs.get_cols()
lookups = itertools.cycle([GreaterThan, Exact])
connectors = itertools.cycle([OR, AND])
cols_list = [col for col in cols for _ in range(2)]
vals_list = [val for val in self.rhs for _ in range(2)]
cols_iter = iter(cols_list)
vals_iter = iter(vals_list)
col, val = next(cols_iter), next(vals_iter)
lookup, connector = next(lookups), next(connectors)
root = node = WhereNode([lookup(col, val)], connector=connector)
for col, val in zip(cols_iter, vals_iter):
lookup, connector = next(lookups), next(connectors)
child = WhereNode([lookup(col, val)], connector=connector)
node.children.append(child)
node = child
return root.as_sql(compiler, connection)
class TupleLessThan(TupleLookupMixin, LessThan):
def as_oracle(self, compiler, connection):
# e.g.: (a, b, c) < (x, y, z) as SQL:
# WHERE a < x OR (a = x AND (b < y OR (b = y AND c < z)))
cols = self.lhs.get_cols()
lookups = itertools.cycle([LessThan, Exact])
connectors = itertools.cycle([OR, AND])
cols_list = [col for col in cols for _ in range(2)]
vals_list = [val for val in self.rhs for _ in range(2)]
cols_iter = iter(cols_list[:-1])
vals_iter = iter(vals_list[:-1])
col, val = next(cols_iter), next(vals_iter)
lookup, connector = next(lookups), next(connectors)
root = node = WhereNode([lookup(col, val)], connector=connector)
for col, val in zip(cols_iter, vals_iter):
lookup, connector = next(lookups), next(connectors)
child = WhereNode([lookup(col, val)], connector=connector)
node.children.append(child)
node = child
return root.as_sql(compiler, connection)
class TupleLessThanOrEqual(TupleLookupMixin, LessThanOrEqual):
def as_oracle(self, compiler, connection):
# e.g.: (a, b, c) <= (x, y, z) as SQL:
# WHERE a < x OR (a = x AND (b < y OR (b = y AND (c < z OR c = z))))
cols = self.lhs.get_cols()
lookups = itertools.cycle([LessThan, Exact])
connectors = itertools.cycle([OR, AND])
cols_list = [col for col in cols for _ in range(2)]
vals_list = [val for val in self.rhs for _ in range(2)]
cols_iter = iter(cols_list)
vals_iter = iter(vals_list)
col, val = next(cols_iter), next(vals_iter)
lookup, connector = next(lookups), next(connectors)
root = node = WhereNode([lookup(col, val)], connector=connector)
for col, val in zip(cols_iter, vals_iter):
lookup, connector = next(lookups), next(connectors)
child = WhereNode([lookup(col, val)], connector=connector)
node.children.append(child)
node = child
return root.as_sql(compiler, connection)
class TupleIn(TupleLookupMixin, In):
def check_tuple_lookup(self):
assert isinstance(self.lhs, ColPairs)
self.check_rhs_is_tuple_or_list()
self.check_rhs_is_collection_of_tuples_or_lists()
self.check_rhs_elements_length_equals_lhs_length()
def as_sql(self, compiler, connection):
if not self.rhs:
raise EmptyResultSet
# e.g.: (a, b, c) in [(x1, y1, z1), (x2, y2, z2)] as SQL:
# WHERE (a, b, c) IN ((x1, y1, z1), (x2, y2, z2))
rhs = []
for vals in self.rhs:
rhs.append(
Tuple(
*[
Value(val, output_field=col.output_field)
for col, val in zip(self.lhs, vals)
]
)
)
lookup = In(Tuple(self.lhs), Tuple(*rhs))
return lookup.as_sql(compiler, connection)
def as_sqlite(self, compiler, connection):
if not self.rhs:
raise EmptyResultSet
# e.g.: (a, b, c) in [(x1, y1, z1), (x2, y2, z2)] as SQL:
# WHERE (a = x1 AND b = y1 AND c = z1) OR (a = x2 AND b = y2 AND c = z2)
root = WhereNode([], connector=OR)
cols = self.lhs.get_cols()
for vals in self.rhs:
lookups = [Exact(col, val) for col, val in zip(cols, vals)]
root.children.append(WhereNode(lookups, connector=AND))
return root.as_sql(compiler, connection)
tuple_lookups = {
"exact": TupleExact,
"gt": TupleGreaterThan,
"gte": TupleGreaterThanOrEqual,
"lt": TupleLessThan,
"lte": TupleLessThanOrEqual,
"in": TupleIn,
"isnull": TupleIsNull,
}

View File

@ -23,6 +23,7 @@ from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import (
BaseExpression,
Col,
ColPairs,
Exists,
F,
OuterRef,
@ -32,7 +33,6 @@ from django.db.models.expressions import (
Value,
)
from django.db.models.fields import Field
from django.db.models.fields.related_lookups import MultiColSource
from django.db.models.lookups import Lookup
from django.db.models.query_utils import (
Q,
@ -1549,9 +1549,7 @@ class Query(BaseExpression):
if len(targets) == 1:
col = self._get_col(targets[0], join_info.final_field, alias)
else:
col = MultiColSource(
alias, targets, join_info.targets, join_info.final_field
)
col = ColPairs(alias, targets, join_info.targets, join_info.final_field)
else:
col = self._get_col(targets[0], join_info.final_field, alias)

View File

@ -0,0 +1,242 @@
import unittest
from django.db import NotSupportedError, connection
from django.test import TestCase
from .models import Contact, Customer
class TupleLookupsTests(TestCase):
@classmethod
def setUpTestData(cls):
super().setUpTestData()
cls.customer_1 = Customer.objects.create(customer_id=1, company="a")
cls.customer_2 = Customer.objects.create(customer_id=1, company="b")
cls.customer_3 = Customer.objects.create(customer_id=2, company="c")
cls.customer_4 = Customer.objects.create(customer_id=3, company="d")
cls.customer_5 = Customer.objects.create(customer_id=1, company="e")
cls.contact_1 = Contact.objects.create(customer=cls.customer_1)
cls.contact_2 = Contact.objects.create(customer=cls.customer_1)
cls.contact_3 = Contact.objects.create(customer=cls.customer_2)
cls.contact_4 = Contact.objects.create(customer=cls.customer_3)
cls.contact_5 = Contact.objects.create(customer=cls.customer_1)
cls.contact_6 = Contact.objects.create(customer=cls.customer_5)
def test_exact(self):
test_cases = (
(self.customer_1, (self.contact_1, self.contact_2, self.contact_5)),
(self.customer_2, (self.contact_3,)),
(self.customer_3, (self.contact_4,)),
(self.customer_4, ()),
(self.customer_5, (self.contact_6,)),
)
for customer, contacts in test_cases:
with self.subTest(customer=customer, contacts=contacts):
self.assertSequenceEqual(
Contact.objects.filter(customer=customer).order_by("id"), contacts
)
def test_exact_subquery(self):
with self.assertRaisesMessage(
NotSupportedError, "'exact' doesn't support multi-column subqueries."
):
subquery = Customer.objects.filter(id=self.customer_1.id)[:1]
self.assertSequenceEqual(
Contact.objects.filter(customer=subquery).order_by("id"), ()
)
def test_in(self):
cust_1, cust_2, cust_3, cust_4, cust_5 = (
self.customer_1,
self.customer_2,
self.customer_3,
self.customer_4,
self.customer_5,
)
c1, c2, c3, c4, c5, c6 = (
self.contact_1,
self.contact_2,
self.contact_3,
self.contact_4,
self.contact_5,
self.contact_6,
)
test_cases = (
((), ()),
((cust_1,), (c1, c2, c5)),
((cust_1, cust_2), (c1, c2, c3, c5)),
((cust_1, cust_2, cust_3), (c1, c2, c3, c4, c5)),
((cust_1, cust_2, cust_3, cust_4), (c1, c2, c3, c4, c5)),
((cust_1, cust_2, cust_3, cust_4, cust_5), (c1, c2, c3, c4, c5, c6)),
)
for contacts, customers in test_cases:
with self.subTest(contacts=contacts, customers=customers):
self.assertSequenceEqual(
Contact.objects.filter(customer__in=contacts).order_by("id"),
customers,
)
@unittest.skipIf(
connection.vendor == "mysql",
"MySQL doesn't support LIMIT & IN/ALL/ANY/SOME subquery",
)
def test_in_subquery(self):
subquery = Customer.objects.filter(id=self.customer_1.id)[:1]
self.assertSequenceEqual(
Contact.objects.filter(customer__in=subquery).order_by("id"),
(self.contact_1, self.contact_2, self.contact_5),
)
def test_lt(self):
c1, c2, c3, c4, c5, c6 = (
self.contact_1,
self.contact_2,
self.contact_3,
self.contact_4,
self.contact_5,
self.contact_6,
)
test_cases = (
(self.customer_1, ()),
(self.customer_2, (c1, c2, c5)),
(self.customer_5, (c1, c2, c3, c5)),
(self.customer_3, (c1, c2, c3, c5, c6)),
(self.customer_4, (c1, c2, c3, c4, c5, c6)),
)
for customer, contacts in test_cases:
with self.subTest(customer=customer, contacts=contacts):
self.assertSequenceEqual(
Contact.objects.filter(customer__lt=customer).order_by("id"),
contacts,
)
def test_lt_subquery(self):
with self.assertRaisesMessage(
NotSupportedError, "'lt' doesn't support multi-column subqueries."
):
subquery = Customer.objects.filter(id=self.customer_1.id)[:1]
self.assertSequenceEqual(
Contact.objects.filter(customer__lt=subquery).order_by("id"), ()
)
def test_lte(self):
c1, c2, c3, c4, c5, c6 = (
self.contact_1,
self.contact_2,
self.contact_3,
self.contact_4,
self.contact_5,
self.contact_6,
)
test_cases = (
(self.customer_1, (c1, c2, c5)),
(self.customer_2, (c1, c2, c3, c5)),
(self.customer_5, (c1, c2, c3, c5, c6)),
(self.customer_3, (c1, c2, c3, c4, c5, c6)),
(self.customer_4, (c1, c2, c3, c4, c5, c6)),
)
for customer, contacts in test_cases:
with self.subTest(customer=customer, contacts=contacts):
self.assertSequenceEqual(
Contact.objects.filter(customer__lte=customer).order_by("id"),
contacts,
)
def test_lte_subquery(self):
with self.assertRaisesMessage(
NotSupportedError, "'lte' doesn't support multi-column subqueries."
):
subquery = Customer.objects.filter(id=self.customer_1.id)[:1]
self.assertSequenceEqual(
Contact.objects.filter(customer__lte=subquery).order_by("id"), ()
)
def test_gt(self):
test_cases = (
(self.customer_1, (self.contact_3, self.contact_4, self.contact_6)),
(self.customer_2, (self.contact_4, self.contact_6)),
(self.customer_5, (self.contact_4,)),
(self.customer_3, ()),
(self.customer_4, ()),
)
for customer, contacts in test_cases:
with self.subTest(customer=customer, contacts=contacts):
self.assertSequenceEqual(
Contact.objects.filter(customer__gt=customer).order_by("id"),
contacts,
)
def test_gt_subquery(self):
with self.assertRaisesMessage(
NotSupportedError, "'gt' doesn't support multi-column subqueries."
):
subquery = Customer.objects.filter(id=self.customer_1.id)[:1]
self.assertSequenceEqual(
Contact.objects.filter(customer__gt=subquery).order_by("id"), ()
)
def test_gte(self):
c1, c2, c3, c4, c5, c6 = (
self.contact_1,
self.contact_2,
self.contact_3,
self.contact_4,
self.contact_5,
self.contact_6,
)
test_cases = (
(self.customer_1, (c1, c2, c3, c4, c5, c6)),
(self.customer_2, (c3, c4, c6)),
(self.customer_5, (c4, c6)),
(self.customer_3, (c4,)),
(self.customer_4, ()),
)
for customer, contacts in test_cases:
with self.subTest(customer=customer, contacts=contacts):
self.assertSequenceEqual(
Contact.objects.filter(customer__gte=customer).order_by("pk"),
contacts,
)
def test_gte_subquery(self):
with self.assertRaisesMessage(
NotSupportedError, "'gte' doesn't support multi-column subqueries."
):
subquery = Customer.objects.filter(id=self.customer_1.id)[:1]
self.assertSequenceEqual(
Contact.objects.filter(customer__gte=subquery).order_by("id"), ()
)
def test_isnull(self):
with self.subTest("customer__isnull=True"):
self.assertSequenceEqual(
Contact.objects.filter(customer__isnull=True).order_by("id"),
(),
)
with self.subTest("customer__isnull=False"):
self.assertSequenceEqual(
Contact.objects.filter(customer__isnull=False).order_by("id"),
(
self.contact_1,
self.contact_2,
self.contact_3,
self.contact_4,
self.contact_5,
self.contact_6,
),
)
def test_isnull_subquery(self):
with self.assertRaisesMessage(
NotSupportedError, "'isnull' doesn't support multi-column subqueries."
):
subquery = Customer.objects.filter(id=0)[:1]
self.assertSequenceEqual(
Contact.objects.filter(customer__isnull=subquery).order_by("id"), ()
)