bitxhub/internal/ledger/state_accessor.go

279 lines
6.9 KiB
Go

package ledger
import (
"bytes"
"crypto/sha256"
"encoding/json"
"sort"
"sync"
"github.com/meshplus/bitxhub-kit/types"
)
var _ Ledger = (*ChainLedger)(nil)
// GetOrCreateAccount get the account, if not exist, create a new account
func (l *ChainLedger) GetOrCreateAccount(addr *types.Address) *Account {
l.lock.RLock()
value, ok := l.accounts[addr.String()]
l.lock.RUnlock()
if ok {
return value
}
l.lock.Lock()
defer l.lock.Unlock()
if value, ok := l.accounts[addr.String()]; ok {
return value
}
account := l.GetAccount(addr)
l.accounts[addr.String()] = account
return account
}
// GetAccount get account info using account Address, if not found, create a new account
func (l *ChainLedger) GetAccount(address *types.Address) *Account {
account := newAccount(l.ldb, l.accountCache, address)
if innerAccount, ok := l.accountCache.getInnerAccount(address); ok {
account.originAccount = innerAccount
return account
}
if data := l.ldb.Get(compositeKey(accountKey, address)); data != nil {
account.originAccount = &innerAccount{}
if err := account.originAccount.Unmarshal(data); err != nil {
panic(err)
}
}
return account
}
// GetBalanec get account balance using account Address
func (l *ChainLedger) GetBalance(addr *types.Address) uint64 {
account := l.GetOrCreateAccount(addr)
return account.GetBalance()
}
// SetBalance set account balance
func (l *ChainLedger) SetBalance(addr *types.Address, value uint64) {
account := l.GetOrCreateAccount(addr)
account.SetBalance(value)
}
// GetState get account state value using account Address and key
func (l *ChainLedger) GetState(addr *types.Address, key []byte) (bool, []byte) {
account := l.GetOrCreateAccount(addr)
return account.GetState(key)
}
// SetState set account state value using account Address and key
func (l *ChainLedger) SetState(addr *types.Address, key []byte, v []byte) {
account := l.GetOrCreateAccount(addr)
account.SetState(key, v)
}
// AddState add account state value using account Address and key
func (l *ChainLedger) AddState(addr *types.Address, key []byte, v []byte) {
account := l.GetOrCreateAccount(addr)
account.AddState(key, v)
}
// SetCode set contract code
func (l *ChainLedger) SetCode(addr *types.Address, code []byte) {
account := l.GetOrCreateAccount(addr)
account.SetCodeAndHash(code)
}
// GetCode get contract code
func (l *ChainLedger) GetCode(addr *types.Address) []byte {
account := l.GetOrCreateAccount(addr)
return account.Code()
}
// GetNonce get account nonce
func (l *ChainLedger) GetNonce(addr *types.Address) uint64 {
account := l.GetOrCreateAccount(addr)
return account.GetNonce()
}
// SetNonce set account nonce
func (l *ChainLedger) SetNonce(addr *types.Address, nonce uint64) {
account := l.GetOrCreateAccount(addr)
account.SetNonce(nonce)
}
// QueryByPrefix query value using key
func (l *ChainLedger) QueryByPrefix(addr *types.Address, prefix string) (bool, [][]byte) {
account := l.GetOrCreateAccount(addr)
return account.Query(prefix)
}
func (l *ChainLedger) Clear() {
l.events = sync.Map{}
l.accounts = make(map[string]*Account)
}
// FlushDirtyDataAndComputeJournal gets dirty accounts and computes block journal
func (l *ChainLedger) FlushDirtyDataAndComputeJournal() (map[string]*Account, *BlockJournal) {
dirtyAccounts := make(map[string]*Account)
var dirtyAccountData []byte
var journals []*journal
var sortedAddr []string
accountData := make(map[string][]byte)
for addr, account := range l.accounts {
journal := account.getJournalIfModified()
if journal != nil {
journals = append(journals, journal)
sortedAddr = append(sortedAddr, addr)
accountData[addr] = account.getDirtyData()
dirtyAccounts[addr] = account
}
}
sort.Strings(sortedAddr)
for _, addr := range sortedAddr {
dirtyAccountData = append(dirtyAccountData, accountData[addr]...)
}
dirtyAccountData = append(dirtyAccountData, l.prevJnlHash.Bytes()...)
journalHash := sha256.Sum256(dirtyAccountData)
blockJournal := &BlockJournal{
Journals: journals,
ChangedHash: types.NewHash(journalHash[:]),
}
l.prevJnlHash = blockJournal.ChangedHash
l.Clear()
l.accountCache.add(dirtyAccounts)
return dirtyAccounts, blockJournal
}
// Commit commit the state
func (l *ChainLedger) Commit(height uint64, accounts map[string]*Account, blockJournal *BlockJournal) error {
ldbBatch := l.ldb.NewBatch()
for _, account := range accounts {
if innerAccountChanged(account.originAccount, account.dirtyAccount) {
data, err := account.dirtyAccount.Marshal()
if err != nil {
panic(err)
}
ldbBatch.Put(compositeKey(accountKey, account.Addr), data)
}
if !bytes.Equal(account.originCode, account.dirtyCode) {
if account.dirtyCode != nil {
ldbBatch.Put(compositeKey(codeKey, account.Addr), account.dirtyCode)
} else {
ldbBatch.Delete(compositeKey(codeKey, account.Addr))
}
}
account.dirtyState.Range(func(key, value interface{}) bool {
valBytes := value.([]byte)
origVal, ok := account.originState.Load(key)
var origValBytes []byte
if ok {
origValBytes = origVal.([]byte)
}
if !bytes.Equal(origValBytes, valBytes) {
if valBytes != nil {
ldbBatch.Put(composeStateKey(account.Addr, []byte(key.(string))), valBytes)
} else {
ldbBatch.Delete(composeStateKey(account.Addr, []byte(key.(string))))
}
}
return true
})
}
data, err := json.Marshal(blockJournal)
if err != nil {
return err
}
ldbBatch.Put(compositeKey(journalKey, height), data)
ldbBatch.Put(compositeKey(journalKey, maxHeightStr), marshalHeight(height))
l.journalMutex.Lock()
if l.minJnlHeight == 0 {
l.minJnlHeight = height
ldbBatch.Put(compositeKey(journalKey, minHeightStr), marshalHeight(height))
}
ldbBatch.Commit()
l.maxJnlHeight = height
l.journalMutex.Unlock()
l.accountCache.remove(accounts)
return nil
}
// Version returns the current version
func (l *ChainLedger) Version() uint64 {
l.journalMutex.RLock()
defer l.journalMutex.RUnlock()
return l.maxJnlHeight
}
func (l *ChainLedger) rollbackState(height uint64) error {
l.journalMutex.Lock()
defer l.journalMutex.Unlock()
if l.maxJnlHeight < height {
return ErrorRollbackToHigherNumber
}
if l.minJnlHeight > height && !(l.minJnlHeight == 1 && height == 0) {
return ErrorRollbackTooMuch
}
if l.maxJnlHeight == height {
return nil
}
// clean cache account
l.Clear()
l.accountCache.clear()
for i := l.maxJnlHeight; i > height; i-- {
batch := l.ldb.NewBatch()
blockJournal := getBlockJournal(i, l.ldb)
if blockJournal == nil {
return ErrorRollbackWithoutJournal
}
for _, journal := range blockJournal.Journals {
journal.revert(batch)
}
batch.Delete(compositeKey(journalKey, i))
batch.Put(compositeKey(journalKey, maxHeightStr), marshalHeight(i-1))
batch.Commit()
}
if height != 0 {
journal := getBlockJournal(height, l.ldb)
l.prevJnlHash = journal.ChangedHash
} else {
l.prevJnlHash = &types.Hash{}
l.minJnlHeight = 0
}
l.maxJnlHeight = height
return nil
}