Fixed #16055 -- Fixed crash when filtering against char/text GenericRelation relation on PostgreSQL.

This commit is contained in:
David Wobrock 2023-04-18 10:19:06 +02:00 committed by Mariusz Felisiak
parent 594fcc2b74
commit 9bbf97bcdb
11 changed files with 117 additions and 13 deletions

View File

@ -8,6 +8,7 @@ import sqlparse
from django.conf import settings from django.conf import settings
from django.db import NotSupportedError, transaction from django.db import NotSupportedError, transaction
from django.db.backends import utils from django.db.backends import utils
from django.db.models.expressions import Col
from django.utils import timezone from django.utils import timezone
from django.utils.encoding import force_str from django.utils.encoding import force_str
@ -776,3 +777,9 @@ class BaseDatabaseOperations:
def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields): def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
return "" return ""
def prepare_join_on_clause(self, lhs_table, lhs_field, rhs_table, rhs_field):
lhs_expr = Col(lhs_table, lhs_field)
rhs_expr = Col(rhs_table, rhs_field)
return lhs_expr, rhs_expr

View File

@ -120,6 +120,9 @@ class DatabaseFeatures(BaseDatabaseFeatures):
"migrations.test_operations.OperationTests." "migrations.test_operations.OperationTests."
"test_alter_field_pk_fk_db_collation", "test_alter_field_pk_fk_db_collation",
}, },
"Oracle doesn't support comparing NCLOB to NUMBER.": {
"generic_relations_regress.tests.GenericRelationTests.test_textlink_filter",
},
} }
django_test_expected_failures = { django_test_expected_failures = {
# A bug in Django/cx_Oracle with respect to string handling (#23843). # A bug in Django/cx_Oracle with respect to string handling (#23843).

View File

@ -12,6 +12,7 @@ from django.db.backends.postgresql.psycopg_any import (
) )
from django.db.backends.utils import split_tzname_delta from django.db.backends.utils import split_tzname_delta
from django.db.models.constants import OnConflict from django.db.models.constants import OnConflict
from django.db.models.functions import Cast
from django.utils.regex_helper import _lazy_re_compile from django.utils.regex_helper import _lazy_re_compile
@ -413,3 +414,13 @@ class DatabaseOperations(BaseDatabaseOperations):
update_fields, update_fields,
unique_fields, unique_fields,
) )
def prepare_join_on_clause(self, lhs_table, lhs_field, rhs_table, rhs_field):
lhs_expr, rhs_expr = super().prepare_join_on_clause(
lhs_table, lhs_field, rhs_table, rhs_field
)
if lhs_field.db_type(self.connection) != rhs_field.db_type(self.connection):
rhs_expr = Cast(rhs_expr, lhs_field)
return lhs_expr, rhs_expr

View File

@ -785,6 +785,14 @@ class ForeignObject(RelatedField):
def get_reverse_joining_columns(self): def get_reverse_joining_columns(self):
return self.get_joining_columns(reverse_join=True) return self.get_joining_columns(reverse_join=True)
def get_joining_fields(self, reverse_join=False):
return tuple(
self.reverse_related_fields if reverse_join else self.related_fields
)
def get_reverse_joining_fields(self):
return self.get_joining_fields(reverse_join=True)
def get_extra_descriptor_filter(self, instance): def get_extra_descriptor_filter(self, instance):
""" """
Return an extra filter condition for related object fetching when Return an extra filter condition for related object fetching when

View File

@ -195,6 +195,9 @@ class ForeignObjectRel(FieldCacheMixin):
def get_joining_columns(self): def get_joining_columns(self):
return self.field.get_reverse_joining_columns() return self.field.get_reverse_joining_columns()
def get_joining_fields(self):
return self.field.get_reverse_joining_fields()
def get_extra_restriction(self, alias, related_alias): def get_extra_restriction(self, alias, related_alias):
return self.field.get_extra_restriction(related_alias, alias) return self.field.get_extra_restriction(related_alias, alias)

