Fixed #34944 -- Made GeneratedField.output_field required.

Regression in f333e3513e.
This commit is contained in:
Mariusz Felisiak 2023-11-13 05:33:25 +01:00
parent de4884b114
commit 5875f03ce6
9 changed files with 150 additions and 54 deletions

View File

@ -16,7 +16,7 @@ class GeneratedField(Field):
_resolved_expression = None _resolved_expression = None
output_field = None output_field = None
def __init__(self, *, expression, db_persist=None, output_field=None, **kwargs): def __init__(self, *, expression, output_field, db_persist=None, **kwargs):
if kwargs.setdefault("editable", False): if kwargs.setdefault("editable", False):
raise ValueError("GeneratedField cannot be editable.") raise ValueError("GeneratedField cannot be editable.")
if not kwargs.setdefault("blank", True): if not kwargs.setdefault("blank", True):
@ -29,7 +29,7 @@ class GeneratedField(Field):
raise ValueError("GeneratedField.db_persist must be True or False.") raise ValueError("GeneratedField.db_persist must be True or False.")
self.expression = expression self.expression = expression
self._output_field = output_field self.output_field = output_field
self.db_persist = db_persist self.db_persist = db_persist
super().__init__(**kwargs) super().__init__(**kwargs)
@ -51,11 +51,6 @@ class GeneratedField(Field):
self._resolved_expression = self.expression.resolve_expression( self._resolved_expression = self.expression.resolve_expression(
self._query, allow_joins=False self._query, allow_joins=False
) )
self.output_field = (
self._output_field
if self._output_field is not None
else self._resolved_expression.output_field
)
# Register lookups from the output_field class. # Register lookups from the output_field class.
for lookup_name, lookup in self.output_field.get_class_lookups().items(): for lookup_name, lookup in self.output_field.get_class_lookups().items():
self.register_lookup(lookup, lookup_name=lookup_name) self.register_lookup(lookup, lookup_name=lookup_name)
@ -150,8 +145,7 @@ class GeneratedField(Field):
del kwargs["editable"] del kwargs["editable"]
kwargs["db_persist"] = self.db_persist kwargs["db_persist"] = self.db_persist
kwargs["expression"] = self.expression kwargs["expression"] = self.expression
if self._output_field is not None: kwargs["output_field"] = self.output_field
kwargs["output_field"] = self._output_field
return name, path, args, kwargs return name, path, args, kwargs
def get_internal_type(self): def get_internal_type(self):

View File

@ -1237,7 +1237,7 @@ when :attr:`~django.forms.Field.localize` is ``False`` or
.. versionadded:: 5.0 .. versionadded:: 5.0
.. class:: GeneratedField(expression, db_persist=None, output_field=None, **kwargs) .. class:: GeneratedField(expression, output_field, db_persist=None, **kwargs)
A field that is always computed based on other fields in the model. This field A field that is always computed based on other fields in the model. This field
is managed and updated by the database itself. Uses the ``GENERATED ALWAYS`` is managed and updated by the database itself. Uses the ``GENERATED ALWAYS``
@ -1259,6 +1259,10 @@ materialized view.
the model (in the same database table). Generated fields cannot reference the model (in the same database table). Generated fields cannot reference
other generated fields. Database backends can impose further restrictions. other generated fields. Database backends can impose further restrictions.
.. attribute:: GeneratedField.output_field
A model field instance to define the field's data type.
.. attribute:: GeneratedField.db_persist .. attribute:: GeneratedField.db_persist
Determines if the database column should occupy storage as if it were a Determines if the database column should occupy storage as if it were a
@ -1268,12 +1272,6 @@ materialized view.
PostgreSQL only supports persisted columns. Oracle only supports virtual PostgreSQL only supports persisted columns. Oracle only supports virtual
columns. columns.
.. attribute:: GeneratedField.output_field
An optional model field instance to define the field's data type. This can
be used to customize attributes like the field's collation. By default, the
output field is derived from ``expression``.
.. admonition:: Refresh the data .. admonition:: Refresh the data
Since the database always computed the value, the object must be reloaded Since the database always computed the value, the object must be reloaded

