123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258 |
- /*
- Probe test server to check the reachability of Snowflake proxies from
- clients with symmetric NATs.
- The probe server receives an offer from a proxy, returns an answer, and then
- attempts to establish a datachannel connection to that proxy. The proxy will
- self-determine whether the connection opened successfully.
- */
- package main
- import (
- "crypto/tls"
- "flag"
- "fmt"
- "io"
- "log"
- "net/http"
- "os"
- "strings"
- "time"
- "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/ptutil/safelog"
- "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/messages"
- "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/util"
- "github.com/pion/transport/v2/stdnet"
- "github.com/pion/webrtc/v3"
- "golang.org/x/crypto/acme/autocert"
- )
- const (
- readLimit = 100000 //Maximum number of bytes to be read from an HTTP request
- dataChannelTimeout = 20 * time.Second //time after which we assume proxy data channel will not open
- defaultStunUrl = "stun:stun.l.google.com:19302" //default STUN URL
- )
- type ProbeHandler struct {
- stunURL string
- handle func(string, http.ResponseWriter, *http.Request)
- }
- func (h ProbeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
- h.handle(h.stunURL, w, r)
- }
- // Create a PeerConnection from an SDP offer. Blocks until the gathering of ICE
- // candidates is complete and the answer is available in LocalDescription.
- func makePeerConnectionFromOffer(stunURL string, sdp *webrtc.SessionDescription,
- dataChan chan struct{}) (*webrtc.PeerConnection, error) {
- settingsEngine := webrtc.SettingEngine{}
- // Use the SetNet setting https://pkg.go.dev/github.com/pion/webrtc/v3#SettingEngine.SetNet
- // to functionally revert a new change in pion by silently ignoring
- // when net.Interfaces() fails, rather than throwing an error
- vnet, _ := stdnet.NewNet()
- settingsEngine.SetNet(vnet)
- api := webrtc.NewAPI(webrtc.WithSettingEngine(settingsEngine))
- config := webrtc.Configuration{
- ICEServers: []webrtc.ICEServer{
- {
- URLs: []string{stunURL},
- },
- },
- }
- pc, err := api.NewPeerConnection(config)
- if err != nil {
- return nil, fmt.Errorf("accept: NewPeerConnection: %s", err)
- }
- pc.OnDataChannel(func(dc *webrtc.DataChannel) {
- dc.OnOpen(func() {
- close(dataChan)
- })
- dc.OnClose(func() {
- dc.Close()
- })
- })
- // As of v3.0.0, pion-webrtc uses trickle ICE by default.
- // We have to wait for candidate gathering to complete
- // before we send the offer
- done := webrtc.GatheringCompletePromise(pc)
- err = pc.SetRemoteDescription(*sdp)
- if err != nil {
- if inerr := pc.Close(); inerr != nil {
- log.Printf("unable to call pc.Close after pc.SetRemoteDescription with error: %v", inerr)
- }
- return nil, fmt.Errorf("accept: SetRemoteDescription: %s", err)
- }
- answer, err := pc.CreateAnswer(nil)
- if err != nil {
- if inerr := pc.Close(); inerr != nil {
- log.Printf("ICE gathering has generated an error when calling pc.Close: %v", inerr)
- }
- return nil, err
- }
- err = pc.SetLocalDescription(answer)
- if err != nil {
- if err = pc.Close(); err != nil {
- log.Printf("pc.Close after setting local description returned : %v", err)
- }
- return nil, err
- }
- // Wait for ICE candidate gathering to complete
- <-done
- return pc, nil
- }
- func probeHandler(stunURL string, w http.ResponseWriter, r *http.Request) {
- w.Header().Set("Access-Control-Allow-Origin", "*")
- resp, err := io.ReadAll(http.MaxBytesReader(w, r.Body, readLimit))
- if nil != err {
- log.Println("Invalid data.")
- w.WriteHeader(http.StatusBadRequest)
- return
- }
- offer, _, err := messages.DecodePollResponse(resp)
- if err != nil {
- log.Printf("Error reading offer: %s", err.Error())
- w.WriteHeader(http.StatusBadRequest)
- return
- }
- if offer == "" {
- log.Printf("Error processing session description: %s", err.Error())
- w.WriteHeader(http.StatusBadRequest)
- return
- }
- sdp, err := util.DeserializeSessionDescription(offer)
- if err != nil {
- log.Printf("Error processing session description: %s", err.Error())
- w.WriteHeader(http.StatusBadRequest)
- return
- }
- dataChan := make(chan struct{})
- pc, err := makePeerConnectionFromOffer(stunURL, sdp, dataChan)
- if err != nil {
- log.Printf("Error making WebRTC connection: %s", err)
- w.WriteHeader(http.StatusInternalServerError)
- return
- }
- sdp = &webrtc.SessionDescription{
- Type: pc.LocalDescription().Type,
- SDP: util.StripLocalAddresses(pc.LocalDescription().SDP),
- }
- answer, err := util.SerializeSessionDescription(sdp)
- if err != nil {
- log.Printf("Error making WebRTC connection: %s", err)
- w.WriteHeader(http.StatusInternalServerError)
- return
- }
- body, err := messages.EncodeAnswerRequest(answer, "stub-sid")
- if err != nil {
- log.Printf("Error making WebRTC connection: %s", err)
- w.WriteHeader(http.StatusInternalServerError)
- return
- }
- w.Write(body)
- // Set a timeout on peerconnection. If the connection state has not
- // advanced to PeerConnectionStateConnected in this time,
- // destroy the peer connection and return the token.
- go func() {
- timer := time.NewTimer(dataChannelTimeout)
- defer timer.Stop()
- select {
- case <-dataChan:
- case <-timer.C:
- }
- if err := pc.Close(); err != nil {
- log.Printf("Error calling pc.Close: %v", err)
- }
- }()
- return
- }
- func main() {
- var acmeEmail string
- var acmeHostnamesCommas string
- var acmeCertCacheDir string
- var addr string
- var disableTLS bool
- var certFilename, keyFilename string
- var unsafeLogging bool
- var stunURL string
- flag.StringVar(&acmeEmail, "acme-email", "", "optional contact email for Let's Encrypt notifications")
- flag.StringVar(&acmeHostnamesCommas, "acme-hostnames", "", "comma-separated hostnames for TLS certificate")
- flag.StringVar(&acmeCertCacheDir, "acme-cert-cache", "acme-cert-cache", "directory in which certificates should be cached")
- flag.StringVar(&certFilename, "cert", "", "TLS certificate file")
- flag.StringVar(&keyFilename, "key", "", "TLS private key file")
- flag.StringVar(&addr, "addr", ":8443", "address to listen on")
- flag.BoolVar(&disableTLS, "disable-tls", false, "don't use HTTPS")
- flag.BoolVar(&unsafeLogging, "unsafe-logging", false, "prevent logs from being scrubbed")
- flag.StringVar(&stunURL, "stun", defaultStunUrl, "STUN server to use for NAT traversal")
- flag.Parse()
- var logOutput io.Writer = os.Stderr
- if unsafeLogging {
- log.SetOutput(logOutput)
- } else {
- // Scrub log output just in case an address ends up there
- log.SetOutput(&safelog.LogScrubber{Output: logOutput})
- }
- log.SetFlags(log.LstdFlags | log.LUTC)
- http.Handle("/probe", ProbeHandler{stunURL, probeHandler})
- server := http.Server{
- Addr: addr,
- }
- var err error
- if acmeHostnamesCommas != "" {
- acmeHostnames := strings.Split(acmeHostnamesCommas, ",")
- log.Printf("ACME hostnames: %q", acmeHostnames)
- var cache autocert.Cache
- if err = os.MkdirAll(acmeCertCacheDir, 0700); err != nil {
- log.Printf("Warning: Couldn't create cache directory %q (reason: %s) so we're *not* using our certificate cache.", acmeCertCacheDir, err)
- } else {
- cache = autocert.DirCache(acmeCertCacheDir)
- }
- certManager := autocert.Manager{
- Cache: cache,
- Prompt: autocert.AcceptTOS,
- HostPolicy: autocert.HostWhitelist(acmeHostnames...),
- Email: acmeEmail,
- }
- // start certificate manager handler
- go func() {
- log.Printf("Starting HTTP-01 listener")
- log.Fatal(http.ListenAndServe(":80", certManager.HTTPHandler(nil)))
- }()
- server.TLSConfig = &tls.Config{GetCertificate: certManager.GetCertificate}
- err = server.ListenAndServeTLS("", "")
- } else if certFilename != "" && keyFilename != "" {
- err = server.ListenAndServeTLS(certFilename, keyFilename)
- } else if disableTLS {
- err = server.ListenAndServe()
- } else {
- log.Fatal("the --cert and --key, --acme-hostnames, or --disable-tls option is required")
- }
- if err != nil {
- log.Println(err)
- }
- }
|