View File

@ -61,6 +61,14 @@ class Join:
self.join_type = join_type self.join_type = join_type
# A list of 2-tuples to use in the ON clause of the JOIN. # A list of 2-tuples to use in the ON clause of the JOIN.
# Each 2-tuple will create one join condition in the ON clause. # Each 2-tuple will create one join condition in the ON clause.
if hasattr(join_field, "get_joining_fields"):
self.join_fields = join_field.get_joining_fields()
self.join_cols = tuple(
(lhs_field.column, rhs_field.column)
for lhs_field, rhs_field in self.join_fields
)
else:
self.join_fields = None
self.join_cols = join_field.get_joining_columns() self.join_cols = join_field.get_joining_columns()
# Along which field (or ForeignObjectRel in the reverse join case) # Along which field (or ForeignObjectRel in the reverse join case)
self.join_field = join_field self.join_field = join_field
@ -78,18 +86,21 @@ class Join:
params = [] params = []
qn = compiler.quote_name_unless_alias qn = compiler.quote_name_unless_alias
qn2 = connection.ops.quote_name qn2 = connection.ops.quote_name
# Add a join condition for each pair of joining columns. # Add a join condition for each pair of joining columns.
for lhs_col, rhs_col in self.join_cols: join_fields = self.join_fields or self.join_cols
join_conditions.append( for lhs, rhs in join_fields:
"%s.%s = %s.%s" if isinstance(lhs, str):
% ( lhs_full_name = "%s.%s" % (qn(self.parent_alias), qn2(lhs))
qn(self.parent_alias), rhs_full_name = "%s.%s" % (qn(self.table_alias), qn2(rhs))
qn2(lhs_col), else:
qn(self.table_alias), lhs, rhs = connection.ops.prepare_join_on_clause(
qn2(rhs_col), self.parent_alias, lhs, self.table_alias, rhs
)
) )
lhs_sql, lhs_params = compiler.compile(lhs)
lhs_full_name = lhs_sql % lhs_params
rhs_sql, rhs_params = compiler.compile(rhs)
rhs_full_name = rhs_sql % rhs_params
join_conditions.append(f"{lhs_full_name} = {rhs_full_name}")
# Add a single condition inside parentheses for whatever # Add a single condition inside parentheses for whatever
# get_extra_restriction() returns. # get_extra_restriction() returns.

View File

@ -4,6 +4,7 @@ from django.core.management.color import no_style
from django.db import NotSupportedError, connection, transaction from django.db import NotSupportedError, connection, transaction
from django.db.backends.base.operations import BaseDatabaseOperations from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.models import DurationField, Value from django.db.models import DurationField, Value
from django.db.models.expressions import Col
from django.test import ( from django.test import (
SimpleTestCase, SimpleTestCase,
TestCase, TestCase,
@ -159,6 +160,20 @@ class SimpleDatabaseOperationTests(SimpleTestCase):
): ):
self.ops.datetime_extract_sql(None, None, None, None) self.ops.datetime_extract_sql(None, None, None, None)
def test_prepare_join_on_clause(self):
author_table = Author._meta.db_table
author_id_field = Author._meta.get_field("id")
book_table = Book._meta.db_table
book_fk_field = Book._meta.get_field("author")
lhs_expr, rhs_expr = self.ops.prepare_join_on_clause(
author_table,
author_id_field,
book_table,
book_fk_field,
)
self.assertEqual(lhs_expr, Col(author_table, author_id_field))
self.assertEqual(rhs_expr, Col(book_table, book_fk_field))
class DatabaseOperationTests(TestCase): class DatabaseOperationTests(TestCase):
def setUp(self): def setUp(self):

View File

