probetest.go 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. /*
  2. Probe test server to check the reachability of Snowflake proxies from
  3. clients with symmetric NATs.
  4. The probe server receives an offer from a proxy, returns an answer, and then
  5. attempts to establish a datachannel connection to that proxy. The proxy will
  6. self-determine whether the connection opened successfully.
  7. */
  8. package main
  9. import (
  10. "crypto/tls"
  11. "flag"
  12. "fmt"
  13. "io"
  14. "log"
  15. "net/http"
  16. "os"
  17. "strings"
  18. "time"
  19. "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/ptutil/safelog"
  20. "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/messages"
  21. "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/util"
  22. "github.com/pion/transport/v2/stdnet"
  23. "github.com/pion/webrtc/v3"
  24. "golang.org/x/crypto/acme/autocert"
  25. )
  26. const (
  27. readLimit = 100000 //Maximum number of bytes to be read from an HTTP request
  28. dataChannelTimeout = 20 * time.Second //time after which we assume proxy data channel will not open
  29. defaultStunUrl = "stun:stun.l.google.com:19302" //default STUN URL
  30. )
  31. type ProbeHandler struct {
  32. stunURL string
  33. handle func(string, http.ResponseWriter, *http.Request)
  34. }
  35. func (h ProbeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  36. h.handle(h.stunURL, w, r)
  37. }
  38. // Create a PeerConnection from an SDP offer. Blocks until the gathering of ICE
  39. // candidates is complete and the answer is available in LocalDescription.
  40. func makePeerConnectionFromOffer(stunURL string, sdp *webrtc.SessionDescription,
  41. dataChan chan struct{}) (*webrtc.PeerConnection, error) {
  42. settingsEngine := webrtc.SettingEngine{}
  43. // Use the SetNet setting https://pkg.go.dev/github.com/pion/webrtc/v3#SettingEngine.SetNet
  44. // to functionally revert a new change in pion by silently ignoring
  45. // when net.Interfaces() fails, rather than throwing an error
  46. vnet, _ := stdnet.NewNet()
  47. settingsEngine.SetNet(vnet)
  48. api := webrtc.NewAPI(webrtc.WithSettingEngine(settingsEngine))
  49. config := webrtc.Configuration{
  50. ICEServers: []webrtc.ICEServer{
  51. {
  52. URLs: []string{stunURL},
  53. },
  54. },
  55. }
  56. pc, err := api.NewPeerConnection(config)
  57. if err != nil {
  58. return nil, fmt.Errorf("accept: NewPeerConnection: %s", err)
  59. }
  60. pc.OnDataChannel(func(dc *webrtc.DataChannel) {
  61. dc.OnOpen(func() {
  62. close(dataChan)
  63. })
  64. dc.OnClose(func() {
  65. dc.Close()
  66. })
  67. })
  68. // As of v3.0.0, pion-webrtc uses trickle ICE by default.
  69. // We have to wait for candidate gathering to complete
  70. // before we send the offer
  71. done := webrtc.GatheringCompletePromise(pc)
  72. err = pc.SetRemoteDescription(*sdp)
  73. if err != nil {
  74. if inerr := pc.Close(); inerr != nil {
  75. log.Printf("unable to call pc.Close after pc.SetRemoteDescription with error: %v", inerr)
  76. }
  77. return nil, fmt.Errorf("accept: SetRemoteDescription: %s", err)
  78. }
  79. answer, err := pc.CreateAnswer(nil)
  80. if err != nil {
  81. if inerr := pc.Close(); inerr != nil {
  82. log.Printf("ICE gathering has generated an error when calling pc.Close: %v", inerr)
  83. }
  84. return nil, err
  85. }
  86. err = pc.SetLocalDescription(answer)
  87. if err != nil {
  88. if err = pc.Close(); err != nil {
  89. log.Printf("pc.Close after setting local description returned : %v", err)
  90. }
  91. return nil, err
  92. }
  93. // Wait for ICE candidate gathering to complete
  94. <-done
  95. return pc, nil
  96. }
  97. func probeHandler(stunURL string, w http.ResponseWriter, r *http.Request) {
  98. w.Header().Set("Access-Control-Allow-Origin", "*")
  99. resp, err := io.ReadAll(http.MaxBytesReader(w, r.Body, readLimit))
  100. if nil != err {
  101. log.Println("Invalid data.")
  102. w.WriteHeader(http.StatusBadRequest)
  103. return
  104. }
  105. offer, _, err := messages.DecodePollResponse(resp)
  106. if err != nil {
  107. log.Printf("Error reading offer: %s", err.Error())
  108. w.WriteHeader(http.StatusBadRequest)
  109. return
  110. }
  111. if offer == "" {
  112. log.Printf("Error processing session description: %s", err.Error())
  113. w.WriteHeader(http.StatusBadRequest)
  114. return
  115. }
  116. sdp, err := util.DeserializeSessionDescription(offer)
  117. if err != nil {
  118. log.Printf("Error processing session description: %s", err.Error())
  119. w.WriteHeader(http.StatusBadRequest)
  120. return
  121. }
  122. dataChan := make(chan struct{})
  123. pc, err := makePeerConnectionFromOffer(stunURL, sdp, dataChan)
  124. if err != nil {
  125. log.Printf("Error making WebRTC connection: %s", err)
  126. w.WriteHeader(http.StatusInternalServerError)
  127. return
  128. }
  129. sdp = &webrtc.SessionDescription{
  130. Type: pc.LocalDescription().Type,
  131. SDP: util.StripLocalAddresses(pc.LocalDescription().SDP),
  132. }
  133. answer, err := util.SerializeSessionDescription(sdp)
  134. if err != nil {
  135. log.Printf("Error making WebRTC connection: %s", err)
  136. w.WriteHeader(http.StatusInternalServerError)
  137. return
  138. }
  139. body, err := messages.EncodeAnswerRequest(answer, "stub-sid")
  140. if err != nil {
  141. log.Printf("Error making WebRTC connection: %s", err)
  142. w.WriteHeader(http.StatusInternalServerError)
  143. return
  144. }
  145. w.Write(body)
  146. // Set a timeout on peerconnection. If the connection state has not
  147. // advanced to PeerConnectionStateConnected in this time,
  148. // destroy the peer connection and return the token.
  149. go func() {
  150. timer := time.NewTimer(dataChannelTimeout)
  151. defer timer.Stop()
  152. select {
  153. case <-dataChan:
  154. case <-timer.C:
  155. }
  156. if err := pc.Close(); err != nil {
  157. log.Printf("Error calling pc.Close: %v", err)
  158. }
  159. }()
  160. return
  161. }
  162. func main() {
  163. var acmeEmail string
  164. var acmeHostnamesCommas string
  165. var acmeCertCacheDir string
  166. var addr string
  167. var disableTLS bool
  168. var certFilename, keyFilename string
  169. var unsafeLogging bool
  170. var stunURL string
  171. flag.StringVar(&acmeEmail, "acme-email", "", "optional contact email for Let's Encrypt notifications")
  172. flag.StringVar(&acmeHostnamesCommas, "acme-hostnames", "", "comma-separated hostnames for TLS certificate")
  173. flag.StringVar(&acmeCertCacheDir, "acme-cert-cache", "acme-cert-cache", "directory in which certificates should be cached")
  174. flag.StringVar(&certFilename, "cert", "", "TLS certificate file")
  175. flag.StringVar(&keyFilename, "key", "", "TLS private key file")
  176. flag.StringVar(&addr, "addr", ":8443", "address to listen on")
  177. flag.BoolVar(&disableTLS, "disable-tls", false, "don't use HTTPS")
  178. flag.BoolVar(&unsafeLogging, "unsafe-logging", false, "prevent logs from being scrubbed")
  179. flag.StringVar(&stunURL, "stun", defaultStunUrl, "STUN server to use for NAT traversal")
  180. flag.Parse()
  181. var logOutput io.Writer = os.Stderr
  182. if unsafeLogging {
  183. log.SetOutput(logOutput)
  184. } else {
  185. // Scrub log output just in case an address ends up there
  186. log.SetOutput(&safelog.LogScrubber{Output: logOutput})
  187. }
  188. log.SetFlags(log.LstdFlags | log.LUTC)
  189. http.Handle("/probe", ProbeHandler{stunURL, probeHandler})
  190. server := http.Server{
  191. Addr: addr,
  192. }
  193. var err error
  194. if acmeHostnamesCommas != "" {
  195. acmeHostnames := strings.Split(acmeHostnamesCommas, ",")
  196. log.Printf("ACME hostnames: %q", acmeHostnames)
  197. var cache autocert.Cache
  198. if err = os.MkdirAll(acmeCertCacheDir, 0700); err != nil {
  199. log.Printf("Warning: Couldn't create cache directory %q (reason: %s) so we're *not* using our certificate cache.", acmeCertCacheDir, err)
  200. } else {
  201. cache = autocert.DirCache(acmeCertCacheDir)
  202. }
  203. certManager := autocert.Manager{
  204. Cache: cache,
  205. Prompt: autocert.AcceptTOS,
  206. HostPolicy: autocert.HostWhitelist(acmeHostnames...),
  207. Email: acmeEmail,
  208. }
  209. // start certificate manager handler
  210. go func() {
  211. log.Printf("Starting HTTP-01 listener")
  212. log.Fatal(http.ListenAndServe(":80", certManager.HTTPHandler(nil)))
  213. }()
  214. server.TLSConfig = &tls.Config{GetCertificate: certManager.GetCertificate}
  215. err = server.ListenAndServeTLS("", "")
  216. } else if certFilename != "" && keyFilename != "" {
  217. err = server.ListenAndServeTLS(certFilename, keyFilename)
  218. } else if disableTLS {
  219. err = server.ListenAndServe()
  220. } else {
  221. log.Fatal("the --cert and --key, --acme-hostnames, or --disable-tls option is required")
  222. }
  223. if err != nil {
  224. log.Println(err)
  225. }
  226. }