Implement zinterstore
This commit is contained in:
parent
53fb11b2fd
commit
e9dda3aa64
|
@ -206,7 +206,7 @@ API 2.0
|
|||
- [X] SSCAN
|
||||
- [X] Sorted Set Family
|
||||
- [X] ZCOUNT
|
||||
- [ ] ZINTERSTORE
|
||||
- [X] ZINTERSTORE
|
||||
- [X] ZLEXCOUNT
|
||||
- [X] ZRANGEBYLEX
|
||||
- [X] ZRANK
|
||||
|
@ -214,7 +214,7 @@ API 2.0
|
|||
- [X] ZREMRANGEBYRANK
|
||||
- [X] ZREVRANGEBYSCORE
|
||||
- [X] ZREVRANK
|
||||
- [ ] ZUNIONSTORE
|
||||
- [X] ZUNIONSTORE
|
||||
- [X] ZSCAN
|
||||
- [ ] HYPERLOGLOG Family
|
||||
- [ ] PFADD
|
||||
|
|
|
@ -632,6 +632,29 @@ void UnionScoredMap(ScoredMap* dest, ScoredMap* src, AggType agg_type) {
|
|||
dest->swap(*src);
|
||||
}
|
||||
|
||||
void InterScoredMap(ScoredMap* dest, ScoredMap* src, AggType agg_type) {
|
||||
ScoredMap* target = dest;
|
||||
ScoredMap* iter = src;
|
||||
|
||||
if (iter->size() > target->size())
|
||||
swap(target, iter);
|
||||
|
||||
auto it = iter->begin();
|
||||
while (it != iter->end()) {
|
||||
auto inter_it = target->find(it->first);
|
||||
if (inter_it == target->end()) {
|
||||
auto copy_it = it++;
|
||||
iter->erase(copy_it);
|
||||
} else {
|
||||
it->second = Aggregate(it->second, inter_it->second, agg_type);
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
if (iter != dest)
|
||||
dest->swap(*src);
|
||||
}
|
||||
|
||||
OpResult<ScoredMap> OpUnion(EngineShard* shard, Transaction* t, string_view dest, AggType agg_type,
|
||||
const vector<double>& weights, bool store) {
|
||||
ArgSlice keys = t->ShardArgsInShard(shard->shard_id());
|
||||
|
@ -678,6 +701,57 @@ OpResult<ScoredMap> OpUnion(EngineShard* shard, Transaction* t, string_view dest
|
|||
return result;
|
||||
}
|
||||
|
||||
OpResult<ScoredMap> OpInter(EngineShard* shard, Transaction* t, string_view dest, AggType agg_type,
|
||||
const vector<double>& weights, bool store) {
|
||||
ArgSlice keys = t->ShardArgsInShard(shard->shard_id());
|
||||
DVLOG(1) << "shard:" << shard->shard_id() << ", keys " << vector(keys.begin(), keys.end());
|
||||
DCHECK(!keys.empty());
|
||||
|
||||
unsigned start = 0;
|
||||
|
||||
if (keys.front() == dest) {
|
||||
++start;
|
||||
}
|
||||
|
||||
auto& db_slice = shard->db_slice();
|
||||
vector<pair<PrimeIterator, double>> it_arr(keys.size() - start);
|
||||
if (it_arr.empty()) // could be when only the dest key is hosted in this shard
|
||||
return OpStatus::SKIPPED; // return noop
|
||||
|
||||
for (unsigned j = start; j < keys.size(); ++j) {
|
||||
auto it_res = db_slice.Find(t->db_index(), keys[j], OBJ_ZSET);
|
||||
if (it_res == OpStatus::WRONG_TYPE) // TODO: support sets with default score 1.
|
||||
return it_res.status();
|
||||
|
||||
if (!it_res)
|
||||
continue; // we exit in the next loop
|
||||
|
||||
// first global index is 2 after {destkey, numkeys}
|
||||
unsigned src_indx = j - start;
|
||||
unsigned windex = t->ReverseArgIndex(shard->shard_id(), j) - 2;
|
||||
DCHECK_LT(windex, weights.size());
|
||||
it_arr[src_indx] = {*it_res, weights[windex]};
|
||||
}
|
||||
|
||||
ScoredMap result;
|
||||
for (auto it = it_arr.begin(); it != it_arr.end(); ++it) {
|
||||
if (it->first.is_done()) {
|
||||
return ScoredMap{};
|
||||
}
|
||||
|
||||
ScoredMap sm = FromObject(it->first->second, it->second);
|
||||
if (result.empty())
|
||||
result.swap(sm);
|
||||
else
|
||||
InterScoredMap(&result, &sm, agg_type);
|
||||
|
||||
if (result.empty())
|
||||
return result;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
using ScoredMemberView = std::pair<double, std::string_view>;
|
||||
using ScoredMemberSpan = absl::Span<ScoredMemberView>;
|
||||
|
||||
|
@ -760,6 +834,62 @@ OpResult<AddResult> OpAdd(const OpArgs& op_args, const ZParams& zparams, string_
|
|||
return aresult;
|
||||
}
|
||||
|
||||
struct StoreArgs {
|
||||
AggType agg_type = AggType::SUM;
|
||||
unsigned num_keys;
|
||||
vector<double> weights;
|
||||
};
|
||||
|
||||
OpResult<StoreArgs> ParseStoreArgs(CmdArgList args) {
|
||||
string_view num_str = ArgS(args, 2);
|
||||
StoreArgs store_args;
|
||||
|
||||
// we parsed the structure before, when transaction has been initialized.
|
||||
CHECK(absl::SimpleAtoi(num_str, &store_args.num_keys));
|
||||
DCHECK_GE(args.size(), 3 + store_args.num_keys);
|
||||
|
||||
store_args.weights.resize(store_args.num_keys, 1);
|
||||
for (size_t i = 3 + store_args.num_keys; i < args.size(); ++i) {
|
||||
ToUpper(&args[i]);
|
||||
string_view arg = ArgS(args, i);
|
||||
if (arg == "WEIGHTS") {
|
||||
if (args.size() <= i + store_args.num_keys) {
|
||||
return OpStatus::SYNTAX_ERR;
|
||||
}
|
||||
|
||||
for (unsigned j = 0; j < store_args.num_keys; ++j) {
|
||||
string_view weight = ArgS(args, i + j + 1);
|
||||
if (!absl::SimpleAtod(weight, &store_args.weights[j])) {
|
||||
return OpStatus::INVALID_FLOAT;
|
||||
}
|
||||
}
|
||||
i += store_args.num_keys;
|
||||
} else if (arg == "AGGREGATE") {
|
||||
if (i + 2 != args.size()) {
|
||||
return OpStatus::SYNTAX_ERR;
|
||||
}
|
||||
|
||||
ToUpper(&args[i + 1]);
|
||||
|
||||
string_view agg = ArgS(args, i + 1);
|
||||
if (agg == "SUM") {
|
||||
store_args.agg_type = AggType::SUM;
|
||||
} else if (agg == "MIN") {
|
||||
store_args.agg_type = AggType::MIN;
|
||||
} else if (agg == "MAX") {
|
||||
store_args.agg_type = AggType::MAX;
|
||||
} else {
|
||||
return OpStatus::SYNTAX_ERR;
|
||||
}
|
||||
break;
|
||||
} else {
|
||||
return OpStatus::SYNTAX_ERR;
|
||||
}
|
||||
}
|
||||
|
||||
return store_args;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void ZSetFamily::ZAdd(CmdArgList args, ConnectionContext* cntx) {
|
||||
|
@ -943,6 +1073,69 @@ void ZSetFamily::ZIncrBy(CmdArgList args, ConnectionContext* cntx) {
|
|||
}
|
||||
|
||||
void ZSetFamily::ZInterStore(CmdArgList args, ConnectionContext* cntx) {
|
||||
string_view dest_key = ArgS(args, 1);
|
||||
OpResult<StoreArgs> store_args_res = ParseStoreArgs(args);
|
||||
|
||||
if (!store_args_res) {
|
||||
switch (store_args_res.status()) {
|
||||
case OpStatus::INVALID_FLOAT:
|
||||
return (*cntx)->SendError("weight value is not a float", kSyntaxErrType);
|
||||
default:
|
||||
return (*cntx)->SendError(store_args_res.status());
|
||||
}
|
||||
}
|
||||
const auto& store_args = *store_args_res;
|
||||
if (store_args.num_keys == 0) {
|
||||
return SendAtLeastOneKeyError(cntx);
|
||||
}
|
||||
|
||||
vector<OpResult<ScoredMap>> maps(cntx->shard_set->size(), OpStatus::SKIPPED);
|
||||
|
||||
auto cb = [&](Transaction* t, EngineShard* shard) {
|
||||
maps[shard->shard_id()] =
|
||||
OpInter(shard, t, dest_key, store_args.agg_type, store_args.weights, false);
|
||||
return OpStatus::OK;
|
||||
};
|
||||
|
||||
cntx->transaction->Schedule();
|
||||
cntx->transaction->Execute(std::move(cb), false);
|
||||
|
||||
ScoredMap result;
|
||||
for (auto& op_res : maps) {
|
||||
if (op_res.status() == OpStatus::SKIPPED)
|
||||
continue;
|
||||
|
||||
if (!op_res)
|
||||
return (*cntx)->SendError(op_res.status());
|
||||
|
||||
if (result.empty())
|
||||
result.swap(op_res.value());
|
||||
else
|
||||
InterScoredMap(&result, &op_res.value(), store_args.agg_type);
|
||||
if (result.empty())
|
||||
break;
|
||||
}
|
||||
|
||||
ShardId dest_shard = Shard(dest_key, maps.size());
|
||||
AddResult add_result;
|
||||
vector<ScoredMemberView> smvec;
|
||||
for (const auto& elem : result) {
|
||||
smvec.emplace_back(elem.second, elem.first);
|
||||
}
|
||||
|
||||
auto store_cb = [&](Transaction* t, EngineShard* shard) {
|
||||
if (shard->shard_id() == dest_shard) {
|
||||
ZParams zparams;
|
||||
zparams.override = true;
|
||||
add_result =
|
||||
OpAdd(OpArgs{shard, t->db_index()}, zparams, dest_key, ScoredMemberSpan{smvec}).value();
|
||||
}
|
||||
return OpStatus::OK;
|
||||
};
|
||||
|
||||
cntx->transaction->Execute(std::move(store_cb), true);
|
||||
|
||||
(*cntx)->SendLong(smvec.size());
|
||||
}
|
||||
|
||||
void ZSetFamily::ZLexCount(CmdArgList args, ConnectionContext* cntx) {
|
||||
|
@ -1178,59 +1371,26 @@ void ZSetFamily::ZScan(CmdArgList args, ConnectionContext* cntx) {
|
|||
|
||||
void ZSetFamily::ZUnionStore(CmdArgList args, ConnectionContext* cntx) {
|
||||
string_view dest_key = ArgS(args, 1);
|
||||
string_view num_str = ArgS(args, 2);
|
||||
uint32_t num_keys;
|
||||
AggType agg_type = AggType::SUM;
|
||||
OpResult<StoreArgs> store_args_res = ParseStoreArgs(args);
|
||||
|
||||
// we parsed the structure before, when transaction has been initialized.
|
||||
CHECK(absl::SimpleAtoi(num_str, &num_keys));
|
||||
if (num_keys == 0) {
|
||||
return SendAtLeastOneKeyError(cntx);
|
||||
}
|
||||
|
||||
DCHECK_GE(args.size(), 3 + num_keys);
|
||||
|
||||
vector<double> weights(num_keys, 1);
|
||||
for (size_t i = 3 + num_keys; i < args.size(); ++i) {
|
||||
ToUpper(&args[i]);
|
||||
string_view arg = ArgS(args, i);
|
||||
if (arg == "WEIGHTS") {
|
||||
if (args.size() <= i + num_keys) {
|
||||
return (*cntx)->SendError(kSyntaxErr);
|
||||
}
|
||||
for (unsigned j = 0; j < num_keys; ++j) {
|
||||
string_view weight = ArgS(args, i + j + 1);
|
||||
if (!absl::SimpleAtod(weight, &weights[j])) {
|
||||
return (*cntx)->SendError("weight value is not a float", kSyntaxErrType);
|
||||
}
|
||||
}
|
||||
i += num_keys;
|
||||
} else if (arg == "AGGREGATE") {
|
||||
if (i + 2 != args.size()) {
|
||||
return (*cntx)->SendError(kSyntaxErr);
|
||||
}
|
||||
ToUpper(&args[i + 1]);
|
||||
|
||||
string_view agg = ArgS(args, i + 1);
|
||||
if (agg == "SUM") {
|
||||
agg_type = AggType::SUM;
|
||||
} else if (agg == "MIN") {
|
||||
agg_type = AggType::MIN;
|
||||
} else if (agg == "MAX") {
|
||||
agg_type = AggType::MAX;
|
||||
} else {
|
||||
return (*cntx)->SendError(kSyntaxErr);
|
||||
}
|
||||
break;
|
||||
} else {
|
||||
return (*cntx)->SendError(kSyntaxErr);
|
||||
if (!store_args_res) {
|
||||
switch (store_args_res.status()) {
|
||||
case OpStatus::INVALID_FLOAT:
|
||||
return (*cntx)->SendError("weight value is not a float", kSyntaxErrType);
|
||||
default:
|
||||
return (*cntx)->SendError(store_args_res.status());
|
||||
}
|
||||
}
|
||||
const auto& store_args = *store_args_res;
|
||||
if (store_args.num_keys == 0) {
|
||||
return SendAtLeastOneKeyError(cntx);
|
||||
}
|
||||
|
||||
vector<OpResult<ScoredMap>> maps(cntx->shard_set->size());
|
||||
|
||||
auto cb = [&](Transaction* t, EngineShard* shard) {
|
||||
maps[shard->shard_id()] = OpUnion(shard, t, dest_key, agg_type, weights, false);
|
||||
maps[shard->shard_id()] =
|
||||
OpUnion(shard, t, dest_key, store_args.agg_type, store_args.weights, false);
|
||||
return OpStatus::OK;
|
||||
};
|
||||
|
||||
|
@ -1242,7 +1402,7 @@ void ZSetFamily::ZUnionStore(CmdArgList args, ConnectionContext* cntx) {
|
|||
for (auto& op_res : maps) {
|
||||
if (!op_res)
|
||||
return (*cntx)->SendError(op_res.status());
|
||||
UnionScoredMap(&result, &op_res.value(), agg_type);
|
||||
UnionScoredMap(&result, &op_res.value(), store_args.agg_type);
|
||||
}
|
||||
ShardId dest_shard = Shard(dest_key, maps.size());
|
||||
AddResult add_result;
|
||||
|
|
|
@ -208,14 +208,14 @@ TEST_F(ZSetFamilyTest, ZUnionStoreOpts) {
|
|||
EXPECT_EQ(2, CheckedInt({"zadd", "z2", "3", "c", "2", "b"}));
|
||||
RespExpr resp;
|
||||
|
||||
EXPECT_EQ(3, CheckedInt({"zunionstore", "a", "2", "z1", "z2", "weights", "1", "3"}));
|
||||
EXPECT_EQ(3, CheckedInt({"zunionstore", "a", "2", "z1", "z2", "weights", "1", "3"}));
|
||||
resp = Run({"zrange", "a", "0", "-1", "withscores"});
|
||||
EXPECT_THAT(resp.GetVec(), ElementsAre("a", "1", "b", "8", "c", "9"));
|
||||
|
||||
resp = Run({"zunionstore", "a", "2", "z1", "z2", "weights", "1"});
|
||||
resp = Run({"zunionstore", "a", "2", "z1", "z2", "weights", "1"});
|
||||
EXPECT_THAT(resp, ErrArg("syntax error"));
|
||||
|
||||
resp = Run({"zunionstore", "z1", "1", "z1", "weights", "2"});
|
||||
resp = Run({"zunionstore", "z1", "1", "z1", "weights", "2"});
|
||||
EXPECT_THAT(resp, IntArg(2));
|
||||
resp = Run({"zrange", "z1", "0", "-1", "withscores"});
|
||||
EXPECT_THAT(resp.GetVec(), ElementsAre("a", "2", "b", "4"));
|
||||
|
@ -226,4 +226,14 @@ TEST_F(ZSetFamilyTest, ZUnionStoreOpts) {
|
|||
EXPECT_THAT(resp.GetVec(), ElementsAre("c", "0", "a", "2", "b", "4"));
|
||||
}
|
||||
|
||||
TEST_F(ZSetFamilyTest, ZInterStore) {
|
||||
EXPECT_EQ(2, CheckedInt({"zadd", "z1", "1", "a", "2", "b"}));
|
||||
EXPECT_EQ(2, CheckedInt({"zadd", "z2", "3", "c", "2", "b"}));
|
||||
RespExpr resp;
|
||||
|
||||
EXPECT_EQ(1, CheckedInt({"zinterstore", "a", "2", "z1", "z2"}));
|
||||
resp = Run({"zrange", "a", "0", "-1", "withscores"});
|
||||
EXPECT_THAT(resp.GetVec(), ElementsAre("b", "4"));
|
||||
}
|
||||
|
||||
} // namespace dfly
|
||||
|
|
Loading…
Reference in New Issue