Fixed #25534, Fixed #31639 -- Added support for transform references in expressions.

Thanks Mariusz Felisiak and Simon Charette for reviews.
This commit is contained in:
Ian Foote 2020-11-15 22:43:47 +00:00 committed by Mariusz Felisiak
parent e46ca51c24
commit 8b040e3cbb
14 changed files with 354 additions and 82 deletions

View File

@ -1610,6 +1610,8 @@ class Query(BaseExpression):
# fields to the appropriate wrapped version. # fields to the appropriate wrapped version.
def final_transformer(field, alias): def final_transformer(field, alias):
if not self.alias_cols:
alias = None
return field.get_col(alias) return field.get_col(alias)
# Try resolving all the names as fields first. If there's an error, # Try resolving all the names as fields first. If there's an error,
@ -1714,8 +1716,6 @@ class Query(BaseExpression):
yield from (expr.alias for expr in cls._gen_cols(exprs)) yield from (expr.alias for expr in cls._gen_cols(exprs))
def resolve_ref(self, name, allow_joins=True, reuse=None, summarize=False): def resolve_ref(self, name, allow_joins=True, reuse=None, summarize=False):
if not allow_joins and LOOKUP_SEP in name:
raise FieldError("Joined field references are not permitted in this query")
annotation = self.annotations.get(name) annotation = self.annotations.get(name)
if annotation is not None: if annotation is not None:
if not allow_joins: if not allow_joins:
@ -1740,6 +1740,11 @@ class Query(BaseExpression):
return annotation return annotation
else: else:
field_list = name.split(LOOKUP_SEP) field_list = name.split(LOOKUP_SEP)
annotation = self.annotations.get(field_list[0])
if annotation is not None:
for transform in field_list[1:]:
annotation = self.try_transform(annotation, transform)
return annotation
join_info = self.setup_joins(field_list, self.get_meta(), self.get_initial_alias(), can_reuse=reuse) join_info = self.setup_joins(field_list, self.get_meta(), self.get_initial_alias(), can_reuse=reuse)
targets, final_alias, join_list = self.trim_joins(join_info.targets, join_info.joins, join_info.path) targets, final_alias, join_list = self.trim_joins(join_info.targets, join_info.joins, join_info.path)
if not allow_joins and len(join_list) > 1: if not allow_joins and len(join_list) > 1:
@ -1749,10 +1754,10 @@ class Query(BaseExpression):
"isn't supported") "isn't supported")
# Verify that the last lookup in name is a field or a transform: # Verify that the last lookup in name is a field or a transform:
# transform_function() raises FieldError if not. # transform_function() raises FieldError if not.
join_info.transform_function(targets[0], final_alias) transform = join_info.transform_function(targets[0], final_alias)
if reuse is not None: if reuse is not None:
reuse.update(join_list) reuse.update(join_list)
return self._get_col(targets[0], join_info.targets[0], join_list[-1]) return transform
def split_exclude(self, filter_expr, can_reuse, names_with_path): def split_exclude(self, filter_expr, can_reuse, names_with_path):
""" """

View File

