Fixed #31902 -- Fixed crash of ExclusionConstraint on expressions with params.
This commit is contained in:
parent
e2e34f4de3
commit
bf6d07730c
|
@ -66,20 +66,21 @@ class ExclusionConstraint(BaseConstraint):
|
||||||
self.opclasses = opclasses
|
self.opclasses = opclasses
|
||||||
super().__init__(name=name)
|
super().__init__(name=name)
|
||||||
|
|
||||||
def _get_expression_sql(self, compiler, connection, query):
|
def _get_expression_sql(self, compiler, schema_editor, query):
|
||||||
expressions = []
|
expressions = []
|
||||||
for idx, (expression, operator) in enumerate(self.expressions):
|
for idx, (expression, operator) in enumerate(self.expressions):
|
||||||
if isinstance(expression, str):
|
if isinstance(expression, str):
|
||||||
expression = F(expression)
|
expression = F(expression)
|
||||||
expression = expression.resolve_expression(query=query)
|
expression = expression.resolve_expression(query=query)
|
||||||
sql, params = expression.as_sql(compiler, connection)
|
sql, params = expression.as_sql(compiler, schema_editor.connection)
|
||||||
try:
|
try:
|
||||||
opclass = self.opclasses[idx]
|
opclass = self.opclasses[idx]
|
||||||
if opclass:
|
if opclass:
|
||||||
sql = '%s %s' % (sql, opclass)
|
sql = '%s %s' % (sql, opclass)
|
||||||
except IndexError:
|
except IndexError:
|
||||||
pass
|
pass
|
||||||
expressions.append('%s WITH %s' % (sql % params, operator))
|
sql = sql % tuple(schema_editor.quote_value(p) for p in params)
|
||||||
|
expressions.append('%s WITH %s' % (sql, operator))
|
||||||
return expressions
|
return expressions
|
||||||
|
|
||||||
def _get_condition_sql(self, compiler, schema_editor, query):
|
def _get_condition_sql(self, compiler, schema_editor, query):
|
||||||
|
@ -92,7 +93,7 @@ class ExclusionConstraint(BaseConstraint):
|
||||||
def constraint_sql(self, model, schema_editor):
|
def constraint_sql(self, model, schema_editor):
|
||||||
query = Query(model, alias_cols=False)
|
query = Query(model, alias_cols=False)
|
||||||
compiler = query.get_compiler(connection=schema_editor.connection)
|
compiler = query.get_compiler(connection=schema_editor.connection)
|
||||||
expressions = self._get_expression_sql(compiler, schema_editor.connection, query)
|
expressions = self._get_expression_sql(compiler, schema_editor, query)
|
||||||
condition = self._get_condition_sql(compiler, schema_editor, query)
|
condition = self._get_condition_sql(compiler, schema_editor, query)
|
||||||
include = [model._meta.get_field(field_name).column for field_name in self.include]
|
include = [model._meta.get_field(field_name).column for field_name in self.include]
|
||||||
return self.template % {
|
return self.template % {
|
||||||
|
|
|
@ -7,6 +7,7 @@ from django.db import (
|
||||||
from django.db.models import (
|
from django.db.models import (
|
||||||
CheckConstraint, Deferrable, F, Func, Q, UniqueConstraint,
|
CheckConstraint, Deferrable, F, Func, Q, UniqueConstraint,
|
||||||
)
|
)
|
||||||
|
from django.db.models.functions import Left
|
||||||
from django.test import skipUnlessDBFeature
|
from django.test import skipUnlessDBFeature
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
|
|
||||||
|
@ -608,6 +609,17 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
|
||||||
editor.remove_constraint(RangesModel, constraint)
|
editor.remove_constraint(RangesModel, constraint)
|
||||||
self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
|
self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
|
||||||
|
|
||||||
|
def test_expressions_with_params(self):
|
||||||
|
constraint_name = 'scene_left_equal'
|
||||||
|
self.assertNotIn(constraint_name, self.get_constraints(Scene._meta.db_table))
|
||||||
|
constraint = ExclusionConstraint(
|
||||||
|
name=constraint_name,
|
||||||
|
expressions=[(Left('scene', 4), RangeOperators.EQUAL)],
|
||||||
|
)
|
||||||
|
with connection.schema_editor() as editor:
|
||||||
|
editor.add_constraint(Scene, constraint)
|
||||||
|
self.assertIn(constraint_name, self.get_constraints(Scene._meta.db_table))
|
||||||
|
|
||||||
def test_range_adjacent_initially_deferred(self):
|
def test_range_adjacent_initially_deferred(self):
|
||||||
constraint_name = 'ints_adjacent_deferred'
|
constraint_name = 'ints_adjacent_deferred'
|
||||||
self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
|
self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
|
||||||
|
|
Loading…
Reference in New Issue