testdb.go 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. package testdb
  2. import (
  3. "os"
  4. "strings"
  5. _ "github.com/go-sql-driver/mysql" // register mysql driver
  6. "github.com/jmoiron/sqlx"
  7. _ "github.com/lib/pq" // register postgresql driver
  8. _ "github.com/mattn/go-sqlite3" // register sqlite3 driver
  9. )
  10. const (
  11. mysqlTruncateTables = `
  12. TRUNCATE certificates;
  13. TRUNCATE ocsp_responses;
  14. `
  15. pgTruncateTables = `
  16. CREATE OR REPLACE FUNCTION truncate_tables() RETURNS void AS $$
  17. DECLARE
  18. statements CURSOR FOR
  19. SELECT tablename FROM pg_tables
  20. WHERE tablename != 'goose_db_version'
  21. AND tableowner = session_user
  22. AND schemaname = 'public';
  23. BEGIN
  24. FOR stmt IN statements LOOP
  25. EXECUTE 'TRUNCATE TABLE ' || quote_ident(stmt.tablename) || ' CASCADE;';
  26. END LOOP;
  27. END;
  28. $$ LANGUAGE plpgsql;
  29. SELECT truncate_tables();
  30. `
  31. sqliteTruncateTables = `
  32. DELETE FROM certificates;
  33. DELETE FROM ocsp_responses;
  34. `
  35. )
  36. // MySQLDB returns a MySQL db instance for certdb testing.
  37. func MySQLDB() *sqlx.DB {
  38. connStr := "root@tcp(localhost:3306)/certdb_development?parseTime=true"
  39. if dbURL := os.Getenv("DATABASE_URL"); dbURL != "" {
  40. connStr = dbURL
  41. }
  42. db, err := sqlx.Open("mysql", connStr)
  43. if err != nil {
  44. panic(err)
  45. }
  46. Truncate(db)
  47. return db
  48. }
  49. // PostgreSQLDB returns a PostgreSQL db instance for certdb testing.
  50. func PostgreSQLDB() *sqlx.DB {
  51. connStr := "dbname=certdb_development sslmode=disable user=postgres"
  52. if dbURL := os.Getenv("DATABASE_URL"); dbURL != "" {
  53. connStr = dbURL
  54. }
  55. db, err := sqlx.Open("postgres", connStr)
  56. if err != nil {
  57. panic(err)
  58. }
  59. Truncate(db)
  60. return db
  61. }
  62. // SQLiteDB returns a SQLite db instance for certdb testing.
  63. func SQLiteDB(dbpath string) *sqlx.DB {
  64. db, err := sqlx.Open("sqlite3", dbpath)
  65. if err != nil {
  66. panic(err)
  67. }
  68. Truncate(db)
  69. return db
  70. }
  71. // Truncate truncates the DB
  72. func Truncate(db *sqlx.DB) {
  73. var sql []string
  74. switch db.DriverName() {
  75. case "mysql":
  76. sql = strings.Split(mysqlTruncateTables, "\n")
  77. case "postgres":
  78. sql = []string{pgTruncateTables}
  79. case "sqlite3":
  80. sql = []string{sqliteTruncateTables}
  81. default:
  82. panic("Unknown driver")
  83. }
  84. for _, expr := range sql {
  85. if len(strings.TrimSpace(expr)) == 0 {
  86. continue
  87. }
  88. if _, err := db.Exec(expr); err != nil {
  89. panic(err)
  90. }
  91. }
  92. }