Refs #28643 -- Added Repeat database function.

Thanks Tim Graham and Nick Pope for reviews.
This commit is contained in:
Mariusz Felisiak 2018-04-03 19:36:12 +02:00 committed by GitHub
parent 6141c752fe
commit 55cc26941a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 64 additions and 4 deletions

View File

@ -4,6 +4,7 @@ SQLite3 backend for the sqlite3 module in the standard library.
import datetime
import decimal
import math
import operator
import re
import warnings
from sqlite3 import dbapi2 as Database
@ -170,6 +171,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
conn.create_function("django_format_dtdelta", 3, _sqlite_format_dtdelta)
conn.create_function("django_power", 2, _sqlite_power)
conn.create_function('LPAD', 3, _sqlite_lpad)
conn.create_function('REPEAT', 2, operator.mul)
conn.create_function('RPAD', 3, _sqlite_rpad)
conn.execute('PRAGMA foreign_keys = ON')
return conn

View File

@ -6,8 +6,8 @@ from .datetime import (
TruncQuarter, TruncSecond, TruncTime, TruncWeek, TruncYear,
)
from .text import (
Chr, Concat, ConcatPair, Left, Length, Lower, LPad, LTrim, Ord, Replace,
Right, RPad, RTrim, StrIndex, Substr, Trim, Upper,
Chr, Concat, ConcatPair, Left, Length, Lower, LPad, LTrim, Ord, Repeat,
Replace, Right, RPad, RTrim, StrIndex, Substr, Trim, Upper,
)
from .window import (
CumeDist, DenseRank, FirstValue, Lag, LastValue, Lead, NthValue, Ntile,
@ -25,8 +25,8 @@ __all__ = [
'TruncWeek', 'TruncYear',
# text
'Chr', 'Concat', 'ConcatPair', 'Left', 'Length', 'Lower', 'LPad', 'LTrim',
'Ord', 'Replace', 'Right', 'RPad', 'RTrim', 'StrIndex', 'Substr', 'Trim',
'Upper',
'Ord', 'Repeat', 'Replace', 'Right', 'RPad', 'RTrim', 'StrIndex', 'Substr',
'Trim', 'Upper',
# window
'CumeDist', 'DenseRank', 'FirstValue', 'Lag', 'LastValue', 'Lead',
'NthValue', 'Ntile', 'PercentRank', 'Rank', 'RowNumber',

View File

@ -152,6 +152,20 @@ class Ord(Transform):
return super().as_sql(compiler, connection, function='UNICODE', **extra_context)
class Repeat(BytesToCharFieldConversionMixin, Func):
function = 'REPEAT'
def __init__(self, expression, number, **extra):
if not hasattr(number, 'resolve_expression') and number < 0:
raise ValueError("'number' must be greater or equal to 0.")
super().__init__(expression, number, **extra)
def as_oracle(self, compiler, connection, **extra_context):
expression, number = self.source_expressions
rpad = RPad(expression, Length(expression) * number, expression)
return rpad.as_sql(compiler, connection, **extra_context)
class Replace(Func):
function = 'REPLACE'

View File

@ -855,6 +855,25 @@ Usage example::
>>> print(author.name_code_point)
77
``Repeat``
----------
.. class:: Repeat(expression, number, **extra)
.. versionadded:: 2.1
Returns the value of the given text field or expression repeated ``number``
times.
Usage example::
>>> from django.db.models.functions import Repeat
>>> Author.objects.create(name='John', alias='j')
>>> Author.objects.update(name=Repeat('name', 3))
1
>>> print(Author.objects.get(alias='j').name)
JohnJohnJohn
``Replace``
-----------

View File

@ -210,6 +210,7 @@ Models
:class:`~django.db.models.functions.LPad`,
:class:`~django.db.models.functions.LTrim`,
:class:`~django.db.models.functions.Ord`,
:class:`~django.db.models.functions.Repeat`,
:class:`~django.db.models.functions.Replace`,
:class:`~django.db.models.functions.Right`,
:class:`~django.db.models.functions.RPad`,

View File

@ -0,0 +1,24 @@
from django.db.models import CharField, Value
from django.db.models.functions import Length, Repeat
from django.test import TestCase
from .models import Author
class RepeatTests(TestCase):
def test_basic(self):
Author.objects.create(name='John', alias='xyz')
tests = (
(Repeat('name', 0), ''),
(Repeat('name', 2), 'JohnJohn'),
(Repeat('name', Length('alias'), output_field=CharField()), 'JohnJohnJohn'),
(Repeat(Value('x'), 3, output_field=CharField()), 'xxx'),
)
for function, repeated_text in tests:
with self.subTest(function=function):
authors = Author.objects.annotate(repeated_text=function)
self.assertQuerysetEqual(authors, [repeated_text], lambda a: a.repeated_text, ordered=False)
def test_negative_number(self):
with self.assertRaisesMessage(ValueError, "'number' must be greater or equal to 0."):
Repeat('name', -1)