from __future__ import unicode_literals import datetime import decimal import hashlib import logging from time import time from django.conf import settings from django.utils.encoding import force_bytes from django.utils.timezone import utc logger = logging.getLogger('django.db.backends') class CursorWrapper(object): def __init__(self, cursor, db): self.cursor = cursor self.db = db WRAP_ERROR_ATTRS = frozenset(['fetchone', 'fetchmany', 'fetchall', 'nextset']) def __getattr__(self, attr): cursor_attr = getattr(self.cursor, attr) if attr in CursorWrapper.WRAP_ERROR_ATTRS: return self.db.wrap_database_errors(cursor_attr) else: return cursor_attr def __iter__(self): return iter(self.cursor) def __enter__(self): return self def __exit__(self, type, value, traceback): # Ticket #17671 - Close instead of passing thru to avoid backend # specific behavior. self.close() # The following methods cannot be implemented in __getattr__, because the # code must run when the method is invoked, not just when it is accessed. def callproc(self, procname, params=None): self.db.validate_no_broken_transaction() self.db.set_dirty() with self.db.wrap_database_errors: if params is None: return self.cursor.callproc(procname) else: return self.cursor.callproc(procname, params) def execute(self, sql, params=None): self.db.validate_no_broken_transaction() self.db.set_dirty() with self.db.wrap_database_errors: if params is None: return self.cursor.execute(sql) else: return self.cursor.execute(sql, params) def executemany(self, sql, param_list): self.db.validate_no_broken_transaction() self.db.set_dirty() with self.db.wrap_database_errors: return self.cursor.executemany(sql, param_list) class CursorDebugWrapper(CursorWrapper): # XXX callproc isn't instrumented at this time. def execute(self, sql, params=None): start = time() try: return super(CursorDebugWrapper, self).execute(sql, params) finally: stop = time() duration = stop - start sql = self.db.ops.last_executed_query(self.cursor, sql, params) self.db.queries.append({ 'sql': sql, 'time': "%.3f" % duration, }) logger.debug('(%.3f) %s; args=%s' % (duration, sql, params), extra={'duration': duration, 'sql': sql, 'params': params} ) def executemany(self, sql, param_list): start = time() try: return super(CursorDebugWrapper, self).executemany(sql, param_list) finally: stop = time() duration = stop - start try: times = len(param_list) except TypeError: # param_list could be an iterator times = '?' self.db.queries.append({ 'sql': '%s times: %s' % (times, sql), 'time': "%.3f" % duration, }) logger.debug('(%.3f) %s; args=%s' % (duration, sql, param_list), extra={'duration': duration, 'sql': sql, 'params': param_list} ) ############################################### # Converters from database (string) to Python # ############################################### def typecast_date(s): return datetime.date(*map(int, s.split('-'))) if s else None # returns None if s is null def typecast_time(s): # does NOT store time zone information if not s: return None hour, minutes, seconds = s.split(':') if '.' in seconds: # check whether seconds have a fractional part seconds, microseconds = seconds.split('.') else: microseconds = '0' return datetime.time(int(hour), int(minutes), int(seconds), int(float('.' + microseconds) * 1000000)) def typecast_timestamp(s): # does NOT store time zone information # "2005-07-29 15:48:00.590358-05" # "2005-07-29 09:56:00-05" if not s: return None if not ' ' in s: return typecast_date(s) d, t = s.split() # Extract timezone information, if it exists. Currently we just throw # it away, but in the future we may make use of it. if '-' in t: t, tz = t.split('-', 1) tz = '-' + tz elif '+' in t: t, tz = t.split('+', 1) tz = '+' + tz else: tz = '' dates = d.split('-') times = t.split(':') seconds = times[2] if '.' in seconds: # check whether seconds have a fractional part seconds, microseconds = seconds.split('.') else: microseconds = '0' tzinfo = utc if settings.USE_TZ else None return datetime.datetime(int(dates[0]), int(dates[1]), int(dates[2]), int(times[0]), int(times[1]), int(seconds), int((microseconds + '000000')[:6]), tzinfo) def typecast_decimal(s): if s is None or s == '': return None return decimal.Decimal(s) ############################################### # Converters from Python to database (string) # ############################################### def rev_typecast_decimal(d): if d is None: return None return str(d) def truncate_name(name, length=None, hash_len=4): """Shortens a string to a repeatable mangled version with the given length. """ if length is None or len(name) <= length: return name hsh = hashlib.md5(force_bytes(name)).hexdigest()[:hash_len] return '%s%s' % (name[:length - hash_len], hsh) def format_number(value, max_digits, decimal_places): """ Formats a number into a string with the requisite number of digits and decimal places. """ if isinstance(value, decimal.Decimal): context = decimal.getcontext().copy() context.prec = max_digits return "{0:f}".format(value.quantize(decimal.Decimal(".1") ** decimal_places, context=context)) else: return "%.*f" % (decimal_places, value) # Map of vendor name -> map of query element class -> implementation function compile_implementations = {} def get_implementations(vendor): try: implementation = compile_implementations[vendor] except KeyError: # TODO: do we need thread safety here? We could easily use an lock... implementation = {} compile_implementations[vendor] = implementation return implementation class add_implementation(object): def __init__(self, klass, vendor): self.klass = klass self.vendor = vendor def __call__(self, func): implementations = get_implementations(self.vendor) implementations[self.klass] = func return func