Fixed #27718 -- Added QuerySet.union(), intersection(), difference().

Thanks Mariusz Felisiak for review and Oracle assistance.
Thanks Tim Graham for review and writing docs.
This commit is contained in:
Florian Apolloner 2017-01-14 14:32:07 +01:00 committed by Tim Graham
parent 611ef422b1
commit 84c1826ded
12 changed files with 323 additions and 51 deletions

View File

@ -221,6 +221,12 @@ class BaseDatabaseFeatures(object):
# Place FOR UPDATE right after FROM clause. Used on MSSQL.
for_update_after_from = False
# Combinatorial flags
supports_select_union = True
supports_select_intersection = True
supports_select_difference = True
supports_slicing_ordering_in_compound = False
def __init__(self, connection):
self.connection = connection

View File

@ -29,6 +29,11 @@ class BaseDatabaseOperations(object):
'PositiveSmallIntegerField': (0, 32767),
'PositiveIntegerField': (0, 2147483647),
}
set_operators = {
'union': 'UNION',
'intersection': 'INTERSECT',
'difference': 'EXCEPT',
}
def __init__(self, connection):
self.connection = connection

View File

@ -29,6 +29,9 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_column_check_constraints = False
can_clone_databases = True
supports_temporal_subtraction = True
supports_select_intersection = False
supports_select_difference = False
supports_slicing_ordering_in_compound = True
@cached_property
def _mysql_storage_engine(self):

View File

@ -41,6 +41,10 @@ BEGIN
END;
/"""
def __init__(self, *args, **kwargs):
super(DatabaseOperations, self).__init__(*args, **kwargs)
self.set_operators['difference'] = 'MINUS'
def autoinc_sql(self, table, column):
# To simulate auto-incrementing primary keys in Oracle, we have to
# create a sequence and a trigger.

View File

@ -31,6 +31,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
greatest_least_ignores_nulls = True
can_clone_databases = True
supports_temporal_subtraction = True
supports_slicing_ordering_in_compound = True
@cached_property
def has_select_for_update_skip_locked(self):

View File

@ -816,6 +816,33 @@ class QuerySet(object):
else:
return self._filter_or_exclude(None, **filter_obj)
def _combinator_query(self, combinator, *other_qs, **kwargs):
# Clone the query to inherit the select list and everything
clone = self._clone()
# Clear limits and ordering so they can be reapplied
clone.query.clear_ordering(True)
clone.query.clear_limits()
clone.query.combined_queries = (self.query,) + tuple(qs.query for qs in other_qs)
clone.query.combinator = combinator
clone.query.combinator_all = kwargs.pop('all', False)
return clone
def union(self, *other_qs, **kwargs):
if kwargs:
unexpected_kwarg = next((k for k in kwargs.keys() if k != 'all'), None)
if unexpected_kwarg:
raise TypeError(
"union() received an unexpected keyword argument '%s'" %
(unexpected_kwarg,)
)
return self._combinator_query('union', *other_qs, **kwargs)
def intersection(self, *other_qs):
return self._combinator_query('intersection', *other_qs)
def difference(self, *other_qs):
return self._combinator_query('difference', *other_qs)
def select_for_update(self, nowait=False, skip_locked=False):
"""
Returns a new QuerySet instance that will select objects with a

View File

