This commit is contained in:
UlricQin 2020-11-03 12:11:32 +08:00
commit b448dad860
5 changed files with 156 additions and 128 deletions

View File

@ -92,3 +92,5 @@ wechat:
corp_id: "xxxxxxxxxxxxx" corp_id: "xxxxxxxxxxxxx"
agent_id: 1000000 agent_id: 1000000
secret: "xxxxxxxxxxxxxxxxx" secret: "xxxxxxxxxxxxxxxxx"
captcha: false

View File

@ -1,5 +1,7 @@
package models package models
import "errors"
type LoginCode struct { type LoginCode struct {
Username string `json:"username"` Username string `json:"username"`
Code string `json:"code"` Code string `json:"code"`
@ -7,6 +9,10 @@ type LoginCode struct {
CreatedAt int64 `json:"created_at"` CreatedAt int64 `json:"created_at"`
} }
var (
errLoginCode = errors.New("invalid login code")
)
func LoginCodeGet(where string, args ...interface{}) (*LoginCode, error) { func LoginCodeGet(where string, args ...interface{}) (*LoginCode, error) {
var obj LoginCode var obj LoginCode
has, err := DB["rdb"].Where(where, args...).Get(&obj) has, err := DB["rdb"].Where(where, args...).Get(&obj)
@ -15,7 +21,7 @@ func LoginCodeGet(where string, args ...interface{}) (*LoginCode, error) {
} }
if !has { if !has {
return nil, nil return nil, errLoginCode
} }
return &obj, nil return &obj, nil

View File

@ -18,6 +18,15 @@ import (
"github.com/didi/nightingale/src/modules/rdb/config" "github.com/didi/nightingale/src/modules/rdb/config"
) )
const (
LOGIN_T_SMS = "sms-code"
LOGIN_T_EMAIL = "email-code"
LOGIN_T_RST = "rst-code"
LOGIN_T_PWD = "password"
LOGIN_T_LDAP = "ldap"
LOGIN_EXPIRES_IN = 300
)
type User struct { type User struct {
Id int64 `json:"id"` Id int64 `json:"id"`
UUID string `json:"-" xorm:"'uuid'"` UUID string `json:"-" xorm:"'uuid'"`
@ -82,18 +91,16 @@ func InitRooter() {
log.Println("user root init done") log.Println("user root init done")
} }
func LdapLogin(user, pass, clientIP string) error { func LdapLogin(user, pass string) (*User, error) {
sr, err := ldapReq(user, pass) sr, err := ldapReq(user, pass)
if err != nil { if err != nil {
return err return nil, err
} }
go LoginLogNew(user, clientIP, "in")
var u User var u User
has, err := DB["rdb"].Where("username=?", user).Get(&u) has, err := DB["rdb"].Where("username=?", user).Get(&u)
if err != nil { if err != nil {
return err return nil, err
} }
u.CopyLdapAttr(sr) u.CopyLdapAttr(sr)
@ -101,9 +108,9 @@ func LdapLogin(user, pass, clientIP string) error {
if has { if has {
if config.Config.LDAP.CoverAttributes { if config.Config.LDAP.CoverAttributes {
_, err := DB["rdb"].Where("id=?", u.Id).Update(u) _, err := DB["rdb"].Where("id=?", u.Id).Update(u)
return err return nil, err
} else { } else {
return nil return &u, err
} }
} }
@ -111,34 +118,76 @@ func LdapLogin(user, pass, clientIP string) error {
u.Password = "******" u.Password = "******"
u.UUID = GenUUIDForUser(user) u.UUID = GenUUIDForUser(user)
_, err = DB["rdb"].Insert(u) _, err = DB["rdb"].Insert(u)
return err return &u, nil
} }
func PassLogin(user, pass, clientIP string) error { func PassLogin(user, pass string) (*User, error) {
var u User var u User
has, err := DB["rdb"].Where("username=?", user).Cols("password").Get(&u) has, err := DB["rdb"].Where("username=?", user).Get(&u)
if err != nil { if err != nil {
return err return nil, err
} }
if !has { if !has {
logger.Infof("password auth fail, no such user: %s", user) logger.Infof("password auth fail, no such user: %s", user)
return fmt.Errorf("login fail, check your username and password") return nil, fmt.Errorf("login fail, check your username and password")
} }
loginPass, err := CryptoPass(pass) loginPass, err := CryptoPass(pass)
if err != nil { if err != nil {
return err return nil, err
} }
if loginPass != u.Password { if loginPass != u.Password {
logger.Infof("password auth fail, password error, user: %s", user) logger.Infof("password auth fail, password error, user: %s", user)
return fmt.Errorf("login fail, check your username and password") return nil, fmt.Errorf("login fail, check your username and password")
} }
go LoginLogNew(user, clientIP, "in") return &u, nil
}
return nil func SmsCodeLogin(phone, code string) (*User, error) {
user, _ := UserGet("phone=?", phone)
if user == nil {
return nil, fmt.Errorf("phone %s dose not exist", phone)
}
lc, err := LoginCodeGet("username=? and code=? and login_type=?", user.Username, code, LOGIN_T_SMS)
if err != nil {
logger.Infof("sms-code auth fail, user: %s", user.Username)
return nil, fmt.Errorf("login fail, check your sms-code")
}
if time.Now().Unix()-lc.CreatedAt > LOGIN_EXPIRES_IN {
logger.Infof("sms-code auth expired, user: %s", user.Username)
return nil, fmt.Errorf("login fail, the code has expired")
}
lc.Del()
return user, nil
}
func EmailCodeLogin(email, code string) (*User, error) {
user, _ := UserGet("email=?", email)
if user == nil {
return nil, fmt.Errorf("email %s dose not exist", email)
}
lc, err := LoginCodeGet("username=? and code=? and login_type=?", user.Username, code, LOGIN_T_EMAIL)
if err != nil {
logger.Infof("email-code auth fail, user: %s", user.Username)
return nil, fmt.Errorf("login fail, check your email-code")
}
if time.Now().Unix()-lc.CreatedAt > LOGIN_EXPIRES_IN {
logger.Infof("email-code auth expired, user: %s", user.Username)
return nil, fmt.Errorf("login fail, the code has expired")
}
lc.Del()
return user, nil
} }
func UserGet(where string, args ...interface{}) (*User, error) { func UserGet(where string, args ...interface{}) (*User, error) {

View File

@ -22,6 +22,8 @@ func Config(r *gin.Engine) {
notLogin.GET("/auth/v2/callback", authCallbackV2) notLogin.GET("/auth/v2/callback", authCallbackV2)
notLogin.GET("/auth/v2/logout", logoutV2) notLogin.GET("/auth/v2/logout", logoutV2)
notLogin.POST("/auth/send-login-code-by-sms", v1SendLoginCodeBySms)
notLogin.POST("/auth/send-login-code-by-email", v1SendLoginCodeByEmail)
notLogin.POST("/auth/send-rst-code-by-sms", sendRstCodeBySms) notLogin.POST("/auth/send-rst-code-by-sms", sendRstCodeBySms)
notLogin.POST("/auth/rst-password", rstPassword) notLogin.POST("/auth/rst-password", rstPassword)
notLogin.GET("/auth/captcha", captchaGet) notLogin.GET("/auth/captcha", captchaGet)

View File

@ -27,6 +27,7 @@ var (
loginCodeSmsTpl *template.Template loginCodeSmsTpl *template.Template
loginCodeEmailTpl *template.Template loginCodeEmailTpl *template.Template
errUnsupportCaptcha = errors.New("unsupported captcha") errUnsupportCaptcha = errors.New("unsupported captcha")
errInvalidAnswer = errors.New("Invalid captcha answer")
// https://captcha.mojotv.cn // https://captcha.mojotv.cn
captchaDirver = base64Captcha.DriverString{ captchaDirver = base64Captcha.DriverString{
@ -39,55 +40,60 @@ var (
} }
) )
func getConfigFile(name, ext string) (string, error) {
if p := path.Join(path.Join(file.SelfDir(), "etc", name+".local."+ext)); file.IsExist(p) {
return p, nil
}
if p := path.Join(path.Join(file.SelfDir(), "etc", name+"."+ext)); file.IsExist(p) {
return p, nil
} else {
return "", fmt.Errorf("file %s not found", p)
}
}
func init() { func init() {
var err error filename, err := getConfigFile("login-code-sms", "tpl")
filename := path.Join(file.SelfDir(), "etc", "login-code-sms.tpl") if err != nil {
log.Fatal(err)
}
loginCodeSmsTpl, err = template.ParseFiles(filename) loginCodeSmsTpl, err = template.ParseFiles(filename)
if err != nil { if err != nil {
log.Fatalf("open %s err: %s", filename, err) log.Fatalf("open %s err: %s", filename, err)
} }
filename = path.Join(file.SelfDir(), "etc", "login-code-email.tpl") filename, err = getConfigFile("login-code-email", "tpl")
if err != nil {
log.Fatal(err)
}
loginCodeEmailTpl, err = template.ParseFiles(filename) loginCodeEmailTpl, err = template.ParseFiles(filename)
if err != nil { if err != nil {
log.Fatalf("open %s err: %s", filename, err) log.Fatalf("open %s err: %s", filename, err)
} }
} }
type loginForm struct {
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required"`
IsLDAP int `json:"is_ldap"`
RemoteAddr string `json:"remote_addr"`
}
func (f *loginForm) validate() {
if str.Dangerous(f.Username) {
bomb("%s invalid", f.Username)
}
if len(f.Username) > 64 {
bomb("%s too long", f.Username)
}
}
func login(c *gin.Context) { func login(c *gin.Context) {
var f loginForm var f loginInput
bind(c, &f) bind(c, &f)
f.validate() f.validate()
if f.IsLDAP == 1 { if config.Config.Captcha {
dangerous(models.LdapLogin(f.Username, f.Password, c.ClientIP())) c, err := models.CaptchaGet("captcha_id=?", f.CaptchaId)
} else { dangerous(err)
dangerous(models.PassLogin(f.Username, f.Password, c.ClientIP())) if strings.ToLower(c.Answer) != strings.ToLower(f.Answer) {
dangerous(errInvalidAnswer)
}
} }
user, err := models.UserGet("username=?", f.Username) user, err := authLogin(f)
dangerous(err) dangerous(err)
writeCookieUser(c, user.UUID) writeCookieUser(c, user.UUID)
renderMessage(c, "") renderMessage(c, "")
go models.LoginLogNew(user.Username, c.ClientIP(), "in")
} }
func logout(c *gin.Context) { func logout(c *gin.Context) {
@ -105,14 +111,14 @@ func logout(c *gin.Context) {
writeCookieUser(c, "") writeCookieUser(c, "")
go models.LoginLogNew(username, c.ClientIP(), "out")
if config.Config.SSO.Enable { if config.Config.SSO.Enable {
redirect := queryStr(c, "redirect", "/") redirect := queryStr(c, "redirect", "/")
c.Redirect(302, ssoc.LogoutLocation(redirect)) c.Redirect(302, ssoc.LogoutLocation(redirect))
} else { } else {
c.String(200, "logout successfully") c.String(200, "logout successfully")
} }
go models.LoginLogNew(username, c.ClientIP(), "out")
} }
type authRedirect struct { type authRedirect struct {
@ -181,15 +187,16 @@ func logoutV2(c *gin.Context) {
writeCookieUser(c, "") writeCookieUser(c, "")
ret.Msg = "logout successfully" ret.Msg = "logout successfully"
go models.LoginLogNew(username, c.ClientIP(), "out")
if config.Config.SSO.Enable { if config.Config.SSO.Enable {
if redirect == "" { if redirect == "" {
redirect = "/" redirect = "/"
} }
ret.Redirect = ssoc.LogoutLocation(redirect) ret.Redirect = ssoc.LogoutLocation(redirect)
} }
renderData(c, ret, nil) renderData(c, ret, nil)
go models.LoginLogNew(username, c.ClientIP(), "out")
} }
type loginInput struct { type loginInput struct {
@ -198,51 +205,53 @@ type loginInput struct {
Phone string `json:"phone"` Phone string `json:"phone"`
Email string `json:"email"` Email string `json:"email"`
Code string `json:"code"` Code string `json:"code"`
Type string `json:"type"` CaptchaId string `json:"captcha_id"`
RemoteAddr string `json:"remote_addr"` Answer string `json:"answer" description:"captcha answer"`
Type string `json:"type" description:"sms-code|email-code|password|ldap"`
RemoteAddr string `json:"remote_addr" description:"use for server account(v1)"`
IsLDAP int `json:"is_ldap" description:"deprecated"`
} }
const ( func (f *loginInput) validate() {
LOGIN_T_SMS = "sms-code" if f.IsLDAP == 1 {
LOGIN_T_EMAIL = "email-code" f.Type = models.LOGIN_T_LDAP
LOGIN_T_RST = "rst-code" }
LOGIN_T_PWD = "password" if f.Type == "" {
LOGIN_T_LDAP = "ldap" f.Type = models.LOGIN_T_PWD
LOGIN_EXPIRES_IN = 300 }
) if f.Type == models.LOGIN_T_PWD {
if str.Dangerous(f.Username) {
bomb("%s invalid", f.Username)
}
if len(f.Username) > 64 {
bomb("%s too long", f.Username)
}
}
}
// v1Login called by sso.rdb module // v1Login called by sso.rdb module
func v1Login(c *gin.Context) { func v1Login(c *gin.Context) {
var f loginInput var f loginInput
bind(c, &f) bind(c, &f)
user, err := func() (*models.User, error) { user, err := authLogin(f)
switch strings.ToLower(f.Type) { renderData(c, *user, err)
case LOGIN_T_LDAP: }
err := models.LdapLogin(f.Username, f.Password, c.ClientIP())
if err != nil {
return nil, err
}
return models.UserGet("username=?", f.Username)
case LOGIN_T_PWD:
err := models.PassLogin(f.Username, f.Password, c.ClientIP())
if err != nil {
return nil, err
}
return models.UserGet("username=?", f.Username)
case LOGIN_T_SMS:
return smsCodeVerify(f.Phone, f.Code)
case LOGIN_T_EMAIL:
return emailCodeVerify(f.Email, f.Code)
default:
return nil, fmt.Errorf("invalid login type %s", f.Type)
}
}()
// TODO: implement remote address access control // authLogin called by /v1/rdb/login, /api/rdb/auth/login
go models.LoginLogNew(f.Username, f.RemoteAddr, "in") func authLogin(in loginInput) (user *models.User, err error) {
switch strings.ToLower(in.Type) {
renderData(c, user, err) case models.LOGIN_T_LDAP:
return models.LdapLogin(in.Username, in.Password)
case models.LOGIN_T_PWD:
return models.PassLogin(in.Username, in.Password)
case models.LOGIN_T_SMS:
return models.SmsCodeLogin(in.Phone, in.Code)
case models.LOGIN_T_EMAIL:
return models.EmailCodeLogin(in.Email, in.Code)
default:
return nil, fmt.Errorf("invalid login type %s", in.Type)
}
} }
type v1SendLoginCodeBySmsInput struct { type v1SendLoginCodeBySmsInput struct {
@ -269,7 +278,7 @@ func v1SendLoginCodeBySms(c *gin.Context) {
loginCode := &models.LoginCode{ loginCode := &models.LoginCode{
Username: user.Username, Username: user.Username,
Code: code, Code: code,
LoginType: LOGIN_T_SMS, LoginType: models.LOGIN_T_SMS,
CreatedAt: time.Now().Unix(), CreatedAt: time.Now().Unix(),
} }
@ -298,26 +307,6 @@ func v1SendLoginCodeBySms(c *gin.Context) {
renderData(c, msg, err) renderData(c, msg, err)
} }
func smsCodeVerify(phone, code string) (*models.User, error) {
user, _ := models.UserGet("phone=?", phone)
if user == nil {
return nil, fmt.Errorf("phone %s dose not exist", phone)
}
lc, err := models.LoginCodeGet("username=? and code=? and login_type=?", user.Username, code, LOGIN_T_SMS)
if err != nil {
return nil, fmt.Errorf("invalid code", phone)
}
if time.Now().Unix()-lc.CreatedAt > LOGIN_EXPIRES_IN {
return nil, fmt.Errorf("the code has expired", phone)
}
lc.Del()
return user, nil
}
type v1SendLoginCodeByEmailInput struct { type v1SendLoginCodeByEmailInput struct {
Email string `json:"email"` Email string `json:"email"`
} }
@ -342,7 +331,7 @@ func v1SendLoginCodeByEmail(c *gin.Context) {
loginCode := &models.LoginCode{ loginCode := &models.LoginCode{
Username: user.Username, Username: user.Username,
Code: code, Code: code,
LoginType: LOGIN_T_EMAIL, LoginType: models.LOGIN_T_EMAIL,
CreatedAt: time.Now().Unix(), CreatedAt: time.Now().Unix(),
} }
@ -369,26 +358,6 @@ func v1SendLoginCodeByEmail(c *gin.Context) {
renderData(c, msg, err) renderData(c, msg, err)
} }
func emailCodeVerify(email, code string) (*models.User, error) {
user, _ := models.UserGet("email=?", email)
if user == nil {
return nil, fmt.Errorf("email %s dose not exist", email)
}
lc, err := models.LoginCodeGet("username=? and code=? and login_type=?", user.Username, code, LOGIN_T_EMAIL)
if err != nil {
return nil, fmt.Errorf("invalid code", email)
}
if time.Now().Unix()-lc.CreatedAt > LOGIN_EXPIRES_IN {
return nil, fmt.Errorf("the code has expired", email)
}
lc.Del()
return user, nil
}
type sendRstCodeBySmsInput struct { type sendRstCodeBySmsInput struct {
Phone string `json:"phone"` Phone string `json:"phone"`
} }
@ -413,7 +382,7 @@ func sendRstCodeBySms(c *gin.Context) {
loginCode := &models.LoginCode{ loginCode := &models.LoginCode{
Username: user.Username, Username: user.Username,
Code: code, Code: code,
LoginType: LOGIN_T_RST, LoginType: models.LOGIN_T_RST,
CreatedAt: time.Now().Unix(), CreatedAt: time.Now().Unix(),
} }
@ -459,12 +428,12 @@ func rstPassword(c *gin.Context) {
} }
lc, err := models.LoginCodeGet("username=? and code=? and login_type=?", lc, err := models.LoginCodeGet("username=? and code=? and login_type=?",
user.Username, in.Code, LOGIN_T_RST) user.Username, in.Code, models.LOGIN_T_RST)
if err != nil { if err != nil {
return fmt.Errorf("invalid code", in.Phone) return fmt.Errorf("invalid code", in.Phone)
} }
if time.Now().Unix()-lc.CreatedAt > LOGIN_EXPIRES_IN { if time.Now().Unix()-lc.CreatedAt > models.LOGIN_EXPIRES_IN {
return fmt.Errorf("the code has expired", in.Phone) return fmt.Errorf("the code has expired", in.Phone)
} }