websocket.go 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. package carrier
  2. import (
  3. "io"
  4. "net/http"
  5. "net/http/httputil"
  6. "net/url"
  7. "github.com/gorilla/websocket"
  8. "github.com/rs/zerolog"
  9. "github.com/cloudflare/cloudflared/stream"
  10. "github.com/cloudflare/cloudflared/token"
  11. cfwebsocket "github.com/cloudflare/cloudflared/websocket"
  12. )
  13. // Websocket is used to carry data via WS binary frames over the tunnel from client to the origin
  14. // This implements the functions for glider proxy (sock5) and the carrier interface
  15. type Websocket struct {
  16. log *zerolog.Logger
  17. isSocks bool
  18. }
  19. // NewWSConnection returns a new connection object
  20. func NewWSConnection(log *zerolog.Logger) Connection {
  21. return &Websocket{
  22. log: log,
  23. }
  24. }
  25. // ServeStream will create a Websocket client stream connection to the edge
  26. // it blocks and writes the raw data from conn over the tunnel
  27. func (ws *Websocket) ServeStream(options *StartOptions, conn io.ReadWriter) error {
  28. wsConn, err := createWebsocketStream(options, ws.log)
  29. if err != nil {
  30. ws.log.Err(err).Str(LogFieldOriginURL, options.OriginURL).Msg("failed to connect to origin")
  31. return err
  32. }
  33. defer wsConn.Close()
  34. stream.Pipe(wsConn, conn, ws.log)
  35. return nil
  36. }
  37. // createWebsocketStream will create a WebSocket connection to stream data over
  38. // It also handles redirects from Access and will present that flow if
  39. // the token is not present on the request
  40. func createWebsocketStream(options *StartOptions, log *zerolog.Logger) (*cfwebsocket.GorillaConn, error) {
  41. req, err := http.NewRequest(http.MethodGet, options.OriginURL, nil)
  42. if err != nil {
  43. return nil, err
  44. }
  45. req.Header = options.Headers
  46. if options.Host != "" {
  47. req.Host = options.Host
  48. }
  49. dump, err := httputil.DumpRequest(req, false)
  50. if err != nil {
  51. return nil, err
  52. }
  53. log.Debug().Msgf("Websocket request: %s", string(dump))
  54. dialer := &websocket.Dialer{
  55. TLSClientConfig: options.TLSClientConfig,
  56. Proxy: http.ProxyFromEnvironment,
  57. }
  58. wsConn, resp, err := clientConnect(req, dialer)
  59. defer closeRespBody(resp)
  60. if err != nil && IsAccessResponse(resp) {
  61. // Only get Access app info if we know the origin is protected by Access
  62. originReq, err := http.NewRequest(http.MethodGet, options.OriginURL, nil)
  63. if err != nil {
  64. return nil, err
  65. }
  66. appInfo, err := token.GetAppInfo(originReq.URL)
  67. if err != nil {
  68. return nil, err
  69. }
  70. options.AppInfo = appInfo
  71. wsConn, err = createAccessAuthenticatedStream(options, log)
  72. if err != nil {
  73. return nil, err
  74. }
  75. } else if err != nil {
  76. return nil, err
  77. }
  78. return &cfwebsocket.GorillaConn{Conn: wsConn}, nil
  79. }
  80. var stripWebsocketHeaders = []string{
  81. "Upgrade",
  82. "Connection",
  83. "Sec-Websocket-Key",
  84. "Sec-Websocket-Version",
  85. "Sec-Websocket-Extensions",
  86. }
  87. // the gorilla websocket library sets its own Upgrade, Connection, Sec-WebSocket-Key,
  88. // Sec-WebSocket-Version and Sec-Websocket-Extensions headers.
  89. // https://github.com/gorilla/websocket/blob/master/client.go#L189-L194.
  90. func websocketHeaders(req *http.Request) http.Header {
  91. wsHeaders := make(http.Header)
  92. for key, val := range req.Header {
  93. wsHeaders[key] = val
  94. }
  95. // Assume the header keys are in canonical format.
  96. for _, header := range stripWebsocketHeaders {
  97. wsHeaders.Del(header)
  98. }
  99. wsHeaders.Set("Host", req.Host) // See TUN-1097
  100. return wsHeaders
  101. }
  102. // clientConnect creates a WebSocket client connection for provided request. Caller is responsible for closing
  103. // the connection. The response body may not contain the entire response and does
  104. // not need to be closed by the application.
  105. func clientConnect(req *http.Request, dialler *websocket.Dialer) (*websocket.Conn, *http.Response, error) {
  106. req.URL.Scheme = changeRequestScheme(req.URL)
  107. wsHeaders := websocketHeaders(req)
  108. if dialler == nil {
  109. dialler = &websocket.Dialer{
  110. Proxy: http.ProxyFromEnvironment,
  111. }
  112. }
  113. conn, response, err := dialler.Dial(req.URL.String(), wsHeaders)
  114. if err != nil {
  115. return nil, response, err
  116. }
  117. return conn, response, nil
  118. }
  119. // changeRequestScheme is needed as the gorilla websocket library requires the ws scheme.
  120. // (even though it changes it back to http/https, but ¯\_(ツ)_/¯.)
  121. func changeRequestScheme(reqURL *url.URL) string {
  122. switch reqURL.Scheme {
  123. case "https":
  124. return "wss"
  125. case "http":
  126. return "ws"
  127. case "":
  128. return "ws"
  129. default:
  130. return reqURL.Scheme
  131. }
  132. }
  133. // createAccessAuthenticatedStream will try load a token from storage and make
  134. // a connection with the token set on the request. If it still get redirect,
  135. // this probably means the token in storage is invalid (expired/revoked). If that
  136. // happens it deletes the token and runs the connection again, so the user can
  137. // login again and generate a new one.
  138. func createAccessAuthenticatedStream(options *StartOptions, log *zerolog.Logger) (*websocket.Conn, error) {
  139. wsConn, resp, err := createAccessWebSocketStream(options, log)
  140. defer closeRespBody(resp)
  141. if err == nil {
  142. return wsConn, nil
  143. }
  144. if !IsAccessResponse(resp) {
  145. return nil, err
  146. }
  147. // Access Token is invalid for some reason. Go through regen flow
  148. if err := token.RemoveTokenIfExists(options.AppInfo); err != nil {
  149. return nil, err
  150. }
  151. wsConn, resp, err = createAccessWebSocketStream(options, log)
  152. defer closeRespBody(resp)
  153. if err != nil {
  154. return nil, err
  155. }
  156. return wsConn, nil
  157. }
  158. // createAccessWebSocketStream builds an Access request and makes a connection
  159. func createAccessWebSocketStream(options *StartOptions, log *zerolog.Logger) (*websocket.Conn, *http.Response, error) {
  160. req, err := BuildAccessRequest(options, log)
  161. if err != nil {
  162. return nil, nil, err
  163. }
  164. dump, err := httputil.DumpRequest(req, false)
  165. if err != nil {
  166. return nil, nil, err
  167. }
  168. log.Debug().Msgf("Access Websocket request: %s", string(dump))
  169. conn, resp, err := clientConnect(req, nil)
  170. if resp != nil {
  171. r, err := httputil.DumpResponse(resp, true)
  172. if r != nil {
  173. log.Debug().Msgf("Websocket response: %q", r)
  174. } else if err != nil {
  175. log.Debug().Msgf("Websocket response error: %v", err)
  176. }
  177. }
  178. return conn, resp, err
  179. }