conntracker.go 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. package tunnelstate
  2. import (
  3. "sync"
  4. "github.com/rs/zerolog"
  5. "github.com/cloudflare/cloudflared/connection"
  6. )
  7. type ConnTracker struct {
  8. sync.RWMutex
  9. // int is the connection Index
  10. connectionInfo map[uint8]ConnectionInfo
  11. log *zerolog.Logger
  12. }
  13. type ConnectionInfo struct {
  14. IsConnected bool
  15. Protocol connection.Protocol
  16. }
  17. func NewConnTracker(log *zerolog.Logger) *ConnTracker {
  18. return &ConnTracker{
  19. connectionInfo: make(map[uint8]ConnectionInfo, 0),
  20. log: log,
  21. }
  22. }
  23. func MockedConnTracker(mocked map[uint8]ConnectionInfo) *ConnTracker {
  24. return &ConnTracker{
  25. connectionInfo: mocked,
  26. }
  27. }
  28. func (ct *ConnTracker) OnTunnelEvent(c connection.Event) {
  29. switch c.EventType {
  30. case connection.Connected:
  31. ct.Lock()
  32. ci := ConnectionInfo{
  33. IsConnected: true,
  34. Protocol: c.Protocol,
  35. }
  36. ct.connectionInfo[c.Index] = ci
  37. ct.Unlock()
  38. case connection.Disconnected, connection.Reconnecting, connection.RegisteringTunnel, connection.Unregistering:
  39. ct.Lock()
  40. ci := ct.connectionInfo[c.Index]
  41. ci.IsConnected = false
  42. ct.connectionInfo[c.Index] = ci
  43. ct.Unlock()
  44. default:
  45. ct.log.Error().Msgf("Unknown connection event case %v", c)
  46. }
  47. }
  48. func (ct *ConnTracker) CountActiveConns() uint {
  49. ct.RLock()
  50. defer ct.RUnlock()
  51. active := uint(0)
  52. for _, ci := range ct.connectionInfo {
  53. if ci.IsConnected {
  54. active++
  55. }
  56. }
  57. return active
  58. }
  59. // HasConnectedWith checks if we've ever had a successful connection to the edge
  60. // with said protocol.
  61. func (ct *ConnTracker) HasConnectedWith(protocol connection.Protocol) bool {
  62. ct.RLock()
  63. defer ct.RUnlock()
  64. for _, ci := range ct.connectionInfo {
  65. if ci.Protocol == protocol {
  66. return true
  67. }
  68. }
  69. return false
  70. }