Fixed #29049 -- Added slicing notation to F expressions.

Co-authored-by: Priyansh Saxena <askpriyansh@gmail.com>
Co-authored-by: Niclas Olofsson <n@niclasolofsson.se>
Co-authored-by: David Smith <smithdc@gmail.com>
Co-authored-by: Mariusz Felisiak <felisiak.mariusz@gmail.com>
Co-authored-by: Abhinav Yadav <abhinav.sny.2002@gmail.com>
This commit is contained in:
Nick Pope 2023-12-30 07:24:30 +00:00 committed by GitHub
parent 561e16d6a7
commit 94b6f101f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 256 additions and 4 deletions

View File

@ -234,6 +234,12 @@ class ArrayField(CheckFieldDefaultMixin, Field):
}
)
def slice_expression(self, expression, start, length):
# If length is not provided, don't specify an end to slice to the end
# of the array.
end = None if length is None else start + length - 1
return SliceTransform(start, end, expression)
class ArrayRHSMixin:
def __init__(self, lhs, rhs):
@ -351,9 +357,11 @@ class SliceTransform(Transform):
def as_sql(self, compiler, connection):
lhs, params = compiler.compile(self.lhs)
if not lhs.endswith("]"):
lhs = "(%s)" % lhs
return "%s[%%s:%%s]" % lhs, (*params, self.start, self.end)
# self.start is set to 1 if slice start is not provided.
if self.end is None:
return f"({lhs})[%s:]", (*params, self.start)
else:
return f"({lhs})[%s:%s]", (*params, self.start, self.end)
class SliceTransformFactory:

View File

@ -851,6 +851,9 @@ class F(Combinable):
def __repr__(self):
return "{}({})".format(self.__class__.__name__, self.name)
def __getitem__(self, subscript):
return Sliced(self, subscript)
def resolve_expression(
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
):
@ -925,6 +928,63 @@ class OuterRef(F):
return self
class Sliced(F):
"""
An object that contains a slice of an F expression.
Object resolves the column on which the slicing is applied, and then
applies the slicing if possible.
"""
def __init__(self, obj, subscript):
super().__init__(obj.name)
self.obj = obj
if isinstance(subscript, int):
if subscript < 0:
raise ValueError("Negative indexing is not supported.")
self.start = subscript + 1
self.length = 1
elif isinstance(subscript, slice):
if (subscript.start is not None and subscript.start < 0) or (
subscript.stop is not None and subscript.stop < 0
):
raise ValueError("Negative indexing is not supported.")
if subscript.step is not None:
raise ValueError("Step argument is not supported.")
if subscript.stop and subscript.start and subscript.stop < subscript.start:
raise ValueError("Slice stop must be greater than slice start.")
self.start = 1 if subscript.start is None else subscript.start + 1
if subscript.stop is None:
self.length = None
else:
self.length = subscript.stop - (subscript.start or 0)
else:
raise TypeError("Argument to slice must be either int or slice instance.")
def __repr__(self):
start = self.start - 1
stop = None if self.length is None else start + self.length
subscript = slice(start, stop)
return f"{self.__class__.__qualname__}({self.obj!r}, {subscript!r})"
def resolve_expression(
self,
query=None,
allow_joins=True,
reuse=None,
summarize=False,
for_save=False,
):
resolved = query.resolve_ref(self.name, allow_joins, reuse, summarize)
if isinstance(self.obj, (OuterRef, self.__class__)):
expr = self.obj.resolve_expression(
query, allow_joins, reuse, summarize, for_save
)
else:
expr = resolved
return resolved.output_field.slice_expression(expr, self.start, self.length)
@deconstructible(path="django.db.models.Func")
class Func(SQLiteNumericMixin, Expression):
"""An SQL function call."""

View File

