magic-removal: Created InBulkQuerySet, ValuesQuerySet and DateQuerySet, subclasses of QuerySet that provide custom iterator(). This lets you use iterator() with in_bulk(), values() and dates(). Also added unit tests.

git-svn-id: http://code.djangoproject.com/svn/django/branches/magic-removal@2200 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Adrian Holovaty 2006-01-31 02:21:31 +00:00
parent 2983bd8989
commit e51fcf8ec3
4 changed files with 113 additions and 61 deletions

View File

@ -85,6 +85,9 @@ class Manager(object):
def in_bulk(self, *args, **kwargs):
return QuerySet(self.model).in_bulk(*args, **kwargs)
def iterator(self, *args, **kwargs):
return QuerySet(self.model).iterator(*args, **kwargs)
def order_by(self, *args, **kwargs):
return QuerySet(self.model).order_by(*args, **kwargs)

View File

@ -190,67 +190,33 @@ class QuerySet(object):
_, sql, params = del_query._get_sql_clause(False)
cursor.execute("DELETE " + sql, params)
##################################################
# PUBLIC METHODS THAT RETURN A QUERYSET SUBCLASS #
##################################################
def in_bulk(self, id_list):
assert isinstance(id_list, list), "in_bulk() must be provided with a list of IDs."
assert id_list != [], "in_bulk() cannot be passed an empty ID list."
bulk_query = self._clone()
bulk_query._where.append("%s.%s IN (%s)" % (backend.quote_name(self.model._meta.db_table), backend.quote_name(self.model._meta.pk.column), ",".join(['%s'] * len(id_list))))
bulk_query._params.extend(id_list)
return dict([(obj._get_pk_val(), obj) for obj in bulk_query.iterator()])
return self._clone(klass=InBulkQuerySet, _id_list=id_list)
def values(self, *fields):
# select_related and select aren't supported in values().
values_query = self._clone(_select_related=False, _select={})
# 'fields' is a list of field names to fetch.
if fields:
columns = [self.model._meta.get_field(f, many_to_many=False).column for f in fields]
else: # Default to all fields.
columns = [f.column for f in self.model._meta.fields]
cursor = connection.cursor()
select, sql, params = values_query._get_sql_clause(True)
select = ['%s.%s' % (backend.quote_name(self.model._meta.db_table), backend.quote_name(c)) for c in columns]
cursor.execute("SELECT " + (self._distinct and "DISTINCT " or "") + ",".join(select) + sql, params)
while 1:
rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)
if not rows:
raise StopIteration
for row in rows:
yield dict(zip(columns, row))
return self._clone(klass=ValuesQuerySet, _fields=fields)
def dates(self, field_name, kind, order='ASC'):
"""
Returns a list of datetime objects representing all available dates
for the given field_name, scoped to 'kind'.
"""
from django.db.backends.util import typecast_timestamp
assert kind in ("month", "year", "day"), "'kind' must be one of 'year', 'month' or 'day'."
assert order in ('ASC', 'DESC'), "'order' must be either 'ASC' or 'DESC'."
# Let the FieldDoesNotExist exception propogate.
field = self.model._meta.get_field(field_name, many_to_many=False)
assert isinstance(field, DateField), "%r isn't a DateField." % field_name
return self._clone(klass=DateQuerySet, _field=field, _kind=kind, _order=order)
date_query = self._clone()
date_query._order_by = () # Clear this because it'll mess things up otherwise.
if field.null:
date_query._where.append('%s.%s IS NOT NULL' % \
(backend.quote_name(self.model._meta.db_table), backend.quote_name(field.column)))
select, sql, params = date_query._get_sql_clause(True)
sql = 'SELECT %s %s GROUP BY 1 ORDER BY 1 %s' % \
(backend.get_date_trunc_sql(kind, '%s.%s' % (backend.quote_name(self.model._meta.db_table),
backend.quote_name(field.column))), sql, order)
cursor = connection.cursor()
cursor.execute(sql, params)
# We have to manually run typecast_timestamp(str()) on the results, because
# MySQL doesn't automatically cast the result of date functions as datetime
# objects -- MySQL returns the values as strings, instead.
return [typecast_timestamp(str(row[0])) for row in cursor.fetchall()]
#############################################
# PUBLIC METHODS THAT RETURN A NEW QUERYSET #
#############################################
##################################################################
# PUBLIC METHODS THAT ALTER ATTRIBUTES AND RETURN A NEW QUERYSET #
##################################################################
def filter(self, *args, **kwargs):
"Returns a new QuerySet instance with the args ANDed to the existing set."
@ -285,8 +251,10 @@ class QuerySet(object):
# PRIVATE METHODS #
###################
def _clone(self, **kwargs):
c = QuerySet()
def _clone(self, klass=None, **kwargs):
if klass is None:
klass = self.__class__
c = klass()
c.model = self.model
c._filters = self._filters
c._order_by = self._order_by
@ -402,6 +370,61 @@ class QuerySet(object):
return select, " ".join(sql), params
class InBulkQuerySet(QuerySet):
def iterator(self):
self._where.append("%s.%s IN (%s)" % (backend.quote_name(self.model._meta.db_table), backend.quote_name(self.model._meta.pk.column), ",".join(['%s'] * len(self._id_list))))
self._params.extend(self._id_list)
yield dict([(obj._get_pk_val(), obj) for obj in QuerySet.iterator(self)])
def _get_data(self):
if self._result_cache is None:
for i in self.iterator():
self._result_cache = i
return self._result_cache
class ValuesQuerySet(QuerySet):
def iterator(self):
# select_related and select aren't supported in values().
self._select_related = False
self._select = {}
# self._fields is a list of field names to fetch.
if self._fields:
columns = [self.model._meta.get_field(f, many_to_many=False).column for f in self._fields]
field_names = [f.attname for f in self._fields]
else: # Default to all fields.
columns = [f.column for f in self.model._meta.fields]
field_names = [f.attname for f in self.model._meta.fields]
cursor = connection.cursor()
select, sql, params = self._get_sql_clause(True)
select = ['%s.%s' % (backend.quote_name(self.model._meta.db_table), backend.quote_name(c)) for c in columns]
cursor.execute("SELECT " + (self._distinct and "DISTINCT " or "") + ",".join(select) + sql, params)
while 1:
rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)
if not rows:
raise StopIteration
for row in rows:
yield dict(zip(field_names, row))
class DateQuerySet(QuerySet):
def iterator(self):
from django.db.backends.util import typecast_timestamp
self._order_by = () # Clear this because it'll mess things up otherwise.
if self._field.null:
date_query._where.append('%s.%s IS NOT NULL' % \
(backend.quote_name(self.model._meta.db_table), backend.quote_name(self._field.column)))
select, sql, params = self._get_sql_clause(True)
sql = 'SELECT %s %s GROUP BY 1 ORDER BY 1 %s' % \
(backend.get_date_trunc_sql(self._kind, '%s.%s' % (backend.quote_name(self.model._meta.db_table),
backend.quote_name(self._field.column))), sql, self._order)
cursor = connection.cursor()
cursor.execute(sql, params)
# We have to manually run typecast_timestamp(str()) on the results, because
# MySQL doesn't automatically cast the result of date functions as datetime
# objects -- MySQL returns the values as strings, instead.
return [typecast_timestamp(str(row[0])) for row in cursor.fetchall()]
class QOperator:
"Base class for QAnd and QOr"
def __init__(self, *args):

