Added a ManyToManyField(db_constraint=False) option, this allows not creating constraints on the intermediary models.

This commit is contained in:
Alex Gaynor 2013-03-07 11:24:51 -08:00
parent 4cccb85e29
commit bbbd698c7a
5 changed files with 65 additions and 22 deletions

View File

@ -955,7 +955,9 @@ class OneToOneRel(ManyToOneRel):
class ManyToManyRel(object): class ManyToManyRel(object):
def __init__(self, to, related_name=None, limit_choices_to=None, def __init__(self, to, related_name=None, limit_choices_to=None,
symmetrical=True, through=None): symmetrical=True, through=None, db_constraint=True):
if through and not db_constraint:
raise ValueError("Can't supply a through model and db_constraint=False")
self.to = to self.to = to
self.related_name = related_name self.related_name = related_name
if limit_choices_to is None: if limit_choices_to is None:
@ -964,6 +966,7 @@ class ManyToManyRel(object):
self.symmetrical = symmetrical self.symmetrical = symmetrical
self.multiple = True self.multiple = True
self.through = through self.through = through
self.db_constraint = db_constraint
def is_hidden(self): def is_hidden(self):
"Should the related object be hidden?" "Should the related object be hidden?"
@ -1196,15 +1199,15 @@ def create_many_to_many_intermediary_model(field, klass):
return type(name, (models.Model,), { return type(name, (models.Model,), {
'Meta': meta, 'Meta': meta,
'__module__': klass.__module__, '__module__': klass.__module__,
from_: models.ForeignKey(klass, related_name='%s+' % name, db_tablespace=field.db_tablespace), from_: models.ForeignKey(klass, related_name='%s+' % name, db_tablespace=field.db_tablespace, db_constraint=field.rel.db_constraint),
to: models.ForeignKey(to_model, related_name='%s+' % name, db_tablespace=field.db_tablespace) to: models.ForeignKey(to_model, related_name='%s+' % name, db_tablespace=field.db_tablespace, db_constraint=field.rel.db_constraint)
}) })
class ManyToManyField(RelatedField, Field): class ManyToManyField(RelatedField, Field):
description = _("Many-to-many relationship") description = _("Many-to-many relationship")
def __init__(self, to, **kwargs): def __init__(self, to, db_constraint=True, **kwargs):
try: try:
assert not to._meta.abstract, "%s cannot define a relation with abstract class %s" % (self.__class__.__name__, to._meta.object_name) assert not to._meta.abstract, "%s cannot define a relation with abstract class %s" % (self.__class__.__name__, to._meta.object_name)
except AttributeError: # to._meta doesn't exist, so it must be RECURSIVE_RELATIONSHIP_CONSTANT except AttributeError: # to._meta doesn't exist, so it must be RECURSIVE_RELATIONSHIP_CONSTANT
@ -1219,13 +1222,15 @@ class ManyToManyField(RelatedField, Field):
related_name=kwargs.pop('related_name', None), related_name=kwargs.pop('related_name', None),
limit_choices_to=kwargs.pop('limit_choices_to', None), limit_choices_to=kwargs.pop('limit_choices_to', None),
symmetrical=kwargs.pop('symmetrical', to == RECURSIVE_RELATIONSHIP_CONSTANT), symmetrical=kwargs.pop('symmetrical', to == RECURSIVE_RELATIONSHIP_CONSTANT),
through=kwargs.pop('through', None)) through=kwargs.pop('through', None),
db_constraint=db_constraint,
)
self.db_table = kwargs.pop('db_table', None) self.db_table = kwargs.pop('db_table', None)
if kwargs['rel'].through is not None: if kwargs['rel'].through is not None:
assert self.db_table is None, "Cannot specify a db_table if an intermediary model is used." assert self.db_table is None, "Cannot specify a db_table if an intermediary model is used."
Field.__init__(self, **kwargs) super(ManyToManyField, self).__init__(**kwargs)
msg = _('Hold down "Control", or "Command" on a Mac, to select more than one.') msg = _('Hold down "Control", or "Command" on a Mac, to select more than one.')
self.help_text = string_concat(self.help_text, ' ', msg) self.help_text = string_concat(self.help_text, ' ', msg)

View File

