diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index 01b8a550d1..399d16600c 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -955,7 +955,9 @@ class OneToOneRel(ManyToOneRel): class ManyToManyRel(object): def __init__(self, to, related_name=None, limit_choices_to=None, - symmetrical=True, through=None): + symmetrical=True, through=None, db_constraint=True): + if through and not db_constraint: + raise ValueError("Can't supply a through model and db_constraint=False") self.to = to self.related_name = related_name if limit_choices_to is None: @@ -964,6 +966,7 @@ class ManyToManyRel(object): self.symmetrical = symmetrical self.multiple = True self.through = through + self.db_constraint = db_constraint def is_hidden(self): "Should the related object be hidden?" @@ -1196,15 +1199,15 @@ def create_many_to_many_intermediary_model(field, klass): return type(name, (models.Model,), { 'Meta': meta, '__module__': klass.__module__, - from_: models.ForeignKey(klass, related_name='%s+' % name, db_tablespace=field.db_tablespace), - to: models.ForeignKey(to_model, related_name='%s+' % name, db_tablespace=field.db_tablespace) + from_: models.ForeignKey(klass, related_name='%s+' % name, db_tablespace=field.db_tablespace, db_constraint=field.rel.db_constraint), + to: models.ForeignKey(to_model, related_name='%s+' % name, db_tablespace=field.db_tablespace, db_constraint=field.rel.db_constraint) }) class ManyToManyField(RelatedField, Field): description = _("Many-to-many relationship") - def __init__(self, to, **kwargs): + def __init__(self, to, db_constraint=True, **kwargs): try: assert not to._meta.abstract, "%s cannot define a relation with abstract class %s" % (self.__class__.__name__, to._meta.object_name) except AttributeError: # to._meta doesn't exist, so it must be RECURSIVE_RELATIONSHIP_CONSTANT @@ -1219,13 +1222,15 @@ class ManyToManyField(RelatedField, Field): related_name=kwargs.pop('related_name', None), limit_choices_to=kwargs.pop('limit_choices_to', None), symmetrical=kwargs.pop('symmetrical', to == RECURSIVE_RELATIONSHIP_CONSTANT), - through=kwargs.pop('through', None)) + through=kwargs.pop('through', None), + db_constraint=db_constraint, + ) self.db_table = kwargs.pop('db_table', None) if kwargs['rel'].through is not None: assert self.db_table is None, "Cannot specify a db_table if an intermediary model is used." - Field.__init__(self, **kwargs) + super(ManyToManyField, self).__init__(**kwargs) msg = _('Hold down "Control", or "Command" on a Mac, to select more than one.') self.help_text = string_concat(self.help_text, ' ', msg) diff --git a/docs/ref/models/fields.txt b/docs/ref/models/fields.txt index 1dbc8c3998..1b80196183 100644 --- a/docs/ref/models/fields.txt +++ b/docs/ref/models/fields.txt @@ -1227,6 +1227,20 @@ that control how the relationship functions. the table for the model defining the relationship and the name of the field itself. +.. attribute:: ManyToManyField.db_constraint + + Controls whether or not constraints should be created in the database for + the foreign keys in the intermediary table. The default is ``True``, and + that's almost certainly what you want; setting this to ``False`` can be + very bad for data integrity. That said, here are some scenarios where you + might want to do this: + + * You have legacy data that is not valid. + * You're sharding your database. + + It is an error to pass both ``db_constraint`` and ``through``. + + .. _ref-onetoone: ``OneToOneField`` diff --git a/docs/releases/1.6.txt b/docs/releases/1.6.txt index 599da6847c..81b1e48d25 100644 --- a/docs/releases/1.6.txt +++ b/docs/releases/1.6.txt @@ -113,8 +113,8 @@ Minor features * The ``MemcachedCache`` cache backend now uses the latest :mod:`pickle` protocol available. -* Added the :attr:`django.db.models.ForeignKey.db_constraint` - option. +* Added the :attr:`django.db.models.ForeignKey.db_constraint` and + :attr:`django.db.models.ManyToManyField.db_constraint` options. * The jQuery library embedded in the admin has been upgraded to version 1.9.1. diff --git a/tests/backends/models.py b/tests/backends/models.py index 5876cbe52d..94be36cfaf 100644 --- a/tests/backends/models.py +++ b/tests/backends/models.py @@ -90,7 +90,10 @@ class Item(models.Model): @python_2_unicode_compatible class Object(models.Model): - pass + related_objects = models.ManyToManyField("self", db_constraint=False, symmetrical=False) + + def __str__(self): + return str(self.id) @python_2_unicode_compatible diff --git a/tests/backends/tests.py b/tests/backends/tests.py index 7c68863f0b..103a44684e 100644 --- a/tests/backends/tests.py +++ b/tests/backends/tests.py @@ -12,13 +12,12 @@ from django.db import (backend, connection, connections, DEFAULT_DB_ALIAS, from django.db.backends.signals import connection_created from django.db.backends.postgresql_psycopg2 import version as pg_version from django.db.models import Sum, Avg, Variance, StdDev -from django.db.utils import ConnectionHandler, DatabaseError +from django.db.utils import ConnectionHandler from django.test import (TestCase, skipUnlessDBFeature, skipIfDBFeature, TransactionTestCase) from django.test.utils import override_settings, str_prefix -from django.utils import six +from django.utils import six, unittest from django.utils.six.moves import xrange -from django.utils import unittest from . import models @@ -52,7 +51,7 @@ class OracleChecks(unittest.TestCase): convert_unicode = backend.convert_unicode cursor = connection.cursor() cursor.callproc(convert_unicode('DBMS_SESSION.SET_IDENTIFIER'), - [convert_unicode('_django_testing!'),]) + [convert_unicode('_django_testing!')]) @unittest.skipUnless(connection.vendor == 'oracle', "No need to check Oracle cursor semantics") @@ -72,7 +71,7 @@ class OracleChecks(unittest.TestCase): c = connection.cursor() c.execute('CREATE TABLE ltext ("TEXT" NCLOB)') long_str = ''.join([six.text_type(x) for x in xrange(4000)]) - c.execute('INSERT INTO ltext VALUES (%s)',[long_str]) + c.execute('INSERT INTO ltext VALUES (%s)', [long_str]) c.execute('SELECT text FROM ltext') row = c.fetchone() self.assertEqual(long_str, row[0].read()) @@ -99,6 +98,7 @@ class OracleChecks(unittest.TestCase): c.execute(query) self.assertEqual(c.fetchone()[0], 1) + class MySQLTests(TestCase): @unittest.skipUnless(connection.vendor == 'mysql', "Test valid only for MySQL") @@ -117,7 +117,7 @@ class MySQLTests(TestCase): found_reset = False for sql in statements: found_reset = found_reset or 'ALTER TABLE' in sql - if connection.mysql_version < (5,0,13): + if connection.mysql_version < (5, 0, 13): self.assertTrue(found_reset) else: self.assertFalse(found_reset) @@ -182,6 +182,7 @@ class LastExecutedQueryTest(TestCase): self.assertEqual(connection.queries[-1]['sql'], str_prefix("QUERY = %(_)s\"SELECT strftime('%%Y', 'now');\" - PARAMS = ()")) + class ParameterHandlingTest(TestCase): def test_bad_parameter_count(self): "An executemany call with too many/not enough parameters will raise an exception (Refs #12612)" @@ -191,8 +192,9 @@ class ParameterHandlingTest(TestCase): connection.ops.quote_name('root'), connection.ops.quote_name('square') )) - self.assertRaises(Exception, cursor.executemany, query, [(1,2,3),]) - self.assertRaises(Exception, cursor.executemany, query, [(1,),]) + self.assertRaises(Exception, cursor.executemany, query, [(1, 2, 3)]) + self.assertRaises(Exception, cursor.executemany, query, [(1,)]) + # Unfortunately, the following tests would be a good test to run on all # backends, but it breaks MySQL hard. Until #13711 is fixed, it can't be run @@ -240,6 +242,7 @@ class LongNameTest(TestCase): for statement in connection.ops.sql_flush(no_style(), tables, sequences): cursor.execute(statement) + class SequenceResetTest(TestCase): def test_generic_relation(self): "Sequence names are correct when resetting generic relations (Ref #13941)" @@ -257,6 +260,7 @@ class SequenceResetTest(TestCase): obj = models.Post.objects.create(name='New post', text='goodbye world') self.assertTrue(obj.pk > 10) + class PostgresVersionTest(TestCase): def assert_parses(self, version_string, version): self.assertEqual(pg_version._parse_version(version_string), version) @@ -291,6 +295,7 @@ class PostgresVersionTest(TestCase): conn = OlderConnectionMock() self.assertEqual(pg_version.get_version(conn), 80300) + class PostgresNewConnectionTest(TestCase): """ #17062: PostgreSQL shouldn't roll back SET TIME ZONE, even if the first @@ -338,17 +343,18 @@ class ConnectionCreatedSignalTest(TestCase): @skipUnlessDBFeature('test_db_allows_multiple_connections') def test_signal(self): data = {} + def receiver(sender, connection, **kwargs): data["connection"] = connection connection_created.connect(receiver) connection.close() - cursor = connection.cursor() + connection.cursor() self.assertTrue(data["connection"].connection is connection.connection) connection_created.disconnect(receiver) data.clear() - cursor = connection.cursor() + connection.cursor() self.assertTrue(data == {}) @@ -443,7 +449,7 @@ class BackendTestCase(TestCase): old_password = connection.settings_dict['PASSWORD'] connection.settings_dict['PASSWORD'] = "françois" try: - cursor = connection.cursor() + connection.cursor() except DatabaseError: # As password is probably wrong, a database exception is expected pass @@ -470,6 +476,7 @@ class BackendTestCase(TestCase): with self.assertRaises(DatabaseError): cursor.execute(query) + # We don't make these tests conditional because that means we would need to # check and differentiate between: # * MySQL+InnoDB, MySQL+MYISAM (something we currently can't do). @@ -477,7 +484,6 @@ class BackendTestCase(TestCase): # on or not, something that would be controlled by runtime support and user # preference. # verify if its type is django.database.db.IntegrityError. - class FkConstraintsTests(TransactionTestCase): def setUp(self): @@ -581,6 +587,7 @@ class ThreadTests(TestCase): connections_dict = {} connection.cursor() connections_dict[id(connection)] = connection + def runner(): # Passing django.db.connection between threads doesn't work while # connections[DEFAULT_DB_ALIAS] does. @@ -602,7 +609,7 @@ class ThreadTests(TestCase): # Finish by closing the connections opened by the other threads (the # connection opened in the main thread will automatically be closed on # teardown). - for conn in connections_dict.values() : + for conn in connections_dict.values(): if conn is not connection: conn.close() @@ -616,6 +623,7 @@ class ThreadTests(TestCase): connections_dict = {} for conn in connections.all(): connections_dict[id(conn)] = conn + def runner(): from django.db import connections for conn in connections.all(): @@ -682,6 +690,7 @@ class ThreadTests(TestCase): """ # First, without explicitly enabling the connection for sharing. exceptions = set() + def runner1(): def runner2(other_thread_connection): try: @@ -699,6 +708,7 @@ class ThreadTests(TestCase): # Then, with explicitly enabling the connection for sharing. exceptions = set() + def runner1(): def runner2(other_thread_connection): try: @@ -746,3 +756,14 @@ class DBConstraintTestCase(TransactionTestCase): with self.assertRaises(models.Object.DoesNotExist): ref.obj + + def test_many_to_many(self): + obj = models.Object.objects.create() + obj.related_objects.create() + self.assertEqual(models.Object.objects.count(), 2) + self.assertEqual(obj.related_objects.count(), 1) + + intermediary_model = models.Object._meta.get_field_by_name("related_objects")[0].rel.through + intermediary_model.objects.create(from_object_id=obj.id, to_object_id=12345) + self.assertEqual(obj.related_objects.count(), 1) + self.assertEqual(intermediary_model.objects.count(), 2)