bitxhub/internal/ledger/account.go

400 lines
8.6 KiB
Go

package ledger
import (
"bytes"
"crypto/sha256"
"encoding/json"
"math/big"
"sort"
"strings"
"sync"
"github.com/ethereum/go-ethereum/crypto"
"github.com/meshplus/bitxhub-kit/storage"
"github.com/meshplus/bitxhub-kit/types"
)
type Account struct {
Addr *types.Address
originAccount *innerAccount
dirtyAccount *innerAccount
originState sync.Map
dirtyState sync.Map
originCode []byte
dirtyCode []byte
dirtyStateHash *types.Hash
ldb storage.Storage
cache *AccountCache
lock sync.RWMutex
changer *stateChanger
suicided bool
}
type innerAccount struct {
Nonce uint64 `json:"nonce"`
Balance *big.Int `json:"balance"`
CodeHash []byte `json:"code_hash"`
}
func newAccount(ldb storage.Storage, cache *AccountCache, addr *types.Address, changer *stateChanger) *Account {
return &Account{
Addr: addr,
ldb: ldb,
cache: cache,
changer: changer,
suicided: false,
}
}
// GetState Get state from local cache, if not found, then get it from DB
func (o *Account) GetState(key []byte) (bool, []byte) {
if val, exist := o.dirtyState.Load(string(key)); exist {
value := val.([]byte)
return value != nil, value
}
if val, exist := o.originState.Load(string(key)); exist {
value := val.([]byte)
return value != nil, value
}
val, ok := o.cache.getState(o.Addr, string(key))
if !ok {
val = o.ldb.Get(composeStateKey(o.Addr, key))
}
o.originState.Store(string(key), val)
return val != nil, val
}
func (o *Account) GetCommittedState(key []byte) []byte {
if val, exist := o.originState.Load(string(key)); exist {
value := val.([]byte)
if val != nil {
return (&types.Hash{}).Bytes()
}
return value
}
val, ok := o.cache.getState(o.Addr, string(key))
if !ok {
val = o.ldb.Get(composeStateKey(o.Addr, key))
}
o.originState.Store(string(key), val)
if val != nil {
return (&types.Hash{}).Bytes()
}
return val
}
// SetState Set account state
func (o *Account) SetState(key []byte, value []byte) {
o.GetState(key)
o.dirtyState.Store(string(key), value)
}
// AddState Add account state
func (o *Account) AddState(key []byte, value []byte) {
o.dirtyState.Store(string(key), value)
}
// SetCodeAndHash Set the contract code and hash
func (o *Account) SetCodeAndHash(code []byte) {
ret := crypto.Keccak256Hash(code)
o.changer.append(codeChange{
account: o.Addr,
prevcode: o.Code(),
})
if o.dirtyAccount == nil {
o.dirtyAccount = copyOrNewIfEmpty(o.originAccount)
}
o.dirtyAccount.CodeHash = ret.Bytes()
o.dirtyCode = code
}
// Code return the contract code
func (o *Account) Code() []byte {
o.lock.Lock()
defer o.lock.Unlock()
if o.dirtyCode != nil {
return o.dirtyCode
}
if o.originCode != nil {
return o.originCode
}
if bytes.Equal(o.CodeHash(), nil) {
return nil
}
code, ok := o.cache.getCode(o.Addr)
if !ok {
code = o.ldb.Get(compositeKey(codeKey, o.Addr))
}
o.originCode = code
o.dirtyCode = code
return code
}
func (o *Account) CodeHash() []byte {
if o.dirtyAccount != nil {
return o.dirtyAccount.CodeHash
}
if o.originAccount != nil {
return o.originAccount.CodeHash
}
return nil
}
// SetNonce Set the nonce which indicates the contract number
func (o *Account) SetNonce(nonce uint64) {
o.changer.append(nonceChange{
account: o.Addr,
prev: o.GetNonce(),
})
if o.dirtyAccount == nil {
o.dirtyAccount = copyOrNewIfEmpty(o.originAccount)
}
o.dirtyAccount.Nonce = nonce
}
// GetNonce Get the nonce from user account
func (o *Account) GetNonce() uint64 {
if o.dirtyAccount != nil {
return o.dirtyAccount.Nonce
}
if o.originAccount != nil {
return o.originAccount.Nonce
}
return 0
}
// GetBalance Get the balance from the account
func (o *Account) GetBalance() *big.Int {
if o.dirtyAccount != nil {
return o.dirtyAccount.Balance
}
if o.originAccount != nil {
return o.originAccount.Balance
}
return new(big.Int).SetInt64(0)
}
// SetBalance Set the balance to the account
func (o *Account) SetBalance(balance *big.Int) {
o.changer.append(balanceChange{
account: o.Addr,
prev: new(big.Int).Set(o.GetBalance()),
})
if o.dirtyAccount == nil {
o.dirtyAccount = copyOrNewIfEmpty(o.originAccount)
}
o.dirtyAccount.Balance = balance
}
func (o *Account) SubBalance(amount *big.Int) {
if amount.Sign() == 0 {
return
}
o.SetBalance(new(big.Int).Sub(o.GetBalance(), amount))
}
func (o *Account) AddBalance(amount *big.Int) {
if amount.Sign() == 0 {
return
}
o.SetBalance(new(big.Int).Add(o.GetBalance(), amount))
}
// Query Query the value using key
func (o *Account) Query(prefix string) (bool, [][]byte) {
var ret [][]byte
stored := make(map[string][]byte)
cached := o.cache.query(o.Addr, prefix)
begin, end := bytesPrefix(append(o.Addr.Bytes(), prefix...))
it := o.ldb.Iterator(begin, end)
for it.Next() {
key := make([]byte, len(it.Key()))
val := make([]byte, len(it.Value()))
copy(key, it.Key())
copy(val, it.Value())
stored[string(key)] = val
}
for key, val := range cached {
stored[key] = val
}
o.dirtyState.Range(func(key, value interface{}) bool {
if strings.HasPrefix(key.(string), prefix) {
stored[key.(string)] = value.([]byte)
}
return true
})
for _, val := range stored {
ret = append(ret, val)
}
sort.Slice(ret, func(i, j int) bool {
return bytes.Compare(ret[i], ret[j]) < 0
})
return len(ret) != 0, ret
}
func (o *Account) getJournalIfModified() *journal {
entry := &journal{Address: o.Addr}
if innerAccountChanged(o.originAccount, o.dirtyAccount) {
entry.AccountChanged = true
entry.PrevAccount = o.originAccount
}
if o.originCode == nil && !(o.originAccount == nil || o.originAccount.CodeHash == nil) {
o.originCode = o.ldb.Get(compositeKey(codeKey, o.Addr))
}
if !bytes.Equal(o.originCode, o.dirtyCode) {
entry.CodeChanged = true
entry.PrevCode = o.originCode
}
prevStates := o.getStateJournalAndComputeHash()
if len(prevStates) != 0 {
entry.PrevStates = prevStates
}
if entry.AccountChanged || entry.CodeChanged || len(entry.PrevStates) != 0 {
return entry
}
return nil
}
func (o *Account) getStateJournalAndComputeHash() map[string][]byte {
prevStates := make(map[string][]byte)
var dirtyStateKeys []string
var dirtyStateData []byte
o.dirtyState.Range(func(key, value interface{}) bool {
origVal, ok := o.originState.Load(key)
var origValBytes []byte
if ok {
origValBytes = origVal.([]byte)
}
valBytes := value.([]byte)
if !bytes.Equal(origValBytes, valBytes) {
prevStates[key.(string)] = origValBytes
dirtyStateKeys = append(dirtyStateKeys, key.(string))
}
return true
})
sort.Strings(dirtyStateKeys)
for _, key := range dirtyStateKeys {
dirtyStateData = append(dirtyStateData, key...)
dirtyVal, _ := o.dirtyState.Load(key)
dirtyStateData = append(dirtyStateData, dirtyVal.([]byte)...)
}
hash := sha256.Sum256(dirtyStateData)
o.dirtyStateHash = types.NewHash(hash[:])
return prevStates
}
func (o *Account) getDirtyData() []byte {
var dirtyData []byte
dirtyData = append(dirtyData, o.Addr.Bytes()...)
if o.dirtyAccount != nil {
data, err := o.dirtyAccount.Marshal()
if err != nil {
panic(err)
}
dirtyData = append(dirtyData, data...)
}
return append(dirtyData, o.dirtyStateHash.Bytes()...)
}
func (o *Account) markSuicided() {
o.suicided = true
}
func (o *Account) isEmpty() bool {
return o.GetBalance().Sign() == 0 && o.GetNonce() == 0 && o.Code() == nil && o.suicided == false
}
func innerAccountChanged(account0 *innerAccount, account1 *innerAccount) bool {
// If account1 is nil, the account does not change whatever account0 is.
if account1 == nil {
return false
}
// If account already exists, account0 is not nil. We should compare account0 and account1 to get the result.
if account0 != nil &&
account0.Nonce == account1.Nonce &&
account0.Balance.Cmp(account1.Balance) == 0 &&
bytes.Equal(account0.CodeHash, account1.CodeHash) {
return false
}
return true
}
// Marshal Marshal the account into byte
func (o *innerAccount) Marshal() ([]byte, error) {
obj := &innerAccount{
Nonce: o.Nonce,
Balance: o.Balance,
CodeHash: o.CodeHash,
}
return json.Marshal(obj)
}
// Unmarshal Unmarshal the account byte into structure
func (o *innerAccount) Unmarshal(data []byte) error {
return json.Unmarshal(data, o)
}
func copyOrNewIfEmpty(o *innerAccount) *innerAccount {
if o == nil {
return &innerAccount{Balance: big.NewInt(0)}
}
return &innerAccount{
Nonce: o.Nonce,
Balance: o.Balance,
CodeHash: o.CodeHash,
}
}
func bytesPrefix(prefix []byte) ([]byte, []byte) {
var limit []byte
for i := len(prefix) - 1; i >= 0; i-- {
c := prefix[i]
if c < 0xff {
limit = make([]byte, i+1)
copy(limit, prefix)
limit[i] = c + 1
break
}
}
return prefix, limit
}