diff --git a/libcontainer/configs/config.go b/libcontainer/configs/config.go index 0325d58a..b79511d3 100644 --- a/libcontainer/configs/config.go +++ b/libcontainer/configs/config.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/json" "os/exec" + + "github.com/Sirupsen/logrus" ) type Rlimit struct { @@ -175,8 +177,8 @@ type Config struct { NoNewPrivileges bool `json:"no_new_privileges"` // Hooks are a collection of actions to perform at various container lifecycle events. - // Hooks are not able to be marshaled to json but they are also not needed to. - Hooks *Hooks `json:"-"` + // CommandHooks are serialized to JSON, but other hooks are not. + Hooks *Hooks // Version is the version of opencontainer specification that is supported. Version string `json:"version"` @@ -197,6 +199,52 @@ type Hooks struct { Poststop []Hook } +func (hooks *Hooks) UnmarshalJSON(b []byte) error { + var state struct { + Prestart []CommandHook + Poststart []CommandHook + Poststop []CommandHook + } + + if err := json.Unmarshal(b, &state); err != nil { + return err + } + + deserialize := func(shooks []CommandHook) (hooks []Hook) { + for _, shook := range shooks { + hooks = append(hooks, shook) + } + + return hooks + } + + hooks.Prestart = deserialize(state.Prestart) + hooks.Poststart = deserialize(state.Poststart) + hooks.Poststop = deserialize(state.Poststop) + return nil +} + +func (hooks Hooks) MarshalJSON() ([]byte, error) { + serialize := func(hooks []Hook) (serializableHooks []CommandHook) { + for _, hook := range hooks { + switch chook := hook.(type) { + case CommandHook: + serializableHooks = append(serializableHooks, chook) + default: + logrus.Warnf("cannot serialize hook of type %T, skipping", hook) + } + } + + return serializableHooks + } + + return json.Marshal(map[string]interface{}{ + "prestart": serialize(hooks.Prestart), + "poststart": serialize(hooks.Poststart), + "poststop": serialize(hooks.Poststop), + }) +} + // HookState is the payload provided to a hook on execution. type HookState struct { Version string `json:"version"` diff --git a/libcontainer/factory_linux_test.go b/libcontainer/factory_linux_test.go index b0c0f496..ea3b5132 100644 --- a/libcontainer/factory_linux_test.go +++ b/libcontainer/factory_linux_test.go @@ -6,6 +6,7 @@ import ( "io/ioutil" "os" "path/filepath" + "reflect" "syscall" "testing" @@ -132,9 +133,22 @@ func TestFactoryLoadContainer(t *testing.T) { defer os.RemoveAll(root) // setup default container config and state for mocking var ( - id = "1" + id = "1" + expectedHooks = &configs.Hooks{ + Prestart: []configs.Hook{ + configs.CommandHook{Command: configs.Command{Path: "prestart-hook"}}, + }, + Poststart: []configs.Hook{ + configs.CommandHook{Command: configs.Command{Path: "poststart-hook"}}, + }, + Poststop: []configs.Hook{ + unserializableHook{}, + configs.CommandHook{Command: configs.Command{Path: "poststop-hook"}}, + }, + } expectedConfig = &configs.Config{ Rootfs: "/mycontainer/root", + Hooks: expectedHooks, } expectedState = &State{ BaseState: BaseState{ @@ -164,6 +178,10 @@ func TestFactoryLoadContainer(t *testing.T) { if config.Rootfs != expectedConfig.Rootfs { t.Fatalf("expected rootfs %q but received %q", expectedConfig.Rootfs, config.Rootfs) } + expectedHooks.Poststop = expectedHooks.Poststop[1:] // expect unserializable hook to be skipped + if !reflect.DeepEqual(config.Hooks, expectedHooks) { + t.Fatalf("expects hooks %q but received %q", expectedHooks, config.Hooks) + } lcontainer, ok := container.(*linuxContainer) if !ok { t.Fatal("expected linux container on linux based systems") @@ -181,3 +199,9 @@ func marshal(path string, v interface{}) error { defer f.Close() return utils.WriteJSON(f, v) } + +type unserializableHook struct{} + +func (unserializableHook) Run(configs.HookState) error { + return nil +}