diff --git a/django/db/models/fields/generated.py b/django/db/models/fields/generated.py index 0980be98af4..deb5875638c 100644 --- a/django/db/models/fields/generated.py +++ b/django/db/models/fields/generated.py @@ -1,6 +1,7 @@ from django.core import checks from django.db import connections, router from django.db.models.sql import Query +from django.utils.functional import cached_property from . import NOT_PROVIDED, Field @@ -32,6 +33,17 @@ class GeneratedField(Field): self.db_persist = db_persist super().__init__(**kwargs) + @cached_property + def cached_col(self): + from django.db.models.expressions import Col + + return Col(self.model._meta.db_table, self, self.output_field) + + def get_col(self, alias, output_field=None): + if alias != self.model._meta.db_table and output_field is None: + output_field = self.output_field + return super().get_col(alias, output_field) + def contribute_to_class(self, *args, **kwargs): super().contribute_to_class(*args, **kwargs) diff --git a/tests/model_fields/test_generatedfield.py b/tests/model_fields/test_generatedfield.py index e2746bdd0cd..dec1f3a31fd 100644 --- a/tests/model_fields/test_generatedfield.py +++ b/tests/model_fields/test_generatedfield.py @@ -1,6 +1,6 @@ from django.core.exceptions import FieldError from django.db import IntegrityError, connection -from django.db.models import F, GeneratedField, IntegerField +from django.db.models import F, FloatField, GeneratedField, IntegerField, Model from django.db.models.functions import Lower from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature @@ -49,6 +49,40 @@ class BaseGeneratedFieldTests(SimpleTestCase): self.assertEqual(args, []) self.assertEqual(kwargs, {"db_persist": True, "expression": F("a") + F("b")}) + def test_get_col(self): + class Square(Model): + side = IntegerField() + area = GeneratedField(expression=F("side") * F("side"), db_persist=True) + + col = Square._meta.get_field("area").get_col("alias") + self.assertIsInstance(col.output_field, IntegerField) + + class FloatSquare(Model): + side = IntegerField() + area = GeneratedField( + expression=F("side") * F("side"), + db_persist=True, + output_field=FloatField(), + ) + + col = FloatSquare._meta.get_field("area").get_col("alias") + self.assertIsInstance(col.output_field, FloatField) + + def test_cached_col(self): + class Sum(Model): + a = IntegerField() + b = IntegerField() + total = GeneratedField(expression=F("a") + F("b"), db_persist=True) + + field = Sum._meta.get_field("total") + cached_col = field.cached_col + self.assertIs(field.get_col(Sum._meta.db_table), cached_col) + self.assertIs(field.get_col(Sum._meta.db_table, field), cached_col) + self.assertIsNot(field.get_col("alias"), cached_col) + self.assertIsNot(field.get_col(Sum._meta.db_table, IntegerField()), cached_col) + self.assertIs(cached_col.target, field) + self.assertIsInstance(cached_col.output_field, IntegerField) + class GeneratedFieldTestMixin: def _refresh_if_needed(self, m):