probetest.go 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. /*
  2. Probe test server to check the reachability of Snowflake proxies from
  3. clients with symmetric NATs.
  4. The probe server receives an offer from a proxy, returns an answer, and then
  5. attempts to establish a datachannel connection to that proxy. The proxy will
  6. self-determine whether the connection opened successfully.
  7. */
  8. package main
  9. import (
  10. "crypto/tls"
  11. "flag"
  12. "fmt"
  13. "io"
  14. "io/ioutil"
  15. "log"
  16. "net/http"
  17. "os"
  18. "strings"
  19. "time"
  20. "git.torproject.org/pluggable-transports/snowflake.git/common/messages"
  21. "git.torproject.org/pluggable-transports/snowflake.git/common/safelog"
  22. "git.torproject.org/pluggable-transports/snowflake.git/common/util"
  23. "github.com/pion/webrtc/v3"
  24. "golang.org/x/crypto/acme/autocert"
  25. )
  26. const (
  27. readLimit = 100000 //Maximum number of bytes to be read from an HTTP request
  28. dataChannelTimeout = 20 * time.Second //time after which we assume proxy data channel will not open
  29. stunUrl = "stun:stun.l.google.com:19302" //default STUN URL
  30. )
  31. // Create a PeerConnection from an SDP offer. Blocks until the gathering of ICE
  32. // candidates is complete and the answer is available in LocalDescription.
  33. func makePeerConnectionFromOffer(sdp *webrtc.SessionDescription,
  34. dataChan chan struct{}) (*webrtc.PeerConnection, error) {
  35. config := webrtc.Configuration{
  36. ICEServers: []webrtc.ICEServer{
  37. {
  38. URLs: []string{stunUrl},
  39. },
  40. },
  41. }
  42. pc, err := webrtc.NewPeerConnection(config)
  43. if err != nil {
  44. return nil, fmt.Errorf("accept: NewPeerConnection: %s", err)
  45. }
  46. pc.OnDataChannel(func(dc *webrtc.DataChannel) {
  47. dc.OnOpen(func() {
  48. close(dataChan)
  49. })
  50. dc.OnClose(func() {
  51. dc.Close()
  52. })
  53. })
  54. // As of v3.0.0, pion-webrtc uses trickle ICE by default.
  55. // We have to wait for candidate gathering to complete
  56. // before we send the offer
  57. done := webrtc.GatheringCompletePromise(pc)
  58. err = pc.SetRemoteDescription(*sdp)
  59. if err != nil {
  60. if inerr := pc.Close(); inerr != nil {
  61. log.Printf("unable to call pc.Close after pc.SetRemoteDescription with error: %v", inerr)
  62. }
  63. return nil, fmt.Errorf("accept: SetRemoteDescription: %s", err)
  64. }
  65. answer, err := pc.CreateAnswer(nil)
  66. if err != nil {
  67. if inerr := pc.Close(); inerr != nil {
  68. log.Printf("ICE gathering has generated an error when calling pc.Close: %v", inerr)
  69. }
  70. return nil, err
  71. }
  72. err = pc.SetLocalDescription(answer)
  73. if err != nil {
  74. if err = pc.Close(); err != nil {
  75. log.Printf("pc.Close after setting local description returned : %v", err)
  76. }
  77. return nil, err
  78. }
  79. // Wait for ICE candidate gathering to complete
  80. <-done
  81. return pc, nil
  82. }
  83. func probeHandler(w http.ResponseWriter, r *http.Request) {
  84. w.Header().Set("Access-Control-Allow-Origin", "*")
  85. resp, err := ioutil.ReadAll(http.MaxBytesReader(w, r.Body, readLimit))
  86. if nil != err {
  87. log.Println("Invalid data.")
  88. w.WriteHeader(http.StatusBadRequest)
  89. return
  90. }
  91. offer, _, err := messages.DecodePollResponse(resp)
  92. if err != nil {
  93. log.Printf("Error reading offer: %s", err.Error())
  94. w.WriteHeader(http.StatusBadRequest)
  95. return
  96. }
  97. if offer == "" {
  98. log.Printf("Error processing session description: %s", err.Error())
  99. w.WriteHeader(http.StatusBadRequest)
  100. return
  101. }
  102. sdp, err := util.DeserializeSessionDescription(offer)
  103. if err != nil {
  104. log.Printf("Error processing session description: %s", err.Error())
  105. w.WriteHeader(http.StatusBadRequest)
  106. return
  107. }
  108. dataChan := make(chan struct{})
  109. pc, err := makePeerConnectionFromOffer(sdp, dataChan)
  110. if err != nil {
  111. log.Printf("Error making WebRTC connection: %s", err)
  112. w.WriteHeader(http.StatusInternalServerError)
  113. return
  114. }
  115. sdp = &webrtc.SessionDescription{
  116. Type: pc.LocalDescription().Type,
  117. SDP: util.StripLocalAddresses(pc.LocalDescription().SDP),
  118. }
  119. answer, err := util.SerializeSessionDescription(sdp)
  120. if err != nil {
  121. log.Printf("Error making WebRTC connection: %s", err)
  122. w.WriteHeader(http.StatusInternalServerError)
  123. return
  124. }
  125. body, err := messages.EncodeAnswerRequest(answer, "stub-sid")
  126. if err != nil {
  127. log.Printf("Error making WebRTC connection: %s", err)
  128. w.WriteHeader(http.StatusInternalServerError)
  129. return
  130. }
  131. w.Write(body)
  132. // Set a timeout on peerconnection. If the connection state has not
  133. // advanced to PeerConnectionStateConnected in this time,
  134. // destroy the peer connection and return the token.
  135. go func() {
  136. timer := time.NewTimer(dataChannelTimeout)
  137. defer timer.Stop()
  138. select {
  139. case <-dataChan:
  140. case <-timer.C:
  141. }
  142. if err := pc.Close(); err != nil {
  143. log.Printf("Error calling pc.Close: %v", err)
  144. }
  145. }()
  146. return
  147. }
  148. func main() {
  149. var acmeEmail string
  150. var acmeHostnamesCommas string
  151. var acmeCertCacheDir string
  152. var addr string
  153. var disableTLS bool
  154. var certFilename, keyFilename string
  155. var unsafeLogging bool
  156. flag.StringVar(&acmeEmail, "acme-email", "", "optional contact email for Let's Encrypt notifications")
  157. flag.StringVar(&acmeHostnamesCommas, "acme-hostnames", "", "comma-separated hostnames for TLS certificate")
  158. flag.StringVar(&acmeCertCacheDir, "acme-cert-cache", "acme-cert-cache", "directory in which certificates should be cached")
  159. flag.StringVar(&certFilename, "cert", "", "TLS certificate file")
  160. flag.StringVar(&keyFilename, "key", "", "TLS private key file")
  161. flag.StringVar(&addr, "addr", ":8443", "address to listen on")
  162. flag.BoolVar(&disableTLS, "disable-tls", false, "don't use HTTPS")
  163. flag.BoolVar(&unsafeLogging, "unsafe-logging", false, "prevent logs from being scrubbed")
  164. flag.Parse()
  165. var logOutput io.Writer = os.Stderr
  166. if unsafeLogging {
  167. log.SetOutput(logOutput)
  168. } else {
  169. // Scrub log output just in case an address ends up there
  170. log.SetOutput(&safelog.LogScrubber{Output: logOutput})
  171. }
  172. log.SetFlags(log.LstdFlags | log.LUTC)
  173. http.HandleFunc("/probe", probeHandler)
  174. server := http.Server{
  175. Addr: addr,
  176. }
  177. var err error
  178. if acmeHostnamesCommas != "" {
  179. acmeHostnames := strings.Split(acmeHostnamesCommas, ",")
  180. log.Printf("ACME hostnames: %q", acmeHostnames)
  181. var cache autocert.Cache
  182. if err = os.MkdirAll(acmeCertCacheDir, 0700); err != nil {
  183. log.Printf("Warning: Couldn't create cache directory %q (reason: %s) so we're *not* using our certificate cache.", acmeCertCacheDir, err)
  184. } else {
  185. cache = autocert.DirCache(acmeCertCacheDir)
  186. }
  187. certManager := autocert.Manager{
  188. Cache: cache,
  189. Prompt: autocert.AcceptTOS,
  190. HostPolicy: autocert.HostWhitelist(acmeHostnames...),
  191. Email: acmeEmail,
  192. }
  193. // start certificate manager handler
  194. go func() {
  195. log.Printf("Starting HTTP-01 listener")
  196. log.Fatal(http.ListenAndServe(":80", certManager.HTTPHandler(nil)))
  197. }()
  198. server.TLSConfig = &tls.Config{GetCertificate: certManager.GetCertificate}
  199. err = server.ListenAndServeTLS("", "")
  200. } else if certFilename != "" && keyFilename != "" {
  201. err = server.ListenAndServeTLS(certFilename, keyFilename)
  202. } else if disableTLS {
  203. err = server.ListenAndServe()
  204. } else {
  205. log.Fatal("the --cert and --key, --acme-hostnames, or --disable-tls option is required")
  206. }
  207. if err != nil {
  208. log.Println(err)
  209. }
  210. }