@ -1227,6 +1227,20 @@ that control how the relationship functions.
the table for the model defining the relationship and the name of the field the table for the model defining the relationship and the name of the field
itself. itself.
.. attribute:: ManyToManyField.db_constraint
Controls whether or not constraints should be created in the database for
the foreign keys in the intermediary table. The default is ``True``, and
that's almost certainly what you want; setting this to ``False`` can be
very bad for data integrity. That said, here are some scenarios where you
might want to do this:
* You have legacy data that is not valid.
* You're sharding your database.
It is an error to pass both ``db_constraint`` and ``through``.
.. _ref-onetoone: .. _ref-onetoone:
``OneToOneField`` ``OneToOneField``

View File

@ -113,8 +113,8 @@ Minor features
* The ``MemcachedCache`` cache backend now uses the latest :mod:`pickle` * The ``MemcachedCache`` cache backend now uses the latest :mod:`pickle`
protocol available. protocol available.
* Added the :attr:`django.db.models.ForeignKey.db_constraint` * Added the :attr:`django.db.models.ForeignKey.db_constraint` and
option. :attr:`django.db.models.ManyToManyField.db_constraint` options.
* The jQuery library embedded in the admin has been upgraded to version 1.9.1. * The jQuery library embedded in the admin has been upgraded to version 1.9.1.

View File

@ -90,7 +90,10 @@ class Item(models.Model):
@python_2_unicode_compatible @python_2_unicode_compatible
class Object(models.Model): class Object(models.Model):
pass related_objects = models.ManyToManyField("self", db_constraint=False, symmetrical=False)
def __str__(self):
return str(self.id)
@python_2_unicode_compatible @python_2_unicode_compatible

View File

