datagram_test.go 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. package quic
  2. import (
  3. "bytes"
  4. "context"
  5. "crypto/rand"
  6. "crypto/rsa"
  7. "crypto/tls"
  8. "crypto/x509"
  9. "encoding/pem"
  10. "fmt"
  11. "math/big"
  12. "net"
  13. "net/netip"
  14. "testing"
  15. "time"
  16. "github.com/google/gopacket/layers"
  17. "github.com/google/uuid"
  18. "github.com/quic-go/quic-go"
  19. "github.com/rs/zerolog"
  20. "github.com/stretchr/testify/require"
  21. "golang.org/x/net/icmp"
  22. "golang.org/x/net/ipv4"
  23. "golang.org/x/sync/errgroup"
  24. "github.com/cloudflare/cloudflared/packet"
  25. "github.com/cloudflare/cloudflared/tracing"
  26. )
  27. var (
  28. testSessionID = uuid.New()
  29. )
  30. func TestSuffixThenRemoveSessionID(t *testing.T) {
  31. msg := []byte(t.Name())
  32. msgWithID, err := SuffixSessionID(testSessionID, msg)
  33. require.NoError(t, err)
  34. require.Len(t, msgWithID, len(msg)+sessionIDLen)
  35. sessionID, msgWithoutID, err := extractSessionID(msgWithID)
  36. require.NoError(t, err)
  37. require.Equal(t, msg, msgWithoutID)
  38. require.Equal(t, testSessionID, sessionID)
  39. }
  40. func TestRemoveSessionIDError(t *testing.T) {
  41. // message is too short to contain session ID
  42. msg := []byte("test")
  43. _, _, err := extractSessionID(msg)
  44. require.Error(t, err)
  45. }
  46. func TestSuffixSessionIDError(t *testing.T) {
  47. msg := make([]byte, MaxDatagramFrameSize-sessionIDLen)
  48. _, err := SuffixSessionID(testSessionID, msg)
  49. require.NoError(t, err)
  50. msg = make([]byte, MaxDatagramFrameSize-sessionIDLen+1)
  51. _, err = SuffixSessionID(testSessionID, msg)
  52. require.Error(t, err)
  53. }
  54. func TestDatagram(t *testing.T) {
  55. maxPayload := make([]byte, maxDatagramPayloadSize)
  56. noPayloadSession := uuid.New()
  57. maxPayloadSession := uuid.New()
  58. sessionToPayload := []*packet.Session{
  59. {
  60. ID: noPayloadSession,
  61. Payload: make([]byte, 0),
  62. },
  63. {
  64. ID: maxPayloadSession,
  65. Payload: maxPayload,
  66. },
  67. }
  68. packets := []packet.ICMP{
  69. {
  70. IP: &packet.IP{
  71. Src: netip.MustParseAddr("172.16.0.1"),
  72. Dst: netip.MustParseAddr("192.168.0.1"),
  73. Protocol: layers.IPProtocolICMPv4,
  74. },
  75. Message: &icmp.Message{
  76. Type: ipv4.ICMPTypeTimeExceeded,
  77. Code: 0,
  78. Body: &icmp.TimeExceeded{
  79. Data: []byte("original packet"),
  80. },
  81. },
  82. },
  83. {
  84. IP: &packet.IP{
  85. Src: netip.MustParseAddr("172.16.0.2"),
  86. Dst: netip.MustParseAddr("192.168.0.2"),
  87. Protocol: layers.IPProtocolICMPv4,
  88. },
  89. Message: &icmp.Message{
  90. Type: ipv4.ICMPTypeEcho,
  91. Code: 0,
  92. Body: &icmp.Echo{
  93. ID: 6182,
  94. Seq: 9151,
  95. Data: []byte("Test ICMP echo"),
  96. },
  97. },
  98. },
  99. }
  100. testDatagram(t, 1, sessionToPayload, nil)
  101. testDatagram(t, 2, sessionToPayload, packets)
  102. }
  103. func testDatagram(t *testing.T, version uint8, sessionToPayloads []*packet.Session, packets []packet.ICMP) {
  104. quicConfig := &quic.Config{
  105. KeepAlivePeriod: 5 * time.Millisecond,
  106. EnableDatagrams: true,
  107. }
  108. quicListener := newQUICListener(t, quicConfig)
  109. defer quicListener.Close()
  110. logger := zerolog.Nop()
  111. tracingIdentity, err := tracing.NewIdentity("ec31ad8a01fde11fdcabe2efdce36873:52726f6cabc144f5:0:1")
  112. require.NoError(t, err)
  113. serializedTracingID, err := tracingIdentity.MarshalBinary()
  114. require.NoError(t, err)
  115. tracingSpan := &TracingSpanPacket{
  116. Spans: []byte("tracing"),
  117. TracingIdentity: serializedTracingID,
  118. }
  119. errGroup, ctx := errgroup.WithContext(context.Background())
  120. // Run edge side of datagram muxer
  121. errGroup.Go(func() error {
  122. // Accept quic connection
  123. quicSession, err := quicListener.Accept(ctx)
  124. if err != nil {
  125. return err
  126. }
  127. sessionDemuxChan := make(chan *packet.Session, 16)
  128. switch version {
  129. case 1:
  130. muxer := NewDatagramMuxer(quicSession, &logger, sessionDemuxChan)
  131. muxer.ServeReceive(ctx)
  132. case 2:
  133. muxer := NewDatagramMuxerV2(quicSession, &logger, sessionDemuxChan)
  134. muxer.ServeReceive(ctx)
  135. for _, pk := range packets {
  136. received, err := muxer.ReceivePacket(ctx)
  137. require.NoError(t, err)
  138. validateIPPacket(t, received, &pk)
  139. received, err = muxer.ReceivePacket(ctx)
  140. require.NoError(t, err)
  141. validateIPPacketWithTracing(t, received, &pk, serializedTracingID)
  142. }
  143. received, err := muxer.ReceivePacket(ctx)
  144. require.NoError(t, err)
  145. validateTracingSpans(t, received, tracingSpan)
  146. default:
  147. return fmt.Errorf("unknown datagram version %d", version)
  148. }
  149. for _, expectedPayload := range sessionToPayloads {
  150. actualPayload := <-sessionDemuxChan
  151. require.Equal(t, expectedPayload, actualPayload)
  152. }
  153. return nil
  154. })
  155. largePayload := make([]byte, MaxDatagramFrameSize)
  156. // Run cloudflared side of datagram muxer
  157. errGroup.Go(func() error {
  158. tlsClientConfig := &tls.Config{
  159. InsecureSkipVerify: true,
  160. NextProtos: []string{"argotunnel"},
  161. }
  162. ctx, cancel := context.WithCancel(context.Background())
  163. defer cancel()
  164. // https://github.com/quic-go/quic-go/issues/3793 MTU discovery is disabled on OSX for dual stack listeners
  165. udpConn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
  166. require.NoError(t, err)
  167. // Establish quic connection
  168. quicSession, err := quic.DialEarly(ctx, udpConn, quicListener.Addr(), tlsClientConfig, quicConfig)
  169. require.NoError(t, err)
  170. defer quicSession.CloseWithError(0, "")
  171. // Wait a few milliseconds for MTU discovery to take place
  172. time.Sleep(time.Millisecond * 100)
  173. var muxer BaseDatagramMuxer
  174. switch version {
  175. case 1:
  176. muxer = NewDatagramMuxer(quicSession, &logger, nil)
  177. case 2:
  178. muxerV2 := NewDatagramMuxerV2(quicSession, &logger, nil)
  179. encoder := packet.NewEncoder()
  180. for _, pk := range packets {
  181. encodedPacket, err := encoder.Encode(&pk)
  182. require.NoError(t, err)
  183. require.NoError(t, muxerV2.SendPacket(RawPacket(encodedPacket)))
  184. require.NoError(t, muxerV2.SendPacket(&TracedPacket{
  185. Packet: encodedPacket,
  186. TracingIdentity: serializedTracingID,
  187. }))
  188. }
  189. require.NoError(t, muxerV2.SendPacket(tracingSpan))
  190. // Payload larger than transport MTU, should not be sent
  191. require.Error(t, muxerV2.SendPacket(RawPacket{
  192. Data: largePayload,
  193. }))
  194. muxer = muxerV2
  195. default:
  196. return fmt.Errorf("unknown datagram version %d", version)
  197. }
  198. for _, session := range sessionToPayloads {
  199. require.NoError(t, muxer.SendToSession(session))
  200. }
  201. // Payload larger than transport MTU, should not be sent
  202. require.Error(t, muxer.SendToSession(&packet.Session{
  203. ID: testSessionID,
  204. Payload: largePayload,
  205. }))
  206. // Wait for edge to finish receiving the messages
  207. time.Sleep(time.Millisecond * 100)
  208. return nil
  209. })
  210. require.NoError(t, errGroup.Wait())
  211. }
  212. func validateIPPacket(t *testing.T, receivedPacket Packet, expectedICMP *packet.ICMP) {
  213. require.Equal(t, DatagramTypeIP, receivedPacket.Type())
  214. rawPacket := receivedPacket.(RawPacket)
  215. decoder := packet.NewICMPDecoder()
  216. receivedICMP, err := decoder.Decode(packet.RawPacket(rawPacket))
  217. require.NoError(t, err)
  218. validateICMP(t, expectedICMP, receivedICMP)
  219. }
  220. func validateIPPacketWithTracing(t *testing.T, receivedPacket Packet, expectedICMP *packet.ICMP, serializedTracingID []byte) {
  221. require.Equal(t, DatagramTypeIPWithTrace, receivedPacket.Type())
  222. tracedPacket := receivedPacket.(*TracedPacket)
  223. decoder := packet.NewICMPDecoder()
  224. receivedICMP, err := decoder.Decode(tracedPacket.Packet)
  225. require.NoError(t, err)
  226. validateICMP(t, expectedICMP, receivedICMP)
  227. require.True(t, bytes.Equal(tracedPacket.TracingIdentity, serializedTracingID))
  228. }
  229. func validateICMP(t *testing.T, expected, actual *packet.ICMP) {
  230. require.Equal(t, expected.IP, actual.IP)
  231. require.Equal(t, expected.Type, actual.Type)
  232. require.Equal(t, expected.Code, actual.Code)
  233. require.Equal(t, expected.Body, actual.Body)
  234. }
  235. func validateTracingSpans(t *testing.T, receivedPacket Packet, expectedSpan *TracingSpanPacket) {
  236. require.Equal(t, DatagramTypeTracingSpan, receivedPacket.Type())
  237. tracingSpans := receivedPacket.(*TracingSpanPacket)
  238. require.Equal(t, tracingSpans, expectedSpan)
  239. }
  240. func newQUICListener(t *testing.T, config *quic.Config) *quic.Listener {
  241. // Create a simple tls config.
  242. tlsConfig := generateTLSConfig()
  243. listener, err := quic.ListenAddr("127.0.0.1:0", tlsConfig, config)
  244. require.NoError(t, err)
  245. return listener
  246. }
  247. func generateTLSConfig() *tls.Config {
  248. key, err := rsa.GenerateKey(rand.Reader, 1024)
  249. if err != nil {
  250. panic(err)
  251. }
  252. template := x509.Certificate{SerialNumber: big.NewInt(1)}
  253. certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
  254. if err != nil {
  255. panic(err)
  256. }
  257. keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
  258. certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
  259. tlsCert, err := tls.X509KeyPair(certPEM, keyPEM)
  260. if err != nil {
  261. panic(err)
  262. }
  263. return &tls.Config{
  264. Certificates: []tls.Certificate{tlsCert},
  265. NextProtos: []string{"argotunnel"},
  266. }
  267. }