@ -15,6 +15,7 @@ from django.core import checks, exceptions, validators
from django.db import connection, connections, router
from django.db.models.constants import LOOKUP_SEP
from django.db.models.query_utils import DeferredAttribute, RegisterLookupMixin
from django.db.utils import NotSupportedError
from django.utils import timezone
from django.utils.choices import (
BlankChoiceIterator,
@ -1143,6 +1144,10 @@ class Field(RegisterLookupMixin):
"""Return the value of this field in the given model instance."""
return getattr(obj, self.attname)
def slice_expression(self, expression, start, length):
"""Return a slice of this field."""
raise NotSupportedError("This field does not support slicing.")
class BooleanField(Field):
empty_strings_allowed = False
@ -1303,6 +1308,11 @@ class CharField(Field):
kwargs["db_collation"] = self.db_collation
return name, path, args, kwargs
def slice_expression(self, expression, start, length):
from django.db.models.functions import Substr
return Substr(expression, start, length)
class CommaSeparatedIntegerField(CharField):
default_validators = [validators.validate_comma_separated_integer_list]
@ -2497,6 +2507,11 @@ class TextField(Field):
kwargs["db_collation"] = self.db_collation
return name, path, args, kwargs
def slice_expression(self, expression, start, length):
from django.db.models.functions import Substr
return Substr(expression, start, length)
class TimeField(DateTimeCheckMixin, Field):
empty_strings_allowed = False

View File

@ -183,6 +183,28 @@ the field value of each one, and saving each one back to the database::
* getting the database, rather than Python, to do work
* reducing the number of queries some operations require
.. _slicing-using-f:
Slicing ``F()`` expressions
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. versionadded:: 5.1
For string-based fields, text-based fields, and
:class:`~django.contrib.postgres.fields.ArrayField`, you can use Python's
array-slicing syntax. The indices are 0-based and the ``step`` argument to
``slice`` is not supported. For example:
.. code-block:: pycon
>>> # Replacing a name with a substring of itself.
>>> writer = Writers.objects.get(name="Priyansh")
>>> writer.name = F("name")[1:5]
>>> writer.save()
>>> writer.refresh_from_db()
>>> writer.name
'riya'
.. _avoiding-race-conditions-using-f:
Avoiding race conditions using ``F()``

View File

@ -184,6 +184,14 @@ Models
* :meth:`.QuerySet.order_by` now supports ordering by annotation transforms
such as ``JSONObject`` keys and ``ArrayAgg`` indices.
* :class:`F() <django.db.models.F>` and :class:`OuterRef()
<django.db.models.OuterRef>` expressions that output
:class:`~django.db.models.CharField`, :class:`~django.db.models.EmailField`,
:class:`~django.db.models.SlugField`, :class:`~django.db.models.URLField`,
:class:`~django.db.models.TextField`, or
:class:`~django.contrib.postgres.fields.ArrayField` can now be :ref:`sliced
<slicing-using-f>`.
Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~

View File

@ -106,3 +106,7 @@ class UUIDPK(models.Model):
class UUID(models.Model):
uuid = models.UUIDField(null=True)
uuid_fk = models.ForeignKey(UUIDPK, models.CASCADE, null=True)
class Text(models.Model):
name = models.TextField()

View File

@ -84,6 +84,7 @@ from .models import (
RemoteEmployee,
Result,
SimulationRun,
Text,
Time,
)
@ -205,6 +206,100 @@ class BasicExpressionsTests(TestCase):
],
)
def _test_slicing_of_f_expressions(self, model):
tests = [
(F("name")[:], "Example Inc.", "Example Inc."),
(F("name")[:7], "Example Inc.", "Example"),
(F("name")[:6][:5], "Example", "Examp"), # Nested slicing.
(F("name")[0], "Examp", "E"),
(F("name")[5], "E", ""),
(F("name")[7:], "Foobar Ltd.", "Ltd."),
(F("name")[0:10], "Ltd.", "Ltd."),
(F("name")[2:7], "Test GmbH", "st Gm"),
(F("name")[1:][:3], "st Gm", "t G"),
(F("name")[2:2], "t G", ""),
]
for expression, name, expected in tests:
with self.subTest(expression=expression, name=name, expected=expected):
obj = model.objects.get(name=name)
obj.name = expression
obj.save()
obj.refresh_from_db()
self.assertEqual(obj.name, expected)
def test_slicing_of_f_expressions_charfield(self):
self._test_slicing_of_f_expressions(Company)
def test_slicing_of_f_expressions_textfield(self):
Text.objects.bulk_create(
[Text(name=company.name) for company in Company.objects.all()]
)
self._test_slicing_of_f_expressions(Text)
def test_slicing_of_f_expressions_with_annotate(self):
qs = Company.objects.annotate(
first_three=F("name")[:3],
after_three=F("name")[3:],
random_four=F("name")[2:5],
first_letter_slice=F("name")[:1],
first_letter_index=F("name")[0],
)
tests = [
("first_three", ["Exa", "Foo", "Tes"]),
("after_three", ["mple Inc.", "bar Ltd.", "t GmbH"]),
("random_four", ["amp", "oba", "st "]),
("first_letter_slice", ["E", "F", "T"]),
("first_letter_index", ["E", "F", "T"]),
]
for annotation, expected in tests:
with self.subTest(annotation):
self.assertCountEqual(qs.values_list(annotation, flat=True), expected)
def test_slicing_of_f_expression_with_annotated_expression(self):
qs = Company.objects.annotate(
new_name=Case(
When(based_in_eu=True, then=Concat(Value("EU:"), F("name"))),
default=F("name"),
),
first_two=F("new_name")[:3],
)
self.assertCountEqual(
qs.values_list("first_two", flat=True),
["Exa", "EU:", "Tes"],
)
def test_slicing_of_f_expressions_with_negative_index(self):
msg = "Negative indexing is not supported."
indexes = [slice(0, -4), slice(-4, 0), slice(-4), -5]
for i in indexes:
with self.subTest(i=i), self.assertRaisesMessage(ValueError, msg):
F("name")[i]
def test_slicing_of_f_expressions_with_slice_stop_less_than_slice_start(self):
msg = "Slice stop must be greater than slice start."
with self.assertRaisesMessage(ValueError, msg):
F("name")[4:2]
def test_slicing_of_f_expressions_with_invalid_type(self):
msg = "Argument to slice must be either int or slice instance."
with self.assertRaisesMessage(TypeError, msg):
F("name")["error"]
def test_slicing_of_f_expressions_with_step(self):
msg = "Step argument is not supported."
with self.assertRaisesMessage(ValueError, msg):
F("name")[::4]
def test_slicing_of_f_unsupported_field(self):
msg = "This field does not support slicing."
with self.assertRaisesMessage(NotSupportedError, msg):
Company.objects.update(num_chairs=F("num_chairs")[:4])
def test_slicing_of_outerref(self):
inner = Company.objects.filter(name__startswith=OuterRef("ceo__firstname")[0])
outer = Company.objects.filter(Exists(inner)).values_list("name", flat=True)
self.assertSequenceEqual(outer, ["Foobar Ltd."])
def test_arithmetic(self):
# We can perform arithmetic operations in expressions
# Make sure we have 2 spare chairs
@ -2359,6 +2454,12 @@ class ReprTests(SimpleTestCase):
repr(Func("published", function="TO_CHAR")),
"Func(F(published), function=TO_CHAR)",
)
self.assertEqual(
repr(F("published")[0:2]), "Sliced(F(published), slice(0, 2, None))"
)
self.assertEqual(
repr(OuterRef("name")[1:5]), "Sliced(OuterRef(name), slice(1, 5, None))"
)
self.assertEqual(repr(OrderBy(Value(1))), "OrderBy(Value(1), descending=False)")
self.assertEqual(repr(RawSQL("table.col", [])), "RawSQL(table.col, [])")
self.assertEqual(

View File

@ -10,7 +10,7 @@ from django.core import checks, exceptions, serializers, validators
from django.core.exceptions import FieldError
from django.core.management import call_command
from django.db import IntegrityError, connection, models
from django.db.models.expressions import Exists, OuterRef, RawSQL, Value
from django.db.models.expressions import Exists, F, OuterRef, RawSQL, Value
from django.db.models.functions import Cast, JSONObject, Upper
from django.test import TransactionTestCase, override_settings, skipUnlessDBFeature
from django.test.utils import isolate_apps
@ -594,6 +594,40 @@ class TestQuerying(PostgreSQLTestCase):
[None, [1], [2], [2, 3], [20, 30]],
)
def test_slicing_of_f_expressions(self):
tests = [
(F("field")[:2], [1, 2]),
(F("field")[2:], [3, 4]),
(F("field")[1:3], [2, 3]),
(F("field")[3], [4]),
(F("field")[:3][1:], [2, 3]), # Nested slicing.
(F("field")[:3][1], [2]), # Slice then index.
]
for expression, expected in tests:
with self.subTest(expression=expression, expected=expected):
instance = IntegerArrayModel.objects.create(field=[1, 2, 3, 4])
instance.field = expression
instance.save()
instance.refresh_from_db()
self.assertEqual(instance.field, expected)
def test_slicing_of_f_expressions_with_annotate(self):
IntegerArrayModel.objects.create(field=[1, 2, 3])
annotated = IntegerArrayModel.objects.annotate(
first_two=F("field")[:2],
after_two=F("field")[2:],
random_two=F("field")[1:3],
).get()
self.assertEqual(annotated.first_two, [1, 2])
self.assertEqual(annotated.after_two, [3])
self.assertEqual(annotated.random_two, [2, 3])
def test_slicing_of_f_expressions_with_len(self):
queryset = NullableIntegerArrayModel.objects.annotate(
subarray=F("field")[:1]
).filter(field__len=F("subarray__len"))
self.assertSequenceEqual(queryset, self.objs[:2])
def test_usage_in_subquery(self):
self.assertSequenceEqual(
NullableIntegerArrayModel.objects.filter(