Merge pull request #2226 from avagin/runsc-restore-cmd-wait

restore: fix a race condition in process.Wait()
This commit is contained in:
Mrunal Patel 2020-03-15 18:48:16 -07:00 committed by GitHub
commit 981dbef514
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 30 additions and 22 deletions

View File

@ -1487,17 +1487,19 @@ func (c *linuxContainer) criuSwrk(process *Process, req *criurpc.CriuReq, opts *
return err return err
} }
criuServer.Close() criuServer.Close()
// cmd.Process will be replaced by a restored init.
criuProcess := cmd.Process
defer func() { defer func() {
criuClientCon.Close() criuClientCon.Close()
_, err := cmd.Process.Wait() _, err := criuProcess.Wait()
if err != nil { if err != nil {
return return
} }
}() }()
if applyCgroups { if applyCgroups {
err := c.criuApplyCgroups(cmd.Process.Pid, req) err := c.criuApplyCgroups(criuProcess.Pid, req)
if err != nil { if err != nil {
return err return err
} }
@ -1505,7 +1507,7 @@ func (c *linuxContainer) criuSwrk(process *Process, req *criurpc.CriuReq, opts *
var extFds []string var extFds []string
if process != nil { if process != nil {
extFds, err = getPipeFds(cmd.Process.Pid) extFds, err = getPipeFds(criuProcess.Pid)
if err != nil { if err != nil {
return err return err
} }
@ -1579,7 +1581,7 @@ func (c *linuxContainer) criuSwrk(process *Process, req *criurpc.CriuReq, opts *
logrus.Debugf("Feature check says: %s", resp) logrus.Debugf("Feature check says: %s", resp)
criuFeatures = resp.GetFeatures() criuFeatures = resp.GetFeatures()
case t == criurpc.CriuReqType_NOTIFY: case t == criurpc.CriuReqType_NOTIFY:
if err := c.criuNotifications(resp, process, opts, extFds, oob[:oobn]); err != nil { if err := c.criuNotifications(resp, process, cmd, opts, extFds, oob[:oobn]); err != nil {
return err return err
} }
t = criurpc.CriuReqType_NOTIFY t = criurpc.CriuReqType_NOTIFY
@ -1609,7 +1611,7 @@ func (c *linuxContainer) criuSwrk(process *Process, req *criurpc.CriuReq, opts *
criuClientCon.CloseWrite() criuClientCon.CloseWrite()
// cmd.Wait() waits cmd.goroutines which are used for proxying file descriptors. // cmd.Wait() waits cmd.goroutines which are used for proxying file descriptors.
// Here we want to wait only the CRIU process. // Here we want to wait only the CRIU process.
st, err := cmd.Process.Wait() st, err := criuProcess.Wait()
if err != nil { if err != nil {
return err return err
} }
@ -1655,7 +1657,7 @@ func unlockNetwork(config *configs.Config) error {
return nil return nil
} }
func (c *linuxContainer) criuNotifications(resp *criurpc.CriuResp, process *Process, opts *CriuOpts, fds []string, oob []byte) error { func (c *linuxContainer) criuNotifications(resp *criurpc.CriuResp, process *Process, cmd *exec.Cmd, opts *CriuOpts, fds []string, oob []byte) error {
notify := resp.GetNotify() notify := resp.GetNotify()
if notify == nil { if notify == nil {
return fmt.Errorf("invalid response: %s", resp.String()) return fmt.Errorf("invalid response: %s", resp.String())
@ -1691,7 +1693,14 @@ func (c *linuxContainer) criuNotifications(resp *criurpc.CriuResp, process *Proc
} }
case notify.GetScript() == "post-restore": case notify.GetScript() == "post-restore":
pid := notify.GetPid() pid := notify.GetPid()
r, err := newRestoredProcess(int(pid), fds)
p, err := os.FindProcess(int(pid))
if err != nil {
return err
}
cmd.Process = p
r, err := newRestoredProcess(cmd, fds)
if err != nil { if err != nil {
return err return err
} }

View File

@ -42,7 +42,6 @@ func showFile(t *testing.T, fname string) error {
} }
func TestUsernsCheckpoint(t *testing.T) { func TestUsernsCheckpoint(t *testing.T) {
t.Skip("Ubuntu kernel is broken to run criu (#2196, #2198)")
if _, err := os.Stat("/proc/self/ns/user"); os.IsNotExist(err) { if _, err := os.Stat("/proc/self/ns/user"); os.IsNotExist(err) {
t.Skip("userns is unsupported") t.Skip("userns is unsupported")
} }
@ -54,7 +53,6 @@ func TestUsernsCheckpoint(t *testing.T) {
} }
func TestCheckpoint(t *testing.T) { func TestCheckpoint(t *testing.T) {
t.Skip("Ubuntu kernel is broken to run criu (#2196, #2198)")
testCheckpoint(t, false) testCheckpoint(t, false)
} }
@ -248,7 +246,7 @@ func testCheckpoint(t *testing.T, userns bool) {
} }
restoreStdinW.Close() restoreStdinW.Close()
s, err := process.Wait() s, err := restoreProcessConfig.Wait()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -5,31 +5,29 @@ package libcontainer
import ( import (
"fmt" "fmt"
"os" "os"
"os/exec"
"github.com/opencontainers/runc/libcontainer/system" "github.com/opencontainers/runc/libcontainer/system"
) )
func newRestoredProcess(pid int, fds []string) (*restoredProcess, error) { func newRestoredProcess(cmd *exec.Cmd, fds []string) (*restoredProcess, error) {
var ( var (
err error err error
) )
proc, err := os.FindProcess(pid) pid := cmd.Process.Pid
if err != nil {
return nil, err
}
stat, err := system.Stat(pid) stat, err := system.Stat(pid)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &restoredProcess{ return &restoredProcess{
proc: proc, cmd: cmd,
processStartTime: stat.StartTime, processStartTime: stat.StartTime,
fds: fds, fds: fds,
}, nil }, nil
} }
type restoredProcess struct { type restoredProcess struct {
proc *os.Process cmd *exec.Cmd
processStartTime uint64 processStartTime uint64
fds []string fds []string
} }
@ -39,11 +37,11 @@ func (p *restoredProcess) start() error {
} }
func (p *restoredProcess) pid() int { func (p *restoredProcess) pid() int {
return p.proc.Pid return p.cmd.Process.Pid
} }
func (p *restoredProcess) terminate() error { func (p *restoredProcess) terminate() error {
err := p.proc.Kill() err := p.cmd.Process.Kill()
if _, werr := p.wait(); err == nil { if _, werr := p.wait(); err == nil {
err = werr err = werr
} }
@ -53,10 +51,13 @@ func (p *restoredProcess) terminate() error {
func (p *restoredProcess) wait() (*os.ProcessState, error) { func (p *restoredProcess) wait() (*os.ProcessState, error) {
// TODO: how do we wait on the actual process? // TODO: how do we wait on the actual process?
// maybe use --exec-cmd in criu // maybe use --exec-cmd in criu
st, err := p.proc.Wait() err := p.cmd.Wait()
if err != nil { if err != nil {
return nil, err if _, ok := err.(*exec.ExitError); !ok {
return nil, err
}
} }
st := p.cmd.ProcessState
return st, nil return st, nil
} }
@ -65,7 +66,7 @@ func (p *restoredProcess) startTime() (uint64, error) {
} }
func (p *restoredProcess) signal(s os.Signal) error { func (p *restoredProcess) signal(s os.Signal) error {
return p.proc.Signal(s) return p.cmd.Process.Signal(s)
} }
func (p *restoredProcess) externalDescriptors() []string { func (p *restoredProcess) externalDescriptors() []string {