client.go 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371
  1. package transport
  2. import (
  3. "crypto/tls"
  4. "net"
  5. "os"
  6. "time"
  7. "github.com/cloudflare/backoff"
  8. "github.com/cloudflare/cfssl/csr"
  9. "github.com/cloudflare/cfssl/errors"
  10. "github.com/cloudflare/cfssl/log"
  11. "github.com/cloudflare/cfssl/revoke"
  12. "github.com/cloudflare/cfssl/transport/ca"
  13. "github.com/cloudflare/cfssl/transport/core"
  14. "github.com/cloudflare/cfssl/transport/kp"
  15. "github.com/cloudflare/cfssl/transport/roots"
  16. )
  17. func envOrDefault(key, def string) string {
  18. val := os.Getenv(key)
  19. if val == "" {
  20. return def
  21. }
  22. return val
  23. }
  24. var (
  25. // NewKeyProvider is the function used to build key providers
  26. // from some identity.
  27. NewKeyProvider = func(id *core.Identity) (kp.KeyProvider, error) {
  28. return kp.NewStandardProvider(id)
  29. }
  30. // NewCA is used to load a configuration for a certificate
  31. // authority.
  32. NewCA = func(id *core.Identity) (ca.CertificateAuthority, error) {
  33. return ca.NewCFSSLProvider(id, nil)
  34. }
  35. )
  36. // A Transport is capable of providing transport-layer security using
  37. // TLS.
  38. type Transport struct {
  39. // Before defines how long before the certificate expires the
  40. // transport should start attempting to refresh the
  41. // certificate. For example, if this is 24h, then 24 hours
  42. // before the certificate expires the Transport will start
  43. // attempting to replace it.
  44. Before time.Duration
  45. // Provider contains a key management provider.
  46. Provider kp.KeyProvider
  47. // CA contains a mechanism for obtaining signed certificates.
  48. CA ca.CertificateAuthority
  49. // TrustStore contains the certificates trusted by this
  50. // transport.
  51. TrustStore *roots.TrustStore
  52. // ClientTrustStore contains the certificate authorities to
  53. // use in verifying client authentication certificates.
  54. ClientTrustStore *roots.TrustStore
  55. // Identity contains information about the entity that will be
  56. // used to construct certificates.
  57. Identity *core.Identity
  58. // Backoff is used to control the behaviour of a Transport
  59. // when it is attempting to automatically update a certificate
  60. // as part of AutoUpdate.
  61. Backoff *backoff.Backoff
  62. // RevokeSoftFail, if true, will cause a failure to check
  63. // revocation (such that the revocation status of a
  64. // certificate cannot be checked) to not be treated as an
  65. // error.
  66. RevokeSoftFail bool
  67. }
  68. // TLSClientAuthClientConfig returns a new client authentication TLS
  69. // configuration that can be used for a client using client auth
  70. // connecting to the named host.
  71. func (tr *Transport) TLSClientAuthClientConfig(host string) (*tls.Config, error) {
  72. cert, err := tr.getCertificate()
  73. if err != nil {
  74. return nil, err
  75. }
  76. return &tls.Config{
  77. Certificates: []tls.Certificate{cert},
  78. RootCAs: tr.TrustStore.Pool(),
  79. ServerName: host,
  80. CipherSuites: core.CipherSuites,
  81. MinVersion: tls.VersionTLS12,
  82. ClientAuth: tls.RequireAndVerifyClientCert,
  83. }, nil
  84. }
  85. // TLSClientAuthServerConfig returns a new client authentication TLS
  86. // configuration for servers expecting mutually authenticated
  87. // clients. The clientAuth parameter should contain the root pool used
  88. // to authenticate clients.
  89. func (tr *Transport) TLSClientAuthServerConfig() (*tls.Config, error) {
  90. cert, err := tr.getCertificate()
  91. if err != nil {
  92. return nil, err
  93. }
  94. return &tls.Config{
  95. Certificates: []tls.Certificate{cert},
  96. RootCAs: tr.TrustStore.Pool(),
  97. ClientCAs: tr.ClientTrustStore.Pool(),
  98. ClientAuth: tls.RequireAndVerifyClientCert,
  99. CipherSuites: core.CipherSuites,
  100. MinVersion: tls.VersionTLS12,
  101. }, nil
  102. }
  103. // TLSServerConfig is a general server configuration that should be
  104. // used for non-client authentication purposes, such as HTTPS.
  105. func (tr *Transport) TLSServerConfig() (*tls.Config, error) {
  106. cert, err := tr.getCertificate()
  107. if err != nil {
  108. return nil, err
  109. }
  110. return &tls.Config{
  111. Certificates: []tls.Certificate{cert},
  112. CipherSuites: core.CipherSuites,
  113. MinVersion: tls.VersionTLS12,
  114. }, nil
  115. }
  116. // New builds a new transport from an identity and a before time. The
  117. // before time tells the transport how long before the certificate
  118. // expires to start attempting to update when auto-updating. If before
  119. // is longer than the certificate's lifetime, every update check will
  120. // trigger a new certificate to be generated.
  121. func New(before time.Duration, identity *core.Identity) (*Transport, error) {
  122. var tr = &Transport{
  123. Before: before,
  124. Identity: identity,
  125. Backoff: &backoff.Backoff{},
  126. }
  127. store, err := roots.New(identity.Roots)
  128. if err != nil {
  129. return nil, err
  130. }
  131. tr.TrustStore = store
  132. if len(identity.ClientRoots) > 0 {
  133. store, err = roots.New(identity.ClientRoots)
  134. if err != nil {
  135. return nil, err
  136. }
  137. tr.ClientTrustStore = store
  138. }
  139. tr.Provider, err = NewKeyProvider(identity)
  140. if err != nil {
  141. return nil, err
  142. }
  143. tr.CA, err = NewCA(identity)
  144. if err != nil {
  145. return nil, err
  146. }
  147. return tr, nil
  148. }
  149. // Lifespan returns how much time is left before the transport's
  150. // certificate expires, or 0 if the certificate is not present or
  151. // expired.
  152. func (tr *Transport) Lifespan() time.Duration {
  153. cert := tr.Provider.Certificate()
  154. if cert == nil {
  155. return 0
  156. }
  157. now := time.Now()
  158. if now.After(cert.NotAfter) {
  159. return 0
  160. }
  161. now = now.Add(tr.Before)
  162. ls := cert.NotAfter.Sub(now)
  163. log.Debugf(" LIFESPAN:\t%s", ls)
  164. if ls < 0 {
  165. return 0
  166. }
  167. return ls
  168. }
  169. // RefreshKeys will make sure the Transport has loaded keys and has a
  170. // valid certificate. It will handle any persistence, check that the
  171. // certificate is valid (i.e. that its expiry date is within the
  172. // Before date), and handle certificate reissuance as needed.
  173. func (tr *Transport) RefreshKeys() (err error) {
  174. if !tr.Provider.Ready() {
  175. log.Debug("key and certificate aren't ready, loading")
  176. err = tr.Provider.Load()
  177. if err != nil && err != kp.ErrCertificateUnavailable {
  178. log.Debugf("failed to load keypair: %v", err)
  179. kr := tr.Identity.Request.KeyRequest
  180. if kr == nil {
  181. kr = csr.NewBasicKeyRequest()
  182. }
  183. err = tr.Provider.Generate(kr.Algo(), kr.Size())
  184. if err != nil {
  185. log.Debugf("failed to generate key: %v", err)
  186. return err
  187. }
  188. }
  189. }
  190. lifespan := tr.Lifespan()
  191. if lifespan < tr.Before {
  192. log.Debugf("transport's certificate is out of date (lifespan %s)", lifespan)
  193. req, err := tr.Provider.CertificateRequest(tr.Identity.Request)
  194. if err != nil {
  195. log.Debugf("couldn't get a CSR: %v", err)
  196. if tr.Provider.SignalFailure(err) {
  197. return tr.RefreshKeys()
  198. }
  199. return err
  200. }
  201. log.Debug("requesting certificate from CA")
  202. cert, err := tr.CA.SignCSR(req)
  203. if err != nil {
  204. if tr.Provider.SignalFailure(err) {
  205. return tr.RefreshKeys()
  206. }
  207. log.Debugf("failed to get the certificate signed: %v", err)
  208. return err
  209. }
  210. log.Debug("giving the certificate to the provider")
  211. err = tr.Provider.SetCertificatePEM(cert)
  212. if err != nil {
  213. log.Debugf("failed to set the provider's certificate: %v", err)
  214. if tr.Provider.SignalFailure(err) {
  215. return tr.RefreshKeys()
  216. }
  217. return err
  218. }
  219. if tr.Provider.Persistent() {
  220. log.Debug("storing the certificate")
  221. err = tr.Provider.Store()
  222. if err != nil {
  223. log.Debugf("the provider failed to store the certificate: %v", err)
  224. if tr.Provider.SignalFailure(err) {
  225. return tr.RefreshKeys()
  226. }
  227. return err
  228. }
  229. }
  230. }
  231. return nil
  232. }
  233. func (tr *Transport) getCertificate() (cert tls.Certificate, err error) {
  234. if !tr.Provider.Ready() {
  235. log.Debug("transport isn't ready; attempting to refresh keypair")
  236. err = tr.RefreshKeys()
  237. if err != nil {
  238. log.Debugf("transport couldn't get a certificate: %v", err)
  239. return
  240. }
  241. }
  242. cert, err = tr.Provider.X509KeyPair()
  243. if err != nil {
  244. log.Debugf("couldn't generate an X.509 keypair: %v", err)
  245. }
  246. return
  247. }
  248. // Dial initiates a TLS connection to an outbound server. It returns a
  249. // TLS connection to the server.
  250. func Dial(address string, tr *Transport) (*tls.Conn, error) {
  251. host, _, err := net.SplitHostPort(address)
  252. if err != nil {
  253. // Assume address is a hostname, and that it should
  254. // use the HTTPS port number.
  255. host = address
  256. address = net.JoinHostPort(address, "443")
  257. }
  258. cfg, err := tr.TLSClientAuthClientConfig(host)
  259. if err != nil {
  260. return nil, err
  261. }
  262. conn, err := tls.Dial("tcp", address, cfg)
  263. if err != nil {
  264. return nil, err
  265. }
  266. state := conn.ConnectionState()
  267. if len(state.VerifiedChains) == 0 {
  268. return nil, errors.New(errors.CertificateError, errors.VerifyFailed)
  269. }
  270. for _, chain := range state.VerifiedChains {
  271. for _, cert := range chain {
  272. revoked, ok := revoke.VerifyCertificate(cert)
  273. if (!tr.RevokeSoftFail && !ok) || revoked {
  274. return nil, errors.New(errors.CertificateError, errors.VerifyFailed)
  275. }
  276. }
  277. }
  278. return conn, nil
  279. }
  280. // AutoUpdate will automatically update the listener. If a non-nil
  281. // certUpdates chan is provided, it will receive timestamps for
  282. // reissued certificates. If errChan is non-nil, any errors that occur
  283. // in the updater will be passed along.
  284. func (tr *Transport) AutoUpdate(certUpdates chan<- time.Time, errChan chan<- error) {
  285. defer func() {
  286. if r := recover(); r != nil {
  287. log.Criticalf("AutoUpdate panicked: %v", r)
  288. }
  289. }()
  290. for {
  291. // Wait until it's time to update the certificate.
  292. target := time.Now().Add(tr.Lifespan())
  293. if PollInterval == 0 {
  294. <-time.After(tr.Lifespan())
  295. } else {
  296. pollWait(target)
  297. }
  298. // Keep trying to update the certificate until it's
  299. // ready.
  300. for {
  301. log.Debugf("attempting to refresh keypair")
  302. err := tr.RefreshKeys()
  303. if err == nil {
  304. break
  305. }
  306. delay := tr.Backoff.Duration()
  307. log.Debugf("failed to update certificate, will try again in %s", delay)
  308. if errChan != nil {
  309. errChan <- err
  310. }
  311. <-time.After(delay)
  312. }
  313. log.Debugf("certificate updated")
  314. if certUpdates != nil {
  315. certUpdates <- time.Now()
  316. }
  317. tr.Backoff.Reset()
  318. }
  319. }