Refs #33308 -- Used get_db_prep_value() to adapt JSONFields.

This commit is contained in:
Simon Charette 2022-10-31 22:28:17 -04:00 committed by Mariusz Felisiak
parent d87a7b9f4b
commit 5c23d9f0c3
3 changed files with 32 additions and 5 deletions

View File

@ -1,5 +1,6 @@
import datetime import datetime
import decimal import decimal
import json
from importlib import import_module from importlib import import_module
import sqlparse import sqlparse
@ -575,6 +576,9 @@ class BaseDatabaseOperations:
""" """
return value or None return value or None
def adapt_json_value(self, value, encoder):
return json.dumps(value, cls=encoder)
def year_lookup_bounds_for_date_field(self, value, iso_year=False): def year_lookup_bounds_for_date_field(self, value, iso_year=False):
""" """
Return a two-elements list with the lower and upper bound to be used Return a two-elements list with the lower and upper bound to be used

View File

@ -1,4 +1,8 @@
import json
from functools import lru_cache, partial
from psycopg2.extras import Inet from psycopg2.extras import Inet
from psycopg2.extras import Json as Jsonb
from django.conf import settings from django.conf import settings
from django.db.backends.base.operations import BaseDatabaseOperations from django.db.backends.base.operations import BaseDatabaseOperations
@ -6,6 +10,13 @@ from django.db.backends.utils import split_tzname_delta
from django.db.models.constants import OnConflict from django.db.models.constants import OnConflict
@lru_cache
def get_json_dumps(encoder):
if encoder is None:
return json.dumps
return partial(json.dumps, cls=encoder)
class DatabaseOperations(BaseDatabaseOperations): class DatabaseOperations(BaseDatabaseOperations):
cast_char_field_without_max_length = "varchar" cast_char_field_without_max_length = "varchar"
explain_prefix = "EXPLAIN" explain_prefix = "EXPLAIN"
@ -308,6 +319,9 @@ class DatabaseOperations(BaseDatabaseOperations):
return Inet(value) return Inet(value)
return None return None
def adapt_json_value(self, value, encoder):
return Jsonb(value, dumps=get_json_dumps(encoder))
def subtract_temporals(self, internal_type, lhs, rhs): def subtract_temporals(self, internal_type, lhs, rhs):
if internal_type == "DateField": if internal_type == "DateField":
lhs_sql, lhs_params = lhs lhs_sql, lhs_params = lhs

View File

@ -6,7 +6,11 @@ from django.db import NotSupportedError, connections, router
from django.db.models import lookups from django.db.models import lookups
from django.db.models.constants import LOOKUP_SEP from django.db.models.constants import LOOKUP_SEP
from django.db.models.fields import TextField from django.db.models.fields import TextField
from django.db.models.lookups import PostgresOperatorLookup, Transform from django.db.models.lookups import (
FieldGetDbPrepValueMixin,
PostgresOperatorLookup,
Transform,
)
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from . import Field from . import Field
@ -92,10 +96,15 @@ class JSONField(CheckFieldDefaultMixin, Field):
def get_internal_type(self): def get_internal_type(self):
return "JSONField" return "JSONField"
def get_prep_value(self, value): def get_db_prep_value(self, value, connection, prepared=False):
if hasattr(value, "as_sql"):
return value
return connection.ops.adapt_json_value(value, self.encoder)
def get_db_prep_save(self, value, connection):
if value is None: if value is None:
return value return value
return json.dumps(value, cls=self.encoder) return self.get_db_prep_value(value, connection)
def get_transform(self, name): def get_transform(self, name):
transform = super().get_transform(name) transform = super().get_transform(name)
@ -141,7 +150,7 @@ def compile_json_path(key_transforms, include_root=True):
return "".join(path) return "".join(path)
class DataContains(PostgresOperatorLookup): class DataContains(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
lookup_name = "contains" lookup_name = "contains"
postgres_operator = "@>" postgres_operator = "@>"
@ -156,7 +165,7 @@ class DataContains(PostgresOperatorLookup):
return "JSON_CONTAINS(%s, %s)" % (lhs, rhs), params return "JSON_CONTAINS(%s, %s)" % (lhs, rhs), params
class ContainedBy(PostgresOperatorLookup): class ContainedBy(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
lookup_name = "contained_by" lookup_name = "contained_by"
postgres_operator = "<@" postgres_operator = "<@"