diff --git a/libcontainer/container_linux.go b/libcontainer/container_linux.go index ae3d39f3..4c220958 100644 --- a/libcontainer/container_linux.go +++ b/libcontainer/container_linux.go @@ -1487,17 +1487,19 @@ func (c *linuxContainer) criuSwrk(process *Process, req *criurpc.CriuReq, opts * return err } criuServer.Close() + // cmd.Process will be replaced by a restored init. + criuProcess := cmd.Process defer func() { criuClientCon.Close() - _, err := cmd.Process.Wait() + _, err := criuProcess.Wait() if err != nil { return } }() if applyCgroups { - err := c.criuApplyCgroups(cmd.Process.Pid, req) + err := c.criuApplyCgroups(criuProcess.Pid, req) if err != nil { return err } @@ -1505,7 +1507,7 @@ func (c *linuxContainer) criuSwrk(process *Process, req *criurpc.CriuReq, opts * var extFds []string if process != nil { - extFds, err = getPipeFds(cmd.Process.Pid) + extFds, err = getPipeFds(criuProcess.Pid) if err != nil { return err } @@ -1579,7 +1581,7 @@ func (c *linuxContainer) criuSwrk(process *Process, req *criurpc.CriuReq, opts * logrus.Debugf("Feature check says: %s", resp) criuFeatures = resp.GetFeatures() case t == criurpc.CriuReqType_NOTIFY: - if err := c.criuNotifications(resp, process, opts, extFds, oob[:oobn]); err != nil { + if err := c.criuNotifications(resp, process, cmd, opts, extFds, oob[:oobn]); err != nil { return err } t = criurpc.CriuReqType_NOTIFY @@ -1609,7 +1611,7 @@ func (c *linuxContainer) criuSwrk(process *Process, req *criurpc.CriuReq, opts * criuClientCon.CloseWrite() // cmd.Wait() waits cmd.goroutines which are used for proxying file descriptors. // Here we want to wait only the CRIU process. - st, err := cmd.Process.Wait() + st, err := criuProcess.Wait() if err != nil { return err } @@ -1655,7 +1657,7 @@ func unlockNetwork(config *configs.Config) error { return nil } -func (c *linuxContainer) criuNotifications(resp *criurpc.CriuResp, process *Process, opts *CriuOpts, fds []string, oob []byte) error { +func (c *linuxContainer) criuNotifications(resp *criurpc.CriuResp, process *Process, cmd *exec.Cmd, opts *CriuOpts, fds []string, oob []byte) error { notify := resp.GetNotify() if notify == nil { return fmt.Errorf("invalid response: %s", resp.String()) @@ -1691,7 +1693,14 @@ func (c *linuxContainer) criuNotifications(resp *criurpc.CriuResp, process *Proc } case notify.GetScript() == "post-restore": pid := notify.GetPid() - r, err := newRestoredProcess(int(pid), fds) + + p, err := os.FindProcess(int(pid)) + if err != nil { + return err + } + cmd.Process = p + + r, err := newRestoredProcess(cmd, fds) if err != nil { return err } diff --git a/libcontainer/integration/checkpoint_test.go b/libcontainer/integration/checkpoint_test.go index 552ebe8c..51ed401f 100644 --- a/libcontainer/integration/checkpoint_test.go +++ b/libcontainer/integration/checkpoint_test.go @@ -42,7 +42,6 @@ func showFile(t *testing.T, fname string) error { } func TestUsernsCheckpoint(t *testing.T) { - t.Skip("Ubuntu kernel is broken to run criu (#2196, #2198)") if _, err := os.Stat("/proc/self/ns/user"); os.IsNotExist(err) { t.Skip("userns is unsupported") } @@ -54,7 +53,6 @@ func TestUsernsCheckpoint(t *testing.T) { } func TestCheckpoint(t *testing.T) { - t.Skip("Ubuntu kernel is broken to run criu (#2196, #2198)") testCheckpoint(t, false) } @@ -248,7 +246,7 @@ func testCheckpoint(t *testing.T, userns bool) { } restoreStdinW.Close() - s, err := process.Wait() + s, err := restoreProcessConfig.Wait() if err != nil { t.Fatal(err) } diff --git a/libcontainer/restored_process.go b/libcontainer/restored_process.go index 28d52ad0..f861e82d 100644 --- a/libcontainer/restored_process.go +++ b/libcontainer/restored_process.go @@ -5,31 +5,29 @@ package libcontainer import ( "fmt" "os" + "os/exec" "github.com/opencontainers/runc/libcontainer/system" ) -func newRestoredProcess(pid int, fds []string) (*restoredProcess, error) { +func newRestoredProcess(cmd *exec.Cmd, fds []string) (*restoredProcess, error) { var ( err error ) - proc, err := os.FindProcess(pid) - if err != nil { - return nil, err - } + pid := cmd.Process.Pid stat, err := system.Stat(pid) if err != nil { return nil, err } return &restoredProcess{ - proc: proc, + cmd: cmd, processStartTime: stat.StartTime, fds: fds, }, nil } type restoredProcess struct { - proc *os.Process + cmd *exec.Cmd processStartTime uint64 fds []string } @@ -39,11 +37,11 @@ func (p *restoredProcess) start() error { } func (p *restoredProcess) pid() int { - return p.proc.Pid + return p.cmd.Process.Pid } func (p *restoredProcess) terminate() error { - err := p.proc.Kill() + err := p.cmd.Process.Kill() if _, werr := p.wait(); err == nil { err = werr } @@ -53,10 +51,13 @@ func (p *restoredProcess) terminate() error { func (p *restoredProcess) wait() (*os.ProcessState, error) { // TODO: how do we wait on the actual process? // maybe use --exec-cmd in criu - st, err := p.proc.Wait() + err := p.cmd.Wait() if err != nil { - return nil, err + if _, ok := err.(*exec.ExitError); !ok { + return nil, err + } } + st := p.cmd.ProcessState return st, nil } @@ -65,7 +66,7 @@ func (p *restoredProcess) startTime() (uint64, error) { } func (p *restoredProcess) signal(s os.Signal) error { - return p.proc.Signal(s) + return p.cmd.Process.Signal(s) } func (p *restoredProcess) externalDescriptors() []string {