diff --git a/django/contrib/formtools/tests.py b/django/contrib/formtools/tests.py index bc65a60fbe..7816c15bf5 100644 --- a/django/contrib/formtools/tests.py +++ b/django/contrib/formtools/tests.py @@ -1,5 +1,6 @@ import unittest from django import forms +from django.conf import settings from django.contrib.formtools import preview, wizard, utils from django import http from django.test import TestCase @@ -145,6 +146,9 @@ class WizardPageOneForm(forms.Form): class WizardPageTwoForm(forms.Form): field = forms.CharField() +class WizardPageThreeForm(forms.Form): + field = forms.CharField() + class WizardClass(wizard.FormWizard): def render_template(self, *args, **kw): return http.HttpResponse("") @@ -161,6 +165,15 @@ class DummyRequest(http.HttpRequest): self._dont_enforce_csrf_checks = True class WizardTests(TestCase): + + def setUp(self): + # Use a known SECRET_KEY to make security_hash tests deterministic + self.old_SECRET_KEY = settings.SECRET_KEY + settings.SECRET_KEY = "123" + + def tearDown(self): + settings.SECRET_KEY = self.old_SECRET_KEY + def test_step_starts_at_zero(self): """ step should be zero for the first form @@ -179,3 +192,25 @@ class WizardTests(TestCase): response = wizard(request) self.assertEquals(1, wizard.step) + def test_14498(self): + """ + Regression test for ticket #14498. + """ + that = self + reached = [False] + + class WizardWithProcessStep(WizardClass): + def process_step(self, request, form, step): + reached[0] = True + that.assertTrue(hasattr(form, 'cleaned_data')) + + wizard = WizardWithProcessStep([WizardPageOneForm, + WizardPageTwoForm, + WizardPageThreeForm]) + data = {"0-field": "test", + "1-field": "test2", + "hash_0": "2fdbefd4c0cad51509478fbacddf8b13", + "wizard_step": "1"} + wizard(DummyRequest(POST=data)) + self.assertTrue(reached[0]) + diff --git a/django/contrib/formtools/wizard.py b/django/contrib/formtools/wizard.py index 32e27df574..97d2fb8e34 100644 --- a/django/contrib/formtools/wizard.py +++ b/django/contrib/formtools/wizard.py @@ -68,39 +68,50 @@ class FormWizard(object): if current_step >= self.num_steps(): raise Http404('Step %s does not exist' % current_step) - # For each previous step, verify the hash and process. - # TODO: Move "hash_%d" to a method to make it configurable. - for i in range(current_step): - form = self.get_form(i, request.POST) - if request.POST.get("hash_%d" % i, '') != self.security_hash(request, form): - return self.render_hash_failure(request, i) - self.process_step(request, form, i) - # Process the current step. If it's valid, go to the next step or call # done(), depending on whether any steps remain. if request.method == 'POST': form = self.get_form(current_step, request.POST) else: form = self.get_form(current_step) + if form.is_valid(): + # Validate all the forms. If any of them fail validation, that + # must mean the validator relied on some other input, such as + # an external Web site. + + # It is also possible that validation might fail under certain + # attack situations: an attacker might be able to bypass previous + # stages, and generate correct security hashes for all the + # skipped stages by virtue of: + # 1) having filled out an identical form which doesn't have the + # validation (and does something different at the end), + # 2) or having filled out a previous version of the same form + # which had some validation missing, + # 3) or previously having filled out the form when they had + # more privileges than they do now. + # + # Since the hashes only take into account values, and not other + # other validation the form might do, we must re-do validation + # now for security reasons. + current_form_list = [self.get_form(i, request.POST) for i in range(current_step)] + + for i, f in enumerate(current_form_list): + if request.POST.get("hash_%d" % i, '') != self.security_hash(request, f): + return self.render_hash_failure(request, i) + + if not f.is_valid(): + return self.render_revalidation_failure(request, i, f) + else: + self.process_step(request, f, i) + + # Now progress to processing this step: self.process_step(request, form, current_step) next_step = current_step + 1 - # If this was the last step, validate all of the forms one more - # time, as a sanity check, and call done(). - num = self.num_steps() - if next_step == num: - final_form_list = [self.get_form(i, request.POST) for i in range(num)] - # Validate all the forms. If any of them fail validation, that - # must mean the validator relied on some other input, such as - # an external Web site. - for i, f in enumerate(final_form_list): - if not f.is_valid(): - return self.render_revalidation_failure(request, i, f) + if current_step == self.num_steps(): return self.done(request, final_form_list) - - # Otherwise, move along to the next step. else: form = self.get_form(next_step) self.step = current_step = next_step