feat: add user repo unit test

This commit is contained in:
LinkinStar 2022-10-28 19:28:08 +08:00
parent d353b8519b
commit 95cb136e3f
7 changed files with 202 additions and 38 deletions

View File

@ -112,7 +112,7 @@ func (uc *UserController) UserEmailLogin(ctx *gin.Context) {
return return
} }
captchaPass := uc.actionService.ActionRecordVerifyCaptcha(ctx, schema.ActionRecord_Type_Login, ctx.ClientIP(), req.CaptchaID, req.CaptchaCode) captchaPass := uc.actionService.ActionRecordVerifyCaptcha(ctx, schema.ActionRecordTypeLogin, ctx.ClientIP(), req.CaptchaID, req.CaptchaCode)
if !captchaPass { if !captchaPass {
resp := schema.UserVerifyEmailErrorResponse{ resp := schema.UserVerifyEmailErrorResponse{
Key: "captcha_code", Key: "captcha_code",
@ -125,7 +125,7 @@ func (uc *UserController) UserEmailLogin(ctx *gin.Context) {
resp, err := uc.userService.EmailLogin(ctx, req) resp, err := uc.userService.EmailLogin(ctx, req)
if err != nil { if err != nil {
_, _ = uc.actionService.ActionRecordAdd(ctx, schema.ActionRecord_Type_Login, ctx.ClientIP()) _, _ = uc.actionService.ActionRecordAdd(ctx, schema.ActionRecordTypeLogin, ctx.ClientIP())
resp := schema.UserVerifyEmailErrorResponse{ resp := schema.UserVerifyEmailErrorResponse{
Key: "e_mail", Key: "e_mail",
Value: "error.object.email_or_password_incorrect", Value: "error.object.email_or_password_incorrect",
@ -134,7 +134,7 @@ func (uc *UserController) UserEmailLogin(ctx *gin.Context) {
handler.HandleResponse(ctx, errors.BadRequest(reason.CaptchaVerificationFailed), resp) handler.HandleResponse(ctx, errors.BadRequest(reason.CaptchaVerificationFailed), resp)
return return
} }
uc.actionService.ActionRecordDel(ctx, schema.ActionRecord_Type_Login, ctx.ClientIP()) uc.actionService.ActionRecordDel(ctx, schema.ActionRecordTypeLogin, ctx.ClientIP())
handler.HandleResponse(ctx, nil, resp) handler.HandleResponse(ctx, nil, resp)
} }
@ -152,7 +152,7 @@ func (uc *UserController) RetrievePassWord(ctx *gin.Context) {
if handler.BindAndCheck(ctx, req) { if handler.BindAndCheck(ctx, req) {
return return
} }
captchaPass := uc.actionService.ActionRecordVerifyCaptcha(ctx, schema.ActionRecord_Type_Find_Pass, ctx.ClientIP(), req.CaptchaID, req.CaptchaCode) captchaPass := uc.actionService.ActionRecordVerifyCaptcha(ctx, schema.ActionRecordTypeFindPass, ctx.ClientIP(), req.CaptchaID, req.CaptchaCode)
if !captchaPass { if !captchaPass {
resp := schema.UserVerifyEmailErrorResponse{ resp := schema.UserVerifyEmailErrorResponse{
Key: "captcha_code", Key: "captcha_code",
@ -162,7 +162,7 @@ func (uc *UserController) RetrievePassWord(ctx *gin.Context) {
handler.HandleResponse(ctx, errors.BadRequest(reason.CaptchaVerificationFailed), resp) handler.HandleResponse(ctx, errors.BadRequest(reason.CaptchaVerificationFailed), resp)
return return
} }
_, _ = uc.actionService.ActionRecordAdd(ctx, schema.ActionRecord_Type_Find_Pass, ctx.ClientIP()) _, _ = uc.actionService.ActionRecordAdd(ctx, schema.ActionRecordTypeFindPass, ctx.ClientIP())
code, err := uc.userService.RetrievePassWord(ctx, req) code, err := uc.userService.RetrievePassWord(ctx, req)
handler.HandleResponse(ctx, err, code) handler.HandleResponse(ctx, err, code)
} }
@ -189,8 +189,8 @@ func (uc *UserController) UseRePassWord(ctx *gin.Context) {
return return
} }
resp, err := uc.userService.UseRePassWord(ctx, req) resp, err := uc.userService.UseRePassword(ctx, req)
uc.actionService.ActionRecordDel(ctx, schema.ActionRecord_Type_Find_Pass, ctx.ClientIP()) uc.actionService.ActionRecordDel(ctx, schema.ActionRecordTypeFindPass, ctx.ClientIP())
handler.HandleResponse(ctx, err, resp) handler.HandleResponse(ctx, err, resp)
} }
@ -256,7 +256,7 @@ func (uc *UserController) UserVerifyEmail(ctx *gin.Context) {
return return
} }
uc.actionService.ActionRecordDel(ctx, schema.ActionRecord_Type_Email, ctx.ClientIP()) uc.actionService.ActionRecordDel(ctx, schema.ActionRecordTypeEmail, ctx.ClientIP())
handler.HandleResponse(ctx, err, resp) handler.HandleResponse(ctx, err, resp)
} }
@ -282,7 +282,7 @@ func (uc *UserController) UserVerifyEmailSend(ctx *gin.Context) {
return return
} }
captchaPass := uc.actionService.ActionRecordVerifyCaptcha(ctx, schema.ActionRecord_Type_Email, ctx.ClientIP(), captchaPass := uc.actionService.ActionRecordVerifyCaptcha(ctx, schema.ActionRecordTypeEmail, ctx.ClientIP(),
req.CaptchaID, req.CaptchaCode) req.CaptchaID, req.CaptchaCode)
if !captchaPass { if !captchaPass {
resp := schema.UserVerifyEmailErrorResponse{ resp := schema.UserVerifyEmailErrorResponse{
@ -294,7 +294,7 @@ func (uc *UserController) UserVerifyEmailSend(ctx *gin.Context) {
return return
} }
uc.actionService.ActionRecordAdd(ctx, schema.ActionRecord_Type_Email, ctx.ClientIP()) uc.actionService.ActionRecordAdd(ctx, schema.ActionRecordTypeEmail, ctx.ClientIP())
err := uc.userService.UserVerifyEmailSend(ctx, userInfo.UserID) err := uc.userService.UserVerifyEmailSend(ctx, userInfo.UserID)
handler.HandleResponse(ctx, err, nil) handler.HandleResponse(ctx, err, nil)
} }
@ -340,7 +340,7 @@ func (uc *UserController) UserModifyPassWord(ctx *gin.Context) {
handler.HandleResponse(ctx, errors.BadRequest(reason.CaptchaVerificationFailed), resp) handler.HandleResponse(ctx, errors.BadRequest(reason.CaptchaVerificationFailed), resp)
return return
} }
err = uc.userService.UserModifyPassWord(ctx, req) err = uc.userService.UserModifyPassword(ctx, req)
handler.HandleResponse(ctx, err, nil) handler.HandleResponse(ctx, err, nil)
} }

