From 18983f0ee73c9b3708b10b90e0a37bd17a8a1729 Mon Sep 17 00:00:00 2001 From: Russell Keith-Magee Date: Sun, 7 Mar 2010 07:13:55 +0000 Subject: [PATCH] Fixed #13003 -- Ensured that ._state.db is set correctly for select_related() queries. Thanks to Alex Gaynor for the report. git-svn-id: http://code.djangoproject.com/svn/django/trunk@12701 bcc190cf-cafb-0310-a4f2-bffc1f526a37 --- django/db/models/query.py | 24 ++++++++++++------- .../multiple_database/tests.py | 14 +++++++++++ 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/django/db/models/query.py b/django/db/models/query.py index ab6a14e48a..fea1144200 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -267,7 +267,7 @@ class QuerySet(object): for row in compiler.results_iter(): if fill_cache: obj, _ = get_cached_row(self.model, row, - index_start, max_depth, + index_start, using=self.db, max_depth=max_depth, requested=requested, offset=len(aggregate_select), only_load=only_load) else: @@ -279,6 +279,9 @@ class QuerySet(object): # Omit aggregates in object creation. obj = self.model(*row[index_start:aggregate_start]) + # Store the source database of the object + obj._state.db = self.db + for i, k in enumerate(extra_select): setattr(obj, k, row[i]) @@ -286,9 +289,6 @@ class QuerySet(object): for i, aggregate in enumerate(aggregate_select): setattr(obj, aggregate, row[i+aggregate_start]) - # Store the source database of the object - obj._state.db = self.db - yield obj def aggregate(self, *args, **kwargs): @@ -1112,7 +1112,7 @@ class EmptyQuerySet(QuerySet): value_annotation = False -def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0, +def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0, requested=None, offset=0, only_load=None): """ Helper function that recursively returns an object with the specified @@ -1126,6 +1126,7 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0, * row - the row of data returned by the database cursor * index_start - the index of the row at which data for this object is known to start + * using - the database alias on which the query is being executed. * max_depth - the maximum depth to which a select_related() relationship should be explored. * cur_depth - the current depth in the select_related() tree. @@ -1170,6 +1171,7 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0, obj = klass(**dict(zip(init_list, fields))) else: obj = klass(*fields) + else: # Load all fields on klass field_count = len(klass._meta.fields) @@ -1182,6 +1184,10 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0, else: obj = klass(*fields) + # If an object was retrieved, set the database state. + if obj: + obj._state.db = using + index_end = index_start + field_count + offset # Iterate over each related object, populating any # select_related() fields @@ -1193,8 +1199,8 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0, else: next = None # Recursively retrieve the data for the related object - cached_row = get_cached_row(f.rel.to, row, index_end, max_depth, - cur_depth+1, next) + cached_row = get_cached_row(f.rel.to, row, index_end, using, + max_depth, cur_depth+1, next) # If the recursive descent found an object, populate the # descriptor caches relevant to the object if cached_row: @@ -1222,8 +1228,8 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0, continue next = requested[f.related_query_name()] # Recursively retrieve the data for the related object - cached_row = get_cached_row(model, row, index_end, max_depth, - cur_depth+1, next) + cached_row = get_cached_row(model, row, index_end, using, + max_depth, cur_depth+1, next) # If the recursive descent found an object, populate the # descriptor caches relevant to the object if cached_row: diff --git a/tests/regressiontests/multiple_database/tests.py b/tests/regressiontests/multiple_database/tests.py index 47ea010045..81b6a5ffb9 100644 --- a/tests/regressiontests/multiple_database/tests.py +++ b/tests/regressiontests/multiple_database/tests.py @@ -641,6 +641,20 @@ class QueryTestCase(TestCase): val = Book.objects.raw('SELECT id FROM "multiple_database_book"').using('other') self.assertEqual(map(lambda o: o.pk, val), [dive.pk]) + def test_select_related(self): + "Database assignment is retained if an object is retrieved with select_related()" + # Create a book and author on the other database + mark = Person.objects.using('other').create(name="Mark Pilgrim") + dive = Book.objects.using('other').create(title="Dive into Python", + published=datetime.date(2009, 5, 4), + editor=mark) + + # Retrieve the Person using select_related() + book = Book.objects.using('other').select_related('editor').get(title="Dive into Python") + + # The editor instance should have a db state + self.assertEqual(book.editor._state.db, 'other') + class TestRouter(object): # A test router. The behaviour is vaguely master/slave, but the # databases aren't assumed to propagate changes.