Refs #16860 -- Moved password_changed() logic to AbstractBaseUser.
Thanks Carl Meyer for review.
This commit is contained in:
parent
d7848c11e0
commit
f5e9d67907
|
@ -4,6 +4,7 @@ not in INSTALLED_APPS.
|
|||
"""
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from django.contrib.auth import password_validation
|
||||
from django.contrib.auth.hashers import (
|
||||
check_password, is_password_usable, make_password,
|
||||
)
|
||||
|
@ -60,9 +61,21 @@ class AbstractBaseUser(models.Model):
|
|||
"Return the identifying username for this User"
|
||||
return getattr(self, self.USERNAME_FIELD)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(AbstractBaseUser, self).__init__(*args, **kwargs)
|
||||
# Stores the raw password if set_password() is called so that it can
|
||||
# be passed to password_changed() after the model is saved.
|
||||
self._password = None
|
||||
|
||||
def __str__(self):
|
||||
return self.get_username()
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
super(AbstractBaseUser, self).save(*args, **kwargs)
|
||||
if self._password is not None:
|
||||
password_validation.password_changed(self._password, self)
|
||||
self._password = None
|
||||
|
||||
def natural_key(self):
|
||||
return (self.get_username(),)
|
||||
|
||||
|
@ -82,6 +95,7 @@ class AbstractBaseUser(models.Model):
|
|||
|
||||
def set_password(self, raw_password):
|
||||
self.password = make_password(raw_password)
|
||||
self._password = raw_password
|
||||
|
||||
def check_password(self, raw_password):
|
||||
"""
|
||||
|
@ -90,6 +104,8 @@ class AbstractBaseUser(models.Model):
|
|||
"""
|
||||
def setter(raw_password):
|
||||
self.set_password(raw_password)
|
||||
# Password hash upgrades shouldn't be considered password changes.
|
||||
self._password = None
|
||||
self.save(update_fields=["password"])
|
||||
return check_password(raw_password, self.password, setter)
|
||||
|
||||
|
|
|
@ -289,7 +289,6 @@ class SetPasswordForm(forms.Form):
|
|||
def save(self, commit=True):
|
||||
password = self.cleaned_data["new_password1"]
|
||||
self.user.set_password(password)
|
||||
password_validation.password_changed(password, self.user)
|
||||
if commit:
|
||||
self.user.save()
|
||||
return self.user
|
||||
|
@ -363,7 +362,6 @@ class AdminPasswordChangeForm(forms.Form):
|
|||
"""
|
||||
password = self.cleaned_data["password1"]
|
||||
self.user.set_password(password)
|
||||
password_validation.password_changed(password, self.user)
|
||||
if commit:
|
||||
self.user.save()
|
||||
return self.user
|
||||
|
|
|
@ -379,6 +379,11 @@ or if you have API calls that allow passwords to be set, for example.
|
|||
by validators such as one that prevents password reuse. This should be
|
||||
called once the password has been successfully changed.
|
||||
|
||||
For subclasses of :class:`~django.contrib.auth.models.AbstractBaseUser`,
|
||||
the password field will be marked as "dirty" when calling
|
||||
:meth:`~django.contrib.auth.models.AbstractBaseUser.set_password` which
|
||||
triggers a call to ``password_changed()`` after the user is saved.
|
||||
|
||||
.. function:: password_validators_help_texts(password_validators=None)
|
||||
|
||||
Returns a list of the help texts of all validators. These explain the
|
||||
|
|
|
@ -5,16 +5,16 @@ import re
|
|||
|
||||
from django import forms
|
||||
from django.contrib.auth.forms import (
|
||||
AuthenticationForm, PasswordChangeForm, PasswordResetForm,
|
||||
ReadOnlyPasswordHashField, ReadOnlyPasswordHashWidget, SetPasswordForm,
|
||||
UserChangeForm, UserCreationForm,
|
||||
AdminPasswordChangeForm, AuthenticationForm, PasswordChangeForm,
|
||||
PasswordResetForm, ReadOnlyPasswordHashField, ReadOnlyPasswordHashWidget,
|
||||
SetPasswordForm, UserChangeForm, UserCreationForm,
|
||||
)
|
||||
from django.contrib.auth.models import User
|
||||
from django.contrib.sites.models import Site
|
||||
from django.core import mail
|
||||
from django.core.mail import EmailMultiAlternatives
|
||||
from django.forms.fields import CharField, Field
|
||||
from django.test import SimpleTestCase, TestCase, override_settings
|
||||
from django.test import SimpleTestCase, TestCase, mock, override_settings
|
||||
from django.utils import translation
|
||||
from django.utils.encoding import force_text
|
||||
from django.utils.text import capfirst
|
||||
|
@ -116,7 +116,8 @@ class UserCreationFormTest(TestDataMixin, TestCase):
|
|||
self.assertEqual(form['password1'].errors, required_error)
|
||||
self.assertEqual(form['password2'].errors, [])
|
||||
|
||||
def test_success(self):
|
||||
@mock.patch('django.contrib.auth.password_validation.password_changed')
|
||||
def test_success(self, password_changed):
|
||||
# The success case.
|
||||
data = {
|
||||
'username': 'jsmith@example.com',
|
||||
|
@ -125,7 +126,10 @@ class UserCreationFormTest(TestDataMixin, TestCase):
|
|||
}
|
||||
form = UserCreationForm(data)
|
||||
self.assertTrue(form.is_valid())
|
||||
form.save(commit=False)
|
||||
self.assertEqual(password_changed.call_count, 0)
|
||||
u = form.save()
|
||||
self.assertEqual(password_changed.call_count, 1)
|
||||
self.assertEqual(repr(u), '<User: jsmith@example.com>')
|
||||
|
||||
|
||||
|
@ -254,7 +258,8 @@ class SetPasswordFormTest(TestDataMixin, TestCase):
|
|||
self.assertEqual(form["new_password2"].errors,
|
||||
[force_text(form.error_messages['password_mismatch'])])
|
||||
|
||||
def test_success(self):
|
||||
@mock.patch('django.contrib.auth.password_validation.password_changed')
|
||||
def test_success(self, password_changed):
|
||||
user = User.objects.get(username='testclient')
|
||||
data = {
|
||||
'new_password1': 'abc123',
|
||||
|
@ -262,6 +267,10 @@ class SetPasswordFormTest(TestDataMixin, TestCase):
|
|||
}
|
||||
form = SetPasswordForm(user, data)
|
||||
self.assertTrue(form.is_valid())
|
||||
form.save(commit=False)
|
||||
self.assertEqual(password_changed.call_count, 0)
|
||||
form.save()
|
||||
self.assertEqual(password_changed.call_count, 1)
|
||||
|
||||
@override_settings(AUTH_PASSWORD_VALIDATORS=[
|
||||
{'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator'},
|
||||
|
@ -313,7 +322,8 @@ class PasswordChangeFormTest(TestDataMixin, TestCase):
|
|||
self.assertEqual(form["new_password2"].errors,
|
||||
[force_text(form.error_messages['password_mismatch'])])
|
||||
|
||||
def test_success(self):
|
||||
@mock.patch('django.contrib.auth.password_validation.password_changed')
|
||||
def test_success(self, password_changed):
|
||||
# The success case.
|
||||
user = User.objects.get(username='testclient')
|
||||
data = {
|
||||
|
@ -323,6 +333,10 @@ class PasswordChangeFormTest(TestDataMixin, TestCase):
|
|||
}
|
||||
form = PasswordChangeForm(user, data)
|
||||
self.assertTrue(form.is_valid())
|
||||
form.save(commit=False)
|
||||
self.assertEqual(password_changed.call_count, 0)
|
||||
form.save()
|
||||
self.assertEqual(password_changed.call_count, 1)
|
||||
|
||||
def test_field_order(self):
|
||||
# Regression test - check the order of fields:
|
||||
|
@ -586,3 +600,21 @@ class ReadOnlyPasswordHashTest(SimpleTestCase):
|
|||
def test_readonly_field_has_changed(self):
|
||||
field = ReadOnlyPasswordHashField()
|
||||
self.assertFalse(field.has_changed('aaa', 'bbb'))
|
||||
|
||||
|
||||
@override_settings(USE_TZ=False, PASSWORD_HASHERS=['django.contrib.auth.hashers.SHA1PasswordHasher'])
|
||||
class AdminPasswordChangeFormTest(TestDataMixin, TestCase):
|
||||
|
||||
@mock.patch('django.contrib.auth.password_validation.password_changed')
|
||||
def test_success(self, password_changed):
|
||||
user = User.objects.get(username='testclient')
|
||||
data = {
|
||||
'password1': 'test123',
|
||||
'password2': 'test123',
|
||||
}
|
||||
form = AdminPasswordChangeForm(user, data)
|
||||
self.assertTrue(form.is_valid())
|
||||
form.save(commit=False)
|
||||
self.assertEqual(password_changed.call_count, 0)
|
||||
form.save()
|
||||
self.assertEqual(password_changed.call_count, 1)
|
||||
|
|
|
@ -1,13 +1,15 @@
|
|||
import datetime
|
||||
|
||||
from django.conf.global_settings import PASSWORD_HASHERS
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.contrib.auth.hashers import get_hasher
|
||||
from django.contrib.auth.models import (
|
||||
AbstractUser, Group, Permission, User, UserManager,
|
||||
)
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.core import mail
|
||||
from django.db.models.signals import post_save
|
||||
from django.test import TestCase, override_settings
|
||||
from django.test import TestCase, mock, override_settings
|
||||
|
||||
|
||||
@override_settings(USE_TZ=False)
|
||||
|
@ -216,6 +218,41 @@ class AbstractUserTestCase(TestCase):
|
|||
user2 = User.objects.create_user(username='user2')
|
||||
self.assertIsNone(user2.last_login)
|
||||
|
||||
def test_user_double_save(self):
|
||||
"""
|
||||
Calling user.save() twice should trigger password_changed() once.
|
||||
"""
|
||||
user = User.objects.create_user(username='user', password='foo')
|
||||
user.set_password('bar')
|
||||
with mock.patch('django.contrib.auth.password_validation.password_changed') as pw_changed:
|
||||
user.save()
|
||||
self.assertEqual(pw_changed.call_count, 1)
|
||||
user.save()
|
||||
self.assertEqual(pw_changed.call_count, 1)
|
||||
|
||||
@override_settings(PASSWORD_HASHERS=PASSWORD_HASHERS)
|
||||
def test_check_password_upgrade(self):
|
||||
"""
|
||||
password_changed() shouldn't be called if User.check_password()
|
||||
triggers a hash iteration upgrade.
|
||||
"""
|
||||
user = User.objects.create_user(username='user', password='foo')
|
||||
initial_password = user.password
|
||||
self.assertTrue(user.check_password('foo'))
|
||||
hasher = get_hasher('default')
|
||||
self.assertEqual('pbkdf2_sha256', hasher.algorithm)
|
||||
|
||||
old_iterations = hasher.iterations
|
||||
try:
|
||||
# Upgrade the password iterations
|
||||
hasher.iterations = old_iterations + 1
|
||||
with mock.patch('django.contrib.auth.password_validation.password_changed') as pw_changed:
|
||||
user.check_password('foo')
|
||||
self.assertEqual(pw_changed.call_count, 0)
|
||||
self.assertNotEqual(initial_password, user.password)
|
||||
finally:
|
||||
hasher.iterations = old_iterations
|
||||
|
||||
|
||||
class IsActiveTestCase(TestCase):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue