Fixed #20939 -- Simplified query generation by converting QuerySet to Query.

Thanks Anssi Kääriäinen for the initial patch and Anssi, Simon Charette,
and Josh Smeaton for review.
This commit is contained in:
Tim Graham 2016-10-28 11:20:23 -04:00 committed by GitHub
parent 80e742d991
commit 1bc249c2a6
9 changed files with 71 additions and 140 deletions

View File

@ -6,6 +6,7 @@ from django.core.exceptions import FieldDoesNotExist
from django.db.models.constants import LOOKUP_SEP from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import Col, Expression from django.db.models.expressions import Col, Expression
from django.db.models.lookups import BuiltinLookup, Lookup, Transform from django.db.models.lookups import BuiltinLookup, Lookup, Transform
from django.db.models.sql.query import Query
from django.utils import six from django.utils import six
gis_lookups = {} gis_lookups = {}
@ -100,8 +101,8 @@ class GISLookup(Lookup):
return ('%s', params) return ('%s', params)
def process_rhs(self, compiler, connection): def process_rhs(self, compiler, connection):
if hasattr(self.rhs, '_as_sql'): if isinstance(self.rhs, Query):
# If rhs is some QuerySet, don't touch it # If rhs is some Query, don't touch it.
return super(GISLookup, self).process_rhs(compiler, connection) return super(GISLookup, self).process_rhs(compiler, connection)
geom = self.rhs geom = self.rhs

View File

@ -2,7 +2,7 @@ from django.db.models.sql import compiler
class SQLCompiler(compiler.SQLCompiler): class SQLCompiler(compiler.SQLCompiler):
def as_sql(self, with_limits=True, with_col_aliases=False, subquery=False): def as_sql(self, with_limits=True, with_col_aliases=False):
""" """
Creates the SQL for this query. Returns the SQL string and list Creates the SQL for this query. Returns the SQL string and list
of parameters. This is overridden from the original Query class of parameters. This is overridden from the original Query class
@ -19,13 +19,11 @@ class SQLCompiler(compiler.SQLCompiler):
sql, params = super(SQLCompiler, self).as_sql( sql, params = super(SQLCompiler, self).as_sql(
with_limits=False, with_limits=False,
with_col_aliases=with_col_aliases, with_col_aliases=with_col_aliases,
subquery=subquery,
) )
else: else:
sql, params = super(SQLCompiler, self).as_sql( sql, params = super(SQLCompiler, self).as_sql(
with_limits=False, with_limits=False,
with_col_aliases=True, with_col_aliases=True,
subquery=subquery,
) )
# Wrap the base query in an outer SELECT * with boundaries on # Wrap the base query in an outer SELECT * with boundaries on
# the "_RN" column. This is the canonical way to emulate LIMIT # the "_RN" column. This is the canonical way to emulate LIMIT

View File

@ -81,27 +81,18 @@ class RelatedIn(In):
AND) AND)
return root_constraint.as_sql(compiler, connection) return root_constraint.as_sql(compiler, connection)
else: else:
if getattr(self.rhs, '_forced_pk', False):
self.rhs.clear_select_clause()
if getattr(self.lhs.output_field, 'primary_key', False):
# A case like Restaurant.objects.filter(place__in=restaurant_qs),
# where place is a OneToOneField and the primary key of
# Restaurant.
target_field = self.lhs.field.name
else:
target_field = self.lhs.field.target_field.name
self.rhs.add_fields([target_field], True)
return super(RelatedIn, self).as_sql(compiler, connection) return super(RelatedIn, self).as_sql(compiler, connection)
def __getstate__(self):
"""
Prevent pickling a query with an __in=inner_qs lookup from evaluating
inner_qs.
"""
from django.db.models.query import QuerySet # Avoid circular import
state = self.__dict__.copy()
if isinstance(self.rhs, QuerySet):
state['rhs'] = (self.rhs.__class__, self.rhs.query)
return state
def __setstate__(self, state):
self.__dict__.update(state)
if isinstance(self.rhs, tuple):
queryset_class, query = self.rhs
queryset = queryset_class()
queryset.query = query
self.rhs = queryset
class RelatedLookupMixin(object): class RelatedLookupMixin(object):
def get_prep_lookup(self): def get_prep_lookup(self):

