quic_connection.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418
  1. package connection
  2. import (
  3. "bufio"
  4. "context"
  5. "fmt"
  6. "io"
  7. "net"
  8. "net/http"
  9. "strconv"
  10. "strings"
  11. "sync/atomic"
  12. "time"
  13. "github.com/pkg/errors"
  14. "github.com/quic-go/quic-go"
  15. "github.com/rs/zerolog"
  16. "golang.org/x/sync/errgroup"
  17. cfdquic "github.com/cloudflare/cloudflared/quic"
  18. "github.com/cloudflare/cloudflared/tracing"
  19. "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
  20. tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
  21. rpcquic "github.com/cloudflare/cloudflared/tunnelrpc/quic"
  22. )
  23. const (
  24. // HTTPHeaderKey is used to get or set http headers in QUIC ALPN if the underlying proxy connection type is HTTP.
  25. HTTPHeaderKey = "HttpHeader"
  26. // HTTPMethodKey is used to get or set http method in QUIC ALPN if the underlying proxy connection type is HTTP.
  27. HTTPMethodKey = "HttpMethod"
  28. // HTTPHostKey is used to get or set http host in QUIC ALPN if the underlying proxy connection type is HTTP.
  29. HTTPHostKey = "HttpHost"
  30. QUICMetadataFlowID = "FlowID"
  31. )
  32. // quicConnection represents the type that facilitates Proxying via QUIC streams.
  33. type quicConnection struct {
  34. conn quic.Connection
  35. logger *zerolog.Logger
  36. orchestrator Orchestrator
  37. datagramHandler DatagramSessionHandler
  38. controlStreamHandler ControlStreamHandler
  39. connOptions *tunnelpogs.ConnectionOptions
  40. connIndex uint8
  41. rpcTimeout time.Duration
  42. streamWriteTimeout time.Duration
  43. gracePeriod time.Duration
  44. }
  45. // NewTunnelConnection takes a [quic.Connection] to wrap it for use with cloudflared application logic.
  46. func NewTunnelConnection(
  47. ctx context.Context,
  48. conn quic.Connection,
  49. connIndex uint8,
  50. orchestrator Orchestrator,
  51. datagramSessionHandler DatagramSessionHandler,
  52. controlStreamHandler ControlStreamHandler,
  53. connOptions *pogs.ConnectionOptions,
  54. rpcTimeout time.Duration,
  55. streamWriteTimeout time.Duration,
  56. gracePeriod time.Duration,
  57. logger *zerolog.Logger,
  58. ) (TunnelConnection, error) {
  59. return &quicConnection{
  60. conn: conn,
  61. logger: logger,
  62. orchestrator: orchestrator,
  63. datagramHandler: datagramSessionHandler,
  64. controlStreamHandler: controlStreamHandler,
  65. connOptions: connOptions,
  66. connIndex: connIndex,
  67. rpcTimeout: rpcTimeout,
  68. streamWriteTimeout: streamWriteTimeout,
  69. gracePeriod: gracePeriod,
  70. }, nil
  71. }
  72. // Serve starts a QUIC connection that begins accepting streams.
  73. func (q *quicConnection) Serve(ctx context.Context) error {
  74. // The edge assumes the first stream is used for the control plane
  75. controlStream, err := q.conn.OpenStream()
  76. if err != nil {
  77. return fmt.Errorf("failed to open a registration control stream: %w", err)
  78. }
  79. // If either goroutine returns nil error, we rely on this cancellation to make sure the other goroutine exits
  80. // as fast as possible as well. Nil error means we want to exit for good (caller code won't retry serving this
  81. // connection).
  82. // If either goroutine returns a non nil error, then the error group cancels the context, thus also canceling the
  83. // other goroutine as fast as possible.
  84. ctx, cancel := context.WithCancel(ctx)
  85. errGroup, ctx := errgroup.WithContext(ctx)
  86. // In the future, if cloudflared can autonomously push traffic to the edge, we have to make sure the control
  87. // stream is already fully registered before the other goroutines can proceed.
  88. errGroup.Go(func() error {
  89. // err is equal to nil if we exit due to unregistration. If that happens we want to wait the full
  90. // amount of the grace period, allowing requests to finish before we cancel the context, which will
  91. // make cloudflared exit.
  92. if err := q.serveControlStream(ctx, controlStream); err == nil {
  93. select {
  94. case <-ctx.Done():
  95. case <-time.Tick(q.gracePeriod):
  96. }
  97. }
  98. cancel()
  99. return err
  100. })
  101. errGroup.Go(func() error {
  102. defer cancel()
  103. return q.acceptStream(ctx)
  104. })
  105. errGroup.Go(func() error {
  106. defer cancel()
  107. return q.datagramHandler.Serve(ctx)
  108. })
  109. return errGroup.Wait()
  110. }
  111. // serveControlStream will serve the RPC; blocking until the control plane is done.
  112. func (q *quicConnection) serveControlStream(ctx context.Context, controlStream quic.Stream) error {
  113. return q.controlStreamHandler.ServeControlStream(ctx, controlStream, q.connOptions, q.orchestrator)
  114. }
  115. // Close the connection with no errors specified.
  116. func (q *quicConnection) Close() {
  117. q.conn.CloseWithError(0, "")
  118. }
  119. func (q *quicConnection) acceptStream(ctx context.Context) error {
  120. defer q.Close()
  121. for {
  122. quicStream, err := q.conn.AcceptStream(ctx)
  123. if err != nil {
  124. // context.Canceled is usually a user ctrl+c. We don't want to log an error here as it's intentional.
  125. if errors.Is(err, context.Canceled) || q.controlStreamHandler.IsStopped() {
  126. return nil
  127. }
  128. return fmt.Errorf("failed to accept QUIC stream: %w", err)
  129. }
  130. go q.runStream(quicStream)
  131. }
  132. }
  133. func (q *quicConnection) runStream(quicStream quic.Stream) {
  134. ctx := quicStream.Context()
  135. stream := cfdquic.NewSafeStreamCloser(quicStream, q.streamWriteTimeout, q.logger)
  136. defer stream.Close()
  137. // we are going to fuse readers/writers from stream <- cloudflared -> origin, and we want to guarantee that
  138. // code executed in the code path of handleStream don't trigger an earlier close to the downstream write stream.
  139. // So, we wrap the stream with a no-op write closer and only this method can actually close write side of the stream.
  140. // A call to close will simulate a close to the read-side, which will fail subsequent reads.
  141. noCloseStream := &nopCloserReadWriter{ReadWriteCloser: stream}
  142. ss := rpcquic.NewCloudflaredServer(q.handleDataStream, q.datagramHandler, q, q.rpcTimeout)
  143. if err := ss.Serve(ctx, noCloseStream); err != nil {
  144. q.logger.Debug().Err(err).Msg("Failed to handle QUIC stream")
  145. // if we received an error at this level, then close write side of stream with an error, which will result in
  146. // RST_STREAM frame.
  147. quicStream.CancelWrite(0)
  148. }
  149. }
  150. func (q *quicConnection) handleDataStream(ctx context.Context, stream *rpcquic.RequestServerStream) error {
  151. request, err := stream.ReadConnectRequestData()
  152. if err != nil {
  153. return err
  154. }
  155. if err, connectResponseSent := q.dispatchRequest(ctx, stream, request); err != nil {
  156. q.logger.Err(err).Str("type", request.Type.String()).Str("dest", request.Dest).Msg("Request failed")
  157. // if the connectResponse was already sent and we had an error, we need to propagate it up, so that the stream is
  158. // closed with an RST_STREAM frame
  159. if connectResponseSent {
  160. return err
  161. }
  162. if writeRespErr := stream.WriteConnectResponseData(err); writeRespErr != nil {
  163. return writeRespErr
  164. }
  165. }
  166. return nil
  167. }
  168. // dispatchRequest will dispatch the request to the origin depending on the type and returns an error if it occurs.
  169. // Also returns if the connect response was sent to the downstream during processing of the origin request.
  170. func (q *quicConnection) dispatchRequest(ctx context.Context, stream *rpcquic.RequestServerStream, request *pogs.ConnectRequest) (err error, connectResponseSent bool) {
  171. originProxy, err := q.orchestrator.GetOriginProxy()
  172. if err != nil {
  173. return err, false
  174. }
  175. switch request.Type {
  176. case pogs.ConnectionTypeHTTP, pogs.ConnectionTypeWebsocket:
  177. tracedReq, err := buildHTTPRequest(ctx, request, stream, q.connIndex, q.logger)
  178. if err != nil {
  179. return err, false
  180. }
  181. w := newHTTPResponseAdapter(stream)
  182. return originProxy.ProxyHTTP(&w, tracedReq, request.Type == pogs.ConnectionTypeWebsocket), w.connectResponseSent
  183. case pogs.ConnectionTypeTCP:
  184. rwa := &streamReadWriteAcker{RequestServerStream: stream}
  185. metadata := request.MetadataMap()
  186. return originProxy.ProxyTCP(ctx, rwa, &TCPRequest{
  187. Dest: request.Dest,
  188. FlowID: metadata[QUICMetadataFlowID],
  189. CfTraceID: metadata[tracing.TracerContextName],
  190. ConnIndex: q.connIndex,
  191. }), rwa.connectResponseSent
  192. default:
  193. return errors.Errorf("unsupported error type: %s", request.Type), false
  194. }
  195. }
  196. // UpdateConfiguration is the RPC method invoked by edge when there is a new configuration
  197. func (q *quicConnection) UpdateConfiguration(ctx context.Context, version int32, config []byte) *tunnelpogs.UpdateConfigurationResponse {
  198. return q.orchestrator.UpdateConfig(version, config)
  199. }
  200. // streamReadWriteAcker is a light wrapper over QUIC streams with a callback to send response back to
  201. // the client.
  202. type streamReadWriteAcker struct {
  203. *rpcquic.RequestServerStream
  204. connectResponseSent bool
  205. }
  206. // AckConnection acks response back to the proxy.
  207. func (s *streamReadWriteAcker) AckConnection(tracePropagation string) error {
  208. metadata := []pogs.Metadata{}
  209. // Only add tracing if provided by the edge request
  210. if tracePropagation != "" {
  211. metadata = append(metadata, pogs.Metadata{
  212. Key: tracing.CanonicalCloudflaredTracingHeader,
  213. Val: tracePropagation,
  214. })
  215. }
  216. s.connectResponseSent = true
  217. return s.WriteConnectResponseData(nil, metadata...)
  218. }
  219. // httpResponseAdapter translates responses written by the HTTP Proxy into ones that can be used in QUIC.
  220. type httpResponseAdapter struct {
  221. *rpcquic.RequestServerStream
  222. headers http.Header
  223. connectResponseSent bool
  224. }
  225. func newHTTPResponseAdapter(s *rpcquic.RequestServerStream) httpResponseAdapter {
  226. return httpResponseAdapter{RequestServerStream: s, headers: make(http.Header)}
  227. }
  228. func (hrw *httpResponseAdapter) AddTrailer(trailerName, trailerValue string) {
  229. // we do not support trailers over QUIC
  230. }
  231. func (hrw *httpResponseAdapter) WriteRespHeaders(status int, header http.Header) error {
  232. metadata := make([]pogs.Metadata, 0)
  233. metadata = append(metadata, pogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(status)})
  234. for k, vv := range header {
  235. for _, v := range vv {
  236. httpHeaderKey := fmt.Sprintf("%s:%s", HTTPHeaderKey, k)
  237. metadata = append(metadata, pogs.Metadata{Key: httpHeaderKey, Val: v})
  238. }
  239. }
  240. return hrw.WriteConnectResponseData(nil, metadata...)
  241. }
  242. func (hrw *httpResponseAdapter) Write(p []byte) (int, error) {
  243. // Make sure to send WriteHeader response if not called yet
  244. if !hrw.connectResponseSent {
  245. hrw.WriteRespHeaders(http.StatusOK, hrw.headers)
  246. }
  247. return hrw.RequestServerStream.Write(p)
  248. }
  249. func (hrw *httpResponseAdapter) Header() http.Header {
  250. return hrw.headers
  251. }
  252. // This is a no-op Flush because this adapter is over a quic.Stream and we don't need Flush here.
  253. func (hrw *httpResponseAdapter) Flush() {}
  254. func (hrw *httpResponseAdapter) WriteHeader(status int) {
  255. hrw.WriteRespHeaders(status, hrw.headers)
  256. }
  257. func (hrw *httpResponseAdapter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
  258. conn := &localProxyConnection{hrw.ReadWriteCloser}
  259. readWriter := bufio.NewReadWriter(
  260. bufio.NewReader(hrw.ReadWriteCloser),
  261. bufio.NewWriter(hrw.ReadWriteCloser),
  262. )
  263. return conn, readWriter, nil
  264. }
  265. func (hrw *httpResponseAdapter) WriteErrorResponse(err error) {
  266. hrw.WriteConnectResponseData(err, pogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)})
  267. }
  268. func (hrw *httpResponseAdapter) WriteConnectResponseData(respErr error, metadata ...pogs.Metadata) error {
  269. hrw.connectResponseSent = true
  270. return hrw.RequestServerStream.WriteConnectResponseData(respErr, metadata...)
  271. }
  272. func buildHTTPRequest(
  273. ctx context.Context,
  274. connectRequest *pogs.ConnectRequest,
  275. body io.ReadCloser,
  276. connIndex uint8,
  277. log *zerolog.Logger,
  278. ) (*tracing.TracedHTTPRequest, error) {
  279. metadata := connectRequest.MetadataMap()
  280. dest := connectRequest.Dest
  281. method := metadata[HTTPMethodKey]
  282. host := metadata[HTTPHostKey]
  283. isWebsocket := connectRequest.Type == pogs.ConnectionTypeWebsocket
  284. req, err := http.NewRequestWithContext(ctx, method, dest, body)
  285. if err != nil {
  286. return nil, err
  287. }
  288. req.Host = host
  289. for _, metadata := range connectRequest.Metadata {
  290. if strings.Contains(metadata.Key, HTTPHeaderKey) {
  291. // metadata.Key is off the format httpHeaderKey:<HTTPHeader>
  292. httpHeaderKey := strings.Split(metadata.Key, ":")
  293. if len(httpHeaderKey) != 2 {
  294. return nil, fmt.Errorf("header Key: %s malformed", metadata.Key)
  295. }
  296. req.Header.Add(httpHeaderKey[1], metadata.Val)
  297. }
  298. }
  299. // Go's http.Client automatically sends chunked request body if this value is not set on the
  300. // *http.Request struct regardless of header:
  301. // https://go.googlesource.com/go/+/go1.8rc2/src/net/http/transfer.go#154.
  302. if err := setContentLength(req); err != nil {
  303. return nil, fmt.Errorf("Error setting content-length: %w", err)
  304. }
  305. // Go's client defaults to chunked encoding after a 200ms delay if the following cases are true:
  306. // * the request body blocks
  307. // * the content length is not set (or set to -1)
  308. // * the method doesn't usually have a body (GET, HEAD, DELETE, ...)
  309. // * there is no transfer-encoding=chunked already set.
  310. // So, if transfer cannot be chunked and content length is 0, we dont set a request body.
  311. if !isWebsocket && !isTransferEncodingChunked(req) && req.ContentLength == 0 {
  312. req.Body = http.NoBody
  313. }
  314. stripWebsocketUpgradeHeader(req)
  315. // Check for tracing on request
  316. tracedReq := tracing.NewTracedHTTPRequest(req, connIndex, log)
  317. return tracedReq, err
  318. }
  319. func setContentLength(req *http.Request) error {
  320. var err error
  321. if contentLengthStr := req.Header.Get("Content-Length"); contentLengthStr != "" {
  322. req.ContentLength, err = strconv.ParseInt(contentLengthStr, 10, 64)
  323. }
  324. return err
  325. }
  326. func isTransferEncodingChunked(req *http.Request) bool {
  327. transferEncodingVal := req.Header.Get("Transfer-Encoding")
  328. // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Transfer-Encoding suggests that this can be a comma
  329. // separated value as well.
  330. return strings.Contains(strings.ToLower(transferEncodingVal), "chunked")
  331. }
  332. // A helper struct that guarantees a call to close only affects read side, but not write side.
  333. type nopCloserReadWriter struct {
  334. io.ReadWriteCloser
  335. // for use by Read only
  336. // we don't need a memory barrier here because there is an implicit assumption that
  337. // Read calls can't happen concurrently by different go-routines.
  338. sawEOF bool
  339. // should be updated and read using atomic primitives.
  340. // value is read in Read method and written in Close method, which could be done by different
  341. // go-routines.
  342. closed uint32
  343. }
  344. func (np *nopCloserReadWriter) Read(p []byte) (n int, err error) {
  345. if np.sawEOF {
  346. return 0, io.EOF
  347. }
  348. if atomic.LoadUint32(&np.closed) > 0 {
  349. return 0, fmt.Errorf("closed by handler")
  350. }
  351. n, err = np.ReadWriteCloser.Read(p)
  352. if err == io.EOF {
  353. np.sawEOF = true
  354. }
  355. return
  356. }
  357. func (np *nopCloserReadWriter) Close() error {
  358. atomic.StoreUint32(&np.closed, 1)
  359. return nil
  360. }