Fixed #11964 -- Added support for database check constraints.

This commit is contained in:
Ian Foote 2016-11-05 13:12:12 +00:00 committed by Tim Graham
parent 6fbfb5cb96
commit 952f05a6db
29 changed files with 799 additions and 39 deletions

View File

@ -172,6 +172,7 @@ class BaseDatabaseFeatures:
# Does it support CHECK constraints?
supports_column_check_constraints = True
supports_table_check_constraints = True
# Does the backend support 'pyformat' style ("... %(name)s ...", {'name': value})
# parameter passing? Note this can be provided by the backend even if not

View File

@ -63,7 +63,8 @@ class BaseDatabaseSchemaEditor:
sql_rename_column = "ALTER TABLE %(table)s RENAME COLUMN %(old_column)s TO %(new_column)s"
sql_update_with_default = "UPDATE %(table)s SET %(column)s = %(default)s WHERE %(column)s IS NULL"
sql_create_check = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s CHECK (%(check)s)"
sql_check = "CONSTRAINT %(name)s CHECK (%(check)s)"
sql_create_check = "ALTER TABLE %(table)s ADD %(check)s"
sql_delete_check = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
sql_create_unique = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s UNIQUE (%(columns)s)"
@ -299,10 +300,11 @@ class BaseDatabaseSchemaEditor:
for fields in model._meta.unique_together:
columns = [model._meta.get_field(field).column for field in fields]
self.deferred_sql.append(self._create_unique_sql(model, columns))
constraints = [check.constraint_sql(model, self) for check in model._meta.constraints]
# Make the table
sql = self.sql_create_table % {
"table": self.quote_name(model._meta.db_table),
"definition": ", ".join(column_sqls)
"definition": ", ".join((*column_sqls, *constraints)),
}
if model._meta.db_tablespace:
tablespace_sql = self.connection.ops.tablespace_sql(model._meta.db_tablespace)
@ -343,6 +345,14 @@ class BaseDatabaseSchemaEditor:
"""Remove an index from a model."""
self.execute(index.remove_sql(model, self))
def add_constraint(self, model, constraint):
"""Add a check constraint to a model."""
self.execute(constraint.create_sql(model, self))
def remove_constraint(self, model, constraint):
"""Remove a check constraint from a model."""
self.execute(constraint.remove_sql(model, self))
def alter_unique_together(self, model, old_unique_together, new_unique_together):
"""
Deal with a model changing its unique_together. The input
@ -752,11 +762,12 @@ class BaseDatabaseSchemaEditor:
self.execute(
self.sql_create_check % {
"table": self.quote_name(model._meta.db_table),
"name": self.quote_name(
self._create_index_name(model._meta.db_table, [new_field.column], suffix="_check")
"check": self.sql_check % {
'name': self.quote_name(
self._create_index_name(model._meta.db_table, [new_field.column], suffix='_check'),
),
"column": self.quote_name(new_field.column),
"check": new_db_params['check'],
'check': new_db_params['check'],
},
}
)
# Drop the default if we need to

View File

@ -26,6 +26,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
can_release_savepoints = True
atomic_transactions = False
supports_column_check_constraints = False
supports_table_check_constraints = False
can_clone_databases = True
supports_temporal_subtraction = True
supports_select_intersection = False

View File

@ -126,7 +126,8 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
else:
super().alter_field(model, old_field, new_field, strict=strict)
def _remake_table(self, model, create_field=None, delete_field=None, alter_field=None):
def _remake_table(self, model, create_field=None, delete_field=None, alter_field=None,
add_constraint=None, remove_constraint=None):
"""
Shortcut to transform a model from old_model into new_model
@ -222,6 +223,15 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
if delete_field.name not in index.fields
]
constraints = list(model._meta.constraints)
if add_constraint:
constraints.append(add_constraint)
if remove_constraint:
constraints = [
constraint for constraint in constraints
if remove_constraint.name != constraint.name
]
# Construct a new model for the new state
meta_contents = {
'app_label': model._meta.app_label,
@ -229,6 +239,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
'unique_together': unique_together,
'index_together': index_together,
'indexes': indexes,
'constraints': constraints,
'apps': apps,
}
meta = type("Meta", (), meta_contents)
@ -362,3 +373,9 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
))
# Delete the old through table
self.delete_model(old_field.remote_field.through)
def add_constraint(self, model, constraint):
self._remake_table(model, add_constraint=constraint)
def remove_constraint(self, model, constraint):
self._remake_table(model, remove_constraint=constraint)

View File

@ -122,6 +122,7 @@ class MigrationAutodetector:
# resolve dependencies caused by M2Ms and FKs.
self.generated_operations = {}
self.altered_indexes = {}
self.altered_constraints = {}
# Prepare some old/new state and model lists, separating
# proxy models and ignoring unmigrated apps.
@ -175,7 +176,9 @@ class MigrationAutodetector:
# This avoids the same computation in generate_removed_indexes()
# and generate_added_indexes().
self.create_altered_indexes()
self.create_altered_constraints()
# Generate index removal operations before field is removed
self.generate_removed_constraints()
self.generate_removed_indexes()
# Generate field operations
self.generate_renamed_fields()
@ -185,6 +188,7 @@ class MigrationAutodetector:
self.generate_altered_unique_together()
self.generate_altered_index_together()
self.generate_added_indexes()
self.generate_added_constraints()
self.generate_altered_db_table()
self.generate_altered_order_with_respect_to()
@ -533,6 +537,7 @@ class MigrationAutodetector:
related_fields[field.name] = field
# Are there indexes/unique|index_together to defer?
indexes = model_state.options.pop('indexes')
constraints = model_state.options.pop('constraints')
unique_together = model_state.options.pop('unique_together', None)
index_together = model_state.options.pop('index_together', None)
order_with_respect_to = model_state.options.pop('order_with_respect_to', None)
@ -601,6 +606,15 @@ class MigrationAutodetector:
),
dependencies=related_dependencies,
)
for constraint in constraints:
self.add_operation(
app_label,
operations.AddConstraint(
model_name=model_name,
constraint=constraint,
),
dependencies=related_dependencies,
)
if unique_together:
self.add_operation(
app_label,
@ -997,6 +1011,46 @@ class MigrationAutodetector:
)
)
def create_altered_constraints(self):
option_name = operations.AddConstraint.option_name
for app_label, model_name in sorted(self.kept_model_keys):
old_model_name = self.renamed_models.get((app_label, model_name), model_name)
old_model_state = self.from_state.models[app_label, old_model_name]
new_model_state = self.to_state.models[app_label, model_name]
old_constraints = old_model_state.options[option_name]
new_constraints = new_model_state.options[option_name]
add_constraints = [c for c in new_constraints if c not in old_constraints]
rem_constraints = [c for c in old_constraints if c not in new_constraints]
self.altered_constraints.update({
(app_label, model_name): {
'added_constraints': add_constraints, 'removed_constraints': rem_constraints,
}
})
def generate_added_constraints(self):
for (app_label, model_name), alt_constraints in self.altered_constraints.items():
for constraint in alt_constraints['added_constraints']:
self.add_operation(
app_label,
operations.AddConstraint(
model_name=model_name,
constraint=constraint,
)
)
def generate_removed_constraints(self):
for (app_label, model_name), alt_constraints in self.altered_constraints.items():
for constraint in alt_constraints['removed_constraints']:
self.add_operation(
app_label,
operations.RemoveConstraint(
model_name=model_name,
name=constraint.name,
)
)
def _get_dependencies_for_foreign_key(self, field):
# Account for FKs to swappable models
swappable_setting = getattr(field, 'swappable_setting', None)

