diff --git a/network/stats.go b/network/stats.go index c8ece5c7..e2156c74 100644 --- a/network/stats.go +++ b/network/stats.go @@ -2,7 +2,6 @@ package network import ( "io/ioutil" - "os" "path/filepath" "strconv" "strings" @@ -25,45 +24,51 @@ func GetStats(networkState *NetworkState) (*NetworkStats, error) { if networkState.VethHost == "" { return &NetworkStats{}, nil } - data, err := readSysfsNetworkStats(networkState.VethHost) - if err != nil { - return nil, err + + out := &NetworkStats{} + + type netStatsPair struct { + // Where to write the output. + Out *uint64 + + // The network stats file to read. + File string } // Ingress for host veth is from the container. Hence tx_bytes stat on the host veth is actually number of bytes received by the container. - return &NetworkStats{ - RxBytes: data["tx_bytes"], - RxPackets: data["tx_packets"], - RxErrors: data["tx_errors"], - RxDropped: data["tx_dropped"], - TxBytes: data["rx_bytes"], - TxPackets: data["rx_packets"], - TxErrors: data["rx_errors"], - TxDropped: data["rx_dropped"], - }, nil + netStats := []netStatsPair{ + {Out: &out.RxBytes, File: "tx_bytes"}, + {Out: &out.RxPackets, File: "tx_packets"}, + {Out: &out.RxErrors, File: "tx_errors"}, + {Out: &out.RxDropped, File: "tx_dropped"}, + + {Out: &out.TxBytes, File: "rx_bytes"}, + {Out: &out.TxPackets, File: "rx_packets"}, + {Out: &out.TxErrors, File: "rx_errors"}, + {Out: &out.TxDropped, File: "rx_dropped"}, + } + for _, netStat := range netStats { + data, err := readSysfsNetworkStats(networkState.VethHost, netStat.File) + if err != nil { + return nil, err + } + *(netStat.Out) = data + } + + return out, nil } -// Reads all the statistics available under /sys/class/net//statistics as a map with file name as key and data as integers. -func readSysfsNetworkStats(ethInterface string) (map[string]uint64, error) { - out := make(map[string]uint64) +// Reads the specified statistics available under /sys/class/net//statistics +func readSysfsNetworkStats(ethInterface, statsFile string) (uint64, error) { + fullPath := filepath.Join("/sys/class/net", ethInterface, "statistics", statsFile) + data, err := ioutil.ReadFile(fullPath) + if err != nil { + return 0, err + } + value, err := strconv.ParseUint(strings.TrimSpace(string(data)), 10, 64) + if err != nil { + return 0, err + } - fullPath := filepath.Join("/sys/class/net", ethInterface, "statistics/") - err := filepath.Walk(fullPath, func(path string, _ os.FileInfo, _ error) error { - // skip fullPath. - if path == fullPath { - return nil - } - base := filepath.Base(path) - data, err := ioutil.ReadFile(path) - if err != nil { - return err - } - value, err := strconv.ParseUint(strings.TrimSpace(string(data)), 10, 64) - if err != nil { - return err - } - out[base] = value - return nil - }) - return out, err + return value, err }