support redis tls

This commit is contained in:
Ulric Qin 2022-04-22 21:48:56 +08:00
parent e00f102703
commit e0f0e08852
5 changed files with 287 additions and 6 deletions

View File

@ -115,10 +115,11 @@ Timeout = 3000
[Redis] [Redis]
# address, ip:port # address, ip:port
Address = "127.0.0.1:6379" Address = "127.0.0.1:6379"
# requirepass # Username = ""
Password = "" # Password = ""
# # db
# DB = 0 # DB = 0
# UseTLS = false
# MinVersion = "1.2"
[Gorm] [Gorm]
# enable debug mode or not # enable debug mode or not

38
src/pkg/tls/common.go Normal file
View File

@ -0,0 +1,38 @@
package tls
import "crypto/tls"
var tlsVersionMap = map[string]uint16{
"TLS10": tls.VersionTLS10,
"TLS11": tls.VersionTLS11,
"TLS12": tls.VersionTLS12,
"TLS13": tls.VersionTLS13,
}
var tlsCipherMap = map[string]uint16{
"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305": tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
"TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
"TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
"TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,
"TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
"TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
"TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
"TLS_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_RSA_WITH_AES_128_GCM_SHA256,
"TLS_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_RSA_WITH_AES_256_GCM_SHA384,
"TLS_RSA_WITH_AES_128_CBC_SHA256": tls.TLS_RSA_WITH_AES_128_CBC_SHA256,
"TLS_RSA_WITH_AES_128_CBC_SHA": tls.TLS_RSA_WITH_AES_128_CBC_SHA,
"TLS_RSA_WITH_AES_256_CBC_SHA": tls.TLS_RSA_WITH_AES_256_CBC_SHA,
"TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
"TLS_RSA_WITH_3DES_EDE_CBC_SHA": tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA,
"TLS_RSA_WITH_RC4_128_SHA": tls.TLS_RSA_WITH_RC4_128_SHA,
"TLS_ECDHE_RSA_WITH_RC4_128_SHA": tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA,
"TLS_ECDHE_ECDSA_WITH_RC4_128_SHA": tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
"TLS_AES_128_GCM_SHA256": tls.TLS_AES_128_GCM_SHA256,
"TLS_AES_256_GCM_SHA384": tls.TLS_AES_256_GCM_SHA384,
"TLS_CHACHA20_POLY1305_SHA256": tls.TLS_CHACHA20_POLY1305_SHA256,
}

196
src/pkg/tls/config.go Normal file
View File

@ -0,0 +1,196 @@
package tls
import (
"crypto/tls"
"crypto/x509"
"fmt"
"os"
"strings"
"github.com/toolkits/pkg/slice"
)
// ClientConfig represents the standard client TLS config.
type ClientConfig struct {
TLSCA string
TLSCert string
TLSKey string
TLSKeyPwd string
InsecureSkipVerify bool
ServerName string
MinVersion string
}
// ServerConfig represents the standard server TLS config.
type ServerConfig struct {
TLSCert string
TLSKey string
TLSKeyPwd string
TLSAllowedCACerts []string
TLSCipherSuites []string
TLSMinVersion string
TLSMaxVersion string
TLSAllowedDNSNames []string
}
// TLSConfig returns a tls.Config, may be nil without error if TLS is not
// configured.
func (c *ClientConfig) TLSConfig() (*tls.Config, error) {
// This check returns a nil (aka, "use the default")
// tls.Config if no field is set that would have an effect on
// a TLS connection. That is, any of:
// * client certificate settings,
// * peer certificate authorities,
// * disabled security, or
// * an SNI server name.
if c.TLSCA == "" && c.TLSKey == "" && c.TLSCert == "" && !c.InsecureSkipVerify && c.ServerName == "" {
return nil, nil
}
tlsConfig := &tls.Config{
InsecureSkipVerify: c.InsecureSkipVerify,
Renegotiation: tls.RenegotiateNever,
}
if c.TLSCA != "" {
pool, err := makeCertPool([]string{c.TLSCA})
if err != nil {
return nil, err
}
tlsConfig.RootCAs = pool
}
if c.TLSCert != "" && c.TLSKey != "" {
err := loadCertificate(tlsConfig, c.TLSCert, c.TLSKey)
if err != nil {
return nil, err
}
}
if c.ServerName != "" {
tlsConfig.ServerName = c.ServerName
}
if c.MinVersion == "1.0" {
tlsConfig.MinVersion = tls.VersionTLS10
} else if c.MinVersion == "1.1" {
tlsConfig.MinVersion = tls.VersionTLS11
} else if c.MinVersion == "1.2" {
tlsConfig.MinVersion = tls.VersionTLS12
} else if c.MinVersion == "1.3" {
tlsConfig.MinVersion = tls.VersionTLS13
}
return tlsConfig, nil
}
// TLSConfig returns a tls.Config, may be nil without error if TLS is not
// configured.
func (c *ServerConfig) TLSConfig() (*tls.Config, error) {
if c.TLSCert == "" && c.TLSKey == "" && len(c.TLSAllowedCACerts) == 0 {
return nil, nil
}
tlsConfig := &tls.Config{}
if len(c.TLSAllowedCACerts) != 0 {
pool, err := makeCertPool(c.TLSAllowedCACerts)
if err != nil {
return nil, err
}
tlsConfig.ClientCAs = pool
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
}
if c.TLSCert != "" && c.TLSKey != "" {
err := loadCertificate(tlsConfig, c.TLSCert, c.TLSKey)
if err != nil {
return nil, err
}
}
if len(c.TLSCipherSuites) != 0 {
cipherSuites, err := ParseCiphers(c.TLSCipherSuites)
if err != nil {
return nil, fmt.Errorf(
"could not parse server cipher suites %s: %v", strings.Join(c.TLSCipherSuites, ","), err)
}
tlsConfig.CipherSuites = cipherSuites
}
if c.TLSMaxVersion != "" {
version, err := ParseTLSVersion(c.TLSMaxVersion)
if err != nil {
return nil, fmt.Errorf(
"could not parse tls max version %q: %v", c.TLSMaxVersion, err)
}
tlsConfig.MaxVersion = version
}
if c.TLSMinVersion != "" {
version, err := ParseTLSVersion(c.TLSMinVersion)
if err != nil {
return nil, fmt.Errorf(
"could not parse tls min version %q: %v", c.TLSMinVersion, err)
}
tlsConfig.MinVersion = version
}
if tlsConfig.MinVersion != 0 && tlsConfig.MaxVersion != 0 && tlsConfig.MinVersion > tlsConfig.MaxVersion {
return nil, fmt.Errorf(
"tls min version %q can't be greater than tls max version %q", tlsConfig.MinVersion, tlsConfig.MaxVersion)
}
// Since clientAuth is tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
// there must be certs to validate.
if len(c.TLSAllowedCACerts) > 0 && len(c.TLSAllowedDNSNames) > 0 {
tlsConfig.VerifyPeerCertificate = c.verifyPeerCertificate
}
return tlsConfig, nil
}
func makeCertPool(certFiles []string) (*x509.CertPool, error) {
pool := x509.NewCertPool()
for _, certFile := range certFiles {
pem, err := os.ReadFile(certFile)
if err != nil {
return nil, fmt.Errorf(
"could not read certificate %q: %v", certFile, err)
}
if !pool.AppendCertsFromPEM(pem) {
return nil, fmt.Errorf(
"could not parse any PEM certificates %q: %v", certFile, err)
}
}
return pool, nil
}
func loadCertificate(config *tls.Config, certFile, keyFile string) error {
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return fmt.Errorf(
"could not load keypair %s:%s: %v", certFile, keyFile, err)
}
config.Certificates = []tls.Certificate{cert}
config.BuildNameToCertificate()
return nil
}
func (c *ServerConfig) verifyPeerCertificate(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
// The certificate chain is client + intermediate + root.
// Let's review the client certificate.
cert, err := x509.ParseCertificate(rawCerts[0])
if err != nil {
return fmt.Errorf("could not validate peer certificate: %v", err)
}
for _, name := range cert.DNSNames {
if slice.ContainsString(c.TLSAllowedDNSNames, name) {
return nil
}
}
return fmt.Errorf("peer certificate not in allowed DNS Name list: %v", cert.DNSNames)
}

