Refs #28643 -- Added Repeat database function.

Thanks Tim Graham and Nick Pope for reviews.
This commit is contained in:
Mariusz Felisiak 2018-04-03 19:36:12 +02:00 committed by GitHub
parent 6141c752fe
commit 55cc26941a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 64 additions and 4 deletions

View File

@ -4,6 +4,7 @@ SQLite3 backend for the sqlite3 module in the standard library.
import datetime import datetime
import decimal import decimal
import math import math
import operator
import re import re
import warnings import warnings
from sqlite3 import dbapi2 as Database from sqlite3 import dbapi2 as Database
@ -170,6 +171,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
conn.create_function("django_format_dtdelta", 3, _sqlite_format_dtdelta) conn.create_function("django_format_dtdelta", 3, _sqlite_format_dtdelta)
conn.create_function("django_power", 2, _sqlite_power) conn.create_function("django_power", 2, _sqlite_power)
conn.create_function('LPAD', 3, _sqlite_lpad) conn.create_function('LPAD', 3, _sqlite_lpad)
conn.create_function('REPEAT', 2, operator.mul)
conn.create_function('RPAD', 3, _sqlite_rpad) conn.create_function('RPAD', 3, _sqlite_rpad)
conn.execute('PRAGMA foreign_keys = ON') conn.execute('PRAGMA foreign_keys = ON')
return conn return conn

View File

@ -6,8 +6,8 @@ from .datetime import (
TruncQuarter, TruncSecond, TruncTime, TruncWeek, TruncYear, TruncQuarter, TruncSecond, TruncTime, TruncWeek, TruncYear,
) )
from .text import ( from .text import (
Chr, Concat, ConcatPair, Left, Length, Lower, LPad, LTrim, Ord, Replace, Chr, Concat, ConcatPair, Left, Length, Lower, LPad, LTrim, Ord, Repeat,
Right, RPad, RTrim, StrIndex, Substr, Trim, Upper, Replace, Right, RPad, RTrim, StrIndex, Substr, Trim, Upper,
) )
from .window import ( from .window import (
CumeDist, DenseRank, FirstValue, Lag, LastValue, Lead, NthValue, Ntile, CumeDist, DenseRank, FirstValue, Lag, LastValue, Lead, NthValue, Ntile,
@ -25,8 +25,8 @@ __all__ = [
'TruncWeek', 'TruncYear', 'TruncWeek', 'TruncYear',
# text # text
'Chr', 'Concat', 'ConcatPair', 'Left', 'Length', 'Lower', 'LPad', 'LTrim', 'Chr', 'Concat', 'ConcatPair', 'Left', 'Length', 'Lower', 'LPad', 'LTrim',
'Ord', 'Replace', 'Right', 'RPad', 'RTrim', 'StrIndex', 'Substr', 'Trim', 'Ord', 'Repeat', 'Replace', 'Right', 'RPad', 'RTrim', 'StrIndex', 'Substr',
'Upper', 'Trim', 'Upper',
# window # window
'CumeDist', 'DenseRank', 'FirstValue', 'Lag', 'LastValue', 'Lead', 'CumeDist', 'DenseRank', 'FirstValue', 'Lag', 'LastValue', 'Lead',
'NthValue', 'Ntile', 'PercentRank', 'Rank', 'RowNumber', 'NthValue', 'Ntile', 'PercentRank', 'Rank', 'RowNumber',

View File

@ -152,6 +152,20 @@ class Ord(Transform):
return super().as_sql(compiler, connection, function='UNICODE', **extra_context) return super().as_sql(compiler, connection, function='UNICODE', **extra_context)
class Repeat(BytesToCharFieldConversionMixin, Func):
function = 'REPEAT'
def __init__(self, expression, number, **extra):
if not hasattr(number, 'resolve_expression') and number < 0:
raise ValueError("'number' must be greater or equal to 0.")
super().__init__(expression, number, **extra)
def as_oracle(self, compiler, connection, **extra_context):
expression, number = self.source_expressions
rpad = RPad(expression, Length(expression) * number, expression)
return rpad.as_sql(compiler, connection, **extra_context)
class Replace(Func): class Replace(Func):
function = 'REPLACE' function = 'REPLACE'

View File

@ -855,6 +855,25 @@ Usage example::
>>> print(author.name_code_point) >>> print(author.name_code_point)
77 77
``Repeat``
----------
.. class:: Repeat(expression, number, **extra)
.. versionadded:: 2.1
Returns the value of the given text field or expression repeated ``number``
times.
Usage example::
>>> from django.db.models.functions import Repeat
>>> Author.objects.create(name='John', alias='j')
>>> Author.objects.update(name=Repeat('name', 3))
1
>>> print(Author.objects.get(alias='j').name)
JohnJohnJohn
``Replace`` ``Replace``
----------- -----------

View File

@ -210,6 +210,7 @@ Models
:class:`~django.db.models.functions.LPad`, :class:`~django.db.models.functions.LPad`,
:class:`~django.db.models.functions.LTrim`, :class:`~django.db.models.functions.LTrim`,
:class:`~django.db.models.functions.Ord`, :class:`~django.db.models.functions.Ord`,
:class:`~django.db.models.functions.Repeat`,
:class:`~django.db.models.functions.Replace`, :class:`~django.db.models.functions.Replace`,
:class:`~django.db.models.functions.Right`, :class:`~django.db.models.functions.Right`,
:class:`~django.db.models.functions.RPad`, :class:`~django.db.models.functions.RPad`,

View File

@ -0,0 +1,24 @@
from django.db.models import CharField, Value
from django.db.models.functions import Length, Repeat
from django.test import TestCase
from .models import Author
class RepeatTests(TestCase):
def test_basic(self):
Author.objects.create(name='John', alias='xyz')
tests = (
(Repeat('name', 0), ''),
(Repeat('name', 2), 'JohnJohn'),
(Repeat('name', Length('alias'), output_field=CharField()), 'JohnJohnJohn'),
(Repeat(Value('x'), 3, output_field=CharField()), 'xxx'),
)
for function, repeated_text in tests:
with self.subTest(function=function):
authors = Author.objects.annotate(repeated_text=function)
self.assertQuerysetEqual(authors, [repeated_text], lambda a: a.repeated_text, ordered=False)
def test_negative_number(self):
with self.assertRaisesMessage(ValueError, "'number' must be greater or equal to 0."):
Repeat('name', -1)