diff --git a/django/core/meta/__init__.py b/django/core/meta/__init__.py index c6bcb09e4b..99abbd728c 100644 --- a/django/core/meta/__init__.py +++ b/django/core/meta/__init__.py @@ -208,6 +208,14 @@ class Options: if self.pk is None: self.fields.insert(0, AutoField('id', 'ID', primary_key=True)) self.pk = self.fields[0] + # Cache whether this has an AutoField. + self.has_auto_field = False + for f in self.fields: + is_auto = isinstance(f, AutoField) + if is_auto and self.has_auto_field: + raise AssertionError, "A model can't have more than one AutoField." + elif is_auto: + self.has_auto_field = True def __repr__(self): return '' % self.module_name @@ -717,37 +725,27 @@ def method_save(opts, self): self._pre_save() non_pks = [f for f in opts.fields if not f.primary_key] cursor = db.db.cursor() - add = not bool(getattr(self, opts.pk.name)) - for f in non_pks: - f.pre_save(self, getattr(self, f.name), add) - db_values = [f.get_db_prep_save(getattr(self, f.name)) for f in non_pks] - # OneToOne objects are a special case because there's no AutoField, and the - # primary key field is set manually. - if isinstance(opts.pk.rel, OneToOne): - cursor.execute("UPDATE %s SET %s WHERE %s=%%s" % \ - (opts.db_table, ','.join(['%s=%%s' % f.name for f in non_pks]), - opts.pk.name), db_values + [getattr(self, opts.pk.name)]) - if cursor.rowcount == 0: # If nothing was updated, add the record. - field_names = [f.name for f in opts.fields] - placeholders = ['%s'] * len(field_names) - cursor.execute("INSERT INTO %s (%s) VALUES (%s)" % \ - (opts.db_table, ','.join(field_names), ','.join(placeholders)), - [f.get_db_prep_save(getattr(self, f.name)) for f in opts.fields]) - else: - if not add: - cursor.execute("UPDATE %s SET %s WHERE %s=%%s" % \ - (opts.db_table, ','.join(['%s=%%s' % f.name for f in non_pks]), - opts.pk.name), db_values + [getattr(self, opts.pk.name)]) - else: - field_names = [f.name for f in non_pks] - placeholders = ['%s'] * len(field_names) - if opts.order_with_respect_to: - field_names.append('_order') - placeholders.append('(SELECT COUNT(*) FROM %s WHERE %s = %%s)' % \ - (opts.db_table, opts.order_with_respect_to.name)) - db_values.append(getattr(self, opts.order_with_respect_to.name)) - cursor.execute("INSERT INTO %s (%s) VALUES (%s)" % \ - (opts.db_table, ','.join(field_names), ','.join(placeholders)), db_values) + + # First, try an UPDATE. If that doesn't update anything, do an INSERT. + pk_set = bool(getattr(self, opts.pk.name)) + if pk_set: + db_values = [f.get_db_prep_save(f.pre_save(getattr(self, f.name), False)) for f in non_pks] + cursor.execute("UPDATE %s SET %s WHERE %s=%%s" % (opts.db_table, + ','.join(['%s=%%s' % f.name for f in non_pks]), opts.pk.name), + db_values + [getattr(self, opts.pk.name)]) + if not pk_set or cursor.rowcount == 0: + field_names = [f.name for f in opts.fields if not isinstance(f, AutoField)] + placeholders = ['%s'] * len(field_names) + db_values = [f.get_db_prep_save(f.pre_save(getattr(self, f.name), True)) for f in opts.fields if not isinstance(f, AutoField)] + if opts.order_with_respect_to: + field_names.append('_order') + # TODO: This assumes the database supports subqueries. + placeholders.append('(SELECT COUNT(*) FROM %s WHERE %s = %%s)' % \ + (opts.db_table, opts.order_with_respect_to.name)) + db_values.append(getattr(self, opts.order_with_respect_to.name)) + cursor.execute("INSERT INTO %s (%s) VALUES (%s)" % (opts.db_table, + ','.join(field_names), ','.join(placeholders)), db_values) + if opts.has_auto_field: setattr(self, opts.pk.name, db.get_last_insert_id(cursor, opts.db_table, opts.pk.name)) db.db.commit() # Run any post-save hooks. diff --git a/django/core/meta/fields.py b/django/core/meta/fields.py index 34027a8d7a..46e317dca4 100644 --- a/django/core/meta/fields.py +++ b/django/core/meta/fields.py @@ -82,12 +82,9 @@ class Field(object): else: self.db_index = db_index - def pre_save(self, obj, value, add): - """ - Hook for altering the object obj based on the value of this field and - and on the add/change status. - """ - pass + def pre_save(self, value, add): + "Returns field's value just before saving." + return value def get_db_prep_save(self, value): "Returns field's value prepared for saving into a database." @@ -236,6 +233,10 @@ class Field(object): class AutoField(Field): empty_strings_allowed = False + def __init__(self, *args, **kwargs): + assert kwargs.get('primary_key', False) is True, "%ss must have primary_key=True." % self.__class__.__name__ + Field.__init__(self, *args, **kwargs) + def get_manipulator_fields(self, opts, manipulator, change, name_prefix='', rel=False): if not rel: return [] # Don't add a FormField unless it's in a related context. @@ -280,9 +281,10 @@ class DateField(Field): value = str(value) return Field.get_db_prep_lookup(self, lookup_type, value) - def pre_save(self, obj, value, add): + def pre_save(self, value, add): if self.auto_now or (self.auto_now_add and add): - setattr(obj, self.name, datetime.datetime.now()) + return datetime.datetime.now() + return value def get_db_prep_save(self, value): # Casts dates into string format for entry into database. @@ -483,9 +485,10 @@ class TimeField(Field): value = str(value) return Field.get_db_prep_lookup(self, lookup_type, value) - def pre_save(self, obj, value, add): + def pre_save(self, value, add): if self.auto_now or (self.auto_now_add and add): - setattr(obj, self.name, datetime.datetime.now().time()) + return datetime.datetime.now().time() + return value def get_db_prep_save(self, value): # Casts dates into string format for entry into database.