View File

@ -1,8 +1,9 @@
from .fields import AddField, AlterField, RemoveField, RenameField
from .models import (
AddIndex, AlterIndexTogether, AlterModelManagers, AlterModelOptions,
AlterModelTable, AlterOrderWithRespectTo, AlterUniqueTogether, CreateModel,
DeleteModel, RemoveIndex, RenameModel,
AddConstraint, AddIndex, AlterIndexTogether, AlterModelManagers,
AlterModelOptions, AlterModelTable, AlterOrderWithRespectTo,
AlterUniqueTogether, CreateModel, DeleteModel, RemoveConstraint,
RemoveIndex, RenameModel,
)
from .special import RunPython, RunSQL, SeparateDatabaseAndState
@ -10,6 +11,7 @@ __all__ = [
'CreateModel', 'DeleteModel', 'AlterModelTable', 'AlterUniqueTogether',
'RenameModel', 'AlterIndexTogether', 'AlterModelOptions', 'AddIndex',
'RemoveIndex', 'AddField', 'RemoveField', 'AlterField', 'RenameField',
'AddConstraint', 'RemoveConstraint',
'SeparateDatabaseAndState', 'RunSQL', 'RunPython',
'AlterOrderWithRespectTo', 'AlterModelManagers',
]

View File

@ -822,3 +822,72 @@ class RemoveIndex(IndexOperation):
def describe(self):
return 'Remove index %s from %s' % (self.name, self.model_name)
class AddConstraint(IndexOperation):
option_name = 'constraints'
def __init__(self, model_name, constraint):
self.model_name = model_name
self.constraint = constraint
def state_forwards(self, app_label, state):
model_state = state.models[app_label, self.model_name_lower]
constraints = list(model_state.options[self.option_name])
constraints.append(self.constraint)
model_state.options[self.option_name] = constraints
def database_forwards(self, app_label, schema_editor, from_state, to_state):
model = to_state.apps.get_model(app_label, self.model_name)
if self.allow_migrate_model(schema_editor.connection.alias, model):
schema_editor.add_constraint(model, self.constraint)
def database_backwards(self, app_label, schema_editor, from_state, to_state):
model = to_state.apps.get_model(app_label, self.model_name)
if self.allow_migrate_model(schema_editor.connection.alias, model):
schema_editor.remove_constraint(model, self.constraint)
def deconstruct(self):
return self.__class__.__name__, [], {
'model_name': self.model_name,
'constraint': self.constraint,
}
def describe(self):
return 'Create constraint %s on model %s' % (self.constraint.name, self.model_name)
class RemoveConstraint(IndexOperation):
option_name = 'constraints'
def __init__(self, model_name, name):
self.model_name = model_name
self.name = name
def state_forwards(self, app_label, state):
model_state = state.models[app_label, self.model_name_lower]
constraints = model_state.options[self.option_name]
model_state.options[self.option_name] = [c for c in constraints if c.name != self.name]
def database_forwards(self, app_label, schema_editor, from_state, to_state):
model = from_state.apps.get_model(app_label, self.model_name)
if self.allow_migrate_model(schema_editor.connection.alias, model):
from_model_state = from_state.models[app_label, self.model_name_lower]
constraint = from_model_state.get_constraint_by_name(self.name)
schema_editor.remove_constraint(model, constraint)
def database_backwards(self, app_label, schema_editor, from_state, to_state):
model = to_state.apps.get_model(app_label, self.model_name)
if self.allow_migrate_model(schema_editor.connection.alias, model):
to_model_state = to_state.models[app_label, self.model_name_lower]
constraint = to_model_state.get_constraint_by_name(self.name)
schema_editor.add_constraint(model, constraint)
def deconstruct(self):
return self.__class__.__name__, [], {
'model_name': self.model_name,
'name': self.name,
}
def describe(self):
return 'Remove constraint %s from model %s' % (self.name, self.model_name)

View File

@ -362,6 +362,7 @@ class ModelState:
self.fields = fields
self.options = options or {}
self.options.setdefault('indexes', [])
self.options.setdefault('constraints', [])
self.bases = bases or (models.Model,)
self.managers = managers or []
# Sanity-check that fields is NOT a dict. It must be ordered.
@ -445,6 +446,8 @@ class ModelState:
if not index.name:
index.set_name_with_model(model)
options['indexes'] = indexes
elif name == 'constraints':
options['constraints'] = [con.clone() for con in model._meta.constraints]
else:
options[name] = model._meta.original_attrs[name]
# If we're ignoring relationships, remove all field-listing model
@ -585,6 +588,12 @@ class ModelState:
return index
raise ValueError("No index named %s on model %s" % (name, self.name))
def get_constraint_by_name(self, name):
for constraint in self.options['constraints']:
if constraint.name == name:
return constraint
raise ValueError('No constraint named %s on model %s' % (name, self.name))
def __repr__(self):
return "<%s: '%s.%s'>" % (self.__class__.__name__, self.app_label, self.name)

