Fixed #34955 -- Made Concat() use || operator on PostgreSQL.

This also avoids casting string based expressions in Concat() on
PostgreSQL.
This commit is contained in:
Simon Charette 2023-11-14 09:30:14 +01:00 committed by Mariusz Felisiak
parent bdf30b952c
commit 6364b6ee10
2 changed files with 37 additions and 15 deletions

View File

@ -73,7 +73,7 @@ class ConcatPair(Func):
function = "CONCAT"
def as_sqlite(self, compiler, connection, **extra_context):
def pipes_concat_sql(self, compiler, connection, **extra_context):
coalesced = self.coalesce()
return super(ConcatPair, coalesced).as_sql(
compiler,
@ -83,19 +83,19 @@ class ConcatPair(Func):
**extra_context,
)
as_sqlite = pipes_concat_sql
def as_postgresql(self, compiler, connection, **extra_context):
copy = self.copy()
copy.set_source_expressions(
c = self.copy()
c.set_source_expressions(
[
Cast(expression, TextField())
for expression in copy.get_source_expressions()
expression
if isinstance(expression.output_field, (CharField, TextField))
else Cast(expression, TextField())
for expression in c.get_source_expressions()
]
)
return super(ConcatPair, copy).as_sql(
compiler,
connection,
**extra_context,
)
return c.pipes_concat_sql(compiler, connection, **extra_context)
def as_mysql(self, compiler, connection, **extra_context):
# Use CONCAT_WS with an empty separator so that NULLs are ignored.
@ -132,16 +132,20 @@ class Concat(Func):
def __init__(self, *expressions, **extra):
if len(expressions) < 2:
raise ValueError("Concat must take at least two expressions")
paired = self._paired(expressions)
paired = self._paired(expressions, output_field=extra.get("output_field"))
super().__init__(paired, **extra)
def _paired(self, expressions):
def _paired(self, expressions, output_field):
# wrap pairs of expressions in successive concat functions
# exp = [a, b, c, d]
# -> ConcatPair(a, ConcatPair(b, ConcatPair(c, d))))
if len(expressions) == 2:
return ConcatPair(*expressions)
return ConcatPair(expressions[0], self._paired(expressions[1:]))
return ConcatPair(*expressions, output_field=output_field)
return ConcatPair(
expressions[0],
self._paired(expressions[1:], output_field=output_field),
output_field=output_field,
)
class Left(Func):

View File

@ -75,7 +75,10 @@ class ConcatTests(TestCase):
expected = article.title + " - " + article.text
self.assertEqual(expected.upper(), article.title_text)
@skipUnless(connection.vendor == "sqlite", "sqlite specific implementation detail.")
@skipUnless(
connection.vendor in ("sqlite", "postgresql"),
"SQLite and PostgreSQL specific implementation detail.",
)
def test_coalesce_idempotent(self):
pair = ConcatPair(V("a"), V("b"))
# Check nodes counts
@ -89,3 +92,18 @@ class ConcatTests(TestCase):
qs = Article.objects.annotate(description=Concat("title", V(": "), "summary"))
# Multiple compilations should not alter the generated query.
self.assertEqual(str(qs.query), str(qs.all().query))
def test_concat_non_str(self):
Author.objects.create(name="The Name", age=42)
with self.assertNumQueries(1) as ctx:
author = Author.objects.annotate(
name_text=Concat(
"name", V(":"), "alias", V(":"), "age", output_field=TextField()
),
).get()
self.assertEqual(author.name_text, "The Name::42")
# Only non-string columns are casted on PostgreSQL.
self.assertEqual(
ctx.captured_queries[0]["sql"].count("::text"),
1 if connection.vendor == "postgresql" else 0,
)