support redis tls
This commit is contained in:
parent
e00f102703
commit
e0f0e08852
|
@ -115,10 +115,11 @@ Timeout = 3000
|
|||
[Redis]
|
||||
# address, ip:port
|
||||
Address = "127.0.0.1:6379"
|
||||
# requirepass
|
||||
Password = ""
|
||||
# # db
|
||||
# Username = ""
|
||||
# Password = ""
|
||||
# DB = 0
|
||||
# UseTLS = false
|
||||
# MinVersion = "1.2"
|
||||
|
||||
[Gorm]
|
||||
# enable debug mode or not
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
package tls
|
||||
|
||||
import "crypto/tls"
|
||||
|
||||
var tlsVersionMap = map[string]uint16{
|
||||
"TLS10": tls.VersionTLS10,
|
||||
"TLS11": tls.VersionTLS11,
|
||||
"TLS12": tls.VersionTLS12,
|
||||
"TLS13": tls.VersionTLS13,
|
||||
}
|
||||
|
||||
var tlsCipherMap = map[string]uint16{
|
||||
"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
|
||||
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305": tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
|
||||
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||
"TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
|
||||
"TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
|
||||
"TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,
|
||||
"TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
|
||||
"TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
|
||||
"TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
|
||||
"TLS_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_RSA_WITH_AES_128_GCM_SHA256,
|
||||
"TLS_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_RSA_WITH_AES_256_GCM_SHA384,
|
||||
"TLS_RSA_WITH_AES_128_CBC_SHA256": tls.TLS_RSA_WITH_AES_128_CBC_SHA256,
|
||||
"TLS_RSA_WITH_AES_128_CBC_SHA": tls.TLS_RSA_WITH_AES_128_CBC_SHA,
|
||||
"TLS_RSA_WITH_AES_256_CBC_SHA": tls.TLS_RSA_WITH_AES_256_CBC_SHA,
|
||||
"TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
|
||||
"TLS_RSA_WITH_3DES_EDE_CBC_SHA": tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA,
|
||||
"TLS_RSA_WITH_RC4_128_SHA": tls.TLS_RSA_WITH_RC4_128_SHA,
|
||||
"TLS_ECDHE_RSA_WITH_RC4_128_SHA": tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA,
|
||||
"TLS_ECDHE_ECDSA_WITH_RC4_128_SHA": tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
|
||||
"TLS_AES_128_GCM_SHA256": tls.TLS_AES_128_GCM_SHA256,
|
||||
"TLS_AES_256_GCM_SHA384": tls.TLS_AES_256_GCM_SHA384,
|
||||
"TLS_CHACHA20_POLY1305_SHA256": tls.TLS_CHACHA20_POLY1305_SHA256,
|
||||
}
|
|
@ -0,0 +1,196 @@
|
|||
package tls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/toolkits/pkg/slice"
|
||||
)
|
||||
|
||||
// ClientConfig represents the standard client TLS config.
|
||||
type ClientConfig struct {
|
||||
TLSCA string
|
||||
TLSCert string
|
||||
TLSKey string
|
||||
TLSKeyPwd string
|
||||
InsecureSkipVerify bool
|
||||
ServerName string
|
||||
MinVersion string
|
||||
}
|
||||
|
||||
// ServerConfig represents the standard server TLS config.
|
||||
type ServerConfig struct {
|
||||
TLSCert string
|
||||
TLSKey string
|
||||
TLSKeyPwd string
|
||||
TLSAllowedCACerts []string
|
||||
TLSCipherSuites []string
|
||||
TLSMinVersion string
|
||||
TLSMaxVersion string
|
||||
TLSAllowedDNSNames []string
|
||||
}
|
||||
|
||||
// TLSConfig returns a tls.Config, may be nil without error if TLS is not
|
||||
// configured.
|
||||
func (c *ClientConfig) TLSConfig() (*tls.Config, error) {
|
||||
// This check returns a nil (aka, "use the default")
|
||||
// tls.Config if no field is set that would have an effect on
|
||||
// a TLS connection. That is, any of:
|
||||
// * client certificate settings,
|
||||
// * peer certificate authorities,
|
||||
// * disabled security, or
|
||||
// * an SNI server name.
|
||||
if c.TLSCA == "" && c.TLSKey == "" && c.TLSCert == "" && !c.InsecureSkipVerify && c.ServerName == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
InsecureSkipVerify: c.InsecureSkipVerify,
|
||||
Renegotiation: tls.RenegotiateNever,
|
||||
}
|
||||
|
||||
if c.TLSCA != "" {
|
||||
pool, err := makeCertPool([]string{c.TLSCA})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tlsConfig.RootCAs = pool
|
||||
}
|
||||
|
||||
if c.TLSCert != "" && c.TLSKey != "" {
|
||||
err := loadCertificate(tlsConfig, c.TLSCert, c.TLSKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if c.ServerName != "" {
|
||||
tlsConfig.ServerName = c.ServerName
|
||||
}
|
||||
|
||||
if c.MinVersion == "1.0" {
|
||||
tlsConfig.MinVersion = tls.VersionTLS10
|
||||
} else if c.MinVersion == "1.1" {
|
||||
tlsConfig.MinVersion = tls.VersionTLS11
|
||||
} else if c.MinVersion == "1.2" {
|
||||
tlsConfig.MinVersion = tls.VersionTLS12
|
||||
} else if c.MinVersion == "1.3" {
|
||||
tlsConfig.MinVersion = tls.VersionTLS13
|
||||
}
|
||||
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
// TLSConfig returns a tls.Config, may be nil without error if TLS is not
|
||||
// configured.
|
||||
func (c *ServerConfig) TLSConfig() (*tls.Config, error) {
|
||||
if c.TLSCert == "" && c.TLSKey == "" && len(c.TLSAllowedCACerts) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{}
|
||||
|
||||
if len(c.TLSAllowedCACerts) != 0 {
|
||||
pool, err := makeCertPool(c.TLSAllowedCACerts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tlsConfig.ClientCAs = pool
|
||||
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
}
|
||||
|
||||
if c.TLSCert != "" && c.TLSKey != "" {
|
||||
err := loadCertificate(tlsConfig, c.TLSCert, c.TLSKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if len(c.TLSCipherSuites) != 0 {
|
||||
cipherSuites, err := ParseCiphers(c.TLSCipherSuites)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"could not parse server cipher suites %s: %v", strings.Join(c.TLSCipherSuites, ","), err)
|
||||
}
|
||||
tlsConfig.CipherSuites = cipherSuites
|
||||
}
|
||||
|
||||
if c.TLSMaxVersion != "" {
|
||||
version, err := ParseTLSVersion(c.TLSMaxVersion)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"could not parse tls max version %q: %v", c.TLSMaxVersion, err)
|
||||
}
|
||||
tlsConfig.MaxVersion = version
|
||||
}
|
||||
|
||||
if c.TLSMinVersion != "" {
|
||||
version, err := ParseTLSVersion(c.TLSMinVersion)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"could not parse tls min version %q: %v", c.TLSMinVersion, err)
|
||||
}
|
||||
tlsConfig.MinVersion = version
|
||||
}
|
||||
|
||||
if tlsConfig.MinVersion != 0 && tlsConfig.MaxVersion != 0 && tlsConfig.MinVersion > tlsConfig.MaxVersion {
|
||||
return nil, fmt.Errorf(
|
||||
"tls min version %q can't be greater than tls max version %q", tlsConfig.MinVersion, tlsConfig.MaxVersion)
|
||||
}
|
||||
|
||||
// Since clientAuth is tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
// there must be certs to validate.
|
||||
if len(c.TLSAllowedCACerts) > 0 && len(c.TLSAllowedDNSNames) > 0 {
|
||||
tlsConfig.VerifyPeerCertificate = c.verifyPeerCertificate
|
||||
}
|
||||
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
func makeCertPool(certFiles []string) (*x509.CertPool, error) {
|
||||
pool := x509.NewCertPool()
|
||||
for _, certFile := range certFiles {
|
||||
pem, err := os.ReadFile(certFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"could not read certificate %q: %v", certFile, err)
|
||||
}
|
||||
if !pool.AppendCertsFromPEM(pem) {
|
||||
return nil, fmt.Errorf(
|
||||
"could not parse any PEM certificates %q: %v", certFile, err)
|
||||
}
|
||||
}
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
func loadCertificate(config *tls.Config, certFile, keyFile string) error {
|
||||
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"could not load keypair %s:%s: %v", certFile, keyFile, err)
|
||||
}
|
||||
|
||||
config.Certificates = []tls.Certificate{cert}
|
||||
config.BuildNameToCertificate()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ServerConfig) verifyPeerCertificate(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
||||
// The certificate chain is client + intermediate + root.
|
||||
// Let's review the client certificate.
|
||||
cert, err := x509.ParseCertificate(rawCerts[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not validate peer certificate: %v", err)
|
||||
}
|
||||
|
||||
for _, name := range cert.DNSNames {
|
||||
if slice.ContainsString(c.TLSAllowedDNSNames, name) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("peer certificate not in allowed DNS Name list: %v", cert.DNSNames)
|
||||
}
|
|
@ -0,0 +1,30 @@
|
|||
package tls
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ParseCiphers returns a `[]uint16` by received `[]string` key that represents ciphers from crypto/tls.
|
||||
// If some of ciphers in received list doesn't exists ParseCiphers returns nil with error
|
||||
func ParseCiphers(ciphers []string) ([]uint16, error) {
|
||||
suites := []uint16{}
|
||||
|
||||
for _, cipher := range ciphers {
|
||||
v, ok := tlsCipherMap[cipher]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unsupported cipher %q", cipher)
|
||||
}
|
||||
suites = append(suites, v)
|
||||
}
|
||||
|
||||
return suites, nil
|
||||
}
|
||||
|
||||
// ParseTLSVersion returns a `uint16` by received version string key that represents tls version from crypto/tls.
|
||||
// If version isn't supported ParseTLSVersion returns 0 with error
|
||||
func ParseTLSVersion(version string) (uint16, error) {
|
||||
if v, ok := tlsVersionMap[version]; ok {
|
||||
return v, nil
|
||||
}
|
||||
return 0, fmt.Errorf("unsupported version %q", version)
|
||||
}
|
|
@ -11,12 +11,16 @@ import (
|
|||
"gorm.io/gorm"
|
||||
|
||||
"github.com/didi/nightingale/v5/src/pkg/ormx"
|
||||
"github.com/didi/nightingale/v5/src/pkg/tls"
|
||||
)
|
||||
|
||||
type RedisConfig struct {
|
||||
Address string
|
||||
Username string
|
||||
Password string
|
||||
DB int
|
||||
UseTLS bool
|
||||
tls.ClientConfig
|
||||
}
|
||||
|
||||
type DBConfig struct {
|
||||
|
@ -101,15 +105,27 @@ func newGormDB(cfg DBConfig) (*gorm.DB, error) {
|
|||
var Redis *redis.Client
|
||||
|
||||
func InitRedis(cfg RedisConfig) (func(), error) {
|
||||
Redis = redis.NewClient(&redis.Options{
|
||||
redisOptions := &redis.Options{
|
||||
Addr: cfg.Address,
|
||||
Username: cfg.Username,
|
||||
Password: cfg.Password,
|
||||
DB: cfg.DB,
|
||||
})
|
||||
}
|
||||
|
||||
if cfg.UseTLS {
|
||||
tlsConfig, err := cfg.TLSConfig()
|
||||
if err != nil {
|
||||
fmt.Println("failed to init redis tls config:", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
redisOptions.TLSConfig = tlsConfig
|
||||
}
|
||||
|
||||
Redis = redis.NewClient(redisOptions)
|
||||
|
||||
err := Redis.Ping(context.Background()).Err()
|
||||
if err != nil {
|
||||
fmt.Println("ping redis failed:", err)
|
||||
fmt.Println("failed to ping redis:", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue