From bef6f7584280f1cc80e5e2d80b7ad073a93d26ec Mon Sep 17 00:00:00 2001 From: Illia Volochii Date: Sat, 16 Jan 2021 14:52:11 +0200 Subject: [PATCH] Fixed #32358 -- Fixed queryset crash when grouping by annotation with Distance()/Area(). Made MeasureBase hashable. --- django/contrib/gis/measure.py | 3 +++ tests/gis_tests/distapp/tests.py | 22 +++++++++++++++++++++- tests/gis_tests/test_measure.py | 16 ++++++++++++++++ 3 files changed, 40 insertions(+), 1 deletion(-) diff --git a/django/contrib/gis/measure.py b/django/contrib/gis/measure.py index 9cb2db83a2..7e571f14dd 100644 --- a/django/contrib/gis/measure.py +++ b/django/contrib/gis/measure.py @@ -89,6 +89,9 @@ class MeasureBase: else: return NotImplemented + def __hash__(self): + return hash(self.standard) + def __lt__(self, other): if isinstance(other, self.__class__): return self.standard < other.standard diff --git a/tests/gis_tests/distapp/tests.py b/tests/gis_tests/distapp/tests.py index 30f64e97cb..571d091cb8 100644 --- a/tests/gis_tests/distapp/tests.py +++ b/tests/gis_tests/distapp/tests.py @@ -4,7 +4,9 @@ from django.contrib.gis.db.models.functions import ( from django.contrib.gis.geos import GEOSGeometry, LineString, Point from django.contrib.gis.measure import D # alias for Distance from django.db import NotSupportedError, connection -from django.db.models import Exists, F, OuterRef, Q +from django.db.models import ( + Case, Count, Exists, F, IntegerField, OuterRef, Q, Value, When, +) from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature from ..utils import FuncTestMixin @@ -214,6 +216,24 @@ class DistanceTest(TestCase): SouthTexasCity.objects.count(), ) + @skipUnlessDBFeature('supports_distances_lookups') + def test_distance_annotation_group_by(self): + stx_pnt = self.stx_pnt.transform( + SouthTexasCity._meta.get_field('point').srid, + clone=True, + ) + qs = SouthTexasCity.objects.annotate( + relative_distance=Case( + When(point__distance_lte=(stx_pnt, D(km=20)), then=Value(20)), + default=Value(100), + output_field=IntegerField(), + ), + ).values('relative_distance').annotate(count=Count('pk')) + self.assertCountEqual(qs, [ + {'relative_distance': 20, 'count': 5}, + {'relative_distance': 100, 'count': 4}, + ]) + def test_mysql_geodetic_distance_error(self): if not connection.ops.mysql: self.skipTest('This is a MySQL-specific test.') diff --git a/tests/gis_tests/test_measure.py b/tests/gis_tests/test_measure.py index 8c468139e7..bc2249a816 100644 --- a/tests/gis_tests/test_measure.py +++ b/tests/gis_tests/test_measure.py @@ -151,6 +151,14 @@ class DistanceTest(unittest.TestCase): with self.subTest(nm=nm): self.assertEqual(att, D.unit_attname(nm)) + def test_hash(self): + d1 = D(m=99) + d2 = D(m=100) + d3 = D(km=0.1) + self.assertEqual(hash(d2), hash(d3)) + self.assertNotEqual(hash(d1), hash(d2)) + self.assertNotEqual(hash(d1), hash(d3)) + class AreaTest(unittest.TestCase): "Testing the Area object" @@ -272,6 +280,14 @@ class AreaTest(unittest.TestCase): self.assertEqual(repr(a1), 'Area(sq_m=100.0)') self.assertEqual(repr(a2), 'Area(sq_km=3.5)') + def test_hash(self): + a1 = A(sq_m=100) + a2 = A(sq_m=1000000) + a3 = A(sq_km=1) + self.assertEqual(hash(a2), hash(a3)) + self.assertNotEqual(hash(a1), hash(a2)) + self.assertNotEqual(hash(a1), hash(a3)) + def suite(): s = unittest.TestSuite()