View File

@ -2,6 +2,8 @@ from django.core.exceptions import ObjectDoesNotExist
from django.db.models import signals
from django.db.models.aggregates import * # NOQA
from django.db.models.aggregates import __all__ as aggregates_all
from django.db.models.constraints import * # NOQA
from django.db.models.constraints import __all__ as constraints_all
from django.db.models.deletion import (
CASCADE, DO_NOTHING, PROTECT, SET, SET_DEFAULT, SET_NULL, ProtectedError,
)
@ -30,7 +32,7 @@ from django.db.models.fields.related import ( # isort:skip
)
__all__ = aggregates_all + fields_all + indexes_all
__all__ = aggregates_all + constraints_all + fields_all + indexes_all
__all__ += [
'ObjectDoesNotExist', 'signals',
'CASCADE', 'DO_NOTHING', 'PROTECT', 'SET', 'SET_DEFAULT', 'SET_NULL',

View File

@ -16,6 +16,7 @@ from django.db import (
connections, router, transaction,
)
from django.db.models.constants import LOOKUP_SEP
from django.db.models.constraints import CheckConstraint
from django.db.models.deletion import CASCADE, Collector
from django.db.models.fields.related import (
ForeignObjectRel, OneToOneField, lazy_related_operation, resolve_relation,
@ -1201,6 +1202,7 @@ class Model(metaclass=ModelBase):
*cls._check_unique_together(),
*cls._check_indexes(),
*cls._check_ordering(),
*cls._check_constraints(),
]
return errors
@ -1699,6 +1701,29 @@ class Model(metaclass=ModelBase):
return errors
@classmethod
def _check_constraints(cls):
errors = []
for db in settings.DATABASES:
if not router.allow_migrate_model(db, cls):
continue
connection = connections[db]
if connection.features.supports_table_check_constraints:
continue
if any(isinstance(constraint, CheckConstraint) for constraint in cls._meta.constraints):
errors.append(
checks.Warning(
'%s does not support check constraints.' % connection.display_name,
hint=(
"A constraint won't be created. Silence this "
"warning if you don't care about it."
),
obj=cls,
id='models.W027',
)
)
return errors
############################################
# HELPER FUNCTIONS (CURRIED MODEL METHODS) #

View File

@ -0,0 +1,54 @@
from django.db.models.sql.query import Query
__all__ = ['CheckConstraint']
class CheckConstraint:
def __init__(self, constraint, name):
self.constraint = constraint
self.name = name
def constraint_sql(self, model, schema_editor):
query = Query(model)
where = query.build_where(self.constraint)
connection = schema_editor.connection
compiler = connection.ops.compiler('SQLCompiler')(query, connection, 'default')
sql, params = where.as_sql(compiler, connection)
params = tuple(schema_editor.quote_value(p) for p in params)
return schema_editor.sql_check % {
'name': schema_editor.quote_name(self.name),
'check': sql % params,
}
def create_sql(self, model, schema_editor):
sql = self.constraint_sql(model, schema_editor)
return schema_editor.sql_create_check % {
'table': schema_editor.quote_name(model._meta.db_table),
'check': sql,
}
def remove_sql(self, model, schema_editor):
quote_name = schema_editor.quote_name
return schema_editor.sql_delete_check % {
'table': quote_name(model._meta.db_table),
'name': quote_name(self.name),
}
def __repr__(self):
return "<%s: constraint='%s' name='%s'>" % (self.__class__.__name__, self.constraint, self.name)
def __eq__(self, other):
return (
isinstance(other, CheckConstraint) and
self.name == other.name and
self.constraint == other.constraint
)
def deconstruct(self):
path = '%s.%s' % (self.__class__.__module__, self.__class__.__name__)
path = path.replace('django.db.models.constraints', 'django.db.models')
return (path, (), {'constraint': self.constraint, 'name': self.name})
def clone(self):
_, args, kwargs = self.deconstruct()
return self.__class__(*args, **kwargs)

View File

@ -505,8 +505,9 @@ class F(Combinable):
def __repr__(self):
return "{}({})".format(self.__class__.__name__, self.name)
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
return query.resolve_ref(self.name, allow_joins, reuse, summarize)
def resolve_expression(self, query=None, allow_joins=True, reuse=None,
summarize=False, for_save=False, simple_col=False):
return query.resolve_ref(self.name, allow_joins, reuse, summarize, simple_col)
def asc(self, **kwargs):
return OrderBy(self, **kwargs)
@ -542,7 +543,8 @@ class ResolvedOuterRef(F):
class OuterRef(F):
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
def resolve_expression(self, query=None, allow_joins=True, reuse=None,
summarize=False, for_save=False, simple_col=False):
if isinstance(self.name, self.__class__):
return self.name
return ResolvedOuterRef(self.name)
@ -746,6 +748,40 @@ class Col(Expression):
self.target.get_db_converters(connection))
class SimpleCol(Expression):
"""
Represents the SQL of a column name without the table name.
This variant of Col doesn't include the table name (or an alias) to
avoid a syntax error in check constraints.
"""
contains_column_references = True
def __init__(self, target, output_field=None):
if output_field is None:
output_field = target
super().__init__(output_field=output_field)
self.target = target
def __repr__(self):
return '{}({})'.format(self.__class__.__name__, self.target)
def as_sql(self, compiler, connection):
qn = compiler.quote_name_unless_alias
return qn(self.target.column), []
def get_group_by_cols(self):
return [self]
def get_db_converters(self, connection):
if self.target == self.output_field:
return self.output_field.get_db_converters(connection)
return (
self.output_field.get_db_converters(connection) +
self.target.get_db_converters(connection)
)
class Ref(Expression):
"""
Reference to column alias of the query. For example, Ref('sum_cost') in

View File

@ -32,7 +32,7 @@ DEFAULT_NAMES = (
'auto_created', 'index_together', 'apps', 'default_permissions',
'select_on_save', 'default_related_name', 'required_db_features',
'required_db_vendor', 'base_manager_name', 'default_manager_name',
'indexes',
'indexes', 'constraints',
# For backwards compatibility with Django 1.11. RemovedInDjango30Warning
'manager_inheritance_from_future',
)
@ -89,6 +89,7 @@ class Options:
self.ordering = []
self._ordering_clash = False
self.indexes = []
self.constraints = []
self.unique_together = []
self.index_together = []
self.select_on_save = False

View File

@ -18,7 +18,7 @@ from django.core.exceptions import (
from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections
from django.db.models.aggregates import Count
from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import Col, Ref
from django.db.models.expressions import Col, F, Ref, SimpleCol
from django.db.models.fields import Field
from django.db.models.fields.related_lookups import MultiColSource
from django.db.models.lookups import Lookup
@ -62,6 +62,12 @@ JoinInfo = namedtuple(
)
def _get_col(target, field, alias, simple_col):
if simple_col:
return SimpleCol(target, field)
return target.get_col(alias, field)
class RawQuery:
"""A single raw SQL query."""
@ -1011,14 +1017,23 @@ class Query:
def as_sql(self, compiler, connection):
return self.get_compiler(connection=connection).as_sql()
def resolve_lookup_value(self, value, can_reuse, allow_joins):
def resolve_lookup_value(self, value, can_reuse, allow_joins, simple_col):
if hasattr(value, 'resolve_expression'):
value = value.resolve_expression(self, reuse=can_reuse, allow_joins=allow_joins)
kwargs = {'reuse': can_reuse, 'allow_joins': allow_joins}
if isinstance(value, F):
kwargs['simple_col'] = simple_col
value = value.resolve_expression(self, **kwargs)
elif isinstance(value, (list, tuple)):
# The items of the iterable may be expressions and therefore need
# to be resolved independently.
for sub_value in value:
if hasattr(sub_value, 'resolve_expression'):
if isinstance(sub_value, F):
sub_value.resolve_expression(
self, reuse=can_reuse, allow_joins=allow_joins,
simple_col=simple_col,
)
else:
sub_value.resolve_expression(self, reuse=can_reuse, allow_joins=allow_joins)
return value
@ -1133,7 +1148,7 @@ class Query:
def build_filter(self, filter_expr, branch_negated=False, current_negated=False,
can_reuse=None, allow_joins=True, split_subq=True,
reuse_with_filtered_relation=False):
reuse_with_filtered_relation=False, simple_col=False):
"""
Build a WhereNode for a single filter clause but don't add it
to this Query. Query.add_q() will then add this filter to the where
@ -1179,7 +1194,7 @@ class Query:
raise FieldError("Joined field references are not permitted in this query")
pre_joins = self.alias_refcount.copy()
value = self.resolve_lookup_value(value, can_reuse, allow_joins)
value = self.resolve_lookup_value(value, can_reuse, allow_joins, simple_col)
used_joins = {k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0)}
clause = self.where_class()
@ -1222,11 +1237,11 @@ class Query:
if num_lookups > 1:
raise FieldError('Related Field got invalid lookup: {}'.format(lookups[0]))
if len(targets) == 1:
col = targets[0].get_col(alias, join_info.final_field)
col = _get_col(targets[0], join_info.final_field, alias, simple_col)
else:
col = MultiColSource(alias, targets, join_info.targets, join_info.final_field)
else:
col = targets[0].get_col(alias, join_info.final_field)
col = _get_col(targets[0], join_info.final_field, alias, simple_col)
condition = self.build_lookup(lookups, col, value)
lookup_type = condition.lookup_name
@ -1248,7 +1263,8 @@ class Query:
# <=>
# NOT (col IS NOT NULL AND col = someval).
lookup_class = targets[0].get_lookup('isnull')
clause.add(lookup_class(targets[0].get_col(alias, join_info.targets[0]), False), AND)
col = _get_col(targets[0], join_info.targets[0], alias, simple_col)
clause.add(lookup_class(col, False), AND)
return clause, used_joins if not require_outer else ()
def add_filter(self, filter_clause):
@ -1271,8 +1287,12 @@ class Query:
self.where.add(clause, AND)
self.demote_joins(existing_inner)
def build_where(self, q_object):
return self._add_q(q_object, used_aliases=set(), allow_joins=False, simple_col=True)[0]
def _add_q(self, q_object, used_aliases, branch_negated=False,
current_negated=False, allow_joins=True, split_subq=True):
current_negated=False, allow_joins=True, split_subq=True,
simple_col=False):
"""Add a Q-object to the current filter."""
connector = q_object.connector
current_negated = current_negated ^ q_object.negated
@ -1290,7 +1310,7 @@ class Query:
child_clause, needed_inner = self.build_filter(
child, can_reuse=used_aliases, branch_negated=branch_negated,
current_negated=current_negated, allow_joins=allow_joins,
split_subq=split_subq,
split_subq=split_subq, simple_col=simple_col,
)
joinpromoter.add_votes(needed_inner)
if child_clause:
@ -1559,7 +1579,7 @@ class Query:
self.unref_alias(joins.pop())
return targets, joins[-1], joins
def resolve_ref(self, name, allow_joins=True, reuse=None, summarize=False):
def resolve_ref(self, name, allow_joins=True, reuse=None, summarize=False, simple_col=False):
if not allow_joins and LOOKUP_SEP in name:
raise FieldError("Joined field references are not permitted in this query")
if name in self.annotations:
@ -1580,7 +1600,7 @@ class Query:
"isn't supported")
if reuse is not None:
reuse.update(join_list)
col = targets[0].get_col(join_list[-1], join_info.targets[0])
col = _get_col(targets[0], join_info.targets[0], join_list[-1], simple_col)
return col
def split_exclude(self, filter_expr, can_reuse, names_with_path):

