carrier.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. //Package carrier provides a WebSocket proxy to carry or proxy a connection
  2. //from the local client to the edge. See it as a wrapper around any protocol
  3. //that it packages up in a WebSocket connection to the edge.
  4. package carrier
  5. import (
  6. "io"
  7. "net"
  8. "net/http"
  9. "os"
  10. "strings"
  11. "github.com/cloudflare/cloudflared/cmd/cloudflared/token"
  12. "github.com/cloudflare/cloudflared/h2mux"
  13. "github.com/cloudflare/cloudflared/logger"
  14. "github.com/pkg/errors"
  15. )
  16. type StartOptions struct {
  17. OriginURL string
  18. Headers http.Header
  19. }
  20. // Connection wraps up all the needed functions to forward over the tunnel
  21. type Connection interface {
  22. // ServeStream is used to forward data from the client to the edge
  23. ServeStream(*StartOptions, io.ReadWriter) error
  24. // StartServer is used to listen for incoming connections from the edge to the origin
  25. StartServer(net.Listener, string, <-chan struct{}) error
  26. }
  27. // StdinoutStream is empty struct for wrapping stdin/stdout
  28. // into a single ReadWriter
  29. type StdinoutStream struct {
  30. }
  31. // Read will read from Stdin
  32. func (c *StdinoutStream) Read(p []byte) (int, error) {
  33. return os.Stdin.Read(p)
  34. }
  35. // Write will write to Stdout
  36. func (c *StdinoutStream) Write(p []byte) (int, error) {
  37. return os.Stdout.Write(p)
  38. }
  39. // Helper to allow defering the response close with a check that the resp is not nil
  40. func closeRespBody(resp *http.Response) {
  41. if resp != nil {
  42. resp.Body.Close()
  43. }
  44. }
  45. // StartForwarder will setup a listener on a specified address/port and then
  46. // forward connections to the origin by calling `Serve()`.
  47. func StartForwarder(conn Connection, address string, shutdownC <-chan struct{}, options *StartOptions) error {
  48. listener, err := net.Listen("tcp", address)
  49. if err != nil {
  50. return errors.Wrap(err, "failed to start forwarding server")
  51. }
  52. return Serve(conn, listener, shutdownC, options)
  53. }
  54. // StartClient will copy the data from stdin/stdout over a WebSocket connection
  55. // to the edge (originURL)
  56. func StartClient(conn Connection, stream io.ReadWriter, options *StartOptions) error {
  57. return conn.ServeStream(options, stream)
  58. }
  59. // Serve accepts incoming connections on the specified net.Listener.
  60. // Each connection is handled in a new goroutine: its data is copied over a
  61. // WebSocket connection to the edge (originURL).
  62. // `Serve` always closes `listener`.
  63. func Serve(remoteConn Connection, listener net.Listener, shutdownC <-chan struct{}, options *StartOptions) error {
  64. defer listener.Close()
  65. errChan := make(chan error)
  66. go func() {
  67. for {
  68. conn, err := listener.Accept()
  69. if err != nil {
  70. // don't block if parent goroutine quit early
  71. select {
  72. case errChan <- err:
  73. default:
  74. }
  75. return
  76. }
  77. go serveConnection(remoteConn, conn, options)
  78. }
  79. }()
  80. select {
  81. case <-shutdownC:
  82. return nil
  83. case err := <-errChan:
  84. return err
  85. }
  86. }
  87. // serveConnection handles connections for the Serve() call
  88. func serveConnection(remoteConn Connection, c net.Conn, options *StartOptions) {
  89. defer c.Close()
  90. remoteConn.ServeStream(options, c)
  91. }
  92. // IsAccessResponse checks the http Response to see if the url location
  93. // contains the Access structure.
  94. func IsAccessResponse(resp *http.Response) bool {
  95. if resp == nil || resp.StatusCode != http.StatusFound {
  96. return false
  97. }
  98. location, err := resp.Location()
  99. if err != nil || location == nil {
  100. return false
  101. }
  102. if strings.HasPrefix(location.Path, "/cdn-cgi/access/login") {
  103. return true
  104. }
  105. return false
  106. }
  107. // BuildAccessRequest builds an HTTP request with the Access token set
  108. func BuildAccessRequest(options *StartOptions, logger logger.Service) (*http.Request, error) {
  109. req, err := http.NewRequest(http.MethodGet, options.OriginURL, nil)
  110. if err != nil {
  111. return nil, err
  112. }
  113. token, err := token.FetchTokenWithRedirect(req.URL, logger)
  114. if err != nil {
  115. return nil, err
  116. }
  117. // We need to create a new request as FetchToken will modify req (boo mutable)
  118. // as it has to follow redirect on the API and such, so here we init a new one
  119. originRequest, err := http.NewRequest(http.MethodGet, options.OriginURL, nil)
  120. if err != nil {
  121. return nil, err
  122. }
  123. originRequest.Header.Set(h2mux.CFAccessTokenHeader, token)
  124. for k, v := range options.Headers {
  125. if len(v) >= 1 {
  126. originRequest.Header.Set(k, v[0])
  127. }
  128. }
  129. return originRequest, nil
  130. }