@ -12,13 +12,12 @@ from django.db import (backend, connection, connections, DEFAULT_DB_ALIAS,
from django.db.backends.signals import connection_created from django.db.backends.signals import connection_created
from django.db.backends.postgresql_psycopg2 import version as pg_version from django.db.backends.postgresql_psycopg2 import version as pg_version
from django.db.models import Sum, Avg, Variance, StdDev from django.db.models import Sum, Avg, Variance, StdDev
from django.db.utils import ConnectionHandler, DatabaseError from django.db.utils import ConnectionHandler
from django.test import (TestCase, skipUnlessDBFeature, skipIfDBFeature, from django.test import (TestCase, skipUnlessDBFeature, skipIfDBFeature,
TransactionTestCase) TransactionTestCase)
from django.test.utils import override_settings, str_prefix from django.test.utils import override_settings, str_prefix
from django.utils import six from django.utils import six, unittest
from django.utils.six.moves import xrange from django.utils.six.moves import xrange
from django.utils import unittest
from . import models from . import models
@ -52,7 +51,7 @@ class OracleChecks(unittest.TestCase):
convert_unicode = backend.convert_unicode convert_unicode = backend.convert_unicode
cursor = connection.cursor() cursor = connection.cursor()
cursor.callproc(convert_unicode('DBMS_SESSION.SET_IDENTIFIER'), cursor.callproc(convert_unicode('DBMS_SESSION.SET_IDENTIFIER'),
[convert_unicode('_django_testing!'),]) [convert_unicode('_django_testing!')])
@unittest.skipUnless(connection.vendor == 'oracle', @unittest.skipUnless(connection.vendor == 'oracle',
"No need to check Oracle cursor semantics") "No need to check Oracle cursor semantics")
@ -72,7 +71,7 @@ class OracleChecks(unittest.TestCase):
c = connection.cursor() c = connection.cursor()
c.execute('CREATE TABLE ltext ("TEXT" NCLOB)') c.execute('CREATE TABLE ltext ("TEXT" NCLOB)')
long_str = ''.join([six.text_type(x) for x in xrange(4000)]) long_str = ''.join([six.text_type(x) for x in xrange(4000)])
c.execute('INSERT INTO ltext VALUES (%s)',[long_str]) c.execute('INSERT INTO ltext VALUES (%s)', [long_str])
c.execute('SELECT text FROM ltext') c.execute('SELECT text FROM ltext')
row = c.fetchone() row = c.fetchone()
self.assertEqual(long_str, row[0].read()) self.assertEqual(long_str, row[0].read())
@ -99,6 +98,7 @@ class OracleChecks(unittest.TestCase):
c.execute(query) c.execute(query)
self.assertEqual(c.fetchone()[0], 1) self.assertEqual(c.fetchone()[0], 1)
class MySQLTests(TestCase): class MySQLTests(TestCase):
@unittest.skipUnless(connection.vendor == 'mysql', @unittest.skipUnless(connection.vendor == 'mysql',
"Test valid only for MySQL") "Test valid only for MySQL")
@ -117,7 +117,7 @@ class MySQLTests(TestCase):
found_reset = False found_reset = False
for sql in statements: for sql in statements:
found_reset = found_reset or 'ALTER TABLE' in sql found_reset = found_reset or 'ALTER TABLE' in sql
if connection.mysql_version < (5,0,13): if connection.mysql_version < (5, 0, 13):
self.assertTrue(found_reset) self.assertTrue(found_reset)
else: else:
self.assertFalse(found_reset) self.assertFalse(found_reset)
@ -182,6 +182,7 @@ class LastExecutedQueryTest(TestCase):
self.assertEqual(connection.queries[-1]['sql'], self.assertEqual(connection.queries[-1]['sql'],
str_prefix("QUERY = %(_)s\"SELECT strftime('%%Y', 'now');\" - PARAMS = ()")) str_prefix("QUERY = %(_)s\"SELECT strftime('%%Y', 'now');\" - PARAMS = ()"))
class ParameterHandlingTest(TestCase): class ParameterHandlingTest(TestCase):
def test_bad_parameter_count(self): def test_bad_parameter_count(self):
"An executemany call with too many/not enough parameters will raise an exception (Refs #12612)" "An executemany call with too many/not enough parameters will raise an exception (Refs #12612)"
@ -191,8 +192,9 @@ class ParameterHandlingTest(TestCase):
connection.ops.quote_name('root'), connection.ops.quote_name('root'),
connection.ops.quote_name('square') connection.ops.quote_name('square')
)) ))
self.assertRaises(Exception, cursor.executemany, query, [(1,2,3),]) self.assertRaises(Exception, cursor.executemany, query, [(1, 2, 3)])
self.assertRaises(Exception, cursor.executemany, query, [(1,),]) self.assertRaises(Exception, cursor.executemany, query, [(1,)])
# Unfortunately, the following tests would be a good test to run on all # Unfortunately, the following tests would be a good test to run on all
# backends, but it breaks MySQL hard. Until #13711 is fixed, it can't be run # backends, but it breaks MySQL hard. Until #13711 is fixed, it can't be run
@ -240,6 +242,7 @@ class LongNameTest(TestCase):
for statement in connection.ops.sql_flush(no_style(), tables, sequences): for statement in connection.ops.sql_flush(no_style(), tables, sequences):
cursor.execute(statement) cursor.execute(statement)
class SequenceResetTest(TestCase): class SequenceResetTest(TestCase):
def test_generic_relation(self): def test_generic_relation(self):
"Sequence names are correct when resetting generic relations (Ref #13941)" "Sequence names are correct when resetting generic relations (Ref #13941)"
@ -257,6 +260,7 @@ class SequenceResetTest(TestCase):
obj = models.Post.objects.create(name='New post', text='goodbye world') obj = models.Post.objects.create(name='New post', text='goodbye world')
self.assertTrue(obj.pk > 10) self.assertTrue(obj.pk > 10)
class PostgresVersionTest(TestCase): class PostgresVersionTest(TestCase):
def assert_parses(self, version_string, version): def assert_parses(self, version_string, version):
self.assertEqual(pg_version._parse_version(version_string), version) self.assertEqual(pg_version._parse_version(version_string), version)
@ -291,6 +295,7 @@ class PostgresVersionTest(TestCase):
conn = OlderConnectionMock() conn = OlderConnectionMock()
self.assertEqual(pg_version.get_version(conn), 80300) self.assertEqual(pg_version.get_version(conn), 80300)
class PostgresNewConnectionTest(TestCase): class PostgresNewConnectionTest(TestCase):
""" """
#17062: PostgreSQL shouldn't roll back SET TIME ZONE, even if the first #17062: PostgreSQL shouldn't roll back SET TIME ZONE, even if the first
@ -338,17 +343,18 @@ class ConnectionCreatedSignalTest(TestCase):
@skipUnlessDBFeature('test_db_allows_multiple_connections') @skipUnlessDBFeature('test_db_allows_multiple_connections')
def test_signal(self): def test_signal(self):
data = {} data = {}
def receiver(sender, connection, **kwargs): def receiver(sender, connection, **kwargs):
data["connection"] = connection data["connection"] = connection
connection_created.connect(receiver) connection_created.connect(receiver)
connection.close() connection.close()
cursor = connection.cursor() connection.cursor()
self.assertTrue(data["connection"].connection is connection.connection) self.assertTrue(data["connection"].connection is connection.connection)
connection_created.disconnect(receiver) connection_created.disconnect(receiver)
data.clear() data.clear()
cursor = connection.cursor() connection.cursor()
self.assertTrue(data == {}) self.assertTrue(data == {})
@ -443,7 +449,7 @@ class BackendTestCase(TestCase):
old_password = connection.settings_dict['PASSWORD'] old_password = connection.settings_dict['PASSWORD']
connection.settings_dict['PASSWORD'] = "françois" connection.settings_dict['PASSWORD'] = "françois"
try: try:
cursor = connection.cursor() connection.cursor()
except DatabaseError: except DatabaseError:
# As password is probably wrong, a database exception is expected # As password is probably wrong, a database exception is expected
pass pass
@ -470,6 +476,7 @@ class BackendTestCase(TestCase):
with self.assertRaises(DatabaseError): with self.assertRaises(DatabaseError):
cursor.execute(query) cursor.execute(query)
# We don't make these tests conditional because that means we would need to # We don't make these tests conditional because that means we would need to
# check and differentiate between: # check and differentiate between:
# * MySQL+InnoDB, MySQL+MYISAM (something we currently can't do). # * MySQL+InnoDB, MySQL+MYISAM (something we currently can't do).
@ -477,7 +484,6 @@ class BackendTestCase(TestCase):
# on or not, something that would be controlled by runtime support and user # on or not, something that would be controlled by runtime support and user
# preference. # preference.
# verify if its type is django.database.db.IntegrityError. # verify if its type is django.database.db.IntegrityError.
class FkConstraintsTests(TransactionTestCase): class FkConstraintsTests(TransactionTestCase):
def setUp(self): def setUp(self):
@ -581,6 +587,7 @@ class ThreadTests(TestCase):
connections_dict = {} connections_dict = {}
connection.cursor() connection.cursor()
connections_dict[id(connection)] = connection connections_dict[id(connection)] = connection
def runner(): def runner():
# Passing django.db.connection between threads doesn't work while # Passing django.db.connection between threads doesn't work while
# connections[DEFAULT_DB_ALIAS] does. # connections[DEFAULT_DB_ALIAS] does.
@ -602,7 +609,7 @@ class ThreadTests(TestCase):
# Finish by closing the connections opened by the other threads (the # Finish by closing the connections opened by the other threads (the
# connection opened in the main thread will automatically be closed on # connection opened in the main thread will automatically be closed on
# teardown). # teardown).
for conn in connections_dict.values() : for conn in connections_dict.values():
if conn is not connection: if conn is not connection:
conn.close() conn.close()
@ -616,6 +623,7 @@ class ThreadTests(TestCase):
connections_dict = {} connections_dict = {}
for conn in connections.all(): for conn in connections.all():
connections_dict[id(conn)] = conn connections_dict[id(conn)] = conn
def runner(): def runner():
from django.db import connections from django.db import connections
for conn in connections.all(): for conn in connections.all():
@ -682,6 +690,7 @@ class ThreadTests(TestCase):
""" """
# First, without explicitly enabling the connection for sharing. # First, without explicitly enabling the connection for sharing.
exceptions = set() exceptions = set()
def runner1(): def runner1():
def runner2(other_thread_connection): def runner2(other_thread_connection):
try: try:
@ -699,6 +708,7 @@ class ThreadTests(TestCase):
# Then, with explicitly enabling the connection for sharing. # Then, with explicitly enabling the connection for sharing.
exceptions = set() exceptions = set()
def runner1(): def runner1():
def runner2(other_thread_connection): def runner2(other_thread_connection):
try: try:
@ -746,3 +756,14 @@ class DBConstraintTestCase(TransactionTestCase):
with self.assertRaises(models.Object.DoesNotExist): with self.assertRaises(models.Object.DoesNotExist):
ref.obj ref.obj
def test_many_to_many(self):
obj = models.Object.objects.create()
obj.related_objects.create()
self.assertEqual(models.Object.objects.count(), 2)
self.assertEqual(obj.related_objects.count(), 1)
intermediary_model = models.Object._meta.get_field_by_name("related_objects")[0].rel.through
intermediary_model.objects.create(from_object_id=obj.id, to_object_id=12345)
self.assertEqual(obj.related_objects.count(), 1)
self.assertEqual(intermediary_model.objects.count(), 2)