View File

@ -297,6 +297,7 @@ Models
field accessor.
* **models.E026**: The model cannot have more than one field with
``primary_key=True``.
* **models.W027**: ``<database>`` does not support check constraints.
Security
--------

View File

@ -207,6 +207,25 @@ Creates an index in the database table for the model with ``model_name``.
Removes the index named ``name`` from the model with ``model_name``.
``AddConstraint``
-----------------
.. class:: AddConstraint(model_name, constraint)
.. versionadded:: 2.2
Creates a constraint in the database table for the model with ``model_name``.
``constraint`` is an instance of :class:`~django.db.models.CheckConstraint`.
``RemoveConstraint``
--------------------
.. class:: RemoveConstraint(model_name, name)
.. versionadded:: 2.2
Removes the constraint named ``name`` from the model with ``model_name``.
Special Operations
==================

View File

@ -0,0 +1,46 @@
===========================
Check constraints reference
===========================
.. module:: django.db.models.constraints
.. currentmodule:: django.db.models
.. versionadded:: 2.2
The ``CheckConstraint`` class creates database check constraints. They are
added in the model :attr:`Meta.constraints
<django.db.models.Options.constraints>` option. This document
explains the API references of :class:`CheckConstraint`.
.. admonition:: Referencing built-in constraints
Constraints are defined in ``django.db.models.constraints``, but for
convenience they're imported into :mod:`django.db.models`. The standard
convention is to use ``from django.db import models`` and refer to the
constraints as ``models.CheckConstraint``.
``CheckConstraint`` options
===========================
.. class:: CheckConstraint(constraint, name)
Creates a check constraint in the database.
``constraint``
--------------
.. attribute:: CheckConstraint.constraint
A :class:`Q` object that specifies the condition you want the constraint to
enforce.
For example ``CheckConstraint(Q(age__gte=18), 'age_gte_18')`` ensures the age
field is never less than 18.
``name``
--------
.. attribute:: CheckConstraint.name
The name of the constraint.

