Fixed #34944 -- Made GeneratedField.output_field required.

Regression in f333e3513e.
This commit is contained in:
Mariusz Felisiak 2023-11-13 05:33:25 +01:00
parent de4884b114
commit 5875f03ce6
9 changed files with 150 additions and 54 deletions

View File

@ -16,7 +16,7 @@ class GeneratedField(Field):
_resolved_expression = None
output_field = None
def __init__(self, *, expression, db_persist=None, output_field=None, **kwargs):
def __init__(self, *, expression, output_field, db_persist=None, **kwargs):
if kwargs.setdefault("editable", False):
raise ValueError("GeneratedField cannot be editable.")
if not kwargs.setdefault("blank", True):
@ -29,7 +29,7 @@ class GeneratedField(Field):
raise ValueError("GeneratedField.db_persist must be True or False.")
self.expression = expression
self._output_field = output_field
self.output_field = output_field
self.db_persist = db_persist
super().__init__(**kwargs)
@ -51,11 +51,6 @@ class GeneratedField(Field):
self._resolved_expression = self.expression.resolve_expression(
self._query, allow_joins=False
)
self.output_field = (
self._output_field
if self._output_field is not None
else self._resolved_expression.output_field
)
# Register lookups from the output_field class.
for lookup_name, lookup in self.output_field.get_class_lookups().items():
self.register_lookup(lookup, lookup_name=lookup_name)
@ -150,8 +145,7 @@ class GeneratedField(Field):
del kwargs["editable"]
kwargs["db_persist"] = self.db_persist
kwargs["expression"] = self.expression
if self._output_field is not None:
kwargs["output_field"] = self._output_field
kwargs["output_field"] = self.output_field
return name, path, args, kwargs
def get_internal_type(self):

View File

@ -1237,7 +1237,7 @@ when :attr:`~django.forms.Field.localize` is ``False`` or
.. versionadded:: 5.0
.. class:: GeneratedField(expression, db_persist=None, output_field=None, **kwargs)
.. class:: GeneratedField(expression, output_field, db_persist=None, **kwargs)
A field that is always computed based on other fields in the model. This field
is managed and updated by the database itself. Uses the ``GENERATED ALWAYS``
@ -1259,6 +1259,10 @@ materialized view.
the model (in the same database table). Generated fields cannot reference
other generated fields. Database backends can impose further restrictions.
.. attribute:: GeneratedField.output_field
A model field instance to define the field's data type.
.. attribute:: GeneratedField.db_persist
Determines if the database column should occupy storage as if it were a
@ -1268,12 +1272,6 @@ materialized view.
PostgreSQL only supports persisted columns. Oracle only supports virtual
columns.
.. attribute:: GeneratedField.output_field
An optional model field instance to define the field's data type. This can
be used to customize attributes like the field's collation. By default, the
output field is derived from ``expression``.
.. admonition:: Refresh the data
Since the database always computed the value, the object must be reloaded

View File

@ -142,7 +142,11 @@ to create a field that is always computed from other fields. For example::
class Square(models.Model):
side = models.IntegerField()
area = models.GeneratedField(expression=F("side") * F("side"), db_persist=True)
area = models.GeneratedField(
expression=F("side") * F("side"),
output_field=models.BigIntegerField(),
db_persist=True,
)
More options for declaring field choices
----------------------------------------

View File

@ -1147,6 +1147,7 @@ class Square(models.Model):
area = models.GeneratedField(
db_persist=True,
expression=models.F("side") * models.F("side"),
output_field=models.BigIntegerField(),
)
class Meta:

View File

