ocspstapling.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. // Package ocspstapling implements OCSP stapling of Signed Certificate
  2. // Timestamps (SCTs) into OCSP responses in a database. See RFC 6962.
  3. package ocspstapling
  4. import (
  5. "crypto"
  6. "crypto/x509"
  7. "crypto/x509/pkix"
  8. "encoding/asn1"
  9. "encoding/base64"
  10. "errors"
  11. "github.com/cloudflare/cfssl/certdb"
  12. cferr "github.com/cloudflare/cfssl/errors"
  13. "github.com/cloudflare/cfssl/helpers"
  14. ct "github.com/google/certificate-transparency-go"
  15. "golang.org/x/crypto/ocsp"
  16. )
  17. // sctExtOid is the OID of the OCSP Stapling SCT extension (see section 3.3. of RFC 6962).
  18. var sctExtOid = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 11129, 2, 4, 5}
  19. // StapleSCTList inserts a list of Signed Certificate Timestamps into all OCSP
  20. // responses in a database wrapped by a given certdb.Accessor.
  21. //
  22. // NOTE: This function is patterned after the exported Sign method in
  23. // https://github.com/cloudflare/cfssl/blob/master/signer/local/local.go
  24. func StapleSCTList(acc certdb.Accessor, serial, aki string, scts []ct.SignedCertificateTimestamp,
  25. responderCert, issuer *x509.Certificate, priv crypto.Signer) error {
  26. ocspRecs, err := acc.GetOCSP(serial, aki)
  27. if err != nil {
  28. return err
  29. }
  30. if len(ocspRecs) == 0 {
  31. return cferr.Wrap(cferr.CertStoreError, cferr.RecordNotFound, errors.New("empty OCSPRecord"))
  32. }
  33. // This loop adds the SCTs to each OCSP response in ocspRecs.
  34. for _, rec := range ocspRecs {
  35. der, err := base64.StdEncoding.DecodeString(rec.Body)
  36. if err != nil {
  37. return cferr.Wrap(cferr.CertificateError, cferr.DecodeFailed,
  38. errors.New("failed to decode Base64-encoded OCSP response"))
  39. }
  40. response, err := ocsp.ParseResponse(der, nil)
  41. if err != nil {
  42. return cferr.Wrap(cferr.CertificateError, cferr.ParseFailed,
  43. errors.New("failed to parse DER-encoded OCSP response"))
  44. }
  45. serializedSCTList, err := helpers.SerializeSCTList(scts)
  46. if err != nil {
  47. return cferr.Wrap(cferr.CTError, cferr.Unknown,
  48. errors.New("failed to serialize SCT list"))
  49. }
  50. serializedSCTList, err = asn1.Marshal(serializedSCTList)
  51. if err != nil {
  52. return cferr.Wrap(cferr.CTError, cferr.Unknown,
  53. errors.New("failed to serialize SCT list"))
  54. }
  55. sctExtension := pkix.Extension{
  56. Id: sctExtOid,
  57. Critical: false,
  58. Value: serializedSCTList,
  59. }
  60. // This loop finds the SCTListExtension in the ocsp response.
  61. var idxExt int
  62. for _, ext := range response.Extensions {
  63. if ext.Id.Equal(sctExtOid) {
  64. break
  65. }
  66. idxExt++
  67. }
  68. newExtensions := make([]pkix.Extension, len(response.Extensions))
  69. copy(newExtensions, response.Extensions)
  70. if idxExt >= len(response.Extensions) {
  71. // No SCT extension was found.
  72. newExtensions = append(newExtensions, sctExtension)
  73. } else {
  74. newExtensions[idxExt] = sctExtension
  75. }
  76. // Here we write the updated extensions to replace the old
  77. // response extensions when re-marshalling.
  78. newSN := *response.SerialNumber
  79. template := ocsp.Response{
  80. Status: response.Status,
  81. SerialNumber: &newSN,
  82. ThisUpdate: response.ThisUpdate,
  83. NextUpdate: response.NextUpdate,
  84. Certificate: response.Certificate,
  85. ExtraExtensions: newExtensions,
  86. IssuerHash: response.IssuerHash,
  87. }
  88. // Finally, we re-sign the response to generate the new
  89. // DER-encoded response.
  90. der, err = ocsp.CreateResponse(issuer, responderCert, template, priv)
  91. if err != nil {
  92. return cferr.Wrap(cferr.CTError, cferr.Unknown,
  93. errors.New("failed to sign new OCSP response"))
  94. }
  95. body := base64.StdEncoding.EncodeToString(der)
  96. err = acc.UpdateOCSP(serial, aki, body, rec.Expiry)
  97. if err != nil {
  98. return err
  99. }
  100. }
  101. return nil
  102. }