Refs #33308 -- Used get_db_prep_value() to adapt JSONFields.

This commit is contained in:
Simon Charette 2022-10-31 22:28:17 -04:00 committed by Mariusz Felisiak
parent d87a7b9f4b
commit 5c23d9f0c3
3 changed files with 32 additions and 5 deletions

View File

@ -1,5 +1,6 @@
import datetime
import decimal
import json
from importlib import import_module
import sqlparse
@ -575,6 +576,9 @@ class BaseDatabaseOperations:
"""
return value or None
def adapt_json_value(self, value, encoder):
return json.dumps(value, cls=encoder)
def year_lookup_bounds_for_date_field(self, value, iso_year=False):
"""
Return a two-elements list with the lower and upper bound to be used

View File

@ -1,4 +1,8 @@
import json
from functools import lru_cache, partial
from psycopg2.extras import Inet
from psycopg2.extras import Json as Jsonb
from django.conf import settings
from django.db.backends.base.operations import BaseDatabaseOperations
@ -6,6 +10,13 @@ from django.db.backends.utils import split_tzname_delta
from django.db.models.constants import OnConflict
@lru_cache
def get_json_dumps(encoder):
if encoder is None:
return json.dumps
return partial(json.dumps, cls=encoder)
class DatabaseOperations(BaseDatabaseOperations):
cast_char_field_without_max_length = "varchar"
explain_prefix = "EXPLAIN"
@ -308,6 +319,9 @@ class DatabaseOperations(BaseDatabaseOperations):
return Inet(value)
return None
def adapt_json_value(self, value, encoder):
return Jsonb(value, dumps=get_json_dumps(encoder))
def subtract_temporals(self, internal_type, lhs, rhs):
if internal_type == "DateField":
lhs_sql, lhs_params = lhs

View File

@ -6,7 +6,11 @@ from django.db import NotSupportedError, connections, router
from django.db.models import lookups
from django.db.models.constants import LOOKUP_SEP
from django.db.models.fields import TextField
from django.db.models.lookups import PostgresOperatorLookup, Transform
from django.db.models.lookups import (
FieldGetDbPrepValueMixin,
PostgresOperatorLookup,
Transform,
)
from django.utils.translation import gettext_lazy as _
from . import Field
@ -92,10 +96,15 @@ class JSONField(CheckFieldDefaultMixin, Field):
def get_internal_type(self):
return "JSONField"
def get_prep_value(self, value):
def get_db_prep_value(self, value, connection, prepared=False):
if hasattr(value, "as_sql"):
return value
return connection.ops.adapt_json_value(value, self.encoder)
def get_db_prep_save(self, value, connection):
if value is None:
return value
return json.dumps(value, cls=self.encoder)
return self.get_db_prep_value(value, connection)
def get_transform(self, name):
transform = super().get_transform(name)
@ -141,7 +150,7 @@ def compile_json_path(key_transforms, include_root=True):
return "".join(path)
class DataContains(PostgresOperatorLookup):
class DataContains(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
lookup_name = "contains"
postgres_operator = "@>"
@ -156,7 +165,7 @@ class DataContains(PostgresOperatorLookup):
return "JSON_CONTAINS(%s, %s)" % (lhs, rhs), params
class ContainedBy(PostgresOperatorLookup):
class ContainedBy(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
lookup_name = "contained_by"
postgres_operator = "<@"