@ -309,6 +309,21 @@ class SQLCompiler(object):
seen = set()
for expr, is_ref in order_by:
if self.query.combinator:
src = expr.get_source_expressions()[0]
# Relabel order by columns to raw numbers if this is a combined
# query; necessary since the columns can't be referenced by the
# fully qualified name and the simple column names may collide.
for idx, (sel_expr, _, col_alias) in enumerate(self.select):
if is_ref and col_alias == src.refs:
src = src.source
elif col_alias:
continue
if src == sel_expr:
expr.set_source_expressions([RawSQL('%d' % (idx + 1), ())])
break
else:
raise DatabaseError('ORDER BY term does not match any column in the result set.')
resolved = expr.resolve_expression(
self.query, allow_joins=True, reuse=None)
sql, params = self.compile(resolved)
@ -360,6 +375,30 @@ class SQLCompiler(object):
return node.output_field.select_format(self, sql, params)
return sql, params
def get_combinator_sql(self, combinator, all):
features = self.connection.features
compilers = [
query.get_compiler(self.using, self.connection)
for query in self.query.combined_queries
]
if not features.supports_slicing_ordering_in_compound:
for query, compiler in zip(self.query.combined_queries, compilers):
if query.low_mark or query.high_mark:
raise DatabaseError('LIMIT/OFFSET not allowed in subqueries of compound statements.')
if compiler.get_order_by():
raise DatabaseError('ORDER BY not allowed in subqueries of compound statements.')
parts = (compiler.as_sql() for compiler in compilers)
combinator_sql = self.connection.ops.set_operators[combinator]
if all and combinator == 'union':
combinator_sql += ' ALL'
braces = '({})' if features.supports_slicing_ordering_in_compound else '{}'
sql_parts, args_parts = zip(*((braces.format(sql), args) for sql, args in parts))
result = [' {} '.format(combinator_sql).join(sql_parts)]
params = []
for part in args_parts:
params.extend(part)
return result, params
def as_sql(self, with_limits=True, with_col_aliases=False):
"""
Creates the SQL for this query. Returns the SQL string and list of
@ -377,69 +416,76 @@ class SQLCompiler(object):
# docstring of get_from_clause() for details.
from_, f_params = self.get_from_clause()
for_update_part = None
where, w_params = self.compile(self.where) if self.where is not None else ("", [])
having, h_params = self.compile(self.having) if self.having is not None else ("", [])
params = []
result = ['SELECT']
if self.query.distinct:
result.append(self.connection.ops.distinct_sql(distinct_fields))
combinator = self.query.combinator
features = self.connection.features
if combinator:
if not getattr(features, 'supports_select_{}'.format(combinator)):
raise DatabaseError('{} not supported on this database backend.'.format(combinator))
result, params = self.get_combinator_sql(combinator, self.query.combinator_all)
else:
result = ['SELECT']
params = []
out_cols = []
col_idx = 1
for _, (s_sql, s_params), alias in self.select + extra_select:
if alias:
s_sql = '%s AS %s' % (s_sql, self.connection.ops.quote_name(alias))
elif with_col_aliases:
s_sql = '%s AS %s' % (s_sql, 'Col%d' % col_idx)
col_idx += 1
params.extend(s_params)
out_cols.append(s_sql)
if self.query.distinct:
result.append(self.connection.ops.distinct_sql(distinct_fields))
result.append(', '.join(out_cols))
out_cols = []
col_idx = 1
for _, (s_sql, s_params), alias in self.select + extra_select:
if alias:
s_sql = '%s AS %s' % (s_sql, self.connection.ops.quote_name(alias))
elif with_col_aliases:
s_sql = '%s AS %s' % (s_sql, 'Col%d' % col_idx)
col_idx += 1
params.extend(s_params)
out_cols.append(s_sql)
result.append('FROM')
result.extend(from_)
params.extend(f_params)
result.append(', '.join(out_cols))
for_update_part = None
if self.query.select_for_update and self.connection.features.has_select_for_update:
if self.connection.get_autocommit():
raise TransactionManagementError("select_for_update cannot be used outside of a transaction.")
result.append('FROM')
result.extend(from_)
params.extend(f_params)
nowait = self.query.select_for_update_nowait
skip_locked = self.query.select_for_update_skip_locked
# If it's a NOWAIT/SKIP LOCKED query but the backend doesn't
# support it, raise a DatabaseError to prevent a possible
# deadlock.
if nowait and not self.connection.features.has_select_for_update_nowait:
raise DatabaseError('NOWAIT is not supported on this database backend.')
elif skip_locked and not self.connection.features.has_select_for_update_skip_locked:
raise DatabaseError('SKIP LOCKED is not supported on this database backend.')
for_update_part = self.connection.ops.for_update_sql(nowait=nowait, skip_locked=skip_locked)
if self.query.select_for_update and self.connection.features.has_select_for_update:
if self.connection.get_autocommit():
raise TransactionManagementError('select_for_update cannot be used outside of a transaction.')
if for_update_part and self.connection.features.for_update_after_from:
result.append(for_update_part)
nowait = self.query.select_for_update_nowait
skip_locked = self.query.select_for_update_skip_locked
# If it's a NOWAIT/SKIP LOCKED query but the backend
# doesn't support it, raise a DatabaseError to prevent a
# possible deadlock.
if nowait and not self.connection.features.has_select_for_update_nowait:
raise DatabaseError('NOWAIT is not supported on this database backend.')
elif skip_locked and not self.connection.features.has_select_for_update_skip_locked:
raise DatabaseError('SKIP LOCKED is not supported on this database backend.')
for_update_part = self.connection.ops.for_update_sql(nowait=nowait, skip_locked=skip_locked)
if where:
result.append('WHERE %s' % where)
params.extend(w_params)
if for_update_part and self.connection.features.for_update_after_from:
result.append(for_update_part)
grouping = []
for g_sql, g_params in group_by:
grouping.append(g_sql)
params.extend(g_params)
if grouping:
if distinct_fields:
raise NotImplementedError(
"annotate() + distinct(fields) is not implemented.")
if not order_by:
order_by = self.connection.ops.force_no_ordering()
result.append('GROUP BY %s' % ', '.join(grouping))
if where:
result.append('WHERE %s' % where)
params.extend(w_params)
if having:
result.append('HAVING %s' % having)
params.extend(h_params)
grouping = []
for g_sql, g_params in group_by:
grouping.append(g_sql)
params.extend(g_params)
if grouping:
if distinct_fields:
raise NotImplementedError('annotate() + distinct(fields) is not implemented.')
if not order_by:
order_by = self.connection.ops.force_no_ordering()
result.append('GROUP BY %s' % ', '.join(grouping))
if having:
result.append('HAVING %s' % having)
params.extend(h_params)
if order_by:
ordering = []

View File

@ -186,6 +186,11 @@ class Query(object):
self.annotation_select_mask = None
self._annotation_select_cache = None
# Set combination attributes
self.combinator = None
self.combinator_all = False
self.combined_queries = ()
# These are for extensions. The contents are more or less appended
# verbatim to the appropriate clause.
# The _extra attribute is an OrderedDict, lazily created similarly to
@ -303,6 +308,9 @@ class Query(object):
# used.
obj._annotation_select_cache = None
obj.max_depth = self.max_depth
obj.combinator = self.combinator
obj.combinator_all = self.combinator_all
obj.combined_queries = self.combined_queries
obj._extra = self._extra.copy() if self._extra is not None else None
if self.extra_select_mask is None:
obj.extra_select_mask = None

View File

@ -801,6 +801,61 @@ typically caches its results. If the data in the database might have changed
since a ``QuerySet`` was evaluated, you can get updated results for the same
query by calling ``all()`` on a previously evaluated ``QuerySet``.
``union()``
~~~~~~~~~~~
.. method:: union(*other_qs, all=False)
.. versionadded:: 1.11
Uses SQL's ``UNION`` operator to combine the results of two or more
``QuerySet``\s. For example:
>>> qs1.union(qs2, qs3)
The ``UNION`` operator selects only distinct values by default. To allow
duplicate values, use the ``all=True`` argument.
``union()``, ``intersection()``, and ``difference()`` return model instances
of the type of the first ``QuerySet`` even if the arguments are ``QuerySet``\s
of other models. Passing different models works as long as the ``SELECT`` list
is the same in all ``QuerySet``\s (at least the types, the names don't matter
as long as the types in the same order).
In addition, only ``LIMIT``, ``OFFSET``, and ``ORDER BY`` (i.e. slicing and
:meth:`order_by`) are allowed on the resulting ``QuerySet``. Further, databases
place restrictions on what operations are allowed in the combined queries. For
example, most databases don't allow ``LIMIT`` or ``OFFSET`` in the combined
queries.
``intersection()``
~~~~~~~~~~~~~~~~~~
.. method:: intersection(*other_qs)
.. versionadded:: 1.11
Uses SQL's ``INTERSECT`` operator to return the shared elements of two or more
``QuerySet``\s. For example:
>>> qs1.itersect(qs2, qs3)
See :meth:`union` for some restrictions.
``difference()``
~~~~~~~~~~~~~~~~
.. method:: difference(*other_qs)
.. versionadded:: 1.11
Uses SQL's ``EXCEPT`` operator to keep only elements present in the
``QuerySet`` but not in some other ``QuerySet``\s. For example::
>>> qs1.difference(qs2, qs3)
See :meth:`union` for some restrictions.
``select_related()``
~~~~~~~~~~~~~~~~~~~~

View File

@ -386,6 +386,9 @@ Models
* The new ``F`` expression ``bitleftshift()`` and ``bitrightshift()`` methods
allow :ref:`bitwise shift operations <using-f-expressions-in-filters>`.
* Added :meth:`.QuerySet.union`, :meth:`~.QuerySet.intersection`, and
:meth:`~.QuerySet.difference`.
Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~

View File

@ -589,6 +589,9 @@ class ManagerTest(SimpleTestCase):
'_insert',
'_update',
'raw',
'union',
'intersection',
'difference',
]
def test_manager_methods(self):

View File

@ -0,0 +1,111 @@
from __future__ import unicode_literals
from django.db.models import F, IntegerField, Value
from django.db.utils import DatabaseError
from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
from django.utils.six.moves import range
from .models import Number, ReservedName
@skipUnlessDBFeature('supports_select_union')
class QuerySetSetOperationTests(TestCase):
@classmethod
def setUpTestData(cls):
Number.objects.bulk_create(Number(num=i) for i in range(10))
def number_transform(self, value):
return value.num
def assertNumbersEqual(self, queryset, expected_numbers, ordered=True):
self.assertQuerysetEqual(queryset, expected_numbers, self.number_transform, ordered)
def test_simple_union(self):
qs1 = Number.objects.filter(num__lte=1)
qs2 = Number.objects.filter(num__gte=8)
qs3 = Number.objects.filter(num=5)
self.assertNumbersEqual(qs1.union(qs2, qs3), [0, 1, 5, 8, 9], ordered=False)
@skipUnlessDBFeature('supports_select_intersection')
def test_simple_intersection(self):
qs1 = Number.objects.filter(num__lte=5)
qs2 = Number.objects.filter(num__gte=5)
qs3 = Number.objects.filter(num__gte=4, num__lte=6)
self.assertNumbersEqual(qs1.intersection(qs2, qs3), [5], ordered=False)
@skipUnlessDBFeature('supports_select_difference')
def test_simple_difference(self):
qs1 = Number.objects.filter(num__lte=5)
qs2 = Number.objects.filter(num__lte=4)
self.assertNumbersEqual(qs1.difference(qs2), [5], ordered=False)
def test_union_distinct(self):
qs1 = Number.objects.all()
qs2 = Number.objects.all()
self.assertEqual(len(list(qs1.union(qs2, all=True))), 20)
self.assertEqual(len(list(qs1.union(qs2))), 10)
def test_union_bad_kwarg(self):
qs1 = Number.objects.all()
msg = "union() received an unexpected keyword argument 'bad'"
with self.assertRaisesMessage(TypeError, msg):
self.assertEqual(len(list(qs1.union(qs1, bad=True))), 20)
def test_limits(self):
qs1 = Number.objects.all()
qs2 = Number.objects.all()
self.assertEqual(len(list(qs1.union(qs2)[:2])), 2)
def test_ordering(self):
qs1 = Number.objects.filter(num__lte=1)
qs2 = Number.objects.filter(num__gte=2, num__lte=3)
self.assertNumbersEqual(qs1.union(qs2).order_by('-num'), [3, 2, 1, 0])
@skipUnlessDBFeature('supports_slicing_ordering_in_compound')
def test_ordering_subqueries(self):
qs1 = Number.objects.order_by('num')[:2]
qs2 = Number.objects.order_by('-num')[:2]
self.assertNumbersEqual(qs1.union(qs2).order_by('-num')[:4], [9, 8, 1, 0])
@skipIfDBFeature('supports_slicing_ordering_in_compound')
def test_unsupported_ordering_slicing_raises_db_error(self):
qs1 = Number.objects.all()
qs2 = Number.objects.all()
msg = 'LIMIT/OFFSET not allowed in subqueries of compound statements'
with self.assertRaisesMessage(DatabaseError, msg):
list(qs1.union(qs2[:10]))
msg = 'ORDER BY not allowed in subqueries of compound statements'
with self.assertRaisesMessage(DatabaseError, msg):
list(qs1.order_by('id').union(qs2))
@skipIfDBFeature('supports_select_intersection')
def test_unsupported_intersection_raises_db_error(self):
qs1 = Number.objects.all()
qs2 = Number.objects.all()
msg = 'intersection not supported on this database backend'
with self.assertRaisesMessage(DatabaseError, msg):
list(qs1.intersection(qs2))
def test_combining_multiple_models(self):
ReservedName.objects.create(name='99 little bugs', order=99)
qs1 = Number.objects.filter(num=1).values_list('num', flat=True)
qs2 = ReservedName.objects.values_list('order')
self.assertEqual(list(qs1.union(qs2).order_by('num')), [1, 99])
def test_order_raises_on_non_selected_column(self):
qs1 = Number.objects.filter().annotate(
annotation=Value(1, IntegerField()),
).values('annotation', num2=F('num'))
qs2 = Number.objects.filter().values('id', 'num')
# Should not raise
list(qs1.union(qs2).order_by('annotation'))
list(qs1.union(qs2).order_by('num2'))
msg = 'ORDER BY term does not match any column in the result set'
# 'id' is not part of the select
with self.assertRaisesMessage(DatabaseError, msg):
list(qs1.union(qs2).order_by('id'))
# 'num' got realiased to num2
with self.assertRaisesMessage(DatabaseError, msg):
list(qs1.union(qs2).order_by('num'))
# switched order, now 'exists' again:
list(qs2.union(qs1).order_by('num'))