Fix SDIFFSTORE bug

This commit is contained in:
Roman Gershman 2022-03-30 02:04:49 +03:00
parent e29f76ad4d
commit 39ef7bf630
3 changed files with 19 additions and 11 deletions

View File

@ -649,7 +649,7 @@ void SetFamily::SDiff(CmdArgList args, ConnectionContext* cntx) {
ArgSlice largs = t->ShardArgsInShard(shard->shard_id());
if (shard->shard_id() == src_shard) {
CHECK_EQ(src_key, largs.front());
result_set[shard->shard_id()] = OpDiff(t, shard);
result_set[shard->shard_id()] = OpDiff(OpArgs{shard, t->db_index()}, largs);
} else {
result_set[shard->shard_id()] = OpUnion(OpArgs{shard, t->db_index()}, largs);
}
@ -678,6 +678,8 @@ void SetFamily::SDiffStore(CmdArgList args, ConnectionContext* cntx) {
std::string_view src_key = ArgS(args, 2);
ShardId src_shard = Shard(src_key, result_set.size());
VLOG(1) << "SDiffStore " << src_key << " " << src_shard;
auto diff_cb = [&](Transaction* t, EngineShard* shard) {
ArgSlice largs = t->ShardArgsInShard(shard->shard_id());
DCHECK(!largs.empty());
@ -689,11 +691,12 @@ void SetFamily::SDiffStore(CmdArgList args, ConnectionContext* cntx) {
return OpStatus::OK;
}
OpArgs op_args{shard, t->db_index()};
if (shard->shard_id() == src_shard) {
CHECK_EQ(src_key, largs.front());
result_set[shard->shard_id()] = OpDiff(t, shard);
result_set[shard->shard_id()] = OpDiff(op_args, largs);
} else {
result_set[shard->shard_id()] = OpUnion(OpArgs{shard, t->db_index()}, largs);
result_set[shard->shard_id()] = OpUnion(op_args, largs);
}
return OpStatus::OK;
};
@ -894,7 +897,7 @@ void SetFamily::SScan(CmdArgList args, ConnectionContext* cntx) {
}
}
OpResult<StringVec> SetFamily::OpUnion(const OpArgs& op_args, const ArgSlice& keys) {
OpResult<StringVec> SetFamily::OpUnion(const OpArgs& op_args, ArgSlice keys) {
DCHECK(!keys.empty());
absl::flat_hash_set<string> uniques;
@ -914,11 +917,11 @@ OpResult<StringVec> SetFamily::OpUnion(const OpArgs& op_args, const ArgSlice& ke
return ToVec(std::move(uniques));
}
OpResult<StringVec> SetFamily::OpDiff(const Transaction* t, EngineShard* es) {
ArgSlice keys = t->ShardArgsInShard(es->shard_id());
OpResult<StringVec> SetFamily::OpDiff(const OpArgs& op_args, ArgSlice keys) {
DCHECK(!keys.empty());
OpResult<MainIterator> find_res = es->db_slice().Find(t->db_index(), keys.front(), OBJ_SET);
DVLOG(1) << "OpDiff from " << keys.front();
EngineShard* es = op_args.shard;
OpResult<MainIterator> find_res = es->db_slice().Find(op_args.db_ind, keys.front(), OBJ_SET);
if (!find_res) {
return find_res.status();
@ -932,7 +935,7 @@ OpResult<StringVec> SetFamily::OpDiff(const Transaction* t, EngineShard* es) {
DCHECK(!uniques.empty()); // otherwise the key would not exist.
for (size_t i = 1; i < keys.size(); ++i) {
OpResult<MainIterator> diff_res = es->db_slice().Find(t->db_index(), keys[i], OBJ_SET);
OpResult<MainIterator> diff_res = es->db_slice().Find(op_args.db_ind, keys[i], OBJ_SET);
if (!diff_res) {
if (diff_res.status() == OpStatus::WRONG_TYPE) {
return OpStatus::WRONG_TYPE;

View File

@ -35,8 +35,8 @@ class SetFamily {
static void SInterStore(CmdArgList args, ConnectionContext* cntx);
static void SScan(CmdArgList args, ConnectionContext* cntx);
static OpResult<StringVec> OpUnion(const OpArgs& op_args, const ArgSlice& args);
static OpResult<StringVec> OpDiff(const Transaction* t, EngineShard* es);
static OpResult<StringVec> OpUnion(const OpArgs& op_args, ArgSlice args);
static OpResult<StringVec> OpDiff(const OpArgs& op_args, ArgSlice keys);
static OpResult<StringVec> OpInter(const Transaction* t, EngineShard* es, bool remove_first);
// count - how many elements to pop.

View File

@ -63,6 +63,11 @@ TEST_F(SetFamilyTest, SDiff) {
EXPECT_THAT(resp, UnorderedElementsAre("1", "2", "3"));
resp = Run({"sdiffstore", "a", "b", "c"});
EXPECT_THAT(resp[0], IntArg(3));
Run({"sadd", "bar", "x", "a", "b", "c"});
Run({"sadd", "foo", "c"});
Run({"sadd", "car", "a", "d"});
EXPECT_EQ(2, CheckedInt({"SDIFFSTORE", "tar", "bar", "foo", "car"}));
}
TEST_F(SetFamilyTest, SInter) {