netlink: Extract message checks into reusable method
Signed-off-by: Jonathan Rudenberg <jonathan@titanous.com>
This commit is contained in:
parent
6e4334a68e
commit
65842f749b
|
@ -3,6 +3,7 @@ package netlink
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
@ -322,35 +323,44 @@ func (s *NetlinkSocket) GetPid() (uint32, error) {
|
||||||
return 0, ErrWrongSockType
|
return 0, ErrWrongSockType
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *NetlinkSocket) HandleAck(seq uint32) error {
|
func (s *NetlinkSocket) CheckMessage(m syscall.NetlinkMessage, seq, pid uint32) error {
|
||||||
|
if m.Header.Seq != seq {
|
||||||
|
return fmt.Errorf("netlink: invalid seq %d, expected %d", m.Header.Seq, seq)
|
||||||
|
}
|
||||||
|
if m.Header.Pid != pid {
|
||||||
|
return fmt.Errorf("netlink: wrong pid %d, expected %d", m.Header.Pid, pid)
|
||||||
|
}
|
||||||
|
if m.Header.Type == syscall.NLMSG_DONE {
|
||||||
|
return io.EOF
|
||||||
|
}
|
||||||
|
if m.Header.Type == syscall.NLMSG_ERROR {
|
||||||
|
e := int32(native.Uint32(m.Data[0:4]))
|
||||||
|
if e == 0 {
|
||||||
|
return io.EOF
|
||||||
|
}
|
||||||
|
return syscall.Errno(-e)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *NetlinkSocket) HandleAck(seq uint32) error {
|
||||||
pid, err := s.GetPid()
|
pid, err := s.GetPid()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
done:
|
outer:
|
||||||
for {
|
for {
|
||||||
msgs, err := s.Receive()
|
msgs, err := s.Receive()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
for _, m := range msgs {
|
for _, m := range msgs {
|
||||||
if m.Header.Seq != seq {
|
if err := s.CheckMessage(m, seq, pid); err != nil {
|
||||||
return fmt.Errorf("Wrong Seq nr %d, expected %d", m.Header.Seq, seq)
|
if err == io.EOF {
|
||||||
}
|
break outer
|
||||||
if m.Header.Pid != pid {
|
|
||||||
return fmt.Errorf("Wrong pid %d, expected %d", m.Header.Pid, pid)
|
|
||||||
}
|
|
||||||
if m.Header.Type == syscall.NLMSG_DONE {
|
|
||||||
break done
|
|
||||||
}
|
|
||||||
if m.Header.Type == syscall.NLMSG_ERROR {
|
|
||||||
error := int32(native.Uint32(m.Data[0:4]))
|
|
||||||
if error == 0 {
|
|
||||||
break done
|
|
||||||
}
|
}
|
||||||
return syscall.Errno(-error)
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -781,28 +791,18 @@ func NetworkGetRoutes() ([]Route, error) {
|
||||||
|
|
||||||
res := make([]Route, 0)
|
res := make([]Route, 0)
|
||||||
|
|
||||||
done:
|
outer:
|
||||||
for {
|
for {
|
||||||
msgs, err := s.Receive()
|
msgs, err := s.Receive()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
for _, m := range msgs {
|
for _, m := range msgs {
|
||||||
if m.Header.Seq != wb.Seq {
|
if err := s.CheckMessage(m, wb.Seq, pid); err != nil {
|
||||||
return nil, fmt.Errorf("Wrong Seq nr %d, expected 1", m.Header.Seq)
|
if err == io.EOF {
|
||||||
}
|
break outer
|
||||||
if m.Header.Pid != pid {
|
|
||||||
return nil, fmt.Errorf("Wrong pid %d, expected %d", m.Header.Pid, pid)
|
|
||||||
}
|
|
||||||
if m.Header.Type == syscall.NLMSG_DONE {
|
|
||||||
break done
|
|
||||||
}
|
|
||||||
if m.Header.Type == syscall.NLMSG_ERROR {
|
|
||||||
error := int32(native.Uint32(m.Data[0:4]))
|
|
||||||
if error == 0 {
|
|
||||||
break done
|
|
||||||
}
|
}
|
||||||
return nil, syscall.Errno(-error)
|
return nil, err
|
||||||
}
|
}
|
||||||
if m.Header.Type != syscall.RTM_NEWROUTE {
|
if m.Header.Type != syscall.RTM_NEWROUTE {
|
||||||
continue
|
continue
|
||||||
|
|
Loading…
Reference in New Issue