diff --git a/src/core/compact_object.cc b/src/core/compact_object.cc index c37e4b8..3ce96f0 100644 --- a/src/core/compact_object.cc +++ b/src/core/compact_object.cc @@ -20,7 +20,6 @@ extern "C" { #include "base/logging.h" #include "base/pod_array.h" -#include "core/flat_set.h" #if defined(__aarch64__) #include "base/sse2neon.h" @@ -52,10 +51,7 @@ size_t DictMallocSize(dict* d) { inline void FreeObjSet(unsigned encoding, void* ptr, pmr::memory_resource* mr) { switch (encoding) { case kEncodingStrMap: { - pmr::polymorphic_allocator pa(mr); - - pa.destroy((FlatSet*)ptr); - pa.deallocate((FlatSet*)ptr, 1); + dictRelease((dict*)ptr); break; } case kEncodingIntSet: @@ -66,6 +62,24 @@ inline void FreeObjSet(unsigned encoding, void* ptr, pmr::memory_resource* mr) { } } +bool dictContains(const dict* d, string_view key) { + uint64_t h = dictGenHashFunction(key.data(), key.size()); + + for (unsigned table = 0; table <= 1; table++) { + uint64_t idx = h & DICTHT_SIZE_MASK(d->ht_size_exp[table]); + dictEntry* he = d->ht_table[table][idx]; + while (he) { + sds dkey = (sds)he->key; + if (sdslen(dkey) == key.size() && (key.empty() || memcmp(dkey, key.data(), key.size()) == 0)) + return true; + he = he->next; + } + if (!dictIsRehashing(d)) + break; + } + return false; +} + bool IsMemberSet(unsigned encoding, std::string_view key, void* set) { long long llval; @@ -78,13 +92,14 @@ bool IsMemberSet(unsigned encoding, std::string_view key, void* set) { return intsetFind(is, llval); } case kEncodingStrMap: { - const FlatSet* fs = (FlatSet*)set; - return fs->Contains(key); + const dict* ds = (dict*)set; + return dictContains(ds, key); } default: LOG(FATAL) << "Unexpected encoding " << encoding; } } + size_t MallocUsedSet(unsigned encoding, void* ptr) { switch (encoding) { case kEncodingStrMap /*OBJ_ENCODING_HT*/: @@ -116,8 +131,7 @@ size_t MallocUsedZSet(unsigned encoding, void* ptr) { case OBJ_ENCODING_SKIPLIST: { zset* zs = (zset*)ptr; return DictMallocSize(zs->dict); - } - break; + } break; default: LOG(FATAL) << "Unknown set encoding type " << encoding; } @@ -269,8 +283,8 @@ size_t RobjWrapper::Size() const { return intsetLen(is); } case kEncodingStrMap: { - const FlatSet* fs = (FlatSet*)inner_obj_; - return fs->Size(); + dict* d = (dict*)inner_obj_; + return dictSize(d); } default: LOG(FATAL) << "Unexpected encoding " << encoding_; @@ -997,4 +1011,8 @@ size_t CompactObj::DecodedLen(size_t sz) const { return ascii_len(sz) - ((mask_ & ASCII1_ENC_BIT) ? 1 : 0); } +pmr::memory_resource* CompactObj::memory_resource() { + return tl.local_mr; +} + } // namespace dfly diff --git a/src/core/compact_object.h b/src/core/compact_object.h index 0f2d9c6..1c90146 100644 --- a/src/core/compact_object.h +++ b/src/core/compact_object.h @@ -257,6 +257,7 @@ class CompactObj { static Stats GetStats(); static void InitThreadLocal(std::pmr::memory_resource* mr); + static std::pmr::memory_resource* memory_resource(); // thread-local. private: size_t DecodedLen(size_t sz) const; diff --git a/src/server/set_family.cc b/src/server/set_family.cc index de5a1d8..0b1f6fd 100644 --- a/src/server/set_family.cc +++ b/src/server/set_family.cc @@ -12,7 +12,6 @@ extern "C" { } #include "base/logging.h" -#include "core/flat_set.h" #include "server/command_registry.h" #include "server/conn_context.h" #include "server/engine_shard_set.h" @@ -29,14 +28,7 @@ using SvArray = vector; namespace { -FlatSet* CreateFlatSet(pmr::memory_resource* mr) { - pmr::polymorphic_allocator pa(mr); - FlatSet* fs = pa.allocate(1); - pa.construct(fs, mr); - return fs; -} - -void ConvertTo(intset* src, FlatSet* dest) { +void ConvertTo(intset* src, dict* dest) { int64_t intele; char buf[32]; @@ -44,7 +36,8 @@ void ConvertTo(intset* src, FlatSet* dest) { int ii = 0; while (intsetGet(src, ii++, &intele)) { char* next = absl::numbers_internal::FastIntToBuffer(intele, buf); - dest->Add(string_view{buf, size_t(next - buf)}); + sds s = sdsnewlen(buf, next - buf); + CHECK(dictAddRaw(dest, s, NULL)); } } @@ -94,12 +87,14 @@ pair RemoveSet(ArgSlice vals, CompactObj* set) { isempty = (intsetLen(is) == 0); set->SetRObjPtr(is); } else { - FlatSet* fs = (FlatSet*)set->RObjPtr(); - for (auto val : vals) { - removed += fs->Remove(val); + dict* d = (dict*)set->RObjPtr(); + auto* shard = EngineShard::tlocal(); + for (auto member : vals) { + shard->tmp_str1 = sdscpylen(shard->tmp_str1, member.data(), member.size()); + int result = dictDelete(d, shard->tmp_str1); + removed += (result == DICT_OK); } - isempty = fs->Empty(); - set->SetRObjPtr(fs); + isempty = (dictSize(d) == 0); } return make_pair(removed, isempty); } @@ -116,12 +111,15 @@ template void FillSet(const CompactObj& set, F&& f) { f(string{buf, size_t(next - buf)}); } } else { - FlatSet* fs = (FlatSet*)set.RObjPtr(); + dict* ds = (dict*)set.RObjPtr(); string str; - for (const auto& member : *fs) { - member.GetString(&str); + dictIterator* di = dictGetIterator(ds); + dictEntry* de = nullptr; + while ((de = dictNext(di))) { + str.assign((sds)de->key, sdslen((sds)de->key)); f(move(str)); } + dictReleaseIterator(di); } } @@ -269,8 +267,8 @@ OpResult OpAdd(const OpArgs& op_args, std::string_view key, const ArgS intset* is = intsetNew(); co.InitRobj(OBJ_SET, kEncodingIntSet, is); } else { - FlatSet* fs = CreateFlatSet(op_args.shard->memory_resource()); - co.InitRobj(OBJ_SET, kEncodingStrMap, fs); + dict* ds = dictCreate(&setDictType); + co.InitRobj(OBJ_SET, kEncodingStrMap, ds); } } else { // We delibirately check only now because with othewrite=true @@ -292,11 +290,11 @@ OpResult OpAdd(const OpArgs& op_args, std::string_view key, const ArgS res += added; if (!success) { - FlatSet* fs = CreateFlatSet(op_args.shard->memory_resource()); - ConvertTo(is, fs); + dict* ds = dictCreate(&setDictType); + ConvertTo(is, ds); co.SetRObjPtr(is); - co.InitRobj(OBJ_SET, kEncodingStrMap, fs); - inner_obj = fs; + co.InitRobj(OBJ_SET, kEncodingStrMap, ds); + inner_obj = ds; break; } } @@ -306,9 +304,15 @@ OpResult OpAdd(const OpArgs& op_args, std::string_view key, const ArgS } if (co.Encoding() == kEncodingStrMap) { - FlatSet* fs = (FlatSet*)inner_obj; - for (auto val : vals) { - res += fs->Add(val); + dict* ds = (dict*)inner_obj; + + for (auto member : vals) { + es->tmp_str1 = sdscpylen(es->tmp_str1, member.data(), member.size()); + dictEntry* de = dictAddRaw(ds, es->tmp_str1, NULL); + if (de) { + de->key = sdsdup(es->tmp_str1); + ++res; + } } } @@ -362,7 +366,7 @@ OpStatus Mover::OpFind(Transaction* t, EngineShard* es) { ArgSlice largs = t->ShardArgsInShard(es->shard_id()); // In case both src and dest are in the same shard, largs size will be 2. - DCHECK_LT(largs.size(), 2u); + DCHECK_LE(largs.size(), 2u); for (auto k : largs) { unsigned index = (k == src_) ? 0 : 1; @@ -380,7 +384,7 @@ OpStatus Mover::OpFind(Transaction* t, EngineShard* es) { OpStatus Mover::OpMutate(Transaction* t, EngineShard* es) { ArgSlice largs = t->ShardArgsInShard(es->shard_id()); - DCHECK_LT(largs.size(), 2u); + DCHECK_LE(largs.size(), 2u); OpArgs op_args{es, t->db_index()}; for (auto k : largs) { @@ -899,9 +903,7 @@ OpResult SetFamily::OpPop(const OpArgs& op_args, std::string_view key * The number of requested elements is greater than or equal to * the number of elements inside the set: simply return the whole set. */ if (count >= slen) { - FillSet(it->second, [&result](string s) { - result.push_back(move(s)); - }); + FillSet(it->second, [&result](string s) { result.push_back(move(s)); }); /* Delete the set as it is now empty */ CHECK(es->db_slice().Del(op_args.db_ind, it)); } else { @@ -918,17 +920,16 @@ OpResult SetFamily::OpPop(const OpArgs& op_args, std::string_view key is = intsetTrimTail(is, count); // now remove last count items it->second.SetRObjPtr(is); } else { - FlatSet* fs = (FlatSet*)it->second.RObjPtr(); + dict* ds = (dict*)it->second.RObjPtr(); string str; - + dictIterator* di = dictGetSafeIterator(ds); for (uint32_t i = 0; i < count; ++i) { - auto it = fs->begin(); - it->GetString(&str); - fs->Erase(it); - result.push_back(move(str)); + dictEntry* de = dictNext(di); + DCHECK(de); + result.emplace_back((sds)de->key, sdslen((sds)de->key)); + dictDelete(ds, de->key); } - - it->second.SetRObjPtr(fs); + dictReleaseIterator(di); } } return result; @@ -947,9 +948,7 @@ OpResult SetFamily::OpInter(const Transaction* t, EngineShard* es, bo if (!find_res) return find_res.status(); - FillSet(find_res.value()->second, [&result](string s) { - result.push_back(move(s)); - }); + FillSet(find_res.value()->second, [&result](string s) { result.push_back(move(s)); }); return result; } diff --git a/src/server/set_family_test.cc b/src/server/set_family_test.cc index 2050f5b..42fd814 100644 --- a/src/server/set_family_test.cc +++ b/src/server/set_family_test.cc @@ -82,6 +82,10 @@ TEST_F(SetFamilyTest, SMove) { Run({"sadd", "b", "3", "5", "6", "2"}); resp = Run({"smove", "a", "b", "1"}); EXPECT_THAT(resp[0], IntArg(1)); + + Run({"sadd", "x", "a", "b", "c"}); + Run({"sadd", "y", "c"}); + EXPECT_THAT(Run({"smove", "x", "y", "c"}), ElementsAre(IntArg(1))); } TEST_F(SetFamilyTest, SPop) {