Refs #29641 -- Extracted reusable CheckConstraint logic into a base class.

This commit is contained in:
Simon Charette 2018-08-05 22:12:27 -04:00 committed by Tim Graham
parent 9142bebff2
commit 24dc7d8940
2 changed files with 43 additions and 21 deletions

View File

@ -3,22 +3,12 @@ from django.db.models.sql.query import Query
__all__ = ['CheckConstraint'] __all__ = ['CheckConstraint']
class CheckConstraint: class BaseConstraint:
def __init__(self, *, check, name): def __init__(self, name):
self.check = check
self.name = name self.name = name
def constraint_sql(self, model, schema_editor): def constraint_sql(self, model, schema_editor):
query = Query(model) raise NotImplementedError('This method must be implemented by a subclass.')
where = query.build_where(self.check)
connection = schema_editor.connection
compiler = connection.ops.compiler('SQLCompiler')(query, connection, 'default')
sql, params = where.as_sql(compiler, connection)
params = tuple(schema_editor.quote_value(p) for p in params)
return schema_editor.sql_check % {
'name': schema_editor.quote_name(self.name),
'check': sql % params,
}
def create_sql(self, model, schema_editor): def create_sql(self, model, schema_editor):
sql = self.constraint_sql(model, schema_editor) sql = self.constraint_sql(model, schema_editor)
@ -34,6 +24,33 @@ class CheckConstraint:
'name': quote_name(self.name), 'name': quote_name(self.name),
} }
def deconstruct(self):
path = '%s.%s' % (self.__class__.__module__, self.__class__.__name__)
path = path.replace('django.db.models.constraints', 'django.db.models')
return (path, (), {'name': self.name})
def clone(self):
_, args, kwargs = self.deconstruct()
return self.__class__(*args, **kwargs)
class CheckConstraint(BaseConstraint):
def __init__(self, *, check, name):
self.check = check
super().__init__(name)
def constraint_sql(self, model, schema_editor):
query = Query(model)
where = query.build_where(self.check)
connection = schema_editor.connection
compiler = connection.ops.compiler('SQLCompiler')(query, connection, 'default')
sql, params = where.as_sql(compiler, connection)
params = tuple(schema_editor.quote_value(p) for p in params)
return schema_editor.sql_check % {
'name': schema_editor.quote_name(self.name),
'check': sql % params,
}
def __repr__(self): def __repr__(self):
return "<%s: check='%s' name=%r>" % (self.__class__.__name__, self.check, self.name) return "<%s: check='%s' name=%r>" % (self.__class__.__name__, self.check, self.name)
@ -45,10 +62,6 @@ class CheckConstraint:
) )
def deconstruct(self): def deconstruct(self):
path = '%s.%s' % (self.__class__.__module__, self.__class__.__name__) path, args, kwargs = super().deconstruct()
path = path.replace('django.db.models.constraints', 'django.db.models') kwargs['check'] = self.check
return (path, (), {'check': self.check, 'name': self.name}) return path, args, kwargs
def clone(self):
_, args, kwargs = self.deconstruct()
return self.__class__(*args, **kwargs)

View File

@ -1,9 +1,18 @@
from django.db import IntegrityError, models from django.db import IntegrityError, models
from django.test import TestCase, skipUnlessDBFeature from django.db.models.constraints import BaseConstraint
from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature
from .models import Product from .models import Product
class BaseConstraintTests(SimpleTestCase):
def test_constraint_sql(self):
c = BaseConstraint('name')
msg = 'This method must be implemented by a subclass.'
with self.assertRaisesMessage(NotImplementedError, msg):
c.constraint_sql(None, None)
class CheckConstraintTests(TestCase): class CheckConstraintTests(TestCase):
def test_repr(self): def test_repr(self):
check = models.Q(price__gt=models.F('discounted_price')) check = models.Q(price__gt=models.F('discounted_price'))