certreloader.go 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. package tlsconfig
  2. import (
  3. "crypto/tls"
  4. "crypto/x509"
  5. "fmt"
  6. "os"
  7. "runtime"
  8. "sync"
  9. "github.com/getsentry/sentry-go"
  10. "github.com/pkg/errors"
  11. "github.com/rs/zerolog"
  12. "github.com/urfave/cli/v2"
  13. )
  14. const (
  15. OriginCAPoolFlag = "origin-ca-pool"
  16. CaCertFlag = "cacert"
  17. )
  18. // CertReloader can load and reload a TLS certificate from a particular filepath.
  19. // Hooks into tls.Config's GetCertificate to allow a TLS server to update its certificate without restarting.
  20. type CertReloader struct {
  21. sync.Mutex
  22. certificate *tls.Certificate
  23. certPath string
  24. keyPath string
  25. }
  26. // NewCertReloader makes a CertReloader. It loads the cert during initialization to make sure certPath and keyPath are valid
  27. func NewCertReloader(certPath, keyPath string) (*CertReloader, error) {
  28. cr := new(CertReloader)
  29. cr.certPath = certPath
  30. cr.keyPath = keyPath
  31. if err := cr.LoadCert(); err != nil {
  32. return nil, err
  33. }
  34. return cr, nil
  35. }
  36. // Cert returns the TLS certificate most recently read by the CertReloader.
  37. // This method works as a direct utility method for tls.Config#Cert.
  38. func (cr *CertReloader) Cert(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
  39. cr.Lock()
  40. defer cr.Unlock()
  41. return cr.certificate, nil
  42. }
  43. // ClientCert returns the TLS certificate most recently read by the CertReloader.
  44. // This method works as a direct utility method for tls.Config#ClientCert.
  45. func (cr *CertReloader) ClientCert(certRequestInfo *tls.CertificateRequestInfo) (*tls.Certificate, error) {
  46. cr.Lock()
  47. defer cr.Unlock()
  48. return cr.certificate, nil
  49. }
  50. // LoadCert loads a TLS certificate from the CertReloader's specified filepath.
  51. // Call this after writing a new certificate to the disk (e.g. after renewing a certificate)
  52. func (cr *CertReloader) LoadCert() error {
  53. cr.Lock()
  54. defer cr.Unlock()
  55. cert, err := tls.LoadX509KeyPair(cr.certPath, cr.keyPath)
  56. // Keep the old certificate if there's a problem reading the new one.
  57. if err != nil {
  58. sentry.CaptureException(fmt.Errorf("Error parsing X509 key pair: %v", err))
  59. return err
  60. }
  61. cr.certificate = &cert
  62. return nil
  63. }
  64. func LoadOriginCA(originCAPoolFilename string, log *zerolog.Logger) (*x509.CertPool, error) {
  65. var originCustomCAPool []byte
  66. if originCAPoolFilename != "" {
  67. var err error
  68. originCustomCAPool, err = os.ReadFile(originCAPoolFilename)
  69. if err != nil {
  70. return nil, errors.Wrap(err, fmt.Sprintf("unable to read the file %s for --%s", originCAPoolFilename, OriginCAPoolFlag))
  71. }
  72. }
  73. originCertPool, err := loadOriginCertPool(originCustomCAPool, log)
  74. if err != nil {
  75. return nil, errors.Wrap(err, "error loading the certificate pool")
  76. }
  77. // Windows users should be notified that they can use the flag
  78. if runtime.GOOS == "windows" && originCAPoolFilename == "" {
  79. log.Info().Msgf("cloudflared does not support loading the system root certificate pool on Windows. Please use --%s <PATH> to specify the path to the certificate pool", OriginCAPoolFlag)
  80. }
  81. return originCertPool, nil
  82. }
  83. func LoadCustomOriginCA(originCAFilename string) (*x509.CertPool, error) {
  84. // First, obtain the system certificate pool
  85. certPool, err := x509.SystemCertPool()
  86. if err != nil {
  87. certPool = x509.NewCertPool()
  88. }
  89. // Next, append the Cloudflare CAs into the system pool
  90. cfRootCA, err := GetCloudflareRootCA()
  91. if err != nil {
  92. return nil, errors.Wrap(err, "could not append Cloudflare Root CAs to cloudflared certificate pool")
  93. }
  94. for _, cert := range cfRootCA {
  95. certPool.AddCert(cert)
  96. }
  97. if originCAFilename == "" {
  98. return certPool, nil
  99. }
  100. customOriginCA, err := os.ReadFile(originCAFilename)
  101. if err != nil {
  102. return nil, errors.Wrap(err, fmt.Sprintf("unable to read the file %s", originCAFilename))
  103. }
  104. if !certPool.AppendCertsFromPEM(customOriginCA) {
  105. return nil, fmt.Errorf("error appending custom CA to cert pool")
  106. }
  107. return certPool, nil
  108. }
  109. func CreateTunnelConfig(c *cli.Context, serverName string) (*tls.Config, error) {
  110. var rootCAs []string
  111. if c.String(CaCertFlag) != "" {
  112. rootCAs = append(rootCAs, c.String(CaCertFlag))
  113. }
  114. userConfig := &TLSParameters{RootCAs: rootCAs, ServerName: serverName}
  115. tlsConfig, err := GetConfig(userConfig)
  116. if err != nil {
  117. return nil, err
  118. }
  119. if tlsConfig.RootCAs == nil {
  120. rootCAPool, err := x509.SystemCertPool()
  121. if err != nil {
  122. return nil, errors.Wrap(err, "unable to get x509 system cert pool")
  123. }
  124. cfRootCA, err := GetCloudflareRootCA()
  125. if err != nil {
  126. return nil, errors.Wrap(err, "could not append Cloudflare Root CAs to cloudflared certificate pool")
  127. }
  128. for _, cert := range cfRootCA {
  129. rootCAPool.AddCert(cert)
  130. }
  131. tlsConfig.RootCAs = rootCAPool
  132. }
  133. if tlsConfig.ServerName == "" && !tlsConfig.InsecureSkipVerify {
  134. return nil, fmt.Errorf("either ServerName or InsecureSkipVerify must be specified in the tls.Config")
  135. }
  136. return tlsConfig, nil
  137. }
  138. func loadOriginCertPool(originCAPoolPEM []byte, log *zerolog.Logger) (*x509.CertPool, error) {
  139. // Get the global pool
  140. certPool, err := loadGlobalCertPool(log)
  141. if err != nil {
  142. return nil, err
  143. }
  144. // Then, add any custom origin CA pool the user may have passed
  145. if originCAPoolPEM != nil {
  146. if !certPool.AppendCertsFromPEM(originCAPoolPEM) {
  147. log.Info().Msg("could not append the provided origin CA to the cloudflared certificate pool")
  148. }
  149. }
  150. return certPool, nil
  151. }
  152. func loadGlobalCertPool(log *zerolog.Logger) (*x509.CertPool, error) {
  153. // First, obtain the system certificate pool
  154. certPool, err := x509.SystemCertPool()
  155. if err != nil {
  156. if runtime.GOOS != "windows" { // See https://github.com/golang/go/issues/16736
  157. log.Err(err).Msg("error obtaining the system certificates")
  158. }
  159. certPool = x509.NewCertPool()
  160. }
  161. // Next, append the Cloudflare CAs into the system pool
  162. cfRootCA, err := GetCloudflareRootCA()
  163. if err != nil {
  164. return nil, errors.Wrap(err, "could not append Cloudflare Root CAs to cloudflared certificate pool")
  165. }
  166. for _, cert := range cfRootCA {
  167. certPool.AddCert(cert)
  168. }
  169. // Finally, add the Hello certificate into the pool (since it's self-signed)
  170. helloCert, err := GetHelloCertificateX509()
  171. if err != nil {
  172. return nil, errors.Wrap(err, "could not append Hello server certificate to cloudflared certificate pool")
  173. }
  174. certPool.AddCert(helloCert)
  175. return certPool, nil
  176. }