From 22bb548900146832459deaefa880660c17a51516 Mon Sep 17 00:00:00 2001 From: Sergey Kolosov Date: Fri, 16 May 2014 18:18:34 +0200 Subject: [PATCH] Fixed #22634 -- Made the database-backed session backends more extensible. Introduced an AbstractBaseSession model and hooks providing the option of overriding the model class used by the session store and the session store class used by the model. --- AUTHORS | 1 + django/contrib/sessions/backends/cache.py | 8 +- django/contrib/sessions/backends/cached_db.py | 15 +- django/contrib/sessions/backends/db.py | 49 +++-- django/contrib/sessions/base_session.py | 51 +++++ .../sessions/migrations/0001_initial.py | 2 +- django/contrib/sessions/models.py | 48 ++--- docs/releases/1.9.txt | 5 +- docs/topics/http/sessions.txt | 177 +++++++++++++++++- tests/runtests.py | 1 + tests/sessions_tests/custom_db_backend.py | 43 +++++ tests/sessions_tests/tests.py | 48 ++++- 12 files changed, 370 insertions(+), 78 deletions(-) create mode 100644 django/contrib/sessions/base_session.py create mode 100644 tests/sessions_tests/custom_db_backend.py diff --git a/AUTHORS b/AUTHORS index bc3576f93b..73d701920d 100644 --- a/AUTHORS +++ b/AUTHORS @@ -640,6 +640,7 @@ answer newbie questions, and generally made Django that much better: Sengtha Chay Senko Rašić serbaut@gmail.com + Sergey Kolosov Seth Hill Shai Berger Shannon -jj Behrens diff --git a/django/contrib/sessions/backends/cache.py b/django/contrib/sessions/backends/cache.py index 9be47e00a1..268e06576a 100644 --- a/django/contrib/sessions/backends/cache.py +++ b/django/contrib/sessions/backends/cache.py @@ -10,13 +10,15 @@ class SessionStore(SessionBase): """ A cache-based session store. """ + cache_key_prefix = KEY_PREFIX + def __init__(self, session_key=None): self._cache = caches[settings.SESSION_CACHE_ALIAS] super(SessionStore, self).__init__(session_key) @property def cache_key(self): - return KEY_PREFIX + self._get_or_create_session_key() + return self.cache_key_prefix + self._get_or_create_session_key() def load(self): try: @@ -62,14 +64,14 @@ class SessionStore(SessionBase): raise CreateError def exists(self, session_key): - return session_key and (KEY_PREFIX + session_key) in self._cache + return session_key and (self.cache_key_prefix + session_key) in self._cache def delete(self, session_key=None): if session_key is None: if self.session_key is None: return session_key = self.session_key - self._cache.delete(KEY_PREFIX + session_key) + self._cache.delete(self.cache_key_prefix + session_key) @classmethod def clear_expired(cls): diff --git a/django/contrib/sessions/backends/cached_db.py b/django/contrib/sessions/backends/cached_db.py index bc9a55fd9d..fda3a76a2e 100644 --- a/django/contrib/sessions/backends/cached_db.py +++ b/django/contrib/sessions/backends/cached_db.py @@ -18,6 +18,7 @@ class SessionStore(DBStore): """ Implements cached, database backed sessions. """ + cache_key_prefix = KEY_PREFIX def __init__(self, session_key=None): self._cache = caches[settings.SESSION_CACHE_ALIAS] @@ -25,7 +26,7 @@ class SessionStore(DBStore): @property def cache_key(self): - return KEY_PREFIX + self._get_or_create_session_key() + return self.cache_key_prefix + self._get_or_create_session_key() def load(self): try: @@ -39,14 +40,14 @@ class SessionStore(DBStore): # Duplicate DBStore.load, because we need to keep track # of the expiry date to set it properly in the cache. try: - s = Session.objects.get( + s = self.model.objects.get( session_key=self.session_key, expire_date__gt=timezone.now() ) data = self.decode(s.session_data) self._cache.set(self.cache_key, data, self.get_expiry_age(expiry=s.expire_date)) - except (Session.DoesNotExist, SuspiciousOperation) as e: + except (self.model.DoesNotExist, SuspiciousOperation) as e: if isinstance(e, SuspiciousOperation): logger = logging.getLogger('django.security.%s' % e.__class__.__name__) @@ -56,7 +57,7 @@ class SessionStore(DBStore): return data def exists(self, session_key): - if session_key and (KEY_PREFIX + session_key) in self._cache: + if session_key and (self.cache_key_prefix + session_key) in self._cache: return True return super(SessionStore, self).exists(session_key) @@ -70,7 +71,7 @@ class SessionStore(DBStore): if self.session_key is None: return session_key = self.session_key - self._cache.delete(KEY_PREFIX + session_key) + self._cache.delete(self.cache_key_prefix + session_key) def flush(self): """ @@ -80,7 +81,3 @@ class SessionStore(DBStore): self.clear() self.delete(self.session_key) self._session_key = None - - -# At bottom to avoid circular import -from django.contrib.sessions.models import Session # isort:skip diff --git a/django/contrib/sessions/backends/db.py b/django/contrib/sessions/backends/db.py index 0fba3ec178..ae2470e120 100644 --- a/django/contrib/sessions/backends/db.py +++ b/django/contrib/sessions/backends/db.py @@ -5,6 +5,7 @@ from django.core.exceptions import SuspiciousOperation from django.db import IntegrityError, router, transaction from django.utils import timezone from django.utils.encoding import force_text +from django.utils.functional import cached_property class SessionStore(SessionBase): @@ -14,14 +15,25 @@ class SessionStore(SessionBase): def __init__(self, session_key=None): super(SessionStore, self).__init__(session_key) + @classmethod + def get_model_class(cls): + # Avoids a circular import and allows importing SessionStore when + # django.contrib.sessions is not in INSTALLED_APPS. + from django.contrib.sessions.models import Session + return Session + + @cached_property + def model(self): + return self.get_model_class() + def load(self): try: - s = Session.objects.get( + s = self.model.objects.get( session_key=self.session_key, expire_date__gt=timezone.now() ) return self.decode(s.session_data) - except (Session.DoesNotExist, SuspiciousOperation) as e: + except (self.model.DoesNotExist, SuspiciousOperation) as e: if isinstance(e, SuspiciousOperation): logger = logging.getLogger('django.security.%s' % e.__class__.__name__) @@ -30,7 +42,7 @@ class SessionStore(SessionBase): return {} def exists(self, session_key): - return Session.objects.filter(session_key=session_key).exists() + return self.model.objects.filter(session_key=session_key).exists() def create(self): while True: @@ -45,6 +57,18 @@ class SessionStore(SessionBase): self.modified = True return + def create_model_instance(self, data): + """ + Return a new instance of the session model object, which represents the + current session state. Intended to be used for saving the session data + to the database. + """ + return self.model( + session_key=self._get_or_create_session_key(), + session_data=self.encode(data), + expire_date=self.get_expiry_date(), + ) + def save(self, must_create=False): """ Saves the current session data to the database. If 'must_create' is @@ -54,12 +78,9 @@ class SessionStore(SessionBase): """ if self.session_key is None: return self.create() - obj = Session( - session_key=self._get_or_create_session_key(), - session_data=self.encode(self._get_session(no_load=must_create)), - expire_date=self.get_expiry_date() - ) - using = router.db_for_write(Session, instance=obj) + data = self._get_session(no_load=must_create) + obj = self.create_model_instance(data) + using = router.db_for_write(self.model, instance=obj) try: with transaction.atomic(using=using): obj.save(force_insert=must_create, using=using) @@ -74,14 +95,10 @@ class SessionStore(SessionBase): return session_key = self.session_key try: - Session.objects.get(session_key=session_key).delete() - except Session.DoesNotExist: + self.model.objects.get(session_key=session_key).delete() + except self.model.DoesNotExist: pass @classmethod def clear_expired(cls): - Session.objects.filter(expire_date__lt=timezone.now()).delete() - - -# At bottom to avoid circular import -from django.contrib.sessions.models import Session # isort:skip + cls.get_model_class().objects.filter(expire_date__lt=timezone.now()).delete() diff --git a/django/contrib/sessions/base_session.py b/django/contrib/sessions/base_session.py new file mode 100644 index 0000000000..5fed0cd2d1 --- /dev/null +++ b/django/contrib/sessions/base_session.py @@ -0,0 +1,51 @@ +""" +This module allows importing AbstractBaseSession even +when django.contrib.sessions is not in INSTALLED_APPS. +""" +from __future__ import unicode_literals + +from django.db import models +from django.utils.encoding import python_2_unicode_compatible +from django.utils.translation import ugettext_lazy as _ + + +class BaseSessionManager(models.Manager): + def encode(self, session_dict): + """ + Return the given session dictionary serialized and encoded as a string. + """ + session_store_class = self.model.get_session_store_class() + return session_store_class().encode(session_dict) + + def save(self, session_key, session_dict, expire_date): + s = self.model(session_key, self.encode(session_dict), expire_date) + if session_dict: + s.save() + else: + s.delete() # Clear sessions with no data. + return s + + +@python_2_unicode_compatible +class AbstractBaseSession(models.Model): + session_key = models.CharField(_('session key'), max_length=40, primary_key=True) + session_data = models.TextField(_('session data')) + expire_date = models.DateTimeField(_('expire date'), db_index=True) + + objects = BaseSessionManager() + + class Meta: + abstract = True + verbose_name = _('session') + verbose_name_plural = _('sessions') + + def __str__(self): + return self.session_key + + @classmethod + def get_session_store_class(cls): + raise NotImplementedError + + def get_decoded(self): + session_store_class = self.get_session_store_class() + return session_store_class().decode(self.session_data) diff --git a/django/contrib/sessions/migrations/0001_initial.py b/django/contrib/sessions/migrations/0001_initial.py index 82b856ae62..0a4acc9382 100644 --- a/django/contrib/sessions/migrations/0001_initial.py +++ b/django/contrib/sessions/migrations/0001_initial.py @@ -19,11 +19,11 @@ class Migration(migrations.Migration): ('expire_date', models.DateTimeField(verbose_name='expire date', db_index=True)), ], options={ + 'abstract': False, 'db_table': 'django_session', 'verbose_name': 'session', 'verbose_name_plural': 'sessions', }, - bases=(models.Model,), managers=[ ('objects', django.contrib.sessions.models.SessionManager()), ], diff --git a/django/contrib/sessions/models.py b/django/contrib/sessions/models.py index ce9d609a9b..3ee3ce73d3 100644 --- a/django/contrib/sessions/models.py +++ b/django/contrib/sessions/models.py @@ -1,30 +1,15 @@ from __future__ import unicode_literals -from django.db import models -from django.utils.encoding import python_2_unicode_compatible -from django.utils.translation import ugettext_lazy as _ +from django.contrib.sessions.base_session import ( + AbstractBaseSession, BaseSessionManager, +) -class SessionManager(models.Manager): +class SessionManager(BaseSessionManager): use_in_migrations = True - def encode(self, session_dict): - """ - Returns the given session dictionary serialized and encoded as a string. - """ - return SessionStore().encode(session_dict) - def save(self, session_key, session_dict, expire_date): - s = self.model(session_key, self.encode(session_dict), expire_date) - if session_dict: - s.save() - else: - s.delete() # Clear sessions with no data. - return s - - -@python_2_unicode_compatible -class Session(models.Model): +class Session(AbstractBaseSession): """ Django provides full support for anonymous sessions. The session framework lets you store and retrieve arbitrary data on a @@ -41,23 +26,12 @@ class Session(models.Model): the sessions documentation that is shipped with Django (also available on the Django Web site). """ - session_key = models.CharField(_('session key'), max_length=40, - primary_key=True) - session_data = models.TextField(_('session data')) - expire_date = models.DateTimeField(_('expire date'), db_index=True) objects = SessionManager() - class Meta: + @classmethod + def get_session_store_class(cls): + from django.contrib.sessions.backends.db import SessionStore + return SessionStore + + class Meta(AbstractBaseSession.Meta): db_table = 'django_session' - verbose_name = _('session') - verbose_name_plural = _('sessions') - - def __str__(self): - return self.session_key - - def get_decoded(self): - return SessionStore().decode(self.session_data) - - -# At bottom to avoid circular import -from django.contrib.sessions.backends.db import SessionStore # isort:skip diff --git a/docs/releases/1.9.txt b/docs/releases/1.9.txt index 1a5c488d0f..35b45ead43 100644 --- a/docs/releases/1.9.txt +++ b/docs/releases/1.9.txt @@ -254,7 +254,10 @@ Minor features :mod:`django.contrib.sessions` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -* ... +* The session model and ``SessionStore`` classes for the ``db`` and + ``cached_db`` backends are refactored to allow a custom database session + backend to build upon them. See + :ref:`extending-database-backed-session-engines` for more details. :mod:`django.contrib.sitemaps` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/topics/http/sessions.txt b/docs/topics/http/sessions.txt index 3a825f28cb..da9e5ef567 100644 --- a/docs/topics/http/sessions.txt +++ b/docs/topics/http/sessions.txt @@ -514,8 +514,10 @@ access sessions using the normal Django database API:: >>> s.expire_date datetime.datetime(2005, 8, 20, 13, 35, 12) -Note that you'll need to call ``get_decoded()`` to get the session dictionary. -This is necessary because the dictionary is stored in an encoded format:: +Note that you'll need to call +:meth:`~base_session.AbstractBaseSession.get_decoded()` to get the session +dictionary. This is necessary because the dictionary is stored in an encoded +format:: >>> s.session_data 'KGRwMQpTJ19hdXRoX3VzZXJfaWQnCnAyCkkxCnMuMTExY2ZjODI2Yj...' @@ -670,6 +672,177 @@ Technical details * Django only sends a cookie if it needs to. If you don't set any session data, it won't send a session cookie. +The ``SessionStore`` object +--------------------------- + +When working with sessions internally, Django uses a session store object from +the corresponding session engine. By convention, the session store object class +is named ``SessionStore`` and is located in the module designated by +:setting:`SESSION_ENGINE`. + +All ``SessionStore`` classes available in Django inherit from +:class:`~backends.base.SessionBase` and implement data manipulation methods, +namely: + +* ``exists()`` +* ``create()`` +* ``save()`` +* ``delete()`` +* ``load()`` +* :meth:`~backends.base.SessionBase.clear_expired` + +In order to build a custom session engine or to customize an existing one, you +may create a new class inheriting from :class:`~backends.base.SessionBase` or +any other existing ``SessionStore`` class. + +Extending most of the session engines is quite straightforward, but doing so +with database-backed session engines generally requires some extra effort (see +the next section for details). + +.. _extending-database-backed-session-engines: + +Extending database-backed session engines +========================================= + +.. versionadded:: 1.9 + +Creating a custom database-backed session engine built upon those included in +Django (namely ``db`` and ``cached_db``) may be done by inheriting +:class:`~base_session.AbstractBaseSession` and either ``SessionStore`` class. + +``AbstractBaseSession`` and ``BaseSessionManager`` are importable from +``django.contrib.sessions.base_session`` so that they can be imported without +including ``django.contrib.sessions`` in :setting:`INSTALLED_APPS`. + +.. class:: base_session.AbstractBaseSession + + .. versionadded:: 1.9 + + The abstract base session model. + + .. attribute:: session_key + + Primary key. The field itself may contain up to 40 characters. The + current implementation generates a 32-character string (a random + sequence of digits and lowercase ASCII letters). + + .. attribute:: session_data + + A string containing an encoded and serialized session dictionary. + + .. attribute:: expire_date + + A datetime designating when the session expires. + + Expired sessions are not available to a user, however, they may still + be stored in the database until the :djadmin:`clearsessions` management + command is run. + + .. classmethod:: get_session_store_class() + + Returns a session store class to be used with this session model. + + .. method:: get_decoded() + + Returns decoded session data. + + Decoding is performed by the session store class. + +You can also customize the model manager by subclassing +:class:`~django.contrib.sessions.base_session.BaseSessionManager`: + +.. class:: base_session.BaseSessionManager + + .. versionadded:: 1.9 + + .. method:: encode(session_dict) + + Returns the given session dictionary serialized and encoded as a string. + + Encoding is performed by the session store class tied to a model class. + + .. method:: save(session_key, session_dict, expire_date) + + Saves session data for a provided session key, or deletes the session + in case the data is empty. + +Customization of ``SessionStore`` classes is achieved by overriding methods +and properties described below: + +.. class:: backends.db.SessionStore + + Implements database-backed session store. + + .. classmethod:: get_model_class() + + .. versionadded:: 1.9 + + Override this method to return a custom session model if you need one. + + .. method:: create_model_instance(data) + + .. versionadded:: 1.9 + + Returns a new instance of the session model object, which represents + the current session state. + + Overriding this method provides the ability to modify session model + data before it's saved to database. + +.. class:: backends.cached_db.SessionStore + + Implements cached database-backed session store. + + .. attribute:: cache_key_prefix + + .. versionadded:: 1.9 + + A prefix added to a session key to build a cache key string. + +Example +------- + +The example below shows a custom database-backed session engine that includes +an additional database column to store an account ID (thus providing an option +to query the database for all active sessions for an account):: + + from django.contrib.sessions.backends.db import SessionStore as DBStore + from django.contrib.sessions.base_session import AbstractBaseSession + from django.db import models + + class CustomSession(AbstractBaseSession): + account_id = models.IntegerField(null=True, db_index=True) + + class Meta: + app_label = 'mysessions' + + @classmethod + def get_session_store_class(cls): + return SessionStore + + class SessionStore(DBStore): + @classmethod + def get_model_class(cls): + return CustomSession + + def create_model_instance(self, data): + obj = super(SessionStore, self).create_model_instance(data) + try: + account_id = int(data.get('_auth_user_id')) + except (ValueError, TypeError): + account_id = None + obj.account_id = account_id + return obj + +If you are migrating from the Django's built-in ``cached_db`` session store to +a custom one based on ``cached_db``, you should override the cache key prefix +in order to prevent a namespace clash:: + + class SessionStore(CachedDBStore): + cache_key_prefix = 'mysessions.custom_cached_db_backend' + + # ... + Session IDs in URLs =================== diff --git a/tests/runtests.py b/tests/runtests.py index 7367cd5d66..97c09c7372 100755 --- a/tests/runtests.py +++ b/tests/runtests.py @@ -140,6 +140,7 @@ def setup(verbosity, test_labels): # us skip creating migrations for the test models. 'auth': 'django.contrib.auth.tests.migrations', 'contenttypes': 'contenttypes_tests.migrations', + 'sessions': 'sessions_tests.migrations', } log_config = DEFAULT_LOGGING # Filter out non-error logging so we don't have to capture it in lots of diff --git a/tests/sessions_tests/custom_db_backend.py b/tests/sessions_tests/custom_db_backend.py new file mode 100644 index 0000000000..5d4857e5bb --- /dev/null +++ b/tests/sessions_tests/custom_db_backend.py @@ -0,0 +1,43 @@ +""" +This custom Session model adds an extra column to store an account ID. In +real-world applications, it gives you the option of querying the database for +all active sessions for a particular account. +""" +from django.contrib.sessions.backends.db import SessionStore as DBStore +from django.contrib.sessions.base_session import AbstractBaseSession +from django.db import models + + +class CustomSession(AbstractBaseSession): + """ + A session model with a column for an account ID. + """ + account_id = models.IntegerField(null=True, db_index=True) + + class Meta: + app_label = 'sessions' + + @classmethod + def get_session_store_class(cls): + return SessionStore + + +class SessionStore(DBStore): + """ + A database session store, that handles updating the account ID column + inside the custom session model. + """ + @classmethod + def get_model_class(cls): + return CustomSession + + def create_model_instance(self, data): + obj = super(SessionStore, self).create_model_instance(data) + + try: + account_id = int(data.get('_auth_user_id')) + except (ValueError, TypeError): + account_id = None + obj.account_id = account_id + + return obj diff --git a/tests/sessions_tests/tests.py b/tests/sessions_tests/tests.py index 76e625aa76..1a50720ffa 100644 --- a/tests/sessions_tests/tests.py +++ b/tests/sessions_tests/tests.py @@ -34,6 +34,8 @@ from django.utils import six, timezone from django.utils.encoding import force_text from django.utils.six.moves import http_cookies +from .custom_db_backend import SessionStore as CustomDatabaseSession + class SessionTestsMixin(object): # This does not inherit from TestCase to avoid any tests being run with this @@ -355,6 +357,11 @@ class SessionTestsMixin(object): class DatabaseSessionTests(SessionTestsMixin, TestCase): backend = DatabaseSession + session_engine = 'django.contrib.sessions.backends.db' + + @property + def model(self): + return self.backend.get_model_class() def test_session_str(self): "Session repr should be the session key." @@ -362,7 +369,7 @@ class DatabaseSessionTests(SessionTestsMixin, TestCase): self.session.save() session_key = self.session.session_key - s = Session.objects.get(session_key=session_key) + s = self.model.objects.get(session_key=session_key) self.assertEqual(force_text(s), session_key) @@ -374,7 +381,7 @@ class DatabaseSessionTests(SessionTestsMixin, TestCase): self.session['x'] = 1 self.session.save() - s = Session.objects.get(session_key=self.session.session_key) + s = self.model.objects.get(session_key=self.session.session_key) self.assertEqual(s.get_decoded(), {'x': 1}) @@ -386,19 +393,18 @@ class DatabaseSessionTests(SessionTestsMixin, TestCase): self.session['y'] = 1 self.session.save() - s = Session.objects.get(session_key=self.session.session_key) + s = self.model.objects.get(session_key=self.session.session_key) # Change it - Session.objects.save(s.session_key, {'y': 2}, s.expire_date) + self.model.objects.save(s.session_key, {'y': 2}, s.expire_date) # Clear cache, so that it will be retrieved from DB del self.session._session_cache self.assertEqual(self.session['y'], 2) - @override_settings(SESSION_ENGINE="django.contrib.sessions.backends.db") def test_clearsessions_command(self): """ Test clearsessions command for clearing expired sessions. """ - self.assertEqual(0, Session.objects.count()) + self.assertEqual(0, self.model.objects.count()) # One object in the future self.session['foo'] = 'bar' @@ -412,10 +418,11 @@ class DatabaseSessionTests(SessionTestsMixin, TestCase): other_session.save() # Two sessions are in the database before clearsessions... - self.assertEqual(2, Session.objects.count()) - management.call_command('clearsessions') + self.assertEqual(2, self.model.objects.count()) + with override_settings(SESSION_ENGINE=self.session_engine): + management.call_command('clearsessions') # ... and one is deleted. - self.assertEqual(1, Session.objects.count()) + self.assertEqual(1, self.model.objects.count()) @override_settings(USE_TZ=True) @@ -423,6 +430,29 @@ class DatabaseSessionWithTimeZoneTests(DatabaseSessionTests): pass +class CustomDatabaseSessionTests(DatabaseSessionTests): + backend = CustomDatabaseSession + session_engine = 'sessions_tests.custom_db_backend' + + def test_extra_session_field(self): + # Set the account ID to be picked up by a custom session storage + # and saved to a custom session model database column. + self.session['_auth_user_id'] = 42 + self.session.save() + + # Make sure that the customized create_model_instance() was called. + s = self.model.objects.get(session_key=self.session.session_key) + self.assertEqual(s.account_id, 42) + + # Make the session "anonymous". + self.session.pop('_auth_user_id') + self.session.save() + + # Make sure that save() on an existing session did the right job. + s = self.model.objects.get(session_key=self.session.session_key) + self.assertEqual(s.account_id, None) + + class CacheDBSessionTests(SessionTestsMixin, TestCase): backend = CacheDBSession