sql_test.go 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. package sql
  2. import (
  3. "math"
  4. "testing"
  5. "time"
  6. "github.com/cloudflare/cfssl/certdb"
  7. "github.com/cloudflare/cfssl/certdb/testdb"
  8. "github.com/jmoiron/sqlx"
  9. )
  10. const (
  11. sqliteDBFile = "../testdb/certstore_development.db"
  12. fakeAKI = "fake_aki"
  13. )
  14. func TestNoDB(t *testing.T) {
  15. dba := &Accessor{}
  16. _, err := dba.GetCertificate("foobar serial", "random aki")
  17. if err == nil {
  18. t.Fatal("should return error")
  19. }
  20. }
  21. type TestAccessor struct {
  22. Accessor certdb.Accessor
  23. DB *sqlx.DB
  24. }
  25. func (ta *TestAccessor) Truncate() {
  26. testdb.Truncate(ta.DB)
  27. }
  28. func TestSQLite(t *testing.T) {
  29. db := testdb.SQLiteDB(sqliteDBFile)
  30. ta := TestAccessor{
  31. Accessor: NewAccessor(db),
  32. DB: db,
  33. }
  34. testEverything(ta, t)
  35. }
  36. // roughlySameTime decides if t1 and t2 are close enough.
  37. func roughlySameTime(t1, t2 time.Time) bool {
  38. // return true if the difference is smaller than 1 sec.
  39. return math.Abs(float64(t1.Sub(t2))) < float64(time.Second)
  40. }
  41. func testEverything(ta TestAccessor, t *testing.T) {
  42. testInsertCertificateAndGetCertificate(ta, t)
  43. testInsertCertificateAndGetUnexpiredCertificate(ta, t)
  44. testUpdateCertificateAndGetCertificate(ta, t)
  45. testInsertOCSPAndGetOCSP(ta, t)
  46. testInsertOCSPAndGetUnexpiredOCSP(ta, t)
  47. testUpdateOCSPAndGetOCSP(ta, t)
  48. testUpsertOCSPAndGetOCSP(ta, t)
  49. }
  50. func testInsertCertificateAndGetCertificate(ta TestAccessor, t *testing.T) {
  51. ta.Truncate()
  52. expiry := time.Date(2010, time.December, 25, 23, 0, 0, 0, time.UTC)
  53. want := certdb.CertificateRecord{
  54. PEM: "fake cert data",
  55. Serial: "fake serial",
  56. AKI: fakeAKI,
  57. Status: "good",
  58. Reason: 0,
  59. Expiry: expiry,
  60. }
  61. if err := ta.Accessor.InsertCertificate(want); err != nil {
  62. t.Fatal(err)
  63. }
  64. rets, err := ta.Accessor.GetCertificate(want.Serial, want.AKI)
  65. if err != nil {
  66. t.Fatal(err)
  67. }
  68. if len(rets) != 1 {
  69. t.Fatal("should only return one record.")
  70. }
  71. got := rets[0]
  72. // relfection comparison with zero time objects are not stable as it seems
  73. if want.Serial != got.Serial || want.Status != got.Status ||
  74. want.AKI != got.AKI || !got.RevokedAt.IsZero() ||
  75. want.PEM != got.PEM || !roughlySameTime(got.Expiry, expiry) {
  76. t.Errorf("want Certificate %+v, got %+v", want, got)
  77. }
  78. unexpired, err := ta.Accessor.GetUnexpiredCertificates()
  79. if err != nil {
  80. t.Fatal(err)
  81. }
  82. if len(unexpired) != 0 {
  83. t.Error("should not have unexpired certificate record")
  84. }
  85. }
  86. func testInsertCertificateAndGetUnexpiredCertificate(ta TestAccessor, t *testing.T) {
  87. ta.Truncate()
  88. expiry := time.Now().Add(time.Minute)
  89. want := certdb.CertificateRecord{
  90. PEM: "fake cert data",
  91. Serial: "fake serial 2",
  92. AKI: fakeAKI,
  93. Status: "good",
  94. Reason: 0,
  95. Expiry: expiry,
  96. }
  97. if err := ta.Accessor.InsertCertificate(want); err != nil {
  98. t.Fatal(err)
  99. }
  100. rets, err := ta.Accessor.GetCertificate(want.Serial, want.AKI)
  101. if err != nil {
  102. t.Fatal(err)
  103. }
  104. if len(rets) != 1 {
  105. t.Fatal("should return exactly one record")
  106. }
  107. got := rets[0]
  108. // relfection comparison with zero time objects are not stable as it seems
  109. if want.Serial != got.Serial || want.Status != got.Status ||
  110. want.AKI != got.AKI || !got.RevokedAt.IsZero() ||
  111. want.PEM != got.PEM || !roughlySameTime(got.Expiry, expiry) {
  112. t.Errorf("want Certificate %+v, got %+v", want, got)
  113. }
  114. unexpired, err := ta.Accessor.GetUnexpiredCertificates()
  115. if err != nil {
  116. t.Fatal(err)
  117. }
  118. if len(unexpired) != 1 {
  119. t.Error("Should have 1 unexpired certificate record:", len(unexpired))
  120. }
  121. }
  122. func testUpdateCertificateAndGetCertificate(ta TestAccessor, t *testing.T) {
  123. ta.Truncate()
  124. expiry := time.Now().Add(time.Hour)
  125. want := certdb.CertificateRecord{
  126. PEM: "fake cert data",
  127. Serial: "fake serial 3",
  128. AKI: fakeAKI,
  129. Status: "good",
  130. Reason: 0,
  131. Expiry: expiry,
  132. }
  133. // Make sure the revoke on a non-existent cert fails
  134. if err := ta.Accessor.RevokeCertificate(want.Serial, want.AKI, 2); err == nil {
  135. t.Fatal("Expected error")
  136. }
  137. if err := ta.Accessor.InsertCertificate(want); err != nil {
  138. t.Fatal(err)
  139. }
  140. // reason 2 is CACompromise
  141. if err := ta.Accessor.RevokeCertificate(want.Serial, want.AKI, 2); err != nil {
  142. t.Fatal(err)
  143. }
  144. rets, err := ta.Accessor.GetCertificate(want.Serial, want.AKI)
  145. if err != nil {
  146. t.Fatal(err)
  147. }
  148. if len(rets) != 1 {
  149. t.Fatal("should return exactly one record")
  150. }
  151. got := rets[0]
  152. // relfection comparison with zero time objects are not stable as it seems
  153. if want.Serial != got.Serial || got.Status != "revoked" ||
  154. want.AKI != got.AKI || got.RevokedAt.IsZero() ||
  155. want.PEM != got.PEM {
  156. t.Errorf("want Certificate %+v, got %+v", want, got)
  157. }
  158. rets, err = ta.Accessor.GetRevokedAndUnexpiredCertificates()
  159. if err != nil {
  160. t.Fatal(err)
  161. }
  162. got = rets[0]
  163. // relfection comparison with zero time objects are not stable as it seems
  164. if want.Serial != got.Serial || got.Status != "revoked" ||
  165. want.AKI != got.AKI || got.RevokedAt.IsZero() ||
  166. want.PEM != got.PEM {
  167. t.Errorf("want Certificate %+v, got %+v", want, got)
  168. }
  169. rets, err = ta.Accessor.GetRevokedAndUnexpiredCertificatesByLabel("")
  170. if err != nil {
  171. t.Fatal(err)
  172. }
  173. got = rets[0]
  174. // relfection comparison with zero time objects are not stable as it seems
  175. if want.Serial != got.Serial || got.Status != "revoked" ||
  176. want.AKI != got.AKI || got.RevokedAt.IsZero() ||
  177. want.PEM != got.PEM {
  178. t.Errorf("want Certificate %+v, got %+v", want, got)
  179. }
  180. }
  181. func testInsertOCSPAndGetOCSP(ta TestAccessor, t *testing.T) {
  182. ta.Truncate()
  183. expiry := time.Date(2010, time.December, 25, 23, 0, 0, 0, time.UTC)
  184. want := certdb.OCSPRecord{
  185. Serial: "fake serial",
  186. AKI: fakeAKI,
  187. Body: "fake body",
  188. Expiry: expiry,
  189. }
  190. setupGoodCert(ta, t, want)
  191. if err := ta.Accessor.InsertOCSP(want); err != nil {
  192. t.Fatal(err)
  193. }
  194. rets, err := ta.Accessor.GetOCSP(want.Serial, want.AKI)
  195. if err != nil {
  196. t.Fatal(err)
  197. }
  198. if len(rets) != 1 {
  199. t.Fatal("should return exactly one record")
  200. }
  201. got := rets[0]
  202. if want.Serial != got.Serial || want.Body != got.Body ||
  203. !roughlySameTime(want.Expiry, got.Expiry) {
  204. t.Errorf("want OCSP %+v, got %+v", want, got)
  205. }
  206. unexpired, err := ta.Accessor.GetUnexpiredOCSPs()
  207. if err != nil {
  208. t.Fatal(err)
  209. }
  210. if len(unexpired) != 0 {
  211. t.Error("should not have unexpired certificate record")
  212. }
  213. }
  214. func testInsertOCSPAndGetUnexpiredOCSP(ta TestAccessor, t *testing.T) {
  215. ta.Truncate()
  216. want := certdb.OCSPRecord{
  217. Serial: "fake serial 2",
  218. AKI: fakeAKI,
  219. Body: "fake body",
  220. Expiry: time.Now().Add(time.Minute),
  221. }
  222. setupGoodCert(ta, t, want)
  223. if err := ta.Accessor.InsertOCSP(want); err != nil {
  224. t.Fatal(err)
  225. }
  226. rets, err := ta.Accessor.GetOCSP(want.Serial, want.AKI)
  227. if err != nil {
  228. t.Fatal(err)
  229. }
  230. if len(rets) != 1 {
  231. t.Fatal("should return exactly one record")
  232. }
  233. got := rets[0]
  234. if want.Serial != got.Serial || want.Body != got.Body ||
  235. !roughlySameTime(want.Expiry, got.Expiry) {
  236. t.Errorf("want OCSP %+v, got %+v", want, got)
  237. }
  238. unexpired, err := ta.Accessor.GetUnexpiredOCSPs()
  239. if err != nil {
  240. t.Fatal(err)
  241. }
  242. if len(unexpired) != 1 {
  243. t.Error("should not have other than 1 unexpired certificate record:", len(unexpired))
  244. }
  245. }
  246. func testUpdateOCSPAndGetOCSP(ta TestAccessor, t *testing.T) {
  247. ta.Truncate()
  248. want := certdb.OCSPRecord{
  249. Serial: "fake serial 3",
  250. AKI: fakeAKI,
  251. Body: "fake body",
  252. Expiry: time.Date(2010, time.December, 25, 23, 0, 0, 0, time.UTC),
  253. }
  254. setupGoodCert(ta, t, want)
  255. // Make sure the update fails
  256. if err := ta.Accessor.UpdateOCSP(want.Serial, want.AKI, want.Body, want.Expiry); err == nil {
  257. t.Fatal("Expected error")
  258. }
  259. if err := ta.Accessor.InsertOCSP(want); err != nil {
  260. t.Fatal(err)
  261. }
  262. want.Body = "fake body revoked"
  263. newExpiry := time.Now().Add(time.Hour)
  264. if err := ta.Accessor.UpdateOCSP(want.Serial, want.AKI, want.Body, newExpiry); err != nil {
  265. t.Fatal(err)
  266. }
  267. rets, err := ta.Accessor.GetOCSP(want.Serial, want.AKI)
  268. if err != nil {
  269. t.Fatal(err)
  270. }
  271. if len(rets) != 1 {
  272. t.Fatal("should return exactly one record")
  273. }
  274. got := rets[0]
  275. want.Expiry = newExpiry
  276. if want.Serial != got.Serial || got.Body != "fake body revoked" ||
  277. !roughlySameTime(newExpiry, got.Expiry) {
  278. t.Errorf("want OCSP %+v, got %+v", want, got)
  279. }
  280. }
  281. func testUpsertOCSPAndGetOCSP(ta TestAccessor, t *testing.T) {
  282. ta.Truncate()
  283. want := certdb.OCSPRecord{
  284. Serial: "fake serial 3",
  285. AKI: fakeAKI,
  286. Body: "fake body",
  287. Expiry: time.Date(2010, time.December, 25, 23, 0, 0, 0, time.UTC),
  288. }
  289. setupGoodCert(ta, t, want)
  290. if err := ta.Accessor.UpsertOCSP(want.Serial, want.AKI, want.Body, want.Expiry); err != nil {
  291. t.Fatal(err)
  292. }
  293. rets, err := ta.Accessor.GetOCSP(want.Serial, want.AKI)
  294. if err != nil {
  295. t.Fatal(err)
  296. }
  297. if len(rets) != 1 {
  298. t.Fatal("should return exactly one record")
  299. }
  300. got := rets[0]
  301. if want.Serial != got.Serial || want.Body != got.Body ||
  302. !roughlySameTime(want.Expiry, got.Expiry) {
  303. t.Errorf("want OCSP %+v, got %+v", want, got)
  304. }
  305. newExpiry := time.Now().Add(time.Hour)
  306. if err := ta.Accessor.UpsertOCSP(want.Serial, want.AKI, "fake body revoked", newExpiry); err != nil {
  307. t.Fatal(err)
  308. }
  309. rets, err = ta.Accessor.GetOCSP(want.Serial, want.AKI)
  310. if err != nil {
  311. t.Fatal(err)
  312. }
  313. if len(rets) != 1 {
  314. t.Fatal("should return exactly one record")
  315. }
  316. got = rets[0]
  317. want.Expiry = newExpiry
  318. if want.Serial != got.Serial || got.Body != "fake body revoked" ||
  319. !roughlySameTime(newExpiry, got.Expiry) {
  320. t.Errorf("want OCSP %+v, got %+v", want, got)
  321. }
  322. }
  323. func setupGoodCert(ta TestAccessor, t *testing.T, r certdb.OCSPRecord) {
  324. certWant := certdb.CertificateRecord{
  325. AKI: r.AKI,
  326. CALabel: "default",
  327. Expiry: time.Now().Add(time.Minute),
  328. PEM: "fake cert data",
  329. Serial: r.Serial,
  330. Status: "good",
  331. Reason: 0,
  332. }
  333. if err := ta.Accessor.InsertCertificate(certWant); err != nil {
  334. t.Fatal(err)
  335. }
  336. }