30
src/pkg/tls/utils.go Normal file
View File

@ -0,0 +1,30 @@
package tls
import (
"fmt"
)
// ParseCiphers returns a `[]uint16` by received `[]string` key that represents ciphers from crypto/tls.
// If some of ciphers in received list doesn't exists ParseCiphers returns nil with error
func ParseCiphers(ciphers []string) ([]uint16, error) {
suites := []uint16{}
for _, cipher := range ciphers {
v, ok := tlsCipherMap[cipher]
if !ok {
return nil, fmt.Errorf("unsupported cipher %q", cipher)
}
suites = append(suites, v)
}
return suites, nil
}
// ParseTLSVersion returns a `uint16` by received version string key that represents tls version from crypto/tls.
// If version isn't supported ParseTLSVersion returns 0 with error
func ParseTLSVersion(version string) (uint16, error) {
if v, ok := tlsVersionMap[version]; ok {
return v, nil
}
return 0, fmt.Errorf("unsupported version %q", version)
}

View File

@ -11,12 +11,16 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
"github.com/didi/nightingale/v5/src/pkg/ormx" "github.com/didi/nightingale/v5/src/pkg/ormx"
"github.com/didi/nightingale/v5/src/pkg/tls"
) )
type RedisConfig struct { type RedisConfig struct {
Address string Address string
Username string
Password string Password string
DB int DB int
UseTLS bool
tls.ClientConfig
} }
type DBConfig struct { type DBConfig struct {
@ -101,15 +105,27 @@ func newGormDB(cfg DBConfig) (*gorm.DB, error) {
var Redis *redis.Client var Redis *redis.Client
func InitRedis(cfg RedisConfig) (func(), error) { func InitRedis(cfg RedisConfig) (func(), error) {
Redis = redis.NewClient(&redis.Options{ redisOptions := &redis.Options{
Addr: cfg.Address, Addr: cfg.Address,
Username: cfg.Username,
Password: cfg.Password, Password: cfg.Password,
DB: cfg.DB, DB: cfg.DB,
}) }
if cfg.UseTLS {
tlsConfig, err := cfg.TLSConfig()
if err != nil {
fmt.Println("failed to init redis tls config:", err)
os.Exit(1)
}
redisOptions.TLSConfig = tlsConfig
}
Redis = redis.NewClient(redisOptions)
err := Redis.Ping(context.Background()).Err() err := Redis.Ping(context.Background()).Err()
if err != nil { if err != nil {
fmt.Println("ping redis failed:", err) fmt.Println("failed to ping redis:", err)
os.Exit(1) os.Exit(1)
} }