main.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. package main
  2. import (
  3. "fmt"
  4. "log"
  5. "os"
  6. "sort"
  7. "strings"
  8. "github.com/olekukonko/tablewriter"
  9. "github.com/pkg/errors"
  10. "gopkg.in/DATA-DOG/go-sqlmock.v2"
  11. "gorm.io/driver/mysql"
  12. "gorm.io/driver/postgres"
  13. "gorm.io/driver/sqlite"
  14. "gorm.io/gorm"
  15. "gorm.io/gorm/clause"
  16. "gorm.io/gorm/schema"
  17. "gogs.io/gogs/internal/db"
  18. )
  19. //go:generate go run main.go ../../../docs/dev/database_schema.md
  20. func main() {
  21. w, err := os.Create(os.Args[1])
  22. if err != nil {
  23. log.Fatalf("Failed to create file: %v", err)
  24. }
  25. defer func() { _ = w.Close() }()
  26. conn, _, err := sqlmock.New()
  27. if err != nil {
  28. log.Fatalf("Failed to get mock connection: %v", err)
  29. }
  30. defer func() { _ = conn.Close() }()
  31. dialectors := []gorm.Dialector{
  32. postgres.New(postgres.Config{
  33. Conn: conn,
  34. }),
  35. mysql.New(mysql.Config{
  36. Conn: conn,
  37. SkipInitializeWithVersion: true,
  38. }),
  39. sqlite.Open(""),
  40. }
  41. collected := make([][]*tableInfo, 0, len(dialectors))
  42. for i, dialector := range dialectors {
  43. tableInfos, err := generate(dialector)
  44. if err != nil {
  45. log.Fatalf("Failed to get table info of %d: %v", i, err)
  46. }
  47. collected = append(collected, tableInfos)
  48. }
  49. for i, ti := range collected[0] {
  50. _, _ = w.WriteString(`# Table "` + ti.Name + `"`)
  51. _, _ = w.WriteString("\n\n")
  52. _, _ = w.WriteString("```\n")
  53. table := tablewriter.NewWriter(w)
  54. table.SetHeader([]string{"Field", "Column", "PostgreSQL", "MySQL", "SQLite3"})
  55. table.SetBorder(false)
  56. for j, f := range ti.Fields {
  57. table.Append([]string{
  58. f.Name, f.Column,
  59. strings.ToUpper(f.Type), // PostgreSQL
  60. strings.ToUpper(collected[1][i].Fields[j].Type), // MySQL
  61. strings.ToUpper(collected[2][i].Fields[j].Type), // SQLite3
  62. })
  63. }
  64. table.Render()
  65. _, _ = w.WriteString("\n")
  66. _, _ = w.WriteString("Primary keys: ")
  67. _, _ = w.WriteString(strings.Join(ti.PrimaryKeys, ", "))
  68. _, _ = w.WriteString("\n")
  69. if len(ti.Indexes) > 0 {
  70. _, _ = w.WriteString("Indexes: \n")
  71. for _, index := range ti.Indexes {
  72. _, _ = w.WriteString(fmt.Sprintf("\t%q", index.Name))
  73. if index.Class != "" {
  74. _, _ = w.WriteString(fmt.Sprintf(" %s", index.Class))
  75. }
  76. if index.Type != "" {
  77. _, _ = w.WriteString(fmt.Sprintf(", %s", index.Type))
  78. }
  79. if len(index.Fields) > 0 {
  80. fields := make([]string, len(index.Fields))
  81. for i := range index.Fields {
  82. fields[i] = index.Fields[i].DBName
  83. }
  84. _, _ = w.WriteString(fmt.Sprintf(" (%s)", strings.Join(fields, ", ")))
  85. }
  86. _, _ = w.WriteString("\n")
  87. }
  88. }
  89. _, _ = w.WriteString("```\n\n")
  90. }
  91. }
  92. type tableField struct {
  93. Name string
  94. Column string
  95. Type string
  96. }
  97. type tableInfo struct {
  98. Name string
  99. Fields []*tableField
  100. PrimaryKeys []string
  101. Indexes []schema.Index
  102. }
  103. // This function is derived from gorm.io/gorm/migrator/migrator.go:Migrator.CreateTable.
  104. func generate(dialector gorm.Dialector) ([]*tableInfo, error) {
  105. conn, err := gorm.Open(dialector,
  106. &gorm.Config{
  107. SkipDefaultTransaction: true,
  108. NamingStrategy: schema.NamingStrategy{
  109. SingularTable: true,
  110. },
  111. DryRun: true,
  112. DisableAutomaticPing: true,
  113. },
  114. )
  115. if err != nil {
  116. return nil, errors.Wrap(err, "open database")
  117. }
  118. m := conn.Migrator().(interface {
  119. RunWithValue(value any, fc func(*gorm.Statement) error) error
  120. FullDataTypeOf(*schema.Field) clause.Expr
  121. })
  122. tableInfos := make([]*tableInfo, 0, len(db.Tables))
  123. for _, table := range db.Tables {
  124. err = m.RunWithValue(table, func(stmt *gorm.Statement) error {
  125. fields := make([]*tableField, 0, len(stmt.Schema.DBNames))
  126. for _, field := range stmt.Schema.Fields {
  127. if field.DBName == "" {
  128. continue
  129. }
  130. fields = append(fields, &tableField{
  131. Name: field.Name,
  132. Column: field.DBName,
  133. Type: m.FullDataTypeOf(field).SQL,
  134. })
  135. }
  136. primaryKeys := make([]string, 0, len(stmt.Schema.PrimaryFields))
  137. if len(stmt.Schema.PrimaryFields) > 0 {
  138. for _, field := range stmt.Schema.PrimaryFields {
  139. primaryKeys = append(primaryKeys, field.DBName)
  140. }
  141. }
  142. var indexes []schema.Index
  143. for _, index := range stmt.Schema.ParseIndexes() {
  144. indexes = append(indexes, index)
  145. }
  146. sort.Slice(indexes, func(i, j int) bool {
  147. return indexes[i].Name < indexes[j].Name
  148. })
  149. tableInfos = append(tableInfos, &tableInfo{
  150. Name: stmt.Table,
  151. Fields: fields,
  152. PrimaryKeys: primaryKeys,
  153. Indexes: indexes,
  154. })
  155. return nil
  156. })
  157. if err != nil {
  158. return nil, errors.Wrap(err, "gather table information")
  159. }
  160. }
  161. return tableInfos, nil
  162. }