restore: fix a race condition in process.Wait()

Adrian reported that the checkpoint test stated failing:
=== RUN   TestCheckpoint
--- FAIL: TestCheckpoint (0.38s)
    checkpoint_test.go:297: Did not restore the pipe correctly:

The problem here is when we start exec.Cmd, we don't call its wait
method. This means that we don't wait cmd.goroutines ans so we don't
know when all data will be read from process pipes.

Signed-off-by: Andrei Vagin <avagin@gmail.com>
This commit is contained in:
Andrei Vagin 2020-02-09 22:32:56 -08:00
parent e6555cc01a
commit 269ea385a4
3 changed files with 30 additions and 22 deletions

View File

@ -1485,17 +1485,19 @@ func (c *linuxContainer) criuSwrk(process *Process, req *criurpc.CriuReq, opts *
return err return err
} }
criuServer.Close() criuServer.Close()
// cmd.Process will be replaced by a restored init.
criuProcess := cmd.Process
defer func() { defer func() {
criuClientCon.Close() criuClientCon.Close()
_, err := cmd.Process.Wait() _, err := criuProcess.Wait()
if err != nil { if err != nil {
return return
} }
}() }()
if applyCgroups { if applyCgroups {
err := c.criuApplyCgroups(cmd.Process.Pid, req) err := c.criuApplyCgroups(criuProcess.Pid, req)
if err != nil { if err != nil {
return err return err
} }
@ -1503,7 +1505,7 @@ func (c *linuxContainer) criuSwrk(process *Process, req *criurpc.CriuReq, opts *
var extFds []string var extFds []string
if process != nil { if process != nil {
extFds, err = getPipeFds(cmd.Process.Pid) extFds, err = getPipeFds(criuProcess.Pid)
if err != nil { if err != nil {
return err return err
} }
@ -1577,7 +1579,7 @@ func (c *linuxContainer) criuSwrk(process *Process, req *criurpc.CriuReq, opts *
logrus.Debugf("Feature check says: %s", resp) logrus.Debugf("Feature check says: %s", resp)
criuFeatures = resp.GetFeatures() criuFeatures = resp.GetFeatures()
case t == criurpc.CriuReqType_NOTIFY: case t == criurpc.CriuReqType_NOTIFY:
if err := c.criuNotifications(resp, process, opts, extFds, oob[:oobn]); err != nil { if err := c.criuNotifications(resp, process, cmd, opts, extFds, oob[:oobn]); err != nil {
return err return err
} }
t = criurpc.CriuReqType_NOTIFY t = criurpc.CriuReqType_NOTIFY
@ -1607,7 +1609,7 @@ func (c *linuxContainer) criuSwrk(process *Process, req *criurpc.CriuReq, opts *
criuClientCon.CloseWrite() criuClientCon.CloseWrite()
// cmd.Wait() waits cmd.goroutines which are used for proxying file descriptors. // cmd.Wait() waits cmd.goroutines which are used for proxying file descriptors.
// Here we want to wait only the CRIU process. // Here we want to wait only the CRIU process.
st, err := cmd.Process.Wait() st, err := criuProcess.Wait()
if err != nil { if err != nil {
return err return err
} }
@ -1653,7 +1655,7 @@ func unlockNetwork(config *configs.Config) error {
return nil return nil
} }
func (c *linuxContainer) criuNotifications(resp *criurpc.CriuResp, process *Process, opts *CriuOpts, fds []string, oob []byte) error { func (c *linuxContainer) criuNotifications(resp *criurpc.CriuResp, process *Process, cmd *exec.Cmd, opts *CriuOpts, fds []string, oob []byte) error {
notify := resp.GetNotify() notify := resp.GetNotify()
if notify == nil { if notify == nil {
return fmt.Errorf("invalid response: %s", resp.String()) return fmt.Errorf("invalid response: %s", resp.String())
@ -1689,7 +1691,14 @@ func (c *linuxContainer) criuNotifications(resp *criurpc.CriuResp, process *Proc
} }
case notify.GetScript() == "post-restore": case notify.GetScript() == "post-restore":
pid := notify.GetPid() pid := notify.GetPid()
r, err := newRestoredProcess(int(pid), fds)
p, err := os.FindProcess(int(pid))
if err != nil {
return err
}
cmd.Process = p
r, err := newRestoredProcess(cmd, fds)
if err != nil { if err != nil {
return err return err
} }

View File

@ -42,7 +42,6 @@ func showFile(t *testing.T, fname string) error {
} }
func TestUsernsCheckpoint(t *testing.T) { func TestUsernsCheckpoint(t *testing.T) {
t.Skip("Ubuntu kernel is broken to run criu (#2196, #2198)")
if _, err := os.Stat("/proc/self/ns/user"); os.IsNotExist(err) { if _, err := os.Stat("/proc/self/ns/user"); os.IsNotExist(err) {
t.Skip("userns is unsupported") t.Skip("userns is unsupported")
} }
@ -54,7 +53,6 @@ func TestUsernsCheckpoint(t *testing.T) {
} }
func TestCheckpoint(t *testing.T) { func TestCheckpoint(t *testing.T) {
t.Skip("Ubuntu kernel is broken to run criu (#2196, #2198)")
testCheckpoint(t, false) testCheckpoint(t, false)
} }
@ -248,7 +246,7 @@ func testCheckpoint(t *testing.T, userns bool) {
} }
restoreStdinW.Close() restoreStdinW.Close()
s, err := process.Wait() s, err := restoreProcessConfig.Wait()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -5,31 +5,29 @@ package libcontainer
import ( import (
"fmt" "fmt"
"os" "os"
"os/exec"
"github.com/opencontainers/runc/libcontainer/system" "github.com/opencontainers/runc/libcontainer/system"
) )
func newRestoredProcess(pid int, fds []string) (*restoredProcess, error) { func newRestoredProcess(cmd *exec.Cmd, fds []string) (*restoredProcess, error) {
var ( var (
err error err error
) )
proc, err := os.FindProcess(pid) pid := cmd.Process.Pid
if err != nil {
return nil, err
}
stat, err := system.Stat(pid) stat, err := system.Stat(pid)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &restoredProcess{ return &restoredProcess{
proc: proc, cmd: cmd,
processStartTime: stat.StartTime, processStartTime: stat.StartTime,
fds: fds, fds: fds,
}, nil }, nil
} }
type restoredProcess struct { type restoredProcess struct {
proc *os.Process cmd *exec.Cmd
processStartTime uint64 processStartTime uint64
fds []string fds []string
} }
@ -39,11 +37,11 @@ func (p *restoredProcess) start() error {
} }
func (p *restoredProcess) pid() int { func (p *restoredProcess) pid() int {
return p.proc.Pid return p.cmd.Process.Pid
} }
func (p *restoredProcess) terminate() error { func (p *restoredProcess) terminate() error {
err := p.proc.Kill() err := p.cmd.Process.Kill()
if _, werr := p.wait(); err == nil { if _, werr := p.wait(); err == nil {
err = werr err = werr
} }
@ -53,10 +51,13 @@ func (p *restoredProcess) terminate() error {
func (p *restoredProcess) wait() (*os.ProcessState, error) { func (p *restoredProcess) wait() (*os.ProcessState, error) {
// TODO: how do we wait on the actual process? // TODO: how do we wait on the actual process?
// maybe use --exec-cmd in criu // maybe use --exec-cmd in criu
st, err := p.proc.Wait() err := p.cmd.Wait()
if err != nil { if err != nil {
return nil, err if _, ok := err.(*exec.ExitError); !ok {
return nil, err
}
} }
st := p.cmd.ProcessState
return st, nil return st, nil
} }
@ -65,7 +66,7 @@ func (p *restoredProcess) startTime() (uint64, error) {
} }
func (p *restoredProcess) signal(s os.Signal) error { func (p *restoredProcess) signal(s os.Signal) error {
return p.proc.Signal(s) return p.cmd.Process.Signal(s)
} }
func (p *restoredProcess) externalDescriptors() []string { func (p *restoredProcess) externalDescriptors() []string {