support redis tls
This commit is contained in:
parent
e00f102703
commit
e0f0e08852
|
@ -115,10 +115,11 @@ Timeout = 3000
|
||||||
[Redis]
|
[Redis]
|
||||||
# address, ip:port
|
# address, ip:port
|
||||||
Address = "127.0.0.1:6379"
|
Address = "127.0.0.1:6379"
|
||||||
# requirepass
|
# Username = ""
|
||||||
Password = ""
|
# Password = ""
|
||||||
# # db
|
|
||||||
# DB = 0
|
# DB = 0
|
||||||
|
# UseTLS = false
|
||||||
|
# MinVersion = "1.2"
|
||||||
|
|
||||||
[Gorm]
|
[Gorm]
|
||||||
# enable debug mode or not
|
# enable debug mode or not
|
||||||
|
|
|
@ -0,0 +1,38 @@
|
||||||
|
package tls
|
||||||
|
|
||||||
|
import "crypto/tls"
|
||||||
|
|
||||||
|
var tlsVersionMap = map[string]uint16{
|
||||||
|
"TLS10": tls.VersionTLS10,
|
||||||
|
"TLS11": tls.VersionTLS11,
|
||||||
|
"TLS12": tls.VersionTLS12,
|
||||||
|
"TLS13": tls.VersionTLS13,
|
||||||
|
}
|
||||||
|
|
||||||
|
var tlsCipherMap = map[string]uint16{
|
||||||
|
"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
|
||||||
|
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305": tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
|
||||||
|
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||||
|
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||||
|
"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||||
|
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||||
|
"TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
|
||||||
|
"TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
|
||||||
|
"TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,
|
||||||
|
"TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
|
||||||
|
"TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
|
||||||
|
"TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
|
||||||
|
"TLS_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_RSA_WITH_AES_128_GCM_SHA256,
|
||||||
|
"TLS_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_RSA_WITH_AES_256_GCM_SHA384,
|
||||||
|
"TLS_RSA_WITH_AES_128_CBC_SHA256": tls.TLS_RSA_WITH_AES_128_CBC_SHA256,
|
||||||
|
"TLS_RSA_WITH_AES_128_CBC_SHA": tls.TLS_RSA_WITH_AES_128_CBC_SHA,
|
||||||
|
"TLS_RSA_WITH_AES_256_CBC_SHA": tls.TLS_RSA_WITH_AES_256_CBC_SHA,
|
||||||
|
"TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
|
||||||
|
"TLS_RSA_WITH_3DES_EDE_CBC_SHA": tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA,
|
||||||
|
"TLS_RSA_WITH_RC4_128_SHA": tls.TLS_RSA_WITH_RC4_128_SHA,
|
||||||
|
"TLS_ECDHE_RSA_WITH_RC4_128_SHA": tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA,
|
||||||
|
"TLS_ECDHE_ECDSA_WITH_RC4_128_SHA": tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
|
||||||
|
"TLS_AES_128_GCM_SHA256": tls.TLS_AES_128_GCM_SHA256,
|
||||||
|
"TLS_AES_256_GCM_SHA384": tls.TLS_AES_256_GCM_SHA384,
|
||||||
|
"TLS_CHACHA20_POLY1305_SHA256": tls.TLS_CHACHA20_POLY1305_SHA256,
|
||||||
|
}
|
|
@ -0,0 +1,196 @@
|
||||||
|
package tls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/toolkits/pkg/slice"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ClientConfig represents the standard client TLS config.
|
||||||
|
type ClientConfig struct {
|
||||||
|
TLSCA string
|
||||||
|
TLSCert string
|
||||||
|
TLSKey string
|
||||||
|
TLSKeyPwd string
|
||||||
|
InsecureSkipVerify bool
|
||||||
|
ServerName string
|
||||||
|
MinVersion string
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServerConfig represents the standard server TLS config.
|
||||||
|
type ServerConfig struct {
|
||||||
|
TLSCert string
|
||||||
|
TLSKey string
|
||||||
|
TLSKeyPwd string
|
||||||
|
TLSAllowedCACerts []string
|
||||||
|
TLSCipherSuites []string
|
||||||
|
TLSMinVersion string
|
||||||
|
TLSMaxVersion string
|
||||||
|
TLSAllowedDNSNames []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// TLSConfig returns a tls.Config, may be nil without error if TLS is not
|
||||||
|
// configured.
|
||||||
|
func (c *ClientConfig) TLSConfig() (*tls.Config, error) {
|
||||||
|
// This check returns a nil (aka, "use the default")
|
||||||
|
// tls.Config if no field is set that would have an effect on
|
||||||
|
// a TLS connection. That is, any of:
|
||||||
|
// * client certificate settings,
|
||||||
|
// * peer certificate authorities,
|
||||||
|
// * disabled security, or
|
||||||
|
// * an SNI server name.
|
||||||
|
if c.TLSCA == "" && c.TLSKey == "" && c.TLSCert == "" && !c.InsecureSkipVerify && c.ServerName == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
InsecureSkipVerify: c.InsecureSkipVerify,
|
||||||
|
Renegotiation: tls.RenegotiateNever,
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.TLSCA != "" {
|
||||||
|
pool, err := makeCertPool([]string{c.TLSCA})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tlsConfig.RootCAs = pool
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.TLSCert != "" && c.TLSKey != "" {
|
||||||
|
err := loadCertificate(tlsConfig, c.TLSCert, c.TLSKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.ServerName != "" {
|
||||||
|
tlsConfig.ServerName = c.ServerName
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.MinVersion == "1.0" {
|
||||||
|
tlsConfig.MinVersion = tls.VersionTLS10
|
||||||
|
} else if c.MinVersion == "1.1" {
|
||||||
|
tlsConfig.MinVersion = tls.VersionTLS11
|
||||||
|
} else if c.MinVersion == "1.2" {
|
||||||
|
tlsConfig.MinVersion = tls.VersionTLS12
|
||||||
|
} else if c.MinVersion == "1.3" {
|
||||||
|
tlsConfig.MinVersion = tls.VersionTLS13
|
||||||
|
}
|
||||||
|
|
||||||
|
return tlsConfig, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TLSConfig returns a tls.Config, may be nil without error if TLS is not
|
||||||
|
// configured.
|
||||||
|
func (c *ServerConfig) TLSConfig() (*tls.Config, error) {
|
||||||
|
if c.TLSCert == "" && c.TLSKey == "" && len(c.TLSAllowedCACerts) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig := &tls.Config{}
|
||||||
|
|
||||||
|
if len(c.TLSAllowedCACerts) != 0 {
|
||||||
|
pool, err := makeCertPool(c.TLSAllowedCACerts)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tlsConfig.ClientCAs = pool
|
||||||
|
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.TLSCert != "" && c.TLSKey != "" {
|
||||||
|
err := loadCertificate(tlsConfig, c.TLSCert, c.TLSKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(c.TLSCipherSuites) != 0 {
|
||||||
|
cipherSuites, err := ParseCiphers(c.TLSCipherSuites)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf(
|
||||||
|
"could not parse server cipher suites %s: %v", strings.Join(c.TLSCipherSuites, ","), err)
|
||||||
|
}
|
||||||
|
tlsConfig.CipherSuites = cipherSuites
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.TLSMaxVersion != "" {
|
||||||
|
version, err := ParseTLSVersion(c.TLSMaxVersion)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf(
|
||||||
|
"could not parse tls max version %q: %v", c.TLSMaxVersion, err)
|
||||||
|
}
|
||||||
|
tlsConfig.MaxVersion = version
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.TLSMinVersion != "" {
|
||||||
|
version, err := ParseTLSVersion(c.TLSMinVersion)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf(
|
||||||
|
"could not parse tls min version %q: %v", c.TLSMinVersion, err)
|
||||||
|
}
|
||||||
|
tlsConfig.MinVersion = version
|
||||||
|
}
|
||||||
|
|
||||||
|
if tlsConfig.MinVersion != 0 && tlsConfig.MaxVersion != 0 && tlsConfig.MinVersion > tlsConfig.MaxVersion {
|
||||||
|
return nil, fmt.Errorf(
|
||||||
|
"tls min version %q can't be greater than tls max version %q", tlsConfig.MinVersion, tlsConfig.MaxVersion)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Since clientAuth is tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||||
|
// there must be certs to validate.
|
||||||
|
if len(c.TLSAllowedCACerts) > 0 && len(c.TLSAllowedDNSNames) > 0 {
|
||||||
|
tlsConfig.VerifyPeerCertificate = c.verifyPeerCertificate
|
||||||
|
}
|
||||||
|
|
||||||
|
return tlsConfig, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeCertPool(certFiles []string) (*x509.CertPool, error) {
|
||||||
|
pool := x509.NewCertPool()
|
||||||
|
for _, certFile := range certFiles {
|
||||||
|
pem, err := os.ReadFile(certFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf(
|
||||||
|
"could not read certificate %q: %v", certFile, err)
|
||||||
|
}
|
||||||
|
if !pool.AppendCertsFromPEM(pem) {
|
||||||
|
return nil, fmt.Errorf(
|
||||||
|
"could not parse any PEM certificates %q: %v", certFile, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return pool, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadCertificate(config *tls.Config, certFile, keyFile string) error {
|
||||||
|
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"could not load keypair %s:%s: %v", certFile, keyFile, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
config.Certificates = []tls.Certificate{cert}
|
||||||
|
config.BuildNameToCertificate()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ServerConfig) verifyPeerCertificate(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
||||||
|
// The certificate chain is client + intermediate + root.
|
||||||
|
// Let's review the client certificate.
|
||||||
|
cert, err := x509.ParseCertificate(rawCerts[0])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("could not validate peer certificate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, name := range cert.DNSNames {
|
||||||
|
if slice.ContainsString(c.TLSAllowedDNSNames, name) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("peer certificate not in allowed DNS Name list: %v", cert.DNSNames)
|
||||||
|
}
|
|
@ -0,0 +1,30 @@
|
||||||
|
package tls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ParseCiphers returns a `[]uint16` by received `[]string` key that represents ciphers from crypto/tls.
|
||||||
|
// If some of ciphers in received list doesn't exists ParseCiphers returns nil with error
|
||||||
|
func ParseCiphers(ciphers []string) ([]uint16, error) {
|
||||||
|
suites := []uint16{}
|
||||||
|
|
||||||
|
for _, cipher := range ciphers {
|
||||||
|
v, ok := tlsCipherMap[cipher]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("unsupported cipher %q", cipher)
|
||||||
|
}
|
||||||
|
suites = append(suites, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
return suites, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseTLSVersion returns a `uint16` by received version string key that represents tls version from crypto/tls.
|
||||||
|
// If version isn't supported ParseTLSVersion returns 0 with error
|
||||||
|
func ParseTLSVersion(version string) (uint16, error) {
|
||||||
|
if v, ok := tlsVersionMap[version]; ok {
|
||||||
|
return v, nil
|
||||||
|
}
|
||||||
|
return 0, fmt.Errorf("unsupported version %q", version)
|
||||||
|
}
|
|
@ -11,12 +11,16 @@ import (
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
|
||||||
"github.com/didi/nightingale/v5/src/pkg/ormx"
|
"github.com/didi/nightingale/v5/src/pkg/ormx"
|
||||||
|
"github.com/didi/nightingale/v5/src/pkg/tls"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RedisConfig struct {
|
type RedisConfig struct {
|
||||||
Address string
|
Address string
|
||||||
|
Username string
|
||||||
Password string
|
Password string
|
||||||
DB int
|
DB int
|
||||||
|
UseTLS bool
|
||||||
|
tls.ClientConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
type DBConfig struct {
|
type DBConfig struct {
|
||||||
|
@ -101,15 +105,27 @@ func newGormDB(cfg DBConfig) (*gorm.DB, error) {
|
||||||
var Redis *redis.Client
|
var Redis *redis.Client
|
||||||
|
|
||||||
func InitRedis(cfg RedisConfig) (func(), error) {
|
func InitRedis(cfg RedisConfig) (func(), error) {
|
||||||
Redis = redis.NewClient(&redis.Options{
|
redisOptions := &redis.Options{
|
||||||
Addr: cfg.Address,
|
Addr: cfg.Address,
|
||||||
|
Username: cfg.Username,
|
||||||
Password: cfg.Password,
|
Password: cfg.Password,
|
||||||
DB: cfg.DB,
|
DB: cfg.DB,
|
||||||
})
|
}
|
||||||
|
|
||||||
|
if cfg.UseTLS {
|
||||||
|
tlsConfig, err := cfg.TLSConfig()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("failed to init redis tls config:", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
redisOptions.TLSConfig = tlsConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
Redis = redis.NewClient(redisOptions)
|
||||||
|
|
||||||
err := Redis.Ping(context.Background()).Err()
|
err := Redis.Ping(context.Background()).Err()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("ping redis failed:", err)
|
fmt.Println("failed to ping redis:", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue