Fixed #31956 -- Fixed crash of ordering by JSONField with a custom decoder on PostgreSQL.

Thanks Marc Debureaux for the report.
Thanks Simon Charette, Nick Pope, and Adam Johnson for reviews.
This commit is contained in:
Mariusz Felisiak 2020-08-28 07:56:04 +02:00
parent 2210539142
commit 0be51d2226
8 changed files with 21 additions and 27 deletions

View File

@ -48,7 +48,7 @@ class JSONBAgg(OrderableAggMixin, Aggregate):
def convert_value(self, value, expression, connection): def convert_value(self, value, expression, connection):
if not value: if not value:
return [] return '[]'
return value return value

View File

@ -153,13 +153,6 @@ class BaseDatabaseOperations:
""" """
return self.date_extract_sql(lookup_type, field_name) return self.date_extract_sql(lookup_type, field_name)
def json_cast_text_sql(self, field_name):
"""Return the SQL to cast a JSON value to text value."""
raise NotImplementedError(
'subclasses of BaseDatabaseOperations may require a '
'json_cast_text_sql() method'
)
def deferrable_sql(self): def deferrable_sql(self):
""" """
Return the SQL to make a constraint "initially deferred" during a Return the SQL to make a constraint "initially deferred" during a

View File

@ -200,7 +200,10 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# Set the isolation level to the value from OPTIONS. # Set the isolation level to the value from OPTIONS.
if self.isolation_level != connection.isolation_level: if self.isolation_level != connection.isolation_level:
connection.set_session(isolation_level=self.isolation_level) connection.set_session(isolation_level=self.isolation_level)
# Register dummy loads() to avoid a round trip from psycopg2's decode
# to json.dumps() to json.loads(), when using a custom decoder in
# JSONField.
psycopg2.extras.register_default_jsonb(conn_or_curs=connection, loads=lambda x: x)
return connection return connection
def ensure_timezone(self): def ensure_timezone(self):

View File

@ -74,9 +74,6 @@ class DatabaseOperations(BaseDatabaseOperations):
def time_trunc_sql(self, lookup_type, field_name): def time_trunc_sql(self, lookup_type, field_name):
return "DATE_TRUNC('%s', %s)::time" % (lookup_type, field_name) return "DATE_TRUNC('%s', %s)::time" % (lookup_type, field_name)
def json_cast_text_sql(self, field_name):
return '(%s)::text' % field_name
def deferrable_sql(self): def deferrable_sql(self):
return " DEFERRABLE INITIALLY DEFERRED" return " DEFERRABLE INITIALLY DEFERRED"

View File

@ -70,8 +70,6 @@ class JSONField(CheckFieldDefaultMixin, Field):
def from_db_value(self, value, expression, connection): def from_db_value(self, value, expression, connection):
if value is None: if value is None:
return value return value
if connection.features.has_native_json_field and self.decoder is None:
return value
try: try:
return json.loads(value, cls=self.decoder) return json.loads(value, cls=self.decoder)
except json.JSONDecodeError: except json.JSONDecodeError:
@ -91,14 +89,6 @@ class JSONField(CheckFieldDefaultMixin, Field):
return transform return transform
return KeyTransformFactory(name) return KeyTransformFactory(name)
def select_format(self, compiler, sql, params):
if (
compiler.connection.features.has_native_json_field and
self.decoder is not None
):
return compiler.connection.ops.json_cast_text_sql(sql), params
return super().select_format(compiler, sql, params)
def validate(self, value, model_instance): def validate(self, value, model_instance):
super().validate(value, model_instance) super().validate(value, model_instance)
try: try:

View File

@ -51,3 +51,7 @@ Bugfixes
* Fixed detecting an async ``get_response`` callable in various builtin * Fixed detecting an async ``get_response`` callable in various builtin
middlewares (:ticket:`31928`). middlewares (:ticket:`31928`).
* Fixed a ``QuerySet.order_by()`` crash on PostgreSQL when ordering and
grouping by :class:`~django.db.models.JSONField` with a custom
:attr:`~django.db.models.JSONField.decoder` (:ticket:`31956`).

View File

@ -117,11 +117,6 @@ class SimpleDatabaseOperationTests(SimpleTestCase):
with self.assertRaisesMessage(NotImplementedError, self.may_require_msg % 'datetime_extract_sql'): with self.assertRaisesMessage(NotImplementedError, self.may_require_msg % 'datetime_extract_sql'):
self.ops.datetime_extract_sql(None, None, None) self.ops.datetime_extract_sql(None, None, None)
def test_json_cast_text_sql(self):
msg = self.may_require_msg % 'json_cast_text_sql'
with self.assertRaisesMessage(NotImplementedError, msg):
self.ops.json_cast_text_sql(None)
class DatabaseOperationTests(TestCase): class DatabaseOperationTests(TestCase):
def setUp(self): def setUp(self):

View File

@ -359,6 +359,18 @@ class TestQuerying(TestCase):
).values('value__d__0').annotate(count=Count('value__d__0')).order_by('count') ).values('value__d__0').annotate(count=Count('value__d__0')).order_by('count')
self.assertQuerysetEqual(qs, [1, 11], operator.itemgetter('count')) self.assertQuerysetEqual(qs, [1, 11], operator.itemgetter('count'))
def test_order_grouping_custom_decoder(self):
NullableJSONModel.objects.create(value_custom={'a': 'b'})
qs = NullableJSONModel.objects.filter(value_custom__isnull=False)
self.assertSequenceEqual(
qs.values(
'value_custom__a',
).annotate(
count=Count('id'),
).order_by('value_custom__a'),
[{'value_custom__a': 'b', 'count': 1}],
)
def test_key_transform_raw_expression(self): def test_key_transform_raw_expression(self):
expr = RawSQL(self.raw_sql, ['{"x": "bar"}']) expr = RawSQL(self.raw_sql, ['{"x": "bar"}'])
self.assertSequenceEqual( self.assertSequenceEqual(