This commit is contained in:
UlricQin 2020-11-03 12:11:32 +08:00
commit b448dad860
5 changed files with 156 additions and 128 deletions

View File

@ -92,3 +92,5 @@ wechat:
corp_id: "xxxxxxxxxxxxx"
agent_id: 1000000
secret: "xxxxxxxxxxxxxxxxx"
captcha: false

View File

@ -1,5 +1,7 @@
package models
import "errors"
type LoginCode struct {
Username string `json:"username"`
Code string `json:"code"`
@ -7,6 +9,10 @@ type LoginCode struct {
CreatedAt int64 `json:"created_at"`
}
var (
errLoginCode = errors.New("invalid login code")
)
func LoginCodeGet(where string, args ...interface{}) (*LoginCode, error) {
var obj LoginCode
has, err := DB["rdb"].Where(where, args...).Get(&obj)
@ -15,7 +21,7 @@ func LoginCodeGet(where string, args ...interface{}) (*LoginCode, error) {
}
if !has {
return nil, nil
return nil, errLoginCode
}
return &obj, nil

View File

@ -18,6 +18,15 @@ import (
"github.com/didi/nightingale/src/modules/rdb/config"
)
const (
LOGIN_T_SMS = "sms-code"
LOGIN_T_EMAIL = "email-code"
LOGIN_T_RST = "rst-code"
LOGIN_T_PWD = "password"
LOGIN_T_LDAP = "ldap"
LOGIN_EXPIRES_IN = 300
)
type User struct {
Id int64 `json:"id"`
UUID string `json:"-" xorm:"'uuid'"`
@ -82,18 +91,16 @@ func InitRooter() {
log.Println("user root init done")
}
func LdapLogin(user, pass, clientIP string) error {
func LdapLogin(user, pass string) (*User, error) {
sr, err := ldapReq(user, pass)
if err != nil {
return err
return nil, err
}
go LoginLogNew(user, clientIP, "in")
var u User
has, err := DB["rdb"].Where("username=?", user).Get(&u)
if err != nil {
return err
return nil, err
}
u.CopyLdapAttr(sr)
@ -101,9 +108,9 @@ func LdapLogin(user, pass, clientIP string) error {
if has {
if config.Config.LDAP.CoverAttributes {
_, err := DB["rdb"].Where("id=?", u.Id).Update(u)
return err
return nil, err
} else {
return nil
return &u, err
}
}
@ -111,34 +118,76 @@ func LdapLogin(user, pass, clientIP string) error {
u.Password = "******"
u.UUID = GenUUIDForUser(user)
_, err = DB["rdb"].Insert(u)
return err
return &u, nil
}
func PassLogin(user, pass, clientIP string) error {
func PassLogin(user, pass string) (*User, error) {
var u User
has, err := DB["rdb"].Where("username=?", user).Cols("password").Get(&u)
has, err := DB["rdb"].Where("username=?", user).Get(&u)
if err != nil {
return err
return nil, err
}
if !has {
logger.Infof("password auth fail, no such user: %s", user)
return fmt.Errorf("login fail, check your username and password")
return nil, fmt.Errorf("login fail, check your username and password")
}
loginPass, err := CryptoPass(pass)
if err != nil {
return err
return nil, err
}
if loginPass != u.Password {
logger.Infof("password auth fail, password error, user: %s", user)
return fmt.Errorf("login fail, check your username and password")
return nil, fmt.Errorf("login fail, check your username and password")
}
go LoginLogNew(user, clientIP, "in")
return &u, nil
}
return nil
func SmsCodeLogin(phone, code string) (*User, error) {
user, _ := UserGet("phone=?", phone)
if user == nil {
return nil, fmt.Errorf("phone %s dose not exist", phone)
}
lc, err := LoginCodeGet("username=? and code=? and login_type=?", user.Username, code, LOGIN_T_SMS)
if err != nil {
logger.Infof("sms-code auth fail, user: %s", user.Username)
return nil, fmt.Errorf("login fail, check your sms-code")
}
if time.Now().Unix()-lc.CreatedAt > LOGIN_EXPIRES_IN {
logger.Infof("sms-code auth expired, user: %s", user.Username)
return nil, fmt.Errorf("login fail, the code has expired")
}
lc.Del()
return user, nil
}
func EmailCodeLogin(email, code string) (*User, error) {
user, _ := UserGet("email=?", email)
if user == nil {
return nil, fmt.Errorf("email %s dose not exist", email)
}
lc, err := LoginCodeGet("username=? and code=? and login_type=?", user.Username, code, LOGIN_T_EMAIL)
if err != nil {
logger.Infof("email-code auth fail, user: %s", user.Username)
return nil, fmt.Errorf("login fail, check your email-code")
}
if time.Now().Unix()-lc.CreatedAt > LOGIN_EXPIRES_IN {
logger.Infof("email-code auth expired, user: %s", user.Username)
return nil, fmt.Errorf("login fail, the code has expired")
}
lc.Del()
return user, nil
}
func UserGet(where string, args ...interface{}) (*User, error) {

View File

@ -22,6 +22,8 @@ func Config(r *gin.Engine) {
notLogin.GET("/auth/v2/callback", authCallbackV2)
notLogin.GET("/auth/v2/logout", logoutV2)
notLogin.POST("/auth/send-login-code-by-sms", v1SendLoginCodeBySms)
notLogin.POST("/auth/send-login-code-by-email", v1SendLoginCodeByEmail)
notLogin.POST("/auth/send-rst-code-by-sms", sendRstCodeBySms)
notLogin.POST("/auth/rst-password", rstPassword)
notLogin.GET("/auth/captcha", captchaGet)

View File

@ -27,6 +27,7 @@ var (
loginCodeSmsTpl *template.Template
loginCodeEmailTpl *template.Template
errUnsupportCaptcha = errors.New("unsupported captcha")
errInvalidAnswer = errors.New("Invalid captcha answer")
// https://captcha.mojotv.cn
captchaDirver = base64Captcha.DriverString{
@ -39,55 +40,60 @@ var (
}
)
func getConfigFile(name, ext string) (string, error) {
if p := path.Join(path.Join(file.SelfDir(), "etc", name+".local."+ext)); file.IsExist(p) {
return p, nil
}
if p := path.Join(path.Join(file.SelfDir(), "etc", name+"."+ext)); file.IsExist(p) {
return p, nil
} else {
return "", fmt.Errorf("file %s not found", p)
}
}
func init() {
var err error
filename := path.Join(file.SelfDir(), "etc", "login-code-sms.tpl")
filename, err := getConfigFile("login-code-sms", "tpl")
if err != nil {
log.Fatal(err)
}
loginCodeSmsTpl, err = template.ParseFiles(filename)
if err != nil {
log.Fatalf("open %s err: %s", filename, err)
}
filename = path.Join(file.SelfDir(), "etc", "login-code-email.tpl")
filename, err = getConfigFile("login-code-email", "tpl")
if err != nil {
log.Fatal(err)
}
loginCodeEmailTpl, err = template.ParseFiles(filename)
if err != nil {
log.Fatalf("open %s err: %s", filename, err)
}
}
type loginForm struct {
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required"`
IsLDAP int `json:"is_ldap"`
RemoteAddr string `json:"remote_addr"`
}
func (f *loginForm) validate() {
if str.Dangerous(f.Username) {
bomb("%s invalid", f.Username)
}
if len(f.Username) > 64 {
bomb("%s too long", f.Username)
}
}
func login(c *gin.Context) {
var f loginForm
var f loginInput
bind(c, &f)
f.validate()
if f.IsLDAP == 1 {
dangerous(models.LdapLogin(f.Username, f.Password, c.ClientIP()))
} else {
dangerous(models.PassLogin(f.Username, f.Password, c.ClientIP()))
if config.Config.Captcha {
c, err := models.CaptchaGet("captcha_id=?", f.CaptchaId)
dangerous(err)
if strings.ToLower(c.Answer) != strings.ToLower(f.Answer) {
dangerous(errInvalidAnswer)
}
}
user, err := models.UserGet("username=?", f.Username)
user, err := authLogin(f)
dangerous(err)
writeCookieUser(c, user.UUID)
renderMessage(c, "")
go models.LoginLogNew(user.Username, c.ClientIP(), "in")
}
func logout(c *gin.Context) {
@ -105,14 +111,14 @@ func logout(c *gin.Context) {
writeCookieUser(c, "")
go models.LoginLogNew(username, c.ClientIP(), "out")
if config.Config.SSO.Enable {
redirect := queryStr(c, "redirect", "/")
c.Redirect(302, ssoc.LogoutLocation(redirect))
} else {
c.String(200, "logout successfully")
}
go models.LoginLogNew(username, c.ClientIP(), "out")
}
type authRedirect struct {
@ -181,15 +187,16 @@ func logoutV2(c *gin.Context) {
writeCookieUser(c, "")
ret.Msg = "logout successfully"
go models.LoginLogNew(username, c.ClientIP(), "out")
if config.Config.SSO.Enable {
if redirect == "" {
redirect = "/"
}
ret.Redirect = ssoc.LogoutLocation(redirect)
}
renderData(c, ret, nil)
go models.LoginLogNew(username, c.ClientIP(), "out")
}
type loginInput struct {
@ -198,51 +205,53 @@ type loginInput struct {
Phone string `json:"phone"`
Email string `json:"email"`
Code string `json:"code"`
Type string `json:"type"`
RemoteAddr string `json:"remote_addr"`
CaptchaId string `json:"captcha_id"`
Answer string `json:"answer" description:"captcha answer"`
Type string `json:"type" description:"sms-code|email-code|password|ldap"`
RemoteAddr string `json:"remote_addr" description:"use for server account(v1)"`
IsLDAP int `json:"is_ldap" description:"deprecated"`
}
const (
LOGIN_T_SMS = "sms-code"
LOGIN_T_EMAIL = "email-code"
LOGIN_T_RST = "rst-code"
LOGIN_T_PWD = "password"
LOGIN_T_LDAP = "ldap"
LOGIN_EXPIRES_IN = 300
)
func (f *loginInput) validate() {
if f.IsLDAP == 1 {
f.Type = models.LOGIN_T_LDAP
}
if f.Type == "" {
f.Type = models.LOGIN_T_PWD
}
if f.Type == models.LOGIN_T_PWD {
if str.Dangerous(f.Username) {
bomb("%s invalid", f.Username)
}
if len(f.Username) > 64 {
bomb("%s too long", f.Username)
}
}
}
// v1Login called by sso.rdb module
func v1Login(c *gin.Context) {
var f loginInput
bind(c, &f)
user, err := func() (*models.User, error) {
switch strings.ToLower(f.Type) {
case LOGIN_T_LDAP:
err := models.LdapLogin(f.Username, f.Password, c.ClientIP())
if err != nil {
return nil, err
}
return models.UserGet("username=?", f.Username)
case LOGIN_T_PWD:
err := models.PassLogin(f.Username, f.Password, c.ClientIP())
if err != nil {
return nil, err
}
return models.UserGet("username=?", f.Username)
case LOGIN_T_SMS:
return smsCodeVerify(f.Phone, f.Code)
case LOGIN_T_EMAIL:
return emailCodeVerify(f.Email, f.Code)
default:
return nil, fmt.Errorf("invalid login type %s", f.Type)
}
}()
user, err := authLogin(f)
renderData(c, *user, err)
}
// TODO: implement remote address access control
go models.LoginLogNew(f.Username, f.RemoteAddr, "in")
renderData(c, user, err)
// authLogin called by /v1/rdb/login, /api/rdb/auth/login
func authLogin(in loginInput) (user *models.User, err error) {
switch strings.ToLower(in.Type) {
case models.LOGIN_T_LDAP:
return models.LdapLogin(in.Username, in.Password)
case models.LOGIN_T_PWD:
return models.PassLogin(in.Username, in.Password)
case models.LOGIN_T_SMS:
return models.SmsCodeLogin(in.Phone, in.Code)
case models.LOGIN_T_EMAIL:
return models.EmailCodeLogin(in.Email, in.Code)
default:
return nil, fmt.Errorf("invalid login type %s", in.Type)
}
}
type v1SendLoginCodeBySmsInput struct {
@ -269,7 +278,7 @@ func v1SendLoginCodeBySms(c *gin.Context) {
loginCode := &models.LoginCode{
Username: user.Username,
Code: code,
LoginType: LOGIN_T_SMS,
LoginType: models.LOGIN_T_SMS,
CreatedAt: time.Now().Unix(),
}
@ -298,26 +307,6 @@ func v1SendLoginCodeBySms(c *gin.Context) {
renderData(c, msg, err)
}
func smsCodeVerify(phone, code string) (*models.User, error) {
user, _ := models.UserGet("phone=?", phone)
if user == nil {
return nil, fmt.Errorf("phone %s dose not exist", phone)
}
lc, err := models.LoginCodeGet("username=? and code=? and login_type=?", user.Username, code, LOGIN_T_SMS)
if err != nil {
return nil, fmt.Errorf("invalid code", phone)
}
if time.Now().Unix()-lc.CreatedAt > LOGIN_EXPIRES_IN {
return nil, fmt.Errorf("the code has expired", phone)
}
lc.Del()
return user, nil
}
type v1SendLoginCodeByEmailInput struct {
Email string `json:"email"`
}
@ -342,7 +331,7 @@ func v1SendLoginCodeByEmail(c *gin.Context) {
loginCode := &models.LoginCode{
Username: user.Username,
Code: code,
LoginType: LOGIN_T_EMAIL,
LoginType: models.LOGIN_T_EMAIL,
CreatedAt: time.Now().Unix(),
}
@ -369,26 +358,6 @@ func v1SendLoginCodeByEmail(c *gin.Context) {
renderData(c, msg, err)
}
func emailCodeVerify(email, code string) (*models.User, error) {
user, _ := models.UserGet("email=?", email)
if user == nil {
return nil, fmt.Errorf("email %s dose not exist", email)
}
lc, err := models.LoginCodeGet("username=? and code=? and login_type=?", user.Username, code, LOGIN_T_EMAIL)
if err != nil {
return nil, fmt.Errorf("invalid code", email)
}
if time.Now().Unix()-lc.CreatedAt > LOGIN_EXPIRES_IN {
return nil, fmt.Errorf("the code has expired", email)
}
lc.Del()
return user, nil
}
type sendRstCodeBySmsInput struct {
Phone string `json:"phone"`
}
@ -413,7 +382,7 @@ func sendRstCodeBySms(c *gin.Context) {
loginCode := &models.LoginCode{
Username: user.Username,
Code: code,
LoginType: LOGIN_T_RST,
LoginType: models.LOGIN_T_RST,
CreatedAt: time.Now().Unix(),
}
@ -459,12 +428,12 @@ func rstPassword(c *gin.Context) {
}
lc, err := models.LoginCodeGet("username=? and code=? and login_type=?",
user.Username, in.Code, LOGIN_T_RST)
user.Username, in.Code, models.LOGIN_T_RST)
if err != nil {
return fmt.Errorf("invalid code", in.Phone)
}
if time.Now().Unix()-lc.CreatedAt > LOGIN_EXPIRES_IN {
if time.Now().Unix()-lc.CreatedAt > models.LOGIN_EXPIRES_IN {
return fmt.Errorf("the code has expired", in.Phone)
}