diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index a09854e59b..80bbfc401a 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -62,6 +62,7 @@ class BaseQuery(object): self.dupe_avoidance = {} self.used_aliases = set() self.filter_is_sticky = False + self.included_inherited_models = {} # SQL-related attributes self.select = [] @@ -171,6 +172,7 @@ class BaseQuery(object): obj.default_cols = self.default_cols obj.default_ordering = self.default_ordering obj.standard_ordering = self.standard_ordering + obj.included_inherited_models = self.included_inherited_models.copy() obj.ordering_aliases = [] obj.select_fields = self.select_fields[:] obj.related_select_fields = self.related_select_fields[:] @@ -304,6 +306,7 @@ class BaseQuery(object): self.select = [] self.default_cols = False self.extra_select = {} + self.remove_inherited_models() query.clear_ordering(True) query.clear_limits() @@ -458,6 +461,7 @@ class BaseQuery(object): assert self.distinct == rhs.distinct, \ "Cannot combine a unique query with a non-unique query." + self.remove_inherited_models() # Work out how to relabel the rhs aliases, if necessary. change_map = {} used = set() @@ -540,6 +544,9 @@ class BaseQuery(object): """ if not self.tables: self.join((None, self.model._meta.db_table, None, None)) + if (not self.select and self.default_cols and not + self.included_inherited_models): + self.setup_inherited_models() if self.select_related and not self.related_select_cols: self.fill_related_selections() @@ -619,7 +626,9 @@ class BaseQuery(object): start_alias=None, opts=None, as_pairs=False): """ Computes the default columns for selecting every field in the base - model. + model. Will sometimes be called to pull in related models (e.g. via + select_related), in which case "opts" and "start_alias" will be given + to provide a starting point for the traversal. Returns a list of strings, quoted appropriately for use in SQL directly, as well as a set of aliases used in the select statement (if @@ -629,22 +638,25 @@ class BaseQuery(object): result = [] if opts is None: opts = self.model._meta - if start_alias: - table_alias = start_alias - else: - table_alias = self.tables[0] - root_pk = opts.pk.column - seen = {None: table_alias} qn = self.quote_name_unless_alias qn2 = self.connection.ops.quote_name aliases = set() + if start_alias: + seen = {None: start_alias} + root_pk = opts.pk.column for field, model in opts.get_fields_with_model(): - try: - alias = seen[model] - except KeyError: - alias = self.join((table_alias, model._meta.db_table, - root_pk, model._meta.pk.column)) - seen[model] = alias + if start_alias: + try: + alias = seen[model] + except KeyError: + alias = self.join((start_alias, model._meta.db_table, + root_pk, model._meta.pk.column)) + seen[model] = alias + else: + # If we're starting from the base model of the queryset, the + # aliases will have already been set up in pre_sql_setup(), so + # we can save time here. + alias = self.included_inherited_models[model] if as_pairs: result.append((alias, field.column)) continue @@ -996,6 +1008,9 @@ class BaseQuery(object): if alias == old_alias: self.tables[pos] = new_alias break + for key, alias in self.included_inherited_models.items(): + if alias in change_map: + self.included_inherited_models[key] = change_map[alias] # 3. Update any joins that refer to the old alias. for alias, data in self.alias_map.iteritems(): @@ -1062,9 +1077,11 @@ class BaseQuery(object): lhs.lhs_col = table.col If 'always_create' is True and 'reuse' is None, a new alias is always - created, regardless of whether one already exists or not. Otherwise - 'reuse' must be a set and a new join is created unless one of the - aliases in `reuse` can be used. + created, regardless of whether one already exists or not. If + 'always_create' is True and 'reuse' is a set, an alias in 'reuse' that + matches the connection will be returned, if possible. If + 'always_create' is False, the first existing alias that matches the + 'connection' is returned, if any. Otherwise a new join is created. If 'exclusions' is specified, it is something satisfying the container protocol ("foo in exclusions" must work) and specifies a list of @@ -1126,6 +1143,38 @@ class BaseQuery(object): self.rev_join_map[alias] = t_ident return alias + def setup_inherited_models(self): + """ + If the model that is the basis for this QuerySet inherits other models, + we need to ensure that those other models have their tables included in + the query. + + We do this as a separate step so that subclasses know which + tables are going to be active in the query, without needing to compute + all the select columns (this method is called from pre_sql_setup(), + whereas column determination is a later part, and side-effect, of + as_sql()). + """ + opts = self.model._meta + root_pk = opts.pk.column + root_alias = self.tables[0] + seen = {None: root_alias} + for field, model in opts.get_fields_with_model(): + if model not in seen: + seen[model] = self.join((root_alias, model._meta.db_table, + root_pk, model._meta.pk.column)) + self.included_inherited_models = seen + + def remove_inherited_models(self): + """ + Undoes the effects of setup_inherited_models(). Should be called + whenever select columns (self.select) are set explicitly. + """ + for key, alias in self.included_inherited_models.items(): + if key: + self.unref_alias(alias) + self.included_inherited_models = {} + def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1, used=None, requested=None, restricted=None, nullable=None, dupe_set=None, avoid_set=None): @@ -1803,6 +1852,7 @@ class BaseQuery(object): names.sort() raise FieldError("Cannot resolve keyword %r into field. " "Choices are: %s" % (name, ", ".join(names))) + self.remove_inherited_models() def add_ordering(self, *ordering): """ @@ -2004,6 +2054,7 @@ class BaseQuery(object): select_alias = join_info[RHS_ALIAS] select_col = join_info[RHS_JOIN_COL] self.select = [(select_alias, select_col)] + self.remove_inherited_models() def execute_sql(self, result_type=MULTI): """ diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index 385df90569..d40004b1c1 100644 --- a/django/db/models/sql/subqueries.py +++ b/django/db/models/sql/subqueries.py @@ -179,21 +179,9 @@ class UpdateQuery(Query): query = self.clone(klass=Query) query.bump_prefix() query.extra_select = {} - first_table = query.tables[0] - if query.alias_refcount[first_table] == 1: - # We can remove one table from the inner query. - query.unref_alias(first_table) - for i in xrange(1, len(query.tables)): - table = query.tables[i] - if query.alias_refcount[table]: - break - join_info = query.alias_map[table] - query.select = [(join_info[RHS_ALIAS], join_info[RHS_JOIN_COL])] - must_pre_select = False - else: - query.select = [] - query.add_fields([query.model._meta.pk.name]) - must_pre_select = not self.connection.features.update_can_self_select + query.select = [] + query.add_fields([query.model._meta.pk.name]) + must_pre_select = count > 1 and not self.connection.features.update_can_self_select # Now we adjust the current query: reset the where clause and get rid # of all the tables we don't need (since they're in the sub-select). diff --git a/tests/regressiontests/model_inheritance_regress/models.py b/tests/regressiontests/model_inheritance_regress/models.py index 06a886e0be..b5c051d5ca 100644 --- a/tests/regressiontests/model_inheritance_regress/models.py +++ b/tests/regressiontests/model_inheritance_regress/models.py @@ -222,7 +222,7 @@ True >>> obj = SelfRefChild.objects.create(child_data=37, parent_data=42) >>> obj.delete() -# Regression tests for #8076 - get_(next/previous)_by_date should +# Regression tests for #8076 - get_(next/previous)_by_date should work. >>> c1 = ArticleWithAuthor(headline='ArticleWithAuthor 1', author="Person 1", pub_date=datetime.datetime(2005, 8, 1, 3, 0)) >>> c1.save() >>> c2 = ArticleWithAuthor(headline='ArticleWithAuthor 2', author="Person 2", pub_date=datetime.datetime(2005, 8, 1, 10, 0)) @@ -267,4 +267,12 @@ DoesNotExist: ArticleWithAuthor matching query does not exist. >>> fragment.find('pub_date', pos + 1) == -1 True +# It is possible to call update() and only change a field in an ancestor model +# (regression test for #10362). +>>> article = ArticleWithAuthor.objects.create(author="fred", headline="Hey there!", pub_date = datetime.datetime(2009, 3, 1, 8, 0, 0)) +>>> ArticleWithAuthor.objects.filter(author="fred").update(headline="Oh, no!") +1 +>>> ArticleWithAuthor.objects.filter(pk=article.pk).update(headline="Oh, no!") +1 + """}