origin_connection.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. package ingress
  2. import (
  3. "context"
  4. "io"
  5. "net"
  6. "time"
  7. "github.com/rs/zerolog"
  8. "github.com/cloudflare/cloudflared/ipaccess"
  9. "github.com/cloudflare/cloudflared/socks"
  10. "github.com/cloudflare/cloudflared/stream"
  11. "github.com/cloudflare/cloudflared/websocket"
  12. )
  13. // OriginConnection is a way to stream to a service running on the user's origin.
  14. // Different concrete implementations will stream different protocols as long as they are io.ReadWriters.
  15. type OriginConnection interface {
  16. // Stream should generally be implemented as a bidirectional io.Copy.
  17. Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger)
  18. Close()
  19. }
  20. type streamHandlerFunc func(originConn io.ReadWriter, remoteConn net.Conn, log *zerolog.Logger)
  21. // DefaultStreamHandler is an implementation of streamHandlerFunc that
  22. // performs a two way io.Copy between originConn and remoteConn.
  23. func DefaultStreamHandler(originConn io.ReadWriter, remoteConn net.Conn, log *zerolog.Logger) {
  24. stream.Pipe(originConn, remoteConn, log)
  25. }
  26. // tcpConnection is an OriginConnection that directly streams to raw TCP.
  27. type tcpConnection struct {
  28. net.Conn
  29. writeTimeout time.Duration
  30. logger *zerolog.Logger
  31. }
  32. func (tc *tcpConnection) Stream(_ context.Context, tunnelConn io.ReadWriter, _ *zerolog.Logger) {
  33. stream.Pipe(tunnelConn, tc, tc.logger)
  34. }
  35. func (tc *tcpConnection) Write(b []byte) (int, error) {
  36. if tc.writeTimeout > 0 {
  37. if err := tc.Conn.SetWriteDeadline(time.Now().Add(tc.writeTimeout)); err != nil {
  38. tc.logger.Err(err).Msg("Error setting write deadline for TCP connection")
  39. }
  40. }
  41. nBytes, err := tc.Conn.Write(b)
  42. if err != nil {
  43. tc.logger.Err(err).Msg("Error writing to the TCP connection")
  44. }
  45. return nBytes, err
  46. }
  47. func (tc *tcpConnection) Close() {
  48. tc.Conn.Close()
  49. }
  50. // tcpOverWSConnection is an OriginConnection that streams to TCP over WS.
  51. type tcpOverWSConnection struct {
  52. conn net.Conn
  53. streamHandler streamHandlerFunc
  54. }
  55. func (wc *tcpOverWSConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) {
  56. wsCtx, cancel := context.WithCancel(ctx)
  57. wsConn := websocket.NewConn(wsCtx, tunnelConn, log)
  58. wc.streamHandler(wsConn, wc.conn, log)
  59. cancel()
  60. // Makes sure wsConn stops sending ping before terminating the stream
  61. wsConn.Close()
  62. }
  63. func (wc *tcpOverWSConnection) Close() {
  64. wc.conn.Close()
  65. }
  66. // socksProxyOverWSConnection is an OriginConnection that streams SOCKS connections over WS.
  67. // The connection to the origin happens inside the SOCKS code as the client specifies the origin
  68. // details in the packet.
  69. type socksProxyOverWSConnection struct {
  70. accessPolicy *ipaccess.Policy
  71. }
  72. func (sp *socksProxyOverWSConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) {
  73. wsCtx, cancel := context.WithCancel(ctx)
  74. wsConn := websocket.NewConn(wsCtx, tunnelConn, log)
  75. socks.StreamNetHandler(wsConn, sp.accessPolicy, log)
  76. cancel()
  77. // Makes sure wsConn stops sending ping before terminating the stream
  78. wsConn.Close()
  79. }
  80. func (sp *socksProxyOverWSConnection) Close() {
  81. }