pki.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. package scan
  2. import (
  3. "bytes"
  4. "crypto/x509"
  5. "fmt"
  6. "time"
  7. "github.com/cloudflare/cfssl/helpers"
  8. "github.com/cloudflare/cfssl/revoke"
  9. "github.com/cloudflare/cfssl/scan/crypto/tls"
  10. )
  11. // PKI contains scanners for the Public Key Infrastructure.
  12. var PKI = &Family{
  13. Description: "Scans for the Public Key Infrastructure",
  14. Scanners: map[string]*Scanner{
  15. "ChainExpiration": {
  16. "Host's chain hasn't expired and won't expire in the next 30 days",
  17. chainExpiration,
  18. },
  19. "ChainValidation": {
  20. "All certificates in host's chain are valid",
  21. chainValidation,
  22. },
  23. "MultipleCerts": {
  24. "Host serves same certificate chain across all IPs",
  25. multipleCerts,
  26. },
  27. },
  28. }
  29. // getChain is a helper function that retreives the host's certificate chain.
  30. func getChain(addr string, config *tls.Config) (chain []*x509.Certificate, err error) {
  31. var conn *tls.Conn
  32. conn, err = tls.DialWithDialer(Dialer, Network, addr, config)
  33. if err != nil {
  34. return
  35. }
  36. err = conn.Close()
  37. if err != nil {
  38. return
  39. }
  40. chain = conn.ConnectionState().PeerCertificates
  41. if len(chain) == 0 {
  42. err = fmt.Errorf("%s returned empty certificate chain", addr)
  43. }
  44. return
  45. }
  46. type expiration time.Time
  47. func (e expiration) String() string {
  48. return time.Time(e).Format("Jan 2 15:04:05 2006 MST")
  49. }
  50. func chainExpiration(addr, hostname string) (grade Grade, output Output, err error) {
  51. chain, err := getChain(addr, defaultTLSConfig(hostname))
  52. if err != nil {
  53. return
  54. }
  55. expirationTime := helpers.ExpiryTime(chain)
  56. output = expirationTime
  57. if time.Now().After(expirationTime) {
  58. return
  59. }
  60. // Warn if cert will expire in the next 30 days
  61. if time.Now().Add(time.Hour * 24 * 30).After(expirationTime) {
  62. grade = Warning
  63. return
  64. }
  65. grade = Good
  66. return
  67. }
  68. func chainValidation(addr, hostname string) (grade Grade, output Output, err error) {
  69. chain, err := getChain(addr, defaultTLSConfig(hostname))
  70. if err != nil {
  71. return
  72. }
  73. var warnings []string
  74. for i := 0; i < len(chain)-1; i++ {
  75. cert, parent := chain[i], chain[i+1]
  76. valid := helpers.ValidExpiry(cert)
  77. if !valid {
  78. warnings = append(warnings, fmt.Sprintf("Certificate for %s is valid for too long", cert.Subject.CommonName))
  79. }
  80. revoked, ok := revoke.VerifyCertificate(cert)
  81. if !ok {
  82. warnings = append(warnings, fmt.Sprintf("couldn't check if %s is revoked", cert.Subject.CommonName))
  83. }
  84. if revoked {
  85. err = fmt.Errorf("%s is revoked", cert.Subject.CommonName)
  86. return
  87. }
  88. if !parent.IsCA {
  89. err = fmt.Errorf("%s is not a CA", parent.Subject.CommonName)
  90. return
  91. }
  92. if !bytes.Equal(cert.AuthorityKeyId, parent.SubjectKeyId) {
  93. err = fmt.Errorf("%s AuthorityKeyId differs from %s SubjectKeyId", cert.Subject.CommonName, parent.Subject.CommonName)
  94. return
  95. }
  96. if err = cert.CheckSignatureFrom(parent); err != nil {
  97. return
  98. }
  99. switch cert.SignatureAlgorithm {
  100. case x509.ECDSAWithSHA1:
  101. warnings = append(warnings, fmt.Sprintf("%s is signed by ECDSAWithSHA1", cert.Subject.CommonName))
  102. case x509.SHA1WithRSA:
  103. warnings = append(warnings, fmt.Sprintf("%s is signed by RSAWithSHA1", cert.Subject.CommonName))
  104. }
  105. }
  106. if len(warnings) == 0 {
  107. grade = Good
  108. } else {
  109. grade = Warning
  110. output = warnings
  111. }
  112. return
  113. }
  114. func multipleCerts(addr, hostname string) (grade Grade, output Output, err error) {
  115. config := defaultTLSConfig(hostname)
  116. firstChain, err := getChain(addr, config)
  117. if err != nil {
  118. return
  119. }
  120. grade, _, err = multiscan(addr, func(addrport string) (g Grade, o Output, e error) {
  121. g = Good
  122. chain, e1 := getChain(addrport, config)
  123. if e1 != nil {
  124. return
  125. }
  126. if !chain[0].Equal(firstChain[0]) {
  127. e = fmt.Errorf("%s not equal to %s", chain[0].Subject.CommonName, firstChain[0].Subject.CommonName)
  128. g = Bad
  129. return
  130. }
  131. return
  132. })
  133. return
  134. }