View File

@ -9,6 +9,7 @@ Model API reference. For introductory material, see :doc:`/topics/db/models`.
fields
indexes
check-constraints
meta
relations
class

View File

@ -451,6 +451,26 @@ Django quotes column and table names behind the scenes.
index_together = ["pub_date", "deadline"]
``constraints``
---------------
.. attribute:: Options.constraints
.. versionadded:: 2.2
A list of :doc:`constraints </ref/models/check-constraints>` that you want
to define on the model::
from django.db import models
class Customer(models.Model):
age = models.IntegerField()
class Meta:
constraints = [
models.CheckConstraint(models.Q(age__gte=18), 'age_gte_18'),
]
``verbose_name``
----------------

View File

@ -30,6 +30,13 @@ officially support the latest release of each series.
What's new in Django 2.2
========================
Check Constraints
-----------------
The new :class:`~django.db.models.CheckConstraint` class enables adding custom
database constraints. Constraints are added to models using the
:attr:`Meta.constraints <django.db.models.Options.constraints>` option.
Minor features
--------------
@ -213,7 +220,9 @@ Backwards incompatible changes in 2.2
Database backend API
--------------------
* ...
* Third-party database backends must implement support for table check
constraints or set ``DatabaseFeatures.supports_table_check_constraints`` to
``False``.
:mod:`django.contrib.gis`
-------------------------

View File

View File

@ -0,0 +1,15 @@
from django.db import models
class Product(models.Model):
name = models.CharField(max_length=255)
price = models.IntegerField()
discounted_price = models.IntegerField()
class Meta:
constraints = [
models.CheckConstraint(
models.Q(price__gt=models.F('discounted_price')),
'price_gt_discounted_price'
)
]

View File

@ -0,0 +1,30 @@
from django.db import IntegrityError, models
from django.test import TestCase, skipUnlessDBFeature
from .models import Product
class CheckConstraintTests(TestCase):
def test_repr(self):
constraint = models.Q(price__gt=models.F('discounted_price'))
name = 'price_gt_discounted_price'
check = models.CheckConstraint(constraint, name)
self.assertEqual(
repr(check),
"<CheckConstraint: constraint='{}' name='{}'>".format(constraint, name),
)
def test_deconstruction(self):
constraint = models.Q(price__gt=models.F('discounted_price'))
name = 'price_gt_discounted_price'
check = models.CheckConstraint(constraint, name)
path, args, kwargs = check.deconstruct()
self.assertEqual(path, 'django.db.models.CheckConstraint')
self.assertEqual(args, ())
self.assertEqual(kwargs, {'constraint': constraint, 'name': name})
@skipUnlessDBFeature('supports_table_check_constraints')
def test_model_constraint(self):
Product.objects.create(name='Valid', price=10, discounted_price=5)
with self.assertRaises(IntegrityError):
Product.objects.create(name='Invalid', price=10, discounted_price=20)

View File

@ -1,10 +1,10 @@
import unittest
from django.conf import settings
from django.core.checks import Error
from django.core.checks import Error, Warning
from django.core.checks.model_checks import _check_lazy_references
from django.core.exceptions import ImproperlyConfigured
from django.db import connections, models
from django.db import connection, connections, models
from django.db.models.signals import post_init
from django.test import SimpleTestCase
from django.test.utils import isolate_apps, override_settings
@ -972,3 +972,26 @@ class OtherModelTests(SimpleTestCase):
id='signals.E001',
),
])
@isolate_apps('invalid_models_tests')
class ConstraintsTests(SimpleTestCase):
def test_check_constraints(self):
class Model(models.Model):
age = models.IntegerField()
class Meta:
constraints = [models.CheckConstraint(models.Q(age__gte=18), 'is_adult')]
errors = Model.check()
warn = Warning(
'%s does not support check constraints.' % connection.display_name,
hint=(
"A constraint won't be created. Silence this warning if you "
"don't care about it."
),
obj=Model,
id='models.W027',
)
expected = [] if connection.features.supports_table_check_constraints else [warn, warn]
self.assertCountEqual(errors, expected)

View File