View File

@ -0,0 +1,53 @@
package repo_test
import (
"context"
"testing"
"github.com/answerdev/answer/internal/entity"
"github.com/answerdev/answer/internal/repo/auth"
"github.com/answerdev/answer/internal/repo/user"
"github.com/stretchr/testify/assert"
)
func Test_userBackyardRepo_GetUserInfo(t *testing.T) {
userBackyardRepo := user.NewUserBackyardRepo(testDataSource, auth.NewAuthRepo(testDataSource))
got, exist, err := userBackyardRepo.GetUserInfo(context.TODO(), "1")
assert.NoError(t, err)
assert.True(t, exist)
assert.Equal(t, "1", got.ID)
}
func Test_userBackyardRepo_GetUserPage(t *testing.T) {
userBackyardRepo := user.NewUserBackyardRepo(testDataSource, auth.NewAuthRepo(testDataSource))
got, total, err := userBackyardRepo.GetUserPage(context.TODO(), 1, 1, &entity.User{Username: "admin"})
assert.NoError(t, err)
assert.Equal(t, int64(1), total)
assert.Equal(t, "1", got[0].ID)
}
func Test_userBackyardRepo_UpdateUserStatus(t *testing.T) {
userBackyardRepo := user.NewUserBackyardRepo(testDataSource, auth.NewAuthRepo(testDataSource))
got, exist, err := userBackyardRepo.GetUserInfo(context.TODO(), "1")
assert.NoError(t, err)
assert.True(t, exist)
assert.Equal(t, entity.UserStatusAvailable, got.Status)
err = userBackyardRepo.UpdateUserStatus(context.TODO(), "1", entity.UserStatusSuspended, entity.EmailStatusAvailable,
"admin@admin.com")
assert.NoError(t, err)
got, exist, err = userBackyardRepo.GetUserInfo(context.TODO(), "1")
assert.NoError(t, err)
assert.True(t, exist)
assert.Equal(t, entity.UserStatusSuspended, got.Status)
err = userBackyardRepo.UpdateUserStatus(context.TODO(), "1", entity.UserStatusAvailable, entity.EmailStatusAvailable,
"admin@admin.com")
assert.NoError(t, err)
got, exist, err = userBackyardRepo.GetUserInfo(context.TODO(), "1")
assert.NoError(t, err)
assert.True(t, exist)
assert.Equal(t, entity.UserStatusAvailable, got.Status)
}