@ -2,9 +2,11 @@ import unittest
from django.core.management.color import no_style from django.core.management.color import no_style
from django.db import connection from django.db import connection
from django.db.models.expressions import Col
from django.db.models.functions import Cast
from django.test import SimpleTestCase from django.test import SimpleTestCase
from ..models import Person, Tag from ..models import Author, Book, Person, Tag
@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL tests.") @unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL tests.")
@ -48,3 +50,31 @@ class PostgreSQLOperationsTests(SimpleTestCase):
), ),
['TRUNCATE "backends_person", "backends_tag" RESTART IDENTITY CASCADE;'], ['TRUNCATE "backends_person", "backends_tag" RESTART IDENTITY CASCADE;'],
) )
def test_prepare_join_on_clause_same_type(self):
author_table = Author._meta.db_table
author_id_field = Author._meta.get_field("id")
lhs_expr, rhs_expr = connection.ops.prepare_join_on_clause(
author_table,
author_id_field,
author_table,
author_id_field,
)
self.assertEqual(lhs_expr, Col(author_table, author_id_field))
self.assertEqual(rhs_expr, Col(author_table, author_id_field))
def test_prepare_join_on_clause_different_types(self):
author_table = Author._meta.db_table
author_id_field = Author._meta.get_field("id")
book_table = Book._meta.db_table
book_fk_field = Book._meta.get_field("author")
lhs_expr, rhs_expr = connection.ops.prepare_join_on_clause(
author_table,
author_id_field,
book_table,
book_fk_field,
)
self.assertEqual(lhs_expr, Col(author_table, author_id_field))
self.assertEqual(
rhs_expr, Cast(Col(book_table, book_fk_field), author_id_field)
)

View File

@ -50,7 +50,7 @@ class StartsWithRelation(models.ForeignObject):
from_field = self.model._meta.get_field(self.from_fields[0]) from_field = self.model._meta.get_field(self.from_fields[0])
return StartsWith(to_field.get_col(alias), from_field.get_col(related_alias)) return StartsWith(to_field.get_col(alias), from_field.get_col(related_alias))
def get_joining_columns(self, reverse_join=False): def get_joining_fields(self, reverse_join=False):
return () return ()
def get_path_info(self, filtered_relation=None): def get_path_info(self, filtered_relation=None):

View File

@ -64,12 +64,14 @@ class CharLink(models.Model):
content_type = models.ForeignKey(ContentType, models.CASCADE) content_type = models.ForeignKey(ContentType, models.CASCADE)
object_id = models.CharField(max_length=100) object_id = models.CharField(max_length=100)
content_object = GenericForeignKey() content_object = GenericForeignKey()
value = models.CharField(max_length=250)
class TextLink(models.Model): class TextLink(models.Model):
content_type = models.ForeignKey(ContentType, models.CASCADE) content_type = models.ForeignKey(ContentType, models.CASCADE)
object_id = models.TextField() object_id = models.TextField()
content_object = GenericForeignKey() content_object = GenericForeignKey()
value = models.CharField(max_length=250)
class OddRelation1(models.Model): class OddRelation1(models.Model):

View File

@ -72,6 +72,20 @@ class GenericRelationTests(TestCase):
TextLink.objects.create(content_object=oddrel) TextLink.objects.create(content_object=oddrel)
oddrel.delete() oddrel.delete()
def test_charlink_filter(self):
oddrel = OddRelation1.objects.create(name="clink")
CharLink.objects.create(content_object=oddrel, value="value")
self.assertSequenceEqual(
OddRelation1.objects.filter(clinks__value="value"), [oddrel]
)
def test_textlink_filter(self):
oddrel = OddRelation2.objects.create(name="clink")
TextLink.objects.create(content_object=oddrel, value="value")
self.assertSequenceEqual(
OddRelation2.objects.filter(tlinks__value="value"), [oddrel]
)
def test_coerce_object_id_remote_field_cache_persistence(self): def test_coerce_object_id_remote_field_cache_persistence(self):
restaurant = Restaurant.objects.create() restaurant = Restaurant.objects.create()
CharLink.objects.create(content_object=restaurant) CharLink.objects.create(content_object=restaurant)