packet.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. package native
  2. import (
  3. "bufio"
  4. "github.com/ziutek/mymysql/mysql"
  5. "io"
  6. "io/ioutil"
  7. )
  8. type pktReader struct {
  9. rd *bufio.Reader
  10. seq *byte
  11. remain int
  12. last bool
  13. buf [8]byte
  14. ibuf [3]byte
  15. }
  16. func (my *Conn) newPktReader() *pktReader {
  17. return &pktReader{rd: my.rd, seq: &my.seq}
  18. }
  19. func (pr *pktReader) readHeader() {
  20. // Read next packet header
  21. buf := pr.ibuf[:]
  22. for {
  23. n, err := pr.rd.Read(buf)
  24. if err != nil {
  25. panic(err)
  26. }
  27. buf = buf[n:]
  28. if len(buf) == 0 {
  29. break
  30. }
  31. }
  32. pr.remain = int(DecodeU24(pr.ibuf[:]))
  33. seq, err := pr.rd.ReadByte()
  34. if err != nil {
  35. panic(err)
  36. }
  37. // Chceck sequence number
  38. if *pr.seq != seq {
  39. panic(mysql.ErrSeq)
  40. }
  41. *pr.seq++
  42. // Last packet?
  43. pr.last = (pr.remain != 0xffffff)
  44. }
  45. func (pr *pktReader) readFull(buf []byte) {
  46. for len(buf) > 0 {
  47. if pr.remain == 0 {
  48. if pr.last {
  49. // No more packets
  50. panic(io.EOF)
  51. }
  52. pr.readHeader()
  53. }
  54. n := len(buf)
  55. if n > pr.remain {
  56. n = pr.remain
  57. }
  58. n, err := pr.rd.Read(buf[:n])
  59. pr.remain -= n
  60. if err != nil {
  61. panic(err)
  62. }
  63. buf = buf[n:]
  64. }
  65. return
  66. }
  67. func (pr *pktReader) readByte() byte {
  68. if pr.remain == 0 {
  69. if pr.last {
  70. // No more packets
  71. panic(io.EOF)
  72. }
  73. pr.readHeader()
  74. }
  75. b, err := pr.rd.ReadByte()
  76. if err != nil {
  77. panic(err)
  78. }
  79. pr.remain--
  80. return b
  81. }
  82. func (pr *pktReader) readAll() (buf []byte) {
  83. m := 0
  84. for {
  85. if pr.remain == 0 {
  86. if pr.last {
  87. break
  88. }
  89. pr.readHeader()
  90. }
  91. new_buf := make([]byte, m+pr.remain)
  92. copy(new_buf, buf)
  93. buf = new_buf
  94. n, err := pr.rd.Read(buf[m:])
  95. pr.remain -= n
  96. m += n
  97. if err != nil {
  98. panic(err)
  99. }
  100. }
  101. return
  102. }
  103. func (pr *pktReader) skipAll() {
  104. for {
  105. if pr.remain == 0 {
  106. if pr.last {
  107. break
  108. }
  109. pr.readHeader()
  110. }
  111. n, err := io.CopyN(ioutil.Discard, pr.rd, int64(pr.remain))
  112. pr.remain -= int(n)
  113. if err != nil {
  114. panic(err)
  115. }
  116. }
  117. return
  118. }
  119. func (pr *pktReader) skipN(n int) {
  120. for n > 0 {
  121. if pr.remain == 0 {
  122. if pr.last {
  123. panic(io.EOF)
  124. }
  125. pr.readHeader()
  126. }
  127. m := int64(n)
  128. if n > pr.remain {
  129. m = int64(pr.remain)
  130. }
  131. m, err := io.CopyN(ioutil.Discard, pr.rd, m)
  132. pr.remain -= int(m)
  133. n -= int(m)
  134. if err != nil {
  135. panic(err)
  136. }
  137. }
  138. return
  139. }
  140. func (pr *pktReader) unreadByte() {
  141. if err := pr.rd.UnreadByte(); err != nil {
  142. panic(err)
  143. }
  144. pr.remain++
  145. }
  146. func (pr *pktReader) eof() bool {
  147. return pr.remain == 0 && pr.last
  148. }
  149. func (pr *pktReader) checkEof() {
  150. if !pr.eof() {
  151. panic(mysql.ErrPktLong)
  152. }
  153. }
  154. type pktWriter struct {
  155. wr *bufio.Writer
  156. seq *byte
  157. remain int
  158. to_write int
  159. last bool
  160. buf [23]byte
  161. ibuf [3]byte
  162. }
  163. func (my *Conn) newPktWriter(to_write int) *pktWriter {
  164. return &pktWriter{wr: my.wr, seq: &my.seq, to_write: to_write}
  165. }
  166. func (pw *pktWriter) writeHeader(l int) {
  167. buf := pw.ibuf[:]
  168. EncodeU24(buf, uint32(l))
  169. if _, err := pw.wr.Write(buf); err != nil {
  170. panic(err)
  171. }
  172. if err := pw.wr.WriteByte(*pw.seq); err != nil {
  173. panic(err)
  174. }
  175. // Update sequence number
  176. *pw.seq++
  177. }
  178. func (pw *pktWriter) write(buf []byte) {
  179. if len(buf) == 0 {
  180. return
  181. }
  182. var nn int
  183. for len(buf) != 0 {
  184. if pw.remain == 0 {
  185. if pw.to_write == 0 {
  186. panic("too many data for write as packet")
  187. }
  188. if pw.to_write >= 0xffffff {
  189. pw.remain = 0xffffff
  190. } else {
  191. pw.remain = pw.to_write
  192. pw.last = true
  193. }
  194. pw.to_write -= pw.remain
  195. pw.writeHeader(pw.remain)
  196. }
  197. nn = len(buf)
  198. if nn > pw.remain {
  199. nn = pw.remain
  200. }
  201. var err error
  202. nn, err = pw.wr.Write(buf[0:nn])
  203. pw.remain -= nn
  204. if err != nil {
  205. panic(err)
  206. }
  207. buf = buf[nn:]
  208. }
  209. if pw.remain+pw.to_write == 0 {
  210. if !pw.last {
  211. // Write header for empty packet
  212. pw.writeHeader(0)
  213. }
  214. // Flush bufio buffers
  215. if err := pw.wr.Flush(); err != nil {
  216. panic(err)
  217. }
  218. }
  219. return
  220. }
  221. func (pw *pktWriter) writeByte(b byte) {
  222. pw.buf[0] = b
  223. pw.write(pw.buf[:1])
  224. }
  225. // n should be <= 23
  226. func (pw *pktWriter) writeZeros(n int) {
  227. buf := pw.buf[:n]
  228. for i := range buf {
  229. buf[i] = 0
  230. }
  231. pw.write(buf)
  232. }