refactor GetAdditionalGroupsPath

This parses group file only once to process a list of groups instead of parsing
once for each group. Also added an unit test for GetAdditionalGroupsPath

Signed-off-by: Daniel, Dao Quang Minh <dqminh89@gmail.com>
This commit is contained in:
Daniel, Dao Quang Minh 2015-05-25 19:02:34 +00:00
parent 50603caabe
commit d4ece29c0b
2 changed files with 136 additions and 37 deletions

View File

@ -349,51 +349,59 @@ func GetExecUser(userSpec string, defaults *ExecUser, passwd, group io.Reader) (
return user, nil
}
// GetAdditionalGroupsPath is a wrapper for GetAdditionalGroups. It reads data from the
// given file path and uses that data as the arguments to GetAdditionalGroups.
// GetAdditionalGroupsPath looks up a list of groups by name or group id
// against the group file. If a group name cannot be found, an error will be
// returned. If a group id cannot be found, it will be returned as-is.
func GetAdditionalGroupsPath(additionalGroups []string, groupPath string) ([]int, error) {
var groupIds []int
for _, ag := range additionalGroups {
groupReader, err := os.Open(groupPath)
if err != nil {
return nil, fmt.Errorf("Failed to open group file: %v", err)
}
defer groupReader.Close()
groupId, err := GetAdditionalGroup(ag, groupReader)
if err != nil {
return nil, err
}
groupIds = append(groupIds, groupId)
groupReader, err := os.Open(groupPath)
if err != nil {
return nil, fmt.Errorf("Failed to open group file: %v", err)
}
defer groupReader.Close()
return groupIds, nil
}
// GetAdditionalGroup looks up the specified group in the passed groupReader.
func GetAdditionalGroup(additionalGroup string, groupReader io.Reader) (int, error) {
groups, err := ParseGroupFilter(groupReader, func(g Group) bool {
return g.Name == additionalGroup || strconv.Itoa(g.Gid) == additionalGroup
for _, ag := range additionalGroups {
if g.Name == ag || strconv.Itoa(g.Gid) == ag {
return true
}
}
return false
})
if err != nil {
return -1, fmt.Errorf("Unable to find additional groups %v: %v", additionalGroup, err)
return nil, fmt.Errorf("Unable to find additional groups %v: %v", additionalGroups, err)
}
if groups != nil && len(groups) > 0 {
// if we found any group entries that matched our filter, let's take the first one as "correct"
return groups[0].Gid, nil
} else {
// we asked for a group but didn't find id... let's check to see if we wanted a numeric group
addGroup, err := strconv.Atoi(additionalGroup)
if err != nil {
// not numeric - we have to bail
return -1, fmt.Errorf("Unable to find group %v", additionalGroup)
}
// Ensure gid is inside gid range.
if addGroup < minId || addGroup > maxId {
return -1, ErrRange
gidMap := make(map[int]struct{})
for _, ag := range additionalGroups {
var found bool
for _, g := range groups {
// if we found a matched group either by name or gid, take the
// first matched as correct
if g.Name == ag || strconv.Itoa(g.Gid) == ag {
if _, ok := gidMap[g.Gid]; !ok {
gidMap[g.Gid] = struct{}{}
found = true
break
}
}
}
// we asked for a group but didn't find it. let's check to see
// if we wanted a numeric group
if !found {
gid, err := strconv.Atoi(ag)
if err != nil {
return nil, fmt.Errorf("Unable to find group %s", ag)
}
// Ensure gid is inside gid range.
if gid < minId || gid > maxId {
return nil, ErrRange
}
gidMap[gid] = struct{}{}
}
return addGroup, nil
}
gids := []int{}
for gid := range gidMap {
gids = append(gids, gid)
}
return gids, nil
}

View File

@ -1,8 +1,12 @@
package user
import (
"fmt"
"io"
"io/ioutil"
"reflect"
"sort"
"strconv"
"strings"
"testing"
)
@ -350,3 +354,90 @@ this is just some garbage data
}
}
}
func TestGetAdditionalGroupsPath(t *testing.T) {
const groupContent = `
root:x:0:root
adm:x:43:
grp:x:1234:root,adm
adm:x:4343:root,adm-duplicate
this is just some garbage data
`
tests := []struct {
groups []string
expected []int
hasError bool
}{
{
// empty group
groups: []string{},
expected: []int{},
},
{
// single group
groups: []string{"adm"},
expected: []int{43},
},
{
// multiple groups
groups: []string{"adm", "grp"},
expected: []int{43, 1234},
},
{
// invalid group
groups: []string{"adm", "grp", "not-exist"},
expected: nil,
hasError: true,
},
{
// group with numeric id
groups: []string{"43"},
expected: []int{43},
},
{
// group with unknown numeric id
groups: []string{"adm", "10001"},
expected: []int{43, 10001},
},
{
// groups specified twice with numeric and name
groups: []string{"adm", "43"},
expected: []int{43},
},
{
// groups with too small id
groups: []string{"-1"},
expected: nil,
hasError: true,
},
{
// groups with too large id
groups: []string{strconv.Itoa(1 << 31)},
expected: nil,
hasError: true,
},
}
for _, test := range tests {
tmpFile, err := ioutil.TempFile("", "get-additional-groups-path")
if err != nil {
t.Error(err)
}
fmt.Fprint(tmpFile, groupContent)
tmpFile.Close()
gids, err := GetAdditionalGroupsPath(test.groups, tmpFile.Name())
if test.hasError && err == nil {
t.Errorf("Parse(%#v) expects error but has none", test)
continue
}
if !test.hasError && err != nil {
t.Errorf("Parse(%#v) has error %v", test, err)
continue
}
sort.Sort(sort.IntSlice(gids))
if !reflect.DeepEqual(gids, test.expected) {
t.Errorf("Gids(%v), expect %v from groups %v", gids, test.expected, test.groups)
}
}
}