diff --git a/django/contrib/formtools/tests/wizard/__init__.py b/django/contrib/formtools/tests/wizard/__init__.py index 732c02f940..a2a9692ac6 100644 --- a/django/contrib/formtools/tests/wizard/__init__.py +++ b/django/contrib/formtools/tests/wizard/__init__.py @@ -15,4 +15,5 @@ from django.contrib.formtools.tests.wizard.wizardtests.tests import ( CookieWizardTests, WizardTestKwargs, WizardTestGenericViewInterface, + WizardFormKwargsOverrideTests, ) diff --git a/django/contrib/formtools/tests/wizard/wizardtests/forms.py b/django/contrib/formtools/tests/wizard/wizardtests/forms.py index 9013e89aef..6a8132971b 100644 --- a/django/contrib/formtools/tests/wizard/wizardtests/forms.py +++ b/django/contrib/formtools/tests/wizard/wizardtests/forms.py @@ -2,8 +2,10 @@ import os import tempfile from django import forms +from django.contrib.auth.models import User from django.core.files.storage import FileSystemStorage from django.forms.formsets import formset_factory +from django.forms.models import modelformset_factory from django.http import HttpResponse from django.template import Template, Context @@ -50,6 +52,13 @@ class ContactWizard(WizardView): context.update({'another_var': True}) return context +class UserForm(forms.ModelForm): + class Meta: + model = User + fields = ('username', 'email') + +UserFormSet = modelformset_factory(User, form=UserForm) + class SessionContactWizard(ContactWizard): storage_name = 'django.contrib.formtools.wizard.storage.session.SessionStorage' diff --git a/django/contrib/formtools/tests/wizard/wizardtests/tests.py b/django/contrib/formtools/tests/wizard/wizardtests/tests.py index 592065e51a..2ef39dc7a3 100644 --- a/django/contrib/formtools/tests/wizard/wizardtests/tests.py +++ b/django/contrib/formtools/tests/wizard/wizardtests/tests.py @@ -7,6 +7,7 @@ from django.test.client import RequestFactory from django.conf import settings from django.contrib.auth.models import User from django.contrib.formtools.wizard.views import CookieWizardView +from django.contrib.formtools.tests.wizard.forms import UserForm, UserFormSet class WizardTests(object): @@ -331,3 +332,48 @@ class WizardTestGenericViewInterface(TestCase): response = view(factory.get('/')) self.assertEquals(response.context_data['test_key'], 'test_value') self.assertEquals(response.context_data['another_key'], 'another_value') + + +class WizardFormKwargsOverrideTests(TestCase): + def setUp(self): + super(WizardFormKwargsOverrideTests, self).setUp() + self.rf = RequestFactory() + + # Create two users so we can filter by is_staff when handing our + # wizard a queryset keyword argument. + self.normal_user = User.objects.create(username='test1', email='normal@example.com') + self.staff_user = User.objects.create(username='test2', email='staff@example.com', is_staff=True) + + def test_instance_is_maintained(self): + self.assertEqual(2, User.objects.count()) + queryset = User.objects.get(pk=self.staff_user.pk) + + class InstanceOverrideWizard(CookieWizardView): + def get_form_kwargs(self, step): + return {'instance': queryset} + + view = InstanceOverrideWizard.as_view([UserForm]) + response = view(self.rf.get('/')) + + form = response.context_data['wizard']['form'] + + self.assertNotEqual(form.instance.pk, None) + self.assertEqual(form.instance.pk, self.staff_user.pk) + self.assertEqual('staff@example.com', form.initial.get('email', None)) + + def test_queryset_is_maintained(self): + queryset = User.objects.filter(pk=self.staff_user.pk) + + class QuerySetOverrideWizard(CookieWizardView): + def get_form_kwargs(self, step): + return {'queryset': queryset} + + view = QuerySetOverrideWizard.as_view([UserFormSet]) + response = view(self.rf.get('/')) + + formset = response.context_data['wizard']['form'] + + self.assertNotEqual(formset.queryset, None) + self.assertEqual(formset.initial_form_count(), 1) + self.assertEqual(['staff@example.com'], + list(formset.queryset.values_list('email', flat=True))) diff --git a/django/contrib/formtools/wizard/views.py b/django/contrib/formtools/wizard/views.py index 06a03984a7..4372c8aa6a 100644 --- a/django/contrib/formtools/wizard/views.py +++ b/django/contrib/formtools/wizard/views.py @@ -385,11 +385,13 @@ class WizardView(TemplateView): 'initial': self.get_form_initial(step), }) if issubclass(self.form_list[step], forms.ModelForm): - # If the form is based on ModelForm, add instance if available. - kwargs.update({'instance': self.get_form_instance(step)}) + # If the form is based on ModelForm, add instance if available + # and not previously set. + kwargs.setdefault('instance', self.get_form_instance(step)) elif issubclass(self.form_list[step], forms.models.BaseModelFormSet): - # If the form is based on ModelFormSet, add queryset if available. - kwargs.update({'queryset': self.get_form_instance(step)}) + # If the form is based on ModelFormSet, add queryset if available + # and not previous set. + kwargs.setdefault('queryset', self.get_form_instance(step)) return self.form_list[step](**kwargs) def process_step(self, form):