From 512ee0f52889eb7f624f309cdc61fab57ab73a7b Mon Sep 17 00:00:00 2001
From: Russell Keith-Magee <russell@keith-magee.com>
Date: Sat, 6 Jun 2009 06:14:05 +0000
Subject: [PATCH] Fixed #10572 -- Corrected the operation of the defer() and
 only() clauses when used on inherited models.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@10926 bcc190cf-cafb-0310-a4f2-bffc1f526a37
---
 django/db/models/query.py        | 20 ++++++-
 django/db/models/sql/query.py    | 10 +++-
 tests/modeltests/defer/models.py | 92 +++++++++++++++++++++++++++++++-
 3 files changed, 118 insertions(+), 4 deletions(-)

diff --git a/django/db/models/query.py b/django/db/models/query.py
index 0d35b0ba16..0f34fb8a5a 100644
--- a/django/db/models/query.py
+++ b/django/db/models/query.py
@@ -190,7 +190,25 @@ class QuerySet(object):
         index_start = len(extra_select)
         aggregate_start = index_start + len(self.model._meta.fields)
 
-        load_fields = only_load.get(self.model)
+        load_fields = []
+        # If only/defer clauses have been specified,
+        # build the list of fields that are to be loaded.
+        if only_load:
+            for field, model in self.model._meta.get_fields_with_model():
+                if model is None:
+                    model = self.model
+                if field == self.model._meta.pk:
+                    # Record the index of the primary key when it is found
+                    pk_idx = len(load_fields)
+                try:
+                    if field.name in only_load[model]:
+                        # Add a field that has been explicitly included
+                        load_fields.append(field.name)
+                except KeyError:
+                    # Model wasn't explicitly listed in the only_load table
+                    # Therefore, we need to load all fields from this model
+                    load_fields.append(field.name)
+
         skip = None
         if load_fields and not fill_cache:
             # Some fields have been deferred, so we have to initialise
diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py
index d290d60e63..15b9fd6366 100644
--- a/django/db/models/sql/query.py
+++ b/django/db/models/sql/query.py
@@ -635,10 +635,10 @@ class BaseQuery(object):
             # models.
             workset = {}
             for model, values in seen.iteritems():
-                for field, f_model in model._meta.get_fields_with_model():
+                for field in model._meta.local_fields:
                     if field in values:
                         continue
-                    add_to_dict(workset, f_model or model, field)
+                    add_to_dict(workset, model, field)
             for model, values in must_include.iteritems():
                 # If we haven't included a model in workset, we don't add the
                 # corresponding must_include fields for that model, since an
@@ -657,6 +657,12 @@ class BaseQuery(object):
                     # included any fields, we have to make sure it's mentioned
                     # so that only the "must include" fields are pulled in.
                     seen[model] = values
+            # Now ensure that every model in the inheritance chain is mentioned
+            # in the parent list. Again, it must be mentioned to ensure that
+            # only "must include" fields are pulled in.
+            for model in orig_opts.get_parent_list():
+                if model not in seen:
+                    seen[model] = set()
             for model, values in seen.iteritems():
                 callback(target, model, values)
 
diff --git a/tests/modeltests/defer/models.py b/tests/modeltests/defer/models.py
index ce65065d40..96eb427811 100644
--- a/tests/modeltests/defer/models.py
+++ b/tests/modeltests/defer/models.py
@@ -17,6 +17,12 @@ class Primary(models.Model):
     def __unicode__(self):
         return self.name
 
+class Child(Primary):
+    pass
+
+class BigChild(Primary):
+    other = models.CharField(max_length=50)
+
 def count_delayed_fields(obj, debug=False):
     """
     Returns the number of delayed attributes on the given model instance.
@@ -33,7 +39,7 @@ def count_delayed_fields(obj, debug=False):
 
 __test__ = {"API_TEST": """
 To all outward appearances, instances with deferred fields look the same as
-normal instances when we examine attribut values. Therefore we test for the
+normal instances when we examine attribute values. Therefore we test for the
 number of deferred fields on returned instances (by poking at the internals),
 as a way to observe what is going on.
 
@@ -98,5 +104,89 @@ Using defer() and only() with get() is also valid.
 >>> Primary.objects.all()
 [<Primary: a new name>]
 
+# Regression for #10572 - A subclass with no extra fields can defer fields from the base class
+>>> _ = Child.objects.create(name="c1", value="foo", related=s1)
+
+# You can defer a field on a baseclass when the subclass has no fields
+>>> obj = Child.objects.defer("value").get(name="c1")
+>>> count_delayed_fields(obj)
+1
+>>> obj.name
+u"c1"
+>>> obj.value
+u"foo"
+>>> obj.name = "c2"
+>>> obj.save()
+
+# You can retrive a single column on a base class with no fields
+>>> obj = Child.objects.only("name").get(name="c2")
+>>> count_delayed_fields(obj)
+3
+>>> obj.name
+u"c2"
+>>> obj.value
+u"foo"
+>>> obj.name = "cc"
+>>> obj.save()
+
+>>> _ = BigChild.objects.create(name="b1", value="foo", related=s1, other="bar")
+
+# You can defer a field on a baseclass
+>>> obj = BigChild.objects.defer("value").get(name="b1")
+>>> count_delayed_fields(obj)
+1
+>>> obj.name
+u"b1"
+>>> obj.value
+u"foo"
+>>> obj.other
+u"bar"
+>>> obj.name = "b2"
+>>> obj.save()
+
+# You can defer a field on a subclass
+>>> obj = BigChild.objects.defer("other").get(name="b2")
+>>> count_delayed_fields(obj)
+1
+>>> obj.name
+u"b2"
+>>> obj.value
+u"foo"
+>>> obj.other
+u"bar"
+>>> obj.name = "b3"
+>>> obj.save()
+
+# You can retrieve a single field on a baseclass
+>>> obj = BigChild.objects.only("name").get(name="b3")
+>>> count_delayed_fields(obj)
+4
+>>> obj.name
+u"b3"
+>>> obj.value
+u"foo"
+>>> obj.other
+u"bar"
+>>> obj.name = "b4"
+>>> obj.save()
+
+# You can retrieve a single field on a baseclass
+>>> obj = BigChild.objects.only("other").get(name="b4")
+>>> count_delayed_fields(obj)
+4
+>>> obj.name
+u"b4"
+>>> obj.value
+u"foo"
+>>> obj.other
+u"bar"
+>>> obj.name = "bb"
+>>> obj.save()
+
+# Finally, we need to flush the app cache for the defer module.
+# Using only/defer creates some artifical entries in the app cache
+# that messes up later tests. Purge all entries, just to be sure.
+>>> from django.db.models.loading import cache
+>>> cache.app_models['defer'] = {}
 
 """}