A large number of stylistic cleanups across django/db/

This commit is contained in:
Alex Gaynor 2013-07-08 10:39:54 +10:00
parent 0b69a75502
commit 03d9566e0d
48 changed files with 383 additions and 195 deletions

View File

@ -1,20 +1,25 @@
import warnings import warnings
from django.core import signals from django.core import signals
from django.db.utils import (DEFAULT_DB_ALIAS, from django.db.utils import (DEFAULT_DB_ALIAS, DataError, OperationalError,
DataError, OperationalError, IntegrityError, InternalError, IntegrityError, InternalError, ProgrammingError, NotSupportedError,
ProgrammingError, NotSupportedError, DatabaseError, DatabaseError, InterfaceError, Error, load_backend,
InterfaceError, Error, ConnectionHandler, ConnectionRouter)
load_backend, ConnectionHandler, ConnectionRouter)
from django.utils.functional import cached_property from django.utils.functional import cached_property
__all__ = ('backend', 'connection', 'connections', 'router', 'DatabaseError',
'IntegrityError', 'DEFAULT_DB_ALIAS') __all__ = [
'backend', 'connection', 'connections', 'router', 'DatabaseError',
'IntegrityError', 'InternalError', 'ProgrammingError', 'DataError',
'NotSupportedError', 'Error', 'InterfaceError', 'OperationalError',
'DEFAULT_DB_ALIAS'
]
connections = ConnectionHandler() connections = ConnectionHandler()
router = ConnectionRouter() router = ConnectionRouter()
# `connection`, `DatabaseError` and `IntegrityError` are convenient aliases # `connection`, `DatabaseError` and `IntegrityError` are convenient aliases
# for backend bits. # for backend bits.
@ -70,6 +75,7 @@ class DefaultBackendProxy(object):
backend = DefaultBackendProxy() backend = DefaultBackendProxy()
def close_connection(**kwargs): def close_connection(**kwargs):
warnings.warn( warnings.warn(
"close_connection is superseded by close_old_connections.", "close_connection is superseded by close_old_connections.",
@ -83,12 +89,14 @@ def close_connection(**kwargs):
transaction.abort(conn) transaction.abort(conn)
connections[conn].close() connections[conn].close()
# Register an event to reset saved queries when a Django request is started. # Register an event to reset saved queries when a Django request is started.
def reset_queries(**kwargs): def reset_queries(**kwargs):
for conn in connections.all(): for conn in connections.all():
conn.queries = [] conn.queries = []
signals.request_started.connect(reset_queries) signals.request_started.connect(reset_queries)
# Register an event to reset transaction state and close connections past # Register an event to reset transaction state and close connections past
# their lifetime. NB: abort() doesn't do anything outside of a transaction. # their lifetime. NB: abort() doesn't do anything outside of a transaction.
def close_old_connections(**kwargs): def close_old_connections(**kwargs):

View File

@ -1167,6 +1167,7 @@ FieldInfo = namedtuple('FieldInfo',
'name type_code display_size internal_size precision scale null_ok' 'name type_code display_size internal_size precision scale null_ok'
) )
class BaseDatabaseIntrospection(object): class BaseDatabaseIntrospection(object):
""" """
This class encapsulates all backend-specific introspection utilities This class encapsulates all backend-specific introspection utilities

View File

@ -251,12 +251,13 @@ class BaseDatabaseCreation(object):
r_col = model._meta.get_field(f.rel.field_name).column r_col = model._meta.get_field(f.rel.field_name).column
r_name = '%s_refs_%s_%s' % ( r_name = '%s_refs_%s_%s' % (
col, r_col, self._digest(table, r_table)) col, r_col, self._digest(table, r_table))
output.append('%s %s %s %s;' % \ output.append('%s %s %s %s;' % (
(style.SQL_KEYWORD('ALTER TABLE'), style.SQL_KEYWORD('ALTER TABLE'),
style.SQL_TABLE(qn(table)), style.SQL_TABLE(qn(table)),
style.SQL_KEYWORD(self.connection.ops.drop_foreignkey_sql()), style.SQL_KEYWORD(self.connection.ops.drop_foreignkey_sql()),
style.SQL_FIELD(qn(truncate_name( style.SQL_FIELD(qn(truncate_name(
r_name, self.connection.ops.max_name_length()))))) r_name, self.connection.ops.max_name_length())))
))
del references_to_delete[model] del references_to_delete[model]
return output return output

View File

@ -8,33 +8,43 @@ ImproperlyConfigured.
""" """
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.db.backends import * from django.db.backends import (BaseDatabaseOperations, BaseDatabaseClient,
BaseDatabaseIntrospection, BaseDatabaseWrapper, BaseDatabaseFeatures,
BaseDatabaseValidation)
from django.db.backends.creation import BaseDatabaseCreation from django.db.backends.creation import BaseDatabaseCreation
def complain(*args, **kwargs): def complain(*args, **kwargs):
raise ImproperlyConfigured("settings.DATABASES is improperly configured. " raise ImproperlyConfigured("settings.DATABASES is improperly configured. "
"Please supply the ENGINE value. Check " "Please supply the ENGINE value. Check "
"settings documentation for more details.") "settings documentation for more details.")
def ignore(*args, **kwargs): def ignore(*args, **kwargs):
pass pass
class DatabaseError(Exception): class DatabaseError(Exception):
pass pass
class IntegrityError(DatabaseError): class IntegrityError(DatabaseError):
pass pass
class DatabaseOperations(BaseDatabaseOperations): class DatabaseOperations(BaseDatabaseOperations):
quote_name = complain quote_name = complain
class DatabaseClient(BaseDatabaseClient): class DatabaseClient(BaseDatabaseClient):
runshell = complain runshell = complain
class DatabaseCreation(BaseDatabaseCreation): class DatabaseCreation(BaseDatabaseCreation):
create_test_db = ignore create_test_db = ignore
destroy_test_db = ignore destroy_test_db = ignore
class DatabaseIntrospection(BaseDatabaseIntrospection): class DatabaseIntrospection(BaseDatabaseIntrospection):
get_table_list = complain get_table_list = complain
get_table_description = complain get_table_description = complain
@ -42,6 +52,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
get_indexes = complain get_indexes = complain
get_key_columns = complain get_key_columns = complain
class DatabaseWrapper(BaseDatabaseWrapper): class DatabaseWrapper(BaseDatabaseWrapper):
operators = {} operators = {}
# Override the base class implementations with null # Override the base class implementations with null

View File

@ -36,8 +36,9 @@ except ImportError:
pytz = None pytz = None
from django.conf import settings from django.conf import settings
from django.db import utils from django.db import (utils, BaseDatabaseFeatures, BaseDatabaseOperations,
from django.db.backends import * BaseDatabaseWrapper)
from django.db.backends import util
from django.db.backends.mysql.client import DatabaseClient from django.db.backends.mysql.client import DatabaseClient
from django.db.backends.mysql.creation import DatabaseCreation from django.db.backends.mysql.creation import DatabaseCreation
from django.db.backends.mysql.introspection import DatabaseIntrospection from django.db.backends.mysql.introspection import DatabaseIntrospection
@ -57,6 +58,7 @@ IntegrityError = Database.IntegrityError
# It's impossible to import datetime_or_None directly from MySQLdb.times # It's impossible to import datetime_or_None directly from MySQLdb.times
parse_datetime = conversions[FIELD_TYPE.DATETIME] parse_datetime = conversions[FIELD_TYPE.DATETIME]
def parse_datetime_with_timezone_support(value): def parse_datetime_with_timezone_support(value):
dt = parse_datetime(value) dt = parse_datetime(value)
# Confirm that dt is naive before overwriting its tzinfo. # Confirm that dt is naive before overwriting its tzinfo.
@ -64,6 +66,7 @@ def parse_datetime_with_timezone_support(value):
dt = dt.replace(tzinfo=timezone.utc) dt = dt.replace(tzinfo=timezone.utc)
return dt return dt
def adapt_datetime_with_timezone_support(value, conv): def adapt_datetime_with_timezone_support(value, conv):
# Equivalent to DateTimeField.get_db_prep_value. Used only by raw SQL. # Equivalent to DateTimeField.get_db_prep_value. Used only by raw SQL.
if settings.USE_TZ: if settings.USE_TZ:
@ -98,6 +101,7 @@ django_conversions.update({
# http://dev.mysql.com/doc/refman/5.0/en/news.html . # http://dev.mysql.com/doc/refman/5.0/en/news.html .
server_version_re = re.compile(r'(\d{1,2})\.(\d{1,2})\.(\d{1,2})') server_version_re = re.compile(r'(\d{1,2})\.(\d{1,2})\.(\d{1,2})')
# MySQLdb-1.2.1 and newer automatically makes use of SHOW WARNINGS on # MySQLdb-1.2.1 and newer automatically makes use of SHOW WARNINGS on
# MySQL-4.1 and newer, so the MysqlDebugWrapper is unnecessary. Since the # MySQL-4.1 and newer, so the MysqlDebugWrapper is unnecessary. Since the
# point is to raise Warnings as exceptions, this can be done with the Python # point is to raise Warnings as exceptions, this can be done with the Python
@ -148,6 +152,7 @@ class CursorWrapper(object):
def __iter__(self): def __iter__(self):
return iter(self.cursor) return iter(self.cursor)
class DatabaseFeatures(BaseDatabaseFeatures): class DatabaseFeatures(BaseDatabaseFeatures):
empty_fetchmany_value = () empty_fetchmany_value = ()
update_can_self_select = False update_can_self_select = False
@ -204,6 +209,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
cursor.execute("SELECT 1 FROM mysql.time_zone LIMIT 1") cursor.execute("SELECT 1 FROM mysql.time_zone LIMIT 1")
return cursor.fetchone() is not None return cursor.fetchone() is not None
class DatabaseOperations(BaseDatabaseOperations): class DatabaseOperations(BaseDatabaseOperations):
compiler_module = "django.db.backends.mysql.compiler" compiler_module = "django.db.backends.mysql.compiler"
@ -319,7 +325,7 @@ class DatabaseOperations(BaseDatabaseOperations):
# Truncate already resets the AUTO_INCREMENT field from # Truncate already resets the AUTO_INCREMENT field from
# MySQL version 5.0.13 onwards. Refs #16961. # MySQL version 5.0.13 onwards. Refs #16961.
if self.connection.mysql_version < (5, 0, 13): if self.connection.mysql_version < (5, 0, 13):
return ["%s %s %s %s %s;" % \ return ["%s %s %s %s %s;" %
(style.SQL_KEYWORD('ALTER'), (style.SQL_KEYWORD('ALTER'),
style.SQL_KEYWORD('TABLE'), style.SQL_KEYWORD('TABLE'),
style.SQL_TABLE(self.quote_name(sequence['table'])), style.SQL_TABLE(self.quote_name(sequence['table'])),
@ -373,6 +379,7 @@ class DatabaseOperations(BaseDatabaseOperations):
items_sql = "(%s)" % ", ".join(["%s"] * len(fields)) items_sql = "(%s)" % ", ".join(["%s"] * len(fields))
return "VALUES " + ", ".join([items_sql] * num_values) return "VALUES " + ", ".join([items_sql] * num_values)
class DatabaseWrapper(BaseDatabaseWrapper): class DatabaseWrapper(BaseDatabaseWrapper):
vendor = 'mysql' vendor = 'mysql'
operators = { operators = {

View File

@ -3,6 +3,7 @@ import sys
from django.db.backends import BaseDatabaseClient from django.db.backends import BaseDatabaseClient
class DatabaseClient(BaseDatabaseClient): class DatabaseClient(BaseDatabaseClient):
executable_name = 'mysql' executable_name = 'mysql'
@ -37,4 +38,3 @@ class DatabaseClient(BaseDatabaseClient):
sys.exit(os.system(" ".join(args))) sys.exit(os.system(" ".join(args)))
else: else:
os.execvp(self.executable_name, args) os.execvp(self.executable_name, args)

View File

@ -22,20 +22,26 @@ class SQLCompiler(compiler.SQLCompiler):
sql, params = self.as_sql() sql, params = self.as_sql()
return '(%s) IN (%s)' % (', '.join(['%s.%s' % (qn(alias), qn2(column)) for column in columns]), sql), params return '(%s) IN (%s)' % (', '.join(['%s.%s' % (qn(alias), qn2(column)) for column in columns]), sql), params
class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler): class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler):
pass pass
class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler): class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler):
pass pass
class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler): class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler):
pass pass
class SQLAggregateCompiler(compiler.SQLAggregateCompiler, SQLCompiler): class SQLAggregateCompiler(compiler.SQLAggregateCompiler, SQLCompiler):
pass pass
class SQLDateCompiler(compiler.SQLDateCompiler, SQLCompiler): class SQLDateCompiler(compiler.SQLDateCompiler, SQLCompiler):
pass pass
class SQLDateTimeCompiler(compiler.SQLDateTimeCompiler, SQLCompiler): class SQLDateTimeCompiler(compiler.SQLDateTimeCompiler, SQLCompiler):
pass pass

View File

@ -1,5 +1,6 @@
from django.db.backends.creation import BaseDatabaseCreation from django.db.backends.creation import BaseDatabaseCreation
class DatabaseCreation(BaseDatabaseCreation): class DatabaseCreation(BaseDatabaseCreation):
# This dictionary maps Field objects to their associated MySQL column # This dictionary maps Field objects to their associated MySQL column
# types, as strings. Column-type strings can contain format strings; they'll # types, as strings. Column-type strings can contain format strings; they'll

View File

@ -7,6 +7,7 @@ from django.utils.encoding import force_text
foreign_key_re = re.compile(r"\sCONSTRAINT `[^`]*` FOREIGN KEY \(`([^`]*)`\) REFERENCES `([^`]*)` \(`([^`]*)`\)") foreign_key_re = re.compile(r"\sCONSTRAINT `[^`]*` FOREIGN KEY \(`([^`]*)`\) REFERENCES `([^`]*)` \(`([^`]*)`\)")
class DatabaseIntrospection(BaseDatabaseIntrospection): class DatabaseIntrospection(BaseDatabaseIntrospection):
data_types_reverse = { data_types_reverse = {
FIELD_TYPE.BLOB: 'TextField', FIELD_TYPE.BLOB: 'TextField',
@ -116,4 +117,3 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
continue continue
indexes[row[4]] = {'primary_key': (row[2] == 'PRIMARY'), 'unique': not bool(row[1])} indexes[row[4]] = {'primary_key': (row[2] == 'PRIMARY'), 'unique': not bool(row[1])}
return indexes return indexes

View File

@ -1,5 +1,6 @@
from django.db.backends import BaseDatabaseValidation from django.db.backends import BaseDatabaseValidation
class DatabaseValidation(BaseDatabaseValidation): class DatabaseValidation(BaseDatabaseValidation):
def validate_field(self, errors, opts, f): def validate_field(self, errors, opts, f):
""" """

View File

@ -7,11 +7,12 @@ from __future__ import unicode_literals
import decimal import decimal
import re import re
import platform
import sys import sys
import warnings import warnings
def _setup_environment(environ): def _setup_environment(environ):
import platform
# Cygwin requires some special voodoo to set the environment variables # Cygwin requires some special voodoo to set the environment variables
# properly so that Oracle will see them. # properly so that Oracle will see them.
if platform.system().upper().startswith('CYGWIN'): if platform.system().upper().startswith('CYGWIN'):
@ -90,6 +91,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_tablespaces = True supports_tablespaces = True
supports_sequence_reset = False supports_sequence_reset = False
class DatabaseOperations(BaseDatabaseOperations): class DatabaseOperations(BaseDatabaseOperations):
compiler_module = "django.db.backends.oracle.compiler" compiler_module = "django.db.backends.oracle.compiler"

View File

@ -3,6 +3,7 @@ import sys
from django.db.backends import BaseDatabaseClient from django.db.backends import BaseDatabaseClient
class DatabaseClient(BaseDatabaseClient): class DatabaseClient(BaseDatabaseClient):
executable_name = 'sqlplus' executable_name = 'sqlplus'
@ -13,4 +14,3 @@ class DatabaseClient(BaseDatabaseClient):
sys.exit(os.system(" ".join(args))) sys.exit(os.system(" ".join(args)))
else: else:
os.execvp(self.executable_name, args) os.execvp(self.executable_name, args)

View File

@ -60,17 +60,22 @@ class SQLCompiler(compiler.SQLCompiler):
class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler): class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler):
pass pass
class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler): class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler):
pass pass
class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler): class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler):
pass pass
class SQLAggregateCompiler(compiler.SQLAggregateCompiler, SQLCompiler): class SQLAggregateCompiler(compiler.SQLAggregateCompiler, SQLCompiler):
pass pass
class SQLDateCompiler(compiler.SQLDateCompiler, SQLCompiler): class SQLDateCompiler(compiler.SQLDateCompiler, SQLCompiler):
pass pass
class SQLDateTimeCompiler(compiler.SQLDateTimeCompiler, SQLCompiler): class SQLDateTimeCompiler(compiler.SQLDateTimeCompiler, SQLCompiler):
pass pass

