copy.go 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. package pq
  2. import (
  3. "bytes"
  4. "context"
  5. "database/sql/driver"
  6. "encoding/binary"
  7. "errors"
  8. "fmt"
  9. "sync"
  10. )
  11. var (
  12. errCopyInClosed = errors.New("pq: copyin statement has already been closed")
  13. errBinaryCopyNotSupported = errors.New("pq: only text format supported for COPY")
  14. errCopyToNotSupported = errors.New("pq: COPY TO is not supported")
  15. errCopyNotSupportedOutsideTxn = errors.New("pq: COPY is only allowed inside a transaction")
  16. errCopyInProgress = errors.New("pq: COPY in progress")
  17. )
  18. // CopyIn creates a COPY FROM statement which can be prepared with
  19. // Tx.Prepare(). The target table should be visible in search_path.
  20. func CopyIn(table string, columns ...string) string {
  21. buffer := bytes.NewBufferString("COPY ")
  22. BufferQuoteIdentifier(table, buffer)
  23. buffer.WriteString(" (")
  24. makeStmt(buffer, columns...)
  25. return buffer.String()
  26. }
  27. // MakeStmt makes the stmt string for CopyIn and CopyInSchema.
  28. func makeStmt(buffer *bytes.Buffer, columns ...string) {
  29. //s := bytes.NewBufferString()
  30. for i, col := range columns {
  31. if i != 0 {
  32. buffer.WriteString(", ")
  33. }
  34. BufferQuoteIdentifier(col, buffer)
  35. }
  36. buffer.WriteString(") FROM STDIN")
  37. }
  38. // CopyInSchema creates a COPY FROM statement which can be prepared with
  39. // Tx.Prepare().
  40. func CopyInSchema(schema, table string, columns ...string) string {
  41. buffer := bytes.NewBufferString("COPY ")
  42. BufferQuoteIdentifier(schema, buffer)
  43. buffer.WriteRune('.')
  44. BufferQuoteIdentifier(table, buffer)
  45. buffer.WriteString(" (")
  46. makeStmt(buffer, columns...)
  47. return buffer.String()
  48. }
  49. type copyin struct {
  50. cn *conn
  51. buffer []byte
  52. rowData chan []byte
  53. done chan bool
  54. closed bool
  55. mu struct {
  56. sync.Mutex
  57. err error
  58. driver.Result
  59. }
  60. }
  61. const ciBufferSize = 64 * 1024
  62. // flush buffer before the buffer is filled up and needs reallocation
  63. const ciBufferFlushSize = 63 * 1024
  64. func (cn *conn) prepareCopyIn(q string) (_ driver.Stmt, err error) {
  65. if !cn.isInTransaction() {
  66. return nil, errCopyNotSupportedOutsideTxn
  67. }
  68. ci := &copyin{
  69. cn: cn,
  70. buffer: make([]byte, 0, ciBufferSize),
  71. rowData: make(chan []byte),
  72. done: make(chan bool, 1),
  73. }
  74. // add CopyData identifier + 4 bytes for message length
  75. ci.buffer = append(ci.buffer, 'd', 0, 0, 0, 0)
  76. b := cn.writeBuf('Q')
  77. b.string(q)
  78. cn.send(b)
  79. awaitCopyInResponse:
  80. for {
  81. t, r := cn.recv1()
  82. switch t {
  83. case 'G':
  84. if r.byte() != 0 {
  85. err = errBinaryCopyNotSupported
  86. break awaitCopyInResponse
  87. }
  88. go ci.resploop()
  89. return ci, nil
  90. case 'H':
  91. err = errCopyToNotSupported
  92. break awaitCopyInResponse
  93. case 'E':
  94. err = parseError(r)
  95. case 'Z':
  96. if err == nil {
  97. ci.setBad(driver.ErrBadConn)
  98. errorf("unexpected ReadyForQuery in response to COPY")
  99. }
  100. cn.processReadyForQuery(r)
  101. return nil, err
  102. default:
  103. ci.setBad(driver.ErrBadConn)
  104. errorf("unknown response for copy query: %q", t)
  105. }
  106. }
  107. // something went wrong, abort COPY before we return
  108. b = cn.writeBuf('f')
  109. b.string(err.Error())
  110. cn.send(b)
  111. for {
  112. t, r := cn.recv1()
  113. switch t {
  114. case 'c', 'C', 'E':
  115. case 'Z':
  116. // correctly aborted, we're done
  117. cn.processReadyForQuery(r)
  118. return nil, err
  119. default:
  120. ci.setBad(driver.ErrBadConn)
  121. errorf("unknown response for CopyFail: %q", t)
  122. }
  123. }
  124. }
  125. func (ci *copyin) flush(buf []byte) {
  126. // set message length (without message identifier)
  127. binary.BigEndian.PutUint32(buf[1:], uint32(len(buf)-1))
  128. _, err := ci.cn.c.Write(buf)
  129. if err != nil {
  130. panic(err)
  131. }
  132. }
  133. func (ci *copyin) resploop() {
  134. for {
  135. var r readBuf
  136. t, err := ci.cn.recvMessage(&r)
  137. if err != nil {
  138. ci.setBad(driver.ErrBadConn)
  139. ci.setError(err)
  140. ci.done <- true
  141. return
  142. }
  143. switch t {
  144. case 'C':
  145. // complete
  146. res, _ := ci.cn.parseComplete(r.string())
  147. ci.setResult(res)
  148. case 'N':
  149. if n := ci.cn.noticeHandler; n != nil {
  150. n(parseError(&r))
  151. }
  152. case 'Z':
  153. ci.cn.processReadyForQuery(&r)
  154. ci.done <- true
  155. return
  156. case 'E':
  157. err := parseError(&r)
  158. ci.setError(err)
  159. default:
  160. ci.setBad(driver.ErrBadConn)
  161. ci.setError(fmt.Errorf("unknown response during CopyIn: %q", t))
  162. ci.done <- true
  163. return
  164. }
  165. }
  166. }
  167. func (ci *copyin) setBad(err error) {
  168. ci.cn.err.set(err)
  169. }
  170. func (ci *copyin) getBad() error {
  171. return ci.cn.err.get()
  172. }
  173. func (ci *copyin) err() error {
  174. ci.mu.Lock()
  175. err := ci.mu.err
  176. ci.mu.Unlock()
  177. return err
  178. }
  179. // setError() sets ci.err if one has not been set already. Caller must not be
  180. // holding ci.Mutex.
  181. func (ci *copyin) setError(err error) {
  182. ci.mu.Lock()
  183. if ci.mu.err == nil {
  184. ci.mu.err = err
  185. }
  186. ci.mu.Unlock()
  187. }
  188. func (ci *copyin) setResult(result driver.Result) {
  189. ci.mu.Lock()
  190. ci.mu.Result = result
  191. ci.mu.Unlock()
  192. }
  193. func (ci *copyin) getResult() driver.Result {
  194. ci.mu.Lock()
  195. result := ci.mu.Result
  196. ci.mu.Unlock()
  197. if result == nil {
  198. return driver.RowsAffected(0)
  199. }
  200. return result
  201. }
  202. func (ci *copyin) NumInput() int {
  203. return -1
  204. }
  205. func (ci *copyin) Query(v []driver.Value) (r driver.Rows, err error) {
  206. return nil, ErrNotSupported
  207. }
  208. // Exec inserts values into the COPY stream. The insert is asynchronous
  209. // and Exec can return errors from previous Exec calls to the same
  210. // COPY stmt.
  211. //
  212. // You need to call Exec(nil) to sync the COPY stream and to get any
  213. // errors from pending data, since Stmt.Close() doesn't return errors
  214. // to the user.
  215. func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) {
  216. if ci.closed {
  217. return nil, errCopyInClosed
  218. }
  219. if err := ci.getBad(); err != nil {
  220. return nil, err
  221. }
  222. defer ci.cn.errRecover(&err)
  223. if err := ci.err(); err != nil {
  224. return nil, err
  225. }
  226. if len(v) == 0 {
  227. if err := ci.Close(); err != nil {
  228. return driver.RowsAffected(0), err
  229. }
  230. return ci.getResult(), nil
  231. }
  232. numValues := len(v)
  233. for i, value := range v {
  234. ci.buffer = appendEncodedText(&ci.cn.parameterStatus, ci.buffer, value)
  235. if i < numValues-1 {
  236. ci.buffer = append(ci.buffer, '\t')
  237. }
  238. }
  239. ci.buffer = append(ci.buffer, '\n')
  240. if len(ci.buffer) > ciBufferFlushSize {
  241. ci.flush(ci.buffer)
  242. // reset buffer, keep bytes for message identifier and length
  243. ci.buffer = ci.buffer[:5]
  244. }
  245. return driver.RowsAffected(0), nil
  246. }
  247. // CopyData inserts a raw string into the COPY stream. The insert is
  248. // asynchronous and CopyData can return errors from previous CopyData calls to
  249. // the same COPY stmt.
  250. //
  251. // You need to call Exec(nil) to sync the COPY stream and to get any
  252. // errors from pending data, since Stmt.Close() doesn't return errors
  253. // to the user.
  254. func (ci *copyin) CopyData(ctx context.Context, line string) (r driver.Result, err error) {
  255. if ci.closed {
  256. return nil, errCopyInClosed
  257. }
  258. if finish := ci.cn.watchCancel(ctx); finish != nil {
  259. defer finish()
  260. }
  261. if err := ci.getBad(); err != nil {
  262. return nil, err
  263. }
  264. defer ci.cn.errRecover(&err)
  265. if err := ci.err(); err != nil {
  266. return nil, err
  267. }
  268. ci.buffer = append(ci.buffer, []byte(line)...)
  269. ci.buffer = append(ci.buffer, '\n')
  270. if len(ci.buffer) > ciBufferFlushSize {
  271. ci.flush(ci.buffer)
  272. // reset buffer, keep bytes for message identifier and length
  273. ci.buffer = ci.buffer[:5]
  274. }
  275. return driver.RowsAffected(0), nil
  276. }
  277. func (ci *copyin) Close() (err error) {
  278. if ci.closed { // Don't do anything, we're already closed
  279. return nil
  280. }
  281. ci.closed = true
  282. if err := ci.getBad(); err != nil {
  283. return err
  284. }
  285. defer ci.cn.errRecover(&err)
  286. if len(ci.buffer) > 0 {
  287. ci.flush(ci.buffer)
  288. }
  289. // Avoid touching the scratch buffer as resploop could be using it.
  290. err = ci.cn.sendSimpleMessage('c')
  291. if err != nil {
  292. return err
  293. }
  294. <-ci.done
  295. ci.cn.inCopy = false
  296. if err := ci.err(); err != nil {
  297. return err
  298. }
  299. return nil
  300. }