123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259 |
- package quic
- import (
- "context"
- "fmt"
- "github.com/pkg/errors"
- "github.com/quic-go/quic-go"
- "github.com/rs/zerolog"
- "github.com/cloudflare/cloudflared/packet"
- "github.com/cloudflare/cloudflared/tracing"
- )
- type DatagramV2Type byte
- const (
- // UDP payload
- DatagramTypeUDP DatagramV2Type = iota
- // Full IP packet
- DatagramTypeIP
- // DatagramTypeIP + tracing ID
- DatagramTypeIPWithTrace
- // Tracing spans in protobuf format
- DatagramTypeTracingSpan
- )
- type Packet interface {
- Type() DatagramV2Type
- Payload() []byte
- Metadata() []byte
- }
- const (
- typeIDLen = 1
- // Same as sessionDemuxChan capacity
- packetChanCapacity = 128
- )
- func SuffixType(b []byte, datagramType DatagramV2Type) ([]byte, error) {
- if len(b)+typeIDLen > MaxDatagramFrameSize {
- return nil, fmt.Errorf("datagram size %d exceeds max frame size %d", len(b), MaxDatagramFrameSize)
- }
- b = append(b, byte(datagramType))
- return b, nil
- }
- // Maximum application payload to send to / receive from QUIC datagram frame
- func (dm *DatagramMuxerV2) mtu() int {
- return maxDatagramPayloadSize
- }
- type DatagramMuxerV2 struct {
- session quic.Connection
- logger *zerolog.Logger
- sessionDemuxChan chan<- *packet.Session
- packetDemuxChan chan Packet
- }
- func NewDatagramMuxerV2(
- quicSession quic.Connection,
- log *zerolog.Logger,
- sessionDemuxChan chan<- *packet.Session,
- ) *DatagramMuxerV2 {
- logger := log.With().Uint8("datagramVersion", 2).Logger()
- return &DatagramMuxerV2{
- session: quicSession,
- logger: &logger,
- sessionDemuxChan: sessionDemuxChan,
- packetDemuxChan: make(chan Packet, packetChanCapacity),
- }
- }
- // SendToSession suffix the session ID and datagram version to the payload so the other end of the QUIC connection can
- // demultiplex the payload from multiple datagram sessions
- func (dm *DatagramMuxerV2) SendToSession(session *packet.Session) error {
- if len(session.Payload) > dm.mtu() {
- packetTooBigDropped.Inc()
- return fmt.Errorf("origin UDP payload has %d bytes, which exceeds transport MTU %d", len(session.Payload), dm.mtu())
- }
- msgWithID, err := SuffixSessionID(session.ID, session.Payload)
- if err != nil {
- return errors.Wrap(err, "Failed to suffix session ID to datagram, it will be dropped")
- }
- msgWithIDAndType, err := SuffixType(msgWithID, DatagramTypeUDP)
- if err != nil {
- return errors.Wrap(err, "Failed to suffix datagram type, it will be dropped")
- }
- if err := dm.session.SendDatagram(msgWithIDAndType); err != nil {
- return errors.Wrap(err, "Failed to send datagram back to edge")
- }
- return nil
- }
- // SendPacket sends a packet with datagram version in the suffix. If ctx is a TracedContext, it adds the tracing
- // context between payload and datagram version.
- // The other end of the QUIC connection can demultiplex by parsing the payload as IP and look at the source and destination.
- func (dm *DatagramMuxerV2) SendPacket(pk Packet) error {
- payloadWithMetadata, err := suffixMetadata(pk.Payload(), pk.Metadata())
- if err != nil {
- return err
- }
- payloadWithMetadataAndType, err := SuffixType(payloadWithMetadata, pk.Type())
- if err != nil {
- return errors.Wrap(err, "Failed to suffix datagram type, it will be dropped")
- }
- if err := dm.session.SendDatagram(payloadWithMetadataAndType); err != nil {
- return errors.Wrap(err, "Failed to send datagram back to edge")
- }
- return nil
- }
- // Demux reads datagrams from the QUIC connection and demuxes depending on whether it's a session or packet
- func (dm *DatagramMuxerV2) ServeReceive(ctx context.Context) error {
- for {
- msg, err := dm.session.ReceiveDatagram(ctx)
- if err != nil {
- return err
- }
- if err := dm.demux(ctx, msg); err != nil {
- dm.logger.Error().Err(err).Msg("Failed to demux datagram")
- if err == context.Canceled {
- return err
- }
- }
- }
- }
- func (dm *DatagramMuxerV2) ReceivePacket(ctx context.Context) (pk Packet, err error) {
- select {
- case <-ctx.Done():
- return nil, ctx.Err()
- case pk := <-dm.packetDemuxChan:
- return pk, nil
- }
- }
- func (dm *DatagramMuxerV2) demux(ctx context.Context, msgWithType []byte) error {
- if len(msgWithType) < typeIDLen {
- return fmt.Errorf("QUIC datagram should have at least %d byte", typeIDLen)
- }
- msgType := DatagramV2Type(msgWithType[len(msgWithType)-typeIDLen])
- msg := msgWithType[0 : len(msgWithType)-typeIDLen]
- switch msgType {
- case DatagramTypeUDP:
- return dm.handleSession(ctx, msg)
- default:
- return dm.handlePacket(ctx, msg, msgType)
- }
- }
- func (dm *DatagramMuxerV2) handleSession(ctx context.Context, session []byte) error {
- sessionID, payload, err := extractSessionID(session)
- if err != nil {
- return err
- }
- sessionDatagram := packet.Session{
- ID: sessionID,
- Payload: payload,
- }
- select {
- case dm.sessionDemuxChan <- &sessionDatagram:
- return nil
- case <-ctx.Done():
- return ctx.Err()
- }
- }
- func (dm *DatagramMuxerV2) handlePacket(ctx context.Context, pk []byte, msgType DatagramV2Type) error {
- var demuxedPacket Packet
- switch msgType {
- case DatagramTypeIP:
- demuxedPacket = RawPacket(packet.RawPacket{Data: pk})
- case DatagramTypeIPWithTrace:
- tracingIdentity, payload, err := extractTracingIdentity(pk)
- if err != nil {
- return err
- }
- demuxedPacket = &TracedPacket{
- Packet: packet.RawPacket{Data: payload},
- TracingIdentity: tracingIdentity,
- }
- case DatagramTypeTracingSpan:
- tracingIdentity, spans, err := extractTracingIdentity(pk)
- if err != nil {
- return err
- }
- demuxedPacket = &TracingSpanPacket{
- Spans: spans,
- TracingIdentity: tracingIdentity,
- }
- default:
- return fmt.Errorf("Unexpected datagram type %d", msgType)
- }
- select {
- case <-ctx.Done():
- return ctx.Err()
- case dm.packetDemuxChan <- demuxedPacket:
- return nil
- }
- }
- func extractTracingIdentity(pk []byte) (tracingIdentity []byte, payload []byte, err error) {
- if len(pk) < tracing.IdentityLength {
- return nil, nil, fmt.Errorf("packet with tracing context should have at least %d bytes, got %v", tracing.IdentityLength, pk)
- }
- tracingIdentity = pk[len(pk)-tracing.IdentityLength:]
- payload = pk[:len(pk)-tracing.IdentityLength]
- return tracingIdentity, payload, nil
- }
- type RawPacket packet.RawPacket
- func (rw RawPacket) Type() DatagramV2Type {
- return DatagramTypeIP
- }
- func (rw RawPacket) Payload() []byte {
- return rw.Data
- }
- func (rw RawPacket) Metadata() []byte {
- return []byte{}
- }
- type TracedPacket struct {
- Packet packet.RawPacket
- TracingIdentity []byte
- }
- func (tp *TracedPacket) Type() DatagramV2Type {
- return DatagramTypeIPWithTrace
- }
- func (tp *TracedPacket) Payload() []byte {
- return tp.Packet.Data
- }
- func (tp *TracedPacket) Metadata() []byte {
- return tp.TracingIdentity
- }
- type TracingSpanPacket struct {
- Spans []byte
- TracingIdentity []byte
- }
- func (tsp *TracingSpanPacket) Type() DatagramV2Type {
- return DatagramTypeTracingSpan
- }
- func (tsp *TracingSpanPacket) Payload() []byte {
- return tsp.Spans
- }
- func (tsp *TracingSpanPacket) Metadata() []byte {
- return tsp.TracingIdentity
- }
|