View File

@ -26,9 +26,8 @@ class Lookup(object):
if bilateral_transforms: if bilateral_transforms:
# Warn the user as soon as possible if they are trying to apply # Warn the user as soon as possible if they are trying to apply
# a bilateral transformation on a nested QuerySet: that won't work. # a bilateral transformation on a nested QuerySet: that won't work.
# We need to import QuerySet here so as to avoid circular from django.db.models.sql.query import Query # avoid circular import
from django.db.models.query import QuerySet if isinstance(rhs, Query):
if isinstance(rhs, QuerySet):
raise NotImplementedError("Bilateral transformations on nested querysets are not supported.") raise NotImplementedError("Bilateral transformations on nested querysets are not supported.")
self.bilateral_transforms = bilateral_transforms self.bilateral_transforms = bilateral_transforms
@ -79,24 +78,19 @@ class Lookup(object):
value = value.resolve_expression(compiler.query) value = value.resolve_expression(compiler.query)
# Due to historical reasons there are a couple of different # Due to historical reasons there are a couple of different
# ways to produce sql here. get_compiler is likely a Query # ways to produce sql here. get_compiler is likely a Query
# instance, _as_sql QuerySet and as_sql just something with # instance and as_sql just something with as_sql. Finally the value
# as_sql. Finally the value can of course be just plain # can of course be just plain Python value.
# Python value.
if hasattr(value, 'get_compiler'): if hasattr(value, 'get_compiler'):
value = value.get_compiler(connection=connection) value = value.get_compiler(connection=connection)
if hasattr(value, 'as_sql'): if hasattr(value, 'as_sql'):
sql, params = compiler.compile(value) sql, params = compiler.compile(value)
return '(' + sql + ')', params return '(' + sql + ')', params
if hasattr(value, '_as_sql'):
sql, params = value._as_sql(connection=connection)
return '(' + sql + ')', params
else: else:
return self.get_db_prep_lookup(value, connection) return self.get_db_prep_lookup(value, connection)
def rhs_is_direct_value(self): def rhs_is_direct_value(self):
return not( return not(
hasattr(self.rhs, 'as_sql') or hasattr(self.rhs, 'as_sql') or
hasattr(self.rhs, '_as_sql') or
hasattr(self.rhs, 'get_compiler')) hasattr(self.rhs, 'get_compiler'))
def relabeled_clone(self, relabels): def relabeled_clone(self, relabels):
@ -371,8 +365,7 @@ class PatternLookup(BuiltinLookup):
# So, for Python values we don't need any special pattern, but for # So, for Python values we don't need any special pattern, but for
# SQL reference values or SQL transformations we need the correct # SQL reference values or SQL transformations we need the correct
# pattern added. # pattern added.
if (hasattr(self.rhs, 'get_compiler') or hasattr(self.rhs, 'as_sql') or if hasattr(self.rhs, 'get_compiler') or hasattr(self.rhs, 'as_sql') or self.bilateral_transforms:
hasattr(self.rhs, '_as_sql') or self.bilateral_transforms):
pattern = connection.pattern_ops[self.lookup_name].format(connection.pattern_esc) pattern = connection.pattern_ops[self.lookup_name].format(connection.pattern_esc)
return pattern.format(rhs) return pattern.format(rhs)
else: else:

View File

