diff --git a/django/contrib/formtools/preview.py b/django/contrib/formtools/preview.py index b85c6ef7d1..3fa61ba6bf 100644 --- a/django/contrib/formtools/preview.py +++ b/django/contrib/formtools/preview.py @@ -51,17 +51,17 @@ class FormPreview(object): def preview_get(self, request): "Displays the form" - f = self.form(auto_id=AUTO_ID) + f = self.form(auto_id=self.get_auto_id(), initial=self.get_initial(request)) return render_to_response(self.form_template, - {'form': f, 'stage_field': self.unused_name('stage'), 'state': self.state}, + self.get_context(request, f), context_instance=RequestContext(request)) def preview_post(self, request): "Validates the POST data. If valid, displays the preview page. Else, redisplays form." - f = self.form(request.POST, auto_id=AUTO_ID) - context = {'form': f, 'stage_field': self.unused_name('stage'), 'state': self.state} + f = self.form(request.POST, auto_id=self.get_auto_id()) + context = self.get_context(request, f) if f.is_valid(): - self.process_preview(request, f, context) + self.process_preview(request, f, context) context['hash_field'] = self.unused_name('hash') context['hash_value'] = self.security_hash(request, f) return render_to_response(self.preview_template, context, context_instance=RequestContext(request)) @@ -91,7 +91,7 @@ class FormPreview(object): def post_post(self, request): "Validates the POST data. If valid, calls done(). Else, redisplays form." - f = self.form(request.POST, auto_id=AUTO_ID) + f = self.form(request.POST, auto_id=self.get_auto_id()) if f.is_valid(): if not self._check_security_hash(request.POST.get(self.unused_name('hash'), ''), request, f): @@ -99,11 +99,30 @@ class FormPreview(object): return self.done(request, f.cleaned_data) else: return render_to_response(self.form_template, - {'form': f, 'stage_field': self.unused_name('stage'), 'state': self.state}, + self.get_context(request, f), context_instance=RequestContext(request)) # METHODS SUBCLASSES MIGHT OVERRIDE IF APPROPRIATE ######################## + def get_auto_id(self): + """ + Hook to override the ``auto_id`` kwarg for the form. Needed when + rendering two form previews in the same template. + """ + return AUTO_ID + + def get_initial(self, request): + """ + Takes a request argument and returns a dictionary to pass to the form's + ``initial`` kwarg when the form is being created from an HTTP get. + """ + return {} + + def get_context(self, request, form): + "Context for template rendering." + return {'form': form, 'stage_field': self.unused_name('stage'), 'state': self.state} + + def parse_params(self, *args, **kwargs): """ Given captured args and kwargs from the URLconf, saves something in diff --git a/django/contrib/formtools/tests/__init__.py b/django/contrib/formtools/tests/__init__.py index 4c71c50321..6aeeaf5a3d 100644 --- a/django/contrib/formtools/tests/__init__.py +++ b/django/contrib/formtools/tests/__init__.py @@ -11,6 +11,13 @@ success_string = "Done was called!" class TestFormPreview(preview.FormPreview): + def get_context(self, request, form): + context = super(TestFormPreview, self).get_context(request, form) + context.update({'custom_context': True}) + return context + + def get_initial(self, request): + return {'field1': 'Works!'} def done(self, request, cleaned_data): return http.HttpResponse(success_string) @@ -59,6 +66,8 @@ class PreviewTests(TestCase): response = self.client.get('/test1/') stage = self.input % 1 self.assertContains(response, stage, 1) + self.assertEquals(response.context['custom_context'], True) + self.assertEquals(response.context['form'].initial, {'field1': 'Works!'}) def test_form_preview(self): """