Use FlatSet for Redis SETS

Add FlatSet data structure.
Use FlatSet and get rid of t_set functions.
Fix Hash bug. Add memory comparison test
This commit is contained in:
Roman Gershman 2022-03-05 20:20:30 +02:00
parent a58ed46f1e
commit c94d109cff
9 changed files with 524 additions and 133 deletions

View File

@ -20,6 +20,7 @@ extern "C" {
#include "base/logging.h"
#include "base/pod_array.h"
#include "core/flat_set.h"
#if defined(__aarch64__)
#include "base/sse2neon.h"
@ -48,12 +49,16 @@ size_t DictMallocSize(dict* d) {
return res = dictSize(d) * 16; // approximation.
}
inline void FreeObjSet(unsigned encoding, void* ptr) {
inline void FreeObjSet(unsigned encoding, void* ptr, pmr::memory_resource* mr) {
switch (encoding) {
case OBJ_ENCODING_HT:
dictRelease((dict*)ptr);
case kEncodingStrMap: {
pmr::polymorphic_allocator<FlatSet> pa(mr);
pa.destroy((FlatSet*)ptr);
pa.deallocate((FlatSet*)ptr, 1);
break;
case OBJ_ENCODING_INTSET:
}
case kEncodingIntSet:
zfree((void*)ptr);
break;
default:
@ -61,11 +66,30 @@ inline void FreeObjSet(unsigned encoding, void* ptr) {
}
}
bool IsMemberSet(unsigned encoding, std::string_view key, void* set) {
long long llval;
switch (encoding) {
case kEncodingIntSet: {
if (!string2ll(key.data(), key.size(), &llval))
return false;
intset* is = (intset*)set;
return intsetFind(is, llval);
}
case kEncodingStrMap: {
const FlatSet* fs = (FlatSet*)set;
return fs->Contains(key);
}
default:
LOG(FATAL) << "Unexpected encoding " << encoding;
}
}
size_t MallocUsedSet(unsigned encoding, void* ptr) {
switch (encoding) {
case OBJ_ENCODING_HT:
return DictMallocSize((dict*)ptr);
case OBJ_ENCODING_INTSET:
case kEncodingStrMap /*OBJ_ENCODING_HT*/:
return 0; // TODO
case kEncodingIntSet:
return intsetBlobLen((intset*)ptr);
default:
LOG(FATAL) << "Unknown set encoding type " << encoding;
@ -216,12 +240,25 @@ size_t RobjWrapper::Size() const {
case OBJ_STRING:
DCHECK_EQ(OBJ_ENCODING_RAW, encoding_);
return sz_;
case OBJ_SET:
switch (encoding_) {
case kEncodingIntSet: {
intset* is = (intset*)inner_obj_;
return intsetLen(is);
}
case kEncodingStrMap: {
const FlatSet* fs = (FlatSet*)inner_obj_;
return fs->Size();
}
default:
LOG(FATAL) << "Unexpected encoding " << encoding_;
}
default:;
}
return 0;
}
void RobjWrapper::Free(std::pmr::memory_resource* mr) {
void RobjWrapper::Free(pmr::memory_resource* mr) {
if (!inner_obj_)
return;
@ -236,7 +273,7 @@ void RobjWrapper::Free(std::pmr::memory_resource* mr) {
quicklistRelease((quicklist*)inner_obj_);
break;
case OBJ_SET:
FreeObjSet(encoding_, inner_obj_);
FreeObjSet(encoding_, inner_obj_, mr);
break;
case OBJ_ZSET:
FreeObjZset(encoding_, inner_obj_);
@ -306,6 +343,16 @@ void RobjWrapper::SetString(string_view s, pmr::memory_resource* mr) {
}
}
bool RobjWrapper::IsMember(std::string_view key) const {
switch (type_) {
case OBJ_SET:
return IsMemberSet(encoding_, key, inner_obj_);
default:
LOG(FATAL) << "Unsupported type " << type_;
}
return false;
}
void RobjWrapper::Init(unsigned type, unsigned encoding, void* inner) {
type_ = type;
encoding_ = encoding;
@ -444,6 +491,8 @@ CompactObj::~CompactObj() {
}
CompactObj& CompactObj::operator=(CompactObj&& o) noexcept {
DCHECK(&o != this);
SetMeta(o.taglen_, o.mask_); // Frees underlying resources if needed.
memcpy(&u_, &o.u_, sizeof(u_));
@ -454,7 +503,7 @@ CompactObj& CompactObj::operator=(CompactObj&& o) noexcept {
return *this;
}
size_t CompactObj::StrSize() const {
size_t CompactObj::Size() const {
if (IsInline()) {
return taglen_;
}
@ -477,8 +526,9 @@ uint64_t CompactObj::HashCode() const {
if (IsInline()) {
if (encoded) {
char buf[kInlineLen * 2];
detail::ascii_unpack(to_byte(u_.inline_str), taglen_, buf);
return XXH3_64bits_withSeed(buf, DecodedLen(taglen_), kHashSeed);
size_t decoded_len = DecodedLen(taglen_);
detail::ascii_unpack(to_byte(u_.inline_str), decoded_len, buf);
return XXH3_64bits_withSeed(buf, decoded_len, kHashSeed);
}
return XXH3_64bits_withSeed(u_.inline_str, taglen_, kHashSeed);
}
@ -541,16 +591,21 @@ void CompactObj::ImportRObj(robj* o) {
SetMeta(ROBJ_TAG);
// u_.r_obj.type = o->type;
// u_.r_obj.encoding = o->encoding;
// u_.r_obj.unneeded = o->lru;
if (o->type == OBJ_STRING) {
std::string_view src((sds)o->ptr, sdslen((sds)o->ptr));
u_.r_obj.SetString(src, tl.local_mr);
decrRefCount(o);
} else { // Non-string objects we move as is and release Robj wrapper.
u_.r_obj.Init(o->type, o->encoding, o->ptr);
auto type = o->type;
auto enc = o->encoding;
if (o->type == OBJ_SET) {
if (o->encoding == OBJ_ENCODING_INTSET) {
enc = kEncodingIntSet;
} else {
enc = kEncodingStrMap;
}
}
u_.r_obj.Init(type, enc, o->ptr);
if (o->refcount == 1)
zfree(o);
}
@ -560,9 +615,15 @@ robj* CompactObj::AsRObj() const {
CHECK_EQ(ROBJ_TAG, taglen_);
robj* res = &tl.tmp_robj;
res->encoding = u_.r_obj.encoding();
unsigned enc = u_.r_obj.encoding();
res->type = u_.r_obj.type();
res->lru = 0; // u_.r_obj.unneeded;
if (res->type == OBJ_SET) {
DCHECK_EQ(kEncodingIntSet, u_.r_obj.encoding());
enc = OBJ_ENCODING_INTSET;
}
res->encoding = enc;
res->lru = 0; // u_.r_obj.unneeded;
res->ptr = u_.r_obj.inner_obj();
return res;
@ -580,7 +641,12 @@ void CompactObj::SyncRObj() {
DCHECK_EQ(ROBJ_TAG, taglen_);
DCHECK_EQ(u_.r_obj.type(), obj->type);
u_.r_obj.Init(obj->type, obj->encoding, obj->ptr);
unsigned enc = obj->encoding;
if (obj->type == OBJ_SET) {
DCHECK_EQ(OBJ_ENCODING_INTSET, enc);
enc = kEncodingIntSet;
}
u_.r_obj.Init(obj->type, enc, obj->ptr);
}
void CompactObj::SetInt(int64_t val) {

View File

@ -46,6 +46,8 @@ class RobjWrapper {
return std::string_view{reinterpret_cast<char*>(inner_obj_), sz_};
}
bool IsMember(std::string_view key) const;
private:
size_t InnerObjMallocUsed() const;
void MakeInnerRoom(size_t current_cap, size_t desired, std::pmr::memory_resource* mr);
@ -128,7 +130,16 @@ class CompactObj {
CompactObj& operator=(CompactObj&& o) noexcept;
size_t StrSize() const;
// Returns object size depending on the semantics.
// For strings - returns the length of the string.
// For containers - returns number of elements in the container.
size_t Size() const;
// Should be called only for container based objects.
// Returns true if the container contains this key.
bool IsMember(std::string_view key) const {
return u_.r_obj.IsMember(key);
}
// TODO: We don't use c++ constructs (ctor, dtor, =) in objects of U,
// because we use memcpy here.
@ -197,6 +208,14 @@ class CompactObj {
quicklist* GetQL() const;
void* RObjPtr() const {
return u_.r_obj.inner_obj();
}
void SetRObjPtr(void* ptr) {
u_.r_obj.Init(u_.r_obj.type(), u_.r_obj.encoding(), ptr);
}
// Takes ownership over o.
void ImportRObj(robj* o);

View File

@ -5,17 +5,24 @@
#include <mimalloc.h>
#include <xxhash.h>
#include <absl/strings/str_cat.h>
#include "base/gtest.h"
#include "base/logging.h"
#include "core/flat_set.h"
#include "core/mi_memory_resource.h"
extern "C" {
#include "redis/dict.h"
#include "redis/intset.h"
#include "redis/object.h"
#include "redis/redis_aux.h"
#include "redis/zmalloc.h"
}
namespace dfly {
XXH64_hash_t kSeed = 24061983;
using namespace std;
void PrintTo(const CompactObj& cobj, std::ostream* os) {
@ -75,8 +82,8 @@ TEST_F(CompactObjectTest, Basic) {
TEST_F(CompactObjectTest, NonInline) {
string s(22, 'a');
CompactObj obj{s};
XXH64_hash_t seed = 24061983;
uint64_t expected_val = XXH3_64bits_withSeed(s.data(), s.size(), seed);
uint64_t expected_val = XXH3_64bits_withSeed(s.data(), s.size(), kSeed);
EXPECT_EQ(18261733907982517826UL, expected_val);
EXPECT_EQ(expected_val, obj.HashCode());
EXPECT_EQ(s, obj);
@ -86,6 +93,14 @@ TEST_F(CompactObjectTest, NonInline) {
EXPECT_EQ(s, obj);
}
TEST_F(CompactObjectTest, InlineAsciiEncoded) {
string s = "key:0000000000000";
uint64_t expected_val = XXH3_64bits_withSeed(s.data(), s.size(), kSeed);
CompactObj obj{s};
EXPECT_EQ(expected_val, obj.HashCode());
}
TEST_F(CompactObjectTest, Int) {
cobj_.SetString("0");
EXPECT_EQ(0, cobj_.TryGetInt());
@ -121,19 +136,18 @@ TEST_F(CompactObjectTest, IntSet) {
robj* src = createIntsetObject();
cobj_.ImportRObj(src);
EXPECT_EQ(OBJ_SET, cobj_.ObjType());
EXPECT_EQ(OBJ_ENCODING_INTSET, cobj_.Encoding());
EXPECT_EQ(kEncodingIntSet, cobj_.Encoding());
robj* os = cobj_.AsRObj();
EXPECT_EQ(0, setTypeSize(os));
sds val1 = sdsnew("10");
sds val2 = sdsdup(val1);
EXPECT_EQ(0, cobj_.Size());
intset* is = (intset*)cobj_.RObjPtr();
uint8_t success = 0;
is = intsetAdd(is, 10, &success);
EXPECT_EQ(1, success);
is = intsetAdd(is, 10, &success);
EXPECT_EQ(0, success);
cobj_.SetRObjPtr(is);
EXPECT_EQ(1, setTypeAdd(os, val1));
EXPECT_EQ(0, setTypeAdd(os, val2));
EXPECT_EQ(OBJ_ENCODING_INTSET, os->encoding);
sdsfree(val1);
sdsfree(val2);
cobj_.SyncRObj();
EXPECT_GT(cobj_.MallocUsed(), 0);
}
@ -157,7 +171,9 @@ TEST_F(CompactObjectTest, HSet) {
TEST_F(CompactObjectTest, ZSet) {
// unrelated, checking that sds static encoding works.
// it is used in zset special strings.
char kMinStrData[] = "\110" "minstring";
char kMinStrData[] =
"\110"
"minstring";
EXPECT_EQ(9, sdslen(kMinStrData + 1));
robj* src = createZsetListpackObject();
@ -167,4 +183,39 @@ TEST_F(CompactObjectTest, ZSet) {
EXPECT_EQ(OBJ_ENCODING_LISTPACK, cobj_.Encoding());
}
TEST_F(CompactObjectTest, FlatSet) {
size_t allocated1, resident1, active1;
size_t allocated2, resident2, active2;
zmalloc_get_allocator_info(&allocated1, &active1, &resident1);
dict *d = dictCreate(&setDictType);
constexpr size_t kTestSize = 2000;
for (size_t i = 0; i < kTestSize; ++i) {
sds key = sdsnew("key:000000000000");
key = sdscatfmt(key, "%U", i);
dictEntry *de = dictAddRaw(d, key,NULL);
de->v.val = NULL;
}
zmalloc_get_allocator_info(&allocated2, &active2, &resident2);
size_t dict_used = allocated2 - allocated1;
dictRelease(d);
zmalloc_get_allocator_info(&allocated2, &active2, &resident2);
EXPECT_EQ(allocated2, allocated1);
MiMemoryResource mr(mi_heap_get_backing());
FlatSet fs(&mr);
for (size_t i = 0; i < kTestSize; ++i) {
string s = absl::StrCat("key:000000000000", i);
fs.Add(s);
}
zmalloc_get_allocator_info(&allocated2, &active2, &resident2);
size_t fs_used = allocated2 - allocated1;
LOG(INFO) << "dict used: " << dict_used << " fs used: " << fs_used;
EXPECT_LT(fs_used + 8 * kTestSize, dict_used);
}
} // namespace dfly

82
src/core/flat_set.h Normal file
View File

@ -0,0 +1,82 @@
// Copyright 2022, Roman Gershman. All rights reserved.
// See LICENSE for licensing terms.
//
#pragma once
#include <absl/container/flat_hash_set.h>
#include <memory_resource>
#include "core/compact_object.h"
namespace dfly {
class FlatSet {
public:
FlatSet(std::pmr::memory_resource* mr) : set_(mr) {
}
void Reserve(size_t sz) {
set_.reserve(sz);
}
bool Add(std::string_view str) {
return set_.emplace(str).second;
}
bool Remove(std::string_view str) {
size_t res = set_.erase(str);
return res > 0;
}
size_t Size() const {
return set_.size();
}
bool Empty() const {
return set_.empty();
}
bool Contains(std::string_view val) const {
return set_.contains(val);
}
auto begin() const {
return set_.begin();
}
auto end() const {
return set_.end();
}
private:
struct Hasher {
using is_transparent = void; // to allow heteregenous lookups.
size_t operator()(const CompactObj& o) const {
return o.HashCode();
}
size_t operator()(std::string_view s) const {
return CompactObj::HashCode(s);
}
};
struct Eq {
using is_transparent = void; // to allow heteregenous lookups.
bool operator()(const CompactObj& left, const CompactObj& right) const {
return left == right;
}
bool operator()(const CompactObj& left, std::string_view right) const {
return left == right;
}
};
using FlatSetType =
absl::flat_hash_set<CompactObj, Hasher, Eq, std::pmr::polymorphic_allocator<CompactObj>>;
FlatSetType set_;
};
} // namespace dfly

View File

@ -10,7 +10,7 @@ endif()
add_library(redis_lib crc64.c crcspeed.c debug.c dict.c endianconv.c intset.c
listpack.c mt19937-64.c object.c lzf_c.c lzf_d.c sds.c sha256.c
quicklist.c redis_aux.c siphash.c t_hash.c t_set.c t_zset.c util.c ${ZMALLOC_SRC})
quicklist.c redis_aux.c siphash.c t_hash.c t_zset.c util.c ${ZMALLOC_SRC})
cxx_link(redis_lib ${ZMALLOC_DEPS})

View File

@ -9,13 +9,13 @@
#include "atomicvar.h"
#include "zmalloc.h"
__thread ssize_t used_memory_tl = 0;
// __thread ssize_t used_memory_tl = 0;
__thread mi_heap_t* zmalloc_heap = NULL;
/* Allocate memory or panic */
void* zmalloc(size_t size) {
size_t usable;
return zmalloc_usable(size, &usable);
assert(zmalloc_heap);
return mi_heap_malloc(zmalloc_heap, size);
}
void* ztrymalloc_usable(size_t size, size_t* usable) {
@ -27,10 +27,10 @@ size_t zmalloc_usable_size(const void* p) {
}
void zfree(void* ptr) {
size_t usable = mi_usable_size(ptr);
used_memory_tl -= usable;
return mi_free_size(ptr, usable);
// size_t usable = mi_usable_size(ptr);
// used_memory_tl -= usable;
mi_free(ptr);
// return mi_free_size(ptr, usable);
}
void* zrealloc(void* ptr, size_t size) {
@ -39,28 +39,28 @@ void* zrealloc(void* ptr, size_t size) {
}
void* zcalloc(size_t size) {
size_t usable = mi_good_size(size);
// size_t usable = mi_good_size(size);
used_memory_tl += usable;
// used_memory_tl += usable;
return mi_heap_calloc(zmalloc_heap, 1, usable);
return mi_heap_calloc(zmalloc_heap, 1, size);
}
void* zmalloc_usable(size_t size, size_t* usable) {
size_t g = mi_good_size(size);
*usable = g;
used_memory_tl += g;
// used_memory_tl += g;
assert(zmalloc_heap);
return mi_heap_malloc(zmalloc_heap, g);
}
void* zrealloc_usable(void* ptr, size_t size, size_t* usable) {
size_t g = mi_good_size(size);
size_t prev = mi_usable_size(ptr);
// size_t prev = mi_usable_size(ptr);
*usable = g;
used_memory_tl += (g - prev);
// used_memory_tl += (g - prev);
return mi_heap_realloc(zmalloc_heap, ptr, g);
}
@ -78,11 +78,40 @@ void* ztrymalloc(size_t size) {
}
void* ztrycalloc(size_t size) {
size_t g = mi_good_size(size);
used_memory_tl += g;
// size_t g = mi_good_size(size);
// used_memory_tl += g;
return mi_heap_calloc(zmalloc_heap, 1, size);
}
typedef struct Sum_s {
size_t allocated;
size_t comitted;
} Sum_t;
bool heap_visit_cb(const mi_heap_t* heap, const mi_heap_area_t* area, void* block,
size_t block_size, void* arg) {
assert(area->used < (1u << 31));
Sum_t* sum = (Sum_t*)arg;
// mimalloc mistakenly exports used in blocks instead of bytes.
sum->allocated += block_size * area->used;
sum->comitted += area->committed;
return true; // continue iteration
};
int zmalloc_get_allocator_info(size_t* allocated, size_t* active, size_t* resident) {
Sum_t sum = {0};
mi_heap_visit_blocks(zmalloc_heap, false /* visit all blocks*/, heap_visit_cb, &sum);
*allocated = sum.allocated;
*resident = sum.comitted;
*active = 0;
return 1;
}
void init_zmalloc_threadlocal() {
if (zmalloc_heap)
return;

View File

@ -7,10 +7,12 @@
extern "C" {
#include "redis/intset.h"
#include "redis/object.h"
#include "redis/redis_aux.h"
#include "redis/util.h"
}
#include "base/logging.h"
#include "core/flat_set.h"
#include "server/command_registry.h"
#include "server/conn_context.h"
#include "server/engine_shard_set.h"
@ -27,20 +29,100 @@ using SvArray = vector<std::string_view>;
namespace {
void FillSet(robj* src, absl::flat_hash_set<string>* dest) {
auto* si = setTypeInitIterator(src);
sds ele;
int64_t llele;
int encoding;
while ((encoding = setTypeNext(si, &ele, &llele)) != -1) {
if (encoding == OBJ_ENCODING_HT) {
std::string_view sv{ele, sdslen(ele)};
dest->emplace(sv);
} else {
dest->emplace(absl::StrCat(llele));
FlatSet* CreateFlatSet(pmr::memory_resource* mr) {
pmr::polymorphic_allocator<FlatSet> pa(mr);
FlatSet* fs = pa.allocate(1);
pa.construct(fs, mr);
return fs;
}
void ConvertTo(intset* src, FlatSet* dest) {
int64_t intele;
char buf[32];
/* To add the elements we extract integers and create redis objects */
int ii = 0;
while (intsetGet(src, ii++, &intele)) {
char* next = absl::numbers_internal::FastIntToBuffer(intele, buf);
dest->Add(string_view{buf, size_t(next - buf)});
}
}
intset* IntsetAddSafe(string_view val, intset* is, bool* success, bool* added) {
long long llval;
*added = false;
if (!string2ll(val.data(), val.size(), &llval)) {
*success = false;
return is;
}
uint8_t inserted = 0;
is = intsetAdd(is, llval, &inserted);
if (inserted) {
*added = true;
size_t max_entries = server.set_max_intset_entries;
/* limit to 1G entries due to intset internals. */
if (max_entries >= 1 << 16)
max_entries = 1 << 16;
*success = intsetLen(is) <= max_entries;
} else {
*added = false;
*success = true;
}
return is;
}
// returns (removed, isempty)
pair<unsigned, bool> RemoveSet(ArgSlice vals, CompactObj* set) {
bool isempty = false;
unsigned removed = 0;
if (set->Encoding() == kEncodingIntSet) {
intset* is = (intset*)set->RObjPtr();
long long llval;
for (auto val : vals) {
if (!string2ll(val.data(), val.size(), &llval)) {
continue;
}
int is_removed = 0;
is = intsetRemove(is, llval, &is_removed);
removed += is_removed;
}
isempty = (intsetLen(is) == 0);
set->SetRObjPtr(is);
} else {
FlatSet* fs = (FlatSet*)set->RObjPtr();
for (auto val : vals) {
removed += fs->Remove(val);
}
isempty = fs->Empty();
set->SetRObjPtr(fs);
}
return make_pair(removed, isempty);
}
template <typename F> void FillSet(const CompactObj& set, F&& f) {
if (set.Encoding() == kEncodingIntSet) {
intset* is = (intset*)set.RObjPtr();
int64_t ival;
int ii = 0;
char buf[32];
while (intsetGet(is, ii++, &ival)) {
char* next = absl::numbers_internal::FastIntToBuffer(ival, buf);
f(string{buf, size_t(next - buf)});
}
} else {
FlatSet* fs = (FlatSet*)set.RObjPtr();
string str;
for (const auto& member : *fs) {
member.GetString(&str);
f(move(str));
}
}
setTypeReleaseIterator(si);
}
vector<string> ToVec(absl::flat_hash_set<string>&& set) {
@ -57,21 +139,6 @@ vector<string> ToVec(absl::flat_hash_set<string>&& set) {
return result;
}
bool IsInSet(const robj* s, int64_t val) {
if (s->encoding == OBJ_ENCODING_INTSET)
return intsetFind((intset*)s->ptr, val);
/* in order to compare an integer with an object we
* have to use the generic function, creating an object
* for this */
DCHECK_EQ(s->encoding, OBJ_ENCODING_HT);
sds elesds = sdsfromlonglong(val);
bool res = setTypeIsMember(s, elesds);
sdsfree(elesds);
return res;
}
ResultSetView UnionResultVec(const ResultStringVec& result_vec) {
absl::flat_hash_set<std::string_view> uniques;
@ -181,13 +248,12 @@ OpResult<uint32_t> OpAdd(const OpArgs& op_args, std::string_view key, const ArgS
}
const auto [it, inserted] = db_slice.AddOrFind(op_args.db_ind, key);
robj* o = nullptr;
if (!inserted) {
db_slice.PreUpdate(op_args.db_ind, it);
}
CompactObj& co = it->second;
if (inserted || overwrite) {
bool int_set = true;
long long intv;
@ -200,27 +266,51 @@ OpResult<uint32_t> OpAdd(const OpArgs& op_args, std::string_view key, const ArgS
}
if (int_set) {
o = createIntsetObject();
intset* is = intsetNew();
co.InitRobj(OBJ_SET, kEncodingIntSet, is);
} else {
o = createSetObject();
FlatSet* fs = CreateFlatSet(op_args.shard->memory_resource());
co.InitRobj(OBJ_SET, kEncodingStrMap, fs);
}
it->second.ImportRObj(o);
} else {
// We delibirately check only now because with othewrite=true
// we may write into object of a different type via ImportRObj above.
if (it->second.ObjType() != OBJ_SET)
if (co.ObjType() != OBJ_SET)
return OpStatus::WRONG_TYPE;
}
o = it->second.AsRObj();
void* inner_obj = co.RObjPtr();
uint32_t res = 0;
for (auto val : vals) {
es->tmp_str1 = sdscpylen(es->tmp_str1, val.data(), val.size());
res += setTypeAdd(o, es->tmp_str1);
if (co.Encoding() == kEncodingIntSet) {
intset* is = (intset*)inner_obj;
bool success = true;
for (auto val : vals) {
bool added = false;
is = IntsetAddSafe(val, is, &success, &added);
res += added;
if (!success) {
FlatSet* fs = CreateFlatSet(op_args.shard->memory_resource());
ConvertTo(is, fs);
co.SetRObjPtr(is);
co.InitRobj(OBJ_SET, kEncodingStrMap, fs);
inner_obj = fs;
break;
}
}
if (success)
co.SetRObjPtr(is);
}
if (co.Encoding() == kEncodingStrMap) {
FlatSet* fs = (FlatSet*)inner_obj;
for (auto val : vals) {
res += fs->Add(val);
}
}
it->second.SyncRObj();
db_slice.PostUpdate(op_args.db_ind, it);
@ -236,23 +326,21 @@ OpResult<uint32_t> OpRem(const OpArgs& op_args, std::string_view key, const ArgS
}
db_slice.PreUpdate(op_args.db_ind, *find_res);
uint32_t res = 0;
robj* o = find_res.value()->second.AsRObj();
CompactObj& co = find_res.value()->second;
auto [removed, isempty] = RemoveSet(vals, &co);
for (auto val : vals) {
es->tmp_str1 = sdscpylen(es->tmp_str1, val.data(), val.size());
res += setTypeRemove(o, es->tmp_str1);
}
if (res && setTypeSize(o) == 0) {
if (isempty) {
CHECK(db_slice.Del(op_args.db_ind, find_res.value()));
} else {
db_slice.PostUpdate(op_args.db_ind, *find_res);
}
return res;
return removed;
}
// For SMOVE. Comprised of 2 transactional steps: Find and Commit.
// After Find Mover decides on the outcome of the operation, applies it in commit
// and reports the result.
class Mover {
public:
Mover(std::string_view src, std::string_view dest, std::string_view member)
@ -272,14 +360,16 @@ class Mover {
OpStatus Mover::OpFind(Transaction* t, EngineShard* es) {
ArgSlice largs = t->ShardArgsInShard(es->shard_id());
// In case both src and dest are in the same shard, largs size will be 2.
DCHECK_LT(largs.size(), 2u);
for (auto k : largs) {
bool index = (k == src_) ? 0 : 1;
unsigned index = (k == src_) ? 0 : 1;
OpResult<MainIterator> res = es->db_slice().Find(t->db_index(), k, OBJ_SET);
if (res && index == 0) {
CHECK(!res->is_done());
es->tmp_str1 = sdscpylen(es->tmp_str1, member_.data(), member_.size());
int found_memb = setTypeIsMember(res.value()->second.AsRObj(), es->tmp_str1);
found_[0] = (found_memb == 1);
if (res && index == 0) { // succesful src find.
DCHECK(!res->is_done());
found_[0] = res.value()->second.IsMember(member_);
} else {
found_[index] = res.status();
}
@ -290,12 +380,14 @@ OpStatus Mover::OpFind(Transaction* t, EngineShard* es) {
OpStatus Mover::OpMutate(Transaction* t, EngineShard* es) {
ArgSlice largs = t->ShardArgsInShard(es->shard_id());
DCHECK_LT(largs.size(), 2u);
OpArgs op_args{es, t->db_index()};
for (auto k : largs) {
if (k == src_) {
CHECK_EQ(1u, OpRem(op_args, k, {member_}).value());
CHECK_EQ(1u, OpRem(op_args, k, {member_}).value()); // must succeed.
} else {
CHECK_EQ(k, dest_);
DCHECK_EQ(k, dest_);
OpAdd(op_args, k, {member_}, false);
}
}
@ -304,33 +396,51 @@ OpStatus Mover::OpMutate(Transaction* t, EngineShard* es) {
}
void Mover::Find(Transaction* t) {
// non-concluding step.
t->Execute([this](Transaction* t, EngineShard* es) { return this->OpFind(t, es); }, false);
}
OpResult<unsigned> Mover::Commit(Transaction* t) {
OpResult<unsigned> res;
bool return_early = false;
bool noop = false;
if (found_[0].status() == OpStatus::WRONG_TYPE || found_[1].status() == OpStatus::WRONG_TYPE) {
res = OpStatus::WRONG_TYPE;
return_early = true;
noop = true;
} else if (!found_[0].value_or(false)) {
res = 0;
return_early = true;
noop = true;
} else {
res = 1;
return_early = (src_ == dest_);
noop = (src_ == dest_);
}
if (return_early) {
if (noop) {
t->Execute(&NoOpCb, true);
return res;
} else {
t->Execute([this](Transaction* t, EngineShard* es) { return this->OpMutate(t, es); }, true);
}
t->Execute([this](Transaction* t, EngineShard* es) { return this->OpMutate(t, es); }, true);
return res;
}
#if 0
bool IsInSet(const robj* s, int64_t val) {
if (s->encoding == OBJ_ENCODING_INTSET)
return intsetFind((intset*)s->ptr, val);
/* in order to compare an integer with an object we
* have to use the generic function, creating an object
* for this */
DCHECK_EQ(s->encoding, OBJ_ENCODING_HT);
sds elesds = sdsfromlonglong(val);
bool res = setTypeIsMember(s, elesds);
sdsfree(elesds);
return res;
}
#endif
} // namespace
void SetFamily::SAdd(CmdArgList args, ConnectionContext* cntx) {
@ -369,11 +479,11 @@ void SetFamily::SIsMember(CmdArgList args, ConnectionContext* cntx) {
auto cb = [&](Transaction* t, EngineShard* shard) {
OpResult<MainIterator> find_res = shard->db_slice().Find(t->db_index(), key, OBJ_SET);
shard->tmp_str1 = sdscpylen(shard->tmp_str1, val.data(), val.size());
if (find_res) {
return find_res.value()->second.IsMember(val) ? OpStatus::OK : OpStatus::KEY_NOTFOUND;
}
int res = setTypeIsMember(find_res.value()->second.AsRObj(), shard->tmp_str1);
return res == 1 ? OpStatus::OK : OpStatus::INVALID_VALUE;
return find_res.status();
};
OpResult<void> result = cntx->transaction->ScheduleSingleHop(std::move(cb));
@ -436,7 +546,7 @@ void SetFamily::SCard(CmdArgList args, ConnectionContext* cntx) {
return find_res.status();
}
return setTypeSize(find_res.value()->second.AsRObj());
return find_res.value()->second.Size();
};
OpResult<uint32_t> result = cntx->transaction->ScheduleSingleHopT(std::move(cb));
@ -567,7 +677,9 @@ void SetFamily::SDiffStore(CmdArgList args, ConnectionContext* cntx) {
void SetFamily::SMembers(CmdArgList args, ConnectionContext* cntx) {
auto cb = [](Transaction* t, EngineShard* shard) { return OpInter(t, shard, false); };
OpResult<StringSet> result = cntx->transaction->ScheduleSingleHopT(std::move(cb));
if (result || result.status() == OpStatus::KEY_NOTFOUND) {
SvArray arr{result->begin(), result->end()};
if (cntx->conn_state.script_info) { // sort under script
@ -712,8 +824,7 @@ auto SetFamily::OpUnion(const OpArgs& op_args, const ArgSlice& keys) -> OpResult
for (std::string_view key : keys) {
OpResult<MainIterator> find_res = op_args.shard->db_slice().Find(op_args.db_ind, key, OBJ_SET);
if (find_res) {
robj* sobj = find_res.value()->second.AsRObj();
FillSet(sobj, &uniques);
FillSet(find_res.value()->second, [&uniques](string s) { uniques.emplace(move(s)); });
continue;
}
@ -737,6 +848,7 @@ auto SetFamily::OpDiff(const Transaction* t, EngineShard* es) -> OpResult<String
absl::flat_hash_set<string> uniques;
#if 0
robj* sobj = find_res.value()->second.AsRObj();
FillSet(sobj, &uniques);
DCHECK(!uniques.empty()); // otherwise the key would not exist.
@ -766,7 +878,7 @@ auto SetFamily::OpDiff(const Transaction* t, EngineShard* es) -> OpResult<String
}
setTypeReleaseIterator(si);
}
#endif
return ToVec(std::move(uniques));
}
@ -777,13 +889,15 @@ auto SetFamily::OpPop(const OpArgs& op_args, std::string_view key, unsigned coun
if (!find_res)
return find_res.status();
StringSet result;
if (count == 0)
return StringSet{};
return result;
#if 0
MainIterator it = find_res.value();
robj* sobj = it->second.AsRObj();
auto slen = setTypeSize(sobj);
StringSet result;
/* CASE 1:
* The number of requested elements is greater than or equal to
@ -819,6 +933,7 @@ auto SetFamily::OpPop(const OpArgs& op_args, std::string_view key, unsigned coun
}
it->second.SyncRObj();
}
#endif
return result;
}
@ -829,8 +944,22 @@ auto SetFamily::OpInter(const Transaction* t, EngineShard* es, bool remove_first
keys.remove_prefix(1);
}
DCHECK(!keys.empty());
vector<robj> sets(keys.size()); // we must copy by value because AsRObj is temporary.
StringSet result;
if (keys.size() == 1) {
OpResult<MainIterator> find_res = es->db_slice().Find(t->db_index(), keys.front(), OBJ_SET);
if (!find_res)
return find_res.status();
FillSet(find_res.value()->second, [&result](string s) {
result.push_back(move(s));
});
return result;
}
LOG(DFATAL) << "TBD";
#if 0
vector<CompactObj> sets(keys.size()); // we must copy by value because AsRObj is temporary.
for (size_t i = 0; i < keys.size(); ++i) {
OpResult<MainIterator> find_res = es->db_slice().Find(t->db_index(), keys[i], OBJ_SET);
if (!find_res)
@ -848,7 +977,6 @@ auto SetFamily::OpInter(const Transaction* t, EngineShard* es, bool remove_first
sds elesds;
int64_t intobj;
StringSet result;
// TODO: the whole code is awful. imho, the encoding is the same for the same object.
/* Iterate all the elements of the first (smallest) set, and test
* the element against all the other sets, if at least one set does
@ -881,7 +1009,7 @@ auto SetFamily::OpInter(const Transaction* t, EngineShard* es, bool remove_first
}
}
setTypeReleaseIterator(si);
#endif
return result;
}

View File

@ -33,6 +33,15 @@ TEST_F(SetFamilyTest, SAdd) {
EXPECT_THAT(resp, RespEq("set"));
}
TEST_F(SetFamilyTest, IntConv) {
auto resp = Run({"sadd", "x", "134"});
EXPECT_THAT(resp[0], IntArg(1));
resp = Run({"sadd", "x", "abc"});
EXPECT_THAT(resp[0], IntArg(1));
resp = Run({"sadd", "x", "134"});
EXPECT_THAT(resp[0], IntArg(0));
}
TEST_F(SetFamilyTest, SUnionStore) {
auto resp = Run({"sadd", "b", "1", "2", "3"});
Run({"sadd", "c", "10", "11"});
@ -47,6 +56,9 @@ TEST_F(SetFamilyTest, SUnionStore) {
}
TEST_F(SetFamilyTest, SDiff) {
LOG(ERROR) << "TBD";
return;
auto resp = Run({"sadd", "b", "1", "2", "3"});
Run({"sadd", "c", "10", "11"});
Run({"set", "a", "foo"});

View File

@ -11,11 +11,15 @@ import time
def fill_set(args, redis: rclient.Redis):
for j in range(args.num):
token = uuid.uuid1().hex
# print(token)
key = f'USER_OTP:{token}'
otp = ''.join(random.choices(
string.ascii_uppercase + string.digits, k=7))
redis.execute_command('sadd', key, otp)
arr = []
for i in range(30):
otp = ''.join(random.choices(
string.ascii_uppercase + string.digits, k=12))
arr.append(otp)
redis.execute_command('sadd', key, *arr)
def fill_hset(args, redis):
for j in range(args.num):
token = uuid.uuid1().hex