tlsconfig.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. // Package tlsconfig provides convenience functions for configuring TLS connections from the
  2. // command line.
  3. package tlsconfig
  4. import (
  5. "crypto/tls"
  6. "crypto/x509"
  7. "os"
  8. "github.com/pkg/errors"
  9. )
  10. // Config is the user provided parameters to create a tls.Config
  11. type TLSParameters struct {
  12. Cert string
  13. Key string
  14. GetCertificate *CertReloader
  15. GetClientCertificate *CertReloader
  16. ClientCAs []string
  17. RootCAs []string
  18. ServerName string
  19. CurvePreferences []tls.CurveID
  20. MinVersion uint16 // min tls version. If zero, TLS1.0 is defined as minimum.
  21. MaxVersion uint16 // max tls version. If zero, last TLS version is used defined as limit (currently TLS1.3)
  22. }
  23. // GetConfig returns a TLS configuration according to the Config set by the user.
  24. func GetConfig(p *TLSParameters) (*tls.Config, error) {
  25. tlsconfig := &tls.Config{}
  26. if p.Cert != "" && p.Key != "" {
  27. cert, err := tls.LoadX509KeyPair(p.Cert, p.Key)
  28. if err != nil {
  29. return nil, errors.Wrap(err, "Error parsing X509 key pair")
  30. }
  31. tlsconfig.Certificates = []tls.Certificate{cert}
  32. // BuildNameToCertificate parses Certificates and builds NameToCertificate from common name
  33. // and SAN fields of leaf certificates
  34. tlsconfig.BuildNameToCertificate()
  35. }
  36. if p.GetCertificate != nil {
  37. // GetCertificate is called when client supplies SNI info or Certificates is empty.
  38. // Order of retrieving certificate is GetCertificate, NameToCertificate and lastly first element of Certificates
  39. tlsconfig.GetCertificate = p.GetCertificate.Cert
  40. }
  41. if p.GetClientCertificate != nil {
  42. // GetClientCertificate is called when using an HTTP client library and mTLS is required.
  43. tlsconfig.GetClientCertificate = p.GetClientCertificate.ClientCert
  44. }
  45. if len(p.ClientCAs) > 0 {
  46. // set of root certificate authorities that servers use if required to verify a client certificate
  47. // by the policy in ClientAuth
  48. clientCAs, err := LoadCert(p.ClientCAs)
  49. if err != nil {
  50. return nil, errors.Wrap(err, "Error loading client CAs")
  51. }
  52. tlsconfig.ClientCAs = clientCAs
  53. // server's policy for TLS Client Authentication. Default is no client cert
  54. tlsconfig.ClientAuth = tls.RequireAndVerifyClientCert
  55. }
  56. if len(p.RootCAs) > 0 {
  57. rootCAs, err := LoadCert(p.RootCAs)
  58. if err != nil {
  59. return nil, errors.Wrap(err, "Error loading root CAs")
  60. }
  61. tlsconfig.RootCAs = rootCAs
  62. }
  63. if p.ServerName != "" {
  64. tlsconfig.ServerName = p.ServerName
  65. }
  66. if len(p.CurvePreferences) > 0 {
  67. tlsconfig.CurvePreferences = p.CurvePreferences
  68. } else {
  69. // Cloudflare optimize CurveP256
  70. tlsconfig.CurvePreferences = []tls.CurveID{tls.CurveP256}
  71. }
  72. tlsconfig.MinVersion = p.MinVersion
  73. tlsconfig.MaxVersion = p.MaxVersion
  74. return tlsconfig, nil
  75. }
  76. // LoadCert creates a CertPool containing all certificates in a PEM-format file.
  77. func LoadCert(certPaths []string) (*x509.CertPool, error) {
  78. ca := x509.NewCertPool()
  79. for _, certPath := range certPaths {
  80. caCert, err := os.ReadFile(certPath)
  81. if err != nil {
  82. return nil, errors.Wrapf(err, "Error reading certificate %s", certPath)
  83. }
  84. if !ca.AppendCertsFromPEM(caCert) {
  85. return nil, errors.Wrapf(err, "Error parsing certificate %s", certPath)
  86. }
  87. }
  88. return ca, nil
  89. }