diff --git a/libcontainer/cgroups/fscommon/fscommon.go b/libcontainer/cgroups/fscommon/fscommon.go index dd92e8c8..dc53987e 100644 --- a/libcontainer/cgroups/fscommon/fscommon.go +++ b/libcontainer/cgroups/fscommon/fscommon.go @@ -4,9 +4,12 @@ package fscommon import ( "io/ioutil" + "os" + "syscall" securejoin "github.com/cyphar/filepath-securejoin" "github.com/pkg/errors" + "github.com/sirupsen/logrus" ) func WriteFile(dir, file, data string) error { @@ -17,7 +20,7 @@ func WriteFile(dir, file, data string) error { if err != nil { return err } - if err := ioutil.WriteFile(path, []byte(data), 0700); err != nil { + if err := retryingWriteFile(path, []byte(data), 0700); err != nil { return errors.Wrapf(err, "failed to write %q to %q", data, path) } return nil @@ -34,3 +37,24 @@ func ReadFile(dir, file string) (string, error) { data, err := ioutil.ReadFile(path) return string(data), err } + +func retryingWriteFile(filename string, data []byte, perm os.FileMode) error { + for { + err := ioutil.WriteFile(filename, data, perm) + if isInterruptedWriteFile(err) { + logrus.Infof("interrupted while writing %s to %s", string(data), filename) + continue + } + return err + } +} + +func isInterruptedWriteFile(err error) bool { + if patherr, ok := err.(*os.PathError); ok { + errno, ok2 := patherr.Err.(syscall.Errno) + if ok2 && errno == syscall.EINTR { + return true + } + } + return false +} diff --git a/libcontainer/cgroups/fscommon/fscommon_test.go b/libcontainer/cgroups/fscommon/fscommon_test.go new file mode 100644 index 00000000..e56052e0 --- /dev/null +++ b/libcontainer/cgroups/fscommon/fscommon_test.go @@ -0,0 +1,39 @@ +// +build linux + +package fscommon + +import ( + "fmt" + "os" + "path/filepath" + "strconv" + "testing" + "time" + + "github.com/opencontainers/runc/libcontainer/cgroups" +) + +func TestWriteCgroupFileHandlesInterrupt(t *testing.T) { + if cgroups.IsCgroup2UnifiedMode() { + t.Skip("cgroup v2 is not supported") + } + + memoryCgroupMount, err := cgroups.FindCgroupMountpoint("", "memory") + if err != nil { + t.Fatal(err) + } + + cgroupName := fmt.Sprintf("test-eint-%d", time.Now().Nanosecond()) + cgroupPath := filepath.Join(memoryCgroupMount, cgroupName) + if err := os.MkdirAll(cgroupPath, 0755); err != nil { + t.Fatal(err) + } + defer os.RemoveAll(cgroupPath) + + for i := 0; i < 100000; i++ { + limit := 1024*1024 + i + if err := WriteFile(cgroupPath, "memory.limit_in_bytes", strconv.Itoa(limit)); err != nil { + t.Fatalf("Failed to write %d on attempt %d: %+v", limit, i, err) + } + } +}