Implement zinterstore

This commit is contained in:
Roman Gershman 2022-05-10 21:48:24 +03:00
parent 53fb11b2fd
commit e9dda3aa64
3 changed files with 222 additions and 52 deletions

View File

@ -206,7 +206,7 @@ API 2.0
- [X] SSCAN
- [X] Sorted Set Family
- [X] ZCOUNT
- [ ] ZINTERSTORE
- [X] ZINTERSTORE
- [X] ZLEXCOUNT
- [X] ZRANGEBYLEX
- [X] ZRANK
@ -214,7 +214,7 @@ API 2.0
- [X] ZREMRANGEBYRANK
- [X] ZREVRANGEBYSCORE
- [X] ZREVRANK
- [ ] ZUNIONSTORE
- [X] ZUNIONSTORE
- [X] ZSCAN
- [ ] HYPERLOGLOG Family
- [ ] PFADD

View File

@ -632,6 +632,29 @@ void UnionScoredMap(ScoredMap* dest, ScoredMap* src, AggType agg_type) {
dest->swap(*src);
}
void InterScoredMap(ScoredMap* dest, ScoredMap* src, AggType agg_type) {
ScoredMap* target = dest;
ScoredMap* iter = src;
if (iter->size() > target->size())
swap(target, iter);
auto it = iter->begin();
while (it != iter->end()) {
auto inter_it = target->find(it->first);
if (inter_it == target->end()) {
auto copy_it = it++;
iter->erase(copy_it);
} else {
it->second = Aggregate(it->second, inter_it->second, agg_type);
++it;
}
}
if (iter != dest)
dest->swap(*src);
}
OpResult<ScoredMap> OpUnion(EngineShard* shard, Transaction* t, string_view dest, AggType agg_type,
const vector<double>& weights, bool store) {
ArgSlice keys = t->ShardArgsInShard(shard->shard_id());
@ -678,6 +701,57 @@ OpResult<ScoredMap> OpUnion(EngineShard* shard, Transaction* t, string_view dest
return result;
}
OpResult<ScoredMap> OpInter(EngineShard* shard, Transaction* t, string_view dest, AggType agg_type,
const vector<double>& weights, bool store) {
ArgSlice keys = t->ShardArgsInShard(shard->shard_id());
DVLOG(1) << "shard:" << shard->shard_id() << ", keys " << vector(keys.begin(), keys.end());
DCHECK(!keys.empty());
unsigned start = 0;
if (keys.front() == dest) {
++start;
}
auto& db_slice = shard->db_slice();
vector<pair<PrimeIterator, double>> it_arr(keys.size() - start);
if (it_arr.empty()) // could be when only the dest key is hosted in this shard
return OpStatus::SKIPPED; // return noop
for (unsigned j = start; j < keys.size(); ++j) {
auto it_res = db_slice.Find(t->db_index(), keys[j], OBJ_ZSET);
if (it_res == OpStatus::WRONG_TYPE) // TODO: support sets with default score 1.
return it_res.status();
if (!it_res)
continue; // we exit in the next loop
// first global index is 2 after {destkey, numkeys}
unsigned src_indx = j - start;
unsigned windex = t->ReverseArgIndex(shard->shard_id(), j) - 2;
DCHECK_LT(windex, weights.size());
it_arr[src_indx] = {*it_res, weights[windex]};
}
ScoredMap result;
for (auto it = it_arr.begin(); it != it_arr.end(); ++it) {
if (it->first.is_done()) {
return ScoredMap{};
}
ScoredMap sm = FromObject(it->first->second, it->second);
if (result.empty())
result.swap(sm);
else
InterScoredMap(&result, &sm, agg_type);
if (result.empty())
return result;
}
return result;
}
using ScoredMemberView = std::pair<double, std::string_view>;
using ScoredMemberSpan = absl::Span<ScoredMemberView>;
@ -760,6 +834,62 @@ OpResult<AddResult> OpAdd(const OpArgs& op_args, const ZParams& zparams, string_
return aresult;
}
struct StoreArgs {
AggType agg_type = AggType::SUM;
unsigned num_keys;
vector<double> weights;
};
OpResult<StoreArgs> ParseStoreArgs(CmdArgList args) {
string_view num_str = ArgS(args, 2);
StoreArgs store_args;
// we parsed the structure before, when transaction has been initialized.
CHECK(absl::SimpleAtoi(num_str, &store_args.num_keys));
DCHECK_GE(args.size(), 3 + store_args.num_keys);
store_args.weights.resize(store_args.num_keys, 1);
for (size_t i = 3 + store_args.num_keys; i < args.size(); ++i) {
ToUpper(&args[i]);
string_view arg = ArgS(args, i);
if (arg == "WEIGHTS") {
if (args.size() <= i + store_args.num_keys) {
return OpStatus::SYNTAX_ERR;
}
for (unsigned j = 0; j < store_args.num_keys; ++j) {
string_view weight = ArgS(args, i + j + 1);
if (!absl::SimpleAtod(weight, &store_args.weights[j])) {
return OpStatus::INVALID_FLOAT;
}
}
i += store_args.num_keys;
} else if (arg == "AGGREGATE") {
if (i + 2 != args.size()) {
return OpStatus::SYNTAX_ERR;
}
ToUpper(&args[i + 1]);
string_view agg = ArgS(args, i + 1);
if (agg == "SUM") {
store_args.agg_type = AggType::SUM;
} else if (agg == "MIN") {
store_args.agg_type = AggType::MIN;
} else if (agg == "MAX") {
store_args.agg_type = AggType::MAX;
} else {
return OpStatus::SYNTAX_ERR;
}
break;
} else {
return OpStatus::SYNTAX_ERR;
}
}
return store_args;
};
} // namespace
void ZSetFamily::ZAdd(CmdArgList args, ConnectionContext* cntx) {
@ -943,6 +1073,69 @@ void ZSetFamily::ZIncrBy(CmdArgList args, ConnectionContext* cntx) {
}
void ZSetFamily::ZInterStore(CmdArgList args, ConnectionContext* cntx) {
string_view dest_key = ArgS(args, 1);
OpResult<StoreArgs> store_args_res = ParseStoreArgs(args);
if (!store_args_res) {
switch (store_args_res.status()) {
case OpStatus::INVALID_FLOAT:
return (*cntx)->SendError("weight value is not a float", kSyntaxErrType);
default:
return (*cntx)->SendError(store_args_res.status());
}
}
const auto& store_args = *store_args_res;
if (store_args.num_keys == 0) {
return SendAtLeastOneKeyError(cntx);
}
vector<OpResult<ScoredMap>> maps(cntx->shard_set->size(), OpStatus::SKIPPED);
auto cb = [&](Transaction* t, EngineShard* shard) {
maps[shard->shard_id()] =
OpInter(shard, t, dest_key, store_args.agg_type, store_args.weights, false);
return OpStatus::OK;
};
cntx->transaction->Schedule();
cntx->transaction->Execute(std::move(cb), false);
ScoredMap result;
for (auto& op_res : maps) {
if (op_res.status() == OpStatus::SKIPPED)
continue;
if (!op_res)
return (*cntx)->SendError(op_res.status());
if (result.empty())
result.swap(op_res.value());
else
InterScoredMap(&result, &op_res.value(), store_args.agg_type);
if (result.empty())
break;
}
ShardId dest_shard = Shard(dest_key, maps.size());
AddResult add_result;
vector<ScoredMemberView> smvec;
for (const auto& elem : result) {
smvec.emplace_back(elem.second, elem.first);
}
auto store_cb = [&](Transaction* t, EngineShard* shard) {
if (shard->shard_id() == dest_shard) {
ZParams zparams;
zparams.override = true;
add_result =
OpAdd(OpArgs{shard, t->db_index()}, zparams, dest_key, ScoredMemberSpan{smvec}).value();
}
return OpStatus::OK;
};
cntx->transaction->Execute(std::move(store_cb), true);
(*cntx)->SendLong(smvec.size());
}
void ZSetFamily::ZLexCount(CmdArgList args, ConnectionContext* cntx) {
@ -1178,59 +1371,26 @@ void ZSetFamily::ZScan(CmdArgList args, ConnectionContext* cntx) {
void ZSetFamily::ZUnionStore(CmdArgList args, ConnectionContext* cntx) {
string_view dest_key = ArgS(args, 1);
string_view num_str = ArgS(args, 2);
uint32_t num_keys;
AggType agg_type = AggType::SUM;
OpResult<StoreArgs> store_args_res = ParseStoreArgs(args);
// we parsed the structure before, when transaction has been initialized.
CHECK(absl::SimpleAtoi(num_str, &num_keys));
if (num_keys == 0) {
return SendAtLeastOneKeyError(cntx);
}
DCHECK_GE(args.size(), 3 + num_keys);
vector<double> weights(num_keys, 1);
for (size_t i = 3 + num_keys; i < args.size(); ++i) {
ToUpper(&args[i]);
string_view arg = ArgS(args, i);
if (arg == "WEIGHTS") {
if (args.size() <= i + num_keys) {
return (*cntx)->SendError(kSyntaxErr);
}
for (unsigned j = 0; j < num_keys; ++j) {
string_view weight = ArgS(args, i + j + 1);
if (!absl::SimpleAtod(weight, &weights[j])) {
return (*cntx)->SendError("weight value is not a float", kSyntaxErrType);
}
}
i += num_keys;
} else if (arg == "AGGREGATE") {
if (i + 2 != args.size()) {
return (*cntx)->SendError(kSyntaxErr);
}
ToUpper(&args[i + 1]);
string_view agg = ArgS(args, i + 1);
if (agg == "SUM") {
agg_type = AggType::SUM;
} else if (agg == "MIN") {
agg_type = AggType::MIN;
} else if (agg == "MAX") {
agg_type = AggType::MAX;
} else {
return (*cntx)->SendError(kSyntaxErr);
}
break;
} else {
return (*cntx)->SendError(kSyntaxErr);
if (!store_args_res) {
switch (store_args_res.status()) {
case OpStatus::INVALID_FLOAT:
return (*cntx)->SendError("weight value is not a float", kSyntaxErrType);
default:
return (*cntx)->SendError(store_args_res.status());
}
}
const auto& store_args = *store_args_res;
if (store_args.num_keys == 0) {
return SendAtLeastOneKeyError(cntx);
}
vector<OpResult<ScoredMap>> maps(cntx->shard_set->size());
auto cb = [&](Transaction* t, EngineShard* shard) {
maps[shard->shard_id()] = OpUnion(shard, t, dest_key, agg_type, weights, false);
maps[shard->shard_id()] =
OpUnion(shard, t, dest_key, store_args.agg_type, store_args.weights, false);
return OpStatus::OK;
};
@ -1242,7 +1402,7 @@ void ZSetFamily::ZUnionStore(CmdArgList args, ConnectionContext* cntx) {
for (auto& op_res : maps) {
if (!op_res)
return (*cntx)->SendError(op_res.status());
UnionScoredMap(&result, &op_res.value(), agg_type);
UnionScoredMap(&result, &op_res.value(), store_args.agg_type);
}
ShardId dest_shard = Shard(dest_key, maps.size());
AddResult add_result;

View File

@ -208,14 +208,14 @@ TEST_F(ZSetFamilyTest, ZUnionStoreOpts) {
EXPECT_EQ(2, CheckedInt({"zadd", "z2", "3", "c", "2", "b"}));
RespExpr resp;
EXPECT_EQ(3, CheckedInt({"zunionstore", "a", "2", "z1", "z2", "weights", "1", "3"}));
EXPECT_EQ(3, CheckedInt({"zunionstore", "a", "2", "z1", "z2", "weights", "1", "3"}));
resp = Run({"zrange", "a", "0", "-1", "withscores"});
EXPECT_THAT(resp.GetVec(), ElementsAre("a", "1", "b", "8", "c", "9"));
resp = Run({"zunionstore", "a", "2", "z1", "z2", "weights", "1"});
resp = Run({"zunionstore", "a", "2", "z1", "z2", "weights", "1"});
EXPECT_THAT(resp, ErrArg("syntax error"));
resp = Run({"zunionstore", "z1", "1", "z1", "weights", "2"});
resp = Run({"zunionstore", "z1", "1", "z1", "weights", "2"});
EXPECT_THAT(resp, IntArg(2));
resp = Run({"zrange", "z1", "0", "-1", "withscores"});
EXPECT_THAT(resp.GetVec(), ElementsAre("a", "2", "b", "4"));
@ -226,4 +226,14 @@ TEST_F(ZSetFamilyTest, ZUnionStoreOpts) {
EXPECT_THAT(resp.GetVec(), ElementsAre("c", "0", "a", "2", "b", "4"));
}
TEST_F(ZSetFamilyTest, ZInterStore) {
EXPECT_EQ(2, CheckedInt({"zadd", "z1", "1", "a", "2", "b"}));
EXPECT_EQ(2, CheckedInt({"zadd", "z2", "3", "c", "2", "b"}));
RespExpr resp;
EXPECT_EQ(1, CheckedInt({"zinterstore", "a", "2", "z1", "z2"}));
resp = Run({"zrange", "a", "0", "-1", "withscores"});
EXPECT_THAT(resp.GetVec(), ElementsAre("b", "4"));
}
} // namespace dfly