Refs #373 -- Added additional validations to tuple lookups.

This commit is contained in:
Bendeguz Csirmaz 2024-09-27 00:04:33 +08:00 committed by Sarah Boyce
parent 263f731919
commit 97c05a64ca
2 changed files with 139 additions and 7 deletions

View File

@ -2,7 +2,7 @@ import itertools
from django.core.exceptions import EmptyResultSet from django.core.exceptions import EmptyResultSet
from django.db.models import Field from django.db.models import Field
from django.db.models.expressions import Func, Value from django.db.models.expressions import ColPairs, Func, Value
from django.db.models.lookups import ( from django.db.models.lookups import (
Exact, Exact,
GreaterThan, GreaterThan,
@ -28,17 +28,32 @@ class Tuple(Func):
class TupleLookupMixin: class TupleLookupMixin:
def get_prep_lookup(self): def get_prep_lookup(self):
self.check_rhs_is_tuple_or_list()
self.check_rhs_length_equals_lhs_length() self.check_rhs_length_equals_lhs_length()
return self.rhs return self.rhs
def check_rhs_is_tuple_or_list(self):
if not isinstance(self.rhs, (tuple, list)):
lhs_str = self.get_lhs_str()
raise ValueError(
f"{self.lookup_name!r} lookup of {lhs_str} must be a tuple or a list"
)
def check_rhs_length_equals_lhs_length(self): def check_rhs_length_equals_lhs_length(self):
len_lhs = len(self.lhs) len_lhs = len(self.lhs)
if len_lhs != len(self.rhs): if len_lhs != len(self.rhs):
lhs_str = self.get_lhs_str()
raise ValueError( raise ValueError(
f"'{self.lookup_name}' lookup of '{self.lhs.field.name}' field " f"{self.lookup_name!r} lookup of {lhs_str} must have {len_lhs} elements"
f"must have {len_lhs} elements"
) )
def get_lhs_str(self):
if isinstance(self.lhs, ColPairs):
return repr(self.lhs.field.name)
else:
names = ", ".join(repr(f.name) for f in self.lhs)
return f"({names})"
def get_prep_lhs(self): def get_prep_lhs(self):
if isinstance(self.lhs, (tuple, list)): if isinstance(self.lhs, (tuple, list)):
return Tuple(*self.lhs) return Tuple(*self.lhs)
@ -196,14 +211,25 @@ class TupleLessThanOrEqual(TupleLookupMixin, LessThanOrEqual):
class TupleIn(TupleLookupMixin, In): class TupleIn(TupleLookupMixin, In):
def get_prep_lookup(self): def get_prep_lookup(self):
self.check_rhs_is_tuple_or_list()
self.check_rhs_is_collection_of_tuples_or_lists()
self.check_rhs_elements_length_equals_lhs_length() self.check_rhs_elements_length_equals_lhs_length()
return super(TupleLookupMixin, self).get_prep_lookup() return self.rhs # skip checks from mixin
def check_rhs_is_collection_of_tuples_or_lists(self):
if not all(isinstance(vals, (tuple, list)) for vals in self.rhs):
lhs_str = self.get_lhs_str()
raise ValueError(
f"{self.lookup_name!r} lookup of {lhs_str} "
"must be a collection of tuples or lists"
)
def check_rhs_elements_length_equals_lhs_length(self): def check_rhs_elements_length_equals_lhs_length(self):
len_lhs = len(self.lhs) len_lhs = len(self.lhs)
if not all(len_lhs == len(vals) for vals in self.rhs): if not all(len_lhs == len(vals) for vals in self.rhs):
lhs_str = self.get_lhs_str()
raise ValueError( raise ValueError(
f"'{self.lookup_name}' lookup of '{self.lhs.field.name}' field " f"{self.lookup_name!r} lookup of {lhs_str} "
f"must have {len_lhs} elements each" f"must have {len_lhs} elements each"
) )

View File

@ -1,3 +1,4 @@
import itertools
import unittest import unittest
from django.db import NotSupportedError, connection from django.db import NotSupportedError, connection
@ -129,6 +130,37 @@ class TupleLookupsTests(TestCase):
(self.contact_1, self.contact_2, self.contact_5), (self.contact_1, self.contact_2, self.contact_5),
) )
def test_tuple_in_rhs_must_be_collection_of_tuples_or_lists(self):
test_cases = (
(1, 2, 3),
((1, 2), (3, 4), None),
)
for rhs in test_cases:
with self.subTest(rhs=rhs):
with self.assertRaisesMessage(
ValueError,
"'in' lookup of ('customer_code', 'company_code') "
"must be a collection of tuples or lists",
):
TupleIn((F("customer_code"), F("company_code")), rhs)
def test_tuple_in_rhs_must_have_2_elements_each(self):
test_cases = (
((),),
((1,),),
((1, 2, 3),),
)
for rhs in test_cases:
with self.subTest(rhs=rhs):
with self.assertRaisesMessage(
ValueError,
"'in' lookup of ('customer_code', 'company_code') "
"must have 2 elements each",
):
TupleIn((F("customer_code"), F("company_code")), rhs)
def test_lt(self): def test_lt(self):
c1, c2, c3, c4, c5, c6 = ( c1, c2, c3, c4, c5, c6 = (
self.contact_1, self.contact_1,
@ -358,8 +390,8 @@ class TupleLookupsTests(TestCase):
) )
def test_lookup_errors(self): def test_lookup_errors(self):
m_2_elements = "'%s' lookup of 'customer' field must have 2 elements" m_2_elements = "'%s' lookup of 'customer' must have 2 elements"
m_2_elements_each = "'in' lookup of 'customer' field must have 2 elements each" m_2_elements_each = "'in' lookup of 'customer' must have 2 elements each"
test_cases = ( test_cases = (
({"customer": 1}, m_2_elements % "exact"), ({"customer": 1}, m_2_elements % "exact"),
({"customer": (1, 2, 3)}, m_2_elements % "exact"), ({"customer": (1, 2, 3)}, m_2_elements % "exact"),
@ -381,3 +413,77 @@ class TupleLookupsTests(TestCase):
self.assertRaisesMessage(ValueError, message), self.assertRaisesMessage(ValueError, message),
): ):
Contact.objects.get(**kwargs) Contact.objects.get(**kwargs)
def test_tuple_lookup_names(self):
test_cases = (
(TupleExact, "exact"),
(TupleGreaterThan, "gt"),
(TupleGreaterThanOrEqual, "gte"),
(TupleLessThan, "lt"),
(TupleLessThanOrEqual, "lte"),
(TupleIn, "in"),
(TupleIsNull, "isnull"),
)
for lookup_class, lookup_name in test_cases:
with self.subTest(lookup_name):
self.assertEqual(lookup_class.lookup_name, lookup_name)
def test_tuple_lookup_rhs_must_be_tuple_or_list(self):
test_cases = itertools.product(
(
TupleExact,
TupleGreaterThan,
TupleGreaterThanOrEqual,
TupleLessThan,
TupleLessThanOrEqual,
TupleIn,
),
(
0,
1,
None,
True,
False,
{"foo": "bar"},
),
)
for lookup_cls, rhs in test_cases:
lookup_name = lookup_cls.lookup_name
with self.subTest(lookup_name=lookup_name, rhs=rhs):
with self.assertRaisesMessage(
ValueError,
f"'{lookup_name}' lookup of ('customer_code', 'company_code') "
"must be a tuple or a list",
):
lookup_cls((F("customer_code"), F("company_code")), rhs)
def test_tuple_lookup_rhs_must_have_2_elements(self):
test_cases = itertools.product(
(
TupleExact,
TupleGreaterThan,
TupleGreaterThanOrEqual,
TupleLessThan,
TupleLessThanOrEqual,
),
(
[],
[1],
[1, 2, 3],
(),
(1,),
(1, 2, 3),
),
)
for lookup_cls, rhs in test_cases:
lookup_name = lookup_cls.lookup_name
with self.subTest(lookup_name=lookup_name, rhs=rhs):
with self.assertRaisesMessage(
ValueError,
f"'{lookup_name}' lookup of ('customer_code', 'company_code') "
"must have 2 elements",
):
lookup_cls((F("customer_code"), F("company_code")), rhs)