From e51fcf8ec3f4d3f68c4cb03f3dcf598c1a06cd4d Mon Sep 17 00:00:00 2001 From: Adrian Holovaty Date: Tue, 31 Jan 2006 02:21:31 +0000 Subject: [PATCH] magic-removal: Created InBulkQuerySet, ValuesQuerySet and DateQuerySet, subclasses of QuerySet that provide custom iterator(). This lets you use iterator() with in_bulk(), values() and dates(). Also added unit tests. git-svn-id: http://code.djangoproject.com/svn/django/branches/magic-removal@2200 bcc190cf-cafb-0310-a4f2-bffc1f526a37 --- django/db/models/manager.py | 3 + django/db/models/query.py | 115 ++++++++++++++++++------------ tests/modeltests/basic/models.py | 39 ++++++---- tests/modeltests/lookup/models.py | 17 ++++- 4 files changed, 113 insertions(+), 61 deletions(-) diff --git a/django/db/models/manager.py b/django/db/models/manager.py index 4b7bc3b288..b2ed971467 100644 --- a/django/db/models/manager.py +++ b/django/db/models/manager.py @@ -85,6 +85,9 @@ class Manager(object): def in_bulk(self, *args, **kwargs): return QuerySet(self.model).in_bulk(*args, **kwargs) + def iterator(self, *args, **kwargs): + return QuerySet(self.model).iterator(*args, **kwargs) + def order_by(self, *args, **kwargs): return QuerySet(self.model).order_by(*args, **kwargs) diff --git a/django/db/models/query.py b/django/db/models/query.py index 5636d92975..18493dac34 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -190,67 +190,33 @@ class QuerySet(object): _, sql, params = del_query._get_sql_clause(False) cursor.execute("DELETE " + sql, params) + ################################################## + # PUBLIC METHODS THAT RETURN A QUERYSET SUBCLASS # + ################################################## + def in_bulk(self, id_list): assert isinstance(id_list, list), "in_bulk() must be provided with a list of IDs." assert id_list != [], "in_bulk() cannot be passed an empty ID list." - bulk_query = self._clone() - bulk_query._where.append("%s.%s IN (%s)" % (backend.quote_name(self.model._meta.db_table), backend.quote_name(self.model._meta.pk.column), ",".join(['%s'] * len(id_list)))) - bulk_query._params.extend(id_list) - return dict([(obj._get_pk_val(), obj) for obj in bulk_query.iterator()]) + return self._clone(klass=InBulkQuerySet, _id_list=id_list) def values(self, *fields): - # select_related and select aren't supported in values(). - values_query = self._clone(_select_related=False, _select={}) - - # 'fields' is a list of field names to fetch. - if fields: - columns = [self.model._meta.get_field(f, many_to_many=False).column for f in fields] - else: # Default to all fields. - columns = [f.column for f in self.model._meta.fields] - - cursor = connection.cursor() - select, sql, params = values_query._get_sql_clause(True) - select = ['%s.%s' % (backend.quote_name(self.model._meta.db_table), backend.quote_name(c)) for c in columns] - cursor.execute("SELECT " + (self._distinct and "DISTINCT " or "") + ",".join(select) + sql, params) - while 1: - rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE) - if not rows: - raise StopIteration - for row in rows: - yield dict(zip(columns, row)) + return self._clone(klass=ValuesQuerySet, _fields=fields) def dates(self, field_name, kind, order='ASC'): """ Returns a list of datetime objects representing all available dates for the given field_name, scoped to 'kind'. """ - from django.db.backends.util import typecast_timestamp - assert kind in ("month", "year", "day"), "'kind' must be one of 'year', 'month' or 'day'." assert order in ('ASC', 'DESC'), "'order' must be either 'ASC' or 'DESC'." # Let the FieldDoesNotExist exception propogate. field = self.model._meta.get_field(field_name, many_to_many=False) assert isinstance(field, DateField), "%r isn't a DateField." % field_name + return self._clone(klass=DateQuerySet, _field=field, _kind=kind, _order=order) - date_query = self._clone() - date_query._order_by = () # Clear this because it'll mess things up otherwise. - if field.null: - date_query._where.append('%s.%s IS NOT NULL' % \ - (backend.quote_name(self.model._meta.db_table), backend.quote_name(field.column))) - select, sql, params = date_query._get_sql_clause(True) - sql = 'SELECT %s %s GROUP BY 1 ORDER BY 1 %s' % \ - (backend.get_date_trunc_sql(kind, '%s.%s' % (backend.quote_name(self.model._meta.db_table), - backend.quote_name(field.column))), sql, order) - cursor = connection.cursor() - cursor.execute(sql, params) - # We have to manually run typecast_timestamp(str()) on the results, because - # MySQL doesn't automatically cast the result of date functions as datetime - # objects -- MySQL returns the values as strings, instead. - return [typecast_timestamp(str(row[0])) for row in cursor.fetchall()] - - ############################################# - # PUBLIC METHODS THAT RETURN A NEW QUERYSET # - ############################################# + ################################################################## + # PUBLIC METHODS THAT ALTER ATTRIBUTES AND RETURN A NEW QUERYSET # + ################################################################## def filter(self, *args, **kwargs): "Returns a new QuerySet instance with the args ANDed to the existing set." @@ -285,8 +251,10 @@ class QuerySet(object): # PRIVATE METHODS # ################### - def _clone(self, **kwargs): - c = QuerySet() + def _clone(self, klass=None, **kwargs): + if klass is None: + klass = self.__class__ + c = klass() c.model = self.model c._filters = self._filters c._order_by = self._order_by @@ -402,6 +370,61 @@ class QuerySet(object): return select, " ".join(sql), params +class InBulkQuerySet(QuerySet): + def iterator(self): + self._where.append("%s.%s IN (%s)" % (backend.quote_name(self.model._meta.db_table), backend.quote_name(self.model._meta.pk.column), ",".join(['%s'] * len(self._id_list)))) + self._params.extend(self._id_list) + yield dict([(obj._get_pk_val(), obj) for obj in QuerySet.iterator(self)]) + + def _get_data(self): + if self._result_cache is None: + for i in self.iterator(): + self._result_cache = i + return self._result_cache + +class ValuesQuerySet(QuerySet): + def iterator(self): + # select_related and select aren't supported in values(). + self._select_related = False + self._select = {} + + # self._fields is a list of field names to fetch. + if self._fields: + columns = [self.model._meta.get_field(f, many_to_many=False).column for f in self._fields] + field_names = [f.attname for f in self._fields] + else: # Default to all fields. + columns = [f.column for f in self.model._meta.fields] + field_names = [f.attname for f in self.model._meta.fields] + + cursor = connection.cursor() + select, sql, params = self._get_sql_clause(True) + select = ['%s.%s' % (backend.quote_name(self.model._meta.db_table), backend.quote_name(c)) for c in columns] + cursor.execute("SELECT " + (self._distinct and "DISTINCT " or "") + ",".join(select) + sql, params) + while 1: + rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE) + if not rows: + raise StopIteration + for row in rows: + yield dict(zip(field_names, row)) + +class DateQuerySet(QuerySet): + def iterator(self): + from django.db.backends.util import typecast_timestamp + self._order_by = () # Clear this because it'll mess things up otherwise. + if self._field.null: + date_query._where.append('%s.%s IS NOT NULL' % \ + (backend.quote_name(self.model._meta.db_table), backend.quote_name(self._field.column))) + select, sql, params = self._get_sql_clause(True) + sql = 'SELECT %s %s GROUP BY 1 ORDER BY 1 %s' % \ + (backend.get_date_trunc_sql(self._kind, '%s.%s' % (backend.quote_name(self.model._meta.db_table), + backend.quote_name(self._field.column))), sql, self._order) + cursor = connection.cursor() + cursor.execute(sql, params) + # We have to manually run typecast_timestamp(str()) on the results, because + # MySQL doesn't automatically cast the result of date functions as datetime + # objects -- MySQL returns the values as strings, instead. + return [typecast_timestamp(str(row[0])) for row in cursor.fetchall()] + class QOperator: "Base class for QAnd and QOr" def __init__(self, *args): diff --git a/tests/modeltests/basic/models.py b/tests/modeltests/basic/models.py index 3e01fc22a4..79cd5fd52f 100644 --- a/tests/modeltests/basic/models.py +++ b/tests/modeltests/basic/models.py @@ -37,14 +37,13 @@ datetime.datetime(2005, 7, 28, 0, 0) >>> a.headline = 'Area woman programs in Python' >>> a.save() -# Listing objects displays all the articles in the database. Note that the article -# is represented by "
", because we haven't given the Article -# model a __repr__() method. +# Article.objects.all() returns all the articles in the database. Note that +# the article is represented by "
", because we haven't given +# the Article model a __repr__() method. >>> Article.objects.all() [
] -# Django provides a rich database lookup API that's entirely driven by -# keyword arguments. +# Django provides a rich database lookup API. >>> Article.objects.get(id__exact=1)
>>> Article.objects.get(headline__startswith='Area woman') @@ -56,7 +55,7 @@ datetime.datetime(2005, 7, 28, 0, 0) >>> Article.objects.get(pub_date__year=2005, pub_date__month=7, pub_date__day=28)
-# You can omit __exact if you want +# The "__exact" lookup type can be omitted, as a shortcut. >>> Article.objects.get(id=1)
>>> Article.objects.get(headline='Area woman programs in Python') @@ -69,7 +68,8 @@ datetime.datetime(2005, 7, 28, 0, 0) >>> Article.objects.filter(pub_date__year=2005, pub_date__month=7) [
] -# Django raises an ArticleDoesNotExist exception for get() +# Django raises an Article.DoesNotExist exception for get() if the parameters +# don't match any object. >>> Article.objects.get(id__exact=2) Traceback (most recent call last): ... @@ -82,7 +82,7 @@ DoesNotExist: Article does not exist for ... # Lookup by a primary key is the most common case, so Django provides a # shortcut for primary-key exact lookups. -# The following is identical to articles.get(id__exact=1). +# The following is identical to articles.get(id=1). >>> Article.objects.get(pk=1)
@@ -93,7 +93,7 @@ DoesNotExist: Article does not exist for ... True # You can initialize a model instance using positional arguments, which should -# match the field order as defined in the model... +# match the field order as defined in the model. >>> a2 = Article(None, 'Second article', datetime(2005, 7, 29)) >>> a2.save() >>> a2.id @@ -126,7 +126,8 @@ Traceback (most recent call last): ... TypeError: 'foo' is an invalid keyword argument for this function -# You can leave off the ID. +# You can leave off the value for an AutoField when creating an object, because +# it'll get filled in automatically when you save(). >>> a5 = Article(headline='Article 6', pub_date=datetime(2005, 7, 31)) >>> a5.save() >>> a5.id @@ -154,7 +155,7 @@ datetime.datetime(2005, 7, 31, 12, 30, 45) >>> a8.id 8L -# Saving an object again shouldn't create a new object -- it just saves the old one. +# Saving an object again doesn't create a new object -- it just saves the old one. >>> a8.save() >>> a8.id 8L @@ -174,6 +175,7 @@ True >>> Article.objects.get(id__exact=8) == Article.objects.get(id__exact=7) False +# dates() returns a list of available dates of the given scope for the given field. >>> Article.objects.dates('pub_date', 'year') [datetime.datetime(2005, 1, 1, 0, 0)] >>> Article.objects.dates('pub_date', 'month') @@ -185,7 +187,7 @@ False >>> Article.objects.dates('pub_date', 'day', order='DESC') [datetime.datetime(2005, 7, 31, 0, 0), datetime.datetime(2005, 7, 30, 0, 0), datetime.datetime(2005, 7, 29, 0, 0), datetime.datetime(2005, 7, 28, 0, 0)] -# Try some bad arguments to dates(). +# dates() requires valid arguments. >>> Article.objects.dates() Traceback (most recent call last): @@ -207,7 +209,16 @@ Traceback (most recent call last): ... AssertionError: 'order' must be either 'ASC' or 'DESC'. -# You can combine queries with & and | +# Use iterator() with dates() to return a generator that lazily requests each +# result one at a time, to save memory. +>>> for a in Article.objects.dates('pub_date', 'day', order='DESC').iterator(): +... print repr(a) +datetime.datetime(2005, 7, 31, 0, 0) +datetime.datetime(2005, 7, 30, 0, 0) +datetime.datetime(2005, 7, 29, 0, 0) +datetime.datetime(2005, 7, 28, 0, 0) + +# You can combine queries with & and |. >>> s1 = Article.objects.filter(id__exact=1) >>> s2 = Article.objects.filter(id__exact=2) >>> tmp = [a.id for a in list(s1 | s2)] @@ -231,7 +242,7 @@ AssertionError: 'order' must be either 'ASC' or 'DESC'. [
,
] # An Article instance doesn't have access to the "objects" attribute. -# That is only available as a class method. +# That's only available on the class. >>> a7.objects.all() Traceback (most recent call last): ... diff --git a/tests/modeltests/lookup/models.py b/tests/modeltests/lookup/models.py index 6364ffe6fd..cc089493d5 100644 --- a/tests/modeltests/lookup/models.py +++ b/tests/modeltests/lookup/models.py @@ -33,7 +33,8 @@ API_TESTS = """ >>> a7 = Article(headline='Article 7', pub_date=datetime(2005, 7, 27)) >>> a7.save() -# iterator() is a generator. +# Each QuerySet gets iterator(), which is a generator that "lazily" returns +# results using database-level iteration. >>> for a in Article.objects.iterator(): ... print a.headline Article 5 @@ -103,6 +104,20 @@ True [('headline', 'Article 7'), ('id', 7)] [('headline', 'Article 1'), ('id', 1)] +# You can use values() with iterator() for memory savings, because iterator() +# uses database-level iteration. +>>> for d in Article.objects.values('id', 'headline').iterator(): +... i = d.items() +... i.sort() +... i +[('headline', 'Article 5'), ('id', 5)] +[('headline', 'Article 6'), ('id', 6)] +[('headline', 'Article 4'), ('id', 4)] +[('headline', 'Article 2'), ('id', 2)] +[('headline', 'Article 3'), ('id', 3)] +[('headline', 'Article 7'), ('id', 7)] +[('headline', 'Article 1'), ('id', 1)] + # if you don't specify which fields, all are returned >>> list(Article.objects.filter(id=5).values()) == [{'id': 5, 'headline': 'Article 5', 'pub_date': datetime(2005, 8, 1, 9, 0)}] True