@ -90,10 +90,10 @@ Built-in Expressions
.. class:: F .. class:: F
An ``F()`` object represents the value of a model field or annotated column. It An ``F()`` object represents the value of a model field, transformed value of a
makes it possible to refer to model field values and perform database model field, or annotated column. It makes it possible to refer to model field
operations using them without actually having to pull them out of the database values and perform database operations using them without actually having to
into Python memory. pull them out of the database into Python memory.
Instead, Django uses the ``F()`` object to generate an SQL expression that Instead, Django uses the ``F()`` object to generate an SQL expression that
describes the required operation at the database level. describes the required operation at the database level.
@ -155,6 +155,10 @@ the field value of each one, and saving each one back to the database::
* getting the database, rather than Python, to do work * getting the database, rather than Python, to do work
* reducing the number of queries some operations require * reducing the number of queries some operations require
.. versionchanged:: 3.2
Support for transforms of the field was added.
.. _avoiding-race-conditions-using-f: .. _avoiding-race-conditions-using-f:
Avoiding race conditions using ``F()`` Avoiding race conditions using ``F()``
@ -406,9 +410,9 @@ The ``Aggregate`` API is as follows:
allows passing a ``distinct`` keyword argument. If set to ``False`` allows passing a ``distinct`` keyword argument. If set to ``False``
(default), ``TypeError`` is raised if ``distinct=True`` is passed. (default), ``TypeError`` is raised if ``distinct=True`` is passed.
The ``expressions`` positional arguments can include expressions or the names The ``expressions`` positional arguments can include expressions, transforms of
of model fields. They will be converted to a string and used as the the model field, or the names of model fields. They will be converted to a
``expressions`` placeholder within the ``template``. string and used as the ``expressions`` placeholder within the ``template``.
The ``output_field`` argument requires a model field instance, like The ``output_field`` argument requires a model field instance, like
``IntegerField()`` or ``BooleanField()``, into which Django will load the value ``IntegerField()`` or ``BooleanField()``, into which Django will load the value
@ -435,6 +439,10 @@ and :ref:`filtering-on-annotations` for example usage.
The ``**extra`` kwargs are ``key=value`` pairs that can be interpolated The ``**extra`` kwargs are ``key=value`` pairs that can be interpolated
into the ``template`` attribute. into the ``template`` attribute.
.. versionchanged:: 3.2
Support for transforms of the field was added.
Creating your own Aggregate Functions Creating your own Aggregate Functions
------------------------------------- -------------------------------------
@ -551,9 +559,9 @@ Referencing columns from the outer queryset
.. class:: OuterRef(field) .. class:: OuterRef(field)
Use ``OuterRef`` when a queryset in a ``Subquery`` needs to refer to a field Use ``OuterRef`` when a queryset in a ``Subquery`` needs to refer to a field
from the outer query. It acts like an :class:`F` expression except that the from the outer query or its transform. It acts like an :class:`F` expression
check to see if it refers to a valid field isn't made until the outer queryset except that the check to see if it refers to a valid field isn't made until the
is resolved. outer queryset is resolved.
Instances of ``OuterRef`` may be used in conjunction with nested instances Instances of ``OuterRef`` may be used in conjunction with nested instances
of ``Subquery`` to refer to a containing queryset that isn't the immediate of ``Subquery`` to refer to a containing queryset that isn't the immediate
@ -562,6 +570,10 @@ parent. For example, this queryset would need to be within a nested pair of
>>> Book.objects.filter(author=OuterRef(OuterRef('pk'))) >>> Book.objects.filter(author=OuterRef(OuterRef('pk')))
.. versionchanged:: 3.2
Support for transforms of the field was added.
Limiting a subquery to a single column Limiting a subquery to a single column
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -3525,8 +3525,12 @@ All aggregates have the following parameters in common:
``expressions`` ``expressions``
~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~
Strings that reference fields on the model, or :doc:`query expressions Strings that reference fields on the model, transforms of the field, or
</ref/models/expressions>`. :doc:`query expressions </ref/models/expressions>`.
.. versionchanged:: 3.2
Support for transforms of the field was added.
``output_field`` ``output_field``
~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~

View File

@ -351,6 +351,11 @@ Models
* Added the :class:`~django.db.models.functions.Random` database function. * Added the :class:`~django.db.models.functions.Random` database function.
* :ref:`aggregation-functions`, :class:`F() <django.db.models.F>`,
:class:`OuterRef() <django.db.models.OuterRef>`, and other expressions now
allow using transforms. See :ref:`using-transforms-in-expressions` for
details.
Pagination Pagination
~~~~~~~~~~ ~~~~~~~~~~

View File

@ -669,6 +669,36 @@ The ``F()`` objects support bitwise operations by ``.bitand()``, ``.bitor()``,
Support for ``.bitxor()`` was added. Support for ``.bitxor()`` was added.
.. _using-transforms-in-expressions:
Expressions can reference transforms
------------------------------------
.. versionadded: 3.2
Django supports using transforms in expressions.
For example, to find all ``Entry`` objects published in the same year as they
were last modified::
>>> Entry.objects.filter(pub_date__year=F('mod_date__year'))
To find the earliest year an entry was published, we can issue the query::
>>> Entry.objects.aggregate(first_published_year=Min('pub_date__year'))
This example finds the value of the highest rated entry and the total number
of comments on all entries for each year::
>>> Entry.objects.values('pub_date__year').annotate(
... top_rating=Subquery(
... Entry.objects.filter(
... pub_date__year=OuterRef('pub_date__year'),
... ).order_by('-rating').values('rating')[:1]
... ),
... total_comments=Sum('number_of_comments'),
... )
The ``pk`` lookup shortcut The ``pk`` lookup shortcut
-------------------------- --------------------------

View File

