refactor(ledger): rollback ledger to a certain block according to the journals
This commit is contained in:
parent
67b059aa45
commit
e2b5302f22
|
@ -52,8 +52,13 @@ func (o *Account) GetState(key []byte) (bool, []byte) {
|
|||
}
|
||||
|
||||
val, err := o.ldb.Get(append(o.Addr.Bytes(), key...))
|
||||
if err != nil && err != errors.ErrNotFound {
|
||||
panic(err)
|
||||
if err != nil {
|
||||
if err != errors.ErrNotFound {
|
||||
panic(err)
|
||||
} else {
|
||||
o.originState[hexKey] = nil
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
o.originState[hexKey] = val
|
||||
|
@ -63,6 +68,7 @@ func (o *Account) GetState(key []byte) (bool, []byte) {
|
|||
|
||||
// SetState Set account state
|
||||
func (o *Account) SetState(key []byte, value []byte) {
|
||||
o.GetState(key)
|
||||
o.dirtyState[hex.EncodeToString(key)] = value
|
||||
}
|
||||
|
||||
|
@ -91,13 +97,18 @@ func (o *Account) Code() []byte {
|
|||
}
|
||||
|
||||
code, err := o.ldb.Get(compositeKey(codeKey, o.Addr.Hex()))
|
||||
if err != nil && err != errors.ErrNotFound {
|
||||
panic(err)
|
||||
if err != nil {
|
||||
if err != errors.ErrNotFound {
|
||||
panic(err)
|
||||
} else {
|
||||
o.originCode = nil
|
||||
o.dirtyCode = nil
|
||||
}
|
||||
} else {
|
||||
o.originCode = code
|
||||
o.dirtyCode = code
|
||||
}
|
||||
|
||||
o.originCode = code
|
||||
o.dirtyCode = code
|
||||
|
||||
return code
|
||||
}
|
||||
|
||||
|
@ -177,6 +188,19 @@ func (o *Account) getJournalIfModified(ldbBatch storage.Batch) *journal {
|
|||
entry.PrevAccount = o.originAccount
|
||||
}
|
||||
|
||||
if o.originCode == nil && !(o.originAccount == nil || o.originAccount.CodeHash == nil) {
|
||||
code, err := o.ldb.Get(compositeKey(codeKey, o.Addr.Hex()))
|
||||
if err != nil {
|
||||
if err != errors.ErrNotFound {
|
||||
panic(err)
|
||||
} else {
|
||||
o.originCode = nil
|
||||
}
|
||||
} else {
|
||||
o.originCode = code
|
||||
}
|
||||
}
|
||||
|
||||
if !bytes.Equal(o.originCode, o.dirtyCode) {
|
||||
if o.dirtyCode != nil {
|
||||
ldbBatch.Put(compositeKey(codeKey, o.Addr.Hex()), o.dirtyCode)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package ledger
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
|
@ -36,10 +37,15 @@ func (journal *journal) revert(batch storage.Batch) {
|
|||
}
|
||||
|
||||
for key, val := range journal.PrevStates {
|
||||
byteKey, err := hex.DecodeString(key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if val != nil {
|
||||
batch.Put(compositeKey(journal.Address.Hex(), key), val)
|
||||
batch.Put(append(journal.Address.Bytes(), byteKey...), val)
|
||||
} else {
|
||||
batch.Delete(compositeKey(journal.Address.Hex(), key))
|
||||
batch.Delete(append(journal.Address.Bytes(), byteKey...))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -75,3 +81,17 @@ func getLatestJournal(ldb storage.Storage) (uint64, *BlockJournal, error) {
|
|||
|
||||
return maxHeight, journal, nil
|
||||
}
|
||||
|
||||
func getBlockJournal(height uint64, ldb storage.Storage) *BlockJournal {
|
||||
data, err := ldb.Get(compositeKey(journalKey, height))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
journal := &BlockJournal{}
|
||||
if err := json.Unmarshal(data, journal); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return journal
|
||||
}
|
||||
|
|
|
@ -65,32 +65,44 @@ func New(repoRoot string, blockchainStore storage.Storage, logger logrus.FieldLo
|
|||
}, nil
|
||||
}
|
||||
|
||||
// Rollback rollback to history version
|
||||
// Rollback rollback ledger to history version
|
||||
func (l *ChainLedger) Rollback(height uint64) error {
|
||||
if l.chainMeta.Height <= height {
|
||||
if l.height < height {
|
||||
return ErrorRollbackTohigherNumber
|
||||
}
|
||||
|
||||
block, err := l.GetBlock(height)
|
||||
if err != nil {
|
||||
return err
|
||||
if l.height == height {
|
||||
return nil
|
||||
}
|
||||
|
||||
count, err := getInterchainTxCount(block.BlockHeader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
l.UpdateChainMeta(&pb.ChainMeta{
|
||||
Height: height,
|
||||
BlockHash: block.BlockHash,
|
||||
InterchainTxCount: count,
|
||||
})
|
||||
|
||||
// clean cache account
|
||||
l.Clear()
|
||||
|
||||
// TODO
|
||||
for i := l.height; i > height; i-- {
|
||||
batch := l.ldb.NewBatch()
|
||||
blockJournal := getBlockJournal(i, l.ldb)
|
||||
|
||||
for _, journal := range blockJournal.Journals {
|
||||
journal.revert(batch)
|
||||
}
|
||||
|
||||
if err := batch.Commit(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if err := l.ldb.Delete(compositeKey(journalKey, i)); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
height, journal, err := getLatestJournal(l.ldb)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get journal during rollback: %w", err)
|
||||
}
|
||||
|
||||
l.prevJournalHash = journal.ChangedHash
|
||||
|
||||
l.height = height
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package ledger
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"io/ioutil"
|
||||
"testing"
|
||||
|
@ -100,3 +101,75 @@ func TestLedger_Commit(t *testing.T) {
|
|||
ver := ldg.Version()
|
||||
assert.Equal(t, uint64(5), ver)
|
||||
}
|
||||
|
||||
func TestChainLedger_Rollback(t *testing.T) {
|
||||
repoRoot, err := ioutil.TempDir("", "ledger_rollback")
|
||||
assert.Nil(t, err)
|
||||
blockStorage, err := leveldb.New(repoRoot)
|
||||
assert.Nil(t, err)
|
||||
ledger, err := New(repoRoot, blockStorage, log.NewWithModule("executor"))
|
||||
assert.Nil(t, err)
|
||||
|
||||
// create an addr0
|
||||
addr0 := types.Bytes2Address(bytesutil.LeftPadBytes([]byte{100}, 20))
|
||||
addr1 := types.Bytes2Address(bytesutil.LeftPadBytes([]byte{101}, 20))
|
||||
|
||||
code := sha256.Sum256([]byte("code"))
|
||||
codeHash := sha256.Sum256(code[:])
|
||||
|
||||
ledger.SetBalance(addr0, 1)
|
||||
ledger.SetCode(addr0, code[:])
|
||||
|
||||
hash1, err := ledger.Commit(1)
|
||||
assert.Nil(t, err)
|
||||
|
||||
ledger.SetBalance(addr0, 2)
|
||||
ledger.SetState(addr0, []byte("a"), []byte("2"))
|
||||
|
||||
code1 := sha256.Sum256([]byte("code1"))
|
||||
codeHash1 := sha256.Sum256(code1[:])
|
||||
ledger.SetCode(addr0, code1[:])
|
||||
|
||||
hash2, err := ledger.Commit(2)
|
||||
assert.Nil(t, err)
|
||||
|
||||
ledger.SetBalance(addr1, 3)
|
||||
ledger.SetBalance(addr0, 4)
|
||||
ledger.SetState(addr0, []byte("a"), []byte("3"))
|
||||
ledger.SetState(addr0, []byte("b"), []byte("4"))
|
||||
|
||||
hash3, err := ledger.Commit(3)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, hash3, ledger.prevJournalHash)
|
||||
|
||||
err = ledger.Rollback(2)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, hash2, ledger.prevJournalHash)
|
||||
|
||||
account0 := ledger.GetAccount(addr0)
|
||||
assert.Equal(t, uint64(2), account0.GetBalance())
|
||||
assert.Equal(t, uint64(0), account0.GetNonce())
|
||||
assert.Equal(t, codeHash1[:], account0.CodeHash())
|
||||
assert.Equal(t, code1[:], account0.Code())
|
||||
ok, val := account0.GetState([]byte("a"))
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, []byte("2"), val)
|
||||
|
||||
account1 := ledger.GetAccount(addr1)
|
||||
assert.Equal(t, uint64(0), account1.GetBalance())
|
||||
assert.Equal(t, uint64(0), account1.GetNonce())
|
||||
assert.Nil(t, account1.CodeHash())
|
||||
assert.Nil(t, account1.Code())
|
||||
|
||||
err = ledger.Rollback(1)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, hash1, ledger.prevJournalHash)
|
||||
|
||||
account0 = ledger.GetAccount(addr0)
|
||||
assert.Equal(t, uint64(1), account0.GetBalance())
|
||||
assert.Equal(t, uint64(0), account0.GetNonce())
|
||||
assert.Equal(t, codeHash[:], account0.CodeHash())
|
||||
assert.Equal(t, code[:], account0.Code())
|
||||
ok, _ = account0.GetState([]byte("a"))
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
|
|
@ -32,8 +32,12 @@ func (l *ChainLedger) GetAccount(addr types.Address) *Account {
|
|||
|
||||
account := newAccount(l.ldb, addr)
|
||||
data, err := l.ldb.Get(compositeKey(accountKey, addr.Hex()))
|
||||
if err != nil && err != errors.ErrNotFound {
|
||||
panic(err)
|
||||
if err != nil {
|
||||
if err != errors.ErrNotFound {
|
||||
panic(err)
|
||||
} else {
|
||||
return account
|
||||
}
|
||||
}
|
||||
if data != nil {
|
||||
account.originAccount = &innerAccount{}
|
||||
|
|
Loading…
Reference in New Issue