refactor(ledger): rollback ledger to a certain block according to the journals

This commit is contained in:
zhourong 2020-04-27 17:18:22 +08:00
parent 67b059aa45
commit e2b5302f22
5 changed files with 161 additions and 28 deletions

View File

@ -52,8 +52,13 @@ func (o *Account) GetState(key []byte) (bool, []byte) {
}
val, err := o.ldb.Get(append(o.Addr.Bytes(), key...))
if err != nil && err != errors.ErrNotFound {
panic(err)
if err != nil {
if err != errors.ErrNotFound {
panic(err)
} else {
o.originState[hexKey] = nil
return false, nil
}
}
o.originState[hexKey] = val
@ -63,6 +68,7 @@ func (o *Account) GetState(key []byte) (bool, []byte) {
// SetState Set account state
func (o *Account) SetState(key []byte, value []byte) {
o.GetState(key)
o.dirtyState[hex.EncodeToString(key)] = value
}
@ -91,13 +97,18 @@ func (o *Account) Code() []byte {
}
code, err := o.ldb.Get(compositeKey(codeKey, o.Addr.Hex()))
if err != nil && err != errors.ErrNotFound {
panic(err)
if err != nil {
if err != errors.ErrNotFound {
panic(err)
} else {
o.originCode = nil
o.dirtyCode = nil
}
} else {
o.originCode = code
o.dirtyCode = code
}
o.originCode = code
o.dirtyCode = code
return code
}
@ -177,6 +188,19 @@ func (o *Account) getJournalIfModified(ldbBatch storage.Batch) *journal {
entry.PrevAccount = o.originAccount
}
if o.originCode == nil && !(o.originAccount == nil || o.originAccount.CodeHash == nil) {
code, err := o.ldb.Get(compositeKey(codeKey, o.Addr.Hex()))
if err != nil {
if err != errors.ErrNotFound {
panic(err)
} else {
o.originCode = nil
}
} else {
o.originCode = code
}
}
if !bytes.Equal(o.originCode, o.dirtyCode) {
if o.dirtyCode != nil {
ldbBatch.Put(compositeKey(codeKey, o.Addr.Hex()), o.dirtyCode)

View File

@ -1,6 +1,7 @@
package ledger
import (
"encoding/hex"
"encoding/json"
"fmt"
@ -36,10 +37,15 @@ func (journal *journal) revert(batch storage.Batch) {
}
for key, val := range journal.PrevStates {
byteKey, err := hex.DecodeString(key)
if err != nil {
panic(err)
}
if val != nil {
batch.Put(compositeKey(journal.Address.Hex(), key), val)
batch.Put(append(journal.Address.Bytes(), byteKey...), val)
} else {
batch.Delete(compositeKey(journal.Address.Hex(), key))
batch.Delete(append(journal.Address.Bytes(), byteKey...))
}
}
@ -75,3 +81,17 @@ func getLatestJournal(ldb storage.Storage) (uint64, *BlockJournal, error) {
return maxHeight, journal, nil
}
func getBlockJournal(height uint64, ldb storage.Storage) *BlockJournal {
data, err := ldb.Get(compositeKey(journalKey, height))
if err != nil {
panic(err)
}
journal := &BlockJournal{}
if err := json.Unmarshal(data, journal); err != nil {
panic(err)
}
return journal
}

View File

@ -65,32 +65,44 @@ func New(repoRoot string, blockchainStore storage.Storage, logger logrus.FieldLo
}, nil
}
// Rollback rollback to history version
// Rollback rollback ledger to history version
func (l *ChainLedger) Rollback(height uint64) error {
if l.chainMeta.Height <= height {
if l.height < height {
return ErrorRollbackTohigherNumber
}
block, err := l.GetBlock(height)
if err != nil {
return err
if l.height == height {
return nil
}
count, err := getInterchainTxCount(block.BlockHeader)
if err != nil {
return err
}
l.UpdateChainMeta(&pb.ChainMeta{
Height: height,
BlockHash: block.BlockHash,
InterchainTxCount: count,
})
// clean cache account
l.Clear()
// TODO
for i := l.height; i > height; i-- {
batch := l.ldb.NewBatch()
blockJournal := getBlockJournal(i, l.ldb)
for _, journal := range blockJournal.Journals {
journal.revert(batch)
}
if err := batch.Commit(); err != nil {
panic(err)
}
if err := l.ldb.Delete(compositeKey(journalKey, i)); err != nil {
panic(err)
}
}
height, journal, err := getLatestJournal(l.ldb)
if err != nil {
return fmt.Errorf("get journal during rollback: %w", err)
}
l.prevJournalHash = journal.ChangedHash
l.height = height
return nil
}

View File

@ -1,6 +1,7 @@
package ledger
import (
"crypto/sha256"
"encoding/hex"
"io/ioutil"
"testing"
@ -100,3 +101,75 @@ func TestLedger_Commit(t *testing.T) {
ver := ldg.Version()
assert.Equal(t, uint64(5), ver)
}
func TestChainLedger_Rollback(t *testing.T) {
repoRoot, err := ioutil.TempDir("", "ledger_rollback")
assert.Nil(t, err)
blockStorage, err := leveldb.New(repoRoot)
assert.Nil(t, err)
ledger, err := New(repoRoot, blockStorage, log.NewWithModule("executor"))
assert.Nil(t, err)
// create an addr0
addr0 := types.Bytes2Address(bytesutil.LeftPadBytes([]byte{100}, 20))
addr1 := types.Bytes2Address(bytesutil.LeftPadBytes([]byte{101}, 20))
code := sha256.Sum256([]byte("code"))
codeHash := sha256.Sum256(code[:])
ledger.SetBalance(addr0, 1)
ledger.SetCode(addr0, code[:])
hash1, err := ledger.Commit(1)
assert.Nil(t, err)
ledger.SetBalance(addr0, 2)
ledger.SetState(addr0, []byte("a"), []byte("2"))
code1 := sha256.Sum256([]byte("code1"))
codeHash1 := sha256.Sum256(code1[:])
ledger.SetCode(addr0, code1[:])
hash2, err := ledger.Commit(2)
assert.Nil(t, err)
ledger.SetBalance(addr1, 3)
ledger.SetBalance(addr0, 4)
ledger.SetState(addr0, []byte("a"), []byte("3"))
ledger.SetState(addr0, []byte("b"), []byte("4"))
hash3, err := ledger.Commit(3)
assert.Nil(t, err)
assert.Equal(t, hash3, ledger.prevJournalHash)
err = ledger.Rollback(2)
assert.Nil(t, err)
assert.Equal(t, hash2, ledger.prevJournalHash)
account0 := ledger.GetAccount(addr0)
assert.Equal(t, uint64(2), account0.GetBalance())
assert.Equal(t, uint64(0), account0.GetNonce())
assert.Equal(t, codeHash1[:], account0.CodeHash())
assert.Equal(t, code1[:], account0.Code())
ok, val := account0.GetState([]byte("a"))
assert.True(t, ok)
assert.Equal(t, []byte("2"), val)
account1 := ledger.GetAccount(addr1)
assert.Equal(t, uint64(0), account1.GetBalance())
assert.Equal(t, uint64(0), account1.GetNonce())
assert.Nil(t, account1.CodeHash())
assert.Nil(t, account1.Code())
err = ledger.Rollback(1)
assert.Nil(t, err)
assert.Equal(t, hash1, ledger.prevJournalHash)
account0 = ledger.GetAccount(addr0)
assert.Equal(t, uint64(1), account0.GetBalance())
assert.Equal(t, uint64(0), account0.GetNonce())
assert.Equal(t, codeHash[:], account0.CodeHash())
assert.Equal(t, code[:], account0.Code())
ok, _ = account0.GetState([]byte("a"))
assert.False(t, ok)
}

View File

@ -32,8 +32,12 @@ func (l *ChainLedger) GetAccount(addr types.Address) *Account {
account := newAccount(l.ldb, addr)
data, err := l.ldb.Get(compositeKey(accountKey, addr.Hex()))
if err != nil && err != errors.ErrNotFound {
panic(err)
if err != nil {
if err != errors.ErrNotFound {
panic(err)
} else {
return account
}
}
if data != nil {
account.originAccount = &innerAccount{}