Fixed #34901 -- Added async-compatible interface to session engines.

Thanks Andrew-Chen-Wang for the initial implementation which was posted
to the Django forum thread about asyncifying contrib modules.
This commit is contained in:
Jon Janzen 2023-10-16 18:50:20 -07:00 committed by Mariusz Felisiak
parent 33c06ca0da
commit f5c340684b
12 changed files with 975 additions and 9 deletions

View File

@ -269,4 +269,6 @@ def update_session_auth_hash(request, user):
async def aupdate_session_auth_hash(request, user): async def aupdate_session_auth_hash(request, user):
"""See update_session_auth_hash().""" """See update_session_auth_hash()."""
return await sync_to_async(update_session_auth_hash)(request, user) await request.session.acycle_key()
if hasattr(user, "get_session_auth_hash") and request.user == user:
await request.session.aset(HASH_SESSION_KEY, user.get_session_auth_hash())

View File

@ -2,6 +2,8 @@ import logging
import string import string
from datetime import datetime, timedelta from datetime import datetime, timedelta
from asgiref.sync import sync_to_async
from django.conf import settings from django.conf import settings
from django.core import signing from django.core import signing
from django.utils import timezone from django.utils import timezone
@ -56,6 +58,10 @@ class SessionBase:
self._session[key] = value self._session[key] = value
self.modified = True self.modified = True
async def aset(self, key, value):
(await self._aget_session())[key] = value
self.modified = True
def __delitem__(self, key): def __delitem__(self, key):
del self._session[key] del self._session[key]
self.modified = True self.modified = True
@ -67,11 +73,19 @@ class SessionBase:
def get(self, key, default=None): def get(self, key, default=None):
return self._session.get(key, default) return self._session.get(key, default)
async def aget(self, key, default=None):
return (await self._aget_session()).get(key, default)
def pop(self, key, default=__not_given): def pop(self, key, default=__not_given):
self.modified = self.modified or key in self._session self.modified = self.modified or key in self._session
args = () if default is self.__not_given else (default,) args = () if default is self.__not_given else (default,)
return self._session.pop(key, *args) return self._session.pop(key, *args)
async def apop(self, key, default=__not_given):
self.modified = self.modified or key in (await self._aget_session())
args = () if default is self.__not_given else (default,)
return (await self._aget_session()).pop(key, *args)
def setdefault(self, key, value): def setdefault(self, key, value):
if key in self._session: if key in self._session:
return self._session[key] return self._session[key]
@ -79,15 +93,32 @@ class SessionBase:
self[key] = value self[key] = value
return value return value
async def asetdefault(self, key, value):
session = await self._aget_session()
if key in session:
return session[key]
else:
await self.aset(key, value)
return value
def set_test_cookie(self): def set_test_cookie(self):
self[self.TEST_COOKIE_NAME] = self.TEST_COOKIE_VALUE self[self.TEST_COOKIE_NAME] = self.TEST_COOKIE_VALUE
async def aset_test_cookie(self):
await self.aset(self.TEST_COOKIE_NAME, self.TEST_COOKIE_VALUE)
def test_cookie_worked(self): def test_cookie_worked(self):
return self.get(self.TEST_COOKIE_NAME) == self.TEST_COOKIE_VALUE return self.get(self.TEST_COOKIE_NAME) == self.TEST_COOKIE_VALUE
async def atest_cookie_worked(self):
return (await self.aget(self.TEST_COOKIE_NAME)) == self.TEST_COOKIE_VALUE
def delete_test_cookie(self): def delete_test_cookie(self):
del self[self.TEST_COOKIE_NAME] del self[self.TEST_COOKIE_NAME]
async def adelete_test_cookie(self):
del (await self._aget_session())[self.TEST_COOKIE_NAME]
def encode(self, session_dict): def encode(self, session_dict):
"Return the given session dictionary serialized and encoded as a string." "Return the given session dictionary serialized and encoded as a string."
return signing.dumps( return signing.dumps(
@ -115,18 +146,34 @@ class SessionBase:
self._session.update(dict_) self._session.update(dict_)
self.modified = True self.modified = True
async def aupdate(self, dict_):
(await self._aget_session()).update(dict_)
self.modified = True
def has_key(self, key): def has_key(self, key):
return key in self._session return key in self._session
async def ahas_key(self, key):
return key in (await self._aget_session())
def keys(self): def keys(self):
return self._session.keys() return self._session.keys()
async def akeys(self):
return (await self._aget_session()).keys()
def values(self): def values(self):
return self._session.values() return self._session.values()
async def avalues(self):
return (await self._aget_session()).values()
def items(self): def items(self):
return self._session.items() return self._session.items()
async def aitems(self):
return (await self._aget_session()).items()
def clear(self): def clear(self):
# To avoid unnecessary persistent storage accesses, we set up the # To avoid unnecessary persistent storage accesses, we set up the
# internals directly (loading data wastes time, since we are going to # internals directly (loading data wastes time, since we are going to
@ -149,11 +196,22 @@ class SessionBase:
if not self.exists(session_key): if not self.exists(session_key):
return session_key return session_key
async def _aget_new_session_key(self):
while True:
session_key = get_random_string(32, VALID_KEY_CHARS)
if not await self.aexists(session_key):
return session_key
def _get_or_create_session_key(self): def _get_or_create_session_key(self):
if self._session_key is None: if self._session_key is None:
self._session_key = self._get_new_session_key() self._session_key = self._get_new_session_key()
return self._session_key return self._session_key
async def _aget_or_create_session_key(self):
if self._session_key is None:
self._session_key = await self._aget_new_session_key()
return self._session_key
def _validate_session_key(self, key): def _validate_session_key(self, key):
""" """
Key must be truthy and at least 8 characters long. 8 characters is an Key must be truthy and at least 8 characters long. 8 characters is an
@ -191,6 +249,17 @@ class SessionBase:
self._session_cache = self.load() self._session_cache = self.load()
return self._session_cache return self._session_cache
async def _aget_session(self, no_load=False):
self.accessed = True
try:
return self._session_cache
except AttributeError:
if self.session_key is None or no_load:
self._session_cache = {}
else:
self._session_cache = await self.aload()
return self._session_cache
_session = property(_get_session) _session = property(_get_session)
def get_session_cookie_age(self): def get_session_cookie_age(self):
@ -223,6 +292,25 @@ class SessionBase:
delta = expiry - modification delta = expiry - modification
return delta.days * 86400 + delta.seconds return delta.days * 86400 + delta.seconds
async def aget_expiry_age(self, **kwargs):
try:
modification = kwargs["modification"]
except KeyError:
modification = timezone.now()
try:
expiry = kwargs["expiry"]
except KeyError:
expiry = await self.aget("_session_expiry")
if not expiry: # Checks both None and 0 cases
return self.get_session_cookie_age()
if not isinstance(expiry, (datetime, str)):
return expiry
if isinstance(expiry, str):
expiry = datetime.fromisoformat(expiry)
delta = expiry - modification
return delta.days * 86400 + delta.seconds
def get_expiry_date(self, **kwargs): def get_expiry_date(self, **kwargs):
"""Get session the expiry date (as a datetime object). """Get session the expiry date (as a datetime object).
@ -246,6 +334,23 @@ class SessionBase:
expiry = expiry or self.get_session_cookie_age() expiry = expiry or self.get_session_cookie_age()
return modification + timedelta(seconds=expiry) return modification + timedelta(seconds=expiry)
async def aget_expiry_date(self, **kwargs):
try:
modification = kwargs["modification"]
except KeyError:
modification = timezone.now()
try:
expiry = kwargs["expiry"]
except KeyError:
expiry = await self.aget("_session_expiry")
if isinstance(expiry, datetime):
return expiry
elif isinstance(expiry, str):
return datetime.fromisoformat(expiry)
expiry = expiry or self.get_session_cookie_age()
return modification + timedelta(seconds=expiry)
def set_expiry(self, value): def set_expiry(self, value):
""" """
Set a custom expiration for the session. ``value`` can be an integer, Set a custom expiration for the session. ``value`` can be an integer,
@ -274,6 +379,20 @@ class SessionBase:
value = value.isoformat() value = value.isoformat()
self["_session_expiry"] = value self["_session_expiry"] = value
async def aset_expiry(self, value):
if value is None:
# Remove any custom expiration for this session.
try:
await self.apop("_session_expiry")
except KeyError:
pass
return
if isinstance(value, timedelta):
value = timezone.now() + value
if isinstance(value, datetime):
value = value.isoformat()
await self.aset("_session_expiry", value)
def get_expire_at_browser_close(self): def get_expire_at_browser_close(self):
""" """
Return ``True`` if the session is set to expire when the browser Return ``True`` if the session is set to expire when the browser
@ -285,6 +404,11 @@ class SessionBase:
return settings.SESSION_EXPIRE_AT_BROWSER_CLOSE return settings.SESSION_EXPIRE_AT_BROWSER_CLOSE
return expiry == 0 return expiry == 0
async def aget_expire_at_browser_close(self):
if (expiry := await self.aget("_session_expiry")) is None:
return settings.SESSION_EXPIRE_AT_BROWSER_CLOSE
return expiry == 0
def flush(self): def flush(self):
""" """
Remove the current session data from the database and regenerate the Remove the current session data from the database and regenerate the
@ -294,6 +418,11 @@ class SessionBase:
self.delete() self.delete()
self._session_key = None self._session_key = None
async def aflush(self):
self.clear()
await self.adelete()
self._session_key = None
def cycle_key(self): def cycle_key(self):
""" """
Create a new session key, while retaining the current session data. Create a new session key, while retaining the current session data.
@ -305,6 +434,17 @@ class SessionBase:
if key: if key:
self.delete(key) self.delete(key)
async def acycle_key(self):
"""
Create a new session key, while retaining the current session data.
"""
data = await self._aget_session()
key = self.session_key
await self.acreate()
self._session_cache = data
if key:
await self.adelete(key)
# Methods that child classes must implement. # Methods that child classes must implement.
def exists(self, session_key): def exists(self, session_key):
@ -315,6 +455,9 @@ class SessionBase:
"subclasses of SessionBase must provide an exists() method" "subclasses of SessionBase must provide an exists() method"
) )
async def aexists(self, session_key):
return await sync_to_async(self.exists)(session_key)
def create(self): def create(self):
""" """
Create a new session instance. Guaranteed to create a new object with Create a new session instance. Guaranteed to create a new object with
@ -325,6 +468,9 @@ class SessionBase:
"subclasses of SessionBase must provide a create() method" "subclasses of SessionBase must provide a create() method"
) )
async def acreate(self):
return await sync_to_async(self.create)()
def save(self, must_create=False): def save(self, must_create=False):
""" """
Save the session data. If 'must_create' is True, create a new session Save the session data. If 'must_create' is True, create a new session
@ -335,6 +481,9 @@ class SessionBase:
"subclasses of SessionBase must provide a save() method" "subclasses of SessionBase must provide a save() method"
) )
async def asave(self, must_create=False):
return await sync_to_async(self.save)(must_create)
def delete(self, session_key=None): def delete(self, session_key=None):
""" """
Delete the session data under this key. If the key is None, use the Delete the session data under this key. If the key is None, use the
@ -344,6 +493,9 @@ class SessionBase:
"subclasses of SessionBase must provide a delete() method" "subclasses of SessionBase must provide a delete() method"
) )
async def adelete(self, session_key=None):
return await sync_to_async(self.delete)(session_key)
def load(self): def load(self):
""" """
Load the session data and return a dictionary. Load the session data and return a dictionary.
@ -352,6 +504,9 @@ class SessionBase:
"subclasses of SessionBase must provide a load() method" "subclasses of SessionBase must provide a load() method"
) )
async def aload(self):
return await sync_to_async(self.load)()
@classmethod @classmethod
def clear_expired(cls): def clear_expired(cls):
""" """
@ -362,3 +517,7 @@ class SessionBase:
a built-in expiration mechanism, it should be a no-op. a built-in expiration mechanism, it should be a no-op.
""" """
raise NotImplementedError("This backend does not support clear_expired().") raise NotImplementedError("This backend does not support clear_expired().")
@classmethod
async def aclear_expired(cls):
return await sync_to_async(cls.clear_expired)()

