From 8d71fd392f93e1016635a5faa82e7d6096101a31 Mon Sep 17 00:00:00 2001 From: zhourong Date: Thu, 7 May 2020 14:11:02 +0800 Subject: [PATCH] refactor(*): separate data persistence from block execution --- internal/executor/executor.go | 26 +- internal/executor/executor_test.go | 7 +- internal/executor/handle.go | 31 +-- internal/ledger/account.go | 71 ++--- internal/ledger/account_cache.go | 141 ++++++++++ internal/ledger/account_test.go | 8 +- internal/ledger/blockchain.go | 65 +++++ internal/ledger/genesis/genesis.go | 28 +- internal/ledger/journal.go | 10 +- internal/ledger/key.go | 10 +- internal/ledger/ledger.go | 93 ++++--- internal/ledger/ledger_test.go | 372 ++++++++++++++++++++----- internal/ledger/state_accessor.go | 127 ++++++++- internal/ledger/state_accessor_test.go | 128 --------- internal/ledger/types.go | 8 +- 15 files changed, 797 insertions(+), 328 deletions(-) create mode 100644 internal/ledger/account_cache.go delete mode 100644 internal/ledger/state_accessor_test.go diff --git a/internal/executor/executor.go b/internal/executor/executor.go index d919c02..ca09e5f 100755 --- a/internal/executor/executor.go +++ b/internal/executor/executor.go @@ -19,7 +19,10 @@ import ( "github.com/wasmerio/go-ext-wasm/wasmer" ) -const blockChanNumber = 1024 +const ( + blockChanNumber = 1024 + persistChanNumber = 1024 +) var _ Executor = (*BlockExecutor)(nil) @@ -28,6 +31,7 @@ type BlockExecutor struct { ledger ledger.Ledger logger logrus.FieldLogger blockC chan *pb.Block + persistC chan *ledger.BlockData pendingBlockQ *cache.Cache interchainCounter map[string][]uint64 validationEngine validator.Engine @@ -43,29 +47,30 @@ type BlockExecutor struct { } // New creates executor instance -func New(ledger ledger.Ledger, logger logrus.FieldLogger) (*BlockExecutor, error) { +func New(chainLedger ledger.Ledger, logger logrus.FieldLogger) (*BlockExecutor, error) { pendingBlockQ, err := cache.NewCache() if err != nil { return nil, fmt.Errorf("create cache: %w", err) } - ve := validator.NewValidationEngine(ledger, logger) + ve := validator.NewValidationEngine(chainLedger, logger) boltContracts := registerBoltContracts() ctx, cancel := context.WithCancel(context.Background()) return &BlockExecutor{ - ledger: ledger, + ledger: chainLedger, logger: logger, interchainCounter: make(map[string][]uint64), ctx: ctx, cancel: cancel, blockC: make(chan *pb.Block, blockChanNumber), + persistC: make(chan *ledger.BlockData, persistChanNumber), pendingBlockQ: pendingBlockQ, validationEngine: ve, - currentHeight: ledger.GetChainMeta().Height, - currentBlockHash: ledger.GetChainMeta().BlockHash, + currentHeight: chainLedger.GetChainMeta().Height, + currentBlockHash: chainLedger.GetChainMeta().BlockHash, boltContracts: boltContracts, wasmInstances: make(map[string]wasmer.Instance), }, nil @@ -75,6 +80,8 @@ func New(ledger ledger.Ledger, logger logrus.FieldLogger) (*BlockExecutor, error func (exec *BlockExecutor) Start() error { go exec.listenExecuteEvent() + go exec.persistData() + exec.logger.WithFields(logrus.Fields{ "height": exec.currentHeight, "hash": exec.currentBlockHash.ShortString(), @@ -112,11 +119,18 @@ func (exec *BlockExecutor) listenExecuteEvent() { case block := <-exec.blockC: exec.handleExecuteEvent(block) case <-exec.ctx.Done(): + close(exec.persistC) return } } } +func (exec *BlockExecutor) persistData() { + for data := range exec.persistC { + exec.ledger.PersistBlockData(data) + } +} + func registerBoltContracts() map[string]boltvm.Contract { boltContracts := []*boltvm.BoltContract{ { diff --git a/internal/executor/executor_test.go b/internal/executor/executor_test.go index 21948d4..8112e19 100644 --- a/internal/executor/executor_test.go +++ b/internal/executor/executor_test.go @@ -55,7 +55,7 @@ func TestBlockExecutor_ExecuteBlock(t *testing.T) { evs = append(evs, ev) mockLedger.EXPECT().GetChainMeta().Return(chainMeta).AnyTimes() mockLedger.EXPECT().Events(gomock.Any()).Return(evs).AnyTimes() - mockLedger.EXPECT().Commit(gomock.Any()).Return(types.String2Hash(from), nil).AnyTimes() + mockLedger.EXPECT().Commit(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockLedger.EXPECT().Clear().AnyTimes() mockLedger.EXPECT().GetState(gomock.Any(), gomock.Any()).Return(true, []byte("10")).AnyTimes() mockLedger.EXPECT().SetState(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() @@ -66,6 +66,8 @@ func TestBlockExecutor_ExecuteBlock(t *testing.T) { mockLedger.EXPECT().SetCode(gomock.Any(), gomock.Any()).AnyTimes() mockLedger.EXPECT().GetCode(gomock.Any()).Return([]byte("10")).AnyTimes() mockLedger.EXPECT().PersistExecutionResult(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + mockLedger.EXPECT().FlushDirtyDataAndComputeJournal().Return(make(map[string]*ledger.Account), &ledger.BlockJournal{}).AnyTimes() + mockLedger.EXPECT().PersistBlockData(gomock.Any()).AnyTimes() logger := log.NewWithModule("executor") exec, err := New(mockLedger, logger) @@ -178,7 +180,8 @@ func TestBlockExecutor_ExecuteBlock_Transfer(t *testing.T) { _, from := loadAdminKey(t) ledger.SetBalance(from, 100000000) - _, err = ledger.Commit(1) + account, journal := ledger.FlushDirtyDataAndComputeJournal() + err = ledger.Commit(1, account, journal) require.Nil(t, err) err = ledger.PersistExecutionResult(mockBlock(1, nil), nil) require.Nil(t, err) diff --git a/internal/executor/handle.go b/internal/executor/handle.go index 5547b06..dfb1151 100755 --- a/internal/executor/handle.go +++ b/internal/executor/handle.go @@ -11,6 +11,7 @@ import ( "github.com/meshplus/bitxhub-kit/merkle/merkletree" "github.com/meshplus/bitxhub-kit/types" "github.com/meshplus/bitxhub-model/pb" + "github.com/meshplus/bitxhub/internal/ledger" "github.com/meshplus/bitxhub/internal/model/events" "github.com/meshplus/bitxhub/pkg/vm" "github.com/meshplus/bitxhub/pkg/vm/boltvm" @@ -53,8 +54,9 @@ func (exec *BlockExecutor) processExecuteEvent(block *pb.Block) { validTxs, invalidReceipts := exec.verifySign(block) receipts := exec.applyTransactions(validTxs) + receipts = append(receipts, invalidReceipts...) - root, receiptRoot, err := exec.calcMerkleRoots(block.Transactions, append(receipts, invalidReceipts...)) + root, receiptRoot, err := exec.calcMerkleRoots(block.Transactions, receipts) if err != nil { panic(err) } @@ -68,11 +70,9 @@ func (exec *BlockExecutor) processExecuteEvent(block *pb.Block) { } block.BlockHeader.InterchainIndex = idx - hash, err := exec.ledger.Commit(block.BlockHeader.Number) - if err != nil { - panic(err) - } - block.BlockHeader.StateRoot = hash + accounts, journal := exec.ledger.FlushDirtyDataAndComputeJournal() + + block.BlockHeader.StateRoot = journal.ChangedHash block.BlockHash = block.Hash() exec.logger.WithFields(logrus.Fields{ @@ -81,23 +81,18 @@ func (exec *BlockExecutor) processExecuteEvent(block *pb.Block) { "state_root": block.BlockHeader.StateRoot.ShortString(), }).Debug("block meta") - // persist execution result - receipts = append(receipts, invalidReceipts...) - if err := exec.ledger.PersistExecutionResult(block, receipts); err != nil { - panic(err) - } - - exec.logger.WithFields(logrus.Fields{ - "height": block.BlockHeader.Number, - "hash": block.BlockHash.ShortString(), - "count": len(block.Transactions), - }).Info("Persist block") - exec.postBlockEvent(block) exec.clear() exec.currentHeight = block.BlockHeader.Number exec.currentBlockHash = block.BlockHash + + exec.persistC <- &ledger.BlockData{ + Block: block, + Receipts: receipts, + Accounts: accounts, + Journal: journal, + } } func (exec *BlockExecutor) verifySign(block *pb.Block) ([]*pb.Transaction, []*pb.Receipt) { diff --git a/internal/ledger/account.go b/internal/ledger/account.go index 3a840aa..ea32876 100644 --- a/internal/ledger/account.go +++ b/internal/ledger/account.go @@ -3,7 +3,6 @@ package ledger import ( "bytes" "crypto/sha256" - "encoding/hex" "encoding/json" "sort" @@ -21,6 +20,7 @@ type Account struct { dirtyCode []byte dirtyStateHash types.Hash ldb storage.Storage + cache *AccountCache } type innerAccount struct { @@ -29,30 +29,32 @@ type innerAccount struct { CodeHash []byte `json:"code_hash"` } -func newAccount(ldb storage.Storage, addr types.Address) *Account { +func newAccount(ldb storage.Storage, cache *AccountCache, addr types.Address) *Account { return &Account{ Addr: addr, originState: make(map[string][]byte), dirtyState: make(map[string][]byte), ldb: ldb, + cache: cache, } } // GetState Get state from local cache, if not found, then get it from DB func (o *Account) GetState(key []byte) (bool, []byte) { - hexKey := hex.EncodeToString(key) - - if val, exist := o.dirtyState[hexKey]; exist { + if val, exist := o.dirtyState[string(key)]; exist { return val != nil, val } - if val, exist := o.originState[hexKey]; exist { + if val, exist := o.originState[string(key)]; exist { return val != nil, val } - val := o.ldb.Get(append(o.Addr.Bytes(), key...)) + val, ok := o.cache.getState(o.Addr.Hex(), string(key)) + if !ok { + val = o.ldb.Get(composeStateKey(o.Addr, key)) + } - o.originState[hexKey] = val + o.originState[string(key)] = val return val != nil, val } @@ -60,7 +62,7 @@ func (o *Account) GetState(key []byte) (bool, []byte) { // SetState Set account state func (o *Account) SetState(key []byte, value []byte) { o.GetState(key) - o.dirtyState[hex.EncodeToString(key)] = value + o.dirtyState[string(key)] = value } // SetCodeAndHash Set the contract code and hash @@ -87,7 +89,11 @@ func (o *Account) Code() []byte { return nil } - code := o.ldb.Get(compositeKey(codeKey, o.Addr.Hex())) + code, ok := o.cache.getCode(o.Addr.Hex()) + if !ok { + code = o.ldb.Get(compositeKey(codeKey, o.Addr.Hex())) + } + o.originCode = code o.dirtyCode = code @@ -145,27 +151,39 @@ func (o *Account) SetBalance(balance uint64) { // Query Query the value using key func (o *Account) Query(prefix string) (bool, [][]byte) { var ret [][]byte + stored := make(map[string][]byte) + begin, end := bytesPrefix(append(o.Addr.Bytes(), prefix...)) it := o.ldb.Iterator(begin, end) for it.Next() { + key := make([]byte, len(it.Key())) val := make([]byte, len(it.Value())) + copy(key, it.Key()) copy(val, it.Value()) + stored[string(key)] = val + } + + cached := o.cache.query(o.Addr.Hex(), prefix) + for key, val := range cached { + stored[key] = val + } + + for _, val := range stored { ret = append(ret, val) } + sort.Slice(ret, func(i, j int) bool { + return bytes.Compare(ret[i], ret[j]) < 0 + }) + return len(ret) != 0, ret } -func (o *Account) getJournalIfModified(ldbBatch storage.Batch) *journal { +func (o *Account) getJournalIfModified() *journal { entry := &journal{Address: o.Addr} if innerAccountChanged(o.originAccount, o.dirtyAccount) { - data, err := o.dirtyAccount.Marshal() - if err != nil { - panic(err) - } - ldbBatch.Put(compositeKey(accountKey, o.Addr.Hex()), data) entry.AccountChanged = true entry.PrevAccount = o.originAccount } @@ -175,16 +193,11 @@ func (o *Account) getJournalIfModified(ldbBatch storage.Batch) *journal { } if !bytes.Equal(o.originCode, o.dirtyCode) { - if o.dirtyCode != nil { - ldbBatch.Put(compositeKey(codeKey, o.Addr.Hex()), o.dirtyCode) - } else { - ldbBatch.Delete(compositeKey(codeKey, o.Addr.Hex())) - } entry.CodeChanged = true entry.PrevCode = o.originCode } - prevStates := o.getStateJournalAndComputeHash(ldbBatch) + prevStates := o.getStateJournalAndComputeHash() if len(prevStates) != 0 { entry.PrevStates = prevStates } @@ -196,7 +209,7 @@ func (o *Account) getJournalIfModified(ldbBatch storage.Batch) *journal { return nil } -func (o *Account) getStateJournalAndComputeHash(ldbBatch storage.Batch) map[string][]byte { +func (o *Account) getStateJournalAndComputeHash() map[string][]byte { prevStates := make(map[string][]byte) var dirtyStateKeys []string var dirtyStateData []byte @@ -204,18 +217,8 @@ func (o *Account) getStateJournalAndComputeHash(ldbBatch storage.Batch) map[stri for key, val := range o.dirtyState { origVal := o.originState[key] if !bytes.Equal(origVal, val) { - dirtyStateKeys = append(dirtyStateKeys, key) - byteKey, err := hex.DecodeString(key) - if err != nil { - panic(err) - } - - if val != nil { - ldbBatch.Put(append(o.Addr.Bytes(), byteKey...), val) - } else { - ldbBatch.Delete(append(o.Addr.Bytes(), byteKey...)) - } prevStates[key] = origVal + dirtyStateKeys = append(dirtyStateKeys, key) } } diff --git a/internal/ledger/account_cache.go b/internal/ledger/account_cache.go new file mode 100644 index 0000000..9a131b7 --- /dev/null +++ b/internal/ledger/account_cache.go @@ -0,0 +1,141 @@ +package ledger + +import ( + "bytes" + "strings" + "sync" +) + +type AccountCache struct { + innerAccounts map[string]*innerAccount + states map[string]map[string][]byte + codes map[string][]byte + rwLock sync.RWMutex +} + +func NewAccountCache() *AccountCache { + return &AccountCache{ + innerAccounts: make(map[string]*innerAccount), + states: make(map[string]map[string][]byte), + codes: make(map[string][]byte), + rwLock: sync.RWMutex{}, + } +} + +func (ac *AccountCache) add(accounts map[string]*Account) { + ac.rwLock.Lock() + defer ac.rwLock.Unlock() + + for addr, account := range accounts { + ac.innerAccounts[addr] = account.dirtyAccount + if len(account.dirtyState) != 0 { + stateMap, ok := ac.states[addr] + if !ok { + stateMap = make(map[string][]byte) + ac.states[addr] = stateMap + } + for key, val := range account.dirtyState { + stateMap[key] = val + } + } + if !bytes.Equal(account.originCode, account.dirtyCode) { + ac.codes[addr] = account.dirtyCode + } + } +} + +func (ac *AccountCache) remove(accounts map[string]*Account) { + ac.rwLock.Lock() + defer ac.rwLock.Unlock() + + for addr, account := range accounts { + if innerAccount, ok := ac.innerAccounts[addr]; ok { + if !innerAccountChanged(innerAccount, account.dirtyAccount) { + delete(ac.innerAccounts, addr) + } + } + + if len(account.dirtyState) != 0 { + if stateMap, ok := ac.states[addr]; ok { + for key, val := range account.dirtyState { + if v, ok := stateMap[key]; ok { + if bytes.Equal(v, val) { + delete(stateMap, key) + } + } + } + if len(stateMap) == 0 { + delete(ac.states, addr) + } + } + } + + if !bytes.Equal(account.dirtyCode, account.originCode) { + if code, ok := ac.codes[addr]; ok { + if bytes.Equal(code, account.dirtyCode) { + delete(ac.codes, addr) + } + } + } + } +} + +func (ac *AccountCache) getInnerAccount(addr string) (*innerAccount, bool) { + ac.rwLock.RLock() + defer ac.rwLock.RUnlock() + + if innerAccount, ok := ac.innerAccounts[addr]; ok { + return innerAccount, true + } + + return nil, false +} + +func (ac *AccountCache) getState(addr string, key string) ([]byte, bool) { + ac.rwLock.RLock() + defer ac.rwLock.RUnlock() + + if stateMap, ok := ac.states[addr]; ok { + if val, ok := stateMap[key]; ok { + return val, true + } + } + + return nil, false +} + +func (ac *AccountCache) getCode(addr string) ([]byte, bool) { + ac.rwLock.RLock() + defer ac.rwLock.RUnlock() + + if code, ok := ac.codes[addr]; ok { + return code, true + } + + return nil, false +} + +func (ac *AccountCache) query(addr string, prefix string) map[string][]byte { + ac.rwLock.RLock() + defer ac.rwLock.RUnlock() + + ret := make(map[string][]byte) + + if stateMap, ok := ac.states[addr]; ok { + for key, val := range stateMap { + if strings.HasPrefix(key, prefix) { + ret[key] = val + } + } + } + return ret +} + +func (ac *AccountCache) clear() { + ac.rwLock.Lock() + defer ac.rwLock.Unlock() + + ac.innerAccounts = make(map[string]*innerAccount) + ac.states = make(map[string]map[string][]byte) + ac.codes = make(map[string][]byte) +} diff --git a/internal/ledger/account_test.go b/internal/ledger/account_test.go index d431595..655082f 100644 --- a/internal/ledger/account_test.go +++ b/internal/ledger/account_test.go @@ -23,17 +23,19 @@ func TestAccount_GetState(t *testing.T) { ledger, err := New(repoRoot, blockStorage, log.NewWithModule("ChainLedger")) assert.Nil(t, err) - ldb := ledger.ldb - h := hexutil.Encode(bytesutil.LeftPadBytes([]byte{11}, 20)) addr := types.String2Address(h) - account := newAccount(ldb, addr) + account := newAccount(ledger.ldb, ledger.accountCache, addr) account.SetState([]byte("a"), []byte("b")) ok, v := account.GetState([]byte("a")) assert.True(t, ok) assert.Equal(t, []byte("b"), v) + ok, v = account.GetState([]byte("a")) + assert.True(t, ok) + assert.Equal(t, []byte("b"), v) + account.SetState([]byte("a"), nil) ok, v = account.GetState([]byte("a")) assert.False(t, ok) diff --git a/internal/ledger/blockchain.go b/internal/ledger/blockchain.go index 9d3a8d5..94425c7 100644 --- a/internal/ledger/blockchain.go +++ b/internal/ledger/blockchain.go @@ -268,3 +268,68 @@ func (l *ChainLedger) persistChainMeta(batcher storage.Batch, meta *pb.ChainMeta return nil } + +func (l *ChainLedger) removeChainDataOnBlock(batch storage.Batch, height uint64) (uint64, error) { + block, err := l.GetBlock(height) + if err != nil { + return 0, err + } + + batch.Delete(compositeKey(blockKey, height)) + batch.Delete(compositeKey(blockHashKey, block.BlockHash.Hex())) + + for _, tx := range block.Transactions { + batch.Delete(compositeKey(transactionKey, tx.TransactionHash.Hex())) + batch.Delete(compositeKey(transactionMetaKey, tx.TransactionHash.Hex())) + batch.Delete(compositeKey(receiptKey, tx.TransactionHash.Hex())) + } + + return getInterchainTxCount(block.BlockHeader) +} + +func (l *ChainLedger) rollbackBlockChain(height uint64) error { + meta := l.GetChainMeta() + + if meta.Height < height { + return ErrorRollbackToHigherNumber + } + + if meta.Height == height { + return nil + } + + batch := l.blockchainStore.NewBatch() + + for i := meta.Height; i > height; i-- { + count, err := l.removeChainDataOnBlock(batch, i) + if err != nil { + return err + } + meta.InterchainTxCount -= count + } + + if height == 0 { + batch.Delete([]byte(chainMetaKey)) + meta = &pb.ChainMeta{} + } else { + block, err := l.GetBlock(height) + if err != nil { + return err + } + meta = &pb.ChainMeta{ + Height: block.BlockHeader.Number, + BlockHash: block.BlockHash, + InterchainTxCount: meta.InterchainTxCount, + } + + if err := l.persistChainMeta(batch, meta); err != nil { + return err + } + } + + batch.Commit() + + l.UpdateChainMeta(meta) + + return nil +} diff --git a/internal/ledger/genesis/genesis.go b/internal/ledger/genesis/genesis.go index 9ad493e..80af81b 100644 --- a/internal/ledger/genesis/genesis.go +++ b/internal/ledger/genesis/genesis.go @@ -3,10 +3,9 @@ package genesis import ( "encoding/json" - "github.com/meshplus/bitxhub-model/pb" - "github.com/meshplus/bitxhub-kit/bytesutil" "github.com/meshplus/bitxhub-kit/types" + "github.com/meshplus/bitxhub-model/pb" "github.com/meshplus/bitxhub/internal/ledger" "github.com/meshplus/bitxhub/internal/repo" ) @@ -16,9 +15,9 @@ var ( ) // Initialize initialize block -func Initialize(config *repo.Config, ledger ledger.Ledger) error { +func Initialize(config *repo.Config, lg ledger.Ledger) error { for _, addr := range config.Addresses { - ledger.SetBalance(types.String2Address(addr), 100000000) + lg.SetBalance(types.String2Address(addr), 100000000) } body, err := json.Marshal(config.Genesis.Addresses) @@ -26,21 +25,24 @@ func Initialize(config *repo.Config, ledger ledger.Ledger) error { return err } - ledger.SetState(roleAddr, []byte("admin-roles"), body) - - hash, err := ledger.Commit(1) - if err != nil { - return err - } + lg.SetState(roleAddr, []byte("admin-roles"), body) + accounts, journal := lg.FlushDirtyDataAndComputeJournal() block := &pb.Block{ BlockHeader: &pb.BlockHeader{ Number: 1, - StateRoot: hash, + StateRoot: journal.ChangedHash, }, } - block.BlockHash = block.Hash() + blockData := &ledger.BlockData{ + Block: block, + Receipts: nil, + Accounts: accounts, + Journal: journal, + } - return ledger.PersistExecutionResult(block, nil) + lg.PersistBlockData(blockData) + + return nil } diff --git a/internal/ledger/journal.go b/internal/ledger/journal.go index d0c366c..d1dbd7a 100644 --- a/internal/ledger/journal.go +++ b/internal/ledger/journal.go @@ -1,7 +1,6 @@ package ledger import ( - "encoding/hex" "encoding/json" "strconv" @@ -42,15 +41,10 @@ func (journal *journal) revert(batch storage.Batch) { } for key, val := range journal.PrevStates { - byteKey, err := hex.DecodeString(key) - if err != nil { - panic(err) - } - if val != nil { - batch.Put(append(journal.Address.Bytes(), byteKey...), val) + batch.Put(composeStateKey(journal.Address, []byte(key)), val) } else { - batch.Delete(append(journal.Address.Bytes(), byteKey...)) + batch.Delete(composeStateKey(journal.Address, []byte(key))) } } diff --git a/internal/ledger/key.go b/internal/ledger/key.go index 2632eb4..3d8e664 100644 --- a/internal/ledger/key.go +++ b/internal/ledger/key.go @@ -1,6 +1,10 @@ package ledger -import "fmt" +import ( + "fmt" + + "github.com/meshplus/bitxhub-kit/types" +) const ( blockKey = "block-" @@ -17,3 +21,7 @@ const ( func compositeKey(prefix string, value interface{}) []byte { return append([]byte(prefix), []byte(fmt.Sprintf("%v", value))...) } + +func composeStateKey(addr types.Address, key []byte) []byte { + return append(addr.Bytes(), key...) +} diff --git a/internal/ledger/ledger.go b/internal/ledger/ledger.go index 1de0444..57e737e 100644 --- a/internal/ledger/ledger.go +++ b/internal/ledger/ledger.go @@ -29,10 +29,20 @@ type ChainLedger struct { maxJnlHeight uint64 events map[string][]*pb.Event accounts map[string]*Account + accountCache *AccountCache prevJnlHash types.Hash chainMutex sync.RWMutex chainMeta *pb.ChainMeta + + journalMutex sync.RWMutex +} + +type BlockData struct { + Block *pb.Block + Receipts []*pb.Receipt + Accounts map[string]*Account + Journal *BlockJournal } // New create a new ledger instance @@ -49,18 +59,13 @@ func New(repoRoot string, blockchainStore storage.Storage, logger logrus.FieldLo minJnlHeight, maxJnlHeight := getJournalRange(ldb) - if maxJnlHeight < chainMeta.Height { - // TODO(xcc): how to handle this case - panic("state tree height is less than blockchain height") - } - prevJnlHash := types.Hash{} if maxJnlHeight != 0 { blockJournal := getBlockJournal(maxJnlHeight, ldb) prevJnlHash = blockJournal.ChangedHash } - return &ChainLedger{ + ledger := &ChainLedger{ logger: logger, chainMeta: chainMeta, blockchainStore: blockchainStore, @@ -69,54 +74,62 @@ func New(repoRoot string, blockchainStore storage.Storage, logger logrus.FieldLo maxJnlHeight: maxJnlHeight, events: make(map[string][]*pb.Event, 10), accounts: make(map[string]*Account), + accountCache: NewAccountCache(), prevJnlHash: prevJnlHash, - }, nil + } + + height := maxJnlHeight + if maxJnlHeight > chainMeta.Height { + height = chainMeta.Height + } + + if err := ledger.Rollback(height); err != nil { + return nil, err + } + + return ledger, nil +} + +// PersistBlockData persists block data +func (l *ChainLedger) PersistBlockData(blockData *BlockData) { + block := blockData.Block + receipts := blockData.Receipts + accounts := blockData.Accounts + journal := blockData.Journal + + if err := l.Commit(block.BlockHeader.Number, accounts, journal); err != nil { + panic(err) + } + + if err := l.PersistExecutionResult(block, receipts); err != nil { + panic(err) + } + + l.logger.WithFields(logrus.Fields{ + "height": block.BlockHeader.Number, + "hash": block.BlockHash.ShortString(), + "count": len(block.Transactions), + }).Info("Persist block") } // Rollback rollback ledger to history version func (l *ChainLedger) Rollback(height uint64) error { - if l.maxJnlHeight < height { - return ErrorRollbackToHigherNumber + if err := l.rollbackState(height); err != nil { + return err } - if l.minJnlHeight > height { - return ErrorRollbackTooMuch + if err := l.rollbackBlockChain(height); err != nil { + return err } - if l.maxJnlHeight == height { - return nil - } - - // clean cache account - l.Clear() - - for i := l.maxJnlHeight; i > height; i-- { - batch := l.ldb.NewBatch() - - blockJournal := getBlockJournal(i, l.ldb) - if blockJournal == nil { - return ErrorRollbackWithoutJournal - } - - for _, journal := range blockJournal.Journals { - journal.revert(batch) - } - - batch.Delete(compositeKey(journalKey, i)) - batch.Put(compositeKey(journalKey, maxHeightStr), marshalHeight(i-1)) - batch.Commit() - } - - journal := getBlockJournal(height, l.ldb) - - l.maxJnlHeight = height - l.prevJnlHash = journal.ChangedHash - return nil } // RemoveJournalsBeforeBlock removes ledger journals whose block number < height func (l *ChainLedger) RemoveJournalsBeforeBlock(height uint64) error { + l.journalMutex.Lock() + defer l.journalMutex.Unlock() + if height > l.maxJnlHeight { return ErrorRemoveJournalOutOfRange } diff --git a/internal/ledger/ledger_test.go b/internal/ledger/ledger_test.go index 4fdb448..a1bac65 100644 --- a/internal/ledger/ledger_test.go +++ b/internal/ledger/ledger_test.go @@ -6,9 +6,12 @@ import ( "io/ioutil" "testing" + "github.com/meshplus/bitxhub/pkg/storage" + "github.com/meshplus/bitxhub-kit/bytesutil" "github.com/meshplus/bitxhub-kit/log" "github.com/meshplus/bitxhub-kit/types" + "github.com/meshplus/bitxhub-model/pb" "github.com/meshplus/bitxhub/pkg/storage/leveldb" "github.com/stretchr/testify/assert" ) @@ -25,47 +28,47 @@ func TestLedger_Commit(t *testing.T) { account := types.Bytes2Address(bytesutil.LeftPadBytes([]byte{100}, 20)) ledger.SetState(account, []byte("a"), []byte("b")) - hash, err := ledger.Commit(1) - assert.Nil(t, err) + accounts, journal := ledger.FlushDirtyDataAndComputeJournal() + ledger.PersistBlockData(genBlockData(1, accounts, journal)) assert.Equal(t, uint64(1), ledger.Version()) - assert.Equal(t, "0xe5ace5cd035b4c3d9d73a3f4a4a64e6e306010c75c35558283847c7c6473d66c", hash.Hex()) + assert.Equal(t, "0xa1a6d35708fa6cf804b6cf9479f3a55d9a87fbfb83c55a64685aeabdba6116b1", journal.ChangedHash.Hex()) - hash, err = ledger.Commit(2) - assert.Nil(t, err) + accounts, journal = ledger.FlushDirtyDataAndComputeJournal() + ledger.PersistBlockData(genBlockData(2, accounts, journal)) assert.Equal(t, uint64(2), ledger.Version()) - assert.Equal(t, "0x4204720214cb812d802b2075c5fed85cd5dfe8a6065627489b6296108f0fedc2", hash.Hex()) + assert.Equal(t, "0xf09f0198c06d549316d4ee7c497c9eaef9d24f5b1075e7bcef3d0a82dfa742cf", journal.ChangedHash.Hex()) ledger.SetState(account, []byte("a"), []byte("3")) ledger.SetState(account, []byte("a"), []byte("2")) - hash, err = ledger.Commit(3) - assert.Nil(t, err) + accounts, journal = ledger.FlushDirtyDataAndComputeJournal() + ledger.PersistBlockData(genBlockData(3, accounts, journal)) assert.Equal(t, uint64(3), ledger.Version()) - assert.Equal(t, "0xf08cc4b2da3f277202dc50a094ff2021300375915c14894a53fe02540feb3411", hash.Hex()) + assert.Equal(t, "0xe9fc370dd36c9bd5f67ccfbc031c909f53a3d8bc7084c01362c55f2d42ba841c", journal.ChangedHash.Hex()) ledger.SetBalance(account, 100) - hash, err = ledger.Commit(4) - assert.Equal(t, uint64(4), ledger.maxJnlHeight) + accounts, journal = ledger.FlushDirtyDataAndComputeJournal() + ledger.PersistBlockData(genBlockData(4, accounts, journal)) assert.Nil(t, err) assert.Equal(t, uint64(4), ledger.Version()) - assert.Equal(t, "0x8ef7f408372406532c7060045d77fb67d322cea7aa49afdc3a741f4f340dc6d5", hash.Hex()) + assert.Equal(t, "0xc179056204ba33ed6cfc0bfe94ca03319beb522fd7b0773a589899817b49ec08", journal.ChangedHash.Hex()) code := bytesutil.RightPadBytes([]byte{100}, 100) ledger.SetCode(account, code) ledger.SetState(account, []byte("b"), []byte("3")) ledger.SetState(account, []byte("c"), []byte("2")) - hash, err = ledger.Commit(5) - assert.Nil(t, err) + accounts, journal = ledger.FlushDirtyDataAndComputeJournal() + ledger.PersistBlockData(genBlockData(5, accounts, journal)) assert.Equal(t, uint64(5), ledger.Version()) assert.Equal(t, uint64(5), ledger.maxJnlHeight) minHeight, maxHeight := getJournalRange(ledger.ldb) - journal := getBlockJournal(maxHeight, ledger.ldb) + journal5 := getBlockJournal(maxHeight, ledger.ldb) assert.Nil(t, err) assert.Equal(t, uint64(1), minHeight) assert.Equal(t, uint64(5), maxHeight) - assert.Equal(t, hash, journal.ChangedHash) - assert.Equal(t, 1, len(journal.Journals)) - entry := journal.Journals[0] + assert.Equal(t, journal.ChangedHash, journal5.ChangedHash) + assert.Equal(t, 1, len(journal5.Journals)) + entry := journal5.Journals[0] assert.Equal(t, account, entry.Address) assert.True(t, entry.AccountChanged) assert.Equal(t, uint64(100), entry.PrevAccount.Balance) @@ -83,7 +86,7 @@ func TestLedger_Commit(t *testing.T) { ldg, err := New(repoRoot, blockStorage, log.NewWithModule("executor")) assert.Nil(t, err) assert.Equal(t, uint64(5), ldg.maxJnlHeight) - assert.Equal(t, hash, ldg.prevJnlHash) + assert.Equal(t, journal.ChangedHash, ldg.prevJnlHash) ok, value := ldg.GetState(account, []byte("a")) assert.True(t, ok) @@ -119,45 +122,83 @@ func TestChainLedger_Rollback(t *testing.T) { hash0 := types.Hash{} assert.Equal(t, hash0, ledger.prevJnlHash) - code := sha256.Sum256([]byte("code")) - codeHash := sha256.Sum256(code[:]) - ledger.SetBalance(addr0, 1) - ledger.SetCode(addr0, code[:]) - - hash1, err := ledger.Commit(1) - assert.Nil(t, err) + accounts, journal1 := ledger.FlushDirtyDataAndComputeJournal() + ledger.PersistBlockData(genBlockData(1, accounts, journal1)) ledger.SetBalance(addr0, 2) ledger.SetState(addr0, []byte("a"), []byte("2")) - code1 := sha256.Sum256([]byte("code1")) - codeHash1 := sha256.Sum256(code1[:]) - ledger.SetCode(addr0, code1[:]) + code := sha256.Sum256([]byte("code")) + codeHash := sha256.Sum256(code[:]) + ledger.SetCode(addr0, code[:]) - hash2, err := ledger.Commit(2) - assert.Nil(t, err) + accounts, journal2 := ledger.FlushDirtyDataAndComputeJournal() + ledger.PersistBlockData(genBlockData(2, accounts, journal2)) + + account0 := ledger.GetAccount(addr0) + assert.Equal(t, uint64(2), account0.GetBalance()) ledger.SetBalance(addr1, 3) ledger.SetBalance(addr0, 4) ledger.SetState(addr0, []byte("a"), []byte("3")) ledger.SetState(addr0, []byte("b"), []byte("4")) - hash3, err := ledger.Commit(3) + code1 := sha256.Sum256([]byte("code1")) + codeHash1 := sha256.Sum256(code1[:]) + ledger.SetCode(addr0, code1[:]) + + accounts, journal3 := ledger.FlushDirtyDataAndComputeJournal() + ledger.PersistBlockData(genBlockData(3, accounts, journal3)) + + assert.Equal(t, journal3.ChangedHash, ledger.prevJnlHash) + block, err := ledger.GetBlock(3) assert.Nil(t, err) - assert.Equal(t, hash3, ledger.prevJnlHash) + assert.NotNil(t, block) + assert.Equal(t, uint64(3), ledger.chainMeta.Height) + + account0 = ledger.GetAccount(addr0) + assert.Equal(t, uint64(4), account0.GetBalance()) + + err = ledger.Rollback(4) + assert.Equal(t, ErrorRollbackToHigherNumber, err) + + err = ledger.RemoveJournalsBeforeBlock(2) + assert.Nil(t, err) + assert.Equal(t, uint64(2), ledger.minJnlHeight) + + err = ledger.Rollback(0) + assert.Equal(t, ErrorRollbackTooMuch, err) + + err = ledger.Rollback(1) + assert.Equal(t, ErrorRollbackTooMuch, err) + assert.Equal(t, uint64(3), ledger.chainMeta.Height) + + err = ledger.Rollback(3) + assert.Nil(t, err) + assert.Equal(t, journal3.ChangedHash, ledger.prevJnlHash) + block, err = ledger.GetBlock(3) + assert.Nil(t, err) + assert.NotNil(t, block) + assert.Equal(t, uint64(3), ledger.chainMeta.Height) + assert.Equal(t, codeHash1[:], account0.CodeHash()) + assert.Equal(t, code1[:], account0.Code()) err = ledger.Rollback(2) assert.Nil(t, err) - assert.Equal(t, hash2, ledger.prevJnlHash) - assert.Equal(t, uint64(1), ledger.minJnlHeight) + block, err = ledger.GetBlock(3) + assert.Equal(t, storage.ErrorNotFound, err) + assert.Nil(t, block) + assert.Equal(t, uint64(2), ledger.chainMeta.Height) + assert.Equal(t, journal2.ChangedHash, ledger.prevJnlHash) + assert.Equal(t, uint64(2), ledger.minJnlHeight) assert.Equal(t, uint64(2), ledger.maxJnlHeight) - account0 := ledger.GetAccount(addr0) + account0 = ledger.GetAccount(addr0) assert.Equal(t, uint64(2), account0.GetBalance()) assert.Equal(t, uint64(0), account0.GetNonce()) - assert.Equal(t, codeHash1[:], account0.CodeHash()) - assert.Equal(t, code1[:], account0.Code()) + assert.Equal(t, codeHash[:], account0.CodeHash()) + assert.Equal(t, code[:], account0.Code()) ok, val := account0.GetState([]byte("a")) assert.True(t, ok) assert.Equal(t, []byte("2"), val) @@ -171,24 +212,11 @@ func TestChainLedger_Rollback(t *testing.T) { ledger.Close() ledger, err = New(repoRoot, blockStorage, log.NewWithModule("executor")) assert.Nil(t, err) - assert.Equal(t, uint64(1), ledger.minJnlHeight) + assert.Equal(t, uint64(2), ledger.minJnlHeight) assert.Equal(t, uint64(2), ledger.maxJnlHeight) err = ledger.Rollback(1) - assert.Nil(t, err) - assert.Equal(t, hash1, ledger.prevJnlHash) - - account0 = ledger.GetAccount(addr0) - assert.Equal(t, uint64(1), account0.GetBalance()) - assert.Equal(t, uint64(0), account0.GetNonce()) - assert.Equal(t, codeHash[:], account0.CodeHash()) - assert.Equal(t, code[:], account0.Code()) - ok, _ = account0.GetState([]byte("a")) - assert.False(t, ok) - - err = ledger.Rollback(0) assert.Equal(t, ErrorRollbackTooMuch, err) - } func TestChainLedger_RemoveJournalsBeforeBlock(t *testing.T) { @@ -202,19 +230,23 @@ func TestChainLedger_RemoveJournalsBeforeBlock(t *testing.T) { assert.Equal(t, uint64(0), ledger.minJnlHeight) assert.Equal(t, uint64(0), ledger.maxJnlHeight) - _, _ = ledger.Commit(1) - _, _ = ledger.Commit(2) - _, _ = ledger.Commit(3) - hash, _ := ledger.Commit(4) + accounts, journal := ledger.FlushDirtyDataAndComputeJournal() + ledger.PersistBlockData(genBlockData(1, accounts, journal)) + accounts, journal = ledger.FlushDirtyDataAndComputeJournal() + ledger.PersistBlockData(genBlockData(2, accounts, journal)) + accounts, journal = ledger.FlushDirtyDataAndComputeJournal() + ledger.PersistBlockData(genBlockData(3, accounts, journal)) + accounts, journal4 := ledger.FlushDirtyDataAndComputeJournal() + ledger.PersistBlockData(genBlockData(4, accounts, journal4)) assert.Equal(t, uint64(1), ledger.minJnlHeight) assert.Equal(t, uint64(4), ledger.maxJnlHeight) minHeight, maxHeight := getJournalRange(ledger.ldb) - journal := getBlockJournal(maxHeight, ledger.ldb) + journal = getBlockJournal(maxHeight, ledger.ldb) assert.Equal(t, uint64(1), minHeight) assert.Equal(t, uint64(4), maxHeight) - assert.Equal(t, hash, journal.ChangedHash) + assert.Equal(t, journal4.ChangedHash, journal.ChangedHash) err = ledger.RemoveJournalsBeforeBlock(5) assert.Equal(t, ErrorRemoveJournalOutOfRange, err) @@ -228,7 +260,7 @@ func TestChainLedger_RemoveJournalsBeforeBlock(t *testing.T) { journal = getBlockJournal(maxHeight, ledger.ldb) assert.Equal(t, uint64(2), minHeight) assert.Equal(t, uint64(4), maxHeight) - assert.Equal(t, hash, journal.ChangedHash) + assert.Equal(t, journal4.ChangedHash, journal.ChangedHash) err = ledger.RemoveJournalsBeforeBlock(2) assert.Nil(t, err) @@ -247,18 +279,234 @@ func TestChainLedger_RemoveJournalsBeforeBlock(t *testing.T) { assert.Equal(t, uint64(4), ledger.minJnlHeight) assert.Equal(t, uint64(4), ledger.maxJnlHeight) - assert.Equal(t, hash, ledger.prevJnlHash) + assert.Equal(t, journal4.ChangedHash, ledger.prevJnlHash) minHeight, maxHeight = getJournalRange(ledger.ldb) journal = getBlockJournal(maxHeight, ledger.ldb) assert.Equal(t, uint64(4), minHeight) assert.Equal(t, uint64(4), maxHeight) - assert.Equal(t, hash, journal.ChangedHash) + assert.Equal(t, journal4.ChangedHash, journal.ChangedHash) ledger.Close() ledger, err = New(repoRoot, blockStorage, log.NewWithModule("executor")) assert.Nil(t, err) assert.Equal(t, uint64(4), ledger.minJnlHeight) assert.Equal(t, uint64(4), ledger.maxJnlHeight) - assert.Equal(t, hash, ledger.prevJnlHash) + assert.Equal(t, journal4.ChangedHash, ledger.prevJnlHash) +} + +func TestChainLedger_QueryByPrefix(t *testing.T) { + repoRoot, err := ioutil.TempDir("", "ledger_queryByPrefix") + assert.Nil(t, err) + blockStorage, err := leveldb.New(repoRoot) + assert.Nil(t, err) + ledger, err := New(repoRoot, blockStorage, log.NewWithModule("executor")) + assert.Nil(t, err) + + addr := types.Bytes2Address(bytesutil.LeftPadBytes([]byte{1}, 20)) + key0 := []byte{100, 100} + key1 := []byte{100, 101} + key2 := []byte{100, 102} + key3 := []byte{10, 102} + + ledger.SetState(addr, key0, []byte("0")) + ledger.SetState(addr, key1, []byte("1")) + ledger.SetState(addr, key2, []byte("2")) + ledger.SetState(addr, key3, []byte("2")) + + accounts, journal := ledger.FlushDirtyDataAndComputeJournal() + + ok, vals := ledger.QueryByPrefix(addr, string([]byte{100})) + assert.True(t, ok) + assert.Equal(t, 3, len(vals)) + assert.Equal(t, []byte("0"), vals[0]) + assert.Equal(t, []byte("1"), vals[1]) + assert.Equal(t, []byte("2"), vals[2]) + + err = ledger.Commit(1, accounts, journal) + assert.Nil(t, err) + + ok, vals = ledger.QueryByPrefix(addr, string([]byte{100})) + assert.True(t, ok) + assert.Equal(t, 3, len(vals)) + assert.Equal(t, []byte("0"), vals[0]) + assert.Equal(t, []byte("1"), vals[1]) + assert.Equal(t, []byte("2"), vals[2]) + +} + +func TestChainLedger_GetAccount(t *testing.T) { + repoRoot, err := ioutil.TempDir("", "ledger_getAccount") + assert.Nil(t, err) + blockStorage, err := leveldb.New(repoRoot) + assert.Nil(t, err) + ledger, err := New(repoRoot, blockStorage, log.NewWithModule("executor")) + assert.Nil(t, err) + + addr := types.Bytes2Address(bytesutil.LeftPadBytes([]byte{1}, 20)) + code := bytesutil.LeftPadBytes([]byte{1}, 120) + key0 := []byte{100, 100} + key1 := []byte{100, 101} + + account := ledger.GetOrCreateAccount(addr) + account.SetBalance(1) + account.SetNonce(2) + account.SetCodeAndHash(code) + + account.SetState(key0, key1) + account.SetState(key1, key0) + + accounts, journal := ledger.FlushDirtyDataAndComputeJournal() + err = ledger.Commit(1, accounts, journal) + assert.Nil(t, err) + + account1 := ledger.GetAccount(addr) + + assert.Equal(t, account.GetBalance(), ledger.GetBalance(addr)) + assert.Equal(t, account.GetBalance(), account1.GetBalance()) + assert.Equal(t, account.GetNonce(), account1.GetNonce()) + assert.Equal(t, account.CodeHash(), account1.CodeHash()) + assert.Equal(t, account.Code(), account1.Code()) + ok0, val0 := account.GetState(key0) + ok1, val1 := account.GetState(key1) + assert.Equal(t, ok0, ok1) + assert.Equal(t, val0, key1) + assert.Equal(t, val1, key0) + + key2 := []byte{100, 102} + val2 := []byte{111} + ledger.SetState(addr, key0, val0) + ledger.SetState(addr, key2, val2) + ledger.SetState(addr, key0, val1) + accounts, journal = ledger.FlushDirtyDataAndComputeJournal() + err = ledger.Commit(2, accounts, journal) + assert.Nil(t, err) + + ledger.SetState(addr, key0, val0) + ledger.SetState(addr, key0, val1) + ledger.SetState(addr, key2, nil) + accounts, journal = ledger.FlushDirtyDataAndComputeJournal() + err = ledger.Commit(3, accounts, journal) + assert.Nil(t, err) + + ok, val := ledger.GetState(addr, key0) + assert.True(t, ok) + assert.Equal(t, val1, val) + + ok, val2 = ledger.GetState(addr, key2) + assert.False(t, ok) + assert.Nil(t, val2) +} + +func TestChainLedger_GetCode(t *testing.T) { + repoRoot, err := ioutil.TempDir("", "ledger_getCode") + assert.Nil(t, err) + blockStorage, err := leveldb.New(repoRoot) + assert.Nil(t, err) + ledger, err := New(repoRoot, blockStorage, log.NewWithModule("executor")) + assert.Nil(t, err) + + addr := types.Bytes2Address(bytesutil.LeftPadBytes([]byte{1}, 20)) + code := bytesutil.LeftPadBytes([]byte{10}, 120) + + code0 := ledger.GetCode(addr) + assert.Nil(t, code0) + + ledger.SetCode(addr, code) + + accounts, journal := ledger.FlushDirtyDataAndComputeJournal() + err = ledger.Commit(1, accounts, journal) + assert.Nil(t, err) + + vals := ledger.GetCode(addr) + assert.Equal(t, code, vals) + + accounts, journal = ledger.FlushDirtyDataAndComputeJournal() + err = ledger.Commit(2, accounts, journal) + assert.Nil(t, err) + + vals = ledger.GetCode(addr) + assert.Equal(t, code, vals) + + vals = ledger.GetCode(addr) + assert.Equal(t, code, vals) +} + +func TestChainLedger_AddAccountsToCache(t *testing.T) { + repoRoot, err := ioutil.TempDir("", "ledger_addAccountToCache") + assert.Nil(t, err) + blockStorage, err := leveldb.New(repoRoot) + assert.Nil(t, err) + ledger, err := New(repoRoot, blockStorage, log.NewWithModule("executor")) + assert.Nil(t, err) + + addr := types.Bytes2Address(bytesutil.LeftPadBytes([]byte{1}, 20)) + key := []byte{1} + val := []byte{2} + code := bytesutil.RightPadBytes([]byte{1, 2, 3, 4}, 100) + + ledger.SetBalance(addr, 100) + ledger.SetNonce(addr, 1) + ledger.SetState(addr, key, val) + ledger.SetCode(addr, code) + + accounts, journal := ledger.FlushDirtyDataAndComputeJournal() + ledger.Clear() + + innerAccount, ok := ledger.accountCache.getInnerAccount(addr.Hex()) + assert.True(t, ok) + assert.Equal(t, uint64(100), innerAccount.Balance) + assert.Equal(t, uint64(1), innerAccount.Nonce) + assert.Equal(t, types.Hash(sha256.Sum256(code)).Bytes(), innerAccount.CodeHash) + + val1, ok := ledger.accountCache.getState(addr.Hex(), string(key)) + assert.True(t, ok) + assert.Equal(t, val, val1) + + code1, ok := ledger.accountCache.getCode(addr.Hex()) + assert.True(t, ok) + assert.Equal(t, code, code1) + + assert.Equal(t, uint64(100), ledger.GetBalance(addr)) + assert.Equal(t, uint64(1), ledger.GetNonce(addr)) + + ok, val1 = ledger.GetState(addr, key) + assert.Equal(t, true, ok) + assert.Equal(t, val, val1) + assert.Equal(t, code, ledger.GetCode(addr)) + + err = ledger.Commit(1, accounts, journal) + assert.Nil(t, err) + + assert.Equal(t, uint64(100), ledger.GetBalance(addr)) + assert.Equal(t, uint64(1), ledger.GetNonce(addr)) + + ok, val1 = ledger.GetState(addr, key) + assert.Equal(t, true, ok) + assert.Equal(t, val, val1) + assert.Equal(t, code, ledger.GetCode(addr)) + + _, ok = ledger.accountCache.getInnerAccount(addr.Hex()) + assert.False(t, ok) + + _, ok = ledger.accountCache.getState(addr.Hex(), string(key)) + assert.False(t, ok) + + _, ok = ledger.accountCache.getCode(addr.Hex()) + assert.False(t, ok) +} + +func genBlockData(height uint64, accounts map[string]*Account, journal *BlockJournal) *BlockData { + return &BlockData{ + Block: &pb.Block{ + BlockHeader: &pb.BlockHeader{ + Number: height, + }, + BlockHash: sha256.Sum256([]byte{1}), + Transactions: []*pb.Transaction{{}}, + }, + Receipts: nil, + Accounts: accounts, + Journal: journal, + } } diff --git a/internal/ledger/state_accessor.go b/internal/ledger/state_accessor.go index a2a87d5..96bb179 100644 --- a/internal/ledger/state_accessor.go +++ b/internal/ledger/state_accessor.go @@ -1,6 +1,7 @@ package ledger import ( + "bytes" "crypto/sha256" "encoding/json" "sort" @@ -20,14 +21,19 @@ func (l *ChainLedger) GetOrCreateAccount(addr types.Address) *Account { } account := l.GetAccount(addr) - l.accounts[addr.Hex()] = account + l.accounts[h] = account return account } // GetAccount get account info using account Address, if not found, create a new account func (l *ChainLedger) GetAccount(addr types.Address) *Account { - account := newAccount(l.ldb, addr) + account := newAccount(l.ldb, l.accountCache, addr) + + if innerAccount, ok := l.accountCache.getInnerAccount(addr.Hex()); ok { + account.originAccount = innerAccount + return account + } if data := l.ldb.Get(compositeKey(accountKey, addr.Hex())); data != nil { account.originAccount = &innerAccount{} @@ -98,20 +104,21 @@ func (l *ChainLedger) Clear() { l.accounts = make(map[string]*Account) } -// Commit commit the state -func (l *ChainLedger) Commit(height uint64) (types.Hash, error) { +// FlushDirtyDataAndComputeJournal gets dirty accounts and computes block journal +func (l *ChainLedger) FlushDirtyDataAndComputeJournal() (map[string]*Account, *BlockJournal) { + dirtyAccounts := make(map[string]*Account) var dirtyAccountData []byte var journals []*journal sortedAddr := make([]string, 0, len(l.accounts)) accountData := make(map[string][]byte) - ldbBatch := l.ldb.NewBatch() for addr, account := range l.accounts { - journal := account.getJournalIfModified(ldbBatch) + journal := account.getJournalIfModified() if journal != nil { + journals = append(journals, journal) sortedAddr = append(sortedAddr, addr) accountData[addr] = account.getDirtyData() - journals = append(journals, journal) + dirtyAccounts[addr] = account } } @@ -122,19 +129,61 @@ func (l *ChainLedger) Commit(height uint64) (types.Hash, error) { dirtyAccountData = append(dirtyAccountData, l.prevJnlHash[:]...) journalHash := sha256.Sum256(dirtyAccountData) - blockJournal := BlockJournal{ + blockJournal := &BlockJournal{ Journals: journals, ChangedHash: journalHash, } + l.prevJnlHash = journalHash + l.Clear() + l.accountCache.add(dirtyAccounts) + + return dirtyAccounts, blockJournal +} + +// Commit commit the state +func (l *ChainLedger) Commit(height uint64, accounts map[string]*Account, blockJournal *BlockJournal) error { + ldbBatch := l.ldb.NewBatch() + + for _, account := range accounts { + if innerAccountChanged(account.originAccount, account.dirtyAccount) { + data, err := account.dirtyAccount.Marshal() + if err != nil { + panic(err) + } + ldbBatch.Put(compositeKey(accountKey, account.Addr.Hex()), data) + } + + if !bytes.Equal(account.originCode, account.dirtyCode) { + if account.dirtyCode != nil { + ldbBatch.Put(compositeKey(codeKey, account.Addr.Hex()), account.dirtyCode) + } else { + ldbBatch.Delete(compositeKey(codeKey, account.Addr.Hex())) + } + } + + for key, val := range account.dirtyState { + origVal := account.originState[key] + if !bytes.Equal(origVal, val) { + if val != nil { + ldbBatch.Put(composeStateKey(account.Addr, []byte(key)), val) + } else { + ldbBatch.Delete(composeStateKey(account.Addr, []byte(key))) + } + } + } + } + data, err := json.Marshal(blockJournal) if err != nil { - return [32]byte{}, err + return err } ldbBatch.Put(compositeKey(journalKey, height), data) ldbBatch.Put(compositeKey(journalKey, maxHeightStr), marshalHeight(height)) + l.journalMutex.Lock() + if l.minJnlHeight == 0 { l.minJnlHeight = height ldbBatch.Put(compositeKey(journalKey, minHeightStr), marshalHeight(height)) @@ -143,13 +192,67 @@ func (l *ChainLedger) Commit(height uint64) (types.Hash, error) { ldbBatch.Commit() l.maxJnlHeight = height - l.prevJnlHash = journalHash - l.Clear() - return journalHash, nil + l.journalMutex.Unlock() + + l.accountCache.remove(accounts) + + return nil } // Version returns the current version func (l *ChainLedger) Version() uint64 { + l.journalMutex.RLock() + defer l.journalMutex.RUnlock() + return l.maxJnlHeight } + +func (l *ChainLedger) rollbackState(height uint64) error { + l.journalMutex.Lock() + defer l.journalMutex.Unlock() + + if l.maxJnlHeight < height { + return ErrorRollbackToHigherNumber + } + + if l.minJnlHeight > height && !(l.minJnlHeight == 1 && height == 0) { + return ErrorRollbackTooMuch + } + + if l.maxJnlHeight == height { + return nil + } + + // clean cache account + l.Clear() + l.accountCache.clear() + + for i := l.maxJnlHeight; i > height; i-- { + batch := l.ldb.NewBatch() + + blockJournal := getBlockJournal(i, l.ldb) + if blockJournal == nil { + return ErrorRollbackWithoutJournal + } + + for _, journal := range blockJournal.Journals { + journal.revert(batch) + } + + batch.Delete(compositeKey(journalKey, i)) + batch.Put(compositeKey(journalKey, maxHeightStr), marshalHeight(i-1)) + batch.Commit() + } + + if height != 0 { + journal := getBlockJournal(height, l.ldb) + l.prevJnlHash = journal.ChangedHash + } else { + l.prevJnlHash = types.Hash{} + l.minJnlHeight = 0 + } + l.maxJnlHeight = height + + return nil +} diff --git a/internal/ledger/state_accessor_test.go b/internal/ledger/state_accessor_test.go deleted file mode 100644 index ab9a343..0000000 --- a/internal/ledger/state_accessor_test.go +++ /dev/null @@ -1,128 +0,0 @@ -package ledger - -import ( - "io/ioutil" - "testing" - - "github.com/meshplus/bitxhub-kit/bytesutil" - "github.com/meshplus/bitxhub-kit/log" - "github.com/meshplus/bitxhub-kit/types" - "github.com/meshplus/bitxhub/pkg/storage/leveldb" - "github.com/stretchr/testify/assert" -) - -func TestChainLedger_QueryByPrefix(t *testing.T) { - repoRoot, err := ioutil.TempDir("", "ledger_commit") - assert.Nil(t, err) - blockStorage, err := leveldb.New(repoRoot) - assert.Nil(t, err) - ledger, err := New(repoRoot, blockStorage, log.NewWithModule("executor")) - assert.Nil(t, err) - - addr := types.Bytes2Address(bytesutil.LeftPadBytes([]byte{1}, 20)) - key0 := []byte{100, 100} - key1 := []byte{100, 101} - key2 := []byte{100, 102} - - ledger.SetState(addr, key0, []byte("0")) - ledger.SetState(addr, key1, []byte("1")) - ledger.SetState(addr, key2, []byte("2")) - - _, err = ledger.Commit(1) - assert.Nil(t, err) - - ok, vals := ledger.QueryByPrefix(addr, string([]byte{100})) - assert.True(t, ok) - assert.Equal(t, 3, len(vals)) - assert.Equal(t, []byte("0"), vals[0]) - assert.Equal(t, []byte("1"), vals[1]) - assert.Equal(t, []byte("2"), vals[2]) -} - -func TestChainLedger_GetAccount(t *testing.T) { - repoRoot, err := ioutil.TempDir("", "ledger_commit") - assert.Nil(t, err) - blockStorage, err := leveldb.New(repoRoot) - assert.Nil(t, err) - ledger, err := New(repoRoot, blockStorage, log.NewWithModule("executor")) - assert.Nil(t, err) - - addr := types.Bytes2Address(bytesutil.LeftPadBytes([]byte{1}, 20)) - code := bytesutil.LeftPadBytes([]byte{1}, 120) - key0 := []byte{100, 100} - key1 := []byte{100, 101} - - account := ledger.GetOrCreateAccount(addr) - account.SetBalance(1) - account.SetNonce(2) - account.SetCodeAndHash(code) - - account.SetState(key0, key1) - account.SetState(key1, key0) - - _, err = ledger.Commit(1) - assert.Nil(t, err) - - account1 := ledger.GetAccount(addr) - - assert.Equal(t, account.GetBalance(), ledger.GetBalance(addr)) - assert.Equal(t, account.GetBalance(), account1.GetBalance()) - assert.Equal(t, account.GetNonce(), account1.GetNonce()) - assert.Equal(t, account.CodeHash(), account1.CodeHash()) - assert.Equal(t, account.Code(), account1.Code()) - ok0, val0 := account.GetState(key0) - ok1, val1 := account.GetState(key1) - assert.Equal(t, ok0, ok1) - assert.Equal(t, val0, key1) - assert.Equal(t, val1, key0) - - key2 := []byte{100, 102} - val2 := []byte{111} - ledger.SetState(addr, key0, val0) - ledger.SetState(addr, key2, val2) - ledger.SetState(addr, key0, val1) - _, err = ledger.Commit(2) - assert.Nil(t, err) - - ledger.SetState(addr, key0, val0) - ledger.SetState(addr, key0, val1) - _, err = ledger.Commit(3) - assert.Nil(t, err) - - ok, val := ledger.GetState(addr, key0) - assert.True(t, ok) - assert.Equal(t, val1, val) -} - -func TestChainLedger_GetCode(t *testing.T) { - repoRoot, err := ioutil.TempDir("", "ledger_commit") - assert.Nil(t, err) - blockStorage, err := leveldb.New(repoRoot) - assert.Nil(t, err) - ledger, err := New(repoRoot, blockStorage, log.NewWithModule("executor")) - assert.Nil(t, err) - - addr := types.Bytes2Address(bytesutil.LeftPadBytes([]byte{1}, 20)) - key0 := []byte{100, 100} - key1 := []byte{100, 101} - key2 := []byte{100, 102} - - ledger.SetState(addr, key0, []byte("0")) - ledger.SetState(addr, key1, []byte("1")) - ledger.SetState(addr, key2, []byte("2")) - - code := bytesutil.LeftPadBytes([]byte{10}, 120) - ledger.SetCode(addr, code) - - _, err = ledger.Commit(1) - assert.Nil(t, err) - - vals := ledger.GetCode(addr) - assert.Equal(t, code, vals) - - _, err = ledger.Commit(2) - assert.Nil(t, err) - - vals = ledger.GetCode(addr) - assert.Equal(t, code, vals) -} diff --git a/internal/ledger/types.go b/internal/ledger/types.go index cc02ea5..2c7c96c 100644 --- a/internal/ledger/types.go +++ b/internal/ledger/types.go @@ -10,6 +10,9 @@ type Ledger interface { BlockchainLedger StateAccessor + // PersistBlockData + PersistBlockData(blockData *BlockData) + // AddEvent AddEvent(*pb.Event) @@ -62,7 +65,10 @@ type StateAccessor interface { QueryByPrefix(address types.Address, prefix string) (bool, [][]byte) // Commit commits the state data - Commit(height uint64) (types.Hash, error) + Commit(height uint64, accounts map[string]*Account, blockJournal *BlockJournal) error + + // FlushDirtyDataAndComputeJournal flushes the dirty data and computes block journal + FlushDirtyDataAndComputeJournal() (map[string]*Account, *BlockJournal) // Version Version() uint64