diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py index a1fd923faf7..28a4a48a9f8 100644 --- a/django/db/backends/__init__.py +++ b/django/db/backends/__init__.py @@ -70,7 +70,9 @@ class BaseDatabaseWrapper(object): self._thread_ident = thread.get_ident() def __eq__(self, other): - return self.alias == other.alias + if isinstance(other, BaseDatabaseWrapper): + return self.alias == other.alias + return NotImplemented def __ne__(self, other): return not self == other diff --git a/tests/db_backends/__init__.py b/tests/db_backends/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/db_backends/tests.py b/tests/db_backends/tests.py new file mode 100644 index 00000000000..052ee872549 --- /dev/null +++ b/tests/db_backends/tests.py @@ -0,0 +1,36 @@ +from django.test import TestCase +from django.db.backends import BaseDatabaseWrapper + + +class DummyDatabaseWrapper(BaseDatabaseWrapper): + pass + + +class DummyObject(object): + alias = None + + +class DbBackendTests(TestCase): + def test_compare_db_wrapper_with_another_object(self): + wrapper = BaseDatabaseWrapper({}) + self.assertFalse(wrapper == 'not-a-db-wrapper') + + def test_compare_db_wrapper_with_another_object_with_alias(self): + wrapper = BaseDatabaseWrapper({}) + obj = DummyObject() + obj.alias = wrapper.alias = 'foobar' + self.assertFalse(wrapper == obj) + + def test_negate_compare_db_wrapper_with_another_object(self): + wrapper = BaseDatabaseWrapper({}) + self.assertTrue(wrapper != 'not-a-db-wrapper') + + def test_compare_db_wrappers(self): + wrapper1 = DummyDatabaseWrapper({}) + wrapper2 = BaseDatabaseWrapper({}) + + wrapper1.alias = wrapper2.alias = 'foo' + self.assertTrue(wrapper1 == wrapper2) + + wrapper1.alias = 'bar' + self.assertFalse(wrapper1 == wrapper2)