View File

@ -20,6 +20,9 @@ class SessionStore(SessionBase):
def cache_key(self): def cache_key(self):
return self.cache_key_prefix + self._get_or_create_session_key() return self.cache_key_prefix + self._get_or_create_session_key()
async def acache_key(self):
return self.cache_key_prefix + await self._aget_or_create_session_key()
def load(self): def load(self):
try: try:
session_data = self._cache.get(self.cache_key) session_data = self._cache.get(self.cache_key)
@ -32,6 +35,16 @@ class SessionStore(SessionBase):
self._session_key = None self._session_key = None
return {} return {}
async def aload(self):
try:
session_data = await self._cache.aget(await self.acache_key())
except Exception:
session_data = None
if session_data is not None:
return session_data
self._session_key = None
return {}
def create(self): def create(self):
# Because a cache can fail silently (e.g. memcache), we don't know if # Because a cache can fail silently (e.g. memcache), we don't know if
# we are failing to create a new session because of a key collision or # we are failing to create a new session because of a key collision or
@ -51,6 +64,20 @@ class SessionStore(SessionBase):
"It is likely that the cache is unavailable." "It is likely that the cache is unavailable."
) )
async def acreate(self):
for i in range(10000):
self._session_key = await self._aget_new_session_key()
try:
await self.asave(must_create=True)
except CreateError:
continue
self.modified = True
return
raise RuntimeError(
"Unable to create a new session key. "
"It is likely that the cache is unavailable."
)
def save(self, must_create=False): def save(self, must_create=False):
if self.session_key is None: if self.session_key is None:
return self.create() return self.create()
@ -68,11 +95,33 @@ class SessionStore(SessionBase):
if must_create and not result: if must_create and not result:
raise CreateError raise CreateError
async def asave(self, must_create=False):
if self.session_key is None:
return await self.acreate()
if must_create:
func = self._cache.aadd
elif await self._cache.aget(await self.acache_key()) is not None:
func = self._cache.aset
else:
raise UpdateError
result = await func(
await self.acache_key(),
await self._aget_session(no_load=must_create),
await self.aget_expiry_age(),
)
if must_create and not result:
raise CreateError
def exists(self, session_key): def exists(self, session_key):
return ( return (
bool(session_key) and (self.cache_key_prefix + session_key) in self._cache bool(session_key) and (self.cache_key_prefix + session_key) in self._cache
) )
async def aexists(self, session_key):
return bool(session_key) and await self._cache.ahas_key(
self.cache_key_prefix + session_key
)
def delete(self, session_key=None): def delete(self, session_key=None):
if session_key is None: if session_key is None:
if self.session_key is None: if self.session_key is None:
@ -80,6 +129,17 @@ class SessionStore(SessionBase):
session_key = self.session_key session_key = self.session_key
self._cache.delete(self.cache_key_prefix + session_key) self._cache.delete(self.cache_key_prefix + session_key)
async def adelete(self, session_key=None):
if session_key is None:
if self.session_key is None:
return
session_key = self.session_key
await self._cache.adelete(self.cache_key_prefix + session_key)
@classmethod @classmethod
def clear_expired(cls): def clear_expired(cls):
pass pass
@classmethod
async def aclear_expired(cls):
pass

View File

@ -28,6 +28,9 @@ class SessionStore(DBStore):
def cache_key(self): def cache_key(self):
return self.cache_key_prefix + self._get_or_create_session_key() return self.cache_key_prefix + self._get_or_create_session_key()
async def acache_key(self):
return self.cache_key_prefix + await self._aget_or_create_session_key()
def load(self): def load(self):
try: try:
data = self._cache.get(self.cache_key) data = self._cache.get(self.cache_key)
@ -47,6 +50,27 @@ class SessionStore(DBStore):
data = {} data = {}
return data return data
async def aload(self):
try:
data = await self._cache.aget(await self.acache_key())
except Exception:
# Some backends (e.g. memcache) raise an exception on invalid
# cache keys. If this happens, reset the session. See #17810.
data = None
if data is None:
s = await self._aget_session_from_db()
if s:
data = self.decode(s.session_data)
await self._cache.aset(
await self.acache_key(),
data,
await self.aget_expiry_age(expiry=s.expire_date),
)
else:
data = {}
return data
def exists(self, session_key): def exists(self, session_key):
return ( return (
session_key session_key
@ -54,6 +78,13 @@ class SessionStore(DBStore):
or super().exists(session_key) or super().exists(session_key)
) )
async def aexists(self, session_key):
return (
session_key
and (self.cache_key_prefix + session_key) in self._cache
or await super().aexists(session_key)
)
def save(self, must_create=False): def save(self, must_create=False):
super().save(must_create) super().save(must_create)
try: try:
@ -61,6 +92,17 @@ class SessionStore(DBStore):
except Exception: except Exception:
logger.exception("Error saving to cache (%s)", self._cache) logger.exception("Error saving to cache (%s)", self._cache)
async def asave(self, must_create=False):
await super().asave(must_create)
try:
await self._cache.aset(
await self.acache_key(),
self._session,
await self.aget_expiry_age(),
)
except Exception:
logger.exception("Error saving to cache (%s)", self._cache)
def delete(self, session_key=None): def delete(self, session_key=None):
super().delete(session_key) super().delete(session_key)
if session_key is None: if session_key is None:
@ -69,6 +111,14 @@ class SessionStore(DBStore):
session_key = self.session_key session_key = self.session_key
self._cache.delete(self.cache_key_prefix + session_key) self._cache.delete(self.cache_key_prefix + session_key)
async def adelete(self, session_key=None):
await super().adelete(session_key)
if session_key is None:
if self.session_key is None:
return
session_key = self.session_key
await self._cache.adelete(self.cache_key_prefix + session_key)
def flush(self): def flush(self):
""" """
Remove the current session data from the database and regenerate the Remove the current session data from the database and regenerate the
@ -77,3 +127,9 @@ class SessionStore(DBStore):
self.clear() self.clear()
self.delete(self.session_key) self.delete(self.session_key)
self._session_key = None self._session_key = None
async def aflush(self):
"""See flush()."""
self.clear()
await self.adelete(self.session_key)
self._session_key = None

