dbtest.go 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. // Copyright 2020 The Gogs Authors. All rights reserved.
  2. // Use of this source code is governed by a MIT-style
  3. // license that can be found in the LICENSE file.
  4. package dbtest
  5. import (
  6. "database/sql"
  7. "fmt"
  8. "os"
  9. "path/filepath"
  10. "testing"
  11. "time"
  12. "github.com/stretchr/testify/require"
  13. "gorm.io/gorm"
  14. "gorm.io/gorm/schema"
  15. "gogs.io/gogs/internal/conf"
  16. "gogs.io/gogs/internal/dbutil"
  17. )
  18. // NewDB creates a new test database and initializes the given list of tables
  19. // for the suite. The test database is dropped after testing is completed unless
  20. // failed.
  21. func NewDB(t *testing.T, suite string, tables ...any) *gorm.DB {
  22. dbType := os.Getenv("GOGS_DATABASE_TYPE")
  23. var dbName string
  24. var dbOpts conf.DatabaseOpts
  25. var cleanup func(db *gorm.DB)
  26. switch dbType {
  27. case "mysql":
  28. dbOpts = conf.DatabaseOpts{
  29. Type: "mysql",
  30. Host: os.ExpandEnv("$MYSQL_HOST:$MYSQL_PORT"),
  31. Name: dbName,
  32. User: os.Getenv("MYSQL_USER"),
  33. Password: os.Getenv("MYSQL_PASSWORD"),
  34. }
  35. dsn, err := dbutil.NewDSN(dbOpts)
  36. require.NoError(t, err)
  37. sqlDB, err := sql.Open("mysql", dsn)
  38. require.NoError(t, err)
  39. // Set up test database
  40. dbName = fmt.Sprintf("gogs-%s-%d", suite, time.Now().Unix())
  41. _, err = sqlDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", dbName))
  42. require.NoError(t, err)
  43. _, err = sqlDB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", dbName))
  44. require.NoError(t, err)
  45. dbOpts.Name = dbName
  46. cleanup = func(db *gorm.DB) {
  47. testDB, err := db.DB()
  48. if err == nil {
  49. _ = testDB.Close()
  50. }
  51. _, _ = sqlDB.Exec(fmt.Sprintf("DROP DATABASE `%s`", dbName))
  52. _ = sqlDB.Close()
  53. }
  54. case "postgres":
  55. dbOpts = conf.DatabaseOpts{
  56. Type: "postgres",
  57. Host: os.ExpandEnv("$PGHOST:$PGPORT"),
  58. Name: dbName,
  59. Schema: "public",
  60. User: os.Getenv("PGUSER"),
  61. Password: os.Getenv("PGPASSWORD"),
  62. SSLMode: os.Getenv("PGSSLMODE"),
  63. }
  64. dsn, err := dbutil.NewDSN(dbOpts)
  65. require.NoError(t, err)
  66. sqlDB, err := sql.Open("pgx", dsn)
  67. require.NoError(t, err)
  68. // Set up test database
  69. dbName = fmt.Sprintf("gogs-%s-%d", suite, time.Now().Unix())
  70. _, err = sqlDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %q", dbName))
  71. require.NoError(t, err)
  72. _, err = sqlDB.Exec(fmt.Sprintf("CREATE DATABASE %q", dbName))
  73. require.NoError(t, err)
  74. dbOpts.Name = dbName
  75. cleanup = func(db *gorm.DB) {
  76. testDB, err := db.DB()
  77. if err == nil {
  78. _ = testDB.Close()
  79. }
  80. _, _ = sqlDB.Exec(fmt.Sprintf(`DROP DATABASE %q`, dbName))
  81. _ = sqlDB.Close()
  82. }
  83. case "sqlite":
  84. dbName = filepath.Join(os.TempDir(), fmt.Sprintf("gogs-%s-%d.db", suite, time.Now().Unix()))
  85. dbOpts = conf.DatabaseOpts{
  86. Type: "sqlite",
  87. Path: dbName,
  88. }
  89. cleanup = func(db *gorm.DB) {
  90. sqlDB, err := db.DB()
  91. if err == nil {
  92. _ = sqlDB.Close()
  93. }
  94. _ = os.Remove(dbName)
  95. }
  96. default:
  97. dbName = filepath.Join(os.TempDir(), fmt.Sprintf("gogs-%s-%d.db", suite, time.Now().Unix()))
  98. dbOpts = conf.DatabaseOpts{
  99. Type: "sqlite3",
  100. Path: dbName,
  101. }
  102. cleanup = func(db *gorm.DB) {
  103. sqlDB, err := db.DB()
  104. if err == nil {
  105. _ = sqlDB.Close()
  106. }
  107. _ = os.Remove(dbName)
  108. }
  109. }
  110. now := time.Now().UTC().Truncate(time.Second)
  111. db, err := dbutil.OpenDB(
  112. dbOpts,
  113. &gorm.Config{
  114. SkipDefaultTransaction: true,
  115. NamingStrategy: schema.NamingStrategy{
  116. SingularTable: true,
  117. },
  118. NowFunc: func() time.Time {
  119. return now
  120. },
  121. },
  122. )
  123. require.NoError(t, err)
  124. t.Cleanup(func() {
  125. if t.Failed() {
  126. t.Logf("Database %q left intact for inspection", dbName)
  127. return
  128. }
  129. cleanup(db)
  130. })
  131. err = db.Migrator().AutoMigrate(tables...)
  132. require.NoError(t, err)
  133. return db
  134. }