sql.go 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. package dbconnect
  2. import (
  3. "context"
  4. "database/sql"
  5. "encoding/json"
  6. "fmt"
  7. "net/url"
  8. "reflect"
  9. "strings"
  10. "github.com/jmoiron/sqlx"
  11. "github.com/pkg/errors"
  12. "github.com/xo/dburl"
  13. // SQL drivers self-register with the database/sql package.
  14. // https://github.com/golang/go/wiki/SQLDrivers
  15. _ "github.com/denisenkom/go-mssqldb"
  16. _ "github.com/go-sql-driver/mysql"
  17. _ "github.com/mattn/go-sqlite3"
  18. "github.com/kshvakov/clickhouse"
  19. "github.com/lib/pq"
  20. )
  21. // SQLClient is a Client that talks to a SQL database.
  22. type SQLClient struct {
  23. Dialect string
  24. driver *sqlx.DB
  25. }
  26. // NewSQLClient creates a SQL client based on its URL scheme.
  27. func NewSQLClient(ctx context.Context, originURL *url.URL) (Client, error) {
  28. res, err := dburl.Parse(originURL.String())
  29. if err != nil {
  30. helpText := fmt.Sprintf("supported drivers: %+q, see documentation for more details: %s", sql.Drivers(), "https://godoc.org/github.com/xo/dburl")
  31. return nil, fmt.Errorf("could not parse sql database url '%s': %s\n%s", originURL, err.Error(), helpText)
  32. }
  33. // Establishes the driver, but does not test the connection.
  34. driver, err := sqlx.Open(res.Driver, res.DSN)
  35. if err != nil {
  36. return nil, fmt.Errorf("could not open sql driver %s: %s\n%s", res.Driver, err.Error(), res.DSN)
  37. }
  38. // Closes the driver, will occur when the context finishes.
  39. go func() {
  40. <-ctx.Done()
  41. driver.Close()
  42. }()
  43. return &SQLClient{driver.DriverName(), driver}, nil
  44. }
  45. // Ping verifies a connection to the database is still alive.
  46. func (client *SQLClient) Ping(ctx context.Context) error {
  47. return client.driver.PingContext(ctx)
  48. }
  49. // Submit queries or executes a command to the SQL database.
  50. func (client *SQLClient) Submit(ctx context.Context, cmd *Command) (interface{}, error) {
  51. txx, err := cmd.ValidateSQL(client.Dialect)
  52. if err != nil {
  53. return nil, err
  54. }
  55. ctx, cancel := context.WithTimeout(ctx, cmd.Timeout)
  56. defer cancel()
  57. var res interface{}
  58. // Get the next available sql.Conn and submit the Command.
  59. err = sqlConn(ctx, client.driver, txx, func(conn *sql.Conn) error {
  60. stmt := cmd.Statement
  61. args := cmd.Arguments.Positional
  62. if cmd.Mode == "query" {
  63. res, err = sqlQuery(ctx, conn, stmt, args)
  64. } else {
  65. res, err = sqlExec(ctx, conn, stmt, args)
  66. }
  67. return err
  68. })
  69. return res, err
  70. }
  71. // ValidateSQL extends the contract of Command for SQL dialects:
  72. // mode is conformed, arguments are []sql.NamedArg, and isolation is a sql.IsolationLevel.
  73. //
  74. // When the command should not be wrapped in a transaction, *sql.TxOptions and error will both be nil.
  75. func (cmd *Command) ValidateSQL(dialect string) (*sql.TxOptions, error) {
  76. err := cmd.Validate()
  77. if err != nil {
  78. return nil, err
  79. }
  80. mode, err := sqlMode(cmd.Mode)
  81. if err != nil {
  82. return nil, err
  83. }
  84. // Mutates Arguments to only use positional arguments with the type sql.NamedArg.
  85. // This is a required by the sql.Driver before submitting arguments.
  86. cmd.Arguments.sql(dialect)
  87. iso, err := sqlIsolation(cmd.Isolation)
  88. if err != nil {
  89. return nil, err
  90. }
  91. // When isolation is out-of-range, this is indicative that no
  92. // transaction should be executed and sql.TxOptions should be nil.
  93. if iso < sql.LevelDefault {
  94. return nil, nil
  95. }
  96. // In query mode, execute the transaction in read-only, unless it's Microsoft SQL
  97. // which does not support that type of transaction.
  98. readOnly := mode == "query" && dialect != "mssql"
  99. return &sql.TxOptions{Isolation: iso, ReadOnly: readOnly}, nil
  100. }
  101. // sqlConn gets the next available sql.Conn in the connection pool and runs a function to use it.
  102. //
  103. // If the transaction options are nil, run the useIt function outside a transaction.
  104. // This is potentially an unsafe operation if the command does not clean up its state.
  105. func sqlConn(ctx context.Context, driver *sqlx.DB, txx *sql.TxOptions, useIt func(*sql.Conn) error) error {
  106. conn, err := driver.Conn(ctx)
  107. if err != nil {
  108. return err
  109. }
  110. defer conn.Close()
  111. // If transaction options are specified, begin and defer a rollback to catch errors.
  112. var tx *sql.Tx
  113. if txx != nil {
  114. tx, err = conn.BeginTx(ctx, txx)
  115. if err != nil {
  116. return err
  117. }
  118. defer tx.Rollback()
  119. }
  120. err = useIt(conn)
  121. // Check if useIt was successful and a transaction exists before committing.
  122. if err == nil && tx != nil {
  123. err = tx.Commit()
  124. }
  125. return err
  126. }
  127. // sqlQuery queries rows on a sql.Conn and returns an array of result objects.
  128. func sqlQuery(ctx context.Context, conn *sql.Conn, stmt string, args []interface{}) ([]map[string]interface{}, error) {
  129. rows, err := conn.QueryContext(ctx, stmt, args...)
  130. if err == nil {
  131. return sqlRows(rows)
  132. }
  133. return nil, err
  134. }
  135. // sqlExec executes a command on a sql.Conn and returns the result of the operation.
  136. func sqlExec(ctx context.Context, conn *sql.Conn, stmt string, args []interface{}) (sqlResult, error) {
  137. exec, err := conn.ExecContext(ctx, stmt, args...)
  138. if err == nil {
  139. return sqlResultFrom(exec), nil
  140. }
  141. return sqlResult{}, err
  142. }
  143. // sql mutates Arguments to contain a positional []sql.NamedArg.
  144. //
  145. // The actual return type is []interface{} due to the native Golang
  146. // function signatures for sql.Exec and sql.Query being generic.
  147. func (args *Arguments) sql(dialect string) {
  148. result := args.Positional
  149. for i, val := range result {
  150. result[i] = sqlArg("", val, dialect)
  151. }
  152. for key, val := range args.Named {
  153. result = append(result, sqlArg(key, val, dialect))
  154. }
  155. args.Positional = result
  156. args.Named = map[string]interface{}{}
  157. }
  158. // sqlArg creates a sql.NamedArg from a key-value pair and an optional dialect.
  159. //
  160. // Certain dialects will need to wrap objects, such as arrays, to conform its driver requirements.
  161. func sqlArg(key, val interface{}, dialect string) sql.NamedArg {
  162. switch reflect.ValueOf(val).Kind() {
  163. // PostgreSQL and Clickhouse require arrays to be wrapped before
  164. // being inserted into the driver interface.
  165. case reflect.Slice, reflect.Array:
  166. switch dialect {
  167. case "postgres":
  168. val = pq.Array(val)
  169. case "clickhouse":
  170. val = clickhouse.Array(val)
  171. }
  172. }
  173. return sql.Named(fmt.Sprint(key), val)
  174. }
  175. // sqlIsolation tries to match a string to a sql.IsolationLevel.
  176. func sqlIsolation(str string) (sql.IsolationLevel, error) {
  177. if str == "none" {
  178. return sql.IsolationLevel(-1), nil
  179. }
  180. for iso := sql.LevelDefault; ; iso++ {
  181. if iso > sql.LevelLinearizable {
  182. return -1, fmt.Errorf("cannot provide an invalid sql isolation level: '%s'", str)
  183. }
  184. if str == "" || strings.EqualFold(iso.String(), strings.ReplaceAll(str, "_", " ")) {
  185. return iso, nil
  186. }
  187. }
  188. }
  189. // sqlMode tries to match a string to a command mode: 'query' or 'exec' for now.
  190. func sqlMode(str string) (string, error) {
  191. switch str {
  192. case "query", "exec":
  193. return str, nil
  194. default:
  195. return "", fmt.Errorf("cannot provide invalid sql mode: '%s'", str)
  196. }
  197. }
  198. // sqlRows scans through a SQL result set and returns an array of objects.
  199. func sqlRows(rows *sql.Rows) ([]map[string]interface{}, error) {
  200. columns, err := rows.Columns()
  201. if err != nil {
  202. return nil, errors.Wrap(err, "could not extract columns from result")
  203. }
  204. defer rows.Close()
  205. types, err := rows.ColumnTypes()
  206. if err != nil {
  207. // Some drivers do not support type extraction, so fail silently and continue.
  208. types = make([]*sql.ColumnType, len(columns))
  209. }
  210. values := make([]interface{}, len(columns))
  211. pointers := make([]interface{}, len(columns))
  212. var results []map[string]interface{}
  213. for rows.Next() {
  214. for i := range columns {
  215. pointers[i] = &values[i]
  216. }
  217. rows.Scan(pointers...)
  218. // Convert a row, an array of values, into an object where
  219. // each key is the name of its respective column.
  220. entry := make(map[string]interface{})
  221. for i, col := range columns {
  222. entry[col] = sqlValue(values[i], types[i])
  223. }
  224. results = append(results, entry)
  225. }
  226. return results, nil
  227. }
  228. // sqlValue handles special cases where sql.Rows does not return a "human-readable" object.
  229. func sqlValue(val interface{}, col *sql.ColumnType) interface{} {
  230. bytes, ok := val.([]byte)
  231. if ok {
  232. // Opportunistically check for embeded JSON and convert it to a first-class object.
  233. var embeded interface{}
  234. if json.Unmarshal(bytes, &embeded) == nil {
  235. return embeded
  236. }
  237. // STOR-604: investigate a way to coerce PostgreSQL arrays '{a, b, ...}' into JSON.
  238. // Although easy with strings, it becomes more difficult with special types like INET[].
  239. return string(bytes)
  240. }
  241. return val
  242. }
  243. // sqlResult is a thin wrapper around sql.Result.
  244. type sqlResult struct {
  245. LastInsertId int64 `json:"last_insert_id"`
  246. RowsAffected int64 `json:"rows_affected"`
  247. }
  248. // sqlResultFrom converts sql.Result into a JSON-marshable sqlResult.
  249. func sqlResultFrom(res sql.Result) sqlResult {
  250. insertID, errID := res.LastInsertId()
  251. rowsAffected, errRows := res.RowsAffected()
  252. // If an error occurs when extracting the result, it is because the
  253. // driver does not support that specific field. Instead of passing this
  254. // to the user, omit the field in the response.
  255. if errID != nil {
  256. insertID = -1
  257. }
  258. if errRows != nil {
  259. rowsAffected = -1
  260. }
  261. return sqlResult{insertID, rowsAffected}
  262. }