@ -19,9 +19,7 @@ from django.db.models.deletion import Collector
from django.db.models.expressions import F from django.db.models.expressions import F
from django.db.models.fields import AutoField from django.db.models.fields import AutoField
from django.db.models.functions import Trunc from django.db.models.functions import Trunc
from django.db.models.query_utils import ( from django.db.models.query_utils import InvalidQuery, Q
InvalidQuery, Q, check_rel_lookup_compatibility,
)
from django.db.models.sql.constants import CURSOR from django.db.models.sql.constants import CURSOR
from django.utils import six, timezone from django.utils import six, timezone
from django.utils.deprecation import RemovedInDjango20Warning from django.utils.deprecation import RemovedInDjango20Warning
@ -1114,36 +1112,17 @@ class QuerySet(object):
for field, objects in other._known_related_objects.items(): for field, objects in other._known_related_objects.items():
self._known_related_objects.setdefault(field, {}).update(objects) self._known_related_objects.setdefault(field, {}).update(objects)
def _prepare(self, field): def _prepare_as_filter_value(self):
if self._fields is not None: if self._fields is None:
# values() queryset can only be used as nested queries queryset = self.values('pk')
# if they are set up to select only a single field. queryset.query._forced_pk = True
if len(self._fields or self.model._meta.concrete_fields) > 1:
raise TypeError('Cannot use multi-field values as a filter value.')
elif self.model != field.model:
# If the query is used as a subquery for a ForeignKey with non-pk
# target field, make sure to select the target field in the subquery.
foreign_fields = getattr(field, 'foreign_related_fields', ())
if len(foreign_fields) == 1 and not foreign_fields[0].primary_key:
return self.values(foreign_fields[0].name)
return self
def _as_sql(self, connection):
"""
Returns the internal query's SQL and parameters (as a tuple).
"""
if self._fields is not None:
# values() queryset can only be used as nested queries
# if they are set up to select only a single field.
if len(self._fields or self.model._meta.concrete_fields) > 1:
raise TypeError('Cannot use multi-field values as a filter value.')
clone = self._clone()
else: else:
clone = self.values('pk') # values() queryset can only be used as nested queries
# if they are set up to select only a single field.
if clone._db is None or connection == connections[clone._db]: if len(self._fields) > 1:
return clone.query.get_compiler(connection=connection).as_nested_sql() raise TypeError('Cannot use multi-field values as a filter value.')
raise ValueError("Can't do subqueries with queries on different DBs.") queryset = self._clone()
return queryset.query.as_subquery_filter(queryset._db)
def _add_hints(self, **hints): def _add_hints(self, **hints):
""" """
@ -1161,21 +1140,6 @@ class QuerySet(object):
""" """
return self.query.has_filters() return self.query.has_filters()
def is_compatible_query_object_type(self, opts, field):
"""
Check that using this queryset as the rhs value for a lookup is
allowed. The opts are the options of the relation's target we are
querying against. For example in .filter(author__in=Author.objects.all())
the opts would be Author's (from the author field) and self.model would
be Author.objects.all() queryset's .model (Author also). The field is
the related field on the lhs side.
"""
# We trust that users of values() know what they are doing.
if self._fields is not None:
return True
return check_rel_lookup_compatibility(self.model, opts, field)
is_compatible_query_object_type.queryset_only = True
class InstanceCheckMeta(type): class InstanceCheckMeta(type):
def __instancecheck__(self, instance): def __instancecheck__(self, instance):

View File

