testing_helpers_test.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479
  1. package testsuite
  2. import (
  3. "crypto/x509"
  4. "encoding/json"
  5. "math"
  6. "math/rand"
  7. "os"
  8. "os/exec"
  9. "reflect"
  10. "strconv"
  11. "strings"
  12. "testing"
  13. "time"
  14. "github.com/cloudflare/cfssl/csr"
  15. "github.com/cloudflare/cfssl/helpers"
  16. )
  17. const (
  18. testDataDirectory = "testdata"
  19. initCADirectory = testDataDirectory + string(os.PathSeparator) + "initCA"
  20. preMadeOutput = initCADirectory + string(os.PathSeparator) + "cfssl_output.pem"
  21. csrFile = testDataDirectory + string(os.PathSeparator) + "cert_csr.json"
  22. )
  23. var (
  24. keyRequest = csr.KeyRequest{
  25. A: "rsa",
  26. S: 2048,
  27. }
  28. CAConfig = csr.CAConfig{
  29. PathLength: 1,
  30. Expiry: "1h", // issue a CA certificate only valid for 1 hour
  31. }
  32. baseRequest = csr.CertificateRequest{
  33. CN: "example.com",
  34. Names: []csr.Name{
  35. {
  36. C: "US",
  37. ST: "California",
  38. L: "San Francisco",
  39. O: "Internet Widgets, LLC",
  40. OU: "Certificate Authority",
  41. },
  42. },
  43. Hosts: []string{"ca.example.com"},
  44. KeyRequest: &keyRequest,
  45. }
  46. CARequest = csr.CertificateRequest{
  47. CN: "example.com",
  48. Names: []csr.Name{
  49. {
  50. C: "US",
  51. ST: "California",
  52. L: "San Francisco",
  53. O: "Internet Widgets, LLC",
  54. OU: "Certificate Authority",
  55. },
  56. },
  57. Hosts: []string{"ca.example.com"},
  58. KeyRequest: &keyRequest,
  59. CA: &CAConfig,
  60. }
  61. )
  62. func TestStartCFSSLServer(t *testing.T) {
  63. // We will test on this address and port. Be sure that these are free or
  64. // the test will fail.
  65. addressToTest := "127.0.0.1"
  66. portToTest := 9775
  67. CACert, CAKey, err := CreateSelfSignedCert(CARequest)
  68. if err != nil {
  69. t.Fatal(err.Error())
  70. }
  71. // Set up a test server using our CA certificate and key.
  72. serverData := CFSSLServerData{CA: CACert, CAKey: CAKey}
  73. server, err := StartCFSSLServer(addressToTest, portToTest, serverData)
  74. if err != nil {
  75. t.Fatal(err.Error())
  76. }
  77. // Try to start up a second server at the same address and port number. We
  78. // should get an 'address in use' error.
  79. _, err = StartCFSSLServer(addressToTest, portToTest, serverData)
  80. if err == nil || !strings.Contains(err.Error(), "Error occurred on server: address") {
  81. t.Fatal("Two servers allowed on same address and port.")
  82. }
  83. // Now make a request of our server and check that no error occurred.
  84. // First we need a request to send to our server. We marshall the request
  85. // into JSON format and write it to a temporary file.
  86. jsonBytes, err := json.Marshal(baseRequest)
  87. if err != nil {
  88. t.Fatal(err.Error())
  89. }
  90. tempFile, err := createTempFile(jsonBytes)
  91. if err != nil {
  92. os.Remove(tempFile)
  93. panic(err)
  94. }
  95. // Now we make the request and check the output.
  96. remoteServerString := "-remote=" + "http://" + addressToTest + ":" + strconv.Itoa(portToTest)
  97. command := exec.Command(
  98. "cfssl", "gencert", remoteServerString, "-hostname="+baseRequest.CN, tempFile)
  99. CLIOutput, err := command.CombinedOutput()
  100. os.Remove(tempFile)
  101. if err != nil {
  102. t.Fatalf("%v: %s", err.Error(), string(CLIOutput))
  103. }
  104. err = checkCLIOutput(CLIOutput)
  105. if err != nil {
  106. t.Fatal(err.Error())
  107. }
  108. // The output should contain the certificate, request, and private key.
  109. _, err = cleanCLIOutput(CLIOutput, "cert")
  110. if err != nil {
  111. t.Fatal(err.Error())
  112. }
  113. _, err = cleanCLIOutput(CLIOutput, "csr")
  114. if err != nil {
  115. t.Fatal(err.Error())
  116. }
  117. _, err = cleanCLIOutput(CLIOutput, "key")
  118. if err != nil {
  119. t.Fatal(err.Error())
  120. }
  121. // Finally, kill the server.
  122. err = server.Kill()
  123. if err != nil {
  124. t.Fatal(err.Error())
  125. }
  126. }
  127. func TestCreateCertificateChain(t *testing.T) {
  128. // N is the number of certificates that will be chained together.
  129. N := 10
  130. // --- TEST: Create a chain of one certificate. --- //
  131. encodedChainFromCode, _, err := CreateCertificateChain([]csr.CertificateRequest{CARequest})
  132. if err != nil {
  133. t.Fatal(err.Error())
  134. }
  135. // Now compare to a pre-made certificate chain using a JSON file containing
  136. // the same request data.
  137. CLIOutputFile := preMadeOutput
  138. CLIOutput, err := os.ReadFile(CLIOutputFile)
  139. if err != nil {
  140. t.Fatal(err.Error())
  141. }
  142. encodedChainFromCLI, err := cleanCLIOutput(CLIOutput, "cert")
  143. if err != nil {
  144. t.Fatal(err.Error())
  145. }
  146. chainFromCode, err := helpers.ParseCertificatesPEM(encodedChainFromCode)
  147. if err != nil {
  148. t.Fatal(err.Error())
  149. }
  150. chainFromCLI, err := helpers.ParseCertificatesPEM(encodedChainFromCLI)
  151. if err != nil {
  152. t.Fatal(err.Error())
  153. }
  154. if !chainsEqual(chainFromCode, chainFromCLI) {
  155. unequalFieldSlices := checkFieldsOfChains(chainFromCode, chainFromCLI)
  156. for i, unequalFields := range unequalFieldSlices {
  157. if len(unequalFields) > 0 {
  158. t.Log("The certificate chains held unequal fields for chain " + strconv.Itoa(i))
  159. t.Log("The following fields were unequal:")
  160. for _, field := range unequalFields {
  161. t.Log("\t" + field)
  162. }
  163. }
  164. }
  165. t.Fatal("Certificate chains unequal.")
  166. }
  167. // --- TEST: Create a chain of N certificates. --- //
  168. // First we make a slice of N requests. We make each slightly different.
  169. cnGrabBag := []string{"example", "invalid", "test"}
  170. topLevelDomains := []string{".com", ".net", ".org"}
  171. subDomains := []string{"www.", "secure.", "ca.", ""}
  172. countryGrabBag := []string{"USA", "China", "England", "Vanuatu"}
  173. stateGrabBag := []string{"California", "Texas", "Alaska", "London"}
  174. localityGrabBag := []string{"San Francisco", "Houston", "London", "Oslo"}
  175. orgGrabBag := []string{"Internet Widgets, LLC", "CloudFlare, Inc."}
  176. orgUnitGrabBag := []string{"Certificate Authority", "Systems Engineering"}
  177. requests := make([]csr.CertificateRequest, N)
  178. requests[0] = CARequest
  179. for i := 1; i < N; i++ {
  180. requests[i] = baseRequest
  181. cn := randomElement(cnGrabBag)
  182. tld := randomElement(topLevelDomains)
  183. subDomain1 := randomElement(subDomains)
  184. subDomain2 := randomElement(subDomains)
  185. country := randomElement(countryGrabBag)
  186. state := randomElement(stateGrabBag)
  187. locality := randomElement(localityGrabBag)
  188. org := randomElement(orgGrabBag)
  189. orgUnit := randomElement(orgUnitGrabBag)
  190. requests[i].CN = cn + tld
  191. requests[i].Names = []csr.Name{
  192. {C: country,
  193. ST: state,
  194. L: locality,
  195. O: org,
  196. OU: orgUnit,
  197. },
  198. }
  199. hosts := []string{subDomain1 + requests[i].CN}
  200. if subDomain2 != subDomain1 {
  201. hosts = append(hosts, subDomain2+requests[i].CN)
  202. }
  203. requests[i].Hosts = hosts
  204. }
  205. // Now we make a certificate chain out of these requests.
  206. encodedCertChain, _, err := CreateCertificateChain(requests)
  207. if err != nil {
  208. t.Fatal(err.Error())
  209. }
  210. // To test this chain, we compare the data encoded in each certificate to
  211. // each request we used to generate the chain.
  212. chain, err := helpers.ParseCertificatesPEM(encodedCertChain)
  213. if err != nil {
  214. t.Fatal(err.Error())
  215. }
  216. if len(chain) != len(requests) {
  217. t.Log("Length of chain: " + strconv.Itoa(len(chain)))
  218. t.Log("Length of requests: " + strconv.Itoa(len(requests)))
  219. t.Fatal("Length of chain not equal to length of requests.")
  220. }
  221. mismatchOccurred := false
  222. for i := 0; i < len(chain); i++ {
  223. certEqualsRequest, unequalFields := certEqualsRequest(chain[i], requests[i])
  224. if !certEqualsRequest {
  225. mismatchOccurred = true
  226. t.Log(
  227. "Certificate " + strconv.Itoa(i) + " and request " +
  228. strconv.Itoa(i) + " unequal.",
  229. )
  230. t.Log("Unequal fields for index " + strconv.Itoa(i) + ":")
  231. for _, field := range unequalFields {
  232. t.Log("\t" + field)
  233. }
  234. }
  235. }
  236. // TODO: check that each certificate is actually signed by the previous one
  237. if mismatchOccurred {
  238. t.Fatal("Unequal certificate(s) and request(s) found.")
  239. }
  240. // --- TEST: Create a chain of certificates with invalid path lengths. --- //
  241. // Other invalid chains?
  242. }
  243. func TestCreateSelfSignedCert(t *testing.T) {
  244. // --- TEST: Create a self-signed certificate from a CSR. --- //
  245. // Generate a self-signed certificate from the request.
  246. encodedCertFromCode, _, err := CreateSelfSignedCert(CARequest)
  247. if err != nil {
  248. t.Fatal(err.Error())
  249. }
  250. // Now compare to a pre-made certificate made using a JSON file with the
  251. // same request information. This JSON file is located in testdata/initCA
  252. // and is called ca_csr.json.
  253. CLIOutputFile := preMadeOutput
  254. CLIOutput, err := os.ReadFile(CLIOutputFile)
  255. if err != nil {
  256. t.Fatal(err.Error())
  257. }
  258. encodedCertFromCLI, err := cleanCLIOutput(CLIOutput, "cert")
  259. if err != nil {
  260. t.Fatal(err.Error())
  261. }
  262. certFromCode, err := helpers.ParseSelfSignedCertificatePEM(encodedCertFromCode)
  263. if err != nil {
  264. t.Fatal(err.Error())
  265. }
  266. certFromCLI, err := helpers.ParseSelfSignedCertificatePEM(encodedCertFromCLI)
  267. if err != nil {
  268. t.Fatal(err.Error())
  269. }
  270. // Nullify any fields of the certificates which are dependent upon the time
  271. // of the certificate's creation.
  272. nullifyTimeDependency(certFromCode)
  273. nullifyTimeDependency(certFromCLI)
  274. if !reflect.DeepEqual(certFromCode, certFromCLI) {
  275. unequalFields := checkFields(
  276. *certFromCode, *certFromCLI, reflect.TypeOf(*certFromCode))
  277. t.Log("The following fields were unequal:")
  278. for _, field := range unequalFields {
  279. t.Log(field)
  280. }
  281. t.Fatal("Certificates unequal.")
  282. }
  283. }
  284. // Compare two x509 certificate chains. We only compare relevant data to
  285. // determine equality.
  286. func chainsEqual(chain1, chain2 []*x509.Certificate) bool {
  287. if len(chain1) != len(chain2) {
  288. return false
  289. }
  290. for i := 0; i < len(chain1); i++ {
  291. cert1 := nullifyTimeDependency(chain1[i])
  292. cert2 := nullifyTimeDependency(chain2[i])
  293. if !reflect.DeepEqual(cert1, cert2) {
  294. return false
  295. }
  296. }
  297. return true
  298. }
  299. // When comparing certificates created at different times for equality, we do
  300. // not want to worry about fields which are dependent on the time of creation.
  301. // Thus we nullify these fields before comparing the certificates.
  302. func nullifyTimeDependency(cert *x509.Certificate) *x509.Certificate {
  303. cert.Raw = nil
  304. cert.RawTBSCertificate = nil
  305. cert.RawSubject = nil
  306. cert.RawIssuer = nil
  307. cert.RawSubjectPublicKeyInfo = nil
  308. cert.Signature = nil
  309. cert.PublicKey = nil
  310. cert.SerialNumber = nil
  311. cert.NotBefore = time.Time{}
  312. cert.NotAfter = time.Time{}
  313. cert.Extensions = nil
  314. cert.SubjectKeyId = nil
  315. cert.AuthorityKeyId = nil
  316. cert.Subject.Names = nil
  317. cert.Subject.ExtraNames = nil
  318. cert.Issuer.Names = nil
  319. cert.Issuer.ExtraNames = nil
  320. return cert
  321. }
  322. // Compares two structs and returns a list containing the names of all fields
  323. // for which the two structs hold different values.
  324. func checkFields(struct1, struct2 interface{}, typeOfStructs reflect.Type) []string {
  325. v1 := reflect.ValueOf(struct1)
  326. v2 := reflect.ValueOf(struct2)
  327. var unequalFields []string
  328. for i := 0; i < v1.NumField(); i++ {
  329. if !reflect.DeepEqual(v1.Field(i).Interface(), v2.Field(i).Interface()) {
  330. unequalFields = append(unequalFields, typeOfStructs.Field(i).Name)
  331. }
  332. }
  333. return unequalFields
  334. }
  335. // Runs checkFields on the corresponding elements of chain1 and chain2. Element
  336. // i of the returned slice contains a slice of the fields for which certificate
  337. // i in chain1 had different values than certificate i of chain2.
  338. func checkFieldsOfChains(chain1, chain2 []*x509.Certificate) [][]string {
  339. minLen := math.Min(float64(len(chain1)), float64(len(chain2)))
  340. typeOfCert := reflect.TypeOf(*chain1[0])
  341. var unequalFields [][]string
  342. for i := 0; i < int(minLen); i++ {
  343. unequalFields = append(unequalFields, checkFields(
  344. *chain1[i], *chain2[i], typeOfCert))
  345. }
  346. return unequalFields
  347. }
  348. // Compares a certificate to a request. Returns (true, []) if both items
  349. // contain matching data (for the things that can match). Otherwise, returns
  350. // (false, unequalFields) where unequalFields contains the names of all fields
  351. // which did not match.
  352. func certEqualsRequest(cert *x509.Certificate, request csr.CertificateRequest) (bool, []string) {
  353. equal := true
  354. var unequalFields []string
  355. if cert.Subject.CommonName != request.CN {
  356. equal = false
  357. unequalFields = append(unequalFields, "Common Name")
  358. }
  359. nameData := make(map[string]map[string]bool)
  360. nameData["Country"] = make(map[string]bool)
  361. nameData["Organization"] = make(map[string]bool)
  362. nameData["OrganizationalUnit"] = make(map[string]bool)
  363. nameData["Locality"] = make(map[string]bool)
  364. nameData["Province"] = make(map[string]bool)
  365. for _, name := range request.Names {
  366. nameData["Country"][name.C] = true
  367. nameData["Organization"][name.O] = true
  368. nameData["OrganizationalUnit"][name.OU] = true
  369. nameData["Locality"][name.L] = true
  370. nameData["Province"][name.ST] = true
  371. }
  372. for _, country := range cert.Subject.Country {
  373. if _, exists := nameData["Country"][country]; !exists {
  374. equal = false
  375. unequalFields = append(unequalFields, "Country")
  376. }
  377. }
  378. for _, organization := range cert.Subject.Organization {
  379. if _, exists := nameData["Organization"][organization]; !exists {
  380. equal = false
  381. unequalFields = append(unequalFields, "Organization")
  382. }
  383. }
  384. for _, organizationalUnit := range cert.Subject.OrganizationalUnit {
  385. if _, exists := nameData["OrganizationalUnit"][organizationalUnit]; !exists {
  386. equal = false
  387. unequalFields = append(unequalFields, "OrganizationalUnit")
  388. }
  389. }
  390. for _, locality := range cert.Subject.Locality {
  391. if _, exists := nameData["Locality"][locality]; !exists {
  392. equal = false
  393. unequalFields = append(unequalFields, "Locality")
  394. }
  395. }
  396. for _, province := range cert.Subject.Province {
  397. if _, exists := nameData["Province"][province]; !exists {
  398. equal = false
  399. unequalFields = append(unequalFields, "Province")
  400. }
  401. }
  402. // TODO: check hosts
  403. if cert.BasicConstraintsValid && request.CA != nil {
  404. if cert.MaxPathLen != request.CA.PathLength {
  405. equal = false
  406. unequalFields = append(unequalFields, "Max Path Length")
  407. }
  408. // TODO: check expiry
  409. }
  410. // TODO: check isCA
  411. return equal, unequalFields
  412. }
  413. // Returns a random element of the input slice.
  414. func randomElement(set []string) string {
  415. return set[rand.Intn(len(set))]
  416. }