testing_helpers_test.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480
  1. package testsuite
  2. import (
  3. "crypto/x509"
  4. "encoding/json"
  5. "io/ioutil"
  6. "math"
  7. "math/rand"
  8. "os"
  9. "os/exec"
  10. "reflect"
  11. "strconv"
  12. "strings"
  13. "testing"
  14. "time"
  15. "github.com/cloudflare/cfssl/csr"
  16. "github.com/cloudflare/cfssl/helpers"
  17. )
  18. const (
  19. testDataDirectory = "testdata"
  20. initCADirectory = testDataDirectory + string(os.PathSeparator) + "initCA"
  21. preMadeOutput = initCADirectory + string(os.PathSeparator) + "cfssl_output.pem"
  22. csrFile = testDataDirectory + string(os.PathSeparator) + "cert_csr.json"
  23. )
  24. var (
  25. keyRequest = csr.BasicKeyRequest{
  26. A: "rsa",
  27. S: 2048,
  28. }
  29. CAConfig = csr.CAConfig{
  30. PathLength: 1,
  31. Expiry: "1h", // issue a CA certificate only valid for 1 hour
  32. }
  33. baseRequest = csr.CertificateRequest{
  34. CN: "example.com",
  35. Names: []csr.Name{
  36. {
  37. C: "US",
  38. ST: "California",
  39. L: "San Francisco",
  40. O: "Internet Widgets, LLC",
  41. OU: "Certificate Authority",
  42. },
  43. },
  44. Hosts: []string{"ca.example.com"},
  45. KeyRequest: &keyRequest,
  46. }
  47. CARequest = csr.CertificateRequest{
  48. CN: "example.com",
  49. Names: []csr.Name{
  50. {
  51. C: "US",
  52. ST: "California",
  53. L: "San Francisco",
  54. O: "Internet Widgets, LLC",
  55. OU: "Certificate Authority",
  56. },
  57. },
  58. Hosts: []string{"ca.example.com"},
  59. KeyRequest: &keyRequest,
  60. CA: &CAConfig,
  61. }
  62. )
  63. func TestStartCFSSLServer(t *testing.T) {
  64. // We will test on this address and port. Be sure that these are free or
  65. // the test will fail.
  66. addressToTest := "127.0.0.1"
  67. portToTest := 9775
  68. CACert, CAKey, err := CreateSelfSignedCert(CARequest)
  69. if err != nil {
  70. t.Fatal(err.Error())
  71. }
  72. // Set up a test server using our CA certificate and key.
  73. serverData := CFSSLServerData{CA: CACert, CAKey: CAKey}
  74. server, err := StartCFSSLServer(addressToTest, portToTest, serverData)
  75. if err != nil {
  76. t.Fatal(err.Error())
  77. }
  78. // Try to start up a second server at the same address and port number. We
  79. // should get an 'address in use' error.
  80. _, err = StartCFSSLServer(addressToTest, portToTest, serverData)
  81. if err == nil || !strings.Contains(err.Error(), "Error occurred on server: address") {
  82. t.Fatal("Two servers allowed on same address and port.")
  83. }
  84. // Now make a request of our server and check that no error occurred.
  85. // First we need a request to send to our server. We marshall the request
  86. // into JSON format and write it to a temporary file.
  87. jsonBytes, err := json.Marshal(baseRequest)
  88. if err != nil {
  89. t.Fatal(err.Error())
  90. }
  91. tempFile, err := createTempFile(jsonBytes)
  92. if err != nil {
  93. os.Remove(tempFile)
  94. panic(err)
  95. }
  96. // Now we make the request and check the output.
  97. remoteServerString := "-remote=" + "http://" + addressToTest + ":" + strconv.Itoa(portToTest)
  98. command := exec.Command(
  99. "cfssl", "gencert", remoteServerString, "-hostname="+baseRequest.CN, tempFile)
  100. CLIOutput, err := command.CombinedOutput()
  101. os.Remove(tempFile)
  102. if err != nil {
  103. t.Fatalf("%v: %s", err.Error(), string(CLIOutput))
  104. }
  105. err = checkCLIOutput(CLIOutput)
  106. if err != nil {
  107. t.Fatal(err.Error())
  108. }
  109. // The output should contain the certificate, request, and private key.
  110. _, err = cleanCLIOutput(CLIOutput, "cert")
  111. if err != nil {
  112. t.Fatal(err.Error())
  113. }
  114. _, err = cleanCLIOutput(CLIOutput, "csr")
  115. if err != nil {
  116. t.Fatal(err.Error())
  117. }
  118. _, err = cleanCLIOutput(CLIOutput, "key")
  119. if err != nil {
  120. t.Fatal(err.Error())
  121. }
  122. // Finally, kill the server.
  123. err = server.Kill()
  124. if err != nil {
  125. t.Fatal(err.Error())
  126. }
  127. }
  128. func TestCreateCertificateChain(t *testing.T) {
  129. // N is the number of certificates that will be chained together.
  130. N := 10
  131. // --- TEST: Create a chain of one certificate. --- //
  132. encodedChainFromCode, _, err := CreateCertificateChain([]csr.CertificateRequest{CARequest})
  133. if err != nil {
  134. t.Fatal(err.Error())
  135. }
  136. // Now compare to a pre-made certificate chain using a JSON file containing
  137. // the same request data.
  138. CLIOutputFile := preMadeOutput
  139. CLIOutput, err := ioutil.ReadFile(CLIOutputFile)
  140. if err != nil {
  141. t.Fatal(err.Error())
  142. }
  143. encodedChainFromCLI, err := cleanCLIOutput(CLIOutput, "cert")
  144. if err != nil {
  145. t.Fatal(err.Error())
  146. }
  147. chainFromCode, err := helpers.ParseCertificatesPEM(encodedChainFromCode)
  148. if err != nil {
  149. t.Fatal(err.Error())
  150. }
  151. chainFromCLI, err := helpers.ParseCertificatesPEM(encodedChainFromCLI)
  152. if err != nil {
  153. t.Fatal(err.Error())
  154. }
  155. if !chainsEqual(chainFromCode, chainFromCLI) {
  156. unequalFieldSlices := checkFieldsOfChains(chainFromCode, chainFromCLI)
  157. for i, unequalFields := range unequalFieldSlices {
  158. if len(unequalFields) > 0 {
  159. t.Log("The certificate chains held unequal fields for chain " + strconv.Itoa(i))
  160. t.Log("The following fields were unequal:")
  161. for _, field := range unequalFields {
  162. t.Log("\t" + field)
  163. }
  164. }
  165. }
  166. t.Fatal("Certificate chains unequal.")
  167. }
  168. // --- TEST: Create a chain of N certificates. --- //
  169. // First we make a slice of N requests. We make each slightly different.
  170. cnGrabBag := []string{"example", "invalid", "test"}
  171. topLevelDomains := []string{".com", ".net", ".org"}
  172. subDomains := []string{"www.", "secure.", "ca.", ""}
  173. countryGrabBag := []string{"USA", "China", "England", "Vanuatu"}
  174. stateGrabBag := []string{"California", "Texas", "Alaska", "London"}
  175. localityGrabBag := []string{"San Francisco", "Houston", "London", "Oslo"}
  176. orgGrabBag := []string{"Internet Widgets, LLC", "CloudFlare, Inc."}
  177. orgUnitGrabBag := []string{"Certificate Authority", "Systems Engineering"}
  178. requests := make([]csr.CertificateRequest, N)
  179. requests[0] = CARequest
  180. for i := 1; i < N; i++ {
  181. requests[i] = baseRequest
  182. cn := randomElement(cnGrabBag)
  183. tld := randomElement(topLevelDomains)
  184. subDomain1 := randomElement(subDomains)
  185. subDomain2 := randomElement(subDomains)
  186. country := randomElement(countryGrabBag)
  187. state := randomElement(stateGrabBag)
  188. locality := randomElement(localityGrabBag)
  189. org := randomElement(orgGrabBag)
  190. orgUnit := randomElement(orgUnitGrabBag)
  191. requests[i].CN = cn + "." + tld
  192. requests[i].Names = []csr.Name{
  193. {C: country,
  194. ST: state,
  195. L: locality,
  196. O: org,
  197. OU: orgUnit,
  198. },
  199. }
  200. hosts := []string{subDomain1 + requests[i].CN}
  201. if subDomain2 != subDomain1 {
  202. hosts = append(hosts, subDomain2+requests[i].CN)
  203. }
  204. requests[i].Hosts = hosts
  205. }
  206. // Now we make a certificate chain out of these requests.
  207. encodedCertChain, _, err := CreateCertificateChain(requests)
  208. if err != nil {
  209. t.Fatal(err.Error())
  210. }
  211. // To test this chain, we compare the data encoded in each certificate to
  212. // each request we used to generate the chain.
  213. chain, err := helpers.ParseCertificatesPEM(encodedCertChain)
  214. if err != nil {
  215. t.Fatal(err.Error())
  216. }
  217. if len(chain) != len(requests) {
  218. t.Log("Length of chain: " + strconv.Itoa(len(chain)))
  219. t.Log("Length of requests: " + strconv.Itoa(len(requests)))
  220. t.Fatal("Length of chain not equal to length of requests.")
  221. }
  222. mismatchOccurred := false
  223. for i := 0; i < len(chain); i++ {
  224. certEqualsRequest, unequalFields := certEqualsRequest(chain[i], requests[i])
  225. if !certEqualsRequest {
  226. mismatchOccurred = true
  227. t.Log(
  228. "Certificate " + strconv.Itoa(i) + " and request " +
  229. strconv.Itoa(i) + " unequal.",
  230. )
  231. t.Log("Unequal fields for index " + strconv.Itoa(i) + ":")
  232. for _, field := range unequalFields {
  233. t.Log("\t" + field)
  234. }
  235. }
  236. }
  237. // TODO: check that each certificate is actually signed by the previous one
  238. if mismatchOccurred {
  239. t.Fatal("Unequal certificate(s) and request(s) found.")
  240. }
  241. // --- TEST: Create a chain of certificates with invalid path lengths. --- //
  242. // Other invalid chains?
  243. }
  244. func TestCreateSelfSignedCert(t *testing.T) {
  245. // --- TEST: Create a self-signed certificate from a CSR. --- //
  246. // Generate a self-signed certificate from the request.
  247. encodedCertFromCode, _, err := CreateSelfSignedCert(CARequest)
  248. if err != nil {
  249. t.Fatal(err.Error())
  250. }
  251. // Now compare to a pre-made certificate made using a JSON file with the
  252. // same request information. This JSON file is located in testdata/initCA
  253. // and is called ca_csr.json.
  254. CLIOutputFile := preMadeOutput
  255. CLIOutput, err := ioutil.ReadFile(CLIOutputFile)
  256. if err != nil {
  257. t.Fatal(err.Error())
  258. }
  259. encodedCertFromCLI, err := cleanCLIOutput(CLIOutput, "cert")
  260. if err != nil {
  261. t.Fatal(err.Error())
  262. }
  263. certFromCode, err := helpers.ParseSelfSignedCertificatePEM(encodedCertFromCode)
  264. if err != nil {
  265. t.Fatal(err.Error())
  266. }
  267. certFromCLI, err := helpers.ParseSelfSignedCertificatePEM(encodedCertFromCLI)
  268. if err != nil {
  269. t.Fatal(err.Error())
  270. }
  271. // Nullify any fields of the certificates which are dependent upon the time
  272. // of the certificate's creation.
  273. nullifyTimeDependency(certFromCode)
  274. nullifyTimeDependency(certFromCLI)
  275. if !reflect.DeepEqual(certFromCode, certFromCLI) {
  276. unequalFields := checkFields(
  277. *certFromCode, *certFromCLI, reflect.TypeOf(*certFromCode))
  278. t.Log("The following fields were unequal:")
  279. for _, field := range unequalFields {
  280. t.Log(field)
  281. }
  282. t.Fatal("Certificates unequal.")
  283. }
  284. }
  285. // Compare two x509 certificate chains. We only compare relevant data to
  286. // determine equality.
  287. func chainsEqual(chain1, chain2 []*x509.Certificate) bool {
  288. if len(chain1) != len(chain2) {
  289. return false
  290. }
  291. for i := 0; i < len(chain1); i++ {
  292. cert1 := nullifyTimeDependency(chain1[i])
  293. cert2 := nullifyTimeDependency(chain2[i])
  294. if !reflect.DeepEqual(cert1, cert2) {
  295. return false
  296. }
  297. }
  298. return true
  299. }
  300. // When comparing certificates created at different times for equality, we do
  301. // not want to worry about fields which are dependent on the time of creation.
  302. // Thus we nullify these fields before comparing the certificates.
  303. func nullifyTimeDependency(cert *x509.Certificate) *x509.Certificate {
  304. cert.Raw = nil
  305. cert.RawTBSCertificate = nil
  306. cert.RawSubject = nil
  307. cert.RawIssuer = nil
  308. cert.RawSubjectPublicKeyInfo = nil
  309. cert.Signature = nil
  310. cert.PublicKey = nil
  311. cert.SerialNumber = nil
  312. cert.NotBefore = time.Time{}
  313. cert.NotAfter = time.Time{}
  314. cert.Extensions = nil
  315. cert.SubjectKeyId = nil
  316. cert.AuthorityKeyId = nil
  317. cert.Subject.Names = nil
  318. cert.Subject.ExtraNames = nil
  319. cert.Issuer.Names = nil
  320. cert.Issuer.ExtraNames = nil
  321. return cert
  322. }
  323. // Compares two structs and returns a list containing the names of all fields
  324. // for which the two structs hold different values.
  325. func checkFields(struct1, struct2 interface{}, typeOfStructs reflect.Type) []string {
  326. v1 := reflect.ValueOf(struct1)
  327. v2 := reflect.ValueOf(struct2)
  328. var unequalFields []string
  329. for i := 0; i < v1.NumField(); i++ {
  330. if !reflect.DeepEqual(v1.Field(i).Interface(), v2.Field(i).Interface()) {
  331. unequalFields = append(unequalFields, typeOfStructs.Field(i).Name)
  332. }
  333. }
  334. return unequalFields
  335. }
  336. // Runs checkFields on the corresponding elements of chain1 and chain2. Element
  337. // i of the returned slice contains a slice of the fields for which certificate
  338. // i in chain1 had different values than certificate i of chain2.
  339. func checkFieldsOfChains(chain1, chain2 []*x509.Certificate) [][]string {
  340. minLen := math.Min(float64(len(chain1)), float64(len(chain2)))
  341. typeOfCert := reflect.TypeOf(*chain1[0])
  342. var unequalFields [][]string
  343. for i := 0; i < int(minLen); i++ {
  344. unequalFields = append(unequalFields, checkFields(
  345. *chain1[i], *chain2[i], typeOfCert))
  346. }
  347. return unequalFields
  348. }
  349. // Compares a certificate to a request. Returns (true, []) if both items
  350. // contain matching data (for the things that can match). Otherwise, returns
  351. // (false, unequalFields) where unequalFields contains the names of all fields
  352. // which did not match.
  353. func certEqualsRequest(cert *x509.Certificate, request csr.CertificateRequest) (bool, []string) {
  354. equal := true
  355. var unequalFields []string
  356. if cert.Subject.CommonName != request.CN {
  357. equal = false
  358. unequalFields = append(unequalFields, "Common Name")
  359. }
  360. nameData := make(map[string]map[string]bool)
  361. nameData["Country"] = make(map[string]bool)
  362. nameData["Organization"] = make(map[string]bool)
  363. nameData["OrganizationalUnit"] = make(map[string]bool)
  364. nameData["Locality"] = make(map[string]bool)
  365. nameData["Province"] = make(map[string]bool)
  366. for _, name := range request.Names {
  367. nameData["Country"][name.C] = true
  368. nameData["Organization"][name.O] = true
  369. nameData["OrganizationalUnit"][name.OU] = true
  370. nameData["Locality"][name.L] = true
  371. nameData["Province"][name.ST] = true
  372. }
  373. for _, country := range cert.Subject.Country {
  374. if _, exists := nameData["Country"][country]; !exists {
  375. equal = false
  376. unequalFields = append(unequalFields, "Country")
  377. }
  378. }
  379. for _, organization := range cert.Subject.Organization {
  380. if _, exists := nameData["Organization"][organization]; !exists {
  381. equal = false
  382. unequalFields = append(unequalFields, "Organization")
  383. }
  384. }
  385. for _, organizationalUnit := range cert.Subject.OrganizationalUnit {
  386. if _, exists := nameData["OrganizationalUnit"][organizationalUnit]; !exists {
  387. equal = false
  388. unequalFields = append(unequalFields, "OrganizationalUnit")
  389. }
  390. }
  391. for _, locality := range cert.Subject.Locality {
  392. if _, exists := nameData["Locality"][locality]; !exists {
  393. equal = false
  394. unequalFields = append(unequalFields, "Locality")
  395. }
  396. }
  397. for _, province := range cert.Subject.Province {
  398. if _, exists := nameData["Province"][province]; !exists {
  399. equal = false
  400. unequalFields = append(unequalFields, "Province")
  401. }
  402. }
  403. // TODO: check hosts
  404. if cert.BasicConstraintsValid && request.CA != nil {
  405. if cert.MaxPathLen != request.CA.PathLength {
  406. equal = false
  407. unequalFields = append(unequalFields, "Max Path Length")
  408. }
  409. // TODO: check expiry
  410. }
  411. // TODO: check isCA
  412. return equal, unequalFields
  413. }
  414. // Returns a random element of the input slice.
  415. func randomElement(set []string) string {
  416. return set[rand.Intn(len(set))]
  417. }