Merge pull request #43 from meshplus/refactor/network-interface
refactor(network): add message handler
This commit is contained in:
commit
3e40509dbe
|
@ -1,4 +1,4 @@
|
|||
package p2p
|
||||
package network
|
||||
|
||||
import (
|
||||
"fmt"
|
|
@ -1,4 +1,4 @@
|
|||
package p2p
|
||||
package network
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
@ -8,8 +8,7 @@ import (
|
|||
|
||||
ggio "github.com/gogo/protobuf/io"
|
||||
"github.com/libp2p/go-libp2p-core/network"
|
||||
net "github.com/meshplus/bitxhub/pkg/network"
|
||||
"github.com/meshplus/bitxhub/pkg/network/proto"
|
||||
"github.com/meshplus/bitxhub/pkg/network/pb"
|
||||
)
|
||||
|
||||
// handle newly connected stream
|
||||
|
@ -20,41 +19,44 @@ func (p2p *P2P) handleNewStream(s network.Stream) {
|
|||
}
|
||||
|
||||
reader := ggio.NewDelimitedReader(s, network.MessageSizeMax)
|
||||
|
||||
for {
|
||||
msg := &proto.Message{}
|
||||
msg := &pb.Message{}
|
||||
if err := reader.ReadMsg(msg); err != nil {
|
||||
if err != io.EOF {
|
||||
if err := s.Reset(); err != nil {
|
||||
p2p.logger.WithField("error", err).Error("Reset stream")
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
p2p.recvQ <- &net.MessageStream{
|
||||
Message: msg,
|
||||
Stream: s,
|
||||
if p2p.handleMessage != nil {
|
||||
p2p.handleMessage(s, msg.Data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// waitMsg wait the incoming messages within time duration.
|
||||
func waitMsg(stream io.Reader, timeout time.Duration) *proto.Message {
|
||||
func waitMsg(stream io.Reader, timeout time.Duration) *pb.Message {
|
||||
reader := ggio.NewDelimitedReader(stream, network.MessageSizeMax)
|
||||
rs := make(chan *proto.Message)
|
||||
|
||||
ch := make(chan *pb.Message)
|
||||
|
||||
go func() {
|
||||
msg := &proto.Message{}
|
||||
msg := &pb.Message{}
|
||||
if err := reader.ReadMsg(msg); err == nil {
|
||||
rs <- msg
|
||||
ch <- msg
|
||||
} else {
|
||||
rs <- nil
|
||||
ch <- nil
|
||||
}
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
|
||||
select {
|
||||
case r := <-rs:
|
||||
case r := <-ch:
|
||||
cancel()
|
||||
return r
|
||||
case <-ctx.Done():
|
||||
|
@ -63,7 +65,7 @@ func waitMsg(stream io.Reader, timeout time.Duration) *proto.Message {
|
|||
}
|
||||
}
|
||||
|
||||
func (p2p *P2P) send(s network.Stream, message *proto.Message) error {
|
||||
func (p2p *P2P) send(s network.Stream, msg *pb.Message) error {
|
||||
deadline := time.Now().Add(sendTimeout)
|
||||
|
||||
if err := s.SetWriteDeadline(deadline); err != nil {
|
||||
|
@ -71,7 +73,7 @@ func (p2p *P2P) send(s network.Stream, message *proto.Message) error {
|
|||
}
|
||||
|
||||
writer := ggio.NewDelimitedWriter(s)
|
||||
if err := writer.WriteMsg(message); err != nil {
|
||||
if err := writer.WriteMsg(msg); err != nil {
|
||||
return fmt.Errorf("write msg: %w", err)
|
||||
}
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
package network
|
||||
|
||||
import (
|
||||
"github.com/meshplus/bitxhub/pkg/network/pb"
|
||||
)
|
||||
|
||||
func Message(data []byte) *pb.Message {
|
||||
return &pb.Message{
|
||||
Data: data,
|
||||
}
|
||||
}
|
|
@ -3,11 +3,13 @@ package network
|
|||
import (
|
||||
"github.com/libp2p/go-libp2p-core/network"
|
||||
"github.com/libp2p/go-libp2p-core/peer"
|
||||
"github.com/meshplus/bitxhub/pkg/network/proto"
|
||||
"github.com/meshplus/bitxhub/pkg/network/pb"
|
||||
)
|
||||
|
||||
type ConnectCallback func(*peer.AddrInfo) error
|
||||
|
||||
type MessageHandler func(network.Stream, []byte)
|
||||
|
||||
type Network interface {
|
||||
// Start start the network service.
|
||||
Start() error
|
||||
|
@ -21,21 +23,21 @@ type Network interface {
|
|||
// Disconnect peer with id
|
||||
Disconnect(*peer.AddrInfo) error
|
||||
|
||||
// SetConnectionCallback Sets the callback after connecting
|
||||
// SetConnectionCallback sets the callback after connecting
|
||||
SetConnectCallback(ConnectCallback)
|
||||
|
||||
// SetMessageHandler sets message handler
|
||||
SetMessageHandler(MessageHandler)
|
||||
|
||||
// AsyncSend sends message to peer with peer info.
|
||||
AsyncSend(*peer.AddrInfo, *proto.Message) error
|
||||
AsyncSend(*peer.AddrInfo, *pb.Message) error
|
||||
|
||||
// Send message using existed stream
|
||||
SendWithStream(s network.Stream, msg *proto.Message) error
|
||||
SendWithStream(network.Stream, *pb.Message) error
|
||||
|
||||
// Send sends message waiting response
|
||||
Send(*peer.AddrInfo, *proto.Message) (*proto.Message, error)
|
||||
Send(*peer.AddrInfo, *pb.Message) (*pb.Message, error)
|
||||
|
||||
// Broadcast message to all node
|
||||
Broadcast([]*peer.AddrInfo, *proto.Message) error
|
||||
|
||||
// Receive message from the channel
|
||||
Receive() <-chan *MessageStream
|
||||
Broadcast([]*peer.AddrInfo, *pb.Message) error
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package p2p
|
||||
package network
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
@ -10,13 +10,12 @@ import (
|
|||
"github.com/libp2p/go-libp2p-core/network"
|
||||
"github.com/libp2p/go-libp2p-core/peer"
|
||||
"github.com/libp2p/go-libp2p-core/peerstore"
|
||||
net "github.com/meshplus/bitxhub/pkg/network"
|
||||
"github.com/meshplus/bitxhub/pkg/network/proto"
|
||||
"github.com/meshplus/bitxhub/pkg/network/pb"
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var _ net.Network = (*P2P)(nil)
|
||||
var _ Network = (*P2P)(nil)
|
||||
|
||||
var (
|
||||
connectTimeout = 10 * time.Second
|
||||
|
@ -27,16 +26,16 @@ var (
|
|||
type P2P struct {
|
||||
config *Config
|
||||
host host.Host // manage all connections
|
||||
recvQ chan *net.MessageStream
|
||||
streamMng *streamMgr
|
||||
connectCallback net.ConnectCallback
|
||||
connectCallback ConnectCallback
|
||||
handleMessage MessageHandler
|
||||
logger logrus.FieldLogger
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func New(opts ...Option) (net.Network, error) {
|
||||
func New(opts ...Option) (*P2P, error) {
|
||||
config, err := generateConfig(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate config: %w", err)
|
||||
|
@ -55,7 +54,6 @@ func New(opts ...Option) (net.Network, error) {
|
|||
p2p := &P2P{
|
||||
config: config,
|
||||
host: h,
|
||||
recvQ: make(chan *net.MessageStream),
|
||||
streamMng: newStreamMng(ctx, h, config.protocolID),
|
||||
logger: config.logger,
|
||||
ctx: ctx,
|
||||
|
@ -92,12 +90,16 @@ func (p2p *P2P) Connect(addr *peer.AddrInfo) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p2p *P2P) SetConnectCallback(callback net.ConnectCallback) {
|
||||
func (p2p *P2P) SetConnectCallback(callback ConnectCallback) {
|
||||
p2p.connectCallback = callback
|
||||
}
|
||||
|
||||
func (p2p *P2P) SetMessageHandler(handler MessageHandler) {
|
||||
p2p.handleMessage = handler
|
||||
}
|
||||
|
||||
// AsyncSend message to peer with specific id.
|
||||
func (p2p *P2P) AsyncSend(addr *peer.AddrInfo, msg *proto.Message) error {
|
||||
func (p2p *P2P) AsyncSend(addr *peer.AddrInfo, msg *pb.Message) error {
|
||||
s, err := p2p.streamMng.get(addr.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get stream: %w", err)
|
||||
|
@ -111,11 +113,11 @@ func (p2p *P2P) AsyncSend(addr *peer.AddrInfo, msg *proto.Message) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p2p *P2P) SendWithStream(s network.Stream, msg *proto.Message) error {
|
||||
func (p2p *P2P) SendWithStream(s network.Stream, msg *pb.Message) error {
|
||||
return p2p.send(s, msg)
|
||||
}
|
||||
|
||||
func (p2p *P2P) Send(addr *peer.AddrInfo, msg *proto.Message) (*proto.Message, error) {
|
||||
func (p2p *P2P) Send(addr *peer.AddrInfo, msg *pb.Message) (*pb.Message, error) {
|
||||
s, err := p2p.streamMng.get(addr.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get stream: %w", err)
|
||||
|
@ -134,7 +136,7 @@ func (p2p *P2P) Send(addr *peer.AddrInfo, msg *proto.Message) (*proto.Message, e
|
|||
return recvMsg, nil
|
||||
}
|
||||
|
||||
func (p2p *P2P) Broadcast(ids []*peer.AddrInfo, msg *proto.Message) error {
|
||||
func (p2p *P2P) Broadcast(ids []*peer.AddrInfo, msg *pb.Message) error {
|
||||
for _, id := range ids {
|
||||
if err := p2p.AsyncSend(id, msg); err != nil {
|
||||
p2p.logger.WithFields(logrus.Fields{
|
||||
|
@ -148,10 +150,6 @@ func (p2p *P2P) Broadcast(ids []*peer.AddrInfo, msg *proto.Message) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p2p *P2P) Receive() <-chan *net.MessageStream {
|
||||
return p2p.recvQ
|
||||
}
|
||||
|
||||
// Stop stop the network service.
|
||||
func (p2p *P2P) Stop() error {
|
||||
p2p.cancel()
|
|
@ -1,4 +1,4 @@
|
|||
package p2p
|
||||
package network
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
@ -7,11 +7,11 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
net "github.com/meshplus/bitxhub/pkg/network"
|
||||
|
||||
"github.com/libp2p/go-libp2p-core/crypto"
|
||||
"github.com/libp2p/go-libp2p-core/network"
|
||||
"github.com/libp2p/go-libp2p-core/peer"
|
||||
"github.com/libp2p/go-libp2p-core/protocol"
|
||||
"github.com/meshplus/bitxhub/pkg/network/pb"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
|
@ -43,6 +43,15 @@ func TestP2P_Send(t *testing.T) {
|
|||
p1, addr1 := generateNetwork(t, 6005)
|
||||
p2, addr2 := generateNetwork(t, 6006)
|
||||
|
||||
msg := []byte("hello")
|
||||
|
||||
ch := make(chan struct{})
|
||||
|
||||
p2.SetMessageHandler(func(s network.Stream, data []byte) {
|
||||
assert.EqualValues(t, msg, data)
|
||||
close(ch)
|
||||
})
|
||||
|
||||
err := p1.Start()
|
||||
assert.Nil(t, err)
|
||||
err = p2.Start()
|
||||
|
@ -53,23 +62,18 @@ func TestP2P_Send(t *testing.T) {
|
|||
err = p2.Connect(addr1)
|
||||
assert.Nil(t, err)
|
||||
|
||||
msg := []byte("hello")
|
||||
err = p1.AsyncSend(addr2, net.Message(msg))
|
||||
err = p1.AsyncSend(addr2, &pb.Message{Data: msg})
|
||||
assert.Nil(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
ch := p2.Receive()
|
||||
for {
|
||||
select {
|
||||
case c := <-ch:
|
||||
assert.EqualValues(t, msg, c.Message.Data)
|
||||
return
|
||||
case <-ctx.Done():
|
||||
assert.Error(t, fmt.Errorf("timeout"))
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-ch:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
assert.Error(t, fmt.Errorf("timeout"))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -89,12 +93,22 @@ func TestP2p_MultiSend(t *testing.T) {
|
|||
|
||||
N := 50
|
||||
msg := []byte("hello")
|
||||
ch := p2.Receive()
|
||||
count := 0
|
||||
ch := make(chan struct{})
|
||||
|
||||
p2.SetMessageHandler(func(s network.Stream, data []byte) {
|
||||
assert.EqualValues(t, msg, data)
|
||||
count++
|
||||
if count == N {
|
||||
close(ch)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
go func() {
|
||||
for i := 0; i < N; i++ {
|
||||
time.Sleep(200 * time.Microsecond)
|
||||
err = p1.AsyncSend(addr2, net.Message(msg))
|
||||
err = p1.AsyncSend(addr2, &pb.Message{Data: msg})
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
|
@ -102,22 +116,16 @@ func TestP2p_MultiSend(t *testing.T) {
|
|||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
count := 0
|
||||
for {
|
||||
select {
|
||||
case c := <-ch:
|
||||
assert.EqualValues(t, msg, c.Message.Data)
|
||||
case <-ctx.Done():
|
||||
assert.Error(t, fmt.Errorf("timeout"))
|
||||
}
|
||||
count++
|
||||
if count == N {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ch:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
assert.Error(t, fmt.Errorf("timeout"))
|
||||
}
|
||||
}
|
||||
|
||||
func generateNetwork(t *testing.T, port int) (net.Network, *peer.AddrInfo) {
|
||||
func generateNetwork(t *testing.T, port int) (Network, *peer.AddrInfo) {
|
||||
privKey, pubKey, err := crypto.GenerateECDSAKeyPair(rand.Reader)
|
||||
assert.Nil(t, err)
|
||||
|
|
@ -6,7 +6,7 @@ help: Makefile
|
|||
|
||||
## make pb: build network message protobuf
|
||||
proto:
|
||||
cd proto && protoc -I=. \
|
||||
protoc -I=. \
|
||||
-I${GOPATH}/src \
|
||||
-I${GOPATH}/src/github.com/gogo/protobuf/protobuf \
|
||||
--gogofast_out=:. network.proto
|
|
@ -1,7 +1,7 @@
|
|||
// Code generated by protoc-gen-gogo. DO NOT EDIT.
|
||||
// source: network.proto
|
||||
|
||||
package proto
|
||||
package pb
|
||||
|
||||
import (
|
||||
fmt "fmt"
|
||||
|
@ -23,8 +23,7 @@ var _ = math.Inf
|
|||
const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package
|
||||
|
||||
type Message struct {
|
||||
Data []byte `protobuf:"bytes,1,opt,name=Data,proto3" json:"Data,omitempty"`
|
||||
Version string `protobuf:"bytes,2,opt,name=Version,proto3" json:"Version,omitempty"`
|
||||
Data []byte `protobuf:"bytes,1,opt,name=data,proto3" json:"data,omitempty"`
|
||||
XXX_NoUnkeyedLiteral struct{} `json:"-"`
|
||||
XXX_unrecognized []byte `json:"-"`
|
||||
XXX_sizecache int32 `json:"-"`
|
||||
|
@ -70,28 +69,20 @@ func (m *Message) GetData() []byte {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (m *Message) GetVersion() string {
|
||||
if m != nil {
|
||||
return m.Version
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func init() {
|
||||
proto.RegisterType((*Message)(nil), "proto.Message")
|
||||
proto.RegisterType((*Message)(nil), "pb.Message")
|
||||
}
|
||||
|
||||
func init() { proto.RegisterFile("network.proto", fileDescriptor_8571034d60397816) }
|
||||
|
||||
var fileDescriptor_8571034d60397816 = []byte{
|
||||
// 110 bytes of a gzipped FileDescriptorProto
|
||||
// 91 bytes of a gzipped FileDescriptorProto
|
||||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0xcd, 0x4b, 0x2d, 0x29,
|
||||
0xcf, 0x2f, 0xca, 0xd6, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x05, 0x53, 0x4a, 0xe6, 0x5c,
|
||||
0xec, 0xbe, 0xa9, 0xc5, 0xc5, 0x89, 0xe9, 0xa9, 0x42, 0x42, 0x5c, 0x2c, 0x2e, 0x89, 0x25, 0x89,
|
||||
0x12, 0x8c, 0x0a, 0x8c, 0x1a, 0x3c, 0x41, 0x60, 0xb6, 0x90, 0x04, 0x17, 0x7b, 0x58, 0x6a, 0x51,
|
||||
0x71, 0x66, 0x7e, 0x9e, 0x04, 0x93, 0x02, 0xa3, 0x06, 0x67, 0x10, 0x8c, 0xeb, 0xc4, 0x73, 0xe2,
|
||||
0x91, 0x1c, 0xe3, 0x85, 0x47, 0x72, 0x8c, 0x0f, 0x1e, 0xc9, 0x31, 0x26, 0xb1, 0x81, 0x4d, 0x33,
|
||||
0x06, 0x04, 0x00, 0x00, 0xff, 0xff, 0xf4, 0xb2, 0xbf, 0x15, 0x65, 0x00, 0x00, 0x00,
|
||||
0xcf, 0x2f, 0xca, 0xd6, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x2a, 0x48, 0x52, 0x92, 0xe5,
|
||||
0x62, 0xf7, 0x4d, 0x2d, 0x2e, 0x4e, 0x4c, 0x4f, 0x15, 0x12, 0xe2, 0x62, 0x49, 0x49, 0x2c, 0x49,
|
||||
0x94, 0x60, 0x54, 0x60, 0xd4, 0xe0, 0x09, 0x02, 0xb3, 0x9d, 0x78, 0x4e, 0x3c, 0x92, 0x63, 0xbc,
|
||||
0xf0, 0x48, 0x8e, 0xf1, 0xc1, 0x23, 0x39, 0xc6, 0x24, 0x36, 0xb0, 0x3e, 0x63, 0x40, 0x00, 0x00,
|
||||
0x00, 0xff, 0xff, 0x0e, 0xd4, 0xfc, 0x1a, 0x48, 0x00, 0x00, 0x00,
|
||||
}
|
||||
|
||||
func (m *Message) Marshal() (dAtA []byte, err error) {
|
||||
|
@ -118,13 +109,6 @@ func (m *Message) MarshalToSizedBuffer(dAtA []byte) (int, error) {
|
|||
i -= len(m.XXX_unrecognized)
|
||||
copy(dAtA[i:], m.XXX_unrecognized)
|
||||
}
|
||||
if len(m.Version) > 0 {
|
||||
i -= len(m.Version)
|
||||
copy(dAtA[i:], m.Version)
|
||||
i = encodeVarintNetwork(dAtA, i, uint64(len(m.Version)))
|
||||
i--
|
||||
dAtA[i] = 0x12
|
||||
}
|
||||
if len(m.Data) > 0 {
|
||||
i -= len(m.Data)
|
||||
copy(dAtA[i:], m.Data)
|
||||
|
@ -156,10 +140,6 @@ func (m *Message) Size() (n int) {
|
|||
if l > 0 {
|
||||
n += 1 + l + sovNetwork(uint64(l))
|
||||
}
|
||||
l = len(m.Version)
|
||||
if l > 0 {
|
||||
n += 1 + l + sovNetwork(uint64(l))
|
||||
}
|
||||
if m.XXX_unrecognized != nil {
|
||||
n += len(m.XXX_unrecognized)
|
||||
}
|
||||
|
@ -235,38 +215,6 @@ func (m *Message) Unmarshal(dAtA []byte) error {
|
|||
m.Data = []byte{}
|
||||
}
|
||||
iNdEx = postIndex
|
||||
case 2:
|
||||
if wireType != 2 {
|
||||
return fmt.Errorf("proto: wrong wireType = %d for field Version", wireType)
|
||||
}
|
||||
var stringLen uint64
|
||||
for shift := uint(0); ; shift += 7 {
|
||||
if shift >= 64 {
|
||||
return ErrIntOverflowNetwork
|
||||
}
|
||||
if iNdEx >= l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
b := dAtA[iNdEx]
|
||||
iNdEx++
|
||||
stringLen |= uint64(b&0x7F) << shift
|
||||
if b < 0x80 {
|
||||
break
|
||||
}
|
||||
}
|
||||
intStringLen := int(stringLen)
|
||||
if intStringLen < 0 {
|
||||
return ErrInvalidLengthNetwork
|
||||
}
|
||||
postIndex := iNdEx + intStringLen
|
||||
if postIndex < 0 {
|
||||
return ErrInvalidLengthNetwork
|
||||
}
|
||||
if postIndex > l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
m.Version = string(dAtA[iNdEx:postIndex])
|
||||
iNdEx = postIndex
|
||||
default:
|
||||
iNdEx = preIndex
|
||||
skippy, err := skipNetwork(dAtA[iNdEx:])
|
|
@ -0,0 +1,7 @@
|
|||
syntax = "proto3";
|
||||
|
||||
package pb;
|
||||
|
||||
message Message {
|
||||
bytes data = 1;
|
||||
}
|
|
@ -1,8 +0,0 @@
|
|||
syntax = "proto3";
|
||||
|
||||
package proto;
|
||||
|
||||
message Message {
|
||||
bytes Data = 1;
|
||||
string Version = 2;
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package p2p
|
||||
package network
|
||||
|
||||
import (
|
||||
"context"
|
|
@ -1,18 +0,0 @@
|
|||
package network
|
||||
|
||||
import (
|
||||
"github.com/libp2p/go-libp2p-core/network"
|
||||
"github.com/meshplus/bitxhub/pkg/network/proto"
|
||||
)
|
||||
|
||||
type MessageStream struct {
|
||||
Message *proto.Message
|
||||
Stream network.Stream
|
||||
}
|
||||
|
||||
func Message(data []byte) *proto.Message {
|
||||
return &proto.Message{
|
||||
Data: data,
|
||||
Version: "1.0",
|
||||
}
|
||||
}
|
|
@ -5,36 +5,46 @@ import (
|
|||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/meshplus/bitxhub/internal/model"
|
||||
|
||||
"github.com/libp2p/go-libp2p-core/network"
|
||||
network2 "github.com/libp2p/go-libp2p-core/network"
|
||||
"github.com/meshplus/bitxhub-model/pb"
|
||||
"github.com/meshplus/bitxhub/internal/model"
|
||||
"github.com/meshplus/bitxhub/internal/model/events"
|
||||
"github.com/meshplus/bitxhub/pkg/cert"
|
||||
proto2 "github.com/meshplus/bitxhub/pkg/network/proto"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func (swarm *Swarm) handleMessage(s network2.Stream, msg *proto2.Message) error {
|
||||
func (swarm *Swarm) handleMessage(s network.Stream, data []byte) {
|
||||
m := &pb.Message{}
|
||||
if err := m.Unmarshal(msg.Data); err != nil {
|
||||
return err
|
||||
if err := m.Unmarshal(data); err != nil {
|
||||
swarm.logger.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
switch m.Type {
|
||||
case pb.Message_GET_BLOCK:
|
||||
return swarm.handleGetBlockPack(s, m)
|
||||
case pb.Message_FETCH_CERT:
|
||||
return swarm.handleFetchCertMessage(s)
|
||||
case pb.Message_CONSENSUS:
|
||||
go swarm.orderMessageFeed.Send(events.OrderMessageEvent{Data: m.Data})
|
||||
case pb.Message_FETCH_BLOCK_SIGN:
|
||||
swarm.handleFetchBlockSignMessage(s, m.Data)
|
||||
default:
|
||||
swarm.logger.WithField("module", "p2p").Errorf("can't handle msg[type: %v]", m.Type)
|
||||
handler := func() error {
|
||||
switch m.Type {
|
||||
case pb.Message_GET_BLOCK:
|
||||
return swarm.handleGetBlockPack(s, m)
|
||||
case pb.Message_FETCH_CERT:
|
||||
return swarm.handleFetchCertMessage(s)
|
||||
case pb.Message_CONSENSUS:
|
||||
go swarm.orderMessageFeed.Send(events.OrderMessageEvent{Data: m.Data})
|
||||
case pb.Message_FETCH_BLOCK_SIGN:
|
||||
swarm.handleFetchBlockSignMessage(s, m.Data)
|
||||
default:
|
||||
swarm.logger.WithField("module", "p2p").Errorf("can't handle msg[type: %v]", m.Type)
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
if err := handler(); err != nil {
|
||||
swarm.logger.WithFields(logrus.Fields{
|
||||
"error": err,
|
||||
"type": m.Type.String(),
|
||||
}).Error("Handle message")
|
||||
}
|
||||
}
|
||||
|
||||
func (swarm *Swarm) handleGetBlockPack(s network.Stream, msg *pb.Message) error {
|
||||
|
|
|
@ -18,7 +18,6 @@ import (
|
|||
"github.com/meshplus/bitxhub/internal/repo"
|
||||
"github.com/meshplus/bitxhub/pkg/cert"
|
||||
"github.com/meshplus/bitxhub/pkg/network"
|
||||
"github.com/meshplus/bitxhub/pkg/network/p2p"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
|
@ -41,11 +40,11 @@ type Swarm struct {
|
|||
}
|
||||
|
||||
func New(repo *repo.Repo, logger logrus.FieldLogger, ledger ledger.Ledger) (*Swarm, error) {
|
||||
p2p, err := p2p.New(
|
||||
p2p.WithLocalAddr(repo.NetworkConfig.LocalAddr),
|
||||
p2p.WithPrivateKey(repo.Key.Libp2pPrivKey),
|
||||
p2p.WithProtocolID(protocolID),
|
||||
p2p.WithLogger(logger),
|
||||
p2p, err := network.New(
|
||||
network.WithLocalAddr(repo.NetworkConfig.LocalAddr),
|
||||
network.WithPrivateKey(repo.Key.Libp2pPrivKey),
|
||||
network.WithProtocolID(protocolID),
|
||||
network.WithLogger(logger),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
|
@ -67,12 +66,12 @@ func New(repo *repo.Repo, logger logrus.FieldLogger, ledger ledger.Ledger) (*Swa
|
|||
}
|
||||
|
||||
func (swarm *Swarm) Start() error {
|
||||
swarm.p2p.SetMessageHandler(swarm.handleMessage)
|
||||
|
||||
if err := swarm.p2p.Start(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go swarm.receiveMessage()
|
||||
|
||||
for id, addr := range swarm.peers {
|
||||
go func(id uint64, addr *peer.AddrInfo) {
|
||||
if err := retry.Retry(func(attempt uint) error {
|
||||
|
@ -129,9 +128,7 @@ func (swarm *Swarm) AsyncSend(id uint64, msg *pb.Message) error {
|
|||
return err
|
||||
}
|
||||
|
||||
m := network.Message(data)
|
||||
|
||||
return swarm.p2p.AsyncSend(swarm.peers[id], m)
|
||||
return swarm.p2p.AsyncSend(swarm.peers[id], network.Message(data))
|
||||
}
|
||||
|
||||
func (swarm *Swarm) SendWithStream(s network2.Stream, msg *pb.Message) error {
|
||||
|
@ -140,9 +137,7 @@ func (swarm *Swarm) SendWithStream(s network2.Stream, msg *pb.Message) error {
|
|||
return err
|
||||
}
|
||||
|
||||
m := network.Message(data)
|
||||
|
||||
return swarm.p2p.SendWithStream(s, m)
|
||||
return swarm.p2p.SendWithStream(s, network.Message(data))
|
||||
}
|
||||
|
||||
func (swarm *Swarm) Send(id uint64, msg *pb.Message) (*pb.Message, error) {
|
||||
|
@ -179,9 +174,7 @@ func (swarm *Swarm) Broadcast(msg *pb.Message) error {
|
|||
return err
|
||||
}
|
||||
|
||||
m := network.Message(data)
|
||||
|
||||
return swarm.p2p.Broadcast(addrs, m)
|
||||
return swarm.p2p.Broadcast(addrs, network.Message(data))
|
||||
}
|
||||
|
||||
func (swarm *Swarm) Peers() map[uint64]*peer.AddrInfo {
|
||||
|
@ -252,18 +245,3 @@ func (swarm *Swarm) checkID(id uint64) error {
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (swarm *Swarm) receiveMessage() {
|
||||
for {
|
||||
select {
|
||||
case m := <-swarm.p2p.Receive():
|
||||
go func() {
|
||||
if err := swarm.handleMessage(m.Stream, m.Message); err != nil {
|
||||
swarm.logger.WithField("error", err).Error("Handle message")
|
||||
}
|
||||
}()
|
||||
case <-swarm.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue