mirror of https://github.com/django/django.git
Fixed #470 -- Added support for database defaults on fields.
Special thanks to Hannes Ljungberg for finding multiple implementation gaps. Thanks also to Simon Charette, Adam Johnson, and Mariusz Felisiak for reviews.
This commit is contained in:
parent
599f3e2cda
commit
7414704e88
1
AUTHORS
1
AUTHORS
|
@ -587,6 +587,7 @@ answer newbie questions, and generally made Django that much better:
|
||||||
lerouxb@gmail.com
|
lerouxb@gmail.com
|
||||||
Lex Berezhny <lex@damoti.com>
|
Lex Berezhny <lex@damoti.com>
|
||||||
Liang Feng <hutuworm@gmail.com>
|
Liang Feng <hutuworm@gmail.com>
|
||||||
|
Lily Foote
|
||||||
limodou
|
limodou
|
||||||
Lincoln Smith <lincoln.smith@anu.edu.au>
|
Lincoln Smith <lincoln.smith@anu.edu.au>
|
||||||
Liu Yijie <007gzs@gmail.com>
|
Liu Yijie <007gzs@gmail.com>
|
||||||
|
|
|
@ -201,6 +201,15 @@ class BaseDatabaseFeatures:
|
||||||
# Does the backend require literal defaults, rather than parameterized ones?
|
# Does the backend require literal defaults, rather than parameterized ones?
|
||||||
requires_literal_defaults = False
|
requires_literal_defaults = False
|
||||||
|
|
||||||
|
# Does the backend support functions in defaults?
|
||||||
|
supports_expression_defaults = True
|
||||||
|
|
||||||
|
# Does the backend support the DEFAULT keyword in insert queries?
|
||||||
|
supports_default_keyword_in_insert = True
|
||||||
|
|
||||||
|
# Does the backend support the DEFAULT keyword in bulk insert queries?
|
||||||
|
supports_default_keyword_in_bulk_insert = True
|
||||||
|
|
||||||
# Does the backend require a connection reset after each material schema change?
|
# Does the backend require a connection reset after each material schema change?
|
||||||
connection_persists_old_columns = False
|
connection_persists_old_columns = False
|
||||||
|
|
||||||
|
@ -361,6 +370,9 @@ class BaseDatabaseFeatures:
|
||||||
# SQL template override for tests.aggregation.tests.NowUTC
|
# SQL template override for tests.aggregation.tests.NowUTC
|
||||||
test_now_utc_template = None
|
test_now_utc_template = None
|
||||||
|
|
||||||
|
# SQL to create a model instance using the database defaults.
|
||||||
|
insert_test_table_with_defaults = None
|
||||||
|
|
||||||
# A set of dotted paths to tests in Django's test suite that are expected
|
# A set of dotted paths to tests in Django's test suite that are expected
|
||||||
# to fail on this database.
|
# to fail on this database.
|
||||||
django_test_expected_failures = set()
|
django_test_expected_failures = set()
|
||||||
|
|
|
@ -12,7 +12,7 @@ from django.db.backends.ddl_references import (
|
||||||
Table,
|
Table,
|
||||||
)
|
)
|
||||||
from django.db.backends.utils import names_digest, split_identifier, truncate_name
|
from django.db.backends.utils import names_digest, split_identifier, truncate_name
|
||||||
from django.db.models import Deferrable, Index
|
from django.db.models import NOT_PROVIDED, Deferrable, Index
|
||||||
from django.db.models.sql import Query
|
from django.db.models.sql import Query
|
||||||
from django.db.transaction import TransactionManagementError, atomic
|
from django.db.transaction import TransactionManagementError, atomic
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
|
@ -296,6 +296,12 @@ class BaseDatabaseSchemaEditor:
|
||||||
yield self._comment_sql(field.db_comment)
|
yield self._comment_sql(field.db_comment)
|
||||||
# Work out nullability.
|
# Work out nullability.
|
||||||
null = field.null
|
null = field.null
|
||||||
|
# Add database default.
|
||||||
|
if field.db_default is not NOT_PROVIDED:
|
||||||
|
default_sql, default_params = self.db_default_sql(field)
|
||||||
|
yield f"DEFAULT {default_sql}"
|
||||||
|
params.extend(default_params)
|
||||||
|
include_default = False
|
||||||
# Include a default value, if requested.
|
# Include a default value, if requested.
|
||||||
include_default = (
|
include_default = (
|
||||||
include_default
|
include_default
|
||||||
|
@ -400,6 +406,22 @@ class BaseDatabaseSchemaEditor:
|
||||||
"""
|
"""
|
||||||
return "%s"
|
return "%s"
|
||||||
|
|
||||||
|
def db_default_sql(self, field):
|
||||||
|
"""Return the sql and params for the field's database default."""
|
||||||
|
from django.db.models.expressions import Value
|
||||||
|
|
||||||
|
sql = "%s" if isinstance(field.db_default, Value) else "(%s)"
|
||||||
|
query = Query(model=field.model)
|
||||||
|
compiler = query.get_compiler(connection=self.connection)
|
||||||
|
default_sql, params = compiler.compile(field.db_default)
|
||||||
|
if self.connection.features.requires_literal_defaults:
|
||||||
|
# Some databases doesn't support parameterized defaults (Oracle,
|
||||||
|
# SQLite). If this is the case, the individual schema backend
|
||||||
|
# should implement prepare_default().
|
||||||
|
default_sql %= tuple(self.prepare_default(p) for p in params)
|
||||||
|
params = []
|
||||||
|
return sql % default_sql, params
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _effective_default(field):
|
def _effective_default(field):
|
||||||
# This method allows testing its logic without a connection.
|
# This method allows testing its logic without a connection.
|
||||||
|
@ -1025,6 +1047,21 @@ class BaseDatabaseSchemaEditor:
|
||||||
)
|
)
|
||||||
actions.append(fragment)
|
actions.append(fragment)
|
||||||
post_actions.extend(other_actions)
|
post_actions.extend(other_actions)
|
||||||
|
|
||||||
|
if new_field.db_default is not NOT_PROVIDED:
|
||||||
|
if (
|
||||||
|
old_field.db_default is NOT_PROVIDED
|
||||||
|
or new_field.db_default != old_field.db_default
|
||||||
|
):
|
||||||
|
actions.append(
|
||||||
|
self._alter_column_database_default_sql(model, old_field, new_field)
|
||||||
|
)
|
||||||
|
elif old_field.db_default is not NOT_PROVIDED:
|
||||||
|
actions.append(
|
||||||
|
self._alter_column_database_default_sql(
|
||||||
|
model, old_field, new_field, drop=True
|
||||||
|
)
|
||||||
|
)
|
||||||
# When changing a column NULL constraint to NOT NULL with a given
|
# When changing a column NULL constraint to NOT NULL with a given
|
||||||
# default value, we need to perform 4 steps:
|
# default value, we need to perform 4 steps:
|
||||||
# 1. Add a default for new incoming writes
|
# 1. Add a default for new incoming writes
|
||||||
|
@ -1033,7 +1070,11 @@ class BaseDatabaseSchemaEditor:
|
||||||
# 4. Drop the default again.
|
# 4. Drop the default again.
|
||||||
# Default change?
|
# Default change?
|
||||||
needs_database_default = False
|
needs_database_default = False
|
||||||
if old_field.null and not new_field.null:
|
if (
|
||||||
|
old_field.null
|
||||||
|
and not new_field.null
|
||||||
|
and new_field.db_default is NOT_PROVIDED
|
||||||
|
):
|
||||||
old_default = self.effective_default(old_field)
|
old_default = self.effective_default(old_field)
|
||||||
new_default = self.effective_default(new_field)
|
new_default = self.effective_default(new_field)
|
||||||
if (
|
if (
|
||||||
|
@ -1051,9 +1092,9 @@ class BaseDatabaseSchemaEditor:
|
||||||
if fragment:
|
if fragment:
|
||||||
null_actions.append(fragment)
|
null_actions.append(fragment)
|
||||||
# Only if we have a default and there is a change from NULL to NOT NULL
|
# Only if we have a default and there is a change from NULL to NOT NULL
|
||||||
four_way_default_alteration = new_field.has_default() and (
|
four_way_default_alteration = (
|
||||||
old_field.null and not new_field.null
|
new_field.has_default() or new_field.db_default is not NOT_PROVIDED
|
||||||
)
|
) and (old_field.null and not new_field.null)
|
||||||
if actions or null_actions:
|
if actions or null_actions:
|
||||||
if not four_way_default_alteration:
|
if not four_way_default_alteration:
|
||||||
# If we don't have to do a 4-way default alteration we can
|
# If we don't have to do a 4-way default alteration we can
|
||||||
|
@ -1074,15 +1115,20 @@ class BaseDatabaseSchemaEditor:
|
||||||
params,
|
params,
|
||||||
)
|
)
|
||||||
if four_way_default_alteration:
|
if four_way_default_alteration:
|
||||||
|
if new_field.db_default is NOT_PROVIDED:
|
||||||
|
default_sql = "%s"
|
||||||
|
params = [new_default]
|
||||||
|
else:
|
||||||
|
default_sql, params = self.db_default_sql(new_field)
|
||||||
# Update existing rows with default value
|
# Update existing rows with default value
|
||||||
self.execute(
|
self.execute(
|
||||||
self.sql_update_with_default
|
self.sql_update_with_default
|
||||||
% {
|
% {
|
||||||
"table": self.quote_name(model._meta.db_table),
|
"table": self.quote_name(model._meta.db_table),
|
||||||
"column": self.quote_name(new_field.column),
|
"column": self.quote_name(new_field.column),
|
||||||
"default": "%s",
|
"default": default_sql,
|
||||||
},
|
},
|
||||||
[new_default],
|
params,
|
||||||
)
|
)
|
||||||
# Since we didn't run a NOT NULL change before we need to do it
|
# Since we didn't run a NOT NULL change before we need to do it
|
||||||
# now
|
# now
|
||||||
|
@ -1264,6 +1310,34 @@ class BaseDatabaseSchemaEditor:
|
||||||
params,
|
params,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _alter_column_database_default_sql(
|
||||||
|
self, model, old_field, new_field, drop=False
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Hook to specialize column database default alteration.
|
||||||
|
|
||||||
|
Return a (sql, params) fragment to add or drop (depending on the drop
|
||||||
|
argument) a default to new_field's column.
|
||||||
|
"""
|
||||||
|
if drop:
|
||||||
|
sql = self.sql_alter_column_no_default
|
||||||
|
default_sql = ""
|
||||||
|
params = []
|
||||||
|
else:
|
||||||
|
sql = self.sql_alter_column_default
|
||||||
|
default_sql, params = self.db_default_sql(new_field)
|
||||||
|
|
||||||
|
new_db_params = new_field.db_parameters(connection=self.connection)
|
||||||
|
return (
|
||||||
|
sql
|
||||||
|
% {
|
||||||
|
"column": self.quote_name(new_field.column),
|
||||||
|
"type": new_db_params["type"],
|
||||||
|
"default": default_sql,
|
||||||
|
},
|
||||||
|
params,
|
||||||
|
)
|
||||||
|
|
||||||
def _alter_column_type_sql(
|
def _alter_column_type_sql(
|
||||||
self, model, old_field, new_field, new_type, old_collation, new_collation
|
self, model, old_field, new_field, new_type, old_collation, new_collation
|
||||||
):
|
):
|
||||||
|
|
|
@ -51,6 +51,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
||||||
# COLLATE must be wrapped in parentheses because MySQL treats COLLATE as an
|
# COLLATE must be wrapped in parentheses because MySQL treats COLLATE as an
|
||||||
# indexed expression.
|
# indexed expression.
|
||||||
collate_as_index_expression = True
|
collate_as_index_expression = True
|
||||||
|
insert_test_table_with_defaults = "INSERT INTO {} () VALUES ()"
|
||||||
|
|
||||||
supports_order_by_nulls_modifier = False
|
supports_order_by_nulls_modifier = False
|
||||||
order_by_nulls_first = True
|
order_by_nulls_first = True
|
||||||
|
@ -342,3 +343,9 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
||||||
if self.connection.mysql_is_mariadb:
|
if self.connection.mysql_is_mariadb:
|
||||||
return self.connection.mysql_version >= (10, 5, 2)
|
return self.connection.mysql_version >= (10, 5, 2)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def supports_expression_defaults(self):
|
||||||
|
if self.connection.mysql_is_mariadb:
|
||||||
|
return True
|
||||||
|
return self.connection.mysql_version >= (8, 0, 13)
|
||||||
|
|
|
@ -209,11 +209,15 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
|
||||||
self._create_missing_fk_index(model, fields=fields)
|
self._create_missing_fk_index(model, fields=fields)
|
||||||
return super()._delete_composed_index(model, fields, *args)
|
return super()._delete_composed_index(model, fields, *args)
|
||||||
|
|
||||||
def _set_field_new_type_null_status(self, field, new_type):
|
def _set_field_new_type(self, field, new_type):
|
||||||
"""
|
"""
|
||||||
Keep the null property of the old field. If it has changed, it will be
|
Keep the NULL and DEFAULT properties of the old field. If it has
|
||||||
handled separately.
|
changed, it will be handled separately.
|
||||||
"""
|
"""
|
||||||
|
if field.db_default is not NOT_PROVIDED:
|
||||||
|
default_sql, params = self.db_default_sql(field)
|
||||||
|
default_sql %= tuple(self.quote_value(p) for p in params)
|
||||||
|
new_type += f" DEFAULT {default_sql}"
|
||||||
if field.null:
|
if field.null:
|
||||||
new_type += " NULL"
|
new_type += " NULL"
|
||||||
else:
|
else:
|
||||||
|
@ -223,7 +227,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
|
||||||
def _alter_column_type_sql(
|
def _alter_column_type_sql(
|
||||||
self, model, old_field, new_field, new_type, old_collation, new_collation
|
self, model, old_field, new_field, new_type, old_collation, new_collation
|
||||||
):
|
):
|
||||||
new_type = self._set_field_new_type_null_status(old_field, new_type)
|
new_type = self._set_field_new_type(old_field, new_type)
|
||||||
return super()._alter_column_type_sql(
|
return super()._alter_column_type_sql(
|
||||||
model, old_field, new_field, new_type, old_collation, new_collation
|
model, old_field, new_field, new_type, old_collation, new_collation
|
||||||
)
|
)
|
||||||
|
@ -242,7 +246,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
|
||||||
return field_db_params["check"]
|
return field_db_params["check"]
|
||||||
|
|
||||||
def _rename_field_sql(self, table, old_field, new_field, new_type):
|
def _rename_field_sql(self, table, old_field, new_field, new_type):
|
||||||
new_type = self._set_field_new_type_null_status(old_field, new_type)
|
new_type = self._set_field_new_type(old_field, new_type)
|
||||||
return super()._rename_field_sql(table, old_field, new_field, new_type)
|
return super()._rename_field_sql(table, old_field, new_field, new_type)
|
||||||
|
|
||||||
def _alter_column_comment_sql(self, model, new_field, new_type, new_db_comment):
|
def _alter_column_comment_sql(self, model, new_field, new_type, new_db_comment):
|
||||||
|
@ -252,3 +256,18 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
|
||||||
def _comment_sql(self, comment):
|
def _comment_sql(self, comment):
|
||||||
comment_sql = super()._comment_sql(comment)
|
comment_sql = super()._comment_sql(comment)
|
||||||
return f" COMMENT {comment_sql}"
|
return f" COMMENT {comment_sql}"
|
||||||
|
|
||||||
|
def _alter_column_null_sql(self, model, old_field, new_field):
|
||||||
|
if new_field.db_default is NOT_PROVIDED:
|
||||||
|
return super()._alter_column_null_sql(model, old_field, new_field)
|
||||||
|
|
||||||
|
new_db_params = new_field.db_parameters(connection=self.connection)
|
||||||
|
type_sql = self._set_field_new_type(new_field, new_db_params["type"])
|
||||||
|
return (
|
||||||
|
"MODIFY %(column)s %(type)s"
|
||||||
|
% {
|
||||||
|
"column": self.quote_name(new_field.column),
|
||||||
|
"type": type_sql,
|
||||||
|
},
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
|
|
@ -32,6 +32,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
||||||
atomic_transactions = False
|
atomic_transactions = False
|
||||||
nulls_order_largest = True
|
nulls_order_largest = True
|
||||||
requires_literal_defaults = True
|
requires_literal_defaults = True
|
||||||
|
supports_default_keyword_in_bulk_insert = False
|
||||||
closed_cursor_error_class = InterfaceError
|
closed_cursor_error_class = InterfaceError
|
||||||
bare_select_suffix = " FROM DUAL"
|
bare_select_suffix = " FROM DUAL"
|
||||||
# Select for update with limit can be achieved on Oracle, but not with the
|
# Select for update with limit can be achieved on Oracle, but not with the
|
||||||
|
@ -130,6 +131,9 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
||||||
"annotations.tests.NonAggregateAnnotationTestCase."
|
"annotations.tests.NonAggregateAnnotationTestCase."
|
||||||
"test_custom_functions_can_ref_other_functions",
|
"test_custom_functions_can_ref_other_functions",
|
||||||
}
|
}
|
||||||
|
insert_test_table_with_defaults = (
|
||||||
|
"INSERT INTO {} VALUES (DEFAULT, DEFAULT, DEFAULT)"
|
||||||
|
)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def introspected_field_types(self):
|
def introspected_field_types(self):
|
||||||
|
|
|
@ -156,7 +156,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
|
||||||
field_map = {
|
field_map = {
|
||||||
column: (
|
column: (
|
||||||
display_size,
|
display_size,
|
||||||
default if default != "NULL" else None,
|
default.rstrip() if default and default != "NULL" else None,
|
||||||
collation,
|
collation,
|
||||||
is_autofield,
|
is_autofield,
|
||||||
is_json,
|
is_json,
|
||||||
|
|
|
@ -198,7 +198,9 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
|
||||||
return self.normalize_name(for_name + "_" + suffix)
|
return self.normalize_name(for_name + "_" + suffix)
|
||||||
|
|
||||||
def prepare_default(self, value):
|
def prepare_default(self, value):
|
||||||
return self.quote_value(value)
|
# Replace % with %% as %-formatting is applied in
|
||||||
|
# FormatStylePlaceholderCursor._fix_for_params().
|
||||||
|
return self.quote_value(value).replace("%", "%%")
|
||||||
|
|
||||||
def _field_should_be_indexed(self, model, field):
|
def _field_should_be_indexed(self, model, field):
|
||||||
create_index = super()._field_should_be_indexed(model, field)
|
create_index = super()._field_should_be_indexed(model, field)
|
||||||
|
|
|
@ -76,6 +76,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
||||||
"swedish_ci": "sv-x-icu",
|
"swedish_ci": "sv-x-icu",
|
||||||
}
|
}
|
||||||
test_now_utc_template = "STATEMENT_TIMESTAMP() AT TIME ZONE 'UTC'"
|
test_now_utc_template = "STATEMENT_TIMESTAMP() AT TIME ZONE 'UTC'"
|
||||||
|
insert_test_table_with_defaults = "INSERT INTO {} DEFAULT VALUES"
|
||||||
|
|
||||||
django_test_skips = {
|
django_test_skips = {
|
||||||
"opclasses are PostgreSQL only.": {
|
"opclasses are PostgreSQL only.": {
|
||||||
|
|
|
@ -59,6 +59,8 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
||||||
PRIMARY KEY(column_1, column_2)
|
PRIMARY KEY(column_1, column_2)
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
|
insert_test_table_with_defaults = 'INSERT INTO {} ("null") VALUES (1)'
|
||||||
|
supports_default_keyword_in_insert = False
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def django_test_skips(self):
|
def django_test_skips(self):
|
||||||
|
|
|
@ -6,7 +6,7 @@ from django.db import NotSupportedError
|
||||||
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
|
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
|
||||||
from django.db.backends.ddl_references import Statement
|
from django.db.backends.ddl_references import Statement
|
||||||
from django.db.backends.utils import strip_quotes
|
from django.db.backends.utils import strip_quotes
|
||||||
from django.db.models import UniqueConstraint
|
from django.db.models import NOT_PROVIDED, UniqueConstraint
|
||||||
from django.db.transaction import atomic
|
from django.db.transaction import atomic
|
||||||
|
|
||||||
|
|
||||||
|
@ -233,9 +233,13 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
|
||||||
if create_field:
|
if create_field:
|
||||||
body[create_field.name] = create_field
|
body[create_field.name] = create_field
|
||||||
# Choose a default and insert it into the copy map
|
# Choose a default and insert it into the copy map
|
||||||
if not create_field.many_to_many and create_field.concrete:
|
if (
|
||||||
|
create_field.db_default is NOT_PROVIDED
|
||||||
|
and not create_field.many_to_many
|
||||||
|
and create_field.concrete
|
||||||
|
):
|
||||||
mapping[create_field.column] = self.prepare_default(
|
mapping[create_field.column] = self.prepare_default(
|
||||||
self.effective_default(create_field),
|
self.effective_default(create_field)
|
||||||
)
|
)
|
||||||
# Add in any altered fields
|
# Add in any altered fields
|
||||||
for alter_field in alter_fields:
|
for alter_field in alter_fields:
|
||||||
|
@ -244,9 +248,13 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
|
||||||
mapping.pop(old_field.column, None)
|
mapping.pop(old_field.column, None)
|
||||||
body[new_field.name] = new_field
|
body[new_field.name] = new_field
|
||||||
if old_field.null and not new_field.null:
|
if old_field.null and not new_field.null:
|
||||||
|
if new_field.db_default is NOT_PROVIDED:
|
||||||
|
default = self.prepare_default(self.effective_default(new_field))
|
||||||
|
else:
|
||||||
|
default, _ = self.db_default_sql(new_field)
|
||||||
case_sql = "coalesce(%(col)s, %(default)s)" % {
|
case_sql = "coalesce(%(col)s, %(default)s)" % {
|
||||||
"col": self.quote_name(old_field.column),
|
"col": self.quote_name(old_field.column),
|
||||||
"default": self.prepare_default(self.effective_default(new_field)),
|
"default": default,
|
||||||
}
|
}
|
||||||
mapping[new_field.column] = case_sql
|
mapping[new_field.column] = case_sql
|
||||||
else:
|
else:
|
||||||
|
@ -381,6 +389,8 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
|
||||||
|
|
||||||
def add_field(self, model, field):
|
def add_field(self, model, field):
|
||||||
"""Create a field on a model."""
|
"""Create a field on a model."""
|
||||||
|
from django.db.models.expressions import Value
|
||||||
|
|
||||||
# Special-case implicit M2M tables.
|
# Special-case implicit M2M tables.
|
||||||
if field.many_to_many and field.remote_field.through._meta.auto_created:
|
if field.many_to_many and field.remote_field.through._meta.auto_created:
|
||||||
self.create_model(field.remote_field.through)
|
self.create_model(field.remote_field.through)
|
||||||
|
@ -394,6 +404,12 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
|
||||||
# COLUMN statement because DROP DEFAULT is not supported in
|
# COLUMN statement because DROP DEFAULT is not supported in
|
||||||
# ALTER TABLE.
|
# ALTER TABLE.
|
||||||
or self.effective_default(field) is not None
|
or self.effective_default(field) is not None
|
||||||
|
# Fields with non-constant defaults cannot by handled by ALTER
|
||||||
|
# TABLE ADD COLUMN statement.
|
||||||
|
or (
|
||||||
|
field.db_default is not NOT_PROVIDED
|
||||||
|
and not isinstance(field.db_default, Value)
|
||||||
|
)
|
||||||
):
|
):
|
||||||
self._remake_table(model, create_field=field)
|
self._remake_table(model, create_field=field)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -1040,6 +1040,7 @@ class MigrationAutodetector:
|
||||||
preserve_default = (
|
preserve_default = (
|
||||||
field.null
|
field.null
|
||||||
or field.has_default()
|
or field.has_default()
|
||||||
|
or field.db_default is not models.NOT_PROVIDED
|
||||||
or field.many_to_many
|
or field.many_to_many
|
||||||
or (field.blank and field.empty_strings_allowed)
|
or (field.blank and field.empty_strings_allowed)
|
||||||
or (isinstance(field, time_fields) and field.auto_now)
|
or (isinstance(field, time_fields) and field.auto_now)
|
||||||
|
@ -1187,6 +1188,7 @@ class MigrationAutodetector:
|
||||||
old_field.null
|
old_field.null
|
||||||
and not new_field.null
|
and not new_field.null
|
||||||
and not new_field.has_default()
|
and not new_field.has_default()
|
||||||
|
and new_field.db_default is models.NOT_PROVIDED
|
||||||
and not new_field.many_to_many
|
and not new_field.many_to_many
|
||||||
):
|
):
|
||||||
field = new_field.clone()
|
field = new_field.clone()
|
||||||
|
|
|
@ -971,8 +971,10 @@ class Model(AltersData, metaclass=ModelBase):
|
||||||
not raw
|
not raw
|
||||||
and not force_insert
|
and not force_insert
|
||||||
and self._state.adding
|
and self._state.adding
|
||||||
and meta.pk.default
|
and (
|
||||||
and meta.pk.default is not NOT_PROVIDED
|
(meta.pk.default and meta.pk.default is not NOT_PROVIDED)
|
||||||
|
or (meta.pk.db_default and meta.pk.db_default is not NOT_PROVIDED)
|
||||||
|
)
|
||||||
):
|
):
|
||||||
force_insert = True
|
force_insert = True
|
||||||
# If possible, try an UPDATE. If that doesn't update anything, do an INSERT.
|
# If possible, try an UPDATE. If that doesn't update anything, do an INSERT.
|
||||||
|
|
|
@ -176,6 +176,8 @@ class BaseExpression:
|
||||||
filterable = True
|
filterable = True
|
||||||
# Can the expression can be used as a source expression in Window?
|
# Can the expression can be used as a source expression in Window?
|
||||||
window_compatible = False
|
window_compatible = False
|
||||||
|
# Can the expression be used as a database default value?
|
||||||
|
allowed_default = False
|
||||||
|
|
||||||
def __init__(self, output_field=None):
|
def __init__(self, output_field=None):
|
||||||
if output_field is not None:
|
if output_field is not None:
|
||||||
|
@ -733,6 +735,10 @@ class CombinedExpression(SQLiteNumericMixin, Expression):
|
||||||
c.rhs = rhs
|
c.rhs = rhs
|
||||||
return c
|
return c
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def allowed_default(self):
|
||||||
|
return self.lhs.allowed_default and self.rhs.allowed_default
|
||||||
|
|
||||||
|
|
||||||
class DurationExpression(CombinedExpression):
|
class DurationExpression(CombinedExpression):
|
||||||
def compile(self, side, compiler, connection):
|
def compile(self, side, compiler, connection):
|
||||||
|
@ -804,6 +810,8 @@ class TemporalSubtraction(CombinedExpression):
|
||||||
class F(Combinable):
|
class F(Combinable):
|
||||||
"""An object capable of resolving references to existing query objects."""
|
"""An object capable of resolving references to existing query objects."""
|
||||||
|
|
||||||
|
allowed_default = False
|
||||||
|
|
||||||
def __init__(self, name):
|
def __init__(self, name):
|
||||||
"""
|
"""
|
||||||
Arguments:
|
Arguments:
|
||||||
|
@ -987,6 +995,10 @@ class Func(SQLiteNumericMixin, Expression):
|
||||||
copy.extra = self.extra.copy()
|
copy.extra = self.extra.copy()
|
||||||
return copy
|
return copy
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def allowed_default(self):
|
||||||
|
return all(expression.allowed_default for expression in self.source_expressions)
|
||||||
|
|
||||||
|
|
||||||
@deconstructible(path="django.db.models.Value")
|
@deconstructible(path="django.db.models.Value")
|
||||||
class Value(SQLiteNumericMixin, Expression):
|
class Value(SQLiteNumericMixin, Expression):
|
||||||
|
@ -995,6 +1007,7 @@ class Value(SQLiteNumericMixin, Expression):
|
||||||
# Provide a default value for `for_save` in order to allow unresolved
|
# Provide a default value for `for_save` in order to allow unresolved
|
||||||
# instances to be compiled until a decision is taken in #25425.
|
# instances to be compiled until a decision is taken in #25425.
|
||||||
for_save = False
|
for_save = False
|
||||||
|
allowed_default = True
|
||||||
|
|
||||||
def __init__(self, value, output_field=None):
|
def __init__(self, value, output_field=None):
|
||||||
"""
|
"""
|
||||||
|
@ -1069,6 +1082,8 @@ class Value(SQLiteNumericMixin, Expression):
|
||||||
|
|
||||||
|
|
||||||
class RawSQL(Expression):
|
class RawSQL(Expression):
|
||||||
|
allowed_default = True
|
||||||
|
|
||||||
def __init__(self, sql, params, output_field=None):
|
def __init__(self, sql, params, output_field=None):
|
||||||
if output_field is None:
|
if output_field is None:
|
||||||
output_field = fields.Field()
|
output_field = fields.Field()
|
||||||
|
@ -1110,6 +1125,13 @@ class Star(Expression):
|
||||||
return "*", []
|
return "*", []
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseDefault(Expression):
|
||||||
|
"""Placeholder expression for the database default in an insert query."""
|
||||||
|
|
||||||
|
def as_sql(self, compiler, connection):
|
||||||
|
return "DEFAULT", []
|
||||||
|
|
||||||
|
|
||||||
class Col(Expression):
|
class Col(Expression):
|
||||||
contains_column_references = True
|
contains_column_references = True
|
||||||
possibly_multivalued = False
|
possibly_multivalued = False
|
||||||
|
@ -1213,6 +1235,7 @@ class ExpressionList(Func):
|
||||||
|
|
||||||
|
|
||||||
class OrderByList(Func):
|
class OrderByList(Func):
|
||||||
|
allowed_default = False
|
||||||
template = "ORDER BY %(expressions)s"
|
template = "ORDER BY %(expressions)s"
|
||||||
|
|
||||||
def __init__(self, *expressions, **extra):
|
def __init__(self, *expressions, **extra):
|
||||||
|
@ -1270,6 +1293,10 @@ class ExpressionWrapper(SQLiteNumericMixin, Expression):
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "{}({})".format(self.__class__.__name__, self.expression)
|
return "{}({})".format(self.__class__.__name__, self.expression)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def allowed_default(self):
|
||||||
|
return self.expression.allowed_default
|
||||||
|
|
||||||
|
|
||||||
class NegatedExpression(ExpressionWrapper):
|
class NegatedExpression(ExpressionWrapper):
|
||||||
"""The logical negation of a conditional expression."""
|
"""The logical negation of a conditional expression."""
|
||||||
|
@ -1397,6 +1424,10 @@ class When(Expression):
|
||||||
cols.extend(source.get_group_by_cols())
|
cols.extend(source.get_group_by_cols())
|
||||||
return cols
|
return cols
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def allowed_default(self):
|
||||||
|
return self.condition.allowed_default and self.result.allowed_default
|
||||||
|
|
||||||
|
|
||||||
@deconstructible(path="django.db.models.Case")
|
@deconstructible(path="django.db.models.Case")
|
||||||
class Case(SQLiteNumericMixin, Expression):
|
class Case(SQLiteNumericMixin, Expression):
|
||||||
|
@ -1494,6 +1525,12 @@ class Case(SQLiteNumericMixin, Expression):
|
||||||
return self.default.get_group_by_cols()
|
return self.default.get_group_by_cols()
|
||||||
return super().get_group_by_cols()
|
return super().get_group_by_cols()
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def allowed_default(self):
|
||||||
|
return self.default.allowed_default and all(
|
||||||
|
case_.allowed_default for case_ in self.cases
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Subquery(BaseExpression, Combinable):
|
class Subquery(BaseExpression, Combinable):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -202,6 +202,7 @@ class Field(RegisterLookupMixin):
|
||||||
validators=(),
|
validators=(),
|
||||||
error_messages=None,
|
error_messages=None,
|
||||||
db_comment=None,
|
db_comment=None,
|
||||||
|
db_default=NOT_PROVIDED,
|
||||||
):
|
):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.verbose_name = verbose_name # May be set by set_attributes_from_name
|
self.verbose_name = verbose_name # May be set by set_attributes_from_name
|
||||||
|
@ -212,6 +213,13 @@ class Field(RegisterLookupMixin):
|
||||||
self.remote_field = rel
|
self.remote_field = rel
|
||||||
self.is_relation = self.remote_field is not None
|
self.is_relation = self.remote_field is not None
|
||||||
self.default = default
|
self.default = default
|
||||||
|
if db_default is not NOT_PROVIDED and not hasattr(
|
||||||
|
db_default, "resolve_expression"
|
||||||
|
):
|
||||||
|
from django.db.models.expressions import Value
|
||||||
|
|
||||||
|
db_default = Value(db_default)
|
||||||
|
self.db_default = db_default
|
||||||
self.editable = editable
|
self.editable = editable
|
||||||
self.serialize = serialize
|
self.serialize = serialize
|
||||||
self.unique_for_date = unique_for_date
|
self.unique_for_date = unique_for_date
|
||||||
|
@ -263,6 +271,7 @@ class Field(RegisterLookupMixin):
|
||||||
return [
|
return [
|
||||||
*self._check_field_name(),
|
*self._check_field_name(),
|
||||||
*self._check_choices(),
|
*self._check_choices(),
|
||||||
|
*self._check_db_default(**kwargs),
|
||||||
*self._check_db_index(),
|
*self._check_db_index(),
|
||||||
*self._check_db_comment(**kwargs),
|
*self._check_db_comment(**kwargs),
|
||||||
*self._check_null_allowed_for_primary_keys(),
|
*self._check_null_allowed_for_primary_keys(),
|
||||||
|
@ -379,6 +388,39 @@ class Field(RegisterLookupMixin):
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def _check_db_default(self, databases=None, **kwargs):
|
||||||
|
from django.db.models.expressions import Value
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.db_default is NOT_PROVIDED
|
||||||
|
or isinstance(self.db_default, Value)
|
||||||
|
or databases is None
|
||||||
|
):
|
||||||
|
return []
|
||||||
|
errors = []
|
||||||
|
for db in databases:
|
||||||
|
if not router.allow_migrate_model(db, self.model):
|
||||||
|
continue
|
||||||
|
connection = connections[db]
|
||||||
|
|
||||||
|
if not getattr(self.db_default, "allowed_default", False) and (
|
||||||
|
connection.features.supports_expression_defaults
|
||||||
|
):
|
||||||
|
msg = f"{self.db_default} cannot be used in db_default."
|
||||||
|
errors.append(checks.Error(msg, obj=self, id="fields.E012"))
|
||||||
|
|
||||||
|
if not (
|
||||||
|
connection.features.supports_expression_defaults
|
||||||
|
or "supports_expression_defaults"
|
||||||
|
in self.model._meta.required_db_features
|
||||||
|
):
|
||||||
|
msg = (
|
||||||
|
f"{connection.display_name} does not support default database "
|
||||||
|
"values with expressions (db_default)."
|
||||||
|
)
|
||||||
|
errors.append(checks.Error(msg, obj=self, id="fields.E011"))
|
||||||
|
return errors
|
||||||
|
|
||||||
def _check_db_index(self):
|
def _check_db_index(self):
|
||||||
if self.db_index not in (None, True, False):
|
if self.db_index not in (None, True, False):
|
||||||
return [
|
return [
|
||||||
|
@ -558,6 +600,7 @@ class Field(RegisterLookupMixin):
|
||||||
"null": False,
|
"null": False,
|
||||||
"db_index": False,
|
"db_index": False,
|
||||||
"default": NOT_PROVIDED,
|
"default": NOT_PROVIDED,
|
||||||
|
"db_default": NOT_PROVIDED,
|
||||||
"editable": True,
|
"editable": True,
|
||||||
"serialize": True,
|
"serialize": True,
|
||||||
"unique_for_date": None,
|
"unique_for_date": None,
|
||||||
|
@ -876,7 +919,10 @@ class Field(RegisterLookupMixin):
|
||||||
@property
|
@property
|
||||||
def db_returning(self):
|
def db_returning(self):
|
||||||
"""Private API intended only to be used by Django itself."""
|
"""Private API intended only to be used by Django itself."""
|
||||||
return False
|
return (
|
||||||
|
self.db_default is not NOT_PROVIDED
|
||||||
|
and connection.features.can_return_columns_from_insert
|
||||||
|
)
|
||||||
|
|
||||||
def set_attributes_from_name(self, name):
|
def set_attributes_from_name(self, name):
|
||||||
self.name = self.name or name
|
self.name = self.name or name
|
||||||
|
@ -929,7 +975,13 @@ class Field(RegisterLookupMixin):
|
||||||
|
|
||||||
def pre_save(self, model_instance, add):
|
def pre_save(self, model_instance, add):
|
||||||
"""Return field's value just before saving."""
|
"""Return field's value just before saving."""
|
||||||
return getattr(model_instance, self.attname)
|
value = getattr(model_instance, self.attname)
|
||||||
|
if not connection.features.supports_default_keyword_in_insert:
|
||||||
|
from django.db.models.expressions import DatabaseDefault
|
||||||
|
|
||||||
|
if isinstance(value, DatabaseDefault):
|
||||||
|
return self.db_default
|
||||||
|
return value
|
||||||
|
|
||||||
def get_prep_value(self, value):
|
def get_prep_value(self, value):
|
||||||
"""Perform preliminary non-db specific value checks and conversions."""
|
"""Perform preliminary non-db specific value checks and conversions."""
|
||||||
|
@ -968,6 +1020,11 @@ class Field(RegisterLookupMixin):
|
||||||
return self.default
|
return self.default
|
||||||
return lambda: self.default
|
return lambda: self.default
|
||||||
|
|
||||||
|
if self.db_default is not NOT_PROVIDED:
|
||||||
|
from django.db.models.expressions import DatabaseDefault
|
||||||
|
|
||||||
|
return DatabaseDefault
|
||||||
|
|
||||||
if (
|
if (
|
||||||
not self.empty_strings_allowed
|
not self.empty_strings_allowed
|
||||||
or self.null
|
or self.null
|
||||||
|
|
|
@ -105,6 +105,7 @@ class Coalesce(Func):
|
||||||
class Collate(Func):
|
class Collate(Func):
|
||||||
function = "COLLATE"
|
function = "COLLATE"
|
||||||
template = "%(expressions)s %(function)s %(collation)s"
|
template = "%(expressions)s %(function)s %(collation)s"
|
||||||
|
allowed_default = False
|
||||||
# Inspired from
|
# Inspired from
|
||||||
# https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS
|
# https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS
|
||||||
collation_re = _lazy_re_compile(r"^[\w\-]+$")
|
collation_re = _lazy_re_compile(r"^[\w\-]+$")
|
||||||
|
|
|
@ -185,6 +185,10 @@ class Lookup(Expression):
|
||||||
sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
|
sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
|
||||||
return sql, params
|
return sql, params
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def allowed_default(self):
|
||||||
|
return self.lhs.allowed_default and self.rhs.allowed_default
|
||||||
|
|
||||||
|
|
||||||
class Transform(RegisterLookupMixin, Func):
|
class Transform(RegisterLookupMixin, Func):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -654,10 +654,19 @@ class QuerySet(AltersData):
|
||||||
return await sync_to_async(self.create)(**kwargs)
|
return await sync_to_async(self.create)(**kwargs)
|
||||||
|
|
||||||
def _prepare_for_bulk_create(self, objs):
|
def _prepare_for_bulk_create(self, objs):
|
||||||
|
from django.db.models.expressions import DatabaseDefault
|
||||||
|
|
||||||
|
connection = connections[self.db]
|
||||||
for obj in objs:
|
for obj in objs:
|
||||||
if obj.pk is None:
|
if obj.pk is None:
|
||||||
# Populate new PK values.
|
# Populate new PK values.
|
||||||
obj.pk = obj._meta.pk.get_pk_value_on_save(obj)
|
obj.pk = obj._meta.pk.get_pk_value_on_save(obj)
|
||||||
|
if not connection.features.supports_default_keyword_in_bulk_insert:
|
||||||
|
for field in obj._meta.fields:
|
||||||
|
value = getattr(obj, field.attname)
|
||||||
|
if isinstance(value, DatabaseDefault):
|
||||||
|
setattr(obj, field.attname, field.db_default)
|
||||||
|
|
||||||
obj._prepare_related_fields_for_save(operation_name="bulk_create")
|
obj._prepare_related_fields_for_save(operation_name="bulk_create")
|
||||||
|
|
||||||
def _check_bulk_create_options(
|
def _check_bulk_create_options(
|
||||||
|
|
|
@ -175,6 +175,9 @@ Model fields
|
||||||
``choices`` (``<count>`` characters).
|
``choices`` (``<count>`` characters).
|
||||||
* **fields.E010**: ``<field>`` default should be a callable instead of an
|
* **fields.E010**: ``<field>`` default should be a callable instead of an
|
||||||
instance so that it's not shared between all field instances.
|
instance so that it's not shared between all field instances.
|
||||||
|
* **fields.E011**: ``<database>`` does not support default database values with
|
||||||
|
expressions (``db_default``).
|
||||||
|
* **fields.E012**: ``<expression>`` cannot be used in ``db_default``.
|
||||||
* **fields.E100**: ``AutoField``\s must set primary_key=True.
|
* **fields.E100**: ``AutoField``\s must set primary_key=True.
|
||||||
* **fields.E110**: ``BooleanField``\s do not accept null values. *This check
|
* **fields.E110**: ``BooleanField``\s do not accept null values. *This check
|
||||||
appeared before support for null values was added in Django 2.1.*
|
appeared before support for null values was added in Django 2.1.*
|
||||||
|
|
|
@ -996,6 +996,13 @@ calling the appropriate methods on the wrapped expression.
|
||||||
|
|
||||||
.. class:: Expression
|
.. class:: Expression
|
||||||
|
|
||||||
|
.. attribute:: allowed_default
|
||||||
|
|
||||||
|
.. versionadded:: 5.0
|
||||||
|
|
||||||
|
Tells Django that this expression can be used in
|
||||||
|
:attr:`Field.db_default`. Defaults to ``False``.
|
||||||
|
|
||||||
.. attribute:: contains_aggregate
|
.. attribute:: contains_aggregate
|
||||||
|
|
||||||
Tells Django that this expression contains an aggregate and that a
|
Tells Django that this expression contains an aggregate and that a
|
||||||
|
|
|
@ -351,6 +351,38 @@ looking at your Django code. For example::
|
||||||
db_comment="Date and time when the article was published",
|
db_comment="Date and time when the article was published",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
``db_default``
|
||||||
|
--------------
|
||||||
|
|
||||||
|
.. versionadded:: 5.0
|
||||||
|
|
||||||
|
.. attribute:: Field.db_default
|
||||||
|
|
||||||
|
The database-computed default value for this field. This can be a literal value
|
||||||
|
or a database function, such as :class:`~django.db.models.functions.Now`::
|
||||||
|
|
||||||
|
created = models.DateTimeField(db_default=Now())
|
||||||
|
|
||||||
|
More complex expressions can be used, as long as they are made from literals
|
||||||
|
and database functions::
|
||||||
|
|
||||||
|
month_due = models.DateField(
|
||||||
|
db_default=TruncMonth(
|
||||||
|
Now() + timedelta(days=90),
|
||||||
|
output_field=models.DateField(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
Database defaults cannot reference other fields or models. For example, this is
|
||||||
|
invalid::
|
||||||
|
|
||||||
|
end = models.IntegerField(db_default=F("start") + 50)
|
||||||
|
|
||||||
|
If both ``db_default`` and :attr:`Field.default` are set, ``default`` will take
|
||||||
|
precedence when creating instances in Python code. ``db_default`` will still be
|
||||||
|
set at the database level and will be used when inserting rows outside of the
|
||||||
|
ORM or when adding a new field in a migration.
|
||||||
|
|
||||||
``db_index``
|
``db_index``
|
||||||
------------
|
------------
|
||||||
|
|
||||||
|
@ -408,6 +440,9 @@ The default value is used when new model instances are created and a value
|
||||||
isn't provided for the field. When the field is a primary key, the default is
|
isn't provided for the field. When the field is a primary key, the default is
|
||||||
also used when the field is set to ``None``.
|
also used when the field is set to ``None``.
|
||||||
|
|
||||||
|
The default value can also be set at the database level with
|
||||||
|
:attr:`Field.db_default`.
|
||||||
|
|
||||||
``editable``
|
``editable``
|
||||||
------------
|
------------
|
||||||
|
|
||||||
|
|
|
@ -541,7 +541,8 @@ You may have noticed Django database objects use the same ``save()`` method
|
||||||
for creating and changing objects. Django abstracts the need to use ``INSERT``
|
for creating and changing objects. Django abstracts the need to use ``INSERT``
|
||||||
or ``UPDATE`` SQL statements. Specifically, when you call ``save()`` and the
|
or ``UPDATE`` SQL statements. Specifically, when you call ``save()`` and the
|
||||||
object's primary key attribute does **not** define a
|
object's primary key attribute does **not** define a
|
||||||
:attr:`~django.db.models.Field.default`, Django follows this algorithm:
|
:attr:`~django.db.models.Field.default` or
|
||||||
|
:attr:`~django.db.models.Field.db_default`, Django follows this algorithm:
|
||||||
|
|
||||||
* If the object's primary key attribute is set to a value that evaluates to
|
* If the object's primary key attribute is set to a value that evaluates to
|
||||||
``True`` (i.e., a value other than ``None`` or the empty string), Django
|
``True`` (i.e., a value other than ``None`` or the empty string), Django
|
||||||
|
@ -551,9 +552,10 @@ object's primary key attribute does **not** define a
|
||||||
exist in the database), Django executes an ``INSERT``.
|
exist in the database), Django executes an ``INSERT``.
|
||||||
|
|
||||||
If the object's primary key attribute defines a
|
If the object's primary key attribute defines a
|
||||||
:attr:`~django.db.models.Field.default` then Django executes an ``UPDATE`` if
|
:attr:`~django.db.models.Field.default` or
|
||||||
it is an existing model instance and primary key is set to a value that exists
|
:attr:`~django.db.models.Field.db_default` then Django executes an ``UPDATE``
|
||||||
in the database. Otherwise, Django executes an ``INSERT``.
|
if it is an existing model instance and primary key is set to a value that
|
||||||
|
exists in the database. Otherwise, Django executes an ``INSERT``.
|
||||||
|
|
||||||
The one gotcha here is that you should be careful not to specify a primary-key
|
The one gotcha here is that you should be careful not to specify a primary-key
|
||||||
value explicitly when saving new objects, if you cannot guarantee the
|
value explicitly when saving new objects, if you cannot guarantee the
|
||||||
|
@ -570,6 +572,10 @@ which returns ``NULL``. In such cases it is possible to revert to the old
|
||||||
algorithm by setting the :attr:`~django.db.models.Options.select_on_save`
|
algorithm by setting the :attr:`~django.db.models.Options.select_on_save`
|
||||||
option to ``True``.
|
option to ``True``.
|
||||||
|
|
||||||
|
.. versionchanged:: 5.0
|
||||||
|
|
||||||
|
The ``Field.db_default`` parameter was added.
|
||||||
|
|
||||||
.. _ref-models-force-insert:
|
.. _ref-models-force-insert:
|
||||||
|
|
||||||
Forcing an INSERT or UPDATE
|
Forcing an INSERT or UPDATE
|
||||||
|
|
|
@ -108,6 +108,21 @@ Can now be simplified to:
|
||||||
per-project, per-field, or per-request basis. See
|
per-project, per-field, or per-request basis. See
|
||||||
:ref:`reusable-field-group-templates`.
|
:ref:`reusable-field-group-templates`.
|
||||||
|
|
||||||
|
Database-computed default values
|
||||||
|
--------------------------------
|
||||||
|
|
||||||
|
The new :attr:`Field.db_default <django.db.models.Field.db_default>` parameter
|
||||||
|
sets a database-computed default value. For example::
|
||||||
|
|
||||||
|
from django.db import models
|
||||||
|
from django.db.models.functions import Now, Pi
|
||||||
|
|
||||||
|
|
||||||
|
class MyModel(models.Model):
|
||||||
|
age = models.IntegerField(db_default=18)
|
||||||
|
created = models.DateTimeField(db_default=Now())
|
||||||
|
circumference = models.FloatField(db_default=2 * Pi())
|
||||||
|
|
||||||
Minor features
|
Minor features
|
||||||
--------------
|
--------------
|
||||||
|
|
||||||
|
@ -355,7 +370,16 @@ Database backend API
|
||||||
This section describes changes that may be needed in third-party database
|
This section describes changes that may be needed in third-party database
|
||||||
backends.
|
backends.
|
||||||
|
|
||||||
* ...
|
* ``DatabaseFeatures.supports_expression_defaults`` should be set to ``False``
|
||||||
|
if the database doesn't support using database functions as defaults.
|
||||||
|
|
||||||
|
* ``DatabaseFeatures.supports_default_keyword_in_insert`` should be set to
|
||||||
|
``False`` if the database doesn't support the ``DEFAULT`` keyword in
|
||||||
|
``INSERT`` queries.
|
||||||
|
|
||||||
|
* ``DatabaseFeatures.supports_default_keyword_in_bulk insert`` should be set to
|
||||||
|
``False`` if the database doesn't support the ``DEFAULT`` keyword in bulk
|
||||||
|
``INSERT`` queries.
|
||||||
|
|
||||||
Using ``create_defaults__exact`` may now be required with ``QuerySet.update_or_create()``
|
Using ``create_defaults__exact`` may now be required with ``QuerySet.update_or_create()``
|
||||||
-----------------------------------------------------------------------------------------
|
-----------------------------------------------------------------------------------------
|
||||||
|
|
|
@ -49,5 +49,9 @@ class PrimaryKeyWithDefault(models.Model):
|
||||||
uuid = models.UUIDField(primary_key=True, default=uuid.uuid4)
|
uuid = models.UUIDField(primary_key=True, default=uuid.uuid4)
|
||||||
|
|
||||||
|
|
||||||
|
class PrimaryKeyWithDbDefault(models.Model):
|
||||||
|
uuid = models.IntegerField(primary_key=True, db_default=1)
|
||||||
|
|
||||||
|
|
||||||
class ChildPrimaryKeyWithDefault(PrimaryKeyWithDefault):
|
class ChildPrimaryKeyWithDefault(PrimaryKeyWithDefault):
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -20,6 +20,7 @@ from .models import (
|
||||||
ArticleSelectOnSave,
|
ArticleSelectOnSave,
|
||||||
ChildPrimaryKeyWithDefault,
|
ChildPrimaryKeyWithDefault,
|
||||||
FeaturedArticle,
|
FeaturedArticle,
|
||||||
|
PrimaryKeyWithDbDefault,
|
||||||
PrimaryKeyWithDefault,
|
PrimaryKeyWithDefault,
|
||||||
SelfRef,
|
SelfRef,
|
||||||
)
|
)
|
||||||
|
@ -175,6 +176,11 @@ class ModelInstanceCreationTests(TestCase):
|
||||||
with self.assertNumQueries(1):
|
with self.assertNumQueries(1):
|
||||||
PrimaryKeyWithDefault().save()
|
PrimaryKeyWithDefault().save()
|
||||||
|
|
||||||
|
def test_save_primary_with_db_default(self):
|
||||||
|
# An UPDATE attempt is skipped when a primary key has db_default.
|
||||||
|
with self.assertNumQueries(1):
|
||||||
|
PrimaryKeyWithDbDefault().save()
|
||||||
|
|
||||||
def test_save_parent_primary_with_default(self):
|
def test_save_parent_primary_with_default(self):
|
||||||
# An UPDATE attempt is skipped when an inherited primary key has
|
# An UPDATE attempt is skipped when an inherited primary key has
|
||||||
# default.
|
# default.
|
||||||
|
|
|
@ -12,6 +12,8 @@ field.
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from django.db import models
|
from django.db import models
|
||||||
|
from django.db.models.functions import Coalesce, ExtractYear, Now, Pi
|
||||||
|
from django.db.models.lookups import GreaterThan
|
||||||
|
|
||||||
|
|
||||||
class Article(models.Model):
|
class Article(models.Model):
|
||||||
|
@ -20,3 +22,45 @@ class Article(models.Model):
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.headline
|
return self.headline
|
||||||
|
|
||||||
|
|
||||||
|
class DBArticle(models.Model):
|
||||||
|
"""
|
||||||
|
Values or expressions can be passed as the db_default parameter to a field.
|
||||||
|
When the object is created without an explicit value passed in, the
|
||||||
|
database will insert the default value automatically.
|
||||||
|
"""
|
||||||
|
|
||||||
|
headline = models.CharField(max_length=100, db_default="Default headline")
|
||||||
|
pub_date = models.DateTimeField(db_default=Now())
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
required_db_features = {"supports_expression_defaults"}
|
||||||
|
|
||||||
|
|
||||||
|
class DBDefaults(models.Model):
|
||||||
|
both = models.IntegerField(default=1, db_default=2)
|
||||||
|
null = models.FloatField(null=True, db_default=1.1)
|
||||||
|
|
||||||
|
|
||||||
|
class DBDefaultsFunction(models.Model):
|
||||||
|
number = models.FloatField(db_default=Pi())
|
||||||
|
year = models.IntegerField(db_default=ExtractYear(Now()))
|
||||||
|
added = models.FloatField(db_default=Pi() + 4.5)
|
||||||
|
multiple_subfunctions = models.FloatField(db_default=Coalesce(4.5, Pi()))
|
||||||
|
case_when = models.IntegerField(
|
||||||
|
db_default=models.Case(models.When(GreaterThan(2, 1), then=3), default=4)
|
||||||
|
)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
required_db_features = {"supports_expression_defaults"}
|
||||||
|
|
||||||
|
|
||||||
|
class DBDefaultsPK(models.Model):
|
||||||
|
language_code = models.CharField(primary_key=True, max_length=2, db_default="en")
|
||||||
|
|
||||||
|
|
||||||
|
class DBDefaultsFK(models.Model):
|
||||||
|
language_code = models.ForeignKey(
|
||||||
|
DBDefaultsPK, db_default="fr", on_delete=models.CASCADE
|
||||||
|
)
|
||||||
|
|
|
@ -1,8 +1,28 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from math import pi
|
||||||
|
|
||||||
from django.test import TestCase
|
from django.db import connection
|
||||||
|
from django.db.models import Case, F, FloatField, Value, When
|
||||||
|
from django.db.models.expressions import (
|
||||||
|
Expression,
|
||||||
|
ExpressionList,
|
||||||
|
ExpressionWrapper,
|
||||||
|
Func,
|
||||||
|
OrderByList,
|
||||||
|
RawSQL,
|
||||||
|
)
|
||||||
|
from django.db.models.functions import Collate
|
||||||
|
from django.db.models.lookups import GreaterThan
|
||||||
|
from django.test import SimpleTestCase, TestCase, skipIfDBFeature, skipUnlessDBFeature
|
||||||
|
|
||||||
from .models import Article
|
from .models import (
|
||||||
|
Article,
|
||||||
|
DBArticle,
|
||||||
|
DBDefaults,
|
||||||
|
DBDefaultsFK,
|
||||||
|
DBDefaultsFunction,
|
||||||
|
DBDefaultsPK,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DefaultTests(TestCase):
|
class DefaultTests(TestCase):
|
||||||
|
@ -14,3 +34,171 @@ class DefaultTests(TestCase):
|
||||||
self.assertIsInstance(a.id, int)
|
self.assertIsInstance(a.id, int)
|
||||||
self.assertEqual(a.headline, "Default headline")
|
self.assertEqual(a.headline, "Default headline")
|
||||||
self.assertLess((now - a.pub_date).seconds, 5)
|
self.assertLess((now - a.pub_date).seconds, 5)
|
||||||
|
|
||||||
|
@skipUnlessDBFeature(
|
||||||
|
"can_return_columns_from_insert", "supports_expression_defaults"
|
||||||
|
)
|
||||||
|
def test_field_db_defaults_returning(self):
|
||||||
|
a = DBArticle()
|
||||||
|
a.save()
|
||||||
|
self.assertIsInstance(a.id, int)
|
||||||
|
self.assertEqual(a.headline, "Default headline")
|
||||||
|
self.assertIsInstance(a.pub_date, datetime)
|
||||||
|
|
||||||
|
@skipIfDBFeature("can_return_columns_from_insert")
|
||||||
|
@skipUnlessDBFeature("supports_expression_defaults")
|
||||||
|
def test_field_db_defaults_refresh(self):
|
||||||
|
a = DBArticle()
|
||||||
|
a.save()
|
||||||
|
a.refresh_from_db()
|
||||||
|
self.assertIsInstance(a.id, int)
|
||||||
|
self.assertEqual(a.headline, "Default headline")
|
||||||
|
self.assertIsInstance(a.pub_date, datetime)
|
||||||
|
|
||||||
|
def test_null_db_default(self):
|
||||||
|
obj1 = DBDefaults.objects.create()
|
||||||
|
if not connection.features.can_return_columns_from_insert:
|
||||||
|
obj1.refresh_from_db()
|
||||||
|
self.assertEqual(obj1.null, 1.1)
|
||||||
|
|
||||||
|
obj2 = DBDefaults.objects.create(null=None)
|
||||||
|
self.assertIsNone(obj2.null)
|
||||||
|
|
||||||
|
@skipUnlessDBFeature("supports_expression_defaults")
|
||||||
|
def test_db_default_function(self):
|
||||||
|
m = DBDefaultsFunction.objects.create()
|
||||||
|
if not connection.features.can_return_columns_from_insert:
|
||||||
|
m.refresh_from_db()
|
||||||
|
self.assertAlmostEqual(m.number, pi)
|
||||||
|
self.assertEqual(m.year, datetime.now().year)
|
||||||
|
self.assertAlmostEqual(m.added, pi + 4.5)
|
||||||
|
self.assertEqual(m.multiple_subfunctions, 4.5)
|
||||||
|
|
||||||
|
@skipUnlessDBFeature("insert_test_table_with_defaults")
|
||||||
|
def test_both_default(self):
|
||||||
|
create_sql = connection.features.insert_test_table_with_defaults
|
||||||
|
with connection.cursor() as cursor:
|
||||||
|
cursor.execute(create_sql.format(DBDefaults._meta.db_table))
|
||||||
|
obj1 = DBDefaults.objects.get()
|
||||||
|
self.assertEqual(obj1.both, 2)
|
||||||
|
|
||||||
|
obj2 = DBDefaults.objects.create()
|
||||||
|
self.assertEqual(obj2.both, 1)
|
||||||
|
|
||||||
|
def test_pk_db_default(self):
|
||||||
|
obj1 = DBDefaultsPK.objects.create()
|
||||||
|
if not connection.features.can_return_columns_from_insert:
|
||||||
|
# refresh_from_db() cannot be used because that needs the pk to
|
||||||
|
# already be known to Django.
|
||||||
|
obj1 = DBDefaultsPK.objects.get(pk="en")
|
||||||
|
self.assertEqual(obj1.pk, "en")
|
||||||
|
self.assertEqual(obj1.language_code, "en")
|
||||||
|
|
||||||
|
obj2 = DBDefaultsPK.objects.create(language_code="de")
|
||||||
|
self.assertEqual(obj2.pk, "de")
|
||||||
|
self.assertEqual(obj2.language_code, "de")
|
||||||
|
|
||||||
|
def test_foreign_key_db_default(self):
|
||||||
|
parent1 = DBDefaultsPK.objects.create(language_code="fr")
|
||||||
|
child1 = DBDefaultsFK.objects.create()
|
||||||
|
if not connection.features.can_return_columns_from_insert:
|
||||||
|
child1.refresh_from_db()
|
||||||
|
self.assertEqual(child1.language_code, parent1)
|
||||||
|
|
||||||
|
parent2 = DBDefaultsPK.objects.create()
|
||||||
|
if not connection.features.can_return_columns_from_insert:
|
||||||
|
# refresh_from_db() cannot be used because that needs the pk to
|
||||||
|
# already be known to Django.
|
||||||
|
parent2 = DBDefaultsPK.objects.get(pk="en")
|
||||||
|
child2 = DBDefaultsFK.objects.create(language_code=parent2)
|
||||||
|
self.assertEqual(child2.language_code, parent2)
|
||||||
|
|
||||||
|
@skipUnlessDBFeature(
|
||||||
|
"can_return_columns_from_insert", "supports_expression_defaults"
|
||||||
|
)
|
||||||
|
def test_case_when_db_default_returning(self):
|
||||||
|
m = DBDefaultsFunction.objects.create()
|
||||||
|
self.assertEqual(m.case_when, 3)
|
||||||
|
|
||||||
|
@skipIfDBFeature("can_return_columns_from_insert")
|
||||||
|
@skipUnlessDBFeature("supports_expression_defaults")
|
||||||
|
def test_case_when_db_default_no_returning(self):
|
||||||
|
m = DBDefaultsFunction.objects.create()
|
||||||
|
m.refresh_from_db()
|
||||||
|
self.assertEqual(m.case_when, 3)
|
||||||
|
|
||||||
|
@skipUnlessDBFeature("supports_expression_defaults")
|
||||||
|
def test_bulk_create_all_db_defaults(self):
|
||||||
|
articles = [DBArticle(), DBArticle()]
|
||||||
|
DBArticle.objects.bulk_create(articles)
|
||||||
|
|
||||||
|
headlines = DBArticle.objects.values_list("headline", flat=True)
|
||||||
|
self.assertSequenceEqual(headlines, ["Default headline", "Default headline"])
|
||||||
|
|
||||||
|
@skipUnlessDBFeature("supports_expression_defaults")
|
||||||
|
def test_bulk_create_all_db_defaults_one_field(self):
|
||||||
|
pub_date = datetime.now()
|
||||||
|
articles = [DBArticle(pub_date=pub_date), DBArticle(pub_date=pub_date)]
|
||||||
|
DBArticle.objects.bulk_create(articles)
|
||||||
|
|
||||||
|
headlines = DBArticle.objects.values_list("headline", "pub_date")
|
||||||
|
self.assertSequenceEqual(
|
||||||
|
headlines,
|
||||||
|
[
|
||||||
|
("Default headline", pub_date),
|
||||||
|
("Default headline", pub_date),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@skipUnlessDBFeature("supports_expression_defaults")
|
||||||
|
def test_bulk_create_mixed_db_defaults(self):
|
||||||
|
articles = [DBArticle(), DBArticle(headline="Something else")]
|
||||||
|
DBArticle.objects.bulk_create(articles)
|
||||||
|
|
||||||
|
headlines = DBArticle.objects.values_list("headline", flat=True)
|
||||||
|
self.assertCountEqual(headlines, ["Default headline", "Something else"])
|
||||||
|
|
||||||
|
@skipUnlessDBFeature("supports_expression_defaults")
|
||||||
|
def test_bulk_create_mixed_db_defaults_function(self):
|
||||||
|
instances = [DBDefaultsFunction(), DBDefaultsFunction(year=2000)]
|
||||||
|
DBDefaultsFunction.objects.bulk_create(instances)
|
||||||
|
|
||||||
|
years = DBDefaultsFunction.objects.values_list("year", flat=True)
|
||||||
|
self.assertCountEqual(years, [2000, datetime.now().year])
|
||||||
|
|
||||||
|
|
||||||
|
class AllowedDefaultTests(SimpleTestCase):
|
||||||
|
def test_allowed(self):
|
||||||
|
class Max(Func):
|
||||||
|
function = "MAX"
|
||||||
|
|
||||||
|
tests = [
|
||||||
|
Value(10),
|
||||||
|
Max(1, 2),
|
||||||
|
RawSQL("Now()", ()),
|
||||||
|
Value(10) + Value(7), # Combined expression.
|
||||||
|
ExpressionList(Value(1), Value(2)),
|
||||||
|
ExpressionWrapper(Value(1), output_field=FloatField()),
|
||||||
|
Case(When(GreaterThan(2, 1), then=3), default=4),
|
||||||
|
]
|
||||||
|
for expression in tests:
|
||||||
|
with self.subTest(expression=expression):
|
||||||
|
self.assertIs(expression.allowed_default, True)
|
||||||
|
|
||||||
|
def test_disallowed(self):
|
||||||
|
class Max(Func):
|
||||||
|
function = "MAX"
|
||||||
|
|
||||||
|
tests = [
|
||||||
|
Expression(),
|
||||||
|
F("field"),
|
||||||
|
Max(F("count"), 1),
|
||||||
|
Value(10) + F("count"), # Combined expression.
|
||||||
|
ExpressionList(F("count"), Value(2)),
|
||||||
|
ExpressionWrapper(F("count"), output_field=FloatField()),
|
||||||
|
Collate(Value("John"), "nocase"),
|
||||||
|
OrderByList("field"),
|
||||||
|
]
|
||||||
|
for expression in tests:
|
||||||
|
with self.subTest(expression=expression):
|
||||||
|
self.assertIs(expression.allowed_default, False)
|
||||||
|
|
|
@ -4,6 +4,7 @@ import uuid
|
||||||
from django.core.checks import Error
|
from django.core.checks import Error
|
||||||
from django.core.checks import Warning as DjangoWarning
|
from django.core.checks import Warning as DjangoWarning
|
||||||
from django.db import connection, models
|
from django.db import connection, models
|
||||||
|
from django.db.models.functions import Coalesce, Pi
|
||||||
from django.test import SimpleTestCase, TestCase, skipIfDBFeature, skipUnlessDBFeature
|
from django.test import SimpleTestCase, TestCase, skipIfDBFeature, skipUnlessDBFeature
|
||||||
from django.test.utils import isolate_apps, override_settings
|
from django.test.utils import isolate_apps, override_settings
|
||||||
from django.utils.functional import lazy
|
from django.utils.functional import lazy
|
||||||
|
@ -1057,3 +1058,109 @@ class DbCommentTests(TestCase):
|
||||||
|
|
||||||
errors = Model._meta.get_field("field").check(databases=self.databases)
|
errors = Model._meta.get_field("field").check(databases=self.databases)
|
||||||
self.assertEqual(errors, [])
|
self.assertEqual(errors, [])
|
||||||
|
|
||||||
|
|
||||||
|
@isolate_apps("invalid_models_tests")
|
||||||
|
class InvalidDBDefaultTests(TestCase):
|
||||||
|
def test_db_default(self):
|
||||||
|
class Model(models.Model):
|
||||||
|
field = models.FloatField(db_default=Pi())
|
||||||
|
|
||||||
|
field = Model._meta.get_field("field")
|
||||||
|
errors = field.check(databases=self.databases)
|
||||||
|
|
||||||
|
if connection.features.supports_expression_defaults:
|
||||||
|
expected_errors = []
|
||||||
|
else:
|
||||||
|
msg = (
|
||||||
|
f"{connection.display_name} does not support default database values "
|
||||||
|
"with expressions (db_default)."
|
||||||
|
)
|
||||||
|
expected_errors = [Error(msg=msg, obj=field, id="fields.E011")]
|
||||||
|
self.assertEqual(errors, expected_errors)
|
||||||
|
|
||||||
|
def test_db_default_literal(self):
|
||||||
|
class Model(models.Model):
|
||||||
|
field = models.IntegerField(db_default=1)
|
||||||
|
|
||||||
|
field = Model._meta.get_field("field")
|
||||||
|
errors = field.check(databases=self.databases)
|
||||||
|
self.assertEqual(errors, [])
|
||||||
|
|
||||||
|
def test_db_default_required_db_features(self):
|
||||||
|
class Model(models.Model):
|
||||||
|
field = models.FloatField(db_default=Pi())
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
required_db_features = {"supports_expression_defaults"}
|
||||||
|
|
||||||
|
field = Model._meta.get_field("field")
|
||||||
|
errors = field.check(databases=self.databases)
|
||||||
|
self.assertEqual(errors, [])
|
||||||
|
|
||||||
|
def test_db_default_expression_invalid(self):
|
||||||
|
expression = models.F("field_name")
|
||||||
|
|
||||||
|
class Model(models.Model):
|
||||||
|
field = models.FloatField(db_default=expression)
|
||||||
|
|
||||||
|
field = Model._meta.get_field("field")
|
||||||
|
errors = field.check(databases=self.databases)
|
||||||
|
|
||||||
|
if connection.features.supports_expression_defaults:
|
||||||
|
msg = f"{expression} cannot be used in db_default."
|
||||||
|
expected_errors = [Error(msg=msg, obj=field, id="fields.E012")]
|
||||||
|
else:
|
||||||
|
msg = (
|
||||||
|
f"{connection.display_name} does not support default database values "
|
||||||
|
"with expressions (db_default)."
|
||||||
|
)
|
||||||
|
expected_errors = [Error(msg=msg, obj=field, id="fields.E011")]
|
||||||
|
self.assertEqual(errors, expected_errors)
|
||||||
|
|
||||||
|
def test_db_default_expression_required_db_features(self):
|
||||||
|
expression = models.F("field_name")
|
||||||
|
|
||||||
|
class Model(models.Model):
|
||||||
|
field = models.FloatField(db_default=expression)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
required_db_features = {"supports_expression_defaults"}
|
||||||
|
|
||||||
|
field = Model._meta.get_field("field")
|
||||||
|
errors = field.check(databases=self.databases)
|
||||||
|
|
||||||
|
if connection.features.supports_expression_defaults:
|
||||||
|
msg = f"{expression} cannot be used in db_default."
|
||||||
|
expected_errors = [Error(msg=msg, obj=field, id="fields.E012")]
|
||||||
|
else:
|
||||||
|
expected_errors = []
|
||||||
|
self.assertEqual(errors, expected_errors)
|
||||||
|
|
||||||
|
@skipUnlessDBFeature("supports_expression_defaults")
|
||||||
|
def test_db_default_combined_invalid(self):
|
||||||
|
expression = models.Value(4.5) + models.F("field_name")
|
||||||
|
|
||||||
|
class Model(models.Model):
|
||||||
|
field = models.FloatField(db_default=expression)
|
||||||
|
|
||||||
|
field = Model._meta.get_field("field")
|
||||||
|
errors = field.check(databases=self.databases)
|
||||||
|
|
||||||
|
msg = f"{expression} cannot be used in db_default."
|
||||||
|
expected_error = Error(msg=msg, obj=field, id="fields.E012")
|
||||||
|
self.assertEqual(errors, [expected_error])
|
||||||
|
|
||||||
|
@skipUnlessDBFeature("supports_expression_defaults")
|
||||||
|
def test_db_default_function_arguments_invalid(self):
|
||||||
|
expression = Coalesce(models.Value(4.5), models.F("field_name"))
|
||||||
|
|
||||||
|
class Model(models.Model):
|
||||||
|
field = models.FloatField(db_default=expression)
|
||||||
|
|
||||||
|
field = Model._meta.get_field("field")
|
||||||
|
errors = field.check(databases=self.databases)
|
||||||
|
|
||||||
|
msg = f"{expression} cannot be used in db_default."
|
||||||
|
expected_error = Error(msg=msg, obj=field, id="fields.E012")
|
||||||
|
self.assertEqual(errors, [expected_error])
|
||||||
|
|
|
@ -269,6 +269,14 @@ class AutodetectorTests(BaseAutodetectorTests):
|
||||||
("name", models.CharField(max_length=200, default="Ada Lovelace")),
|
("name", models.CharField(max_length=200, default="Ada Lovelace")),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
author_name_db_default = ModelState(
|
||||||
|
"testapp",
|
||||||
|
"Author",
|
||||||
|
[
|
||||||
|
("id", models.AutoField(primary_key=True)),
|
||||||
|
("name", models.CharField(max_length=200, db_default="Ada Lovelace")),
|
||||||
|
],
|
||||||
|
)
|
||||||
author_name_check_constraint = ModelState(
|
author_name_check_constraint = ModelState(
|
||||||
"testapp",
|
"testapp",
|
||||||
"Author",
|
"Author",
|
||||||
|
@ -1289,6 +1297,21 @@ class AutodetectorTests(BaseAutodetectorTests):
|
||||||
self.assertOperationTypes(changes, "testapp", 0, ["AddField"])
|
self.assertOperationTypes(changes, "testapp", 0, ["AddField"])
|
||||||
self.assertOperationAttributes(changes, "testapp", 0, 0, name="name")
|
self.assertOperationAttributes(changes, "testapp", 0, 0, name="name")
|
||||||
|
|
||||||
|
@mock.patch(
|
||||||
|
"django.db.migrations.questioner.MigrationQuestioner.ask_not_null_addition",
|
||||||
|
side_effect=AssertionError("Should not have prompted for not null addition"),
|
||||||
|
)
|
||||||
|
def test_add_not_null_field_with_db_default(self, mocked_ask_method):
|
||||||
|
changes = self.get_changes([self.author_empty], [self.author_name_db_default])
|
||||||
|
self.assertNumberMigrations(changes, "testapp", 1)
|
||||||
|
self.assertOperationTypes(changes, "testapp", 0, ["AddField"])
|
||||||
|
self.assertOperationAttributes(
|
||||||
|
changes, "testapp", 0, 0, name="name", preserve_default=True
|
||||||
|
)
|
||||||
|
self.assertOperationFieldAttributes(
|
||||||
|
changes, "testapp", 0, 0, db_default=models.Value("Ada Lovelace")
|
||||||
|
)
|
||||||
|
|
||||||
@mock.patch(
|
@mock.patch(
|
||||||
"django.db.migrations.questioner.MigrationQuestioner.ask_not_null_addition",
|
"django.db.migrations.questioner.MigrationQuestioner.ask_not_null_addition",
|
||||||
side_effect=AssertionError("Should not have prompted for not null addition"),
|
side_effect=AssertionError("Should not have prompted for not null addition"),
|
||||||
|
@ -1478,6 +1501,23 @@ class AutodetectorTests(BaseAutodetectorTests):
|
||||||
changes, "testapp", 0, 0, default="Ada Lovelace"
|
changes, "testapp", 0, 0, default="Ada Lovelace"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@mock.patch(
|
||||||
|
"django.db.migrations.questioner.MigrationQuestioner.ask_not_null_alteration",
|
||||||
|
side_effect=AssertionError("Should not have prompted for not null alteration"),
|
||||||
|
)
|
||||||
|
def test_alter_field_to_not_null_with_db_default(self, mocked_ask_method):
|
||||||
|
changes = self.get_changes(
|
||||||
|
[self.author_name_null], [self.author_name_db_default]
|
||||||
|
)
|
||||||
|
self.assertNumberMigrations(changes, "testapp", 1)
|
||||||
|
self.assertOperationTypes(changes, "testapp", 0, ["AlterField"])
|
||||||
|
self.assertOperationAttributes(
|
||||||
|
changes, "testapp", 0, 0, name="name", preserve_default=True
|
||||||
|
)
|
||||||
|
self.assertOperationFieldAttributes(
|
||||||
|
changes, "testapp", 0, 0, db_default=models.Value("Ada Lovelace")
|
||||||
|
)
|
||||||
|
|
||||||
@mock.patch(
|
@mock.patch(
|
||||||
"django.db.migrations.questioner.MigrationQuestioner.ask_not_null_alteration",
|
"django.db.migrations.questioner.MigrationQuestioner.ask_not_null_alteration",
|
||||||
return_value=models.NOT_PROVIDED,
|
return_value=models.NOT_PROVIDED,
|
||||||
|
|
|
@ -292,6 +292,13 @@ class OperationTestBase(MigrationTestBase):
|
||||||
("id", models.AutoField(primary_key=True)),
|
("id", models.AutoField(primary_key=True)),
|
||||||
("pink", models.IntegerField(default=3)),
|
("pink", models.IntegerField(default=3)),
|
||||||
("weight", models.FloatField()),
|
("weight", models.FloatField()),
|
||||||
|
("green", models.IntegerField(null=True)),
|
||||||
|
(
|
||||||
|
"yellow",
|
||||||
|
models.CharField(
|
||||||
|
blank=True, null=True, db_default="Yellow", max_length=20
|
||||||
|
),
|
||||||
|
),
|
||||||
],
|
],
|
||||||
options=model_options,
|
options=model_options,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,14 +1,18 @@
|
||||||
|
import math
|
||||||
|
|
||||||
from django.core.exceptions import FieldDoesNotExist
|
from django.core.exceptions import FieldDoesNotExist
|
||||||
from django.db import IntegrityError, connection, migrations, models, transaction
|
from django.db import IntegrityError, connection, migrations, models, transaction
|
||||||
from django.db.migrations.migration import Migration
|
from django.db.migrations.migration import Migration
|
||||||
from django.db.migrations.operations.fields import FieldOperation
|
from django.db.migrations.operations.fields import FieldOperation
|
||||||
from django.db.migrations.state import ModelState, ProjectState
|
from django.db.migrations.state import ModelState, ProjectState
|
||||||
from django.db.models.functions import Abs
|
from django.db.models.expressions import Value
|
||||||
|
from django.db.models.functions import Abs, Pi
|
||||||
from django.db.transaction import atomic
|
from django.db.transaction import atomic
|
||||||
from django.test import (
|
from django.test import (
|
||||||
SimpleTestCase,
|
SimpleTestCase,
|
||||||
ignore_warnings,
|
ignore_warnings,
|
||||||
override_settings,
|
override_settings,
|
||||||
|
skipIfDBFeature,
|
||||||
skipUnlessDBFeature,
|
skipUnlessDBFeature,
|
||||||
)
|
)
|
||||||
from django.test.utils import CaptureQueriesContext
|
from django.test.utils import CaptureQueriesContext
|
||||||
|
@ -1340,7 +1344,7 @@ class OperationTests(OperationTestBase):
|
||||||
self.assertEqual(operation.describe(), "Add field height to Pony")
|
self.assertEqual(operation.describe(), "Add field height to Pony")
|
||||||
self.assertEqual(operation.migration_name_fragment, "pony_height")
|
self.assertEqual(operation.migration_name_fragment, "pony_height")
|
||||||
project_state, new_state = self.make_test_state("test_adfl", operation)
|
project_state, new_state = self.make_test_state("test_adfl", operation)
|
||||||
self.assertEqual(len(new_state.models["test_adfl", "pony"].fields), 4)
|
self.assertEqual(len(new_state.models["test_adfl", "pony"].fields), 6)
|
||||||
field = new_state.models["test_adfl", "pony"].fields["height"]
|
field = new_state.models["test_adfl", "pony"].fields["height"]
|
||||||
self.assertEqual(field.default, 5)
|
self.assertEqual(field.default, 5)
|
||||||
# Test the database alteration
|
# Test the database alteration
|
||||||
|
@ -1528,7 +1532,7 @@ class OperationTests(OperationTestBase):
|
||||||
)
|
)
|
||||||
new_state = project_state.clone()
|
new_state = project_state.clone()
|
||||||
operation.state_forwards("test_adflpd", new_state)
|
operation.state_forwards("test_adflpd", new_state)
|
||||||
self.assertEqual(len(new_state.models["test_adflpd", "pony"].fields), 4)
|
self.assertEqual(len(new_state.models["test_adflpd", "pony"].fields), 6)
|
||||||
field = new_state.models["test_adflpd", "pony"].fields["height"]
|
field = new_state.models["test_adflpd", "pony"].fields["height"]
|
||||||
self.assertEqual(field.default, models.NOT_PROVIDED)
|
self.assertEqual(field.default, models.NOT_PROVIDED)
|
||||||
# Test the database alteration
|
# Test the database alteration
|
||||||
|
@ -1547,6 +1551,169 @@ class OperationTests(OperationTestBase):
|
||||||
sorted(definition[2]), ["field", "model_name", "name", "preserve_default"]
|
sorted(definition[2]), ["field", "model_name", "name", "preserve_default"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_add_field_database_default(self):
|
||||||
|
"""The AddField operation can set and unset a database default."""
|
||||||
|
app_label = "test_adfldd"
|
||||||
|
table_name = f"{app_label}_pony"
|
||||||
|
project_state = self.set_up_test_model(app_label)
|
||||||
|
operation = migrations.AddField(
|
||||||
|
"Pony", "height", models.FloatField(null=True, db_default=4)
|
||||||
|
)
|
||||||
|
new_state = project_state.clone()
|
||||||
|
operation.state_forwards(app_label, new_state)
|
||||||
|
self.assertEqual(len(new_state.models[app_label, "pony"].fields), 6)
|
||||||
|
field = new_state.models[app_label, "pony"].fields["height"]
|
||||||
|
self.assertEqual(field.default, models.NOT_PROVIDED)
|
||||||
|
self.assertEqual(field.db_default, Value(4))
|
||||||
|
project_state.apps.get_model(app_label, "pony").objects.create(weight=4)
|
||||||
|
self.assertColumnNotExists(table_name, "height")
|
||||||
|
# Add field.
|
||||||
|
with connection.schema_editor() as editor:
|
||||||
|
operation.database_forwards(app_label, editor, project_state, new_state)
|
||||||
|
self.assertColumnExists(table_name, "height")
|
||||||
|
new_model = new_state.apps.get_model(app_label, "pony")
|
||||||
|
old_pony = new_model.objects.get()
|
||||||
|
self.assertEqual(old_pony.height, 4)
|
||||||
|
new_pony = new_model.objects.create(weight=5)
|
||||||
|
if not connection.features.can_return_columns_from_insert:
|
||||||
|
new_pony.refresh_from_db()
|
||||||
|
self.assertEqual(new_pony.height, 4)
|
||||||
|
# Reversal.
|
||||||
|
with connection.schema_editor() as editor:
|
||||||
|
operation.database_backwards(app_label, editor, new_state, project_state)
|
||||||
|
self.assertColumnNotExists(table_name, "height")
|
||||||
|
# Deconstruction.
|
||||||
|
definition = operation.deconstruct()
|
||||||
|
self.assertEqual(definition[0], "AddField")
|
||||||
|
self.assertEqual(definition[1], [])
|
||||||
|
self.assertEqual(
|
||||||
|
definition[2],
|
||||||
|
{
|
||||||
|
"field": field,
|
||||||
|
"model_name": "Pony",
|
||||||
|
"name": "height",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_add_field_database_default_special_char_escaping(self):
|
||||||
|
app_label = "test_adflddsce"
|
||||||
|
table_name = f"{app_label}_pony"
|
||||||
|
project_state = self.set_up_test_model(app_label)
|
||||||
|
old_pony_pk = (
|
||||||
|
project_state.apps.get_model(app_label, "pony").objects.create(weight=4).pk
|
||||||
|
)
|
||||||
|
tests = ["%", "'", '"']
|
||||||
|
for db_default in tests:
|
||||||
|
with self.subTest(db_default=db_default):
|
||||||
|
operation = migrations.AddField(
|
||||||
|
"Pony",
|
||||||
|
"special_char",
|
||||||
|
models.CharField(max_length=1, db_default=db_default),
|
||||||
|
)
|
||||||
|
new_state = project_state.clone()
|
||||||
|
operation.state_forwards(app_label, new_state)
|
||||||
|
self.assertEqual(len(new_state.models[app_label, "pony"].fields), 6)
|
||||||
|
field = new_state.models[app_label, "pony"].fields["special_char"]
|
||||||
|
self.assertEqual(field.default, models.NOT_PROVIDED)
|
||||||
|
self.assertEqual(field.db_default, Value(db_default))
|
||||||
|
self.assertColumnNotExists(table_name, "special_char")
|
||||||
|
with connection.schema_editor() as editor:
|
||||||
|
operation.database_forwards(
|
||||||
|
app_label, editor, project_state, new_state
|
||||||
|
)
|
||||||
|
self.assertColumnExists(table_name, "special_char")
|
||||||
|
new_model = new_state.apps.get_model(app_label, "pony")
|
||||||
|
try:
|
||||||
|
new_pony = new_model.objects.create(weight=5)
|
||||||
|
if not connection.features.can_return_columns_from_insert:
|
||||||
|
new_pony.refresh_from_db()
|
||||||
|
self.assertEqual(new_pony.special_char, db_default)
|
||||||
|
|
||||||
|
old_pony = new_model.objects.get(pk=old_pony_pk)
|
||||||
|
if connection.vendor != "oracle" or db_default != "'":
|
||||||
|
# The single quotation mark ' is properly quoted and is
|
||||||
|
# set for new rows on Oracle, however it is not set on
|
||||||
|
# existing rows. Skip the assertion as it's probably a
|
||||||
|
# bug in Oracle.
|
||||||
|
self.assertEqual(old_pony.special_char, db_default)
|
||||||
|
finally:
|
||||||
|
with connection.schema_editor() as editor:
|
||||||
|
operation.database_backwards(
|
||||||
|
app_label, editor, new_state, project_state
|
||||||
|
)
|
||||||
|
|
||||||
|
@skipUnlessDBFeature("supports_expression_defaults")
|
||||||
|
def test_add_field_database_default_function(self):
|
||||||
|
app_label = "test_adflddf"
|
||||||
|
table_name = f"{app_label}_pony"
|
||||||
|
project_state = self.set_up_test_model(app_label)
|
||||||
|
operation = migrations.AddField(
|
||||||
|
"Pony", "height", models.FloatField(db_default=Pi())
|
||||||
|
)
|
||||||
|
new_state = project_state.clone()
|
||||||
|
operation.state_forwards(app_label, new_state)
|
||||||
|
self.assertEqual(len(new_state.models[app_label, "pony"].fields), 6)
|
||||||
|
field = new_state.models[app_label, "pony"].fields["height"]
|
||||||
|
self.assertEqual(field.default, models.NOT_PROVIDED)
|
||||||
|
self.assertEqual(field.db_default, Pi())
|
||||||
|
project_state.apps.get_model(app_label, "pony").objects.create(weight=4)
|
||||||
|
self.assertColumnNotExists(table_name, "height")
|
||||||
|
# Add field.
|
||||||
|
with connection.schema_editor() as editor:
|
||||||
|
operation.database_forwards(app_label, editor, project_state, new_state)
|
||||||
|
self.assertColumnExists(table_name, "height")
|
||||||
|
new_model = new_state.apps.get_model(app_label, "pony")
|
||||||
|
old_pony = new_model.objects.get()
|
||||||
|
self.assertAlmostEqual(old_pony.height, math.pi)
|
||||||
|
new_pony = new_model.objects.create(weight=5)
|
||||||
|
if not connection.features.can_return_columns_from_insert:
|
||||||
|
new_pony.refresh_from_db()
|
||||||
|
self.assertAlmostEqual(old_pony.height, math.pi)
|
||||||
|
|
||||||
|
def test_add_field_both_defaults(self):
|
||||||
|
"""The AddField operation with both default and db_default."""
|
||||||
|
app_label = "test_adflbddd"
|
||||||
|
table_name = f"{app_label}_pony"
|
||||||
|
project_state = self.set_up_test_model(app_label)
|
||||||
|
operation = migrations.AddField(
|
||||||
|
"Pony", "height", models.FloatField(default=3, db_default=4)
|
||||||
|
)
|
||||||
|
new_state = project_state.clone()
|
||||||
|
operation.state_forwards(app_label, new_state)
|
||||||
|
self.assertEqual(len(new_state.models[app_label, "pony"].fields), 6)
|
||||||
|
field = new_state.models[app_label, "pony"].fields["height"]
|
||||||
|
self.assertEqual(field.default, 3)
|
||||||
|
self.assertEqual(field.db_default, Value(4))
|
||||||
|
project_state.apps.get_model(app_label, "pony").objects.create(weight=4)
|
||||||
|
self.assertColumnNotExists(table_name, "height")
|
||||||
|
# Add field.
|
||||||
|
with connection.schema_editor() as editor:
|
||||||
|
operation.database_forwards(app_label, editor, project_state, new_state)
|
||||||
|
self.assertColumnExists(table_name, "height")
|
||||||
|
new_model = new_state.apps.get_model(app_label, "pony")
|
||||||
|
old_pony = new_model.objects.get()
|
||||||
|
self.assertEqual(old_pony.height, 4)
|
||||||
|
new_pony = new_model.objects.create(weight=5)
|
||||||
|
if not connection.features.can_return_columns_from_insert:
|
||||||
|
new_pony.refresh_from_db()
|
||||||
|
self.assertEqual(new_pony.height, 3)
|
||||||
|
# Reversal.
|
||||||
|
with connection.schema_editor() as editor:
|
||||||
|
operation.database_backwards(app_label, editor, new_state, project_state)
|
||||||
|
self.assertColumnNotExists(table_name, "height")
|
||||||
|
# Deconstruction.
|
||||||
|
definition = operation.deconstruct()
|
||||||
|
self.assertEqual(definition[0], "AddField")
|
||||||
|
self.assertEqual(definition[1], [])
|
||||||
|
self.assertEqual(
|
||||||
|
definition[2],
|
||||||
|
{
|
||||||
|
"field": field,
|
||||||
|
"model_name": "Pony",
|
||||||
|
"name": "height",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def test_add_field_m2m(self):
|
def test_add_field_m2m(self):
|
||||||
"""
|
"""
|
||||||
Tests the AddField operation with a ManyToManyField.
|
Tests the AddField operation with a ManyToManyField.
|
||||||
|
@ -1558,7 +1725,7 @@ class OperationTests(OperationTestBase):
|
||||||
)
|
)
|
||||||
new_state = project_state.clone()
|
new_state = project_state.clone()
|
||||||
operation.state_forwards("test_adflmm", new_state)
|
operation.state_forwards("test_adflmm", new_state)
|
||||||
self.assertEqual(len(new_state.models["test_adflmm", "pony"].fields), 4)
|
self.assertEqual(len(new_state.models["test_adflmm", "pony"].fields), 6)
|
||||||
# Test the database alteration
|
# Test the database alteration
|
||||||
self.assertTableNotExists("test_adflmm_pony_stables")
|
self.assertTableNotExists("test_adflmm_pony_stables")
|
||||||
with connection.schema_editor() as editor:
|
with connection.schema_editor() as editor:
|
||||||
|
@ -1727,7 +1894,7 @@ class OperationTests(OperationTestBase):
|
||||||
self.assertEqual(operation.migration_name_fragment, "remove_pony_pink")
|
self.assertEqual(operation.migration_name_fragment, "remove_pony_pink")
|
||||||
new_state = project_state.clone()
|
new_state = project_state.clone()
|
||||||
operation.state_forwards("test_rmfl", new_state)
|
operation.state_forwards("test_rmfl", new_state)
|
||||||
self.assertEqual(len(new_state.models["test_rmfl", "pony"].fields), 2)
|
self.assertEqual(len(new_state.models["test_rmfl", "pony"].fields), 4)
|
||||||
# Test the database alteration
|
# Test the database alteration
|
||||||
self.assertColumnExists("test_rmfl_pony", "pink")
|
self.assertColumnExists("test_rmfl_pony", "pink")
|
||||||
with connection.schema_editor() as editor:
|
with connection.schema_editor() as editor:
|
||||||
|
@ -1934,6 +2101,146 @@ class OperationTests(OperationTestBase):
|
||||||
self.assertEqual(definition[1], [])
|
self.assertEqual(definition[1], [])
|
||||||
self.assertEqual(sorted(definition[2]), ["field", "model_name", "name"])
|
self.assertEqual(sorted(definition[2]), ["field", "model_name", "name"])
|
||||||
|
|
||||||
|
def test_alter_field_add_database_default(self):
|
||||||
|
app_label = "test_alfladd"
|
||||||
|
project_state = self.set_up_test_model(app_label)
|
||||||
|
operation = migrations.AlterField(
|
||||||
|
"Pony", "weight", models.FloatField(db_default=4.5)
|
||||||
|
)
|
||||||
|
new_state = project_state.clone()
|
||||||
|
operation.state_forwards(app_label, new_state)
|
||||||
|
old_weight = project_state.models[app_label, "pony"].fields["weight"]
|
||||||
|
self.assertIs(old_weight.db_default, models.NOT_PROVIDED)
|
||||||
|
new_weight = new_state.models[app_label, "pony"].fields["weight"]
|
||||||
|
self.assertEqual(new_weight.db_default, Value(4.5))
|
||||||
|
with self.assertRaises(IntegrityError), transaction.atomic():
|
||||||
|
project_state.apps.get_model(app_label, "pony").objects.create()
|
||||||
|
# Alter field.
|
||||||
|
with connection.schema_editor() as editor:
|
||||||
|
operation.database_forwards(app_label, editor, project_state, new_state)
|
||||||
|
pony = new_state.apps.get_model(app_label, "pony").objects.create()
|
||||||
|
if not connection.features.can_return_columns_from_insert:
|
||||||
|
pony.refresh_from_db()
|
||||||
|
self.assertEqual(pony.weight, 4.5)
|
||||||
|
# Reversal.
|
||||||
|
with connection.schema_editor() as editor:
|
||||||
|
operation.database_backwards(app_label, editor, new_state, project_state)
|
||||||
|
with self.assertRaises(IntegrityError), transaction.atomic():
|
||||||
|
project_state.apps.get_model(app_label, "pony").objects.create()
|
||||||
|
# Deconstruction.
|
||||||
|
definition = operation.deconstruct()
|
||||||
|
self.assertEqual(definition[0], "AlterField")
|
||||||
|
self.assertEqual(definition[1], [])
|
||||||
|
self.assertEqual(
|
||||||
|
definition[2],
|
||||||
|
{
|
||||||
|
"field": new_weight,
|
||||||
|
"model_name": "Pony",
|
||||||
|
"name": "weight",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_alter_field_change_default_to_database_default(self):
|
||||||
|
"""The AlterField operation changing default to db_default."""
|
||||||
|
app_label = "test_alflcdtdd"
|
||||||
|
project_state = self.set_up_test_model(app_label)
|
||||||
|
operation = migrations.AlterField(
|
||||||
|
"Pony", "pink", models.IntegerField(db_default=4)
|
||||||
|
)
|
||||||
|
new_state = project_state.clone()
|
||||||
|
operation.state_forwards(app_label, new_state)
|
||||||
|
old_pink = project_state.models[app_label, "pony"].fields["pink"]
|
||||||
|
self.assertEqual(old_pink.default, 3)
|
||||||
|
self.assertIs(old_pink.db_default, models.NOT_PROVIDED)
|
||||||
|
new_pink = new_state.models[app_label, "pony"].fields["pink"]
|
||||||
|
self.assertIs(new_pink.default, models.NOT_PROVIDED)
|
||||||
|
self.assertEqual(new_pink.db_default, Value(4))
|
||||||
|
pony = project_state.apps.get_model(app_label, "pony").objects.create(weight=1)
|
||||||
|
self.assertEqual(pony.pink, 3)
|
||||||
|
# Alter field.
|
||||||
|
with connection.schema_editor() as editor:
|
||||||
|
operation.database_forwards(app_label, editor, project_state, new_state)
|
||||||
|
pony = new_state.apps.get_model(app_label, "pony").objects.create(weight=1)
|
||||||
|
if not connection.features.can_return_columns_from_insert:
|
||||||
|
pony.refresh_from_db()
|
||||||
|
self.assertEqual(pony.pink, 4)
|
||||||
|
# Reversal.
|
||||||
|
with connection.schema_editor() as editor:
|
||||||
|
operation.database_backwards(app_label, editor, new_state, project_state)
|
||||||
|
pony = project_state.apps.get_model(app_label, "pony").objects.create(weight=1)
|
||||||
|
self.assertEqual(pony.pink, 3)
|
||||||
|
|
||||||
|
def test_alter_field_change_nullable_to_database_default_not_null(self):
|
||||||
|
"""
|
||||||
|
The AlterField operation changing a null field to db_default.
|
||||||
|
"""
|
||||||
|
app_label = "test_alflcntddnn"
|
||||||
|
project_state = self.set_up_test_model(app_label)
|
||||||
|
operation = migrations.AlterField(
|
||||||
|
"Pony", "green", models.IntegerField(db_default=4)
|
||||||
|
)
|
||||||
|
new_state = project_state.clone()
|
||||||
|
operation.state_forwards(app_label, new_state)
|
||||||
|
old_green = project_state.models[app_label, "pony"].fields["green"]
|
||||||
|
self.assertIs(old_green.db_default, models.NOT_PROVIDED)
|
||||||
|
new_green = new_state.models[app_label, "pony"].fields["green"]
|
||||||
|
self.assertEqual(new_green.db_default, Value(4))
|
||||||
|
old_pony = project_state.apps.get_model(app_label, "pony").objects.create(
|
||||||
|
weight=1
|
||||||
|
)
|
||||||
|
self.assertIsNone(old_pony.green)
|
||||||
|
# Alter field.
|
||||||
|
with connection.schema_editor() as editor:
|
||||||
|
operation.database_forwards(app_label, editor, project_state, new_state)
|
||||||
|
old_pony.refresh_from_db()
|
||||||
|
self.assertEqual(old_pony.green, 4)
|
||||||
|
pony = new_state.apps.get_model(app_label, "pony").objects.create(weight=1)
|
||||||
|
if not connection.features.can_return_columns_from_insert:
|
||||||
|
pony.refresh_from_db()
|
||||||
|
self.assertEqual(pony.green, 4)
|
||||||
|
# Reversal.
|
||||||
|
with connection.schema_editor() as editor:
|
||||||
|
operation.database_backwards(app_label, editor, new_state, project_state)
|
||||||
|
pony = project_state.apps.get_model(app_label, "pony").objects.create(weight=1)
|
||||||
|
self.assertIsNone(pony.green)
|
||||||
|
|
||||||
|
@skipIfDBFeature("interprets_empty_strings_as_nulls")
|
||||||
|
def test_alter_field_change_blank_nullable_database_default_to_not_null(self):
|
||||||
|
app_label = "test_alflcbnddnn"
|
||||||
|
table_name = f"{app_label}_pony"
|
||||||
|
project_state = self.set_up_test_model(app_label)
|
||||||
|
default = "Yellow"
|
||||||
|
operation = migrations.AlterField(
|
||||||
|
"Pony",
|
||||||
|
"yellow",
|
||||||
|
models.CharField(blank=True, db_default=default, max_length=20),
|
||||||
|
)
|
||||||
|
new_state = project_state.clone()
|
||||||
|
operation.state_forwards(app_label, new_state)
|
||||||
|
self.assertColumnNull(table_name, "yellow")
|
||||||
|
pony = project_state.apps.get_model(app_label, "pony").objects.create(
|
||||||
|
weight=1, yellow=None
|
||||||
|
)
|
||||||
|
self.assertIsNone(pony.yellow)
|
||||||
|
# Alter field.
|
||||||
|
with connection.schema_editor() as editor:
|
||||||
|
operation.database_forwards(app_label, editor, project_state, new_state)
|
||||||
|
self.assertColumnNotNull(table_name, "yellow")
|
||||||
|
pony.refresh_from_db()
|
||||||
|
self.assertEqual(pony.yellow, default)
|
||||||
|
pony = new_state.apps.get_model(app_label, "pony").objects.create(weight=1)
|
||||||
|
if not connection.features.can_return_columns_from_insert:
|
||||||
|
pony.refresh_from_db()
|
||||||
|
self.assertEqual(pony.yellow, default)
|
||||||
|
# Reversal.
|
||||||
|
with connection.schema_editor() as editor:
|
||||||
|
operation.database_backwards(app_label, editor, new_state, project_state)
|
||||||
|
self.assertColumnNull(table_name, "yellow")
|
||||||
|
pony = project_state.apps.get_model(app_label, "pony").objects.create(
|
||||||
|
weight=1, yellow=None
|
||||||
|
)
|
||||||
|
self.assertIsNone(pony.yellow)
|
||||||
|
|
||||||
def test_alter_field_add_db_column_noop(self):
|
def test_alter_field_add_db_column_noop(self):
|
||||||
"""
|
"""
|
||||||
AlterField operation is a noop when adding only a db_column and the
|
AlterField operation is a noop when adding only a db_column and the
|
||||||
|
|
|
@ -2102,6 +2102,33 @@ class SchemaTests(TransactionTestCase):
|
||||||
with self.assertRaises(IntegrityError):
|
with self.assertRaises(IntegrityError):
|
||||||
NoteRename.objects.create(detail_info=None)
|
NoteRename.objects.create(detail_info=None)
|
||||||
|
|
||||||
|
@isolate_apps("schema")
|
||||||
|
def test_rename_keep_db_default(self):
|
||||||
|
"""Renaming a field shouldn't affect a database default."""
|
||||||
|
|
||||||
|
class AuthorDbDefault(Model):
|
||||||
|
birth_year = IntegerField(db_default=1985)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
app_label = "schema"
|
||||||
|
|
||||||
|
self.isolated_local_models = [AuthorDbDefault]
|
||||||
|
with connection.schema_editor() as editor:
|
||||||
|
editor.create_model(AuthorDbDefault)
|
||||||
|
columns = self.column_classes(AuthorDbDefault)
|
||||||
|
self.assertEqual(columns["birth_year"][1].default, "1985")
|
||||||
|
|
||||||
|
old_field = AuthorDbDefault._meta.get_field("birth_year")
|
||||||
|
new_field = IntegerField(db_default=1985)
|
||||||
|
new_field.set_attributes_from_name("renamed_year")
|
||||||
|
new_field.model = AuthorDbDefault
|
||||||
|
with connection.schema_editor(
|
||||||
|
atomic=connection.features.supports_atomic_references_rename
|
||||||
|
) as editor:
|
||||||
|
editor.alter_field(AuthorDbDefault, old_field, new_field, strict=True)
|
||||||
|
columns = self.column_classes(AuthorDbDefault)
|
||||||
|
self.assertEqual(columns["renamed_year"][1].default, "1985")
|
||||||
|
|
||||||
@skipUnlessDBFeature(
|
@skipUnlessDBFeature(
|
||||||
"supports_column_check_constraints", "can_introspect_check_constraints"
|
"supports_column_check_constraints", "can_introspect_check_constraints"
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue