Fixed #19259 -- Added group by selected primary keys support.

This commit is contained in:
Simon Charette 2015-03-26 16:54:43 -04:00
parent 8119876d4a
commit dc27f3ee0c
6 changed files with 34 additions and 12 deletions

View File

@ -6,6 +6,7 @@ from django.utils.functional import cached_property
class BaseDatabaseFeatures(object): class BaseDatabaseFeatures(object):
gis_enabled = False gis_enabled = False
allows_group_by_pk = False allows_group_by_pk = False
allows_group_by_selected_pks = False
# True if django.db.backends.utils.typecast_timestamp is used on values # True if django.db.backends.utils.typecast_timestamp is used on values
# returned from dates() calls. # returned from dates() calls.
needs_datetime_string_cast = True needs_datetime_string_cast = True

View File

@ -3,6 +3,7 @@ from django.db.utils import InterfaceError
class DatabaseFeatures(BaseDatabaseFeatures): class DatabaseFeatures(BaseDatabaseFeatures):
allows_group_by_selected_pks = True
needs_datetime_string_cast = False needs_datetime_string_cast = False
can_return_id_from_insert = True can_return_id_from_insert = True
has_real_datatype = True has_real_datatype = True

View File

@ -136,10 +136,7 @@ class SQLCompiler(object):
# If the DB can group by primary key, then group by the primary key of # If the DB can group by primary key, then group by the primary key of
# query's main model. Note that for PostgreSQL the GROUP BY clause must # query's main model. Note that for PostgreSQL the GROUP BY clause must
# include the primary key of every table, but for MySQL it is enough to # include the primary key of every table, but for MySQL it is enough to
# have the main table's primary key. Currently only the MySQL form is # have the main table's primary key.
# implemented.
# MySQLism: however, columns in HAVING clause must be added to the
# GROUP BY.
if self.connection.features.allows_group_by_pk: if self.connection.features.allows_group_by_pk:
# The logic here is: if the main model's primary key is in the # The logic here is: if the main model's primary key is in the
# query, then set new_expressions to that field. If that happens, # query, then set new_expressions to that field. If that happens,
@ -150,7 +147,18 @@ class SQLCompiler(object):
getattr(expr.output_field, 'model') == self.query.model): getattr(expr.output_field, 'model') == self.query.model):
pk = expr pk = expr
if pk: if pk:
# MySQLism: Columns in HAVING clause must be added to the GROUP BY.
expressions = [pk] + [expr for expr in expressions if expr in having] expressions = [pk] + [expr for expr in expressions if expr in having]
elif self.connection.features.allows_group_by_selected_pks:
# Filter out all expressions associated with a table's primary key
# present in the grouped columns. This is done by identifying all
# tables that have their primary key included in the grouped
# columns and removing non-primary key columns referring to them.
pks = {expr for expr in expressions if hasattr(expr, 'target') and expr.target.primary_key}
aliases = {expr.alias for expr in pks}
expressions = [
expr for expr in expressions if expr in pks or getattr(expr, 'alias', None) not in aliases
]
return expressions return expressions
def get_select(self): def get_select(self):

View File

@ -6,7 +6,7 @@ from django.test.client import Client, RequestFactory
from django.test.testcases import ( from django.test.testcases import (
TestCase, TransactionTestCase, TestCase, TransactionTestCase,
SimpleTestCase, LiveServerTestCase, skipIfDBFeature, SimpleTestCase, LiveServerTestCase, skipIfDBFeature,
skipUnlessDBFeature skipUnlessAnyDBFeature, skipUnlessDBFeature
) )
from django.test.utils import (ignore_warnings, modify_settings, from django.test.utils import (ignore_warnings, modify_settings,
override_settings, override_system_checks) override_settings, override_system_checks)
@ -14,8 +14,8 @@ from django.test.utils import (ignore_warnings, modify_settings,
__all__ = [ __all__ = [
'Client', 'RequestFactory', 'TestCase', 'TransactionTestCase', 'Client', 'RequestFactory', 'TestCase', 'TransactionTestCase',
'SimpleTestCase', 'LiveServerTestCase', 'skipIfDBFeature', 'SimpleTestCase', 'LiveServerTestCase', 'skipIfDBFeature',
'skipUnlessDBFeature', 'ignore_warnings', 'modify_settings', 'skipUnlessAnyDBFeature', 'skipUnlessDBFeature', 'ignore_warnings',
'override_settings', 'override_system_checks' 'modify_settings', 'override_settings', 'override_system_checks'
] ]
# To simplify Django's test suite; not meant as a public API # To simplify Django's test suite; not meant as a public API

View File

@ -1051,6 +1051,16 @@ def skipUnlessDBFeature(*features):
) )
def skipUnlessAnyDBFeature(*features):
"""
Skip a test unless a database has any of the named features.
"""
return _deferredSkip(
lambda: not any(getattr(connection.features, feature, False) for feature in features),
"Database doesn't support any of the feature(s): %s" % ", ".join(features)
)
class QuietWSGIRequestHandler(WSGIRequestHandler): class QuietWSGIRequestHandler(WSGIRequestHandler):
""" """
Just a regular WSGIRequestHandler except it doesn't log to the standard Just a regular WSGIRequestHandler except it doesn't log to the standard

View File

@ -7,10 +7,11 @@ from operator import attrgetter
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db import connection
from django.db.models import ( from django.db.models import (
F, Q, Avg, Count, Max, StdDev, Sum, Value, Variance, F, Q, Avg, Count, Max, StdDev, Sum, Value, Variance,
) )
from django.test import TestCase, skipUnlessDBFeature from django.test import TestCase, skipUnlessAnyDBFeature, skipUnlessDBFeature
from django.test.utils import Approximate from django.test.utils import Approximate
from django.utils import six from django.utils import six
@ -1011,7 +1012,7 @@ class AggregationTests(TestCase):
# Check that the query executes without problems. # Check that the query executes without problems.
self.assertEqual(len(qs.exclude(publisher=-1)), 6) self.assertEqual(len(qs.exclude(publisher=-1)), 6)
@skipUnlessDBFeature("allows_group_by_pk") @skipUnlessAnyDBFeature('allows_group_by_pk', 'allows_group_by_selected_pks')
def test_aggregate_duplicate_columns(self): def test_aggregate_duplicate_columns(self):
# Regression test for #17144 # Regression test for #17144
@ -1041,7 +1042,7 @@ class AggregationTests(TestCase):
] ]
) )
@skipUnlessDBFeature("allows_group_by_pk") @skipUnlessAnyDBFeature('allows_group_by_pk', 'allows_group_by_selected_pks')
def test_aggregate_duplicate_columns_only(self): def test_aggregate_duplicate_columns_only(self):
# Works with only() too. # Works with only() too.
results = Author.objects.only('id', 'name').annotate(num_contacts=Count('book_contact_set')) results = Author.objects.only('id', 'name').annotate(num_contacts=Count('book_contact_set'))
@ -1067,13 +1068,14 @@ class AggregationTests(TestCase):
] ]
) )
@skipUnlessDBFeature("allows_group_by_pk") @skipUnlessAnyDBFeature('allows_group_by_pk', 'allows_group_by_selected_pks')
def test_aggregate_duplicate_columns_select_related(self): def test_aggregate_duplicate_columns_select_related(self):
# And select_related() # And select_related()
results = Book.objects.select_related('contact').annotate( results = Book.objects.select_related('contact').annotate(
num_authors=Count('authors')) num_authors=Count('authors'))
_, _, grouping = results.query.get_compiler(using='default').pre_sql_setup() _, _, grouping = results.query.get_compiler(using='default').pre_sql_setup()
self.assertEqual(len(grouping), 1) # In the case of `group_by_selected_pks` we also group by contact.id because of the select_related.
self.assertEqual(len(grouping), 1 if connection.features.allows_group_by_pk else 2)
self.assertIn('id', grouping[0][0]) self.assertIn('id', grouping[0][0])
self.assertNotIn('name', grouping[0][0]) self.assertNotIn('name', grouping[0][0])
self.assertNotIn('contact', grouping[0][0]) self.assertNotIn('contact', grouping[0][0])