@ -13,6 +13,8 @@ from django.db.transaction import TransactionManagementError
from django.db.utils import DatabaseError from django.db.utils import DatabaseError
from django.utils.six.moves import zip from django.utils.six.moves import zip
FORCE = object()
class SQLCompiler(object): class SQLCompiler(object):
def __init__(self, query, connection, using): def __init__(self, query, connection, using):
@ -28,7 +30,6 @@ class SQLCompiler(object):
self.annotation_col_map = None self.annotation_col_map = None
self.klass_info = None self.klass_info = None
self.ordering_parts = re.compile(r'(.*)\s(ASC|DESC)(.*)') self.ordering_parts = re.compile(r'(.*)\s(ASC|DESC)(.*)')
self.subquery = False
def setup_query(self): def setup_query(self):
if all(self.query.alias_refcount[a] == 0 for a in self.query.tables): if all(self.query.alias_refcount[a] == 0 for a in self.query.tables):
@ -355,11 +356,11 @@ class SQLCompiler(object):
sql, params = vendor_impl(self, self.connection) sql, params = vendor_impl(self, self.connection)
else: else:
sql, params = node.as_sql(self, self.connection) sql, params = node.as_sql(self, self.connection)
if select_format and not self.subquery: if select_format is FORCE or (select_format and not self.query.subquery):
return node.output_field.select_format(self, sql, params) return node.output_field.select_format(self, sql, params)
return sql, params return sql, params
def as_sql(self, with_limits=True, with_col_aliases=False, subquery=False): def as_sql(self, with_limits=True, with_col_aliases=False):
""" """
Creates the SQL for this query. Returns the SQL string and list of Creates the SQL for this query. Returns the SQL string and list of
parameters. parameters.
@ -367,7 +368,6 @@ class SQLCompiler(object):
If 'with_limits' is False, any limit/offset information is not included If 'with_limits' is False, any limit/offset information is not included
in the query. in the query.
""" """
self.subquery = subquery
refcounts_before = self.query.alias_refcount.copy() refcounts_before = self.query.alias_refcount.copy()
try: try:
extra_select, order_by, group_by = self.pre_sql_setup() extra_select, order_by, group_by = self.pre_sql_setup()
@ -466,24 +466,6 @@ class SQLCompiler(object):
# Finally do cleanup - get rid of the joins we created above. # Finally do cleanup - get rid of the joins we created above.
self.query.reset_refcounts(refcounts_before) self.query.reset_refcounts(refcounts_before)
def as_nested_sql(self):
"""
Perform the same functionality as the as_sql() method, returning an
SQL string and parameters. However, the alias prefixes are bumped
beforehand (in a copy -- the current query isn't changed), and any
ordering is removed if the query is unsliced.
Used when nesting this query inside another.
"""
obj = self.query.clone()
# It's safe to drop ordering if the queryset isn't using slicing,
# distinct(*fields) or select_for_update().
if (obj.low_mark == 0 and obj.high_mark is None and
not self.query.distinct_fields and
not self.query.select_for_update):
obj.clear_ordering(True)
return obj.get_compiler(connection=self.connection).as_sql(subquery=True)
def get_default_columns(self, start_alias=None, opts=None, from_parent=None): def get_default_columns(self, start_alias=None, opts=None, from_parent=None):
""" """
Computes the default columns for selecting every field in the base Computes the default columns for selecting every field in the base
@ -1218,7 +1200,7 @@ class SQLAggregateCompiler(SQLCompiler):
""" """
sql, params = [], [] sql, params = [], []
for annotation in self.query.annotation_select.values(): for annotation in self.query.annotation_select.values():
ann_sql, ann_params = self.compile(annotation, select_format=True) ann_sql, ann_params = self.compile(annotation, select_format=FORCE)
sql.append(ann_sql) sql.append(ann_sql)
params.extend(ann_params) params.extend(ann_params)
self.col_count = len(self.query.annotation_select) self.col_count = len(self.query.annotation_select)

View File

