From 6789ded0a6ab797f0dcdfa6ad5d1cfa46e23abcd Mon Sep 17 00:00:00 2001 From: sage Date: Sun, 9 Jun 2019 07:56:37 +0700 Subject: [PATCH] Fixed #12990, Refs #27694 -- Added JSONField model field. Thanks to Adam Johnson, Carlton Gibson, Mariusz Felisiak, and Raphael Michel for mentoring this Google Summer of Code 2019 project and everyone else who helped with the patch. Special thanks to Mads Jensen, Nick Pope, and Simon Charette for extensive reviews. Co-authored-by: Mariusz Felisiak --- AUTHORS | 1 + django/contrib/postgres/aggregates/general.py | 4 +- django/contrib/postgres/apps.py | 1 - django/contrib/postgres/fields/jsonb.py | 204 +----- django/contrib/postgres/forms/jsonb.py | 69 +- django/contrib/postgres/lookups.py | 11 +- django/db/backends/base/features.py | 9 + django/db/backends/base/operations.py | 7 + django/db/backends/mysql/base.py | 8 +- django/db/backends/mysql/features.py | 12 + django/db/backends/mysql/introspection.py | 21 +- django/db/backends/mysql/operations.py | 10 + django/db/backends/oracle/base.py | 2 + django/db/backends/oracle/features.py | 1 + django/db/backends/oracle/introspection.py | 25 +- django/db/backends/oracle/operations.py | 6 +- django/db/backends/postgresql/base.py | 1 + django/db/backends/postgresql/features.py | 1 + .../db/backends/postgresql/introspection.py | 1 + django/db/backends/postgresql/operations.py | 3 + django/db/backends/sqlite3/base.py | 12 + django/db/backends/sqlite3/features.py | 16 + django/db/backends/sqlite3/introspection.py | 24 +- django/db/models/__init__.py | 7 +- django/db/models/fields/__init__.py | 2 + django/db/models/fields/json.py | 525 ++++++++++++++ django/db/models/functions/comparison.py | 15 +- django/forms/fields.py | 71 +- docs/internals/deprecation.txt | 7 + docs/ref/checks.txt | 4 + docs/ref/contrib/postgres/fields.txt | 95 +-- docs/ref/contrib/postgres/forms.txt | 8 +- docs/ref/databases.txt | 16 + docs/ref/forms/fields.txt | 54 ++ docs/ref/models/fields.txt | 69 +- docs/releases/3.1.txt | 48 ++ docs/topics/db/queries.txt | 230 ++++++ tests/backends/base/test_operations.py | 5 + .../forms_tests/field_tests/test_jsonfield.py | 110 +++ tests/inspectdb/models.py | 11 + tests/inspectdb/tests.py | 9 + tests/invalid_models_tests/test_models.py | 36 +- .../test_ordinary_fields.py | 48 +- tests/model_fields/models.py | 31 + tests/model_fields/test_jsonfield.py | 667 ++++++++++++++++++ tests/postgres_tests/fields.py | 7 +- .../migrations/0002_create_test_models.py | 17 +- tests/postgres_tests/models.py | 10 +- tests/postgres_tests/test_bulk_update.py | 3 +- tests/postgres_tests/test_introspection.py | 6 - tests/postgres_tests/test_json.py | 583 --------------- tests/postgres_tests/test_json_deprecation.py | 54 ++ tests/queries/models.py | 7 + tests/queries/test_bulk_update.py | 17 +- 54 files changed, 2240 insertions(+), 981 deletions(-) create mode 100644 django/db/models/fields/json.py create mode 100644 tests/forms_tests/field_tests/test_jsonfield.py create mode 100644 tests/model_fields/test_jsonfield.py delete mode 100644 tests/postgres_tests/test_json.py create mode 100644 tests/postgres_tests/test_json_deprecation.py diff --git a/AUTHORS b/AUTHORS index 41758bb38d..3b91851ae6 100644 --- a/AUTHORS +++ b/AUTHORS @@ -792,6 +792,7 @@ answer newbie questions, and generally made Django that much better: Ryan Rubin Ryno Mathee Sachin Jat + Sage M. Abdullah Sam Newman Sander Dijkhuis Sanket Saurav diff --git a/django/contrib/postgres/aggregates/general.py b/django/contrib/postgres/aggregates/general.py index 31dd52773b..12cba62701 100644 --- a/django/contrib/postgres/aggregates/general.py +++ b/django/contrib/postgres/aggregates/general.py @@ -1,5 +1,5 @@ -from django.contrib.postgres.fields import ArrayField, JSONField -from django.db.models import Aggregate, Value +from django.contrib.postgres.fields import ArrayField +from django.db.models import Aggregate, JSONField, Value from .mixins import OrderableAggMixin diff --git a/django/contrib/postgres/apps.py b/django/contrib/postgres/apps.py index 97475de6f7..25cfa1a814 100644 --- a/django/contrib/postgres/apps.py +++ b/django/contrib/postgres/apps.py @@ -47,7 +47,6 @@ class PostgresConfig(AppConfig): for conn in connections.all(): if conn.vendor == 'postgresql': conn.introspection.data_types_reverse.update({ - 3802: 'django.contrib.postgres.fields.JSONField', 3904: 'django.contrib.postgres.fields.IntegerRangeField', 3906: 'django.contrib.postgres.fields.DecimalRangeField', 3910: 'django.contrib.postgres.fields.DateTimeRangeField', diff --git a/django/contrib/postgres/fields/jsonb.py b/django/contrib/postgres/fields/jsonb.py index c402dd19d8..7f76b29a13 100644 --- a/django/contrib/postgres/fields/jsonb.py +++ b/django/contrib/postgres/fields/jsonb.py @@ -1,185 +1,43 @@ -import json +import warnings -from psycopg2.extras import Json - -from django.contrib.postgres import forms, lookups -from django.core import exceptions -from django.db.models import ( - Field, TextField, Transform, lookups as builtin_lookups, +from django.db.models import JSONField as BuiltinJSONField +from django.db.models.fields.json import ( + KeyTextTransform as BuiltinKeyTextTransform, + KeyTransform as BuiltinKeyTransform, ) -from django.db.models.fields.mixins import CheckFieldDefaultMixin -from django.utils.translation import gettext_lazy as _ +from django.utils.deprecation import RemovedInDjango40Warning __all__ = ['JSONField'] -class JsonAdapter(Json): - """ - Customized psycopg2.extras.Json to allow for a custom encoder. - """ - def __init__(self, adapted, dumps=None, encoder=None): - self.encoder = encoder - super().__init__(adapted, dumps=dumps) - - def dumps(self, obj): - options = {'cls': self.encoder} if self.encoder else {} - return json.dumps(obj, **options) - - -class JSONField(CheckFieldDefaultMixin, Field): - empty_strings_allowed = False - description = _('A JSON object') - default_error_messages = { - 'invalid': _("Value must be valid JSON."), +class JSONField(BuiltinJSONField): + system_check_deprecated_details = { + 'msg': ( + 'django.contrib.postgres.fields.JSONField is deprecated. Support ' + 'for it (except in historical migrations) will be removed in ' + 'Django 4.0.' + ), + 'hint': 'Use django.db.models.JSONField instead.', + 'id': 'fields.W904', } - _default_hint = ('dict', '{}') - - def __init__(self, verbose_name=None, name=None, encoder=None, **kwargs): - if encoder and not callable(encoder): - raise ValueError("The encoder parameter must be a callable object.") - self.encoder = encoder - super().__init__(verbose_name, name, **kwargs) - - def db_type(self, connection): - return 'jsonb' - - def deconstruct(self): - name, path, args, kwargs = super().deconstruct() - if self.encoder is not None: - kwargs['encoder'] = self.encoder - return name, path, args, kwargs - - def get_transform(self, name): - transform = super().get_transform(name) - if transform: - return transform - return KeyTransformFactory(name) - - def get_prep_value(self, value): - if value is not None: - return JsonAdapter(value, encoder=self.encoder) - return value - - def validate(self, value, model_instance): - super().validate(value, model_instance) - options = {'cls': self.encoder} if self.encoder else {} - try: - json.dumps(value, **options) - except TypeError: - raise exceptions.ValidationError( - self.error_messages['invalid'], - code='invalid', - params={'value': value}, - ) - - def value_to_string(self, obj): - return self.value_from_object(obj) - - def formfield(self, **kwargs): - return super().formfield(**{ - 'form_class': forms.JSONField, - **kwargs, - }) -JSONField.register_lookup(lookups.DataContains) -JSONField.register_lookup(lookups.ContainedBy) -JSONField.register_lookup(lookups.HasKey) -JSONField.register_lookup(lookups.HasKeys) -JSONField.register_lookup(lookups.HasAnyKeys) -JSONField.register_lookup(lookups.JSONExact) - - -class KeyTransform(Transform): - operator = '->' - nested_operator = '#>' - - def __init__(self, key_name, *args, **kwargs): - super().__init__(*args, **kwargs) - self.key_name = key_name - - def as_sql(self, compiler, connection): - key_transforms = [self.key_name] - previous = self.lhs - while isinstance(previous, KeyTransform): - key_transforms.insert(0, previous.key_name) - previous = previous.lhs - lhs, params = compiler.compile(previous) - if len(key_transforms) > 1: - return '(%s %s %%s)' % (lhs, self.nested_operator), params + [key_transforms] - try: - lookup = int(self.key_name) - except ValueError: - lookup = self.key_name - return '(%s %s %%s)' % (lhs, self.operator), tuple(params) + (lookup,) - - -class KeyTextTransform(KeyTransform): - operator = '->>' - nested_operator = '#>>' - output_field = TextField() - - -class KeyTransformTextLookupMixin: - """ - Mixin for combining with a lookup expecting a text lhs from a JSONField - key lookup. Make use of the ->> operator instead of casting key values to - text and performing the lookup on the resulting representation. - """ - def __init__(self, key_transform, *args, **kwargs): - assert isinstance(key_transform, KeyTransform) - key_text_transform = KeyTextTransform( - key_transform.key_name, *key_transform.source_expressions, **key_transform.extra +class KeyTransform(BuiltinKeyTransform): + def __init__(self, *args, **kwargs): + warnings.warn( + 'django.contrib.postgres.fields.jsonb.KeyTransform is deprecated ' + 'in favor of django.db.models.fields.json.KeyTransform.', + RemovedInDjango40Warning, stacklevel=2, ) - super().__init__(key_text_transform, *args, **kwargs) + super().__init__(*args, **kwargs) -class KeyTransformIExact(KeyTransformTextLookupMixin, builtin_lookups.IExact): - pass - - -class KeyTransformIContains(KeyTransformTextLookupMixin, builtin_lookups.IContains): - pass - - -class KeyTransformStartsWith(KeyTransformTextLookupMixin, builtin_lookups.StartsWith): - pass - - -class KeyTransformIStartsWith(KeyTransformTextLookupMixin, builtin_lookups.IStartsWith): - pass - - -class KeyTransformEndsWith(KeyTransformTextLookupMixin, builtin_lookups.EndsWith): - pass - - -class KeyTransformIEndsWith(KeyTransformTextLookupMixin, builtin_lookups.IEndsWith): - pass - - -class KeyTransformRegex(KeyTransformTextLookupMixin, builtin_lookups.Regex): - pass - - -class KeyTransformIRegex(KeyTransformTextLookupMixin, builtin_lookups.IRegex): - pass - - -KeyTransform.register_lookup(KeyTransformIExact) -KeyTransform.register_lookup(KeyTransformIContains) -KeyTransform.register_lookup(KeyTransformStartsWith) -KeyTransform.register_lookup(KeyTransformIStartsWith) -KeyTransform.register_lookup(KeyTransformEndsWith) -KeyTransform.register_lookup(KeyTransformIEndsWith) -KeyTransform.register_lookup(KeyTransformRegex) -KeyTransform.register_lookup(KeyTransformIRegex) - - -class KeyTransformFactory: - - def __init__(self, key_name): - self.key_name = key_name - - def __call__(self, *args, **kwargs): - return KeyTransform(self.key_name, *args, **kwargs) +class KeyTextTransform(BuiltinKeyTextTransform): + def __init__(self, *args, **kwargs): + warnings.warn( + 'django.contrib.postgres.fields.jsonb.KeyTextTransform is ' + 'deprecated in favor of ' + 'django.db.models.fields.json.KeyTextTransform.', + RemovedInDjango40Warning, stacklevel=2, + ) + super().__init__(*args, **kwargs) diff --git a/django/contrib/postgres/forms/jsonb.py b/django/contrib/postgres/forms/jsonb.py index 196d2b9096..ebc85efa6f 100644 --- a/django/contrib/postgres/forms/jsonb.py +++ b/django/contrib/postgres/forms/jsonb.py @@ -1,63 +1,16 @@ -import json +import warnings -from django import forms -from django.core.exceptions import ValidationError -from django.utils.translation import gettext_lazy as _ +from django.forms import JSONField as BuiltinJSONField +from django.utils.deprecation import RemovedInDjango40Warning __all__ = ['JSONField'] -class InvalidJSONInput(str): - pass - - -class JSONString(str): - pass - - -class JSONField(forms.CharField): - default_error_messages = { - 'invalid': _('ā€œ%(value)sā€ value must be valid JSON.'), - } - widget = forms.Textarea - - def to_python(self, value): - if self.disabled: - return value - if value in self.empty_values: - return None - elif isinstance(value, (list, dict, int, float, JSONString)): - return value - try: - converted = json.loads(value) - except json.JSONDecodeError: - raise ValidationError( - self.error_messages['invalid'], - code='invalid', - params={'value': value}, - ) - if isinstance(converted, str): - return JSONString(converted) - else: - return converted - - def bound_data(self, data, initial): - if self.disabled: - return initial - try: - return json.loads(data) - except json.JSONDecodeError: - return InvalidJSONInput(data) - - def prepare_value(self, value): - if isinstance(value, InvalidJSONInput): - return value - return json.dumps(value) - - def has_changed(self, initial, data): - if super().has_changed(initial, data): - return True - # For purposes of seeing whether something has changed, True isn't the - # same as 1 and the order of keys doesn't matter. - data = self.to_python(data) - return json.dumps(initial, sort_keys=True) != json.dumps(data, sort_keys=True) +class JSONField(BuiltinJSONField): + def __init__(self, *args, **kwargs): + warnings.warn( + 'django.contrib.postgres.forms.JSONField is deprecated in favor ' + 'of django.forms.JSONField.', + RemovedInDjango40Warning, stacklevel=2, + ) + super().__init__(*args, **kwargs) diff --git a/django/contrib/postgres/lookups.py b/django/contrib/postgres/lookups.py index 360e0c6a31..28d8590e1d 100644 --- a/django/contrib/postgres/lookups.py +++ b/django/contrib/postgres/lookups.py @@ -1,5 +1,5 @@ from django.db.models import Transform -from django.db.models.lookups import Exact, PostgresOperatorLookup +from django.db.models.lookups import PostgresOperatorLookup from .search import SearchVector, SearchVectorExact, SearchVectorField @@ -58,12 +58,3 @@ class SearchLookup(SearchVectorExact): class TrigramSimilar(PostgresOperatorLookup): lookup_name = 'trigram_similar' postgres_operator = '%%' - - -class JSONExact(Exact): - can_use_none_as_rhs = True - - def process_rhs(self, compiler, connection): - result = super().process_rhs(compiler, connection) - # Treat None lookup values as null. - return ("'null'", []) if result == ('%s', [None]) else result diff --git a/django/db/backends/base/features.py b/django/db/backends/base/features.py index a8f55f966c..33eeff171d 100644 --- a/django/db/backends/base/features.py +++ b/django/db/backends/base/features.py @@ -300,6 +300,15 @@ class BaseDatabaseFeatures: # Does the backend support boolean expressions in the SELECT clause? supports_boolean_expr_in_select_clause = True + # Does the backend support JSONField? + supports_json_field = True + # Can the backend introspect a JSONField? + can_introspect_json_field = True + # Does the backend support primitives in JSONField? + supports_primitives_in_json_field = True + # Is there a true datatype for JSON? + has_native_json_field = False + def __init__(self, connection): self.connection = connection diff --git a/django/db/backends/base/operations.py b/django/db/backends/base/operations.py index 1dbcee4637..6d0f5c68b3 100644 --- a/django/db/backends/base/operations.py +++ b/django/db/backends/base/operations.py @@ -159,6 +159,13 @@ class BaseDatabaseOperations: """ return self.date_extract_sql(lookup_type, field_name) + def json_cast_text_sql(self, field_name): + """Return the SQL to cast a JSON value to text value.""" + raise NotImplementedError( + 'subclasses of BaseDatabaseOperations may require a ' + 'json_cast_text_sql() method' + ) + def deferrable_sql(self): """ Return the SQL to make a constraint "initially deferred" during a diff --git a/django/db/backends/mysql/base.py b/django/db/backends/mysql/base.py index 44560ccdaf..8792f3c7c5 100644 --- a/django/db/backends/mysql/base.py +++ b/django/db/backends/mysql/base.py @@ -118,6 +118,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): 'BigIntegerField': 'bigint', 'IPAddressField': 'char(15)', 'GenericIPAddressField': 'char(39)', + 'JSONField': 'json', 'NullBooleanField': 'bool', 'OneToOneField': 'integer', 'PositiveBigIntegerField': 'bigint UNSIGNED', @@ -341,11 +342,16 @@ class DatabaseWrapper(BaseDatabaseWrapper): @cached_property def data_type_check_constraints(self): if self.features.supports_column_check_constraints: - return { + check_constraints = { 'PositiveBigIntegerField': '`%(column)s` >= 0', 'PositiveIntegerField': '`%(column)s` >= 0', 'PositiveSmallIntegerField': '`%(column)s` >= 0', } + if self.mysql_is_mariadb and self.mysql_version < (10, 4, 3): + # MariaDB < 10.4.3 doesn't automatically use the JSON_VALID as + # a check constraint. + check_constraints['JSONField'] = 'JSON_VALID(`%(column)s`)' + return check_constraints return {} @cached_property diff --git a/django/db/backends/mysql/features.py b/django/db/backends/mysql/features.py index 8a2a64c5e4..faa84f7d7c 100644 --- a/django/db/backends/mysql/features.py +++ b/django/db/backends/mysql/features.py @@ -160,3 +160,15 @@ class DatabaseFeatures(BaseDatabaseFeatures): def supports_default_in_lead_lag(self): # To be added in https://jira.mariadb.org/browse/MDEV-12981. return not self.connection.mysql_is_mariadb + + @cached_property + def supports_json_field(self): + if self.connection.mysql_is_mariadb: + return self.connection.mysql_version >= (10, 2, 7) + return self.connection.mysql_version >= (5, 7, 8) + + @cached_property + def can_introspect_json_field(self): + if self.connection.mysql_is_mariadb: + return self.supports_json_field and self.can_introspect_check_constraints + return self.supports_json_field diff --git a/django/db/backends/mysql/introspection.py b/django/db/backends/mysql/introspection.py index 50160ba590..1a104c7810 100644 --- a/django/db/backends/mysql/introspection.py +++ b/django/db/backends/mysql/introspection.py @@ -9,7 +9,7 @@ from django.db.backends.base.introspection import ( from django.db.models import Index from django.utils.datastructures import OrderedSet -FieldInfo = namedtuple('FieldInfo', BaseFieldInfo._fields + ('extra', 'is_unsigned')) +FieldInfo = namedtuple('FieldInfo', BaseFieldInfo._fields + ('extra', 'is_unsigned', 'has_json_constraint')) InfoLine = namedtuple('InfoLine', 'col_name data_type max_len num_prec num_scale extra column_default is_unsigned') @@ -24,6 +24,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): FIELD_TYPE.DOUBLE: 'FloatField', FIELD_TYPE.FLOAT: 'FloatField', FIELD_TYPE.INT24: 'IntegerField', + FIELD_TYPE.JSON: 'JSONField', FIELD_TYPE.LONG: 'IntegerField', FIELD_TYPE.LONGLONG: 'BigIntegerField', FIELD_TYPE.SHORT: 'SmallIntegerField', @@ -53,6 +54,10 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): return 'PositiveIntegerField' elif field_type == 'SmallIntegerField': return 'PositiveSmallIntegerField' + # JSON data type is an alias for LONGTEXT in MariaDB, use check + # constraints clauses to introspect JSONField. + if description.has_json_constraint: + return 'JSONField' return field_type def get_table_list(self, cursor): @@ -66,6 +71,19 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): Return a description of the table with the DB-API cursor.description interface." """ + json_constraints = {} + if self.connection.mysql_is_mariadb and self.connection.features.can_introspect_json_field: + # JSON data type is an alias for LONGTEXT in MariaDB, select + # JSON_VALID() constraints to introspect JSONField. + cursor.execute(""" + SELECT c.constraint_name AS column_name + FROM information_schema.check_constraints AS c + WHERE + c.table_name = %s AND + LOWER(c.check_clause) = 'json_valid(`' + LOWER(c.constraint_name) + '`)' AND + c.constraint_schema = DATABASE() + """, [table_name]) + json_constraints = {row[0] for row in cursor.fetchall()} # information_schema database gives more accurate results for some figures: # - varchar length returned by cursor.description is an internal length, # not visible length (#5725) @@ -100,6 +118,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): info.column_default, info.extra, info.is_unsigned, + line[0] in json_constraints, )) return fields diff --git a/django/db/backends/mysql/operations.py b/django/db/backends/mysql/operations.py index d01e3bef6b..bc04739f0d 100644 --- a/django/db/backends/mysql/operations.py +++ b/django/db/backends/mysql/operations.py @@ -368,3 +368,13 @@ class DatabaseOperations(BaseDatabaseOperations): def insert_statement(self, ignore_conflicts=False): return 'INSERT IGNORE INTO' if ignore_conflicts else super().insert_statement(ignore_conflicts) + + def lookup_cast(self, lookup_type, internal_type=None): + lookup = '%s' + if internal_type == 'JSONField': + if self.connection.mysql_is_mariadb or lookup_type in ( + 'iexact', 'contains', 'icontains', 'startswith', 'istartswith', + 'endswith', 'iendswith', 'regex', 'iregex', + ): + lookup = 'JSON_UNQUOTE(%s)' + return lookup diff --git a/django/db/backends/oracle/base.py b/django/db/backends/oracle/base.py index e9ec2bac51..e104530228 100644 --- a/django/db/backends/oracle/base.py +++ b/django/db/backends/oracle/base.py @@ -123,6 +123,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): 'FilePathField': 'NVARCHAR2(%(max_length)s)', 'FloatField': 'DOUBLE PRECISION', 'IntegerField': 'NUMBER(11)', + 'JSONField': 'NCLOB', 'BigIntegerField': 'NUMBER(19)', 'IPAddressField': 'VARCHAR2(15)', 'GenericIPAddressField': 'VARCHAR2(39)', @@ -141,6 +142,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): } data_type_check_constraints = { 'BooleanField': '%(qn_column)s IN (0,1)', + 'JSONField': '%(qn_column)s IS JSON', 'NullBooleanField': '%(qn_column)s IN (0,1)', 'PositiveBigIntegerField': '%(qn_column)s >= 0', 'PositiveIntegerField': '%(qn_column)s >= 0', diff --git a/django/db/backends/oracle/features.py b/django/db/backends/oracle/features.py index 3782874512..bae09559ce 100644 --- a/django/db/backends/oracle/features.py +++ b/django/db/backends/oracle/features.py @@ -60,3 +60,4 @@ class DatabaseFeatures(BaseDatabaseFeatures): supports_slicing_ordering_in_compound = True allows_multiple_constraints_on_same_fields = False supports_boolean_expr_in_select_clause = False + supports_primitives_in_json_field = False diff --git a/django/db/backends/oracle/introspection.py b/django/db/backends/oracle/introspection.py index 2322ae0b5d..3fab497b2a 100644 --- a/django/db/backends/oracle/introspection.py +++ b/django/db/backends/oracle/introspection.py @@ -7,7 +7,7 @@ from django.db.backends.base.introspection import ( BaseDatabaseIntrospection, FieldInfo as BaseFieldInfo, TableInfo, ) -FieldInfo = namedtuple('FieldInfo', BaseFieldInfo._fields + ('is_autofield',)) +FieldInfo = namedtuple('FieldInfo', BaseFieldInfo._fields + ('is_autofield', 'is_json')) class DatabaseIntrospection(BaseDatabaseIntrospection): @@ -45,6 +45,8 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): return 'IntegerField' elif scale == -127: return 'FloatField' + elif data_type == cx_Oracle.NCLOB and description.is_json: + return 'JSONField' return super().get_field_type(data_type, description) @@ -83,12 +85,23 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): CASE WHEN identity_column = 'YES' THEN 1 ELSE 0 - END as is_autofield + END as is_autofield, + CASE + WHEN EXISTS ( + SELECT 1 + FROM user_json_columns + WHERE + user_json_columns.table_name = user_tab_cols.table_name AND + user_json_columns.column_name = user_tab_cols.column_name + ) + THEN 1 + ELSE 0 + END as is_json FROM user_tab_cols WHERE table_name = UPPER(%s)""", [table_name]) field_map = { - column: (internal_size, default if default != 'NULL' else None, is_autofield) - for column, default, internal_size, is_autofield in cursor.fetchall() + column: (internal_size, default if default != 'NULL' else None, is_autofield, is_json) + for column, default, internal_size, is_autofield, is_json in cursor.fetchall() } self.cache_bust_counter += 1 cursor.execute("SELECT * FROM {} WHERE ROWNUM < 2 AND {} > 0".format( @@ -97,11 +110,11 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): description = [] for desc in cursor.description: name = desc[0] - internal_size, default, is_autofield = field_map[name] + internal_size, default, is_autofield, is_json = field_map[name] name = name % {} # cx_Oracle, for some reason, doubles percent signs. description.append(FieldInfo( self.identifier_converter(name), *desc[1:3], internal_size, desc[4] or 0, - desc[5] or 0, *desc[6:], default, is_autofield, + desc[5] or 0, *desc[6:], default, is_autofield, is_json, )) return description diff --git a/django/db/backends/oracle/operations.py b/django/db/backends/oracle/operations.py index 6f4121425f..9dc28c84cd 100644 --- a/django/db/backends/oracle/operations.py +++ b/django/db/backends/oracle/operations.py @@ -176,7 +176,7 @@ END; def get_db_converters(self, expression): converters = super().get_db_converters(expression) internal_type = expression.output_field.get_internal_type() - if internal_type == 'TextField': + if internal_type in ['JSONField', 'TextField']: converters.append(self.convert_textfield_value) elif internal_type == 'BinaryField': converters.append(self.convert_binaryfield_value) @@ -269,7 +269,7 @@ END; return tuple(columns) def field_cast_sql(self, db_type, internal_type): - if db_type and db_type.endswith('LOB'): + if db_type and db_type.endswith('LOB') and internal_type != 'JSONField': return "DBMS_LOB.SUBSTR(%s)" else: return "%s" @@ -307,6 +307,8 @@ END; def lookup_cast(self, lookup_type, internal_type=None): if lookup_type in ('iexact', 'icontains', 'istartswith', 'iendswith'): return "UPPER(%s)" + if internal_type == 'JSONField' and lookup_type == 'exact': + return 'DBMS_LOB.SUBSTR(%s)' return "%s" def max_in_list_size(self): diff --git a/django/db/backends/postgresql/base.py b/django/db/backends/postgresql/base.py index 192316d7fb..ed911a91da 100644 --- a/django/db/backends/postgresql/base.py +++ b/django/db/backends/postgresql/base.py @@ -86,6 +86,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): 'BigIntegerField': 'bigint', 'IPAddressField': 'inet', 'GenericIPAddressField': 'inet', + 'JSONField': 'jsonb', 'NullBooleanField': 'boolean', 'OneToOneField': 'integer', 'PositiveBigIntegerField': 'bigint', diff --git a/django/db/backends/postgresql/features.py b/django/db/backends/postgresql/features.py index 3b4199fa78..00a8009cf2 100644 --- a/django/db/backends/postgresql/features.py +++ b/django/db/backends/postgresql/features.py @@ -12,6 +12,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): has_real_datatype = True has_native_uuid_field = True has_native_duration_field = True + has_native_json_field = True can_defer_constraint_checks = True has_select_for_update = True has_select_for_update_nowait = True diff --git a/django/db/backends/postgresql/introspection.py b/django/db/backends/postgresql/introspection.py index beec8619cc..dee305cc06 100644 --- a/django/db/backends/postgresql/introspection.py +++ b/django/db/backends/postgresql/introspection.py @@ -26,6 +26,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): 1266: 'TimeField', 1700: 'DecimalField', 2950: 'UUIDField', + 3802: 'JSONField', } ignored_tables = [] diff --git a/django/db/backends/postgresql/operations.py b/django/db/backends/postgresql/operations.py index 70880d4179..c67062a4a7 100644 --- a/django/db/backends/postgresql/operations.py +++ b/django/db/backends/postgresql/operations.py @@ -74,6 +74,9 @@ class DatabaseOperations(BaseDatabaseOperations): def time_trunc_sql(self, lookup_type, field_name): return "DATE_TRUNC('%s', %s)::time" % (lookup_type, field_name) + def json_cast_text_sql(self, field_name): + return '(%s)::text' % field_name + def deferrable_sql(self): return " DEFERRABLE INITIALLY DEFERRED" diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py index 26968475bf..31e8a55a43 100644 --- a/django/db/backends/sqlite3/base.py +++ b/django/db/backends/sqlite3/base.py @@ -5,6 +5,7 @@ import datetime import decimal import functools import hashlib +import json import math import operator import re @@ -101,6 +102,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): 'BigIntegerField': 'bigint', 'IPAddressField': 'char(15)', 'GenericIPAddressField': 'char(39)', + 'JSONField': 'text', 'NullBooleanField': 'bool', 'OneToOneField': 'integer', 'PositiveBigIntegerField': 'bigint unsigned', @@ -115,6 +117,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): } data_type_check_constraints = { 'PositiveBigIntegerField': '"%(column)s" >= 0', + 'JSONField': '(JSON_VALID("%(column)s") OR "%(column)s" IS NULL)', 'PositiveIntegerField': '"%(column)s" >= 0', 'PositiveSmallIntegerField': '"%(column)s" >= 0', } @@ -233,6 +236,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): create_deterministic_function('DEGREES', 1, none_guard(math.degrees)) create_deterministic_function('EXP', 1, none_guard(math.exp)) create_deterministic_function('FLOOR', 1, none_guard(math.floor)) + create_deterministic_function('JSON_CONTAINS', 2, _sqlite_json_contains) create_deterministic_function('LN', 1, none_guard(math.log)) create_deterministic_function('LOG', 2, none_guard(lambda x, y: math.log(y, x))) create_deterministic_function('LPAD', 3, _sqlite_lpad) @@ -598,3 +602,11 @@ def _sqlite_lpad(text, length, fill_text): @none_guard def _sqlite_rpad(text, length, fill_text): return (text + fill_text * length)[:length] + + +@none_guard +def _sqlite_json_contains(haystack, needle): + target, candidate = json.loads(haystack), json.loads(needle) + if isinstance(target, dict) and isinstance(candidate, dict): + return target.items() >= candidate.items() + return target == candidate diff --git a/django/db/backends/sqlite3/features.py b/django/db/backends/sqlite3/features.py index 817b1067e3..1b6f99a58c 100644 --- a/django/db/backends/sqlite3/features.py +++ b/django/db/backends/sqlite3/features.py @@ -1,4 +1,9 @@ +import operator + +from django.db import transaction from django.db.backends.base.features import BaseDatabaseFeatures +from django.db.utils import OperationalError +from django.utils.functional import cached_property from .base import Database @@ -45,3 +50,14 @@ class DatabaseFeatures(BaseDatabaseFeatures): supports_aggregate_filter_clause = Database.sqlite_version_info >= (3, 30, 1) supports_order_by_nulls_modifier = Database.sqlite_version_info >= (3, 30, 0) order_by_nulls_first = True + + @cached_property + def supports_json_field(self): + try: + with self.connection.cursor() as cursor, transaction.atomic(): + cursor.execute('SELECT JSON(\'{"a": "b"}\')') + except OperationalError: + return False + return True + + can_introspect_json_field = property(operator.attrgetter('supports_json_field')) diff --git a/django/db/backends/sqlite3/introspection.py b/django/db/backends/sqlite3/introspection.py index a203c454df..992e925e10 100644 --- a/django/db/backends/sqlite3/introspection.py +++ b/django/db/backends/sqlite3/introspection.py @@ -9,7 +9,7 @@ from django.db.backends.base.introspection import ( from django.db.models import Index from django.utils.regex_helper import _lazy_re_compile -FieldInfo = namedtuple('FieldInfo', BaseFieldInfo._fields + ('pk',)) +FieldInfo = namedtuple('FieldInfo', BaseFieldInfo._fields + ('pk', 'has_json_constraint')) field_size_re = _lazy_re_compile(r'^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$') @@ -63,6 +63,8 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): # No support for BigAutoField or SmallAutoField as SQLite treats # all integer primary keys as signed 64-bit integers. return 'AutoField' + if description.has_json_constraint: + return 'JSONField' return field_type def get_table_list(self, cursor): @@ -81,12 +83,28 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): interface. """ cursor.execute('PRAGMA table_info(%s)' % self.connection.ops.quote_name(table_name)) + table_info = cursor.fetchall() + json_columns = set() + if self.connection.features.can_introspect_json_field: + for line in table_info: + column = line[1] + json_constraint_sql = '%%json_valid("%s")%%' % column + has_json_constraint = cursor.execute(""" + SELECT sql + FROM sqlite_master + WHERE + type = 'table' AND + name = %s AND + sql LIKE %s + """, [table_name, json_constraint_sql]).fetchone() + if has_json_constraint: + json_columns.add(column) return [ FieldInfo( name, data_type, None, get_field_size(data_type), None, None, - not notnull, default, pk == 1, + not notnull, default, pk == 1, name in json_columns ) - for cid, name, data_type, notnull, default, pk in cursor.fetchall() + for cid, name, data_type, notnull, default, pk in table_info ] def get_sequences(self, cursor, table_name, table_fields=()): diff --git a/django/db/models/__init__.py b/django/db/models/__init__.py index 7af6e60c51..a583af2aff 100644 --- a/django/db/models/__init__.py +++ b/django/db/models/__init__.py @@ -18,6 +18,7 @@ from django.db.models.expressions import ( from django.db.models.fields import * # NOQA from django.db.models.fields import __all__ as fields_all from django.db.models.fields.files import FileField, ImageField +from django.db.models.fields.json import JSONField from django.db.models.fields.proxy import OrderWrt from django.db.models.indexes import * # NOQA from django.db.models.indexes import __all__ as indexes_all @@ -43,9 +44,9 @@ __all__ += [ 'Func', 'OrderBy', 'OuterRef', 'RowRange', 'Subquery', 'Value', 'ValueRange', 'When', 'Window', 'WindowFrame', - 'FileField', 'ImageField', 'OrderWrt', 'Lookup', 'Transform', 'Manager', - 'Prefetch', 'Q', 'QuerySet', 'prefetch_related_objects', 'DEFERRED', 'Model', - 'FilteredRelation', + 'FileField', 'ImageField', 'JSONField', 'OrderWrt', 'Lookup', 'Transform', + 'Manager', 'Prefetch', 'Q', 'QuerySet', 'prefetch_related_objects', + 'DEFERRED', 'Model', 'FilteredRelation', 'ForeignKey', 'ForeignObject', 'OneToOneField', 'ManyToManyField', 'ForeignObjectRel', 'ManyToOneRel', 'ManyToManyRel', 'OneToOneRel', ] diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index 08c2a18d94..0fd69059ee 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -496,6 +496,8 @@ class Field(RegisterLookupMixin): path = path.replace("django.db.models.fields.related", "django.db.models") elif path.startswith("django.db.models.fields.files"): path = path.replace("django.db.models.fields.files", "django.db.models") + elif path.startswith('django.db.models.fields.json'): + path = path.replace('django.db.models.fields.json', 'django.db.models') elif path.startswith("django.db.models.fields.proxy"): path = path.replace("django.db.models.fields.proxy", "django.db.models") elif path.startswith("django.db.models.fields"): diff --git a/django/db/models/fields/json.py b/django/db/models/fields/json.py new file mode 100644 index 0000000000..edc5441799 --- /dev/null +++ b/django/db/models/fields/json.py @@ -0,0 +1,525 @@ +import json + +from django import forms +from django.core import checks, exceptions +from django.db import NotSupportedError, connections, router +from django.db.models import lookups +from django.db.models.lookups import PostgresOperatorLookup, Transform +from django.utils.translation import gettext_lazy as _ + +from . import Field +from .mixins import CheckFieldDefaultMixin + +__all__ = ['JSONField'] + + +class JSONField(CheckFieldDefaultMixin, Field): + empty_strings_allowed = False + description = _('A JSON object') + default_error_messages = { + 'invalid': _('Value must be valid JSON.'), + } + _default_hint = ('dict', '{}') + + def __init__( + self, verbose_name=None, name=None, encoder=None, decoder=None, + **kwargs, + ): + if encoder and not callable(encoder): + raise ValueError('The encoder parameter must be a callable object.') + if decoder and not callable(decoder): + raise ValueError('The decoder parameter must be a callable object.') + self.encoder = encoder + self.decoder = decoder + super().__init__(verbose_name, name, **kwargs) + + def check(self, **kwargs): + errors = super().check(**kwargs) + databases = kwargs.get('databases') or [] + errors.extend(self._check_supported(databases)) + return errors + + def _check_supported(self, databases): + errors = [] + for db in databases: + if not router.allow_migrate_model(db, self.model): + continue + connection = connections[db] + if not ( + 'supports_json_field' in self.model._meta.required_db_features or + connection.features.supports_json_field + ): + errors.append( + checks.Error( + '%s does not support JSONFields.' + % connection.display_name, + obj=self.model, + id='fields.E180', + ) + ) + return errors + + def deconstruct(self): + name, path, args, kwargs = super().deconstruct() + if self.encoder is not None: + kwargs['encoder'] = self.encoder + if self.decoder is not None: + kwargs['decoder'] = self.decoder + return name, path, args, kwargs + + def from_db_value(self, value, expression, connection): + if value is None: + return value + if connection.features.has_native_json_field and self.decoder is None: + return value + try: + return json.loads(value, cls=self.decoder) + except json.JSONDecodeError: + return value + + def get_internal_type(self): + return 'JSONField' + + def get_prep_value(self, value): + if value is None: + return value + return json.dumps(value, cls=self.encoder) + + def get_transform(self, name): + transform = super().get_transform(name) + if transform: + return transform + return KeyTransformFactory(name) + + def select_format(self, compiler, sql, params): + if ( + compiler.connection.features.has_native_json_field and + self.decoder is not None + ): + return compiler.connection.ops.json_cast_text_sql(sql), params + return super().select_format(compiler, sql, params) + + def validate(self, value, model_instance): + super().validate(value, model_instance) + try: + json.dumps(value, cls=self.encoder) + except TypeError: + raise exceptions.ValidationError( + self.error_messages['invalid'], + code='invalid', + params={'value': value}, + ) + + def value_to_string(self, obj): + return self.value_from_object(obj) + + def formfield(self, **kwargs): + return super().formfield(**{ + 'form_class': forms.JSONField, + 'encoder': self.encoder, + 'decoder': self.decoder, + **kwargs, + }) + + +def compile_json_path(key_transforms, include_root=True): + path = ['$'] if include_root else [] + for key_transform in key_transforms: + try: + num = int(key_transform) + except ValueError: # non-integer + path.append('.') + path.append(json.dumps(key_transform)) + else: + path.append('[%s]' % num) + return ''.join(path) + + +class DataContains(PostgresOperatorLookup): + lookup_name = 'contains' + postgres_operator = '@>' + + def as_sql(self, compiler, connection): + lhs, lhs_params = self.process_lhs(compiler, connection) + rhs, rhs_params = self.process_rhs(compiler, connection) + params = tuple(lhs_params) + tuple(rhs_params) + return 'JSON_CONTAINS(%s, %s)' % (lhs, rhs), params + + def as_oracle(self, compiler, connection): + if isinstance(self.rhs, KeyTransform): + return HasKey(self.lhs, self.rhs).as_oracle(compiler, connection) + lhs, lhs_params = self.process_lhs(compiler, connection) + params = tuple(lhs_params) + sql = ( + "JSON_QUERY(%s, '$%s' WITH WRAPPER) = " + "JSON_QUERY('%s', '$.value' WITH WRAPPER)" + ) + rhs = json.loads(self.rhs) + if isinstance(rhs, dict): + if not rhs: + return "DBMS_LOB.SUBSTR(%s) LIKE '{%%%%}'" % lhs, params + return ' AND '.join([ + sql % ( + lhs, '.%s' % json.dumps(key), json.dumps({'value': value}), + ) for key, value in rhs.items() + ]), params + return sql % (lhs, '', json.dumps({'value': rhs})), params + + +class ContainedBy(PostgresOperatorLookup): + lookup_name = 'contained_by' + postgres_operator = '<@' + + def as_sql(self, compiler, connection): + lhs, lhs_params = self.process_lhs(compiler, connection) + rhs, rhs_params = self.process_rhs(compiler, connection) + params = tuple(rhs_params) + tuple(lhs_params) + return 'JSON_CONTAINS(%s, %s)' % (rhs, lhs), params + + def as_oracle(self, compiler, connection): + raise NotSupportedError('contained_by lookup is not supported on Oracle.') + + +class HasKeyLookup(PostgresOperatorLookup): + logical_operator = None + + def as_sql(self, compiler, connection, template=None): + # Process JSON path from the left-hand side. + if isinstance(self.lhs, KeyTransform): + lhs, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs(compiler, connection) + lhs_json_path = compile_json_path(lhs_key_transforms) + else: + lhs, lhs_params = self.process_lhs(compiler, connection) + lhs_json_path = '$' + sql = template % lhs + # Process JSON path from the right-hand side. + rhs = self.rhs + rhs_params = [] + if not isinstance(rhs, (list, tuple)): + rhs = [rhs] + for key in rhs: + if isinstance(key, KeyTransform): + *_, rhs_key_transforms = key.preprocess_lhs(compiler, connection) + else: + rhs_key_transforms = [key] + rhs_params.append('%s%s' % ( + lhs_json_path, + compile_json_path(rhs_key_transforms, include_root=False), + )) + # Add condition for each key. + if self.logical_operator: + sql = '(%s)' % self.logical_operator.join([sql] * len(rhs_params)) + return sql, tuple(lhs_params) + tuple(rhs_params) + + def as_mysql(self, compiler, connection): + return self.as_sql(compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %%s)") + + def as_oracle(self, compiler, connection): + sql, params = self.as_sql(compiler, connection, template="JSON_EXISTS(%s, '%%s')") + # Add paths directly into SQL because path expressions cannot be passed + # as bind variables on Oracle. + return sql % tuple(params), [] + + def as_postgresql(self, compiler, connection): + if isinstance(self.rhs, KeyTransform): + *_, rhs_key_transforms = self.rhs.preprocess_lhs(compiler, connection) + for key in rhs_key_transforms[:-1]: + self.lhs = KeyTransform(key, self.lhs) + self.rhs = rhs_key_transforms[-1] + return super().as_postgresql(compiler, connection) + + def as_sqlite(self, compiler, connection): + return self.as_sql(compiler, connection, template='JSON_TYPE(%s, %%s) IS NOT NULL') + + +class HasKey(HasKeyLookup): + lookup_name = 'has_key' + postgres_operator = '?' + prepare_rhs = False + + +class HasKeys(HasKeyLookup): + lookup_name = 'has_keys' + postgres_operator = '?&' + logical_operator = ' AND ' + + def get_prep_lookup(self): + return [str(item) for item in self.rhs] + + +class HasAnyKeys(HasKeys): + lookup_name = 'has_any_keys' + postgres_operator = '?|' + logical_operator = ' OR ' + + +class JSONExact(lookups.Exact): + can_use_none_as_rhs = True + + def process_lhs(self, compiler, connection): + lhs, lhs_params = super().process_lhs(compiler, connection) + if connection.vendor == 'sqlite': + rhs, rhs_params = super().process_rhs(compiler, connection) + if rhs == '%s' and rhs_params == [None]: + # Use JSON_TYPE instead of JSON_EXTRACT for NULLs. + lhs = "JSON_TYPE(%s, '$')" % lhs + return lhs, lhs_params + + def process_rhs(self, compiler, connection): + rhs, rhs_params = super().process_rhs(compiler, connection) + # Treat None lookup values as null. + if rhs == '%s' and rhs_params == [None]: + rhs_params = ['null'] + if connection.vendor == 'mysql': + func = ["JSON_EXTRACT(%s, '$')"] * len(rhs_params) + rhs = rhs % tuple(func) + return rhs, rhs_params + + +JSONField.register_lookup(DataContains) +JSONField.register_lookup(ContainedBy) +JSONField.register_lookup(HasKey) +JSONField.register_lookup(HasKeys) +JSONField.register_lookup(HasAnyKeys) +JSONField.register_lookup(JSONExact) + + +class KeyTransform(Transform): + postgres_operator = '->' + postgres_nested_operator = '#>' + + def __init__(self, key_name, *args, **kwargs): + super().__init__(*args, **kwargs) + self.key_name = str(key_name) + + def preprocess_lhs(self, compiler, connection, lhs_only=False): + if not lhs_only: + key_transforms = [self.key_name] + previous = self.lhs + while isinstance(previous, KeyTransform): + if not lhs_only: + key_transforms.insert(0, previous.key_name) + previous = previous.lhs + lhs, params = compiler.compile(previous) + if connection.vendor == 'oracle': + # Escape string-formatting. + key_transforms = [key.replace('%', '%%') for key in key_transforms] + return (lhs, params, key_transforms) if not lhs_only else (lhs, params) + + def as_mysql(self, compiler, connection): + lhs, params, key_transforms = self.preprocess_lhs(compiler, connection) + json_path = compile_json_path(key_transforms) + return 'JSON_EXTRACT(%s, %%s)' % lhs, tuple(params) + (json_path,) + + def as_oracle(self, compiler, connection): + lhs, params, key_transforms = self.preprocess_lhs(compiler, connection) + json_path = compile_json_path(key_transforms) + return ( + "COALESCE(JSON_QUERY(%s, '%s'), JSON_VALUE(%s, '%s'))" % + ((lhs, json_path) * 2) + ), tuple(params) * 2 + + def as_postgresql(self, compiler, connection): + lhs, params, key_transforms = self.preprocess_lhs(compiler, connection) + if len(key_transforms) > 1: + return '(%s %s %%s)' % (lhs, self.postgres_nested_operator), params + [key_transforms] + try: + lookup = int(self.key_name) + except ValueError: + lookup = self.key_name + return '(%s %s %%s)' % (lhs, self.postgres_operator), tuple(params) + (lookup,) + + def as_sqlite(self, compiler, connection): + lhs, params, key_transforms = self.preprocess_lhs(compiler, connection) + json_path = compile_json_path(key_transforms) + return 'JSON_EXTRACT(%s, %%s)' % lhs, tuple(params) + (json_path,) + + +class KeyTextTransform(KeyTransform): + postgres_operator = '->>' + postgres_nested_operator = '#>>' + + +class KeyTransformTextLookupMixin: + """ + Mixin for combining with a lookup expecting a text lhs from a JSONField + key lookup. On PostgreSQL, make use of the ->> operator instead of casting + key values to text and performing the lookup on the resulting + representation. + """ + def __init__(self, key_transform, *args, **kwargs): + if not isinstance(key_transform, KeyTransform): + raise TypeError( + 'Transform should be an instance of KeyTransform in order to ' + 'use this lookup.' + ) + key_text_transform = KeyTextTransform( + key_transform.key_name, *key_transform.source_expressions, + **key_transform.extra, + ) + super().__init__(key_text_transform, *args, **kwargs) + + +class CaseInsensitiveMixin: + """ + Mixin to allow case-insensitive comparison of JSON values on MySQL. + MySQL handles strings used in JSON context using the utf8mb4_bin collation. + Because utf8mb4_bin is a binary collation, comparison of JSON values is + case-sensitive. + """ + def process_lhs(self, compiler, connection): + lhs, lhs_params = super().process_lhs(compiler, connection) + if connection.vendor == 'mysql': + return 'LOWER(%s)' % lhs, lhs_params + return lhs, lhs_params + + def process_rhs(self, compiler, connection): + rhs, rhs_params = super().process_rhs(compiler, connection) + if connection.vendor == 'mysql': + return 'LOWER(%s)' % rhs, rhs_params + return rhs, rhs_params + + +class KeyTransformIsNull(lookups.IsNull): + # key__isnull=False is the same as has_key='key' + def as_oracle(self, compiler, connection): + if not self.rhs: + return HasKey(self.lhs.lhs, self.lhs.key_name).as_oracle(compiler, connection) + return super().as_sql(compiler, connection) + + def as_sqlite(self, compiler, connection): + if not self.rhs: + return HasKey(self.lhs.lhs, self.lhs.key_name).as_sqlite(compiler, connection) + return super().as_sql(compiler, connection) + + +class KeyTransformExact(JSONExact): + def process_lhs(self, compiler, connection): + lhs, lhs_params = super().process_lhs(compiler, connection) + if connection.vendor == 'sqlite': + rhs, rhs_params = super().process_rhs(compiler, connection) + if rhs == '%s' and rhs_params == ['null']: + lhs, _ = self.lhs.preprocess_lhs(compiler, connection, lhs_only=True) + lhs = 'JSON_TYPE(%s, %%s)' % lhs + return lhs, lhs_params + + def process_rhs(self, compiler, connection): + if isinstance(self.rhs, KeyTransform): + return super(lookups.Exact, self).process_rhs(compiler, connection) + rhs, rhs_params = super().process_rhs(compiler, connection) + if connection.vendor == 'oracle': + func = [] + for value in rhs_params: + value = json.loads(value) + function = 'JSON_QUERY' if isinstance(value, (list, dict)) else 'JSON_VALUE' + func.append("%s('%s', '$.value')" % ( + function, + json.dumps({'value': value}), + )) + rhs = rhs % tuple(func) + rhs_params = [] + elif connection.vendor == 'sqlite': + func = ["JSON_EXTRACT(%s, '$')" if value != 'null' else '%s' for value in rhs_params] + rhs = rhs % tuple(func) + return rhs, rhs_params + + def as_oracle(self, compiler, connection): + rhs, rhs_params = super().process_rhs(compiler, connection) + if rhs_params == ['null']: + # Field has key and it's NULL. + has_key_expr = HasKey(self.lhs.lhs, self.lhs.key_name) + has_key_sql, has_key_params = has_key_expr.as_oracle(compiler, connection) + is_null_expr = self.lhs.get_lookup('isnull')(self.lhs, True) + is_null_sql, is_null_params = is_null_expr.as_sql(compiler, connection) + return ( + '%s AND %s' % (has_key_sql, is_null_sql), + tuple(has_key_params) + tuple(is_null_params), + ) + return super().as_sql(compiler, connection) + + +class KeyTransformIExact(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact): + pass + + +class KeyTransformIContains(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains): + pass + + +class KeyTransformContains(KeyTransformTextLookupMixin, lookups.Contains): + pass + + +class KeyTransformStartsWith(KeyTransformTextLookupMixin, lookups.StartsWith): + pass + + +class KeyTransformIStartsWith(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith): + pass + + +class KeyTransformEndsWith(KeyTransformTextLookupMixin, lookups.EndsWith): + pass + + +class KeyTransformIEndsWith(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith): + pass + + +class KeyTransformRegex(KeyTransformTextLookupMixin, lookups.Regex): + pass + + +class KeyTransformIRegex(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex): + pass + + +class KeyTransformNumericLookupMixin: + def process_rhs(self, compiler, connection): + rhs, rhs_params = super().process_rhs(compiler, connection) + if not connection.features.has_native_json_field: + rhs_params = [json.loads(value) for value in rhs_params] + return rhs, rhs_params + + +class KeyTransformLt(KeyTransformNumericLookupMixin, lookups.LessThan): + pass + + +class KeyTransformLte(KeyTransformNumericLookupMixin, lookups.LessThanOrEqual): + pass + + +class KeyTransformGt(KeyTransformNumericLookupMixin, lookups.GreaterThan): + pass + + +class KeyTransformGte(KeyTransformNumericLookupMixin, lookups.GreaterThanOrEqual): + pass + + +KeyTransform.register_lookup(KeyTransformExact) +KeyTransform.register_lookup(KeyTransformIExact) +KeyTransform.register_lookup(KeyTransformIsNull) +KeyTransform.register_lookup(KeyTransformContains) +KeyTransform.register_lookup(KeyTransformIContains) +KeyTransform.register_lookup(KeyTransformStartsWith) +KeyTransform.register_lookup(KeyTransformIStartsWith) +KeyTransform.register_lookup(KeyTransformEndsWith) +KeyTransform.register_lookup(KeyTransformIEndsWith) +KeyTransform.register_lookup(KeyTransformRegex) +KeyTransform.register_lookup(KeyTransformIRegex) + +KeyTransform.register_lookup(KeyTransformLt) +KeyTransform.register_lookup(KeyTransformLte) +KeyTransform.register_lookup(KeyTransformGt) +KeyTransform.register_lookup(KeyTransformGte) + + +class KeyTransformFactory: + + def __init__(self, key_name): + self.key_name = key_name + + def __call__(self, *args, **kwargs): + return KeyTransform(self.key_name, *args, **kwargs) diff --git a/django/db/models/functions/comparison.py b/django/db/models/functions/comparison.py index 24c3c4b4b8..6dc235bffb 100644 --- a/django/db/models/functions/comparison.py +++ b/django/db/models/functions/comparison.py @@ -29,8 +29,14 @@ class Cast(Func): return self.as_sql(compiler, connection, **extra_context) def as_mysql(self, compiler, connection, **extra_context): + template = None + output_type = self.output_field.get_internal_type() # MySQL doesn't support explicit cast to float. - template = '(%(expressions)s + 0.0)' if self.output_field.get_internal_type() == 'FloatField' else None + if output_type == 'FloatField': + template = '(%(expressions)s + 0.0)' + # MariaDB doesn't support explicit cast to JSON. + elif output_type == 'JSONField' and connection.mysql_is_mariadb: + template = "JSON_EXTRACT(%(expressions)s, '$')" return self.as_sql(compiler, connection, template=template, **extra_context) def as_postgresql(self, compiler, connection, **extra_context): @@ -39,6 +45,13 @@ class Cast(Func): # expression. return self.as_sql(compiler, connection, template='(%(expressions)s)::%(db_type)s', **extra_context) + def as_oracle(self, compiler, connection, **extra_context): + if self.output_field.get_internal_type() == 'JSONField': + # Oracle doesn't support explicit cast to JSON. + template = "JSON_QUERY(%(expressions)s, '$')" + return super().as_sql(compiler, connection, template=template, **extra_context) + return self.as_sql(compiler, connection, **extra_context) + class Coalesce(Func): """Return, from left to right, the first non-null expression.""" diff --git a/django/forms/fields.py b/django/forms/fields.py index c5374c7e9d..36dad72704 100644 --- a/django/forms/fields.py +++ b/django/forms/fields.py @@ -4,6 +4,7 @@ Field classes. import copy import datetime +import json import math import operator import os @@ -21,8 +22,8 @@ from django.forms.widgets import ( FILE_INPUT_CONTRADICTION, CheckboxInput, ClearableFileInput, DateInput, DateTimeInput, EmailInput, FileInput, HiddenInput, MultipleHiddenInput, NullBooleanSelect, NumberInput, Select, SelectMultiple, - SplitDateTimeWidget, SplitHiddenDateTimeWidget, TextInput, TimeInput, - URLInput, + SplitDateTimeWidget, SplitHiddenDateTimeWidget, Textarea, TextInput, + TimeInput, URLInput, ) from django.utils import formats from django.utils.dateparse import parse_datetime, parse_duration @@ -38,7 +39,8 @@ __all__ = ( 'BooleanField', 'NullBooleanField', 'ChoiceField', 'MultipleChoiceField', 'ComboField', 'MultiValueField', 'FloatField', 'DecimalField', 'SplitDateTimeField', 'GenericIPAddressField', 'FilePathField', - 'SlugField', 'TypedChoiceField', 'TypedMultipleChoiceField', 'UUIDField', + 'JSONField', 'SlugField', 'TypedChoiceField', 'TypedMultipleChoiceField', + 'UUIDField', ) @@ -1211,3 +1213,66 @@ class UUIDField(CharField): except ValueError: raise ValidationError(self.error_messages['invalid'], code='invalid') return value + + +class InvalidJSONInput(str): + pass + + +class JSONString(str): + pass + + +class JSONField(CharField): + default_error_messages = { + 'invalid': _('Enter a valid JSON.'), + } + widget = Textarea + + def __init__(self, encoder=None, decoder=None, **kwargs): + self.encoder = encoder + self.decoder = decoder + super().__init__(**kwargs) + + def to_python(self, value): + if self.disabled: + return value + if value in self.empty_values: + return None + elif isinstance(value, (list, dict, int, float, JSONString)): + return value + try: + converted = json.loads(value, cls=self.decoder) + except json.JSONDecodeError: + raise ValidationError( + self.error_messages['invalid'], + code='invalid', + params={'value': value}, + ) + if isinstance(converted, str): + return JSONString(converted) + else: + return converted + + def bound_data(self, data, initial): + if self.disabled: + return initial + try: + return json.loads(data, cls=self.decoder) + except json.JSONDecodeError: + return InvalidJSONInput(data) + + def prepare_value(self, value): + if isinstance(value, InvalidJSONInput): + return value + return json.dumps(value, cls=self.encoder) + + def has_changed(self, initial, data): + if super().has_changed(initial, data): + return True + # For purposes of seeing whether something has changed, True isn't the + # same as 1 and the order of keys doesn't matter. + return ( + json.dumps(initial, sort_keys=True, cls=self.encoder) != + json.dumps(self.to_python(data), sort_keys=True, cls=self.encoder) + ) diff --git a/docs/internals/deprecation.txt b/docs/internals/deprecation.txt index 1d89238ede..183ce23408 100644 --- a/docs/internals/deprecation.txt +++ b/docs/internals/deprecation.txt @@ -83,6 +83,13 @@ details on these changes. * ``django.conf.urls.url()`` will be removed. +* The model ``django.contrib.postgres.fields.JSONField`` will be removed. A + stub field will remain for compatibility with historical migrations. + +* ``django.contrib.postgres.forms.JSONField``, + ``django.contrib.postgres.fields.jsonb.KeyTransform``, and + ``django.contrib.postgres.fields.jsonb.KeyTextTransform`` will be removed. + See the :ref:`Django 3.1 release notes ` for more details on these changes. diff --git a/docs/ref/checks.txt b/docs/ref/checks.txt index daf651392f..37a3a572c9 100644 --- a/docs/ref/checks.txt +++ b/docs/ref/checks.txt @@ -190,6 +190,7 @@ Model fields ```` columns. * **fields.E170**: ``BinaryField``ā€™s ``default`` cannot be a string. Use bytes content instead. +* **fields.E180**: ```` does not support ``JSONField``\s. * **fields.E900**: ``IPAddressField`` has been removed except for support in historical migrations. * **fields.W900**: ``IPAddressField`` has been deprecated. Support for it @@ -204,6 +205,9 @@ Model fields Django 3.1. *This check appeared in Django 2.2 and 3.0*. * **fields.W903**: ``NullBooleanField`` is deprecated. Support for it (except in historical migrations) will be removed in Django 4.0. +* **fields.W904**: ``django.contrib.postgres.fields.JSONField`` is deprecated. + Support for it (except in historical migrations) will be removed in Django + 4.0. File fields ~~~~~~~~~~~ diff --git a/docs/ref/contrib/postgres/fields.txt b/docs/ref/contrib/postgres/fields.txt index baebba9c50..aeacc72e7c 100644 --- a/docs/ref/contrib/postgres/fields.txt +++ b/docs/ref/contrib/postgres/fields.txt @@ -16,8 +16,7 @@ Indexes such as :class:`~django.contrib.postgres.indexes.GinIndex` and :class:`~django.contrib.postgres.indexes.GistIndex` are better suited, though the index choice is dependent on the queries that you're using. Generally, GiST may be a good choice for the :ref:`range fields ` and -:class:`HStoreField`, and GIN may be helpful for :class:`ArrayField` and -:class:`JSONField`. +:class:`HStoreField`, and GIN may be helpful for :class:`ArrayField`. ``ArrayField`` ============== @@ -517,96 +516,14 @@ using in conjunction with lookups on of the JSON which allows indexing. The trade-off is a small additional cost on writing to the ``jsonb`` field. ``JSONField`` uses ``jsonb``. +.. deprecated:: 3.1 + + Use :class:`django.db.models.JSONField` instead. + Querying ``JSONField`` ---------------------- -We will use the following example model:: - - from django.contrib.postgres.fields import JSONField - from django.db import models - - class Dog(models.Model): - name = models.CharField(max_length=200) - data = JSONField() - - def __str__(self): - return self.name - -.. fieldlookup:: jsonfield.key - -Key, index, and path lookups -~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -To query based on a given dictionary key, use that key as the lookup name:: - - >>> Dog.objects.create(name='Rufus', data={ - ... 'breed': 'labrador', - ... 'owner': { - ... 'name': 'Bob', - ... 'other_pets': [{ - ... 'name': 'Fishy', - ... }], - ... }, - ... }) - >>> Dog.objects.create(name='Meg', data={'breed': 'collie', 'owner': None}) - - >>> Dog.objects.filter(data__breed='collie') - ]> - -Multiple keys can be chained together to form a path lookup:: - - >>> Dog.objects.filter(data__owner__name='Bob') - ]> - -If the key is an integer, it will be interpreted as an index lookup in an -array:: - - >>> Dog.objects.filter(data__owner__other_pets__0__name='Fishy') - ]> - -If the key you wish to query by clashes with the name of another lookup, use -the :lookup:`jsonfield.contains` lookup instead. - -If only one key or index is used, the SQL operator ``->`` is used. If multiple -operators are used then the ``#>`` operator is used. - -To query for ``null`` in JSON data, use ``None`` as a value:: - - >>> Dog.objects.filter(data__owner=None) - ]> - -To query for missing keys, use the ``isnull`` lookup:: - - >>> Dog.objects.create(name='Shep', data={'breed': 'collie'}) - >>> Dog.objects.filter(data__owner__isnull=True) - ]> - -.. warning:: - - Since any string could be a key in a JSON object, any lookup other than - those listed below will be interpreted as a key lookup. No errors are - raised. Be extra careful for typing mistakes, and always check your queries - work as you intend. - -Containment and key operations -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. fieldlookup:: jsonfield.contains -.. fieldlookup:: jsonfield.contained_by -.. fieldlookup:: jsonfield.has_key -.. fieldlookup:: jsonfield.has_any_keys -.. fieldlookup:: jsonfield.has_keys - -:class:`~django.contrib.postgres.fields.JSONField` shares lookups relating to -containment and keys with :class:`~django.contrib.postgres.fields.HStoreField`. - -- :lookup:`contains ` (accepts any JSON rather than - just a dictionary of strings) -- :lookup:`contained_by ` (accepts any JSON - rather than just a dictionary of strings) -- :lookup:`has_key ` -- :lookup:`has_any_keys ` -- :lookup:`has_keys ` +See :ref:`querying-jsonfield` for details. .. _range-fields: diff --git a/docs/ref/contrib/postgres/forms.txt b/docs/ref/contrib/postgres/forms.txt index f559ac75cb..14a3ad61de 100644 --- a/docs/ref/contrib/postgres/forms.txt +++ b/docs/ref/contrib/postgres/forms.txt @@ -164,8 +164,8 @@ Fields .. class:: JSONField A field which accepts JSON encoded data for a - :class:`~django.contrib.postgres.fields.JSONField`. It is represented by an - HTML ``', form.as_p()) + + def test_redisplay_wrong_input(self): + """ + Displaying a bound form (typically due to invalid input). The form + should not overquote JSONField inputs. + """ + class JSONForm(Form): + name = CharField(max_length=2) + json_field = JSONField() + + # JSONField input is valid, name is too long. + form = JSONForm({'name': 'xyz', 'json_field': '["foo"]'}) + self.assertNotIn('json_field', form.errors) + self.assertIn('["foo"]', form.as_p()) + # Invalid JSONField. + form = JSONForm({'name': 'xy', 'json_field': '{"foo"}'}) + self.assertEqual(form.errors['json_field'], ['Enter a valid JSON.']) + self.assertIn('{"foo"}', form.as_p()) diff --git a/tests/inspectdb/models.py b/tests/inspectdb/models.py index 8a48031b24..d0076ce94f 100644 --- a/tests/inspectdb/models.py +++ b/tests/inspectdb/models.py @@ -68,6 +68,17 @@ class ColumnTypes(models.Model): uuid_field = models.UUIDField() +class JSONFieldColumnType(models.Model): + json_field = models.JSONField() + null_json_field = models.JSONField(blank=True, null=True) + + class Meta: + required_db_features = { + 'can_introspect_json_field', + 'supports_json_field', + } + + class UniqueTogether(models.Model): field1 = models.IntegerField() field2 = models.CharField(max_length=10) diff --git a/tests/inspectdb/tests.py b/tests/inspectdb/tests.py index 6e3f4b8aa6..afe89e0dda 100644 --- a/tests/inspectdb/tests.py +++ b/tests/inspectdb/tests.py @@ -85,6 +85,15 @@ class InspectDBTestCase(TestCase): elif not connection.features.interprets_empty_strings_as_nulls: assertFieldType('uuid_field', "models.CharField(max_length=32)") + @skipUnlessDBFeature('can_introspect_json_field', 'supports_json_field') + def test_json_field(self): + out = StringIO() + call_command('inspectdb', 'inspectdb_jsonfieldcolumntype', stdout=out) + output = out.getvalue() + if not connection.features.interprets_empty_strings_as_nulls: + self.assertIn('json_field = models.JSONField()', output) + self.assertIn('null_json_field = models.JSONField(blank=True, null=True)', output) + def test_number_field_types(self): """Test introspection of various Django field types""" assertFieldType = self.make_field_type_asserter() diff --git a/tests/invalid_models_tests/test_models.py b/tests/invalid_models_tests/test_models.py index 5a1bb4cc7a..6c062b2990 100644 --- a/tests/invalid_models_tests/test_models.py +++ b/tests/invalid_models_tests/test_models.py @@ -5,7 +5,7 @@ from django.core.checks.model_checks import _check_lazy_references from django.db import connection, connections, models from django.db.models.functions import Lower from django.db.models.signals import post_init -from django.test import SimpleTestCase, TestCase +from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature from django.test.utils import isolate_apps, override_settings, register_lookup @@ -1349,6 +1349,40 @@ class OtherModelTests(SimpleTestCase): ]) +@isolate_apps('invalid_models_tests') +class JSONFieldTests(TestCase): + @skipUnlessDBFeature('supports_json_field') + def test_ordering_pointing_to_json_field_value(self): + class Model(models.Model): + field = models.JSONField() + + class Meta: + ordering = ['field__value'] + + self.assertEqual(Model.check(databases=self.databases), []) + + def test_check_jsonfield(self): + class Model(models.Model): + field = models.JSONField() + + error = Error( + '%s does not support JSONFields.' % connection.display_name, + obj=Model, + id='fields.E180', + ) + expected = [] if connection.features.supports_json_field else [error] + self.assertEqual(Model.check(databases=self.databases), expected) + + def test_check_jsonfield_required_db_features(self): + class Model(models.Model): + field = models.JSONField() + + class Meta: + required_db_features = {'supports_json_field'} + + self.assertEqual(Model.check(databases=self.databases), []) + + @isolate_apps('invalid_models_tests') class ConstraintsTests(TestCase): def test_check_constraints(self): diff --git a/tests/invalid_models_tests/test_ordinary_fields.py b/tests/invalid_models_tests/test_ordinary_fields.py index d263dc5cc9..a81f9eed90 100644 --- a/tests/invalid_models_tests/test_ordinary_fields.py +++ b/tests/invalid_models_tests/test_ordinary_fields.py @@ -3,7 +3,9 @@ import uuid from django.core.checks import Error, Warning as DjangoWarning from django.db import connection, models -from django.test import SimpleTestCase, TestCase, skipIfDBFeature +from django.test import ( + SimpleTestCase, TestCase, skipIfDBFeature, skipUnlessDBFeature, +) from django.test.utils import isolate_apps, override_settings from django.utils.functional import lazy from django.utils.timezone import now @@ -793,3 +795,47 @@ class UUIDFieldTests(TestCase): ) self.assertEqual(Model._meta.get_field('field').check(), []) + + +@isolate_apps('invalid_models_tests') +@skipUnlessDBFeature('supports_json_field') +class JSONFieldTests(TestCase): + def test_invalid_default(self): + class Model(models.Model): + field = models.JSONField(default={}) + + self.assertEqual(Model._meta.get_field('field').check(), [ + DjangoWarning( + msg=( + "JSONField default should be a callable instead of an " + "instance so that it's not shared between all field " + "instances." + ), + hint=( + 'Use a callable instead, e.g., use `dict` instead of `{}`.' + ), + obj=Model._meta.get_field('field'), + id='fields.E010', + ) + ]) + + def test_valid_default(self): + class Model(models.Model): + field = models.JSONField(default=dict) + + self.assertEqual(Model._meta.get_field('field').check(), []) + + def test_valid_default_none(self): + class Model(models.Model): + field = models.JSONField(default=None) + + self.assertEqual(Model._meta.get_field('field').check(), []) + + def test_valid_callable_default(self): + def callable_default(): + return {'it': 'works'} + + class Model(models.Model): + field = models.JSONField(default=callable_default) + + self.assertEqual(Model._meta.get_field('field').check(), []) diff --git a/tests/model_fields/models.py b/tests/model_fields/models.py index a7efe199ab..a11eb0ba44 100644 --- a/tests/model_fields/models.py +++ b/tests/model_fields/models.py @@ -1,3 +1,4 @@ +import json import os import tempfile import uuid @@ -7,6 +8,7 @@ from django.contrib.contenttypes.fields import ( ) from django.contrib.contenttypes.models import ContentType from django.core.files.storage import FileSystemStorage +from django.core.serializers.json import DjangoJSONEncoder from django.db import models from django.db.models.fields.files import ImageFieldFile from django.utils.translation import gettext_lazy as _ @@ -332,6 +334,35 @@ if Image: width_field='headshot_width') +class CustomJSONDecoder(json.JSONDecoder): + def __init__(self, object_hook=None, *args, **kwargs): + return super().__init__(object_hook=self.as_uuid, *args, **kwargs) + + def as_uuid(self, dct): + if 'uuid' in dct: + dct['uuid'] = uuid.UUID(dct['uuid']) + return dct + + +class JSONModel(models.Model): + value = models.JSONField() + + class Meta: + required_db_features = {'supports_json_field'} + + +class NullableJSONModel(models.Model): + value = models.JSONField(blank=True, null=True) + value_custom = models.JSONField( + encoder=DjangoJSONEncoder, + decoder=CustomJSONDecoder, + null=True, + ) + + class Meta: + required_db_features = {'supports_json_field'} + + class AllFieldsModel(models.Model): big_integer = models.BigIntegerField() binary = models.BinaryField() diff --git a/tests/model_fields/test_jsonfield.py b/tests/model_fields/test_jsonfield.py new file mode 100644 index 0000000000..464cf163d4 --- /dev/null +++ b/tests/model_fields/test_jsonfield.py @@ -0,0 +1,667 @@ +import operator +import uuid +from unittest import mock, skipIf, skipUnless + +from django import forms +from django.core import serializers +from django.core.exceptions import ValidationError +from django.core.serializers.json import DjangoJSONEncoder +from django.db import ( + DataError, IntegrityError, NotSupportedError, OperationalError, connection, + models, +) +from django.db.models import Count, F, OuterRef, Q, Subquery, Transform, Value +from django.db.models.expressions import RawSQL +from django.db.models.fields.json import ( + KeyTextTransform, KeyTransform, KeyTransformFactory, + KeyTransformTextLookupMixin, +) +from django.db.models.functions import Cast +from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature +from django.test.utils import CaptureQueriesContext + +from .models import CustomJSONDecoder, JSONModel, NullableJSONModel + + +@skipUnlessDBFeature('supports_json_field') +class JSONFieldTests(TestCase): + def test_invalid_value(self): + msg = 'is not JSON serializable' + with self.assertRaisesMessage(TypeError, msg): + NullableJSONModel.objects.create(value={ + 'uuid': uuid.UUID('d85e2076-b67c-4ee7-8c3a-2bf5a2cc2475'), + }) + + def test_custom_encoder_decoder(self): + value = {'uuid': uuid.UUID('{d85e2076-b67c-4ee7-8c3a-2bf5a2cc2475}')} + obj = NullableJSONModel(value_custom=value) + obj.clean_fields() + obj.save() + obj.refresh_from_db() + self.assertEqual(obj.value_custom, value) + + def test_db_check_constraints(self): + value = '{@!invalid json value 123 $!@#' + with mock.patch.object(DjangoJSONEncoder, 'encode', return_value=value): + with self.assertRaises((IntegrityError, DataError, OperationalError)): + NullableJSONModel.objects.create(value_custom=value) + + +class TestMethods(SimpleTestCase): + def test_deconstruct(self): + field = models.JSONField() + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, 'django.db.models.JSONField') + self.assertEqual(args, []) + self.assertEqual(kwargs, {}) + + def test_deconstruct_custom_encoder_decoder(self): + field = models.JSONField(encoder=DjangoJSONEncoder, decoder=CustomJSONDecoder) + name, path, args, kwargs = field.deconstruct() + self.assertEqual(kwargs['encoder'], DjangoJSONEncoder) + self.assertEqual(kwargs['decoder'], CustomJSONDecoder) + + def test_get_transforms(self): + @models.JSONField.register_lookup + class MyTransform(Transform): + lookup_name = 'my_transform' + field = models.JSONField() + transform = field.get_transform('my_transform') + self.assertIs(transform, MyTransform) + models.JSONField._unregister_lookup(MyTransform) + models.JSONField._clear_cached_lookups() + transform = field.get_transform('my_transform') + self.assertIsInstance(transform, KeyTransformFactory) + + def test_key_transform_text_lookup_mixin_non_key_transform(self): + transform = Transform('test') + msg = ( + 'Transform should be an instance of KeyTransform in order to use ' + 'this lookup.' + ) + with self.assertRaisesMessage(TypeError, msg): + KeyTransformTextLookupMixin(transform) + + +class TestValidation(SimpleTestCase): + def test_invalid_encoder(self): + msg = 'The encoder parameter must be a callable object.' + with self.assertRaisesMessage(ValueError, msg): + models.JSONField(encoder=DjangoJSONEncoder()) + + def test_invalid_decoder(self): + msg = 'The decoder parameter must be a callable object.' + with self.assertRaisesMessage(ValueError, msg): + models.JSONField(decoder=CustomJSONDecoder()) + + def test_validation_error(self): + field = models.JSONField() + msg = 'Value must be valid JSON.' + value = uuid.UUID('{d85e2076-b67c-4ee7-8c3a-2bf5a2cc2475}') + with self.assertRaisesMessage(ValidationError, msg): + field.clean({'uuid': value}, None) + + def test_custom_encoder(self): + field = models.JSONField(encoder=DjangoJSONEncoder) + value = uuid.UUID('{d85e2076-b67c-4ee7-8c3a-2bf5a2cc2475}') + field.clean({'uuid': value}, None) + + +class TestFormField(SimpleTestCase): + def test_formfield(self): + model_field = models.JSONField() + form_field = model_field.formfield() + self.assertIsInstance(form_field, forms.JSONField) + + def test_formfield_custom_encoder_decoder(self): + model_field = models.JSONField(encoder=DjangoJSONEncoder, decoder=CustomJSONDecoder) + form_field = model_field.formfield() + self.assertIs(form_field.encoder, DjangoJSONEncoder) + self.assertIs(form_field.decoder, CustomJSONDecoder) + + +class TestSerialization(SimpleTestCase): + test_data = ( + '[{"fields": {"value": %s}, ' + '"model": "model_fields.jsonmodel", "pk": null}]' + ) + test_values = ( + # (Python value, serialized value), + ({'a': 'b', 'c': None}, '{"a": "b", "c": null}'), + ('abc', '"abc"'), + ('{"a": "a"}', '"{\\"a\\": \\"a\\"}"'), + ) + + def test_dumping(self): + for value, serialized in self.test_values: + with self.subTest(value=value): + instance = JSONModel(value=value) + data = serializers.serialize('json', [instance]) + self.assertJSONEqual(data, self.test_data % serialized) + + def test_loading(self): + for value, serialized in self.test_values: + with self.subTest(value=value): + instance = list( + serializers.deserialize('json', self.test_data % serialized) + )[0].object + self.assertEqual(instance.value, value) + + +@skipUnlessDBFeature('supports_json_field') +class TestSaveLoad(TestCase): + def test_null(self): + obj = NullableJSONModel(value=None) + obj.save() + obj.refresh_from_db() + self.assertIsNone(obj.value) + + @skipUnlessDBFeature('supports_primitives_in_json_field') + def test_json_null_different_from_sql_null(self): + json_null = NullableJSONModel.objects.create(value=Value('null')) + json_null.refresh_from_db() + sql_null = NullableJSONModel.objects.create(value=None) + sql_null.refresh_from_db() + # 'null' is not equal to NULL in the database. + self.assertSequenceEqual( + NullableJSONModel.objects.filter(value=Value('null')), + [json_null], + ) + self.assertSequenceEqual( + NullableJSONModel.objects.filter(value=None), + [json_null], + ) + self.assertSequenceEqual( + NullableJSONModel.objects.filter(value__isnull=True), + [sql_null], + ) + # 'null' is equal to NULL in Python (None). + self.assertEqual(json_null.value, sql_null.value) + + @skipUnlessDBFeature('supports_primitives_in_json_field') + def test_primitives(self): + values = [ + True, + 1, + 1.45, + 'String', + '', + ] + for value in values: + with self.subTest(value=value): + obj = JSONModel(value=value) + obj.save() + obj.refresh_from_db() + self.assertEqual(obj.value, value) + + def test_dict(self): + values = [ + {}, + {'name': 'John', 'age': 20, 'height': 180.3}, + {'a': True, 'b': {'b1': False, 'b2': None}}, + ] + for value in values: + with self.subTest(value=value): + obj = JSONModel.objects.create(value=value) + obj.refresh_from_db() + self.assertEqual(obj.value, value) + + def test_list(self): + values = [ + [], + ['John', 20, 180.3], + [True, [False, None]], + ] + for value in values: + with self.subTest(value=value): + obj = JSONModel.objects.create(value=value) + obj.refresh_from_db() + self.assertEqual(obj.value, value) + + def test_realistic_object(self): + value = { + 'name': 'John', + 'age': 20, + 'pets': [ + {'name': 'Kit', 'type': 'cat', 'age': 2}, + {'name': 'Max', 'type': 'dog', 'age': 1}, + ], + 'courses': [ + ['A1', 'A2', 'A3'], + ['B1', 'B2'], + ['C1'], + ], + } + obj = JSONModel.objects.create(value=value) + obj.refresh_from_db() + self.assertEqual(obj.value, value) + + +@skipUnlessDBFeature('supports_json_field') +class TestQuerying(TestCase): + @classmethod + def setUpTestData(cls): + cls.primitives = [True, False, 'yes', 7, 9.6] + values = [ + None, + [], + {}, + {'a': 'b', 'c': 14}, + { + 'a': 'b', + 'c': 14, + 'd': ['e', {'f': 'g'}], + 'h': True, + 'i': False, + 'j': None, + 'k': {'l': 'm'}, + 'n': [None], + }, + [1, [2]], + {'k': True, 'l': False}, + { + 'foo': 'bar', + 'baz': {'a': 'b', 'c': 'd'}, + 'bar': ['foo', 'bar'], + 'bax': {'foo': 'bar'}, + }, + ] + cls.objs = [ + NullableJSONModel.objects.create(value=value) + for value in values + ] + if connection.features.supports_primitives_in_json_field: + cls.objs.extend([ + NullableJSONModel.objects.create(value=value) + for value in cls.primitives + ]) + cls.raw_sql = '%s::jsonb' if connection.vendor == 'postgresql' else '%s' + + def test_exact(self): + self.assertSequenceEqual( + NullableJSONModel.objects.filter(value__exact={}), + [self.objs[2]], + ) + + def test_exact_complex(self): + self.assertSequenceEqual( + NullableJSONModel.objects.filter(value__exact={'a': 'b', 'c': 14}), + [self.objs[3]], + ) + + def test_isnull(self): + self.assertSequenceEqual( + NullableJSONModel.objects.filter(value__isnull=True), + [self.objs[0]], + ) + + def test_ordering_by_transform(self): + objs = [ + NullableJSONModel.objects.create(value={'ord': 93, 'name': 'bar'}), + NullableJSONModel.objects.create(value={'ord': 22.1, 'name': 'foo'}), + NullableJSONModel.objects.create(value={'ord': -1, 'name': 'baz'}), + NullableJSONModel.objects.create(value={'ord': 21.931902, 'name': 'spam'}), + NullableJSONModel.objects.create(value={'ord': -100291029, 'name': 'eggs'}), + ] + query = NullableJSONModel.objects.filter(value__name__isnull=False).order_by('value__ord') + expected = [objs[4], objs[2], objs[3], objs[1], objs[0]] + mariadb = connection.vendor == 'mysql' and connection.mysql_is_mariadb + if mariadb or connection.vendor == 'oracle': + # MariaDB and Oracle return JSON values as strings. + expected = [objs[2], objs[4], objs[3], objs[1], objs[0]] + self.assertSequenceEqual(query, expected) + + def test_ordering_grouping_by_key_transform(self): + base_qs = NullableJSONModel.objects.filter(value__d__0__isnull=False) + for qs in ( + base_qs.order_by('value__d__0'), + base_qs.annotate(key=KeyTransform('0', KeyTransform('d', 'value'))).order_by('key'), + ): + self.assertSequenceEqual(qs, [self.objs[4]]) + qs = NullableJSONModel.objects.filter(value__isnull=False) + self.assertQuerysetEqual( + qs.filter(value__isnull=False).annotate( + key=KeyTextTransform('f', KeyTransform('1', KeyTransform('d', 'value'))), + ).values('key').annotate(count=Count('key')).order_by('count'), + [(None, 0), ('g', 1)], + operator.itemgetter('key', 'count'), + ) + + @skipIf(connection.vendor == 'oracle', "Oracle doesn't support grouping by LOBs, see #24096.") + def test_ordering_grouping_by_count(self): + qs = NullableJSONModel.objects.filter( + value__isnull=False, + ).values('value__d__0').annotate(count=Count('value__d__0')).order_by('count') + self.assertQuerysetEqual(qs, [1, 11], operator.itemgetter('count')) + + def test_key_transform_raw_expression(self): + expr = RawSQL(self.raw_sql, ['{"x": "bar"}']) + self.assertSequenceEqual( + NullableJSONModel.objects.filter(value__foo=KeyTransform('x', expr)), + [self.objs[7]], + ) + + def test_nested_key_transform_raw_expression(self): + expr = RawSQL(self.raw_sql, ['{"x": {"y": "bar"}}']) + self.assertSequenceEqual( + NullableJSONModel.objects.filter(value__foo=KeyTransform('y', KeyTransform('x', expr))), + [self.objs[7]], + ) + + def test_key_transform_expression(self): + self.assertSequenceEqual( + NullableJSONModel.objects.filter(value__d__0__isnull=False).annotate( + key=KeyTransform('d', 'value'), + chain=KeyTransform('0', 'key'), + expr=KeyTransform('0', Cast('key', models.JSONField())), + ).filter(chain=F('expr')), + [self.objs[4]], + ) + + def test_nested_key_transform_expression(self): + self.assertSequenceEqual( + NullableJSONModel.objects.filter(value__d__0__isnull=False).annotate( + key=KeyTransform('d', 'value'), + chain=KeyTransform('f', KeyTransform('1', 'key')), + expr=KeyTransform('f', KeyTransform('1', Cast('key', models.JSONField()))), + ).filter(chain=F('expr')), + [self.objs[4]], + ) + + def test_has_key(self): + self.assertSequenceEqual( + NullableJSONModel.objects.filter(value__has_key='a'), + [self.objs[3], self.objs[4]], + ) + + def test_has_key_null_value(self): + self.assertSequenceEqual( + NullableJSONModel.objects.filter(value__has_key='j'), + [self.objs[4]], + ) + + def test_has_key_deep(self): + tests = [ + (Q(value__baz__has_key='a'), self.objs[7]), + (Q(value__has_key=KeyTransform('a', KeyTransform('baz', 'value'))), self.objs[7]), + (Q(value__has_key=KeyTransform('c', KeyTransform('baz', 'value'))), self.objs[7]), + (Q(value__d__1__has_key='f'), self.objs[4]), + ( + Q(value__has_key=KeyTransform('f', KeyTransform('1', KeyTransform('d', 'value')))), + self.objs[4], + ) + ] + for condition, expected in tests: + with self.subTest(condition=condition): + self.assertSequenceEqual( + NullableJSONModel.objects.filter(condition), + [expected], + ) + + def test_has_key_list(self): + obj = NullableJSONModel.objects.create(value=[{'a': 1}, {'b': 'x'}]) + tests = [ + Q(value__1__has_key='b'), + Q(value__has_key=KeyTransform('b', KeyTransform(1, 'value'))), + Q(value__has_key=KeyTransform('b', KeyTransform('1', 'value'))), + ] + for condition in tests: + with self.subTest(condition=condition): + self.assertSequenceEqual( + NullableJSONModel.objects.filter(condition), + [obj], + ) + + def test_has_keys(self): + self.assertSequenceEqual( + NullableJSONModel.objects.filter(value__has_keys=['a', 'c', 'h']), + [self.objs[4]], + ) + + def test_has_any_keys(self): + self.assertSequenceEqual( + NullableJSONModel.objects.filter(value__has_any_keys=['c', 'l']), + [self.objs[3], self.objs[4], self.objs[6]], + ) + + def test_contains(self): + tests = [ + ({}, self.objs[2:5] + self.objs[6:8]), + ({'baz': {'a': 'b', 'c': 'd'}}, [self.objs[7]]), + ({'k': True, 'l': False}, [self.objs[6]]), + ({'d': ['e', {'f': 'g'}]}, [self.objs[4]]), + ([1, [2]], [self.objs[5]]), + ({'n': [None]}, [self.objs[4]]), + ({'j': None}, [self.objs[4]]), + ] + for value, expected in tests: + with self.subTest(value=value): + qs = NullableJSONModel.objects.filter(value__contains=value) + self.assertSequenceEqual(qs, expected) + + @skipUnlessDBFeature('supports_primitives_in_json_field') + def test_contains_primitives(self): + for value in self.primitives: + with self.subTest(value=value): + qs = NullableJSONModel.objects.filter(value__contains=value) + self.assertIs(qs.exists(), True) + + @skipIf( + connection.vendor == 'oracle', + "Oracle doesn't support contained_by lookup.", + ) + def test_contained_by(self): + qs = NullableJSONModel.objects.filter(value__contained_by={'a': 'b', 'c': 14, 'h': True}) + self.assertSequenceEqual(qs, self.objs[2:4]) + + @skipUnless( + connection.vendor == 'oracle', + "Oracle doesn't support contained_by lookup.", + ) + def test_contained_by_unsupported(self): + msg = 'contained_by lookup is not supported on Oracle.' + with self.assertRaisesMessage(NotSupportedError, msg): + NullableJSONModel.objects.filter(value__contained_by={'a': 'b'}).get() + + def test_deep_values(self): + qs = NullableJSONModel.objects.values_list('value__k__l') + expected_objs = [(None,)] * len(self.objs) + expected_objs[4] = ('m',) + self.assertSequenceEqual(qs, expected_objs) + + @skipUnlessDBFeature('can_distinct_on_fields') + def test_deep_distinct(self): + query = NullableJSONModel.objects.distinct('value__k__l').values_list('value__k__l') + self.assertSequenceEqual(query, [('m',), (None,)]) + + def test_isnull_key(self): + # key__isnull=False works the same as has_key='key'. + self.assertSequenceEqual( + NullableJSONModel.objects.filter(value__a__isnull=True), + self.objs[:3] + self.objs[5:], + ) + self.assertSequenceEqual( + NullableJSONModel.objects.filter(value__a__isnull=False), + [self.objs[3], self.objs[4]], + ) + self.assertSequenceEqual( + NullableJSONModel.objects.filter(value__j__isnull=False), + [self.objs[4]], + ) + + def test_isnull_key_or_none(self): + obj = NullableJSONModel.objects.create(value={'a': None}) + self.assertSequenceEqual( + NullableJSONModel.objects.filter(Q(value__a__isnull=True) | Q(value__a=None)), + self.objs[:3] + self.objs[5:] + [obj], + ) + + def test_none_key(self): + self.assertSequenceEqual( + NullableJSONModel.objects.filter(value__j=None), + [self.objs[4]], + ) + + def test_none_key_exclude(self): + obj = NullableJSONModel.objects.create(value={'j': 1}) + if connection.vendor == 'oracle': + # Oracle supports filtering JSON objects with NULL keys, but the + # current implementation doesn't support it. + self.assertSequenceEqual( + NullableJSONModel.objects.exclude(value__j=None), + self.objs[1:4] + self.objs[5:] + [obj], + ) + else: + self.assertSequenceEqual(NullableJSONModel.objects.exclude(value__j=None), [obj]) + + def test_shallow_list_lookup(self): + self.assertSequenceEqual( + NullableJSONModel.objects.filter(value__0=1), + [self.objs[5]], + ) + + def test_shallow_obj_lookup(self): + self.assertSequenceEqual( + NullableJSONModel.objects.filter(value__a='b'), + [self.objs[3], self.objs[4]], + ) + + def test_obj_subquery_lookup(self): + qs = NullableJSONModel.objects.annotate( + field=Subquery(NullableJSONModel.objects.filter(pk=OuterRef('pk')).values('value')), + ).filter(field__a='b') + self.assertSequenceEqual(qs, [self.objs[3], self.objs[4]]) + + def test_deep_lookup_objs(self): + self.assertSequenceEqual( + NullableJSONModel.objects.filter(value__k__l='m'), + [self.objs[4]], + ) + + def test_shallow_lookup_obj_target(self): + self.assertSequenceEqual( + NullableJSONModel.objects.filter(value__k={'l': 'm'}), + [self.objs[4]], + ) + + def test_deep_lookup_array(self): + self.assertSequenceEqual( + NullableJSONModel.objects.filter(value__1__0=2), + [self.objs[5]], + ) + + def test_deep_lookup_mixed(self): + self.assertSequenceEqual( + NullableJSONModel.objects.filter(value__d__1__f='g'), + [self.objs[4]], + ) + + def test_deep_lookup_transform(self): + self.assertSequenceEqual( + NullableJSONModel.objects.filter(value__c__gt=2), + [self.objs[3], self.objs[4]], + ) + self.assertSequenceEqual( + NullableJSONModel.objects.filter(value__c__gt=2.33), + [self.objs[3], self.objs[4]], + ) + self.assertIs(NullableJSONModel.objects.filter(value__c__lt=5).exists(), False) + + @skipIf( + connection.vendor == 'oracle', + 'Raises ORA-00600: internal error code on Oracle 18.', + ) + def test_usage_in_subquery(self): + self.assertSequenceEqual( + NullableJSONModel.objects.filter( + id__in=NullableJSONModel.objects.filter(value__c=14), + ), + self.objs[3:5], + ) + + def test_key_iexact(self): + self.assertIs(NullableJSONModel.objects.filter(value__foo__iexact='BaR').exists(), True) + self.assertIs(NullableJSONModel.objects.filter(value__foo__iexact='"BaR"').exists(), False) + + def test_key_contains(self): + self.assertIs(NullableJSONModel.objects.filter(value__foo__contains='ar').exists(), True) + + def test_key_icontains(self): + self.assertIs(NullableJSONModel.objects.filter(value__foo__icontains='Ar').exists(), True) + + def test_key_startswith(self): + self.assertIs(NullableJSONModel.objects.filter(value__foo__startswith='b').exists(), True) + + def test_key_istartswith(self): + self.assertIs(NullableJSONModel.objects.filter(value__foo__istartswith='B').exists(), True) + + def test_key_endswith(self): + self.assertIs(NullableJSONModel.objects.filter(value__foo__endswith='r').exists(), True) + + def test_key_iendswith(self): + self.assertIs(NullableJSONModel.objects.filter(value__foo__iendswith='R').exists(), True) + + def test_key_regex(self): + self.assertIs(NullableJSONModel.objects.filter(value__foo__regex=r'^bar$').exists(), True) + + def test_key_iregex(self): + self.assertIs(NullableJSONModel.objects.filter(value__foo__iregex=r'^bAr$').exists(), True) + + @skipUnless(connection.vendor == 'postgresql', 'kwargs are crafted for PostgreSQL.') + def test_key_sql_injection(self): + with CaptureQueriesContext(connection) as queries: + self.assertIs( + NullableJSONModel.objects.filter(**{ + """value__test' = '"a"') OR 1 = 1 OR ('d""": 'x', + }).exists(), + False, + ) + self.assertIn( + """."value" -> 'test'' = ''"a"'') OR 1 = 1 OR (''d') = '"x"' """, + queries[0]['sql'], + ) + + @skipIf(connection.vendor == 'postgresql', 'PostgreSQL uses operators not functions.') + def test_key_sql_injection_escape(self): + query = str(JSONModel.objects.filter(**{ + """value__test") = '"a"' OR 1 = 1 OR ("d""": 'x', + }).query) + self.assertIn('"test\\"', query) + self.assertIn('\\"d', query) + + def test_key_escape(self): + obj = NullableJSONModel.objects.create(value={'%total': 10}) + self.assertEqual(NullableJSONModel.objects.filter(**{'value__%total': 10}).get(), obj) + + def test_none_key_and_exact_lookup(self): + self.assertSequenceEqual( + NullableJSONModel.objects.filter(value__a='b', value__j=None), + [self.objs[4]], + ) + + def test_lookups_with_key_transform(self): + tests = ( + ('value__d__contains', 'e'), + ('value__baz__has_key', 'c'), + ('value__baz__has_keys', ['a', 'c']), + ('value__baz__has_any_keys', ['a', 'x']), + ('value__contains', KeyTransform('bax', 'value')), + ('value__has_key', KeyTextTransform('foo', 'value')), + ) + # contained_by lookup is not supported on Oracle. + if connection.vendor != 'oracle': + tests += ( + ('value__baz__contained_by', {'a': 'b', 'c': 'd', 'e': 'f'}), + ( + 'value__contained_by', + KeyTransform('x', RawSQL( + self.raw_sql, + ['{"x": {"a": "b", "c": 1, "d": "e"}}'], + )), + ), + ) + for lookup, value in tests: + with self.subTest(lookup=lookup): + self.assertIs(NullableJSONModel.objects.filter( + **{lookup: value}, + ).exists(), True) diff --git a/tests/postgres_tests/fields.py b/tests/postgres_tests/fields.py index 4ebc0ce7dc..a36c10c750 100644 --- a/tests/postgres_tests/fields.py +++ b/tests/postgres_tests/fields.py @@ -10,7 +10,7 @@ try: from django.contrib.postgres.fields import ( ArrayField, BigIntegerRangeField, CICharField, CIEmailField, CITextField, DateRangeField, DateTimeRangeField, DecimalRangeField, - HStoreField, IntegerRangeField, JSONField, + HStoreField, IntegerRangeField, ) from django.contrib.postgres.search import SearchVectorField except ImportError: @@ -26,10 +26,6 @@ except ImportError: }) return name, path, args, kwargs - class DummyJSONField(models.Field): - def __init__(self, encoder=None, **kwargs): - super().__init__(**kwargs) - ArrayField = DummyArrayField BigIntegerRangeField = models.Field CICharField = models.Field @@ -40,7 +36,6 @@ except ImportError: DecimalRangeField = models.Field HStoreField = models.Field IntegerRangeField = models.Field - JSONField = DummyJSONField SearchVectorField = models.Field diff --git a/tests/postgres_tests/migrations/0002_create_test_models.py b/tests/postgres_tests/migrations/0002_create_test_models.py index ee1463e1eb..cb5f4c6d3e 100644 --- a/tests/postgres_tests/migrations/0002_create_test_models.py +++ b/tests/postgres_tests/migrations/0002_create_test_models.py @@ -1,10 +1,9 @@ -from django.core.serializers.json import DjangoJSONEncoder from django.db import migrations, models from ..fields import ( ArrayField, BigIntegerRangeField, CICharField, CIEmailField, CITextField, DateRangeField, DateTimeRangeField, DecimalRangeField, EnumField, - HStoreField, IntegerRangeField, JSONField, SearchVectorField, + HStoreField, IntegerRangeField, SearchVectorField, ) from ..models import TagField @@ -60,7 +59,7 @@ class Migration(migrations.Migration): ('uuids', ArrayField(models.UUIDField(), size=None, default=list)), ('decimals', ArrayField(models.DecimalField(max_digits=5, decimal_places=2), size=None, default=list)), ('tags', ArrayField(TagField(), blank=True, null=True, size=None)), - ('json', ArrayField(JSONField(default={}), default=[])), + ('json', ArrayField(models.JSONField(default={}), default=[])), ('int_ranges', ArrayField(IntegerRangeField(), null=True, blank=True)), ('bigint_ranges', ArrayField(BigIntegerRangeField(), null=True, blank=True)), ], @@ -270,18 +269,6 @@ class Migration(migrations.Migration): }, bases=(models.Model,), ), - migrations.CreateModel( - name='JSONModel', - fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('field', JSONField(null=True, blank=True)), - ('field_custom', JSONField(null=True, blank=True, encoder=DjangoJSONEncoder)), - ], - options={ - 'required_db_vendor': 'postgresql', - }, - bases=(models.Model,), - ), migrations.CreateModel( name='ArrayEnumModel', fields=[ diff --git a/tests/postgres_tests/models.py b/tests/postgres_tests/models.py index 24605954b2..464245fbab 100644 --- a/tests/postgres_tests/models.py +++ b/tests/postgres_tests/models.py @@ -1,10 +1,9 @@ -from django.core.serializers.json import DjangoJSONEncoder from django.db import models from .fields import ( ArrayField, BigIntegerRangeField, CICharField, CIEmailField, CITextField, DateRangeField, DateTimeRangeField, DecimalRangeField, EnumField, - HStoreField, IntegerRangeField, JSONField, SearchVectorField, + HStoreField, IntegerRangeField, SearchVectorField, ) @@ -68,7 +67,7 @@ class OtherTypesArrayModel(PostgreSQLModel): uuids = ArrayField(models.UUIDField(), default=list) decimals = ArrayField(models.DecimalField(max_digits=5, decimal_places=2), default=list) tags = ArrayField(TagField(), blank=True, null=True) - json = ArrayField(JSONField(default=dict), default=list) + json = ArrayField(models.JSONField(default=dict), default=list) int_ranges = ArrayField(IntegerRangeField(), blank=True, null=True) bigint_ranges = ArrayField(BigIntegerRangeField(), blank=True, null=True) @@ -150,11 +149,6 @@ class RangeLookupsModel(PostgreSQLModel): decimal_field = models.DecimalField(max_digits=5, decimal_places=2, blank=True, null=True) -class JSONModel(PostgreSQLModel): - field = JSONField(blank=True, null=True) - field_custom = JSONField(blank=True, null=True, encoder=DjangoJSONEncoder) - - class ArrayFieldSubclass(ArrayField): def __init__(self, *args, **kwargs): super().__init__(models.IntegerField()) diff --git a/tests/postgres_tests/test_bulk_update.py b/tests/postgres_tests/test_bulk_update.py index 6dd7036a9b..7fa2a6a7db 100644 --- a/tests/postgres_tests/test_bulk_update.py +++ b/tests/postgres_tests/test_bulk_update.py @@ -2,7 +2,7 @@ from datetime import date from . import PostgreSQLTestCase from .models import ( - HStoreModel, IntegerArrayModel, JSONModel, NestedIntegerArrayModel, + HStoreModel, IntegerArrayModel, NestedIntegerArrayModel, NullableIntegerArrayModel, OtherTypesArrayModel, RangesModel, ) @@ -17,7 +17,6 @@ class BulkSaveTests(PostgreSQLTestCase): test_data = [ (IntegerArrayModel, 'field', [], [1, 2, 3]), (NullableIntegerArrayModel, 'field', [1, 2, 3], None), - (JSONModel, 'field', {'a': 'b'}, {'c': 'd'}), (NestedIntegerArrayModel, 'field', [], [[1, 2, 3]]), (HStoreModel, 'field', {}, {1: 2}), (RangesModel, 'ints', None, NumericRange(lower=1, upper=10)), diff --git a/tests/postgres_tests/test_introspection.py b/tests/postgres_tests/test_introspection.py index 8ae5b80da1..50cb9b2828 100644 --- a/tests/postgres_tests/test_introspection.py +++ b/tests/postgres_tests/test_introspection.py @@ -19,12 +19,6 @@ class InspectDBTests(PostgreSQLTestCase): for field_output in field_outputs: self.assertIn(field_output, output) - def test_json_field(self): - self.assertFieldsInModel( - 'postgres_tests_jsonmodel', - ['field = django.contrib.postgres.fields.JSONField(blank=True, null=True)'], - ) - def test_range_fields(self): self.assertFieldsInModel( 'postgres_tests_rangesmodel', diff --git a/tests/postgres_tests/test_json.py b/tests/postgres_tests/test_json.py deleted file mode 100644 index 2ff765e918..0000000000 --- a/tests/postgres_tests/test_json.py +++ /dev/null @@ -1,583 +0,0 @@ -import datetime -import operator -import uuid -from decimal import Decimal - -from django.core import checks, exceptions, serializers -from django.core.serializers.json import DjangoJSONEncoder -from django.db import connection -from django.db.models import Count, F, OuterRef, Q, Subquery -from django.db.models.expressions import RawSQL -from django.db.models.functions import Cast -from django.forms import CharField, Form, widgets -from django.test.utils import CaptureQueriesContext, isolate_apps -from django.utils.html import escape - -from . import PostgreSQLSimpleTestCase, PostgreSQLTestCase -from .models import JSONModel, PostgreSQLModel - -try: - from django.contrib.postgres import forms - from django.contrib.postgres.fields import JSONField - from django.contrib.postgres.fields.jsonb import KeyTextTransform, KeyTransform -except ImportError: - pass - - -class TestModelMetaOrdering(PostgreSQLSimpleTestCase): - def test_ordering_by_json_field_value(self): - class TestJSONModel(JSONModel): - class Meta: - ordering = ['field__value'] - - self.assertEqual(TestJSONModel.check(), []) - - -class TestSaveLoad(PostgreSQLTestCase): - def test_null(self): - instance = JSONModel() - instance.save() - loaded = JSONModel.objects.get() - self.assertIsNone(loaded.field) - - def test_empty_object(self): - instance = JSONModel(field={}) - instance.save() - loaded = JSONModel.objects.get() - self.assertEqual(loaded.field, {}) - - def test_empty_list(self): - instance = JSONModel(field=[]) - instance.save() - loaded = JSONModel.objects.get() - self.assertEqual(loaded.field, []) - - def test_boolean(self): - instance = JSONModel(field=True) - instance.save() - loaded = JSONModel.objects.get() - self.assertIs(loaded.field, True) - - def test_string(self): - instance = JSONModel(field='why?') - instance.save() - loaded = JSONModel.objects.get() - self.assertEqual(loaded.field, 'why?') - - def test_number(self): - instance = JSONModel(field=1) - instance.save() - loaded = JSONModel.objects.get() - self.assertEqual(loaded.field, 1) - - def test_realistic_object(self): - obj = { - 'a': 'b', - 'c': 1, - 'd': ['e', {'f': 'g'}], - 'h': True, - 'i': False, - 'j': None, - } - instance = JSONModel(field=obj) - instance.save() - loaded = JSONModel.objects.get() - self.assertEqual(loaded.field, obj) - - def test_custom_encoding(self): - """ - JSONModel.field_custom has a custom DjangoJSONEncoder. - """ - some_uuid = uuid.uuid4() - obj_before = { - 'date': datetime.date(2016, 8, 12), - 'datetime': datetime.datetime(2016, 8, 12, 13, 44, 47, 575981), - 'decimal': Decimal('10.54'), - 'uuid': some_uuid, - } - obj_after = { - 'date': '2016-08-12', - 'datetime': '2016-08-12T13:44:47.575', - 'decimal': '10.54', - 'uuid': str(some_uuid), - } - JSONModel.objects.create(field_custom=obj_before) - loaded = JSONModel.objects.get() - self.assertEqual(loaded.field_custom, obj_after) - - -class TestQuerying(PostgreSQLTestCase): - @classmethod - def setUpTestData(cls): - cls.objs = JSONModel.objects.bulk_create([ - JSONModel(field=None), - JSONModel(field=True), - JSONModel(field=False), - JSONModel(field='yes'), - JSONModel(field=7), - JSONModel(field=[]), - JSONModel(field={}), - JSONModel(field={ - 'a': 'b', - 'c': 1, - }), - JSONModel(field={ - 'a': 'b', - 'c': 1, - 'd': ['e', {'f': 'g'}], - 'h': True, - 'i': False, - 'j': None, - 'k': {'l': 'm'}, - }), - JSONModel(field=[1, [2]]), - JSONModel(field={ - 'k': True, - 'l': False, - }), - JSONModel(field={ - 'foo': 'bar', - 'baz': {'a': 'b', 'c': 'd'}, - 'bar': ['foo', 'bar'], - 'bax': {'foo': 'bar'}, - }), - ]) - - def test_exact(self): - self.assertSequenceEqual( - JSONModel.objects.filter(field__exact={}), - [self.objs[6]] - ) - - def test_exact_complex(self): - self.assertSequenceEqual( - JSONModel.objects.filter(field__exact={'a': 'b', 'c': 1}), - [self.objs[7]] - ) - - def test_isnull(self): - self.assertSequenceEqual( - JSONModel.objects.filter(field__isnull=True), - [self.objs[0]] - ) - - def test_ordering_by_transform(self): - objs = [ - JSONModel.objects.create(field={'ord': 93, 'name': 'bar'}), - JSONModel.objects.create(field={'ord': 22.1, 'name': 'foo'}), - JSONModel.objects.create(field={'ord': -1, 'name': 'baz'}), - JSONModel.objects.create(field={'ord': 21.931902, 'name': 'spam'}), - JSONModel.objects.create(field={'ord': -100291029, 'name': 'eggs'}), - ] - query = JSONModel.objects.filter(field__name__isnull=False).order_by('field__ord') - self.assertSequenceEqual(query, [objs[4], objs[2], objs[3], objs[1], objs[0]]) - - def test_ordering_grouping_by_key_transform(self): - base_qs = JSONModel.objects.filter(field__d__0__isnull=False) - for qs in ( - base_qs.order_by('field__d__0'), - base_qs.annotate(key=KeyTransform('0', KeyTransform('d', 'field'))).order_by('key'), - ): - self.assertSequenceEqual(qs, [self.objs[8]]) - qs = JSONModel.objects.filter(field__isnull=False) - self.assertQuerysetEqual( - qs.values('field__d__0').annotate(count=Count('field__d__0')).order_by('count'), - [1, 10], - operator.itemgetter('count'), - ) - self.assertQuerysetEqual( - qs.filter(field__isnull=False).annotate( - key=KeyTextTransform('f', KeyTransform('1', KeyTransform('d', 'field'))), - ).values('key').annotate(count=Count('key')).order_by('count'), - [(None, 0), ('g', 1)], - operator.itemgetter('key', 'count'), - ) - - def test_key_transform_raw_expression(self): - expr = RawSQL('%s::jsonb', ['{"x": "bar"}']) - self.assertSequenceEqual( - JSONModel.objects.filter(field__foo=KeyTransform('x', expr)), - [self.objs[-1]], - ) - - def test_key_transform_expression(self): - self.assertSequenceEqual( - JSONModel.objects.filter(field__d__0__isnull=False).annotate( - key=KeyTransform('d', 'field'), - chain=KeyTransform('0', 'key'), - expr=KeyTransform('0', Cast('key', JSONField())), - ).filter(chain=F('expr')), - [self.objs[8]], - ) - - def test_nested_key_transform_raw_expression(self): - expr = RawSQL('%s::jsonb', ['{"x": {"y": "bar"}}']) - self.assertSequenceEqual( - JSONModel.objects.filter(field__foo=KeyTransform('y', KeyTransform('x', expr))), - [self.objs[-1]], - ) - - def test_nested_key_transform_expression(self): - self.assertSequenceEqual( - JSONModel.objects.filter(field__d__0__isnull=False).annotate( - key=KeyTransform('d', 'field'), - chain=KeyTransform('f', KeyTransform('1', 'key')), - expr=KeyTransform('f', KeyTransform('1', Cast('key', JSONField()))), - ).filter(chain=F('expr')), - [self.objs[8]], - ) - - def test_deep_values(self): - query = JSONModel.objects.values_list('field__k__l') - self.assertSequenceEqual( - query, - [ - (None,), (None,), (None,), (None,), (None,), (None,), - (None,), (None,), ('m',), (None,), (None,), (None,), - ] - ) - - def test_deep_distinct(self): - query = JSONModel.objects.distinct('field__k__l').values_list('field__k__l') - self.assertSequenceEqual(query, [('m',), (None,)]) - - def test_isnull_key(self): - # key__isnull works the same as has_key='key'. - self.assertSequenceEqual( - JSONModel.objects.filter(field__a__isnull=True), - self.objs[:7] + self.objs[9:] - ) - self.assertSequenceEqual( - JSONModel.objects.filter(field__a__isnull=False), - [self.objs[7], self.objs[8]] - ) - - def test_none_key(self): - self.assertSequenceEqual(JSONModel.objects.filter(field__j=None), [self.objs[8]]) - - def test_none_key_exclude(self): - obj = JSONModel.objects.create(field={'j': 1}) - self.assertSequenceEqual(JSONModel.objects.exclude(field__j=None), [obj]) - - def test_isnull_key_or_none(self): - obj = JSONModel.objects.create(field={'a': None}) - self.assertSequenceEqual( - JSONModel.objects.filter(Q(field__a__isnull=True) | Q(field__a=None)), - self.objs[:7] + self.objs[9:] + [obj] - ) - - def test_contains(self): - self.assertSequenceEqual( - JSONModel.objects.filter(field__contains={'a': 'b'}), - [self.objs[7], self.objs[8]] - ) - - def test_contained_by(self): - self.assertSequenceEqual( - JSONModel.objects.filter(field__contained_by={'a': 'b', 'c': 1, 'h': True}), - [self.objs[6], self.objs[7]] - ) - - def test_has_key(self): - self.assertSequenceEqual( - JSONModel.objects.filter(field__has_key='a'), - [self.objs[7], self.objs[8]] - ) - - def test_has_keys(self): - self.assertSequenceEqual( - JSONModel.objects.filter(field__has_keys=['a', 'c', 'h']), - [self.objs[8]] - ) - - def test_has_any_keys(self): - self.assertSequenceEqual( - JSONModel.objects.filter(field__has_any_keys=['c', 'l']), - [self.objs[7], self.objs[8], self.objs[10]] - ) - - def test_shallow_list_lookup(self): - self.assertSequenceEqual( - JSONModel.objects.filter(field__0=1), - [self.objs[9]] - ) - - def test_shallow_obj_lookup(self): - self.assertSequenceEqual( - JSONModel.objects.filter(field__a='b'), - [self.objs[7], self.objs[8]] - ) - - def test_obj_subquery_lookup(self): - qs = JSONModel.objects.annotate( - value=Subquery(JSONModel.objects.filter(pk=OuterRef('pk')).values('field')), - ).filter(value__a='b') - self.assertSequenceEqual(qs, [self.objs[7], self.objs[8]]) - - def test_deep_lookup_objs(self): - self.assertSequenceEqual( - JSONModel.objects.filter(field__k__l='m'), - [self.objs[8]] - ) - - def test_shallow_lookup_obj_target(self): - self.assertSequenceEqual( - JSONModel.objects.filter(field__k={'l': 'm'}), - [self.objs[8]] - ) - - def test_deep_lookup_array(self): - self.assertSequenceEqual( - JSONModel.objects.filter(field__1__0=2), - [self.objs[9]] - ) - - def test_deep_lookup_mixed(self): - self.assertSequenceEqual( - JSONModel.objects.filter(field__d__1__f='g'), - [self.objs[8]] - ) - - def test_deep_lookup_transform(self): - self.assertSequenceEqual( - JSONModel.objects.filter(field__c__gt=1), - [] - ) - self.assertSequenceEqual( - JSONModel.objects.filter(field__c__lt=5), - [self.objs[7], self.objs[8]] - ) - - def test_usage_in_subquery(self): - self.assertSequenceEqual( - JSONModel.objects.filter(id__in=JSONModel.objects.filter(field__c=1)), - self.objs[7:9] - ) - - def test_iexact(self): - self.assertTrue(JSONModel.objects.filter(field__foo__iexact='BaR').exists()) - self.assertFalse(JSONModel.objects.filter(field__foo__iexact='"BaR"').exists()) - - def test_icontains(self): - self.assertFalse(JSONModel.objects.filter(field__foo__icontains='"bar"').exists()) - - def test_startswith(self): - self.assertTrue(JSONModel.objects.filter(field__foo__startswith='b').exists()) - - def test_istartswith(self): - self.assertTrue(JSONModel.objects.filter(field__foo__istartswith='B').exists()) - - def test_endswith(self): - self.assertTrue(JSONModel.objects.filter(field__foo__endswith='r').exists()) - - def test_iendswith(self): - self.assertTrue(JSONModel.objects.filter(field__foo__iendswith='R').exists()) - - def test_regex(self): - self.assertTrue(JSONModel.objects.filter(field__foo__regex=r'^bar$').exists()) - - def test_iregex(self): - self.assertTrue(JSONModel.objects.filter(field__foo__iregex=r'^bAr$').exists()) - - def test_key_sql_injection(self): - with CaptureQueriesContext(connection) as queries: - self.assertFalse( - JSONModel.objects.filter(**{ - """field__test' = '"a"') OR 1 = 1 OR ('d""": 'x', - }).exists() - ) - self.assertIn( - """."field" -> 'test'' = ''"a"'') OR 1 = 1 OR (''d') = '"x"' """, - queries[0]['sql'], - ) - - def test_lookups_with_key_transform(self): - tests = ( - ('field__d__contains', 'e'), - ('field__baz__contained_by', {'a': 'b', 'c': 'd', 'e': 'f'}), - ('field__baz__has_key', 'c'), - ('field__baz__has_keys', ['a', 'c']), - ('field__baz__has_any_keys', ['a', 'x']), - ('field__contains', KeyTransform('bax', 'field')), - ( - 'field__contained_by', - KeyTransform('x', RawSQL('%s::jsonb', ['{"x": {"a": "b", "c": 1, "d": "e"}}'])), - ), - ('field__has_key', KeyTextTransform('foo', 'field')), - ) - for lookup, value in tests: - with self.subTest(lookup=lookup): - self.assertTrue(JSONModel.objects.filter( - **{lookup: value}, - ).exists()) - - def test_key_escape(self): - obj = JSONModel.objects.create(field={'%total': 10}) - self.assertEqual(JSONModel.objects.filter(**{'field__%total': 10}).get(), obj) - - -@isolate_apps('postgres_tests') -class TestChecks(PostgreSQLSimpleTestCase): - - def test_invalid_default(self): - class MyModel(PostgreSQLModel): - field = JSONField(default={}) - - model = MyModel() - self.assertEqual(model.check(), [ - checks.Warning( - msg=( - "JSONField default should be a callable instead of an " - "instance so that it's not shared between all field " - "instances." - ), - hint='Use a callable instead, e.g., use `dict` instead of `{}`.', - obj=MyModel._meta.get_field('field'), - id='fields.E010', - ) - ]) - - def test_valid_default(self): - class MyModel(PostgreSQLModel): - field = JSONField(default=dict) - - model = MyModel() - self.assertEqual(model.check(), []) - - def test_valid_default_none(self): - class MyModel(PostgreSQLModel): - field = JSONField(default=None) - - model = MyModel() - self.assertEqual(model.check(), []) - - -class TestSerialization(PostgreSQLSimpleTestCase): - test_data = ( - '[{"fields": {"field": %s, "field_custom": null}, ' - '"model": "postgres_tests.jsonmodel", "pk": null}]' - ) - test_values = ( - # (Python value, serialized value), - ({'a': 'b', 'c': None}, '{"a": "b", "c": null}'), - ('abc', '"abc"'), - ('{"a": "a"}', '"{\\"a\\": \\"a\\"}"'), - ) - - def test_dumping(self): - for value, serialized in self.test_values: - with self.subTest(value=value): - instance = JSONModel(field=value) - data = serializers.serialize('json', [instance]) - self.assertJSONEqual(data, self.test_data % serialized) - - def test_loading(self): - for value, serialized in self.test_values: - with self.subTest(value=value): - instance = list(serializers.deserialize('json', self.test_data % serialized))[0].object - self.assertEqual(instance.field, value) - - -class TestValidation(PostgreSQLSimpleTestCase): - - def test_not_serializable(self): - field = JSONField() - with self.assertRaises(exceptions.ValidationError) as cm: - field.clean(datetime.timedelta(days=1), None) - self.assertEqual(cm.exception.code, 'invalid') - self.assertEqual(cm.exception.message % cm.exception.params, "Value must be valid JSON.") - - def test_custom_encoder(self): - with self.assertRaisesMessage(ValueError, "The encoder parameter must be a callable object."): - field = JSONField(encoder=DjangoJSONEncoder()) - field = JSONField(encoder=DjangoJSONEncoder) - self.assertEqual(field.clean(datetime.timedelta(days=1), None), datetime.timedelta(days=1)) - - -class TestFormField(PostgreSQLSimpleTestCase): - - def test_valid(self): - field = forms.JSONField() - value = field.clean('{"a": "b"}') - self.assertEqual(value, {'a': 'b'}) - - def test_valid_empty(self): - field = forms.JSONField(required=False) - value = field.clean('') - self.assertIsNone(value) - - def test_invalid(self): - field = forms.JSONField() - with self.assertRaises(exceptions.ValidationError) as cm: - field.clean('{some badly formed: json}') - self.assertEqual(cm.exception.messages[0], 'ā€œ{some badly formed: json}ā€ value must be valid JSON.') - - def test_formfield(self): - model_field = JSONField() - form_field = model_field.formfield() - self.assertIsInstance(form_field, forms.JSONField) - - def test_formfield_disabled(self): - class JsonForm(Form): - name = CharField() - jfield = forms.JSONField(disabled=True) - - form = JsonForm({'name': 'xyz', 'jfield': '["bar"]'}, initial={'jfield': ['foo']}) - self.assertIn('["foo"]', form.as_p()) - - def test_prepare_value(self): - field = forms.JSONField() - self.assertEqual(field.prepare_value({'a': 'b'}), '{"a": "b"}') - self.assertEqual(field.prepare_value(None), 'null') - self.assertEqual(field.prepare_value('foo'), '"foo"') - - def test_redisplay_wrong_input(self): - """ - When displaying a bound form (typically due to invalid input), the form - should not overquote JSONField inputs. - """ - class JsonForm(Form): - name = CharField(max_length=2) - jfield = forms.JSONField() - - # JSONField input is fine, name is too long - form = JsonForm({'name': 'xyz', 'jfield': '["foo"]'}) - self.assertIn('["foo"]', form.as_p()) - - # This time, the JSONField input is wrong - form = JsonForm({'name': 'xy', 'jfield': '{"foo"}'}) - # Appears once in the textarea and once in the error message - self.assertEqual(form.as_p().count(escape('{"foo"}')), 2) - - def test_widget(self): - """The default widget of a JSONField is a Textarea.""" - field = forms.JSONField() - self.assertIsInstance(field.widget, widgets.Textarea) - - def test_custom_widget_kwarg(self): - """The widget can be overridden with a kwarg.""" - field = forms.JSONField(widget=widgets.Input) - self.assertIsInstance(field.widget, widgets.Input) - - def test_custom_widget_attribute(self): - """The widget can be overridden with an attribute.""" - class CustomJSONField(forms.JSONField): - widget = widgets.Input - - field = CustomJSONField() - self.assertIsInstance(field.widget, widgets.Input) - - def test_already_converted_value(self): - field = forms.JSONField(required=False) - tests = [ - '["a", "b", "c"]', '{"a": 1, "b": 2}', '1', '1.5', '"foo"', - 'true', 'false', 'null', - ] - for json_string in tests: - val = field.clean(json_string) - self.assertEqual(field.clean(val), val) - - def test_has_changed(self): - field = forms.JSONField() - self.assertIs(field.has_changed({'a': True}, '{"a": 1}'), True) - self.assertIs(field.has_changed({'a': 1, 'b': 2}, '{"b": 2, "a": 1}'), False) diff --git a/tests/postgres_tests/test_json_deprecation.py b/tests/postgres_tests/test_json_deprecation.py new file mode 100644 index 0000000000..80deb0cb15 --- /dev/null +++ b/tests/postgres_tests/test_json_deprecation.py @@ -0,0 +1,54 @@ +try: + from django.contrib.postgres.fields import JSONField + from django.contrib.postgres.fields.jsonb import KeyTransform, KeyTextTransform + from django.contrib.postgres import forms +except ImportError: + pass + +from django.core.checks import Warning as DjangoWarning +from django.utils.deprecation import RemovedInDjango40Warning + +from . import PostgreSQLSimpleTestCase +from .models import PostgreSQLModel + + +class DeprecationTests(PostgreSQLSimpleTestCase): + def test_model_field_deprecation_message(self): + class PostgreSQLJSONModel(PostgreSQLModel): + field = JSONField() + + self.assertEqual(PostgreSQLJSONModel().check(), [ + DjangoWarning( + 'django.contrib.postgres.fields.JSONField is deprecated. ' + 'Support for it (except in historical migrations) will be ' + 'removed in Django 4.0.', + hint='Use django.db.models.JSONField instead.', + obj=PostgreSQLJSONModel._meta.get_field('field'), + id='fields.W904', + ), + ]) + + def test_form_field_deprecation_message(self): + msg = ( + 'django.contrib.postgres.forms.JSONField is deprecated in favor ' + 'of django.forms.JSONField.' + ) + with self.assertWarnsMessage(RemovedInDjango40Warning, msg): + forms.JSONField() + + def test_key_transform_deprecation_message(self): + msg = ( + 'django.contrib.postgres.fields.jsonb.KeyTransform is deprecated ' + 'in favor of django.db.models.fields.json.KeyTransform.' + ) + with self.assertWarnsMessage(RemovedInDjango40Warning, msg): + KeyTransform('foo', 'bar') + + def test_key_text_transform_deprecation_message(self): + msg = ( + 'django.contrib.postgres.fields.jsonb.KeyTextTransform is ' + 'deprecated in favor of ' + 'django.db.models.fields.json.KeyTextTransform.' + ) + with self.assertWarnsMessage(RemovedInDjango40Warning, msg): + KeyTextTransform('foo', 'bar') diff --git a/tests/queries/models.py b/tests/queries/models.py index fd994170dd..fc46205a79 100644 --- a/tests/queries/models.py +++ b/tests/queries/models.py @@ -747,3 +747,10 @@ class ReturningModel(models.Model): class NonIntegerPKReturningModel(models.Model): created = CreatedField(editable=False, primary_key=True) + + +class JSONFieldNullable(models.Model): + json_field = models.JSONField(blank=True, null=True) + + class Meta: + required_db_features = {'supports_json_field'} diff --git a/tests/queries/test_bulk_update.py b/tests/queries/test_bulk_update.py index e2e9a6147a..ec43c86691 100644 --- a/tests/queries/test_bulk_update.py +++ b/tests/queries/test_bulk_update.py @@ -3,11 +3,11 @@ import datetime from django.core.exceptions import FieldDoesNotExist from django.db.models import F from django.db.models.functions import Lower -from django.test import TestCase +from django.test import TestCase, skipUnlessDBFeature from .models import ( - Article, CustomDbColumn, CustomPk, Detail, Individual, Member, Note, - Number, Order, Paragraph, SpecialCategory, Tag, Valid, + Article, CustomDbColumn, CustomPk, Detail, Individual, JSONFieldNullable, + Member, Note, Number, Order, Paragraph, SpecialCategory, Tag, Valid, ) @@ -228,3 +228,14 @@ class BulkUpdateTests(TestCase): article.created = point_in_time Article.objects.bulk_update(articles, ['created']) self.assertCountEqual(Article.objects.filter(created=point_in_time), articles) + + @skipUnlessDBFeature('supports_json_field') + def test_json_field(self): + JSONFieldNullable.objects.bulk_create([ + JSONFieldNullable(json_field={'a': i}) for i in range(10) + ]) + objs = JSONFieldNullable.objects.all() + for obj in objs: + obj.json_field = {'c': obj.json_field['a'] + 1} + JSONFieldNullable.objects.bulk_update(objs, ['json_field']) + self.assertCountEqual(JSONFieldNullable.objects.filter(json_field__has_key='c'), objs)