client.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. package tunnelstore
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "net/url"
  9. "path"
  10. "strings"
  11. "time"
  12. "github.com/google/uuid"
  13. "github.com/pkg/errors"
  14. "github.com/cloudflare/cloudflared/logger"
  15. )
  16. const (
  17. defaultTimeout = 15 * time.Second
  18. jsonContentType = "application/json"
  19. )
  20. var (
  21. ErrTunnelNameConflict = errors.New("tunnel with name already exists")
  22. ErrUnauthorized = errors.New("unauthorized")
  23. ErrBadRequest = errors.New("incorrect request parameters")
  24. ErrNotFound = errors.New("not found")
  25. ErrAPINoSuccess = errors.New("API call failed")
  26. )
  27. type Tunnel struct {
  28. ID uuid.UUID `json:"id"`
  29. Name string `json:"name"`
  30. CreatedAt time.Time `json:"created_at"`
  31. DeletedAt time.Time `json:"deleted_at"`
  32. Connections []Connection `json:"connections"`
  33. }
  34. type Connection struct {
  35. ColoName string `json:"colo_name"`
  36. ID uuid.UUID `json:"id"`
  37. IsPendingReconnect bool `json:"is_pending_reconnect"`
  38. }
  39. type Change = string
  40. const (
  41. ChangeNew = "new"
  42. ChangeUpdated = "updated"
  43. ChangeUnchanged = "unchanged"
  44. )
  45. // Route represents a record type that can route to a tunnel
  46. type Route interface {
  47. json.Marshaler
  48. RecordType() string
  49. UnmarshalResult(body io.Reader) (RouteResult, error)
  50. }
  51. type RouteResult interface {
  52. // SuccessSummary explains what will route to this tunnel when it's provisioned successfully
  53. SuccessSummary() string
  54. }
  55. type DNSRoute struct {
  56. userHostname string
  57. }
  58. type DNSRouteResult struct {
  59. route *DNSRoute
  60. CName Change `json:"cname"`
  61. }
  62. func NewDNSRoute(userHostname string) Route {
  63. return &DNSRoute{
  64. userHostname: userHostname,
  65. }
  66. }
  67. func (dr *DNSRoute) MarshalJSON() ([]byte, error) {
  68. s := struct {
  69. Type string `json:"type"`
  70. UserHostname string `json:"user_hostname"`
  71. }{
  72. Type: dr.RecordType(),
  73. UserHostname: dr.userHostname,
  74. }
  75. return json.Marshal(&s)
  76. }
  77. func (dr *DNSRoute) UnmarshalResult(body io.Reader) (RouteResult, error) {
  78. var result DNSRouteResult
  79. err := parseResponse(body, &result)
  80. result.route = dr
  81. return &result, err
  82. }
  83. func (dr *DNSRoute) RecordType() string {
  84. return "dns"
  85. }
  86. func (res *DNSRouteResult) SuccessSummary() string {
  87. var msgFmt string
  88. switch res.CName {
  89. case ChangeNew:
  90. msgFmt = "Added CNAME %s which will route to this tunnel"
  91. case ChangeUpdated: // this is not currently returned by tunnelsore
  92. msgFmt = "%s updated to route to your tunnel"
  93. case ChangeUnchanged:
  94. msgFmt = "%s is already configured to route to your tunnel"
  95. }
  96. return fmt.Sprintf(msgFmt, res.route.userHostname)
  97. }
  98. type LBRoute struct {
  99. lbName string
  100. lbPool string
  101. }
  102. type LBRouteResult struct {
  103. route *LBRoute
  104. LoadBalancer Change `json:"load_balancer"`
  105. Pool Change `json:"pool"`
  106. }
  107. func NewLBRoute(lbName, lbPool string) Route {
  108. return &LBRoute{
  109. lbName: lbName,
  110. lbPool: lbPool,
  111. }
  112. }
  113. func (lr *LBRoute) MarshalJSON() ([]byte, error) {
  114. s := struct {
  115. Type string `json:"type"`
  116. LBName string `json:"lb_name"`
  117. LBPool string `json:"lb_pool"`
  118. }{
  119. Type: lr.RecordType(),
  120. LBName: lr.lbName,
  121. LBPool: lr.lbPool,
  122. }
  123. return json.Marshal(&s)
  124. }
  125. func (lr *LBRoute) RecordType() string {
  126. return "lb"
  127. }
  128. func (lr *LBRoute) UnmarshalResult(body io.Reader) (RouteResult, error) {
  129. var result LBRouteResult
  130. err := parseResponse(body, &result)
  131. result.route = lr
  132. return &result, err
  133. }
  134. func (res *LBRouteResult) SuccessSummary() string {
  135. var msg string
  136. switch res.LoadBalancer + "," + res.Pool {
  137. case "new,new":
  138. msg = "Created load balancer %s and added a new pool %s with this tunnel as an origin"
  139. case "new,updated":
  140. msg = "Created load balancer %s with an existing pool %s which was updated to use this tunnel as an origin"
  141. case "new,unchanged":
  142. msg = "Created load balancer %s with an existing pool %s which already has this tunnel as an origin"
  143. case "updated,new":
  144. msg = "Added new pool %[2]s with this tunnel as an origin to load balancer %[1]s"
  145. case "updated,updated":
  146. msg = "Updated pool %[2]s to use this tunnel as an origin and added it to load balancer %[1]s"
  147. case "updated,unchanged":
  148. msg = "Added pool %[2]s, which already has this tunnel as an origin, to load balancer %[1]s"
  149. case "unchanged,updated":
  150. msg = "Added this tunnel as an origin in pool %[2]s which is already used by load balancer %[1]s"
  151. case "unchanged,unchanged":
  152. msg = "Load balancer %s already uses pool %s which has this tunnel as an origin"
  153. case "unchanged,new":
  154. // this state is not possible
  155. fallthrough
  156. default:
  157. msg = "Something went wrong: failed to modify load balancer %s with pool %s; please check traffic manager configuration in the dashboard"
  158. }
  159. return fmt.Sprintf(msg, res.route.lbName, res.route.lbPool)
  160. }
  161. type Client interface {
  162. CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, error)
  163. GetTunnel(tunnelID uuid.UUID) (*Tunnel, error)
  164. DeleteTunnel(tunnelID uuid.UUID) error
  165. ListTunnels(filter *Filter) ([]*Tunnel, error)
  166. CleanupConnections(tunnelID uuid.UUID) error
  167. RouteTunnel(tunnelID uuid.UUID, route Route) (RouteResult, error)
  168. }
  169. type RESTClient struct {
  170. baseEndpoints *baseEndpoints
  171. authToken string
  172. userAgent string
  173. client http.Client
  174. logger logger.Service
  175. }
  176. type baseEndpoints struct {
  177. accountLevel url.URL
  178. zoneLevel url.URL
  179. }
  180. var _ Client = (*RESTClient)(nil)
  181. func NewRESTClient(baseURL, accountTag, zoneTag, authToken, userAgent string, logger logger.Service) (*RESTClient, error) {
  182. if strings.HasSuffix(baseURL, "/") {
  183. baseURL = baseURL[:len(baseURL)-1]
  184. }
  185. accountLevelEndpoint, err := url.Parse(fmt.Sprintf("%s/accounts/%s/tunnels", baseURL, accountTag))
  186. if err != nil {
  187. return nil, errors.Wrap(err, "failed to create account level endpoint")
  188. }
  189. zoneLevelEndpoint, err := url.Parse(fmt.Sprintf("%s/zones/%s/tunnels", baseURL, zoneTag))
  190. if err != nil {
  191. return nil, errors.Wrap(err, "failed to create account level endpoint")
  192. }
  193. return &RESTClient{
  194. baseEndpoints: &baseEndpoints{
  195. accountLevel: *accountLevelEndpoint,
  196. zoneLevel: *zoneLevelEndpoint,
  197. },
  198. authToken: authToken,
  199. userAgent: userAgent,
  200. client: http.Client{
  201. Transport: &http.Transport{
  202. TLSHandshakeTimeout: defaultTimeout,
  203. ResponseHeaderTimeout: defaultTimeout,
  204. },
  205. Timeout: defaultTimeout,
  206. },
  207. logger: logger,
  208. }, nil
  209. }
  210. type newTunnel struct {
  211. Name string `json:"name"`
  212. TunnelSecret []byte `json:"tunnel_secret"`
  213. }
  214. func (r *RESTClient) CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, error) {
  215. if name == "" {
  216. return nil, errors.New("tunnel name required")
  217. }
  218. if _, err := uuid.Parse(name); err == nil {
  219. return nil, errors.New("you cannot use UUIDs as tunnel names")
  220. }
  221. body := &newTunnel{
  222. Name: name,
  223. TunnelSecret: tunnelSecret,
  224. }
  225. resp, err := r.sendRequest("POST", r.baseEndpoints.accountLevel, body)
  226. if err != nil {
  227. return nil, errors.Wrap(err, "REST request failed")
  228. }
  229. defer resp.Body.Close()
  230. switch resp.StatusCode {
  231. case http.StatusOK:
  232. return unmarshalTunnel(resp.Body)
  233. case http.StatusConflict:
  234. return nil, ErrTunnelNameConflict
  235. }
  236. return nil, r.statusCodeToError("create tunnel", resp)
  237. }
  238. func (r *RESTClient) GetTunnel(tunnelID uuid.UUID) (*Tunnel, error) {
  239. endpoint := r.baseEndpoints.accountLevel
  240. endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v", tunnelID))
  241. resp, err := r.sendRequest("GET", endpoint, nil)
  242. if err != nil {
  243. return nil, errors.Wrap(err, "REST request failed")
  244. }
  245. defer resp.Body.Close()
  246. if resp.StatusCode == http.StatusOK {
  247. return unmarshalTunnel(resp.Body)
  248. }
  249. return nil, r.statusCodeToError("get tunnel", resp)
  250. }
  251. func (r *RESTClient) DeleteTunnel(tunnelID uuid.UUID) error {
  252. endpoint := r.baseEndpoints.accountLevel
  253. endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v", tunnelID))
  254. resp, err := r.sendRequest("DELETE", endpoint, nil)
  255. if err != nil {
  256. return errors.Wrap(err, "REST request failed")
  257. }
  258. defer resp.Body.Close()
  259. return r.statusCodeToError("delete tunnel", resp)
  260. }
  261. func (r *RESTClient) ListTunnels(filter *Filter) ([]*Tunnel, error) {
  262. endpoint := r.baseEndpoints.accountLevel
  263. endpoint.RawQuery = filter.encode()
  264. resp, err := r.sendRequest("GET", endpoint, nil)
  265. if err != nil {
  266. return nil, errors.Wrap(err, "REST request failed")
  267. }
  268. defer resp.Body.Close()
  269. if resp.StatusCode == http.StatusOK {
  270. return parseListTunnels(resp.Body)
  271. }
  272. return nil, r.statusCodeToError("list tunnels", resp)
  273. }
  274. func parseListTunnels(body io.ReadCloser) ([]*Tunnel, error) {
  275. var tunnels []*Tunnel
  276. err := parseResponse(body, &tunnels)
  277. return tunnels, err
  278. }
  279. func (r *RESTClient) CleanupConnections(tunnelID uuid.UUID) error {
  280. endpoint := r.baseEndpoints.accountLevel
  281. endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v/connections", tunnelID))
  282. resp, err := r.sendRequest("DELETE", endpoint, nil)
  283. if err != nil {
  284. return errors.Wrap(err, "REST request failed")
  285. }
  286. defer resp.Body.Close()
  287. return r.statusCodeToError("cleanup connections", resp)
  288. }
  289. func (r *RESTClient) RouteTunnel(tunnelID uuid.UUID, route Route) (RouteResult, error) {
  290. endpoint := r.baseEndpoints.zoneLevel
  291. endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v/routes", tunnelID))
  292. resp, err := r.sendRequest("PUT", endpoint, route)
  293. if err != nil {
  294. return nil, errors.Wrap(err, "REST request failed")
  295. }
  296. defer resp.Body.Close()
  297. if resp.StatusCode == http.StatusOK {
  298. return route.UnmarshalResult(resp.Body)
  299. }
  300. return nil, r.statusCodeToError("add route", resp)
  301. }
  302. func (r *RESTClient) sendRequest(method string, url url.URL, body interface{}) (*http.Response, error) {
  303. var bodyReader io.Reader
  304. if body != nil {
  305. if bodyBytes, err := json.Marshal(body); err != nil {
  306. return nil, errors.Wrap(err, "failed to serialize json body")
  307. } else {
  308. bodyReader = bytes.NewBuffer(bodyBytes)
  309. }
  310. }
  311. req, err := http.NewRequest(method, url.String(), bodyReader)
  312. if err != nil {
  313. return nil, errors.Wrapf(err, "can't create %s request", method)
  314. }
  315. req.Header.Set("User-Agent", r.userAgent)
  316. if bodyReader != nil {
  317. req.Header.Set("Content-Type", jsonContentType)
  318. }
  319. req.Header.Add("X-Auth-User-Service-Key", r.authToken)
  320. req.Header.Add("Accept", "application/json;version=1")
  321. return r.client.Do(req)
  322. }
  323. func parseResponse(reader io.Reader, data interface{}) error {
  324. // Schema for Tunnelstore responses in the v1 API.
  325. // Roughly, it's a wrapper around a particular result that adds failures/errors/etc
  326. var result response
  327. // First, parse the wrapper and check the API call succeeded
  328. if err := json.NewDecoder(reader).Decode(&result); err != nil {
  329. return errors.Wrap(err, "failed to decode response")
  330. }
  331. if err := result.checkErrors(); err != nil {
  332. return err
  333. }
  334. if !result.Success {
  335. return ErrAPINoSuccess
  336. }
  337. // At this point we know the API call succeeded, so, parse out the inner
  338. // result into the datatype provided as a parameter.
  339. return json.Unmarshal(result.Result, &data)
  340. }
  341. func unmarshalTunnel(reader io.Reader) (*Tunnel, error) {
  342. var tunnel Tunnel
  343. err := parseResponse(reader, &tunnel)
  344. return &tunnel, err
  345. }
  346. type response struct {
  347. Success bool `json:"success,omitempty"`
  348. Errors []apiErr `json:"errors,omitempty"`
  349. Messages []string `json:"messages,omitempty"`
  350. Result json.RawMessage `json:"result,omitempty"`
  351. }
  352. func (r *response) checkErrors() error {
  353. if len(r.Errors) == 0 {
  354. return nil
  355. }
  356. if len(r.Errors) == 1 {
  357. return r.Errors[0]
  358. }
  359. var messages string
  360. for _, e := range r.Errors {
  361. messages += fmt.Sprintf("%s; ", e)
  362. }
  363. return fmt.Errorf("API errors: %s", messages)
  364. }
  365. type apiErr struct {
  366. Code json.Number `json:"code,omitempty"`
  367. Message string `json:"message,omitempty"`
  368. }
  369. func (e apiErr) Error() string {
  370. return fmt.Sprintf("code: %v, reason: %s", e.Code, e.Message)
  371. }
  372. func (r *RESTClient) statusCodeToError(op string, resp *http.Response) error {
  373. if resp.Header.Get("Content-Type") == "application/json" {
  374. var errorsResp response
  375. if json.NewDecoder(resp.Body).Decode(&errorsResp) == nil {
  376. if err := errorsResp.checkErrors(); err != nil {
  377. return errors.Errorf("Failed to %s: %s", op, err)
  378. }
  379. }
  380. }
  381. switch resp.StatusCode {
  382. case http.StatusOK:
  383. return nil
  384. case http.StatusBadRequest:
  385. return ErrBadRequest
  386. case http.StatusUnauthorized, http.StatusForbidden:
  387. return ErrUnauthorized
  388. case http.StatusNotFound:
  389. return ErrNotFound
  390. }
  391. return errors.Errorf("API call to %s failed with status %d: %s", op,
  392. resp.StatusCode, http.StatusText(resp.StatusCode))
  393. }