carrier.go 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. //Package carrier provides a WebSocket proxy to carry or proxy a connection
  2. //from the local client to the edge. See it as a wrapper around any protocol
  3. //that it packages up in a WebSocket connection to the edge.
  4. package carrier
  5. import (
  6. "crypto/tls"
  7. "io"
  8. "net"
  9. "net/http"
  10. "os"
  11. "strings"
  12. "github.com/cloudflare/cloudflared/h2mux"
  13. "github.com/cloudflare/cloudflared/token"
  14. "github.com/pkg/errors"
  15. "github.com/rs/zerolog"
  16. )
  17. const LogFieldOriginURL = "originURL"
  18. type StartOptions struct {
  19. AppInfo *token.AppInfo
  20. OriginURL string
  21. Headers http.Header
  22. Host string
  23. TLSClientConfig *tls.Config
  24. }
  25. // Connection wraps up all the needed functions to forward over the tunnel
  26. type Connection interface {
  27. // ServeStream is used to forward data from the client to the edge
  28. ServeStream(*StartOptions, io.ReadWriter) error
  29. // StartServer is used to listen for incoming connections from the edge to the origin
  30. StartServer(net.Listener, string, <-chan struct{}) error
  31. }
  32. // StdinoutStream is empty struct for wrapping stdin/stdout
  33. // into a single ReadWriter
  34. type StdinoutStream struct {
  35. }
  36. // Read will read from Stdin
  37. func (c *StdinoutStream) Read(p []byte) (int, error) {
  38. return os.Stdin.Read(p)
  39. }
  40. // Write will write to Stdout
  41. func (c *StdinoutStream) Write(p []byte) (int, error) {
  42. return os.Stdout.Write(p)
  43. }
  44. // Helper to allow defering the response close with a check that the resp is not nil
  45. func closeRespBody(resp *http.Response) {
  46. if resp != nil {
  47. _ = resp.Body.Close()
  48. }
  49. }
  50. // StartForwarder will setup a listener on a specified address/port and then
  51. // forward connections to the origin by calling `Serve()`.
  52. func StartForwarder(conn Connection, address string, shutdownC <-chan struct{}, options *StartOptions) error {
  53. listener, err := net.Listen("tcp", address)
  54. if err != nil {
  55. return errors.Wrap(err, "failed to start forwarding server")
  56. }
  57. return Serve(conn, listener, shutdownC, options)
  58. }
  59. // StartClient will copy the data from stdin/stdout over a WebSocket connection
  60. // to the edge (originURL)
  61. func StartClient(conn Connection, stream io.ReadWriter, options *StartOptions) error {
  62. return conn.ServeStream(options, stream)
  63. }
  64. // Serve accepts incoming connections on the specified net.Listener.
  65. // Each connection is handled in a new goroutine: its data is copied over a
  66. // WebSocket connection to the edge (originURL).
  67. // `Serve` always closes `listener`.
  68. func Serve(remoteConn Connection, listener net.Listener, shutdownC <-chan struct{}, options *StartOptions) error {
  69. defer listener.Close()
  70. errChan := make(chan error)
  71. go func() {
  72. for {
  73. conn, err := listener.Accept()
  74. if err != nil {
  75. // don't block if parent goroutine quit early
  76. select {
  77. case errChan <- err:
  78. default:
  79. }
  80. return
  81. }
  82. go serveConnection(remoteConn, conn, options)
  83. }
  84. }()
  85. select {
  86. case <-shutdownC:
  87. return nil
  88. case err := <-errChan:
  89. return err
  90. }
  91. }
  92. // serveConnection handles connections for the Serve() call
  93. func serveConnection(remoteConn Connection, c net.Conn, options *StartOptions) {
  94. defer c.Close()
  95. _ = remoteConn.ServeStream(options, c)
  96. }
  97. // IsAccessResponse checks the http Response to see if the url location
  98. // contains the Access structure.
  99. func IsAccessResponse(resp *http.Response) bool {
  100. if resp == nil || resp.StatusCode != http.StatusFound {
  101. return false
  102. }
  103. location, err := resp.Location()
  104. if err != nil || location == nil {
  105. return false
  106. }
  107. if strings.HasPrefix(location.Path, token.AccessLoginWorkerPath) {
  108. return true
  109. }
  110. return false
  111. }
  112. // BuildAccessRequest builds an HTTP request with the Access token set
  113. func BuildAccessRequest(options *StartOptions, log *zerolog.Logger) (*http.Request, error) {
  114. req, err := http.NewRequest(http.MethodGet, options.OriginURL, nil)
  115. if err != nil {
  116. return nil, err
  117. }
  118. token, err := token.FetchTokenWithRedirect(req.URL, options.AppInfo, log)
  119. if err != nil {
  120. return nil, err
  121. }
  122. // We need to create a new request as FetchToken will modify req (boo mutable)
  123. // as it has to follow redirect on the API and such, so here we init a new one
  124. originRequest, err := http.NewRequest(http.MethodGet, options.OriginURL, nil)
  125. if err != nil {
  126. return nil, err
  127. }
  128. originRequest.Header.Set(h2mux.CFAccessTokenHeader, token)
  129. for k, v := range options.Headers {
  130. if len(v) >= 1 {
  131. originRequest.Header.Set(k, v[0])
  132. }
  133. }
  134. return originRequest, nil
  135. }