sqlstruct.go 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. // Copyright 2012 Kamil Kisiel. All rights reserved.
  2. // Use of this source code is governed by the MIT
  3. // license which can be found in the LICENSE file.
  4. /*
  5. Package sqlstruct provides some convenience functions for using structs with
  6. the Go standard library's database/sql package.
  7. The package matches struct field names to SQL query column names. A field can
  8. also specify a matching column with "sql" tag, if it's different from field
  9. name. Unexported fields or fields marked with `sql:"-"` are ignored, just like
  10. with "encoding/json" package.
  11. For example:
  12. type T struct {
  13. F1 string
  14. F2 string `sql:"field2"`
  15. F3 string `sql:"-"`
  16. }
  17. rows, err := db.Query(fmt.Sprintf("SELECT %s FROM tablename", sqlstruct.Columns(T{})))
  18. ...
  19. for rows.Next() {
  20. var t T
  21. err = sqlstruct.Scan(&t, rows)
  22. ...
  23. }
  24. err = rows.Err() // get any errors encountered during iteration
  25. Aliased tables in a SQL statement may be scanned into a specific structure identified
  26. by the same alias, using the ColumnsAliased and ScanAliased functions:
  27. type User struct {
  28. Id int `sql:"id"`
  29. Username string `sql:"username"`
  30. Email string `sql:"address"`
  31. Name string `sql:"name"`
  32. HomeAddress *Address `sql:"-"`
  33. }
  34. type Address struct {
  35. Id int `sql:"id"`
  36. City string `sql:"city"`
  37. Street string `sql:"address"`
  38. }
  39. ...
  40. var user User
  41. var address Address
  42. sql := `
  43. SELECT %s, %s FROM users AS u
  44. INNER JOIN address AS a ON a.id = u.address_id
  45. WHERE u.username = ?
  46. `
  47. sql = fmt.Sprintf(sql, sqlstruct.ColumnsAliased(*user, "u"), sqlstruct.ColumnsAliased(*address, "a"))
  48. rows, err := db.Query(sql, "gedi")
  49. if err != nil {
  50. log.Fatal(err)
  51. }
  52. defer rows.Close()
  53. if rows.Next() {
  54. err = sqlstruct.ScanAliased(&user, rows, "u")
  55. if err != nil {
  56. log.Fatal(err)
  57. }
  58. err = sqlstruct.ScanAliased(&address, rows, "a")
  59. if err != nil {
  60. log.Fatal(err)
  61. }
  62. user.HomeAddress = address
  63. }
  64. fmt.Printf("%+v", *user)
  65. // output: "{Id:1 Username:gedi Email:gediminas.morkevicius@gmail.com Name:Gedas HomeAddress:0xc21001f570}"
  66. fmt.Printf("%+v", *user.HomeAddress)
  67. // output: "{Id:2 City:Vilnius Street:Plento 34}"
  68. */
  69. package sqlstruct
  70. import (
  71. "bytes"
  72. "database/sql"
  73. "fmt"
  74. "reflect"
  75. "sort"
  76. "strings"
  77. "sync"
  78. )
  79. // NameMapper is the function used to convert struct fields which do not have sql tags
  80. // into database column names.
  81. //
  82. // The default mapper converts field names to lower case. If instead you would prefer
  83. // field names converted to snake case, simply assign sqlstruct.ToSnakeCase to the variable:
  84. //
  85. // sqlstruct.NameMapper = sqlstruct.ToSnakeCase
  86. //
  87. // Alternatively for a custom mapping, any func(string) string can be used instead.
  88. var NameMapper func(string) string = strings.ToLower
  89. // A cache of fieldInfos to save reflecting every time. Inspried by encoding/xml
  90. var finfos map[reflect.Type]fieldInfo
  91. var finfoLock sync.RWMutex
  92. // TagName is the name of the tag to use on struct fields
  93. var TagName = "sql"
  94. // fieldInfo is a mapping of field tag values to their indices
  95. type fieldInfo map[string][]int
  96. func init() {
  97. finfos = make(map[reflect.Type]fieldInfo)
  98. }
  99. // Rows defines the interface of types that are scannable with the Scan function.
  100. // It is implemented by the sql.Rows type from the standard library
  101. type Rows interface {
  102. Scan(...interface{}) error
  103. Columns() ([]string, error)
  104. }
  105. // getFieldInfo creates a fieldInfo for the provided type. Fields that are not tagged
  106. // with the "sql" tag and unexported fields are not included.
  107. func getFieldInfo(typ reflect.Type) fieldInfo {
  108. finfoLock.RLock()
  109. finfo, ok := finfos[typ]
  110. finfoLock.RUnlock()
  111. if ok {
  112. return finfo
  113. }
  114. finfo = make(fieldInfo)
  115. n := typ.NumField()
  116. for i := 0; i < n; i++ {
  117. f := typ.Field(i)
  118. tag := f.Tag.Get(TagName)
  119. // Skip unexported fields or fields marked with "-"
  120. if f.PkgPath != "" || tag == "-" {
  121. continue
  122. }
  123. // Handle embedded structs
  124. if f.Anonymous && f.Type.Kind() == reflect.Struct {
  125. for k, v := range getFieldInfo(f.Type) {
  126. finfo[k] = append([]int{i}, v...)
  127. }
  128. continue
  129. }
  130. // Use field name for untagged fields
  131. if tag == "" {
  132. tag = f.Name
  133. }
  134. tag = NameMapper(tag)
  135. finfo[tag] = []int{i}
  136. }
  137. finfoLock.Lock()
  138. finfos[typ] = finfo
  139. finfoLock.Unlock()
  140. return finfo
  141. }
  142. // Scan scans the next row from rows in to a struct pointed to by dest. The struct type
  143. // should have exported fields tagged with the "sql" tag. Columns from row which are not
  144. // mapped to any struct fields are ignored. Struct fields which have no matching column
  145. // in the result set are left unchanged.
  146. func Scan(dest interface{}, rows Rows) error {
  147. return doScan(dest, rows, "")
  148. }
  149. // ScanAliased works like scan, except that it expects the results in the query to be
  150. // prefixed by the given alias.
  151. //
  152. // For example, if scanning to a field named "name" with an alias of "user" it will
  153. // expect to find the result in a column named "user_name".
  154. //
  155. // See ColumnAliased for a convenient way to generate these queries.
  156. func ScanAliased(dest interface{}, rows Rows, alias string) error {
  157. return doScan(dest, rows, alias)
  158. }
  159. // Columns returns a string containing a sorted, comma-separated list of column names as
  160. // defined by the type s. s must be a struct that has exported fields tagged with the "sql" tag.
  161. func Columns(s interface{}) string {
  162. return strings.Join(cols(s), ", ")
  163. }
  164. // ColumnsAliased works like Columns except it prefixes the resulting column name with the
  165. // given alias.
  166. //
  167. // For each field in the given struct it will generate a statement like:
  168. // alias.field AS alias_field
  169. //
  170. // It is intended to be used in conjunction with the ScanAliased function.
  171. func ColumnsAliased(s interface{}, alias string) string {
  172. names := cols(s)
  173. aliased := make([]string, 0, len(names))
  174. for _, n := range names {
  175. aliased = append(aliased, alias+"."+n+" AS "+alias+"_"+n)
  176. }
  177. return strings.Join(aliased, ", ")
  178. }
  179. func cols(s interface{}) []string {
  180. v := reflect.ValueOf(s)
  181. fields := getFieldInfo(v.Type())
  182. names := make([]string, 0, len(fields))
  183. for f := range fields {
  184. names = append(names, f)
  185. }
  186. sort.Strings(names)
  187. return names
  188. }
  189. func doScan(dest interface{}, rows Rows, alias string) error {
  190. destv := reflect.ValueOf(dest)
  191. typ := destv.Type()
  192. if typ.Kind() != reflect.Ptr || typ.Elem().Kind() != reflect.Struct {
  193. panic(fmt.Errorf("dest must be pointer to struct; got %T", destv))
  194. }
  195. fieldInfo := getFieldInfo(typ.Elem())
  196. elem := destv.Elem()
  197. var values []interface{}
  198. cols, err := rows.Columns()
  199. if err != nil {
  200. return err
  201. }
  202. for _, name := range cols {
  203. if len(alias) > 0 {
  204. name = strings.Replace(name, alias+"_", "", 1)
  205. }
  206. idx, ok := fieldInfo[strings.ToLower(name)]
  207. var v interface{}
  208. if !ok {
  209. // There is no field mapped to this column so we discard it
  210. v = &sql.RawBytes{}
  211. } else {
  212. v = elem.FieldByIndex(idx).Addr().Interface()
  213. }
  214. values = append(values, v)
  215. }
  216. return rows.Scan(values...)
  217. }
  218. // ToSnakeCase converts a string to snake case, words separated with underscores.
  219. // It's intended to be used with NameMapper to map struct field names to snake case database fields.
  220. func ToSnakeCase(src string) string {
  221. thisUpper := false
  222. prevUpper := false
  223. buf := bytes.NewBufferString("")
  224. for i, v := range src {
  225. if v >= 'A' && v <= 'Z' {
  226. thisUpper = true
  227. } else {
  228. thisUpper = false
  229. }
  230. if i > 0 && thisUpper && !prevUpper {
  231. buf.WriteRune('_')
  232. }
  233. prevUpper = thisUpper
  234. buf.WriteRune(v)
  235. }
  236. return strings.ToLower(buf.String())
  237. }