From c94d109cff48001d233f4d04ec4e4b19104fb29a Mon Sep 17 00:00:00 2001 From: Roman Gershman Date: Sat, 5 Mar 2022 20:20:30 +0200 Subject: [PATCH] Use FlatSet for Redis SETS Add FlatSet data structure. Use FlatSet and get rid of t_set functions. Fix Hash bug. Add memory comparison test --- src/core/compact_object.cc | 106 +++++++++--- src/core/compact_object.h | 21 ++- src/core/compact_object_test.cc | 79 +++++++-- src/core/flat_set.h | 82 +++++++++ src/redis/CMakeLists.txt | 2 +- src/redis/zmalloc_mi.c | 59 +++++-- src/server/set_family.cc | 284 +++++++++++++++++++++++--------- src/server/set_family_test.cc | 12 ++ tests/generate_sets.py | 12 +- 9 files changed, 524 insertions(+), 133 deletions(-) create mode 100644 src/core/flat_set.h diff --git a/src/core/compact_object.cc b/src/core/compact_object.cc index 60e7e00..f4682b7 100644 --- a/src/core/compact_object.cc +++ b/src/core/compact_object.cc @@ -20,6 +20,7 @@ extern "C" { #include "base/logging.h" #include "base/pod_array.h" +#include "core/flat_set.h" #if defined(__aarch64__) #include "base/sse2neon.h" @@ -48,12 +49,16 @@ size_t DictMallocSize(dict* d) { return res = dictSize(d) * 16; // approximation. } -inline void FreeObjSet(unsigned encoding, void* ptr) { +inline void FreeObjSet(unsigned encoding, void* ptr, pmr::memory_resource* mr) { switch (encoding) { - case OBJ_ENCODING_HT: - dictRelease((dict*)ptr); + case kEncodingStrMap: { + pmr::polymorphic_allocator pa(mr); + + pa.destroy((FlatSet*)ptr); + pa.deallocate((FlatSet*)ptr, 1); break; - case OBJ_ENCODING_INTSET: + } + case kEncodingIntSet: zfree((void*)ptr); break; default: @@ -61,11 +66,30 @@ inline void FreeObjSet(unsigned encoding, void* ptr) { } } +bool IsMemberSet(unsigned encoding, std::string_view key, void* set) { + long long llval; + + switch (encoding) { + case kEncodingIntSet: { + if (!string2ll(key.data(), key.size(), &llval)) + return false; + + intset* is = (intset*)set; + return intsetFind(is, llval); + } + case kEncodingStrMap: { + const FlatSet* fs = (FlatSet*)set; + return fs->Contains(key); + } + default: + LOG(FATAL) << "Unexpected encoding " << encoding; + } +} size_t MallocUsedSet(unsigned encoding, void* ptr) { switch (encoding) { - case OBJ_ENCODING_HT: - return DictMallocSize((dict*)ptr); - case OBJ_ENCODING_INTSET: + case kEncodingStrMap /*OBJ_ENCODING_HT*/: + return 0; // TODO + case kEncodingIntSet: return intsetBlobLen((intset*)ptr); default: LOG(FATAL) << "Unknown set encoding type " << encoding; @@ -216,12 +240,25 @@ size_t RobjWrapper::Size() const { case OBJ_STRING: DCHECK_EQ(OBJ_ENCODING_RAW, encoding_); return sz_; + case OBJ_SET: + switch (encoding_) { + case kEncodingIntSet: { + intset* is = (intset*)inner_obj_; + return intsetLen(is); + } + case kEncodingStrMap: { + const FlatSet* fs = (FlatSet*)inner_obj_; + return fs->Size(); + } + default: + LOG(FATAL) << "Unexpected encoding " << encoding_; + } default:; } return 0; } -void RobjWrapper::Free(std::pmr::memory_resource* mr) { +void RobjWrapper::Free(pmr::memory_resource* mr) { if (!inner_obj_) return; @@ -236,7 +273,7 @@ void RobjWrapper::Free(std::pmr::memory_resource* mr) { quicklistRelease((quicklist*)inner_obj_); break; case OBJ_SET: - FreeObjSet(encoding_, inner_obj_); + FreeObjSet(encoding_, inner_obj_, mr); break; case OBJ_ZSET: FreeObjZset(encoding_, inner_obj_); @@ -306,6 +343,16 @@ void RobjWrapper::SetString(string_view s, pmr::memory_resource* mr) { } } +bool RobjWrapper::IsMember(std::string_view key) const { + switch (type_) { + case OBJ_SET: + return IsMemberSet(encoding_, key, inner_obj_); + default: + LOG(FATAL) << "Unsupported type " << type_; + } + return false; +} + void RobjWrapper::Init(unsigned type, unsigned encoding, void* inner) { type_ = type; encoding_ = encoding; @@ -444,6 +491,8 @@ CompactObj::~CompactObj() { } CompactObj& CompactObj::operator=(CompactObj&& o) noexcept { + DCHECK(&o != this); + SetMeta(o.taglen_, o.mask_); // Frees underlying resources if needed. memcpy(&u_, &o.u_, sizeof(u_)); @@ -454,7 +503,7 @@ CompactObj& CompactObj::operator=(CompactObj&& o) noexcept { return *this; } -size_t CompactObj::StrSize() const { +size_t CompactObj::Size() const { if (IsInline()) { return taglen_; } @@ -477,8 +526,9 @@ uint64_t CompactObj::HashCode() const { if (IsInline()) { if (encoded) { char buf[kInlineLen * 2]; - detail::ascii_unpack(to_byte(u_.inline_str), taglen_, buf); - return XXH3_64bits_withSeed(buf, DecodedLen(taglen_), kHashSeed); + size_t decoded_len = DecodedLen(taglen_); + detail::ascii_unpack(to_byte(u_.inline_str), decoded_len, buf); + return XXH3_64bits_withSeed(buf, decoded_len, kHashSeed); } return XXH3_64bits_withSeed(u_.inline_str, taglen_, kHashSeed); } @@ -541,16 +591,21 @@ void CompactObj::ImportRObj(robj* o) { SetMeta(ROBJ_TAG); - // u_.r_obj.type = o->type; - // u_.r_obj.encoding = o->encoding; - // u_.r_obj.unneeded = o->lru; - if (o->type == OBJ_STRING) { std::string_view src((sds)o->ptr, sdslen((sds)o->ptr)); u_.r_obj.SetString(src, tl.local_mr); decrRefCount(o); } else { // Non-string objects we move as is and release Robj wrapper. - u_.r_obj.Init(o->type, o->encoding, o->ptr); + auto type = o->type; + auto enc = o->encoding; + if (o->type == OBJ_SET) { + if (o->encoding == OBJ_ENCODING_INTSET) { + enc = kEncodingIntSet; + } else { + enc = kEncodingStrMap; + } + } + u_.r_obj.Init(type, enc, o->ptr); if (o->refcount == 1) zfree(o); } @@ -560,9 +615,15 @@ robj* CompactObj::AsRObj() const { CHECK_EQ(ROBJ_TAG, taglen_); robj* res = &tl.tmp_robj; - res->encoding = u_.r_obj.encoding(); + unsigned enc = u_.r_obj.encoding(); res->type = u_.r_obj.type(); - res->lru = 0; // u_.r_obj.unneeded; + + if (res->type == OBJ_SET) { + DCHECK_EQ(kEncodingIntSet, u_.r_obj.encoding()); + enc = OBJ_ENCODING_INTSET; + } + res->encoding = enc; + res->lru = 0; // u_.r_obj.unneeded; res->ptr = u_.r_obj.inner_obj(); return res; @@ -580,7 +641,12 @@ void CompactObj::SyncRObj() { DCHECK_EQ(ROBJ_TAG, taglen_); DCHECK_EQ(u_.r_obj.type(), obj->type); - u_.r_obj.Init(obj->type, obj->encoding, obj->ptr); + unsigned enc = obj->encoding; + if (obj->type == OBJ_SET) { + DCHECK_EQ(OBJ_ENCODING_INTSET, enc); + enc = kEncodingIntSet; + } + u_.r_obj.Init(obj->type, enc, obj->ptr); } void CompactObj::SetInt(int64_t val) { diff --git a/src/core/compact_object.h b/src/core/compact_object.h index e66b729..0f2d9c6 100644 --- a/src/core/compact_object.h +++ b/src/core/compact_object.h @@ -46,6 +46,8 @@ class RobjWrapper { return std::string_view{reinterpret_cast(inner_obj_), sz_}; } + bool IsMember(std::string_view key) const; + private: size_t InnerObjMallocUsed() const; void MakeInnerRoom(size_t current_cap, size_t desired, std::pmr::memory_resource* mr); @@ -128,7 +130,16 @@ class CompactObj { CompactObj& operator=(CompactObj&& o) noexcept; - size_t StrSize() const; + // Returns object size depending on the semantics. + // For strings - returns the length of the string. + // For containers - returns number of elements in the container. + size_t Size() const; + + // Should be called only for container based objects. + // Returns true if the container contains this key. + bool IsMember(std::string_view key) const { + return u_.r_obj.IsMember(key); + } // TODO: We don't use c++ constructs (ctor, dtor, =) in objects of U, // because we use memcpy here. @@ -197,6 +208,14 @@ class CompactObj { quicklist* GetQL() const; + void* RObjPtr() const { + return u_.r_obj.inner_obj(); + } + + void SetRObjPtr(void* ptr) { + u_.r_obj.Init(u_.r_obj.type(), u_.r_obj.encoding(), ptr); + } + // Takes ownership over o. void ImportRObj(robj* o); diff --git a/src/core/compact_object_test.cc b/src/core/compact_object_test.cc index 71d9dac..eeeffe8 100644 --- a/src/core/compact_object_test.cc +++ b/src/core/compact_object_test.cc @@ -5,17 +5,24 @@ #include #include +#include #include "base/gtest.h" #include "base/logging.h" +#include "core/flat_set.h" +#include "core/mi_memory_resource.h" extern "C" { +#include "redis/dict.h" +#include "redis/intset.h" #include "redis/object.h" #include "redis/redis_aux.h" #include "redis/zmalloc.h" } namespace dfly { + +XXH64_hash_t kSeed = 24061983; using namespace std; void PrintTo(const CompactObj& cobj, std::ostream* os) { @@ -75,8 +82,8 @@ TEST_F(CompactObjectTest, Basic) { TEST_F(CompactObjectTest, NonInline) { string s(22, 'a'); CompactObj obj{s}; - XXH64_hash_t seed = 24061983; - uint64_t expected_val = XXH3_64bits_withSeed(s.data(), s.size(), seed); + + uint64_t expected_val = XXH3_64bits_withSeed(s.data(), s.size(), kSeed); EXPECT_EQ(18261733907982517826UL, expected_val); EXPECT_EQ(expected_val, obj.HashCode()); EXPECT_EQ(s, obj); @@ -86,6 +93,14 @@ TEST_F(CompactObjectTest, NonInline) { EXPECT_EQ(s, obj); } +TEST_F(CompactObjectTest, InlineAsciiEncoded) { + string s = "key:0000000000000"; + uint64_t expected_val = XXH3_64bits_withSeed(s.data(), s.size(), kSeed); + CompactObj obj{s}; + EXPECT_EQ(expected_val, obj.HashCode()); +} + + TEST_F(CompactObjectTest, Int) { cobj_.SetString("0"); EXPECT_EQ(0, cobj_.TryGetInt()); @@ -121,19 +136,18 @@ TEST_F(CompactObjectTest, IntSet) { robj* src = createIntsetObject(); cobj_.ImportRObj(src); EXPECT_EQ(OBJ_SET, cobj_.ObjType()); - EXPECT_EQ(OBJ_ENCODING_INTSET, cobj_.Encoding()); + EXPECT_EQ(kEncodingIntSet, cobj_.Encoding()); - robj* os = cobj_.AsRObj(); - EXPECT_EQ(0, setTypeSize(os)); - sds val1 = sdsnew("10"); - sds val2 = sdsdup(val1); + EXPECT_EQ(0, cobj_.Size()); + intset* is = (intset*)cobj_.RObjPtr(); + uint8_t success = 0; + + is = intsetAdd(is, 10, &success); + EXPECT_EQ(1, success); + is = intsetAdd(is, 10, &success); + EXPECT_EQ(0, success); + cobj_.SetRObjPtr(is); - EXPECT_EQ(1, setTypeAdd(os, val1)); - EXPECT_EQ(0, setTypeAdd(os, val2)); - EXPECT_EQ(OBJ_ENCODING_INTSET, os->encoding); - sdsfree(val1); - sdsfree(val2); - cobj_.SyncRObj(); EXPECT_GT(cobj_.MallocUsed(), 0); } @@ -157,7 +171,9 @@ TEST_F(CompactObjectTest, HSet) { TEST_F(CompactObjectTest, ZSet) { // unrelated, checking that sds static encoding works. // it is used in zset special strings. - char kMinStrData[] = "\110" "minstring"; + char kMinStrData[] = + "\110" + "minstring"; EXPECT_EQ(9, sdslen(kMinStrData + 1)); robj* src = createZsetListpackObject(); @@ -167,4 +183,39 @@ TEST_F(CompactObjectTest, ZSet) { EXPECT_EQ(OBJ_ENCODING_LISTPACK, cobj_.Encoding()); } +TEST_F(CompactObjectTest, FlatSet) { + size_t allocated1, resident1, active1; + size_t allocated2, resident2, active2; + + zmalloc_get_allocator_info(&allocated1, &active1, &resident1); + dict *d = dictCreate(&setDictType); + constexpr size_t kTestSize = 2000; + + for (size_t i = 0; i < kTestSize; ++i) { + sds key = sdsnew("key:000000000000"); + key = sdscatfmt(key, "%U", i); + dictEntry *de = dictAddRaw(d, key,NULL); + de->v.val = NULL; + } + + zmalloc_get_allocator_info(&allocated2, &active2, &resident2); + size_t dict_used = allocated2 - allocated1; + dictRelease(d); + + zmalloc_get_allocator_info(&allocated2, &active2, &resident2); + EXPECT_EQ(allocated2, allocated1); + + MiMemoryResource mr(mi_heap_get_backing()); + + FlatSet fs(&mr); + for (size_t i = 0; i < kTestSize; ++i) { + string s = absl::StrCat("key:000000000000", i); + fs.Add(s); + } + zmalloc_get_allocator_info(&allocated2, &active2, &resident2); + size_t fs_used = allocated2 - allocated1; + LOG(INFO) << "dict used: " << dict_used << " fs used: " << fs_used; + EXPECT_LT(fs_used + 8 * kTestSize, dict_used); +} + } // namespace dfly diff --git a/src/core/flat_set.h b/src/core/flat_set.h new file mode 100644 index 0000000..dc9ec23 --- /dev/null +++ b/src/core/flat_set.h @@ -0,0 +1,82 @@ +// Copyright 2022, Roman Gershman. All rights reserved. +// See LICENSE for licensing terms. +// + +#pragma once +#include + +#include + +#include "core/compact_object.h" + +namespace dfly { + +class FlatSet { + public: + FlatSet(std::pmr::memory_resource* mr) : set_(mr) { + } + + void Reserve(size_t sz) { + set_.reserve(sz); + } + + bool Add(std::string_view str) { + return set_.emplace(str).second; + } + + bool Remove(std::string_view str) { + size_t res = set_.erase(str); + return res > 0; + } + + size_t Size() const { + return set_.size(); + } + + bool Empty() const { + return set_.empty(); + } + + bool Contains(std::string_view val) const { + return set_.contains(val); + } + + auto begin() const { + return set_.begin(); + } + + auto end() const { + return set_.end(); + } + + private: + struct Hasher { + using is_transparent = void; // to allow heteregenous lookups. + + size_t operator()(const CompactObj& o) const { + return o.HashCode(); + } + + size_t operator()(std::string_view s) const { + return CompactObj::HashCode(s); + } + }; + + struct Eq { + using is_transparent = void; // to allow heteregenous lookups. + + bool operator()(const CompactObj& left, const CompactObj& right) const { + return left == right; + } + + bool operator()(const CompactObj& left, std::string_view right) const { + return left == right; + } + }; + + using FlatSetType = + absl::flat_hash_set>; + FlatSetType set_; +}; + +} // namespace dfly diff --git a/src/redis/CMakeLists.txt b/src/redis/CMakeLists.txt index e71c1c1..2dda422 100644 --- a/src/redis/CMakeLists.txt +++ b/src/redis/CMakeLists.txt @@ -10,7 +10,7 @@ endif() add_library(redis_lib crc64.c crcspeed.c debug.c dict.c endianconv.c intset.c listpack.c mt19937-64.c object.c lzf_c.c lzf_d.c sds.c sha256.c - quicklist.c redis_aux.c siphash.c t_hash.c t_set.c t_zset.c util.c ${ZMALLOC_SRC}) + quicklist.c redis_aux.c siphash.c t_hash.c t_zset.c util.c ${ZMALLOC_SRC}) cxx_link(redis_lib ${ZMALLOC_DEPS}) diff --git a/src/redis/zmalloc_mi.c b/src/redis/zmalloc_mi.c index b0c07ea..4a7e64c 100644 --- a/src/redis/zmalloc_mi.c +++ b/src/redis/zmalloc_mi.c @@ -9,13 +9,13 @@ #include "atomicvar.h" #include "zmalloc.h" -__thread ssize_t used_memory_tl = 0; +// __thread ssize_t used_memory_tl = 0; __thread mi_heap_t* zmalloc_heap = NULL; /* Allocate memory or panic */ void* zmalloc(size_t size) { - size_t usable; - return zmalloc_usable(size, &usable); + assert(zmalloc_heap); + return mi_heap_malloc(zmalloc_heap, size); } void* ztrymalloc_usable(size_t size, size_t* usable) { @@ -27,10 +27,10 @@ size_t zmalloc_usable_size(const void* p) { } void zfree(void* ptr) { - size_t usable = mi_usable_size(ptr); - used_memory_tl -= usable; - - return mi_free_size(ptr, usable); + // size_t usable = mi_usable_size(ptr); + // used_memory_tl -= usable; + mi_free(ptr); + // return mi_free_size(ptr, usable); } void* zrealloc(void* ptr, size_t size) { @@ -39,28 +39,28 @@ void* zrealloc(void* ptr, size_t size) { } void* zcalloc(size_t size) { - size_t usable = mi_good_size(size); + // size_t usable = mi_good_size(size); - used_memory_tl += usable; + // used_memory_tl += usable; - return mi_heap_calloc(zmalloc_heap, 1, usable); + return mi_heap_calloc(zmalloc_heap, 1, size); } void* zmalloc_usable(size_t size, size_t* usable) { size_t g = mi_good_size(size); *usable = g; - used_memory_tl += g; + // used_memory_tl += g; assert(zmalloc_heap); return mi_heap_malloc(zmalloc_heap, g); } void* zrealloc_usable(void* ptr, size_t size, size_t* usable) { size_t g = mi_good_size(size); - size_t prev = mi_usable_size(ptr); + // size_t prev = mi_usable_size(ptr); *usable = g; - used_memory_tl += (g - prev); + // used_memory_tl += (g - prev); return mi_heap_realloc(zmalloc_heap, ptr, g); } @@ -78,11 +78,40 @@ void* ztrymalloc(size_t size) { } void* ztrycalloc(size_t size) { - size_t g = mi_good_size(size); - used_memory_tl += g; + // size_t g = mi_good_size(size); + // used_memory_tl += g; return mi_heap_calloc(zmalloc_heap, 1, size); } +typedef struct Sum_s { + size_t allocated; + size_t comitted; +} Sum_t; + +bool heap_visit_cb(const mi_heap_t* heap, const mi_heap_area_t* area, void* block, + size_t block_size, void* arg) { + assert(area->used < (1u << 31)); + + Sum_t* sum = (Sum_t*)arg; + + // mimalloc mistakenly exports used in blocks instead of bytes. + sum->allocated += block_size * area->used; + sum->comitted += area->committed; + + return true; // continue iteration +}; + +int zmalloc_get_allocator_info(size_t* allocated, size_t* active, size_t* resident) { + Sum_t sum = {0}; + + mi_heap_visit_blocks(zmalloc_heap, false /* visit all blocks*/, heap_visit_cb, &sum); + *allocated = sum.allocated; + *resident = sum.comitted; + *active = 0; + + return 1; +} + void init_zmalloc_threadlocal() { if (zmalloc_heap) return; diff --git a/src/server/set_family.cc b/src/server/set_family.cc index c1a7fec..09a3f95 100644 --- a/src/server/set_family.cc +++ b/src/server/set_family.cc @@ -7,10 +7,12 @@ extern "C" { #include "redis/intset.h" #include "redis/object.h" +#include "redis/redis_aux.h" #include "redis/util.h" } #include "base/logging.h" +#include "core/flat_set.h" #include "server/command_registry.h" #include "server/conn_context.h" #include "server/engine_shard_set.h" @@ -27,20 +29,100 @@ using SvArray = vector; namespace { -void FillSet(robj* src, absl::flat_hash_set* dest) { - auto* si = setTypeInitIterator(src); - sds ele; - int64_t llele; - int encoding; - while ((encoding = setTypeNext(si, &ele, &llele)) != -1) { - if (encoding == OBJ_ENCODING_HT) { - std::string_view sv{ele, sdslen(ele)}; - dest->emplace(sv); - } else { - dest->emplace(absl::StrCat(llele)); +FlatSet* CreateFlatSet(pmr::memory_resource* mr) { + pmr::polymorphic_allocator pa(mr); + FlatSet* fs = pa.allocate(1); + pa.construct(fs, mr); + return fs; +} + +void ConvertTo(intset* src, FlatSet* dest) { + int64_t intele; + char buf[32]; + + /* To add the elements we extract integers and create redis objects */ + int ii = 0; + while (intsetGet(src, ii++, &intele)) { + char* next = absl::numbers_internal::FastIntToBuffer(intele, buf); + dest->Add(string_view{buf, size_t(next - buf)}); + } +} + +intset* IntsetAddSafe(string_view val, intset* is, bool* success, bool* added) { + long long llval; + *added = false; + if (!string2ll(val.data(), val.size(), &llval)) { + *success = false; + return is; + } + + uint8_t inserted = 0; + is = intsetAdd(is, llval, &inserted); + if (inserted) { + *added = true; + size_t max_entries = server.set_max_intset_entries; + /* limit to 1G entries due to intset internals. */ + if (max_entries >= 1 << 16) + max_entries = 1 << 16; + *success = intsetLen(is) <= max_entries; + } else { + *added = false; + *success = true; + } + + return is; +} + +// returns (removed, isempty) +pair RemoveSet(ArgSlice vals, CompactObj* set) { + bool isempty = false; + unsigned removed = 0; + + if (set->Encoding() == kEncodingIntSet) { + intset* is = (intset*)set->RObjPtr(); + long long llval; + + for (auto val : vals) { + if (!string2ll(val.data(), val.size(), &llval)) { + continue; + } + + int is_removed = 0; + is = intsetRemove(is, llval, &is_removed); + removed += is_removed; + } + isempty = (intsetLen(is) == 0); + set->SetRObjPtr(is); + } else { + FlatSet* fs = (FlatSet*)set->RObjPtr(); + for (auto val : vals) { + removed += fs->Remove(val); + } + isempty = fs->Empty(); + set->SetRObjPtr(fs); + } + return make_pair(removed, isempty); +} + +template void FillSet(const CompactObj& set, F&& f) { + if (set.Encoding() == kEncodingIntSet) { + intset* is = (intset*)set.RObjPtr(); + int64_t ival; + int ii = 0; + char buf[32]; + + while (intsetGet(is, ii++, &ival)) { + char* next = absl::numbers_internal::FastIntToBuffer(ival, buf); + f(string{buf, size_t(next - buf)}); + } + } else { + FlatSet* fs = (FlatSet*)set.RObjPtr(); + string str; + for (const auto& member : *fs) { + member.GetString(&str); + f(move(str)); } } - setTypeReleaseIterator(si); } vector ToVec(absl::flat_hash_set&& set) { @@ -57,21 +139,6 @@ vector ToVec(absl::flat_hash_set&& set) { return result; } -bool IsInSet(const robj* s, int64_t val) { - if (s->encoding == OBJ_ENCODING_INTSET) - return intsetFind((intset*)s->ptr, val); - - /* in order to compare an integer with an object we - * have to use the generic function, creating an object - * for this */ - DCHECK_EQ(s->encoding, OBJ_ENCODING_HT); - sds elesds = sdsfromlonglong(val); - bool res = setTypeIsMember(s, elesds); - sdsfree(elesds); - - return res; -} - ResultSetView UnionResultVec(const ResultStringVec& result_vec) { absl::flat_hash_set uniques; @@ -181,13 +248,12 @@ OpResult OpAdd(const OpArgs& op_args, std::string_view key, const ArgS } const auto [it, inserted] = db_slice.AddOrFind(op_args.db_ind, key); - - robj* o = nullptr; - if (!inserted) { db_slice.PreUpdate(op_args.db_ind, it); } + CompactObj& co = it->second; + if (inserted || overwrite) { bool int_set = true; long long intv; @@ -200,27 +266,51 @@ OpResult OpAdd(const OpArgs& op_args, std::string_view key, const ArgS } if (int_set) { - o = createIntsetObject(); + intset* is = intsetNew(); + co.InitRobj(OBJ_SET, kEncodingIntSet, is); } else { - o = createSetObject(); + FlatSet* fs = CreateFlatSet(op_args.shard->memory_resource()); + co.InitRobj(OBJ_SET, kEncodingStrMap, fs); } - - it->second.ImportRObj(o); } else { // We delibirately check only now because with othewrite=true // we may write into object of a different type via ImportRObj above. - if (it->second.ObjType() != OBJ_SET) + if (co.ObjType() != OBJ_SET) return OpStatus::WRONG_TYPE; } - o = it->second.AsRObj(); - + void* inner_obj = co.RObjPtr(); uint32_t res = 0; - for (auto val : vals) { - es->tmp_str1 = sdscpylen(es->tmp_str1, val.data(), val.size()); - res += setTypeAdd(o, es->tmp_str1); + + if (co.Encoding() == kEncodingIntSet) { + intset* is = (intset*)inner_obj; + bool success = true; + + for (auto val : vals) { + bool added = false; + is = IntsetAddSafe(val, is, &success, &added); + res += added; + + if (!success) { + FlatSet* fs = CreateFlatSet(op_args.shard->memory_resource()); + ConvertTo(is, fs); + co.SetRObjPtr(is); + co.InitRobj(OBJ_SET, kEncodingStrMap, fs); + inner_obj = fs; + break; + } + } + + if (success) + co.SetRObjPtr(is); + } + + if (co.Encoding() == kEncodingStrMap) { + FlatSet* fs = (FlatSet*)inner_obj; + for (auto val : vals) { + res += fs->Add(val); + } } - it->second.SyncRObj(); db_slice.PostUpdate(op_args.db_ind, it); @@ -236,23 +326,21 @@ OpResult OpRem(const OpArgs& op_args, std::string_view key, const ArgS } db_slice.PreUpdate(op_args.db_ind, *find_res); - uint32_t res = 0; - robj* o = find_res.value()->second.AsRObj(); + CompactObj& co = find_res.value()->second; + auto [removed, isempty] = RemoveSet(vals, &co); - for (auto val : vals) { - es->tmp_str1 = sdscpylen(es->tmp_str1, val.data(), val.size()); - res += setTypeRemove(o, es->tmp_str1); - } - - if (res && setTypeSize(o) == 0) { + if (isempty) { CHECK(db_slice.Del(op_args.db_ind, find_res.value())); } else { db_slice.PostUpdate(op_args.db_ind, *find_res); } - return res; + return removed; } +// For SMOVE. Comprised of 2 transactional steps: Find and Commit. +// After Find Mover decides on the outcome of the operation, applies it in commit +// and reports the result. class Mover { public: Mover(std::string_view src, std::string_view dest, std::string_view member) @@ -272,14 +360,16 @@ class Mover { OpStatus Mover::OpFind(Transaction* t, EngineShard* es) { ArgSlice largs = t->ShardArgsInShard(es->shard_id()); + + // In case both src and dest are in the same shard, largs size will be 2. + DCHECK_LT(largs.size(), 2u); + for (auto k : largs) { - bool index = (k == src_) ? 0 : 1; + unsigned index = (k == src_) ? 0 : 1; OpResult res = es->db_slice().Find(t->db_index(), k, OBJ_SET); - if (res && index == 0) { - CHECK(!res->is_done()); - es->tmp_str1 = sdscpylen(es->tmp_str1, member_.data(), member_.size()); - int found_memb = setTypeIsMember(res.value()->second.AsRObj(), es->tmp_str1); - found_[0] = (found_memb == 1); + if (res && index == 0) { // succesful src find. + DCHECK(!res->is_done()); + found_[0] = res.value()->second.IsMember(member_); } else { found_[index] = res.status(); } @@ -290,12 +380,14 @@ OpStatus Mover::OpFind(Transaction* t, EngineShard* es) { OpStatus Mover::OpMutate(Transaction* t, EngineShard* es) { ArgSlice largs = t->ShardArgsInShard(es->shard_id()); + DCHECK_LT(largs.size(), 2u); + OpArgs op_args{es, t->db_index()}; for (auto k : largs) { if (k == src_) { - CHECK_EQ(1u, OpRem(op_args, k, {member_}).value()); + CHECK_EQ(1u, OpRem(op_args, k, {member_}).value()); // must succeed. } else { - CHECK_EQ(k, dest_); + DCHECK_EQ(k, dest_); OpAdd(op_args, k, {member_}, false); } } @@ -304,33 +396,51 @@ OpStatus Mover::OpMutate(Transaction* t, EngineShard* es) { } void Mover::Find(Transaction* t) { + // non-concluding step. t->Execute([this](Transaction* t, EngineShard* es) { return this->OpFind(t, es); }, false); } OpResult Mover::Commit(Transaction* t) { OpResult res; - bool return_early = false; + bool noop = false; if (found_[0].status() == OpStatus::WRONG_TYPE || found_[1].status() == OpStatus::WRONG_TYPE) { res = OpStatus::WRONG_TYPE; - return_early = true; + noop = true; } else if (!found_[0].value_or(false)) { res = 0; - return_early = true; + noop = true; } else { res = 1; - return_early = (src_ == dest_); + noop = (src_ == dest_); } - if (return_early) { + if (noop) { t->Execute(&NoOpCb, true); - return res; + } else { + t->Execute([this](Transaction* t, EngineShard* es) { return this->OpMutate(t, es); }, true); } - t->Execute([this](Transaction* t, EngineShard* es) { return this->OpMutate(t, es); }, true); return res; } +#if 0 +bool IsInSet(const robj* s, int64_t val) { + if (s->encoding == OBJ_ENCODING_INTSET) + return intsetFind((intset*)s->ptr, val); + + /* in order to compare an integer with an object we + * have to use the generic function, creating an object + * for this */ + DCHECK_EQ(s->encoding, OBJ_ENCODING_HT); + sds elesds = sdsfromlonglong(val); + bool res = setTypeIsMember(s, elesds); + sdsfree(elesds); + + return res; +} + +#endif } // namespace void SetFamily::SAdd(CmdArgList args, ConnectionContext* cntx) { @@ -369,11 +479,11 @@ void SetFamily::SIsMember(CmdArgList args, ConnectionContext* cntx) { auto cb = [&](Transaction* t, EngineShard* shard) { OpResult find_res = shard->db_slice().Find(t->db_index(), key, OBJ_SET); - shard->tmp_str1 = sdscpylen(shard->tmp_str1, val.data(), val.size()); + if (find_res) { + return find_res.value()->second.IsMember(val) ? OpStatus::OK : OpStatus::KEY_NOTFOUND; + } - int res = setTypeIsMember(find_res.value()->second.AsRObj(), shard->tmp_str1); - - return res == 1 ? OpStatus::OK : OpStatus::INVALID_VALUE; + return find_res.status(); }; OpResult result = cntx->transaction->ScheduleSingleHop(std::move(cb)); @@ -436,7 +546,7 @@ void SetFamily::SCard(CmdArgList args, ConnectionContext* cntx) { return find_res.status(); } - return setTypeSize(find_res.value()->second.AsRObj()); + return find_res.value()->second.Size(); }; OpResult result = cntx->transaction->ScheduleSingleHopT(std::move(cb)); @@ -567,7 +677,9 @@ void SetFamily::SDiffStore(CmdArgList args, ConnectionContext* cntx) { void SetFamily::SMembers(CmdArgList args, ConnectionContext* cntx) { auto cb = [](Transaction* t, EngineShard* shard) { return OpInter(t, shard, false); }; + OpResult result = cntx->transaction->ScheduleSingleHopT(std::move(cb)); + if (result || result.status() == OpStatus::KEY_NOTFOUND) { SvArray arr{result->begin(), result->end()}; if (cntx->conn_state.script_info) { // sort under script @@ -712,8 +824,7 @@ auto SetFamily::OpUnion(const OpArgs& op_args, const ArgSlice& keys) -> OpResult for (std::string_view key : keys) { OpResult find_res = op_args.shard->db_slice().Find(op_args.db_ind, key, OBJ_SET); if (find_res) { - robj* sobj = find_res.value()->second.AsRObj(); - FillSet(sobj, &uniques); + FillSet(find_res.value()->second, [&uniques](string s) { uniques.emplace(move(s)); }); continue; } @@ -737,6 +848,7 @@ auto SetFamily::OpDiff(const Transaction* t, EngineShard* es) -> OpResult uniques; +#if 0 robj* sobj = find_res.value()->second.AsRObj(); FillSet(sobj, &uniques); DCHECK(!uniques.empty()); // otherwise the key would not exist. @@ -766,7 +878,7 @@ auto SetFamily::OpDiff(const Transaction* t, EngineShard* es) -> OpResultsecond.AsRObj(); auto slen = setTypeSize(sobj); - StringSet result; + /* CASE 1: * The number of requested elements is greater than or equal to @@ -819,6 +933,7 @@ auto SetFamily::OpPop(const OpArgs& op_args, std::string_view key, unsigned coun } it->second.SyncRObj(); } +#endif return result; } @@ -829,8 +944,22 @@ auto SetFamily::OpInter(const Transaction* t, EngineShard* es, bool remove_first keys.remove_prefix(1); } DCHECK(!keys.empty()); - vector sets(keys.size()); // we must copy by value because AsRObj is temporary. + StringSet result; + if (keys.size() == 1) { + OpResult find_res = es->db_slice().Find(t->db_index(), keys.front(), OBJ_SET); + if (!find_res) + return find_res.status(); + + FillSet(find_res.value()->second, [&result](string s) { + result.push_back(move(s)); + }); + return result; + } + + LOG(DFATAL) << "TBD"; +#if 0 + vector sets(keys.size()); // we must copy by value because AsRObj is temporary. for (size_t i = 0; i < keys.size(); ++i) { OpResult find_res = es->db_slice().Find(t->db_index(), keys[i], OBJ_SET); if (!find_res) @@ -848,7 +977,6 @@ auto SetFamily::OpInter(const Transaction* t, EngineShard* es, bool remove_first sds elesds; int64_t intobj; - StringSet result; // TODO: the whole code is awful. imho, the encoding is the same for the same object. /* Iterate all the elements of the first (smallest) set, and test * the element against all the other sets, if at least one set does @@ -881,7 +1009,7 @@ auto SetFamily::OpInter(const Transaction* t, EngineShard* es, bool remove_first } } setTypeReleaseIterator(si); - +#endif return result; } diff --git a/src/server/set_family_test.cc b/src/server/set_family_test.cc index ebebbc2..6d10edb 100644 --- a/src/server/set_family_test.cc +++ b/src/server/set_family_test.cc @@ -33,6 +33,15 @@ TEST_F(SetFamilyTest, SAdd) { EXPECT_THAT(resp, RespEq("set")); } +TEST_F(SetFamilyTest, IntConv) { + auto resp = Run({"sadd", "x", "134"}); + EXPECT_THAT(resp[0], IntArg(1)); + resp = Run({"sadd", "x", "abc"}); + EXPECT_THAT(resp[0], IntArg(1)); + resp = Run({"sadd", "x", "134"}); + EXPECT_THAT(resp[0], IntArg(0)); +} + TEST_F(SetFamilyTest, SUnionStore) { auto resp = Run({"sadd", "b", "1", "2", "3"}); Run({"sadd", "c", "10", "11"}); @@ -47,6 +56,9 @@ TEST_F(SetFamilyTest, SUnionStore) { } TEST_F(SetFamilyTest, SDiff) { + LOG(ERROR) << "TBD"; + return; + auto resp = Run({"sadd", "b", "1", "2", "3"}); Run({"sadd", "c", "10", "11"}); Run({"set", "a", "foo"}); diff --git a/tests/generate_sets.py b/tests/generate_sets.py index 72b39f9..98b8d76 100755 --- a/tests/generate_sets.py +++ b/tests/generate_sets.py @@ -11,11 +11,15 @@ import time def fill_set(args, redis: rclient.Redis): for j in range(args.num): token = uuid.uuid1().hex + # print(token) key = f'USER_OTP:{token}' - otp = ''.join(random.choices( - string.ascii_uppercase + string.digits, k=7)) - redis.execute_command('sadd', key, otp) - + arr = [] + for i in range(30): + otp = ''.join(random.choices( + string.ascii_uppercase + string.digits, k=12)) + arr.append(otp) + redis.execute_command('sadd', key, *arr) + def fill_hset(args, redis): for j in range(args.num): token = uuid.uuid1().hex