View File

@ -1,5 +1,7 @@
import logging import logging
from asgiref.sync import sync_to_async
from django.contrib.sessions.backends.base import CreateError, SessionBase, UpdateError from django.contrib.sessions.backends.base import CreateError, SessionBase, UpdateError
from django.core.exceptions import SuspiciousOperation from django.core.exceptions import SuspiciousOperation
from django.db import DatabaseError, IntegrityError, router, transaction from django.db import DatabaseError, IntegrityError, router, transaction
@ -38,13 +40,31 @@ class SessionStore(SessionBase):
logger.warning(str(e)) logger.warning(str(e))
self._session_key = None self._session_key = None
async def _aget_session_from_db(self):
try:
return await self.model.objects.aget(
session_key=self.session_key, expire_date__gt=timezone.now()
)
except (self.model.DoesNotExist, SuspiciousOperation) as e:
if isinstance(e, SuspiciousOperation):
logger = logging.getLogger("django.security.%s" % e.__class__.__name__)
logger.warning(str(e))
self._session_key = None
def load(self): def load(self):
s = self._get_session_from_db() s = self._get_session_from_db()
return self.decode(s.session_data) if s else {} return self.decode(s.session_data) if s else {}
async def aload(self):
s = await self._aget_session_from_db()
return self.decode(s.session_data) if s else {}
def exists(self, session_key): def exists(self, session_key):
return self.model.objects.filter(session_key=session_key).exists() return self.model.objects.filter(session_key=session_key).exists()
async def aexists(self, session_key):
return await self.model.objects.filter(session_key=session_key).aexists()
def create(self): def create(self):
while True: while True:
self._session_key = self._get_new_session_key() self._session_key = self._get_new_session_key()
@ -58,6 +78,19 @@ class SessionStore(SessionBase):
self.modified = True self.modified = True
return return
async def acreate(self):
while True:
self._session_key = await self._aget_new_session_key()
try:
# Save immediately to ensure we have a unique entry in the
# database.
await self.asave(must_create=True)
except CreateError:
# Key wasn't unique. Try again.
continue
self.modified = True
return
def create_model_instance(self, data): def create_model_instance(self, data):
""" """
Return a new instance of the session model object, which represents the Return a new instance of the session model object, which represents the
@ -70,6 +103,14 @@ class SessionStore(SessionBase):
expire_date=self.get_expiry_date(), expire_date=self.get_expiry_date(),
) )
async def acreate_model_instance(self, data):
"""See create_model_instance()."""
return self.model(
session_key=await self._aget_or_create_session_key(),
session_data=self.encode(data),
expire_date=await self.aget_expiry_date(),
)
def save(self, must_create=False): def save(self, must_create=False):
""" """
Save the current session data to the database. If 'must_create' is Save the current session data to the database. If 'must_create' is
@ -95,6 +136,36 @@ class SessionStore(SessionBase):
raise UpdateError raise UpdateError
raise raise
async def asave(self, must_create=False):
"""See save()."""
if self.session_key is None:
return await self.acreate()
data = await self._aget_session(no_load=must_create)
obj = await self.acreate_model_instance(data)
using = router.db_for_write(self.model, instance=obj)
try:
# This code MOST run in a transaction, so it requires
# @sync_to_async wrapping until transaction.atomic() supports
# async.
@sync_to_async
def sync_transaction():
with transaction.atomic(using=using):
obj.save(
force_insert=must_create,
force_update=not must_create,
using=using,
)
await sync_transaction()
except IntegrityError:
if must_create:
raise CreateError
raise
except DatabaseError:
if not must_create:
raise UpdateError
raise
def delete(self, session_key=None): def delete(self, session_key=None):
if session_key is None: if session_key is None:
if self.session_key is None: if self.session_key is None:
@ -105,6 +176,23 @@ class SessionStore(SessionBase):
except self.model.DoesNotExist: except self.model.DoesNotExist:
pass pass
async def adelete(self, session_key=None):
if session_key is None:
if self.session_key is None:
return
session_key = self.session_key
try:
obj = await self.model.objects.aget(session_key=session_key)
await obj.adelete()
except self.model.DoesNotExist:
pass
@classmethod @classmethod
def clear_expired(cls): def clear_expired(cls):
cls.get_model_class().objects.filter(expire_date__lt=timezone.now()).delete() cls.get_model_class().objects.filter(expire_date__lt=timezone.now()).delete()
@classmethod
async def aclear_expired(cls):
await cls.get_model_class().objects.filter(
expire_date__lt=timezone.now()
).adelete()

View File

@ -104,6 +104,9 @@ class SessionStore(SessionBase):
self._session_key = None self._session_key = None
return session_data return session_data
async def aload(self):
return self.load()
def create(self): def create(self):
while True: while True:
self._session_key = self._get_new_session_key() self._session_key = self._get_new_session_key()
@ -114,6 +117,9 @@ class SessionStore(SessionBase):
self.modified = True self.modified = True
return return
async def acreate(self):
return self.create()
def save(self, must_create=False): def save(self, must_create=False):
if self.session_key is None: if self.session_key is None:
return self.create() return self.create()
@ -177,9 +183,15 @@ class SessionStore(SessionBase):
except (EOFError, OSError): except (EOFError, OSError):
pass pass
async def asave(self, must_create=False):
return self.save(must_create=must_create)
def exists(self, session_key): def exists(self, session_key):
return os.path.exists(self._key_to_file(session_key)) return os.path.exists(self._key_to_file(session_key))
async def aexists(self, session_key):
return self.exists(session_key)
def delete(self, session_key=None): def delete(self, session_key=None):
if session_key is None: if session_key is None:
if self.session_key is None: if self.session_key is None:
@ -190,6 +202,9 @@ class SessionStore(SessionBase):
except OSError: except OSError:
pass pass
async def adelete(self, session_key=None):
return self.delete(session_key=session_key)
@classmethod @classmethod
def clear_expired(cls): def clear_expired(cls):
storage_path = cls._get_storage_path() storage_path = cls._get_storage_path()
@ -205,3 +220,7 @@ class SessionStore(SessionBase):
# the create() method. # the create() method.
session.create = lambda: None session.create = lambda: None
session.load() session.load()
@classmethod
async def aclear_expired(cls):
cls.clear_expired()

View File

@ -23,6 +23,9 @@ class SessionStore(SessionBase):
self.create() self.create()
return {} return {}
async def aload(self):
return self.load()
def create(self): def create(self):
""" """
To create a new key, set the modified flag so that the cookie is set To create a new key, set the modified flag so that the cookie is set
@ -30,6 +33,9 @@ class SessionStore(SessionBase):
""" """
self.modified = True self.modified = True
async def acreate(self):
return self.create()
def save(self, must_create=False): def save(self, must_create=False):
""" """
To save, get the session key as a securely signed string and then set To save, get the session key as a securely signed string and then set
@ -39,6 +45,9 @@ class SessionStore(SessionBase):
self._session_key = self._get_session_key() self._session_key = self._get_session_key()
self.modified = True self.modified = True
async def asave(self, must_create=False):
return self.save(must_create=must_create)
def exists(self, session_key=None): def exists(self, session_key=None):
""" """
This method makes sense when you're talking to a shared resource, but This method makes sense when you're talking to a shared resource, but
@ -47,6 +56,9 @@ class SessionStore(SessionBase):
""" """
return False return False
async def aexists(self, session_key=None):
return self.exists(session_key=session_key)
def delete(self, session_key=None): def delete(self, session_key=None):
""" """
To delete, clear the session key and the underlying data structure To delete, clear the session key and the underlying data structure
@ -57,6 +69,9 @@ class SessionStore(SessionBase):
self._session_cache = {} self._session_cache = {}
self.modified = True self.modified = True
async def adelete(self, session_key=None):
return self.delete(session_key=session_key)
def cycle_key(self): def cycle_key(self):
""" """
Keep the same data but with a new key. Call save() and it will Keep the same data but with a new key. Call save() and it will
@ -64,6 +79,9 @@ class SessionStore(SessionBase):
""" """
self.save() self.save()
async def acycle_key(self):
return self.cycle_key()
def _get_session_key(self): def _get_session_key(self):
""" """
Instead of generating a random string, generate a secure url-safe Instead of generating a random string, generate a secure url-safe
@ -79,3 +97,7 @@ class SessionStore(SessionBase):
@classmethod @classmethod
def clear_expired(cls): def clear_expired(cls):
pass pass
@classmethod
async def aclear_expired(cls):
pass

