diff --git a/django/core/cache/backends/base.py b/django/core/cache/backends/base.py index bb67399f3b..cd0d7bd103 100644 --- a/django/core/cache/backends/base.py +++ b/django/core/cache/backends/base.py @@ -14,6 +14,14 @@ class BaseCache(object): timeout = 300 self.default_timeout = timeout + def add(self, key, value, timeout=None): + """ + Set a value in the cache if the key does not already exist. If + timeout is given, that timeout will be used for the key; otherwise + the default cache timeout will be used. + """ + raise NotImplementedError + def get(self, key, default=None): """ Fetch a given key from the cache. If the key does not exist, return diff --git a/django/core/cache/backends/db.py b/django/core/cache/backends/db.py index 4a0d44a44e..8896b5974b 100644 --- a/django/core/cache/backends/db.py +++ b/django/core/cache/backends/db.py @@ -24,6 +24,9 @@ class CacheClass(BaseCache): except (ValueError, TypeError): self._cull_frequency = 3 + def add(self, key, value, timeout=None): + return self._base_set('add', key, value, timeout) + def get(self, key, default=None): cursor = connection.cursor() cursor.execute("SELECT cache_key, value, expires FROM %s WHERE cache_key = %%s" % self._table, [key]) @@ -38,6 +41,9 @@ class CacheClass(BaseCache): return pickle.loads(base64.decodestring(row[1])) def set(self, key, value, timeout=None): + return self._base_set('set', key, value, timeout) + + def _base_set(self, mode, key, value, timeout=None): if timeout is None: timeout = self.default_timeout cursor = connection.cursor() @@ -50,10 +56,11 @@ class CacheClass(BaseCache): encoded = base64.encodestring(pickle.dumps(value, 2)).strip() cursor.execute("SELECT cache_key FROM %s WHERE cache_key = %%s" % self._table, [key]) try: - if cursor.fetchone(): + if mode == 'set' and cursor.fetchone(): cursor.execute("UPDATE %s SET value = %%s, expires = %%s WHERE cache_key = %%s" % self._table, [encoded, str(exp), key]) else: - cursor.execute("INSERT INTO %s (cache_key, value, expires) VALUES (%%s, %%s, %%s)" % self._table, [key, encoded, str(exp)]) + if mode == 'add': + cursor.execute("INSERT INTO %s (cache_key, value, expires) VALUES (%%s, %%s, %%s)" % self._table, [key, encoded, str(exp)]) except DatabaseError: # To be threadsafe, updates/inserts are allowed to fail silently pass diff --git a/django/core/cache/backends/dummy.py b/django/core/cache/backends/dummy.py index 4c64161538..3ff7e1c1b9 100644 --- a/django/core/cache/backends/dummy.py +++ b/django/core/cache/backends/dummy.py @@ -6,6 +6,9 @@ class CacheClass(BaseCache): def __init__(self, *args, **kwargs): pass + def add(self, *args, **kwargs): + pass + def get(self, key, default=None): return default diff --git a/django/core/cache/backends/filebased.py b/django/core/cache/backends/filebased.py index d5415c8ace..690193ac81 100644 --- a/django/core/cache/backends/filebased.py +++ b/django/core/cache/backends/filebased.py @@ -17,6 +17,26 @@ class CacheClass(SimpleCacheClass): del self._cache del self._expire_info + def add(self, key, value, timeout=None): + fname = self._key_to_file(key) + if timeout is None: + timeout = self.default_timeout + try: + filelist = os.listdir(self._dir) + except (IOError, OSError): + self._createdir() + filelist = [] + if len(filelist) > self._max_entries: + self._cull(filelist) + if os.path.basename(fname) not in filelist: + try: + f = open(fname, 'wb') + now = time.time() + pickle.dump(now + timeout, f, 2) + pickle.dump(value, f, 2) + except (IOError, OSError): + pass + def get(self, key, default=None): fname = self._key_to_file(key) try: diff --git a/django/core/cache/backends/locmem.py b/django/core/cache/backends/locmem.py index 4c48c571b7..5998f7bfd5 100644 --- a/django/core/cache/backends/locmem.py +++ b/django/core/cache/backends/locmem.py @@ -14,6 +14,13 @@ class CacheClass(SimpleCacheClass): SimpleCacheClass.__init__(self, host, params) self._lock = RWLock() + def add(self, key, value, timeout=None): + self._lock.writer_enters() + try: + SimpleCacheClass.add(self, key, value, timeout) + finally: + self._lock.writer_leaves() + def get(self, key, default=None): should_delete = False self._lock.reader_enters() diff --git a/django/core/cache/backends/memcached.py b/django/core/cache/backends/memcached.py index 52610daef1..096cec0ee0 100644 --- a/django/core/cache/backends/memcached.py +++ b/django/core/cache/backends/memcached.py @@ -16,6 +16,9 @@ class CacheClass(BaseCache): BaseCache.__init__(self, params) self._cache = memcache.Client(server.split(';')) + def add(self, key, value, timeout=0): + self._cache.add(key.encode('ascii', 'ignore'), value, timeout or self.default_timeout) + def get(self, key, default=None): val = self._cache.get(smart_str(key)) if val is None: diff --git a/django/core/cache/backends/simple.py b/django/core/cache/backends/simple.py index 3fcad8c7ad..ff60d49066 100644 --- a/django/core/cache/backends/simple.py +++ b/django/core/cache/backends/simple.py @@ -21,6 +21,15 @@ class CacheClass(BaseCache): except (ValueError, TypeError): self._cull_frequency = 3 + def add(self, key, value, timeout=None): + if len(self._cache) >= self._max_entries: + self._cull() + if timeout is None: + timeout = self.default_timeout + if key not in self._cache.keys(): + self._cache[key] = value + self._expire_info[key] = time.time() + timeout + def get(self, key, default=None): now = time.time() exp = self._expire_info.get(key) diff --git a/docs/cache.txt b/docs/cache.txt index 8ba0383909..6fe0a22ae7 100644 --- a/docs/cache.txt +++ b/docs/cache.txt @@ -326,6 +326,15 @@ get() can take a ``default`` argument:: >>> cache.get('my_key', 'has expired') 'has expired' +To add a key only if it doesn't already exist, there is an add() method. It +takes the same parameters as set(), but will not attempt to update the cache +if the key specified is already present:: + + >>> cache.set('add_key', 'Initial value') + >>> cache.add('add_key', 'New value') + >>> cache.get('add_key') + 'Initial value' + There's also a get_many() interface that only hits the cache once. get_many() returns a dictionary with all the keys you asked for that actually exist in the cache (and haven't expired):: diff --git a/tests/regressiontests/cache/tests.py b/tests/regressiontests/cache/tests.py index 752083bd2f..3879da7703 100644 --- a/tests/regressiontests/cache/tests.py +++ b/tests/regressiontests/cache/tests.py @@ -19,6 +19,12 @@ class Cache(unittest.TestCase): cache.set("key", "value") self.assertEqual(cache.get("key"), "value") + def test_add(self): + # test add (only add if key isn't already in cache) + cache.add("addkey1", "value") + cache.add("addkey1", "newvalue") + self.assertEqual(cache.get("addkey1"), "value") + def test_non_existent(self): # get with non-existent keys self.assertEqual(cache.get("does_not_exist"), None)