Fixed #10362 -- An update() that only affects a parent model no longer crashes.

This includes a fairly large refactor of the update() query path (and
the initial portions of constructing the SQL for any query). The
previous code appears to have been only working more or less by accident
and was very fragile.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@9967 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Malcolm Tredinnick 2009-03-04 05:34:01 +00:00
parent 14c8e5227a
commit 0e93f60c7f
3 changed files with 79 additions and 32 deletions

View File

@ -62,6 +62,7 @@ class BaseQuery(object):
self.dupe_avoidance = {} self.dupe_avoidance = {}
self.used_aliases = set() self.used_aliases = set()
self.filter_is_sticky = False self.filter_is_sticky = False
self.included_inherited_models = {}
# SQL-related attributes # SQL-related attributes
self.select = [] self.select = []
@ -171,6 +172,7 @@ class BaseQuery(object):
obj.default_cols = self.default_cols obj.default_cols = self.default_cols
obj.default_ordering = self.default_ordering obj.default_ordering = self.default_ordering
obj.standard_ordering = self.standard_ordering obj.standard_ordering = self.standard_ordering
obj.included_inherited_models = self.included_inherited_models.copy()
obj.ordering_aliases = [] obj.ordering_aliases = []
obj.select_fields = self.select_fields[:] obj.select_fields = self.select_fields[:]
obj.related_select_fields = self.related_select_fields[:] obj.related_select_fields = self.related_select_fields[:]
@ -304,6 +306,7 @@ class BaseQuery(object):
self.select = [] self.select = []
self.default_cols = False self.default_cols = False
self.extra_select = {} self.extra_select = {}
self.remove_inherited_models()
query.clear_ordering(True) query.clear_ordering(True)
query.clear_limits() query.clear_limits()
@ -458,6 +461,7 @@ class BaseQuery(object):
assert self.distinct == rhs.distinct, \ assert self.distinct == rhs.distinct, \
"Cannot combine a unique query with a non-unique query." "Cannot combine a unique query with a non-unique query."
self.remove_inherited_models()
# Work out how to relabel the rhs aliases, if necessary. # Work out how to relabel the rhs aliases, if necessary.
change_map = {} change_map = {}
used = set() used = set()
@ -540,6 +544,9 @@ class BaseQuery(object):
""" """
if not self.tables: if not self.tables:
self.join((None, self.model._meta.db_table, None, None)) self.join((None, self.model._meta.db_table, None, None))
if (not self.select and self.default_cols and not
self.included_inherited_models):
self.setup_inherited_models()
if self.select_related and not self.related_select_cols: if self.select_related and not self.related_select_cols:
self.fill_related_selections() self.fill_related_selections()
@ -619,7 +626,9 @@ class BaseQuery(object):
start_alias=None, opts=None, as_pairs=False): start_alias=None, opts=None, as_pairs=False):
""" """
Computes the default columns for selecting every field in the base Computes the default columns for selecting every field in the base
model. model. Will sometimes be called to pull in related models (e.g. via
select_related), in which case "opts" and "start_alias" will be given
to provide a starting point for the traversal.
Returns a list of strings, quoted appropriately for use in SQL Returns a list of strings, quoted appropriately for use in SQL
directly, as well as a set of aliases used in the select statement (if directly, as well as a set of aliases used in the select statement (if
@ -629,22 +638,25 @@ class BaseQuery(object):
result = [] result = []
if opts is None: if opts is None:
opts = self.model._meta opts = self.model._meta
if start_alias:
table_alias = start_alias
else:
table_alias = self.tables[0]
root_pk = opts.pk.column
seen = {None: table_alias}
qn = self.quote_name_unless_alias qn = self.quote_name_unless_alias
qn2 = self.connection.ops.quote_name qn2 = self.connection.ops.quote_name
aliases = set() aliases = set()
if start_alias:
seen = {None: start_alias}
root_pk = opts.pk.column
for field, model in opts.get_fields_with_model(): for field, model in opts.get_fields_with_model():
try: if start_alias:
alias = seen[model] try:
except KeyError: alias = seen[model]
alias = self.join((table_alias, model._meta.db_table, except KeyError:
root_pk, model._meta.pk.column)) alias = self.join((start_alias, model._meta.db_table,
seen[model] = alias root_pk, model._meta.pk.column))
seen[model] = alias
else:
# If we're starting from the base model of the queryset, the
# aliases will have already been set up in pre_sql_setup(), so
# we can save time here.
alias = self.included_inherited_models[model]
if as_pairs: if as_pairs:
result.append((alias, field.column)) result.append((alias, field.column))
continue continue
@ -996,6 +1008,9 @@ class BaseQuery(object):
if alias == old_alias: if alias == old_alias:
self.tables[pos] = new_alias self.tables[pos] = new_alias
break break
for key, alias in self.included_inherited_models.items():
if alias in change_map:
self.included_inherited_models[key] = change_map[alias]
# 3. Update any joins that refer to the old alias. # 3. Update any joins that refer to the old alias.
for alias, data in self.alias_map.iteritems(): for alias, data in self.alias_map.iteritems():
@ -1062,9 +1077,11 @@ class BaseQuery(object):
lhs.lhs_col = table.col lhs.lhs_col = table.col
If 'always_create' is True and 'reuse' is None, a new alias is always If 'always_create' is True and 'reuse' is None, a new alias is always
created, regardless of whether one already exists or not. Otherwise created, regardless of whether one already exists or not. If
'reuse' must be a set and a new join is created unless one of the 'always_create' is True and 'reuse' is a set, an alias in 'reuse' that
aliases in `reuse` can be used. matches the connection will be returned, if possible. If
'always_create' is False, the first existing alias that matches the
'connection' is returned, if any. Otherwise a new join is created.
If 'exclusions' is specified, it is something satisfying the container If 'exclusions' is specified, it is something satisfying the container
protocol ("foo in exclusions" must work) and specifies a list of protocol ("foo in exclusions" must work) and specifies a list of
@ -1126,6 +1143,38 @@ class BaseQuery(object):
self.rev_join_map[alias] = t_ident self.rev_join_map[alias] = t_ident
return alias return alias
def setup_inherited_models(self):
"""
If the model that is the basis for this QuerySet inherits other models,
we need to ensure that those other models have their tables included in
the query.
We do this as a separate step so that subclasses know which
tables are going to be active in the query, without needing to compute
all the select columns (this method is called from pre_sql_setup(),
whereas column determination is a later part, and side-effect, of
as_sql()).
"""
opts = self.model._meta
root_pk = opts.pk.column
root_alias = self.tables[0]
seen = {None: root_alias}
for field, model in opts.get_fields_with_model():
if model not in seen:
seen[model] = self.join((root_alias, model._meta.db_table,
root_pk, model._meta.pk.column))
self.included_inherited_models = seen
def remove_inherited_models(self):
"""
Undoes the effects of setup_inherited_models(). Should be called
whenever select columns (self.select) are set explicitly.
"""
for key, alias in self.included_inherited_models.items():
if key:
self.unref_alias(alias)
self.included_inherited_models = {}
def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1, def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1,
used=None, requested=None, restricted=None, nullable=None, used=None, requested=None, restricted=None, nullable=None,
dupe_set=None, avoid_set=None): dupe_set=None, avoid_set=None):
@ -1803,6 +1852,7 @@ class BaseQuery(object):
names.sort() names.sort()
raise FieldError("Cannot resolve keyword %r into field. " raise FieldError("Cannot resolve keyword %r into field. "
"Choices are: %s" % (name, ", ".join(names))) "Choices are: %s" % (name, ", ".join(names)))
self.remove_inherited_models()
def add_ordering(self, *ordering): def add_ordering(self, *ordering):
""" """
@ -2004,6 +2054,7 @@ class BaseQuery(object):
select_alias = join_info[RHS_ALIAS] select_alias = join_info[RHS_ALIAS]
select_col = join_info[RHS_JOIN_COL] select_col = join_info[RHS_JOIN_COL]
self.select = [(select_alias, select_col)] self.select = [(select_alias, select_col)]
self.remove_inherited_models()
def execute_sql(self, result_type=MULTI): def execute_sql(self, result_type=MULTI):
""" """