@ -61,6 +61,12 @@ class AutodetectorTests(TestCase):
("id", models.AutoField(primary_key=True)),
("name", models.CharField(max_length=200, default='Ada Lovelace')),
])
author_name_check_constraint = ModelState("testapp", "Author", [
("id", models.AutoField(primary_key=True)),
("name", models.CharField(max_length=200)),
],
{'constraints': [models.CheckConstraint(models.Q(name__contains='Bob'), 'name_contains_bob')]},
)
author_dates_of_birth_auto_now = ModelState("testapp", "Author", [
("id", models.AutoField(primary_key=True)),
("date_of_birth", models.DateField(auto_now=True)),
@ -1389,6 +1395,40 @@ class AutodetectorTests(TestCase):
added_index = models.Index(fields=['title', 'author'], name='book_author_title_idx')
self.assertOperationAttributes(changes, 'otherapp', 0, 1, model_name='book', index=added_index)
def test_create_model_with_check_constraint(self):
"""Test creation of new model with constraints already defined."""
author = ModelState('otherapp', 'Author', [
('id', models.AutoField(primary_key=True)),
('name', models.CharField(max_length=200)),
], {'constraints': [models.CheckConstraint(models.Q(name__contains='Bob'), 'name_contains_bob')]})
changes = self.get_changes([], [author])
added_constraint = models.CheckConstraint(models.Q(name__contains='Bob'), 'name_contains_bob')
# Right number of migrations?
self.assertEqual(len(changes['otherapp']), 1)
# Right number of actions?
migration = changes['otherapp'][0]
self.assertEqual(len(migration.operations), 2)
# Right actions order?
self.assertOperationTypes(changes, 'otherapp', 0, ['CreateModel', 'AddConstraint'])
self.assertOperationAttributes(changes, 'otherapp', 0, 0, name='Author')
self.assertOperationAttributes(changes, 'otherapp', 0, 1, model_name='author', constraint=added_constraint)
def test_add_constraints(self):
"""Test change detection of new constraints."""
changes = self.get_changes([self.author_name], [self.author_name_check_constraint])
self.assertNumberMigrations(changes, 'testapp', 1)
self.assertOperationTypes(changes, 'testapp', 0, ['AddConstraint'])
added_constraint = models.CheckConstraint(models.Q(name__contains='Bob'), 'name_contains_bob')
self.assertOperationAttributes(changes, 'testapp', 0, 0, model_name='author', constraint=added_constraint)
def test_remove_constraints(self):
"""Test change detection of removed constraints."""
changes = self.get_changes([self.author_name_check_constraint], [self.author_name])
# Right number/type of migrations?
self.assertNumberMigrations(changes, 'testapp', 1)
self.assertOperationTypes(changes, 'testapp', 0, ['RemoveConstraint'])
self.assertOperationAttributes(changes, 'testapp', 0, 0, model_name='author', name='name_contains_bob')
def test_add_foo_together(self):
"""Tests index/unique_together detection."""
changes = self.get_changes([self.author_empty, self.book], [self.author_empty, self.book_foo_together])
@ -1520,7 +1560,7 @@ class AutodetectorTests(TestCase):
self.assertNumberMigrations(changes, "testapp", 1)
self.assertOperationTypes(changes, "testapp", 0, ["CreateModel"])
self.assertOperationAttributes(
changes, "testapp", 0, 0, name="AuthorProxy", options={"proxy": True, "indexes": []}
changes, "testapp", 0, 0, name="AuthorProxy", options={"proxy": True, "indexes": [], "constraints": []}
)
# Now, we test turning a proxy model into a non-proxy model
# It should delete the proxy then make the real one

View File

@ -67,6 +67,17 @@ class MigrationTestBase(TransactionTestCase):
def assertIndexNotExists(self, table, columns):
return self.assertIndexExists(table, columns, False)
def assertConstraintExists(self, table, name, value=True, using='default'):
with connections[using].cursor() as cursor:
constraints = connections[using].introspection.get_constraints(cursor, table).items()
self.assertEqual(
value,
any(c['check'] for n, c in constraints if n == name),
)
def assertConstraintNotExists(self, table, name):
return self.assertConstraintExists(table, name, False)
def assertFKExists(self, table, columns, to, value=True, using='default'):
with connections[using].cursor() as cursor:
self.assertEqual(

View File

@ -53,7 +53,7 @@ class OperationTestBase(MigrationTestBase):
def set_up_test_model(
self, app_label, second_model=False, third_model=False, index=False, multicol_index=False,
related_model=False, mti_model=False, proxy_model=False, manager_model=False,
unique_together=False, options=False, db_table=None, index_together=False):
unique_together=False, options=False, db_table=None, index_together=False, check_constraint=False):
"""
Creates a test model state and database table.
"""
@ -106,6 +106,11 @@ class OperationTestBase(MigrationTestBase):
"Pony",
models.Index(fields=["pink", "weight"], name="pony_test_idx")
))
if check_constraint:
operations.append(migrations.AddConstraint(
"Pony",
models.CheckConstraint(models.Q(pink__gt=2), name="pony_test_constraint")
))
if second_model:
operations.append(migrations.CreateModel(
"Stable",
@ -462,6 +467,45 @@ class OperationTests(OperationTestBase):
self.assertTableNotExists("test_crummo_unmanagedpony")
self.assertTableExists("test_crummo_pony")
@skipUnlessDBFeature('supports_table_check_constraints')
def test_create_model_with_constraint(self):
where = models.Q(pink__gt=2)
check_constraint = models.CheckConstraint(where, name='test_constraint_pony_pink_gt_2')
operation = migrations.CreateModel(
"Pony",
[
("id", models.AutoField(primary_key=True)),
("pink", models.IntegerField(default=3)),
],
options={'constraints': [check_constraint]},
)
# Test the state alteration
project_state = ProjectState()
new_state = project_state.clone()
operation.state_forwards("test_crmo", new_state)
self.assertEqual(len(new_state.models['test_crmo', 'pony'].options['constraints']), 1)
# Test database alteration
self.assertTableNotExists("test_crmo_pony")
with connection.schema_editor() as editor:
operation.database_forwards("test_crmo", editor, project_state, new_state)
self.assertTableExists("test_crmo_pony")
with connection.cursor() as cursor:
with self.assertRaises(IntegrityError):
cursor.execute("INSERT INTO test_crmo_pony (id, pink) VALUES (1, 1)")
# Test reversal
with connection.schema_editor() as editor:
operation.database_backwards("test_crmo", editor, new_state, project_state)
self.assertTableNotExists("test_crmo_pony")
# Test deconstruction
definition = operation.deconstruct()
self.assertEqual(definition[0], "CreateModel")
self.assertEqual(definition[1], [])
self.assertEqual(definition[2]['options']['constraints'], [check_constraint])
def test_create_model_managers(self):
"""
The managers on a model are set.
@ -1708,6 +1752,87 @@ class OperationTests(OperationTestBase):
operation = migrations.AlterIndexTogether("Pony", None)
self.assertEqual(operation.describe(), "Alter index_together for Pony (0 constraint(s))")
@skipUnlessDBFeature('supports_table_check_constraints')
def test_add_constraint(self):
"""Test the AddConstraint operation."""
project_state = self.set_up_test_model('test_addconstraint')
where = models.Q(pink__gt=2)
check_constraint = models.CheckConstraint(where, name='test_constraint_pony_pink_gt_2')
operation = migrations.AddConstraint('Pony', check_constraint)
self.assertEqual(operation.describe(), 'Create constraint test_constraint_pony_pink_gt_2 on model Pony')
new_state = project_state.clone()
operation.state_forwards('test_addconstraint', new_state)
self.assertEqual(len(new_state.models['test_addconstraint', 'pony'].options['constraints']), 1)
# Test database alteration
with connection.cursor() as cursor:
with atomic():
cursor.execute("INSERT INTO test_addconstraint_pony (id, pink, weight) VALUES (1, 1, 1.0)")
cursor.execute("DELETE FROM test_addconstraint_pony")
with connection.schema_editor() as editor:
operation.database_forwards("test_addconstraint", editor, project_state, new_state)
with connection.cursor() as cursor:
with self.assertRaises(IntegrityError):
cursor.execute("INSERT INTO test_addconstraint_pony (id, pink, weight) VALUES (1, 1, 1.0)")
# Test reversal
with connection.schema_editor() as editor:
operation.database_backwards("test_addconstraint", editor, new_state, project_state)
with connection.cursor() as cursor:
with atomic():
cursor.execute("INSERT INTO test_addconstraint_pony (id, pink, weight) VALUES (1, 1, 1.0)")
cursor.execute("DELETE FROM test_addconstraint_pony")
# Test deconstruction
definition = operation.deconstruct()
self.assertEqual(definition[0], "AddConstraint")
self.assertEqual(definition[1], [])
self.assertEqual(definition[2], {'model_name': "Pony", 'constraint': check_constraint})
@skipUnlessDBFeature('supports_table_check_constraints')
def test_remove_constraint(self):
"""Test the RemoveConstraint operation."""
project_state = self.set_up_test_model("test_removeconstraint", check_constraint=True)
self.assertTableExists("test_removeconstraint_pony")
operation = migrations.RemoveConstraint("Pony", "pony_test_constraint")
self.assertEqual(operation.describe(), "Remove constraint pony_test_constraint from model Pony")
new_state = project_state.clone()
operation.state_forwards("test_removeconstraint", new_state)
# Test state alteration
self.assertEqual(len(new_state.models["test_removeconstraint", "pony"].options['constraints']), 0)
with connection.cursor() as cursor:
with self.assertRaises(IntegrityError):
cursor.execute("INSERT INTO test_removeconstraint_pony (id, pink, weight) VALUES (1, 1, 1.0)")
# Test database alteration
with connection.schema_editor() as editor:
operation.database_forwards("test_removeconstraint", editor, project_state, new_state)
with connection.cursor() as cursor:
with atomic():
cursor.execute("INSERT INTO test_removeconstraint_pony (id, pink, weight) VALUES (1, 1, 1.0)")
cursor.execute("DELETE FROM test_removeconstraint_pony")
# Test reversal
with connection.schema_editor() as editor:
operation.database_backwards("test_removeconstraint", editor, new_state, project_state)
with connection.cursor() as cursor:
with self.assertRaises(IntegrityError):
cursor.execute("INSERT INTO test_removeconstraint_pony (id, pink, weight) VALUES (1, 1, 1.0)")
# Test deconstruction
definition = operation.deconstruct()
self.assertEqual(definition[0], "RemoveConstraint")
self.assertEqual(definition[1], [])
self.assertEqual(definition[2], {'model_name': "Pony", 'name': "pony_test_constraint"})
def test_alter_model_options(self):
"""
Tests the AlterModelOptions operation.

View File

@ -127,7 +127,12 @@ class StateTests(SimpleTestCase):
self.assertIs(author_state.fields[3][1].null, True)
self.assertEqual(
author_state.options,
{"unique_together": {("name", "bio")}, "index_together": {("bio", "age")}, "indexes": []}
{
"unique_together": {("name", "bio")},
"index_together": {("bio", "age")},
"indexes": [],
"constraints": [],
}
)
self.assertEqual(author_state.bases, (models.Model,))
@ -139,14 +144,17 @@ class StateTests(SimpleTestCase):
self.assertEqual(book_state.fields[3][1].__class__.__name__, "ManyToManyField")
self.assertEqual(
book_state.options,
{"verbose_name": "tome", "db_table": "test_tome", "indexes": [book_index]},
{"verbose_name": "tome", "db_table": "test_tome", "indexes": [book_index], "constraints": []},
)
self.assertEqual(book_state.bases, (models.Model,))
self.assertEqual(author_proxy_state.app_label, "migrations")
self.assertEqual(author_proxy_state.name, "AuthorProxy")
self.assertEqual(author_proxy_state.fields, [])
self.assertEqual(author_proxy_state.options, {"proxy": True, "ordering": ["name"], "indexes": []})
self.assertEqual(
author_proxy_state.options,
{"proxy": True, "ordering": ["name"], "indexes": [], "constraints": []},
)
self.assertEqual(author_proxy_state.bases, ("migrations.author",))
self.assertEqual(sub_author_state.app_label, "migrations")
@ -1002,7 +1010,7 @@ class ModelStateTests(SimpleTestCase):
self.assertEqual(author_state.fields[1][1].max_length, 255)
self.assertIs(author_state.fields[2][1].null, False)
self.assertIs(author_state.fields[3][1].null, True)
self.assertEqual(author_state.options, {'swappable': 'TEST_SWAPPABLE_MODEL', 'indexes': []})
self.assertEqual(author_state.options, {'swappable': 'TEST_SWAPPABLE_MODEL', 'indexes': [], "constraints": []})
self.assertEqual(author_state.bases, (models.Model,))
self.assertEqual(author_state.managers, [])
@ -1047,7 +1055,7 @@ class ModelStateTests(SimpleTestCase):
self.assertEqual(station_state.fields[2][1].null, False)
self.assertEqual(
station_state.options,
{'abstract': False, 'swappable': 'TEST_SWAPPABLE_MODEL', 'indexes': []}
{'abstract': False, 'swappable': 'TEST_SWAPPABLE_MODEL', 'indexes': [], 'constraints': []}
)
self.assertEqual(station_state.bases, ('migrations.searchablelocation',))
self.assertEqual(station_state.managers, [])
@ -1129,6 +1137,21 @@ class ModelStateTests(SimpleTestCase):
index_names = [index.name for index in model_state.options['indexes']]
self.assertEqual(index_names, ['foo_idx'])
@isolate_apps('migrations')
def test_from_model_constraints(self):
class ModelWithConstraints(models.Model):
size = models.IntegerField()
class Meta:
constraints = [models.CheckConstraint(models.Q(size__gt=1), 'size_gt_1')]
state = ModelState.from_model(ModelWithConstraints)
model_constraints = ModelWithConstraints._meta.constraints
state_constraints = state.options['constraints']
self.assertEqual(model_constraints, state_constraints)
self.assertIsNot(model_constraints, state_constraints)
self.assertIsNot(model_constraints[0], state_constraints[0])
class RelatedModelsTests(SimpleTestCase):

View File

@ -0,0 +1,95 @@
from datetime import datetime
from django.core.exceptions import FieldError
from django.db.models import CharField, F, Q
from django.db.models.expressions import SimpleCol
from django.db.models.fields.related_lookups import RelatedIsNull
from django.db.models.functions import Lower
from django.db.models.lookups import Exact, GreaterThan, IsNull, LessThan
from django.db.models.sql.query import Query
from django.db.models.sql.where import OR
from django.test import TestCase
from .models import Author, Item, ObjectC, Ranking
class TestQuery(TestCase):
def test_simple_query(self):
query = Query(Author)
where = query.build_where(Q(num__gt=2))
lookup = where.children[0]
self.assertIsInstance(lookup, GreaterThan)
self.assertEqual(lookup.rhs, 2)
self.assertEqual(lookup.lhs.target, Author._meta.get_field('num'))
def test_complex_query(self):
query = Query(Author)
where = query.build_where(Q(num__gt=2) | Q(num__lt=0))
self.assertEqual(where.connector, OR)
lookup = where.children[0]
self.assertIsInstance(lookup, GreaterThan)
self.assertEqual(lookup.rhs, 2)
self.assertEqual(lookup.lhs.target, Author._meta.get_field('num'))
lookup = where.children[1]
self.assertIsInstance(lookup, LessThan)
self.assertEqual(lookup.rhs, 0)
self.assertEqual(lookup.lhs.target, Author._meta.get_field('num'))
def test_multiple_fields(self):
query = Query(Item)
where = query.build_where(Q(modified__gt=F('created')))
lookup = where.children[0]
self.assertIsInstance(lookup, GreaterThan)
self.assertIsInstance(lookup.rhs, SimpleCol)
self.assertIsInstance(lookup.lhs, SimpleCol)
self.assertEqual(lookup.rhs.target, Item._meta.get_field('created'))
self.assertEqual(lookup.lhs.target, Item._meta.get_field('modified'))
def test_transform(self):
query = Query(Author)
CharField.register_lookup(Lower, 'lower')
try:
where = query.build_where(~Q(name__lower='foo'))
finally:
CharField._unregister_lookup(Lower, 'lower')
lookup = where.children[0]
self.assertIsInstance(lookup, Exact)
self.assertIsInstance(lookup.lhs, Lower)
self.assertIsInstance(lookup.lhs.lhs, SimpleCol)
self.assertEqual(lookup.lhs.lhs.target, Author._meta.get_field('name'))
def test_negated_nullable(self):
query = Query(Item)
where = query.build_where(~Q(modified__lt=datetime(2017, 1, 1)))
self.assertTrue(where.negated)
lookup = where.children[0]
self.assertIsInstance(lookup, LessThan)
self.assertEqual(lookup.lhs.target, Item._meta.get_field('modified'))
lookup = where.children[1]
self.assertIsInstance(lookup, IsNull)
self.assertEqual(lookup.lhs.target, Item._meta.get_field('modified'))
def test_foreign_key(self):
query = Query(Item)
msg = 'Joined field references are not permitted in this query'
with self.assertRaisesMessage(FieldError, msg):
query.build_where(Q(creator__num__gt=2))
def test_foreign_key_f(self):
query = Query(Ranking)
with self.assertRaises(FieldError):
query.build_where(Q(rank__gt=F('author__num')))
def test_foreign_key_exclusive(self):
query = Query(ObjectC)
where = query.build_where(Q(objecta=None) | Q(objectb=None))
a_isnull = where.children[0]
self.assertIsInstance(a_isnull, RelatedIsNull)
self.assertIsInstance(a_isnull.lhs, SimpleCol)
self.assertEqual(a_isnull.lhs.target, ObjectC._meta.get_field('objecta'))
b_isnull = where.children[1]
self.assertIsInstance(b_isnull, RelatedIsNull)
self.assertIsInstance(b_isnull.lhs, SimpleCol)
self.assertEqual(b_isnull.lhs.target, ObjectC._meta.get_field('objectb'))