sqlite3_load_extension.go 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. // Copyright (C) 2019 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
  2. //
  3. // Use of this source code is governed by an MIT-style
  4. // license that can be found in the LICENSE file.
  5. // +build !sqlite_omit_load_extension
  6. package sqlite3
  7. /*
  8. #ifndef USE_LIBSQLITE3
  9. #include "sqlite3-binding.h"
  10. #else
  11. #include <sqlite3.h>
  12. #endif
  13. #include <stdlib.h>
  14. */
  15. import "C"
  16. import (
  17. "errors"
  18. "unsafe"
  19. )
  20. func (c *SQLiteConn) loadExtensions(extensions []string) error {
  21. rv := C.sqlite3_enable_load_extension(c.db, 1)
  22. if rv != C.SQLITE_OK {
  23. return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
  24. }
  25. for _, extension := range extensions {
  26. if err := c.loadExtension(extension, nil); err != nil {
  27. C.sqlite3_enable_load_extension(c.db, 0)
  28. return err
  29. }
  30. }
  31. rv = C.sqlite3_enable_load_extension(c.db, 0)
  32. if rv != C.SQLITE_OK {
  33. return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
  34. }
  35. return nil
  36. }
  37. // LoadExtension load the sqlite3 extension.
  38. func (c *SQLiteConn) LoadExtension(lib string, entry string) error {
  39. rv := C.sqlite3_enable_load_extension(c.db, 1)
  40. if rv != C.SQLITE_OK {
  41. return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
  42. }
  43. if err := c.loadExtension(lib, &entry); err != nil {
  44. C.sqlite3_enable_load_extension(c.db, 0)
  45. return err
  46. }
  47. rv = C.sqlite3_enable_load_extension(c.db, 0)
  48. if rv != C.SQLITE_OK {
  49. return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
  50. }
  51. return nil
  52. }
  53. func (c *SQLiteConn) loadExtension(lib string, entry *string) error {
  54. clib := C.CString(lib)
  55. defer C.free(unsafe.Pointer(clib))
  56. var centry *C.char
  57. if entry != nil {
  58. centry = C.CString(*entry)
  59. defer C.free(unsafe.Pointer(centry))
  60. }
  61. var errMsg *C.char
  62. defer C.sqlite3_free(unsafe.Pointer(errMsg))
  63. rv := C.sqlite3_load_extension(c.db, clib, centry, &errMsg)
  64. if rv != C.SQLITE_OK {
  65. return errors.New(C.GoString(errMsg))
  66. }
  67. return nil
  68. }