diff --git a/django/db/models/constraints.py b/django/db/models/constraints.py index 96205b44ad..98912a6467 100644 --- a/django/db/models/constraints.py +++ b/django/db/models/constraints.py @@ -30,6 +30,11 @@ class BaseConstraint: class CheckConstraint(BaseConstraint): def __init__(self, *, check, name): self.check = check + if not getattr(check, 'conditional', False): + raise TypeError( + 'CheckConstraint.check must be a Q instance or boolean ' + 'expression.' + ) super().__init__(name) def _get_check_sql(self, model, schema_editor): diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index f14859fda1..dcf897c649 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -1221,8 +1221,19 @@ class Query(BaseExpression): """ if isinstance(filter_expr, dict): raise FieldError("Cannot parse keyword query as dict") + if isinstance(filter_expr, Q): + return self._add_q( + filter_expr, + branch_negated=branch_negated, + current_negated=current_negated, + used_aliases=can_reuse, + allow_joins=allow_joins, + split_subq=split_subq, + ) if hasattr(filter_expr, 'resolve_expression') and getattr(filter_expr, 'conditional', False): - condition = self.build_lookup(['exact'], filter_expr.resolve_expression(self), True) + condition = self.build_lookup( + ['exact'], filter_expr.resolve_expression(self, allow_joins=allow_joins), True + ) clause = self.where_class() clause.add(condition, AND) return clause, [] @@ -1332,8 +1343,8 @@ class Query(BaseExpression): self.where.add(clause, AND) self.demote_joins(existing_inner) - def build_where(self, q_object): - return self._add_q(q_object, used_aliases=set(), allow_joins=False)[0] + def build_where(self, filter_expr): + return self.build_filter(filter_expr, allow_joins=False)[0] def _add_q(self, q_object, used_aliases, branch_negated=False, current_negated=False, allow_joins=True, split_subq=True): @@ -1345,18 +1356,12 @@ class Query(BaseExpression): negated=q_object.negated) joinpromoter = JoinPromoter(q_object.connector, len(q_object.children), current_negated) for child in q_object.children: - if isinstance(child, Node): - child_clause, needed_inner = self._add_q( - child, used_aliases, branch_negated, - current_negated, allow_joins, split_subq) - joinpromoter.add_votes(needed_inner) - else: - child_clause, needed_inner = self.build_filter( - child, can_reuse=used_aliases, branch_negated=branch_negated, - current_negated=current_negated, allow_joins=allow_joins, - split_subq=split_subq, - ) - joinpromoter.add_votes(needed_inner) + child_clause, needed_inner = self.build_filter( + child, can_reuse=used_aliases, branch_negated=branch_negated, + current_negated=current_negated, allow_joins=allow_joins, + split_subq=split_subq, + ) + joinpromoter.add_votes(needed_inner) if child_clause: target_clause.add(child_clause, connector) needed_inner = joinpromoter.update_join_types(self) diff --git a/docs/ref/models/constraints.txt b/docs/ref/models/constraints.txt index 0ac54cf265..9e3119f600 100644 --- a/docs/ref/models/constraints.txt +++ b/docs/ref/models/constraints.txt @@ -52,12 +52,16 @@ option. .. attribute:: CheckConstraint.check -A :class:`Q` object that specifies the check you want the constraint to -enforce. +A :class:`Q` object or boolean :class:`~django.db.models.Expression` that +specifies the check you want the constraint to enforce. For example, ``CheckConstraint(check=Q(age__gte=18), name='age_gte_18')`` ensures the age field is never less than 18. +.. versionchanged:: 3.1 + + Support for boolean :class:`~django.db.models.Expression` was added. + ``name`` -------- diff --git a/docs/releases/3.1.txt b/docs/releases/3.1.txt index 16f9ed01be..bff16dc1ab 100644 --- a/docs/releases/3.1.txt +++ b/docs/releases/3.1.txt @@ -204,6 +204,8 @@ Models ``OneToOneField`` emulates the behavior of the SQL constraint ``ON DELETE RESTRICT``. +* :attr:`.CheckConstraint.check` now supports boolean expressions. + Pagination ~~~~~~~~~~ diff --git a/tests/constraints/models.py b/tests/constraints/models.py index d207de6e73..cadeea887e 100644 --- a/tests/constraints/models.py +++ b/tests/constraints/models.py @@ -18,6 +18,19 @@ class Product(models.Model): check=models.Q(price__gt=0), name='%(app_label)s_%(class)s_price_gt_0', ), + models.CheckConstraint( + check=models.expressions.RawSQL( + 'price < %s', (1000,), output_field=models.BooleanField() + ), + name='%(app_label)s_price_lt_1000_raw', + ), + models.CheckConstraint( + check=models.expressions.ExpressionWrapper( + models.Q(price__gt=500) | models.Q(price__lt=500), + output_field=models.BooleanField() + ), + name='%(app_label)s_price_neq_500_wrap', + ), ] diff --git a/tests/constraints/tests.py b/tests/constraints/tests.py index 8e2eb11e2a..136060eb07 100644 --- a/tests/constraints/tests.py +++ b/tests/constraints/tests.py @@ -61,6 +61,13 @@ class CheckConstraintTests(TestCase): "".format(check, name), ) + def test_invalid_check_types(self): + msg = ( + 'CheckConstraint.check must be a Q instance or boolean expression.' + ) + with self.assertRaisesMessage(TypeError, msg): + models.CheckConstraint(check=models.F('discounted_price'), name='check') + def test_deconstruction(self): check = models.Q(price__gt=models.F('discounted_price')) name = 'price_gt_discounted_price' @@ -76,11 +83,25 @@ class CheckConstraintTests(TestCase): with self.assertRaises(IntegrityError): Product.objects.create(price=10, discounted_price=20) + @skipUnlessDBFeature('supports_table_check_constraints') + def test_database_constraint_expression(self): + Product.objects.create(price=999, discounted_price=5) + with self.assertRaises(IntegrityError): + Product.objects.create(price=1000, discounted_price=5) + + @skipUnlessDBFeature('supports_table_check_constraints') + def test_database_constraint_expressionwrapper(self): + Product.objects.create(price=499, discounted_price=5) + with self.assertRaises(IntegrityError): + Product.objects.create(price=500, discounted_price=5) + @skipUnlessDBFeature('supports_table_check_constraints', 'can_introspect_check_constraints') def test_name(self): constraints = get_constraints(Product._meta.db_table) for expected_name in ( 'price_gt_discounted_price', + 'constraints_price_lt_1000_raw', + 'constraints_price_neq_500_wrap', 'constraints_product_price_gt_0', ): with self.subTest(expected_name): diff --git a/tests/queries/test_query.py b/tests/queries/test_query.py index ecd9c96d8c..9d18b15f3c 100644 --- a/tests/queries/test_query.py +++ b/tests/queries/test_query.py @@ -1,8 +1,8 @@ from datetime import datetime from django.core.exceptions import FieldError -from django.db.models import CharField, F, Q -from django.db.models.expressions import Col +from django.db.models import BooleanField, CharField, F, Q +from django.db.models.expressions import Col, Func from django.db.models.fields.related_lookups import RelatedIsNull from django.db.models.functions import Lower from django.db.models.lookups import Exact, GreaterThan, IsNull, LessThan @@ -129,3 +129,18 @@ class TestQuery(SimpleTestCase): name_exact = where.children[0] self.assertIsInstance(name_exact, Exact) self.assertEqual(name_exact.rhs, "['a', 'b']") + + def test_filter_conditional(self): + query = Query(Item) + where = query.build_where(Func(output_field=BooleanField())) + exact = where.children[0] + self.assertIsInstance(exact, Exact) + self.assertIsInstance(exact.lhs, Func) + self.assertIs(exact.rhs, True) + + def test_filter_conditional_join(self): + query = Query(Item) + filter_expr = Func('note__note', output_field=BooleanField()) + msg = 'Joined field references are not permitted in this query' + with self.assertRaisesMessage(FieldError, msg): + query.build_where(filter_expr)