Fix SDIFF/SINTER commands

This commit is contained in:
Roman Gershman 2022-03-22 23:50:47 +02:00
parent c533ffb692
commit f1ea69c0b4
5 changed files with 158 additions and 164 deletions

View File

@ -111,7 +111,7 @@ API 1.0
- [X] ZRANGE
- [X] ZRANGEBYSCORE
- [X] ZREM
- [ ] ZREMRANGEBYSCORE
- [X] ZREMRANGEBYSCORE
- [ ] ZREVRANGE
- [X] ZSCORE
- [ ] Not sure whether these are required for the initial release.

View File

@ -62,44 +62,6 @@ inline void FreeObjSet(unsigned encoding, void* ptr, pmr::memory_resource* mr) {
}
}
bool dictContains(const dict* d, string_view key) {
uint64_t h = dictGenHashFunction(key.data(), key.size());
for (unsigned table = 0; table <= 1; table++) {
uint64_t idx = h & DICTHT_SIZE_MASK(d->ht_size_exp[table]);
dictEntry* he = d->ht_table[table][idx];
while (he) {
sds dkey = (sds)he->key;
if (sdslen(dkey) == key.size() && (key.empty() || memcmp(dkey, key.data(), key.size()) == 0))
return true;
he = he->next;
}
if (!dictIsRehashing(d))
break;
}
return false;
}
bool IsMemberSet(unsigned encoding, std::string_view key, void* set) {
long long llval;
switch (encoding) {
case kEncodingIntSet: {
if (!string2ll(key.data(), key.size(), &llval))
return false;
intset* is = (intset*)set;
return intsetFind(is, llval);
}
case kEncodingStrMap: {
const dict* ds = (dict*)set;
return dictContains(ds, key);
}
default:
LOG(FATAL) << "Unexpected encoding " << encoding;
}
}
size_t MallocUsedSet(unsigned encoding, void* ptr) {
switch (encoding) {
case kEncodingStrMap /*OBJ_ENCODING_HT*/:
@ -380,16 +342,6 @@ void RobjWrapper::SetString(string_view s, pmr::memory_resource* mr) {
}
}
bool RobjWrapper::IsMember(std::string_view key) const {
switch (type_) {
case OBJ_SET:
return IsMemberSet(encoding_, key, inner_obj_);
default:
LOG(FATAL) << "Unsupported type " << type_;
}
return false;
}
void RobjWrapper::Init(unsigned type, unsigned encoding, void* inner) {
type_ = type;
encoding_ = encoding;

View File

@ -46,8 +46,6 @@ class RobjWrapper {
return std::string_view{reinterpret_cast<char*>(inner_obj_), sz_};
}
bool IsMember(std::string_view key) const;
private:
size_t InnerObjMallocUsed() const;
void MakeInnerRoom(size_t current_cap, size_t desired, std::pmr::memory_resource* mr);
@ -135,12 +133,6 @@ class CompactObj {
// For containers - returns number of elements in the container.
size_t Size() const;
// Should be called only for container based objects.
// Returns true if the container contains this key.
bool IsMember(std::string_view key) const {
return u_.r_obj.IsMember(key);
}
// TODO: We don't use c++ constructs (ctor, dtor, =) in objects of U,
// because we use memcpy here.
CompactObj AsRef() const {

View File

@ -99,30 +99,6 @@ pair<unsigned, bool> RemoveSet(ArgSlice vals, CompactObj* set) {
return make_pair(removed, isempty);
}
template <typename F> void FillSet(const CompactObj& set, F&& f) {
if (set.Encoding() == kEncodingIntSet) {
intset* is = (intset*)set.RObjPtr();
int64_t ival;
int ii = 0;
char buf[32];
while (intsetGet(is, ii++, &ival)) {
char* next = absl::numbers_internal::FastIntToBuffer(ival, buf);
f(string{buf, size_t(next - buf)});
}
} else {
dict* ds = (dict*)set.RObjPtr();
string str;
dictIterator* di = dictGetIterator(ds);
dictEntry* de = nullptr;
while ((de = dictNext(di))) {
str.assign((sds)de->key, sdslen((sds)de->key));
f(move(str));
}
dictReleaseIterator(di);
}
}
vector<string> ToVec(absl::flat_hash_set<string>&& set) {
vector<string> result(set.size());
size_t i = 0;
@ -234,6 +210,82 @@ OpStatus NoOpCb(Transaction* t, EngineShard* shard) {
return OpStatus::OK;
};
using SetType = pair<void*, unsigned>;
uint32_t SetTypeLen(const SetType& set) {
if (set.second == kEncodingStrMap) {
return dictSize((const dict*)set.first);
}
DCHECK_EQ(set.second, kEncodingIntSet);
return intsetLen((const intset*)set.first);
};
bool dictContains(const dict* d, string_view key) {
uint64_t h = dictGenHashFunction(key.data(), key.size());
for (unsigned table = 0; table <= 1; table++) {
uint64_t idx = h & DICTHT_SIZE_MASK(d->ht_size_exp[table]);
dictEntry* he = d->ht_table[table][idx];
while (he) {
sds dkey = (sds)he->key;
if (sdslen(dkey) == key.size() && (key.empty() || memcmp(dkey, key.data(), key.size()) == 0))
return true;
he = he->next;
}
if (!dictIsRehashing(d))
break;
}
return false;
}
bool IsInSet(const SetType& st, int64_t val) {
if (st.second == kEncodingIntSet)
return intsetFind((intset*)st.first, val);
DCHECK_EQ(st.second, kEncodingStrMap);
char buf[32];
char* next = absl::numbers_internal::FastIntToBuffer(val, buf);
return dictContains((dict*)st.first, string_view{buf, size_t(next - buf)});
}
bool IsInSet(const SetType& st, string_view member) {
if (st.second == kEncodingIntSet) {
long long llval;
if (!string2ll(member.data(), member.size(), &llval))
return false;
return intsetFind((intset*)st.first, llval);
}
DCHECK_EQ(st.second, kEncodingStrMap);
return dictContains((dict*)st.first, member);
}
template <typename F> void FillSet(const SetType& set, F&& f) {
if (set.second == kEncodingIntSet) {
intset* is = (intset*)set.first;
int64_t ival;
int ii = 0;
char buf[32];
while (intsetGet(is, ii++, &ival)) {
char* next = absl::numbers_internal::FastIntToBuffer(ival, buf);
f(string{buf, size_t(next - buf)});
}
} else {
dict* ds = (dict*)set.first;
string str;
dictIterator* di = dictGetIterator(ds);
dictEntry* de = nullptr;
while ((de = dictNext(di))) {
str.assign((sds)de->key, sdslen((sds)de->key));
f(move(str));
}
dictReleaseIterator(di);
}
}
// if overwrite is true then OpAdd writes vals into the key and discards its previous value.
OpResult<uint32_t> OpAdd(const OpArgs& op_args, std::string_view key, const ArgSlice& vals,
bool overwrite) {
@ -373,7 +425,9 @@ OpStatus Mover::OpFind(Transaction* t, EngineShard* es) {
OpResult<MainIterator> res = es->db_slice().Find(t->db_index(), k, OBJ_SET);
if (res && index == 0) { // succesful src find.
DCHECK(!res->is_done());
found_[0] = res.value()->second.IsMember(member_);
const CompactObj& val = res.value()->second;
SetType st{val.RObjPtr(), val.Encoding()};
found_[0] = IsInSet(st, member_);
} else {
found_[index] = res.status();
}
@ -428,23 +482,6 @@ OpResult<unsigned> Mover::Commit(Transaction* t) {
return res;
}
#if 0
bool IsInSet(const robj* s, int64_t val) {
if (s->encoding == OBJ_ENCODING_INTSET)
return intsetFind((intset*)s->ptr, val);
/* in order to compare an integer with an object we
* have to use the generic function, creating an object
* for this */
DCHECK_EQ(s->encoding, OBJ_ENCODING_HT);
sds elesds = sdsfromlonglong(val);
bool res = setTypeIsMember(s, elesds);
sdsfree(elesds);
return res;
}
#endif
} // namespace
void SetFamily::SAdd(CmdArgList args, ConnectionContext* cntx) {
@ -484,7 +521,8 @@ void SetFamily::SIsMember(CmdArgList args, ConnectionContext* cntx) {
OpResult<MainIterator> find_res = shard->db_slice().Find(t->db_index(), key, OBJ_SET);
if (find_res) {
return find_res.value()->second.IsMember(val) ? OpStatus::OK : OpStatus::KEY_NOTFOUND;
SetType st{find_res.value()->second.RObjPtr(), find_res.value()->second.Encoding()};
return IsInSet(st, val) ? OpStatus::OK : OpStatus::KEY_NOTFOUND;
}
return find_res.status();
@ -828,7 +866,8 @@ OpResult<StringVec> SetFamily::OpUnion(const OpArgs& op_args, const ArgSlice& ke
for (std::string_view key : keys) {
OpResult<MainIterator> find_res = op_args.shard->db_slice().Find(op_args.db_ind, key, OBJ_SET);
if (find_res) {
FillSet(find_res.value()->second, [&uniques](string s) { uniques.emplace(move(s)); });
SetType st{find_res.value()->second.RObjPtr(), find_res.value()->second.Encoding()};
FillSet(st, [&uniques](string s) { uniques.emplace(move(s)); });
continue;
}
@ -851,38 +890,45 @@ OpResult<StringVec> SetFamily::OpDiff(const Transaction* t, EngineShard* es) {
}
absl::flat_hash_set<string> uniques;
SetType st{find_res.value()->second.RObjPtr(), find_res.value()->second.Encoding()};
FillSet(st, [&uniques](string s) { uniques.insert(move(s)); });
#if 0
robj* sobj = find_res.value()->second.AsRObj();
FillSet(sobj, &uniques);
DCHECK(!uniques.empty()); // otherwise the key would not exist.
for (size_t i = 1; i < keys.size(); ++i) {
OpResult<MainIterator> diff_res = es->db_slice().Find(t->db_index(), keys[i], OBJ_SET);
if (!find_res) {
if (find_res.status() == OpStatus::WRONG_TYPE) {
if (!diff_res) {
if (diff_res.status() == OpStatus::WRONG_TYPE) {
return OpStatus::WRONG_TYPE;
}
continue;
continue; // KEY_NOTFOUND
}
sobj = diff_res.value()->second.AsRObj();
auto* si = setTypeInitIterator(sobj);
sds ele;
int64_t llele;
int encoding;
while ((encoding = setTypeNext(si, &ele, &llele)) != -1) {
if (encoding == OBJ_ENCODING_HT) {
std::string_view sv{ele, sdslen(ele)};
uniques.erase(sv);
} else {
absl::AlphaNum an(llele);
uniques.erase(an.Piece());
SetType st2{diff_res.value()->second.RObjPtr(), diff_res.value()->second.Encoding()};
if (st2.second == kEncodingIntSet) {
int ii = 0;
intset* is = (intset*)st2.first;
int64_t intele;
char buf[32];
while (intsetGet(is, ii++, &intele)) {
char* next = absl::numbers_internal::FastIntToBuffer(intele, buf);
uniques.erase(string_view{buf, size_t(next - buf)});
}
} else {
DCHECK_EQ(kEncodingStrMap, st2.second);
dict* ds = (dict*)st2.first;
dictIterator* di = dictGetIterator(ds);
dictEntry* de = nullptr;
while ((de = dictNext(di))) {
sds key = (sds)de->key;
uniques.erase(string_view{key, sdslen(key)});
}
dictReleaseIterator(di);
}
setTypeReleaseIterator(si);
}
#endif
return ToVec(std::move(uniques));
}
@ -898,17 +944,18 @@ OpResult<StringVec> SetFamily::OpPop(const OpArgs& op_args, std::string_view key
MainIterator it = find_res.value();
size_t slen = it->second.Size();
SetType st{find_res.value()->second.RObjPtr(), find_res.value()->second.Encoding()};
/* CASE 1:
* The number of requested elements is greater than or equal to
* the number of elements inside the set: simply return the whole set. */
if (count >= slen) {
FillSet(it->second, [&result](string s) { result.push_back(move(s)); });
FillSet(st, [&result](string s) { result.push_back(move(s)); });
/* Delete the set as it is now empty */
CHECK(es->db_slice().Del(op_args.db_ind, it));
} else {
if (it->second.Encoding() == kEncodingIntSet) {
intset* is = (intset*)it->second.RObjPtr();
if (st.second == kEncodingIntSet) {
intset* is = (intset*)st.first;
int64_t val = 0;
// copy last count values.
@ -920,7 +967,7 @@ OpResult<StringVec> SetFamily::OpPop(const OpArgs& op_args, std::string_view key
is = intsetTrimTail(is, count); // now remove last count items
it->second.SetRObjPtr(is);
} else {
dict* ds = (dict*)it->second.RObjPtr();
dict* ds = (dict*)st.first;
string str;
dictIterator* di = dictGetSafeIterator(ds);
for (uint32_t i = 0; i < count; ++i) {
@ -948,63 +995,69 @@ OpResult<StringVec> SetFamily::OpInter(const Transaction* t, EngineShard* es, bo
if (!find_res)
return find_res.status();
FillSet(find_res.value()->second, [&result](string s) { result.push_back(move(s)); });
SetType st{find_res.value()->second.RObjPtr(), find_res.value()->second.Encoding()};
FillSet(st, [&result](string s) { result.push_back(move(s)); });
return result;
}
LOG(DFATAL) << "TBD";
#if 0
vector<CompactObj> sets(keys.size()); // we must copy by value because AsRObj is temporary.
// we must copy by value because AsRObj is temporary.
vector<SetType> sets(keys.size());
for (size_t i = 0; i < keys.size(); ++i) {
OpResult<MainIterator> find_res = es->db_slice().Find(t->db_index(), keys[i], OBJ_SET);
if (!find_res)
return find_res.status();
robj* sobj = find_res.value()->second.AsRObj();
sets[i] = *sobj;
sets[i] = make_pair(sobj->ptr, unsigned(sobj->encoding));
}
auto comp = [](const robj& left, const robj& right) {
return setTypeSize(&left) < setTypeSize(&right);
auto comp = [](const SetType& left, const SetType& right) {
return SetTypeLen(left) < SetTypeLen(right);
};
std::sort(sets.begin(), sets.end(), comp);
int encoding;
sds elesds;
int64_t intobj;
// TODO: the whole code is awful. imho, the encoding is the same for the same object.
/* Iterate all the elements of the first (smallest) set, and test
* the element against all the other sets, if at least one set does
* not include the element it is discarded */
auto* si = setTypeInitIterator(&sets[0]);
while ((encoding = setTypeNext(si, &elesds, &intobj)) != -1) {
size_t j = 1;
for (; j < sets.size(); j++) {
if (sets[j].ptr == sets[0].ptr) // when provide the same key several times.
continue;
if (encoding == OBJ_ENCODING_INTSET) {
/* intset with intset is simple... and fast */
if (!IsInSet(&sets[j], intobj))
int encoding = sets.front().second;
if (encoding == kEncodingIntSet) {
int ii = 0;
intset* is = (intset*)sets.front().first;
int64_t intele;
while (intsetGet(is, ii++, &intele)) {
size_t j = 1;
for (j = 1; j < sets.size(); j++) {
if (sets[j].first != is && !IsInSet(sets[j], intele))
break;
} else if (encoding == OBJ_ENCODING_HT) {
if (!setTypeIsMember(&sets[j], elesds)) {
break;
}
}
/* Only take action when all sets contain the member */
if (j == sets.size()) {
result.push_back(absl::StrCat(intele));
}
}
} else {
dict* ds = (dict*)sets.front().first;
dictIterator* di = dictGetIterator(ds);
dictEntry* de = nullptr;
while ((de = dictNext(di))) {
size_t j = 1;
sds key = (sds)de->key;
string_view member{key, sdslen(key)};
/* Only take action when all sets contain the member */
if (j == sets.size()) {
if (encoding == OBJ_ENCODING_HT) {
result.emplace_back(std::string_view{elesds, sdslen(elesds)});
} else {
DCHECK_EQ(unsigned(encoding), OBJ_ENCODING_INTSET);
result.push_back(absl::StrCat(intobj));
for (j = 1; j < sets.size(); j++) {
if (sets[j].first != ds && !IsInSet(sets[j], member))
break;
}
/* Only take action when all sets contain the member */
if (j == sets.size()) {
result.push_back(string(member));
}
}
dictReleaseIterator(di);
}
setTypeReleaseIterator(si);
#endif
return result;
}

View File

@ -56,9 +56,6 @@ TEST_F(SetFamilyTest, SUnionStore) {
}
TEST_F(SetFamilyTest, SDiff) {
LOG(ERROR) << "TBD";
return;
auto resp = Run({"sadd", "b", "1", "2", "3"});
Run({"sadd", "c", "10", "11"});
Run({"set", "a", "foo"});