Fixed #32358 -- Fixed queryset crash when grouping by annotation with Distance()/Area().

Made MeasureBase hashable.
This commit is contained in:
Illia Volochii 2021-01-16 14:52:11 +02:00 committed by Mariusz Felisiak
parent 0aff3fd711
commit bef6f75842
3 changed files with 40 additions and 1 deletions

View File

@ -89,6 +89,9 @@ class MeasureBase:
else: else:
return NotImplemented return NotImplemented
def __hash__(self):
return hash(self.standard)
def __lt__(self, other): def __lt__(self, other):
if isinstance(other, self.__class__): if isinstance(other, self.__class__):
return self.standard < other.standard return self.standard < other.standard

View File

@ -4,7 +4,9 @@ from django.contrib.gis.db.models.functions import (
from django.contrib.gis.geos import GEOSGeometry, LineString, Point from django.contrib.gis.geos import GEOSGeometry, LineString, Point
from django.contrib.gis.measure import D # alias for Distance from django.contrib.gis.measure import D # alias for Distance
from django.db import NotSupportedError, connection from django.db import NotSupportedError, connection
from django.db.models import Exists, F, OuterRef, Q from django.db.models import (
Case, Count, Exists, F, IntegerField, OuterRef, Q, Value, When,
)
from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
from ..utils import FuncTestMixin from ..utils import FuncTestMixin
@ -214,6 +216,24 @@ class DistanceTest(TestCase):
SouthTexasCity.objects.count(), SouthTexasCity.objects.count(),
) )
@skipUnlessDBFeature('supports_distances_lookups')
def test_distance_annotation_group_by(self):
stx_pnt = self.stx_pnt.transform(
SouthTexasCity._meta.get_field('point').srid,
clone=True,
)
qs = SouthTexasCity.objects.annotate(
relative_distance=Case(
When(point__distance_lte=(stx_pnt, D(km=20)), then=Value(20)),
default=Value(100),
output_field=IntegerField(),
),
).values('relative_distance').annotate(count=Count('pk'))
self.assertCountEqual(qs, [
{'relative_distance': 20, 'count': 5},
{'relative_distance': 100, 'count': 4},
])
def test_mysql_geodetic_distance_error(self): def test_mysql_geodetic_distance_error(self):
if not connection.ops.mysql: if not connection.ops.mysql:
self.skipTest('This is a MySQL-specific test.') self.skipTest('This is a MySQL-specific test.')

View File

@ -151,6 +151,14 @@ class DistanceTest(unittest.TestCase):
with self.subTest(nm=nm): with self.subTest(nm=nm):
self.assertEqual(att, D.unit_attname(nm)) self.assertEqual(att, D.unit_attname(nm))
def test_hash(self):
d1 = D(m=99)
d2 = D(m=100)
d3 = D(km=0.1)
self.assertEqual(hash(d2), hash(d3))
self.assertNotEqual(hash(d1), hash(d2))
self.assertNotEqual(hash(d1), hash(d3))
class AreaTest(unittest.TestCase): class AreaTest(unittest.TestCase):
"Testing the Area object" "Testing the Area object"
@ -272,6 +280,14 @@ class AreaTest(unittest.TestCase):
self.assertEqual(repr(a1), 'Area(sq_m=100.0)') self.assertEqual(repr(a1), 'Area(sq_m=100.0)')
self.assertEqual(repr(a2), 'Area(sq_km=3.5)') self.assertEqual(repr(a2), 'Area(sq_km=3.5)')
def test_hash(self):
a1 = A(sq_m=100)
a2 = A(sq_m=1000000)
a3 = A(sq_km=1)
self.assertEqual(hash(a2), hash(a3))
self.assertNotEqual(hash(a1), hash(a2))
self.assertNotEqual(hash(a1), hash(a3))
def suite(): def suite():
s = unittest.TestSuite() s = unittest.TestSuite()