Fixed #21204 -- Tracked field deferrals by field instead of models.
This ensures field deferral works properly when a model is involved more than once in the same query with a distinct deferral mask.
This commit is contained in:
parent
5d12650ed9
commit
b3db6c8dcb
|
@ -259,7 +259,7 @@ class RegisterLookupMixin:
|
|||
cls._clear_cached_lookups()
|
||||
|
||||
|
||||
def select_related_descend(field, restricted, requested, load_fields, reverse=False):
|
||||
def select_related_descend(field, restricted, requested, select_mask, reverse=False):
|
||||
"""
|
||||
Return True if this field should be used to descend deeper for
|
||||
select_related() purposes. Used by both the query construction code
|
||||
|
@ -271,7 +271,7 @@ def select_related_descend(field, restricted, requested, load_fields, reverse=Fa
|
|||
* restricted - a boolean field, indicating if the field list has been
|
||||
manually restricted using a requested clause)
|
||||
* requested - The select_related() dictionary.
|
||||
* load_fields - the set of fields to be loaded on this model
|
||||
* select_mask - the dictionary of selected fields.
|
||||
* reverse - boolean, True if we are checking a reverse select related
|
||||
"""
|
||||
if not field.remote_field:
|
||||
|
@ -287,9 +287,9 @@ def select_related_descend(field, restricted, requested, load_fields, reverse=Fa
|
|||
return False
|
||||
if (
|
||||
restricted
|
||||
and load_fields
|
||||
and select_mask
|
||||
and field.name in requested
|
||||
and field.attname not in load_fields
|
||||
and field not in select_mask
|
||||
):
|
||||
raise FieldError(
|
||||
f"Field {field.model._meta.object_name}.{field.name} cannot be both "
|
||||
|
|
|
@ -256,8 +256,9 @@ class SQLCompiler:
|
|||
select.append((RawSQL(sql, params), alias))
|
||||
select_idx += 1
|
||||
assert not (self.query.select and self.query.default_cols)
|
||||
select_mask = self.query.get_select_mask()
|
||||
if self.query.default_cols:
|
||||
cols = self.get_default_columns()
|
||||
cols = self.get_default_columns(select_mask)
|
||||
else:
|
||||
# self.query.select is a special case. These columns never go to
|
||||
# any model.
|
||||
|
@ -278,7 +279,7 @@ class SQLCompiler:
|
|||
select_idx += 1
|
||||
|
||||
if self.query.select_related:
|
||||
related_klass_infos = self.get_related_selections(select)
|
||||
related_klass_infos = self.get_related_selections(select, select_mask)
|
||||
klass_info["related_klass_infos"] = related_klass_infos
|
||||
|
||||
def get_select_from_parent(klass_info):
|
||||
|
@ -870,7 +871,9 @@ class SQLCompiler:
|
|||
# Finally do cleanup - get rid of the joins we created above.
|
||||
self.query.reset_refcounts(refcounts_before)
|
||||
|
||||
def get_default_columns(self, start_alias=None, opts=None, from_parent=None):
|
||||
def get_default_columns(
|
||||
self, select_mask, start_alias=None, opts=None, from_parent=None
|
||||
):
|
||||
"""
|
||||
Compute the default columns for selecting every field in the base
|
||||
model. Will sometimes be called to pull in related models (e.g. via
|
||||
|
@ -886,7 +889,6 @@ class SQLCompiler:
|
|||
if opts is None:
|
||||
if (opts := self.query.get_meta()) is None:
|
||||
return result
|
||||
only_load = self.deferred_to_columns()
|
||||
start_alias = start_alias or self.query.get_initial_alias()
|
||||
# The 'seen_models' is used to optimize checking the needed parent
|
||||
# alias for a given field. This also includes None -> start_alias to
|
||||
|
@ -912,7 +914,7 @@ class SQLCompiler:
|
|||
# parent model data is already present in the SELECT clause,
|
||||
# and we want to avoid reloading the same data again.
|
||||
continue
|
||||
if field.model in only_load and field.attname not in only_load[field.model]:
|
||||
if select_mask and field not in select_mask:
|
||||
continue
|
||||
alias = self.query.join_parent_model(opts, model, start_alias, seen_models)
|
||||
column = field.get_col(alias)
|
||||
|
@ -1063,6 +1065,7 @@ class SQLCompiler:
|
|||
def get_related_selections(
|
||||
self,
|
||||
select,
|
||||
select_mask,
|
||||
opts=None,
|
||||
root_alias=None,
|
||||
cur_depth=1,
|
||||
|
@ -1095,7 +1098,6 @@ class SQLCompiler:
|
|||
if not opts:
|
||||
opts = self.query.get_meta()
|
||||
root_alias = self.query.get_initial_alias()
|
||||
only_load = self.deferred_to_columns()
|
||||
|
||||
# Setup for the case when only particular related fields should be
|
||||
# included in the related selection.
|
||||
|
@ -1109,7 +1111,6 @@ class SQLCompiler:
|
|||
klass_info["related_klass_infos"] = related_klass_infos
|
||||
|
||||
for f in opts.fields:
|
||||
field_model = f.model._meta.concrete_model
|
||||
fields_found.add(f.name)
|
||||
|
||||
if restricted:
|
||||
|
@ -1129,10 +1130,9 @@ class SQLCompiler:
|
|||
else:
|
||||
next = False
|
||||
|
||||
if not select_related_descend(
|
||||
f, restricted, requested, only_load.get(field_model)
|
||||
):
|
||||
if not select_related_descend(f, restricted, requested, select_mask):
|
||||
continue
|
||||
related_select_mask = select_mask.get(f) or {}
|
||||
klass_info = {
|
||||
"model": f.remote_field.model,
|
||||
"field": f,
|
||||
|
@ -1148,7 +1148,7 @@ class SQLCompiler:
|
|||
_, _, _, joins, _, _ = self.query.setup_joins([f.name], opts, root_alias)
|
||||
alias = joins[-1]
|
||||
columns = self.get_default_columns(
|
||||
start_alias=alias, opts=f.remote_field.model._meta
|
||||
related_select_mask, start_alias=alias, opts=f.remote_field.model._meta
|
||||
)
|
||||
for col in columns:
|
||||
select_fields.append(len(select))
|
||||
|
@ -1156,6 +1156,7 @@ class SQLCompiler:
|
|||
klass_info["select_fields"] = select_fields
|
||||
next_klass_infos = self.get_related_selections(
|
||||
select,
|
||||
related_select_mask,
|
||||
f.remote_field.model._meta,
|
||||
alias,
|
||||
cur_depth + 1,
|
||||
|
@ -1171,8 +1172,9 @@ class SQLCompiler:
|
|||
if o.field.unique and not o.many_to_many
|
||||
]
|
||||
for f, model in related_fields:
|
||||
related_select_mask = select_mask.get(f) or {}
|
||||
if not select_related_descend(
|
||||
f, restricted, requested, only_load.get(model), reverse=True
|
||||
f, restricted, requested, related_select_mask, reverse=True
|
||||
):
|
||||
continue
|
||||
|
||||
|
@ -1195,7 +1197,10 @@ class SQLCompiler:
|
|||
related_klass_infos.append(klass_info)
|
||||
select_fields = []
|
||||
columns = self.get_default_columns(
|
||||
start_alias=alias, opts=model._meta, from_parent=opts.model
|
||||
related_select_mask,
|
||||
start_alias=alias,
|
||||
opts=model._meta,
|
||||
from_parent=opts.model,
|
||||
)
|
||||
for col in columns:
|
||||
select_fields.append(len(select))
|
||||
|
@ -1203,7 +1208,13 @@ class SQLCompiler:
|
|||
klass_info["select_fields"] = select_fields
|
||||
next = requested.get(f.related_query_name(), {})
|
||||
next_klass_infos = self.get_related_selections(
|
||||
select, model._meta, alias, cur_depth + 1, next, restricted
|
||||
select,
|
||||
related_select_mask,
|
||||
model._meta,
|
||||
alias,
|
||||
cur_depth + 1,
|
||||
next,
|
||||
restricted,
|
||||
)
|
||||
get_related_klass_infos(klass_info, next_klass_infos)
|
||||
|
||||
|
@ -1239,7 +1250,9 @@ class SQLCompiler:
|
|||
}
|
||||
related_klass_infos.append(klass_info)
|
||||
select_fields = []
|
||||
field_select_mask = select_mask.get((name, f)) or {}
|
||||
columns = self.get_default_columns(
|
||||
field_select_mask,
|
||||
start_alias=alias,
|
||||
opts=model._meta,
|
||||
from_parent=opts.model,
|
||||
|
@ -1251,6 +1264,7 @@ class SQLCompiler:
|
|||
next_requested = requested.get(name, {})
|
||||
next_klass_infos = self.get_related_selections(
|
||||
select,
|
||||
field_select_mask,
|
||||
opts=model._meta,
|
||||
root_alias=alias,
|
||||
cur_depth=cur_depth + 1,
|
||||
|
@ -1377,16 +1391,6 @@ class SQLCompiler:
|
|||
)
|
||||
return result
|
||||
|
||||
def deferred_to_columns(self):
|
||||
"""
|
||||
Convert the self.deferred_loading data structure to mapping of table
|
||||
names to sets of column names which are to be loaded. Return the
|
||||
dictionary.
|
||||
"""
|
||||
columns = {}
|
||||
self.query.deferred_to_data(columns)
|
||||
return columns
|
||||
|
||||
def get_converters(self, expressions):
|
||||
converters = {}
|
||||
for i, expression in enumerate(expressions):
|
||||
|
|
|
@ -718,7 +718,61 @@ class Query(BaseExpression):
|
|||
self.order_by = rhs.order_by or self.order_by
|
||||
self.extra_order_by = rhs.extra_order_by or self.extra_order_by
|
||||
|
||||
def deferred_to_data(self, target):
|
||||
def _get_defer_select_mask(self, opts, mask, select_mask=None):
|
||||
if select_mask is None:
|
||||
select_mask = {}
|
||||
select_mask[opts.pk] = {}
|
||||
# All concrete fields that are not part of the defer mask must be
|
||||
# loaded. If a relational field is encountered it gets added to the
|
||||
# mask for it be considered if `select_related` and the cycle continues
|
||||
# by recursively caling this function.
|
||||
for field in opts.concrete_fields:
|
||||
field_mask = mask.pop(field.name, None)
|
||||
if field_mask is None:
|
||||
select_mask.setdefault(field, {})
|
||||
elif field_mask:
|
||||
if not field.is_relation:
|
||||
raise FieldError(next(iter(field_mask)))
|
||||
field_select_mask = select_mask.setdefault(field, {})
|
||||
related_model = field.remote_field.model._meta.concrete_model
|
||||
self._get_defer_select_mask(
|
||||
related_model._meta, field_mask, field_select_mask
|
||||
)
|
||||
# Remaining defer entries must be references to reverse relationships.
|
||||
# The following code is expected to raise FieldError if it encounters
|
||||
# a malformed defer entry.
|
||||
for field_name, field_mask in mask.items():
|
||||
if filtered_relation := self._filtered_relations.get(field_name):
|
||||
relation = opts.get_field(filtered_relation.relation_name)
|
||||
field_select_mask = select_mask.setdefault((field_name, relation), {})
|
||||
field = relation.field
|
||||
else:
|
||||
field = opts.get_field(field_name).field
|
||||
field_select_mask = select_mask.setdefault(field, {})
|
||||
related_model = field.model._meta.concrete_model
|
||||
self._get_defer_select_mask(
|
||||
related_model._meta, field_mask, field_select_mask
|
||||
)
|
||||
return select_mask
|
||||
|
||||
def _get_only_select_mask(self, opts, mask, select_mask=None):
|
||||
if select_mask is None:
|
||||
select_mask = {}
|
||||
select_mask[opts.pk] = {}
|
||||
# Only include fields mentioned in the mask.
|
||||
for field_name, field_mask in mask.items():
|
||||
field = opts.get_field(field_name)
|
||||
field_select_mask = select_mask.setdefault(field, {})
|
||||
if field_mask:
|
||||
if not field.is_relation:
|
||||
raise FieldError(next(iter(field_mask)))
|
||||
related_model = field.remote_field.model._meta.concrete_model
|
||||
self._get_only_select_mask(
|
||||
related_model._meta, field_mask, field_select_mask
|
||||
)
|
||||
return select_mask
|
||||
|
||||
def get_select_mask(self):
|
||||
"""
|
||||
Convert the self.deferred_loading data structure to an alternate data
|
||||
structure, describing the field that *will* be loaded. This is used to
|
||||
|
@ -726,81 +780,19 @@ class Query(BaseExpression):
|
|||
QuerySet class to work out which fields are being initialized on each
|
||||
model. Models that have all their fields included aren't mentioned in
|
||||
the result, only those that have field restrictions in place.
|
||||
|
||||
The "target" parameter is the instance that is populated (in place).
|
||||
"""
|
||||
field_names, defer = self.deferred_loading
|
||||
if not field_names:
|
||||
return
|
||||
orig_opts = self.get_meta()
|
||||
seen = {}
|
||||
must_include = {orig_opts.concrete_model: {orig_opts.pk}}
|
||||
return {}
|
||||
mask = {}
|
||||
for field_name in field_names:
|
||||
parts = field_name.split(LOOKUP_SEP)
|
||||
cur_model = self.model._meta.concrete_model
|
||||
opts = orig_opts
|
||||
for name in parts[:-1]:
|
||||
old_model = cur_model
|
||||
if name in self._filtered_relations:
|
||||
name = self._filtered_relations[name].relation_name
|
||||
source = opts.get_field(name)
|
||||
if is_reverse_o2o(source):
|
||||
cur_model = source.related_model
|
||||
else:
|
||||
cur_model = source.remote_field.model
|
||||
cur_model = cur_model._meta.concrete_model
|
||||
opts = cur_model._meta
|
||||
# Even if we're "just passing through" this model, we must add
|
||||
# both the current model's pk and the related reference field
|
||||
# (if it's not a reverse relation) to the things we select.
|
||||
if not is_reverse_o2o(source):
|
||||
must_include[old_model].add(source)
|
||||
add_to_dict(must_include, cur_model, opts.pk)
|
||||
field = opts.get_field(parts[-1])
|
||||
is_reverse_object = field.auto_created and not field.concrete
|
||||
model = field.related_model if is_reverse_object else field.model
|
||||
model = model._meta.concrete_model
|
||||
if model == opts.model:
|
||||
model = cur_model
|
||||
if not is_reverse_o2o(field):
|
||||
add_to_dict(seen, model, field)
|
||||
|
||||
part_mask = mask
|
||||
for part in field_name.split(LOOKUP_SEP):
|
||||
part_mask = part_mask.setdefault(part, {})
|
||||
opts = self.get_meta()
|
||||
if defer:
|
||||
# We need to load all fields for each model, except those that
|
||||
# appear in "seen" (for all models that appear in "seen"). The only
|
||||
# slight complexity here is handling fields that exist on parent
|
||||
# models.
|
||||
workset = {}
|
||||
for model, values in seen.items():
|
||||
for field in model._meta.local_fields:
|
||||
if field not in values:
|
||||
m = field.model._meta.concrete_model
|
||||
add_to_dict(workset, m, field)
|
||||
for model, values in must_include.items():
|
||||
# If we haven't included a model in workset, we don't add the
|
||||
# corresponding must_include fields for that model, since an
|
||||
# empty set means "include all fields". That's why there's no
|
||||
# "else" branch here.
|
||||
if model in workset:
|
||||
workset[model].update(values)
|
||||
for model, fields in workset.items():
|
||||
target[model] = {f.attname for f in fields}
|
||||
else:
|
||||
for model, values in must_include.items():
|
||||
if model in seen:
|
||||
seen[model].update(values)
|
||||
else:
|
||||
# As we've passed through this model, but not explicitly
|
||||
# included any fields, we have to make sure it's mentioned
|
||||
# so that only the "must include" fields are pulled in.
|
||||
seen[model] = values
|
||||
# Now ensure that every model in the inheritance chain is mentioned
|
||||
# in the parent list. Again, it must be mentioned to ensure that
|
||||
# only "must include" fields are pulled in.
|
||||
for model in orig_opts.get_parent_list():
|
||||
seen.setdefault(model, set())
|
||||
for model, fields in seen.items():
|
||||
target[model] = {f.attname for f in fields}
|
||||
return self._get_defer_select_mask(opts, mask)
|
||||
return self._get_only_select_mask(opts, mask)
|
||||
|
||||
def table_alias(self, table_name, create=False, filtered_relation=None):
|
||||
"""
|
||||
|
@ -2583,25 +2575,6 @@ def get_order_dir(field, default="ASC"):
|
|||
return field, dirn[0]
|
||||
|
||||
|
||||
def add_to_dict(data, key, value):
|
||||
"""
|
||||
Add "value" to the set of values for "key", whether or not "key" already
|
||||
exists.
|
||||
"""
|
||||
if key in data:
|
||||
data[key].add(value)
|
||||
else:
|
||||
data[key] = {value}
|
||||
|
||||
|
||||
def is_reverse_o2o(field):
|
||||
"""
|
||||
Check if the given field is reverse-o2o. The field is expected to be some
|
||||
sort of relation field or related object.
|
||||
"""
|
||||
return field.is_relation and field.one_to_one and not field.concrete
|
||||
|
||||
|
||||
class JoinPromoter:
|
||||
"""
|
||||
A class to abstract away join promotion problems for complex filter
|
||||
|
|
|
@ -290,6 +290,8 @@ class InvalidDeferTests(SimpleTestCase):
|
|||
msg = "Primary has no field named 'missing'"
|
||||
with self.assertRaisesMessage(FieldDoesNotExist, msg):
|
||||
list(Primary.objects.defer("missing"))
|
||||
with self.assertRaisesMessage(FieldError, "missing"):
|
||||
list(Primary.objects.defer("value__missing"))
|
||||
msg = "Secondary has no field named 'missing'"
|
||||
with self.assertRaisesMessage(FieldDoesNotExist, msg):
|
||||
list(Primary.objects.defer("related__missing"))
|
||||
|
@ -298,6 +300,8 @@ class InvalidDeferTests(SimpleTestCase):
|
|||
msg = "Primary has no field named 'missing'"
|
||||
with self.assertRaisesMessage(FieldDoesNotExist, msg):
|
||||
list(Primary.objects.only("missing"))
|
||||
with self.assertRaisesMessage(FieldError, "missing"):
|
||||
list(Primary.objects.only("value__missing"))
|
||||
msg = "Secondary has no field named 'missing'"
|
||||
with self.assertRaisesMessage(FieldDoesNotExist, msg):
|
||||
list(Primary.objects.only("related__missing"))
|
||||
|
|
|
@ -246,8 +246,6 @@ class DeferRegressionTest(TestCase):
|
|||
)
|
||||
self.assertEqual(len(qs), 1)
|
||||
|
||||
|
||||
class DeferAnnotateSelectRelatedTest(TestCase):
|
||||
def test_defer_annotate_select_related(self):
|
||||
location = Location.objects.create()
|
||||
Request.objects.create(location=location)
|
||||
|
@ -276,6 +274,28 @@ class DeferAnnotateSelectRelatedTest(TestCase):
|
|||
list,
|
||||
)
|
||||
|
||||
def test_common_model_different_mask(self):
|
||||
child = Child.objects.create(name="Child", value=42)
|
||||
second_child = Child.objects.create(name="Second", value=64)
|
||||
Leaf.objects.create(child=child, second_child=second_child)
|
||||
with self.assertNumQueries(1):
|
||||
leaf = (
|
||||
Leaf.objects.select_related("child", "second_child")
|
||||
.defer("child__name", "second_child__value")
|
||||
.get()
|
||||
)
|
||||
self.assertEqual(leaf.child, child)
|
||||
self.assertEqual(leaf.second_child, second_child)
|
||||
self.assertEqual(leaf.child.get_deferred_fields(), {"name"})
|
||||
self.assertEqual(leaf.second_child.get_deferred_fields(), {"value"})
|
||||
with self.assertNumQueries(0):
|
||||
self.assertEqual(leaf.child.value, 42)
|
||||
self.assertEqual(leaf.second_child.name, "Second")
|
||||
with self.assertNumQueries(1):
|
||||
self.assertEqual(leaf.child.name, "Child")
|
||||
with self.assertNumQueries(1):
|
||||
self.assertEqual(leaf.second_child.value, 64)
|
||||
|
||||
|
||||
class DeferDeletionSignalsTests(TestCase):
|
||||
senders = [Item, Proxy]
|
||||
|
|
|
@ -3594,12 +3594,6 @@ class WhereNodeTest(SimpleTestCase):
|
|||
|
||||
|
||||
class QuerySetExceptionTests(SimpleTestCase):
|
||||
def test_iter_exceptions(self):
|
||||
qs = ExtraInfo.objects.only("author")
|
||||
msg = "'ManyToOneRel' object has no attribute 'attname'"
|
||||
with self.assertRaisesMessage(AttributeError, msg):
|
||||
list(qs)
|
||||
|
||||
def test_invalid_order_by(self):
|
||||
msg = "Cannot resolve keyword '*' into field. Choices are: created, id, name"
|
||||
with self.assertRaisesMessage(FieldError, msg):
|
||||
|
|
Loading…
Reference in New Issue