diff --git a/libcontainer/state_linux_test.go b/libcontainer/state_linux_test.go index 65e2b850..6ef516b7 100644 --- a/libcontainer/state_linux_test.go +++ b/libcontainer/state_linux_test.go @@ -2,16 +2,21 @@ package libcontainer -import "testing" +import ( + "reflect" + "testing" +) + +var states = map[containerState]Status{ + &createdState{}: Created, + &runningState{}: Running, + &restoredState{}: Running, + &pausedState{}: Paused, + &stoppedState{}: Stopped, + &loadedState{s: Running}: Running, +} func TestStateStatus(t *testing.T) { - states := map[containerState]Status{ - &stoppedState{}: Stopped, - &runningState{}: Running, - &restoredState{}: Running, - &pausedState{}: Paused, - &createdState{}: Created, - } for s, status := range states { if s.status() != status { t.Fatalf("state returned %s but expected %s", s.status(), status) @@ -24,94 +29,88 @@ func isStateTransitionError(err error) bool { return ok } -func TestStoppedStateTransition(t *testing.T) { - s := &stoppedState{c: &linuxContainer{}} - valid := []containerState{ - &stoppedState{}, - &runningState{}, - &restoredState{}, +func testTransitions(t *testing.T, initialState containerState, valid []containerState) { + validMap := map[reflect.Type]interface{}{} + for _, validState := range valid { + validMap[reflect.TypeOf(validState)] = nil + t.Run(validState.status().String(), func(t *testing.T) { + if err := initialState.transition(validState); err != nil { + t.Fatal(err) + } + }) } - for _, v := range valid { - if err := s.transition(v); err != nil { - t.Fatal(err) + for state := range states { + if _, ok := validMap[reflect.TypeOf(state)]; ok { + continue } + t.Run(state.status().String(), func(t *testing.T) { + err := initialState.transition(state) + if err == nil { + t.Fatal("transition should fail") + } + if !isStateTransitionError(err) { + t.Fatal("expected stateTransitionError") + } + }) } - err := s.transition(&pausedState{}) - if err == nil { - t.Fatal("transition to paused state should fail") - } - if !isStateTransitionError(err) { - t.Fatal("expected stateTransitionError") - } +} + +func TestStoppedStateTransition(t *testing.T) { + testTransitions( + t, + &stoppedState{c: &linuxContainer{}}, + []containerState{ + &stoppedState{}, + &runningState{}, + &restoredState{}, + }, + ) } func TestPausedStateTransition(t *testing.T) { - s := &pausedState{c: &linuxContainer{}} - valid := []containerState{ - &pausedState{}, - &runningState{}, - &stoppedState{}, - } - for _, v := range valid { - if err := s.transition(v); err != nil { - t.Fatal(err) - } - } + testTransitions( + t, + &pausedState{c: &linuxContainer{}}, + []containerState{ + &pausedState{}, + &runningState{}, + &stoppedState{}, + }, + ) } func TestRestoredStateTransition(t *testing.T) { - s := &restoredState{c: &linuxContainer{}} - valid := []containerState{ - &stoppedState{}, - &runningState{}, - } - for _, v := range valid { - if err := s.transition(v); err != nil { - t.Fatal(err) - } - } - err := s.transition(&createdState{}) - if err == nil { - t.Fatal("transition to created state should fail") - } - if !isStateTransitionError(err) { - t.Fatal("expected stateTransitionError") - } + testTransitions( + t, + &restoredState{c: &linuxContainer{}}, + []containerState{ + &stoppedState{}, + &runningState{}, + }, + ) } func TestRunningStateTransition(t *testing.T) { - s := &runningState{c: &linuxContainer{}} - valid := []containerState{ - &stoppedState{}, - &pausedState{}, - &runningState{}, - } - for _, v := range valid { - if err := s.transition(v); err != nil { - t.Fatal(err) - } - } - - err := s.transition(&createdState{}) - if err == nil { - t.Fatal("transition to created state should fail") - } - if !isStateTransitionError(err) { - t.Fatal("expected stateTransitionError") - } + testTransitions( + t, + &runningState{c: &linuxContainer{}}, + []containerState{ + &stoppedState{}, + &pausedState{}, + &runningState{}, + }, + ) } func TestCreatedStateTransition(t *testing.T) { - s := &createdState{c: &linuxContainer{}} - valid := []containerState{ - &stoppedState{}, - &pausedState{}, - &runningState{}, - &createdState{}, - } - for _, v := range valid { - if err := s.transition(v); err != nil { - t.Fatal(err) - } - } + testTransitions( + t, + &createdState{c: &linuxContainer{}}, + []containerState{ + &stoppedState{}, + &pausedState{}, + &runningState{}, + &createdState{}, + }, + ) }