diff --git a/django/db/backends/base/features.py b/django/db/backends/base/features.py index d5b142383a..582fa2dc1f 100644 --- a/django/db/backends/base/features.py +++ b/django/db/backends/base/features.py @@ -172,6 +172,7 @@ class BaseDatabaseFeatures: # Does it support CHECK constraints? supports_column_check_constraints = True + supports_table_check_constraints = True # Does the backend support 'pyformat' style ("... %(name)s ...", {'name': value}) # parameter passing? Note this can be provided by the backend even if not diff --git a/django/db/backends/base/schema.py b/django/db/backends/base/schema.py index ec2cf0e5c7..9608e95afb 100644 --- a/django/db/backends/base/schema.py +++ b/django/db/backends/base/schema.py @@ -63,7 +63,8 @@ class BaseDatabaseSchemaEditor: sql_rename_column = "ALTER TABLE %(table)s RENAME COLUMN %(old_column)s TO %(new_column)s" sql_update_with_default = "UPDATE %(table)s SET %(column)s = %(default)s WHERE %(column)s IS NULL" - sql_create_check = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s CHECK (%(check)s)" + sql_check = "CONSTRAINT %(name)s CHECK (%(check)s)" + sql_create_check = "ALTER TABLE %(table)s ADD %(check)s" sql_delete_check = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s" sql_create_unique = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s UNIQUE (%(columns)s)" @@ -299,10 +300,11 @@ class BaseDatabaseSchemaEditor: for fields in model._meta.unique_together: columns = [model._meta.get_field(field).column for field in fields] self.deferred_sql.append(self._create_unique_sql(model, columns)) + constraints = [check.constraint_sql(model, self) for check in model._meta.constraints] # Make the table sql = self.sql_create_table % { "table": self.quote_name(model._meta.db_table), - "definition": ", ".join(column_sqls) + "definition": ", ".join((*column_sqls, *constraints)), } if model._meta.db_tablespace: tablespace_sql = self.connection.ops.tablespace_sql(model._meta.db_tablespace) @@ -343,6 +345,14 @@ class BaseDatabaseSchemaEditor: """Remove an index from a model.""" self.execute(index.remove_sql(model, self)) + def add_constraint(self, model, constraint): + """Add a check constraint to a model.""" + self.execute(constraint.create_sql(model, self)) + + def remove_constraint(self, model, constraint): + """Remove a check constraint from a model.""" + self.execute(constraint.remove_sql(model, self)) + def alter_unique_together(self, model, old_unique_together, new_unique_together): """ Deal with a model changing its unique_together. The input @@ -752,11 +762,12 @@ class BaseDatabaseSchemaEditor: self.execute( self.sql_create_check % { "table": self.quote_name(model._meta.db_table), - "name": self.quote_name( - self._create_index_name(model._meta.db_table, [new_field.column], suffix="_check") - ), - "column": self.quote_name(new_field.column), - "check": new_db_params['check'], + "check": self.sql_check % { + 'name': self.quote_name( + self._create_index_name(model._meta.db_table, [new_field.column], suffix='_check'), + ), + 'check': new_db_params['check'], + }, } ) # Drop the default if we need to diff --git a/django/db/backends/mysql/features.py b/django/db/backends/mysql/features.py index ec80c6e54e..b50513d779 100644 --- a/django/db/backends/mysql/features.py +++ b/django/db/backends/mysql/features.py @@ -26,6 +26,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): can_release_savepoints = True atomic_transactions = False supports_column_check_constraints = False + supports_table_check_constraints = False can_clone_databases = True supports_temporal_subtraction = True supports_select_intersection = False diff --git a/django/db/backends/sqlite3/schema.py b/django/db/backends/sqlite3/schema.py index b27d39d732..8710e9d0e2 100644 --- a/django/db/backends/sqlite3/schema.py +++ b/django/db/backends/sqlite3/schema.py @@ -126,7 +126,8 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): else: super().alter_field(model, old_field, new_field, strict=strict) - def _remake_table(self, model, create_field=None, delete_field=None, alter_field=None): + def _remake_table(self, model, create_field=None, delete_field=None, alter_field=None, + add_constraint=None, remove_constraint=None): """ Shortcut to transform a model from old_model into new_model @@ -222,6 +223,15 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): if delete_field.name not in index.fields ] + constraints = list(model._meta.constraints) + if add_constraint: + constraints.append(add_constraint) + if remove_constraint: + constraints = [ + constraint for constraint in constraints + if remove_constraint.name != constraint.name + ] + # Construct a new model for the new state meta_contents = { 'app_label': model._meta.app_label, @@ -229,6 +239,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): 'unique_together': unique_together, 'index_together': index_together, 'indexes': indexes, + 'constraints': constraints, 'apps': apps, } meta = type("Meta", (), meta_contents) @@ -362,3 +373,9 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): )) # Delete the old through table self.delete_model(old_field.remote_field.through) + + def add_constraint(self, model, constraint): + self._remake_table(model, add_constraint=constraint) + + def remove_constraint(self, model, constraint): + self._remake_table(model, remove_constraint=constraint) diff --git a/django/db/migrations/autodetector.py b/django/db/migrations/autodetector.py index f32d4af9eb..bf9a45530a 100644 --- a/django/db/migrations/autodetector.py +++ b/django/db/migrations/autodetector.py @@ -122,6 +122,7 @@ class MigrationAutodetector: # resolve dependencies caused by M2Ms and FKs. self.generated_operations = {} self.altered_indexes = {} + self.altered_constraints = {} # Prepare some old/new state and model lists, separating # proxy models and ignoring unmigrated apps. @@ -175,7 +176,9 @@ class MigrationAutodetector: # This avoids the same computation in generate_removed_indexes() # and generate_added_indexes(). self.create_altered_indexes() + self.create_altered_constraints() # Generate index removal operations before field is removed + self.generate_removed_constraints() self.generate_removed_indexes() # Generate field operations self.generate_renamed_fields() @@ -185,6 +188,7 @@ class MigrationAutodetector: self.generate_altered_unique_together() self.generate_altered_index_together() self.generate_added_indexes() + self.generate_added_constraints() self.generate_altered_db_table() self.generate_altered_order_with_respect_to() @@ -533,6 +537,7 @@ class MigrationAutodetector: related_fields[field.name] = field # Are there indexes/unique|index_together to defer? indexes = model_state.options.pop('indexes') + constraints = model_state.options.pop('constraints') unique_together = model_state.options.pop('unique_together', None) index_together = model_state.options.pop('index_together', None) order_with_respect_to = model_state.options.pop('order_with_respect_to', None) @@ -601,6 +606,15 @@ class MigrationAutodetector: ), dependencies=related_dependencies, ) + for constraint in constraints: + self.add_operation( + app_label, + operations.AddConstraint( + model_name=model_name, + constraint=constraint, + ), + dependencies=related_dependencies, + ) if unique_together: self.add_operation( app_label, @@ -997,6 +1011,46 @@ class MigrationAutodetector: ) ) + def create_altered_constraints(self): + option_name = operations.AddConstraint.option_name + for app_label, model_name in sorted(self.kept_model_keys): + old_model_name = self.renamed_models.get((app_label, model_name), model_name) + old_model_state = self.from_state.models[app_label, old_model_name] + new_model_state = self.to_state.models[app_label, model_name] + + old_constraints = old_model_state.options[option_name] + new_constraints = new_model_state.options[option_name] + add_constraints = [c for c in new_constraints if c not in old_constraints] + rem_constraints = [c for c in old_constraints if c not in new_constraints] + + self.altered_constraints.update({ + (app_label, model_name): { + 'added_constraints': add_constraints, 'removed_constraints': rem_constraints, + } + }) + + def generate_added_constraints(self): + for (app_label, model_name), alt_constraints in self.altered_constraints.items(): + for constraint in alt_constraints['added_constraints']: + self.add_operation( + app_label, + operations.AddConstraint( + model_name=model_name, + constraint=constraint, + ) + ) + + def generate_removed_constraints(self): + for (app_label, model_name), alt_constraints in self.altered_constraints.items(): + for constraint in alt_constraints['removed_constraints']: + self.add_operation( + app_label, + operations.RemoveConstraint( + model_name=model_name, + name=constraint.name, + ) + ) + def _get_dependencies_for_foreign_key(self, field): # Account for FKs to swappable models swappable_setting = getattr(field, 'swappable_setting', None) diff --git a/django/db/migrations/operations/__init__.py b/django/db/migrations/operations/__init__.py index 894f2ab9c5..119c955868 100644 --- a/django/db/migrations/operations/__init__.py +++ b/django/db/migrations/operations/__init__.py @@ -1,8 +1,9 @@ from .fields import AddField, AlterField, RemoveField, RenameField from .models import ( - AddIndex, AlterIndexTogether, AlterModelManagers, AlterModelOptions, - AlterModelTable, AlterOrderWithRespectTo, AlterUniqueTogether, CreateModel, - DeleteModel, RemoveIndex, RenameModel, + AddConstraint, AddIndex, AlterIndexTogether, AlterModelManagers, + AlterModelOptions, AlterModelTable, AlterOrderWithRespectTo, + AlterUniqueTogether, CreateModel, DeleteModel, RemoveConstraint, + RemoveIndex, RenameModel, ) from .special import RunPython, RunSQL, SeparateDatabaseAndState @@ -10,6 +11,7 @@ __all__ = [ 'CreateModel', 'DeleteModel', 'AlterModelTable', 'AlterUniqueTogether', 'RenameModel', 'AlterIndexTogether', 'AlterModelOptions', 'AddIndex', 'RemoveIndex', 'AddField', 'RemoveField', 'AlterField', 'RenameField', + 'AddConstraint', 'RemoveConstraint', 'SeparateDatabaseAndState', 'RunSQL', 'RunPython', 'AlterOrderWithRespectTo', 'AlterModelManagers', ] diff --git a/django/db/migrations/operations/models.py b/django/db/migrations/operations/models.py index 857981bcb8..bd1c66a01d 100644 --- a/django/db/migrations/operations/models.py +++ b/django/db/migrations/operations/models.py @@ -822,3 +822,72 @@ class RemoveIndex(IndexOperation): def describe(self): return 'Remove index %s from %s' % (self.name, self.model_name) + + +class AddConstraint(IndexOperation): + option_name = 'constraints' + + def __init__(self, model_name, constraint): + self.model_name = model_name + self.constraint = constraint + + def state_forwards(self, app_label, state): + model_state = state.models[app_label, self.model_name_lower] + constraints = list(model_state.options[self.option_name]) + constraints.append(self.constraint) + model_state.options[self.option_name] = constraints + + def database_forwards(self, app_label, schema_editor, from_state, to_state): + model = to_state.apps.get_model(app_label, self.model_name) + if self.allow_migrate_model(schema_editor.connection.alias, model): + schema_editor.add_constraint(model, self.constraint) + + def database_backwards(self, app_label, schema_editor, from_state, to_state): + model = to_state.apps.get_model(app_label, self.model_name) + if self.allow_migrate_model(schema_editor.connection.alias, model): + schema_editor.remove_constraint(model, self.constraint) + + def deconstruct(self): + return self.__class__.__name__, [], { + 'model_name': self.model_name, + 'constraint': self.constraint, + } + + def describe(self): + return 'Create constraint %s on model %s' % (self.constraint.name, self.model_name) + + +class RemoveConstraint(IndexOperation): + option_name = 'constraints' + + def __init__(self, model_name, name): + self.model_name = model_name + self.name = name + + def state_forwards(self, app_label, state): + model_state = state.models[app_label, self.model_name_lower] + constraints = model_state.options[self.option_name] + model_state.options[self.option_name] = [c for c in constraints if c.name != self.name] + + def database_forwards(self, app_label, schema_editor, from_state, to_state): + model = from_state.apps.get_model(app_label, self.model_name) + if self.allow_migrate_model(schema_editor.connection.alias, model): + from_model_state = from_state.models[app_label, self.model_name_lower] + constraint = from_model_state.get_constraint_by_name(self.name) + schema_editor.remove_constraint(model, constraint) + + def database_backwards(self, app_label, schema_editor, from_state, to_state): + model = to_state.apps.get_model(app_label, self.model_name) + if self.allow_migrate_model(schema_editor.connection.alias, model): + to_model_state = to_state.models[app_label, self.model_name_lower] + constraint = to_model_state.get_constraint_by_name(self.name) + schema_editor.add_constraint(model, constraint) + + def deconstruct(self): + return self.__class__.__name__, [], { + 'model_name': self.model_name, + 'name': self.name, + } + + def describe(self): + return 'Remove constraint %s from model %s' % (self.name, self.model_name) diff --git a/django/db/migrations/state.py b/django/db/migrations/state.py index f41d1edf2c..ea2db0e5af 100644 --- a/django/db/migrations/state.py +++ b/django/db/migrations/state.py @@ -362,6 +362,7 @@ class ModelState: self.fields = fields self.options = options or {} self.options.setdefault('indexes', []) + self.options.setdefault('constraints', []) self.bases = bases or (models.Model,) self.managers = managers or [] # Sanity-check that fields is NOT a dict. It must be ordered. @@ -445,6 +446,8 @@ class ModelState: if not index.name: index.set_name_with_model(model) options['indexes'] = indexes + elif name == 'constraints': + options['constraints'] = [con.clone() for con in model._meta.constraints] else: options[name] = model._meta.original_attrs[name] # If we're ignoring relationships, remove all field-listing model @@ -585,6 +588,12 @@ class ModelState: return index raise ValueError("No index named %s on model %s" % (name, self.name)) + def get_constraint_by_name(self, name): + for constraint in self.options['constraints']: + if constraint.name == name: + return constraint + raise ValueError('No constraint named %s on model %s' % (name, self.name)) + def __repr__(self): return "<%s: '%s.%s'>" % (self.__class__.__name__, self.app_label, self.name) diff --git a/django/db/models/__init__.py b/django/db/models/__init__.py index 27c67c7fc4..79b175c1d5 100644 --- a/django/db/models/__init__.py +++ b/django/db/models/__init__.py @@ -2,6 +2,8 @@ from django.core.exceptions import ObjectDoesNotExist from django.db.models import signals from django.db.models.aggregates import * # NOQA from django.db.models.aggregates import __all__ as aggregates_all +from django.db.models.constraints import * # NOQA +from django.db.models.constraints import __all__ as constraints_all from django.db.models.deletion import ( CASCADE, DO_NOTHING, PROTECT, SET, SET_DEFAULT, SET_NULL, ProtectedError, ) @@ -30,7 +32,7 @@ from django.db.models.fields.related import ( # isort:skip ) -__all__ = aggregates_all + fields_all + indexes_all +__all__ = aggregates_all + constraints_all + fields_all + indexes_all __all__ += [ 'ObjectDoesNotExist', 'signals', 'CASCADE', 'DO_NOTHING', 'PROTECT', 'SET', 'SET_DEFAULT', 'SET_NULL', diff --git a/django/db/models/base.py b/django/db/models/base.py index 251825991f..3574f7f676 100644 --- a/django/db/models/base.py +++ b/django/db/models/base.py @@ -16,6 +16,7 @@ from django.db import ( connections, router, transaction, ) from django.db.models.constants import LOOKUP_SEP +from django.db.models.constraints import CheckConstraint from django.db.models.deletion import CASCADE, Collector from django.db.models.fields.related import ( ForeignObjectRel, OneToOneField, lazy_related_operation, resolve_relation, @@ -1201,6 +1202,7 @@ class Model(metaclass=ModelBase): *cls._check_unique_together(), *cls._check_indexes(), *cls._check_ordering(), + *cls._check_constraints(), ] return errors @@ -1699,6 +1701,29 @@ class Model(metaclass=ModelBase): return errors + @classmethod + def _check_constraints(cls): + errors = [] + for db in settings.DATABASES: + if not router.allow_migrate_model(db, cls): + continue + connection = connections[db] + if connection.features.supports_table_check_constraints: + continue + if any(isinstance(constraint, CheckConstraint) for constraint in cls._meta.constraints): + errors.append( + checks.Warning( + '%s does not support check constraints.' % connection.display_name, + hint=( + "A constraint won't be created. Silence this " + "warning if you don't care about it." + ), + obj=cls, + id='models.W027', + ) + ) + return errors + ############################################ # HELPER FUNCTIONS (CURRIED MODEL METHODS) # diff --git a/django/db/models/constraints.py b/django/db/models/constraints.py new file mode 100644 index 0000000000..fe99f8310d --- /dev/null +++ b/django/db/models/constraints.py @@ -0,0 +1,54 @@ +from django.db.models.sql.query import Query + +__all__ = ['CheckConstraint'] + + +class CheckConstraint: + def __init__(self, constraint, name): + self.constraint = constraint + self.name = name + + def constraint_sql(self, model, schema_editor): + query = Query(model) + where = query.build_where(self.constraint) + connection = schema_editor.connection + compiler = connection.ops.compiler('SQLCompiler')(query, connection, 'default') + sql, params = where.as_sql(compiler, connection) + params = tuple(schema_editor.quote_value(p) for p in params) + return schema_editor.sql_check % { + 'name': schema_editor.quote_name(self.name), + 'check': sql % params, + } + + def create_sql(self, model, schema_editor): + sql = self.constraint_sql(model, schema_editor) + return schema_editor.sql_create_check % { + 'table': schema_editor.quote_name(model._meta.db_table), + 'check': sql, + } + + def remove_sql(self, model, schema_editor): + quote_name = schema_editor.quote_name + return schema_editor.sql_delete_check % { + 'table': quote_name(model._meta.db_table), + 'name': quote_name(self.name), + } + + def __repr__(self): + return "<%s: constraint='%s' name='%s'>" % (self.__class__.__name__, self.constraint, self.name) + + def __eq__(self, other): + return ( + isinstance(other, CheckConstraint) and + self.name == other.name and + self.constraint == other.constraint + ) + + def deconstruct(self): + path = '%s.%s' % (self.__class__.__module__, self.__class__.__name__) + path = path.replace('django.db.models.constraints', 'django.db.models') + return (path, (), {'constraint': self.constraint, 'name': self.name}) + + def clone(self): + _, args, kwargs = self.deconstruct() + return self.__class__(*args, **kwargs) diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 219485750f..86ce77daa2 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -505,8 +505,9 @@ class F(Combinable): def __repr__(self): return "{}({})".format(self.__class__.__name__, self.name) - def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): - return query.resolve_ref(self.name, allow_joins, reuse, summarize) + def resolve_expression(self, query=None, allow_joins=True, reuse=None, + summarize=False, for_save=False, simple_col=False): + return query.resolve_ref(self.name, allow_joins, reuse, summarize, simple_col) def asc(self, **kwargs): return OrderBy(self, **kwargs) @@ -542,7 +543,8 @@ class ResolvedOuterRef(F): class OuterRef(F): - def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): + def resolve_expression(self, query=None, allow_joins=True, reuse=None, + summarize=False, for_save=False, simple_col=False): if isinstance(self.name, self.__class__): return self.name return ResolvedOuterRef(self.name) @@ -746,6 +748,40 @@ class Col(Expression): self.target.get_db_converters(connection)) +class SimpleCol(Expression): + """ + Represents the SQL of a column name without the table name. + + This variant of Col doesn't include the table name (or an alias) to + avoid a syntax error in check constraints. + """ + contains_column_references = True + + def __init__(self, target, output_field=None): + if output_field is None: + output_field = target + super().__init__(output_field=output_field) + self.target = target + + def __repr__(self): + return '{}({})'.format(self.__class__.__name__, self.target) + + def as_sql(self, compiler, connection): + qn = compiler.quote_name_unless_alias + return qn(self.target.column), [] + + def get_group_by_cols(self): + return [self] + + def get_db_converters(self, connection): + if self.target == self.output_field: + return self.output_field.get_db_converters(connection) + return ( + self.output_field.get_db_converters(connection) + + self.target.get_db_converters(connection) + ) + + class Ref(Expression): """ Reference to column alias of the query. For example, Ref('sum_cost') in diff --git a/django/db/models/options.py b/django/db/models/options.py index c0c925375f..98bd2f0064 100644 --- a/django/db/models/options.py +++ b/django/db/models/options.py @@ -32,7 +32,7 @@ DEFAULT_NAMES = ( 'auto_created', 'index_together', 'apps', 'default_permissions', 'select_on_save', 'default_related_name', 'required_db_features', 'required_db_vendor', 'base_manager_name', 'default_manager_name', - 'indexes', + 'indexes', 'constraints', # For backwards compatibility with Django 1.11. RemovedInDjango30Warning 'manager_inheritance_from_future', ) @@ -89,6 +89,7 @@ class Options: self.ordering = [] self._ordering_clash = False self.indexes = [] + self.constraints = [] self.unique_together = [] self.index_together = [] self.select_on_save = False diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 01ff007eda..2a3510f992 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -18,7 +18,7 @@ from django.core.exceptions import ( from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections from django.db.models.aggregates import Count from django.db.models.constants import LOOKUP_SEP -from django.db.models.expressions import Col, Ref +from django.db.models.expressions import Col, F, Ref, SimpleCol from django.db.models.fields import Field from django.db.models.fields.related_lookups import MultiColSource from django.db.models.lookups import Lookup @@ -62,6 +62,12 @@ JoinInfo = namedtuple( ) +def _get_col(target, field, alias, simple_col): + if simple_col: + return SimpleCol(target, field) + return target.get_col(alias, field) + + class RawQuery: """A single raw SQL query.""" @@ -1011,15 +1017,24 @@ class Query: def as_sql(self, compiler, connection): return self.get_compiler(connection=connection).as_sql() - def resolve_lookup_value(self, value, can_reuse, allow_joins): + def resolve_lookup_value(self, value, can_reuse, allow_joins, simple_col): if hasattr(value, 'resolve_expression'): - value = value.resolve_expression(self, reuse=can_reuse, allow_joins=allow_joins) + kwargs = {'reuse': can_reuse, 'allow_joins': allow_joins} + if isinstance(value, F): + kwargs['simple_col'] = simple_col + value = value.resolve_expression(self, **kwargs) elif isinstance(value, (list, tuple)): # The items of the iterable may be expressions and therefore need # to be resolved independently. for sub_value in value: if hasattr(sub_value, 'resolve_expression'): - sub_value.resolve_expression(self, reuse=can_reuse, allow_joins=allow_joins) + if isinstance(sub_value, F): + sub_value.resolve_expression( + self, reuse=can_reuse, allow_joins=allow_joins, + simple_col=simple_col, + ) + else: + sub_value.resolve_expression(self, reuse=can_reuse, allow_joins=allow_joins) return value def solve_lookup_type(self, lookup): @@ -1133,7 +1148,7 @@ class Query: def build_filter(self, filter_expr, branch_negated=False, current_negated=False, can_reuse=None, allow_joins=True, split_subq=True, - reuse_with_filtered_relation=False): + reuse_with_filtered_relation=False, simple_col=False): """ Build a WhereNode for a single filter clause but don't add it to this Query. Query.add_q() will then add this filter to the where @@ -1179,7 +1194,7 @@ class Query: raise FieldError("Joined field references are not permitted in this query") pre_joins = self.alias_refcount.copy() - value = self.resolve_lookup_value(value, can_reuse, allow_joins) + value = self.resolve_lookup_value(value, can_reuse, allow_joins, simple_col) used_joins = {k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0)} clause = self.where_class() @@ -1222,11 +1237,11 @@ class Query: if num_lookups > 1: raise FieldError('Related Field got invalid lookup: {}'.format(lookups[0])) if len(targets) == 1: - col = targets[0].get_col(alias, join_info.final_field) + col = _get_col(targets[0], join_info.final_field, alias, simple_col) else: col = MultiColSource(alias, targets, join_info.targets, join_info.final_field) else: - col = targets[0].get_col(alias, join_info.final_field) + col = _get_col(targets[0], join_info.final_field, alias, simple_col) condition = self.build_lookup(lookups, col, value) lookup_type = condition.lookup_name @@ -1248,7 +1263,8 @@ class Query: # <=> # NOT (col IS NOT NULL AND col = someval). lookup_class = targets[0].get_lookup('isnull') - clause.add(lookup_class(targets[0].get_col(alias, join_info.targets[0]), False), AND) + col = _get_col(targets[0], join_info.targets[0], alias, simple_col) + clause.add(lookup_class(col, False), AND) return clause, used_joins if not require_outer else () def add_filter(self, filter_clause): @@ -1271,8 +1287,12 @@ class Query: self.where.add(clause, AND) self.demote_joins(existing_inner) + def build_where(self, q_object): + return self._add_q(q_object, used_aliases=set(), allow_joins=False, simple_col=True)[0] + def _add_q(self, q_object, used_aliases, branch_negated=False, - current_negated=False, allow_joins=True, split_subq=True): + current_negated=False, allow_joins=True, split_subq=True, + simple_col=False): """Add a Q-object to the current filter.""" connector = q_object.connector current_negated = current_negated ^ q_object.negated @@ -1290,7 +1310,7 @@ class Query: child_clause, needed_inner = self.build_filter( child, can_reuse=used_aliases, branch_negated=branch_negated, current_negated=current_negated, allow_joins=allow_joins, - split_subq=split_subq, + split_subq=split_subq, simple_col=simple_col, ) joinpromoter.add_votes(needed_inner) if child_clause: @@ -1559,7 +1579,7 @@ class Query: self.unref_alias(joins.pop()) return targets, joins[-1], joins - def resolve_ref(self, name, allow_joins=True, reuse=None, summarize=False): + def resolve_ref(self, name, allow_joins=True, reuse=None, summarize=False, simple_col=False): if not allow_joins and LOOKUP_SEP in name: raise FieldError("Joined field references are not permitted in this query") if name in self.annotations: @@ -1580,7 +1600,7 @@ class Query: "isn't supported") if reuse is not None: reuse.update(join_list) - col = targets[0].get_col(join_list[-1], join_info.targets[0]) + col = _get_col(targets[0], join_info.targets[0], join_list[-1], simple_col) return col def split_exclude(self, filter_expr, can_reuse, names_with_path): diff --git a/docs/ref/checks.txt b/docs/ref/checks.txt index b0d7dfc066..ee16ebde54 100644 --- a/docs/ref/checks.txt +++ b/docs/ref/checks.txt @@ -297,6 +297,7 @@ Models field accessor. * **models.E026**: The model cannot have more than one field with ``primary_key=True``. +* **models.W027**: ```` does not support check constraints. Security -------- diff --git a/docs/ref/migration-operations.txt b/docs/ref/migration-operations.txt index b45134b46d..c117145fde 100644 --- a/docs/ref/migration-operations.txt +++ b/docs/ref/migration-operations.txt @@ -207,6 +207,25 @@ Creates an index in the database table for the model with ``model_name``. Removes the index named ``name`` from the model with ``model_name``. +``AddConstraint`` +----------------- + +.. class:: AddConstraint(model_name, constraint) + +.. versionadded:: 2.2 + +Creates a constraint in the database table for the model with ``model_name``. +``constraint`` is an instance of :class:`~django.db.models.CheckConstraint`. + +``RemoveConstraint`` +-------------------- + +.. class:: RemoveConstraint(model_name, name) + +.. versionadded:: 2.2 + +Removes the constraint named ``name`` from the model with ``model_name``. + Special Operations ================== diff --git a/docs/ref/models/check-constraints.txt b/docs/ref/models/check-constraints.txt new file mode 100644 index 0000000000..29681d7ebc --- /dev/null +++ b/docs/ref/models/check-constraints.txt @@ -0,0 +1,46 @@ +=========================== +Check constraints reference +=========================== + +.. module:: django.db.models.constraints + +.. currentmodule:: django.db.models + +.. versionadded:: 2.2 + +The ``CheckConstraint`` class creates database check constraints. They are +added in the model :attr:`Meta.constraints +` option. This document +explains the API references of :class:`CheckConstraint`. + +.. admonition:: Referencing built-in constraints + + Constraints are defined in ``django.db.models.constraints``, but for + convenience they're imported into :mod:`django.db.models`. The standard + convention is to use ``from django.db import models`` and refer to the + constraints as ``models.CheckConstraint``. + +``CheckConstraint`` options +=========================== + +.. class:: CheckConstraint(constraint, name) + + Creates a check constraint in the database. + +``constraint`` +-------------- + +.. attribute:: CheckConstraint.constraint + +A :class:`Q` object that specifies the condition you want the constraint to +enforce. + +For example ``CheckConstraint(Q(age__gte=18), 'age_gte_18')`` ensures the age +field is never less than 18. + +``name`` +-------- + +.. attribute:: CheckConstraint.name + +The name of the constraint. diff --git a/docs/ref/models/index.txt b/docs/ref/models/index.txt index d731ee37dc..c3aa5a718a 100644 --- a/docs/ref/models/index.txt +++ b/docs/ref/models/index.txt @@ -9,6 +9,7 @@ Model API reference. For introductory material, see :doc:`/topics/db/models`. fields indexes + check-constraints meta relations class diff --git a/docs/ref/models/options.txt b/docs/ref/models/options.txt index cea73eb67f..246f3e7d9a 100644 --- a/docs/ref/models/options.txt +++ b/docs/ref/models/options.txt @@ -451,6 +451,26 @@ Django quotes column and table names behind the scenes. index_together = ["pub_date", "deadline"] +``constraints`` +--------------- + +.. attribute:: Options.constraints + + .. versionadded:: 2.2 + + A list of :doc:`constraints ` that you want + to define on the model:: + + from django.db import models + + class Customer(models.Model): + age = models.IntegerField() + + class Meta: + constraints = [ + models.CheckConstraint(models.Q(age__gte=18), 'age_gte_18'), + ] + ``verbose_name`` ---------------- diff --git a/docs/releases/2.2.txt b/docs/releases/2.2.txt index 0fdd57e340..29a6f54f9d 100644 --- a/docs/releases/2.2.txt +++ b/docs/releases/2.2.txt @@ -30,6 +30,13 @@ officially support the latest release of each series. What's new in Django 2.2 ======================== +Check Constraints +----------------- + +The new :class:`~django.db.models.CheckConstraint` class enables adding custom +database constraints. Constraints are added to models using the +:attr:`Meta.constraints ` option. + Minor features -------------- @@ -213,7 +220,9 @@ Backwards incompatible changes in 2.2 Database backend API -------------------- -* ... +* Third-party database backends must implement support for table check + constraints or set ``DatabaseFeatures.supports_table_check_constraints`` to + ``False``. :mod:`django.contrib.gis` ------------------------- diff --git a/tests/constraints/__init__.py b/tests/constraints/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/constraints/models.py b/tests/constraints/models.py new file mode 100644 index 0000000000..de49fa2765 --- /dev/null +++ b/tests/constraints/models.py @@ -0,0 +1,15 @@ +from django.db import models + + +class Product(models.Model): + name = models.CharField(max_length=255) + price = models.IntegerField() + discounted_price = models.IntegerField() + + class Meta: + constraints = [ + models.CheckConstraint( + models.Q(price__gt=models.F('discounted_price')), + 'price_gt_discounted_price' + ) + ] diff --git a/tests/constraints/tests.py b/tests/constraints/tests.py new file mode 100644 index 0000000000..19573dffa1 --- /dev/null +++ b/tests/constraints/tests.py @@ -0,0 +1,30 @@ +from django.db import IntegrityError, models +from django.test import TestCase, skipUnlessDBFeature + +from .models import Product + + +class CheckConstraintTests(TestCase): + def test_repr(self): + constraint = models.Q(price__gt=models.F('discounted_price')) + name = 'price_gt_discounted_price' + check = models.CheckConstraint(constraint, name) + self.assertEqual( + repr(check), + "".format(constraint, name), + ) + + def test_deconstruction(self): + constraint = models.Q(price__gt=models.F('discounted_price')) + name = 'price_gt_discounted_price' + check = models.CheckConstraint(constraint, name) + path, args, kwargs = check.deconstruct() + self.assertEqual(path, 'django.db.models.CheckConstraint') + self.assertEqual(args, ()) + self.assertEqual(kwargs, {'constraint': constraint, 'name': name}) + + @skipUnlessDBFeature('supports_table_check_constraints') + def test_model_constraint(self): + Product.objects.create(name='Valid', price=10, discounted_price=5) + with self.assertRaises(IntegrityError): + Product.objects.create(name='Invalid', price=10, discounted_price=20) diff --git a/tests/invalid_models_tests/test_models.py b/tests/invalid_models_tests/test_models.py index 19ec21c9ae..9dd2fd1f06 100644 --- a/tests/invalid_models_tests/test_models.py +++ b/tests/invalid_models_tests/test_models.py @@ -1,10 +1,10 @@ import unittest from django.conf import settings -from django.core.checks import Error +from django.core.checks import Error, Warning from django.core.checks.model_checks import _check_lazy_references from django.core.exceptions import ImproperlyConfigured -from django.db import connections, models +from django.db import connection, connections, models from django.db.models.signals import post_init from django.test import SimpleTestCase from django.test.utils import isolate_apps, override_settings @@ -972,3 +972,26 @@ class OtherModelTests(SimpleTestCase): id='signals.E001', ), ]) + + +@isolate_apps('invalid_models_tests') +class ConstraintsTests(SimpleTestCase): + def test_check_constraints(self): + class Model(models.Model): + age = models.IntegerField() + + class Meta: + constraints = [models.CheckConstraint(models.Q(age__gte=18), 'is_adult')] + + errors = Model.check() + warn = Warning( + '%s does not support check constraints.' % connection.display_name, + hint=( + "A constraint won't be created. Silence this warning if you " + "don't care about it." + ), + obj=Model, + id='models.W027', + ) + expected = [] if connection.features.supports_table_check_constraints else [warn, warn] + self.assertCountEqual(errors, expected) diff --git a/tests/migrations/test_autodetector.py b/tests/migrations/test_autodetector.py index fd1bc383b9..1fde1ba466 100644 --- a/tests/migrations/test_autodetector.py +++ b/tests/migrations/test_autodetector.py @@ -61,6 +61,12 @@ class AutodetectorTests(TestCase): ("id", models.AutoField(primary_key=True)), ("name", models.CharField(max_length=200, default='Ada Lovelace')), ]) + author_name_check_constraint = ModelState("testapp", "Author", [ + ("id", models.AutoField(primary_key=True)), + ("name", models.CharField(max_length=200)), + ], + {'constraints': [models.CheckConstraint(models.Q(name__contains='Bob'), 'name_contains_bob')]}, + ) author_dates_of_birth_auto_now = ModelState("testapp", "Author", [ ("id", models.AutoField(primary_key=True)), ("date_of_birth", models.DateField(auto_now=True)), @@ -1389,6 +1395,40 @@ class AutodetectorTests(TestCase): added_index = models.Index(fields=['title', 'author'], name='book_author_title_idx') self.assertOperationAttributes(changes, 'otherapp', 0, 1, model_name='book', index=added_index) + def test_create_model_with_check_constraint(self): + """Test creation of new model with constraints already defined.""" + author = ModelState('otherapp', 'Author', [ + ('id', models.AutoField(primary_key=True)), + ('name', models.CharField(max_length=200)), + ], {'constraints': [models.CheckConstraint(models.Q(name__contains='Bob'), 'name_contains_bob')]}) + changes = self.get_changes([], [author]) + added_constraint = models.CheckConstraint(models.Q(name__contains='Bob'), 'name_contains_bob') + # Right number of migrations? + self.assertEqual(len(changes['otherapp']), 1) + # Right number of actions? + migration = changes['otherapp'][0] + self.assertEqual(len(migration.operations), 2) + # Right actions order? + self.assertOperationTypes(changes, 'otherapp', 0, ['CreateModel', 'AddConstraint']) + self.assertOperationAttributes(changes, 'otherapp', 0, 0, name='Author') + self.assertOperationAttributes(changes, 'otherapp', 0, 1, model_name='author', constraint=added_constraint) + + def test_add_constraints(self): + """Test change detection of new constraints.""" + changes = self.get_changes([self.author_name], [self.author_name_check_constraint]) + self.assertNumberMigrations(changes, 'testapp', 1) + self.assertOperationTypes(changes, 'testapp', 0, ['AddConstraint']) + added_constraint = models.CheckConstraint(models.Q(name__contains='Bob'), 'name_contains_bob') + self.assertOperationAttributes(changes, 'testapp', 0, 0, model_name='author', constraint=added_constraint) + + def test_remove_constraints(self): + """Test change detection of removed constraints.""" + changes = self.get_changes([self.author_name_check_constraint], [self.author_name]) + # Right number/type of migrations? + self.assertNumberMigrations(changes, 'testapp', 1) + self.assertOperationTypes(changes, 'testapp', 0, ['RemoveConstraint']) + self.assertOperationAttributes(changes, 'testapp', 0, 0, model_name='author', name='name_contains_bob') + def test_add_foo_together(self): """Tests index/unique_together detection.""" changes = self.get_changes([self.author_empty, self.book], [self.author_empty, self.book_foo_together]) @@ -1520,7 +1560,7 @@ class AutodetectorTests(TestCase): self.assertNumberMigrations(changes, "testapp", 1) self.assertOperationTypes(changes, "testapp", 0, ["CreateModel"]) self.assertOperationAttributes( - changes, "testapp", 0, 0, name="AuthorProxy", options={"proxy": True, "indexes": []} + changes, "testapp", 0, 0, name="AuthorProxy", options={"proxy": True, "indexes": [], "constraints": []} ) # Now, we test turning a proxy model into a non-proxy model # It should delete the proxy then make the real one diff --git a/tests/migrations/test_base.py b/tests/migrations/test_base.py index 84a5117751..7fcbaffd24 100644 --- a/tests/migrations/test_base.py +++ b/tests/migrations/test_base.py @@ -67,6 +67,17 @@ class MigrationTestBase(TransactionTestCase): def assertIndexNotExists(self, table, columns): return self.assertIndexExists(table, columns, False) + def assertConstraintExists(self, table, name, value=True, using='default'): + with connections[using].cursor() as cursor: + constraints = connections[using].introspection.get_constraints(cursor, table).items() + self.assertEqual( + value, + any(c['check'] for n, c in constraints if n == name), + ) + + def assertConstraintNotExists(self, table, name): + return self.assertConstraintExists(table, name, False) + def assertFKExists(self, table, columns, to, value=True, using='default'): with connections[using].cursor() as cursor: self.assertEqual( diff --git a/tests/migrations/test_operations.py b/tests/migrations/test_operations.py index d70feaacdb..b1581042f7 100644 --- a/tests/migrations/test_operations.py +++ b/tests/migrations/test_operations.py @@ -53,7 +53,7 @@ class OperationTestBase(MigrationTestBase): def set_up_test_model( self, app_label, second_model=False, third_model=False, index=False, multicol_index=False, related_model=False, mti_model=False, proxy_model=False, manager_model=False, - unique_together=False, options=False, db_table=None, index_together=False): + unique_together=False, options=False, db_table=None, index_together=False, check_constraint=False): """ Creates a test model state and database table. """ @@ -106,6 +106,11 @@ class OperationTestBase(MigrationTestBase): "Pony", models.Index(fields=["pink", "weight"], name="pony_test_idx") )) + if check_constraint: + operations.append(migrations.AddConstraint( + "Pony", + models.CheckConstraint(models.Q(pink__gt=2), name="pony_test_constraint") + )) if second_model: operations.append(migrations.CreateModel( "Stable", @@ -462,6 +467,45 @@ class OperationTests(OperationTestBase): self.assertTableNotExists("test_crummo_unmanagedpony") self.assertTableExists("test_crummo_pony") + @skipUnlessDBFeature('supports_table_check_constraints') + def test_create_model_with_constraint(self): + where = models.Q(pink__gt=2) + check_constraint = models.CheckConstraint(where, name='test_constraint_pony_pink_gt_2') + operation = migrations.CreateModel( + "Pony", + [ + ("id", models.AutoField(primary_key=True)), + ("pink", models.IntegerField(default=3)), + ], + options={'constraints': [check_constraint]}, + ) + + # Test the state alteration + project_state = ProjectState() + new_state = project_state.clone() + operation.state_forwards("test_crmo", new_state) + self.assertEqual(len(new_state.models['test_crmo', 'pony'].options['constraints']), 1) + + # Test database alteration + self.assertTableNotExists("test_crmo_pony") + with connection.schema_editor() as editor: + operation.database_forwards("test_crmo", editor, project_state, new_state) + self.assertTableExists("test_crmo_pony") + with connection.cursor() as cursor: + with self.assertRaises(IntegrityError): + cursor.execute("INSERT INTO test_crmo_pony (id, pink) VALUES (1, 1)") + + # Test reversal + with connection.schema_editor() as editor: + operation.database_backwards("test_crmo", editor, new_state, project_state) + self.assertTableNotExists("test_crmo_pony") + + # Test deconstruction + definition = operation.deconstruct() + self.assertEqual(definition[0], "CreateModel") + self.assertEqual(definition[1], []) + self.assertEqual(definition[2]['options']['constraints'], [check_constraint]) + def test_create_model_managers(self): """ The managers on a model are set. @@ -1708,6 +1752,87 @@ class OperationTests(OperationTestBase): operation = migrations.AlterIndexTogether("Pony", None) self.assertEqual(operation.describe(), "Alter index_together for Pony (0 constraint(s))") + @skipUnlessDBFeature('supports_table_check_constraints') + def test_add_constraint(self): + """Test the AddConstraint operation.""" + project_state = self.set_up_test_model('test_addconstraint') + + where = models.Q(pink__gt=2) + check_constraint = models.CheckConstraint(where, name='test_constraint_pony_pink_gt_2') + operation = migrations.AddConstraint('Pony', check_constraint) + self.assertEqual(operation.describe(), 'Create constraint test_constraint_pony_pink_gt_2 on model Pony') + + new_state = project_state.clone() + operation.state_forwards('test_addconstraint', new_state) + self.assertEqual(len(new_state.models['test_addconstraint', 'pony'].options['constraints']), 1) + + # Test database alteration + with connection.cursor() as cursor: + with atomic(): + cursor.execute("INSERT INTO test_addconstraint_pony (id, pink, weight) VALUES (1, 1, 1.0)") + cursor.execute("DELETE FROM test_addconstraint_pony") + + with connection.schema_editor() as editor: + operation.database_forwards("test_addconstraint", editor, project_state, new_state) + + with connection.cursor() as cursor: + with self.assertRaises(IntegrityError): + cursor.execute("INSERT INTO test_addconstraint_pony (id, pink, weight) VALUES (1, 1, 1.0)") + + # Test reversal + with connection.schema_editor() as editor: + operation.database_backwards("test_addconstraint", editor, new_state, project_state) + + with connection.cursor() as cursor: + with atomic(): + cursor.execute("INSERT INTO test_addconstraint_pony (id, pink, weight) VALUES (1, 1, 1.0)") + cursor.execute("DELETE FROM test_addconstraint_pony") + + # Test deconstruction + definition = operation.deconstruct() + self.assertEqual(definition[0], "AddConstraint") + self.assertEqual(definition[1], []) + self.assertEqual(definition[2], {'model_name': "Pony", 'constraint': check_constraint}) + + @skipUnlessDBFeature('supports_table_check_constraints') + def test_remove_constraint(self): + """Test the RemoveConstraint operation.""" + project_state = self.set_up_test_model("test_removeconstraint", check_constraint=True) + self.assertTableExists("test_removeconstraint_pony") + operation = migrations.RemoveConstraint("Pony", "pony_test_constraint") + self.assertEqual(operation.describe(), "Remove constraint pony_test_constraint from model Pony") + new_state = project_state.clone() + operation.state_forwards("test_removeconstraint", new_state) + # Test state alteration + self.assertEqual(len(new_state.models["test_removeconstraint", "pony"].options['constraints']), 0) + + with connection.cursor() as cursor: + with self.assertRaises(IntegrityError): + cursor.execute("INSERT INTO test_removeconstraint_pony (id, pink, weight) VALUES (1, 1, 1.0)") + + # Test database alteration + with connection.schema_editor() as editor: + operation.database_forwards("test_removeconstraint", editor, project_state, new_state) + + with connection.cursor() as cursor: + with atomic(): + cursor.execute("INSERT INTO test_removeconstraint_pony (id, pink, weight) VALUES (1, 1, 1.0)") + cursor.execute("DELETE FROM test_removeconstraint_pony") + + # Test reversal + with connection.schema_editor() as editor: + operation.database_backwards("test_removeconstraint", editor, new_state, project_state) + + with connection.cursor() as cursor: + with self.assertRaises(IntegrityError): + cursor.execute("INSERT INTO test_removeconstraint_pony (id, pink, weight) VALUES (1, 1, 1.0)") + + # Test deconstruction + definition = operation.deconstruct() + self.assertEqual(definition[0], "RemoveConstraint") + self.assertEqual(definition[1], []) + self.assertEqual(definition[2], {'model_name': "Pony", 'name': "pony_test_constraint"}) + def test_alter_model_options(self): """ Tests the AlterModelOptions operation. diff --git a/tests/migrations/test_state.py b/tests/migrations/test_state.py index 255e14beff..6a7a087ac5 100644 --- a/tests/migrations/test_state.py +++ b/tests/migrations/test_state.py @@ -127,7 +127,12 @@ class StateTests(SimpleTestCase): self.assertIs(author_state.fields[3][1].null, True) self.assertEqual( author_state.options, - {"unique_together": {("name", "bio")}, "index_together": {("bio", "age")}, "indexes": []} + { + "unique_together": {("name", "bio")}, + "index_together": {("bio", "age")}, + "indexes": [], + "constraints": [], + } ) self.assertEqual(author_state.bases, (models.Model,)) @@ -139,14 +144,17 @@ class StateTests(SimpleTestCase): self.assertEqual(book_state.fields[3][1].__class__.__name__, "ManyToManyField") self.assertEqual( book_state.options, - {"verbose_name": "tome", "db_table": "test_tome", "indexes": [book_index]}, + {"verbose_name": "tome", "db_table": "test_tome", "indexes": [book_index], "constraints": []}, ) self.assertEqual(book_state.bases, (models.Model,)) self.assertEqual(author_proxy_state.app_label, "migrations") self.assertEqual(author_proxy_state.name, "AuthorProxy") self.assertEqual(author_proxy_state.fields, []) - self.assertEqual(author_proxy_state.options, {"proxy": True, "ordering": ["name"], "indexes": []}) + self.assertEqual( + author_proxy_state.options, + {"proxy": True, "ordering": ["name"], "indexes": [], "constraints": []}, + ) self.assertEqual(author_proxy_state.bases, ("migrations.author",)) self.assertEqual(sub_author_state.app_label, "migrations") @@ -1002,7 +1010,7 @@ class ModelStateTests(SimpleTestCase): self.assertEqual(author_state.fields[1][1].max_length, 255) self.assertIs(author_state.fields[2][1].null, False) self.assertIs(author_state.fields[3][1].null, True) - self.assertEqual(author_state.options, {'swappable': 'TEST_SWAPPABLE_MODEL', 'indexes': []}) + self.assertEqual(author_state.options, {'swappable': 'TEST_SWAPPABLE_MODEL', 'indexes': [], "constraints": []}) self.assertEqual(author_state.bases, (models.Model,)) self.assertEqual(author_state.managers, []) @@ -1047,7 +1055,7 @@ class ModelStateTests(SimpleTestCase): self.assertEqual(station_state.fields[2][1].null, False) self.assertEqual( station_state.options, - {'abstract': False, 'swappable': 'TEST_SWAPPABLE_MODEL', 'indexes': []} + {'abstract': False, 'swappable': 'TEST_SWAPPABLE_MODEL', 'indexes': [], 'constraints': []} ) self.assertEqual(station_state.bases, ('migrations.searchablelocation',)) self.assertEqual(station_state.managers, []) @@ -1129,6 +1137,21 @@ class ModelStateTests(SimpleTestCase): index_names = [index.name for index in model_state.options['indexes']] self.assertEqual(index_names, ['foo_idx']) + @isolate_apps('migrations') + def test_from_model_constraints(self): + class ModelWithConstraints(models.Model): + size = models.IntegerField() + + class Meta: + constraints = [models.CheckConstraint(models.Q(size__gt=1), 'size_gt_1')] + + state = ModelState.from_model(ModelWithConstraints) + model_constraints = ModelWithConstraints._meta.constraints + state_constraints = state.options['constraints'] + self.assertEqual(model_constraints, state_constraints) + self.assertIsNot(model_constraints, state_constraints) + self.assertIsNot(model_constraints[0], state_constraints[0]) + class RelatedModelsTests(SimpleTestCase): diff --git a/tests/queries/test_query.py b/tests/queries/test_query.py new file mode 100644 index 0000000000..10ea8eb0f2 --- /dev/null +++ b/tests/queries/test_query.py @@ -0,0 +1,95 @@ +from datetime import datetime + +from django.core.exceptions import FieldError +from django.db.models import CharField, F, Q +from django.db.models.expressions import SimpleCol +from django.db.models.fields.related_lookups import RelatedIsNull +from django.db.models.functions import Lower +from django.db.models.lookups import Exact, GreaterThan, IsNull, LessThan +from django.db.models.sql.query import Query +from django.db.models.sql.where import OR +from django.test import TestCase + +from .models import Author, Item, ObjectC, Ranking + + +class TestQuery(TestCase): + def test_simple_query(self): + query = Query(Author) + where = query.build_where(Q(num__gt=2)) + lookup = where.children[0] + self.assertIsInstance(lookup, GreaterThan) + self.assertEqual(lookup.rhs, 2) + self.assertEqual(lookup.lhs.target, Author._meta.get_field('num')) + + def test_complex_query(self): + query = Query(Author) + where = query.build_where(Q(num__gt=2) | Q(num__lt=0)) + self.assertEqual(where.connector, OR) + + lookup = where.children[0] + self.assertIsInstance(lookup, GreaterThan) + self.assertEqual(lookup.rhs, 2) + self.assertEqual(lookup.lhs.target, Author._meta.get_field('num')) + + lookup = where.children[1] + self.assertIsInstance(lookup, LessThan) + self.assertEqual(lookup.rhs, 0) + self.assertEqual(lookup.lhs.target, Author._meta.get_field('num')) + + def test_multiple_fields(self): + query = Query(Item) + where = query.build_where(Q(modified__gt=F('created'))) + lookup = where.children[0] + self.assertIsInstance(lookup, GreaterThan) + self.assertIsInstance(lookup.rhs, SimpleCol) + self.assertIsInstance(lookup.lhs, SimpleCol) + self.assertEqual(lookup.rhs.target, Item._meta.get_field('created')) + self.assertEqual(lookup.lhs.target, Item._meta.get_field('modified')) + + def test_transform(self): + query = Query(Author) + CharField.register_lookup(Lower, 'lower') + try: + where = query.build_where(~Q(name__lower='foo')) + finally: + CharField._unregister_lookup(Lower, 'lower') + lookup = where.children[0] + self.assertIsInstance(lookup, Exact) + self.assertIsInstance(lookup.lhs, Lower) + self.assertIsInstance(lookup.lhs.lhs, SimpleCol) + self.assertEqual(lookup.lhs.lhs.target, Author._meta.get_field('name')) + + def test_negated_nullable(self): + query = Query(Item) + where = query.build_where(~Q(modified__lt=datetime(2017, 1, 1))) + self.assertTrue(where.negated) + lookup = where.children[0] + self.assertIsInstance(lookup, LessThan) + self.assertEqual(lookup.lhs.target, Item._meta.get_field('modified')) + lookup = where.children[1] + self.assertIsInstance(lookup, IsNull) + self.assertEqual(lookup.lhs.target, Item._meta.get_field('modified')) + + def test_foreign_key(self): + query = Query(Item) + msg = 'Joined field references are not permitted in this query' + with self.assertRaisesMessage(FieldError, msg): + query.build_where(Q(creator__num__gt=2)) + + def test_foreign_key_f(self): + query = Query(Ranking) + with self.assertRaises(FieldError): + query.build_where(Q(rank__gt=F('author__num'))) + + def test_foreign_key_exclusive(self): + query = Query(ObjectC) + where = query.build_where(Q(objecta=None) | Q(objectb=None)) + a_isnull = where.children[0] + self.assertIsInstance(a_isnull, RelatedIsNull) + self.assertIsInstance(a_isnull.lhs, SimpleCol) + self.assertEqual(a_isnull.lhs.target, ObjectC._meta.get_field('objecta')) + b_isnull = where.children[1] + self.assertIsInstance(b_isnull, RelatedIsNull) + self.assertIsInstance(b_isnull.lhs, SimpleCol) + self.assertEqual(b_isnull.lhs.target, ObjectC._meta.get_field('objectb'))