Fixed #13003 -- Ensured that ._state.db is set correctly for select_related() queries. Thanks to Alex Gaynor for the report.
git-svn-id: http://code.djangoproject.com/svn/django/trunk@12701 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
parent
3508a86ddf
commit
18983f0ee7
|
@ -267,7 +267,7 @@ class QuerySet(object):
|
||||||
for row in compiler.results_iter():
|
for row in compiler.results_iter():
|
||||||
if fill_cache:
|
if fill_cache:
|
||||||
obj, _ = get_cached_row(self.model, row,
|
obj, _ = get_cached_row(self.model, row,
|
||||||
index_start, max_depth,
|
index_start, using=self.db, max_depth=max_depth,
|
||||||
requested=requested, offset=len(aggregate_select),
|
requested=requested, offset=len(aggregate_select),
|
||||||
only_load=only_load)
|
only_load=only_load)
|
||||||
else:
|
else:
|
||||||
|
@ -279,6 +279,9 @@ class QuerySet(object):
|
||||||
# Omit aggregates in object creation.
|
# Omit aggregates in object creation.
|
||||||
obj = self.model(*row[index_start:aggregate_start])
|
obj = self.model(*row[index_start:aggregate_start])
|
||||||
|
|
||||||
|
# Store the source database of the object
|
||||||
|
obj._state.db = self.db
|
||||||
|
|
||||||
for i, k in enumerate(extra_select):
|
for i, k in enumerate(extra_select):
|
||||||
setattr(obj, k, row[i])
|
setattr(obj, k, row[i])
|
||||||
|
|
||||||
|
@ -286,9 +289,6 @@ class QuerySet(object):
|
||||||
for i, aggregate in enumerate(aggregate_select):
|
for i, aggregate in enumerate(aggregate_select):
|
||||||
setattr(obj, aggregate, row[i+aggregate_start])
|
setattr(obj, aggregate, row[i+aggregate_start])
|
||||||
|
|
||||||
# Store the source database of the object
|
|
||||||
obj._state.db = self.db
|
|
||||||
|
|
||||||
yield obj
|
yield obj
|
||||||
|
|
||||||
def aggregate(self, *args, **kwargs):
|
def aggregate(self, *args, **kwargs):
|
||||||
|
@ -1112,7 +1112,7 @@ class EmptyQuerySet(QuerySet):
|
||||||
value_annotation = False
|
value_annotation = False
|
||||||
|
|
||||||
|
|
||||||
def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
|
def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
|
||||||
requested=None, offset=0, only_load=None):
|
requested=None, offset=0, only_load=None):
|
||||||
"""
|
"""
|
||||||
Helper function that recursively returns an object with the specified
|
Helper function that recursively returns an object with the specified
|
||||||
|
@ -1126,6 +1126,7 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
|
||||||
* row - the row of data returned by the database cursor
|
* row - the row of data returned by the database cursor
|
||||||
* index_start - the index of the row at which data for this
|
* index_start - the index of the row at which data for this
|
||||||
object is known to start
|
object is known to start
|
||||||
|
* using - the database alias on which the query is being executed.
|
||||||
* max_depth - the maximum depth to which a select_related()
|
* max_depth - the maximum depth to which a select_related()
|
||||||
relationship should be explored.
|
relationship should be explored.
|
||||||
* cur_depth - the current depth in the select_related() tree.
|
* cur_depth - the current depth in the select_related() tree.
|
||||||
|
@ -1170,6 +1171,7 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
|
||||||
obj = klass(**dict(zip(init_list, fields)))
|
obj = klass(**dict(zip(init_list, fields)))
|
||||||
else:
|
else:
|
||||||
obj = klass(*fields)
|
obj = klass(*fields)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Load all fields on klass
|
# Load all fields on klass
|
||||||
field_count = len(klass._meta.fields)
|
field_count = len(klass._meta.fields)
|
||||||
|
@ -1182,6 +1184,10 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
|
||||||
else:
|
else:
|
||||||
obj = klass(*fields)
|
obj = klass(*fields)
|
||||||
|
|
||||||
|
# If an object was retrieved, set the database state.
|
||||||
|
if obj:
|
||||||
|
obj._state.db = using
|
||||||
|
|
||||||
index_end = index_start + field_count + offset
|
index_end = index_start + field_count + offset
|
||||||
# Iterate over each related object, populating any
|
# Iterate over each related object, populating any
|
||||||
# select_related() fields
|
# select_related() fields
|
||||||
|
@ -1193,8 +1199,8 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
|
||||||
else:
|
else:
|
||||||
next = None
|
next = None
|
||||||
# Recursively retrieve the data for the related object
|
# Recursively retrieve the data for the related object
|
||||||
cached_row = get_cached_row(f.rel.to, row, index_end, max_depth,
|
cached_row = get_cached_row(f.rel.to, row, index_end, using,
|
||||||
cur_depth+1, next)
|
max_depth, cur_depth+1, next)
|
||||||
# If the recursive descent found an object, populate the
|
# If the recursive descent found an object, populate the
|
||||||
# descriptor caches relevant to the object
|
# descriptor caches relevant to the object
|
||||||
if cached_row:
|
if cached_row:
|
||||||
|
@ -1222,8 +1228,8 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
|
||||||
continue
|
continue
|
||||||
next = requested[f.related_query_name()]
|
next = requested[f.related_query_name()]
|
||||||
# Recursively retrieve the data for the related object
|
# Recursively retrieve the data for the related object
|
||||||
cached_row = get_cached_row(model, row, index_end, max_depth,
|
cached_row = get_cached_row(model, row, index_end, using,
|
||||||
cur_depth+1, next)
|
max_depth, cur_depth+1, next)
|
||||||
# If the recursive descent found an object, populate the
|
# If the recursive descent found an object, populate the
|
||||||
# descriptor caches relevant to the object
|
# descriptor caches relevant to the object
|
||||||
if cached_row:
|
if cached_row:
|
||||||
|
|
|
@ -641,6 +641,20 @@ class QueryTestCase(TestCase):
|
||||||
val = Book.objects.raw('SELECT id FROM "multiple_database_book"').using('other')
|
val = Book.objects.raw('SELECT id FROM "multiple_database_book"').using('other')
|
||||||
self.assertEqual(map(lambda o: o.pk, val), [dive.pk])
|
self.assertEqual(map(lambda o: o.pk, val), [dive.pk])
|
||||||
|
|
||||||
|
def test_select_related(self):
|
||||||
|
"Database assignment is retained if an object is retrieved with select_related()"
|
||||||
|
# Create a book and author on the other database
|
||||||
|
mark = Person.objects.using('other').create(name="Mark Pilgrim")
|
||||||
|
dive = Book.objects.using('other').create(title="Dive into Python",
|
||||||
|
published=datetime.date(2009, 5, 4),
|
||||||
|
editor=mark)
|
||||||
|
|
||||||
|
# Retrieve the Person using select_related()
|
||||||
|
book = Book.objects.using('other').select_related('editor').get(title="Dive into Python")
|
||||||
|
|
||||||
|
# The editor instance should have a db state
|
||||||
|
self.assertEqual(book.editor._state.db, 'other')
|
||||||
|
|
||||||
class TestRouter(object):
|
class TestRouter(object):
|
||||||
# A test router. The behaviour is vaguely master/slave, but the
|
# A test router. The behaviour is vaguely master/slave, but the
|
||||||
# databases aren't assumed to propagate changes.
|
# databases aren't assumed to propagate changes.
|
||||||
|
|
Loading…
Reference in New Issue