View File

@ -0,0 +1,121 @@
package repo_test
import (
"context"
"testing"
"github.com/answerdev/answer/internal/entity"
"github.com/answerdev/answer/internal/repo/config"
"github.com/answerdev/answer/internal/repo/user"
"github.com/stretchr/testify/assert"
)
func Test_userRepo_AddUser(t *testing.T) {
userRepo := user.NewUserRepo(testDataSource, config.NewConfigRepo(testDataSource))
userInfo := &entity.User{
Username: "answer",
Pass: "answer",
EMail: "answer@example.com",
MailStatus: entity.EmailStatusAvailable,
Status: entity.UserStatusAvailable,
DisplayName: "answer",
IsAdmin: false,
}
err := userRepo.AddUser(context.TODO(), userInfo)
assert.NoError(t, err)
}
func Test_userRepo_BatchGetByID(t *testing.T) {
userRepo := user.NewUserRepo(testDataSource, config.NewConfigRepo(testDataSource))
got, err := userRepo.BatchGetByID(context.TODO(), []string{"1"})
assert.NoError(t, err)
assert.Equal(t, 1, len(got))
assert.Equal(t, "admin", got[0].Username)
}
func Test_userRepo_GetByEmail(t *testing.T) {
userRepo := user.NewUserRepo(testDataSource, config.NewConfigRepo(testDataSource))
got, exist, err := userRepo.GetByEmail(context.TODO(), "admin@admin.com")
assert.NoError(t, err)
assert.True(t, exist)
assert.Equal(t, "admin", got.Username)
}
func Test_userRepo_GetByUserID(t *testing.T) {
userRepo := user.NewUserRepo(testDataSource, config.NewConfigRepo(testDataSource))
got, exist, err := userRepo.GetByUserID(context.TODO(), "1")
assert.NoError(t, err)
assert.True(t, exist)
assert.Equal(t, "admin", got.Username)
}
func Test_userRepo_GetByUsername(t *testing.T) {
userRepo := user.NewUserRepo(testDataSource, config.NewConfigRepo(testDataSource))
got, exist, err := userRepo.GetByUsername(context.TODO(), "admin")
assert.NoError(t, err)
assert.True(t, exist)
assert.Equal(t, "admin", got.Username)
}
func Test_userRepo_IncreaseAnswerCount(t *testing.T) {
userRepo := user.NewUserRepo(testDataSource, config.NewConfigRepo(testDataSource))
err := userRepo.IncreaseAnswerCount(context.TODO(), "1", 1)
assert.NoError(t, err)
got, exist, err := userRepo.GetByUserID(context.TODO(), "1")
assert.NoError(t, err)
assert.True(t, exist)
assert.Equal(t, 1, got.AnswerCount)
}
func Test_userRepo_IncreaseQuestionCount(t *testing.T) {
userRepo := user.NewUserRepo(testDataSource, config.NewConfigRepo(testDataSource))
err := userRepo.IncreaseQuestionCount(context.TODO(), "1", 1)
assert.NoError(t, err)
got, exist, err := userRepo.GetByUserID(context.TODO(), "1")
assert.NoError(t, err)
assert.True(t, exist)
assert.Equal(t, 1, got.AnswerCount)
}
func Test_userRepo_UpdateEmail(t *testing.T) {
userRepo := user.NewUserRepo(testDataSource, config.NewConfigRepo(testDataSource))
err := userRepo.UpdateEmail(context.TODO(), "1", "admin@admin.com")
assert.NoError(t, err)
}
func Test_userRepo_UpdateEmailStatus(t *testing.T) {
userRepo := user.NewUserRepo(testDataSource, config.NewConfigRepo(testDataSource))
err := userRepo.UpdateEmailStatus(context.TODO(), "1", entity.EmailStatusToBeVerified)
assert.NoError(t, err)
}
func Test_userRepo_UpdateInfo(t *testing.T) {
userRepo := user.NewUserRepo(testDataSource, config.NewConfigRepo(testDataSource))
err := userRepo.UpdateInfo(context.TODO(), &entity.User{ID: "1", Bio: "test"})
assert.NoError(t, err)
got, exist, err := userRepo.GetByUserID(context.TODO(), "1")
assert.NoError(t, err)
assert.True(t, exist)
assert.Equal(t, "test", got.Bio)
}
func Test_userRepo_UpdateLastLoginDate(t *testing.T) {
userRepo := user.NewUserRepo(testDataSource, config.NewConfigRepo(testDataSource))
err := userRepo.UpdateLastLoginDate(context.TODO(), "1")
assert.NoError(t, err)
}
func Test_userRepo_UpdateNoticeStatus(t *testing.T) {
userRepo := user.NewUserRepo(testDataSource, config.NewConfigRepo(testDataSource))
err := userRepo.UpdateNoticeStatus(context.TODO(), "1", 1)
assert.NoError(t, err)
}
func Test_userRepo_UpdatePass(t *testing.T) {
userRepo := user.NewUserRepo(testDataSource, config.NewConfigRepo(testDataSource))
err := userRepo.UpdatePass(context.TODO(), "1", "admin")
assert.NoError(t, err)
}