@ -151,6 +151,14 @@ class AggregateTestCase(TestCase):
vals = Store.objects.filter(name="Amazon.com").aggregate(amazon_mean=Avg("books__rating")) vals = Store.objects.filter(name="Amazon.com").aggregate(amazon_mean=Avg("books__rating"))
self.assertEqual(vals, {'amazon_mean': Approximate(4.08, places=2)}) self.assertEqual(vals, {'amazon_mean': Approximate(4.08, places=2)})
def test_aggregate_transform(self):
vals = Store.objects.aggregate(min_month=Min('original_opening__month'))
self.assertEqual(vals, {'min_month': 3})
def test_aggregate_join_transform(self):
vals = Publisher.objects.aggregate(min_year=Min('book__pubdate__year'))
self.assertEqual(vals, {'min_year': 1991})
def test_annotate_basic(self): def test_annotate_basic(self):
self.assertQuerysetEqual( self.assertQuerysetEqual(
Book.objects.annotate().order_by('pk'), [ Book.objects.annotate().order_by('pk'), [

View File

@ -4,13 +4,16 @@ from decimal import Decimal
from django.core.exceptions import FieldDoesNotExist, FieldError from django.core.exceptions import FieldDoesNotExist, FieldError
from django.db import connection from django.db import connection
from django.db.models import ( from django.db.models import (
BooleanField, Case, Count, DateTimeField, Exists, ExpressionWrapper, F, BooleanField, Case, CharField, Count, DateTimeField, DecimalField, Exists,
FloatField, Func, IntegerField, Max, NullBooleanField, OuterRef, Q, ExpressionWrapper, F, FloatField, Func, IntegerField, Max,
Subquery, Sum, Value, When, NullBooleanField, OuterRef, Q, Subquery, Sum, Value, When,
) )
from django.db.models.expressions import RawSQL from django.db.models.expressions import RawSQL
from django.db.models.functions import Coalesce, ExtractYear, Length, Lower from django.db.models.functions import (
Coalesce, ExtractYear, Floor, Length, Lower, Trim,
)
from django.test import TestCase, skipUnlessDBFeature from django.test import TestCase, skipUnlessDBFeature
from django.test.utils import register_lookup
from .models import ( from .models import (
Author, Book, Company, DepartmentStore, Employee, Publisher, Store, Ticket, Author, Book, Company, DepartmentStore, Employee, Publisher, Store, Ticket,
@ -95,24 +98,24 @@ class NonAggregateAnnotationTestCase(TestCase):
cls.b5.authors.add(cls.a8, cls.a9) cls.b5.authors.add(cls.a8, cls.a9)
cls.b6.authors.add(cls.a8) cls.b6.authors.add(cls.a8)
s1 = Store.objects.create( cls.s1 = Store.objects.create(
name='Amazon.com', name='Amazon.com',
original_opening=datetime.datetime(1994, 4, 23, 9, 17, 42), original_opening=datetime.datetime(1994, 4, 23, 9, 17, 42),
friday_night_closing=datetime.time(23, 59, 59) friday_night_closing=datetime.time(23, 59, 59)
) )
s2 = Store.objects.create( cls.s2 = Store.objects.create(
name='Books.com', name='Books.com',
original_opening=datetime.datetime(2001, 3, 15, 11, 23, 37), original_opening=datetime.datetime(2001, 3, 15, 11, 23, 37),
friday_night_closing=datetime.time(23, 59, 59) friday_night_closing=datetime.time(23, 59, 59)
) )
s3 = Store.objects.create( cls.s3 = Store.objects.create(
name="Mamma and Pappa's Books", name="Mamma and Pappa's Books",
original_opening=datetime.datetime(1945, 4, 25, 16, 24, 14), original_opening=datetime.datetime(1945, 4, 25, 16, 24, 14),
friday_night_closing=datetime.time(21, 30) friday_night_closing=datetime.time(21, 30)
) )
s1.books.add(cls.b1, cls.b2, cls.b3, cls.b4, cls.b5, cls.b6) cls.s1.books.add(cls.b1, cls.b2, cls.b3, cls.b4, cls.b5, cls.b6)
s2.books.add(cls.b1, cls.b3, cls.b5, cls.b6) cls.s2.books.add(cls.b1, cls.b3, cls.b5, cls.b6)
s3.books.add(cls.b3, cls.b4, cls.b6) cls.s3.books.add(cls.b3, cls.b4, cls.b6)
def test_basic_annotation(self): def test_basic_annotation(self):
books = Book.objects.annotate(is_book=Value(1)) books = Book.objects.annotate(is_book=Value(1))
@ -130,6 +133,66 @@ class NonAggregateAnnotationTestCase(TestCase):
for book in books: for book in books:
self.assertEqual(book.num_awards, book.publisher.num_awards) self.assertEqual(book.num_awards, book.publisher.num_awards)
def test_joined_transformed_annotation(self):
Employee.objects.bulk_create([
Employee(
first_name='John',
last_name='Doe',
age=18,
store=self.s1,
salary=15000,
),
Employee(
first_name='Jane',
last_name='Jones',
age=30,
store=self.s2,
salary=30000,
),
Employee(
first_name='Jo',
last_name='Smith',
age=55,
store=self.s3,
salary=50000,
),
])
employees = Employee.objects.annotate(
store_opened_year=F('store__original_opening__year'),
)
for employee in employees:
self.assertEqual(
employee.store_opened_year,
employee.store.original_opening.year,
)
def test_custom_transform_annotation(self):
with register_lookup(DecimalField, Floor):
books = Book.objects.annotate(floor_price=F('price__floor'))
self.assertSequenceEqual(books.values_list('pk', 'floor_price'), [
(self.b1.pk, 30),
(self.b2.pk, 23),
(self.b3.pk, 29),
(self.b4.pk, 29),
(self.b5.pk, 82),
(self.b6.pk, 75),
])
def test_chaining_transforms(self):
Company.objects.create(name=' Django Software Foundation ')
Company.objects.create(name='Yahoo')
with register_lookup(CharField, Trim), register_lookup(CharField, Length):
for expr in [Length('name__trim'), F('name__trim__length')]:
with self.subTest(expr=expr):
self.assertCountEqual(
Company.objects.annotate(length=expr).values('name', 'length'),
[
{'name': ' Django Software Foundation ', 'length': 26},
{'name': 'Yahoo', 'length': 5},
],
)
def test_mixed_type_annotation_date_interval(self): def test_mixed_type_annotation_date_interval(self):
active = datetime.datetime(2015, 3, 20, 14, 0, 0) active = datetime.datetime(2015, 3, 20, 14, 0, 0)
duration = datetime.timedelta(hours=1) duration = datetime.timedelta(hours=1)
@ -689,6 +752,23 @@ class NonAggregateAnnotationTestCase(TestCase):
{'pub_year': 2008, 'top_rating': 4.0, 'total_pages': 1178}, {'pub_year': 2008, 'top_rating': 4.0, 'total_pages': 1178},
]) ])
def test_annotation_subquery_outerref_transform(self):
qs = Book.objects.annotate(
top_rating_year=Subquery(
Book.objects.filter(
pubdate__year=OuterRef('pubdate__year')
).order_by('-rating').values('rating')[:1]
),
).values('pubdate__year', 'top_rating_year')
self.assertCountEqual(qs, [
{'pubdate__year': 1991, 'top_rating_year': 5.0},
{'pubdate__year': 1995, 'top_rating_year': 4.0},
{'pubdate__year': 2007, 'top_rating_year': 4.5},
{'pubdate__year': 2008, 'top_rating_year': 4.0},
{'pubdate__year': 2008, 'top_rating_year': 4.0},
{'pubdate__year': 2008, 'top_rating_year': 4.0},
])
def test_annotation_aggregate_with_m2o(self): def test_annotation_aggregate_with_m2o(self):
if connection.vendor == 'mysql' and 'ONLY_FULL_GROUP_BY' in connection.sql_mode: if connection.vendor == 'mysql' and 'ONLY_FULL_GROUP_BY' in connection.sql_mode:
self.skipTest( self.skipTest(
@ -776,6 +856,15 @@ class AliasTests(TestCase):
with self.subTest(book=book): with self.subTest(book=book):
self.assertEqual(book.another_rating, book.rating) self.assertEqual(book.another_rating, book.rating)
def test_basic_alias_f_transform_annotation(self):
qs = Book.objects.alias(
pubdate_alias=F('pubdate'),
).annotate(pubdate_year=F('pubdate_alias__year'))
self.assertIs(hasattr(qs.first(), 'pubdate_alias'), False)
for book in qs:
with self.subTest(book=book):
self.assertEqual(book.pubdate_year, book.pubdate.year)
def test_alias_after_annotation(self): def test_alias_after_annotation(self):
qs = Book.objects.annotate( qs = Book.objects.annotate(
is_book=Value(1), is_book=Value(1),

View File

@ -25,7 +25,9 @@ from django.db.models.functions import (
from django.db.models.sql import constants from django.db.models.sql import constants
from django.db.models.sql.datastructures import Join from django.db.models.sql.datastructures import Join
from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature
from django.test.utils import Approximate, CaptureQueriesContext, isolate_apps from django.test.utils import (
Approximate, CaptureQueriesContext, isolate_apps, register_lookup,
)
from django.utils.functional import SimpleLazyObject from django.utils.functional import SimpleLazyObject
from .models import ( from .models import (
@ -1216,6 +1218,12 @@ class ExpressionOperatorTests(TestCase):
self.assertEqual(Number.objects.get(pk=self.n.pk).integer, 58) self.assertEqual(Number.objects.get(pk=self.n.pk).integer, 58)
self.assertEqual(Number.objects.get(pk=self.n1.pk).integer, -10) self.assertEqual(Number.objects.get(pk=self.n1.pk).integer, -10)
def test_lefthand_transformed_field_bitwise_or(self):
Employee.objects.create(firstname='Max', lastname='Mustermann')
with register_lookup(CharField, Length):
qs = Employee.objects.annotate(bitor=F('lastname__length').bitor(48))
self.assertEqual(qs.get().bitor, 58)
def test_lefthand_power(self): def test_lefthand_power(self):
# LH Power arithmetic operation on floats and integers # LH Power arithmetic operation on floats and integers
Number.objects.filter(pk=self.n.pk).update(integer=F('integer') ** 2, float=F('float') ** 1.5) Number.objects.filter(pk=self.n.pk).update(integer=F('integer') ** 2, float=F('float') ** 1.5)

View File

@ -48,24 +48,35 @@ class WindowFunctionTests(TestCase):
]) ])
def test_dense_rank(self): def test_dense_rank(self):
qs = Employee.objects.annotate(rank=Window( tests = [
expression=DenseRank(), ExtractYear(F('hire_date')).asc(),
order_by=ExtractYear(F('hire_date')).asc(), F('hire_date__year').asc(),
)) ]
self.assertQuerysetEqual(qs, [ for order_by in tests:
('Jones', 45000, 'Accounting', datetime.date(2005, 11, 1), 1), with self.subTest(order_by=order_by):
('Miller', 100000, 'Management', datetime.date(2005, 6, 1), 1), qs = Employee.objects.annotate(
('Johnson', 80000, 'Management', datetime.date(2005, 7, 1), 1), rank=Window(expression=DenseRank(), order_by=order_by),
('Smith', 55000, 'Sales', datetime.date(2007, 6, 1), 2), )
('Jenson', 45000, 'Accounting', datetime.date(2008, 4, 1), 3), self.assertQuerysetEqual(qs, [
('Smith', 38000, 'Marketing', datetime.date(2009, 10, 1), 4), ('Jones', 45000, 'Accounting', datetime.date(2005, 11, 1), 1),
('Brown', 53000, 'Sales', datetime.date(2009, 9, 1), 4), ('Miller', 100000, 'Management', datetime.date(2005, 6, 1), 1),
('Williams', 37000, 'Accounting', datetime.date(2009, 6, 1), 4), ('Johnson', 80000, 'Management', datetime.date(2005, 7, 1), 1),
('Wilkinson', 60000, 'IT', datetime.date(2011, 3, 1), 5), ('Smith', 55000, 'Sales', datetime.date(2007, 6, 1), 2),
('Johnson', 40000, 'Marketing', datetime.date(2012, 3, 1), 6), ('Jenson', 45000, 'Accounting', datetime.date(2008, 4, 1), 3),
('Moore', 34000, 'IT', datetime.date(2013, 8, 1), 7), ('Smith', 38000, 'Marketing', datetime.date(2009, 10, 1), 4),
('Adams', 50000, 'Accounting', datetime.date(2013, 7, 1), 7), ('Brown', 53000, 'Sales', datetime.date(2009, 9, 1), 4),
], lambda entry: (entry.name, entry.salary, entry.department, entry.hire_date, entry.rank), ordered=False) ('Williams', 37000, 'Accounting', datetime.date(2009, 6, 1), 4),
('Wilkinson', 60000, 'IT', datetime.date(2011, 3, 1), 5),
('Johnson', 40000, 'Marketing', datetime.date(2012, 3, 1), 6),
('Moore', 34000, 'IT', datetime.date(2013, 8, 1), 7),
('Adams', 50000, 'Accounting', datetime.date(2013, 7, 1), 7),
], lambda entry: (
entry.name,
entry.salary,
entry.department,
entry.hire_date,
entry.rank,
), ordered=False)
def test_department_salary(self): def test_department_salary(self):
qs = Employee.objects.annotate(department_sum=Window( qs = Employee.objects.annotate(department_sum=Window(
@ -96,7 +107,7 @@ class WindowFunctionTests(TestCase):
""" """
qs = Employee.objects.annotate(rank=Window( qs = Employee.objects.annotate(rank=Window(
expression=Rank(), expression=Rank(),
order_by=ExtractYear(F('hire_date')).asc(), order_by=F('hire_date__year').asc(),
)) ))
self.assertQuerysetEqual(qs, [ self.assertQuerysetEqual(qs, [
('Jones', 45000, 'Accounting', datetime.date(2005, 11, 1), 1), ('Jones', 45000, 'Accounting', datetime.date(2005, 11, 1), 1),
@ -523,7 +534,7 @@ class WindowFunctionTests(TestCase):
""" """
qs = Employee.objects.annotate(max=Window( qs = Employee.objects.annotate(max=Window(
expression=Max('salary'), expression=Max('salary'),
partition_by=[F('department'), ExtractYear(F('hire_date'))], partition_by=[F('department'), F('hire_date__year')],
)).order_by('department', 'hire_date', 'name') )).order_by('department', 'hire_date', 'name')
self.assertQuerysetEqual(qs, [ self.assertQuerysetEqual(qs, [
('Jones', 45000, 'Accounting', datetime.date(2005, 11, 1), 45000), ('Jones', 45000, 'Accounting', datetime.date(2005, 11, 1), 45000),
@ -753,26 +764,32 @@ class WindowFunctionTests(TestCase):
Detail(value={'department': 'HR', 'name': 'Smith', 'salary': 55000}), Detail(value={'department': 'HR', 'name': 'Smith', 'salary': 55000}),
Detail(value={'department': 'PR', 'name': 'Moore', 'salary': 90000}), Detail(value={'department': 'PR', 'name': 'Moore', 'salary': 90000}),
]) ])
qs = Detail.objects.annotate(department_sum=Window( tests = [
expression=Sum(Cast( (KeyTransform('department', 'value'), KeyTransform('name', 'value')),
KeyTextTransform('salary', 'value'), (F('value__department'), F('value__name')),
output_field=IntegerField(), ]
)), for partition_by, order_by in tests:
partition_by=[KeyTransform('department', 'value')], with self.subTest(partition_by=partition_by, order_by=order_by):
order_by=[KeyTransform('name', 'value')], qs = Detail.objects.annotate(department_sum=Window(
)).order_by('value__department', 'department_sum') expression=Sum(Cast(
self.assertQuerysetEqual(qs, [ KeyTextTransform('salary', 'value'),
('Brown', 'HR', 50000, 50000), output_field=IntegerField(),
('Smith', 'HR', 55000, 105000), )),
('Nowak', 'IT', 32000, 32000), partition_by=[partition_by],
('Smith', 'IT', 37000, 69000), order_by=[order_by],
('Moore', 'PR', 90000, 90000), )).order_by('value__department', 'department_sum')
], lambda entry: ( self.assertQuerysetEqual(qs, [
entry.value['name'], ('Brown', 'HR', 50000, 50000),
entry.value['department'], ('Smith', 'HR', 55000, 105000),
entry.value['salary'], ('Nowak', 'IT', 32000, 32000),
entry.department_sum, ('Smith', 'IT', 37000, 69000),
)) ('Moore', 'PR', 90000, 90000),
], lambda entry: (
entry.value['name'],
entry.value['department'],
entry.value['salary'],
entry.department_sum,
))
def test_invalid_start_value_range(self): def test_invalid_start_value_range(self):
msg = "start argument must be a negative integer, zero, or None, but got '3'." msg = "start argument must be a negative integer, zero, or None, but got '3'."

View File

@ -363,6 +363,14 @@ class NullableJSONModel(models.Model):
required_db_features = {'supports_json_field'} required_db_features = {'supports_json_field'}
class RelatedJSONModel(models.Model):
value = models.JSONField()
json_model = models.ForeignKey(NullableJSONModel, models.CASCADE)
class Meta:
required_db_features = {'supports_json_field'}
class AllFieldsModel(models.Model): class AllFieldsModel(models.Model):
big_integer = models.BigIntegerField() big_integer = models.BigIntegerField()
binary = models.BinaryField() binary = models.BinaryField()

View File

@ -25,7 +25,9 @@ from django.test import (
) )
from django.test.utils import CaptureQueriesContext from django.test.utils import CaptureQueriesContext
from .models import CustomJSONDecoder, JSONModel, NullableJSONModel from .models import (
CustomJSONDecoder, JSONModel, NullableJSONModel, RelatedJSONModel,
)
@skipUnlessDBFeature('supports_json_field') @skipUnlessDBFeature('supports_json_field')
@ -357,12 +359,11 @@ class TestQuerying(TestCase):
operator.itemgetter('key', 'count'), operator.itemgetter('key', 'count'),
) )
@skipUnlessDBFeature('allows_group_by_lob')
def test_ordering_grouping_by_count(self): def test_ordering_grouping_by_count(self):
qs = NullableJSONModel.objects.filter( qs = NullableJSONModel.objects.filter(
value__isnull=False, value__isnull=False,
).values('value__d__0').annotate(count=Count('value__d__0')).order_by('count') ).values('value__d__0').annotate(count=Count('value__d__0')).order_by('count')
self.assertQuerysetEqual(qs, [1, 11], operator.itemgetter('count')) self.assertQuerysetEqual(qs, [0, 1], operator.itemgetter('count'))
def test_order_grouping_custom_decoder(self): def test_order_grouping_custom_decoder(self):
NullableJSONModel.objects.create(value_custom={'a': 'b'}) NullableJSONModel.objects.create(value_custom={'a': 'b'})
@ -400,6 +401,17 @@ class TestQuerying(TestCase):
[self.objs[4]], [self.objs[4]],
) )
def test_key_transform_annotation_expression(self):
obj = NullableJSONModel.objects.create(value={'d': ['e', 'e']})
self.assertSequenceEqual(
NullableJSONModel.objects.filter(value__d__0__isnull=False).annotate(
key=F('value__d'),
chain=F('key__0'),
expr=Cast('key', models.JSONField()),
).filter(chain=F('expr__1')),
[obj],
)
def test_nested_key_transform_expression(self): def test_nested_key_transform_expression(self):
self.assertSequenceEqual( self.assertSequenceEqual(
NullableJSONModel.objects.filter(value__d__0__isnull=False).annotate( NullableJSONModel.objects.filter(value__d__0__isnull=False).annotate(
@ -410,6 +422,19 @@ class TestQuerying(TestCase):
[self.objs[4]], [self.objs[4]],
) )
def test_nested_key_transform_annotation_expression(self):
obj = NullableJSONModel.objects.create(
value={'d': ['e', {'f': 'g'}, {'f': 'g'}]},
)
self.assertSequenceEqual(
NullableJSONModel.objects.filter(value__d__0__isnull=False).annotate(
key=F('value__d'),
chain=F('key__1__f'),
expr=Cast('key', models.JSONField()),
).filter(chain=F('expr__2__f')),
[obj],
)
def test_nested_key_transform_on_subquery(self): def test_nested_key_transform_on_subquery(self):
self.assertSequenceEqual( self.assertSequenceEqual(
NullableJSONModel.objects.filter(value__d__0__isnull=False).annotate( NullableJSONModel.objects.filter(value__d__0__isnull=False).annotate(
@ -449,12 +474,15 @@ class TestQuerying(TestCase):
tests = [ tests = [
(Q(value__baz__has_key='a'), self.objs[7]), (Q(value__baz__has_key='a'), self.objs[7]),
(Q(value__has_key=KeyTransform('a', KeyTransform('baz', 'value'))), self.objs[7]), (Q(value__has_key=KeyTransform('a', KeyTransform('baz', 'value'))), self.objs[7]),
(Q(value__has_key=F('value__baz__a')), self.objs[7]),
(Q(value__has_key=KeyTransform('c', KeyTransform('baz', 'value'))), self.objs[7]), (Q(value__has_key=KeyTransform('c', KeyTransform('baz', 'value'))), self.objs[7]),
(Q(value__has_key=F('value__baz__c')), self.objs[7]),
(Q(value__d__1__has_key='f'), self.objs[4]), (Q(value__d__1__has_key='f'), self.objs[4]),
( (
Q(value__has_key=KeyTransform('f', KeyTransform('1', KeyTransform('d', 'value')))), Q(value__has_key=KeyTransform('f', KeyTransform('1', KeyTransform('d', 'value')))),
self.objs[4], self.objs[4],
) ),
(Q(value__has_key=F('value__d__1__f')), self.objs[4]),
] ]
for condition, expected in tests: for condition, expected in tests:
with self.subTest(condition=condition): with self.subTest(condition=condition):
@ -469,6 +497,7 @@ class TestQuerying(TestCase):
Q(value__1__has_key='b'), Q(value__1__has_key='b'),
Q(value__has_key=KeyTransform('b', KeyTransform(1, 'value'))), Q(value__has_key=KeyTransform('b', KeyTransform(1, 'value'))),
Q(value__has_key=KeyTransform('b', KeyTransform('1', 'value'))), Q(value__has_key=KeyTransform('b', KeyTransform('1', 'value'))),
Q(value__has_key=F('value__1__b')),
] ]
for condition in tests: for condition in tests:
with self.subTest(condition=condition): with self.subTest(condition=condition):
@ -733,11 +762,13 @@ class TestQuerying(TestCase):
[KeyTransform('foo', KeyTransform('bax', 'value'))], [KeyTransform('foo', KeyTransform('bax', 'value'))],
[self.objs[7]], [self.objs[7]],
), ),
('value__foo__in', [F('value__bax__foo')], [self.objs[7]]),
( (
'value__foo__in', 'value__foo__in',
[KeyTransform('foo', KeyTransform('bax', 'value')), 'baz'], [KeyTransform('foo', KeyTransform('bax', 'value')), 'baz'],
[self.objs[7]], [self.objs[7]],
), ),
('value__foo__in', [F('value__bax__foo'), 'baz'], [self.objs[7]]),
('value__foo__in', ['bar', 'baz'], [self.objs[7]]), ('value__foo__in', ['bar', 'baz'], [self.objs[7]]),
('value__bar__in', [['foo', 'bar']], [self.objs[7]]), ('value__bar__in', [['foo', 'bar']], [self.objs[7]]),
('value__bar__in', [['foo', 'bar'], ['a']], [self.objs[7]]), ('value__bar__in', [['foo', 'bar'], ['a']], [self.objs[7]]),
@ -850,6 +881,7 @@ class TestQuerying(TestCase):
('value__d__contains', 'e'), ('value__d__contains', 'e'),
('value__d__contains', [{'f': 'g'}]), ('value__d__contains', [{'f': 'g'}]),
('value__contains', KeyTransform('bax', 'value')), ('value__contains', KeyTransform('bax', 'value')),
('value__contains', F('value__bax')),
('value__baz__contains', {'a': 'b'}), ('value__baz__contains', {'a': 'b'}),
('value__baz__contained_by', {'a': 'b', 'c': 'd', 'e': 'f'}), ('value__baz__contained_by', {'a': 'b', 'c': 'd', 'e': 'f'}),
( (
@ -869,3 +901,22 @@ class TestQuerying(TestCase):
self.assertIs(NullableJSONModel.objects.filter( self.assertIs(NullableJSONModel.objects.filter(
**{lookup: value}, **{lookup: value},
).exists(), True) ).exists(), True)
def test_join_key_transform_annotation_expression(self):
related_obj = RelatedJSONModel.objects.create(
value={'d': ['f', 'e']},
json_model=self.objs[4],
)
RelatedJSONModel.objects.create(
value={'d': ['e', 'f']},
json_model=self.objs[4],
)
self.assertSequenceEqual(
RelatedJSONModel.objects.annotate(
key=F('value__d'),
related_key=F('json_model__value__d'),
chain=F('key__1'),
expr=Cast('key', models.JSONField()),
).filter(chain=F('related_key__0')),
[related_obj],
)

View File

@ -434,6 +434,13 @@ class TestQuerying(PostgreSQLTestCase):
self.objs[:1], self.objs[:1],
) )
def test_index_annotation(self):
qs = NullableIntegerArrayModel.objects.annotate(second=models.F('field__1'))
self.assertCountEqual(
qs.values_list('second', flat=True),
[None, None, None, 3, 30],
)
def test_overlap(self): def test_overlap(self):
self.assertSequenceEqual( self.assertSequenceEqual(
NullableIntegerArrayModel.objects.filter(field__overlap=[1, 2]), NullableIntegerArrayModel.objects.filter(field__overlap=[1, 2]),
@ -495,6 +502,15 @@ class TestQuerying(PostgreSQLTestCase):
self.objs[2:3], self.objs[2:3],
) )
def test_slice_annotation(self):
qs = NullableIntegerArrayModel.objects.annotate(
first_two=models.F('field__0_2'),
)
self.assertCountEqual(
qs.values_list('first_two', flat=True),
[None, [1], [2], [2, 3], [20, 30]],
)
def test_usage_in_subquery(self): def test_usage_in_subquery(self):
self.assertSequenceEqual( self.assertSequenceEqual(
NullableIntegerArrayModel.objects.filter( NullableIntegerArrayModel.objects.filter(

View File

@ -2,7 +2,7 @@ import json
from django.core import checks, exceptions, serializers from django.core import checks, exceptions, serializers
from django.db import connection from django.db import connection
from django.db.models import OuterRef, Subquery from django.db.models import F, OuterRef, Subquery
from django.db.models.expressions import RawSQL from django.db.models.expressions import RawSQL
from django.forms import Form from django.forms import Form
from django.test.utils import CaptureQueriesContext, isolate_apps from django.test.utils import CaptureQueriesContext, isolate_apps
@ -137,6 +137,13 @@ class TestQuerying(PostgreSQLTestCase):
self.objs[:2] self.objs[:2]
) )
def test_key_transform_annotation(self):
qs = HStoreModel.objects.annotate(a=F('field__a'))
self.assertCountEqual(
qs.values_list('a', flat=True),
['b', 'b', None, None, None],
)
def test_keys(self): def test_keys(self):
self.assertSequenceEqual( self.assertSequenceEqual(
HStoreModel.objects.filter(field__keys=['a']), HStoreModel.objects.filter(field__keys=['a']),

View File

@ -2,9 +2,10 @@ import unittest
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db import IntegrityError, connection, transaction from django.db import IntegrityError, connection, transaction
from django.db.models import Count, F, Max from django.db.models import CharField, Count, F, IntegerField, Max
from django.db.models.functions import Concat, Lower from django.db.models.functions import Abs, Concat, Lower
from django.test import TestCase from django.test import TestCase
from django.test.utils import register_lookup
from .models import A, B, Bar, D, DataPoint, Foo, RelatedPoint, UniqueNumber from .models import A, B, Bar, D, DataPoint, Foo, RelatedPoint, UniqueNumber
@ -154,6 +155,13 @@ class AdvancedTests(TestCase):
with self.assertRaisesMessage(FieldError, msg): with self.assertRaisesMessage(FieldError, msg):
Bar.objects.update(m2m_foo='whatever') Bar.objects.update(m2m_foo='whatever')
def test_update_transformed_field(self):
A.objects.create(x=5)
A.objects.create(x=-6)
with register_lookup(IntegerField, Abs):
A.objects.update(x=F('x__abs'))
self.assertCountEqual(A.objects.values_list('x', flat=True), [5, 6])
def test_update_annotated_queryset(self): def test_update_annotated_queryset(self):
""" """
Update of a queryset that's been annotated. Update of a queryset that's been annotated.
@ -194,14 +202,18 @@ class AdvancedTests(TestCase):
def test_update_with_joined_field_annotation(self): def test_update_with_joined_field_annotation(self):
msg = 'Joined field references are not permitted in this query' msg = 'Joined field references are not permitted in this query'
for annotation in ( with register_lookup(CharField, Lower):
F('data__name'), for annotation in (
Lower('data__name'), F('data__name'),
Concat('data__name', 'data__value'), F('data__name__lower'),
): Lower('data__name'),
with self.subTest(annotation=annotation): Concat('data__name', 'data__value'),
with self.assertRaisesMessage(FieldError, msg): ):
RelatedPoint.objects.annotate(new_name=annotation).update(name=F('new_name')) with self.subTest(annotation=annotation):
with self.assertRaisesMessage(FieldError, msg):
RelatedPoint.objects.annotate(
new_name=annotation,
).update(name=F('new_name'))
@unittest.skipUnless( @unittest.skipUnless(