Fixed #32483 -- Fixed QuerySet.values()/values_list() on JSONField key transforms with booleans on SQLite.

Thanks Matthew Cornell for the report.
This commit is contained in:
Mariusz Felisiak 2021-02-26 07:52:16 +01:00
parent c4df8b86c7
commit 71ec102b01
4 changed files with 30 additions and 33 deletions

View File

@ -21,6 +21,9 @@ class DatabaseOperations(BaseDatabaseOperations):
'DateTimeField': 'TEXT', 'DateTimeField': 'TEXT',
} }
explain_prefix = 'EXPLAIN QUERY PLAN' explain_prefix = 'EXPLAIN QUERY PLAN'
# List of datatypes to that cannot be extracted with JSON_EXTRACT() on
# SQLite. Use JSON_TYPE() instead.
jsonfield_datatype_values = frozenset(['null', 'false', 'true'])
def bulk_batch_size(self, fields, objs): def bulk_batch_size(self, fields, objs):
""" """

View File

@ -260,15 +260,6 @@ class CaseInsensitiveMixin:
class JSONExact(lookups.Exact): class JSONExact(lookups.Exact):
can_use_none_as_rhs = True can_use_none_as_rhs = True
def process_lhs(self, compiler, connection):
lhs, lhs_params = super().process_lhs(compiler, connection)
if connection.vendor == 'sqlite':
rhs, rhs_params = super().process_rhs(compiler, connection)
if rhs == '%s' and rhs_params == [None]:
# Use JSON_TYPE instead of JSON_EXTRACT for NULLs.
lhs = "JSON_TYPE(%s, '$')" % lhs
return lhs, lhs_params
def process_rhs(self, compiler, connection): def process_rhs(self, compiler, connection):
rhs, rhs_params = super().process_rhs(compiler, connection) rhs, rhs_params = super().process_rhs(compiler, connection)
# Treat None lookup values as null. # Treat None lookup values as null.
@ -340,7 +331,13 @@ class KeyTransform(Transform):
def as_sqlite(self, compiler, connection): def as_sqlite(self, compiler, connection):
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection) lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
json_path = compile_json_path(key_transforms) json_path = compile_json_path(key_transforms)
return 'JSON_EXTRACT(%s, %%s)' % lhs, tuple(params) + (json_path,) datatype_values = ','.join([
repr(datatype) for datatype in connection.ops.jsonfield_datatype_values
])
return (
"(CASE WHEN JSON_TYPE(%s, %%s) IN (%s) "
"THEN JSON_TYPE(%s, %%s) ELSE JSON_EXTRACT(%s, %%s) END)"
) % (lhs, datatype_values, lhs, lhs), (tuple(params) + (json_path,)) * 3
class KeyTextTransform(KeyTransform): class KeyTextTransform(KeyTransform):
@ -408,7 +405,10 @@ class KeyTransformIn(lookups.In):
sql = sql % 'JSON_QUERY' sql = sql % 'JSON_QUERY'
else: else:
sql = sql % 'JSON_VALUE' sql = sql % 'JSON_VALUE'
elif connection.vendor in {'sqlite', 'mysql'}: elif connection.vendor == 'mysql' or (
connection.vendor == 'sqlite' and
params[0] not in connection.ops.jsonfield_datatype_values
):
sql = "JSON_EXTRACT(%s, '$')" sql = "JSON_EXTRACT(%s, '$')"
if connection.vendor == 'mysql' and connection.mysql_is_mariadb: if connection.vendor == 'mysql' and connection.mysql_is_mariadb:
sql = 'JSON_UNQUOTE(%s)' % sql sql = 'JSON_UNQUOTE(%s)' % sql
@ -416,15 +416,6 @@ class KeyTransformIn(lookups.In):
class KeyTransformExact(JSONExact): class KeyTransformExact(JSONExact):
def process_lhs(self, compiler, connection):
lhs, lhs_params = super().process_lhs(compiler, connection)
if connection.vendor == 'sqlite':
rhs, rhs_params = super().process_rhs(compiler, connection)
if rhs == '%s' and rhs_params == ['null']:
lhs, *_ = self.lhs.preprocess_lhs(compiler, connection)
lhs = 'JSON_TYPE(%s, %%s)' % lhs
return lhs, lhs_params
def process_rhs(self, compiler, connection): def process_rhs(self, compiler, connection):
if isinstance(self.rhs, KeyTransform): if isinstance(self.rhs, KeyTransform):
return super(lookups.Exact, self).process_rhs(compiler, connection) return super(lookups.Exact, self).process_rhs(compiler, connection)
@ -440,7 +431,12 @@ class KeyTransformExact(JSONExact):
func.append(sql % 'JSON_VALUE') func.append(sql % 'JSON_VALUE')
rhs = rhs % tuple(func) rhs = rhs % tuple(func)
elif connection.vendor == 'sqlite': elif connection.vendor == 'sqlite':
func = ["JSON_EXTRACT(%s, '$')" if value != 'null' else '%s' for value in rhs_params] func = []
for value in rhs_params:
if value in connection.ops.jsonfield_datatype_values:
func.append('%s')
else:
func.append("JSON_EXTRACT(%s, '$')")
rhs = rhs % tuple(func) rhs = rhs % tuple(func)
return rhs, rhs_params return rhs, rhs_params

View File

@ -695,12 +695,6 @@ You can also refer to fields on related models with reverse relations through
pronounced if you include multiple such fields in your ``values()`` query, pronounced if you include multiple such fields in your ``values()`` query,
in which case all possible combinations will be returned. in which case all possible combinations will be returned.
.. admonition:: Boolean values for ``JSONField`` on SQLite
Due to the way the ``JSON_EXTRACT`` SQL function is implemented on SQLite,
``values()`` will return ``1`` and ``0`` instead of ``True`` and ``False``
for :class:`~django.db.models.JSONField` key transforms.
``values_list()`` ``values_list()``
~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~
@ -771,12 +765,6 @@ not having any author::
>>> Entry.objects.values_list('authors') >>> Entry.objects.values_list('authors')
<QuerySet [('Noam Chomsky',), ('George Orwell',), (None,)]> <QuerySet [('Noam Chomsky',), ('George Orwell',), (None,)]>
.. admonition:: Boolean values for ``JSONField`` on SQLite
Due to the way the ``JSON_EXTRACT`` SQL function is implemented on SQLite,
``values_list()`` will return ``1`` and ``0`` instead of ``True`` and
``False`` for :class:`~django.db.models.JSONField` key transforms.
``dates()`` ``dates()``
~~~~~~~~~~~ ~~~~~~~~~~~

View File

@ -808,6 +808,16 @@ class TestQuerying(TestCase):
with self.subTest(lookup=lookup): with self.subTest(lookup=lookup):
self.assertEqual(qs.values_list(lookup, flat=True).get(), expected) self.assertEqual(qs.values_list(lookup, flat=True).get(), expected)
def test_key_values_boolean(self):
qs = NullableJSONModel.objects.filter(value__h=True, value__i=False)
tests = [
('value__h', True),
('value__i', False),
]
for lookup, expected in tests:
with self.subTest(lookup=lookup):
self.assertIs(qs.values_list(lookup, flat=True).get(), expected)
@skipUnlessDBFeature('supports_json_field_contains') @skipUnlessDBFeature('supports_json_field_contains')
def test_key_contains(self): def test_key_contains(self):
self.assertIs(NullableJSONModel.objects.filter(value__foo__contains='ar').exists(), False) self.assertIs(NullableJSONModel.objects.filter(value__foo__contains='ar').exists(), False)