View File

@ -179,21 +179,9 @@ class UpdateQuery(Query):
query = self.clone(klass=Query) query = self.clone(klass=Query)
query.bump_prefix() query.bump_prefix()
query.extra_select = {} query.extra_select = {}
first_table = query.tables[0] query.select = []
if query.alias_refcount[first_table] == 1: query.add_fields([query.model._meta.pk.name])
# We can remove one table from the inner query. must_pre_select = count > 1 and not self.connection.features.update_can_self_select
query.unref_alias(first_table)
for i in xrange(1, len(query.tables)):
table = query.tables[i]
if query.alias_refcount[table]:
break
join_info = query.alias_map[table]
query.select = [(join_info[RHS_ALIAS], join_info[RHS_JOIN_COL])]
must_pre_select = False
else:
query.select = []
query.add_fields([query.model._meta.pk.name])
must_pre_select = not self.connection.features.update_can_self_select
# Now we adjust the current query: reset the where clause and get rid # Now we adjust the current query: reset the where clause and get rid
# of all the tables we don't need (since they're in the sub-select). # of all the tables we don't need (since they're in the sub-select).

View File

@ -222,7 +222,7 @@ True
>>> obj = SelfRefChild.objects.create(child_data=37, parent_data=42) >>> obj = SelfRefChild.objects.create(child_data=37, parent_data=42)
>>> obj.delete() >>> obj.delete()
# Regression tests for #8076 - get_(next/previous)_by_date should # Regression tests for #8076 - get_(next/previous)_by_date should work.
>>> c1 = ArticleWithAuthor(headline='ArticleWithAuthor 1', author="Person 1", pub_date=datetime.datetime(2005, 8, 1, 3, 0)) >>> c1 = ArticleWithAuthor(headline='ArticleWithAuthor 1', author="Person 1", pub_date=datetime.datetime(2005, 8, 1, 3, 0))
>>> c1.save() >>> c1.save()
>>> c2 = ArticleWithAuthor(headline='ArticleWithAuthor 2', author="Person 2", pub_date=datetime.datetime(2005, 8, 1, 10, 0)) >>> c2 = ArticleWithAuthor(headline='ArticleWithAuthor 2', author="Person 2", pub_date=datetime.datetime(2005, 8, 1, 10, 0))
@ -267,4 +267,12 @@ DoesNotExist: ArticleWithAuthor matching query does not exist.
>>> fragment.find('pub_date', pos + 1) == -1 >>> fragment.find('pub_date', pos + 1) == -1
True True
# It is possible to call update() and only change a field in an ancestor model
# (regression test for #10362).
>>> article = ArticleWithAuthor.objects.create(author="fred", headline="Hey there!", pub_date = datetime.datetime(2009, 3, 1, 8, 0, 0))
>>> ArticleWithAuthor.objects.filter(author="fred").update(headline="Oh, no!")
1
>>> ArticleWithAuthor.objects.filter(pk=article.pk).update(headline="Oh, no!")
1
"""} """}