View File

@ -817,7 +817,14 @@ class ClientMixin:
return session return session
async def asession(self): async def asession(self):
return await sync_to_async(lambda: self.session)() engine = import_module(settings.SESSION_ENGINE)
cookie = self.cookies.get(settings.SESSION_COOKIE_NAME)
if cookie:
return engine.SessionStore(cookie.value)
session = engine.SessionStore()
await session.asave()
self.cookies[settings.SESSION_COOKIE_NAME] = session.session_key
return session
def login(self, **credentials): def login(self, **credentials):
""" """
@ -893,7 +900,7 @@ class ClientMixin:
await alogin(request, user, backend) await alogin(request, user, backend)
# Save the session values. # Save the session values.
await sync_to_async(request.session.save)() await request.session.asave()
self._set_login_cookies(request) self._set_login_cookies(request)
def _set_login_cookies(self, request): def _set_login_cookies(self, request):

View File

@ -125,6 +125,10 @@ Minor features
error messages with their traceback via the newly added error messages with their traceback via the newly added
:ref:`sessions logger <django-contrib-sessions-logger>`. :ref:`sessions logger <django-contrib-sessions-logger>`.
* :class:`django.contrib.sessions.backends.base.SessionBase` and all built-in
session engines now provide async API. The new asynchronous methods all have
``a`` prefixed names, e.g. ``aget()``, ``akeys()``, or ``acycle_key()``.
:mod:`django.contrib.sitemaps` :mod:`django.contrib.sitemaps`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -196,54 +196,156 @@ You can edit it multiple times.
Example: ``'fav_color' in request.session`` Example: ``'fav_color' in request.session``
.. method:: get(key, default=None) .. method:: get(key, default=None)
.. method:: aget(key, default=None)
*Asynchronous version*: ``aget()``
Example: ``fav_color = request.session.get('fav_color', 'red')`` Example: ``fav_color = request.session.get('fav_color', 'red')``
.. versionchanged:: 5.1
``aget()`` function was added.
.. method:: aset(key, value)
.. versionadded:: 5.1
Example: ``await request.session.aset('fav_color', 'red')``
.. method:: update(dict)
.. method:: aupdate(dict)
*Asynchronous version*: ``aupdate()``
Example: ``request.session.update({'fav_color': 'red'})``
.. versionchanged:: 5.1
``aupdate()`` function was added.
.. method:: pop(key, default=__not_given) .. method:: pop(key, default=__not_given)
.. method:: apop(key, default=__not_given)
*Asynchronous version*: ``apop()``
Example: ``fav_color = request.session.pop('fav_color', 'blue')`` Example: ``fav_color = request.session.pop('fav_color', 'blue')``
.. versionchanged:: 5.1
``apop()`` function was added.
.. method:: keys() .. method:: keys()
.. method:: akeys()
*Asynchronous version*: ``akeys()``
.. versionchanged:: 5.1
``akeys()`` function was added.
.. method:: values()
.. method:: avalues()
*Asynchronous version*: ``avalues()``
.. versionchanged:: 5.1
``avalues()`` function was added.
.. method:: has_key(key)
.. method:: ahas_key(key)
*Asynchronous version*: ``ahas_key()``
.. versionchanged:: 5.1
``ahas_key()`` function was added.
.. method:: items() .. method:: items()
.. method:: aitems()
*Asynchronous version*: ``aitems()``
.. versionchanged:: 5.1
``aitems()`` function was added.
.. method:: setdefault() .. method:: setdefault()
.. method:: asetdefault()
*Asynchronous version*: ``asetdefault()``
.. versionchanged:: 5.1
``asetdefault()`` function was added.
.. method:: clear() .. method:: clear()
It also has these methods: It also has these methods:
.. method:: flush() .. method:: flush()
.. method:: aflush()
*Asynchronous version*: ``aflush()``
Deletes the current session data from the session and deletes the session Deletes the current session data from the session and deletes the session
cookie. This is used if you want to ensure that the previous session data cookie. This is used if you want to ensure that the previous session data
can't be accessed again from the user's browser (for example, the can't be accessed again from the user's browser (for example, the
:func:`django.contrib.auth.logout()` function calls it). :func:`django.contrib.auth.logout()` function calls it).
.. versionchanged:: 5.1
``aflush()`` function was added.
.. method:: set_test_cookie() .. method:: set_test_cookie()
.. method:: aset_test_cookie()
*Asynchronous version*: ``aset_test_cookie()``
Sets a test cookie to determine whether the user's browser supports Sets a test cookie to determine whether the user's browser supports
cookies. Due to the way cookies work, you won't be able to test this cookies. Due to the way cookies work, you won't be able to test this
until the user's next page request. See `Setting test cookies`_ below for until the user's next page request. See `Setting test cookies`_ below for
more information. more information.
.. versionchanged:: 5.1
``aset_test_cookie()`` function was added.
.. method:: test_cookie_worked() .. method:: test_cookie_worked()
.. method:: atest_cookie_worked()
*Asynchronous version*: ``atest_cookie_worked()``
Returns either ``True`` or ``False``, depending on whether the user's Returns either ``True`` or ``False``, depending on whether the user's
browser accepted the test cookie. Due to the way cookies work, you'll browser accepted the test cookie. Due to the way cookies work, you'll
have to call ``set_test_cookie()`` on a previous, separate page request. have to call ``set_test_cookie()`` or ``aset_test_cookie()`` on a
previous, separate page request.
See `Setting test cookies`_ below for more information. See `Setting test cookies`_ below for more information.
.. versionchanged:: 5.1
``atest_cookie_worked()`` function was added.
.. method:: delete_test_cookie() .. method:: delete_test_cookie()
.. method:: adelete_test_cookie()
*Asynchronous version*: ``adelete_test_cookie()``
Deletes the test cookie. Use this to clean up after yourself. Deletes the test cookie. Use this to clean up after yourself.
.. versionchanged:: 5.1
``adelete_test_cookie()`` function was added.
.. method:: get_session_cookie_age() .. method:: get_session_cookie_age()
Returns the value of the setting :setting:`SESSION_COOKIE_AGE`. This can Returns the value of the setting :setting:`SESSION_COOKIE_AGE`. This can
be overridden in a custom session backend. be overridden in a custom session backend.
.. method:: set_expiry(value) .. method:: set_expiry(value)
.. method:: aset_expiry(value)
*Asynchronous version*: ``aset_expiry()``
Sets the expiration time for the session. You can pass a number of Sets the expiration time for the session. You can pass a number of
different values: different values:
@ -266,7 +368,14 @@ You can edit it multiple times.
purposes. Session expiration is computed from the last time the purposes. Session expiration is computed from the last time the
session was *modified*. session was *modified*.
.. versionchanged:: 5.1
``aset_expiry()`` function was added.
.. method:: get_expiry_age() .. method:: get_expiry_age()
.. method:: aget_expiry_age()
*Asynchronous version*: ``aget_expiry_age()``
Returns the number of seconds until this session expires. For sessions Returns the number of seconds until this session expires. For sessions
with no custom expiration (or those set to expire at browser close), this with no custom expiration (or those set to expire at browser close), this
@ -279,7 +388,7 @@ You can edit it multiple times.
- ``expiry``: expiry information for the session, as a - ``expiry``: expiry information for the session, as a
:class:`~datetime.datetime` object, an :class:`int` (in seconds), or :class:`~datetime.datetime` object, an :class:`int` (in seconds), or
``None``. Defaults to the value stored in the session by ``None``. Defaults to the value stored in the session by
:meth:`set_expiry`, if there is one, or ``None``. :meth:`set_expiry`/:meth:`aset_expiry`, if there is one, or ``None``.
.. note:: .. note::
@ -295,7 +404,14 @@ You can edit it multiple times.
expires_at = modification + timedelta(seconds=settings.SESSION_COOKIE_AGE) expires_at = modification + timedelta(seconds=settings.SESSION_COOKIE_AGE)
.. versionchanged:: 5.1
``aget_expiry_age()`` function was added.
.. method:: get_expiry_date() .. method:: get_expiry_date()
.. method:: aget_expiry_date()
*Asynchronous version*: ``aget_expiry_date()``
Returns the date this session will expire. For sessions with no custom Returns the date this session will expire. For sessions with no custom
expiration (or those set to expire at browser close), this will equal the expiration (or those set to expire at browser close), this will equal the
@ -304,22 +420,47 @@ You can edit it multiple times.
This function accepts the same keyword arguments as This function accepts the same keyword arguments as
:meth:`get_expiry_age`, and similar notes on usage apply. :meth:`get_expiry_age`, and similar notes on usage apply.
.. versionchanged:: 5.1
``aget_expiry_date()`` function was added.
.. method:: get_expire_at_browser_close() .. method:: get_expire_at_browser_close()
.. method:: aget_expire_at_browser_close()
*Asynchronous version*: ``aget_expire_at_browser_close()``
Returns either ``True`` or ``False``, depending on whether the user's Returns either ``True`` or ``False``, depending on whether the user's
session cookie will expire when the user's web browser is closed. session cookie will expire when the user's web browser is closed.
.. versionchanged:: 5.1
``aget_expire_at_browser_close()`` function was added.
.. method:: clear_expired() .. method:: clear_expired()
.. method:: aclear_expired()
*Asynchronous version*: ``aclear_expired()``
Removes expired sessions from the session store. This class method is Removes expired sessions from the session store. This class method is
called by :djadmin:`clearsessions`. called by :djadmin:`clearsessions`.
.. versionchanged:: 5.1
``aclear_expired()`` function was added.
.. method:: cycle_key() .. method:: cycle_key()
.. method:: acycle_key()
*Asynchronous version*: ``acycle_key()``
Creates a new session key while retaining the current session data. Creates a new session key while retaining the current session data.
:func:`django.contrib.auth.login()` calls this method to mitigate against :func:`django.contrib.auth.login()` calls this method to mitigate against
session fixation. session fixation.
.. versionchanged:: 5.1
``acycle_key()`` function was added.
.. _session_serialization: .. _session_serialization:
Session serialization Session serialization
@ -475,6 +616,10 @@ Here's a typical usage example::
request.session.set_test_cookie() request.session.set_test_cookie()
return render(request, "foo/login_form.html") return render(request, "foo/login_form.html")
.. versionchanged:: 5.1
Support for setting test cookies in asynchronous view functions was added.
Using sessions out of views Using sessions out of views
=========================== ===========================
@ -694,16 +839,26 @@ the corresponding session engine. By convention, the session store object class
is named ``SessionStore`` and is located in the module designated by is named ``SessionStore`` and is located in the module designated by
:setting:`SESSION_ENGINE`. :setting:`SESSION_ENGINE`.
All ``SessionStore`` classes available in Django inherit from All ``SessionStore`` subclasses available in Django implement the following
:class:`~backends.base.SessionBase` and implement data manipulation methods, data manipulation methods:
namely:
* ``exists()`` * ``exists()``
* ``create()`` * ``create()``
* ``save()`` * ``save()``
* ``delete()`` * ``delete()``
* ``load()`` * ``load()``
* :meth:`~backends.base.SessionBase.clear_expired` * :meth:`~.SessionBase.clear_expired`
An asynchronous interface for these methods is provided by wrapping them with
``sync_to_async()``. They can be implemented directly if an async-native
implementation is available:
* ``aexists()``
* ``acreate()``
* ``asave()``
* ``adelete()``
* ``aload()``
* :meth:`~.SessionBase.aclear_expired`
In order to build a custom session engine or to customize an existing one, you In order to build a custom session engine or to customize an existing one, you
may create a new class inheriting from :class:`~backends.base.SessionBase` or may create a new class inheriting from :class:`~backends.base.SessionBase` or
@ -713,6 +868,11 @@ You can extend the session engines, but doing so with database-backed session
engines generally requires some extra effort (see the next section for engines generally requires some extra effort (see the next section for
details). details).
.. versionchanged:: 5.1
``aexists()``, ``acreate()``, ``asave()``, ``adelete()``, ``aload()``, and
``aclear_expired()`` methods were added.
.. _extending-database-backed-session-engines: .. _extending-database-backed-session-engines:
Extending database-backed session engines Extending database-backed session engines

