From 27986994eea82645fb99ad1a373b7d99110eb3c0 Mon Sep 17 00:00:00 2001 From: Adrian Holovaty Date: Mon, 15 Aug 2005 16:00:28 +0000 Subject: [PATCH] Fixed #320 -- Changed save() code so that it doesn't rely on cursor.rowcount. MySQLdb sets cursor.rowcount to 0 in an UPDATE statement even if the record already exists. Now save() does a SELECT query to find out whether a record with the primary key already exists. git-svn-id: http://code.djangoproject.com/svn/django/trunk@507 bcc190cf-cafb-0310-a4f2-bffc1f526a37 --- django/core/meta/__init__.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/django/core/meta/__init__.py b/django/core/meta/__init__.py index 69f6538a4c..d28cb04f88 100644 --- a/django/core/meta/__init__.py +++ b/django/core/meta/__init__.py @@ -728,13 +728,21 @@ def method_save(opts, self): cursor = db.db.cursor() # First, try an UPDATE. If that doesn't update anything, do an INSERT. - pk_set = bool(getattr(self, opts.pk.name)) + pk_val = getattr(self, opts.pk.name) + pk_set = bool(pk_val) + record_exists = True if pk_set: - db_values = [f.get_db_prep_save(f.pre_save(getattr(self, f.name), False)) for f in non_pks] - cursor.execute("UPDATE %s SET %s WHERE %s=%%s" % (opts.db_table, - ','.join(['%s=%%s' % f.name for f in non_pks]), opts.pk.name), - db_values + [getattr(self, opts.pk.name)]) - if not pk_set or cursor.rowcount == 0: + # Determine whether a record with the primary key already exists. + cursor.execute("SELECT 1 FROM %s WHERE %s=%%s LIMIT 1" % (opts.db_table, opts.pk.name), [pk_val]) + # If it does already exist, do an UPDATE. + if cursor.rowcount > 0: + db_values = [f.get_db_prep_save(f.pre_save(getattr(self, f.name), False)) for f in non_pks] + cursor.execute("UPDATE %s SET %s WHERE %s=%%s" % (opts.db_table, + ','.join(['%s=%%s' % f.name for f in non_pks]), opts.pk.name), + db_values + [pk_val]) + else: + record_exists = False + if not pk_set or not record_exists: field_names = [f.name for f in opts.fields if not isinstance(f, AutoField)] placeholders = ['%s'] * len(field_names) db_values = [f.get_db_prep_save(f.pre_save(getattr(self, f.name), True)) for f in opts.fields if not isinstance(f, AutoField)]