Fixed #31300 -- Added GeneratedField model field.

Thanks Adam Johnson and Paolo Melchiorre for reviews.

Co-Authored-By: Lily Foote <code@lilyf.org>
Co-Authored-By: Mariusz Felisiak <felisiak.mariusz@gmail.com>
This commit is contained in:
Jeremy Nauta 2023-07-06 20:36:48 -06:00 committed by Mariusz Felisiak
parent cafe7266ee
commit f333e3513e
22 changed files with 807 additions and 11 deletions

View File

@ -353,6 +353,11 @@ class BaseDatabaseFeatures:
# Does the backend support column comments in ADD COLUMN statements?
supports_comments_inline = False
# Does the backend support stored generated columns?
supports_stored_generated_columns = False
# Does the backend support virtual generated columns?
supports_virtual_generated_columns = False
# Does the backend support the logical XOR operator?
supports_logical_xor = False

View File

@ -332,7 +332,9 @@ class BaseDatabaseSchemaEditor:
and self.connection.features.interprets_empty_strings_as_nulls
):
null = True
if not null:
if field.generated:
yield self._column_generated_sql(field)
elif not null:
yield "NOT NULL"
elif not self.connection.features.implied_column_null:
yield "NULL"
@ -422,11 +424,21 @@ class BaseDatabaseSchemaEditor:
params = []
return sql % default_sql, params
def _column_generated_sql(self, field):
"""Return the SQL to use in a GENERATED ALWAYS clause."""
expression_sql, params = field.generated_sql(self.connection)
persistency_sql = "STORED" if field.db_persist else "VIRTUAL"
if params:
expression_sql = expression_sql % tuple(self.quote_value(p) for p in params)
return f"GENERATED ALWAYS AS ({expression_sql}) {persistency_sql}"
@staticmethod
def _effective_default(field):
# This method allows testing its logic without a connection.
if field.has_default():
default = field.get_default()
elif field.generated:
default = None
elif not field.null and field.blank and field.empty_strings_allowed:
if field.get_internal_type() == "BinaryField":
default = b""
@ -848,6 +860,18 @@ class BaseDatabaseSchemaEditor:
"(you cannot alter to or from M2M fields, or add or remove "
"through= on M2M fields)" % (old_field, new_field)
)
elif old_field.generated != new_field.generated or (
new_field.generated
and (
old_field.db_persist != new_field.db_persist
or old_field.generated_sql(self.connection)
!= new_field.generated_sql(self.connection)
)
):
raise ValueError(
f"Modifying GeneratedFields is not supported - the field {new_field} "
"must be removed and re-added with the new definition."
)
self._alter_field(
model,

View File

@ -60,6 +60,9 @@ class DatabaseFeatures(BaseDatabaseFeatures):
order_by_nulls_first = True
supports_logical_xor = True
supports_stored_generated_columns = True
supports_virtual_generated_columns = True
@cached_property
def minimum_database_version(self):
if self.connection.mysql_is_mariadb:

View File

@ -70,6 +70,8 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_ignore_conflicts = False
max_query_params = 2**16 - 1
supports_partial_indexes = False
supports_stored_generated_columns = False
supports_virtual_generated_columns = True
can_rename_index = True
supports_slicing_ordering_in_compound = True
requires_compound_order_by_subquery = True

View File

@ -70,6 +70,8 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_update_conflicts = True
supports_update_conflicts_with_target = True
supports_covering_indexes = True
supports_stored_generated_columns = True
supports_virtual_generated_columns = False
can_rename_index = True
test_collations = {
"non_default": "sv-x-icu",

View File

@ -40,6 +40,8 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_json_field_contains = False
supports_update_conflicts = True
supports_update_conflicts_with_target = True
supports_stored_generated_columns = Database.sqlite_version_info >= (3, 31, 0)
supports_virtual_generated_columns = Database.sqlite_version_info >= (3, 31, 0)
test_collations = {
"ci": "nocase",
"cs": "binary",

View File

@ -91,7 +91,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
interface.
"""
cursor.execute(
"PRAGMA table_info(%s)" % self.connection.ops.quote_name(table_name)
"PRAGMA table_xinfo(%s)" % self.connection.ops.quote_name(table_name)
)
table_info = cursor.fetchall()
if not table_info:
@ -129,7 +129,13 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
pk == 1,
name in json_columns,
)
for cid, name, data_type, notnull, default, pk in table_info
for cid, name, data_type, notnull, default, pk, hidden in table_info
if hidden
in [
0, # Normal column.
2, # Virtual generated column.
3, # Stored generated column.
]
]
def get_sequences(self, cursor, table_name, table_fields=()):

View File

@ -135,7 +135,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
# Choose a default and insert it into the copy map
if (
create_field.db_default is NOT_PROVIDED
and not create_field.many_to_many
and not (create_field.many_to_many or create_field.generated)
and create_field.concrete
):
mapping[create_field.column] = self.prepare_default(

View File

@ -38,6 +38,7 @@ from django.db.models.expressions import (
from django.db.models.fields import * # NOQA
from django.db.models.fields import __all__ as fields_all
from django.db.models.fields.files import FileField, ImageField
from django.db.models.fields.generated import GeneratedField
from django.db.models.fields.json import JSONField
from django.db.models.fields.proxy import OrderWrt
from django.db.models.indexes import * # NOQA
@ -92,6 +93,7 @@ __all__ += [
"WindowFrame",
"FileField",
"ImageField",
"GeneratedField",
"JSONField",
"OrderWrt",
"Lookup",

View File

@ -508,7 +508,7 @@ class Model(AltersData, metaclass=ModelBase):
for field in fields_iter:
is_related_object = False
# Virtual field
if field.attname not in kwargs and field.column is None:
if field.attname not in kwargs and field.column is None or field.generated:
continue
if kwargs:
if isinstance(field.remote_field, ForeignObjectRel):
@ -1050,10 +1050,11 @@ class Model(AltersData, metaclass=ModelBase):
),
)["_order__max"]
)
fields = meta.local_concrete_fields
if not pk_set:
fields = [f for f in fields if f is not meta.auto_field]
fields = [
f
for f in meta.local_concrete_fields
if not f.generated and (pk_set or f is not meta.auto_field)
]
returning_fields = meta.db_returning_fields
results = self._do_insert(
cls._base_manager, using, fields, returning_fields, raw

View File

@ -165,6 +165,7 @@ class Field(RegisterLookupMixin):
one_to_many = None
one_to_one = None
related_model = None
generated = False
descriptor_class = DeferredAttribute
@ -646,6 +647,8 @@ class Field(RegisterLookupMixin):
path = path.replace("django.db.models.fields.related", "django.db.models")
elif path.startswith("django.db.models.fields.files"):
path = path.replace("django.db.models.fields.files", "django.db.models")
elif path.startswith("django.db.models.fields.generated"):
path = path.replace("django.db.models.fields.generated", "django.db.models")
elif path.startswith("django.db.models.fields.json"):
path = path.replace("django.db.models.fields.json", "django.db.models")
elif path.startswith("django.db.models.fields.proxy"):

View File

@ -0,0 +1,151 @@
from django.core import checks
from django.db import connections, router
from django.db.models.sql import Query
from . import NOT_PROVIDED, Field
__all__ = ["GeneratedField"]
class GeneratedField(Field):
generated = True
db_returning = True
_query = None
_resolved_expression = None
output_field = None
def __init__(self, *, expression, db_persist=None, output_field=None, **kwargs):
if kwargs.setdefault("editable", False):
raise ValueError("GeneratedField cannot be editable.")
if not kwargs.setdefault("blank", True):
raise ValueError("GeneratedField must be blank.")
if kwargs.get("default", NOT_PROVIDED) is not NOT_PROVIDED:
raise ValueError("GeneratedField cannot have a default.")
if kwargs.get("db_default", NOT_PROVIDED) is not NOT_PROVIDED:
raise ValueError("GeneratedField cannot have a database default.")
if db_persist not in (True, False):
raise ValueError("GeneratedField.db_persist must be True or False.")
self.expression = expression
self._output_field = output_field
self.db_persist = db_persist
super().__init__(**kwargs)
def contribute_to_class(self, *args, **kwargs):
super().contribute_to_class(*args, **kwargs)
self._query = Query(model=self.model, alias_cols=False)
self._resolved_expression = self.expression.resolve_expression(
self._query, allow_joins=False
)
self.output_field = (
self._output_field
if self._output_field is not None
else self._resolved_expression.output_field
)
# Register lookups from the output_field class.
for lookup_name, lookup in self.output_field.get_class_lookups().items():
self.register_lookup(lookup, lookup_name=lookup_name)
def generated_sql(self, connection):
return self._resolved_expression.as_sql(
compiler=connection.ops.compiler("SQLCompiler")(
self._query, connection=connection, using=None
),
connection=connection,
)
def check(self, **kwargs):
databases = kwargs.get("databases") or []
return [
*super().check(**kwargs),
*self._check_supported(databases),
*self._check_persistence(databases),
]
def _check_supported(self, databases):
errors = []
for db in databases:
if not router.allow_migrate_model(db, self.model):
continue
connection = connections[db]
if (
self.model._meta.required_db_vendor
and self.model._meta.required_db_vendor != connection.vendor
):
continue
if not (
connection.features.supports_virtual_generated_columns
or "supports_stored_generated_columns"
in self.model._meta.required_db_features
) and not (
connection.features.supports_stored_generated_columns
or "supports_virtual_generated_columns"
in self.model._meta.required_db_features
):
errors.append(
checks.Error(
f"{connection.display_name} does not support GeneratedFields.",
obj=self,
id="fields.E220",
)
)
return errors
def _check_persistence(self, databases):
errors = []
for db in databases:
if not router.allow_migrate_model(db, self.model):
continue
connection = connections[db]
if (
self.model._meta.required_db_vendor
and self.model._meta.required_db_vendor != connection.vendor
):
continue
if not self.db_persist and not (
connection.features.supports_virtual_generated_columns
or "supports_virtual_generated_columns"
in self.model._meta.required_db_features
):
errors.append(
checks.Error(
f"{connection.display_name} does not support non-persisted "
"GeneratedFields.",
obj=self,
id="fields.E221",
hint="Set db_persist=True on the field.",
)
)
if self.db_persist and not (
connection.features.supports_stored_generated_columns
or "supports_stored_generated_columns"
in self.model._meta.required_db_features
):
errors.append(
checks.Error(
f"{connection.display_name} does not support persisted "
"GeneratedFields.",
obj=self,
id="fields.E222",
hint="Set db_persist=False on the field.",
)
)
return errors
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
del kwargs["blank"]
del kwargs["editable"]
kwargs["db_persist"] = self.db_persist
kwargs["expression"] = self.expression
if self._output_field is not None:
kwargs["output_field"] = self._output_field
return name, path, args, kwargs
def get_internal_type(self):
return self.output_field.get_internal_type()
def db_parameters(self, connection):
return self.output_field.db_parameters(connection)

View File

@ -689,6 +689,8 @@ class QuerySet(AltersData):
obj.pk = obj._meta.pk.get_pk_value_on_save(obj)
if not connection.features.supports_default_keyword_in_bulk_insert:
for field in obj._meta.fields:
if field.generated:
continue
value = getattr(obj, field.attname)
if isinstance(value, DatabaseDefault):
setattr(obj, field.attname, field.db_default)
@ -804,7 +806,7 @@ class QuerySet(AltersData):
unique_fields,
)
self._for_write = True
fields = opts.concrete_fields
fields = [f for f in opts.concrete_fields if not f.generated]
objs = list(objs)
self._prepare_for_bulk_create(objs)
with transaction.atomic(using=self.db, savepoint=False):

View File

@ -198,6 +198,10 @@ class DeferredAttribute:
# might be able to reuse the already loaded value. Refs #18343.
val = self._check_parent_chain(instance)
if val is None:
if instance.pk is None and self.field.generated:
raise FieldError(
"Cannot read a generated field from an unsaved model."
)
instance.refresh_from_db(fields=[field_name])
else:
data[field_name] = val

View File

@ -108,6 +108,9 @@ class UpdateQuery(Query):
called add_update_targets() to hint at the extra information here.
"""
for field, model, val in values_seq:
# Omit generated fields.
if field.generated:
continue
if hasattr(val, "resolve_expression"):
# Resolve expressions here so that annotations are no longer needed
val = val.resolve_expression(self, allow_joins=False, for_save=True)

View File

@ -208,6 +208,11 @@ Model fields
* **fields.E180**: ``<database>`` does not support ``JSONField``\s.
* **fields.E190**: ``<database>`` does not support a database collation on
``<field_type>``\s.
* **fields.E220**: ``<database>`` does not support ``GeneratedField``\s.
* **fields.E221**: ``<database>`` does not support non-persisted
``GeneratedField``\s.
* **fields.E222**: ``<database>`` does not support persisted
``GeneratedField``\s.
* **fields.E900**: ``IPAddressField`` has been removed except for support in
historical migrations.
* **fields.W900**: ``IPAddressField`` has been deprecated. Support for it

View File

@ -1215,6 +1215,71 @@ when :attr:`~django.forms.Field.localize` is ``False`` or
information on the difference between the two, see Python's documentation
for the :mod:`decimal` module.
``GeneratedField``
------------------
.. versionadded:: 5.0
.. class:: GeneratedField(expression, db_persist=None, output_field=None, **kwargs)
A field that is always computed based on other fields in the model. This field
is managed and updated by the database itself. Uses the ``GENERATED ALWAYS``
SQL syntax.
There are two kinds of generated columns: stored and virtual. A stored
generated column is computed when it is written (inserted or updated) and
occupies storage as if it were a regular column. A virtual generated column
occupies no storage and is computed when it is read. Thus, a virtual generated
column is similar to a view and a stored generated column is similar to a
materialized view.
.. attribute:: GeneratedField.expression
An :class:`Expression` used by the database to automatically set the field
value each time the model is changed.
The expressions should be deterministic and only reference fields within
the model (in the same database table). Generated fields cannot reference
other generated fields. Database backends can impose further restrictions.
.. attribute:: GeneratedField.db_persist
Determines if the database column should occupy storage as if it were a
real column. If ``False``, the column acts as a virtual column and does
not occupy database storage space.
PostgreSQL only supports persisted columns. Oracle only supports virtual
columns.
.. attribute:: GeneratedField.output_field
An optional model field instance to define the field's data type. This can
be used to customize attributes like the field's collation. By default, the
output field is derived from ``expression``.
.. admonition:: Refresh the data
Since the database always computed the value, the object must be reloaded
to access the new value after :meth:`~Model.save()`, for example, by using
:meth:`~Model.refresh_from_db()`.
.. admonition:: Database limitations
There are many database-specific restrictions on generated fields that
Django doesn't validate and the database may raise an error e.g. PostgreSQL
requires functions and operators referenced in a generated columns to be
marked as ``IMMUTABLE`` .
You should always check that ``expression`` is supported on your database.
Check out `MariaDB`_, `MySQL`_, `Oracle`_, `PostgreSQL`_, or `SQLite`_
docs.
.. _MariaDB: https://mariadb.com/kb/en/generated-columns/#expression-support
.. _MySQL: https://dev.mysql.com/doc/refman/en/create-table-generated-columns.html
.. _Oracle: https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/CREATE-TABLE.html#GUID-F9CE0CC3-13AE-4744-A43C-EAC7A71AAAB6__BABIIGBD
.. _PostgreSQL: https://www.postgresql.org/docs/current/ddl-generated-columns.html
.. _SQLite: https://www.sqlite.org/gencol.html#limitations
``GenericIPAddressField``
-------------------------

View File

@ -129,6 +129,13 @@ sets a database-computed default value. For example::
created = models.DateTimeField(db_default=Now())
circumference = models.FloatField(db_default=2 * Pi())
Database generated model field
------------------------------
The new :class:`~django.db.models.GeneratedField` allows creation of database
generated columns. This field can be used on all supported database backends
to create a field that is always computed from other fields.
More options for declaring field choices
----------------------------------------

View File

@ -1226,3 +1226,115 @@ class InvalidDBDefaultTests(TestCase):
msg = f"{expression} cannot be used in db_default."
expected_error = Error(msg=msg, obj=field, id="fields.E012")
self.assertEqual(errors, [expected_error])
@isolate_apps("invalid_models_tests")
class GeneratedFieldTests(TestCase):
def test_not_supported(self):
db_persist = connection.features.supports_stored_generated_columns
class Model(models.Model):
name = models.IntegerField()
field = models.GeneratedField(
expression=models.F("name"), db_persist=db_persist
)
expected_errors = []
if (
not connection.features.supports_stored_generated_columns
and not connection.features.supports_virtual_generated_columns
):
expected_errors.append(
Error(
f"{connection.display_name} does not support GeneratedFields.",
obj=Model._meta.get_field("field"),
id="fields.E220",
)
)
if (
not db_persist
and not connection.features.supports_virtual_generated_columns
):
expected_errors.append(
Error(
f"{connection.display_name} does not support non-persisted "
"GeneratedFields.",
obj=Model._meta.get_field("field"),
id="fields.E221",
hint="Set db_persist=True on the field.",
),
)
self.assertEqual(
Model._meta.get_field("field").check(databases={"default"}),
expected_errors,
)
def test_not_supported_stored_required_db_features(self):
class Model(models.Model):
name = models.IntegerField()
field = models.GeneratedField(expression=models.F("name"), db_persist=True)
class Meta:
required_db_features = {"supports_stored_generated_columns"}
self.assertEqual(Model.check(databases=self.databases), [])
def test_not_supported_virtual_required_db_features(self):
class Model(models.Model):
name = models.IntegerField()
field = models.GeneratedField(expression=models.F("name"), db_persist=False)
class Meta:
required_db_features = {"supports_virtual_generated_columns"}
self.assertEqual(Model.check(databases=self.databases), [])
@skipUnlessDBFeature("supports_stored_generated_columns")
def test_not_supported_virtual(self):
class Model(models.Model):
name = models.IntegerField()
field = models.GeneratedField(expression=models.F("name"), db_persist=False)
a = models.TextField()
excepted_errors = (
[]
if connection.features.supports_virtual_generated_columns
else [
Error(
f"{connection.display_name} does not support non-persisted "
"GeneratedFields.",
obj=Model._meta.get_field("field"),
id="fields.E221",
hint="Set db_persist=True on the field.",
),
]
)
self.assertEqual(
Model._meta.get_field("field").check(databases={"default"}),
excepted_errors,
)
@skipUnlessDBFeature("supports_virtual_generated_columns")
def test_not_supported_stored(self):
class Model(models.Model):
name = models.IntegerField()
field = models.GeneratedField(expression=models.F("name"), db_persist=True)
a = models.TextField()
expected_errors = (
[]
if connection.features.supports_stored_generated_columns
else [
Error(
f"{connection.display_name} does not support persisted "
"GeneratedFields.",
obj=Model._meta.get_field("field"),
id="fields.E222",
hint="Set db_persist=False on the field.",
),
]
)
self.assertEqual(
Model._meta.get_field("field").check(databases={"default"}),
expected_errors,
)

View File

@ -5,6 +5,7 @@ from django.db import IntegrityError, connection, migrations, models, transactio
from django.db.migrations.migration import Migration
from django.db.migrations.operations.fields import FieldOperation
from django.db.migrations.state import ModelState, ProjectState
from django.db.models import F
from django.db.models.expressions import Value
from django.db.models.functions import Abs, Pi
from django.db.transaction import atomic
@ -5741,6 +5742,130 @@ class OperationTests(OperationTestBase):
operation.database_backwards(app_label, editor, new_state, project_state)
assertModelsAndTables(after_db=False)
def _test_invalid_generated_field_changes(self, db_persist):
regular = models.IntegerField(default=1)
generated_1 = models.GeneratedField(
expression=F("pink") + F("pink"), db_persist=db_persist
)
generated_2 = models.GeneratedField(
expression=F("pink") + F("pink") + F("pink"), db_persist=db_persist
)
tests = [
("test_igfc_1", regular, generated_1),
("test_igfc_2", generated_1, regular),
("test_igfc_3", generated_1, generated_2),
]
for app_label, add_field, alter_field in tests:
project_state = self.set_up_test_model(app_label)
operations = [
migrations.AddField("Pony", "modified_pink", add_field),
migrations.AlterField("Pony", "modified_pink", alter_field),
]
msg = (
"Modifying GeneratedFields is not supported - the field "
f"{app_label}.Pony.modified_pink must be removed and re-added with the "
"new definition."
)
with self.assertRaisesMessage(ValueError, msg):
self.apply_operations(app_label, project_state, operations)
@skipUnlessDBFeature("supports_stored_generated_columns")
def test_invalid_generated_field_changes_stored(self):
self._test_invalid_generated_field_changes(db_persist=True)
@skipUnlessDBFeature("supports_virtual_generated_columns")
def test_invalid_generated_field_changes_virtual(self):
self._test_invalid_generated_field_changes(db_persist=False)
@skipUnlessDBFeature(
"supports_stored_generated_columns",
"supports_virtual_generated_columns",
)
def test_invalid_generated_field_persistency_change(self):
app_label = "test_igfpc"
project_state = self.set_up_test_model(app_label)
operations = [
migrations.AddField(
"Pony",
"modified_pink",
models.GeneratedField(expression=F("pink"), db_persist=True),
),
migrations.AlterField(
"Pony",
"modified_pink",
models.GeneratedField(expression=F("pink"), db_persist=False),
),
]
msg = (
"Modifying GeneratedFields is not supported - the field "
f"{app_label}.Pony.modified_pink must be removed and re-added with the "
"new definition."
)
with self.assertRaisesMessage(ValueError, msg):
self.apply_operations(app_label, project_state, operations)
def _test_add_generated_field(self, db_persist):
app_label = "test_agf"
operation = migrations.AddField(
"Pony",
"modified_pink",
models.GeneratedField(
expression=F("pink") + F("pink"), db_persist=db_persist
),
)
project_state, new_state = self.make_test_state(app_label, operation)
self.assertEqual(len(new_state.models[app_label, "pony"].fields), 6)
# Add generated column.
with connection.schema_editor() as editor:
operation.database_forwards(app_label, editor, project_state, new_state)
self.assertColumnExists(f"{app_label}_pony", "modified_pink")
Pony = new_state.apps.get_model(app_label, "Pony")
obj = Pony.objects.create(pink=5, weight=3.23)
self.assertEqual(obj.modified_pink, 10)
# Reversal.
with connection.schema_editor() as editor:
operation.database_backwards(app_label, editor, new_state, project_state)
self.assertColumnNotExists(f"{app_label}_pony", "modified_pink")
@skipUnlessDBFeature("supports_stored_generated_columns")
def test_add_generated_field_stored(self):
self._test_add_generated_field(db_persist=True)
@skipUnlessDBFeature("supports_virtual_generated_columns")
def test_add_generated_field_virtual(self):
self._test_add_generated_field(db_persist=False)
def _test_remove_generated_field(self, db_persist):
app_label = "test_rgf"
operation = migrations.AddField(
"Pony",
"modified_pink",
models.GeneratedField(
expression=F("pink") + F("pink"), db_persist=db_persist
),
)
project_state, new_state = self.make_test_state(app_label, operation)
self.assertEqual(len(new_state.models[app_label, "pony"].fields), 6)
# Add generated column.
with connection.schema_editor() as editor:
operation.database_forwards(app_label, editor, project_state, new_state)
project_state = new_state
new_state = project_state.clone()
operation = migrations.RemoveField("Pony", "modified_pink")
operation.state_forwards(app_label, new_state)
# Remove generated column.
with connection.schema_editor() as editor:
operation.database_forwards(app_label, editor, project_state, new_state)
self.assertColumnNotExists(f"{app_label}_pony", "modified_pink")
@skipUnlessDBFeature("supports_stored_generated_columns")
def test_remove_generated_field_stored(self):
self._test_remove_generated_field(db_persist=True)
@skipUnlessDBFeature("supports_virtual_generated_columns")
def test_remove_generated_field_virtual(self):
self._test_remove_generated_field(db_persist=False)
class SwappableOperationTests(OperationTestBase):
"""

View File

@ -6,8 +6,11 @@ from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelatio
from django.contrib.contenttypes.models import ContentType
from django.core.files.storage import FileSystemStorage
from django.core.serializers.json import DjangoJSONEncoder
from django.db import models
from django.db import connection, models
from django.db.models import F, Value
from django.db.models.fields.files import ImageFieldFile
from django.db.models.functions import Lower
from django.utils.functional import SimpleLazyObject
from django.utils.translation import gettext_lazy as _
try:
@ -16,6 +19,11 @@ except ImportError:
Image = None
test_collation = SimpleLazyObject(
lambda: connection.features.test_collations.get("non_default")
)
class Foo(models.Model):
a = models.CharField(max_length=10)
d = models.DecimalField(max_digits=5, decimal_places=3)
@ -468,3 +476,91 @@ class UUIDChild(PrimaryKeyUUIDModel):
class UUIDGrandchild(UUIDChild):
pass
class GeneratedModel(models.Model):
a = models.IntegerField()
b = models.IntegerField()
field = models.GeneratedField(expression=F("a") + F("b"), db_persist=True)
class Meta:
required_db_features = {"supports_stored_generated_columns"}
class GeneratedModelVirtual(models.Model):
a = models.IntegerField()
b = models.IntegerField()
field = models.GeneratedField(expression=F("a") + F("b"), db_persist=False)
class Meta:
required_db_features = {"supports_virtual_generated_columns"}
class GeneratedModelParams(models.Model):
field = models.GeneratedField(
expression=Value("Constant", output_field=models.CharField(max_length=10)),
db_persist=True,
)
class Meta:
required_db_features = {"supports_stored_generated_columns"}
class GeneratedModelParamsVirtual(models.Model):
field = models.GeneratedField(
expression=Value("Constant", output_field=models.CharField(max_length=10)),
db_persist=False,
)
class Meta:
required_db_features = {"supports_virtual_generated_columns"}
class GeneratedModelOutputField(models.Model):
name = models.CharField(max_length=10)
lower_name = models.GeneratedField(
expression=Lower("name"),
output_field=models.CharField(db_collation=test_collation, max_length=11),
db_persist=True,
)
class Meta:
required_db_features = {
"supports_stored_generated_columns",
"supports_collation_on_charfield",
}
class GeneratedModelOutputFieldVirtual(models.Model):
name = models.CharField(max_length=10)
lower_name = models.GeneratedField(
expression=Lower("name"),
db_persist=False,
output_field=models.CharField(db_collation=test_collation, max_length=11),
)
class Meta:
required_db_features = {
"supports_virtual_generated_columns",
"supports_collation_on_charfield",
}
class GeneratedModelNull(models.Model):
name = models.CharField(max_length=10, null=True)
lower_name = models.GeneratedField(
expression=Lower("name"), db_persist=True, null=True
)
class Meta:
required_db_features = {"supports_stored_generated_columns"}
class GeneratedModelNullVirtual(models.Model):
name = models.CharField(max_length=10, null=True)
lower_name = models.GeneratedField(
expression=Lower("name"), db_persist=False, null=True
)
class Meta:
required_db_features = {"supports_virtual_generated_columns"}

View File

@ -0,0 +1,176 @@
from django.core.exceptions import FieldError
from django.db import IntegrityError, connection
from django.db.models import F, GeneratedField, IntegerField
from django.db.models.functions import Lower
from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature
from .models import (
GeneratedModel,
GeneratedModelNull,
GeneratedModelNullVirtual,
GeneratedModelOutputField,
GeneratedModelOutputFieldVirtual,
GeneratedModelParams,
GeneratedModelParamsVirtual,
GeneratedModelVirtual,
)
class BaseGeneratedFieldTests(SimpleTestCase):
def test_editable_unsupported(self):
with self.assertRaisesMessage(ValueError, "GeneratedField cannot be editable."):
GeneratedField(expression=Lower("name"), editable=True, db_persist=False)
def test_blank_unsupported(self):
with self.assertRaisesMessage(ValueError, "GeneratedField must be blank."):
GeneratedField(expression=Lower("name"), blank=False, db_persist=False)
def test_default_unsupported(self):
msg = "GeneratedField cannot have a default."
with self.assertRaisesMessage(ValueError, msg):
GeneratedField(expression=Lower("name"), default="", db_persist=False)
def test_database_default_unsupported(self):
msg = "GeneratedField cannot have a database default."
with self.assertRaisesMessage(ValueError, msg):
GeneratedField(expression=Lower("name"), db_default="", db_persist=False)
def test_db_persist_required(self):
msg = "GeneratedField.db_persist must be True or False."
with self.assertRaisesMessage(ValueError, msg):
GeneratedField(expression=Lower("name"))
with self.assertRaisesMessage(ValueError, msg):
GeneratedField(expression=Lower("name"), db_persist=None)
def test_deconstruct(self):
field = GeneratedField(expression=F("a") + F("b"), db_persist=True)
_, path, args, kwargs = field.deconstruct()
self.assertEqual(path, "django.db.models.GeneratedField")
self.assertEqual(args, [])
self.assertEqual(kwargs, {"db_persist": True, "expression": F("a") + F("b")})
class GeneratedFieldTestMixin:
def _refresh_if_needed(self, m):
if not connection.features.can_return_columns_from_insert:
m.refresh_from_db()
return m
def test_unsaved_error(self):
m = self.base_model(a=1, b=2)
msg = "Cannot read a generated field from an unsaved model."
with self.assertRaisesMessage(FieldError, msg):
m.field
def test_create(self):
m = self.base_model.objects.create(a=1, b=2)
m = self._refresh_if_needed(m)
self.assertEqual(m.field, 3)
def test_non_nullable_create(self):
with self.assertRaises(IntegrityError):
self.base_model.objects.create()
def test_save(self):
# Insert.
m = self.base_model(a=2, b=4)
m.save()
m = self._refresh_if_needed(m)
self.assertEqual(m.field, 6)
# Update.
m.a = 4
m.save()
m.refresh_from_db()
self.assertEqual(m.field, 8)
def test_update(self):
m = self.base_model.objects.create(a=1, b=2)
self.base_model.objects.update(b=3)
m = self.base_model.objects.get(pk=m.pk)
self.assertEqual(m.field, 4)
def test_bulk_create(self):
m = self.base_model(a=3, b=4)
(m,) = self.base_model.objects.bulk_create([m])
if not connection.features.can_return_rows_from_bulk_insert:
m = self.base_model.objects.get()
self.assertEqual(m.field, 7)
def test_bulk_update(self):
m = self.base_model.objects.create(a=1, b=2)
m.a = 3
self.base_model.objects.bulk_update([m], fields=["a"])
m = self.base_model.objects.get(pk=m.pk)
self.assertEqual(m.field, 5)
def test_output_field_lookups(self):
"""Lookups from the output_field are available on GeneratedFields."""
internal_type = IntegerField().get_internal_type()
min_value, max_value = connection.ops.integer_field_range(internal_type)
if min_value is None:
self.skipTest("Backend doesn't define an integer min value.")
if max_value is None:
self.skipTest("Backend doesn't define an integer max value.")
does_not_exist = self.base_model.DoesNotExist
underflow_value = min_value - 1
with self.assertNumQueries(0), self.assertRaises(does_not_exist):
self.base_model.objects.get(field=underflow_value)
with self.assertNumQueries(0), self.assertRaises(does_not_exist):
self.base_model.objects.get(field__lt=underflow_value)
with self.assertNumQueries(0), self.assertRaises(does_not_exist):
self.base_model.objects.get(field__lte=underflow_value)
overflow_value = max_value + 1
with self.assertNumQueries(0), self.assertRaises(does_not_exist):
self.base_model.objects.get(field=overflow_value)
with self.assertNumQueries(0), self.assertRaises(does_not_exist):
self.base_model.objects.get(field__gt=overflow_value)
with self.assertNumQueries(0), self.assertRaises(does_not_exist):
self.base_model.objects.get(field__gte=overflow_value)
@skipUnlessDBFeature("supports_collation_on_charfield")
def test_output_field(self):
collation = connection.features.test_collations.get("non_default")
if not collation:
self.skipTest("Language collations are not supported.")
m = self.output_field_model.objects.create(name="NAME")
field = m._meta.get_field("lower_name")
db_parameters = field.db_parameters(connection)
self.assertEqual(db_parameters["collation"], collation)
self.assertEqual(db_parameters["type"], field.output_field.db_type(connection))
self.assertNotEqual(
db_parameters["type"],
field._resolved_expression.output_field.db_type(connection),
)
def test_model_with_params(self):
m = self.params_model.objects.create()
m = self._refresh_if_needed(m)
self.assertEqual(m.field, "Constant")
def test_nullable(self):
m1 = self.nullable_model.objects.create()
m1 = self._refresh_if_needed(m1)
none_val = "" if connection.features.interprets_empty_strings_as_nulls else None
self.assertEqual(m1.lower_name, none_val)
m2 = self.nullable_model.objects.create(name="NaMe")
m2 = self._refresh_if_needed(m2)
self.assertEqual(m2.lower_name, "name")
@skipUnlessDBFeature("supports_stored_generated_columns")
class StoredGeneratedFieldTests(GeneratedFieldTestMixin, TestCase):
base_model = GeneratedModel
nullable_model = GeneratedModelNull
output_field_model = GeneratedModelOutputField
params_model = GeneratedModelParams
@skipUnlessDBFeature("supports_virtual_generated_columns")
class VirtualGeneratedFieldTests(GeneratedFieldTestMixin, TestCase):
base_model = GeneratedModelVirtual
nullable_model = GeneratedModelNullVirtual
output_field_model = GeneratedModelOutputFieldVirtual
params_model = GeneratedModelParamsVirtual