base_client.go 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. package cfapi
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "net/url"
  9. "strings"
  10. "time"
  11. "github.com/pkg/errors"
  12. "github.com/rs/zerolog"
  13. "golang.org/x/net/http2"
  14. )
  15. const (
  16. defaultTimeout = 15 * time.Second
  17. jsonContentType = "application/json"
  18. )
  19. var (
  20. ErrUnauthorized = errors.New("unauthorized")
  21. ErrBadRequest = errors.New("incorrect request parameters")
  22. ErrNotFound = errors.New("not found")
  23. ErrAPINoSuccess = errors.New("API call failed")
  24. )
  25. type RESTClient struct {
  26. baseEndpoints *baseEndpoints
  27. authToken string
  28. userAgent string
  29. client http.Client
  30. log *zerolog.Logger
  31. }
  32. type baseEndpoints struct {
  33. accountLevel url.URL
  34. zoneLevel url.URL
  35. accountRoutes url.URL
  36. accountVnets url.URL
  37. }
  38. var _ Client = (*RESTClient)(nil)
  39. func NewRESTClient(baseURL, accountTag, zoneTag, authToken, userAgent string, log *zerolog.Logger) (*RESTClient, error) {
  40. if strings.HasSuffix(baseURL, "/") {
  41. baseURL = baseURL[:len(baseURL)-1]
  42. }
  43. accountLevelEndpoint, err := url.Parse(fmt.Sprintf("%s/accounts/%s/cfd_tunnel", baseURL, accountTag))
  44. if err != nil {
  45. return nil, errors.Wrap(err, "failed to create account level endpoint")
  46. }
  47. accountRoutesEndpoint, err := url.Parse(fmt.Sprintf("%s/accounts/%s/teamnet/routes", baseURL, accountTag))
  48. if err != nil {
  49. return nil, errors.Wrap(err, "failed to create route account-level endpoint")
  50. }
  51. accountVnetsEndpoint, err := url.Parse(fmt.Sprintf("%s/accounts/%s/teamnet/virtual_networks", baseURL, accountTag))
  52. if err != nil {
  53. return nil, errors.Wrap(err, "failed to create virtual network account-level endpoint")
  54. }
  55. zoneLevelEndpoint, err := url.Parse(fmt.Sprintf("%s/zones/%s/tunnels", baseURL, zoneTag))
  56. if err != nil {
  57. return nil, errors.Wrap(err, "failed to create account level endpoint")
  58. }
  59. httpTransport := http.Transport{
  60. TLSHandshakeTimeout: defaultTimeout,
  61. ResponseHeaderTimeout: defaultTimeout,
  62. }
  63. http2.ConfigureTransport(&httpTransport)
  64. return &RESTClient{
  65. baseEndpoints: &baseEndpoints{
  66. accountLevel: *accountLevelEndpoint,
  67. zoneLevel: *zoneLevelEndpoint,
  68. accountRoutes: *accountRoutesEndpoint,
  69. accountVnets: *accountVnetsEndpoint,
  70. },
  71. authToken: authToken,
  72. userAgent: userAgent,
  73. client: http.Client{
  74. Transport: &httpTransport,
  75. Timeout: defaultTimeout,
  76. },
  77. log: log,
  78. }, nil
  79. }
  80. func (r *RESTClient) sendRequest(method string, url url.URL, body interface{}) (*http.Response, error) {
  81. var bodyReader io.Reader
  82. if body != nil {
  83. if bodyBytes, err := json.Marshal(body); err != nil {
  84. return nil, errors.Wrap(err, "failed to serialize json body")
  85. } else {
  86. bodyReader = bytes.NewBuffer(bodyBytes)
  87. }
  88. }
  89. req, err := http.NewRequest(method, url.String(), bodyReader)
  90. if err != nil {
  91. return nil, errors.Wrapf(err, "can't create %s request", method)
  92. }
  93. req.Header.Set("User-Agent", r.userAgent)
  94. if bodyReader != nil {
  95. req.Header.Set("Content-Type", jsonContentType)
  96. }
  97. req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", r.authToken))
  98. req.Header.Add("Accept", "application/json;version=1")
  99. return r.client.Do(req)
  100. }
  101. func parseResponseEnvelope(reader io.Reader) (*response, error) {
  102. // Schema for Tunnelstore responses in the v1 API.
  103. // Roughly, it's a wrapper around a particular result that adds failures/errors/etc
  104. var result response
  105. // First, parse the wrapper and check the API call succeeded
  106. if err := json.NewDecoder(reader).Decode(&result); err != nil {
  107. return nil, errors.Wrap(err, "failed to decode response")
  108. }
  109. if err := result.checkErrors(); err != nil {
  110. return nil, err
  111. }
  112. if !result.Success {
  113. return nil, ErrAPINoSuccess
  114. }
  115. return &result, nil
  116. }
  117. func parseResponse(reader io.Reader, data interface{}) error {
  118. result, err := parseResponseEnvelope(reader)
  119. if err != nil {
  120. return err
  121. }
  122. return parseResponseBody(result, data)
  123. }
  124. func parseResponseBody(result *response, data interface{}) error {
  125. // At this point we know the API call succeeded, so, parse out the inner
  126. // result into the datatype provided as a parameter.
  127. if err := json.Unmarshal(result.Result, &data); err != nil {
  128. return errors.Wrap(err, "the Cloudflare API response was an unexpected type")
  129. }
  130. return nil
  131. }
  132. func fetchExhaustively[T any](requestFn func(int) (*http.Response, error)) ([]*T, error) {
  133. page := 0
  134. var fullResponse []*T
  135. for {
  136. page += 1
  137. envelope, parsedBody, err := fetchPage[T](requestFn, page)
  138. if err != nil {
  139. return nil, errors.Wrap(err, fmt.Sprintf("Error Parsing page %d", page))
  140. }
  141. fullResponse = append(fullResponse, parsedBody...)
  142. if envelope.Pagination.Count < envelope.Pagination.PerPage || len(fullResponse) >= envelope.Pagination.TotalCount {
  143. break
  144. }
  145. }
  146. return fullResponse, nil
  147. }
  148. func fetchPage[T any](requestFn func(int) (*http.Response, error), page int) (*response, []*T, error) {
  149. pageResp, err := requestFn(page)
  150. if err != nil {
  151. return nil, nil, errors.Wrap(err, "REST request failed")
  152. }
  153. defer pageResp.Body.Close()
  154. if pageResp.StatusCode == http.StatusOK {
  155. envelope, err := parseResponseEnvelope(pageResp.Body)
  156. if err != nil {
  157. return nil, nil, err
  158. }
  159. var parsedRspBody []*T
  160. return envelope, parsedRspBody, parseResponseBody(envelope, &parsedRspBody)
  161. }
  162. return nil, nil, errors.New(fmt.Sprintf("Failed to fetch page. Server returned: %d", pageResp.StatusCode))
  163. }
  164. type response struct {
  165. Success bool `json:"success,omitempty"`
  166. Errors []apiErr `json:"errors,omitempty"`
  167. Messages []string `json:"messages,omitempty"`
  168. Result json.RawMessage `json:"result,omitempty"`
  169. Pagination Pagination `json:"result_info,omitempty"`
  170. }
  171. type Pagination struct {
  172. Count int `json:"count,omitempty"`
  173. Page int `json:"page,omitempty"`
  174. PerPage int `json:"per_page,omitempty"`
  175. TotalCount int `json:"total_count,omitempty"`
  176. }
  177. func (r *response) checkErrors() error {
  178. if len(r.Errors) == 0 {
  179. return nil
  180. }
  181. if len(r.Errors) == 1 {
  182. return r.Errors[0]
  183. }
  184. var messages string
  185. for _, e := range r.Errors {
  186. messages += fmt.Sprintf("%s; ", e)
  187. }
  188. return fmt.Errorf("API errors: %s", messages)
  189. }
  190. type apiErr struct {
  191. Code json.Number `json:"code,omitempty"`
  192. Message string `json:"message,omitempty"`
  193. }
  194. func (e apiErr) Error() string {
  195. return fmt.Sprintf("code: %v, reason: %s", e.Code, e.Message)
  196. }
  197. func (r *RESTClient) statusCodeToError(op string, resp *http.Response) error {
  198. if resp.Header.Get("Content-Type") == "application/json" {
  199. var errorsResp response
  200. if json.NewDecoder(resp.Body).Decode(&errorsResp) == nil {
  201. if err := errorsResp.checkErrors(); err != nil {
  202. return errors.Errorf("Failed to %s: %s", op, err)
  203. }
  204. }
  205. }
  206. switch resp.StatusCode {
  207. case http.StatusOK:
  208. return nil
  209. case http.StatusBadRequest:
  210. return ErrBadRequest
  211. case http.StatusUnauthorized, http.StatusForbidden:
  212. return ErrUnauthorized
  213. case http.StatusNotFound:
  214. return ErrNotFound
  215. }
  216. return errors.Errorf("API call to %s failed with status %d: %s", op,
  217. resp.StatusCode, http.StatusText(resp.StatusCode))
  218. }