Fixed #34955 -- Made Concat() use || operator on PostgreSQL.

This also avoids casting string based expressions in Concat() on
PostgreSQL.
This commit is contained in:
Simon Charette 2023-11-14 09:30:14 +01:00 committed by Mariusz Felisiak
parent bdf30b952c
commit 6364b6ee10
2 changed files with 37 additions and 15 deletions

View File

@ -73,7 +73,7 @@ class ConcatPair(Func):
function = "CONCAT" function = "CONCAT"
def as_sqlite(self, compiler, connection, **extra_context): def pipes_concat_sql(self, compiler, connection, **extra_context):
coalesced = self.coalesce() coalesced = self.coalesce()
return super(ConcatPair, coalesced).as_sql( return super(ConcatPair, coalesced).as_sql(
compiler, compiler,
@ -83,19 +83,19 @@ class ConcatPair(Func):
**extra_context, **extra_context,
) )
as_sqlite = pipes_concat_sql
def as_postgresql(self, compiler, connection, **extra_context): def as_postgresql(self, compiler, connection, **extra_context):
copy = self.copy() c = self.copy()
copy.set_source_expressions( c.set_source_expressions(
[ [
Cast(expression, TextField()) expression
for expression in copy.get_source_expressions() if isinstance(expression.output_field, (CharField, TextField))
else Cast(expression, TextField())
for expression in c.get_source_expressions()
] ]
) )
return super(ConcatPair, copy).as_sql( return c.pipes_concat_sql(compiler, connection, **extra_context)
compiler,
connection,
**extra_context,
)
def as_mysql(self, compiler, connection, **extra_context): def as_mysql(self, compiler, connection, **extra_context):
# Use CONCAT_WS with an empty separator so that NULLs are ignored. # Use CONCAT_WS with an empty separator so that NULLs are ignored.
@ -132,16 +132,20 @@ class Concat(Func):
def __init__(self, *expressions, **extra): def __init__(self, *expressions, **extra):
if len(expressions) < 2: if len(expressions) < 2:
raise ValueError("Concat must take at least two expressions") raise ValueError("Concat must take at least two expressions")
paired = self._paired(expressions) paired = self._paired(expressions, output_field=extra.get("output_field"))
super().__init__(paired, **extra) super().__init__(paired, **extra)
def _paired(self, expressions): def _paired(self, expressions, output_field):
# wrap pairs of expressions in successive concat functions # wrap pairs of expressions in successive concat functions
# exp = [a, b, c, d] # exp = [a, b, c, d]
# -> ConcatPair(a, ConcatPair(b, ConcatPair(c, d)))) # -> ConcatPair(a, ConcatPair(b, ConcatPair(c, d))))
if len(expressions) == 2: if len(expressions) == 2:
return ConcatPair(*expressions) return ConcatPair(*expressions, output_field=output_field)
return ConcatPair(expressions[0], self._paired(expressions[1:])) return ConcatPair(
expressions[0],
self._paired(expressions[1:], output_field=output_field),
output_field=output_field,
)
class Left(Func): class Left(Func):

View File

@ -75,7 +75,10 @@ class ConcatTests(TestCase):
expected = article.title + " - " + article.text expected = article.title + " - " + article.text
self.assertEqual(expected.upper(), article.title_text) self.assertEqual(expected.upper(), article.title_text)
@skipUnless(connection.vendor == "sqlite", "sqlite specific implementation detail.") @skipUnless(
connection.vendor in ("sqlite", "postgresql"),
"SQLite and PostgreSQL specific implementation detail.",
)
def test_coalesce_idempotent(self): def test_coalesce_idempotent(self):
pair = ConcatPair(V("a"), V("b")) pair = ConcatPair(V("a"), V("b"))
# Check nodes counts # Check nodes counts
@ -89,3 +92,18 @@ class ConcatTests(TestCase):
qs = Article.objects.annotate(description=Concat("title", V(": "), "summary")) qs = Article.objects.annotate(description=Concat("title", V(": "), "summary"))
# Multiple compilations should not alter the generated query. # Multiple compilations should not alter the generated query.
self.assertEqual(str(qs.query), str(qs.all().query)) self.assertEqual(str(qs.query), str(qs.all().query))
def test_concat_non_str(self):
Author.objects.create(name="The Name", age=42)
with self.assertNumQueries(1) as ctx:
author = Author.objects.annotate(
name_text=Concat(
"name", V(":"), "alias", V(":"), "age", output_field=TextField()
),
).get()
self.assertEqual(author.name_text, "The Name::42")
# Only non-string columns are casted on PostgreSQL.
self.assertEqual(
ctx.captured_queries[0]["sql"].count("::text"),
1 if connection.vendor == "postgresql" else 0,
)