http2.go 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. package connection
  2. import (
  3. "context"
  4. "fmt"
  5. "io"
  6. "math"
  7. "net"
  8. "net/http"
  9. "strings"
  10. "sync"
  11. "github.com/cloudflare/cloudflared/h2mux"
  12. tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
  13. "github.com/rs/zerolog"
  14. "golang.org/x/net/http2"
  15. )
  16. const (
  17. internalUpgradeHeader = "Cf-Cloudflared-Proxy-Connection-Upgrade"
  18. websocketUpgrade = "websocket"
  19. controlStreamUpgrade = "control-stream"
  20. )
  21. var errEdgeConnectionClosed = fmt.Errorf("connection with edge closed")
  22. type http2Connection struct {
  23. conn net.Conn
  24. server *http2.Server
  25. config *Config
  26. namedTunnel *NamedTunnelConfig
  27. connOptions *tunnelpogs.ConnectionOptions
  28. observer *Observer
  29. connIndexStr string
  30. connIndex uint8
  31. // newRPCClientFunc allows us to mock RPCs during testing
  32. newRPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient
  33. activeRequestsWG sync.WaitGroup
  34. connectedFuse ConnectedFuse
  35. gracefulShutdownC <-chan struct{}
  36. stoppedGracefully bool
  37. controlStreamErr error // result of running control stream handler
  38. }
  39. func NewHTTP2Connection(
  40. conn net.Conn,
  41. config *Config,
  42. namedTunnelConfig *NamedTunnelConfig,
  43. connOptions *tunnelpogs.ConnectionOptions,
  44. observer *Observer,
  45. connIndex uint8,
  46. connectedFuse ConnectedFuse,
  47. gracefulShutdownC <-chan struct{},
  48. ) *http2Connection {
  49. return &http2Connection{
  50. conn: conn,
  51. server: &http2.Server{
  52. MaxConcurrentStreams: math.MaxUint32,
  53. },
  54. config: config,
  55. namedTunnel: namedTunnelConfig,
  56. connOptions: connOptions,
  57. observer: observer,
  58. connIndexStr: uint8ToString(connIndex),
  59. connIndex: connIndex,
  60. newRPCClientFunc: newRegistrationRPCClient,
  61. connectedFuse: connectedFuse,
  62. gracefulShutdownC: gracefulShutdownC,
  63. }
  64. }
  65. func (c *http2Connection) Serve(ctx context.Context) error {
  66. go func() {
  67. <-ctx.Done()
  68. c.close()
  69. }()
  70. c.server.ServeConn(c.conn, &http2.ServeConnOpts{
  71. Context: ctx,
  72. Handler: c,
  73. })
  74. switch {
  75. case c.stoppedGracefully:
  76. return nil
  77. case c.controlStreamErr != nil:
  78. return c.controlStreamErr
  79. default:
  80. c.observer.log.Info().Uint8(LogFieldConnIndex, c.connIndex).Msg("Lost connection with the edge")
  81. return errEdgeConnectionClosed
  82. }
  83. }
  84. func (c *http2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  85. c.activeRequestsWG.Add(1)
  86. defer c.activeRequestsWG.Done()
  87. respWriter := &http2RespWriter{
  88. r: r.Body,
  89. w: w,
  90. }
  91. flusher, isFlusher := w.(http.Flusher)
  92. if !isFlusher {
  93. c.observer.log.Error().Msgf("%T doesn't implement http.Flusher", w)
  94. respWriter.WriteErrorResponse()
  95. return
  96. }
  97. respWriter.flusher = flusher
  98. var err error
  99. if isControlStreamUpgrade(r) {
  100. respWriter.shouldFlush = true
  101. err = c.serveControlStream(r.Context(), respWriter)
  102. c.controlStreamErr = err
  103. } else if isWebsocketUpgrade(r) {
  104. respWriter.shouldFlush = true
  105. stripWebsocketUpgradeHeader(r)
  106. err = c.config.OriginClient.Proxy(respWriter, r, true)
  107. } else {
  108. err = c.config.OriginClient.Proxy(respWriter, r, false)
  109. }
  110. if err != nil {
  111. respWriter.WriteErrorResponse()
  112. }
  113. }
  114. func (c *http2Connection) serveControlStream(ctx context.Context, respWriter *http2RespWriter) error {
  115. rpcClient := c.newRPCClientFunc(ctx, respWriter, c.observer.log)
  116. defer rpcClient.Close()
  117. if err := rpcClient.RegisterConnection(ctx, c.namedTunnel, c.connOptions, c.connIndex, c.observer); err != nil {
  118. return err
  119. }
  120. c.connectedFuse.Connected()
  121. // wait for connection termination or start of graceful shutdown
  122. select {
  123. case <-ctx.Done():
  124. break
  125. case <-c.gracefulShutdownC:
  126. c.stoppedGracefully = true
  127. }
  128. c.observer.sendUnregisteringEvent(c.connIndex)
  129. rpcClient.GracefulShutdown(ctx, c.config.GracePeriod)
  130. c.observer.log.Info().Uint8(LogFieldConnIndex, c.connIndex).Msg("Unregistered tunnel connection")
  131. return nil
  132. }
  133. func (c *http2Connection) close() {
  134. // Wait for all serve HTTP handlers to return
  135. c.activeRequestsWG.Wait()
  136. c.conn.Close()
  137. }
  138. type http2RespWriter struct {
  139. r io.Reader
  140. w http.ResponseWriter
  141. flusher http.Flusher
  142. shouldFlush bool
  143. }
  144. func (rp *http2RespWriter) WriteRespHeaders(resp *http.Response) error {
  145. dest := rp.w.Header()
  146. userHeaders := make(http.Header, len(resp.Header))
  147. for header, values := range resp.Header {
  148. // Since these are http2 headers, they're required to be lowercase
  149. h2name := strings.ToLower(header)
  150. for _, v := range values {
  151. if h2name == "content-length" {
  152. // This header has meaning in HTTP/2 and will be used by the edge,
  153. // so it should be sent as an HTTP/2 response header.
  154. dest.Add(h2name, v)
  155. // Since these are http2 headers, they're required to be lowercase
  156. } else if !h2mux.IsControlHeader(h2name) || h2mux.IsWebsocketClientHeader(h2name) {
  157. // User headers, on the other hand, must all be serialized so that
  158. // HTTP/2 header validation won't be applied to HTTP/1 header values
  159. userHeaders.Add(h2name, v)
  160. }
  161. }
  162. }
  163. // Perform user header serialization and set them in the single header
  164. dest.Set(canonicalResponseUserHeadersField, h2mux.SerializeHeaders(userHeaders))
  165. rp.setResponseMetaHeader(responseMetaHeaderOrigin)
  166. status := resp.StatusCode
  167. // HTTP2 removes support for 101 Switching Protocols https://tools.ietf.org/html/rfc7540#section-8.1.1
  168. if status == http.StatusSwitchingProtocols {
  169. status = http.StatusOK
  170. }
  171. rp.w.WriteHeader(status)
  172. if IsServerSentEvent(resp.Header) {
  173. rp.shouldFlush = true
  174. }
  175. if rp.shouldFlush {
  176. rp.flusher.Flush()
  177. }
  178. return nil
  179. }
  180. func (rp *http2RespWriter) WriteErrorResponse() {
  181. rp.setResponseMetaHeader(responseMetaHeaderCfd)
  182. rp.w.WriteHeader(http.StatusBadGateway)
  183. }
  184. func (rp *http2RespWriter) setResponseMetaHeader(value string) {
  185. rp.w.Header().Set(canonicalResponseMetaHeaderField, value)
  186. }
  187. func (rp *http2RespWriter) Read(p []byte) (n int, err error) {
  188. return rp.r.Read(p)
  189. }
  190. func (rp *http2RespWriter) Write(p []byte) (n int, err error) {
  191. defer func() {
  192. // Implementer of OriginClient should make sure it doesn't write to the connection after Proxy returns
  193. // Register a recover routine just in case.
  194. if r := recover(); r != nil {
  195. println("Recover from http2 response writer panic, error", r)
  196. }
  197. }()
  198. n, err = rp.w.Write(p)
  199. if err == nil && rp.shouldFlush {
  200. rp.flusher.Flush()
  201. }
  202. return n, err
  203. }
  204. func (rp *http2RespWriter) Close() error {
  205. return nil
  206. }
  207. func isControlStreamUpgrade(r *http.Request) bool {
  208. return strings.ToLower(r.Header.Get(internalUpgradeHeader)) == controlStreamUpgrade
  209. }
  210. func isWebsocketUpgrade(r *http.Request) bool {
  211. return strings.ToLower(r.Header.Get(internalUpgradeHeader)) == websocketUpgrade
  212. }
  213. func stripWebsocketUpgradeHeader(r *http.Request) {
  214. r.Header.Del(internalUpgradeHeader)
  215. }