bind.go 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. package sqlx
  2. import (
  3. "bytes"
  4. "database/sql/driver"
  5. "errors"
  6. "reflect"
  7. "strconv"
  8. "strings"
  9. "sync"
  10. "github.com/jmoiron/sqlx/reflectx"
  11. )
  12. // Bindvar types supported by Rebind, BindMap and BindStruct.
  13. const (
  14. UNKNOWN = iota
  15. QUESTION
  16. DOLLAR
  17. NAMED
  18. AT
  19. )
  20. var defaultBinds = map[int][]string{
  21. DOLLAR: []string{"postgres", "pgx", "pq-timeouts", "cloudsqlpostgres", "ql", "nrpostgres", "cockroach"},
  22. QUESTION: []string{"mysql", "sqlite3", "nrmysql", "nrsqlite3"},
  23. NAMED: []string{"oci8", "ora", "goracle", "godror"},
  24. AT: []string{"sqlserver"},
  25. }
  26. var binds sync.Map
  27. func init() {
  28. for bind, drivers := range defaultBinds {
  29. for _, driver := range drivers {
  30. BindDriver(driver, bind)
  31. }
  32. }
  33. }
  34. // BindType returns the bindtype for a given database given a drivername.
  35. func BindType(driverName string) int {
  36. itype, ok := binds.Load(driverName)
  37. if !ok {
  38. return UNKNOWN
  39. }
  40. return itype.(int)
  41. }
  42. // BindDriver sets the BindType for driverName to bindType.
  43. func BindDriver(driverName string, bindType int) {
  44. binds.Store(driverName, bindType)
  45. }
  46. // FIXME: this should be able to be tolerant of escaped ?'s in queries without
  47. // losing much speed, and should be to avoid confusion.
  48. // Rebind a query from the default bindtype (QUESTION) to the target bindtype.
  49. func Rebind(bindType int, query string) string {
  50. switch bindType {
  51. case QUESTION, UNKNOWN:
  52. return query
  53. }
  54. // Add space enough for 10 params before we have to allocate
  55. rqb := make([]byte, 0, len(query)+10)
  56. var i, j int
  57. for i = strings.Index(query, "?"); i != -1; i = strings.Index(query, "?") {
  58. rqb = append(rqb, query[:i]...)
  59. switch bindType {
  60. case DOLLAR:
  61. rqb = append(rqb, '$')
  62. case NAMED:
  63. rqb = append(rqb, ':', 'a', 'r', 'g')
  64. case AT:
  65. rqb = append(rqb, '@', 'p')
  66. }
  67. j++
  68. rqb = strconv.AppendInt(rqb, int64(j), 10)
  69. query = query[i+1:]
  70. }
  71. return string(append(rqb, query...))
  72. }
  73. // Experimental implementation of Rebind which uses a bytes.Buffer. The code is
  74. // much simpler and should be more resistant to odd unicode, but it is twice as
  75. // slow. Kept here for benchmarking purposes and to possibly replace Rebind if
  76. // problems arise with its somewhat naive handling of unicode.
  77. func rebindBuff(bindType int, query string) string {
  78. if bindType != DOLLAR {
  79. return query
  80. }
  81. b := make([]byte, 0, len(query))
  82. rqb := bytes.NewBuffer(b)
  83. j := 1
  84. for _, r := range query {
  85. if r == '?' {
  86. rqb.WriteRune('$')
  87. rqb.WriteString(strconv.Itoa(j))
  88. j++
  89. } else {
  90. rqb.WriteRune(r)
  91. }
  92. }
  93. return rqb.String()
  94. }
  95. func asSliceForIn(i interface{}) (v reflect.Value, ok bool) {
  96. if i == nil {
  97. return reflect.Value{}, false
  98. }
  99. v = reflect.ValueOf(i)
  100. t := reflectx.Deref(v.Type())
  101. // Only expand slices
  102. if t.Kind() != reflect.Slice {
  103. return reflect.Value{}, false
  104. }
  105. // []byte is a driver.Value type so it should not be expanded
  106. if t == reflect.TypeOf([]byte{}) {
  107. return reflect.Value{}, false
  108. }
  109. return v, true
  110. }
  111. // In expands slice values in args, returning the modified query string
  112. // and a new arg list that can be executed by a database. The `query` should
  113. // use the `?` bindVar. The return value uses the `?` bindVar.
  114. func In(query string, args ...interface{}) (string, []interface{}, error) {
  115. // argMeta stores reflect.Value and length for slices and
  116. // the value itself for non-slice arguments
  117. type argMeta struct {
  118. v reflect.Value
  119. i interface{}
  120. length int
  121. }
  122. var flatArgsCount int
  123. var anySlices bool
  124. var stackMeta [32]argMeta
  125. var meta []argMeta
  126. if len(args) <= len(stackMeta) {
  127. meta = stackMeta[:len(args)]
  128. } else {
  129. meta = make([]argMeta, len(args))
  130. }
  131. for i, arg := range args {
  132. if a, ok := arg.(driver.Valuer); ok {
  133. var err error
  134. arg, err = a.Value()
  135. if err != nil {
  136. return "", nil, err
  137. }
  138. }
  139. if v, ok := asSliceForIn(arg); ok {
  140. meta[i].length = v.Len()
  141. meta[i].v = v
  142. anySlices = true
  143. flatArgsCount += meta[i].length
  144. if meta[i].length == 0 {
  145. return "", nil, errors.New("empty slice passed to 'in' query")
  146. }
  147. } else {
  148. meta[i].i = arg
  149. flatArgsCount++
  150. }
  151. }
  152. // don't do any parsing if there aren't any slices; note that this means
  153. // some errors that we might have caught below will not be returned.
  154. if !anySlices {
  155. return query, args, nil
  156. }
  157. newArgs := make([]interface{}, 0, flatArgsCount)
  158. var buf strings.Builder
  159. buf.Grow(len(query) + len(", ?")*flatArgsCount)
  160. var arg, offset int
  161. for i := strings.IndexByte(query[offset:], '?'); i != -1; i = strings.IndexByte(query[offset:], '?') {
  162. if arg >= len(meta) {
  163. // if an argument wasn't passed, lets return an error; this is
  164. // not actually how database/sql Exec/Query works, but since we are
  165. // creating an argument list programmatically, we want to be able
  166. // to catch these programmer errors earlier.
  167. return "", nil, errors.New("number of bindVars exceeds arguments")
  168. }
  169. argMeta := meta[arg]
  170. arg++
  171. // not a slice, continue.
  172. // our questionmark will either be written before the next expansion
  173. // of a slice or after the loop when writing the rest of the query
  174. if argMeta.length == 0 {
  175. offset = offset + i + 1
  176. newArgs = append(newArgs, argMeta.i)
  177. continue
  178. }
  179. // write everything up to and including our ? character
  180. buf.WriteString(query[:offset+i+1])
  181. for si := 1; si < argMeta.length; si++ {
  182. buf.WriteString(", ?")
  183. }
  184. newArgs = appendReflectSlice(newArgs, argMeta.v, argMeta.length)
  185. // slice the query and reset the offset. this avoids some bookkeeping for
  186. // the write after the loop
  187. query = query[offset+i+1:]
  188. offset = 0
  189. }
  190. buf.WriteString(query)
  191. if arg < len(meta) {
  192. return "", nil, errors.New("number of bindVars less than number arguments")
  193. }
  194. return buf.String(), newArgs, nil
  195. }
  196. func appendReflectSlice(args []interface{}, v reflect.Value, vlen int) []interface{} {
  197. switch val := v.Interface().(type) {
  198. case []interface{}:
  199. args = append(args, val...)
  200. case []int:
  201. for i := range val {
  202. args = append(args, val[i])
  203. }
  204. case []string:
  205. for i := range val {
  206. args = append(args, val[i])
  207. }
  208. default:
  209. for si := 0; si < vlen; si++ {
  210. args = append(args, v.Index(si).Interface())
  211. }
  212. }
  213. return args
  214. }