Merge pull request #59 from cyphar/603-fixup-API
libcontainer: user: fix GetAdditionalGroups* API
This commit is contained in:
commit
0a4ba80b71
|
@ -349,17 +349,12 @@ func GetExecUser(userSpec string, defaults *ExecUser, passwd, group io.Reader) (
|
|||
return user, nil
|
||||
}
|
||||
|
||||
// GetAdditionalGroupsPath looks up a list of groups by name or group id
|
||||
// against the group file. If a group name cannot be found, an error will be
|
||||
// returned. If a group id cannot be found, it will be returned as-is.
|
||||
func GetAdditionalGroupsPath(additionalGroups []string, groupPath string) ([]int, error) {
|
||||
groupReader, err := os.Open(groupPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to open group file: %v", err)
|
||||
}
|
||||
defer groupReader.Close()
|
||||
|
||||
groups, err := ParseGroupFilter(groupReader, func(g Group) bool {
|
||||
// GetAdditionalGroups looks up a list of groups by name or group id against
|
||||
// against the given /etc/group formatted data. If a group name cannot be found,
|
||||
// an error will be returned. If a group id cannot be found, it will be returned
|
||||
// as-is.
|
||||
func GetAdditionalGroups(additionalGroups []string, group io.Reader) ([]int, error) {
|
||||
groups, err := ParseGroupFilter(group, func(g Group) bool {
|
||||
for _, ag := range additionalGroups {
|
||||
if g.Name == ag || strconv.Itoa(g.Gid) == ag {
|
||||
return true
|
||||
|
@ -405,3 +400,14 @@ func GetAdditionalGroupsPath(additionalGroups []string, groupPath string) ([]int
|
|||
}
|
||||
return gids, nil
|
||||
}
|
||||
|
||||
// Wrapper around GetAdditionalGroups that opens the groupPath given and gives
|
||||
// it as an argument to GetAdditionalGroups.
|
||||
func GetAdditionalGroupsPath(additionalGroups []string, groupPath string) ([]int, error) {
|
||||
group, err := os.Open(groupPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to open group file: %v", err)
|
||||
}
|
||||
defer group.Close()
|
||||
return GetAdditionalGroups(additionalGroups, group)
|
||||
}
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
package user
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strconv"
|
||||
|
@ -355,7 +353,7 @@ this is just some garbage data
|
|||
}
|
||||
}
|
||||
|
||||
func TestGetAdditionalGroupsPath(t *testing.T) {
|
||||
func TestGetAdditionalGroups(t *testing.T) {
|
||||
const groupContent = `
|
||||
root:x:0:root
|
||||
adm:x:43:
|
||||
|
@ -419,14 +417,9 @@ this is just some garbage data
|
|||
}
|
||||
|
||||
for _, test := range tests {
|
||||
tmpFile, err := ioutil.TempFile("", "get-additional-groups-path")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
fmt.Fprint(tmpFile, groupContent)
|
||||
tmpFile.Close()
|
||||
group := strings.NewReader(groupContent)
|
||||
|
||||
gids, err := GetAdditionalGroupsPath(test.groups, tmpFile.Name())
|
||||
gids, err := GetAdditionalGroups(test.groups, group)
|
||||
if test.hasError && err == nil {
|
||||
t.Errorf("Parse(%#v) expects error but has none", test)
|
||||
continue
|
||||
|
|
Loading…
Reference in New Issue