django1/django/db/backends/utils.py

226 lines
7.1 KiB
Python

import datetime
import decimal
import hashlib
import logging
import re
from time import time
from django.conf import settings
from django.utils.encoding import force_bytes
from django.utils.timezone import utc
logger = logging.getLogger('django.db.backends')
class CursorWrapper:
def __init__(self, cursor, db):
self.cursor = cursor
self.db = db
WRAP_ERROR_ATTRS = frozenset(['fetchone', 'fetchmany', 'fetchall', 'nextset'])
def __getattr__(self, attr):
cursor_attr = getattr(self.cursor, attr)
if attr in CursorWrapper.WRAP_ERROR_ATTRS:
return self.db.wrap_database_errors(cursor_attr)
else:
return cursor_attr
def __iter__(self):
with self.db.wrap_database_errors:
yield from self.cursor
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
# Close instead of passing through to avoid backend-specific behavior
# (#17671). Catch errors liberally because errors in cleanup code
# aren't useful.
try:
self.close()
except self.db.Database.Error:
pass
# The following methods cannot be implemented in __getattr__, because the
# code must run when the method is invoked, not just when it is accessed.
def callproc(self, procname, params=None):
self.db.validate_no_broken_transaction()
with self.db.wrap_database_errors:
if params is None:
return self.cursor.callproc(procname)
else:
return self.cursor.callproc(procname, params)
def execute(self, sql, params=None):
self.db.validate_no_broken_transaction()
with self.db.wrap_database_errors:
if params is None:
return self.cursor.execute(sql)
else:
return self.cursor.execute(sql, params)
def executemany(self, sql, param_list):
self.db.validate_no_broken_transaction()
with self.db.wrap_database_errors:
return self.cursor.executemany(sql, param_list)
class CursorDebugWrapper(CursorWrapper):
# XXX callproc isn't instrumented at this time.
def execute(self, sql, params=None):
start = time()
try:
return super().execute(sql, params)
finally:
stop = time()
duration = stop - start
sql = self.db.ops.last_executed_query(self.cursor, sql, params)
self.db.queries_log.append({
'sql': sql,
'time': "%.3f" % duration,
})
logger.debug(
'(%.3f) %s; args=%s', duration, sql, params,
extra={'duration': duration, 'sql': sql, 'params': params}
)
def executemany(self, sql, param_list):
start = time()
try:
return super().executemany(sql, param_list)
finally:
stop = time()
duration = stop - start
try:
times = len(param_list)
except TypeError: # param_list could be an iterator
times = '?'
self.db.queries_log.append({
'sql': '%s times: %s' % (times, sql),
'time': "%.3f" % duration,
})
logger.debug(
'(%.3f) %s; args=%s', duration, sql, param_list,
extra={'duration': duration, 'sql': sql, 'params': param_list}
)
###############################################
# Converters from database (string) to Python #
###############################################
def typecast_date(s):
return datetime.date(*map(int, s.split('-'))) if s else None # returns None if s is null
def typecast_time(s): # does NOT store time zone information
if not s:
return None
hour, minutes, seconds = s.split(':')
if '.' in seconds: # check whether seconds have a fractional part
seconds, microseconds = seconds.split('.')
else:
microseconds = '0'
return datetime.time(int(hour), int(minutes), int(seconds), int((microseconds + '000000')[:6]))
def typecast_timestamp(s): # does NOT store time zone information
# "2005-07-29 15:48:00.590358-05"
# "2005-07-29 09:56:00-05"
if not s:
return None
if ' ' not in s:
return typecast_date(s)
d, t = s.split()
# Extract timezone information, if it exists. Currently we just throw
# it away, but in the future we may make use of it.
if '-' in t:
t, tz = t.split('-', 1)
tz = '-' + tz
elif '+' in t:
t, tz = t.split('+', 1)
tz = '+' + tz
else:
tz = ''
dates = d.split('-')
times = t.split(':')
seconds = times[2]
if '.' in seconds: # check whether seconds have a fractional part
seconds, microseconds = seconds.split('.')
else:
microseconds = '0'
tzinfo = utc if settings.USE_TZ else None
return datetime.datetime(
int(dates[0]), int(dates[1]), int(dates[2]),
int(times[0]), int(times[1]), int(seconds),
int((microseconds + '000000')[:6]), tzinfo
)
def typecast_decimal(s):
if s is None or s == '':
return None
return decimal.Decimal(s)
###############################################
# Converters from Python to database (string) #
###############################################
def rev_typecast_decimal(d):
if d is None:
return None
return str(d)
def truncate_name(name, length=None, hash_len=4):
"""
Shorten a string to a repeatable mangled version with the given length.
If a quote stripped name contains a username, e.g. USERNAME"."TABLE,
truncate the table portion only.
"""
match = re.match('([^"]+)"\."([^"]+)', name)
table_name = match.group(2) if match else name
if length is None or len(table_name) <= length:
return name
hsh = hashlib.md5(force_bytes(table_name)).hexdigest()[:hash_len]
return '%s%s%s' % (match.group(1) + '"."' if match else '', table_name[:length - hash_len], hsh)
def format_number(value, max_digits, decimal_places):
"""
Formats a number into a string with the requisite number of digits and
decimal places.
"""
if value is None:
return None
if isinstance(value, decimal.Decimal):
context = decimal.getcontext().copy()
if max_digits is not None:
context.prec = max_digits
if decimal_places is not None:
value = value.quantize(decimal.Decimal(".1") ** decimal_places, context=context)
else:
context.traps[decimal.Rounded] = 1
value = context.create_decimal(value)
return "{:f}".format(value)
if decimal_places is not None:
return "%.*f" % (decimal_places, value)
return "{:f}".format(value)
def strip_quotes(table_name):
"""
Strip quotes off of quoted table names to make them safe for use in index
names, sequence names, etc. For example '"USER"."TABLE"' (an Oracle naming
scheme) becomes 'USER"."TABLE'.
"""
has_quotes = table_name.startswith('"') and table_name.endswith('"')
return table_name[1:-1] if has_quotes else table_name