Merge pull request from meshplus/refactor/network-interface

refactor(network): add message handler
This commit is contained in:
Aiden X 2020-04-22 15:49:36 +08:00 committed by GitHub
commit 3e40509dbe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 149 additions and 211 deletions

View File

@ -1,4 +1,4 @@
package p2p
package network
import (
"fmt"

View File

@ -1,4 +1,4 @@
package p2p
package network
import (
"context"
@ -8,8 +8,7 @@ import (
ggio "github.com/gogo/protobuf/io"
"github.com/libp2p/go-libp2p-core/network"
net "github.com/meshplus/bitxhub/pkg/network"
"github.com/meshplus/bitxhub/pkg/network/proto"
"github.com/meshplus/bitxhub/pkg/network/pb"
)
// handle newly connected stream
@ -20,41 +19,44 @@ func (p2p *P2P) handleNewStream(s network.Stream) {
}
reader := ggio.NewDelimitedReader(s, network.MessageSizeMax)
for {
msg := &proto.Message{}
msg := &pb.Message{}
if err := reader.ReadMsg(msg); err != nil {
if err != io.EOF {
if err := s.Reset(); err != nil {
p2p.logger.WithField("error", err).Error("Reset stream")
}
}
return
}
p2p.recvQ <- &net.MessageStream{
Message: msg,
Stream: s,
if p2p.handleMessage != nil {
p2p.handleMessage(s, msg.Data)
}
}
}
// waitMsg wait the incoming messages within time duration.
func waitMsg(stream io.Reader, timeout time.Duration) *proto.Message {
func waitMsg(stream io.Reader, timeout time.Duration) *pb.Message {
reader := ggio.NewDelimitedReader(stream, network.MessageSizeMax)
rs := make(chan *proto.Message)
ch := make(chan *pb.Message)
go func() {
msg := &proto.Message{}
msg := &pb.Message{}
if err := reader.ReadMsg(msg); err == nil {
rs <- msg
ch <- msg
} else {
rs <- nil
ch <- nil
}
}()
ctx, cancel := context.WithTimeout(context.Background(), timeout)
select {
case r := <-rs:
case r := <-ch:
cancel()
return r
case <-ctx.Done():
@ -63,7 +65,7 @@ func waitMsg(stream io.Reader, timeout time.Duration) *proto.Message {
}
}
func (p2p *P2P) send(s network.Stream, message *proto.Message) error {
func (p2p *P2P) send(s network.Stream, msg *pb.Message) error {
deadline := time.Now().Add(sendTimeout)
if err := s.SetWriteDeadline(deadline); err != nil {
@ -71,7 +73,7 @@ func (p2p *P2P) send(s network.Stream, message *proto.Message) error {
}
writer := ggio.NewDelimitedWriter(s)
if err := writer.WriteMsg(message); err != nil {
if err := writer.WriteMsg(msg); err != nil {
return fmt.Errorf("write msg: %w", err)
}

11
pkg/network/helper.go Normal file
View File

@ -0,0 +1,11 @@
package network
import (
"github.com/meshplus/bitxhub/pkg/network/pb"
)
func Message(data []byte) *pb.Message {
return &pb.Message{
Data: data,
}
}

View File

@ -3,11 +3,13 @@ package network
import (
"github.com/libp2p/go-libp2p-core/network"
"github.com/libp2p/go-libp2p-core/peer"
"github.com/meshplus/bitxhub/pkg/network/proto"
"github.com/meshplus/bitxhub/pkg/network/pb"
)
type ConnectCallback func(*peer.AddrInfo) error
type MessageHandler func(network.Stream, []byte)
type Network interface {
// Start start the network service.
Start() error
@ -21,21 +23,21 @@ type Network interface {
// Disconnect peer with id
Disconnect(*peer.AddrInfo) error
// SetConnectionCallback Sets the callback after connecting
// SetConnectionCallback sets the callback after connecting
SetConnectCallback(ConnectCallback)
// SetMessageHandler sets message handler
SetMessageHandler(MessageHandler)
// AsyncSend sends message to peer with peer info.
AsyncSend(*peer.AddrInfo, *proto.Message) error
AsyncSend(*peer.AddrInfo, *pb.Message) error
// Send message using existed stream
SendWithStream(s network.Stream, msg *proto.Message) error
SendWithStream(network.Stream, *pb.Message) error
// Send sends message waiting response
Send(*peer.AddrInfo, *proto.Message) (*proto.Message, error)
Send(*peer.AddrInfo, *pb.Message) (*pb.Message, error)
// Broadcast message to all node
Broadcast([]*peer.AddrInfo, *proto.Message) error
// Receive message from the channel
Receive() <-chan *MessageStream
Broadcast([]*peer.AddrInfo, *pb.Message) error
}

View File

@ -1,4 +1,4 @@
package p2p
package network
import (
"context"
@ -10,13 +10,12 @@ import (
"github.com/libp2p/go-libp2p-core/network"
"github.com/libp2p/go-libp2p-core/peer"
"github.com/libp2p/go-libp2p-core/peerstore"
net "github.com/meshplus/bitxhub/pkg/network"
"github.com/meshplus/bitxhub/pkg/network/proto"
"github.com/meshplus/bitxhub/pkg/network/pb"
ma "github.com/multiformats/go-multiaddr"
"github.com/sirupsen/logrus"
)
var _ net.Network = (*P2P)(nil)
var _ Network = (*P2P)(nil)
var (
connectTimeout = 10 * time.Second
@ -27,16 +26,16 @@ var (
type P2P struct {
config *Config
host host.Host // manage all connections
recvQ chan *net.MessageStream
streamMng *streamMgr
connectCallback net.ConnectCallback
connectCallback ConnectCallback
handleMessage MessageHandler
logger logrus.FieldLogger
ctx context.Context
cancel context.CancelFunc
}
func New(opts ...Option) (net.Network, error) {
func New(opts ...Option) (*P2P, error) {
config, err := generateConfig(opts...)
if err != nil {
return nil, fmt.Errorf("generate config: %w", err)
@ -55,7 +54,6 @@ func New(opts ...Option) (net.Network, error) {
p2p := &P2P{
config: config,
host: h,
recvQ: make(chan *net.MessageStream),
streamMng: newStreamMng(ctx, h, config.protocolID),
logger: config.logger,
ctx: ctx,
@ -92,12 +90,16 @@ func (p2p *P2P) Connect(addr *peer.AddrInfo) error {
return nil
}
func (p2p *P2P) SetConnectCallback(callback net.ConnectCallback) {
func (p2p *P2P) SetConnectCallback(callback ConnectCallback) {
p2p.connectCallback = callback
}
func (p2p *P2P) SetMessageHandler(handler MessageHandler) {
p2p.handleMessage = handler
}
// AsyncSend message to peer with specific id.
func (p2p *P2P) AsyncSend(addr *peer.AddrInfo, msg *proto.Message) error {
func (p2p *P2P) AsyncSend(addr *peer.AddrInfo, msg *pb.Message) error {
s, err := p2p.streamMng.get(addr.ID)
if err != nil {
return fmt.Errorf("get stream: %w", err)
@ -111,11 +113,11 @@ func (p2p *P2P) AsyncSend(addr *peer.AddrInfo, msg *proto.Message) error {
return nil
}
func (p2p *P2P) SendWithStream(s network.Stream, msg *proto.Message) error {
func (p2p *P2P) SendWithStream(s network.Stream, msg *pb.Message) error {
return p2p.send(s, msg)
}
func (p2p *P2P) Send(addr *peer.AddrInfo, msg *proto.Message) (*proto.Message, error) {
func (p2p *P2P) Send(addr *peer.AddrInfo, msg *pb.Message) (*pb.Message, error) {
s, err := p2p.streamMng.get(addr.ID)
if err != nil {
return nil, fmt.Errorf("get stream: %w", err)
@ -134,7 +136,7 @@ func (p2p *P2P) Send(addr *peer.AddrInfo, msg *proto.Message) (*proto.Message, e
return recvMsg, nil
}
func (p2p *P2P) Broadcast(ids []*peer.AddrInfo, msg *proto.Message) error {
func (p2p *P2P) Broadcast(ids []*peer.AddrInfo, msg *pb.Message) error {
for _, id := range ids {
if err := p2p.AsyncSend(id, msg); err != nil {
p2p.logger.WithFields(logrus.Fields{
@ -148,10 +150,6 @@ func (p2p *P2P) Broadcast(ids []*peer.AddrInfo, msg *proto.Message) error {
return nil
}
func (p2p *P2P) Receive() <-chan *net.MessageStream {
return p2p.recvQ
}
// Stop stop the network service.
func (p2p *P2P) Stop() error {
p2p.cancel()

View File

@ -1,4 +1,4 @@
package p2p
package network
import (
"context"
@ -7,11 +7,11 @@ import (
"testing"
"time"
net "github.com/meshplus/bitxhub/pkg/network"
"github.com/libp2p/go-libp2p-core/crypto"
"github.com/libp2p/go-libp2p-core/network"
"github.com/libp2p/go-libp2p-core/peer"
"github.com/libp2p/go-libp2p-core/protocol"
"github.com/meshplus/bitxhub/pkg/network/pb"
"github.com/stretchr/testify/assert"
)
@ -43,6 +43,15 @@ func TestP2P_Send(t *testing.T) {
p1, addr1 := generateNetwork(t, 6005)
p2, addr2 := generateNetwork(t, 6006)
msg := []byte("hello")
ch := make(chan struct{})
p2.SetMessageHandler(func(s network.Stream, data []byte) {
assert.EqualValues(t, msg, data)
close(ch)
})
err := p1.Start()
assert.Nil(t, err)
err = p2.Start()
@ -53,23 +62,18 @@ func TestP2P_Send(t *testing.T) {
err = p2.Connect(addr1)
assert.Nil(t, err)
msg := []byte("hello")
err = p1.AsyncSend(addr2, net.Message(msg))
err = p1.AsyncSend(addr2, &pb.Message{Data: msg})
assert.Nil(t, err)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
ch := p2.Receive()
for {
select {
case c := <-ch:
assert.EqualValues(t, msg, c.Message.Data)
return
case <-ctx.Done():
assert.Error(t, fmt.Errorf("timeout"))
return
}
select {
case <-ch:
return
case <-ctx.Done():
assert.Error(t, fmt.Errorf("timeout"))
return
}
}
@ -89,12 +93,22 @@ func TestP2p_MultiSend(t *testing.T) {
N := 50
msg := []byte("hello")
ch := p2.Receive()
count := 0
ch := make(chan struct{})
p2.SetMessageHandler(func(s network.Stream, data []byte) {
assert.EqualValues(t, msg, data)
count++
if count == N {
close(ch)
return
}
})
go func() {
for i := 0; i < N; i++ {
time.Sleep(200 * time.Microsecond)
err = p1.AsyncSend(addr2, net.Message(msg))
err = p1.AsyncSend(addr2, &pb.Message{Data: msg})
assert.Nil(t, err)
}
@ -102,22 +116,16 @@ func TestP2p_MultiSend(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
count := 0
for {
select {
case c := <-ch:
assert.EqualValues(t, msg, c.Message.Data)
case <-ctx.Done():
assert.Error(t, fmt.Errorf("timeout"))
}
count++
if count == N {
return
}
select {
case <-ch:
return
case <-ctx.Done():
assert.Error(t, fmt.Errorf("timeout"))
}
}
func generateNetwork(t *testing.T, port int) (net.Network, *peer.AddrInfo) {
func generateNetwork(t *testing.T, port int) (Network, *peer.AddrInfo) {
privKey, pubKey, err := crypto.GenerateECDSAKeyPair(rand.Reader)
assert.Nil(t, err)

View File

@ -6,7 +6,7 @@ help: Makefile
## make pb: build network message protobuf
proto:
cd proto && protoc -I=. \
protoc -I=. \
-I${GOPATH}/src \
-I${GOPATH}/src/github.com/gogo/protobuf/protobuf \
--gogofast_out=:. network.proto

View File

@ -1,7 +1,7 @@
// Code generated by protoc-gen-gogo. DO NOT EDIT.
// source: network.proto
package proto
package pb
import (
fmt "fmt"
@ -23,8 +23,7 @@ var _ = math.Inf
const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package
type Message struct {
Data []byte `protobuf:"bytes,1,opt,name=Data,proto3" json:"Data,omitempty"`
Version string `protobuf:"bytes,2,opt,name=Version,proto3" json:"Version,omitempty"`
Data []byte `protobuf:"bytes,1,opt,name=data,proto3" json:"data,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
@ -70,28 +69,20 @@ func (m *Message) GetData() []byte {
return nil
}
func (m *Message) GetVersion() string {
if m != nil {
return m.Version
}
return ""
}
func init() {
proto.RegisterType((*Message)(nil), "proto.Message")
proto.RegisterType((*Message)(nil), "pb.Message")
}
func init() { proto.RegisterFile("network.proto", fileDescriptor_8571034d60397816) }
var fileDescriptor_8571034d60397816 = []byte{
// 110 bytes of a gzipped FileDescriptorProto
// 91 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0xcd, 0x4b, 0x2d, 0x29,
0xcf, 0x2f, 0xca, 0xd6, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x05, 0x53, 0x4a, 0xe6, 0x5c,
0xec, 0xbe, 0xa9, 0xc5, 0xc5, 0x89, 0xe9, 0xa9, 0x42, 0x42, 0x5c, 0x2c, 0x2e, 0x89, 0x25, 0x89,
0x12, 0x8c, 0x0a, 0x8c, 0x1a, 0x3c, 0x41, 0x60, 0xb6, 0x90, 0x04, 0x17, 0x7b, 0x58, 0x6a, 0x51,
0x71, 0x66, 0x7e, 0x9e, 0x04, 0x93, 0x02, 0xa3, 0x06, 0x67, 0x10, 0x8c, 0xeb, 0xc4, 0x73, 0xe2,
0x91, 0x1c, 0xe3, 0x85, 0x47, 0x72, 0x8c, 0x0f, 0x1e, 0xc9, 0x31, 0x26, 0xb1, 0x81, 0x4d, 0x33,
0x06, 0x04, 0x00, 0x00, 0xff, 0xff, 0xf4, 0xb2, 0xbf, 0x15, 0x65, 0x00, 0x00, 0x00,
0xcf, 0x2f, 0xca, 0xd6, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x2a, 0x48, 0x52, 0x92, 0xe5,
0x62, 0xf7, 0x4d, 0x2d, 0x2e, 0x4e, 0x4c, 0x4f, 0x15, 0x12, 0xe2, 0x62, 0x49, 0x49, 0x2c, 0x49,
0x94, 0x60, 0x54, 0x60, 0xd4, 0xe0, 0x09, 0x02, 0xb3, 0x9d, 0x78, 0x4e, 0x3c, 0x92, 0x63, 0xbc,
0xf0, 0x48, 0x8e, 0xf1, 0xc1, 0x23, 0x39, 0xc6, 0x24, 0x36, 0xb0, 0x3e, 0x63, 0x40, 0x00, 0x00,
0x00, 0xff, 0xff, 0x0e, 0xd4, 0xfc, 0x1a, 0x48, 0x00, 0x00, 0x00,
}
func (m *Message) Marshal() (dAtA []byte, err error) {
@ -118,13 +109,6 @@ func (m *Message) MarshalToSizedBuffer(dAtA []byte) (int, error) {
i -= len(m.XXX_unrecognized)
copy(dAtA[i:], m.XXX_unrecognized)
}
if len(m.Version) > 0 {
i -= len(m.Version)
copy(dAtA[i:], m.Version)
i = encodeVarintNetwork(dAtA, i, uint64(len(m.Version)))
i--
dAtA[i] = 0x12
}
if len(m.Data) > 0 {
i -= len(m.Data)
copy(dAtA[i:], m.Data)
@ -156,10 +140,6 @@ func (m *Message) Size() (n int) {
if l > 0 {
n += 1 + l + sovNetwork(uint64(l))
}
l = len(m.Version)
if l > 0 {
n += 1 + l + sovNetwork(uint64(l))
}
if m.XXX_unrecognized != nil {
n += len(m.XXX_unrecognized)
}
@ -235,38 +215,6 @@ func (m *Message) Unmarshal(dAtA []byte) error {
m.Data = []byte{}
}
iNdEx = postIndex
case 2:
if wireType != 2 {
return fmt.Errorf("proto: wrong wireType = %d for field Version", wireType)
}
var stringLen uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowNetwork
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
stringLen |= uint64(b&0x7F) << shift
if b < 0x80 {
break
}
}
intStringLen := int(stringLen)
if intStringLen < 0 {
return ErrInvalidLengthNetwork
}
postIndex := iNdEx + intStringLen
if postIndex < 0 {
return ErrInvalidLengthNetwork
}
if postIndex > l {
return io.ErrUnexpectedEOF
}
m.Version = string(dAtA[iNdEx:postIndex])
iNdEx = postIndex
default:
iNdEx = preIndex
skippy, err := skipNetwork(dAtA[iNdEx:])

View File

@ -0,0 +1,7 @@
syntax = "proto3";
package pb;
message Message {
bytes data = 1;
}

View File

@ -1,8 +0,0 @@
syntax = "proto3";
package proto;
message Message {
bytes Data = 1;
string Version = 2;
}

View File

@ -1,4 +1,4 @@
package p2p
package network
import (
"context"

View File

@ -1,18 +0,0 @@
package network
import (
"github.com/libp2p/go-libp2p-core/network"
"github.com/meshplus/bitxhub/pkg/network/proto"
)
type MessageStream struct {
Message *proto.Message
Stream network.Stream
}
func Message(data []byte) *proto.Message {
return &proto.Message{
Data: data,
Version: "1.0",
}
}

View File

@ -5,36 +5,46 @@ import (
"fmt"
"strconv"
"github.com/meshplus/bitxhub/internal/model"
"github.com/libp2p/go-libp2p-core/network"
network2 "github.com/libp2p/go-libp2p-core/network"
"github.com/meshplus/bitxhub-model/pb"
"github.com/meshplus/bitxhub/internal/model"
"github.com/meshplus/bitxhub/internal/model/events"
"github.com/meshplus/bitxhub/pkg/cert"
proto2 "github.com/meshplus/bitxhub/pkg/network/proto"
"github.com/sirupsen/logrus"
)
func (swarm *Swarm) handleMessage(s network2.Stream, msg *proto2.Message) error {
func (swarm *Swarm) handleMessage(s network.Stream, data []byte) {
m := &pb.Message{}
if err := m.Unmarshal(msg.Data); err != nil {
return err
if err := m.Unmarshal(data); err != nil {
swarm.logger.Error(err)
return
}
switch m.Type {
case pb.Message_GET_BLOCK:
return swarm.handleGetBlockPack(s, m)
case pb.Message_FETCH_CERT:
return swarm.handleFetchCertMessage(s)
case pb.Message_CONSENSUS:
go swarm.orderMessageFeed.Send(events.OrderMessageEvent{Data: m.Data})
case pb.Message_FETCH_BLOCK_SIGN:
swarm.handleFetchBlockSignMessage(s, m.Data)
default:
swarm.logger.WithField("module", "p2p").Errorf("can't handle msg[type: %v]", m.Type)
handler := func() error {
switch m.Type {
case pb.Message_GET_BLOCK:
return swarm.handleGetBlockPack(s, m)
case pb.Message_FETCH_CERT:
return swarm.handleFetchCertMessage(s)
case pb.Message_CONSENSUS:
go swarm.orderMessageFeed.Send(events.OrderMessageEvent{Data: m.Data})
case pb.Message_FETCH_BLOCK_SIGN:
swarm.handleFetchBlockSignMessage(s, m.Data)
default:
swarm.logger.WithField("module", "p2p").Errorf("can't handle msg[type: %v]", m.Type)
return nil
}
return nil
}
return nil
if err := handler(); err != nil {
swarm.logger.WithFields(logrus.Fields{
"error": err,
"type": m.Type.String(),
}).Error("Handle message")
}
}
func (swarm *Swarm) handleGetBlockPack(s network.Stream, msg *pb.Message) error {

View File

@ -18,7 +18,6 @@ import (
"github.com/meshplus/bitxhub/internal/repo"
"github.com/meshplus/bitxhub/pkg/cert"
"github.com/meshplus/bitxhub/pkg/network"
"github.com/meshplus/bitxhub/pkg/network/p2p"
"github.com/sirupsen/logrus"
)
@ -41,11 +40,11 @@ type Swarm struct {
}
func New(repo *repo.Repo, logger logrus.FieldLogger, ledger ledger.Ledger) (*Swarm, error) {
p2p, err := p2p.New(
p2p.WithLocalAddr(repo.NetworkConfig.LocalAddr),
p2p.WithPrivateKey(repo.Key.Libp2pPrivKey),
p2p.WithProtocolID(protocolID),
p2p.WithLogger(logger),
p2p, err := network.New(
network.WithLocalAddr(repo.NetworkConfig.LocalAddr),
network.WithPrivateKey(repo.Key.Libp2pPrivKey),
network.WithProtocolID(protocolID),
network.WithLogger(logger),
)
if err != nil {
@ -67,12 +66,12 @@ func New(repo *repo.Repo, logger logrus.FieldLogger, ledger ledger.Ledger) (*Swa
}
func (swarm *Swarm) Start() error {
swarm.p2p.SetMessageHandler(swarm.handleMessage)
if err := swarm.p2p.Start(); err != nil {
return err
}
go swarm.receiveMessage()
for id, addr := range swarm.peers {
go func(id uint64, addr *peer.AddrInfo) {
if err := retry.Retry(func(attempt uint) error {
@ -129,9 +128,7 @@ func (swarm *Swarm) AsyncSend(id uint64, msg *pb.Message) error {
return err
}
m := network.Message(data)
return swarm.p2p.AsyncSend(swarm.peers[id], m)
return swarm.p2p.AsyncSend(swarm.peers[id], network.Message(data))
}
func (swarm *Swarm) SendWithStream(s network2.Stream, msg *pb.Message) error {
@ -140,9 +137,7 @@ func (swarm *Swarm) SendWithStream(s network2.Stream, msg *pb.Message) error {
return err
}
m := network.Message(data)
return swarm.p2p.SendWithStream(s, m)
return swarm.p2p.SendWithStream(s, network.Message(data))
}
func (swarm *Swarm) Send(id uint64, msg *pb.Message) (*pb.Message, error) {
@ -179,9 +174,7 @@ func (swarm *Swarm) Broadcast(msg *pb.Message) error {
return err
}
m := network.Message(data)
return swarm.p2p.Broadcast(addrs, m)
return swarm.p2p.Broadcast(addrs, network.Message(data))
}
func (swarm *Swarm) Peers() map[uint64]*peer.AddrInfo {
@ -252,18 +245,3 @@ func (swarm *Swarm) checkID(id uint64) error {
return nil
}
func (swarm *Swarm) receiveMessage() {
for {
select {
case m := <-swarm.p2p.Receive():
go func() {
if err := swarm.handleMessage(m.Stream, m.Message); err != nil {
swarm.logger.WithField("error", err).Error("Handle message")
}
}()
case <-swarm.ctx.Done():
return
}
}
}