netlink: Extract message checks into reusable method

Signed-off-by: Jonathan Rudenberg <jonathan@titanous.com>
This commit is contained in:
Jonathan Rudenberg 2014-09-14 20:29:27 -04:00
parent 6e4334a68e
commit 65842f749b
1 changed files with 31 additions and 31 deletions

View File

@ -3,6 +3,7 @@ package netlink
import ( import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"io"
"net" "net"
"sync/atomic" "sync/atomic"
"syscall" "syscall"
@ -322,35 +323,44 @@ func (s *NetlinkSocket) GetPid() (uint32, error) {
return 0, ErrWrongSockType return 0, ErrWrongSockType
} }
func (s *NetlinkSocket) HandleAck(seq uint32) error { func (s *NetlinkSocket) CheckMessage(m syscall.NetlinkMessage, seq, pid uint32) error {
if m.Header.Seq != seq {
return fmt.Errorf("netlink: invalid seq %d, expected %d", m.Header.Seq, seq)
}
if m.Header.Pid != pid {
return fmt.Errorf("netlink: wrong pid %d, expected %d", m.Header.Pid, pid)
}
if m.Header.Type == syscall.NLMSG_DONE {
return io.EOF
}
if m.Header.Type == syscall.NLMSG_ERROR {
e := int32(native.Uint32(m.Data[0:4]))
if e == 0 {
return io.EOF
}
return syscall.Errno(-e)
}
return nil
}
func (s *NetlinkSocket) HandleAck(seq uint32) error {
pid, err := s.GetPid() pid, err := s.GetPid()
if err != nil { if err != nil {
return err return err
} }
done: outer:
for { for {
msgs, err := s.Receive() msgs, err := s.Receive()
if err != nil { if err != nil {
return err return err
} }
for _, m := range msgs { for _, m := range msgs {
if m.Header.Seq != seq { if err := s.CheckMessage(m, seq, pid); err != nil {
return fmt.Errorf("Wrong Seq nr %d, expected %d", m.Header.Seq, seq) if err == io.EOF {
} break outer
if m.Header.Pid != pid {
return fmt.Errorf("Wrong pid %d, expected %d", m.Header.Pid, pid)
}
if m.Header.Type == syscall.NLMSG_DONE {
break done
}
if m.Header.Type == syscall.NLMSG_ERROR {
error := int32(native.Uint32(m.Data[0:4]))
if error == 0 {
break done
} }
return syscall.Errno(-error) return err
} }
} }
} }
@ -781,28 +791,18 @@ func NetworkGetRoutes() ([]Route, error) {
res := make([]Route, 0) res := make([]Route, 0)
done: outer:
for { for {
msgs, err := s.Receive() msgs, err := s.Receive()
if err != nil { if err != nil {
return nil, err return nil, err
} }
for _, m := range msgs { for _, m := range msgs {
if m.Header.Seq != wb.Seq { if err := s.CheckMessage(m, wb.Seq, pid); err != nil {
return nil, fmt.Errorf("Wrong Seq nr %d, expected 1", m.Header.Seq) if err == io.EOF {
} break outer
if m.Header.Pid != pid {
return nil, fmt.Errorf("Wrong pid %d, expected %d", m.Header.Pid, pid)
}
if m.Header.Type == syscall.NLMSG_DONE {
break done
}
if m.Header.Type == syscall.NLMSG_ERROR {
error := int32(native.Uint32(m.Data[0:4]))
if error == 0 {
break done
} }
return nil, syscall.Errno(-error) return nil, err
} }
if m.Header.Type != syscall.RTM_NEWROUTE { if m.Header.Type != syscall.RTM_NEWROUTE {
continue continue