View File

@ -142,7 +142,11 @@ to create a field that is always computed from other fields. For example::
class Square(models.Model): class Square(models.Model):
side = models.IntegerField() side = models.IntegerField()
area = models.GeneratedField(expression=F("side") * F("side"), db_persist=True) area = models.GeneratedField(
expression=F("side") * F("side"),
output_field=models.BigIntegerField(),
db_persist=True,
)
More options for declaring field choices More options for declaring field choices
---------------------------------------- ----------------------------------------

View File

@ -1147,6 +1147,7 @@ class Square(models.Model):
area = models.GeneratedField( area = models.GeneratedField(
db_persist=True, db_persist=True,
expression=models.F("side") * models.F("side"), expression=models.F("side") * models.F("side"),
output_field=models.BigIntegerField(),
) )
class Meta: class Meta:

View File

@ -1216,7 +1216,9 @@ class GeneratedFieldTests(TestCase):
class Model(models.Model): class Model(models.Model):
name = models.IntegerField() name = models.IntegerField()
field = models.GeneratedField( field = models.GeneratedField(
expression=models.F("name"), db_persist=db_persist expression=models.F("name"),
output_field=models.IntegerField(),
db_persist=db_persist,
) )
expected_errors = [] expected_errors = []
@ -1252,7 +1254,11 @@ class GeneratedFieldTests(TestCase):
def test_not_supported_stored_required_db_features(self): def test_not_supported_stored_required_db_features(self):
class Model(models.Model): class Model(models.Model):
name = models.IntegerField() name = models.IntegerField()
field = models.GeneratedField(expression=models.F("name"), db_persist=True) field = models.GeneratedField(
expression=models.F("name"),
output_field=models.IntegerField(),
db_persist=True,
)
class Meta: class Meta:
required_db_features = {"supports_stored_generated_columns"} required_db_features = {"supports_stored_generated_columns"}
@ -1262,7 +1268,11 @@ class GeneratedFieldTests(TestCase):
def test_not_supported_virtual_required_db_features(self): def test_not_supported_virtual_required_db_features(self):
class Model(models.Model): class Model(models.Model):
name = models.IntegerField() name = models.IntegerField()
field = models.GeneratedField(expression=models.F("name"), db_persist=False) field = models.GeneratedField(
expression=models.F("name"),
output_field=models.IntegerField(),
db_persist=False,
)
class Meta: class Meta:
required_db_features = {"supports_virtual_generated_columns"} required_db_features = {"supports_virtual_generated_columns"}
@ -1273,7 +1283,11 @@ class GeneratedFieldTests(TestCase):
def test_not_supported_virtual(self): def test_not_supported_virtual(self):
class Model(models.Model): class Model(models.Model):
name = models.IntegerField() name = models.IntegerField()
field = models.GeneratedField(expression=models.F("name"), db_persist=False) field = models.GeneratedField(
expression=models.F("name"),
output_field=models.IntegerField(),
db_persist=False,
)
a = models.TextField() a = models.TextField()
excepted_errors = ( excepted_errors = (
@ -1298,7 +1312,11 @@ class GeneratedFieldTests(TestCase):
def test_not_supported_stored(self): def test_not_supported_stored(self):
class Model(models.Model): class Model(models.Model):
name = models.IntegerField() name = models.IntegerField()
field = models.GeneratedField(expression=models.F("name"), db_persist=True) field = models.GeneratedField(
expression=models.F("name"),
output_field=models.IntegerField(),
db_persist=True,
)
a = models.TextField() a = models.TextField()
expected_errors = ( expected_errors = (

View File

@ -5664,10 +5664,14 @@ class OperationTests(OperationTestBase):
def _test_invalid_generated_field_changes(self, db_persist): def _test_invalid_generated_field_changes(self, db_persist):
regular = models.IntegerField(default=1) regular = models.IntegerField(default=1)
generated_1 = models.GeneratedField( generated_1 = models.GeneratedField(
expression=F("pink") + F("pink"), db_persist=db_persist expression=F("pink") + F("pink"),
output_field=models.IntegerField(),
db_persist=db_persist,
) )
generated_2 = models.GeneratedField( generated_2 = models.GeneratedField(
expression=F("pink") + F("pink") + F("pink"), db_persist=db_persist expression=F("pink") + F("pink") + F("pink"),
output_field=models.IntegerField(),
db_persist=db_persist,
) )
tests = [ tests = [
("test_igfc_1", regular, generated_1), ("test_igfc_1", regular, generated_1),
@ -5707,12 +5711,20 @@ class OperationTests(OperationTestBase):
migrations.AddField( migrations.AddField(
"Pony", "Pony",
"modified_pink", "modified_pink",
models.GeneratedField(expression=F("pink"), db_persist=True), models.GeneratedField(
expression=F("pink"),
output_field=models.IntegerField(),
db_persist=True,
),
), ),
migrations.AlterField( migrations.AlterField(
"Pony", "Pony",
"modified_pink", "modified_pink",
models.GeneratedField(expression=F("pink"), db_persist=False), models.GeneratedField(
expression=F("pink"),
output_field=models.IntegerField(),
db_persist=False,
),
), ),
] ]
msg = ( msg = (
@ -5729,7 +5741,9 @@ class OperationTests(OperationTestBase):
"Pony", "Pony",
"modified_pink", "modified_pink",
models.GeneratedField( models.GeneratedField(
expression=F("pink") + F("pink"), db_persist=db_persist expression=F("pink") + F("pink"),
output_field=models.IntegerField(),
db_persist=db_persist,
), ),
) )
project_state, new_state = self.make_test_state(app_label, operation) project_state, new_state = self.make_test_state(app_label, operation)
@ -5760,7 +5774,9 @@ class OperationTests(OperationTestBase):
"Pony", "Pony",
"modified_pink", "modified_pink",
models.GeneratedField( models.GeneratedField(
expression=F("pink") + F("pink"), db_persist=db_persist expression=F("pink") + F("pink"),
output_field=models.IntegerField(),
db_persist=db_persist,
), ),
) )
project_state, new_state = self.make_test_state(app_label, operation) project_state, new_state = self.make_test_state(app_label, operation)

View File

@ -485,7 +485,11 @@ class UUIDGrandchild(UUIDChild):
class GeneratedModel(models.Model): class GeneratedModel(models.Model):
a = models.IntegerField() a = models.IntegerField()
b = models.IntegerField() b = models.IntegerField()
field = models.GeneratedField(expression=F("a") + F("b"), db_persist=True) field = models.GeneratedField(
expression=F("a") + F("b"),
output_field=models.IntegerField(),
db_persist=True,
)
class Meta: class Meta:
required_db_features = {"supports_stored_generated_columns"} required_db_features = {"supports_stored_generated_columns"}
@ -494,7 +498,11 @@ class GeneratedModel(models.Model):
class GeneratedModelVirtual(models.Model): class GeneratedModelVirtual(models.Model):
a = models.IntegerField() a = models.IntegerField()
b = models.IntegerField() b = models.IntegerField()
field = models.GeneratedField(expression=F("a") + F("b"), db_persist=False) field = models.GeneratedField(
expression=F("a") + F("b"),
output_field=models.IntegerField(),
db_persist=False,
)
class Meta: class Meta:
required_db_features = {"supports_virtual_generated_columns"} required_db_features = {"supports_virtual_generated_columns"}
@ -503,6 +511,7 @@ class GeneratedModelVirtual(models.Model):
class GeneratedModelParams(models.Model): class GeneratedModelParams(models.Model):
field = models.GeneratedField( field = models.GeneratedField(
expression=Value("Constant", output_field=models.CharField(max_length=10)), expression=Value("Constant", output_field=models.CharField(max_length=10)),
output_field=models.CharField(max_length=10),
db_persist=True, db_persist=True,
) )
@ -513,6 +522,7 @@ class GeneratedModelParams(models.Model):
class GeneratedModelParamsVirtual(models.Model): class GeneratedModelParamsVirtual(models.Model):
field = models.GeneratedField( field = models.GeneratedField(
expression=Value("Constant", output_field=models.CharField(max_length=10)), expression=Value("Constant", output_field=models.CharField(max_length=10)),
output_field=models.CharField(max_length=10),
db_persist=False, db_persist=False,
) )
@ -520,7 +530,7 @@ class GeneratedModelParamsVirtual(models.Model):
required_db_features = {"supports_virtual_generated_columns"} required_db_features = {"supports_virtual_generated_columns"}
class GeneratedModelOutputField(models.Model): class GeneratedModelOutputFieldDbCollation(models.Model):
name = models.CharField(max_length=10) name = models.CharField(max_length=10)
lower_name = models.GeneratedField( lower_name = models.GeneratedField(
expression=Lower("name"), expression=Lower("name"),
@ -532,7 +542,7 @@ class GeneratedModelOutputField(models.Model):
required_db_features = {"supports_stored_generated_columns"} required_db_features = {"supports_stored_generated_columns"}
class GeneratedModelOutputFieldVirtual(models.Model): class GeneratedModelOutputFieldDbCollationVirtual(models.Model):
name = models.CharField(max_length=10) name = models.CharField(max_length=10)
lower_name = models.GeneratedField( lower_name = models.GeneratedField(
expression=Lower("name"), expression=Lower("name"),
@ -547,7 +557,10 @@ class GeneratedModelOutputFieldVirtual(models.Model):
class GeneratedModelNull(models.Model): class GeneratedModelNull(models.Model):
name = models.CharField(max_length=10, null=True) name = models.CharField(max_length=10, null=True)
lower_name = models.GeneratedField( lower_name = models.GeneratedField(
expression=Lower("name"), db_persist=True, null=True expression=Lower("name"),
output_field=models.CharField(max_length=10),
db_persist=True,
null=True,
) )
class Meta: class Meta:
@ -557,7 +570,10 @@ class GeneratedModelNull(models.Model):
class GeneratedModelNullVirtual(models.Model): class GeneratedModelNullVirtual(models.Model):
name = models.CharField(max_length=10, null=True) name = models.CharField(max_length=10, null=True)
lower_name = models.GeneratedField( lower_name = models.GeneratedField(
expression=Lower("name"), db_persist=False, null=True expression=Lower("name"),
output_field=models.CharField(max_length=10),
db_persist=False,
null=True,
) )
class Meta: class Meta:

View File

@ -1,5 +1,12 @@
from django.db import IntegrityError, connection from django.db import IntegrityError, connection
from django.db.models import F, FloatField, GeneratedField, IntegerField, Model from django.db.models import (
CharField,
F,
FloatField,
GeneratedField,
IntegerField,
Model,
)
from django.db.models.functions import Lower from django.db.models.functions import Lower
from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature
from django.test.utils import isolate_apps from django.test.utils import isolate_apps
@ -8,8 +15,8 @@ from .models import (
GeneratedModel, GeneratedModel,
GeneratedModelNull, GeneratedModelNull,
GeneratedModelNullVirtual, GeneratedModelNullVirtual,
GeneratedModelOutputField, GeneratedModelOutputFieldDbCollation,
GeneratedModelOutputFieldVirtual, GeneratedModelOutputFieldDbCollationVirtual,
GeneratedModelParams, GeneratedModelParams,
GeneratedModelParamsVirtual, GeneratedModelParamsVirtual,
GeneratedModelVirtual, GeneratedModelVirtual,
@ -19,41 +26,77 @@ from .models import (
class BaseGeneratedFieldTests(SimpleTestCase): class BaseGeneratedFieldTests(SimpleTestCase):
def test_editable_unsupported(self): def test_editable_unsupported(self):
with self.assertRaisesMessage(ValueError, "GeneratedField cannot be editable."): with self.assertRaisesMessage(ValueError, "GeneratedField cannot be editable."):
GeneratedField(expression=Lower("name"), editable=True, db_persist=False) GeneratedField(
expression=Lower("name"),
output_field=CharField(max_length=255),
editable=True,
db_persist=False,
)
def test_blank_unsupported(self): def test_blank_unsupported(self):
with self.assertRaisesMessage(ValueError, "GeneratedField must be blank."): with self.assertRaisesMessage(ValueError, "GeneratedField must be blank."):
GeneratedField(expression=Lower("name"), blank=False, db_persist=False) GeneratedField(
expression=Lower("name"),
output_field=CharField(max_length=255),
blank=False,
db_persist=False,
)
def test_default_unsupported(self): def test_default_unsupported(self):
msg = "GeneratedField cannot have a default." msg = "GeneratedField cannot have a default."
with self.assertRaisesMessage(ValueError, msg): with self.assertRaisesMessage(ValueError, msg):
GeneratedField(expression=Lower("name"), default="", db_persist=False) GeneratedField(
expression=Lower("name"),
output_field=CharField(max_length=255),
default="",
db_persist=False,
)
def test_database_default_unsupported(self): def test_database_default_unsupported(self):
msg = "GeneratedField cannot have a database default." msg = "GeneratedField cannot have a database default."
with self.assertRaisesMessage(ValueError, msg): with self.assertRaisesMessage(ValueError, msg):
GeneratedField(expression=Lower("name"), db_default="", db_persist=False) GeneratedField(
expression=Lower("name"),
output_field=CharField(max_length=255),
db_default="",
db_persist=False,
)
def test_db_persist_required(self): def test_db_persist_required(self):
msg = "GeneratedField.db_persist must be True or False." msg = "GeneratedField.db_persist must be True or False."
with self.assertRaisesMessage(ValueError, msg): with self.assertRaisesMessage(ValueError, msg):
GeneratedField(expression=Lower("name")) GeneratedField(
expression=Lower("name"), output_field=CharField(max_length=255)
)
with self.assertRaisesMessage(ValueError, msg): with self.assertRaisesMessage(ValueError, msg):
GeneratedField(expression=Lower("name"), db_persist=None) GeneratedField(
expression=Lower("name"),
output_field=CharField(max_length=255),
db_persist=None,
)
def test_deconstruct(self): def test_deconstruct(self):
field = GeneratedField(expression=F("a") + F("b"), db_persist=True) field = GeneratedField(
expression=F("a") + F("b"), output_field=IntegerField(), db_persist=True
)
_, path, args, kwargs = field.deconstruct() _, path, args, kwargs = field.deconstruct()
self.assertEqual(path, "django.db.models.GeneratedField") self.assertEqual(path, "django.db.models.GeneratedField")
self.assertEqual(args, []) self.assertEqual(args, [])
self.assertEqual(kwargs, {"db_persist": True, "expression": F("a") + F("b")}) self.assertEqual(kwargs["db_persist"], True)
self.assertEqual(kwargs["expression"], F("a") + F("b"))
self.assertEqual(
kwargs["output_field"].deconstruct(), IntegerField().deconstruct()
)
@isolate_apps("model_fields") @isolate_apps("model_fields")
def test_get_col(self): def test_get_col(self):
class Square(Model): class Square(Model):
side = IntegerField() side = IntegerField()
area = GeneratedField(expression=F("side") * F("side"), db_persist=True) area = GeneratedField(
expression=F("side") * F("side"),
output_field=IntegerField(),
db_persist=True,
)
col = Square._meta.get_field("area").get_col("alias") col = Square._meta.get_field("area").get_col("alias")
self.assertIsInstance(col.output_field, IntegerField) self.assertIsInstance(col.output_field, IntegerField)
@ -74,7 +117,9 @@ class BaseGeneratedFieldTests(SimpleTestCase):
class Sum(Model): class Sum(Model):
a = IntegerField() a = IntegerField()
b = IntegerField() b = IntegerField()
total = GeneratedField(expression=F("a") + F("b"), db_persist=True) total = GeneratedField(
expression=F("a") + F("b"), output_field=IntegerField(), db_persist=True
)
field = Sum._meta.get_field("total") field = Sum._meta.get_field("total")
cached_col = field.cached_col cached_col = field.cached_col
@ -165,9 +210,9 @@ class GeneratedFieldTestMixin:
with self.assertNumQueries(0), self.assertRaises(does_not_exist): with self.assertNumQueries(0), self.assertRaises(does_not_exist):
self.base_model.objects.get(field__gte=overflow_value) self.base_model.objects.get(field__gte=overflow_value)
def test_output_field(self): def test_output_field_db_collation(self):
collation = connection.features.test_collations["virtual"] collation = connection.features.test_collations["virtual"]
m = self.output_field_model.objects.create(name="NAME") m = self.output_field_db_collation_model.objects.create(name="NAME")
field = m._meta.get_field("lower_name") field = m._meta.get_field("lower_name")
db_parameters = field.db_parameters(connection) db_parameters = field.db_parameters(connection)
self.assertEqual(db_parameters["collation"], collation) self.assertEqual(db_parameters["collation"], collation)
@ -178,7 +223,7 @@ class GeneratedFieldTestMixin:
) )
def test_db_type_parameters(self): def test_db_type_parameters(self):
db_type_parameters = self.output_field_model._meta.get_field( db_type_parameters = self.output_field_db_collation_model._meta.get_field(
"lower_name" "lower_name"
).db_type_parameters(connection) ).db_type_parameters(connection)
self.assertEqual(db_type_parameters["max_length"], 11) self.assertEqual(db_type_parameters["max_length"], 11)
@ -202,7 +247,7 @@ class GeneratedFieldTestMixin:
class StoredGeneratedFieldTests(GeneratedFieldTestMixin, TestCase): class StoredGeneratedFieldTests(GeneratedFieldTestMixin, TestCase):
base_model = GeneratedModel base_model = GeneratedModel
nullable_model = GeneratedModelNull nullable_model = GeneratedModelNull
output_field_model = GeneratedModelOutputField output_field_db_collation_model = GeneratedModelOutputFieldDbCollation
params_model = GeneratedModelParams params_model = GeneratedModelParams
@ -210,5 +255,5 @@ class StoredGeneratedFieldTests(GeneratedFieldTestMixin, TestCase):
class VirtualGeneratedFieldTests(GeneratedFieldTestMixin, TestCase): class VirtualGeneratedFieldTests(GeneratedFieldTestMixin, TestCase):
base_model = GeneratedModelVirtual base_model = GeneratedModelVirtual
nullable_model = GeneratedModelNullVirtual nullable_model = GeneratedModelNullVirtual
output_field_model = GeneratedModelOutputFieldVirtual output_field_db_collation_model = GeneratedModelOutputFieldDbCollationVirtual
params_model = GeneratedModelParamsVirtual params_model = GeneratedModelParamsVirtual

View File

@ -829,7 +829,11 @@ class SchemaTests(TransactionTestCase):
def test_add_generated_field_with_kt_model(self): def test_add_generated_field_with_kt_model(self):
class GeneratedFieldKTModel(Model): class GeneratedFieldKTModel(Model):
data = JSONField() data = JSONField()
status = GeneratedField(expression=KT("data__status"), db_persist=True) status = GeneratedField(
expression=KT("data__status"),
output_field=TextField(),
db_persist=True,
)
class Meta: class Meta:
app_label = "schema" app_label = "schema"
@ -844,7 +848,7 @@ class SchemaTests(TransactionTestCase):
@isolate_apps("schema") @isolate_apps("schema")
@skipUnlessDBFeature("supports_stored_generated_columns") @skipUnlessDBFeature("supports_stored_generated_columns")
def test_add_generated_field_with_output_field(self): def test_add_generated_field(self):
class GeneratedFieldOutputFieldModel(Model): class GeneratedFieldOutputFieldModel(Model):
price = DecimalField(max_digits=7, decimal_places=2) price = DecimalField(max_digits=7, decimal_places=2)
vat_price = GeneratedField( vat_price = GeneratedField(