diff --git a/src/server/set_family.cc b/src/server/set_family.cc index ca62059..af74f4a 100644 --- a/src/server/set_family.cc +++ b/src/server/set_family.cc @@ -85,6 +85,26 @@ pair RemoveSet(ArgSlice vals, CompactObj* set) { return make_pair(removed, isempty); } +void InitSet(ArgSlice vals, CompactObj* set) { + bool int_set = true; + long long intv; + + for (auto v : vals) { + if (!string2ll(v.data(), v.size(), &intv)) { + int_set = false; + break; + } + } + + if (int_set) { + intset* is = intsetNew(); + set->InitRobj(OBJ_SET, kEncodingIntSet, is); + } else { + dict* ds = dictCreate(&setDictType); + set->InitRobj(OBJ_SET, kEncodingStrMap, ds); + } +} + vector ToVec(absl::flat_hash_set&& set) { vector result(set.size()); size_t i = 0; @@ -159,7 +179,7 @@ OpResult InterResultVec(const ResultStringVec& result_vec, unsigned req bool first = true; for (const auto& res : result_vec) { if (res.status() == OpStatus::SKIPPED) - continue; + continue; DCHECK(res); // we handled it above. @@ -280,10 +300,13 @@ template void FillSet(const SetType& set, F&& f) { } // if overwrite is true then OpAdd writes vals into the key and discards its previous value. -OpResult OpAdd(const OpArgs& op_args, std::string_view key, const ArgSlice& vals, +OpResult OpAdd(const OpArgs& op_args, std::string_view key, ArgSlice vals, bool overwrite) { auto* es = op_args.shard; auto& db_slice = es->db_slice(); + + // overwrite - meaning we run in the context of 2-hop operation and we had already + // ensured that the key exists. if (overwrite && vals.empty()) { auto it = db_slice.FindExt(op_args.db_ind, key).first; db_slice.Del(op_args.db_ind, it); @@ -291,35 +314,22 @@ OpResult OpAdd(const OpArgs& op_args, std::string_view key, const ArgS } const auto [it, inserted] = db_slice.AddOrFind(op_args.db_ind, key); - if (!inserted) { - db_slice.PreUpdate(op_args.db_ind, it); - } CompactObj& co = it->second; - if (inserted || overwrite) { - bool int_set = true; - long long intv; - - for (auto v : vals) { - if (!string2ll(v.data(), v.size(), &intv)) { - int_set = false; - break; - } - } - - if (int_set) { - intset* is = intsetNew(); - co.InitRobj(OBJ_SET, kEncodingIntSet, is); - } else { - dict* ds = dictCreate(&setDictType); - co.InitRobj(OBJ_SET, kEncodingStrMap, ds); - } - } else { - // We delibirately check only now because with othewrite=true - // we may write into object of a different type via ImportRObj above. - if (co.ObjType() != OBJ_SET) + if (!inserted) { + // for non-overwrite case it must be set. + if (!overwrite && co.ObjType() != OBJ_SET) return OpStatus::WRONG_TYPE; + + // Update stats and trigger any handle the old value if needed. + db_slice.PreUpdate(op_args.db_ind, it); + } + + if (inserted || overwrite) { + // does not store the values, merely sets the encoding. + // TODO: why not store the values as well? + InitSet(vals, &co); } void* inner_obj = co.RObjPtr(); @@ -376,13 +386,14 @@ OpResult OpRem(const OpArgs& op_args, std::string_view key, const ArgS } db_slice.PreUpdate(op_args.db_ind, *find_res); + CompactObj& co = find_res.value()->second; auto [removed, isempty] = RemoveSet(vals, &co); + db_slice.PostUpdate(op_args.db_ind, *find_res); + if (isempty) { CHECK(db_slice.Del(op_args.db_ind, find_res.value())); - } else { - db_slice.PostUpdate(op_args.db_ind, *find_res); } return removed; @@ -482,6 +493,172 @@ void ScanCallback(void* privdata, const dictEntry* de) { sv->push_back(string(key, sdslen(key))); } +// Read-only OpUnion op on sets. +OpResult OpUnion(const OpArgs& op_args, ArgSlice keys) { + DCHECK(!keys.empty()); + absl::flat_hash_set uniques; + + for (std::string_view key : keys) { + OpResult find_res = op_args.shard->db_slice().Find(op_args.db_ind, key, OBJ_SET); + if (find_res) { + SetType st{find_res.value()->second.RObjPtr(), find_res.value()->second.Encoding()}; + FillSet(st, [&uniques](string s) { uniques.emplace(move(s)); }); + continue; + } + + if (find_res.status() != OpStatus::KEY_NOTFOUND) { + return find_res.status(); + } + } + + return ToVec(std::move(uniques)); +} + +// Read-only OpDiff op on sets. +OpResult OpDiff(const OpArgs& op_args, ArgSlice keys) { + DCHECK(!keys.empty()); + DVLOG(1) << "OpDiff from " << keys.front(); + EngineShard* es = op_args.shard; + OpResult find_res = es->db_slice().Find(op_args.db_ind, keys.front(), OBJ_SET); + + if (!find_res) { + return find_res.status(); + } + + absl::flat_hash_set uniques; + SetType st{find_res.value()->second.RObjPtr(), find_res.value()->second.Encoding()}; + + FillSet(st, [&uniques](string s) { uniques.insert(move(s)); }); + + DCHECK(!uniques.empty()); // otherwise the key would not exist. + + for (size_t i = 1; i < keys.size(); ++i) { + OpResult diff_res = es->db_slice().Find(op_args.db_ind, keys[i], OBJ_SET); + if (!diff_res) { + if (diff_res.status() == OpStatus::WRONG_TYPE) { + return OpStatus::WRONG_TYPE; + } + continue; // KEY_NOTFOUND + } + + SetType st2{diff_res.value()->second.RObjPtr(), diff_res.value()->second.Encoding()}; + if (st2.second == kEncodingIntSet) { + int ii = 0; + intset* is = (intset*)st2.first; + int64_t intele; + char buf[32]; + + while (intsetGet(is, ii++, &intele)) { + char* next = absl::numbers_internal::FastIntToBuffer(intele, buf); + uniques.erase(string_view{buf, size_t(next - buf)}); + } + } else { + DCHECK_EQ(kEncodingStrMap, st2.second); + dict* ds = (dict*)st2.first; + dictIterator* di = dictGetIterator(ds); + dictEntry* de = nullptr; + while ((de = dictNext(di))) { + sds key = (sds)de->key; + uniques.erase(string_view{key, sdslen(key)}); + } + dictReleaseIterator(di); + } + } + + return ToVec(std::move(uniques)); +} + +// Read-only OpInter op on sets. +OpResult OpInter(const Transaction* t, EngineShard* es, bool remove_first) { + ArgSlice keys = t->ShardArgsInShard(es->shard_id()); + if (remove_first) { + keys.remove_prefix(1); + } + DCHECK(!keys.empty()); + + StringVec result; + if (keys.size() == 1) { + OpResult find_res = es->db_slice().Find(t->db_index(), keys.front(), OBJ_SET); + if (!find_res) + return find_res.status(); + + SetType st{find_res.value()->second.RObjPtr(), find_res.value()->second.Encoding()}; + + FillSet(st, [&result](string s) { result.push_back(move(s)); }); + return result; + } + + // we must copy by value because AsRObj is temporary. + vector sets(keys.size()); + + OpStatus status = OpStatus::OK; + + for (size_t i = 0; i < keys.size(); ++i) { + OpResult find_res = es->db_slice().Find(t->db_index(), keys[i], OBJ_SET); + if (!find_res) { + if (status == OpStatus::OK || status == OpStatus::KEY_NOTFOUND || + find_res.status() != OpStatus::KEY_NOTFOUND) { + status = find_res.status(); + } + continue; + } + const PrimeValue& pv = find_res.value()->second; + void* ptr = pv.RObjPtr(); + sets[i] = make_pair(ptr, pv.Encoding()); + } + + if (status != OpStatus::OK) + return status; + + auto comp = [](const SetType& left, const SetType& right) { + return SetTypeLen(left) < SetTypeLen(right); + }; + + std::sort(sets.begin(), sets.end(), comp); + + int encoding = sets.front().second; + if (encoding == kEncodingIntSet) { + int ii = 0; + intset* is = (intset*)sets.front().first; + int64_t intele; + + while (intsetGet(is, ii++, &intele)) { + size_t j = 1; + for (j = 1; j < sets.size(); j++) { + if (sets[j].first != is && !IsInSet(sets[j], intele)) + break; + } + + /* Only take action when all sets contain the member */ + if (j == sets.size()) { + result.push_back(absl::StrCat(intele)); + } + } + } else { + dict* ds = (dict*)sets.front().first; + dictIterator* di = dictGetIterator(ds); + dictEntry* de = nullptr; + while ((de = dictNext(di))) { + size_t j = 1; + sds key = (sds)de->key; + string_view member{key, sdslen(key)}; + + for (j = 1; j < sets.size(); j++) { + if (sets[j].first != ds && !IsInSet(sets[j], member)) + break; + } + + /* Only take action when all sets contain the member */ + if (j == sets.size()) { + result.push_back(string(member)); + } + } + dictReleaseIterator(di); + } + + return result; +} + } // namespace void SetFamily::SAdd(CmdArgList args, ConnectionContext* cntx) { @@ -676,6 +853,7 @@ void SetFamily::SDiffStore(CmdArgList args, ConnectionContext* cntx) { VLOG(1) << "SDiffStore " << src_key << " " << src_shard; + // read-only op auto diff_cb = [&](Transaction* t, EngineShard* shard) { ArgSlice largs = t->ShardArgsInShard(shard->shard_id()); DCHECK(!largs.empty()); @@ -690,10 +868,11 @@ void SetFamily::SDiffStore(CmdArgList args, ConnectionContext* cntx) { OpArgs op_args{shard, t->db_index()}; if (shard->shard_id() == src_shard) { CHECK_EQ(src_key, largs.front()); - result_set[shard->shard_id()] = OpDiff(op_args, largs); + result_set[shard->shard_id()] = OpDiff(op_args, largs); // Diff } else { - result_set[shard->shard_id()] = OpUnion(op_args, largs); + result_set[shard->shard_id()] = OpUnion(op_args, largs); // Union } + return OpStatus::OK; }; @@ -893,82 +1072,9 @@ void SetFamily::SScan(CmdArgList args, ConnectionContext* cntx) { } } -OpResult SetFamily::OpUnion(const OpArgs& op_args, ArgSlice keys) { - DCHECK(!keys.empty()); - absl::flat_hash_set uniques; - - for (std::string_view key : keys) { - OpResult find_res = op_args.shard->db_slice().Find(op_args.db_ind, key, OBJ_SET); - if (find_res) { - SetType st{find_res.value()->second.RObjPtr(), find_res.value()->second.Encoding()}; - FillSet(st, [&uniques](string s) { uniques.emplace(move(s)); }); - continue; - } - - if (find_res.status() != OpStatus::KEY_NOTFOUND) { - return find_res.status(); - } - } - - return ToVec(std::move(uniques)); -} - -OpResult SetFamily::OpDiff(const OpArgs& op_args, ArgSlice keys) { - DCHECK(!keys.empty()); - DVLOG(1) << "OpDiff from " << keys.front(); - EngineShard* es = op_args.shard; - OpResult find_res = es->db_slice().Find(op_args.db_ind, keys.front(), OBJ_SET); - - if (!find_res) { - return find_res.status(); - } - - absl::flat_hash_set uniques; - SetType st{find_res.value()->second.RObjPtr(), find_res.value()->second.Encoding()}; - - FillSet(st, [&uniques](string s) { uniques.insert(move(s)); }); - - DCHECK(!uniques.empty()); // otherwise the key would not exist. - - for (size_t i = 1; i < keys.size(); ++i) { - OpResult diff_res = es->db_slice().Find(op_args.db_ind, keys[i], OBJ_SET); - if (!diff_res) { - if (diff_res.status() == OpStatus::WRONG_TYPE) { - return OpStatus::WRONG_TYPE; - } - continue; // KEY_NOTFOUND - } - - SetType st2{diff_res.value()->second.RObjPtr(), diff_res.value()->second.Encoding()}; - if (st2.second == kEncodingIntSet) { - int ii = 0; - intset* is = (intset*)st2.first; - int64_t intele; - char buf[32]; - - while (intsetGet(is, ii++, &intele)) { - char* next = absl::numbers_internal::FastIntToBuffer(intele, buf); - uniques.erase(string_view{buf, size_t(next - buf)}); - } - } else { - DCHECK_EQ(kEncodingStrMap, st2.second); - dict* ds = (dict*)st2.first; - dictIterator* di = dictGetIterator(ds); - dictEntry* de = nullptr; - while ((de = dictNext(di))) { - sds key = (sds)de->key; - uniques.erase(string_view{key, sdslen(key)}); - } - dictReleaseIterator(di); - } - } - - return ToVec(std::move(uniques)); -} - OpResult SetFamily::OpPop(const OpArgs& op_args, std::string_view key, unsigned count) { - auto* es = op_args.shard; - OpResult find_res = es->db_slice().Find(op_args.db_ind, key, OBJ_SET); + auto& db_slice = op_args.shard->db_slice(); + OpResult find_res = db_slice.Find(op_args.db_ind, key, OBJ_SET); if (!find_res) return find_res.status(); @@ -986,8 +1092,9 @@ OpResult SetFamily::OpPop(const OpArgs& op_args, std::string_view key if (count >= slen) { FillSet(st, [&result](string s) { result.push_back(move(s)); }); /* Delete the set as it is now empty */ - CHECK(es->db_slice().Del(op_args.db_ind, it)); + CHECK(db_slice.Del(op_args.db_ind, it)); } else { + db_slice.PreUpdate(op_args.db_ind, it); if (st.second == kEncodingIntSet) { intset* is = (intset*)st.first; int64_t val = 0; @@ -1012,100 +1119,11 @@ OpResult SetFamily::OpPop(const OpArgs& op_args, std::string_view key } dictReleaseIterator(di); } + db_slice.PostUpdate(op_args.db_ind, it); } return result; } -OpResult SetFamily::OpInter(const Transaction* t, EngineShard* es, bool remove_first) { - ArgSlice keys = t->ShardArgsInShard(es->shard_id()); - if (remove_first) { - keys.remove_prefix(1); - } - DCHECK(!keys.empty()); - - StringVec result; - if (keys.size() == 1) { - OpResult find_res = es->db_slice().Find(t->db_index(), keys.front(), OBJ_SET); - if (!find_res) - return find_res.status(); - - SetType st{find_res.value()->second.RObjPtr(), find_res.value()->second.Encoding()}; - - FillSet(st, [&result](string s) { result.push_back(move(s)); }); - return result; - } - - // we must copy by value because AsRObj is temporary. - vector sets(keys.size()); - - OpStatus status = OpStatus::OK; - - for (size_t i = 0; i < keys.size(); ++i) { - OpResult find_res = es->db_slice().Find(t->db_index(), keys[i], OBJ_SET); - if (!find_res) { - if (status == OpStatus::OK || status == OpStatus::KEY_NOTFOUND || - find_res.status() != OpStatus::KEY_NOTFOUND) { - status = find_res.status(); - } - continue; - } - const PrimeValue& pv = find_res.value()->second; - void* ptr = pv.RObjPtr(); - sets[i] = make_pair(ptr, pv.Encoding()); - } - - if (status != OpStatus::OK) - return status; - - auto comp = [](const SetType& left, const SetType& right) { - return SetTypeLen(left) < SetTypeLen(right); - }; - - std::sort(sets.begin(), sets.end(), comp); - - int encoding = sets.front().second; - if (encoding == kEncodingIntSet) { - int ii = 0; - intset* is = (intset*)sets.front().first; - int64_t intele; - - while (intsetGet(is, ii++, &intele)) { - size_t j = 1; - for (j = 1; j < sets.size(); j++) { - if (sets[j].first != is && !IsInSet(sets[j], intele)) - break; - } - - /* Only take action when all sets contain the member */ - if (j == sets.size()) { - result.push_back(absl::StrCat(intele)); - } - } - } else { - dict* ds = (dict*)sets.front().first; - dictIterator* di = dictGetIterator(ds); - dictEntry* de = nullptr; - while ((de = dictNext(di))) { - size_t j = 1; - sds key = (sds)de->key; - string_view member{key, sdslen(key)}; - - for (j = 1; j < sets.size(); j++) { - if (sets[j].first != ds && !IsInSet(sets[j], member)) - break; - } - - /* Only take action when all sets contain the member */ - if (j == sets.size()) { - result.push_back(string(member)); - } - } - dictReleaseIterator(di); - } - - return result; -} - OpResult SetFamily::OpScan(const OpArgs& op_args, std::string_view key, uint64_t* cursor) { OpResult find_res = op_args.shard->db_slice().Find(op_args.db_ind, key, OBJ_SET); diff --git a/src/server/set_family.h b/src/server/set_family.h index ee0fac0..c48a211 100644 --- a/src/server/set_family.h +++ b/src/server/set_family.h @@ -43,10 +43,6 @@ class SetFamily { static void SInterStore(CmdArgList args, ConnectionContext* cntx); static void SScan(CmdArgList args, ConnectionContext* cntx); - static OpResult OpUnion(const OpArgs& op_args, ArgSlice args); - static OpResult OpDiff(const OpArgs& op_args, ArgSlice keys); - static OpResult OpInter(const Transaction* t, EngineShard* es, bool remove_first); - // count - how many elements to pop. static OpResult OpPop(const OpArgs& op_args, std::string_view key, unsigned count); static OpResult OpScan(const OpArgs& op_args, std::string_view key, uint64_t* cursor);