""" MySQL database backend for Django. Requires mysqlclient: https://pypi.python.org/pypi/mysqlclient/ """ import re from django.core.exceptions import ImproperlyConfigured from django.db import utils from django.db.backends import utils as backend_utils from django.db.backends.base.base import BaseDatabaseWrapper from django.utils.functional import cached_property try: import MySQLdb as Database except ImportError as err: raise ImproperlyConfigured( 'Error loading MySQLdb module.\n' 'Did you install mysqlclient?' ) from err from MySQLdb.constants import CLIENT, FIELD_TYPE # isort:skip from MySQLdb.converters import conversions # isort:skip # Some of these import MySQLdb, so import them after checking if it's installed. from .client import DatabaseClient # isort:skip from .creation import DatabaseCreation # isort:skip from .features import DatabaseFeatures # isort:skip from .introspection import DatabaseIntrospection # isort:skip from .operations import DatabaseOperations # isort:skip from .schema import DatabaseSchemaEditor # isort:skip from .validation import DatabaseValidation # isort:skip version = Database.version_info if version < (1, 3, 7): raise ImproperlyConfigured('mysqlclient 1.3.7 or newer is required; you have %s.' % Database.__version__) # MySQLdb returns TIME columns as timedelta -- they are more like timedelta in # terms of actual behavior as they are signed and include days -- and Django # expects time. django_conversions = { **conversions, **{FIELD_TYPE.TIME: backend_utils.typecast_time}, } # This should match the numerical portion of the version numbers (we can treat # versions like 5.0.24 and 5.0.24a as the same). server_version_re = re.compile(r'(\d{1,2})\.(\d{1,2})\.(\d{1,2})') class CursorWrapper: """ A thin wrapper around MySQLdb's normal cursor class that catches particular exception instances and reraises them with the correct types. Implemented as a wrapper, rather than a subclass, so that it isn't stuck to the particular underlying representation returned by Connection.cursor(). """ codes_for_integrityerror = ( 1048, # Column cannot be null 1690, # BIGINT UNSIGNED value is out of range ) def __init__(self, cursor): self.cursor = cursor def execute(self, query, args=None): try: # args is None means no string interpolation return self.cursor.execute(query, args) except Database.OperationalError as e: # Map some error codes to IntegrityError, since they seem to be # misclassified and Django would prefer the more logical place. if e.args[0] in self.codes_for_integrityerror: raise utils.IntegrityError(*tuple(e.args)) raise def executemany(self, query, args): try: return self.cursor.executemany(query, args) except Database.OperationalError as e: # Map some error codes to IntegrityError, since they seem to be # misclassified and Django would prefer the more logical place. if e.args[0] in self.codes_for_integrityerror: raise utils.IntegrityError(*tuple(e.args)) raise def __getattr__(self, attr): return getattr(self.cursor, attr) def __iter__(self): return iter(self.cursor) class DatabaseWrapper(BaseDatabaseWrapper): vendor = 'mysql' display_name = 'MySQL' # This dictionary maps Field objects to their associated MySQL column # types, as strings. Column-type strings can contain format strings; they'll # be interpolated against the values of Field.__dict__ before being output. # If a column type is set to None, it won't be included in the output. data_types = { 'AutoField': 'integer AUTO_INCREMENT', 'BigAutoField': 'bigint AUTO_INCREMENT', 'BinaryField': 'longblob', 'BooleanField': 'bool', 'CharField': 'varchar(%(max_length)s)', 'DateField': 'date', 'DateTimeField': 'datetime(6)', 'DecimalField': 'numeric(%(max_digits)s, %(decimal_places)s)', 'DurationField': 'bigint', 'FileField': 'varchar(%(max_length)s)', 'FilePathField': 'varchar(%(max_length)s)', 'FloatField': 'double precision', 'IntegerField': 'integer', 'BigIntegerField': 'bigint', 'IPAddressField': 'char(15)', 'GenericIPAddressField': 'char(39)', 'NullBooleanField': 'bool', 'OneToOneField': 'integer', 'PositiveIntegerField': 'integer UNSIGNED', 'PositiveSmallIntegerField': 'smallint UNSIGNED', 'SlugField': 'varchar(%(max_length)s)', 'SmallIntegerField': 'smallint', 'TextField': 'longtext', 'TimeField': 'time(6)', 'UUIDField': 'char(32)', } # For these columns, MySQL doesn't: # - accept default values and implicitly treats these columns as nullable # - support a database index _limited_data_types = ( 'tinyblob', 'blob', 'mediumblob', 'longblob', 'tinytext', 'text', 'mediumtext', 'longtext', 'json', ) operators = { 'exact': '= %s', 'iexact': 'LIKE %s', 'contains': 'LIKE BINARY %s', 'icontains': 'LIKE %s', 'regex': 'REGEXP BINARY %s', 'iregex': 'REGEXP %s', 'gt': '> %s', 'gte': '>= %s', 'lt': '< %s', 'lte': '<= %s', 'startswith': 'LIKE BINARY %s', 'endswith': 'LIKE BINARY %s', 'istartswith': 'LIKE %s', 'iendswith': 'LIKE %s', } # The patterns below are used to generate SQL pattern lookup clauses when # the right-hand side of the lookup isn't a raw string (it might be an expression # or the result of a bilateral transformation). # In those cases, special characters for LIKE operators (e.g. \, *, _) should be # escaped on database side. # # Note: we use str.format() here for readability as '%' is used as a wildcard for # the LIKE operator. pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\\', '\\\\'), '%%', '\%%'), '_', '\_')" pattern_ops = { 'contains': "LIKE BINARY CONCAT('%%', {}, '%%')", 'icontains': "LIKE CONCAT('%%', {}, '%%')", 'startswith': "LIKE BINARY CONCAT({}, '%%')", 'istartswith': "LIKE CONCAT({}, '%%')", 'endswith': "LIKE BINARY CONCAT('%%', {})", 'iendswith': "LIKE CONCAT('%%', {})", } isolation_levels = { 'read uncommitted', 'read committed', 'repeatable read', 'serializable', } Database = Database SchemaEditorClass = DatabaseSchemaEditor # Classes instantiated in __init__(). client_class = DatabaseClient creation_class = DatabaseCreation features_class = DatabaseFeatures introspection_class = DatabaseIntrospection ops_class = DatabaseOperations validation_class = DatabaseValidation def get_connection_params(self): kwargs = { 'conv': django_conversions, 'charset': 'utf8', } settings_dict = self.settings_dict if settings_dict['USER']: kwargs['user'] = settings_dict['USER'] if settings_dict['NAME']: kwargs['db'] = settings_dict['NAME'] if settings_dict['PASSWORD']: kwargs['passwd'] = settings_dict['PASSWORD'] if settings_dict['HOST'].startswith('/'): kwargs['unix_socket'] = settings_dict['HOST'] elif settings_dict['HOST']: kwargs['host'] = settings_dict['HOST'] if settings_dict['PORT']: kwargs['port'] = int(settings_dict['PORT']) # We need the number of potentially affected rows after an # "UPDATE", not the number of changed rows. kwargs['client_flag'] = CLIENT.FOUND_ROWS # Validate the transaction isolation level, if specified. options = settings_dict['OPTIONS'].copy() isolation_level = options.pop('isolation_level', 'read committed') if isolation_level: isolation_level = isolation_level.lower() if isolation_level not in self.isolation_levels: raise ImproperlyConfigured( "Invalid transaction isolation level '%s' specified.\n" "Use one of %s, or None." % ( isolation_level, ', '.join("'%s'" % s for s in sorted(self.isolation_levels)) )) self.isolation_level = isolation_level kwargs.update(options) return kwargs def get_new_connection(self, conn_params): return Database.connect(**conn_params) def init_connection_state(self): assignments = [] if self.features.is_sql_auto_is_null_enabled: # SQL_AUTO_IS_NULL controls whether an AUTO_INCREMENT column on # a recently inserted row will return when the field is tested # for NULL. Disabling this brings this aspect of MySQL in line # with SQL standards. assignments.append('SET SQL_AUTO_IS_NULL = 0') if self.isolation_level: assignments.append('SET SESSION TRANSACTION ISOLATION LEVEL %s' % self.isolation_level.upper()) if assignments: with self.cursor() as cursor: cursor.execute('; '.join(assignments)) def create_cursor(self, name=None): cursor = self.connection.cursor() return CursorWrapper(cursor) def _rollback(self): try: BaseDatabaseWrapper._rollback(self) except Database.NotSupportedError: pass def _set_autocommit(self, autocommit): with self.wrap_database_errors: self.connection.autocommit(autocommit) def disable_constraint_checking(self): """ Disable foreign key checks, primarily for use in adding rows with forward references. Always return True to indicate constraint checks need to be re-enabled. """ self.cursor().execute('SET foreign_key_checks=0') return True def enable_constraint_checking(self): """ Re-enable foreign key checks after they have been disabled. """ # Override needs_rollback in case constraint_checks_disabled is # nested inside transaction.atomic. self.needs_rollback, needs_rollback = False, self.needs_rollback try: self.cursor().execute('SET foreign_key_checks=1') finally: self.needs_rollback = needs_rollback def check_constraints(self, table_names=None): """ Check each table name in `table_names` for rows with invalid foreign key references. This method is intended to be used in conjunction with `disable_constraint_checking()` and `enable_constraint_checking()`, to determine if rows with invalid references were entered while constraint checks were off. """ with self.cursor() as cursor: if table_names is None: table_names = self.introspection.table_names(cursor) for table_name in table_names: primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name) if not primary_key_column_name: continue key_columns = self.introspection.get_key_columns(cursor, table_name) for column_name, referenced_table_name, referenced_column_name in key_columns: cursor.execute( """ SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING LEFT JOIN `%s` as REFERRED ON (REFERRING.`%s` = REFERRED.`%s`) WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL """ % ( primary_key_column_name, column_name, table_name, referenced_table_name, column_name, referenced_column_name, column_name, referenced_column_name, ) ) for bad_row in cursor.fetchall(): raise utils.IntegrityError( "The row in table '%s' with primary key '%s' has an invalid " "foreign key: %s.%s contains a value '%s' that does not " "have a corresponding value in %s.%s." % ( table_name, bad_row[0], table_name, column_name, bad_row[1], referenced_table_name, referenced_column_name, ) ) def is_usable(self): try: self.connection.ping() except Database.Error: return False else: return True @cached_property def mysql_version(self): with self.temporary_connection() as cursor: cursor.execute('SELECT VERSION()') server_info = cursor.fetchone()[0] match = server_version_re.match(server_info) if not match: raise Exception('Unable to determine MySQL version from version string %r' % server_info) return tuple(int(x) for x in match.groups())