From 7ffbadd3056273278c8208fe63864afa1ec92002 Mon Sep 17 00:00:00 2001 From: Roman Gershman Date: Tue, 1 Mar 2022 12:35:02 +0200 Subject: [PATCH] Add t_hash.c file. Add some compact_object tests for set,hset,zset interfaces. Enable memory leak checking for mimalloc allocator in compact_object_test. --- src/core/compact_object.cc | 17 +- src/core/compact_object.h | 4 +- src/core/compact_object_test.cc | 94 ++- src/core/op_status.h | 2 +- src/redis/CMakeLists.txt | 2 +- src/redis/object.c | 1 - src/redis/object.h | 46 +- src/redis/redis_aux.c | 14 + src/redis/redis_aux.h | 4 +- src/redis/t_hash.c | 1171 +++++++++++++++++++++++++++++++ src/redis/t_zset.c | 24 +- src/redis/zset.h | 3 - 12 files changed, 1329 insertions(+), 53 deletions(-) create mode 100644 src/redis/t_hash.c diff --git a/src/core/compact_object.cc b/src/core/compact_object.cc index ac2fa61..c2eecbe 100644 --- a/src/core/compact_object.cc +++ b/src/core/compact_object.cc @@ -9,6 +9,7 @@ extern "C" { #include "redis/intset.h" +#include "redis/listpack.h" #include "redis/object.h" #include "redis/util.h" #include "redis/zmalloc.h" // for non-string objects. @@ -242,7 +243,16 @@ void RobjWrapper::Free(std::pmr::memory_resource* mr) { LOG(FATAL) << "TBD"; break; case OBJ_HASH: - LOG(FATAL) << "Unsupported HASH type"; + switch (encoding) { + case OBJ_ENCODING_HT: + dictRelease((dict*)ptr); + break; + case OBJ_ENCODING_LISTPACK: + lpFree((uint8_t*)ptr); + break; + default: + LOG(FATAL) << "Unknown hset encoding type"; + } break; case OBJ_MODULE: LOG(FATAL) << "Unsupported OBJ_MODULE type"; @@ -520,11 +530,10 @@ robj* CompactObj::AsRObj() const { } void CompactObj::SyncRObj() { - CHECK_EQ(ROBJ_TAG, taglen_); - robj* obj = &tl.tmp_robj; - CHECK_EQ(u_.r_obj.type, obj->type); + DCHECK_EQ(ROBJ_TAG, taglen_); + DCHECK_EQ(u_.r_obj.type, obj->type); u_.r_obj.encoding = obj->encoding; u_.r_obj.blob.Set(obj->ptr, 0); diff --git a/src/core/compact_object.h b/src/core/compact_object.h index 6696c10..b11962b 100644 --- a/src/core/compact_object.h +++ b/src/core/compact_object.h @@ -204,6 +204,7 @@ class CompactObj { unsigned Encoding() const; unsigned ObjType() const; + quicklist* GetQL() const; // Takes ownership over o. @@ -215,11 +216,12 @@ class CompactObj { // Requires: AsRObj() has been called before in the same thread in fiber-atomic section. void SyncRObj(); + // For STR object. void SetInt(int64_t val); std::optional TryGetInt() const; + // For STR object. void SetString(std::string_view str); - void GetString(std::string* res) const; size_t MallocUsed() const; diff --git a/src/core/compact_object_test.cc b/src/core/compact_object_test.cc index 84393b4..9abcc92 100644 --- a/src/core/compact_object_test.cc +++ b/src/core/compact_object_test.cc @@ -3,12 +3,15 @@ // #include "core/compact_object.h" +#include #include #include "base/gtest.h" +#include "base/logging.h" extern "C" { #include "redis/object.h" +#include "redis/redis_aux.h" #include "redis/zmalloc.h" } @@ -25,18 +28,35 @@ void PrintTo(const CompactObj& cobj, std::ostream* os) { class CompactObjectTest : public ::testing::Test { protected: - static void SetUpTestCase() { + static void SetUpTestSuite() { + InitRedisTables(); // to initialize server struct. + init_zmalloc_threadlocal(); CompactObj::InitThreadLocal(pmr::get_default_resource()); } - CompactObj cs_; + static void TearDownTestSuite() { + mi_heap_collect(mi_heap_get_backing(), true); + + auto cb_visit = [](const mi_heap_t* heap, const mi_heap_area_t* area, void* block, + size_t block_size, void* arg) { + LOG(ERROR) << "Unfreed allocations: block_size " << block_size + << ", allocated: " << area->used * block_size; + return true; + }; + + mi_heap_visit_blocks(mi_heap_get_backing(), false /* do not visit all blocks*/, cb_visit, + nullptr); + } + + CompactObj cobj_; string tmp_; }; TEST_F(CompactObjectTest, Basic) { robj* rv = createRawStringObject("foo", 3); - cs_.ImportRObj(rv); + cobj_.ImportRObj(rv); + return; CompactObj a; @@ -67,22 +87,21 @@ TEST_F(CompactObjectTest, NonInline) { } TEST_F(CompactObjectTest, Int) { - cs_.SetString("0"); - EXPECT_EQ(0, cs_.TryGetInt()); - EXPECT_EQ(cs_, "0"); - EXPECT_EQ("0", cs_.GetSlice(&tmp_)); - EXPECT_EQ(OBJ_STRING, cs_.ObjType()); - cs_.SetString("42"); - EXPECT_EQ(8181779779123079347, cs_.HashCode()); - EXPECT_EQ(OBJ_ENCODING_INT, cs_.Encoding()); + cobj_.SetString("0"); + EXPECT_EQ(0, cobj_.TryGetInt()); + EXPECT_EQ(cobj_, "0"); + EXPECT_EQ("0", cobj_.GetSlice(&tmp_)); + EXPECT_EQ(OBJ_STRING, cobj_.ObjType()); + cobj_.SetString("42"); + EXPECT_EQ(8181779779123079347, cobj_.HashCode()); + EXPECT_EQ(OBJ_ENCODING_INT, cobj_.Encoding()); } TEST_F(CompactObjectTest, MediumString) { - CompactObj obj; string tmp(512, 'b'); - obj.SetString(tmp); - obj.SetString(tmp); - obj.Reset(); + cobj_.SetString(tmp); + cobj_.SetString(tmp); + cobj_.Reset(); } TEST_F(CompactObjectTest, AsciiUtil) { @@ -98,4 +117,49 @@ TEST_F(CompactObjectTest, AsciiUtil) { ASSERT_EQ(data.substr(0, 7), actual); } +TEST_F(CompactObjectTest, IntSet) { + robj* src = createIntsetObject(); + cobj_.ImportRObj(src); + EXPECT_EQ(OBJ_SET, cobj_.ObjType()); + EXPECT_EQ(OBJ_ENCODING_INTSET, cobj_.Encoding()); + + robj* os = cobj_.AsRObj(); + EXPECT_EQ(0, setTypeSize(os)); + sds val1 = sdsnew("10"); + sds val2 = sdsdup(val1); + + EXPECT_EQ(1, setTypeAdd(os, val1)); + EXPECT_EQ(0, setTypeAdd(os, val2)); + EXPECT_EQ(OBJ_ENCODING_INTSET, os->encoding); + sdsfree(val1); + sdsfree(val2); + cobj_.SyncRObj(); + EXPECT_GT(cobj_.MallocUsed(), 0); +} + +TEST_F(CompactObjectTest, HSet) { + robj* src = createHashObject(); + cobj_.ImportRObj(src); + + EXPECT_EQ(OBJ_HASH, cobj_.ObjType()); + EXPECT_EQ(OBJ_ENCODING_LISTPACK, cobj_.Encoding()); + + robj* os = cobj_.AsRObj(); + + sds key1 = sdsnew("key1"); + sds val1 = sdsnew("val1"); + + + // returns 0 on insert. + EXPECT_EQ(0, hashTypeSet(os, key1, val1, HASH_SET_TAKE_FIELD | HASH_SET_TAKE_VALUE)); + cobj_.SyncRObj(); +} + +TEST_F(CompactObjectTest, ZSet) { + // unrelated, checking sds static encoding used in zset special strings. + char kMinStrData[] = "\110" "minstring"; + EXPECT_EQ(9, sdslen(kMinStrData + 1)); + +} + } // namespace dfly diff --git a/src/core/op_status.h b/src/core/op_status.h index 6a28be5..7d4a1ea 100644 --- a/src/core/op_status.h +++ b/src/core/op_status.h @@ -72,7 +72,7 @@ template class OpResult : public OpResultBase { } private: - V v_; + V v_{}; }; template <> class OpResult : public OpResultBase { diff --git a/src/redis/CMakeLists.txt b/src/redis/CMakeLists.txt index 59e266b..e71c1c1 100644 --- a/src/redis/CMakeLists.txt +++ b/src/redis/CMakeLists.txt @@ -10,7 +10,7 @@ endif() add_library(redis_lib crc64.c crcspeed.c debug.c dict.c endianconv.c intset.c listpack.c mt19937-64.c object.c lzf_c.c lzf_d.c sds.c sha256.c - quicklist.c redis_aux.c siphash.c t_set.c t_zset.c util.c ${ZMALLOC_SRC}) + quicklist.c redis_aux.c siphash.c t_hash.c t_set.c t_zset.c util.c ${ZMALLOC_SRC}) cxx_link(redis_lib ${ZMALLOC_DEPS}) diff --git a/src/redis/object.c b/src/redis/object.c index 15dfeb8..c4f5322 100644 --- a/src/redis/object.c +++ b/src/redis/object.c @@ -58,7 +58,6 @@ * too strict. */ static void dismissMemory(void* ptr, size_t size_hint) { } -struct sharedObjectsStruct shared; /* ===================== Creation and parsing of objects ==================== */ diff --git a/src/redis/object.h b/src/redis/object.h index 9536eb1..7ee59c8 100644 --- a/src/redis/object.h +++ b/src/redis/object.h @@ -121,7 +121,6 @@ static inline int sdsEncodedObject(const robj *o) { return o->encoding == OBJ_ENCODING_RAW || o->encoding == OBJ_ENCODING_EMBSTR; } - /* Structure to hold set iteration abstraction. */ typedef struct { robj *subject; @@ -164,6 +163,35 @@ int setTypeNext(setTypeIterator *si, sds *sdsele, int64_t *llele); sds setTypeNextObject(setTypeIterator *si); + +/* hash set interface */ + + +/* Hash data type */ +#define HASH_SET_TAKE_FIELD (1<<0) +#define HASH_SET_TAKE_VALUE (1<<1) +#define HASH_SET_COPY 0 + +void hashTypeConvert(robj *o, int enc); +void hashTypeTryConversion(robj *subject, robj **argv, int start, int end); +int hashTypeExists(robj *o, sds key); +int hashTypeDelete(robj *o, sds key); +unsigned long hashTypeLength(const robj *o); +hashTypeIterator *hashTypeInitIterator(robj *subject); +void hashTypeReleaseIterator(hashTypeIterator *hi); +int hashTypeNext(hashTypeIterator *hi); +void hashTypeCurrentFromListpack(hashTypeIterator *hi, int what, + unsigned char **vstr, + unsigned int *vlen, + long long *vll); +sds hashTypeCurrentFromHashTable(hashTypeIterator *hi, int what); +void hashTypeCurrentObject(hashTypeIterator *hi, int what, unsigned char **vstr, unsigned int *vlen, long long *vll); +sds hashTypeCurrentObjectNewSds(hashTypeIterator *hi, int what); +robj *hashTypeGetValueObject(robj *o, sds field); +int hashTypeSet(robj *o, sds field, sds value, int flags); +robj *hashTypeDup(robj *o); + + /* Macro used to initialize a Redis object allocated on the stack. * Note that this macro is taken near the structure definition to make sure * we'll update it when the structure is changed, to avoid bugs like @@ -181,21 +209,5 @@ sds setTypeNextObject(setTypeIterator *si); #define PROTO_SHARED_SELECT_CMDS 10 #define OBJ_SHARED_BULKHDR_LEN 32 -struct sharedObjectsStruct { - robj *crlf, *ok, *err, *emptybulk, *czero, *cone, *pong, *space, - *colon, *queued, *null[4], *nullarray[4], *emptymap[4], *emptyset[4], - *emptyarray, *wrongtypeerr, *nokeyerr, *syntaxerr, *sameobjecterr, - *outofrangeerr, *noscripterr, *loadingerr, *slowscripterr, *bgsaveerr, - *masterdownerr, *roslaveerr, *execaborterr, *noautherr, *noreplicaserr, - *busykeyerr, *oomerr, *plus, *messagebulk, *pmessagebulk, *subscribebulk, - *unsubscribebulk, *psubscribebulk, *punsubscribebulk, *del, *unlink, - *rpop, *lpop, *lpush, *rpoplpush, *lmove, *blmove, *zpopmin, *zpopmax, - *emptyscan, *multi, *exec, *left, *right; - sds minstring, maxstring; -}; - -extern struct sharedObjectsStruct shared; - -void initSharedStruct(); #endif diff --git a/src/redis/redis_aux.c b/src/redis/redis_aux.c index 66e4cb9..e0cc53d 100644 --- a/src/redis/redis_aux.c +++ b/src/redis/redis_aux.c @@ -14,6 +14,9 @@ void InitRedisTables() { server.page_size = sysconf(_SC_PAGESIZE); server.zset_max_listpack_entries = 128; server.zset_max_listpack_value = 64; + server.set_max_intset_entries = 512; + server.hash_max_listpack_entries = 512; + server.hash_max_listpack_value = 64; } // These functions are moved here from server.c @@ -100,3 +103,14 @@ dictType zsetDictType = { NULL, /* val destructor */ NULL /* allow to expand */ }; + +/* Hash type hash table (note that small hashes are represented with listpacks) */ +dictType hashDictType = { + dictSdsHash, /* hash function */ + NULL, /* key dup */ + NULL, /* val dup */ + dictSdsKeyCompare, /* key compare */ + dictSdsDestructor, /* key destructor */ + dictSdsDestructor, /* val destructor */ + NULL /* allow to expand */ +}; diff --git a/src/redis/redis_aux.h b/src/redis/redis_aux.h index 761485b..48a603f 100644 --- a/src/redis/redis_aux.h +++ b/src/redis/redis_aux.h @@ -89,8 +89,8 @@ typedef struct ServerStub { int rdb_save_incremental_fsync; size_t stat_peak_memory; - size_t set_max_intset_entries, hash_max_ziplist_entries, - hash_max_ziplist_value; + size_t set_max_intset_entries, hash_max_listpack_entries, + hash_max_listpack_value; size_t zset_max_listpack_entries; size_t zset_max_listpack_value; int sanitize_dump_payload; /* Enables deep sanitization for ziplist and listpack in RDB and RESTORE. */ diff --git a/src/redis/t_hash.c b/src/redis/t_hash.c new file mode 100644 index 0000000..b24936f --- /dev/null +++ b/src/redis/t_hash.c @@ -0,0 +1,1171 @@ +/* + * Copyright (c) 2009-2012, Salvatore Sanfilippo + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of Redis nor the names of its contributors may be used + * to endorse or promote products derived from this software without + * specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#include + +// ROMAN - taken from redis 7.0 branch. + +#include "listpack.h" +#include "object.h" +#include "redis_aux.h" +#include "util.h" +#include "zmalloc.h" + +/*----------------------------------------------------------------------------- + * Hash type API + *----------------------------------------------------------------------------*/ + +/* Check the length of a number of objects to see if we need to convert a + * listpack to a real hash. Note that we only check string encoded objects + * as their string length can be queried in constant time. */ +void hashTypeTryConversion(robj *o, robj **argv, int start, int end) { + int i; + size_t sum = 0; + + if (o->encoding != OBJ_ENCODING_LISTPACK) return; + + for (i = start; i <= end; i++) { + if (!sdsEncodedObject(argv[i])) + continue; + size_t len = sdslen(argv[i]->ptr); + if (len > server.hash_max_listpack_value) { + hashTypeConvert(o, OBJ_ENCODING_HT); + return; + } + sum += len; + } + if (!lpSafeToAdd(o->ptr, sum)) + hashTypeConvert(o, OBJ_ENCODING_HT); +} + +/* Get the value from a listpack encoded hash, identified by field. + * Returns -1 when the field cannot be found. */ +int hashTypeGetFromListpack(robj *o, sds field, + unsigned char **vstr, + unsigned int *vlen, + long long *vll) +{ + unsigned char *zl, *fptr = NULL, *vptr = NULL; + + serverAssert(o->encoding == OBJ_ENCODING_LISTPACK); + + zl = o->ptr; + fptr = lpFirst(zl); + if (fptr != NULL) { + fptr = lpFind(zl, fptr, (unsigned char*)field, sdslen(field), 1); + if (fptr != NULL) { + /* Grab pointer to the value (fptr points to the field) */ + vptr = lpNext(zl, fptr); + serverAssert(vptr != NULL); + } + } + + if (vptr != NULL) { + *vstr = lpGetValue(vptr, vlen, vll); + return 0; + } + + return -1; +} + +/* Get the value from a hash table encoded hash, identified by field. + * Returns NULL when the field cannot be found, otherwise the SDS value + * is returned. */ +sds hashTypeGetFromHashTable(robj *o, sds field) { + dictEntry *de; + + serverAssert(o->encoding == OBJ_ENCODING_HT); + + de = dictFind(o->ptr, field); + if (de == NULL) return NULL; + return dictGetVal(de); +} + +/* Higher level function of hashTypeGet*() that returns the hash value + * associated with the specified field. If the field is found C_OK + * is returned, otherwise C_ERR. The returned object is returned by + * reference in either *vstr and *vlen if it's returned in string form, + * or stored in *vll if it's returned as a number. + * + * If *vll is populated *vstr is set to NULL, so the caller + * can always check the function return by checking the return value + * for C_OK and checking if vll (or vstr) is NULL. */ +int hashTypeGetValue(robj *o, sds field, unsigned char **vstr, unsigned int *vlen, long long *vll) { + if (o->encoding == OBJ_ENCODING_LISTPACK) { + *vstr = NULL; + if (hashTypeGetFromListpack(o, field, vstr, vlen, vll) == 0) + return C_OK; + } else if (o->encoding == OBJ_ENCODING_HT) { + sds value; + if ((value = hashTypeGetFromHashTable(o, field)) != NULL) { + *vstr = (unsigned char*) value; + *vlen = sdslen(value); + return C_OK; + } + } else { + serverPanic("Unknown hash encoding"); + } + return C_ERR; +} + +/* Like hashTypeGetValue() but returns a Redis object, which is useful for + * interaction with the hash type outside t_hash.c. + * The function returns NULL if the field is not found in the hash. Otherwise + * a newly allocated string object with the value is returned. */ +robj *hashTypeGetValueObject(robj *o, sds field) { + unsigned char *vstr; + unsigned int vlen; + long long vll; + + if (hashTypeGetValue(o,field,&vstr,&vlen,&vll) == C_ERR) return NULL; + if (vstr) return createStringObject((char*)vstr,vlen); + else return createStringObjectFromLongLong(vll); +} + +/* Higher level function using hashTypeGet*() to return the length of the + * object associated with the requested field, or 0 if the field does not + * exist. */ +size_t hashTypeGetValueLength(robj *o, sds field) { + size_t len = 0; + if (o->encoding == OBJ_ENCODING_LISTPACK) { + unsigned char *vstr = NULL; + unsigned int vlen = UINT_MAX; + long long vll = LLONG_MAX; + + if (hashTypeGetFromListpack(o, field, &vstr, &vlen, &vll) == 0) + len = vstr ? vlen : sdigits10(vll); + } else if (o->encoding == OBJ_ENCODING_HT) { + sds aux; + + if ((aux = hashTypeGetFromHashTable(o, field)) != NULL) + len = sdslen(aux); + } else { + serverPanic("Unknown hash encoding"); + } + return len; +} + +/* Test if the specified field exists in the given hash. Returns 1 if the field + * exists, and 0 when it doesn't. */ +int hashTypeExists(robj *o, sds field) { + if (o->encoding == OBJ_ENCODING_LISTPACK) { + unsigned char *vstr = NULL; + unsigned int vlen = UINT_MAX; + long long vll = LLONG_MAX; + + if (hashTypeGetFromListpack(o, field, &vstr, &vlen, &vll) == 0) return 1; + } else if (o->encoding == OBJ_ENCODING_HT) { + if (hashTypeGetFromHashTable(o, field) != NULL) return 1; + } else { + serverPanic("Unknown hash encoding"); + } + return 0; +} + +/* Add a new field, overwrite the old with the new value if it already exists. + * Return 0 on insert and 1 on update. + * + * By default, the key and value SDS strings are copied if needed, so the + * caller retains ownership of the strings passed. However this behavior + * can be effected by passing appropriate flags (possibly bitwise OR-ed): + * + * HASH_SET_TAKE_FIELD -- The SDS field ownership passes to the function. + * HASH_SET_TAKE_VALUE -- The SDS value ownership passes to the function. + * + * When the flags are used the caller does not need to release the passed + * SDS string(s). It's up to the function to use the string to create a new + * entry or to free the SDS string before returning to the caller. + * + * HASH_SET_COPY corresponds to no flags passed, and means the default + * semantics of copying the values if needed. + * + */ +#define HASH_SET_TAKE_FIELD (1<<0) +#define HASH_SET_TAKE_VALUE (1<<1) +#define HASH_SET_COPY 0 +int hashTypeSet(robj *o, sds field, sds value, int flags) { + int update = 0; + + if (o->encoding == OBJ_ENCODING_LISTPACK) { + unsigned char *zl, *fptr, *vptr; + + zl = o->ptr; + fptr = lpFirst(zl); + if (fptr != NULL) { + fptr = lpFind(zl, fptr, (unsigned char*)field, sdslen(field), 1); + if (fptr != NULL) { + /* Grab pointer to the value (fptr points to the field) */ + vptr = lpNext(zl, fptr); + serverAssert(vptr != NULL); + update = 1; + + /* Replace value */ + zl = lpReplace(zl, &vptr, (unsigned char*)value, sdslen(value)); + } + } + + if (!update) { + /* Push new field/value pair onto the tail of the listpack */ + zl = lpAppend(zl, (unsigned char*)field, sdslen(field)); + zl = lpAppend(zl, (unsigned char*)value, sdslen(value)); + } + o->ptr = zl; + + /* Check if the listpack needs to be converted to a hash table */ + if (hashTypeLength(o) > server.hash_max_listpack_entries) + hashTypeConvert(o, OBJ_ENCODING_HT); + } else if (o->encoding == OBJ_ENCODING_HT) { + dictEntry *de = dictFind(o->ptr,field); + if (de) { + sdsfree(dictGetVal(de)); + if (flags & HASH_SET_TAKE_VALUE) { + dictGetVal(de) = value; + value = NULL; + } else { + dictGetVal(de) = sdsdup(value); + } + update = 1; + } else { + sds f,v; + if (flags & HASH_SET_TAKE_FIELD) { + f = field; + field = NULL; + } else { + f = sdsdup(field); + } + if (flags & HASH_SET_TAKE_VALUE) { + v = value; + value = NULL; + } else { + v = sdsdup(value); + } + dictAdd(o->ptr,f,v); + } + } else { + serverPanic("Unknown hash encoding"); + } + + /* Free SDS strings we did not referenced elsewhere if the flags + * want this function to be responsible. */ + if (flags & HASH_SET_TAKE_FIELD && field) sdsfree(field); + if (flags & HASH_SET_TAKE_VALUE && value) sdsfree(value); + return update; +} + +/* Delete an element from a hash. + * Return 1 on deleted and 0 on not found. */ +int hashTypeDelete(robj *o, sds field) { + int deleted = 0; + + if (o->encoding == OBJ_ENCODING_LISTPACK) { + unsigned char *zl, *fptr; + + zl = o->ptr; + fptr = lpFirst(zl); + if (fptr != NULL) { + fptr = lpFind(zl, fptr, (unsigned char*)field, sdslen(field), 1); + if (fptr != NULL) { + /* Delete both of the key and the value. */ + zl = lpDeleteRangeWithEntry(zl,&fptr,2); + o->ptr = zl; + deleted = 1; + } + } + } else if (o->encoding == OBJ_ENCODING_HT) { + if (dictDelete((dict*)o->ptr, field) == C_OK) { + deleted = 1; + + /* Always check if the dictionary needs a resize after a delete. */ + if (htNeedsResize(o->ptr)) dictResize(o->ptr); + } + + } else { + serverPanic("Unknown hash encoding"); + } + return deleted; +} + +/* Return the number of elements in a hash. */ +unsigned long hashTypeLength(const robj *o) { + unsigned long length = ULONG_MAX; + + if (o->encoding == OBJ_ENCODING_LISTPACK) { + length = lpLength(o->ptr) / 2; + } else if (o->encoding == OBJ_ENCODING_HT) { + length = dictSize((const dict*)o->ptr); + } else { + serverPanic("Unknown hash encoding"); + } + return length; +} + +hashTypeIterator *hashTypeInitIterator(robj *subject) { + hashTypeIterator *hi = zmalloc(sizeof(hashTypeIterator)); + hi->subject = subject; + hi->encoding = subject->encoding; + + if (hi->encoding == OBJ_ENCODING_LISTPACK) { + hi->fptr = NULL; + hi->vptr = NULL; + } else if (hi->encoding == OBJ_ENCODING_HT) { + hi->di = dictGetIterator(subject->ptr); + } else { + serverPanic("Unknown hash encoding"); + } + return hi; +} + +void hashTypeReleaseIterator(hashTypeIterator *hi) { + if (hi->encoding == OBJ_ENCODING_HT) + dictReleaseIterator(hi->di); + zfree(hi); +} + +/* Move to the next entry in the hash. Return C_OK when the next entry + * could be found and C_ERR when the iterator reaches the end. */ +int hashTypeNext(hashTypeIterator *hi) { + if (hi->encoding == OBJ_ENCODING_LISTPACK) { + unsigned char *zl; + unsigned char *fptr, *vptr; + + zl = hi->subject->ptr; + fptr = hi->fptr; + vptr = hi->vptr; + + if (fptr == NULL) { + /* Initialize cursor */ + serverAssert(vptr == NULL); + fptr = lpFirst(zl); + } else { + /* Advance cursor */ + serverAssert(vptr != NULL); + fptr = lpNext(zl, vptr); + } + if (fptr == NULL) return C_ERR; + + /* Grab pointer to the value (fptr points to the field) */ + vptr = lpNext(zl, fptr); + serverAssert(vptr != NULL); + + /* fptr, vptr now point to the first or next pair */ + hi->fptr = fptr; + hi->vptr = vptr; + } else if (hi->encoding == OBJ_ENCODING_HT) { + if ((hi->de = dictNext(hi->di)) == NULL) return C_ERR; + } else { + serverPanic("Unknown hash encoding"); + } + return C_OK; +} + +/* Get the field or value at iterator cursor, for an iterator on a hash value + * encoded as a listpack. Prototype is similar to `hashTypeGetFromListpack`. */ +void hashTypeCurrentFromListpack(hashTypeIterator *hi, int what, + unsigned char **vstr, + unsigned int *vlen, + long long *vll) +{ + serverAssert(hi->encoding == OBJ_ENCODING_LISTPACK); + + if (what & OBJ_HASH_KEY) { + *vstr = lpGetValue(hi->fptr, vlen, vll); + } else { + *vstr = lpGetValue(hi->vptr, vlen, vll); + } +} + +/* Get the field or value at iterator cursor, for an iterator on a hash value + * encoded as a hash table. Prototype is similar to + * `hashTypeGetFromHashTable`. */ +sds hashTypeCurrentFromHashTable(hashTypeIterator *hi, int what) { + serverAssert(hi->encoding == OBJ_ENCODING_HT); + + if (what & OBJ_HASH_KEY) { + return dictGetKey(hi->de); + } else { + return dictGetVal(hi->de); + } +} + +/* Higher level function of hashTypeCurrent*() that returns the hash value + * at current iterator position. + * + * The returned element is returned by reference in either *vstr and *vlen if + * it's returned in string form, or stored in *vll if it's returned as + * a number. + * + * If *vll is populated *vstr is set to NULL, so the caller + * can always check the function return by checking the return value + * type checking if vstr == NULL. */ +void hashTypeCurrentObject(hashTypeIterator *hi, int what, unsigned char **vstr, unsigned int *vlen, long long *vll) { + if (hi->encoding == OBJ_ENCODING_LISTPACK) { + *vstr = NULL; + hashTypeCurrentFromListpack(hi, what, vstr, vlen, vll); + } else if (hi->encoding == OBJ_ENCODING_HT) { + sds ele = hashTypeCurrentFromHashTable(hi, what); + *vstr = (unsigned char*) ele; + *vlen = sdslen(ele); + } else { + serverPanic("Unknown hash encoding"); + } +} + +/* Return the key or value at the current iterator position as a new + * SDS string. */ +sds hashTypeCurrentObjectNewSds(hashTypeIterator *hi, int what) { + unsigned char *vstr; + unsigned int vlen; + long long vll; + + hashTypeCurrentObject(hi,what,&vstr,&vlen,&vll); + if (vstr) return sdsnewlen(vstr,vlen); + return sdsfromlonglong(vll); +} + +#ifdef ROMAN_CLIENT_DISABLE +robj *hashTypeLookupWriteOrCreate(client *c, robj *key) { + robj *o = lookupKeyWrite(c->db,key); + if (checkType(c,o,OBJ_HASH)) return NULL; + + if (o == NULL) { + o = createHashObject(); + dbAdd(c->db,key,o); + } + return o; +} +#endif + +void hashTypeConvertListpack(robj *o, int enc) { + serverAssert(o->encoding == OBJ_ENCODING_LISTPACK); + + if (enc == OBJ_ENCODING_LISTPACK) { + /* Nothing to do... */ + + } else if (enc == OBJ_ENCODING_HT) { + hashTypeIterator *hi; + dict *dict; + int ret; + + hi = hashTypeInitIterator(o); + dict = dictCreate(&hashDictType); + + /* Presize the dict to avoid rehashing */ + dictExpand(dict,hashTypeLength(o)); + + while (hashTypeNext(hi) != C_ERR) { + sds key, value; + + key = hashTypeCurrentObjectNewSds(hi,OBJ_HASH_KEY); + value = hashTypeCurrentObjectNewSds(hi,OBJ_HASH_VALUE); + ret = dictAdd(dict, key, value); + if (ret != DICT_OK) { + sdsfree(key); sdsfree(value); /* Needed for gcc ASAN */ + hashTypeReleaseIterator(hi); /* Needed for gcc ASAN */ + serverLogHexDump(LL_WARNING,"listpack with dup elements dump", + o->ptr,lpBytes(o->ptr)); + serverPanic("Listpack corruption detected"); + } + } + hashTypeReleaseIterator(hi); + zfree(o->ptr); + o->encoding = OBJ_ENCODING_HT; + o->ptr = dict; + } else { + serverPanic("Unknown hash encoding"); + } +} + +void hashTypeConvert(robj *o, int enc) { + if (o->encoding == OBJ_ENCODING_LISTPACK) { + hashTypeConvertListpack(o, enc); + } else if (o->encoding == OBJ_ENCODING_HT) { + serverPanic("Not implemented"); + } else { + serverPanic("Unknown hash encoding"); + } +} + +/* This is a helper function for the COPY command. + * Duplicate a hash object, with the guarantee that the returned object + * has the same encoding as the original one. + * + * The resulting object always has refcount set to 1 */ +robj *hashTypeDup(robj *o) { + robj *hobj; + hashTypeIterator *hi; + + serverAssert(o->type == OBJ_HASH); + + if(o->encoding == OBJ_ENCODING_LISTPACK) { + unsigned char *zl = o->ptr; + size_t sz = lpBytes(zl); + unsigned char *new_zl = zmalloc(sz); + memcpy(new_zl, zl, sz); + hobj = createObject(OBJ_HASH, new_zl); + hobj->encoding = OBJ_ENCODING_LISTPACK; + } else if(o->encoding == OBJ_ENCODING_HT){ + dict *d = dictCreate(&hashDictType); + dictExpand(d, dictSize((const dict*)o->ptr)); + + hi = hashTypeInitIterator(o); + while (hashTypeNext(hi) != C_ERR) { + sds field, value; + sds newfield, newvalue; + /* Extract a field-value pair from an original hash object.*/ + field = hashTypeCurrentFromHashTable(hi, OBJ_HASH_KEY); + value = hashTypeCurrentFromHashTable(hi, OBJ_HASH_VALUE); + newfield = sdsdup(field); + newvalue = sdsdup(value); + + /* Add a field-value pair to a new hash object. */ + dictAdd(d,newfield,newvalue); + } + hashTypeReleaseIterator(hi); + + hobj = createObject(OBJ_HASH, d); + hobj->encoding = OBJ_ENCODING_HT; + } else { + serverPanic("Unknown hash encoding"); + } + return hobj; +} + +/* Create a new sds string from the listpack entry. */ +sds hashSdsFromListpackEntry(listpackEntry *e) { + return e->sval ? sdsnewlen(e->sval, e->slen) : sdsfromlonglong(e->lval); +} + +#ifdef ROMAN_CLIENT_DISABLE +/* Reply with bulk string from the listpack entry. */ +void hashReplyFromListpackEntry(client *c, listpackEntry *e) { + if (e->sval) + addReplyBulkCBuffer(c, e->sval, e->slen); + else + addReplyBulkLongLong(c, e->lval); +} + +#endif + +/* Return random element from a non empty hash. + * 'key' and 'val' will be set to hold the element. + * The memory in them is not to be freed or modified by the caller. + * 'val' can be NULL in which case it's not extracted. */ +void hashTypeRandomElement(robj *hashobj, unsigned long hashsize, listpackEntry *key, listpackEntry *val) { + if (hashobj->encoding == OBJ_ENCODING_HT) { + dictEntry *de = dictGetFairRandomKey(hashobj->ptr); + sds s = dictGetKey(de); + key->sval = (unsigned char*)s; + key->slen = sdslen(s); + if (val) { + sds s = dictGetVal(de); + val->sval = (unsigned char*)s; + val->slen = sdslen(s); + } + } else if (hashobj->encoding == OBJ_ENCODING_LISTPACK) { + lpRandomPair(hashobj->ptr, hashsize, key, val); + } else { + serverPanic("Unknown hash encoding"); + } +} + +#ifdef ROMAN_CLIENT_DISABLE +/*----------------------------------------------------------------------------- + * Hash type commands + *----------------------------------------------------------------------------*/ + +void hsetnxCommand(client *c) { + robj *o; + if ((o = hashTypeLookupWriteOrCreate(c,c->argv[1])) == NULL) return; + + if (hashTypeExists(o, c->argv[2]->ptr)) { + addReply(c, shared.czero); + } else { + hashTypeTryConversion(o,c->argv,2,3); + hashTypeSet(o,c->argv[2]->ptr,c->argv[3]->ptr,HASH_SET_COPY); + addReply(c, shared.cone); + signalModifiedKey(c,c->db,c->argv[1]); + notifyKeyspaceEvent(NOTIFY_HASH,"hset",c->argv[1],c->db->id); + server.dirty++; + } +} + +void hsetCommand(client *c) { + int i, created = 0; + robj *o; + + if ((c->argc % 2) == 1) { + addReplyErrorArity(c); + return; + } + + if ((o = hashTypeLookupWriteOrCreate(c,c->argv[1])) == NULL) return; + hashTypeTryConversion(o,c->argv,2,c->argc-1); + + for (i = 2; i < c->argc; i += 2) + created += !hashTypeSet(o,c->argv[i]->ptr,c->argv[i+1]->ptr,HASH_SET_COPY); + + /* HMSET (deprecated) and HSET return value is different. */ + char *cmdname = c->argv[0]->ptr; + if (cmdname[1] == 's' || cmdname[1] == 'S') { + /* HSET */ + addReplyLongLong(c, created); + } else { + /* HMSET */ + addReply(c, shared.ok); + } + signalModifiedKey(c,c->db,c->argv[1]); + notifyKeyspaceEvent(NOTIFY_HASH,"hset",c->argv[1],c->db->id); + server.dirty += (c->argc - 2)/2; +} + +void hincrbyCommand(client *c) { + long long value, incr, oldvalue; + robj *o; + sds new; + unsigned char *vstr; + unsigned int vlen; + + if (getLongLongFromObjectOrReply(c,c->argv[3],&incr,NULL) != C_OK) return; + if ((o = hashTypeLookupWriteOrCreate(c,c->argv[1])) == NULL) return; + if (hashTypeGetValue(o,c->argv[2]->ptr,&vstr,&vlen,&value) == C_OK) { + if (vstr) { + if (string2ll((char*)vstr,vlen,&value) == 0) { + addReplyError(c,"hash value is not an integer"); + return; + } + } /* Else hashTypeGetValue() already stored it into &value */ + } else { + value = 0; + } + + oldvalue = value; + if ((incr < 0 && oldvalue < 0 && incr < (LLONG_MIN-oldvalue)) || + (incr > 0 && oldvalue > 0 && incr > (LLONG_MAX-oldvalue))) { + addReplyError(c,"increment or decrement would overflow"); + return; + } + value += incr; + new = sdsfromlonglong(value); + hashTypeSet(o,c->argv[2]->ptr,new,HASH_SET_TAKE_VALUE); + addReplyLongLong(c,value); + signalModifiedKey(c,c->db,c->argv[1]); + notifyKeyspaceEvent(NOTIFY_HASH,"hincrby",c->argv[1],c->db->id); + server.dirty++; +} + +void hincrbyfloatCommand(client *c) { + long double value, incr; + long long ll; + robj *o; + sds new; + unsigned char *vstr; + unsigned int vlen; + + if (getLongDoubleFromObjectOrReply(c,c->argv[3],&incr,NULL) != C_OK) return; + if ((o = hashTypeLookupWriteOrCreate(c,c->argv[1])) == NULL) return; + if (hashTypeGetValue(o,c->argv[2]->ptr,&vstr,&vlen,&ll) == C_OK) { + if (vstr) { + if (string2ld((char*)vstr,vlen,&value) == 0) { + addReplyError(c,"hash value is not a float"); + return; + } + } else { + value = (long double)ll; + } + } else { + value = 0; + } + + value += incr; + if (isnan(value) || isinf(value)) { + addReplyError(c,"increment would produce NaN or Infinity"); + return; + } + + char buf[MAX_LONG_DOUBLE_CHARS]; + int len = ld2string(buf,sizeof(buf),value,LD_STR_HUMAN); + new = sdsnewlen(buf,len); + hashTypeSet(o,c->argv[2]->ptr,new,HASH_SET_TAKE_VALUE); + addReplyBulkCBuffer(c,buf,len); + signalModifiedKey(c,c->db,c->argv[1]); + notifyKeyspaceEvent(NOTIFY_HASH,"hincrbyfloat",c->argv[1],c->db->id); + server.dirty++; + + /* Always replicate HINCRBYFLOAT as an HSET command with the final value + * in order to make sure that differences in float precision or formatting + * will not create differences in replicas or after an AOF restart. */ + robj *newobj; + newobj = createRawStringObject(buf,len); + rewriteClientCommandArgument(c,0,shared.hset); + rewriteClientCommandArgument(c,3,newobj); + decrRefCount(newobj); +} + +static void addHashFieldToReply(client *c, robj *o, sds field) { + int ret; + + if (o == NULL) { + addReplyNull(c); + return; + } + + if (o->encoding == OBJ_ENCODING_LISTPACK) { + unsigned char *vstr = NULL; + unsigned int vlen = UINT_MAX; + long long vll = LLONG_MAX; + + ret = hashTypeGetFromListpack(o, field, &vstr, &vlen, &vll); + if (ret < 0) { + addReplyNull(c); + } else { + if (vstr) { + addReplyBulkCBuffer(c, vstr, vlen); + } else { + addReplyBulkLongLong(c, vll); + } + } + + } else if (o->encoding == OBJ_ENCODING_HT) { + sds value = hashTypeGetFromHashTable(o, field); + if (value == NULL) + addReplyNull(c); + else + addReplyBulkCBuffer(c, value, sdslen(value)); + } else { + serverPanic("Unknown hash encoding"); + } +} + +void hgetCommand(client *c) { + robj *o; + + if ((o = lookupKeyReadOrReply(c,c->argv[1],shared.null[c->resp])) == NULL || + checkType(c,o,OBJ_HASH)) return; + + addHashFieldToReply(c, o, c->argv[2]->ptr); +} + +void hmgetCommand(client *c) { + robj *o; + int i; + + /* Don't abort when the key cannot be found. Non-existing keys are empty + * hashes, where HMGET should respond with a series of null bulks. */ + o = lookupKeyRead(c->db, c->argv[1]); + if (checkType(c,o,OBJ_HASH)) return; + + addReplyArrayLen(c, c->argc-2); + for (i = 2; i < c->argc; i++) { + addHashFieldToReply(c, o, c->argv[i]->ptr); + } +} + +void hdelCommand(client *c) { + robj *o; + int j, deleted = 0, keyremoved = 0; + + if ((o = lookupKeyWriteOrReply(c,c->argv[1],shared.czero)) == NULL || + checkType(c,o,OBJ_HASH)) return; + + for (j = 2; j < c->argc; j++) { + if (hashTypeDelete(o,c->argv[j]->ptr)) { + deleted++; + if (hashTypeLength(o) == 0) { + dbDelete(c->db,c->argv[1]); + keyremoved = 1; + break; + } + } + } + if (deleted) { + signalModifiedKey(c,c->db,c->argv[1]); + notifyKeyspaceEvent(NOTIFY_HASH,"hdel",c->argv[1],c->db->id); + if (keyremoved) + notifyKeyspaceEvent(NOTIFY_GENERIC,"del",c->argv[1], + c->db->id); + server.dirty += deleted; + } + addReplyLongLong(c,deleted); +} + +void hlenCommand(client *c) { + robj *o; + + if ((o = lookupKeyReadOrReply(c,c->argv[1],shared.czero)) == NULL || + checkType(c,o,OBJ_HASH)) return; + + addReplyLongLong(c,hashTypeLength(o)); +} + +void hstrlenCommand(client *c) { + robj *o; + + if ((o = lookupKeyReadOrReply(c,c->argv[1],shared.czero)) == NULL || + checkType(c,o,OBJ_HASH)) return; + addReplyLongLong(c,hashTypeGetValueLength(o,c->argv[2]->ptr)); +} + +static void addHashIteratorCursorToReply(client *c, hashTypeIterator *hi, int what) { + if (hi->encoding == OBJ_ENCODING_LISTPACK) { + unsigned char *vstr = NULL; + unsigned int vlen = UINT_MAX; + long long vll = LLONG_MAX; + + hashTypeCurrentFromListpack(hi, what, &vstr, &vlen, &vll); + if (vstr) + addReplyBulkCBuffer(c, vstr, vlen); + else + addReplyBulkLongLong(c, vll); + } else if (hi->encoding == OBJ_ENCODING_HT) { + sds value = hashTypeCurrentFromHashTable(hi, what); + addReplyBulkCBuffer(c, value, sdslen(value)); + } else { + serverPanic("Unknown hash encoding"); + } +} + +void genericHgetallCommand(client *c, int flags) { + robj *o; + hashTypeIterator *hi; + int length, count = 0; + + robj *emptyResp = (flags & OBJ_HASH_KEY && flags & OBJ_HASH_VALUE) ? + shared.emptymap[c->resp] : shared.emptyarray; + if ((o = lookupKeyReadOrReply(c,c->argv[1],emptyResp)) + == NULL || checkType(c,o,OBJ_HASH)) return; + + /* We return a map if the user requested keys and values, like in the + * HGETALL case. Otherwise to use a flat array makes more sense. */ + length = hashTypeLength(o); + if (flags & OBJ_HASH_KEY && flags & OBJ_HASH_VALUE) { + addReplyMapLen(c, length); + } else { + addReplyArrayLen(c, length); + } + + hi = hashTypeInitIterator(o); + while (hashTypeNext(hi) != C_ERR) { + if (flags & OBJ_HASH_KEY) { + addHashIteratorCursorToReply(c, hi, OBJ_HASH_KEY); + count++; + } + if (flags & OBJ_HASH_VALUE) { + addHashIteratorCursorToReply(c, hi, OBJ_HASH_VALUE); + count++; + } + } + + hashTypeReleaseIterator(hi); + + /* Make sure we returned the right number of elements. */ + if (flags & OBJ_HASH_KEY && flags & OBJ_HASH_VALUE) count /= 2; + serverAssert(count == length); +} + +void hkeysCommand(client *c) { + genericHgetallCommand(c,OBJ_HASH_KEY); +} + +void hvalsCommand(client *c) { + genericHgetallCommand(c,OBJ_HASH_VALUE); +} + +void hgetallCommand(client *c) { + genericHgetallCommand(c,OBJ_HASH_KEY|OBJ_HASH_VALUE); +} + +void hexistsCommand(client *c) { + robj *o; + if ((o = lookupKeyReadOrReply(c,c->argv[1],shared.czero)) == NULL || + checkType(c,o,OBJ_HASH)) return; + + addReply(c, hashTypeExists(o,c->argv[2]->ptr) ? shared.cone : shared.czero); +} + +void hscanCommand(client *c) { + robj *o; + unsigned long cursor; + + if (parseScanCursorOrReply(c,c->argv[2],&cursor) == C_ERR) return; + if ((o = lookupKeyReadOrReply(c,c->argv[1],shared.emptyscan)) == NULL || + checkType(c,o,OBJ_HASH)) return; + scanGenericCommand(c,o,cursor); +} + +static void harndfieldReplyWithListpack(client *c, unsigned int count, listpackEntry *keys, listpackEntry *vals) { + for (unsigned long i = 0; i < count; i++) { + if (vals && c->resp > 2) + addReplyArrayLen(c,2); + if (keys[i].sval) + addReplyBulkCBuffer(c, keys[i].sval, keys[i].slen); + else + addReplyBulkLongLong(c, keys[i].lval); + if (vals) { + if (vals[i].sval) + addReplyBulkCBuffer(c, vals[i].sval, vals[i].slen); + else + addReplyBulkLongLong(c, vals[i].lval); + } + } +} + +/* How many times bigger should be the hash compared to the requested size + * for us to not use the "remove elements" strategy? Read later in the + * implementation for more info. */ +#define HRANDFIELD_SUB_STRATEGY_MUL 3 + +/* If client is trying to ask for a very large number of random elements, + * queuing may consume an unlimited amount of memory, so we want to limit + * the number of randoms per time. */ +#define HRANDFIELD_RANDOM_SAMPLE_LIMIT 1000 + +void hrandfieldWithCountCommand(client *c, long l, int withvalues) { + unsigned long count, size; + int uniq = 1; + robj *hash; + + if ((hash = lookupKeyReadOrReply(c,c->argv[1],shared.emptyarray)) + == NULL || checkType(c,hash,OBJ_HASH)) return; + size = hashTypeLength(hash); + + if(l >= 0) { + count = (unsigned long) l; + } else { + count = -l; + uniq = 0; + } + + /* If count is zero, serve it ASAP to avoid special cases later. */ + if (count == 0) { + addReply(c,shared.emptyarray); + return; + } + + /* CASE 1: The count was negative, so the extraction method is just: + * "return N random elements" sampling the whole set every time. + * This case is trivial and can be served without auxiliary data + * structures. This case is the only one that also needs to return the + * elements in random order. */ + if (!uniq || count == 1) { + if (withvalues && c->resp == 2) + addReplyArrayLen(c, count*2); + else + addReplyArrayLen(c, count); + if (hash->encoding == OBJ_ENCODING_HT) { + sds key, value; + while (count--) { + dictEntry *de = dictGetFairRandomKey(hash->ptr); + key = dictGetKey(de); + value = dictGetVal(de); + if (withvalues && c->resp > 2) + addReplyArrayLen(c,2); + addReplyBulkCBuffer(c, key, sdslen(key)); + if (withvalues) + addReplyBulkCBuffer(c, value, sdslen(value)); + } + } else if (hash->encoding == OBJ_ENCODING_LISTPACK) { + listpackEntry *keys, *vals = NULL; + unsigned long limit, sample_count; + + limit = count > HRANDFIELD_RANDOM_SAMPLE_LIMIT ? HRANDFIELD_RANDOM_SAMPLE_LIMIT : count; + keys = zmalloc(sizeof(listpackEntry)*limit); + if (withvalues) + vals = zmalloc(sizeof(listpackEntry)*limit); + while (count) { + sample_count = count > limit ? limit : count; + count -= sample_count; + lpRandomPairs(hash->ptr, sample_count, keys, vals); + harndfieldReplyWithListpack(c, sample_count, keys, vals); + } + zfree(keys); + zfree(vals); + } + return; + } + + /* Initiate reply count, RESP3 responds with nested array, RESP2 with flat one. */ + long reply_size = count < size ? count : size; + if (withvalues && c->resp == 2) + addReplyArrayLen(c, reply_size*2); + else + addReplyArrayLen(c, reply_size); + + /* CASE 2: + * The number of requested elements is greater than the number of + * elements inside the hash: simply return the whole hash. */ + if(count >= size) { + hashTypeIterator *hi = hashTypeInitIterator(hash); + while (hashTypeNext(hi) != C_ERR) { + if (withvalues && c->resp > 2) + addReplyArrayLen(c,2); + addHashIteratorCursorToReply(c, hi, OBJ_HASH_KEY); + if (withvalues) + addHashIteratorCursorToReply(c, hi, OBJ_HASH_VALUE); + } + hashTypeReleaseIterator(hi); + return; + } + + /* CASE 3: + * The number of elements inside the hash is not greater than + * HRANDFIELD_SUB_STRATEGY_MUL times the number of requested elements. + * In this case we create a hash from scratch with all the elements, and + * subtract random elements to reach the requested number of elements. + * + * This is done because if the number of requested elements is just + * a bit less than the number of elements in the hash, the natural approach + * used into CASE 4 is highly inefficient. */ + if (count*HRANDFIELD_SUB_STRATEGY_MUL > size) { + dict *d = dictCreate(&sdsReplyDictType); + dictExpand(d, size); + hashTypeIterator *hi = hashTypeInitIterator(hash); + + /* Add all the elements into the temporary dictionary. */ + while ((hashTypeNext(hi)) != C_ERR) { + int ret = DICT_ERR; + sds key, value = NULL; + + key = hashTypeCurrentObjectNewSds(hi,OBJ_HASH_KEY); + if (withvalues) + value = hashTypeCurrentObjectNewSds(hi,OBJ_HASH_VALUE); + ret = dictAdd(d, key, value); + + serverAssert(ret == DICT_OK); + } + serverAssert(dictSize(d) == size); + hashTypeReleaseIterator(hi); + + /* Remove random elements to reach the right count. */ + while (size > count) { + dictEntry *de; + de = dictGetFairRandomKey(d); + dictUnlink(d,dictGetKey(de)); + sdsfree(dictGetKey(de)); + sdsfree(dictGetVal(de)); + dictFreeUnlinkedEntry(d,de); + size--; + } + + /* Reply with what's in the dict and release memory */ + dictIterator *di; + dictEntry *de; + di = dictGetIterator(d); + while ((de = dictNext(di)) != NULL) { + sds key = dictGetKey(de); + sds value = dictGetVal(de); + if (withvalues && c->resp > 2) + addReplyArrayLen(c,2); + addReplyBulkSds(c, key); + if (withvalues) + addReplyBulkSds(c, value); + } + + dictReleaseIterator(di); + dictRelease(d); + } + + /* CASE 4: We have a big hash compared to the requested number of elements. + * In this case we can simply get random elements from the hash and add + * to the temporary hash, trying to eventually get enough unique elements + * to reach the specified count. */ + else { + if (hash->encoding == OBJ_ENCODING_LISTPACK) { + /* it is inefficient to repeatedly pick one random element from a + * listpack. so we use this instead: */ + listpackEntry *keys, *vals = NULL; + keys = zmalloc(sizeof(listpackEntry)*count); + if (withvalues) + vals = zmalloc(sizeof(listpackEntry)*count); + serverAssert(lpRandomPairsUnique(hash->ptr, count, keys, vals) == count); + harndfieldReplyWithListpack(c, count, keys, vals); + zfree(keys); + zfree(vals); + return; + } + + /* Hashtable encoding (generic implementation) */ + unsigned long added = 0; + listpackEntry key, value; + dict *d = dictCreate(&hashDictType); + dictExpand(d, count); + while(added < count) { + hashTypeRandomElement(hash, size, &key, withvalues? &value : NULL); + + /* Try to add the object to the dictionary. If it already exists + * free it, otherwise increment the number of objects we have + * in the result dictionary. */ + sds skey = hashSdsFromListpackEntry(&key); + if (dictAdd(d,skey,NULL) != DICT_OK) { + sdsfree(skey); + continue; + } + added++; + + /* We can reply right away, so that we don't need to store the value in the dict. */ + if (withvalues && c->resp > 2) + addReplyArrayLen(c,2); + hashReplyFromListpackEntry(c, &key); + if (withvalues) + hashReplyFromListpackEntry(c, &value); + } + + /* Release memory */ + dictRelease(d); + } +} + +/* HRANDFIELD key [ [WITHVALUES]] */ +void hrandfieldCommand(client *c) { + long l; + int withvalues = 0; + robj *hash; + listpackEntry ele; + + if (c->argc >= 3) { + if (getLongFromObjectOrReply(c,c->argv[2],&l,NULL) != C_OK) return; + if (c->argc > 4 || (c->argc == 4 && strcasecmp(c->argv[3]->ptr,"withvalues"))) { + addReplyErrorObject(c,shared.syntaxerr); + return; + } else if (c->argc == 4) + withvalues = 1; + hrandfieldWithCountCommand(c, l, withvalues); + return; + } + + /* Handle variant without argument. Reply with simple bulk string */ + if ((hash = lookupKeyReadOrReply(c,c->argv[1],shared.null[c->resp]))== NULL || + checkType(c,hash,OBJ_HASH)) { + return; + } + + hashTypeRandomElement(hash,hashTypeLength(hash),&ele,NULL); + hashReplyFromListpackEntry(c, &ele); +} + +#endif diff --git a/src/redis/t_zset.c b/src/redis/t_zset.c index 0308b6c..5d0c89e 100644 --- a/src/redis/t_zset.c +++ b/src/redis/t_zset.c @@ -72,6 +72,14 @@ * Skiplist implementation of the low level API *----------------------------------------------------------------------------*/ +// ROMAN: static representation of sds strings +static char kMinStrData[] = "\110" "minstring"; +static char kMaxStrData[] = "\110" "minstring"; + +#define cminstring (kMinStrData + 1) +#define cmaxstring (kMaxStrData + 1) + + int zslLexValueGteMin(sds value, const zlexrangespec *spec); int zslLexValueLteMax(sds value, const zlexrangespec *spec); @@ -586,12 +594,12 @@ int zslParseLexRangeItem(robj *item, sds *dest, int *ex) { case '+': if (c[1] != '\0') return C_ERR; *ex = 1; - *dest = shared.maxstring; + *dest = cmaxstring; return C_OK; case '-': if (c[1] != '\0') return C_ERR; *ex = 1; - *dest = shared.minstring; + *dest = cminstring; return C_OK; case '(': *ex = 1; @@ -609,10 +617,10 @@ int zslParseLexRangeItem(robj *item, sds *dest, int *ex) { /* Free a lex range structure, must be called only after zslParseLexRange() * populated the structure with success (C_OK returned). */ void zslFreeLexRange(zlexrangespec *spec) { - if (spec->min != shared.minstring && - spec->min != shared.maxstring) sdsfree(spec->min); - if (spec->max != shared.minstring && - spec->max != shared.maxstring) sdsfree(spec->max); + if (spec->min != cminstring && + spec->min != cmaxstring) sdsfree(spec->min); + if (spec->max != cminstring && + spec->max != cmaxstring) sdsfree(spec->max); } /* Populate the lex rangespec according to the objects min and max. @@ -641,8 +649,8 @@ int zslParseLexRange(robj *min, robj *max, zlexrangespec *spec) { * -inf and +inf for strings */ int sdscmplex(sds a, sds b) { if (a == b) return 0; - if (a == shared.minstring || b == shared.maxstring) return -1; - if (a == shared.maxstring || b == shared.minstring) return 1; + if (a == cminstring || b == cmaxstring) return -1; + if (a == cmaxstring || b == cminstring) return 1; return sdscmp(a,b); } diff --git a/src/redis/zset.h b/src/redis/zset.h index 4c25b80..53d86d0 100644 --- a/src/redis/zset.h +++ b/src/redis/zset.h @@ -97,7 +97,4 @@ int zslLexValueGteMin(sds value, const zlexrangespec* spec); int zslLexValueLteMax(sds value, const zlexrangespec* spec); int zsetZiplistValidateIntegrity(unsigned char* zl, size_t size, int deep); -extern size_t zset_max_ziplist_entries; -extern size_t zset_max_ziplist_value; - #endif