@ -143,6 +143,7 @@ class Query(object):
self.standard_ordering = True self.standard_ordering = True
self.used_aliases = set() self.used_aliases = set()
self.filter_is_sticky = False self.filter_is_sticky = False
self.subquery = False
# SQL-related attributes # SQL-related attributes
# Select and related select clauses are expressions to use in the # Select and related select clauses are expressions to use in the
@ -319,6 +320,7 @@ class Query(object):
else: else:
obj.used_aliases = set() obj.used_aliases = set()
obj.filter_is_sticky = False obj.filter_is_sticky = False
obj.subquery = self.subquery
if 'alias_prefix' in self.__dict__: if 'alias_prefix' in self.__dict__:
obj.alias_prefix = self.alias_prefix obj.alias_prefix = self.alias_prefix
if 'subq_aliases' in self.__dict__: if 'subq_aliases' in self.__dict__:
@ -964,6 +966,9 @@ class Query(object):
self.append_annotation_mask([alias]) self.append_annotation_mask([alias])
self.annotations[alias] = annotation self.annotations[alias] = annotation
def _prepare_as_filter_value(self):
return self.clone()
def prepare_lookup_value(self, value, lookups, can_reuse, allow_joins=True): def prepare_lookup_value(self, value, lookups, can_reuse, allow_joins=True):
# Default lookup if none given is exact. # Default lookup if none given is exact.
used_joins = [] used_joins = []
@ -974,8 +979,7 @@ class Query(object):
if value is None: if value is None:
if lookups[-1] not in ('exact', 'iexact'): if lookups[-1] not in ('exact', 'iexact'):
raise ValueError("Cannot use None as a query value") raise ValueError("Cannot use None as a query value")
lookups[-1] = 'isnull' return True, ['isnull'], used_joins
value = True
elif hasattr(value, 'resolve_expression'): elif hasattr(value, 'resolve_expression'):
pre_joins = self.alias_refcount.copy() pre_joins = self.alias_refcount.copy()
value = value.resolve_expression(self, reuse=can_reuse, allow_joins=allow_joins) value = value.resolve_expression(self, reuse=can_reuse, allow_joins=allow_joins)
@ -997,11 +1001,8 @@ class Query(object):
# Subqueries need to use a different set of aliases than the # Subqueries need to use a different set of aliases than the
# outer query. Call bump_prefix to change aliases of the inner # outer query. Call bump_prefix to change aliases of the inner
# query (the value). # query (the value).
if hasattr(value, 'query') and hasattr(value.query, 'bump_prefix'): if hasattr(value, '_prepare_as_filter_value'):
value = value._clone() value = value._prepare_as_filter_value()
value.query.bump_prefix(self)
if hasattr(value, 'bump_prefix'):
value = value.clone()
value.bump_prefix(self) value.bump_prefix(self)
# For Oracle '' is equivalent to null. The check needs to be done # For Oracle '' is equivalent to null. The check needs to be done
# at this stage because join promotion can't be done at compiler # at this stage because join promotion can't be done at compiler
@ -1049,14 +1050,20 @@ class Query(object):
Checks the type of object passed to query relations. Checks the type of object passed to query relations.
""" """
if field.is_relation: if field.is_relation:
# QuerySets implement is_compatible_query_object_type() to # Check that the field and the queryset use the same model in a
# determine compatibility with the given field. # query like .filter(author=Author.objects.all()). For example, the
if hasattr(value, 'is_compatible_query_object_type'): # opts would be Author's (from the author field) and value.model
if not value.is_compatible_query_object_type(opts, field): # would be Author.objects.all() queryset's .model (Author also).
raise ValueError( # The field is the related field on the lhs side.
'Cannot use QuerySet for "%s": Use a QuerySet for "%s".' % # If _forced_pk isn't set, this isn't a queryset query or values()
(value.model._meta.object_name, opts.object_name) # or values_list() was specified by the developer in which case
) # that choice is trusted.
if (getattr(value, '_forced_pk', False) and
not check_rel_lookup_compatibility(value.model, opts, field)):
raise ValueError(
'Cannot use QuerySet for "%s": Use a QuerySet for "%s".' %
(value.model._meta.object_name, opts.object_name)
)
elif hasattr(value, '_meta'): elif hasattr(value, '_meta'):
self.check_query_object_type(value, opts, field) self.check_query_object_type(value, opts, field)
elif hasattr(value, '__iter__'): elif hasattr(value, '__iter__'):
@ -2005,6 +2012,17 @@ class Query(object):
else: else:
return field.null return field.null
def as_subquery_filter(self, db):
self._db = db
self.subquery = True
# It's safe to drop ordering if the queryset isn't using slicing,
# distinct(*fields) or select_for_update().
if (self.low_mark == 0 and self.high_mark is None and
not self.distinct_fields and
not self.select_for_update):
self.clear_ordering(True)
return self
def get_order_dir(field, default='ASC'): def get_order_dir(field, default='ASC'):
""" """

View File

@ -205,7 +205,5 @@ class AggregateQuery(Query):
compiler = 'SQLAggregateCompiler' compiler = 'SQLAggregateCompiler'
def add_subquery(self, query, using): def add_subquery(self, query, using):
self.subquery, self.sub_params = query.get_compiler(using).as_sql( query.subquery = True
with_col_aliases=True, self.subquery, self.sub_params = query.get_compiler(using).as_sql(with_col_aliases=True)
subquery=True,
)

View File

@ -197,20 +197,6 @@ class SubqueryConstraint(object):
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
query = self.query_object query = self.query_object
query.set_values(self.targets)
# QuerySet was sent
if hasattr(query, 'values'):
if query._db and connection.alias != query._db:
raise ValueError("Can't do subqueries with queries on different DBs.")
# Do not override already existing values.
if query._fields is None:
query = query.values(*self.targets)
else:
query = query._clone()
query = query.query
if query.can_filter():
# If there is no slicing in use, then we can safely drop all ordering
query.clear_ordering(True)
query_compiler = query.get_compiler(connection=connection) query_compiler = query.get_compiler(connection=connection)
return query_compiler.as_subquery_condition(self.alias, self.columns, compiler) return query_compiler.as_subquery_condition(self.alias, self.columns, compiler)