mirror of https://github.com/django/django.git
Fixed #34955 -- Made Concat() use || operator on PostgreSQL.
This also avoids casting string based expressions in Concat() on PostgreSQL.
This commit is contained in:
parent
bdf30b952c
commit
6364b6ee10
|
@ -73,7 +73,7 @@ class ConcatPair(Func):
|
||||||
|
|
||||||
function = "CONCAT"
|
function = "CONCAT"
|
||||||
|
|
||||||
def as_sqlite(self, compiler, connection, **extra_context):
|
def pipes_concat_sql(self, compiler, connection, **extra_context):
|
||||||
coalesced = self.coalesce()
|
coalesced = self.coalesce()
|
||||||
return super(ConcatPair, coalesced).as_sql(
|
return super(ConcatPair, coalesced).as_sql(
|
||||||
compiler,
|
compiler,
|
||||||
|
@ -83,19 +83,19 @@ class ConcatPair(Func):
|
||||||
**extra_context,
|
**extra_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
as_sqlite = pipes_concat_sql
|
||||||
|
|
||||||
def as_postgresql(self, compiler, connection, **extra_context):
|
def as_postgresql(self, compiler, connection, **extra_context):
|
||||||
copy = self.copy()
|
c = self.copy()
|
||||||
copy.set_source_expressions(
|
c.set_source_expressions(
|
||||||
[
|
[
|
||||||
Cast(expression, TextField())
|
expression
|
||||||
for expression in copy.get_source_expressions()
|
if isinstance(expression.output_field, (CharField, TextField))
|
||||||
|
else Cast(expression, TextField())
|
||||||
|
for expression in c.get_source_expressions()
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
return super(ConcatPair, copy).as_sql(
|
return c.pipes_concat_sql(compiler, connection, **extra_context)
|
||||||
compiler,
|
|
||||||
connection,
|
|
||||||
**extra_context,
|
|
||||||
)
|
|
||||||
|
|
||||||
def as_mysql(self, compiler, connection, **extra_context):
|
def as_mysql(self, compiler, connection, **extra_context):
|
||||||
# Use CONCAT_WS with an empty separator so that NULLs are ignored.
|
# Use CONCAT_WS with an empty separator so that NULLs are ignored.
|
||||||
|
@ -132,16 +132,20 @@ class Concat(Func):
|
||||||
def __init__(self, *expressions, **extra):
|
def __init__(self, *expressions, **extra):
|
||||||
if len(expressions) < 2:
|
if len(expressions) < 2:
|
||||||
raise ValueError("Concat must take at least two expressions")
|
raise ValueError("Concat must take at least two expressions")
|
||||||
paired = self._paired(expressions)
|
paired = self._paired(expressions, output_field=extra.get("output_field"))
|
||||||
super().__init__(paired, **extra)
|
super().__init__(paired, **extra)
|
||||||
|
|
||||||
def _paired(self, expressions):
|
def _paired(self, expressions, output_field):
|
||||||
# wrap pairs of expressions in successive concat functions
|
# wrap pairs of expressions in successive concat functions
|
||||||
# exp = [a, b, c, d]
|
# exp = [a, b, c, d]
|
||||||
# -> ConcatPair(a, ConcatPair(b, ConcatPair(c, d))))
|
# -> ConcatPair(a, ConcatPair(b, ConcatPair(c, d))))
|
||||||
if len(expressions) == 2:
|
if len(expressions) == 2:
|
||||||
return ConcatPair(*expressions)
|
return ConcatPair(*expressions, output_field=output_field)
|
||||||
return ConcatPair(expressions[0], self._paired(expressions[1:]))
|
return ConcatPair(
|
||||||
|
expressions[0],
|
||||||
|
self._paired(expressions[1:], output_field=output_field),
|
||||||
|
output_field=output_field,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Left(Func):
|
class Left(Func):
|
||||||
|
|
|
@ -75,7 +75,10 @@ class ConcatTests(TestCase):
|
||||||
expected = article.title + " - " + article.text
|
expected = article.title + " - " + article.text
|
||||||
self.assertEqual(expected.upper(), article.title_text)
|
self.assertEqual(expected.upper(), article.title_text)
|
||||||
|
|
||||||
@skipUnless(connection.vendor == "sqlite", "sqlite specific implementation detail.")
|
@skipUnless(
|
||||||
|
connection.vendor in ("sqlite", "postgresql"),
|
||||||
|
"SQLite and PostgreSQL specific implementation detail.",
|
||||||
|
)
|
||||||
def test_coalesce_idempotent(self):
|
def test_coalesce_idempotent(self):
|
||||||
pair = ConcatPair(V("a"), V("b"))
|
pair = ConcatPair(V("a"), V("b"))
|
||||||
# Check nodes counts
|
# Check nodes counts
|
||||||
|
@ -89,3 +92,18 @@ class ConcatTests(TestCase):
|
||||||
qs = Article.objects.annotate(description=Concat("title", V(": "), "summary"))
|
qs = Article.objects.annotate(description=Concat("title", V(": "), "summary"))
|
||||||
# Multiple compilations should not alter the generated query.
|
# Multiple compilations should not alter the generated query.
|
||||||
self.assertEqual(str(qs.query), str(qs.all().query))
|
self.assertEqual(str(qs.query), str(qs.all().query))
|
||||||
|
|
||||||
|
def test_concat_non_str(self):
|
||||||
|
Author.objects.create(name="The Name", age=42)
|
||||||
|
with self.assertNumQueries(1) as ctx:
|
||||||
|
author = Author.objects.annotate(
|
||||||
|
name_text=Concat(
|
||||||
|
"name", V(":"), "alias", V(":"), "age", output_field=TextField()
|
||||||
|
),
|
||||||
|
).get()
|
||||||
|
self.assertEqual(author.name_text, "The Name::42")
|
||||||
|
# Only non-string columns are casted on PostgreSQL.
|
||||||
|
self.assertEqual(
|
||||||
|
ctx.captured_queries[0]["sql"].count("::text"),
|
||||||
|
1 if connection.vendor == "postgresql" else 0,
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in New Issue