View File

@ -37,14 +37,13 @@ datetime.datetime(2005, 7, 28, 0, 0)
>>> a.headline = 'Area woman programs in Python'
>>> a.save()
# Listing objects displays all the articles in the database. Note that the article
# is represented by "<Article object>", because we haven't given the Article
# model a __repr__() method.
# Article.objects.all() returns all the articles in the database. Note that
# the article is represented by "<Article object>", because we haven't given
# the Article model a __repr__() method.
>>> Article.objects.all()
[<Article object>]
# Django provides a rich database lookup API that's entirely driven by
# keyword arguments.
# Django provides a rich database lookup API.
>>> Article.objects.get(id__exact=1)
<Article object>
>>> Article.objects.get(headline__startswith='Area woman')
@ -56,7 +55,7 @@ datetime.datetime(2005, 7, 28, 0, 0)
>>> Article.objects.get(pub_date__year=2005, pub_date__month=7, pub_date__day=28)
<Article object>
# You can omit __exact if you want
# The "__exact" lookup type can be omitted, as a shortcut.
>>> Article.objects.get(id=1)
<Article object>
>>> Article.objects.get(headline='Area woman programs in Python')
@ -69,7 +68,8 @@ datetime.datetime(2005, 7, 28, 0, 0)
>>> Article.objects.filter(pub_date__year=2005, pub_date__month=7)
[<Article object>]
# Django raises an ArticleDoesNotExist exception for get()
# Django raises an Article.DoesNotExist exception for get() if the parameters
# don't match any object.
>>> Article.objects.get(id__exact=2)
Traceback (most recent call last):
...
@ -82,7 +82,7 @@ DoesNotExist: Article does not exist for ...
# Lookup by a primary key is the most common case, so Django provides a
# shortcut for primary-key exact lookups.
# The following is identical to articles.get(id__exact=1).
# The following is identical to articles.get(id=1).
>>> Article.objects.get(pk=1)
<Article object>
@ -93,7 +93,7 @@ DoesNotExist: Article does not exist for ...
True
# You can initialize a model instance using positional arguments, which should
# match the field order as defined in the model...
# match the field order as defined in the model.
>>> a2 = Article(None, 'Second article', datetime(2005, 7, 29))
>>> a2.save()
>>> a2.id
@ -126,7 +126,8 @@ Traceback (most recent call last):
...
TypeError: 'foo' is an invalid keyword argument for this function
# You can leave off the ID.
# You can leave off the value for an AutoField when creating an object, because
# it'll get filled in automatically when you save().
>>> a5 = Article(headline='Article 6', pub_date=datetime(2005, 7, 31))
>>> a5.save()
>>> a5.id
@ -154,7 +155,7 @@ datetime.datetime(2005, 7, 31, 12, 30, 45)
>>> a8.id
8L
# Saving an object again shouldn't create a new object -- it just saves the old one.
# Saving an object again doesn't create a new object -- it just saves the old one.
>>> a8.save()
>>> a8.id
8L
@ -174,6 +175,7 @@ True
>>> Article.objects.get(id__exact=8) == Article.objects.get(id__exact=7)
False
# dates() returns a list of available dates of the given scope for the given field.
>>> Article.objects.dates('pub_date', 'year')
[datetime.datetime(2005, 1, 1, 0, 0)]
>>> Article.objects.dates('pub_date', 'month')
@ -185,7 +187,7 @@ False
>>> Article.objects.dates('pub_date', 'day', order='DESC')
[datetime.datetime(2005, 7, 31, 0, 0), datetime.datetime(2005, 7, 30, 0, 0), datetime.datetime(2005, 7, 29, 0, 0), datetime.datetime(2005, 7, 28, 0, 0)]
# Try some bad arguments to dates().
# dates() requires valid arguments.
>>> Article.objects.dates()
Traceback (most recent call last):
@ -207,7 +209,16 @@ Traceback (most recent call last):
...
AssertionError: 'order' must be either 'ASC' or 'DESC'.
# You can combine queries with & and |
# Use iterator() with dates() to return a generator that lazily requests each
# result one at a time, to save memory.
>>> for a in Article.objects.dates('pub_date', 'day', order='DESC').iterator():
... print repr(a)
datetime.datetime(2005, 7, 31, 0, 0)
datetime.datetime(2005, 7, 30, 0, 0)
datetime.datetime(2005, 7, 29, 0, 0)
datetime.datetime(2005, 7, 28, 0, 0)
# You can combine queries with & and |.
>>> s1 = Article.objects.filter(id__exact=1)
>>> s2 = Article.objects.filter(id__exact=2)
>>> tmp = [a.id for a in list(s1 | s2)]
@ -231,7 +242,7 @@ AssertionError: 'order' must be either 'ASC' or 'DESC'.
[<Article object>, <Article object>]
# An Article instance doesn't have access to the "objects" attribute.
# That is only available as a class method.
# That's only available on the class.
>>> a7.objects.all()
Traceback (most recent call last):
...

