django1/django/db/backends/sqlite3/operations.py

281 lines
12 KiB
Python

from __future__ import unicode_literals
import datetime
import uuid
from django.conf import settings
from django.core.exceptions import FieldError, ImproperlyConfigured
from django.db import utils
from django.db.backends import utils as backend_utils
from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.models import aggregates, fields
from django.utils import six, timezone
from django.utils.dateparse import parse_date, parse_datetime, parse_time
from django.utils.duration import duration_string
try:
import pytz
except ImportError:
pytz = None
class DatabaseOperations(BaseDatabaseOperations):
def bulk_batch_size(self, fields, objs):
"""
SQLite has a compile-time default (SQLITE_LIMIT_VARIABLE_NUMBER) of
999 variables per query.
If there is just single field to insert, then we can hit another
limit, SQLITE_MAX_COMPOUND_SELECT which defaults to 500.
"""
limit = 999 if len(fields) > 1 else 500
return (limit // len(fields)) if len(fields) > 0 else len(objs)
def check_expression_support(self, expression):
bad_fields = (fields.DateField, fields.DateTimeField, fields.TimeField)
bad_aggregates = (aggregates.Sum, aggregates.Avg, aggregates.Variance, aggregates.StdDev)
if isinstance(expression, bad_aggregates):
for expr in expression.get_source_expressions():
try:
output_field = expr.output_field
if isinstance(output_field, bad_fields):
raise NotImplementedError(
'You cannot use Sum, Avg, StdDev, and Variance '
'aggregations on date/time fields in sqlite3 '
'since date/time is saved as text.'
)
except FieldError:
# Not every subexpression has an output_field which is fine
# to ignore.
pass
def date_extract_sql(self, lookup_type, field_name):
# sqlite doesn't support extract, so we fake it with the user-defined
# function django_date_extract that's registered in connect(). Note that
# single quotes are used because this is a string (and could otherwise
# cause a collision with a field name).
return "django_date_extract('%s', %s)" % (lookup_type.lower(), field_name)
def date_interval_sql(self, timedelta):
return "'%s'" % duration_string(timedelta), []
def format_for_duration_arithmetic(self, sql):
"""Do nothing here, we will handle it in the custom function."""
return sql
def date_trunc_sql(self, lookup_type, field_name):
# sqlite doesn't support DATE_TRUNC, so we fake it with a user-defined
# function django_date_trunc that's registered in connect(). Note that
# single quotes are used because this is a string (and could otherwise
# cause a collision with a field name).
return "django_date_trunc('%s', %s)" % (lookup_type.lower(), field_name)
def time_trunc_sql(self, lookup_type, field_name):
# sqlite doesn't support DATE_TRUNC, so we fake it with a user-defined
# function django_date_trunc that's registered in connect(). Note that
# single quotes are used because this is a string (and could otherwise
# cause a collision with a field name).
return "django_time_trunc('%s', %s)" % (lookup_type.lower(), field_name)
def _require_pytz(self):
if settings.USE_TZ and pytz is None:
raise ImproperlyConfigured("This query requires pytz, but it isn't installed.")
def datetime_cast_date_sql(self, field_name, tzname):
self._require_pytz()
return "django_datetime_cast_date(%s, %%s)" % field_name, [tzname]
def datetime_cast_time_sql(self, field_name, tzname):
self._require_pytz()
return "django_datetime_cast_time(%s, %%s)" % field_name, [tzname]
def datetime_extract_sql(self, lookup_type, field_name, tzname):
# Same comment as in date_extract_sql.
self._require_pytz()
return "django_datetime_extract('%s', %s, %%s)" % (
lookup_type.lower(), field_name), [tzname]
def datetime_trunc_sql(self, lookup_type, field_name, tzname):
# Same comment as in date_trunc_sql.
self._require_pytz()
return "django_datetime_trunc('%s', %s, %%s)" % (
lookup_type.lower(), field_name), [tzname]
def time_extract_sql(self, lookup_type, field_name):
# sqlite doesn't support extract, so we fake it with the user-defined
# function django_time_extract that's registered in connect(). Note that
# single quotes are used because this is a string (and could otherwise
# cause a collision with a field name).
return "django_time_extract('%s', %s)" % (lookup_type.lower(), field_name)
def pk_default_value(self):
return "NULL"
def _quote_params_for_last_executed_query(self, params):
"""
Only for last_executed_query! Don't use this to execute SQL queries!
"""
# This function is limited both by SQLITE_LIMIT_VARIABLE_NUMBER (the
# number of parameters, default = 999) and SQLITE_MAX_COLUMN (the
# number of return values, default = 2000). Since Python's sqlite3
# module doesn't expose the get_limit() C API, assume the default
# limits are in effect and split the work in batches if needed.
BATCH_SIZE = 999
if len(params) > BATCH_SIZE:
results = ()
for index in range(0, len(params), BATCH_SIZE):
chunk = params[index:index + BATCH_SIZE]
results += self._quote_params_for_last_executed_query(chunk)
return results
sql = 'SELECT ' + ', '.join(['QUOTE(?)'] * len(params))
# Bypass Django's wrappers and use the underlying sqlite3 connection
# to avoid logging this query - it would trigger infinite recursion.
cursor = self.connection.connection.cursor()
# Native sqlite3 cursors cannot be used as context managers.
try:
return cursor.execute(sql, params).fetchone()
finally:
cursor.close()
def last_executed_query(self, cursor, sql, params):
# Python substitutes parameters in Modules/_sqlite/cursor.c with:
# pysqlite_statement_bind_parameters(self->statement, parameters, allow_8bit_chars);
# Unfortunately there is no way to reach self->statement from Python,
# so we quote and substitute parameters manually.
if params:
if isinstance(params, (list, tuple)):
params = self._quote_params_for_last_executed_query(params)
else:
keys = params.keys()
values = tuple(params.values())
values = self._quote_params_for_last_executed_query(values)
params = dict(zip(keys, values))
return sql % params
# For consistency with SQLiteCursorWrapper.execute(), just return sql
# when there are no parameters. See #13648 and #17158.
else:
return sql
def quote_name(self, name):
if name.startswith('"') and name.endswith('"'):
return name # Quoting once is enough.
return '"%s"' % name
def no_limit_value(self):
return -1
def sql_flush(self, style, tables, sequences, allow_cascade=False):
# NB: The generated SQL below is specific to SQLite
# Note: The DELETE FROM... SQL generated below works for SQLite databases
# because constraints don't exist
sql = ['%s %s %s;' % (
style.SQL_KEYWORD('DELETE'),
style.SQL_KEYWORD('FROM'),
style.SQL_FIELD(self.quote_name(table))
) for table in tables]
# Note: No requirement for reset of auto-incremented indices (cf. other
# sql_flush() implementations). Just return SQL at this point
return sql
def adapt_datetimefield_value(self, value):
if value is None:
return None
# SQLite doesn't support tz-aware datetimes
if timezone.is_aware(value):
if settings.USE_TZ:
value = timezone.make_naive(value, self.connection.timezone)
else:
raise ValueError("SQLite backend does not support timezone-aware datetimes when USE_TZ is False.")
return six.text_type(value)
def adapt_timefield_value(self, value):
if value is None:
return None
# SQLite doesn't support tz-aware datetimes
if timezone.is_aware(value):
raise ValueError("SQLite backend does not support timezone-aware times.")
return six.text_type(value)
def get_db_converters(self, expression):
converters = super(DatabaseOperations, self).get_db_converters(expression)
internal_type = expression.output_field.get_internal_type()
if internal_type == 'DateTimeField':
converters.append(self.convert_datetimefield_value)
elif internal_type == 'DateField':
converters.append(self.convert_datefield_value)
elif internal_type == 'TimeField':
converters.append(self.convert_timefield_value)
elif internal_type == 'DecimalField':
converters.append(self.convert_decimalfield_value)
elif internal_type == 'UUIDField':
converters.append(self.convert_uuidfield_value)
return converters
def convert_datetimefield_value(self, value, expression, connection, context):
if value is not None:
if not isinstance(value, datetime.datetime):
value = parse_datetime(value)
if settings.USE_TZ:
value = timezone.make_aware(value, self.connection.timezone)
return value
def convert_datefield_value(self, value, expression, connection, context):
if value is not None:
if not isinstance(value, datetime.date):
value = parse_date(value)
return value
def convert_timefield_value(self, value, expression, connection, context):
if value is not None:
if not isinstance(value, datetime.time):
value = parse_time(value)
return value
def convert_decimalfield_value(self, value, expression, connection, context):
if value is not None:
value = expression.output_field.format_number(value)
value = backend_utils.typecast_decimal(value)
return value
def convert_uuidfield_value(self, value, expression, connection, context):
if value is not None:
value = uuid.UUID(value)
return value
def bulk_insert_sql(self, fields, placeholder_rows):
return " UNION ALL ".join(
"SELECT %s" % ", ".join(row)
for row in placeholder_rows
)
def combine_expression(self, connector, sub_expressions):
# SQLite doesn't have a power function, so we fake it with a
# user-defined function django_power that's registered in connect().
if connector == '^':
return 'django_power(%s)' % ','.join(sub_expressions)
return super(DatabaseOperations, self).combine_expression(connector, sub_expressions)
def combine_duration_expression(self, connector, sub_expressions):
if connector not in ['+', '-']:
raise utils.DatabaseError('Invalid connector for timedelta: %s.' % connector)
fn_params = ["'%s'" % connector] + sub_expressions
if len(fn_params) > 3:
raise ValueError('Too many params for timedelta operations.')
return "django_format_dtdelta(%s)" % ', '.join(fn_params)
def integer_field_range(self, internal_type):
# SQLite doesn't enforce any integer constraints
return (None, None)
def subtract_temporals(self, internal_type, lhs, rhs):
lhs_sql, lhs_params = lhs
rhs_sql, rhs_params = rhs
if internal_type == 'TimeField':
return "django_time_diff(%s, %s)" % (lhs_sql, rhs_sql), lhs_params + rhs_params
return "django_timestamp_diff(%s, %s)" % (lhs_sql, rhs_sql), lhs_params + rhs_params