Refs #26511 -- Fixed json.KeyTextTransform() on MySQL/MariaDB.

This commit is contained in:
Mariusz Felisiak 2022-08-18 21:02:29 +02:00 committed by GitHub
parent bd36023100
commit e9fd2b5724
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 1 deletions

View File

@ -4,6 +4,7 @@ from django import forms
from django.core import checks, exceptions from django.core import checks, exceptions
from django.db import NotSupportedError, connections, router from django.db import NotSupportedError, connections, router
from django.db.models import lookups from django.db.models import lookups
from django.db.models.fields import TextField
from django.db.models.lookups import PostgresOperatorLookup, Transform from django.db.models.lookups import PostgresOperatorLookup, Transform
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
@ -366,6 +367,17 @@ class KeyTransform(Transform):
class KeyTextTransform(KeyTransform): class KeyTextTransform(KeyTransform):
postgres_operator = "->>" postgres_operator = "->>"
postgres_nested_operator = "#>>" postgres_nested_operator = "#>>"
output_field = TextField()
def as_mysql(self, compiler, connection):
if connection.mysql_is_mariadb:
# MariaDB doesn't support -> and ->> operators (see MDEV-13594).
sql, params = super().as_mysql(compiler, connection)
return "JSON_UNQUOTE(%s)" % sql, params
else:
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
json_path = compile_json_path(key_transforms)
return "(%s ->> %%s)" % lhs, tuple(params) + (json_path,)
class KeyTransformTextLookupMixin: class KeyTransformTextLookupMixin:

View File

@ -370,6 +370,7 @@ class TestQuerying(TestCase):
).order_by("key"), ).order_by("key"),
): ):
self.assertSequenceEqual(qs, [self.objs[4]]) self.assertSequenceEqual(qs, [self.objs[4]])
none_val = "" if connection.features.interprets_empty_strings_as_nulls else None
qs = NullableJSONModel.objects.filter(value__isnull=False) qs = NullableJSONModel.objects.filter(value__isnull=False)
self.assertQuerysetEqual( self.assertQuerysetEqual(
qs.filter(value__isnull=False) qs.filter(value__isnull=False)
@ -381,7 +382,7 @@ class TestQuerying(TestCase):
.values("key") .values("key")
.annotate(count=Count("key")) .annotate(count=Count("key"))
.order_by("count"), .order_by("count"),
[(None, 0), ("g", 1)], [(none_val, 0), ("g", 1)],
operator.itemgetter("key", "count"), operator.itemgetter("key", "count"),
) )
@ -494,6 +495,17 @@ class TestQuerying(TestCase):
[self.objs[4]], [self.objs[4]],
) )
def test_key_text_transform_char_lookup(self):
qs = NullableJSONModel.objects.annotate(
char_value=KeyTextTransform("foo", "value"),
).filter(char_value__startswith="bar")
self.assertSequenceEqual(qs, [self.objs[7]])
qs = NullableJSONModel.objects.annotate(
char_value=KeyTextTransform(1, KeyTextTransform("bar", "value")),
).filter(char_value__startswith="bar")
self.assertSequenceEqual(qs, [self.objs[7]])
def test_expression_wrapper_key_transform(self): def test_expression_wrapper_key_transform(self):
self.assertCountEqual( self.assertCountEqual(
NullableJSONModel.objects.annotate( NullableJSONModel.objects.annotate(