View File

@ -33,7 +33,8 @@ API_TESTS = """
>>> a7 = Article(headline='Article 7', pub_date=datetime(2005, 7, 27))
>>> a7.save()
# iterator() is a generator.
# Each QuerySet gets iterator(), which is a generator that "lazily" returns
# results using database-level iteration.
>>> for a in Article.objects.iterator():
... print a.headline
Article 5
@ -103,6 +104,20 @@ True
[('headline', 'Article 7'), ('id', 7)]
[('headline', 'Article 1'), ('id', 1)]
# You can use values() with iterator() for memory savings, because iterator()
# uses database-level iteration.
>>> for d in Article.objects.values('id', 'headline').iterator():
... i = d.items()
... i.sort()
... i
[('headline', 'Article 5'), ('id', 5)]
[('headline', 'Article 6'), ('id', 6)]
[('headline', 'Article 4'), ('id', 4)]
[('headline', 'Article 2'), ('id', 2)]
[('headline', 'Article 3'), ('id', 3)]
[('headline', 'Article 7'), ('id', 7)]
[('headline', 'Article 1'), ('id', 1)]
# if you don't specify which fields, all are returned
>>> list(Article.objects.filter(id=5).values()) == [{'id': 5, 'headline': 'Article 5', 'pub_date': datetime(2005, 8, 1, 9, 0)}]
True