View File

@ -2,7 +2,6 @@ package user
import ( import (
"context" "context"
"fmt"
"time" "time"
"github.com/answerdev/answer/internal/base/data" "github.com/answerdev/answer/internal/base/data"
@ -86,13 +85,10 @@ func (ur *userRepo) UpdateNoticeStatus(ctx context.Context, userID string, notic
return nil return nil
} }
func (ur *userRepo) UpdatePass(ctx context.Context, Data *entity.User) error { func (ur *userRepo) UpdatePass(ctx context.Context, userID, pass string) error {
if Data.ID == "" { _, err := ur.data.DB.Where("id = ?", userID).Cols("pass").Update(&entity.User{Pass: pass})
return fmt.Errorf("input error")
}
_, err := ur.data.DB.Where("id = ?", Data.ID).Cols("pass").Update(Data)
if err != nil { if err != nil {
return err return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack()
} }
return nil return nil
} }

View File

@ -150,16 +150,12 @@ func (r *GetOtherUserInfoByUsernameResp) GetFromUserEntity(userInfo *entity.User
} }
const ( const (
Mail_State_Pass = 1 NoticeStatusOn = 1
Mail_State_Verifi = 2 NoticeStatusOff = 2
Notice_Status_On = 1 ActionRecordTypeLogin = "login"
Notice_Status_Off = 2 ActionRecordTypeEmail = "e_mail"
ActionRecordTypeFindPass = "find_pass"
//ActionRecord ReportType
ActionRecord_Type_Login = "login"
ActionRecord_Type_Email = "e_mail"
ActionRecord_Type_Find_Pass = "find_pass"
) )
var UserStatusShow = map[int]string{ var UserStatusShow = map[int]string{

View File

@ -15,7 +15,7 @@ type UserRepo interface {
UpdateEmailStatus(ctx context.Context, userID string, emailStatus int) error UpdateEmailStatus(ctx context.Context, userID string, emailStatus int) error
UpdateNoticeStatus(ctx context.Context, userID string, noticeStatus int) error UpdateNoticeStatus(ctx context.Context, userID string, noticeStatus int) error
UpdateEmail(ctx context.Context, userID, email string) error UpdateEmail(ctx context.Context, userID, email string) error
UpdatePass(ctx context.Context, Data *entity.User) error UpdatePass(ctx context.Context, userID, pass string) error
UpdateInfo(ctx context.Context, userInfo *entity.User) (err error) UpdateInfo(ctx context.Context, userInfo *entity.User) (err error)
GetByUserID(ctx context.Context, userID string) (userInfo *entity.User, exist bool, err error) GetByUserID(ctx context.Context, userID string) (userInfo *entity.User, exist bool, err error)
BatchGetByID(ctx context.Context, ids []string) ([]*entity.User, error) BatchGetByID(ctx context.Context, ids []string) ([]*entity.User, error)

View File

@ -174,8 +174,8 @@ func (us *UserService) RetrievePassWord(ctx context.Context, req *schema.UserRet
return code, nil return code, nil
} }
// UseRePassWord // UseRePassword
func (us *UserService) UseRePassWord(ctx context.Context, req *schema.UserRePassWordRequest) (resp *schema.GetUserResp, err error) { func (us *UserService) UseRePassword(ctx context.Context, req *schema.UserRePassWordRequest) (resp *schema.GetUserResp, err error) {
data := &schema.EmailCodeContent{} data := &schema.EmailCodeContent{}
err = data.FromJSONString(req.Content) err = data.FromJSONString(req.Content)
if err != nil { if err != nil {
@ -193,8 +193,7 @@ func (us *UserService) UseRePassWord(ctx context.Context, req *schema.UserRePass
if err != nil { if err != nil {
return nil, err return nil, err
} }
userInfo.Pass = enpass err = us.userRepo.UpdatePass(ctx, userInfo.ID, enpass)
err = us.userRepo.UpdatePass(ctx, userInfo)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -219,8 +218,8 @@ func (us *UserService) UserModifyPassWordVerification(ctx context.Context, reque
return true, nil return true, nil
} }
// UserModifyPassWord // UserModifyPassword user modify password
func (us *UserService) UserModifyPassWord(ctx context.Context, request *schema.UserModifyPassWordRequest) error { func (us *UserService) UserModifyPassword(ctx context.Context, request *schema.UserModifyPassWordRequest) error {
enpass, err := us.encryptPassword(ctx, request.Pass) enpass, err := us.encryptPassword(ctx, request.Pass)
if err != nil { if err != nil {
return err return err
@ -236,8 +235,7 @@ func (us *UserService) UserModifyPassWord(ctx context.Context, request *schema.U
if !isPass { if !isPass {
return fmt.Errorf("the old password verification failed") return fmt.Errorf("the old password verification failed")
} }
userInfo.Pass = enpass err = us.userRepo.UpdatePass(ctx, userInfo.ID, enpass)
err = us.userRepo.UpdatePass(ctx, userInfo)
if err != nil { if err != nil {
return err return err
} }
@ -377,9 +375,9 @@ func (us *UserService) UserNoticeSet(ctx context.Context, userId string, noticeS
return nil, errors.BadRequest(reason.UserNotFound) return nil, errors.BadRequest(reason.UserNotFound)
} }
if noticeSwitch { if noticeSwitch {
userInfo.NoticeStatus = schema.Notice_Status_On userInfo.NoticeStatus = schema.NoticeStatusOn
} else { } else {
userInfo.NoticeStatus = schema.Notice_Status_Off userInfo.NoticeStatus = schema.NoticeStatusOff
} }
err = us.userRepo.UpdateNoticeStatus(ctx, userInfo.ID, userInfo.NoticeStatus) err = us.userRepo.UpdateNoticeStatus(ctx, userInfo.ID, userInfo.NoticeStatus)
return &schema.UserNoticeSetResp{NoticeSwitch: noticeSwitch}, err return &schema.UserNoticeSetResp{NoticeSwitch: noticeSwitch}, err