Integrate UpdateHooks into set-family code

This commit is contained in:
Roman Gershman 2022-04-29 19:54:41 +03:00
parent f65d6308c7
commit 362f1f4ec4
2 changed files with 215 additions and 201 deletions

View File

@ -85,6 +85,26 @@ pair<unsigned, bool> RemoveSet(ArgSlice vals, CompactObj* set) {
return make_pair(removed, isempty);
}
void InitSet(ArgSlice vals, CompactObj* set) {
bool int_set = true;
long long intv;
for (auto v : vals) {
if (!string2ll(v.data(), v.size(), &intv)) {
int_set = false;
break;
}
}
if (int_set) {
intset* is = intsetNew();
set->InitRobj(OBJ_SET, kEncodingIntSet, is);
} else {
dict* ds = dictCreate(&setDictType);
set->InitRobj(OBJ_SET, kEncodingStrMap, ds);
}
}
vector<string> ToVec(absl::flat_hash_set<string>&& set) {
vector<string> result(set.size());
size_t i = 0;
@ -159,7 +179,7 @@ OpResult<SvArray> InterResultVec(const ResultStringVec& result_vec, unsigned req
bool first = true;
for (const auto& res : result_vec) {
if (res.status() == OpStatus::SKIPPED)
continue;
continue;
DCHECK(res); // we handled it above.
@ -280,10 +300,13 @@ template <typename F> void FillSet(const SetType& set, F&& f) {
}
// if overwrite is true then OpAdd writes vals into the key and discards its previous value.
OpResult<uint32_t> OpAdd(const OpArgs& op_args, std::string_view key, const ArgSlice& vals,
OpResult<uint32_t> OpAdd(const OpArgs& op_args, std::string_view key, ArgSlice vals,
bool overwrite) {
auto* es = op_args.shard;
auto& db_slice = es->db_slice();
// overwrite - meaning we run in the context of 2-hop operation and we had already
// ensured that the key exists.
if (overwrite && vals.empty()) {
auto it = db_slice.FindExt(op_args.db_ind, key).first;
db_slice.Del(op_args.db_ind, it);
@ -291,35 +314,22 @@ OpResult<uint32_t> OpAdd(const OpArgs& op_args, std::string_view key, const ArgS
}
const auto [it, inserted] = db_slice.AddOrFind(op_args.db_ind, key);
if (!inserted) {
db_slice.PreUpdate(op_args.db_ind, it);
}
CompactObj& co = it->second;
if (inserted || overwrite) {
bool int_set = true;
long long intv;
for (auto v : vals) {
if (!string2ll(v.data(), v.size(), &intv)) {
int_set = false;
break;
}
}
if (int_set) {
intset* is = intsetNew();
co.InitRobj(OBJ_SET, kEncodingIntSet, is);
} else {
dict* ds = dictCreate(&setDictType);
co.InitRobj(OBJ_SET, kEncodingStrMap, ds);
}
} else {
// We delibirately check only now because with othewrite=true
// we may write into object of a different type via ImportRObj above.
if (co.ObjType() != OBJ_SET)
if (!inserted) {
// for non-overwrite case it must be set.
if (!overwrite && co.ObjType() != OBJ_SET)
return OpStatus::WRONG_TYPE;
// Update stats and trigger any handle the old value if needed.
db_slice.PreUpdate(op_args.db_ind, it);
}
if (inserted || overwrite) {
// does not store the values, merely sets the encoding.
// TODO: why not store the values as well?
InitSet(vals, &co);
}
void* inner_obj = co.RObjPtr();
@ -376,13 +386,14 @@ OpResult<uint32_t> OpRem(const OpArgs& op_args, std::string_view key, const ArgS
}
db_slice.PreUpdate(op_args.db_ind, *find_res);
CompactObj& co = find_res.value()->second;
auto [removed, isempty] = RemoveSet(vals, &co);
db_slice.PostUpdate(op_args.db_ind, *find_res);
if (isempty) {
CHECK(db_slice.Del(op_args.db_ind, find_res.value()));
} else {
db_slice.PostUpdate(op_args.db_ind, *find_res);
}
return removed;
@ -482,6 +493,172 @@ void ScanCallback(void* privdata, const dictEntry* de) {
sv->push_back(string(key, sdslen(key)));
}
// Read-only OpUnion op on sets.
OpResult<StringVec> OpUnion(const OpArgs& op_args, ArgSlice keys) {
DCHECK(!keys.empty());
absl::flat_hash_set<string> uniques;
for (std::string_view key : keys) {
OpResult<PrimeIterator> find_res = op_args.shard->db_slice().Find(op_args.db_ind, key, OBJ_SET);
if (find_res) {
SetType st{find_res.value()->second.RObjPtr(), find_res.value()->second.Encoding()};
FillSet(st, [&uniques](string s) { uniques.emplace(move(s)); });
continue;
}
if (find_res.status() != OpStatus::KEY_NOTFOUND) {
return find_res.status();
}
}
return ToVec(std::move(uniques));
}
// Read-only OpDiff op on sets.
OpResult<StringVec> OpDiff(const OpArgs& op_args, ArgSlice keys) {
DCHECK(!keys.empty());
DVLOG(1) << "OpDiff from " << keys.front();
EngineShard* es = op_args.shard;
OpResult<PrimeIterator> find_res = es->db_slice().Find(op_args.db_ind, keys.front(), OBJ_SET);
if (!find_res) {
return find_res.status();
}
absl::flat_hash_set<string> uniques;
SetType st{find_res.value()->second.RObjPtr(), find_res.value()->second.Encoding()};
FillSet(st, [&uniques](string s) { uniques.insert(move(s)); });
DCHECK(!uniques.empty()); // otherwise the key would not exist.
for (size_t i = 1; i < keys.size(); ++i) {
OpResult<PrimeIterator> diff_res = es->db_slice().Find(op_args.db_ind, keys[i], OBJ_SET);
if (!diff_res) {
if (diff_res.status() == OpStatus::WRONG_TYPE) {
return OpStatus::WRONG_TYPE;
}
continue; // KEY_NOTFOUND
}
SetType st2{diff_res.value()->second.RObjPtr(), diff_res.value()->second.Encoding()};
if (st2.second == kEncodingIntSet) {
int ii = 0;
intset* is = (intset*)st2.first;
int64_t intele;
char buf[32];
while (intsetGet(is, ii++, &intele)) {
char* next = absl::numbers_internal::FastIntToBuffer(intele, buf);
uniques.erase(string_view{buf, size_t(next - buf)});
}
} else {
DCHECK_EQ(kEncodingStrMap, st2.second);
dict* ds = (dict*)st2.first;
dictIterator* di = dictGetIterator(ds);
dictEntry* de = nullptr;
while ((de = dictNext(di))) {
sds key = (sds)de->key;
uniques.erase(string_view{key, sdslen(key)});
}
dictReleaseIterator(di);
}
}
return ToVec(std::move(uniques));
}
// Read-only OpInter op on sets.
OpResult<StringVec> OpInter(const Transaction* t, EngineShard* es, bool remove_first) {
ArgSlice keys = t->ShardArgsInShard(es->shard_id());
if (remove_first) {
keys.remove_prefix(1);
}
DCHECK(!keys.empty());
StringVec result;
if (keys.size() == 1) {
OpResult<PrimeIterator> find_res = es->db_slice().Find(t->db_index(), keys.front(), OBJ_SET);
if (!find_res)
return find_res.status();
SetType st{find_res.value()->second.RObjPtr(), find_res.value()->second.Encoding()};
FillSet(st, [&result](string s) { result.push_back(move(s)); });
return result;
}
// we must copy by value because AsRObj is temporary.
vector<SetType> sets(keys.size());
OpStatus status = OpStatus::OK;
for (size_t i = 0; i < keys.size(); ++i) {
OpResult<PrimeIterator> find_res = es->db_slice().Find(t->db_index(), keys[i], OBJ_SET);
if (!find_res) {
if (status == OpStatus::OK || status == OpStatus::KEY_NOTFOUND ||
find_res.status() != OpStatus::KEY_NOTFOUND) {
status = find_res.status();
}
continue;
}
const PrimeValue& pv = find_res.value()->second;
void* ptr = pv.RObjPtr();
sets[i] = make_pair(ptr, pv.Encoding());
}
if (status != OpStatus::OK)
return status;
auto comp = [](const SetType& left, const SetType& right) {
return SetTypeLen(left) < SetTypeLen(right);
};
std::sort(sets.begin(), sets.end(), comp);
int encoding = sets.front().second;
if (encoding == kEncodingIntSet) {
int ii = 0;
intset* is = (intset*)sets.front().first;
int64_t intele;
while (intsetGet(is, ii++, &intele)) {
size_t j = 1;
for (j = 1; j < sets.size(); j++) {
if (sets[j].first != is && !IsInSet(sets[j], intele))
break;
}
/* Only take action when all sets contain the member */
if (j == sets.size()) {
result.push_back(absl::StrCat(intele));
}
}
} else {
dict* ds = (dict*)sets.front().first;
dictIterator* di = dictGetIterator(ds);
dictEntry* de = nullptr;
while ((de = dictNext(di))) {
size_t j = 1;
sds key = (sds)de->key;
string_view member{key, sdslen(key)};
for (j = 1; j < sets.size(); j++) {
if (sets[j].first != ds && !IsInSet(sets[j], member))
break;
}
/* Only take action when all sets contain the member */
if (j == sets.size()) {
result.push_back(string(member));
}
}
dictReleaseIterator(di);
}
return result;
}
} // namespace
void SetFamily::SAdd(CmdArgList args, ConnectionContext* cntx) {
@ -676,6 +853,7 @@ void SetFamily::SDiffStore(CmdArgList args, ConnectionContext* cntx) {
VLOG(1) << "SDiffStore " << src_key << " " << src_shard;
// read-only op
auto diff_cb = [&](Transaction* t, EngineShard* shard) {
ArgSlice largs = t->ShardArgsInShard(shard->shard_id());
DCHECK(!largs.empty());
@ -690,10 +868,11 @@ void SetFamily::SDiffStore(CmdArgList args, ConnectionContext* cntx) {
OpArgs op_args{shard, t->db_index()};
if (shard->shard_id() == src_shard) {
CHECK_EQ(src_key, largs.front());
result_set[shard->shard_id()] = OpDiff(op_args, largs);
result_set[shard->shard_id()] = OpDiff(op_args, largs); // Diff
} else {
result_set[shard->shard_id()] = OpUnion(op_args, largs);
result_set[shard->shard_id()] = OpUnion(op_args, largs); // Union
}
return OpStatus::OK;
};
@ -893,82 +1072,9 @@ void SetFamily::SScan(CmdArgList args, ConnectionContext* cntx) {
}
}
OpResult<StringVec> SetFamily::OpUnion(const OpArgs& op_args, ArgSlice keys) {
DCHECK(!keys.empty());
absl::flat_hash_set<string> uniques;
for (std::string_view key : keys) {
OpResult<PrimeIterator> find_res = op_args.shard->db_slice().Find(op_args.db_ind, key, OBJ_SET);
if (find_res) {
SetType st{find_res.value()->second.RObjPtr(), find_res.value()->second.Encoding()};
FillSet(st, [&uniques](string s) { uniques.emplace(move(s)); });
continue;
}
if (find_res.status() != OpStatus::KEY_NOTFOUND) {
return find_res.status();
}
}
return ToVec(std::move(uniques));
}
OpResult<StringVec> SetFamily::OpDiff(const OpArgs& op_args, ArgSlice keys) {
DCHECK(!keys.empty());
DVLOG(1) << "OpDiff from " << keys.front();
EngineShard* es = op_args.shard;
OpResult<PrimeIterator> find_res = es->db_slice().Find(op_args.db_ind, keys.front(), OBJ_SET);
if (!find_res) {
return find_res.status();
}
absl::flat_hash_set<string> uniques;
SetType st{find_res.value()->second.RObjPtr(), find_res.value()->second.Encoding()};
FillSet(st, [&uniques](string s) { uniques.insert(move(s)); });
DCHECK(!uniques.empty()); // otherwise the key would not exist.
for (size_t i = 1; i < keys.size(); ++i) {
OpResult<PrimeIterator> diff_res = es->db_slice().Find(op_args.db_ind, keys[i], OBJ_SET);
if (!diff_res) {
if (diff_res.status() == OpStatus::WRONG_TYPE) {
return OpStatus::WRONG_TYPE;
}
continue; // KEY_NOTFOUND
}
SetType st2{diff_res.value()->second.RObjPtr(), diff_res.value()->second.Encoding()};
if (st2.second == kEncodingIntSet) {
int ii = 0;
intset* is = (intset*)st2.first;
int64_t intele;
char buf[32];
while (intsetGet(is, ii++, &intele)) {
char* next = absl::numbers_internal::FastIntToBuffer(intele, buf);
uniques.erase(string_view{buf, size_t(next - buf)});
}
} else {
DCHECK_EQ(kEncodingStrMap, st2.second);
dict* ds = (dict*)st2.first;
dictIterator* di = dictGetIterator(ds);
dictEntry* de = nullptr;
while ((de = dictNext(di))) {
sds key = (sds)de->key;
uniques.erase(string_view{key, sdslen(key)});
}
dictReleaseIterator(di);
}
}
return ToVec(std::move(uniques));
}
OpResult<StringVec> SetFamily::OpPop(const OpArgs& op_args, std::string_view key, unsigned count) {
auto* es = op_args.shard;
OpResult<PrimeIterator> find_res = es->db_slice().Find(op_args.db_ind, key, OBJ_SET);
auto& db_slice = op_args.shard->db_slice();
OpResult<PrimeIterator> find_res = db_slice.Find(op_args.db_ind, key, OBJ_SET);
if (!find_res)
return find_res.status();
@ -986,8 +1092,9 @@ OpResult<StringVec> SetFamily::OpPop(const OpArgs& op_args, std::string_view key
if (count >= slen) {
FillSet(st, [&result](string s) { result.push_back(move(s)); });
/* Delete the set as it is now empty */
CHECK(es->db_slice().Del(op_args.db_ind, it));
CHECK(db_slice.Del(op_args.db_ind, it));
} else {
db_slice.PreUpdate(op_args.db_ind, it);
if (st.second == kEncodingIntSet) {
intset* is = (intset*)st.first;
int64_t val = 0;
@ -1012,100 +1119,11 @@ OpResult<StringVec> SetFamily::OpPop(const OpArgs& op_args, std::string_view key
}
dictReleaseIterator(di);
}
db_slice.PostUpdate(op_args.db_ind, it);
}
return result;
}
OpResult<StringVec> SetFamily::OpInter(const Transaction* t, EngineShard* es, bool remove_first) {
ArgSlice keys = t->ShardArgsInShard(es->shard_id());
if (remove_first) {
keys.remove_prefix(1);
}
DCHECK(!keys.empty());
StringVec result;
if (keys.size() == 1) {
OpResult<PrimeIterator> find_res = es->db_slice().Find(t->db_index(), keys.front(), OBJ_SET);
if (!find_res)
return find_res.status();
SetType st{find_res.value()->second.RObjPtr(), find_res.value()->second.Encoding()};
FillSet(st, [&result](string s) { result.push_back(move(s)); });
return result;
}
// we must copy by value because AsRObj is temporary.
vector<SetType> sets(keys.size());
OpStatus status = OpStatus::OK;
for (size_t i = 0; i < keys.size(); ++i) {
OpResult<PrimeIterator> find_res = es->db_slice().Find(t->db_index(), keys[i], OBJ_SET);
if (!find_res) {
if (status == OpStatus::OK || status == OpStatus::KEY_NOTFOUND ||
find_res.status() != OpStatus::KEY_NOTFOUND) {
status = find_res.status();
}
continue;
}
const PrimeValue& pv = find_res.value()->second;
void* ptr = pv.RObjPtr();
sets[i] = make_pair(ptr, pv.Encoding());
}
if (status != OpStatus::OK)
return status;
auto comp = [](const SetType& left, const SetType& right) {
return SetTypeLen(left) < SetTypeLen(right);
};
std::sort(sets.begin(), sets.end(), comp);
int encoding = sets.front().second;
if (encoding == kEncodingIntSet) {
int ii = 0;
intset* is = (intset*)sets.front().first;
int64_t intele;
while (intsetGet(is, ii++, &intele)) {
size_t j = 1;
for (j = 1; j < sets.size(); j++) {
if (sets[j].first != is && !IsInSet(sets[j], intele))
break;
}
/* Only take action when all sets contain the member */
if (j == sets.size()) {
result.push_back(absl::StrCat(intele));
}
}
} else {
dict* ds = (dict*)sets.front().first;
dictIterator* di = dictGetIterator(ds);
dictEntry* de = nullptr;
while ((de = dictNext(di))) {
size_t j = 1;
sds key = (sds)de->key;
string_view member{key, sdslen(key)};
for (j = 1; j < sets.size(); j++) {
if (sets[j].first != ds && !IsInSet(sets[j], member))
break;
}
/* Only take action when all sets contain the member */
if (j == sets.size()) {
result.push_back(string(member));
}
}
dictReleaseIterator(di);
}
return result;
}
OpResult<StringVec> SetFamily::OpScan(const OpArgs& op_args, std::string_view key,
uint64_t* cursor) {
OpResult<PrimeIterator> find_res = op_args.shard->db_slice().Find(op_args.db_ind, key, OBJ_SET);

View File

@ -43,10 +43,6 @@ class SetFamily {
static void SInterStore(CmdArgList args, ConnectionContext* cntx);
static void SScan(CmdArgList args, ConnectionContext* cntx);
static OpResult<StringVec> OpUnion(const OpArgs& op_args, ArgSlice args);
static OpResult<StringVec> OpDiff(const OpArgs& op_args, ArgSlice keys);
static OpResult<StringVec> OpInter(const Transaction* t, EngineShard* es, bool remove_first);
// count - how many elements to pop.
static OpResult<StringVec> OpPop(const OpArgs& op_args, std::string_view key, unsigned count);
static OpResult<StringVec> OpScan(const OpArgs& op_args, std::string_view key, uint64_t* cursor);