download_file.go 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. // License: GPLv3 Copyright: 2022, Kovid Goyal, <kovid at kovidgoyal.net>
  2. package utils
  3. import (
  4. "bytes"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "os"
  9. "path/filepath"
  10. "strconv"
  11. )
  12. var _ = fmt.Print
  13. type ReportFunc = func(done, total uint64) error
  14. type write_counter struct {
  15. done, total uint64
  16. report ReportFunc
  17. }
  18. func (self *write_counter) Write(p []byte) (int, error) {
  19. n := len(p)
  20. self.done += uint64(n)
  21. if self.report != nil {
  22. err := self.report(self.done, self.total)
  23. if err != nil {
  24. return 0, err
  25. }
  26. }
  27. return n, nil
  28. }
  29. func DownloadToWriter(url string, dest io.Writer, progress_callback ReportFunc) error {
  30. resp, err := http.Get(url)
  31. if err != nil {
  32. return err
  33. }
  34. defer resp.Body.Close()
  35. if resp.StatusCode != http.StatusOK {
  36. return fmt.Errorf("The server responded with the HTTP error: %s", resp.Status)
  37. }
  38. wc := write_counter{report: progress_callback}
  39. cl, err := strconv.Atoi(resp.Header.Get("Content-Length"))
  40. if err == nil {
  41. wc.total = uint64(cl)
  42. }
  43. _, err = io.Copy(dest, io.TeeReader(resp.Body, &wc))
  44. if err != nil {
  45. return err
  46. }
  47. return nil
  48. }
  49. func DownloadAsSlice(url string, progress_callback ReportFunc) (data []byte, err error) {
  50. b := bytes.Buffer{}
  51. b.Grow(4096)
  52. err = DownloadToWriter(url, &b, progress_callback)
  53. if err == nil {
  54. return b.Bytes(), nil
  55. }
  56. return nil, err
  57. }
  58. func DownloadToFile(destpath, url string, progress_callback ReportFunc, temp_file_path_callback func(string)) error {
  59. destpath, err := filepath.EvalSymlinks(destpath)
  60. if err != nil {
  61. return err
  62. }
  63. dest, err := os.CreateTemp(filepath.Dir(destpath), filepath.Base(destpath)+".partial-download.")
  64. if err != nil {
  65. return err
  66. }
  67. if temp_file_path_callback != nil {
  68. temp_file_path_callback(dest.Name())
  69. }
  70. dest_removed := false
  71. defer func() {
  72. dest.Close()
  73. if !dest_removed {
  74. os.Remove(dest.Name())
  75. }
  76. }()
  77. err = DownloadToWriter(url, dest, progress_callback)
  78. if err != nil {
  79. return err
  80. }
  81. dest.Close()
  82. fi, err := os.Stat(destpath)
  83. if err == nil {
  84. err = os.Chmod(dest.Name(), fi.Mode().Perm())
  85. if err != nil {
  86. return err
  87. }
  88. }
  89. if err != nil {
  90. return err
  91. }
  92. err = os.Rename(dest.Name(), destpath)
  93. if err != nil {
  94. return err
  95. }
  96. dest_removed = true
  97. return nil
  98. }