carrier.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. //Package carrier provides a WebSocket proxy to carry or proxy a connection
  2. //from the local client to the edge. See it as a wrapper around any protocol
  3. //that it packages up in a WebSocket connection to the edge.
  4. package carrier
  5. import (
  6. "io"
  7. "net"
  8. "net/http"
  9. "os"
  10. "strings"
  11. "github.com/cloudflare/cloudflared/cmd/cloudflared/token"
  12. "github.com/cloudflare/cloudflared/h2mux"
  13. "github.com/pkg/errors"
  14. "github.com/rs/zerolog"
  15. )
  16. const LogFieldOriginURL = "originURL"
  17. type StartOptions struct {
  18. OriginURL string
  19. Headers http.Header
  20. }
  21. // Connection wraps up all the needed functions to forward over the tunnel
  22. type Connection interface {
  23. // ServeStream is used to forward data from the client to the edge
  24. ServeStream(*StartOptions, io.ReadWriter) error
  25. // StartServer is used to listen for incoming connections from the edge to the origin
  26. StartServer(net.Listener, string, <-chan struct{}) error
  27. }
  28. // StdinoutStream is empty struct for wrapping stdin/stdout
  29. // into a single ReadWriter
  30. type StdinoutStream struct {
  31. }
  32. // Read will read from Stdin
  33. func (c *StdinoutStream) Read(p []byte) (int, error) {
  34. return os.Stdin.Read(p)
  35. }
  36. // Write will write to Stdout
  37. func (c *StdinoutStream) Write(p []byte) (int, error) {
  38. return os.Stdout.Write(p)
  39. }
  40. // Helper to allow defering the response close with a check that the resp is not nil
  41. func closeRespBody(resp *http.Response) {
  42. if resp != nil {
  43. _ = resp.Body.Close()
  44. }
  45. }
  46. // StartForwarder will setup a listener on a specified address/port and then
  47. // forward connections to the origin by calling `Serve()`.
  48. func StartForwarder(conn Connection, address string, shutdownC <-chan struct{}, options *StartOptions) error {
  49. listener, err := net.Listen("tcp", address)
  50. if err != nil {
  51. return errors.Wrap(err, "failed to start forwarding server")
  52. }
  53. return Serve(conn, listener, shutdownC, options)
  54. }
  55. // StartClient will copy the data from stdin/stdout over a WebSocket connection
  56. // to the edge (originURL)
  57. func StartClient(conn Connection, stream io.ReadWriter, options *StartOptions) error {
  58. return conn.ServeStream(options, stream)
  59. }
  60. // Serve accepts incoming connections on the specified net.Listener.
  61. // Each connection is handled in a new goroutine: its data is copied over a
  62. // WebSocket connection to the edge (originURL).
  63. // `Serve` always closes `listener`.
  64. func Serve(remoteConn Connection, listener net.Listener, shutdownC <-chan struct{}, options *StartOptions) error {
  65. defer listener.Close()
  66. errChan := make(chan error)
  67. go func() {
  68. for {
  69. conn, err := listener.Accept()
  70. if err != nil {
  71. // don't block if parent goroutine quit early
  72. select {
  73. case errChan <- err:
  74. default:
  75. }
  76. return
  77. }
  78. go serveConnection(remoteConn, conn, options)
  79. }
  80. }()
  81. select {
  82. case <-shutdownC:
  83. return nil
  84. case err := <-errChan:
  85. return err
  86. }
  87. }
  88. // serveConnection handles connections for the Serve() call
  89. func serveConnection(remoteConn Connection, c net.Conn, options *StartOptions) {
  90. defer c.Close()
  91. _ = remoteConn.ServeStream(options, c)
  92. }
  93. // IsAccessResponse checks the http Response to see if the url location
  94. // contains the Access structure.
  95. func IsAccessResponse(resp *http.Response) bool {
  96. if resp == nil || resp.StatusCode != http.StatusFound {
  97. return false
  98. }
  99. location, err := resp.Location()
  100. if err != nil || location == nil {
  101. return false
  102. }
  103. if strings.HasPrefix(location.Path, "/cdn-cgi/access/login") {
  104. return true
  105. }
  106. return false
  107. }
  108. // BuildAccessRequest builds an HTTP request with the Access token set
  109. func BuildAccessRequest(options *StartOptions, log *zerolog.Logger) (*http.Request, error) {
  110. req, err := http.NewRequest(http.MethodGet, options.OriginURL, nil)
  111. if err != nil {
  112. return nil, err
  113. }
  114. token, err := token.FetchTokenWithRedirect(req.URL, log)
  115. if err != nil {
  116. return nil, err
  117. }
  118. // We need to create a new request as FetchToken will modify req (boo mutable)
  119. // as it has to follow redirect on the API and such, so here we init a new one
  120. originRequest, err := http.NewRequest(http.MethodGet, options.OriginURL, nil)
  121. if err != nil {
  122. return nil, err
  123. }
  124. originRequest.Header.Set(h2mux.CFAccessTokenHeader, token)
  125. for k, v := range options.Headers {
  126. if len(v) >= 1 {
  127. originRequest.Header.Set(k, v[0])
  128. }
  129. }
  130. return originRequest, nil
  131. }