import datetime import decimal import hashlib import logging import re from time import time from django.conf import settings from django.db.utils import NotSupportedError from django.utils.encoding import force_bytes from django.utils.timezone import utc logger = logging.getLogger('django.db.backends') class CursorWrapper: def __init__(self, cursor, db): self.cursor = cursor self.db = db WRAP_ERROR_ATTRS = frozenset(['fetchone', 'fetchmany', 'fetchall', 'nextset']) def __getattr__(self, attr): cursor_attr = getattr(self.cursor, attr) if attr in CursorWrapper.WRAP_ERROR_ATTRS: return self.db.wrap_database_errors(cursor_attr) else: return cursor_attr def __iter__(self): with self.db.wrap_database_errors: yield from self.cursor def __enter__(self): return self def __exit__(self, type, value, traceback): # Close instead of passing through to avoid backend-specific behavior # (#17671). Catch errors liberally because errors in cleanup code # aren't useful. try: self.close() except self.db.Database.Error: pass # The following methods cannot be implemented in __getattr__, because the # code must run when the method is invoked, not just when it is accessed. def callproc(self, procname, params=None, kparams=None): # Keyword parameters for callproc aren't supported in PEP 249, but the # database driver may support them (e.g. cx_Oracle). if kparams is not None and not self.db.features.supports_callproc_kwargs: raise NotSupportedError( 'Keyword parameters for callproc are not supported on this ' 'database backend.' ) self.db.validate_no_broken_transaction() with self.db.wrap_database_errors: if params is None and kparams is None: return self.cursor.callproc(procname) elif kparams is None: return self.cursor.callproc(procname, params) else: params = params or () return self.cursor.callproc(procname, params, kparams) def execute(self, sql, params=None): self.db.validate_no_broken_transaction() with self.db.wrap_database_errors: if params is None: return self.cursor.execute(sql) else: return self.cursor.execute(sql, params) def executemany(self, sql, param_list): self.db.validate_no_broken_transaction() with self.db.wrap_database_errors: return self.cursor.executemany(sql, param_list) class CursorDebugWrapper(CursorWrapper): # XXX callproc isn't instrumented at this time. def execute(self, sql, params=None): start = time() try: return super().execute(sql, params) finally: stop = time() duration = stop - start sql = self.db.ops.last_executed_query(self.cursor, sql, params) self.db.queries_log.append({ 'sql': sql, 'time': "%.3f" % duration, }) logger.debug( '(%.3f) %s; args=%s', duration, sql, params, extra={'duration': duration, 'sql': sql, 'params': params} ) def executemany(self, sql, param_list): start = time() try: return super().executemany(sql, param_list) finally: stop = time() duration = stop - start try: times = len(param_list) except TypeError: # param_list could be an iterator times = '?' self.db.queries_log.append({ 'sql': '%s times: %s' % (times, sql), 'time': "%.3f" % duration, }) logger.debug( '(%.3f) %s; args=%s', duration, sql, param_list, extra={'duration': duration, 'sql': sql, 'params': param_list} ) ############################################### # Converters from database (string) to Python # ############################################### def typecast_date(s): return datetime.date(*map(int, s.split('-'))) if s else None # return None if s is null def typecast_time(s): # does NOT store time zone information if not s: return None hour, minutes, seconds = s.split(':') if '.' in seconds: # check whether seconds have a fractional part seconds, microseconds = seconds.split('.') else: microseconds = '0' return datetime.time(int(hour), int(minutes), int(seconds), int((microseconds + '000000')[:6])) def typecast_timestamp(s): # does NOT store time zone information # "2005-07-29 15:48:00.590358-05" # "2005-07-29 09:56:00-05" if not s: return None if ' ' not in s: return typecast_date(s) d, t = s.split() # Extract timezone information, if it exists. Currently it's ignored. if '-' in t: t, tz = t.split('-', 1) tz = '-' + tz elif '+' in t: t, tz = t.split('+', 1) tz = '+' + tz else: tz = '' dates = d.split('-') times = t.split(':') seconds = times[2] if '.' in seconds: # check whether seconds have a fractional part seconds, microseconds = seconds.split('.') else: microseconds = '0' tzinfo = utc if settings.USE_TZ else None return datetime.datetime( int(dates[0]), int(dates[1]), int(dates[2]), int(times[0]), int(times[1]), int(seconds), int((microseconds + '000000')[:6]), tzinfo ) ############################################### # Converters from Python to database (string) # ############################################### def rev_typecast_decimal(d): if d is None: return None return str(d) def truncate_name(name, length=None, hash_len=4): """ Shorten a string to a repeatable mangled version with the given length. If a quote stripped name contains a username, e.g. USERNAME"."TABLE, truncate the table portion only. """ match = re.match(r'([^"]+)"\."([^"]+)', name) table_name = match.group(2) if match else name if length is None or len(table_name) <= length: return name hsh = hashlib.md5(force_bytes(table_name)).hexdigest()[:hash_len] return '%s%s%s' % (match.group(1) + '"."' if match else '', table_name[:length - hash_len], hsh) def format_number(value, max_digits, decimal_places): """ Format a number into a string with the requisite number of digits and decimal places. """ if value is None: return None if isinstance(value, decimal.Decimal): context = decimal.getcontext().copy() if max_digits is not None: context.prec = max_digits if decimal_places is not None: value = value.quantize(decimal.Decimal(".1") ** decimal_places, context=context) else: context.traps[decimal.Rounded] = 1 value = context.create_decimal(value) return "{:f}".format(value) if decimal_places is not None: return "%.*f" % (decimal_places, value) return "{:f}".format(value) def strip_quotes(table_name): """ Strip quotes off of quoted table names to make them safe for use in index names, sequence names, etc. For example '"USER"."TABLE"' (an Oracle naming scheme) becomes 'USER"."TABLE'. """ has_quotes = table_name.startswith('"') and table_name.endswith('"') return table_name[1:-1] if has_quotes else table_name