View File

@ -5,9 +5,11 @@ from django.conf import settings
from django.db.backends.creation import BaseDatabaseCreation from django.db.backends.creation import BaseDatabaseCreation
from django.utils.six.moves import input from django.utils.six.moves import input
TEST_DATABASE_PREFIX = 'test_' TEST_DATABASE_PREFIX = 'test_'
PASSWORD = 'Im_a_lumberjack' PASSWORD = 'Im_a_lumberjack'
class DatabaseCreation(BaseDatabaseCreation): class DatabaseCreation(BaseDatabaseCreation):
# This dictionary maps Field objects to their associated Oracle column # This dictionary maps Field objects to their associated Oracle column
# types, as strings. Column-type strings can contain format strings; they'll # types, as strings. Column-type strings can contain format strings; they'll

View File

@ -1,10 +1,13 @@
from django.db.backends import BaseDatabaseIntrospection, FieldInfo
from django.utils.encoding import force_text
import cx_Oracle
import re import re
import cx_Oracle
from django.db.backends import BaseDatabaseIntrospection, FieldInfo
from django.utils.encoding import force_text
foreign_key_re = re.compile(r"\sCONSTRAINT `[^`]*` FOREIGN KEY \(`([^`]*)`\) REFERENCES `([^`]*)` \(`([^`]*)`\)") foreign_key_re = re.compile(r"\sCONSTRAINT `[^`]*` FOREIGN KEY \(`([^`]*)`\) REFERENCES `([^`]*)` \(`([^`]*)`\)")
class DatabaseIntrospection(BaseDatabaseIntrospection): class DatabaseIntrospection(BaseDatabaseIntrospection):
# Maps type objects to Django Field types. # Maps type objects to Django Field types.
data_types_reverse = { data_types_reverse = {

View File

@ -6,7 +6,9 @@ Requires psycopg 2: http://initd.org/projects/psycopg2
import logging import logging
import sys import sys
from django.db.backends import * from django.conf import settings
from django.db.backends import (BaseDatabaseFeatures, BaseDatabaseWrapper,
BaseDatabaseValidation)
from django.db.backends.postgresql_psycopg2.operations import DatabaseOperations from django.db.backends.postgresql_psycopg2.operations import DatabaseOperations
from django.db.backends.postgresql_psycopg2.client import DatabaseClient from django.db.backends.postgresql_psycopg2.client import DatabaseClient
from django.db.backends.postgresql_psycopg2.creation import DatabaseCreation from django.db.backends.postgresql_psycopg2.creation import DatabaseCreation
@ -33,11 +35,13 @@ psycopg2.extensions.register_adapter(SafeText, psycopg2.extensions.QuotedString)
logger = logging.getLogger('django.db.backends') logger = logging.getLogger('django.db.backends')
def utc_tzinfo_factory(offset): def utc_tzinfo_factory(offset):
if offset != 0: if offset != 0:
raise AssertionError("database connection isn't set to UTC") raise AssertionError("database connection isn't set to UTC")
return utc return utc
class DatabaseFeatures(BaseDatabaseFeatures): class DatabaseFeatures(BaseDatabaseFeatures):
needs_datetime_string_cast = False needs_datetime_string_cast = False
can_return_id_from_insert = True can_return_id_from_insert = True
@ -52,6 +56,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_transactions = True supports_transactions = True
can_distinct_on_fields = True can_distinct_on_fields = True
class DatabaseWrapper(BaseDatabaseWrapper): class DatabaseWrapper(BaseDatabaseWrapper):
vendor = 'postgresql' vendor = 'postgresql'
operators = { operators = {
@ -132,7 +137,8 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# Set the time zone in autocommit mode (see #17062) # Set the time zone in autocommit mode (see #17062)
self.set_autocommit(True) self.set_autocommit(True)
self.connection.cursor().execute( self.connection.cursor().execute(
self.ops.set_time_zone_sql(), [tz]) self.ops.set_time_zone_sql(), [tz]
)
self.connection.set_isolation_level(self.isolation_level) self.connection.set_isolation_level(self.isolation_level)
def create_cursor(self): def create_cursor(self):

View File

@ -3,6 +3,7 @@ import sys
from django.db.backends import BaseDatabaseClient from django.db.backends import BaseDatabaseClient
class DatabaseClient(BaseDatabaseClient): class DatabaseClient(BaseDatabaseClient):
executable_name = 'psql' executable_name = 'psql'
@ -20,4 +21,3 @@ class DatabaseClient(BaseDatabaseClient):
sys.exit(os.system(" ".join(args))) sys.exit(os.system(" ".join(args)))
else: else:
os.execvp(self.executable_name, args) os.execvp(self.executable_name, args)

View File

@ -135,7 +135,7 @@ class DatabaseOperations(BaseDatabaseOperations):
# This will be the case if it's an m2m using an autogenerated # This will be the case if it's an m2m using an autogenerated
# intermediate table (see BaseDatabaseIntrospection.sequence_list) # intermediate table (see BaseDatabaseIntrospection.sequence_list)
column_name = 'id' column_name = 'id'
sql.append("%s setval(pg_get_serial_sequence('%s','%s'), 1, false);" % \ sql.append("%s setval(pg_get_serial_sequence('%s','%s'), 1, false);" %
(style.SQL_KEYWORD('SELECT'), (style.SQL_KEYWORD('SELECT'),
style.SQL_TABLE(self.quote_name(table_name)), style.SQL_TABLE(self.quote_name(table_name)),
style.SQL_FIELD(column_name)) style.SQL_FIELD(column_name))
@ -161,7 +161,7 @@ class DatabaseOperations(BaseDatabaseOperations):
for f in model._meta.local_fields: for f in model._meta.local_fields:
if isinstance(f, models.AutoField): if isinstance(f, models.AutoField):
output.append("%s setval(pg_get_serial_sequence('%s','%s'), coalesce(max(%s), 1), max(%s) %s null) %s %s;" % \ output.append("%s setval(pg_get_serial_sequence('%s','%s'), coalesce(max(%s), 1), max(%s) %s null) %s %s;" %
(style.SQL_KEYWORD('SELECT'), (style.SQL_KEYWORD('SELECT'),
style.SQL_TABLE(qn(model._meta.db_table)), style.SQL_TABLE(qn(model._meta.db_table)),
style.SQL_FIELD(f.column), style.SQL_FIELD(f.column),
@ -173,7 +173,7 @@ class DatabaseOperations(BaseDatabaseOperations):
break # Only one AutoField is allowed per model, so don't bother continuing. break # Only one AutoField is allowed per model, so don't bother continuing.
for f in model._meta.many_to_many: for f in model._meta.many_to_many:
if not f.rel.through: if not f.rel.through:
output.append("%s setval(pg_get_serial_sequence('%s','%s'), coalesce(max(%s), 1), max(%s) %s null) %s %s;" % \ output.append("%s setval(pg_get_serial_sequence('%s','%s'), coalesce(max(%s), 1), max(%s) %s null) %s %s;" %
(style.SQL_KEYWORD('SELECT'), (style.SQL_KEYWORD('SELECT'),
style.SQL_TABLE(qn(f.m2m_db_table())), style.SQL_TABLE(qn(f.m2m_db_table())),
style.SQL_FIELD('id'), style.SQL_FIELD('id'),

View File

@ -21,6 +21,7 @@ def _parse_version(text):
except (ValueError, TypeError): except (ValueError, TypeError):
return int(major) * 10000 + int(major2) * 100 return int(major) * 10000 + int(major2) * 100
def get_version(connection): def get_version(connection):
""" """
Returns an integer representing the major, minor and revision number of the Returns an integer representing the major, minor and revision number of the

View File

@ -11,8 +11,10 @@ import decimal
import warnings import warnings
import re import re
from django.conf import settings
from django.db import utils from django.db import utils
from django.db.backends import * from django.db.backends import (util, BaseDatabaseFeatures,
BaseDatabaseOperations, BaseDatabaseWrapper, BaseDatabaseValidation)
from django.db.backends.sqlite3.client import DatabaseClient from django.db.backends.sqlite3.client import DatabaseClient
from django.db.backends.sqlite3.creation import DatabaseCreation from django.db.backends.sqlite3.creation import DatabaseCreation
from django.db.backends.sqlite3.introspection import DatabaseIntrospection from django.db.backends.sqlite3.introspection import DatabaseIntrospection
@ -42,6 +44,7 @@ except ImportError:
DatabaseError = Database.DatabaseError DatabaseError = Database.DatabaseError
IntegrityError = Database.IntegrityError IntegrityError = Database.IntegrityError
def parse_datetime_with_timezone_support(value): def parse_datetime_with_timezone_support(value):
dt = parse_datetime(value) dt = parse_datetime(value)
# Confirm that dt is naive before overwriting its tzinfo. # Confirm that dt is naive before overwriting its tzinfo.
@ -49,6 +52,7 @@ def parse_datetime_with_timezone_support(value):
dt = dt.replace(tzinfo=timezone.utc) dt = dt.replace(tzinfo=timezone.utc)
return dt return dt
def adapt_datetime_with_timezone_support(value): def adapt_datetime_with_timezone_support(value):
# Equivalent to DateTimeField.get_db_prep_value. Used only by raw SQL. # Equivalent to DateTimeField.get_db_prep_value. Used only by raw SQL.
if settings.USE_TZ: if settings.USE_TZ:
@ -61,6 +65,7 @@ def adapt_datetime_with_timezone_support(value):
value = value.astimezone(timezone.utc).replace(tzinfo=None) value = value.astimezone(timezone.utc).replace(tzinfo=None)
return value.isoformat(str(" ")) return value.isoformat(str(" "))
def decoder(conv_func): def decoder(conv_func):
""" The Python sqlite3 interface returns always byte strings. """ The Python sqlite3 interface returns always byte strings.
This function converts the received value to a regular string before This function converts the received value to a regular string before
@ -81,6 +86,7 @@ Database.register_adapter(decimal.Decimal, util.rev_typecast_decimal)
Database.register_adapter(str, lambda s: s.decode('utf-8')) Database.register_adapter(str, lambda s: s.decode('utf-8'))
Database.register_adapter(SafeBytes, lambda s: s.decode('utf-8')) Database.register_adapter(SafeBytes, lambda s: s.decode('utf-8'))
class DatabaseFeatures(BaseDatabaseFeatures): class DatabaseFeatures(BaseDatabaseFeatures):
# SQLite cannot handle us only partially reading from a cursor's result set # SQLite cannot handle us only partially reading from a cursor's result set
# and then writing the same rows to the database in another cursor. This # and then writing the same rows to the database in another cursor. This
@ -124,6 +130,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
def has_zoneinfo_database(self): def has_zoneinfo_database(self):
return pytz is not None return pytz is not None
class DatabaseOperations(BaseDatabaseOperations): class DatabaseOperations(BaseDatabaseOperations):
def bulk_batch_size(self, fields, objs): def bulk_batch_size(self, fields, objs):
""" """
@ -272,6 +279,7 @@ class DatabaseOperations(BaseDatabaseOperations):
res.extend(["UNION ALL SELECT %s" % ", ".join(["%s"] * len(fields))] * (num_values - 1)) res.extend(["UNION ALL SELECT %s" % ", ".join(["%s"] * len(fields))] * (num_values - 1))
return " ".join(res) return " ".join(res)
class DatabaseWrapper(BaseDatabaseWrapper): class DatabaseWrapper(BaseDatabaseWrapper):
vendor = 'sqlite' vendor = 'sqlite'
# SQLite requires LIKE statements to include an ESCAPE clause if the value # SQLite requires LIKE statements to include an ESCAPE clause if the value
@ -426,6 +434,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
FORMAT_QMARK_REGEX = re.compile(r'(?<!%)%s') FORMAT_QMARK_REGEX = re.compile(r'(?<!%)%s')
class SQLiteCursorWrapper(Database.Cursor): class SQLiteCursorWrapper(Database.Cursor):
""" """
Django uses "format" style placeholders, but pysqlite2 uses "qmark" style. Django uses "format" style placeholders, but pysqlite2 uses "qmark" style.
@ -445,6 +454,7 @@ class SQLiteCursorWrapper(Database.Cursor):
def convert_query(self, query): def convert_query(self, query):
return FORMAT_QMARK_REGEX.sub('?', query).replace('%%', '%') return FORMAT_QMARK_REGEX.sub('?', query).replace('%%', '%')
def _sqlite_date_extract(lookup_type, dt): def _sqlite_date_extract(lookup_type, dt):
if dt is None: if dt is None:
return None return None
@ -457,6 +467,7 @@ def _sqlite_date_extract(lookup_type, dt):
else: else:
return getattr(dt, lookup_type) return getattr(dt, lookup_type)
def _sqlite_date_trunc(lookup_type, dt): def _sqlite_date_trunc(lookup_type, dt):
try: try:
dt = util.typecast_timestamp(dt) dt = util.typecast_timestamp(dt)
@ -469,6 +480,7 @@ def _sqlite_date_trunc(lookup_type, dt):
elif lookup_type == 'day': elif lookup_type == 'day':
return "%i-%02i-%02i" % (dt.year, dt.month, dt.day) return "%i-%02i-%02i" % (dt.year, dt.month, dt.day)
def _sqlite_datetime_extract(lookup_type, dt, tzname): def _sqlite_datetime_extract(lookup_type, dt, tzname):
if dt is None: if dt is None:
return None return None
@ -483,6 +495,7 @@ def _sqlite_datetime_extract(lookup_type, dt, tzname):
else: else:
return getattr(dt, lookup_type) return getattr(dt, lookup_type)
def _sqlite_datetime_trunc(lookup_type, dt, tzname): def _sqlite_datetime_trunc(lookup_type, dt, tzname):
try: try:
dt = util.typecast_timestamp(dt) dt = util.typecast_timestamp(dt)
@ -503,6 +516,7 @@ def _sqlite_datetime_trunc(lookup_type, dt, tzname):
elif lookup_type == 'second': elif lookup_type == 'second':
return "%i-%02i-%02i %02i:%02i:%02i" % (dt.year, dt.month, dt.day, dt.hour, dt.minute, dt.second) return "%i-%02i-%02i %02i:%02i:%02i" % (dt.year, dt.month, dt.day, dt.hour, dt.minute, dt.second)
def _sqlite_format_dtdelta(dt, conn, days, secs, usecs): def _sqlite_format_dtdelta(dt, conn, days, secs, usecs):
try: try:
dt = util.typecast_timestamp(dt) dt = util.typecast_timestamp(dt)
@ -517,5 +531,6 @@ def _sqlite_format_dtdelta(dt, conn, days, secs, usecs):
# It will be formatted as "%Y-%m-%d" or "%Y-%m-%d %H:%M:%S[.%f]" # It will be formatted as "%Y-%m-%d" or "%Y-%m-%d %H:%M:%S[.%f]"
return str(dt) return str(dt)
def _sqlite_regexp(re_pattern, re_string): def _sqlite_regexp(re_pattern, re_string):
return bool(re.search(re_pattern, force_text(re_string))) if re_string is not None else False return bool(re.search(re_pattern, force_text(re_string))) if re_string is not None else False

View File

@ -3,6 +3,7 @@ import sys
from django.db.backends import BaseDatabaseClient from django.db.backends import BaseDatabaseClient
class DatabaseClient(BaseDatabaseClient): class DatabaseClient(BaseDatabaseClient):
executable_name = 'sqlite3' executable_name = 'sqlite3'
@ -13,4 +14,3 @@ class DatabaseClient(BaseDatabaseClient):
sys.exit(os.system(" ".join(args))) sys.exit(os.system(" ".join(args)))
else: else:
os.execvp(self.executable_name, args) os.execvp(self.executable_name, args)

View File

@ -1,8 +1,10 @@
import os import os
import sys import sys
from django.db.backends.creation import BaseDatabaseCreation from django.db.backends.creation import BaseDatabaseCreation
from django.utils.six.moves import input from django.utils.six.moves import input
class DatabaseCreation(BaseDatabaseCreation): class DatabaseCreation(BaseDatabaseCreation):
# SQLite doesn't actually support most of these types, but it "does the right # SQLite doesn't actually support most of these types, but it "does the right
# thing" given more verbose field definitions, so leave them as is so that # thing" given more verbose field definitions, so leave them as is so that
@ -80,7 +82,6 @@ class DatabaseCreation(BaseDatabaseCreation):
SQLite since the databases will be distinct despite having the same SQLite since the databases will be distinct despite having the same
TEST_NAME. See http://www.sqlite.org/inmemorydb.html TEST_NAME. See http://www.sqlite.org/inmemorydb.html
""" """
settings_dict = self.connection.settings_dict
test_dbname = self._get_test_db_name() test_dbname = self._get_test_db_name()
sig = [self.connection.settings_dict['NAME']] sig = [self.connection.settings_dict['NAME']]
if test_dbname == ':memory:': if test_dbname == ':memory:':

View File

@ -1,8 +1,11 @@
import re import re
from django.db.backends import BaseDatabaseIntrospection, FieldInfo from django.db.backends import BaseDatabaseIntrospection, FieldInfo
field_size_re = re.compile(r'^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$') field_size_re = re.compile(r'^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$')
def get_field_size(name): def get_field_size(name):
""" Extract the size number from a "varchar(11)" type name """ """ Extract the size number from a "varchar(11)" type name """
m = field_size_re.search(name) m = field_size_re.search(name)
@ -46,6 +49,7 @@ class FlexibleFieldLookupDict(object):
return ('CharField', {'max_length': size}) return ('CharField', {'max_length': size})
raise KeyError raise KeyError
class DatabaseIntrospection(BaseDatabaseIntrospection): class DatabaseIntrospection(BaseDatabaseIntrospection):
data_types_reverse = FlexibleFieldLookupDict() data_types_reverse = FlexibleFieldLookupDict()
@ -98,7 +102,6 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
li, ri = other_table_results.index('('), other_table_results.rindex(')') li, ri = other_table_results.index('('), other_table_results.rindex(')')
other_table_results = other_table_results[li + 1:ri] other_table_results = other_table_results[li + 1:ri]
for other_index, other_desc in enumerate(other_table_results.split(',')): for other_index, other_desc in enumerate(other_table_results.split(',')):
other_desc = other_desc.strip() other_desc = other_desc.strip()
if other_desc.startswith('UNIQUE'): if other_desc.startswith('UNIQUE'):

View File

@ -85,8 +85,10 @@ class CursorDebugWrapper(CursorWrapper):
def typecast_date(s): def typecast_date(s):
return datetime.date(*map(int, s.split('-'))) if s else None # returns None if s is null return datetime.date(*map(int, s.split('-'))) if s else None # returns None if s is null
def typecast_time(s): # does NOT store time zone information def typecast_time(s): # does NOT store time zone information
if not s: return None if not s:
return None
hour, minutes, seconds = s.split(':') hour, minutes, seconds = s.split(':')
if '.' in seconds: # check whether seconds have a fractional part if '.' in seconds: # check whether seconds have a fractional part
seconds, microseconds = seconds.split('.') seconds, microseconds = seconds.split('.')
@ -94,11 +96,14 @@ def typecast_time(s): # does NOT store time zone information
microseconds = '0' microseconds = '0'
return datetime.time(int(hour), int(minutes), int(seconds), int(float('.' + microseconds) * 1000000)) return datetime.time(int(hour), int(minutes), int(seconds), int(float('.' + microseconds) * 1000000))
def typecast_timestamp(s): # does NOT store time zone information def typecast_timestamp(s): # does NOT store time zone information
# "2005-07-29 15:48:00.590358-05" # "2005-07-29 15:48:00.590358-05"
# "2005-07-29 09:56:00-05" # "2005-07-29 09:56:00-05"
if not s: return None if not s:
if not ' ' in s: return typecast_date(s) return None
if not ' ' in s:
return typecast_date(s)
d, t = s.split() d, t = s.split()
# Extract timezone information, if it exists. Currently we just throw # Extract timezone information, if it exists. Currently we just throw
# it away, but in the future we may make use of it. # it away, but in the future we may make use of it.
@ -122,11 +127,13 @@ def typecast_timestamp(s): # does NOT store time zone information
int(times[0]), int(times[1]), int(seconds), int(times[0]), int(times[1]), int(seconds),
int((microseconds + '000000')[:6]), tzinfo) int((microseconds + '000000')[:6]), tzinfo)
def typecast_decimal(s): def typecast_decimal(s):
if s is None or s == '': if s is None or s == '':
return None return None
return decimal.Decimal(s) return decimal.Decimal(s)
############################################### ###############################################
# Converters from Python to database (string) # # Converters from Python to database (string) #
############################################### ###############################################
@ -136,6 +143,7 @@ def rev_typecast_decimal(d):
return None return None
return str(d) return str(d)
def truncate_name(name, length=None, hash_len=4): def truncate_name(name, length=None, hash_len=4):
"""Shortens a string to a repeatable mangled version with the given length. """Shortens a string to a repeatable mangled version with the given length.
""" """
@ -145,6 +153,7 @@ def truncate_name(name, length=None, hash_len=4):
hsh = hashlib.md5(force_bytes(name)).hexdigest()[:hash_len] hsh = hashlib.md5(force_bytes(name)).hexdigest()[:hash_len]
return '%s%s' % (name[:length - hash_len], hsh) return '%s%s' % (name[:length - hash_len], hsh)
def format_number(value, max_digits, decimal_places): def format_number(value, max_digits, decimal_places):
""" """
Formats a number into a string with the requisite number of digits and Formats a number into a string with the requisite number of digits and

View File

@ -26,6 +26,7 @@ def permalink(func):
(viewname, viewargs, viewkwargs) (viewname, viewargs, viewkwargs)
""" """
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse
@wraps(func) @wraps(func)
def inner(*args, **kwargs): def inner(*args, **kwargs):
bits = func(*args, **kwargs) bits = func(*args, **kwargs)

View File

@ -3,6 +3,7 @@ Classes to represent the definitions of aggregate functions.
""" """
from django.db.models.constants import LOOKUP_SEP from django.db.models.constants import LOOKUP_SEP
def refs_aggregate(lookup_parts, aggregates): def refs_aggregate(lookup_parts, aggregates):
""" """
A little helper method to check if the lookup_parts contains references A little helper method to check if the lookup_parts contains references
@ -15,6 +16,7 @@ def refs_aggregate(lookup_parts, aggregates):
return True return True
return False return False
class Aggregate(object): class Aggregate(object):
""" """
Default Aggregate definition. Default Aggregate definition.
@ -58,23 +60,30 @@ class Aggregate(object):
aggregate = klass(col, source=source, is_summary=is_summary, **self.extra) aggregate = klass(col, source=source, is_summary=is_summary, **self.extra)
query.aggregates[alias] = aggregate query.aggregates[alias] = aggregate
class Avg(Aggregate): class Avg(Aggregate):
name = 'Avg' name = 'Avg'
class Count(Aggregate): class Count(Aggregate):
name = 'Count' name = 'Count'
class Max(Aggregate): class Max(Aggregate):
name = 'Max' name = 'Max'
class Min(Aggregate): class Min(Aggregate):
name = 'Min' name = 'Min'
class StdDev(Aggregate): class StdDev(Aggregate):
name = 'StdDev' name = 'StdDev'
class Sum(Aggregate): class Sum(Aggregate):
name = 'Sum' name = 'Sum'
class Variance(Aggregate): class Variance(Aggregate):
name = 'Variance' name = 'Variance'

View File

@ -226,9 +226,9 @@ class ModelBase(type):
# class # class
for field in base._meta.virtual_fields: for field in base._meta.virtual_fields:
if base._meta.abstract and field.name in field_names: if base._meta.abstract and field.name in field_names:
raise FieldError('Local field %r in class %r clashes '\ raise FieldError('Local field %r in class %r clashes '
'with field of similar name from '\ 'with field of similar name from '
'abstract base class %r' % \ 'abstract base class %r' %
(field.name, name, base.__name__)) (field.name, name, base.__name__))
new_class.add_to_class(field.name, copy.deepcopy(field)) new_class.add_to_class(field.name, copy.deepcopy(field))
@ -1008,8 +1008,6 @@ def get_absolute_url(opts, func, self, *args, **kwargs):
# MISC # # MISC #
######## ########
class Empty(object):
pass
def simple_class_factory(model, attrs): def simple_class_factory(model, attrs):
""" """
@ -1017,6 +1015,7 @@ def simple_class_factory(model, attrs):
""" """
return model return model
def model_unpickle(model_id, attrs, factory): def model_unpickle(model_id, attrs, factory):
""" """
Used to unpickle Model subclasses with deferred fields. Used to unpickle Model subclasses with deferred fields.

View File

@ -4,6 +4,7 @@ from django.db.models.aggregates import refs_aggregate
from django.db.models.constants import LOOKUP_SEP from django.db.models.constants import LOOKUP_SEP
from django.utils import tree from django.utils import tree
class ExpressionNode(tree.Node): class ExpressionNode(tree.Node):
""" """
Base class for all query expressions. Base class for all query expressions.
@ -128,6 +129,7 @@ class ExpressionNode(tree.Node):
"Use .bitand() and .bitor() for bitwise logical operations." "Use .bitand() and .bitor() for bitwise logical operations."
) )
class F(ExpressionNode): class F(ExpressionNode):
""" """
An expression representing the value of the given field. An expression representing the value of the given field.
@ -147,6 +149,7 @@ class F(ExpressionNode):
def evaluate(self, evaluator, qn, connection): def evaluate(self, evaluator, qn, connection):
return evaluator.evaluate_leaf(self, qn, connection) return evaluator.evaluate_leaf(self, qn, connection)
class DateModifierNode(ExpressionNode): class DateModifierNode(ExpressionNode):
""" """
Node that implements the following syntax: Node that implements the following syntax:

View File

@ -25,9 +25,11 @@ from django.utils.encoding import smart_text, force_text, force_bytes
from django.utils.ipv6 import clean_ipv6_address from django.utils.ipv6 import clean_ipv6_address
from django.utils import six from django.utils import six
class Empty(object): class Empty(object):
pass pass
class NOT_PROVIDED: class NOT_PROVIDED:
pass pass
@ -35,12 +37,15 @@ class NOT_PROVIDED:
# of most "choices" lists. # of most "choices" lists.
BLANK_CHOICE_DASH = [("", "---------")] BLANK_CHOICE_DASH = [("", "---------")]
def _load_field(app_label, model_name, field_name): def _load_field(app_label, model_name, field_name):
return get_model(app_label, model_name)._meta.get_field_by_name(field_name)[0] return get_model(app_label, model_name)._meta.get_field_by_name(field_name)[0]
class FieldDoesNotExist(Exception): class FieldDoesNotExist(Exception):
pass pass
# A guide to Field parameters: # A guide to Field parameters:
# #
# * name: The name of the field specifed in the model. # * name: The name of the field specifed in the model.
@ -61,6 +66,7 @@ def _empty(of_cls):
new.__class__ = of_cls new.__class__ = of_cls
return new return new
@total_ordering @total_ordering
class Field(object): class Field(object):
"""Base class for all field types""" """Base class for all field types"""
@ -444,12 +450,12 @@ class Field(object):
if hasattr(value, '_prepare'): if hasattr(value, '_prepare'):
return value._prepare() return value._prepare()
if lookup_type in ( if lookup_type in {
'iexact', 'contains', 'icontains', 'iexact', 'contains', 'icontains',
'startswith', 'istartswith', 'endswith', 'iendswith', 'startswith', 'istartswith', 'endswith', 'iendswith',
'month', 'day', 'week_day', 'hour', 'minute', 'second', 'month', 'day', 'week_day', 'hour', 'minute', 'second',
'isnull', 'search', 'regex', 'iregex', 'isnull', 'search', 'regex', 'iregex',
): }:
return value return value
elif lookup_type in ('exact', 'gt', 'gte', 'lt', 'lte'): elif lookup_type in ('exact', 'gt', 'gte', 'lt', 'lte'):
return self.get_prep_value(value) return self.get_prep_value(value)
@ -712,6 +718,7 @@ class AutoField(Field):
def formfield(self, **kwargs): def formfield(self, **kwargs):
return None return None
class BooleanField(Field): class BooleanField(Field):
empty_strings_allowed = False empty_strings_allowed = False
default_error_messages = { default_error_messages = {
@ -766,13 +773,13 @@ class BooleanField(Field):
if self.choices: if self.choices:
include_blank = (self.null or include_blank = (self.null or
not (self.has_default() or 'initial' in kwargs)) not (self.has_default() or 'initial' in kwargs))
defaults = {'choices': self.get_choices( defaults = {'choices': self.get_choices(include_blank=include_blank)}
include_blank=include_blank)}
else: else:
defaults = {'form_class': forms.BooleanField} defaults = {'form_class': forms.BooleanField}
defaults.update(kwargs) defaults.update(kwargs)
return super(BooleanField, self).formfield(**defaults) return super(BooleanField, self).formfield(**defaults)
class CharField(Field): class CharField(Field):
description = _("String (up to %(max_length)s)") description = _("String (up to %(max_length)s)")
@ -799,6 +806,7 @@ class CharField(Field):
defaults.update(kwargs) defaults.update(kwargs)
return super(CharField, self).formfield(**defaults) return super(CharField, self).formfield(**defaults)
# TODO: Maybe move this into contrib, because it's specialized. # TODO: Maybe move this into contrib, because it's specialized.
class CommaSeparatedIntegerField(CharField): class CommaSeparatedIntegerField(CharField):
default_validators = [validators.validate_comma_separated_integer_list] default_validators = [validators.validate_comma_separated_integer_list]
@ -813,6 +821,7 @@ class CommaSeparatedIntegerField(CharField):
defaults.update(kwargs) defaults.update(kwargs)
return super(CommaSeparatedIntegerField, self).formfield(**defaults) return super(CommaSeparatedIntegerField, self).formfield(**defaults)
class DateField(Field): class DateField(Field):
empty_strings_allowed = False empty_strings_allowed = False
default_error_messages = { default_error_messages = {
@ -919,6 +928,7 @@ class DateField(Field):
defaults.update(kwargs) defaults.update(kwargs)
return super(DateField, self).formfield(**defaults) return super(DateField, self).formfield(**defaults)
class DateTimeField(DateField): class DateTimeField(DateField):
empty_strings_allowed = False empty_strings_allowed = False
default_error_messages = { default_error_messages = {
@ -1025,6 +1035,7 @@ class DateTimeField(DateField):
defaults.update(kwargs) defaults.update(kwargs)
return super(DateTimeField, self).formfield(**defaults) return super(DateTimeField, self).formfield(**defaults)
class DecimalField(Field): class DecimalField(Field):
empty_strings_allowed = False empty_strings_allowed = False
default_error_messages = { default_error_messages = {
@ -1096,6 +1107,7 @@ class DecimalField(Field):
defaults.update(kwargs) defaults.update(kwargs)
return super(DecimalField, self).formfield(**defaults) return super(DecimalField, self).formfield(**defaults)
class EmailField(CharField): class EmailField(CharField):
default_validators = [validators.validate_email] default_validators = [validators.validate_email]
description = _("Email address") description = _("Email address")
@ -1122,6 +1134,7 @@ class EmailField(CharField):
defaults.update(kwargs) defaults.update(kwargs)
return super(EmailField, self).formfield(**defaults) return super(EmailField, self).formfield(**defaults)
class FilePathField(Field): class FilePathField(Field):
description = _("File path") description = _("File path")
@ -1163,6 +1176,7 @@ class FilePathField(Field):
def get_internal_type(self): def get_internal_type(self):
return "FilePathField" return "FilePathField"
class FloatField(Field): class FloatField(Field):
empty_strings_allowed = False empty_strings_allowed = False
default_error_messages = { default_error_messages = {
@ -1195,6 +1209,7 @@ class FloatField(Field):
defaults.update(kwargs) defaults.update(kwargs)
return super(FloatField, self).formfield(**defaults) return super(FloatField, self).formfield(**defaults)
class IntegerField(Field): class IntegerField(Field):
empty_strings_allowed = False empty_strings_allowed = False
default_error_messages = { default_error_messages = {
@ -1233,6 +1248,7 @@ class IntegerField(Field):
defaults.update(kwargs) defaults.update(kwargs)
return super(IntegerField, self).formfield(**defaults) return super(IntegerField, self).formfield(**defaults)
class BigIntegerField(IntegerField): class BigIntegerField(IntegerField):
empty_strings_allowed = False empty_strings_allowed = False
description = _("Big (8 byte) integer") description = _("Big (8 byte) integer")
@ -1247,6 +1263,7 @@ class BigIntegerField(IntegerField):
defaults.update(kwargs) defaults.update(kwargs)
return super(BigIntegerField, self).formfield(**defaults) return super(BigIntegerField, self).formfield(**defaults)
class IPAddressField(Field): class IPAddressField(Field):
empty_strings_allowed = False empty_strings_allowed = False
description = _("IPv4 address") description = _("IPv4 address")
@ -1268,6 +1285,7 @@ class IPAddressField(Field):
defaults.update(kwargs) defaults.update(kwargs)
return super(IPAddressField, self).formfield(**defaults) return super(IPAddressField, self).formfield(**defaults)
class GenericIPAddressField(Field): class GenericIPAddressField(Field):
empty_strings_allowed = True empty_strings_allowed = True
description = _("IP address") description = _("IP address")
@ -1383,6 +1401,7 @@ class NullBooleanField(Field):
defaults.update(kwargs) defaults.update(kwargs)
return super(NullBooleanField, self).formfield(**defaults) return super(NullBooleanField, self).formfield(**defaults)
class PositiveIntegerField(IntegerField): class PositiveIntegerField(IntegerField):
description = _("Positive integer") description = _("Positive integer")
@ -1394,6 +1413,7 @@ class PositiveIntegerField(IntegerField):
defaults.update(kwargs) defaults.update(kwargs)
return super(PositiveIntegerField, self).formfield(**defaults) return super(PositiveIntegerField, self).formfield(**defaults)
class PositiveSmallIntegerField(IntegerField): class PositiveSmallIntegerField(IntegerField):
description = _("Positive small integer") description = _("Positive small integer")
@ -1405,6 +1425,7 @@ class PositiveSmallIntegerField(IntegerField):
defaults.update(kwargs) defaults.update(kwargs)
return super(PositiveSmallIntegerField, self).formfield(**defaults) return super(PositiveSmallIntegerField, self).formfield(**defaults)
class SlugField(CharField): class SlugField(CharField):
default_validators = [validators.validate_slug] default_validators = [validators.validate_slug]
description = _("Slug (up to %(max_length)s)") description = _("Slug (up to %(max_length)s)")
@ -1434,12 +1455,14 @@ class SlugField(CharField):
defaults.update(kwargs) defaults.update(kwargs)
return super(SlugField, self).formfield(**defaults) return super(SlugField, self).formfield(**defaults)
class SmallIntegerField(IntegerField): class SmallIntegerField(IntegerField):
description = _("Small integer") description = _("Small integer")
def get_internal_type(self): def get_internal_type(self):
return "SmallIntegerField" return "SmallIntegerField"
class TextField(Field): class TextField(Field):
description = _("Text") description = _("Text")
@ -1456,6 +1479,7 @@ class TextField(Field):
defaults.update(kwargs) defaults.update(kwargs)
return super(TextField, self).formfield(**defaults) return super(TextField, self).formfield(**defaults)
class TimeField(Field): class TimeField(Field):
empty_strings_allowed = False empty_strings_allowed = False
default_error_messages = { default_error_messages = {
@ -1539,6 +1563,7 @@ class TimeField(Field):
defaults.update(kwargs) defaults.update(kwargs)
return super(TimeField, self).formfield(**defaults) return super(TimeField, self).formfield(**defaults)
class URLField(CharField): class URLField(CharField):
default_validators = [validators.URLValidator()] default_validators = [validators.URLValidator()]
description = _("URL") description = _("URL")
@ -1562,6 +1587,7 @@ class URLField(CharField):
defaults.update(kwargs) defaults.update(kwargs)
return super(URLField, self).formfield(**defaults) return super(URLField, self).formfield(**defaults)
class BinaryField(Field): class BinaryField(Field):
description = _("Raw binary data") description = _("Raw binary data")
empty_values = [None, b''] empty_values = [None, b'']

View File

@ -11,6 +11,7 @@ from django.utils.encoding import force_str, force_text
from django.utils import six from django.utils import six
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
class FieldFile(File): class FieldFile(File):
def __init__(self, instance, field, name): def __init__(self, instance, field, name):
super(FieldFile, self).__init__(None, name) super(FieldFile, self).__init__(None, name)
@ -135,6 +136,7 @@ class FieldFile(File):
# be restored later, by FileDescriptor below. # be restored later, by FileDescriptor below.
return {'name': self.name, 'closed': False, '_committed': True, '_file': None} return {'name': self.name, 'closed': False, '_committed': True, '_file': None}
class FileDescriptor(object): class FileDescriptor(object):
""" """
The descriptor for the file attribute on the model instance. Returns a The descriptor for the file attribute on the model instance. Returns a
@ -205,6 +207,7 @@ class FileDescriptor(object):
def __set__(self, instance, value): def __set__(self, instance, value):
instance.__dict__[self.field.name] = value instance.__dict__[self.field.name] = value
class FileField(Field): class FileField(Field):
# The class to wrap instance attributes in. Accessing the file object off # The class to wrap instance attributes in. Accessing the file object off
@ -300,6 +303,7 @@ class FileField(Field):
defaults.update(kwargs) defaults.update(kwargs)
return super(FileField, self).formfield(**defaults) return super(FileField, self).formfield(**defaults)
class ImageFileDescriptor(FileDescriptor): class ImageFileDescriptor(FileDescriptor):
""" """
Just like the FileDescriptor, but for ImageFields. The only difference is Just like the FileDescriptor, but for ImageFields. The only difference is
@ -321,14 +325,15 @@ class ImageFileDescriptor(FileDescriptor):
if previous_file is not None: if previous_file is not None:
self.field.update_dimension_fields(instance, force=True) self.field.update_dimension_fields(instance, force=True)
class ImageFieldFile(ImageFile, FieldFile):
class ImageFieldFile(ImageFile, FieldFile):
def delete(self, save=True): def delete(self, save=True):
# Clear the image dimensions cache # Clear the image dimensions cache
if hasattr(self, '_dimensions_cache'): if hasattr(self, '_dimensions_cache'):
del self._dimensions_cache del self._dimensions_cache
super(ImageFieldFile, self).delete(save) super(ImageFieldFile, self).delete(save)
class ImageField(FileField): class ImageField(FileField):
attr_class = ImageFieldFile attr_class = ImageFieldFile
descriptor_class = ImageFileDescriptor descriptor_class = ImageFileDescriptor

View File

@ -5,6 +5,7 @@ have the same attributes as fields sometimes (avoids a lot of special casing).
from django.db.models import fields from django.db.models import fields
class OrderWrt(fields.IntegerField): class OrderWrt(fields.IntegerField):
""" """
A proxy for the _order database field that is used when A proxy for the _order database field that is used when

View File

@ -213,7 +213,7 @@ class SingleRelatedObjectDescriptor(six.with_metaclass(RenameRelatedObjectDescri
# If null=True, we can assign null here, but otherwise the value needs # If null=True, we can assign null here, but otherwise the value needs
# to be an instance of the related class. # to be an instance of the related class.
if value is None and self.related.field.null == False: if value is None and self.related.field.null is False:
raise ValueError('Cannot assign None: "%s.%s" does not allow null values.' % raise ValueError('Cannot assign None: "%s.%s" does not allow null values.' %
(instance._meta.object_name, self.related.get_accessor_name())) (instance._meta.object_name, self.related.get_accessor_name()))
elif value is not None and not isinstance(value, self.related.model): elif value is not None and not isinstance(value, self.related.model):
@ -312,7 +312,7 @@ class ReverseSingleRelatedObjectDescriptor(six.with_metaclass(RenameRelatedObjec
def __set__(self, instance, value): def __set__(self, instance, value):
# If null=True, we can assign null here, but otherwise the value needs # If null=True, we can assign null here, but otherwise the value needs
# to be an instance of the related class. # to be an instance of the related class.
if value is None and self.field.null == False: if value is None and self.field.null is False:
raise ValueError('Cannot assign None: "%s.%s" does not allow null values.' % raise ValueError('Cannot assign None: "%s.%s" does not allow null values.' %
(instance._meta.object_name, self.field.name)) (instance._meta.object_name, self.field.name))
elif value is not None and not isinstance(value, self.field.rel.to): elif value is not None and not isinstance(value, self.field.rel.to):
@ -512,7 +512,6 @@ def create_many_related_manager(superclass, rel):
"a many-to-many relationship can be used." % "a many-to-many relationship can be used." %
instance.__class__.__name__) instance.__class__.__name__)
def _get_fk_val(self, obj, field_name): def _get_fk_val(self, obj, field_name):
""" """
Returns the correct value for this relationship's foreign key. This Returns the correct value for this relationship's foreign key. This
@ -823,6 +822,7 @@ class ReverseManyRelatedObjectsDescriptor(object):
manager.clear() manager.clear()
manager.add(*value) manager.add(*value)
class ForeignObjectRel(object): class ForeignObjectRel(object):
def __init__(self, field, to, related_name=None, limit_choices_to=None, def __init__(self, field, to, related_name=None, limit_choices_to=None,
parent_link=False, on_delete=None, related_query_name=None): parent_link=False, on_delete=None, related_query_name=None):
@ -860,6 +860,7 @@ class ForeignObjectRel(object):
# example custom multicolumn joins currently have no remote field). # example custom multicolumn joins currently have no remote field).
self.field_name = None self.field_name = None
class ManyToOneRel(ForeignObjectRel): class ManyToOneRel(ForeignObjectRel):
def __init__(self, field, to, field_name, related_name=None, limit_choices_to=None, def __init__(self, field, to, field_name, related_name=None, limit_choices_to=None,
parent_link=False, on_delete=None, related_query_name=None): parent_link=False, on_delete=None, related_query_name=None):
@ -1125,7 +1126,7 @@ class ForeignKey(ForeignObject):
def __init__(self, to, to_field=None, rel_class=ManyToOneRel, def __init__(self, to, to_field=None, rel_class=ManyToOneRel,
db_constraint=True, **kwargs): db_constraint=True, **kwargs):
try: try:
to_name = to._meta.object_name.lower() to._meta.object_name.lower()
except AttributeError: # to._meta doesn't exist, so it must be RECURSIVE_RELATIONSHIP_CONSTANT except AttributeError: # to._meta doesn't exist, so it must be RECURSIVE_RELATIONSHIP_CONSTANT
assert isinstance(to, six.string_types), "%s(%r) is invalid. First parameter to ForeignKey must be either a model, a model name, or the string %r" % (self.__class__.__name__, to, RECURSIVE_RELATIONSHIP_CONSTANT) assert isinstance(to, six.string_types), "%s(%r) is invalid. First parameter to ForeignKey must be either a model, a model name, or the string %r" % (self.__class__.__name__, to, RECURSIVE_RELATIONSHIP_CONSTANT)
else: else:
@ -1162,7 +1163,6 @@ class ForeignKey(ForeignObject):
if self.rel.on_delete is not CASCADE: if self.rel.on_delete is not CASCADE:
kwargs['on_delete'] = self.rel.on_delete kwargs['on_delete'] = self.rel.on_delete
# Rel needs more work. # Rel needs more work.
rel = self.rel
if self.rel.field_name: if self.rel.field_name:
kwargs['to_field'] = self.rel.field_name kwargs['to_field'] = self.rel.field_name
if isinstance(self.rel.to, six.string_types): if isinstance(self.rel.to, six.string_types):
@ -1222,7 +1222,7 @@ class ForeignKey(ForeignObject):
return field_default return field_default
def get_db_prep_save(self, value, connection): def get_db_prep_save(self, value, connection):
if value == '' or value == None: if value == '' or value is None:
return None return None
else: else:
return self.related_field.get_db_prep_save(value, return self.related_field.get_db_prep_save(value,
@ -1389,7 +1389,6 @@ class ManyToManyField(RelatedField):
if "help_text" in kwargs: if "help_text" in kwargs:
del kwargs['help_text'] del kwargs['help_text']
# Rel needs more work. # Rel needs more work.
rel = self.rel
if isinstance(self.rel.to, six.string_types): if isinstance(self.rel.to, six.string_types):
kwargs['to'] = self.rel.to kwargs['to'] = self.rel.to
else: else:

View File

@ -7,6 +7,7 @@ to_python() and the other necessary methods and everything will work
seamlessly. seamlessly.
""" """
class SubfieldBase(type): class SubfieldBase(type):
""" """
A metaclass for custom Field subclasses. This ensures the model's attribute A metaclass for custom Field subclasses. This ensures the model's attribute
@ -19,6 +20,7 @@ class SubfieldBase(type):
) )
return new_class return new_class
class Creator(object): class Creator(object):
""" """
A placeholder class that provides a way to set the attribute on the model. A placeholder class that provides a way to set the attribute on the model.
@ -34,6 +36,7 @@ class Creator(object):
def __set__(self, obj, value): def __set__(self, obj, value):
obj.__dict__[self.field.name] = self.field.to_python(value) obj.__dict__[self.field.name] = self.field.to_python(value)
def make_contrib(superclass, func=None): def make_contrib(superclass, func=None):
""" """
Returns a suitable contribute_to_class() method for the Field subclass. Returns a suitable contribute_to_class() method for the Field subclass.

View File

@ -15,9 +15,11 @@ import os
__all__ = ('get_apps', 'get_app', 'get_models', 'get_model', 'register_models', __all__ = ('get_apps', 'get_app', 'get_models', 'get_model', 'register_models',
'load_app', 'app_cache_ready') 'load_app', 'app_cache_ready')
class UnavailableApp(Exception): class UnavailableApp(Exception):
pass pass
class AppCache(object): class AppCache(object):
""" """
A cache that stores installed applications and their models. Used to A cache that stores installed applications and their models. Used to

View File

@ -6,6 +6,7 @@ from django.db.models.fields import FieldDoesNotExist
from django.utils import six from django.utils import six
from django.utils.deprecation import RenameMethodsBase from django.utils.deprecation import RenameMethodsBase
def ensure_default_manager(sender, **kwargs): def ensure_default_manager(sender, **kwargs):
""" """
Ensures that a Model subclass contains a default manager and sets the Ensures that a Model subclass contains a default manager and sets the
@ -245,7 +246,7 @@ class ManagerDescriptor(object):
self.manager = manager self.manager = manager
def __get__(self, instance, type=None): def __get__(self, instance, type=None):
if instance != None: if instance is not None:
raise AttributeError("Manager isn't accessible via %s instances" % type.__name__) raise AttributeError("Manager isn't accessible via %s instances" % type.__name__)
return self.manager return self.manager

View File

@ -899,10 +899,12 @@ class QuerySet(object):
# empty" result. # empty" result.
value_annotation = True value_annotation = True
class InstanceCheckMeta(type): class InstanceCheckMeta(type):
def __instancecheck__(self, instance): def __instancecheck__(self, instance):
return instance.query.is_empty() return instance.query.is_empty()
class EmptyQuerySet(six.with_metaclass(InstanceCheckMeta)): class EmptyQuerySet(six.with_metaclass(InstanceCheckMeta)):
""" """
Marker class usable for checking if a queryset is empty by .none(): Marker class usable for checking if a queryset is empty by .none():
@ -912,6 +914,7 @@ class EmptyQuerySet(six.with_metaclass(InstanceCheckMeta)):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
raise TypeError("EmptyQuerySet can't be instantiated") raise TypeError("EmptyQuerySet can't be instantiated")
class ValuesQuerySet(QuerySet): class ValuesQuerySet(QuerySet):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(ValuesQuerySet, self).__init__(*args, **kwargs) super(ValuesQuerySet, self).__init__(*args, **kwargs)
@ -1276,11 +1279,10 @@ def get_cached_row(row, index_start, using, klass_info, offset=0,
return None return None
klass, field_names, field_count, related_fields, reverse_related_fields, pk_idx = klass_info klass, field_names, field_count, related_fields, reverse_related_fields, pk_idx = klass_info
fields = row[index_start:index_start + field_count] fields = row[index_start:index_start + field_count]
# If the pk column is None (or the Oracle equivalent ''), then the related # If the pk column is None (or the Oracle equivalent ''), then the related
# object must be non-existent - set the relation to None. # object must be non-existent - set the relation to None.
if fields[pk_idx] == None or fields[pk_idx] == '': if fields[pk_idx] is None or fields[pk_idx] == '':
obj = None obj = None
elif field_names: elif field_names:
fields = list(fields) fields = list(fields)
@ -1510,8 +1512,6 @@ def prefetch_related_objects(result_cache, related_lookups):
if len(result_cache) == 0: if len(result_cache) == 0:
return # nothing to do return # nothing to do
model = result_cache[0].__class__
# We need to be able to dynamically add to the list of prefetch_related # We need to be able to dynamically add to the list of prefetch_related
# lookups that we look up (see below). So we need some book keeping to # lookups that we look up (see below). So we need some book keeping to
# ensure we don't do duplicate work. # ensure we don't do duplicate work.
@ -1538,7 +1538,7 @@ def prefetch_related_objects(result_cache, related_lookups):
if len(obj_list) == 0: if len(obj_list) == 0:
break break
current_lookup = LOOKUP_SEP.join(attrs[0:level+1]) current_lookup = LOOKUP_SEP.join(attrs[:level + 1])
if current_lookup in done_queries: if current_lookup in done_queries:
# Skip any prefetching, and any object preparation # Skip any prefetching, and any object preparation
obj_list = done_queries[current_lookup] obj_list = done_queries[current_lookup]

View File

@ -30,6 +30,7 @@ class QueryWrapper(object):
def as_sql(self, qn=None, connection=None): def as_sql(self, qn=None, connection=None):
return self.data return self.data
class Q(tree.Node): class Q(tree.Node):
""" """
Encapsulates filters as objects that can then be combined logically (using Encapsulates filters as objects that can then be combined logically (using
@ -74,6 +75,7 @@ class Q(tree.Node):
clone.children.append(child) clone.children.append(child)
return clone return clone
class DeferredAttribute(object): class DeferredAttribute(object):
""" """
A wrapper for a deferred-loading field. When the value is read from this A wrapper for a deferred-loading field. When the value is read from this
@ -99,8 +101,7 @@ class DeferredAttribute(object):
try: try:
f = opts.get_field_by_name(self.field_name)[0] f = opts.get_field_by_name(self.field_name)[0]
except FieldDoesNotExist: except FieldDoesNotExist:
f = [f for f in opts.fields f = [f for f in opts.fields if f.attname == self.field_name][0]
if f.attname == self.field_name][0]
name = f.name name = f.name
# Let's see if the field is part of the parent chain. If so we # Let's see if the field is part of the parent chain. If so we
# might be able to reuse the already loaded value. Refs #18343. # might be able to reuse the already loaded value. Refs #18343.
@ -174,6 +175,7 @@ def select_related_descend(field, restricted, requested, load_fields, reverse=Fa
return False return False
return True return True
# This function is needed because data descriptors must be defined on a class # This function is needed because data descriptors must be defined on a class
# object, not an instance, to have any effect. # object, not an instance, to have any effect.

View File

@ -10,6 +10,7 @@ PathInfo = namedtuple('PathInfo',
'from_opts to_opts target_fields join_field ' 'from_opts to_opts target_fields join_field '
'm2m direct') 'm2m direct')
class RelatedObject(object): class RelatedObject(object):
def __init__(self, parent_model, model, field): def __init__(self, parent_model, model, field):
self.parent_model = parent_model self.parent_model = parent_model

View File

@ -9,6 +9,7 @@ from django.db.models.fields import IntegerField, FloatField
ordinal_aggregate_field = IntegerField() ordinal_aggregate_field = IntegerField()
computed_aggregate_field = FloatField() computed_aggregate_field = FloatField()
class Aggregate(object): class Aggregate(object):
""" """
Default SQL Aggregate. Default SQL Aggregate.
@ -93,6 +94,7 @@ class Avg(Aggregate):
is_computed = True is_computed = True
sql_function = 'AVG' sql_function = 'AVG'
class Count(Aggregate): class Count(Aggregate):
is_ordinal = True is_ordinal = True
sql_function = 'COUNT' sql_function = 'COUNT'
@ -101,12 +103,15 @@ class Count(Aggregate):
def __init__(self, col, distinct=False, **extra): def __init__(self, col, distinct=False, **extra):
super(Count, self).__init__(col, distinct='DISTINCT ' if distinct else '', **extra) super(Count, self).__init__(col, distinct='DISTINCT ' if distinct else '', **extra)
class Max(Aggregate): class Max(Aggregate):
sql_function = 'MAX' sql_function = 'MAX'
class Min(Aggregate): class Min(Aggregate):
sql_function = 'MIN' sql_function = 'MIN'
class StdDev(Aggregate): class StdDev(Aggregate):
is_computed = True is_computed = True
@ -114,9 +119,11 @@ class StdDev(Aggregate):
super(StdDev, self).__init__(col, **extra) super(StdDev, self).__init__(col, **extra)
self.sql_function = 'STDDEV_SAMP' if sample else 'STDDEV_POP' self.sql_function = 'STDDEV_SAMP' if sample else 'STDDEV_POP'
class Sum(Aggregate): class Sum(Aggregate):
sql_function = 'SUM' sql_function = 'SUM'
class Variance(Aggregate): class Variance(Aggregate):
is_computed = True is_computed = True

View File

@ -915,6 +915,7 @@ class SQLDeleteCompiler(SQLCompiler):
result.append('WHERE %s' % where) result.append('WHERE %s' % where)
return ' '.join(result), tuple(params) return ' '.join(result), tuple(params)
class SQLUpdateCompiler(SQLCompiler): class SQLUpdateCompiler(SQLCompiler):
def as_sql(self): def as_sql(self):
""" """
@ -1029,6 +1030,7 @@ class SQLUpdateCompiler(SQLCompiler):
for alias in self.query.tables[1:]: for alias in self.query.tables[1:]:
self.query.alias_refcount[alias] = 0 self.query.alias_refcount[alias] = 0
class SQLAggregateCompiler(SQLCompiler): class SQLAggregateCompiler(SQLCompiler):
def as_sql(self, qn=None): def as_sql(self, qn=None):
""" """
@ -1050,6 +1052,7 @@ class SQLAggregateCompiler(SQLCompiler):
params = params + self.query.sub_params params = params + self.query.sub_params
return sql, params return sql, params
class SQLDateCompiler(SQLCompiler): class SQLDateCompiler(SQLCompiler):
def results_iter(self): def results_iter(self):
""" """
@ -1075,6 +1078,7 @@ class SQLDateCompiler(SQLCompiler):
date = date.date() date = date.date()
yield date yield date
class SQLDateTimeCompiler(SQLCompiler): class SQLDateTimeCompiler(SQLCompiler):
def results_iter(self): def results_iter(self):
""" """
@ -1107,6 +1111,7 @@ class SQLDateTimeCompiler(SQLCompiler):
datetime = timezone.make_aware(datetime, self.query.tzinfo) datetime = timezone.make_aware(datetime, self.query.tzinfo)
yield datetime yield datetime
def order_modified_iter(cursor, trim, sentinel): def order_modified_iter(cursor, trim, sentinel):
""" """
Yields blocks of rows from a cursor. We use this iterator in the special Yields blocks of rows from a cursor. We use this iterator in the special

View File

@ -3,9 +3,11 @@ Useful auxilliary data structures for query construction. Not useful outside
the SQL domain. the SQL domain.
""" """
class EmptyResultSet(Exception): class EmptyResultSet(Exception):
pass pass
class MultiJoin(Exception): class MultiJoin(Exception):
""" """
Used by join construction code to indicate the point at which a Used by join construction code to indicate the point at which a
@ -17,12 +19,10 @@ class MultiJoin(Exception):
# The path travelled, this includes the path to the multijoin. # The path travelled, this includes the path to the multijoin.
self.names_with_path = path_with_names self.names_with_path = path_with_names
class Empty(object): class Empty(object):
pass pass
class RawValue(object):
def __init__(self, value):
self.value = value
class Date(object): class Date(object):
""" """
@ -42,6 +42,7 @@ class Date(object):
col = self.col col = self.col
return connection.ops.date_trunc_sql(self.lookup_type, col), [] return connection.ops.date_trunc_sql(self.lookup_type, col), []
class DateTime(object): class DateTime(object):
""" """
Add a datetime selection column. Add a datetime selection column.

View File

@ -1,7 +1,9 @@
import copy
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db.models.constants import LOOKUP_SEP from django.db.models.constants import LOOKUP_SEP
from django.db.models.fields import FieldDoesNotExist from django.db.models.fields import FieldDoesNotExist
import copy
class SQLEvaluator(object): class SQLEvaluator(object):
def __init__(self, expression, query, allow_joins=True, reuse=None): def __init__(self, expression, query, allow_joins=True, reuse=None):

View File

@ -615,7 +615,6 @@ class Query(object):
for model, values in six.iteritems(seen): for model, values in six.iteritems(seen):
callback(target, model, values) callback(target, model, values)
def deferred_to_columns_cb(self, target, model, fields): def deferred_to_columns_cb(self, target, model, fields):
""" """
Callback used by deferred_to_columns(). The "target" parameter should Callback used by deferred_to_columns(). The "target" parameter should
@ -627,7 +626,6 @@ class Query(object):
for field in fields: for field in fields:
target[table].add(field.column) target[table].add(field.column)
def table_alias(self, table_name, create=False): def table_alias(self, table_name, create=False):
""" """
Returns a table alias for the given table_name and whether this is a Returns a table alias for the given table_name and whether this is a
@ -955,7 +953,6 @@ class Query(object):
self.unref_alias(alias) self.unref_alias(alias)
self.included_inherited_models = {} self.included_inherited_models = {}
def add_aggregate(self, aggregate, model, alias, is_summary): def add_aggregate(self, aggregate, model, alias, is_summary):
""" """
Adds a single aggregate expression to the Query Adds a single aggregate expression to the Query
@ -1876,6 +1873,7 @@ class Query(object):
else: else:
return field.null return field.null
def get_order_dir(field, default='ASC'): def get_order_dir(field, default='ASC'):
""" """
Returns the field name and direction for an order specification. For Returns the field name and direction for an order specification. For
@ -1900,6 +1898,7 @@ def add_to_dict(data, key, value):
else: else:
data[key] = set([value]) data[key] = set([value])
def is_reverse_o2o(field): def is_reverse_o2o(field):
""" """
A little helper to check if the given field is reverse-o2o. The field is A little helper to check if the given field is reverse-o2o. The field is
@ -1907,6 +1906,7 @@ def is_reverse_o2o(field):
""" """
return not hasattr(field, 'rel') and field.field.unique return not hasattr(field, 'rel') and field.field.unique
def alias_diff(refcounts_before, refcounts_after): def alias_diff(refcounts_before, refcounts_after):
""" """
Given the before and after copies of refcounts works out which aliases Given the before and after copies of refcounts works out which aliases

View File

@ -7,7 +7,7 @@ from django.core.exceptions import FieldError
from django.db import connections from django.db import connections
from django.db.models.constants import LOOKUP_SEP from django.db.models.constants import LOOKUP_SEP
from django.db.models.fields import DateField, DateTimeField, FieldDoesNotExist from django.db.models.fields import DateField, DateTimeField, FieldDoesNotExist
from django.db.models.sql.constants import * from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, SelectInfo
from django.db.models.sql.datastructures import Date, DateTime from django.db.models.sql.datastructures import Date, DateTime
from django.db.models.sql.query import Query from django.db.models.sql.query import Query
from django.db.models.sql.where import AND, Constraint from django.db.models.sql.where import AND, Constraint
@ -20,6 +20,7 @@ from django.utils import timezone
__all__ = ['DeleteQuery', 'UpdateQuery', 'InsertQuery', 'DateQuery', __all__ = ['DeleteQuery', 'UpdateQuery', 'InsertQuery', 'DateQuery',
'DateTimeQuery', 'AggregateQuery'] 'DateTimeQuery', 'AggregateQuery']
class DeleteQuery(Query): class DeleteQuery(Query):
""" """
Delete queries are done through this class, since they are more constrained Delete queries are done through this class, since they are more constrained
@ -77,7 +78,9 @@ class DeleteQuery(Query):
return return
else: else:
innerq.clear_select_clause() innerq.clear_select_clause()
innerq.select = [SelectInfo((self.get_initial_alias(), pk.column), None)] innerq.select = [
SelectInfo((self.get_initial_alias(), pk.column), None)
]
values = innerq values = innerq
where = self.where_class() where = self.where_class()
where.add((Constraint(None, pk.column, pk), 'in', values), AND) where.add((Constraint(None, pk.column, pk), 'in', values), AND)
@ -178,6 +181,7 @@ class UpdateQuery(Query):
result.append(query) result.append(query)
return result return result
class InsertQuery(Query): class InsertQuery(Query):
compiler = 'SQLInsertCompiler' compiler = 'SQLInsertCompiler'
@ -215,6 +219,7 @@ class InsertQuery(Query):
self.objs = objs self.objs = objs
self.raw = raw self.raw = raw
class DateQuery(Query): class DateQuery(Query):
""" """
A DateQuery is a normal query, except that it specifically selects a single A DateQuery is a normal query, except that it specifically selects a single
@ -260,6 +265,7 @@ class DateQuery(Query):
def _get_select(self, col, lookup_type): def _get_select(self, col, lookup_type):
return Date(col, lookup_type) return Date(col, lookup_type)
class DateTimeQuery(DateQuery): class DateTimeQuery(DateQuery):
""" """
A DateTimeQuery is like a DateQuery but for a datetime field. If time zone A DateTimeQuery is like a DateQuery but for a datetime field. If time zone
@ -280,6 +286,7 @@ class DateTimeQuery(DateQuery):
tzname = timezone._get_timezone_name(self.tzinfo) tzname = timezone._get_timezone_name(self.tzinfo)
return DateTime(col, lookup_type, tzname) return DateTime(col, lookup_type, tzname)
class AggregateQuery(Query): class AggregateQuery(Query):
""" """
An AggregateQuery takes another query as a parameter to the FROM An AggregateQuery takes another query as a parameter to the FROM

View File

@ -16,10 +16,12 @@ from django.utils.six.moves import xrange
from django.utils import timezone from django.utils import timezone
from django.utils import tree from django.utils import tree
# Connection types # Connection types
AND = 'AND' AND = 'AND'
OR = 'OR' OR = 'OR'
class EmptyShortCircuit(Exception): class EmptyShortCircuit(Exception):
""" """
Internal exception used to indicate that a "matches nothing" node should be Internal exception used to indicate that a "matches nothing" node should be
@ -27,6 +29,7 @@ class EmptyShortCircuit(Exception):
""" """
pass pass
class WhereNode(tree.Node): class WhereNode(tree.Node):
""" """
Used to represent the SQL where-clause. Used to represent the SQL where-clause.
@ -304,14 +307,15 @@ class WhereNode(tree.Node):
clone.children.append(child) clone.children.append(child)
return clone return clone
class EmptyWhere(WhereNode):
class EmptyWhere(WhereNode):
def add(self, data, connector): def add(self, data, connector):
return return
def as_sql(self, qn=None, connection=None): def as_sql(self, qn=None, connection=None):
raise EmptyResultSet raise EmptyResultSet
class EverythingNode(object): class EverythingNode(object):
""" """
A node that matches everything. A node that matches everything.
@ -385,6 +389,7 @@ class Constraint(object):
new.alias, new.col, new.field = change_map[self.alias], self.col, self.field new.alias, new.col, new.field = change_map[self.alias], self.col, self.field
return new return new
class SubqueryConstraint(object): class SubqueryConstraint(object):
def __init__(self, alias, columns, targets, query_object): def __init__(self, alias, columns, targets, query_object):
self.alias = alias self.alias = alias

View File

@ -27,6 +27,7 @@ class TransactionManagementError(Exception):
""" """
pass pass
################ ################
# Private APIs # # Private APIs #
################ ################
@ -40,6 +41,7 @@ def get_connection(using=None):
using = DEFAULT_DB_ALIAS using = DEFAULT_DB_ALIAS
return connections[using] return connections[using]
########################### ###########################
# Deprecated private APIs # # Deprecated private APIs #
########################### ###########################
@ -56,6 +58,7 @@ def abort(using=None):
""" """
get_connection(using).abort() get_connection(using).abort()
def enter_transaction_management(managed=True, using=None, forced=False): def enter_transaction_management(managed=True, using=None, forced=False):
""" """
Enters transaction management for a running thread. It must be balanced with Enters transaction management for a running thread. It must be balanced with
@ -68,6 +71,7 @@ def enter_transaction_management(managed=True, using=None, forced=False):
""" """
get_connection(using).enter_transaction_management(managed, forced) get_connection(using).enter_transaction_management(managed, forced)
def leave_transaction_management(using=None): def leave_transaction_management(using=None):
""" """
Leaves transaction management for a running thread. A dirty flag is carried Leaves transaction management for a running thread. A dirty flag is carried
@ -76,6 +80,7 @@ def leave_transaction_management(using=None):
""" """
get_connection(using).leave_transaction_management() get_connection(using).leave_transaction_management()
def is_dirty(using=None): def is_dirty(using=None):
""" """
Returns True if the current transaction requires a commit for changes to Returns True if the current transaction requires a commit for changes to
@ -83,6 +88,7 @@ def is_dirty(using=None):
""" """
return get_connection(using).is_dirty() return get_connection(using).is_dirty()
def set_dirty(using=None): def set_dirty(using=None):
""" """
Sets a dirty flag for the current thread and code streak. This can be used Sets a dirty flag for the current thread and code streak. This can be used
@ -91,6 +97,7 @@ def set_dirty(using=None):
""" """
get_connection(using).set_dirty() get_connection(using).set_dirty()
def set_clean(using=None): def set_clean(using=None):
""" """
Resets a dirty flag for the current thread and code streak. This can be used Resets a dirty flag for the current thread and code streak. This can be used
@ -99,22 +106,27 @@ def set_clean(using=None):
""" """
get_connection(using).set_clean() get_connection(using).set_clean()
def is_managed(using=None): def is_managed(using=None):
warnings.warn("'is_managed' is deprecated.", warnings.warn("'is_managed' is deprecated.",
DeprecationWarning, stacklevel=2) DeprecationWarning, stacklevel=2)
def managed(flag=True, using=None): def managed(flag=True, using=None):
warnings.warn("'managed' no longer serves a purpose.", warnings.warn("'managed' no longer serves a purpose.",
DeprecationWarning, stacklevel=2) DeprecationWarning, stacklevel=2)
def commit_unless_managed(using=None): def commit_unless_managed(using=None):
warnings.warn("'commit_unless_managed' is now a no-op.", warnings.warn("'commit_unless_managed' is now a no-op.",
DeprecationWarning, stacklevel=2) DeprecationWarning, stacklevel=2)
def rollback_unless_managed(using=None): def rollback_unless_managed(using=None):
warnings.warn("'rollback_unless_managed' is now a no-op.", warnings.warn("'rollback_unless_managed' is now a no-op.",
DeprecationWarning, stacklevel=2) DeprecationWarning, stacklevel=2)
############### ###############
# Public APIs # # Public APIs #
############### ###############
@ -125,24 +137,28 @@ def get_autocommit(using=None):
""" """
return get_connection(using).get_autocommit() return get_connection(using).get_autocommit()
def set_autocommit(autocommit, using=None): def set_autocommit(autocommit, using=None):
""" """
Set the autocommit status of the connection. Set the autocommit status of the connection.
""" """
return get_connection(using).set_autocommit(autocommit) return get_connection(using).set_autocommit(autocommit)
def commit(using=None): def commit(using=None):
""" """
Commits a transaction and resets the dirty flag. Commits a transaction and resets the dirty flag.
""" """
get_connection(using).commit() get_connection(using).commit()
def rollback(using=None): def rollback(using=None):
""" """
Rolls back a transaction and resets the dirty flag. Rolls back a transaction and resets the dirty flag.
""" """
get_connection(using).rollback() get_connection(using).rollback()
def savepoint(using=None): def savepoint(using=None):
""" """
Creates a savepoint (if supported and required by the backend) inside the Creates a savepoint (if supported and required by the backend) inside the
@ -151,6 +167,7 @@ def savepoint(using=None):
""" """
return get_connection(using).savepoint() return get_connection(using).savepoint()
def savepoint_rollback(sid, using=None): def savepoint_rollback(sid, using=None):
""" """
Rolls back the most recent savepoint (if one exists). Does nothing if Rolls back the most recent savepoint (if one exists). Does nothing if
@ -158,6 +175,7 @@ def savepoint_rollback(sid, using=None):
""" """
get_connection(using).savepoint_rollback(sid) get_connection(using).savepoint_rollback(sid)
def savepoint_commit(sid, using=None): def savepoint_commit(sid, using=None):
""" """
Commits the most recent savepoint (if one exists). Does nothing if Commits the most recent savepoint (if one exists). Does nothing if
@ -165,18 +183,21 @@ def savepoint_commit(sid, using=None):
""" """
get_connection(using).savepoint_commit(sid) get_connection(using).savepoint_commit(sid)
def clean_savepoints(using=None): def clean_savepoints(using=None):
""" """
Resets the counter used to generate unique savepoint ids in this thread. Resets the counter used to generate unique savepoint ids in this thread.
""" """
get_connection(using).clean_savepoints() get_connection(using).clean_savepoints()
def get_rollback(using=None): def get_rollback(using=None):
""" """
Gets the "needs rollback" flag -- for *advanced use* only. Gets the "needs rollback" flag -- for *advanced use* only.
""" """
return get_connection(using).get_rollback() return get_connection(using).get_rollback()
def set_rollback(rollback, using=None): def set_rollback(rollback, using=None):
""" """
Sets or unsets the "needs rollback" flag -- for *advanced use* only. Sets or unsets the "needs rollback" flag -- for *advanced use* only.
@ -191,6 +212,7 @@ def set_rollback(rollback, using=None):
""" """
return get_connection(using).set_rollback(rollback) return get_connection(using).set_rollback(rollback)
################################# #################################
# Decorators / context managers # # Decorators / context managers #
################################# #################################
@ -398,6 +420,7 @@ class Transaction(object):
return func(*args, **kwargs) return func(*args, **kwargs)
return inner return inner
def _transaction_func(entering, exiting, using): def _transaction_func(entering, exiting, using):
""" """
Takes 3 things, an entering function (what to do to start this block of Takes 3 things, an entering function (what to do to start this block of
@ -436,6 +459,7 @@ def autocommit(using=None):
return _transaction_func(entering, exiting, using) return _transaction_func(entering, exiting, using)
def commit_on_success(using=None): def commit_on_success(using=None):
""" """
This decorator activates commit on response. This way, if the view function This decorator activates commit on response. This way, if the view function
@ -466,6 +490,7 @@ def commit_on_success(using=None):
return _transaction_func(entering, exiting, using) return _transaction_func(entering, exiting, using)
def commit_manually(using=None): def commit_manually(using=None):
""" """
Decorator that activates manual transaction control. It just disables Decorator that activates manual transaction control. It just disables
@ -484,6 +509,7 @@ def commit_manually(using=None):
return _transaction_func(entering, exiting, using) return _transaction_func(entering, exiting, using)
def commit_on_success_unless_managed(using=None, savepoint=False): def commit_on_success_unless_managed(using=None, savepoint=False):
""" """
Transitory API to preserve backwards-compatibility while refactoring. Transitory API to preserve backwards-compatibility while refactoring.