@ -1216,7 +1216,9 @@ class GeneratedFieldTests(TestCase):
class Model(models.Model):
name = models.IntegerField()
field = models.GeneratedField(
expression=models.F("name"), db_persist=db_persist
expression=models.F("name"),
output_field=models.IntegerField(),
db_persist=db_persist,
)
expected_errors = []
@ -1252,7 +1254,11 @@ class GeneratedFieldTests(TestCase):
def test_not_supported_stored_required_db_features(self):
class Model(models.Model):
name = models.IntegerField()
field = models.GeneratedField(expression=models.F("name"), db_persist=True)
field = models.GeneratedField(
expression=models.F("name"),
output_field=models.IntegerField(),
db_persist=True,
)
class Meta:
required_db_features = {"supports_stored_generated_columns"}
@ -1262,7 +1268,11 @@ class GeneratedFieldTests(TestCase):
def test_not_supported_virtual_required_db_features(self):
class Model(models.Model):
name = models.IntegerField()
field = models.GeneratedField(expression=models.F("name"), db_persist=False)
field = models.GeneratedField(
expression=models.F("name"),
output_field=models.IntegerField(),
db_persist=False,
)
class Meta:
required_db_features = {"supports_virtual_generated_columns"}
@ -1273,7 +1283,11 @@ class GeneratedFieldTests(TestCase):
def test_not_supported_virtual(self):
class Model(models.Model):
name = models.IntegerField()
field = models.GeneratedField(expression=models.F("name"), db_persist=False)
field = models.GeneratedField(
expression=models.F("name"),
output_field=models.IntegerField(),
db_persist=False,
)
a = models.TextField()
excepted_errors = (
@ -1298,7 +1312,11 @@ class GeneratedFieldTests(TestCase):
def test_not_supported_stored(self):
class Model(models.Model):
name = models.IntegerField()
field = models.GeneratedField(expression=models.F("name"), db_persist=True)
field = models.GeneratedField(
expression=models.F("name"),
output_field=models.IntegerField(),
db_persist=True,
)
a = models.TextField()
expected_errors = (

View File

@ -5664,10 +5664,14 @@ class OperationTests(OperationTestBase):
def _test_invalid_generated_field_changes(self, db_persist):
regular = models.IntegerField(default=1)
generated_1 = models.GeneratedField(
expression=F("pink") + F("pink"), db_persist=db_persist
expression=F("pink") + F("pink"),
output_field=models.IntegerField(),
db_persist=db_persist,
)
generated_2 = models.GeneratedField(
expression=F("pink") + F("pink") + F("pink"), db_persist=db_persist
expression=F("pink") + F("pink") + F("pink"),
output_field=models.IntegerField(),
db_persist=db_persist,
)
tests = [
("test_igfc_1", regular, generated_1),
@ -5707,12 +5711,20 @@ class OperationTests(OperationTestBase):
migrations.AddField(
"Pony",
"modified_pink",
models.GeneratedField(expression=F("pink"), db_persist=True),
models.GeneratedField(
expression=F("pink"),
output_field=models.IntegerField(),
db_persist=True,
),
),
migrations.AlterField(
"Pony",
"modified_pink",
models.GeneratedField(expression=F("pink"), db_persist=False),
models.GeneratedField(
expression=F("pink"),
output_field=models.IntegerField(),
db_persist=False,
),
),
]
msg = (
@ -5729,7 +5741,9 @@ class OperationTests(OperationTestBase):
"Pony",
"modified_pink",
models.GeneratedField(
expression=F("pink") + F("pink"), db_persist=db_persist
expression=F("pink") + F("pink"),
output_field=models.IntegerField(),
db_persist=db_persist,
),
)
project_state, new_state = self.make_test_state(app_label, operation)
@ -5760,7 +5774,9 @@ class OperationTests(OperationTestBase):
"Pony",
"modified_pink",
models.GeneratedField(
expression=F("pink") + F("pink"), db_persist=db_persist
expression=F("pink") + F("pink"),
output_field=models.IntegerField(),
db_persist=db_persist,
),
)
project_state, new_state = self.make_test_state(app_label, operation)

View File

@ -485,7 +485,11 @@ class UUIDGrandchild(UUIDChild):
class GeneratedModel(models.Model):
a = models.IntegerField()
b = models.IntegerField()
field = models.GeneratedField(expression=F("a") + F("b"), db_persist=True)
field = models.GeneratedField(
expression=F("a") + F("b"),
output_field=models.IntegerField(),
db_persist=True,
)
class Meta:
required_db_features = {"supports_stored_generated_columns"}
@ -494,7 +498,11 @@ class GeneratedModel(models.Model):
class GeneratedModelVirtual(models.Model):
a = models.IntegerField()
b = models.IntegerField()
field = models.GeneratedField(expression=F("a") + F("b"), db_persist=False)
field = models.GeneratedField(
expression=F("a") + F("b"),
output_field=models.IntegerField(),
db_persist=False,
)
class Meta:
required_db_features = {"supports_virtual_generated_columns"}
@ -503,6 +511,7 @@ class GeneratedModelVirtual(models.Model):
class GeneratedModelParams(models.Model):
field = models.GeneratedField(
expression=Value("Constant", output_field=models.CharField(max_length=10)),
output_field=models.CharField(max_length=10),
db_persist=True,
)
@ -513,6 +522,7 @@ class GeneratedModelParams(models.Model):
class GeneratedModelParamsVirtual(models.Model):
field = models.GeneratedField(
expression=Value("Constant", output_field=models.CharField(max_length=10)),
output_field=models.CharField(max_length=10),
db_persist=False,
)
@ -520,7 +530,7 @@ class GeneratedModelParamsVirtual(models.Model):
required_db_features = {"supports_virtual_generated_columns"}
class GeneratedModelOutputField(models.Model):
class GeneratedModelOutputFieldDbCollation(models.Model):
name = models.CharField(max_length=10)
lower_name = models.GeneratedField(
expression=Lower("name"),
@ -532,7 +542,7 @@ class GeneratedModelOutputField(models.Model):
required_db_features = {"supports_stored_generated_columns"}
class GeneratedModelOutputFieldVirtual(models.Model):
class GeneratedModelOutputFieldDbCollationVirtual(models.Model):
name = models.CharField(max_length=10)
lower_name = models.GeneratedField(
expression=Lower("name"),
@ -547,7 +557,10 @@ class GeneratedModelOutputFieldVirtual(models.Model):
class GeneratedModelNull(models.Model):
name = models.CharField(max_length=10, null=True)
lower_name = models.GeneratedField(
expression=Lower("name"), db_persist=True, null=True
expression=Lower("name"),
output_field=models.CharField(max_length=10),
db_persist=True,
null=True,
)
class Meta:
@ -557,7 +570,10 @@ class GeneratedModelNull(models.Model):
class GeneratedModelNullVirtual(models.Model):
name = models.CharField(max_length=10, null=True)
lower_name = models.GeneratedField(
expression=Lower("name"), db_persist=False, null=True
expression=Lower("name"),
output_field=models.CharField(max_length=10),
db_persist=False,
null=True,
)
class Meta:

View File

@ -1,5 +1,12 @@
from django.db import IntegrityError, connection
from django.db.models import F, FloatField, GeneratedField, IntegerField, Model
from django.db.models import (
CharField,
F,
FloatField,
GeneratedField,
IntegerField,
Model,
)
from django.db.models.functions import Lower
from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature
from django.test.utils import isolate_apps
@ -8,8 +15,8 @@ from .models import (
GeneratedModel,
GeneratedModelNull,
GeneratedModelNullVirtual,
GeneratedModelOutputField,
GeneratedModelOutputFieldVirtual,
GeneratedModelOutputFieldDbCollation,
GeneratedModelOutputFieldDbCollationVirtual,
GeneratedModelParams,
GeneratedModelParamsVirtual,
GeneratedModelVirtual,
@ -19,41 +26,77 @@ from .models import (
class BaseGeneratedFieldTests(SimpleTestCase):
def test_editable_unsupported(self):
with self.assertRaisesMessage(ValueError, "GeneratedField cannot be editable."):
GeneratedField(expression=Lower("name"), editable=True, db_persist=False)
GeneratedField(
expression=Lower("name"),
output_field=CharField(max_length=255),
editable=True,
db_persist=False,
)
def test_blank_unsupported(self):
with self.assertRaisesMessage(ValueError, "GeneratedField must be blank."):
GeneratedField(expression=Lower("name"), blank=False, db_persist=False)
GeneratedField(
expression=Lower("name"),
output_field=CharField(max_length=255),
blank=False,
db_persist=False,
)
def test_default_unsupported(self):
msg = "GeneratedField cannot have a default."
with self.assertRaisesMessage(ValueError, msg):
GeneratedField(expression=Lower("name"), default="", db_persist=False)
GeneratedField(
expression=Lower("name"),
output_field=CharField(max_length=255),
default="",
db_persist=False,
)
def test_database_default_unsupported(self):
msg = "GeneratedField cannot have a database default."
with self.assertRaisesMessage(ValueError, msg):
GeneratedField(expression=Lower("name"), db_default="", db_persist=False)
GeneratedField(
expression=Lower("name"),
output_field=CharField(max_length=255),
db_default="",
db_persist=False,
)
def test_db_persist_required(self):
msg = "GeneratedField.db_persist must be True or False."
with self.assertRaisesMessage(ValueError, msg):
GeneratedField(expression=Lower("name"))
GeneratedField(
expression=Lower("name"), output_field=CharField(max_length=255)
)
with self.assertRaisesMessage(ValueError, msg):
GeneratedField(expression=Lower("name"), db_persist=None)
GeneratedField(
expression=Lower("name"),
output_field=CharField(max_length=255),
db_persist=None,
)
def test_deconstruct(self):
field = GeneratedField(expression=F("a") + F("b"), db_persist=True)
field = GeneratedField(
expression=F("a") + F("b"), output_field=IntegerField(), db_persist=True
)
_, path, args, kwargs = field.deconstruct()
self.assertEqual(path, "django.db.models.GeneratedField")
self.assertEqual(args, [])
self.assertEqual(kwargs, {"db_persist": True, "expression": F("a") + F("b")})
self.assertEqual(kwargs["db_persist"], True)
self.assertEqual(kwargs["expression"], F("a") + F("b"))
self.assertEqual(
kwargs["output_field"].deconstruct(), IntegerField().deconstruct()
)
@isolate_apps("model_fields")
def test_get_col(self):
class Square(Model):
side = IntegerField()
area = GeneratedField(expression=F("side") * F("side"), db_persist=True)
area = GeneratedField(
expression=F("side") * F("side"),
output_field=IntegerField(),
db_persist=True,
)
col = Square._meta.get_field("area").get_col("alias")
self.assertIsInstance(col.output_field, IntegerField)
@ -74,7 +117,9 @@ class BaseGeneratedFieldTests(SimpleTestCase):
class Sum(Model):
a = IntegerField()
b = IntegerField()
total = GeneratedField(expression=F("a") + F("b"), db_persist=True)
total = GeneratedField(
expression=F("a") + F("b"), output_field=IntegerField(), db_persist=True
)
field = Sum._meta.get_field("total")
cached_col = field.cached_col
@ -165,9 +210,9 @@ class GeneratedFieldTestMixin:
with self.assertNumQueries(0), self.assertRaises(does_not_exist):
self.base_model.objects.get(field__gte=overflow_value)
def test_output_field(self):
def test_output_field_db_collation(self):
collation = connection.features.test_collations["virtual"]
m = self.output_field_model.objects.create(name="NAME")
m = self.output_field_db_collation_model.objects.create(name="NAME")
field = m._meta.get_field("lower_name")
db_parameters = field.db_parameters(connection)
self.assertEqual(db_parameters["collation"], collation)
@ -178,7 +223,7 @@ class GeneratedFieldTestMixin:
)
def test_db_type_parameters(self):
db_type_parameters = self.output_field_model._meta.get_field(
db_type_parameters = self.output_field_db_collation_model._meta.get_field(
"lower_name"
).db_type_parameters(connection)
self.assertEqual(db_type_parameters["max_length"], 11)
@ -202,7 +247,7 @@ class GeneratedFieldTestMixin:
class StoredGeneratedFieldTests(GeneratedFieldTestMixin, TestCase):
base_model = GeneratedModel
nullable_model = GeneratedModelNull
output_field_model = GeneratedModelOutputField
output_field_db_collation_model = GeneratedModelOutputFieldDbCollation
params_model = GeneratedModelParams
@ -210,5 +255,5 @@ class StoredGeneratedFieldTests(GeneratedFieldTestMixin, TestCase):
class VirtualGeneratedFieldTests(GeneratedFieldTestMixin, TestCase):
base_model = GeneratedModelVirtual
nullable_model = GeneratedModelNullVirtual
output_field_model = GeneratedModelOutputFieldVirtual
output_field_db_collation_model = GeneratedModelOutputFieldDbCollationVirtual
params_model = GeneratedModelParamsVirtual

View File

@ -829,7 +829,11 @@ class SchemaTests(TransactionTestCase):
def test_add_generated_field_with_kt_model(self):
class GeneratedFieldKTModel(Model):
data = JSONField()
status = GeneratedField(expression=KT("data__status"), db_persist=True)
status = GeneratedField(
expression=KT("data__status"),
output_field=TextField(),
db_persist=True,
)
class Meta:
app_label = "schema"
@ -844,7 +848,7 @@ class SchemaTests(TransactionTestCase):
@isolate_apps("schema")
@skipUnlessDBFeature("supports_stored_generated_columns")
def test_add_generated_field_with_output_field(self):
def test_add_generated_field(self):
class GeneratedFieldOutputFieldModel(Model):
price = DecimalField(max_digits=7, decimal_places=2)
vat_price = GeneratedField(