Fixed #11442 -- Postgresql backend casts all inet types to text

This commit is contained in:
Erik Romijn 2013-05-19 13:28:09 +02:00
parent 56d6fdbbf5
commit 60d94c2a80
5 changed files with 16 additions and 12 deletions

View File

@ -765,12 +765,12 @@ class BaseDatabaseOperations(object):
""" """
return cursor.fetchone()[0] return cursor.fetchone()[0]
def field_cast_sql(self, db_type): def field_cast_sql(self, db_type, internal_type):
""" """
Given a column type (e.g. 'BLOB', 'VARCHAR'), returns the SQL necessary Given a column type (e.g. 'BLOB', 'VARCHAR'), and an internal type
to cast it before using it in a WHERE statement. Note that the (e.g. 'GenericIPAddressField'), returns the SQL necessary to cast it
resulting string should contain a '%s' placeholder for the column being before using it in a WHERE statement. Note that the resulting string
searched against. should contain a '%s' placeholder for the column being searched against.
""" """
return '%s' return '%s'

View File

@ -254,7 +254,7 @@ WHEN (new.%(col_name)s IS NULL)
def fetch_returned_insert_id(self, cursor): def fetch_returned_insert_id(self, cursor):
return int(cursor._insert_id_var.getvalue()) return int(cursor._insert_id_var.getvalue())
def field_cast_sql(self, db_type): def field_cast_sql(self, db_type, internal_type):
if db_type and db_type.endswith('LOB'): if db_type and db_type.endswith('LOB'):
return "DBMS_LOB.SUBSTR(%s)" return "DBMS_LOB.SUBSTR(%s)"
else: else:

View File

@ -78,8 +78,8 @@ class DatabaseOperations(BaseDatabaseOperations):
return lookup return lookup
def field_cast_sql(self, db_type): def field_cast_sql(self, db_type, internal_type):
if db_type == 'inet': if internal_type == "GenericIPAddressField" or internal_type == "IPAddressField":
return 'HOST(%s)' return 'HOST(%s)'
return '%s' return '%s'

View File

@ -174,6 +174,8 @@ class WhereNode(tree.Node):
it. it.
""" """
lvalue, lookup_type, value_annotation, params_or_value = child lvalue, lookup_type, value_annotation, params_or_value = child
field_internal_type = lvalue.field.get_internal_type() if lvalue.field else None
if isinstance(lvalue, Constraint): if isinstance(lvalue, Constraint):
try: try:
lvalue, params = lvalue.process(lookup_type, params_or_value, connection) lvalue, params = lvalue.process(lookup_type, params_or_value, connection)
@ -187,7 +189,7 @@ class WhereNode(tree.Node):
if isinstance(lvalue, tuple): if isinstance(lvalue, tuple):
# A direct database column lookup. # A direct database column lookup.
field_sql, field_params = self.sql_for_columns(lvalue, qn, connection), [] field_sql, field_params = self.sql_for_columns(lvalue, qn, connection, field_internal_type), []
else: else:
# A smart object with an as_sql() method. # A smart object with an as_sql() method.
field_sql, field_params = lvalue.as_sql(qn, connection) field_sql, field_params = lvalue.as_sql(qn, connection)
@ -257,7 +259,7 @@ class WhereNode(tree.Node):
raise TypeError('Invalid lookup_type: %r' % lookup_type) raise TypeError('Invalid lookup_type: %r' % lookup_type)
def sql_for_columns(self, data, qn, connection): def sql_for_columns(self, data, qn, connection, internal_type=None):
""" """
Returns the SQL fragment used for the left-hand side of a column Returns the SQL fragment used for the left-hand side of a column
constraint (for example, the "T1.foo" portion in the clause constraint (for example, the "T1.foo" portion in the clause
@ -268,7 +270,7 @@ class WhereNode(tree.Node):
lhs = '%s.%s' % (qn(table_alias), qn(name)) lhs = '%s.%s' % (qn(table_alias), qn(name))
else: else:
lhs = qn(name) lhs = qn(name)
return connection.ops.field_cast_sql(db_type) % lhs return connection.ops.field_cast_sql(db_type, internal_type) % lhs
def relabel_aliases(self, change_map): def relabel_aliases(self, change_map):
""" """

View File

@ -73,9 +73,11 @@ class StringLookupTests(TestCase):
""" """
Regression test for #708 Regression test for #708
"like" queries on IP address fields require casting to text (on PostgreSQL). "like" queries on IP address fields require casting with HOST() (on PostgreSQL).
""" """
a = Article(name='IP test', text='The body', submitted_from='192.0.2.100') a = Article(name='IP test', text='The body', submitted_from='192.0.2.100')
a.save() a.save()
self.assertEqual(repr(Article.objects.filter(submitted_from__contains='192.0.2')), self.assertEqual(repr(Article.objects.filter(submitted_from__contains='192.0.2')),
repr([a])) repr([a]))
# Test that the searches do not match the subnet mask (/32 in this case)
self.assertEqual(Article.objects.filter(submitted_from__contains='32').count(), 0)