diff --git a/django/contrib/postgres/fields/array.py b/django/contrib/postgres/fields/array.py index 318afabd2c..b0850b92e7 100644 --- a/django/contrib/postgres/fields/array.py +++ b/django/contrib/postgres/fields/array.py @@ -70,9 +70,9 @@ class ArrayField(Field): size = self.size or '' return '%s[%s]' % (self.base_field.db_type(connection), size) - def get_prep_value(self, value): + def get_db_prep_value(self, value, connection, prepared=False): if isinstance(value, list) or isinstance(value, tuple): - return [self.base_field.get_prep_value(i) for i in value] + return [self.base_field.get_db_prep_value(i, connection, prepared) for i in value] return value def deconstruct(self): diff --git a/django/db/backends/base/features.py b/django/db/backends/base/features.py index fe13827b77..0f6ee0efe3 100644 --- a/django/db/backends/base/features.py +++ b/django/db/backends/base/features.py @@ -59,6 +59,9 @@ class BaseDatabaseFeatures(object): supports_subqueries_in_group_by = True supports_bitwise_or = True + # Is there a true datatype for uuid? + has_native_uuid_field = False + # Is there a true datatype for timedeltas? has_native_duration_field = False diff --git a/django/db/backends/base/operations.py b/django/db/backends/base/operations.py index c4e78e719e..24bcbb3d08 100644 --- a/django/db/backends/base/operations.py +++ b/django/db/backends/base/operations.py @@ -219,7 +219,7 @@ class BaseDatabaseOperations(object): """ return cursor.lastrowid - def lookup_cast(self, lookup_type): + def lookup_cast(self, lookup_type, internal_type=None): """ Returns the string to use in a query when performing lookups ("contains", "like", etc). The resulting string should contain a '%s' @@ -442,7 +442,7 @@ class BaseDatabaseOperations(object): def value_to_db_date(self, value): """ - Transform a date value to an object compatible with what is expected + Transforms a date value to an object compatible with what is expected by the backend driver for date columns. """ if value is None: @@ -451,7 +451,7 @@ class BaseDatabaseOperations(object): def value_to_db_datetime(self, value): """ - Transform a datetime value to an object compatible with what is expected + Transforms a datetime value to an object compatible with what is expected by the backend driver for datetime columns. """ if value is None: @@ -460,7 +460,7 @@ class BaseDatabaseOperations(object): def value_to_db_time(self, value): """ - Transform a time value to an object compatible with what is expected + Transforms a time value to an object compatible with what is expected by the backend driver for time columns. """ if value is None: @@ -471,11 +471,18 @@ class BaseDatabaseOperations(object): def value_to_db_decimal(self, value, max_digits, decimal_places): """ - Transform a decimal.Decimal value to an object compatible with what is + Transforms a decimal.Decimal value to an object compatible with what is expected by the backend driver for decimal (numeric) columns. """ return utils.format_number(value, max_digits, decimal_places) + def value_to_db_ipaddress(self, value): + """ + Transforms a string representation of an IP address into the expected + type for the backend driver. + """ + return value + def year_lookup_bounds_for_date_field(self, value): """ Returns a two-elements list with the lower and upper bound to be used diff --git a/django/db/backends/oracle/operations.py b/django/db/backends/oracle/operations.py index f00fd3fbea..fe9c93ba90 100644 --- a/django/db/backends/oracle/operations.py +++ b/django/db/backends/oracle/operations.py @@ -246,7 +246,7 @@ WHEN (new.%(col_name)s IS NULL) cursor.execute('SELECT "%s".currval FROM dual' % sq_name) return cursor.fetchone()[0] - def lookup_cast(self, lookup_type): + def lookup_cast(self, lookup_type, internal_type=None): if lookup_type in ('iexact', 'icontains', 'istartswith', 'iendswith'): return "UPPER(%s)" return "%s" diff --git a/django/db/backends/postgresql_psycopg2/base.py b/django/db/backends/postgresql_psycopg2/base.py index 34b2870773..37433c3987 100644 --- a/django/db/backends/postgresql_psycopg2/base.py +++ b/django/db/backends/postgresql_psycopg2/base.py @@ -38,6 +38,16 @@ psycopg2.extensions.register_adapter(SafeBytes, psycopg2.extensions.QuotedString psycopg2.extensions.register_adapter(SafeText, psycopg2.extensions.QuotedString) psycopg2.extras.register_uuid() +# Register support for inet[] manually so we don't have to handle the Inet() +# object on load all the time. +INETARRAY_OID = 1041 +INETARRAY = psycopg2.extensions.new_array_type( + (INETARRAY_OID,), + 'INETARRAY', + psycopg2.extensions.UNICODE, +) +psycopg2.extensions.register_type(INETARRAY) + class DatabaseWrapper(BaseDatabaseWrapper): vendor = 'postgresql' diff --git a/django/db/backends/postgresql_psycopg2/features.py b/django/db/backends/postgresql_psycopg2/features.py index 64acd0570a..6bb6de1a96 100644 --- a/django/db/backends/postgresql_psycopg2/features.py +++ b/django/db/backends/postgresql_psycopg2/features.py @@ -6,6 +6,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): needs_datetime_string_cast = False can_return_id_from_insert = True has_real_datatype = True + has_native_uuid_field = True has_native_duration_field = True driver_supports_timedelta_args = True can_defer_constraint_checks = True diff --git a/django/db/backends/postgresql_psycopg2/operations.py b/django/db/backends/postgresql_psycopg2/operations.py index 8e90a4020b..27b19db459 100644 --- a/django/db/backends/postgresql_psycopg2/operations.py +++ b/django/db/backends/postgresql_psycopg2/operations.py @@ -3,6 +3,8 @@ from __future__ import unicode_literals from django.conf import settings from django.db.backends.base.operations import BaseDatabaseOperations +from psycopg2.extras import Inet + class DatabaseOperations(BaseDatabaseOperations): def unification_cast_sql(self, output_field): @@ -57,13 +59,16 @@ class DatabaseOperations(BaseDatabaseOperations): def deferrable_sql(self): return " DEFERRABLE INITIALLY DEFERRED" - def lookup_cast(self, lookup_type): + def lookup_cast(self, lookup_type, internal_type=None): lookup = '%s' # Cast text lookups to text to allow things like filter(x__contains=4) if lookup_type in ('iexact', 'contains', 'icontains', 'startswith', 'istartswith', 'endswith', 'iendswith', 'regex', 'iregex'): - lookup = "%s::text" + if internal_type in ('IPAddressField', 'GenericIPAddressField'): + lookup = "HOST(%s)" + else: + lookup = "%s::text" # Use UPPER(x) for case-insensitive lookups; it's faster. if lookup_type in ('iexact', 'icontains', 'istartswith', 'iendswith'): @@ -71,11 +76,6 @@ class DatabaseOperations(BaseDatabaseOperations): return lookup - def field_cast_sql(self, db_type, internal_type): - if internal_type == "GenericIPAddressField" or internal_type == "IPAddressField": - return 'HOST(%s)' - return '%s' - def last_insert_id(self, cursor, table_name, pk_name): # Use pg_get_serial_sequence to get the underlying sequence name # from the table name and column name (available since PostgreSQL 8) @@ -224,3 +224,17 @@ class DatabaseOperations(BaseDatabaseOperations): def bulk_insert_sql(self, fields, num_values): items_sql = "(%s)" % ", ".join(["%s"] * len(fields)) return "VALUES " + ", ".join([items_sql] * num_values) + + def value_to_db_date(self, value): + return value + + def value_to_db_datetime(self, value): + return value + + def value_to_db_time(self, value): + return value + + def value_to_db_ipaddress(self, value): + if value: + return Inet(value) + return None diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index 03c7eafac6..d5dfac733f 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -1983,7 +1983,7 @@ class GenericIPAddressField(Field): def get_db_prep_value(self, value, connection, prepared=False): if not prepared: value = self.get_prep_value(value) - return value or None + return connection.ops.value_to_db_ipaddress(value) def get_prep_value(self, value): value = super(GenericIPAddressField, self).get_prep_value(value) @@ -2366,8 +2366,10 @@ class UUIDField(Field): def get_internal_type(self): return "UUIDField" - def get_prep_value(self, value): + def get_db_prep_value(self, value, connection, prepared=False): if isinstance(value, uuid.UUID): + if connection.features.has_native_uuid_field: + return value return value.hex if isinstance(value, six.string_types): return value.replace('-', '') diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py index d7423762f3..7610c0dde4 100644 --- a/django/db/models/lookups.py +++ b/django/db/models/lookups.py @@ -198,7 +198,7 @@ class BuiltinLookup(Lookup): db_type = self.lhs.output_field.db_type(connection=connection) lhs_sql = connection.ops.field_cast_sql( db_type, field_internal_type) % lhs_sql - lhs_sql = connection.ops.lookup_cast(self.lookup_name) % lhs_sql + lhs_sql = connection.ops.lookup_cast(self.lookup_name, field_internal_type) % lhs_sql return lhs_sql, params def as_sql(self, compiler, connection): diff --git a/tests/model_fields/tests.py b/tests/model_fields/tests.py index 897359c0cf..a9ce43cae6 100644 --- a/tests/model_fields/tests.py +++ b/tests/model_fields/tests.py @@ -695,6 +695,11 @@ class GenericIPAddressFieldTests(test.TestCase): o = GenericIPAddress.objects.get() self.assertIsNone(o.ip) + def test_save_load(self): + instance = GenericIPAddress.objects.create(ip='::1') + loaded = GenericIPAddress.objects.get() + self.assertEqual(loaded.ip, instance.ip) + class PromiseTest(test.TestCase): def test_AutoField(self): diff --git a/tests/postgres_tests/migrations/0002_create_test_models.py b/tests/postgres_tests/migrations/0002_create_test_models.py index bdde4a9bf6..841953d351 100644 --- a/tests/postgres_tests/migrations/0002_create_test_models.py +++ b/tests/postgres_tests/migrations/0002_create_test_models.py @@ -27,7 +27,9 @@ class Migration(migrations.Migration): name='DateTimeArrayModel', fields=[ ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('field', django.contrib.postgres.fields.ArrayField(models.DateTimeField(), size=None)), + ('datetimes', django.contrib.postgres.fields.ArrayField(models.DateTimeField(), size=None)), + ('dates', django.contrib.postgres.fields.ArrayField(models.DateField(), size=None)), + ('times', django.contrib.postgres.fields.ArrayField(models.TimeField(), size=None)), ], options={ }, @@ -43,6 +45,18 @@ class Migration(migrations.Migration): }, bases=(models.Model,), ), + migrations.CreateModel( + name='OtherTypesArrayModel', + fields=[ + ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), + ('ips', django.contrib.postgres.fields.ArrayField(models.GenericIPAddressField(), size=None)), + ('uuids', django.contrib.postgres.fields.ArrayField(models.UUIDField(), size=None)), + ('decimals', django.contrib.postgres.fields.ArrayField(models.DecimalField(max_digits=5, decimal_places=2), size=None)), + ], + options={ + }, + bases=(models.Model,), + ), migrations.CreateModel( name='IntegerArrayModel', fields=[ diff --git a/tests/postgres_tests/models.py b/tests/postgres_tests/models.py index 74af39dd04..0422aba6a0 100644 --- a/tests/postgres_tests/models.py +++ b/tests/postgres_tests/models.py @@ -18,13 +18,21 @@ class CharArrayModel(models.Model): class DateTimeArrayModel(models.Model): - field = ArrayField(models.DateTimeField()) + datetimes = ArrayField(models.DateTimeField()) + dates = ArrayField(models.DateField()) + times = ArrayField(models.TimeField()) class NestedIntegerArrayModel(models.Model): field = ArrayField(ArrayField(models.IntegerField())) +class OtherTypesArrayModel(models.Model): + ips = ArrayField(models.GenericIPAddressField()) + uuids = ArrayField(models.UUIDField()) + decimals = ArrayField(models.DecimalField(max_digits=5, decimal_places=2)) + + class HStoreModel(models.Model): field = HStoreField(blank=True, null=True) diff --git a/tests/postgres_tests/test_array.py b/tests/postgres_tests/test_array.py index 90f4c246c6..5c300f7ea3 100644 --- a/tests/postgres_tests/test_array.py +++ b/tests/postgres_tests/test_array.py @@ -1,5 +1,7 @@ +import decimal import json import unittest +import uuid from django.contrib.postgres.fields import ArrayField from django.contrib.postgres.forms import SimpleArrayField, SplitArrayField @@ -10,7 +12,11 @@ from django import forms from django.test import TestCase, override_settings from django.utils import timezone -from .models import IntegerArrayModel, NullableIntegerArrayModel, CharArrayModel, DateTimeArrayModel, NestedIntegerArrayModel, ArrayFieldSubclass +from .models import ( + IntegerArrayModel, NullableIntegerArrayModel, CharArrayModel, + DateTimeArrayModel, NestedIntegerArrayModel, OtherTypesArrayModel, + ArrayFieldSubclass, +) @unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL required') @@ -29,10 +35,16 @@ class TestSaveLoad(TestCase): self.assertEqual(instance.field, loaded.field) def test_dates(self): - instance = DateTimeArrayModel(field=[timezone.now()]) + instance = DateTimeArrayModel( + datetimes=[timezone.now()], + dates=[timezone.now().date()], + times=[timezone.now().time()], + ) instance.save() loaded = DateTimeArrayModel.objects.get() - self.assertEqual(instance.field, loaded.field) + self.assertEqual(instance.datetimes, loaded.datetimes) + self.assertEqual(instance.dates, loaded.dates) + self.assertEqual(instance.times, loaded.times) def test_tuples(self): instance = IntegerArrayModel(field=(1,)) @@ -70,6 +82,18 @@ class TestSaveLoad(TestCase): loaded = NestedIntegerArrayModel.objects.get() self.assertEqual(instance.field, loaded.field) + def test_other_array_types(self): + instance = OtherTypesArrayModel( + ips=['192.168.0.1', '::1'], + uuids=[uuid.uuid4()], + decimals=[decimal.Decimal(1.25), 1.75], + ) + instance.save() + loaded = OtherTypesArrayModel.objects.get() + self.assertEqual(instance.ips, loaded.ips) + self.assertEqual(instance.uuids, loaded.uuids) + self.assertEqual(instance.decimals, loaded.decimals) + @unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL required') class TestQuerying(TestCase):