quic_connection.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431
  1. package connection
  2. import (
  3. "bufio"
  4. "context"
  5. "fmt"
  6. "io"
  7. "net"
  8. "net/http"
  9. "strconv"
  10. "strings"
  11. "sync/atomic"
  12. "time"
  13. "github.com/pkg/errors"
  14. "github.com/quic-go/quic-go"
  15. "github.com/rs/zerolog"
  16. "golang.org/x/sync/errgroup"
  17. cfdflow "github.com/cloudflare/cloudflared/flow"
  18. cfdquic "github.com/cloudflare/cloudflared/quic"
  19. "github.com/cloudflare/cloudflared/tracing"
  20. "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
  21. tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
  22. rpcquic "github.com/cloudflare/cloudflared/tunnelrpc/quic"
  23. )
  24. const (
  25. // HTTPHeaderKey is used to get or set http headers in QUIC ALPN if the underlying proxy connection type is HTTP.
  26. HTTPHeaderKey = "HttpHeader"
  27. // HTTPMethodKey is used to get or set http method in QUIC ALPN if the underlying proxy connection type is HTTP.
  28. HTTPMethodKey = "HttpMethod"
  29. // HTTPHostKey is used to get or set http host in QUIC ALPN if the underlying proxy connection type is HTTP.
  30. HTTPHostKey = "HttpHost"
  31. QUICMetadataFlowID = "FlowID"
  32. )
  33. // quicConnection represents the type that facilitates Proxying via QUIC streams.
  34. type quicConnection struct {
  35. conn quic.Connection
  36. logger *zerolog.Logger
  37. orchestrator Orchestrator
  38. datagramHandler DatagramSessionHandler
  39. controlStreamHandler ControlStreamHandler
  40. connOptions *tunnelpogs.ConnectionOptions
  41. connIndex uint8
  42. rpcTimeout time.Duration
  43. streamWriteTimeout time.Duration
  44. gracePeriod time.Duration
  45. }
  46. // NewTunnelConnection takes a [quic.Connection] to wrap it for use with cloudflared application logic.
  47. func NewTunnelConnection(
  48. ctx context.Context,
  49. conn quic.Connection,
  50. connIndex uint8,
  51. orchestrator Orchestrator,
  52. datagramSessionHandler DatagramSessionHandler,
  53. controlStreamHandler ControlStreamHandler,
  54. connOptions *pogs.ConnectionOptions,
  55. rpcTimeout time.Duration,
  56. streamWriteTimeout time.Duration,
  57. gracePeriod time.Duration,
  58. logger *zerolog.Logger,
  59. ) (TunnelConnection, error) {
  60. return &quicConnection{
  61. conn: conn,
  62. logger: logger,
  63. orchestrator: orchestrator,
  64. datagramHandler: datagramSessionHandler,
  65. controlStreamHandler: controlStreamHandler,
  66. connOptions: connOptions,
  67. connIndex: connIndex,
  68. rpcTimeout: rpcTimeout,
  69. streamWriteTimeout: streamWriteTimeout,
  70. gracePeriod: gracePeriod,
  71. }, nil
  72. }
  73. // Serve starts a QUIC connection that begins accepting streams.
  74. func (q *quicConnection) Serve(ctx context.Context) error {
  75. // The edge assumes the first stream is used for the control plane
  76. controlStream, err := q.conn.OpenStream()
  77. if err != nil {
  78. return fmt.Errorf("failed to open a registration control stream: %w", err)
  79. }
  80. // If either goroutine returns nil error, we rely on this cancellation to make sure the other goroutine exits
  81. // as fast as possible as well. Nil error means we want to exit for good (caller code won't retry serving this
  82. // connection).
  83. // If either goroutine returns a non nil error, then the error group cancels the context, thus also canceling the
  84. // other goroutine as fast as possible.
  85. ctx, cancel := context.WithCancel(ctx)
  86. errGroup, ctx := errgroup.WithContext(ctx)
  87. // In the future, if cloudflared can autonomously push traffic to the edge, we have to make sure the control
  88. // stream is already fully registered before the other goroutines can proceed.
  89. errGroup.Go(func() error {
  90. // err is equal to nil if we exit due to unregistration. If that happens we want to wait the full
  91. // amount of the grace period, allowing requests to finish before we cancel the context, which will
  92. // make cloudflared exit.
  93. if err := q.serveControlStream(ctx, controlStream); err == nil {
  94. if q.gracePeriod > 0 {
  95. // In Go1.23 this can be removed and replaced with time.Ticker
  96. // see https://pkg.go.dev/time#Tick
  97. ticker := time.NewTicker(q.gracePeriod)
  98. defer ticker.Stop()
  99. select {
  100. case <-ctx.Done():
  101. case <-ticker.C:
  102. }
  103. }
  104. }
  105. cancel()
  106. return err
  107. })
  108. errGroup.Go(func() error {
  109. defer cancel()
  110. return q.acceptStream(ctx)
  111. })
  112. errGroup.Go(func() error {
  113. defer cancel()
  114. return q.datagramHandler.Serve(ctx)
  115. })
  116. return errGroup.Wait()
  117. }
  118. // serveControlStream will serve the RPC; blocking until the control plane is done.
  119. func (q *quicConnection) serveControlStream(ctx context.Context, controlStream quic.Stream) error {
  120. return q.controlStreamHandler.ServeControlStream(ctx, controlStream, q.connOptions, q.orchestrator)
  121. }
  122. // Close the connection with no errors specified.
  123. func (q *quicConnection) Close() {
  124. _ = q.conn.CloseWithError(0, "")
  125. }
  126. func (q *quicConnection) acceptStream(ctx context.Context) error {
  127. defer q.Close()
  128. for {
  129. quicStream, err := q.conn.AcceptStream(ctx)
  130. if err != nil {
  131. // context.Canceled is usually a user ctrl+c. We don't want to log an error here as it's intentional.
  132. if errors.Is(err, context.Canceled) || q.controlStreamHandler.IsStopped() {
  133. return nil
  134. }
  135. return fmt.Errorf("failed to accept QUIC stream: %w", err)
  136. }
  137. go q.runStream(quicStream)
  138. }
  139. }
  140. func (q *quicConnection) runStream(quicStream quic.Stream) {
  141. ctx := quicStream.Context()
  142. stream := cfdquic.NewSafeStreamCloser(quicStream, q.streamWriteTimeout, q.logger)
  143. defer stream.Close()
  144. // we are going to fuse readers/writers from stream <- cloudflared -> origin, and we want to guarantee that
  145. // code executed in the code path of handleStream don't trigger an earlier close to the downstream write stream.
  146. // So, we wrap the stream with a no-op write closer and only this method can actually close write side of the stream.
  147. // A call to close will simulate a close to the read-side, which will fail subsequent reads.
  148. noCloseStream := &nopCloserReadWriter{ReadWriteCloser: stream}
  149. ss := rpcquic.NewCloudflaredServer(q.handleDataStream, q.datagramHandler, q, q.rpcTimeout)
  150. if err := ss.Serve(ctx, noCloseStream); err != nil {
  151. q.logger.Debug().Err(err).Msg("Failed to handle QUIC stream")
  152. // if we received an error at this level, then close write side of stream with an error, which will result in
  153. // RST_STREAM frame.
  154. quicStream.CancelWrite(0)
  155. }
  156. }
  157. func (q *quicConnection) handleDataStream(ctx context.Context, stream *rpcquic.RequestServerStream) error {
  158. request, err := stream.ReadConnectRequestData()
  159. if err != nil {
  160. return err
  161. }
  162. if err, connectResponseSent := q.dispatchRequest(ctx, stream, request); err != nil {
  163. q.logger.Err(err).Str("type", request.Type.String()).Str("dest", request.Dest).Msg("Request failed")
  164. // if the connectResponse was already sent and we had an error, we need to propagate it up, so that the stream is
  165. // closed with an RST_STREAM frame
  166. if connectResponseSent {
  167. return err
  168. }
  169. var metadata []pogs.Metadata
  170. // Check the type of error that was throw and add metadata that will help identify it on OTD.
  171. if errors.Is(err, cfdflow.ErrTooManyActiveFlows) {
  172. metadata = append(metadata, pogs.ErrorFlowConnectRateLimitedMetadata)
  173. }
  174. if writeRespErr := stream.WriteConnectResponseData(err, metadata...); writeRespErr != nil {
  175. return writeRespErr
  176. }
  177. }
  178. return nil
  179. }
  180. // dispatchRequest will dispatch the request to the origin depending on the type and returns an error if it occurs.
  181. // Also returns if the connect response was sent to the downstream during processing of the origin request.
  182. func (q *quicConnection) dispatchRequest(ctx context.Context, stream *rpcquic.RequestServerStream, request *pogs.ConnectRequest) (err error, connectResponseSent bool) {
  183. originProxy, err := q.orchestrator.GetOriginProxy()
  184. if err != nil {
  185. return err, false
  186. }
  187. switch request.Type {
  188. case pogs.ConnectionTypeHTTP, pogs.ConnectionTypeWebsocket:
  189. tracedReq, err := buildHTTPRequest(ctx, request, stream, q.connIndex, q.logger)
  190. if err != nil {
  191. return err, false
  192. }
  193. w := newHTTPResponseAdapter(stream)
  194. return originProxy.ProxyHTTP(&w, tracedReq, request.Type == pogs.ConnectionTypeWebsocket), w.connectResponseSent
  195. case pogs.ConnectionTypeTCP:
  196. rwa := &streamReadWriteAcker{RequestServerStream: stream}
  197. metadata := request.MetadataMap()
  198. return originProxy.ProxyTCP(ctx, rwa, &TCPRequest{
  199. Dest: request.Dest,
  200. FlowID: metadata[QUICMetadataFlowID],
  201. CfTraceID: metadata[tracing.TracerContextName],
  202. ConnIndex: q.connIndex,
  203. }), rwa.connectResponseSent
  204. default:
  205. return errors.Errorf("unsupported error type: %s", request.Type), false
  206. }
  207. }
  208. // UpdateConfiguration is the RPC method invoked by edge when there is a new configuration
  209. func (q *quicConnection) UpdateConfiguration(ctx context.Context, version int32, config []byte) *tunnelpogs.UpdateConfigurationResponse {
  210. return q.orchestrator.UpdateConfig(version, config)
  211. }
  212. // streamReadWriteAcker is a light wrapper over QUIC streams with a callback to send response back to
  213. // the client.
  214. type streamReadWriteAcker struct {
  215. *rpcquic.RequestServerStream
  216. connectResponseSent bool
  217. }
  218. // AckConnection acks response back to the proxy.
  219. func (s *streamReadWriteAcker) AckConnection(tracePropagation string) error {
  220. metadata := []pogs.Metadata{}
  221. // Only add tracing if provided by the edge request
  222. if tracePropagation != "" {
  223. metadata = append(metadata, pogs.Metadata{
  224. Key: tracing.CanonicalCloudflaredTracingHeader,
  225. Val: tracePropagation,
  226. })
  227. }
  228. s.connectResponseSent = true
  229. return s.WriteConnectResponseData(nil, metadata...)
  230. }
  231. // httpResponseAdapter translates responses written by the HTTP Proxy into ones that can be used in QUIC.
  232. type httpResponseAdapter struct {
  233. *rpcquic.RequestServerStream
  234. headers http.Header
  235. connectResponseSent bool
  236. }
  237. func newHTTPResponseAdapter(s *rpcquic.RequestServerStream) httpResponseAdapter {
  238. return httpResponseAdapter{RequestServerStream: s, headers: make(http.Header)}
  239. }
  240. func (hrw *httpResponseAdapter) AddTrailer(trailerName, trailerValue string) {
  241. // we do not support trailers over QUIC
  242. }
  243. func (hrw *httpResponseAdapter) WriteRespHeaders(status int, header http.Header) error {
  244. metadata := make([]pogs.Metadata, 0)
  245. metadata = append(metadata, pogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(status)})
  246. for k, vv := range header {
  247. for _, v := range vv {
  248. httpHeaderKey := fmt.Sprintf("%s:%s", HTTPHeaderKey, k)
  249. metadata = append(metadata, pogs.Metadata{Key: httpHeaderKey, Val: v})
  250. }
  251. }
  252. return hrw.WriteConnectResponseData(nil, metadata...)
  253. }
  254. func (hrw *httpResponseAdapter) Write(p []byte) (int, error) {
  255. // Make sure to send WriteHeader response if not called yet
  256. if !hrw.connectResponseSent {
  257. _ = hrw.WriteRespHeaders(http.StatusOK, hrw.headers)
  258. }
  259. return hrw.RequestServerStream.Write(p)
  260. }
  261. func (hrw *httpResponseAdapter) Header() http.Header {
  262. return hrw.headers
  263. }
  264. // This is a no-op Flush because this adapter is over a quic.Stream and we don't need Flush here.
  265. func (hrw *httpResponseAdapter) Flush() {}
  266. func (hrw *httpResponseAdapter) WriteHeader(status int) {
  267. _ = hrw.WriteRespHeaders(status, hrw.headers)
  268. }
  269. func (hrw *httpResponseAdapter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
  270. conn := &localProxyConnection{hrw.ReadWriteCloser}
  271. readWriter := bufio.NewReadWriter(
  272. bufio.NewReader(hrw.ReadWriteCloser),
  273. bufio.NewWriter(hrw.ReadWriteCloser),
  274. )
  275. return conn, readWriter, nil
  276. }
  277. func (hrw *httpResponseAdapter) WriteErrorResponse(err error) {
  278. _ = hrw.WriteConnectResponseData(err, pogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)})
  279. }
  280. func (hrw *httpResponseAdapter) WriteConnectResponseData(respErr error, metadata ...pogs.Metadata) error {
  281. hrw.connectResponseSent = true
  282. return hrw.RequestServerStream.WriteConnectResponseData(respErr, metadata...)
  283. }
  284. func buildHTTPRequest(
  285. ctx context.Context,
  286. connectRequest *pogs.ConnectRequest,
  287. body io.ReadCloser,
  288. connIndex uint8,
  289. log *zerolog.Logger,
  290. ) (*tracing.TracedHTTPRequest, error) {
  291. metadata := connectRequest.MetadataMap()
  292. dest := connectRequest.Dest
  293. method := metadata[HTTPMethodKey]
  294. host := metadata[HTTPHostKey]
  295. isWebsocket := connectRequest.Type == pogs.ConnectionTypeWebsocket
  296. req, err := http.NewRequestWithContext(ctx, method, dest, body)
  297. if err != nil {
  298. return nil, err
  299. }
  300. req.Host = host
  301. for _, metadata := range connectRequest.Metadata {
  302. if strings.Contains(metadata.Key, HTTPHeaderKey) {
  303. // metadata.Key is off the format httpHeaderKey:<HTTPHeader>
  304. httpHeaderKey := strings.Split(metadata.Key, ":")
  305. if len(httpHeaderKey) != 2 {
  306. return nil, fmt.Errorf("header Key: %s malformed", metadata.Key)
  307. }
  308. req.Header.Add(httpHeaderKey[1], metadata.Val)
  309. }
  310. }
  311. // Go's http.Client automatically sends chunked request body if this value is not set on the
  312. // *http.Request struct regardless of header:
  313. // https://go.googlesource.com/go/+/go1.8rc2/src/net/http/transfer.go#154.
  314. if err := setContentLength(req); err != nil {
  315. return nil, fmt.Errorf("Error setting content-length: %w", err)
  316. }
  317. // Go's client defaults to chunked encoding after a 200ms delay if the following cases are true:
  318. // * the request body blocks
  319. // * the content length is not set (or set to -1)
  320. // * the method doesn't usually have a body (GET, HEAD, DELETE, ...)
  321. // * there is no transfer-encoding=chunked already set.
  322. // So, if transfer cannot be chunked and content length is 0, we dont set a request body.
  323. if !isWebsocket && !isTransferEncodingChunked(req) && req.ContentLength == 0 {
  324. req.Body = http.NoBody
  325. }
  326. stripWebsocketUpgradeHeader(req)
  327. // Check for tracing on request
  328. tracedReq := tracing.NewTracedHTTPRequest(req, connIndex, log)
  329. return tracedReq, err
  330. }
  331. func setContentLength(req *http.Request) error {
  332. var err error
  333. if contentLengthStr := req.Header.Get("Content-Length"); contentLengthStr != "" {
  334. req.ContentLength, err = strconv.ParseInt(contentLengthStr, 10, 64)
  335. }
  336. return err
  337. }
  338. func isTransferEncodingChunked(req *http.Request) bool {
  339. transferEncodingVal := req.Header.Get("Transfer-Encoding")
  340. // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Transfer-Encoding suggests that this can be a comma
  341. // separated value as well.
  342. return strings.Contains(strings.ToLower(transferEncodingVal), "chunked")
  343. }
  344. // A helper struct that guarantees a call to close only affects read side, but not write side.
  345. type nopCloserReadWriter struct {
  346. io.ReadWriteCloser
  347. // for use by Read only
  348. // we don't need a memory barrier here because there is an implicit assumption that
  349. // Read calls can't happen concurrently by different go-routines.
  350. sawEOF bool
  351. // should be updated and read using atomic primitives.
  352. // value is read in Read method and written in Close method, which could be done by different
  353. // go-routines.
  354. closed uint32
  355. }
  356. func (np *nopCloserReadWriter) Read(p []byte) (n int, err error) {
  357. if np.sawEOF {
  358. return 0, io.EOF
  359. }
  360. if atomic.LoadUint32(&np.closed) > 0 {
  361. return 0, fmt.Errorf("closed by handler")
  362. }
  363. n, err = np.ReadWriteCloser.Read(p)
  364. if err == io.EOF {
  365. np.sawEOF = true
  366. }
  367. return
  368. }
  369. func (np *nopCloserReadWriter) Close() error {
  370. atomic.StoreUint32(&np.closed, 1)
  371. return nil
  372. }