Fixed #30484 -- Added conditional expressions support to CheckConstraint.
This commit is contained in:
parent
37e6c5b79b
commit
e9a0e1d4f6
|
@ -30,6 +30,11 @@ class BaseConstraint:
|
||||||
class CheckConstraint(BaseConstraint):
|
class CheckConstraint(BaseConstraint):
|
||||||
def __init__(self, *, check, name):
|
def __init__(self, *, check, name):
|
||||||
self.check = check
|
self.check = check
|
||||||
|
if not getattr(check, 'conditional', False):
|
||||||
|
raise TypeError(
|
||||||
|
'CheckConstraint.check must be a Q instance or boolean '
|
||||||
|
'expression.'
|
||||||
|
)
|
||||||
super().__init__(name)
|
super().__init__(name)
|
||||||
|
|
||||||
def _get_check_sql(self, model, schema_editor):
|
def _get_check_sql(self, model, schema_editor):
|
||||||
|
|
|
@ -1221,8 +1221,19 @@ class Query(BaseExpression):
|
||||||
"""
|
"""
|
||||||
if isinstance(filter_expr, dict):
|
if isinstance(filter_expr, dict):
|
||||||
raise FieldError("Cannot parse keyword query as dict")
|
raise FieldError("Cannot parse keyword query as dict")
|
||||||
|
if isinstance(filter_expr, Q):
|
||||||
|
return self._add_q(
|
||||||
|
filter_expr,
|
||||||
|
branch_negated=branch_negated,
|
||||||
|
current_negated=current_negated,
|
||||||
|
used_aliases=can_reuse,
|
||||||
|
allow_joins=allow_joins,
|
||||||
|
split_subq=split_subq,
|
||||||
|
)
|
||||||
if hasattr(filter_expr, 'resolve_expression') and getattr(filter_expr, 'conditional', False):
|
if hasattr(filter_expr, 'resolve_expression') and getattr(filter_expr, 'conditional', False):
|
||||||
condition = self.build_lookup(['exact'], filter_expr.resolve_expression(self), True)
|
condition = self.build_lookup(
|
||||||
|
['exact'], filter_expr.resolve_expression(self, allow_joins=allow_joins), True
|
||||||
|
)
|
||||||
clause = self.where_class()
|
clause = self.where_class()
|
||||||
clause.add(condition, AND)
|
clause.add(condition, AND)
|
||||||
return clause, []
|
return clause, []
|
||||||
|
@ -1332,8 +1343,8 @@ class Query(BaseExpression):
|
||||||
self.where.add(clause, AND)
|
self.where.add(clause, AND)
|
||||||
self.demote_joins(existing_inner)
|
self.demote_joins(existing_inner)
|
||||||
|
|
||||||
def build_where(self, q_object):
|
def build_where(self, filter_expr):
|
||||||
return self._add_q(q_object, used_aliases=set(), allow_joins=False)[0]
|
return self.build_filter(filter_expr, allow_joins=False)[0]
|
||||||
|
|
||||||
def _add_q(self, q_object, used_aliases, branch_negated=False,
|
def _add_q(self, q_object, used_aliases, branch_negated=False,
|
||||||
current_negated=False, allow_joins=True, split_subq=True):
|
current_negated=False, allow_joins=True, split_subq=True):
|
||||||
|
@ -1345,18 +1356,12 @@ class Query(BaseExpression):
|
||||||
negated=q_object.negated)
|
negated=q_object.negated)
|
||||||
joinpromoter = JoinPromoter(q_object.connector, len(q_object.children), current_negated)
|
joinpromoter = JoinPromoter(q_object.connector, len(q_object.children), current_negated)
|
||||||
for child in q_object.children:
|
for child in q_object.children:
|
||||||
if isinstance(child, Node):
|
child_clause, needed_inner = self.build_filter(
|
||||||
child_clause, needed_inner = self._add_q(
|
child, can_reuse=used_aliases, branch_negated=branch_negated,
|
||||||
child, used_aliases, branch_negated,
|
current_negated=current_negated, allow_joins=allow_joins,
|
||||||
current_negated, allow_joins, split_subq)
|
split_subq=split_subq,
|
||||||
joinpromoter.add_votes(needed_inner)
|
)
|
||||||
else:
|
joinpromoter.add_votes(needed_inner)
|
||||||
child_clause, needed_inner = self.build_filter(
|
|
||||||
child, can_reuse=used_aliases, branch_negated=branch_negated,
|
|
||||||
current_negated=current_negated, allow_joins=allow_joins,
|
|
||||||
split_subq=split_subq,
|
|
||||||
)
|
|
||||||
joinpromoter.add_votes(needed_inner)
|
|
||||||
if child_clause:
|
if child_clause:
|
||||||
target_clause.add(child_clause, connector)
|
target_clause.add(child_clause, connector)
|
||||||
needed_inner = joinpromoter.update_join_types(self)
|
needed_inner = joinpromoter.update_join_types(self)
|
||||||
|
|
|
@ -52,12 +52,16 @@ option.
|
||||||
|
|
||||||
.. attribute:: CheckConstraint.check
|
.. attribute:: CheckConstraint.check
|
||||||
|
|
||||||
A :class:`Q` object that specifies the check you want the constraint to
|
A :class:`Q` object or boolean :class:`~django.db.models.Expression` that
|
||||||
enforce.
|
specifies the check you want the constraint to enforce.
|
||||||
|
|
||||||
For example, ``CheckConstraint(check=Q(age__gte=18), name='age_gte_18')``
|
For example, ``CheckConstraint(check=Q(age__gte=18), name='age_gte_18')``
|
||||||
ensures the age field is never less than 18.
|
ensures the age field is never less than 18.
|
||||||
|
|
||||||
|
.. versionchanged:: 3.1
|
||||||
|
|
||||||
|
Support for boolean :class:`~django.db.models.Expression` was added.
|
||||||
|
|
||||||
``name``
|
``name``
|
||||||
--------
|
--------
|
||||||
|
|
||||||
|
|
|
@ -204,6 +204,8 @@ Models
|
||||||
``OneToOneField`` emulates the behavior of the SQL constraint ``ON DELETE
|
``OneToOneField`` emulates the behavior of the SQL constraint ``ON DELETE
|
||||||
RESTRICT``.
|
RESTRICT``.
|
||||||
|
|
||||||
|
* :attr:`.CheckConstraint.check` now supports boolean expressions.
|
||||||
|
|
||||||
Pagination
|
Pagination
|
||||||
~~~~~~~~~~
|
~~~~~~~~~~
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,19 @@ class Product(models.Model):
|
||||||
check=models.Q(price__gt=0),
|
check=models.Q(price__gt=0),
|
||||||
name='%(app_label)s_%(class)s_price_gt_0',
|
name='%(app_label)s_%(class)s_price_gt_0',
|
||||||
),
|
),
|
||||||
|
models.CheckConstraint(
|
||||||
|
check=models.expressions.RawSQL(
|
||||||
|
'price < %s', (1000,), output_field=models.BooleanField()
|
||||||
|
),
|
||||||
|
name='%(app_label)s_price_lt_1000_raw',
|
||||||
|
),
|
||||||
|
models.CheckConstraint(
|
||||||
|
check=models.expressions.ExpressionWrapper(
|
||||||
|
models.Q(price__gt=500) | models.Q(price__lt=500),
|
||||||
|
output_field=models.BooleanField()
|
||||||
|
),
|
||||||
|
name='%(app_label)s_price_neq_500_wrap',
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -61,6 +61,13 @@ class CheckConstraintTests(TestCase):
|
||||||
"<CheckConstraint: check='{}' name='{}'>".format(check, name),
|
"<CheckConstraint: check='{}' name='{}'>".format(check, name),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_invalid_check_types(self):
|
||||||
|
msg = (
|
||||||
|
'CheckConstraint.check must be a Q instance or boolean expression.'
|
||||||
|
)
|
||||||
|
with self.assertRaisesMessage(TypeError, msg):
|
||||||
|
models.CheckConstraint(check=models.F('discounted_price'), name='check')
|
||||||
|
|
||||||
def test_deconstruction(self):
|
def test_deconstruction(self):
|
||||||
check = models.Q(price__gt=models.F('discounted_price'))
|
check = models.Q(price__gt=models.F('discounted_price'))
|
||||||
name = 'price_gt_discounted_price'
|
name = 'price_gt_discounted_price'
|
||||||
|
@ -76,11 +83,25 @@ class CheckConstraintTests(TestCase):
|
||||||
with self.assertRaises(IntegrityError):
|
with self.assertRaises(IntegrityError):
|
||||||
Product.objects.create(price=10, discounted_price=20)
|
Product.objects.create(price=10, discounted_price=20)
|
||||||
|
|
||||||
|
@skipUnlessDBFeature('supports_table_check_constraints')
|
||||||
|
def test_database_constraint_expression(self):
|
||||||
|
Product.objects.create(price=999, discounted_price=5)
|
||||||
|
with self.assertRaises(IntegrityError):
|
||||||
|
Product.objects.create(price=1000, discounted_price=5)
|
||||||
|
|
||||||
|
@skipUnlessDBFeature('supports_table_check_constraints')
|
||||||
|
def test_database_constraint_expressionwrapper(self):
|
||||||
|
Product.objects.create(price=499, discounted_price=5)
|
||||||
|
with self.assertRaises(IntegrityError):
|
||||||
|
Product.objects.create(price=500, discounted_price=5)
|
||||||
|
|
||||||
@skipUnlessDBFeature('supports_table_check_constraints', 'can_introspect_check_constraints')
|
@skipUnlessDBFeature('supports_table_check_constraints', 'can_introspect_check_constraints')
|
||||||
def test_name(self):
|
def test_name(self):
|
||||||
constraints = get_constraints(Product._meta.db_table)
|
constraints = get_constraints(Product._meta.db_table)
|
||||||
for expected_name in (
|
for expected_name in (
|
||||||
'price_gt_discounted_price',
|
'price_gt_discounted_price',
|
||||||
|
'constraints_price_lt_1000_raw',
|
||||||
|
'constraints_price_neq_500_wrap',
|
||||||
'constraints_product_price_gt_0',
|
'constraints_product_price_gt_0',
|
||||||
):
|
):
|
||||||
with self.subTest(expected_name):
|
with self.subTest(expected_name):
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from django.core.exceptions import FieldError
|
from django.core.exceptions import FieldError
|
||||||
from django.db.models import CharField, F, Q
|
from django.db.models import BooleanField, CharField, F, Q
|
||||||
from django.db.models.expressions import Col
|
from django.db.models.expressions import Col, Func
|
||||||
from django.db.models.fields.related_lookups import RelatedIsNull
|
from django.db.models.fields.related_lookups import RelatedIsNull
|
||||||
from django.db.models.functions import Lower
|
from django.db.models.functions import Lower
|
||||||
from django.db.models.lookups import Exact, GreaterThan, IsNull, LessThan
|
from django.db.models.lookups import Exact, GreaterThan, IsNull, LessThan
|
||||||
|
@ -129,3 +129,18 @@ class TestQuery(SimpleTestCase):
|
||||||
name_exact = where.children[0]
|
name_exact = where.children[0]
|
||||||
self.assertIsInstance(name_exact, Exact)
|
self.assertIsInstance(name_exact, Exact)
|
||||||
self.assertEqual(name_exact.rhs, "['a', 'b']")
|
self.assertEqual(name_exact.rhs, "['a', 'b']")
|
||||||
|
|
||||||
|
def test_filter_conditional(self):
|
||||||
|
query = Query(Item)
|
||||||
|
where = query.build_where(Func(output_field=BooleanField()))
|
||||||
|
exact = where.children[0]
|
||||||
|
self.assertIsInstance(exact, Exact)
|
||||||
|
self.assertIsInstance(exact.lhs, Func)
|
||||||
|
self.assertIs(exact.rhs, True)
|
||||||
|
|
||||||
|
def test_filter_conditional_join(self):
|
||||||
|
query = Query(Item)
|
||||||
|
filter_expr = Func('note__note', output_field=BooleanField())
|
||||||
|
msg = 'Joined field references are not permitted in this query'
|
||||||
|
with self.assertRaisesMessage(FieldError, msg):
|
||||||
|
query.build_where(filter_expr)
|
||||||
|
|
Loading…
Reference in New Issue