mirror of https://github.com/django/django.git
Fixed #16055 -- Fixed crash when filtering against char/text GenericRelation relation on PostgreSQL.
This commit is contained in:
parent
594fcc2b74
commit
9bbf97bcdb
|
@ -8,6 +8,7 @@ import sqlparse
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.db import NotSupportedError, transaction
|
from django.db import NotSupportedError, transaction
|
||||||
from django.db.backends import utils
|
from django.db.backends import utils
|
||||||
|
from django.db.models.expressions import Col
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
from django.utils.encoding import force_str
|
from django.utils.encoding import force_str
|
||||||
|
|
||||||
|
@ -776,3 +777,9 @@ class BaseDatabaseOperations:
|
||||||
|
|
||||||
def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
|
def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
def prepare_join_on_clause(self, lhs_table, lhs_field, rhs_table, rhs_field):
|
||||||
|
lhs_expr = Col(lhs_table, lhs_field)
|
||||||
|
rhs_expr = Col(rhs_table, rhs_field)
|
||||||
|
|
||||||
|
return lhs_expr, rhs_expr
|
||||||
|
|
|
@ -120,6 +120,9 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
||||||
"migrations.test_operations.OperationTests."
|
"migrations.test_operations.OperationTests."
|
||||||
"test_alter_field_pk_fk_db_collation",
|
"test_alter_field_pk_fk_db_collation",
|
||||||
},
|
},
|
||||||
|
"Oracle doesn't support comparing NCLOB to NUMBER.": {
|
||||||
|
"generic_relations_regress.tests.GenericRelationTests.test_textlink_filter",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
django_test_expected_failures = {
|
django_test_expected_failures = {
|
||||||
# A bug in Django/cx_Oracle with respect to string handling (#23843).
|
# A bug in Django/cx_Oracle with respect to string handling (#23843).
|
||||||
|
|
|
@ -12,6 +12,7 @@ from django.db.backends.postgresql.psycopg_any import (
|
||||||
)
|
)
|
||||||
from django.db.backends.utils import split_tzname_delta
|
from django.db.backends.utils import split_tzname_delta
|
||||||
from django.db.models.constants import OnConflict
|
from django.db.models.constants import OnConflict
|
||||||
|
from django.db.models.functions import Cast
|
||||||
from django.utils.regex_helper import _lazy_re_compile
|
from django.utils.regex_helper import _lazy_re_compile
|
||||||
|
|
||||||
|
|
||||||
|
@ -413,3 +414,13 @@ class DatabaseOperations(BaseDatabaseOperations):
|
||||||
update_fields,
|
update_fields,
|
||||||
unique_fields,
|
unique_fields,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def prepare_join_on_clause(self, lhs_table, lhs_field, rhs_table, rhs_field):
|
||||||
|
lhs_expr, rhs_expr = super().prepare_join_on_clause(
|
||||||
|
lhs_table, lhs_field, rhs_table, rhs_field
|
||||||
|
)
|
||||||
|
|
||||||
|
if lhs_field.db_type(self.connection) != rhs_field.db_type(self.connection):
|
||||||
|
rhs_expr = Cast(rhs_expr, lhs_field)
|
||||||
|
|
||||||
|
return lhs_expr, rhs_expr
|
||||||
|
|
|
@ -785,6 +785,14 @@ class ForeignObject(RelatedField):
|
||||||
def get_reverse_joining_columns(self):
|
def get_reverse_joining_columns(self):
|
||||||
return self.get_joining_columns(reverse_join=True)
|
return self.get_joining_columns(reverse_join=True)
|
||||||
|
|
||||||
|
def get_joining_fields(self, reverse_join=False):
|
||||||
|
return tuple(
|
||||||
|
self.reverse_related_fields if reverse_join else self.related_fields
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_reverse_joining_fields(self):
|
||||||
|
return self.get_joining_fields(reverse_join=True)
|
||||||
|
|
||||||
def get_extra_descriptor_filter(self, instance):
|
def get_extra_descriptor_filter(self, instance):
|
||||||
"""
|
"""
|
||||||
Return an extra filter condition for related object fetching when
|
Return an extra filter condition for related object fetching when
|
||||||
|
|
|
@ -195,6 +195,9 @@ class ForeignObjectRel(FieldCacheMixin):
|
||||||
def get_joining_columns(self):
|
def get_joining_columns(self):
|
||||||
return self.field.get_reverse_joining_columns()
|
return self.field.get_reverse_joining_columns()
|
||||||
|
|
||||||
|
def get_joining_fields(self):
|
||||||
|
return self.field.get_reverse_joining_fields()
|
||||||
|
|
||||||
def get_extra_restriction(self, alias, related_alias):
|
def get_extra_restriction(self, alias, related_alias):
|
||||||
return self.field.get_extra_restriction(related_alias, alias)
|
return self.field.get_extra_restriction(related_alias, alias)
|
||||||
|
|
||||||
|
|
|
@ -61,6 +61,14 @@ class Join:
|
||||||
self.join_type = join_type
|
self.join_type = join_type
|
||||||
# A list of 2-tuples to use in the ON clause of the JOIN.
|
# A list of 2-tuples to use in the ON clause of the JOIN.
|
||||||
# Each 2-tuple will create one join condition in the ON clause.
|
# Each 2-tuple will create one join condition in the ON clause.
|
||||||
|
if hasattr(join_field, "get_joining_fields"):
|
||||||
|
self.join_fields = join_field.get_joining_fields()
|
||||||
|
self.join_cols = tuple(
|
||||||
|
(lhs_field.column, rhs_field.column)
|
||||||
|
for lhs_field, rhs_field in self.join_fields
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.join_fields = None
|
||||||
self.join_cols = join_field.get_joining_columns()
|
self.join_cols = join_field.get_joining_columns()
|
||||||
# Along which field (or ForeignObjectRel in the reverse join case)
|
# Along which field (or ForeignObjectRel in the reverse join case)
|
||||||
self.join_field = join_field
|
self.join_field = join_field
|
||||||
|
@ -78,18 +86,21 @@ class Join:
|
||||||
params = []
|
params = []
|
||||||
qn = compiler.quote_name_unless_alias
|
qn = compiler.quote_name_unless_alias
|
||||||
qn2 = connection.ops.quote_name
|
qn2 = connection.ops.quote_name
|
||||||
|
|
||||||
# Add a join condition for each pair of joining columns.
|
# Add a join condition for each pair of joining columns.
|
||||||
for lhs_col, rhs_col in self.join_cols:
|
join_fields = self.join_fields or self.join_cols
|
||||||
join_conditions.append(
|
for lhs, rhs in join_fields:
|
||||||
"%s.%s = %s.%s"
|
if isinstance(lhs, str):
|
||||||
% (
|
lhs_full_name = "%s.%s" % (qn(self.parent_alias), qn2(lhs))
|
||||||
qn(self.parent_alias),
|
rhs_full_name = "%s.%s" % (qn(self.table_alias), qn2(rhs))
|
||||||
qn2(lhs_col),
|
else:
|
||||||
qn(self.table_alias),
|
lhs, rhs = connection.ops.prepare_join_on_clause(
|
||||||
qn2(rhs_col),
|
self.parent_alias, lhs, self.table_alias, rhs
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
lhs_sql, lhs_params = compiler.compile(lhs)
|
||||||
|
lhs_full_name = lhs_sql % lhs_params
|
||||||
|
rhs_sql, rhs_params = compiler.compile(rhs)
|
||||||
|
rhs_full_name = rhs_sql % rhs_params
|
||||||
|
join_conditions.append(f"{lhs_full_name} = {rhs_full_name}")
|
||||||
|
|
||||||
# Add a single condition inside parentheses for whatever
|
# Add a single condition inside parentheses for whatever
|
||||||
# get_extra_restriction() returns.
|
# get_extra_restriction() returns.
|
||||||
|
|
|
@ -4,6 +4,7 @@ from django.core.management.color import no_style
|
||||||
from django.db import NotSupportedError, connection, transaction
|
from django.db import NotSupportedError, connection, transaction
|
||||||
from django.db.backends.base.operations import BaseDatabaseOperations
|
from django.db.backends.base.operations import BaseDatabaseOperations
|
||||||
from django.db.models import DurationField, Value
|
from django.db.models import DurationField, Value
|
||||||
|
from django.db.models.expressions import Col
|
||||||
from django.test import (
|
from django.test import (
|
||||||
SimpleTestCase,
|
SimpleTestCase,
|
||||||
TestCase,
|
TestCase,
|
||||||
|
@ -159,6 +160,20 @@ class SimpleDatabaseOperationTests(SimpleTestCase):
|
||||||
):
|
):
|
||||||
self.ops.datetime_extract_sql(None, None, None, None)
|
self.ops.datetime_extract_sql(None, None, None, None)
|
||||||
|
|
||||||
|
def test_prepare_join_on_clause(self):
|
||||||
|
author_table = Author._meta.db_table
|
||||||
|
author_id_field = Author._meta.get_field("id")
|
||||||
|
book_table = Book._meta.db_table
|
||||||
|
book_fk_field = Book._meta.get_field("author")
|
||||||
|
lhs_expr, rhs_expr = self.ops.prepare_join_on_clause(
|
||||||
|
author_table,
|
||||||
|
author_id_field,
|
||||||
|
book_table,
|
||||||
|
book_fk_field,
|
||||||
|
)
|
||||||
|
self.assertEqual(lhs_expr, Col(author_table, author_id_field))
|
||||||
|
self.assertEqual(rhs_expr, Col(book_table, book_fk_field))
|
||||||
|
|
||||||
|
|
||||||
class DatabaseOperationTests(TestCase):
|
class DatabaseOperationTests(TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
|
|
@ -2,9 +2,11 @@ import unittest
|
||||||
|
|
||||||
from django.core.management.color import no_style
|
from django.core.management.color import no_style
|
||||||
from django.db import connection
|
from django.db import connection
|
||||||
|
from django.db.models.expressions import Col
|
||||||
|
from django.db.models.functions import Cast
|
||||||
from django.test import SimpleTestCase
|
from django.test import SimpleTestCase
|
||||||
|
|
||||||
from ..models import Person, Tag
|
from ..models import Author, Book, Person, Tag
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL tests.")
|
@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL tests.")
|
||||||
|
@ -48,3 +50,31 @@ class PostgreSQLOperationsTests(SimpleTestCase):
|
||||||
),
|
),
|
||||||
['TRUNCATE "backends_person", "backends_tag" RESTART IDENTITY CASCADE;'],
|
['TRUNCATE "backends_person", "backends_tag" RESTART IDENTITY CASCADE;'],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_prepare_join_on_clause_same_type(self):
|
||||||
|
author_table = Author._meta.db_table
|
||||||
|
author_id_field = Author._meta.get_field("id")
|
||||||
|
lhs_expr, rhs_expr = connection.ops.prepare_join_on_clause(
|
||||||
|
author_table,
|
||||||
|
author_id_field,
|
||||||
|
author_table,
|
||||||
|
author_id_field,
|
||||||
|
)
|
||||||
|
self.assertEqual(lhs_expr, Col(author_table, author_id_field))
|
||||||
|
self.assertEqual(rhs_expr, Col(author_table, author_id_field))
|
||||||
|
|
||||||
|
def test_prepare_join_on_clause_different_types(self):
|
||||||
|
author_table = Author._meta.db_table
|
||||||
|
author_id_field = Author._meta.get_field("id")
|
||||||
|
book_table = Book._meta.db_table
|
||||||
|
book_fk_field = Book._meta.get_field("author")
|
||||||
|
lhs_expr, rhs_expr = connection.ops.prepare_join_on_clause(
|
||||||
|
author_table,
|
||||||
|
author_id_field,
|
||||||
|
book_table,
|
||||||
|
book_fk_field,
|
||||||
|
)
|
||||||
|
self.assertEqual(lhs_expr, Col(author_table, author_id_field))
|
||||||
|
self.assertEqual(
|
||||||
|
rhs_expr, Cast(Col(book_table, book_fk_field), author_id_field)
|
||||||
|
)
|
||||||
|
|
|
@ -50,7 +50,7 @@ class StartsWithRelation(models.ForeignObject):
|
||||||
from_field = self.model._meta.get_field(self.from_fields[0])
|
from_field = self.model._meta.get_field(self.from_fields[0])
|
||||||
return StartsWith(to_field.get_col(alias), from_field.get_col(related_alias))
|
return StartsWith(to_field.get_col(alias), from_field.get_col(related_alias))
|
||||||
|
|
||||||
def get_joining_columns(self, reverse_join=False):
|
def get_joining_fields(self, reverse_join=False):
|
||||||
return ()
|
return ()
|
||||||
|
|
||||||
def get_path_info(self, filtered_relation=None):
|
def get_path_info(self, filtered_relation=None):
|
||||||
|
|
|
@ -64,12 +64,14 @@ class CharLink(models.Model):
|
||||||
content_type = models.ForeignKey(ContentType, models.CASCADE)
|
content_type = models.ForeignKey(ContentType, models.CASCADE)
|
||||||
object_id = models.CharField(max_length=100)
|
object_id = models.CharField(max_length=100)
|
||||||
content_object = GenericForeignKey()
|
content_object = GenericForeignKey()
|
||||||
|
value = models.CharField(max_length=250)
|
||||||
|
|
||||||
|
|
||||||
class TextLink(models.Model):
|
class TextLink(models.Model):
|
||||||
content_type = models.ForeignKey(ContentType, models.CASCADE)
|
content_type = models.ForeignKey(ContentType, models.CASCADE)
|
||||||
object_id = models.TextField()
|
object_id = models.TextField()
|
||||||
content_object = GenericForeignKey()
|
content_object = GenericForeignKey()
|
||||||
|
value = models.CharField(max_length=250)
|
||||||
|
|
||||||
|
|
||||||
class OddRelation1(models.Model):
|
class OddRelation1(models.Model):
|
||||||
|
|
|
@ -72,6 +72,20 @@ class GenericRelationTests(TestCase):
|
||||||
TextLink.objects.create(content_object=oddrel)
|
TextLink.objects.create(content_object=oddrel)
|
||||||
oddrel.delete()
|
oddrel.delete()
|
||||||
|
|
||||||
|
def test_charlink_filter(self):
|
||||||
|
oddrel = OddRelation1.objects.create(name="clink")
|
||||||
|
CharLink.objects.create(content_object=oddrel, value="value")
|
||||||
|
self.assertSequenceEqual(
|
||||||
|
OddRelation1.objects.filter(clinks__value="value"), [oddrel]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_textlink_filter(self):
|
||||||
|
oddrel = OddRelation2.objects.create(name="clink")
|
||||||
|
TextLink.objects.create(content_object=oddrel, value="value")
|
||||||
|
self.assertSequenceEqual(
|
||||||
|
OddRelation2.objects.filter(tlinks__value="value"), [oddrel]
|
||||||
|
)
|
||||||
|
|
||||||
def test_coerce_object_id_remote_field_cache_persistence(self):
|
def test_coerce_object_id_remote_field_cache_persistence(self):
|
||||||
restaurant = Restaurant.objects.create()
|
restaurant = Restaurant.objects.create()
|
||||||
CharLink.objects.create(content_object=restaurant)
|
CharLink.objects.create(content_object=restaurant)
|
||||||
|
|
Loading…
Reference in New Issue