From 39d95fb6ada99c59d47fa0eae6d3128abafe2d58 Mon Sep 17 00:00:00 2001 From: Marc Tamlyn Date: Sat, 10 Jan 2015 18:13:28 +0000 Subject: [PATCH] Fixed #24092 -- Widened base field support for ArrayField. Several issues resolved here, following from a report that a base_field of GenericIpAddressField was failing. We were using get_prep_value instead of get_db_prep_value in ArrayField which was bypassing any extra modifications to the value being made in the base field's get_db_prep_value. Changing this broke datetime support, so the postgres backend has gained the relevant operation methods to send dates/times/datetimes directly to the db backend instead of casting them to strings. Similarly, a new database feature has been added allowing the uuid to be passed directly to the backend, as we do with timedeltas. On the other side, psycopg2 expects an Inet() instance for IP address fields, so we add a value_to_db_ipaddress method to wrap the strings on postgres. We also have to manually add a database adapter to psycopg2, as we do not wish to use the built in adapter which would turn everything into Inet() instances. Thanks to smclenithan for the report. --- django/contrib/postgres/fields/array.py | 4 +-- django/db/backends/base/features.py | 3 ++ django/db/backends/base/operations.py | 17 +++++++---- django/db/backends/oracle/operations.py | 2 +- .../db/backends/postgresql_psycopg2/base.py | 10 +++++++ .../backends/postgresql_psycopg2/features.py | 1 + .../postgresql_psycopg2/operations.py | 28 ++++++++++++----- django/db/models/fields/__init__.py | 6 ++-- django/db/models/lookups.py | 2 +- tests/model_fields/tests.py | 5 ++++ .../migrations/0002_create_test_models.py | 16 +++++++++- tests/postgres_tests/models.py | 10 ++++++- tests/postgres_tests/test_array.py | 30 +++++++++++++++++-- 13 files changed, 111 insertions(+), 23 deletions(-) diff --git a/django/contrib/postgres/fields/array.py b/django/contrib/postgres/fields/array.py index 318afabd2c..b0850b92e7 100644 --- a/django/contrib/postgres/fields/array.py +++ b/django/contrib/postgres/fields/array.py @@ -70,9 +70,9 @@ class ArrayField(Field): size = self.size or '' return '%s[%s]' % (self.base_field.db_type(connection), size) - def get_prep_value(self, value): + def get_db_prep_value(self, value, connection, prepared=False): if isinstance(value, list) or isinstance(value, tuple): - return [self.base_field.get_prep_value(i) for i in value] + return [self.base_field.get_db_prep_value(i, connection, prepared) for i in value] return value def deconstruct(self): diff --git a/django/db/backends/base/features.py b/django/db/backends/base/features.py index fe13827b77..0f6ee0efe3 100644 --- a/django/db/backends/base/features.py +++ b/django/db/backends/base/features.py @@ -59,6 +59,9 @@ class BaseDatabaseFeatures(object): supports_subqueries_in_group_by = True supports_bitwise_or = True + # Is there a true datatype for uuid? + has_native_uuid_field = False + # Is there a true datatype for timedeltas? has_native_duration_field = False diff --git a/django/db/backends/base/operations.py b/django/db/backends/base/operations.py index c4e78e719e..24bcbb3d08 100644 --- a/django/db/backends/base/operations.py +++ b/django/db/backends/base/operations.py @@ -219,7 +219,7 @@ class BaseDatabaseOperations(object): """ return cursor.lastrowid - def lookup_cast(self, lookup_type): + def lookup_cast(self, lookup_type, internal_type=None): """ Returns the string to use in a query when performing lookups ("contains", "like", etc). The resulting string should contain a '%s' @@ -442,7 +442,7 @@ class BaseDatabaseOperations(object): def value_to_db_date(self, value): """ - Transform a date value to an object compatible with what is expected + Transforms a date value to an object compatible with what is expected by the backend driver for date columns. """ if value is None: @@ -451,7 +451,7 @@ class BaseDatabaseOperations(object): def value_to_db_datetime(self, value): """ - Transform a datetime value to an object compatible with what is expected + Transforms a datetime value to an object compatible with what is expected by the backend driver for datetime columns. """ if value is None: @@ -460,7 +460,7 @@ class BaseDatabaseOperations(object): def value_to_db_time(self, value): """ - Transform a time value to an object compatible with what is expected + Transforms a time value to an object compatible with what is expected by the backend driver for time columns. """ if value is None: @@ -471,11 +471,18 @@ class BaseDatabaseOperations(object): def value_to_db_decimal(self, value, max_digits, decimal_places): """ - Transform a decimal.Decimal value to an object compatible with what is + Transforms a decimal.Decimal value to an object compatible with what is expected by the backend driver for decimal (numeric) columns. """ return utils.format_number(value, max_digits, decimal_places) + def value_to_db_ipaddress(self, value): + """ + Transforms a string representation of an IP address into the expected + type for the backend driver. + """ + return value + def year_lookup_bounds_for_date_field(self, value): """ Returns a two-elements list with the lower and upper bound to be used diff --git a/django/db/backends/oracle/operations.py b/django/db/backends/oracle/operations.py index f00fd3fbea..fe9c93ba90 100644 --- a/django/db/backends/oracle/operations.py +++ b/django/db/backends/oracle/operations.py @@ -246,7 +246,7 @@ WHEN (new.%(col_name)s IS NULL) cursor.execute('SELECT "%s".currval FROM dual' % sq_name) return cursor.fetchone()[0] - def lookup_cast(self, lookup_type): + def lookup_cast(self, lookup_type, internal_type=None): if lookup_type in ('iexact', 'icontains', 'istartswith', 'iendswith'): return "UPPER(%s)" return "%s" diff --git a/django/db/backends/postgresql_psycopg2/base.py b/django/db/backends/postgresql_psycopg2/base.py index 34b2870773..37433c3987 100644 --- a/django/db/backends/postgresql_psycopg2/base.py +++ b/django/db/backends/postgresql_psycopg2/base.py @@ -38,6 +38,16 @@ psycopg2.extensions.register_adapter(SafeBytes, psycopg2.extensions.QuotedString psycopg2.extensions.register_adapter(SafeText, psycopg2.extensions.QuotedString) psycopg2.extras.register_uuid() +# Register support for inet[] manually so we don't have to handle the Inet() +# object on load all the time. +INETARRAY_OID = 1041 +INETARRAY = psycopg2.extensions.new_array_type( + (INETARRAY_OID,), + 'INETARRAY', + psycopg2.extensions.UNICODE, +) +psycopg2.extensions.register_type(INETARRAY) + class DatabaseWrapper(BaseDatabaseWrapper): vendor = 'postgresql' diff --git a/django/db/backends/postgresql_psycopg2/features.py b/django/db/backends/postgresql_psycopg2/features.py index 64acd0570a..6bb6de1a96 100644 --- a/django/db/backends/postgresql_psycopg2/features.py +++ b/django/db/backends/postgresql_psycopg2/features.py @@ -6,6 +6,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): needs_datetime_string_cast = False can_return_id_from_insert = True has_real_datatype = True + has_native_uuid_field = True has_native_duration_field = True driver_supports_timedelta_args = True can_defer_constraint_checks = True diff --git a/django/db/backends/postgresql_psycopg2/operations.py b/django/db/backends/postgresql_psycopg2/operations.py index 8e90a4020b..27b19db459 100644 --- a/django/db/backends/postgresql_psycopg2/operations.py +++ b/django/db/backends/postgresql_psycopg2/operations.py @@ -3,6 +3,8 @@ from __future__ import unicode_literals from django.conf import settings from django.db.backends.base.operations import BaseDatabaseOperations +from psycopg2.extras import Inet + class DatabaseOperations(BaseDatabaseOperations): def unification_cast_sql(self, output_field): @@ -57,13 +59,16 @@ class DatabaseOperations(BaseDatabaseOperations): def deferrable_sql(self): return " DEFERRABLE INITIALLY DEFERRED" - def lookup_cast(self, lookup_type): + def lookup_cast(self, lookup_type, internal_type=None): lookup = '%s' # Cast text lookups to text to allow things like filter(x__contains=4) if lookup_type in ('iexact', 'contains', 'icontains', 'startswith', 'istartswith', 'endswith', 'iendswith', 'regex', 'iregex'): - lookup = "%s::text" + if internal_type in ('IPAddressField', 'GenericIPAddressField'): + lookup = "HOST(%s)" + else: + lookup = "%s::text" # Use UPPER(x) for case-insensitive lookups; it's faster. if lookup_type in ('iexact', 'icontains', 'istartswith', 'iendswith'): @@ -71,11 +76,6 @@ class DatabaseOperations(BaseDatabaseOperations): return lookup - def field_cast_sql(self, db_type, internal_type): - if internal_type == "GenericIPAddressField" or internal_type == "IPAddressField": - return 'HOST(%s)' - return '%s' - def last_insert_id(self, cursor, table_name, pk_name): # Use pg_get_serial_sequence to get the underlying sequence name # from the table name and column name (available since PostgreSQL 8) @@ -224,3 +224,17 @@ class DatabaseOperations(BaseDatabaseOperations): def bulk_insert_sql(self, fields, num_values): items_sql = "(%s)" % ", ".join(["%s"] * len(fields)) return "VALUES " + ", ".join([items_sql] * num_values) + + def value_to_db_date(self, value): + return value + + def value_to_db_datetime(self, value): + return value + + def value_to_db_time(self, value): + return value + + def value_to_db_ipaddress(self, value): + if value: + return Inet(value) + return None diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index 03c7eafac6..d5dfac733f 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -1983,7 +1983,7 @@ class GenericIPAddressField(Field): def get_db_prep_value(self, value, connection, prepared=False): if not prepared: value = self.get_prep_value(value) - return value or None + return connection.ops.value_to_db_ipaddress(value) def get_prep_value(self, value): value = super(GenericIPAddressField, self).get_prep_value(value) @@ -2366,8 +2366,10 @@ class UUIDField(Field): def get_internal_type(self): return "UUIDField" - def get_prep_value(self, value): + def get_db_prep_value(self, value, connection, prepared=False): if isinstance(value, uuid.UUID): + if connection.features.has_native_uuid_field: + return value return value.hex if isinstance(value, six.string_types): return value.replace('-', '') diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py index d7423762f3..7610c0dde4 100644 --- a/django/db/models/lookups.py +++ b/django/db/models/lookups.py @@ -198,7 +198,7 @@ class BuiltinLookup(Lookup): db_type = self.lhs.output_field.db_type(connection=connection) lhs_sql = connection.ops.field_cast_sql( db_type, field_internal_type) % lhs_sql - lhs_sql = connection.ops.lookup_cast(self.lookup_name) % lhs_sql + lhs_sql = connection.ops.lookup_cast(self.lookup_name, field_internal_type) % lhs_sql return lhs_sql, params def as_sql(self, compiler, connection): diff --git a/tests/model_fields/tests.py b/tests/model_fields/tests.py index 897359c0cf..a9ce43cae6 100644 --- a/tests/model_fields/tests.py +++ b/tests/model_fields/tests.py @@ -695,6 +695,11 @@ class GenericIPAddressFieldTests(test.TestCase): o = GenericIPAddress.objects.get() self.assertIsNone(o.ip) + def test_save_load(self): + instance = GenericIPAddress.objects.create(ip='::1') + loaded = GenericIPAddress.objects.get() + self.assertEqual(loaded.ip, instance.ip) + class PromiseTest(test.TestCase): def test_AutoField(self): diff --git a/tests/postgres_tests/migrations/0002_create_test_models.py b/tests/postgres_tests/migrations/0002_create_test_models.py index bdde4a9bf6..841953d351 100644 --- a/tests/postgres_tests/migrations/0002_create_test_models.py +++ b/tests/postgres_tests/migrations/0002_create_test_models.py @@ -27,7 +27,9 @@ class Migration(migrations.Migration): name='DateTimeArrayModel', fields=[ ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('field', django.contrib.postgres.fields.ArrayField(models.DateTimeField(), size=None)), + ('datetimes', django.contrib.postgres.fields.ArrayField(models.DateTimeField(), size=None)), + ('dates', django.contrib.postgres.fields.ArrayField(models.DateField(), size=None)), + ('times', django.contrib.postgres.fields.ArrayField(models.TimeField(), size=None)), ], options={ }, @@ -43,6 +45,18 @@ class Migration(migrations.Migration): }, bases=(models.Model,), ), + migrations.CreateModel( + name='OtherTypesArrayModel', + fields=[ + ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), + ('ips', django.contrib.postgres.fields.ArrayField(models.GenericIPAddressField(), size=None)), + ('uuids', django.contrib.postgres.fields.ArrayField(models.UUIDField(), size=None)), + ('decimals', django.contrib.postgres.fields.ArrayField(models.DecimalField(max_digits=5, decimal_places=2), size=None)), + ], + options={ + }, + bases=(models.Model,), + ), migrations.CreateModel( name='IntegerArrayModel', fields=[ diff --git a/tests/postgres_tests/models.py b/tests/postgres_tests/models.py index 74af39dd04..0422aba6a0 100644 --- a/tests/postgres_tests/models.py +++ b/tests/postgres_tests/models.py @@ -18,13 +18,21 @@ class CharArrayModel(models.Model): class DateTimeArrayModel(models.Model): - field = ArrayField(models.DateTimeField()) + datetimes = ArrayField(models.DateTimeField()) + dates = ArrayField(models.DateField()) + times = ArrayField(models.TimeField()) class NestedIntegerArrayModel(models.Model): field = ArrayField(ArrayField(models.IntegerField())) +class OtherTypesArrayModel(models.Model): + ips = ArrayField(models.GenericIPAddressField()) + uuids = ArrayField(models.UUIDField()) + decimals = ArrayField(models.DecimalField(max_digits=5, decimal_places=2)) + + class HStoreModel(models.Model): field = HStoreField(blank=True, null=True) diff --git a/tests/postgres_tests/test_array.py b/tests/postgres_tests/test_array.py index 90f4c246c6..5c300f7ea3 100644 --- a/tests/postgres_tests/test_array.py +++ b/tests/postgres_tests/test_array.py @@ -1,5 +1,7 @@ +import decimal import json import unittest +import uuid from django.contrib.postgres.fields import ArrayField from django.contrib.postgres.forms import SimpleArrayField, SplitArrayField @@ -10,7 +12,11 @@ from django import forms from django.test import TestCase, override_settings from django.utils import timezone -from .models import IntegerArrayModel, NullableIntegerArrayModel, CharArrayModel, DateTimeArrayModel, NestedIntegerArrayModel, ArrayFieldSubclass +from .models import ( + IntegerArrayModel, NullableIntegerArrayModel, CharArrayModel, + DateTimeArrayModel, NestedIntegerArrayModel, OtherTypesArrayModel, + ArrayFieldSubclass, +) @unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL required') @@ -29,10 +35,16 @@ class TestSaveLoad(TestCase): self.assertEqual(instance.field, loaded.field) def test_dates(self): - instance = DateTimeArrayModel(field=[timezone.now()]) + instance = DateTimeArrayModel( + datetimes=[timezone.now()], + dates=[timezone.now().date()], + times=[timezone.now().time()], + ) instance.save() loaded = DateTimeArrayModel.objects.get() - self.assertEqual(instance.field, loaded.field) + self.assertEqual(instance.datetimes, loaded.datetimes) + self.assertEqual(instance.dates, loaded.dates) + self.assertEqual(instance.times, loaded.times) def test_tuples(self): instance = IntegerArrayModel(field=(1,)) @@ -70,6 +82,18 @@ class TestSaveLoad(TestCase): loaded = NestedIntegerArrayModel.objects.get() self.assertEqual(instance.field, loaded.field) + def test_other_array_types(self): + instance = OtherTypesArrayModel( + ips=['192.168.0.1', '::1'], + uuids=[uuid.uuid4()], + decimals=[decimal.Decimal(1.25), 1.75], + ) + instance.save() + loaded = OtherTypesArrayModel.objects.get() + self.assertEqual(instance.ips, loaded.ips) + self.assertEqual(instance.uuids, loaded.uuids) + self.assertEqual(instance.decimals, loaded.decimals) + @unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL required') class TestQuerying(TestCase):