notices_test.go 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. // Copyright 2023 The Gogs Authors. All rights reserved.
  2. // Use of this source code is governed by a MIT-style
  3. // license that can be found in the LICENSE file.
  4. package db
  5. import (
  6. "context"
  7. "testing"
  8. "time"
  9. "github.com/stretchr/testify/assert"
  10. "github.com/stretchr/testify/require"
  11. "gorm.io/gorm"
  12. "gogs.io/gogs/internal/dbtest"
  13. )
  14. func TestNotice_BeforeCreate(t *testing.T) {
  15. now := time.Now()
  16. db := &gorm.DB{
  17. Config: &gorm.Config{
  18. SkipDefaultTransaction: true,
  19. NowFunc: func() time.Time {
  20. return now
  21. },
  22. },
  23. }
  24. t.Run("CreatedUnix has been set", func(t *testing.T) {
  25. notice := &Notice{
  26. CreatedUnix: 1,
  27. }
  28. _ = notice.BeforeCreate(db)
  29. assert.Equal(t, int64(1), notice.CreatedUnix)
  30. })
  31. t.Run("CreatedUnix has not been set", func(t *testing.T) {
  32. notice := &Notice{}
  33. _ = notice.BeforeCreate(db)
  34. assert.Equal(t, db.NowFunc().Unix(), notice.CreatedUnix)
  35. })
  36. }
  37. func TestNotice_AfterFind(t *testing.T) {
  38. now := time.Now()
  39. db := &gorm.DB{
  40. Config: &gorm.Config{
  41. SkipDefaultTransaction: true,
  42. NowFunc: func() time.Time {
  43. return now
  44. },
  45. },
  46. }
  47. notice := &Notice{
  48. CreatedUnix: now.Unix(),
  49. }
  50. _ = notice.AfterFind(db)
  51. assert.Equal(t, notice.CreatedUnix, notice.Created.Unix())
  52. }
  53. func TestNotices(t *testing.T) {
  54. if testing.Short() {
  55. t.Skip()
  56. }
  57. t.Parallel()
  58. ctx := context.Background()
  59. tables := []any{new(Notice)}
  60. db := &notices{
  61. DB: dbtest.NewDB(t, "notices", tables...),
  62. }
  63. for _, tc := range []struct {
  64. name string
  65. test func(t *testing.T, ctx context.Context, db *notices)
  66. }{
  67. {"Create", noticesCreate},
  68. {"DeleteByIDs", noticesDeleteByIDs},
  69. {"DeleteAll", noticesDeleteAll},
  70. {"List", noticesList},
  71. {"Count", noticesCount},
  72. } {
  73. t.Run(tc.name, func(t *testing.T) {
  74. t.Cleanup(func() {
  75. err := clearTables(t, db.DB, tables...)
  76. require.NoError(t, err)
  77. })
  78. tc.test(t, ctx, db)
  79. })
  80. if t.Failed() {
  81. break
  82. }
  83. }
  84. }
  85. func noticesCreate(t *testing.T, ctx context.Context, db *notices) {
  86. err := db.Create(ctx, NoticeTypeRepository, "test")
  87. require.NoError(t, err)
  88. count := db.Count(ctx)
  89. assert.Equal(t, int64(1), count)
  90. }
  91. func noticesDeleteByIDs(t *testing.T, ctx context.Context, db *notices) {
  92. err := db.Create(ctx, NoticeTypeRepository, "test")
  93. require.NoError(t, err)
  94. notices, err := db.List(ctx, 1, 10)
  95. require.NoError(t, err)
  96. ids := make([]int64, 0, len(notices))
  97. for _, notice := range notices {
  98. ids = append(ids, notice.ID)
  99. }
  100. // Non-existing IDs should be ignored
  101. ids = append(ids, 404)
  102. err = db.DeleteByIDs(ctx, ids...)
  103. require.NoError(t, err)
  104. count := db.Count(ctx)
  105. assert.Equal(t, int64(0), count)
  106. }
  107. func noticesDeleteAll(t *testing.T, ctx context.Context, db *notices) {
  108. err := db.Create(ctx, NoticeTypeRepository, "test")
  109. require.NoError(t, err)
  110. err = db.DeleteAll(ctx)
  111. require.NoError(t, err)
  112. count := db.Count(ctx)
  113. assert.Equal(t, int64(0), count)
  114. }
  115. func noticesList(t *testing.T, ctx context.Context, db *notices) {
  116. err := db.Create(ctx, NoticeTypeRepository, "test 1")
  117. require.NoError(t, err)
  118. err = db.Create(ctx, NoticeTypeRepository, "test 2")
  119. require.NoError(t, err)
  120. got1, err := db.List(ctx, 1, 1)
  121. require.NoError(t, err)
  122. require.Len(t, got1, 1)
  123. got2, err := db.List(ctx, 2, 1)
  124. require.NoError(t, err)
  125. require.Len(t, got2, 1)
  126. assert.True(t, got1[0].ID > got2[0].ID)
  127. got, err := db.List(ctx, 1, 3)
  128. require.NoError(t, err)
  129. require.Len(t, got, 2)
  130. }
  131. func noticesCount(t *testing.T, ctx context.Context, db *notices) {
  132. count := db.Count(ctx)
  133. assert.Equal(t, int64(0), count)
  134. err := db.Create(ctx, NoticeTypeRepository, "test")
  135. require.NoError(t, err)
  136. count = db.Count(ctx)
  137. assert.Equal(t, int64(1), count)
  138. }