Fixed #31902 -- Fixed crash of ExclusionConstraint on expressions with params.

This commit is contained in:
Maxim Petrov 2020-08-19 05:22:12 +03:00 committed by Mariusz Felisiak
parent e2e34f4de3
commit bf6d07730c
2 changed files with 17 additions and 4 deletions

View File

@ -66,20 +66,21 @@ class ExclusionConstraint(BaseConstraint):
self.opclasses = opclasses self.opclasses = opclasses
super().__init__(name=name) super().__init__(name=name)
def _get_expression_sql(self, compiler, connection, query): def _get_expression_sql(self, compiler, schema_editor, query):
expressions = [] expressions = []
for idx, (expression, operator) in enumerate(self.expressions): for idx, (expression, operator) in enumerate(self.expressions):
if isinstance(expression, str): if isinstance(expression, str):
expression = F(expression) expression = F(expression)
expression = expression.resolve_expression(query=query) expression = expression.resolve_expression(query=query)
sql, params = expression.as_sql(compiler, connection) sql, params = expression.as_sql(compiler, schema_editor.connection)
try: try:
opclass = self.opclasses[idx] opclass = self.opclasses[idx]
if opclass: if opclass:
sql = '%s %s' % (sql, opclass) sql = '%s %s' % (sql, opclass)
except IndexError: except IndexError:
pass pass
expressions.append('%s WITH %s' % (sql % params, operator)) sql = sql % tuple(schema_editor.quote_value(p) for p in params)
expressions.append('%s WITH %s' % (sql, operator))
return expressions return expressions
def _get_condition_sql(self, compiler, schema_editor, query): def _get_condition_sql(self, compiler, schema_editor, query):
@ -92,7 +93,7 @@ class ExclusionConstraint(BaseConstraint):
def constraint_sql(self, model, schema_editor): def constraint_sql(self, model, schema_editor):
query = Query(model, alias_cols=False) query = Query(model, alias_cols=False)
compiler = query.get_compiler(connection=schema_editor.connection) compiler = query.get_compiler(connection=schema_editor.connection)
expressions = self._get_expression_sql(compiler, schema_editor.connection, query) expressions = self._get_expression_sql(compiler, schema_editor, query)
condition = self._get_condition_sql(compiler, schema_editor, query) condition = self._get_condition_sql(compiler, schema_editor, query)
include = [model._meta.get_field(field_name).column for field_name in self.include] include = [model._meta.get_field(field_name).column for field_name in self.include]
return self.template % { return self.template % {

View File

@ -7,6 +7,7 @@ from django.db import (
from django.db.models import ( from django.db.models import (
CheckConstraint, Deferrable, F, Func, Q, UniqueConstraint, CheckConstraint, Deferrable, F, Func, Q, UniqueConstraint,
) )
from django.db.models.functions import Left
from django.test import skipUnlessDBFeature from django.test import skipUnlessDBFeature
from django.utils import timezone from django.utils import timezone
@ -608,6 +609,17 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
editor.remove_constraint(RangesModel, constraint) editor.remove_constraint(RangesModel, constraint)
self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
def test_expressions_with_params(self):
constraint_name = 'scene_left_equal'
self.assertNotIn(constraint_name, self.get_constraints(Scene._meta.db_table))
constraint = ExclusionConstraint(
name=constraint_name,
expressions=[(Left('scene', 4), RangeOperators.EQUAL)],
)
with connection.schema_editor() as editor:
editor.add_constraint(Scene, constraint)
self.assertIn(constraint_name, self.get_constraints(Scene._meta.db_table))
def test_range_adjacent_initially_deferred(self): def test_range_adjacent_initially_deferred(self):
constraint_name = 'ints_adjacent_deferred' constraint_name = 'ints_adjacent_deferred'
self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))