View File

@ -5,3 +5,6 @@ class CacheClass(LocMemCache):
def set(self, *args, **kwargs): def set(self, *args, **kwargs):
raise Exception("Faked exception saving to cache") raise Exception("Faked exception saving to cache")
async def aset(self, *args, **kwargs):
raise Exception("Faked exception saving to cache")

View File

@ -61,11 +61,19 @@ class SessionTestsMixin:
def test_get_empty(self): def test_get_empty(self):
self.assertIsNone(self.session.get("cat")) self.assertIsNone(self.session.get("cat"))
async def test_get_empty_async(self):
self.assertIsNone(await self.session.aget("cat"))
def test_store(self): def test_store(self):
self.session["cat"] = "dog" self.session["cat"] = "dog"
self.assertIs(self.session.modified, True) self.assertIs(self.session.modified, True)
self.assertEqual(self.session.pop("cat"), "dog") self.assertEqual(self.session.pop("cat"), "dog")
async def test_store_async(self):
await self.session.aset("cat", "dog")
self.assertIs(self.session.modified, True)
self.assertEqual(await self.session.apop("cat"), "dog")
def test_pop(self): def test_pop(self):
self.session["some key"] = "exists" self.session["some key"] = "exists"
# Need to reset these to pretend we haven't accessed it: # Need to reset these to pretend we haven't accessed it:
@ -77,6 +85,17 @@ class SessionTestsMixin:
self.assertIs(self.session.modified, True) self.assertIs(self.session.modified, True)
self.assertIsNone(self.session.get("some key")) self.assertIsNone(self.session.get("some key"))
async def test_pop_async(self):
await self.session.aset("some key", "exists")
# Need to reset these to pretend we haven't accessed it:
self.accessed = False
self.modified = False
self.assertEqual(await self.session.apop("some key"), "exists")
self.assertIs(self.session.accessed, True)
self.assertIs(self.session.modified, True)
self.assertIsNone(await self.session.aget("some key"))
def test_pop_default(self): def test_pop_default(self):
self.assertEqual( self.assertEqual(
self.session.pop("some key", "does not exist"), "does not exist" self.session.pop("some key", "does not exist"), "does not exist"
@ -84,6 +103,13 @@ class SessionTestsMixin:
self.assertIs(self.session.accessed, True) self.assertIs(self.session.accessed, True)
self.assertIs(self.session.modified, False) self.assertIs(self.session.modified, False)
async def test_pop_default_async(self):
self.assertEqual(
await self.session.apop("some key", "does not exist"), "does not exist"
)
self.assertIs(self.session.accessed, True)
self.assertIs(self.session.modified, False)
def test_pop_default_named_argument(self): def test_pop_default_named_argument(self):
self.assertEqual( self.assertEqual(
self.session.pop("some key", default="does not exist"), "does not exist" self.session.pop("some key", default="does not exist"), "does not exist"
@ -91,22 +117,46 @@ class SessionTestsMixin:
self.assertIs(self.session.accessed, True) self.assertIs(self.session.accessed, True)
self.assertIs(self.session.modified, False) self.assertIs(self.session.modified, False)
async def test_pop_default_named_argument_async(self):
self.assertEqual(
await self.session.apop("some key", default="does not exist"),
"does not exist",
)
self.assertIs(self.session.accessed, True)
self.assertIs(self.session.modified, False)
def test_pop_no_default_keyerror_raised(self): def test_pop_no_default_keyerror_raised(self):
with self.assertRaises(KeyError): with self.assertRaises(KeyError):
self.session.pop("some key") self.session.pop("some key")
async def test_pop_no_default_keyerror_raised_async(self):
with self.assertRaises(KeyError):
await self.session.apop("some key")
def test_setdefault(self): def test_setdefault(self):
self.assertEqual(self.session.setdefault("foo", "bar"), "bar") self.assertEqual(self.session.setdefault("foo", "bar"), "bar")
self.assertEqual(self.session.setdefault("foo", "baz"), "bar") self.assertEqual(self.session.setdefault("foo", "baz"), "bar")
self.assertIs(self.session.accessed, True) self.assertIs(self.session.accessed, True)
self.assertIs(self.session.modified, True) self.assertIs(self.session.modified, True)
async def test_setdefault_async(self):
self.assertEqual(await self.session.asetdefault("foo", "bar"), "bar")
self.assertEqual(await self.session.asetdefault("foo", "baz"), "bar")
self.assertIs(self.session.accessed, True)
self.assertIs(self.session.modified, True)
def test_update(self): def test_update(self):
self.session.update({"update key": 1}) self.session.update({"update key": 1})
self.assertIs(self.session.accessed, True) self.assertIs(self.session.accessed, True)
self.assertIs(self.session.modified, True) self.assertIs(self.session.modified, True)
self.assertEqual(self.session.get("update key", None), 1) self.assertEqual(self.session.get("update key", None), 1)
async def test_update_async(self):
await self.session.aupdate({"update key": 1})
self.assertIs(self.session.accessed, True)
self.assertIs(self.session.modified, True)
self.assertEqual(await self.session.aget("update key", None), 1)
def test_has_key(self): def test_has_key(self):
self.session["some key"] = 1 self.session["some key"] = 1
self.session.modified = False self.session.modified = False
@ -115,6 +165,14 @@ class SessionTestsMixin:
self.assertIs(self.session.accessed, True) self.assertIs(self.session.accessed, True)
self.assertIs(self.session.modified, False) self.assertIs(self.session.modified, False)
async def test_has_key_async(self):
await self.session.aset("some key", 1)
self.session.modified = False
self.session.accessed = False
self.assertIs(await self.session.ahas_key("some key"), True)
self.assertIs(self.session.accessed, True)
self.assertIs(self.session.modified, False)
def test_values(self): def test_values(self):
self.assertEqual(list(self.session.values()), []) self.assertEqual(list(self.session.values()), [])
self.assertIs(self.session.accessed, True) self.assertIs(self.session.accessed, True)
@ -125,6 +183,16 @@ class SessionTestsMixin:
self.assertIs(self.session.accessed, True) self.assertIs(self.session.accessed, True)
self.assertIs(self.session.modified, False) self.assertIs(self.session.modified, False)
async def test_values_async(self):
self.assertEqual(list(await self.session.avalues()), [])
self.assertIs(self.session.accessed, True)
await self.session.aset("some key", 1)
self.session.modified = False
self.session.accessed = False
self.assertEqual(list(await self.session.avalues()), [1])
self.assertIs(self.session.accessed, True)
self.assertIs(self.session.modified, False)
def test_keys(self): def test_keys(self):
self.session["x"] = 1 self.session["x"] = 1
self.session.modified = False self.session.modified = False
@ -133,6 +201,14 @@ class SessionTestsMixin:
self.assertIs(self.session.accessed, True) self.assertIs(self.session.accessed, True)
self.assertIs(self.session.modified, False) self.assertIs(self.session.modified, False)
async def test_keys_async(self):
await self.session.aset("x", 1)
self.session.modified = False
self.session.accessed = False
self.assertEqual(list(await self.session.akeys()), ["x"])
self.assertIs(self.session.accessed, True)
self.assertIs(self.session.modified, False)
def test_items(self): def test_items(self):
self.session["x"] = 1 self.session["x"] = 1
self.session.modified = False self.session.modified = False
@ -141,6 +217,14 @@ class SessionTestsMixin:
self.assertIs(self.session.accessed, True) self.assertIs(self.session.accessed, True)
self.assertIs(self.session.modified, False) self.assertIs(self.session.modified, False)
async def test_items_async(self):
await self.session.aset("x", 1)
self.session.modified = False
self.session.accessed = False
self.assertEqual(list(await self.session.aitems()), [("x", 1)])
self.assertIs(self.session.accessed, True)
self.assertIs(self.session.modified, False)
def test_clear(self): def test_clear(self):
self.session["x"] = 1 self.session["x"] = 1
self.session.modified = False self.session.modified = False
@ -155,11 +239,20 @@ class SessionTestsMixin:
self.session.save() self.session.save()
self.assertIs(self.session.exists(self.session.session_key), True) self.assertIs(self.session.exists(self.session.session_key), True)
async def test_save_async(self):
await self.session.asave()
self.assertIs(await self.session.aexists(self.session.session_key), True)
def test_delete(self): def test_delete(self):
self.session.save() self.session.save()
self.session.delete(self.session.session_key) self.session.delete(self.session.session_key)
self.assertIs(self.session.exists(self.session.session_key), False) self.assertIs(self.session.exists(self.session.session_key), False)
async def test_delete_async(self):
await self.session.asave()
await self.session.adelete(self.session.session_key)
self.assertIs(await self.session.aexists(self.session.session_key), False)
def test_flush(self): def test_flush(self):
self.session["foo"] = "bar" self.session["foo"] = "bar"
self.session.save() self.session.save()
@ -171,6 +264,17 @@ class SessionTestsMixin:
self.assertIs(self.session.modified, True) self.assertIs(self.session.modified, True)
self.assertIs(self.session.accessed, True) self.assertIs(self.session.accessed, True)
async def test_flush_async(self):
await self.session.aset("foo", "bar")
await self.session.asave()
prev_key = self.session.session_key
await self.session.aflush()
self.assertIs(await self.session.aexists(prev_key), False)
self.assertNotEqual(self.session.session_key, prev_key)
self.assertIsNone(self.session.session_key)
self.assertIs(self.session.modified, True)
self.assertIs(self.session.accessed, True)
def test_cycle(self): def test_cycle(self):
self.session["a"], self.session["b"] = "c", "d" self.session["a"], self.session["b"] = "c", "d"
self.session.save() self.session.save()
@ -181,6 +285,17 @@ class SessionTestsMixin:
self.assertNotEqual(self.session.session_key, prev_key) self.assertNotEqual(self.session.session_key, prev_key)
self.assertEqual(list(self.session.items()), prev_data) self.assertEqual(list(self.session.items()), prev_data)
async def test_cycle_async(self):
await self.session.aset("a", "c")
await self.session.aset("b", "d")
await self.session.asave()
prev_key = self.session.session_key
prev_data = list(await self.session.aitems())
await self.session.acycle_key()
self.assertIs(await self.session.aexists(prev_key), False)
self.assertNotEqual(self.session.session_key, prev_key)
self.assertEqual(list(await self.session.aitems()), prev_data)
def test_cycle_with_no_session_cache(self): def test_cycle_with_no_session_cache(self):
self.session["a"], self.session["b"] = "c", "d" self.session["a"], self.session["b"] = "c", "d"
self.session.save() self.session.save()
@ -190,11 +305,26 @@ class SessionTestsMixin:
self.session.cycle_key() self.session.cycle_key()
self.assertCountEqual(self.session.items(), prev_data) self.assertCountEqual(self.session.items(), prev_data)
async def test_cycle_with_no_session_cache_async(self):
await self.session.aset("a", "c")
await self.session.aset("b", "d")
await self.session.asave()
prev_data = await self.session.aitems()
self.session = self.backend(self.session.session_key)
self.assertIs(hasattr(self.session, "_session_cache"), False)
await self.session.acycle_key()
self.assertCountEqual(await self.session.aitems(), prev_data)
def test_save_doesnt_clear_data(self): def test_save_doesnt_clear_data(self):
self.session["a"] = "b" self.session["a"] = "b"
self.session.save() self.session.save()
self.assertEqual(self.session["a"], "b") self.assertEqual(self.session["a"], "b")
async def test_save_doesnt_clear_data_async(self):
await self.session.aset("a", "b")
await self.session.asave()
self.assertEqual(await self.session.aget("a"), "b")
def test_invalid_key(self): def test_invalid_key(self):
# Submitting an invalid session key (either by guessing, or if the db has # Submitting an invalid session key (either by guessing, or if the db has
# removed the key) results in a new key being generated. # removed the key) results in a new key being generated.
@ -209,6 +339,20 @@ class SessionTestsMixin:
# session key; make sure that entry is manually deleted # session key; make sure that entry is manually deleted
session.delete("1") session.delete("1")
async def test_invalid_key_async(self):
# Submitting an invalid session key (either by guessing, or if the db has
# removed the key) results in a new key being generated.
try:
session = self.backend("1")
await session.asave()
self.assertNotEqual(session.session_key, "1")
self.assertIsNone(await session.aget("cat"))
await session.adelete()
finally:
# Some backends leave a stale cache entry for the invalid
# session key; make sure that entry is manually deleted
await session.adelete("1")
def test_session_key_empty_string_invalid(self): def test_session_key_empty_string_invalid(self):
"""Falsey values (Such as an empty string) are rejected.""" """Falsey values (Such as an empty string) are rejected."""
self.session._session_key = "" self.session._session_key = ""
@ -241,6 +385,18 @@ class SessionTestsMixin:
self.session.set_expiry(0) self.session.set_expiry(0)
self.assertEqual(self.session.get_expiry_age(), settings.SESSION_COOKIE_AGE) self.assertEqual(self.session.get_expiry_age(), settings.SESSION_COOKIE_AGE)
async def test_default_expiry_async(self):
# A normal session has a max age equal to settings.
self.assertEqual(
await self.session.aget_expiry_age(), settings.SESSION_COOKIE_AGE
)
# So does a custom session with an idle expiration time of 0 (but it'll
# expire at browser close).
await self.session.aset_expiry(0)
self.assertEqual(
await self.session.aget_expiry_age(), settings.SESSION_COOKIE_AGE
)
def test_custom_expiry_seconds(self): def test_custom_expiry_seconds(self):
modification = timezone.now() modification = timezone.now()
@ -252,6 +408,17 @@ class SessionTestsMixin:
age = self.session.get_expiry_age(modification=modification) age = self.session.get_expiry_age(modification=modification)
self.assertEqual(age, 10) self.assertEqual(age, 10)
async def test_custom_expiry_seconds_async(self):
modification = timezone.now()
await self.session.aset_expiry(10)
date = await self.session.aget_expiry_date(modification=modification)
self.assertEqual(date, modification + timedelta(seconds=10))
age = await self.session.aget_expiry_age(modification=modification)
self.assertEqual(age, 10)
def test_custom_expiry_timedelta(self): def test_custom_expiry_timedelta(self):
modification = timezone.now() modification = timezone.now()
@ -269,6 +436,23 @@ class SessionTestsMixin:
age = self.session.get_expiry_age(modification=modification) age = self.session.get_expiry_age(modification=modification)
self.assertEqual(age, 10) self.assertEqual(age, 10)
async def test_custom_expiry_timedelta_async(self):
modification = timezone.now()
# Mock timezone.now, because set_expiry calls it on this code path.
original_now = timezone.now
try:
timezone.now = lambda: modification
await self.session.aset_expiry(timedelta(seconds=10))
finally:
timezone.now = original_now
date = await self.session.aget_expiry_date(modification=modification)
self.assertEqual(date, modification + timedelta(seconds=10))
age = await self.session.aget_expiry_age(modification=modification)
self.assertEqual(age, 10)
def test_custom_expiry_datetime(self): def test_custom_expiry_datetime(self):
modification = timezone.now() modification = timezone.now()
@ -280,12 +464,31 @@ class SessionTestsMixin:
age = self.session.get_expiry_age(modification=modification) age = self.session.get_expiry_age(modification=modification)
self.assertEqual(age, 10) self.assertEqual(age, 10)
async def test_custom_expiry_datetime_async(self):
modification = timezone.now()
await self.session.aset_expiry(modification + timedelta(seconds=10))
date = await self.session.aget_expiry_date(modification=modification)
self.assertEqual(date, modification + timedelta(seconds=10))
age = await self.session.aget_expiry_age(modification=modification)
self.assertEqual(age, 10)
def test_custom_expiry_reset(self): def test_custom_expiry_reset(self):
self.session.set_expiry(None) self.session.set_expiry(None)
self.session.set_expiry(10) self.session.set_expiry(10)
self.session.set_expiry(None) self.session.set_expiry(None)
self.assertEqual(self.session.get_expiry_age(), settings.SESSION_COOKIE_AGE) self.assertEqual(self.session.get_expiry_age(), settings.SESSION_COOKIE_AGE)
async def test_custom_expiry_reset_async(self):
await self.session.aset_expiry(None)
await self.session.aset_expiry(10)
await self.session.aset_expiry(None)
self.assertEqual(
await self.session.aget_expiry_age(), settings.SESSION_COOKIE_AGE
)
def test_get_expire_at_browser_close(self): def test_get_expire_at_browser_close(self):
# Tests get_expire_at_browser_close with different settings and different # Tests get_expire_at_browser_close with different settings and different
# set_expiry calls # set_expiry calls
@ -309,6 +512,29 @@ class SessionTestsMixin:
self.session.set_expiry(None) self.session.set_expiry(None)
self.assertIs(self.session.get_expire_at_browser_close(), True) self.assertIs(self.session.get_expire_at_browser_close(), True)
async def test_get_expire_at_browser_close_async(self):
# Tests get_expire_at_browser_close with different settings and different
# set_expiry calls
with override_settings(SESSION_EXPIRE_AT_BROWSER_CLOSE=False):
await self.session.aset_expiry(10)
self.assertIs(await self.session.aget_expire_at_browser_close(), False)
await self.session.aset_expiry(0)
self.assertIs(await self.session.aget_expire_at_browser_close(), True)
await self.session.aset_expiry(None)
self.assertIs(await self.session.aget_expire_at_browser_close(), False)
with override_settings(SESSION_EXPIRE_AT_BROWSER_CLOSE=True):
await self.session.aset_expiry(10)
self.assertIs(await self.session.aget_expire_at_browser_close(), False)
await self.session.aset_expiry(0)
self.assertIs(await self.session.aget_expire_at_browser_close(), True)
await self.session.aset_expiry(None)
self.assertIs(await self.session.aget_expire_at_browser_close(), True)
def test_decode(self): def test_decode(self):
# Ensure we can decode what we encode # Ensure we can decode what we encode
data = {"a test key": "a test value"} data = {"a test key": "a test value"}
@ -350,6 +576,22 @@ class SessionTestsMixin:
self.session.delete(old_session_key) self.session.delete(old_session_key)
self.session.delete(new_session_key) self.session.delete(new_session_key)
async def test_actual_expiry_async(self):
old_session_key = None
new_session_key = None
try:
await self.session.aset("foo", "bar")
await self.session.aset_expiry(-timedelta(seconds=10))
await self.session.asave()
old_session_key = self.session.session_key
# With an expiry date in the past, the session expires instantly.
new_session = self.backend(self.session.session_key)
new_session_key = new_session.session_key
self.assertIs(await new_session.ahas_key("foo"), False)
finally:
await self.session.adelete(old_session_key)
await self.session.adelete(new_session_key)
def test_session_load_does_not_create_record(self): def test_session_load_does_not_create_record(self):
""" """
Loading an unknown session key does not create a session record. Loading an unknown session key does not create a session record.
@ -364,6 +606,15 @@ class SessionTestsMixin:
# provided unknown key was cycled, not reused # provided unknown key was cycled, not reused
self.assertNotEqual(session.session_key, "someunknownkey") self.assertNotEqual(session.session_key, "someunknownkey")
async def test_session_load_does_not_create_record_async(self):
session = self.backend("someunknownkey")
await session.aload()
self.assertIsNone(session.session_key)
self.assertIs(await session.aexists(session.session_key), False)
# Provided unknown key was cycled, not reused.
self.assertNotEqual(session.session_key, "someunknownkey")
def test_session_save_does_not_resurrect_session_logged_out_in_other_context(self): def test_session_save_does_not_resurrect_session_logged_out_in_other_context(self):
""" """
Sessions shouldn't be resurrected by a concurrent request. Sessions shouldn't be resurrected by a concurrent request.
@ -386,6 +637,28 @@ class SessionTestsMixin:
self.assertEqual(s1.load(), {}) self.assertEqual(s1.load(), {})
async def test_session_asave_does_not_resurrect_session_logged_out_in_other_context(
self,
):
"""Sessions shouldn't be resurrected by a concurrent request."""
# Create new session.
s1 = self.backend()
await s1.aset("test_data", "value1")
await s1.asave(must_create=True)
# Logout in another context.
s2 = self.backend(s1.session_key)
await s2.adelete()
# Modify session in first context.
await s1.aset("test_data", "value2")
with self.assertRaises(UpdateError):
# This should throw an exception as the session is deleted, not
# resurrect the session.
await s1.asave()
self.assertEqual(await s1.aload(), {})
class DatabaseSessionTests(SessionTestsMixin, TestCase): class DatabaseSessionTests(SessionTestsMixin, TestCase):
backend = DatabaseSession backend = DatabaseSession
@ -456,6 +729,25 @@ class DatabaseSessionTests(SessionTestsMixin, TestCase):
# ... and one is deleted. # ... and one is deleted.
self.assertEqual(1, self.model.objects.count()) self.assertEqual(1, self.model.objects.count())
async def test_aclear_expired(self):
self.assertEqual(await self.model.objects.acount(), 0)
# Object in the future.
await self.session.aset("key", "value")
await self.session.aset_expiry(3600)
await self.session.asave()
# Object in the past.
other_session = self.backend()
await other_session.aset("key", "value")
await other_session.aset_expiry(-3600)
await other_session.asave()
# Two sessions are in the database before clearing expired.
self.assertEqual(await self.model.objects.acount(), 2)
await self.session.aclear_expired()
await other_session.aclear_expired()
self.assertEqual(await self.model.objects.acount(), 1)
@override_settings(USE_TZ=True) @override_settings(USE_TZ=True)
class DatabaseSessionWithTimeZoneTests(DatabaseSessionTests): class DatabaseSessionWithTimeZoneTests(DatabaseSessionTests):
@ -491,11 +783,28 @@ class CustomDatabaseSessionTests(DatabaseSessionTests):
self.session.set_expiry(None) self.session.set_expiry(None)
self.assertEqual(self.session.get_expiry_age(), self.custom_session_cookie_age) self.assertEqual(self.session.get_expiry_age(), self.custom_session_cookie_age)
async def test_custom_expiry_reset_async(self):
await self.session.aset_expiry(None)
await self.session.aset_expiry(10)
await self.session.aset_expiry(None)
self.assertEqual(
await self.session.aget_expiry_age(), self.custom_session_cookie_age
)
def test_default_expiry(self): def test_default_expiry(self):
self.assertEqual(self.session.get_expiry_age(), self.custom_session_cookie_age) self.assertEqual(self.session.get_expiry_age(), self.custom_session_cookie_age)
self.session.set_expiry(0) self.session.set_expiry(0)
self.assertEqual(self.session.get_expiry_age(), self.custom_session_cookie_age) self.assertEqual(self.session.get_expiry_age(), self.custom_session_cookie_age)
async def test_default_expiry_async(self):
self.assertEqual(
await self.session.aget_expiry_age(), self.custom_session_cookie_age
)
await self.session.aset_expiry(0)
self.assertEqual(
await self.session.aget_expiry_age(), self.custom_session_cookie_age
)
class CacheDBSessionTests(SessionTestsMixin, TestCase): class CacheDBSessionTests(SessionTestsMixin, TestCase):
backend = CacheDBSession backend = CacheDBSession
@ -533,6 +842,22 @@ class CacheDBSessionTests(SessionTestsMixin, TestCase):
self.assertEqual(log.message, f"Error saving to cache ({session._cache})") self.assertEqual(log.message, f"Error saving to cache ({session._cache})")
self.assertEqual(str(log.exc_info[1]), "Faked exception saving to cache") self.assertEqual(str(log.exc_info[1]), "Faked exception saving to cache")
@override_settings(
CACHES={"default": {"BACKEND": "cache.failing_cache.CacheClass"}}
)
async def test_cache_async_set_failure_non_fatal(self):
"""Failing to write to the cache does not raise errors."""
session = self.backend()
await session.aset("key", "val")
with self.assertLogs("django.contrib.sessions", "ERROR") as cm:
await session.asave()
# A proper ERROR log message was recorded.
log = cm.records[-1]
self.assertEqual(log.message, f"Error saving to cache ({session._cache})")
self.assertEqual(str(log.exc_info[1]), "Faked exception saving to cache")
@override_settings(USE_TZ=True) @override_settings(USE_TZ=True)
class CacheDBSessionWithTimeZoneTests(CacheDBSessionTests): class CacheDBSessionWithTimeZoneTests(CacheDBSessionTests):
@ -673,6 +998,12 @@ class CacheSessionTests(SessionTestsMixin, SimpleTestCase):
self.session.save() self.session.save()
self.assertIsNotNone(caches["default"].get(self.session.cache_key)) self.assertIsNotNone(caches["default"].get(self.session.cache_key))
async def test_create_and_save_async(self):
self.session = self.backend()
await self.session.acreate()
await self.session.asave()
self.assertIsNotNone(caches["default"].get(await self.session.acache_key()))
class SessionMiddlewareTests(TestCase): class SessionMiddlewareTests(TestCase):
request_factory = RequestFactory() request_factory = RequestFactory()
@ -899,6 +1230,9 @@ class CookieSessionTests(SessionTestsMixin, SimpleTestCase):
""" """
pass pass
async def test_save_async(self):
pass
def test_cycle(self): def test_cycle(self):
""" """
This test tested cycle_key() which would create a new session This test tested cycle_key() which would create a new session
@ -908,11 +1242,17 @@ class CookieSessionTests(SessionTestsMixin, SimpleTestCase):
""" """
pass pass
async def test_cycle_async(self):
pass
@unittest.expectedFailure @unittest.expectedFailure
def test_actual_expiry(self): def test_actual_expiry(self):
# The cookie backend doesn't handle non-default expiry dates, see #19201 # The cookie backend doesn't handle non-default expiry dates, see #19201
super().test_actual_expiry() super().test_actual_expiry()
async def test_actual_expiry_async(self):
pass
def test_unpickling_exception(self): def test_unpickling_exception(self):
# signed_cookies backend should handle unpickle exceptions gracefully # signed_cookies backend should handle unpickle exceptions gracefully
# by creating a new session # by creating a new session
@ -927,12 +1267,26 @@ class CookieSessionTests(SessionTestsMixin, SimpleTestCase):
def test_session_load_does_not_create_record(self): def test_session_load_does_not_create_record(self):
pass pass
@unittest.skip(
"Cookie backend doesn't have an external store to create records in."
)
async def test_session_load_does_not_create_record_async(self):
pass
@unittest.skip( @unittest.skip(
"CookieSession is stored in the client and there is no way to query it." "CookieSession is stored in the client and there is no way to query it."
) )
def test_session_save_does_not_resurrect_session_logged_out_in_other_context(self): def test_session_save_does_not_resurrect_session_logged_out_in_other_context(self):
pass pass
@unittest.skip(
"CookieSession is stored in the client and there is no way to query it."
)
async def test_session_asave_does_not_resurrect_session_logged_out_in_other_context(
self,
):
pass
class ClearSessionsCommandTests(SimpleTestCase): class ClearSessionsCommandTests(SimpleTestCase):
def test_clearsessions_unsupported(self): def test_clearsessions_unsupported(self):
@ -956,26 +1310,51 @@ class SessionBaseTests(SimpleTestCase):
with self.assertRaisesMessage(NotImplementedError, msg): with self.assertRaisesMessage(NotImplementedError, msg):
self.session.create() self.session.create()
async def test_acreate(self):
msg = self.not_implemented_msg % "a create"
with self.assertRaisesMessage(NotImplementedError, msg):
await self.session.acreate()
def test_delete(self): def test_delete(self):
msg = self.not_implemented_msg % "a delete" msg = self.not_implemented_msg % "a delete"
with self.assertRaisesMessage(NotImplementedError, msg): with self.assertRaisesMessage(NotImplementedError, msg):
self.session.delete() self.session.delete()
async def test_adelete(self):
msg = self.not_implemented_msg % "a delete"
with self.assertRaisesMessage(NotImplementedError, msg):
await self.session.adelete()
def test_exists(self): def test_exists(self):
msg = self.not_implemented_msg % "an exists" msg = self.not_implemented_msg % "an exists"
with self.assertRaisesMessage(NotImplementedError, msg): with self.assertRaisesMessage(NotImplementedError, msg):
self.session.exists(None) self.session.exists(None)
async def test_aexists(self):
msg = self.not_implemented_msg % "an exists"
with self.assertRaisesMessage(NotImplementedError, msg):
await self.session.aexists(None)
def test_load(self): def test_load(self):
msg = self.not_implemented_msg % "a load" msg = self.not_implemented_msg % "a load"
with self.assertRaisesMessage(NotImplementedError, msg): with self.assertRaisesMessage(NotImplementedError, msg):
self.session.load() self.session.load()
async def test_aload(self):
msg = self.not_implemented_msg % "a load"
with self.assertRaisesMessage(NotImplementedError, msg):
await self.session.aload()
def test_save(self): def test_save(self):
msg = self.not_implemented_msg % "a save" msg = self.not_implemented_msg % "a save"
with self.assertRaisesMessage(NotImplementedError, msg): with self.assertRaisesMessage(NotImplementedError, msg):
self.session.save() self.session.save()
async def test_asave(self):
msg = self.not_implemented_msg % "a save"
with self.assertRaisesMessage(NotImplementedError, msg):
await self.session.asave()
def test_test_cookie(self): def test_test_cookie(self):
self.assertIs(self.session.has_key(self.session.TEST_COOKIE_NAME), False) self.assertIs(self.session.has_key(self.session.TEST_COOKIE_NAME), False)
self.session.set_test_cookie() self.session.set_test_cookie()
@ -983,5 +1362,12 @@ class SessionBaseTests(SimpleTestCase):
self.session.delete_test_cookie() self.session.delete_test_cookie()
self.assertIs(self.session.has_key(self.session.TEST_COOKIE_NAME), False) self.assertIs(self.session.has_key(self.session.TEST_COOKIE_NAME), False)
async def test_atest_cookie(self):
self.assertIs(await self.session.ahas_key(self.session.TEST_COOKIE_NAME), False)
await self.session.aset_test_cookie()
self.assertIs(await self.session.atest_cookie_worked(), True)
await self.session.adelete_test_cookie()
self.assertIs(await self.session.ahas_key(self.session.TEST_COOKIE_NAME), False)
def test_is_empty(self): def test_is_empty(self):
self.assertIs(self.session.is_empty(), True) self.assertIs(self.session.is_empty(), True)