codecs.go 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508
  1. package native
  2. import (
  3. "bytes"
  4. "github.com/ziutek/mymysql/mysql"
  5. "io"
  6. "time"
  7. )
  8. // Integers
  9. func DecodeU16(buf []byte) uint16 {
  10. return uint16(buf[1])<<8 | uint16(buf[0])
  11. }
  12. func (pr *pktReader) readU16() uint16 {
  13. buf := pr.buf[:2]
  14. pr.readFull(buf)
  15. return DecodeU16(buf)
  16. }
  17. func DecodeU24(buf []byte) uint32 {
  18. return (uint32(buf[2])<<8|uint32(buf[1]))<<8 | uint32(buf[0])
  19. }
  20. func (pr *pktReader) readU24() uint32 {
  21. buf := pr.buf[:3]
  22. pr.readFull(buf)
  23. return DecodeU24(buf)
  24. }
  25. func DecodeU32(buf []byte) uint32 {
  26. return ((uint32(buf[3])<<8|uint32(buf[2]))<<8|
  27. uint32(buf[1]))<<8 | uint32(buf[0])
  28. }
  29. func (pr *pktReader) readU32() uint32 {
  30. buf := pr.buf[:4]
  31. pr.readFull(buf)
  32. return DecodeU32(buf)
  33. }
  34. func DecodeU64(buf []byte) (rv uint64) {
  35. for ii, vv := range buf {
  36. rv |= uint64(vv) << uint(ii*8)
  37. }
  38. return
  39. }
  40. func (pr *pktReader) readU64() (rv uint64) {
  41. buf := pr.buf[:8]
  42. pr.readFull(buf)
  43. return DecodeU64(buf)
  44. }
  45. func EncodeU16(buf []byte, val uint16) {
  46. buf[0] = byte(val)
  47. buf[1] = byte(val >> 8)
  48. }
  49. func (pw *pktWriter) writeU16(val uint16) {
  50. buf := pw.buf[:2]
  51. EncodeU16(buf, val)
  52. pw.write(buf)
  53. }
  54. func EncodeU24(buf []byte, val uint32) {
  55. buf[0] = byte(val)
  56. buf[1] = byte(val >> 8)
  57. buf[2] = byte(val >> 16)
  58. }
  59. func (pw *pktWriter) writeU24(val uint32) {
  60. buf := pw.buf[:3]
  61. EncodeU24(buf, val)
  62. pw.write(buf)
  63. }
  64. func EncodeU32(buf []byte, val uint32) {
  65. buf[0] = byte(val)
  66. buf[1] = byte(val >> 8)
  67. buf[2] = byte(val >> 16)
  68. buf[3] = byte(val >> 24)
  69. }
  70. func (pw *pktWriter) writeU32(val uint32) {
  71. buf := pw.buf[:4]
  72. EncodeU32(buf, val)
  73. pw.write(buf)
  74. }
  75. func EncodeU64(buf []byte, val uint64) {
  76. buf[0] = byte(val)
  77. buf[1] = byte(val >> 8)
  78. buf[2] = byte(val >> 16)
  79. buf[3] = byte(val >> 24)
  80. buf[4] = byte(val >> 32)
  81. buf[5] = byte(val >> 40)
  82. buf[6] = byte(val >> 48)
  83. buf[7] = byte(val >> 56)
  84. }
  85. func (pw *pktWriter) writeU64(val uint64) {
  86. buf := pw.buf[:8]
  87. EncodeU64(buf, val)
  88. pw.write(buf)
  89. }
  90. // Variable length values
  91. func (pr *pktReader) readNullLCB() (lcb uint64, null bool) {
  92. bb := pr.readByte()
  93. switch bb {
  94. case 251:
  95. null = true
  96. case 252:
  97. lcb = uint64(pr.readU16())
  98. case 253:
  99. lcb = uint64(pr.readU24())
  100. case 254:
  101. lcb = pr.readU64()
  102. default:
  103. lcb = uint64(bb)
  104. }
  105. return
  106. }
  107. func (pr *pktReader) readLCB() uint64 {
  108. lcb, null := pr.readNullLCB()
  109. if null {
  110. panic(mysql.ErrUnexpNullLCB)
  111. }
  112. return lcb
  113. }
  114. func (pw *pktWriter) writeLCB(val uint64) {
  115. switch {
  116. case val <= 250:
  117. pw.writeByte(byte(val))
  118. case val <= 0xffff:
  119. pw.writeByte(252)
  120. pw.writeU16(uint16(val))
  121. case val <= 0xffffff:
  122. pw.writeByte(253)
  123. pw.writeU24(uint32(val))
  124. default:
  125. pw.writeByte(254)
  126. pw.writeU64(val)
  127. }
  128. }
  129. func lenLCB(val uint64) int {
  130. switch {
  131. case val <= 250:
  132. return 1
  133. case val <= 0xffff:
  134. return 3
  135. case val <= 0xffffff:
  136. return 4
  137. }
  138. return 9
  139. }
  140. func (pr *pktReader) readNullBin() (buf []byte, null bool) {
  141. var l uint64
  142. l, null = pr.readNullLCB()
  143. if null {
  144. return
  145. }
  146. buf = make([]byte, l)
  147. pr.readFull(buf)
  148. return
  149. }
  150. func (pr *pktReader) readBin() []byte {
  151. buf, null := pr.readNullBin()
  152. if null {
  153. panic(mysql.ErrUnexpNullLCS)
  154. }
  155. return buf
  156. }
  157. func (pr *pktReader) skipBin() {
  158. n, _ := pr.readNullLCB()
  159. pr.skipN(int(n))
  160. }
  161. func (pw *pktWriter) writeBin(buf []byte) {
  162. pw.writeLCB(uint64(len(buf)))
  163. pw.write(buf)
  164. }
  165. func lenBin(buf []byte) int {
  166. return lenLCB(uint64(len(buf))) + len(buf)
  167. }
  168. func lenStr(str string) int {
  169. return lenLCB(uint64(len(str))) + len(str)
  170. }
  171. func (pw *pktWriter) writeLC(v interface{}) {
  172. switch val := v.(type) {
  173. case []byte:
  174. pw.writeBin(val)
  175. case *[]byte:
  176. pw.writeBin(*val)
  177. case string:
  178. pw.writeBin([]byte(val))
  179. case *string:
  180. pw.writeBin([]byte(*val))
  181. default:
  182. panic("Unknown data type for write as length coded string")
  183. }
  184. }
  185. func lenLC(v interface{}) int {
  186. switch val := v.(type) {
  187. case []byte:
  188. return lenBin(val)
  189. case *[]byte:
  190. return lenBin(*val)
  191. case string:
  192. return lenStr(val)
  193. case *string:
  194. return lenStr(*val)
  195. }
  196. panic("Unknown data type for write as length coded string")
  197. }
  198. func (pr *pktReader) readNTB() (buf []byte) {
  199. for {
  200. ch := pr.readByte()
  201. if ch == 0 {
  202. break
  203. }
  204. buf = append(buf, ch)
  205. }
  206. return
  207. }
  208. func (pw *pktWriter) writeNTB(buf []byte) {
  209. pw.write(buf)
  210. pw.writeByte(0)
  211. }
  212. func (pw *pktWriter) writeNT(v interface{}) {
  213. switch val := v.(type) {
  214. case []byte:
  215. pw.writeNTB(val)
  216. case string:
  217. pw.writeNTB([]byte(val))
  218. default:
  219. panic("Unknown type for write as null terminated data")
  220. }
  221. }
  222. // Date and time
  223. func (pr *pktReader) readDuration() time.Duration {
  224. dlen := pr.readByte()
  225. switch dlen {
  226. case 251:
  227. // Null
  228. panic(mysql.ErrUnexpNullTime)
  229. case 0:
  230. // 00:00:00
  231. return 0
  232. case 5, 8, 12:
  233. // Properly time length
  234. default:
  235. panic(mysql.ErrWrongDateLen)
  236. }
  237. buf := pr.buf[:dlen]
  238. pr.readFull(buf)
  239. tt := int64(0)
  240. switch dlen {
  241. case 12:
  242. // Nanosecond part
  243. tt += int64(DecodeU32(buf[8:]))
  244. fallthrough
  245. case 8:
  246. // HH:MM:SS part
  247. tt += int64(int(buf[5])*3600+int(buf[6])*60+int(buf[7])) * 1e9
  248. fallthrough
  249. case 5:
  250. // Day part
  251. tt += int64(DecodeU32(buf[1:5])) * (24 * 3600 * 1e9)
  252. }
  253. if buf[0] != 0 {
  254. tt = -tt
  255. }
  256. return time.Duration(tt)
  257. }
  258. func EncodeDuration(buf []byte, d time.Duration) int {
  259. buf[0] = 0
  260. if d < 0 {
  261. buf[1] = 1
  262. d = -d
  263. }
  264. if ns := uint32(d % 1e9); ns != 0 {
  265. EncodeU32(buf[9:13], ns) // nanosecond
  266. buf[0] += 4
  267. }
  268. d /= 1e9
  269. if hms := int(d % (24 * 3600)); buf[0] != 0 || hms != 0 {
  270. buf[8] = byte(hms % 60) // second
  271. hms /= 60
  272. buf[7] = byte(hms % 60) // minute
  273. buf[6] = byte(hms / 60) // hour
  274. buf[0] += 3
  275. }
  276. if day := uint32(d / (24 * 3600)); buf[0] != 0 || day != 0 {
  277. EncodeU32(buf[2:6], day) // day
  278. buf[0] += 4
  279. }
  280. buf[0]++ // For sign byte
  281. return int(buf[0] + 1)
  282. }
  283. func (pw *pktWriter) writeDuration(d time.Duration) {
  284. buf := pw.buf[:13]
  285. n := EncodeDuration(buf, d)
  286. pw.write(buf[:n])
  287. }
  288. func lenDuration(d time.Duration) int {
  289. if d == 0 {
  290. return 2
  291. }
  292. if d%1e9 != 0 {
  293. return 13
  294. }
  295. d /= 1e9
  296. if d%(24*3600) != 0 {
  297. return 9
  298. }
  299. return 6
  300. }
  301. func (pr *pktReader) readTime() time.Time {
  302. dlen := pr.readByte()
  303. switch dlen {
  304. case 251:
  305. // Null
  306. panic(mysql.ErrUnexpNullDate)
  307. case 0:
  308. // return 0000-00-00 converted to time.Time zero
  309. return time.Time{}
  310. case 4, 7, 11:
  311. // Properly datetime length
  312. default:
  313. panic(mysql.ErrWrongDateLen)
  314. }
  315. buf := pr.buf[:dlen]
  316. pr.readFull(buf)
  317. var y, mon, d, h, m, s, u int
  318. switch dlen {
  319. case 11:
  320. // 2006-01-02 15:04:05.001004005
  321. u = int(DecodeU32(buf[7:]))
  322. fallthrough
  323. case 7:
  324. // 2006-01-02 15:04:05
  325. h = int(buf[4])
  326. m = int(buf[5])
  327. s = int(buf[6])
  328. fallthrough
  329. case 4:
  330. // 2006-01-02
  331. y = int(DecodeU16(buf[0:2]))
  332. mon = int(buf[2])
  333. d = int(buf[3])
  334. }
  335. n := u * int(time.Microsecond)
  336. return time.Date(y, time.Month(mon), d, h, m, s, n, time.Local)
  337. }
  338. func encodeNonzeroTime(buf []byte, y int16, mon, d, h, m, s byte, n uint32) int {
  339. buf[0] = 0
  340. switch {
  341. case n != 0:
  342. EncodeU32(buf[8:12], n)
  343. buf[0] += 4
  344. fallthrough
  345. case s != 0 || m != 0 || h != 0:
  346. buf[7] = s
  347. buf[6] = m
  348. buf[5] = h
  349. buf[0] += 3
  350. }
  351. buf[4] = d
  352. buf[3] = mon
  353. EncodeU16(buf[1:3], uint16(y))
  354. buf[0] += 4
  355. return int(buf[0] + 1)
  356. }
  357. func getTimeMicroseconds(t time.Time) int {
  358. return t.Nanosecond()/int(time.Microsecond)
  359. }
  360. func EncodeTime(buf []byte, t time.Time) int {
  361. if t.IsZero() {
  362. // MySQL zero
  363. buf[0] = 0
  364. return 1 // MySQL zero
  365. }
  366. y, mon, d := t.Date()
  367. h, m, s := t.Clock()
  368. u:= getTimeMicroseconds(t)
  369. return encodeNonzeroTime(
  370. buf,
  371. int16(y), byte(mon), byte(d),
  372. byte(h), byte(m), byte(s), uint32(u),
  373. )
  374. }
  375. func (pw *pktWriter) writeTime(t time.Time) {
  376. buf := pw.buf[:12]
  377. n := EncodeTime(buf, t)
  378. pw.write(buf[:n])
  379. }
  380. func lenTime(t time.Time) int {
  381. switch {
  382. case t.IsZero():
  383. return 1
  384. case getTimeMicroseconds(t) != 0:
  385. return 12
  386. case t.Second() != 0 || t.Minute() != 0 || t.Hour() != 0:
  387. return 8
  388. }
  389. return 5
  390. }
  391. func (pr *pktReader) readDate() mysql.Date {
  392. y, m, d := pr.readTime().Date()
  393. return mysql.Date{int16(y), byte(m), byte(d)}
  394. }
  395. func EncodeDate(buf []byte, d mysql.Date) int {
  396. if d.IsZero() {
  397. // MySQL zero
  398. buf[0] = 0
  399. return 1
  400. }
  401. return encodeNonzeroTime(buf, d.Year, d.Month, d.Day, 0, 0, 0, 0)
  402. }
  403. func (pw *pktWriter) writeDate(d mysql.Date) {
  404. buf := pw.buf[:5]
  405. n := EncodeDate(buf, d)
  406. pw.write(buf[:n])
  407. }
  408. func lenDate(d mysql.Date) int {
  409. if d.IsZero() {
  410. return 1
  411. }
  412. return 5
  413. }
  414. func escapeString(txt string) string {
  415. var (
  416. esc string
  417. buf bytes.Buffer
  418. )
  419. last := 0
  420. for ii, bb := range txt {
  421. switch bb {
  422. case 0:
  423. esc = `\0`
  424. case '\n':
  425. esc = `\n`
  426. case '\r':
  427. esc = `\r`
  428. case '\\':
  429. esc = `\\`
  430. case '\'':
  431. esc = `\'`
  432. case '"':
  433. esc = `\"`
  434. case '\032':
  435. esc = `\Z`
  436. default:
  437. continue
  438. }
  439. io.WriteString(&buf, txt[last:ii])
  440. io.WriteString(&buf, esc)
  441. last = ii + 1
  442. }
  443. io.WriteString(&buf, txt[last:])
  444. return buf.String()
  445. }
  446. func escapeQuotes(txt string) string {
  447. var buf bytes.Buffer
  448. last := 0
  449. for ii, bb := range txt {
  450. if bb == '\'' {
  451. io.WriteString(&buf, txt[last:ii])
  452. io.WriteString(&buf, `''`)
  453. last = ii + 1
  454. }
  455. }
  456. io.WriteString(&buf, txt[last:])
  457. return buf.String()
  458. }